├── .gitignore
├── LICENSE.txt
├── VERSION
├── assets
├── gopro.gif
├── memory_comparison.png
├── nightrain30.gif
├── raindrop.gif
├── snowwww.gif
└── turtle.png
├── basicsr
├── data
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── data_util.py
│ ├── prefetch_dataloader.py
│ ├── transforms.py
│ ├── video_image_dataset.py
│ └── video_super_image_dataset.py
├── inference.py
├── inference_no_ground_truth.py
├── loss
│ └── __init__.py
├── metrics
│ ├── __init__.py
│ ├── metric_util.py
│ └── psnr_ssim.py
├── models
│ ├── __init__.py
│ ├── archs
│ │ ├── turtle_arch.py
│ │ ├── turtle_t1_arch.py
│ │ └── turtlesuper_t1_arch.py
│ ├── base_model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── loss_util.py
│ │ └── losses.py
│ ├── lr_scheduler.py
│ └── video_restoration_model.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── create_lmdb.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── face_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ ├── options.py
│ ├── util.py
│ └── utils_video.py
└── version.py
├── cog.yaml
├── make_video.py
├── options
├── Turtle_Deblur_Gopro.yml
├── Turtle_Denoise_Davis.yml
├── Turtle_Derain.yml
├── Turtle_Derain_VRDS.yml
├── Turtle_Desnow.yml
└── Turtle_SR_MVSR.yml
├── readme.md
├── requirements.txt
├── setup.cfg
├── setup.py
└── video_to_frames.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea/*
3 | experiments
4 | logs/
5 | *results*
6 | *__pycache__*
7 | *.sh
8 | datasets
9 | basicsr.egg-info
10 | tb_logger
11 | placeholder_datasetDAVIS
12 | placeholder_dataset
13 | outputs
14 | options
15 | basicsr/inference_outputs
16 | inference_outputs
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Huawei Technologies Co., Ltd.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 1.0.0
2 |
--------------------------------------------------------------------------------
/assets/gopro.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/gopro.gif
--------------------------------------------------------------------------------
/assets/memory_comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/memory_comparison.png
--------------------------------------------------------------------------------
/assets/nightrain30.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/nightrain30.gif
--------------------------------------------------------------------------------
/assets/raindrop.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/raindrop.gif
--------------------------------------------------------------------------------
/assets/snowwww.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/snowwww.gif
--------------------------------------------------------------------------------
/assets/turtle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/turtle.png
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 |
6 | import importlib
7 | import numpy as np
8 | import random
9 | import torch
10 | import torch.utils.data
11 | from functools import partial
12 | from os import path as osp
13 |
14 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
15 | from basicsr.utils import get_root_logger, scandir
16 | from basicsr.utils.dist_util import get_dist_info
17 |
18 | __all__ = ['create_dataset', 'create_dataloader']
19 |
20 | # automatically scan and import dataset modules
21 | # scan all the files under the data folder with '_dataset' in file names
22 | data_folder = osp.dirname(osp.abspath(__file__))
23 | dataset_filenames = [
24 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
25 | if v.endswith('_dataset.py')
26 | ]
27 | # import all the dataset modules
28 | _dataset_modules = [
29 | importlib.import_module(f'basicsr.data.{file_name}')
30 | for file_name in dataset_filenames
31 | ]
32 |
33 |
34 | def create_dataset(dataset_opt):
35 | """Create dataset.
36 |
37 | Args:
38 | dataset_opt (dict): Configuration for dataset. It constains:
39 | name (str): Dataset name.
40 | type (str): Dataset type.
41 | """
42 | dataset_type = dataset_opt['type']
43 |
44 | # dynamic instantiation
45 | for module in _dataset_modules:
46 | dataset_cls = getattr(module, dataset_type, None)
47 | if dataset_cls is not None:
48 | break
49 | if dataset_cls is None:
50 | raise ValueError(f'Dataset {dataset_type} is not found.')
51 |
52 | dataset = dataset_cls(dataset_opt)
53 |
54 | logger = get_root_logger()
55 | logger.info(
56 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
57 | 'is created.')
58 | return dataset
59 |
60 |
61 | def create_dataloader(dataset,
62 | dataset_opt,
63 | num_gpu=1,
64 | dist=False,
65 | sampler=None,
66 | seed=None):
67 | """Create dataloader.
68 |
69 | Args:
70 | dataset (torch.utils.data.Dataset): Dataset.
71 | dataset_opt (dict): Dataset options. It contains the following keys:
72 | phase (str): 'train' or 'val'.
73 | num_worker_per_gpu (int): Number of workers for each GPU.
74 | batch_size_per_gpu (int): Training batch size for each GPU.
75 | num_gpu (int): Number of GPUs. Used only in the train phase.
76 | Default: 1.
77 | dist (bool): Whether in distributed training. Used only in the train
78 | phase. Default: False.
79 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
80 | seed (int | None): Seed. Default: None
81 | """
82 | phase = dataset_opt['phase']
83 | rank, _ = get_dist_info()
84 |
85 | if phase == 'train':
86 | if dist: # distributed training
87 | batch_size = dataset_opt['batch_size_per_gpu']
88 | num_workers = dataset_opt['num_worker_per_gpu']
89 | else: # non-distributed training
90 | multiplier = 1 if num_gpu == 0 else num_gpu
91 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
92 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
93 | dataloader_args = dict(
94 | dataset=dataset,
95 | batch_size=batch_size,
96 | shuffle=False,
97 | num_workers=num_workers,
98 | sampler=sampler,
99 | drop_last=True)
100 | if sampler is None:
101 | dataloader_args['shuffle'] = True
102 | dataloader_args['worker_init_fn'] = partial(
103 | worker_init_fn, num_workers=num_workers, rank=rank,
104 | seed=seed) if seed is not None else None
105 |
106 | elif phase in ['val', 'test']: # validation
107 | dataloader_args = dict(
108 | dataset=dataset,
109 | batch_size=1,
110 | shuffle=False,
111 | num_workers=0)
112 |
113 | else:
114 | raise ValueError(f'Wrong dataset phase: {phase}. '
115 | "Supported ones are 'train', 'val' and 'test'.")
116 |
117 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
118 |
119 | prefetch_mode = dataset_opt.get('prefetch_mode')
120 | if prefetch_mode == 'cpu': # CPUPrefetcher
121 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
122 | logger = get_root_logger()
123 | logger.info(f'Use {prefetch_mode} prefetch dataloader: '
124 | f'num_prefetch_queue = {num_prefetch_queue}')
125 | return PrefetchDataLoader(
126 | num_prefetch_queue=num_prefetch_queue, **dataloader_args)
127 | else:
128 | # prefetch_mode=None: Normal dataloader
129 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
130 | return torch.utils.data.DataLoader(**dataloader_args)
131 |
132 |
133 | def worker_init_fn(worker_id, num_workers, rank, seed):
134 | # Set the worker seed to num_workers * rank + worker_id + seed
135 | worker_seed = num_workers * rank + worker_id + seed
136 | np.random.seed(worker_seed)
137 | random.seed(worker_seed)
138 |
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 |
6 | import math
7 | import torch
8 | from torch.utils.data.sampler import Sampler
9 |
10 |
11 | class EnlargedSampler(Sampler):
12 | """Sampler that restricts data loading to a subset of the dataset.
13 |
14 | Modified from torch.utils.data.distributed.DistributedSampler
15 | Support enlarging the dataset for iteration-based training, for saving
16 | time when restart the dataloader after each epoch
17 |
18 | Args:
19 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
20 | num_replicas (int | None): Number of processes participating in
21 | the training. It is usually the world_size.
22 | rank (int | None): Rank of the current process within num_replicas.
23 | ratio (int): Enlarging ratio. Default: 1.
24 | """
25 |
26 | def __init__(self, dataset, num_replicas, rank, ratio=1):
27 | self.dataset = dataset
28 | self.num_replicas = num_replicas
29 | self.rank = rank
30 | self.epoch = 0
31 | self.num_samples = math.ceil(
32 | len(self.dataset) * ratio / self.num_replicas)
33 | self.total_size = self.num_samples * self.num_replicas
34 |
35 | def __iter__(self):
36 | # deterministically shuffle based on epoch
37 | g = torch.Generator()
38 | g.manual_seed(self.epoch)
39 | indices = torch.randperm(self.total_size, generator=g).tolist()
40 |
41 | dataset_size = len(self.dataset)
42 | indices = [v % dataset_size for v in indices]
43 |
44 | # subsample
45 | indices = indices[self.rank:self.total_size:self.num_replicas]
46 | assert len(indices) == self.num_samples
47 |
48 | return iter(indices)
49 |
50 | def __len__(self):
51 | return self.num_samples
52 |
53 | def set_epoch(self, epoch):
54 | self.epoch = epoch
55 |
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 | import queue as Queue
6 | import threading
7 | import torch
8 | from torch.utils.data import DataLoader
9 |
10 |
11 | class PrefetchGenerator(threading.Thread):
12 | """A general prefetch generator.
13 |
14 | Ref:
15 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
16 |
17 | Args:
18 | generator: Python generator.
19 | num_prefetch_queue (int): Number of prefetch queue.
20 | """
21 |
22 | def __init__(self, generator, num_prefetch_queue):
23 | threading.Thread.__init__(self)
24 | self.queue = Queue.Queue(num_prefetch_queue)
25 | self.generator = generator
26 | self.daemon = True
27 | self.start()
28 |
29 | def run(self):
30 | for item in self.generator:
31 | self.queue.put(item)
32 | self.queue.put(None)
33 |
34 | def __next__(self):
35 | next_item = self.queue.get()
36 | if next_item is None:
37 | raise StopIteration
38 | return next_item
39 |
40 | def __iter__(self):
41 | return self
42 |
43 |
44 | class PrefetchDataLoader(DataLoader):
45 | """Prefetch version of dataloader.
46 |
47 | Ref:
48 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
49 |
50 | TODO:
51 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
52 | ddp.
53 |
54 | Args:
55 | num_prefetch_queue (int): Number of prefetch queue.
56 | kwargs (dict): Other arguments for dataloader.
57 | """
58 |
59 | def __init__(self, num_prefetch_queue, **kwargs):
60 | self.num_prefetch_queue = num_prefetch_queue
61 | super(PrefetchDataLoader, self).__init__(**kwargs)
62 |
63 | def __iter__(self):
64 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
65 |
66 |
67 | class CPUPrefetcher():
68 | """CPU prefetcher.
69 |
70 | Args:
71 | loader: Dataloader.
72 | """
73 |
74 | def __init__(self, loader):
75 | self.ori_loader = loader
76 | self.loader = iter(loader)
77 |
78 | def next(self):
79 | try:
80 | return next(self.loader)
81 | except StopIteration:
82 | return None
83 |
84 | def reset(self):
85 | self.loader = iter(self.ori_loader)
86 |
87 |
88 | class CUDAPrefetcher():
89 | """CUDA prefetcher.
90 |
91 | Ref:
92 | https://github.com/NVIDIA/apex/issues/304#
93 |
94 | It may consums more GPU memory.
95 |
96 | Args:
97 | loader: Dataloader.
98 | opt (dict): Options.
99 | """
100 |
101 | def __init__(self, loader, opt):
102 | self.ori_loader = loader
103 | self.loader = iter(loader)
104 | self.opt = opt
105 | self.stream = torch.cuda.Stream()
106 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
107 | self.preload()
108 |
109 | def preload(self):
110 | try:
111 | self.batch = next(self.loader) # self.batch is a dict
112 | except StopIteration:
113 | self.batch = None
114 | return None
115 | # put tensors to gpu
116 | with torch.cuda.stream(self.stream):
117 | for k, v in self.batch.items():
118 | if torch.is_tensor(v):
119 | self.batch[k] = self.batch[k].to(
120 | device=self.device, non_blocking=True)
121 |
122 | def next(self):
123 | torch.cuda.current_stream().wait_stream(self.stream)
124 | batch = self.batch
125 | self.preload()
126 | return batch
127 |
128 | def reset(self):
129 | self.loader = iter(self.ori_loader)
130 | self.preload()
131 |
--------------------------------------------------------------------------------
/basicsr/data/transforms.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import cv2
8 | import random, numpy as np
9 |
10 | def mod_crop(img, scale):
11 | """Mod crop images, used during testing.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | scale (int): Scale factor.
16 |
17 | Returns:
18 | ndarray: Result image.
19 | """
20 | img = img.copy()
21 | if img.ndim in (2, 3):
22 | h, w = img.shape[0], img.shape[1]
23 | h_remainder, w_remainder = h % scale, w % scale
24 | img = img[:h - h_remainder, :w - w_remainder, ...]
25 | else:
26 | raise ValueError(f'Wrong img ndim: {img.ndim}.')
27 | return img
28 |
29 |
30 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
31 | """Paired random crop.
32 |
33 | It crops lists of lq and gt images with corresponding locations.
34 |
35 | Args:
36 | img_gts (list[ndarray] | ndarray): GT images. Note that all images
37 | should have the same shape. If the input is an ndarray, it will
38 | be transformed to a list containing itself.
39 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
40 | should have the same shape. If the input is an ndarray, it will
41 | be transformed to a list containing itself.
42 | gt_patch_size (int): GT patch size.
43 | scale (int): Scale factor.
44 | gt_path (str): Path to ground-truth.
45 |
46 | Returns:
47 | list[ndarray] | ndarray: GT images and LQ images. If returned results
48 | only have one element, just return ndarray.
49 | """
50 |
51 | if not isinstance(img_gts, list):
52 | img_gts = [img_gts]
53 | if not isinstance(img_lqs, list):
54 | img_lqs = [img_lqs]
55 |
56 | h_lq, w_lq, _ = img_lqs[0].shape
57 | h_gt, w_gt, _ = img_gts[0].shape
58 | lq_patch_size = gt_patch_size // scale
59 |
60 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
61 | raise ValueError(
62 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
63 | f'multiplication of LQ ({h_lq}, {w_lq}).')
64 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
65 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
66 | f'({lq_patch_size}, {lq_patch_size}). '
67 | f'Please remove {gt_path}.')
68 |
69 | # randomly choose top and left coordinates for lq patch
70 | top = random.randint(0, h_lq - lq_patch_size)
71 | left = random.randint(0, w_lq - lq_patch_size)
72 |
73 | # crop lq patch
74 | img_lqs = [
75 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
76 | for v in img_lqs
77 | ]
78 |
79 | # crop corresponding gt patch
80 | top_gt, left_gt = int(top * scale), int(left * scale)
81 | img_gts = [
82 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
83 | for v in img_gts
84 | ]
85 | if len(img_gts) == 1:
86 | img_gts = img_gts[0]
87 | if len(img_lqs) == 1:
88 | img_lqs = img_lqs[0]
89 | return img_gts, img_lqs
90 |
91 |
92 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
93 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
94 |
95 | We use vertical flip and transpose for rotation implementation.
96 | All the images in the list use the same augmentation.
97 |
98 | Args:
99 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input
100 | is an ndarray, it will be transformed to a list.
101 | hflip (bool): Horizontal flip. Default: True.
102 | rotation (bool): Ratotation. Default: True.
103 | flows (list[ndarray]: Flows to be augmented. If the input is an
104 | ndarray, it will be transformed to a list.
105 | Dimension is (h, w, 2). Default: None.
106 | return_status (bool): Return the status of flip and rotation.
107 | Default: False.
108 |
109 | Returns:
110 | list[ndarray] | ndarray: Augmented images and flows. If returned
111 | results only have one element, just return ndarray.
112 |
113 | """
114 | hflip = hflip and random.random() < 0.5
115 | vflip = rotation and random.random() < 0.5
116 | rot90 = rotation and random.random() < 0.5
117 |
118 | def _augment(img):
119 | if hflip: # horizontal
120 | cv2.flip(img, 1, img)
121 | if vflip: # vertical
122 | cv2.flip(img, 0, img)
123 | if rot90:
124 | img = img.transpose(1, 0, 2)
125 | return img
126 |
127 | def _augment_flow(flow):
128 | if hflip: # horizontal
129 | cv2.flip(flow, 1, flow)
130 | flow[:, :, 0] *= -1
131 | if vflip: # vertical
132 | cv2.flip(flow, 0, flow)
133 | flow[:, :, 1] *= -1
134 | if rot90:
135 | flow = flow.transpose(1, 0, 2)
136 | flow = flow[:, :, [1, 0]]
137 | return flow
138 |
139 | if not isinstance(imgs, list):
140 | imgs = [imgs]
141 | imgs = [_augment(img) for img in imgs]
142 | if len(imgs) == 1:
143 | imgs = imgs[0]
144 |
145 | if flows is not None:
146 | if not isinstance(flows, list):
147 | flows = [flows]
148 | flows = [_augment_flow(flow) for flow in flows]
149 | if len(flows) == 1:
150 | flows = flows[0]
151 | return imgs, flows
152 | else:
153 | if return_status:
154 | return imgs, (hflip, vflip, rot90)
155 | else:
156 | return imgs
157 |
158 |
159 | def img_rotate(img, angle, center=None, scale=1.0):
160 | """Rotate image.
161 |
162 | Args:
163 | img (ndarray): Image to be rotated.
164 | angle (float): Rotation angle in degrees. Positive values mean
165 | counter-clockwise rotation.
166 | center (tuple[int]): Rotation center. If the center is None,
167 | initialize it as the center of the image. Default: None.
168 | scale (float): Isotropic scale factor. Default: 1.0.
169 | """
170 | (h, w) = img.shape[:2]
171 |
172 | if center is None:
173 | center = (w // 2, h // 2)
174 |
175 | matrix = cv2.getRotationMatrix2D(center, angle, scale)
176 | rotated_img = cv2.warpAffine(img, matrix, (w, h))
177 | return rotated_img
178 |
179 | def data_augmentation(image, mode):
180 | """
181 | Performs data augmentation of the input image
182 | Input:
183 | image: a cv2 (OpenCV) image
184 | mode: int. Choice of transformation to apply to the image
185 | 0 - no transformation
186 | 1 - flip up and down
187 | 2 - rotate counterwise 90 degree
188 | 3 - rotate 90 degree and flip up and down
189 | 4 - rotate 180 degree
190 | 5 - rotate 180 degree and flip
191 | 6 - rotate 270 degree
192 | 7 - rotate 270 degree and flip
193 | """
194 | if mode == 0:
195 | # original
196 | out = image
197 | elif mode == 1:
198 | # flip up and down
199 | out = np.flipud(image)
200 | elif mode == 2:
201 | # rotate counterwise 90 degree
202 | out = np.rot90(image)
203 | elif mode == 3:
204 | # rotate 90 degree and flip up and down
205 | out = np.rot90(image)
206 | out = np.flipud(out)
207 | elif mode == 4:
208 | # rotate 180 degree
209 | out = np.rot90(image, k=2)
210 | elif mode == 5:
211 | # rotate 180 degree and flip
212 | out = np.rot90(image, k=2)
213 | out = np.flipud(out)
214 | elif mode == 6:
215 | # rotate 270 degree
216 | out = np.rot90(image, k=3)
217 | elif mode == 7:
218 | # rotate 270 degree and flip
219 | out = np.rot90(image, k=3)
220 | out = np.flipud(out)
221 | else:
222 | raise Exception('Invalid choice of image transformation')
223 |
224 | return out
225 |
226 | def random_augmentation(*args):
227 | ## older random augmentation
228 | out = []
229 | if random.randint(0,1) == 1:
230 | flag_aug = random.randint(1,7)
231 | for data in args:
232 | out.append(data_augmentation(data, flag_aug).copy())
233 | else:
234 | for data in args:
235 | out.append(data)
236 | return out
237 |
238 | # restormer's augmentation
239 | # def random_augmentation(*args):
240 | # out = []
241 | # flag_aug = random.randint(0, 7)
242 | # for data in args:
243 | # out.append(data_augmentation(data, flag_aug).copy())
244 | # return out
--------------------------------------------------------------------------------
/basicsr/data/video_image_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data as data
2 | from basicsr.data.data_util import np2Tensor, get_patch
3 | from basicsr.data.transforms import random_augmentation
4 | import os
5 | import glob, imageio
6 | import numpy as np
7 | import torch
8 |
9 | class VideoImageDataset(data.Dataset):
10 | def __init__(self, args, phase):
11 | self.args = args
12 | self.name = args['name']
13 | self.phase = phase
14 | self.n_seq = args['n_sequence']
15 | self.n_frames_video = []
16 | if self.phase == "train":
17 | self._set_filesystem(args['dir_data'],
18 | self.phase)
19 | else:
20 | self._set_filesystem(args['datasets']['val']['dir_data'],
21 | self.phase)
22 |
23 | self.images_gt, self.images_input = self._scan()
24 | self.num_video = len(self.images_gt)
25 | self.num_frame = sum(self.n_frames_video) - (self.n_seq - 1) * len(self.n_frames_video)
26 | print("Number of videos to load:", self.num_video)
27 | self.n_colors = args['n_colors']
28 | self.rgb_range = args['rgb_range']
29 | self.patch_size = args['patch_size']
30 | self.no_augment = args['no_augment']
31 | self.size_must_mode = args['size_must_mode']
32 |
33 | def _set_filesystem(self, dir_data, phase):
34 | print("Loading {} => {} DataSet".format(f"{phase}", self.name))
35 | if isinstance(dir_data, list):
36 | self.dir_gt = []
37 | self.apath = []
38 | self.dir_input = []
39 | for path in dir_data:
40 | self.apath.append(path)
41 | self.dir_gt.append(os.path.join(path, 'gt'))
42 | self.dir_input.append(os.path.join(path, 'blur'))
43 | else:
44 | self.apath = dir_data
45 | self.dir_gt = os.path.join(self.apath, 'gt')
46 | self.dir_input = os.path.join(self.apath, 'blur')
47 |
48 | def _scan(self):
49 | if isinstance(self.dir_gt, list):
50 | vid_gt_names_combined = []
51 | vid_input_names_combined = []
52 |
53 | for ix in range(len(self.dir_gt)):
54 | vid_gt_names = sorted(glob.glob(os.path.join(self.dir_gt[ix], '*')))
55 | vid_input_names = sorted(glob.glob(os.path.join(self.dir_input[ix], '*')))
56 |
57 | vid_gt_names_combined.append(vid_gt_names)
58 | vid_input_names_combined.append(vid_input_names)
59 | assert len(vid_gt_names) == len(vid_input_names), "len(vid_gt_names) must equal len(vid_input_names)"
60 | else:
61 | vid_gt_names_combined = vid_gt_names
62 | vid_input_names_combined = vid_input_names
63 |
64 | images_gt = []
65 | images_input = []
66 | for vid_gt, vid_input in zip(vid_gt_names_combined, vid_input_names_combined):
67 | for vid_gt_name, vid_input_name in zip(vid_gt, vid_input):
68 | gt_dir_names = sorted(glob.glob(os.path.join(vid_gt_name, '*')))
69 | input_dir_names = sorted(glob.glob(os.path.join(vid_input_name, '*')))
70 |
71 | images_gt.append(gt_dir_names)
72 | images_input.append(input_dir_names)
73 | self.n_frames_video.append(len(gt_dir_names))
74 | return images_gt, images_input
75 |
76 | def _load(self, images_gt, images_input):
77 | data_input = []
78 | data_gt = []
79 | n_videos = len(images_gt)
80 | for idx in range(n_videos):
81 | if idx % 10 == 0:
82 | print("Loading video %d" % idx)
83 | gts = np.array([imageio.imread(hr_name) for hr_name in images_gt[idx]])
84 | inputs = np.array([imageio.imread(lr_name) for lr_name in images_input[idx]])
85 | data_input.append(inputs)
86 | data_gt.append(gts)
87 | return data_gt, data_input
88 |
89 | def add_noise(self, x):
90 | # x is numpy here
91 | x = torch.tensor(x).unsqueeze(0).permute(0, 3, 1, 2)
92 | if self.phase == "train":
93 | # uniform sampling from [20, 50]
94 | r1 = 20.0/255.0
95 | r2 = 50.0/255.0
96 | stdn = np.random.rand(1,1,1,1) * (r2-r1) + r1
97 | stdn = torch.FloatTensor(stdn)
98 | noise = torch.zeros_like(x)
99 | noise = torch.normal(mean=noise.float(),
100 | std=stdn.expand_as(noise))
101 | lq = (noise + x/255.0)*255
102 | else:
103 | # in validation, the noise is fixed to 50.0/255.0.
104 | r2 = 50.0/255.0
105 | stdn = [r2]
106 | stdn = torch.FloatTensor(stdn)
107 | noise = torch.zeros_like(x)
108 | noise = torch.normal(mean=noise.float(),
109 | std=stdn.expand_as(noise))
110 | lq = (noise + x/255.0)*255
111 |
112 | return lq.squeeze(0).permute(1, 2, 0).numpy()
113 |
114 | def __getitem__(self, idx):
115 | inputs, gts, filenames_prompts, filenames = self._load_file(idx)
116 | inputs_list = [inputs[i, :, :, :] for i in range(self.n_seq)]
117 | inputs_concat = np.concatenate(inputs_list, axis=2)
118 | gts_list = [gts[i, :, :, :] for i in range(self.n_seq)]
119 | gts_concat = np.concatenate(gts_list, axis=2)
120 | inputs_concat, gts_concat = self.get_patch(inputs_concat, gts_concat, self.size_must_mode)
121 | inputs_list = [inputs_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)]
122 | gts_list = [gts_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)]
123 |
124 | inputs_updated = []
125 | for ix in range(len(filenames_prompts)):
126 | _filename_ = filenames_prompts[ix]
127 | _img_ = inputs_list[ix]
128 | if "DAVIS" in _filename_:
129 | # denoising dataset, add noise.
130 | noise_added_img = self.add_noise(_img_)
131 | inputs_updated.append(noise_added_img)
132 | else:
133 | # let it go as is.
134 | inputs_updated.append(_img_)
135 |
136 | inputs = np.array(inputs_updated)
137 | gts = np.array(gts_list)
138 |
139 | input_tensors = np2Tensor(*inputs, rgb_range=self.rgb_range, n_colors=self.n_colors)
140 | gt_tensors = np2Tensor(*gts, rgb_range=self.rgb_range, n_colors=self.n_colors)
141 | return torch.stack(input_tensors), torch.stack(gt_tensors), filenames_prompts, filenames
142 |
143 | def __len__(self):
144 | return self.num_frame
145 |
146 | def _get_index(self, idx):
147 | return idx % self.num_frame
148 |
149 | def _find_video_num(self, idx, n_frame):
150 | for i, j in enumerate(n_frame):
151 | if idx < j: return i, idx
152 | else: idx -= j
153 |
154 | def _load_file(self, idx):
155 | idx = self._get_index(idx)
156 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video]
157 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames)
158 | f_gts = self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq]
159 | f_inputs = self.images_input[video_idx][frame_idx:frame_idx + self.n_seq]
160 | gts = np.array([imageio.imread(hr_name) for hr_name in f_gts])
161 | inputs = np.array([imageio.imread(lr_name) for lr_name in f_inputs])
162 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0]
163 | for name in f_gts]
164 | filenames_prompts = [x for x in f_inputs]
165 | return inputs, gts, filenames_prompts, filenames
166 |
167 | def _load_file_from_loaded_data(self, idx):
168 | idx = self._get_index(idx)
169 |
170 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video]
171 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames)
172 | gts = self.data_gt[video_idx][frame_idx:frame_idx + self.n_seq]
173 | inputs = self.data_input[video_idx][frame_idx:frame_idx + self.n_seq]
174 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0]
175 | for name in self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq]]
176 | return inputs, gts, filenames
177 |
178 | def get_patch(self, input, gt, size_must_mode=1):
179 | if True:
180 | input, gt = get_patch(input, gt, patch_size=self.patch_size)
181 | h, w, c = input.shape
182 | new_h, new_w = h - h % size_must_mode, w - w % size_must_mode
183 | input, gt = input[:new_h, :new_w, :], gt[:new_h, :new_w, :]
184 | if not self.no_augment and self.phase == "train":
185 | input, gt = random_augmentation(input, gt)
186 | return input, gt
--------------------------------------------------------------------------------
/basicsr/data/video_super_image_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data as data
2 | from basicsr.data.data_util import np2Tensor
3 | from basicsr.data.transforms import random_augmentation
4 | import os
5 | import glob, imageio
6 | import numpy as np
7 | import torch
8 | import cv2, random
9 |
10 | class VideoSuperImageDataset(data.Dataset):
11 | def __init__(self, args, phase):
12 | self.args = args
13 | self.name = args['name']
14 | self.phase = phase
15 | self.n_seq = args['n_sequence']
16 | print("n_seq:", self.n_seq)
17 | self.n_frames_video = []
18 | if self.phase == "train":
19 | self._set_filesystem(args['dir_data'],
20 | self.phase)
21 | else:
22 | self._set_filesystem(args['datasets']['val']['dir_data'],
23 | self.phase)
24 |
25 | self.images_gt, self.images_input = self._scan()
26 | self.num_video = len(self.images_gt)
27 | self.num_frame = sum(self.n_frames_video) - (self.n_seq - 1) * len(self.n_frames_video)
28 | print("Number of videos to load:", self.num_video)
29 | self.n_colors = args['n_colors']
30 | self.rgb_range = args['rgb_range']
31 | self.patch_size = args['patch_size']
32 | self.no_augment = args['no_augment']
33 | self.size_must_mode = args['size_must_mode']
34 |
35 | def _set_filesystem(self, dir_data, phase):
36 | print("Loading {} => {} DataSet".format(f"{phase}", self.name))
37 | if isinstance(dir_data, list):
38 | self.dir_gt = []
39 | self.apath = []
40 | self.dir_input = []
41 | for path in dir_data:
42 | self.apath.append(path)
43 | self.dir_gt.append(os.path.join(path, 'gt'))
44 | self.dir_input.append(os.path.join(path, 'blur'))
45 | else:
46 | self.apath = dir_data
47 | self.dir_gt = os.path.join(self.apath, 'gt')
48 | self.dir_input = os.path.join(self.apath, 'blur')
49 |
50 | def _scan(self):
51 | if isinstance(self.dir_gt, list):
52 | vid_gt_names_combined = []
53 | vid_input_names_combined = []
54 |
55 | for ix in range(len(self.dir_gt)):
56 | vid_gt_names = sorted(glob.glob(os.path.join(self.dir_gt[ix], '*')))
57 | vid_input_names = sorted(glob.glob(os.path.join(self.dir_input[ix], '*')))
58 |
59 | vid_gt_names_combined.append(vid_gt_names)
60 | vid_input_names_combined.append(vid_input_names)
61 | assert len(vid_gt_names) == len(vid_input_names), "len(vid_gt_names) must equal len(vid_input_names)"
62 | else:
63 | vid_gt_names_combined = vid_gt_names
64 | vid_input_names_combined = vid_input_names
65 |
66 | images_gt = []
67 | images_input = []
68 | for vid_gt, vid_input in zip(vid_gt_names_combined, vid_input_names_combined):
69 | for vid_gt_name, vid_input_name in zip(vid_gt, vid_input):
70 | gt_dir_names = sorted(glob.glob(os.path.join(vid_gt_name, '*')))
71 | input_dir_names = sorted(glob.glob(os.path.join(vid_input_name, '*')))
72 |
73 | images_gt.append(gt_dir_names)
74 | images_input.append(input_dir_names)
75 | self.n_frames_video.append(len(gt_dir_names))
76 | return images_gt, images_input
77 |
78 | def _load(self, images_gt, images_input):
79 | data_input = []
80 | data_gt = []
81 | n_videos = len(images_gt)
82 | for idx in range(n_videos):
83 | if idx % 10 == 0:
84 | print("Loading video %d" % idx)
85 | gts = np.array([imageio.imread(hr_name) for hr_name in images_gt[idx]])
86 | inputs = np.array([imageio.imread(lr_name) for lr_name in images_input[idx]])
87 | data_input.append(inputs)
88 | data_gt.append(gts)
89 | return data_gt, data_input
90 |
91 | def __getitem__(self, idx):
92 | inputs, gts, filenames, filenames_prompts = self._load_file(idx)
93 | inputs_list = [inputs[i, :, :, :] for i in range(self.n_seq)]
94 | inputs_concat = np.concatenate(inputs_list, axis=2)
95 | gts_list = [gts[i, :, :, :] for i in range(self.n_seq)]
96 | gts_concat = np.concatenate(gts_list, axis=2)
97 |
98 | inputs_concat, gts_concat = self._crop_patch(inputs_concat, gts_concat)
99 | inputs_list = [inputs_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)]
100 | gts_list = [gts_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)]
101 | inputs = np.array(inputs_list)
102 | gts = np.array(gts_list)
103 |
104 | input_tensors = np2Tensor(*inputs, rgb_range=self.rgb_range, n_colors=self.n_colors)
105 | gt_tensors = np2Tensor(*gts, rgb_range=self.rgb_range, n_colors=self.n_colors)
106 | return torch.stack(input_tensors), torch.stack(gt_tensors), filenames, filenames_prompts
107 |
108 | def __len__(self):
109 | return self.num_frame
110 |
111 | def _get_index(self, idx):
112 | return idx % self.num_frame
113 |
114 | def _find_video_num(self, idx, n_frame):
115 | for i, j in enumerate(n_frame):
116 | if idx < j: return i, idx
117 | else: idx -= j
118 |
119 | def _load_file(self, idx):
120 | idx = self._get_index(idx)
121 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video]
122 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames)
123 | f_gts = self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq]
124 | f_inputs = self.images_input[video_idx][frame_idx:frame_idx + self.n_seq]
125 | inputs = []
126 | gts = np.array([imageio.imread(hr_name) for hr_name in f_gts])
127 | # inputs = np.array([imageio.imread(lr_name) for lr_name in f_inputs])
128 | inputs = []
129 | for lr_name in f_inputs:
130 | lq_img = imageio.imread(lr_name)
131 | h,w,_ = lq_img.shape
132 | lq_img_ = cv2.resize(lq_img, (w//4, h//4),
133 | interpolation=cv2.INTER_CUBIC)
134 | inputs.append(lq_img_)
135 | inputs = np.array(inputs)
136 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0]
137 | for name in f_gts]
138 | filenames_prompts = [x for x in f_inputs]
139 | return inputs, gts, filenames, filenames_prompts
140 |
141 | def _load_file_from_loaded_data(self, idx):
142 | idx = self._get_index(idx)
143 |
144 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video]
145 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames)
146 | gts = self.data_gt[video_idx][frame_idx:frame_idx + self.n_seq]
147 | inputs = self.data_input[video_idx][frame_idx:frame_idx + self.n_seq]
148 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0]
149 | for name in self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq]]
150 | return inputs, gts, filenames
151 |
152 | def _crop_patch(self, lr_seq, hr_seq, patch_size=48, scale=4):
153 | ih, iw, _ = lr_seq.shape
154 | pw = random.randrange(0, iw - patch_size + 1)
155 | ph = random.randrange(0, ih - patch_size + 1)
156 |
157 | hpw, hph = scale * pw, scale * ph
158 | hr_patch_size = scale * patch_size
159 |
160 | lr_patch_seq = lr_seq[ph:ph+patch_size, pw:pw+patch_size, :]
161 | hr_patch_seq = hr_seq[hph:hph+hr_patch_size, hpw:hpw+hr_patch_size, :]
162 | if not self.no_augment and self.phase == "train":
163 | lr_patch_seq, hr_patch_seq = random_augmentation(lr_patch_seq, hr_patch_seq)
164 | return lr_patch_seq, hr_patch_seq
--------------------------------------------------------------------------------
/basicsr/loss/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class L1BaseLoss(nn.Module):
8 | def __init__(self, loss_weight=1.0, reduction='mean'):
9 | super(L1BaseLoss, self).__init__()
10 | self.loss_weight = loss_weight
11 | self.reduction = reduction
12 |
13 | def forward(self, pred, target):
14 | l1_loss = nn.L1Loss()
15 | l1_base = l1_loss(pred, target)
16 | return l1_base
17 |
18 | class PSNRLoss(nn.Module):
19 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
20 | super(PSNRLoss, self).__init__()
21 | assert reduction == 'mean'
22 | self.loss_weight = loss_weight
23 | self.scale = 10 / np.log(10)
24 | self.toY = toY
25 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
26 | self.first = True
27 |
28 | def forward(self, pred, target):
29 | assert len(pred.size()) == 4
30 | if self.toY:
31 | if self.first:
32 | self.coef = self.coef.to(pred.device)
33 | self.first = False
34 |
35 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
36 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
37 |
38 | pred, target = pred / 255., target / 255.
39 | pass
40 | assert len(pred.size()) == 4
41 |
42 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
--------------------------------------------------------------------------------
/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 | # from .niqe import calculate_niqe
6 | from .psnr_ssim import calculate_psnr, calculate_ssim
7 |
8 | __all__ = ['calculate_psnr', 'calculate_ssim']
9 |
--------------------------------------------------------------------------------
/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import numpy as np
8 |
9 | from basicsr.utils.matlab_functions import bgr2ycbcr
10 |
11 |
12 | def reorder_image(img, input_order='HWC'):
13 | """Reorder images to 'HWC' order.
14 |
15 | If the input_order is (h, w), return (h, w, 1);
16 | If the input_order is (c, h, w), return (h, w, c);
17 | If the input_order is (h, w, c), return as it is.
18 |
19 | Args:
20 | img (ndarray): Input image.
21 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
22 | If the input image shape is (h, w), input_order will not have
23 | effects. Default: 'HWC'.
24 |
25 | Returns:
26 | ndarray: reordered image.
27 | """
28 |
29 | if input_order not in ['HWC', 'CHW']:
30 | raise ValueError(
31 | f'Wrong input_order {input_order}. Supported input_orders are '
32 | "'HWC' and 'CHW'")
33 | if len(img.shape) == 2:
34 | img = img[..., None]
35 | if input_order == 'CHW':
36 | img = img.transpose(1, 2, 0)
37 | return img
38 |
39 |
40 | def to_y_channel(img):
41 | """Change to Y channel of YCbCr.
42 |
43 | Args:
44 | img (ndarray): Images with range [0, 255].
45 |
46 | Returns:
47 | (ndarray): Images with range [0, 255] (float type) without round.
48 | """
49 | img = img.astype(np.float32) / 255.
50 | if img.ndim == 3 and img.shape[2] == 3:
51 | img = bgr2ycbcr(img, y_only=True)
52 | img = img[..., None]
53 | return img * 255.
54 |
--------------------------------------------------------------------------------
/basicsr/metrics/psnr_ssim.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 | import cv2
6 | import numpy as np
7 |
8 | from basicsr.metrics.metric_util import reorder_image, to_y_channel
9 | import skimage.metrics
10 | import torch
11 |
12 |
13 | def calculate_psnr(img1,
14 | img2,
15 | crop_border,
16 | input_order='HWC',
17 | test_y_channel=False):
18 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
19 |
20 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
21 |
22 | Args:
23 | img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
24 | img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
25 | crop_border (int): Cropped pixels in each edge of an image. These
26 | pixels are not involved in the PSNR calculation.
27 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
28 | Default: 'HWC'.
29 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
30 |
31 | Returns:
32 | float: psnr result.
33 | """
34 |
35 | assert img1.shape == img2.shape, (
36 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
37 | if input_order not in ['HWC', 'CHW']:
38 | raise ValueError(
39 | f'Wrong input_order {input_order}. Supported input_orders are '
40 | '"HWC" and "CHW"')
41 | if type(img1) == torch.Tensor:
42 | if len(img1.shape) == 4:
43 | img1 = img1.squeeze(0)
44 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
45 | if type(img2) == torch.Tensor:
46 | if len(img2.shape) == 4:
47 | img2 = img2.squeeze(0)
48 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
49 |
50 | img1 = reorder_image(img1, input_order=input_order)
51 | img2 = reorder_image(img2, input_order=input_order)
52 | img1 = img1.astype(np.float64)
53 | img2 = img2.astype(np.float64)
54 |
55 | if crop_border != 0:
56 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
57 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
58 |
59 | if test_y_channel:
60 | img1 = to_y_channel(img1)
61 | img2 = to_y_channel(img2)
62 |
63 | mse = np.mean((img1 - img2)**2)
64 | if mse == 0:
65 | return float('inf')
66 | max_value = 1. if img1.max() <= 1 else 255.
67 | return 20. * np.log10(max_value / np.sqrt(mse))
68 |
69 |
70 | def _ssim(img1, img2):
71 | """Calculate SSIM (structural similarity) for one channel images.
72 |
73 | It is called by func:`calculate_ssim`.
74 |
75 | Args:
76 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
77 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
78 |
79 | Returns:
80 | float: ssim result.
81 | """
82 |
83 | C1 = (0.01 * 255)**2
84 | C2 = (0.03 * 255)**2
85 |
86 | img1 = img1.astype(np.float64)
87 | img2 = img2.astype(np.float64)
88 | kernel = cv2.getGaussianKernel(11, 1.5)
89 | window = np.outer(kernel, kernel.transpose())
90 |
91 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
92 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
93 | mu1_sq = mu1**2
94 | mu2_sq = mu2**2
95 | mu1_mu2 = mu1 * mu2
96 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
97 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
98 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
99 |
100 | ssim_map = ((2 * mu1_mu2 + C1) *
101 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
102 | (sigma1_sq + sigma2_sq + C2))
103 | return ssim_map.mean()
104 |
105 | def prepare_for_ssim(img, k):
106 | import torch
107 | with torch.no_grad():
108 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
109 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
110 | conv.weight.requires_grad = False
111 | conv.weight[:, :, :, :] = 1. / (k * k)
112 |
113 | img = conv(img)
114 |
115 | img = img.squeeze(0).squeeze(0)
116 | img = img[0::k, 0::k]
117 | return img.detach().cpu().numpy()
118 |
119 | def prepare_for_ssim_rgb(img, k):
120 | import torch
121 | with torch.no_grad():
122 | img = torch.from_numpy(img).float() #HxWx3
123 |
124 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
125 | conv.weight.requires_grad = False
126 | conv.weight[:, :, :, :] = 1. / (k * k)
127 |
128 | new_img = []
129 |
130 | for i in range(3):
131 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
132 |
133 | return torch.stack(new_img, dim=2).detach().cpu().numpy()
134 |
135 | def _3d_gaussian_calculator(img, conv3d):
136 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
137 | return out
138 |
139 | def _generate_3d_gaussian_kernel():
140 | kernel = cv2.getGaussianKernel(11, 1.5)
141 | window = np.outer(kernel, kernel.transpose())
142 | kernel_3 = cv2.getGaussianKernel(11, 1.5)
143 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
144 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
145 | conv3d.weight.requires_grad = False
146 | conv3d.weight[0, 0, :, :, :] = kernel
147 | return conv3d
148 |
149 | def _ssim_3d(img1, img2, max_value):
150 | assert len(img1.shape) == 3 and len(img2.shape) == 3
151 | """Calculate SSIM (structural similarity) for one channel images.
152 |
153 | It is called by func:`calculate_ssim`.
154 |
155 | Args:
156 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
157 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
158 |
159 | Returns:
160 | float: ssim result.
161 | """
162 | C1 = (0.01 * max_value) ** 2
163 | C2 = (0.03 * max_value) ** 2
164 | img1 = img1.astype(np.float64)
165 | img2 = img2.astype(np.float64)
166 |
167 | kernel = _generate_3d_gaussian_kernel().cuda()
168 |
169 | img1 = torch.tensor(img1).float().cuda()
170 | img2 = torch.tensor(img2).float().cuda()
171 |
172 |
173 | mu1 = _3d_gaussian_calculator(img1, kernel)
174 | mu2 = _3d_gaussian_calculator(img2, kernel)
175 |
176 | mu1_sq = mu1 ** 2
177 | mu2_sq = mu2 ** 2
178 | mu1_mu2 = mu1 * mu2
179 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
180 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
181 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
182 |
183 | ssim_map = ((2 * mu1_mu2 + C1) *
184 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
185 | (sigma1_sq + sigma2_sq + C2))
186 | return float(ssim_map.mean())
187 |
188 | def _ssim_cly(img1, img2):
189 | assert len(img1.shape) == 2 and len(img2.shape) == 2
190 | """Calculate SSIM (structural similarity) for one channel images.
191 |
192 | It is called by func:`calculate_ssim`.
193 |
194 | Args:
195 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
196 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
197 |
198 | Returns:
199 | float: ssim result.
200 | """
201 |
202 | C1 = (0.01 * 255)**2
203 | C2 = (0.03 * 255)**2
204 | img1 = img1.astype(np.float64)
205 | img2 = img2.astype(np.float64)
206 |
207 | kernel = cv2.getGaussianKernel(11, 1.5)
208 | # print(kernel)
209 | window = np.outer(kernel, kernel.transpose())
210 |
211 | bt = cv2.BORDER_REPLICATE
212 |
213 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
214 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
215 |
216 | mu1_sq = mu1**2
217 | mu2_sq = mu2**2
218 | mu1_mu2 = mu1 * mu2
219 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
220 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
221 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
222 |
223 | ssim_map = ((2 * mu1_mu2 + C1) *
224 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
225 | (sigma1_sq + sigma2_sq + C2))
226 | return ssim_map.mean()
227 |
228 |
229 | def calculate_ssim(img1,
230 | img2,
231 | crop_border,
232 | input_order='HWC',
233 | test_y_channel=False):
234 | """Calculate SSIM (structural similarity).
235 |
236 | Ref:
237 | Image quality assessment: From error visibility to structural similarity
238 |
239 | The results are the same as that of the official released MATLAB code in
240 | https://ece.uwaterloo.ca/~z70wang/research/ssim/.
241 |
242 | For three-channel images, SSIM is calculated for each channel and then
243 | averaged.
244 |
245 | Args:
246 | img1 (ndarray): Images with range [0, 255].
247 | img2 (ndarray): Images with range [0, 255].
248 | crop_border (int): Cropped pixels in each edge of an image. These
249 | pixels are not involved in the SSIM calculation.
250 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
251 | Default: 'HWC'.
252 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
253 |
254 | Returns:
255 | float: ssim result.
256 | """
257 |
258 | assert img1.shape == img2.shape, (
259 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
260 | if input_order not in ['HWC', 'CHW']:
261 | raise ValueError(
262 | f'Wrong input_order {input_order}. Supported input_orders are '
263 | '"HWC" and "CHW"')
264 |
265 | if type(img1) == torch.Tensor:
266 | if len(img1.shape) == 4:
267 | img1 = img1.squeeze(0)
268 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
269 | if type(img2) == torch.Tensor:
270 | if len(img2.shape) == 4:
271 | img2 = img2.squeeze(0)
272 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
273 |
274 | img1 = reorder_image(img1, input_order=input_order)
275 | img2 = reorder_image(img2, input_order=input_order)
276 |
277 | img1 = img1.astype(np.float64)
278 | img2 = img2.astype(np.float64)
279 |
280 | if crop_border != 0:
281 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
282 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
283 |
284 | if test_y_channel:
285 | img1 = to_y_channel(img1)
286 | img2 = to_y_channel(img2)
287 | return _ssim_cly(img1[..., 0], img2[..., 0])
288 |
289 |
290 | ssims = []
291 |
292 | max_value = 1 if img1.max() <= 1 else 255
293 | with torch.no_grad():
294 | final_ssim = _ssim_3d(img1, img2, max_value)
295 | ssims.append(final_ssim)
296 |
297 | return np.array(ssims).mean()
298 |
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import importlib
8 | from os import path as osp
9 |
10 | from basicsr.utils import get_root_logger, scandir
11 |
12 | # automatically scan and import model modules
13 | # scan all the files under the 'models' folder and collect files ending with
14 | # '_model.py'
15 | model_folder = osp.dirname(osp.abspath(__file__))
16 | model_filenames = [
17 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
18 | if v.endswith('_model.py')
19 | ]
20 | # import all the model modules
21 | _model_modules = [
22 | importlib.import_module(f'basicsr.models.{file_name}')
23 | for file_name in model_filenames
24 | ]
25 |
26 |
27 | def create_model(opt):
28 | """Create model.
29 |
30 | Args:
31 | opt (dict): Configuration. It constains:
32 | model_type (str): Model type.
33 | """
34 | model_type = opt['model_type']
35 |
36 | # dynamic instantiation
37 | for module in _model_modules:
38 | model_cls = getattr(module, model_type, None)
39 | if model_cls is not None:
40 | break
41 | if model_cls is None:
42 | raise ValueError(f'Model {model_type} is not found.')
43 |
44 | model = model_cls(opt)
45 |
46 | logger = get_root_logger()
47 | logger.info(f'Model [{model.__class__.__name__}] is created.')
48 | return model
49 |
--------------------------------------------------------------------------------
/basicsr/models/base_model.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import logging
8 | import os
9 | import torch
10 | from collections import OrderedDict
11 | from copy import deepcopy
12 | from torch.nn.parallel import DataParallel, DistributedDataParallel
13 |
14 | from basicsr.models import lr_scheduler as lr_scheduler
15 | from basicsr.utils.dist_util import master_only
16 |
17 | logger = logging.getLogger('basicsr')
18 |
19 |
20 | class BaseModel():
21 | """Base model."""
22 |
23 | def __init__(self, opt):
24 | self.opt = opt
25 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
26 | self.is_train = opt['is_train']
27 | self.schedulers = []
28 | self.optimizers = []
29 |
30 | def feed_data(self, data):
31 | pass
32 |
33 | def optimize_parameters(self):
34 | pass
35 |
36 | def get_current_visuals(self):
37 | pass
38 |
39 | def save(self, epoch, current_iter):
40 | """Save networks and training state."""
41 | pass
42 |
43 | def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True):
44 | """Validation function.
45 |
46 | Args:
47 | dataloader (torch.utils.data.DataLoader): Validation dataloader.
48 | current_iter (int): Current iteration.
49 | tb_logger (tensorboard logger): Tensorboard logger.
50 | save_img (bool): Whether to save images. Default: False.
51 | rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True
52 | use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True
53 | """
54 | if self.opt['dist']:
55 | return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image)
56 | else:
57 | return self.nondist_validation(dataloader, current_iter, tb_logger,
58 | save_img, rgb2bgr, use_image)
59 |
60 | def get_current_log(self):
61 | return self.log_dict
62 |
63 | def model_to_device(self, net):
64 | """Model to device. It also warps models with DistributedDataParallel
65 | or DataParallel.
66 |
67 | Args:
68 | net (nn.Module)
69 | """
70 |
71 | net = net.to(self.device)
72 | if self.opt['dist']:
73 | find_unused_parameters = True
74 | net = DistributedDataParallel(
75 | net,
76 | device_ids=[torch.cuda.current_device()],
77 | find_unused_parameters=find_unused_parameters)
78 | elif self.opt['num_gpu'] > 1:
79 | net = DataParallel(net)
80 | return net
81 |
82 | def setup_schedulers(self):
83 | """Set up schedulers."""
84 | train_opt = self.opt['train']
85 | scheduler_type = train_opt['scheduler'].pop('type')
86 | if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
87 | for optimizer in self.optimizers:
88 | self.schedulers.append(
89 | lr_scheduler.MultiStepRestartLR(optimizer,
90 | **train_opt['scheduler']))
91 | elif scheduler_type == 'CosineAnnealingRestartLR':
92 | for optimizer in self.optimizers:
93 | self.schedulers.append(
94 | lr_scheduler.CosineAnnealingRestartLR(
95 | optimizer, **train_opt['scheduler']))
96 | elif scheduler_type == 'TrueCosineAnnealingLR':
97 | print('..', 'cosineannealingLR')
98 | for optimizer in self.optimizers:
99 | self.schedulers.append(
100 | torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler']))
101 | elif scheduler_type == 'LinearLR':
102 | for optimizer in self.optimizers:
103 | self.schedulers.append(
104 | lr_scheduler.LinearLR(
105 | optimizer, train_opt['total_iter']))
106 | elif scheduler_type == 'VibrateLR':
107 | for optimizer in self.optimizers:
108 | self.schedulers.append(
109 | lr_scheduler.VibrateLR(
110 | optimizer, train_opt['total_iter']))
111 | else:
112 | raise NotImplementedError(
113 | f'Scheduler {scheduler_type} is not implemented yet.')
114 |
115 | def get_bare_model(self, net):
116 | """Get bare model, especially under wrapping with
117 | DistributedDataParallel or DataParallel.
118 | """
119 | if isinstance(net, (DataParallel, DistributedDataParallel)):
120 | net = net.module
121 | return net
122 |
123 | @master_only
124 | def print_network(self, net):
125 | """Print the str and parameter number of a network.
126 |
127 | Args:
128 | net (nn.Module)
129 | """
130 | if isinstance(net, (DataParallel, DistributedDataParallel)):
131 | net_cls_str = (f'{net.__class__.__name__} - '
132 | f'{net.module.__class__.__name__}')
133 | else:
134 | net_cls_str = f'{net.__class__.__name__}'
135 |
136 | net = self.get_bare_model(net)
137 | net_str = str(net)
138 | net_params = sum(map(lambda x: x.numel(), net.parameters()))
139 |
140 | logger.info(
141 | f'Network: {net_cls_str}, with parameters: {net_params:,d}')
142 | logger.info(net_str)
143 |
144 | def _set_lr(self, lr_groups_l):
145 | """Set learning rate for warmup.
146 |
147 | Args:
148 | lr_groups_l (list): List for lr_groups, each for an optimizer.
149 | """
150 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
151 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
152 | param_group['lr'] = lr
153 |
154 | def _get_init_lr(self):
155 | """Get the initial lr, which is set by the scheduler.
156 | """
157 | init_lr_groups_l = []
158 | for optimizer in self.optimizers:
159 | init_lr_groups_l.append(
160 | [v['initial_lr'] for v in optimizer.param_groups])
161 | return init_lr_groups_l
162 |
163 | def update_learning_rate(self, current_iter, warmup_iter=-1):
164 | """Update learning rate.
165 |
166 | Args:
167 | current_iter (int): Current iteration.
168 | warmup_iter (int): Warmup iter numbers. -1 for no warmup.
169 | Default: -1.
170 | """
171 | if current_iter > 1:
172 | for scheduler in self.schedulers:
173 | scheduler.step()
174 | # set up warm-up learning rate
175 | if current_iter < warmup_iter:
176 | # get initial lr for each group
177 | init_lr_g_l = self._get_init_lr()
178 | # modify warming-up learning rates
179 | # currently only support linearly warm up
180 | warm_up_lr_l = []
181 | for init_lr_g in init_lr_g_l:
182 | warm_up_lr_l.append(
183 | [v / warmup_iter * current_iter for v in init_lr_g])
184 | # set learning rate
185 | self._set_lr(warm_up_lr_l)
186 |
187 | def get_current_learning_rate(self):
188 | return [
189 | param_group['lr']
190 | for param_group in self.optimizers[0].param_groups
191 | ]
192 |
193 | @master_only
194 | def save_network(self, net, net_label, current_iter, param_key='params'):
195 | """Save networks.
196 |
197 | Args:
198 | net (nn.Module | list[nn.Module]): Network(s) to be saved.
199 | net_label (str): Network label.
200 | current_iter (int): Current iter number.
201 | param_key (str | list[str]): The parameter key(s) to save network.
202 | Default: 'params'.
203 | """
204 | if current_iter == -1:
205 | current_iter = 'latest'
206 | save_filename = f'{net_label}_{current_iter}.pth'
207 | save_path = os.path.join(self.opt['path']['models'], save_filename)
208 |
209 | net = net if isinstance(net, list) else [net]
210 | param_key = param_key if isinstance(param_key, list) else [param_key]
211 | assert len(net) == len(
212 | param_key), 'The lengths of net and param_key should be the same.'
213 |
214 | save_dict = {}
215 | for net_, param_key_ in zip(net, param_key):
216 | net_ = self.get_bare_model(net_)
217 | state_dict = net_.state_dict()
218 | for key, param in state_dict.items():
219 | if key.startswith('module.'): # remove unnecessary 'module.'
220 | key = key[7:]
221 | state_dict[key] = param.cpu()
222 | save_dict[param_key_] = state_dict
223 |
224 | torch.save(save_dict, save_path)
225 |
226 | def _print_different_keys_loading(self, crt_net, load_net, strict=True):
227 | """Print keys with differnet name or different size when loading models.
228 |
229 | 1. Print keys with differnet names.
230 | 2. If strict=False, print the same key but with different tensor size.
231 | It also ignore these keys with different sizes (not load).
232 |
233 | Args:
234 | crt_net (torch model): Current network.
235 | load_net (dict): Loaded network.
236 | strict (bool): Whether strictly loaded. Default: True.
237 | """
238 | crt_net = self.get_bare_model(crt_net)
239 | crt_net = crt_net.state_dict()
240 | crt_net_keys = set(crt_net.keys())
241 | load_net_keys = set(load_net.keys())
242 |
243 | if crt_net_keys != load_net_keys:
244 | logger.warning('Current net - loaded net:')
245 | for v in sorted(list(crt_net_keys - load_net_keys)):
246 | logger.warning(f' {v}')
247 | logger.warning('Loaded net - current net:')
248 | for v in sorted(list(load_net_keys - crt_net_keys)):
249 | logger.warning(f' {v}')
250 |
251 | # check the size for the same keys
252 | if not strict:
253 | common_keys = crt_net_keys & load_net_keys
254 | for k in common_keys:
255 | if crt_net[k].size() != load_net[k].size():
256 | logger.warning(
257 | f'Size different, ignore [{k}]: crt_net: '
258 | f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
259 | load_net[k + '.ignore'] = load_net.pop(k)
260 |
261 | def load_network(self, net, load_path, strict=True, param_key='params'):
262 | """Load network.
263 |
264 | Args:
265 | load_path (str): The path of networks to be loaded.
266 | net (nn.Module): Network.
267 | strict (bool): Whether strictly loaded.
268 | param_key (str): The parameter key of loaded network. If set to
269 | None, use the root 'path'.
270 | Default: 'params'.
271 | """
272 | net = self.get_bare_model(net)
273 | logger.info(
274 | f'Loading {net.__class__.__name__} model from {load_path}.')
275 | load_net = torch.load(
276 | load_path, map_location=lambda storage, loc: storage)
277 | if param_key is not None:
278 | load_net = load_net[param_key]
279 | print(' load net keys', load_net.keys)
280 | # remove unnecessary 'module.'
281 | for k, v in deepcopy(load_net).items():
282 | if k.startswith('module.'):
283 | load_net[k[7:]] = v
284 | load_net.pop(k)
285 | self._print_different_keys_loading(net, load_net, strict)
286 | net.load_state_dict(load_net, strict=strict)
287 |
288 | @master_only
289 | def save_training_state(self, epoch, current_iter):
290 | """Save training states during training, which will be used for
291 | resuming.
292 |
293 | Args:
294 | epoch (int): Current epoch.
295 | current_iter (int): Current iteration.
296 | """
297 | if current_iter != -1:
298 | state = {
299 | 'epoch': epoch,
300 | 'iter': current_iter,
301 | 'optimizers': [],
302 | 'schedulers': []
303 | }
304 | for o in self.optimizers:
305 | state['optimizers'].append(o.state_dict())
306 | for s in self.schedulers:
307 | state['schedulers'].append(s.state_dict())
308 | save_filename = f'{current_iter}.state'
309 | save_path = os.path.join(self.opt['path']['training_states'],
310 | save_filename)
311 | torch.save(state, save_path)
312 |
313 | def resume_training(self, resume_state):
314 | """Reload the optimizers and schedulers for resumed training.
315 |
316 | Args:
317 | resume_state (dict): Resume state.
318 | """
319 | resume_optimizers = resume_state['optimizers']
320 | resume_schedulers = resume_state['schedulers']
321 | assert len(resume_optimizers) == len(
322 | self.optimizers), 'Wrong lengths of optimizers'
323 | assert len(resume_schedulers) == len(
324 | self.schedulers), 'Wrong lengths of schedulers'
325 | for i, o in enumerate(resume_optimizers):
326 | self.optimizers[i].load_state_dict(o)
327 | for i, s in enumerate(resume_schedulers):
328 | self.schedulers[i].load_state_dict(s)
329 |
330 | def gather_tensors(self, tensor_to_gather):
331 | if tensor_to_gather is None:
332 | return None
333 | group = torch.distributed.group.WORLD
334 | group_size = torch.distributed.get_world_size(group)
335 | gather_t_tensors = [torch.zeros_like(tensor_to_gather) for _
336 | in range(group_size)]
337 | torch.distributed.all_gather(gather_t_tensors, tensor_to_gather)
338 | return torch.cat(gather_t_tensors, dim=0)
339 |
340 | def reduce_loss_dict(self, loss_dict):
341 | """reduce loss dict.
342 |
343 | In distributed training, it averages the losses among different GPUs.
344 |
345 | Args:
346 | loss_dict (OrderedDict): Loss dict.
347 | """
348 | with torch.no_grad():
349 | if self.opt['dist']:
350 | keys = []
351 | losses = []
352 | for name, value in loss_dict.items():
353 | keys.append(name)
354 | losses.append(value)
355 | losses = torch.stack(losses, 0)
356 | torch.distributed.reduce(losses, dst=0)
357 | if self.opt['rank'] == 0:
358 | losses /= self.opt['world_size']
359 | loss_dict = {key: loss for key, loss in zip(keys, losses)}
360 |
361 | log_dict = OrderedDict()
362 | for name, value in loss_dict.items():
363 | log_dict[name] = value.mean().item()
364 |
365 | return log_dict
366 |
--------------------------------------------------------------------------------
/basicsr/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | from .losses import (L1Loss, MSELoss, PSNRLoss)
8 |
9 | __all__ = [
10 | 'L1Loss', 'MSELoss', 'PSNRLoss',
11 | ]
12 |
--------------------------------------------------------------------------------
/basicsr/models/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import functools
8 | from torch.nn import functional as F
9 |
10 |
11 | def reduce_loss(loss, reduction):
12 | """Reduce loss as specified.
13 |
14 | Args:
15 | loss (Tensor): Elementwise loss tensor.
16 | reduction (str): Options are 'none', 'mean' and 'sum'.
17 |
18 | Returns:
19 | Tensor: Reduced loss tensor.
20 | """
21 | reduction_enum = F._Reduction.get_enum(reduction)
22 | # none: 0, elementwise_mean:1, sum: 2
23 | if reduction_enum == 0:
24 | return loss
25 | elif reduction_enum == 1:
26 | return loss.mean()
27 | else:
28 | return loss.sum()
29 |
30 |
31 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
32 | """Apply element-wise weight and reduce loss.
33 |
34 | Args:
35 | loss (Tensor): Element-wise loss.
36 | weight (Tensor): Element-wise weights. Default: None.
37 | reduction (str): Same as built-in losses of PyTorch. Options are
38 | 'none', 'mean' and 'sum'. Default: 'mean'.
39 |
40 | Returns:
41 | Tensor: Loss values.
42 | """
43 | # if weight is specified, apply element-wise weight
44 | if weight is not None:
45 | assert weight.dim() == loss.dim()
46 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
47 | loss = loss * weight
48 |
49 | # if weight is not specified or reduction is sum, just reduce the loss
50 | if weight is None or reduction == 'sum':
51 | loss = reduce_loss(loss, reduction)
52 | # if reduction is mean, then compute mean over weight region
53 | elif reduction == 'mean':
54 | if weight.size(1) > 1:
55 | weight = weight.sum()
56 | else:
57 | weight = weight.sum() * loss.size(1)
58 | loss = loss.sum() / weight
59 |
60 | return loss
61 |
62 |
63 | def weighted_loss(loss_func):
64 | """Create a weighted version of a given loss function.
65 |
66 | To use this decorator, the loss function must have the signature like
67 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
68 | element-wise loss without any reduction. This decorator will add weight
69 | and reduction arguments to the function. The decorated function will have
70 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
71 | **kwargs)`.
72 |
73 | :Example:
74 |
75 | >>> import torch
76 | >>> @weighted_loss
77 | >>> def l1_loss(pred, target):
78 | >>> return (pred - target).abs()
79 |
80 | >>> pred = torch.Tensor([0, 2, 3])
81 | >>> target = torch.Tensor([1, 1, 1])
82 | >>> weight = torch.Tensor([1, 0, 1])
83 |
84 | >>> l1_loss(pred, target)
85 | tensor(1.3333)
86 | >>> l1_loss(pred, target, weight)
87 | tensor(1.5000)
88 | >>> l1_loss(pred, target, reduction='none')
89 | tensor([1., 1., 2.])
90 | >>> l1_loss(pred, target, weight, reduction='sum')
91 | tensor(3.)
92 | """
93 |
94 | @functools.wraps(loss_func)
95 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
96 | # get element-wise loss
97 | loss = loss_func(pred, target, **kwargs)
98 | loss = weight_reduce_loss(loss, weight, reduction)
99 | return loss
100 |
101 | return wrapper
102 |
--------------------------------------------------------------------------------
/basicsr/models/losses/losses.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import torch
8 | from torch import nn as nn
9 | from torch.nn import functional as F
10 | import numpy as np
11 |
12 | from basicsr.models.losses.loss_util import weighted_loss
13 |
14 | _reduction_modes = ['none', 'mean', 'sum']
15 |
16 |
17 | @weighted_loss
18 | def l1_loss(pred, target):
19 | return F.l1_loss(pred, target, reduction='none')
20 |
21 |
22 | @weighted_loss
23 | def mse_loss(pred, target):
24 | return F.mse_loss(pred, target, reduction='none')
25 |
26 |
27 | # @weighted_loss
28 | # def charbonnier_loss(pred, target, eps=1e-12):
29 | # return torch.sqrt((pred - target)**2 + eps)
30 |
31 |
32 | class L1Loss(nn.Module):
33 | """L1 (mean absolute error, MAE) loss.
34 |
35 | Args:
36 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
37 | reduction (str): Specifies the reduction to apply to the output.
38 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
39 | """
40 |
41 | def __init__(self, loss_weight=1.0, reduction='mean'):
42 | super(L1Loss, self).__init__()
43 | if reduction not in ['none', 'mean', 'sum']:
44 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
45 | f'Supported ones are: {_reduction_modes}')
46 |
47 | self.loss_weight = loss_weight
48 | self.reduction = reduction
49 |
50 | def forward(self, pred, target, weight=None, **kwargs):
51 | """
52 | Args:
53 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
54 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
55 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
56 | weights. Default: None.
57 | """
58 | return self.loss_weight * l1_loss(
59 | pred, target, weight, reduction=self.reduction)
60 |
61 | class MSELoss(nn.Module):
62 | """MSE (L2) loss.
63 |
64 | Args:
65 | loss_weight (float): Loss weight for MSE loss. Default: 1.0.
66 | reduction (str): Specifies the reduction to apply to the output.
67 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
68 | """
69 |
70 | def __init__(self, loss_weight=1.0, reduction='mean'):
71 | super(MSELoss, self).__init__()
72 | if reduction not in ['none', 'mean', 'sum']:
73 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
74 | f'Supported ones are: {_reduction_modes}')
75 |
76 | self.loss_weight = loss_weight
77 | self.reduction = reduction
78 |
79 | def forward(self, pred, target, weight=None, **kwargs):
80 | """
81 | Args:
82 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
83 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
84 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
85 | weights. Default: None.
86 | """
87 | return self.loss_weight * mse_loss(
88 | pred, target, weight, reduction=self.reduction)
89 |
90 | class PSNRLoss(nn.Module):
91 |
92 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
93 | super(PSNRLoss, self).__init__()
94 | assert reduction == 'mean'
95 | self.loss_weight = loss_weight
96 | self.scale = 10 / np.log(10)
97 | self.toY = toY
98 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
99 | self.first = True
100 |
101 | def forward(self, pred, target):
102 | assert len(pred.size()) == 4
103 | if self.toY:
104 | if self.first:
105 | self.coef = self.coef.to(pred.device)
106 | self.first = False
107 |
108 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
109 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
110 |
111 | pred, target = pred / 255., target / 255.
112 | pass
113 | assert len(pred.size()) == 4
114 |
115 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
116 |
117 |
--------------------------------------------------------------------------------
/basicsr/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import math
8 | from collections import Counter
9 | from torch.optim.lr_scheduler import _LRScheduler
10 |
11 |
12 | class MultiStepRestartLR(_LRScheduler):
13 | """ MultiStep with restarts learning rate scheme.
14 |
15 | Args:
16 | optimizer (torch.nn.optimizer): Torch optimizer.
17 | milestones (list): Iterations that will decrease learning rate.
18 | gamma (float): Decrease ratio. Default: 0.1.
19 | restarts (list): Restart iterations. Default: [0].
20 | restart_weights (list): Restart weights at each restart iteration.
21 | Default: [1].
22 | last_epoch (int): Used in _LRScheduler. Default: -1.
23 | """
24 |
25 | def __init__(self,
26 | optimizer,
27 | milestones,
28 | gamma=0.1,
29 | restarts=(0, ),
30 | restart_weights=(1, ),
31 | last_epoch=-1):
32 | self.milestones = Counter(milestones)
33 | self.gamma = gamma
34 | self.restarts = restarts
35 | self.restart_weights = restart_weights
36 | assert len(self.restarts) == len(
37 | self.restart_weights), 'restarts and their weights do not match.'
38 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
39 |
40 | def get_lr(self):
41 | if self.last_epoch in self.restarts:
42 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
43 | return [
44 | group['initial_lr'] * weight
45 | for group in self.optimizer.param_groups
46 | ]
47 | if self.last_epoch not in self.milestones:
48 | return [group['lr'] for group in self.optimizer.param_groups]
49 | return [
50 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
51 | for group in self.optimizer.param_groups
52 | ]
53 |
54 | class LinearLR(_LRScheduler):
55 | """
56 |
57 | Args:
58 | optimizer (torch.nn.optimizer): Torch optimizer.
59 | milestones (list): Iterations that will decrease learning rate.
60 | gamma (float): Decrease ratio. Default: 0.1.
61 | last_epoch (int): Used in _LRScheduler. Default: -1.
62 | """
63 |
64 | def __init__(self,
65 | optimizer,
66 | total_iter,
67 | last_epoch=-1):
68 | self.total_iter = total_iter
69 | super(LinearLR, self).__init__(optimizer, last_epoch)
70 |
71 | def get_lr(self):
72 | process = self.last_epoch / self.total_iter
73 | weight = (1 - process)
74 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
75 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
76 |
77 | class VibrateLR(_LRScheduler):
78 | """
79 |
80 | Args:
81 | optimizer (torch.nn.optimizer): Torch optimizer.
82 | milestones (list): Iterations that will decrease learning rate.
83 | gamma (float): Decrease ratio. Default: 0.1.
84 | last_epoch (int): Used in _LRScheduler. Default: -1.
85 | """
86 |
87 | def __init__(self,
88 | optimizer,
89 | total_iter,
90 | last_epoch=-1):
91 | self.total_iter = total_iter
92 | super(VibrateLR, self).__init__(optimizer, last_epoch)
93 |
94 | def get_lr(self):
95 | process = self.last_epoch / self.total_iter
96 |
97 | f = 0.1
98 | if process < 3 / 8:
99 | f = 1 - process * 8 / 3
100 | elif process < 5 / 8:
101 | f = 0.2
102 |
103 | T = self.total_iter // 80
104 | Th = T // 2
105 |
106 | t = self.last_epoch % T
107 |
108 | f2 = t / Th
109 | if t >= Th:
110 | f2 = 2 - f2
111 |
112 | weight = f * f2
113 |
114 | if self.last_epoch < Th:
115 | weight = max(0.1, weight)
116 |
117 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
118 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
119 |
120 | def get_position_from_periods(iteration, cumulative_period):
121 | """Get the position from a period list.
122 |
123 | It will return the index of the right-closest number in the period list.
124 | For example, the cumulative_period = [100, 200, 300, 400],
125 | if iteration == 50, return 0;
126 | if iteration == 210, return 2;
127 | if iteration == 300, return 2.
128 |
129 | Args:
130 | iteration (int): Current iteration.
131 | cumulative_period (list[int]): Cumulative period list.
132 |
133 | Returns:
134 | int: The position of the right-closest number in the period list.
135 | """
136 | for i, period in enumerate(cumulative_period):
137 | if iteration <= period:
138 | return i
139 |
140 |
141 | class CosineAnnealingRestartLR(_LRScheduler):
142 | """ Cosine annealing with restarts learning rate scheme.
143 |
144 | An example of config:
145 | periods = [10, 10, 10, 10]
146 | restart_weights = [1, 0.5, 0.5, 0.5]
147 | eta_min=1e-7
148 |
149 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
150 | scheduler will restart with the weights in restart_weights.
151 |
152 | Args:
153 | optimizer (torch.nn.optimizer): Torch optimizer.
154 | periods (list): Period for each cosine anneling cycle.
155 | restart_weights (list): Restart weights at each restart iteration.
156 | Default: [1].
157 | eta_min (float): The mimimum lr. Default: 0.
158 | last_epoch (int): Used in _LRScheduler. Default: -1.
159 | """
160 |
161 | def __init__(self,
162 | optimizer,
163 | periods,
164 | restart_weights=(1, ),
165 | eta_min=0,
166 | last_epoch=-1):
167 | self.periods = periods
168 | self.restart_weights = restart_weights
169 | self.eta_min = eta_min
170 | assert (len(self.periods) == len(self.restart_weights)
171 | ), 'periods and restart_weights should have the same length.'
172 | self.cumulative_period = [
173 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
174 | ]
175 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
176 |
177 | def get_lr(self):
178 | idx = get_position_from_periods(self.last_epoch,
179 | self.cumulative_period)
180 | current_weight = self.restart_weights[idx]
181 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
182 | current_period = self.periods[idx]
183 |
184 | return [
185 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
186 | (1 + math.cos(math.pi * (
187 | (self.last_epoch - nearest_restart) / current_period)))
188 | for base_lr in self.base_lrs
189 | ]
190 |
--------------------------------------------------------------------------------
/basicsr/models/video_restoration_model.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch
3 | from collections import OrderedDict
4 | from copy import deepcopy
5 | from os import path as osp
6 | from tqdm import tqdm
7 | from torch.nn.parallel import DataParallel, DistributedDataParallel
8 | from basicsr.utils.dist_util import get_dist_info
9 | from basicsr.models.base_model import BaseModel
10 | from basicsr.utils import get_root_logger, imwrite, tensor2img
11 | from importlib import import_module
12 | import basicsr.loss as loss
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 |
16 | import json
17 |
18 | def create_video_model(opt):
19 | module = import_module('basicsr.models.archs.' + opt['model'].lower())
20 | model = module.make_model(opt)
21 | return model
22 |
23 | metric_module = importlib.import_module('basicsr.metrics')
24 |
25 | class VideoRestorationModel(BaseModel):
26 | def __init__(self, opt):
27 | super(VideoRestorationModel, self).__init__(opt)
28 | self.net_g = create_video_model(opt)
29 | self.net_g = self.model_to_device(self.net_g)
30 | self.n_sequence = opt['n_sequence']
31 | load_path = self.opt['path'].get('pretrain_network_g', None)
32 | if load_path is not None:
33 | self.load_network(self.net_g, load_path,
34 | self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params'))
35 | print("load_model", load_path)
36 | if self.is_train:
37 | self.init_training_settings()
38 | self.loss = loss.L1BaseLoss()
39 | self.scaler = torch.cuda.amp.GradScaler()
40 |
41 | def init_training_settings(self):
42 | self.net_g.train()
43 | # set up optimizers and schedulers
44 | self.setup_optimizers()
45 | self.setup_schedulers()
46 |
47 | def model_to_device(self, net):
48 | net = net.to(self.device)
49 | if self.opt['dist']:
50 | net = DistributedDataParallel(
51 | net,
52 | device_ids=[torch.cuda.current_device()],
53 | find_unused_parameters=False)
54 | elif self.opt['num_gpu'] > 1:
55 | net = DataParallel(net)
56 | return net
57 |
58 | def setup_optimizers(self):
59 | train_opt = self.opt['train']
60 | optim_params = []
61 | for k, v in self.net_g.named_parameters():
62 | if v.requires_grad:
63 | optim_params.append(v)
64 | else:
65 | logger = get_root_logger()
66 | logger.warning(f'Params {k} will not be optimized.')
67 | train_opt['optim_g'].pop('type')
68 | self.optimizer_g = torch.optim.AdamW([{'params': optim_params}],
69 | **train_opt['optim_g'])
70 | self.optimizers.append(self.optimizer_g)
71 |
72 | # method to feed the data to the model.
73 | def feed_data(self, data):
74 | lq, gt, _, _ = data
75 | self.lq = lq.to(self.device).half()
76 | self.gt = gt.to(self.device)
77 |
78 | def optimize_parameters(self, current_iter):
79 | self.optimizer_g.zero_grad()
80 | with torch.cuda.amp.autocast():
81 | loss_dict = OrderedDict()
82 | loss_dict['l_pix'] = 0
83 |
84 | frame_num = self.lq.shape[1]
85 | k_cache, v_cache = None, None
86 | for j in range(frame_num):
87 | target_g_images = self.gt[:, j, :, :, :]
88 | current_input = self.lq[:, j,:, :, :].unsqueeze(1)
89 | pre_input = self.lq[:, j if j == 0 else j-1, :, :, :].unsqueeze(1)
90 |
91 | input = torch.concat([pre_input, current_input], dim=1)
92 | (out_g, k_cache, v_cache) = self.net_g(input, k_cache, v_cache)
93 |
94 | l_pix = self.loss(out_g, target_g_images)
95 | loss_dict['l_pix'] += l_pix
96 |
97 | # normalize w.r.t. total frames seen.
98 | loss_dict['l_pix'] /= frame_num
99 | l_total = loss_dict['l_pix'] + 0 * sum(p.sum() for p in self.net_g.parameters())
100 | loss_dict['l_pix'] = loss_dict['l_pix']
101 |
102 | self.scaler.scale(l_total).backward()
103 | self.scaler.unscale_(self.optimizer_g)
104 | # do gradient clipping to avoid larger updates.
105 | # torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01)
106 | self.scaler.step(self.optimizer_g)
107 | self.scaler.update()
108 | self.log_dict = self.reduce_loss_dict(loss_dict)
109 |
110 | def test(self):
111 | self.net_g.eval()
112 | with torch.no_grad():
113 | self.outputs_list = []
114 | self.gt_lists = []
115 | self.lq_lists = []
116 | frame_num = self.lq.shape[1]
117 | k_cache, v_cache = None, None
118 | for j in range(frame_num):
119 | target_g_images = self.gt[:, j, :, :, :]
120 | current_input = self.lq[:, j,:, :, :].unsqueeze(1)
121 | pre_input = self.lq[:, j if j == 0 else j-1, :, :, :].unsqueeze(1)
122 | input = torch.concat([pre_input, current_input], dim=1)
123 | out_g, k_cache, v_cache = self.net_g(input.float(),
124 | k_cache,
125 | v_cache)
126 | self.outputs_list.append(out_g)
127 | self.gt_lists.append(target_g_images)
128 | self.lq_lists.append(self.lq[:, j,:, :, :])
129 | self.net_g.train()
130 |
131 | def non_cached_test(self):
132 | # proxy to the actual scores to save time.
133 | self.net_g.eval()
134 | with torch.no_grad():
135 | k_cache, v_cache = None, None
136 | pred, _, _, _ = self.net_g(self.lq.float(), k_cache, v_cache)
137 | if isinstance(pred, list):
138 | pred = pred[-1]
139 | self.output = pred
140 | self.net_g.train()
141 |
142 | def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image):
143 | logger = get_root_logger()
144 | import os
145 | return self.nondist_validation(dataloader, current_iter,
146 | tb_logger, save_img,
147 | rgb2bgr, use_image)
148 |
149 | def nondist_validation(self, dataloader, current_iter, tb_logger,
150 | save_img, rgb2bgr, use_image):
151 | with_metrics = self.opt['val'].get('metrics') is not None
152 | if with_metrics:
153 | self.metric_results = {
154 | metric: 0
155 | for metric in self.opt['val']['metrics'].keys()
156 | }
157 | rank, world_size = get_dist_info()
158 | if rank == 0:
159 | pbar = tqdm(total=len(dataloader), unit='image')
160 | cnt = 0
161 |
162 | for idx, val_data in enumerate(dataloader):
163 | if idx % world_size != rank:
164 | continue
165 |
166 | folder_name, img_name = val_data[len(val_data)-1][0][0].split('.')
167 | self.feed_data(val_data)
168 | self.test()
169 |
170 | for temp_i in range(len(self.outputs_list)):
171 | sr_img = tensor2img(self.outputs_list[temp_i], rgb2bgr=rgb2bgr)
172 | gt_img = tensor2img(self.gt_lists[temp_i], rgb2bgr=rgb2bgr)
173 | lq_img = tensor2img(self.lq_lists[temp_i], rgb2bgr=rgb2bgr)
174 |
175 | if save_img:
176 | # if self.opt['is_train']:
177 | save_img_path = osp.join(self.opt['path']['visualization'],
178 | folder_name,
179 | f'{img_name}_frame{temp_i}_res.png')
180 |
181 | save_gt_img_path = osp.join(self.opt['path']['visualization'],
182 | folder_name,
183 | f'{img_name}_frame{temp_i}_gt.png')
184 |
185 | save_lq_img_path = osp.join(self.opt['path']['visualization'],
186 | folder_name,
187 | f'{img_name}_frame{temp_i}_lq.png')
188 |
189 | imwrite(sr_img, save_img_path)
190 | imwrite(gt_img, save_gt_img_path)
191 | imwrite(lq_img, save_lq_img_path)
192 |
193 | if with_metrics:
194 | # calculate metrics
195 | opt_metric = deepcopy(self.opt['val']['metrics'])
196 | if use_image:
197 | for name, opt_ in opt_metric.items():
198 | metric_type = opt_.pop('type')
199 | self.metric_results[name] += getattr(
200 | metric_module, metric_type)(sr_img, gt_img, **opt_)
201 | else:
202 | for name, opt_ in opt_metric.items():
203 | metric_type = opt_.pop('type')
204 | self.metric_results[name] += getattr(
205 | metric_module, metric_type)(self.outputs_list[temp_i], self.gt_lists[temp_i], **opt_)
206 |
207 | cnt += 1
208 | if rank == 0:
209 | for _ in range(world_size):
210 | pbar.update(1)
211 | pbar.set_description(f'Test {img_name}')
212 |
213 | if rank == 0:
214 | pbar.close()
215 |
216 | current_metric = 0.
217 | if with_metrics:
218 | for metric in self.metric_results.keys():
219 | self.metric_results[metric] /= cnt
220 | current_metric = self.metric_results[metric]
221 |
222 | self._log_validation_metric_values(current_iter,
223 | tb_logger)
224 | return current_metric
225 |
226 |
227 | def _log_validation_metric_values(self, current_iter, tb_logger):
228 | log_str = f'Validation,\t'
229 | for metric, value in self.metric_results.items():
230 | log_str += f'\t # {metric}: {value:.4f}'
231 | logger = get_root_logger()
232 | logger.info(log_str)
233 | if tb_logger:
234 | for metric, value in self.metric_results.items():
235 | tb_logger.add_scalar(f'metrics/{metric}', value, current_iter)
236 |
237 | def get_current_visuals(self):
238 | # pick the current frame.
239 | out_dict = OrderedDict()
240 | out_dict['lq'] = self.lq[:,1,:,:,:].detach().cpu()
241 | out_dict['result'] = self.output.detach().cpu()
242 | if hasattr(self, 'gt'):
243 | out_dict['gt'] = self.gt[:,1,:,:,:].detach().cpu()
244 | return out_dict
245 |
246 | def save(self, epoch, current_iter):
247 | self.save_network(self.net_g, 'net_g', current_iter)
248 | self.save_training_state(epoch, current_iter)
--------------------------------------------------------------------------------
/basicsr/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import logging
4 | import math
5 | import random
6 | import time
7 | import torch
8 | import sys
9 | from pathlib import Path
10 | sys.path.append(str(Path(__file__).parents[1]))
11 | from os import path as osp
12 |
13 | from basicsr.data import create_dataloader
14 | from basicsr.data.data_sampler import EnlargedSampler
15 | from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
16 | from basicsr.models import create_model
17 | from basicsr.utils import (MessageLogger, check_resume, get_env_info,
18 | get_root_logger, get_time_str, init_tb_logger,
19 | init_wandb_logger, make_exp_dirs, mkdir_and_rename,
20 | set_random_seed)
21 | from basicsr.utils.dist_util import get_dist_info, init_dist
22 | from basicsr.utils.options import dict2str, parse
23 |
24 | # for superresolution uncomment this line, and comment line # 26.
25 | # from basicsr.data.video_super_image_dataset import VideoSuperImageDataset as VideoImageDataset
26 |
27 | # for deblurring/deraining/etc. comment the line above, and uncomment the next line.
28 | from basicsr.data.video_image_dataset import VideoImageDataset
29 | import torch.distributed as dist
30 |
31 | import os
32 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
33 | def parse_options(is_train=True):
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | '-opt', type=str, required=True, help='Path to option YAML file.')
37 | parser.add_argument(
38 | '--launcher',
39 | choices=['none', 'pytorch', 'slurm'],
40 | default='none',
41 | help='job launcher')
42 | parser.add_argument('--local_rank', type=int, default=0)
43 | args, unknown = parser.parse_known_args()
44 | opt = parse(args.opt, is_train=is_train)
45 |
46 | # distributed settings
47 | if args.launcher == 'none':
48 | opt['dist'] = False
49 | print('Disable distributed.', flush=True)
50 | else:
51 | opt['dist'] = True
52 | # increase timeout to 1.5 hours
53 | opt['dist_params']['timeout'] = datetime.timedelta(seconds=5400)
54 | if args.launcher == 'slurm' and 'dist_params' in opt:
55 | init_dist(args.launcher, **opt['dist_params'])
56 | else:
57 | init_dist(args.launcher)
58 | print('init dist .. ', args.launcher)
59 |
60 | opt['rank'], opt['world_size'] = get_dist_info()
61 |
62 | # random seed
63 | seed = opt.get('manual_seed')
64 | if seed is None:
65 | seed = random.randint(1, 10000)
66 | opt['manual_seed'] = seed
67 | set_random_seed(seed + opt['rank'])
68 | torch.manual_seed(seed+opt['rank'])
69 |
70 | return opt
71 |
72 |
73 | def init_loggers(opt):
74 | log_file = osp.join(opt['path']['log'],
75 | f"train_{opt['name']}_{get_time_str()}.log")
76 | logger = get_root_logger(
77 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
78 | logger.info(get_env_info())
79 | logger.info(dict2str(opt))
80 |
81 | # initialize wandb logger before tensorboard logger to allow proper sync:
82 | if (opt['logger'].get('wandb')
83 | is not None) and (opt['logger']['wandb'].get('project')
84 | is not None) and ('debug' not in opt['name']):
85 | assert opt['logger'].get('use_tb_logger') is True, (
86 | 'should turn on tensorboard when using wandb')
87 | init_wandb_logger(opt)
88 | tb_logger = None
89 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
90 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
91 | return logger, tb_logger
92 |
93 |
94 | def create_train_val_dataloader(opt, logger):
95 | # create train and val dataloaders
96 | train_loader, val_loader = None, None
97 | for phase, dataset_opt in opt['datasets'].items():
98 | if phase == 'train':
99 | dataset_enlarge_ratio = 1
100 | train_set = VideoImageDataset(opt, phase)
101 | train_sampler = EnlargedSampler(train_set, opt['world_size'],
102 | opt['rank'], dataset_enlarge_ratio)
103 | train_loader = create_dataloader(
104 | train_set,
105 | dataset_opt,
106 | num_gpu=opt['num_gpu'],
107 | dist=opt['dist'],
108 | sampler=train_sampler,
109 | seed=opt['manual_seed'])
110 |
111 | num_iter_per_epoch = math.ceil(
112 | len(train_set) * dataset_enlarge_ratio /
113 | (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
114 | total_iters = int(opt['train']['total_iter'])
115 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
116 | logger.info(
117 | 'Training statistics:'
118 | f'\n\tNumber of train images: {len(train_set)}'
119 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
120 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
121 | f'\n\tWorld size (gpu number): {opt["world_size"]}'
122 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
123 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
124 | elif phase == 'val':
125 | val_set = VideoImageDataset(opt, phase)
126 | val_loader = create_dataloader(
127 | val_set,
128 | dataset_opt,
129 | num_gpu=opt['num_gpu'],
130 | dist=opt['dist'],
131 | sampler=None,
132 | seed=opt['manual_seed'])
133 | logger.info(
134 | f'Number of val images/folders in dataset: '
135 | f'{len(val_set)}')
136 | else:
137 | raise ValueError(f'Dataset phase {phase} is not recognized')
138 |
139 | return train_loader, train_sampler, val_loader, total_epochs, total_iters
140 |
141 | def main():
142 | # parse options, set distributed setting, set ramdom seed
143 | opt = parse_options(is_train=True)
144 |
145 | torch.backends.cudnn.benchmark = True
146 | # torch.backends.cudnn.deterministic = True
147 | state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
148 | import os
149 | try:
150 | states = os.listdir(state_folder_path)
151 | except:
152 | states = []
153 |
154 | resume_state = None
155 | if len(states) > 0:
156 | max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
157 | resume_state = os.path.join(state_folder_path, max_state_file)
158 | opt['path']['resume_state'] = resume_state
159 |
160 | # load resume states if necessary
161 | if opt['path'].get('resume_state'):
162 | device_id = torch.cuda.current_device()
163 | resume_state = torch.load(
164 | opt['path']['resume_state'],
165 | map_location=lambda storage, loc: storage.cuda(device_id))
166 | else:
167 | resume_state = None
168 |
169 | # mkdir for experiments and logger
170 | if resume_state is None:
171 | make_exp_dirs(opt)
172 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
173 | 'name'] and opt['rank'] == 0:
174 | mkdir_and_rename(osp.join('tb_logger', opt['name']))
175 |
176 | # initialize loggers
177 | logger, tb_logger = init_loggers(opt)
178 |
179 | # create train and validation dataloaders
180 | result = create_train_val_dataloader(opt, logger)
181 | train_loader, train_sampler, val_loader, total_epochs, total_iters = result
182 | print("Len --- train_loader", len(train_loader))
183 | if resume_state: # resume training
184 | print("resuming is True")
185 | check_resume(opt, resume_state['iter'])
186 | model = create_model(opt)
187 | model.resume_training(resume_state) # handle optimizers and scheduler
188 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
189 | f"iter: {resume_state['iter']}.")
190 | start_epoch = resume_state['epoch']
191 | current_iter = resume_state['iter']
192 | del resume_state
193 | torch.cuda.empty_cache()
194 | else:
195 | model = create_model(opt)
196 | start_epoch = 0
197 | current_iter = 0
198 |
199 | # create message logger (formatted outputs)
200 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
201 |
202 | # dataloader prefetcher
203 | prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
204 | if prefetch_mode is None or prefetch_mode == 'cpu':
205 | prefetcher = CPUPrefetcher(train_loader)
206 | elif prefetch_mode == 'cuda':
207 | prefetcher = CUDAPrefetcher(train_loader, opt)
208 | logger.info(f'Use {prefetch_mode} prefetch dataloader')
209 | if opt['datasets']['train'].get('pin_memory') is not True:
210 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
211 | else:
212 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
213 | "Supported ones are: None, 'cuda', 'cpu'.")
214 |
215 | # training
216 | logger.info(
217 | f'Start training from epoch: {start_epoch}, iter: {current_iter}')
218 | data_time, iter_time = time.time(), time.time()
219 | start_time = time.time()
220 |
221 | epoch = start_epoch
222 | while current_iter <= total_iters:
223 | train_sampler.set_epoch(epoch)
224 | prefetcher.reset()
225 | train_data = prefetcher.next()
226 |
227 | while train_data is not None:
228 | data_time = time.time() - data_time
229 |
230 | current_iter += 1
231 | if current_iter > total_iters:
232 | break
233 | # update learning rate
234 | model.update_learning_rate(
235 | current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
236 | # training
237 | model.feed_data(train_data)
238 | model.optimize_parameters(current_iter)
239 | iter_time = time.time() - iter_time
240 | # log
241 | if dist.get_rank() == 0 and current_iter % opt['logger']['print_freq'] == 0:
242 | log_vars = {'epoch': epoch, 'iter': current_iter}
243 | log_vars.update({'lrs': model.get_current_learning_rate()})
244 | log_vars.update({'time': iter_time, 'data_time': data_time})
245 | log_vars.update(model.get_current_log())
246 | print("Loss at iteration", current_iter, model.get_current_log()['l_pix'])
247 | msg_logger(log_vars)
248 |
249 | # save models and training states
250 | if dist.get_rank() == 0 and current_iter % opt['logger']['save_checkpoint_freq'] == 0:
251 | logger.info('Saving models and training states.')
252 | print("saving")
253 | model.save(epoch, current_iter)
254 |
255 | # validation
256 | if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0):
257 | rgb2bgr = opt['val'].get('rgb2bgr', True)
258 | use_image = opt['val'].get('use_image', True)
259 | model.validation(val_loader, current_iter, tb_logger,
260 | opt['val']['save_img'],
261 | rgb2bgr, use_image)
262 | log_vars = {'epoch': epoch, 'iter': current_iter, 'total_iter': total_iters}
263 | log_vars.update({'lrs': model.get_current_learning_rate()})
264 | log_vars.update(model.get_current_log())
265 | msg_logger(log_vars)
266 |
267 | data_time = time.time()
268 | iter_time = time.time()
269 | train_data = prefetcher.next()
270 |
271 | if dist.get_rank() == 0:
272 | print(f"Loss at the end of Epoch {epoch} is {model.get_current_log()['l_pix']}.")
273 | # end of iter
274 | epoch += 1
275 |
276 | # end of epoch
277 | consumed_time = str(
278 | datetime.timedelta(seconds=int(time.time() - start_time)))
279 | logger.info(f'End of training. Time consumed: {consumed_time}')
280 | logger.info('Save the latest model.')
281 | model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
282 | if opt.get('val') is not None:
283 | rgb2bgr = opt['val'].get('rgb2bgr', True)
284 | use_image = opt['val'].get('use_image', True)
285 | model.validation(val_loader, current_iter, tb_logger,
286 | opt['val']['save_img'],
287 | rgb2bgr, use_image)
288 | if tb_logger:
289 | tb_logger.close()
290 |
291 |
292 | if __name__ == '__main__':
293 | main()
294 |
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 | from .file_client import FileClient
6 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, pltimwrite
7 | from .logger import (MessageLogger, get_env_info, get_root_logger,
8 | init_tb_logger, init_wandb_logger)
9 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
10 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt)
11 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
12 |
13 | __all__ = [
14 | # file_client.py
15 | 'FileClient',
16 | # img_util.py
17 | 'img2tensor',
18 | 'tensor2img',
19 | 'imfrombytes',
20 | 'imwrite',
21 | 'pltimwrite',
22 | 'crop_border',
23 | # logger.py
24 | 'MessageLogger',
25 | 'init_tb_logger',
26 | 'init_wandb_logger',
27 | 'get_root_logger',
28 | 'get_env_info',
29 | # misc.py
30 | 'set_random_seed',
31 | 'get_time_str',
32 | 'mkdir_and_rename',
33 | 'make_exp_dirs',
34 | 'scandir',
35 | 'check_resume',
36 | 'sizeof_fmt',
37 | 'padding',
38 | 'create_lmdb_for_reds',
39 | 'create_lmdb_for_gopro',
40 | 'create_lmdb_for_rain13k',
41 | ]
42 |
--------------------------------------------------------------------------------
/basicsr/utils/create_lmdb.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 | import argparse
6 | from os import path as osp
7 |
8 | from basicsr.utils import scandir
9 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs
10 |
11 | def prepare_keys(folder_path, suffix='png'):
12 | """Prepare image path list and keys for DIV2K dataset.
13 |
14 | Args:
15 | folder_path (str): Folder path.
16 |
17 | Returns:
18 | list[str]: Image path list.
19 | list[str]: Key list.
20 | """
21 | print('Reading image path list ...')
22 | img_path_list = sorted(
23 | list(scandir(folder_path, suffix=suffix, recursive=False)))
24 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
25 |
26 | return img_path_list, keys
27 |
28 | def create_lmdb_for_reds():
29 | folder_path = './datasets/REDS/val/sharp_300'
30 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
31 | img_path_list, keys = prepare_keys(folder_path, 'png')
32 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
33 | #
34 | folder_path = './datasets/REDS/val/blur_300'
35 | lmdb_path = './datasets/REDS/val/blur_300.lmdb'
36 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
37 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
38 |
39 | folder_path = './datasets/REDS/train/train_sharp'
40 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
41 | img_path_list, keys = prepare_keys(folder_path, 'png')
42 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
43 |
44 | folder_path = './datasets/REDS/train/train_blur_jpeg'
45 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
46 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
47 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
48 |
49 |
50 | def create_lmdb_for_gopro():
51 | folder_path = './datasets/GoPro/train/blur_crops'
52 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
53 |
54 | img_path_list, keys = prepare_keys(folder_path, 'png')
55 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
56 |
57 | folder_path = './datasets/GoPro/train/sharp_crops'
58 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
59 |
60 | img_path_list, keys = prepare_keys(folder_path, 'png')
61 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
62 |
63 | folder_path = './datasets/GoPro/test/target'
64 | lmdb_path = './datasets/GoPro/test/target.lmdb'
65 |
66 | img_path_list, keys = prepare_keys(folder_path, 'png')
67 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
68 |
69 | folder_path = './datasets/GoPro/test/input'
70 | lmdb_path = './datasets/GoPro/test/input.lmdb'
71 |
72 | img_path_list, keys = prepare_keys(folder_path, 'png')
73 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
74 |
75 | def create_lmdb_for_rain13k():
76 | folder_path = './datasets/Rain13k/train/input'
77 | lmdb_path = './datasets/Rain13k/train/input.lmdb'
78 |
79 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
80 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
81 |
82 | folder_path = './datasets/Rain13k/train/target'
83 | lmdb_path = './datasets/Rain13k/train/target.lmdb'
84 |
85 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
86 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
87 |
88 | def create_lmdb_for_SIDD():
89 | folder_path = './datasets/SIDD/train/input_crops'
90 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
91 |
92 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
93 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
94 |
95 | folder_path = './datasets/SIDD/train/gt_crops'
96 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
97 |
98 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
99 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
100 |
101 | #for val
102 | folder_path = './datasets/SIDD/val/input_crops'
103 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
104 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
105 | if not osp.exists(folder_path):
106 | os.makedirs(folder_path)
107 | assert osp.exists(mat_path)
108 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
109 | N, B, H ,W, C = data.shape
110 | data = data.reshape(N*B, H, W, C)
111 | for i in tqdm(range(N*B)):
112 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
113 | img_path_list, keys = prepare_keys(folder_path, 'png')
114 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
115 |
116 | folder_path = './datasets/SIDD/val/gt_crops'
117 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
118 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
119 | if not osp.exists(folder_path):
120 | os.makedirs(folder_path)
121 | assert osp.exists(mat_path)
122 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
123 | N, B, H ,W, C = data.shape
124 | data = data.reshape(N*B, H, W, C)
125 | for i in tqdm(range(N*B)):
126 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
127 | img_path_list, keys = prepare_keys(folder_path, 'png')
128 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
129 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
3 | # Copyright 2018-2020 BasicSR Authors
4 | # ------------------------------------------------------------------------
5 |
6 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
7 | import functools
8 | import os
9 | import subprocess
10 | import torch
11 | import torch.distributed as dist
12 | import torch.multiprocessing as mp
13 | import datetime
14 |
15 | def init_dist(launcher, backend='nccl', **kwargs):
16 | if mp.get_start_method(allow_none=True) is None:
17 | mp.set_start_method('spawn')
18 | if launcher == 'pytorch':
19 | _init_dist_pytorch(backend, **kwargs)
20 | elif launcher == 'slurm':
21 | _init_dist_slurm(backend, **kwargs)
22 | else:
23 | raise ValueError(f'Invalid launcher type: {launcher}')
24 |
25 |
26 | def _init_dist_pytorch(backend, **kwargs):
27 | rank = int(os.environ['RANK'])
28 | num_gpus = torch.cuda.device_count()
29 | torch.cuda.set_device(rank % num_gpus)
30 | dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=5400), **kwargs)
31 |
32 |
33 | def _init_dist_slurm(backend, port=None):
34 | """Initialize slurm distributed training environment.
35 |
36 | If argument ``port`` is not specified, then the master port will be system
37 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
38 | environment variable, then a default port ``29500`` will be used.
39 |
40 | Args:
41 | backend (str): Backend of torch.distributed.
42 | port (int, optional): Master port. Defaults to None.
43 | """
44 | proc_id = int(os.environ['SLURM_PROCID'])
45 | ntasks = int(os.environ['SLURM_NTASKS'])
46 | node_list = os.environ['SLURM_NODELIST']
47 | num_gpus = torch.cuda.device_count()
48 | torch.cuda.set_device(proc_id % num_gpus)
49 | addr = subprocess.getoutput(
50 | f'scontrol show hostname {node_list} | head -n1')
51 | # specify master port
52 | if port is not None:
53 | os.environ['MASTER_PORT'] = str(port)
54 | elif 'MASTER_PORT' in os.environ:
55 | pass # use MASTER_PORT in the environment variable
56 | else:
57 | # 29500 is torch.distributed default port
58 | os.environ['MASTER_PORT'] = '29500'
59 | os.environ['MASTER_ADDR'] = addr
60 | os.environ['WORLD_SIZE'] = str(ntasks)
61 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
62 | os.environ['RANK'] = str(proc_id)
63 | dist.init_process_group(backend=backend)
64 |
65 |
66 | def get_dist_info():
67 | if dist.is_available():
68 | initialized = dist.is_initialized()
69 | else:
70 | initialized = False
71 | if initialized:
72 | rank = dist.get_rank()
73 | world_size = dist.get_world_size()
74 | else:
75 | rank = 0
76 | world_size = 1
77 | return rank, world_size
78 |
79 |
80 | def master_only(func):
81 |
82 | @functools.wraps(func)
83 | def wrapper(*args, **kwargs):
84 | rank, _ = get_dist_info()
85 | if rank == 0:
86 | return func(*args, **kwargs)
87 |
88 | return wrapper
89 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import math
8 | import requests
9 | from tqdm import tqdm
10 |
11 | from .misc import sizeof_fmt
12 |
13 |
14 | def download_file_from_google_drive(file_id, save_path):
15 | """Download files from google drive.
16 |
17 | Ref:
18 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
19 |
20 | Args:
21 | file_id (str): File id.
22 | save_path (str): Save path.
23 | """
24 |
25 | session = requests.Session()
26 | URL = 'https://docs.google.com/uc?export=download'
27 | params = {'id': file_id}
28 |
29 | response = session.get(URL, params=params, stream=True)
30 | token = get_confirm_token(response)
31 | if token:
32 | params['confirm'] = token
33 | response = session.get(URL, params=params, stream=True)
34 |
35 | # get file size
36 | response_file_size = session.get(
37 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
38 | if 'Content-Range' in response_file_size.headers:
39 | file_size = int(
40 | response_file_size.headers['Content-Range'].split('/')[1])
41 | else:
42 | file_size = None
43 |
44 | save_response_content(response, save_path, file_size)
45 |
46 |
47 | def get_confirm_token(response):
48 | for key, value in response.cookies.items():
49 | if key.startswith('download_warning'):
50 | return value
51 | return None
52 |
53 |
54 | def save_response_content(response,
55 | destination,
56 | file_size=None,
57 | chunk_size=32768):
58 | if file_size is not None:
59 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
60 |
61 | readable_file_size = sizeof_fmt(file_size)
62 | else:
63 | pbar = None
64 |
65 | with open(destination, 'wb') as f:
66 | downloaded_size = 0
67 | for chunk in response.iter_content(chunk_size):
68 | downloaded_size += chunk_size
69 | if pbar is not None:
70 | pbar.update(1)
71 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
72 | f'/ {readable_file_size}')
73 | if chunk: # filter out keep-alive new chunks
74 | f.write(chunk)
75 | if pbar is not None:
76 | pbar.close()
77 |
--------------------------------------------------------------------------------
/basicsr/utils/face_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import cv2
8 | import numpy as np
9 | import os
10 | import torch
11 | from skimage import transform as trans
12 |
13 | from basicsr.utils import imwrite
14 |
15 | try:
16 | import dlib
17 | except ImportError:
18 | print('Please install dlib before testing face restoration.'
19 | 'Reference: https://github.com/davisking/dlib')
20 |
21 |
22 | class FaceRestorationHelper(object):
23 | """Helper for the face restoration pipeline."""
24 |
25 | def __init__(self, upscale_factor, face_size=512):
26 | self.upscale_factor = upscale_factor
27 | self.face_size = (face_size, face_size)
28 |
29 | # standard 5 landmarks for FFHQ faces with 1024 x 1024
30 | self.face_template = np.array([[686.77227723, 488.62376238],
31 | [586.77227723, 493.59405941],
32 | [337.91089109, 488.38613861],
33 | [437.95049505, 493.51485149],
34 | [513.58415842, 678.5049505]])
35 | self.face_template = self.face_template / (1024 // face_size)
36 | # for estimation the 2D similarity transformation
37 | self.similarity_trans = trans.SimilarityTransform()
38 |
39 | self.all_landmarks_5 = []
40 | self.all_landmarks_68 = []
41 | self.affine_matrices = []
42 | self.inverse_affine_matrices = []
43 | self.cropped_faces = []
44 | self.restored_faces = []
45 | self.save_png = True
46 |
47 | def init_dlib(self, detection_path, landmark5_path, landmark68_path):
48 | """Initialize the dlib detectors and predictors."""
49 | self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
50 | self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
51 | self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
52 |
53 | def free_dlib_gpu_memory(self):
54 | del self.face_detector
55 | del self.shape_predictor_5
56 | del self.shape_predictor_68
57 |
58 | def read_input_image(self, img_path):
59 | # self.input_img is Numpy array, (h, w, c) with RGB order
60 | self.input_img = dlib.load_rgb_image(img_path)
61 |
62 | def detect_faces(self,
63 | img_path,
64 | upsample_num_times=1,
65 | only_keep_largest=False):
66 | """
67 | Args:
68 | img_path (str): Image path.
69 | upsample_num_times (int): Upsamples the image before running the
70 | face detector
71 |
72 | Returns:
73 | int: Number of detected faces.
74 | """
75 | self.read_input_image(img_path)
76 | det_faces = self.face_detector(self.input_img, upsample_num_times)
77 | if len(det_faces) == 0:
78 | print('No face detected. Try to increase upsample_num_times.')
79 | else:
80 | if only_keep_largest:
81 | print('Detect several faces and only keep the largest.')
82 | face_areas = []
83 | for i in range(len(det_faces)):
84 | face_area = (det_faces[i].rect.right() -
85 | det_faces[i].rect.left()) * (
86 | det_faces[i].rect.bottom() -
87 | det_faces[i].rect.top())
88 | face_areas.append(face_area)
89 | largest_idx = face_areas.index(max(face_areas))
90 | self.det_faces = [det_faces[largest_idx]]
91 | else:
92 | self.det_faces = det_faces
93 | return len(self.det_faces)
94 |
95 | def get_face_landmarks_5(self):
96 | for face in self.det_faces:
97 | shape = self.shape_predictor_5(self.input_img, face.rect)
98 | landmark = np.array([[part.x, part.y] for part in shape.parts()])
99 | self.all_landmarks_5.append(landmark)
100 | return len(self.all_landmarks_5)
101 |
102 | def get_face_landmarks_68(self):
103 | """Get 68 densemarks for cropped images.
104 |
105 | Should only have one face at most in the cropped image.
106 | """
107 | num_detected_face = 0
108 | for idx, face in enumerate(self.cropped_faces):
109 | # face detection
110 | det_face = self.face_detector(face, 1) # TODO: can we remove it?
111 | if len(det_face) == 0:
112 | print(f'Cannot find faces in cropped image with index {idx}.')
113 | self.all_landmarks_68.append(None)
114 | else:
115 | if len(det_face) > 1:
116 | print('Detect several faces in the cropped face. Use the '
117 | ' largest one. Note that it will also cause overlap '
118 | 'during paste_faces_to_input_image.')
119 | face_areas = []
120 | for i in range(len(det_face)):
121 | face_area = (det_face[i].rect.right() -
122 | det_face[i].rect.left()) * (
123 | det_face[i].rect.bottom() -
124 | det_face[i].rect.top())
125 | face_areas.append(face_area)
126 | largest_idx = face_areas.index(max(face_areas))
127 | face_rect = det_face[largest_idx].rect
128 | else:
129 | face_rect = det_face[0].rect
130 | shape = self.shape_predictor_68(face, face_rect)
131 | landmark = np.array([[part.x, part.y]
132 | for part in shape.parts()])
133 | self.all_landmarks_68.append(landmark)
134 | num_detected_face += 1
135 |
136 | return num_detected_face
137 |
138 | def warp_crop_faces(self,
139 | save_cropped_path=None,
140 | save_inverse_affine_path=None):
141 | """Get affine matrix, warp and cropped faces.
142 |
143 | Also get inverse affine matrix for post-processing.
144 | """
145 | for idx, landmark in enumerate(self.all_landmarks_5):
146 | # use 5 landmarks to get affine matrix
147 | self.similarity_trans.estimate(landmark, self.face_template)
148 | affine_matrix = self.similarity_trans.params[0:2, :]
149 | self.affine_matrices.append(affine_matrix)
150 | # warp and crop faces
151 | cropped_face = cv2.warpAffine(self.input_img, affine_matrix,
152 | self.face_size)
153 | self.cropped_faces.append(cropped_face)
154 | # save the cropped face
155 | if save_cropped_path is not None:
156 | path, ext = os.path.splitext(save_cropped_path)
157 | if self.save_png:
158 | save_path = f'{path}_{idx:02d}.png'
159 | else:
160 | save_path = f'{path}_{idx:02d}{ext}'
161 |
162 | imwrite(
163 | cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path)
164 |
165 | # get inverse affine matrix
166 | self.similarity_trans.estimate(self.face_template,
167 | landmark * self.upscale_factor)
168 | inverse_affine = self.similarity_trans.params[0:2, :]
169 | self.inverse_affine_matrices.append(inverse_affine)
170 | # save inverse affine matrices
171 | if save_inverse_affine_path is not None:
172 | path, _ = os.path.splitext(save_inverse_affine_path)
173 | save_path = f'{path}_{idx:02d}.pth'
174 | torch.save(inverse_affine, save_path)
175 |
176 | def add_restored_face(self, face):
177 | self.restored_faces.append(face)
178 |
179 | def paste_faces_to_input_image(self, save_path):
180 | # operate in the BGR order
181 | input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR)
182 | h, w, _ = input_img.shape
183 | h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
184 | # simply resize the background
185 | upsample_img = cv2.resize(input_img, (w_up, h_up))
186 | assert len(self.restored_faces) == len(self.inverse_affine_matrices), (
187 | 'length of restored_faces and affine_matrices are different.')
188 | for restored_face, inverse_affine in zip(self.restored_faces,
189 | self.inverse_affine_matrices):
190 | inv_restored = cv2.warpAffine(restored_face, inverse_affine,
191 | (w_up, h_up))
192 | mask = np.ones((*self.face_size, 3), dtype=np.float32)
193 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
194 | # remove the black borders
195 | inv_mask_erosion = cv2.erode(
196 | inv_mask,
197 | np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
198 | np.uint8))
199 | inv_restored_remove_border = inv_mask_erosion * inv_restored
200 | total_face_area = np.sum(inv_mask_erosion) // 3
201 | # compute the fusion edge based on the area of face
202 | w_edge = int(total_face_area**0.5) // 20
203 | erosion_radius = w_edge * 2
204 | inv_mask_center = cv2.erode(
205 | inv_mask_erosion,
206 | np.ones((erosion_radius, erosion_radius), np.uint8))
207 | blur_size = w_edge * 2
208 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center,
209 | (blur_size + 1, blur_size + 1), 0)
210 | upsample_img = inv_soft_mask * inv_restored_remove_border + (
211 | 1 - inv_soft_mask) * upsample_img
212 | if self.save_png:
213 | save_path = save_path.replace('.jpg',
214 | '.png').replace('.jpeg', '.png')
215 | imwrite(upsample_img.astype(np.uint8), save_path)
216 |
217 | def clean_all(self):
218 | self.all_landmarks_5 = []
219 | self.all_landmarks_68 = []
220 | self.restored_faces = []
221 | self.affine_matrices = []
222 | self.cropped_faces = []
223 | self.inverse_affine_matrices = []
224 |
--------------------------------------------------------------------------------
/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
8 | from abc import ABCMeta, abstractmethod
9 |
10 |
11 | class BaseStorageBackend(metaclass=ABCMeta):
12 | """Abstract class of storage backends.
13 |
14 | All backends need to implement two apis: ``get()`` and ``get_text()``.
15 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
16 | as texts.
17 | """
18 |
19 | @abstractmethod
20 | def get(self, filepath):
21 | pass
22 |
23 | @abstractmethod
24 | def get_text(self, filepath):
25 | pass
26 |
27 |
28 | class MemcachedBackend(BaseStorageBackend):
29 | """Memcached storage backend.
30 |
31 | Attributes:
32 | server_list_cfg (str): Config file for memcached server list.
33 | client_cfg (str): Config file for memcached client.
34 | sys_path (str | None): Additional path to be appended to `sys.path`.
35 | Default: None.
36 | """
37 |
38 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
39 | if sys_path is not None:
40 | import sys
41 | sys.path.append(sys_path)
42 | try:
43 | import mc
44 | except ImportError:
45 | raise ImportError(
46 | 'Please install memcached to enable MemcachedBackend.')
47 |
48 | self.server_list_cfg = server_list_cfg
49 | self.client_cfg = client_cfg
50 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
51 | self.client_cfg)
52 | # mc.pyvector servers as a point which points to a memory cache
53 | self._mc_buffer = mc.pyvector()
54 |
55 | def get(self, filepath):
56 | filepath = str(filepath)
57 | import mc
58 | self._client.Get(filepath, self._mc_buffer)
59 | value_buf = mc.ConvertBuffer(self._mc_buffer)
60 | return value_buf
61 |
62 | def get_text(self, filepath):
63 | raise NotImplementedError
64 |
65 |
66 | class HardDiskBackend(BaseStorageBackend):
67 | """Raw hard disks storage backend."""
68 |
69 | def get(self, filepath):
70 | filepath = str(filepath)
71 | with open(filepath, 'rb') as f:
72 | value_buf = f.read()
73 | return value_buf
74 |
75 | def get_text(self, filepath):
76 | filepath = str(filepath)
77 | with open(filepath, 'r') as f:
78 | value_buf = f.read()
79 | return value_buf
80 |
81 |
82 | class LmdbBackend(BaseStorageBackend):
83 | """Lmdb storage backend.
84 |
85 | Args:
86 | db_paths (str | list[str]): Lmdb database paths.
87 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
88 | readonly (bool, optional): Lmdb environment parameter. If True,
89 | disallow any write operations. Default: True.
90 | lock (bool, optional): Lmdb environment parameter. If False, when
91 | concurrent access occurs, do not lock the database. Default: False.
92 | readahead (bool, optional): Lmdb environment parameter. If False,
93 | disable the OS filesystem readahead mechanism, which may improve
94 | random read performance when a database is larger than RAM.
95 | Default: False.
96 |
97 | Attributes:
98 | db_paths (list): Lmdb database path.
99 | _client (list): A list of several lmdb envs.
100 | """
101 |
102 | def __init__(self,
103 | db_paths,
104 | client_keys='default',
105 | readonly=True,
106 | lock=False,
107 | readahead=False,
108 | **kwargs):
109 | try:
110 | import lmdb
111 | except ImportError:
112 | raise ImportError('Please install lmdb to enable LmdbBackend.')
113 |
114 | if isinstance(client_keys, str):
115 | client_keys = [client_keys]
116 |
117 | if isinstance(db_paths, list):
118 | self.db_paths = [str(v) for v in db_paths]
119 | elif isinstance(db_paths, str):
120 | self.db_paths = [str(db_paths)]
121 | assert len(client_keys) == len(self.db_paths), (
122 | 'client_keys and db_paths should have the same length, '
123 | f'but received {len(client_keys)} and {len(self.db_paths)}.')
124 |
125 | self._client = {}
126 |
127 | for client, path in zip(client_keys, self.db_paths):
128 | self._client[client] = lmdb.open(
129 | path,
130 | readonly=readonly,
131 | lock=lock,
132 | readahead=readahead,
133 | map_size=8*1024*10485760,
134 | # max_readers=1,
135 | **kwargs)
136 |
137 | def get(self, filepath, client_key):
138 | """Get values according to the filepath from one lmdb named client_key.
139 |
140 | Args:
141 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
142 | client_key (str): Used for distinguishing differnet lmdb envs.
143 | """
144 | filepath = str(filepath)
145 | assert client_key in self._client, (f'client_key {client_key} is not '
146 | 'in lmdb clients.')
147 | client = self._client[client_key]
148 | with client.begin(write=False) as txn:
149 | value_buf = txn.get(filepath.encode('ascii'))
150 | return value_buf
151 |
152 | def get_text(self, filepath):
153 | raise NotImplementedError
154 |
155 |
156 | class FileClient(object):
157 | """A general file client to access files in different backend.
158 |
159 | The client loads a file or text in a specified backend from its path
160 | and return it as a binary file. it can also register other backend
161 | accessor with a given name and backend class.
162 |
163 | Attributes:
164 | backend (str): The storage backend type. Options are "disk",
165 | "memcached" and "lmdb".
166 | client (:obj:`BaseStorageBackend`): The backend object.
167 | """
168 |
169 | _backends = {
170 | 'disk': HardDiskBackend,
171 | 'memcached': MemcachedBackend,
172 | 'lmdb': LmdbBackend,
173 | }
174 |
175 | def __init__(self, backend='disk', **kwargs):
176 | if backend not in self._backends:
177 | raise ValueError(
178 | f'Backend {backend} is not supported. Currently supported ones'
179 | f' are {list(self._backends.keys())}')
180 | self.backend = backend
181 | self.client = self._backends[backend](**kwargs)
182 |
183 | def get(self, filepath, client_key='default'):
184 | # client_key is used only for lmdb, where different fileclients have
185 | # different lmdb environments.
186 | if self.backend == 'lmdb':
187 | return self.client.get(filepath, client_key)
188 | else:
189 | return self.client.get(filepath)
190 |
191 | def get_text(self, filepath):
192 | return self.client.get_text(filepath)
193 |
--------------------------------------------------------------------------------
/basicsr/utils/flow_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
8 | import cv2
9 | import numpy as np
10 | import os
11 |
12 |
13 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
14 | """Read an optical flow map.
15 |
16 | Args:
17 | flow_path (ndarray or str): Flow path.
18 | quantize (bool): whether to read quantized pair, if set to True,
19 | remaining args will be passed to :func:`dequantize_flow`.
20 | concat_axis (int): The axis that dx and dy are concatenated,
21 | can be either 0 or 1. Ignored if quantize is False.
22 |
23 | Returns:
24 | ndarray: Optical flow represented as a (h, w, 2) numpy array
25 | """
26 | if quantize:
27 | assert concat_axis in [0, 1]
28 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
29 | if cat_flow.ndim != 2:
30 | raise IOError(f'{flow_path} is not a valid quantized flow file, '
31 | f'its dimension is {cat_flow.ndim}.')
32 | assert cat_flow.shape[concat_axis] % 2 == 0
33 | dx, dy = np.split(cat_flow, 2, axis=concat_axis)
34 | flow = dequantize_flow(dx, dy, *args, **kwargs)
35 | else:
36 | with open(flow_path, 'rb') as f:
37 | try:
38 | header = f.read(4).decode('utf-8')
39 | except Exception:
40 | raise IOError(f'Invalid flow file: {flow_path}')
41 | else:
42 | if header != 'PIEH':
43 | raise IOError(f'Invalid flow file: {flow_path}, '
44 | 'header does not contain PIEH')
45 |
46 | w = np.fromfile(f, np.int32, 1).squeeze()
47 | h = np.fromfile(f, np.int32, 1).squeeze()
48 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
49 |
50 | return flow.astype(np.float32)
51 |
52 |
53 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
54 | """Write optical flow to file.
55 |
56 | If the flow is not quantized, it will be saved as a .flo file losslessly,
57 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
58 | will be concatenated horizontally into a single image if quantize is True.)
59 |
60 | Args:
61 | flow (ndarray): (h, w, 2) array of optical flow.
62 | filename (str): Output filepath.
63 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg
64 | images. If set to True, remaining args will be passed to
65 | :func:`quantize_flow`.
66 | concat_axis (int): The axis that dx and dy are concatenated,
67 | can be either 0 or 1. Ignored if quantize is False.
68 | """
69 | if not quantize:
70 | with open(filename, 'wb') as f:
71 | f.write('PIEH'.encode('utf-8'))
72 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
73 | flow = flow.astype(np.float32)
74 | flow.tofile(f)
75 | f.flush()
76 | else:
77 | assert concat_axis in [0, 1]
78 | dx, dy = quantize_flow(flow, *args, **kwargs)
79 | dxdy = np.concatenate((dx, dy), axis=concat_axis)
80 | os.makedirs(filename, exist_ok=True)
81 | cv2.imwrite(dxdy, filename)
82 |
83 |
84 | def quantize_flow(flow, max_val=0.02, norm=True):
85 | """Quantize flow to [0, 255].
86 |
87 | After this step, the size of flow will be much smaller, and can be
88 | dumped as jpeg images.
89 |
90 | Args:
91 | flow (ndarray): (h, w, 2) array of optical flow.
92 | max_val (float): Maximum value of flow, values beyond
93 | [-max_val, max_val] will be truncated.
94 | norm (bool): Whether to divide flow values by image width/height.
95 |
96 | Returns:
97 | tuple[ndarray]: Quantized dx and dy.
98 | """
99 | h, w, _ = flow.shape
100 | dx = flow[..., 0]
101 | dy = flow[..., 1]
102 | if norm:
103 | dx = dx / w # avoid inplace operations
104 | dy = dy / h
105 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
106 | flow_comps = [
107 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
108 | ]
109 | return tuple(flow_comps)
110 |
111 |
112 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
113 | """Recover from quantized flow.
114 |
115 | Args:
116 | dx (ndarray): Quantized dx.
117 | dy (ndarray): Quantized dy.
118 | max_val (float): Maximum value used when quantizing.
119 | denorm (bool): Whether to multiply flow values with width/height.
120 |
121 | Returns:
122 | ndarray: Dequantized flow.
123 | """
124 | assert dx.shape == dy.shape
125 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
126 |
127 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
128 |
129 | if denorm:
130 | dx *= dx.shape[1]
131 | dy *= dx.shape[0]
132 | flow = np.dstack((dx, dy))
133 | return flow
134 |
135 |
136 | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
137 | """Quantize an array of (-inf, inf) to [0, levels-1].
138 |
139 | Args:
140 | arr (ndarray): Input array.
141 | min_val (scalar): Minimum value to be clipped.
142 | max_val (scalar): Maximum value to be clipped.
143 | levels (int): Quantization levels.
144 | dtype (np.type): The type of the quantized array.
145 |
146 | Returns:
147 | tuple: Quantized array.
148 | """
149 | if not (isinstance(levels, int) and levels > 1):
150 | raise ValueError(
151 | f'levels must be a positive integer, but got {levels}')
152 | if min_val >= max_val:
153 | raise ValueError(
154 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
155 |
156 | arr = np.clip(arr, min_val, max_val) - min_val
157 | quantized_arr = np.minimum(
158 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
159 |
160 | return quantized_arr
161 |
162 |
163 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
164 | """Dequantize an array.
165 |
166 | Args:
167 | arr (ndarray): Input array.
168 | min_val (scalar): Minimum value to be clipped.
169 | max_val (scalar): Maximum value to be clipped.
170 | levels (int): Quantization levels.
171 | dtype (np.type): The type of the dequantized array.
172 |
173 | Returns:
174 | tuple: Dequantized array.
175 | """
176 | if not (isinstance(levels, int) and levels > 1):
177 | raise ValueError(
178 | f'levels must be a positive integer, but got {levels}')
179 | if min_val >= max_val:
180 | raise ValueError(
181 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
182 |
183 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
184 | min_val) / levels + min_val
185 |
186 | return dequantized_arr
187 |
--------------------------------------------------------------------------------
/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import cv2
8 | import math
9 | import numpy as np
10 | import os
11 | import torch
12 | from torchvision.utils import make_grid
13 | import matplotlib.pyplot as plt
14 |
15 | def img2tensor(imgs, bgr2rgb=True, float32=True):
16 | """Numpy array to tensor.
17 |
18 | Args:
19 | imgs (list[ndarray] | ndarray): Input images.
20 | bgr2rgb (bool): Whether to change bgr to rgb.
21 | float32 (bool): Whether to change to float32.
22 |
23 | Returns:
24 | list[tensor] | tensor: Tensor images. If returned results only have
25 | one element, just return tensor.
26 | """
27 |
28 | def _totensor(img, bgr2rgb, float32):
29 | if img.shape[2] == 3 and bgr2rgb:
30 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
31 | img = torch.from_numpy(img.transpose(2, 0, 1))
32 | if float32:
33 | img = img.float()
34 | return img
35 |
36 | if isinstance(imgs, list):
37 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
38 | else:
39 | return _totensor(imgs, bgr2rgb, float32)
40 |
41 |
42 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
43 | """Convert torch Tensors into image numpy arrays.
44 |
45 | After clamping to [min, max], values will be normalized to [0, 1].
46 |
47 | Args:
48 | tensor (Tensor or list[Tensor]): Accept shapes:
49 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
50 | 2) 3D Tensor of shape (3/1 x H x W);
51 | 3) 2D Tensor of shape (H x W).
52 | Tensor channel should be in RGB order.
53 | rgb2bgr (bool): Whether to change rgb to bgr.
54 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
55 | to uint8 type with range [0, 255]; otherwise, float type with
56 | range [0, 1]. Default: ``np.uint8``.
57 | min_max (tuple[int]): min and max values for clamp.
58 |
59 | Returns:
60 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
61 | shape (H x W). The channel order is BGR.
62 | """
63 | if not (torch.is_tensor(tensor) or
64 | (isinstance(tensor, list)
65 | and all(torch.is_tensor(t) for t in tensor))):
66 | raise TypeError(
67 | f'tensor or list of tensors expected, got {type(tensor)}')
68 |
69 | if torch.is_tensor(tensor):
70 | tensor = [tensor]
71 | result = []
72 | for _tensor in tensor:
73 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
74 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
75 |
76 | n_dim = _tensor.dim()
77 | if n_dim == 4:
78 | img_np = make_grid(
79 | _tensor, nrow=int(math.sqrt(_tensor.size(0))),
80 | normalize=False).numpy()
81 | img_np = img_np.transpose(1, 2, 0)
82 | if rgb2bgr:
83 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
84 | elif n_dim == 3:
85 | img_np = _tensor.numpy()
86 | img_np = img_np.transpose(1, 2, 0)
87 | if img_np.shape[2] == 1: # gray image
88 | img_np = np.squeeze(img_np, axis=2)
89 | else:
90 | if rgb2bgr:
91 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
92 | elif n_dim == 2:
93 | img_np = _tensor.numpy()
94 | else:
95 | raise TypeError('Only support 4D, 3D or 2D tensor. '
96 | f'But received with dimension: {n_dim}')
97 | if out_type == np.uint8:
98 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
99 | img_np = (img_np * 255.0).round()
100 | img_np = img_np.astype(out_type)
101 | result.append(img_np)
102 | if len(result) == 1:
103 | result = result[0]
104 | return result
105 |
106 |
107 | def imfrombytes(content, flag='color', float32=False):
108 | """Read an image from bytes.
109 |
110 | Args:
111 | content (bytes): Image bytes got from files or other streams.
112 | flag (str): Flags specifying the color type of a loaded image,
113 | candidates are `color`, `grayscale` and `unchanged`.
114 | float32 (bool): Whether to change to float32., If True, will also norm
115 | to [0, 1]. Default: False.
116 |
117 | Returns:
118 | ndarray: Loaded image array.
119 | """
120 | img_np = np.frombuffer(content, np.uint8)
121 | imread_flags = {
122 | 'color': cv2.IMREAD_COLOR,
123 | 'grayscale': cv2.IMREAD_GRAYSCALE,
124 | 'unchanged': cv2.IMREAD_UNCHANGED
125 | }
126 | if img_np is None:
127 | raise Exception('None .. !!!')
128 | img = cv2.imdecode(img_np, imread_flags[flag])
129 | if float32:
130 | img = img.astype(np.float32) / 255.
131 | return img
132 |
133 | def padding(img_lq, img_gt, gt_size):
134 | h, w, _ = img_lq.shape
135 |
136 | h_pad = max(0, gt_size - h)
137 | w_pad = max(0, gt_size - w)
138 |
139 | if h_pad == 0 and w_pad == 0:
140 | return img_lq, img_gt
141 |
142 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
143 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
144 | # print('img_lq', img_lq.shape, img_gt.shape)
145 | return img_lq, img_gt
146 |
147 | def imwrite(img, file_path, params=None, auto_mkdir=True):
148 | """Write image to file.
149 |
150 | Args:
151 | img (ndarray): Image array to be written.
152 | file_path (str): Image file path.
153 | params (None or list): Same as opencv's :func:`imwrite` interface.
154 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
155 | whether to create it automatically.
156 |
157 | Returns:
158 | bool: Successful or not.
159 | """
160 | if auto_mkdir:
161 | dir_name = os.path.abspath(os.path.dirname(file_path))
162 | os.makedirs(dir_name, exist_ok=True)
163 | return cv2.imwrite(file_path, img, params)
164 |
165 | def pltimwrite(img, file_path):
166 | dir_name = os.path.abspath(os.path.dirname(file_path))
167 | os.makedirs(dir_name, exist_ok=True)
168 | plt.imshow(img)
169 | plt.savefig(file_path)
170 | plt.show()
171 |
172 | def crop_border(imgs, crop_border):
173 | """Crop borders of images.
174 |
175 | Args:
176 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
177 | crop_border (int): Crop border for each end of height and weight.
178 |
179 | Returns:
180 | list[ndarray]: Cropped images.
181 | """
182 | if crop_border == 0:
183 | return imgs
184 | else:
185 | if isinstance(imgs, list):
186 | return [
187 | v[crop_border:-crop_border, crop_border:-crop_border, ...]
188 | for v in imgs
189 | ]
190 | else:
191 | return imgs[crop_border:-crop_border, crop_border:-crop_border,
192 | ...]
193 |
--------------------------------------------------------------------------------
/basicsr/utils/lmdb_util.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import cv2
8 | import lmdb
9 | import sys
10 | from multiprocessing import Pool
11 | from os import path as osp
12 | from tqdm import tqdm
13 |
14 |
15 | def make_lmdb_from_imgs(data_path,
16 | lmdb_path,
17 | img_path_list,
18 | keys,
19 | batch=5000,
20 | compress_level=1,
21 | multiprocessing_read=False,
22 | n_thread=40,
23 | map_size=None):
24 | """Make lmdb from images.
25 |
26 | Contents of lmdb. The file structure is:
27 | example.lmdb
28 | ├── data.mdb
29 | ├── lock.mdb
30 | ├── meta_info.txt
31 |
32 | The data.mdb and lock.mdb are standard lmdb files and you can refer to
33 | https://lmdb.readthedocs.io/en/release/ for more details.
34 |
35 | The meta_info.txt is a specified txt file to record the meta information
36 | of our datasets. It will be automatically created when preparing
37 | datasets by our provided dataset tools.
38 | Each line in the txt file records 1)image name (with extension),
39 | 2)image shape, and 3)compression level, separated by a white space.
40 |
41 | For example, the meta information could be:
42 | `000_00000000.png (720,1280,3) 1`, which means:
43 | 1) image name (with extension): 000_00000000.png;
44 | 2) image shape: (720,1280,3);
45 | 3) compression level: 1
46 |
47 | We use the image name without extension as the lmdb key.
48 |
49 | If `multiprocessing_read` is True, it will read all the images to memory
50 | using multiprocessing. Thus, your server needs to have enough memory.
51 |
52 | Args:
53 | data_path (str): Data path for reading images.
54 | lmdb_path (str): Lmdb save path.
55 | img_path_list (str): Image path list.
56 | keys (str): Used for lmdb keys.
57 | batch (int): After processing batch images, lmdb commits.
58 | Default: 5000.
59 | compress_level (int): Compress level when encoding images. Default: 1.
60 | multiprocessing_read (bool): Whether use multiprocessing to read all
61 | the images to memory. Default: False.
62 | n_thread (int): For multiprocessing.
63 | map_size (int | None): Map size for lmdb env. If None, use the
64 | estimated size from images. Default: None
65 | """
66 |
67 | assert len(img_path_list) == len(keys), (
68 | 'img_path_list and keys should have the same length, '
69 | f'but got {len(img_path_list)} and {len(keys)}')
70 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
71 | print(f'Total images: {len(img_path_list)}')
72 | if not lmdb_path.endswith('.lmdb'):
73 | raise ValueError("lmdb_path must end with '.lmdb'.")
74 | if osp.exists(lmdb_path):
75 | print(f'Folder {lmdb_path} already exists. Exit.')
76 | sys.exit(1)
77 |
78 | if multiprocessing_read:
79 | # read all the images to memory (multiprocessing)
80 | dataset = {} # use dict to keep the order for multiprocessing
81 | shapes = {}
82 | print(f'Read images with multiprocessing, #thread: {n_thread} ...')
83 | pbar = tqdm(total=len(img_path_list), unit='image')
84 |
85 | def callback(arg):
86 | """get the image data and update pbar."""
87 | key, dataset[key], shapes[key] = arg
88 | pbar.update(1)
89 | pbar.set_description(f'Read {key}')
90 |
91 | pool = Pool(n_thread)
92 | for path, key in zip(img_path_list, keys):
93 | pool.apply_async(
94 | read_img_worker,
95 | args=(osp.join(data_path, path), key, compress_level),
96 | callback=callback)
97 | pool.close()
98 | pool.join()
99 | pbar.close()
100 | print(f'Finish reading {len(img_path_list)} images.')
101 |
102 | # create lmdb environment
103 | if map_size is None:
104 | # obtain data size for one image
105 | img = cv2.imread(
106 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
107 | _, img_byte = cv2.imencode(
108 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
109 | data_size_per_img = img_byte.nbytes
110 | print('Data size per image is: ', data_size_per_img)
111 | data_size = data_size_per_img * len(img_path_list)
112 | map_size = data_size * 10
113 |
114 | env = lmdb.open(lmdb_path, map_size=map_size)
115 |
116 | # write data to lmdb
117 | pbar = tqdm(total=len(img_path_list), unit='chunk')
118 | txn = env.begin(write=True)
119 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
120 | for idx, (path, key) in enumerate(zip(img_path_list, keys)):
121 | pbar.update(1)
122 | pbar.set_description(f'Write {key}')
123 | key_byte = key.encode('ascii')
124 | if multiprocessing_read:
125 | img_byte = dataset[key]
126 | h, w, c = shapes[key]
127 | else:
128 | _, img_byte, img_shape = read_img_worker(
129 | osp.join(data_path, path), key, compress_level)
130 | h, w, c = img_shape
131 |
132 | txn.put(key_byte, img_byte)
133 | # write meta information
134 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
135 | if idx % batch == 0:
136 | txn.commit()
137 | txn = env.begin(write=True)
138 | pbar.close()
139 | txn.commit()
140 | env.close()
141 | txt_file.close()
142 | print('\nFinish writing lmdb.')
143 |
144 |
145 | def read_img_worker(path, key, compress_level):
146 | """Read image worker.
147 |
148 | Args:
149 | path (str): Image path.
150 | key (str): Image key.
151 | compress_level (int): Compress level when encoding images.
152 |
153 | Returns:
154 | str: Image key.
155 | byte: Image byte.
156 | tuple[int]: Image shape.
157 | """
158 |
159 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
160 | if img.ndim == 2:
161 | h, w = img.shape
162 | c = 1
163 | else:
164 | h, w, c = img.shape
165 | _, img_byte = cv2.imencode('.png', img,
166 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
167 | return (key, img_byte, (h, w, c))
168 |
169 |
170 | class LmdbMaker():
171 | """LMDB Maker.
172 |
173 | Args:
174 | lmdb_path (str): Lmdb save path.
175 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
176 | batch (int): After processing batch images, lmdb commits.
177 | Default: 5000.
178 | compress_level (int): Compress level when encoding images. Default: 1.
179 | """
180 |
181 | def __init__(self,
182 | lmdb_path,
183 | map_size=1024**4,
184 | batch=5000,
185 | compress_level=1):
186 | if not lmdb_path.endswith('.lmdb'):
187 | raise ValueError("lmdb_path must end with '.lmdb'.")
188 | if osp.exists(lmdb_path):
189 | print(f'Folder {lmdb_path} already exists. Exit.')
190 | sys.exit(1)
191 |
192 | self.lmdb_path = lmdb_path
193 | self.batch = batch
194 | self.compress_level = compress_level
195 | self.env = lmdb.open(lmdb_path, map_size=map_size)
196 | self.txn = self.env.begin(write=True)
197 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
198 | self.counter = 0
199 |
200 | def put(self, img_byte, key, img_shape):
201 | self.counter += 1
202 | key_byte = key.encode('ascii')
203 | self.txn.put(key_byte, img_byte)
204 | # write meta information
205 | h, w, c = img_shape
206 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
207 | if self.counter % self.batch == 0:
208 | self.txn.commit()
209 | self.txn = self.env.begin(write=True)
210 |
211 | def close(self):
212 | self.txn.commit()
213 | self.env.close()
214 | self.txt_file.close()
215 |
--------------------------------------------------------------------------------
/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import datetime
8 | import logging
9 | import time
10 |
11 | from .dist_util import get_dist_info, master_only
12 |
13 |
14 | class MessageLogger():
15 | """Message logger for printing.
16 |
17 | Args:
18 | opt (dict): Config. It contains the following keys:
19 | name (str): Exp name.
20 | logger (dict): Contains 'print_freq' (str) for logger interval.
21 | train (dict): Contains 'total_iter' (int) for total iters.
22 | use_tb_logger (bool): Use tensorboard logger.
23 | start_iter (int): Start iter. Default: 1.
24 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
25 | """
26 |
27 | def __init__(self, opt, start_iter=1, tb_logger=None):
28 | self.exp_name = opt['name']
29 | self.interval = opt['logger']['print_freq']
30 | self.start_iter = start_iter
31 | self.max_iters = opt['train']['total_iter']
32 | self.use_tb_logger = opt['logger']['use_tb_logger']
33 | self.tb_logger = tb_logger
34 | self.start_time = time.time()
35 | self.logger = get_root_logger()
36 |
37 | @master_only
38 | def __call__(self, log_vars):
39 | """Format logging message.
40 |
41 | Args:
42 | log_vars (dict): It contains the following keys:
43 | epoch (int): Epoch number.
44 | iter (int): Current iter.
45 | lrs (list): List for learning rates.
46 |
47 | time (float): Iter time.
48 | data_time (float): Data time for each iter.
49 | """
50 | # epoch, iter, learning rates
51 | epoch = log_vars.pop('epoch')
52 | current_iter = log_vars.pop('iter')
53 | lrs = log_vars.pop('lrs')
54 |
55 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, '
56 | f'iter:{current_iter:8,d}, lr:(')
57 | for v in lrs:
58 | message += f'{v:.3e},'
59 | message += ')] '
60 |
61 | # time and estimated time
62 | if 'time' in log_vars.keys():
63 | iter_time = log_vars.pop('time')
64 | data_time = log_vars.pop('data_time')
65 |
66 | total_time = time.time() - self.start_time
67 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
68 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
69 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
70 | message += f'[eta: {eta_str}, '
71 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
72 |
73 | # other items, especially losses
74 | for k, v in log_vars.items():
75 | message += f'{k}: {v:.4e} '
76 | # tensorboard logger
77 | if self.use_tb_logger and 'debug' not in self.exp_name:
78 | if k.startswith('l_'):
79 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
80 | else:
81 | self.tb_logger.add_scalar(k, v, current_iter)
82 | self.logger.info(message)
83 |
84 |
85 | @master_only
86 | def init_tb_logger(log_dir):
87 | from torch.utils.tensorboard import SummaryWriter
88 | tb_logger = SummaryWriter(log_dir=log_dir)
89 | return tb_logger
90 |
91 |
92 | @master_only
93 | def init_wandb_logger(opt):
94 | """We now only use wandb to sync tensorboard log."""
95 | import wandb
96 | logger = logging.getLogger('basicsr')
97 |
98 | project = opt['logger']['wandb']['project']
99 | resume_id = opt['logger']['wandb'].get('resume_id')
100 | if resume_id:
101 | wandb_id = resume_id
102 | resume = 'allow'
103 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
104 | else:
105 | wandb_id = wandb.util.generate_id()
106 | resume = 'never'
107 |
108 | wandb.init(
109 | id=wandb_id,
110 | resume=resume,
111 | name=opt['name'],
112 | config=opt,
113 | project=project,
114 | sync_tensorboard=True)
115 |
116 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
117 |
118 |
119 | def get_root_logger(logger_name='basicsr',
120 | log_level=logging.INFO,
121 | log_file=None):
122 | """Get the root logger.
123 |
124 | The logger will be initialized if it has not been initialized. By default a
125 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
126 | also be added.
127 |
128 | Args:
129 | logger_name (str): root logger name. Default: 'basicsr'.
130 | log_file (str | None): The log filename. If specified, a FileHandler
131 | will be added to the root logger.
132 | log_level (int): The root logger level. Note that only the process of
133 | rank 0 is affected, while other processes will set the level to
134 | "Error" and be silent most of the time.
135 |
136 | Returns:
137 | logging.Logger: The root logger.
138 | """
139 | logger = logging.getLogger(logger_name)
140 | # if the logger has been initialized, just return it
141 | if logger.hasHandlers():
142 | return logger
143 |
144 | format_str = '%(asctime)s %(levelname)s: %(message)s'
145 | logging.basicConfig(format=format_str, level=log_level)
146 | rank, _ = get_dist_info()
147 | if rank != 0:
148 | logger.setLevel('ERROR')
149 | elif log_file is not None:
150 | file_handler = logging.FileHandler(log_file, 'w')
151 | file_handler.setFormatter(logging.Formatter(format_str))
152 | file_handler.setLevel(log_level)
153 | logger.addHandler(file_handler)
154 |
155 | return logger
156 |
157 |
158 | def get_env_info():
159 | """Get environment information.
160 |
161 | Currently, only log the software version.
162 | """
163 | import torch
164 | import torchvision
165 |
166 | from basicsr.version import __version__
167 | msg = r"""
168 | ____ _ _____ ____
169 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
170 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
171 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
172 | /_____/ \__,_//____//_/ \___//____//_/ |_|
173 | ______ __ __ __ __
174 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
175 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
176 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
177 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
178 | """
179 | msg += ('\nVersion Information: '
180 | f'\n\tBasicSR: {__version__}'
181 | f'\n\tPyTorch: {torch.__version__}'
182 | f'\n\tTorchVision: {torchvision.__version__}')
183 | return msg
184 |
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import numpy as np
8 | import os
9 | import random
10 | import time
11 | import torch
12 | from os import path as osp
13 |
14 | from .dist_util import master_only
15 | from .logger import get_root_logger
16 |
17 |
18 | def set_random_seed(seed):
19 | """Set random seeds."""
20 | random.seed(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.cuda.manual_seed_all(seed)
25 |
26 |
27 | def get_time_str():
28 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
29 |
30 |
31 | def mkdir_and_rename(path):
32 | """mkdirs. If path exists, rename it with timestamp and create a new one.
33 |
34 | Args:
35 | path (str): Folder path.
36 | """
37 | if osp.exists(path):
38 | new_name = path + '_archived_' + get_time_str()
39 | print(f'Path already exists. Rename it to {new_name}', flush=True)
40 | os.rename(path, new_name)
41 | os.makedirs(path, exist_ok=True)
42 |
43 |
44 | @master_only
45 | def make_exp_dirs(opt):
46 | """Make dirs for experiments."""
47 | path_opt = opt['path'].copy()
48 | if opt['is_train']:
49 | mkdir_and_rename(path_opt.pop('experiments_root'))
50 | else:
51 | mkdir_and_rename(path_opt.pop('results_root'))
52 | for key, path in path_opt.items():
53 | if ('strict_load' not in key) and ('pretrain_network'
54 | not in key) and ('resume'
55 | not in key):
56 | os.makedirs(path, exist_ok=True)
57 |
58 |
59 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
60 | """Scan a directory to find the interested files.
61 |
62 | Args:
63 | dir_path (str): Path of the directory.
64 | suffix (str | tuple(str), optional): File suffix that we are
65 | interested in. Default: None.
66 | recursive (bool, optional): If set to True, recursively scan the
67 | directory. Default: False.
68 | full_path (bool, optional): If set to True, include the dir_path.
69 | Default: False.
70 |
71 | Returns:
72 | A generator for all the interested files with relative pathes.
73 | """
74 |
75 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
76 | raise TypeError('"suffix" must be a string or tuple of strings')
77 |
78 | root = dir_path
79 |
80 | def _scandir(dir_path, suffix, recursive):
81 | for entry in os.scandir(dir_path):
82 | if not entry.name.startswith('.') and entry.is_file():
83 | if full_path:
84 | return_path = entry.path
85 | else:
86 | return_path = osp.relpath(entry.path, root)
87 |
88 | if suffix is None:
89 | yield return_path
90 | elif return_path.endswith(suffix):
91 | yield return_path
92 | else:
93 | if recursive:
94 | yield from _scandir(
95 | entry.path, suffix=suffix, recursive=recursive)
96 | else:
97 | continue
98 |
99 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
100 |
101 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
102 | """Scan a directory to find the interested files.
103 |
104 | Args:
105 | dir_path (str): Path of the directory.
106 | keywords (str | tuple(str), optional): File keywords that we are
107 | interested in. Default: None.
108 | recursive (bool, optional): If set to True, recursively scan the
109 | directory. Default: False.
110 | full_path (bool, optional): If set to True, include the dir_path.
111 | Default: False.
112 |
113 | Returns:
114 | A generator for all the interested files with relative pathes.
115 | """
116 |
117 | if (keywords is not None) and not isinstance(keywords, (str, tuple)):
118 | raise TypeError('"keywords" must be a string or tuple of strings')
119 |
120 | root = dir_path
121 |
122 | def _scandir(dir_path, keywords, recursive):
123 | for entry in os.scandir(dir_path):
124 | if not entry.name.startswith('.') and entry.is_file():
125 | if full_path:
126 | return_path = entry.path
127 | else:
128 | return_path = osp.relpath(entry.path, root)
129 |
130 | if keywords is None:
131 | yield return_path
132 | elif return_path.find(keywords) > 0:
133 | yield return_path
134 | else:
135 | if recursive:
136 | yield from _scandir(
137 | entry.path, keywords=keywords, recursive=recursive)
138 | else:
139 | continue
140 |
141 | return _scandir(dir_path, keywords=keywords, recursive=recursive)
142 |
143 | def check_resume(opt, resume_iter):
144 | """Check resume states and pretrain_network paths.
145 |
146 | Args:
147 | opt (dict): Options.
148 | resume_iter (int): Resume iteration.
149 | """
150 | logger = get_root_logger()
151 | if opt['path']['resume_state']:
152 | # get all the networks
153 | networks = [key for key in opt.keys() if key.startswith('network_')]
154 | flag_pretrain = False
155 | for network in networks:
156 | if opt['path'].get(f'pretrain_{network}') is not None:
157 | flag_pretrain = True
158 | if flag_pretrain:
159 | logger.warning(
160 | 'pretrain_network path will be ignored during resuming.')
161 | # set pretrained model paths
162 | for network in networks:
163 | name = f'pretrain_{network}'
164 | basename = network.replace('network_', '')
165 | if opt['path'].get('ignore_resume_networks') is None or (
166 | basename not in opt['path']['ignore_resume_networks']):
167 | opt['path'][name] = osp.join(
168 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
169 | logger.info(f"Set {name} to {opt['path'][name]}")
170 |
171 |
172 | def sizeof_fmt(size, suffix='B'):
173 | """Get human readable file size.
174 |
175 | Args:
176 | size (int): File size.
177 | suffix (str): Suffix. Default: 'B'.
178 |
179 | Return:
180 | str: Formated file siz.
181 | """
182 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
183 | if abs(size) < 1024.0:
184 | return f'{size:3.1f} {unit}{suffix}'
185 | size /= 1024.0
186 | return f'{size:3.1f} Y{suffix}'
187 |
--------------------------------------------------------------------------------
/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2021 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR)
5 | # Copyright 2018-2020 BasicSR Authors
6 | # ------------------------------------------------------------------------
7 | import yaml
8 | from collections import OrderedDict
9 | from os import path as osp
10 |
11 |
12 | def ordered_yaml():
13 | """Support OrderedDict for yaml.
14 |
15 | Returns:
16 | yaml Loader and Dumper.
17 | """
18 | try:
19 | from yaml import CDumper as Dumper
20 | from yaml import CLoader as Loader
21 | except ImportError:
22 | from yaml import Dumper, Loader
23 |
24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
25 |
26 | def dict_representer(dumper, data):
27 | return dumper.represent_dict(data.items())
28 |
29 | def dict_constructor(loader, node):
30 | return OrderedDict(loader.construct_pairs(node))
31 |
32 | Dumper.add_representer(OrderedDict, dict_representer)
33 | Loader.add_constructor(_mapping_tag, dict_constructor)
34 | return Loader, Dumper
35 |
36 |
37 | def parse(opt_path, is_train=True):
38 | """Parse option file.
39 |
40 | Args:
41 | opt_path (str): Option file path.
42 | is_train (str): Indicate whether in training or not. Default: True.
43 |
44 | Returns:
45 | (dict): Options.
46 | """
47 | with open(opt_path, mode='r') as f:
48 | Loader, _ = ordered_yaml()
49 | opt = yaml.load(f, Loader=Loader)
50 |
51 | opt['is_train'] = is_train
52 |
53 | # datasets
54 | if 'datasets' in opt:
55 | for phase, dataset in opt['datasets'].items():
56 | # for several datasets, e.g., test_1, test_2
57 | phase = phase.split('_')[0]
58 | dataset['phase'] = phase
59 | if 'scale' in opt:
60 | dataset['scale'] = opt['scale']
61 | if dataset.get('dataroot_gt') is not None:
62 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
63 | if dataset.get('dataroot_lq') is not None:
64 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
65 |
66 | # paths
67 | for key, val in opt['path'].items():
68 | if (val is not None) and ('resume_state' in key
69 | or 'pretrain_network' in key):
70 | opt['path'][key] = osp.expanduser(val)
71 | opt['path']['root'] = osp.abspath(
72 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
73 | if is_train:
74 | experiments_root = osp.join(opt['path']['root'], 'experiments',
75 | opt['name'])
76 | opt['path']['experiments_root'] = experiments_root
77 | opt['path']['models'] = osp.join(experiments_root, 'models')
78 | opt['path']['training_states'] = osp.join(experiments_root,
79 | 'training_states')
80 | opt['path']['log'] = experiments_root
81 | opt['path']['visualization'] = osp.join(experiments_root,
82 | 'visualization')
83 |
84 | # change some options for debug mode
85 | if 'debug' in opt['name']:
86 | if 'val' in opt:
87 | opt['val']['val_freq'] = 8
88 | opt['logger']['print_freq'] = 1
89 | opt['logger']['save_checkpoint_freq'] = 8
90 | else: # test
91 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
92 | opt['path']['results_root'] = results_root
93 | opt['path']['log'] = results_root
94 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
95 |
96 | return opt
97 |
98 |
99 | def dict2str(opt, indent_level=1):
100 | """dict to string for printing options.
101 |
102 | Args:
103 | opt (dict): Option dict.
104 | indent_level (int): Indent level. Default: 1.
105 |
106 | Return:
107 | (str): Option string for printing.
108 | """
109 | msg = '\n'
110 | for k, v in opt.items():
111 | if isinstance(v, dict):
112 | msg += ' ' * (indent_level * 2) + k + ':['
113 | msg += dict2str(v, indent_level + 1)
114 | msg += ' ' * (indent_level * 2) + ']\n'
115 | else:
116 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
117 | return msg
118 |
--------------------------------------------------------------------------------
/basicsr/version.py:
--------------------------------------------------------------------------------
1 | # GENERATED VERSION FILE
2 | # TIME: Mon May 21 02:15:04 2024
3 | __version__ = '1.0.0+ebd1331'
4 | short_version = '1.0.0'
5 | version_info = (1, 0, 0)
6 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | cuda: "11.3"
3 | gpu: true
4 | python_version: "3.9"
5 | system_packages:
6 | - "libgl1-mesa-glx"
7 | - "libglib2.0-0"
8 | python_packages:
9 | - "numpy==1.21.1"
10 | - "ipython==7.21.0"
11 | - "addict==2.4.0"
12 | - "future==0.18.2"
13 | - "lmdb==1.3.0"
14 | - "opencv-python==4.5.5.64"
15 | - "Pillow==9.1.0"
16 | - "pyyaml==6.0"
17 | - "torch==1.11.0"
18 | - "torchvision==0.12.0"
19 | - "tqdm==4.64.0"
20 | - "scipy==1.8.0"
21 | - "scikit-image==0.19.2"
22 | - "matplotlib==3.5.1"
23 |
24 | predict: "predict.py:Predictor"
25 |
--------------------------------------------------------------------------------
/make_video.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 |
4 |
5 | """
6 | This script creates a side-by-side comparison video from pairs of input and predicted frames stored in a directory.
7 | A sliding line moves across the frames to visually compare the differences, and the resulting video is saved to an output file.
8 | """
9 |
10 |
11 | # Directory path for Input and Restored frames
12 | frames_dir = 'path to the low quality ad=nd high quality frames'
13 |
14 | # Output video parameters
15 | output_video_path = 'path_to_save_video/x.mp4'
16 | fps = 20 # Set the frames per second for the output video
17 |
18 |
19 |
20 | # Initialize video writer
21 | frame_example = cv2.imread(os.path.join(frames_dir, os.listdir(frames_dir)[1]))
22 | height, width, layers = frame_example.shape
23 | print(height, width)
24 | out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
25 |
26 | # Get sorted list of frame filenames
27 | all_files = os.listdir(frames_dir)
28 | input_frames = sorted([f for f in all_files if 'Input' in f], key=lambda x: int(x.split('_')[1]))
29 | pred_frames = sorted([f for f in all_files if 'Pred' in f], key=lambda x: int(x.split('_')[1]))
30 |
31 | # Total number of frames
32 | total_frames = min(len(input_frames), len(pred_frames))
33 |
34 | for i in range(total_frames):
35 | # Construct the filenames based on the sorted lists
36 | low_quality_filename = input_frames[i]
37 | high_quality_filename = pred_frames[i]
38 |
39 | # Read frames
40 | frame1 = cv2.imread(os.path.join(frames_dir, low_quality_filename))
41 | frame2 = cv2.imread(os.path.join(frames_dir, high_quality_filename))
42 | # print(frame1.shape,frame2.shape)
43 |
44 | # Compute the position of the sliding line
45 | slider_position = int((i / total_frames) * width)
46 |
47 | # Create a combined frame
48 | combined_frame = frame1.copy()
49 | combined_frame[:, :slider_position] = frame2[:, :slider_position]
50 |
51 | # Draw the sliding line
52 | cv2.line(combined_frame, (slider_position, 0), (slider_position, height), (0, 255, 0), 2)
53 |
54 | # Write the combined frame to the output video
55 | out.write(combined_frame)
56 |
57 | # Release video writer
58 | out.release()
59 |
60 | print("Video has been created and saved as", output_video_path)
61 |
--------------------------------------------------------------------------------
/options/Turtle_Deblur_Gopro.yml:
--------------------------------------------------------------------------------
1 | name: Final_Gaia_Gopro
2 | model_type: VideoRestorationModel
3 | scale: 1
4 | num_gpu: 8
5 | manual_seed: 10
6 | n_sequence: 5 # n_frames
7 | dir_data: ['/home/amir/datasets/GoPro/train/']
8 | n_colors: 3
9 | rgb_range: 1
10 | no_augment: False
11 | loss_type: 1*L1
12 | patch_size: 192
13 | size_must_mode: 4
14 | model: Turtle_t1_arch
15 | pretrain_models_dir: None
16 | type: deblurring
17 | dim: 64
18 | Enc_blocks: [2, 6, 10]
19 | Middle_blocks: 11
20 | Dec_blocks: [10, 6, 2]
21 | num_refinement_blocks: 2
22 | use_both_input: False
23 | num_heads: [1, 2, 4, 8]
24 | num_frames_tocache: 3
25 | ffn_expansion_factor: 2.5
26 |
27 | encoder1_attn_type1 : "ReducedAttn"
28 | encoder1_attn_type2 : "ReducedAttn"
29 | encoder1_ffw_type : "FFW"
30 |
31 | encoder2_attn_type1 : "ReducedAttn"
32 | encoder2_attn_type2 : "ReducedAttn"
33 | encoder2_ffw_type : "FFW"
34 |
35 | encoder3_attn_type1 : "Channel"
36 | encoder3_attn_type2 : "Channel"
37 | encoder3_ffw_type : "GFFW"
38 |
39 | decoder1_attn_type1 : "Channel"
40 | decoder1_attn_type2 : "CHM"
41 | decoder1_ffw_type : "GFFW"
42 |
43 | decoder2_attn_type1 : "Channel"
44 | decoder2_attn_type2 : "CHM"
45 | decoder2_ffw_type : "GFFW"
46 |
47 | decoder3_attn_type1 : "Channel"
48 | decoder3_attn_type2 : "CHM"
49 | decoder3_ffw_type : "GFFW"
50 |
51 | latent_attn_type1 : "FHR"
52 | latent_attn_type2 : "Channel"
53 | latent_attn_type3 : "FHR"
54 | latent_ffw_type : "GFFW"
55 |
56 | refinement_attn_type1 : "ReducedAttn"
57 | refinement_attn_type2 : "ReducedAttn"
58 | refinement_ffw_type : "GFFW"
59 |
60 | datasets:
61 | train:
62 | name: gopro-train
63 | filename_tmpl: '{}'
64 | io_backend:
65 | type: lmdb
66 |
67 | gt_size: 192
68 | use_flip: false
69 | use_rot: false
70 |
71 | # data loader
72 | use_shuffle: true
73 | num_worker_per_gpu: 8
74 | batch_size_per_gpu: 2
75 | dataset_enlarge_ratio: 1
76 | prefetch_mode: ~
77 |
78 | val:
79 | name: gopro-test
80 | dir_data: ['/home/amir/datasets/GoPro/test/']
81 |
82 | path:
83 | pretrain_network_g: ~
84 | strict_load_g: true
85 | resume_state: ~
86 |
87 | train:
88 | optim_g:
89 | type: Adam
90 | lr: !!float 4e-4
91 | weight_decay: 0
92 | betas: [0.9, 0.99]
93 |
94 | scheduler:
95 | type: TrueCosineAnnealingLR
96 | T_max: 200000
97 | eta_min: !!float 1e-7
98 |
99 | total_iter: 200000
100 | warmup_iter: -1 # no warm up
101 |
102 | # losses
103 | pixel_opt:
104 | type: L1Loss
105 | loss_weight: 1
106 | reduction: mean
107 |
108 | # validation settings
109 | val:
110 | val_freq: 10000
111 | save_img: true
112 | grids: true
113 | crop_size: 192
114 | max_minibatch: 8
115 |
116 | metrics:
117 | psnr: # metric name, can be arbitrary
118 | type: calculate_psnr
119 | crop_border: 0
120 | test_y_channel: false
121 |
122 | # logging settings
123 | logger:
124 | print_freq: 200
125 | save_checkpoint_freq: !!float 10000
126 | use_tb_logger: true
127 | wandb:
128 | project: ~
129 | resume_id: ~
130 |
131 | # dist training settings
132 | dist_params:
133 | backend: nccl
134 | port: 29500
--------------------------------------------------------------------------------
/options/Turtle_Denoise_Davis.yml:
--------------------------------------------------------------------------------
1 | name: Gaia_Denoise_Davis_noise20-50
2 | model_type: VideoRestorationModel
3 | scale: 1
4 | num_gpu: 8
5 | manual_seed: 10
6 | n_sequence: 5 # n_frames
7 | dir_data: ['/datasets/DAVIS/JPEGImages/']
8 | n_colors: 3
9 | rgb_range: 1
10 | no_augment: False
11 | loss_type: 1*L1
12 | patch_size: 192
13 | size_must_mode: 4
14 | model: Turtle_t1_arch
15 |
16 |
17 | pretrain_models_dir: None
18 | type: denoising
19 | dim: 64
20 | Enc_blocks: [2, 6, 10]
21 | Middle_blocks: 11
22 | Dec_blocks: [10, 6, 2]
23 | num_refinement_blocks: 2
24 | use_both_input: False
25 | num_heads: [1, 2, 4, 8]
26 | num_frames_tocache: 3
27 | ffn_expansion_factor: 2.5
28 |
29 | encoder1_attn_type1 : "ReducedAttn"
30 | encoder1_attn_type2 : "ReducedAttn"
31 | encoder1_ffw_type : "FFW"
32 |
33 | encoder2_attn_type1 : "ReducedAttn"
34 | encoder2_attn_type2 : "ReducedAttn"
35 | encoder2_ffw_type : "FFW"
36 |
37 | encoder3_attn_type1 : "Channel"
38 | encoder3_attn_type2 : "Channel"
39 | encoder3_ffw_type : "GFFW"
40 |
41 | decoder1_attn_type1 : "Channel"
42 | decoder1_attn_type2 : "CHM"
43 | decoder1_ffw_type : "GFFW"
44 |
45 | decoder2_attn_type1 : "Channel"
46 | decoder2_attn_type2 : "CHM"
47 | decoder2_ffw_type : "GFFW"
48 |
49 | decoder3_attn_type1 : "Channel"
50 | decoder3_attn_type2 : "CHM"
51 | decoder3_ffw_type : "GFFW"
52 |
53 | latent_attn_type1 : "FHR"
54 | latent_attn_type2 : "Channel"
55 | latent_attn_type3 : "FHR"
56 | latent_ffw_type : "GFFW"
57 |
58 | refinement_attn_type1 : "ReducedAttn"
59 | refinement_attn_type2 : "ReducedAttn"
60 | refinement_ffw_type : "GFFW"
61 |
62 | prompt_attn: "NoAttn"
63 | prompt_ffw: "GFFW"
64 |
65 | datasets:
66 | train:
67 | name: rsvd-train
68 | filename_tmpl: '{}'
69 | io_backend:
70 | type: lmdb
71 |
72 | gt_size: 192
73 | use_flip: false
74 | use_rot: false
75 |
76 | # data loader
77 | use_shuffle: true
78 | num_worker_per_gpu: 8
79 | batch_size_per_gpu: 2
80 | dataset_enlarge_ratio: 1
81 | prefetch_mode: ~
82 |
83 | val:
84 | name: davis
85 | dir_data: ['/datasets/DAVIS_testdev/DAVIS/JPEGImages/']
86 |
87 | path:
88 | pretrain_network_g: ~
89 | resume_state: ~
90 |
91 | train:
92 | optim_g:
93 | type: Adam
94 | lr: !!float 4e-4
95 | weight_decay: 0
96 | betas: [0.9, 0.99]
97 |
98 | scheduler:
99 | type: TrueCosineAnnealingLR
100 | T_max: 250000
101 | eta_min: !!float 1e-7
102 |
103 | total_iter: 250000
104 | warmup_iter: -1 # no warm up
105 |
106 | # losses
107 | pixel_opt:
108 | type: L1Loss
109 | loss_weight: 1
110 | reduction: mean
111 |
112 | # validation settings
113 | val:
114 | val_freq: 10000
115 | save_img: true
116 | grids: true
117 | crop_size: 192
118 | max_minibatch: 8
119 |
120 | metrics:
121 | psnr: # metric name, can be arbitrary
122 | type: calculate_psnr
123 | crop_border: 0
124 | test_y_channel: false
125 |
126 | # logging settings
127 | logger:
128 | print_freq: 200
129 | save_checkpoint_freq: !!float 10000
130 | use_tb_logger: true
131 | wandb:
132 | project: ~
133 | resume_id: ~
134 |
135 | # dist training settings
136 | dist_params:
137 | backend: nccl
138 | port: 29500
139 |
--------------------------------------------------------------------------------
/options/Turtle_Derain.yml:
--------------------------------------------------------------------------------
1 | name: Turtle_Derain
2 | model_type: VideoRestorationModel
3 | scale: 1
4 | num_gpu: 8
5 | manual_seed: 10
6 | n_sequence: 5 # n_frames
7 | dir_data: ['/datasets/NightRain/train/']
8 | n_colors: 3
9 | rgb_range: 1
10 | no_augment: False
11 | loss_type: 1*L1
12 | patch_size: 192
13 | size_must_mode: 4
14 | model: Turtle_arch
15 | pretrain_models_dir: None
16 | type: deraining
17 | dim: 64
18 | Enc_blocks: [2, 6, 10]
19 | Middle_blocks: 11
20 | Dec_blocks: [10, 6, 2]
21 | num_refinement_blocks: 2
22 | use_both_input: False
23 | num_heads: [1, 2, 4, 8]
24 | num_frames_tocache: 3
25 | ffn_expansion_factor: 2.5
26 |
27 |
28 | encoder1_attn_type1 : "ReducedAttn"
29 | encoder1_attn_type2 : "ReducedAttn"
30 | encoder1_ffw_type : "FFW"
31 |
32 | encoder2_attn_type1 : "ReducedAttn"
33 | encoder2_attn_type2 : "ReducedAttn"
34 | encoder2_ffw_type : "FFW"
35 |
36 | encoder3_attn_type1 : "Channel"
37 | encoder3_attn_type2 : "Channel"
38 | encoder3_ffw_type : "GFFW"
39 |
40 | decoder1_attn_type1 : "Channel"
41 | decoder1_attn_type2 : "CHM"
42 | decoder1_ffw_type : "GFFW"
43 |
44 | decoder2_attn_type1 : "Channel"
45 | decoder2_attn_type2 : "CHM"
46 | decoder2_ffw_type : "GFFW"
47 |
48 | decoder3_attn_type1 : "Channel"
49 | decoder3_attn_type2 : "CHM"
50 | decoder3_ffw_type : "GFFW"
51 |
52 | latent_attn_type1 : "FHR"
53 | latent_attn_type2 : "Channel"
54 | latent_attn_type3 : "FHR"
55 | latent_ffw_type : "GFFW"
56 |
57 | refinement_attn_type1 : "ReducedAttn"
58 | refinement_attn_type2 : "ReducedAttn"
59 | refinement_ffw_type : "GFFW"
60 |
61 |
62 | datasets:
63 | train:
64 | name: ngtrain-train
65 | filename_tmpl: '{}'
66 | io_backend:
67 | type: lmdb
68 |
69 | gt_size: 192
70 | use_flip: false
71 | use_rot: false
72 |
73 | # data loader
74 | use_shuffle: true
75 | num_worker_per_gpu: 8
76 | batch_size_per_gpu: 2
77 | dataset_enlarge_ratio: 1
78 | prefetch_mode: ~
79 |
80 | val:
81 | name: ngtrain-test
82 | dir_data: ['/datasets/NightRain/test/']
83 |
84 | path:
85 | pretrain_network_g: ~
86 | strict_load_g: true
87 | resume_state: ~
88 |
89 | train:
90 | optim_g:
91 | type: Adam
92 | lr: !!float 4e-4
93 | weight_decay: 0
94 | betas: [0.9, 0.99]
95 |
96 | scheduler:
97 | type: TrueCosineAnnealingLR
98 | T_max: 200000
99 | eta_min: !!float 1e-7
100 |
101 | total_iter: 200000
102 | warmup_iter: -1 # no warm up
103 |
104 | # losses
105 | pixel_opt:
106 | type: L1Loss
107 | loss_weight: 1
108 | reduction: mean
109 |
110 | # validation settings
111 | val:
112 | val_freq: 50000
113 | save_img: true
114 | grids: true
115 | crop_size: 192
116 | max_minibatch: 8
117 |
118 | metrics:
119 | psnr: # metric name, can be arbitrary
120 | type: calculate_psnr
121 | crop_border: 0
122 | test_y_channel: false
123 |
124 | # logging settings
125 | logger:
126 | print_freq: 200
127 | save_checkpoint_freq: !!float 10000
128 | use_tb_logger: true
129 | wandb:
130 | project: ~
131 | resume_id: ~
132 |
133 | # dist training settings
134 | dist_params:
135 | backend: nccl
136 | port: 29500
--------------------------------------------------------------------------------
/options/Turtle_Derain_VRDS.yml:
--------------------------------------------------------------------------------
1 | name: Turtle_Derain
2 | model_type: VideoRestorationModel
3 | scale: 1
4 | num_gpu: 8
5 | manual_seed: 10
6 | n_sequence: 5 # n_frames
7 | dir_data: ['/datasets/VRDS/train/']
8 | n_colors: 3
9 | rgb_range: 1
10 | no_augment: False
11 | loss_type: 1*L1
12 | patch_size: 192
13 | size_must_mode: 4
14 | model: Turtle_t1_arch
15 | pretrain_models_dir: None
16 | type: deraining
17 | dim: 64
18 | Enc_blocks: [2, 6, 10]
19 | Middle_blocks: 11
20 | Dec_blocks: [10, 6, 2]
21 | num_refinement_blocks: 2
22 | use_both_input: False
23 | num_heads: [1, 2, 4, 8]
24 | num_frames_tocache: 3
25 | ffn_expansion_factor: 2.5
26 |
27 | encoder1_attn_type1 : "ReducedAttn"
28 | encoder1_attn_type2 : "ReducedAttn"
29 | encoder1_ffw_type : "FFW"
30 |
31 | encoder2_attn_type1 : "ReducedAttn"
32 | encoder2_attn_type2 : "ReducedAttn"
33 | encoder2_ffw_type : "FFW"
34 |
35 | encoder3_attn_type1 : "Channel"
36 | encoder3_attn_type2 : "Channel"
37 | encoder3_ffw_type : "GFFW"
38 |
39 | decoder1_attn_type1 : "Channel"
40 | decoder1_attn_type2 : "CHM"
41 | decoder1_ffw_type : "GFFW"
42 |
43 | decoder2_attn_type1 : "Channel"
44 | decoder2_attn_type2 : "CHM"
45 | decoder2_ffw_type : "GFFW"
46 |
47 | decoder3_attn_type1 : "Channel"
48 | decoder3_attn_type2 : "CHM"
49 | decoder3_ffw_type : "GFFW"
50 |
51 | latent_attn_type1 : "FHR"
52 | latent_attn_type2 : "Channel"
53 | latent_attn_type3 : "FHR"
54 | latent_ffw_type : "GFFW"
55 |
56 | refinement_attn_type1 : "ReducedAttn"
57 | refinement_attn_type2 : "ReducedAttn"
58 | refinement_ffw_type : "GFFW"
59 |
60 |
61 | datasets:
62 | train:
63 | name: VRDS-train
64 | filename_tmpl: '{}'
65 | io_backend:
66 | type: lmdb
67 |
68 | gt_size: 192
69 | use_flip: false
70 | use_rot: false
71 |
72 | # data loader
73 | use_shuffle: true
74 | num_worker_per_gpu: 8
75 | batch_size_per_gpu: 2
76 | dataset_enlarge_ratio: 1
77 | prefetch_mode: ~
78 |
79 | val:
80 | name: VRDS-test
81 | dir_data: ['/datasets/VRDS/test/']
82 |
83 | path:
84 | pretrain_network_g: ~
85 | strict_load_g: true
86 | resume_state: ~
87 |
88 | train:
89 | optim_g:
90 | type: Adam
91 | lr: !!float 4e-4
92 | weight_decay: 0
93 | betas: [0.9, 0.99]
94 |
95 | scheduler:
96 | type: TrueCosineAnnealingLR
97 | T_max: 200000
98 | eta_min: !!float 1e-7
99 |
100 | total_iter: 200000
101 | warmup_iter: -1 # no warm up
102 |
103 | # losses
104 | pixel_opt:
105 | type: L1Loss
106 | loss_weight: 1
107 | reduction: mean
108 |
109 | # validation settings
110 | val:
111 | val_freq: 50000
112 | save_img: true
113 | grids: true
114 | crop_size: 192
115 | max_minibatch: 8
116 |
117 | metrics:
118 | psnr: # metric name, can be arbitrary
119 | type: calculate_psnr
120 | crop_border: 0
121 | test_y_channel: false
122 |
123 | # logging settings
124 | logger:
125 | print_freq: 200
126 | save_checkpoint_freq: !!float 10000
127 | use_tb_logger: true
128 | wandb:
129 | project: ~
130 | resume_id: ~
131 |
132 | # dist training settings
133 | dist_params:
134 | backend: nccl
135 | port: 29500
--------------------------------------------------------------------------------
/options/Turtle_Desnow.yml:
--------------------------------------------------------------------------------
1 | name: Turtle_Desnow
2 | model_type: VideoRestorationModel
3 | scale: 1
4 | num_gpu: 8
5 | manual_seed: 10
6 | n_sequence: 5 # n_frames
7 | dir_data: ['/datasets/Desnowing/rsvd/train/']
8 | n_colors: 3
9 | rgb_range: 1
10 | no_augment: False
11 | loss_type: 1*L1
12 | patch_size: 192
13 | size_must_mode: 4
14 | model: Turtle_arch
15 |
16 | pretrain_models_dir: None
17 | type: desnowing
18 | dim: 64
19 | Enc_blocks: [2, 6, 10]
20 | Middle_blocks: 11
21 | Dec_blocks: [10, 6, 2]
22 | num_refinement_blocks: 2
23 | use_both_input: False
24 | num_heads: [1, 2, 4, 8]
25 | num_frames_tocache: 3
26 | ffn_expansion_factor: 2.5
27 |
28 | encoder1_attn_type1 : "ReducedAttn"
29 | encoder1_attn_type2 : "ReducedAttn"
30 | encoder1_ffw_type : "FFW"
31 |
32 | encoder2_attn_type1 : "ReducedAttn"
33 | encoder2_attn_type2 : "ReducedAttn"
34 | encoder2_ffw_type : "FFW"
35 |
36 | encoder3_attn_type1 : "Channel"
37 | encoder3_attn_type2 : "Channel"
38 | encoder3_ffw_type : "GFFW"
39 |
40 | decoder1_attn_type1 : "Channel"
41 | decoder1_attn_type2 : "CHM"
42 | decoder1_ffw_type : "GFFW"
43 |
44 | decoder2_attn_type1 : "Channel"
45 | decoder2_attn_type2 : "CHM"
46 | decoder2_ffw_type : "GFFW"
47 |
48 | decoder3_attn_type1 : "Channel"
49 | decoder3_attn_type2 : "CHM"
50 | decoder3_ffw_type : "GFFW"
51 |
52 | latent_attn_type1 : "FHR"
53 | latent_attn_type2 : "Channel"
54 | latent_attn_type3 : "FHR"
55 | latent_ffw_type : "GFFW"
56 |
57 | refinement_attn_type1 : "ReducedAttn"
58 | refinement_attn_type2 : "ReducedAttn"
59 | refinement_ffw_type : "GFFW"
60 |
61 | datasets:
62 | train:
63 | name: rsvd-train
64 | filename_tmpl: '{}'
65 | io_backend:
66 | type: lmdb
67 |
68 | gt_size: 192
69 | use_flip: false
70 | use_rot: false
71 |
72 | # data loader
73 | use_shuffle: true
74 | num_worker_per_gpu: 8
75 | batch_size_per_gpu: 2
76 | dataset_enlarge_ratio: 1
77 | prefetch_mode: ~
78 |
79 | val:
80 | name: rsvd-test
81 | dir_data: ['/datasets/Desnowing/rsvd/test/']
82 |
83 | path:
84 | pretrain_network_g: ~
85 | strict_load_g: true
86 | resume_state: ~
87 | train:
88 | optim_g:
89 | type: Adam
90 | lr: !!float 4e-4
91 | weight_decay: 0
92 | betas: [0.9, 0.99]
93 |
94 | scheduler:
95 | type: TrueCosineAnnealingLR
96 | T_max: 250000
97 | eta_min: !!float 1e-7
98 |
99 | total_iter: 250000
100 | warmup_iter: -1 # no warm up
101 |
102 | # losses
103 | pixel_opt:
104 | type: L1Loss
105 | loss_weight: 1
106 | reduction: mean
107 |
108 | # validation settings
109 | val:
110 | val_freq: 20000
111 | save_img: true
112 | grids: true
113 | crop_size: 192
114 | max_minibatch: 8
115 |
116 | metrics:
117 | psnr: # metric name, can be arbitrary
118 | type: calculate_psnr
119 | crop_border: 0
120 | test_y_channel: false
121 |
122 | # logging settings
123 | logger:
124 | print_freq: 200
125 | save_checkpoint_freq: !!float 10000
126 | use_tb_logger: true
127 | wandb:
128 | project: ~
129 | resume_id: ~
130 |
131 | # dist training settings
132 | dist_params:
133 | backend: nccl
134 | port: 29500
--------------------------------------------------------------------------------
/options/Turtle_SR_MVSR.yml:
--------------------------------------------------------------------------------
1 | name: Turtle_SR_MVSR
2 | model_type: VideoRestorationModel
3 | scale: 1
4 | num_gpu: 8
5 | manual_seed: 10
6 | n_sequence: 5 # n_frames
7 |
8 | dir_data: ['/datasets/MVSR4x/train/']
9 |
10 | n_colors: 3
11 | rgb_range: 1
12 | no_augment: False
13 | loss_type: 1*L1
14 | patch_size: 192
15 | size_must_mode: 4
16 | model: Turtlesuper_t1_arch
17 | pretrain_models_dir: None
18 | type: superresolution
19 | dim: 64
20 | Enc_blocks: [2, 6, 10]
21 | Middle_blocks: 11
22 | Dec_blocks: [10, 6, 2]
23 | num_refinement_blocks: 2
24 | use_both_input: False
25 | num_heads: [1, 2, 4, 8]
26 | num_frames_tocache: 3
27 | ffn_expansion_factor: 2.5
28 |
29 | encoder1_attn_type1 : "ReducedAttn"
30 | encoder1_attn_type2 : "ReducedAttn"
31 | encoder1_ffw_type : "FFW"
32 |
33 | encoder2_attn_type1 : "ReducedAttn"
34 | encoder2_attn_type2 : "ReducedAttn"
35 | encoder2_ffw_type : "FFW"
36 |
37 | encoder3_attn_type1 : "Channel"
38 | encoder3_attn_type2 : "Channel"
39 | encoder3_ffw_type : "GFFW"
40 |
41 | decoder1_attn_type1 : "Channel"
42 | decoder1_attn_type2 : "CHM"
43 | decoder1_ffw_type : "GFFW"
44 |
45 | decoder2_attn_type1 : "Channel"
46 | decoder2_attn_type2 : "CHM"
47 | decoder2_ffw_type : "GFFW"
48 |
49 | decoder3_attn_type1 : "Channel"
50 | decoder3_attn_type2 : "CHM"
51 | decoder3_ffw_type : "GFFW"
52 |
53 | latent_attn_type1 : "FHR"
54 | latent_attn_type2 : "Channel"
55 | latent_attn_type3 : "FHR"
56 | latent_ffw_type : "GFFW"
57 |
58 | refinement_attn_type1 : "ReducedAttn"
59 | refinement_attn_type2 : "ReducedAttn"
60 | refinement_ffw_type : "GFFW"
61 |
62 | prompt_attn: "NoAttn"
63 | prompt_ffw: "GFFW"
64 |
65 | datasets:
66 | train:
67 | name: mvsr-train
68 | filename_tmpl: '{}'
69 | io_backend:
70 | type: lmdb
71 |
72 | gt_size: 192
73 | use_flip: false
74 | use_rot: false
75 |
76 | # data loader
77 | use_shuffle: true
78 | num_worker_per_gpu: 8
79 | batch_size_per_gpu: 2
80 | dataset_enlarge_ratio: 1
81 | prefetch_mode: ~
82 |
83 | val:
84 | name: mvsr-test
85 | dir_data: ['/datasets/MVSR4x/test/']
86 |
87 | path:
88 | pretrain_network_g: ~
89 | strict_load_g: true
90 | resume_state: ~
91 |
92 | train:
93 | optim_g:
94 | type: Adam
95 | lr: !!float 4e-4
96 | weight_decay: 0
97 | betas: [0.9, 0.99]
98 |
99 | scheduler:
100 | type: TrueCosineAnnealingLR
101 | T_max: 200000
102 | eta_min: !!float 1e-7
103 |
104 | total_iter: 200000
105 | warmup_iter: -1 # no warm up
106 |
107 | # losses
108 | pixel_opt:
109 | type: L1Loss
110 | loss_weight: 1
111 | reduction: mean
112 |
113 | # validation settings
114 | val:
115 | val_freq: 10000
116 | save_img: true
117 | grids: true
118 | crop_size: 192
119 | max_minibatch: 8
120 |
121 | metrics:
122 | psnr: # metric name, can be arbitrary
123 | type: calculate_psnr
124 | crop_border: 0
125 | test_y_channel: false
126 |
127 | # logging settings
128 | logger:
129 | print_freq: 200
130 | save_checkpoint_freq: !!float 5000
131 | use_tb_logger: true
132 | wandb:
133 | project: ~
134 | resume_id: ~
135 |
136 | # dist training settings
137 | dist_params:
138 | backend: nccl
139 | port: 29500
140 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | [](https://paperswithcode.com/sota/deblurring-on-beam-splitter-deblurring-bsd?p=learning-truncated-causal-history-model-for)
2 | [](https://paperswithcode.com/sota/rain-removal-on-nighrain?p=learning-truncated-causal-history-model-for)
3 | [](https://paperswithcode.com/sota/snow-removal-on-rvsd?p=learning-truncated-causal-history-model-for)
4 | [](https://paperswithcode.com/sota/video-deraining-on-vrds?p=learning-truncated-causal-history-model-for)
5 | [](https://paperswithcode.com/sota/video-denoising-on-set8-sigma50?p=learning-truncated-causal-history-model-for)
6 | [](https://paperswithcode.com/sota/deblurring-on-gopro?p=learning-truncated-causal-history-model-for)
7 |
8 |
9 |
10 | # **Turtle: Learning Truncated Causal History Model for Video Restoration [NeurIPS'2024]**
11 |
12 | [📄 arxiv](https://arxiv.org/abs/2410.03936)
13 | **|**
14 | [🌐 Website](https://kjanjua26.github.io/turtle/)
15 |
16 | The official PyTorch implementation for **Learning Truncated Causal History Model for Video Restoration**, accepted to NeurIPS 2024.
17 |
18 | - Turtle achieves state-of-the-art results on multiple video restoration benchmarks, offering superior computational efficiency and enhanced restoration quality 🔥🔥🔥.
19 | - **🛠️💡Model Forge**: Easily design your own architecture by modifying the option file.
20 | - You have the flexibility to choose from various types of layers—such as channel attention, simple channel attention, CHM, FHR, or custom blocks—as well as different types of feed-forward layers.
21 | - This setup allows you to create custom networks and experiment with layer and feed-forward configurations to suit your needs.
22 | - If you like this project, please give us a ⭐ on Github!🚀
23 |
24 |
25 |
26 |
27 |
30 |
31 |
32 |