├── .gitignore ├── LICENSE.txt ├── VERSION ├── assets ├── gopro.gif ├── memory_comparison.png ├── nightrain30.gif ├── raindrop.gif ├── snowwww.gif └── turtle.png ├── basicsr ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── prefetch_dataloader.py │ ├── transforms.py │ ├── video_image_dataset.py │ └── video_super_image_dataset.py ├── inference.py ├── inference_no_ground_truth.py ├── loss │ └── __init__.py ├── metrics │ ├── __init__.py │ ├── metric_util.py │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── archs │ │ ├── turtle_arch.py │ │ ├── turtle_t1_arch.py │ │ └── turtlesuper_t1_arch.py │ ├── base_model.py │ ├── losses │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ ├── lr_scheduler.py │ └── video_restoration_model.py ├── train.py ├── utils │ ├── __init__.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── util.py │ └── utils_video.py └── version.py ├── cog.yaml ├── make_video.py ├── options ├── Turtle_Deblur_Gopro.yml ├── Turtle_Denoise_Davis.yml ├── Turtle_Derain.yml ├── Turtle_Derain_VRDS.yml ├── Turtle_Desnow.yml └── Turtle_SR_MVSR.yml ├── readme.md ├── requirements.txt ├── setup.cfg ├── setup.py └── video_to_frames.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea/* 3 | experiments 4 | logs/ 5 | *results* 6 | *__pycache__* 7 | *.sh 8 | datasets 9 | basicsr.egg-info 10 | tb_logger 11 | placeholder_datasetDAVIS 12 | placeholder_dataset 13 | outputs 14 | options 15 | basicsr/inference_outputs 16 | inference_outputs -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Huawei Technologies Co., Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.0.0 2 | -------------------------------------------------------------------------------- /assets/gopro.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/gopro.gif -------------------------------------------------------------------------------- /assets/memory_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/memory_comparison.png -------------------------------------------------------------------------------- /assets/nightrain30.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/nightrain30.gif -------------------------------------------------------------------------------- /assets/raindrop.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/raindrop.gif -------------------------------------------------------------------------------- /assets/snowwww.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/snowwww.gif -------------------------------------------------------------------------------- /assets/turtle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ascend-Research/Turtle/dea2cd990f6d6f3b21e68872ae81a76736a7c205/assets/turtle.png -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | 6 | import importlib 7 | import numpy as np 8 | import random 9 | import torch 10 | import torch.utils.data 11 | from functools import partial 12 | from os import path as osp 13 | 14 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 15 | from basicsr.utils import get_root_logger, scandir 16 | from basicsr.utils.dist_util import get_dist_info 17 | 18 | __all__ = ['create_dataset', 'create_dataloader'] 19 | 20 | # automatically scan and import dataset modules 21 | # scan all the files under the data folder with '_dataset' in file names 22 | data_folder = osp.dirname(osp.abspath(__file__)) 23 | dataset_filenames = [ 24 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 25 | if v.endswith('_dataset.py') 26 | ] 27 | # import all the dataset modules 28 | _dataset_modules = [ 29 | importlib.import_module(f'basicsr.data.{file_name}') 30 | for file_name in dataset_filenames 31 | ] 32 | 33 | 34 | def create_dataset(dataset_opt): 35 | """Create dataset. 36 | 37 | Args: 38 | dataset_opt (dict): Configuration for dataset. It constains: 39 | name (str): Dataset name. 40 | type (str): Dataset type. 41 | """ 42 | dataset_type = dataset_opt['type'] 43 | 44 | # dynamic instantiation 45 | for module in _dataset_modules: 46 | dataset_cls = getattr(module, dataset_type, None) 47 | if dataset_cls is not None: 48 | break 49 | if dataset_cls is None: 50 | raise ValueError(f'Dataset {dataset_type} is not found.') 51 | 52 | dataset = dataset_cls(dataset_opt) 53 | 54 | logger = get_root_logger() 55 | logger.info( 56 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 57 | 'is created.') 58 | return dataset 59 | 60 | 61 | def create_dataloader(dataset, 62 | dataset_opt, 63 | num_gpu=1, 64 | dist=False, 65 | sampler=None, 66 | seed=None): 67 | """Create dataloader. 68 | 69 | Args: 70 | dataset (torch.utils.data.Dataset): Dataset. 71 | dataset_opt (dict): Dataset options. It contains the following keys: 72 | phase (str): 'train' or 'val'. 73 | num_worker_per_gpu (int): Number of workers for each GPU. 74 | batch_size_per_gpu (int): Training batch size for each GPU. 75 | num_gpu (int): Number of GPUs. Used only in the train phase. 76 | Default: 1. 77 | dist (bool): Whether in distributed training. Used only in the train 78 | phase. Default: False. 79 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 80 | seed (int | None): Seed. Default: None 81 | """ 82 | phase = dataset_opt['phase'] 83 | rank, _ = get_dist_info() 84 | 85 | if phase == 'train': 86 | if dist: # distributed training 87 | batch_size = dataset_opt['batch_size_per_gpu'] 88 | num_workers = dataset_opt['num_worker_per_gpu'] 89 | else: # non-distributed training 90 | multiplier = 1 if num_gpu == 0 else num_gpu 91 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 92 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 93 | dataloader_args = dict( 94 | dataset=dataset, 95 | batch_size=batch_size, 96 | shuffle=False, 97 | num_workers=num_workers, 98 | sampler=sampler, 99 | drop_last=True) 100 | if sampler is None: 101 | dataloader_args['shuffle'] = True 102 | dataloader_args['worker_init_fn'] = partial( 103 | worker_init_fn, num_workers=num_workers, rank=rank, 104 | seed=seed) if seed is not None else None 105 | 106 | elif phase in ['val', 'test']: # validation 107 | dataloader_args = dict( 108 | dataset=dataset, 109 | batch_size=1, 110 | shuffle=False, 111 | num_workers=0) 112 | 113 | else: 114 | raise ValueError(f'Wrong dataset phase: {phase}. ' 115 | "Supported ones are 'train', 'val' and 'test'.") 116 | 117 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 118 | 119 | prefetch_mode = dataset_opt.get('prefetch_mode') 120 | if prefetch_mode == 'cpu': # CPUPrefetcher 121 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 122 | logger = get_root_logger() 123 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 124 | f'num_prefetch_queue = {num_prefetch_queue}') 125 | return PrefetchDataLoader( 126 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 127 | else: 128 | # prefetch_mode=None: Normal dataloader 129 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 130 | return torch.utils.data.DataLoader(**dataloader_args) 131 | 132 | 133 | def worker_init_fn(worker_id, num_workers, rank, seed): 134 | # Set the worker seed to num_workers * rank + worker_id + seed 135 | worker_seed = num_workers * rank + worker_id + seed 136 | np.random.seed(worker_seed) 137 | random.seed(worker_seed) 138 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | class EnlargedSampler(Sampler): 12 | """Sampler that restricts data loading to a subset of the dataset. 13 | 14 | Modified from torch.utils.data.distributed.DistributedSampler 15 | Support enlarging the dataset for iteration-based training, for saving 16 | time when restart the dataloader after each epoch 17 | 18 | Args: 19 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 20 | num_replicas (int | None): Number of processes participating in 21 | the training. It is usually the world_size. 22 | rank (int | None): Rank of the current process within num_replicas. 23 | ratio (int): Enlarging ratio. Default: 1. 24 | """ 25 | 26 | def __init__(self, dataset, num_replicas, rank, ratio=1): 27 | self.dataset = dataset 28 | self.num_replicas = num_replicas 29 | self.rank = rank 30 | self.epoch = 0 31 | self.num_samples = math.ceil( 32 | len(self.dataset) * ratio / self.num_replicas) 33 | self.total_size = self.num_samples * self.num_replicas 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | g = torch.Generator() 38 | g.manual_seed(self.epoch) 39 | indices = torch.randperm(self.total_size, generator=g).tolist() 40 | 41 | dataset_size = len(self.dataset) 42 | indices = [v % dataset_size for v in indices] 43 | 44 | # subsample 45 | indices = indices[self.rank:self.total_size:self.num_replicas] 46 | assert len(indices) == self.num_samples 47 | 48 | return iter(indices) 49 | 50 | def __len__(self): 51 | return self.num_samples 52 | 53 | def set_epoch(self, epoch): 54 | self.epoch = epoch 55 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | import queue as Queue 6 | import threading 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | 11 | class PrefetchGenerator(threading.Thread): 12 | """A general prefetch generator. 13 | 14 | Ref: 15 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 16 | 17 | Args: 18 | generator: Python generator. 19 | num_prefetch_queue (int): Number of prefetch queue. 20 | """ 21 | 22 | def __init__(self, generator, num_prefetch_queue): 23 | threading.Thread.__init__(self) 24 | self.queue = Queue.Queue(num_prefetch_queue) 25 | self.generator = generator 26 | self.daemon = True 27 | self.start() 28 | 29 | def run(self): 30 | for item in self.generator: 31 | self.queue.put(item) 32 | self.queue.put(None) 33 | 34 | def __next__(self): 35 | next_item = self.queue.get() 36 | if next_item is None: 37 | raise StopIteration 38 | return next_item 39 | 40 | def __iter__(self): 41 | return self 42 | 43 | 44 | class PrefetchDataLoader(DataLoader): 45 | """Prefetch version of dataloader. 46 | 47 | Ref: 48 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 49 | 50 | TODO: 51 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 52 | ddp. 53 | 54 | Args: 55 | num_prefetch_queue (int): Number of prefetch queue. 56 | kwargs (dict): Other arguments for dataloader. 57 | """ 58 | 59 | def __init__(self, num_prefetch_queue, **kwargs): 60 | self.num_prefetch_queue = num_prefetch_queue 61 | super(PrefetchDataLoader, self).__init__(**kwargs) 62 | 63 | def __iter__(self): 64 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 65 | 66 | 67 | class CPUPrefetcher(): 68 | """CPU prefetcher. 69 | 70 | Args: 71 | loader: Dataloader. 72 | """ 73 | 74 | def __init__(self, loader): 75 | self.ori_loader = loader 76 | self.loader = iter(loader) 77 | 78 | def next(self): 79 | try: 80 | return next(self.loader) 81 | except StopIteration: 82 | return None 83 | 84 | def reset(self): 85 | self.loader = iter(self.ori_loader) 86 | 87 | 88 | class CUDAPrefetcher(): 89 | """CUDA prefetcher. 90 | 91 | Ref: 92 | https://github.com/NVIDIA/apex/issues/304# 93 | 94 | It may consums more GPU memory. 95 | 96 | Args: 97 | loader: Dataloader. 98 | opt (dict): Options. 99 | """ 100 | 101 | def __init__(self, loader, opt): 102 | self.ori_loader = loader 103 | self.loader = iter(loader) 104 | self.opt = opt 105 | self.stream = torch.cuda.Stream() 106 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 107 | self.preload() 108 | 109 | def preload(self): 110 | try: 111 | self.batch = next(self.loader) # self.batch is a dict 112 | except StopIteration: 113 | self.batch = None 114 | return None 115 | # put tensors to gpu 116 | with torch.cuda.stream(self.stream): 117 | for k, v in self.batch.items(): 118 | if torch.is_tensor(v): 119 | self.batch[k] = self.batch[k].to( 120 | device=self.device, non_blocking=True) 121 | 122 | def next(self): 123 | torch.cuda.current_stream().wait_stream(self.stream) 124 | batch = self.batch 125 | self.preload() 126 | return batch 127 | 128 | def reset(self): 129 | self.loader = iter(self.ori_loader) 130 | self.preload() 131 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import random, numpy as np 9 | 10 | def mod_crop(img, scale): 11 | """Mod crop images, used during testing. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | scale (int): Scale factor. 16 | 17 | Returns: 18 | ndarray: Result image. 19 | """ 20 | img = img.copy() 21 | if img.ndim in (2, 3): 22 | h, w = img.shape[0], img.shape[1] 23 | h_remainder, w_remainder = h % scale, w % scale 24 | img = img[:h - h_remainder, :w - w_remainder, ...] 25 | else: 26 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 27 | return img 28 | 29 | 30 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): 31 | """Paired random crop. 32 | 33 | It crops lists of lq and gt images with corresponding locations. 34 | 35 | Args: 36 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 37 | should have the same shape. If the input is an ndarray, it will 38 | be transformed to a list containing itself. 39 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 40 | should have the same shape. If the input is an ndarray, it will 41 | be transformed to a list containing itself. 42 | gt_patch_size (int): GT patch size. 43 | scale (int): Scale factor. 44 | gt_path (str): Path to ground-truth. 45 | 46 | Returns: 47 | list[ndarray] | ndarray: GT images and LQ images. If returned results 48 | only have one element, just return ndarray. 49 | """ 50 | 51 | if not isinstance(img_gts, list): 52 | img_gts = [img_gts] 53 | if not isinstance(img_lqs, list): 54 | img_lqs = [img_lqs] 55 | 56 | h_lq, w_lq, _ = img_lqs[0].shape 57 | h_gt, w_gt, _ = img_gts[0].shape 58 | lq_patch_size = gt_patch_size // scale 59 | 60 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 61 | raise ValueError( 62 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 63 | f'multiplication of LQ ({h_lq}, {w_lq}).') 64 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 65 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 66 | f'({lq_patch_size}, {lq_patch_size}). ' 67 | f'Please remove {gt_path}.') 68 | 69 | # randomly choose top and left coordinates for lq patch 70 | top = random.randint(0, h_lq - lq_patch_size) 71 | left = random.randint(0, w_lq - lq_patch_size) 72 | 73 | # crop lq patch 74 | img_lqs = [ 75 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...] 76 | for v in img_lqs 77 | ] 78 | 79 | # crop corresponding gt patch 80 | top_gt, left_gt = int(top * scale), int(left * scale) 81 | img_gts = [ 82 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] 83 | for v in img_gts 84 | ] 85 | if len(img_gts) == 1: 86 | img_gts = img_gts[0] 87 | if len(img_lqs) == 1: 88 | img_lqs = img_lqs[0] 89 | return img_gts, img_lqs 90 | 91 | 92 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 93 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 94 | 95 | We use vertical flip and transpose for rotation implementation. 96 | All the images in the list use the same augmentation. 97 | 98 | Args: 99 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 100 | is an ndarray, it will be transformed to a list. 101 | hflip (bool): Horizontal flip. Default: True. 102 | rotation (bool): Ratotation. Default: True. 103 | flows (list[ndarray]: Flows to be augmented. If the input is an 104 | ndarray, it will be transformed to a list. 105 | Dimension is (h, w, 2). Default: None. 106 | return_status (bool): Return the status of flip and rotation. 107 | Default: False. 108 | 109 | Returns: 110 | list[ndarray] | ndarray: Augmented images and flows. If returned 111 | results only have one element, just return ndarray. 112 | 113 | """ 114 | hflip = hflip and random.random() < 0.5 115 | vflip = rotation and random.random() < 0.5 116 | rot90 = rotation and random.random() < 0.5 117 | 118 | def _augment(img): 119 | if hflip: # horizontal 120 | cv2.flip(img, 1, img) 121 | if vflip: # vertical 122 | cv2.flip(img, 0, img) 123 | if rot90: 124 | img = img.transpose(1, 0, 2) 125 | return img 126 | 127 | def _augment_flow(flow): 128 | if hflip: # horizontal 129 | cv2.flip(flow, 1, flow) 130 | flow[:, :, 0] *= -1 131 | if vflip: # vertical 132 | cv2.flip(flow, 0, flow) 133 | flow[:, :, 1] *= -1 134 | if rot90: 135 | flow = flow.transpose(1, 0, 2) 136 | flow = flow[:, :, [1, 0]] 137 | return flow 138 | 139 | if not isinstance(imgs, list): 140 | imgs = [imgs] 141 | imgs = [_augment(img) for img in imgs] 142 | if len(imgs) == 1: 143 | imgs = imgs[0] 144 | 145 | if flows is not None: 146 | if not isinstance(flows, list): 147 | flows = [flows] 148 | flows = [_augment_flow(flow) for flow in flows] 149 | if len(flows) == 1: 150 | flows = flows[0] 151 | return imgs, flows 152 | else: 153 | if return_status: 154 | return imgs, (hflip, vflip, rot90) 155 | else: 156 | return imgs 157 | 158 | 159 | def img_rotate(img, angle, center=None, scale=1.0): 160 | """Rotate image. 161 | 162 | Args: 163 | img (ndarray): Image to be rotated. 164 | angle (float): Rotation angle in degrees. Positive values mean 165 | counter-clockwise rotation. 166 | center (tuple[int]): Rotation center. If the center is None, 167 | initialize it as the center of the image. Default: None. 168 | scale (float): Isotropic scale factor. Default: 1.0. 169 | """ 170 | (h, w) = img.shape[:2] 171 | 172 | if center is None: 173 | center = (w // 2, h // 2) 174 | 175 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 176 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 177 | return rotated_img 178 | 179 | def data_augmentation(image, mode): 180 | """ 181 | Performs data augmentation of the input image 182 | Input: 183 | image: a cv2 (OpenCV) image 184 | mode: int. Choice of transformation to apply to the image 185 | 0 - no transformation 186 | 1 - flip up and down 187 | 2 - rotate counterwise 90 degree 188 | 3 - rotate 90 degree and flip up and down 189 | 4 - rotate 180 degree 190 | 5 - rotate 180 degree and flip 191 | 6 - rotate 270 degree 192 | 7 - rotate 270 degree and flip 193 | """ 194 | if mode == 0: 195 | # original 196 | out = image 197 | elif mode == 1: 198 | # flip up and down 199 | out = np.flipud(image) 200 | elif mode == 2: 201 | # rotate counterwise 90 degree 202 | out = np.rot90(image) 203 | elif mode == 3: 204 | # rotate 90 degree and flip up and down 205 | out = np.rot90(image) 206 | out = np.flipud(out) 207 | elif mode == 4: 208 | # rotate 180 degree 209 | out = np.rot90(image, k=2) 210 | elif mode == 5: 211 | # rotate 180 degree and flip 212 | out = np.rot90(image, k=2) 213 | out = np.flipud(out) 214 | elif mode == 6: 215 | # rotate 270 degree 216 | out = np.rot90(image, k=3) 217 | elif mode == 7: 218 | # rotate 270 degree and flip 219 | out = np.rot90(image, k=3) 220 | out = np.flipud(out) 221 | else: 222 | raise Exception('Invalid choice of image transformation') 223 | 224 | return out 225 | 226 | def random_augmentation(*args): 227 | ## older random augmentation 228 | out = [] 229 | if random.randint(0,1) == 1: 230 | flag_aug = random.randint(1,7) 231 | for data in args: 232 | out.append(data_augmentation(data, flag_aug).copy()) 233 | else: 234 | for data in args: 235 | out.append(data) 236 | return out 237 | 238 | # restormer's augmentation 239 | # def random_augmentation(*args): 240 | # out = [] 241 | # flag_aug = random.randint(0, 7) 242 | # for data in args: 243 | # out.append(data_augmentation(data, flag_aug).copy()) 244 | # return out -------------------------------------------------------------------------------- /basicsr/data/video_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from basicsr.data.data_util import np2Tensor, get_patch 3 | from basicsr.data.transforms import random_augmentation 4 | import os 5 | import glob, imageio 6 | import numpy as np 7 | import torch 8 | 9 | class VideoImageDataset(data.Dataset): 10 | def __init__(self, args, phase): 11 | self.args = args 12 | self.name = args['name'] 13 | self.phase = phase 14 | self.n_seq = args['n_sequence'] 15 | self.n_frames_video = [] 16 | if self.phase == "train": 17 | self._set_filesystem(args['dir_data'], 18 | self.phase) 19 | else: 20 | self._set_filesystem(args['datasets']['val']['dir_data'], 21 | self.phase) 22 | 23 | self.images_gt, self.images_input = self._scan() 24 | self.num_video = len(self.images_gt) 25 | self.num_frame = sum(self.n_frames_video) - (self.n_seq - 1) * len(self.n_frames_video) 26 | print("Number of videos to load:", self.num_video) 27 | self.n_colors = args['n_colors'] 28 | self.rgb_range = args['rgb_range'] 29 | self.patch_size = args['patch_size'] 30 | self.no_augment = args['no_augment'] 31 | self.size_must_mode = args['size_must_mode'] 32 | 33 | def _set_filesystem(self, dir_data, phase): 34 | print("Loading {} => {} DataSet".format(f"{phase}", self.name)) 35 | if isinstance(dir_data, list): 36 | self.dir_gt = [] 37 | self.apath = [] 38 | self.dir_input = [] 39 | for path in dir_data: 40 | self.apath.append(path) 41 | self.dir_gt.append(os.path.join(path, 'gt')) 42 | self.dir_input.append(os.path.join(path, 'blur')) 43 | else: 44 | self.apath = dir_data 45 | self.dir_gt = os.path.join(self.apath, 'gt') 46 | self.dir_input = os.path.join(self.apath, 'blur') 47 | 48 | def _scan(self): 49 | if isinstance(self.dir_gt, list): 50 | vid_gt_names_combined = [] 51 | vid_input_names_combined = [] 52 | 53 | for ix in range(len(self.dir_gt)): 54 | vid_gt_names = sorted(glob.glob(os.path.join(self.dir_gt[ix], '*'))) 55 | vid_input_names = sorted(glob.glob(os.path.join(self.dir_input[ix], '*'))) 56 | 57 | vid_gt_names_combined.append(vid_gt_names) 58 | vid_input_names_combined.append(vid_input_names) 59 | assert len(vid_gt_names) == len(vid_input_names), "len(vid_gt_names) must equal len(vid_input_names)" 60 | else: 61 | vid_gt_names_combined = vid_gt_names 62 | vid_input_names_combined = vid_input_names 63 | 64 | images_gt = [] 65 | images_input = [] 66 | for vid_gt, vid_input in zip(vid_gt_names_combined, vid_input_names_combined): 67 | for vid_gt_name, vid_input_name in zip(vid_gt, vid_input): 68 | gt_dir_names = sorted(glob.glob(os.path.join(vid_gt_name, '*'))) 69 | input_dir_names = sorted(glob.glob(os.path.join(vid_input_name, '*'))) 70 | 71 | images_gt.append(gt_dir_names) 72 | images_input.append(input_dir_names) 73 | self.n_frames_video.append(len(gt_dir_names)) 74 | return images_gt, images_input 75 | 76 | def _load(self, images_gt, images_input): 77 | data_input = [] 78 | data_gt = [] 79 | n_videos = len(images_gt) 80 | for idx in range(n_videos): 81 | if idx % 10 == 0: 82 | print("Loading video %d" % idx) 83 | gts = np.array([imageio.imread(hr_name) for hr_name in images_gt[idx]]) 84 | inputs = np.array([imageio.imread(lr_name) for lr_name in images_input[idx]]) 85 | data_input.append(inputs) 86 | data_gt.append(gts) 87 | return data_gt, data_input 88 | 89 | def add_noise(self, x): 90 | # x is numpy here 91 | x = torch.tensor(x).unsqueeze(0).permute(0, 3, 1, 2) 92 | if self.phase == "train": 93 | # uniform sampling from [20, 50] 94 | r1 = 20.0/255.0 95 | r2 = 50.0/255.0 96 | stdn = np.random.rand(1,1,1,1) * (r2-r1) + r1 97 | stdn = torch.FloatTensor(stdn) 98 | noise = torch.zeros_like(x) 99 | noise = torch.normal(mean=noise.float(), 100 | std=stdn.expand_as(noise)) 101 | lq = (noise + x/255.0)*255 102 | else: 103 | # in validation, the noise is fixed to 50.0/255.0. 104 | r2 = 50.0/255.0 105 | stdn = [r2] 106 | stdn = torch.FloatTensor(stdn) 107 | noise = torch.zeros_like(x) 108 | noise = torch.normal(mean=noise.float(), 109 | std=stdn.expand_as(noise)) 110 | lq = (noise + x/255.0)*255 111 | 112 | return lq.squeeze(0).permute(1, 2, 0).numpy() 113 | 114 | def __getitem__(self, idx): 115 | inputs, gts, filenames_prompts, filenames = self._load_file(idx) 116 | inputs_list = [inputs[i, :, :, :] for i in range(self.n_seq)] 117 | inputs_concat = np.concatenate(inputs_list, axis=2) 118 | gts_list = [gts[i, :, :, :] for i in range(self.n_seq)] 119 | gts_concat = np.concatenate(gts_list, axis=2) 120 | inputs_concat, gts_concat = self.get_patch(inputs_concat, gts_concat, self.size_must_mode) 121 | inputs_list = [inputs_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)] 122 | gts_list = [gts_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)] 123 | 124 | inputs_updated = [] 125 | for ix in range(len(filenames_prompts)): 126 | _filename_ = filenames_prompts[ix] 127 | _img_ = inputs_list[ix] 128 | if "DAVIS" in _filename_: 129 | # denoising dataset, add noise. 130 | noise_added_img = self.add_noise(_img_) 131 | inputs_updated.append(noise_added_img) 132 | else: 133 | # let it go as is. 134 | inputs_updated.append(_img_) 135 | 136 | inputs = np.array(inputs_updated) 137 | gts = np.array(gts_list) 138 | 139 | input_tensors = np2Tensor(*inputs, rgb_range=self.rgb_range, n_colors=self.n_colors) 140 | gt_tensors = np2Tensor(*gts, rgb_range=self.rgb_range, n_colors=self.n_colors) 141 | return torch.stack(input_tensors), torch.stack(gt_tensors), filenames_prompts, filenames 142 | 143 | def __len__(self): 144 | return self.num_frame 145 | 146 | def _get_index(self, idx): 147 | return idx % self.num_frame 148 | 149 | def _find_video_num(self, idx, n_frame): 150 | for i, j in enumerate(n_frame): 151 | if idx < j: return i, idx 152 | else: idx -= j 153 | 154 | def _load_file(self, idx): 155 | idx = self._get_index(idx) 156 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video] 157 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames) 158 | f_gts = self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq] 159 | f_inputs = self.images_input[video_idx][frame_idx:frame_idx + self.n_seq] 160 | gts = np.array([imageio.imread(hr_name) for hr_name in f_gts]) 161 | inputs = np.array([imageio.imread(lr_name) for lr_name in f_inputs]) 162 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0] 163 | for name in f_gts] 164 | filenames_prompts = [x for x in f_inputs] 165 | return inputs, gts, filenames_prompts, filenames 166 | 167 | def _load_file_from_loaded_data(self, idx): 168 | idx = self._get_index(idx) 169 | 170 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video] 171 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames) 172 | gts = self.data_gt[video_idx][frame_idx:frame_idx + self.n_seq] 173 | inputs = self.data_input[video_idx][frame_idx:frame_idx + self.n_seq] 174 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0] 175 | for name in self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq]] 176 | return inputs, gts, filenames 177 | 178 | def get_patch(self, input, gt, size_must_mode=1): 179 | if True: 180 | input, gt = get_patch(input, gt, patch_size=self.patch_size) 181 | h, w, c = input.shape 182 | new_h, new_w = h - h % size_must_mode, w - w % size_must_mode 183 | input, gt = input[:new_h, :new_w, :], gt[:new_h, :new_w, :] 184 | if not self.no_augment and self.phase == "train": 185 | input, gt = random_augmentation(input, gt) 186 | return input, gt -------------------------------------------------------------------------------- /basicsr/data/video_super_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from basicsr.data.data_util import np2Tensor 3 | from basicsr.data.transforms import random_augmentation 4 | import os 5 | import glob, imageio 6 | import numpy as np 7 | import torch 8 | import cv2, random 9 | 10 | class VideoSuperImageDataset(data.Dataset): 11 | def __init__(self, args, phase): 12 | self.args = args 13 | self.name = args['name'] 14 | self.phase = phase 15 | self.n_seq = args['n_sequence'] 16 | print("n_seq:", self.n_seq) 17 | self.n_frames_video = [] 18 | if self.phase == "train": 19 | self._set_filesystem(args['dir_data'], 20 | self.phase) 21 | else: 22 | self._set_filesystem(args['datasets']['val']['dir_data'], 23 | self.phase) 24 | 25 | self.images_gt, self.images_input = self._scan() 26 | self.num_video = len(self.images_gt) 27 | self.num_frame = sum(self.n_frames_video) - (self.n_seq - 1) * len(self.n_frames_video) 28 | print("Number of videos to load:", self.num_video) 29 | self.n_colors = args['n_colors'] 30 | self.rgb_range = args['rgb_range'] 31 | self.patch_size = args['patch_size'] 32 | self.no_augment = args['no_augment'] 33 | self.size_must_mode = args['size_must_mode'] 34 | 35 | def _set_filesystem(self, dir_data, phase): 36 | print("Loading {} => {} DataSet".format(f"{phase}", self.name)) 37 | if isinstance(dir_data, list): 38 | self.dir_gt = [] 39 | self.apath = [] 40 | self.dir_input = [] 41 | for path in dir_data: 42 | self.apath.append(path) 43 | self.dir_gt.append(os.path.join(path, 'gt')) 44 | self.dir_input.append(os.path.join(path, 'blur')) 45 | else: 46 | self.apath = dir_data 47 | self.dir_gt = os.path.join(self.apath, 'gt') 48 | self.dir_input = os.path.join(self.apath, 'blur') 49 | 50 | def _scan(self): 51 | if isinstance(self.dir_gt, list): 52 | vid_gt_names_combined = [] 53 | vid_input_names_combined = [] 54 | 55 | for ix in range(len(self.dir_gt)): 56 | vid_gt_names = sorted(glob.glob(os.path.join(self.dir_gt[ix], '*'))) 57 | vid_input_names = sorted(glob.glob(os.path.join(self.dir_input[ix], '*'))) 58 | 59 | vid_gt_names_combined.append(vid_gt_names) 60 | vid_input_names_combined.append(vid_input_names) 61 | assert len(vid_gt_names) == len(vid_input_names), "len(vid_gt_names) must equal len(vid_input_names)" 62 | else: 63 | vid_gt_names_combined = vid_gt_names 64 | vid_input_names_combined = vid_input_names 65 | 66 | images_gt = [] 67 | images_input = [] 68 | for vid_gt, vid_input in zip(vid_gt_names_combined, vid_input_names_combined): 69 | for vid_gt_name, vid_input_name in zip(vid_gt, vid_input): 70 | gt_dir_names = sorted(glob.glob(os.path.join(vid_gt_name, '*'))) 71 | input_dir_names = sorted(glob.glob(os.path.join(vid_input_name, '*'))) 72 | 73 | images_gt.append(gt_dir_names) 74 | images_input.append(input_dir_names) 75 | self.n_frames_video.append(len(gt_dir_names)) 76 | return images_gt, images_input 77 | 78 | def _load(self, images_gt, images_input): 79 | data_input = [] 80 | data_gt = [] 81 | n_videos = len(images_gt) 82 | for idx in range(n_videos): 83 | if idx % 10 == 0: 84 | print("Loading video %d" % idx) 85 | gts = np.array([imageio.imread(hr_name) for hr_name in images_gt[idx]]) 86 | inputs = np.array([imageio.imread(lr_name) for lr_name in images_input[idx]]) 87 | data_input.append(inputs) 88 | data_gt.append(gts) 89 | return data_gt, data_input 90 | 91 | def __getitem__(self, idx): 92 | inputs, gts, filenames, filenames_prompts = self._load_file(idx) 93 | inputs_list = [inputs[i, :, :, :] for i in range(self.n_seq)] 94 | inputs_concat = np.concatenate(inputs_list, axis=2) 95 | gts_list = [gts[i, :, :, :] for i in range(self.n_seq)] 96 | gts_concat = np.concatenate(gts_list, axis=2) 97 | 98 | inputs_concat, gts_concat = self._crop_patch(inputs_concat, gts_concat) 99 | inputs_list = [inputs_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)] 100 | gts_list = [gts_concat[:, :, i*self.n_colors:(i+1)*self.n_colors] for i in range(self.n_seq)] 101 | inputs = np.array(inputs_list) 102 | gts = np.array(gts_list) 103 | 104 | input_tensors = np2Tensor(*inputs, rgb_range=self.rgb_range, n_colors=self.n_colors) 105 | gt_tensors = np2Tensor(*gts, rgb_range=self.rgb_range, n_colors=self.n_colors) 106 | return torch.stack(input_tensors), torch.stack(gt_tensors), filenames, filenames_prompts 107 | 108 | def __len__(self): 109 | return self.num_frame 110 | 111 | def _get_index(self, idx): 112 | return idx % self.num_frame 113 | 114 | def _find_video_num(self, idx, n_frame): 115 | for i, j in enumerate(n_frame): 116 | if idx < j: return i, idx 117 | else: idx -= j 118 | 119 | def _load_file(self, idx): 120 | idx = self._get_index(idx) 121 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video] 122 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames) 123 | f_gts = self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq] 124 | f_inputs = self.images_input[video_idx][frame_idx:frame_idx + self.n_seq] 125 | inputs = [] 126 | gts = np.array([imageio.imread(hr_name) for hr_name in f_gts]) 127 | # inputs = np.array([imageio.imread(lr_name) for lr_name in f_inputs]) 128 | inputs = [] 129 | for lr_name in f_inputs: 130 | lq_img = imageio.imread(lr_name) 131 | h,w,_ = lq_img.shape 132 | lq_img_ = cv2.resize(lq_img, (w//4, h//4), 133 | interpolation=cv2.INTER_CUBIC) 134 | inputs.append(lq_img_) 135 | inputs = np.array(inputs) 136 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0] 137 | for name in f_gts] 138 | filenames_prompts = [x for x in f_inputs] 139 | return inputs, gts, filenames, filenames_prompts 140 | 141 | def _load_file_from_loaded_data(self, idx): 142 | idx = self._get_index(idx) 143 | 144 | n_poss_frames = [n - self.n_seq + 1 for n in self.n_frames_video] 145 | video_idx, frame_idx = self._find_video_num(idx, n_poss_frames) 146 | gts = self.data_gt[video_idx][frame_idx:frame_idx + self.n_seq] 147 | inputs = self.data_input[video_idx][frame_idx:frame_idx + self.n_seq] 148 | filenames = [os.path.split(os.path.dirname(name))[-1] + '.' + os.path.splitext(os.path.basename(name))[0] 149 | for name in self.images_gt[video_idx][frame_idx:frame_idx + self.n_seq]] 150 | return inputs, gts, filenames 151 | 152 | def _crop_patch(self, lr_seq, hr_seq, patch_size=48, scale=4): 153 | ih, iw, _ = lr_seq.shape 154 | pw = random.randrange(0, iw - patch_size + 1) 155 | ph = random.randrange(0, ih - patch_size + 1) 156 | 157 | hpw, hph = scale * pw, scale * ph 158 | hr_patch_size = scale * patch_size 159 | 160 | lr_patch_seq = lr_seq[ph:ph+patch_size, pw:pw+patch_size, :] 161 | hr_patch_seq = hr_seq[hph:hph+hr_patch_size, hpw:hpw+hr_patch_size, :] 162 | if not self.no_augment and self.phase == "train": 163 | lr_patch_seq, hr_patch_seq = random_augmentation(lr_patch_seq, hr_patch_seq) 164 | return lr_patch_seq, hr_patch_seq -------------------------------------------------------------------------------- /basicsr/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class L1BaseLoss(nn.Module): 8 | def __init__(self, loss_weight=1.0, reduction='mean'): 9 | super(L1BaseLoss, self).__init__() 10 | self.loss_weight = loss_weight 11 | self.reduction = reduction 12 | 13 | def forward(self, pred, target): 14 | l1_loss = nn.L1Loss() 15 | l1_base = l1_loss(pred, target) 16 | return l1_base 17 | 18 | class PSNRLoss(nn.Module): 19 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 20 | super(PSNRLoss, self).__init__() 21 | assert reduction == 'mean' 22 | self.loss_weight = loss_weight 23 | self.scale = 10 / np.log(10) 24 | self.toY = toY 25 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 26 | self.first = True 27 | 28 | def forward(self, pred, target): 29 | assert len(pred.size()) == 4 30 | if self.toY: 31 | if self.first: 32 | self.coef = self.coef.to(pred.device) 33 | self.first = False 34 | 35 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 36 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 37 | 38 | pred, target = pred / 255., target / 255. 39 | pass 40 | assert len(pred.size()) == 4 41 | 42 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | # from .niqe import calculate_niqe 6 | from .psnr_ssim import calculate_psnr, calculate_ssim 7 | 8 | __all__ = ['calculate_psnr', 'calculate_ssim'] 9 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import numpy as np 8 | 9 | from basicsr.utils.matlab_functions import bgr2ycbcr 10 | 11 | 12 | def reorder_image(img, input_order='HWC'): 13 | """Reorder images to 'HWC' order. 14 | 15 | If the input_order is (h, w), return (h, w, 1); 16 | If the input_order is (c, h, w), return (h, w, c); 17 | If the input_order is (h, w, c), return as it is. 18 | 19 | Args: 20 | img (ndarray): Input image. 21 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 22 | If the input image shape is (h, w), input_order will not have 23 | effects. Default: 'HWC'. 24 | 25 | Returns: 26 | ndarray: reordered image. 27 | """ 28 | 29 | if input_order not in ['HWC', 'CHW']: 30 | raise ValueError( 31 | f'Wrong input_order {input_order}. Supported input_orders are ' 32 | "'HWC' and 'CHW'") 33 | if len(img.shape) == 2: 34 | img = img[..., None] 35 | if input_order == 'CHW': 36 | img = img.transpose(1, 2, 0) 37 | return img 38 | 39 | 40 | def to_y_channel(img): 41 | """Change to Y channel of YCbCr. 42 | 43 | Args: 44 | img (ndarray): Images with range [0, 255]. 45 | 46 | Returns: 47 | (ndarray): Images with range [0, 255] (float type) without round. 48 | """ 49 | img = img.astype(np.float32) / 255. 50 | if img.ndim == 3 and img.shape[2] == 3: 51 | img = bgr2ycbcr(img, y_only=True) 52 | img = img[..., None] 53 | return img * 255. 54 | -------------------------------------------------------------------------------- /basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | import cv2 6 | import numpy as np 7 | 8 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 9 | import skimage.metrics 10 | import torch 11 | 12 | 13 | def calculate_psnr(img1, 14 | img2, 15 | crop_border, 16 | input_order='HWC', 17 | test_y_channel=False): 18 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 19 | 20 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 21 | 22 | Args: 23 | img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. 24 | img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. 25 | crop_border (int): Cropped pixels in each edge of an image. These 26 | pixels are not involved in the PSNR calculation. 27 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 28 | Default: 'HWC'. 29 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 30 | 31 | Returns: 32 | float: psnr result. 33 | """ 34 | 35 | assert img1.shape == img2.shape, ( 36 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 37 | if input_order not in ['HWC', 'CHW']: 38 | raise ValueError( 39 | f'Wrong input_order {input_order}. Supported input_orders are ' 40 | '"HWC" and "CHW"') 41 | if type(img1) == torch.Tensor: 42 | if len(img1.shape) == 4: 43 | img1 = img1.squeeze(0) 44 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 45 | if type(img2) == torch.Tensor: 46 | if len(img2.shape) == 4: 47 | img2 = img2.squeeze(0) 48 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 49 | 50 | img1 = reorder_image(img1, input_order=input_order) 51 | img2 = reorder_image(img2, input_order=input_order) 52 | img1 = img1.astype(np.float64) 53 | img2 = img2.astype(np.float64) 54 | 55 | if crop_border != 0: 56 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 57 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 58 | 59 | if test_y_channel: 60 | img1 = to_y_channel(img1) 61 | img2 = to_y_channel(img2) 62 | 63 | mse = np.mean((img1 - img2)**2) 64 | if mse == 0: 65 | return float('inf') 66 | max_value = 1. if img1.max() <= 1 else 255. 67 | return 20. * np.log10(max_value / np.sqrt(mse)) 68 | 69 | 70 | def _ssim(img1, img2): 71 | """Calculate SSIM (structural similarity) for one channel images. 72 | 73 | It is called by func:`calculate_ssim`. 74 | 75 | Args: 76 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 77 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 78 | 79 | Returns: 80 | float: ssim result. 81 | """ 82 | 83 | C1 = (0.01 * 255)**2 84 | C2 = (0.03 * 255)**2 85 | 86 | img1 = img1.astype(np.float64) 87 | img2 = img2.astype(np.float64) 88 | kernel = cv2.getGaussianKernel(11, 1.5) 89 | window = np.outer(kernel, kernel.transpose()) 90 | 91 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 92 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 93 | mu1_sq = mu1**2 94 | mu2_sq = mu2**2 95 | mu1_mu2 = mu1 * mu2 96 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 97 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 98 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 99 | 100 | ssim_map = ((2 * mu1_mu2 + C1) * 101 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 102 | (sigma1_sq + sigma2_sq + C2)) 103 | return ssim_map.mean() 104 | 105 | def prepare_for_ssim(img, k): 106 | import torch 107 | with torch.no_grad(): 108 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() 109 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect') 110 | conv.weight.requires_grad = False 111 | conv.weight[:, :, :, :] = 1. / (k * k) 112 | 113 | img = conv(img) 114 | 115 | img = img.squeeze(0).squeeze(0) 116 | img = img[0::k, 0::k] 117 | return img.detach().cpu().numpy() 118 | 119 | def prepare_for_ssim_rgb(img, k): 120 | import torch 121 | with torch.no_grad(): 122 | img = torch.from_numpy(img).float() #HxWx3 123 | 124 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect') 125 | conv.weight.requires_grad = False 126 | conv.weight[:, :, :, :] = 1. / (k * k) 127 | 128 | new_img = [] 129 | 130 | for i in range(3): 131 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k]) 132 | 133 | return torch.stack(new_img, dim=2).detach().cpu().numpy() 134 | 135 | def _3d_gaussian_calculator(img, conv3d): 136 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 137 | return out 138 | 139 | def _generate_3d_gaussian_kernel(): 140 | kernel = cv2.getGaussianKernel(11, 1.5) 141 | window = np.outer(kernel, kernel.transpose()) 142 | kernel_3 = cv2.getGaussianKernel(11, 1.5) 143 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) 144 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') 145 | conv3d.weight.requires_grad = False 146 | conv3d.weight[0, 0, :, :, :] = kernel 147 | return conv3d 148 | 149 | def _ssim_3d(img1, img2, max_value): 150 | assert len(img1.shape) == 3 and len(img2.shape) == 3 151 | """Calculate SSIM (structural similarity) for one channel images. 152 | 153 | It is called by func:`calculate_ssim`. 154 | 155 | Args: 156 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 157 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 158 | 159 | Returns: 160 | float: ssim result. 161 | """ 162 | C1 = (0.01 * max_value) ** 2 163 | C2 = (0.03 * max_value) ** 2 164 | img1 = img1.astype(np.float64) 165 | img2 = img2.astype(np.float64) 166 | 167 | kernel = _generate_3d_gaussian_kernel().cuda() 168 | 169 | img1 = torch.tensor(img1).float().cuda() 170 | img2 = torch.tensor(img2).float().cuda() 171 | 172 | 173 | mu1 = _3d_gaussian_calculator(img1, kernel) 174 | mu2 = _3d_gaussian_calculator(img2, kernel) 175 | 176 | mu1_sq = mu1 ** 2 177 | mu2_sq = mu2 ** 2 178 | mu1_mu2 = mu1 * mu2 179 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq 180 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq 181 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2 182 | 183 | ssim_map = ((2 * mu1_mu2 + C1) * 184 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 185 | (sigma1_sq + sigma2_sq + C2)) 186 | return float(ssim_map.mean()) 187 | 188 | def _ssim_cly(img1, img2): 189 | assert len(img1.shape) == 2 and len(img2.shape) == 2 190 | """Calculate SSIM (structural similarity) for one channel images. 191 | 192 | It is called by func:`calculate_ssim`. 193 | 194 | Args: 195 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 196 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 197 | 198 | Returns: 199 | float: ssim result. 200 | """ 201 | 202 | C1 = (0.01 * 255)**2 203 | C2 = (0.03 * 255)**2 204 | img1 = img1.astype(np.float64) 205 | img2 = img2.astype(np.float64) 206 | 207 | kernel = cv2.getGaussianKernel(11, 1.5) 208 | # print(kernel) 209 | window = np.outer(kernel, kernel.transpose()) 210 | 211 | bt = cv2.BORDER_REPLICATE 212 | 213 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt) 214 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt) 215 | 216 | mu1_sq = mu1**2 217 | mu2_sq = mu2**2 218 | mu1_mu2 = mu1 * mu2 219 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq 220 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq 221 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2 222 | 223 | ssim_map = ((2 * mu1_mu2 + C1) * 224 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 225 | (sigma1_sq + sigma2_sq + C2)) 226 | return ssim_map.mean() 227 | 228 | 229 | def calculate_ssim(img1, 230 | img2, 231 | crop_border, 232 | input_order='HWC', 233 | test_y_channel=False): 234 | """Calculate SSIM (structural similarity). 235 | 236 | Ref: 237 | Image quality assessment: From error visibility to structural similarity 238 | 239 | The results are the same as that of the official released MATLAB code in 240 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 241 | 242 | For three-channel images, SSIM is calculated for each channel and then 243 | averaged. 244 | 245 | Args: 246 | img1 (ndarray): Images with range [0, 255]. 247 | img2 (ndarray): Images with range [0, 255]. 248 | crop_border (int): Cropped pixels in each edge of an image. These 249 | pixels are not involved in the SSIM calculation. 250 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 251 | Default: 'HWC'. 252 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 253 | 254 | Returns: 255 | float: ssim result. 256 | """ 257 | 258 | assert img1.shape == img2.shape, ( 259 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 260 | if input_order not in ['HWC', 'CHW']: 261 | raise ValueError( 262 | f'Wrong input_order {input_order}. Supported input_orders are ' 263 | '"HWC" and "CHW"') 264 | 265 | if type(img1) == torch.Tensor: 266 | if len(img1.shape) == 4: 267 | img1 = img1.squeeze(0) 268 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 269 | if type(img2) == torch.Tensor: 270 | if len(img2.shape) == 4: 271 | img2 = img2.squeeze(0) 272 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 273 | 274 | img1 = reorder_image(img1, input_order=input_order) 275 | img2 = reorder_image(img2, input_order=input_order) 276 | 277 | img1 = img1.astype(np.float64) 278 | img2 = img2.astype(np.float64) 279 | 280 | if crop_border != 0: 281 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 282 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 283 | 284 | if test_y_channel: 285 | img1 = to_y_channel(img1) 286 | img2 = to_y_channel(img2) 287 | return _ssim_cly(img1[..., 0], img2[..., 0]) 288 | 289 | 290 | ssims = [] 291 | 292 | max_value = 1 if img1.max() <= 1 else 255 293 | with torch.no_grad(): 294 | final_ssim = _ssim_3d(img1, img2, max_value) 295 | ssims.append(final_ssim) 296 | 297 | return np.array(ssims).mean() 298 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import importlib 8 | from os import path as osp 9 | 10 | from basicsr.utils import get_root_logger, scandir 11 | 12 | # automatically scan and import model modules 13 | # scan all the files under the 'models' folder and collect files ending with 14 | # '_model.py' 15 | model_folder = osp.dirname(osp.abspath(__file__)) 16 | model_filenames = [ 17 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 18 | if v.endswith('_model.py') 19 | ] 20 | # import all the model modules 21 | _model_modules = [ 22 | importlib.import_module(f'basicsr.models.{file_name}') 23 | for file_name in model_filenames 24 | ] 25 | 26 | 27 | def create_model(opt): 28 | """Create model. 29 | 30 | Args: 31 | opt (dict): Configuration. It constains: 32 | model_type (str): Model type. 33 | """ 34 | model_type = opt['model_type'] 35 | 36 | # dynamic instantiation 37 | for module in _model_modules: 38 | model_cls = getattr(module, model_type, None) 39 | if model_cls is not None: 40 | break 41 | if model_cls is None: 42 | raise ValueError(f'Model {model_type} is not found.') 43 | 44 | model = model_cls(opt) 45 | 46 | logger = get_root_logger() 47 | logger.info(f'Model [{model.__class__.__name__}] is created.') 48 | return model 49 | -------------------------------------------------------------------------------- /basicsr/models/base_model.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import logging 8 | import os 9 | import torch 10 | from collections import OrderedDict 11 | from copy import deepcopy 12 | from torch.nn.parallel import DataParallel, DistributedDataParallel 13 | 14 | from basicsr.models import lr_scheduler as lr_scheduler 15 | from basicsr.utils.dist_util import master_only 16 | 17 | logger = logging.getLogger('basicsr') 18 | 19 | 20 | class BaseModel(): 21 | """Base model.""" 22 | 23 | def __init__(self, opt): 24 | self.opt = opt 25 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 26 | self.is_train = opt['is_train'] 27 | self.schedulers = [] 28 | self.optimizers = [] 29 | 30 | def feed_data(self, data): 31 | pass 32 | 33 | def optimize_parameters(self): 34 | pass 35 | 36 | def get_current_visuals(self): 37 | pass 38 | 39 | def save(self, epoch, current_iter): 40 | """Save networks and training state.""" 41 | pass 42 | 43 | def validation(self, dataloader, current_iter, tb_logger, save_img=False, rgb2bgr=True, use_image=True): 44 | """Validation function. 45 | 46 | Args: 47 | dataloader (torch.utils.data.DataLoader): Validation dataloader. 48 | current_iter (int): Current iteration. 49 | tb_logger (tensorboard logger): Tensorboard logger. 50 | save_img (bool): Whether to save images. Default: False. 51 | rgb2bgr (bool): Whether to save images using rgb2bgr. Default: True 52 | use_image (bool): Whether to use saved images to compute metrics (PSNR, SSIM), if not, then use data directly from network' output. Default: True 53 | """ 54 | if self.opt['dist']: 55 | return self.dist_validation(dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image) 56 | else: 57 | return self.nondist_validation(dataloader, current_iter, tb_logger, 58 | save_img, rgb2bgr, use_image) 59 | 60 | def get_current_log(self): 61 | return self.log_dict 62 | 63 | def model_to_device(self, net): 64 | """Model to device. It also warps models with DistributedDataParallel 65 | or DataParallel. 66 | 67 | Args: 68 | net (nn.Module) 69 | """ 70 | 71 | net = net.to(self.device) 72 | if self.opt['dist']: 73 | find_unused_parameters = True 74 | net = DistributedDataParallel( 75 | net, 76 | device_ids=[torch.cuda.current_device()], 77 | find_unused_parameters=find_unused_parameters) 78 | elif self.opt['num_gpu'] > 1: 79 | net = DataParallel(net) 80 | return net 81 | 82 | def setup_schedulers(self): 83 | """Set up schedulers.""" 84 | train_opt = self.opt['train'] 85 | scheduler_type = train_opt['scheduler'].pop('type') 86 | if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: 87 | for optimizer in self.optimizers: 88 | self.schedulers.append( 89 | lr_scheduler.MultiStepRestartLR(optimizer, 90 | **train_opt['scheduler'])) 91 | elif scheduler_type == 'CosineAnnealingRestartLR': 92 | for optimizer in self.optimizers: 93 | self.schedulers.append( 94 | lr_scheduler.CosineAnnealingRestartLR( 95 | optimizer, **train_opt['scheduler'])) 96 | elif scheduler_type == 'TrueCosineAnnealingLR': 97 | print('..', 'cosineannealingLR') 98 | for optimizer in self.optimizers: 99 | self.schedulers.append( 100 | torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **train_opt['scheduler'])) 101 | elif scheduler_type == 'LinearLR': 102 | for optimizer in self.optimizers: 103 | self.schedulers.append( 104 | lr_scheduler.LinearLR( 105 | optimizer, train_opt['total_iter'])) 106 | elif scheduler_type == 'VibrateLR': 107 | for optimizer in self.optimizers: 108 | self.schedulers.append( 109 | lr_scheduler.VibrateLR( 110 | optimizer, train_opt['total_iter'])) 111 | else: 112 | raise NotImplementedError( 113 | f'Scheduler {scheduler_type} is not implemented yet.') 114 | 115 | def get_bare_model(self, net): 116 | """Get bare model, especially under wrapping with 117 | DistributedDataParallel or DataParallel. 118 | """ 119 | if isinstance(net, (DataParallel, DistributedDataParallel)): 120 | net = net.module 121 | return net 122 | 123 | @master_only 124 | def print_network(self, net): 125 | """Print the str and parameter number of a network. 126 | 127 | Args: 128 | net (nn.Module) 129 | """ 130 | if isinstance(net, (DataParallel, DistributedDataParallel)): 131 | net_cls_str = (f'{net.__class__.__name__} - ' 132 | f'{net.module.__class__.__name__}') 133 | else: 134 | net_cls_str = f'{net.__class__.__name__}' 135 | 136 | net = self.get_bare_model(net) 137 | net_str = str(net) 138 | net_params = sum(map(lambda x: x.numel(), net.parameters())) 139 | 140 | logger.info( 141 | f'Network: {net_cls_str}, with parameters: {net_params:,d}') 142 | logger.info(net_str) 143 | 144 | def _set_lr(self, lr_groups_l): 145 | """Set learning rate for warmup. 146 | 147 | Args: 148 | lr_groups_l (list): List for lr_groups, each for an optimizer. 149 | """ 150 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 151 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 152 | param_group['lr'] = lr 153 | 154 | def _get_init_lr(self): 155 | """Get the initial lr, which is set by the scheduler. 156 | """ 157 | init_lr_groups_l = [] 158 | for optimizer in self.optimizers: 159 | init_lr_groups_l.append( 160 | [v['initial_lr'] for v in optimizer.param_groups]) 161 | return init_lr_groups_l 162 | 163 | def update_learning_rate(self, current_iter, warmup_iter=-1): 164 | """Update learning rate. 165 | 166 | Args: 167 | current_iter (int): Current iteration. 168 | warmup_iter (int): Warmup iter numbers. -1 for no warmup. 169 | Default: -1. 170 | """ 171 | if current_iter > 1: 172 | for scheduler in self.schedulers: 173 | scheduler.step() 174 | # set up warm-up learning rate 175 | if current_iter < warmup_iter: 176 | # get initial lr for each group 177 | init_lr_g_l = self._get_init_lr() 178 | # modify warming-up learning rates 179 | # currently only support linearly warm up 180 | warm_up_lr_l = [] 181 | for init_lr_g in init_lr_g_l: 182 | warm_up_lr_l.append( 183 | [v / warmup_iter * current_iter for v in init_lr_g]) 184 | # set learning rate 185 | self._set_lr(warm_up_lr_l) 186 | 187 | def get_current_learning_rate(self): 188 | return [ 189 | param_group['lr'] 190 | for param_group in self.optimizers[0].param_groups 191 | ] 192 | 193 | @master_only 194 | def save_network(self, net, net_label, current_iter, param_key='params'): 195 | """Save networks. 196 | 197 | Args: 198 | net (nn.Module | list[nn.Module]): Network(s) to be saved. 199 | net_label (str): Network label. 200 | current_iter (int): Current iter number. 201 | param_key (str | list[str]): The parameter key(s) to save network. 202 | Default: 'params'. 203 | """ 204 | if current_iter == -1: 205 | current_iter = 'latest' 206 | save_filename = f'{net_label}_{current_iter}.pth' 207 | save_path = os.path.join(self.opt['path']['models'], save_filename) 208 | 209 | net = net if isinstance(net, list) else [net] 210 | param_key = param_key if isinstance(param_key, list) else [param_key] 211 | assert len(net) == len( 212 | param_key), 'The lengths of net and param_key should be the same.' 213 | 214 | save_dict = {} 215 | for net_, param_key_ in zip(net, param_key): 216 | net_ = self.get_bare_model(net_) 217 | state_dict = net_.state_dict() 218 | for key, param in state_dict.items(): 219 | if key.startswith('module.'): # remove unnecessary 'module.' 220 | key = key[7:] 221 | state_dict[key] = param.cpu() 222 | save_dict[param_key_] = state_dict 223 | 224 | torch.save(save_dict, save_path) 225 | 226 | def _print_different_keys_loading(self, crt_net, load_net, strict=True): 227 | """Print keys with differnet name or different size when loading models. 228 | 229 | 1. Print keys with differnet names. 230 | 2. If strict=False, print the same key but with different tensor size. 231 | It also ignore these keys with different sizes (not load). 232 | 233 | Args: 234 | crt_net (torch model): Current network. 235 | load_net (dict): Loaded network. 236 | strict (bool): Whether strictly loaded. Default: True. 237 | """ 238 | crt_net = self.get_bare_model(crt_net) 239 | crt_net = crt_net.state_dict() 240 | crt_net_keys = set(crt_net.keys()) 241 | load_net_keys = set(load_net.keys()) 242 | 243 | if crt_net_keys != load_net_keys: 244 | logger.warning('Current net - loaded net:') 245 | for v in sorted(list(crt_net_keys - load_net_keys)): 246 | logger.warning(f' {v}') 247 | logger.warning('Loaded net - current net:') 248 | for v in sorted(list(load_net_keys - crt_net_keys)): 249 | logger.warning(f' {v}') 250 | 251 | # check the size for the same keys 252 | if not strict: 253 | common_keys = crt_net_keys & load_net_keys 254 | for k in common_keys: 255 | if crt_net[k].size() != load_net[k].size(): 256 | logger.warning( 257 | f'Size different, ignore [{k}]: crt_net: ' 258 | f'{crt_net[k].shape}; load_net: {load_net[k].shape}') 259 | load_net[k + '.ignore'] = load_net.pop(k) 260 | 261 | def load_network(self, net, load_path, strict=True, param_key='params'): 262 | """Load network. 263 | 264 | Args: 265 | load_path (str): The path of networks to be loaded. 266 | net (nn.Module): Network. 267 | strict (bool): Whether strictly loaded. 268 | param_key (str): The parameter key of loaded network. If set to 269 | None, use the root 'path'. 270 | Default: 'params'. 271 | """ 272 | net = self.get_bare_model(net) 273 | logger.info( 274 | f'Loading {net.__class__.__name__} model from {load_path}.') 275 | load_net = torch.load( 276 | load_path, map_location=lambda storage, loc: storage) 277 | if param_key is not None: 278 | load_net = load_net[param_key] 279 | print(' load net keys', load_net.keys) 280 | # remove unnecessary 'module.' 281 | for k, v in deepcopy(load_net).items(): 282 | if k.startswith('module.'): 283 | load_net[k[7:]] = v 284 | load_net.pop(k) 285 | self._print_different_keys_loading(net, load_net, strict) 286 | net.load_state_dict(load_net, strict=strict) 287 | 288 | @master_only 289 | def save_training_state(self, epoch, current_iter): 290 | """Save training states during training, which will be used for 291 | resuming. 292 | 293 | Args: 294 | epoch (int): Current epoch. 295 | current_iter (int): Current iteration. 296 | """ 297 | if current_iter != -1: 298 | state = { 299 | 'epoch': epoch, 300 | 'iter': current_iter, 301 | 'optimizers': [], 302 | 'schedulers': [] 303 | } 304 | for o in self.optimizers: 305 | state['optimizers'].append(o.state_dict()) 306 | for s in self.schedulers: 307 | state['schedulers'].append(s.state_dict()) 308 | save_filename = f'{current_iter}.state' 309 | save_path = os.path.join(self.opt['path']['training_states'], 310 | save_filename) 311 | torch.save(state, save_path) 312 | 313 | def resume_training(self, resume_state): 314 | """Reload the optimizers and schedulers for resumed training. 315 | 316 | Args: 317 | resume_state (dict): Resume state. 318 | """ 319 | resume_optimizers = resume_state['optimizers'] 320 | resume_schedulers = resume_state['schedulers'] 321 | assert len(resume_optimizers) == len( 322 | self.optimizers), 'Wrong lengths of optimizers' 323 | assert len(resume_schedulers) == len( 324 | self.schedulers), 'Wrong lengths of schedulers' 325 | for i, o in enumerate(resume_optimizers): 326 | self.optimizers[i].load_state_dict(o) 327 | for i, s in enumerate(resume_schedulers): 328 | self.schedulers[i].load_state_dict(s) 329 | 330 | def gather_tensors(self, tensor_to_gather): 331 | if tensor_to_gather is None: 332 | return None 333 | group = torch.distributed.group.WORLD 334 | group_size = torch.distributed.get_world_size(group) 335 | gather_t_tensors = [torch.zeros_like(tensor_to_gather) for _ 336 | in range(group_size)] 337 | torch.distributed.all_gather(gather_t_tensors, tensor_to_gather) 338 | return torch.cat(gather_t_tensors, dim=0) 339 | 340 | def reduce_loss_dict(self, loss_dict): 341 | """reduce loss dict. 342 | 343 | In distributed training, it averages the losses among different GPUs. 344 | 345 | Args: 346 | loss_dict (OrderedDict): Loss dict. 347 | """ 348 | with torch.no_grad(): 349 | if self.opt['dist']: 350 | keys = [] 351 | losses = [] 352 | for name, value in loss_dict.items(): 353 | keys.append(name) 354 | losses.append(value) 355 | losses = torch.stack(losses, 0) 356 | torch.distributed.reduce(losses, dst=0) 357 | if self.opt['rank'] == 0: 358 | losses /= self.opt['world_size'] 359 | loss_dict = {key: loss for key, loss in zip(keys, losses)} 360 | 361 | log_dict = OrderedDict() 362 | for name, value in loss_dict.items(): 363 | log_dict[name] = value.mean().item() 364 | 365 | return log_dict 366 | -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | from .losses import (L1Loss, MSELoss, PSNRLoss) 8 | 9 | __all__ = [ 10 | 'L1Loss', 'MSELoss', 'PSNRLoss', 11 | ] 12 | -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import functools 8 | from torch.nn import functional as F 9 | 10 | 11 | def reduce_loss(loss, reduction): 12 | """Reduce loss as specified. 13 | 14 | Args: 15 | loss (Tensor): Elementwise loss tensor. 16 | reduction (str): Options are 'none', 'mean' and 'sum'. 17 | 18 | Returns: 19 | Tensor: Reduced loss tensor. 20 | """ 21 | reduction_enum = F._Reduction.get_enum(reduction) 22 | # none: 0, elementwise_mean:1, sum: 2 23 | if reduction_enum == 0: 24 | return loss 25 | elif reduction_enum == 1: 26 | return loss.mean() 27 | else: 28 | return loss.sum() 29 | 30 | 31 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 32 | """Apply element-wise weight and reduce loss. 33 | 34 | Args: 35 | loss (Tensor): Element-wise loss. 36 | weight (Tensor): Element-wise weights. Default: None. 37 | reduction (str): Same as built-in losses of PyTorch. Options are 38 | 'none', 'mean' and 'sum'. Default: 'mean'. 39 | 40 | Returns: 41 | Tensor: Loss values. 42 | """ 43 | # if weight is specified, apply element-wise weight 44 | if weight is not None: 45 | assert weight.dim() == loss.dim() 46 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 47 | loss = loss * weight 48 | 49 | # if weight is not specified or reduction is sum, just reduce the loss 50 | if weight is None or reduction == 'sum': 51 | loss = reduce_loss(loss, reduction) 52 | # if reduction is mean, then compute mean over weight region 53 | elif reduction == 'mean': 54 | if weight.size(1) > 1: 55 | weight = weight.sum() 56 | else: 57 | weight = weight.sum() * loss.size(1) 58 | loss = loss.sum() / weight 59 | 60 | return loss 61 | 62 | 63 | def weighted_loss(loss_func): 64 | """Create a weighted version of a given loss function. 65 | 66 | To use this decorator, the loss function must have the signature like 67 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 68 | element-wise loss without any reduction. This decorator will add weight 69 | and reduction arguments to the function. The decorated function will have 70 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 71 | **kwargs)`. 72 | 73 | :Example: 74 | 75 | >>> import torch 76 | >>> @weighted_loss 77 | >>> def l1_loss(pred, target): 78 | >>> return (pred - target).abs() 79 | 80 | >>> pred = torch.Tensor([0, 2, 3]) 81 | >>> target = torch.Tensor([1, 1, 1]) 82 | >>> weight = torch.Tensor([1, 0, 1]) 83 | 84 | >>> l1_loss(pred, target) 85 | tensor(1.3333) 86 | >>> l1_loss(pred, target, weight) 87 | tensor(1.5000) 88 | >>> l1_loss(pred, target, reduction='none') 89 | tensor([1., 1., 2.]) 90 | >>> l1_loss(pred, target, weight, reduction='sum') 91 | tensor(3.) 92 | """ 93 | 94 | @functools.wraps(loss_func) 95 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 96 | # get element-wise loss 97 | loss = loss_func(pred, target, **kwargs) 98 | loss = weight_reduce_loss(loss, weight, reduction) 99 | return loss 100 | 101 | return wrapper 102 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import torch 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | from basicsr.models.losses.loss_util import weighted_loss 13 | 14 | _reduction_modes = ['none', 'mean', 'sum'] 15 | 16 | 17 | @weighted_loss 18 | def l1_loss(pred, target): 19 | return F.l1_loss(pred, target, reduction='none') 20 | 21 | 22 | @weighted_loss 23 | def mse_loss(pred, target): 24 | return F.mse_loss(pred, target, reduction='none') 25 | 26 | 27 | # @weighted_loss 28 | # def charbonnier_loss(pred, target, eps=1e-12): 29 | # return torch.sqrt((pred - target)**2 + eps) 30 | 31 | 32 | class L1Loss(nn.Module): 33 | """L1 (mean absolute error, MAE) loss. 34 | 35 | Args: 36 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 37 | reduction (str): Specifies the reduction to apply to the output. 38 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 39 | """ 40 | 41 | def __init__(self, loss_weight=1.0, reduction='mean'): 42 | super(L1Loss, self).__init__() 43 | if reduction not in ['none', 'mean', 'sum']: 44 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 45 | f'Supported ones are: {_reduction_modes}') 46 | 47 | self.loss_weight = loss_weight 48 | self.reduction = reduction 49 | 50 | def forward(self, pred, target, weight=None, **kwargs): 51 | """ 52 | Args: 53 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 54 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 55 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 56 | weights. Default: None. 57 | """ 58 | return self.loss_weight * l1_loss( 59 | pred, target, weight, reduction=self.reduction) 60 | 61 | class MSELoss(nn.Module): 62 | """MSE (L2) loss. 63 | 64 | Args: 65 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 66 | reduction (str): Specifies the reduction to apply to the output. 67 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 68 | """ 69 | 70 | def __init__(self, loss_weight=1.0, reduction='mean'): 71 | super(MSELoss, self).__init__() 72 | if reduction not in ['none', 'mean', 'sum']: 73 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 74 | f'Supported ones are: {_reduction_modes}') 75 | 76 | self.loss_weight = loss_weight 77 | self.reduction = reduction 78 | 79 | def forward(self, pred, target, weight=None, **kwargs): 80 | """ 81 | Args: 82 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 83 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 84 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 85 | weights. Default: None. 86 | """ 87 | return self.loss_weight * mse_loss( 88 | pred, target, weight, reduction=self.reduction) 89 | 90 | class PSNRLoss(nn.Module): 91 | 92 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 93 | super(PSNRLoss, self).__init__() 94 | assert reduction == 'mean' 95 | self.loss_weight = loss_weight 96 | self.scale = 10 / np.log(10) 97 | self.toY = toY 98 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 99 | self.first = True 100 | 101 | def forward(self, pred, target): 102 | assert len(pred.size()) == 4 103 | if self.toY: 104 | if self.first: 105 | self.coef = self.coef.to(pred.device) 106 | self.first = False 107 | 108 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 109 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 110 | 111 | pred, target = pred / 255., target / 255. 112 | pass 113 | assert len(pred.size()) == 4 114 | 115 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 116 | 117 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import math 8 | from collections import Counter 9 | from torch.optim.lr_scheduler import _LRScheduler 10 | 11 | 12 | class MultiStepRestartLR(_LRScheduler): 13 | """ MultiStep with restarts learning rate scheme. 14 | 15 | Args: 16 | optimizer (torch.nn.optimizer): Torch optimizer. 17 | milestones (list): Iterations that will decrease learning rate. 18 | gamma (float): Decrease ratio. Default: 0.1. 19 | restarts (list): Restart iterations. Default: [0]. 20 | restart_weights (list): Restart weights at each restart iteration. 21 | Default: [1]. 22 | last_epoch (int): Used in _LRScheduler. Default: -1. 23 | """ 24 | 25 | def __init__(self, 26 | optimizer, 27 | milestones, 28 | gamma=0.1, 29 | restarts=(0, ), 30 | restart_weights=(1, ), 31 | last_epoch=-1): 32 | self.milestones = Counter(milestones) 33 | self.gamma = gamma 34 | self.restarts = restarts 35 | self.restart_weights = restart_weights 36 | assert len(self.restarts) == len( 37 | self.restart_weights), 'restarts and their weights do not match.' 38 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 39 | 40 | def get_lr(self): 41 | if self.last_epoch in self.restarts: 42 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 43 | return [ 44 | group['initial_lr'] * weight 45 | for group in self.optimizer.param_groups 46 | ] 47 | if self.last_epoch not in self.milestones: 48 | return [group['lr'] for group in self.optimizer.param_groups] 49 | return [ 50 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 51 | for group in self.optimizer.param_groups 52 | ] 53 | 54 | class LinearLR(_LRScheduler): 55 | """ 56 | 57 | Args: 58 | optimizer (torch.nn.optimizer): Torch optimizer. 59 | milestones (list): Iterations that will decrease learning rate. 60 | gamma (float): Decrease ratio. Default: 0.1. 61 | last_epoch (int): Used in _LRScheduler. Default: -1. 62 | """ 63 | 64 | def __init__(self, 65 | optimizer, 66 | total_iter, 67 | last_epoch=-1): 68 | self.total_iter = total_iter 69 | super(LinearLR, self).__init__(optimizer, last_epoch) 70 | 71 | def get_lr(self): 72 | process = self.last_epoch / self.total_iter 73 | weight = (1 - process) 74 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) 75 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 76 | 77 | class VibrateLR(_LRScheduler): 78 | """ 79 | 80 | Args: 81 | optimizer (torch.nn.optimizer): Torch optimizer. 82 | milestones (list): Iterations that will decrease learning rate. 83 | gamma (float): Decrease ratio. Default: 0.1. 84 | last_epoch (int): Used in _LRScheduler. Default: -1. 85 | """ 86 | 87 | def __init__(self, 88 | optimizer, 89 | total_iter, 90 | last_epoch=-1): 91 | self.total_iter = total_iter 92 | super(VibrateLR, self).__init__(optimizer, last_epoch) 93 | 94 | def get_lr(self): 95 | process = self.last_epoch / self.total_iter 96 | 97 | f = 0.1 98 | if process < 3 / 8: 99 | f = 1 - process * 8 / 3 100 | elif process < 5 / 8: 101 | f = 0.2 102 | 103 | T = self.total_iter // 80 104 | Th = T // 2 105 | 106 | t = self.last_epoch % T 107 | 108 | f2 = t / Th 109 | if t >= Th: 110 | f2 = 2 - f2 111 | 112 | weight = f * f2 113 | 114 | if self.last_epoch < Th: 115 | weight = max(0.1, weight) 116 | 117 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) 118 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 119 | 120 | def get_position_from_periods(iteration, cumulative_period): 121 | """Get the position from a period list. 122 | 123 | It will return the index of the right-closest number in the period list. 124 | For example, the cumulative_period = [100, 200, 300, 400], 125 | if iteration == 50, return 0; 126 | if iteration == 210, return 2; 127 | if iteration == 300, return 2. 128 | 129 | Args: 130 | iteration (int): Current iteration. 131 | cumulative_period (list[int]): Cumulative period list. 132 | 133 | Returns: 134 | int: The position of the right-closest number in the period list. 135 | """ 136 | for i, period in enumerate(cumulative_period): 137 | if iteration <= period: 138 | return i 139 | 140 | 141 | class CosineAnnealingRestartLR(_LRScheduler): 142 | """ Cosine annealing with restarts learning rate scheme. 143 | 144 | An example of config: 145 | periods = [10, 10, 10, 10] 146 | restart_weights = [1, 0.5, 0.5, 0.5] 147 | eta_min=1e-7 148 | 149 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 150 | scheduler will restart with the weights in restart_weights. 151 | 152 | Args: 153 | optimizer (torch.nn.optimizer): Torch optimizer. 154 | periods (list): Period for each cosine anneling cycle. 155 | restart_weights (list): Restart weights at each restart iteration. 156 | Default: [1]. 157 | eta_min (float): The mimimum lr. Default: 0. 158 | last_epoch (int): Used in _LRScheduler. Default: -1. 159 | """ 160 | 161 | def __init__(self, 162 | optimizer, 163 | periods, 164 | restart_weights=(1, ), 165 | eta_min=0, 166 | last_epoch=-1): 167 | self.periods = periods 168 | self.restart_weights = restart_weights 169 | self.eta_min = eta_min 170 | assert (len(self.periods) == len(self.restart_weights) 171 | ), 'periods and restart_weights should have the same length.' 172 | self.cumulative_period = [ 173 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 174 | ] 175 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 176 | 177 | def get_lr(self): 178 | idx = get_position_from_periods(self.last_epoch, 179 | self.cumulative_period) 180 | current_weight = self.restart_weights[idx] 181 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 182 | current_period = self.periods[idx] 183 | 184 | return [ 185 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 186 | (1 + math.cos(math.pi * ( 187 | (self.last_epoch - nearest_restart) / current_period))) 188 | for base_lr in self.base_lrs 189 | ] 190 | -------------------------------------------------------------------------------- /basicsr/models/video_restoration_model.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | from collections import OrderedDict 4 | from copy import deepcopy 5 | from os import path as osp 6 | from tqdm import tqdm 7 | from torch.nn.parallel import DataParallel, DistributedDataParallel 8 | from basicsr.utils.dist_util import get_dist_info 9 | from basicsr.models.base_model import BaseModel 10 | from basicsr.utils import get_root_logger, imwrite, tensor2img 11 | from importlib import import_module 12 | import basicsr.loss as loss 13 | import numpy as np 14 | import matplotlib.pyplot as plt 15 | 16 | import json 17 | 18 | def create_video_model(opt): 19 | module = import_module('basicsr.models.archs.' + opt['model'].lower()) 20 | model = module.make_model(opt) 21 | return model 22 | 23 | metric_module = importlib.import_module('basicsr.metrics') 24 | 25 | class VideoRestorationModel(BaseModel): 26 | def __init__(self, opt): 27 | super(VideoRestorationModel, self).__init__(opt) 28 | self.net_g = create_video_model(opt) 29 | self.net_g = self.model_to_device(self.net_g) 30 | self.n_sequence = opt['n_sequence'] 31 | load_path = self.opt['path'].get('pretrain_network_g', None) 32 | if load_path is not None: 33 | self.load_network(self.net_g, load_path, 34 | self.opt['path'].get('strict_load_g', True), param_key=self.opt['path'].get('param_key', 'params')) 35 | print("load_model", load_path) 36 | if self.is_train: 37 | self.init_training_settings() 38 | self.loss = loss.L1BaseLoss() 39 | self.scaler = torch.cuda.amp.GradScaler() 40 | 41 | def init_training_settings(self): 42 | self.net_g.train() 43 | # set up optimizers and schedulers 44 | self.setup_optimizers() 45 | self.setup_schedulers() 46 | 47 | def model_to_device(self, net): 48 | net = net.to(self.device) 49 | if self.opt['dist']: 50 | net = DistributedDataParallel( 51 | net, 52 | device_ids=[torch.cuda.current_device()], 53 | find_unused_parameters=False) 54 | elif self.opt['num_gpu'] > 1: 55 | net = DataParallel(net) 56 | return net 57 | 58 | def setup_optimizers(self): 59 | train_opt = self.opt['train'] 60 | optim_params = [] 61 | for k, v in self.net_g.named_parameters(): 62 | if v.requires_grad: 63 | optim_params.append(v) 64 | else: 65 | logger = get_root_logger() 66 | logger.warning(f'Params {k} will not be optimized.') 67 | train_opt['optim_g'].pop('type') 68 | self.optimizer_g = torch.optim.AdamW([{'params': optim_params}], 69 | **train_opt['optim_g']) 70 | self.optimizers.append(self.optimizer_g) 71 | 72 | # method to feed the data to the model. 73 | def feed_data(self, data): 74 | lq, gt, _, _ = data 75 | self.lq = lq.to(self.device).half() 76 | self.gt = gt.to(self.device) 77 | 78 | def optimize_parameters(self, current_iter): 79 | self.optimizer_g.zero_grad() 80 | with torch.cuda.amp.autocast(): 81 | loss_dict = OrderedDict() 82 | loss_dict['l_pix'] = 0 83 | 84 | frame_num = self.lq.shape[1] 85 | k_cache, v_cache = None, None 86 | for j in range(frame_num): 87 | target_g_images = self.gt[:, j, :, :, :] 88 | current_input = self.lq[:, j,:, :, :].unsqueeze(1) 89 | pre_input = self.lq[:, j if j == 0 else j-1, :, :, :].unsqueeze(1) 90 | 91 | input = torch.concat([pre_input, current_input], dim=1) 92 | (out_g, k_cache, v_cache) = self.net_g(input, k_cache, v_cache) 93 | 94 | l_pix = self.loss(out_g, target_g_images) 95 | loss_dict['l_pix'] += l_pix 96 | 97 | # normalize w.r.t. total frames seen. 98 | loss_dict['l_pix'] /= frame_num 99 | l_total = loss_dict['l_pix'] + 0 * sum(p.sum() for p in self.net_g.parameters()) 100 | loss_dict['l_pix'] = loss_dict['l_pix'] 101 | 102 | self.scaler.scale(l_total).backward() 103 | self.scaler.unscale_(self.optimizer_g) 104 | # do gradient clipping to avoid larger updates. 105 | # torch.nn.utils.clip_grad_norm_(self.net_g.parameters(), 0.01) 106 | self.scaler.step(self.optimizer_g) 107 | self.scaler.update() 108 | self.log_dict = self.reduce_loss_dict(loss_dict) 109 | 110 | def test(self): 111 | self.net_g.eval() 112 | with torch.no_grad(): 113 | self.outputs_list = [] 114 | self.gt_lists = [] 115 | self.lq_lists = [] 116 | frame_num = self.lq.shape[1] 117 | k_cache, v_cache = None, None 118 | for j in range(frame_num): 119 | target_g_images = self.gt[:, j, :, :, :] 120 | current_input = self.lq[:, j,:, :, :].unsqueeze(1) 121 | pre_input = self.lq[:, j if j == 0 else j-1, :, :, :].unsqueeze(1) 122 | input = torch.concat([pre_input, current_input], dim=1) 123 | out_g, k_cache, v_cache = self.net_g(input.float(), 124 | k_cache, 125 | v_cache) 126 | self.outputs_list.append(out_g) 127 | self.gt_lists.append(target_g_images) 128 | self.lq_lists.append(self.lq[:, j,:, :, :]) 129 | self.net_g.train() 130 | 131 | def non_cached_test(self): 132 | # proxy to the actual scores to save time. 133 | self.net_g.eval() 134 | with torch.no_grad(): 135 | k_cache, v_cache = None, None 136 | pred, _, _, _ = self.net_g(self.lq.float(), k_cache, v_cache) 137 | if isinstance(pred, list): 138 | pred = pred[-1] 139 | self.output = pred 140 | self.net_g.train() 141 | 142 | def dist_validation(self, dataloader, current_iter, tb_logger, save_img, rgb2bgr, use_image): 143 | logger = get_root_logger() 144 | import os 145 | return self.nondist_validation(dataloader, current_iter, 146 | tb_logger, save_img, 147 | rgb2bgr, use_image) 148 | 149 | def nondist_validation(self, dataloader, current_iter, tb_logger, 150 | save_img, rgb2bgr, use_image): 151 | with_metrics = self.opt['val'].get('metrics') is not None 152 | if with_metrics: 153 | self.metric_results = { 154 | metric: 0 155 | for metric in self.opt['val']['metrics'].keys() 156 | } 157 | rank, world_size = get_dist_info() 158 | if rank == 0: 159 | pbar = tqdm(total=len(dataloader), unit='image') 160 | cnt = 0 161 | 162 | for idx, val_data in enumerate(dataloader): 163 | if idx % world_size != rank: 164 | continue 165 | 166 | folder_name, img_name = val_data[len(val_data)-1][0][0].split('.') 167 | self.feed_data(val_data) 168 | self.test() 169 | 170 | for temp_i in range(len(self.outputs_list)): 171 | sr_img = tensor2img(self.outputs_list[temp_i], rgb2bgr=rgb2bgr) 172 | gt_img = tensor2img(self.gt_lists[temp_i], rgb2bgr=rgb2bgr) 173 | lq_img = tensor2img(self.lq_lists[temp_i], rgb2bgr=rgb2bgr) 174 | 175 | if save_img: 176 | # if self.opt['is_train']: 177 | save_img_path = osp.join(self.opt['path']['visualization'], 178 | folder_name, 179 | f'{img_name}_frame{temp_i}_res.png') 180 | 181 | save_gt_img_path = osp.join(self.opt['path']['visualization'], 182 | folder_name, 183 | f'{img_name}_frame{temp_i}_gt.png') 184 | 185 | save_lq_img_path = osp.join(self.opt['path']['visualization'], 186 | folder_name, 187 | f'{img_name}_frame{temp_i}_lq.png') 188 | 189 | imwrite(sr_img, save_img_path) 190 | imwrite(gt_img, save_gt_img_path) 191 | imwrite(lq_img, save_lq_img_path) 192 | 193 | if with_metrics: 194 | # calculate metrics 195 | opt_metric = deepcopy(self.opt['val']['metrics']) 196 | if use_image: 197 | for name, opt_ in opt_metric.items(): 198 | metric_type = opt_.pop('type') 199 | self.metric_results[name] += getattr( 200 | metric_module, metric_type)(sr_img, gt_img, **opt_) 201 | else: 202 | for name, opt_ in opt_metric.items(): 203 | metric_type = opt_.pop('type') 204 | self.metric_results[name] += getattr( 205 | metric_module, metric_type)(self.outputs_list[temp_i], self.gt_lists[temp_i], **opt_) 206 | 207 | cnt += 1 208 | if rank == 0: 209 | for _ in range(world_size): 210 | pbar.update(1) 211 | pbar.set_description(f'Test {img_name}') 212 | 213 | if rank == 0: 214 | pbar.close() 215 | 216 | current_metric = 0. 217 | if with_metrics: 218 | for metric in self.metric_results.keys(): 219 | self.metric_results[metric] /= cnt 220 | current_metric = self.metric_results[metric] 221 | 222 | self._log_validation_metric_values(current_iter, 223 | tb_logger) 224 | return current_metric 225 | 226 | 227 | def _log_validation_metric_values(self, current_iter, tb_logger): 228 | log_str = f'Validation,\t' 229 | for metric, value in self.metric_results.items(): 230 | log_str += f'\t # {metric}: {value:.4f}' 231 | logger = get_root_logger() 232 | logger.info(log_str) 233 | if tb_logger: 234 | for metric, value in self.metric_results.items(): 235 | tb_logger.add_scalar(f'metrics/{metric}', value, current_iter) 236 | 237 | def get_current_visuals(self): 238 | # pick the current frame. 239 | out_dict = OrderedDict() 240 | out_dict['lq'] = self.lq[:,1,:,:,:].detach().cpu() 241 | out_dict['result'] = self.output.detach().cpu() 242 | if hasattr(self, 'gt'): 243 | out_dict['gt'] = self.gt[:,1,:,:,:].detach().cpu() 244 | return out_dict 245 | 246 | def save(self, epoch, current_iter): 247 | self.save_network(self.net_g, 'net_g', current_iter) 248 | self.save_training_state(epoch, current_iter) -------------------------------------------------------------------------------- /basicsr/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import logging 4 | import math 5 | import random 6 | import time 7 | import torch 8 | import sys 9 | from pathlib import Path 10 | sys.path.append(str(Path(__file__).parents[1])) 11 | from os import path as osp 12 | 13 | from basicsr.data import create_dataloader 14 | from basicsr.data.data_sampler import EnlargedSampler 15 | from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher 16 | from basicsr.models import create_model 17 | from basicsr.utils import (MessageLogger, check_resume, get_env_info, 18 | get_root_logger, get_time_str, init_tb_logger, 19 | init_wandb_logger, make_exp_dirs, mkdir_and_rename, 20 | set_random_seed) 21 | from basicsr.utils.dist_util import get_dist_info, init_dist 22 | from basicsr.utils.options import dict2str, parse 23 | 24 | # for superresolution uncomment this line, and comment line # 26. 25 | # from basicsr.data.video_super_image_dataset import VideoSuperImageDataset as VideoImageDataset 26 | 27 | # for deblurring/deraining/etc. comment the line above, and uncomment the next line. 28 | from basicsr.data.video_image_dataset import VideoImageDataset 29 | import torch.distributed as dist 30 | 31 | import os 32 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 33 | def parse_options(is_train=True): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | '-opt', type=str, required=True, help='Path to option YAML file.') 37 | parser.add_argument( 38 | '--launcher', 39 | choices=['none', 'pytorch', 'slurm'], 40 | default='none', 41 | help='job launcher') 42 | parser.add_argument('--local_rank', type=int, default=0) 43 | args, unknown = parser.parse_known_args() 44 | opt = parse(args.opt, is_train=is_train) 45 | 46 | # distributed settings 47 | if args.launcher == 'none': 48 | opt['dist'] = False 49 | print('Disable distributed.', flush=True) 50 | else: 51 | opt['dist'] = True 52 | # increase timeout to 1.5 hours 53 | opt['dist_params']['timeout'] = datetime.timedelta(seconds=5400) 54 | if args.launcher == 'slurm' and 'dist_params' in opt: 55 | init_dist(args.launcher, **opt['dist_params']) 56 | else: 57 | init_dist(args.launcher) 58 | print('init dist .. ', args.launcher) 59 | 60 | opt['rank'], opt['world_size'] = get_dist_info() 61 | 62 | # random seed 63 | seed = opt.get('manual_seed') 64 | if seed is None: 65 | seed = random.randint(1, 10000) 66 | opt['manual_seed'] = seed 67 | set_random_seed(seed + opt['rank']) 68 | torch.manual_seed(seed+opt['rank']) 69 | 70 | return opt 71 | 72 | 73 | def init_loggers(opt): 74 | log_file = osp.join(opt['path']['log'], 75 | f"train_{opt['name']}_{get_time_str()}.log") 76 | logger = get_root_logger( 77 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 78 | logger.info(get_env_info()) 79 | logger.info(dict2str(opt)) 80 | 81 | # initialize wandb logger before tensorboard logger to allow proper sync: 82 | if (opt['logger'].get('wandb') 83 | is not None) and (opt['logger']['wandb'].get('project') 84 | is not None) and ('debug' not in opt['name']): 85 | assert opt['logger'].get('use_tb_logger') is True, ( 86 | 'should turn on tensorboard when using wandb') 87 | init_wandb_logger(opt) 88 | tb_logger = None 89 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: 90 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) 91 | return logger, tb_logger 92 | 93 | 94 | def create_train_val_dataloader(opt, logger): 95 | # create train and val dataloaders 96 | train_loader, val_loader = None, None 97 | for phase, dataset_opt in opt['datasets'].items(): 98 | if phase == 'train': 99 | dataset_enlarge_ratio = 1 100 | train_set = VideoImageDataset(opt, phase) 101 | train_sampler = EnlargedSampler(train_set, opt['world_size'], 102 | opt['rank'], dataset_enlarge_ratio) 103 | train_loader = create_dataloader( 104 | train_set, 105 | dataset_opt, 106 | num_gpu=opt['num_gpu'], 107 | dist=opt['dist'], 108 | sampler=train_sampler, 109 | seed=opt['manual_seed']) 110 | 111 | num_iter_per_epoch = math.ceil( 112 | len(train_set) * dataset_enlarge_ratio / 113 | (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) 114 | total_iters = int(opt['train']['total_iter']) 115 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) 116 | logger.info( 117 | 'Training statistics:' 118 | f'\n\tNumber of train images: {len(train_set)}' 119 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' 120 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' 121 | f'\n\tWorld size (gpu number): {opt["world_size"]}' 122 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' 123 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') 124 | elif phase == 'val': 125 | val_set = VideoImageDataset(opt, phase) 126 | val_loader = create_dataloader( 127 | val_set, 128 | dataset_opt, 129 | num_gpu=opt['num_gpu'], 130 | dist=opt['dist'], 131 | sampler=None, 132 | seed=opt['manual_seed']) 133 | logger.info( 134 | f'Number of val images/folders in dataset: ' 135 | f'{len(val_set)}') 136 | else: 137 | raise ValueError(f'Dataset phase {phase} is not recognized') 138 | 139 | return train_loader, train_sampler, val_loader, total_epochs, total_iters 140 | 141 | def main(): 142 | # parse options, set distributed setting, set ramdom seed 143 | opt = parse_options(is_train=True) 144 | 145 | torch.backends.cudnn.benchmark = True 146 | # torch.backends.cudnn.deterministic = True 147 | state_folder_path = 'experiments/{}/training_states/'.format(opt['name']) 148 | import os 149 | try: 150 | states = os.listdir(state_folder_path) 151 | except: 152 | states = [] 153 | 154 | resume_state = None 155 | if len(states) > 0: 156 | max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states])) 157 | resume_state = os.path.join(state_folder_path, max_state_file) 158 | opt['path']['resume_state'] = resume_state 159 | 160 | # load resume states if necessary 161 | if opt['path'].get('resume_state'): 162 | device_id = torch.cuda.current_device() 163 | resume_state = torch.load( 164 | opt['path']['resume_state'], 165 | map_location=lambda storage, loc: storage.cuda(device_id)) 166 | else: 167 | resume_state = None 168 | 169 | # mkdir for experiments and logger 170 | if resume_state is None: 171 | make_exp_dirs(opt) 172 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ 173 | 'name'] and opt['rank'] == 0: 174 | mkdir_and_rename(osp.join('tb_logger', opt['name'])) 175 | 176 | # initialize loggers 177 | logger, tb_logger = init_loggers(opt) 178 | 179 | # create train and validation dataloaders 180 | result = create_train_val_dataloader(opt, logger) 181 | train_loader, train_sampler, val_loader, total_epochs, total_iters = result 182 | print("Len --- train_loader", len(train_loader)) 183 | if resume_state: # resume training 184 | print("resuming is True") 185 | check_resume(opt, resume_state['iter']) 186 | model = create_model(opt) 187 | model.resume_training(resume_state) # handle optimizers and scheduler 188 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " 189 | f"iter: {resume_state['iter']}.") 190 | start_epoch = resume_state['epoch'] 191 | current_iter = resume_state['iter'] 192 | del resume_state 193 | torch.cuda.empty_cache() 194 | else: 195 | model = create_model(opt) 196 | start_epoch = 0 197 | current_iter = 0 198 | 199 | # create message logger (formatted outputs) 200 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 201 | 202 | # dataloader prefetcher 203 | prefetch_mode = opt['datasets']['train'].get('prefetch_mode') 204 | if prefetch_mode is None or prefetch_mode == 'cpu': 205 | prefetcher = CPUPrefetcher(train_loader) 206 | elif prefetch_mode == 'cuda': 207 | prefetcher = CUDAPrefetcher(train_loader, opt) 208 | logger.info(f'Use {prefetch_mode} prefetch dataloader') 209 | if opt['datasets']['train'].get('pin_memory') is not True: 210 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') 211 | else: 212 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' 213 | "Supported ones are: None, 'cuda', 'cpu'.") 214 | 215 | # training 216 | logger.info( 217 | f'Start training from epoch: {start_epoch}, iter: {current_iter}') 218 | data_time, iter_time = time.time(), time.time() 219 | start_time = time.time() 220 | 221 | epoch = start_epoch 222 | while current_iter <= total_iters: 223 | train_sampler.set_epoch(epoch) 224 | prefetcher.reset() 225 | train_data = prefetcher.next() 226 | 227 | while train_data is not None: 228 | data_time = time.time() - data_time 229 | 230 | current_iter += 1 231 | if current_iter > total_iters: 232 | break 233 | # update learning rate 234 | model.update_learning_rate( 235 | current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) 236 | # training 237 | model.feed_data(train_data) 238 | model.optimize_parameters(current_iter) 239 | iter_time = time.time() - iter_time 240 | # log 241 | if dist.get_rank() == 0 and current_iter % opt['logger']['print_freq'] == 0: 242 | log_vars = {'epoch': epoch, 'iter': current_iter} 243 | log_vars.update({'lrs': model.get_current_learning_rate()}) 244 | log_vars.update({'time': iter_time, 'data_time': data_time}) 245 | log_vars.update(model.get_current_log()) 246 | print("Loss at iteration", current_iter, model.get_current_log()['l_pix']) 247 | msg_logger(log_vars) 248 | 249 | # save models and training states 250 | if dist.get_rank() == 0 and current_iter % opt['logger']['save_checkpoint_freq'] == 0: 251 | logger.info('Saving models and training states.') 252 | print("saving") 253 | model.save(epoch, current_iter) 254 | 255 | # validation 256 | if opt.get('val') is not None and (current_iter % opt['val']['val_freq'] == 0): 257 | rgb2bgr = opt['val'].get('rgb2bgr', True) 258 | use_image = opt['val'].get('use_image', True) 259 | model.validation(val_loader, current_iter, tb_logger, 260 | opt['val']['save_img'], 261 | rgb2bgr, use_image) 262 | log_vars = {'epoch': epoch, 'iter': current_iter, 'total_iter': total_iters} 263 | log_vars.update({'lrs': model.get_current_learning_rate()}) 264 | log_vars.update(model.get_current_log()) 265 | msg_logger(log_vars) 266 | 267 | data_time = time.time() 268 | iter_time = time.time() 269 | train_data = prefetcher.next() 270 | 271 | if dist.get_rank() == 0: 272 | print(f"Loss at the end of Epoch {epoch} is {model.get_current_log()['l_pix']}.") 273 | # end of iter 274 | epoch += 1 275 | 276 | # end of epoch 277 | consumed_time = str( 278 | datetime.timedelta(seconds=int(time.time() - start_time))) 279 | logger.info(f'End of training. Time consumed: {consumed_time}') 280 | logger.info('Save the latest model.') 281 | model.save(epoch=-1, current_iter=-1) # -1 stands for the latest 282 | if opt.get('val') is not None: 283 | rgb2bgr = opt['val'].get('rgb2bgr', True) 284 | use_image = opt['val'].get('use_image', True) 285 | model.validation(val_loader, current_iter, tb_logger, 286 | opt['val']['save_img'], 287 | rgb2bgr, use_image) 288 | if tb_logger: 289 | tb_logger.close() 290 | 291 | 292 | if __name__ == '__main__': 293 | main() 294 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | from .file_client import FileClient 6 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, pltimwrite 7 | from .logger import (MessageLogger, get_env_info, get_root_logger, 8 | init_tb_logger, init_wandb_logger) 9 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 10 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 11 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 12 | 13 | __all__ = [ 14 | # file_client.py 15 | 'FileClient', 16 | # img_util.py 17 | 'img2tensor', 18 | 'tensor2img', 19 | 'imfrombytes', 20 | 'imwrite', 21 | 'pltimwrite', 22 | 'crop_border', 23 | # logger.py 24 | 'MessageLogger', 25 | 'init_tb_logger', 26 | 'init_wandb_logger', 27 | 'get_root_logger', 28 | 'get_env_info', 29 | # misc.py 30 | 'set_random_seed', 31 | 'get_time_str', 32 | 'mkdir_and_rename', 33 | 'make_exp_dirs', 34 | 'scandir', 35 | 'check_resume', 36 | 'sizeof_fmt', 37 | 'padding', 38 | 'create_lmdb_for_reds', 39 | 'create_lmdb_for_gopro', 40 | 'create_lmdb_for_rain13k', 41 | ] 42 | -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | import argparse 6 | from os import path as osp 7 | 8 | from basicsr.utils import scandir 9 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 10 | 11 | def prepare_keys(folder_path, suffix='png'): 12 | """Prepare image path list and keys for DIV2K dataset. 13 | 14 | Args: 15 | folder_path (str): Folder path. 16 | 17 | Returns: 18 | list[str]: Image path list. 19 | list[str]: Key list. 20 | """ 21 | print('Reading image path list ...') 22 | img_path_list = sorted( 23 | list(scandir(folder_path, suffix=suffix, recursive=False))) 24 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 25 | 26 | return img_path_list, keys 27 | 28 | def create_lmdb_for_reds(): 29 | folder_path = './datasets/REDS/val/sharp_300' 30 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 31 | img_path_list, keys = prepare_keys(folder_path, 'png') 32 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 33 | # 34 | folder_path = './datasets/REDS/val/blur_300' 35 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 36 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 37 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 38 | 39 | folder_path = './datasets/REDS/train/train_sharp' 40 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 41 | img_path_list, keys = prepare_keys(folder_path, 'png') 42 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 43 | 44 | folder_path = './datasets/REDS/train/train_blur_jpeg' 45 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 46 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 47 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 48 | 49 | 50 | def create_lmdb_for_gopro(): 51 | folder_path = './datasets/GoPro/train/blur_crops' 52 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 53 | 54 | img_path_list, keys = prepare_keys(folder_path, 'png') 55 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 56 | 57 | folder_path = './datasets/GoPro/train/sharp_crops' 58 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 59 | 60 | img_path_list, keys = prepare_keys(folder_path, 'png') 61 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 62 | 63 | folder_path = './datasets/GoPro/test/target' 64 | lmdb_path = './datasets/GoPro/test/target.lmdb' 65 | 66 | img_path_list, keys = prepare_keys(folder_path, 'png') 67 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 68 | 69 | folder_path = './datasets/GoPro/test/input' 70 | lmdb_path = './datasets/GoPro/test/input.lmdb' 71 | 72 | img_path_list, keys = prepare_keys(folder_path, 'png') 73 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 74 | 75 | def create_lmdb_for_rain13k(): 76 | folder_path = './datasets/Rain13k/train/input' 77 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 78 | 79 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 80 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 81 | 82 | folder_path = './datasets/Rain13k/train/target' 83 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 84 | 85 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 86 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 87 | 88 | def create_lmdb_for_SIDD(): 89 | folder_path = './datasets/SIDD/train/input_crops' 90 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 91 | 92 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 93 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 94 | 95 | folder_path = './datasets/SIDD/train/gt_crops' 96 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 97 | 98 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 99 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 100 | 101 | #for val 102 | folder_path = './datasets/SIDD/val/input_crops' 103 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 104 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 105 | if not osp.exists(folder_path): 106 | os.makedirs(folder_path) 107 | assert osp.exists(mat_path) 108 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 109 | N, B, H ,W, C = data.shape 110 | data = data.reshape(N*B, H, W, C) 111 | for i in tqdm(range(N*B)): 112 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 113 | img_path_list, keys = prepare_keys(folder_path, 'png') 114 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 115 | 116 | folder_path = './datasets/SIDD/val/gt_crops' 117 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 118 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 119 | if not osp.exists(folder_path): 120 | os.makedirs(folder_path) 121 | assert osp.exists(mat_path) 122 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 123 | N, B, H ,W, C = data.shape 124 | data = data.reshape(N*B, H, W, C) 125 | for i in tqdm(range(N*B)): 126 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 127 | img_path_list, keys = prepare_keys(folder_path, 'png') 128 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 129 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | 6 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 7 | import functools 8 | import os 9 | import subprocess 10 | import torch 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | import datetime 14 | 15 | def init_dist(launcher, backend='nccl', **kwargs): 16 | if mp.get_start_method(allow_none=True) is None: 17 | mp.set_start_method('spawn') 18 | if launcher == 'pytorch': 19 | _init_dist_pytorch(backend, **kwargs) 20 | elif launcher == 'slurm': 21 | _init_dist_slurm(backend, **kwargs) 22 | else: 23 | raise ValueError(f'Invalid launcher type: {launcher}') 24 | 25 | 26 | def _init_dist_pytorch(backend, **kwargs): 27 | rank = int(os.environ['RANK']) 28 | num_gpus = torch.cuda.device_count() 29 | torch.cuda.set_device(rank % num_gpus) 30 | dist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=5400), **kwargs) 31 | 32 | 33 | def _init_dist_slurm(backend, port=None): 34 | """Initialize slurm distributed training environment. 35 | 36 | If argument ``port`` is not specified, then the master port will be system 37 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 38 | environment variable, then a default port ``29500`` will be used. 39 | 40 | Args: 41 | backend (str): Backend of torch.distributed. 42 | port (int, optional): Master port. Defaults to None. 43 | """ 44 | proc_id = int(os.environ['SLURM_PROCID']) 45 | ntasks = int(os.environ['SLURM_NTASKS']) 46 | node_list = os.environ['SLURM_NODELIST'] 47 | num_gpus = torch.cuda.device_count() 48 | torch.cuda.set_device(proc_id % num_gpus) 49 | addr = subprocess.getoutput( 50 | f'scontrol show hostname {node_list} | head -n1') 51 | # specify master port 52 | if port is not None: 53 | os.environ['MASTER_PORT'] = str(port) 54 | elif 'MASTER_PORT' in os.environ: 55 | pass # use MASTER_PORT in the environment variable 56 | else: 57 | # 29500 is torch.distributed default port 58 | os.environ['MASTER_PORT'] = '29500' 59 | os.environ['MASTER_ADDR'] = addr 60 | os.environ['WORLD_SIZE'] = str(ntasks) 61 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 62 | os.environ['RANK'] = str(proc_id) 63 | dist.init_process_group(backend=backend) 64 | 65 | 66 | def get_dist_info(): 67 | if dist.is_available(): 68 | initialized = dist.is_initialized() 69 | else: 70 | initialized = False 71 | if initialized: 72 | rank = dist.get_rank() 73 | world_size = dist.get_world_size() 74 | else: 75 | rank = 0 76 | world_size = 1 77 | return rank, world_size 78 | 79 | 80 | def master_only(func): 81 | 82 | @functools.wraps(func) 83 | def wrapper(*args, **kwargs): 84 | rank, _ = get_dist_info() 85 | if rank == 0: 86 | return func(*args, **kwargs) 87 | 88 | return wrapper 89 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import math 8 | import requests 9 | from tqdm import tqdm 10 | 11 | from .misc import sizeof_fmt 12 | 13 | 14 | def download_file_from_google_drive(file_id, save_path): 15 | """Download files from google drive. 16 | 17 | Ref: 18 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 19 | 20 | Args: 21 | file_id (str): File id. 22 | save_path (str): Save path. 23 | """ 24 | 25 | session = requests.Session() 26 | URL = 'https://docs.google.com/uc?export=download' 27 | params = {'id': file_id} 28 | 29 | response = session.get(URL, params=params, stream=True) 30 | token = get_confirm_token(response) 31 | if token: 32 | params['confirm'] = token 33 | response = session.get(URL, params=params, stream=True) 34 | 35 | # get file size 36 | response_file_size = session.get( 37 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 38 | if 'Content-Range' in response_file_size.headers: 39 | file_size = int( 40 | response_file_size.headers['Content-Range'].split('/')[1]) 41 | else: 42 | file_size = None 43 | 44 | save_response_content(response, save_path, file_size) 45 | 46 | 47 | def get_confirm_token(response): 48 | for key, value in response.cookies.items(): 49 | if key.startswith('download_warning'): 50 | return value 51 | return None 52 | 53 | 54 | def save_response_content(response, 55 | destination, 56 | file_size=None, 57 | chunk_size=32768): 58 | if file_size is not None: 59 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 60 | 61 | readable_file_size = sizeof_fmt(file_size) 62 | else: 63 | pbar = None 64 | 65 | with open(destination, 'wb') as f: 66 | downloaded_size = 0 67 | for chunk in response.iter_content(chunk_size): 68 | downloaded_size += chunk_size 69 | if pbar is not None: 70 | pbar.update(1) 71 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 72 | f'/ {readable_file_size}') 73 | if chunk: # filter out keep-alive new chunks 74 | f.write(chunk) 75 | if pbar is not None: 76 | pbar.close() 77 | -------------------------------------------------------------------------------- /basicsr/utils/face_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import numpy as np 9 | import os 10 | import torch 11 | from skimage import transform as trans 12 | 13 | from basicsr.utils import imwrite 14 | 15 | try: 16 | import dlib 17 | except ImportError: 18 | print('Please install dlib before testing face restoration.' 19 | 'Reference: https://github.com/davisking/dlib') 20 | 21 | 22 | class FaceRestorationHelper(object): 23 | """Helper for the face restoration pipeline.""" 24 | 25 | def __init__(self, upscale_factor, face_size=512): 26 | self.upscale_factor = upscale_factor 27 | self.face_size = (face_size, face_size) 28 | 29 | # standard 5 landmarks for FFHQ faces with 1024 x 1024 30 | self.face_template = np.array([[686.77227723, 488.62376238], 31 | [586.77227723, 493.59405941], 32 | [337.91089109, 488.38613861], 33 | [437.95049505, 493.51485149], 34 | [513.58415842, 678.5049505]]) 35 | self.face_template = self.face_template / (1024 // face_size) 36 | # for estimation the 2D similarity transformation 37 | self.similarity_trans = trans.SimilarityTransform() 38 | 39 | self.all_landmarks_5 = [] 40 | self.all_landmarks_68 = [] 41 | self.affine_matrices = [] 42 | self.inverse_affine_matrices = [] 43 | self.cropped_faces = [] 44 | self.restored_faces = [] 45 | self.save_png = True 46 | 47 | def init_dlib(self, detection_path, landmark5_path, landmark68_path): 48 | """Initialize the dlib detectors and predictors.""" 49 | self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) 50 | self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) 51 | self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) 52 | 53 | def free_dlib_gpu_memory(self): 54 | del self.face_detector 55 | del self.shape_predictor_5 56 | del self.shape_predictor_68 57 | 58 | def read_input_image(self, img_path): 59 | # self.input_img is Numpy array, (h, w, c) with RGB order 60 | self.input_img = dlib.load_rgb_image(img_path) 61 | 62 | def detect_faces(self, 63 | img_path, 64 | upsample_num_times=1, 65 | only_keep_largest=False): 66 | """ 67 | Args: 68 | img_path (str): Image path. 69 | upsample_num_times (int): Upsamples the image before running the 70 | face detector 71 | 72 | Returns: 73 | int: Number of detected faces. 74 | """ 75 | self.read_input_image(img_path) 76 | det_faces = self.face_detector(self.input_img, upsample_num_times) 77 | if len(det_faces) == 0: 78 | print('No face detected. Try to increase upsample_num_times.') 79 | else: 80 | if only_keep_largest: 81 | print('Detect several faces and only keep the largest.') 82 | face_areas = [] 83 | for i in range(len(det_faces)): 84 | face_area = (det_faces[i].rect.right() - 85 | det_faces[i].rect.left()) * ( 86 | det_faces[i].rect.bottom() - 87 | det_faces[i].rect.top()) 88 | face_areas.append(face_area) 89 | largest_idx = face_areas.index(max(face_areas)) 90 | self.det_faces = [det_faces[largest_idx]] 91 | else: 92 | self.det_faces = det_faces 93 | return len(self.det_faces) 94 | 95 | def get_face_landmarks_5(self): 96 | for face in self.det_faces: 97 | shape = self.shape_predictor_5(self.input_img, face.rect) 98 | landmark = np.array([[part.x, part.y] for part in shape.parts()]) 99 | self.all_landmarks_5.append(landmark) 100 | return len(self.all_landmarks_5) 101 | 102 | def get_face_landmarks_68(self): 103 | """Get 68 densemarks for cropped images. 104 | 105 | Should only have one face at most in the cropped image. 106 | """ 107 | num_detected_face = 0 108 | for idx, face in enumerate(self.cropped_faces): 109 | # face detection 110 | det_face = self.face_detector(face, 1) # TODO: can we remove it? 111 | if len(det_face) == 0: 112 | print(f'Cannot find faces in cropped image with index {idx}.') 113 | self.all_landmarks_68.append(None) 114 | else: 115 | if len(det_face) > 1: 116 | print('Detect several faces in the cropped face. Use the ' 117 | ' largest one. Note that it will also cause overlap ' 118 | 'during paste_faces_to_input_image.') 119 | face_areas = [] 120 | for i in range(len(det_face)): 121 | face_area = (det_face[i].rect.right() - 122 | det_face[i].rect.left()) * ( 123 | det_face[i].rect.bottom() - 124 | det_face[i].rect.top()) 125 | face_areas.append(face_area) 126 | largest_idx = face_areas.index(max(face_areas)) 127 | face_rect = det_face[largest_idx].rect 128 | else: 129 | face_rect = det_face[0].rect 130 | shape = self.shape_predictor_68(face, face_rect) 131 | landmark = np.array([[part.x, part.y] 132 | for part in shape.parts()]) 133 | self.all_landmarks_68.append(landmark) 134 | num_detected_face += 1 135 | 136 | return num_detected_face 137 | 138 | def warp_crop_faces(self, 139 | save_cropped_path=None, 140 | save_inverse_affine_path=None): 141 | """Get affine matrix, warp and cropped faces. 142 | 143 | Also get inverse affine matrix for post-processing. 144 | """ 145 | for idx, landmark in enumerate(self.all_landmarks_5): 146 | # use 5 landmarks to get affine matrix 147 | self.similarity_trans.estimate(landmark, self.face_template) 148 | affine_matrix = self.similarity_trans.params[0:2, :] 149 | self.affine_matrices.append(affine_matrix) 150 | # warp and crop faces 151 | cropped_face = cv2.warpAffine(self.input_img, affine_matrix, 152 | self.face_size) 153 | self.cropped_faces.append(cropped_face) 154 | # save the cropped face 155 | if save_cropped_path is not None: 156 | path, ext = os.path.splitext(save_cropped_path) 157 | if self.save_png: 158 | save_path = f'{path}_{idx:02d}.png' 159 | else: 160 | save_path = f'{path}_{idx:02d}{ext}' 161 | 162 | imwrite( 163 | cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) 164 | 165 | # get inverse affine matrix 166 | self.similarity_trans.estimate(self.face_template, 167 | landmark * self.upscale_factor) 168 | inverse_affine = self.similarity_trans.params[0:2, :] 169 | self.inverse_affine_matrices.append(inverse_affine) 170 | # save inverse affine matrices 171 | if save_inverse_affine_path is not None: 172 | path, _ = os.path.splitext(save_inverse_affine_path) 173 | save_path = f'{path}_{idx:02d}.pth' 174 | torch.save(inverse_affine, save_path) 175 | 176 | def add_restored_face(self, face): 177 | self.restored_faces.append(face) 178 | 179 | def paste_faces_to_input_image(self, save_path): 180 | # operate in the BGR order 181 | input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) 182 | h, w, _ = input_img.shape 183 | h_up, w_up = h * self.upscale_factor, w * self.upscale_factor 184 | # simply resize the background 185 | upsample_img = cv2.resize(input_img, (w_up, h_up)) 186 | assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( 187 | 'length of restored_faces and affine_matrices are different.') 188 | for restored_face, inverse_affine in zip(self.restored_faces, 189 | self.inverse_affine_matrices): 190 | inv_restored = cv2.warpAffine(restored_face, inverse_affine, 191 | (w_up, h_up)) 192 | mask = np.ones((*self.face_size, 3), dtype=np.float32) 193 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) 194 | # remove the black borders 195 | inv_mask_erosion = cv2.erode( 196 | inv_mask, 197 | np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), 198 | np.uint8)) 199 | inv_restored_remove_border = inv_mask_erosion * inv_restored 200 | total_face_area = np.sum(inv_mask_erosion) // 3 201 | # compute the fusion edge based on the area of face 202 | w_edge = int(total_face_area**0.5) // 20 203 | erosion_radius = w_edge * 2 204 | inv_mask_center = cv2.erode( 205 | inv_mask_erosion, 206 | np.ones((erosion_radius, erosion_radius), np.uint8)) 207 | blur_size = w_edge * 2 208 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, 209 | (blur_size + 1, blur_size + 1), 0) 210 | upsample_img = inv_soft_mask * inv_restored_remove_border + ( 211 | 1 - inv_soft_mask) * upsample_img 212 | if self.save_png: 213 | save_path = save_path.replace('.jpg', 214 | '.png').replace('.jpeg', '.png') 215 | imwrite(upsample_img.astype(np.uint8), save_path) 216 | 217 | def clean_all(self): 218 | self.all_landmarks_5 = [] 219 | self.all_landmarks_68 = [] 220 | self.restored_faces = [] 221 | self.affine_matrices = [] 222 | self.cropped_faces = [] 223 | self.inverse_affine_matrices = [] 224 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 8 | from abc import ABCMeta, abstractmethod 9 | 10 | 11 | class BaseStorageBackend(metaclass=ABCMeta): 12 | """Abstract class of storage backends. 13 | 14 | All backends need to implement two apis: ``get()`` and ``get_text()``. 15 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 16 | as texts. 17 | """ 18 | 19 | @abstractmethod 20 | def get(self, filepath): 21 | pass 22 | 23 | @abstractmethod 24 | def get_text(self, filepath): 25 | pass 26 | 27 | 28 | class MemcachedBackend(BaseStorageBackend): 29 | """Memcached storage backend. 30 | 31 | Attributes: 32 | server_list_cfg (str): Config file for memcached server list. 33 | client_cfg (str): Config file for memcached client. 34 | sys_path (str | None): Additional path to be appended to `sys.path`. 35 | Default: None. 36 | """ 37 | 38 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 39 | if sys_path is not None: 40 | import sys 41 | sys.path.append(sys_path) 42 | try: 43 | import mc 44 | except ImportError: 45 | raise ImportError( 46 | 'Please install memcached to enable MemcachedBackend.') 47 | 48 | self.server_list_cfg = server_list_cfg 49 | self.client_cfg = client_cfg 50 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, 51 | self.client_cfg) 52 | # mc.pyvector servers as a point which points to a memory cache 53 | self._mc_buffer = mc.pyvector() 54 | 55 | def get(self, filepath): 56 | filepath = str(filepath) 57 | import mc 58 | self._client.Get(filepath, self._mc_buffer) 59 | value_buf = mc.ConvertBuffer(self._mc_buffer) 60 | return value_buf 61 | 62 | def get_text(self, filepath): 63 | raise NotImplementedError 64 | 65 | 66 | class HardDiskBackend(BaseStorageBackend): 67 | """Raw hard disks storage backend.""" 68 | 69 | def get(self, filepath): 70 | filepath = str(filepath) 71 | with open(filepath, 'rb') as f: 72 | value_buf = f.read() 73 | return value_buf 74 | 75 | def get_text(self, filepath): 76 | filepath = str(filepath) 77 | with open(filepath, 'r') as f: 78 | value_buf = f.read() 79 | return value_buf 80 | 81 | 82 | class LmdbBackend(BaseStorageBackend): 83 | """Lmdb storage backend. 84 | 85 | Args: 86 | db_paths (str | list[str]): Lmdb database paths. 87 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 88 | readonly (bool, optional): Lmdb environment parameter. If True, 89 | disallow any write operations. Default: True. 90 | lock (bool, optional): Lmdb environment parameter. If False, when 91 | concurrent access occurs, do not lock the database. Default: False. 92 | readahead (bool, optional): Lmdb environment parameter. If False, 93 | disable the OS filesystem readahead mechanism, which may improve 94 | random read performance when a database is larger than RAM. 95 | Default: False. 96 | 97 | Attributes: 98 | db_paths (list): Lmdb database path. 99 | _client (list): A list of several lmdb envs. 100 | """ 101 | 102 | def __init__(self, 103 | db_paths, 104 | client_keys='default', 105 | readonly=True, 106 | lock=False, 107 | readahead=False, 108 | **kwargs): 109 | try: 110 | import lmdb 111 | except ImportError: 112 | raise ImportError('Please install lmdb to enable LmdbBackend.') 113 | 114 | if isinstance(client_keys, str): 115 | client_keys = [client_keys] 116 | 117 | if isinstance(db_paths, list): 118 | self.db_paths = [str(v) for v in db_paths] 119 | elif isinstance(db_paths, str): 120 | self.db_paths = [str(db_paths)] 121 | assert len(client_keys) == len(self.db_paths), ( 122 | 'client_keys and db_paths should have the same length, ' 123 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 124 | 125 | self._client = {} 126 | 127 | for client, path in zip(client_keys, self.db_paths): 128 | self._client[client] = lmdb.open( 129 | path, 130 | readonly=readonly, 131 | lock=lock, 132 | readahead=readahead, 133 | map_size=8*1024*10485760, 134 | # max_readers=1, 135 | **kwargs) 136 | 137 | def get(self, filepath, client_key): 138 | """Get values according to the filepath from one lmdb named client_key. 139 | 140 | Args: 141 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 142 | client_key (str): Used for distinguishing differnet lmdb envs. 143 | """ 144 | filepath = str(filepath) 145 | assert client_key in self._client, (f'client_key {client_key} is not ' 146 | 'in lmdb clients.') 147 | client = self._client[client_key] 148 | with client.begin(write=False) as txn: 149 | value_buf = txn.get(filepath.encode('ascii')) 150 | return value_buf 151 | 152 | def get_text(self, filepath): 153 | raise NotImplementedError 154 | 155 | 156 | class FileClient(object): 157 | """A general file client to access files in different backend. 158 | 159 | The client loads a file or text in a specified backend from its path 160 | and return it as a binary file. it can also register other backend 161 | accessor with a given name and backend class. 162 | 163 | Attributes: 164 | backend (str): The storage backend type. Options are "disk", 165 | "memcached" and "lmdb". 166 | client (:obj:`BaseStorageBackend`): The backend object. 167 | """ 168 | 169 | _backends = { 170 | 'disk': HardDiskBackend, 171 | 'memcached': MemcachedBackend, 172 | 'lmdb': LmdbBackend, 173 | } 174 | 175 | def __init__(self, backend='disk', **kwargs): 176 | if backend not in self._backends: 177 | raise ValueError( 178 | f'Backend {backend} is not supported. Currently supported ones' 179 | f' are {list(self._backends.keys())}') 180 | self.backend = backend 181 | self.client = self._backends[backend](**kwargs) 182 | 183 | def get(self, filepath, client_key='default'): 184 | # client_key is used only for lmdb, where different fileclients have 185 | # different lmdb environments. 186 | if self.backend == 'lmdb': 187 | return self.client.get(filepath, client_key) 188 | else: 189 | return self.client.get(filepath) 190 | 191 | def get_text(self, filepath): 192 | return self.client.get_text(filepath) 193 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 8 | import cv2 9 | import numpy as np 10 | import os 11 | 12 | 13 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 14 | """Read an optical flow map. 15 | 16 | Args: 17 | flow_path (ndarray or str): Flow path. 18 | quantize (bool): whether to read quantized pair, if set to True, 19 | remaining args will be passed to :func:`dequantize_flow`. 20 | concat_axis (int): The axis that dx and dy are concatenated, 21 | can be either 0 or 1. Ignored if quantize is False. 22 | 23 | Returns: 24 | ndarray: Optical flow represented as a (h, w, 2) numpy array 25 | """ 26 | if quantize: 27 | assert concat_axis in [0, 1] 28 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 29 | if cat_flow.ndim != 2: 30 | raise IOError(f'{flow_path} is not a valid quantized flow file, ' 31 | f'its dimension is {cat_flow.ndim}.') 32 | assert cat_flow.shape[concat_axis] % 2 == 0 33 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 34 | flow = dequantize_flow(dx, dy, *args, **kwargs) 35 | else: 36 | with open(flow_path, 'rb') as f: 37 | try: 38 | header = f.read(4).decode('utf-8') 39 | except Exception: 40 | raise IOError(f'Invalid flow file: {flow_path}') 41 | else: 42 | if header != 'PIEH': 43 | raise IOError(f'Invalid flow file: {flow_path}, ' 44 | 'header does not contain PIEH') 45 | 46 | w = np.fromfile(f, np.int32, 1).squeeze() 47 | h = np.fromfile(f, np.int32, 1).squeeze() 48 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 49 | 50 | return flow.astype(np.float32) 51 | 52 | 53 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 54 | """Write optical flow to file. 55 | 56 | If the flow is not quantized, it will be saved as a .flo file losslessly, 57 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 58 | will be concatenated horizontally into a single image if quantize is True.) 59 | 60 | Args: 61 | flow (ndarray): (h, w, 2) array of optical flow. 62 | filename (str): Output filepath. 63 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 64 | images. If set to True, remaining args will be passed to 65 | :func:`quantize_flow`. 66 | concat_axis (int): The axis that dx and dy are concatenated, 67 | can be either 0 or 1. Ignored if quantize is False. 68 | """ 69 | if not quantize: 70 | with open(filename, 'wb') as f: 71 | f.write('PIEH'.encode('utf-8')) 72 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 73 | flow = flow.astype(np.float32) 74 | flow.tofile(f) 75 | f.flush() 76 | else: 77 | assert concat_axis in [0, 1] 78 | dx, dy = quantize_flow(flow, *args, **kwargs) 79 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 80 | os.makedirs(filename, exist_ok=True) 81 | cv2.imwrite(dxdy, filename) 82 | 83 | 84 | def quantize_flow(flow, max_val=0.02, norm=True): 85 | """Quantize flow to [0, 255]. 86 | 87 | After this step, the size of flow will be much smaller, and can be 88 | dumped as jpeg images. 89 | 90 | Args: 91 | flow (ndarray): (h, w, 2) array of optical flow. 92 | max_val (float): Maximum value of flow, values beyond 93 | [-max_val, max_val] will be truncated. 94 | norm (bool): Whether to divide flow values by image width/height. 95 | 96 | Returns: 97 | tuple[ndarray]: Quantized dx and dy. 98 | """ 99 | h, w, _ = flow.shape 100 | dx = flow[..., 0] 101 | dy = flow[..., 1] 102 | if norm: 103 | dx = dx / w # avoid inplace operations 104 | dy = dy / h 105 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 106 | flow_comps = [ 107 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] 108 | ] 109 | return tuple(flow_comps) 110 | 111 | 112 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 113 | """Recover from quantized flow. 114 | 115 | Args: 116 | dx (ndarray): Quantized dx. 117 | dy (ndarray): Quantized dy. 118 | max_val (float): Maximum value used when quantizing. 119 | denorm (bool): Whether to multiply flow values with width/height. 120 | 121 | Returns: 122 | ndarray: Dequantized flow. 123 | """ 124 | assert dx.shape == dy.shape 125 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 126 | 127 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 128 | 129 | if denorm: 130 | dx *= dx.shape[1] 131 | dy *= dx.shape[0] 132 | flow = np.dstack((dx, dy)) 133 | return flow 134 | 135 | 136 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 137 | """Quantize an array of (-inf, inf) to [0, levels-1]. 138 | 139 | Args: 140 | arr (ndarray): Input array. 141 | min_val (scalar): Minimum value to be clipped. 142 | max_val (scalar): Maximum value to be clipped. 143 | levels (int): Quantization levels. 144 | dtype (np.type): The type of the quantized array. 145 | 146 | Returns: 147 | tuple: Quantized array. 148 | """ 149 | if not (isinstance(levels, int) and levels > 1): 150 | raise ValueError( 151 | f'levels must be a positive integer, but got {levels}') 152 | if min_val >= max_val: 153 | raise ValueError( 154 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 155 | 156 | arr = np.clip(arr, min_val, max_val) - min_val 157 | quantized_arr = np.minimum( 158 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 159 | 160 | return quantized_arr 161 | 162 | 163 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 164 | """Dequantize an array. 165 | 166 | Args: 167 | arr (ndarray): Input array. 168 | min_val (scalar): Minimum value to be clipped. 169 | max_val (scalar): Maximum value to be clipped. 170 | levels (int): Quantization levels. 171 | dtype (np.type): The type of the dequantized array. 172 | 173 | Returns: 174 | tuple: Dequantized array. 175 | """ 176 | if not (isinstance(levels, int) and levels > 1): 177 | raise ValueError( 178 | f'levels must be a positive integer, but got {levels}') 179 | if min_val >= max_val: 180 | raise ValueError( 181 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 182 | 183 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - 184 | min_val) / levels + min_val 185 | 186 | return dequantized_arr 187 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import math 9 | import numpy as np 10 | import os 11 | import torch 12 | from torchvision.utils import make_grid 13 | import matplotlib.pyplot as plt 14 | 15 | def img2tensor(imgs, bgr2rgb=True, float32=True): 16 | """Numpy array to tensor. 17 | 18 | Args: 19 | imgs (list[ndarray] | ndarray): Input images. 20 | bgr2rgb (bool): Whether to change bgr to rgb. 21 | float32 (bool): Whether to change to float32. 22 | 23 | Returns: 24 | list[tensor] | tensor: Tensor images. If returned results only have 25 | one element, just return tensor. 26 | """ 27 | 28 | def _totensor(img, bgr2rgb, float32): 29 | if img.shape[2] == 3 and bgr2rgb: 30 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 31 | img = torch.from_numpy(img.transpose(2, 0, 1)) 32 | if float32: 33 | img = img.float() 34 | return img 35 | 36 | if isinstance(imgs, list): 37 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 38 | else: 39 | return _totensor(imgs, bgr2rgb, float32) 40 | 41 | 42 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 43 | """Convert torch Tensors into image numpy arrays. 44 | 45 | After clamping to [min, max], values will be normalized to [0, 1]. 46 | 47 | Args: 48 | tensor (Tensor or list[Tensor]): Accept shapes: 49 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 50 | 2) 3D Tensor of shape (3/1 x H x W); 51 | 3) 2D Tensor of shape (H x W). 52 | Tensor channel should be in RGB order. 53 | rgb2bgr (bool): Whether to change rgb to bgr. 54 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 55 | to uint8 type with range [0, 255]; otherwise, float type with 56 | range [0, 1]. Default: ``np.uint8``. 57 | min_max (tuple[int]): min and max values for clamp. 58 | 59 | Returns: 60 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 61 | shape (H x W). The channel order is BGR. 62 | """ 63 | if not (torch.is_tensor(tensor) or 64 | (isinstance(tensor, list) 65 | and all(torch.is_tensor(t) for t in tensor))): 66 | raise TypeError( 67 | f'tensor or list of tensors expected, got {type(tensor)}') 68 | 69 | if torch.is_tensor(tensor): 70 | tensor = [tensor] 71 | result = [] 72 | for _tensor in tensor: 73 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 74 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 75 | 76 | n_dim = _tensor.dim() 77 | if n_dim == 4: 78 | img_np = make_grid( 79 | _tensor, nrow=int(math.sqrt(_tensor.size(0))), 80 | normalize=False).numpy() 81 | img_np = img_np.transpose(1, 2, 0) 82 | if rgb2bgr: 83 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 84 | elif n_dim == 3: 85 | img_np = _tensor.numpy() 86 | img_np = img_np.transpose(1, 2, 0) 87 | if img_np.shape[2] == 1: # gray image 88 | img_np = np.squeeze(img_np, axis=2) 89 | else: 90 | if rgb2bgr: 91 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 92 | elif n_dim == 2: 93 | img_np = _tensor.numpy() 94 | else: 95 | raise TypeError('Only support 4D, 3D or 2D tensor. ' 96 | f'But received with dimension: {n_dim}') 97 | if out_type == np.uint8: 98 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 99 | img_np = (img_np * 255.0).round() 100 | img_np = img_np.astype(out_type) 101 | result.append(img_np) 102 | if len(result) == 1: 103 | result = result[0] 104 | return result 105 | 106 | 107 | def imfrombytes(content, flag='color', float32=False): 108 | """Read an image from bytes. 109 | 110 | Args: 111 | content (bytes): Image bytes got from files or other streams. 112 | flag (str): Flags specifying the color type of a loaded image, 113 | candidates are `color`, `grayscale` and `unchanged`. 114 | float32 (bool): Whether to change to float32., If True, will also norm 115 | to [0, 1]. Default: False. 116 | 117 | Returns: 118 | ndarray: Loaded image array. 119 | """ 120 | img_np = np.frombuffer(content, np.uint8) 121 | imread_flags = { 122 | 'color': cv2.IMREAD_COLOR, 123 | 'grayscale': cv2.IMREAD_GRAYSCALE, 124 | 'unchanged': cv2.IMREAD_UNCHANGED 125 | } 126 | if img_np is None: 127 | raise Exception('None .. !!!') 128 | img = cv2.imdecode(img_np, imread_flags[flag]) 129 | if float32: 130 | img = img.astype(np.float32) / 255. 131 | return img 132 | 133 | def padding(img_lq, img_gt, gt_size): 134 | h, w, _ = img_lq.shape 135 | 136 | h_pad = max(0, gt_size - h) 137 | w_pad = max(0, gt_size - w) 138 | 139 | if h_pad == 0 and w_pad == 0: 140 | return img_lq, img_gt 141 | 142 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 143 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 144 | # print('img_lq', img_lq.shape, img_gt.shape) 145 | return img_lq, img_gt 146 | 147 | def imwrite(img, file_path, params=None, auto_mkdir=True): 148 | """Write image to file. 149 | 150 | Args: 151 | img (ndarray): Image array to be written. 152 | file_path (str): Image file path. 153 | params (None or list): Same as opencv's :func:`imwrite` interface. 154 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 155 | whether to create it automatically. 156 | 157 | Returns: 158 | bool: Successful or not. 159 | """ 160 | if auto_mkdir: 161 | dir_name = os.path.abspath(os.path.dirname(file_path)) 162 | os.makedirs(dir_name, exist_ok=True) 163 | return cv2.imwrite(file_path, img, params) 164 | 165 | def pltimwrite(img, file_path): 166 | dir_name = os.path.abspath(os.path.dirname(file_path)) 167 | os.makedirs(dir_name, exist_ok=True) 168 | plt.imshow(img) 169 | plt.savefig(file_path) 170 | plt.show() 171 | 172 | def crop_border(imgs, crop_border): 173 | """Crop borders of images. 174 | 175 | Args: 176 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 177 | crop_border (int): Crop border for each end of height and weight. 178 | 179 | Returns: 180 | list[ndarray]: Cropped images. 181 | """ 182 | if crop_border == 0: 183 | return imgs 184 | else: 185 | if isinstance(imgs, list): 186 | return [ 187 | v[crop_border:-crop_border, crop_border:-crop_border, ...] 188 | for v in imgs 189 | ] 190 | else: 191 | return imgs[crop_border:-crop_border, crop_border:-crop_border, 192 | ...] 193 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import cv2 8 | import lmdb 9 | import sys 10 | from multiprocessing import Pool 11 | from os import path as osp 12 | from tqdm import tqdm 13 | 14 | 15 | def make_lmdb_from_imgs(data_path, 16 | lmdb_path, 17 | img_path_list, 18 | keys, 19 | batch=5000, 20 | compress_level=1, 21 | multiprocessing_read=False, 22 | n_thread=40, 23 | map_size=None): 24 | """Make lmdb from images. 25 | 26 | Contents of lmdb. The file structure is: 27 | example.lmdb 28 | ├── data.mdb 29 | ├── lock.mdb 30 | ├── meta_info.txt 31 | 32 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 33 | https://lmdb.readthedocs.io/en/release/ for more details. 34 | 35 | The meta_info.txt is a specified txt file to record the meta information 36 | of our datasets. It will be automatically created when preparing 37 | datasets by our provided dataset tools. 38 | Each line in the txt file records 1)image name (with extension), 39 | 2)image shape, and 3)compression level, separated by a white space. 40 | 41 | For example, the meta information could be: 42 | `000_00000000.png (720,1280,3) 1`, which means: 43 | 1) image name (with extension): 000_00000000.png; 44 | 2) image shape: (720,1280,3); 45 | 3) compression level: 1 46 | 47 | We use the image name without extension as the lmdb key. 48 | 49 | If `multiprocessing_read` is True, it will read all the images to memory 50 | using multiprocessing. Thus, your server needs to have enough memory. 51 | 52 | Args: 53 | data_path (str): Data path for reading images. 54 | lmdb_path (str): Lmdb save path. 55 | img_path_list (str): Image path list. 56 | keys (str): Used for lmdb keys. 57 | batch (int): After processing batch images, lmdb commits. 58 | Default: 5000. 59 | compress_level (int): Compress level when encoding images. Default: 1. 60 | multiprocessing_read (bool): Whether use multiprocessing to read all 61 | the images to memory. Default: False. 62 | n_thread (int): For multiprocessing. 63 | map_size (int | None): Map size for lmdb env. If None, use the 64 | estimated size from images. Default: None 65 | """ 66 | 67 | assert len(img_path_list) == len(keys), ( 68 | 'img_path_list and keys should have the same length, ' 69 | f'but got {len(img_path_list)} and {len(keys)}') 70 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 71 | print(f'Total images: {len(img_path_list)}') 72 | if not lmdb_path.endswith('.lmdb'): 73 | raise ValueError("lmdb_path must end with '.lmdb'.") 74 | if osp.exists(lmdb_path): 75 | print(f'Folder {lmdb_path} already exists. Exit.') 76 | sys.exit(1) 77 | 78 | if multiprocessing_read: 79 | # read all the images to memory (multiprocessing) 80 | dataset = {} # use dict to keep the order for multiprocessing 81 | shapes = {} 82 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 83 | pbar = tqdm(total=len(img_path_list), unit='image') 84 | 85 | def callback(arg): 86 | """get the image data and update pbar.""" 87 | key, dataset[key], shapes[key] = arg 88 | pbar.update(1) 89 | pbar.set_description(f'Read {key}') 90 | 91 | pool = Pool(n_thread) 92 | for path, key in zip(img_path_list, keys): 93 | pool.apply_async( 94 | read_img_worker, 95 | args=(osp.join(data_path, path), key, compress_level), 96 | callback=callback) 97 | pool.close() 98 | pool.join() 99 | pbar.close() 100 | print(f'Finish reading {len(img_path_list)} images.') 101 | 102 | # create lmdb environment 103 | if map_size is None: 104 | # obtain data size for one image 105 | img = cv2.imread( 106 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 107 | _, img_byte = cv2.imencode( 108 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 109 | data_size_per_img = img_byte.nbytes 110 | print('Data size per image is: ', data_size_per_img) 111 | data_size = data_size_per_img * len(img_path_list) 112 | map_size = data_size * 10 113 | 114 | env = lmdb.open(lmdb_path, map_size=map_size) 115 | 116 | # write data to lmdb 117 | pbar = tqdm(total=len(img_path_list), unit='chunk') 118 | txn = env.begin(write=True) 119 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 120 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 121 | pbar.update(1) 122 | pbar.set_description(f'Write {key}') 123 | key_byte = key.encode('ascii') 124 | if multiprocessing_read: 125 | img_byte = dataset[key] 126 | h, w, c = shapes[key] 127 | else: 128 | _, img_byte, img_shape = read_img_worker( 129 | osp.join(data_path, path), key, compress_level) 130 | h, w, c = img_shape 131 | 132 | txn.put(key_byte, img_byte) 133 | # write meta information 134 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 135 | if idx % batch == 0: 136 | txn.commit() 137 | txn = env.begin(write=True) 138 | pbar.close() 139 | txn.commit() 140 | env.close() 141 | txt_file.close() 142 | print('\nFinish writing lmdb.') 143 | 144 | 145 | def read_img_worker(path, key, compress_level): 146 | """Read image worker. 147 | 148 | Args: 149 | path (str): Image path. 150 | key (str): Image key. 151 | compress_level (int): Compress level when encoding images. 152 | 153 | Returns: 154 | str: Image key. 155 | byte: Image byte. 156 | tuple[int]: Image shape. 157 | """ 158 | 159 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 160 | if img.ndim == 2: 161 | h, w = img.shape 162 | c = 1 163 | else: 164 | h, w, c = img.shape 165 | _, img_byte = cv2.imencode('.png', img, 166 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 167 | return (key, img_byte, (h, w, c)) 168 | 169 | 170 | class LmdbMaker(): 171 | """LMDB Maker. 172 | 173 | Args: 174 | lmdb_path (str): Lmdb save path. 175 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 176 | batch (int): After processing batch images, lmdb commits. 177 | Default: 5000. 178 | compress_level (int): Compress level when encoding images. Default: 1. 179 | """ 180 | 181 | def __init__(self, 182 | lmdb_path, 183 | map_size=1024**4, 184 | batch=5000, 185 | compress_level=1): 186 | if not lmdb_path.endswith('.lmdb'): 187 | raise ValueError("lmdb_path must end with '.lmdb'.") 188 | if osp.exists(lmdb_path): 189 | print(f'Folder {lmdb_path} already exists. Exit.') 190 | sys.exit(1) 191 | 192 | self.lmdb_path = lmdb_path 193 | self.batch = batch 194 | self.compress_level = compress_level 195 | self.env = lmdb.open(lmdb_path, map_size=map_size) 196 | self.txn = self.env.begin(write=True) 197 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 198 | self.counter = 0 199 | 200 | def put(self, img_byte, key, img_shape): 201 | self.counter += 1 202 | key_byte = key.encode('ascii') 203 | self.txn.put(key_byte, img_byte) 204 | # write meta information 205 | h, w, c = img_shape 206 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 207 | if self.counter % self.batch == 0: 208 | self.txn.commit() 209 | self.txn = self.env.begin(write=True) 210 | 211 | def close(self): 212 | self.txn.commit() 213 | self.env.close() 214 | self.txt_file.close() 215 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import datetime 8 | import logging 9 | import time 10 | 11 | from .dist_util import get_dist_info, master_only 12 | 13 | 14 | class MessageLogger(): 15 | """Message logger for printing. 16 | 17 | Args: 18 | opt (dict): Config. It contains the following keys: 19 | name (str): Exp name. 20 | logger (dict): Contains 'print_freq' (str) for logger interval. 21 | train (dict): Contains 'total_iter' (int) for total iters. 22 | use_tb_logger (bool): Use tensorboard logger. 23 | start_iter (int): Start iter. Default: 1. 24 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 25 | """ 26 | 27 | def __init__(self, opt, start_iter=1, tb_logger=None): 28 | self.exp_name = opt['name'] 29 | self.interval = opt['logger']['print_freq'] 30 | self.start_iter = start_iter 31 | self.max_iters = opt['train']['total_iter'] 32 | self.use_tb_logger = opt['logger']['use_tb_logger'] 33 | self.tb_logger = tb_logger 34 | self.start_time = time.time() 35 | self.logger = get_root_logger() 36 | 37 | @master_only 38 | def __call__(self, log_vars): 39 | """Format logging message. 40 | 41 | Args: 42 | log_vars (dict): It contains the following keys: 43 | epoch (int): Epoch number. 44 | iter (int): Current iter. 45 | lrs (list): List for learning rates. 46 | 47 | time (float): Iter time. 48 | data_time (float): Data time for each iter. 49 | """ 50 | # epoch, iter, learning rates 51 | epoch = log_vars.pop('epoch') 52 | current_iter = log_vars.pop('iter') 53 | lrs = log_vars.pop('lrs') 54 | 55 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' 56 | f'iter:{current_iter:8,d}, lr:(') 57 | for v in lrs: 58 | message += f'{v:.3e},' 59 | message += ')] ' 60 | 61 | # time and estimated time 62 | if 'time' in log_vars.keys(): 63 | iter_time = log_vars.pop('time') 64 | data_time = log_vars.pop('data_time') 65 | 66 | total_time = time.time() - self.start_time 67 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 68 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 69 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 70 | message += f'[eta: {eta_str}, ' 71 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 72 | 73 | # other items, especially losses 74 | for k, v in log_vars.items(): 75 | message += f'{k}: {v:.4e} ' 76 | # tensorboard logger 77 | if self.use_tb_logger and 'debug' not in self.exp_name: 78 | if k.startswith('l_'): 79 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 80 | else: 81 | self.tb_logger.add_scalar(k, v, current_iter) 82 | self.logger.info(message) 83 | 84 | 85 | @master_only 86 | def init_tb_logger(log_dir): 87 | from torch.utils.tensorboard import SummaryWriter 88 | tb_logger = SummaryWriter(log_dir=log_dir) 89 | return tb_logger 90 | 91 | 92 | @master_only 93 | def init_wandb_logger(opt): 94 | """We now only use wandb to sync tensorboard log.""" 95 | import wandb 96 | logger = logging.getLogger('basicsr') 97 | 98 | project = opt['logger']['wandb']['project'] 99 | resume_id = opt['logger']['wandb'].get('resume_id') 100 | if resume_id: 101 | wandb_id = resume_id 102 | resume = 'allow' 103 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 104 | else: 105 | wandb_id = wandb.util.generate_id() 106 | resume = 'never' 107 | 108 | wandb.init( 109 | id=wandb_id, 110 | resume=resume, 111 | name=opt['name'], 112 | config=opt, 113 | project=project, 114 | sync_tensorboard=True) 115 | 116 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 117 | 118 | 119 | def get_root_logger(logger_name='basicsr', 120 | log_level=logging.INFO, 121 | log_file=None): 122 | """Get the root logger. 123 | 124 | The logger will be initialized if it has not been initialized. By default a 125 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 126 | also be added. 127 | 128 | Args: 129 | logger_name (str): root logger name. Default: 'basicsr'. 130 | log_file (str | None): The log filename. If specified, a FileHandler 131 | will be added to the root logger. 132 | log_level (int): The root logger level. Note that only the process of 133 | rank 0 is affected, while other processes will set the level to 134 | "Error" and be silent most of the time. 135 | 136 | Returns: 137 | logging.Logger: The root logger. 138 | """ 139 | logger = logging.getLogger(logger_name) 140 | # if the logger has been initialized, just return it 141 | if logger.hasHandlers(): 142 | return logger 143 | 144 | format_str = '%(asctime)s %(levelname)s: %(message)s' 145 | logging.basicConfig(format=format_str, level=log_level) 146 | rank, _ = get_dist_info() 147 | if rank != 0: 148 | logger.setLevel('ERROR') 149 | elif log_file is not None: 150 | file_handler = logging.FileHandler(log_file, 'w') 151 | file_handler.setFormatter(logging.Formatter(format_str)) 152 | file_handler.setLevel(log_level) 153 | logger.addHandler(file_handler) 154 | 155 | return logger 156 | 157 | 158 | def get_env_info(): 159 | """Get environment information. 160 | 161 | Currently, only log the software version. 162 | """ 163 | import torch 164 | import torchvision 165 | 166 | from basicsr.version import __version__ 167 | msg = r""" 168 | ____ _ _____ ____ 169 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 170 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 171 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 172 | /_____/ \__,_//____//_/ \___//____//_/ |_| 173 | ______ __ __ __ __ 174 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 175 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 176 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 177 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 178 | """ 179 | msg += ('\nVersion Information: ' 180 | f'\n\tBasicSR: {__version__}' 181 | f'\n\tPyTorch: {torch.__version__}' 182 | f'\n\tTorchVision: {torchvision.__version__}') 183 | return msg 184 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import numpy as np 8 | import os 9 | import random 10 | import time 11 | import torch 12 | from os import path as osp 13 | 14 | from .dist_util import master_only 15 | from .logger import get_root_logger 16 | 17 | 18 | def set_random_seed(seed): 19 | """Set random seeds.""" 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.cuda.manual_seed_all(seed) 25 | 26 | 27 | def get_time_str(): 28 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 29 | 30 | 31 | def mkdir_and_rename(path): 32 | """mkdirs. If path exists, rename it with timestamp and create a new one. 33 | 34 | Args: 35 | path (str): Folder path. 36 | """ 37 | if osp.exists(path): 38 | new_name = path + '_archived_' + get_time_str() 39 | print(f'Path already exists. Rename it to {new_name}', flush=True) 40 | os.rename(path, new_name) 41 | os.makedirs(path, exist_ok=True) 42 | 43 | 44 | @master_only 45 | def make_exp_dirs(opt): 46 | """Make dirs for experiments.""" 47 | path_opt = opt['path'].copy() 48 | if opt['is_train']: 49 | mkdir_and_rename(path_opt.pop('experiments_root')) 50 | else: 51 | mkdir_and_rename(path_opt.pop('results_root')) 52 | for key, path in path_opt.items(): 53 | if ('strict_load' not in key) and ('pretrain_network' 54 | not in key) and ('resume' 55 | not in key): 56 | os.makedirs(path, exist_ok=True) 57 | 58 | 59 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 60 | """Scan a directory to find the interested files. 61 | 62 | Args: 63 | dir_path (str): Path of the directory. 64 | suffix (str | tuple(str), optional): File suffix that we are 65 | interested in. Default: None. 66 | recursive (bool, optional): If set to True, recursively scan the 67 | directory. Default: False. 68 | full_path (bool, optional): If set to True, include the dir_path. 69 | Default: False. 70 | 71 | Returns: 72 | A generator for all the interested files with relative pathes. 73 | """ 74 | 75 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 76 | raise TypeError('"suffix" must be a string or tuple of strings') 77 | 78 | root = dir_path 79 | 80 | def _scandir(dir_path, suffix, recursive): 81 | for entry in os.scandir(dir_path): 82 | if not entry.name.startswith('.') and entry.is_file(): 83 | if full_path: 84 | return_path = entry.path 85 | else: 86 | return_path = osp.relpath(entry.path, root) 87 | 88 | if suffix is None: 89 | yield return_path 90 | elif return_path.endswith(suffix): 91 | yield return_path 92 | else: 93 | if recursive: 94 | yield from _scandir( 95 | entry.path, suffix=suffix, recursive=recursive) 96 | else: 97 | continue 98 | 99 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 100 | 101 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): 102 | """Scan a directory to find the interested files. 103 | 104 | Args: 105 | dir_path (str): Path of the directory. 106 | keywords (str | tuple(str), optional): File keywords that we are 107 | interested in. Default: None. 108 | recursive (bool, optional): If set to True, recursively scan the 109 | directory. Default: False. 110 | full_path (bool, optional): If set to True, include the dir_path. 111 | Default: False. 112 | 113 | Returns: 114 | A generator for all the interested files with relative pathes. 115 | """ 116 | 117 | if (keywords is not None) and not isinstance(keywords, (str, tuple)): 118 | raise TypeError('"keywords" must be a string or tuple of strings') 119 | 120 | root = dir_path 121 | 122 | def _scandir(dir_path, keywords, recursive): 123 | for entry in os.scandir(dir_path): 124 | if not entry.name.startswith('.') and entry.is_file(): 125 | if full_path: 126 | return_path = entry.path 127 | else: 128 | return_path = osp.relpath(entry.path, root) 129 | 130 | if keywords is None: 131 | yield return_path 132 | elif return_path.find(keywords) > 0: 133 | yield return_path 134 | else: 135 | if recursive: 136 | yield from _scandir( 137 | entry.path, keywords=keywords, recursive=recursive) 138 | else: 139 | continue 140 | 141 | return _scandir(dir_path, keywords=keywords, recursive=recursive) 142 | 143 | def check_resume(opt, resume_iter): 144 | """Check resume states and pretrain_network paths. 145 | 146 | Args: 147 | opt (dict): Options. 148 | resume_iter (int): Resume iteration. 149 | """ 150 | logger = get_root_logger() 151 | if opt['path']['resume_state']: 152 | # get all the networks 153 | networks = [key for key in opt.keys() if key.startswith('network_')] 154 | flag_pretrain = False 155 | for network in networks: 156 | if opt['path'].get(f'pretrain_{network}') is not None: 157 | flag_pretrain = True 158 | if flag_pretrain: 159 | logger.warning( 160 | 'pretrain_network path will be ignored during resuming.') 161 | # set pretrained model paths 162 | for network in networks: 163 | name = f'pretrain_{network}' 164 | basename = network.replace('network_', '') 165 | if opt['path'].get('ignore_resume_networks') is None or ( 166 | basename not in opt['path']['ignore_resume_networks']): 167 | opt['path'][name] = osp.join( 168 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 169 | logger.info(f"Set {name} to {opt['path'][name]}") 170 | 171 | 172 | def sizeof_fmt(size, suffix='B'): 173 | """Get human readable file size. 174 | 175 | Args: 176 | size (int): File size. 177 | suffix (str): Suffix. Default: 'B'. 178 | 179 | Return: 180 | str: Formated file siz. 181 | """ 182 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 183 | if abs(size) < 1024.0: 184 | return f'{size:3.1f} {unit}{suffix}' 185 | size /= 1024.0 186 | return f'{size:3.1f} Y{suffix}' 187 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2021 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import yaml 8 | from collections import OrderedDict 9 | from os import path as osp 10 | 11 | 12 | def ordered_yaml(): 13 | """Support OrderedDict for yaml. 14 | 15 | Returns: 16 | yaml Loader and Dumper. 17 | """ 18 | try: 19 | from yaml import CDumper as Dumper 20 | from yaml import CLoader as Loader 21 | except ImportError: 22 | from yaml import Dumper, Loader 23 | 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | def parse(opt_path, is_train=True): 38 | """Parse option file. 39 | 40 | Args: 41 | opt_path (str): Option file path. 42 | is_train (str): Indicate whether in training or not. Default: True. 43 | 44 | Returns: 45 | (dict): Options. 46 | """ 47 | with open(opt_path, mode='r') as f: 48 | Loader, _ = ordered_yaml() 49 | opt = yaml.load(f, Loader=Loader) 50 | 51 | opt['is_train'] = is_train 52 | 53 | # datasets 54 | if 'datasets' in opt: 55 | for phase, dataset in opt['datasets'].items(): 56 | # for several datasets, e.g., test_1, test_2 57 | phase = phase.split('_')[0] 58 | dataset['phase'] = phase 59 | if 'scale' in opt: 60 | dataset['scale'] = opt['scale'] 61 | if dataset.get('dataroot_gt') is not None: 62 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 63 | if dataset.get('dataroot_lq') is not None: 64 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 65 | 66 | # paths 67 | for key, val in opt['path'].items(): 68 | if (val is not None) and ('resume_state' in key 69 | or 'pretrain_network' in key): 70 | opt['path'][key] = osp.expanduser(val) 71 | opt['path']['root'] = osp.abspath( 72 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 73 | if is_train: 74 | experiments_root = osp.join(opt['path']['root'], 'experiments', 75 | opt['name']) 76 | opt['path']['experiments_root'] = experiments_root 77 | opt['path']['models'] = osp.join(experiments_root, 'models') 78 | opt['path']['training_states'] = osp.join(experiments_root, 79 | 'training_states') 80 | opt['path']['log'] = experiments_root 81 | opt['path']['visualization'] = osp.join(experiments_root, 82 | 'visualization') 83 | 84 | # change some options for debug mode 85 | if 'debug' in opt['name']: 86 | if 'val' in opt: 87 | opt['val']['val_freq'] = 8 88 | opt['logger']['print_freq'] = 1 89 | opt['logger']['save_checkpoint_freq'] = 8 90 | else: # test 91 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 92 | opt['path']['results_root'] = results_root 93 | opt['path']['log'] = results_root 94 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 95 | 96 | return opt 97 | 98 | 99 | def dict2str(opt, indent_level=1): 100 | """dict to string for printing options. 101 | 102 | Args: 103 | opt (dict): Option dict. 104 | indent_level (int): Indent level. Default: 1. 105 | 106 | Return: 107 | (str): Option string for printing. 108 | """ 109 | msg = '\n' 110 | for k, v in opt.items(): 111 | if isinstance(v, dict): 112 | msg += ' ' * (indent_level * 2) + k + ':[' 113 | msg += dict2str(v, indent_level + 1) 114 | msg += ' ' * (indent_level * 2) + ']\n' 115 | else: 116 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 117 | return msg 118 | -------------------------------------------------------------------------------- /basicsr/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Mon May 21 02:15:04 2024 3 | __version__ = '1.0.0+ebd1331' 4 | short_version = '1.0.0' 5 | version_info = (1, 0, 0) 6 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | cuda: "11.3" 3 | gpu: true 4 | python_version: "3.9" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "numpy==1.21.1" 10 | - "ipython==7.21.0" 11 | - "addict==2.4.0" 12 | - "future==0.18.2" 13 | - "lmdb==1.3.0" 14 | - "opencv-python==4.5.5.64" 15 | - "Pillow==9.1.0" 16 | - "pyyaml==6.0" 17 | - "torch==1.11.0" 18 | - "torchvision==0.12.0" 19 | - "tqdm==4.64.0" 20 | - "scipy==1.8.0" 21 | - "scikit-image==0.19.2" 22 | - "matplotlib==3.5.1" 23 | 24 | predict: "predict.py:Predictor" 25 | -------------------------------------------------------------------------------- /make_video.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | 5 | """ 6 | This script creates a side-by-side comparison video from pairs of input and predicted frames stored in a directory. 7 | A sliding line moves across the frames to visually compare the differences, and the resulting video is saved to an output file. 8 | """ 9 | 10 | 11 | # Directory path for Input and Restored frames 12 | frames_dir = 'path to the low quality ad=nd high quality frames' 13 | 14 | # Output video parameters 15 | output_video_path = 'path_to_save_video/x.mp4' 16 | fps = 20 # Set the frames per second for the output video 17 | 18 | 19 | 20 | # Initialize video writer 21 | frame_example = cv2.imread(os.path.join(frames_dir, os.listdir(frames_dir)[1])) 22 | height, width, layers = frame_example.shape 23 | print(height, width) 24 | out = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) 25 | 26 | # Get sorted list of frame filenames 27 | all_files = os.listdir(frames_dir) 28 | input_frames = sorted([f for f in all_files if 'Input' in f], key=lambda x: int(x.split('_')[1])) 29 | pred_frames = sorted([f for f in all_files if 'Pred' in f], key=lambda x: int(x.split('_')[1])) 30 | 31 | # Total number of frames 32 | total_frames = min(len(input_frames), len(pred_frames)) 33 | 34 | for i in range(total_frames): 35 | # Construct the filenames based on the sorted lists 36 | low_quality_filename = input_frames[i] 37 | high_quality_filename = pred_frames[i] 38 | 39 | # Read frames 40 | frame1 = cv2.imread(os.path.join(frames_dir, low_quality_filename)) 41 | frame2 = cv2.imread(os.path.join(frames_dir, high_quality_filename)) 42 | # print(frame1.shape,frame2.shape) 43 | 44 | # Compute the position of the sliding line 45 | slider_position = int((i / total_frames) * width) 46 | 47 | # Create a combined frame 48 | combined_frame = frame1.copy() 49 | combined_frame[:, :slider_position] = frame2[:, :slider_position] 50 | 51 | # Draw the sliding line 52 | cv2.line(combined_frame, (slider_position, 0), (slider_position, height), (0, 255, 0), 2) 53 | 54 | # Write the combined frame to the output video 55 | out.write(combined_frame) 56 | 57 | # Release video writer 58 | out.release() 59 | 60 | print("Video has been created and saved as", output_video_path) 61 | -------------------------------------------------------------------------------- /options/Turtle_Deblur_Gopro.yml: -------------------------------------------------------------------------------- 1 | name: Final_Gaia_Gopro 2 | model_type: VideoRestorationModel 3 | scale: 1 4 | num_gpu: 8 5 | manual_seed: 10 6 | n_sequence: 5 # n_frames 7 | dir_data: ['/home/amir/datasets/GoPro/train/'] 8 | n_colors: 3 9 | rgb_range: 1 10 | no_augment: False 11 | loss_type: 1*L1 12 | patch_size: 192 13 | size_must_mode: 4 14 | model: Turtle_t1_arch 15 | pretrain_models_dir: None 16 | type: deblurring 17 | dim: 64 18 | Enc_blocks: [2, 6, 10] 19 | Middle_blocks: 11 20 | Dec_blocks: [10, 6, 2] 21 | num_refinement_blocks: 2 22 | use_both_input: False 23 | num_heads: [1, 2, 4, 8] 24 | num_frames_tocache: 3 25 | ffn_expansion_factor: 2.5 26 | 27 | encoder1_attn_type1 : "ReducedAttn" 28 | encoder1_attn_type2 : "ReducedAttn" 29 | encoder1_ffw_type : "FFW" 30 | 31 | encoder2_attn_type1 : "ReducedAttn" 32 | encoder2_attn_type2 : "ReducedAttn" 33 | encoder2_ffw_type : "FFW" 34 | 35 | encoder3_attn_type1 : "Channel" 36 | encoder3_attn_type2 : "Channel" 37 | encoder3_ffw_type : "GFFW" 38 | 39 | decoder1_attn_type1 : "Channel" 40 | decoder1_attn_type2 : "CHM" 41 | decoder1_ffw_type : "GFFW" 42 | 43 | decoder2_attn_type1 : "Channel" 44 | decoder2_attn_type2 : "CHM" 45 | decoder2_ffw_type : "GFFW" 46 | 47 | decoder3_attn_type1 : "Channel" 48 | decoder3_attn_type2 : "CHM" 49 | decoder3_ffw_type : "GFFW" 50 | 51 | latent_attn_type1 : "FHR" 52 | latent_attn_type2 : "Channel" 53 | latent_attn_type3 : "FHR" 54 | latent_ffw_type : "GFFW" 55 | 56 | refinement_attn_type1 : "ReducedAttn" 57 | refinement_attn_type2 : "ReducedAttn" 58 | refinement_ffw_type : "GFFW" 59 | 60 | datasets: 61 | train: 62 | name: gopro-train 63 | filename_tmpl: '{}' 64 | io_backend: 65 | type: lmdb 66 | 67 | gt_size: 192 68 | use_flip: false 69 | use_rot: false 70 | 71 | # data loader 72 | use_shuffle: true 73 | num_worker_per_gpu: 8 74 | batch_size_per_gpu: 2 75 | dataset_enlarge_ratio: 1 76 | prefetch_mode: ~ 77 | 78 | val: 79 | name: gopro-test 80 | dir_data: ['/home/amir/datasets/GoPro/test/'] 81 | 82 | path: 83 | pretrain_network_g: ~ 84 | strict_load_g: true 85 | resume_state: ~ 86 | 87 | train: 88 | optim_g: 89 | type: Adam 90 | lr: !!float 4e-4 91 | weight_decay: 0 92 | betas: [0.9, 0.99] 93 | 94 | scheduler: 95 | type: TrueCosineAnnealingLR 96 | T_max: 200000 97 | eta_min: !!float 1e-7 98 | 99 | total_iter: 200000 100 | warmup_iter: -1 # no warm up 101 | 102 | # losses 103 | pixel_opt: 104 | type: L1Loss 105 | loss_weight: 1 106 | reduction: mean 107 | 108 | # validation settings 109 | val: 110 | val_freq: 10000 111 | save_img: true 112 | grids: true 113 | crop_size: 192 114 | max_minibatch: 8 115 | 116 | metrics: 117 | psnr: # metric name, can be arbitrary 118 | type: calculate_psnr 119 | crop_border: 0 120 | test_y_channel: false 121 | 122 | # logging settings 123 | logger: 124 | print_freq: 200 125 | save_checkpoint_freq: !!float 10000 126 | use_tb_logger: true 127 | wandb: 128 | project: ~ 129 | resume_id: ~ 130 | 131 | # dist training settings 132 | dist_params: 133 | backend: nccl 134 | port: 29500 -------------------------------------------------------------------------------- /options/Turtle_Denoise_Davis.yml: -------------------------------------------------------------------------------- 1 | name: Gaia_Denoise_Davis_noise20-50 2 | model_type: VideoRestorationModel 3 | scale: 1 4 | num_gpu: 8 5 | manual_seed: 10 6 | n_sequence: 5 # n_frames 7 | dir_data: ['/datasets/DAVIS/JPEGImages/'] 8 | n_colors: 3 9 | rgb_range: 1 10 | no_augment: False 11 | loss_type: 1*L1 12 | patch_size: 192 13 | size_must_mode: 4 14 | model: Turtle_t1_arch 15 | 16 | 17 | pretrain_models_dir: None 18 | type: denoising 19 | dim: 64 20 | Enc_blocks: [2, 6, 10] 21 | Middle_blocks: 11 22 | Dec_blocks: [10, 6, 2] 23 | num_refinement_blocks: 2 24 | use_both_input: False 25 | num_heads: [1, 2, 4, 8] 26 | num_frames_tocache: 3 27 | ffn_expansion_factor: 2.5 28 | 29 | encoder1_attn_type1 : "ReducedAttn" 30 | encoder1_attn_type2 : "ReducedAttn" 31 | encoder1_ffw_type : "FFW" 32 | 33 | encoder2_attn_type1 : "ReducedAttn" 34 | encoder2_attn_type2 : "ReducedAttn" 35 | encoder2_ffw_type : "FFW" 36 | 37 | encoder3_attn_type1 : "Channel" 38 | encoder3_attn_type2 : "Channel" 39 | encoder3_ffw_type : "GFFW" 40 | 41 | decoder1_attn_type1 : "Channel" 42 | decoder1_attn_type2 : "CHM" 43 | decoder1_ffw_type : "GFFW" 44 | 45 | decoder2_attn_type1 : "Channel" 46 | decoder2_attn_type2 : "CHM" 47 | decoder2_ffw_type : "GFFW" 48 | 49 | decoder3_attn_type1 : "Channel" 50 | decoder3_attn_type2 : "CHM" 51 | decoder3_ffw_type : "GFFW" 52 | 53 | latent_attn_type1 : "FHR" 54 | latent_attn_type2 : "Channel" 55 | latent_attn_type3 : "FHR" 56 | latent_ffw_type : "GFFW" 57 | 58 | refinement_attn_type1 : "ReducedAttn" 59 | refinement_attn_type2 : "ReducedAttn" 60 | refinement_ffw_type : "GFFW" 61 | 62 | prompt_attn: "NoAttn" 63 | prompt_ffw: "GFFW" 64 | 65 | datasets: 66 | train: 67 | name: rsvd-train 68 | filename_tmpl: '{}' 69 | io_backend: 70 | type: lmdb 71 | 72 | gt_size: 192 73 | use_flip: false 74 | use_rot: false 75 | 76 | # data loader 77 | use_shuffle: true 78 | num_worker_per_gpu: 8 79 | batch_size_per_gpu: 2 80 | dataset_enlarge_ratio: 1 81 | prefetch_mode: ~ 82 | 83 | val: 84 | name: davis 85 | dir_data: ['/datasets/DAVIS_testdev/DAVIS/JPEGImages/'] 86 | 87 | path: 88 | pretrain_network_g: ~ 89 | resume_state: ~ 90 | 91 | train: 92 | optim_g: 93 | type: Adam 94 | lr: !!float 4e-4 95 | weight_decay: 0 96 | betas: [0.9, 0.99] 97 | 98 | scheduler: 99 | type: TrueCosineAnnealingLR 100 | T_max: 250000 101 | eta_min: !!float 1e-7 102 | 103 | total_iter: 250000 104 | warmup_iter: -1 # no warm up 105 | 106 | # losses 107 | pixel_opt: 108 | type: L1Loss 109 | loss_weight: 1 110 | reduction: mean 111 | 112 | # validation settings 113 | val: 114 | val_freq: 10000 115 | save_img: true 116 | grids: true 117 | crop_size: 192 118 | max_minibatch: 8 119 | 120 | metrics: 121 | psnr: # metric name, can be arbitrary 122 | type: calculate_psnr 123 | crop_border: 0 124 | test_y_channel: false 125 | 126 | # logging settings 127 | logger: 128 | print_freq: 200 129 | save_checkpoint_freq: !!float 10000 130 | use_tb_logger: true 131 | wandb: 132 | project: ~ 133 | resume_id: ~ 134 | 135 | # dist training settings 136 | dist_params: 137 | backend: nccl 138 | port: 29500 139 | -------------------------------------------------------------------------------- /options/Turtle_Derain.yml: -------------------------------------------------------------------------------- 1 | name: Turtle_Derain 2 | model_type: VideoRestorationModel 3 | scale: 1 4 | num_gpu: 8 5 | manual_seed: 10 6 | n_sequence: 5 # n_frames 7 | dir_data: ['/datasets/NightRain/train/'] 8 | n_colors: 3 9 | rgb_range: 1 10 | no_augment: False 11 | loss_type: 1*L1 12 | patch_size: 192 13 | size_must_mode: 4 14 | model: Turtle_arch 15 | pretrain_models_dir: None 16 | type: deraining 17 | dim: 64 18 | Enc_blocks: [2, 6, 10] 19 | Middle_blocks: 11 20 | Dec_blocks: [10, 6, 2] 21 | num_refinement_blocks: 2 22 | use_both_input: False 23 | num_heads: [1, 2, 4, 8] 24 | num_frames_tocache: 3 25 | ffn_expansion_factor: 2.5 26 | 27 | 28 | encoder1_attn_type1 : "ReducedAttn" 29 | encoder1_attn_type2 : "ReducedAttn" 30 | encoder1_ffw_type : "FFW" 31 | 32 | encoder2_attn_type1 : "ReducedAttn" 33 | encoder2_attn_type2 : "ReducedAttn" 34 | encoder2_ffw_type : "FFW" 35 | 36 | encoder3_attn_type1 : "Channel" 37 | encoder3_attn_type2 : "Channel" 38 | encoder3_ffw_type : "GFFW" 39 | 40 | decoder1_attn_type1 : "Channel" 41 | decoder1_attn_type2 : "CHM" 42 | decoder1_ffw_type : "GFFW" 43 | 44 | decoder2_attn_type1 : "Channel" 45 | decoder2_attn_type2 : "CHM" 46 | decoder2_ffw_type : "GFFW" 47 | 48 | decoder3_attn_type1 : "Channel" 49 | decoder3_attn_type2 : "CHM" 50 | decoder3_ffw_type : "GFFW" 51 | 52 | latent_attn_type1 : "FHR" 53 | latent_attn_type2 : "Channel" 54 | latent_attn_type3 : "FHR" 55 | latent_ffw_type : "GFFW" 56 | 57 | refinement_attn_type1 : "ReducedAttn" 58 | refinement_attn_type2 : "ReducedAttn" 59 | refinement_ffw_type : "GFFW" 60 | 61 | 62 | datasets: 63 | train: 64 | name: ngtrain-train 65 | filename_tmpl: '{}' 66 | io_backend: 67 | type: lmdb 68 | 69 | gt_size: 192 70 | use_flip: false 71 | use_rot: false 72 | 73 | # data loader 74 | use_shuffle: true 75 | num_worker_per_gpu: 8 76 | batch_size_per_gpu: 2 77 | dataset_enlarge_ratio: 1 78 | prefetch_mode: ~ 79 | 80 | val: 81 | name: ngtrain-test 82 | dir_data: ['/datasets/NightRain/test/'] 83 | 84 | path: 85 | pretrain_network_g: ~ 86 | strict_load_g: true 87 | resume_state: ~ 88 | 89 | train: 90 | optim_g: 91 | type: Adam 92 | lr: !!float 4e-4 93 | weight_decay: 0 94 | betas: [0.9, 0.99] 95 | 96 | scheduler: 97 | type: TrueCosineAnnealingLR 98 | T_max: 200000 99 | eta_min: !!float 1e-7 100 | 101 | total_iter: 200000 102 | warmup_iter: -1 # no warm up 103 | 104 | # losses 105 | pixel_opt: 106 | type: L1Loss 107 | loss_weight: 1 108 | reduction: mean 109 | 110 | # validation settings 111 | val: 112 | val_freq: 50000 113 | save_img: true 114 | grids: true 115 | crop_size: 192 116 | max_minibatch: 8 117 | 118 | metrics: 119 | psnr: # metric name, can be arbitrary 120 | type: calculate_psnr 121 | crop_border: 0 122 | test_y_channel: false 123 | 124 | # logging settings 125 | logger: 126 | print_freq: 200 127 | save_checkpoint_freq: !!float 10000 128 | use_tb_logger: true 129 | wandb: 130 | project: ~ 131 | resume_id: ~ 132 | 133 | # dist training settings 134 | dist_params: 135 | backend: nccl 136 | port: 29500 -------------------------------------------------------------------------------- /options/Turtle_Derain_VRDS.yml: -------------------------------------------------------------------------------- 1 | name: Turtle_Derain 2 | model_type: VideoRestorationModel 3 | scale: 1 4 | num_gpu: 8 5 | manual_seed: 10 6 | n_sequence: 5 # n_frames 7 | dir_data: ['/datasets/VRDS/train/'] 8 | n_colors: 3 9 | rgb_range: 1 10 | no_augment: False 11 | loss_type: 1*L1 12 | patch_size: 192 13 | size_must_mode: 4 14 | model: Turtle_t1_arch 15 | pretrain_models_dir: None 16 | type: deraining 17 | dim: 64 18 | Enc_blocks: [2, 6, 10] 19 | Middle_blocks: 11 20 | Dec_blocks: [10, 6, 2] 21 | num_refinement_blocks: 2 22 | use_both_input: False 23 | num_heads: [1, 2, 4, 8] 24 | num_frames_tocache: 3 25 | ffn_expansion_factor: 2.5 26 | 27 | encoder1_attn_type1 : "ReducedAttn" 28 | encoder1_attn_type2 : "ReducedAttn" 29 | encoder1_ffw_type : "FFW" 30 | 31 | encoder2_attn_type1 : "ReducedAttn" 32 | encoder2_attn_type2 : "ReducedAttn" 33 | encoder2_ffw_type : "FFW" 34 | 35 | encoder3_attn_type1 : "Channel" 36 | encoder3_attn_type2 : "Channel" 37 | encoder3_ffw_type : "GFFW" 38 | 39 | decoder1_attn_type1 : "Channel" 40 | decoder1_attn_type2 : "CHM" 41 | decoder1_ffw_type : "GFFW" 42 | 43 | decoder2_attn_type1 : "Channel" 44 | decoder2_attn_type2 : "CHM" 45 | decoder2_ffw_type : "GFFW" 46 | 47 | decoder3_attn_type1 : "Channel" 48 | decoder3_attn_type2 : "CHM" 49 | decoder3_ffw_type : "GFFW" 50 | 51 | latent_attn_type1 : "FHR" 52 | latent_attn_type2 : "Channel" 53 | latent_attn_type3 : "FHR" 54 | latent_ffw_type : "GFFW" 55 | 56 | refinement_attn_type1 : "ReducedAttn" 57 | refinement_attn_type2 : "ReducedAttn" 58 | refinement_ffw_type : "GFFW" 59 | 60 | 61 | datasets: 62 | train: 63 | name: VRDS-train 64 | filename_tmpl: '{}' 65 | io_backend: 66 | type: lmdb 67 | 68 | gt_size: 192 69 | use_flip: false 70 | use_rot: false 71 | 72 | # data loader 73 | use_shuffle: true 74 | num_worker_per_gpu: 8 75 | batch_size_per_gpu: 2 76 | dataset_enlarge_ratio: 1 77 | prefetch_mode: ~ 78 | 79 | val: 80 | name: VRDS-test 81 | dir_data: ['/datasets/VRDS/test/'] 82 | 83 | path: 84 | pretrain_network_g: ~ 85 | strict_load_g: true 86 | resume_state: ~ 87 | 88 | train: 89 | optim_g: 90 | type: Adam 91 | lr: !!float 4e-4 92 | weight_decay: 0 93 | betas: [0.9, 0.99] 94 | 95 | scheduler: 96 | type: TrueCosineAnnealingLR 97 | T_max: 200000 98 | eta_min: !!float 1e-7 99 | 100 | total_iter: 200000 101 | warmup_iter: -1 # no warm up 102 | 103 | # losses 104 | pixel_opt: 105 | type: L1Loss 106 | loss_weight: 1 107 | reduction: mean 108 | 109 | # validation settings 110 | val: 111 | val_freq: 50000 112 | save_img: true 113 | grids: true 114 | crop_size: 192 115 | max_minibatch: 8 116 | 117 | metrics: 118 | psnr: # metric name, can be arbitrary 119 | type: calculate_psnr 120 | crop_border: 0 121 | test_y_channel: false 122 | 123 | # logging settings 124 | logger: 125 | print_freq: 200 126 | save_checkpoint_freq: !!float 10000 127 | use_tb_logger: true 128 | wandb: 129 | project: ~ 130 | resume_id: ~ 131 | 132 | # dist training settings 133 | dist_params: 134 | backend: nccl 135 | port: 29500 -------------------------------------------------------------------------------- /options/Turtle_Desnow.yml: -------------------------------------------------------------------------------- 1 | name: Turtle_Desnow 2 | model_type: VideoRestorationModel 3 | scale: 1 4 | num_gpu: 8 5 | manual_seed: 10 6 | n_sequence: 5 # n_frames 7 | dir_data: ['/datasets/Desnowing/rsvd/train/'] 8 | n_colors: 3 9 | rgb_range: 1 10 | no_augment: False 11 | loss_type: 1*L1 12 | patch_size: 192 13 | size_must_mode: 4 14 | model: Turtle_arch 15 | 16 | pretrain_models_dir: None 17 | type: desnowing 18 | dim: 64 19 | Enc_blocks: [2, 6, 10] 20 | Middle_blocks: 11 21 | Dec_blocks: [10, 6, 2] 22 | num_refinement_blocks: 2 23 | use_both_input: False 24 | num_heads: [1, 2, 4, 8] 25 | num_frames_tocache: 3 26 | ffn_expansion_factor: 2.5 27 | 28 | encoder1_attn_type1 : "ReducedAttn" 29 | encoder1_attn_type2 : "ReducedAttn" 30 | encoder1_ffw_type : "FFW" 31 | 32 | encoder2_attn_type1 : "ReducedAttn" 33 | encoder2_attn_type2 : "ReducedAttn" 34 | encoder2_ffw_type : "FFW" 35 | 36 | encoder3_attn_type1 : "Channel" 37 | encoder3_attn_type2 : "Channel" 38 | encoder3_ffw_type : "GFFW" 39 | 40 | decoder1_attn_type1 : "Channel" 41 | decoder1_attn_type2 : "CHM" 42 | decoder1_ffw_type : "GFFW" 43 | 44 | decoder2_attn_type1 : "Channel" 45 | decoder2_attn_type2 : "CHM" 46 | decoder2_ffw_type : "GFFW" 47 | 48 | decoder3_attn_type1 : "Channel" 49 | decoder3_attn_type2 : "CHM" 50 | decoder3_ffw_type : "GFFW" 51 | 52 | latent_attn_type1 : "FHR" 53 | latent_attn_type2 : "Channel" 54 | latent_attn_type3 : "FHR" 55 | latent_ffw_type : "GFFW" 56 | 57 | refinement_attn_type1 : "ReducedAttn" 58 | refinement_attn_type2 : "ReducedAttn" 59 | refinement_ffw_type : "GFFW" 60 | 61 | datasets: 62 | train: 63 | name: rsvd-train 64 | filename_tmpl: '{}' 65 | io_backend: 66 | type: lmdb 67 | 68 | gt_size: 192 69 | use_flip: false 70 | use_rot: false 71 | 72 | # data loader 73 | use_shuffle: true 74 | num_worker_per_gpu: 8 75 | batch_size_per_gpu: 2 76 | dataset_enlarge_ratio: 1 77 | prefetch_mode: ~ 78 | 79 | val: 80 | name: rsvd-test 81 | dir_data: ['/datasets/Desnowing/rsvd/test/'] 82 | 83 | path: 84 | pretrain_network_g: ~ 85 | strict_load_g: true 86 | resume_state: ~ 87 | train: 88 | optim_g: 89 | type: Adam 90 | lr: !!float 4e-4 91 | weight_decay: 0 92 | betas: [0.9, 0.99] 93 | 94 | scheduler: 95 | type: TrueCosineAnnealingLR 96 | T_max: 250000 97 | eta_min: !!float 1e-7 98 | 99 | total_iter: 250000 100 | warmup_iter: -1 # no warm up 101 | 102 | # losses 103 | pixel_opt: 104 | type: L1Loss 105 | loss_weight: 1 106 | reduction: mean 107 | 108 | # validation settings 109 | val: 110 | val_freq: 20000 111 | save_img: true 112 | grids: true 113 | crop_size: 192 114 | max_minibatch: 8 115 | 116 | metrics: 117 | psnr: # metric name, can be arbitrary 118 | type: calculate_psnr 119 | crop_border: 0 120 | test_y_channel: false 121 | 122 | # logging settings 123 | logger: 124 | print_freq: 200 125 | save_checkpoint_freq: !!float 10000 126 | use_tb_logger: true 127 | wandb: 128 | project: ~ 129 | resume_id: ~ 130 | 131 | # dist training settings 132 | dist_params: 133 | backend: nccl 134 | port: 29500 -------------------------------------------------------------------------------- /options/Turtle_SR_MVSR.yml: -------------------------------------------------------------------------------- 1 | name: Turtle_SR_MVSR 2 | model_type: VideoRestorationModel 3 | scale: 1 4 | num_gpu: 8 5 | manual_seed: 10 6 | n_sequence: 5 # n_frames 7 | 8 | dir_data: ['/datasets/MVSR4x/train/'] 9 | 10 | n_colors: 3 11 | rgb_range: 1 12 | no_augment: False 13 | loss_type: 1*L1 14 | patch_size: 192 15 | size_must_mode: 4 16 | model: Turtlesuper_t1_arch 17 | pretrain_models_dir: None 18 | type: superresolution 19 | dim: 64 20 | Enc_blocks: [2, 6, 10] 21 | Middle_blocks: 11 22 | Dec_blocks: [10, 6, 2] 23 | num_refinement_blocks: 2 24 | use_both_input: False 25 | num_heads: [1, 2, 4, 8] 26 | num_frames_tocache: 3 27 | ffn_expansion_factor: 2.5 28 | 29 | encoder1_attn_type1 : "ReducedAttn" 30 | encoder1_attn_type2 : "ReducedAttn" 31 | encoder1_ffw_type : "FFW" 32 | 33 | encoder2_attn_type1 : "ReducedAttn" 34 | encoder2_attn_type2 : "ReducedAttn" 35 | encoder2_ffw_type : "FFW" 36 | 37 | encoder3_attn_type1 : "Channel" 38 | encoder3_attn_type2 : "Channel" 39 | encoder3_ffw_type : "GFFW" 40 | 41 | decoder1_attn_type1 : "Channel" 42 | decoder1_attn_type2 : "CHM" 43 | decoder1_ffw_type : "GFFW" 44 | 45 | decoder2_attn_type1 : "Channel" 46 | decoder2_attn_type2 : "CHM" 47 | decoder2_ffw_type : "GFFW" 48 | 49 | decoder3_attn_type1 : "Channel" 50 | decoder3_attn_type2 : "CHM" 51 | decoder3_ffw_type : "GFFW" 52 | 53 | latent_attn_type1 : "FHR" 54 | latent_attn_type2 : "Channel" 55 | latent_attn_type3 : "FHR" 56 | latent_ffw_type : "GFFW" 57 | 58 | refinement_attn_type1 : "ReducedAttn" 59 | refinement_attn_type2 : "ReducedAttn" 60 | refinement_ffw_type : "GFFW" 61 | 62 | prompt_attn: "NoAttn" 63 | prompt_ffw: "GFFW" 64 | 65 | datasets: 66 | train: 67 | name: mvsr-train 68 | filename_tmpl: '{}' 69 | io_backend: 70 | type: lmdb 71 | 72 | gt_size: 192 73 | use_flip: false 74 | use_rot: false 75 | 76 | # data loader 77 | use_shuffle: true 78 | num_worker_per_gpu: 8 79 | batch_size_per_gpu: 2 80 | dataset_enlarge_ratio: 1 81 | prefetch_mode: ~ 82 | 83 | val: 84 | name: mvsr-test 85 | dir_data: ['/datasets/MVSR4x/test/'] 86 | 87 | path: 88 | pretrain_network_g: ~ 89 | strict_load_g: true 90 | resume_state: ~ 91 | 92 | train: 93 | optim_g: 94 | type: Adam 95 | lr: !!float 4e-4 96 | weight_decay: 0 97 | betas: [0.9, 0.99] 98 | 99 | scheduler: 100 | type: TrueCosineAnnealingLR 101 | T_max: 200000 102 | eta_min: !!float 1e-7 103 | 104 | total_iter: 200000 105 | warmup_iter: -1 # no warm up 106 | 107 | # losses 108 | pixel_opt: 109 | type: L1Loss 110 | loss_weight: 1 111 | reduction: mean 112 | 113 | # validation settings 114 | val: 115 | val_freq: 10000 116 | save_img: true 117 | grids: true 118 | crop_size: 192 119 | max_minibatch: 8 120 | 121 | metrics: 122 | psnr: # metric name, can be arbitrary 123 | type: calculate_psnr 124 | crop_border: 0 125 | test_y_channel: false 126 | 127 | # logging settings 128 | logger: 129 | print_freq: 200 130 | save_checkpoint_freq: !!float 5000 131 | use_tb_logger: true 132 | wandb: 133 | project: ~ 134 | resume_id: ~ 135 | 136 | # dist training settings 137 | dist_params: 138 | backend: nccl 139 | port: 29500 140 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-truncated-causal-history-model-for/deblurring-on-beam-splitter-deblurring-bsd)](https://paperswithcode.com/sota/deblurring-on-beam-splitter-deblurring-bsd?p=learning-truncated-causal-history-model-for) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-truncated-causal-history-model-for/rain-removal-on-nighrain)](https://paperswithcode.com/sota/rain-removal-on-nighrain?p=learning-truncated-causal-history-model-for) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-truncated-causal-history-model-for/snow-removal-on-rvsd)](https://paperswithcode.com/sota/snow-removal-on-rvsd?p=learning-truncated-causal-history-model-for) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-truncated-causal-history-model-for/video-deraining-on-vrds)](https://paperswithcode.com/sota/video-deraining-on-vrds?p=learning-truncated-causal-history-model-for) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-truncated-causal-history-model-for/video-denoising-on-set8-sigma50)](https://paperswithcode.com/sota/video-denoising-on-set8-sigma50?p=learning-truncated-causal-history-model-for) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/learning-truncated-causal-history-model-for/deblurring-on-gopro)](https://paperswithcode.com/sota/deblurring-on-gopro?p=learning-truncated-causal-history-model-for) 7 | 8 | 9 | 10 | # Lego Turtle **Turtle: Learning Truncated Causal History Model for Video Restoration [NeurIPS'2024]** 11 | 12 | [📄 arxiv](https://arxiv.org/abs/2410.03936) 13 | **|** 14 | [🌐 Website](https://kjanjua26.github.io/turtle/) 15 | 16 | The official PyTorch implementation for **Learning Truncated Causal History Model for Video Restoration**, accepted to NeurIPS 2024. 17 | 18 | - Turtle achieves state-of-the-art results on multiple video restoration benchmarks, offering superior computational efficiency and enhanced restoration quality 🔥🔥🔥. 19 | - **🛠️💡Model Forge**: Easily design your own architecture by modifying the option file. 20 | - You have the flexibility to choose from various types of layers—such as channel attention, simple channel attention, CHM, FHR, or custom blocks—as well as different types of feed-forward layers. 21 | - This setup allows you to create custom networks and experiment with layer and feed-forward configurations to suit your needs. 22 | - If you like this project, please give us a ⭐ on Github!🚀 23 | 24 |

25 | Restored Video 1 26 | Restored Video 2 27 |

28 | 29 |

30 | Restored Video 3 31 | Restored Video 4 32 |

33 | 34 | ### 🔥 📰 News 🔥 35 | - Oct. 10, 2024: The paper is now available on [arxiv](http://export.arxiv.org/abs/2410.03936) along with the code and pretrained models. 36 | - Sept 25, 2024: Turtle is accepted to NeurIPS'2024. 37 | 38 | 39 | ## Table of Contents 40 | 1. [Installation](#installation) 41 | 2. [Trained Models](#trained-models) 42 | 3. [Dataset Preparation](#1-dataset-preparation) 43 | 4. [Training](#2-training) 44 | 5. [Evaluation](#3-evaluation) 45 | - [Testing the Model](#31-testing-the-model) 46 | - [Inference on Given Videos](#32-inference-on-given-videos) 47 | 6. [Model Complexity and Inference Speed](#4-model-complexity-and-inference-speed) 48 | 7. [Acknowledgments](#5-Acknowledgments) 49 | 8. [Citation](#6-citation) 50 | 51 | 52 | ## Installation 53 | This implementation is based on [BasicSR](https://github.com/xinntao/BasicSR) which is an open-source toolbox for image/video restoration tasks. 54 | 55 | ```python 56 | python 3.9.5 57 | pytorch 1.11.0 58 | cuda 11.3 59 | ``` 60 | 61 | ``` 62 | pip install -r requirements.txt 63 | python setup.py develop --no_cuda_ext 64 | ``` 65 | 66 | ## Trained Models 67 | 68 | You can download our trained models from Google Drive: [Trained Models](https://drive.google.com/drive/folders/1Mur4IboaNgEW5qyynTIHq8CSAGtyykrA?usp=sharing) 69 | 70 | 71 | ## 1. Dataset Preparation 72 | To obtain the datasets, follow the official instructions provided by each dataset's provider and download them into the dataset folder. You can download the datasets for each of the task from the following links (official sources reported by their respective authors). 73 | 74 | 1. Desnowing: [RSVD](https://haoyuchen.com/VideoDesnowing) 75 | 2. Raindrops and Rainstreaks Removal: [VRDS](https://hkustgz-my.sharepoint.com/personal/hwu375_connect_hkust-gz_edu_cn/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fhwu375%5Fconnect%5Fhkust%2Dgz%5Fedu%5Fcn%2FDocuments%2FVRDS&ga=1) 76 | 3. Night Deraining: [NightRain](https://drive.google.com/drive/folders/1zsW1D8Wtj_0GH1OOHSL7dwR_MIkZ8-zp?usp=sharing) 77 | 4. Synthetic Deblurring: [GoPro](https://seungjunnah.github.io/Datasets/gopro) 78 | 5. Real-World Deblurring: [BSD3ms-24ms](https://drive.google.com/drive/folders/1LKLCE_RqPF5chqWgmh3pj7cg-t9KM2Hd?usp=sharing) 79 | 6. Denoising: [DAVIS](https://github.com/m-tassano/fastdvdnet?tab=readme-ov-file) | [Set8](https://drive.google.com/drive/folders/11chLkbcX-oKGLOLONuDpXZM2-vujn_KD?usp=sharing) 80 | 7. Real-World Super Resolution: [MVSR](https://github.com/HITRainer/EAVSR?tab=readme-ov-file) 81 | 82 | The directory structure, including the ground truth ('gt') for reference frames and 'blur' for degraded images, should be organized as follows: 83 | 84 | ```bash 85 | ./datasets/ 86 | └── Dataset_name/ 87 | ├── train/ 88 | └── test/ 89 | ├── blur 90 | ├── video_1 91 | │ ├── Fame1 92 | │ .... 93 | └── video_n 94 | │ ├── Fame1 95 | │ .... 96 | └── gt 97 | ├── video_1 98 | │ ├── Fame1 99 | │ .... 100 | └── video_n 101 | │ ├── Fame1 102 | │ .... 103 | ``` 104 | 105 | ## 2. Training 106 | To train the model, make sure you select the appropriate data loader in the `train.py`. There are two options as follows. 107 | 108 | 1. For deblurring, denoising, deraining, etc. keep the following import line, and comment the superresolution one. 109 | `from basicsr.data.video_image_dataset import VideoImageDataset` 110 | 111 | 2. For superresolution, keep the following import line, and comment the previous one. 112 | `from basicsr.data.video_super_image_dataset import VideoSuperImageDataset as VideoImageDataset` 113 | 114 | ``` 115 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=8080 basicsr/train.py -opt /options/option_file_name.yml --launcher pytorch 116 | ``` 117 | 118 | ## 3. Evaluation 119 | 120 | The pretrained models can be downloaded from the [GDrive link](https://drive.google.com/drive/folders/1Mur4IboaNgEW5qyynTIHq8CSAGtyykrA?usp=sharing). 121 | 122 | ### 3.1 Testing the model 123 | To evaluate the pre-trained model use this command: 124 | 125 | ``` 126 | python inference.py 127 | ``` 128 | 129 | Adjust the function parameters in the Python file according to each task requirements: 130 | 1. `config`: Specify the path to the option file. 131 | 2. `model_path`: Provide the location of pre-trained model. 132 | 3. `dataset_name`: Select the dataset you are using ("RSVD", "GoPro", "SR", "NightRain", "DVD", "Set8"). 133 | 4. `task_name`: Choose the restoration task ("Desnowing", "Deblurring", "SR", "Deraining", "Denoising"). 134 | 5. `model_type`: Indicate the model type ("t0", "t1", "SR"). 135 | 6. `save_image`: Set to `True` if you want to save the output images; provide the output path in `image_out_path`. 136 | 7. `do_patches`: Enable if processing images in patches; adjust `tile` and `tile_overlap` as needed, default values are 320 and 128. 137 | 8. `y_channel_PSNR`: Enable if need to calculate PSNR/SSIM in Y Channel, default is set to False. 138 | 139 | 140 | ### 3.2 Running Turtle on Custom Videos: 141 | 142 | This pipeline processes a video by extracting frames and running a pre-trained model for tasks like desnowing: 143 | 144 | #### Step 1: Extract Frames from Video 145 | 146 | 1. Edit `video_to_frames.py`: 147 | - Set the `video_path` to your input video file. 148 | - Set the `output_folder` to save extracted frames. 149 | 150 | 2. Run the script: 151 | ```bash 152 | python video_to_frames.py 153 | ``` 154 | 155 | #### Step 2: Run Model Inference 156 | 157 | 1. Edit `inference_no_ground_truth.py`: 158 | - Set paths for `config`, `model_path`, `data_dir` (extracted frames), and `image_out_path` (output frames). 159 | 160 | 2. Run the script: 161 | ```bash 162 | python inference_no_ground_truth.py 163 | ``` 164 | 165 | 166 | ## 4. Model complexity and inference speed 167 | * To get the parameter count, MAC, and inference speed use this command: 168 | ``` 169 | python basicsr/models/archs/turtle_arch.py 170 | ``` 171 | 172 | ### Contributions 📝📝 173 | 174 | We invite the community to contribute to extending **TURTLE** to other low-level vision tasks. Below is a list of specific areas where contributions could be highly valuable if the models are open-sourced. If you have other suggestions or requests, please feel free to open an issue. 175 | 176 | 1. **Training TURTLE for Synthetic Super-Resolution Tasks** 177 | - **Bicubic (BI) Degradation**: Train on REDS, Vimeo90K and evaluate on REDSS4, Vimeo90K-T. 178 | - **Blur-Downsampling (BD) Degradation**: Train on Vimeo90K and evaluate on Vimeo90K-T, Vid4, UDM10. 179 | 180 | For more information on dataset selection and data preparation, please refer to Section 4.3 in this [paper](https://arxiv.org/pdf/2206.02146). 181 | 182 | 183 | ### Acknowledgments 184 | 185 | This codebase borrows from the following [BasicSR](https://github.com/xinntao/BasicSR) and [ShiftNet](https://github.com/dasongli1/Shift-Net) repositories. 186 | 187 | ### Citation 188 | 189 | If you find our work useful, please consider citing our paper in your research. 190 | 191 | 192 | ``` 193 | @inproceedings{ghasemabadilearning, 194 | title={Learning Truncated Causal History Model for Video Restoration}, 195 | author={Ghasemabadi, Amirhosein and Janjua, Muhammad Kamran and Salameh, Mohammad and Niu, Di}, 196 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems} 197 | } 198 | ``` 199 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy>=1.17 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | torch>=1.7 13 | torchvision 14 | tqdm 15 | yapf -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=79 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | blank_line_before_nested_class_or_def = true 12 | split_before_expression_after_opening_paren = true 13 | 14 | [isort] 15 | line_length = 79 16 | multi_line_output = 0 17 | known_standard_library = pkg_resources,setuptools 18 | known_first_party = basicsr 19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 20 | no_lines_before = STDLIB,LOCALFOLDER 21 | default_section = THIRDPARTY 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 3 | # Copyright 2018-2020 BasicSR Authors 4 | # ------------------------------------------------------------------------ 5 | #!/usr/bin/env python 6 | 7 | from setuptools import find_packages, setup 8 | 9 | import os 10 | import subprocess 11 | import sys 12 | import time 13 | import torch 14 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 15 | CUDAExtension) 16 | 17 | version_file = 'basicsr/version.py' 18 | 19 | 20 | def readme(): 21 | return '' 22 | 23 | 24 | def get_git_hash(): 25 | 26 | def _minimal_ext_cmd(cmd): 27 | # construct minimal environment 28 | env = {} 29 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 30 | v = os.environ.get(k) 31 | if v is not None: 32 | env[k] = v 33 | # LANGUAGE is used on win32 34 | env['LANGUAGE'] = 'C' 35 | env['LANG'] = 'C' 36 | env['LC_ALL'] = 'C' 37 | out = subprocess.Popen( 38 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 39 | return out 40 | 41 | try: 42 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 43 | sha = out.strip().decode('ascii') 44 | except OSError: 45 | sha = 'unknown' 46 | 47 | return sha 48 | 49 | 50 | def get_hash(): 51 | if os.path.exists('.git'): 52 | sha = get_git_hash()[:7] 53 | elif os.path.exists(version_file): 54 | try: 55 | from basicsr.version import __version__ 56 | sha = __version__.split('+')[-1] 57 | except ImportError: 58 | raise ImportError('Unable to get git version') 59 | else: 60 | sha = 'unknown' 61 | 62 | return sha 63 | 64 | 65 | def write_version_py(): 66 | content = """# GENERATED VERSION FILE 67 | # TIME: {} 68 | __version__ = '{}' 69 | short_version = '{}' 70 | version_info = ({}) 71 | """ 72 | sha = get_hash() 73 | with open('VERSION', 'r') as f: 74 | SHORT_VERSION = f.read().strip() 75 | VERSION_INFO = ', '.join( 76 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 77 | VERSION = SHORT_VERSION + '+' + sha 78 | 79 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 80 | VERSION_INFO) 81 | with open(version_file, 'w') as f: 82 | f.write(version_file_str) 83 | 84 | 85 | def get_version(): 86 | with open(version_file, 'r') as f: 87 | exec(compile(f.read(), version_file, 'exec')) 88 | return locals()['__version__'] 89 | 90 | 91 | def make_cuda_ext(name, module, sources, sources_cuda=None): 92 | if sources_cuda is None: 93 | sources_cuda = [] 94 | define_macros = [] 95 | extra_compile_args = {'cxx': []} 96 | 97 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 98 | define_macros += [('WITH_CUDA', None)] 99 | extension = CUDAExtension 100 | extra_compile_args['nvcc'] = [ 101 | '-D__CUDA_NO_HALF_OPERATORS__', 102 | '-D__CUDA_NO_HALF_CONVERSIONS__', 103 | '-D__CUDA_NO_HALF2_OPERATORS__', 104 | ] 105 | sources += sources_cuda 106 | else: 107 | print(f'Compiling {name} without CUDA') 108 | extension = CppExtension 109 | 110 | return extension( 111 | name=f'{module}.{name}', 112 | sources=[os.path.join(*module.split('.'), p) for p in sources], 113 | define_macros=define_macros, 114 | extra_compile_args=extra_compile_args) 115 | 116 | 117 | def get_requirements(filename='requirements.txt'): 118 | return [] 119 | here = os.path.dirname(os.path.realpath(__file__)) 120 | with open(os.path.join(here, filename), 'r') as f: 121 | requires = [line.replace('\n', '') for line in f.readlines()] 122 | return requires 123 | 124 | 125 | if __name__ == '__main__': 126 | if '--no_cuda_ext' in sys.argv: 127 | ext_modules = [] 128 | sys.argv.remove('--no_cuda_ext') 129 | else: 130 | ext_modules = [ 131 | make_cuda_ext( 132 | name='deform_conv_ext', 133 | module='basicsr.models.ops.dcn', 134 | sources=['src/deform_conv_ext.cpp'], 135 | sources_cuda=[ 136 | 'src/deform_conv_cuda.cpp', 137 | 'src/deform_conv_cuda_kernel.cu' 138 | ]), 139 | make_cuda_ext( 140 | name='fused_act_ext', 141 | module='basicsr.models.ops.fused_act', 142 | sources=['src/fused_bias_act.cpp'], 143 | sources_cuda=['src/fused_bias_act_kernel.cu']), 144 | make_cuda_ext( 145 | name='upfirdn2d_ext', 146 | module='basicsr.models.ops.upfirdn2d', 147 | sources=['src/upfirdn2d.cpp'], 148 | sources_cuda=['src/upfirdn2d_kernel.cu']), 149 | ] 150 | 151 | write_version_py() 152 | setup( 153 | name='basicsr', 154 | version=get_version(), 155 | description='Open Source Image and Video Super-Resolution Toolbox', 156 | long_description=readme(), 157 | author='Xintao Wang', 158 | author_email='xintao.wang@outlook.com', 159 | keywords='computer vision, restoration, super resolution', 160 | url='https://github.com/xinntao/BasicSR', 161 | packages=find_packages( 162 | exclude=('options', 'datasets', 'experiments', 'results', 163 | 'tb_logger', 'wandb')), 164 | classifiers=[ 165 | 'Development Status :: 4 - Beta', 166 | 'License :: OSI Approved :: Apache Software License', 167 | 'Operating System :: OS Independent', 168 | 'Programming Language :: Python :: 3', 169 | 'Programming Language :: Python :: 3.7', 170 | 'Programming Language :: Python :: 3.8', 171 | ], 172 | license='Apache License 2.0', 173 | setup_requires=['cython', 'numpy'], 174 | install_requires=get_requirements(), 175 | ext_modules=ext_modules, 176 | cmdclass={'build_ext': BuildExtension}, 177 | zip_safe=False) 178 | -------------------------------------------------------------------------------- /video_to_frames.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import shutil 4 | 5 | """ 6 | This script converts a video into individual frames and saves them to a specified directory. 7 | Only a subset of frames is saved based on a defined interval to match the desired frame rate for further processing. 8 | """ 9 | 10 | 11 | # Path to the video file 12 | video_path = 'video.mp4' 13 | 14 | # Directory to save the frames 15 | output_folder = 'output_dir' 16 | 17 | if os.path.exists(output_folder): 18 | shutil.rmtree(output_folder) 19 | os.makedirs(output_folder, exist_ok=True) 20 | 21 | 22 | 23 | # Load the video 24 | cap = cv2.VideoCapture(video_path) 25 | 26 | # Get the frame rate of the video 27 | fps = cap.get(cv2.CAP_PROP_FPS) 28 | frames_per_second = 10 29 | interval = int(fps / frames_per_second) 30 | 31 | frame_count = 0 32 | saved_frame_count = 0 33 | 34 | while cap.isOpened(): 35 | ret, frame = cap.read() 36 | if not ret: 37 | break 38 | 39 | # Save frame if it's in the interval 40 | if frame_count % interval == 0: 41 | frame_filename = os.path.join(output_folder, f'frame_{saved_frame_count:04d}.png') 42 | cv2.imwrite(frame_filename, frame) 43 | saved_frame_count += 1 44 | 45 | frame_count += 1 46 | 47 | cap.release() 48 | print(f'Extracted {saved_frame_count} frames to {output_folder}') 49 | --------------------------------------------------------------------------------