├── README.md ├── data ├── NTIRE_Val │ └── test │ │ └── 000936 │ │ ├── alignratio.npy │ │ ├── metadata.npy │ │ ├── raw_1.npy │ │ ├── raw_16.npy │ │ ├── raw_256.npy │ │ ├── raw_4.npy │ │ ├── raw_64.npy │ │ ├── rgb_vis_1.png │ │ ├── rgb_vis_16.png │ │ ├── rgb_vis_256.png │ │ ├── rgb_vis_4.png │ │ └── rgb_vis_64.png ├── __init__.py ├── base_dataset.py ├── bracketire_dataset.py ├── bracketireplus_dataset.py └── degrade │ ├── degrade_kernel.py │ ├── process.py │ └── unprocess.py ├── imgs ├── Overview of CRNet.png ├── multi_pro.png ├── out1.png └── out2.png ├── isp ├── __init__.py ├── demosaic_bayer.py ├── dng_opcode.py ├── exif_data_formats.py ├── exif_utils.py ├── isp.py ├── model.bin ├── pipeline.py ├── pipeline_utils.py └── tone_curve.mat ├── mixer.yaml ├── models ├── __init__.py ├── base_model.py ├── blocks.py ├── cat_model.py ├── degrade │ ├── degrade_kernel.py │ ├── process.py │ └── unprocess.py ├── losses.py └── networks.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── spynet └── spynet_20210409-c6c1bd09.pth ├── test.py ├── test_track1.sh ├── train.py ├── train_track1.sh └── util ├── __init__.py ├── util.py └── visualizer.py /README.md: -------------------------------------------------------------------------------- 1 | # CRNet 2 | 3 | PyTorch implementation of CRNet. Our model achievied third place in track 1 of the Bracketing Image Restoration and Enhancement Challenge. 4 | ## 1. Abstract 5 | 6 | It is challenging but highly desired to acquire high-quality photos with clear content in low-light environments. Although multi-image processing methods (using burst, dual-exposure, or multi-exposure images) have made significant progress in addressing this issue, they typically focus exclusively on specific restoration or enhancement tasks, being insufficient in exploiting multi-image. Motivated by that multi-exposure images are complementary in denoising, deblurring, high dynamic range imaging, and super-resolution, we propose to utilize bracketing photography to unify restoration and enhancement tasks in this work. Due to the difficulty in collecting real-world pairs, we suggest a solution that first pre-trains the model with synthetic paired data and then adapts it to real-world unlabeled images. In particular, a temporally modulated recurrent network (TMRNet) and self-supervised adaptation method are proposed. Moreover, we construct a data simulation pipeline to synthesize pairs and collect real-world images from 200 nighttime scenarios. Experiments on both datasets show that our method performs favorably against the state-of-the-art multi-image processing ones. 7 | 8 | ## 2. Overview of CRNet 9 | 10 |

11 | 12 | ## 3. Comparison of other methods in track 1 of the Bracketing Image Restoration and Enhancement Challenge. 13 | 14 |

15 | 16 | ## 4. Expample Result 17 | 18 |

19 |

20 | 21 | ## 5. Checkpoint 22 | 23 | https://pan.baidu.com/s/17DDXbthjvLjyoHOnUzw-hg?pwd=g1nw 24 | 25 | ## 6. Dataset 26 | 27 | The download link for the dataset is https://github.com/cszhilu1998/BracketIRE. 28 | -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/alignratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/alignratio.npy -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/metadata.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/metadata.npy -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/raw_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_1.npy -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/raw_16.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_16.npy -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/raw_256.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_256.npy -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/raw_4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_4.npy -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/raw_64.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_64.npy -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/rgb_vis_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_1.png -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/rgb_vis_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_16.png -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/rgb_vis_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_256.png -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/rgb_vis_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_4.png -------------------------------------------------------------------------------- /data/NTIRE_Val/test/000936/rgb_vis_64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_64.png -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_dataset import BaseDataset 4 | 5 | 6 | def find_dataset_using_name(dataset_name, split='train'): 7 | dataset_filename = "data." + dataset_name + "_dataset" 8 | datasetlib = importlib.import_module(dataset_filename) 9 | 10 | dataset = None 11 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 12 | for name, cls in datasetlib.__dict__.items(): 13 | if name.lower() == target_dataset_name.lower() \ 14 | and issubclass(cls, BaseDataset): 15 | dataset = cls 16 | 17 | if dataset is None: 18 | raise NotImplementedError("In %s.py, there should be a subclass of " 19 | "BaseDataset with class name that matches %s in " 20 | "lowercase." % (dataset_filename, target_dataset_name)) 21 | return dataset 22 | 23 | 24 | def create_dataset(dataset_name, split, opt): 25 | data_loader = CustomDatasetDataLoader(dataset_name, split, opt) 26 | dataset = data_loader.load_data() 27 | return dataset 28 | 29 | 30 | class CustomDatasetDataLoader(): 31 | def __init__(self, dataset_name, split, opt): 32 | self.opt = opt 33 | dataset_class = find_dataset_using_name(dataset_name, split) 34 | self.dataset = dataset_class(opt, split, dataset_name) 35 | # self.imio = self.dataset.imio 36 | print("dataset [%s(%s)] created" % (dataset_name, split)) 37 | self.dataloader = torch.utils.data.DataLoader( 38 | self.dataset, 39 | batch_size=opt.batch_size if split=='train' else 1, 40 | shuffle=opt.shuffle and split=='train', 41 | num_workers=int(opt.num_dataloader), 42 | drop_last=opt.drop_last) 43 | 44 | def load_data(self): 45 | return self 46 | 47 | def __len__(self): 48 | """Return the number of data in the dataset""" 49 | return min(len(self.dataset), self.opt.max_dataset_size) 50 | 51 | def __iter__(self): 52 | """Return a batch of data""" 53 | for i, data in enumerate(self.dataloader): 54 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 55 | break 56 | yield data 57 | 58 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class BaseDataset(data.Dataset, ABC): 6 | def __init__(self, opt, split, dataset_name): 7 | self.opt = opt 8 | self.split = split 9 | self.root = opt.dataroot 10 | self.dataset_name = dataset_name.lower() 11 | 12 | @abstractmethod 13 | def __len__(self): 14 | return 0 15 | 16 | @abstractmethod 17 | def __getitem__(self, index): 18 | pass 19 | 20 | -------------------------------------------------------------------------------- /data/bracketire_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import torch 5 | import random 6 | from tqdm import tqdm 7 | from os.path import join as opj 8 | from multiprocessing.dummy import Pool 9 | from data.base_dataset import BaseDataset 10 | 11 | 12 | # BracketIRE dataset 13 | class BracketIREDataset(BaseDataset): 14 | def __init__(self, opt, split='train', dataset_name='BracketIRE'): 15 | super(BracketIREDataset, self).__init__(opt, split, dataset_name) 16 | 17 | self.batch_size = opt.batch_size 18 | self.patch_size = opt.patch_size 19 | self.frame_num = opt.frame_num 20 | 21 | if split == 'train': 22 | self._getitem = self._getitem_train 23 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split, 24 | name='Train') 25 | self.len_data = 50000 * self.batch_size 26 | elif split == 'test': 27 | self._getitem = self._getitem_test 28 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split, 29 | name='NTIRE_Val') 30 | self.len_data = len(self.names) 31 | self.meta_data = [0] * len(self.names) 32 | self.raw_images = [0] * len(self.names) 33 | self.gt_images = [0] * len(self.names) 34 | read_images(self) 35 | else: 36 | raise ValueError 37 | 38 | self.split = split 39 | 40 | 41 | def __getitem__(self, index): 42 | return self._getitem(index) 43 | 44 | def __len__(self): 45 | return self.len_data 46 | 47 | def _getitem_train(self, idx): 48 | idx = idx % len(self.names) 49 | 50 | meta_data = np.load(self.meta_dirs[idx], allow_pickle=True) 51 | gt_images = np.load(self.gt_dirs[idx], allow_pickle=True).transpose(2, 0, 1) 52 | imgs = [] 53 | for m in range(5): 54 | imgs.append(np.load(self.raw_dirs[idx][m], allow_pickle=True).transpose(2, 0, 1)) 55 | raw_images = imgs 56 | 57 | raws = torch.from_numpy(np.float32(np.array(raw_images))) / (2 ** 10 - 1) 58 | gt = torch.from_numpy(np.float32(gt_images)) 59 | 60 | raws, gt = self._crop_patch(raws, gt, self.patch_size) 61 | 62 | return {'gt': gt, # [4, H, W] 63 | 'raws': raws, # [T=5, 4, H, W] 64 | 'fname': self.names[idx]} 65 | 66 | def _getitem_test(self, idx): 67 | raws = torch.from_numpy(np.float32(np.array(self.raw_images[idx]))) / (2 ** 10 - 1) 68 | meta = self._process_metadata(self.meta_data[idx]) 69 | 70 | return {'meta': meta, 71 | 'gt': raws[0], 72 | 'raws': raws, 73 | 'fname': self.names[idx]} 74 | 75 | def _crop_patch(self, raws, gt, p): 76 | ih, iw = raws.shape[-2:] 77 | ph = random.randrange(10, ih - p + 1 - 10) 78 | pw = random.randrange(10, iw - p + 1 - 10) 79 | return raws[..., ph:ph + p, pw:pw + p], \ 80 | gt[..., ph:ph + p, pw:pw + p] 81 | 82 | def _process_metadata(self, metadata): 83 | metadata_item = metadata.item() 84 | meta = {} 85 | for key in metadata_item: 86 | meta[key] = torch.from_numpy(metadata_item[key]) 87 | return meta 88 | 89 | def _read_raw_path(self, root): 90 | img_paths = [] 91 | for expo in range(self.frame_num): 92 | img_paths.append(opj(root, 'raw_' + str(4 ** expo) + '.npy')) 93 | return img_paths 94 | 95 | def _get_image_dir(self, dataroot, split=None, name=None): 96 | image_names = [] 97 | meta_dirs = [] 98 | raw_dirs = [] 99 | gt_dirs = [] 100 | 101 | for scene_file in sorted(os.listdir(opj(dataroot, name))): 102 | for image_file in sorted(os.listdir(opj(dataroot, name, scene_file))): 103 | image_root = opj(dataroot, name, scene_file, image_file) 104 | image_names.append(scene_file + '-' + image_file) 105 | meta_dirs.append(opj(image_root, 'metadata.npy')) 106 | raw_dirs.append(self._read_raw_path(image_root)) 107 | if split == 'train': 108 | gt_dirs.append(opj(image_root, 'raw_gt.npy')) 109 | elif split == 'test': 110 | gt_dirs = [] 111 | 112 | return image_names, meta_dirs, raw_dirs, gt_dirs 113 | 114 | 115 | def iter_obj(num, objs): 116 | for i in range(num): 117 | yield (i, objs) 118 | 119 | 120 | def imreader(arg): 121 | i, obj = arg 122 | for _ in range(3): 123 | try: 124 | imgs = [] 125 | for m in range(obj.frame_num): 126 | imgs.append(np.load(obj.raw_dirs[i][m], allow_pickle=True).transpose(2, 0, 1)) 127 | obj.raw_images[i] = imgs 128 | if obj.split == 'train': 129 | obj.gt_images[i] = np.load(obj.gt_dirs[i], allow_pickle=True).transpose(2, 0, 1) 130 | obj.meta_data[i] = np.load(obj.meta_dirs[i], allow_pickle=True) 131 | failed = False 132 | break 133 | except: 134 | failed = True 135 | if failed: print('%s fails!' % obj.names[i]) 136 | 137 | 138 | def read_images(obj): 139 | # may use `from multiprocessing import Pool` instead, but less efficient and 140 | # NOTE: `multiprocessing.Pool` will duplicate given object for each process. 141 | print('Starting to load images via multiple imreaders') 142 | pool = Pool() # use all threads by default 143 | for _ in tqdm(pool.imap(imreader, iter_obj(len(obj.names), obj)), total=len(obj.names)): 144 | pass 145 | pool.close() 146 | pool.join() 147 | 148 | 149 | if __name__ == '__main__': 150 | pass 151 | -------------------------------------------------------------------------------- /data/bracketireplus_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import torch 5 | import random 6 | from tqdm import tqdm 7 | from os.path import join as opj 8 | from multiprocessing.dummy import Pool 9 | from data.base_dataset import BaseDataset 10 | 11 | 12 | # BracketIRE+ dataset 13 | class BracketIREPlusDataset(BaseDataset): 14 | def __init__(self, opt, split='train', dataset_name='BracketIREPlus'): 15 | super(BracketIREPlusDataset, self).__init__(opt, split, dataset_name) 16 | 17 | self.batch_size = opt.batch_size 18 | self.patch_size = opt.patch_size 19 | self.frame_num = opt.frame_num 20 | self.scale = 4 21 | 22 | if split == 'train': 23 | self._getitem = self._getitem_train 24 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split, name='Train') 25 | self.len_data = 500 * self.batch_size 26 | elif split == 'test': 27 | self._getitem = self._getitem_test 28 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split, name='NTIRE_Val') 29 | self.len_data = len(self.names) 30 | else: 31 | raise ValueError 32 | 33 | self.meta_data = [0] * len(self.names) 34 | self.raw_images = [0] * len(self.names) 35 | self.gt_images = [0] * len(self.names) 36 | read_images(self) 37 | 38 | def __getitem__(self, index): 39 | return self._getitem(index) 40 | 41 | def __len__(self): 42 | return self.len_data 43 | 44 | def _getitem_train(self, idx): 45 | idx = idx % len(self.names) 46 | 47 | raws = torch.from_numpy(np.float32(np.array(self.raw_images[idx]))) / (2**10 - 1) 48 | gt = torch.from_numpy(np.float32(self.gt_images[idx])) 49 | 50 | raws, gt = self._crop_patch(raws, gt, self.patch_size, self.scale) 51 | 52 | return {'gt': gt, # [4, H, W] 53 | 'raws': raws, # [T=5, 4, H, W] 54 | 'fname': self.names[idx]} 55 | 56 | def _getitem_test(self, idx): 57 | raws = torch.from_numpy(np.float32(np.array(self.raw_images[idx]))) / (2**10 - 1) 58 | meta = self._process_metadata(self.meta_data[idx]) 59 | 60 | return {'meta': meta, 61 | 'gt': raws[0], 62 | 'raws': raws, 63 | 'fname': self.names[idx]} 64 | 65 | def _crop_patch(self, raws, gt, p, s): 66 | ih, iw = raws.shape[-2:] 67 | ph = random.randrange(4, ih - p + 1 - 4) 68 | pw = random.randrange(4, iw - p + 1 - 4) 69 | return raws[..., ph:ph+p, pw:pw+p], \ 70 | gt[..., ph*s:(ph+p)*s, pw*s:(pw+p)*s] 71 | 72 | def _process_metadata(self, metadata): 73 | metadata_item = metadata.item() 74 | meta = {} 75 | for key in metadata_item: 76 | meta[key] = torch.from_numpy(metadata_item[key]) 77 | return meta 78 | 79 | def _read_raw_path(self, root): 80 | img_paths = [] 81 | for expo in range(self.frame_num): 82 | img_paths.append(opj(root, 'x'+str(self.scale), 'raw_' + str(4**expo) + '.npy')) 83 | return img_paths 84 | 85 | def _get_image_dir(self, dataroot, split=None, name=None): 86 | image_names = [] 87 | meta_dirs = [] 88 | raw_dirs = [] 89 | gt_dirs = [] 90 | 91 | for scene_file in sorted(os.listdir(opj(dataroot, name))): 92 | for image_file in sorted(os.listdir(opj(dataroot, name, scene_file))): 93 | image_root = opj(dataroot, name, scene_file, image_file) 94 | image_names.append(scene_file + '-' + image_file) 95 | meta_dirs.append(opj(image_root, 'metadata.npy')) 96 | raw_dirs.append(self._read_raw_path(image_root)) 97 | if split == 'train': 98 | gt_dirs.append(opj(image_root, 'raw_gt.npy')) 99 | elif split == 'test': 100 | gt_dirs = [] 101 | 102 | return image_names, meta_dirs, raw_dirs, gt_dirs 103 | 104 | 105 | def iter_obj(num, objs): 106 | for i in range(num): 107 | yield (i, objs) 108 | 109 | def imreader(arg): 110 | i, obj = arg 111 | for _ in range(3): 112 | try: 113 | imgs = [] 114 | for m in range(obj.frame_num): 115 | imgs.append(np.load(obj.raw_dirs[i][m], allow_pickle=True).transpose(2, 0, 1)) 116 | obj.raw_images[i] = imgs 117 | if obj.split == 'train': 118 | obj.gt_images[i] = np.load(obj.gt_dirs[i], allow_pickle=True).transpose(2, 0, 1) 119 | obj.meta_data[i] = np.load(obj.meta_dirs[i], allow_pickle=True) 120 | failed = False 121 | break 122 | except: 123 | failed = True 124 | if failed: print('%s fails!' % obj.names[i]) 125 | 126 | def read_images(obj): 127 | # may use `from multiprocessing import Pool` instead, but less efficient and 128 | # NOTE: `multiprocessing.Pool` will duplicate given object for each process. 129 | print('Starting to load images via multiple imreaders') 130 | pool = Pool() # use all threads by default 131 | for _ in tqdm(pool.imap(imreader, iter_obj(len(obj.names), obj)), total=len(obj.names)): 132 | pass 133 | pool.close() 134 | pool.join() 135 | 136 | if __name__ == '__main__': 137 | pass 138 | -------------------------------------------------------------------------------- /data/degrade/degrade_kernel.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | import torch 5 | from scipy import ndimage 6 | from scipy.interpolate import interp2d 7 | from .unprocess import unprocess, random_noise_levels, add_noise 8 | from .process import process 9 | from PIL import Image 10 | 11 | 12 | 13 | # def get_rgb2raw2rgb(img): 14 | # img = torch.from_numpy(np.array(img)) / 255.0 15 | # deg_img, features = unprocess(img) 16 | # shot_noise, read_noise = random_noise_levels() 17 | # deg_img = add_noise(deg_img, shot_noise, read_noise) 18 | # deg_img = deg_img.unsqueeze(0) 19 | # features['red_gain'] = features['red_gain'].unsqueeze(0) 20 | # features['blue_gain'] = features['blue_gain'].unsqueeze(0) 21 | # features['cam2rgb'] = features['cam2rgb'].unsqueeze(0) 22 | # deg_img = process(deg_img, features['red_gain'], features['blue_gain'], features['cam2rgb']) 23 | # deg_img = deg_img.squeeze(0) 24 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy() 25 | # deg_img = deg_img.astype(np.uint8) 26 | # return Image.fromarray(deg_img) 27 | 28 | 29 | # def get_rgb2raw_noise(img, noise_level, features=None): 30 | # # img = np.transpose(img, (1, 2, 0)) 31 | # img = torch.from_numpy(np.array(img)) / 255.0 32 | 33 | # deg_img, features = unprocess(img, features) 34 | # shot_noise, read_noise = random_noise_levels(noise_level) 35 | # deg_img_noise = add_noise(deg_img, shot_noise, read_noise) 36 | # # deg_img_noise = torch.clamp(deg_img_noise, min=0.0, max=1.0) 37 | 38 | # # deg_img = np.transpose(deg_img, (2, 0, 1)) 39 | # # deg_img_noise = np.transpose(deg_img_noise, (2, 0, 1)) 40 | # return deg_img_noise, features 41 | 42 | 43 | def get_rgb2raw(img, features=None): 44 | # img = np.transpose(img, (1, 2, 0)) 45 | device = img.device 46 | deg_img, features = unprocess(img, features, device) 47 | return deg_img, features 48 | 49 | 50 | def get_raw2rgb(img, features, demosaic='net', lineRGB=False): 51 | # img = np.transpose(img, (1, 2, 0)) 52 | # img = torch.from_numpy(np.array(img)) 53 | img = img.unsqueeze(0) 54 | device = img.device 55 | deg_img = process(img, features['red_gain'].to(device), features['blue_gain'].to(device), 56 | features['cam2rgb'].to(device), demosaic, lineRGB) 57 | deg_img = deg_img.squeeze(0) 58 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy() 59 | # deg_img = deg_img.astype(np.uint8) 60 | return deg_img 61 | 62 | 63 | # def pack_raw_image(im_raw): # HxW 64 | # """ Packs a single channel bayer image into 4 channel tensor, where channels contain R, G, G, and B values""" 65 | # if isinstance(im_raw, np.ndarray): 66 | # im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2)) 67 | # elif isinstance(im_raw, torch.Tensor): 68 | # im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype) 69 | # else: 70 | # raise Exception 71 | 72 | # im_out[0, :, :] = im_raw[0::2, 0::2] 73 | # im_out[1, :, :] = im_raw[0::2, 1::2] 74 | # im_out[2, :, :] = im_raw[1::2, 0::2] 75 | # im_out[3, :, :] = im_raw[1::2, 1::2] 76 | # return im_out # 4xHxW 77 | 78 | 79 | # def flatten_raw_image(im_raw_4ch): # 4xHxW 80 | # """ unpack a 4-channel tensor into a single channel bayer image""" 81 | # if isinstance(im_raw_4ch, np.ndarray): 82 | # im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2)) 83 | # elif isinstance(im_raw_4ch, torch.Tensor): 84 | # im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype) 85 | # else: 86 | # raise Exception 87 | 88 | # im_out[0::2, 0::2] = im_raw_4ch[0, :, :] 89 | # im_out[0::2, 1::2] = im_raw_4ch[1, :, :] 90 | # im_out[1::2, 0::2] = im_raw_4ch[2, :, :] 91 | # im_out[1::2, 1::2] = im_raw_4ch[3, :, :] 92 | 93 | # return im_out # HxW -------------------------------------------------------------------------------- /data/degrade/process.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Forward processing of raw data to sRGB images. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | import numpy as np 23 | import torch 24 | import torch.nn as nn 25 | import torch.distributions as tdist 26 | from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007 27 | import os 28 | from isp import demosaic_bayer 29 | 30 | 31 | def apply_gains(bayer_images, red_gains, blue_gains): 32 | """Applies white balance gains to a batch of Bayer images.""" 33 | red_gains = red_gains.squeeze(1) 34 | blue_gains= blue_gains.squeeze(1) 35 | green_gains = torch.ones_like(red_gains) 36 | gains = torch.stack([red_gains, green_gains, green_gains, blue_gains], dim=-1) 37 | gains = gains[:, None, None, :] 38 | # print(bayer_images.shape, gains.shape) 39 | outs = bayer_images * gains 40 | return outs 41 | 42 | 43 | def demosaic(bayer_images): 44 | def SpaceToDepth_fact2(x): 45 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14 46 | bs = 2 47 | N, C, H, W = x.size() 48 | x = x.view(N, C, H // bs, bs, W // bs, bs) # (N, C, H//bs, bs, W//bs, bs) 49 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 50 | x = x.view(N, C * (bs ** 2), H // bs, W // bs) # (N, C*bs^2, H//bs, W//bs) 51 | return x 52 | def DepthToSpace_fact2(x): 53 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14 54 | bs = 2 55 | N, C, H, W = x.size() 56 | x = x.view(N, bs, bs, C // (bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 57 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 58 | x = x.view(N, C // (bs ** 2), H * bs, W * bs) # (N, C//bs^2, H * bs, W * bs) 59 | return x 60 | 61 | """Bilinearly demosaics a batch of RGGB Bayer images.""" 62 | 63 | shape = bayer_images.size() 64 | shape = [shape[1] * 2, shape[2] * 2] 65 | 66 | red = bayer_images[Ellipsis, 0:1] 67 | upsamplebyX = nn.Upsample(size=shape, mode='bilinear', align_corners=False) 68 | red = upsamplebyX(red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 69 | 70 | green_red = bayer_images[Ellipsis, 1:2] 71 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right 72 | green_red = upsamplebyX(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 73 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right 74 | green_red = SpaceToDepth_fact2(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 75 | 76 | green_blue = bayer_images[Ellipsis, 2:3] 77 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down 78 | green_blue = upsamplebyX(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 79 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down 80 | green_blue = SpaceToDepth_fact2(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 81 | 82 | green_at_red = (green_red[Ellipsis, 0] + green_blue[Ellipsis, 0]) / 2 83 | green_at_green_red = green_red[Ellipsis, 1] 84 | green_at_green_blue = green_blue[Ellipsis, 2] 85 | green_at_blue = (green_red[Ellipsis, 3] + green_blue[Ellipsis, 3]) / 2 86 | 87 | green_planes = [ 88 | green_at_red, green_at_green_red, green_at_green_blue, green_at_blue 89 | ] 90 | green = DepthToSpace_fact2(torch.stack(green_planes, dim=-1).permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 91 | 92 | blue = bayer_images[Ellipsis, 3:4] 93 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0]) 94 | blue = upsamplebyX(blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 95 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0]) 96 | 97 | rgb_images = torch.cat([red, green, blue], dim=-1) 98 | return rgb_images 99 | 100 | 101 | def apply_ccms(images, ccms): 102 | """Applies color correction matrices.""" 103 | images = images[:, :, :, None, :] 104 | ccms = ccms[:, None, None, :, :] 105 | outs = torch.sum(images * ccms, dim=-1) 106 | return outs 107 | 108 | 109 | def gamma_compression(images, gamma=2.2): 110 | """Converts from linear to gamma space.""" 111 | # Clamps to prevent numerical instability of gradients near zero. 112 | Mask = lambda x: (x>0.0031308).float() 113 | sRGBDeLinearize = lambda x,m: m * (1.055 * (m * x) ** (1/2.4) - 0.055) + (1-m) * (12.92 * x) 114 | return sRGBDeLinearize(images, Mask(images)) 115 | # outs = torch.clamp(images, min=1e-8) ** (1.0 / gamma) 116 | # return outs 117 | 118 | 119 | def process(bayer_images, red_gains, blue_gains, cam2rgbs, demosaic_type, lineRGB): 120 | # print(bayer_images.shape, red_gains.shape, cam2rgbs.shape) 121 | """Processes a batch of Bayer RGGB images into sRGB images.""" 122 | # White balance. 123 | bayer_images = apply_gains(bayer_images, red_gains, blue_gains) 124 | # Demosaic. 125 | bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) 126 | 127 | if demosaic_type == 'default': 128 | images = demosaic(bayer_images) 129 | elif demosaic_type == 'menon2007': 130 | # print(bayer_images.size()) 131 | bayer_images = flatten_raw_image(bayer_images.squeeze(0)) 132 | images = demosaicing_CFA_Bayer_Menon2007(bayer_images.cpu().numpy(), 'RGGB') 133 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device) 134 | elif demosaic_type == 'net': 135 | bayer_images = flatten_raw_image(bayer_images.squeeze(0)).cpu().numpy() 136 | bayer = np.power(np.clip(bayer_images.astype(dtype=np.float32), 0, 1), 1 / 2.2) 137 | pretrained_model_path = "./isp/model.bin" 138 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=red_gains.device, 139 | cfa='bayer', state_dict=True) 140 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa='RGGB', 141 | demosaic_net=demosaic_net, device=red_gains.device) 142 | images = np.power(np.clip(rgb, 0, 1), 2.2) 143 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device) 144 | 145 | # Color correction. 146 | images = apply_ccms(images, cam2rgbs) 147 | # Gamma compression. 148 | images = torch.clamp(images, min=0.0, max=1.0) 149 | if not lineRGB: 150 | images = gamma_compression(images) 151 | return images 152 | 153 | 154 | def flatten_raw_image(im_raw_4ch): # HxWx4 155 | """ unpack a 4-channel tensor into a single channel bayer image""" 156 | if isinstance(im_raw_4ch, np.ndarray): 157 | im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2)) 158 | elif isinstance(im_raw_4ch, torch.Tensor): 159 | im_out = torch.zeros((im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2), dtype=im_raw_4ch.dtype) 160 | else: 161 | raise Exception 162 | 163 | im_out[0::2, 0::2] = im_raw_4ch[:, :, 0] 164 | im_out[0::2, 1::2] = im_raw_4ch[:, :, 1] 165 | im_out[1::2, 0::2] = im_raw_4ch[:, :, 2] 166 | im_out[1::2, 1::2] = im_raw_4ch[:, :, 3] 167 | 168 | return im_out # HxW -------------------------------------------------------------------------------- /data/degrade/unprocess.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unprocesses sRGB images into realistic raw data. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | import numpy as np 23 | import torch 24 | import torch.distributions as tdist 25 | 26 | 27 | def random_ccm(device): 28 | """Generates random RGB -> Camera color correction matrices.""" 29 | # Takes a random convex combination of XYZ -> Camera CCMs. 30 | xyz2cams = [[[1.0234, -0.2969, -0.2266], 31 | [-0.5625, 1.6328, -0.0469], 32 | [-0.0703, 0.2188, 0.6406]], 33 | [[0.4913, -0.0541, -0.0202], 34 | [-0.613, 1.3513, 0.2906], 35 | [-0.1564, 0.2151, 0.7183]], 36 | [[0.838, -0.263, -0.0639], 37 | [-0.2887, 1.0725, 0.2496], 38 | [-0.0627, 0.1427, 0.5438]], 39 | [[0.6596, -0.2079, -0.0562], 40 | [-0.4782, 1.3016, 0.1933], 41 | [-0.097, 0.1581, 0.5181]]] 42 | num_ccms = len(xyz2cams) 43 | xyz2cams = torch.FloatTensor(xyz2cams) 44 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8) 45 | weights_sum = torch.sum(weights, dim=0) 46 | xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum 47 | 48 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. 49 | rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375], 50 | [0.2126729, 0.7151522, 0.0721750], 51 | [0.0193339, 0.1191920, 0.9503041]]) 52 | rgb2cam = torch.mm(xyz2cam.to(device), rgb2xyz.to(device)) 53 | 54 | # Normalizes each row. 55 | rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True) 56 | return rgb2cam 57 | 58 | 59 | def random_gains(device): 60 | """Generates random gains for brightening and white balance.""" 61 | # RGB gain represents brightening. 62 | n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1])) 63 | rgb_gain = 1.0 / n.sample() 64 | 65 | # Red and blue gains represent white balance. 66 | red_gain = torch.FloatTensor(1).uniform_(1.9, 2.4) 67 | blue_gain = torch.FloatTensor(1).uniform_(1.5, 1.9) 68 | return rgb_gain.to(device), red_gain.to(device), blue_gain.to(device) 69 | 70 | 71 | def inverse_smoothstep(image): 72 | """Approximately inverts a global tone mapping curve.""" 73 | image = torch.clamp(image, min=0.0, max=1.0) 74 | out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) 75 | return out 76 | 77 | 78 | def gamma_expansion(image): 79 | """Converts from gamma to linear space.""" 80 | # Clamps to prevent numerical instability of gradients near zero. 81 | Mask = lambda x: (x>0.04045).float() 82 | sRGBLinearize = lambda x,m: m * ((m * x + 0.055) / 1.055) ** 2.4 + (1-m) * (x / 12.92) 83 | return sRGBLinearize(image, Mask(image)) 84 | # out = torch.clamp(image, min=1e-8) ** 2.2 85 | # return out 86 | 87 | 88 | def apply_ccm(image, ccm): 89 | """Applies a color correction matrix.""" 90 | shape = image.size() 91 | image = torch.reshape(image, [-1, 3]) 92 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]]) 93 | out = torch.reshape(image, shape) 94 | return out 95 | 96 | 97 | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device): 98 | """Inverts gains while safely handling saturated pixels.""" 99 | gains = torch.stack((1.0 / red_gain, torch.tensor([1.0]).to(device), 1.0 / blue_gain)) # / rgb_gain 100 | gains = gains.to(device).squeeze() 101 | gains = gains[None, None, :] 102 | # Prevents dimming of saturated pixels by smoothly masking gains near white. 103 | gray = torch.mean(image, dim=-1, keepdim=True) 104 | inflection = 0.9 105 | mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0 106 | safe_gains = torch.max(mask + (1.0 - mask) * gains, gains) 107 | out = image * safe_gains 108 | return out 109 | 110 | 111 | def mosaic(image): 112 | """Extracts RGGB Bayer planes from an RGB image.""" 113 | shape = image.size() 114 | red = image[0::2, 0::2, 0] 115 | green_red = image[0::2, 1::2, 1] 116 | green_blue = image[1::2, 0::2, 1] 117 | blue = image[1::2, 1::2, 2] 118 | out = torch.stack((red, green_red, green_blue, blue), dim=-1) 119 | out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4)) 120 | return out 121 | 122 | 123 | def unprocess(image, features=None, device=None): 124 | """Unprocesses an image from sRGB to realistic raw data.""" 125 | 126 | if features == None: 127 | # Randomly creates image metadata. 128 | rgb2cam = random_ccm(device) 129 | cam2rgb = torch.inverse(rgb2cam) 130 | rgb_gain, red_gain, blue_gain = random_gains(device) 131 | else: 132 | rgb2cam = features['rgb2cam'] 133 | cam2rgb = features['cam2rgb'] 134 | rgb_gain = features['rgb_gain'] 135 | red_gain = features['red_gain'] 136 | blue_gain = features['blue_gain'] 137 | # Approximately inverts global tone mapping. 138 | # image = inverse_smoothstep(image) 139 | # Inverts gamma compression. 140 | image = gamma_expansion(image) 141 | # Inverts color correction. 142 | image = apply_ccm(image, rgb2cam) 143 | # Approximately inverts white balance and brightening. 144 | image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device) 145 | # Clips saturated pixels. 146 | # image = torch.clamp(image, min=0.0, max=1.0) 147 | # Applies a Bayer mosaic. 148 | image = mosaic(image) 149 | 150 | metadata = { 151 | 'rgb2cam': rgb2cam, 152 | 'cam2rgb': cam2rgb, 153 | 'rgb_gain': rgb_gain, 154 | 'red_gain': red_gain, 155 | 'blue_gain': blue_gain, 156 | } 157 | return image, metadata 158 | 159 | 160 | # ############### If the target dataset is DND, use this function ##################### 161 | # def random_noise_levels(): 162 | # """Generates random noise levels from a log-log linear distribution.""" 163 | # log_min_shot_noise = np.log(0.0001) 164 | # log_max_shot_noise = np.log(0.012) 165 | # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise) 166 | # shot_noise = torch.exp(log_shot_noise) 167 | 168 | # line = lambda x: 2.18 * x + 1.20 169 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.26])) 170 | # log_read_noise = line(log_shot_noise) + n.sample() 171 | # read_noise = torch.exp(log_read_noise) 172 | # return shot_noise, read_noise 173 | 174 | 175 | def add_noise(image, shot_noise, read_noise): 176 | var = image * shot_noise + read_noise 177 | noise = tdist.Normal(loc=torch.zeros_like(var), scale=torch.sqrt(var)).sample() 178 | out = image + noise 179 | return out 180 | 181 | 182 | ################ If the target dataset is SIDD, use this function ##################### 183 | def random_noise_levels(noise_level): 184 | """ Where read_noise in SIDD is not 0 """ 185 | log_min_shot_noise = torch.log(torch.tensor(0.0012)).to(noise_level.device) 186 | log_max_shot_noise = torch.log(torch.tensor(0.0048)).to(noise_level.device) 187 | log_shot_noise = log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise) 188 | shot_noise = torch.exp(log_shot_noise) 189 | 190 | line = lambda x: 1.869 * x + 0.3276 191 | n = tdist.Normal(loc=torch.tensor([0.0], device=noise_level.device), 192 | scale=torch.tensor([0.30], device=noise_level.device)) 193 | 194 | log_read_noise = line(log_shot_noise) + n.sample() 195 | read_noise = torch.exp(log_read_noise) 196 | 197 | return shot_noise, read_noise 198 | 199 | 200 | # def add_noise(image, shot_noise=0.01, read_noise=0.0005): 201 | # """Adds random shot (proportional to image) and read (independent) noise.""" 202 | # variance = image * shot_noise + read_noise 203 | # n = tdist.Normal(loc=torch.zeros_like(variance), scale=torch.sqrt(variance)) 204 | # noise = n.sample() 205 | # out = image + noise 206 | # return out 207 | 208 | 209 | # ################ If the target dataset is SIDD, use this function ##################### 210 | # def random_noise_levels(noise_level): 211 | # """ Where read_noise in SIDD is not 0 """ 212 | # log_min_shot_noise = np.log(0.00068674) 213 | # log_max_shot_noise = np.log(0.02194856) 214 | # # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise) 215 | # log_shot_noise = torch.FloatTensor([log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise)]) 216 | # shot_noise = torch.exp(log_shot_noise) 217 | 218 | # line = lambda x: 1.85 * x + 0.30 219 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.20])) 220 | # log_read_noise = line(log_shot_noise) + n.sample() 221 | # read_noise = torch.exp(log_read_noise) 222 | # return shot_noise, read_noise -------------------------------------------------------------------------------- /imgs/Overview of CRNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/Overview of CRNet.png -------------------------------------------------------------------------------- /imgs/multi_pro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/multi_pro.png -------------------------------------------------------------------------------- /imgs/out1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/out1.png -------------------------------------------------------------------------------- /imgs/out2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/out2.png -------------------------------------------------------------------------------- /isp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/isp/__init__.py -------------------------------------------------------------------------------- /isp/demosaic_bayer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | # sys.path.insert(0, os.path.dirname(__file__)) 4 | import torch 5 | import numpy as np 6 | import pdb 7 | import copy 8 | from collections import OrderedDict 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class BayerNetwork(nn.Module): 15 | """Released version of the network, best quality. 16 | 17 | This model differs from the published description. It has a mask/filter split 18 | towards the end of the processing. Masks and filters are multiplied with each 19 | other. This is not key to performance and can be ignored when training new 20 | models from scratch. 21 | """ 22 | def __init__(self, depth=15, width=64): 23 | super(BayerNetwork, self).__init__() 24 | 25 | self.depth = depth 26 | self.width = width 27 | 28 | # self.debug_layer = nn.Conv2d(3, 4, 2, stride=2) 29 | # self.debug_layer1 =nn.Conv2d(in_channels=4,out_channels=64,kernel_size=3,stride=1,padding=1) 30 | # self.debug_layer2 =nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1) 31 | # self.debug_layer3 =nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1,padding=1) 32 | 33 | layers = OrderedDict([ 34 | ("pack_mosaic", nn.Conv2d(3, 4, 2, stride=2)), # Downsample 2x2 to re-establish translation invariance. 35 | ]) # 36 | # the output of 'pack_mosaic' will be half width and height of the input 37 | # [batch_size, 4, h/2, w/2] = pack_mosaic ( [batch_size, 3, h, w] ) 38 | 39 | for i in range(depth): 40 | #num of in and out neurons in each layers 41 | n_out = width 42 | n_in = width 43 | 44 | if i == 0: # the 1st layer in main_processor 45 | n_in = 4 46 | if i == depth-1: # the last layer in main_processor 47 | n_out = 2*width 48 | 49 | # layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3) 50 | layers["conv{}".format(i + 1)] = nn.Conv2d(n_in, n_out, 3,stride=1,padding=1) 51 | # padding is set to be 1 so that the h and w won't change after conv2d (using kernal size 3) 52 | layers["relu{}".format(i+1)] = nn.ReLU(inplace=True) 53 | 54 | 55 | # main conv layer 56 | self.main_processor = nn.Sequential(layers) 57 | # residual layer 58 | self.residual_predictor = nn.Conv2d(width, 12, 1) 59 | # upsample layer 60 | self.upsampler = nn.ConvTranspose2d(12, 3, 2, stride=2, groups=3) 61 | 62 | # full-res layer 63 | self.fullres_processor = nn.Sequential(OrderedDict([ 64 | # ("post_conv", nn.Conv2d(6, width, 3)), 65 | ("post_conv", nn.Conv2d(6, width, 3,stride=1,padding=1)), 66 | # padding is set to be 1 so that the h and w won't change after conv2d (using kernal size 3) 67 | ("post_relu", nn.ReLU(inplace=True)), 68 | ("output", nn.Conv2d(width, 3, 1)), 69 | ])) 70 | 71 | 72 | # samples structure 73 | # sample = { 74 | # "mosaic": mosaic, 75 | # # model input [batch_size, 3, h,w]. unknown pixels are set to 0. 76 | # "mask": mask, 77 | # # "noise_variance": np.array([std]), 78 | # "target": im, 79 | # # model output [m,n,3] 80 | # } 81 | def forward(self, samples): 82 | 83 | mosaic = samples["mosaic"] 84 | # [batch_size, 3, h, w] 85 | 86 | features = self.main_processor(mosaic) 87 | # [batch_size, self.width*2, hf,wf] 88 | 89 | filters, masks = features[:, :self.width], features[:, self.width:] 90 | 91 | filtered = filters * masks 92 | # [batch_size, self.width, hf,wf] 93 | 94 | residual = self.residual_predictor(filtered) 95 | # [batch_size, 12, hf, wf] 96 | 97 | upsampled = self.upsampler(residual) 98 | # [batch_size, 3, hf*2, wf*2]. upsampled will be 2x2 upsample of residual using ConvTranspose2d() 99 | 100 | # crop original mosaic to match output size 101 | cropped = crop_like(mosaic, upsampled) 102 | 103 | # Concated input samples and residual for further filtering 104 | packed = torch.cat([cropped, upsampled], 1) 105 | 106 | output = self.fullres_processor(packed) 107 | 108 | return output 109 | 110 | 111 | class Converter(object): 112 | def __init__(self, pretrained_dir, model_type): 113 | self.basedir = pretrained_dir 114 | 115 | def convert(self, model): 116 | for n, p in model.named_parameters(): 117 | name, tp = n.split(".")[-2:] 118 | 119 | old_name = self._remap(name) 120 | # print(old_name, "->", name) 121 | 122 | if tp == "bias": 123 | idx = 1 124 | else: 125 | idx = 0 126 | path = os.path.join(self.basedir, "{}_{}.npy".format(old_name, idx)) 127 | data = np.load(path) 128 | # print(name, tp, data.shape, p.shape) 129 | 130 | # Overwiter 131 | # print(p.mean().item(), p.std().item()) 132 | # import ipdb; ipdb.set_trace() 133 | # print(name, old_name, p.shape, data.shape) 134 | p.data.copy_(torch.from_numpy(data)) 135 | # print(p.mean().item(), p.std().item()) 136 | 137 | def _remap(self, s): 138 | if s == "pack_mosaic": 139 | return "pack_mosaick" 140 | if s == "residual_predictor": 141 | return "residual" 142 | if s == "upsampler": 143 | return "unpack_mosaick" 144 | if s == "post_conv": 145 | return "post_conv1" 146 | return s 147 | 148 | 149 | def crop_like(src, tgt): 150 | src_sz = np.array(src.shape) 151 | tgt_sz = np.array(tgt.shape) 152 | crop = (src_sz[2:4]-tgt_sz[2:4]) // 2 153 | if (crop > 0).any(): 154 | return src[:, :, crop[0]:src_sz[2]-crop[0], crop[1]:src_sz[3]-crop[1], ...] 155 | else: 156 | return src 157 | 158 | 159 | def get_modules(params): 160 | params = copy.deepcopy(params) # do not touch the original 161 | 162 | # get the model name from the input params 163 | model_name = params.pop("model", None) 164 | 165 | if model_name is None: 166 | raise ValueError("model has not been specified!") 167 | 168 | # get the model structure by model_name 169 | return getattr(sys.modules[__name__], model_name)(**params) 170 | 171 | 172 | def get_demosaic_net_model(pretrained, device, cfa='bayer', state_dict=False): 173 | ''' 174 | get demosaic network 175 | :param pretrained: 176 | path to the demosaic-network model file [string] 177 | :param device: 178 | 'cuda:0', e.g. 179 | :param state_dict: 180 | whether to use a packed state dictionary for model weights 181 | :return: 182 | model_ref: demosaic-net model 183 | 184 | ''' 185 | 186 | model_ref = get_modules({"model": "BayerNetwork"}) # load model coefficients if 'pretrained'=True 187 | if not state_dict: 188 | cvt = Converter(pretrained, "BayerNetwork") 189 | cvt.convert(model_ref) 190 | for p in model_ref.parameters(): 191 | p.requires_grad = False 192 | model_ref = model_ref.to(device) 193 | else: 194 | model_ref.load_state_dict(torch.load(pretrained)) 195 | model_ref = model_ref.to(device) 196 | 197 | model_ref.eval() 198 | 199 | return model_ref 200 | 201 | 202 | def demosaic_by_demosaic_net(bayer, cfa, demosaic_net, device): 203 | ''' 204 | demosaic the bayer to get RGB by demosaic-net. The func will covnert the numpy array to tensor for demosaic-net, 205 | after which the tensor will be converted back to numpy array to return. 206 | 207 | :param bayer: 208 | [m,n]. numpy float32 in the rnage of [0,1] linear bayer 209 | :param cfa: 210 | [string], 'RGGB', e.g. only GBRG, RGGB, BGGR or GRBG is supported so far! 211 | :param demosaic_net: 212 | demosaic_net object 213 | :param device: 214 | 'cuda:0', e.g. 215 | 216 | :return: 217 | [m,n,3]. np array float32 in the rnage of [0,1] 218 | 219 | ''' 220 | 221 | 222 | assert (cfa == 'GBRG') or (cfa == 'RGGB') or (cfa == 'GRBG') or (cfa == 'BGGR'), 'only GBRG, RGGB, BGGR, GRBG are supported so far!' 223 | 224 | # if the bayer resolution is too high (more than 1000x1000,e.g.), may cause memory error. 225 | 226 | bayer = np.clip(bayer ,0 ,1) 227 | bayer = torch.from_numpy(bayer).float() 228 | bayer = bayer.to(device) 229 | bayer = torch.unsqueeze(bayer, 0) 230 | bayer = torch.unsqueeze(bayer, 0) 231 | 232 | with torch.no_grad(): 233 | rgb = predict_rgb_from_bayer_tensor(bayer, cfa=cfa, demosaic_net=demosaic_net, device=device) 234 | 235 | rgb = rgb.detach().cpu()[0].permute(1, 2, 0).numpy() # torch tensor -> numpy array 236 | # rgb = np.clip(rgb, 0, 1) 237 | 238 | return rgb 239 | 240 | 241 | def predict_rgb_from_bayer_tensor(im,cfa,demosaic_net,device): 242 | ''' 243 | predict the RGB imgae from bayer pattern mosaic using demosaic net 244 | 245 | :param im: 246 | [batch_sz, 1, m,n] tensor. the bayer pattern mosiac. 247 | 248 | :param cfa: 249 | the cfa layout. the demosaic net is trained w/ GRBG. If the input is other than GRBG, need padding or cropping 250 | 251 | :param demosaic_net: 252 | demosaic-net 253 | 254 | :param device: 255 | 'cuda:0', e.g. 256 | 257 | :return: 258 | rgb_hat: 259 | [batch_size, 3, m,n] the rgb image predicted by the demosaic-net using our bayer input 260 | ''' 261 | 262 | assert (cfa == 'GBRG') or (cfa == 'RGGB') or (cfa == 'GRBG') or (cfa == 'BGGR') 263 | # 'only GBRG, RGGB, BGGR, GRBG are supported so far!' 264 | 265 | # print(im.shape) 266 | 267 | n_channel = im.shape[1] 268 | 269 | if n_channel==1: # gray scale image 270 | im= torch.cat((im, im, im), 1) 271 | 272 | if cfa == 'GBRG': # the demosiac net is trained w/ GRBG 273 | im = pad_gbrg_2_grbg(im,device) 274 | elif cfa == 'RGGB': 275 | im = pad_rggb_2_grbg(im, device) 276 | elif cfa == 'BGGR': 277 | im = pad_bggr_2_grbg(im, device) 278 | 279 | im= bayer_mosaic_tensor(im,device) 280 | 281 | sample = {"mosaic": im} 282 | 283 | rgb_hat = demosaic_net(sample) 284 | 285 | if cfa == 'GBRG': 286 | # an extra row and col is padded on four sides of the bayer before using demosaic-net. Need to trim the padded rows and cols of demosaiced rgb 287 | rgb_hat = unpad_grbg_2_gbrg(rgb_hat) 288 | elif cfa == 'RGGB': 289 | rgb_hat = unpad_grbg_2_rggb(rgb_hat) 290 | elif cfa == 'BGGR': 291 | rgb_hat = unpad_grbg_2_bggr(rgb_hat) 292 | 293 | rgb_hat = torch.clamp(rgb_hat, min=0, max=1) 294 | 295 | return rgb_hat 296 | 297 | 298 | def pad_bggr_2_grbg(bayer, device): 299 | ''' 300 | pad bggr bayer pattern to get grbg (for demosaic-net) 301 | 302 | :param bayer: 303 | 2d tensor [bsz,ch, h,w] 304 | :param device: 305 | 'cuda:0' or 'cpu', or ... 306 | :return: 307 | bayer: 2d tensor [bsz,ch,h,w+2] 308 | 309 | ''' 310 | bsz, ch, h, w = bayer.shape 311 | 312 | bayer2 = torch.zeros([bsz, ch, h + 2, w], dtype=torch.float32) 313 | bayer2 = bayer2.to(device) 314 | 315 | bayer2[:, :, 1:-1, :] = bayer 316 | 317 | bayer2[:, :, 0, :] = bayer[:, :, 1, :] 318 | bayer2[:, :, -1, :] = bayer2[:, :, -2, :] 319 | 320 | bayer = bayer2 321 | 322 | return bayer 323 | 324 | 325 | def pad_rggb_2_grbg(bayer,device): 326 | ''' 327 | pad rggb bayer pattern to get grbg (for demosaic-net) 328 | 329 | :param bayer: 330 | 2d tensor [bsz,ch, h,w] 331 | :param device: 332 | 'cuda:0' or 'cpu', or ... 333 | :return: 334 | bayer: 2d tensor [bsz,ch,h,w+2] 335 | 336 | ''' 337 | bsz, ch, h, w = bayer.shape 338 | 339 | bayer2 = torch.zeros([bsz,ch,h, w+2], dtype=torch.float32) 340 | bayer2 = bayer2.to(device) 341 | 342 | bayer2[:,:,:, 1:-1] = bayer 343 | 344 | bayer2[:,:,:, 0] = bayer[:,:,:, 1] 345 | bayer2[:,:,:, -1] = bayer2[:,:,:, -2] 346 | 347 | bayer = bayer2 348 | 349 | return bayer 350 | 351 | 352 | def pad_gbrg_2_grbg(bayer,device): 353 | ''' 354 | pad gbrg bayer pattern to get grbg (for demosaic-net) 355 | 356 | :param bayer: 357 | 2d tensor [bsz,ch, h,w] 358 | :param device: 359 | 'cuda:0' or 'cpu', or ... 360 | :return: 361 | bayer: 2d tensor [bsz,ch,h+4,w+4] 362 | 363 | ''' 364 | bsz, ch, h, w = bayer.shape 365 | 366 | bayer2 = torch.zeros([bsz,ch,h+2, w+2], dtype=torch.float32) 367 | bayer2 = bayer2.to(device) 368 | 369 | bayer2[:,:,1:-1, 1:-1] = bayer 370 | bayer2[:,:,0, 1:-1] = bayer[:,:,1, :] 371 | bayer2[:,:,-1, 1:-1] = bayer[:,:,-2, :] 372 | 373 | bayer2[:,:,:, 0] = bayer2[:,:,:, 2] 374 | bayer2[:,:,:, -1] = bayer2[:,:,:, -3] 375 | 376 | bayer = bayer2 377 | 378 | return bayer 379 | 380 | 381 | def unpad_grbg_2_gbrg(rgb): 382 | ''' 383 | unpad the rgb image. this is used after pad_gbrg_2_grbg() 384 | :param rgb: 385 | tensor. [1,3,m,n] 386 | :return: 387 | tensor [1,3,m-2,n-2] 388 | 389 | ''' 390 | rgb = rgb[:,:,1:-1,1:-1] 391 | 392 | return rgb 393 | 394 | 395 | def unpad_grbg_2_bggr(rgb): 396 | ''' 397 | unpad the rgb image. this is used after pad_bggr_2_grbg() 398 | :param rgb: 399 | tensor. [1,3,m,n] 400 | :return: 401 | tensor [1,3,m,n-2] 402 | 403 | ''' 404 | rgb = rgb[:, :, 1:-1 , : ] 405 | 406 | return rgb 407 | 408 | 409 | def unpad_grbg_2_rggb(rgb): 410 | ''' 411 | unpad the rgb image. this is used after pad_rggb_2_grbg() 412 | :param rgb: 413 | tensor. [1,3,m,n] 414 | :return: 415 | tensor [1,3,m,n-2] 416 | 417 | ''' 418 | rgb = rgb[:,:,:,1:-1] 419 | 420 | return rgb 421 | 422 | 423 | def bayer_mosaic_tensor(im,device): 424 | ''' 425 | create bayer mosaic to set as input to demosaic-net. 426 | make sure the input bayer (im) is GRBG. 427 | 428 | :param im: 429 | [batch_size, 3, m,n]. The color is in RGB order. 430 | :param device: 431 | 'cuda:0', e.g. 432 | :return: 433 | ''' 434 | 435 | """GRBG Bayer mosaic.""" 436 | 437 | batch_size=im.shape[0] 438 | hh=im.shape[2] 439 | ww=im.shape[3] 440 | 441 | mask = torch.ones([batch_size,3,hh, ww], dtype=torch.float32) 442 | mask = mask.to(device) 443 | 444 | # red 445 | mask[:,0, ::2, 0::2] = 0 446 | mask[:,0, 1::2, :] = 0 447 | 448 | # green 449 | mask[:,1, ::2, 1::2] = 0 450 | mask[:,1, 1::2, ::2] = 0 451 | 452 | # blue 453 | mask[:,2, 0::2, :] = 0 454 | mask[:,2, 1::2, 1::2] = 0 455 | 456 | return im*mask -------------------------------------------------------------------------------- /isp/dng_opcode.py: -------------------------------------------------------------------------------- 1 | """ 2 | Authors(s): 3 | Abdelrahman Abdelhamed (a.abdelhamed@samsung.com) 4 | 5 | Utility functions for handling DNG opcode lists. 6 | """ 7 | import struct 8 | import numpy as np 9 | from .exif_utils import get_tag_values_from_ifds 10 | 11 | 12 | class Opcode: 13 | def __init__(self, id_, dng_spec_ver, option_bits, size_bytes, data): 14 | self.id = id_ 15 | self.dng_spec_ver = dng_spec_ver 16 | self.size_bytes = size_bytes 17 | self.option_bits = option_bits 18 | self.data = data 19 | 20 | 21 | def parse_opcode_lists(ifds): 22 | # OpcodeList1, 51008, 0xC740 23 | # Applied to raw image as read directly form file 24 | 25 | # OpcodeList2, 51009, 0xC741 26 | # Applied to raw image after being mapped to linear reference values 27 | # That is, after linearization, black level subtraction, normalization, and clipping 28 | 29 | # OpcodeList3, 51022, 0xC74E 30 | # Applied to raw image after being demosaiced 31 | 32 | opcode_list_tag_nums = [51008, 51009, 51022] 33 | opcode_lists = {} 34 | for i, tag_num in enumerate(opcode_list_tag_nums): 35 | opcode_list_ = get_tag_values_from_ifds(tag_num, ifds) 36 | if opcode_list_ is not None: 37 | opcode_list_ = bytearray(opcode_list_) 38 | opcodes = parse_opcodes(opcode_list_) 39 | opcode_lists.update({tag_num: opcodes}) 40 | else: 41 | pass 42 | 43 | return opcode_lists 44 | 45 | 46 | def parse_opcodes(opcode_list): 47 | """ 48 | Parse a byte array representing an opcode list. 49 | :param opcode_list: An opcode list as a byte array. 50 | :return: Opcode lists as a dictionary. 51 | """ 52 | # opcode lists are always stored in big endian 53 | endian_sign = ">" 54 | 55 | # opcode IDs 56 | # 9: GainMap 57 | # 1: Rectilinear Warp 58 | 59 | # clip to 60 | # [0, 2^32 - 1] for OpcodeList1 61 | # [0, 2^16 - 1] for OpcodeList2 62 | # [0, 1] for OpcodeList3 63 | 64 | i = 0 65 | num_opcodes = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0] 66 | i += 4 67 | 68 | opcodes = {} 69 | for j in range(num_opcodes): 70 | opcode_id_ = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0] 71 | i += 4 72 | dng_spec_ver = [struct.unpack(endian_sign + "B", opcode_list[i + k:i + k + 1])[0] for k in range(4)] 73 | i += 4 74 | option_bits = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0] 75 | i += 4 76 | 77 | # option bits 78 | if option_bits & 1 == 1: # optional/unknown 79 | pass 80 | elif option_bits & 2 == 2: # can be skipped for "preview quality", needed for "full quality" 81 | pass 82 | else: 83 | pass 84 | 85 | opcode_size_bytes = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0] 86 | i += 4 87 | 88 | opcode_data = opcode_list[i:i + 4 * opcode_size_bytes] 89 | i += 4 * opcode_size_bytes 90 | 91 | # GainMap (lens shading correction map) 92 | if opcode_id_ == 9: 93 | opcode_gain_map_data = parse_opcode_gain_map(opcode_data) 94 | opcode_data = opcode_gain_map_data 95 | 96 | # set opcode object 97 | opcode = Opcode(id_=opcode_id_, dng_spec_ver=dng_spec_ver, option_bits=option_bits, 98 | size_bytes=opcode_size_bytes, 99 | data=opcode_data) 100 | opcodes.update({opcode_id_: opcode}) 101 | 102 | return opcodes 103 | 104 | 105 | def parse_opcode_gain_map(opcode_data): 106 | endian_sign = ">" # big 107 | opcode_dict = {} 108 | keys = ['top', 'left', 'bottom', 'right', 'plane', 'planes', 'row_pitch', 'col_pitch', 'map_points_v', 109 | 'map_points_h', 'map_spacing_v', 'map_spacing_h', 'map_origin_v', 'map_origin_h', 'map_planes', 'map_gain'] 110 | dtypes = ['L'] * 10 + ['d'] * 4 + ['L'] + ['f'] 111 | dtype_sizes = [4] * 10 + [8] * 4 + [4] * 2 # data type size in bytes 112 | counts = [1] * 15 + [0] # 0 count means variable count, depending on map_points_v and map_points_h 113 | # values = [] 114 | 115 | i = 0 116 | for k in range(len(keys)): 117 | if counts[k] == 0: # map_gain 118 | counts[k] = opcode_dict['map_points_v'] * opcode_dict['map_points_h'] 119 | 120 | if counts[k] == 1: 121 | vals = struct.unpack(endian_sign + dtypes[k], opcode_data[i:i + dtype_sizes[k]])[0] 122 | i += dtype_sizes[k] 123 | else: 124 | vals = [] 125 | for j in range(counts[k]): 126 | vals.append(struct.unpack(endian_sign + dtypes[k], opcode_data[i:i + dtype_sizes[k]])[0]) 127 | i += dtype_sizes[k] 128 | 129 | opcode_dict[keys[k]] = vals 130 | 131 | opcode_dict['map_gain_2d'] = np.reshape(opcode_dict['map_gain'], 132 | (opcode_dict['map_points_v'], opcode_dict['map_points_h'])) 133 | 134 | return opcode_dict 135 | -------------------------------------------------------------------------------- /isp/exif_data_formats.py: -------------------------------------------------------------------------------- 1 | class ExifFormat: 2 | def __init__(self, id, name, size, short_name): 3 | self.id = id 4 | self.name = name 5 | self.size = size 6 | self.short_name = short_name # used with struct.unpack() 7 | 8 | 9 | exif_formats = { 10 | 1: ExifFormat(1, 'unsigned byte', 1, 'B'), 11 | 2: ExifFormat(2, 'ascii string', 1, 's'), 12 | 3: ExifFormat(3, 'unsigned short', 2, 'H'), 13 | 4: ExifFormat(4, 'unsigned long', 4, 'L'), 14 | 5: ExifFormat(5, 'unsigned rational', 8, ''), 15 | 6: ExifFormat(6, 'signed byte', 1, 'b'), 16 | 7: ExifFormat(7, 'undefined', 1, 'B'), # consider `undefined` as `unsigned byte` 17 | 8: ExifFormat(8, 'signed short', 2, 'h'), 18 | 9: ExifFormat(9, 'signed long', 4, 'l'), 19 | 10: ExifFormat(10, 'signed rational', 8, ''), 20 | 11: ExifFormat(11, 'single float', 4, 'f'), 21 | 12: ExifFormat(12, 'double float', 8, 'd'), 22 | } 23 | -------------------------------------------------------------------------------- /isp/exif_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author(s): 3 | Abdelrahman Abdelhamed 4 | 5 | Manual parsing of image file directories (IFDs). 6 | """ 7 | 8 | 9 | import struct 10 | from fractions import Fraction 11 | from .exif_data_formats import exif_formats 12 | 13 | 14 | class Ifd: 15 | def __init__(self): 16 | self.offset = -1 17 | self.tags = {} # dict; tag number will be key. 18 | 19 | 20 | class Tag: 21 | def __init__(self): 22 | self.offset = -1 23 | self.tag_num = -1 24 | self.data_format = -1 25 | self.num_values = -1 26 | self.values = [] 27 | 28 | 29 | def parse_exif(image_path, verbose=True): 30 | """ 31 | Parse EXIF tags from a binary file and return IFDs. 32 | Returned IFDs include EXIF SubIFDs, if any. 33 | """ 34 | 35 | def print_(str_): 36 | if verbose: 37 | print(str_) 38 | 39 | ifds = {} # dict of pairs; using offset to IFD as key. 40 | 41 | with open(image_path, 'rb') as fid: 42 | fid.seek(0) 43 | b0 = fid.read(1) 44 | _ = fid.read(1) 45 | # byte storage direction (endian): 46 | # +1: b'M' (big-endian/Motorola) 47 | # -1: b'I' (little-endian/Intel) 48 | endian = 1 if b0 == b'M' else -1 49 | print_("Endian = {}".format(b0)) 50 | endian_sign = "<" if endian == -1 else ">" # used in struct.unpack 51 | print_("Endian sign = {}".format(endian_sign)) 52 | _ = fid.read(2) # 0x002A 53 | b4_7 = fid.read(4) # offset to first IFD 54 | offset_ = struct.unpack(endian_sign + "I", b4_7)[0] 55 | i = 0 56 | ifd_offsets = [offset_] 57 | while len(ifd_offsets) > 0: 58 | offset_ = ifd_offsets.pop(0) 59 | # check if IFD at this offset was already parsed before 60 | if offset_ in ifds: 61 | continue 62 | print_("=========== Parsing IFD # {} ===========".format(i)) 63 | ifd_ = parse_exif_ifd(fid, offset_, endian_sign, verbose) 64 | ifds.update({ifd_.offset: ifd_}) 65 | print_("=========== Finished parsing IFD # {} ===========".format(i)) 66 | i += 1 67 | # check SubIFDs; zero or more offsets at tag 0x014a 68 | sub_idfs_tag_num = int('0x014a', 16) 69 | if sub_idfs_tag_num in ifd_.tags: 70 | ifd_offsets.extend(ifd_.tags[sub_idfs_tag_num].values) 71 | # check Exif SUbIDF; usually one offset at tag 0x8769 72 | exif_sub_idf_tag_num = int('0x8769', 16) 73 | if exif_sub_idf_tag_num in ifd_.tags: 74 | ifd_offsets.extend(ifd_.tags[exif_sub_idf_tag_num].values) 75 | return ifds 76 | 77 | 78 | def parse_exif_ifd(binary_file, offset_, endian_sign, verbose=True): 79 | """ 80 | Parse an EXIF IFD. 81 | """ 82 | 83 | def print_(str_): 84 | if verbose: 85 | print(str_) 86 | 87 | ifd = Ifd() 88 | ifd.offset = offset_ 89 | print_("IFD offset = {}".format(ifd.offset)) 90 | binary_file.seek(offset_) 91 | num_entries = struct.unpack(endian_sign + "H", binary_file.read(2))[0] # format H = unsigned short 92 | print_("Number of entries = {}".format(num_entries)) 93 | for t in range(num_entries): 94 | print_("---------- Tag {} / {} ----------".format(t + 1, num_entries)) 95 | if t == 22: 96 | ttt = 1 97 | tag_ = parse_exif_tag(binary_file, endian_sign, verbose) 98 | ifd.tags.update({tag_.tag_num: tag_}) # supposedly, EXIF tag numbers won't repeat in the same IFD 99 | # TODO: check for subsequent IFDs by parsing the next 4 bytes immediately after the IFD 100 | return ifd 101 | 102 | 103 | def parse_exif_tag(binary_file, endian_sign, verbose=True): 104 | """ 105 | Parse EXIF tag from a binary file starting from the current file pointer and returns the tag values. 106 | """ 107 | 108 | def print_(str_): 109 | if verbose: 110 | print(str_) 111 | 112 | tag = Tag() 113 | 114 | # tag offset 115 | tag.offset = binary_file.tell() 116 | print_("Tag offset = {}".format(tag.offset)) 117 | 118 | # tag number 119 | bytes_ = binary_file.read(2) 120 | tag.tag_num = struct.unpack(endian_sign + "H", bytes_)[0] # H: unsigned 2-byte short 121 | print_("Tag number = {} = 0x{:04x}".format(tag.tag_num, tag.tag_num)) 122 | 123 | # data format (some value between [1, 12]) 124 | tag.data_format = struct.unpack(endian_sign + "H", binary_file.read(2))[0] # H: unsigned 2-byte short 125 | exif_format = exif_formats[tag.data_format] 126 | print_("Data format = {} = {}".format(tag.data_format, exif_format.name)) 127 | 128 | # number of components/values 129 | tag.num_values = struct.unpack(endian_sign + "I", binary_file.read(4))[0] # I: unsigned 4-byte integer 130 | print_("Number of values = {}".format(tag.num_values)) 131 | 132 | # total number of data bytes 133 | total_bytes = tag.num_values * exif_format.size 134 | print_("Total bytes = {}".format(total_bytes)) 135 | 136 | # seek to data offset (if needed) 137 | data_is_offset = False 138 | current_offset = binary_file.tell() 139 | if total_bytes > 4: 140 | print_("Total bytes > 4; The next 4 bytes are an offset.") 141 | data_is_offset = True 142 | data_offset = struct.unpack(endian_sign + "I", binary_file.read(4))[0] 143 | current_offset = binary_file.tell() 144 | print_("Current offset = {}".format(current_offset)) 145 | print_("Seeking to data offset = {}".format(data_offset)) 146 | binary_file.seek(data_offset) 147 | 148 | # read values 149 | # TODO: need to distinguish between numeric and text values? 150 | if tag.num_values == 1 and total_bytes < 4: 151 | # special case: data is a single value that is less than 4 bytes inside 4 bytes, take care of endian 152 | val_bytes = binary_file.read(4) 153 | # if endian_sign == ">": 154 | # val_bytes = val_bytes[4 - total_bytes:] 155 | # else: 156 | # val_bytes = val_bytes[:total_bytes][::-1] 157 | val_bytes = val_bytes[:total_bytes] 158 | tag.values.append(struct.unpack(endian_sign + exif_format.short_name, val_bytes)[0]) 159 | else: 160 | # read data values one by one 161 | for k in range(tag.num_values): 162 | val_bytes = binary_file.read(exif_format.size) 163 | if exif_format.name == 'unsigned rational': 164 | tag.values.append(eight_bytes_to_fraction(val_bytes, endian_sign, signed=False)) 165 | elif exif_format.name == 'signed rational': 166 | tag.values.append(eight_bytes_to_fraction(val_bytes, endian_sign, signed=True)) 167 | else: 168 | tag.values.append(struct.unpack(endian_sign + exif_format.short_name, val_bytes)[0]) 169 | if total_bytes < 4: 170 | # special case: multiple values less than 4 bytes in total, inside the 4 bytes; skip the extra bytes 171 | binary_file.seek(4 - total_bytes, 1) 172 | 173 | if verbose: 174 | if len(tag.values) > 100: 175 | print_("Got more than 100 values; printing first 100 only:") 176 | print_("Tag values = {}".format(tag.values[:100])) 177 | else: 178 | print_("Tag values = {}".format(tag.values)) 179 | if tag.data_format == 2: 180 | print_("Tag values (string) = {}".format(b''.join(tag.values).decode())) 181 | 182 | if data_is_offset: 183 | # seek back to current position to read the next tag 184 | print_("Seeking back to current offset = {}".format(current_offset)) 185 | binary_file.seek(current_offset) 186 | 187 | return tag 188 | 189 | 190 | def get_tag_values_from_ifds(tag_num, ifds): 191 | """ 192 | Return values of a tag, if found in ifds. Return None otherwise. 193 | Assuming any tag exists only once in all ifds. 194 | """ 195 | for key, ifd in ifds.items(): 196 | if tag_num in ifd.tags: 197 | return ifd.tags[tag_num].values 198 | return None 199 | 200 | 201 | def eight_bytes_to_fraction(eight_bytes, endian_sign, signed): 202 | """ 203 | Convert 8-byte array into a Fraction. Take care of endian and sign. 204 | """ 205 | if signed: 206 | num = struct.unpack(endian_sign + "l", eight_bytes[:4])[0] 207 | den = struct.unpack(endian_sign + "l", eight_bytes[4:])[0] 208 | else: 209 | num = struct.unpack(endian_sign + "L", eight_bytes[:4])[0] 210 | den = struct.unpack(endian_sign + "L", eight_bytes[4:])[0] 211 | den = den if den != 0 else 1 212 | return Fraction(num, den) 213 | -------------------------------------------------------------------------------- /isp/isp.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import cv2 4 | import numpy as np 5 | 6 | from .pipeline import run_pipeline_v2 7 | from .pipeline_utils import get_visible_raw_image, get_metadata 8 | 9 | params = { 10 | 'input_stage': 'normal', # options: 'raw', 'normal', 'white_balance', 'demosaic', 'xyz', 'srgb', 'gamma', 'tone' 11 | 'output_stage': 'srgb', # options: 'normal', 'white_balance', 'demosaic', 'xyz', 'srgb', 'gamma', 'tone' 12 | 'save_as': 'png', # options: 'jpg', 'png', 'tif', etc. 13 | 'demosaic_type': 'net', # 'menon2007', 'EA', 'VNG', 'net', 'down 14 | 'save_dtype': np.uint16 15 | } 16 | 17 | def reshape_back_raw(bayer): 18 | H = bayer.shape[1] 19 | W = bayer.shape[2] 20 | newH = int(H*2) 21 | newW = int(W*2) 22 | bayer_back = np.zeros((newH, newW)) 23 | bayer_back[0:newH:2, 0:newW:2] = bayer[3] 24 | bayer_back[0:newH:2, 1:newW:2] = bayer[1] 25 | bayer_back[1:newH:2, 0:newW:2] = bayer[2] 26 | bayer_back[1:newH:2, 1:newW:2] = bayer[0] 27 | 28 | return bayer_back 29 | 30 | def isp_pip(raw_image, meta_data, device='0'): 31 | # raw_image = reshape_back_raw(npy_img) / ratio 32 | 33 | # metadata 34 | # meta_npy = meta_data.items() 35 | metadata = get_metadata(meta_data) 36 | # raw_image = raw_image * (2**10 - 1) # + meta_npy['black_level'][0] 37 | 38 | # render 39 | output_image = run_pipeline_v2(raw_image, params, metadata=metadata, device=device) 40 | # output_image = output_image[..., ::-1] * 255 41 | # output_image = np.clip(output_image, 0, 255) 42 | 43 | return output_image #.astype(params['save_dtype']) 44 | -------------------------------------------------------------------------------- /isp/model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/isp/model.bin -------------------------------------------------------------------------------- /isp/pipeline.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .pipeline_utils import get_visible_raw_image, get_metadata, normalize, white_balance, demosaic, \ 3 | apply_color_space_transform, transform_xyz_to_srgb, apply_gamma, apply_tone_map, fix_orientation, \ 4 | lens_shading_correction 5 | 6 | 7 | def run_pipeline_v2(image_or_path, params=None, metadata=None, fix_orient=True, device='0'): 8 | params_ = params.copy() 9 | if type(image_or_path) == str: 10 | image_path = image_or_path 11 | # raw image data 12 | raw_image = get_visible_raw_image(image_path) 13 | # metadata 14 | metadata = get_metadata(image_path) 15 | else: 16 | raw_image = image_or_path.copy() 17 | # must provide metadata 18 | if metadata is None: 19 | raise ValueError("Must provide metadata when providing image data in first argument.") 20 | 21 | current_stage = 'raw' 22 | current_image = raw_image 23 | 24 | if params_['input_stage'] == current_stage: 25 | # linearization 26 | linearization_table = metadata['linearization_table'] 27 | if linearization_table is not None: 28 | print('Linearization table found. Not handled.') 29 | # TODO 30 | 31 | current_image = normalize(current_image, metadata['black_level'], metadata['white_level']) 32 | params_['input_stage'] = 'normal' 33 | 34 | current_stage = 'normal' 35 | 36 | if params_['output_stage'] == current_stage: 37 | return current_image 38 | 39 | if params_['input_stage'] == current_stage: 40 | gain_map_opcode = None 41 | if 'opcode_lists' in metadata: 42 | if 51009 in metadata['opcode_lists']: 43 | opcode_list_2 = metadata['opcode_lists'][51009] 44 | gain_map_opcode = opcode_list_2[9] 45 | if gain_map_opcode is not None: 46 | current_image = lens_shading_correction(current_image, gain_map_opcode=gain_map_opcode, 47 | bayer_pattern=metadata['cfa_pattern']) 48 | params_['input_stage'] = 'lens_shading_correction' 49 | 50 | current_stage = 'lens_shading_correction' 51 | 52 | if params_['output_stage'] == current_stage: 53 | return current_image 54 | 55 | if params_['input_stage'] == current_stage: 56 | current_image = white_balance(current_image, metadata['as_shot_neutral'], metadata['cfa_pattern']) 57 | params_['input_stage'] = 'white_balance' 58 | 59 | current_stage = 'white_balance' 60 | 61 | if params_['output_stage'] == current_stage: 62 | return current_image 63 | 64 | if params_['input_stage'] == current_stage: 65 | current_image = demosaic(current_image, metadata['cfa_pattern'], output_channel_order='RGB', 66 | alg_type=params_['demosaic_type'], device=device) 67 | params_['input_stage'] = 'demosaic' 68 | 69 | current_stage = 'demosaic' 70 | 71 | if params_['output_stage'] == current_stage: 72 | return current_image 73 | 74 | if params_['input_stage'] == current_stage: 75 | current_image = apply_color_space_transform(current_image, metadata['color_matrix_1'], 76 | metadata['color_matrix_2']) 77 | params_['input_stage'] = 'xyz' 78 | 79 | current_stage = 'xyz' 80 | 81 | if params_['output_stage'] == current_stage: 82 | return current_image 83 | 84 | if params_['input_stage'] == current_stage: 85 | current_image = transform_xyz_to_srgb(current_image) 86 | params_['input_stage'] = 'srgb' 87 | 88 | current_stage = 'srgb' 89 | 90 | if fix_orient: 91 | # fix image orientation, if needed (after srgb stage, ok?) 92 | current_image = fix_orientation(current_image, metadata['orientation']) 93 | 94 | if params_['output_stage'] == current_stage: 95 | return current_image 96 | 97 | if params_['input_stage'] == current_stage: 98 | current_image = apply_gamma(current_image) 99 | params_['input_stage'] = 'gamma' 100 | 101 | current_stage = 'gamma' 102 | 103 | if params_['output_stage'] == current_stage: 104 | return current_image 105 | 106 | if params_['input_stage'] == current_stage: 107 | current_image = apply_tone_map(current_image) 108 | params_['input_stage'] = 'tone' 109 | 110 | current_stage = 'tone' 111 | 112 | if params_['output_stage'] == current_stage: 113 | return current_image 114 | 115 | # invalid input/output stage! 116 | raise ValueError('Invalid input/output stage: input_stage = {}, output_stage = {}'.format(params_['input_stage'], 117 | params_['output_stage'])) 118 | 119 | 120 | def run_pipeline(image_path, params): 121 | # raw image data 122 | raw_image = get_visible_raw_image(image_path) 123 | 124 | # metadata 125 | metadata = get_metadata(image_path) 126 | 127 | # linearization 128 | linearization_table = metadata['linearization_table'] 129 | if linearization_table is not None: 130 | print('Linearization table found. Not handled.') 131 | # TODO 132 | 133 | normalized_image = normalize(raw_image, metadata['black_level'], metadata['white_level']) 134 | 135 | if params['output_stage'] == 'normal': 136 | return normalized_image 137 | 138 | white_balanced_image = white_balance(normalized_image, metadata['as_shot_neutral'], metadata['cfa_pattern']) 139 | 140 | if params['output_stage'] == 'white_balance': 141 | return white_balanced_image 142 | 143 | demosaiced_image = demosaic(white_balanced_image, metadata['cfa_pattern'], output_channel_order='BGR', 144 | alg_type=params['demosaic_type']) 145 | 146 | # fix image orientation, if needed 147 | demosaiced_image = fix_orientation(demosaiced_image, metadata['orientation']) 148 | 149 | if params['output_stage'] == 'demosaic': 150 | return demosaiced_image 151 | 152 | xyz_image = apply_color_space_transform(demosaiced_image, metadata['color_matrix_1'], metadata['color_matrix_2']) 153 | 154 | if params['output_stage'] == 'xyz': 155 | return xyz_image 156 | 157 | srgb_image = transform_xyz_to_srgb(xyz_image) 158 | 159 | if params['output_stage'] == 'srgb': 160 | return srgb_image 161 | 162 | gamma_corrected_image = apply_gamma(srgb_image) 163 | 164 | if params['output_stage'] == 'gamma': 165 | return gamma_corrected_image 166 | 167 | tone_mapped_image = apply_tone_map(gamma_corrected_image) 168 | if params['output_stage'] == 'tone': 169 | return tone_mapped_image 170 | 171 | output_image = None 172 | return output_image 173 | -------------------------------------------------------------------------------- /isp/pipeline_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author(s): 3 | Abdelrahman Abdelhamed 4 | 5 | Camera pipeline utilities. 6 | """ 7 | 8 | import os 9 | from fractions import Fraction 10 | 11 | import cv2 12 | import numpy as np 13 | import exifread 14 | # from exifread import Ratio 15 | from exifread.utils import Ratio 16 | import rawpy 17 | from scipy.io import loadmat 18 | from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007 19 | import struct 20 | 21 | from .dng_opcode import parse_opcode_lists 22 | from .exif_data_formats import exif_formats 23 | from .exif_utils import parse_exif_tag, parse_exif, get_tag_values_from_ifds 24 | from isp import demosaic_bayer 25 | 26 | 27 | def get_visible_raw_image(image_path): 28 | raw_image = rawpy.imread(image_path).raw_image_visible.copy() 29 | # raw_image = rawpy.imread(image_path).raw_image.copy() 30 | return raw_image 31 | 32 | 33 | def get_image_tags(image_path): 34 | with open(image_path, 'rb') as f: 35 | tags = exifread.process_file(f) 36 | return tags 37 | 38 | 39 | def get_image_ifds(image_path): 40 | ifds = parse_exif(image_path, verbose=False) 41 | return ifds 42 | 43 | 44 | def get_metadata(image_path): 45 | metadata = {} 46 | tags = image_path 47 | ifds = image_path 48 | # tags = get_image_tags(image_path) 49 | # ifds = get_image_ifds(image_path) 50 | metadata['linearization_table'] = get_linearization_table(tags, ifds) 51 | metadata['black_level'] = get_black_level(tags, ifds) 52 | metadata['white_level'] = get_white_level(tags, ifds) 53 | metadata['cfa_pattern'] = get_cfa_pattern(tags, ifds) 54 | metadata['as_shot_neutral'] = get_as_shot_neutral(tags, ifds) 55 | color_matrix_1, color_matrix_2 = get_color_matrices(tags, ifds) 56 | metadata['color_matrix_1'] = color_matrix_1 57 | metadata['color_matrix_2'] = color_matrix_2 58 | metadata['orientation'] = get_orientation(tags, ifds) 59 | metadata['noise_profile'] = get_noise_profile(tags, ifds) 60 | # ... 61 | 62 | # opcode lists 63 | # metadata['opcode_lists'] = parse_opcode_lists(ifds) 64 | 65 | # fall back to default values, if necessary 66 | if metadata['black_level'] is None: 67 | metadata['black_level'] = 0 68 | print("Black level is None; using 0.") 69 | if metadata['white_level'] is None: 70 | metadata['white_level'] = 2 ** 16 71 | print("White level is None; using 2 ** 16.") 72 | if metadata['cfa_pattern'] is None: 73 | metadata['cfa_pattern'] = [0, 1, 1, 2] 74 | print("CFAPattern is None; using [0, 1, 1, 2] (RGGB)") 75 | if metadata['as_shot_neutral'] is None: 76 | metadata['as_shot_neutral'] = [1, 1, 1] 77 | print("AsShotNeutral is None; using [1, 1, 1]") 78 | if metadata['color_matrix_1'] is None: 79 | metadata['color_matrix_1'] = [1] * 9 80 | print("ColorMatrix1 is None; using [1, 1, 1, 1, 1, 1, 1, 1, 1]") 81 | if metadata['color_matrix_2'] is None: 82 | metadata['color_matrix_2'] = [1] * 9 83 | print("ColorMatrix2 is None; using [1, 1, 1, 1, 1, 1, 1, 1, 1]") 84 | if metadata['orientation'] is None: 85 | metadata['orientation'] = 0 86 | print("Orientation is None; using 0.") 87 | # ... 88 | return metadata 89 | 90 | 91 | def get_linearization_table(tags, ifds): 92 | possible_keys = ['Image Tag 0xC618', 'Image Tag 50712', 'LinearizationTable', 'Image LinearizationTable'] 93 | return get_values(tags, possible_keys) 94 | 95 | 96 | def get_black_level(tags, ifds): 97 | possible_keys = ['Image Tag 0xC61A', 'Image Tag 50714', 'BlackLevel', 'Image BlackLevel'] 98 | vals = get_values(tags, possible_keys) 99 | if vals is None: 100 | # print("Black level not found in exifread tags. Searching IFDs.") 101 | vals = get_tag_values_from_ifds(50714, ifds) 102 | return vals 103 | 104 | 105 | def get_white_level(tags, ifds): 106 | possible_keys = ['Image Tag 0xC61D', 'Image Tag 50717', 'WhiteLevel', 'Image WhiteLevel'] 107 | vals = get_values(tags, possible_keys) 108 | if vals is None: 109 | # print("White level not found in exifread tags. Searching IFDs.") 110 | vals = get_tag_values_from_ifds(50717, ifds) 111 | return vals 112 | 113 | 114 | def get_cfa_pattern(tags, ifds): 115 | possible_keys = ['CFAPattern', 'Image CFAPattern'] 116 | vals = get_values(tags, possible_keys) 117 | if vals is None: 118 | # print("CFAPattern not found in exifread tags. Searching IFDs.") 119 | vals = get_tag_values_from_ifds(33422, ifds) 120 | return vals 121 | 122 | 123 | def get_as_shot_neutral(tags, ifds): 124 | possible_keys = ['Image Tag 0xC628', 'Image Tag 50728', 'AsShotNeutral', 'Image AsShotNeutral'] 125 | return get_values(tags, possible_keys) 126 | 127 | 128 | def get_color_matrices(tags, ifds): 129 | possible_keys_1 = ['Image Tag 0xC621', 'Image Tag 50721', 'ColorMatrix1', 'Image ColorMatrix1'] 130 | color_matrix_1 = get_values(tags, possible_keys_1) 131 | possible_keys_2 = ['Image Tag 0xC622', 'Image Tag 50722', 'ColorMatrix2', 'Image ColorMatrix2'] 132 | color_matrix_2 = get_values(tags, possible_keys_2) 133 | return color_matrix_1, color_matrix_2 134 | 135 | 136 | def get_orientation(tags, ifds): 137 | possible_tags = ['Orientation', 'Image Orientation'] 138 | return get_values(tags, possible_tags) 139 | 140 | 141 | def get_noise_profile(tags, ifds): 142 | possible_keys = ['Image Tag 0xC761', 'Image Tag 51041', 'NoiseProfile', 'Image NoiseProfile'] 143 | vals = get_values(tags, possible_keys) 144 | if vals is None: 145 | # print("Noise profile not found in exifread tags. Searching IFDs.") 146 | vals = get_tag_values_from_ifds(51041, ifds) 147 | return vals 148 | 149 | 150 | def get_values(tags, possible_keys): 151 | values = None 152 | for key in possible_keys: 153 | if key in tags.keys(): 154 | values = tags[key] #.values 155 | return values 156 | 157 | 158 | def normalize(raw_image, black_level, white_level): 159 | if type(black_level) is list and len(black_level) == 1: 160 | black_level = float(black_level[0]) 161 | if type(white_level) is list and len(white_level) == 1: 162 | white_level = float(white_level[0]) 163 | black_level_mask = black_level 164 | if type(black_level) is list and len(black_level) == 4: 165 | if type(black_level[0]) is Ratio: 166 | black_level = ratios2floats(black_level) 167 | black_level_mask = np.zeros(raw_image.shape) 168 | idx2by2 = [[0, 0], [0, 1], [1, 0], [1, 1]] 169 | step2 = 2 170 | for i, idx in enumerate(idx2by2): 171 | black_level_mask[idx[0]::step2, idx[1]::step2] = black_level[i] 172 | normalized_image = raw_image.astype(np.float32) - black_level_mask 173 | # if some values were smaller than black level 174 | normalized_image[normalized_image < 0] = 0 175 | normalized_image = normalized_image / (white_level - black_level_mask) 176 | return normalized_image 177 | 178 | 179 | def ratios2floats(ratios): 180 | floats = [] 181 | for ratio in ratios: 182 | floats.append(float(ratio.num) / ratio.den) 183 | return floats 184 | 185 | 186 | def lens_shading_correction(raw_image, gain_map_opcode, bayer_pattern, gain_map=None, clip=True): 187 | """ 188 | Apply lens shading correction map. 189 | :param raw_image: Input normalized (in [0, 1]) raw image. 190 | :param gain_map_opcode: Gain map opcode. 191 | :param bayer_pattern: Bayer pattern (RGGB, GRBG, ...). 192 | :param gain_map: Optional gain map to replace gain_map_opcode. 1 or 4 channels in order: R, Gr, Gb, and B. 193 | :param clip: Whether to clip result image to [0, 1]. 194 | :return: Image with gain map applied; lens shading corrected. 195 | """ 196 | 197 | if gain_map is None and gain_map_opcode: 198 | gain_map = gain_map_opcode.data['map_gain_2d'] 199 | 200 | # resize gain map, make it 4 channels, if needed 201 | gain_map = cv2.resize(gain_map, dsize=(raw_image.shape[1] // 2, raw_image.shape[0] // 2), 202 | interpolation=cv2.INTER_LINEAR) 203 | if len(gain_map.shape) == 2: 204 | gain_map = np.tile(gain_map[..., np.newaxis], [1, 1, 4]) 205 | 206 | if gain_map_opcode: 207 | # TODO: consider other parameters 208 | 209 | top = gain_map_opcode.data['top'] 210 | left = gain_map_opcode.data['left'] 211 | bottom = gain_map_opcode.data['bottom'] 212 | right = gain_map_opcode.data['right'] 213 | rp = gain_map_opcode.data['row_pitch'] 214 | cp = gain_map_opcode.data['col_pitch'] 215 | 216 | gm_w = right - left 217 | gm_h = bottom - top 218 | 219 | # gain_map = cv2.resize(gain_map, dsize=(gm_w, gm_h), interpolation=cv2.INTER_LINEAR) 220 | 221 | # TODO 222 | # if top > 0: 223 | # pass 224 | # elif left > 0: 225 | # left_col = gain_map[:, 0:1] 226 | # rep_left_col = np.tile(left_col, [1, left]) 227 | # gain_map = np.concatenate([rep_left_col, gain_map], axis=1) 228 | # elif bottom < raw_image.shape[0]: 229 | # pass 230 | # elif right < raw_image.shape[1]: 231 | # pass 232 | 233 | result_image = raw_image.copy() 234 | 235 | # one channel 236 | # result_image[::rp, ::cp] *= gain_map[::rp, ::cp] 237 | 238 | # per bayer channel 239 | upper_left_idx = [[0, 0], [0, 1], [1, 0], [1, 1]] 240 | bayer_pattern_idx = np.array(bayer_pattern) 241 | # blue channel index --> 3 242 | bayer_pattern_idx[bayer_pattern_idx == 2] = 3 243 | # second green channel index --> 2 244 | if bayer_pattern_idx[3] == 1: 245 | bayer_pattern_idx[3] = 2 246 | else: 247 | bayer_pattern_idx[2] = 2 248 | for c in range(4): 249 | i0 = upper_left_idx[c][0] 250 | j0 = upper_left_idx[c][1] 251 | result_image[i0::2, j0::2] *= gain_map[:, :, bayer_pattern_idx[c]] 252 | 253 | if clip: 254 | result_image = np.clip(result_image, 0.0, 1.0) 255 | 256 | return result_image 257 | 258 | 259 | def white_balance(normalized_image, as_shot_neutral, cfa_pattern): 260 | if type(as_shot_neutral[0]) is Ratio: 261 | as_shot_neutral = ratios2floats(as_shot_neutral) 262 | idx2by2 = [[0, 0], [0, 1], [1, 0], [1, 1]] 263 | step2 = 2 264 | white_balanced_image = np.zeros(normalized_image.shape) 265 | for i, idx in enumerate(idx2by2): 266 | idx_y = idx[0] 267 | idx_x = idx[1] 268 | white_balanced_image[idx_y::step2, idx_x::step2] = \ 269 | normalized_image[idx_y::step2, idx_x::step2] / as_shot_neutral[cfa_pattern[i]] 270 | white_balanced_image = np.clip(white_balanced_image, 0.0, 1.0) 271 | return white_balanced_image 272 | 273 | 274 | def get_opencv_demsaic_flag(cfa_pattern, output_channel_order, alg_type='VNG'): 275 | # using opencv edge-aware demosaicing 276 | if alg_type != '': 277 | alg_type = '_' + alg_type 278 | if output_channel_order == 'BGR': 279 | if cfa_pattern == [0, 1, 1, 2]: # RGGB 280 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2BGR' + alg_type) 281 | elif cfa_pattern == [2, 1, 1, 0]: # BGGR 282 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_RG2BGR' + alg_type) 283 | elif cfa_pattern == [1, 0, 2, 1]: # GRBG 284 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GB2BGR' + alg_type) 285 | elif cfa_pattern == [1, 2, 0, 1]: # GBRG 286 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GR2BGR' + alg_type) 287 | else: 288 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2BGR' + alg_type) 289 | print("CFA pattern not identified.") 290 | else: # RGB 291 | if cfa_pattern == [0, 1, 1, 2]: # RGGB 292 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2RGB' + alg_type) 293 | elif cfa_pattern == [2, 1, 1, 0]: # BGGR 294 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_RG2RGB' + alg_type) 295 | elif cfa_pattern == [1, 0, 2, 1]: # GRBG 296 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GB2RGB' + alg_type) 297 | elif cfa_pattern == [1, 2, 0, 1]: # GBRG 298 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GR2RGB' + alg_type) 299 | else: 300 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2RGB' + alg_type) 301 | print("CFA pattern not identified.") 302 | return opencv_demosaic_flag 303 | 304 | 305 | def demosaic(white_balanced_image, cfa_pattern, output_channel_order='BGR', alg_type='VNG', device='0'): 306 | """ 307 | Demosaic a Bayer image. 308 | :param white_balanced_image: 309 | :param cfa_pattern: 310 | :param output_channel_order: 311 | :param alg_type: algorithm type. options: '', 'EA' for edge-aware, 'VNG' for variable number of gradients 312 | :return: Demosaiced image 313 | """ 314 | if alg_type == 'VNG': 315 | max_val = 16383 316 | wb_image = (white_balanced_image * max_val).astype(dtype=np.uint8) 317 | else: 318 | max_val = 16383 319 | wb_image = (white_balanced_image * max_val).astype(dtype=np.uint16) 320 | 321 | if alg_type in ['', 'EA', 'VNG']: 322 | opencv_demosaic_flag = get_opencv_demsaic_flag(cfa_pattern, output_channel_order, alg_type=alg_type) 323 | demosaiced_image = cv2.cvtColor(wb_image, opencv_demosaic_flag) 324 | elif alg_type == 'menon2007': 325 | cfa_pattern_str = "".join(["RGB"[i] for i in cfa_pattern]) 326 | demosaiced_image = demosaicing_CFA_Bayer_Menon2007(wb_image, pattern=cfa_pattern_str) 327 | elif alg_type == 'down': # G R B G 328 | demosaiced_image = np.zeros([wb_image.shape[0]//2, wb_image.shape[1]//2, 3], dtype=np.float32) 329 | demosaiced_image[:,:,0] = wb_image[0::2, 1::2] 330 | demosaiced_image[:,:,1] = wb_image[0::2, 0::2] 331 | demosaiced_image[:,:,2] = wb_image[1::2, 0::2] 332 | elif alg_type == 'net': 333 | cfa_pattern_str = "".join(["RGB"[i] for i in cfa_pattern]) 334 | bayer = np.power(np.clip(wb_image.astype(dtype=np.float32) / max_val, 0, 1), 1 / 2.2) 335 | pretrained_model_path = os.path.dirname(__file__) + "/model.bin" 336 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=device, 337 | cfa='bayer', state_dict=True) 338 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa=cfa_pattern_str, 339 | demosaic_net=demosaic_net, device=device) 340 | demosaiced_image = np.power(np.clip(rgb, 0, 1), 2.2) * max_val 341 | 342 | demosaiced_image = demosaiced_image.astype(dtype=np.float32) / max_val 343 | 344 | return demosaiced_image 345 | 346 | 347 | def apply_color_space_transform(demosaiced_image, color_matrix_1, color_matrix_2): 348 | if type(color_matrix_1[0]) is Ratio: 349 | color_matrix_1 = ratios2floats(color_matrix_1) 350 | if type(color_matrix_2[0]) is Ratio: 351 | color_matrix_2 = ratios2floats(color_matrix_2) 352 | xyz2cam1 = np.reshape(np.asarray(color_matrix_1), (3, 3)) 353 | xyz2cam2 = np.reshape(np.asarray(color_matrix_2), (3, 3)) 354 | # normalize rows (needed?) 355 | xyz2cam1 = xyz2cam1 / np.sum(xyz2cam1, axis=1, keepdims=True) 356 | xyz2cam2 = xyz2cam2 / np.sum(xyz2cam1, axis=1, keepdims=True) 357 | # inverse 358 | cam2xyz1 = np.linalg.inv(xyz2cam1) 359 | cam2xyz2 = np.linalg.inv(xyz2cam2) 360 | # for now, use one matrix # TODO: interpolate btween both 361 | # simplified matrix multiplication 362 | xyz_image = cam2xyz1[np.newaxis, np.newaxis, :, :] * demosaiced_image[:, :, np.newaxis, :] 363 | xyz_image = np.sum(xyz_image, axis=-1) 364 | xyz_image = np.clip(xyz_image, 0.0, 1.0) 365 | return xyz_image 366 | 367 | 368 | def transform_xyz_to_srgb(xyz_image): 369 | # srgb2xyz = np.array([[0.4124564, 0.3575761, 0.1804375], 370 | # [0.2126729, 0.7151522, 0.0721750], 371 | # [0.0193339, 0.1191920, 0.9503041]]) 372 | 373 | # xyz2srgb = np.linalg.inv(srgb2xyz) 374 | 375 | xyz2srgb = np.array([[3.2404542, -1.5371385, -0.4985314], 376 | [-0.9692660, 1.8760108, 0.0415560], 377 | [0.0556434, -0.2040259, 1.0572252]]) 378 | 379 | # normalize rows (needed?) 380 | xyz2srgb = xyz2srgb / np.sum(xyz2srgb, axis=-1, keepdims=True) 381 | 382 | srgb_image = xyz2srgb[np.newaxis, np.newaxis, :, :] * xyz_image[:, :, np.newaxis, :] 383 | srgb_image = np.sum(srgb_image, axis=-1) 384 | srgb_image = np.clip(srgb_image, 0.0, 1.0) 385 | return srgb_image 386 | 387 | 388 | def fix_orientation(image, orientation): 389 | # 1 = Horizontal(normal) 390 | # 2 = Mirror horizontal 391 | # 3 = Rotate 180 392 | # 4 = Mirror vertical 393 | # 5 = Mirror horizontal and rotate 270 CW 394 | # 6 = Rotate 90 CW 395 | # 7 = Mirror horizontal and rotate 90 CW 396 | # 8 = Rotate 270 CW 397 | 398 | if type(orientation) is list: 399 | orientation = orientation[0] 400 | 401 | if orientation == 1: 402 | pass 403 | elif orientation == 2: 404 | image = cv2.flip(image, 0) 405 | elif orientation == 3: 406 | image = cv2.rotate(image, cv2.ROTATE_180) 407 | elif orientation == 4: 408 | image = cv2.flip(image, 1) 409 | elif orientation == 5: 410 | image = cv2.flip(image, 0) 411 | image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) 412 | elif orientation == 6: 413 | image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) 414 | elif orientation == 7: 415 | image = cv2.flip(image, 0) 416 | image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE) 417 | elif orientation == 8: 418 | image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE) 419 | 420 | return image 421 | 422 | 423 | def reverse_orientation(image, orientation): 424 | # 1 = Horizontal(normal) 425 | # 2 = Mirror horizontal 426 | # 3 = Rotate 180 427 | # 4 = Mirror vertical 428 | # 5 = Mirror horizontal and rotate 270 CW 429 | # 6 = Rotate 90 CW 430 | # 7 = Mirror horizontal and rotate 90 CW 431 | # 8 = Rotate 270 CW 432 | rev_orientations = np.array([1, 2, 3, 4, 5, 8, 7, 6]) 433 | return fix_orientation(image, rev_orientations[orientation - 1]) 434 | 435 | 436 | def apply_gamma(x): 437 | return x ** (1.0 / 2.2) 438 | 439 | 440 | def apply_tone_map(x): 441 | # simple tone curve 442 | # return 3 * x ** 2 - 2 * x ** 3 443 | 444 | # tone_curve = loadmat('tone_curve.mat') 445 | tone_curve = loadmat(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'tone_curve.mat')) 446 | tone_curve = tone_curve['tc'] 447 | x = np.round(x * (len(tone_curve) - 1)).astype(int) 448 | tone_mapped_image = np.squeeze(tone_curve[x]) 449 | return tone_mapped_image 450 | 451 | 452 | def raw_rgb_to_cct(rawRgb, xyz2cam1, xyz2cam2): 453 | """Convert raw-RGB triplet to corresponding correlated color temperature (CCT)""" 454 | pass 455 | # pxyz = [.5, 1, .5] 456 | # loss = 1e10 457 | # k = 1 458 | # while loss > 1e-4: 459 | # cct = XyzToCct(pxyz) 460 | # xyz = RawRgbToXyz(rawRgb, cct, xyz2cam1, xyz2cam2) 461 | # loss = norm(xyz - pxyz) 462 | # pxyz = xyz 463 | # fprintf('k = %d, loss = %f\n', [k, loss]) 464 | # k = k + 1 465 | # end 466 | # temp = cct 467 | -------------------------------------------------------------------------------- /isp/tone_curve.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/isp/tone_curve.mat -------------------------------------------------------------------------------- /mixer.yaml: -------------------------------------------------------------------------------- 1 | name: mixer 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/ 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=5.1=1_gnu 13 | - blas=1.0=mkl 14 | - bzip2=1.0.8=h5eee18b_5 15 | - ca-certificates=2023.12.12=h06a4308_0 16 | - certifi=2024.2.2=py39h06a4308_0 17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 18 | - cuda-cudart=11.8.89=0 19 | - cuda-cupti=11.8.87=0 20 | - cuda-libraries=11.8.0=0 21 | - cuda-nvrtc=11.8.89=0 22 | - cuda-nvtx=11.8.86=0 23 | - cuda-runtime=11.8.0=0 24 | - ffmpeg=4.3=hf484d3e_0 25 | - filelock=3.13.1=py39h06a4308_0 26 | - freetype=2.12.1=h4a9f257_0 27 | - gmp=6.2.1=h295c915_3 28 | - gmpy2=2.1.2=py39heeb90bb_0 29 | - gnutls=3.6.15=he1e5248_0 30 | - idna=3.4=py39h06a4308_0 31 | - intel-openmp=2023.1.0=hdb19cb5_46306 32 | - jinja2=3.1.3=py39h06a4308_0 33 | - jpeg=9e=h5eee18b_1 34 | - lame=3.100=h7b6447c_0 35 | - lcms2=2.12=h3be6417_0 36 | - ld_impl_linux-64=2.38=h1181459_1 37 | - lerc=3.0=h295c915_0 38 | - libcublas=11.11.3.6=0 39 | - libcufft=10.9.0.58=0 40 | - libcufile=1.8.1.2=0 41 | - libcurand=10.3.4.107=0 42 | - libcusolver=11.4.1.48=0 43 | - libcusparse=11.7.5.86=0 44 | - libdeflate=1.17=h5eee18b_1 45 | - libffi=3.4.4=h6a678d5_0 46 | - libgcc-ng=11.2.0=h1234567_1 47 | - libgomp=11.2.0=h1234567_1 48 | - libiconv=1.16=h7f8727e_2 49 | - libidn2=2.3.4=h5eee18b_0 50 | - libjpeg-turbo=2.0.0=h9bf148f_0 51 | - libnpp=11.8.0.86=0 52 | - libnvjpeg=11.9.0.86=0 53 | - libpng=1.6.39=h5eee18b_0 54 | - libstdcxx-ng=11.2.0=h1234567_1 55 | - libtasn1=4.19.0=h5eee18b_0 56 | - libtiff=4.5.1=h6a678d5_0 57 | - libunistring=0.9.10=h27cfd23_0 58 | - libwebp-base=1.3.2=h5eee18b_0 59 | - llvm-openmp=14.0.6=h9e868ea_0 60 | - lz4-c=1.9.4=h6a678d5_0 61 | - markupsafe=2.1.3=py39h5eee18b_0 62 | - mkl=2023.1.0=h213fc3f_46344 63 | - mkl-service=2.4.0=py39h5eee18b_1 64 | - mkl_fft=1.3.8=py39h5eee18b_0 65 | - mkl_random=1.2.4=py39hdb19cb5_0 66 | - mpc=1.1.0=h10f8cd9_1 67 | - mpfr=4.0.2=hb69a4c5_1 68 | - mpmath=1.3.0=py39h06a4308_0 69 | - ncurses=6.4=h6a678d5_0 70 | - nettle=3.7.3=hbbd107a_1 71 | - networkx=3.1=py39h06a4308_0 72 | - numpy=1.26.4=py39h5f9d8c6_0 73 | - numpy-base=1.26.4=py39hb5e798b_0 74 | - openh264=2.1.1=h4ff587b_0 75 | - openjpeg=2.4.0=h3ad879b_0 76 | - openssl=3.0.13=h7f8727e_0 77 | - pillow=10.2.0=py39h5eee18b_0 78 | - pip=23.3.1=py39h06a4308_0 79 | - python=3.9.16=h955ad1f_3 80 | - pytorch-cuda=11.8=h7e8668a_5 81 | - pytorch-mutex=1.0=cuda 82 | - pyyaml=6.0.1=py39h5eee18b_0 83 | - readline=8.2=h5eee18b_0 84 | - requests=2.31.0=py39h06a4308_1 85 | - setuptools=68.2.2=py39h06a4308_0 86 | - sqlite=3.41.2=h5eee18b_0 87 | - sympy=1.12=py39h06a4308_0 88 | - tbb=2021.8.0=hdb19cb5_0 89 | - tk=8.6.12=h1ccaba5_0 90 | - torchaudio=2.2.1=py39_cu118 91 | - typing_extensions=4.9.0=py39h06a4308_1 92 | - tzdata=2024a=h04d1e81_0 93 | - urllib3=2.1.0=py39h06a4308_0 94 | - wheel=0.41.2=py39h06a4308_0 95 | - xz=5.4.6=h5eee18b_0 96 | - yaml=0.2.5=h7b6447c_0 97 | - zlib=1.2.13=h5eee18b_0 98 | - zstd=1.5.5=hc292b87_0 99 | - pip: 100 | - absl-py==2.1.0 101 | - accelerate==0.27.2 102 | - addict==2.4.0 103 | - basicsr==1.4.2 104 | - bypy==1.8.5 105 | - cmake==3.28.3 106 | - colour-demosaicing==0.2.5 107 | - colour-science==0.4.4 108 | - contourpy==1.2.0 109 | - cycler==0.12.1 110 | - dill==0.3.8 111 | - einops==0.7.0 112 | - fonttools==4.50.0 113 | - fsspec==2024.2.0 114 | - future==1.0.0 115 | - grpcio==1.62.0 116 | - huggingface-hub==0.21.3 117 | - imageio==2.34.0 118 | - importlib-metadata==7.0.1 119 | - importlib-resources==6.4.0 120 | - kiwisolver==1.4.5 121 | - lazy-loader==0.3 122 | - lit==17.0.6 123 | - lmdb==1.4.1 124 | - markdown==3.5.2 125 | - matplotlib==3.8.3 126 | - multiprocess==0.70.16 127 | - nvidia-cublas-cu11==11.10.3.66 128 | - nvidia-cuda-cupti-cu11==11.7.101 129 | - nvidia-cuda-nvrtc-cu11==11.7.99 130 | - nvidia-cuda-runtime-cu11==11.7.99 131 | - nvidia-cudnn-cu11==8.5.0.96 132 | - nvidia-cufft-cu11==10.9.0.58 133 | - nvidia-curand-cu11==10.2.10.91 134 | - nvidia-cusolver-cu11==11.4.0.1 135 | - nvidia-cusparse-cu11==11.7.4.91 136 | - nvidia-nccl-cu11==2.14.3 137 | - nvidia-nvtx-cu11==11.7.91 138 | - opencv-python==4.9.0.80 139 | - packaging==23.2 140 | - platformdirs==4.2.0 141 | - protobuf==4.25.3 142 | - psutil==5.9.8 143 | - pyparsing==3.1.2 144 | - python-dateutil==2.9.0.post0 145 | - pywavelets==1.5.0 146 | - requests-toolbelt==1.0.0 147 | - safetensors==0.4.2 148 | - scikit-image==0.22.0 149 | - scipy==1.12.0 150 | - six==1.16.0 151 | - tb-nightly==2.17.0a20240229 152 | - tensorboard-data-server==0.7.2 153 | - tensorboardx==2.6.2.2 154 | - thop==0.1.1-2209072238 155 | - tifffile==2024.2.12 156 | - timm==0.9.16 157 | - tomli==2.0.1 158 | - torch==2.0.1 159 | - torchvision==0.15.2 160 | - tqdm==4.66.2 161 | - triton==2.0.0 162 | - werkzeug==3.0.1 163 | - yapf==0.40.2 164 | - zipp==3.17.0 165 | 166 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | """Import the module "models/[model_name]_model.py". 7 | 8 | In the file, the class called DatasetNameModel() will 9 | be instantiated. It has to be a subclass of BaseModel, 10 | and it is case-insensitive. 11 | """ 12 | model_filename = "models." + model_name + "_model" 13 | modellib = importlib.import_module(model_filename) 14 | model = None 15 | target_model_name = model_name.replace('_', '') + 'model' 16 | for name, cls in modellib.__dict__.items(): 17 | if name.lower() == target_model_name.lower() \ 18 | and issubclass(cls, BaseModel): 19 | model = cls 20 | 21 | if model is None: 22 | raise NotImplementedError("In %s.py, there should be a subclass of " 23 | "BaseModel with class name that matches %s in " 24 | "lowercase." % (model_filename, target_model_name)) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | """Create a model given the option. 36 | 37 | This function warps the class CustomDatasetDataLoader. 38 | This is the main interface between this package and 'train.py'/'test.py' 39 | 40 | Example: 41 | >>> from models import create_model 42 | >>> model = create_model(opt) 43 | """ 44 | model = find_model_using_name(opt.model) 45 | instance = model(opt) 46 | print("model [%s] was created" % type(instance).__name__) 47 | return instance 48 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | import torch 7 | from util.util import torch_save 8 | import math 9 | import torch.nn.functional as F 10 | from data.degrade.process import demosaic 11 | 12 | 13 | class BaseModel(ABC): 14 | def __init__(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = opt.gpu_ids 17 | self.isTrain = opt.isTrain 18 | # self.scale = opt.scale 19 | 20 | if len(self.gpu_ids) > 0: 21 | self.device = torch.device('cuda', self.gpu_ids[0]) 22 | else: 23 | self.device = torch.device('cpu') 24 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 25 | self.loss_names = [] 26 | self.model_names = [] 27 | self.visual_names = [] 28 | self.optimizers = [] 29 | self.optimizer_names = [] 30 | self.image_paths = [] 31 | self.metric = 0 # used for learning rate policy 'plateau' 32 | self.start_epoch = 0 33 | 34 | self.backwarp_tenGrid = {} 35 | self.backwarp_tenPartial = {} 36 | 37 | @staticmethod 38 | def modify_commandline_options(parser, is_train): 39 | return parser 40 | 41 | @abstractmethod 42 | def set_input(self, input): 43 | pass 44 | 45 | @abstractmethod 46 | def forward(self): 47 | pass 48 | 49 | @abstractmethod 50 | def optimize_parameters(self): 51 | pass 52 | 53 | def setup(self, opt=None): 54 | opt = opt if opt is not None else self.opt 55 | if self.isTrain: 56 | self.schedulers = [networks.get_scheduler(optimizer, opt) \ 57 | for optimizer in self.optimizers] 58 | for scheduler in self.schedulers: 59 | scheduler.last_epoch = opt.load_iter 60 | if opt.load_iter > 0 or opt.load_path != '': 61 | load_suffix = opt.load_iter 62 | self.load_networks(load_suffix) 63 | if opt.load_optimizers: 64 | self.load_optimizers(opt.load_iter) 65 | 66 | self.print_networks(opt.verbose) 67 | 68 | def eval(self): 69 | for name in self.model_names: 70 | net = getattr(self, 'net' + name) 71 | net.eval() 72 | 73 | def train(self): 74 | self.isTrain = True 75 | for name in self.model_names: 76 | net = getattr(self, 'net' + name) 77 | net.train() 78 | 79 | def test(self): 80 | self.isTrain = False 81 | with torch.no_grad(): 82 | self.forward() 83 | 84 | def get_image_paths(self): 85 | return self.image_paths 86 | 87 | def update_learning_rate(self, epoch): 88 | for i, scheduler in enumerate(self.schedulers): 89 | if scheduler.__class__.__name__ == 'ReduceLROnPlateau': 90 | scheduler.step(self.metric) 91 | elif scheduler.__class__.__name__ == 'CosineLRScheduler': 92 | scheduler.step(epoch) 93 | else: 94 | scheduler.step() 95 | print('lr of %s = %.7f' % ( 96 | self.optimizer_names[i], self.optimizers[i].param_groups[0]['lr'])) 97 | 98 | def post_process(self, image, max=255): 99 | image = image.permute(0, 2, 3, 1) 100 | image = image / image.max() 101 | image = demosaic(image) 102 | image = image.clamp(0.0, 1.0) ** (1/2.2) 103 | if max == 255: 104 | image = torch.clamp(image * 255, 0, 255).round() 105 | image = image.permute(0, 3, 1, 2) 106 | return image 107 | 108 | def get_current_visuals(self): 109 | visual_ret = OrderedDict() 110 | if self.isTrain: 111 | for name in self.visual_names: 112 | if isinstance(getattr(self, name), list): 113 | visual_ret[name] = self.post_process(getattr(self, name)[-1][0:1].detach()) 114 | elif isinstance(getattr(self, name), torch.Tensor): 115 | visual_ret[name] = self.post_process(getattr(self, name)[0:1].detach()) 116 | else: 117 | raise Exception 118 | # visual_ret[name] = getattr(self, name)[0:1].detach() 119 | else: 120 | for name in self.visual_names: 121 | visual_ret[name] = getattr(self, name).clamp_min_(0).detach() 122 | return visual_ret 123 | 124 | def get_current_losses(self): 125 | errors_ret = OrderedDict() 126 | for name in self.loss_names: 127 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 128 | return errors_ret 129 | 130 | def save_networks(self, epoch): 131 | for name in self.model_names: 132 | save_filename = '%s_model_%d.pth' % (name, epoch) 133 | save_path = os.path.join(self.save_dir, save_filename) 134 | net = getattr(self, 'net' + name) 135 | if self.device.type == 'cuda': 136 | state = {'state_dict': net.module.cpu().state_dict()} 137 | torch_save(state, save_path) 138 | net.to(self.device) 139 | else: 140 | state = {'state_dict': net.state_dict()} 141 | torch_save(state, save_path) 142 | self.save_optimizers(epoch) 143 | 144 | def load_networks(self, epoch): 145 | for name in self.model_names: #[0:1]: 146 | load_filename = '%s_model_%d.pth' % (name, epoch) 147 | if self.opt.load_path != '': 148 | load_path = self.opt.load_path 149 | else: 150 | load_path = os.path.join(self.save_dir, load_filename) 151 | net = getattr(self, 'net' + name) 152 | 153 | state_dict = torch.load(load_path, map_location=self.device) 154 | print('loading the model from %s' % (load_path)) 155 | if hasattr(state_dict, '_metadata'): 156 | del state_dict._metadata 157 | 158 | net_state = net.state_dict() 159 | is_loaded = {n:False for n in net_state.keys()} 160 | for name, param in state_dict['state_dict'].items(): 161 | if name in net_state: 162 | try: 163 | net_state[name].copy_(param) 164 | is_loaded[name] = True 165 | except Exception: 166 | print('While copying the parameter named [%s], ' 167 | 'whose dimensions in the model are %s and ' 168 | 'whose dimensions in the checkpoint are %s.' 169 | % (name, list(net_state[name].shape), 170 | list(param.shape))) 171 | raise RuntimeError 172 | else: 173 | print('Saved parameter named [%s] is skipped' % name) 174 | mark = True 175 | for name in is_loaded: 176 | if not is_loaded[name]: 177 | print('Parameter named [%s] is randomly initialized' % name) 178 | mark = False 179 | if mark: 180 | print('All parameters are initialized using [%s]' % load_path) 181 | 182 | self.start_epoch = epoch 183 | 184 | def load_network_path(self, net, path): 185 | if isinstance(net, torch.nn.DataParallel): 186 | net = net.module 187 | state_dict = torch.load(path, map_location=self.device) 188 | print('loading the model from %s' % (path)) 189 | if hasattr(state_dict, '_metadata'): 190 | del state_dict._metadata 191 | 192 | net_state = net.state_dict() 193 | is_loaded = {n:False for n in net_state.keys()} 194 | for name, param in state_dict['state_dict'].items(): 195 | if name in net_state: 196 | try: 197 | net_state[name].copy_(param) 198 | is_loaded[name] = True 199 | except Exception: 200 | print('While copying the parameter named [%s], ' 201 | 'whose dimensions in the model are %s and ' 202 | 'whose dimensions in the checkpoint are %s.' 203 | % (name, list(net_state[name].shape), 204 | list(param.shape))) 205 | raise RuntimeError 206 | else: 207 | print('Saved parameter named [%s] is skipped' % name) 208 | mark = True 209 | for name in is_loaded: 210 | if not is_loaded[name]: 211 | print('Parameter named [%s] is randomly initialized' % name) 212 | mark = False 213 | if mark: 214 | print('All parameters are initialized using [%s]' % path) 215 | 216 | def save_optimizers(self, epoch): 217 | assert len(self.optimizers) == len(self.optimizer_names) 218 | for id, optimizer in enumerate(self.optimizers): 219 | save_filename = self.optimizer_names[id] 220 | state = {'name': save_filename, 221 | 'epoch': epoch, 222 | 'state_dict': optimizer.state_dict()} 223 | save_path = os.path.join(self.save_dir, save_filename+'.pth') 224 | torch_save(state, save_path) 225 | 226 | def load_optimizers(self, epoch): 227 | assert len(self.optimizers) == len(self.optimizer_names) 228 | for id, optimizer in enumerate(self.optimizer_names): 229 | load_filename = self.optimizer_names[id] 230 | load_path = os.path.join(self.save_dir, load_filename+'.pth') 231 | print('loading the optimizer from %s' % load_path) 232 | state_dict = torch.load(load_path) 233 | assert optimizer == state_dict['name'] 234 | assert epoch == state_dict['epoch'] 235 | self.optimizers[id].load_state_dict(state_dict['state_dict']) 236 | 237 | def print_networks(self, verbose): 238 | print('---------- Networks initialized -------------') 239 | 240 | print('-----------------------------------------------') 241 | 242 | def estimate(self, tenFirst, tenSecond, net): 243 | assert(tenFirst.shape[3] == tenSecond.shape[3]) 244 | assert(tenFirst.shape[2] == tenSecond.shape[2]) 245 | intWidth = tenFirst.shape[3] 246 | intHeight = tenFirst.shape[2] 247 | # tenPreprocessedFirst = tenFirst.view(1, 3, intHeight, intWidth) 248 | # tenPreprocessedSecond = tenSecond.view(1, 3, intHeight, intWidth) 249 | 250 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) 251 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) 252 | 253 | tenPreprocessedFirst = F.interpolate(input=tenFirst, 254 | size=(intPreprocessedHeight, intPreprocessedWidth), 255 | mode='bilinear', align_corners=False) 256 | tenPreprocessedSecond = F.interpolate(input=tenSecond, 257 | size=(intPreprocessedHeight, intPreprocessedWidth), 258 | mode='bilinear', align_corners=False) 259 | 260 | tenFlow = 20.0 * F.interpolate( 261 | input=net(tenPreprocessedFirst, tenPreprocessedSecond), 262 | size=(intHeight, intWidth), mode='bilinear', align_corners=False) 263 | 264 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 265 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 266 | 267 | return tenFlow[:, :, :, :] 268 | 269 | def backwarp(self, tenInput, tenFlow): 270 | index = str(tenFlow.shape) + str(tenInput.device) 271 | if index not in self.backwarp_tenGrid: 272 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), 273 | tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) 274 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), 275 | tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) 276 | self.backwarp_tenGrid[index] = torch.cat([tenHor, tenVer], 1).to(tenInput.device) 277 | 278 | if index not in self.backwarp_tenPartial: 279 | self.backwarp_tenPartial[index] = tenFlow.new_ones([ 280 | tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3]]) 281 | 282 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 283 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 284 | tenInput = torch.cat([tenInput, self.backwarp_tenPartial[index]], 1) 285 | 286 | tenOutput = F.grid_sample(input=tenInput, 287 | grid=(self.backwarp_tenGrid[index] + tenFlow).permute(0, 2, 3, 1), 288 | mode='bilinear', padding_mode='zeros', align_corners=False) 289 | 290 | return tenOutput 291 | 292 | def get_backwarp(self, tenFirst, tenSecond, net, flow=None): 293 | if flow is None: 294 | flow = self.get_flow(tenFirst, tenSecond, net) 295 | 296 | tenoutput = self.backwarp(tenSecond, flow) 297 | tenMask = tenoutput[:, -1:, :, :] 298 | tenMask[tenMask > 0.999] = 1.0 299 | tenMask[tenMask < 1.0] = 0.0 300 | return tenoutput[:, :-1, :, :] * tenMask, tenMask 301 | 302 | def get_flow(self, tenFirst, tenSecond, net): 303 | with torch.no_grad(): 304 | net.eval() 305 | flow = self.estimate(tenFirst, tenSecond, net) 306 | return flow -------------------------------------------------------------------------------- /models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | class invertedBlock(nn.Module): 5 | def __init__(self, in_channel, out_channel,ratio=2): 6 | super(invertedBlock, self).__init__() 7 | internal_channel = in_channel * ratio 8 | self.relu = nn.GELU() 9 | ## 7*7卷积,并行3*3卷积 10 | self.conv1 = nn.Conv2d(internal_channel, internal_channel, 7, 1, 3, groups=in_channel,bias=False) 11 | 12 | self.convFFN = ConvFFN(in_channels=in_channel, out_channels=in_channel) 13 | self.layer_norm = nn.LayerNorm(in_channel) 14 | self.pw1 = nn.Conv2d(in_channels=in_channel, out_channels=internal_channel, kernel_size=1, stride=1, 15 | padding=0, groups=1,bias=False) 16 | self.pw2 = nn.Conv2d(in_channels=internal_channel, out_channels=in_channel, kernel_size=1, stride=1, 17 | padding=0, groups=1,bias=False) 18 | 19 | 20 | def hifi(self,x): 21 | 22 | x1=self.pw1(x) 23 | x1=self.relu(x1) 24 | x1=self.conv1(x1) 25 | x1=self.relu(x1) 26 | x1=self.pw2(x1) 27 | x1=self.relu(x1) 28 | # x2 = self.conv2(x) 29 | x3 = x1+x 30 | 31 | x3 = x3.permute(0, 2, 3, 1).contiguous() 32 | x3 = self.layer_norm(x3) 33 | x3 = x3.permute(0, 3, 1, 2).contiguous() 34 | x4 = self.convFFN(x3) 35 | 36 | return x4 37 | 38 | def forward(self, x): 39 | return self.hifi(x)+x 40 | class ConvFFN(nn.Module): 41 | 42 | def __init__(self, in_channels, out_channels, expend_ratio=4): 43 | super().__init__() 44 | 45 | internal_channels = in_channels * expend_ratio 46 | self.pw1 = nn.Conv2d(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1, 47 | padding=0, groups=1,bias=False) 48 | self.pw2 = nn.Conv2d(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1, 49 | padding=0, groups=1,bias=False) 50 | self.nonlinear = nn.GELU() 51 | 52 | def forward(self, x): 53 | x1 = self.pw1(x) 54 | x2 = self.nonlinear(x1) 55 | x3 = self.pw2(x2) 56 | x4 = self.nonlinear(x3) 57 | return x4 + x 58 | 59 | class mixblock(nn.Module): 60 | def __init__(self, n_feats): 61 | super(mixblock, self).__init__() 62 | self.conv1=nn.Sequential(nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU()) 63 | self.conv2=nn.Sequential(nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU(),nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU(),nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU()) 64 | self.alpha=nn.Parameter(torch.ones(1)) 65 | self.beta=nn.Parameter(torch.ones(1)) 66 | def forward(self,x): 67 | return self.alpha*self.conv1(x)+self.beta*self.conv2(x) 68 | class CALayer(nn.Module): 69 | def __init__(self, channel, reduction=16): 70 | super(CALayer, self).__init__() 71 | # global average pooling: feature --> point 72 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 73 | # feature channel downscale and upscale --> channel weight 74 | self.conv_du = nn.Sequential( 75 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 76 | nn.ReLU(inplace=True), 77 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 78 | nn.Sigmoid() 79 | ) 80 | 81 | def forward(self, x): 82 | y = self.avg_pool(x) 83 | y = self.conv_du(y) 84 | return x * y 85 | class Downupblock(nn.Module): 86 | def __init__(self, n_feats): 87 | super(Downupblock, self).__init__() 88 | self.encoder = mixblock(n_feats) 89 | self.decoder_high = mixblock(n_feats) # nn.Sequential(one_module(n_feats), 90 | 91 | self.decoder_low = nn.Sequential(mixblock(n_feats), mixblock(n_feats), mixblock(n_feats)) 92 | self.alise = nn.Conv2d(n_feats,n_feats,1,1,0,bias=False) # one_module(n_feats) 93 | self.alise2 = nn.Conv2d(n_feats*2,n_feats,3,1,1,bias=False) # one_module(n_feats) 94 | self.down = nn.AvgPool2d(kernel_size=2) 95 | self.att = CALayer(n_feats) 96 | self.raw_alpha=nn.Parameter(torch.ones(1)) 97 | 98 | self.raw_alpha.data.fill_(0) 99 | self.ega=selfAttention(n_feats, n_feats) 100 | 101 | def forward(self, x,raw): 102 | x1 = self.encoder(x) 103 | x2 = self.down(x1) 104 | high = x1 - F.interpolate(x2, size=x.size()[-2:], mode='bilinear', align_corners=True) 105 | 106 | high=high+self.ega(high,high)*self.raw_alpha 107 | x2=self.decoder_low(x2) 108 | x3 = x2 109 | # x3 = self.decoder_low(x2) 110 | high1 = self.decoder_high(high) 111 | x4 = F.interpolate(x3, size=x.size()[-2:], mode='bilinear', align_corners=True) 112 | return self.alise(self.att(self.alise2(torch.cat([x4, high1], dim=1)))) + x 113 | class Updownblock(nn.Module): 114 | def __init__(self, n_feats): 115 | super(Updownblock, self).__init__() 116 | self.encoder = mixblock(n_feats) 117 | self.decoder_high = mixblock(n_feats) # nn.Sequential(one_module(n_feats), 118 | # one_module(n_feats), 119 | # one_module(n_feats)) 120 | self.decoder_low = nn.Sequential(mixblock(n_feats), mixblock(n_feats), mixblock(n_feats)) 121 | 122 | self.alise = nn.Conv2d(n_feats,n_feats,1,1,0,bias=False) # one_module(n_feats) 123 | self.alise2 = nn.Conv2d(n_feats*2,n_feats,3,1,1,bias=False) # one_module(n_feats) 124 | self.down = nn.AvgPool2d(kernel_size=2) 125 | self.att = CALayer(n_feats) 126 | self.raw_alpha=nn.Parameter(torch.ones(1)) 127 | # fill 0 128 | self.raw_alpha.data.fill_(0) 129 | self.ega=selfAttention(n_feats, n_feats) 130 | 131 | def forward(self, x,raw): 132 | x1 = self.encoder(x) 133 | x2 = self.down(x1) 134 | high = x1 - F.interpolate(x2, size=x.size()[-2:], mode='bilinear', align_corners=True) 135 | high=high+self.ega(high,high)*self.raw_alpha 136 | x2=self.decoder_low(x2) 137 | x3 = x2 138 | high1 = self.decoder_high(high) 139 | x4 = F.interpolate(x3, size=x.size()[-2:], mode='bilinear', align_corners=True) 140 | return self.alise(self.att(self.alise2(torch.cat([x4, high1], dim=1)))) + x 141 | class basic_block(nn.Module): 142 | ## 双并行分支,通道分支和空间分支 143 | def __init__(self, in_channel, out_channel, depth,ratio=1): 144 | super(basic_block, self).__init__() 145 | 146 | 147 | 148 | 149 | # 个数为depth个 150 | 151 | self.rep1 = nn.Sequential(*[invertedBlock(in_channel=in_channel, out_channel=in_channel,ratio=ratio) for i in range(depth)]) 152 | 153 | 154 | self.relu=nn.GELU() 155 | # 一部分做3个3*3卷积,一部分做1个 156 | 157 | self.updown=Updownblock(in_channel) 158 | self.downup=Downupblock(in_channel) 159 | def forward(self, x,raw=None): 160 | 161 | 162 | x1 = self.rep1(x) 163 | 164 | 165 | x1=self.updown(x1,raw) 166 | x1=self.downup(x1,raw) 167 | return x1+x 168 | 169 | import torchvision 170 | class VGG_aware(nn.Module): 171 | def __init__(self,outFeature): 172 | super(VGG_aware, self).__init__() 173 | blocks = [] 174 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval()) 175 | 176 | for bl in blocks: 177 | for p in bl: 178 | p.requires_grad = False 179 | self.blocks = torch.nn.ModuleList(blocks) 180 | 181 | 182 | def forward(self, x): 183 | return self.blocks[0](x) 184 | 185 | import torch.nn.functional as f 186 | class selfAttention(nn.Module): 187 | def __init__(self, in_channels, out_channels): 188 | super(selfAttention, self).__init__() 189 | self.query_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1) 190 | self.key_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1) 191 | self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1) 192 | self.scale = 1.0 / (out_channels ** 0.5) 193 | 194 | def forward(self, feature, feature_map): 195 | query = self.query_conv(feature) 196 | key = self.key_conv(feature) 197 | value = self.value_conv(feature) 198 | attention_scores = torch.matmul(query, key.transpose(-2, -1)) 199 | attention_scores = attention_scores * self.scale 200 | 201 | attention_weights = f.softmax(attention_scores, dim=-1) 202 | 203 | attended_values = torch.matmul(attention_weights, value) 204 | 205 | output_feature_map = (feature_map + attended_values) 206 | 207 | return output_feature_map -------------------------------------------------------------------------------- /models/cat_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks as N 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from . import losses as L 7 | import torch.nn.functional as F 8 | import torchvision.ops as ops 9 | from util.util import mu_tonemap 10 | 11 | 12 | # For BracketIRE Task 13 | class CatModel(BaseModel): 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train=True): 16 | return parser 17 | 18 | def __init__(self, opt): 19 | super(CatModel, self).__init__(opt) 20 | 21 | self.opt = opt 22 | self.loss_names = ['TMRNet_l1', 'Total'] 23 | self.visual_names = ['data_gt', 'data_in', 'data_out'] 24 | self.model_names = ['TMRNet'] 25 | self.optimizer_names = ['TMRNet_optimizer_%s' % opt.optimizer] 26 | 27 | tmrnet = TMRNet(opt) 28 | # tmrnet = AHDR(8, 6, 64, 32) 29 | self.netTMRNet = N.init_net(tmrnet, opt.init_type, opt.init_gain, opt.gpu_ids) 30 | 31 | if self.isTrain: 32 | self.optimizer_TMRNet = optim.AdamW(self.netTMRNet.parameters(), 33 | lr=opt.lr, 34 | betas=(opt.beta1, opt.beta2), 35 | weight_decay=opt.weight_decay) 36 | self.optimizers = [self.optimizer_TMRNet] 37 | 38 | self.criterionL1 = N.init_net(L.L1Loss(), gpu_ids=opt.gpu_ids) 39 | 40 | def set_input(self, input): 41 | self.data_gt = input['gt'].to(self.device) 42 | self.data_raws = input['raws'].to(self.device) 43 | self.image_paths = input['fname'] 44 | 45 | expo = torch.stack([torch.pow(torch.tensor(4, dtype=torch.float32, device=self.data_raws.device), 2-x) 46 | for x in range(0, self.data_raws.shape[1])]) 47 | self.expo = expo[None,:,None,None,None] 48 | 49 | def forward(self): 50 | 51 | 52 | self.data_raws = self.data_raws * self.expo 53 | self.data_in = self.data_raws[:,0,...].squeeze(1) 54 | 55 | if self.isTrain or (not self.isTrain and not self.opt.chop): 56 | # datat.raws: [N, T, C, H, W],在第二个维度进行展开,分为T个数据,每个数据为[N,C, H, W] 57 | # x1 = self.data_raws[:, 0, :, :, :] 58 | # x2 = self.data_raws[:, 1, :, :, :] 59 | # x3 = self.data_raws[:, 2, :, :, :] 60 | # x4= self.data_raws[:, 3, :, :, :] 61 | # x5 = self.data_raws[:, 4, :, :, :] 62 | # x1=torch.cat((x1,ldr_to_hdr(x1,1,1/2.2)),1) 63 | # x2=torch.cat((x2,ldr_to_hdr(x2,1,1/2.2)),1) 64 | # x3=torch.cat((x3,ldr_to_hdr(x3,1,1/2.2)),1) 65 | # x4=torch.cat((x4,ldr_to_hdr(x4,1,1/2.2)),1) 66 | # x5=torch.cat((x5,ldr_to_hdr(x5,1,1/2.2)),1) 67 | 68 | # self.data_out = self.netTMRNet(x1, x2, x3, x4, x5) 69 | self.data_out = self.netTMRNet(self.data_raws) 70 | elif self.opt.chop: 71 | self.data_out = self.forward_chop(self.data_raws) 72 | 73 | def forward_chop(self, data_raws, chop_size=800): 74 | n, t, c, h, w = data_raws.shape 75 | 76 | num_h = h // chop_size + 1 77 | num_w = w // chop_size + 1 78 | new_h = num_h * chop_size 79 | new_w = num_w * chop_size 80 | 81 | pad_h = new_h - h 82 | pad_w = new_w - w 83 | 84 | pad_top = int(pad_h / 2.) 85 | pad_bottom = pad_h - pad_top 86 | pad_left = int(pad_w / 2.) 87 | pad_right = pad_w - pad_left 88 | 89 | paddings = (pad_left, pad_right, pad_top, pad_bottom) 90 | new_input0 = torch.nn.ReflectionPad2d(paddings)(data_raws[0]) 91 | 92 | out = torch.zeros([1, c, new_h, new_w], dtype=torch.float32, device=data_raws.device) 93 | for i in range(num_h): 94 | for j in range(num_w): 95 | out[:, :, i*chop_size:(i+1)*chop_size, j*chop_size:(j+1)*chop_size] = self.netTMRNet( 96 | new_input0.unsqueeze(0)[:,:,:,i*chop_size:(i+1)*chop_size, j*chop_size:(j+1)*chop_size]) 97 | return out[:, :, pad_top:pad_top+h, pad_left:pad_left+w] 98 | 99 | def backward(self, epoch): 100 | self.loss_TMRNet_l1 = self.criterionL1( 101 | mu_tonemap(torch.clamp(self.data_out / 4**2, min=0)), 102 | mu_tonemap(torch.clamp(self.data_gt / 4**2, 0, 1))).mean() 103 | 104 | self.loss_Total = self.loss_TMRNet_l1 105 | self.loss_Total.backward() 106 | 107 | def optimize_parameters(self, epoch): 108 | self.forward() 109 | self.optimizer_TMRNet.zero_grad() 110 | self.backward(epoch) 111 | self.optimizer_TMRNet.step() 112 | 113 | 114 | class TMRNet(nn.Module): 115 | def __init__(self, opt, mid_channels=64, max_residue_magnitude=10): 116 | 117 | super().__init__() 118 | self.mid_channels = mid_channels 119 | 120 | # optical flow 121 | self.spynet = SPyNet() 122 | if opt.isTrain: 123 | N.load_spynet(self.spynet, './spynet/spynet_20210409-c6c1bd09.pth') 124 | 125 | self.dcn_alignment = DeformableAlignment(mid_channels, mid_channels, 3, padding=1, deform_groups=8, 126 | max_residue_magnitude=max_residue_magnitude) 127 | 128 | # feature extraction module 129 | self.feat_extract = ResidualBlocksWithInputConv(2 * 4, mid_channels, 5) 130 | 131 | # propagation branches 132 | 133 | self.backbone = nn.ModuleDict() 134 | if opt.block == 'tmrnet': 135 | self.backbone['backward'] = ResidualBlocksWithInputConv(5 * mid_channels, mid_channels, 16) 136 | 137 | for i in range(0, 5): 138 | self.backbone['backward_rec_%d' % (i + 1)] = ResidualBlocksWithInputConv(1 * mid_channels, mid_channels, 139 | 24) 140 | 141 | 142 | if opt.block == 'Convnext': 143 | self.backbone['backward'] = ConvnextBlocksWithInputConv(5 * mid_channels, mid_channels, 1) 144 | 145 | for i in range(0, 2): 146 | self.backbone['backward_rec_%d' % (i + 1)] = ConvnextBlocksWithInputConv(1 * mid_channels, mid_channels, 147 | 1) 148 | self.reconstruction = ResidualBlocksWithInputConv(2 * mid_channels, mid_channels, 5) 149 | self.down = ResidualBlocksWithInputConv(3 * mid_channels, mid_channels, 3) 150 | self.skipup1 = PixelShufflePack(4, mid_channels, 1, upsample_kernel=3) 151 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 152 | 153 | self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1) 154 | self.conv_last = nn.Conv2d(mid_channels, 4, 3, 1, 1) 155 | 156 | def compute_flow(self, lqs): 157 | lqs = torch.stack((lqs[:, :, 0], lqs[:, :, 1:3].mean(dim=2), lqs[:, :, 3]), dim=2) 158 | lqs = torch.pow(torch.clamp(lqs, 0, 1), 1 / 2.2) 159 | n, t, c, h, w = lqs.size() 160 | oth = lqs[:, 1:, :, :, :].reshape(-1, c, h, w) 161 | ref = lqs[:, :1, :, :, :].repeat(1, t - 1, 1, 1, 1).reshape(-1, c, h, w) 162 | flows_backward = self.spynet(ref, oth).view(n, t - 1, 2, h, w) 163 | flows_forward = flows_backward 164 | return flows_forward, flows_backward 165 | 166 | def upsample(self, lqs, feats, base_feat): 167 | skip1 = self.skipup1(lqs[:, 0, :, :, :]) 168 | hr = self.down(feats) 169 | hr = torch.cat([base_feat, hr], dim=1) 170 | hr = self.reconstruction(hr) 171 | hr = hr + skip1 172 | hr = self.lrelu(self.conv_hr(hr)) 173 | hr = self.conv_last(hr) 174 | return hr 175 | 176 | def forward(self, lqs): 177 | n, t, c, h, w = lqs.size() # (n, t, c, h, w) 178 | lqs_downsample = lqs.clone() 179 | 180 | feats = {} 181 | lqs_view = lqs.view(-1, c, h, w) 182 | lqs_in = torch.zeros([n * t, 2 * c, h, w], dtype=lqs_view.dtype, device=lqs_view.device) 183 | lqs_in[:, 0::2, :, :] = lqs_view 184 | # torch clamp是 185 | lqs_in[:, 1::2, :, :] = torch.pow(torch.clamp(lqs_view, min=0), 1 / 2.2) 186 | 187 | feats_ = self.feat_extract(lqs_in) # (N*T, C, H, W) 188 | h, w = feats_.shape[2:] 189 | feats_ = feats_.view(n, t, -1, h, w) 190 | 191 | _, flows_backward = self.compute_flow(lqs_downsample) 192 | flows_backward = flows_backward.view(-1, 2, *feats_.shape[-2:]) 193 | 194 | ref_feat = feats_[:, :1, :, :, :].repeat(1, t - 1, 1, 1, 1).view(-1, *feats_.shape[-3:]) 195 | oth_feat = feats_[:, 1:, :, :, :].contiguous().view(-1, *feats_.shape[-3:]) 196 | 197 | oth_feat_warped = N.flow_warp(oth_feat, flows_backward.permute(0, 2, 3, 1)) 198 | oth_feat = self.dcn_alignment(oth_feat, ref_feat, oth_feat_warped, flows_backward) 199 | oth_feat = oth_feat.view(n, t - 1, -1, h, w) 200 | ref_feat = ref_feat.view(n, t - 1, -1, h, w)[:, :1, :, :, :] 201 | 202 | feats_ = torch.cat((ref_feat, oth_feat), dim=1) # (N, T, C, H, W) 203 | 204 | base_feat = feats_[:, 0, :, :, :] 205 | feats_ = torch.cat((feats_[:, 0, :, :, :], feats_[:, 1, :, :, :], feats_[:, 2, :, :, :], feats_[:, 3, :, :, :], 206 | feats_[:, 4, :, :, :]), 1) 207 | # feature propagation 208 | module = 'backward' 209 | feats0 = self.backbone[module](feats_) 210 | feats1 = self.backbone['backward_rec_%d' % (0 + 1)](feats0) 211 | feats2 = self.backbone['backward_rec_%d' % (1 + 1)](feats1) 212 | 213 | feats0 = torch.cat((feats0, feats1, feats2), 1) 214 | 215 | out = self.upsample(lqs, feats0, base_feat) 216 | 217 | return out 218 | 219 | 220 | 221 | 222 | from .blocks import * 223 | 224 | 225 | class ConvnextBlocksWithInputConv(nn.Module): 226 | def __init__(self, in_channels, out_channels=64, num_blocks=3): 227 | super().__init__() 228 | 229 | main = [] 230 | # a convolution used to match the channels of the residual blocks 231 | main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True)) 232 | main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) 233 | 234 | # residual blocks 235 | for i in range(num_blocks): 236 | main.append( 237 | basic_block(out_channels, out_channels, depth=10, ratio=1)) 238 | 239 | self.main = nn.Sequential(*main) 240 | 241 | def forward(self, feat): 242 | return self.main(feat) 243 | 244 | 245 | class ResidualBlocksWithInputConv(nn.Module): 246 | def __init__(self, in_channels, out_channels=64, num_blocks=30): 247 | super().__init__() 248 | 249 | main = [] 250 | # a convolution used to match the channels of the residual blocks 251 | main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True)) 252 | main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True)) 253 | 254 | # residual blocks 255 | main.append( 256 | N.make_layer( 257 | ResidualBlockNoBN, num_blocks, mid_channels=out_channels)) 258 | 259 | self.main = nn.Sequential(*main) 260 | 261 | def forward(self, feat): 262 | return self.main(feat) 263 | 264 | 265 | class ResidualBlockNoBN(nn.Module): 266 | def __init__(self, mid_channels=64, res_scale=1): 267 | super().__init__() 268 | self.res_scale = res_scale 269 | self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) 270 | self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True) 271 | 272 | self.relu = nn.ReLU(inplace=True) 273 | 274 | self.init_weights() 275 | 276 | def init_weights(self): 277 | N.init_weights(self.conv1, init_type='kaiming') 278 | N.init_weights(self.conv2, init_type='kaiming') 279 | self.conv1.weight.data *= 0.1 280 | self.conv2.weight.data *= 0.1 281 | 282 | def forward(self, x): 283 | identity = x 284 | out = self.conv2(self.relu(self.conv1(x))) 285 | return identity + out * self.res_scale 286 | 287 | 288 | class PixelShufflePack(nn.Module): 289 | def __init__(self, in_channels, out_channels, scale_factor, 290 | upsample_kernel): 291 | super().__init__() 292 | self.in_channels = in_channels 293 | self.out_channels = out_channels 294 | self.scale_factor = scale_factor 295 | self.upsample_kernel = upsample_kernel 296 | self.upsample_conv = nn.Conv2d( 297 | self.in_channels, 298 | self.out_channels * scale_factor * scale_factor, 299 | self.upsample_kernel, 300 | padding=(self.upsample_kernel - 1) // 2) 301 | self.init_weights() 302 | 303 | def init_weights(self): 304 | """Initialize weights for PixelShufflePack.""" 305 | N.init_weights(self.upsample_conv, init_type='kaiming') 306 | 307 | def forward(self, x): 308 | x = self.upsample_conv(x) 309 | if self.scale_factor > 1: 310 | x = F.pixel_shuffle(x, self.scale_factor) 311 | return x 312 | 313 | 314 | class DeformableAlignment(nn.Module): 315 | def __init__(self, in_channels, out_channels, kernel=3, padding=1, deform_groups=8, max_residue_magnitude=10): 316 | super().__init__() 317 | 318 | self.max_residue_magnitude = max_residue_magnitude 319 | 320 | self.conv_offset = nn.Sequential( 321 | nn.Conv2d(2 * out_channels + 2, out_channels, 3, 1, 1), 322 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 323 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 324 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 325 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 326 | nn.LeakyReLU(negative_slope=0.1, inplace=True), 327 | nn.Conv2d(out_channels, 27 * deform_groups, 3, 1, 1), 328 | ) 329 | 330 | self.deform_conv = ops.DeformConv2d(in_channels, out_channels, kernel_size=kernel, stride=1, 331 | padding=padding, dilation=1, bias=True, groups=deform_groups) 332 | 333 | self.init_offset() 334 | 335 | def init_offset(self): 336 | N.init_weights(self.conv_offset[-1], init_type='constant') 337 | 338 | def forward(self, cur_feat, ref_feat, warped_feat, flow): 339 | extra_feat = torch.cat([warped_feat, ref_feat, flow], dim=1) 340 | out = self.conv_offset(extra_feat) 341 | o1, o2, mask = torch.chunk(out, 3, dim=1) 342 | offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) 343 | offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1) 344 | mask = torch.sigmoid(mask) 345 | return self.deform_conv(cur_feat, offset, mask=mask) 346 | 347 | 348 | class SPyNet(nn.Module): 349 | """SPyNet network structure. 350 | 351 | The difference to the SPyNet in [tof.py] is that 352 | 1. more SPyNetBasicModule is used in this version, and 353 | 2. no batch normalization is used in this version. 354 | 355 | Paper: 356 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 357 | 358 | Args: 359 | pretrained (str): path for pre-trained SPyNet. Default: None. 360 | """ 361 | 362 | def __init__(self): 363 | super().__init__() 364 | 365 | self.basic_module = nn.ModuleList( 366 | [SPyNetBasicModule() for _ in range(6)]) 367 | 368 | self.register_buffer( 369 | 'mean', 370 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 371 | self.register_buffer( 372 | 'std', 373 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 374 | 375 | def compute_flow(self, ref, supp): 376 | """Compute flow from ref to supp. 377 | 378 | Note that in this function, the images are already resized to a 379 | multiple of 32. 380 | 381 | Args: 382 | ref (Tensor): Reference image with shape of (n, 3, h, w). 383 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 384 | 385 | Returns: 386 | Tensor: Estimated optical flow: (n, 2, h, w). 387 | """ 388 | n, _, h, w = ref.size() 389 | 390 | # normalize the input images 391 | ref = [(ref - self.mean) / self.std] 392 | supp = [(supp - self.mean) / self.std] 393 | 394 | # generate downsampled frames 395 | for level in range(5): 396 | ref.append( 397 | F.avg_pool2d( 398 | input=ref[-1], 399 | kernel_size=2, 400 | stride=2, 401 | count_include_pad=False)) 402 | supp.append( 403 | F.avg_pool2d( 404 | input=supp[-1], 405 | kernel_size=2, 406 | stride=2, 407 | count_include_pad=False)) 408 | ref = ref[::-1] 409 | supp = supp[::-1] 410 | 411 | # flow computation 412 | flow = ref[0].new_zeros(n, 2, h // 32, w // 32) 413 | for level in range(len(ref)): 414 | if level == 0: 415 | flow_up = flow 416 | else: 417 | flow_up = F.interpolate( 418 | input=flow, 419 | scale_factor=2, 420 | mode='bilinear', 421 | align_corners=True) * 2.0 422 | 423 | # add the residue to the upsampled flow 424 | flow = flow_up + self.basic_module[level]( 425 | torch.cat([ 426 | ref[level], 427 | N.flow_warp( 428 | supp[level], 429 | flow_up.permute(0, 2, 3, 1), 430 | padding_mode='border'), flow_up 431 | ], 1)) 432 | 433 | return flow 434 | 435 | def forward(self, ref, supp): 436 | """Forward function of SPyNet. 437 | 438 | This function computes the optical flow from ref to supp. 439 | 440 | Args: 441 | ref (Tensor): Reference image with shape of (n, 3, h, w). 442 | supp (Tensor): Supporting image with shape of (n, 3, h, w). 443 | 444 | Returns: 445 | Tensor: Estimated optical flow: (n, 2, h, w). 446 | """ 447 | 448 | # upsize to a multiple of 32 449 | h, w = ref.shape[2:4] 450 | w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) 451 | h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) 452 | ref = F.interpolate( 453 | input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False) 454 | supp = F.interpolate( 455 | input=supp, 456 | size=(h_up, w_up), 457 | mode='bilinear', 458 | align_corners=False) 459 | 460 | # compute flow, and resize back to the original resolution 461 | flow = F.interpolate( 462 | input=self.compute_flow(ref, supp), 463 | size=(h, w), 464 | mode='bilinear', 465 | align_corners=False) 466 | 467 | # adjust the flow values 468 | flow[:, 0, :, :] *= float(w) / float(w_up) 469 | flow[:, 1, :, :] *= float(h) / float(h_up) 470 | 471 | return flow 472 | 473 | 474 | class SPyNetBasicModule(nn.Module): 475 | """Basic Module for SPyNet. 476 | 477 | Paper: 478 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 479 | """ 480 | 481 | def __init__(self): 482 | super().__init__() 483 | 484 | self.basic_module = nn.Sequential( 485 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), 486 | nn.ReLU(inplace=True), 487 | 488 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), 489 | nn.ReLU(inplace=True), 490 | 491 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), 492 | nn.ReLU(inplace=True), 493 | 494 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), 495 | nn.ReLU(inplace=True), 496 | 497 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3) 498 | ) 499 | 500 | def forward(self, tensor_input): 501 | """ 502 | Args: 503 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w). 504 | 8 channels contain: 505 | [reference image (3), neighbor image (3), initial flow (2)]. 506 | 507 | Returns: 508 | Tensor: Refined flow with shape (b, 2, h, w) 509 | """ 510 | return self.basic_module(tensor_input) 511 | 512 | -------------------------------------------------------------------------------- /models/degrade/degrade_kernel.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import random 4 | import torch 5 | from scipy import ndimage 6 | from scipy.interpolate import interp2d 7 | from .unprocess import unprocess, random_noise_levels, add_noise 8 | from .process import process 9 | from PIL import Image 10 | 11 | 12 | 13 | # def get_rgb2raw2rgb(img): 14 | # img = torch.from_numpy(np.array(img)) / 255.0 15 | # deg_img, features = unprocess(img) 16 | # shot_noise, read_noise = random_noise_levels() 17 | # deg_img = add_noise(deg_img, shot_noise, read_noise) 18 | # deg_img = deg_img.unsqueeze(0) 19 | # features['red_gain'] = features['red_gain'].unsqueeze(0) 20 | # features['blue_gain'] = features['blue_gain'].unsqueeze(0) 21 | # features['cam2rgb'] = features['cam2rgb'].unsqueeze(0) 22 | # deg_img = process(deg_img, features['red_gain'], features['blue_gain'], features['cam2rgb']) 23 | # deg_img = deg_img.squeeze(0) 24 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy() 25 | # deg_img = deg_img.astype(np.uint8) 26 | # return Image.fromarray(deg_img) 27 | 28 | 29 | # def get_rgb2raw_noise(img, noise_level, features=None): 30 | # # img = np.transpose(img, (1, 2, 0)) 31 | # img = torch.from_numpy(np.array(img)) / 255.0 32 | 33 | # deg_img, features = unprocess(img, features) 34 | # shot_noise, read_noise = random_noise_levels(noise_level) 35 | # deg_img_noise = add_noise(deg_img, shot_noise, read_noise) 36 | # # deg_img_noise = torch.clamp(deg_img_noise, min=0.0, max=1.0) 37 | 38 | # # deg_img = np.transpose(deg_img, (2, 0, 1)) 39 | # # deg_img_noise = np.transpose(deg_img_noise, (2, 0, 1)) 40 | # return deg_img_noise, features 41 | 42 | 43 | def get_rgb2raw(img, features=None): 44 | # img = np.transpose(img, (1, 2, 0)) 45 | device = img.device 46 | deg_img, features = unprocess(img, features, device) 47 | return deg_img, features 48 | 49 | 50 | def get_raw2rgb(img, features, demosaic='net', lineRGB=False): 51 | # img = np.transpose(img, (1, 2, 0)) 52 | # img = torch.from_numpy(np.array(img)) 53 | img = img.unsqueeze(0) 54 | device = img.device 55 | deg_img = process(img, features['red_gain'].to(device), features['blue_gain'].to(device), 56 | features['cam2rgb'].to(device), demosaic, lineRGB) 57 | deg_img = deg_img.squeeze(0) 58 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy() 59 | # deg_img = deg_img.astype(np.uint8) 60 | return deg_img 61 | 62 | def get_raw2rgb2(img, features, demosaic='net', lineRGB=False): 63 | # img = np.transpose(img, (1, 2, 0)) 64 | # img = torch.from_numpy(np.array(img)) 65 | image=[] 66 | for i in range(img.shape[0]): 67 | image.append(img[i,:,:]) 68 | deg_images=[] 69 | # img = img.unsqueeze(0) 70 | device = img.device 71 | for i in range(img.shape[0]): 72 | im=image[i].unsqueeze(0) 73 | deg_img = process(im, features['red_gain'][i].unsqueeze(0).to(device), features['blue_gain'][i].unsqueeze(0).to(device), 74 | features['cam2rgb'][i].unsqueeze(0).to(device), demosaic, lineRGB) 75 | deg_img = deg_img.squeeze(0) 76 | deg_images.append(deg_img) 77 | result=torch.stack(deg_images) 78 | # deg_img = deg_img.squeeze(0) 79 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy() 80 | # deg_img = deg_img.astype(np.uint8) 81 | return result 82 | # def pack_raw_image(im_raw): # HxW 83 | # """ Packs a single channel bayer image into 4 channel tensor, where channels contain R, G, G, and B values""" 84 | # if isinstance(im_raw, np.ndarray): 85 | # im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2)) 86 | # elif isinstance(im_raw, torch.Tensor): 87 | # im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype) 88 | # else: 89 | # raise Exception 90 | 91 | # im_out[0, :, :] = im_raw[0::2, 0::2] 92 | # im_out[1, :, :] = im_raw[0::2, 1::2] 93 | # im_out[2, :, :] = im_raw[1::2, 0::2] 94 | # im_out[3, :, :] = im_raw[1::2, 1::2] 95 | # return im_out # 4xHxW 96 | 97 | 98 | # def flatten_raw_image(im_raw_4ch): # 4xHxW 99 | # """ unpack a 4-channel tensor into a single channel bayer image""" 100 | # if isinstance(im_raw_4ch, np.ndarray): 101 | # im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2)) 102 | # elif isinstance(im_raw_4ch, torch.Tensor): 103 | # im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype) 104 | # else: 105 | # raise Exception 106 | 107 | # im_out[0::2, 0::2] = im_raw_4ch[0, :, :] 108 | # im_out[0::2, 1::2] = im_raw_4ch[1, :, :] 109 | # im_out[1::2, 0::2] = im_raw_4ch[2, :, :] 110 | # im_out[1::2, 1::2] = im_raw_4ch[3, :, :] 111 | 112 | # return im_out # HxW -------------------------------------------------------------------------------- /models/degrade/process.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Forward processing of raw data to sRGB images. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | import numpy as np 23 | import torch 24 | import torch.nn as nn 25 | import torch.distributions as tdist 26 | from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007 27 | import os 28 | from isp import demosaic_bayer 29 | 30 | 31 | def apply_gains(bayer_images, red_gains, blue_gains): 32 | """Applies white balance gains to a batch of Bayer images.""" 33 | red_gains = red_gains.squeeze(1) 34 | blue_gains= blue_gains.squeeze(1) 35 | green_gains = torch.ones_like(red_gains) 36 | gains = torch.stack([red_gains, green_gains, green_gains, blue_gains], dim=-1) 37 | gains = gains[:, None, None, :] 38 | # print(bayer_images.shape, gains.shape) 39 | outs = bayer_images * gains 40 | return outs 41 | 42 | 43 | def demosaic(bayer_images): 44 | def SpaceToDepth_fact2(x): 45 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14 46 | bs = 2 47 | N, C, H, W = x.size() 48 | x = x.view(N, C, H // bs, bs, W // bs, bs) # (N, C, H//bs, bs, W//bs, bs) 49 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs) 50 | x = x.view(N, C * (bs ** 2), H // bs, W // bs) # (N, C*bs^2, H//bs, W//bs) 51 | return x 52 | def DepthToSpace_fact2(x): 53 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14 54 | bs = 2 55 | N, C, H, W = x.size() 56 | x = x.view(N, bs, bs, C // (bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W) 57 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs) 58 | x = x.view(N, C // (bs ** 2), H * bs, W * bs) # (N, C//bs^2, H * bs, W * bs) 59 | return x 60 | 61 | """Bilinearly demosaics a batch of RGGB Bayer images.""" 62 | 63 | shape = bayer_images.size() 64 | shape = [shape[1] * 2, shape[2] * 2] 65 | 66 | red = bayer_images[Ellipsis, 0:1] 67 | upsamplebyX = nn.Upsample(size=shape, mode='bilinear', align_corners=False) 68 | red = upsamplebyX(red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 69 | 70 | green_red = bayer_images[Ellipsis, 1:2] 71 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right 72 | green_red = upsamplebyX(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 73 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right 74 | green_red = SpaceToDepth_fact2(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 75 | 76 | green_blue = bayer_images[Ellipsis, 2:3] 77 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down 78 | green_blue = upsamplebyX(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 79 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down 80 | green_blue = SpaceToDepth_fact2(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 81 | 82 | green_at_red = (green_red[Ellipsis, 0] + green_blue[Ellipsis, 0]) / 2 83 | green_at_green_red = green_red[Ellipsis, 1] 84 | green_at_green_blue = green_blue[Ellipsis, 2] 85 | green_at_blue = (green_red[Ellipsis, 3] + green_blue[Ellipsis, 3]) / 2 86 | 87 | green_planes = [ 88 | green_at_red, green_at_green_red, green_at_green_blue, green_at_blue 89 | ] 90 | green = DepthToSpace_fact2(torch.stack(green_planes, dim=-1).permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 91 | 92 | blue = bayer_images[Ellipsis, 3:4] 93 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0]) 94 | blue = upsamplebyX(blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 95 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0]) 96 | 97 | rgb_images = torch.cat([red, green, blue], dim=-1) 98 | return rgb_images 99 | 100 | 101 | def apply_ccms(images, ccms): 102 | """Applies color correction matrices.""" 103 | images = images[:, :, :, None, :] 104 | ccms = ccms[:, None, None, :, :] 105 | outs = torch.sum(images * ccms, dim=-1) 106 | return outs 107 | 108 | 109 | def gamma_compression(images, gamma=2.2): 110 | """Converts from linear to gamma space.""" 111 | # Clamps to prevent numerical instability of gradients near zero. 112 | Mask = lambda x: (x>0.0031308).float() 113 | sRGBDeLinearize = lambda x,m: m * (1.055 * (m * x) ** (1/2.4) - 0.055) + (1-m) * (12.92 * x) 114 | return sRGBDeLinearize(images, Mask(images)) 115 | # outs = torch.clamp(images, min=1e-8) ** (1.0 / gamma) 116 | # return outs 117 | 118 | 119 | def process(bayer_images, red_gains, blue_gains, cam2rgbs, demosaic_type, lineRGB): 120 | # print(bayer_images.shape, red_gains.shape, cam2rgbs.shape) 121 | """Processes a batch of Bayer RGGB images into sRGB images.""" 122 | # White balance. 123 | bayer_images = apply_gains(bayer_images, red_gains, blue_gains) 124 | # Demosaic. 125 | bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0) 126 | 127 | if demosaic_type == 'default': 128 | images = demosaic(bayer_images) 129 | elif demosaic_type == 'menon2007': 130 | # print(bayer_images.size()) 131 | bayer_images = flatten_raw_image(bayer_images.squeeze(0)) 132 | images = demosaicing_CFA_Bayer_Menon2007(bayer_images.cpu().numpy(), 'RGGB') 133 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device) 134 | elif demosaic_type == 'net': 135 | bayer_images = flatten_raw_image(bayer_images.squeeze(0)).cpu().numpy() 136 | bayer = np.power(np.clip(bayer_images.astype(dtype=np.float32), 0, 1), 1 / 2.2) 137 | pretrained_model_path = "./isp/model.bin" 138 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=red_gains.device, 139 | cfa='bayer', state_dict=True) 140 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa='RGGB', 141 | demosaic_net=demosaic_net, device=red_gains.device) 142 | images = np.power(np.clip(rgb, 0, 1), 2.2) 143 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device) 144 | elif demosaic_type == 'net2': 145 | bayer_images = flatten_raw_image(bayer_images.squeeze(0)) 146 | # bayer = np.power(np.clip(bayer_images.astype(dtype=np.float32), 0, 1), 1 / 2.2) 147 | # 使用torch tensor类型重写bayer 148 | bayer= torch.pow(torch.clamp(bayer_images, 0, 1), 1 / 2.2) 149 | pretrained_model_path = "./isp/model.bin" 150 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=red_gains.device, 151 | cfa='bayer', state_dict=True) 152 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa='RGGB', 153 | demosaic_net=demosaic_net, device=red_gains.device) 154 | # images = np.power(np.clip(rgb, 0, 1), 2.2) 155 | images= torch.pow(torch.clamp(rgb, 0, 1), 2.2) 156 | images = images.to(red_gains.device) 157 | # Color correction. 158 | images = apply_ccms(images, cam2rgbs) 159 | # Gamma compression. 160 | images = torch.clamp(images, min=0.0, max=1.0) 161 | if not lineRGB: 162 | images = gamma_compression(images) 163 | return images 164 | 165 | 166 | def flatten_raw_image(im_raw_4ch): # HxWx4 167 | """ unpack a 4-channel tensor into a single channel bayer image""" 168 | if isinstance(im_raw_4ch, np.ndarray): 169 | im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2)) 170 | elif isinstance(im_raw_4ch, torch.Tensor): 171 | im_out = torch.zeros((im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2), dtype=im_raw_4ch.dtype) 172 | else: 173 | raise Exception 174 | 175 | im_out[0::2, 0::2] = im_raw_4ch[:, :, 0] 176 | im_out[0::2, 1::2] = im_raw_4ch[:, :, 1] 177 | im_out[1::2, 0::2] = im_raw_4ch[:, :, 2] 178 | im_out[1::2, 1::2] = im_raw_4ch[:, :, 3] 179 | 180 | return im_out # HxW -------------------------------------------------------------------------------- /models/degrade/unprocess.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Unprocesses sRGB images into realistic raw data. 17 | 18 | Unprocessing Images for Learned Raw Denoising 19 | http://timothybrooks.com/tech/unprocessing 20 | """ 21 | 22 | import numpy as np 23 | import torch 24 | import torch.distributions as tdist 25 | 26 | 27 | def random_ccm(device): 28 | """Generates random RGB -> Camera color correction matrices.""" 29 | # Takes a random convex combination of XYZ -> Camera CCMs. 30 | xyz2cams = [[[1.0234, -0.2969, -0.2266], 31 | [-0.5625, 1.6328, -0.0469], 32 | [-0.0703, 0.2188, 0.6406]], 33 | [[0.4913, -0.0541, -0.0202], 34 | [-0.613, 1.3513, 0.2906], 35 | [-0.1564, 0.2151, 0.7183]], 36 | [[0.838, -0.263, -0.0639], 37 | [-0.2887, 1.0725, 0.2496], 38 | [-0.0627, 0.1427, 0.5438]], 39 | [[0.6596, -0.2079, -0.0562], 40 | [-0.4782, 1.3016, 0.1933], 41 | [-0.097, 0.1581, 0.5181]]] 42 | num_ccms = len(xyz2cams) 43 | xyz2cams = torch.FloatTensor(xyz2cams) 44 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8) 45 | weights_sum = torch.sum(weights, dim=0) 46 | xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum 47 | 48 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM. 49 | rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375], 50 | [0.2126729, 0.7151522, 0.0721750], 51 | [0.0193339, 0.1191920, 0.9503041]]) 52 | rgb2cam = torch.mm(xyz2cam.to(device), rgb2xyz.to(device)) 53 | 54 | # Normalizes each row. 55 | rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True) 56 | return rgb2cam 57 | 58 | 59 | def random_gains(device): 60 | """Generates random gains for brightening and white balance.""" 61 | # RGB gain represents brightening. 62 | n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1])) 63 | rgb_gain = 1.0 / n.sample() 64 | 65 | # Red and blue gains represent white balance. 66 | red_gain = torch.FloatTensor(1).uniform_(1.9, 2.4) 67 | blue_gain = torch.FloatTensor(1).uniform_(1.5, 1.9) 68 | return rgb_gain.to(device), red_gain.to(device), blue_gain.to(device) 69 | 70 | 71 | def inverse_smoothstep(image): 72 | """Approximately inverts a global tone mapping curve.""" 73 | image = torch.clamp(image, min=0.0, max=1.0) 74 | out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0) 75 | return out 76 | 77 | 78 | def gamma_expansion(image): 79 | """Converts from gamma to linear space.""" 80 | # Clamps to prevent numerical instability of gradients near zero. 81 | Mask = lambda x: (x>0.04045).float() 82 | sRGBLinearize = lambda x,m: m * ((m * x + 0.055) / 1.055) ** 2.4 + (1-m) * (x / 12.92) 83 | return sRGBLinearize(image, Mask(image)) 84 | # out = torch.clamp(image, min=1e-8) ** 2.2 85 | # return out 86 | 87 | 88 | def apply_ccm(image, ccm): 89 | """Applies a color correction matrix.""" 90 | shape = image.size() 91 | image = torch.reshape(image, [-1, 3]) 92 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]]) 93 | out = torch.reshape(image, shape) 94 | return out 95 | 96 | 97 | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device): 98 | """Inverts gains while safely handling saturated pixels.""" 99 | gains = torch.stack((1.0 / red_gain, torch.tensor([1.0]).to(device), 1.0 / blue_gain)) # / rgb_gain 100 | gains = gains.to(device).squeeze() 101 | gains = gains[None, None, :] 102 | # Prevents dimming of saturated pixels by smoothly masking gains near white. 103 | gray = torch.mean(image, dim=-1, keepdim=True) 104 | inflection = 0.9 105 | mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0 106 | safe_gains = torch.max(mask + (1.0 - mask) * gains, gains) 107 | out = image * safe_gains 108 | return out 109 | 110 | 111 | def mosaic(image): 112 | """Extracts RGGB Bayer planes from an RGB image.""" 113 | shape = image.size() 114 | red = image[0::2, 0::2, 0] 115 | green_red = image[0::2, 1::2, 1] 116 | green_blue = image[1::2, 0::2, 1] 117 | blue = image[1::2, 1::2, 2] 118 | out = torch.stack((red, green_red, green_blue, blue), dim=-1) 119 | out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4)) 120 | return out 121 | 122 | 123 | def unprocess(image, features=None, device=None): 124 | """Unprocesses an image from sRGB to realistic raw data.""" 125 | 126 | if features == None: 127 | # Randomly creates image metadata. 128 | rgb2cam = random_ccm(device) 129 | cam2rgb = torch.inverse(rgb2cam) 130 | rgb_gain, red_gain, blue_gain = random_gains(device) 131 | else: 132 | rgb2cam = features['rgb2cam'] 133 | cam2rgb = features['cam2rgb'] 134 | rgb_gain = features['rgb_gain'] 135 | red_gain = features['red_gain'] 136 | blue_gain = features['blue_gain'] 137 | # Approximately inverts global tone mapping. 138 | # image = inverse_smoothstep(image) 139 | # Inverts gamma compression. 140 | image = gamma_expansion(image) 141 | # Inverts color correction. 142 | image = apply_ccm(image, rgb2cam) 143 | # Approximately inverts white balance and brightening. 144 | image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device) 145 | # Clips saturated pixels. 146 | # image = torch.clamp(image, min=0.0, max=1.0) 147 | # Applies a Bayer mosaic. 148 | image = mosaic(image) 149 | 150 | metadata = { 151 | 'rgb2cam': rgb2cam, 152 | 'cam2rgb': cam2rgb, 153 | 'rgb_gain': rgb_gain, 154 | 'red_gain': red_gain, 155 | 'blue_gain': blue_gain, 156 | } 157 | return image, metadata 158 | 159 | 160 | # ############### If the target dataset is DND, use this function ##################### 161 | # def random_noise_levels(): 162 | # """Generates random noise levels from a log-log linear distribution.""" 163 | # log_min_shot_noise = np.log(0.0001) 164 | # log_max_shot_noise = np.log(0.012) 165 | # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise) 166 | # shot_noise = torch.exp(log_shot_noise) 167 | 168 | # line = lambda x: 2.18 * x + 1.20 169 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.26])) 170 | # log_read_noise = line(log_shot_noise) + n.sample() 171 | # read_noise = torch.exp(log_read_noise) 172 | # return shot_noise, read_noise 173 | 174 | 175 | def add_noise(image, shot_noise, read_noise): 176 | var = image * shot_noise + read_noise 177 | noise = tdist.Normal(loc=torch.zeros_like(var), scale=torch.sqrt(var)).sample() 178 | out = image + noise 179 | return out 180 | 181 | 182 | ################ If the target dataset is SIDD, use this function ##################### 183 | def random_noise_levels(noise_level): 184 | """ Where read_noise in SIDD is not 0 """ 185 | log_min_shot_noise = torch.log(torch.tensor(0.0012)).to(noise_level.device) 186 | log_max_shot_noise = torch.log(torch.tensor(0.0048)).to(noise_level.device) 187 | log_shot_noise = log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise) 188 | shot_noise = torch.exp(log_shot_noise) 189 | 190 | line = lambda x: 1.869 * x + 0.3276 191 | n = tdist.Normal(loc=torch.tensor([0.0], device=noise_level.device), 192 | scale=torch.tensor([0.30], device=noise_level.device)) 193 | 194 | log_read_noise = line(log_shot_noise) + n.sample() 195 | read_noise = torch.exp(log_read_noise) 196 | 197 | return shot_noise, read_noise 198 | 199 | 200 | # def add_noise(image, shot_noise=0.01, read_noise=0.0005): 201 | # """Adds random shot (proportional to image) and read (independent) noise.""" 202 | # variance = image * shot_noise + read_noise 203 | # n = tdist.Normal(loc=torch.zeros_like(variance), scale=torch.sqrt(variance)) 204 | # noise = n.sample() 205 | # out = image + noise 206 | # return out 207 | 208 | 209 | # ################ If the target dataset is SIDD, use this function ##################### 210 | # def random_noise_levels(noise_level): 211 | # """ Where read_noise in SIDD is not 0 """ 212 | # log_min_shot_noise = np.log(0.00068674) 213 | # log_max_shot_noise = np.log(0.02194856) 214 | # # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise) 215 | # log_shot_noise = torch.FloatTensor([log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise)]) 216 | # shot_noise = torch.exp(log_shot_noise) 217 | 218 | # line = lambda x: 1.85 * x + 0.30 219 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.20])) 220 | # log_read_noise = line(log_shot_noise) + n.sample() 221 | # read_noise = torch.exp(log_read_noise) 222 | # return shot_noise, read_noise -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from math import exp, sqrt 9 | from torch.nn import L1Loss, MSELoss 10 | from torchvision import models 11 | from data.degrade.process import demosaic 12 | 13 | 14 | def gaussian(window_size, sigma): 15 | gauss = torch.Tensor([exp( 16 | -(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \ 17 | for x in range(window_size)]) 18 | return gauss / gauss.sum() 19 | 20 | def create_window(window_size, channel): 21 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 22 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 23 | window = Variable(_2D_window.expand( 24 | channel, 1, window_size, window_size).contiguous()) 25 | return window 26 | 27 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 28 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 29 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 30 | 31 | mu1_sq = mu1.pow(2) 32 | mu2_sq = mu2.pow(2) 33 | mu1_mu2 = mu1 * mu2 34 | 35 | sigma1_sq = F.conv2d( 36 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 37 | sigma2_sq = F.conv2d( 38 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 39 | sigma12 = F.conv2d( 40 | img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 41 | 42 | C1 = 0.01 ** 2 43 | C2 = 0.03 ** 2 44 | 45 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ 46 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 47 | 48 | if size_average: 49 | return ssim_map.mean() 50 | else: 51 | return ssim_map.mean(1).mean(1).mean(1) 52 | 53 | def ssim(img1, img2, window_size=11, size_average=True): 54 | (_, channel, _, _) = img1.size() 55 | window = create_window(window_size, channel) 56 | 57 | if img1.is_cuda: 58 | window = window.cuda(img1.get_device()) 59 | window = window.type_as(img1) 60 | 61 | return _ssim(img1, img2, window, window_size, channel, size_average) 62 | 63 | class SSIMLoss(nn.Module): 64 | def __init__(self, window_size=11, size_average=True): 65 | super(SSIMLoss, self).__init__() 66 | self.window_size = window_size 67 | self.size_average = size_average 68 | self.channel = 1 69 | self.window = create_window(window_size, self.channel) 70 | 71 | def forward(self, img1, img2): 72 | (_, channel, _, _) = img1.size() 73 | 74 | if channel == self.channel and \ 75 | self.window.data.type() == img1.data.type(): 76 | window = self.window 77 | else: 78 | window = create_window(self.window_size, channel) 79 | 80 | if img1.is_cuda: 81 | window = window.cuda(img1.get_device()) 82 | window = window.type_as(img1) 83 | 84 | self.window = window 85 | self.channel = channel 86 | 87 | return _ssim(img1, img2, window, self.window_size, 88 | channel, self.size_average) 89 | 90 | def post_process(image): 91 | image = torch.stack((image[:, 0], image[:, 1:3].mean(dim=1), image[:, 3]), dim=1) 92 | # image = torch.pow(torch.clamp(image, 1e-6, 1), 1/2.2) 93 | return image 94 | 95 | def normalize_batch(batch): 96 | batch = post_process(batch) 97 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 98 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 99 | return (batch - mean) / std 100 | 101 | class VGG19(torch.nn.Module): 102 | def __init__(self): 103 | super(VGG19, self).__init__() 104 | features = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features 105 | self.relu1_1 = torch.nn.Sequential() 106 | self.relu1_2 = torch.nn.Sequential() 107 | 108 | self.relu2_1 = torch.nn.Sequential() 109 | self.relu2_2 = torch.nn.Sequential() 110 | 111 | self.relu3_1 = torch.nn.Sequential() 112 | self.relu3_2 = torch.nn.Sequential() 113 | self.relu3_3 = torch.nn.Sequential() 114 | self.relu3_4 = torch.nn.Sequential() 115 | 116 | self.relu4_1 = torch.nn.Sequential() 117 | self.relu4_2 = torch.nn.Sequential() 118 | self.relu4_3 = torch.nn.Sequential() 119 | self.relu4_4 = torch.nn.Sequential() 120 | 121 | self.relu5_1 = torch.nn.Sequential() 122 | self.relu5_2 = torch.nn.Sequential() 123 | self.relu5_3 = torch.nn.Sequential() 124 | self.relu5_4 = torch.nn.Sequential() 125 | 126 | for x in range(2): 127 | self.relu1_1.add_module(str(x), features[x]) 128 | 129 | for x in range(2, 4): 130 | self.relu1_2.add_module(str(x), features[x]) 131 | 132 | for x in range(4, 7): 133 | self.relu2_1.add_module(str(x), features[x]) 134 | 135 | for x in range(7, 9): 136 | self.relu2_2.add_module(str(x), features[x]) 137 | 138 | for x in range(9, 12): 139 | self.relu3_1.add_module(str(x), features[x]) 140 | 141 | for x in range(12, 14): 142 | self.relu3_2.add_module(str(x), features[x]) 143 | 144 | for x in range(14, 16): 145 | self.relu3_3.add_module(str(x), features[x]) 146 | 147 | for x in range(16, 18): 148 | self.relu3_4.add_module(str(x), features[x]) 149 | 150 | for x in range(18, 21): 151 | self.relu4_1.add_module(str(x), features[x]) 152 | 153 | for x in range(21, 23): 154 | self.relu4_2.add_module(str(x), features[x]) 155 | 156 | for x in range(23, 25): 157 | self.relu4_3.add_module(str(x), features[x]) 158 | 159 | for x in range(25, 27): 160 | self.relu4_4.add_module(str(x), features[x]) 161 | 162 | for x in range(27, 30): 163 | self.relu5_1.add_module(str(x), features[x]) 164 | 165 | for x in range(30, 32): 166 | self.relu5_2.add_module(str(x), features[x]) 167 | 168 | for x in range(32, 34): 169 | self.relu5_3.add_module(str(x), features[x]) 170 | 171 | for x in range(34, 36): 172 | self.relu5_4.add_module(str(x), features[x]) 173 | 174 | # don't need the gradients, just want the features 175 | for param in self.parameters(): 176 | param.requires_grad = False 177 | 178 | def forward(self, x): 179 | relu1_1 = self.relu1_1(x) 180 | relu1_2 = self.relu1_2(relu1_1) 181 | 182 | relu2_1 = self.relu2_1(relu1_2) 183 | relu2_2 = self.relu2_2(relu2_1) 184 | 185 | relu3_1 = self.relu3_1(relu2_2) 186 | relu3_2 = self.relu3_2(relu3_1) 187 | relu3_3 = self.relu3_3(relu3_2) 188 | relu3_4 = self.relu3_4(relu3_3) 189 | 190 | relu4_1 = self.relu4_1(relu3_4) 191 | relu4_2 = self.relu4_2(relu4_1) 192 | relu4_3 = self.relu4_3(relu4_2) 193 | relu4_4 = self.relu4_4(relu4_3) 194 | 195 | relu5_1 = self.relu5_1(relu4_4) 196 | relu5_2 = self.relu5_2(relu5_1) 197 | relu5_3 = self.relu5_3(relu5_2) 198 | relu5_4 = self.relu5_4(relu5_3) 199 | 200 | out = { 201 | 'relu1_1': relu1_1, 202 | 'relu1_2': relu1_2, 203 | 204 | 'relu2_1': relu2_1, 205 | 'relu2_2': relu2_2, 206 | 207 | 'relu3_1': relu3_1, 208 | 'relu3_2': relu3_2, 209 | 'relu3_3': relu3_3, 210 | 'relu3_4': relu3_4, 211 | 212 | 'relu4_1': relu4_1, 213 | 'relu4_2': relu4_2, 214 | 'relu4_3': relu4_3, 215 | 'relu4_4': relu4_4, 216 | 217 | 'relu5_1': relu5_1, 218 | 'relu5_2': relu5_2, 219 | 'relu5_3': relu5_3, 220 | 'relu5_4': relu5_4, 221 | } 222 | return out 223 | 224 | class VGGLoss(nn.Module): 225 | def __init__(self): 226 | super(VGGLoss, self).__init__() 227 | self.add_module('vgg', VGG19()) 228 | self.criterion = torch.nn.L1Loss() 229 | 230 | def forward(self, img1, img2, p=6): 231 | x = normalize_batch(img1) 232 | y = normalize_batch(img2) 233 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 234 | 235 | content_loss = 0.0 236 | # # content_loss += self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) * 0.1 237 | # # content_loss += self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) * 0.2 238 | content_loss += self.criterion(x_vgg['relu3_2'], y_vgg['relu3_2']) * 1 239 | content_loss += self.criterion(x_vgg['relu4_2'], y_vgg['relu4_2']) * 1 240 | content_loss += self.criterion(x_vgg['relu5_2'], y_vgg['relu5_2']) * 2 241 | 242 | return content_loss / 4. 243 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | from torch.optim import lr_scheduler 5 | from collections import OrderedDict 6 | import torch.nn.functional as F 7 | from util.util import SSIM 8 | import random 9 | import torch.distributions as tdist 10 | import numpy as np 11 | from timm.scheduler.cosine_lr import CosineLRScheduler 12 | from einops import rearrange 13 | import torchvision.ops as ops 14 | import numbers 15 | from torch.nn.modules.utils import _triple 16 | import math 17 | import cv2 18 | 19 | 20 | 21 | # Augument 22 | def augment_func(img, hflip, vflip, rot90): # CxHxW 23 | if hflip: img = torch.flip(img, dims=[-1]) 24 | if vflip: img = torch.flip(img, dims=[-2]) 25 | if rot90: img = img.transpose(img.ndim-1, img.ndim-2) 26 | return img 27 | 28 | 29 | def augment(*imgs): # CxHxW 30 | hflip = random.random() < 0.5 31 | vflip = random.random() < 0.5 32 | rot90 = random.random() < 0.5 33 | return (augment_func(img, hflip, vflip, rot90) for img in imgs) 34 | 35 | 36 | # Pack Images 37 | def pack_raw_image(im_raw): # GT: N x 1 x H x W ; RAWs: N x T x H x W 38 | """ Packs a single channel bayer image into 4 channel tensor, where channels contain R, G, G, and B values""" 39 | im_out = torch.zeros([im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[2] // 2, im_raw.shape[3] // 2], 40 | dtype=im_raw.dtype, device=im_raw.device) 41 | 42 | im_out[..., 0, :, :] = im_raw[:, :, 0::2, 0::2] 43 | im_out[..., 1, :, :] = im_raw[:, :, 0::2, 1::2] 44 | im_out[..., 2, :, :] = im_raw[:, :, 1::2, 0::2] 45 | im_out[..., 3, :, :] = im_raw[:, :, 1::2, 1::2] 46 | 47 | # GT: N x 1 x 4 x (H//2) x (W//2) 48 | # RAWs: N x T x 4 x (H//2) x (W//2) 49 | return im_out 50 | 51 | 52 | def get_scheduler(optimizer, opt): 53 | if opt.lr_policy == 'linear': 54 | def lambda_rule(epoch): 55 | return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay)) 56 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 57 | elif opt.lr_policy == 'step': 58 | scheduler = lr_scheduler.StepLR(optimizer, 59 | step_size=opt.lr_decay_iters, 60 | gamma=0.5) 61 | elif opt.lr_policy == 'plateau': 62 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 63 | mode='min', 64 | factor=0.2, 65 | threshold=0.01, 66 | patience=5) 67 | elif opt.lr_policy == 'cosine': 68 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, 69 | T_max=opt.niter, 70 | eta_min=1e-6) 71 | elif opt.lr_policy == 'cosine_warmup': 72 | scheduler = CosineLRScheduler(optimizer, 73 | t_initial=opt.niter, 74 | lr_min=1e-6, 75 | warmup_t=5, 76 | warmup_lr_init=1e-5) 77 | else: 78 | return NotImplementedError('lr [%s] is not implemented', opt.lr_policy) 79 | return scheduler 80 | 81 | 82 | def init_weights(net, init_type='normal', init_gain=0.02): 83 | def init_func(m): # define the initialization function 84 | classname = m.__class__.__name__ 85 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 \ 86 | or classname.find('Linear') != -1): 87 | if init_type == 'normal': 88 | init.normal_(m.weight.data, 0.0, init_gain) 89 | elif init_type == 'xavier': 90 | init.xavier_normal_(m.weight.data, gain=init_gain) 91 | elif init_type == 'kaiming': 92 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 93 | elif init_type == 'orthogonal': 94 | init.orthogonal_(m.weight.data, gain=init_gain) 95 | elif init_type == 'uniform': 96 | init.uniform_(m.weight.data, b=init_gain) 97 | elif init_type == 'constant': 98 | init.constant_(m.weight.data, 0.0) 99 | else: 100 | raise NotImplementedError('[%s] is not implemented' % init_type) 101 | elif hasattr(m, 'bias') and m.bias is not None: 102 | init.constant_(m.bias.data, 0.0) 103 | elif classname.find('BatchNorm2d') != -1: 104 | init.normal_(m.weight.data, 1.0, init_gain) 105 | init.constant_(m.bias.data, 0.0) 106 | 107 | # print('initialize network with %s' % init_type) 108 | net.apply(init_func) # apply the initialization function 109 | 110 | 111 | def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]): 112 | if len(gpu_ids) > 0: 113 | assert(torch.cuda.is_available()) 114 | net.to(gpu_ids[0]) 115 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 116 | if init_type != 'default' and init_type is not None: 117 | init_weights(net, init_type, init_gain=init_gain) 118 | return net 119 | 120 | 121 | def set_requires_grad(nets, requires_grad=False): 122 | if not isinstance(nets, list): 123 | nets = [nets] 124 | for net in nets: 125 | if net is not None: 126 | for param in net.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | 130 | def make_layer(block, num_blocks, **kwarg): 131 | """Make layers by stacking the same blocks. 132 | 133 | Args: 134 | block (nn.module): nn.module class for basic block. 135 | num_blocks (int): number of blocks. 136 | 137 | Returns: 138 | nn.Sequential: Stacked blocks in nn.Sequential. 139 | """ 140 | layers = [] 141 | for _ in range(num_blocks): 142 | layers.append(block(**kwarg)) 143 | return nn.Sequential(*layers) 144 | 145 | 146 | def flow_warp(x, 147 | flow, 148 | interpolation='bilinear', 149 | padding_mode='zeros', 150 | align_corners=True): 151 | """Warp an image or a feature map with optical flow. 152 | 153 | Args: 154 | x (Tensor): Tensor with size (n, c, h, w). 155 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is 156 | a two-channel, denoting the width and height relative offsets. 157 | Note that the values are not normalized to [-1, 1]. 158 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. 159 | Default: 'bilinear'. 160 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. 161 | Default: 'zeros'. 162 | align_corners (bool): Whether align corners. Default: True. 163 | 164 | Returns: 165 | Tensor: Warped image or feature map. 166 | """ 167 | if x.size()[-2:] != flow.size()[1:3]: 168 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' 169 | f'flow ({flow.size()[1:3]}) are not the same.') 170 | _, _, h, w = x.size() 171 | # create mesh grid 172 | device = flow.device 173 | grid_y, grid_x = torch.meshgrid( 174 | torch.arange(0, h, device=device, dtype=x.dtype), 175 | torch.arange(0, w, device=device, dtype=x.dtype)) 176 | grid = torch.stack((grid_x, grid_y), 2) # h, w, 2 177 | grid.requires_grad = False 178 | 179 | grid_flow = grid + flow 180 | # scale grid_flow to [-1,1] 181 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 182 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 183 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) 184 | output = F.grid_sample( 185 | x, 186 | grid_flow, 187 | mode=interpolation, 188 | padding_mode=padding_mode, 189 | align_corners=align_corners) 190 | return output 191 | 192 | 193 | def load_spynet(net, path): 194 | if isinstance(net, torch.nn.DataParallel): 195 | net = net.module 196 | state_dict = torch.load(path) 197 | 198 | print('loading the model from %s' % (path)) 199 | if hasattr(state_dict, '_metadata'): 200 | del state_dict._metadata 201 | 202 | net_state = net.state_dict() 203 | is_loaded = {n:False for n in net_state.keys()} 204 | for name, param in state_dict.items(): 205 | name = name.replace('basic_module.0.conv', 'basic_module.0') 206 | name = name.replace('basic_module.1.conv', 'basic_module.2') 207 | name = name.replace('basic_module.2.conv', 'basic_module.4') 208 | name = name.replace('basic_module.3.conv', 'basic_module.6') 209 | name = name.replace('basic_module.4.conv', 'basic_module.8') 210 | if name in net_state: 211 | try: 212 | net_state[name].copy_(param) 213 | is_loaded[name] = True 214 | except Exception: 215 | print('While copying the parameter named [%s], ' 216 | 'whose dimensions in the model are %s and ' 217 | 'whose dimensions in the checkpoint are %s.' 218 | % (name, list(net_state[name].shape), 219 | list(param.shape))) 220 | raise RuntimeError 221 | else: 222 | print('Saved parameter named [%s] is skipped' % name) 223 | mark = True 224 | for name in is_loaded: 225 | if not is_loaded[name]: 226 | print('Parameter named [%s] is not initialized' % name) 227 | mark = False 228 | if mark: 229 | print('All parameters are initialized using [%s]' % path) 230 | 231 | 232 | ''' 233 | # =================================== 234 | # Advanced nn.Sequential 235 | # reform nn.Sequentials and nn.Modules 236 | # to a single nn.Sequential 237 | # =================================== 238 | ''' 239 | 240 | def seq(*args): 241 | if len(args) == 1: 242 | args = args[0] 243 | if isinstance(args, nn.Module): 244 | return args 245 | modules = OrderedDict() 246 | if isinstance(args, OrderedDict): 247 | for k, v in args.items(): 248 | modules[k] = seq(v) 249 | return nn.Sequential(modules) 250 | assert isinstance(args, (list, tuple)) 251 | return nn.Sequential(*[seq(i) for i in args]) 252 | 253 | ''' 254 | # =================================== 255 | # Useful blocks 256 | # -------------------------------- 257 | # conv (+ normaliation + relu) 258 | # concat 259 | # sum 260 | # resblock (ResBlock) 261 | # resdenseblock (ResidualDenseBlock_5C) 262 | # resinresdenseblock (RRDB) 263 | # =================================== 264 | ''' 265 | 266 | # ------------------------------------------------------- 267 | # return nn.Sequantial of (Conv + BN + ReLU) 268 | # ------------------------------------------------------- 269 | def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, 270 | output_padding=0, dilation=1, groups=1, bias=True, 271 | padding_mode='zeros', mode='CBR'): 272 | L = [] 273 | for t in mode: 274 | if t == 'C': 275 | L.append(nn.Conv2d(in_channels=in_channels, 276 | out_channels=out_channels, 277 | kernel_size=kernel_size, 278 | stride=stride, 279 | padding=padding, 280 | dilation=dilation, 281 | groups=groups, 282 | bias=bias, 283 | padding_mode=padding_mode)) 284 | elif t == 'X': 285 | assert in_channels == out_channels 286 | L.append(nn.Conv2d(in_channels=in_channels, 287 | out_channels=out_channels, 288 | kernel_size=kernel_size, 289 | stride=stride, 290 | padding=padding, 291 | dilation=dilation, 292 | groups=in_channels, 293 | bias=bias, 294 | padding_mode=padding_mode)) 295 | elif t == 'T': 296 | L.append(nn.ConvTranspose2d(in_channels=in_channels, 297 | out_channels=out_channels, 298 | kernel_size=kernel_size, 299 | stride=stride, 300 | padding=padding, 301 | output_padding=output_padding, 302 | groups=groups, 303 | bias=bias, 304 | dilation=dilation, 305 | padding_mode=padding_mode)) 306 | elif t == 'B': 307 | L.append(nn.BatchNorm2d(out_channels)) 308 | elif t == 'I': 309 | L.append(nn.InstanceNorm2d(out_channels, affine=True)) 310 | elif t == 'i': 311 | L.append(nn.InstanceNorm2d(out_channels)) 312 | elif t == 'R': 313 | L.append(nn.ReLU(inplace=True)) 314 | elif t == 'r': 315 | L.append(nn.ReLU(inplace=False)) 316 | elif t == 'S': 317 | L.append(nn.Sigmoid()) 318 | elif t == 'P': 319 | L.append(nn.PReLU()) 320 | elif t == 'L': 321 | L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True)) 322 | elif t == 'l': 323 | L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False)) 324 | elif t == '2': 325 | L.append(nn.PixelShuffle(upscale_factor=2)) 326 | elif t == '3': 327 | L.append(nn.PixelShuffle(upscale_factor=3)) 328 | elif t == '4': 329 | L.append(nn.PixelShuffle(upscale_factor=4)) 330 | elif t == 'U': 331 | L.append(nn.Upsample(scale_factor=2, mode='nearest')) 332 | elif t == 'u': 333 | L.append(nn.Upsample(scale_factor=3, mode='nearest')) 334 | elif t == 'M': 335 | L.append(nn.MaxPool2d(kernel_size=kernel_size, 336 | stride=stride, 337 | padding=0)) 338 | elif t == 'A': 339 | L.append(nn.AvgPool2d(kernel_size=kernel_size, 340 | stride=stride, 341 | padding=0)) 342 | else: 343 | raise NotImplementedError('Undefined type: '.format(t)) 344 | return seq(*L) 345 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules.""" 2 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | from util import util 5 | import torch 6 | import models 7 | import time 8 | 9 | def str2bool(v): 10 | return v.lower() in ('yes', 'y', 'true', 't', '1') 11 | 12 | inf = float('inf') 13 | 14 | class BaseOptions(): 15 | def __init__(self): 16 | """Reset the class; indicates the class hasn't been initailized""" 17 | self.initialized = False 18 | 19 | def initialize(self, parser): 20 | """Define the common options that are used in both training and test.""" 21 | # data parameters 22 | parser.add_argument('--dataroot', type=str, default="data") 23 | parser.add_argument('--dataset_name', type=str, default=['bracketire'], nargs='+') 24 | parser.add_argument('--max_dataset_size', type=int, default=inf) 25 | # parser.add_argument('--scale', type=int, default=4, help='Super-resolution scale.') 26 | parser.add_argument('--frame_num', type=int, default=5) 27 | parser.add_argument('--batch_size', type=int, default=2) 28 | parser.add_argument('--patch_size', type=int, default=128) 29 | parser.add_argument('--shuffle', type=str2bool, default=True) 30 | parser.add_argument('-j', '--num_dataloader', default=4, type=int) 31 | parser.add_argument('--drop_last', type=str2bool, default=True) 32 | 33 | # device parameters 34 | parser.add_argument('--gpu_ids', type=str, default='all', 35 | help='Separate the GPU ids by `,`, using all GPUs by default. ' 36 | 'eg, `--gpu_ids 0`, `--gpu_ids 2,3`, `--gpu_ids -1`(CPU)') 37 | parser.add_argument('--checkpoints_dir', type=str, default='./ckpt') 38 | parser.add_argument('-v', '--verbose', type=str2bool, default=True) 39 | parser.add_argument('--suffix', default='', type=str) 40 | 41 | # model parameters 42 | parser.add_argument('--name', type=str, default='track1', 43 | help='Name of the folder to save models and logs.') 44 | parser.add_argument('--model', type=str, default='cat') 45 | parser.add_argument('--block', type=str, default='Convnext') 46 | parser.add_argument('--load_path', type=str, default='', 47 | help='Will load pre-trained model if load_path is set') 48 | parser.add_argument('--load_iter', type=int, default=[500], nargs='+', 49 | help='Load parameters if > 0 and load_path is not set. ' 50 | 'Set the value of `last_epoch`') 51 | parser.add_argument('--chop', type=str2bool, default=False) 52 | parser.add_argument('--crop_patch', type=int, default=48) 53 | parser.add_argument('--self_weight', type=float, default=1) 54 | parser.add_argument('--neg_weight', type=float, default=1) 55 | parser.add_argument('--exposure', type=int, default=1) 56 | 57 | # training parameters 58 | parser.add_argument('--init_type', type=str, default='default', 59 | choices=['default', 'normal', 'xavier', 60 | 'kaiming', 'orthogonal', 'uniform'], 61 | help='`default` means using PyTorch default init functions.') 62 | parser.add_argument('--init_gain', type=float, default=0.02) 63 | parser.add_argument('--optimizer', type=str, default='Adam', 64 | choices=['Adam', 'SGD', 'RMSprop']) 65 | parser.add_argument('--niter', type=int, default=1000) 66 | parser.add_argument('--niter_decay', type=int, default=0) 67 | parser.add_argument('--lr_policy', type=str, default='step') 68 | parser.add_argument('--lr_decay_iters', type=int, default=200) 69 | parser.add_argument('--lr', type=float, default=0.0001) 70 | 71 | # Optimizer 72 | parser.add_argument('--load_optimizers', type=str2bool, default=False, 73 | help='Loading optimizer parameters for continuing training.') 74 | parser.add_argument('--weight_decay', type=float, default=0) 75 | # Adam 76 | parser.add_argument('--beta1', type=float, default=0.9) 77 | parser.add_argument('--beta2', type=float, default=0.999) 78 | # SGD & RMSprop 79 | parser.add_argument('--momentum', type=float, default=0) 80 | # RMSprop 81 | parser.add_argument('--alpha', type=float, default=0.99) 82 | 83 | # visualization parameters 84 | parser.add_argument('--print_freq', type=int, default=100) 85 | parser.add_argument('--test_every', type=int, default=1) 86 | parser.add_argument('--save_epoch_freq', type=int, default=1) 87 | parser.add_argument('--calc_metrics', type=str2bool, default=True) 88 | parser.add_argument('--save_imgs', type=str2bool, default=True) 89 | parser.add_argument('--visual_full_imgs', type=str2bool, default=False) 90 | 91 | parser.add_argument('--n_scales', type=int, default=3, help='multi-scale deblurring level') 92 | parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps') 93 | parser.add_argument('--rgb_range', type=int, default=1, help='RGB pixel value ranging from 0') 94 | 95 | self.initialized = True 96 | return parser 97 | 98 | def gather_options(self): 99 | """Initialize our parser with basic options(only once). 100 | Add additional model-specific and dataset-specific options. 101 | These options are difined in the function 102 | in model and dataset classes. 103 | """ 104 | if not self.initialized: # check if it has been initialized 105 | parser = argparse.ArgumentParser(formatter_class= 106 | argparse.ArgumentDefaultsHelpFormatter) 107 | parser = self.initialize(parser) 108 | 109 | # get the basic options 110 | opt, _ = parser.parse_known_args() 111 | 112 | # modify model-related parser options 113 | model_name = opt.model 114 | model_option_setter = models.get_option_setter(model_name) 115 | parser = model_option_setter(parser, self.isTrain) 116 | opt, _ = parser.parse_known_args() # parse again with new defaults 117 | 118 | # save and return the parser 119 | self.parser = parser 120 | return parser.parse_args() 121 | 122 | def print_options(self, opt): 123 | """Print and save options 124 | 125 | It will print both current options and default values(if different). 126 | It will save options into a text file / [checkpoints_dir] / opt.txt 127 | """ 128 | message = '' 129 | message += '----------------- Options ---------------\n' 130 | for k, v in sorted(vars(opt).items()): 131 | comment = '' 132 | default = self.parser.get_default(k) 133 | if v != default: 134 | comment = '\t[default: %s]' % str(default) 135 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 136 | message += '----------------- End -------------------' 137 | print(message) 138 | 139 | # save to the disk 140 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 141 | util.mkdirs(expr_dir) 142 | file_name = os.path.join(expr_dir, 'opt_%s.txt' 143 | % ('train' if self.isTrain else 'test')) 144 | with open(file_name, 'wt') as opt_file: 145 | opt_file.write(message) 146 | opt_file.write('\n') 147 | 148 | def parse(self): 149 | opt = self.gather_options() 150 | opt.isTrain = self.isTrain # train or test 151 | opt.serial_batches = not opt.shuffle 152 | 153 | if self.isTrain and (opt.load_iter != [0] or opt.load_path != '') \ 154 | and not opt.load_optimizers: 155 | util.prompt('You are loading a checkpoint and continuing training, ' 156 | 'and no optimizer parameters are loaded. Please make ' 157 | 'sure that the hyper parameters are correctly set.', 80) 158 | time.sleep(3) 159 | 160 | opt.model = opt.model.lower() 161 | opt.name = opt.name.lower() 162 | 163 | scale_patch = {2: 96, 3: 144, 4: 192} 164 | if opt.patch_size is None: 165 | opt.patch_size = scale_patch[opt.scale] 166 | 167 | if opt.name.startswith(opt.checkpoints_dir): 168 | opt.name = opt.name.replace(opt.checkpoints_dir+'/', '') 169 | if opt.name.endswith('/'): 170 | opt.name = opt.name[:-1] 171 | 172 | if len(opt.dataset_name) == 1: 173 | opt.dataset_name = opt.dataset_name[0] 174 | 175 | if len(opt.load_iter) == 1: 176 | opt.load_iter = opt.load_iter[0] 177 | 178 | # process opt.suffix 179 | if opt.suffix != '': 180 | suffix = ('_' + opt.suffix.format(**vars(opt))) 181 | opt.name = opt.name + suffix 182 | 183 | self.print_options(opt) 184 | 185 | # set gpu ids 186 | cuda_device_count = torch.cuda.device_count() 187 | if opt.gpu_ids == 'all': 188 | # GT 710 (3.5), GT 610 (2.1) 189 | gpu_ids = [i for i in range(cuda_device_count)] 190 | else: 191 | p = re.compile('[^-0-9]+') 192 | gpu_ids = [int(i) for i in re.split(p, opt.gpu_ids) if int(i) >= 0] 193 | opt.gpu_ids = [i for i in gpu_ids \ 194 | if torch.cuda.get_device_capability(i) >= (4,0)] 195 | 196 | if len(opt.gpu_ids) == 0 and len(gpu_ids) > 0: 197 | opt.gpu_ids = gpu_ids 198 | util.prompt('You\'re using GPUs with computing capability < 4') 199 | elif len(opt.gpu_ids) != len(gpu_ids): 200 | util.prompt('GPUs(computing capability < 4) have been disabled') 201 | 202 | if len(opt.gpu_ids) > 0: 203 | assert torch.cuda.is_available(), 'No cuda available !!!' 204 | torch.cuda.set_device(opt.gpu_ids[0]) 205 | print('The GPUs you are using:') 206 | for gpu_id in opt.gpu_ids: 207 | print(' %2d *%s* with capability %d.%d' % ( 208 | gpu_id, 209 | torch.cuda.get_device_name(gpu_id), 210 | *torch.cuda.get_device_capability(gpu_id))) 211 | else: 212 | util.prompt('You are using CPU mode') 213 | 214 | self.opt = opt 215 | return self.opt 216 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | self.isTrain = False 8 | return parser 9 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | self.isTrain = True 8 | return parser 9 | -------------------------------------------------------------------------------- /spynet/spynet_20210409-c6c1bd09.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/spynet/spynet_20210409-c6c1bd09.pth -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from options.test_options import TestOptions 4 | from data import create_dataset 5 | from models import create_model 6 | from util.visualizer import Visualizer 7 | from tqdm import tqdm 8 | from util.util import calc_psnr as calc_psnr 9 | import time 10 | import numpy as np 11 | from collections import OrderedDict as odict 12 | from copy import deepcopy 13 | from data.degrade.degrade_kernel import get_raw2rgb 14 | from data.degrade.process import gamma_compression 15 | from util.util import mu_tonemap, save_hdr 16 | import cv2 17 | 18 | 19 | if __name__ == '__main__': 20 | opt = TestOptions().parse() 21 | 22 | if not isinstance(opt.load_iter, list): 23 | load_iters = [opt.load_iter] 24 | else: 25 | load_iters = deepcopy(opt.load_iter) 26 | 27 | if not isinstance(opt.dataset_name, list): 28 | dataset_names = [opt.dataset_name] 29 | else: 30 | dataset_names = deepcopy(opt.dataset_name) 31 | datasets = odict() 32 | for dataset_name in dataset_names: 33 | dataset = create_dataset(dataset_name, 'test', opt) 34 | datasets[dataset_name] = tqdm(dataset) 35 | 36 | for load_iter in load_iters: 37 | opt.load_iter = load_iter 38 | model = create_model(opt) 39 | model.setup(opt) 40 | model.eval() 41 | print(torch.cuda.memory_allocated()) 42 | for dataset_name in dataset_names: 43 | opt.dataset_name = dataset_name 44 | tqdm_val = datasets[dataset_name] 45 | dataset_test = tqdm_val.iterable 46 | dataset_size_test = len(dataset_test) 47 | 48 | print('='*80) 49 | print(dataset_name + ' dataset') 50 | tqdm_val.reset() 51 | 52 | psnr = [0.0] * dataset_size_test 53 | 54 | time_val = 0 55 | 56 | folder_dir = './ckpt/%s/output_vispng_%d' % (opt.name, load_iter) 57 | os.makedirs(folder_dir, exist_ok=True) 58 | 59 | for i, data in enumerate(tqdm_val): 60 | torch.cuda.empty_cache() 61 | model.set_input(data) 62 | torch.cuda.synchronize() 63 | time_val_start = time.time() 64 | model.test() 65 | torch.cuda.synchronize() 66 | time_val += time.time() - time_val_start 67 | res = model.get_current_visuals() 68 | 69 | if opt.save_imgs: 70 | save_dir_vispng = '%s/%s.png' % (folder_dir, data['fname'][0]) 71 | raw_img = res['data_out'][0].permute(1, 2, 0) / 16 72 | img = get_raw2rgb(raw_img, data['meta'], demosaic='net', lineRGB=True) 73 | img = torch.clamp(mu_tonemap(img, mu=5e3)*65535, 0, 65535) 74 | img = img.cpu().numpy()[..., ::-1] 75 | cv2.imwrite(save_dir_vispng, img.astype(np.uint16)) 76 | 77 | avg_psnr = '%.2f'%np.mean(psnr) 78 | print(torch.cuda.max_memory_allocated()) 79 | for dataset in datasets: 80 | datasets[dataset].close() 81 | -------------------------------------------------------------------------------- /test_track1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Start to test the model...." 3 | device="0" 4 | 5 | 6 | dataroot="" # including 'Train' and 'NTIRE_Val' floders 7 | name="" 8 | 9 | python test.py \ 10 | --dataset_name bracketire --model cat --name $name --dataroot $dataroot \ 11 | --load_iter 500 --save_imgs True --calc_metrics False --gpu_id $device -j 8 --block Convnext 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #-*- encoding: UTF-8 -*- 2 | # import sys 3 | # reload(sys) 4 | # sys.setdefaultencoding("utf-8") 5 | 6 | import time 7 | import torch 8 | from options.train_options import TrainOptions 9 | from data import create_dataset 10 | from models import create_model 11 | from util.visualizer import Visualizer 12 | import numpy as np 13 | import sys 14 | import random 15 | def setup_seed(seed=0): 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | np.random.seed(seed) 19 | random.seed(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = True 22 | 23 | if __name__ == '__main__': 24 | setup_seed(seed=0) 25 | 26 | opt = TrainOptions().parse() 27 | dataset_train = create_dataset(opt.dataset_name, 'train', opt) 28 | dataset_size_train = len(dataset_train) 29 | print('The number of training images = %d' % dataset_size_train) 30 | 31 | 32 | 33 | 34 | model = create_model(opt) 35 | model.setup(opt) 36 | visualizer = Visualizer(opt) 37 | total_iters = ((model.start_epoch * (dataset_size_train // opt.batch_size)) \ 38 | // opt.print_freq) * opt.print_freq 39 | total_iters_start = ((model.start_epoch * (dataset_size_train // opt.batch_size)) \ 40 | // opt.print_freq) * opt.print_freq 41 | 42 | for epoch in range(model.start_epoch + 1, opt.niter + opt.niter_decay + 1): 43 | # training 44 | epoch_start_time = time.time() 45 | epoch_iter = 0 46 | model.train() 47 | 48 | iter_data_time = iter_start_time = time.time() 49 | for i, data in enumerate(dataset_train): 50 | if total_iters % opt.print_freq == 0: 51 | t_data = time.time() - iter_data_time 52 | total_iters += 1 53 | epoch_iter += 1 54 | model.set_input(data) 55 | model.optimize_parameters(epoch) 56 | 57 | if total_iters % opt.print_freq == 0 or total_iters==total_iters_start: 58 | losses = model.get_current_losses() 59 | t_comp = (time.time() - iter_start_time) 60 | visualizer.print_current_losses( 61 | epoch, epoch_iter, losses, t_comp, t_data, total_iters) 62 | if opt.save_imgs: # Too many images 63 | visualizer.display_current_results( 64 | 'train', model.get_current_visuals(), total_iters) 65 | iter_start_time = time.time() 66 | 67 | iter_data_time = time.time() 68 | 69 | if epoch % opt.save_epoch_freq == 0: 70 | print('saving the model at the end of epoch %d, iters %d' 71 | % (epoch, total_iters)) 72 | model.save_networks(epoch) 73 | 74 | print('End of epoch %d / %d \t Time Taken: %.3f sec' 75 | % (epoch, opt.niter + opt.niter_decay, 76 | time.time() - epoch_start_time)) 77 | model.update_learning_rate(epoch) 78 | 79 | 80 | 81 | sys.stdout.flush() 82 | -------------------------------------------------------------------------------- /train_track1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Start to train the model...." 3 | dataroot="" # including 'Train' and 'NTIRE_Val' floders 4 | 5 | device='' 6 | name="" 7 | 8 | build_dir="./ckpt/"$name 9 | 10 | if [ ! -d "$build_dir" ]; then 11 | mkdir $build_dir 12 | fi 13 | 14 | LOG=./ckpt/$name/`date +%Y-%m-%d-%H-%M-%S`.txt 15 | 16 | python train.py \ 17 | --dataset_name bracketire --model cat --name $name --lr_policy step \ 18 | --patch_size 128 --niter 400 --save_imgs True --lr 1e-4 --dataroot $dataroot \ 19 | --batch_size 36 --print_freq 500 --calc_metrics True --weight_decay 0.01 \ 20 | --gpu_ids $device -j 8 --lr_decay_iters 27 --block Convnext --load_optimizers False | tee $LOG 21 | 22 | 23 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of helper functions.""" 2 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import time 8 | from functools import wraps 9 | import torch 10 | import random 11 | import numpy as np 12 | import cv2 13 | import torch 14 | import colour_demosaicing 15 | import glob 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Variable 19 | import numpy as np 20 | from math import exp 21 | 22 | 23 | def mu_tonemap(hdr_image, mu=5000): 24 | if isinstance(hdr_image, np.ndarray): 25 | return np.log(1 + mu * hdr_image) / np.log(1 + mu) 26 | elif isinstance(hdr_image, torch.Tensor): 27 | mu = torch.tensor(mu).to(hdr_image.device) 28 | return torch.log(1 + mu * hdr_image) / torch.log(1 + mu) 29 | else: 30 | raise Exception 31 | 32 | 33 | def radiance_writer(out_path, image): 34 | with open(out_path, "wb") as f: 35 | f.write(b"#?RADIANCE\n# Made with Python & Numpy\nFORMAT=32-bit_rle_rgbe\n\n") 36 | f.write(b"-Y %d +X %d\n" %(image.shape[0], image.shape[1])) 37 | brightest = np.maximum(np.maximum(image[...,0], image[...,1]), image[...,2]) + 1e-8 38 | mantissa = np.zeros_like(brightest) 39 | exponent = np.zeros_like(brightest) 40 | np.frexp(brightest, mantissa, exponent) 41 | scaled_mantissa = mantissa * 255.0 / brightest 42 | rgbe = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8) 43 | rgbe[...,0:3] = np.around(image[...,0:3] * scaled_mantissa[...,None]) 44 | rgbe[...,3] = np.around(exponent + 128) 45 | 46 | rgbe.flatten().tofile(f) 47 | 48 | 49 | def save_hdr(path, image): 50 | return radiance_writer(path, image) 51 | 52 | 53 | # 修饰函数,重新尝试600次,每次间隔1秒钟 54 | # 能对func本身处理,缺点在于无法查看func本身的提示 55 | def loop_until_success(func): 56 | @wraps(func) 57 | def wrapper(*args, **kwargs): 58 | for i in range(600): 59 | try: 60 | ret = func(*args, **kwargs) 61 | break 62 | except OSError: 63 | time.sleep(1) 64 | return ret 65 | return wrapper 66 | 67 | # 修改后的print函数及torch.save函数示例 68 | @loop_until_success 69 | def loop_print(*args, **kwargs): 70 | print(*args, **kwargs) 71 | 72 | @loop_until_success 73 | def torch_save(*args, **kwargs): 74 | torch.save(*args, **kwargs) 75 | 76 | def calc_psnr(sr, hr, range=255.): 77 | # shave = 2 78 | with torch.no_grad(): 79 | diff = (sr - hr) / hr.max() 80 | # diff = diff[:, :, shave:-shave, shave:-shave] 81 | mse = torch.pow(diff, 2).mean() 82 | return (-10 * torch.log10(mse)).item() 83 | 84 | def diagnose_network(net, name='network'): 85 | """Calculate and print the mean of average absolute(gradients) 86 | 87 | Parameters: 88 | net (torch network) -- Torch network 89 | name (str) -- the name of the network 90 | """ 91 | mean = 0.0 92 | count = 0 93 | for param in net.parameters(): 94 | if param.grad is not None: 95 | mean += torch.mean(torch.abs(param.grad.data)) 96 | count += 1 97 | if count > 0: 98 | mean = mean / count 99 | print(name) 100 | print(mean) 101 | 102 | def print_numpy(x, val=True, shp=True): 103 | """Print the mean, min, max, median, std, and size of a numpy array 104 | 105 | Parameters: 106 | val (bool) -- if print the values of the numpy array 107 | shp (bool) -- if print the shape of the numpy array 108 | """ 109 | x = x.astype(np.float64) 110 | if shp: 111 | print('shape,', x.shape) 112 | if val: 113 | x = x.flatten() 114 | print('mean = %3.3f, min = %3.3f, max = %3.3f, mid = %3.3f, std=%3.3f' 115 | % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 116 | 117 | def mkdirs(paths): 118 | """create empty directories if they don't exist 119 | 120 | Parameters: 121 | paths (str list) -- a list of directory paths 122 | """ 123 | if isinstance(paths, list) and not isinstance(paths, str): 124 | for path in paths: 125 | mkdir(path) 126 | else: 127 | mkdir(paths) 128 | 129 | def mkdir(path): 130 | """create a single empty directory if it didn't exist 131 | 132 | Parameters: 133 | path (str) -- a single directory path 134 | """ 135 | if not os.path.exists(path): 136 | os.makedirs(path) 137 | 138 | def prompt(s, width=66): 139 | print('='*(width+4)) 140 | ss = s.split('\n') 141 | if len(ss) == 1 and len(s) <= width: 142 | print('= ' + s.center(width) + ' =') 143 | else: 144 | for s in ss: 145 | for i in split_str(s, width): 146 | print('= ' + i.ljust(width) + ' =') 147 | print('='*(width+4)) 148 | 149 | def split_str(s, width): 150 | ss = [] 151 | while len(s) > width: 152 | idx = s.rfind(' ', 0, width+1) 153 | if idx > width >> 1: 154 | ss.append(s[:idx]) 155 | s = s[idx+1:] 156 | else: 157 | ss.append(s[:width]) 158 | s = s[width:] 159 | if s.strip() != '': 160 | ss.append(s) 161 | return ss 162 | 163 | # def augment_func(img, hflip, vflip, rot90): # CxHxW 164 | # if hflip: img = img[:, :, ::-1] 165 | # if vflip: img = img[:, ::-1, :] 166 | # if rot90: img = img.transpose(0, 2, 1) 167 | # return np.ascontiguousarray(img) 168 | 169 | # def augment(*imgs): # CxHxW 170 | # hflip = random.random() < 0.5 171 | # vflip = random.random() < 0.5 172 | # rot90 = random.random() < 0.5 173 | # return (augment_func(img, hflip, vflip, rot90) for img in imgs) 174 | 175 | def remove_black_level(img, black_lv=63, white_lv=4*255): 176 | img = np.maximum(img.astype(np.float32)-black_lv, 0) / (white_lv-black_lv) 177 | return img 178 | 179 | def gamma_correction(img, r=1/2.2): 180 | img = np.maximum(img, 0) 181 | img = np.power(img, r) 182 | return img 183 | 184 | def extract_bayer_channels(raw): # HxW 185 | ch_R = raw[0::2, 0::2] 186 | ch_Gb = raw[0::2, 1::2] 187 | ch_Gr = raw[1::2, 0::2] 188 | ch_B = raw[1::2, 1::2] 189 | raw_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr)) 190 | raw_combined = np.ascontiguousarray(raw_combined.transpose((2, 0, 1))) 191 | return raw_combined # 4xHxW 192 | 193 | def get_raw_demosaic(raw, pattern='RGGB'): # HxW 194 | raw_demosaic = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, pattern=pattern) 195 | raw_demosaic = np.ascontiguousarray(raw_demosaic.astype(np.float32).transpose((2, 0, 1))) 196 | return raw_demosaic # 3xHxW 197 | 198 | def get_coord(H, W, x=448/3968, y=448/2976): 199 | x_coord = np.linspace(-x + (x / W), x - (x / W), W) 200 | x_coord = np.expand_dims(x_coord, axis=0) 201 | x_coord = np.tile(x_coord, (H, 1)) 202 | x_coord = np.expand_dims(x_coord, axis=0) 203 | 204 | y_coord = np.linspace(-y + (y / H), y - (y / H), H) 205 | y_coord = np.expand_dims(y_coord, axis=1) 206 | y_coord = np.tile(y_coord, (1, W)) 207 | y_coord = np.expand_dims(y_coord, axis=0) 208 | 209 | coord = np.ascontiguousarray(np.concatenate([x_coord, y_coord])) 210 | coord = np.float32(coord) 211 | 212 | return coord 213 | 214 | def read_wb(txtfile, key): 215 | wb = np.zeros((1,4)) 216 | with open(txtfile) as f: 217 | for l in f: 218 | if key in l: 219 | for i in range(wb.shape[0]): 220 | nextline = next(f) 221 | try: 222 | wb[i,:] = nextline.split() 223 | except: 224 | print("WB error XXXXXXX") 225 | print(txtfile) 226 | wb = wb.astype(np.float32) 227 | return wb 228 | 229 | 230 | def gaussian(window_size, sigma): 231 | gauss = torch.Tensor([exp(-(x - window_size/2)**2/float(2*sigma**2)) for x in range(window_size)]) 232 | return gauss/gauss.sum() 233 | 234 | def create_window(window_size, channel, device): 235 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1).to(device) 236 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 237 | window = _2D_window.expand(channel, 1, window_size, window_size) 238 | return window 239 | 240 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 241 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 242 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 243 | 244 | mu1_sq = mu1.pow(2) 245 | mu2_sq = mu2.pow(2) 246 | mu1_mu2 = mu1*mu2 247 | 248 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 249 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 250 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 251 | 252 | C1 = 0.01**2 253 | C2 = 0.03**2 254 | 255 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 256 | 257 | if size_average: 258 | return ssim_map.mean() 259 | else: 260 | return ssim_map.mean(1).mean(1).mean(1) 261 | 262 | class SSIM(torch.nn.Module): 263 | def __init__(self, window_size = 11, size_average = True): 264 | super(SSIM, self).__init__() 265 | self.window_size = window_size 266 | self.size_average = size_average 267 | self.channel = 1 268 | # self.window = create_window(window_size, self.channel) 269 | 270 | def forward(self, img1, img2): 271 | (_, channel, _, _) = img1.size() 272 | 273 | window = create_window(self.window_size, channel, img1.device) 274 | self.window = window 275 | self.channel = channel 276 | 277 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 278 | 279 | def ssim(img1, img2, window_size = 11, size_average = True): 280 | (_, channel, _, _) = img1.size() 281 | 282 | window = create_window(window_size, channel, img1.device) 283 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os.path import join 3 | from tensorboardX import SummaryWriter 4 | from matplotlib import pyplot as plt 5 | from io import BytesIO 6 | from PIL import Image 7 | from functools import partial 8 | from functools import wraps 9 | import time 10 | 11 | def write_until_success(func): 12 | @wraps(func) 13 | def wrapper(*args, **kwargs): 14 | for i in range(30): 15 | try: 16 | ret = func(*args, **kwargs) 17 | break 18 | except OSError: 19 | print('%s OSError' % str(args)) 20 | time.sleep(1) 21 | return ret 22 | return wrapper 23 | 24 | class Visualizer(): 25 | def __init__(self, opt): 26 | self.opt = opt 27 | if opt.isTrain: 28 | self.name = opt.name 29 | self.save_dir = join(opt.checkpoints_dir, opt.name, 'log') 30 | self.writer = SummaryWriter(logdir=join(self.save_dir)) 31 | else: 32 | self.name = '%s_%s_%d' % ( 33 | opt.name, opt.dataset_name, opt.load_iter) 34 | self.save_dir = join(opt.checkpoints_dir, opt.name) 35 | if opt.save_imgs: 36 | self.writer = SummaryWriter(logdir=join( 37 | self.save_dir, 'ckpts', self.name)) 38 | 39 | @write_until_success 40 | def display_current_results(self, phase, visuals, iters): 41 | for k, v in visuals.items(): 42 | v = v.cpu() 43 | self.writer.add_image('%s/%s'%(phase, k), v[0]/255, iters) 44 | self.writer.flush() 45 | 46 | @write_until_success 47 | def print_current_losses(self, epoch, iters, losses, 48 | t_comp, t_data, total_iters): 49 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' \ 50 | % (epoch, iters, t_comp, t_data) 51 | for k, v in losses.items(): 52 | message += '%s: %.4e ' % (k, v) 53 | self.writer.add_scalar('loss/%s'%k, v, total_iters) 54 | print(message) 55 | 56 | @write_until_success 57 | def print_psnr(self, epoch, total_epoch, time_val, mean_psnr): 58 | self.writer.add_scalar('val/psnr', mean_psnr, epoch) 59 | print('End of epoch %d / %d (Val) \t Time Taken: %.3f s \t PSNR: %f' 60 | % (epoch, total_epoch, time_val, mean_psnr)) 61 | 62 | 63 | --------------------------------------------------------------------------------