├── LICENSE ├── ProRes ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dataset_specific_learnt.cpython-38.pyc │ │ ├── pair_transforms.cpython-37.pyc │ │ ├── pair_transforms.cpython-38.pyc │ │ ├── sampler.cpython-38.pyc │ │ └── unif.cpython-38.pyc │ ├── data_mixup.py │ ├── data_multi.py │ ├── dataset_simple.py │ ├── dataset_specific_learnt.py │ ├── finetune_dataset.py │ ├── gen_json_deblur.py │ ├── gen_json_dehaze.py │ ├── gen_json_dpdd.py │ ├── gen_json_fivek.py │ ├── gen_json_lol.py │ ├── gen_json_rain.py │ ├── gen_json_sidd.py │ ├── pair_transforms.py │ ├── sampler.py │ ├── unif.py │ └── utils │ │ ├── check_black_image.py │ │ ├── gen_image_flip.py │ │ ├── get_num_obj.py │ │ ├── get_random_propmts.py │ │ ├── get_subset.py │ │ └── get_train2017_subset.py ├── datasets │ └── low_level │ │ ├── deblur.npy │ │ ├── derain.npy │ │ ├── enhance.npy │ │ ├── groundtruth-deblur_gopro_val.json │ │ ├── groundtruth-denoise_ssid_train448.json │ │ ├── groundtruth-denoise_ssid_val256.json │ │ ├── groundtruth_crop-deblur_gopro_train.json │ │ ├── gt-enhance_lol_eval.json │ │ ├── gt-enhance_lol_train.json │ │ ├── ssid.npy │ │ ├── target-derain_test_rain100h.json │ │ └── target-derain_train.json ├── demo │ ├── ddp_utils.py │ ├── eval_sidd.m │ ├── evaluate_PSNR_SSIM.m │ ├── matrix_nms.py │ ├── ours_inference_deblur_v2.py │ ├── ours_inference_derain_v2.py │ ├── ours_inference_lol_v2.py │ ├── ours_inference_sidd_v2.py │ └── prompt_save.py ├── engine_pretrain.py ├── eval_ours.sh ├── main_pretrain.py ├── masking_generator.py ├── models_ours.py ├── scripts │ └── train.sh └── util │ ├── __pycache__ │ ├── lr_decay.cpython-37.pyc │ ├── lr_decay.cpython-38.pyc │ ├── lr_sched.cpython-37.pyc │ ├── lr_sched.cpython-38.pyc │ ├── misc.cpython-37.pyc │ ├── misc.cpython-38.pyc │ ├── pos_embed.cpython-37.pyc │ ├── pos_embed.cpython-38.pyc │ ├── vitdet_utils.cpython-37.pyc │ └── vitdet_utils.cpython-38.pyc │ ├── crop.py │ ├── datasets.py │ ├── lars.py │ ├── lr_decay.py │ ├── lr_sched.py │ ├── metrics.py │ ├── misc.py │ ├── pos_embed.py │ └── vitdet_utils.py ├── README.md ├── figures ├── S1_independent.jpg ├── S2_irrelevant.jpg ├── S3_combine.jpg ├── intro_figure.jpg ├── main_figure.jpg ├── tuning_fivek.jpg └── tuning_reside.jpg ├── index.html └── static ├── css ├── bulma-carousel.min.css ├── bulma-slider.min.css ├── bulma.css.map.txt ├── bulma.min.css ├── fontawesome.all.min.css └── index.css ├── images └── icon.jpg └── js ├── bulma-carousel.js ├── bulma-carousel.min.js ├── bulma-slider.js ├── bulma-slider.min.js ├── fontawesome.all.min.js └── index.js /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jiaqi Ma 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ProRes/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__init__.py -------------------------------------------------------------------------------- /ProRes/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ProRes/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/data/__pycache__/dataset_specific_learnt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__pycache__/dataset_specific_learnt.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/data/__pycache__/pair_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__pycache__/pair_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /ProRes/data/__pycache__/pair_transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__pycache__/pair_transforms.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/data/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/data/__pycache__/unif.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/data/__pycache__/unif.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/data/data_mixup.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | from typing import Any, Callable, List, Optional, Tuple 4 | import random 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision.datasets.vision import VisionDataset, StandardTransform 11 | import torch.nn.functional as F 12 | 13 | class PairDataset(VisionDataset): 14 | """`MS Coco Detection `_ Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root (string): Root directory where images are downloaded to. 20 | annFile (string): Path to json annotation file. 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 26 | and returns a transformed version. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | json_path_list: list, 33 | transform: Optional[Callable] = None, 34 | transform2: Optional[Callable] = None, 35 | transform3: Optional[Callable] = None, 36 | target_transform: Optional[Callable] = None, 37 | transforms: Optional[Callable] = None, 38 | masked_position_generator: Optional[Callable] = None, 39 | use_two_pairs: bool = False, 40 | half_mask_ratio:float = 0., 41 | ) -> None: 42 | super().__init__(root, transforms, transform, target_transform) 43 | 44 | self.pairs = [] 45 | self.weights = [] 46 | #type_weight_list = [10, 20, 2, 20, 40, 20, 2, 2, 2, 2] 47 | type_weight_list = [2, 3, 1, 2] 48 | # type_weight_list = [0.033, 1, 0.004, 0.029] 49 | #type_weight_list = [0.1, 0.25, 0.2, 0.2, 0.1, 0.05, 0.05, 0.05, 0.01] 50 | # type_weight_list = [0.1, 0.2, 0.15, 0.25, 0.2, 0.15, 0.05, 0.05, 0.01] 51 | #type_weight_list = [0.1, 0.15, 0.15, 0.3, 0.3, 0.2, 0.05, 0.05, 0.01] 52 | # type_weight_list = [0.04, 0.96] 53 | # derain 13712 enhance 485 54 | for idx, json_path in enumerate(json_path_list): 55 | cur_pairs = json.load(open(json_path)) 56 | self.pairs.extend(cur_pairs) 57 | cur_num = len(cur_pairs) 58 | self.weights.extend([type_weight_list[idx] * 1./cur_num]*cur_num) 59 | print(json_path, type_weight_list[idx]) 60 | 61 | #self.weights = [1./n for n in self.weights] 62 | self.use_two_pairs = use_two_pairs 63 | if self.use_two_pairs: 64 | self.pair_type_dict = {} 65 | for idx, pair in enumerate(self.pairs): 66 | if "type" in pair: 67 | if pair["type"] not in self.pair_type_dict: 68 | self.pair_type_dict[pair["type"]] = [idx] 69 | else: 70 | self.pair_type_dict[pair["type"]].append(idx) 71 | for t in self.pair_type_dict: 72 | print(t, len(self.pair_type_dict[t])) 73 | 74 | self.transforms = PairStandardTransform(transform, target_transform) if transform is not None else None 75 | self.transforms2 = PairStandardTransform(transform2, target_transform) if transform2 is not None else None 76 | self.transforms3 = PairStandardTransform(transform3, target_transform) if transform3 is not None else None 77 | self.masked_position_generator = masked_position_generator 78 | self.half_mask_ratio = half_mask_ratio 79 | 80 | def _load_image(self, path: str) -> Image.Image: 81 | while True: 82 | try: 83 | img = Image.open(os.path.join(self.root, path)) 84 | except OSError as e: 85 | print(f"Catched exception: {str(e)}. Re-trying...") 86 | import time 87 | time.sleep(1) 88 | else: 89 | break 90 | # process for nyuv2 depth: scale to 0~255 91 | if "sync_depth" in path: 92 | # nyuv2's depth range is 0~10m 93 | img = np.array(img) / 10000. 94 | img = img * 255 95 | img = Image.fromarray(img) 96 | img = img.convert("RGB") 97 | return img 98 | 99 | def _random_add_prompts_random_scales(self,image, prompt, prompt_range=[8,64], scale_range=[0.2,0.3]): 100 | image = np.asarray(image.permute(1,2,0)) 101 | prompt = prompt 102 | 103 | h, w = image.shape[0],image.shape[1] 104 | 105 | mask_image = np.ones((int(h),int(w),3),dtype=np.float32) 106 | mask_prompt = np.zeros((int(h),int(w),3),dtype=np.float32) 107 | 108 | ratio = 0 109 | 110 | while (scale_range[0] > ratio) == True or (ratio > scale_range[1])!=True: 111 | h_p = w_p = int(random.uniform(prompt_range[0], prompt_range[1])) 112 | point_h = int(random.uniform(h_p, h-h_p)) 113 | point_w = int(random.uniform(w_p, w-w_p)) 114 | 115 | mask_image[point_h:point_h+h_p,point_w:point_w+w_p] = 0.0 116 | mask_prompt[point_h:point_h+h_p,point_w:point_w+w_p] = 1.0 117 | prompts_token_num = np.sum(mask_prompt) 118 | ratio = prompts_token_num/(h*w) 119 | 120 | # image = image*mask_image 121 | 122 | # prompt = prompt * mask_prompt 123 | image = image + prompt 124 | 125 | return image 126 | 127 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 128 | pair = self.pairs[index] 129 | 130 | pair_type = pair['type'] 131 | if 'derain' in pair_type: 132 | type_dict = 'derain' 133 | elif 'enhance' in pair_type: 134 | type_dict = 'enhance' 135 | elif 'ssid' in pair_type: 136 | type_dict = 'ssid' 137 | elif 'deblur' in pair_type: 138 | type_dict = 'deblur' 139 | 140 | 141 | 142 | interpolation1 = 'bicubic' 143 | interpolation2 = 'bicubic' 144 | cur_transforms = self.transforms 145 | image = self._load_image(pair['image_path']) 146 | target = self._load_image(pair['target_path']) 147 | image, target = cur_transforms(image, target, interpolation1, interpolation2) 148 | image_ori = image 149 | 150 | prompt_dict = { 151 | 'prompt_derain': 'datasets/low_level/derain.npy', 152 | 'prompt_enhance': 'datasets/low_level/enhance.npy', 153 | 'prompt_ssid': 'datasets/low_level/ssid.npy', 154 | 'prompt_deblur': 'datasets/low_level/deblur.npy' 155 | } 156 | 157 | binary_posneg = np.random.binomial(n=1, p=0.75) 158 | if binary_posneg == 1: 159 | # use original 160 | key = next(k for k, v in prompt_dict.items() if type_dict in k) 161 | prompt = np.load(prompt_dict[key]) 162 | flag = torch.ones(()) 163 | 164 | binary_mixup = np.random.binomial(n=1, p=0.75) 165 | if binary_mixup ==1: 166 | alpha = 0.2 167 | lam = np.random.beta(alpha, alpha) 168 | rand_index = np.random.randint(0, len(self.pairs)) 169 | rand_pair = self.pairs[rand_index] 170 | rand_pair_type = rand_pair['type'] 171 | if 'derain' in rand_pair_type: 172 | rand_type_dict = 'derain' 173 | elif 'enhance' in rand_pair_type: 174 | rand_type_dict = 'enhance' 175 | elif 'ssid' in rand_pair_type: 176 | rand_type_dict = 'ssid' 177 | elif 'deblur' in rand_pair_type: 178 | rand_type_dict = 'deblur' 179 | rand_key = next(k for k, v in prompt_dict.items() if rand_type_dict in k) 180 | # print(rand_type_dict, rand_key) 181 | # exit() 182 | rand_prompt = np.load(prompt_dict[rand_key]) 183 | 184 | rand_image = self._load_image(rand_pair['image_path']) 185 | rand_target = self._load_image(rand_pair['target_path']) 186 | rand_image, rand_target = cur_transforms(rand_image, rand_target, interpolation1, interpolation2) 187 | 188 | # first resize to 448*448 and combine them 189 | image = self._random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.95,0.99]) 190 | rand_image = self._random_add_prompts_random_scales(rand_image, rand_prompt, prompt_range=[8,64], scale_range=[0.95,0.99]) 191 | image = torch.from_numpy(image.transpose(2, 0, 1)) 192 | rand_image = torch.from_numpy(rand_image.transpose(2, 0, 1)) 193 | 194 | image = lam * image + (1 - lam) * rand_image 195 | target = lam * target + (1 - lam) * rand_target 196 | else: 197 | image = self._random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.95,0.99]) 198 | image = torch.from_numpy(image.transpose(2, 0, 1)) 199 | 200 | else: 201 | # remove original 202 | keys_to_remove = {k for k, v in prompt_dict.items() if type_dict in k} 203 | for key_to_remove in keys_to_remove: 204 | del prompt_dict[key_to_remove] 205 | key = random.choice(list(prompt_dict.keys())) 206 | prompt = np.load(prompt_dict[key]) 207 | flag = -torch.ones(()) 208 | image = self._random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.95,0.99]) 209 | image = torch.from_numpy(image.transpose(2, 0, 1)) 210 | 211 | 212 | valid = torch.ones_like(target) 213 | mask = self.masked_position_generator() 214 | 215 | 216 | return image, target, mask, valid, flag, image_ori 217 | 218 | def __len__(self) -> int: 219 | return len(self.pairs) 220 | 221 | 222 | class PairStandardTransform(StandardTransform): 223 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 224 | super().__init__(transform=transform, target_transform=target_transform) 225 | 226 | def __call__(self, input: Any, target: Any, interpolation1: Any, interpolation2: Any) -> Tuple[Any, Any]: 227 | if self.transform is not None: 228 | input, target = self.transform(input, target, interpolation1, interpolation2) 229 | #if self.target_transform is not None: 230 | # target = self.target_transform(target) 231 | return input, target 232 | -------------------------------------------------------------------------------- /ProRes/data/data_multi.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | from typing import Any, Callable, List, Optional, Tuple 4 | import random 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision.datasets.vision import VisionDataset, StandardTransform 11 | import torch.nn.functional as F 12 | 13 | class PairDataset(VisionDataset): 14 | """`MS Coco Detection `_ Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root (string): Root directory where images are downloaded to. 20 | annFile (string): Path to json annotation file. 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 26 | and returns a transformed version. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | json_path_list: list, 33 | transform: Optional[Callable] = None, 34 | transform2: Optional[Callable] = None, 35 | transform3: Optional[Callable] = None, 36 | target_transform: Optional[Callable] = None, 37 | transforms: Optional[Callable] = None, 38 | masked_position_generator: Optional[Callable] = None, 39 | use_two_pairs: bool = False, 40 | half_mask_ratio:float = 0., 41 | ) -> None: 42 | super().__init__(root, transforms, transform, target_transform) 43 | 44 | self.pairs = [] 45 | self.weights = [] 46 | #type_weight_list = [10, 20, 2, 20, 40, 20, 2, 2, 2, 2] 47 | type_weight_list = [3, 3, 1, 3] 48 | # type_weight_list = [0.033, 1, 0.004, 0.029] 49 | #type_weight_list = [0.1, 0.25, 0.2, 0.2, 0.1, 0.05, 0.05, 0.05, 0.01] 50 | # type_weight_list = [0.1, 0.2, 0.15, 0.25, 0.2, 0.15, 0.05, 0.05, 0.01] 51 | #type_weight_list = [0.1, 0.15, 0.15, 0.3, 0.3, 0.2, 0.05, 0.05, 0.01] 52 | # type_weight_list = [0.04, 0.96] 53 | # derain 13712 enhance 485 54 | for idx, json_path in enumerate(json_path_list): 55 | cur_pairs = json.load(open(json_path)) 56 | self.pairs.extend(cur_pairs) 57 | cur_num = len(cur_pairs) 58 | self.weights.extend([type_weight_list[idx] * 1./cur_num]*cur_num) 59 | print(json_path, type_weight_list[idx]) 60 | 61 | #self.weights = [1./n for n in self.weights] 62 | self.use_two_pairs = use_two_pairs 63 | if self.use_two_pairs: 64 | self.pair_type_dict = {} 65 | for idx, pair in enumerate(self.pairs): 66 | if "type" in pair: 67 | if pair["type"] not in self.pair_type_dict: 68 | self.pair_type_dict[pair["type"]] = [idx] 69 | else: 70 | self.pair_type_dict[pair["type"]].append(idx) 71 | for t in self.pair_type_dict: 72 | print(t, len(self.pair_type_dict[t])) 73 | 74 | self.transforms = PairStandardTransform(transform, target_transform) if transform is not None else None 75 | self.transforms2 = PairStandardTransform(transform2, target_transform) if transform2 is not None else None 76 | self.transforms3 = PairStandardTransform(transform3, target_transform) if transform3 is not None else None 77 | self.masked_position_generator = masked_position_generator 78 | self.half_mask_ratio = half_mask_ratio 79 | 80 | def _load_image(self, path: str) -> Image.Image: 81 | while True: 82 | try: 83 | img = Image.open(os.path.join(self.root, path)) 84 | except OSError as e: 85 | print(f"Catched exception: {str(e)}. Re-trying...") 86 | import time 87 | time.sleep(1) 88 | else: 89 | break 90 | # process for nyuv2 depth: scale to 0~255 91 | if "sync_depth" in path: 92 | # nyuv2's depth range is 0~10m 93 | img = np.array(img) / 10000. 94 | img = img * 255 95 | img = Image.fromarray(img) 96 | img = img.convert("RGB") 97 | return img 98 | 99 | def _random_add_prompts_random_scales(self,image, prompt, prompt_range=[8,64], scale_range=[0.2,0.3]): 100 | image = np.asarray(image.permute(1,2,0)) 101 | prompt = prompt 102 | 103 | h, w = image.shape[0],image.shape[1] 104 | 105 | mask_image = np.ones((int(h),int(w),3),dtype=np.float32) 106 | mask_prompt = np.zeros((int(h),int(w),3),dtype=np.float32) 107 | 108 | ratio = 0 109 | 110 | while (scale_range[0] > ratio) == True or (ratio > scale_range[1])!=True: 111 | h_p = w_p = int(random.uniform(prompt_range[0], prompt_range[1])) 112 | point_h = int(random.uniform(h_p, h-h_p)) 113 | point_w = int(random.uniform(w_p, w-w_p)) 114 | 115 | mask_image[point_h:point_h+h_p,point_w:point_w+w_p] = 0.0 116 | mask_prompt[point_h:point_h+h_p,point_w:point_w+w_p] = 1.0 117 | prompts_token_num = np.sum(mask_prompt) 118 | ratio = prompts_token_num/(h*w) 119 | 120 | # image = image*mask_image 121 | 122 | # prompt = prompt * mask_prompt 123 | image = image + prompt 124 | 125 | return image 126 | 127 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 128 | pair = self.pairs[index] 129 | 130 | pair_type = pair['type'] 131 | if 'derain' in pair_type: 132 | prompt = np.load('datasets/low_level/derain.npy') 133 | elif 'enhance' in pair_type: 134 | prompt = np.load('datasets/low_level/enhance.npy') 135 | elif 'ssid' in pair_type: 136 | prompt = np.load('datasets/low_level/ssid.npy') 137 | elif 'deblur' in pair_type: 138 | prompt = np.load('datasets/low_level/deblur.npy') 139 | 140 | prompt = torch.from_numpy(prompt.transpose(2, 0, 1)) 141 | 142 | interpolation1 = 'bicubic' 143 | interpolation2 = 'bicubic' 144 | cur_transforms = self.transforms 145 | image = self._load_image(pair['image_path']) 146 | target = self._load_image(pair['target_path']) 147 | image, target = cur_transforms(image, target, interpolation1, interpolation2) 148 | 149 | #image = torch.from_numpy(image.transpose(2, 0, 1)) 150 | 151 | 152 | valid = torch.ones_like(target) 153 | mask = self.masked_position_generator() 154 | 155 | 156 | return image, target, mask, valid, prompt 157 | 158 | def __len__(self) -> int: 159 | return len(self.pairs) 160 | 161 | 162 | class PairStandardTransform(StandardTransform): 163 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 164 | super().__init__(transform=transform, target_transform=target_transform) 165 | 166 | def __call__(self, input: Any, target: Any, interpolation1: Any, interpolation2: Any) -> Tuple[Any, Any]: 167 | if self.transform is not None: 168 | input, target = self.transform(input, target, interpolation1, interpolation2) 169 | #if self.target_transform is not None: 170 | # target = self.target_transform(target) 171 | return input, target 172 | -------------------------------------------------------------------------------- /ProRes/data/dataset_simple.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | from typing import Any, Callable, List, Optional, Tuple 4 | import random 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision.datasets.vision import VisionDataset, StandardTransform 11 | import torch.nn.functional as F 12 | 13 | class PairDataset(VisionDataset): 14 | """`MS Coco Detection `_ Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root (string): Root directory where images are downloaded to. 20 | annFile (string): Path to json annotation file. 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 26 | and returns a transformed version. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | json_path_list: list, 33 | transform: Optional[Callable] = None, 34 | transform2: Optional[Callable] = None, 35 | transform3: Optional[Callable] = None, 36 | target_transform: Optional[Callable] = None, 37 | transforms: Optional[Callable] = None, 38 | masked_position_generator: Optional[Callable] = None, 39 | use_two_pairs: bool = False, 40 | half_mask_ratio:float = 0., 41 | ) -> None: 42 | super().__init__(root, transforms, transform, target_transform) 43 | 44 | self.pairs = [] 45 | self.weights = [] 46 | #type_weight_list = [10, 20, 2, 20, 40, 20, 2, 2, 2, 2] 47 | type_weight_list = [1] 48 | # type_weight_list = [0.033, 1, 0.004, 0.029] 49 | # derain 13712 enhance 485 50 | for idx, json_path in enumerate(json_path_list): 51 | cur_pairs = json.load(open(json_path)) 52 | self.pairs.extend(cur_pairs) 53 | cur_num = len(cur_pairs) 54 | self.weights.extend([type_weight_list[idx] * 1./cur_num]*cur_num) 55 | print(json_path, type_weight_list[idx]) 56 | 57 | #self.weights = [1./n for n in self.weights] 58 | self.use_two_pairs = use_two_pairs 59 | if self.use_two_pairs: 60 | self.pair_type_dict = {} 61 | for idx, pair in enumerate(self.pairs): 62 | if "type" in pair: 63 | if pair["type"] not in self.pair_type_dict: 64 | self.pair_type_dict[pair["type"]] = [idx] 65 | else: 66 | self.pair_type_dict[pair["type"]].append(idx) 67 | for t in self.pair_type_dict: 68 | print(t, len(self.pair_type_dict[t])) 69 | 70 | self.transforms = PairStandardTransform(transform, target_transform) if transform is not None else None 71 | self.transforms2 = PairStandardTransform(transform2, target_transform) if transform2 is not None else None 72 | self.transforms3 = PairStandardTransform(transform3, target_transform) if transform3 is not None else None 73 | self.masked_position_generator = masked_position_generator 74 | self.half_mask_ratio = half_mask_ratio 75 | 76 | def _load_image(self, path: str) -> Image.Image: 77 | while True: 78 | try: 79 | img = Image.open(os.path.join(self.root, path)) 80 | except OSError as e: 81 | print(f"Catched exception: {str(e)}. Re-trying...") 82 | import time 83 | time.sleep(1) 84 | else: 85 | break 86 | # process for nyuv2 depth: scale to 0~255 87 | if "sync_depth" in path: 88 | # nyuv2's depth range is 0~10m 89 | img = np.array(img) / 10000. 90 | img = img * 255 91 | img = Image.fromarray(img) 92 | img = img.convert("RGB") 93 | return img 94 | 95 | 96 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 97 | pair = self.pairs[index] 98 | 99 | image = self._load_image(pair['image_path']) 100 | target = self._load_image(pair['target_path']) 101 | 102 | 103 | interpolation1 = 'bicubic' 104 | interpolation2 = 'bicubic' 105 | 106 | # no aug for instance segmentation 107 | if "inst" in pair['type'] and self.transforms2 is not None: 108 | cur_transforms = self.transforms2 109 | elif "pose" in pair['type'] and self.transforms3 is not None: 110 | cur_transforms = self.transforms3 111 | else: 112 | cur_transforms = self.transforms 113 | 114 | image, target = cur_transforms(image, target, interpolation1, interpolation2) 115 | 116 | 117 | valid = torch.ones_like(target) 118 | imagenet_mean=torch.tensor([0.485, 0.456, 0.406]) 119 | imagenet_std=torch.tensor([0.229, 0.224, 0.225]) 120 | 121 | 122 | mask = self.masked_position_generator() 123 | # mask all 0 124 | # valid all 1 125 | # Why? 126 | # 1 masked patch 127 | # 0 valid patch 128 | 129 | return image, target, mask, valid 130 | 131 | def __len__(self) -> int: 132 | return len(self.pairs) 133 | 134 | 135 | class PairStandardTransform(StandardTransform): 136 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 137 | super().__init__(transform=transform, target_transform=target_transform) 138 | 139 | def __call__(self, input: Any, target: Any, interpolation1: Any, interpolation2: Any) -> Tuple[Any, Any]: 140 | if self.transform is not None: 141 | input, target = self.transform(input, target, interpolation1, interpolation2) 142 | #if self.target_transform is not None: 143 | # target = self.target_transform(target) 144 | return input, target 145 | -------------------------------------------------------------------------------- /ProRes/data/dataset_specific_learnt.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | from typing import Any, Callable, List, Optional, Tuple 4 | import random 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision.datasets.vision import VisionDataset, StandardTransform 11 | import torch.nn.functional as F 12 | 13 | class PairDataset(VisionDataset): 14 | """`MS Coco Detection `_ Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root (string): Root directory where images are downloaded to. 20 | annFile (string): Path to json annotation file. 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 26 | and returns a transformed version. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | json_path_list: list, 33 | transform: Optional[Callable] = None, 34 | transform2: Optional[Callable] = None, 35 | transform3: Optional[Callable] = None, 36 | target_transform: Optional[Callable] = None, 37 | transforms: Optional[Callable] = None, 38 | ) -> None: 39 | super().__init__(root, transforms, transform, target_transform) 40 | 41 | self.pairs = [] 42 | self.weights = [] 43 | #type_weight_list = [1, 3, 3, 3, 3] 44 | type_weight_list = [1,1,1] 45 | # type_weight_list = [0.04, 0.96] 46 | # derain 13712 enhance 485 47 | for idx, json_path in enumerate(json_path_list): 48 | cur_pairs = json.load(open(json_path)) 49 | self.pairs.extend(cur_pairs) 50 | cur_num = len(cur_pairs) 51 | self.weights.extend([type_weight_list[idx] * 1./cur_num]*cur_num) 52 | print(json_path, type_weight_list[idx]) 53 | 54 | 55 | self.transforms = PairStandardTransform(transform, target_transform) if transform is not None else None 56 | self.transforms2 = PairStandardTransform(transform2, target_transform) if transform2 is not None else None 57 | self.transforms3 = PairStandardTransform(transform3, target_transform) if transform3 is not None else None 58 | 59 | 60 | def _load_image(self, path: str) -> Image.Image: 61 | while True: 62 | try: 63 | img = Image.open(os.path.join(self.root, path)) 64 | except OSError as e: 65 | print(f"Catched exception: {str(e)}. Re-trying...") 66 | import time 67 | time.sleep(1) 68 | else: 69 | break 70 | img = img.convert("RGB") 71 | return img 72 | 73 | 74 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 75 | pair = self.pairs[index] 76 | 77 | 78 | image = self._load_image(pair['image_path']) 79 | target = self._load_image(pair['target_path']) 80 | 81 | 82 | pair_type = pair['type'] 83 | # print(pair['image_path'],pair_type, bool('derain' in pair_type), bool('enhance' in pair_type), bool('ssid' in pair_type), bool('deblur' in pair_type)) 84 | # exit() 85 | if 'dehaze' in pair_type: 86 | type_dict = torch.tensor([1, 0, 0]).unsqueeze(0) 87 | elif 'denoise' in pair_type: 88 | type_dict = torch.tensor([0, 1, 0]).unsqueeze(0) 89 | elif 'derain' in pair_type: 90 | type_dict = torch.tensor([0, 0, 1]).unsqueeze(0) 91 | else: 92 | raise ValueError('Invalid path') 93 | 94 | interpolation1 = 'bicubic' 95 | interpolation2 = 'bicubic' 96 | cur_transforms = self.transforms 97 | image, target = cur_transforms(image, target, interpolation1, interpolation2) 98 | 99 | 100 | return image, target, type_dict 101 | 102 | def __len__(self) -> int: 103 | return len(self.pairs) 104 | 105 | 106 | class PairStandardTransform(StandardTransform): 107 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 108 | super().__init__(transform=transform, target_transform=target_transform) 109 | 110 | def __call__(self, input: Any, target: Any, interpolation1: Any, interpolation2: Any) -> Tuple[Any, Any]: 111 | if self.transform is not None: 112 | input, target = self.transform(input, target, interpolation1, interpolation2) 113 | #if self.target_transform is not None: 114 | # target = self.target_transform(target) 115 | return input, target 116 | -------------------------------------------------------------------------------- /ProRes/data/finetune_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | from typing import Any, Callable, List, Optional, Tuple 4 | import random 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision.datasets.vision import VisionDataset, StandardTransform 11 | import torch.nn.functional as F 12 | 13 | class PairDataset(VisionDataset): 14 | """`MS Coco Detection `_ Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root (string): Root directory where images are downloaded to. 20 | annFile (string): Path to json annotation file. 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 26 | and returns a transformed version. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | json_path_list: list, 33 | transform: Optional[Callable] = None, 34 | transform2: Optional[Callable] = None, 35 | transform3: Optional[Callable] = None, 36 | target_transform: Optional[Callable] = None, 37 | transforms: Optional[Callable] = None, 38 | masked_position_generator: Optional[Callable] = None, 39 | use_two_pairs: bool = False, 40 | half_mask_ratio:float = 0., 41 | ) -> None: 42 | super().__init__(root, transforms, transform, target_transform) 43 | 44 | self.pairs = [] 45 | self.weights = [] 46 | type_weight_list = [1] 47 | # type_weight_list = [2, 3, 1, 2] 48 | # type_weight_list = [0.04, 0.96] 49 | # derain 13712 enhance 485 50 | for idx, json_path in enumerate(json_path_list): 51 | cur_pairs = json.load(open(json_path)) 52 | self.pairs.extend(cur_pairs) 53 | cur_num = len(cur_pairs) 54 | self.weights.extend([type_weight_list[idx] * 1./cur_num]*cur_num) 55 | print(json_path, type_weight_list[idx]) 56 | 57 | 58 | self.transforms = PairStandardTransform(transform, target_transform) if transform is not None else None 59 | self.transforms2 = PairStandardTransform(transform2, target_transform) if transform2 is not None else None 60 | self.transforms3 = PairStandardTransform(transform3, target_transform) if transform3 is not None else None 61 | self.masked_position_generator = masked_position_generator 62 | self.use_two_pairs = use_two_pairs 63 | self.half_mask_ratio = half_mask_ratio 64 | 65 | def _load_image(self, path: str) -> Image.Image: 66 | while True: 67 | try: 68 | img = Image.open(os.path.join(self.root, path)) 69 | except OSError as e: 70 | print(f"Catched exception: {str(e)}. Re-trying...") 71 | import time 72 | time.sleep(1) 73 | else: 74 | break 75 | img = img.convert("RGB") 76 | return img 77 | 78 | 79 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 80 | pair = self.pairs[index] 81 | 82 | 83 | image = self._load_image(pair['image_path']) 84 | target = self._load_image(pair['target_path']) 85 | 86 | 87 | pair_type = pair['type'] 88 | # print(pair['image_path'],pair_type, bool('derain' in pair_type), bool('enhance' in pair_type), bool('ssid' in pair_type), bool('deblur' in pair_type)) 89 | # exit() 90 | if 'derain' in pair_type: 91 | type_dict = torch.tensor([1, 0, 0, 0]).unsqueeze(0) 92 | elif 'fivek' in pair_type: 93 | type_dict = torch.tensor([0, 1, 0, 0]).unsqueeze(0) 94 | elif 'ssid' in pair_type: 95 | type_dict = torch.tensor([0, 0, 1, 0]).unsqueeze(0) 96 | elif 'deblur' in pair_type: 97 | type_dict = torch.tensor([0, 0, 0, 1]).unsqueeze(0) 98 | else: 99 | raise ValueError('Invalid path') 100 | 101 | interpolation1 = 'bicubic' 102 | interpolation2 = 'bicubic' 103 | cur_transforms = self.transforms 104 | image, target = cur_transforms(image, target, interpolation1, interpolation2) 105 | 106 | valid = torch.ones_like(target) 107 | mask = self.masked_position_generator() 108 | 109 | return image, target, mask, valid, type_dict 110 | 111 | def __len__(self) -> int: 112 | return len(self.pairs) 113 | 114 | 115 | class PairStandardTransform(StandardTransform): 116 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 117 | super().__init__(transform=transform, target_transform=target_transform) 118 | 119 | def __call__(self, input: Any, target: Any, interpolation1: Any, interpolation2: Any) -> Tuple[Any, Any]: 120 | if self.transform is not None: 121 | input, target = self.transform(input, target, interpolation1, interpolation2) 122 | #if self.target_transform is not None: 123 | # target = self.target_transform(target) 124 | return input, target 125 | -------------------------------------------------------------------------------- /ProRes/data/gen_json_deblur.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import tqdm 5 | 6 | 7 | if __name__ == "__main__": 8 | # type = 'gt_sub_input' 9 | type = 'groundtruth' 10 | 11 | # split = 'train' 12 | split = 'train' 13 | if split == 'train': 14 | image_dir = "./datasets/low_level/deblur/train/input" 15 | save_path = "./datasets/low_level/{}-deblur_gopro_train.json".format(type) 16 | elif split == 'val': 17 | image_dir = "./datasets/low_level/deblur/test/RealBlur_J/input" 18 | save_path = "./datasets/low_level/{}-deblur_realblur_j_val.json".format(type) 19 | else: 20 | raise NotImplementedError 21 | print(save_path) 22 | 23 | output_dict = [] 24 | 25 | image_path_list = glob.glob(os.path.join(image_dir, '*.png')) 26 | for image_path in tqdm.tqdm(image_path_list): 27 | # image_name = os.path.basename(image_path) 28 | target_path = image_path.replace('input', type) 29 | assert os.path.isfile(image_path) 30 | assert os.path.isfile(target_path) 31 | pair_dict = {} 32 | pair_dict["image_path"] = image_path 33 | pair_dict["target_path"] = target_path 34 | pair_dict["type"] = "{}_deblur".format(type) 35 | output_dict.append(pair_dict) 36 | 37 | json.dump(output_dict, open(save_path, 'w')) 38 | -------------------------------------------------------------------------------- /ProRes/data/gen_json_dehaze.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import tqdm 5 | 6 | 7 | if __name__ == "__main__": 8 | type = 'GT' 9 | # type = 'gt_sub_input' 10 | 11 | # split = 'train' 12 | split = 'test' 13 | if split == 'train': 14 | image_dir = "./datasets/low_level/dehaze/train/hazy/" 15 | save_path = "./datasets/low_level/{}-dehaze_train.json".format(type) 16 | elif split == 'test': 17 | image_dir = "./datasets/low_level/dehaze/test/hazy/" 18 | save_path = "./datasets/low_level/{}-dehaze_test.json".format(type) 19 | else: 20 | raise NotImplementedError 21 | print(save_path) 22 | 23 | output_dict = [] 24 | 25 | image_path_list = glob.glob(os.path.join(image_dir, '*.jpg'))+glob.glob(os.path.join(image_dir, '*.png')) 26 | for image_path in tqdm.tqdm(image_path_list): 27 | # image_name = os.path.basename(image_path) 28 | target_path = image_path.replace('hazy', 'GT') 29 | assert os.path.isfile(image_path) 30 | assert os.path.isfile(target_path) 31 | pair_dict = {} 32 | pair_dict["image_path"] = image_path 33 | pair_dict["target_path"] = target_path 34 | pair_dict["type"] = "{}_dehaze".format(type) 35 | output_dict.append(pair_dict) 36 | 37 | json.dump(output_dict, open(save_path, 'w')) 38 | -------------------------------------------------------------------------------- /ProRes/data/gen_json_dpdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import tqdm 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | type = 'target' 10 | 11 | split = 'train' 12 | # split = 'test' 13 | if split == 'train': 14 | image_dir = "./datasets/low_level/defocusblur/dpdd/train/inputC/" 15 | save_path = "./datasets/low_level/{}-defocusblur_dpdd_train.json".format(type) 16 | elif split == 'test': 17 | image_dir = "./datasets/low_level/defocusblur/dpdd/test/inputC/" 18 | save_path = "./datasets/low_level/{}-defocusblur_dpdd_test.json".format(type) 19 | else: 20 | raise NotImplementedError 21 | print(save_path) 22 | 23 | output_dict = [] 24 | 25 | image_path_list = glob.glob(os.path.join(image_dir, '*.png')) 26 | for image_path in tqdm.tqdm(image_path_list): 27 | # image_name = os.path.basename(image_path) 28 | target_path = image_path.replace('inputC', type) 29 | assert os.path.isfile(image_path) 30 | assert os.path.isfile(target_path) 31 | pair_dict = {} 32 | pair_dict["image_path"] = image_path 33 | pair_dict["target_path"] = target_path 34 | pair_dict["type"] = "{}_defocusblur_dpdd".format(type) 35 | output_dict.append(pair_dict) 36 | 37 | 38 | image_path_list = glob.glob(os.path.join(image_dir.replace('inputC', 'inputL'), '*.png')) 39 | for image_path in tqdm.tqdm(image_path_list): 40 | # image_name = os.path.basename(image_path) 41 | target_path = image_path.replace('inputL', type) 42 | assert os.path.isfile(image_path) 43 | assert os.path.isfile(target_path) 44 | pair_dict = {} 45 | pair_dict["image_path"] = image_path 46 | pair_dict["target_path"] = target_path 47 | pair_dict["type"] = "{}_defocusblur_dpdd".format(type) 48 | output_dict.append(pair_dict) 49 | 50 | 51 | image_path_list = glob.glob(os.path.join(image_dir.replace('inputC', 'inputR'), '*.png')) 52 | for image_path in tqdm.tqdm(image_path_list): 53 | # image_name = os.path.basename(image_path) 54 | target_path = image_path.replace('inputR', type) 55 | assert os.path.isfile(image_path) 56 | assert os.path.isfile(target_path) 57 | pair_dict = {} 58 | pair_dict["image_path"] = image_path 59 | pair_dict["target_path"] = target_path 60 | pair_dict["type"] = "{}_defocusblur_dpdd".format(type) 61 | output_dict.append(pair_dict) 62 | json.dump(output_dict, open(save_path, 'w')) -------------------------------------------------------------------------------- /ProRes/data/gen_json_fivek.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import tqdm 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | images_list = os.listdir("/horizon-bucket/BasicAlgorithm/Users/jiaqi.ma/fivek/input/") 10 | print(len(images_list)) 11 | 12 | images_list.sort() 13 | 14 | train_images_list = images_list[:4500] 15 | test_images_list = images_list[4500:] 16 | 17 | type = 'expertC_gt' 18 | 19 | 20 | output_dict_train = [] 21 | 22 | for image_path in tqdm.tqdm(train_images_list): 23 | # image_name = os.path.basename(image_path) 24 | image_dir = "./datasets/low_level/enhance/fivek/input/" 25 | image_path = os.path.join(image_dir, image_path) 26 | target_path = image_path.replace('input', 'expertC_gt') 27 | assert os.path.isfile(image_path) 28 | assert os.path.isfile(target_path) 29 | pair_dict = {} 30 | pair_dict["image_path"] = image_path 31 | pair_dict["target_path"] = target_path 32 | pair_dict["type"] = "{}_fivek".format(type) 33 | output_dict_train.append(pair_dict) 34 | save_path_train = "./datasets/low_level/{}-fivek_train.json".format(type) 35 | json.dump(output_dict_train, open(save_path_train, 'w')) 36 | 37 | 38 | output_dict_test = [] 39 | for image_path in tqdm.tqdm(test_images_list): 40 | # image_name = os.path.basename(image_path) 41 | image_dir = "./datasets/low_level/enhance/fivek/input/" 42 | image_path = os.path.join(image_dir, image_path) 43 | target_path = image_path.replace('input', 'expertC_gt') 44 | assert os.path.isfile(image_path) 45 | assert os.path.isfile(target_path) 46 | pair_dict = {} 47 | pair_dict["image_path"] = image_path 48 | pair_dict["target_path"] = target_path 49 | pair_dict["type"] = "{}_fivek".format(type) 50 | output_dict_test.append(pair_dict) 51 | save_path_test = "./datasets/low_level/{}-fivek_test.json".format(type) 52 | json.dump(output_dict_test, open(save_path_test, 'w')) -------------------------------------------------------------------------------- /ProRes/data/gen_json_lol.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import tqdm 5 | 6 | 7 | if __name__ == "__main__": 8 | type = 'gt' 9 | # type = 'gt_sub_input' 10 | 11 | # split = 'train' 12 | split = 'test' 13 | if split == 'train': 14 | image_dir = "./datasets/low_level/enhance/lol/our485/input/" 15 | save_path = "./datasets/low_level/{}-enhance_lol_train.json".format(type) 16 | elif split == 'test': 17 | image_dir = "./datasets/low_level/enhance/lol/eval15/input/" 18 | save_path = "./datasets/low_level/{}-enhance_lol_eval.json".format(type) 19 | else: 20 | raise NotImplementedError 21 | print(save_path) 22 | 23 | output_dict = [] 24 | 25 | image_path_list = glob.glob(os.path.join(image_dir, '*.png')) 26 | for image_path in tqdm.tqdm(image_path_list): 27 | # image_name = os.path.basename(image_path) 28 | target_path = image_path.replace('input', type) 29 | assert os.path.isfile(image_path) 30 | assert os.path.isfile(target_path) 31 | pair_dict = {} 32 | pair_dict["image_path"] = image_path 33 | pair_dict["target_path"] = target_path 34 | pair_dict["type"] = "{}_enhance_lol".format(type) 35 | output_dict.append(pair_dict) 36 | 37 | json.dump(output_dict, open(save_path, 'w')) 38 | -------------------------------------------------------------------------------- /ProRes/data/gen_json_rain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import tqdm 5 | 6 | 7 | if __name__ == "__main__": 8 | type = 'target' 9 | # type = 'target_sub_input' 10 | 11 | # split = 'train' 12 | split = 'test' 13 | if split == 'train': 14 | image_dir = "./datasets/low_level/derain/train/input/" 15 | save_path = "./datasets/low_level/{}-derain_train.json".format(type) 16 | elif split == 'test': 17 | image_dir = "./datasets/low_level/derain/test/Rain100H/input/" 18 | save_path = "./datasets/low_level/{}-derain_test_rain100h.json".format(type) 19 | else: 20 | raise NotImplementedError 21 | print(save_path) 22 | 23 | output_dict = [] 24 | 25 | image_path_list = glob.glob(os.path.join(image_dir, '*.png')) 26 | # image_path_list = glob.glob(os.path.join(image_dir, '*.jpg')) 27 | for image_path in tqdm.tqdm(image_path_list): 28 | # image_name = os.path.basename(image_path) 29 | target_path = image_path.replace('input', type) 30 | assert os.path.isfile(image_path) 31 | assert os.path.isfile(target_path) 32 | pair_dict = {} 33 | pair_dict["image_path"] = image_path 34 | pair_dict["target_path"] = target_path 35 | pair_dict["type"] = "{}_derain".format(type) 36 | output_dict.append(pair_dict) 37 | 38 | json.dump(output_dict, open(save_path, 'w')) 39 | -------------------------------------------------------------------------------- /ProRes/data/gen_json_sidd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | import tqdm 5 | 6 | 7 | if __name__ == "__main__": 8 | # type = 'gt_sub_input' 9 | type = 'groundtruth' 10 | 11 | # split = 'train' 12 | split = 'val' 13 | if split == 'train': 14 | image_dir = "./datasets/low_level/denoising/sidd/train_448/input" 15 | save_path = "./datasets/low_level/{}-denoise_ssid_train448.json".format(type) 16 | elif split == 'val': 17 | image_dir = "./datasets/low_level/denoising/sidd/sidd_val_patch256/input" 18 | save_path = "./datasets/low_level/{}-denoise_ssid_val256.json".format(type) 19 | else: 20 | raise NotImplementedError 21 | print(save_path) 22 | 23 | output_dict = [] 24 | 25 | image_path_list = glob.glob(os.path.join(image_dir, '*.png')) 26 | for image_path in tqdm.tqdm(image_path_list): 27 | # image_name = os.path.basename(image_path) 28 | target_path = image_path.replace('input', type) 29 | assert os.path.isfile(image_path) 30 | assert os.path.isfile(target_path) 31 | pair_dict = {} 32 | pair_dict["image_path"] = image_path 33 | pair_dict["target_path"] = target_path 34 | pair_dict["type"] = "{}_denoise_ssid_448".format(type) 35 | output_dict.append(pair_dict) 36 | 37 | json.dump(output_dict, open(save_path, 'w')) 38 | -------------------------------------------------------------------------------- /ProRes/data/pair_transforms.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numbers 3 | import random 4 | import warnings 5 | from collections.abc import Sequence 6 | from typing import List, Optional, Tuple 7 | 8 | import torch 9 | from torch import Tensor 10 | import torchvision.transforms as transforms 11 | 12 | try: 13 | import accimage 14 | except ImportError: 15 | accimage = None 16 | 17 | import torchvision.transforms.functional as F 18 | from torchvision.transforms.functional import _interpolation_modes_from_int, InterpolationMode 19 | 20 | __all__ = [ 21 | "Compose", 22 | "ToTensor", 23 | "Normalize", 24 | "RandomHorizontalFlip", 25 | "RandomResizedCrop", 26 | ] 27 | 28 | 29 | 30 | class Compose(transforms.Compose): 31 | """Composes several transforms together. This transform does not support torchscript. 32 | Please, see the note below. 33 | Args: 34 | transforms (list of ``Transform`` objects): list of transforms to compose. 35 | """ 36 | 37 | def __init__(self, transforms): 38 | super().__init__(transforms) 39 | 40 | def __call__(self, img, tgt, interpolation1=None, interpolation2=None): 41 | for t in self.transforms: 42 | img, tgt = t(img, tgt) 43 | return img, tgt 44 | 45 | 46 | class ToTensor(transforms.ToTensor): 47 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. This transform does not support torchscript. 48 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 49 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] 50 | if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1) 51 | or if the numpy.ndarray has dtype = np.uint8 52 | In the other cases, tensors are returned without scaling. 53 | .. note:: 54 | Because the input image is scaled to [0.0, 1.0], this transformation should not be used when 55 | transforming target image masks. See the `references`_ for implementing the transforms for image masks. 56 | .. _references: https://github.com/pytorch/vision/tree/main/references/segmentation 57 | """ 58 | 59 | def __init__(self) -> None: 60 | super().__init__() 61 | 62 | def __call__(self, pic1, pic2, interpolation1=None, interpolation2=None): 63 | """ 64 | Args: 65 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 66 | Returns: 67 | Tensor: Converted image. 68 | """ 69 | return F.to_tensor(pic1), F.to_tensor(pic2) 70 | 71 | 72 | class Normalize(transforms.Normalize): 73 | """Normalize a tensor image with mean and standard deviation. 74 | This transform does not support PIL Image. 75 | Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` 76 | channels, this transform will normalize each channel of the input 77 | ``torch.*Tensor`` i.e., 78 | ``output[channel] = (input[channel] - mean[channel]) / std[channel]`` 79 | .. note:: 80 | This transform acts out of place, i.e., it does not mutate the input tensor. 81 | Args: 82 | mean (sequence): Sequence of means for each channel. 83 | std (sequence): Sequence of standard deviations for each channel. 84 | inplace(bool,optional): Bool to make this operation in-place. 85 | """ 86 | 87 | def __init__(self, mean, std, inplace=False): 88 | super().__init__(mean, std, inplace) 89 | 90 | def forward(self, tensor1: Tensor, tensor2: Tensor, interpolation1=None, interpolation2=None): 91 | """ 92 | Args: 93 | tensor (Tensor): Tensor image to be normalized. 94 | Returns: 95 | Tensor: Normalized Tensor image. 96 | """ 97 | return F.normalize(tensor1, self.mean, self.std, self.inplace), F.normalize(tensor2, self.mean, self.std, self.inplace) 98 | 99 | 100 | class RandomResizedCrop(transforms.RandomResizedCrop): 101 | """Crop a random portion of image and resize it to a given size. 102 | If the image is torch Tensor, it is expected 103 | to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions 104 | A crop of the original image is made: the crop has a random area (H * W) 105 | and a random aspect ratio. This crop is finally resized to the given 106 | size. This is popularly used to train the Inception networks. 107 | Args: 108 | size (int or sequence): expected output size of the crop, for each edge. If size is an 109 | int instead of sequence like (h, w), a square output size ``(size, size)`` is 110 | made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). 111 | .. note:: 112 | In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``. 113 | scale (tuple of float): Specifies the lower and upper bounds for the random area of the crop, 114 | before resizing. The scale is defined with respect to the area of the original image. 115 | ratio (tuple of float): lower and upper bounds for the random aspect ratio of the crop, before 116 | resizing. 117 | interpolation (InterpolationMode): Desired interpolation enum defined by 118 | :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. 119 | If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and 120 | ``InterpolationMode.BICUBIC`` are supported. 121 | For backward compatibility integer values (e.g. ``PIL.Image[.Resampling].NEAREST``) are still accepted, 122 | but deprecated since 0.13 and will be removed in 0.15. Please use InterpolationMode enum. 123 | """ 124 | 125 | def __init__( 126 | self, 127 | size, 128 | scale=(0.08, 1.0), 129 | ratio=(3.0 / 4.0, 4.0 / 3.0), 130 | interpolation=InterpolationMode.BILINEAR, 131 | ): 132 | super().__init__(size, scale=scale, ratio=ratio, interpolation=interpolation) 133 | 134 | def forward(self, img, tgt, interpolation1=None, interpolation2=None): 135 | """ 136 | Args: 137 | img (PIL Image or Tensor): Image to be cropped and resized. 138 | Returns: 139 | PIL Image or Tensor: Randomly cropped and resized image. 140 | """ 141 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 142 | if interpolation1 == 'nearest': 143 | interpolation1 = InterpolationMode.NEAREST 144 | else: 145 | interpolation1 = InterpolationMode.BICUBIC 146 | if interpolation2 == 'nearest': 147 | interpolation2 = InterpolationMode.NEAREST 148 | else: 149 | interpolation2 = InterpolationMode.BICUBIC 150 | 151 | return F.resized_crop(img, i, j, h, w, self.size, interpolation1), \ 152 | F.resized_crop(tgt, i, j, h, w, self.size, interpolation2) 153 | 154 | 155 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip): 156 | """Horizontally flip the given image randomly with a given probability. 157 | If the image is torch Tensor, it is expected 158 | to have [..., H, W] shape, where ... means an arbitrary number of leading 159 | dimensions 160 | Args: 161 | p (float): probability of the image being flipped. Default value is 0.5 162 | """ 163 | 164 | def __init__(self, p=0.5): 165 | super().__init__(p=p) 166 | 167 | def forward(self, img, tgt, interpolation1=None, interpolation2=None): 168 | """ 169 | Args: 170 | img (PIL Image or Tensor): Image to be flipped. 171 | Returns: 172 | PIL Image or Tensor: Randomly flipped image. 173 | """ 174 | if torch.rand(1) < self.p: 175 | return F.hflip(img), F.hflip(tgt) 176 | return img, tgt 177 | 178 | -------------------------------------------------------------------------------- /ProRes/data/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, List, Optional, Union 2 | from collections import Counter 3 | import logging 4 | from operator import itemgetter 5 | from random import choices, sample 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch.utils.data import Dataset, Sampler 11 | from torch.utils.data import DistributedSampler 12 | 13 | 14 | class DatasetFromSampler(Dataset): 15 | """Dataset to create indexes from `Sampler`. 16 | Args: 17 | sampler: PyTorch sampler 18 | """ 19 | 20 | def __init__(self, sampler: Sampler): 21 | """Initialisation for DatasetFromSampler.""" 22 | self.sampler = sampler 23 | self.sampler_list = None 24 | 25 | def __getitem__(self, index: int): 26 | """Gets element of the dataset. 27 | Args: 28 | index: index of the element in the dataset 29 | Returns: 30 | Single element by index 31 | """ 32 | if self.sampler_list is None: 33 | self.sampler_list = list(self.sampler) 34 | return self.sampler_list[index] 35 | 36 | def __len__(self) -> int: 37 | """ 38 | Returns: 39 | int: length of the dataset 40 | """ 41 | return len(self.sampler) 42 | 43 | 44 | class DistributedSamplerWrapper(DistributedSampler): 45 | """ 46 | Wrapper over `Sampler` for distributed training. 47 | Allows you to use any sampler in distributed mode. 48 | It is especially useful in conjunction with 49 | `torch.nn.parallel.DistributedDataParallel`. In such case, each 50 | process can pass a DistributedSamplerWrapper instance as a DataLoader 51 | sampler, and load a subset of subsampled data of the original dataset 52 | that is exclusive to it. 53 | .. note:: 54 | Sampler is assumed to be of constant size. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | sampler, 60 | num_replicas: Optional[int] = None, 61 | rank: Optional[int] = None, 62 | shuffle: bool = True, 63 | ): 64 | """ 65 | Args: 66 | sampler: Sampler used for subsampling 67 | num_replicas (int, optional): Number of processes participating in 68 | distributed training 69 | rank (int, optional): Rank of the current process 70 | within ``num_replicas`` 71 | shuffle (bool, optional): If true (default), 72 | sampler will shuffle the indices 73 | """ 74 | super(DistributedSamplerWrapper, self).__init__( 75 | DatasetFromSampler(sampler), 76 | num_replicas=num_replicas, 77 | rank=rank, 78 | shuffle=shuffle, 79 | ) 80 | self.sampler = sampler 81 | 82 | def __iter__(self): 83 | """@TODO: Docs. Contribution is welcome.""" 84 | self.dataset = DatasetFromSampler(self.sampler) 85 | indexes_of_indexes = super().__iter__() 86 | subsampler_indexes = self.dataset 87 | return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) 88 | 89 | -------------------------------------------------------------------------------- /ProRes/data/unif.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | from typing import Any, Callable, List, Optional, Tuple 4 | import random 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | import torch 10 | from torchvision.datasets.vision import VisionDataset, StandardTransform 11 | import torch.nn.functional as F 12 | 13 | class PairDataset(VisionDataset): 14 | """`MS Coco Detection `_ Dataset. 15 | 16 | It requires the `COCO API to be installed `_. 17 | 18 | Args: 19 | root (string): Root directory where images are downloaded to. 20 | annFile (string): Path to json annotation file. 21 | transform (callable, optional): A function/transform that takes in an PIL image 22 | and returns a transformed version. E.g, ``transforms.PILToTensor`` 23 | target_transform (callable, optional): A function/transform that takes in the 24 | target and transforms it. 25 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 26 | and returns a transformed version. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | json_path_list: list, 33 | transform: Optional[Callable] = None, 34 | transform2: Optional[Callable] = None, 35 | transform3: Optional[Callable] = None, 36 | target_transform: Optional[Callable] = None, 37 | transforms: Optional[Callable] = None, 38 | masked_position_generator: Optional[Callable] = None, 39 | use_two_pairs: bool = False, 40 | half_mask_ratio:float = 0., 41 | ) -> None: 42 | super().__init__(root, transforms, transform, target_transform) 43 | 44 | self.pairs = [] 45 | self.weights = [] 46 | type_weight_list = [3, 3, 1, 3] 47 | # type_weight_list = [0.033, 1, 0.004, 0.029] 48 | # type_weight_list = [0.04, 0.96] 49 | # derain 13712 enhance 485 50 | for idx, json_path in enumerate(json_path_list): 51 | cur_pairs = json.load(open(json_path)) 52 | self.pairs.extend(cur_pairs) 53 | cur_num = len(cur_pairs) 54 | self.weights.extend([type_weight_list[idx] * 1./cur_num]*cur_num) 55 | print(json_path, type_weight_list[idx]) 56 | 57 | #self.weights = [1./n for n in self.weights] 58 | self.use_two_pairs = use_two_pairs 59 | if self.use_two_pairs: 60 | self.pair_type_dict = {} 61 | for idx, pair in enumerate(self.pairs): 62 | if "type" in pair: 63 | if pair["type"] not in self.pair_type_dict: 64 | self.pair_type_dict[pair["type"]] = [idx] 65 | else: 66 | self.pair_type_dict[pair["type"]].append(idx) 67 | for t in self.pair_type_dict: 68 | print(t, len(self.pair_type_dict[t])) 69 | 70 | self.transforms = PairStandardTransform(transform, target_transform) if transform is not None else None 71 | self.transforms2 = PairStandardTransform(transform2, target_transform) if transform2 is not None else None 72 | self.transforms3 = PairStandardTransform(transform3, target_transform) if transform3 is not None else None 73 | self.masked_position_generator = masked_position_generator 74 | self.half_mask_ratio = half_mask_ratio 75 | 76 | def _load_image(self, path: str) -> Image.Image: 77 | while True: 78 | try: 79 | img = Image.open(os.path.join(self.root, path)) 80 | except OSError as e: 81 | print(f"Catched exception: {str(e)}. Re-trying...") 82 | import time 83 | time.sleep(1) 84 | else: 85 | break 86 | # process for nyuv2 depth: scale to 0~255 87 | if "sync_depth" in path: 88 | # nyuv2's depth range is 0~10m 89 | img = np.array(img) / 10000. 90 | img = img * 255 91 | img = Image.fromarray(img) 92 | img = img.convert("RGB") 93 | return img 94 | 95 | def _random_add_prompts_random_scales(self,image, prompt, prompt_range=[8,64], scale_range=[0.2,0.3]): 96 | image = np.asarray(image.permute(1,2,0)) 97 | prompt = prompt 98 | 99 | h, w = image.shape[0],image.shape[1] 100 | 101 | mask_image = np.ones((int(h),int(w),3),dtype=np.float32) 102 | mask_prompt = np.zeros((int(h),int(w),3),dtype=np.float32) 103 | 104 | ratio = 0 105 | 106 | while (scale_range[0] > ratio) == True or (ratio > scale_range[1])!=True: 107 | h_p = w_p = int(random.uniform(prompt_range[0], prompt_range[1])) 108 | point_h = int(random.uniform(h_p, h-h_p)) 109 | point_w = int(random.uniform(w_p, w-w_p)) 110 | 111 | mask_image[point_h:point_h+h_p,point_w:point_w+w_p] = 0.0 112 | mask_prompt[point_h:point_h+h_p,point_w:point_w+w_p] = 1.0 113 | prompts_token_num = np.sum(mask_prompt) 114 | ratio = prompts_token_num/(h*w) 115 | 116 | # image = image*mask_image 117 | 118 | # prompt = prompt * mask_prompt 119 | image = image + prompt 120 | 121 | return image 122 | 123 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 124 | pair = self.pairs[index] 125 | 126 | pair_type = pair['type'] 127 | if 'derain' in pair_type: 128 | type_dict = 'derain' 129 | elif 'enhance' in pair_type: 130 | type_dict = 'enhance' 131 | elif 'ssid' in pair_type: 132 | type_dict = 'ssid' 133 | elif 'deblur' in pair_type: 134 | type_dict = 'deblur' 135 | 136 | 137 | prompt_dict = { 138 | 'prompt_derain': 'datasets/low_level/derain.npy', 139 | 'prompt_enhance': 'datasets/low_level/enhance.npy', 140 | 'prompt_ssid': 'datasets/low_level/ssid.npy', 141 | 'prompt_deblur': 'datasets/low_level/deblur.npy' 142 | } 143 | 144 | key = next(k for k, v in prompt_dict.items() if type_dict in k) 145 | prompt = np.load(prompt_dict[key]) 146 | flag = torch.ones(()) 147 | # print(type_dict) 148 | # print(prompt_dict) 149 | # exit() 150 | 151 | image = self._load_image(pair['image_path']) 152 | target = self._load_image(pair['target_path']) 153 | 154 | 155 | interpolation1 = 'bicubic' 156 | interpolation2 = 'bicubic' 157 | 158 | # no aug for instance segmentation 159 | if "inst" in pair['type'] and self.transforms2 is not None: 160 | cur_transforms = self.transforms2 161 | elif "pose" in pair['type'] and self.transforms3 is not None: 162 | cur_transforms = self.transforms3 163 | else: 164 | cur_transforms = self.transforms 165 | 166 | image, target = cur_transforms(image, target, interpolation1, interpolation2) 167 | image_ori = image 168 | 169 | 170 | # first resize to 448*448 and combine them 171 | image = self._random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.95,0.99]) 172 | 173 | image = torch.from_numpy(image.transpose(2, 0, 1)) 174 | 175 | valid = torch.ones_like(target) 176 | imagenet_mean=torch.tensor([0.485, 0.456, 0.406]) 177 | imagenet_std=torch.tensor([0.229, 0.224, 0.225]) 178 | 179 | 180 | mask = self.masked_position_generator() 181 | # mask all 0 182 | # valid all 1 183 | # 1 masked patch 184 | # 0 valid patch 185 | 186 | return image, target, mask, valid, flag, image_ori 187 | 188 | def __len__(self) -> int: 189 | return len(self.pairs) 190 | 191 | 192 | class PairStandardTransform(StandardTransform): 193 | def __init__(self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) -> None: 194 | super().__init__(transform=transform, target_transform=target_transform) 195 | 196 | def __call__(self, input: Any, target: Any, interpolation1: Any, interpolation2: Any) -> Tuple[Any, Any]: 197 | if self.transform is not None: 198 | input, target = self.transform(input, target, interpolation1, interpolation2) 199 | #if self.target_transform is not None: 200 | # target = self.target_transform(target) 201 | return input, target 202 | -------------------------------------------------------------------------------- /ProRes/data/utils/check_black_image.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import glob 4 | import json 5 | import warnings 6 | import argparse 7 | import shutil 8 | 9 | import numpy as np 10 | import tqdm 11 | from PIL import Image 12 | 13 | 14 | def get_args_parser(): 15 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 16 | parser.add_argument('--tgt_dir', type=str, help='dir to ckpt', required=True) 17 | parser.add_argument('--remove', action="store_true", help='dir to ckpt', default=False) 18 | return parser.parse_args() 19 | 20 | 21 | def load_image_with_retry(image_path): 22 | while True: 23 | try: 24 | img = Image.open(image_path) 25 | return img 26 | except OSError as e: 27 | print(f"Catched exception: {str(e)}. Re-trying...") 28 | import time 29 | time.sleep(1) 30 | 31 | 32 | if __name__ == '__main__': 33 | args = get_args_parser() 34 | tgt_dir = args.tgt_dir 35 | 36 | image_list = glob.glob(os.path.join(tgt_dir, "*.png")) + glob.glob(os.path.join(tgt_dir, "*.jpg")) 37 | num_black = 0 38 | for image_path in tqdm.tqdm(image_list): 39 | image = load_image_with_retry(image_path) 40 | image = np.array(image) 41 | if (image == 0).all(): 42 | num_black += 1 43 | print("{}. {} is black!".format(num_black, image_path)) 44 | if args.remove: 45 | os.remove(image_path) 46 | 47 | print("num black: {}".format(num_black)) 48 | -------------------------------------------------------------------------------- /ProRes/data/utils/gen_image_flip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import tqdm 4 | from PIL import Image 5 | import torchvision.transforms.functional as transform_func 6 | 7 | 8 | if __name__ == '__main__': 9 | image_src_dir = "coco/train2017" 10 | image_tgt_dir = "coco/train2017_flip" 11 | 12 | if not os.path.exists(image_tgt_dir): 13 | os.makedirs(image_tgt_dir) 14 | 15 | image_list = glob.glob(os.path.join(image_src_dir, "*.jpg")) 16 | for image_path in tqdm.tqdm(image_list): 17 | image = Image.open(image_path) 18 | image_flip = transform_func.hflip(image) 19 | 20 | file_name = os.path.basename(image_path).replace(".jpg", "_flip.jpg") 21 | tgt_path = os.path.join(image_tgt_dir, file_name) 22 | image_flip.save(tgt_path) -------------------------------------------------------------------------------- /ProRes/data/utils/get_num_obj.py: -------------------------------------------------------------------------------- 1 | """ 2 | get subset for quick evaluation 3 | """ 4 | import os 5 | import glob 6 | import json 7 | import tqdm 8 | import shutil 9 | 10 | 11 | if __name__ == "__main__": 12 | file_path = "coco/panoptic_val2017.json" 13 | 14 | data = json.load(open(file_path, 'r')) 15 | annotations = data['annotations'] # panoptic annos are saved in per image style 16 | categories = {category['id']: category for category in data['categories']} 17 | 18 | # note this includes crowd 19 | num_inst_list = [] 20 | for anno in annotations: 21 | num_inst = 0 22 | segments_info = anno['segments_info'] 23 | for seg in segments_info: 24 | if seg['iscrowd']: 25 | continue 26 | if not categories[seg['category_id']]['isthing']: 27 | continue 28 | num_inst += 1 29 | # if num_inst != 90: 30 | num_inst_list.append(num_inst) 31 | 32 | print(max(num_inst_list)) 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /ProRes/data/utils/get_random_propmts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import glob 4 | 5 | 6 | def get_random_prompt(file_dir): 7 | file_list = glob.glob(os.path.join(file_dir, "*.jpg")) + glob.glob(os.path.join(file_dir, "*.png")) 8 | files = random.sample(file_list, 16) 9 | return files 10 | 11 | 12 | if __name__ == "__main__": 13 | # file_dir = "datasets/low_level/enhance/lol/our485/input" 14 | file_dir = "datasets/low_level/derain/train/input/" 15 | # file_dir = "data/low_level/denoising/sidd/train_448/input/" 16 | # file_dir = "data/coco/train2017" 17 | # file_dir = "data/coco/coco_pose_256x192/coco_pose_sigma1.5and3_train2017_maxoverlap_augflip1" 18 | # file_dir = "data/ade20k/images/training" 19 | # file_dir = "/sharefs/baaivision/xinlongwang/code/uip/data/nyuv2/sync/*/" 20 | files = get_random_prompt(file_dir) 21 | for f in files: 22 | # print(f) 23 | print(os.path.basename(f).split(".")[0]) 24 | -------------------------------------------------------------------------------- /ProRes/data/utils/get_subset.py: -------------------------------------------------------------------------------- 1 | """ 2 | get subset for quick evaluation 3 | """ 4 | import os 5 | import glob 6 | import json 7 | import tqdm 8 | import shutil 9 | 10 | 11 | def get_json_subset(file_path, output_path, num_keep=500): 12 | data = json.load(open(file_path, 'r')) 13 | # keys in data: dict_keys(['info', 'licenses', 'images', 'annotations', 'categories']) 14 | data['images'] = data['images'][:num_keep] 15 | with open(output_path, 'w') as f: 16 | json.dump(data, f) 17 | 18 | 19 | if __name__ == "__main__": 20 | num_keep = 500 21 | file_path = "coco/annotations/instances_val2017.json" 22 | output_path = file_path.replace(".json", "_first{}.json".format(num_keep)) 23 | if not os.path.exists(output_path): 24 | get_json_subset(file_path, output_path, num_keep) 25 | 26 | data_of_interest = json.load(open(output_path, 'r')) 27 | images = data_of_interest['images'] 28 | images_list = [img['file_name'] for img in images] 29 | 30 | # images_src_dir = "/sharefs/wwen/unified-vp/uip/models_inference/uip_rpe_vit_large_patch16_input640_win_dec64_8glb_lr1e-3_clip1.5_bs2x8x16_maeinit_mask392_depth_ade20k_cocomask_cocoins_cocosem_cocopose_bidi_new_nearest_25ep/" \ 31 | # "pano_inst_inference_epoch3_000000466730" 32 | images_src_dir = '/sharefs/wwen/unified-vp/uip/models_inference/' \ 33 | 'uip_rpe_vit_large_patch16_input640_win_dec64_8glb_lr1e-3_clip1.5_bs2x8x16_maeinit_mask392_depth_ade20k_cocomask_cocoins_cocosem_cocopose_wobidi_new_nearest_50ep_insworg_newweight_posex50_insx10/' \ 34 | 'pano_semseg_inference_epoch42_000000443397' 35 | # 'pano_inst_inference_epoch42_000000443397' 36 | assert images_src_dir[-1] != '/' 37 | images_tgt_dir = images_src_dir + "_first{}".format(num_keep) 38 | if not os.path.exists(images_tgt_dir): 39 | os.makedirs(images_tgt_dir) 40 | print(images_tgt_dir) 41 | 42 | image_path_list = glob.glob(os.path.join(images_src_dir, '*.png')) 43 | for image_path in tqdm.tqdm(image_path_list): 44 | # file_name = os.path.basename(image_path).split("_")[0] + ".jpg" 45 | file_name = os.path.basename(image_path).replace(".png", ".jpg") 46 | if file_name in images_list: 47 | shutil.copy(image_path, images_tgt_dir) 48 | 49 | -------------------------------------------------------------------------------- /ProRes/data/utils/get_train2017_subset.py: -------------------------------------------------------------------------------- 1 | """ 2 | get subset for quick evaluation 3 | """ 4 | import os 5 | import glob 6 | import json 7 | import tqdm 8 | import shutil 9 | 10 | 11 | if __name__ == "__main__": 12 | images_src_dir = "coco/train2017_copy" 13 | image_path_list = glob.glob(os.path.join(images_src_dir, '*.jpg')) 14 | num_images = len(image_path_list) 15 | 16 | images_tgt_dir = images_src_dir + "_{}".format(num_images // 2) 17 | if not os.path.exists(images_tgt_dir): 18 | os.makedirs(images_tgt_dir) 19 | else: 20 | raise NotImplementedError("{} exist!".format(images_tgt_dir)) 21 | print(images_tgt_dir) 22 | 23 | for image_path in tqdm.tqdm(image_path_list): 24 | images_tgt_path = os.path.join(images_tgt_dir, os.path.basename(image_path)) 25 | assert not os.path.isfile(images_tgt_path) 26 | shutil.move(image_path, images_tgt_dir) 27 | 28 | -------------------------------------------------------------------------------- /ProRes/datasets/low_level/deblur.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/datasets/low_level/deblur.npy -------------------------------------------------------------------------------- /ProRes/datasets/low_level/derain.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/datasets/low_level/derain.npy -------------------------------------------------------------------------------- /ProRes/datasets/low_level/enhance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/datasets/low_level/enhance.npy -------------------------------------------------------------------------------- /ProRes/datasets/low_level/gt-enhance_lol_eval.json: -------------------------------------------------------------------------------- 1 | [{"image_path": "./datasets/low_level/enhance/lol/eval15/input/748.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/748.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/665.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/665.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/111.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/111.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/493.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/493.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/22.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/22.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/23.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/23.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/669.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/669.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/547.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/547.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/778.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/778.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/55.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/55.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/179.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/179.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/780.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/780.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/79.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/79.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/146.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/146.png", "type": "gt_enhance_lol"}, {"image_path": "./datasets/low_level/enhance/lol/eval15/input/1.png", "target_path": "./datasets/low_level/enhance/lol/eval15/gt/1.png", "type": "gt_enhance_lol"}] -------------------------------------------------------------------------------- /ProRes/datasets/low_level/ssid.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/datasets/low_level/ssid.npy -------------------------------------------------------------------------------- /ProRes/demo/ddp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from PIL import Image 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | import torch.distributed as dist 9 | 10 | 11 | class DatasetTest(Dataset): 12 | """ 13 | define dataset for ddp 14 | """ 15 | def __init__(self, img_src_dir, input_size, ext_list=('*.png', '*.jpg'), ): 16 | super(DatasetTest, self).__init__() 17 | self.img_src_dir = img_src_dir 18 | self.input_size = input_size 19 | 20 | img_path_list = [] 21 | for ext in ext_list: 22 | img_path_tmp = glob.glob(os.path.join(img_src_dir, ext)) 23 | img_path_list.extend(img_path_tmp) 24 | self.img_path_list = img_path_list 25 | 26 | def __len__(self): 27 | return len(self.img_path_list) 28 | 29 | def __getitem__(self, index): 30 | img_path = self.img_path_list[index] 31 | img = Image.open(img_path).convert("RGB") 32 | size_org = img.size 33 | img = img.resize((self.input_size, self.input_size)) 34 | img = np.array(img) / 255. 35 | 36 | return img, img_path, size_org 37 | 38 | 39 | def collate_fn(batch): 40 | return batch 41 | # batch = list(zip(*batch)) 42 | # return tuple(batch) 43 | 44 | 45 | def setup_for_distributed(is_master): 46 | """ 47 | This function disables printing when not in master process 48 | """ 49 | import builtins as __builtin__ 50 | builtin_print = __builtin__.print 51 | 52 | def print(*args, **kwargs): 53 | force = kwargs.pop('force', False) 54 | if is_master or force: 55 | builtin_print(*args, **kwargs) 56 | 57 | __builtin__.print = print 58 | 59 | 60 | def is_dist_avail_and_initialized(): 61 | if not dist.is_available(): 62 | return False 63 | if not dist.is_initialized(): 64 | return False 65 | return True 66 | 67 | 68 | def get_world_size(): 69 | if not is_dist_avail_and_initialized(): 70 | return 1 71 | return dist.get_world_size() 72 | 73 | 74 | def get_rank(): 75 | if not is_dist_avail_and_initialized(): 76 | return 0 77 | return dist.get_rank() 78 | 79 | 80 | def is_main_process(): 81 | return get_rank() == 0 82 | 83 | 84 | def init_distributed_mode(args): 85 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and 'LOCAL_RANK' in os.environ: 86 | args.rank = int(os.environ["RANK"]) 87 | args.world_size = int(os.environ['WORLD_SIZE']) 88 | args.gpu = int(os.environ['LOCAL_RANK']) 89 | elif 'SLURM_PROCID' in os.environ: 90 | args.rank = int(os.environ['SLURM_PROCID']) 91 | args.gpu = args.rank % torch.cuda.device_count() 92 | else: 93 | print('Not using distributed mode') 94 | args.distributed = False 95 | return args 96 | 97 | args.distributed = True 98 | 99 | torch.cuda.set_device(args.gpu) 100 | args.dist_backend = 'nccl' 101 | print('| distributed init (rank {}): {}'.format( 102 | args.rank, args.dist_url), flush=True) 103 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 104 | world_size=args.world_size, rank=args.rank) 105 | torch.distributed.barrier() 106 | setup_for_distributed(args.rank == 0) 107 | 108 | return args 109 | -------------------------------------------------------------------------------- /ProRes/demo/eval_sidd.m: -------------------------------------------------------------------------------- 1 | close all;clear all; 2 | 3 | % 5_282 4 | % 288_259 5 | % 215_127 6 | % 243_254 7 | % 263_83 8 | % 221_178 9 | % 104_111 10 | % 3_263 11 | % 300_46 12 | % 314_219 13 | % 147_261 14 | % 275_17 15 | % 100_172 16 | % 125_164 17 | % 311_45 18 | % 27_280 19 | 20 | denoised = load('/sharefs/wwen/unified-vp/uip/models_inference/new3_all_lr5e-4/sidd_inference_epoch14_27_280/Idenoised.mat'); 21 | gt = load('/sharefs/wwen/unified-vp/uip/data/low_level/denoising/sidd/val/ValidationGtBlocksSrgb.mat'); 22 | 23 | denoised = denoised.Idenoised; 24 | gt = gt.ValidationGtBlocksSrgb; 25 | gt = im2single(gt); 26 | 27 | total_psnr = 0; 28 | total_ssim = 0; 29 | for i = 1:40 30 | for k = 1:32 31 | denoised_patch = squeeze(denoised(i,k,:,:,:)); 32 | gt_patch = squeeze(gt(i,k,:,:,:)); 33 | ssim_val = ssim(denoised_patch, gt_patch); 34 | psnr_val = psnr(denoised_patch, gt_patch); 35 | total_ssim = total_ssim + ssim_val; 36 | total_psnr = total_psnr + psnr_val; 37 | end 38 | end 39 | qm_psnr = total_psnr / (40*32); 40 | qm_ssim = total_ssim / (40*32); 41 | 42 | fprintf('PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim); -------------------------------------------------------------------------------- /ProRes/demo/evaluate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | % Multi-Stage Progressive Image Restoration 2 | % Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 3 | % https://arxiv.org/abs/2102.02808 4 | 5 | close all;clear all; 6 | 7 | % datasets = {'Rain100L'}; 8 | datasets = {'Test100', 'Rain100H', 'Rain100L', 'Test2800', 'Test1200'}; 9 | num_set = length(datasets); 10 | 11 | psnr_alldatasets = 0; 12 | ssim_alldatasets = 0; 13 | 14 | tic 15 | 16 | % 13616 17 | % 9303 18 | % 1045 19 | % 9463 20 | % 5181 21 | % 4120 22 | % 6294 23 | % 1491 24 | % 770 25 | % 5459 26 | % 2606 27 | % 9069 28 | % 9210 29 | % 5865 30 | % 335 31 | % 12664 32 | 33 | 34 | for idx_set = 1:num_set 35 | file_path = strcat('/sharefs/wwen/unified-vp/uip/models_inference/new3_all_lr5e-4/derain_inference_epoch14_12664/', datasets{idx_set}, '/'); 36 | gt_path = strcat('/sharefs/wwen/unified-vp/uip/data/low_level/derain/testsets/', datasets{idx_set}, '/target/'); 37 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 38 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 39 | img_num = length(path_list); 40 | 41 | total_psnr = 0; 42 | total_ssim = 0; 43 | if img_num > 0 44 | for j = 1:img_num 45 | image_name = path_list(j).name; 46 | gt_name = gt_list(j).name; 47 | input = imread(strcat(file_path,image_name)); 48 | gt = imread(strcat(gt_path, gt_name)); 49 | ssim_val = compute_ssim(input, gt); 50 | psnr_val = compute_psnr(input, gt); 51 | total_ssim = total_ssim + ssim_val; 52 | total_psnr = total_psnr + psnr_val; 53 | end 54 | end 55 | qm_psnr = total_psnr / img_num; 56 | qm_ssim = total_ssim / img_num; 57 | 58 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 59 | 60 | psnr_alldatasets = psnr_alldatasets + qm_psnr; 61 | ssim_alldatasets = ssim_alldatasets + qm_ssim; 62 | 63 | end 64 | 65 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set); 66 | 67 | toc 68 | 69 | function ssim_mean=compute_ssim(img1,img2) 70 | if size(img1, 3) == 3 71 | img1 = rgb2ycbcr(img1); 72 | img1 = img1(:, :, 1); 73 | end 74 | 75 | if size(img2, 3) == 3 76 | img2 = rgb2ycbcr(img2); 77 | img2 = img2(:, :, 1); 78 | end 79 | ssim_mean = SSIM_index(img1, img2); 80 | end 81 | 82 | function psnr=compute_psnr(img1,img2) 83 | if size(img1, 3) == 3 84 | img1 = rgb2ycbcr(img1); 85 | img1 = img1(:, :, 1); 86 | end 87 | 88 | if size(img2, 3) == 3 89 | img2 = rgb2ycbcr(img2); 90 | img2 = img2(:, :, 1); 91 | end 92 | 93 | imdff = double(img1) - double(img2); 94 | imdff = imdff(:); 95 | rmse = sqrt(mean(imdff.^2)); 96 | psnr = 20*log10(255/rmse); 97 | 98 | end 99 | 100 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L) 101 | 102 | if (nargin < 2 || nargin > 5) 103 | ssim_index = -Inf; 104 | ssim_map = -Inf; 105 | return; 106 | end 107 | 108 | if (size(img1) ~= size(img2)) 109 | ssim_index = -Inf; 110 | ssim_map = -Inf; 111 | return; 112 | end 113 | 114 | [M N] = size(img1); 115 | 116 | if (nargin == 2) 117 | if ((M < 11) || (N < 11)) 118 | ssim_index = -Inf; 119 | ssim_map = -Inf; 120 | return 121 | end 122 | window = fspecial('gaussian', 11, 1.5); % 123 | K(1) = 0.01; % default settings 124 | K(2) = 0.03; % 125 | L = 255; % 126 | end 127 | 128 | if (nargin == 3) 129 | if ((M < 11) || (N < 11)) 130 | ssim_index = -Inf; 131 | ssim_map = -Inf; 132 | return 133 | end 134 | window = fspecial('gaussian', 11, 1.5); 135 | L = 255; 136 | if (length(K) == 2) 137 | if (K(1) < 0 || K(2) < 0) 138 | ssim_index = -Inf; 139 | ssim_map = -Inf; 140 | return; 141 | end 142 | else 143 | ssim_index = -Inf; 144 | ssim_map = -Inf; 145 | return; 146 | end 147 | end 148 | 149 | if (nargin == 4) 150 | [H W] = size(window); 151 | if ((H*W) < 4 || (H > M) || (W > N)) 152 | ssim_index = -Inf; 153 | ssim_map = -Inf; 154 | return 155 | end 156 | L = 255; 157 | if (length(K) == 2) 158 | if (K(1) < 0 || K(2) < 0) 159 | ssim_index = -Inf; 160 | ssim_map = -Inf; 161 | return; 162 | end 163 | else 164 | ssim_index = -Inf; 165 | ssim_map = -Inf; 166 | return; 167 | end 168 | end 169 | 170 | if (nargin == 5) 171 | [H W] = size(window); 172 | if ((H*W) < 4 || (H > M) || (W > N)) 173 | ssim_index = -Inf; 174 | ssim_map = -Inf; 175 | return 176 | end 177 | if (length(K) == 2) 178 | if (K(1) < 0 || K(2) < 0) 179 | ssim_index = -Inf; 180 | ssim_map = -Inf; 181 | return; 182 | end 183 | else 184 | ssim_index = -Inf; 185 | ssim_map = -Inf; 186 | return; 187 | end 188 | end 189 | 190 | C1 = (K(1)*L)^2; 191 | C2 = (K(2)*L)^2; 192 | window = window/sum(sum(window)); 193 | img1 = double(img1); 194 | img2 = double(img2); 195 | 196 | mu1 = filter2(window, img1, 'valid'); 197 | mu2 = filter2(window, img2, 'valid'); 198 | mu1_sq = mu1.*mu1; 199 | mu2_sq = mu2.*mu2; 200 | mu1_mu2 = mu1.*mu2; 201 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 202 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 203 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 204 | 205 | if (C1 > 0 & C2 > 0) 206 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 207 | else 208 | numerator1 = 2*mu1_mu2 + C1; 209 | numerator2 = 2*sigma12 + C2; 210 | denominator1 = mu1_sq + mu2_sq + C1; 211 | denominator2 = sigma1_sq + sigma2_sq + C2; 212 | ssim_map = ones(size(mu1)); 213 | index = (denominator1.*denominator2 > 0); 214 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 215 | index = (denominator1 ~= 0) & (denominator2 == 0); 216 | ssim_map(index) = numerator1(index)./denominator1(index); 217 | end 218 | 219 | mssim = mean2(ssim_map); 220 | 221 | end 222 | -------------------------------------------------------------------------------- /ProRes/demo/matrix_nms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def mask_matrix_nms(masks, 6 | labels, 7 | scores, 8 | filter_thr=-1, 9 | nms_pre=-1, 10 | max_num=-1, 11 | kernel='gaussian', 12 | sigma=2.0, 13 | mask_area=None): 14 | """Matrix NMS for multi-class masks. 15 | 16 | Args: 17 | masks (Tensor): Has shape (num_instances, h, w) 18 | labels (Tensor): Labels of corresponding masks, 19 | has shape (num_instances,). 20 | scores (Tensor): Mask scores of corresponding masks, 21 | has shape (num_instances). 22 | filter_thr (float): Score threshold to filter the masks 23 | after matrix nms. Default: -1, which means do not 24 | use filter_thr. 25 | nms_pre (int): The max number of instances to do the matrix nms. 26 | Default: -1, which means do not use nms_pre. 27 | max_num (int, optional): If there are more than max_num masks after 28 | matrix, only top max_num will be kept. Default: -1, which means 29 | do not use max_num. 30 | kernel (str): 'linear' or 'gaussian'. 31 | sigma (float): std in gaussian method. 32 | mask_area (Tensor): The sum of seg_masks. 33 | 34 | Returns: 35 | tuple(Tensor): Processed mask results. 36 | 37 | - scores (Tensor): Updated scores, has shape (n,). 38 | - labels (Tensor): Remained labels, has shape (n,). 39 | - masks (Tensor): Remained masks, has shape (n, w, h). 40 | - keep_inds (Tensor): The indices number of 41 | the remaining mask in the input mask, has shape (n,). 42 | """ 43 | assert len(labels) == len(masks) == len(scores) 44 | if len(labels) == 0: 45 | return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( 46 | 0, *masks.shape[-2:]), labels.new_zeros(0) 47 | if mask_area is None: 48 | mask_area = masks.sum((1, 2)).float() 49 | else: 50 | assert len(masks) == len(mask_area) 51 | 52 | # sort and keep top nms_pre 53 | scores, sort_inds = torch.sort(scores, descending=True) 54 | 55 | keep_inds = sort_inds 56 | if nms_pre > 0 and len(sort_inds) > nms_pre: 57 | sort_inds = sort_inds[:nms_pre] 58 | keep_inds = keep_inds[:nms_pre] 59 | scores = scores[:nms_pre] 60 | masks = masks[sort_inds] 61 | mask_area = mask_area[sort_inds] 62 | labels = labels[sort_inds] 63 | 64 | num_masks = len(labels) 65 | flatten_masks = masks.reshape(num_masks, -1).float() 66 | # inter. 67 | inter_matrix = torch.mm(flatten_masks, flatten_masks.transpose(1, 0)) 68 | expanded_mask_area = mask_area.expand(num_masks, num_masks) # todo: +1 to avoid nan? 69 | # Upper triangle iou matrix. 70 | iou_matrix = (inter_matrix / 71 | (expanded_mask_area + expanded_mask_area.transpose(1, 0) - 72 | inter_matrix)).triu(diagonal=1) 73 | # label_specific matrix. 74 | expanded_labels = labels.expand(num_masks, num_masks) 75 | # Upper triangle label matrix. 76 | label_matrix = (expanded_labels == expanded_labels.transpose( 77 | 1, 0)).triu(diagonal=1) 78 | 79 | # IoU compensation 80 | compensate_iou, _ = (iou_matrix * label_matrix).max(0) 81 | compensate_iou = compensate_iou.expand(num_masks, 82 | num_masks).transpose(1, 0) 83 | 84 | # IoU decay 85 | decay_iou = iou_matrix * label_matrix 86 | 87 | # Calculate the decay_coefficient 88 | if kernel == 'gaussian': 89 | decay_matrix = torch.exp(-1 * sigma * (decay_iou**2)) 90 | compensate_matrix = torch.exp(-1 * sigma * (compensate_iou**2)) 91 | decay_coefficient, _ = (decay_matrix / compensate_matrix).min(0) 92 | elif kernel == 'linear': 93 | decay_matrix = (1 - decay_iou) / (1 - compensate_iou) 94 | decay_coefficient, _ = decay_matrix.min(0) 95 | else: 96 | raise NotImplementedError( 97 | f'{kernel} kernel is not supported in matrix nms!') 98 | # update the score. 99 | scores = scores * decay_coefficient 100 | 101 | if filter_thr > 0: 102 | keep = scores >= filter_thr 103 | keep_inds = keep_inds[keep] 104 | if not keep.any(): 105 | return scores.new_zeros(0), labels.new_zeros(0), masks.new_zeros( 106 | 0, *masks.shape[-2:]), labels.new_zeros(0) 107 | masks = masks[keep] 108 | scores = scores[keep] 109 | labels = labels[keep] 110 | 111 | # sort and keep top max_num 112 | scores, sort_inds = torch.sort(scores, descending=True) 113 | keep_inds = keep_inds[sort_inds] 114 | if max_num > 0 and len(sort_inds) > max_num: 115 | sort_inds = sort_inds[:max_num] 116 | keep_inds = keep_inds[:max_num] 117 | scores = scores[:max_num] 118 | masks = masks[sort_inds] 119 | labels = labels[sort_inds] 120 | 121 | return scores, labels, masks, keep_inds 122 | -------------------------------------------------------------------------------- /ProRes/demo/ours_inference_deblur_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | import warnings 6 | 7 | import requests 8 | import argparse 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import glob 14 | import tqdm 15 | 16 | import matplotlib.pyplot as plt 17 | from PIL import Image 18 | 19 | sys.path.append('.') 20 | import models_ours 21 | 22 | import random 23 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 24 | from skimage.metrics import structural_similarity as ssim_loss 25 | from util.metrics import calculate_psnr, calculate_ssim 26 | 27 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 28 | imagenet_std = np.array([0.229, 0.224, 0.225]) 29 | 30 | 31 | 32 | def get_args_parser(): 33 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 34 | parser.add_argument('--ckpt_dir', type=str, help='dir to ckpt', 35 | default='/sharefs/baaivision/xinlongwang/code/uip/models/' 36 | 'new3_all_lr5e-4') 37 | parser.add_argument('--model', type=str, help='dir to ckpt', 38 | default='uip_vit_large_patch16_input896x448_win_dec64_8glb_sl1') 39 | parser.add_argument('--prompt', type=str, help='prompt image in train set', 40 | default='100') 41 | parser.add_argument('--epoch', type=int, help='model epochs', 42 | default=14) 43 | parser.add_argument('--input_size', type=int, help='model epochs', 44 | default=448) 45 | parser.add_argument('--split', type=int, help='model epochs', choices=[1, 2, 3, 4], 46 | default=3) 47 | parser.add_argument('--pred_gt', action='store_true', help='trained by using gt as gt', 48 | default=False) 49 | parser.add_argument('--save', action='store_true', help='save predictions', 50 | default=False) 51 | return parser.parse_args() 52 | 53 | 54 | def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'): 55 | # build model 56 | model = getattr(models_ours, arch)() 57 | # load model 58 | checkpoint = torch.load(chkpt_dir, map_location='cuda:0') 59 | msg = model.load_state_dict(checkpoint['model'], strict=False) 60 | print(msg) 61 | return model 62 | 63 | def random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.2,0.3]): 64 | image = image 65 | prompt = prompt 66 | h, w = image.shape[0],image.shape[1] 67 | 68 | mask_image = np.ones((int(h),int(w),3)) 69 | mask_prompt = np.zeros((int(h),int(w),3)) 70 | 71 | ratio = 0 72 | 73 | while (scale_range[0] > ratio) == True or (ratio > scale_range[1])!=True: 74 | h_p = w_p = int(random.uniform(prompt_range[0], prompt_range[1])) 75 | point_h = int(random.uniform(h_p, h-h_p)) 76 | point_w = int(random.uniform(w_p, w-w_p)) 77 | 78 | mask_image[point_h:point_h+h_p,point_w:point_w+w_p] = 0.0 79 | mask_prompt[point_h:point_h+h_p,point_w:point_w+w_p] = 1.0 80 | prompts_token_num = np.sum(mask_prompt) 81 | ratio = prompts_token_num/(h*w) 82 | 83 | # image = image*mask_image 84 | # prompt = prompt*mask_prompt 85 | image = image + prompt 86 | 87 | return image 88 | 89 | def run_one_image(img, tgt, prompt_org, size, model, out_path, device): 90 | x = torch.tensor(img) 91 | # make it a batch-like 92 | x = x.unsqueeze(dim=0) 93 | x = torch.einsum('nhwc->nchw', x) 94 | 95 | tgt = torch.tensor(tgt) 96 | # make it a batch-like 97 | tgt = tgt.unsqueeze(dim=0) 98 | tgt = torch.einsum('nhwc->nchw', tgt) 99 | 100 | prompt_org = torch.tensor(prompt_org) 101 | # make it a batch-like 102 | prompt_org = prompt_org.unsqueeze(dim=0) 103 | prompt_org = torch.einsum('nhwc->nchw', prompt_org) 104 | 105 | # bool_masked_pos = torch.zeros(model.patch_embed.num_patches) 106 | # bool_masked_pos[model.patch_embed.num_patches//2:] = 1 107 | # bool_masked_pos = bool_masked_pos.unsqueeze(dim=0) 108 | 109 | # run MAE 110 | loss, y = model(x.float().to(device), tgt.float().to(device)) 111 | y = model.unpatchify(y) 112 | y = torch.einsum('nchw->nhwc', y).detach().cpu() 113 | 114 | output = y[0, :, :, :] 115 | output = output * imagenet_std + imagenet_mean 116 | output = F.interpolate( 117 | output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='bicubic').permute(0, 2, 3, 1)[0] 118 | 119 | return output.numpy() 120 | 121 | 122 | if __name__ == '__main__': 123 | args = get_args_parser() 124 | 125 | ckpt_dir = args.ckpt_dir 126 | model = args.model 127 | epoch = args.epoch 128 | prompt = args.prompt 129 | input_size = args.input_size 130 | prompt_type = 'groundtruth' if args.pred_gt else 'target_sub_input' 131 | 132 | ckpt_file = 'checkpoint-{}.pth'.format(epoch) 133 | assert ckpt_dir[-1] != "/" 134 | dst_dir = os.path.join('models_inference', ckpt_dir.split('/')[-1], 135 | "deblur_inference_epoch{}_{}".format(epoch, os.path.basename(prompt).split(".")[0])) 136 | 137 | if os.path.exists(dst_dir): 138 | # raise Exception("{} exist! make sure to overwrite?".format(dst_dir)) 139 | warnings.warn("{} exist! make sure to overwrite?".format(dst_dir)) 140 | else: 141 | os.makedirs(dst_dir) 142 | print("output_dir: {}".format(dst_dir)) 143 | 144 | ckpt_path = os.path.join(ckpt_dir, ckpt_file) 145 | model_mae = prepare_model(ckpt_path, model) 146 | print('Model loaded.') 147 | 148 | device = torch.device("cuda") 149 | model_mae.to(device) 150 | 151 | 152 | model_mae.eval() 153 | datasets = ['GoPro','HIDE','RealBlur_R','RealBlur_J'] 154 | 155 | psnr_alldatasets = [] 156 | ssim_alldatasets = [] 157 | print(datasets) 158 | img_src_dir = "datasets/low_level/deblur/test/" 159 | for dset in datasets: 160 | psnr_val_rgb = [] 161 | ssim_val_rgb = [] 162 | real_src_dir = os.path.join(img_src_dir, dset, 'input') 163 | real_dst_dir = os.path.join(dst_dir, dset) 164 | if not os.path.exists(real_dst_dir): 165 | os.makedirs(real_dst_dir) 166 | img_path_list = glob.glob(os.path.join(real_src_dir, "*.png")) + glob.glob(os.path.join(real_src_dir, "*.jpg")) 167 | for img_path in tqdm.tqdm(img_path_list): 168 | """ Load an image """ 169 | img_name = os.path.basename(img_path) 170 | out_path = os.path.join(real_dst_dir, img_name.replace('jpg', 'png')) # TODO: save all results as pngs 171 | img_org = Image.open(img_path).convert("RGB") 172 | 173 | size = img_org.size 174 | img = img_org.resize((input_size, input_size)) 175 | img = np.array(img) / 255. 176 | 177 | img = img - imagenet_mean 178 | img = img / imagenet_std 179 | 180 | prompt_org = np.load('datasets/low_level/deblur.npy') 181 | 182 | # img = random_add_prompts_random_scales(img,prompt_org,prompt_range=[8,64],scale_range=[0.2,0.3]) 183 | # simple add 184 | img = img + prompt_org 185 | 186 | 187 | # load gt 188 | rgb_gt = Image.open(img_path.replace('input', 'groundtruth')).convert("RGB") # irrelevant to prompt-type 189 | 190 | tgt = rgb_gt.resize((input_size, input_size)) 191 | tgt = np.array(tgt) / 255. 192 | 193 | 194 | # normalize by ImageNet mean and std 195 | tgt = tgt - imagenet_mean 196 | tgt = tgt / imagenet_std 197 | 198 | """### Run MAE on the image""" 199 | # make random mask reproducible (comment out to make it change) 200 | torch.manual_seed(2) 201 | 202 | output = run_one_image(img, tgt, prompt_org, size, model_mae, out_path, device) 203 | 204 | rgb_restored = output 205 | 206 | rgb_restored = np.clip(rgb_restored, 0, 1) 207 | 208 | 209 | rgb_gt = np.array(rgb_gt) / 255. 210 | 211 | psnr = calculate_psnr(rgb_restored*255., rgb_gt*255., 0, test_y_channel=False) 212 | ssim = calculate_ssim(rgb_restored*255., rgb_gt*255., 0, test_y_channel=False) 213 | # psnr = psnr_loss(rgb_restored, rgb_gt) 214 | # ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True) 215 | psnr_val_rgb.append(psnr) 216 | ssim_val_rgb.append(ssim) 217 | # print("PSNR:", psnr, ",", img_name, rgb_restored.shape) 218 | # print("PSNR:", psnr, ", SSIM:", ssim, img_name) 219 | 220 | if args.save: 221 | # utils.save_img(out_path, img_as_ubyte(rgb_restored)) 222 | output = rgb_restored * 255 223 | output = Image.fromarray(output.astype(np.uint8)) 224 | output.save(out_path) 225 | 226 | # with open(os.path.join(dst_dir, 'psnr_ssim.txt'), 'a') as f: 227 | # # f.write(img_name+' ---->'+" PSNR: %.4f, SSIM: %.4f] " % (psnr, ssim)+'\n') 228 | # f.write(img_name+' ---->'+" PSNR: %.4f" % (psnr)+'\n') 229 | 230 | psnr_val_rgb = sum(psnr_val_rgb) / len(img_path_list) 231 | ssim_val_rgb = sum(ssim_val_rgb) / len(img_path_list) 232 | psnr_alldatasets.append(psnr_val_rgb) 233 | ssim_alldatasets.append(ssim_val_rgb) 234 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) 235 | 236 | psnr_all = sum(psnr_alldatasets) / len(datasets) 237 | ssim_all = sum(ssim_alldatasets) / len(datasets) 238 | print("PSNR: %f, SSIM: %f " % (psnr_all, ssim_all)) 239 | 240 | # # print("PSNR: %f" % (psnr_val_rgb)) 241 | # print(ckpt_path) 242 | # with open(os.path.join(dst_dir, 'psnr_ssim.txt'), 'a') as f: 243 | # f.write("PSNR: %.4f, SSIM: %.4f] " % (psnr_val_rgb, ssim_val_rgb)+'\n') 244 | # # f.write("PSNR: %.4f" % (psnr_val_rgb)+'\n') 245 | -------------------------------------------------------------------------------- /ProRes/demo/ours_inference_derain_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import sys 5 | import os 6 | import warnings 7 | 8 | import requests 9 | import argparse 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import glob 15 | import tqdm 16 | 17 | import matplotlib.pyplot as plt 18 | from PIL import Image 19 | 20 | sys.path.append('.') 21 | import models_ours 22 | 23 | 24 | import random 25 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 26 | from skimage.metrics import structural_similarity as ssim_loss 27 | from util.metrics import calculate_psnr, calculate_ssim 28 | 29 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 30 | imagenet_std = np.array([0.229, 0.224, 0.225]) 31 | 32 | 33 | 34 | 35 | def get_args_parser(): 36 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 37 | parser.add_argument('--ckpt_dir', type=str, help='dir to ckpt', 38 | default='/sharefs/baaivision/xinlongwang/code/uip/models/' 39 | 'new3_all_lr5e-4') 40 | parser.add_argument('--model', type=str, help='dir to ckpt', 41 | default='uip_vit_large_patch16_input896x448_win_dec64_8glb_sl1') 42 | parser.add_argument('--prompt', type=str, help='prompt image in train set', 43 | default='100') 44 | parser.add_argument('--epoch', type=int, help='model epochs', 45 | default=14) 46 | parser.add_argument('--input_size', type=int, help='model epochs', 47 | default=448) 48 | parser.add_argument('--split', type=int, help='model epochs', choices=[1, 2, 3, 4], 49 | default=3) 50 | parser.add_argument('--pred_gt', action='store_true', help='trained by using gt as gt', 51 | default=False) 52 | parser.add_argument('--save', action='store_true', help='save predictions', 53 | default=False) 54 | return parser.parse_args() 55 | 56 | 57 | def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'): 58 | # build model 59 | model = getattr(models_ours, arch)() 60 | # load model 61 | checkpoint = torch.load(chkpt_dir, map_location='cuda:0') 62 | msg = model.load_state_dict(checkpoint['model'], strict=False) 63 | print(msg) 64 | return model 65 | 66 | def random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.2,0.3]): 67 | image = image 68 | prompt = prompt 69 | h, w = image.shape[0],image.shape[1] 70 | 71 | mask_image = np.ones((int(h),int(w),3)) 72 | mask_prompt = np.zeros((int(h),int(w),3)) 73 | 74 | ratio = 0 75 | 76 | while (scale_range[0] > ratio) == True or (ratio > scale_range[1])!=True: 77 | h_p = w_p = int(random.uniform(prompt_range[0], prompt_range[1])) 78 | point_h = int(random.uniform(h_p, h-h_p)) 79 | point_w = int(random.uniform(w_p, w-w_p)) 80 | 81 | mask_image[point_h:point_h+h_p,point_w:point_w+w_p] = 0.0 82 | mask_prompt[point_h:point_h+h_p,point_w:point_w+w_p] = 1.0 83 | prompts_token_num = np.sum(mask_prompt) 84 | ratio = prompts_token_num/(h*w) 85 | 86 | # image = image*mask_image 87 | # prompt = prompt*mask_prompt 88 | image = image + prompt 89 | 90 | return image 91 | 92 | def run_one_image(img, tgt, prompt_org, size, model, out_path, device): 93 | x = torch.tensor(img) 94 | # make it a batch-like 95 | x = x.unsqueeze(dim=0) 96 | x = torch.einsum('nhwc->nchw', x) 97 | 98 | tgt = torch.tensor(tgt) 99 | # make it a batch-like 100 | tgt = tgt.unsqueeze(dim=0) 101 | tgt = torch.einsum('nhwc->nchw', tgt) 102 | 103 | # prompt_org = torch.tensor(prompt_org) 104 | # # make it a batch-like 105 | # prompt_org = prompt_org.unsqueeze(dim=0) 106 | # prompt_org = torch.einsum('nhwc->nchw', prompt_org) 107 | 108 | # bool_masked_pos = torch.zeros(model.patch_embed.num_patches) 109 | # bool_masked_pos[model.patch_embed.num_patches//2:] = 1 110 | # bool_masked_pos = bool_masked_pos.unsqueeze(dim=0) 111 | 112 | # run MAE 113 | loss, y = model(x.float().to(device), tgt.float().to(device)) 114 | y = model.unpatchify(y) 115 | y = torch.einsum('nchw->nhwc', y).detach().cpu() 116 | 117 | output = y[0, :, :, :] 118 | output = output * imagenet_std + imagenet_mean 119 | output = F.interpolate( 120 | output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='bicubic').permute(0, 2, 3, 1)[0] 121 | 122 | return output.numpy() 123 | 124 | 125 | if __name__ == '__main__': 126 | args = get_args_parser() 127 | 128 | ckpt_dir = args.ckpt_dir 129 | model = args.model 130 | epoch = args.epoch 131 | prompt = args.prompt 132 | input_size = args.input_size 133 | prompt_type = 'target' if args.pred_gt else 'target_sub_input' 134 | 135 | ckpt_file = 'checkpoint-{}.pth'.format(epoch) 136 | assert ckpt_dir[-1] != "/" 137 | dst_dir = os.path.join('models_inference', ckpt_dir.split('/')[-1], 138 | "derain_inference_epoch{}_{}".format(epoch, os.path.basename(prompt).split(".")[0])) 139 | 140 | if os.path.exists(dst_dir): 141 | # raise Exception("{} exist! make sure to overwrite?".format(dst_dir)) 142 | warnings.warn("{} exist! make sure to overwrite?".format(dst_dir)) 143 | else: 144 | os.makedirs(dst_dir) 145 | print("output_dir: {}".format(dst_dir)) 146 | 147 | ckpt_path = os.path.join(ckpt_dir, ckpt_file) 148 | model_mae = prepare_model(ckpt_path, model) 149 | print('Model loaded.') 150 | 151 | device = torch.device("cuda") 152 | model_mae.to(device) 153 | 154 | 155 | model_mae.eval() 156 | # datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'] 157 | if args.split == 1: 158 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200'] 159 | elif args.split == 2: 160 | datasets = ['Test2800'] # this is bottleneck in time ~20min 161 | elif args.split == 3: 162 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'] 163 | elif args.split == 4: 164 | datasets = ['Rain100H'] 165 | else: 166 | raise NotImplementedError(args.split) 167 | 168 | psnr_alldatasets = [] 169 | ssim_alldatasets = [] 170 | print(datasets) 171 | img_src_dir = "datasets/low_level/derain/test/" 172 | for dset in datasets: 173 | psnr_val_rgb = [] 174 | ssim_val_rgb = [] 175 | real_src_dir = os.path.join(img_src_dir, dset, 'input') 176 | real_dst_dir = os.path.join(dst_dir, dset) 177 | if not os.path.exists(real_dst_dir): 178 | os.makedirs(real_dst_dir) 179 | img_path_list = glob.glob(os.path.join(real_src_dir, "*.png")) + glob.glob(os.path.join(real_src_dir, "*.jpg")) 180 | for img_path in tqdm.tqdm(img_path_list): 181 | """ Load an image """ 182 | img_name = os.path.basename(img_path) 183 | out_path = os.path.join(real_dst_dir, img_name.replace('jpg', 'png')) # TODO: save all results as pngs 184 | img_org = Image.open(img_path).convert("RGB") 185 | 186 | 187 | size = img_org.size 188 | img = img_org.resize((input_size, input_size)) 189 | img = np.array(img) / 255. 190 | 191 | 192 | img = img - imagenet_mean 193 | img = img / imagenet_std 194 | 195 | 196 | prompt_org = np.load('datasets/low_level/derain.npy') 197 | # prompt_rand = np.load('datasets/low_level/ssid.npy') 198 | # alpha = 1 199 | # prompt_org = (1 - alpha) * prompt_org + alpha * prompt_rand 200 | 201 | # img = random_add_prompts_random_scales(img,prompt_org,prompt_range=[8,64],scale_range=[0.2,0.3]) 202 | # simple add 203 | img = img + prompt_org 204 | 205 | # load gt 206 | rgb_gt = Image.open(img_path.replace('input', 'target')).convert("RGB") # irrelevant to prompt-type 207 | 208 | tgt = rgb_gt.resize((input_size, input_size)) 209 | tgt = np.array(tgt) / 255. 210 | 211 | 212 | # normalize by ImageNet mean and std 213 | tgt = tgt - imagenet_mean 214 | tgt = tgt / imagenet_std 215 | 216 | """### Run MAE on the image""" 217 | # make random mask reproducible (comment out to make it change) 218 | torch.manual_seed(2) 219 | 220 | output = run_one_image(img, tgt, prompt_org, size, model_mae, out_path, device) 221 | 222 | rgb_restored = output 223 | 224 | rgb_restored = np.clip(rgb_restored, 0, 1) 225 | 226 | 227 | 228 | rgb_gt = np.array(rgb_gt) / 255. 229 | 230 | psnr = calculate_psnr(rgb_restored*255., rgb_gt*255., 0, test_y_channel=True) 231 | ssim = calculate_ssim(rgb_restored*255., rgb_gt*255., 0, test_y_channel=True) 232 | # psnr = psnr_loss(rgb_restored, rgb_gt, data_range=1) 233 | # ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=1) 234 | 235 | psnr_val_rgb.append(psnr) 236 | ssim_val_rgb.append(ssim) 237 | 238 | 239 | 240 | 241 | if args.save: 242 | # utils.save_img(out_path, img_as_ubyte(rgb_restored)) 243 | output = rgb_restored * 255 244 | output = Image.fromarray(output.astype(np.uint8)) 245 | output.save(out_path) 246 | 247 | psnr_val_rgb = sum(psnr_val_rgb) / len(img_path_list) 248 | ssim_val_rgb = sum(ssim_val_rgb) / len(img_path_list) 249 | psnr_alldatasets.append(psnr_val_rgb) 250 | ssim_alldatasets.append(ssim_val_rgb) 251 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) 252 | 253 | psnr_all = sum(psnr_alldatasets) / len(datasets) 254 | ssim_all = sum(ssim_alldatasets) / len(datasets) 255 | print("PSNR: %f, SSIM: %f " % (psnr_all, ssim_all)) -------------------------------------------------------------------------------- /ProRes/demo/ours_inference_lol_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | import warnings 6 | 7 | import requests 8 | import argparse 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import glob 14 | import tqdm 15 | 16 | import matplotlib.pyplot as plt 17 | from PIL import Image 18 | 19 | sys.path.append('.') 20 | 21 | import models_ours 22 | import cv2 23 | 24 | import random 25 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 26 | from skimage.metrics import structural_similarity as ssim_loss 27 | from util.metrics import calculate_psnr, calculate_ssim 28 | 29 | 30 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 31 | imagenet_std = np.array([0.229, 0.224, 0.225]) 32 | 33 | 34 | 35 | 36 | def get_args_parser(): 37 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 38 | parser.add_argument('--ckpt_dir', type=str, help='dir to ckpt', 39 | default='/sharefs/baaivision/xinlongwang/code/uip/models/' 40 | 'new3_all_lr5e-4') 41 | # 'new_ablation_bs2x32x4_enhance_gt_300ep_sl1_beta0.01_square896x448_fusefeat_mask0.75_merge3') 42 | parser.add_argument('--model', type=str, help='dir to ckpt', 43 | default='uip_vit_large_patch16_input896x448_win_dec64_8glb_sl1') 44 | parser.add_argument('--prompt', type=str, help='prompt image in train set', 45 | default='100') 46 | parser.add_argument('--epoch', type=int, help='model epochs', 47 | default=14) 48 | # default=150) 49 | parser.add_argument('--input_size', type=int, help='model epochs', 50 | default=448) 51 | parser.add_argument('--pred_gt', action='store_true', help='trained by using gt as gt', 52 | default=False) 53 | parser.add_argument('--save', action='store_true', help='save predictions', 54 | default=False) 55 | return parser.parse_args() 56 | 57 | 58 | def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'): 59 | # build model 60 | model = getattr(models_ours, arch)() 61 | # load model 62 | checkpoint = torch.load(chkpt_dir, map_location='cuda:0') 63 | msg = model.load_state_dict(checkpoint['model'], strict=False) 64 | print(msg) 65 | return model 66 | 67 | def random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.2,0.3]): 68 | image = image 69 | prompt = prompt 70 | h, w = image.shape[0],image.shape[1] 71 | 72 | mask_image = np.ones((int(h),int(w),3)) 73 | mask_prompt = np.zeros((int(h),int(w),3)) 74 | 75 | ratio = 0 76 | 77 | while (scale_range[0] > ratio) == True or (ratio > scale_range[1])!=True: 78 | h_p = w_p = int(random.uniform(prompt_range[0], prompt_range[1])) 79 | point_h = int(random.uniform(h_p, h-h_p)) 80 | point_w = int(random.uniform(w_p, w-w_p)) 81 | 82 | mask_image[point_h:point_h+h_p,point_w:point_w+w_p] = 0.0 83 | mask_prompt[point_h:point_h+h_p,point_w:point_w+w_p] = 1.0 84 | prompts_token_num = np.sum(mask_prompt) 85 | ratio = prompts_token_num/(h*w) 86 | 87 | # image = image*mask_image 88 | # prompt = prompt*mask_prompt 89 | image = image + prompt 90 | 91 | return image 92 | 93 | def run_one_image(img, tgt, prompt_org, size, model, out_path, device): 94 | x = torch.tensor(img) 95 | # make it a batch-like 96 | x = x.unsqueeze(dim=0) 97 | x = torch.einsum('nhwc->nchw', x) 98 | 99 | tgt = torch.tensor(tgt) 100 | # make it a batch-like 101 | tgt = tgt.unsqueeze(dim=0) 102 | tgt = torch.einsum('nhwc->nchw', tgt) 103 | 104 | 105 | # prompt_org = torch.tensor(prompt_org) 106 | # # make it a batch-like 107 | # prompt_org = prompt_org.unsqueeze(dim=0) 108 | # prompt_org = torch.einsum('nhwc->nchw', prompt_org) 109 | 110 | # run MAE 111 | loss, y = model(x.float().to(device), tgt.float().to(device)) 112 | y = model.unpatchify(y) 113 | y = torch.einsum('nchw->nhwc', y).detach().cpu() 114 | 115 | output = y[0, :, :, :] 116 | output = output * imagenet_std + imagenet_mean 117 | output = F.interpolate( 118 | output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='bicubic').permute(0, 2, 3, 1)[0] 119 | 120 | return output.numpy() 121 | 122 | 123 | # TODO: modified from impl. in git@github.com:swz30/MIRNet.git 124 | def myPSNR(tar_img, prd_img): 125 | imdff = np.clip(prd_img, 0, 1) - np.clip(tar_img, 0, 1) 126 | rmse = np.sqrt((imdff ** 2).mean()) 127 | ps = 20 * np.log10(1 / rmse) 128 | return ps 129 | 130 | 131 | if __name__ == '__main__': 132 | args = get_args_parser() 133 | 134 | ckpt_dir = args.ckpt_dir 135 | model = args.model 136 | epoch = args.epoch 137 | prompt = args.prompt 138 | input_size = args.input_size 139 | prompt_type = 'gt' if args.pred_gt else 'gt_sub_input' 140 | 141 | ckpt_file = 'checkpoint-{}.pth'.format(epoch) 142 | assert ckpt_dir[-1] != "/" 143 | dst_dir = os.path.join('models_inference', ckpt_dir.split('/')[-1], 144 | "lol_inference_epoch{}_{}".format(epoch, os.path.basename(prompt).split(".")[0])) 145 | 146 | if os.path.exists(dst_dir): 147 | # raise Exception("{} exist! make sure to overwrite?".format(dst_dir)) 148 | warnings.warn("{} exist! make sure to overwrite?".format(dst_dir)) 149 | else: 150 | os.makedirs(dst_dir) 151 | print("output_dir: {}".format(dst_dir)) 152 | 153 | ckpt_path = os.path.join(ckpt_dir, ckpt_file) 154 | model_mae = prepare_model(ckpt_path, model) 155 | print('Model loaded.') 156 | 157 | device = torch.device("cuda") 158 | model_mae.to(device) 159 | 160 | img_src_dir = "datasets/low_level/enhance/lol/eval15/input" 161 | img_path_list = glob.glob(os.path.join(img_src_dir, "*.png")) 162 | 163 | 164 | psnr_val_rgb = [] 165 | ssim_val_rgb = [] 166 | model_mae.eval() 167 | for img_path in tqdm.tqdm(img_path_list): 168 | """ Load an image """ 169 | img_name = os.path.basename(img_path) 170 | out_path = os.path.join(dst_dir, img_name) 171 | img_org = Image.open(img_path).convert("RGB") 172 | 173 | size = img_org.size 174 | img = img_org.resize((input_size, input_size)) 175 | img = np.array(img) / 255. 176 | 177 | img = img - imagenet_mean 178 | img = img / imagenet_std 179 | 180 | prompt_org = np.load('datasets/low_level/enhance.npy') 181 | # prompt_rand = np.load('datasets/low_level/ssid.npy') 182 | # alpha = 0 183 | # prompt_org = alpha * prompt_rand + (1 - alpha) * prompt_org 184 | 185 | # img = random_add_prompts_random_scales(img,prompt_org,prompt_range=[8,64],scale_range=[0.6,0.8]) 186 | # simple add 187 | img = img + prompt_org 188 | 189 | 190 | # load gt 191 | rgb_gt = Image.open(img_path.replace('input', 'gt')).convert("RGB") # irrelevant to prompt-type 192 | tgt = rgb_gt.resize((input_size, input_size)) 193 | tgt = np.array(tgt) / 255. 194 | 195 | 196 | # normalize by ImageNet mean and std 197 | tgt = tgt - imagenet_mean 198 | tgt = tgt / imagenet_std 199 | 200 | """### Run MAE on the image""" 201 | # make random mask reproducible (comment out to make it change) 202 | torch.manual_seed(2) 203 | 204 | output = run_one_image(img, tgt, prompt_org, size, model_mae, out_path, device) 205 | 206 | rgb_restored = output 207 | 208 | rgb_restored = np.clip(rgb_restored, 0, 1) 209 | 210 | 211 | rgb_gt = np.array(rgb_gt) / 255. 212 | # psnr = myPSNR(rgb_restored, rgb_gt) 213 | # print(rgb_restored.shape, rgb_gt.max()) 214 | # exit() 215 | 216 | # rgb_restored = rgb_restored*255. 217 | # rgb_gt = rgb_gt*255. 218 | psnr = calculate_psnr(rgb_restored*255., rgb_gt*255., 0) 219 | ssim = calculate_ssim(rgb_restored*255., rgb_gt*255., 0, test_y_channel=False) 220 | # psnr = psnr_loss(rgb_restored, rgb_gt, data_range=1) 221 | # ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=1) 222 | # print(rgb_restored.max()-rgb_restored.min(), rgb_gt.max()-rgb_gt.min()) 223 | psnr_val_rgb.append(psnr) 224 | ssim_val_rgb.append(ssim) 225 | # # print("PSNR:", psnr, ",", img_name, rgb_restored.shape) 226 | # print("PSNR:", psnr, ", SSIM:", ssim, img_name) 227 | 228 | if args.save: 229 | # utils.save_img(out_path, img_as_ubyte(rgb_restored)) 230 | output = rgb_restored*255. 231 | output = Image.fromarray(output.astype(np.uint8)) 232 | output.save(out_path) 233 | 234 | # with open(os.path.join(dst_dir, 'psnr_ssim.txt'), 'a') as f: 235 | # # f.write(img_name+' ---->'+" PSNR: %.4f, SSIM: %.4f] " % (psnr, ssim)+'\n') 236 | # f.write(img_name+' ---->'+" PSNR: %.4f" % (psnr)+'\n') 237 | 238 | psnr_val_rgb = sum(psnr_val_rgb) / len(img_path_list) 239 | ssim_val_rgb = sum(ssim_val_rgb) / len(img_path_list) 240 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) 241 | # print("PSNR: %f" % (psnr_val_rgb)) 242 | # print(ssim_val_rgb) 243 | # with open(os.path.join(dst_dir, 'psnr_ssim.txt'), 'a') as f: 244 | # f.write("PSNR: %.4f, SSIM: %.4f] " % (psnr_val_rgb, ssim_val_rgb)+'\n') 245 | # # f.write("PSNR: %.4f" % (psnr_val_rgb)+'\n') 246 | -------------------------------------------------------------------------------- /ProRes/demo/ours_inference_sidd_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | import sys 5 | import os 6 | import warnings 7 | 8 | import requests 9 | import argparse 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import glob 15 | import tqdm 16 | 17 | import matplotlib.pyplot as plt 18 | from PIL import Image 19 | import scipy.io as sio 20 | 21 | sys.path.append('.') 22 | 23 | import models_ours 24 | 25 | import random 26 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 27 | from skimage.metrics import structural_similarity as ssim_loss 28 | from util.metrics import calculate_psnr, calculate_ssim 29 | 30 | 31 | 32 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 33 | imagenet_std = np.array([0.229, 0.224, 0.225]) 34 | 35 | 36 | 37 | 38 | 39 | def get_args_parser(): 40 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 41 | parser.add_argument('--ckpt_dir', type=str, help='dir to ckpt', 42 | default='/sharefs/baaivision/xinlongwang/code/uip/models/' 43 | 'new3_all_lr5e-4') 44 | # 'new_new_all_lr5e-4_newcocoins') 45 | parser.add_argument('--model', type=str, help='dir to ckpt', 46 | default='uip_vit_large_patch16_input896x448_win_dec64_8glb_sl1') 47 | parser.add_argument('--prompt', type=str, help='prompt image in train set', 48 | default='100') 49 | parser.add_argument('--epoch', type=int, help='model epochs', 50 | default=14) 51 | parser.add_argument('--input_size', type=int, help='model epochs', 52 | default=448) 53 | parser.add_argument('--pred_gt', action='store_true', help='trained by using gt as gt', 54 | default=False) 55 | parser.add_argument('--save', action='store_true', help='save predictions', 56 | default=False) 57 | return parser.parse_args() 58 | 59 | 60 | def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'): 61 | # build model 62 | model = getattr(models_ours, arch)() 63 | # load model 64 | checkpoint = torch.load(chkpt_dir, map_location='cuda:0') 65 | msg = model.load_state_dict(checkpoint['model'], strict=False) 66 | print(msg) 67 | return model 68 | 69 | 70 | 71 | def random_add_prompts_random_scales(image, prompt, prompt_range=[8,64], scale_range=[0.2,0.3]): 72 | image = image 73 | prompt = prompt 74 | h, w = image.shape[0],image.shape[1] 75 | 76 | mask_image = np.ones((int(h),int(w),3)) 77 | mask_prompt = np.zeros((int(h),int(w),3)) 78 | 79 | ratio = 0 80 | 81 | while (scale_range[0] > ratio) == True or (ratio > scale_range[1])!=True: 82 | h_p = w_p = int(random.uniform(prompt_range[0], prompt_range[1])) 83 | point_h = int(random.uniform(h_p, h-h_p)) 84 | point_w = int(random.uniform(w_p, w-w_p)) 85 | 86 | mask_image[point_h:point_h+h_p,point_w:point_w+w_p] = 0.0 87 | mask_prompt[point_h:point_h+h_p,point_w:point_w+w_p] = 1.0 88 | prompts_token_num = np.sum(mask_prompt) 89 | ratio = prompts_token_num/(h*w) 90 | 91 | # image = image*mask_image 92 | # prompt = prompt*mask_prompt 93 | image = image + prompt 94 | 95 | return image 96 | 97 | def run_one_image(img, tgt, prompt_org, size, model, out_path, device): 98 | x = torch.tensor(img) 99 | # make it a batch-like 100 | x = x.unsqueeze(dim=0) 101 | x = torch.einsum('nhwc->nchw', x) 102 | 103 | tgt = torch.tensor(tgt) 104 | # make it a batch-like 105 | tgt = tgt.unsqueeze(dim=0) 106 | tgt = torch.einsum('nhwc->nchw', tgt) 107 | 108 | prompt_org = torch.tensor(prompt_org) 109 | # make it a batch-like 110 | prompt_org = prompt_org.unsqueeze(dim=0) 111 | prompt_org = torch.einsum('nhwc->nchw', prompt_org) 112 | 113 | 114 | # run MAE 115 | loss, y = model(x.float().to(device), tgt.float().to(device)) 116 | y = model.unpatchify(y) 117 | y = torch.einsum('nchw->nhwc', y).detach().cpu() 118 | 119 | output = y[0, :, :, :] 120 | output = output * imagenet_std + imagenet_mean 121 | output = F.interpolate( 122 | output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='bicubic').permute(0, 2, 3, 1)[0] 123 | 124 | return output.numpy() 125 | 126 | 127 | if __name__ == '__main__': 128 | args = get_args_parser() 129 | 130 | ckpt_dir = args.ckpt_dir 131 | model = args.model 132 | epoch = args.epoch 133 | prompt = args.prompt 134 | input_size = args.input_size 135 | prompt_type = 'groundtruth' if args.pred_gt else 'gt_sub_input' 136 | 137 | ckpt_file = 'checkpoint-{}.pth'.format(epoch) 138 | assert ckpt_dir[-1] != "/" 139 | dst_dir = os.path.join('models_inference', ckpt_dir.split('/')[-1], 140 | "sidd_inference_epoch{}_{}".format(epoch, os.path.basename(prompt).split(".")[0])) 141 | 142 | if os.path.exists(dst_dir): 143 | # raise Exception("{} exist! make sure to overwrite?".format(dst_dir)) 144 | warnings.warn("{} exist! make sure to overwrite?".format(dst_dir)) 145 | else: 146 | os.makedirs(dst_dir) 147 | print("output_dir: {}".format(dst_dir)) 148 | 149 | ckpt_path = os.path.join(ckpt_dir, ckpt_file) 150 | model_mae = prepare_model(ckpt_path, model) 151 | print('Model loaded.') 152 | 153 | device = torch.device("cuda") 154 | model_mae.to(device) 155 | 156 | img_src_dir = "datasets/low_level/denoising/sidd/sidd_val_patch256/input" 157 | img_path_list = glob.glob(os.path.join(img_src_dir, "*.png")) 158 | 159 | 160 | 161 | model_mae.eval() 162 | psnr_val_rgb = [] 163 | ssim_val_rgb = [] 164 | for img_path in tqdm.tqdm(img_path_list): 165 | """ Load an image """ 166 | img_name = os.path.basename(img_path) 167 | out_path = os.path.join(dst_dir, img_name) 168 | img_org = Image.open(img_path).convert("RGB") 169 | 170 | size = img_org.size 171 | img = img_org.resize((input_size, input_size)) 172 | img = np.array(img) / 255. 173 | 174 | 175 | img = img - imagenet_mean 176 | img = img / imagenet_std 177 | 178 | 179 | prompt_org = np.load('datasets/low_level/ssid.npy') 180 | 181 | # img = random_add_prompts_random_scales(img,prompt_org,prompt_range=[8,64],scale_range=[0.6,0.8]) 182 | # # simple add 183 | img = img + prompt_org 184 | 185 | # load gt 186 | rgb_gt = Image.open(img_path.replace('input', 'groundtruth')).convert("RGB") # irrelevant to prompt-type 187 | 188 | tgt = rgb_gt.resize((input_size, input_size)) 189 | tgt = np.array(tgt) / 255. 190 | 191 | 192 | # normalize by ImageNet mean and std 193 | tgt = tgt - imagenet_mean 194 | tgt = tgt / imagenet_std 195 | 196 | """### Run MAE on the image""" 197 | # make random mask reproducible (comment out to make it change) 198 | torch.manual_seed(2) 199 | 200 | output = run_one_image(img, tgt, prompt_org, size, model_mae, out_path, device) 201 | 202 | rgb_restored = output 203 | 204 | rgb_restored = np.clip(rgb_restored, 0, 1) 205 | 206 | 207 | rgb_gt = np.array(rgb_gt) / 255. 208 | 209 | psnr = calculate_psnr(rgb_restored*255., rgb_gt*255., 0, test_y_channel=False) 210 | ssim = calculate_ssim(rgb_restored*255., rgb_gt*255., 0, test_y_channel=False) 211 | # psnr = psnr_loss(rgb_restored, rgb_gt) 212 | # ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True) 213 | psnr_val_rgb.append(psnr) 214 | ssim_val_rgb.append(ssim) 215 | # # print("PSNR:", psnr, ",", img_name, rgb_restored.shape) 216 | # print("PSNR:", psnr, ", SSIM:", ssim, img_name) 217 | 218 | if args.save: 219 | # utils.save_img(out_path, img_as_ubyte(rgb_restored)) 220 | output = rgb_restored * 255 221 | output = Image.fromarray(output.astype(np.uint8)) 222 | output.save(out_path) 223 | 224 | # with open(os.path.join(dst_dir, 'psnr_ssim.txt'), 'a') as f: 225 | # # f.write(img_name+' ---->'+" PSNR: %.4f, SSIM: %.4f] " % (psnr, ssim)+'\n') 226 | # f.write(img_name+' ---->'+" PSNR: %.4f" % (psnr)+'\n') 227 | 228 | psnr_val_rgb = sum(psnr_val_rgb) / len(img_path_list) 229 | ssim_val_rgb = sum(ssim_val_rgb) / len(img_path_list) 230 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) 231 | # # print("PSNR: %f" % (psnr_val_rgb)) 232 | # print(ckpt_path) 233 | # with open(os.path.join(dst_dir, 'psnr_ssim.txt'), 'a') as f: 234 | # f.write("PSNR: %.4f, SSIM: %.4f] " % (psnr_val_rgb, ssim_val_rgb)+'\n') 235 | # # f.write("PSNR: %.4f" % (psnr_val_rgb)+'\n') 236 | -------------------------------------------------------------------------------- /ProRes/demo/prompt_save.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import os 5 | import warnings 6 | 7 | import requests 8 | import argparse 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import glob 14 | import tqdm 15 | 16 | import matplotlib.pyplot as plt 17 | from PIL import Image 18 | 19 | sys.path.append('.') 20 | 21 | import models_unet 22 | import models_mirnetv2 23 | import models_mprnet 24 | 25 | import random 26 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 27 | from skimage.metrics import structural_similarity as ssim_loss 28 | from collections import OrderedDict 29 | import cv2 30 | 31 | imagenet_mean = np.array([0.485, 0.456, 0.406]) 32 | imagenet_std = np.array([0.229, 0.224, 0.225]) 33 | 34 | 35 | def get_args_parser(): 36 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 37 | parser.add_argument('--ckpt_dir', type=str, help='dir to ckpt', 38 | default='/sharefs/baaivision/xinlongwang/code/uip/models/' 39 | 'new3_all_lr5e-4') 40 | # 'new_ablation_bs2x32x4_enhance_gt_300ep_sl1_beta0.01_square896x448_fusefeat_mask0.75_merge3') 41 | parser.add_argument('--model', type=str, help='dir to ckpt', 42 | default='uip_vit_large_patch16_input896x448_win_dec64_8glb_sl1') 43 | parser.add_argument('--prompt', type=str, help='prompt image in train set', 44 | default='100') 45 | parser.add_argument('--epoch', type=int, help='model epochs', 46 | default=14) 47 | # default=150) 48 | parser.add_argument('--input_size', type=int, help='model epochs', 49 | default=448) 50 | parser.add_argument('--pred_gt', action='store_true', help='trained by using gt as gt', 51 | default=False) 52 | parser.add_argument('--save', action='store_true', help='save predictions', 53 | default=False) 54 | return parser.parse_args() 55 | 56 | 57 | def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'): 58 | # build model 59 | model = getattr(models_mprnet, arch)() 60 | # load model 61 | checkpoint = torch.load(chkpt_dir, map_location='cuda:0') 62 | msg = model.load_state_dict(checkpoint['model'], strict=False) 63 | # print(msg) 64 | 65 | # state_dict = checkpoint["state_dict"] 66 | # model.load_state_dict(state_dict) 67 | return model 68 | 69 | 70 | 71 | 72 | def run_one_image(img, tgt, type_dict, size, model, out_path, device): 73 | x = torch.tensor(img) 74 | # make it a batch-like 75 | x = x.unsqueeze(dim=0) 76 | x = torch.einsum('nhwc->nchw', x) 77 | 78 | tgt = torch.tensor(tgt) 79 | # make it a batch-like 80 | tgt = tgt.unsqueeze(dim=0) 81 | tgt = torch.einsum('nhwc->nchw', tgt) 82 | 83 | 84 | 85 | # run MAE 86 | loss, y = model(x.float().to(device), tgt.float().to(device),type_dict.float().to(device)) 87 | y = torch.einsum('nchw->nhwc', y[0]).detach().cpu() 88 | 89 | output = y[0, :, :, :] 90 | output = output * imagenet_std + imagenet_mean 91 | output = F.interpolate( 92 | output[None, ...].permute(0, 3, 1, 2), size=[size[1], size[0]], mode='bicubic').permute(0, 2, 3, 1)[0] 93 | 94 | return output.numpy() 95 | 96 | 97 | # TODO: modified from impl. in git@github.com:swz30/MIRNet.git 98 | def myPSNR(tar_img, prd_img): 99 | imdff = np.clip(prd_img, 0, 1) - np.clip(tar_img, 0, 1) 100 | rmse = np.sqrt((imdff ** 2).mean()) 101 | ps = 20 * np.log10(1 / rmse) 102 | return ps 103 | 104 | 105 | if __name__ == '__main__': 106 | args = get_args_parser() 107 | 108 | ckpt_dir = args.ckpt_dir 109 | model = args.model 110 | epoch = args.epoch 111 | prompt = args.prompt 112 | input_size = args.input_size 113 | prompt_type = 'gt' if args.pred_gt else 'gt_sub_input' 114 | 115 | ckpt_file = 'checkpoint-{}.pth'.format(epoch) 116 | assert ckpt_dir[-1] != "/" 117 | dst_dir = os.path.join('models_inference', ckpt_dir.split('/')[-1], 118 | "enhance_inference_epoch{}_{}".format(epoch, os.path.basename(prompt).split(".")[0])) 119 | 120 | if os.path.exists(dst_dir): 121 | # raise Exception("{} exist! make sure to overwrite?".format(dst_dir)) 122 | warnings.warn("{} exist! make sure to overwrite?".format(dst_dir)) 123 | else: 124 | os.makedirs(dst_dir) 125 | print("output_dir: {}".format(dst_dir)) 126 | 127 | ckpt_path = os.path.join(ckpt_dir, ckpt_file) 128 | model_mae = prepare_model(ckpt_path, model) 129 | print('Model loaded.') 130 | 131 | device = torch.device("cuda") 132 | model_mae.to(device) 133 | 134 | prompt = model_mae.prompt 135 | prompt = torch.einsum('nchw->nhwc', prompt) 136 | 137 | output = prompt[3, :, :, :].detach().cpu().numpy() 138 | output_min, output_max = output.min(), output.max() 139 | output_normalized = (output - output_min) / (output_max - output_min)*255 140 | print(output_normalized.max(), output_normalized.min()) 141 | color = ('b', 'g', 'r') 142 | for i, col in enumerate(color): 143 | histr = cv2.calcHist([output_normalized], [i], None, [256], [0, 256]) 144 | plt.plot(histr, color=col) 145 | plt.xlim([0, 256]) 146 | plt.show() 147 | plt.savefig(os.path.join('./models_inference/', str(3) + '.png')) 148 | 149 | # # print(output.dtype, output.shape) 150 | # # np.save(os.path.join('./models_inference/', str(i) + '.npy'), output) 151 | # # print(output.max(), output.min()) 152 | # output = output * imagenet_std + imagenet_mean 153 | # print(output.max(), output.min()) 154 | # output = output * 255 155 | # output_min, output_max = output.min(), output.max() 156 | # output_normalized = (output - output_min) / (output_max - output_min) 157 | # # print(output_normalized.max(), output_normalized.min()) 158 | # output = Image.fromarray(output.astype(np.uint8)) 159 | # output.save(os.path.join('./models_inference/', str(i) + '.png')) 160 | 161 | 162 | # # img_src_dir = "datasets/low_level/denoising/sidd/sidd_val_patch256/input" 163 | # # type_dict = torch.tensor([0, 0, 1, 0]).unsqueeze(0).unsqueeze(0).cuda() 164 | 165 | # # img_src_dir = "datasets/low_level/deblur/test/GoPro/input" 166 | # # type_dict = torch.tensor([0, 0, 0, 1]).unsqueeze(0).unsqueeze(0).cuda() 167 | 168 | # # img_src_dir = "datasets/low_level/derain/test/Rain100L/input" 169 | # # type_dict = torch.tensor([1, 0, 0, 0]).unsqueeze(0).unsqueeze(0).cuda() 170 | 171 | # img_src_dir = "datasets/low_level/enhance/lol/eval15/input" 172 | # type_dict = torch.tensor([0, 1, 0, 0]).unsqueeze(0).unsqueeze(0).cuda() 173 | 174 | # # img_src_dir = "datasets/low_level/deblur/train/input_crop" 175 | 176 | # img_path_list = glob.glob(os.path.join(img_src_dir, "*.png")) 177 | 178 | 179 | 180 | # model_mae.eval() 181 | # for img_path in tqdm.tqdm(img_path_list): 182 | # """ Load an image """ 183 | # img_name = os.path.basename(img_path) 184 | # out_path = os.path.join(dst_dir, img_name) 185 | # img = Image.open(img_path).convert("RGB") 186 | 187 | # size = img.size 188 | # img = img.resize((input_size, input_size)) 189 | # img = np.array(img) / 255. 190 | 191 | # # load gt 192 | # # rgb_gt = Image.open(img_path.replace('input', 'groundtruth')).convert("RGB") 193 | # # rgb_gt = Image.open(img_path.replace('input', 'target')).convert("RGB") 194 | # rgb_gt = Image.open(img_path.replace('input', 'gt')).convert("RGB") 195 | 196 | 197 | # # irrelevant to prompt-type 198 | # rgb_gt = rgb_gt.resize((input_size, input_size)) 199 | # rgb_gt = np.array(rgb_gt) / 255. 200 | 201 | 202 | # # normalize by ImageNet mean and std 203 | # img = img - imagenet_mean 204 | # img = img / imagenet_std 205 | 206 | # tgt = rgb_gt # tgt is not available 207 | # # normalize by ImageNet mean and std 208 | # tgt = tgt - imagenet_mean 209 | # tgt = tgt / imagenet_std 210 | 211 | # """### Run MAE on the image""" 212 | # # make random mask reproducible (comment out to make it change) 213 | # # torch.manual_seed(2) 214 | 215 | # output = run_one_image(img, tgt, type_dict, size, model_mae, out_path, device) 216 | 217 | # rgb_restored = output 218 | 219 | # rgb_restored = np.clip(rgb_restored, 0, 1) 220 | 221 | 222 | 223 | # if args.save: 224 | # # utils.save_img(out_path, img_as_ubyte(rgb_restored)) 225 | # output = rgb_restored * 255 226 | # output = Image.fromarray(output.astype(np.uint8)) 227 | # output.save(out_path) 228 | 229 | -------------------------------------------------------------------------------- /ProRes/engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | 17 | import util.misc as misc 18 | import util.lr_sched as lr_sched 19 | 20 | import numpy as np 21 | import wandb 22 | 23 | 24 | def get_loss_scale_for_deepspeed(model): 25 | optimizer = model.optimizer 26 | loss_scale = None 27 | if hasattr(optimizer, 'loss_scale'): 28 | loss_scale = optimizer.loss_scale 29 | elif hasattr(optimizer, 'cur_scale'): 30 | loss_scale = optimizer.cur_scale 31 | return loss_scale, optimizer._global_grad_norm 32 | # return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 33 | 34 | 35 | def train_one_epoch(model: torch.nn.Module, 36 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 37 | device: torch.device, epoch: int, loss_scaler, 38 | log_writer=None, 39 | global_rank=None, 40 | args=None): 41 | model.train(True) 42 | metric_logger = misc.MetricLogger(delimiter=" ") 43 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 44 | header = 'Epoch: [{}]'.format(epoch) 45 | print_freq = 20 46 | 47 | accum_iter = args.accum_iter 48 | 49 | optimizer.zero_grad() 50 | 51 | if log_writer is not None: 52 | print('log_dir: {}'.format(log_writer.log_dir)) 53 | 54 | wandb_images = [] 55 | for data_iter_step, (samples,targets, bool_masked_pos, valid, flag, image_ori) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 56 | # we use a per iteration (instead of per epoch) lr scheduler 57 | if data_iter_step % accum_iter == 0: 58 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 59 | 60 | 61 | samples= samples.to(device, non_blocking=True) 62 | targets = targets.to(device, non_blocking=True) 63 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True) 64 | valid = valid.to(device, non_blocking=True) 65 | #prompt = prompt.to(device, non_blocking=True) 66 | flag = flag.to(device, non_blocking=True) 67 | image_ori = image_ori.to(device, non_blocking=True) 68 | 69 | with torch.cuda.amp.autocast(): 70 | loss, y = model(samples, targets) 71 | #loss, y = model(samples, targets, prompt) 72 | loss_value = loss.item() 73 | 74 | if not math.isfinite(loss_value): 75 | print("Loss is {}, stopping training".format(loss_value)) 76 | sys.exit(1) 77 | 78 | if loss_scaler is None: 79 | loss /= accum_iter 80 | model.backward(loss) 81 | model.step() 82 | 83 | # if (data_iter_step + 1) % update_freq == 0: 84 | # model.zero_grad() 85 | # Deepspeed will call step() & model.zero_grad() automatic 86 | # grad_norm = None 87 | loss_scale_value, grad_norm = get_loss_scale_for_deepspeed(model) 88 | else: 89 | loss /= accum_iter 90 | grad_norm = loss_scaler(loss, optimizer, clip_grad=args.clip_grad, 91 | parameters=model.parameters(), 92 | update_grad=(data_iter_step + 1) % accum_iter == 0) 93 | if (data_iter_step + 1) % accum_iter == 0: 94 | optimizer.zero_grad() 95 | loss_scale_value = loss_scaler.state_dict()["scale"] 96 | 97 | torch.cuda.synchronize() 98 | 99 | metric_logger.update(loss=loss_value) 100 | 101 | lr = optimizer.param_groups[0]["lr"] 102 | metric_logger.update(lr=lr) 103 | 104 | metric_logger.update(loss_scale=loss_scale_value) 105 | metric_logger.update(grad_norm=grad_norm) 106 | 107 | loss_value_reduce = misc.all_reduce_mean(loss_value) 108 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 109 | """ We use epoch_1000x as the x-axis in tensorboard. 110 | This calibrates different curves when batch size changes. 111 | """ 112 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 113 | log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) 114 | log_writer.add_scalar('lr', lr, epoch_1000x) 115 | 116 | # gather the stats from all processes 117 | metric_logger.synchronize_between_processes() 118 | print("Averaged stats:", metric_logger) 119 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 120 | 121 | 122 | @torch.no_grad() 123 | def evaluate_pt(data_loader, model, device, epoch=None, global_rank=None, args=None): 124 | metric_logger = misc.MetricLogger(delimiter=" ") 125 | header = 'Test:' 126 | # switch to evaluation mode 127 | model.eval() 128 | wandb_images = [] 129 | for batch in metric_logger.log_every(data_loader, 10, header): 130 | samples = batch[0] 131 | targets = batch[1] 132 | bool_masked_pos = batch[2] 133 | valid = batch[3] 134 | #prompt = batch[4] 135 | flag = batch[4] 136 | image_ori = batch[5] 137 | 138 | 139 | 140 | samples = samples.to(device, non_blocking=True) 141 | targets = targets.to(device, non_blocking=True) 142 | bool_masked_pos = bool_masked_pos.to(device, non_blocking=True) 143 | valid = valid.to(device, non_blocking=True) 144 | #prompt = prompt.to(device, non_blocking=True) 145 | flag = flag.to(device, non_blocking=True) 146 | image_ori = image_ori.to(device, non_blocking=True) 147 | 148 | # compute output 149 | with torch.cuda.amp.autocast(): 150 | 151 | #bool_masked_pos[:, :bool_masked_pos.shape[1]//2] = 0 152 | #bool_masked_pos[:, bool_masked_pos.shape[1]//2:] = 1 153 | loss, y = model(samples, targets) 154 | #loss, y = model(samples, targets, prompt) 155 | 156 | metric_logger.update(loss=loss.item()) 157 | # if global_rank == 0 and args.log_wandb: 158 | # imagenet_mean = np.array([0.485, 0.456, 0.406]) 159 | # imagenet_std = np.array([0.229, 0.224, 0.225]) 160 | # y = y[[0]] 161 | # y = model.module.unpatchify(y) 162 | # y = torch.einsum('nchw->nhwc', y).detach().cpu() 163 | # mask = mask[[0]] 164 | # mask = mask.detach().float().cpu() 165 | # mask = mask.unsqueeze(-1).repeat(1, 1, model.module.patch_size**2 *3) # (N, H*W, p*p*3) 166 | # mask = model.module.unpatchify(mask) # 1 is removing, 0 is keeping 167 | # mask = torch.einsum('nchw->nhwc', mask).detach().cpu() 168 | # x = samples[[0]] 169 | # x = x.detach().float().cpu() 170 | # x = torch.einsum('nchw->nhwc', x) 171 | # tgt = targets[[0]] 172 | # tgt = tgt.detach().float().cpu() 173 | # tgt = torch.einsum('nchw->nhwc', tgt) 174 | # im_masked = tgt * (1 - mask) 175 | 176 | # frame = torch.cat((x, im_masked, y, tgt), dim=2) 177 | # frame = frame[0] 178 | # frame = torch.clip((frame * imagenet_std + imagenet_mean) * 255, 0, 255).int() 179 | # wandb_images.append(wandb.Image(frame.numpy(), caption="x; im_masked; y; tgt")) 180 | 181 | # gather the stats from all processes 182 | metric_logger.synchronize_between_processes() 183 | print('Val loss {losses.global_avg:.3f}'.format(losses=metric_logger.loss)) 184 | 185 | out = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 186 | 187 | if global_rank == 0 and args.log_wandb: 188 | wandb.log({**{f'test_{k}': v for k, v in out.items()},'epoch': epoch}) 189 | if len(wandb_images) > 0: 190 | wandb.log({"Testing examples": wandb_images[::2][:20]}) 191 | return out 192 | -------------------------------------------------------------------------------- /ProRes/eval_ours.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | 3 | 4 | JOB_NAME="prores_vitl_pretrained_sl1_mprnetprompt_add" 5 | CKPT_DIR="models/${JOB_NAME}" 6 | EPOCH=49 7 | MODEL="uip_vit_large_patch16_input448x448_win_dec64_8glb_sl1" 8 | WORK_DIR="models_inference/${JOB_NAME}" 9 | 10 | CUDA_VISIBLE_DEVICES=4 python demo/ours_inference_sidd_v2.py \ 11 | --ckpt_dir ${CKPT_DIR} --model ${MODEL} \ 12 | --epoch ${EPOCH} --input_size 448 --pred_gt --save 13 | 14 | 15 | # CUDA_VISIBLE_DEVICES=0 python demo/ours_inference_lol_v2.py \ 16 | # --ckpt_dir ${CKPT_DIR} --model ${MODEL} \ 17 | # --epoch ${EPOCH} --input_size 448 --pred_gt --save 18 | 19 | 20 | # CUDA_VISIBLE_DEVICES=3 python demo/ours_inference_derain_v2.py \ 21 | # --ckpt_dir ${CKPT_DIR} --model ${MODEL} \ 22 | # --epoch ${EPOCH} --input_size 448 --pred_gt --save --split 3 23 | 24 | # CUDA_VISIBLE_DEVICES=3 python demo/ours_inference_deblur_v2.py \ 25 | # --ckpt_dir ${CKPT_DIR} --model ${MODEL} \ 26 | # --epoch ${EPOCH} --input_size 448 --pred_gt --save 27 | 28 | -------------------------------------------------------------------------------- /ProRes/masking_generator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 3 | Copyright Zhun Zhong & Liang Zheng 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | 7 | Modified by Hangbo Bao, for generating the masked position for visual image transformer 8 | """ 9 | # -------------------------------------------------------- 10 | # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) 11 | # Github source: https://github.com/microsoft/unilm/tree/master/beit 12 | # Copyright (c) 2021 Microsoft 13 | # Licensed under The MIT License [see LICENSE for details] 14 | # By Hangbo Bao 15 | # Based on timm, DINO and DeiT code bases 16 | # https://github.com/rwightman/pytorch-image-models/tree/master/timm 17 | # Originally inspired by impl at https://github.com/zhunzhong07/Random-Erasing, Apache 2.0 18 | # Copyright Zhun Zhong & Liang Zheng 19 | # 20 | # Hacked together by / Copyright 2020 Ross Wightman 21 | # 22 | # Modified by Hangbo Bao, for generating the masked position for visual image transformer 23 | # --------------------------------------------------------' 24 | import random 25 | import math 26 | import numpy as np 27 | 28 | 29 | class MaskingGenerator: 30 | def __init__( 31 | self, input_size, num_masking_patches, min_num_patches=0, max_num_patches=None, 32 | min_aspect=0.3, max_aspect=None): 33 | if not isinstance(input_size, tuple): 34 | input_size = (input_size,) * 2 35 | self.height, self.width = input_size 36 | 37 | self.num_patches = self.height * self.width 38 | self.num_masking_patches = num_masking_patches 39 | 40 | self.min_num_patches = min_num_patches 41 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 42 | 43 | max_aspect = max_aspect or 1 / min_aspect 44 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 45 | # print(self.max_num_patches) 46 | def __repr__(self): 47 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 48 | self.height, self.width, self.min_num_patches, self.max_num_patches, 49 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 50 | return repr_str 51 | 52 | def get_shape(self): 53 | return self.height, self.width 54 | 55 | def _mask(self, mask, max_mask_patches): 56 | delta = 0 57 | for attempt in range(10): 58 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 59 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 60 | h = int(round(math.sqrt(target_area * aspect_ratio))) 61 | w = int(round(math.sqrt(target_area / aspect_ratio))) 62 | if w < self.width and h < self.height: 63 | top = random.randint(0, self.height - h) 64 | left = random.randint(0, self.width - w) 65 | 66 | num_masked = mask[top: top + h, left: left + w].sum() 67 | # Overlap 68 | if 0 < h * w - num_masked <= max_mask_patches: 69 | for i in range(top, top + h): 70 | for j in range(left, left + w): 71 | if mask[i, j] == 0: 72 | mask[i, j] = 1 73 | delta += 1 74 | 75 | if delta > 0: 76 | break 77 | return delta 78 | 79 | def __call__(self): 80 | mask = np.zeros(shape=self.get_shape(), dtype=np.int32) 81 | mask_count = 0 82 | while mask_count < self.num_masking_patches: 83 | max_mask_patches = self.num_masking_patches - mask_count 84 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 85 | 86 | delta = self._mask(mask, max_mask_patches) 87 | if delta == 0: 88 | break 89 | else: 90 | mask_count += delta 91 | 92 | # maintain a fix number {self.num_masking_patches} 93 | if mask_count > self.num_masking_patches: 94 | delta = mask_count - self.num_masking_patches 95 | mask_x, mask_y = mask.nonzero() 96 | to_vis = np.random.choice(mask_x.shape[0], delta, replace=False) 97 | mask[mask_x[to_vis], mask_y[to_vis]] = 0 98 | 99 | elif mask_count < self.num_masking_patches: 100 | delta = self.num_masking_patches - mask_count 101 | mask_x, mask_y = (mask == 0).nonzero() 102 | to_mask = np.random.choice(mask_x.shape[0], delta, replace=False) 103 | mask[mask_x[to_mask], mask_y[to_mask]] = 1 104 | 105 | assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}" 106 | 107 | return mask 108 | 109 | 110 | if __name__ == '__main__': 111 | # import pdb 112 | 113 | generator = MaskingGenerator(input_size=28, num_masking_patches=0, min_num_patches=0, ) 114 | for i in range(2): 115 | mask = generator() 116 | # if mask.sum() != 118: 117 | # pdb.set_trace() 118 | print(mask.shape) 119 | print(mask.sum()) 120 | # exit() 121 | 122 | # 1 masked patch 123 | # 0 valid patch -------------------------------------------------------------------------------- /ProRes/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # export CC=/cluster_home/custom_data/gcc/bin/gcc 3 | # export CXX=/cluster_home/custom_data/gcc/bin/g++ 4 | # export PATH=/cluster_home/custom_data/gcc/bin:$PATH 5 | # export LD_LIBRARY_PATH=/cluster_home/custom_data/gcc/lib:$LD_LIBRARY_PATH 6 | # export LD_LIBRARY_PATH=/cluster_home/custom_data/gcc/lib64:$LD_LIBRARY_PATH 7 | 8 | 9 | CONFIG=$1 10 | GPUS=$2 11 | NNODES=${NNODES:-1} 12 | NODE_RANK=${NODE_RANK:-0} 13 | PORT=${PORT:-29502} 14 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 15 | 16 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 17 | 18 | name=ProRes_ep50_lr1e-3 19 | python -m torch.distributed.launch \ 20 | --nnodes=1 \ 21 | --node_rank=$NODE_RANK \ 22 | --master_addr=$MASTER_ADDR \ 23 | --nproc_per_node=8 \ 24 | --master_port=$PORT \ 25 | --use_env main_pretrain.py \ 26 | --batch_size 16 \ 27 | --accum_iter 1 \ 28 | --model uip_vit_large_patch16_input448x448_win_dec64_8glb_sl1 \ 29 | --num_mask_patches 784 \ 30 | --max_mask_patches_per_block 0 \ 31 | --epochs 50 \ 32 | --warmup_epochs 2 \ 33 | --lr 1e-3 \ 34 | --weight_decay 0.05 \ 35 | --clip_grad 1.0 \ 36 | --opt_betas 0.9 0.999 \ 37 | --opt_eps 1e-8 \ 38 | --layer_decay 0.8 \ 39 | --drop_path 0.1 \ 40 | --min_random_scale 0.3 \ 41 | --input_size 448 448 \ 42 | --save_freq 1 \ 43 | --data_path ./ \ 44 | --json_path \ 45 | datasets/low_level/target-derain_train.json \ 46 | datasets/low_level/gt-enhance_lol_train.json \ 47 | datasets/low_level/groundtruth-denoise_ssid_train448.json \ 48 | datasets/low_level/groundtruth_crop-deblur_gopro_train.json \ 49 | --val_json_path \ 50 | datasets/low_level/target-derain_test_rain100h.json \ 51 | datasets/low_level/gt-enhance_lol_eval.json \ 52 | datasets/low_level/groundtruth-denoise_ssid_val256.json \ 53 | datasets/low_level/groundtruth-deblur_gopro_val.json \ 54 | --use_two_pairs \ 55 | --output_dir ./models/$name \ 56 | --log_dir ./models/$name/logs \ 57 | --finetune pretrained_weights/mae_pretrain_vit_large.pth \ 58 | 59 | # --log_wandb \ 60 | # --resume models/$name/checkpoint-2.pth \ 61 | # --seed 1000 \ 62 | -------------------------------------------------------------------------------- /ProRes/util/__pycache__/lr_decay.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/lr_decay.cpython-37.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/lr_decay.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/lr_decay.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/lr_sched.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/lr_sched.cpython-37.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/lr_sched.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/lr_sched.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/pos_embed.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/pos_embed.cpython-37.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/pos_embed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/pos_embed.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/vitdet_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/vitdet_utils.cpython-37.pyc -------------------------------------------------------------------------------- /ProRes/util/__pycache__/vitdet_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/ProRes/util/__pycache__/vitdet_utils.cpython-38.pyc -------------------------------------------------------------------------------- /ProRes/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | 14 | 15 | class RandomResizedCrop(transforms.RandomResizedCrop): 16 | """ 17 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 18 | This may lead to results different with torchvision's version. 19 | Following BYOL's TF code: 20 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 21 | """ 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F._get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w -------------------------------------------------------------------------------- /ProRes/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | import PIL 13 | 14 | from torchvision import datasets, transforms 15 | 16 | from timm.data import create_transform 17 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | 19 | 20 | def build_dataset(is_train, args): 21 | transform = build_transform(is_train, args) 22 | 23 | root = os.path.join(args.data_path, 'train' if is_train else 'val') 24 | dataset = datasets.ImageFolder(root, transform=transform) 25 | 26 | print(dataset) 27 | 28 | return dataset 29 | 30 | 31 | def build_transform(is_train, args): 32 | mean = IMAGENET_DEFAULT_MEAN 33 | std = IMAGENET_DEFAULT_STD 34 | # train transform 35 | if is_train: 36 | # this should always dispatch to transforms_imagenet_train 37 | transform = create_transform( 38 | input_size=args.input_size, 39 | is_training=True, 40 | color_jitter=args.color_jitter, 41 | auto_augment=args.aa, 42 | interpolation='bicubic', 43 | re_prob=args.reprob, 44 | re_mode=args.remode, 45 | re_count=args.recount, 46 | mean=mean, 47 | std=std, 48 | ) 49 | return transform 50 | 51 | # eval transform 52 | t = [] 53 | if args.input_size <= 224: 54 | crop_pct = 224 / 256 55 | else: 56 | crop_pct = 1.0 57 | size = int(args.input_size / crop_pct) 58 | t.append( 59 | transforms.Resize(size, interpolation=PIL.Image.BICUBIC), # to maintain same ratio w.r.t. 224 images 60 | ) 61 | t.append(transforms.CenterCrop(args.input_size)) 62 | 63 | t.append(transforms.ToTensor()) 64 | t.append(transforms.Normalize(mean, std)) 65 | return transforms.Compose(t) 66 | -------------------------------------------------------------------------------- /ProRes/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /ProRes/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /ProRes/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | -------------------------------------------------------------------------------- /ProRes/util/metrics.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import numpy as np 3 | import glob 4 | import tqdm 5 | 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | import cv2 9 | 10 | def _convert_input_type_range(img): 11 | """Convert the type and range of the input image. 12 | 13 | It converts the input image to np.float32 type and range of [0, 1]. 14 | It is mainly used for pre-processing the input image in colorspace 15 | convertion functions such as rgb2ycbcr and ycbcr2rgb. 16 | 17 | Args: 18 | img (ndarray): The input image. It accepts: 19 | 1. np.uint8 type with range [0, 255]; 20 | 2. np.float32 type with range [0, 1]. 21 | 22 | Returns: 23 | (ndarray): The converted image with type of np.float32 and range of 24 | [0, 1]. 25 | """ 26 | img_type = img.dtype 27 | img = img.astype(np.float32) 28 | if img_type == np.float32: 29 | pass 30 | elif img_type == np.uint8: 31 | img /= 255. 32 | else: 33 | raise TypeError('The img type should be np.float32 or np.uint8, ' 34 | f'but got {img_type}') 35 | return img 36 | 37 | 38 | def _convert_output_type_range(img, dst_type): 39 | """Convert the type and range of the image according to dst_type. 40 | 41 | It converts the image to desired type and range. If `dst_type` is np.uint8, 42 | images will be converted to np.uint8 type with range [0, 255]. If 43 | `dst_type` is np.float32, it converts the image to np.float32 type with 44 | range [0, 1]. 45 | It is mainly used for post-processing images in colorspace convertion 46 | functions such as rgb2ycbcr and ycbcr2rgb. 47 | 48 | Args: 49 | img (ndarray): The image to be converted with np.float32 type and 50 | range [0, 255]. 51 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 52 | converts the image to np.uint8 type with range [0, 255]. If 53 | dst_type is np.float32, it converts the image to np.float32 type 54 | with range [0, 1]. 55 | 56 | Returns: 57 | (ndarray): The converted image with desired type and range. 58 | """ 59 | if dst_type not in (np.uint8, np.float32): 60 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' 61 | f'but got {dst_type}') 62 | if dst_type == np.uint8: 63 | img = img.round() 64 | else: 65 | img /= 255. 66 | return img.astype(dst_type) 67 | 68 | def bgr2ycbcr(img, y_only=False): 69 | """Convert a BGR image to YCbCr image. 70 | 71 | The bgr version of rgb2ycbcr. 72 | It implements the ITU-R BT.601 conversion for standard-definition 73 | television. See more details in 74 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 75 | 76 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. 77 | In OpenCV, it implements a JPEG conversion. See more details in 78 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 79 | 80 | Args: 81 | img (ndarray): The input image. It accepts: 82 | 1. np.uint8 type with range [0, 255]; 83 | 2. np.float32 type with range [0, 1]. 84 | y_only (bool): Whether to only return Y channel. Default: False. 85 | 86 | Returns: 87 | ndarray: The converted YCbCr image. The output image has the same type 88 | and range as input image. 89 | """ 90 | img_type = img.dtype 91 | img = _convert_input_type_range(img) 92 | if y_only: 93 | out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 94 | else: 95 | out_img = np.matmul( 96 | img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 97 | [65.481, -37.797, 112.0]]) + [16, 128, 128] 98 | out_img = _convert_output_type_range(out_img, img_type) 99 | return out_img 100 | 101 | def reorder_image(img, input_order='HWC'): 102 | """Reorder images to 'HWC' order. 103 | 104 | If the input_order is (h, w), return (h, w, 1); 105 | If the input_order is (c, h, w), return (h, w, c); 106 | If the input_order is (h, w, c), return as it is. 107 | 108 | Args: 109 | img (ndarray): Input image. 110 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 111 | If the input image shape is (h, w), input_order will not have 112 | effects. Default: 'HWC'. 113 | 114 | Returns: 115 | ndarray: reordered image. 116 | """ 117 | 118 | if input_order not in ['HWC', 'CHW']: 119 | raise ValueError( 120 | f'Wrong input_order {input_order}. Supported input_orders are ' 121 | "'HWC' and 'CHW'") 122 | if len(img.shape) == 2: 123 | img = img[..., None] 124 | if input_order == 'CHW': 125 | img = img.transpose(1, 2, 0) 126 | return img 127 | 128 | 129 | def to_y_channel(img): 130 | """Change to Y channel of YCbCr. 131 | 132 | Args: 133 | img (ndarray): Images with range [0, 255]. 134 | 135 | Returns: 136 | (ndarray): Images with range [0, 255] (float type) without round. 137 | """ 138 | img = img.astype(np.float32) / 255. 139 | if img.ndim == 3 and img.shape[2] == 3: 140 | img = bgr2ycbcr(img, y_only=True) 141 | img = img[..., None] 142 | return img * 255. 143 | 144 | def calculate_psnr(img1, 145 | img2, 146 | crop_border, 147 | input_order='HWC', 148 | test_y_channel=False): 149 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 150 | 151 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 152 | 153 | Args: 154 | img1 (ndarray): Images with range [0, 255]. 155 | img2 (ndarray): Images with range [0, 255]. 156 | crop_border (int): Cropped pixels in each edge of an image. These 157 | pixels are not involved in the PSNR calculation. 158 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 159 | Default: 'HWC'. 160 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 161 | 162 | Returns: 163 | float: psnr result. 164 | """ 165 | 166 | assert img1.shape == img2.shape, ( 167 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 168 | if input_order not in ['HWC', 'CHW']: 169 | raise ValueError( 170 | f'Wrong input_order {input_order}. Supported input_orders are ' 171 | '"HWC" and "CHW"') 172 | img1 = reorder_image(img1, input_order=input_order) 173 | img2 = reorder_image(img2, input_order=input_order) 174 | img1 = img1.astype(np.float64) 175 | img2 = img2.astype(np.float64) 176 | 177 | if crop_border != 0: 178 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 179 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 180 | 181 | if test_y_channel: 182 | img1 = to_y_channel(img1) 183 | img2 = to_y_channel(img2) 184 | 185 | mse = np.mean((img1 - img2)**2) 186 | if mse == 0: 187 | return float('inf') 188 | return 20. * np.log10(255. / np.sqrt(mse)) 189 | 190 | 191 | def _ssim(img1, img2): 192 | """Calculate SSIM (structural similarity) for one channel images. 193 | 194 | It is called by func:`calculate_ssim`. 195 | 196 | Args: 197 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 198 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 199 | 200 | Returns: 201 | float: ssim result. 202 | """ 203 | 204 | C1 = (0.01 * 255)**2 205 | C2 = (0.03 * 255)**2 206 | 207 | img1 = img1.astype(np.float64) 208 | img2 = img2.astype(np.float64) 209 | kernel = cv2.getGaussianKernel(11, 1.5) 210 | window = np.outer(kernel, kernel.transpose()) 211 | 212 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 213 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 214 | mu1_sq = mu1**2 215 | mu2_sq = mu2**2 216 | mu1_mu2 = mu1 * mu2 217 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 218 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 219 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 220 | 221 | ssim_map = ((2 * mu1_mu2 + C1) * 222 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 223 | (sigma1_sq + sigma2_sq + C2)) 224 | return ssim_map.mean() 225 | 226 | 227 | def calculate_ssim(img1, 228 | img2, 229 | crop_border, 230 | input_order='HWC', 231 | test_y_channel=False): 232 | """Calculate SSIM (structural similarity). 233 | 234 | Ref: 235 | Image quality assessment: From error visibility to structural similarity 236 | 237 | The results are the same as that of the official released MATLAB code in 238 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 239 | 240 | For three-channel images, SSIM is calculated for each channel and then 241 | averaged. 242 | 243 | Args: 244 | img1 (ndarray): Images with range [0, 255]. 245 | img2 (ndarray): Images with range [0, 255]. 246 | crop_border (int): Cropped pixels in each edge of an image. These 247 | pixels are not involved in the SSIM calculation. 248 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 249 | Default: 'HWC'. 250 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 251 | 252 | Returns: 253 | float: ssim result. 254 | """ 255 | 256 | assert img1.shape == img2.shape, ( 257 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 258 | if input_order not in ['HWC', 'CHW']: 259 | raise ValueError( 260 | f'Wrong input_order {input_order}. Supported input_orders are ' 261 | '"HWC" and "CHW"') 262 | img1 = reorder_image(img1, input_order=input_order) 263 | img2 = reorder_image(img2, input_order=input_order) 264 | img1 = img1.astype(np.float64) 265 | img2 = img2.astype(np.float64) 266 | 267 | if crop_border != 0: 268 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 269 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 270 | 271 | if test_y_channel: 272 | img1 = to_y_channel(img1) 273 | img2 = to_y_channel(img2) 274 | 275 | ssims = [] 276 | for i in range(img1.shape[2]): 277 | ssims.append(_ssim(img1[..., i], img2[..., i])) 278 | return np.array(ssims).mean() 279 | 280 | -------------------------------------------------------------------------------- /ProRes/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /ProRes/util/vitdet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __all__ = [ 8 | "window_partition", 9 | "window_unpartition", 10 | "add_decomposed_rel_pos", 11 | "get_abs_pos", 12 | "PatchEmbed", 13 | ] 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | 23 | Returns: 24 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 25 | (Hp, Wp): padded height and width before partition 26 | """ 27 | B, H, W, C = x.shape 28 | 29 | pad_h = (window_size - H % window_size) % window_size 30 | pad_w = (window_size - W % window_size) % window_size 31 | if pad_h > 0 or pad_w > 0: 32 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 33 | Hp, Wp = H + pad_h, W + pad_w 34 | 35 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 36 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | return windows, (Hp, Wp) 38 | 39 | 40 | def window_unpartition(windows, window_size, pad_hw, hw): 41 | """ 42 | Window unpartition into original sequences and removing padding. 43 | Args: 44 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 45 | window_size (int): window size. 46 | pad_hw (Tuple): padded height and width (Hp, Wp). 47 | hw (Tuple): original height and width (H, W) before padding. 48 | 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 56 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 57 | 58 | if Hp > H or Wp > W: 59 | x = x[:, :H, :W, :].contiguous() 60 | return x 61 | 62 | 63 | def get_rel_pos(q_size, k_size, rel_pos): 64 | """ 65 | Get relative positional embeddings according to the relative positions of 66 | query and key sizes. 67 | Args: 68 | q_size (int): size of query q. 69 | k_size (int): size of key k. 70 | rel_pos (Tensor): relative position embeddings (L, C). 71 | 72 | Returns: 73 | Extracted positional embeddings according to relative positions. 74 | """ 75 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 76 | # Interpolate rel pos if needed. 77 | if rel_pos.shape[0] != max_rel_dist: 78 | # Interpolate rel pos. 79 | rel_pos_resized = F.interpolate( 80 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 81 | size=max_rel_dist, 82 | mode="linear", 83 | ) 84 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 85 | else: 86 | rel_pos_resized = rel_pos 87 | 88 | # Scale the coords with short length if shapes for q and k are different. 89 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 90 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 91 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 92 | 93 | return rel_pos_resized[relative_coords.long()] 94 | 95 | 96 | def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): 97 | """ 98 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 99 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 100 | Args: 101 | attn (Tensor): attention map. 102 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 103 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 104 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 105 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 106 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 107 | 108 | Returns: 109 | attn (Tensor): attention map with added relative positional embeddings. 110 | """ 111 | q_h, q_w = q_size 112 | k_h, k_w = k_size 113 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 114 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 115 | 116 | B, _, dim = q.shape 117 | r_q = q.reshape(B, q_h, q_w, dim) 118 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 119 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 120 | 121 | attn = ( 122 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 123 | ).view(B, q_h * q_w, k_h * k_w) 124 | 125 | return attn 126 | 127 | 128 | def get_abs_pos(abs_pos, has_cls_token, hw): 129 | """ 130 | Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token 131 | dimension for the original embeddings. 132 | Args: 133 | abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). 134 | has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. 135 | hw (Tuple): size of input image tokens. 136 | 137 | Returns: 138 | Absolute positional embeddings after processing with shape (1, H, W, C) 139 | """ 140 | h, w = hw 141 | if has_cls_token: 142 | abs_pos = abs_pos[:, 1:] 143 | xy_num = abs_pos.shape[1] 144 | size = int(math.sqrt(xy_num)) 145 | assert size * size == xy_num 146 | 147 | if size != h or size != w: 148 | new_abs_pos = F.interpolate( 149 | abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), 150 | size=(h, w), 151 | mode="bicubic", 152 | align_corners=False, 153 | ) 154 | 155 | return new_abs_pos.permute(0, 2, 3, 1) 156 | else: 157 | return abs_pos.reshape(1, h, w, -1) 158 | 159 | 160 | class PatchEmbed(nn.Module): 161 | """ 162 | Image to Patch Embedding. 163 | """ 164 | 165 | def __init__( 166 | self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 167 | ): 168 | """ 169 | Args: 170 | kernel_size (Tuple): kernel size of the projection layer. 171 | stride (Tuple): stride of the projection layer. 172 | padding (Tuple): padding size of the projection layer. 173 | in_chans (int): Number of input image channels. 174 | embed_dim (int): embed_dim (int): Patch embedding dimension. 175 | """ 176 | super().__init__() 177 | 178 | self.proj = nn.Conv2d( 179 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 180 | ) 181 | 182 | def forward(self, x): 183 | x = self.proj(x) 184 | # B C H W -> B H W C 185 | x = x.permute(0, 2, 3, 1) 186 | return x 187 | 188 | 189 | class LayerNorm2D(nn.Module): 190 | """ 191 | A LayerNorm variant, popularized by Transformers, that performs point-wise mean and 192 | variance normalization over the channel dimension for inputs that have shape 193 | (batch_size, channels, height, width). 194 | https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 195 | """ 196 | 197 | def __init__(self, normalized_shape, eps=1e-6): 198 | super().__init__() 199 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 200 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 201 | self.eps = eps 202 | self.normalized_shape = (normalized_shape,) 203 | 204 | def forward(self, x): 205 | u = x.mean(1, keepdim=True) 206 | s = (x - u).pow(2).mean(1, keepdim=True) 207 | x = (x - u) / torch.sqrt(s + self.eps) 208 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 209 | return x -------------------------------------------------------------------------------- /figures/S1_independent.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/figures/S1_independent.jpg -------------------------------------------------------------------------------- /figures/S2_irrelevant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/figures/S2_irrelevant.jpg -------------------------------------------------------------------------------- /figures/S3_combine.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/figures/S3_combine.jpg -------------------------------------------------------------------------------- /figures/intro_figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/figures/intro_figure.jpg -------------------------------------------------------------------------------- /figures/main_figure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/figures/main_figure.jpg -------------------------------------------------------------------------------- /figures/tuning_fivek.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/figures/tuning_fivek.jpg -------------------------------------------------------------------------------- /figures/tuning_reside.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/figures/tuning_reside.jpg -------------------------------------------------------------------------------- /static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:.5rem 1rem;text-align:center}.slider-pagination .slider-page{background:#fff;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}input[type=range].slider{-webkit-appearance:none;-moz-appearance:none;appearance:none;margin:1rem 0;background:0 0;touch-action:none}input[type=range].slider.is-fullwidth{display:block;width:100%}input[type=range].slider:focus{outline:0}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{width:100%}input[type=range].slider:not([orient=vertical])::-moz-range-track{width:100%}input[type=range].slider:not([orient=vertical])::-ms-track{width:100%}input[type=range].slider:not([orient=vertical]).has-output+output,input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{width:3rem;background:#4a4a4a;border-radius:4px;padding:.4rem .8rem;font-size:.75rem;line-height:.75rem;text-align:center;text-overflow:ellipsis;white-space:nowrap;color:#fff;overflow:hidden;pointer-events:none;z-index:200}input[type=range].slider:not([orient=vertical]).has-output-tooltip:disabled+output,input[type=range].slider:not([orient=vertical]).has-output:disabled+output{opacity:.5}input[type=range].slider:not([orient=vertical]).has-output{display:inline-block;vertical-align:middle;width:calc(100% - (4.2rem))}input[type=range].slider:not([orient=vertical]).has-output+output{display:inline-block;margin-left:.75rem;vertical-align:middle}input[type=range].slider:not([orient=vertical]).has-output-tooltip{display:block}input[type=range].slider:not([orient=vertical]).has-output-tooltip+output{position:absolute;left:0;top:-.1rem}input[type=range].slider[orient=vertical]{-webkit-appearance:slider-vertical;-moz-appearance:slider-vertical;appearance:slider-vertical;-webkit-writing-mode:bt-lr;-ms-writing-mode:bt-lr;writing-mode:bt-lr}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{height:100%}input[type=range].slider[orient=vertical]::-moz-range-track{height:100%}input[type=range].slider[orient=vertical]::-ms-track{height:100%}input[type=range].slider::-webkit-slider-runnable-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-moz-range-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-track{cursor:pointer;animate:.2s;box-shadow:0 0 0 #7a7a7a;background:#dbdbdb;border-radius:4px;border:0 solid #7a7a7a}input[type=range].slider::-ms-fill-lower{background:#dbdbdb;border-radius:4px}input[type=range].slider::-ms-fill-upper{background:#dbdbdb;border-radius:4px}input[type=range].slider::-webkit-slider-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-moz-range-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-ms-thumb{box-shadow:none;border:1px solid #b5b5b5;border-radius:4px;background:#fff;cursor:pointer}input[type=range].slider::-webkit-slider-thumb{-webkit-appearance:none;appearance:none}input[type=range].slider.is-circle::-webkit-slider-thumb{border-radius:290486px}input[type=range].slider.is-circle::-moz-range-thumb{border-radius:290486px}input[type=range].slider.is-circle::-ms-thumb{border-radius:290486px}input[type=range].slider:active::-webkit-slider-thumb{-webkit-transform:scale(1.25);transform:scale(1.25)}input[type=range].slider:active::-moz-range-thumb{transform:scale(1.25)}input[type=range].slider:active::-ms-thumb{transform:scale(1.25)}input[type=range].slider:disabled{opacity:.5;cursor:not-allowed}input[type=range].slider:disabled::-webkit-slider-thumb{cursor:not-allowed;-webkit-transform:scale(1);transform:scale(1)}input[type=range].slider:disabled::-moz-range-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:disabled::-ms-thumb{cursor:not-allowed;transform:scale(1)}input[type=range].slider:not([orient=vertical]){min-height:calc((1rem + 2px) * 1.25)}input[type=range].slider:not([orient=vertical])::-webkit-slider-runnable-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-moz-range-track{height:.5rem}input[type=range].slider:not([orient=vertical])::-ms-track{height:.5rem}input[type=range].slider[orient=vertical]::-webkit-slider-runnable-track{width:.5rem}input[type=range].slider[orient=vertical]::-moz-range-track{width:.5rem}input[type=range].slider[orient=vertical]::-ms-track{width:.5rem}input[type=range].slider::-webkit-slider-thumb{height:1rem;width:1rem}input[type=range].slider::-moz-range-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{height:1rem;width:1rem}input[type=range].slider::-ms-thumb{margin-top:0}input[type=range].slider::-webkit-slider-thumb{margin-top:-.25rem}input[type=range].slider[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.25rem}input[type=range].slider.is-small:not([orient=vertical]){min-height:calc((.75rem + 2px) * 1.25)}input[type=range].slider.is-small:not([orient=vertical])::-webkit-slider-runnable-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-moz-range-track{height:.375rem}input[type=range].slider.is-small:not([orient=vertical])::-ms-track{height:.375rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-runnable-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-moz-range-track{width:.375rem}input[type=range].slider.is-small[orient=vertical]::-ms-track{width:.375rem}input[type=range].slider.is-small::-webkit-slider-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-moz-range-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{height:.75rem;width:.75rem}input[type=range].slider.is-small::-ms-thumb{margin-top:0}input[type=range].slider.is-small::-webkit-slider-thumb{margin-top:-.1875rem}input[type=range].slider.is-small[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.1875rem}input[type=range].slider.is-medium:not([orient=vertical]){min-height:calc((1.25rem + 2px) * 1.25)}input[type=range].slider.is-medium:not([orient=vertical])::-webkit-slider-runnable-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-moz-range-track{height:.625rem}input[type=range].slider.is-medium:not([orient=vertical])::-ms-track{height:.625rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-runnable-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-moz-range-track{width:.625rem}input[type=range].slider.is-medium[orient=vertical]::-ms-track{width:.625rem}input[type=range].slider.is-medium::-webkit-slider-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-moz-range-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{height:1.25rem;width:1.25rem}input[type=range].slider.is-medium::-ms-thumb{margin-top:0}input[type=range].slider.is-medium::-webkit-slider-thumb{margin-top:-.3125rem}input[type=range].slider.is-medium[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.3125rem}input[type=range].slider.is-large:not([orient=vertical]){min-height:calc((1.5rem + 2px) * 1.25)}input[type=range].slider.is-large:not([orient=vertical])::-webkit-slider-runnable-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-moz-range-track{height:.75rem}input[type=range].slider.is-large:not([orient=vertical])::-ms-track{height:.75rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-runnable-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-moz-range-track{width:.75rem}input[type=range].slider.is-large[orient=vertical]::-ms-track{width:.75rem}input[type=range].slider.is-large::-webkit-slider-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-moz-range-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{height:1.5rem;width:1.5rem}input[type=range].slider.is-large::-ms-thumb{margin-top:0}input[type=range].slider.is-large::-webkit-slider-thumb{margin-top:-.375rem}input[type=range].slider.is-large[orient=vertical]::-webkit-slider-thumb{margin-top:auto;margin-left:-.375rem}input[type=range].slider.is-white::-moz-range-track{background:#fff!important}input[type=range].slider.is-white::-webkit-slider-runnable-track{background:#fff!important}input[type=range].slider.is-white::-ms-track{background:#fff!important}input[type=range].slider.is-white::-ms-fill-lower{background:#fff}input[type=range].slider.is-white::-ms-fill-upper{background:#fff}input[type=range].slider.is-white .has-output-tooltip+output,input[type=range].slider.is-white.has-output+output{background-color:#fff;color:#0a0a0a}input[type=range].slider.is-black::-moz-range-track{background:#0a0a0a!important}input[type=range].slider.is-black::-webkit-slider-runnable-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-track{background:#0a0a0a!important}input[type=range].slider.is-black::-ms-fill-lower{background:#0a0a0a}input[type=range].slider.is-black::-ms-fill-upper{background:#0a0a0a}input[type=range].slider.is-black .has-output-tooltip+output,input[type=range].slider.is-black.has-output+output{background-color:#0a0a0a;color:#fff}input[type=range].slider.is-light::-moz-range-track{background:#f5f5f5!important}input[type=range].slider.is-light::-webkit-slider-runnable-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-track{background:#f5f5f5!important}input[type=range].slider.is-light::-ms-fill-lower{background:#f5f5f5}input[type=range].slider.is-light::-ms-fill-upper{background:#f5f5f5}input[type=range].slider.is-light .has-output-tooltip+output,input[type=range].slider.is-light.has-output+output{background-color:#f5f5f5;color:#363636}input[type=range].slider.is-dark::-moz-range-track{background:#363636!important}input[type=range].slider.is-dark::-webkit-slider-runnable-track{background:#363636!important}input[type=range].slider.is-dark::-ms-track{background:#363636!important}input[type=range].slider.is-dark::-ms-fill-lower{background:#363636}input[type=range].slider.is-dark::-ms-fill-upper{background:#363636}input[type=range].slider.is-dark .has-output-tooltip+output,input[type=range].slider.is-dark.has-output+output{background-color:#363636;color:#f5f5f5}input[type=range].slider.is-primary::-moz-range-track{background:#00d1b2!important}input[type=range].slider.is-primary::-webkit-slider-runnable-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-track{background:#00d1b2!important}input[type=range].slider.is-primary::-ms-fill-lower{background:#00d1b2}input[type=range].slider.is-primary::-ms-fill-upper{background:#00d1b2}input[type=range].slider.is-primary .has-output-tooltip+output,input[type=range].slider.is-primary.has-output+output{background-color:#00d1b2;color:#fff}input[type=range].slider.is-link::-moz-range-track{background:#3273dc!important}input[type=range].slider.is-link::-webkit-slider-runnable-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-track{background:#3273dc!important}input[type=range].slider.is-link::-ms-fill-lower{background:#3273dc}input[type=range].slider.is-link::-ms-fill-upper{background:#3273dc}input[type=range].slider.is-link .has-output-tooltip+output,input[type=range].slider.is-link.has-output+output{background-color:#3273dc;color:#fff}input[type=range].slider.is-info::-moz-range-track{background:#209cee!important}input[type=range].slider.is-info::-webkit-slider-runnable-track{background:#209cee!important}input[type=range].slider.is-info::-ms-track{background:#209cee!important}input[type=range].slider.is-info::-ms-fill-lower{background:#209cee}input[type=range].slider.is-info::-ms-fill-upper{background:#209cee}input[type=range].slider.is-info .has-output-tooltip+output,input[type=range].slider.is-info.has-output+output{background-color:#209cee;color:#fff}input[type=range].slider.is-success::-moz-range-track{background:#23d160!important}input[type=range].slider.is-success::-webkit-slider-runnable-track{background:#23d160!important}input[type=range].slider.is-success::-ms-track{background:#23d160!important}input[type=range].slider.is-success::-ms-fill-lower{background:#23d160}input[type=range].slider.is-success::-ms-fill-upper{background:#23d160}input[type=range].slider.is-success .has-output-tooltip+output,input[type=range].slider.is-success.has-output+output{background-color:#23d160;color:#fff}input[type=range].slider.is-warning::-moz-range-track{background:#ffdd57!important}input[type=range].slider.is-warning::-webkit-slider-runnable-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-track{background:#ffdd57!important}input[type=range].slider.is-warning::-ms-fill-lower{background:#ffdd57}input[type=range].slider.is-warning::-ms-fill-upper{background:#ffdd57}input[type=range].slider.is-warning .has-output-tooltip+output,input[type=range].slider.is-warning.has-output+output{background-color:#ffdd57;color:rgba(0,0,0,.7)}input[type=range].slider.is-danger::-moz-range-track{background:#ff3860!important}input[type=range].slider.is-danger::-webkit-slider-runnable-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-track{background:#ff3860!important}input[type=range].slider.is-danger::-ms-fill-lower{background:#ff3860}input[type=range].slider.is-danger::-ms-fill-upper{background:#ff3860}input[type=range].slider.is-danger .has-output-tooltip+output,input[type=range].slider.is-danger.has-output+output{background-color:#ff3860;color:#fff} -------------------------------------------------------------------------------- /static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Noto Sans', sans-serif; 3 | } 4 | 5 | 6 | .footer .icon-link { 7 | font-size: 25px; 8 | color: #000; 9 | } 10 | 11 | .link-block a { 12 | margin-top: 5px; 13 | margin-bottom: 5px; 14 | } 15 | 16 | .dnerf { 17 | font-variant: small-caps; 18 | } 19 | 20 | 21 | .teaser .hero-body { 22 | padding-top: 0; 23 | padding-bottom: 3rem; 24 | } 25 | 26 | .teaser { 27 | font-family: 'Google Sans', sans-serif; 28 | } 29 | 30 | 31 | .publication-title { 32 | } 33 | 34 | .publication-banner { 35 | max-height: parent; 36 | 37 | } 38 | 39 | .publication-banner video { 40 | position: relative; 41 | left: auto; 42 | top: auto; 43 | transform: none; 44 | object-fit: fit; 45 | } 46 | 47 | .publication-header .hero-body { 48 | } 49 | 50 | .publication-title { 51 | font-family: 'Google Sans', sans-serif; 52 | } 53 | 54 | .publication-authors { 55 | font-family: 'Google Sans', sans-serif; 56 | } 57 | 58 | .publication-venue { 59 | color: #555; 60 | width: fit-content; 61 | font-weight: bold; 62 | } 63 | 64 | .publication-awards { 65 | color: #ff3860; 66 | width: fit-content; 67 | font-weight: bolder; 68 | } 69 | 70 | .publication-authors { 71 | } 72 | 73 | .publication-authors a { 74 | color: hsl(204, 86%, 53%) !important; 75 | } 76 | 77 | .publication-authors a:hover { 78 | text-decoration: underline; 79 | } 80 | 81 | .author-block { 82 | display: inline-block; 83 | } 84 | 85 | .publication-banner img { 86 | } 87 | 88 | .publication-authors { 89 | /*color: #4286f4;*/ 90 | } 91 | 92 | .publication-video { 93 | position: relative; 94 | width: 100%; 95 | height: 0; 96 | padding-bottom: 56.25%; 97 | 98 | overflow: hidden; 99 | border-radius: 10px !important; 100 | } 101 | 102 | .publication-video iframe { 103 | position: absolute; 104 | top: 0; 105 | left: 0; 106 | width: 100%; 107 | height: 100%; 108 | } 109 | 110 | .publication-body img { 111 | } 112 | 113 | .results-carousel { 114 | overflow: hidden; 115 | } 116 | 117 | .results-carousel .item { 118 | margin: 5px; 119 | overflow: hidden; 120 | padding: 20px; 121 | font-size: 0; 122 | } 123 | 124 | .results-carousel video { 125 | margin: 0; 126 | } 127 | 128 | .slider-pagination .slider-page { 129 | background: #000000; 130 | } 131 | 132 | .eql-cntrb { 133 | font-size: smaller; 134 | } 135 | 136 | 137 | 138 | -------------------------------------------------------------------------------- /static/images/icon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonmakise/ProRes/8e028c95ad3019a09930d25279c4b996f92f723b/static/images/icon.jpg -------------------------------------------------------------------------------- /static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | 7 | var options = { 8 | slidesToScroll: 1, 9 | slidesToShow: 1, 10 | loop: true, 11 | infinite: true, 12 | autoplay: true, 13 | autoplaySpeed: 5000, 14 | } 15 | 16 | // Initialize all div with carousel class 17 | var carousels = bulmaCarousel.attach('.carousel', options); 18 | 19 | bulmaSlider.attach(); 20 | 21 | }) 22 | --------------------------------------------------------------------------------