├── README.md
├── data
├── NTIRE_Val
│ └── test
│ │ └── 000936
│ │ ├── alignratio.npy
│ │ ├── metadata.npy
│ │ ├── raw_1.npy
│ │ ├── raw_16.npy
│ │ ├── raw_256.npy
│ │ ├── raw_4.npy
│ │ ├── raw_64.npy
│ │ ├── rgb_vis_1.png
│ │ ├── rgb_vis_16.png
│ │ ├── rgb_vis_256.png
│ │ ├── rgb_vis_4.png
│ │ └── rgb_vis_64.png
├── __init__.py
├── base_dataset.py
├── bracketire_dataset.py
├── bracketireplus_dataset.py
└── degrade
│ ├── degrade_kernel.py
│ ├── process.py
│ └── unprocess.py
├── imgs
├── Overview of CRNet.png
├── multi_pro.png
├── out1.png
└── out2.png
├── isp
├── __init__.py
├── demosaic_bayer.py
├── dng_opcode.py
├── exif_data_formats.py
├── exif_utils.py
├── isp.py
├── model.bin
├── pipeline.py
├── pipeline_utils.py
└── tone_curve.mat
├── mixer.yaml
├── models
├── __init__.py
├── base_model.py
├── blocks.py
├── cat_model.py
├── degrade
│ ├── degrade_kernel.py
│ ├── process.py
│ └── unprocess.py
├── losses.py
└── networks.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── spynet
└── spynet_20210409-c6c1bd09.pth
├── test.py
├── test_track1.sh
├── train.py
├── train_track1.sh
└── util
├── __init__.py
├── util.py
└── visualizer.py
/README.md:
--------------------------------------------------------------------------------
1 | # CRNet
2 |
3 | PyTorch implementation of CRNet. Our model achievied third place in track 1 of the Bracketing Image Restoration and Enhancement Challenge.
4 | ## 1. Abstract
5 |
6 | It is challenging but highly desired to acquire high-quality photos with clear content in low-light environments. Although multi-image processing methods (using burst, dual-exposure, or multi-exposure images) have made significant progress in addressing this issue, they typically focus exclusively on specific restoration or enhancement tasks, being insufficient in exploiting multi-image. Motivated by that multi-exposure images are complementary in denoising, deblurring, high dynamic range imaging, and super-resolution, we propose to utilize bracketing photography to unify restoration and enhancement tasks in this work. Due to the difficulty in collecting real-world pairs, we suggest a solution that first pre-trains the model with synthetic paired data and then adapts it to real-world unlabeled images. In particular, a temporally modulated recurrent network (TMRNet) and self-supervised adaptation method are proposed. Moreover, we construct a data simulation pipeline to synthesize pairs and collect real-world images from 200 nighttime scenarios. Experiments on both datasets show that our method performs favorably against the state-of-the-art multi-image processing ones.
7 |
8 | ## 2. Overview of CRNet
9 |
10 |

11 |
12 | ## 3. Comparison of other methods in track 1 of the Bracketing Image Restoration and Enhancement Challenge.
13 |
14 | 
15 |
16 | ## 4. Expample Result
17 |
18 | 
19 | 
20 |
21 | ## 5. Checkpoint
22 |
23 | https://pan.baidu.com/s/17DDXbthjvLjyoHOnUzw-hg?pwd=g1nw
24 |
25 | ## 6. Dataset
26 |
27 | The download link for the dataset is https://github.com/cszhilu1998/BracketIRE.
28 |
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/alignratio.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/alignratio.npy
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/metadata.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/metadata.npy
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/raw_1.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_1.npy
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/raw_16.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_16.npy
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/raw_256.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_256.npy
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/raw_4.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_4.npy
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/raw_64.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/raw_64.npy
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/rgb_vis_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_1.png
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/rgb_vis_16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_16.png
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/rgb_vis_256.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_256.png
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/rgb_vis_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_4.png
--------------------------------------------------------------------------------
/data/NTIRE_Val/test/000936/rgb_vis_64.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/data/NTIRE_Val/test/000936/rgb_vis_64.png
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.utils.data
3 | from data.base_dataset import BaseDataset
4 |
5 |
6 | def find_dataset_using_name(dataset_name, split='train'):
7 | dataset_filename = "data." + dataset_name + "_dataset"
8 | datasetlib = importlib.import_module(dataset_filename)
9 |
10 | dataset = None
11 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
12 | for name, cls in datasetlib.__dict__.items():
13 | if name.lower() == target_dataset_name.lower() \
14 | and issubclass(cls, BaseDataset):
15 | dataset = cls
16 |
17 | if dataset is None:
18 | raise NotImplementedError("In %s.py, there should be a subclass of "
19 | "BaseDataset with class name that matches %s in "
20 | "lowercase." % (dataset_filename, target_dataset_name))
21 | return dataset
22 |
23 |
24 | def create_dataset(dataset_name, split, opt):
25 | data_loader = CustomDatasetDataLoader(dataset_name, split, opt)
26 | dataset = data_loader.load_data()
27 | return dataset
28 |
29 |
30 | class CustomDatasetDataLoader():
31 | def __init__(self, dataset_name, split, opt):
32 | self.opt = opt
33 | dataset_class = find_dataset_using_name(dataset_name, split)
34 | self.dataset = dataset_class(opt, split, dataset_name)
35 | # self.imio = self.dataset.imio
36 | print("dataset [%s(%s)] created" % (dataset_name, split))
37 | self.dataloader = torch.utils.data.DataLoader(
38 | self.dataset,
39 | batch_size=opt.batch_size if split=='train' else 1,
40 | shuffle=opt.shuffle and split=='train',
41 | num_workers=int(opt.num_dataloader),
42 | drop_last=opt.drop_last)
43 |
44 | def load_data(self):
45 | return self
46 |
47 | def __len__(self):
48 | """Return the number of data in the dataset"""
49 | return min(len(self.dataset), self.opt.max_dataset_size)
50 |
51 | def __iter__(self):
52 | """Return a batch of data"""
53 | for i, data in enumerate(self.dataloader):
54 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
55 | break
56 | yield data
57 |
58 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from abc import ABC, abstractmethod
3 |
4 |
5 | class BaseDataset(data.Dataset, ABC):
6 | def __init__(self, opt, split, dataset_name):
7 | self.opt = opt
8 | self.split = split
9 | self.root = opt.dataroot
10 | self.dataset_name = dataset_name.lower()
11 |
12 | @abstractmethod
13 | def __len__(self):
14 | return 0
15 |
16 | @abstractmethod
17 | def __getitem__(self, index):
18 | pass
19 |
20 |
--------------------------------------------------------------------------------
/data/bracketire_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import torch
5 | import random
6 | from tqdm import tqdm
7 | from os.path import join as opj
8 | from multiprocessing.dummy import Pool
9 | from data.base_dataset import BaseDataset
10 |
11 |
12 | # BracketIRE dataset
13 | class BracketIREDataset(BaseDataset):
14 | def __init__(self, opt, split='train', dataset_name='BracketIRE'):
15 | super(BracketIREDataset, self).__init__(opt, split, dataset_name)
16 |
17 | self.batch_size = opt.batch_size
18 | self.patch_size = opt.patch_size
19 | self.frame_num = opt.frame_num
20 |
21 | if split == 'train':
22 | self._getitem = self._getitem_train
23 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split,
24 | name='Train')
25 | self.len_data = 50000 * self.batch_size
26 | elif split == 'test':
27 | self._getitem = self._getitem_test
28 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split,
29 | name='NTIRE_Val')
30 | self.len_data = len(self.names)
31 | self.meta_data = [0] * len(self.names)
32 | self.raw_images = [0] * len(self.names)
33 | self.gt_images = [0] * len(self.names)
34 | read_images(self)
35 | else:
36 | raise ValueError
37 |
38 | self.split = split
39 |
40 |
41 | def __getitem__(self, index):
42 | return self._getitem(index)
43 |
44 | def __len__(self):
45 | return self.len_data
46 |
47 | def _getitem_train(self, idx):
48 | idx = idx % len(self.names)
49 |
50 | meta_data = np.load(self.meta_dirs[idx], allow_pickle=True)
51 | gt_images = np.load(self.gt_dirs[idx], allow_pickle=True).transpose(2, 0, 1)
52 | imgs = []
53 | for m in range(5):
54 | imgs.append(np.load(self.raw_dirs[idx][m], allow_pickle=True).transpose(2, 0, 1))
55 | raw_images = imgs
56 |
57 | raws = torch.from_numpy(np.float32(np.array(raw_images))) / (2 ** 10 - 1)
58 | gt = torch.from_numpy(np.float32(gt_images))
59 |
60 | raws, gt = self._crop_patch(raws, gt, self.patch_size)
61 |
62 | return {'gt': gt, # [4, H, W]
63 | 'raws': raws, # [T=5, 4, H, W]
64 | 'fname': self.names[idx]}
65 |
66 | def _getitem_test(self, idx):
67 | raws = torch.from_numpy(np.float32(np.array(self.raw_images[idx]))) / (2 ** 10 - 1)
68 | meta = self._process_metadata(self.meta_data[idx])
69 |
70 | return {'meta': meta,
71 | 'gt': raws[0],
72 | 'raws': raws,
73 | 'fname': self.names[idx]}
74 |
75 | def _crop_patch(self, raws, gt, p):
76 | ih, iw = raws.shape[-2:]
77 | ph = random.randrange(10, ih - p + 1 - 10)
78 | pw = random.randrange(10, iw - p + 1 - 10)
79 | return raws[..., ph:ph + p, pw:pw + p], \
80 | gt[..., ph:ph + p, pw:pw + p]
81 |
82 | def _process_metadata(self, metadata):
83 | metadata_item = metadata.item()
84 | meta = {}
85 | for key in metadata_item:
86 | meta[key] = torch.from_numpy(metadata_item[key])
87 | return meta
88 |
89 | def _read_raw_path(self, root):
90 | img_paths = []
91 | for expo in range(self.frame_num):
92 | img_paths.append(opj(root, 'raw_' + str(4 ** expo) + '.npy'))
93 | return img_paths
94 |
95 | def _get_image_dir(self, dataroot, split=None, name=None):
96 | image_names = []
97 | meta_dirs = []
98 | raw_dirs = []
99 | gt_dirs = []
100 |
101 | for scene_file in sorted(os.listdir(opj(dataroot, name))):
102 | for image_file in sorted(os.listdir(opj(dataroot, name, scene_file))):
103 | image_root = opj(dataroot, name, scene_file, image_file)
104 | image_names.append(scene_file + '-' + image_file)
105 | meta_dirs.append(opj(image_root, 'metadata.npy'))
106 | raw_dirs.append(self._read_raw_path(image_root))
107 | if split == 'train':
108 | gt_dirs.append(opj(image_root, 'raw_gt.npy'))
109 | elif split == 'test':
110 | gt_dirs = []
111 |
112 | return image_names, meta_dirs, raw_dirs, gt_dirs
113 |
114 |
115 | def iter_obj(num, objs):
116 | for i in range(num):
117 | yield (i, objs)
118 |
119 |
120 | def imreader(arg):
121 | i, obj = arg
122 | for _ in range(3):
123 | try:
124 | imgs = []
125 | for m in range(obj.frame_num):
126 | imgs.append(np.load(obj.raw_dirs[i][m], allow_pickle=True).transpose(2, 0, 1))
127 | obj.raw_images[i] = imgs
128 | if obj.split == 'train':
129 | obj.gt_images[i] = np.load(obj.gt_dirs[i], allow_pickle=True).transpose(2, 0, 1)
130 | obj.meta_data[i] = np.load(obj.meta_dirs[i], allow_pickle=True)
131 | failed = False
132 | break
133 | except:
134 | failed = True
135 | if failed: print('%s fails!' % obj.names[i])
136 |
137 |
138 | def read_images(obj):
139 | # may use `from multiprocessing import Pool` instead, but less efficient and
140 | # NOTE: `multiprocessing.Pool` will duplicate given object for each process.
141 | print('Starting to load images via multiple imreaders')
142 | pool = Pool() # use all threads by default
143 | for _ in tqdm(pool.imap(imreader, iter_obj(len(obj.names), obj)), total=len(obj.names)):
144 | pass
145 | pool.close()
146 | pool.join()
147 |
148 |
149 | if __name__ == '__main__':
150 | pass
151 |
--------------------------------------------------------------------------------
/data/bracketireplus_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import cv2
4 | import torch
5 | import random
6 | from tqdm import tqdm
7 | from os.path import join as opj
8 | from multiprocessing.dummy import Pool
9 | from data.base_dataset import BaseDataset
10 |
11 |
12 | # BracketIRE+ dataset
13 | class BracketIREPlusDataset(BaseDataset):
14 | def __init__(self, opt, split='train', dataset_name='BracketIREPlus'):
15 | super(BracketIREPlusDataset, self).__init__(opt, split, dataset_name)
16 |
17 | self.batch_size = opt.batch_size
18 | self.patch_size = opt.patch_size
19 | self.frame_num = opt.frame_num
20 | self.scale = 4
21 |
22 | if split == 'train':
23 | self._getitem = self._getitem_train
24 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split, name='Train')
25 | self.len_data = 500 * self.batch_size
26 | elif split == 'test':
27 | self._getitem = self._getitem_test
28 | self.names, self.meta_dirs, self.raw_dirs, self.gt_dirs = self._get_image_dir(self.root, split, name='NTIRE_Val')
29 | self.len_data = len(self.names)
30 | else:
31 | raise ValueError
32 |
33 | self.meta_data = [0] * len(self.names)
34 | self.raw_images = [0] * len(self.names)
35 | self.gt_images = [0] * len(self.names)
36 | read_images(self)
37 |
38 | def __getitem__(self, index):
39 | return self._getitem(index)
40 |
41 | def __len__(self):
42 | return self.len_data
43 |
44 | def _getitem_train(self, idx):
45 | idx = idx % len(self.names)
46 |
47 | raws = torch.from_numpy(np.float32(np.array(self.raw_images[idx]))) / (2**10 - 1)
48 | gt = torch.from_numpy(np.float32(self.gt_images[idx]))
49 |
50 | raws, gt = self._crop_patch(raws, gt, self.patch_size, self.scale)
51 |
52 | return {'gt': gt, # [4, H, W]
53 | 'raws': raws, # [T=5, 4, H, W]
54 | 'fname': self.names[idx]}
55 |
56 | def _getitem_test(self, idx):
57 | raws = torch.from_numpy(np.float32(np.array(self.raw_images[idx]))) / (2**10 - 1)
58 | meta = self._process_metadata(self.meta_data[idx])
59 |
60 | return {'meta': meta,
61 | 'gt': raws[0],
62 | 'raws': raws,
63 | 'fname': self.names[idx]}
64 |
65 | def _crop_patch(self, raws, gt, p, s):
66 | ih, iw = raws.shape[-2:]
67 | ph = random.randrange(4, ih - p + 1 - 4)
68 | pw = random.randrange(4, iw - p + 1 - 4)
69 | return raws[..., ph:ph+p, pw:pw+p], \
70 | gt[..., ph*s:(ph+p)*s, pw*s:(pw+p)*s]
71 |
72 | def _process_metadata(self, metadata):
73 | metadata_item = metadata.item()
74 | meta = {}
75 | for key in metadata_item:
76 | meta[key] = torch.from_numpy(metadata_item[key])
77 | return meta
78 |
79 | def _read_raw_path(self, root):
80 | img_paths = []
81 | for expo in range(self.frame_num):
82 | img_paths.append(opj(root, 'x'+str(self.scale), 'raw_' + str(4**expo) + '.npy'))
83 | return img_paths
84 |
85 | def _get_image_dir(self, dataroot, split=None, name=None):
86 | image_names = []
87 | meta_dirs = []
88 | raw_dirs = []
89 | gt_dirs = []
90 |
91 | for scene_file in sorted(os.listdir(opj(dataroot, name))):
92 | for image_file in sorted(os.listdir(opj(dataroot, name, scene_file))):
93 | image_root = opj(dataroot, name, scene_file, image_file)
94 | image_names.append(scene_file + '-' + image_file)
95 | meta_dirs.append(opj(image_root, 'metadata.npy'))
96 | raw_dirs.append(self._read_raw_path(image_root))
97 | if split == 'train':
98 | gt_dirs.append(opj(image_root, 'raw_gt.npy'))
99 | elif split == 'test':
100 | gt_dirs = []
101 |
102 | return image_names, meta_dirs, raw_dirs, gt_dirs
103 |
104 |
105 | def iter_obj(num, objs):
106 | for i in range(num):
107 | yield (i, objs)
108 |
109 | def imreader(arg):
110 | i, obj = arg
111 | for _ in range(3):
112 | try:
113 | imgs = []
114 | for m in range(obj.frame_num):
115 | imgs.append(np.load(obj.raw_dirs[i][m], allow_pickle=True).transpose(2, 0, 1))
116 | obj.raw_images[i] = imgs
117 | if obj.split == 'train':
118 | obj.gt_images[i] = np.load(obj.gt_dirs[i], allow_pickle=True).transpose(2, 0, 1)
119 | obj.meta_data[i] = np.load(obj.meta_dirs[i], allow_pickle=True)
120 | failed = False
121 | break
122 | except:
123 | failed = True
124 | if failed: print('%s fails!' % obj.names[i])
125 |
126 | def read_images(obj):
127 | # may use `from multiprocessing import Pool` instead, but less efficient and
128 | # NOTE: `multiprocessing.Pool` will duplicate given object for each process.
129 | print('Starting to load images via multiple imreaders')
130 | pool = Pool() # use all threads by default
131 | for _ in tqdm(pool.imap(imreader, iter_obj(len(obj.names), obj)), total=len(obj.names)):
132 | pass
133 | pool.close()
134 | pool.join()
135 |
136 | if __name__ == '__main__':
137 | pass
138 |
--------------------------------------------------------------------------------
/data/degrade/degrade_kernel.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | import torch
5 | from scipy import ndimage
6 | from scipy.interpolate import interp2d
7 | from .unprocess import unprocess, random_noise_levels, add_noise
8 | from .process import process
9 | from PIL import Image
10 |
11 |
12 |
13 | # def get_rgb2raw2rgb(img):
14 | # img = torch.from_numpy(np.array(img)) / 255.0
15 | # deg_img, features = unprocess(img)
16 | # shot_noise, read_noise = random_noise_levels()
17 | # deg_img = add_noise(deg_img, shot_noise, read_noise)
18 | # deg_img = deg_img.unsqueeze(0)
19 | # features['red_gain'] = features['red_gain'].unsqueeze(0)
20 | # features['blue_gain'] = features['blue_gain'].unsqueeze(0)
21 | # features['cam2rgb'] = features['cam2rgb'].unsqueeze(0)
22 | # deg_img = process(deg_img, features['red_gain'], features['blue_gain'], features['cam2rgb'])
23 | # deg_img = deg_img.squeeze(0)
24 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy()
25 | # deg_img = deg_img.astype(np.uint8)
26 | # return Image.fromarray(deg_img)
27 |
28 |
29 | # def get_rgb2raw_noise(img, noise_level, features=None):
30 | # # img = np.transpose(img, (1, 2, 0))
31 | # img = torch.from_numpy(np.array(img)) / 255.0
32 |
33 | # deg_img, features = unprocess(img, features)
34 | # shot_noise, read_noise = random_noise_levels(noise_level)
35 | # deg_img_noise = add_noise(deg_img, shot_noise, read_noise)
36 | # # deg_img_noise = torch.clamp(deg_img_noise, min=0.0, max=1.0)
37 |
38 | # # deg_img = np.transpose(deg_img, (2, 0, 1))
39 | # # deg_img_noise = np.transpose(deg_img_noise, (2, 0, 1))
40 | # return deg_img_noise, features
41 |
42 |
43 | def get_rgb2raw(img, features=None):
44 | # img = np.transpose(img, (1, 2, 0))
45 | device = img.device
46 | deg_img, features = unprocess(img, features, device)
47 | return deg_img, features
48 |
49 |
50 | def get_raw2rgb(img, features, demosaic='net', lineRGB=False):
51 | # img = np.transpose(img, (1, 2, 0))
52 | # img = torch.from_numpy(np.array(img))
53 | img = img.unsqueeze(0)
54 | device = img.device
55 | deg_img = process(img, features['red_gain'].to(device), features['blue_gain'].to(device),
56 | features['cam2rgb'].to(device), demosaic, lineRGB)
57 | deg_img = deg_img.squeeze(0)
58 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy()
59 | # deg_img = deg_img.astype(np.uint8)
60 | return deg_img
61 |
62 |
63 | # def pack_raw_image(im_raw): # HxW
64 | # """ Packs a single channel bayer image into 4 channel tensor, where channels contain R, G, G, and B values"""
65 | # if isinstance(im_raw, np.ndarray):
66 | # im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2))
67 | # elif isinstance(im_raw, torch.Tensor):
68 | # im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype)
69 | # else:
70 | # raise Exception
71 |
72 | # im_out[0, :, :] = im_raw[0::2, 0::2]
73 | # im_out[1, :, :] = im_raw[0::2, 1::2]
74 | # im_out[2, :, :] = im_raw[1::2, 0::2]
75 | # im_out[3, :, :] = im_raw[1::2, 1::2]
76 | # return im_out # 4xHxW
77 |
78 |
79 | # def flatten_raw_image(im_raw_4ch): # 4xHxW
80 | # """ unpack a 4-channel tensor into a single channel bayer image"""
81 | # if isinstance(im_raw_4ch, np.ndarray):
82 | # im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2))
83 | # elif isinstance(im_raw_4ch, torch.Tensor):
84 | # im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype)
85 | # else:
86 | # raise Exception
87 |
88 | # im_out[0::2, 0::2] = im_raw_4ch[0, :, :]
89 | # im_out[0::2, 1::2] = im_raw_4ch[1, :, :]
90 | # im_out[1::2, 0::2] = im_raw_4ch[2, :, :]
91 | # im_out[1::2, 1::2] = im_raw_4ch[3, :, :]
92 |
93 | # return im_out # HxW
--------------------------------------------------------------------------------
/data/degrade/process.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Forward processing of raw data to sRGB images.
17 |
18 | Unprocessing Images for Learned Raw Denoising
19 | http://timothybrooks.com/tech/unprocessing
20 | """
21 |
22 | import numpy as np
23 | import torch
24 | import torch.nn as nn
25 | import torch.distributions as tdist
26 | from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007
27 | import os
28 | from isp import demosaic_bayer
29 |
30 |
31 | def apply_gains(bayer_images, red_gains, blue_gains):
32 | """Applies white balance gains to a batch of Bayer images."""
33 | red_gains = red_gains.squeeze(1)
34 | blue_gains= blue_gains.squeeze(1)
35 | green_gains = torch.ones_like(red_gains)
36 | gains = torch.stack([red_gains, green_gains, green_gains, blue_gains], dim=-1)
37 | gains = gains[:, None, None, :]
38 | # print(bayer_images.shape, gains.shape)
39 | outs = bayer_images * gains
40 | return outs
41 |
42 |
43 | def demosaic(bayer_images):
44 | def SpaceToDepth_fact2(x):
45 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14
46 | bs = 2
47 | N, C, H, W = x.size()
48 | x = x.view(N, C, H // bs, bs, W // bs, bs) # (N, C, H//bs, bs, W//bs, bs)
49 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
50 | x = x.view(N, C * (bs ** 2), H // bs, W // bs) # (N, C*bs^2, H//bs, W//bs)
51 | return x
52 | def DepthToSpace_fact2(x):
53 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14
54 | bs = 2
55 | N, C, H, W = x.size()
56 | x = x.view(N, bs, bs, C // (bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
57 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
58 | x = x.view(N, C // (bs ** 2), H * bs, W * bs) # (N, C//bs^2, H * bs, W * bs)
59 | return x
60 |
61 | """Bilinearly demosaics a batch of RGGB Bayer images."""
62 |
63 | shape = bayer_images.size()
64 | shape = [shape[1] * 2, shape[2] * 2]
65 |
66 | red = bayer_images[Ellipsis, 0:1]
67 | upsamplebyX = nn.Upsample(size=shape, mode='bilinear', align_corners=False)
68 | red = upsamplebyX(red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
69 |
70 | green_red = bayer_images[Ellipsis, 1:2]
71 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right
72 | green_red = upsamplebyX(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
73 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right
74 | green_red = SpaceToDepth_fact2(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
75 |
76 | green_blue = bayer_images[Ellipsis, 2:3]
77 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down
78 | green_blue = upsamplebyX(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
79 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down
80 | green_blue = SpaceToDepth_fact2(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
81 |
82 | green_at_red = (green_red[Ellipsis, 0] + green_blue[Ellipsis, 0]) / 2
83 | green_at_green_red = green_red[Ellipsis, 1]
84 | green_at_green_blue = green_blue[Ellipsis, 2]
85 | green_at_blue = (green_red[Ellipsis, 3] + green_blue[Ellipsis, 3]) / 2
86 |
87 | green_planes = [
88 | green_at_red, green_at_green_red, green_at_green_blue, green_at_blue
89 | ]
90 | green = DepthToSpace_fact2(torch.stack(green_planes, dim=-1).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
91 |
92 | blue = bayer_images[Ellipsis, 3:4]
93 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0])
94 | blue = upsamplebyX(blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
95 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0])
96 |
97 | rgb_images = torch.cat([red, green, blue], dim=-1)
98 | return rgb_images
99 |
100 |
101 | def apply_ccms(images, ccms):
102 | """Applies color correction matrices."""
103 | images = images[:, :, :, None, :]
104 | ccms = ccms[:, None, None, :, :]
105 | outs = torch.sum(images * ccms, dim=-1)
106 | return outs
107 |
108 |
109 | def gamma_compression(images, gamma=2.2):
110 | """Converts from linear to gamma space."""
111 | # Clamps to prevent numerical instability of gradients near zero.
112 | Mask = lambda x: (x>0.0031308).float()
113 | sRGBDeLinearize = lambda x,m: m * (1.055 * (m * x) ** (1/2.4) - 0.055) + (1-m) * (12.92 * x)
114 | return sRGBDeLinearize(images, Mask(images))
115 | # outs = torch.clamp(images, min=1e-8) ** (1.0 / gamma)
116 | # return outs
117 |
118 |
119 | def process(bayer_images, red_gains, blue_gains, cam2rgbs, demosaic_type, lineRGB):
120 | # print(bayer_images.shape, red_gains.shape, cam2rgbs.shape)
121 | """Processes a batch of Bayer RGGB images into sRGB images."""
122 | # White balance.
123 | bayer_images = apply_gains(bayer_images, red_gains, blue_gains)
124 | # Demosaic.
125 | bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0)
126 |
127 | if demosaic_type == 'default':
128 | images = demosaic(bayer_images)
129 | elif demosaic_type == 'menon2007':
130 | # print(bayer_images.size())
131 | bayer_images = flatten_raw_image(bayer_images.squeeze(0))
132 | images = demosaicing_CFA_Bayer_Menon2007(bayer_images.cpu().numpy(), 'RGGB')
133 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device)
134 | elif demosaic_type == 'net':
135 | bayer_images = flatten_raw_image(bayer_images.squeeze(0)).cpu().numpy()
136 | bayer = np.power(np.clip(bayer_images.astype(dtype=np.float32), 0, 1), 1 / 2.2)
137 | pretrained_model_path = "./isp/model.bin"
138 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=red_gains.device,
139 | cfa='bayer', state_dict=True)
140 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa='RGGB',
141 | demosaic_net=demosaic_net, device=red_gains.device)
142 | images = np.power(np.clip(rgb, 0, 1), 2.2)
143 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device)
144 |
145 | # Color correction.
146 | images = apply_ccms(images, cam2rgbs)
147 | # Gamma compression.
148 | images = torch.clamp(images, min=0.0, max=1.0)
149 | if not lineRGB:
150 | images = gamma_compression(images)
151 | return images
152 |
153 |
154 | def flatten_raw_image(im_raw_4ch): # HxWx4
155 | """ unpack a 4-channel tensor into a single channel bayer image"""
156 | if isinstance(im_raw_4ch, np.ndarray):
157 | im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2))
158 | elif isinstance(im_raw_4ch, torch.Tensor):
159 | im_out = torch.zeros((im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2), dtype=im_raw_4ch.dtype)
160 | else:
161 | raise Exception
162 |
163 | im_out[0::2, 0::2] = im_raw_4ch[:, :, 0]
164 | im_out[0::2, 1::2] = im_raw_4ch[:, :, 1]
165 | im_out[1::2, 0::2] = im_raw_4ch[:, :, 2]
166 | im_out[1::2, 1::2] = im_raw_4ch[:, :, 3]
167 |
168 | return im_out # HxW
--------------------------------------------------------------------------------
/data/degrade/unprocess.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unprocesses sRGB images into realistic raw data.
17 |
18 | Unprocessing Images for Learned Raw Denoising
19 | http://timothybrooks.com/tech/unprocessing
20 | """
21 |
22 | import numpy as np
23 | import torch
24 | import torch.distributions as tdist
25 |
26 |
27 | def random_ccm(device):
28 | """Generates random RGB -> Camera color correction matrices."""
29 | # Takes a random convex combination of XYZ -> Camera CCMs.
30 | xyz2cams = [[[1.0234, -0.2969, -0.2266],
31 | [-0.5625, 1.6328, -0.0469],
32 | [-0.0703, 0.2188, 0.6406]],
33 | [[0.4913, -0.0541, -0.0202],
34 | [-0.613, 1.3513, 0.2906],
35 | [-0.1564, 0.2151, 0.7183]],
36 | [[0.838, -0.263, -0.0639],
37 | [-0.2887, 1.0725, 0.2496],
38 | [-0.0627, 0.1427, 0.5438]],
39 | [[0.6596, -0.2079, -0.0562],
40 | [-0.4782, 1.3016, 0.1933],
41 | [-0.097, 0.1581, 0.5181]]]
42 | num_ccms = len(xyz2cams)
43 | xyz2cams = torch.FloatTensor(xyz2cams)
44 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8)
45 | weights_sum = torch.sum(weights, dim=0)
46 | xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum
47 |
48 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
49 | rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375],
50 | [0.2126729, 0.7151522, 0.0721750],
51 | [0.0193339, 0.1191920, 0.9503041]])
52 | rgb2cam = torch.mm(xyz2cam.to(device), rgb2xyz.to(device))
53 |
54 | # Normalizes each row.
55 | rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True)
56 | return rgb2cam
57 |
58 |
59 | def random_gains(device):
60 | """Generates random gains for brightening and white balance."""
61 | # RGB gain represents brightening.
62 | n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1]))
63 | rgb_gain = 1.0 / n.sample()
64 |
65 | # Red and blue gains represent white balance.
66 | red_gain = torch.FloatTensor(1).uniform_(1.9, 2.4)
67 | blue_gain = torch.FloatTensor(1).uniform_(1.5, 1.9)
68 | return rgb_gain.to(device), red_gain.to(device), blue_gain.to(device)
69 |
70 |
71 | def inverse_smoothstep(image):
72 | """Approximately inverts a global tone mapping curve."""
73 | image = torch.clamp(image, min=0.0, max=1.0)
74 | out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
75 | return out
76 |
77 |
78 | def gamma_expansion(image):
79 | """Converts from gamma to linear space."""
80 | # Clamps to prevent numerical instability of gradients near zero.
81 | Mask = lambda x: (x>0.04045).float()
82 | sRGBLinearize = lambda x,m: m * ((m * x + 0.055) / 1.055) ** 2.4 + (1-m) * (x / 12.92)
83 | return sRGBLinearize(image, Mask(image))
84 | # out = torch.clamp(image, min=1e-8) ** 2.2
85 | # return out
86 |
87 |
88 | def apply_ccm(image, ccm):
89 | """Applies a color correction matrix."""
90 | shape = image.size()
91 | image = torch.reshape(image, [-1, 3])
92 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
93 | out = torch.reshape(image, shape)
94 | return out
95 |
96 |
97 | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device):
98 | """Inverts gains while safely handling saturated pixels."""
99 | gains = torch.stack((1.0 / red_gain, torch.tensor([1.0]).to(device), 1.0 / blue_gain)) # / rgb_gain
100 | gains = gains.to(device).squeeze()
101 | gains = gains[None, None, :]
102 | # Prevents dimming of saturated pixels by smoothly masking gains near white.
103 | gray = torch.mean(image, dim=-1, keepdim=True)
104 | inflection = 0.9
105 | mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0
106 | safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
107 | out = image * safe_gains
108 | return out
109 |
110 |
111 | def mosaic(image):
112 | """Extracts RGGB Bayer planes from an RGB image."""
113 | shape = image.size()
114 | red = image[0::2, 0::2, 0]
115 | green_red = image[0::2, 1::2, 1]
116 | green_blue = image[1::2, 0::2, 1]
117 | blue = image[1::2, 1::2, 2]
118 | out = torch.stack((red, green_red, green_blue, blue), dim=-1)
119 | out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4))
120 | return out
121 |
122 |
123 | def unprocess(image, features=None, device=None):
124 | """Unprocesses an image from sRGB to realistic raw data."""
125 |
126 | if features == None:
127 | # Randomly creates image metadata.
128 | rgb2cam = random_ccm(device)
129 | cam2rgb = torch.inverse(rgb2cam)
130 | rgb_gain, red_gain, blue_gain = random_gains(device)
131 | else:
132 | rgb2cam = features['rgb2cam']
133 | cam2rgb = features['cam2rgb']
134 | rgb_gain = features['rgb_gain']
135 | red_gain = features['red_gain']
136 | blue_gain = features['blue_gain']
137 | # Approximately inverts global tone mapping.
138 | # image = inverse_smoothstep(image)
139 | # Inverts gamma compression.
140 | image = gamma_expansion(image)
141 | # Inverts color correction.
142 | image = apply_ccm(image, rgb2cam)
143 | # Approximately inverts white balance and brightening.
144 | image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device)
145 | # Clips saturated pixels.
146 | # image = torch.clamp(image, min=0.0, max=1.0)
147 | # Applies a Bayer mosaic.
148 | image = mosaic(image)
149 |
150 | metadata = {
151 | 'rgb2cam': rgb2cam,
152 | 'cam2rgb': cam2rgb,
153 | 'rgb_gain': rgb_gain,
154 | 'red_gain': red_gain,
155 | 'blue_gain': blue_gain,
156 | }
157 | return image, metadata
158 |
159 |
160 | # ############### If the target dataset is DND, use this function #####################
161 | # def random_noise_levels():
162 | # """Generates random noise levels from a log-log linear distribution."""
163 | # log_min_shot_noise = np.log(0.0001)
164 | # log_max_shot_noise = np.log(0.012)
165 | # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise)
166 | # shot_noise = torch.exp(log_shot_noise)
167 |
168 | # line = lambda x: 2.18 * x + 1.20
169 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.26]))
170 | # log_read_noise = line(log_shot_noise) + n.sample()
171 | # read_noise = torch.exp(log_read_noise)
172 | # return shot_noise, read_noise
173 |
174 |
175 | def add_noise(image, shot_noise, read_noise):
176 | var = image * shot_noise + read_noise
177 | noise = tdist.Normal(loc=torch.zeros_like(var), scale=torch.sqrt(var)).sample()
178 | out = image + noise
179 | return out
180 |
181 |
182 | ################ If the target dataset is SIDD, use this function #####################
183 | def random_noise_levels(noise_level):
184 | """ Where read_noise in SIDD is not 0 """
185 | log_min_shot_noise = torch.log(torch.tensor(0.0012)).to(noise_level.device)
186 | log_max_shot_noise = torch.log(torch.tensor(0.0048)).to(noise_level.device)
187 | log_shot_noise = log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise)
188 | shot_noise = torch.exp(log_shot_noise)
189 |
190 | line = lambda x: 1.869 * x + 0.3276
191 | n = tdist.Normal(loc=torch.tensor([0.0], device=noise_level.device),
192 | scale=torch.tensor([0.30], device=noise_level.device))
193 |
194 | log_read_noise = line(log_shot_noise) + n.sample()
195 | read_noise = torch.exp(log_read_noise)
196 |
197 | return shot_noise, read_noise
198 |
199 |
200 | # def add_noise(image, shot_noise=0.01, read_noise=0.0005):
201 | # """Adds random shot (proportional to image) and read (independent) noise."""
202 | # variance = image * shot_noise + read_noise
203 | # n = tdist.Normal(loc=torch.zeros_like(variance), scale=torch.sqrt(variance))
204 | # noise = n.sample()
205 | # out = image + noise
206 | # return out
207 |
208 |
209 | # ################ If the target dataset is SIDD, use this function #####################
210 | # def random_noise_levels(noise_level):
211 | # """ Where read_noise in SIDD is not 0 """
212 | # log_min_shot_noise = np.log(0.00068674)
213 | # log_max_shot_noise = np.log(0.02194856)
214 | # # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise)
215 | # log_shot_noise = torch.FloatTensor([log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise)])
216 | # shot_noise = torch.exp(log_shot_noise)
217 |
218 | # line = lambda x: 1.85 * x + 0.30
219 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.20]))
220 | # log_read_noise = line(log_shot_noise) + n.sample()
221 | # read_noise = torch.exp(log_read_noise)
222 | # return shot_noise, read_noise
--------------------------------------------------------------------------------
/imgs/Overview of CRNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/Overview of CRNet.png
--------------------------------------------------------------------------------
/imgs/multi_pro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/multi_pro.png
--------------------------------------------------------------------------------
/imgs/out1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/out1.png
--------------------------------------------------------------------------------
/imgs/out2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/imgs/out2.png
--------------------------------------------------------------------------------
/isp/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/isp/__init__.py
--------------------------------------------------------------------------------
/isp/demosaic_bayer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | # sys.path.insert(0, os.path.dirname(__file__))
4 | import torch
5 | import numpy as np
6 | import pdb
7 | import copy
8 | from collections import OrderedDict
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class BayerNetwork(nn.Module):
15 | """Released version of the network, best quality.
16 |
17 | This model differs from the published description. It has a mask/filter split
18 | towards the end of the processing. Masks and filters are multiplied with each
19 | other. This is not key to performance and can be ignored when training new
20 | models from scratch.
21 | """
22 | def __init__(self, depth=15, width=64):
23 | super(BayerNetwork, self).__init__()
24 |
25 | self.depth = depth
26 | self.width = width
27 |
28 | # self.debug_layer = nn.Conv2d(3, 4, 2, stride=2)
29 | # self.debug_layer1 =nn.Conv2d(in_channels=4,out_channels=64,kernel_size=3,stride=1,padding=1)
30 | # self.debug_layer2 =nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1,padding=1)
31 | # self.debug_layer3 =nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,stride=1,padding=1)
32 |
33 | layers = OrderedDict([
34 | ("pack_mosaic", nn.Conv2d(3, 4, 2, stride=2)), # Downsample 2x2 to re-establish translation invariance.
35 | ]) #
36 | # the output of 'pack_mosaic' will be half width and height of the input
37 | # [batch_size, 4, h/2, w/2] = pack_mosaic ( [batch_size, 3, h, w] )
38 |
39 | for i in range(depth):
40 | #num of in and out neurons in each layers
41 | n_out = width
42 | n_in = width
43 |
44 | if i == 0: # the 1st layer in main_processor
45 | n_in = 4
46 | if i == depth-1: # the last layer in main_processor
47 | n_out = 2*width
48 |
49 | # layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3)
50 | layers["conv{}".format(i + 1)] = nn.Conv2d(n_in, n_out, 3,stride=1,padding=1)
51 | # padding is set to be 1 so that the h and w won't change after conv2d (using kernal size 3)
52 | layers["relu{}".format(i+1)] = nn.ReLU(inplace=True)
53 |
54 |
55 | # main conv layer
56 | self.main_processor = nn.Sequential(layers)
57 | # residual layer
58 | self.residual_predictor = nn.Conv2d(width, 12, 1)
59 | # upsample layer
60 | self.upsampler = nn.ConvTranspose2d(12, 3, 2, stride=2, groups=3)
61 |
62 | # full-res layer
63 | self.fullres_processor = nn.Sequential(OrderedDict([
64 | # ("post_conv", nn.Conv2d(6, width, 3)),
65 | ("post_conv", nn.Conv2d(6, width, 3,stride=1,padding=1)),
66 | # padding is set to be 1 so that the h and w won't change after conv2d (using kernal size 3)
67 | ("post_relu", nn.ReLU(inplace=True)),
68 | ("output", nn.Conv2d(width, 3, 1)),
69 | ]))
70 |
71 |
72 | # samples structure
73 | # sample = {
74 | # "mosaic": mosaic,
75 | # # model input [batch_size, 3, h,w]. unknown pixels are set to 0.
76 | # "mask": mask,
77 | # # "noise_variance": np.array([std]),
78 | # "target": im,
79 | # # model output [m,n,3]
80 | # }
81 | def forward(self, samples):
82 |
83 | mosaic = samples["mosaic"]
84 | # [batch_size, 3, h, w]
85 |
86 | features = self.main_processor(mosaic)
87 | # [batch_size, self.width*2, hf,wf]
88 |
89 | filters, masks = features[:, :self.width], features[:, self.width:]
90 |
91 | filtered = filters * masks
92 | # [batch_size, self.width, hf,wf]
93 |
94 | residual = self.residual_predictor(filtered)
95 | # [batch_size, 12, hf, wf]
96 |
97 | upsampled = self.upsampler(residual)
98 | # [batch_size, 3, hf*2, wf*2]. upsampled will be 2x2 upsample of residual using ConvTranspose2d()
99 |
100 | # crop original mosaic to match output size
101 | cropped = crop_like(mosaic, upsampled)
102 |
103 | # Concated input samples and residual for further filtering
104 | packed = torch.cat([cropped, upsampled], 1)
105 |
106 | output = self.fullres_processor(packed)
107 |
108 | return output
109 |
110 |
111 | class Converter(object):
112 | def __init__(self, pretrained_dir, model_type):
113 | self.basedir = pretrained_dir
114 |
115 | def convert(self, model):
116 | for n, p in model.named_parameters():
117 | name, tp = n.split(".")[-2:]
118 |
119 | old_name = self._remap(name)
120 | # print(old_name, "->", name)
121 |
122 | if tp == "bias":
123 | idx = 1
124 | else:
125 | idx = 0
126 | path = os.path.join(self.basedir, "{}_{}.npy".format(old_name, idx))
127 | data = np.load(path)
128 | # print(name, tp, data.shape, p.shape)
129 |
130 | # Overwiter
131 | # print(p.mean().item(), p.std().item())
132 | # import ipdb; ipdb.set_trace()
133 | # print(name, old_name, p.shape, data.shape)
134 | p.data.copy_(torch.from_numpy(data))
135 | # print(p.mean().item(), p.std().item())
136 |
137 | def _remap(self, s):
138 | if s == "pack_mosaic":
139 | return "pack_mosaick"
140 | if s == "residual_predictor":
141 | return "residual"
142 | if s == "upsampler":
143 | return "unpack_mosaick"
144 | if s == "post_conv":
145 | return "post_conv1"
146 | return s
147 |
148 |
149 | def crop_like(src, tgt):
150 | src_sz = np.array(src.shape)
151 | tgt_sz = np.array(tgt.shape)
152 | crop = (src_sz[2:4]-tgt_sz[2:4]) // 2
153 | if (crop > 0).any():
154 | return src[:, :, crop[0]:src_sz[2]-crop[0], crop[1]:src_sz[3]-crop[1], ...]
155 | else:
156 | return src
157 |
158 |
159 | def get_modules(params):
160 | params = copy.deepcopy(params) # do not touch the original
161 |
162 | # get the model name from the input params
163 | model_name = params.pop("model", None)
164 |
165 | if model_name is None:
166 | raise ValueError("model has not been specified!")
167 |
168 | # get the model structure by model_name
169 | return getattr(sys.modules[__name__], model_name)(**params)
170 |
171 |
172 | def get_demosaic_net_model(pretrained, device, cfa='bayer', state_dict=False):
173 | '''
174 | get demosaic network
175 | :param pretrained:
176 | path to the demosaic-network model file [string]
177 | :param device:
178 | 'cuda:0', e.g.
179 | :param state_dict:
180 | whether to use a packed state dictionary for model weights
181 | :return:
182 | model_ref: demosaic-net model
183 |
184 | '''
185 |
186 | model_ref = get_modules({"model": "BayerNetwork"}) # load model coefficients if 'pretrained'=True
187 | if not state_dict:
188 | cvt = Converter(pretrained, "BayerNetwork")
189 | cvt.convert(model_ref)
190 | for p in model_ref.parameters():
191 | p.requires_grad = False
192 | model_ref = model_ref.to(device)
193 | else:
194 | model_ref.load_state_dict(torch.load(pretrained))
195 | model_ref = model_ref.to(device)
196 |
197 | model_ref.eval()
198 |
199 | return model_ref
200 |
201 |
202 | def demosaic_by_demosaic_net(bayer, cfa, demosaic_net, device):
203 | '''
204 | demosaic the bayer to get RGB by demosaic-net. The func will covnert the numpy array to tensor for demosaic-net,
205 | after which the tensor will be converted back to numpy array to return.
206 |
207 | :param bayer:
208 | [m,n]. numpy float32 in the rnage of [0,1] linear bayer
209 | :param cfa:
210 | [string], 'RGGB', e.g. only GBRG, RGGB, BGGR or GRBG is supported so far!
211 | :param demosaic_net:
212 | demosaic_net object
213 | :param device:
214 | 'cuda:0', e.g.
215 |
216 | :return:
217 | [m,n,3]. np array float32 in the rnage of [0,1]
218 |
219 | '''
220 |
221 |
222 | assert (cfa == 'GBRG') or (cfa == 'RGGB') or (cfa == 'GRBG') or (cfa == 'BGGR'), 'only GBRG, RGGB, BGGR, GRBG are supported so far!'
223 |
224 | # if the bayer resolution is too high (more than 1000x1000,e.g.), may cause memory error.
225 |
226 | bayer = np.clip(bayer ,0 ,1)
227 | bayer = torch.from_numpy(bayer).float()
228 | bayer = bayer.to(device)
229 | bayer = torch.unsqueeze(bayer, 0)
230 | bayer = torch.unsqueeze(bayer, 0)
231 |
232 | with torch.no_grad():
233 | rgb = predict_rgb_from_bayer_tensor(bayer, cfa=cfa, demosaic_net=demosaic_net, device=device)
234 |
235 | rgb = rgb.detach().cpu()[0].permute(1, 2, 0).numpy() # torch tensor -> numpy array
236 | # rgb = np.clip(rgb, 0, 1)
237 |
238 | return rgb
239 |
240 |
241 | def predict_rgb_from_bayer_tensor(im,cfa,demosaic_net,device):
242 | '''
243 | predict the RGB imgae from bayer pattern mosaic using demosaic net
244 |
245 | :param im:
246 | [batch_sz, 1, m,n] tensor. the bayer pattern mosiac.
247 |
248 | :param cfa:
249 | the cfa layout. the demosaic net is trained w/ GRBG. If the input is other than GRBG, need padding or cropping
250 |
251 | :param demosaic_net:
252 | demosaic-net
253 |
254 | :param device:
255 | 'cuda:0', e.g.
256 |
257 | :return:
258 | rgb_hat:
259 | [batch_size, 3, m,n] the rgb image predicted by the demosaic-net using our bayer input
260 | '''
261 |
262 | assert (cfa == 'GBRG') or (cfa == 'RGGB') or (cfa == 'GRBG') or (cfa == 'BGGR')
263 | # 'only GBRG, RGGB, BGGR, GRBG are supported so far!'
264 |
265 | # print(im.shape)
266 |
267 | n_channel = im.shape[1]
268 |
269 | if n_channel==1: # gray scale image
270 | im= torch.cat((im, im, im), 1)
271 |
272 | if cfa == 'GBRG': # the demosiac net is trained w/ GRBG
273 | im = pad_gbrg_2_grbg(im,device)
274 | elif cfa == 'RGGB':
275 | im = pad_rggb_2_grbg(im, device)
276 | elif cfa == 'BGGR':
277 | im = pad_bggr_2_grbg(im, device)
278 |
279 | im= bayer_mosaic_tensor(im,device)
280 |
281 | sample = {"mosaic": im}
282 |
283 | rgb_hat = demosaic_net(sample)
284 |
285 | if cfa == 'GBRG':
286 | # an extra row and col is padded on four sides of the bayer before using demosaic-net. Need to trim the padded rows and cols of demosaiced rgb
287 | rgb_hat = unpad_grbg_2_gbrg(rgb_hat)
288 | elif cfa == 'RGGB':
289 | rgb_hat = unpad_grbg_2_rggb(rgb_hat)
290 | elif cfa == 'BGGR':
291 | rgb_hat = unpad_grbg_2_bggr(rgb_hat)
292 |
293 | rgb_hat = torch.clamp(rgb_hat, min=0, max=1)
294 |
295 | return rgb_hat
296 |
297 |
298 | def pad_bggr_2_grbg(bayer, device):
299 | '''
300 | pad bggr bayer pattern to get grbg (for demosaic-net)
301 |
302 | :param bayer:
303 | 2d tensor [bsz,ch, h,w]
304 | :param device:
305 | 'cuda:0' or 'cpu', or ...
306 | :return:
307 | bayer: 2d tensor [bsz,ch,h,w+2]
308 |
309 | '''
310 | bsz, ch, h, w = bayer.shape
311 |
312 | bayer2 = torch.zeros([bsz, ch, h + 2, w], dtype=torch.float32)
313 | bayer2 = bayer2.to(device)
314 |
315 | bayer2[:, :, 1:-1, :] = bayer
316 |
317 | bayer2[:, :, 0, :] = bayer[:, :, 1, :]
318 | bayer2[:, :, -1, :] = bayer2[:, :, -2, :]
319 |
320 | bayer = bayer2
321 |
322 | return bayer
323 |
324 |
325 | def pad_rggb_2_grbg(bayer,device):
326 | '''
327 | pad rggb bayer pattern to get grbg (for demosaic-net)
328 |
329 | :param bayer:
330 | 2d tensor [bsz,ch, h,w]
331 | :param device:
332 | 'cuda:0' or 'cpu', or ...
333 | :return:
334 | bayer: 2d tensor [bsz,ch,h,w+2]
335 |
336 | '''
337 | bsz, ch, h, w = bayer.shape
338 |
339 | bayer2 = torch.zeros([bsz,ch,h, w+2], dtype=torch.float32)
340 | bayer2 = bayer2.to(device)
341 |
342 | bayer2[:,:,:, 1:-1] = bayer
343 |
344 | bayer2[:,:,:, 0] = bayer[:,:,:, 1]
345 | bayer2[:,:,:, -1] = bayer2[:,:,:, -2]
346 |
347 | bayer = bayer2
348 |
349 | return bayer
350 |
351 |
352 | def pad_gbrg_2_grbg(bayer,device):
353 | '''
354 | pad gbrg bayer pattern to get grbg (for demosaic-net)
355 |
356 | :param bayer:
357 | 2d tensor [bsz,ch, h,w]
358 | :param device:
359 | 'cuda:0' or 'cpu', or ...
360 | :return:
361 | bayer: 2d tensor [bsz,ch,h+4,w+4]
362 |
363 | '''
364 | bsz, ch, h, w = bayer.shape
365 |
366 | bayer2 = torch.zeros([bsz,ch,h+2, w+2], dtype=torch.float32)
367 | bayer2 = bayer2.to(device)
368 |
369 | bayer2[:,:,1:-1, 1:-1] = bayer
370 | bayer2[:,:,0, 1:-1] = bayer[:,:,1, :]
371 | bayer2[:,:,-1, 1:-1] = bayer[:,:,-2, :]
372 |
373 | bayer2[:,:,:, 0] = bayer2[:,:,:, 2]
374 | bayer2[:,:,:, -1] = bayer2[:,:,:, -3]
375 |
376 | bayer = bayer2
377 |
378 | return bayer
379 |
380 |
381 | def unpad_grbg_2_gbrg(rgb):
382 | '''
383 | unpad the rgb image. this is used after pad_gbrg_2_grbg()
384 | :param rgb:
385 | tensor. [1,3,m,n]
386 | :return:
387 | tensor [1,3,m-2,n-2]
388 |
389 | '''
390 | rgb = rgb[:,:,1:-1,1:-1]
391 |
392 | return rgb
393 |
394 |
395 | def unpad_grbg_2_bggr(rgb):
396 | '''
397 | unpad the rgb image. this is used after pad_bggr_2_grbg()
398 | :param rgb:
399 | tensor. [1,3,m,n]
400 | :return:
401 | tensor [1,3,m,n-2]
402 |
403 | '''
404 | rgb = rgb[:, :, 1:-1 , : ]
405 |
406 | return rgb
407 |
408 |
409 | def unpad_grbg_2_rggb(rgb):
410 | '''
411 | unpad the rgb image. this is used after pad_rggb_2_grbg()
412 | :param rgb:
413 | tensor. [1,3,m,n]
414 | :return:
415 | tensor [1,3,m,n-2]
416 |
417 | '''
418 | rgb = rgb[:,:,:,1:-1]
419 |
420 | return rgb
421 |
422 |
423 | def bayer_mosaic_tensor(im,device):
424 | '''
425 | create bayer mosaic to set as input to demosaic-net.
426 | make sure the input bayer (im) is GRBG.
427 |
428 | :param im:
429 | [batch_size, 3, m,n]. The color is in RGB order.
430 | :param device:
431 | 'cuda:0', e.g.
432 | :return:
433 | '''
434 |
435 | """GRBG Bayer mosaic."""
436 |
437 | batch_size=im.shape[0]
438 | hh=im.shape[2]
439 | ww=im.shape[3]
440 |
441 | mask = torch.ones([batch_size,3,hh, ww], dtype=torch.float32)
442 | mask = mask.to(device)
443 |
444 | # red
445 | mask[:,0, ::2, 0::2] = 0
446 | mask[:,0, 1::2, :] = 0
447 |
448 | # green
449 | mask[:,1, ::2, 1::2] = 0
450 | mask[:,1, 1::2, ::2] = 0
451 |
452 | # blue
453 | mask[:,2, 0::2, :] = 0
454 | mask[:,2, 1::2, 1::2] = 0
455 |
456 | return im*mask
--------------------------------------------------------------------------------
/isp/dng_opcode.py:
--------------------------------------------------------------------------------
1 | """
2 | Authors(s):
3 | Abdelrahman Abdelhamed (a.abdelhamed@samsung.com)
4 |
5 | Utility functions for handling DNG opcode lists.
6 | """
7 | import struct
8 | import numpy as np
9 | from .exif_utils import get_tag_values_from_ifds
10 |
11 |
12 | class Opcode:
13 | def __init__(self, id_, dng_spec_ver, option_bits, size_bytes, data):
14 | self.id = id_
15 | self.dng_spec_ver = dng_spec_ver
16 | self.size_bytes = size_bytes
17 | self.option_bits = option_bits
18 | self.data = data
19 |
20 |
21 | def parse_opcode_lists(ifds):
22 | # OpcodeList1, 51008, 0xC740
23 | # Applied to raw image as read directly form file
24 |
25 | # OpcodeList2, 51009, 0xC741
26 | # Applied to raw image after being mapped to linear reference values
27 | # That is, after linearization, black level subtraction, normalization, and clipping
28 |
29 | # OpcodeList3, 51022, 0xC74E
30 | # Applied to raw image after being demosaiced
31 |
32 | opcode_list_tag_nums = [51008, 51009, 51022]
33 | opcode_lists = {}
34 | for i, tag_num in enumerate(opcode_list_tag_nums):
35 | opcode_list_ = get_tag_values_from_ifds(tag_num, ifds)
36 | if opcode_list_ is not None:
37 | opcode_list_ = bytearray(opcode_list_)
38 | opcodes = parse_opcodes(opcode_list_)
39 | opcode_lists.update({tag_num: opcodes})
40 | else:
41 | pass
42 |
43 | return opcode_lists
44 |
45 |
46 | def parse_opcodes(opcode_list):
47 | """
48 | Parse a byte array representing an opcode list.
49 | :param opcode_list: An opcode list as a byte array.
50 | :return: Opcode lists as a dictionary.
51 | """
52 | # opcode lists are always stored in big endian
53 | endian_sign = ">"
54 |
55 | # opcode IDs
56 | # 9: GainMap
57 | # 1: Rectilinear Warp
58 |
59 | # clip to
60 | # [0, 2^32 - 1] for OpcodeList1
61 | # [0, 2^16 - 1] for OpcodeList2
62 | # [0, 1] for OpcodeList3
63 |
64 | i = 0
65 | num_opcodes = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0]
66 | i += 4
67 |
68 | opcodes = {}
69 | for j in range(num_opcodes):
70 | opcode_id_ = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0]
71 | i += 4
72 | dng_spec_ver = [struct.unpack(endian_sign + "B", opcode_list[i + k:i + k + 1])[0] for k in range(4)]
73 | i += 4
74 | option_bits = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0]
75 | i += 4
76 |
77 | # option bits
78 | if option_bits & 1 == 1: # optional/unknown
79 | pass
80 | elif option_bits & 2 == 2: # can be skipped for "preview quality", needed for "full quality"
81 | pass
82 | else:
83 | pass
84 |
85 | opcode_size_bytes = struct.unpack(endian_sign + "I", opcode_list[i:i + 4])[0]
86 | i += 4
87 |
88 | opcode_data = opcode_list[i:i + 4 * opcode_size_bytes]
89 | i += 4 * opcode_size_bytes
90 |
91 | # GainMap (lens shading correction map)
92 | if opcode_id_ == 9:
93 | opcode_gain_map_data = parse_opcode_gain_map(opcode_data)
94 | opcode_data = opcode_gain_map_data
95 |
96 | # set opcode object
97 | opcode = Opcode(id_=opcode_id_, dng_spec_ver=dng_spec_ver, option_bits=option_bits,
98 | size_bytes=opcode_size_bytes,
99 | data=opcode_data)
100 | opcodes.update({opcode_id_: opcode})
101 |
102 | return opcodes
103 |
104 |
105 | def parse_opcode_gain_map(opcode_data):
106 | endian_sign = ">" # big
107 | opcode_dict = {}
108 | keys = ['top', 'left', 'bottom', 'right', 'plane', 'planes', 'row_pitch', 'col_pitch', 'map_points_v',
109 | 'map_points_h', 'map_spacing_v', 'map_spacing_h', 'map_origin_v', 'map_origin_h', 'map_planes', 'map_gain']
110 | dtypes = ['L'] * 10 + ['d'] * 4 + ['L'] + ['f']
111 | dtype_sizes = [4] * 10 + [8] * 4 + [4] * 2 # data type size in bytes
112 | counts = [1] * 15 + [0] # 0 count means variable count, depending on map_points_v and map_points_h
113 | # values = []
114 |
115 | i = 0
116 | for k in range(len(keys)):
117 | if counts[k] == 0: # map_gain
118 | counts[k] = opcode_dict['map_points_v'] * opcode_dict['map_points_h']
119 |
120 | if counts[k] == 1:
121 | vals = struct.unpack(endian_sign + dtypes[k], opcode_data[i:i + dtype_sizes[k]])[0]
122 | i += dtype_sizes[k]
123 | else:
124 | vals = []
125 | for j in range(counts[k]):
126 | vals.append(struct.unpack(endian_sign + dtypes[k], opcode_data[i:i + dtype_sizes[k]])[0])
127 | i += dtype_sizes[k]
128 |
129 | opcode_dict[keys[k]] = vals
130 |
131 | opcode_dict['map_gain_2d'] = np.reshape(opcode_dict['map_gain'],
132 | (opcode_dict['map_points_v'], opcode_dict['map_points_h']))
133 |
134 | return opcode_dict
135 |
--------------------------------------------------------------------------------
/isp/exif_data_formats.py:
--------------------------------------------------------------------------------
1 | class ExifFormat:
2 | def __init__(self, id, name, size, short_name):
3 | self.id = id
4 | self.name = name
5 | self.size = size
6 | self.short_name = short_name # used with struct.unpack()
7 |
8 |
9 | exif_formats = {
10 | 1: ExifFormat(1, 'unsigned byte', 1, 'B'),
11 | 2: ExifFormat(2, 'ascii string', 1, 's'),
12 | 3: ExifFormat(3, 'unsigned short', 2, 'H'),
13 | 4: ExifFormat(4, 'unsigned long', 4, 'L'),
14 | 5: ExifFormat(5, 'unsigned rational', 8, ''),
15 | 6: ExifFormat(6, 'signed byte', 1, 'b'),
16 | 7: ExifFormat(7, 'undefined', 1, 'B'), # consider `undefined` as `unsigned byte`
17 | 8: ExifFormat(8, 'signed short', 2, 'h'),
18 | 9: ExifFormat(9, 'signed long', 4, 'l'),
19 | 10: ExifFormat(10, 'signed rational', 8, ''),
20 | 11: ExifFormat(11, 'single float', 4, 'f'),
21 | 12: ExifFormat(12, 'double float', 8, 'd'),
22 | }
23 |
--------------------------------------------------------------------------------
/isp/exif_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Author(s):
3 | Abdelrahman Abdelhamed
4 |
5 | Manual parsing of image file directories (IFDs).
6 | """
7 |
8 |
9 | import struct
10 | from fractions import Fraction
11 | from .exif_data_formats import exif_formats
12 |
13 |
14 | class Ifd:
15 | def __init__(self):
16 | self.offset = -1
17 | self.tags = {} # dict; tag number will be key.
18 |
19 |
20 | class Tag:
21 | def __init__(self):
22 | self.offset = -1
23 | self.tag_num = -1
24 | self.data_format = -1
25 | self.num_values = -1
26 | self.values = []
27 |
28 |
29 | def parse_exif(image_path, verbose=True):
30 | """
31 | Parse EXIF tags from a binary file and return IFDs.
32 | Returned IFDs include EXIF SubIFDs, if any.
33 | """
34 |
35 | def print_(str_):
36 | if verbose:
37 | print(str_)
38 |
39 | ifds = {} # dict of pairs; using offset to IFD as key.
40 |
41 | with open(image_path, 'rb') as fid:
42 | fid.seek(0)
43 | b0 = fid.read(1)
44 | _ = fid.read(1)
45 | # byte storage direction (endian):
46 | # +1: b'M' (big-endian/Motorola)
47 | # -1: b'I' (little-endian/Intel)
48 | endian = 1 if b0 == b'M' else -1
49 | print_("Endian = {}".format(b0))
50 | endian_sign = "<" if endian == -1 else ">" # used in struct.unpack
51 | print_("Endian sign = {}".format(endian_sign))
52 | _ = fid.read(2) # 0x002A
53 | b4_7 = fid.read(4) # offset to first IFD
54 | offset_ = struct.unpack(endian_sign + "I", b4_7)[0]
55 | i = 0
56 | ifd_offsets = [offset_]
57 | while len(ifd_offsets) > 0:
58 | offset_ = ifd_offsets.pop(0)
59 | # check if IFD at this offset was already parsed before
60 | if offset_ in ifds:
61 | continue
62 | print_("=========== Parsing IFD # {} ===========".format(i))
63 | ifd_ = parse_exif_ifd(fid, offset_, endian_sign, verbose)
64 | ifds.update({ifd_.offset: ifd_})
65 | print_("=========== Finished parsing IFD # {} ===========".format(i))
66 | i += 1
67 | # check SubIFDs; zero or more offsets at tag 0x014a
68 | sub_idfs_tag_num = int('0x014a', 16)
69 | if sub_idfs_tag_num in ifd_.tags:
70 | ifd_offsets.extend(ifd_.tags[sub_idfs_tag_num].values)
71 | # check Exif SUbIDF; usually one offset at tag 0x8769
72 | exif_sub_idf_tag_num = int('0x8769', 16)
73 | if exif_sub_idf_tag_num in ifd_.tags:
74 | ifd_offsets.extend(ifd_.tags[exif_sub_idf_tag_num].values)
75 | return ifds
76 |
77 |
78 | def parse_exif_ifd(binary_file, offset_, endian_sign, verbose=True):
79 | """
80 | Parse an EXIF IFD.
81 | """
82 |
83 | def print_(str_):
84 | if verbose:
85 | print(str_)
86 |
87 | ifd = Ifd()
88 | ifd.offset = offset_
89 | print_("IFD offset = {}".format(ifd.offset))
90 | binary_file.seek(offset_)
91 | num_entries = struct.unpack(endian_sign + "H", binary_file.read(2))[0] # format H = unsigned short
92 | print_("Number of entries = {}".format(num_entries))
93 | for t in range(num_entries):
94 | print_("---------- Tag {} / {} ----------".format(t + 1, num_entries))
95 | if t == 22:
96 | ttt = 1
97 | tag_ = parse_exif_tag(binary_file, endian_sign, verbose)
98 | ifd.tags.update({tag_.tag_num: tag_}) # supposedly, EXIF tag numbers won't repeat in the same IFD
99 | # TODO: check for subsequent IFDs by parsing the next 4 bytes immediately after the IFD
100 | return ifd
101 |
102 |
103 | def parse_exif_tag(binary_file, endian_sign, verbose=True):
104 | """
105 | Parse EXIF tag from a binary file starting from the current file pointer and returns the tag values.
106 | """
107 |
108 | def print_(str_):
109 | if verbose:
110 | print(str_)
111 |
112 | tag = Tag()
113 |
114 | # tag offset
115 | tag.offset = binary_file.tell()
116 | print_("Tag offset = {}".format(tag.offset))
117 |
118 | # tag number
119 | bytes_ = binary_file.read(2)
120 | tag.tag_num = struct.unpack(endian_sign + "H", bytes_)[0] # H: unsigned 2-byte short
121 | print_("Tag number = {} = 0x{:04x}".format(tag.tag_num, tag.tag_num))
122 |
123 | # data format (some value between [1, 12])
124 | tag.data_format = struct.unpack(endian_sign + "H", binary_file.read(2))[0] # H: unsigned 2-byte short
125 | exif_format = exif_formats[tag.data_format]
126 | print_("Data format = {} = {}".format(tag.data_format, exif_format.name))
127 |
128 | # number of components/values
129 | tag.num_values = struct.unpack(endian_sign + "I", binary_file.read(4))[0] # I: unsigned 4-byte integer
130 | print_("Number of values = {}".format(tag.num_values))
131 |
132 | # total number of data bytes
133 | total_bytes = tag.num_values * exif_format.size
134 | print_("Total bytes = {}".format(total_bytes))
135 |
136 | # seek to data offset (if needed)
137 | data_is_offset = False
138 | current_offset = binary_file.tell()
139 | if total_bytes > 4:
140 | print_("Total bytes > 4; The next 4 bytes are an offset.")
141 | data_is_offset = True
142 | data_offset = struct.unpack(endian_sign + "I", binary_file.read(4))[0]
143 | current_offset = binary_file.tell()
144 | print_("Current offset = {}".format(current_offset))
145 | print_("Seeking to data offset = {}".format(data_offset))
146 | binary_file.seek(data_offset)
147 |
148 | # read values
149 | # TODO: need to distinguish between numeric and text values?
150 | if tag.num_values == 1 and total_bytes < 4:
151 | # special case: data is a single value that is less than 4 bytes inside 4 bytes, take care of endian
152 | val_bytes = binary_file.read(4)
153 | # if endian_sign == ">":
154 | # val_bytes = val_bytes[4 - total_bytes:]
155 | # else:
156 | # val_bytes = val_bytes[:total_bytes][::-1]
157 | val_bytes = val_bytes[:total_bytes]
158 | tag.values.append(struct.unpack(endian_sign + exif_format.short_name, val_bytes)[0])
159 | else:
160 | # read data values one by one
161 | for k in range(tag.num_values):
162 | val_bytes = binary_file.read(exif_format.size)
163 | if exif_format.name == 'unsigned rational':
164 | tag.values.append(eight_bytes_to_fraction(val_bytes, endian_sign, signed=False))
165 | elif exif_format.name == 'signed rational':
166 | tag.values.append(eight_bytes_to_fraction(val_bytes, endian_sign, signed=True))
167 | else:
168 | tag.values.append(struct.unpack(endian_sign + exif_format.short_name, val_bytes)[0])
169 | if total_bytes < 4:
170 | # special case: multiple values less than 4 bytes in total, inside the 4 bytes; skip the extra bytes
171 | binary_file.seek(4 - total_bytes, 1)
172 |
173 | if verbose:
174 | if len(tag.values) > 100:
175 | print_("Got more than 100 values; printing first 100 only:")
176 | print_("Tag values = {}".format(tag.values[:100]))
177 | else:
178 | print_("Tag values = {}".format(tag.values))
179 | if tag.data_format == 2:
180 | print_("Tag values (string) = {}".format(b''.join(tag.values).decode()))
181 |
182 | if data_is_offset:
183 | # seek back to current position to read the next tag
184 | print_("Seeking back to current offset = {}".format(current_offset))
185 | binary_file.seek(current_offset)
186 |
187 | return tag
188 |
189 |
190 | def get_tag_values_from_ifds(tag_num, ifds):
191 | """
192 | Return values of a tag, if found in ifds. Return None otherwise.
193 | Assuming any tag exists only once in all ifds.
194 | """
195 | for key, ifd in ifds.items():
196 | if tag_num in ifd.tags:
197 | return ifd.tags[tag_num].values
198 | return None
199 |
200 |
201 | def eight_bytes_to_fraction(eight_bytes, endian_sign, signed):
202 | """
203 | Convert 8-byte array into a Fraction. Take care of endian and sign.
204 | """
205 | if signed:
206 | num = struct.unpack(endian_sign + "l", eight_bytes[:4])[0]
207 | den = struct.unpack(endian_sign + "l", eight_bytes[4:])[0]
208 | else:
209 | num = struct.unpack(endian_sign + "L", eight_bytes[:4])[0]
210 | den = struct.unpack(endian_sign + "L", eight_bytes[4:])[0]
211 | den = den if den != 0 else 1
212 | return Fraction(num, den)
213 |
--------------------------------------------------------------------------------
/isp/isp.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import cv2
4 | import numpy as np
5 |
6 | from .pipeline import run_pipeline_v2
7 | from .pipeline_utils import get_visible_raw_image, get_metadata
8 |
9 | params = {
10 | 'input_stage': 'normal', # options: 'raw', 'normal', 'white_balance', 'demosaic', 'xyz', 'srgb', 'gamma', 'tone'
11 | 'output_stage': 'srgb', # options: 'normal', 'white_balance', 'demosaic', 'xyz', 'srgb', 'gamma', 'tone'
12 | 'save_as': 'png', # options: 'jpg', 'png', 'tif', etc.
13 | 'demosaic_type': 'net', # 'menon2007', 'EA', 'VNG', 'net', 'down
14 | 'save_dtype': np.uint16
15 | }
16 |
17 | def reshape_back_raw(bayer):
18 | H = bayer.shape[1]
19 | W = bayer.shape[2]
20 | newH = int(H*2)
21 | newW = int(W*2)
22 | bayer_back = np.zeros((newH, newW))
23 | bayer_back[0:newH:2, 0:newW:2] = bayer[3]
24 | bayer_back[0:newH:2, 1:newW:2] = bayer[1]
25 | bayer_back[1:newH:2, 0:newW:2] = bayer[2]
26 | bayer_back[1:newH:2, 1:newW:2] = bayer[0]
27 |
28 | return bayer_back
29 |
30 | def isp_pip(raw_image, meta_data, device='0'):
31 | # raw_image = reshape_back_raw(npy_img) / ratio
32 |
33 | # metadata
34 | # meta_npy = meta_data.items()
35 | metadata = get_metadata(meta_data)
36 | # raw_image = raw_image * (2**10 - 1) # + meta_npy['black_level'][0]
37 |
38 | # render
39 | output_image = run_pipeline_v2(raw_image, params, metadata=metadata, device=device)
40 | # output_image = output_image[..., ::-1] * 255
41 | # output_image = np.clip(output_image, 0, 255)
42 |
43 | return output_image #.astype(params['save_dtype'])
44 |
--------------------------------------------------------------------------------
/isp/model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/isp/model.bin
--------------------------------------------------------------------------------
/isp/pipeline.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from .pipeline_utils import get_visible_raw_image, get_metadata, normalize, white_balance, demosaic, \
3 | apply_color_space_transform, transform_xyz_to_srgb, apply_gamma, apply_tone_map, fix_orientation, \
4 | lens_shading_correction
5 |
6 |
7 | def run_pipeline_v2(image_or_path, params=None, metadata=None, fix_orient=True, device='0'):
8 | params_ = params.copy()
9 | if type(image_or_path) == str:
10 | image_path = image_or_path
11 | # raw image data
12 | raw_image = get_visible_raw_image(image_path)
13 | # metadata
14 | metadata = get_metadata(image_path)
15 | else:
16 | raw_image = image_or_path.copy()
17 | # must provide metadata
18 | if metadata is None:
19 | raise ValueError("Must provide metadata when providing image data in first argument.")
20 |
21 | current_stage = 'raw'
22 | current_image = raw_image
23 |
24 | if params_['input_stage'] == current_stage:
25 | # linearization
26 | linearization_table = metadata['linearization_table']
27 | if linearization_table is not None:
28 | print('Linearization table found. Not handled.')
29 | # TODO
30 |
31 | current_image = normalize(current_image, metadata['black_level'], metadata['white_level'])
32 | params_['input_stage'] = 'normal'
33 |
34 | current_stage = 'normal'
35 |
36 | if params_['output_stage'] == current_stage:
37 | return current_image
38 |
39 | if params_['input_stage'] == current_stage:
40 | gain_map_opcode = None
41 | if 'opcode_lists' in metadata:
42 | if 51009 in metadata['opcode_lists']:
43 | opcode_list_2 = metadata['opcode_lists'][51009]
44 | gain_map_opcode = opcode_list_2[9]
45 | if gain_map_opcode is not None:
46 | current_image = lens_shading_correction(current_image, gain_map_opcode=gain_map_opcode,
47 | bayer_pattern=metadata['cfa_pattern'])
48 | params_['input_stage'] = 'lens_shading_correction'
49 |
50 | current_stage = 'lens_shading_correction'
51 |
52 | if params_['output_stage'] == current_stage:
53 | return current_image
54 |
55 | if params_['input_stage'] == current_stage:
56 | current_image = white_balance(current_image, metadata['as_shot_neutral'], metadata['cfa_pattern'])
57 | params_['input_stage'] = 'white_balance'
58 |
59 | current_stage = 'white_balance'
60 |
61 | if params_['output_stage'] == current_stage:
62 | return current_image
63 |
64 | if params_['input_stage'] == current_stage:
65 | current_image = demosaic(current_image, metadata['cfa_pattern'], output_channel_order='RGB',
66 | alg_type=params_['demosaic_type'], device=device)
67 | params_['input_stage'] = 'demosaic'
68 |
69 | current_stage = 'demosaic'
70 |
71 | if params_['output_stage'] == current_stage:
72 | return current_image
73 |
74 | if params_['input_stage'] == current_stage:
75 | current_image = apply_color_space_transform(current_image, metadata['color_matrix_1'],
76 | metadata['color_matrix_2'])
77 | params_['input_stage'] = 'xyz'
78 |
79 | current_stage = 'xyz'
80 |
81 | if params_['output_stage'] == current_stage:
82 | return current_image
83 |
84 | if params_['input_stage'] == current_stage:
85 | current_image = transform_xyz_to_srgb(current_image)
86 | params_['input_stage'] = 'srgb'
87 |
88 | current_stage = 'srgb'
89 |
90 | if fix_orient:
91 | # fix image orientation, if needed (after srgb stage, ok?)
92 | current_image = fix_orientation(current_image, metadata['orientation'])
93 |
94 | if params_['output_stage'] == current_stage:
95 | return current_image
96 |
97 | if params_['input_stage'] == current_stage:
98 | current_image = apply_gamma(current_image)
99 | params_['input_stage'] = 'gamma'
100 |
101 | current_stage = 'gamma'
102 |
103 | if params_['output_stage'] == current_stage:
104 | return current_image
105 |
106 | if params_['input_stage'] == current_stage:
107 | current_image = apply_tone_map(current_image)
108 | params_['input_stage'] = 'tone'
109 |
110 | current_stage = 'tone'
111 |
112 | if params_['output_stage'] == current_stage:
113 | return current_image
114 |
115 | # invalid input/output stage!
116 | raise ValueError('Invalid input/output stage: input_stage = {}, output_stage = {}'.format(params_['input_stage'],
117 | params_['output_stage']))
118 |
119 |
120 | def run_pipeline(image_path, params):
121 | # raw image data
122 | raw_image = get_visible_raw_image(image_path)
123 |
124 | # metadata
125 | metadata = get_metadata(image_path)
126 |
127 | # linearization
128 | linearization_table = metadata['linearization_table']
129 | if linearization_table is not None:
130 | print('Linearization table found. Not handled.')
131 | # TODO
132 |
133 | normalized_image = normalize(raw_image, metadata['black_level'], metadata['white_level'])
134 |
135 | if params['output_stage'] == 'normal':
136 | return normalized_image
137 |
138 | white_balanced_image = white_balance(normalized_image, metadata['as_shot_neutral'], metadata['cfa_pattern'])
139 |
140 | if params['output_stage'] == 'white_balance':
141 | return white_balanced_image
142 |
143 | demosaiced_image = demosaic(white_balanced_image, metadata['cfa_pattern'], output_channel_order='BGR',
144 | alg_type=params['demosaic_type'])
145 |
146 | # fix image orientation, if needed
147 | demosaiced_image = fix_orientation(demosaiced_image, metadata['orientation'])
148 |
149 | if params['output_stage'] == 'demosaic':
150 | return demosaiced_image
151 |
152 | xyz_image = apply_color_space_transform(demosaiced_image, metadata['color_matrix_1'], metadata['color_matrix_2'])
153 |
154 | if params['output_stage'] == 'xyz':
155 | return xyz_image
156 |
157 | srgb_image = transform_xyz_to_srgb(xyz_image)
158 |
159 | if params['output_stage'] == 'srgb':
160 | return srgb_image
161 |
162 | gamma_corrected_image = apply_gamma(srgb_image)
163 |
164 | if params['output_stage'] == 'gamma':
165 | return gamma_corrected_image
166 |
167 | tone_mapped_image = apply_tone_map(gamma_corrected_image)
168 | if params['output_stage'] == 'tone':
169 | return tone_mapped_image
170 |
171 | output_image = None
172 | return output_image
173 |
--------------------------------------------------------------------------------
/isp/pipeline_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Author(s):
3 | Abdelrahman Abdelhamed
4 |
5 | Camera pipeline utilities.
6 | """
7 |
8 | import os
9 | from fractions import Fraction
10 |
11 | import cv2
12 | import numpy as np
13 | import exifread
14 | # from exifread import Ratio
15 | from exifread.utils import Ratio
16 | import rawpy
17 | from scipy.io import loadmat
18 | from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007
19 | import struct
20 |
21 | from .dng_opcode import parse_opcode_lists
22 | from .exif_data_formats import exif_formats
23 | from .exif_utils import parse_exif_tag, parse_exif, get_tag_values_from_ifds
24 | from isp import demosaic_bayer
25 |
26 |
27 | def get_visible_raw_image(image_path):
28 | raw_image = rawpy.imread(image_path).raw_image_visible.copy()
29 | # raw_image = rawpy.imread(image_path).raw_image.copy()
30 | return raw_image
31 |
32 |
33 | def get_image_tags(image_path):
34 | with open(image_path, 'rb') as f:
35 | tags = exifread.process_file(f)
36 | return tags
37 |
38 |
39 | def get_image_ifds(image_path):
40 | ifds = parse_exif(image_path, verbose=False)
41 | return ifds
42 |
43 |
44 | def get_metadata(image_path):
45 | metadata = {}
46 | tags = image_path
47 | ifds = image_path
48 | # tags = get_image_tags(image_path)
49 | # ifds = get_image_ifds(image_path)
50 | metadata['linearization_table'] = get_linearization_table(tags, ifds)
51 | metadata['black_level'] = get_black_level(tags, ifds)
52 | metadata['white_level'] = get_white_level(tags, ifds)
53 | metadata['cfa_pattern'] = get_cfa_pattern(tags, ifds)
54 | metadata['as_shot_neutral'] = get_as_shot_neutral(tags, ifds)
55 | color_matrix_1, color_matrix_2 = get_color_matrices(tags, ifds)
56 | metadata['color_matrix_1'] = color_matrix_1
57 | metadata['color_matrix_2'] = color_matrix_2
58 | metadata['orientation'] = get_orientation(tags, ifds)
59 | metadata['noise_profile'] = get_noise_profile(tags, ifds)
60 | # ...
61 |
62 | # opcode lists
63 | # metadata['opcode_lists'] = parse_opcode_lists(ifds)
64 |
65 | # fall back to default values, if necessary
66 | if metadata['black_level'] is None:
67 | metadata['black_level'] = 0
68 | print("Black level is None; using 0.")
69 | if metadata['white_level'] is None:
70 | metadata['white_level'] = 2 ** 16
71 | print("White level is None; using 2 ** 16.")
72 | if metadata['cfa_pattern'] is None:
73 | metadata['cfa_pattern'] = [0, 1, 1, 2]
74 | print("CFAPattern is None; using [0, 1, 1, 2] (RGGB)")
75 | if metadata['as_shot_neutral'] is None:
76 | metadata['as_shot_neutral'] = [1, 1, 1]
77 | print("AsShotNeutral is None; using [1, 1, 1]")
78 | if metadata['color_matrix_1'] is None:
79 | metadata['color_matrix_1'] = [1] * 9
80 | print("ColorMatrix1 is None; using [1, 1, 1, 1, 1, 1, 1, 1, 1]")
81 | if metadata['color_matrix_2'] is None:
82 | metadata['color_matrix_2'] = [1] * 9
83 | print("ColorMatrix2 is None; using [1, 1, 1, 1, 1, 1, 1, 1, 1]")
84 | if metadata['orientation'] is None:
85 | metadata['orientation'] = 0
86 | print("Orientation is None; using 0.")
87 | # ...
88 | return metadata
89 |
90 |
91 | def get_linearization_table(tags, ifds):
92 | possible_keys = ['Image Tag 0xC618', 'Image Tag 50712', 'LinearizationTable', 'Image LinearizationTable']
93 | return get_values(tags, possible_keys)
94 |
95 |
96 | def get_black_level(tags, ifds):
97 | possible_keys = ['Image Tag 0xC61A', 'Image Tag 50714', 'BlackLevel', 'Image BlackLevel']
98 | vals = get_values(tags, possible_keys)
99 | if vals is None:
100 | # print("Black level not found in exifread tags. Searching IFDs.")
101 | vals = get_tag_values_from_ifds(50714, ifds)
102 | return vals
103 |
104 |
105 | def get_white_level(tags, ifds):
106 | possible_keys = ['Image Tag 0xC61D', 'Image Tag 50717', 'WhiteLevel', 'Image WhiteLevel']
107 | vals = get_values(tags, possible_keys)
108 | if vals is None:
109 | # print("White level not found in exifread tags. Searching IFDs.")
110 | vals = get_tag_values_from_ifds(50717, ifds)
111 | return vals
112 |
113 |
114 | def get_cfa_pattern(tags, ifds):
115 | possible_keys = ['CFAPattern', 'Image CFAPattern']
116 | vals = get_values(tags, possible_keys)
117 | if vals is None:
118 | # print("CFAPattern not found in exifread tags. Searching IFDs.")
119 | vals = get_tag_values_from_ifds(33422, ifds)
120 | return vals
121 |
122 |
123 | def get_as_shot_neutral(tags, ifds):
124 | possible_keys = ['Image Tag 0xC628', 'Image Tag 50728', 'AsShotNeutral', 'Image AsShotNeutral']
125 | return get_values(tags, possible_keys)
126 |
127 |
128 | def get_color_matrices(tags, ifds):
129 | possible_keys_1 = ['Image Tag 0xC621', 'Image Tag 50721', 'ColorMatrix1', 'Image ColorMatrix1']
130 | color_matrix_1 = get_values(tags, possible_keys_1)
131 | possible_keys_2 = ['Image Tag 0xC622', 'Image Tag 50722', 'ColorMatrix2', 'Image ColorMatrix2']
132 | color_matrix_2 = get_values(tags, possible_keys_2)
133 | return color_matrix_1, color_matrix_2
134 |
135 |
136 | def get_orientation(tags, ifds):
137 | possible_tags = ['Orientation', 'Image Orientation']
138 | return get_values(tags, possible_tags)
139 |
140 |
141 | def get_noise_profile(tags, ifds):
142 | possible_keys = ['Image Tag 0xC761', 'Image Tag 51041', 'NoiseProfile', 'Image NoiseProfile']
143 | vals = get_values(tags, possible_keys)
144 | if vals is None:
145 | # print("Noise profile not found in exifread tags. Searching IFDs.")
146 | vals = get_tag_values_from_ifds(51041, ifds)
147 | return vals
148 |
149 |
150 | def get_values(tags, possible_keys):
151 | values = None
152 | for key in possible_keys:
153 | if key in tags.keys():
154 | values = tags[key] #.values
155 | return values
156 |
157 |
158 | def normalize(raw_image, black_level, white_level):
159 | if type(black_level) is list and len(black_level) == 1:
160 | black_level = float(black_level[0])
161 | if type(white_level) is list and len(white_level) == 1:
162 | white_level = float(white_level[0])
163 | black_level_mask = black_level
164 | if type(black_level) is list and len(black_level) == 4:
165 | if type(black_level[0]) is Ratio:
166 | black_level = ratios2floats(black_level)
167 | black_level_mask = np.zeros(raw_image.shape)
168 | idx2by2 = [[0, 0], [0, 1], [1, 0], [1, 1]]
169 | step2 = 2
170 | for i, idx in enumerate(idx2by2):
171 | black_level_mask[idx[0]::step2, idx[1]::step2] = black_level[i]
172 | normalized_image = raw_image.astype(np.float32) - black_level_mask
173 | # if some values were smaller than black level
174 | normalized_image[normalized_image < 0] = 0
175 | normalized_image = normalized_image / (white_level - black_level_mask)
176 | return normalized_image
177 |
178 |
179 | def ratios2floats(ratios):
180 | floats = []
181 | for ratio in ratios:
182 | floats.append(float(ratio.num) / ratio.den)
183 | return floats
184 |
185 |
186 | def lens_shading_correction(raw_image, gain_map_opcode, bayer_pattern, gain_map=None, clip=True):
187 | """
188 | Apply lens shading correction map.
189 | :param raw_image: Input normalized (in [0, 1]) raw image.
190 | :param gain_map_opcode: Gain map opcode.
191 | :param bayer_pattern: Bayer pattern (RGGB, GRBG, ...).
192 | :param gain_map: Optional gain map to replace gain_map_opcode. 1 or 4 channels in order: R, Gr, Gb, and B.
193 | :param clip: Whether to clip result image to [0, 1].
194 | :return: Image with gain map applied; lens shading corrected.
195 | """
196 |
197 | if gain_map is None and gain_map_opcode:
198 | gain_map = gain_map_opcode.data['map_gain_2d']
199 |
200 | # resize gain map, make it 4 channels, if needed
201 | gain_map = cv2.resize(gain_map, dsize=(raw_image.shape[1] // 2, raw_image.shape[0] // 2),
202 | interpolation=cv2.INTER_LINEAR)
203 | if len(gain_map.shape) == 2:
204 | gain_map = np.tile(gain_map[..., np.newaxis], [1, 1, 4])
205 |
206 | if gain_map_opcode:
207 | # TODO: consider other parameters
208 |
209 | top = gain_map_opcode.data['top']
210 | left = gain_map_opcode.data['left']
211 | bottom = gain_map_opcode.data['bottom']
212 | right = gain_map_opcode.data['right']
213 | rp = gain_map_opcode.data['row_pitch']
214 | cp = gain_map_opcode.data['col_pitch']
215 |
216 | gm_w = right - left
217 | gm_h = bottom - top
218 |
219 | # gain_map = cv2.resize(gain_map, dsize=(gm_w, gm_h), interpolation=cv2.INTER_LINEAR)
220 |
221 | # TODO
222 | # if top > 0:
223 | # pass
224 | # elif left > 0:
225 | # left_col = gain_map[:, 0:1]
226 | # rep_left_col = np.tile(left_col, [1, left])
227 | # gain_map = np.concatenate([rep_left_col, gain_map], axis=1)
228 | # elif bottom < raw_image.shape[0]:
229 | # pass
230 | # elif right < raw_image.shape[1]:
231 | # pass
232 |
233 | result_image = raw_image.copy()
234 |
235 | # one channel
236 | # result_image[::rp, ::cp] *= gain_map[::rp, ::cp]
237 |
238 | # per bayer channel
239 | upper_left_idx = [[0, 0], [0, 1], [1, 0], [1, 1]]
240 | bayer_pattern_idx = np.array(bayer_pattern)
241 | # blue channel index --> 3
242 | bayer_pattern_idx[bayer_pattern_idx == 2] = 3
243 | # second green channel index --> 2
244 | if bayer_pattern_idx[3] == 1:
245 | bayer_pattern_idx[3] = 2
246 | else:
247 | bayer_pattern_idx[2] = 2
248 | for c in range(4):
249 | i0 = upper_left_idx[c][0]
250 | j0 = upper_left_idx[c][1]
251 | result_image[i0::2, j0::2] *= gain_map[:, :, bayer_pattern_idx[c]]
252 |
253 | if clip:
254 | result_image = np.clip(result_image, 0.0, 1.0)
255 |
256 | return result_image
257 |
258 |
259 | def white_balance(normalized_image, as_shot_neutral, cfa_pattern):
260 | if type(as_shot_neutral[0]) is Ratio:
261 | as_shot_neutral = ratios2floats(as_shot_neutral)
262 | idx2by2 = [[0, 0], [0, 1], [1, 0], [1, 1]]
263 | step2 = 2
264 | white_balanced_image = np.zeros(normalized_image.shape)
265 | for i, idx in enumerate(idx2by2):
266 | idx_y = idx[0]
267 | idx_x = idx[1]
268 | white_balanced_image[idx_y::step2, idx_x::step2] = \
269 | normalized_image[idx_y::step2, idx_x::step2] / as_shot_neutral[cfa_pattern[i]]
270 | white_balanced_image = np.clip(white_balanced_image, 0.0, 1.0)
271 | return white_balanced_image
272 |
273 |
274 | def get_opencv_demsaic_flag(cfa_pattern, output_channel_order, alg_type='VNG'):
275 | # using opencv edge-aware demosaicing
276 | if alg_type != '':
277 | alg_type = '_' + alg_type
278 | if output_channel_order == 'BGR':
279 | if cfa_pattern == [0, 1, 1, 2]: # RGGB
280 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2BGR' + alg_type)
281 | elif cfa_pattern == [2, 1, 1, 0]: # BGGR
282 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_RG2BGR' + alg_type)
283 | elif cfa_pattern == [1, 0, 2, 1]: # GRBG
284 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GB2BGR' + alg_type)
285 | elif cfa_pattern == [1, 2, 0, 1]: # GBRG
286 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GR2BGR' + alg_type)
287 | else:
288 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2BGR' + alg_type)
289 | print("CFA pattern not identified.")
290 | else: # RGB
291 | if cfa_pattern == [0, 1, 1, 2]: # RGGB
292 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2RGB' + alg_type)
293 | elif cfa_pattern == [2, 1, 1, 0]: # BGGR
294 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_RG2RGB' + alg_type)
295 | elif cfa_pattern == [1, 0, 2, 1]: # GRBG
296 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GB2RGB' + alg_type)
297 | elif cfa_pattern == [1, 2, 0, 1]: # GBRG
298 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_GR2RGB' + alg_type)
299 | else:
300 | opencv_demosaic_flag = eval('cv2.COLOR_BAYER_BG2RGB' + alg_type)
301 | print("CFA pattern not identified.")
302 | return opencv_demosaic_flag
303 |
304 |
305 | def demosaic(white_balanced_image, cfa_pattern, output_channel_order='BGR', alg_type='VNG', device='0'):
306 | """
307 | Demosaic a Bayer image.
308 | :param white_balanced_image:
309 | :param cfa_pattern:
310 | :param output_channel_order:
311 | :param alg_type: algorithm type. options: '', 'EA' for edge-aware, 'VNG' for variable number of gradients
312 | :return: Demosaiced image
313 | """
314 | if alg_type == 'VNG':
315 | max_val = 16383
316 | wb_image = (white_balanced_image * max_val).astype(dtype=np.uint8)
317 | else:
318 | max_val = 16383
319 | wb_image = (white_balanced_image * max_val).astype(dtype=np.uint16)
320 |
321 | if alg_type in ['', 'EA', 'VNG']:
322 | opencv_demosaic_flag = get_opencv_demsaic_flag(cfa_pattern, output_channel_order, alg_type=alg_type)
323 | demosaiced_image = cv2.cvtColor(wb_image, opencv_demosaic_flag)
324 | elif alg_type == 'menon2007':
325 | cfa_pattern_str = "".join(["RGB"[i] for i in cfa_pattern])
326 | demosaiced_image = demosaicing_CFA_Bayer_Menon2007(wb_image, pattern=cfa_pattern_str)
327 | elif alg_type == 'down': # G R B G
328 | demosaiced_image = np.zeros([wb_image.shape[0]//2, wb_image.shape[1]//2, 3], dtype=np.float32)
329 | demosaiced_image[:,:,0] = wb_image[0::2, 1::2]
330 | demosaiced_image[:,:,1] = wb_image[0::2, 0::2]
331 | demosaiced_image[:,:,2] = wb_image[1::2, 0::2]
332 | elif alg_type == 'net':
333 | cfa_pattern_str = "".join(["RGB"[i] for i in cfa_pattern])
334 | bayer = np.power(np.clip(wb_image.astype(dtype=np.float32) / max_val, 0, 1), 1 / 2.2)
335 | pretrained_model_path = os.path.dirname(__file__) + "/model.bin"
336 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=device,
337 | cfa='bayer', state_dict=True)
338 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa=cfa_pattern_str,
339 | demosaic_net=demosaic_net, device=device)
340 | demosaiced_image = np.power(np.clip(rgb, 0, 1), 2.2) * max_val
341 |
342 | demosaiced_image = demosaiced_image.astype(dtype=np.float32) / max_val
343 |
344 | return demosaiced_image
345 |
346 |
347 | def apply_color_space_transform(demosaiced_image, color_matrix_1, color_matrix_2):
348 | if type(color_matrix_1[0]) is Ratio:
349 | color_matrix_1 = ratios2floats(color_matrix_1)
350 | if type(color_matrix_2[0]) is Ratio:
351 | color_matrix_2 = ratios2floats(color_matrix_2)
352 | xyz2cam1 = np.reshape(np.asarray(color_matrix_1), (3, 3))
353 | xyz2cam2 = np.reshape(np.asarray(color_matrix_2), (3, 3))
354 | # normalize rows (needed?)
355 | xyz2cam1 = xyz2cam1 / np.sum(xyz2cam1, axis=1, keepdims=True)
356 | xyz2cam2 = xyz2cam2 / np.sum(xyz2cam1, axis=1, keepdims=True)
357 | # inverse
358 | cam2xyz1 = np.linalg.inv(xyz2cam1)
359 | cam2xyz2 = np.linalg.inv(xyz2cam2)
360 | # for now, use one matrix # TODO: interpolate btween both
361 | # simplified matrix multiplication
362 | xyz_image = cam2xyz1[np.newaxis, np.newaxis, :, :] * demosaiced_image[:, :, np.newaxis, :]
363 | xyz_image = np.sum(xyz_image, axis=-1)
364 | xyz_image = np.clip(xyz_image, 0.0, 1.0)
365 | return xyz_image
366 |
367 |
368 | def transform_xyz_to_srgb(xyz_image):
369 | # srgb2xyz = np.array([[0.4124564, 0.3575761, 0.1804375],
370 | # [0.2126729, 0.7151522, 0.0721750],
371 | # [0.0193339, 0.1191920, 0.9503041]])
372 |
373 | # xyz2srgb = np.linalg.inv(srgb2xyz)
374 |
375 | xyz2srgb = np.array([[3.2404542, -1.5371385, -0.4985314],
376 | [-0.9692660, 1.8760108, 0.0415560],
377 | [0.0556434, -0.2040259, 1.0572252]])
378 |
379 | # normalize rows (needed?)
380 | xyz2srgb = xyz2srgb / np.sum(xyz2srgb, axis=-1, keepdims=True)
381 |
382 | srgb_image = xyz2srgb[np.newaxis, np.newaxis, :, :] * xyz_image[:, :, np.newaxis, :]
383 | srgb_image = np.sum(srgb_image, axis=-1)
384 | srgb_image = np.clip(srgb_image, 0.0, 1.0)
385 | return srgb_image
386 |
387 |
388 | def fix_orientation(image, orientation):
389 | # 1 = Horizontal(normal)
390 | # 2 = Mirror horizontal
391 | # 3 = Rotate 180
392 | # 4 = Mirror vertical
393 | # 5 = Mirror horizontal and rotate 270 CW
394 | # 6 = Rotate 90 CW
395 | # 7 = Mirror horizontal and rotate 90 CW
396 | # 8 = Rotate 270 CW
397 |
398 | if type(orientation) is list:
399 | orientation = orientation[0]
400 |
401 | if orientation == 1:
402 | pass
403 | elif orientation == 2:
404 | image = cv2.flip(image, 0)
405 | elif orientation == 3:
406 | image = cv2.rotate(image, cv2.ROTATE_180)
407 | elif orientation == 4:
408 | image = cv2.flip(image, 1)
409 | elif orientation == 5:
410 | image = cv2.flip(image, 0)
411 | image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
412 | elif orientation == 6:
413 | image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
414 | elif orientation == 7:
415 | image = cv2.flip(image, 0)
416 | image = cv2.rotate(image, cv2.ROTATE_90_CLOCKWISE)
417 | elif orientation == 8:
418 | image = cv2.rotate(image, cv2.ROTATE_90_COUNTERCLOCKWISE)
419 |
420 | return image
421 |
422 |
423 | def reverse_orientation(image, orientation):
424 | # 1 = Horizontal(normal)
425 | # 2 = Mirror horizontal
426 | # 3 = Rotate 180
427 | # 4 = Mirror vertical
428 | # 5 = Mirror horizontal and rotate 270 CW
429 | # 6 = Rotate 90 CW
430 | # 7 = Mirror horizontal and rotate 90 CW
431 | # 8 = Rotate 270 CW
432 | rev_orientations = np.array([1, 2, 3, 4, 5, 8, 7, 6])
433 | return fix_orientation(image, rev_orientations[orientation - 1])
434 |
435 |
436 | def apply_gamma(x):
437 | return x ** (1.0 / 2.2)
438 |
439 |
440 | def apply_tone_map(x):
441 | # simple tone curve
442 | # return 3 * x ** 2 - 2 * x ** 3
443 |
444 | # tone_curve = loadmat('tone_curve.mat')
445 | tone_curve = loadmat(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'tone_curve.mat'))
446 | tone_curve = tone_curve['tc']
447 | x = np.round(x * (len(tone_curve) - 1)).astype(int)
448 | tone_mapped_image = np.squeeze(tone_curve[x])
449 | return tone_mapped_image
450 |
451 |
452 | def raw_rgb_to_cct(rawRgb, xyz2cam1, xyz2cam2):
453 | """Convert raw-RGB triplet to corresponding correlated color temperature (CCT)"""
454 | pass
455 | # pxyz = [.5, 1, .5]
456 | # loss = 1e10
457 | # k = 1
458 | # while loss > 1e-4:
459 | # cct = XyzToCct(pxyz)
460 | # xyz = RawRgbToXyz(rawRgb, cct, xyz2cam1, xyz2cam2)
461 | # loss = norm(xyz - pxyz)
462 | # pxyz = xyz
463 | # fprintf('k = %d, loss = %f\n', [k, loss])
464 | # k = k + 1
465 | # end
466 | # temp = cct
467 |
--------------------------------------------------------------------------------
/isp/tone_curve.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/isp/tone_curve.mat
--------------------------------------------------------------------------------
/mixer.yaml:
--------------------------------------------------------------------------------
1 | name: mixer
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/msys2/
6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
9 | - defaults
10 | dependencies:
11 | - _libgcc_mutex=0.1=main
12 | - _openmp_mutex=5.1=1_gnu
13 | - blas=1.0=mkl
14 | - bzip2=1.0.8=h5eee18b_5
15 | - ca-certificates=2023.12.12=h06a4308_0
16 | - certifi=2024.2.2=py39h06a4308_0
17 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
18 | - cuda-cudart=11.8.89=0
19 | - cuda-cupti=11.8.87=0
20 | - cuda-libraries=11.8.0=0
21 | - cuda-nvrtc=11.8.89=0
22 | - cuda-nvtx=11.8.86=0
23 | - cuda-runtime=11.8.0=0
24 | - ffmpeg=4.3=hf484d3e_0
25 | - filelock=3.13.1=py39h06a4308_0
26 | - freetype=2.12.1=h4a9f257_0
27 | - gmp=6.2.1=h295c915_3
28 | - gmpy2=2.1.2=py39heeb90bb_0
29 | - gnutls=3.6.15=he1e5248_0
30 | - idna=3.4=py39h06a4308_0
31 | - intel-openmp=2023.1.0=hdb19cb5_46306
32 | - jinja2=3.1.3=py39h06a4308_0
33 | - jpeg=9e=h5eee18b_1
34 | - lame=3.100=h7b6447c_0
35 | - lcms2=2.12=h3be6417_0
36 | - ld_impl_linux-64=2.38=h1181459_1
37 | - lerc=3.0=h295c915_0
38 | - libcublas=11.11.3.6=0
39 | - libcufft=10.9.0.58=0
40 | - libcufile=1.8.1.2=0
41 | - libcurand=10.3.4.107=0
42 | - libcusolver=11.4.1.48=0
43 | - libcusparse=11.7.5.86=0
44 | - libdeflate=1.17=h5eee18b_1
45 | - libffi=3.4.4=h6a678d5_0
46 | - libgcc-ng=11.2.0=h1234567_1
47 | - libgomp=11.2.0=h1234567_1
48 | - libiconv=1.16=h7f8727e_2
49 | - libidn2=2.3.4=h5eee18b_0
50 | - libjpeg-turbo=2.0.0=h9bf148f_0
51 | - libnpp=11.8.0.86=0
52 | - libnvjpeg=11.9.0.86=0
53 | - libpng=1.6.39=h5eee18b_0
54 | - libstdcxx-ng=11.2.0=h1234567_1
55 | - libtasn1=4.19.0=h5eee18b_0
56 | - libtiff=4.5.1=h6a678d5_0
57 | - libunistring=0.9.10=h27cfd23_0
58 | - libwebp-base=1.3.2=h5eee18b_0
59 | - llvm-openmp=14.0.6=h9e868ea_0
60 | - lz4-c=1.9.4=h6a678d5_0
61 | - markupsafe=2.1.3=py39h5eee18b_0
62 | - mkl=2023.1.0=h213fc3f_46344
63 | - mkl-service=2.4.0=py39h5eee18b_1
64 | - mkl_fft=1.3.8=py39h5eee18b_0
65 | - mkl_random=1.2.4=py39hdb19cb5_0
66 | - mpc=1.1.0=h10f8cd9_1
67 | - mpfr=4.0.2=hb69a4c5_1
68 | - mpmath=1.3.0=py39h06a4308_0
69 | - ncurses=6.4=h6a678d5_0
70 | - nettle=3.7.3=hbbd107a_1
71 | - networkx=3.1=py39h06a4308_0
72 | - numpy=1.26.4=py39h5f9d8c6_0
73 | - numpy-base=1.26.4=py39hb5e798b_0
74 | - openh264=2.1.1=h4ff587b_0
75 | - openjpeg=2.4.0=h3ad879b_0
76 | - openssl=3.0.13=h7f8727e_0
77 | - pillow=10.2.0=py39h5eee18b_0
78 | - pip=23.3.1=py39h06a4308_0
79 | - python=3.9.16=h955ad1f_3
80 | - pytorch-cuda=11.8=h7e8668a_5
81 | - pytorch-mutex=1.0=cuda
82 | - pyyaml=6.0.1=py39h5eee18b_0
83 | - readline=8.2=h5eee18b_0
84 | - requests=2.31.0=py39h06a4308_1
85 | - setuptools=68.2.2=py39h06a4308_0
86 | - sqlite=3.41.2=h5eee18b_0
87 | - sympy=1.12=py39h06a4308_0
88 | - tbb=2021.8.0=hdb19cb5_0
89 | - tk=8.6.12=h1ccaba5_0
90 | - torchaudio=2.2.1=py39_cu118
91 | - typing_extensions=4.9.0=py39h06a4308_1
92 | - tzdata=2024a=h04d1e81_0
93 | - urllib3=2.1.0=py39h06a4308_0
94 | - wheel=0.41.2=py39h06a4308_0
95 | - xz=5.4.6=h5eee18b_0
96 | - yaml=0.2.5=h7b6447c_0
97 | - zlib=1.2.13=h5eee18b_0
98 | - zstd=1.5.5=hc292b87_0
99 | - pip:
100 | - absl-py==2.1.0
101 | - accelerate==0.27.2
102 | - addict==2.4.0
103 | - basicsr==1.4.2
104 | - bypy==1.8.5
105 | - cmake==3.28.3
106 | - colour-demosaicing==0.2.5
107 | - colour-science==0.4.4
108 | - contourpy==1.2.0
109 | - cycler==0.12.1
110 | - dill==0.3.8
111 | - einops==0.7.0
112 | - fonttools==4.50.0
113 | - fsspec==2024.2.0
114 | - future==1.0.0
115 | - grpcio==1.62.0
116 | - huggingface-hub==0.21.3
117 | - imageio==2.34.0
118 | - importlib-metadata==7.0.1
119 | - importlib-resources==6.4.0
120 | - kiwisolver==1.4.5
121 | - lazy-loader==0.3
122 | - lit==17.0.6
123 | - lmdb==1.4.1
124 | - markdown==3.5.2
125 | - matplotlib==3.8.3
126 | - multiprocess==0.70.16
127 | - nvidia-cublas-cu11==11.10.3.66
128 | - nvidia-cuda-cupti-cu11==11.7.101
129 | - nvidia-cuda-nvrtc-cu11==11.7.99
130 | - nvidia-cuda-runtime-cu11==11.7.99
131 | - nvidia-cudnn-cu11==8.5.0.96
132 | - nvidia-cufft-cu11==10.9.0.58
133 | - nvidia-curand-cu11==10.2.10.91
134 | - nvidia-cusolver-cu11==11.4.0.1
135 | - nvidia-cusparse-cu11==11.7.4.91
136 | - nvidia-nccl-cu11==2.14.3
137 | - nvidia-nvtx-cu11==11.7.91
138 | - opencv-python==4.9.0.80
139 | - packaging==23.2
140 | - platformdirs==4.2.0
141 | - protobuf==4.25.3
142 | - psutil==5.9.8
143 | - pyparsing==3.1.2
144 | - python-dateutil==2.9.0.post0
145 | - pywavelets==1.5.0
146 | - requests-toolbelt==1.0.0
147 | - safetensors==0.4.2
148 | - scikit-image==0.22.0
149 | - scipy==1.12.0
150 | - six==1.16.0
151 | - tb-nightly==2.17.0a20240229
152 | - tensorboard-data-server==0.7.2
153 | - tensorboardx==2.6.2.2
154 | - thop==0.1.1-2209072238
155 | - tifffile==2024.2.12
156 | - timm==0.9.16
157 | - tomli==2.0.1
158 | - torch==2.0.1
159 | - torchvision==0.15.2
160 | - tqdm==4.66.2
161 | - triton==2.0.0
162 | - werkzeug==3.0.1
163 | - yapf==0.40.2
164 | - zipp==3.17.0
165 |
166 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from models.base_model import BaseModel
3 |
4 |
5 | def find_model_using_name(model_name):
6 | """Import the module "models/[model_name]_model.py".
7 |
8 | In the file, the class called DatasetNameModel() will
9 | be instantiated. It has to be a subclass of BaseModel,
10 | and it is case-insensitive.
11 | """
12 | model_filename = "models." + model_name + "_model"
13 | modellib = importlib.import_module(model_filename)
14 | model = None
15 | target_model_name = model_name.replace('_', '') + 'model'
16 | for name, cls in modellib.__dict__.items():
17 | if name.lower() == target_model_name.lower() \
18 | and issubclass(cls, BaseModel):
19 | model = cls
20 |
21 | if model is None:
22 | raise NotImplementedError("In %s.py, there should be a subclass of "
23 | "BaseModel with class name that matches %s in "
24 | "lowercase." % (model_filename, target_model_name))
25 |
26 | return model
27 |
28 |
29 | def get_option_setter(model_name):
30 | model_class = find_model_using_name(model_name)
31 | return model_class.modify_commandline_options
32 |
33 |
34 | def create_model(opt):
35 | """Create a model given the option.
36 |
37 | This function warps the class CustomDatasetDataLoader.
38 | This is the main interface between this package and 'train.py'/'test.py'
39 |
40 | Example:
41 | >>> from models import create_model
42 | >>> model = create_model(opt)
43 | """
44 | model = find_model_using_name(opt.model)
45 | instance = model(opt)
46 | print("model [%s] was created" % type(instance).__name__)
47 | return instance
48 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | from . import networks
6 | import torch
7 | from util.util import torch_save
8 | import math
9 | import torch.nn.functional as F
10 | from data.degrade.process import demosaic
11 |
12 |
13 | class BaseModel(ABC):
14 | def __init__(self, opt):
15 | self.opt = opt
16 | self.gpu_ids = opt.gpu_ids
17 | self.isTrain = opt.isTrain
18 | # self.scale = opt.scale
19 |
20 | if len(self.gpu_ids) > 0:
21 | self.device = torch.device('cuda', self.gpu_ids[0])
22 | else:
23 | self.device = torch.device('cpu')
24 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
25 | self.loss_names = []
26 | self.model_names = []
27 | self.visual_names = []
28 | self.optimizers = []
29 | self.optimizer_names = []
30 | self.image_paths = []
31 | self.metric = 0 # used for learning rate policy 'plateau'
32 | self.start_epoch = 0
33 |
34 | self.backwarp_tenGrid = {}
35 | self.backwarp_tenPartial = {}
36 |
37 | @staticmethod
38 | def modify_commandline_options(parser, is_train):
39 | return parser
40 |
41 | @abstractmethod
42 | def set_input(self, input):
43 | pass
44 |
45 | @abstractmethod
46 | def forward(self):
47 | pass
48 |
49 | @abstractmethod
50 | def optimize_parameters(self):
51 | pass
52 |
53 | def setup(self, opt=None):
54 | opt = opt if opt is not None else self.opt
55 | if self.isTrain:
56 | self.schedulers = [networks.get_scheduler(optimizer, opt) \
57 | for optimizer in self.optimizers]
58 | for scheduler in self.schedulers:
59 | scheduler.last_epoch = opt.load_iter
60 | if opt.load_iter > 0 or opt.load_path != '':
61 | load_suffix = opt.load_iter
62 | self.load_networks(load_suffix)
63 | if opt.load_optimizers:
64 | self.load_optimizers(opt.load_iter)
65 |
66 | self.print_networks(opt.verbose)
67 |
68 | def eval(self):
69 | for name in self.model_names:
70 | net = getattr(self, 'net' + name)
71 | net.eval()
72 |
73 | def train(self):
74 | self.isTrain = True
75 | for name in self.model_names:
76 | net = getattr(self, 'net' + name)
77 | net.train()
78 |
79 | def test(self):
80 | self.isTrain = False
81 | with torch.no_grad():
82 | self.forward()
83 |
84 | def get_image_paths(self):
85 | return self.image_paths
86 |
87 | def update_learning_rate(self, epoch):
88 | for i, scheduler in enumerate(self.schedulers):
89 | if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
90 | scheduler.step(self.metric)
91 | elif scheduler.__class__.__name__ == 'CosineLRScheduler':
92 | scheduler.step(epoch)
93 | else:
94 | scheduler.step()
95 | print('lr of %s = %.7f' % (
96 | self.optimizer_names[i], self.optimizers[i].param_groups[0]['lr']))
97 |
98 | def post_process(self, image, max=255):
99 | image = image.permute(0, 2, 3, 1)
100 | image = image / image.max()
101 | image = demosaic(image)
102 | image = image.clamp(0.0, 1.0) ** (1/2.2)
103 | if max == 255:
104 | image = torch.clamp(image * 255, 0, 255).round()
105 | image = image.permute(0, 3, 1, 2)
106 | return image
107 |
108 | def get_current_visuals(self):
109 | visual_ret = OrderedDict()
110 | if self.isTrain:
111 | for name in self.visual_names:
112 | if isinstance(getattr(self, name), list):
113 | visual_ret[name] = self.post_process(getattr(self, name)[-1][0:1].detach())
114 | elif isinstance(getattr(self, name), torch.Tensor):
115 | visual_ret[name] = self.post_process(getattr(self, name)[0:1].detach())
116 | else:
117 | raise Exception
118 | # visual_ret[name] = getattr(self, name)[0:1].detach()
119 | else:
120 | for name in self.visual_names:
121 | visual_ret[name] = getattr(self, name).clamp_min_(0).detach()
122 | return visual_ret
123 |
124 | def get_current_losses(self):
125 | errors_ret = OrderedDict()
126 | for name in self.loss_names:
127 | errors_ret[name] = float(getattr(self, 'loss_' + name))
128 | return errors_ret
129 |
130 | def save_networks(self, epoch):
131 | for name in self.model_names:
132 | save_filename = '%s_model_%d.pth' % (name, epoch)
133 | save_path = os.path.join(self.save_dir, save_filename)
134 | net = getattr(self, 'net' + name)
135 | if self.device.type == 'cuda':
136 | state = {'state_dict': net.module.cpu().state_dict()}
137 | torch_save(state, save_path)
138 | net.to(self.device)
139 | else:
140 | state = {'state_dict': net.state_dict()}
141 | torch_save(state, save_path)
142 | self.save_optimizers(epoch)
143 |
144 | def load_networks(self, epoch):
145 | for name in self.model_names: #[0:1]:
146 | load_filename = '%s_model_%d.pth' % (name, epoch)
147 | if self.opt.load_path != '':
148 | load_path = self.opt.load_path
149 | else:
150 | load_path = os.path.join(self.save_dir, load_filename)
151 | net = getattr(self, 'net' + name)
152 |
153 | state_dict = torch.load(load_path, map_location=self.device)
154 | print('loading the model from %s' % (load_path))
155 | if hasattr(state_dict, '_metadata'):
156 | del state_dict._metadata
157 |
158 | net_state = net.state_dict()
159 | is_loaded = {n:False for n in net_state.keys()}
160 | for name, param in state_dict['state_dict'].items():
161 | if name in net_state:
162 | try:
163 | net_state[name].copy_(param)
164 | is_loaded[name] = True
165 | except Exception:
166 | print('While copying the parameter named [%s], '
167 | 'whose dimensions in the model are %s and '
168 | 'whose dimensions in the checkpoint are %s.'
169 | % (name, list(net_state[name].shape),
170 | list(param.shape)))
171 | raise RuntimeError
172 | else:
173 | print('Saved parameter named [%s] is skipped' % name)
174 | mark = True
175 | for name in is_loaded:
176 | if not is_loaded[name]:
177 | print('Parameter named [%s] is randomly initialized' % name)
178 | mark = False
179 | if mark:
180 | print('All parameters are initialized using [%s]' % load_path)
181 |
182 | self.start_epoch = epoch
183 |
184 | def load_network_path(self, net, path):
185 | if isinstance(net, torch.nn.DataParallel):
186 | net = net.module
187 | state_dict = torch.load(path, map_location=self.device)
188 | print('loading the model from %s' % (path))
189 | if hasattr(state_dict, '_metadata'):
190 | del state_dict._metadata
191 |
192 | net_state = net.state_dict()
193 | is_loaded = {n:False for n in net_state.keys()}
194 | for name, param in state_dict['state_dict'].items():
195 | if name in net_state:
196 | try:
197 | net_state[name].copy_(param)
198 | is_loaded[name] = True
199 | except Exception:
200 | print('While copying the parameter named [%s], '
201 | 'whose dimensions in the model are %s and '
202 | 'whose dimensions in the checkpoint are %s.'
203 | % (name, list(net_state[name].shape),
204 | list(param.shape)))
205 | raise RuntimeError
206 | else:
207 | print('Saved parameter named [%s] is skipped' % name)
208 | mark = True
209 | for name in is_loaded:
210 | if not is_loaded[name]:
211 | print('Parameter named [%s] is randomly initialized' % name)
212 | mark = False
213 | if mark:
214 | print('All parameters are initialized using [%s]' % path)
215 |
216 | def save_optimizers(self, epoch):
217 | assert len(self.optimizers) == len(self.optimizer_names)
218 | for id, optimizer in enumerate(self.optimizers):
219 | save_filename = self.optimizer_names[id]
220 | state = {'name': save_filename,
221 | 'epoch': epoch,
222 | 'state_dict': optimizer.state_dict()}
223 | save_path = os.path.join(self.save_dir, save_filename+'.pth')
224 | torch_save(state, save_path)
225 |
226 | def load_optimizers(self, epoch):
227 | assert len(self.optimizers) == len(self.optimizer_names)
228 | for id, optimizer in enumerate(self.optimizer_names):
229 | load_filename = self.optimizer_names[id]
230 | load_path = os.path.join(self.save_dir, load_filename+'.pth')
231 | print('loading the optimizer from %s' % load_path)
232 | state_dict = torch.load(load_path)
233 | assert optimizer == state_dict['name']
234 | assert epoch == state_dict['epoch']
235 | self.optimizers[id].load_state_dict(state_dict['state_dict'])
236 |
237 | def print_networks(self, verbose):
238 | print('---------- Networks initialized -------------')
239 |
240 | print('-----------------------------------------------')
241 |
242 | def estimate(self, tenFirst, tenSecond, net):
243 | assert(tenFirst.shape[3] == tenSecond.shape[3])
244 | assert(tenFirst.shape[2] == tenSecond.shape[2])
245 | intWidth = tenFirst.shape[3]
246 | intHeight = tenFirst.shape[2]
247 | # tenPreprocessedFirst = tenFirst.view(1, 3, intHeight, intWidth)
248 | # tenPreprocessedSecond = tenSecond.view(1, 3, intHeight, intWidth)
249 |
250 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
251 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
252 |
253 | tenPreprocessedFirst = F.interpolate(input=tenFirst,
254 | size=(intPreprocessedHeight, intPreprocessedWidth),
255 | mode='bilinear', align_corners=False)
256 | tenPreprocessedSecond = F.interpolate(input=tenSecond,
257 | size=(intPreprocessedHeight, intPreprocessedWidth),
258 | mode='bilinear', align_corners=False)
259 |
260 | tenFlow = 20.0 * F.interpolate(
261 | input=net(tenPreprocessedFirst, tenPreprocessedSecond),
262 | size=(intHeight, intWidth), mode='bilinear', align_corners=False)
263 |
264 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
265 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
266 |
267 | return tenFlow[:, :, :, :]
268 |
269 | def backwarp(self, tenInput, tenFlow):
270 | index = str(tenFlow.shape) + str(tenInput.device)
271 | if index not in self.backwarp_tenGrid:
272 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]),
273 | tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)
274 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]),
275 | tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])
276 | self.backwarp_tenGrid[index] = torch.cat([tenHor, tenVer], 1).to(tenInput.device)
277 |
278 | if index not in self.backwarp_tenPartial:
279 | self.backwarp_tenPartial[index] = tenFlow.new_ones([
280 | tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3]])
281 |
282 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
283 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
284 | tenInput = torch.cat([tenInput, self.backwarp_tenPartial[index]], 1)
285 |
286 | tenOutput = F.grid_sample(input=tenInput,
287 | grid=(self.backwarp_tenGrid[index] + tenFlow).permute(0, 2, 3, 1),
288 | mode='bilinear', padding_mode='zeros', align_corners=False)
289 |
290 | return tenOutput
291 |
292 | def get_backwarp(self, tenFirst, tenSecond, net, flow=None):
293 | if flow is None:
294 | flow = self.get_flow(tenFirst, tenSecond, net)
295 |
296 | tenoutput = self.backwarp(tenSecond, flow)
297 | tenMask = tenoutput[:, -1:, :, :]
298 | tenMask[tenMask > 0.999] = 1.0
299 | tenMask[tenMask < 1.0] = 0.0
300 | return tenoutput[:, :-1, :, :] * tenMask, tenMask
301 |
302 | def get_flow(self, tenFirst, tenSecond, net):
303 | with torch.no_grad():
304 | net.eval()
305 | flow = self.estimate(tenFirst, tenSecond, net)
306 | return flow
--------------------------------------------------------------------------------
/models/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | class invertedBlock(nn.Module):
5 | def __init__(self, in_channel, out_channel,ratio=2):
6 | super(invertedBlock, self).__init__()
7 | internal_channel = in_channel * ratio
8 | self.relu = nn.GELU()
9 | ## 7*7卷积,并行3*3卷积
10 | self.conv1 = nn.Conv2d(internal_channel, internal_channel, 7, 1, 3, groups=in_channel,bias=False)
11 |
12 | self.convFFN = ConvFFN(in_channels=in_channel, out_channels=in_channel)
13 | self.layer_norm = nn.LayerNorm(in_channel)
14 | self.pw1 = nn.Conv2d(in_channels=in_channel, out_channels=internal_channel, kernel_size=1, stride=1,
15 | padding=0, groups=1,bias=False)
16 | self.pw2 = nn.Conv2d(in_channels=internal_channel, out_channels=in_channel, kernel_size=1, stride=1,
17 | padding=0, groups=1,bias=False)
18 |
19 |
20 | def hifi(self,x):
21 |
22 | x1=self.pw1(x)
23 | x1=self.relu(x1)
24 | x1=self.conv1(x1)
25 | x1=self.relu(x1)
26 | x1=self.pw2(x1)
27 | x1=self.relu(x1)
28 | # x2 = self.conv2(x)
29 | x3 = x1+x
30 |
31 | x3 = x3.permute(0, 2, 3, 1).contiguous()
32 | x3 = self.layer_norm(x3)
33 | x3 = x3.permute(0, 3, 1, 2).contiguous()
34 | x4 = self.convFFN(x3)
35 |
36 | return x4
37 |
38 | def forward(self, x):
39 | return self.hifi(x)+x
40 | class ConvFFN(nn.Module):
41 |
42 | def __init__(self, in_channels, out_channels, expend_ratio=4):
43 | super().__init__()
44 |
45 | internal_channels = in_channels * expend_ratio
46 | self.pw1 = nn.Conv2d(in_channels=in_channels, out_channels=internal_channels, kernel_size=1, stride=1,
47 | padding=0, groups=1,bias=False)
48 | self.pw2 = nn.Conv2d(in_channels=internal_channels, out_channels=out_channels, kernel_size=1, stride=1,
49 | padding=0, groups=1,bias=False)
50 | self.nonlinear = nn.GELU()
51 |
52 | def forward(self, x):
53 | x1 = self.pw1(x)
54 | x2 = self.nonlinear(x1)
55 | x3 = self.pw2(x2)
56 | x4 = self.nonlinear(x3)
57 | return x4 + x
58 |
59 | class mixblock(nn.Module):
60 | def __init__(self, n_feats):
61 | super(mixblock, self).__init__()
62 | self.conv1=nn.Sequential(nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU())
63 | self.conv2=nn.Sequential(nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU(),nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU(),nn.Conv2d(n_feats,n_feats,3,1,1,bias=False),nn.GELU())
64 | self.alpha=nn.Parameter(torch.ones(1))
65 | self.beta=nn.Parameter(torch.ones(1))
66 | def forward(self,x):
67 | return self.alpha*self.conv1(x)+self.beta*self.conv2(x)
68 | class CALayer(nn.Module):
69 | def __init__(self, channel, reduction=16):
70 | super(CALayer, self).__init__()
71 | # global average pooling: feature --> point
72 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
73 | # feature channel downscale and upscale --> channel weight
74 | self.conv_du = nn.Sequential(
75 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
76 | nn.ReLU(inplace=True),
77 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
78 | nn.Sigmoid()
79 | )
80 |
81 | def forward(self, x):
82 | y = self.avg_pool(x)
83 | y = self.conv_du(y)
84 | return x * y
85 | class Downupblock(nn.Module):
86 | def __init__(self, n_feats):
87 | super(Downupblock, self).__init__()
88 | self.encoder = mixblock(n_feats)
89 | self.decoder_high = mixblock(n_feats) # nn.Sequential(one_module(n_feats),
90 |
91 | self.decoder_low = nn.Sequential(mixblock(n_feats), mixblock(n_feats), mixblock(n_feats))
92 | self.alise = nn.Conv2d(n_feats,n_feats,1,1,0,bias=False) # one_module(n_feats)
93 | self.alise2 = nn.Conv2d(n_feats*2,n_feats,3,1,1,bias=False) # one_module(n_feats)
94 | self.down = nn.AvgPool2d(kernel_size=2)
95 | self.att = CALayer(n_feats)
96 | self.raw_alpha=nn.Parameter(torch.ones(1))
97 |
98 | self.raw_alpha.data.fill_(0)
99 | self.ega=selfAttention(n_feats, n_feats)
100 |
101 | def forward(self, x,raw):
102 | x1 = self.encoder(x)
103 | x2 = self.down(x1)
104 | high = x1 - F.interpolate(x2, size=x.size()[-2:], mode='bilinear', align_corners=True)
105 |
106 | high=high+self.ega(high,high)*self.raw_alpha
107 | x2=self.decoder_low(x2)
108 | x3 = x2
109 | # x3 = self.decoder_low(x2)
110 | high1 = self.decoder_high(high)
111 | x4 = F.interpolate(x3, size=x.size()[-2:], mode='bilinear', align_corners=True)
112 | return self.alise(self.att(self.alise2(torch.cat([x4, high1], dim=1)))) + x
113 | class Updownblock(nn.Module):
114 | def __init__(self, n_feats):
115 | super(Updownblock, self).__init__()
116 | self.encoder = mixblock(n_feats)
117 | self.decoder_high = mixblock(n_feats) # nn.Sequential(one_module(n_feats),
118 | # one_module(n_feats),
119 | # one_module(n_feats))
120 | self.decoder_low = nn.Sequential(mixblock(n_feats), mixblock(n_feats), mixblock(n_feats))
121 |
122 | self.alise = nn.Conv2d(n_feats,n_feats,1,1,0,bias=False) # one_module(n_feats)
123 | self.alise2 = nn.Conv2d(n_feats*2,n_feats,3,1,1,bias=False) # one_module(n_feats)
124 | self.down = nn.AvgPool2d(kernel_size=2)
125 | self.att = CALayer(n_feats)
126 | self.raw_alpha=nn.Parameter(torch.ones(1))
127 | # fill 0
128 | self.raw_alpha.data.fill_(0)
129 | self.ega=selfAttention(n_feats, n_feats)
130 |
131 | def forward(self, x,raw):
132 | x1 = self.encoder(x)
133 | x2 = self.down(x1)
134 | high = x1 - F.interpolate(x2, size=x.size()[-2:], mode='bilinear', align_corners=True)
135 | high=high+self.ega(high,high)*self.raw_alpha
136 | x2=self.decoder_low(x2)
137 | x3 = x2
138 | high1 = self.decoder_high(high)
139 | x4 = F.interpolate(x3, size=x.size()[-2:], mode='bilinear', align_corners=True)
140 | return self.alise(self.att(self.alise2(torch.cat([x4, high1], dim=1)))) + x
141 | class basic_block(nn.Module):
142 | ## 双并行分支,通道分支和空间分支
143 | def __init__(self, in_channel, out_channel, depth,ratio=1):
144 | super(basic_block, self).__init__()
145 |
146 |
147 |
148 |
149 | # 个数为depth个
150 |
151 | self.rep1 = nn.Sequential(*[invertedBlock(in_channel=in_channel, out_channel=in_channel,ratio=ratio) for i in range(depth)])
152 |
153 |
154 | self.relu=nn.GELU()
155 | # 一部分做3个3*3卷积,一部分做1个
156 |
157 | self.updown=Updownblock(in_channel)
158 | self.downup=Downupblock(in_channel)
159 | def forward(self, x,raw=None):
160 |
161 |
162 | x1 = self.rep1(x)
163 |
164 |
165 | x1=self.updown(x1,raw)
166 | x1=self.downup(x1,raw)
167 | return x1+x
168 |
169 | import torchvision
170 | class VGG_aware(nn.Module):
171 | def __init__(self,outFeature):
172 | super(VGG_aware, self).__init__()
173 | blocks = []
174 | blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
175 |
176 | for bl in blocks:
177 | for p in bl:
178 | p.requires_grad = False
179 | self.blocks = torch.nn.ModuleList(blocks)
180 |
181 |
182 | def forward(self, x):
183 | return self.blocks[0](x)
184 |
185 | import torch.nn.functional as f
186 | class selfAttention(nn.Module):
187 | def __init__(self, in_channels, out_channels):
188 | super(selfAttention, self).__init__()
189 | self.query_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
190 | self.key_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
191 | self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
192 | self.scale = 1.0 / (out_channels ** 0.5)
193 |
194 | def forward(self, feature, feature_map):
195 | query = self.query_conv(feature)
196 | key = self.key_conv(feature)
197 | value = self.value_conv(feature)
198 | attention_scores = torch.matmul(query, key.transpose(-2, -1))
199 | attention_scores = attention_scores * self.scale
200 |
201 | attention_weights = f.softmax(attention_scores, dim=-1)
202 |
203 | attended_values = torch.matmul(attention_weights, value)
204 |
205 | output_feature_map = (feature_map + attended_values)
206 |
207 | return output_feature_map
--------------------------------------------------------------------------------
/models/cat_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from .base_model import BaseModel
3 | from . import networks as N
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from . import losses as L
7 | import torch.nn.functional as F
8 | import torchvision.ops as ops
9 | from util.util import mu_tonemap
10 |
11 |
12 | # For BracketIRE Task
13 | class CatModel(BaseModel):
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train=True):
16 | return parser
17 |
18 | def __init__(self, opt):
19 | super(CatModel, self).__init__(opt)
20 |
21 | self.opt = opt
22 | self.loss_names = ['TMRNet_l1', 'Total']
23 | self.visual_names = ['data_gt', 'data_in', 'data_out']
24 | self.model_names = ['TMRNet']
25 | self.optimizer_names = ['TMRNet_optimizer_%s' % opt.optimizer]
26 |
27 | tmrnet = TMRNet(opt)
28 | # tmrnet = AHDR(8, 6, 64, 32)
29 | self.netTMRNet = N.init_net(tmrnet, opt.init_type, opt.init_gain, opt.gpu_ids)
30 |
31 | if self.isTrain:
32 | self.optimizer_TMRNet = optim.AdamW(self.netTMRNet.parameters(),
33 | lr=opt.lr,
34 | betas=(opt.beta1, opt.beta2),
35 | weight_decay=opt.weight_decay)
36 | self.optimizers = [self.optimizer_TMRNet]
37 |
38 | self.criterionL1 = N.init_net(L.L1Loss(), gpu_ids=opt.gpu_ids)
39 |
40 | def set_input(self, input):
41 | self.data_gt = input['gt'].to(self.device)
42 | self.data_raws = input['raws'].to(self.device)
43 | self.image_paths = input['fname']
44 |
45 | expo = torch.stack([torch.pow(torch.tensor(4, dtype=torch.float32, device=self.data_raws.device), 2-x)
46 | for x in range(0, self.data_raws.shape[1])])
47 | self.expo = expo[None,:,None,None,None]
48 |
49 | def forward(self):
50 |
51 |
52 | self.data_raws = self.data_raws * self.expo
53 | self.data_in = self.data_raws[:,0,...].squeeze(1)
54 |
55 | if self.isTrain or (not self.isTrain and not self.opt.chop):
56 | # datat.raws: [N, T, C, H, W],在第二个维度进行展开,分为T个数据,每个数据为[N,C, H, W]
57 | # x1 = self.data_raws[:, 0, :, :, :]
58 | # x2 = self.data_raws[:, 1, :, :, :]
59 | # x3 = self.data_raws[:, 2, :, :, :]
60 | # x4= self.data_raws[:, 3, :, :, :]
61 | # x5 = self.data_raws[:, 4, :, :, :]
62 | # x1=torch.cat((x1,ldr_to_hdr(x1,1,1/2.2)),1)
63 | # x2=torch.cat((x2,ldr_to_hdr(x2,1,1/2.2)),1)
64 | # x3=torch.cat((x3,ldr_to_hdr(x3,1,1/2.2)),1)
65 | # x4=torch.cat((x4,ldr_to_hdr(x4,1,1/2.2)),1)
66 | # x5=torch.cat((x5,ldr_to_hdr(x5,1,1/2.2)),1)
67 |
68 | # self.data_out = self.netTMRNet(x1, x2, x3, x4, x5)
69 | self.data_out = self.netTMRNet(self.data_raws)
70 | elif self.opt.chop:
71 | self.data_out = self.forward_chop(self.data_raws)
72 |
73 | def forward_chop(self, data_raws, chop_size=800):
74 | n, t, c, h, w = data_raws.shape
75 |
76 | num_h = h // chop_size + 1
77 | num_w = w // chop_size + 1
78 | new_h = num_h * chop_size
79 | new_w = num_w * chop_size
80 |
81 | pad_h = new_h - h
82 | pad_w = new_w - w
83 |
84 | pad_top = int(pad_h / 2.)
85 | pad_bottom = pad_h - pad_top
86 | pad_left = int(pad_w / 2.)
87 | pad_right = pad_w - pad_left
88 |
89 | paddings = (pad_left, pad_right, pad_top, pad_bottom)
90 | new_input0 = torch.nn.ReflectionPad2d(paddings)(data_raws[0])
91 |
92 | out = torch.zeros([1, c, new_h, new_w], dtype=torch.float32, device=data_raws.device)
93 | for i in range(num_h):
94 | for j in range(num_w):
95 | out[:, :, i*chop_size:(i+1)*chop_size, j*chop_size:(j+1)*chop_size] = self.netTMRNet(
96 | new_input0.unsqueeze(0)[:,:,:,i*chop_size:(i+1)*chop_size, j*chop_size:(j+1)*chop_size])
97 | return out[:, :, pad_top:pad_top+h, pad_left:pad_left+w]
98 |
99 | def backward(self, epoch):
100 | self.loss_TMRNet_l1 = self.criterionL1(
101 | mu_tonemap(torch.clamp(self.data_out / 4**2, min=0)),
102 | mu_tonemap(torch.clamp(self.data_gt / 4**2, 0, 1))).mean()
103 |
104 | self.loss_Total = self.loss_TMRNet_l1
105 | self.loss_Total.backward()
106 |
107 | def optimize_parameters(self, epoch):
108 | self.forward()
109 | self.optimizer_TMRNet.zero_grad()
110 | self.backward(epoch)
111 | self.optimizer_TMRNet.step()
112 |
113 |
114 | class TMRNet(nn.Module):
115 | def __init__(self, opt, mid_channels=64, max_residue_magnitude=10):
116 |
117 | super().__init__()
118 | self.mid_channels = mid_channels
119 |
120 | # optical flow
121 | self.spynet = SPyNet()
122 | if opt.isTrain:
123 | N.load_spynet(self.spynet, './spynet/spynet_20210409-c6c1bd09.pth')
124 |
125 | self.dcn_alignment = DeformableAlignment(mid_channels, mid_channels, 3, padding=1, deform_groups=8,
126 | max_residue_magnitude=max_residue_magnitude)
127 |
128 | # feature extraction module
129 | self.feat_extract = ResidualBlocksWithInputConv(2 * 4, mid_channels, 5)
130 |
131 | # propagation branches
132 |
133 | self.backbone = nn.ModuleDict()
134 | if opt.block == 'tmrnet':
135 | self.backbone['backward'] = ResidualBlocksWithInputConv(5 * mid_channels, mid_channels, 16)
136 |
137 | for i in range(0, 5):
138 | self.backbone['backward_rec_%d' % (i + 1)] = ResidualBlocksWithInputConv(1 * mid_channels, mid_channels,
139 | 24)
140 |
141 |
142 | if opt.block == 'Convnext':
143 | self.backbone['backward'] = ConvnextBlocksWithInputConv(5 * mid_channels, mid_channels, 1)
144 |
145 | for i in range(0, 2):
146 | self.backbone['backward_rec_%d' % (i + 1)] = ConvnextBlocksWithInputConv(1 * mid_channels, mid_channels,
147 | 1)
148 | self.reconstruction = ResidualBlocksWithInputConv(2 * mid_channels, mid_channels, 5)
149 | self.down = ResidualBlocksWithInputConv(3 * mid_channels, mid_channels, 3)
150 | self.skipup1 = PixelShufflePack(4, mid_channels, 1, upsample_kernel=3)
151 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
152 |
153 | self.conv_hr = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1)
154 | self.conv_last = nn.Conv2d(mid_channels, 4, 3, 1, 1)
155 |
156 | def compute_flow(self, lqs):
157 | lqs = torch.stack((lqs[:, :, 0], lqs[:, :, 1:3].mean(dim=2), lqs[:, :, 3]), dim=2)
158 | lqs = torch.pow(torch.clamp(lqs, 0, 1), 1 / 2.2)
159 | n, t, c, h, w = lqs.size()
160 | oth = lqs[:, 1:, :, :, :].reshape(-1, c, h, w)
161 | ref = lqs[:, :1, :, :, :].repeat(1, t - 1, 1, 1, 1).reshape(-1, c, h, w)
162 | flows_backward = self.spynet(ref, oth).view(n, t - 1, 2, h, w)
163 | flows_forward = flows_backward
164 | return flows_forward, flows_backward
165 |
166 | def upsample(self, lqs, feats, base_feat):
167 | skip1 = self.skipup1(lqs[:, 0, :, :, :])
168 | hr = self.down(feats)
169 | hr = torch.cat([base_feat, hr], dim=1)
170 | hr = self.reconstruction(hr)
171 | hr = hr + skip1
172 | hr = self.lrelu(self.conv_hr(hr))
173 | hr = self.conv_last(hr)
174 | return hr
175 |
176 | def forward(self, lqs):
177 | n, t, c, h, w = lqs.size() # (n, t, c, h, w)
178 | lqs_downsample = lqs.clone()
179 |
180 | feats = {}
181 | lqs_view = lqs.view(-1, c, h, w)
182 | lqs_in = torch.zeros([n * t, 2 * c, h, w], dtype=lqs_view.dtype, device=lqs_view.device)
183 | lqs_in[:, 0::2, :, :] = lqs_view
184 | # torch clamp是
185 | lqs_in[:, 1::2, :, :] = torch.pow(torch.clamp(lqs_view, min=0), 1 / 2.2)
186 |
187 | feats_ = self.feat_extract(lqs_in) # (N*T, C, H, W)
188 | h, w = feats_.shape[2:]
189 | feats_ = feats_.view(n, t, -1, h, w)
190 |
191 | _, flows_backward = self.compute_flow(lqs_downsample)
192 | flows_backward = flows_backward.view(-1, 2, *feats_.shape[-2:])
193 |
194 | ref_feat = feats_[:, :1, :, :, :].repeat(1, t - 1, 1, 1, 1).view(-1, *feats_.shape[-3:])
195 | oth_feat = feats_[:, 1:, :, :, :].contiguous().view(-1, *feats_.shape[-3:])
196 |
197 | oth_feat_warped = N.flow_warp(oth_feat, flows_backward.permute(0, 2, 3, 1))
198 | oth_feat = self.dcn_alignment(oth_feat, ref_feat, oth_feat_warped, flows_backward)
199 | oth_feat = oth_feat.view(n, t - 1, -1, h, w)
200 | ref_feat = ref_feat.view(n, t - 1, -1, h, w)[:, :1, :, :, :]
201 |
202 | feats_ = torch.cat((ref_feat, oth_feat), dim=1) # (N, T, C, H, W)
203 |
204 | base_feat = feats_[:, 0, :, :, :]
205 | feats_ = torch.cat((feats_[:, 0, :, :, :], feats_[:, 1, :, :, :], feats_[:, 2, :, :, :], feats_[:, 3, :, :, :],
206 | feats_[:, 4, :, :, :]), 1)
207 | # feature propagation
208 | module = 'backward'
209 | feats0 = self.backbone[module](feats_)
210 | feats1 = self.backbone['backward_rec_%d' % (0 + 1)](feats0)
211 | feats2 = self.backbone['backward_rec_%d' % (1 + 1)](feats1)
212 |
213 | feats0 = torch.cat((feats0, feats1, feats2), 1)
214 |
215 | out = self.upsample(lqs, feats0, base_feat)
216 |
217 | return out
218 |
219 |
220 |
221 |
222 | from .blocks import *
223 |
224 |
225 | class ConvnextBlocksWithInputConv(nn.Module):
226 | def __init__(self, in_channels, out_channels=64, num_blocks=3):
227 | super().__init__()
228 |
229 | main = []
230 | # a convolution used to match the channels of the residual blocks
231 | main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True))
232 | main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
233 |
234 | # residual blocks
235 | for i in range(num_blocks):
236 | main.append(
237 | basic_block(out_channels, out_channels, depth=10, ratio=1))
238 |
239 | self.main = nn.Sequential(*main)
240 |
241 | def forward(self, feat):
242 | return self.main(feat)
243 |
244 |
245 | class ResidualBlocksWithInputConv(nn.Module):
246 | def __init__(self, in_channels, out_channels=64, num_blocks=30):
247 | super().__init__()
248 |
249 | main = []
250 | # a convolution used to match the channels of the residual blocks
251 | main.append(nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=True))
252 | main.append(nn.LeakyReLU(negative_slope=0.1, inplace=True))
253 |
254 | # residual blocks
255 | main.append(
256 | N.make_layer(
257 | ResidualBlockNoBN, num_blocks, mid_channels=out_channels))
258 |
259 | self.main = nn.Sequential(*main)
260 |
261 | def forward(self, feat):
262 | return self.main(feat)
263 |
264 |
265 | class ResidualBlockNoBN(nn.Module):
266 | def __init__(self, mid_channels=64, res_scale=1):
267 | super().__init__()
268 | self.res_scale = res_scale
269 | self.conv1 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True)
270 | self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, 1, 1, bias=True)
271 |
272 | self.relu = nn.ReLU(inplace=True)
273 |
274 | self.init_weights()
275 |
276 | def init_weights(self):
277 | N.init_weights(self.conv1, init_type='kaiming')
278 | N.init_weights(self.conv2, init_type='kaiming')
279 | self.conv1.weight.data *= 0.1
280 | self.conv2.weight.data *= 0.1
281 |
282 | def forward(self, x):
283 | identity = x
284 | out = self.conv2(self.relu(self.conv1(x)))
285 | return identity + out * self.res_scale
286 |
287 |
288 | class PixelShufflePack(nn.Module):
289 | def __init__(self, in_channels, out_channels, scale_factor,
290 | upsample_kernel):
291 | super().__init__()
292 | self.in_channels = in_channels
293 | self.out_channels = out_channels
294 | self.scale_factor = scale_factor
295 | self.upsample_kernel = upsample_kernel
296 | self.upsample_conv = nn.Conv2d(
297 | self.in_channels,
298 | self.out_channels * scale_factor * scale_factor,
299 | self.upsample_kernel,
300 | padding=(self.upsample_kernel - 1) // 2)
301 | self.init_weights()
302 |
303 | def init_weights(self):
304 | """Initialize weights for PixelShufflePack."""
305 | N.init_weights(self.upsample_conv, init_type='kaiming')
306 |
307 | def forward(self, x):
308 | x = self.upsample_conv(x)
309 | if self.scale_factor > 1:
310 | x = F.pixel_shuffle(x, self.scale_factor)
311 | return x
312 |
313 |
314 | class DeformableAlignment(nn.Module):
315 | def __init__(self, in_channels, out_channels, kernel=3, padding=1, deform_groups=8, max_residue_magnitude=10):
316 | super().__init__()
317 |
318 | self.max_residue_magnitude = max_residue_magnitude
319 |
320 | self.conv_offset = nn.Sequential(
321 | nn.Conv2d(2 * out_channels + 2, out_channels, 3, 1, 1),
322 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
323 | nn.Conv2d(out_channels, out_channels, 3, 1, 1),
324 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
325 | nn.Conv2d(out_channels, out_channels, 3, 1, 1),
326 | nn.LeakyReLU(negative_slope=0.1, inplace=True),
327 | nn.Conv2d(out_channels, 27 * deform_groups, 3, 1, 1),
328 | )
329 |
330 | self.deform_conv = ops.DeformConv2d(in_channels, out_channels, kernel_size=kernel, stride=1,
331 | padding=padding, dilation=1, bias=True, groups=deform_groups)
332 |
333 | self.init_offset()
334 |
335 | def init_offset(self):
336 | N.init_weights(self.conv_offset[-1], init_type='constant')
337 |
338 | def forward(self, cur_feat, ref_feat, warped_feat, flow):
339 | extra_feat = torch.cat([warped_feat, ref_feat, flow], dim=1)
340 | out = self.conv_offset(extra_feat)
341 | o1, o2, mask = torch.chunk(out, 3, dim=1)
342 | offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
343 | offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1)
344 | mask = torch.sigmoid(mask)
345 | return self.deform_conv(cur_feat, offset, mask=mask)
346 |
347 |
348 | class SPyNet(nn.Module):
349 | """SPyNet network structure.
350 |
351 | The difference to the SPyNet in [tof.py] is that
352 | 1. more SPyNetBasicModule is used in this version, and
353 | 2. no batch normalization is used in this version.
354 |
355 | Paper:
356 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
357 |
358 | Args:
359 | pretrained (str): path for pre-trained SPyNet. Default: None.
360 | """
361 |
362 | def __init__(self):
363 | super().__init__()
364 |
365 | self.basic_module = nn.ModuleList(
366 | [SPyNetBasicModule() for _ in range(6)])
367 |
368 | self.register_buffer(
369 | 'mean',
370 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
371 | self.register_buffer(
372 | 'std',
373 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
374 |
375 | def compute_flow(self, ref, supp):
376 | """Compute flow from ref to supp.
377 |
378 | Note that in this function, the images are already resized to a
379 | multiple of 32.
380 |
381 | Args:
382 | ref (Tensor): Reference image with shape of (n, 3, h, w).
383 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
384 |
385 | Returns:
386 | Tensor: Estimated optical flow: (n, 2, h, w).
387 | """
388 | n, _, h, w = ref.size()
389 |
390 | # normalize the input images
391 | ref = [(ref - self.mean) / self.std]
392 | supp = [(supp - self.mean) / self.std]
393 |
394 | # generate downsampled frames
395 | for level in range(5):
396 | ref.append(
397 | F.avg_pool2d(
398 | input=ref[-1],
399 | kernel_size=2,
400 | stride=2,
401 | count_include_pad=False))
402 | supp.append(
403 | F.avg_pool2d(
404 | input=supp[-1],
405 | kernel_size=2,
406 | stride=2,
407 | count_include_pad=False))
408 | ref = ref[::-1]
409 | supp = supp[::-1]
410 |
411 | # flow computation
412 | flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
413 | for level in range(len(ref)):
414 | if level == 0:
415 | flow_up = flow
416 | else:
417 | flow_up = F.interpolate(
418 | input=flow,
419 | scale_factor=2,
420 | mode='bilinear',
421 | align_corners=True) * 2.0
422 |
423 | # add the residue to the upsampled flow
424 | flow = flow_up + self.basic_module[level](
425 | torch.cat([
426 | ref[level],
427 | N.flow_warp(
428 | supp[level],
429 | flow_up.permute(0, 2, 3, 1),
430 | padding_mode='border'), flow_up
431 | ], 1))
432 |
433 | return flow
434 |
435 | def forward(self, ref, supp):
436 | """Forward function of SPyNet.
437 |
438 | This function computes the optical flow from ref to supp.
439 |
440 | Args:
441 | ref (Tensor): Reference image with shape of (n, 3, h, w).
442 | supp (Tensor): Supporting image with shape of (n, 3, h, w).
443 |
444 | Returns:
445 | Tensor: Estimated optical flow: (n, 2, h, w).
446 | """
447 |
448 | # upsize to a multiple of 32
449 | h, w = ref.shape[2:4]
450 | w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
451 | h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
452 | ref = F.interpolate(
453 | input=ref, size=(h_up, w_up), mode='bilinear', align_corners=False)
454 | supp = F.interpolate(
455 | input=supp,
456 | size=(h_up, w_up),
457 | mode='bilinear',
458 | align_corners=False)
459 |
460 | # compute flow, and resize back to the original resolution
461 | flow = F.interpolate(
462 | input=self.compute_flow(ref, supp),
463 | size=(h, w),
464 | mode='bilinear',
465 | align_corners=False)
466 |
467 | # adjust the flow values
468 | flow[:, 0, :, :] *= float(w) / float(w_up)
469 | flow[:, 1, :, :] *= float(h) / float(h_up)
470 |
471 | return flow
472 |
473 |
474 | class SPyNetBasicModule(nn.Module):
475 | """Basic Module for SPyNet.
476 |
477 | Paper:
478 | Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
479 | """
480 |
481 | def __init__(self):
482 | super().__init__()
483 |
484 | self.basic_module = nn.Sequential(
485 | nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
486 | nn.ReLU(inplace=True),
487 |
488 | nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
489 | nn.ReLU(inplace=True),
490 |
491 | nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
492 | nn.ReLU(inplace=True),
493 |
494 | nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
495 | nn.ReLU(inplace=True),
496 |
497 | nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)
498 | )
499 |
500 | def forward(self, tensor_input):
501 | """
502 | Args:
503 | tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
504 | 8 channels contain:
505 | [reference image (3), neighbor image (3), initial flow (2)].
506 |
507 | Returns:
508 | Tensor: Refined flow with shape (b, 2, h, w)
509 | """
510 | return self.basic_module(tensor_input)
511 |
512 |
--------------------------------------------------------------------------------
/models/degrade/degrade_kernel.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import random
4 | import torch
5 | from scipy import ndimage
6 | from scipy.interpolate import interp2d
7 | from .unprocess import unprocess, random_noise_levels, add_noise
8 | from .process import process
9 | from PIL import Image
10 |
11 |
12 |
13 | # def get_rgb2raw2rgb(img):
14 | # img = torch.from_numpy(np.array(img)) / 255.0
15 | # deg_img, features = unprocess(img)
16 | # shot_noise, read_noise = random_noise_levels()
17 | # deg_img = add_noise(deg_img, shot_noise, read_noise)
18 | # deg_img = deg_img.unsqueeze(0)
19 | # features['red_gain'] = features['red_gain'].unsqueeze(0)
20 | # features['blue_gain'] = features['blue_gain'].unsqueeze(0)
21 | # features['cam2rgb'] = features['cam2rgb'].unsqueeze(0)
22 | # deg_img = process(deg_img, features['red_gain'], features['blue_gain'], features['cam2rgb'])
23 | # deg_img = deg_img.squeeze(0)
24 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy()
25 | # deg_img = deg_img.astype(np.uint8)
26 | # return Image.fromarray(deg_img)
27 |
28 |
29 | # def get_rgb2raw_noise(img, noise_level, features=None):
30 | # # img = np.transpose(img, (1, 2, 0))
31 | # img = torch.from_numpy(np.array(img)) / 255.0
32 |
33 | # deg_img, features = unprocess(img, features)
34 | # shot_noise, read_noise = random_noise_levels(noise_level)
35 | # deg_img_noise = add_noise(deg_img, shot_noise, read_noise)
36 | # # deg_img_noise = torch.clamp(deg_img_noise, min=0.0, max=1.0)
37 |
38 | # # deg_img = np.transpose(deg_img, (2, 0, 1))
39 | # # deg_img_noise = np.transpose(deg_img_noise, (2, 0, 1))
40 | # return deg_img_noise, features
41 |
42 |
43 | def get_rgb2raw(img, features=None):
44 | # img = np.transpose(img, (1, 2, 0))
45 | device = img.device
46 | deg_img, features = unprocess(img, features, device)
47 | return deg_img, features
48 |
49 |
50 | def get_raw2rgb(img, features, demosaic='net', lineRGB=False):
51 | # img = np.transpose(img, (1, 2, 0))
52 | # img = torch.from_numpy(np.array(img))
53 | img = img.unsqueeze(0)
54 | device = img.device
55 | deg_img = process(img, features['red_gain'].to(device), features['blue_gain'].to(device),
56 | features['cam2rgb'].to(device), demosaic, lineRGB)
57 | deg_img = deg_img.squeeze(0)
58 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy()
59 | # deg_img = deg_img.astype(np.uint8)
60 | return deg_img
61 |
62 | def get_raw2rgb2(img, features, demosaic='net', lineRGB=False):
63 | # img = np.transpose(img, (1, 2, 0))
64 | # img = torch.from_numpy(np.array(img))
65 | image=[]
66 | for i in range(img.shape[0]):
67 | image.append(img[i,:,:])
68 | deg_images=[]
69 | # img = img.unsqueeze(0)
70 | device = img.device
71 | for i in range(img.shape[0]):
72 | im=image[i].unsqueeze(0)
73 | deg_img = process(im, features['red_gain'][i].unsqueeze(0).to(device), features['blue_gain'][i].unsqueeze(0).to(device),
74 | features['cam2rgb'][i].unsqueeze(0).to(device), demosaic, lineRGB)
75 | deg_img = deg_img.squeeze(0)
76 | deg_images.append(deg_img)
77 | result=torch.stack(deg_images)
78 | # deg_img = deg_img.squeeze(0)
79 | # deg_img = torch.clamp(deg_img * 255.0, 0.0, 255.0).numpy()
80 | # deg_img = deg_img.astype(np.uint8)
81 | return result
82 | # def pack_raw_image(im_raw): # HxW
83 | # """ Packs a single channel bayer image into 4 channel tensor, where channels contain R, G, G, and B values"""
84 | # if isinstance(im_raw, np.ndarray):
85 | # im_out = np.zeros_like(im_raw, shape=(4, im_raw.shape[0] // 2, im_raw.shape[1] // 2))
86 | # elif isinstance(im_raw, torch.Tensor):
87 | # im_out = torch.zeros((4, im_raw.shape[0] // 2, im_raw.shape[1] // 2), dtype=im_raw.dtype)
88 | # else:
89 | # raise Exception
90 |
91 | # im_out[0, :, :] = im_raw[0::2, 0::2]
92 | # im_out[1, :, :] = im_raw[0::2, 1::2]
93 | # im_out[2, :, :] = im_raw[1::2, 0::2]
94 | # im_out[3, :, :] = im_raw[1::2, 1::2]
95 | # return im_out # 4xHxW
96 |
97 |
98 | # def flatten_raw_image(im_raw_4ch): # 4xHxW
99 | # """ unpack a 4-channel tensor into a single channel bayer image"""
100 | # if isinstance(im_raw_4ch, np.ndarray):
101 | # im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2))
102 | # elif isinstance(im_raw_4ch, torch.Tensor):
103 | # im_out = torch.zeros((im_raw_4ch.shape[1] * 2, im_raw_4ch.shape[2] * 2), dtype=im_raw_4ch.dtype)
104 | # else:
105 | # raise Exception
106 |
107 | # im_out[0::2, 0::2] = im_raw_4ch[0, :, :]
108 | # im_out[0::2, 1::2] = im_raw_4ch[1, :, :]
109 | # im_out[1::2, 0::2] = im_raw_4ch[2, :, :]
110 | # im_out[1::2, 1::2] = im_raw_4ch[3, :, :]
111 |
112 | # return im_out # HxW
--------------------------------------------------------------------------------
/models/degrade/process.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Forward processing of raw data to sRGB images.
17 |
18 | Unprocessing Images for Learned Raw Denoising
19 | http://timothybrooks.com/tech/unprocessing
20 | """
21 |
22 | import numpy as np
23 | import torch
24 | import torch.nn as nn
25 | import torch.distributions as tdist
26 | from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007
27 | import os
28 | from isp import demosaic_bayer
29 |
30 |
31 | def apply_gains(bayer_images, red_gains, blue_gains):
32 | """Applies white balance gains to a batch of Bayer images."""
33 | red_gains = red_gains.squeeze(1)
34 | blue_gains= blue_gains.squeeze(1)
35 | green_gains = torch.ones_like(red_gains)
36 | gains = torch.stack([red_gains, green_gains, green_gains, blue_gains], dim=-1)
37 | gains = gains[:, None, None, :]
38 | # print(bayer_images.shape, gains.shape)
39 | outs = bayer_images * gains
40 | return outs
41 |
42 |
43 | def demosaic(bayer_images):
44 | def SpaceToDepth_fact2(x):
45 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14
46 | bs = 2
47 | N, C, H, W = x.size()
48 | x = x.view(N, C, H // bs, bs, W // bs, bs) # (N, C, H//bs, bs, W//bs, bs)
49 | x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # (N, bs, bs, C, H//bs, W//bs)
50 | x = x.view(N, C * (bs ** 2), H // bs, W // bs) # (N, C*bs^2, H//bs, W//bs)
51 | return x
52 | def DepthToSpace_fact2(x):
53 | # From here - https://discuss.pytorch.org/t/is-there-any-layer-like-tensorflows-space-to-depth-function/3487/14
54 | bs = 2
55 | N, C, H, W = x.size()
56 | x = x.view(N, bs, bs, C // (bs ** 2), H, W) # (N, bs, bs, C//bs^2, H, W)
57 | x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # (N, C//bs^2, H, bs, W, bs)
58 | x = x.view(N, C // (bs ** 2), H * bs, W * bs) # (N, C//bs^2, H * bs, W * bs)
59 | return x
60 |
61 | """Bilinearly demosaics a batch of RGGB Bayer images."""
62 |
63 | shape = bayer_images.size()
64 | shape = [shape[1] * 2, shape[2] * 2]
65 |
66 | red = bayer_images[Ellipsis, 0:1]
67 | upsamplebyX = nn.Upsample(size=shape, mode='bilinear', align_corners=False)
68 | red = upsamplebyX(red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
69 |
70 | green_red = bayer_images[Ellipsis, 1:2]
71 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right
72 | green_red = upsamplebyX(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
73 | green_red = torch.flip(green_red, dims=[1]) # Flip left-right
74 | green_red = SpaceToDepth_fact2(green_red.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
75 |
76 | green_blue = bayer_images[Ellipsis, 2:3]
77 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down
78 | green_blue = upsamplebyX(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
79 | green_blue = torch.flip(green_blue, dims=[0]) # Flip up-down
80 | green_blue = SpaceToDepth_fact2(green_blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
81 |
82 | green_at_red = (green_red[Ellipsis, 0] + green_blue[Ellipsis, 0]) / 2
83 | green_at_green_red = green_red[Ellipsis, 1]
84 | green_at_green_blue = green_blue[Ellipsis, 2]
85 | green_at_blue = (green_red[Ellipsis, 3] + green_blue[Ellipsis, 3]) / 2
86 |
87 | green_planes = [
88 | green_at_red, green_at_green_red, green_at_green_blue, green_at_blue
89 | ]
90 | green = DepthToSpace_fact2(torch.stack(green_planes, dim=-1).permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
91 |
92 | blue = bayer_images[Ellipsis, 3:4]
93 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0])
94 | blue = upsamplebyX(blue.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
95 | blue = torch.flip(torch.flip(blue, dims=[1]), dims=[0])
96 |
97 | rgb_images = torch.cat([red, green, blue], dim=-1)
98 | return rgb_images
99 |
100 |
101 | def apply_ccms(images, ccms):
102 | """Applies color correction matrices."""
103 | images = images[:, :, :, None, :]
104 | ccms = ccms[:, None, None, :, :]
105 | outs = torch.sum(images * ccms, dim=-1)
106 | return outs
107 |
108 |
109 | def gamma_compression(images, gamma=2.2):
110 | """Converts from linear to gamma space."""
111 | # Clamps to prevent numerical instability of gradients near zero.
112 | Mask = lambda x: (x>0.0031308).float()
113 | sRGBDeLinearize = lambda x,m: m * (1.055 * (m * x) ** (1/2.4) - 0.055) + (1-m) * (12.92 * x)
114 | return sRGBDeLinearize(images, Mask(images))
115 | # outs = torch.clamp(images, min=1e-8) ** (1.0 / gamma)
116 | # return outs
117 |
118 |
119 | def process(bayer_images, red_gains, blue_gains, cam2rgbs, demosaic_type, lineRGB):
120 | # print(bayer_images.shape, red_gains.shape, cam2rgbs.shape)
121 | """Processes a batch of Bayer RGGB images into sRGB images."""
122 | # White balance.
123 | bayer_images = apply_gains(bayer_images, red_gains, blue_gains)
124 | # Demosaic.
125 | bayer_images = torch.clamp(bayer_images, min=0.0, max=1.0)
126 |
127 | if demosaic_type == 'default':
128 | images = demosaic(bayer_images)
129 | elif demosaic_type == 'menon2007':
130 | # print(bayer_images.size())
131 | bayer_images = flatten_raw_image(bayer_images.squeeze(0))
132 | images = demosaicing_CFA_Bayer_Menon2007(bayer_images.cpu().numpy(), 'RGGB')
133 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device)
134 | elif demosaic_type == 'net':
135 | bayer_images = flatten_raw_image(bayer_images.squeeze(0)).cpu().numpy()
136 | bayer = np.power(np.clip(bayer_images.astype(dtype=np.float32), 0, 1), 1 / 2.2)
137 | pretrained_model_path = "./isp/model.bin"
138 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=red_gains.device,
139 | cfa='bayer', state_dict=True)
140 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa='RGGB',
141 | demosaic_net=demosaic_net, device=red_gains.device)
142 | images = np.power(np.clip(rgb, 0, 1), 2.2)
143 | images = torch.from_numpy(images).unsqueeze(0).to(red_gains.device)
144 | elif demosaic_type == 'net2':
145 | bayer_images = flatten_raw_image(bayer_images.squeeze(0))
146 | # bayer = np.power(np.clip(bayer_images.astype(dtype=np.float32), 0, 1), 1 / 2.2)
147 | # 使用torch tensor类型重写bayer
148 | bayer= torch.pow(torch.clamp(bayer_images, 0, 1), 1 / 2.2)
149 | pretrained_model_path = "./isp/model.bin"
150 | demosaic_net = demosaic_bayer.get_demosaic_net_model(pretrained=pretrained_model_path, device=red_gains.device,
151 | cfa='bayer', state_dict=True)
152 | rgb = demosaic_bayer.demosaic_by_demosaic_net(bayer=bayer, cfa='RGGB',
153 | demosaic_net=demosaic_net, device=red_gains.device)
154 | # images = np.power(np.clip(rgb, 0, 1), 2.2)
155 | images= torch.pow(torch.clamp(rgb, 0, 1), 2.2)
156 | images = images.to(red_gains.device)
157 | # Color correction.
158 | images = apply_ccms(images, cam2rgbs)
159 | # Gamma compression.
160 | images = torch.clamp(images, min=0.0, max=1.0)
161 | if not lineRGB:
162 | images = gamma_compression(images)
163 | return images
164 |
165 |
166 | def flatten_raw_image(im_raw_4ch): # HxWx4
167 | """ unpack a 4-channel tensor into a single channel bayer image"""
168 | if isinstance(im_raw_4ch, np.ndarray):
169 | im_out = np.zeros_like(im_raw_4ch, shape=(im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2))
170 | elif isinstance(im_raw_4ch, torch.Tensor):
171 | im_out = torch.zeros((im_raw_4ch.shape[0] * 2, im_raw_4ch.shape[1] * 2), dtype=im_raw_4ch.dtype)
172 | else:
173 | raise Exception
174 |
175 | im_out[0::2, 0::2] = im_raw_4ch[:, :, 0]
176 | im_out[0::2, 1::2] = im_raw_4ch[:, :, 1]
177 | im_out[1::2, 0::2] = im_raw_4ch[:, :, 2]
178 | im_out[1::2, 1::2] = im_raw_4ch[:, :, 3]
179 |
180 | return im_out # HxW
--------------------------------------------------------------------------------
/models/degrade/unprocess.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Unprocesses sRGB images into realistic raw data.
17 |
18 | Unprocessing Images for Learned Raw Denoising
19 | http://timothybrooks.com/tech/unprocessing
20 | """
21 |
22 | import numpy as np
23 | import torch
24 | import torch.distributions as tdist
25 |
26 |
27 | def random_ccm(device):
28 | """Generates random RGB -> Camera color correction matrices."""
29 | # Takes a random convex combination of XYZ -> Camera CCMs.
30 | xyz2cams = [[[1.0234, -0.2969, -0.2266],
31 | [-0.5625, 1.6328, -0.0469],
32 | [-0.0703, 0.2188, 0.6406]],
33 | [[0.4913, -0.0541, -0.0202],
34 | [-0.613, 1.3513, 0.2906],
35 | [-0.1564, 0.2151, 0.7183]],
36 | [[0.838, -0.263, -0.0639],
37 | [-0.2887, 1.0725, 0.2496],
38 | [-0.0627, 0.1427, 0.5438]],
39 | [[0.6596, -0.2079, -0.0562],
40 | [-0.4782, 1.3016, 0.1933],
41 | [-0.097, 0.1581, 0.5181]]]
42 | num_ccms = len(xyz2cams)
43 | xyz2cams = torch.FloatTensor(xyz2cams)
44 | weights = torch.FloatTensor(num_ccms, 1, 1).uniform_(1e-8, 1e8)
45 | weights_sum = torch.sum(weights, dim=0)
46 | xyz2cam = torch.sum(xyz2cams * weights, dim=0) / weights_sum
47 |
48 | # Multiplies with RGB -> XYZ to get RGB -> Camera CCM.
49 | rgb2xyz = torch.FloatTensor([[0.4124564, 0.3575761, 0.1804375],
50 | [0.2126729, 0.7151522, 0.0721750],
51 | [0.0193339, 0.1191920, 0.9503041]])
52 | rgb2cam = torch.mm(xyz2cam.to(device), rgb2xyz.to(device))
53 |
54 | # Normalizes each row.
55 | rgb2cam = rgb2cam / torch.sum(rgb2cam, dim=-1, keepdim=True)
56 | return rgb2cam
57 |
58 |
59 | def random_gains(device):
60 | """Generates random gains for brightening and white balance."""
61 | # RGB gain represents brightening.
62 | n = tdist.Normal(loc=torch.tensor([0.8]), scale=torch.tensor([0.1]))
63 | rgb_gain = 1.0 / n.sample()
64 |
65 | # Red and blue gains represent white balance.
66 | red_gain = torch.FloatTensor(1).uniform_(1.9, 2.4)
67 | blue_gain = torch.FloatTensor(1).uniform_(1.5, 1.9)
68 | return rgb_gain.to(device), red_gain.to(device), blue_gain.to(device)
69 |
70 |
71 | def inverse_smoothstep(image):
72 | """Approximately inverts a global tone mapping curve."""
73 | image = torch.clamp(image, min=0.0, max=1.0)
74 | out = 0.5 - torch.sin(torch.asin(1.0 - 2.0 * image) / 3.0)
75 | return out
76 |
77 |
78 | def gamma_expansion(image):
79 | """Converts from gamma to linear space."""
80 | # Clamps to prevent numerical instability of gradients near zero.
81 | Mask = lambda x: (x>0.04045).float()
82 | sRGBLinearize = lambda x,m: m * ((m * x + 0.055) / 1.055) ** 2.4 + (1-m) * (x / 12.92)
83 | return sRGBLinearize(image, Mask(image))
84 | # out = torch.clamp(image, min=1e-8) ** 2.2
85 | # return out
86 |
87 |
88 | def apply_ccm(image, ccm):
89 | """Applies a color correction matrix."""
90 | shape = image.size()
91 | image = torch.reshape(image, [-1, 3])
92 | image = torch.tensordot(image, ccm, dims=[[-1], [-1]])
93 | out = torch.reshape(image, shape)
94 | return out
95 |
96 |
97 | def safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device):
98 | """Inverts gains while safely handling saturated pixels."""
99 | gains = torch.stack((1.0 / red_gain, torch.tensor([1.0]).to(device), 1.0 / blue_gain)) # / rgb_gain
100 | gains = gains.to(device).squeeze()
101 | gains = gains[None, None, :]
102 | # Prevents dimming of saturated pixels by smoothly masking gains near white.
103 | gray = torch.mean(image, dim=-1, keepdim=True)
104 | inflection = 0.9
105 | mask = (torch.clamp(gray - inflection, min=0.0) / (1.0 - inflection)) ** 2.0
106 | safe_gains = torch.max(mask + (1.0 - mask) * gains, gains)
107 | out = image * safe_gains
108 | return out
109 |
110 |
111 | def mosaic(image):
112 | """Extracts RGGB Bayer planes from an RGB image."""
113 | shape = image.size()
114 | red = image[0::2, 0::2, 0]
115 | green_red = image[0::2, 1::2, 1]
116 | green_blue = image[1::2, 0::2, 1]
117 | blue = image[1::2, 1::2, 2]
118 | out = torch.stack((red, green_red, green_blue, blue), dim=-1)
119 | out = torch.reshape(out, (shape[0] // 2, shape[1] // 2, 4))
120 | return out
121 |
122 |
123 | def unprocess(image, features=None, device=None):
124 | """Unprocesses an image from sRGB to realistic raw data."""
125 |
126 | if features == None:
127 | # Randomly creates image metadata.
128 | rgb2cam = random_ccm(device)
129 | cam2rgb = torch.inverse(rgb2cam)
130 | rgb_gain, red_gain, blue_gain = random_gains(device)
131 | else:
132 | rgb2cam = features['rgb2cam']
133 | cam2rgb = features['cam2rgb']
134 | rgb_gain = features['rgb_gain']
135 | red_gain = features['red_gain']
136 | blue_gain = features['blue_gain']
137 | # Approximately inverts global tone mapping.
138 | # image = inverse_smoothstep(image)
139 | # Inverts gamma compression.
140 | image = gamma_expansion(image)
141 | # Inverts color correction.
142 | image = apply_ccm(image, rgb2cam)
143 | # Approximately inverts white balance and brightening.
144 | image = safe_invert_gains(image, rgb_gain, red_gain, blue_gain, device)
145 | # Clips saturated pixels.
146 | # image = torch.clamp(image, min=0.0, max=1.0)
147 | # Applies a Bayer mosaic.
148 | image = mosaic(image)
149 |
150 | metadata = {
151 | 'rgb2cam': rgb2cam,
152 | 'cam2rgb': cam2rgb,
153 | 'rgb_gain': rgb_gain,
154 | 'red_gain': red_gain,
155 | 'blue_gain': blue_gain,
156 | }
157 | return image, metadata
158 |
159 |
160 | # ############### If the target dataset is DND, use this function #####################
161 | # def random_noise_levels():
162 | # """Generates random noise levels from a log-log linear distribution."""
163 | # log_min_shot_noise = np.log(0.0001)
164 | # log_max_shot_noise = np.log(0.012)
165 | # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise)
166 | # shot_noise = torch.exp(log_shot_noise)
167 |
168 | # line = lambda x: 2.18 * x + 1.20
169 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.26]))
170 | # log_read_noise = line(log_shot_noise) + n.sample()
171 | # read_noise = torch.exp(log_read_noise)
172 | # return shot_noise, read_noise
173 |
174 |
175 | def add_noise(image, shot_noise, read_noise):
176 | var = image * shot_noise + read_noise
177 | noise = tdist.Normal(loc=torch.zeros_like(var), scale=torch.sqrt(var)).sample()
178 | out = image + noise
179 | return out
180 |
181 |
182 | ################ If the target dataset is SIDD, use this function #####################
183 | def random_noise_levels(noise_level):
184 | """ Where read_noise in SIDD is not 0 """
185 | log_min_shot_noise = torch.log(torch.tensor(0.0012)).to(noise_level.device)
186 | log_max_shot_noise = torch.log(torch.tensor(0.0048)).to(noise_level.device)
187 | log_shot_noise = log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise)
188 | shot_noise = torch.exp(log_shot_noise)
189 |
190 | line = lambda x: 1.869 * x + 0.3276
191 | n = tdist.Normal(loc=torch.tensor([0.0], device=noise_level.device),
192 | scale=torch.tensor([0.30], device=noise_level.device))
193 |
194 | log_read_noise = line(log_shot_noise) + n.sample()
195 | read_noise = torch.exp(log_read_noise)
196 |
197 | return shot_noise, read_noise
198 |
199 |
200 | # def add_noise(image, shot_noise=0.01, read_noise=0.0005):
201 | # """Adds random shot (proportional to image) and read (independent) noise."""
202 | # variance = image * shot_noise + read_noise
203 | # n = tdist.Normal(loc=torch.zeros_like(variance), scale=torch.sqrt(variance))
204 | # noise = n.sample()
205 | # out = image + noise
206 | # return out
207 |
208 |
209 | # ################ If the target dataset is SIDD, use this function #####################
210 | # def random_noise_levels(noise_level):
211 | # """ Where read_noise in SIDD is not 0 """
212 | # log_min_shot_noise = np.log(0.00068674)
213 | # log_max_shot_noise = np.log(0.02194856)
214 | # # log_shot_noise = torch.FloatTensor(1).uniform_(log_min_shot_noise, log_max_shot_noise)
215 | # log_shot_noise = torch.FloatTensor([log_min_shot_noise + noise_level * (log_max_shot_noise - log_min_shot_noise)])
216 | # shot_noise = torch.exp(log_shot_noise)
217 |
218 | # line = lambda x: 1.85 * x + 0.30
219 | # n = tdist.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.20]))
220 | # log_read_noise = line(log_shot_noise) + n.sample()
221 | # read_noise = torch.exp(log_read_noise)
222 | # return shot_noise, read_noise
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.autograd import Variable
7 | import numpy as np
8 | from math import exp, sqrt
9 | from torch.nn import L1Loss, MSELoss
10 | from torchvision import models
11 | from data.degrade.process import demosaic
12 |
13 |
14 | def gaussian(window_size, sigma):
15 | gauss = torch.Tensor([exp(
16 | -(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) \
17 | for x in range(window_size)])
18 | return gauss / gauss.sum()
19 |
20 | def create_window(window_size, channel):
21 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
22 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
23 | window = Variable(_2D_window.expand(
24 | channel, 1, window_size, window_size).contiguous())
25 | return window
26 |
27 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
28 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
29 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
30 |
31 | mu1_sq = mu1.pow(2)
32 | mu2_sq = mu2.pow(2)
33 | mu1_mu2 = mu1 * mu2
34 |
35 | sigma1_sq = F.conv2d(
36 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
37 | sigma2_sq = F.conv2d(
38 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
39 | sigma12 = F.conv2d(
40 | img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
41 |
42 | C1 = 0.01 ** 2
43 | C2 = 0.03 ** 2
44 |
45 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \
46 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
47 |
48 | if size_average:
49 | return ssim_map.mean()
50 | else:
51 | return ssim_map.mean(1).mean(1).mean(1)
52 |
53 | def ssim(img1, img2, window_size=11, size_average=True):
54 | (_, channel, _, _) = img1.size()
55 | window = create_window(window_size, channel)
56 |
57 | if img1.is_cuda:
58 | window = window.cuda(img1.get_device())
59 | window = window.type_as(img1)
60 |
61 | return _ssim(img1, img2, window, window_size, channel, size_average)
62 |
63 | class SSIMLoss(nn.Module):
64 | def __init__(self, window_size=11, size_average=True):
65 | super(SSIMLoss, self).__init__()
66 | self.window_size = window_size
67 | self.size_average = size_average
68 | self.channel = 1
69 | self.window = create_window(window_size, self.channel)
70 |
71 | def forward(self, img1, img2):
72 | (_, channel, _, _) = img1.size()
73 |
74 | if channel == self.channel and \
75 | self.window.data.type() == img1.data.type():
76 | window = self.window
77 | else:
78 | window = create_window(self.window_size, channel)
79 |
80 | if img1.is_cuda:
81 | window = window.cuda(img1.get_device())
82 | window = window.type_as(img1)
83 |
84 | self.window = window
85 | self.channel = channel
86 |
87 | return _ssim(img1, img2, window, self.window_size,
88 | channel, self.size_average)
89 |
90 | def post_process(image):
91 | image = torch.stack((image[:, 0], image[:, 1:3].mean(dim=1), image[:, 3]), dim=1)
92 | # image = torch.pow(torch.clamp(image, 1e-6, 1), 1/2.2)
93 | return image
94 |
95 | def normalize_batch(batch):
96 | batch = post_process(batch)
97 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
98 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
99 | return (batch - mean) / std
100 |
101 | class VGG19(torch.nn.Module):
102 | def __init__(self):
103 | super(VGG19, self).__init__()
104 | features = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
105 | self.relu1_1 = torch.nn.Sequential()
106 | self.relu1_2 = torch.nn.Sequential()
107 |
108 | self.relu2_1 = torch.nn.Sequential()
109 | self.relu2_2 = torch.nn.Sequential()
110 |
111 | self.relu3_1 = torch.nn.Sequential()
112 | self.relu3_2 = torch.nn.Sequential()
113 | self.relu3_3 = torch.nn.Sequential()
114 | self.relu3_4 = torch.nn.Sequential()
115 |
116 | self.relu4_1 = torch.nn.Sequential()
117 | self.relu4_2 = torch.nn.Sequential()
118 | self.relu4_3 = torch.nn.Sequential()
119 | self.relu4_4 = torch.nn.Sequential()
120 |
121 | self.relu5_1 = torch.nn.Sequential()
122 | self.relu5_2 = torch.nn.Sequential()
123 | self.relu5_3 = torch.nn.Sequential()
124 | self.relu5_4 = torch.nn.Sequential()
125 |
126 | for x in range(2):
127 | self.relu1_1.add_module(str(x), features[x])
128 |
129 | for x in range(2, 4):
130 | self.relu1_2.add_module(str(x), features[x])
131 |
132 | for x in range(4, 7):
133 | self.relu2_1.add_module(str(x), features[x])
134 |
135 | for x in range(7, 9):
136 | self.relu2_2.add_module(str(x), features[x])
137 |
138 | for x in range(9, 12):
139 | self.relu3_1.add_module(str(x), features[x])
140 |
141 | for x in range(12, 14):
142 | self.relu3_2.add_module(str(x), features[x])
143 |
144 | for x in range(14, 16):
145 | self.relu3_3.add_module(str(x), features[x])
146 |
147 | for x in range(16, 18):
148 | self.relu3_4.add_module(str(x), features[x])
149 |
150 | for x in range(18, 21):
151 | self.relu4_1.add_module(str(x), features[x])
152 |
153 | for x in range(21, 23):
154 | self.relu4_2.add_module(str(x), features[x])
155 |
156 | for x in range(23, 25):
157 | self.relu4_3.add_module(str(x), features[x])
158 |
159 | for x in range(25, 27):
160 | self.relu4_4.add_module(str(x), features[x])
161 |
162 | for x in range(27, 30):
163 | self.relu5_1.add_module(str(x), features[x])
164 |
165 | for x in range(30, 32):
166 | self.relu5_2.add_module(str(x), features[x])
167 |
168 | for x in range(32, 34):
169 | self.relu5_3.add_module(str(x), features[x])
170 |
171 | for x in range(34, 36):
172 | self.relu5_4.add_module(str(x), features[x])
173 |
174 | # don't need the gradients, just want the features
175 | for param in self.parameters():
176 | param.requires_grad = False
177 |
178 | def forward(self, x):
179 | relu1_1 = self.relu1_1(x)
180 | relu1_2 = self.relu1_2(relu1_1)
181 |
182 | relu2_1 = self.relu2_1(relu1_2)
183 | relu2_2 = self.relu2_2(relu2_1)
184 |
185 | relu3_1 = self.relu3_1(relu2_2)
186 | relu3_2 = self.relu3_2(relu3_1)
187 | relu3_3 = self.relu3_3(relu3_2)
188 | relu3_4 = self.relu3_4(relu3_3)
189 |
190 | relu4_1 = self.relu4_1(relu3_4)
191 | relu4_2 = self.relu4_2(relu4_1)
192 | relu4_3 = self.relu4_3(relu4_2)
193 | relu4_4 = self.relu4_4(relu4_3)
194 |
195 | relu5_1 = self.relu5_1(relu4_4)
196 | relu5_2 = self.relu5_2(relu5_1)
197 | relu5_3 = self.relu5_3(relu5_2)
198 | relu5_4 = self.relu5_4(relu5_3)
199 |
200 | out = {
201 | 'relu1_1': relu1_1,
202 | 'relu1_2': relu1_2,
203 |
204 | 'relu2_1': relu2_1,
205 | 'relu2_2': relu2_2,
206 |
207 | 'relu3_1': relu3_1,
208 | 'relu3_2': relu3_2,
209 | 'relu3_3': relu3_3,
210 | 'relu3_4': relu3_4,
211 |
212 | 'relu4_1': relu4_1,
213 | 'relu4_2': relu4_2,
214 | 'relu4_3': relu4_3,
215 | 'relu4_4': relu4_4,
216 |
217 | 'relu5_1': relu5_1,
218 | 'relu5_2': relu5_2,
219 | 'relu5_3': relu5_3,
220 | 'relu5_4': relu5_4,
221 | }
222 | return out
223 |
224 | class VGGLoss(nn.Module):
225 | def __init__(self):
226 | super(VGGLoss, self).__init__()
227 | self.add_module('vgg', VGG19())
228 | self.criterion = torch.nn.L1Loss()
229 |
230 | def forward(self, img1, img2, p=6):
231 | x = normalize_batch(img1)
232 | y = normalize_batch(img2)
233 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
234 |
235 | content_loss = 0.0
236 | # # content_loss += self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) * 0.1
237 | # # content_loss += self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) * 0.2
238 | content_loss += self.criterion(x_vgg['relu3_2'], y_vgg['relu3_2']) * 1
239 | content_loss += self.criterion(x_vgg['relu4_2'], y_vgg['relu4_2']) * 1
240 | content_loss += self.criterion(x_vgg['relu5_2'], y_vgg['relu5_2']) * 2
241 |
242 | return content_loss / 4.
243 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | from torch.optim import lr_scheduler
5 | from collections import OrderedDict
6 | import torch.nn.functional as F
7 | from util.util import SSIM
8 | import random
9 | import torch.distributions as tdist
10 | import numpy as np
11 | from timm.scheduler.cosine_lr import CosineLRScheduler
12 | from einops import rearrange
13 | import torchvision.ops as ops
14 | import numbers
15 | from torch.nn.modules.utils import _triple
16 | import math
17 | import cv2
18 |
19 |
20 |
21 | # Augument
22 | def augment_func(img, hflip, vflip, rot90): # CxHxW
23 | if hflip: img = torch.flip(img, dims=[-1])
24 | if vflip: img = torch.flip(img, dims=[-2])
25 | if rot90: img = img.transpose(img.ndim-1, img.ndim-2)
26 | return img
27 |
28 |
29 | def augment(*imgs): # CxHxW
30 | hflip = random.random() < 0.5
31 | vflip = random.random() < 0.5
32 | rot90 = random.random() < 0.5
33 | return (augment_func(img, hflip, vflip, rot90) for img in imgs)
34 |
35 |
36 | # Pack Images
37 | def pack_raw_image(im_raw): # GT: N x 1 x H x W ; RAWs: N x T x H x W
38 | """ Packs a single channel bayer image into 4 channel tensor, where channels contain R, G, G, and B values"""
39 | im_out = torch.zeros([im_raw.shape[0], im_raw.shape[1], 4, im_raw.shape[2] // 2, im_raw.shape[3] // 2],
40 | dtype=im_raw.dtype, device=im_raw.device)
41 |
42 | im_out[..., 0, :, :] = im_raw[:, :, 0::2, 0::2]
43 | im_out[..., 1, :, :] = im_raw[:, :, 0::2, 1::2]
44 | im_out[..., 2, :, :] = im_raw[:, :, 1::2, 0::2]
45 | im_out[..., 3, :, :] = im_raw[:, :, 1::2, 1::2]
46 |
47 | # GT: N x 1 x 4 x (H//2) x (W//2)
48 | # RAWs: N x T x 4 x (H//2) x (W//2)
49 | return im_out
50 |
51 |
52 | def get_scheduler(optimizer, opt):
53 | if opt.lr_policy == 'linear':
54 | def lambda_rule(epoch):
55 | return 1 - max(0, epoch-opt.niter) / max(1, float(opt.niter_decay))
56 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
57 | elif opt.lr_policy == 'step':
58 | scheduler = lr_scheduler.StepLR(optimizer,
59 | step_size=opt.lr_decay_iters,
60 | gamma=0.5)
61 | elif opt.lr_policy == 'plateau':
62 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
63 | mode='min',
64 | factor=0.2,
65 | threshold=0.01,
66 | patience=5)
67 | elif opt.lr_policy == 'cosine':
68 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
69 | T_max=opt.niter,
70 | eta_min=1e-6)
71 | elif opt.lr_policy == 'cosine_warmup':
72 | scheduler = CosineLRScheduler(optimizer,
73 | t_initial=opt.niter,
74 | lr_min=1e-6,
75 | warmup_t=5,
76 | warmup_lr_init=1e-5)
77 | else:
78 | return NotImplementedError('lr [%s] is not implemented', opt.lr_policy)
79 | return scheduler
80 |
81 |
82 | def init_weights(net, init_type='normal', init_gain=0.02):
83 | def init_func(m): # define the initialization function
84 | classname = m.__class__.__name__
85 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 \
86 | or classname.find('Linear') != -1):
87 | if init_type == 'normal':
88 | init.normal_(m.weight.data, 0.0, init_gain)
89 | elif init_type == 'xavier':
90 | init.xavier_normal_(m.weight.data, gain=init_gain)
91 | elif init_type == 'kaiming':
92 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
93 | elif init_type == 'orthogonal':
94 | init.orthogonal_(m.weight.data, gain=init_gain)
95 | elif init_type == 'uniform':
96 | init.uniform_(m.weight.data, b=init_gain)
97 | elif init_type == 'constant':
98 | init.constant_(m.weight.data, 0.0)
99 | else:
100 | raise NotImplementedError('[%s] is not implemented' % init_type)
101 | elif hasattr(m, 'bias') and m.bias is not None:
102 | init.constant_(m.bias.data, 0.0)
103 | elif classname.find('BatchNorm2d') != -1:
104 | init.normal_(m.weight.data, 1.0, init_gain)
105 | init.constant_(m.bias.data, 0.0)
106 |
107 | # print('initialize network with %s' % init_type)
108 | net.apply(init_func) # apply the initialization function
109 |
110 |
111 | def init_net(net, init_type='default', init_gain=0.02, gpu_ids=[]):
112 | if len(gpu_ids) > 0:
113 | assert(torch.cuda.is_available())
114 | net.to(gpu_ids[0])
115 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
116 | if init_type != 'default' and init_type is not None:
117 | init_weights(net, init_type, init_gain=init_gain)
118 | return net
119 |
120 |
121 | def set_requires_grad(nets, requires_grad=False):
122 | if not isinstance(nets, list):
123 | nets = [nets]
124 | for net in nets:
125 | if net is not None:
126 | for param in net.parameters():
127 | param.requires_grad = requires_grad
128 |
129 |
130 | def make_layer(block, num_blocks, **kwarg):
131 | """Make layers by stacking the same blocks.
132 |
133 | Args:
134 | block (nn.module): nn.module class for basic block.
135 | num_blocks (int): number of blocks.
136 |
137 | Returns:
138 | nn.Sequential: Stacked blocks in nn.Sequential.
139 | """
140 | layers = []
141 | for _ in range(num_blocks):
142 | layers.append(block(**kwarg))
143 | return nn.Sequential(*layers)
144 |
145 |
146 | def flow_warp(x,
147 | flow,
148 | interpolation='bilinear',
149 | padding_mode='zeros',
150 | align_corners=True):
151 | """Warp an image or a feature map with optical flow.
152 |
153 | Args:
154 | x (Tensor): Tensor with size (n, c, h, w).
155 | flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
156 | a two-channel, denoting the width and height relative offsets.
157 | Note that the values are not normalized to [-1, 1].
158 | interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
159 | Default: 'bilinear'.
160 | padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
161 | Default: 'zeros'.
162 | align_corners (bool): Whether align corners. Default: True.
163 |
164 | Returns:
165 | Tensor: Warped image or feature map.
166 | """
167 | if x.size()[-2:] != flow.size()[1:3]:
168 | raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
169 | f'flow ({flow.size()[1:3]}) are not the same.')
170 | _, _, h, w = x.size()
171 | # create mesh grid
172 | device = flow.device
173 | grid_y, grid_x = torch.meshgrid(
174 | torch.arange(0, h, device=device, dtype=x.dtype),
175 | torch.arange(0, w, device=device, dtype=x.dtype))
176 | grid = torch.stack((grid_x, grid_y), 2) # h, w, 2
177 | grid.requires_grad = False
178 |
179 | grid_flow = grid + flow
180 | # scale grid_flow to [-1,1]
181 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
182 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
183 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
184 | output = F.grid_sample(
185 | x,
186 | grid_flow,
187 | mode=interpolation,
188 | padding_mode=padding_mode,
189 | align_corners=align_corners)
190 | return output
191 |
192 |
193 | def load_spynet(net, path):
194 | if isinstance(net, torch.nn.DataParallel):
195 | net = net.module
196 | state_dict = torch.load(path)
197 |
198 | print('loading the model from %s' % (path))
199 | if hasattr(state_dict, '_metadata'):
200 | del state_dict._metadata
201 |
202 | net_state = net.state_dict()
203 | is_loaded = {n:False for n in net_state.keys()}
204 | for name, param in state_dict.items():
205 | name = name.replace('basic_module.0.conv', 'basic_module.0')
206 | name = name.replace('basic_module.1.conv', 'basic_module.2')
207 | name = name.replace('basic_module.2.conv', 'basic_module.4')
208 | name = name.replace('basic_module.3.conv', 'basic_module.6')
209 | name = name.replace('basic_module.4.conv', 'basic_module.8')
210 | if name in net_state:
211 | try:
212 | net_state[name].copy_(param)
213 | is_loaded[name] = True
214 | except Exception:
215 | print('While copying the parameter named [%s], '
216 | 'whose dimensions in the model are %s and '
217 | 'whose dimensions in the checkpoint are %s.'
218 | % (name, list(net_state[name].shape),
219 | list(param.shape)))
220 | raise RuntimeError
221 | else:
222 | print('Saved parameter named [%s] is skipped' % name)
223 | mark = True
224 | for name in is_loaded:
225 | if not is_loaded[name]:
226 | print('Parameter named [%s] is not initialized' % name)
227 | mark = False
228 | if mark:
229 | print('All parameters are initialized using [%s]' % path)
230 |
231 |
232 | '''
233 | # ===================================
234 | # Advanced nn.Sequential
235 | # reform nn.Sequentials and nn.Modules
236 | # to a single nn.Sequential
237 | # ===================================
238 | '''
239 |
240 | def seq(*args):
241 | if len(args) == 1:
242 | args = args[0]
243 | if isinstance(args, nn.Module):
244 | return args
245 | modules = OrderedDict()
246 | if isinstance(args, OrderedDict):
247 | for k, v in args.items():
248 | modules[k] = seq(v)
249 | return nn.Sequential(modules)
250 | assert isinstance(args, (list, tuple))
251 | return nn.Sequential(*[seq(i) for i in args])
252 |
253 | '''
254 | # ===================================
255 | # Useful blocks
256 | # --------------------------------
257 | # conv (+ normaliation + relu)
258 | # concat
259 | # sum
260 | # resblock (ResBlock)
261 | # resdenseblock (ResidualDenseBlock_5C)
262 | # resinresdenseblock (RRDB)
263 | # ===================================
264 | '''
265 |
266 | # -------------------------------------------------------
267 | # return nn.Sequantial of (Conv + BN + ReLU)
268 | # -------------------------------------------------------
269 | def conv(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1,
270 | output_padding=0, dilation=1, groups=1, bias=True,
271 | padding_mode='zeros', mode='CBR'):
272 | L = []
273 | for t in mode:
274 | if t == 'C':
275 | L.append(nn.Conv2d(in_channels=in_channels,
276 | out_channels=out_channels,
277 | kernel_size=kernel_size,
278 | stride=stride,
279 | padding=padding,
280 | dilation=dilation,
281 | groups=groups,
282 | bias=bias,
283 | padding_mode=padding_mode))
284 | elif t == 'X':
285 | assert in_channels == out_channels
286 | L.append(nn.Conv2d(in_channels=in_channels,
287 | out_channels=out_channels,
288 | kernel_size=kernel_size,
289 | stride=stride,
290 | padding=padding,
291 | dilation=dilation,
292 | groups=in_channels,
293 | bias=bias,
294 | padding_mode=padding_mode))
295 | elif t == 'T':
296 | L.append(nn.ConvTranspose2d(in_channels=in_channels,
297 | out_channels=out_channels,
298 | kernel_size=kernel_size,
299 | stride=stride,
300 | padding=padding,
301 | output_padding=output_padding,
302 | groups=groups,
303 | bias=bias,
304 | dilation=dilation,
305 | padding_mode=padding_mode))
306 | elif t == 'B':
307 | L.append(nn.BatchNorm2d(out_channels))
308 | elif t == 'I':
309 | L.append(nn.InstanceNorm2d(out_channels, affine=True))
310 | elif t == 'i':
311 | L.append(nn.InstanceNorm2d(out_channels))
312 | elif t == 'R':
313 | L.append(nn.ReLU(inplace=True))
314 | elif t == 'r':
315 | L.append(nn.ReLU(inplace=False))
316 | elif t == 'S':
317 | L.append(nn.Sigmoid())
318 | elif t == 'P':
319 | L.append(nn.PReLU())
320 | elif t == 'L':
321 | L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=True))
322 | elif t == 'l':
323 | L.append(nn.LeakyReLU(negative_slope=1e-1, inplace=False))
324 | elif t == '2':
325 | L.append(nn.PixelShuffle(upscale_factor=2))
326 | elif t == '3':
327 | L.append(nn.PixelShuffle(upscale_factor=3))
328 | elif t == '4':
329 | L.append(nn.PixelShuffle(upscale_factor=4))
330 | elif t == 'U':
331 | L.append(nn.Upsample(scale_factor=2, mode='nearest'))
332 | elif t == 'u':
333 | L.append(nn.Upsample(scale_factor=3, mode='nearest'))
334 | elif t == 'M':
335 | L.append(nn.MaxPool2d(kernel_size=kernel_size,
336 | stride=stride,
337 | padding=0))
338 | elif t == 'A':
339 | L.append(nn.AvgPool2d(kernel_size=kernel_size,
340 | stride=stride,
341 | padding=0))
342 | else:
343 | raise NotImplementedError('Undefined type: '.format(t))
344 | return seq(*L)
345 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules."""
2 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | from util import util
5 | import torch
6 | import models
7 | import time
8 |
9 | def str2bool(v):
10 | return v.lower() in ('yes', 'y', 'true', 't', '1')
11 |
12 | inf = float('inf')
13 |
14 | class BaseOptions():
15 | def __init__(self):
16 | """Reset the class; indicates the class hasn't been initailized"""
17 | self.initialized = False
18 |
19 | def initialize(self, parser):
20 | """Define the common options that are used in both training and test."""
21 | # data parameters
22 | parser.add_argument('--dataroot', type=str, default="data")
23 | parser.add_argument('--dataset_name', type=str, default=['bracketire'], nargs='+')
24 | parser.add_argument('--max_dataset_size', type=int, default=inf)
25 | # parser.add_argument('--scale', type=int, default=4, help='Super-resolution scale.')
26 | parser.add_argument('--frame_num', type=int, default=5)
27 | parser.add_argument('--batch_size', type=int, default=2)
28 | parser.add_argument('--patch_size', type=int, default=128)
29 | parser.add_argument('--shuffle', type=str2bool, default=True)
30 | parser.add_argument('-j', '--num_dataloader', default=4, type=int)
31 | parser.add_argument('--drop_last', type=str2bool, default=True)
32 |
33 | # device parameters
34 | parser.add_argument('--gpu_ids', type=str, default='all',
35 | help='Separate the GPU ids by `,`, using all GPUs by default. '
36 | 'eg, `--gpu_ids 0`, `--gpu_ids 2,3`, `--gpu_ids -1`(CPU)')
37 | parser.add_argument('--checkpoints_dir', type=str, default='./ckpt')
38 | parser.add_argument('-v', '--verbose', type=str2bool, default=True)
39 | parser.add_argument('--suffix', default='', type=str)
40 |
41 | # model parameters
42 | parser.add_argument('--name', type=str, default='track1',
43 | help='Name of the folder to save models and logs.')
44 | parser.add_argument('--model', type=str, default='cat')
45 | parser.add_argument('--block', type=str, default='Convnext')
46 | parser.add_argument('--load_path', type=str, default='',
47 | help='Will load pre-trained model if load_path is set')
48 | parser.add_argument('--load_iter', type=int, default=[500], nargs='+',
49 | help='Load parameters if > 0 and load_path is not set. '
50 | 'Set the value of `last_epoch`')
51 | parser.add_argument('--chop', type=str2bool, default=False)
52 | parser.add_argument('--crop_patch', type=int, default=48)
53 | parser.add_argument('--self_weight', type=float, default=1)
54 | parser.add_argument('--neg_weight', type=float, default=1)
55 | parser.add_argument('--exposure', type=int, default=1)
56 |
57 | # training parameters
58 | parser.add_argument('--init_type', type=str, default='default',
59 | choices=['default', 'normal', 'xavier',
60 | 'kaiming', 'orthogonal', 'uniform'],
61 | help='`default` means using PyTorch default init functions.')
62 | parser.add_argument('--init_gain', type=float, default=0.02)
63 | parser.add_argument('--optimizer', type=str, default='Adam',
64 | choices=['Adam', 'SGD', 'RMSprop'])
65 | parser.add_argument('--niter', type=int, default=1000)
66 | parser.add_argument('--niter_decay', type=int, default=0)
67 | parser.add_argument('--lr_policy', type=str, default='step')
68 | parser.add_argument('--lr_decay_iters', type=int, default=200)
69 | parser.add_argument('--lr', type=float, default=0.0001)
70 |
71 | # Optimizer
72 | parser.add_argument('--load_optimizers', type=str2bool, default=False,
73 | help='Loading optimizer parameters for continuing training.')
74 | parser.add_argument('--weight_decay', type=float, default=0)
75 | # Adam
76 | parser.add_argument('--beta1', type=float, default=0.9)
77 | parser.add_argument('--beta2', type=float, default=0.999)
78 | # SGD & RMSprop
79 | parser.add_argument('--momentum', type=float, default=0)
80 | # RMSprop
81 | parser.add_argument('--alpha', type=float, default=0.99)
82 |
83 | # visualization parameters
84 | parser.add_argument('--print_freq', type=int, default=100)
85 | parser.add_argument('--test_every', type=int, default=1)
86 | parser.add_argument('--save_epoch_freq', type=int, default=1)
87 | parser.add_argument('--calc_metrics', type=str2bool, default=True)
88 | parser.add_argument('--save_imgs', type=str2bool, default=True)
89 | parser.add_argument('--visual_full_imgs', type=str2bool, default=False)
90 |
91 | parser.add_argument('--n_scales', type=int, default=3, help='multi-scale deblurring level')
92 | parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps')
93 | parser.add_argument('--rgb_range', type=int, default=1, help='RGB pixel value ranging from 0')
94 |
95 | self.initialized = True
96 | return parser
97 |
98 | def gather_options(self):
99 | """Initialize our parser with basic options(only once).
100 | Add additional model-specific and dataset-specific options.
101 | These options are difined in the function
102 | in model and dataset classes.
103 | """
104 | if not self.initialized: # check if it has been initialized
105 | parser = argparse.ArgumentParser(formatter_class=
106 | argparse.ArgumentDefaultsHelpFormatter)
107 | parser = self.initialize(parser)
108 |
109 | # get the basic options
110 | opt, _ = parser.parse_known_args()
111 |
112 | # modify model-related parser options
113 | model_name = opt.model
114 | model_option_setter = models.get_option_setter(model_name)
115 | parser = model_option_setter(parser, self.isTrain)
116 | opt, _ = parser.parse_known_args() # parse again with new defaults
117 |
118 | # save and return the parser
119 | self.parser = parser
120 | return parser.parse_args()
121 |
122 | def print_options(self, opt):
123 | """Print and save options
124 |
125 | It will print both current options and default values(if different).
126 | It will save options into a text file / [checkpoints_dir] / opt.txt
127 | """
128 | message = ''
129 | message += '----------------- Options ---------------\n'
130 | for k, v in sorted(vars(opt).items()):
131 | comment = ''
132 | default = self.parser.get_default(k)
133 | if v != default:
134 | comment = '\t[default: %s]' % str(default)
135 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
136 | message += '----------------- End -------------------'
137 | print(message)
138 |
139 | # save to the disk
140 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
141 | util.mkdirs(expr_dir)
142 | file_name = os.path.join(expr_dir, 'opt_%s.txt'
143 | % ('train' if self.isTrain else 'test'))
144 | with open(file_name, 'wt') as opt_file:
145 | opt_file.write(message)
146 | opt_file.write('\n')
147 |
148 | def parse(self):
149 | opt = self.gather_options()
150 | opt.isTrain = self.isTrain # train or test
151 | opt.serial_batches = not opt.shuffle
152 |
153 | if self.isTrain and (opt.load_iter != [0] or opt.load_path != '') \
154 | and not opt.load_optimizers:
155 | util.prompt('You are loading a checkpoint and continuing training, '
156 | 'and no optimizer parameters are loaded. Please make '
157 | 'sure that the hyper parameters are correctly set.', 80)
158 | time.sleep(3)
159 |
160 | opt.model = opt.model.lower()
161 | opt.name = opt.name.lower()
162 |
163 | scale_patch = {2: 96, 3: 144, 4: 192}
164 | if opt.patch_size is None:
165 | opt.patch_size = scale_patch[opt.scale]
166 |
167 | if opt.name.startswith(opt.checkpoints_dir):
168 | opt.name = opt.name.replace(opt.checkpoints_dir+'/', '')
169 | if opt.name.endswith('/'):
170 | opt.name = opt.name[:-1]
171 |
172 | if len(opt.dataset_name) == 1:
173 | opt.dataset_name = opt.dataset_name[0]
174 |
175 | if len(opt.load_iter) == 1:
176 | opt.load_iter = opt.load_iter[0]
177 |
178 | # process opt.suffix
179 | if opt.suffix != '':
180 | suffix = ('_' + opt.suffix.format(**vars(opt)))
181 | opt.name = opt.name + suffix
182 |
183 | self.print_options(opt)
184 |
185 | # set gpu ids
186 | cuda_device_count = torch.cuda.device_count()
187 | if opt.gpu_ids == 'all':
188 | # GT 710 (3.5), GT 610 (2.1)
189 | gpu_ids = [i for i in range(cuda_device_count)]
190 | else:
191 | p = re.compile('[^-0-9]+')
192 | gpu_ids = [int(i) for i in re.split(p, opt.gpu_ids) if int(i) >= 0]
193 | opt.gpu_ids = [i for i in gpu_ids \
194 | if torch.cuda.get_device_capability(i) >= (4,0)]
195 |
196 | if len(opt.gpu_ids) == 0 and len(gpu_ids) > 0:
197 | opt.gpu_ids = gpu_ids
198 | util.prompt('You\'re using GPUs with computing capability < 4')
199 | elif len(opt.gpu_ids) != len(gpu_ids):
200 | util.prompt('GPUs(computing capability < 4) have been disabled')
201 |
202 | if len(opt.gpu_ids) > 0:
203 | assert torch.cuda.is_available(), 'No cuda available !!!'
204 | torch.cuda.set_device(opt.gpu_ids[0])
205 | print('The GPUs you are using:')
206 | for gpu_id in opt.gpu_ids:
207 | print(' %2d *%s* with capability %d.%d' % (
208 | gpu_id,
209 | torch.cuda.get_device_name(gpu_id),
210 | *torch.cuda.get_device_capability(gpu_id)))
211 | else:
212 | util.prompt('You are using CPU mode')
213 |
214 | self.opt = opt
215 | return self.opt
216 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | self.isTrain = False
8 | return parser
9 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | self.isTrain = True
8 | return parser
9 |
--------------------------------------------------------------------------------
/spynet/spynet_20210409-c6c1bd09.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CalvinYang0/CRNet/77ff4f79b54131ea696a3655777940c1073adb35/spynet/spynet_20210409-c6c1bd09.pth
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from options.test_options import TestOptions
4 | from data import create_dataset
5 | from models import create_model
6 | from util.visualizer import Visualizer
7 | from tqdm import tqdm
8 | from util.util import calc_psnr as calc_psnr
9 | import time
10 | import numpy as np
11 | from collections import OrderedDict as odict
12 | from copy import deepcopy
13 | from data.degrade.degrade_kernel import get_raw2rgb
14 | from data.degrade.process import gamma_compression
15 | from util.util import mu_tonemap, save_hdr
16 | import cv2
17 |
18 |
19 | if __name__ == '__main__':
20 | opt = TestOptions().parse()
21 |
22 | if not isinstance(opt.load_iter, list):
23 | load_iters = [opt.load_iter]
24 | else:
25 | load_iters = deepcopy(opt.load_iter)
26 |
27 | if not isinstance(opt.dataset_name, list):
28 | dataset_names = [opt.dataset_name]
29 | else:
30 | dataset_names = deepcopy(opt.dataset_name)
31 | datasets = odict()
32 | for dataset_name in dataset_names:
33 | dataset = create_dataset(dataset_name, 'test', opt)
34 | datasets[dataset_name] = tqdm(dataset)
35 |
36 | for load_iter in load_iters:
37 | opt.load_iter = load_iter
38 | model = create_model(opt)
39 | model.setup(opt)
40 | model.eval()
41 | print(torch.cuda.memory_allocated())
42 | for dataset_name in dataset_names:
43 | opt.dataset_name = dataset_name
44 | tqdm_val = datasets[dataset_name]
45 | dataset_test = tqdm_val.iterable
46 | dataset_size_test = len(dataset_test)
47 |
48 | print('='*80)
49 | print(dataset_name + ' dataset')
50 | tqdm_val.reset()
51 |
52 | psnr = [0.0] * dataset_size_test
53 |
54 | time_val = 0
55 |
56 | folder_dir = './ckpt/%s/output_vispng_%d' % (opt.name, load_iter)
57 | os.makedirs(folder_dir, exist_ok=True)
58 |
59 | for i, data in enumerate(tqdm_val):
60 | torch.cuda.empty_cache()
61 | model.set_input(data)
62 | torch.cuda.synchronize()
63 | time_val_start = time.time()
64 | model.test()
65 | torch.cuda.synchronize()
66 | time_val += time.time() - time_val_start
67 | res = model.get_current_visuals()
68 |
69 | if opt.save_imgs:
70 | save_dir_vispng = '%s/%s.png' % (folder_dir, data['fname'][0])
71 | raw_img = res['data_out'][0].permute(1, 2, 0) / 16
72 | img = get_raw2rgb(raw_img, data['meta'], demosaic='net', lineRGB=True)
73 | img = torch.clamp(mu_tonemap(img, mu=5e3)*65535, 0, 65535)
74 | img = img.cpu().numpy()[..., ::-1]
75 | cv2.imwrite(save_dir_vispng, img.astype(np.uint16))
76 |
77 | avg_psnr = '%.2f'%np.mean(psnr)
78 | print(torch.cuda.max_memory_allocated())
79 | for dataset in datasets:
80 | datasets[dataset].close()
81 |
--------------------------------------------------------------------------------
/test_track1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "Start to test the model...."
3 | device="0"
4 |
5 |
6 | dataroot="" # including 'Train' and 'NTIRE_Val' floders
7 | name=""
8 |
9 | python test.py \
10 | --dataset_name bracketire --model cat --name $name --dataroot $dataroot \
11 | --load_iter 500 --save_imgs True --calc_metrics False --gpu_id $device -j 8 --block Convnext
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #-*- encoding: UTF-8 -*-
2 | # import sys
3 | # reload(sys)
4 | # sys.setdefaultencoding("utf-8")
5 |
6 | import time
7 | import torch
8 | from options.train_options import TrainOptions
9 | from data import create_dataset
10 | from models import create_model
11 | from util.visualizer import Visualizer
12 | import numpy as np
13 | import sys
14 | import random
15 | def setup_seed(seed=0):
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 | np.random.seed(seed)
19 | random.seed(seed)
20 | torch.backends.cudnn.deterministic = True
21 | torch.backends.cudnn.benchmark = True
22 |
23 | if __name__ == '__main__':
24 | setup_seed(seed=0)
25 |
26 | opt = TrainOptions().parse()
27 | dataset_train = create_dataset(opt.dataset_name, 'train', opt)
28 | dataset_size_train = len(dataset_train)
29 | print('The number of training images = %d' % dataset_size_train)
30 |
31 |
32 |
33 |
34 | model = create_model(opt)
35 | model.setup(opt)
36 | visualizer = Visualizer(opt)
37 | total_iters = ((model.start_epoch * (dataset_size_train // opt.batch_size)) \
38 | // opt.print_freq) * opt.print_freq
39 | total_iters_start = ((model.start_epoch * (dataset_size_train // opt.batch_size)) \
40 | // opt.print_freq) * opt.print_freq
41 |
42 | for epoch in range(model.start_epoch + 1, opt.niter + opt.niter_decay + 1):
43 | # training
44 | epoch_start_time = time.time()
45 | epoch_iter = 0
46 | model.train()
47 |
48 | iter_data_time = iter_start_time = time.time()
49 | for i, data in enumerate(dataset_train):
50 | if total_iters % opt.print_freq == 0:
51 | t_data = time.time() - iter_data_time
52 | total_iters += 1
53 | epoch_iter += 1
54 | model.set_input(data)
55 | model.optimize_parameters(epoch)
56 |
57 | if total_iters % opt.print_freq == 0 or total_iters==total_iters_start:
58 | losses = model.get_current_losses()
59 | t_comp = (time.time() - iter_start_time)
60 | visualizer.print_current_losses(
61 | epoch, epoch_iter, losses, t_comp, t_data, total_iters)
62 | if opt.save_imgs: # Too many images
63 | visualizer.display_current_results(
64 | 'train', model.get_current_visuals(), total_iters)
65 | iter_start_time = time.time()
66 |
67 | iter_data_time = time.time()
68 |
69 | if epoch % opt.save_epoch_freq == 0:
70 | print('saving the model at the end of epoch %d, iters %d'
71 | % (epoch, total_iters))
72 | model.save_networks(epoch)
73 |
74 | print('End of epoch %d / %d \t Time Taken: %.3f sec'
75 | % (epoch, opt.niter + opt.niter_decay,
76 | time.time() - epoch_start_time))
77 | model.update_learning_rate(epoch)
78 |
79 |
80 |
81 | sys.stdout.flush()
82 |
--------------------------------------------------------------------------------
/train_track1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | echo "Start to train the model...."
3 | dataroot="" # including 'Train' and 'NTIRE_Val' floders
4 |
5 | device=''
6 | name=""
7 |
8 | build_dir="./ckpt/"$name
9 |
10 | if [ ! -d "$build_dir" ]; then
11 | mkdir $build_dir
12 | fi
13 |
14 | LOG=./ckpt/$name/`date +%Y-%m-%d-%H-%M-%S`.txt
15 |
16 | python train.py \
17 | --dataset_name bracketire --model cat --name $name --lr_policy step \
18 | --patch_size 128 --niter 400 --save_imgs True --lr 1e-4 --dataroot $dataroot \
19 | --batch_size 36 --print_freq 500 --calc_metrics True --weight_decay 0.01 \
20 | --gpu_ids $device -j 8 --lr_decay_iters 27 --block Convnext --load_optimizers False | tee $LOG
21 |
22 |
23 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of helper functions."""
2 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import os
7 | import time
8 | from functools import wraps
9 | import torch
10 | import random
11 | import numpy as np
12 | import cv2
13 | import torch
14 | import colour_demosaicing
15 | import glob
16 | import torch
17 | import torch.nn.functional as F
18 | from torch.autograd import Variable
19 | import numpy as np
20 | from math import exp
21 |
22 |
23 | def mu_tonemap(hdr_image, mu=5000):
24 | if isinstance(hdr_image, np.ndarray):
25 | return np.log(1 + mu * hdr_image) / np.log(1 + mu)
26 | elif isinstance(hdr_image, torch.Tensor):
27 | mu = torch.tensor(mu).to(hdr_image.device)
28 | return torch.log(1 + mu * hdr_image) / torch.log(1 + mu)
29 | else:
30 | raise Exception
31 |
32 |
33 | def radiance_writer(out_path, image):
34 | with open(out_path, "wb") as f:
35 | f.write(b"#?RADIANCE\n# Made with Python & Numpy\nFORMAT=32-bit_rle_rgbe\n\n")
36 | f.write(b"-Y %d +X %d\n" %(image.shape[0], image.shape[1]))
37 | brightest = np.maximum(np.maximum(image[...,0], image[...,1]), image[...,2]) + 1e-8
38 | mantissa = np.zeros_like(brightest)
39 | exponent = np.zeros_like(brightest)
40 | np.frexp(brightest, mantissa, exponent)
41 | scaled_mantissa = mantissa * 255.0 / brightest
42 | rgbe = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
43 | rgbe[...,0:3] = np.around(image[...,0:3] * scaled_mantissa[...,None])
44 | rgbe[...,3] = np.around(exponent + 128)
45 |
46 | rgbe.flatten().tofile(f)
47 |
48 |
49 | def save_hdr(path, image):
50 | return radiance_writer(path, image)
51 |
52 |
53 | # 修饰函数,重新尝试600次,每次间隔1秒钟
54 | # 能对func本身处理,缺点在于无法查看func本身的提示
55 | def loop_until_success(func):
56 | @wraps(func)
57 | def wrapper(*args, **kwargs):
58 | for i in range(600):
59 | try:
60 | ret = func(*args, **kwargs)
61 | break
62 | except OSError:
63 | time.sleep(1)
64 | return ret
65 | return wrapper
66 |
67 | # 修改后的print函数及torch.save函数示例
68 | @loop_until_success
69 | def loop_print(*args, **kwargs):
70 | print(*args, **kwargs)
71 |
72 | @loop_until_success
73 | def torch_save(*args, **kwargs):
74 | torch.save(*args, **kwargs)
75 |
76 | def calc_psnr(sr, hr, range=255.):
77 | # shave = 2
78 | with torch.no_grad():
79 | diff = (sr - hr) / hr.max()
80 | # diff = diff[:, :, shave:-shave, shave:-shave]
81 | mse = torch.pow(diff, 2).mean()
82 | return (-10 * torch.log10(mse)).item()
83 |
84 | def diagnose_network(net, name='network'):
85 | """Calculate and print the mean of average absolute(gradients)
86 |
87 | Parameters:
88 | net (torch network) -- Torch network
89 | name (str) -- the name of the network
90 | """
91 | mean = 0.0
92 | count = 0
93 | for param in net.parameters():
94 | if param.grad is not None:
95 | mean += torch.mean(torch.abs(param.grad.data))
96 | count += 1
97 | if count > 0:
98 | mean = mean / count
99 | print(name)
100 | print(mean)
101 |
102 | def print_numpy(x, val=True, shp=True):
103 | """Print the mean, min, max, median, std, and size of a numpy array
104 |
105 | Parameters:
106 | val (bool) -- if print the values of the numpy array
107 | shp (bool) -- if print the shape of the numpy array
108 | """
109 | x = x.astype(np.float64)
110 | if shp:
111 | print('shape,', x.shape)
112 | if val:
113 | x = x.flatten()
114 | print('mean = %3.3f, min = %3.3f, max = %3.3f, mid = %3.3f, std=%3.3f'
115 | % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
116 |
117 | def mkdirs(paths):
118 | """create empty directories if they don't exist
119 |
120 | Parameters:
121 | paths (str list) -- a list of directory paths
122 | """
123 | if isinstance(paths, list) and not isinstance(paths, str):
124 | for path in paths:
125 | mkdir(path)
126 | else:
127 | mkdir(paths)
128 |
129 | def mkdir(path):
130 | """create a single empty directory if it didn't exist
131 |
132 | Parameters:
133 | path (str) -- a single directory path
134 | """
135 | if not os.path.exists(path):
136 | os.makedirs(path)
137 |
138 | def prompt(s, width=66):
139 | print('='*(width+4))
140 | ss = s.split('\n')
141 | if len(ss) == 1 and len(s) <= width:
142 | print('= ' + s.center(width) + ' =')
143 | else:
144 | for s in ss:
145 | for i in split_str(s, width):
146 | print('= ' + i.ljust(width) + ' =')
147 | print('='*(width+4))
148 |
149 | def split_str(s, width):
150 | ss = []
151 | while len(s) > width:
152 | idx = s.rfind(' ', 0, width+1)
153 | if idx > width >> 1:
154 | ss.append(s[:idx])
155 | s = s[idx+1:]
156 | else:
157 | ss.append(s[:width])
158 | s = s[width:]
159 | if s.strip() != '':
160 | ss.append(s)
161 | return ss
162 |
163 | # def augment_func(img, hflip, vflip, rot90): # CxHxW
164 | # if hflip: img = img[:, :, ::-1]
165 | # if vflip: img = img[:, ::-1, :]
166 | # if rot90: img = img.transpose(0, 2, 1)
167 | # return np.ascontiguousarray(img)
168 |
169 | # def augment(*imgs): # CxHxW
170 | # hflip = random.random() < 0.5
171 | # vflip = random.random() < 0.5
172 | # rot90 = random.random() < 0.5
173 | # return (augment_func(img, hflip, vflip, rot90) for img in imgs)
174 |
175 | def remove_black_level(img, black_lv=63, white_lv=4*255):
176 | img = np.maximum(img.astype(np.float32)-black_lv, 0) / (white_lv-black_lv)
177 | return img
178 |
179 | def gamma_correction(img, r=1/2.2):
180 | img = np.maximum(img, 0)
181 | img = np.power(img, r)
182 | return img
183 |
184 | def extract_bayer_channels(raw): # HxW
185 | ch_R = raw[0::2, 0::2]
186 | ch_Gb = raw[0::2, 1::2]
187 | ch_Gr = raw[1::2, 0::2]
188 | ch_B = raw[1::2, 1::2]
189 | raw_combined = np.dstack((ch_B, ch_Gb, ch_R, ch_Gr))
190 | raw_combined = np.ascontiguousarray(raw_combined.transpose((2, 0, 1)))
191 | return raw_combined # 4xHxW
192 |
193 | def get_raw_demosaic(raw, pattern='RGGB'): # HxW
194 | raw_demosaic = colour_demosaicing.demosaicing_CFA_Bayer_bilinear(raw, pattern=pattern)
195 | raw_demosaic = np.ascontiguousarray(raw_demosaic.astype(np.float32).transpose((2, 0, 1)))
196 | return raw_demosaic # 3xHxW
197 |
198 | def get_coord(H, W, x=448/3968, y=448/2976):
199 | x_coord = np.linspace(-x + (x / W), x - (x / W), W)
200 | x_coord = np.expand_dims(x_coord, axis=0)
201 | x_coord = np.tile(x_coord, (H, 1))
202 | x_coord = np.expand_dims(x_coord, axis=0)
203 |
204 | y_coord = np.linspace(-y + (y / H), y - (y / H), H)
205 | y_coord = np.expand_dims(y_coord, axis=1)
206 | y_coord = np.tile(y_coord, (1, W))
207 | y_coord = np.expand_dims(y_coord, axis=0)
208 |
209 | coord = np.ascontiguousarray(np.concatenate([x_coord, y_coord]))
210 | coord = np.float32(coord)
211 |
212 | return coord
213 |
214 | def read_wb(txtfile, key):
215 | wb = np.zeros((1,4))
216 | with open(txtfile) as f:
217 | for l in f:
218 | if key in l:
219 | for i in range(wb.shape[0]):
220 | nextline = next(f)
221 | try:
222 | wb[i,:] = nextline.split()
223 | except:
224 | print("WB error XXXXXXX")
225 | print(txtfile)
226 | wb = wb.astype(np.float32)
227 | return wb
228 |
229 |
230 | def gaussian(window_size, sigma):
231 | gauss = torch.Tensor([exp(-(x - window_size/2)**2/float(2*sigma**2)) for x in range(window_size)])
232 | return gauss/gauss.sum()
233 |
234 | def create_window(window_size, channel, device):
235 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1).to(device)
236 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
237 | window = _2D_window.expand(channel, 1, window_size, window_size)
238 | return window
239 |
240 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
241 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
242 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
243 |
244 | mu1_sq = mu1.pow(2)
245 | mu2_sq = mu2.pow(2)
246 | mu1_mu2 = mu1*mu2
247 |
248 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
249 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
250 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
251 |
252 | C1 = 0.01**2
253 | C2 = 0.03**2
254 |
255 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
256 |
257 | if size_average:
258 | return ssim_map.mean()
259 | else:
260 | return ssim_map.mean(1).mean(1).mean(1)
261 |
262 | class SSIM(torch.nn.Module):
263 | def __init__(self, window_size = 11, size_average = True):
264 | super(SSIM, self).__init__()
265 | self.window_size = window_size
266 | self.size_average = size_average
267 | self.channel = 1
268 | # self.window = create_window(window_size, self.channel)
269 |
270 | def forward(self, img1, img2):
271 | (_, channel, _, _) = img1.size()
272 |
273 | window = create_window(self.window_size, channel, img1.device)
274 | self.window = window
275 | self.channel = channel
276 |
277 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
278 |
279 | def ssim(img1, img2, window_size = 11, size_average = True):
280 | (_, channel, _, _) = img1.size()
281 |
282 | window = create_window(window_size, channel, img1.device)
283 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os.path import join
3 | from tensorboardX import SummaryWriter
4 | from matplotlib import pyplot as plt
5 | from io import BytesIO
6 | from PIL import Image
7 | from functools import partial
8 | from functools import wraps
9 | import time
10 |
11 | def write_until_success(func):
12 | @wraps(func)
13 | def wrapper(*args, **kwargs):
14 | for i in range(30):
15 | try:
16 | ret = func(*args, **kwargs)
17 | break
18 | except OSError:
19 | print('%s OSError' % str(args))
20 | time.sleep(1)
21 | return ret
22 | return wrapper
23 |
24 | class Visualizer():
25 | def __init__(self, opt):
26 | self.opt = opt
27 | if opt.isTrain:
28 | self.name = opt.name
29 | self.save_dir = join(opt.checkpoints_dir, opt.name, 'log')
30 | self.writer = SummaryWriter(logdir=join(self.save_dir))
31 | else:
32 | self.name = '%s_%s_%d' % (
33 | opt.name, opt.dataset_name, opt.load_iter)
34 | self.save_dir = join(opt.checkpoints_dir, opt.name)
35 | if opt.save_imgs:
36 | self.writer = SummaryWriter(logdir=join(
37 | self.save_dir, 'ckpts', self.name))
38 |
39 | @write_until_success
40 | def display_current_results(self, phase, visuals, iters):
41 | for k, v in visuals.items():
42 | v = v.cpu()
43 | self.writer.add_image('%s/%s'%(phase, k), v[0]/255, iters)
44 | self.writer.flush()
45 |
46 | @write_until_success
47 | def print_current_losses(self, epoch, iters, losses,
48 | t_comp, t_data, total_iters):
49 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' \
50 | % (epoch, iters, t_comp, t_data)
51 | for k, v in losses.items():
52 | message += '%s: %.4e ' % (k, v)
53 | self.writer.add_scalar('loss/%s'%k, v, total_iters)
54 | print(message)
55 |
56 | @write_until_success
57 | def print_psnr(self, epoch, total_epoch, time_val, mean_psnr):
58 | self.writer.add_scalar('val/psnr', mean_psnr, epoch)
59 | print('End of epoch %d / %d (Val) \t Time Taken: %.3f s \t PSNR: %f'
60 | % (epoch, total_epoch, time_val, mean_psnr))
61 |
62 |
63 |
--------------------------------------------------------------------------------