├── .gitignore ├── README.md ├── model ├── metrics.py ├── dataset.py └── bokeh.py ├── synth_new.py ├── testmodel.py └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | model/.DS_Store 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BokehOrNot 2 | ## Transforming Bokeh Effect with Image Transformer and Lens Metadata Embedding 3 | 4 | Official Pytorch implementation of BokehOrNot from NTIRE 2023 Challenge. 5 | 6 | 7 | 8 | Model download link: [model.pt](https://1drv.ms/u/s!AhLc1l9ln_UugpYSbRBcihTHhGZXEA?e=cc3Vas) 9 | 10 | synth_new.py: Synthesizing images with transformerd background. 11 | 12 | testmodel.py: Returning overall metrics. 13 | -------------------------------------------------------------------------------- /model/metrics.py: -------------------------------------------------------------------------------- 1 | import lpips 2 | import numpy as np 3 | import torch 4 | from skimage.metrics import structural_similarity as ssim 5 | 6 | lpips_alex = lpips.LPIPS(net="alex").cuda() 7 | 8 | 9 | def calculate_lpips(img0: torch.Tensor, img1: torch.Tensor): 10 | # NOTE: LPIPS expects image normalized to [-1, 1] 11 | img0 = 2 * img0 - 1.0 12 | img1 = 2 * img1 - 1.0 13 | 14 | with torch.no_grad(): 15 | distance = lpips_alex(img0, img1) 16 | return distance.item() 17 | 18 | 19 | def calculate_psnr(img0, img1): 20 | mse = np.mean((img0.astype(np.float32) - img1.astype(np.float32)) ** 2) 21 | if mse == 0: 22 | return float("inf") 23 | max_val = np.max(img0) 24 | return 20 * np.log10(max_val / np.sqrt(mse)) 25 | 26 | 27 | def calculate_ssim(img0, img1): 28 | val = ssim(img0, img1, channel_axis=2) 29 | return val 30 | -------------------------------------------------------------------------------- /synth_new.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import numpy as np 3 | import PIL.Image as Image 4 | import torch 5 | from torchvision import transforms, datasets 6 | from torchvision.transforms import ToPILImage, ToTensor, PILToTensor 7 | from model.bokeh import bokeh 8 | from PIL import Image 9 | from tqdm import tqdm 10 | import PIL.ImageChops as chops 11 | import time 12 | 13 | device = torch.device("cuda") 14 | 15 | bokehnet = bokeh().to(device) 16 | to_pil = ToPILImage() 17 | 18 | PATH = "type model path here" 19 | val_path = 'folder path with images you want to inference' 20 | 21 | bokehnet = torch.load(PATH,map_location=device) 22 | bokehnet.eval() 23 | print("weights loaded!!") 24 | 25 | from os import walk 26 | pic_list = [] 27 | 28 | for f, _, i in walk(val_path): 29 | for j in i: 30 | if 'txt' in j or 'tgt' in j or 'alpha' in j: 31 | continue 32 | else: 33 | pic_list.append(j) 34 | meta = {} 35 | with open(val_path+"/meta.txt", "r") as f: 36 | lines = f.readlines() 37 | 38 | for line in lines: 39 | id, src_lens, tgt_lens, disparity = [part.strip() for part in line.split(",")] 40 | meta[id] = (src_lens, tgt_lens, disparity) 41 | 42 | trans = transforms.Compose([transforms.ToTensor()]) 43 | transform = transforms.Compose([ 44 | transforms.PILToTensor()]) 45 | 46 | timelapse1 = [] 47 | timelapse2 = [] 48 | 49 | for picture in pic_list: 50 | T1 = time.perf_counter() 51 | pic4inference = Image.open(val_path+"/"+str(picture)) 52 | pic4inference_src = Image.open(val_path+"/"+str(picture)) 53 | pic4inference = trans(pic4inference) 54 | id = picture.split(".")[0] 55 | src_lens, tgt_lens, disparity = meta[id] 56 | 57 | if src_lens == "Sony50mmf1.8BS": 58 | src_lens = [1.8] 59 | elif src_lens == "Sony50mmf16.0BS": 60 | src_lens = [16.0] 61 | elif src_lens == "Canon50mmf1.8BS": 62 | src_lens = [-1.8] 63 | elif src_lens == "Canon50mmf1.4BS": 64 | src_lens = [-1.4] 65 | elif src_lens == "Sony50mmf1.4BS": 66 | src_lens = [1.8] 67 | src_lens = torch.FloatTensor(src_lens).to(device) 68 | 69 | if tgt_lens == "Sony50mmf1.8BS": 70 | tgt_lens = [1.8] 71 | elif tgt_lens == "Sony50mmf16.0BS": 72 | tgt_lens = [16.0] 73 | elif tgt_lens == "Canon50mmf1.8BS": 74 | tgt_lens = [-1.8] 75 | elif tgt_lens == "Canon50mmf1.4BS": 76 | tgt_lens = [-1.4] 77 | elif tgt_lens == "Sony50mmf1.4BS": 78 | tgt_lens = [1.8] 79 | tgt_lens = torch.FloatTensor(tgt_lens).to(device) 80 | 81 | disparity = [int(disparity)] 82 | disparity = torch.FloatTensor(disparity).to(device) 83 | 84 | pic4inference = pic4inference.to(device) 85 | 86 | with torch.no_grad(): 87 | T2 = time.perf_counter() 88 | bok_pred = bokehnet(pic4inference,src_lens,tgt_lens,disparity) 89 | T3 = time.perf_counter() 90 | 91 | bok_pred = bok_pred.detach().cpu().squeeze(0) 92 | 93 | result = bok_pred 94 | T4 = time.perf_counter() 95 | 96 | result = to_pil(result.clip(0,1)) 97 | 98 | result.save('type output path here'+str(picture.split('.')[0])+'.src.png') 99 | timelapse1.append(T4-T1) 100 | timelapse2.append(T3-T2) 101 | 102 | print("Avg Synth Time:",sum(timelapse1)/len(timelapse1)) 103 | print("Avg Inference Time:",sum(timelapse2)/len(timelapse2)) 104 | -------------------------------------------------------------------------------- /model/dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os.path as osp 3 | from typing import Callable, Optional 4 | 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | import random 8 | 9 | class BokehDataset(Dataset): 10 | def __init__(self, root_folder: str, transform: Optional[Callable] = None): 11 | self._root_folder = root_folder 12 | self._transform = transform 13 | 14 | self._source_paths = sorted(glob.glob(osp.join(root_folder, "*.src.jpg"))) 15 | self._target_paths = sorted(glob.glob(osp.join(root_folder, "*.tgt.jpg"))) 16 | self._alpha_paths = sorted(glob.glob(osp.join(root_folder, "*.alpha.png"))) 17 | 18 | self._meta_data = self._read_meta_data(osp.join(root_folder, "meta.txt")) 19 | 20 | file_counts = [ 21 | len(self._source_paths), 22 | len(self._target_paths), 23 | len(self._alpha_paths), 24 | len(self._meta_data), 25 | ] 26 | if not file_counts[0] or len(set(file_counts)) != 1: 27 | raise ValueError( 28 | f"Empty or non-matching number of files in root dir: {file_counts}. " 29 | "Expected an equal number of source, target, source-alpha and target-alpha files. " 30 | "Also expecting matching meta file entries." 31 | ) 32 | 33 | def __len__(self): 34 | return len(self._source_paths) 35 | 36 | def _read_meta_data(self, meta_file_path: str): 37 | """Read the meta file containing source / target lens and disparity for each image. 38 | 39 | Args: 40 | meta_file_path (str): File path 41 | 42 | Raises: 43 | ValueError: File not found. 44 | 45 | Returns: 46 | dict: Meta dict of tuples like {id: (id, src_lens, tgt_lens, disparity)}. 47 | """ 48 | if not osp.isfile(meta_file_path): 49 | raise ValueError(f"Meta file missing under {meta_file_path}.") 50 | 51 | meta = {} 52 | with open(meta_file_path, "r") as f: 53 | lines = f.readlines() 54 | 55 | for line in lines: 56 | id, src_lens, tgt_lens, disparity = [part.strip() for part in line.split(",")] 57 | meta[id] = (src_lens, tgt_lens, disparity) 58 | 59 | return meta 60 | 61 | def __getitem__(self, index): 62 | source = Image.open(self._source_paths[index]) 63 | target = Image.open(self._target_paths[index]) 64 | alpha = Image.open(self._alpha_paths[index]) 65 | LR_size = 384 66 | rnd_h = random.randint(0, max(0, 1440 - LR_size)) 67 | rnd_w = random.randint(0, max(0, 1920 - LR_size)) 68 | (left, upper, right, lower) = (rnd_w, rnd_h, rnd_w + LR_size, rnd_h + LR_size) 69 | source = source.crop((left, upper, right, lower)) 70 | target = target.crop((left, upper, right, lower)) 71 | alpha = alpha.crop((left, upper, right, lower)) 72 | 73 | filename = osp.basename(self._source_paths[index]) 74 | id = filename.split(".")[0] 75 | src_lens, tgt_lens, disparity = self._meta_data[id] 76 | 77 | if self._transform: 78 | source = self._transform(source) 79 | target = self._transform(target) 80 | alpha = self._transform(alpha) 81 | #src_lens = self._transform(src_lens) 82 | #tgt_lens = self._transform(tgt_lens) 83 | #disparity = self._transform(disparity) 84 | 85 | return { 86 | "source": source, 87 | "target": target, 88 | "alpha": alpha, 89 | "src_lens": src_lens, 90 | "tgt_lens": tgt_lens, 91 | "disparity": disparity, 92 | } 93 | -------------------------------------------------------------------------------- /testmodel.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn as nn 6 | from torch.nn import functional as F 7 | from torch.optim import Adam 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.transforms import ToPILImage, ToTensor 11 | 12 | from model.dataset_forTesting import BokehDataset 13 | from model.metrics import calculate_lpips, calculate_psnr, calculate_ssim 14 | from model.bokeh import bokeh 15 | import datetime 16 | from PIL import Image 17 | 18 | to_tensor = ToTensor() 19 | to_pil = ToPILImage() 20 | lpips_list = [] 21 | ssim_list = [] 22 | psnr_list = [] 23 | 24 | PATH = " type model path here " 25 | val_path = 'type validation set path here' 26 | 27 | device = torch.device("cuda:0") 28 | bokehnet = torch.load(PATH,map_location=device).module 29 | 30 | 31 | from os import walk 32 | pic_list = [] 33 | 34 | for f, _, i in walk(val_path): 35 | for j in i: 36 | if 'txt' in j or 'png' in j or 'tgt' in j: 37 | continue 38 | else: 39 | pic_list.append(j) 40 | meta = {} 41 | with open(val_path+"/meta.txt", "r") as f: 42 | lines = f.readlines() 43 | 44 | for line in lines: 45 | id, src_lens, tgt_lens, disparity = [part.strip() for part in line.split(",")] 46 | meta[id] = (src_lens, tgt_lens, disparity) 47 | 48 | trans = ToTensor() 49 | 50 | for picture in pic_list: 51 | id = picture.split(".")[0] 52 | pic4inference = Image.open(val_path+"/"+str(picture)) 53 | target = Image.open(val_path+"/"+str(id)+'.tgt.jpg') 54 | pic4inference = trans(pic4inference) 55 | target = trans(target) 56 | 57 | src_lens, tgt_lens, disparity = meta[id] 58 | 59 | if src_lens == "Sony50mmf1.8BS": 60 | src_lens = [1.8] 61 | elif src_lens == "Sony50mmf16.0BS": 62 | src_lens = [16.0] 63 | elif src_lens == "Canon50mmf1.8BS": 64 | src_lens = [-1.8] 65 | elif src_lens == "Canon50mmf1.4BS": 66 | src_lens = [-1.4] 67 | elif src_lens == "Sony50mmf1.4BS": 68 | src_lens = [1.4] 69 | src_lens = torch.FloatTensor(src_lens).to(device) 70 | 71 | if tgt_lens == "Sony50mmf1.8BS": 72 | tgt_lens = [1.8] 73 | elif tgt_lens == "Sony50mmf16.0BS": 74 | tgt_lens = [16.0] 75 | elif tgt_lens == "Canon50mmf1.8BS": 76 | tgt_lens = [-1.8] 77 | elif tgt_lens == "Canon50mmf1.4BS": 78 | tgt_lens = [-1.4] 79 | elif tgt_lens == "Sony50mmf1.4BS": 80 | tgt_lens = [1.4] 81 | tgt_lens = torch.FloatTensor(tgt_lens).to(device) 82 | 83 | disparity = [int(disparity)] 84 | disparity = torch.FloatTensor(disparity).to(device) 85 | 86 | #print(pic4inference.shape) 87 | pic4inference = pic4inference.to(device) 88 | target = target.to(device) 89 | 90 | #pic4inference = img_transform(pic4inference).unsqueeze(0) 91 | pic4inference = pic4inference.unsqueeze(0) 92 | with torch.no_grad(): 93 | output = bokehnet(pic4inference,src_lens,tgt_lens,disparity) 94 | target = target.unsqueeze(0) 95 | 96 | # Calculate metrics 97 | lpips = np.mean([calculate_lpips(img0, img1) for img0, img1 in zip(output, target)]) 98 | psnr = np.mean( 99 | [calculate_psnr(np.asarray(to_pil(img0)), np.asarray(to_pil(img1))) for img0, img1 in zip(output, target)] 100 | ) 101 | ssim = np.mean( 102 | [calculate_ssim(np.asarray(to_pil(img0)), np.asarray(to_pil(img1))) for img0, img1 in zip(output, target)] 103 | ) 104 | lpips_list.append(lpips) 105 | psnr_list.append(psnr) 106 | ssim_list.append(ssim) 107 | print(id,f"Metrics: lpips={lpips:0.03f}, psnr={psnr:0.03f}, ssim={ssim:0.03f}") 108 | 109 | lpips = sum(lpips_list)/len(lpips_list) 110 | psnr = sum(psnr_list)/len(psnr_list) 111 | ssim = sum(ssim_list)/len(ssim_list) 112 | print("Result:",f"Metrics: lpips={lpips:0.03f}, psnr={psnr:0.03f}, ssim={ssim:0.03f}") 113 | 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /model/bokeh.py: -------------------------------------------------------------------------------- 1 | ## Restormer: Efficient Transformer for High-Resolution Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, and Ming-Hsuan Yang 3 | ## https://arxiv.org/abs/2111.09881 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from pdb import set_trace as stx 10 | import numbers 11 | import math 12 | from einops import rearrange 13 | 14 | #device = torch.device("cuda") 15 | 16 | ########################################################################## 17 | ## Layer Norm 18 | 19 | def to_3d(x): 20 | if len(x.shape) == 3: 21 | x = x[None, ...] 22 | return rearrange(x, 'b c h w -> b (h w) c') 23 | 24 | def to_4d(x,h,w): 25 | return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w) 26 | 27 | class BiasFree_LayerNorm(nn.Module): 28 | def __init__(self, normalized_shape): 29 | super(BiasFree_LayerNorm, self).__init__() 30 | if isinstance(normalized_shape, numbers.Integral): 31 | normalized_shape = (normalized_shape,) 32 | normalized_shape = torch.Size(normalized_shape) 33 | 34 | assert len(normalized_shape) == 1 35 | 36 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 37 | self.normalized_shape = normalized_shape 38 | 39 | def forward(self, x): 40 | sigma = x.var(-1, keepdim=True, unbiased=False) 41 | return x / torch.sqrt(sigma+1e-5) * self.weight 42 | 43 | class WithBias_LayerNorm(nn.Module): 44 | def __init__(self, normalized_shape): 45 | super(WithBias_LayerNorm, self).__init__() 46 | if isinstance(normalized_shape, numbers.Integral): 47 | normalized_shape = (normalized_shape,) 48 | normalized_shape = torch.Size(normalized_shape) 49 | 50 | assert len(normalized_shape) == 1 51 | 52 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 53 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 54 | self.normalized_shape = normalized_shape 55 | 56 | def forward(self, x): 57 | mu = x.mean(-1, keepdim=True) 58 | sigma = x.var(-1, keepdim=True, unbiased=False) 59 | return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias 60 | 61 | 62 | class LayerNorm(nn.Module): 63 | def __init__(self, dim, LayerNorm_type): 64 | super(LayerNorm, self).__init__() 65 | if LayerNorm_type =='BiasFree': 66 | self.body = BiasFree_LayerNorm(dim) 67 | else: 68 | self.body = WithBias_LayerNorm(dim) 69 | 70 | def forward(self, x): 71 | h, w = x.shape[-2:] 72 | return to_4d(self.body(to_3d(x)), h, w) 73 | 74 | 75 | 76 | ########################################################################## 77 | ## Gated-Dconv Feed-Forward Network (GDFN) 78 | class FeedForward(nn.Module): 79 | def __init__(self, dim, ffn_expansion_factor, bias): 80 | super(FeedForward, self).__init__() 81 | 82 | hidden_features = int(dim*ffn_expansion_factor) 83 | 84 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) 85 | 86 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias) 87 | 88 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 89 | 90 | def forward(self, x): 91 | x = self.project_in(x) 92 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 93 | x = F.gelu(x1) * x2 94 | x = self.project_out(x) 95 | return x 96 | 97 | 98 | 99 | ########################################################################## 100 | ## Multi-DConv Head Transposed Self-Attention (MDTA) 101 | class Attention(nn.Module): 102 | def __init__(self, dim, num_heads, bias): 103 | super(Attention, self).__init__() 104 | self.num_heads = num_heads 105 | self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) 106 | 107 | self.qkv = nn.Conv2d(dim, dim*3, kernel_size=1, bias=bias) 108 | self.qkv_dwconv = nn.Conv2d(dim*3, dim*3, kernel_size=3, stride=1, padding=1, groups=dim*3, bias=bias) 109 | self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias) 110 | 111 | 112 | 113 | def forward(self, x): 114 | b,c,h,w = x.shape 115 | 116 | qkv = self.qkv_dwconv(self.qkv(x)) 117 | q,k,v = qkv.chunk(3, dim=1) 118 | 119 | q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 120 | k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 121 | v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads) 122 | 123 | q = torch.nn.functional.normalize(q, dim=-1) 124 | k = torch.nn.functional.normalize(k, dim=-1) 125 | 126 | attn = (q @ k.transpose(-2, -1)) * self.temperature 127 | attn = attn.softmax(dim=-1) 128 | 129 | out = (attn @ v) 130 | 131 | out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w) 132 | 133 | out = self.project_out(out) 134 | return out 135 | 136 | 137 | 138 | ########################################################################## 139 | class TransformerBlock(nn.Module): 140 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 141 | super(TransformerBlock, self).__init__() 142 | 143 | self.mlp = nn.Sequential(nn.Linear(dim, 3*dim), nn.ReLU(), nn.Linear(3*dim, 2*dim)) 144 | self.norm1 = LayerNorm(dim, LayerNorm_type) 145 | self.attn = Attention(dim, num_heads, bias) 146 | self.norm2 = LayerNorm(dim, LayerNorm_type) 147 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 148 | 149 | def forward(self, input): 150 | 151 | x, embed = input[0], input[1] 152 | #print(x.shape,embed.shape) 153 | em = self.mlp(embed) 154 | em = em[...,None,None] 155 | #print(em.shape) 156 | scale, shift = torch.chunk(em, 2,dim=1) 157 | #print(scale) 158 | x = x + self.attn(self.norm1(x)) 159 | #print(x.shape,scale.shape,shift.shape) 160 | #print(x.shape,scale.shape) 161 | x = x * (scale + 1) + shift 162 | 163 | x = x + self.ffn(self.norm2(x)) 164 | 165 | return [x,embed] 166 | 167 | class TransformerBlock2(nn.Module): 168 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 169 | super(TransformerBlock2, self).__init__() 170 | 171 | self.mlp = nn.Sequential(nn.Linear(dim, 3*dim), nn.ReLU(), nn.Linear(3*dim, 512)) 172 | self.norm1 = LayerNorm(dim, LayerNorm_type) 173 | self.attn = Attention(dim, num_heads, bias) 174 | self.norm2 = LayerNorm(dim, LayerNorm_type) 175 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 176 | 177 | def forward(self, x): 178 | 179 | 180 | x = x + self.attn(self.norm1(x)) 181 | 182 | x = x + self.ffn(self.norm2(x)) 183 | 184 | return x 185 | 186 | class TransformerBlock3(nn.Module): 187 | def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type): 188 | super(TransformerBlock3, self).__init__() 189 | 190 | self.mlp = nn.Sequential(nn.Linear(144, 144), nn.ReLU(), nn.Linear(144, 512)) 191 | self.norm1 = LayerNorm(dim, LayerNorm_type) 192 | self.attn = Attention(dim, num_heads, bias) 193 | self.norm2 = LayerNorm(dim, LayerNorm_type) 194 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 195 | 196 | def forward(self, input): 197 | 198 | x, embed = input[0], input[1] 199 | #print(x.shape,embed.shape) 200 | em = self.mlp(embed) 201 | em = em[...,None,None] 202 | #print(em.shape) 203 | scale, shift = torch.chunk(em, 2,dim=1) 204 | #print(scale) 205 | x = x + self.attn(self.norm1(x)) 206 | #print(x.shape,scale.shape,shift.shape) 207 | x = x * scale + shift 208 | 209 | x = x + self.ffn(self.norm2(x)) 210 | 211 | return [x,embed] 212 | 213 | 214 | 215 | ########################################################################## 216 | ## Overlapped image patch embedding with 3x3 Conv 217 | class OverlapPatchEmbed(nn.Module): 218 | def __init__(self, in_c=3, embed_dim=48, bias=False): 219 | super(OverlapPatchEmbed, self).__init__() 220 | 221 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) 222 | 223 | def forward(self, x): 224 | x = self.proj(x) 225 | 226 | return x 227 | 228 | 229 | # sinusoidal positional embeds 230 | class SinusoidalPosEmb(nn.Module): 231 | def __init__(self, dim): 232 | super().__init__() 233 | self.dim = dim 234 | 235 | def forward(self, x): 236 | device = x.device 237 | half_dim = self.dim // 2 238 | emb = math.log(10000) / (half_dim - 1) 239 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 240 | emb = x[:, None] * emb[None, :] 241 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 242 | return emb 243 | 244 | 245 | ########################################################################## 246 | ## Resizing modules 247 | class Downsample(nn.Module): 248 | def __init__(self, n_feat): 249 | super(Downsample, self).__init__() 250 | 251 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat//2, kernel_size=3, stride=1, padding=1, bias=False), 252 | nn.PixelUnshuffle(2)) 253 | 254 | def forward(self, x): 255 | #print("InputSize in DownSample:",x[0].shape,'and',x[1].shape) 256 | x = x[0] 257 | #print("OperatingSize in DownSample:",x.shape) 258 | return self.body(x) 259 | 260 | class Upsample2(nn.Module): 261 | def __init__(self, n_feat): 262 | super(Upsample2, self).__init__() 263 | 264 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), 265 | nn.PixelShuffle(2)) 266 | 267 | def forward(self, x): 268 | return self.body(x) 269 | 270 | class Upsample(nn.Module): 271 | def __init__(self, n_feat): 272 | super(Upsample, self).__init__() 273 | 274 | self.body = nn.Sequential(nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False), 275 | nn.PixelShuffle(2)) 276 | 277 | def forward(self, x): 278 | x = x[0] 279 | return self.body(x) 280 | 281 | 282 | ##---------- BokehOrNot ----------------------- 283 | class bokeh(nn.Module): 284 | def __init__(self, 285 | inp_channels=3, 286 | out_channels=3, 287 | dim = 48, 288 | num_blocks = [2,4,4,6], 289 | num_refinement_blocks = 6, 290 | heads = [1,2,4,8], 291 | ffn_expansion_factor = 2.66, 292 | bias = False, 293 | LayerNorm_type = 'WithBias', 294 | dual_pixel_task = False 295 | ): 296 | 297 | super(bokeh, self).__init__() 298 | 299 | self.cam_embed = SinusoidalPosEmb(dim) # tensor -> 1d vector 300 | self.cam_mlp = nn.Sequential(nn.Linear(3*dim, dim), nn.ReLU(), nn.Linear(dim, dim)) 301 | 302 | 303 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim) 304 | 305 | self.encoder_level1 = nn.Sequential(*[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) 306 | 307 | self.down1_2 = Downsample(dim) ## From Level 1 to Level 2 308 | self.encoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) 309 | 310 | self.down2_3 = Downsample(int(dim*2**1)) ## From Level 2 to Level 3 311 | self.encoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) 312 | 313 | self.down3_4 = Downsample(int(dim*2**2)) ## From Level 3 to Level 4 314 | self.latent = nn.Sequential(*[TransformerBlock2(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[3])]) 315 | 316 | self.up4_3 = Upsample2(int(dim*2**3)) ## From Level 4 to Level 3 317 | self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias) 318 | self.decoder_level3 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[2])]) 319 | 320 | 321 | self.up3_2 = Upsample(int(dim*2**2)) ## From Level 3 to Level 2 322 | self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias) 323 | self.decoder_level2 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[1])]) 324 | 325 | self.up2_1 = Upsample(int(dim*2**1)) ## From Level 2 to Level 1 (NO 1x1 conv to reduce channels) 326 | 327 | self.decoder_level1 = nn.Sequential(*[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_blocks[0])]) 328 | 329 | self.refinement = nn.Sequential(*[TransformerBlock2(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor, bias=bias, LayerNorm_type=LayerNorm_type) for i in range(num_refinement_blocks)]) 330 | 331 | #### For Dual-Pixel Defocus Deblurring Task #### 332 | self.dual_pixel_task = dual_pixel_task 333 | if self.dual_pixel_task: 334 | self.skip_conv = nn.Conv2d(dim, int(dim*2**1), kernel_size=1, bias=bias) 335 | ########################### 336 | 337 | self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 338 | 339 | 340 | def forward(self, inp_img, incam, outcam, disparity): 341 | 342 | incam_embed = self.cam_embed(incam) 343 | outcam_embed = self.cam_embed(outcam) 344 | disparity_embed = self.cam_embed(disparity) 345 | 346 | cam_embed = torch.cat([incam_embed, outcam_embed, disparity_embed], dim=1) 347 | cam_embed = self.cam_mlp(cam_embed) 348 | 349 | inp_enc_level1 = self.patch_embed(inp_img) #img tensor 350 | 351 | out_enc_level1 = self.encoder_level1([inp_enc_level1, cam_embed]) #list 352 | 353 | inp_enc_level2 = self.down1_2(out_enc_level1) #img tensor 354 | 355 | out_enc_level2 = self.encoder_level2([inp_enc_level2, cam_embed]) 356 | 357 | inp_enc_level3 = self.down2_3(out_enc_level2) 358 | 359 | out_enc_level3 = self.encoder_level3([inp_enc_level3, cam_embed]) 360 | 361 | inp_enc_level4 = self.down3_4(out_enc_level3) 362 | 363 | latent = self.latent(inp_enc_level4) 364 | 365 | inp_dec_level3 = self.up4_3(latent) 366 | 367 | inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3[0]], 1) 368 | 369 | inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3) 370 | 371 | out_dec_level3 = self.decoder_level3([inp_dec_level3, cam_embed]) 372 | 373 | 374 | inp_dec_level2 = self.up3_2(out_dec_level3) 375 | inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2[0]], 1) 376 | inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2) 377 | out_dec_level2 = self.decoder_level2([inp_dec_level2, cam_embed]) 378 | 379 | inp_dec_level1 = self.up2_1(out_dec_level2) 380 | inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1[0]], 1) 381 | out_dec_level1 = self.decoder_level1([inp_dec_level1, cam_embed]) 382 | 383 | out_dec_level1 = self.refinement(out_dec_level1[0]) 384 | 385 | 386 | #### For Dual-Pixel Defocus Deblurring Task #### 387 | if self.dual_pixel_task: 388 | out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1) 389 | out_dec_level1 = self.output(out_dec_level1) 390 | ########################### 391 | else: 392 | out_dec_level1 = self.output(out_dec_level1) + inp_img 393 | 394 | 395 | return out_dec_level1 396 | --------------------------------------------------------------------------------