├── .gitignore
├── .idea
└── .gitignore
├── LICENSE
├── README.md
├── README
└── image-20230818115344581.png
├── __init__.py
├── data
├── JSH_dataset_train.py
├── JSH_dataset_val.py
├── __init__.py
├── data_sampler.py
└── util.py
├── models
├── CSNorm_model.py
├── __init__.py
├── base_model.py
├── ckpts
│ └── NAF_LOL.pth
├── lr_scheduler.py
├── modules
│ ├── NAFNet
│ │ ├── Baseline_arch.py
│ │ ├── NAFNet.py
│ │ ├── arch_util.py
│ │ └── local_arch.py
│ ├── __init__.py
│ ├── common.py
│ ├── loss.py
│ ├── loss_new.py
│ └── module_util.py
└── networks.py
├── options
├── __init__.py
├── options.py
├── test
│ └── test.yml
└── train
│ └── train_InvDN.yml
├── test.py
├── train.py
└── utils
├── __init__.py
├── pytorch_ssim
└── __init__.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | *.iml
162 | *.xml
163 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 mingde-yao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # [ICCV 2023 :fire:] Generalized Lightness Adaptation with Channel Selective Normalization.
5 |
6 | [Mingde Yao](https://scholar.google.com/citations?user=fsE3MzwAAAAJ&hl=en)\*, [Jie Huang](https://huangkevinj.github.io/)\*, [Xin Jin](http://home.ustc.edu.cn/~jinxustc/), [Ruikang Xu](https://scholar.google.com/citations?user=PulrrscAAAAJ&hl=en), Shenglong Zhou, [Man Zhou](https://manman1995.github.io/), [Zhiwei Xiong](http://staff.ustc.edu.cn/~zwxiong/)
7 |
8 | University of Science and Technology of China
9 |
10 | Eastern Institute of Technology
11 |
12 | Nanyang Technological University
13 |
14 |
15 | [[`Paper`](https://arxiv.org/pdf/2308.13783.pdf)] [[`BibTeX`](#heart-citing-us)] :zap: :rocket: :fire:
16 |
17 | [](https://github.com/pre-commit/pre-commit)
18 | [](https://pytorch.org/get-started/locally/)
19 | [](#license)
20 |
21 | :rocket: Welcome! This is the official repository of [ICCV'23] Generalized Lightness Adaptation with Channel Selective Normalization.
22 |
23 |
24 |
25 |
26 |
27 | ## 📌 Overview
28 |
29 | >Lightness adaptation is vital to the success of image processing to avoid unexpected visual deterioration, which covers multiple aspects, e.g., low-light image enhancement, image retouching, and inverse tone mapping. Existing methods typically work well on their trained lightness conditions but perform poorly in unknown ones due to their limited generalization ability. To address this limitation, we propose a novel generalized lightness adaptation algorithm that extends conventional normalization techniques through a channel filtering design, dubbed Channel Selective Normalization (CSNorm). The proposed CSNorm purposely normalizes the statistics of lightness-relevant channels and keeps other channels unchanged, so as to improve feature generalization and discrimination. To optimize CSNorm, we propose an alternating training strategy that effectively identifies lightness-relevant channels. The model equipped with our CSNorm only needs to be trained on one lightness condition and can be well generalized to unknown lightness conditions. Experimental results on multiple benchmark datasets demonstrate the effectiveness of CSNorm in enhancing the generalization ability for the existing lightness adaptation methods.
30 |
31 |
32 | 
33 |
34 | Overview of our proposed method. (a) Channel selective normalization (CSNorm), which consists of an instance-level normalization module and a differential gating module. (b) Differential gating module. It outputs a series of on-off switch gates for binarized channel selection in CSNorm. (c) Alternating training strategy. In the first step, we optimize the parameters outside CSNorm to keep an essential ability for lightness adaptation. In the second step, we only update the parameters inside the CSNorm (see (a)&(b)) with lightness-perturbed images. The two steps drive the CSNorm to select channels sensitive to lightness changes, which are normalized in $x_{n+1}$.
35 |
36 |
37 |
43 |
44 |
45 |
46 | ## :sunflower: Results
47 |
48 | 
49 |
50 | Visual comparisons of the generalized image retouching on the MIT-Adobe FiveK dataset. The models are trained on the original dataset and tested on the unseen lightness condition.
51 |
52 |
53 |
54 | ## :rocket: Usage
55 |
56 |
57 |
58 |
59 | To train the model equipped with CSNorm:
60 |
61 | 1. Modify the paths for training and testing in the configuration file (options/train/train_InvDN.yml).
62 | 2. Execute the command "python train.py -opt options/train/train_InvDN.yml".
63 | 3. Drink a cup of coffee or have a nice sleep.
64 | 4. Get the trained model.
65 |
66 |
67 | We employ the [NAFNet](https://github.com/mdyao/CSNorm/blob/62056d2ba45c6ab356a29e4a155d2f72c4c87beb/models/modules/NAFNet/NAFNet.py) as our base model, demonstrating the integration of CSNorm.
68 |
69 | Feel free to replace NAFNet with your preferred backbone when incorporating CSNorm:
70 |
71 |
72 | 1. Define the on-off switch gate function, where CHANNEL_NUM should be pre-defined.
73 |
74 | ```
75 |
76 | class Generate_gate(nn.Module):
77 | def __init__(self):
78 | super(Generate_gate, self).__init__()
79 | self.proj = nn.Sequential(nn.AdaptiveAvgPool2d(1),
80 | nn.Conv2d(CHANNEL_NUM,CHANNEL_NUM, 1),
81 | nn.ReLU(),
82 | nn.Conv2d(CHANNEL_NUM,CHANNEL_NUM, 1),
83 | nn.ReLU())
84 |
85 | self.epsilon = 1e-8
86 | def forward(self, x):
87 |
88 |
89 | alpha = self.proj(x)
90 | gate = (alpha**2) / (alpha**2 + self.epsilon)
91 |
92 | return gate
93 |
94 | def freeze(layer):
95 | for child in layer.children():
96 | for param in child.parameters():
97 | param.requires_grad = False
98 |
99 |
100 | def freeze_direct(layer):
101 | for param in layer.parameters():
102 | param.requires_grad = False
103 |
104 | ```
105 |
106 | 2. Initialize CSNorm in the `__init__()` Method, where CHANNEL_NUM should be pre-defined.:
107 |
108 | ```
109 | self.gate = Generate_gate()
110 | for i in range(CHANNEL_NUM):
111 | setattr(self, 'CSN_' + str(i), nn.InstanceNorm2d(1, affine=True))
112 | freeze_direct(getattr(self, 'CSN_' + str(i)))
113 | freeze(self.gate)
114 | ```
115 |
116 | 3. Integrate the Code in the `forward()` Method of Your Backbone, where CHANNEL_NUM should be pre-defined.
117 |
118 | ```
119 | x = conv(x)
120 | ...
121 | gate = self.gate(x)
122 | lq_copy = torch.cat([getattr(self, 'CSN_' + str(i))(x[:,i,:,:][:,None,:,:]) for i in range(CHANNEL_NUM)], dim=1)
123 | x = gate * lq_copy + (1-gate) * x
124 | ```
125 |
126 | 4. The 2-step training strategy is in https://github.com/mdyao/CSNorm/blob/main/models/CSNorm_model.py.
127 |
128 | ## :heart: Citing Us
129 | If you find this repository or our work useful, please consider giving a star :star: and citation :t-rex: , which would be greatly appreciated:
130 |
131 | ```bibtex
132 | @inproceedings{yao2023csnorm,
133 | title={Generalized Lightness Adaptation with Channel Selective Normalization},
134 | author={Mingde Yao, Jie Huang, Xin Jin, Ruikang Xu, Shenglong Zhou, Man Zhou, and Zhiwei Xiong},
135 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
136 | year={2023}
137 | }
138 | ```
139 |
140 |
141 | ## :email: Contact
142 |
143 |
144 |
145 | For any inquiries or questions, please contact me by email (mdyao@mail.ustc.edu.cn).
146 |
--------------------------------------------------------------------------------
/README/image-20230818115344581.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/README/image-20230818115344581.png
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/__init__.py
--------------------------------------------------------------------------------
/data/JSH_dataset_train.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 | import h5py
5 | import torch
6 | import torch.utils.data as data
7 | import data.util as util
8 | import glob
9 | import os
10 |
11 | class JSHDataset(data.Dataset):
12 | '''
13 | Read LQ (Low Quality, here is LR), GT and noisy image pairs.
14 | If only GT and noisy images are provided, generate LQ image on-the-fly.
15 | The pair is ensured by 'sorted' function, so please check the name convention.
16 | '''
17 |
18 | def __init__(self, opt):
19 | super(JSHDataset, self).__init__()
20 | self.opt = opt
21 | self.data_type = self.opt['data_type']
22 | self.gtimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_gt'], '*')))
23 | self.inputimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_lq'], '*')))
24 | self.length = len(self.gtimglist)
25 |
26 | def __getitem__(self, index):
27 | self.input = cv2.imread(self.inputimglist[index])
28 | self.gt = cv2.imread(self.gtimglist[index])
29 | GT_size = self.opt['GT_size']
30 |
31 | # get GT image
32 | input_img = self.input/255.0
33 | gt_img = self.gt/255.0
34 | input_img = input_img.transpose(2,0,1)
35 | gt_img = gt_img.transpose(2,0,1)
36 |
37 | if self.opt['phase'] == 'train':
38 | C, H, W = input_img.shape
39 | x = random.randint(0, W - GT_size)
40 | y = random.randint(0, H - GT_size)
41 | input_img = input_img[:, y:y + GT_size, x:x + GT_size]
42 | gt_img = gt_img[:, y:y + GT_size, x:x + GT_size]
43 |
44 | # augmentation - flip, rotate
45 | input_img, gt_img = util.augment([input_img, gt_img], self.opt['use_flip'],
46 | self.opt['use_rot'])
47 |
48 | # BGR to RGB, HWC to CHW, numpy to tensor
49 | input_img = torch.from_numpy(np.ascontiguousarray(input_img)).float()
50 | gt_img = torch.from_numpy(np.ascontiguousarray(gt_img)).float()
51 |
52 | return {'gt_img': gt_img, 'lq_img': input_img}
53 |
54 | def __len__(self):
55 | return self.length
56 |
57 |
58 |
--------------------------------------------------------------------------------
/data/JSH_dataset_val.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import cv2
4 | import h5py
5 | import torch
6 | import torch.utils.data as data
7 | import data.util as util
8 | import glob
9 | import os
10 |
11 | class JSHDataset(data.Dataset):
12 | '''
13 | Read LQ (Low Quality, here is LR), GT and noisy image pairs.
14 | If only GT and noisy images are provided, generate LQ image on-the-fly.
15 | The pair is ensured by 'sorted' function, so please check the name convention.
16 | '''
17 |
18 | def __init__(self, opt):
19 | super(JSHDataset, self).__init__()
20 | self.opt = opt
21 | self.data_type = self.opt['data_type']
22 | self.gtimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_gt'], '*')))
23 | self.inputimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_lq'], '*')))
24 | self.length = len(self.gtimglist)
25 |
26 | def __getitem__(self, index):
27 | self.input = cv2.imread(self.inputimglist[index])
28 | self.gt = cv2.imread(self.gtimglist[index])
29 | GT_size = self.opt['GT_size']
30 |
31 | # get GT image
32 | input_img = self.input/255.0
33 | gt_img = self.gt/255.0
34 | input_img = input_img.transpose(2,0,1)
35 | gt_img = gt_img.transpose(2,0,1)
36 |
37 | if self.opt['phase'] == 'train':
38 | C, H, W = input_img.shape
39 | x = random.randint(0, W - GT_size)
40 | y = random.randint(0, H - GT_size)
41 | # input_img = input_img[:, y:y + GT_size, x:x + GT_size]
42 |
43 | # augmentation - flip, rotate
44 | input_img, gt_img = util.augment([input_img, gt_img], self.opt['use_flip'],
45 | self.opt['use_rot'])
46 | # BGR to RGB, HWC to CHW, numpy to tensor
47 | input_img = torch.from_numpy(np.ascontiguousarray(input_img)).float()
48 | gt_img = torch.from_numpy(np.ascontiguousarray(gt_img)).float()
49 |
50 | return {'gt_img': gt_img, 'lq_img': input_img}
51 |
52 | def __len__(self):
53 | return self.length
54 |
55 |
56 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | '''create dataset and dataloader'''
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 |
7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
8 | phase = dataset_opt['phase']
9 | if phase == 'train':
10 | if opt['dist']:
11 | world_size = torch.distributed.get_world_size()
12 | num_workers = dataset_opt['n_workers']
13 | assert dataset_opt['batch_size'] % world_size == 0
14 | batch_size = dataset_opt['batch_size'] // world_size
15 | shuffle = False
16 | else:
17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids'])
18 | batch_size = dataset_opt['batch_size']
19 | shuffle = True
20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
21 | num_workers=num_workers, sampler=sampler, drop_last=True,
22 | pin_memory=False)
23 | else:
24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=6,
25 | pin_memory=False)
26 |
27 |
28 | def create_dataset(dataset_opt):
29 | mode = dataset_opt['mode']
30 | if mode == 'JSH_train':
31 | from data.JSH_dataset_train import JSHDataset as D
32 | elif mode =='JSH_val':
33 | from data.JSH_dataset_val import JSHDataset as D
34 | else:
35 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
36 | dataset = D(dataset_opt)
37 |
38 | logger = logging.getLogger('base')
39 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
40 | dataset_opt['name']))
41 | return dataset
42 |
--------------------------------------------------------------------------------
/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from torch.utils.data.distributed.DistributedSampler
3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the
4 | dataloader after each epoch
5 | """
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 | import torch.distributed as dist
10 |
11 |
12 | class DistIterSampler(Sampler):
13 | """Sampler that restricts data loading to a subset of the dataset.
14 |
15 | It is especially useful in conjunction with
16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
17 | process can pass a DistributedSampler instance as a DataLoader sampler,
18 | and load a subset of the original dataset that is exclusive to it.
19 |
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 |
23 | Arguments:
24 | dataset: Dataset used for sampling.
25 | num_replicas (optional): Number of processes participating in
26 | distributed training.
27 | rank (optional): Rank of the current process within num_replicas.
28 | """
29 |
30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100):
31 | if num_replicas is None:
32 | if not dist.is_available():
33 | raise RuntimeError("Requires distributed package to be available")
34 | num_replicas = dist.get_world_size()
35 | if rank is None:
36 | if not dist.is_available():
37 | raise RuntimeError("Requires distributed package to be available")
38 | rank = dist.get_rank()
39 | self.dataset = dataset
40 | self.num_replicas = num_replicas
41 | self.rank = rank
42 | self.epoch = 0
43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas))
44 | self.total_size = self.num_samples * self.num_replicas
45 |
46 | def __iter__(self):
47 | # deterministically shuffle based on epoch
48 | g = torch.Generator()
49 | g.manual_seed(self.epoch)
50 | indices = torch.randperm(self.total_size, generator=g).tolist()
51 |
52 | dsize = len(self.dataset)
53 | indices = [v % dsize for v in indices]
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | assert len(indices) == self.num_samples
58 |
59 | return iter(indices)
60 |
61 | def __len__(self):
62 | return self.num_samples
63 |
64 | def set_epoch(self, epoch):
65 | self.epoch = epoch
66 |
--------------------------------------------------------------------------------
/data/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import pickle
4 | import random
5 | import numpy as np
6 | import torch
7 | import cv2
8 | import h5py
9 |
10 | ####################
11 | # Files & IO
12 | ####################
13 |
14 | ###################### get image path list ######################
15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
16 |
17 |
18 | def is_image_file(filename):
19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20 |
21 |
22 | def _get_paths_from_images(path):
23 | '''get image path list from image folder'''
24 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
25 | images = []
26 | for dirpath, _, fnames in sorted(os.walk(path)):
27 | for fname in sorted(fnames):
28 | if is_image_file(fname):
29 | img_path = os.path.join(dirpath, fname)
30 | images.append(img_path)
31 | assert images, '{:s} has no valid image file'.format(path)
32 | return images
33 |
34 |
35 | def _get_paths_from_lmdb(dataroot):
36 | '''get image path list from lmdb meta info'''
37 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
38 | paths = meta_info['keys']
39 | sizes = meta_info['resolution']
40 | if len(sizes) == 1:
41 | sizes = sizes * len(paths)
42 | return paths, sizes
43 |
44 | def _get_paths_from_mat(dataroot):
45 | '''get image path list from lmdb meta info'''
46 | meta_info = h5py.File(os.path.join(dataroot), 'r')
47 | key = meta_info.keys()[0]
48 | sizes = meta_info[0][-2::].size()
49 | return key, sizes
50 |
51 | def get_image_paths(data_type, dataroot):
52 | '''get image path list
53 | support lmdb or image files'''
54 | paths, sizes = None, None
55 | if dataroot is not None:
56 | if data_type == 'mat':
57 | paths, sizes = _get_paths_from_mat(dataroot)
58 | elif data_type == 'img':
59 | paths = sorted(_get_paths_from_images(dataroot))
60 | else:
61 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
62 | return paths, sizes
63 |
64 |
65 | ###################### read images ######################
66 | def _read_img_lmdb(env, key, size):
67 | '''read image from lmdb with key (w/ and w/o fixed size)
68 | size: (C, H, W) tuple'''
69 | with env.begin(write=False) as txn:
70 | buf = txn.get(key.encode('ascii'))
71 | img_flat = np.frombuffer(buf, dtype=np.uint8)
72 | C, H, W = size
73 | img = img_flat.reshape(H, W, C)
74 | return img
75 |
76 |
77 | def read_img(env, path, size=None):
78 | '''read image by cv2 or from lmdb
79 | return: Numpy float32, HWC, BGR, [0,1]'''
80 | if env is None: # img
81 | #img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
82 | img = cv2.imread(path, cv2.IMREAD_COLOR)
83 | else:
84 | img = _read_img_lmdb(env, path, size)
85 | img = img.astype(np.float32) / 255.
86 | if img.ndim == 2:
87 | img = np.expand_dims(img, axis=2)
88 | # some images have 4 channels
89 | if img.shape[2] > 3:
90 | img = img[:, :, :3]
91 | return img
92 |
93 | def read_img_array(img):
94 | '''read image array and preprocess
95 | return: Numpy float32, HWC, BGR, [0,1]'''
96 | img = img.astype(np.float32) / 255.
97 | if img.ndim == 2:
98 | img = np.expand_dims(img, axis=2)
99 | return img
100 |
101 | ####################
102 | # image processing
103 | # process on numpy image
104 | ####################
105 |
106 |
107 | def augment(img_list, hflip=True, rot=True):
108 | # horizontal flip OR rotate
109 | hflip = hflip and random.random() < 0.5
110 | vflip = rot and random.random() < 0.5
111 | rot90 = rot and random.random() < 0.5
112 |
113 | def _augment(img):
114 | if isinstance(img, list):
115 | if hflip:
116 | img = [image[:, ::-1, :] for image in img]
117 | if vflip:
118 | img = [image[:, :, ::-1] for image in img]
119 | if rot90:
120 | img = [image.transpose(0, 2, 1) for image in img]
121 | else:
122 | if hflip:
123 | img = img[:, ::-1, :]
124 | if vflip:
125 | img = img[:, :, ::-1]
126 | if rot90:
127 | img = img.transpose(0, 2, 1)
128 | return img
129 |
130 | return [_augment(img) for img in img_list]
131 |
132 |
133 | def augment_flow(img_list, flow_list, hflip=True, rot=True):
134 | # horizontal flip OR rotate
135 | hflip = hflip and random.random() < 0.5
136 | vflip = rot and random.random() < 0.5
137 | rot90 = rot and random.random() < 0.5
138 |
139 | def _augment(img):
140 | if hflip:
141 | img = img[:, ::-1, :]
142 | if vflip:
143 | img = img[::-1, :, :]
144 | if rot90:
145 | img = img.transpose(1, 0, 2)
146 | return img
147 |
148 | def _augment_flow(flow):
149 | if hflip:
150 | flow = flow[:, ::-1, :]
151 | flow[:, :, 0] *= -1
152 | if vflip:
153 | flow = flow[::-1, :, :]
154 | flow[:, :, 1] *= -1
155 | if rot90:
156 | flow = flow.transpose(1, 0, 2)
157 | flow = flow[:, :, [1, 0]]
158 | return flow
159 |
160 | rlt_img_list = [_augment(img) for img in img_list]
161 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
162 |
163 | return rlt_img_list, rlt_flow_list
164 |
165 |
166 | def channel_convert(in_c, tar_type, img_list):
167 | # conversion among BGR, gray and y
168 | if in_c == 3 and tar_type == 'gray': # BGR to gray
169 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
170 | return [np.expand_dims(img, axis=2) for img in gray_list]
171 | elif in_c == 3 and tar_type == 'y': # BGR to y
172 | y_list = [bgr2ycbcr(img, only_y=False) for img in img_list]
173 | return y_list
174 | # return [np.expand_dims(img, axis=2) for img in y_list]
175 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
176 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
177 | else:
178 | return img_list
179 |
180 |
181 | def rgb2ycbcr(img, only_y=True):
182 | '''same as matlab rgb2ycbcr
183 | only_y: only return Y channel
184 | Input:
185 | uint8, [0, 255]
186 | float, [0, 1]
187 | '''
188 | in_img_type = img.dtype
189 | img.astype(np.float32)
190 | if in_img_type != np.uint8:
191 | img *= 255.
192 | # convert
193 | if only_y:
194 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
195 | else:
196 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
197 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
198 | if in_img_type == np.uint8:
199 | rlt = rlt.round()
200 | else:
201 | rlt /= 255.
202 | return rlt.astype(in_img_type)
203 |
204 |
205 | def bgr2ycbcr(img, only_y=True):
206 | '''bgr version of rgb2ycbcr
207 | only_y: only return Y channel
208 | Input:
209 | uint8, [0, 255]
210 | float, [0, 1]
211 | '''
212 | in_img_type = img.dtype
213 | img.astype(np.float32)
214 | if in_img_type != np.uint8:
215 | img *= 255.
216 | # convert
217 | if only_y:
218 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
219 | else:
220 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
221 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
222 | if in_img_type == np.uint8:
223 | rlt = rlt.round()
224 | else:
225 | rlt /= 255.
226 | return rlt.astype(in_img_type)
227 |
228 |
229 | def ycbcr2rgb(img):
230 | '''same as matlab ycbcr2rgb
231 | Input:
232 | uint8, [0, 255]
233 | float, [0, 1]
234 | '''
235 | in_img_type = img.dtype
236 | img.astype(np.float32)
237 | if in_img_type != np.uint8:
238 | img *= 255.
239 | # convert
240 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
241 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
242 | if in_img_type == np.uint8:
243 | rlt = rlt.round()
244 | else:
245 | rlt /= 255.
246 | return rlt.astype(in_img_type)
247 |
248 |
249 | def modcrop(img_in, scale):
250 | # img_in: Numpy, CHW or HW
251 | img = np.copy(img_in)
252 | if img.ndim == 2:
253 | H, W = img.shape
254 | H_r, W_r = H % scale, W % scale
255 | img = img[:H - H_r, :W - W_r]
256 | elif img.ndim == 3:
257 | C, H, W = img.shape
258 | H_r, W_r = H % scale, W % scale
259 | img = img[:, :H - H_r, :W - W_r]
260 | else:
261 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
262 | return img
263 |
264 |
265 | ####################
266 | # Functions
267 | ####################
268 |
269 |
270 | # matlab 'imresize' function, now only support 'bicubic'
271 | def cubic(x):
272 | absx = torch.abs(x)
273 | absx2 = absx**2
274 | absx3 = absx**3
275 | return (1.5 * absx3 - 2.5 * absx2 + 1) * (
276 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((
277 | (absx > 1) * (absx <= 2)).type_as(absx))
278 |
279 |
280 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
281 | if (scale < 1) and (antialiasing):
282 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
283 | kernel_width = kernel_width / scale
284 |
285 | # Output-space coordinates
286 | x = torch.linspace(1, out_length, out_length)
287 |
288 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
289 | # in output space maps to 0.5 in input space, and 0.5+scale in output
290 | # space maps to 1.5 in input space.
291 | u = x / scale + 0.5 * (1 - 1 / scale)
292 |
293 | # What is the left-most pixel that can be involved in the computation?
294 | left = torch.floor(u - kernel_width / 2)
295 |
296 | # What is the maximum number of pixels that can be involved in the
297 | # computation? Note: it's OK to use an extra pixel here; if the
298 | # corresponding weights are all zero, it will be eliminated at the end
299 | # of this function.
300 | P = math.ceil(kernel_width) + 2
301 |
302 | # The indices of the input pixels involved in computing the k-th output
303 | # pixel are in row k of the indices matrix.
304 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
305 | 1, P).expand(out_length, P)
306 |
307 | # The weights used to compute the k-th output pixel are in row k of the
308 | # weights matrix.
309 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
310 | # apply cubic kernel
311 | if (scale < 1) and (antialiasing):
312 | weights = scale * cubic(distance_to_center * scale)
313 | else:
314 | weights = cubic(distance_to_center)
315 | # Normalize the weights matrix so that each row sums to 1.
316 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
317 | weights = weights / weights_sum.expand(out_length, P)
318 |
319 | # If a column in weights is all zero, get rid of it. only consider the first and last column.
320 | weights_zero_tmp = torch.sum((weights == 0), 0)
321 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
322 | indices = indices.narrow(1, 1, P - 2)
323 | weights = weights.narrow(1, 1, P - 2)
324 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
325 | indices = indices.narrow(1, 0, P - 2)
326 | weights = weights.narrow(1, 0, P - 2)
327 | weights = weights.contiguous()
328 | indices = indices.contiguous()
329 | sym_len_s = -indices.min() + 1
330 | sym_len_e = indices.max() - in_length
331 | indices = indices + sym_len_s - 1
332 | return weights, indices, int(sym_len_s), int(sym_len_e)
333 |
334 |
335 | def imresize(img, scale, antialiasing=True):
336 | # Now the scale should be the same for H and W
337 | # input: img: CHW RGB [0,1]
338 | # output: CHW RGB [0,1] w/o round
339 |
340 | in_C, in_H, in_W = img.size()
341 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
342 | kernel_width = 4
343 | kernel = 'cubic'
344 |
345 | # Return the desired dimension order for performing the resize. The
346 | # strategy is to perform the resize first along the dimension with the
347 | # smallest scale factor.
348 | # Now we do not support this.
349 |
350 | # get weights and indices
351 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
352 | in_H, out_H, scale, kernel, kernel_width, antialiasing)
353 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
354 | in_W, out_W, scale, kernel, kernel_width, antialiasing)
355 | # process H dimension
356 | # symmetric copying
357 | img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
358 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
359 |
360 | sym_patch = img[:, :sym_len_Hs, :]
361 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
362 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
363 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
364 |
365 | sym_patch = img[:, -sym_len_He:, :]
366 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
367 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
368 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
369 |
370 | out_1 = torch.FloatTensor(in_C, out_H, in_W)
371 | kernel_width = weights_H.size(1)
372 | for i in range(out_H):
373 | idx = int(indices_H[i][0])
374 | out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
375 | out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
376 | out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
377 |
378 | # process W dimension
379 | # symmetric copying
380 | out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
381 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
382 |
383 | sym_patch = out_1[:, :, :sym_len_Ws]
384 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
385 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
386 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
387 |
388 | sym_patch = out_1[:, :, -sym_len_We:]
389 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
390 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
391 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
392 |
393 | out_2 = torch.FloatTensor(in_C, out_H, out_W)
394 | kernel_width = weights_W.size(1)
395 | for i in range(out_W):
396 | idx = int(indices_W[i][0])
397 | out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i])
398 | out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i])
399 | out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i])
400 |
401 | return out_2
402 |
403 |
404 | def imresize_np(img, scale, antialiasing=True):
405 | # Now the scale should be the same for H and W
406 | # input: img: Numpy, HWC BGR [0,1]
407 | # output: HWC BGR [0,1] w/o round
408 | img = torch.from_numpy(img)
409 |
410 | in_H, in_W, in_C = img.size()
411 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
412 | kernel_width = 4
413 | kernel = 'cubic'
414 |
415 | # Return the desired dimension order for performing the resize. The
416 | # strategy is to perform the resize first along the dimension with the
417 | # smallest scale factor.
418 | # Now we do not support this.
419 |
420 | # get weights and indices
421 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
422 | in_H, out_H, scale, kernel, kernel_width, antialiasing)
423 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
424 | in_W, out_W, scale, kernel, kernel_width, antialiasing)
425 | # process H dimension
426 | # symmetric copying
427 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
428 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
429 |
430 | sym_patch = img[:sym_len_Hs, :, :]
431 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
432 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
433 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
434 |
435 | sym_patch = img[-sym_len_He:, :, :]
436 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
437 | sym_patch_inv = sym_patch.index_select(0, inv_idx)
438 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
439 |
440 | out_1 = torch.FloatTensor(out_H, in_W, in_C)
441 | kernel_width = weights_H.size(1)
442 | for i in range(out_H):
443 | idx = int(indices_H[i][0])
444 | out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i])
445 | out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i])
446 | out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i])
447 |
448 | # process W dimension
449 | # symmetric copying
450 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
451 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
452 |
453 | sym_patch = out_1[:, :sym_len_Ws, :]
454 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
455 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
456 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
457 |
458 | sym_patch = out_1[:, -sym_len_We:, :]
459 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
460 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
461 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
462 |
463 | out_2 = torch.FloatTensor(out_H, out_W, in_C)
464 | kernel_width = weights_W.size(1)
465 | for i in range(out_W):
466 | idx = int(indices_W[i][0])
467 | out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i])
468 | out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i])
469 | out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i])
470 |
471 | return out_2.numpy()
472 |
473 |
474 | if __name__ == '__main__':
475 | # test imresize function
476 | # read images
477 | img = cv2.imread('test.png')
478 | img = img * 1.0 / 255
479 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
480 | # imresize
481 | scale = 1 / 4
482 | import time
483 | total_time = 0
484 | for i in range(10):
485 | start_time = time.time()
486 | rlt = imresize(img, scale, antialiasing=True)
487 | use_time = time.time() - start_time
488 | total_time += use_time
489 | print('average time: {}'.format(total_time / 10))
490 |
491 | import torchvision.utils
492 | torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0,
493 | normalize=False)
494 |
--------------------------------------------------------------------------------
/models/CSNorm_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 | import models.networks as networks
8 | import models.lr_scheduler as lr_scheduler
9 | from .base_model import BaseModel
10 | from models.modules.loss import FFT_Loss
11 | import numpy as np
12 | import time
13 | from models.modules.loss_new import SSIMLoss
14 | import re
15 |
16 | logger = logging.getLogger('base')
17 |
18 |
19 | class CSNorm_Model(BaseModel):
20 | def __init__(self, opt):
21 | super(CSNorm_Model, self).__init__(opt)
22 |
23 |
24 | if opt['dist']:
25 | self.rank = torch.distributed.get_rank()
26 | else:
27 | self.rank = -1 # non dist training
28 | train_opt = opt['train']
29 | test_opt = opt['test']
30 | self.train_opt = train_opt
31 | self.test_opt = test_opt
32 |
33 | self.netG = networks.define_G(opt).to(self.device)
34 | if opt['dist']:
35 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
36 | else:
37 | self.netG = DataParallel(self.netG)
38 |
39 | ######################### set parameters in CSNorm ###############################
40 | target_layer_patterns = re.compile(r'module\.(gate\.proj|CSN_\d+)\.')
41 | # target_layer_patterns = re.compile(r'(gate\.proj|CSN_\d+)\.')
42 |
43 | self.layer_aug = [
44 | name for name, param in self.netG.named_parameters()
45 | if target_layer_patterns.search(name)
46 | ]
47 | # print('parameters in CSNorm:',self.layer_aug)
48 | ######################### set parameters in CSNorm ###############################
49 |
50 | # loss
51 | self.Back_rec = torch.nn.L1Loss()
52 | self.ssim_loss = SSIMLoss()
53 | self.fft_loss = FFT_Loss()
54 | # self.print_network()
55 | self.load()
56 |
57 | if self.is_train:
58 | self.netG.train()
59 |
60 | # optimizers
61 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
62 | optim_params = []
63 | optim_params_aug = []
64 | for k, v in self.netG.named_parameters():
65 | if v.requires_grad:
66 | optim_params.append(v)
67 | else:
68 | if self.rank <= 0:
69 | logger.warning('Params [{:s}] will not optimize.'.format(k))
70 |
71 | for k, v in self.netG.named_parameters():
72 | if k in self.layer_aug:
73 | optim_params_aug.append(v)
74 | else:
75 | if self.rank <= 0:
76 | logger.warning('Params [{:s}] will not optimize in aug.'.format(k))
77 |
78 |
79 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
80 | weight_decay=wd_G,
81 | betas=(train_opt['beta1'], train_opt['beta2']))
82 |
83 | self.optimizer_G_aug = torch.optim.Adam(optim_params_aug, lr=train_opt['lr_G'],
84 | weight_decay=wd_G,
85 | betas=(train_opt['beta1'], train_opt['beta2']))
86 |
87 | self.optimizers.append(self.optimizer_G)
88 | self.optimizers.append(self.optimizer_G_aug)
89 |
90 | # schedulers
91 | if train_opt['lr_scheme'] == 'MultiStepLR':
92 | for optimizer in self.optimizers:
93 | self.schedulers.append(
94 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
95 | restarts=train_opt['restarts'],
96 | weights=train_opt['restart_weights'],
97 | gamma=train_opt['lr_gamma'],
98 | clear_state=train_opt['clear_state']))
99 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
100 | for optimizer in self.optimizers:
101 | self.schedulers.append(
102 | lr_scheduler.CosineAnnealingLR_Restart(
103 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
104 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
105 | else:
106 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
107 |
108 | self.log_dict = OrderedDict()
109 |
110 | def amp_aug(self, x, y):
111 | x = x + 1e-8
112 | y = y + 1e-8
113 | x_freq= torch.fft.rfft2(x, norm='backward')
114 | x_amp = torch.abs(x_freq)
115 | x_phase = torch.angle(x_freq)
116 |
117 | y_freq= torch.fft.rfft2(y, norm='backward')
118 | y_amp = torch.abs(y_freq)
119 | y_phase = torch.angle(y_freq)
120 |
121 | mix_alpha = torch.rand(1).to(self.device)/0.5
122 | mix_alpha = torch.clip(mix_alpha, 0,0.5)
123 | y_amp = mix_alpha * y_amp + (1-mix_alpha) * x_amp
124 |
125 | real = y_amp * torch.cos(y_phase)
126 | imag = y_amp * torch.sin(y_phase)
127 | y_out = torch.complex(real, imag) + 1e-8
128 | y_out = torch.fft.irfft2(y_out) + 1e-8
129 |
130 | return y_out
131 |
132 | def feed_data(self, data):
133 | self.img_gt = data['gt_img'].to(self.device)
134 | self.img_input = data['lq_img'].to(self.device)
135 | self.img_input_aug = self.amp_aug(self.img_gt, self.img_input)
136 |
137 | def feed_data_test(self, data):
138 | # self.ref_L = data['LQ'].to(self.device)
139 | self.img_gt = data['gt_img'].to(self.device)
140 | self.img_input = data['lq_img'].to(self.device)
141 |
142 | def loss_forward(self,img, gt):
143 | loss = 1 * self.Back_rec(img, gt)
144 | loss_ssim = self.ssim_loss(img, gt)
145 |
146 | return loss, loss_ssim
147 |
148 |
149 | def loss_forward_aug(self,img, gt):
150 | loss = 1 * self.Back_rec(img, gt)
151 | loss_ssim = self.ssim_loss(img, gt)
152 |
153 | l_amp, _ = self.fft_loss(img, gt)
154 | return loss, loss_ssim, l_amp
155 |
156 | def optimize_parameters(self, step):
157 |
158 | ############## optimize parameters outside CSNorm ############################
159 | for k, v in self.netG.named_parameters():
160 | if k not in self.layer_aug:
161 | v.requires_grad = True
162 | else:
163 | v.requires_grad = False
164 | self.optimizer_G.zero_grad()
165 |
166 | # forward
167 | self.img_pred = self.netG(self.img_input, aug=True)
168 | loss, l_ssim = self.loss_forward(self.img_pred, self.img_gt)
169 | loss = loss + l_ssim
170 |
171 | # backward
172 | loss.backward()
173 |
174 | # gradient clipping
175 | if self.train_opt['gradient_clipping']:
176 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
177 |
178 | self.optimizer_G.step()
179 |
180 |
181 | ############## optimize parameters inside CSNorm ############################
182 | for k, v in self.netG.named_parameters():
183 | if k in self.layer_aug:
184 | v.requires_grad = True
185 | else:
186 | v.requires_grad = False
187 |
188 | self.optimizer_G_aug.zero_grad()
189 |
190 | # forward
191 | self.img_pred = self.netG(self.img_input_aug, aug=True)
192 | loss_back, l_ssim, l_amp = self.loss_forward_aug(self.img_pred, self.img_gt)
193 | loss_aug = loss_back + l_ssim + l_amp
194 | # backward
195 | loss_aug.backward()
196 |
197 | # gradient clipping
198 | if self.train_opt['gradient_clipping']:
199 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
200 |
201 | self.optimizer_G_aug.step()
202 |
203 |
204 | # set log
205 | self.log_dict['loss'] = loss.item()
206 | self.log_dict['l_amp'] = l_amp.item()
207 | self.log_dict['l_ssim'] = l_ssim.item()
208 |
209 | def test(self):
210 |
211 | self.netG.eval()
212 | with torch.no_grad():
213 | self.img_pred = self.netG(self.img_input, aug=True)
214 |
215 | self.netG.train()
216 |
217 | def get_current_log(self):
218 | return self.log_dict
219 |
220 | def get_current_visuals(self):
221 | out_dict = OrderedDict()
222 | out_dict['img_pred'] = self.img_pred.detach()[0].float().cpu()
223 | out_dict['img_input'] = self.img_input.detach()[0].float().cpu()
224 | out_dict['img_gt'] = self.img_gt.detach()[0].float().cpu()
225 | return out_dict
226 |
227 | def print_network(self):
228 | s, n = self.get_network_description(self.netG)
229 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
230 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
231 | self.netG.module.__class__.__name__)
232 | else:
233 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
234 | if self.rank <= 0:
235 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
236 | logger.info(s)
237 |
238 | def load(self):
239 |
240 | load_path_G = self.opt['path']['pretrain_model_G']
241 | if load_path_G is not None:
242 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
243 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
244 |
245 | def save(self, iter_label):
246 | self.save_network(self.netG, 'G', iter_label)
247 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | model = opt['model']
7 |
8 | if model == 'CSNorm':
9 | from .CSNorm_model import CSNorm_Model as M
10 | else:
11 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model))
12 | m = M(opt)
13 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
14 | return m
15 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.parallel import DistributedDataParallel
6 |
7 |
8 | class BaseModel():
9 | def __init__(self, opt):
10 | self.opt = opt
11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu')
12 | self.is_train = opt['is_train']
13 | self.schedulers = []
14 | self.optimizers = []
15 |
16 | def feed_data(self, data):
17 | pass
18 |
19 | def optimize_parameters(self):
20 | pass
21 |
22 | def get_current_visuals(self):
23 | pass
24 |
25 | def get_current_losses(self):
26 | pass
27 |
28 | def print_network(self):
29 | pass
30 |
31 | def save(self, label):
32 | pass
33 |
34 | def load(self):
35 | pass
36 |
37 | def _set_lr(self, lr_groups_l):
38 | ''' set learning rate for warmup,
39 | lr_groups_l: list for lr_groups. each for a optimizer'''
40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
41 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
42 | param_group['lr'] = lr
43 |
44 | def _get_init_lr(self):
45 | # get the initial lr, which is set by the scheduler
46 | init_lr_groups_l = []
47 | for optimizer in self.optimizers:
48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
49 | return init_lr_groups_l
50 |
51 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
52 | for scheduler in self.schedulers:
53 | scheduler.step()
54 | #### set up warm up learning rate
55 | if cur_iter < warmup_iter:
56 | # get initial lr for each group
57 | init_lr_g_l = self._get_init_lr()
58 | # modify warming-up learning rates
59 | warm_up_lr_l = []
60 | for init_lr_g in init_lr_g_l:
61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
62 | # set learning rate
63 | self._set_lr(warm_up_lr_l)
64 |
65 | def get_current_learning_rate(self):
66 | # return self.schedulers[0].get_lr()[0]
67 | return self.optimizers[0].param_groups[0]['lr']
68 |
69 | def get_network_description(self, network):
70 | '''Get the string and total parameters of the network'''
71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
72 | network = network.module
73 | s = str(network)
74 | n = sum(map(lambda x: x.numel(), network.parameters()))
75 | return s, n
76 |
77 | def save_network(self, network, network_label, iter_label):
78 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
79 | save_path = os.path.join(self.opt['path']['models'], save_filename)
80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
81 | network = network.module
82 | state_dict = network.state_dict()
83 | for key, param in state_dict.items():
84 | state_dict[key] = param.cpu()
85 | torch.save(state_dict, save_path)
86 |
87 | def load_network(self, load_path, network, strict=True):
88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
89 | network = network.module
90 | load_net = torch.load(load_path)
91 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
92 | for k, v in load_net.items():
93 | if k.startswith('module.'):
94 | load_net_clean[k[7:]] = v
95 | else:
96 | load_net_clean[k] = v
97 | network.load_state_dict(load_net_clean, strict=strict)
98 |
99 | def save_training_state(self, epoch, iter_step):
100 | '''Saves training state during training, which will be used for resuming'''
101 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
102 | for s in self.schedulers:
103 | state['schedulers'].append(s.state_dict())
104 | for o in self.optimizers:
105 | state['optimizers'].append(o.state_dict())
106 | save_filename = '{}.state'.format(iter_step)
107 | save_path = os.path.join(self.opt['path']['training_state'], save_filename)
108 | torch.save(state, save_path)
109 |
110 | def resume_training(self, resume_state):
111 | '''Resume the optimizers and schedulers for training'''
112 | resume_optimizers = resume_state['optimizers']
113 | resume_schedulers = resume_state['schedulers']
114 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
115 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
116 | for i, o in enumerate(resume_optimizers):
117 | self.optimizers[i].load_state_dict(o)
118 | for i, s in enumerate(resume_schedulers):
119 | self.schedulers[i].load_state_dict(s)
120 |
--------------------------------------------------------------------------------
/models/ckpts/NAF_LOL.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/models/ckpts/NAF_LOL.pth
--------------------------------------------------------------------------------
/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from collections import defaultdict
4 | import torch
5 | from torch.optim.lr_scheduler import _LRScheduler
6 |
7 |
8 | class MultiStepLR_Restart(_LRScheduler):
9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
10 | clear_state=False, last_epoch=-1):
11 | self.milestones = Counter(milestones)
12 | self.gamma = gamma
13 | self.clear_state = clear_state
14 | self.restarts = restarts if restarts else [0]
15 | self.restart_weights = weights if weights else [1]
16 | assert len(self.restarts) == len(
17 | self.restart_weights), 'restarts and their weights do not match.'
18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
19 |
20 | def get_lr(self):
21 | if self.last_epoch in self.restarts:
22 | if self.clear_state:
23 | self.optimizer.state = defaultdict(dict)
24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
26 | if self.last_epoch not in self.milestones:
27 | return [group['lr'] for group in self.optimizer.param_groups]
28 | return [
29 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
30 | for group in self.optimizer.param_groups
31 | ]
32 |
33 |
34 | class CosineAnnealingLR_Restart(_LRScheduler):
35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
36 | self.T_period = T_period
37 | self.T_max = self.T_period[0] # current T period
38 | self.eta_min = eta_min
39 | self.restarts = restarts if restarts else [0]
40 | self.restart_weights = weights if weights else [1]
41 | self.last_restart = 0
42 | assert len(self.restarts) == len(
43 | self.restart_weights), 'restarts and their weights do not match.'
44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
45 |
46 | def get_lr(self):
47 | if self.last_epoch == 0:
48 | return self.base_lrs
49 | elif self.last_epoch in self.restarts:
50 | self.last_restart = self.last_epoch
51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
55 | return [
56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
58 | ]
59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
61 | (group['lr'] - self.eta_min) + self.eta_min
62 | for group in self.optimizer.param_groups]
63 |
64 |
65 | if __name__ == "__main__":
66 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
67 | betas=(0.9, 0.99))
68 | ##############################
69 | # MultiStepLR_Restart
70 | ##############################
71 | ## Original
72 | lr_steps = [200000, 400000, 600000, 800000]
73 | restarts = None
74 | restart_weights = None
75 |
76 | ## two
77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
78 | restarts = [500000]
79 | restart_weights = [1]
80 |
81 | ## four
82 | lr_steps = [
83 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
84 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
85 | ]
86 | restarts = [250000, 500000, 750000]
87 | restart_weights = [1, 1, 1]
88 |
89 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
90 | clear_state=False)
91 |
92 | ##############################
93 | # Cosine Annealing Restart
94 | ##############################
95 | ## two
96 | T_period = [500000, 500000]
97 | restarts = [500000]
98 | restart_weights = [1]
99 |
100 | ## four
101 | T_period = [250000, 250000, 250000, 250000]
102 | restarts = [250000, 500000, 750000]
103 | restart_weights = [1, 1, 1]
104 |
105 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
106 | weights=restart_weights)
107 |
108 | ##############################
109 | # Draw figure
110 | ##############################
111 | N_iter = 1000000
112 | lr_l = list(range(N_iter))
113 | for i in range(N_iter):
114 | scheduler.step()
115 | current_lr = optimizer.param_groups[0]['lr']
116 | lr_l[i] = current_lr
117 |
118 | import matplotlib as mpl
119 | from matplotlib import pyplot as plt
120 | import matplotlib.ticker as mtick
121 | mpl.style.use('default')
122 | import seaborn
123 | seaborn.set(style='whitegrid')
124 | seaborn.set_context('paper')
125 |
126 | plt.figure(1)
127 | plt.subplot(111)
128 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
129 | plt.title('Title', fontsize=16, color='k')
130 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
131 | legend = plt.legend(loc='upper right', shadow=False)
132 | ax = plt.gca()
133 | labels = ax.get_xticks().tolist()
134 | for k, v in enumerate(labels):
135 | labels[k] = str(int(v / 1000)) + 'K'
136 | ax.set_xticklabels(labels)
137 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
138 |
139 | ax.set_ylabel('Learning rate')
140 | ax.set_xlabel('Iteration')
141 | fig = plt.gcf()
142 | plt.show()
143 |
--------------------------------------------------------------------------------
/models/modules/NAFNet/Baseline_arch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 |
5 | '''
6 | Simple Baselines for Image Restoration
7 |
8 | @article{chen2022simple,
9 | title={Simple Baselines for Image Restoration},
10 | author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
11 | journal={arXiv preprint arXiv:2204.04676},
12 | year={2022}
13 | }
14 | '''
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from basicsr.models.archs.arch_util import LayerNorm2d
20 | from basicsr.models.archs.local_arch import Local_Base
21 |
22 | class BaselineBlock(nn.Module):
23 | def __init__(self, c, DW_Expand=1, FFN_Expand=2, drop_out_rate=0.):
24 | super().__init__()
25 | dw_channel = c * DW_Expand
26 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
27 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
28 | bias=True)
29 | self.conv3 = nn.Conv2d(in_channels=dw_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
30 |
31 | # Channel Attention
32 | self.se = nn.Sequential(
33 | nn.AdaptiveAvgPool2d(1),
34 | nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
35 | groups=1, bias=True),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel, kernel_size=1, padding=0, stride=1,
38 | groups=1, bias=True),
39 | nn.Sigmoid()
40 | )
41 |
42 | # GELU
43 | self.gelu = nn.GELU()
44 |
45 | ffn_channel = FFN_Expand * c
46 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
47 | self.conv5 = nn.Conv2d(in_channels=ffn_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
48 |
49 | self.norm1 = LayerNorm2d(c)
50 | self.norm2 = LayerNorm2d(c)
51 |
52 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
53 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
54 |
55 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
56 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
57 |
58 | def forward(self, inp):
59 | x = inp
60 |
61 | x = self.norm1(x)
62 |
63 | x = self.conv1(x)
64 | x = self.conv2(x)
65 | x = self.gelu(x)
66 | x = x * self.se(x)
67 | x = self.conv3(x)
68 |
69 | x = self.dropout1(x)
70 |
71 | y = inp + x * self.beta
72 |
73 | x = self.conv4(self.norm2(y))
74 | x = self.gelu(x)
75 | x = self.conv5(x)
76 |
77 | x = self.dropout2(x)
78 |
79 | return y + x * self.gamma
80 |
81 |
82 | class Baseline(nn.Module):
83 |
84 | def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], dw_expand=1, ffn_expand=2):
85 | super().__init__()
86 |
87 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
88 | bias=True)
89 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
90 | bias=True)
91 |
92 | self.encoders = nn.ModuleList()
93 | self.decoders = nn.ModuleList()
94 | self.middle_blks = nn.ModuleList()
95 | self.ups = nn.ModuleList()
96 | self.downs = nn.ModuleList()
97 |
98 | chan = width
99 | for num in enc_blk_nums:
100 | self.encoders.append(
101 | nn.Sequential(
102 | *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(num)]
103 | )
104 | )
105 | self.downs.append(
106 | nn.Conv2d(chan, 2*chan, 2, 2)
107 | )
108 | chan = chan * 2
109 |
110 | self.middle_blks = \
111 | nn.Sequential(
112 | *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(middle_blk_num)]
113 | )
114 |
115 | for num in dec_blk_nums:
116 | self.ups.append(
117 | nn.Sequential(
118 | nn.Conv2d(chan, chan * 2, 1, bias=False),
119 | nn.PixelShuffle(2)
120 | )
121 | )
122 | chan = chan // 2
123 | self.decoders.append(
124 | nn.Sequential(
125 | *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(num)]
126 | )
127 | )
128 |
129 | self.padder_size = 2 ** len(self.encoders)
130 |
131 | def forward(self, inp):
132 | B, C, H, W = inp.shape
133 | inp = self.check_image_size(inp)
134 |
135 | x = self.intro(inp)
136 |
137 | encs = []
138 |
139 | for encoder, down in zip(self.encoders, self.downs):
140 | x = encoder(x)
141 | encs.append(x)
142 | x = down(x)
143 |
144 | x = self.middle_blks(x)
145 |
146 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
147 | x = up(x)
148 | x = x + enc_skip
149 | x = decoder(x)
150 |
151 | x = self.ending(x)
152 | x = x + inp
153 |
154 | return x[:, :, :H, :W]
155 |
156 | def check_image_size(self, x):
157 | _, _, h, w = x.size()
158 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
159 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
160 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
161 | return x
162 |
163 | class BaselineLocal(Local_Base, Baseline):
164 | def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
165 | Local_Base.__init__(self)
166 | Baseline.__init__(self, *args, **kwargs)
167 |
168 | N, C, H, W = train_size
169 | base_size = (int(H * 1.5), int(W * 1.5))
170 |
171 | self.eval()
172 | with torch.no_grad():
173 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
174 |
175 | if __name__ == '__main__':
176 | img_channel = 3
177 | width = 32
178 |
179 | dw_expand = 1
180 | ffn_expand = 2
181 |
182 | # enc_blks = [2, 2, 4, 8]
183 | # middle_blk_num = 12
184 | # dec_blks = [2, 2, 2, 2]
185 |
186 | enc_blks = [1, 1, 1, 28]
187 | middle_blk_num = 1
188 | dec_blks = [1, 1, 1, 1]
189 |
190 | net = Baseline(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
191 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, dw_expand=dw_expand, ffn_expand=ffn_expand)
192 |
193 | inp_shape = (3, 256, 256)
194 |
195 | from ptflops import get_model_complexity_info
196 |
197 | macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
198 |
199 | params = float(params[:-3])
200 | macs = float(macs[:-4])
201 |
202 | print(macs, params)
203 |
--------------------------------------------------------------------------------
/models/modules/NAFNet/NAFNet.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 |
5 | '''
6 | Simple Baselines for Image Restoration
7 |
8 | @article{chen2022simple,
9 | title={Simple Baselines for Image Restoration},
10 | author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
11 | journal={arXiv preprint arXiv:2204.04676},
12 | year={2022}
13 | }
14 | '''
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.nn.functional as F
19 | from models.modules.NAFNet.arch_util import LayerNorm2d
20 | from models.modules.NAFNet.local_arch import Local_Base
21 |
22 | class SimpleGate(nn.Module):
23 | def forward(self, x):
24 | x1, x2 = x.chunk(2, dim=1)
25 | return x1 * x2
26 |
27 | class NAFBlock(nn.Module):
28 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
29 | super().__init__()
30 | dw_channel = c * DW_Expand
31 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
32 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
33 | bias=True)
34 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
35 |
36 | # Simplified Channel Attention
37 | self.sca = nn.Sequential(
38 | nn.AdaptiveAvgPool2d(1),
39 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
40 | groups=1, bias=True),
41 | )
42 |
43 | # SimpleGate
44 | self.sg = SimpleGate()
45 |
46 | ffn_channel = FFN_Expand * c
47 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
48 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
49 |
50 | self.norm1 = LayerNorm2d(c)
51 | self.norm2 = LayerNorm2d(c)
52 |
53 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
54 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
55 |
56 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
57 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
58 |
59 | def forward(self, inp):
60 | x = inp
61 |
62 | x = self.norm1(x)
63 |
64 | x = self.conv1(x)
65 | x = self.conv2(x)
66 | x = self.sg(x)
67 | x = x * self.sca(x)
68 | x = self.conv3(x)
69 |
70 | x = self.dropout1(x)
71 |
72 | y = inp + x * self.beta
73 |
74 | x = self.conv4(self.norm2(y))
75 | x = self.sg(x)
76 | x = self.conv5(x)
77 |
78 | x = self.dropout2(x)
79 |
80 | return y + x * self.gamma
81 |
82 |
83 | class Generate_gate(nn.Module):
84 | def __init__(self):
85 | super(Generate_gate, self).__init__()
86 | self.proj = nn.Sequential(nn.AdaptiveAvgPool2d(1),
87 | nn.Conv2d(512,256, 1),
88 | nn.ReLU(),
89 | nn.Conv2d(256,512, 1),
90 | nn.ReLU())
91 |
92 | self.epsilon = 1e-8
93 | def forward(self, x):
94 |
95 |
96 | alpha = self.proj(x)
97 | gate = (alpha**2) / (alpha**2 + self.epsilon)
98 |
99 | return gate
100 |
101 | def freeze(layer):
102 | for child in layer.children():
103 | for param in child.parameters():
104 | param.requires_grad = False
105 |
106 |
107 | def unfreeze(layer):
108 | for child in layer.children():
109 | for param in child.parameters():
110 | param.requires_grad = True
111 |
112 | def freeze_direct(layer):
113 | for param in layer.parameters():
114 | param.requires_grad = False
115 |
116 |
117 | class NAFNet(nn.Module):
118 |
119 | def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
120 | super().__init__()
121 |
122 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
123 | bias=True)
124 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
125 | bias=True)
126 |
127 | self.encoders = nn.ModuleList()
128 | self.decoders = nn.ModuleList()
129 | self.middle_blks = nn.ModuleList()
130 | self.ups = nn.ModuleList()
131 | self.downs = nn.ModuleList()
132 |
133 | chan = width
134 | for num in enc_blk_nums:
135 | self.encoders.append(
136 | nn.Sequential(
137 | *[NAFBlock(chan) for _ in range(num)]
138 | )
139 | )
140 | self.downs.append(
141 | nn.Conv2d(chan, 2*chan, 2, 2)
142 | )
143 | chan = chan * 2
144 |
145 | self.middle_blks = \
146 | nn.Sequential(
147 | *[NAFBlock(chan) for _ in range(middle_blk_num)]
148 | )
149 |
150 | for num in dec_blk_nums:
151 | self.ups.append(
152 | nn.Sequential(
153 | nn.Conv2d(chan, chan * 2, 1, bias=False),
154 | nn.PixelShuffle(2)
155 | )
156 | )
157 | chan = chan // 2
158 | self.decoders.append(
159 | nn.Sequential(
160 | *[NAFBlock(chan) for _ in range(num)]
161 | )
162 | )
163 |
164 | self.padder_size = 2 ** len(self.encoders)
165 |
166 |
167 | ###################### init CSNorm ##################
168 | self.gate = Generate_gate()
169 | for i in range(512):
170 | setattr(self, 'CSN_' + str(i), nn.InstanceNorm2d(1, affine=True))
171 | freeze_direct(getattr(self, 'CSN_' + str(i)))
172 | freeze(self.gate)
173 | ###################### init CSNorm ##################
174 |
175 |
176 | def forward(self, inp, aug=False):
177 | B, C, H, W = inp.shape
178 | inp = self.check_image_size(inp)
179 |
180 | x = self.intro(inp)
181 |
182 | encs = []
183 |
184 | for encoder, down in zip(self.encoders, self.downs):
185 | x = encoder(x)
186 | encs.append(x)
187 | x = down(x)
188 |
189 | ##################### add CSNorm in the network #################
190 | if aug:
191 | gate = self.gate(x)
192 | lq_copy = torch.cat([getattr(self, 'CSN_' + str(i))(x[:,i,:,:][:,None,:,:]) for i in range(512)], dim=1)
193 | x = gate * (lq_copy) + (1-gate) * x
194 | ##################### add CSNorm in the network #################
195 |
196 |
197 | x = self.middle_blks(x)
198 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
199 | x = up(x)
200 | x = x + enc_skip
201 | x = decoder(x)
202 |
203 | x = self.ending(x)
204 | x = x + inp
205 |
206 | return x[:, :, :H, :W]
207 |
208 | def check_image_size(self, x):
209 | _, _, h, w = x.size()
210 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
211 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
212 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
213 | return x
214 |
215 | class NAFNetLocal(Local_Base, NAFNet):
216 | def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs):
217 | Local_Base.__init__(self)
218 | NAFNet.__init__(self, *args, **kwargs)
219 |
220 | N, C, H, W = train_size
221 | base_size = (int(H * 1.5), int(W * 1.5))
222 |
223 | self.eval()
224 | with torch.no_grad():
225 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
226 |
227 | if __name__ == '__main__':
228 | img_channel = 3
229 | width = 32
230 | enc_blks = [1, 1, 1, 1]
231 | middle_blk_num = 1
232 | dec_blks = [1, 1, 1, 1]
233 |
234 | model = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
235 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
236 |
237 | inp_shape = (1, 3, 64, 64)
238 |
239 | device = "cpu"
240 | if torch.cuda.is_available():
241 | device = "cuda"
242 | input1 = torch.randn(inp_shape).to(device)
243 | model = model.to(device)
244 |
245 | import re
246 | # layer_pattern = re.compile(r'module\.(gate\.proj|bn_\d+)\.')
247 | layer_pattern = re.compile(r'(gate\.proj|CSN_\d+)\.')
248 |
249 | selected_params = [
250 | name for name, param in model.named_parameters()
251 | if layer_pattern.search(name)
252 | ]
253 | print(selected_params)
254 |
--------------------------------------------------------------------------------
/models/modules/NAFNet/arch_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import math
8 | import torch
9 | from torch import nn as nn
10 | from torch.nn import functional as F
11 | from torch.nn import init as init
12 | from torch.nn.modules.batchnorm import _BatchNorm
13 |
14 | # from basicsr.utils import get_root_logger
15 |
16 | # try:
17 | # from basicsr.models.ops.dcn import (ModulatedDeformConvPack,
18 | # modulated_deform_conv)
19 | # except ImportError:
20 | # # print('Cannot import dcn. Ignore this warning if dcn is not used. '
21 | # # 'Otherwise install BasicSR with compiling dcn.')
22 | #
23 |
24 | @torch.no_grad()
25 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
26 | """Initialize network weights.
27 |
28 | Args:
29 | module_list (list[nn.Module] | nn.Module): Modules to be initialized.
30 | scale (float): Scale initialized weights, especially for residual
31 | blocks. Default: 1.
32 | bias_fill (float): The value to fill bias. Default: 0
33 | kwargs (dict): Other arguments for initialization function.
34 | """
35 | if not isinstance(module_list, list):
36 | module_list = [module_list]
37 | for module in module_list:
38 | for m in module.modules():
39 | if isinstance(m, nn.Conv2d):
40 | init.kaiming_normal_(m.weight, **kwargs)
41 | m.weight.data *= scale
42 | if m.bias is not None:
43 | m.bias.data.fill_(bias_fill)
44 | elif isinstance(m, nn.Linear):
45 | init.kaiming_normal_(m.weight, **kwargs)
46 | m.weight.data *= scale
47 | if m.bias is not None:
48 | m.bias.data.fill_(bias_fill)
49 | elif isinstance(m, _BatchNorm):
50 | init.constant_(m.weight, 1)
51 | if m.bias is not None:
52 | m.bias.data.fill_(bias_fill)
53 |
54 |
55 | def make_layer(basic_block, num_basic_block, **kwarg):
56 | """Make layers by stacking the same blocks.
57 |
58 | Args:
59 | basic_block (nn.module): nn.module class for basic block.
60 | num_basic_block (int): number of blocks.
61 |
62 | Returns:
63 | nn.Sequential: Stacked blocks in nn.Sequential.
64 | """
65 | layers = []
66 | for _ in range(num_basic_block):
67 | layers.append(basic_block(**kwarg))
68 | return nn.Sequential(*layers)
69 |
70 |
71 | class ResidualBlockNoBN(nn.Module):
72 | """Residual block without BN.
73 |
74 | It has a style of:
75 | ---Conv-ReLU-Conv-+-
76 | |________________|
77 |
78 | Args:
79 | num_feat (int): Channel number of intermediate features.
80 | Default: 64.
81 | res_scale (float): Residual scale. Default: 1.
82 | pytorch_init (bool): If set to True, use pytorch default init,
83 | otherwise, use default_init_weights. Default: False.
84 | """
85 |
86 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
87 | super(ResidualBlockNoBN, self).__init__()
88 | self.res_scale = res_scale
89 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
90 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
91 | self.relu = nn.ReLU(inplace=True)
92 |
93 | if not pytorch_init:
94 | default_init_weights([self.conv1, self.conv2], 0.1)
95 |
96 | def forward(self, x):
97 | identity = x
98 | out = self.conv2(self.relu(self.conv1(x)))
99 | return identity + out * self.res_scale
100 |
101 |
102 | class Upsample(nn.Sequential):
103 | """Upsample module.
104 |
105 | Args:
106 | scale (int): Scale factor. Supported scales: 2^n and 3.
107 | num_feat (int): Channel number of intermediate features.
108 | """
109 |
110 | def __init__(self, scale, num_feat):
111 | m = []
112 | if (scale & (scale - 1)) == 0: # scale = 2^n
113 | for _ in range(int(math.log(scale, 2))):
114 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
115 | m.append(nn.PixelShuffle(2))
116 | elif scale == 3:
117 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
118 | m.append(nn.PixelShuffle(3))
119 | else:
120 | raise ValueError(f'scale {scale} is not supported. '
121 | 'Supported scales: 2^n and 3.')
122 | super(Upsample, self).__init__(*m)
123 |
124 |
125 | def flow_warp(x,
126 | flow,
127 | interp_mode='bilinear',
128 | padding_mode='zeros',
129 | align_corners=True):
130 | """Warp an image or feature map with optical flow.
131 |
132 | Args:
133 | x (Tensor): Tensor with size (n, c, h, w).
134 | flow (Tensor): Tensor with size (n, h, w, 2), normal value.
135 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
136 | padding_mode (str): 'zeros' or 'border' or 'reflection'.
137 | Default: 'zeros'.
138 | align_corners (bool): Before pytorch 1.3, the default value is
139 | align_corners=True. After pytorch 1.3, the default value is
140 | align_corners=False. Here, we use the True as default.
141 |
142 | Returns:
143 | Tensor: Warped image or feature map.
144 | """
145 | assert x.size()[-2:] == flow.size()[1:3]
146 | _, _, h, w = x.size()
147 | # create mesh grid
148 | grid_y, grid_x = torch.meshgrid(
149 | torch.arange(0, h).type_as(x),
150 | torch.arange(0, w).type_as(x))
151 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
152 | grid.requires_grad = False
153 |
154 | vgrid = grid + flow
155 | # scale grid to [-1,1]
156 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
157 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
158 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
159 | output = F.grid_sample(
160 | x,
161 | vgrid_scaled,
162 | mode=interp_mode,
163 | padding_mode=padding_mode,
164 | align_corners=align_corners)
165 |
166 | # TODO, what if align_corners=False
167 | return output
168 |
169 |
170 | def resize_flow(flow,
171 | size_type,
172 | sizes,
173 | interp_mode='bilinear',
174 | align_corners=False):
175 | """Resize a flow according to ratio or shape.
176 |
177 | Args:
178 | flow (Tensor): Precomputed flow. shape [N, 2, H, W].
179 | size_type (str): 'ratio' or 'shape'.
180 | sizes (list[int | float]): the ratio for resizing or the final output
181 | shape.
182 | 1) The order of ratio should be [ratio_h, ratio_w]. For
183 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio
184 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
185 | ratio > 1.0).
186 | 2) The order of output_size should be [out_h, out_w].
187 | interp_mode (str): The mode of interpolation for resizing.
188 | Default: 'bilinear'.
189 | align_corners (bool): Whether align corners. Default: False.
190 |
191 | Returns:
192 | Tensor: Resized flow.
193 | """
194 | _, _, flow_h, flow_w = flow.size()
195 | if size_type == 'ratio':
196 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
197 | elif size_type == 'shape':
198 | output_h, output_w = sizes[0], sizes[1]
199 | else:
200 | raise ValueError(
201 | f'Size type should be ratio or shape, but got type {size_type}.')
202 |
203 | input_flow = flow.clone()
204 | ratio_h = output_h / flow_h
205 | ratio_w = output_w / flow_w
206 | input_flow[:, 0, :, :] *= ratio_w
207 | input_flow[:, 1, :, :] *= ratio_h
208 | resized_flow = F.interpolate(
209 | input=input_flow,
210 | size=(output_h, output_w),
211 | mode=interp_mode,
212 | align_corners=align_corners)
213 | return resized_flow
214 |
215 |
216 | # TODO: may write a cpp file
217 | def pixel_unshuffle(x, scale):
218 | """ Pixel unshuffle.
219 |
220 | Args:
221 | x (Tensor): Input feature with shape (b, c, hh, hw).
222 | scale (int): Downsample ratio.
223 |
224 | Returns:
225 | Tensor: the pixel unshuffled feature.
226 | """
227 | b, c, hh, hw = x.size()
228 | out_channel = c * (scale**2)
229 | assert hh % scale == 0 and hw % scale == 0
230 | h = hh // scale
231 | w = hw // scale
232 | x_view = x.view(b, c, h, scale, w, scale)
233 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
234 |
235 |
236 | # class DCNv2Pack(ModulatedDeformConvPack):
237 | # """Modulated deformable conv for deformable alignment.
238 | #
239 | # Different from the official DCNv2Pack, which generates offsets and masks
240 | # from the preceding features, this DCNv2Pack takes another different
241 | # features to generate offsets and masks.
242 | #
243 | # Ref:
244 | # Delving Deep into Deformable Alignment in Video Super-Resolution.
245 | # """
246 | #
247 | # def forward(self, x, feat):
248 | # out = self.conv_offset(feat)
249 | # o1, o2, mask = torch.chunk(out, 3, dim=1)
250 | # offset = torch.cat((o1, o2), dim=1)
251 | # mask = torch.sigmoid(mask)
252 | #
253 | # offset_absmean = torch.mean(torch.abs(offset))
254 | # if offset_absmean > 50:
255 | # logger = get_root_logger()
256 | # logger.warning(
257 | # f'Offset abs mean is {offset_absmean}, larger than 50.')
258 | #
259 | # return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
260 | # self.stride, self.padding, self.dilation,
261 | # self.groups, self.deformable_groups)
262 |
263 |
264 | class LayerNormFunction(torch.autograd.Function):
265 |
266 | @staticmethod
267 | def forward(ctx, x, weight, bias, eps):
268 | ctx.eps = eps
269 | N, C, H, W = x.size()
270 | mu = x.mean(1, keepdim=True)
271 | var = (x - mu).pow(2).mean(1, keepdim=True)
272 | y = (x - mu) / (var + eps).sqrt()
273 | ctx.save_for_backward(y, var, weight)
274 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
275 | return y
276 |
277 | @staticmethod
278 | def backward(ctx, grad_output):
279 | eps = ctx.eps
280 |
281 | N, C, H, W = grad_output.size()
282 | y, var, weight = ctx.saved_variables
283 | g = grad_output * weight.view(1, C, 1, 1)
284 | mean_g = g.mean(dim=1, keepdim=True)
285 |
286 | mean_gy = (g * y).mean(dim=1, keepdim=True)
287 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
288 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
289 | dim=0), None
290 |
291 | class LayerNorm2d(nn.Module):
292 |
293 | def __init__(self, channels, eps=1e-6):
294 | super(LayerNorm2d, self).__init__()
295 | self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
296 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
297 | self.eps = eps
298 |
299 | def forward(self, x):
300 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
301 |
302 | # handle multiple input
303 | class MySequential(nn.Sequential):
304 | def forward(self, *inputs):
305 | for module in self._modules.values():
306 | if type(inputs) == tuple:
307 | inputs = module(*inputs)
308 | else:
309 | inputs = module(inputs)
310 | return inputs
311 |
312 | import time
313 | def measure_inference_speed(model, data, max_iter=200, log_interval=50):
314 | model.eval()
315 |
316 | # the first several iterations may be very slow so skip them
317 | num_warmup = 5
318 | pure_inf_time = 0
319 | fps = 0
320 |
321 | # benchmark with 2000 image and take the average
322 | for i in range(max_iter):
323 |
324 | torch.cuda.synchronize()
325 | start_time = time.perf_counter()
326 |
327 | with torch.no_grad():
328 | model(*data)
329 |
330 | torch.cuda.synchronize()
331 | elapsed = time.perf_counter() - start_time
332 |
333 | if i >= num_warmup:
334 | pure_inf_time += elapsed
335 | if (i + 1) % log_interval == 0:
336 | fps = (i + 1 - num_warmup) / pure_inf_time
337 | print(
338 | f'Done image [{i + 1:<3}/ {max_iter}], '
339 | f'fps: {fps:.1f} img / s, '
340 | f'times per image: {1000 / fps:.1f} ms / img',
341 | flush=True)
342 |
343 | if (i + 1) == max_iter:
344 | fps = (i + 1 - num_warmup) / pure_inf_time
345 | print(
346 | f'Overall fps: {fps:.1f} img / s, '
347 | f'times per image: {1000 / fps:.1f} ms / img',
348 | flush=True)
349 | break
350 | return fps
--------------------------------------------------------------------------------
/models/modules/NAFNet/local_arch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | class AvgPool2d(nn.Module):
11 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
12 | super().__init__()
13 | self.kernel_size = kernel_size
14 | self.base_size = base_size
15 | self.auto_pad = auto_pad
16 |
17 | # only used for fast implementation
18 | self.fast_imp = fast_imp
19 | self.rs = [5, 4, 3, 2, 1]
20 | self.max_r1 = self.rs[0]
21 | self.max_r2 = self.rs[0]
22 | self.train_size = train_size
23 |
24 | def extra_repr(self) -> str:
25 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
26 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp
27 | )
28 |
29 | def forward(self, x):
30 | if self.kernel_size is None and self.base_size:
31 | train_size = self.train_size
32 | if isinstance(self.base_size, int):
33 | self.base_size = (self.base_size, self.base_size)
34 | self.kernel_size = list(self.base_size)
35 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2]
36 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1]
37 |
38 | # only used for fast implementation
39 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
40 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
41 |
42 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
43 | return F.adaptive_avg_pool2d(x, 1)
44 |
45 | if self.fast_imp: # Non-equivalent implementation but faster
46 | h, w = x.shape[2:]
47 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
48 | out = F.adaptive_avg_pool2d(x, 1)
49 | else:
50 | r1 = [r for r in self.rs if h % r == 0][0]
51 | r2 = [r for r in self.rs if w % r == 0][0]
52 | # reduction_constraint
53 | r1 = min(self.max_r1, r1)
54 | r2 = min(self.max_r2, r2)
55 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
56 | n, c, h, w = s.shape
57 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
58 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
59 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
60 | else:
61 | n, c, h, w = x.shape
62 | s = x.cumsum(dim=-1).cumsum_(dim=-2)
63 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
64 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
65 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
66 | out = s4 + s1 - s2 - s3
67 | out = out / (k1 * k2)
68 |
69 | if self.auto_pad:
70 | n, c, h, w = x.shape
71 | _h, _w = out.shape[2:]
72 | # print(x.shape, self.kernel_size)
73 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
74 | out = torch.nn.functional.pad(out, pad2d, mode='replicate')
75 |
76 | return out
77 |
78 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
79 | for n, m in model.named_children():
80 | if len(list(m.children())) > 0:
81 | ## compound module, go inside it
82 | replace_layers(m, base_size, train_size, fast_imp, **kwargs)
83 |
84 | if isinstance(m, nn.AdaptiveAvgPool2d):
85 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
86 | assert m.output_size == 1
87 | setattr(model, n, pool)
88 |
89 |
90 | '''
91 | ref.
92 | @article{chu2021tlsc,
93 | title={Revisiting Global Statistics Aggregation for Improving Image Restoration},
94 | author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin},
95 | journal={arXiv preprint arXiv:2112.04491},
96 | year={2021}
97 | }
98 | '''
99 | class Local_Base():
100 | def convert(self, *args, train_size, **kwargs):
101 | replace_layers(self, *args, train_size=train_size, **kwargs)
102 | imgs = torch.rand(train_size)
103 | with torch.no_grad():
104 | self.forward(imgs)
105 |
--------------------------------------------------------------------------------
/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/models/modules/__init__.py
--------------------------------------------------------------------------------
/models/modules/common.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
9 | return nn.Conv2d(
10 | in_channels, out_channels, kernel_size,
11 | padding=(kernel_size // 2), bias=bias)
12 |
13 |
14 | class MeanShift(nn.Conv2d):
15 | def __init__(self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
16 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
17 | std = torch.Tensor(rgb_std)
18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
20 | for p in self.parameters():
21 | p.requires_grad = False
22 |
23 |
24 | class BasicBlock(nn.Sequential):
25 | def __init__(
26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
27 | bn=True, act=nn.ReLU(True)):
28 |
29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
30 | if bn:
31 | m.append(nn.BatchNorm2d(out_channels))
32 | if act is not None:
33 | m.append(act)
34 |
35 | super(BasicBlock, self).__init__(*m)
36 |
37 |
38 | class ResBlock(nn.Module):
39 | def __init__(
40 | self, conv, n_feats, kernel_size,
41 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
42 |
43 | super(ResBlock, self).__init__()
44 | m = []
45 | for i in range(2):
46 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
47 | if bn:
48 | m.append(nn.BatchNorm2d(n_feats))
49 | if i == 0:
50 | m.append(act)
51 |
52 | self.body = nn.Sequential(*m)
53 | self.res_scale = res_scale
54 |
55 | def forward(self, x):
56 | res = self.body(x)
57 | res += x
58 |
59 | return res
60 |
61 |
62 | class Upsampler(nn.Sequential):
63 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
64 |
65 | m = []
66 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
67 | for _ in range(int(math.log(scale, 2))):
68 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
69 | m.append(nn.PixelShuffle(2))
70 | if bn:
71 | m.append(nn.BatchNorm2d(n_feats))
72 | if act == 'relu':
73 | m.append(nn.ReLU(True))
74 | elif act == 'prelu':
75 | m.append(nn.PReLU(n_feats))
76 |
77 | elif scale == 3:
78 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
79 | m.append(nn.PixelShuffle(3))
80 | if bn:
81 | m.append(nn.BatchNorm2d(n_feats))
82 | if act == 'relu':
83 | m.append(nn.ReLU(True))
84 | elif act == 'prelu':
85 | m.append(nn.PReLU(n_feats))
86 | else:
87 | raise NotImplementedError
88 |
89 | super(Upsampler, self).__init__(*m)
--------------------------------------------------------------------------------
/models/modules/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torchvision.models.vgg import vgg16
5 | from torch.nn import functional as F
6 | import torch.fft as fft
7 |
8 | class ReconstructionLoss(nn.Module):
9 | def __init__(self, losstype='l2', eps=1e-3):
10 | super(ReconstructionLoss, self).__init__()
11 | self.losstype = losstype
12 | self.eps = eps
13 |
14 | def forward(self, x, target):
15 | if self.losstype == 'l2':
16 | return torch.mean(torch.sum((x - target)**2, (1, 2, 3)))
17 | elif self.losstype == 'l1':
18 | diff = x - target
19 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3)))
20 | elif self.losstype == 'l_log':
21 | diff = x - target
22 | eps = 1e-6
23 | return torch.mean(torch.sum(-torch.log(1-diff.abs()+eps), (1, 2, 3)))
24 | else:
25 | print("reconstruction loss type error!")
26 | return 0
27 |
28 |
29 | class FFT_Loss(nn.Module):
30 | def __init__(self, losstype='l2', eps=1e-3):
31 | super(FFT_Loss, self).__init__()
32 | # self.fpre =
33 | def forward(self, x, gt):
34 | x = x + 1e-8
35 | gt = gt + 1e-8
36 | x_freq= torch.fft.rfft2(x, norm='backward')
37 | x_amp = torch.abs(x_freq)
38 | x_phase = torch.angle(x_freq)
39 |
40 | gt_freq= torch.fft.rfft2(gt, norm='backward')
41 | gt_amp = torch.abs(gt_freq)
42 | gt_phase = torch.angle(gt_freq)
43 |
44 | loss_amp = torch.mean(torch.sum((x_amp - gt_amp) ** 2))
45 | loss_phase = torch.mean(torch.sum((x_phase - gt_phase) ** 2))
46 | return loss_amp, loss_phase
47 |
48 | # Gradient Loss
49 | class Gradient_Loss(nn.Module):
50 | def __init__(self, losstype='l2'):
51 | super(Gradient_Loss, self).__init__()
52 | a = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
53 | conv1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3)
54 | a = torch.from_numpy(a).float().unsqueeze(0)
55 | a = torch.stack((a, a, a))
56 | conv1.weight = nn.Parameter(a, requires_grad=False)
57 | self.conv1 = conv1.cuda()
58 |
59 | b = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
60 | conv2 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3)
61 | b = torch.from_numpy(b).float().unsqueeze(0)
62 | b = torch.stack((b, b, b))
63 | conv2.weight = nn.Parameter(b, requires_grad=False)
64 | self.conv2 = conv2.cuda()
65 |
66 | # self.Loss_criterion = ReconstructionLoss(losstype)
67 | self.Loss_criterion = nn.L1Loss()
68 |
69 | def forward(self, x, y):
70 | x1 = self.conv1(x)
71 | x2 = self.conv2(x)
72 | # x_total = torch.sqrt(torch.pow(x1, 2) + torch.pow(x2, 2))
73 |
74 | y1 = self.conv1(y)
75 | y2 = self.conv2(y)
76 | # y_total = torch.sqrt(torch.pow(y1, 2) + torch.pow(y2, 2))
77 |
78 | l_h = self.Loss_criterion(x1, y1)
79 | l_v = self.Loss_criterion(x2, y2)
80 | # l_total = self.Loss_criterion(x_total, y_total)
81 | return l_h + l_v #+ l_total
82 |
83 |
84 | class SSIM_Loss(nn.Module):
85 | """Layer to compute the SSIM loss between a pair of images
86 | """
87 | def __init__(self):
88 | super(SSIM_Loss, self).__init__()
89 | self.mu_x_pool = nn.AvgPool2d(3, 1)
90 | self.mu_y_pool = nn.AvgPool2d(3, 1)
91 | self.sig_x_pool = nn.AvgPool2d(3, 1)
92 | self.sig_y_pool = nn.AvgPool2d(3, 1)
93 | self.sig_xy_pool = nn.AvgPool2d(3, 1)
94 |
95 | self.refl = nn.ReflectionPad2d(1)
96 |
97 | self.C1 = 0.01 ** 2
98 | self.C2 = 0.03 ** 2
99 |
100 | def forward(self, x, y):
101 | x = self.refl(x)
102 | y = self.refl(y)
103 |
104 | mu_x = self.mu_x_pool(x)
105 | mu_y = self.mu_y_pool(y)
106 |
107 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
108 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
109 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
110 |
111 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
112 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
113 |
114 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
115 |
116 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
117 | class GANLoss(nn.Module):
118 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
119 | super(GANLoss, self).__init__()
120 | self.gan_type = gan_type.lower()
121 | self.real_label_val = real_label_val
122 | self.fake_label_val = fake_label_val
123 |
124 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
125 | self.loss = nn.BCEWithLogitsLoss()
126 | elif self.gan_type == 'lsgan':
127 | self.loss = nn.MSELoss()
128 | elif self.gan_type == 'wgan-gp':
129 |
130 | def wgan_loss(input, target):
131 | # target is boolean
132 | return -1 * input.mean() if target else input.mean()
133 |
134 | self.loss = wgan_loss
135 | else:
136 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
137 |
138 | def get_target_label(self, input, target_is_real):
139 | if self.gan_type == 'wgan-gp':
140 | return target_is_real
141 | if target_is_real:
142 | return torch.empty_like(input).fill_(self.real_label_val)
143 | else:
144 | return torch.empty_like(input).fill_(self.fake_label_val)
145 |
146 | def forward(self, input, target_is_real):
147 | target_label = self.get_target_label(input, target_is_real)
148 | loss = self.loss(input, target_label)
149 | return loss
150 |
151 |
152 | class GradientPenaltyLoss(nn.Module):
153 | def __init__(self, device=torch.device('cpu')):
154 | super(GradientPenaltyLoss, self).__init__()
155 | self.register_buffer('grad_outputs', torch.Tensor())
156 | self.grad_outputs = self.grad_outputs.to(device)
157 |
158 | def get_grad_outputs(self, input):
159 | if self.grad_outputs.size() != input.size():
160 | self.grad_outputs.resize_(input.size()).fill_(1.0)
161 | return self.grad_outputs
162 |
163 | def forward(self, interp, interp_crit):
164 | grad_outputs = self.get_grad_outputs(interp_crit)
165 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
166 | grad_outputs=grad_outputs, create_graph=True,
167 | retain_graph=True, only_inputs=True)[0]
168 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
169 | grad_interp_norm = grad_interp.norm(2, dim=1)
170 |
171 | loss = ((grad_interp_norm - 1)**2).mean()
172 | return loss
173 |
174 | class TVLoss(nn.Module):
175 | def __init__(self, TVLoss_weight=1):
176 | super(TVLoss, self).__init__()
177 | self.TVLoss_weight = TVLoss_weight
178 |
179 | def forward(self, x):
180 | batch_size = x.size()[0]
181 | h_x = x.size()[2]
182 | w_x = x.size()[3]
183 | count_h = self._tensor_size(x[:, :, 1:, :])
184 | count_w = self._tensor_size(x[:, :, :, 1:])
185 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
186 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
187 | return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
188 |
189 | def _tensor_size(self, t):
190 | return t.size()[1] * t.size()[2] * t.size()[3]
191 |
192 | class TV_extractor(nn.Module):
193 | def __init__(self, TVLoss_weight=1):
194 | super(TV_extractor, self).__init__()
195 | self.TVLoss_weight = TVLoss_weight
196 | self.fil = nn.Parameter(torch.ones(1, 1, 3, 3)/9, requires_grad=False)
197 |
198 | def forward(self, x):
199 | batch_size = x.size()[0]
200 | h_x = x.size()[2]
201 | w_x = x.size()[3]
202 | count_h = self._tensor_size(x[:, :, 1:, :])
203 | count_w = self._tensor_size(x[:, :, :, 1:])
204 | h_tv = torch.abs((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]))
205 | w_tv = torch.abs((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]))
206 | h_tv = F.pad(h_tv, [0,0,0,1], "constant", 0)
207 | w_tv = F.pad(w_tv, [0,1,0,0], "constant", 0)
208 |
209 | h_tv = F.conv2d(h_tv, self.fil, stride=1, padding=1, groups=1)
210 | w_tv = F.conv2d(w_tv, self.fil, stride=1, padding=1, groups=1)
211 |
212 | # print(h_tv.shape, w_tv.shape)
213 | tv = torch.abs(h_tv)+torch.abs(w_tv)
214 | return tv
215 |
216 | def _tensor_size(self, t):
217 | return t.size()[1] * t.size()[2] * t.size()[3]
218 |
219 | class CL_Loss(nn.Module):
220 | def __init__(self, opt):
221 | super(CL_Loss, self).__init__()
222 | self.opt = opt
223 | self.d = nn.MSELoss(size_average=True)
224 | vgg = vgg16(pretrained=False).cuda()
225 | vgg.load_state_dict(torch.load(self.opt['vgg16_model']))
226 | self.loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
227 | for param in self.loss_network.parameters():
228 | param.requires_grad = False
229 |
230 | def forward(self, anchor, postive, negative):
231 | anchor_f = self.loss_network(anchor)
232 | positive_f = self.loss_network(postive)
233 | negative_f = self.loss_network(negative)
234 |
235 | loss = self.d(anchor_f, positive_f)/self.d(anchor_f, negative_f)
236 | return loss
237 |
238 | class Percep_Loss(nn.Module):
239 | def __init__(self, opt):
240 | super(Percep_Loss, self).__init__()
241 | self.opt = opt
242 | self.d = nn.MSELoss(size_average=True)
243 | vgg = vgg16(pretrained=True).cuda()
244 | # vgg.load_state_dict(torch.load(self.opt['vgg16_model']))
245 | # self.loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
246 | # for param in self.loss_network.parameters():
247 | # param.requires_grad = False
248 |
249 | blocks = []
250 | blocks.append(vgg.features[:4].eval())
251 | blocks.append(vgg.features[4:9].eval())
252 | blocks.append(vgg.features[9:16].eval())
253 | blocks.append(vgg.features[16:23].eval())
254 | for bl in blocks:
255 | for p in bl.parameters():
256 | p.requires_grad = False
257 |
258 | self.blocks = torch.nn.ModuleList(blocks)
259 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
260 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
261 |
262 | def forward(self, input, target,feature_layers=[0, 1, 2, 3],weights=[1,1,1,1]):
263 | if input.shape[1] != 3:
264 | input = input.repeat(1, 3, 1, 1)
265 | target = target.repeat(1, 3, 1, 1)
266 | # input = (input-self.mean) / self.std
267 | # target = (target-self.mean) / self.std
268 | loss = 0.0
269 | x = input
270 | y = target
271 | for i,block in enumerate(self.blocks):
272 | x = block(x)
273 | y = block(y)
274 | if i in feature_layers:
275 | loss += weights[i] * self.d(x, y)
276 | return loss
277 |
278 | class SID_loss(nn.Module):
279 | def __init__(self):
280 | super(SID_loss).__init__()
281 |
282 | criterion = nn.KLDivLoss()
283 |
284 | def forward(self,x,y):
285 | p = torch.zeros_like(x).cuda()
286 | q = torch.zeros_like(x).cuda()
287 | Sid = 0
288 | # for i in range(len(x)):
289 | # p[i] = x[i] / torch.sum(x)
290 | # q[i] = y[i] / torch.sum(y)
291 | # print(p[i],q[i])
292 | for j in range(len(x)):
293 | Sid += p[j] * np.log10(p[j] / q[j]) + q[j] * np.log10(q[j] / p[j])
294 | return Sid
295 |
--------------------------------------------------------------------------------
/models/modules/loss_new.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from math import exp
6 | import numpy as np
7 | from torchvision import models
8 |
9 |
10 | #########################################################################################################################################
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch.autograd import Variable
15 | import numpy as np
16 | from math import exp
17 |
18 |
19 | def gaussian(window_size, sigma):
20 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
21 | return gauss / gauss.sum()
22 |
23 |
24 | def create_window(window_size, channel):
25 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
26 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
27 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
28 | return window
29 |
30 |
31 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
32 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
33 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
34 |
35 | mu1_sq = mu1.pow(2)
36 | mu2_sq = mu2.pow(2)
37 | mu1_mu2 = mu1 * mu2
38 |
39 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
40 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
41 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
42 |
43 | C1 = 0.01 ** 2
44 | C2 = 0.03 ** 2
45 |
46 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
47 |
48 | if size_average:
49 | return (-1) * ssim_map.mean()
50 | else:
51 | return (-1) * ssim_map.mean(1).mean(1).mean(1)
52 |
53 |
54 | class SSIMLoss(torch.nn.Module):
55 | def __init__(self, window_size=11, size_average=True):
56 | super(SSIMLoss, self).__init__()
57 | self.window_size = window_size
58 | self.size_average = size_average
59 | self.channel = 1
60 | self.window = create_window(window_size, self.channel)
61 |
62 | def forward(self, img1, img2):
63 | (_, channel, _, _) = img1.size()
64 |
65 | if channel == self.channel and self.window.data.type() == img1.data.type():
66 | window = self.window
67 | else:
68 | window = create_window(self.window_size, channel)
69 |
70 | if img1.is_cuda:
71 | window = window.cuda(img1.get_device())
72 | window = window.type_as(img1)
73 |
74 | self.window = window
75 | self.channel = channel
76 |
77 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
78 |
79 |
80 | def ssim(img1, img2, window_size=11, size_average=True):
81 | (_, channel, _, _) = img1.size()
82 | window = create_window(window_size, channel)
83 |
84 | if img1.is_cuda:
85 | window = window.cuda(img1.get_device())
86 | window = window.type_as(img1)
87 |
88 | return _ssim(img1, img2, window, window_size, channel, size_average)
89 |
90 |
91 | ###########################################################################################################################
92 |
93 |
94 |
95 | class Vgg19(nn.Module):
96 | def __init__(self, id, requires_grad=False):
97 | super(Vgg19, self).__init__()
98 | vgg = models.vgg19(pretrained=False)
99 | vgg.load_state_dict(torch.load('/model/1760921465/NewWork2021/vgg19-dcbb9e9d.pth'))
100 | vgg.eval()
101 | vgg_pretrained_features = vgg.features
102 | self.slice1 = torch.nn.Sequential()
103 | self.slice2 = torch.nn.Sequential()
104 | self.slice3 = torch.nn.Sequential()
105 | self.slice4 = torch.nn.Sequential()
106 | self.slice5 = torch.nn.Sequential()
107 | for x in range(3):
108 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
109 | for x in range(3, 7):
110 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(7, 12):
112 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(12, 21):
114 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(21, 30):
116 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
117 | self.id = id
118 | if not requires_grad:
119 | for param in self.parameters():
120 | param.requires_grad = False
121 |
122 | def forward(self, X):
123 | h_relu1 = self.slice1(X)
124 | h_relu2 = self.slice2(h_relu1)
125 | h_relu3 = self.slice3(h_relu2)
126 | h_relu4 = self.slice4(h_relu3)
127 | h_relu5 = self.slice5(h_relu4)
128 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
129 | return out[self.id]
130 |
131 |
132 | class VGGLoss(nn.Module):
133 | def __init__(self, id, gpu_id=0):
134 | super(VGGLoss, self).__init__()
135 | self.vgg = Vgg19(id).cuda(gpu_id)
136 | self.criterion = nn.MSELoss()
137 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
138 | self.downsample = nn.AvgPool2d(2, stride=2, count_include_pad=False)
139 |
140 | def forward(self, x, y):
141 | while x.size()[3] > 4096:
142 | x, y = self.downsample(x), self.downsample(y)
143 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
144 | # loss = 0
145 | # for i in range(len(x_vgg)):
146 | loss = self.criterion(x_vgg, y_vgg.detach())
147 | return loss
148 |
149 |
150 | ############################################################################################################################3
151 |
152 |
153 | class GradientLoss(nn.Module):
154 | """Gradient Histogram Loss"""
155 | def __init__(self):
156 | super(GradientLoss, self).__init__()
157 | self.bin_num = 64
158 | self.delta = 0.2
159 | self.clip_radius = 0.2
160 | assert(self.clip_radius>0 and self.clip_radius<=1)
161 | self.bin_width = 2*self.clip_radius/self.bin_num
162 | if self.bin_width*255<1:
163 | raise RuntimeError("bin width is too small")
164 | self.bin_mean = np.arange(-self.clip_radius+self.bin_width*0.5, self.clip_radius, self.bin_width)
165 | self.gradient_hist_loss_function = 'L2'
166 | # default is KL loss
167 | if self.gradient_hist_loss_function == 'L2':
168 | self.criterion = nn.MSELoss()
169 | elif self.gradient_hist_loss_function == 'L1':
170 | self.criterion = nn.L1Loss()
171 | else:
172 | self.criterion = nn.KLDivLoss()
173 |
174 | def get_response(self, gradient, mean):
175 | # tmp = torch.mul(torch.pow(torch.add(gradient, -mean), 2), self.delta_square_inverse)
176 | s = (-1) / (self.delta ** 2)
177 | tmp = ((gradient - mean) ** 2) * s
178 | return torch.mean(torch.exp(tmp))
179 |
180 | def get_gradient(self, src):
181 | right_src = src[:, :, 1:, 0:-1] # shift src image right by one pixel
182 | down_src = src[:, :, 0:-1, 1:] # shift src image down by one pixel
183 | clip_src = src[:, :, 0:-1, 0:-1] # make src same size as shift version
184 | d_x = right_src - clip_src
185 | d_y = down_src - clip_src
186 |
187 | return d_x, d_y
188 |
189 | def get_gradient_hist(self, gradient_x, gradient_y):
190 | lx = None
191 | ly = None
192 | for ind_bin in range(self.bin_num):
193 | fx = self.get_response(gradient_x, self.bin_mean[ind_bin])
194 | fy = self.get_response(gradient_y, self.bin_mean[ind_bin])
195 | fx = torch.cuda.FloatTensor([fx])
196 | fy = torch.cuda.FloatTensor([fy])
197 |
198 | if lx is None:
199 | lx = fx
200 | ly = fy
201 | else:
202 | lx = torch.cat((lx, fx), 0)
203 | ly = torch.cat((ly, fy), 0)
204 | # lx = torch.div(lx, torch.sum(lx))
205 | # ly = torch.div(ly, torch.sum(ly))
206 | return lx, ly
207 |
208 | def forward(self, output, target):
209 | output_gradient_x, output_gradient_y = self.get_gradient(output)
210 | target_gradient_x, target_gradient_y = self.get_gradient(target)
211 |
212 | output_gradient_x_hist, output_gradient_y_hist = self.get_gradient_hist(output_gradient_x, output_gradient_y)
213 | target_gradient_x_hist, target_gradient_y_hist = self.get_gradient_hist(target_gradient_x, target_gradient_y)
214 | # loss = self.criterion(output_gradient_x_hist, target_gradient_x_hist) + self.criterion(output_gradient_y_hist, target_gradient_y_hist)
215 | loss = self.criterion(output_gradient_x,target_gradient_x)+self.criterion(output_gradient_y,target_gradient_y)
216 | return loss
217 |
--------------------------------------------------------------------------------
/models/modules/module_util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | def initialize_weights(net_l, scale=1):
8 | if not isinstance(net_l, list):
9 | net_l = [net_l]
10 | for net in net_l:
11 | for m in net.modules():
12 | if isinstance(m, nn.Conv2d):
13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
14 | m.weight.data *= scale # for residual block
15 | if m.bias is not None:
16 | m.bias.data.zero_()
17 | elif isinstance(m, nn.Linear):
18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
19 | m.weight.data *= scale
20 | if m.bias is not None:
21 | m.bias.data.zero_()
22 | elif isinstance(m, nn.BatchNorm2d):
23 | init.constant_(m.weight, 1)
24 | init.constant_(m.bias.data, 0.0)
25 |
26 |
27 | def initialize_weights_xavier(net_l, scale=1):
28 | if not isinstance(net_l, list):
29 | net_l = [net_l]
30 | for net in net_l:
31 | for m in net.modules():
32 | if isinstance(m, nn.Conv2d):
33 | init.xavier_normal_(m.weight)
34 | m.weight.data *= scale # for residual block
35 | if m.bias is not None:
36 | m.bias.data.zero_()
37 | elif isinstance(m, nn.Linear):
38 | init.xavier_normal_(m.weight)
39 | m.weight.data *= scale
40 | if m.bias is not None:
41 | m.bias.data.zero_()
42 | elif isinstance(m, nn.BatchNorm2d):
43 | init.constant_(m.weight, 1)
44 | init.constant_(m.bias.data, 0.0)
45 |
46 | def sine_init(m):
47 | with torch.no_grad():
48 | if hasattr(m, 'weight'):
49 | num_input = m.weight.size(-1)
50 | # See supplement Sec. 1.5 for discussion of factor 30
51 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30)
52 |
53 |
54 | def first_layer_sine_init(m):
55 | with torch.no_grad():
56 | if hasattr(m, 'weight'):
57 | num_input = m.weight.size(-1)
58 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
59 | m.weight.uniform_(-1 / num_input, 1 / num_input)
60 |
61 | def make_layer(block, n_layers):
62 | layers = []
63 | for _ in range(n_layers):
64 | layers.append(block())
65 | return nn.Sequential(*layers)
66 |
67 |
68 | class ResidualBlock_noBN(nn.Module):
69 | '''Residual block w/o BN
70 | ---Conv-ReLU-Conv-+-
71 | |________________|
72 | '''
73 |
74 | def __init__(self, nf=64):
75 | super(ResidualBlock_noBN, self).__init__()
76 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
77 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
78 |
79 | # initialization
80 | initialize_weights([self.conv1, self.conv2], 0.1)
81 |
82 | def forward(self, x):
83 | identity = x
84 | out = F.relu(self.conv1(x), inplace=True)
85 | out = self.conv2(out)
86 | return identity + out
87 |
88 |
89 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
90 | """Warp an image or feature map with optical flow
91 | Args:
92 | x (Tensor): size (N, C, H, W)
93 | flow (Tensor): size (N, H, W, 2), normal value
94 | interp_mode (str): 'nearest' or 'bilinear'
95 | padding_mode (str): 'zeros' or 'border' or 'reflection'
96 |
97 | Returns:
98 | Tensor: warped image or feature map
99 | """
100 | assert x.size()[-2:] == flow.size()[1:3]
101 | B, C, H, W = x.size()
102 | # mesh grid
103 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
104 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
105 | grid.requires_grad = False
106 | grid = grid.type_as(x)
107 | vgrid = grid + flow
108 | # scale grid to [-1,1]
109 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
110 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
111 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
112 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
113 | return output
114 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import logging
3 | from models.modules.NAFNet.NAFNet import NAFNet
4 |
5 | import math
6 | logger = logging.getLogger('base')
7 |
8 |
9 | ####################
10 | # define network
11 | ####################
12 | def define_G(opt):
13 | img_channel = 3
14 | width = 32
15 | enc_blks= [2, 2, 4, 8]
16 | middle_blk_num= 6
17 | dec_blks= [2, 2, 2, 2]
18 |
19 | netG = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
20 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
21 |
22 | return netG
23 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/options/__init__.py
--------------------------------------------------------------------------------
/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 | Loader, Dumper = OrderedYaml()
7 |
8 |
9 | def parse(opt_path, is_train=True):
10 | with open(opt_path, mode='r') as f:
11 | opt = yaml.load(f, Loader=Loader)
12 | # export CUDA_VISIBLE_DEVICES
13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
16 |
17 | opt['is_train'] = is_train
18 | scale = opt['scale']
19 |
20 | # datasets
21 | for phase, dataset in opt['datasets'].items():
22 | phase = phase.split('_')[0]
23 | dataset['phase'] = phase
24 | dataset['scale'] = scale
25 | is_mat = False
26 | if dataset.get('dataroot_gt', None) is not None:
27 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
28 | if dataset['dataroot_gt'].endswith('mat'):
29 | is_mat = True
30 | # if dataset.get('dataroot_GT_bg', None) is not None:
31 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg'])
32 | if dataset.get('dataroot_lq', None) is not None:
33 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
34 | if dataset['dataroot_lq'].endswith('mat'):
35 | is_mat = True
36 | dataset['data_type'] = 'mat' if is_mat else 'img'
37 | if dataset['mode'].endswith('mc'): # for memcached
38 | dataset['data_type'] = 'mc'
39 | dataset['mode'] = dataset['mode'].replace('_mc', '')
40 |
41 | # path
42 | for key, path in opt['path'].items():
43 | if path and key in opt['path'] and key != 'strict_load':
44 | opt['path'][key] = osp.expanduser(path)
45 |
46 | if opt['path']['root'] == None:
47 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
48 |
49 | if is_train:
50 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
51 | opt['path']['experiments_root'] = experiments_root
52 | opt['path']['models'] = osp.join(experiments_root, 'models')
53 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
54 | opt['path']['log'] = experiments_root
55 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
56 |
57 | # change some options for debug mode
58 | if 'debug' in opt['name']:
59 | opt['train']['val_freq'] = 8
60 | opt['logger']['print_freq'] = 1
61 | opt['logger']['save_checkpoint_freq'] = 8
62 | else: # test
63 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
64 | opt['path']['results_root'] = results_root
65 | opt['path']['log'] = results_root
66 | opt['network_G']['scale'] = scale
67 |
68 | return opt
69 |
70 |
71 | def dict2str(opt, indent_l=1):
72 | '''dict to string for logger'''
73 | msg = ''
74 | for k, v in opt.items():
75 | if isinstance(v, dict):
76 | msg += ' ' * (indent_l * 2) + k + ':[\n'
77 | msg += dict2str(v, indent_l + 1)
78 | msg += ' ' * (indent_l * 2) + ']\n'
79 | else:
80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
81 | return msg
82 |
83 |
84 | class NoneDict(dict):
85 | def __missing__(self, key):
86 | return None
87 |
88 |
89 | # convert to NoneDict, which return None for missing key.
90 | def dict_to_nonedict(opt):
91 | if isinstance(opt, dict):
92 | new_opt = dict()
93 | for key, sub_opt in opt.items():
94 | new_opt[key] = dict_to_nonedict(sub_opt)
95 | return NoneDict(**new_opt)
96 | elif isinstance(opt, list):
97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
98 | else:
99 | return opt
100 |
101 |
102 | def check_resume(opt, resume_iter):
103 | '''Check resume states and pretrain_model paths'''
104 | logger = logging.getLogger('base')
105 | if opt['path']['resume_state']:
106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
107 | 'pretrain_model_D', None) is not None:
108 | logger.warning('pretrain_model path will be ignored when resuming training.')
109 |
110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
111 | '{}_G.pth'.format(resume_iter))
112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
113 | if 'gan' in opt['model']:
114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
115 | '{}_D.pth'.format(resume_iter))
116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
117 |
--------------------------------------------------------------------------------
/options/test/test.yml:
--------------------------------------------------------------------------------
1 |
2 | #### general settings
3 |
4 | name: test
5 | use_tb_logger: False
6 | model: CSNorm
7 | scale: 2
8 | gpu_ids: [0]
9 |
10 | #### datasets
11 |
12 | datasets:
13 | val:
14 | name: data_val
15 | mode: JSH_val
16 | # dataroot_gt: './data/example' # path to validation Clean images
17 | # dataroot_lq: './data/example' # path to validation Noisy images
18 | dataroot_gt: './README' # path to validation Clean images
19 | dataroot_lq: './README' # path to validation Noisy images
20 |
21 | #### network structures
22 |
23 | network_G:
24 | which_model_G:
25 | subnet_type: Resnet
26 | in_nc: 3
27 | out_nc: 3
28 | block_num: [8, 8]
29 | scale: 2
30 | init: xavier
31 |
32 |
33 | #### path
34 |
35 | path:
36 | root: ./
37 | pretrain_model_G: ./models/ckpts/NAF_LOL.pth
38 | strict_load: true
39 | resume_state: ~
40 |
41 |
42 | #### training settings: learning rate scheme, loss
43 |
44 | train:
45 | lr_G: !!float 1e-4
46 | beta1: 0.9
47 | beta2: 0.999
48 | niter: 600000
49 | warmup_iter: -1 # no warm up
50 |
51 | lr_scheme: MultiStepLR
52 | lr_steps: [5000, 10000, 15000, 30000, 500000]
53 | lr_gamma: 0.5
54 |
55 | pixel_criterion_forw: l2
56 | pixel_criterion_back: l1
57 | pixel_criterion_hist: l2
58 |
59 | manual_seed: 9
60 |
61 | val_freq: !!float 2000
62 |
63 | vgg16_model:
64 |
65 | lambda_fit_forw: 10
66 | lambda_vgg_forw: 0.
67 | lambda_structure_forw: 1
68 | lambda_orth_forw: 1
69 |
70 | lambda_rec_back: 1
71 | lambda_structure_back: 1
72 | lambda_orth_back: 1
73 |
74 | weight_decay_G: !!float 1e-8
75 | gradient_clipping: 10
76 |
77 |
78 | #### logger
79 |
80 | logger:
81 | print_freq: 500
82 | save_checkpoint_freq: !!float 5000
83 |
--------------------------------------------------------------------------------
/options/train/train_InvDN.yml:
--------------------------------------------------------------------------------
1 |
2 | #### general settings
3 |
4 | name: CSNorm_log
5 | use_tb_logger: False
6 | model: CSNorm
7 | gpu_ids: [0]
8 |
9 | #### datasets
10 |
11 | datasets:
12 | train:
13 | name: data_train
14 | mode: JSH_train
15 | dataroot_gt: 'TRAINDATA/GT' # path to training Clean images
16 | dataroot_lq: 'TRAINDATA/LQ' # path to training Noisy images
17 |
18 | use_shuffle: true
19 | n_workers: 4 # per GPU
20 | batch_size: 4
21 | GT_size: 256
22 | use_flip: true
23 | use_rot: true
24 | color: RGB
25 |
26 | val:
27 | name: data_val
28 | mode: JSH_val
29 | dataroot_gt: '.VALIDATA/GT' # path to validation Clean images
30 | dataroot_lq: '.VALIDATA/LQ' # path to validation Noisy images
31 |
32 | #### network structures
33 |
34 | network_G:
35 | which_model_G:
36 | subnet_type: Resnet
37 | in_nc: 3
38 | out_nc: 3
39 | block_num: [8, 8]
40 | scale: 2
41 | init: xavier
42 |
43 |
44 | #### path
45 |
46 | path:
47 | root: ./
48 | pretrain_model_G:
49 | strict_load: true
50 | resume_state: ~
51 |
52 |
53 | #### training settings: learning rate scheme, loss
54 |
55 | train:
56 | lr_G: !!float 2e-4
57 | beta1: 0.9
58 | beta2: 0.999
59 | niter: 600000
60 | warmup_iter: -1 # no warm up
61 |
62 | lr_scheme: MultiStepLR
63 | lr_steps: [50000, 80000, 100000, 200000, 500000]
64 | lr_gamma: 0.5
65 |
66 | pixel_criterion_forw: l2
67 | pixel_criterion_back: l2
68 | pixel_criterion_hist: l2
69 |
70 | manual_seed: 9
71 |
72 | val_freq: !!float 100
73 |
74 | vgg16_model:
75 |
76 | lambda_fit_forw: 10
77 | lambda_vgg_forw: 0.
78 | lambda_structure_forw: 1
79 | lambda_orth_forw: 1
80 |
81 | lambda_rec_back: 1
82 | lambda_structure_back: 1
83 | lambda_orth_back: 1
84 |
85 | weight_decay_G: !!float 1e-8
86 | gradient_clipping: 10
87 |
88 |
89 | #### logger
90 |
91 | logger:
92 | print_freq: 40
93 | save_checkpoint_freq: !!float 5000
94 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import random
5 | import logging
6 |
7 | import torch
8 | import torch.distributed as dist
9 | import torch.multiprocessing as mp
10 | from data.data_sampler import DistIterSampler
11 |
12 | import options.options as option
13 | from utils import util
14 | from data import create_dataloader, create_dataset
15 | from models import create_model
16 | import numpy as np
17 |
18 |
19 | def init_dist(backend='nccl', **kwargs):
20 | ''' initialization for distributed training'''
21 | # if mp.get_start_method(allow_none=True) is None:
22 | if mp.get_start_method(allow_none=True) != 'spawn':
23 | mp.set_start_method('spawn')
24 | rank = int(os.environ['RANK'])
25 | num_gpus = torch.cuda.device_count()
26 | torch.cuda.set_device(rank % num_gpus)
27 | dist.init_process_group(backend=backend, **kwargs)
28 |
29 |
30 | def main():
31 | #### options
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
34 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
35 | help='job launcher')
36 | parser.add_argument('--local_rank', type=int, default=0)
37 | args = parser.parse_args()
38 | opt = option.parse(args.opt, is_train=True)
39 |
40 | #### distributed training settings
41 | if args.launcher == 'none': # disabled distributed training
42 | opt['dist'] = False
43 | rank = -1
44 | print('Disabled distributed training.')
45 | else:
46 | opt['dist'] = True
47 | init_dist()
48 | world_size = torch.distributed.get_world_size()
49 | rank = torch.distributed.get_rank()
50 |
51 | #### loading resume state if exists
52 | if opt['path'].get('resume_state', None):
53 | # distributed resuming: all load into default GPU
54 | device_id = torch.cuda.current_device()
55 | resume_state = torch.load(opt['path']['resume_state'],
56 | map_location=lambda storage, loc: storage.cuda(device_id))
57 | option.check_resume(opt, resume_state['iter']) # check resume options
58 | else:
59 | resume_state = None
60 |
61 | #### mkdir and loggers
62 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
63 | if resume_state is None:
64 | util.mkdir_and_rename(
65 | opt['path']['experiments_root']) # rename experiment folder if exists
66 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
67 | and 'pretrain_model' not in key and 'resume' not in key))
68 |
69 | # config loggers. Before it, the log will not work
70 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
71 | screen=True, tofile=True)
72 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
73 | screen=True, tofile=True)
74 | logger = logging.getLogger('base')
75 | logger.info(option.dict2str(opt))
76 | # tensorboard logger
77 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
78 | version = float(torch.__version__[0:3])
79 | if version >= 1.1: # PyTorch 1.1
80 | from tensorboardX import SummaryWriter
81 | else:
82 | logger.info(
83 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
84 | from tensorboardX import SummaryWriter
85 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name'])
86 | else:
87 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
88 | logger = logging.getLogger('base')
89 |
90 | # convert to NoneDict, which returns None for missing keys
91 | opt = option.dict_to_nonedict(opt)
92 |
93 | #### random seed
94 | seed = opt['train']['manual_seed']
95 | if seed is None:
96 | seed = random.randint(1, 10000)
97 | if rank <= 0:
98 | logger.info('Random seed: {}'.format(seed))
99 | util.set_random_seed(seed)
100 |
101 | torch.backends.cudnn.benchmark = True
102 | # torch.backends.cudnn.deterministic = True
103 |
104 | #### create train and val dataloader
105 | dataset_ratio = 200 # enlarge the size of each epoch
106 | for phase, dataset_opt in opt['datasets'].items():
107 | if phase == 'val':
108 | val_set = create_dataset(dataset_opt)
109 | val_loader = create_dataloader(val_set, dataset_opt, opt, None)
110 | if rank <= 0:
111 | logger.info('Number of val images in [{:s}]: {:d}'.format(
112 | dataset_opt['name'], len(val_set)))
113 | else:
114 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
115 |
116 | #### create model
117 | model = create_model(opt)
118 |
119 | # #### resume training
120 | # if resume_state:
121 | # logger.info('Resuming training from epoch: {}, iter: {}.'.format(
122 | # resume_state['epoch'], resume_state['iter']))
123 | #
124 | # start_epoch = resume_state['epoch']
125 | # current_step = resume_state['iter']
126 | # model.resume_training(resume_state) # handle optimizers and schedulers
127 | # else:
128 | # current_step = 0
129 | # start_epoch = 0
130 |
131 | #### test
132 | avg_psnr = 0.0
133 | idx = 0
134 | for val_data in val_loader:
135 | idx += 1
136 | # img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
137 | # img_dir = os.path.join(opt['path']['val_images'], img_name)
138 | # util.mkdir(img_dir)
139 | model.feed_data_test(val_data)
140 | model.test()
141 | visuals = model.get_current_visuals()
142 | img_input = visuals['img_input'].numpy()
143 | img_pred = visuals['img_pred'].numpy()
144 | img_gt = visuals['img_gt'].numpy()
145 |
146 | ########################## save images for visualization###################
147 | img_input = img_input[::-1,:,:]
148 | img_pred1 = img_pred[::-1,:,:]
149 | img_gt1 = img_gt[::-1,:,:]
150 |
151 | img_input = img_input.transpose(1,2,0)
152 | img_pred1 = img_pred1.transpose(1,2,0)
153 | img_gt1 = img_gt1.transpose(1,2,0)
154 |
155 | from PIL import Image
156 |
157 | img_pred1 = np.clip(img_pred1,0,1)
158 | Image.fromarray((img_pred1*255).astype(np.uint8)).save(os.path.join(opt['path']['val_images'], '%03d.png'%idx))
159 |
160 | img_input = np.clip(img_input,0,1)
161 | Image.fromarray((img_input*255).astype(np.uint8)).save(os.path.join(opt['path']['val_images'], '%03d_i.png'%idx))
162 |
163 | img_gt1 = np.clip(img_gt1,0,1)
164 | Image.fromarray((img_gt1*255).astype(np.uint8)).save(os.path.join(opt['path']['val_images'], '%03d_t.png'%idx))
165 |
166 |
167 | def compute_psnr(img_orig, img_out, peak):
168 | mse = np.mean(np.square(img_orig - img_out))
169 | psnr = 10 * np.log10(peak * peak / mse)
170 | return psnr
171 | curr_psnr = compute_psnr(img_pred, img_gt, 1)
172 | avg_psnr += curr_psnr
173 | print('idx', idx, curr_psnr)
174 |
175 | avg_psnr = avg_psnr / idx
176 |
177 | logger.info('# Validation # PSNR: {:.4e}.'.format(avg_psnr))
178 |
179 |
180 |
181 | if __name__ == '__main__':
182 | main()
183 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import random
5 | import logging
6 |
7 | import torch
8 | import torch.distributed as dist
9 | import torch.multiprocessing as mp
10 | from data.data_sampler import DistIterSampler
11 |
12 | import options.options as option
13 | from utils import util
14 | from data import create_dataloader, create_dataset
15 | from models import create_model
16 | import numpy as np
17 |
18 |
19 | def init_dist(backend='nccl', **kwargs):
20 | ''' initialization for distributed training'''
21 | # if mp.get_start_method(allow_none=True) is None:
22 | if mp.get_start_method(allow_none=True) != 'spawn':
23 | mp.set_start_method('spawn')
24 | rank = int(os.environ['RANK'])
25 | num_gpus = torch.cuda.device_count()
26 | torch.cuda.set_device(rank % num_gpus)
27 | dist.init_process_group(backend=backend, **kwargs)
28 |
29 |
30 | def main():
31 | #### options
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
34 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
35 | help='job launcher')
36 | parser.add_argument('--local_rank', type=int, default=0)
37 | args = parser.parse_args()
38 | opt = option.parse(args.opt, is_train=True)
39 |
40 | #### distributed training settings
41 | if args.launcher == 'none': # disabled distributed training
42 | opt['dist'] = False
43 | rank = -1
44 | print('Disabled distributed training.')
45 | else:
46 | opt['dist'] = True
47 | init_dist()
48 | world_size = torch.distributed.get_world_size()
49 | rank = torch.distributed.get_rank()
50 |
51 | #### loading resume state if exists
52 | if opt['path'].get('resume_state', None):
53 | # distributed resuming: all load into default GPU
54 | device_id = torch.cuda.current_device()
55 | resume_state = torch.load(opt['path']['resume_state'],
56 | map_location=lambda storage, loc: storage.cuda(device_id))
57 | option.check_resume(opt, resume_state['iter']) # check resume options
58 | else:
59 | resume_state = None
60 |
61 | #### mkdir and loggers
62 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
63 | if resume_state is None:
64 | util.mkdir_and_rename(
65 | opt['path']['experiments_root']) # rename experiment folder if exists
66 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
67 | and 'pretrain_model' not in key and 'resume' not in key))
68 |
69 | # config loggers. Before it, the log will not work
70 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
71 | screen=True, tofile=True)
72 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
73 | screen=True, tofile=True)
74 | logger = logging.getLogger('base')
75 | logger.info(option.dict2str(opt))
76 | # tensorboard logger
77 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
78 | version = float(torch.__version__[0:3])
79 | if version >= 1.1: # PyTorch 1.1
80 | from tensorboardX import SummaryWriter
81 | else:
82 | logger.info(
83 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
84 | from tensorboardX import SummaryWriter
85 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name'])
86 | else:
87 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
88 | logger = logging.getLogger('base')
89 |
90 | # convert to NoneDict, which returns None for missing keys
91 | opt = option.dict_to_nonedict(opt)
92 |
93 | #### random seed
94 | seed = opt['train']['manual_seed']
95 | if seed is None:
96 | seed = random.randint(1, 10000)
97 | if rank <= 0:
98 | logger.info('Random seed: {}'.format(seed))
99 | util.set_random_seed(seed)
100 |
101 | torch.backends.cudnn.benchmark = True
102 | # torch.backends.cudnn.deterministic = True
103 |
104 | #### create train and val dataloader
105 | dataset_ratio = 1 # enlarge the size of each epoch
106 | for phase, dataset_opt in opt['datasets'].items():
107 | if phase == 'train':
108 | train_set = create_dataset(dataset_opt)
109 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
110 |
111 | total_iters = int(opt['train']['niter'])
112 | total_epochs = int(math.ceil(total_iters / train_size))
113 | if opt['dist']:
114 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
115 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
116 | else:
117 | train_sampler = None
118 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
119 | if rank <= 0:
120 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
121 | len(train_set), train_size))
122 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
123 | total_epochs, total_iters))
124 | elif phase == 'val':
125 | val_set = create_dataset(dataset_opt)
126 | val_loader = create_dataloader(val_set, dataset_opt, opt, None)
127 | if rank <= 0:
128 | logger.info('Number of val images in [{:s}]: {:d}'.format(
129 | dataset_opt['name'], len(val_set)))
130 | else:
131 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
132 | assert train_loader is not None
133 |
134 | #### create model
135 | model = create_model(opt)
136 |
137 | #### resume training
138 | if resume_state:
139 | logger.info('Resuming training from epoch: {}, iter: {}.'.format(
140 | resume_state['epoch'], resume_state['iter']))
141 |
142 | start_epoch = resume_state['epoch']
143 | current_step = resume_state['iter']
144 | model.resume_training(resume_state) # handle optimizers and schedulers
145 | else:
146 | current_step = 0
147 | start_epoch = 0
148 |
149 | val_freq = opt['train']['val_freq']
150 | #### training
151 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
152 | for epoch in range(start_epoch, total_epochs + 1):
153 | if opt['dist']:
154 | train_sampler.set_epoch(epoch)
155 | for _, train_data in enumerate(train_loader):
156 | current_step += 1
157 |
158 | if current_step > total_iters:
159 | break
160 | #### training
161 | model.feed_data(train_data)
162 | model.optimize_parameters(current_step)
163 |
164 | #### update learning rate
165 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
166 |
167 | #### log
168 | if current_step % opt['logger']['print_freq'] == 0:
169 | logs = model.get_current_log()
170 | message = ' '.format(
171 | epoch, current_step, model.get_current_learning_rate())
172 | for k, v in logs.items():
173 | message += '{:s}: {:.4e} '.format(k, v)
174 | # tensorboard logger
175 | if opt['use_tb_logger'] and 'debug' not in opt['name']:
176 | if rank <= 0:
177 | tb_logger.add_scalar(k, v, current_step)
178 | if rank <= 0:
179 | logger.info(message)
180 |
181 | # validation
182 | if current_step % val_freq == 0 and rank <= 0:
183 | avg_psnr = 0.0
184 | idx = 0
185 | for val_data in val_loader:
186 | idx += 1
187 | model.feed_data_test(val_data)
188 | model.test()
189 | visuals = model.get_current_visuals()
190 | img_pred = visuals['img_pred'].numpy()
191 | img_gt = visuals['img_gt'].numpy()
192 |
193 | def compute_psnr(img_orig, img_out, peak):
194 | mse = np.mean(np.square(img_orig - img_out))
195 | psnr = 10 * np.log10(peak * peak / mse)
196 | return psnr
197 | curr_psnr = compute_psnr(img_pred, img_gt, 1)
198 | avg_psnr += curr_psnr
199 | print('idx', idx, curr_psnr)
200 |
201 | avg_psnr = avg_psnr / idx
202 |
203 | # log
204 | logger.info('# Validation # PSNR: {:.4e}.'.format(avg_psnr))
205 | logger_val = logging.getLogger('val') # validation logger
206 | logger_val.info(' psnr: {:.4e}.'.format(
207 | epoch, current_step, avg_psnr))
208 |
209 | #### save models and training states
210 | if current_step % opt['logger']['save_checkpoint_freq'] == 0:
211 | if rank <= 0:
212 | logger.info('Saving models and training states.')
213 | model.save(current_step)
214 | # model.save_training_state(epoch, current_step)
215 |
216 | if rank <= 0:
217 | logger.info('Saving the final model.')
218 | model.save('latest')
219 | logger.info('End of training.')
220 |
221 |
222 | if __name__ == '__main__':
223 | main()
224 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/utils/__init__.py
--------------------------------------------------------------------------------
/utils/pytorch_ssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssimmap(img1, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu1_sq = mu1.pow(2)
20 |
21 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
22 | # sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
23 | # sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
24 |
25 | C1 = 0.01**2
26 | C2 = 0.03**2
27 |
28 | # ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
29 | feat_map = torch.cat((mu1, C1/(mu1+C1), sigma1_sq, C2/(sigma1_sq+C2)), 1)
30 |
31 | return feat_map
32 |
33 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
34 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
35 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
36 |
37 | mu1_sq = mu1.pow(2)
38 | mu2_sq = mu2.pow(2)
39 | mu1_mu2 = mu1*mu2
40 |
41 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
42 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
43 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
44 |
45 | C1 = 0.01**2
46 | C2 = 0.03**2
47 |
48 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
49 |
50 | if size_average:
51 | return (-1)*ssim_map.mean()
52 | else:
53 | return (-1)*ssim_map.mean(1).mean(1).mean(1)
54 |
55 | class SSIMMap(torch.nn.Module):
56 | def __init__(self, window_size = 11, size_average = True):
57 | super(SSIMMap, self).__init__()
58 | self.window_size = window_size
59 | self.size_average = size_average
60 | self.channel = 1
61 | self.window = create_window(window_size, self.channel)
62 |
63 | def forward(self, img1):
64 | (_, channel, _, _) = img1.size()
65 |
66 | if channel == self.channel and self.window.data.type() == img1.data.type():
67 | window = self.window
68 | else:
69 | window = create_window(self.window_size, channel)
70 |
71 | if img1.is_cuda:
72 | window = window.cuda(img1.get_device())
73 | window = window.type_as(img1)
74 |
75 | self.window = window
76 | self.channel = channel
77 |
78 | return _ssimmap(img1, window, self.window_size, channel, self.size_average)
79 |
80 | class SSIM(torch.nn.Module):
81 | def __init__(self, window_size = 11, size_average = True):
82 | super(SSIM, self).__init__()
83 | self.window_size = window_size
84 | self.size_average = size_average
85 | self.channel = 1
86 | self.window = create_window(window_size, self.channel)
87 |
88 | def forward(self, img1, img2):
89 | (_, channel, _, _) = img1.size()
90 |
91 | if channel == self.channel and self.window.data.type() == img1.data.type():
92 | window = self.window
93 | else:
94 | window = create_window(self.window_size, channel)
95 |
96 | if img1.is_cuda:
97 | window = window.cuda(img1.get_device())
98 | window = window.type_as(img1)
99 |
100 | self.window = window
101 | self.channel = channel
102 |
103 |
104 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
105 |
106 | def ssim(img1, img2, window_size = 11, size_average = True):
107 | (_, channel, _, _) = img1.size()
108 | window = create_window(window_size, channel)
109 |
110 | if img1.is_cuda:
111 | window = window.cuda(img1.get_device())
112 | window = window.type_as(img1)
113 |
114 | return _ssim(img1, img2, window, window_size, channel, size_average)
115 |
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import math
5 | from datetime import datetime
6 | import random
7 | import logging
8 | from collections import OrderedDict
9 | import numpy as np
10 | import cv2
11 | import torch
12 | from torchvision.utils import make_grid
13 | from shutil import get_terminal_size
14 | import imageio
15 | import yaml
16 | try:
17 | from yaml import CLoader as Loader, CDumper as Dumper
18 | except ImportError:
19 | from yaml import Loader, Dumper
20 |
21 | def rgb2yuv(img):
22 | y = 0.299 * img[:, 0] + 0.587 * img[:, 1] + 0.114 * img[:, 2]
23 | u = -0.169 * img[:, 0] - 0.331 * img[:, 1] + 0.5 * img[:, 2] + 0.5
24 | v = 0.5 * img[:, 0] - 0.419 * img[:, 1] - 0.081 * img[:, 2] + 0.5
25 | out = torch.stack((y, u, v))
26 | out = out.transpose(0, 1)
27 | return out
28 |
29 |
30 | def yuv2rgb(img):
31 | r = img[:, 0] + 1.4075 * (img[:, 2] - 0.5)
32 | g = img[:, 0] - 0.3455 * (img[:, 1] - 0.5) - 0.7169 * (img[:, 2] - 0.5)
33 | b = img[:, 0] + 1.779 * (img[:, 1] - 0.5)
34 | out = torch.stack((r, g, b))
35 | out = out.transpose(0, 1)
36 | return out
37 |
38 | def OrderedYaml():
39 | '''yaml orderedDict support'''
40 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
41 |
42 | def dict_representer(dumper, data):
43 | return dumper.represent_dict(data.items())
44 |
45 | def dict_constructor(loader, node):
46 | return OrderedDict(loader.construct_pairs(node))
47 |
48 | Dumper.add_representer(OrderedDict, dict_representer)
49 | Loader.add_constructor(_mapping_tag, dict_constructor)
50 | return Loader, Dumper
51 |
52 | def save_results_yuv(pred, index, test_img_dir):
53 | test_pred = np.squeeze(pred)
54 | test_pred = np.clip(test_pred, 0, 1) * 1023
55 | test_pred = np.uint16(test_pred)
56 |
57 | # split image
58 | pred_y = test_pred[:, :, 0]
59 | pred_u = test_pred[:, :, 1]
60 | pred_v = test_pred[:, :, 2]
61 |
62 | # save prediction - must be saved in separate channels due to 16-bit pixel depth
63 | imageio.imwrite(os.path.join(test_img_dir, "{}-y_pred.png".format(str(int(index) + 1).zfill(2))),
64 | pred_y)
65 | imageio.imwrite(os.path.join(test_img_dir, "{}-u_pred.png".format(str(int(index) + 1).zfill(2))),
66 | pred_u)
67 | imageio.imwrite(os.path.join(test_img_dir, "{}-v_pred.png".format(str(int(index) + 1).zfill(2))),
68 | pred_v)
69 |
70 |
71 |
72 | ####################
73 | # miscellaneous
74 | ####################
75 |
76 |
77 | def get_timestamp():
78 | return datetime.now().strftime('%y%m%d-%H%M%S')
79 |
80 |
81 | def mkdir(path):
82 | if not os.path.exists(path):
83 | os.makedirs(path)
84 |
85 |
86 | def mkdirs(paths):
87 | if isinstance(paths, str):
88 | mkdir(paths)
89 | else:
90 | for path in paths:
91 | mkdir(path)
92 |
93 |
94 | def mkdir_and_rename(path):
95 | if os.path.exists(path):
96 | new_name = path + '_archived_' + get_timestamp()
97 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
98 | logger = logging.getLogger('base')
99 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
100 | os.rename(path, new_name)
101 | os.makedirs(path)
102 |
103 |
104 | def set_random_seed(seed):
105 | random.seed(seed)
106 | np.random.seed(seed)
107 | torch.manual_seed(seed)
108 | torch.cuda.manual_seed_all(seed)
109 |
110 |
111 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
112 | '''set up logger'''
113 | lg = logging.getLogger(logger_name)
114 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
115 | datefmt='%y-%m-%d %H:%M:%S')
116 | lg.setLevel(level)
117 | if tofile:
118 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
119 | fh = logging.FileHandler(log_file, mode='w')
120 | fh.setFormatter(formatter)
121 | lg.addHandler(fh)
122 | if screen:
123 | sh = logging.StreamHandler()
124 | sh.setFormatter(formatter)
125 | lg.addHandler(sh)
126 |
127 |
128 | ####################
129 | # image convert
130 | ####################
131 |
132 |
133 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
134 | '''
135 | Converts a torch Tensor into an image Numpy array
136 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
137 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
138 | '''
139 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
140 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
141 | n_dim = tensor.dim()
142 | if n_dim == 4:
143 | n_img = len(tensor)
144 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
145 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
146 | elif n_dim == 3:
147 | img_np = tensor.numpy()
148 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
149 | elif n_dim == 2:
150 | img_np = tensor.numpy()
151 | else:
152 | raise TypeError(
153 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
154 | if out_type == np.uint8:
155 | img_np = (img_np * 255.0).round()
156 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
157 | return img_np.astype(out_type)
158 |
159 | def tensor2img_Real(tensor, out_type=np.uint8, min_max=(0, 1)):
160 | '''
161 | Converts a torch Tensor into an image Numpy array
162 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
163 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
164 | '''
165 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
166 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
167 | n_dim = tensor.dim()
168 | if n_dim == 4:
169 | # n_img = len(tensor)
170 | # img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
171 | img_np = tensor.numpy()
172 | # img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
173 | elif n_dim == 3:
174 | img_np = tensor.numpy()
175 | # img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
176 | elif n_dim == 2:
177 | img_np = tensor.numpy()
178 | else:
179 | raise TypeError(
180 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
181 | if out_type == np.uint8:
182 | img_np = (img_np * 255.0).round()
183 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
184 | return img_np.astype(out_type)
185 |
186 | def save_img(img, img_path, mode='RGB'):
187 | cv2.imwrite(img_path, img)
188 |
189 |
190 | ####################
191 | # metric
192 | ####################
193 |
194 |
195 | def calculate_psnr(img1, img2):
196 | # img1 and img2 have range [0, 255]
197 | img1 = img1.astype(np.float64)
198 | img2 = img2.astype(np.float64)
199 | mse = np.mean((img1 - img2)**2)
200 | if mse == 0:
201 | return float('inf')
202 | return 20 * math.log10(255.0 / math.sqrt(mse))
203 |
204 |
205 | def ssim(img1, img2):
206 | C1 = (0.01 * 255)**2
207 | C2 = (0.03 * 255)**2
208 |
209 | img1 = img1.astype(np.float64)
210 | img2 = img2.astype(np.float64)
211 | kernel = cv2.getGaussianKernel(11, 1.5)
212 | window = np.outer(kernel, kernel.transpose())
213 |
214 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
215 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
216 | mu1_sq = mu1**2
217 | mu2_sq = mu2**2
218 | mu1_mu2 = mu1 * mu2
219 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
220 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
221 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
222 |
223 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
224 | (sigma1_sq + sigma2_sq + C2))
225 | return ssim_map.mean()
226 |
227 |
228 | def calculate_ssim(img1, img2):
229 | '''calculate SSIM
230 | the same outputs as MATLAB's
231 | img1, img2: [0, 255]
232 | '''
233 | if not img1.shape == img2.shape:
234 | raise ValueError('Input images must have the same dimensions.')
235 | if img1.ndim == 2:
236 | return ssim(img1, img2)
237 | elif img1.ndim == 3:
238 | if img1.shape[2] == 3:
239 | ssims = []
240 | for i in range(3):
241 | ssims.append(ssim(img1, img2))
242 | return np.array(ssims).mean()
243 | elif img1.shape[2] == 1:
244 | return ssim(np.squeeze(img1), np.squeeze(img2))
245 | else:
246 | raise ValueError('Wrong input image dimensions.')
247 |
248 |
249 | class ProgressBar(object):
250 | '''A progress bar which can print the progress
251 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
252 | '''
253 |
254 | def __init__(self, task_num=0, bar_width=50, start=True):
255 | self.task_num = task_num
256 | max_bar_width = self._get_max_bar_width()
257 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
258 | self.completed = 0
259 | if start:
260 | self.start()
261 |
262 | def _get_max_bar_width(self):
263 | terminal_width, _ = get_terminal_size()
264 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
265 | if max_bar_width < 10:
266 | print('terminal width is too small ({}), please consider widen the terminal for better '
267 | 'progressbar visualization'.format(terminal_width))
268 | max_bar_width = 10
269 | return max_bar_width
270 |
271 | def start(self):
272 | if self.task_num > 0:
273 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
274 | ' ' * self.bar_width, self.task_num, 'Start...'))
275 | else:
276 | sys.stdout.write('completed: 0, elapsed: 0s')
277 | sys.stdout.flush()
278 | self.start_time = time.time()
279 |
280 | def update(self, msg='In progress...'):
281 | self.completed += 1
282 | elapsed = time.time() - self.start_time
283 | fps = self.completed / elapsed
284 | if self.task_num > 0:
285 | percentage = self.completed / float(self.task_num)
286 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
287 | mark_width = int(self.bar_width * percentage)
288 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
289 | sys.stdout.write('\033[2F') # cursor up 2 lines
290 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
291 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
292 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
293 | else:
294 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
295 | self.completed, int(elapsed + 0.5), fps))
296 | sys.stdout.flush()
297 |
--------------------------------------------------------------------------------