├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── calib └── calibrate.py ├── config.ipynb ├── datasets ├── genDB-r2rNet.ipynb └── getStat-r2rNet.ipynb ├── demo.ipynb ├── demo.py ├── genValset.ipynb ├── models ├── module.ipynb └── network.ipynb ├── samples └── DJI_0899.DNG ├── test.py ├── train-gainEst.ipynb ├── train-r2rNet.ipynb ├── train-rawProcess.ipynb └── utils ├── dataLoader.ipynb ├── monitor.ipynb └── util.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 YC.L 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepSelfie: Single-shot Low-light Enhancement for Selfies 2 | 3 | If you are interested in using our code and find it useful in your research, please consider citing the following paper: 4 | 5 | ```latex 6 | @article{article, 7 | author = {Lu, Yucheng and Kim, Dong-Wook and Jung, Seung-Won}, 8 | year = {2020}, 9 | month = {07}, 10 | pages = {1-1}, 11 | title = {DeepSelfie: Single-shot Low-light Enhancement for Selfies}, 12 | volume = {PP}, 13 | journal = {IEEE Access}, 14 | doi = {10.1109/ACCESS.2020.3006525} 15 | } 16 | ``` 17 | 18 | 19 | 20 | ### System Requirements 21 | 22 | - python 3 23 | - pyTorch 24 | - torchvision 25 | - Jupyter Notebook (for training) 26 | - Visdom (for training) 27 | - OpenCV (for training) 28 | - numpy 29 | - build-essential: sudo apt-get install build-essential 30 | - python-all-dev: sudo apt-get install python-all-dev 31 | - libexiv2-dev: sudo apt-get install libexiv2-dev 32 | - libboost-python-dev: sudo apt-get install libboost-python-dev 33 | - pyexiv2: pip install py3exiv2 34 | - libraw: sudo apt-get install libraw-dev 35 | - rawpy: pip install rawpy 36 | 37 | Please note that the libraw installed via apt-get is an outdated version, if you get wrong results caused by this version, you should better build it from [source](https://github.com/LibRaw/LibRaw). 38 | 39 | 40 | 41 | ### Run Demo 42 | 43 | To run demo, download the pretrained models from [onedrive](https://dongguk0-my.sharepoint.com/:f:/g/personal/yc_lu_dongguk_edu/EoIwKeaFgZhAj9UaLoxVWDEBX3Yhs07mpXyJn5Y_Xj6aTQ?e=30kDcX), unzip and copy the files to "saves" folder, add your .DNG files to "samples" folder, then run the following command: 44 | 45 | `python ./demo.py` 46 | 47 | The generated images will be saved to "results" folder. 48 | 49 | Alternatively, you can specify your own input or output path by passing "--input" and "--output": 50 | 51 | `python ./demo.py --input PATH-TO-INPUT-FOLDER --output PATH-TO-OUTPUT-FOLDER` 52 | 53 | The default device used for running this demo is CPU, if you want to use GPU instead, pass "--device cuda": 54 | 55 | `python ./demo.py --device cuda` 56 | 57 | If you encounter out of memory error, try down-sampling the input before further processing: 58 | 59 | `python ./demo.py --resize (600,800)` 60 | 61 | The output images are not calibrated and thus have distortion, if you want to do camera calibration, pass "--calib": 62 | 63 | `python ./demo.py --calib` 64 | 65 | 66 | 67 | ### Train from scratch 68 | 69 | Before training from scratch, you need to first download the FivekNight dataset from [onedrive](https://dongguk0-my.sharepoint.com/:f:/g/personal/yc_lu_dongguk_edu/EoIwKeaFgZhAj9UaLoxVWDEBX3Yhs07mpXyJn5Y_Xj6aTQ?e=30kDcX). Also, you need to use your camera to take some images (the more the better, with various exposure levels and ISO) in RAW format. To start training, follow the instructions below: 70 | 71 | 1. Generate training dataset for r2rNet by running "genDB-r2rNet.ipynb" under /dataset. 72 | 2. Get the statistical information of your own dataset by running "getStat_r2rNet.ipynb" under /dataset, copy the results to "config.ipynb" under /. 73 | 3. Train r2rNet by running "train-dataloader.ipynb" under /, you may want to specify a new port for Visdom. 74 | 4. Generate validation dataset by running "genValset.ipynb" under /, this step should be performed after r2rNet training is completed. 75 | 5. Train the gain estimation network by running "train-gainEst.ipynb" under /, you may want to specify a new port for Visdom. 76 | 6. Train the raw processing network by running "train-rawProcess.ipynb" under /, you may want to specify a new port for Visdom. 77 | 78 | Note: You may need to change the paths at step 1 and 2, and the port used by Visdom at step 3, 5, and 6. 79 | 80 | -------------------------------------------------------------------------------- /calib/calibrate.py: -------------------------------------------------------------------------------- 1 | # Fisheye Camera Calibration 2 | # Modified from https://medium.com/@kennethjiang/calibrate-fisheye-lens-using-opencv-333b05afa0b0 3 | 4 | import cv2 5 | import numpy as np 6 | import os 7 | import glob 8 | import rawpy as rp 9 | 10 | CHECKERBOARD = (6, 9) 11 | subpix_criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.1) 12 | calibration_flags = cv2.fisheye.CALIB_RECOMPUTE_EXTRINSIC + cv2.fisheye.CALIB_CHECK_COND + cv2.fisheye.CALIB_FIX_SKEW 13 | objp = np.zeros((1, CHECKERBOARD[0] * CHECKERBOARD[1], 3), np.float32) 14 | objp[0, :, :2] = np.mgrid[0:CHECKERBOARD[0], 15 | 0:CHECKERBOARD[1]].T.reshape(-1, 2) 16 | _img_shape = None 17 | objpoints = [] # 3d point in real world space 18 | imgpoints = [] # 2d points in image plane. 19 | images = glob.glob('*.DNG') 20 | for fname in images: 21 | raw_img = rp.imread(fname) 22 | img = raw_img.postprocess(use_camera_wb=True) 23 | if _img_shape == None: 24 | _img_shape = img.shape[:2] 25 | else: 26 | assert _img_shape == img.shape[: 27 | 2], "All images must share the same size." 28 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 29 | # Find the chess board corners 30 | ret, corners = cv2.findChessboardCorners( 31 | gray, CHECKERBOARD, cv2.CALIB_CB_ADAPTIVE_THRESH + 32 | cv2.CALIB_CB_FAST_CHECK + cv2.CALIB_CB_NORMALIZE_IMAGE) 33 | # If found, add object points, image points (after refining them) 34 | if ret == True: 35 | objpoints.append(objp) 36 | cv2.cornerSubPix(gray, corners, (3, 3), (-1, -1), subpix_criteria) 37 | imgpoints.append(corners) 38 | N_OK = len(objpoints) 39 | K = np.zeros((3, 3)) 40 | D = np.zeros((4, 1)) 41 | rvecs = [np.zeros((1, 1, 3), dtype=np.float64) for i in range(N_OK)] 42 | tvecs = [np.zeros((1, 1, 3), dtype=np.float64) for i in range(N_OK)] 43 | rms, _, _, _, _ = \ 44 | cv2.fisheye.calibrate( 45 | objpoints, 46 | imgpoints, 47 | gray.shape[::-1], 48 | K, 49 | D, 50 | rvecs, 51 | tvecs, 52 | calibration_flags, 53 | (cv2.TERM_CRITERIA_EPS+cv2.TERM_CRITERIA_MAX_ITER, 30, 1e-6) 54 | ) 55 | print("Found " + str(N_OK) + " valid images for calibration") 56 | print("DIM=" + str(_img_shape[::-1])) 57 | print("K=np.array(" + str(K.tolist()) + ")") 58 | print("D=np.array(" + str(D.tolist()) + ")") -------------------------------------------------------------------------------- /config.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Network Configuration" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": { 14 | "code_folding": [] 15 | }, 16 | "outputs": [], 17 | "source": [ 18 | "# camera parameters\n", 19 | "wb_stat = {\n", 20 | " 'slope': -0.933335,\n", 21 | " 'const': -1.249542,\n", 22 | " 'std': 0.063877,\n", 23 | " 'min': -1.020177,\n", 24 | " 'max': -0.042540\n", 25 | "}\n", 26 | "noise_stat = {\n", 27 | " 'slope': 2.385028,\n", 28 | " 'const': 4.055563,\n", 29 | " 'std': 0.100832,\n", 30 | " 'min': -8.373093,\n", 31 | " 'max': -7.049319\n", 32 | "}\n", 33 | "d65_wb = [2.408501, 1.000000, 1.499017, 1.000000]\n", 34 | "cam_matrix = [[1.4462, -0.8273, -0.1252], [-0.1818, 1.0996, 0.0918],\n", 35 | " [0.0193, 0.044, 0.7474]]\n", 36 | "amp_range = (1, 20)\n", 37 | "\n", 38 | "# dataset paths\n", 39 | "r2r_path = '/home/lab/Documents/ssd/r2rSet'\n", 40 | "fivek_path = '/home/lab/Documents/ssd/fivekNight'\n", 41 | "\n", 42 | "\n", 43 | "class r2rNetConf():\n", 44 | " data_root = r2r_path\n", 45 | " save_root = None\n", 46 | "\n", 47 | " # network parameters\n", 48 | " r2r_size = 128 # input size for r2rNet\n", 49 | " batch_size = 8\n", 50 | " lr = 1e-4\n", 51 | " lr_decay = 0.95\n", 52 | " upd_freq = 20 # learning rate update frequency\n", 53 | " max_epoch = 1500\n", 54 | "\n", 55 | " # camera parameters\n", 56 | " d65_wb = d65_wb\n", 57 | " cam_matrix = cam_matrix\n", 58 | "\n", 59 | " # other parameters\n", 60 | " num_workers = 8\n", 61 | " save_epoch = 1000 # epoch to be saved\n", 62 | "\n", 63 | "\n", 64 | "class mainConf():\n", 65 | " data_root = fivek_path\n", 66 | " save_root = None\n", 67 | "\n", 68 | " # network parameters\n", 69 | " att_size = (256, 192) # input size for att module\n", 70 | " isp_size = (640, 480) # input size for isp module\n", 71 | " batch_size = 4\n", 72 | " lr = 1e-4\n", 73 | "\n", 74 | " # camera parameters\n", 75 | " wb_stat = wb_stat\n", 76 | " noise_stat = noise_stat\n", 77 | " amp_range = amp_range\n", 78 | "\n", 79 | " # other parameters\n", 80 | " num_workers = 8\n", 81 | " plot_freq = 100 # plot frequency\n", 82 | " save_freq = 2500 # save frequency\n", 83 | "\n", 84 | "\n", 85 | "class valConf():\n", 86 | " data_root = fivek_path\n", 87 | "\n", 88 | " # network parameters\n", 89 | " att_size = (256, 192)\n", 90 | " isp_size = (768, 576)\n", 91 | "\n", 92 | " # camera parameters\n", 93 | " wb_stat = wb_stat\n", 94 | " noise_stat = noise_stat\n", 95 | " amp_range = amp_range" 96 | ] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "Python 3", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.8.3" 116 | }, 117 | "toc": { 118 | "nav_menu": {}, 119 | "number_sections": true, 120 | "sideBar": true, 121 | "skip_h1_title": true, 122 | "title_cell": "Table of Contents", 123 | "title_sidebar": "Contents", 124 | "toc_cell": false, 125 | "toc_position": {}, 126 | "toc_section_display": true, 127 | "toc_window_display": false 128 | }, 129 | "varInspector": { 130 | "cols": { 131 | "lenName": 16, 132 | "lenType": 16, 133 | "lenVar": 40 134 | }, 135 | "kernels_config": { 136 | "python": { 137 | "delete_cmd_postfix": "", 138 | "delete_cmd_prefix": "del ", 139 | "library": "var_list.py", 140 | "varRefreshCmd": "print(var_dic_list())" 141 | }, 142 | "r": { 143 | "delete_cmd_postfix": ") ", 144 | "delete_cmd_prefix": "rm(", 145 | "library": "var_list.r", 146 | "varRefreshCmd": "cat(var_dic_list()) " 147 | } 148 | }, 149 | "types_to_exclude": [ 150 | "module", 151 | "function", 152 | "builtin_function_or_method", 153 | "instance", 154 | "_Feature" 155 | ], 156 | "window_display": false 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 2 161 | } 162 | -------------------------------------------------------------------------------- /datasets/genDB-r2rNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Dataset Generation" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Includes" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "code_folding": [ 22 | 0 23 | ] 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "# mass includes\n", 28 | "import os\n", 29 | "import pickle\n", 30 | "import pyexiv2 as exiv2\n", 31 | "import rawpy as rp\n", 32 | "import numpy as np\n", 33 | "import torch as t\n", 34 | "from rawpy import HighlightMode\n", 35 | "from tqdm.notebook import tqdm\n", 36 | "from torch.utils import data" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": {}, 42 | "source": [ 43 | "## Initialization" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "code_folding": [ 51 | 0 52 | ] 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "# configuration\n", 57 | "data_root = '/home/lab/Documents/ssd/DJI' # dataset path\n", 58 | "save_root = '/home/lab/Documents/ssd/r2rSet' # save path\n", 59 | "file_ext = '.DNG' # extension of raw file\n", 60 | "train_num = 710 # num of images for training\n", 61 | "patch_size = (400, 300) # size of each patch" 62 | ] 63 | }, 64 | { 65 | "cell_type": "markdown", 66 | "metadata": {}, 67 | "source": [ 68 | "## RAW data manipulation" 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "execution_count": null, 74 | "metadata": { 75 | "code_folding": [ 76 | 0 77 | ] 78 | }, 79 | "outputs": [], 80 | "source": [ 81 | "# get file list\n", 82 | "file_list = [file for file in os.listdir(data_root) if file_ext in file]\n", 83 | "file_list.sort()\n", 84 | "\n", 85 | "# make new folders\n", 86 | "train_path = os.path.join(save_root, 'train')\n", 87 | "os.makedirs(train_path)\n", 88 | "val_path = os.path.join(save_root, 'val')\n", 89 | "os.makedirs(val_path)\n", 90 | "\n", 91 | "for index, file in tqdm(enumerate(file_list),\n", 92 | " desc='progress',\n", 93 | " total=len(file_list)):\n", 94 | " # find black, saturation, and whitebalance\n", 95 | " img_md = exiv2.ImageMetadata(os.path.join(data_root, file))\n", 96 | " img_md.read()\n", 97 | "\n", 98 | " blk_level = img_md['Exif.SubImage1.BlackLevel'].value\n", 99 | " sat_level = img_md['Exif.SubImage1.WhiteLevel'].value\n", 100 | " cam_wb = img_md['Exif.Image.AsShotNeutral'].value\n", 101 | "\n", 102 | " # convert flat Bayer pattern to 4D tensor (RGGB)\n", 103 | " raw_img = rp.imread(os.path.join(data_root, file))\n", 104 | " flat_bayer = raw_img.raw_image_visible\n", 105 | " raw_data = np.stack((flat_bayer[0::2, 0::2], flat_bayer[0::2, 1::2],\n", 106 | " flat_bayer[1::2, 0::2], flat_bayer[1::2, 1::2]),\n", 107 | " axis=2)\n", 108 | "\n", 109 | " # get ground-truth sRGB image\n", 110 | " gt_img = raw_img.postprocess(use_camera_wb=True,\n", 111 | " output_bps=16,\n", 112 | " no_auto_bright=True,\n", 113 | " adjust_maximum_thr=0.0,\n", 114 | " highlight_mode=HighlightMode.Ignore)\n", 115 | "\n", 116 | " # split to small patches\n", 117 | " part_idx = 0\n", 118 | " raw_hei = gt_img.shape[0] / 2\n", 119 | " raw_wid = gt_img.shape[1] / 2\n", 120 | " for i in range(0, int(raw_hei / patch_size[1])):\n", 121 | " for j in range(0, int(raw_wid / patch_size[0])):\n", 122 | " crop_h = i * patch_size[1]\n", 123 | " crop_w = j * patch_size[0]\n", 124 | " raw_patch = raw_data[crop_h:crop_h + patch_size[1],\n", 125 | " crop_w:crop_w + patch_size[0], :]\n", 126 | " gt_patch = gt_img[2 * crop_h:2 * (crop_h + patch_size[1]),\n", 127 | " 2 * crop_w:2 * (crop_w + patch_size[0]), :]\n", 128 | "\n", 129 | " # save to files\n", 130 | " patch = {}\n", 131 | " patch['blk_level'] = np.array(blk_level, dtype=np.uint16)\n", 132 | " patch['sat_level'] = np.array(sat_level, dtype=np.uint16)\n", 133 | " patch['cam_wb'] = np.array(cam_wb, dtype=np.float32)\n", 134 | " patch['raw'] = np.transpose(raw_patch, (2, 0, 1))\n", 135 | " patch['img'] = np.transpose(gt_patch, (2, 0, 1))\n", 136 | " if index < train_num:\n", 137 | " file_path = os.path.join(\n", 138 | " train_path, '%s_p%03d.pkl' % (file[:-4], part_idx))\n", 139 | " else:\n", 140 | " file_path = os.path.join(\n", 141 | " val_path, '%s_p%03d.pkl' % (file[:-4], part_idx))\n", 142 | " with open(file_path, 'wb') as pkl_file:\n", 143 | " pickle.dump(patch, pkl_file)\n", 144 | "\n", 145 | " # update part index\n", 146 | " part_idx += 1" 147 | ] 148 | } 149 | ], 150 | "metadata": { 151 | "kernelspec": { 152 | "display_name": "Python 3", 153 | "language": "python", 154 | "name": "python3" 155 | }, 156 | "language_info": { 157 | "codemirror_mode": { 158 | "name": "ipython", 159 | "version": 3 160 | }, 161 | "file_extension": ".py", 162 | "mimetype": "text/x-python", 163 | "name": "python", 164 | "nbconvert_exporter": "python", 165 | "pygments_lexer": "ipython3", 166 | "version": "3.8.3" 167 | } 168 | }, 169 | "nbformat": 4, 170 | "nbformat_minor": 2 171 | } 172 | -------------------------------------------------------------------------------- /datasets/getStat-r2rNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Dataset Statistical Analysis" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Includes" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "code_folding": [ 22 | 0 23 | ] 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "# mass includes\n", 28 | "import os\n", 29 | "import pickle\n", 30 | "import pyexiv2 as exiv2\n", 31 | "import numpy as np\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "from tqdm.notebook import tqdm\n", 34 | "from torch.utils import data" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Initialization" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "code_folding": [ 49 | 0 50 | ] 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "# configuration\n", 55 | "data_root = '/home/lab/Documents/ssd//DJI' # dataset path\n", 56 | "file_ext = '.DNG' # extension of RAW file" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Linear fit" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": { 70 | "code_folding": [ 71 | 1 72 | ] 73 | }, 74 | "outputs": [], 75 | "source": [ 76 | "# MLE algorithm\n", 77 | "def linearFit(sample_list):\n", 78 | " data_x = sample_list[:, 0]\n", 79 | " data_y = sample_list[:, 1]\n", 80 | "\n", 81 | " # intermediate variables\n", 82 | " x_mean = np.mean(data_x)\n", 83 | " y_mean = np.mean(data_y)\n", 84 | " lxx = np.sum((data_x - x_mean)**2)\n", 85 | " lyy = np.sum((data_y - y_mean)**2)\n", 86 | " lxy = np.sum((data_x - x_mean) * (data_y - y_mean))\n", 87 | "\n", 88 | " # MLE\n", 89 | " slope = lxy / lxx\n", 90 | " const = y_mean - slope * x_mean\n", 91 | " std = np.sqrt((lyy - slope * lxy) / (len(data_x) - 2))\n", 92 | "\n", 93 | " return slope, const, std" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## Statistical analysis" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "metadata": { 107 | "code_folding": [ 108 | 0 109 | ] 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "# get file list\n", 114 | "file_list = [file for file in os.listdir(data_root) if file_ext in file]\n", 115 | "file_list.sort()\n", 116 | "\n", 117 | "wb_list = []\n", 118 | "noise_list = []\n", 119 | "for index, file in tqdm(enumerate(file_list),\n", 120 | " desc='progress',\n", 121 | " total=len(file_list)):\n", 122 | " # load a new sample\n", 123 | " img_md = exiv2.ImageMetadata(os.path.join(data_root, file))\n", 124 | " img_md.read()\n", 125 | "\n", 126 | " # extract metadata\n", 127 | " cam_wb = img_md['Exif.Image.AsShotNeutral'].value\n", 128 | " wb_list.append(np.array([cam_wb[0], cam_wb[2]], dtype=np.float32))\n", 129 | " cam_noise = img_md['Exif.Image.NoiseProfile'].raw_value.split()\n", 130 | " noise_list.append(np.array(cam_noise, dtype=np.float32))\n", 131 | "\n", 132 | "# compute slope,const, and std\n", 133 | "wb_list = np.log(np.array(wb_list))\n", 134 | "noise_list = np.log(np.array(noise_list))\n", 135 | "wb_s, wb_c, wb_std = linearFit(wb_list)\n", 136 | "noise_s, noise_c, noise_std = linearFit(noise_list)\n", 137 | "\n", 138 | "# print results\n", 139 | "print(\n", 140 | " \"stat info for wb {'slope': %f, 'const': %f, 'std': %f, 'min': %f, 'max': %f}\"\n", 141 | " % (wb_s, wb_c, wb_std, np.min(wb_list[:, 0]), np.max(wb_list[:, 0])))\n", 142 | "print(\n", 143 | " \"stat info for noise {'slope': %f, 'const': %f, 'std': %f, 'min': %f, 'max': %f}\"\n", 144 | " % (noise_s, noise_c, noise_std, np.min(\n", 145 | " noise_list[:, 0]), np.max(noise_list[:, 0])))\n", 146 | "\n", 147 | "# plot resu\n", 148 | "fig1 = plt.figure()\n", 149 | "ax1 = fig1.add_subplot(2, 1, 1)\n", 150 | "ax1.plot(wb_list[:, 0], wb_list[:, 1], 'bo', markersize=3)\n", 151 | "plt.xlabel('$\\log(w_{r})$', fontsize=12)\n", 152 | "plt.ylabel('$\\log(w_{b})$', fontsize=12)\n", 153 | "plt.tight_layout()\n", 154 | "\n", 155 | "fig2 = plt.figure()\n", 156 | "ax2 = fig2.add_subplot(2, 1, 2)\n", 157 | "ax2.plot(noise_list[:, 0], noise_list[:, 1], 'bo', markersize=3)\n", 158 | "plt.xlabel('$\\log(\\lambda_{shot})$', fontsize=12)\n", 159 | "plt.ylabel('$\\log(\\lambda_{read})$', fontsize=12)\n", 160 | "plt.tight_layout()\n", 161 | "\n", 162 | "# save to figure if needed\n", 163 | "fig1.savefig('stat1.png', bbox_inches='tight')\n", 164 | "fig2.savefig('stat2.png', bbox_inches='tight')" 165 | ] 166 | } 167 | ], 168 | "metadata": { 169 | "kernelspec": { 170 | "display_name": "Python 3", 171 | "language": "python", 172 | "name": "python3" 173 | }, 174 | "language_info": { 175 | "codemirror_mode": { 176 | "name": "ipython", 177 | "version": 3 178 | }, 179 | "file_extension": ".py", 180 | "mimetype": "text/x-python", 181 | "name": "python", 182 | "nbconvert_exporter": "python", 183 | "pygments_lexer": "ipython3", 184 | "version": "3.8.3" 185 | } 186 | }, 187 | "nbformat": 4, 188 | "nbformat_minor": 2 189 | } 190 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Network Testing" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Includes" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "code_folding": [ 22 | 0 23 | ] 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "# mass includes\n", 28 | "import os, sys, argparse\n", 29 | "import numpy as np\n", 30 | "import pyexiv2 as exiv2\n", 31 | "import rawpy as rp\n", 32 | "import torch as t\n", 33 | "from torchvision.utils import save_image" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "## Modules" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "code_folding": [ 48 | 0, 49 | 19, 50 | 37, 51 | 62, 52 | 83, 53 | 114, 54 | 176, 55 | 225 56 | ] 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "class BasicModule(t.nn.Module):\n", 61 | " def __init__(self):\n", 62 | " super(BasicModule, self).__init__()\n", 63 | " self.model_name = str(type(self))\n", 64 | "\n", 65 | " def load(self, root, device=None):\n", 66 | " save_list = [\n", 67 | " file for file in os.listdir(root)\n", 68 | " if file.startswith(self.model_name)\n", 69 | " ]\n", 70 | " save_list.sort()\n", 71 | " file_path = os.path.join(root, save_list[-1])\n", 72 | " state_dict = t.load(file_path, map_location=device)\n", 73 | " self.load_state_dict(t.load(file_path, map_location=device))\n", 74 | " print('Weights loaded: %s' % file_path)\n", 75 | "\n", 76 | " return\n", 77 | "\n", 78 | "\n", 79 | "class channelAtt(BasicModule):\n", 80 | " def __init__(self, channels):\n", 81 | " super(channelAtt, self).__init__()\n", 82 | "\n", 83 | " # squeeze-excitation layer\n", 84 | " self.glb_pool = t.nn.AdaptiveAvgPool2d((1, 1))\n", 85 | " self.squeeze_excite = t.nn.Sequential(\n", 86 | " t.nn.Linear(channels, int(channels / 16)), t.nn.LeakyReLU(0.2),\n", 87 | " t.nn.Linear(int(channels / 16), channels), t.nn.Sigmoid())\n", 88 | "\n", 89 | " def forward(self, x):\n", 90 | " scale = self.glb_pool(x)\n", 91 | " scale = self.squeeze_excite(scale.squeeze())\n", 92 | " x = scale.view((x.size(0), x.size(1), 1, 1)) * x\n", 93 | "\n", 94 | " return x\n", 95 | "\n", 96 | "\n", 97 | "class encode(BasicModule):\n", 98 | " def __init__(self, in_channels, out_channels, max_pool=True):\n", 99 | " super(encode, self).__init__()\n", 100 | "\n", 101 | " # features\n", 102 | " if max_pool:\n", 103 | " self.features = t.nn.Sequential(\n", 104 | " t.nn.MaxPool2d((2, 2)),\n", 105 | " t.nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", 106 | " t.nn.LeakyReLU(0.2),\n", 107 | " t.nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", 108 | " channelAtt(out_channels), t.nn.LeakyReLU(0.2))\n", 109 | " else:\n", 110 | " self.features = t.nn.Sequential(\n", 111 | " t.nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", 112 | " t.nn.LeakyReLU(0.2),\n", 113 | " t.nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", 114 | " channelAtt(out_channels), t.nn.LeakyReLU(0.2))\n", 115 | "\n", 116 | " def forward(self, x):\n", 117 | " x = self.features(x)\n", 118 | "\n", 119 | " return x\n", 120 | "\n", 121 | "\n", 122 | "class skipConn(BasicModule):\n", 123 | " def __init__(self, in_channels, out_channels, avg_pool=True):\n", 124 | " super(skipConn, self).__init__()\n", 125 | "\n", 126 | " # features\n", 127 | " if avg_pool:\n", 128 | " self.features = t.nn.Sequential(\n", 129 | " t.nn.AvgPool2d((2, 2)),\n", 130 | " t.nn.Conv2d(in_channels, out_channels, 1),\n", 131 | " channelAtt(out_channels), t.nn.Tanh())\n", 132 | " else:\n", 133 | " self.features = t.nn.Sequential(\n", 134 | " t.nn.Conv2d(in_channels, out_channels, 1),\n", 135 | " channelAtt(out_channels), t.nn.Tanh())\n", 136 | "\n", 137 | " def forward(self, x):\n", 138 | " x = self.features(x)\n", 139 | "\n", 140 | " return x\n", 141 | "\n", 142 | "\n", 143 | "class decode(BasicModule):\n", 144 | " def __init__(self,\n", 145 | " in_channels,\n", 146 | " inter_channels,\n", 147 | " out_channels,\n", 148 | " up_sample=True):\n", 149 | " super(decode, self).__init__()\n", 150 | "\n", 151 | " # features\n", 152 | " if up_sample:\n", 153 | " self.features = t.nn.Sequential(\n", 154 | " t.nn.Conv2d(in_channels, inter_channels, 1),\n", 155 | " t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1),\n", 156 | " t.nn.LeakyReLU(0.2),\n", 157 | " t.nn.Upsample(scale_factor=2, mode='nearest'),\n", 158 | " t.nn.Conv2d(inter_channels, out_channels, 3, padding=1),\n", 159 | " channelAtt(out_channels), t.nn.LeakyReLU(0.2))\n", 160 | " else:\n", 161 | " self.features = t.nn.Sequential(\n", 162 | " t.nn.Conv2d(in_channels, inter_channels, 1),\n", 163 | " t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1),\n", 164 | " t.nn.LeakyReLU(0.2),\n", 165 | " t.nn.Conv2d(inter_channels, out_channels, 3, padding=1),\n", 166 | " t.nn.LeakyReLU(0.2))\n", 167 | "\n", 168 | " def forward(self, x):\n", 169 | " x = self.features(x)\n", 170 | "\n", 171 | " return x\n", 172 | "\n", 173 | "\n", 174 | "class gainEst(BasicModule):\n", 175 | " def __init__(self):\n", 176 | " super(gainEst, self).__init__()\n", 177 | " self.model_name = 'gainEst'\n", 178 | "\n", 179 | " # encoders\n", 180 | " self.head = encode(3, 64, max_pool=False)\n", 181 | " self.down1 = encode(64, 96, max_pool=True)\n", 182 | " self.down2 = encode(96, 128, max_pool=True)\n", 183 | " self.down3 = encode(128, 192, max_pool=True)\n", 184 | "\n", 185 | " # bottleneck\n", 186 | " self.bottleneck = t.nn.Sequential(\n", 187 | " t.nn.MaxPool2d(2, 2), t.nn.Conv2d(192, 256, 3, padding=1),\n", 188 | " t.nn.LeakyReLU(0.2), t.nn.Conv2d(256, 256, 3, padding=1),\n", 189 | " t.nn.Upsample(scale_factor=2, mode='nearest'),\n", 190 | " t.nn.Conv2d(256, 192, 3, padding=1), channelAtt(192),\n", 191 | " t.nn.LeakyReLU(0.2))\n", 192 | "\n", 193 | " # decoders\n", 194 | " self.up1 = decode(384, 384, 128, up_sample=True)\n", 195 | " self.up2 = decode(256, 256, 96, up_sample=True)\n", 196 | " self.up3 = decode(192, 192, 64, up_sample=True)\n", 197 | " self.seg_out = t.nn.Sequential(decode(128, 128, 64, up_sample=False),\n", 198 | " t.nn.Conv2d(64, 2, 1))\n", 199 | "\n", 200 | " # external actication\n", 201 | " self.sigmoid = t.nn.Sigmoid()\n", 202 | "\n", 203 | " # prediction\n", 204 | " self.features = t.nn.Sequential(\n", 205 | " t.nn.Conv2d(5, 64, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 206 | " t.nn.Conv2d(64, 96, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 207 | " t.nn.Conv2d(96, 128, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 208 | " t.nn.Conv2d(128, 192, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 209 | " t.nn.Conv2d(192, 256, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2))\n", 210 | " self.amp_out = t.nn.Sequential(t.nn.Linear(8 * 6 * 256,\n", 211 | " 128), t.nn.LeakyReLU(0.2),\n", 212 | " t.nn.Linear(128, 64),\n", 213 | " t.nn.LeakyReLU(0.2), t.nn.Linear(64, 2))\n", 214 | "\n", 215 | " def forward(self, thumb_img, struct_img):\n", 216 | " # segmentation\n", 217 | " out_head = self.head(struct_img)\n", 218 | " out_d1 = self.down1(out_head)\n", 219 | " out_d2 = self.down2(out_d1)\n", 220 | " out_d3 = self.down3(out_d2)\n", 221 | " out_bottleneck = self.bottleneck(out_d3)\n", 222 | " out_u1 = self.up1(t.cat([out_d3, out_bottleneck], dim=1))\n", 223 | " out_u2 = self.up2(t.cat([out_d2, out_u1], dim=1))\n", 224 | " out_u3 = self.up3(t.cat([out_d1, out_u2], dim=1))\n", 225 | " out_mask = self.seg_out(t.cat([out_head, out_u3], dim=1))\n", 226 | "\n", 227 | " # prediction\n", 228 | " out_features = self.features(\n", 229 | " t.cat([thumb_img, self.sigmoid(out_mask)], dim=1))\n", 230 | " out_amp = self.amp_out(out_features.view(out_features.size(0), -1))\n", 231 | " out_amp = t.clamp(out_amp, 0.0, 1.0)\n", 232 | "\n", 233 | " return out_mask, out_amp\n", 234 | "\n", 235 | "\n", 236 | "class ispNet(BasicModule):\n", 237 | " def __init__(self):\n", 238 | " super(ispNet, self).__init__()\n", 239 | "\n", 240 | " # encoders\n", 241 | " self.head = encode(8, 64, max_pool=False)\n", 242 | " self.down1 = encode(64, 64, max_pool=True)\n", 243 | " self.down2 = encode(64, 64, max_pool=True)\n", 244 | "\n", 245 | " # skip connections\n", 246 | " self.skip1 = skipConn(1, 64, avg_pool=False)\n", 247 | " self.skip2 = skipConn(64, 64, avg_pool=True)\n", 248 | " self.skip3 = skipConn(64, 64, avg_pool=True)\n", 249 | "\n", 250 | " # decoders\n", 251 | " self.up1 = decode(128, 64, 64, up_sample=True)\n", 252 | " self.up2 = decode(128, 64, 64, up_sample=True)\n", 253 | " self.srgb_out = t.nn.Sequential(\n", 254 | " decode(128, 64, 64, up_sample=False),\n", 255 | " t.nn.Upsample(scale_factor=2, mode='nearest'),\n", 256 | " t.nn.Conv2d(64, 3, 3, padding=1))\n", 257 | "\n", 258 | " def forward(self, color_map, mag_map, amp, wb):\n", 259 | " # to prevent saturation\n", 260 | " mag_map = amp.view(-1, 1, 1, 1) * mag_map\n", 261 | " mag_map = t.nn.functional.tanh(mag_map - 0.5)\n", 262 | " max_mag = 2.0 * amp.view(-1, 1, 1, 1)\n", 263 | " max_mag = t.nn.functional.tanh(max_mag - 0.5)\n", 264 | " mag_map = mag_map / max_mag\n", 265 | "\n", 266 | " # encoder outputs\n", 267 | " out_head = self.head(t.cat([color_map, mag_map, wb], dim=1))\n", 268 | " out_d1 = self.down1(out_head)\n", 269 | " out_d2 = self.down2(out_d1)\n", 270 | "\n", 271 | " # skip connection outputs\n", 272 | " out_s1 = self.skip1(mag_map)\n", 273 | " out_s2 = self.skip2(out_head)\n", 274 | " out_s3 = self.skip3(out_d1)\n", 275 | "\n", 276 | " # decoder outputs\n", 277 | " out_u1 = self.up1(t.cat([out_s3, out_d2], dim=1))\n", 278 | " out_u2 = self.up2(t.cat([out_s2, out_u1], dim=1))\n", 279 | " out_srgb = self.srgb_out(t.cat([out_s1, out_u2], dim=1))\n", 280 | " out_srgb = t.clamp(out_srgb, 0.0, 1.0)\n", 281 | "\n", 282 | " return out_srgb\n", 283 | "\n", 284 | "\n", 285 | "class rawProcess(BasicModule):\n", 286 | " def __init__(self):\n", 287 | " super(rawProcess, self).__init__()\n", 288 | " self.model_name = 'rawProcess'\n", 289 | "\n", 290 | " # isp module\n", 291 | " self.isp_net = ispNet()\n", 292 | "\n", 293 | " # fusion\n", 294 | " self.fusion = t.nn.Sequential(t.nn.Conv2d(6, 128, 3, padding=1),\n", 295 | " channelAtt(128),\n", 296 | " t.nn.Conv2d(128, 3, 3, padding=1))\n", 297 | "\n", 298 | " def forward(self, raw_data, amp_high, amp_low, wb):\n", 299 | " # convert to color map and mgnitude map\n", 300 | " mag_map = t.sqrt(t.sum(t.pow(raw_data, 2), 1, keepdim=True))\n", 301 | " color_map = raw_data / (mag_map + 1e-4)\n", 302 | "\n", 303 | " # convert to sRGB images\n", 304 | " out_high = self.isp_net(color_map, mag_map, amp_high, wb)\n", 305 | " out_low = self.isp_net(color_map, mag_map, amp_low, wb)\n", 306 | "\n", 307 | " # image fusion\n", 308 | " out_fused = self.fusion(t.cat([out_high, out_low], dim=1))\n", 309 | " out_fused = t.clamp(out_fused, 0.0, 1.0)\n", 310 | "\n", 311 | " return out_fused" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "metadata": {}, 317 | "source": [ 318 | "## Test" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": null, 324 | "metadata": { 325 | "code_folding": [ 326 | 1, 327 | 13, 328 | 34, 329 | 48, 330 | 158 331 | ] 332 | }, 333 | "outputs": [], 334 | "source": [ 335 | "# normalization\n", 336 | "def normalize(raw_data, bk_level, sat_level):\n", 337 | " normal_raw = t.empty_like(raw_data)\n", 338 | " for index in range(raw_data.size(0)):\n", 339 | " for channel in range(raw_data.size(1)):\n", 340 | " normal_raw[index, channel, :, :] = (\n", 341 | " raw_data[index, channel, :, :] -\n", 342 | " bk_level[channel]) / (sat_level - bk_level[channel])\n", 343 | "\n", 344 | " return normal_raw\n", 345 | "\n", 346 | "\n", 347 | "# resize Bayer pattern\n", 348 | "def downSample(raw_data, struct_img_size):\n", 349 | " # convert Bayer pattern to down-sized sRGB image\n", 350 | " batch, _, hei, wid = raw_data.size()\n", 351 | " raw_img = raw_data.new_empty((batch, 3, hei, wid))\n", 352 | " raw_img[:, 0, :, :] = raw_data[:, 0, :, :] # R\n", 353 | " raw_img[:,\n", 354 | " 1, :, :] = (raw_data[:, 1, :, :] + raw_data[:, 2, :, :]) / 2.0 # G\n", 355 | " raw_img[:, 2, :, :] = raw_data[:, 3, :, :] # B\n", 356 | "\n", 357 | " # down-sample to small size\n", 358 | " if hei != struct_img_size[1] and wid != struct_img_size[0]:\n", 359 | " raw_img = t.nn.functional.interpolate(raw_img,\n", 360 | " size=(struct_img_size[1],\n", 361 | " struct_img_size[0]),\n", 362 | " mode='bicubic')\n", 363 | " raw_img = t.clamp(raw_img, 0.0, 1.0)\n", 364 | "\n", 365 | " return raw_img\n", 366 | "\n", 367 | "\n", 368 | "# image standardization (mean 0, std 1)\n", 369 | "def standardize(srgb_img):\n", 370 | " struct_img = t.empty_like(srgb_img)\n", 371 | " adj_std = 1.0 / t.sqrt(srgb_img.new_tensor(srgb_img[0, :, :, :].numel()))\n", 372 | " for index in range(srgb_img.size(0)):\n", 373 | " mean = t.mean(srgb_img[index, :, :, :])\n", 374 | " std = t.std(srgb_img[index, :, :, :])\n", 375 | " adj_std = t.max(std, adj_std)\n", 376 | " struct_img[index, :, :, :] = (srgb_img[index, :, :, :] -\n", 377 | " mean) / adj_std\n", 378 | "\n", 379 | " return struct_img\n", 380 | "\n", 381 | "\n", 382 | "# main entry\n", 383 | "def main(args):\n", 384 | " # initialization\n", 385 | " att_size = (256, 192)\n", 386 | " amp_range = (1, 20)\n", 387 | "\n", 388 | " # choose GPU if available\n", 389 | " os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 390 | " os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 391 | " device = t.device(args.device)\n", 392 | "\n", 393 | " # define models\n", 394 | " gain_est_model = gainEst().to(device)\n", 395 | " gain_est_model.load('./saves', device=device)\n", 396 | " gain_est_model.eval()\n", 397 | " raw_process_model = rawProcess().to(device)\n", 398 | " raw_process_model.load('./saves', device=device)\n", 399 | " raw_process_model.eval()\n", 400 | "\n", 401 | " # search for valid files\n", 402 | " file_list = [file for file in os.listdir(args.input) if '.DNG' in file]\n", 403 | " file_list.sort()\n", 404 | " if not os.path.exists(args.output):\n", 405 | " os.makedirs(args.output)\n", 406 | "\n", 407 | " # loop to process\n", 408 | " for file in file_list:\n", 409 | " # read black, saturation, and whitebalance\n", 410 | " img_md = exiv2.ImageMetadata(os.path.join(args.input, file))\n", 411 | " img_md.read()\n", 412 | "\n", 413 | " blk_level = img_md['Exif.SubImage1.BlackLevel'].value\n", 414 | " sat_level = img_md['Exif.SubImage1.WhiteLevel'].value\n", 415 | " cam_wb = img_md['Exif.Image.AsShotNeutral'].value\n", 416 | "\n", 417 | " # convert flat Bayer pattern to 4D tensor (RGGB)\n", 418 | " raw_img = rp.imread(os.path.join(args.input, file))\n", 419 | " flat_bayer = raw_img.raw_image_visible\n", 420 | " raw_data = np.stack((flat_bayer[0::2, 0::2], flat_bayer[0::2, 1::2],\n", 421 | " flat_bayer[1::2, 0::2], flat_bayer[1::2, 1::2]),\n", 422 | " axis=2)\n", 423 | "\n", 424 | " with t.no_grad():\n", 425 | " # copy to device\n", 426 | " blk_level = t.from_numpy(np.array(blk_level,\n", 427 | " dtype=np.float32)).to(device)\n", 428 | " sat_level = t.from_numpy(np.array(sat_level,\n", 429 | " dtype=np.float32)).to(device)\n", 430 | " cam_wb = t.from_numpy(np.array(cam_wb,\n", 431 | " dtype=np.float32)).to(device)\n", 432 | " raw_data = t.from_numpy(raw_data.astype(np.float32)).to(device)\n", 433 | " raw_data = raw_data.permute(2, 0, 1).unsqueeze(0)\n", 434 | "\n", 435 | " # downsample\n", 436 | " if args.resize:\n", 437 | " raw_data = t.nn.functional.interpolate(raw_data,\n", 438 | " size=args.resize,\n", 439 | " mode='bicubic')\n", 440 | "\n", 441 | " # pre-processing\n", 442 | " raw_data = normalize(raw_data, blk_level, sat_level)\n", 443 | " cam_wb = cam_wb.view([1, 3, 1, 1]).expand(\n", 444 | " [1, 3, raw_data.size(2),\n", 445 | " raw_data.size(3)])\n", 446 | " cam_wb = cam_wb.clone()\n", 447 | " thumb_img = downSample(raw_data, att_size)\n", 448 | " struct_img = standardize(thumb_img)\n", 449 | "\n", 450 | " # run model\n", 451 | " _, pred_amp = gain_est_model(thumb_img, struct_img)\n", 452 | " pred_amp = t.clamp(pred_amp * amp_range[1], amp_range[0],\n", 453 | " amp_range[1])\n", 454 | " print('Predicted ratio(fg/bg) for %s: %.2f, %.2f.' %\n", 455 | " (file, pred_amp[0, 0], pred_amp[0, 1]))\n", 456 | " amp_high, _ = t.max(pred_amp, 1)\n", 457 | " amp_low, _ = t.min(pred_amp, 1)\n", 458 | " pred_fused = raw_process_model(raw_data, amp_high, amp_low, cam_wb)\n", 459 | "\n", 460 | " # save to images\n", 461 | " save_image(\n", 462 | " pred_fused.cpu().squeeze(),\n", 463 | " os.path.join(args.output,\n", 464 | " '%s' % file.replace('.DNG', '-fuse.png')))\n", 465 | "\n", 466 | " # fisheye lens calibration\n", 467 | " # modified from https://medium.com/@kennethjiang/calibrate-fisheye-lens-using-opencv-333b05afa0b0\n", 468 | " if args.calib:\n", 469 | " import cv2\n", 470 | "\n", 471 | " DIM = (4000, 3000)\n", 472 | " K = np.array([[1715.9053454852321, 0.0, 2025.0267134780845],\n", 473 | " [0.0, 1713.8092418955127, 1511.2242172068645],\n", 474 | " [0.0, 0.0, 1.0]])\n", 475 | " D = np.array([[0.21801544244553403], [0.011549797903321477],\n", 476 | " [-0.05436236262851618], [-0.01888678272481524]])\n", 477 | " img = cv2.imread(\n", 478 | " os.path.join(args.output,\n", 479 | " '%s' % file.replace('.DNG', '-fuse.png')))\n", 480 | " map1, map2 = cv2.fisheye.initUndistortRectifyMap(\n", 481 | " K, D, np.eye(3), K, DIM, cv2.CV_16SC2)\n", 482 | " calib_img = cv2.remap(img,\n", 483 | " map1,\n", 484 | " map2,\n", 485 | " interpolation=cv2.INTER_LINEAR,\n", 486 | " borderMode=cv2.BORDER_CONSTANT)\n", 487 | " cv2.imwrite(\n", 488 | " os.path.join(args.output,\n", 489 | " '%s' % file.replace('.DNG', '-calib.png')),\n", 490 | " calib_img)\n", 491 | "\n", 492 | "\n", 493 | "if __name__ == '__main__':\n", 494 | " parser = argparse.ArgumentParser()\n", 495 | " parser.add_argument('--input', default='./samples', help='input directory')\n", 496 | " parser.add_argument('--output',\n", 497 | " default='./results',\n", 498 | " help='output directory')\n", 499 | " parser.add_argument('--resize',\n", 500 | " default=None,\n", 501 | " type=tuple,\n", 502 | " help='downsample to smaller size (hxw)')\n", 503 | " parser.add_argument('--device',\n", 504 | " default='cpu',\n", 505 | " help='device to be used (cpu or cuda)')\n", 506 | " parser.add_argument('--calib',\n", 507 | " action='store_true',\n", 508 | " help='perform fisheye calibration')\n", 509 | " args = parser.parse_args()\n", 510 | " main(args)" 511 | ] 512 | } 513 | ], 514 | "metadata": { 515 | "kernelspec": { 516 | "display_name": "Python 3", 517 | "language": "python", 518 | "name": "python3" 519 | }, 520 | "language_info": { 521 | "codemirror_mode": { 522 | "name": "ipython", 523 | "version": 3 524 | }, 525 | "file_extension": ".py", 526 | "mimetype": "text/x-python", 527 | "name": "python", 528 | "nbconvert_exporter": "python", 529 | "pygments_lexer": "ipython3", 530 | "version": "3.8.3" 531 | } 532 | }, 533 | "nbformat": 4, 534 | "nbformat_minor": 2 535 | } 536 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Network Testing 5 | 6 | # ## Includes 7 | 8 | # In[ ]: 9 | 10 | 11 | # mass includes 12 | import os, sys, argparse 13 | import numpy as np 14 | import pyexiv2 as exiv2 15 | import rawpy as rp 16 | import torch as t 17 | from torchvision.utils import save_image 18 | 19 | 20 | # ## Modules 21 | 22 | # In[ ]: 23 | 24 | 25 | class BasicModule(t.nn.Module): 26 | def __init__(self): 27 | super(BasicModule, self).__init__() 28 | self.model_name = str(type(self)) 29 | 30 | def load(self, root, device=None): 31 | save_list = [ 32 | file for file in os.listdir(root) 33 | if file.startswith(self.model_name) 34 | ] 35 | save_list.sort() 36 | file_path = os.path.join(root, save_list[-1]) 37 | state_dict = t.load(file_path, map_location=device) 38 | self.load_state_dict(t.load(file_path, map_location=device)) 39 | print('Weights loaded: %s' % file_path) 40 | 41 | return 42 | 43 | 44 | class channelAtt(BasicModule): 45 | def __init__(self, channels): 46 | super(channelAtt, self).__init__() 47 | 48 | # squeeze-excitation layer 49 | self.glb_pool = t.nn.AdaptiveAvgPool2d((1, 1)) 50 | self.squeeze_excite = t.nn.Sequential( 51 | t.nn.Linear(channels, int(channels / 16)), t.nn.LeakyReLU(0.2), 52 | t.nn.Linear(int(channels / 16), channels), t.nn.Sigmoid()) 53 | 54 | def forward(self, x): 55 | scale = self.glb_pool(x) 56 | scale = self.squeeze_excite(scale.squeeze()) 57 | x = scale.view((x.size(0), x.size(1), 1, 1)) * x 58 | 59 | return x 60 | 61 | 62 | class encode(BasicModule): 63 | def __init__(self, in_channels, out_channels, max_pool=True): 64 | super(encode, self).__init__() 65 | 66 | # features 67 | if max_pool: 68 | self.features = t.nn.Sequential( 69 | t.nn.MaxPool2d((2, 2)), 70 | t.nn.Conv2d(in_channels, out_channels, 3, padding=1), 71 | t.nn.LeakyReLU(0.2), 72 | t.nn.Conv2d(out_channels, out_channels, 3, padding=1), 73 | channelAtt(out_channels), t.nn.LeakyReLU(0.2)) 74 | else: 75 | self.features = t.nn.Sequential( 76 | t.nn.Conv2d(in_channels, out_channels, 3, padding=1), 77 | t.nn.LeakyReLU(0.2), 78 | t.nn.Conv2d(out_channels, out_channels, 3, padding=1), 79 | channelAtt(out_channels), t.nn.LeakyReLU(0.2)) 80 | 81 | def forward(self, x): 82 | x = self.features(x) 83 | 84 | return x 85 | 86 | 87 | class skipConn(BasicModule): 88 | def __init__(self, in_channels, out_channels, avg_pool=True): 89 | super(skipConn, self).__init__() 90 | 91 | # features 92 | if avg_pool: 93 | self.features = t.nn.Sequential( 94 | t.nn.AvgPool2d((2, 2)), 95 | t.nn.Conv2d(in_channels, out_channels, 1), 96 | channelAtt(out_channels), t.nn.Tanh()) 97 | else: 98 | self.features = t.nn.Sequential( 99 | t.nn.Conv2d(in_channels, out_channels, 1), 100 | channelAtt(out_channels), t.nn.Tanh()) 101 | 102 | def forward(self, x): 103 | x = self.features(x) 104 | 105 | return x 106 | 107 | 108 | class decode(BasicModule): 109 | def __init__(self, 110 | in_channels, 111 | inter_channels, 112 | out_channels, 113 | up_sample=True): 114 | super(decode, self).__init__() 115 | 116 | # features 117 | if up_sample: 118 | self.features = t.nn.Sequential( 119 | t.nn.Conv2d(in_channels, inter_channels, 1), 120 | t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1), 121 | t.nn.LeakyReLU(0.2), 122 | t.nn.Upsample(scale_factor=2, mode='nearest'), 123 | t.nn.Conv2d(inter_channels, out_channels, 3, padding=1), 124 | channelAtt(out_channels), t.nn.LeakyReLU(0.2)) 125 | else: 126 | self.features = t.nn.Sequential( 127 | t.nn.Conv2d(in_channels, inter_channels, 1), 128 | t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1), 129 | t.nn.LeakyReLU(0.2), 130 | t.nn.Conv2d(inter_channels, out_channels, 3, padding=1), 131 | t.nn.LeakyReLU(0.2)) 132 | 133 | def forward(self, x): 134 | x = self.features(x) 135 | 136 | return x 137 | 138 | 139 | class gainEst(BasicModule): 140 | def __init__(self): 141 | super(gainEst, self).__init__() 142 | self.model_name = 'gainEst' 143 | 144 | # encoders 145 | self.head = encode(3, 64, max_pool=False) 146 | self.down1 = encode(64, 96, max_pool=True) 147 | self.down2 = encode(96, 128, max_pool=True) 148 | self.down3 = encode(128, 192, max_pool=True) 149 | 150 | # bottleneck 151 | self.bottleneck = t.nn.Sequential( 152 | t.nn.MaxPool2d(2, 2), t.nn.Conv2d(192, 256, 3, padding=1), 153 | t.nn.LeakyReLU(0.2), t.nn.Conv2d(256, 256, 3, padding=1), 154 | t.nn.Upsample(scale_factor=2, mode='nearest'), 155 | t.nn.Conv2d(256, 192, 3, padding=1), channelAtt(192), 156 | t.nn.LeakyReLU(0.2)) 157 | 158 | # decoders 159 | self.up1 = decode(384, 384, 128, up_sample=True) 160 | self.up2 = decode(256, 256, 96, up_sample=True) 161 | self.up3 = decode(192, 192, 64, up_sample=True) 162 | self.seg_out = t.nn.Sequential(decode(128, 128, 64, up_sample=False), 163 | t.nn.Conv2d(64, 2, 1)) 164 | 165 | # external actication 166 | self.sigmoid = t.nn.Sigmoid() 167 | 168 | # prediction 169 | self.features = t.nn.Sequential( 170 | t.nn.Conv2d(5, 64, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 171 | t.nn.Conv2d(64, 96, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 172 | t.nn.Conv2d(96, 128, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 173 | t.nn.Conv2d(128, 192, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 174 | t.nn.Conv2d(192, 256, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2)) 175 | self.amp_out = t.nn.Sequential(t.nn.Linear(8 * 6 * 256, 176 | 128), t.nn.LeakyReLU(0.2), 177 | t.nn.Linear(128, 64), 178 | t.nn.LeakyReLU(0.2), t.nn.Linear(64, 2)) 179 | 180 | def forward(self, thumb_img, struct_img): 181 | # segmentation 182 | out_head = self.head(struct_img) 183 | out_d1 = self.down1(out_head) 184 | out_d2 = self.down2(out_d1) 185 | out_d3 = self.down3(out_d2) 186 | out_bottleneck = self.bottleneck(out_d3) 187 | out_u1 = self.up1(t.cat([out_d3, out_bottleneck], dim=1)) 188 | out_u2 = self.up2(t.cat([out_d2, out_u1], dim=1)) 189 | out_u3 = self.up3(t.cat([out_d1, out_u2], dim=1)) 190 | out_mask = self.seg_out(t.cat([out_head, out_u3], dim=1)) 191 | 192 | # prediction 193 | out_features = self.features( 194 | t.cat([thumb_img, self.sigmoid(out_mask)], dim=1)) 195 | out_amp = self.amp_out(out_features.view(out_features.size(0), -1)) 196 | out_amp = t.clamp(out_amp, 0.0, 1.0) 197 | 198 | return out_mask, out_amp 199 | 200 | 201 | class ispNet(BasicModule): 202 | def __init__(self): 203 | super(ispNet, self).__init__() 204 | 205 | # encoders 206 | self.head = encode(8, 64, max_pool=False) 207 | self.down1 = encode(64, 64, max_pool=True) 208 | self.down2 = encode(64, 64, max_pool=True) 209 | 210 | # skip connections 211 | self.skip1 = skipConn(1, 64, avg_pool=False) 212 | self.skip2 = skipConn(64, 64, avg_pool=True) 213 | self.skip3 = skipConn(64, 64, avg_pool=True) 214 | 215 | # decoders 216 | self.up1 = decode(128, 64, 64, up_sample=True) 217 | self.up2 = decode(128, 64, 64, up_sample=True) 218 | self.srgb_out = t.nn.Sequential( 219 | decode(128, 64, 64, up_sample=False), 220 | t.nn.Upsample(scale_factor=2, mode='nearest'), 221 | t.nn.Conv2d(64, 3, 3, padding=1)) 222 | 223 | def forward(self, color_map, mag_map, amp, wb): 224 | # to prevent saturation 225 | mag_map = amp.view(-1, 1, 1, 1) * mag_map 226 | mag_map = t.nn.functional.tanh(mag_map - 0.5) 227 | max_mag = 2.0 * amp.view(-1, 1, 1, 1) 228 | max_mag = t.nn.functional.tanh(max_mag - 0.5) 229 | mag_map = mag_map / max_mag 230 | 231 | # encoder outputs 232 | out_head = self.head(t.cat([color_map, mag_map, wb], dim=1)) 233 | out_d1 = self.down1(out_head) 234 | out_d2 = self.down2(out_d1) 235 | 236 | # skip connection outputs 237 | out_s1 = self.skip1(mag_map) 238 | out_s2 = self.skip2(out_head) 239 | out_s3 = self.skip3(out_d1) 240 | 241 | # decoder outputs 242 | out_u1 = self.up1(t.cat([out_s3, out_d2], dim=1)) 243 | out_u2 = self.up2(t.cat([out_s2, out_u1], dim=1)) 244 | out_srgb = self.srgb_out(t.cat([out_s1, out_u2], dim=1)) 245 | out_srgb = t.clamp(out_srgb, 0.0, 1.0) 246 | 247 | return out_srgb 248 | 249 | 250 | class rawProcess(BasicModule): 251 | def __init__(self): 252 | super(rawProcess, self).__init__() 253 | self.model_name = 'rawProcess' 254 | 255 | # isp module 256 | self.isp_net = ispNet() 257 | 258 | # fusion 259 | self.fusion = t.nn.Sequential(t.nn.Conv2d(6, 128, 3, padding=1), 260 | channelAtt(128), 261 | t.nn.Conv2d(128, 3, 3, padding=1)) 262 | 263 | def forward(self, raw_data, amp_high, amp_low, wb): 264 | # convert to color map and mgnitude map 265 | mag_map = t.sqrt(t.sum(t.pow(raw_data, 2), 1, keepdim=True)) 266 | color_map = raw_data / (mag_map + 1e-4) 267 | 268 | # convert to sRGB images 269 | out_high = self.isp_net(color_map, mag_map, amp_high, wb) 270 | out_low = self.isp_net(color_map, mag_map, amp_low, wb) 271 | 272 | # image fusion 273 | out_fused = self.fusion(t.cat([out_high, out_low], dim=1)) 274 | out_fused = t.clamp(out_fused, 0.0, 1.0) 275 | 276 | return out_fused 277 | 278 | 279 | # ## Test 280 | 281 | # In[ ]: 282 | 283 | 284 | # normalization 285 | def normalize(raw_data, bk_level, sat_level): 286 | normal_raw = t.empty_like(raw_data) 287 | for index in range(raw_data.size(0)): 288 | for channel in range(raw_data.size(1)): 289 | normal_raw[index, channel, :, :] = ( 290 | raw_data[index, channel, :, :] - 291 | bk_level[channel]) / (sat_level - bk_level[channel]) 292 | 293 | return normal_raw 294 | 295 | 296 | # resize Bayer pattern 297 | def downSample(raw_data, struct_img_size): 298 | # convert Bayer pattern to down-sized sRGB image 299 | batch, _, hei, wid = raw_data.size() 300 | raw_img = raw_data.new_empty((batch, 3, hei, wid)) 301 | raw_img[:, 0, :, :] = raw_data[:, 0, :, :] # R 302 | raw_img[:, 303 | 1, :, :] = (raw_data[:, 1, :, :] + raw_data[:, 2, :, :]) / 2.0 # G 304 | raw_img[:, 2, :, :] = raw_data[:, 3, :, :] # B 305 | 306 | # down-sample to small size 307 | if hei != struct_img_size[1] and wid != struct_img_size[0]: 308 | raw_img = t.nn.functional.interpolate(raw_img, 309 | size=(struct_img_size[1], 310 | struct_img_size[0]), 311 | mode='bicubic') 312 | raw_img = t.clamp(raw_img, 0.0, 1.0) 313 | 314 | return raw_img 315 | 316 | 317 | # image standardization (mean 0, std 1) 318 | def standardize(srgb_img): 319 | struct_img = t.empty_like(srgb_img) 320 | adj_std = 1.0 / t.sqrt(srgb_img.new_tensor(srgb_img[0, :, :, :].numel())) 321 | for index in range(srgb_img.size(0)): 322 | mean = t.mean(srgb_img[index, :, :, :]) 323 | std = t.std(srgb_img[index, :, :, :]) 324 | adj_std = t.max(std, adj_std) 325 | struct_img[index, :, :, :] = (srgb_img[index, :, :, :] - 326 | mean) / adj_std 327 | 328 | return struct_img 329 | 330 | 331 | # main entry 332 | def main(args): 333 | # initialization 334 | att_size = (256, 192) 335 | amp_range = (1, 20) 336 | 337 | # choose GPU if available 338 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 339 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 340 | device = t.device(args.device) 341 | 342 | # define models 343 | gain_est_model = gainEst().to(device) 344 | gain_est_model.load('./saves', device=device) 345 | gain_est_model.eval() 346 | raw_process_model = rawProcess().to(device) 347 | raw_process_model.load('./saves', device=device) 348 | raw_process_model.eval() 349 | 350 | # search for valid files 351 | file_list = [file for file in os.listdir(args.input) if '.DNG' in file] 352 | file_list.sort() 353 | if not os.path.exists(args.output): 354 | os.makedirs(args.output) 355 | 356 | # loop to process 357 | for file in file_list: 358 | # read black, saturation, and whitebalance 359 | img_md = exiv2.ImageMetadata(os.path.join(args.input, file)) 360 | img_md.read() 361 | 362 | blk_level = img_md['Exif.SubImage1.BlackLevel'].value 363 | sat_level = img_md['Exif.SubImage1.WhiteLevel'].value 364 | cam_wb = img_md['Exif.Image.AsShotNeutral'].value 365 | 366 | # convert flat Bayer pattern to 4D tensor (RGGB) 367 | raw_img = rp.imread(os.path.join(args.input, file)) 368 | flat_bayer = raw_img.raw_image_visible 369 | raw_data = np.stack((flat_bayer[0::2, 0::2], flat_bayer[0::2, 1::2], 370 | flat_bayer[1::2, 0::2], flat_bayer[1::2, 1::2]), 371 | axis=2) 372 | 373 | with t.no_grad(): 374 | # copy to device 375 | blk_level = t.from_numpy(np.array(blk_level, 376 | dtype=np.float32)).to(device) 377 | sat_level = t.from_numpy(np.array(sat_level, 378 | dtype=np.float32)).to(device) 379 | cam_wb = t.from_numpy(np.array(cam_wb, 380 | dtype=np.float32)).to(device) 381 | raw_data = t.from_numpy(raw_data.astype(np.float32)).to(device) 382 | raw_data = raw_data.permute(2, 0, 1).unsqueeze(0) 383 | 384 | # downsample 385 | if args.resize: 386 | raw_data = t.nn.functional.interpolate(raw_data, 387 | size=args.resize, 388 | mode='bicubic') 389 | 390 | # pre-processing 391 | raw_data = normalize(raw_data, blk_level, sat_level) 392 | cam_wb = cam_wb.view([1, 3, 1, 1]).expand( 393 | [1, 3, raw_data.size(2), 394 | raw_data.size(3)]) 395 | cam_wb = cam_wb.clone() 396 | thumb_img = downSample(raw_data, att_size) 397 | struct_img = standardize(thumb_img) 398 | 399 | # run model 400 | _, pred_amp = gain_est_model(thumb_img, struct_img) 401 | pred_amp = t.clamp(pred_amp * amp_range[1], amp_range[0], 402 | amp_range[1]) 403 | print('Predicted ratio(fg/bg) for %s: %.2f, %.2f.' % 404 | (file, pred_amp[0, 0], pred_amp[0, 1])) 405 | amp_high, _ = t.max(pred_amp, 1) 406 | amp_low, _ = t.min(pred_amp, 1) 407 | pred_fused = raw_process_model(raw_data, amp_high, amp_low, cam_wb) 408 | 409 | # save to images 410 | save_image( 411 | pred_fused.cpu().squeeze(), 412 | os.path.join(args.output, 413 | '%s' % file.replace('.DNG', '-fuse.png'))) 414 | 415 | # fisheye lens calibration 416 | # modified from https://medium.com/@kennethjiang/calibrate-fisheye-lens-using-opencv-333b05afa0b0 417 | if args.calib: 418 | import cv2 419 | 420 | DIM = (4000, 3000) 421 | K = np.array([[1715.9053454852321, 0.0, 2025.0267134780845], 422 | [0.0, 1713.8092418955127, 1511.2242172068645], 423 | [0.0, 0.0, 1.0]]) 424 | D = np.array([[0.21801544244553403], [0.011549797903321477], 425 | [-0.05436236262851618], [-0.01888678272481524]]) 426 | img = cv2.imread( 427 | os.path.join(args.output, 428 | '%s' % file.replace('.DNG', '-fuse.png'))) 429 | map1, map2 = cv2.fisheye.initUndistortRectifyMap( 430 | K, D, np.eye(3), K, DIM, cv2.CV_16SC2) 431 | calib_img = cv2.remap(img, 432 | map1, 433 | map2, 434 | interpolation=cv2.INTER_LINEAR, 435 | borderMode=cv2.BORDER_CONSTANT) 436 | cv2.imwrite( 437 | os.path.join(args.output, 438 | '%s' % file.replace('.DNG', '-calib.png')), 439 | calib_img) 440 | 441 | 442 | if __name__ == '__main__': 443 | parser = argparse.ArgumentParser() 444 | parser.add_argument('--input', default='./samples', help='input directory') 445 | parser.add_argument('--output', 446 | default='./results', 447 | help='output directory') 448 | parser.add_argument('--resize', 449 | default=None, 450 | type=tuple, 451 | help='downsample to smaller size (hxw)') 452 | parser.add_argument('--device', 453 | default='cpu', 454 | help='device to be used (cpu or cuda)') 455 | parser.add_argument('--calib', 456 | action='store_true', 457 | help='perform fisheye calibration') 458 | args = parser.parse_args() 459 | main(args) 460 | 461 | -------------------------------------------------------------------------------- /genValset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Generate Validation Set" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import os, sys, warnings\n", 32 | "import ipdb\n", 33 | "import pickle\n", 34 | "import torch as t\n", 35 | "import torchvision as tv\n", 36 | "from tqdm.notebook import tqdm\n", 37 | "\n", 38 | "# add paths for all sub-folders\n", 39 | "paths = [root for root, dirs, files in os.walk('.')]\n", 40 | "for item in paths:\n", 41 | " sys.path.append(item)\n", 42 | "\n", 43 | "from ipynb.fs.full.config import valConf\n", 44 | "from ipynb.fs.full.monitor import Visualizer\n", 45 | "from ipynb.fs.full.network import r2rNet\n", 46 | "from ipynb.fs.full.dataLoader import fivekNight\n", 47 | "from ipynb.fs.full.util import *" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "heading_collapsed": true 54 | }, 55 | "source": [ 56 | "## Initialization" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "code_folding": [ 64 | 0 65 | ], 66 | "hidden": true 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "# for debugging only\n", 71 | "%pdb off\n", 72 | "warnings.filterwarnings('ignore')\n", 73 | "\n", 74 | "# choose GPU if available\n", 75 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 76 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 77 | "device = t.device('cuda' if t.cuda.is_available() else 'cpu')\n", 78 | "\n", 79 | "# define model\n", 80 | "opt = valConf()\n", 81 | "converter = r2rNet().to(device)\n", 82 | "converter.load('./saves')\n", 83 | "converter.eval()\n", 84 | "\n", 85 | "# dataloader for training\n", 86 | "val_dataset = fivekNight(opt)\n", 87 | "val_loader = t.utils.data.DataLoader(val_dataset, shuffle=True)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "## Training entry" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "metadata": { 101 | "code_folding": [ 102 | 0 103 | ] 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "# make new folder\n", 108 | "save_path = os.path.join(opt.data_root, 'val%d' % opt.amp_range[1])\n", 109 | "if not os.path.exists(save_path):\n", 110 | " os.makedirs(save_path)\n", 111 | "\n", 112 | "for index, (syth_img, syth_mask) in tqdm(enumerate(val_loader), total=50):\n", 113 | " # copy to device\n", 114 | " syth_img = syth_img.to(device)\n", 115 | " syth_mask = syth_mask.to(device)\n", 116 | "\n", 117 | " # convert to training samples\n", 118 | " thumb_img, struct_img, seg_mask, amp, noisy_raw, sorted_mask, wb = toRaw(\n", 119 | " converter, syth_img, syth_mask, opt)\n", 120 | "\n", 121 | " # save to files\n", 122 | " file_path = os.path.join(save_path, 'img%04d.jpg' % (index))\n", 123 | " data_dict = {}\n", 124 | " data_dict['syth_img'] = syth_img.squeeze().cpu()\n", 125 | " data_dict['thumb_img'] = thumb_img.squeeze().cpu()\n", 126 | " data_dict['struct_img'] = struct_img.squeeze().cpu()\n", 127 | " data_dict['seg_mask'] = seg_mask.squeeze().cpu()\n", 128 | " data_dict['amp'] = amp.squeeze().cpu()\n", 129 | " data_dict['noisy_raw'] = noisy_raw.squeeze().cpu()\n", 130 | " data_dict['sorted_mask'] = sorted_mask.squeeze().cpu()\n", 131 | " data_dict['wb'] = wb.squeeze().cpu()\n", 132 | " with open(file_path.replace('jpg', 'pkl'), 'wb') as pkl_file:\n", 133 | " pickle.dump(data_dict, pkl_file)\n", 134 | " tv.utils.save_image(syth_img.squeeze(), file_path)\n", 135 | "\n", 136 | " if index >= 50:\n", 137 | " break" 138 | ] 139 | } 140 | ], 141 | "metadata": { 142 | "kernelspec": { 143 | "display_name": "Python 3", 144 | "language": "python", 145 | "name": "python3" 146 | }, 147 | "language_info": { 148 | "codemirror_mode": { 149 | "name": "ipython", 150 | "version": 3 151 | }, 152 | "file_extension": ".py", 153 | "mimetype": "text/x-python", 154 | "name": "python", 155 | "nbconvert_exporter": "python", 156 | "pygments_lexer": "ipython3", 157 | "version": "3.8.3" 158 | }, 159 | "toc": { 160 | "nav_menu": {}, 161 | "number_sections": true, 162 | "sideBar": true, 163 | "skip_h1_title": true, 164 | "title_cell": "Table of Contents", 165 | "title_sidebar": "Contents", 166 | "toc_cell": false, 167 | "toc_position": {}, 168 | "toc_section_display": true, 169 | "toc_window_display": false 170 | }, 171 | "varInspector": { 172 | "cols": { 173 | "lenName": 16, 174 | "lenType": 16, 175 | "lenVar": 40 176 | }, 177 | "kernels_config": { 178 | "python": { 179 | "delete_cmd_postfix": "", 180 | "delete_cmd_prefix": "del ", 181 | "library": "var_list.py", 182 | "varRefreshCmd": "print(var_dic_list())" 183 | }, 184 | "r": { 185 | "delete_cmd_postfix": ") ", 186 | "delete_cmd_prefix": "rm(", 187 | "library": "var_list.r", 188 | "varRefreshCmd": "cat(var_dic_list()) " 189 | } 190 | }, 191 | "types_to_exclude": [ 192 | "module", 193 | "function", 194 | "builtin_function_or_method", 195 | "instance", 196 | "_Feature" 197 | ], 198 | "window_display": false 199 | } 200 | }, 201 | "nbformat": 4, 202 | "nbformat_minor": 2 203 | } 204 | -------------------------------------------------------------------------------- /models/module.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Basic Model Manipulations" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import os\n", 32 | "import time\n", 33 | "import torch as t\n", 34 | "import math as m" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "heading_collapsed": true 41 | }, 42 | "source": [ 43 | "## Basic methods" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "code_folding": [ 51 | 0 52 | ], 53 | "hidden": true 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "class BasicModule(t.nn.Module):\n", 58 | " def __init__(self):\n", 59 | " super(BasicModule, self).__init__()\n", 60 | " self.model_name = str(type(self))\n", 61 | "\n", 62 | " def load(self, root, device=None):\n", 63 | " save_list = [\n", 64 | " file for file in os.listdir(root)\n", 65 | " if file.startswith(self.model_name)\n", 66 | " ]\n", 67 | " save_list.sort()\n", 68 | " file_path = os.path.join(root, save_list[-1])\n", 69 | " state_dict = t.load(file_path, map_location=device)\n", 70 | " self.load_state_dict(t.load(file_path, map_location=device))\n", 71 | " print('Weights loaded: %s' % file_path)\n", 72 | "\n", 73 | " return len(save_list)\n", 74 | "\n", 75 | " def loadPartialDict(self, file_path, device=None):\n", 76 | " pretrained_dict = t.load(file_path, map_location=device)\n", 77 | " model_dict = self.state_dict()\n", 78 | " pretrained_dict = {\n", 79 | " key: value\n", 80 | " for key, value in pretrained_dict.items() if key in model_dict\n", 81 | " }\n", 82 | " model_dict.update(pretrained_dict)\n", 83 | " self.load_state_dict(model_dict)\n", 84 | " print('Partial weights loaded: %s' % file_path)\n", 85 | "\n", 86 | " def save(self):\n", 87 | " prefix = './saves/' + self.model_name + '_'\n", 88 | " file_name = time.strftime(prefix + '%m%d-%H%M%S.pth')\n", 89 | " t.save(self.state_dict(), file_name)\n", 90 | " print('Weights saved: %s' % file_name)\n", 91 | "\n", 92 | " def initLayers(self):\n", 93 | " for module in self.modules():\n", 94 | " if isinstance(module, (t.nn.Conv2d, t.nn.Linear)):\n", 95 | " t.nn.init.xavier_normal_(module.weight)" 96 | ] 97 | } 98 | ], 99 | "metadata": { 100 | "kernelspec": { 101 | "display_name": "Python 3", 102 | "language": "python", 103 | "name": "python3" 104 | }, 105 | "language_info": { 106 | "codemirror_mode": { 107 | "name": "ipython", 108 | "version": 3 109 | }, 110 | "file_extension": ".py", 111 | "mimetype": "text/x-python", 112 | "name": "python", 113 | "nbconvert_exporter": "python", 114 | "pygments_lexer": "ipython3", 115 | "version": "3.6.9" 116 | }, 117 | "toc": { 118 | "nav_menu": {}, 119 | "number_sections": true, 120 | "sideBar": true, 121 | "skip_h1_title": true, 122 | "title_cell": "Table of Contents", 123 | "title_sidebar": "Contents", 124 | "toc_cell": false, 125 | "toc_position": {}, 126 | "toc_section_display": true, 127 | "toc_window_display": false 128 | }, 129 | "varInspector": { 130 | "cols": { 131 | "lenName": 16, 132 | "lenType": 16, 133 | "lenVar": 40 134 | }, 135 | "kernels_config": { 136 | "python": { 137 | "delete_cmd_postfix": "", 138 | "delete_cmd_prefix": "del ", 139 | "library": "var_list.py", 140 | "varRefreshCmd": "print(var_dic_list())" 141 | }, 142 | "r": { 143 | "delete_cmd_postfix": ") ", 144 | "delete_cmd_prefix": "rm(", 145 | "library": "var_list.r", 146 | "varRefreshCmd": "cat(var_dic_list()) " 147 | } 148 | }, 149 | "types_to_exclude": [ 150 | "module", 151 | "function", 152 | "builtin_function_or_method", 153 | "instance", 154 | "_Feature" 155 | ], 156 | "window_display": false 157 | } 158 | }, 159 | "nbformat": 4, 160 | "nbformat_minor": 2 161 | } 162 | -------------------------------------------------------------------------------- /models/network.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# User Defined Network" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import math as m\n", 32 | "import torch as t\n", 33 | "from collections import OrderedDict\n", 34 | "from ipynb.fs.full.module import BasicModule" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "## Modules" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "code_folding": [ 49 | 0, 50 | 30, 51 | 48, 52 | 73, 53 | 94 54 | ] 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "class resDense(BasicModule):\n", 59 | " def __init__(self, channels, layers, growth_rate):\n", 60 | " super(resDense, self).__init__()\n", 61 | "\n", 62 | " # residual dense layers\n", 63 | " self.features = t.nn.ModuleList([])\n", 64 | " inter_channels = channels\n", 65 | " for index in range(0, layers):\n", 66 | " self.features.append(\n", 67 | " t.nn.Conv2d(inter_channels,\n", 68 | " growth_rate,\n", 69 | " 3,\n", 70 | " padding=1,\n", 71 | " bias=False))\n", 72 | " inter_channels += growth_rate\n", 73 | " self.relu = t.nn.ReLU()\n", 74 | "\n", 75 | " # fusion layer\n", 76 | " self.fusion = t.nn.Conv2d(inter_channels, channels, 1, bias=False)\n", 77 | "\n", 78 | " def forward(self, x):\n", 79 | " res = x\n", 80 | " for layer in self.features:\n", 81 | " out = self.relu(layer(res))\n", 82 | " res = t.cat([res, out], dim=1)\n", 83 | " out = self.fusion(res)\n", 84 | "\n", 85 | " return x + out\n", 86 | "\n", 87 | "\n", 88 | "class channelAtt(BasicModule):\n", 89 | " def __init__(self, channels):\n", 90 | " super(channelAtt, self).__init__()\n", 91 | "\n", 92 | " # squeeze-excitation layer\n", 93 | " self.glb_pool = t.nn.AdaptiveAvgPool2d((1, 1))\n", 94 | " self.squeeze_excite = t.nn.Sequential(\n", 95 | " t.nn.Linear(channels, int(channels / 16)), t.nn.LeakyReLU(0.2),\n", 96 | " t.nn.Linear(int(channels / 16), channels), t.nn.Sigmoid())\n", 97 | "\n", 98 | " def forward(self, x):\n", 99 | " scale = self.glb_pool(x)\n", 100 | " scale = self.squeeze_excite(scale.squeeze())\n", 101 | " x = scale.view((x.size(0), x.size(1), 1, 1)) * x\n", 102 | "\n", 103 | " return x\n", 104 | "\n", 105 | "\n", 106 | "class encode(BasicModule):\n", 107 | " def __init__(self, in_channels, out_channels, max_pool=True):\n", 108 | " super(encode, self).__init__()\n", 109 | "\n", 110 | " # features\n", 111 | " if max_pool:\n", 112 | " self.features = t.nn.Sequential(\n", 113 | " t.nn.MaxPool2d((2, 2)),\n", 114 | " t.nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", 115 | " t.nn.LeakyReLU(0.2),\n", 116 | " t.nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", 117 | " channelAtt(out_channels), t.nn.LeakyReLU(0.2))\n", 118 | " else:\n", 119 | " self.features = t.nn.Sequential(\n", 120 | " t.nn.Conv2d(in_channels, out_channels, 3, padding=1),\n", 121 | " t.nn.LeakyReLU(0.2),\n", 122 | " t.nn.Conv2d(out_channels, out_channels, 3, padding=1),\n", 123 | " channelAtt(out_channels), t.nn.LeakyReLU(0.2))\n", 124 | "\n", 125 | " def forward(self, x):\n", 126 | " x = self.features(x)\n", 127 | "\n", 128 | " return x\n", 129 | "\n", 130 | "\n", 131 | "class skipConn(BasicModule):\n", 132 | " def __init__(self, in_channels, out_channels, avg_pool=True):\n", 133 | " super(skipConn, self).__init__()\n", 134 | "\n", 135 | " # features\n", 136 | " if avg_pool:\n", 137 | " self.features = t.nn.Sequential(\n", 138 | " t.nn.AvgPool2d((2, 2)),\n", 139 | " t.nn.Conv2d(in_channels, out_channels, 1),\n", 140 | " channelAtt(out_channels), t.nn.Tanh())\n", 141 | " else:\n", 142 | " self.features = t.nn.Sequential(\n", 143 | " t.nn.Conv2d(in_channels, out_channels, 1),\n", 144 | " channelAtt(out_channels), t.nn.Tanh())\n", 145 | "\n", 146 | " def forward(self, x):\n", 147 | " x = self.features(x)\n", 148 | "\n", 149 | " return x\n", 150 | "\n", 151 | "\n", 152 | "class decode(BasicModule):\n", 153 | " def __init__(self,\n", 154 | " in_channels,\n", 155 | " inter_channels,\n", 156 | " out_channels,\n", 157 | " up_sample=True):\n", 158 | " super(decode, self).__init__()\n", 159 | "\n", 160 | " # features\n", 161 | " if up_sample:\n", 162 | " self.features = t.nn.Sequential(\n", 163 | " t.nn.Conv2d(in_channels, inter_channels, 1),\n", 164 | " t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1),\n", 165 | " t.nn.LeakyReLU(0.2),\n", 166 | " t.nn.Upsample(scale_factor=2, mode='nearest'),\n", 167 | " t.nn.Conv2d(inter_channels, out_channels, 3, padding=1),\n", 168 | " channelAtt(out_channels), t.nn.LeakyReLU(0.2))\n", 169 | " else:\n", 170 | " self.features = t.nn.Sequential(\n", 171 | " t.nn.Conv2d(in_channels, inter_channels, 1),\n", 172 | " t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1),\n", 173 | " t.nn.LeakyReLU(0.2),\n", 174 | " t.nn.Conv2d(inter_channels, out_channels, 3, padding=1),\n", 175 | " t.nn.LeakyReLU(0.2))\n", 176 | "\n", 177 | " def forward(self, x):\n", 178 | " x = self.features(x)\n", 179 | "\n", 180 | " return x" 181 | ] 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "metadata": {}, 186 | "source": [ 187 | "## r2rNet" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": null, 193 | "metadata": { 194 | "code_folding": [ 195 | 0 196 | ] 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "class r2rNet(BasicModule):\n", 201 | " def __init__(self, channels=64, rdbs=4, convs=8, growth_rate=32):\n", 202 | " super(r2rNet, self).__init__()\n", 203 | " self.model_name = 'r2rNet'\n", 204 | " self.rdbs = rdbs\n", 205 | "\n", 206 | " # feature extraction\n", 207 | " self.head = t.nn.Conv2d(7, channels, 3, padding=1)\n", 208 | "\n", 209 | " # RDBs\n", 210 | " self.features = t.nn.ModuleList(\n", 211 | " [t.nn.Conv2d(channels, channels, 3, padding=1)])\n", 212 | " for index in range(0, rdbs):\n", 213 | " self.features.append(resDense(channels, convs, growth_rate))\n", 214 | " self.features.append(t.nn.Conv2d(channels * rdbs, channels, 1))\n", 215 | " self.features.append(t.nn.Conv2d(channels, channels, 3, padding=1))\n", 216 | "\n", 217 | " # final fusion\n", 218 | " self.final = t.nn.Sequential(\n", 219 | " t.nn.Conv2d(channels, channels, 3, padding=1),\n", 220 | " t.nn.Conv2d(channels, 4, 1))\n", 221 | "\n", 222 | " def forward(self, img, wb):\n", 223 | " out = self.head(t.cat([img, wb], dim=1))\n", 224 | " res_0 = self.features[0](out)\n", 225 | " res_n = [res_0]\n", 226 | " for index in range(1, self.rdbs + 1):\n", 227 | " res_n.append(self.features[index](res_n[index - 1]))\n", 228 | " res_r2 = self.features[-2](t.cat(res_n[1:], dim=1))\n", 229 | " res_r1 = self.features[-1](res_r2)\n", 230 | " out = self.final(out + res_r1)\n", 231 | "\n", 232 | " return out" 233 | ] 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "metadata": {}, 238 | "source": [ 239 | "## Gain estimation module" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "metadata": { 246 | "code_folding": [] 247 | }, 248 | "outputs": [], 249 | "source": [ 250 | "class gainEst(BasicModule):\n", 251 | " def __init__(self):\n", 252 | " super(gainEst, self).__init__()\n", 253 | " self.model_name = 'gainEst'\n", 254 | "\n", 255 | " # encoders\n", 256 | " self.head = encode(3, 64, max_pool=False)\n", 257 | " self.down1 = encode(64, 96, max_pool=True)\n", 258 | " self.down2 = encode(96, 128, max_pool=True)\n", 259 | " self.down3 = encode(128, 192, max_pool=True)\n", 260 | "\n", 261 | " # bottleneck\n", 262 | " self.bottleneck = t.nn.Sequential(\n", 263 | " t.nn.MaxPool2d(2, 2), t.nn.Conv2d(192, 256, 3, padding=1),\n", 264 | " t.nn.LeakyReLU(0.2), t.nn.Conv2d(256, 256, 3, padding=1),\n", 265 | " t.nn.Upsample(scale_factor=2, mode='nearest'),\n", 266 | " t.nn.Conv2d(256, 192, 3, padding=1), channelAtt(192),\n", 267 | " t.nn.LeakyReLU(0.2))\n", 268 | "\n", 269 | " # decoders\n", 270 | " self.up1 = decode(384, 384, 128, up_sample=True)\n", 271 | " self.up2 = decode(256, 256, 96, up_sample=True)\n", 272 | " self.up3 = decode(192, 192, 64, up_sample=True)\n", 273 | " self.seg_out = t.nn.Sequential(decode(128, 128, 64, up_sample=False),\n", 274 | " t.nn.Conv2d(64, 2, 1))\n", 275 | "\n", 276 | " # external actication\n", 277 | " self.sigmoid = t.nn.Sigmoid()\n", 278 | "\n", 279 | " # prediction\n", 280 | " self.features = t.nn.Sequential(\n", 281 | " t.nn.Conv2d(5, 64, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 282 | " t.nn.Conv2d(64, 96, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 283 | " t.nn.Conv2d(96, 128, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 284 | " t.nn.Conv2d(128, 192, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2),\n", 285 | " t.nn.Conv2d(192, 256, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2))\n", 286 | " self.amp_out = t.nn.Sequential(t.nn.Linear(8 * 6 * 256,\n", 287 | " 128), t.nn.LeakyReLU(0.2),\n", 288 | " t.nn.Linear(128, 64),\n", 289 | " t.nn.LeakyReLU(0.2), t.nn.Linear(64, 2))\n", 290 | "\n", 291 | " # initialization\n", 292 | " self.initLayers()\n", 293 | "\n", 294 | " def forward(self, thumb_img, struct_img):\n", 295 | " # segmentation\n", 296 | " out_head = self.head(struct_img)\n", 297 | " out_d1 = self.down1(out_head)\n", 298 | " out_d2 = self.down2(out_d1)\n", 299 | " out_d3 = self.down3(out_d2)\n", 300 | " out_bottleneck = self.bottleneck(out_d3)\n", 301 | " out_u1 = self.up1(t.cat([out_d3, out_bottleneck], dim=1))\n", 302 | " out_u2 = self.up2(t.cat([out_d2, out_u1], dim=1))\n", 303 | " out_u3 = self.up3(t.cat([out_d1, out_u2], dim=1))\n", 304 | " out_mask = self.seg_out(t.cat([out_head, out_u3], dim=1))\n", 305 | "\n", 306 | " # prediction\n", 307 | " out_features = self.features(\n", 308 | " t.cat([thumb_img, self.sigmoid(out_mask)], dim=1))\n", 309 | " out_amp = self.amp_out(out_features.view(out_features.size(0), -1))\n", 310 | " out_amp = t.clamp(out_amp, 0.0, 1.0)\n", 311 | "\n", 312 | " return out_mask, out_amp" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "metadata": {}, 318 | "source": [ 319 | "## Raw processing module" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": { 326 | "code_folding": [ 327 | 49 328 | ] 329 | }, 330 | "outputs": [], 331 | "source": [ 332 | "class ispNet(BasicModule):\n", 333 | " def __init__(self):\n", 334 | " super(ispNet, self).__init__()\n", 335 | "\n", 336 | " # encoders\n", 337 | " self.head = encode(8, 64, max_pool=False)\n", 338 | " self.down1 = encode(64, 64, max_pool=True)\n", 339 | " self.down2 = encode(64, 64, max_pool=True)\n", 340 | "\n", 341 | " # skip connections\n", 342 | " self.skip1 = skipConn(1, 64, avg_pool=False)\n", 343 | " self.skip2 = skipConn(64, 64, avg_pool=True)\n", 344 | " self.skip3 = skipConn(64, 64, avg_pool=True)\n", 345 | "\n", 346 | " # decoders\n", 347 | " self.up1 = decode(128, 64, 64, up_sample=True)\n", 348 | " self.up2 = decode(128, 64, 64, up_sample=True)\n", 349 | " self.srgb_out = t.nn.Sequential(\n", 350 | " decode(128, 64, 64, up_sample=False),\n", 351 | " t.nn.Upsample(scale_factor=2, mode='nearest'),\n", 352 | " t.nn.Conv2d(64, 3, 3, padding=1))\n", 353 | "\n", 354 | " def forward(self, color_map, mag_map, amp, wb):\n", 355 | " # to prevent saturation\n", 356 | " mag_map = amp.view(-1, 1, 1, 1) * mag_map\n", 357 | " mag_map = t.nn.functional.tanh(mag_map - 0.5)\n", 358 | " max_mag = 2.0 * amp.view(-1, 1, 1, 1)\n", 359 | " max_mag = t.nn.functional.tanh(max_mag - 0.5)\n", 360 | " mag_map = mag_map / max_mag\n", 361 | "\n", 362 | " # encoder outputs\n", 363 | " out_head = self.head(t.cat([color_map, mag_map, wb], dim=1))\n", 364 | " out_d1 = self.down1(out_head)\n", 365 | " out_d2 = self.down2(out_d1)\n", 366 | "\n", 367 | " # skip connection outputs\n", 368 | " out_s1 = self.skip1(mag_map)\n", 369 | " out_s2 = self.skip2(out_head)\n", 370 | " out_s3 = self.skip3(out_d1)\n", 371 | "\n", 372 | " # decoder outputs\n", 373 | " out_u1 = self.up1(t.cat([out_s3, out_d2], dim=1))\n", 374 | " out_u2 = self.up2(t.cat([out_s2, out_u1], dim=1))\n", 375 | " out_srgb = self.srgb_out(t.cat([out_s1, out_u2], dim=1))\n", 376 | " out_srgb = t.clamp(out_srgb, 0.0, 1.0)\n", 377 | "\n", 378 | " return out_srgb\n", 379 | "\n", 380 | "\n", 381 | "class rawProcess(BasicModule):\n", 382 | " def __init__(self):\n", 383 | " super(rawProcess, self).__init__()\n", 384 | " self.model_name = 'rawProcess'\n", 385 | "\n", 386 | " # isp module\n", 387 | " self.isp_net = ispNet()\n", 388 | "\n", 389 | " # fusion\n", 390 | " self.fusion = t.nn.Sequential(t.nn.Conv2d(6, 128, 3, padding=1),\n", 391 | " channelAtt(128),\n", 392 | " t.nn.Conv2d(128, 3, 3, padding=1))\n", 393 | "\n", 394 | " # initialization\n", 395 | " self.initLayers()\n", 396 | "\n", 397 | " def forward(self, raw_data, amp_high, amp_low, wb):\n", 398 | " # convert to color map and mgnitude map\n", 399 | " mag_map = t.sqrt(t.sum(t.pow(raw_data, 2), 1, keepdim=True))\n", 400 | " color_map = raw_data / (mag_map + 1e-4)\n", 401 | "\n", 402 | " # convert to sRGB images\n", 403 | " out_high = self.isp_net(color_map, mag_map, amp_high, wb)\n", 404 | " out_low = self.isp_net(color_map, mag_map, amp_low, wb)\n", 405 | "\n", 406 | " # image fusion\n", 407 | " out_fused = self.fusion(t.cat([out_high, out_low], dim=1))\n", 408 | " out_fused = t.clamp(out_fused, 0.0, 1.0)\n", 409 | "\n", 410 | " return out_high, out_low, out_fused" 411 | ] 412 | } 413 | ], 414 | "metadata": { 415 | "kernelspec": { 416 | "display_name": "Python 3", 417 | "language": "python", 418 | "name": "python3" 419 | }, 420 | "language_info": { 421 | "codemirror_mode": { 422 | "name": "ipython", 423 | "version": 3 424 | }, 425 | "file_extension": ".py", 426 | "mimetype": "text/x-python", 427 | "name": "python", 428 | "nbconvert_exporter": "python", 429 | "pygments_lexer": "ipython3", 430 | "version": "3.8.3" 431 | } 432 | }, 433 | "nbformat": 4, 434 | "nbformat_minor": 2 435 | } 436 | -------------------------------------------------------------------------------- /samples/DJI_0899.DNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YCL92/DeepSelfie/6ca0c0872b8f1ec1f5e784e4630094e2d035fcf6/samples/DJI_0899.DNG -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | # # Network Testing 5 | 6 | # ## Includes 7 | 8 | # In[ ]: 9 | 10 | 11 | # mass includes 12 | import os, sys, argparse 13 | import numpy as np 14 | import pyexiv2 as exiv2 15 | import rawpy as rp 16 | import torch as t 17 | from torchvision.utils import save_image 18 | 19 | 20 | # ## Modules 21 | 22 | # In[ ]: 23 | 24 | 25 | class BasicModule(t.nn.Module): 26 | def __init__(self): 27 | super(BasicModule, self).__init__() 28 | self.model_name = str(type(self)) 29 | 30 | def load(self, root, device=None): 31 | save_list = [ 32 | file for file in os.listdir(root) 33 | if file.startswith(self.model_name) 34 | ] 35 | save_list.sort() 36 | file_path = os.path.join(root, save_list[-1]) 37 | state_dict = t.load(file_path, map_location=device) 38 | self.load_state_dict(t.load(file_path, map_location=device)) 39 | print('Weights loaded: %s' % file_path) 40 | 41 | return save_list[-1].split('_')[-1][:-4] 42 | 43 | 44 | class channelAtt(BasicModule): 45 | def __init__(self, channels): 46 | super(channelAtt, self).__init__() 47 | 48 | # squeeze-excitation layer 49 | self.glb_pool = t.nn.AdaptiveAvgPool2d((1, 1)) 50 | self.squeeze_excite = t.nn.Sequential( 51 | t.nn.Linear(channels, int(channels / 16)), t.nn.LeakyReLU(0.2), 52 | t.nn.Linear(int(channels / 16), channels), t.nn.Sigmoid()) 53 | 54 | def forward(self, x): 55 | scale = self.glb_pool(x) 56 | scale = self.squeeze_excite(scale.squeeze()) 57 | x = scale.view((x.size(0), x.size(1), 1, 1)) * x 58 | 59 | return x 60 | 61 | 62 | class encode(BasicModule): 63 | def __init__(self, in_channels, out_channels, max_pool=True): 64 | super(encode, self).__init__() 65 | 66 | # features 67 | if max_pool: 68 | self.features = t.nn.Sequential( 69 | t.nn.MaxPool2d((2, 2)), 70 | t.nn.Conv2d(in_channels, out_channels, 3, padding=1), 71 | t.nn.LeakyReLU(0.2), 72 | t.nn.Conv2d(out_channels, out_channels, 3, padding=1), 73 | channelAtt(out_channels), t.nn.LeakyReLU(0.2)) 74 | else: 75 | self.features = t.nn.Sequential( 76 | t.nn.Conv2d(in_channels, out_channels, 3, padding=1), 77 | t.nn.LeakyReLU(0.2), 78 | t.nn.Conv2d(out_channels, out_channels, 3, padding=1), 79 | channelAtt(out_channels), t.nn.LeakyReLU(0.2)) 80 | 81 | def forward(self, x): 82 | x = self.features(x) 83 | 84 | return x 85 | 86 | 87 | class skipConn(BasicModule): 88 | def __init__(self, in_channels, out_channels, avg_pool=True): 89 | super(skipConn, self).__init__() 90 | 91 | # features 92 | if avg_pool: 93 | self.features = t.nn.Sequential( 94 | t.nn.AvgPool2d((2, 2)), 95 | t.nn.Conv2d(in_channels, out_channels, 1), 96 | channelAtt(out_channels), t.nn.Tanh()) 97 | else: 98 | self.features = t.nn.Sequential( 99 | t.nn.Conv2d(in_channels, out_channels, 1), 100 | channelAtt(out_channels), t.nn.Tanh()) 101 | 102 | def forward(self, x): 103 | x = self.features(x) 104 | 105 | return x 106 | 107 | 108 | class decode(BasicModule): 109 | def __init__(self, 110 | in_channels, 111 | inter_channels, 112 | out_channels, 113 | up_sample=True): 114 | super(decode, self).__init__() 115 | 116 | # features 117 | if up_sample: 118 | self.features = t.nn.Sequential( 119 | t.nn.Conv2d(in_channels, inter_channels, 1), 120 | t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1), 121 | t.nn.LeakyReLU(0.2), 122 | t.nn.Upsample(scale_factor=2, mode='nearest'), 123 | t.nn.Conv2d(inter_channels, out_channels, 3, padding=1), 124 | channelAtt(out_channels), t.nn.LeakyReLU(0.2)) 125 | else: 126 | self.features = t.nn.Sequential( 127 | t.nn.Conv2d(in_channels, inter_channels, 1), 128 | t.nn.Conv2d(inter_channels, inter_channels, 3, padding=1), 129 | t.nn.LeakyReLU(0.2), 130 | t.nn.Conv2d(inter_channels, out_channels, 3, padding=1), 131 | t.nn.LeakyReLU(0.2)) 132 | 133 | def forward(self, x): 134 | x = self.features(x) 135 | 136 | return x 137 | 138 | 139 | class gainEst(BasicModule): 140 | def __init__(self): 141 | super(gainEst, self).__init__() 142 | self.model_name = 'gainEst' 143 | 144 | # encoders 145 | self.head = encode(3, 64, max_pool=False) 146 | self.down1 = encode(64, 96, max_pool=True) 147 | self.down2 = encode(96, 128, max_pool=True) 148 | self.down3 = encode(128, 192, max_pool=True) 149 | 150 | # bottleneck 151 | self.bottleneck = t.nn.Sequential( 152 | t.nn.MaxPool2d(2, 2), t.nn.Conv2d(192, 256, 3, padding=1), 153 | t.nn.LeakyReLU(0.2), t.nn.Conv2d(256, 256, 3, padding=1), 154 | t.nn.Upsample(scale_factor=2, mode='nearest'), 155 | t.nn.Conv2d(256, 192, 3, padding=1), channelAtt(192), 156 | t.nn.LeakyReLU(0.2)) 157 | 158 | # decoders 159 | self.up1 = decode(384, 384, 128, up_sample=True) 160 | self.up2 = decode(256, 256, 96, up_sample=True) 161 | self.up3 = decode(192, 192, 64, up_sample=True) 162 | self.seg_out = t.nn.Sequential(decode(128, 128, 64, up_sample=False), 163 | t.nn.Conv2d(64, 2, 1)) 164 | 165 | # external actication 166 | self.sigmoid = t.nn.Sigmoid() 167 | 168 | # prediction 169 | self.features = t.nn.Sequential( 170 | t.nn.Conv2d(5, 64, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 171 | t.nn.Conv2d(64, 96, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 172 | t.nn.Conv2d(96, 128, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 173 | t.nn.Conv2d(128, 192, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2), 174 | t.nn.Conv2d(192, 256, 3, stride=2, padding=1), t.nn.LeakyReLU(0.2)) 175 | self.amp_out = t.nn.Sequential(t.nn.Linear(8 * 6 * 256, 176 | 128), t.nn.LeakyReLU(0.2), 177 | t.nn.Linear(128, 64), 178 | t.nn.LeakyReLU(0.2), t.nn.Linear(64, 2)) 179 | 180 | def forward(self, thumb_img, struct_img): 181 | # segmentation 182 | out_head = self.head(struct_img) 183 | out_d1 = self.down1(out_head) 184 | out_d2 = self.down2(out_d1) 185 | out_d3 = self.down3(out_d2) 186 | out_bottleneck = self.bottleneck(out_d3) 187 | out_u1 = self.up1(t.cat([out_d3, out_bottleneck], dim=1)) 188 | out_u2 = self.up2(t.cat([out_d2, out_u1], dim=1)) 189 | out_u3 = self.up3(t.cat([out_d1, out_u2], dim=1)) 190 | out_mask = self.seg_out(t.cat([out_head, out_u3], dim=1)) 191 | 192 | # prediction 193 | out_features = self.features( 194 | t.cat([thumb_img, self.sigmoid(out_mask)], dim=1)) 195 | out_amp = self.amp_out(out_features.view(out_features.size(0), -1)) 196 | out_amp = t.clamp(out_amp, 0.0, 1.0) 197 | 198 | return out_mask, out_amp 199 | 200 | 201 | class ispNet(BasicModule): 202 | def __init__(self): 203 | super(ispNet, self).__init__() 204 | 205 | # encoders 206 | self.head = encode(8, 64, max_pool=False) 207 | self.down1 = encode(64, 64, max_pool=True) 208 | self.down2 = encode(64, 64, max_pool=True) 209 | 210 | # skip connections 211 | self.skip1 = skipConn(1, 64, avg_pool=False) 212 | self.skip2 = skipConn(64, 64, avg_pool=True) 213 | self.skip3 = skipConn(64, 64, avg_pool=True) 214 | 215 | # decoders 216 | self.up1 = decode(128, 64, 64, up_sample=True) 217 | self.up2 = decode(128, 64, 64, up_sample=True) 218 | self.srgb_out = t.nn.Sequential( 219 | decode(128, 64, 64, up_sample=False), 220 | t.nn.Upsample(scale_factor=2, mode='nearest'), 221 | t.nn.Conv2d(64, 3, 3, padding=1)) 222 | 223 | def forward(self, color_map, mag_map, amp, wb): 224 | # to prevent saturation 225 | mag_map = amp.view(-1, 1, 1, 1) * mag_map 226 | mag_map = t.nn.functional.tanh(mag_map - 0.5) 227 | max_mag = 2.0 * amp.view(-1, 1, 1, 1) 228 | max_mag = t.nn.functional.tanh(max_mag - 0.5) 229 | mag_map = mag_map / max_mag 230 | 231 | # encoder outputs 232 | out_head = self.head(t.cat([color_map, mag_map, wb], dim=1)) 233 | out_d1 = self.down1(out_head) 234 | out_d2 = self.down2(out_d1) 235 | 236 | # skip connection outputs 237 | out_s1 = self.skip1(mag_map) 238 | out_s2 = self.skip2(out_head) 239 | out_s3 = self.skip3(out_d1) 240 | 241 | # decoder outputs 242 | out_u1 = self.up1(t.cat([out_s3, out_d2], dim=1)) 243 | out_u2 = self.up2(t.cat([out_s2, out_u1], dim=1)) 244 | out_srgb = self.srgb_out(t.cat([out_s1, out_u2], dim=1)) 245 | out_srgb = t.clamp(out_srgb, 0.0, 1.0) 246 | 247 | return out_srgb 248 | 249 | 250 | class rawProcess(BasicModule): 251 | def __init__(self): 252 | super(rawProcess, self).__init__() 253 | self.model_name = 'rawProcess' 254 | 255 | # isp module 256 | self.isp_net = ispNet() 257 | 258 | # fusion 259 | self.fusion = t.nn.Sequential(t.nn.Conv2d(6, 128, 3, padding=1), 260 | channelAtt(128), 261 | t.nn.Conv2d(128, 3, 3, padding=1)) 262 | 263 | def forward(self, raw_data, amp_high, amp_low, wb): 264 | # convert to color map and mgnitude map 265 | mag_map = t.sqrt(t.sum(t.pow(raw_data, 2), 1, keepdim=True)) 266 | color_map = raw_data / (mag_map + 1e-4) 267 | 268 | # convert to sRGB images 269 | out_high = self.isp_net(color_map, mag_map, amp_high, wb) 270 | out_low = self.isp_net(color_map, mag_map, amp_low, wb) 271 | 272 | # image fusion 273 | out_fused = self.fusion(t.cat([out_high, out_low], dim=1)) 274 | out_fused = t.clamp(out_fused, 0.0, 1.0) 275 | 276 | return out_high,out_low,out_fused 277 | 278 | 279 | # ## Test 280 | 281 | # In[ ]: 282 | 283 | 284 | # normalization 285 | def normalize(raw_data, bk_level, sat_level): 286 | normal_raw = t.empty_like(raw_data) 287 | for index in range(raw_data.size(0)): 288 | for channel in range(raw_data.size(1)): 289 | normal_raw[index, channel, :, :] = ( 290 | raw_data[index, channel, :, :] - 291 | bk_level[channel]) / (sat_level - bk_level[channel]) 292 | 293 | return normal_raw 294 | 295 | 296 | # resize Bayer pattern 297 | def downSample(raw_data, struct_img_size): 298 | # convert Bayer pattern to down-sized sRGB image 299 | batch, _, hei, wid = raw_data.size() 300 | raw_img = raw_data.new_empty((batch, 3, hei, wid)) 301 | raw_img[:, 0, :, :] = raw_data[:, 0, :, :] # R 302 | raw_img[:, 303 | 1, :, :] = (raw_data[:, 1, :, :] + raw_data[:, 2, :, :]) / 2.0 # G 304 | raw_img[:, 2, :, :] = raw_data[:, 3, :, :] # B 305 | 306 | # down-sample to small size 307 | if hei != struct_img_size[1] and wid != struct_img_size[0]: 308 | raw_img = t.nn.functional.interpolate(raw_img, 309 | size=(struct_img_size[1], 310 | struct_img_size[0]), 311 | mode='bicubic') 312 | raw_img = t.clamp(raw_img, 0.0, 1.0) 313 | 314 | return raw_img 315 | 316 | 317 | # image standardization (mean 0, std 1) 318 | def standardize(srgb_img): 319 | struct_img = t.empty_like(srgb_img) 320 | adj_std = 1.0 / t.sqrt(srgb_img.new_tensor(srgb_img[0, :, :, :].numel())) 321 | for index in range(srgb_img.size(0)): 322 | mean = t.mean(srgb_img[index, :, :, :]) 323 | std = t.std(srgb_img[index, :, :, :]) 324 | adj_std = t.max(std, adj_std) 325 | struct_img[index, :, :, :] = (srgb_img[index, :, :, :] - 326 | mean) / adj_std 327 | 328 | return struct_img 329 | 330 | 331 | # main entry 332 | def main(args): 333 | # initialization 334 | att_size = (256, 192) 335 | amp_range = (1, 20) 336 | 337 | # choose GPU if available 338 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 339 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 340 | device = t.device(args.device) 341 | 342 | # define models 343 | gain_est_model = gainEst().to(device) 344 | gain_est_model.load('./saves', device=device) 345 | gain_est_model.eval() 346 | raw_process_model = rawProcess().to(device) 347 | weight_file = raw_process_model.load('./saves', device=device) 348 | raw_process_model.eval() 349 | args.output = os.path.join('.', weight_file) 350 | 351 | # search for valid files 352 | file_list = [file for file in os.listdir(args.input) if '.DNG' in file] 353 | file_list.sort() 354 | if not os.path.exists(args.output): 355 | os.makedirs(args.output) 356 | 357 | # loop to process 358 | for file in file_list: 359 | # read black, saturation, and whitebalance 360 | img_md = exiv2.ImageMetadata(os.path.join(args.input, file)) 361 | img_md.read() 362 | 363 | blk_level = img_md['Exif.SubImage1.BlackLevel'].value 364 | sat_level = img_md['Exif.SubImage1.WhiteLevel'].value 365 | cam_wb = img_md['Exif.Image.AsShotNeutral'].value 366 | 367 | # convert flat Bayer pattern to 4D tensor (RGGB) 368 | raw_img = rp.imread(os.path.join(args.input, file)) 369 | flat_bayer = raw_img.raw_image_visible 370 | raw_data = np.stack((flat_bayer[0::2, 0::2], flat_bayer[0::2, 1::2], 371 | flat_bayer[1::2, 0::2], flat_bayer[1::2, 1::2]), 372 | axis=2) 373 | 374 | with t.no_grad(): 375 | # copy to device 376 | blk_level = t.from_numpy(np.array(blk_level, 377 | dtype=np.float32)).to(device) 378 | sat_level = t.from_numpy(np.array(sat_level, 379 | dtype=np.float32)).to(device) 380 | cam_wb = t.from_numpy(np.array(cam_wb, 381 | dtype=np.float32)).to(device) 382 | raw_data = t.from_numpy(raw_data.astype(np.float32)).to(device) 383 | raw_data = raw_data.permute(2, 0, 1).unsqueeze(0) 384 | 385 | # downsample 386 | if args.resize: 387 | raw_data = t.nn.functional.interpolate(raw_data, 388 | size=args.resize, 389 | mode='bicubic') 390 | 391 | # pre-processing 392 | raw_data = normalize(raw_data, blk_level, sat_level) 393 | cam_wb = cam_wb.view([1, 3, 1, 1]).expand( 394 | [1, 3, raw_data.size(2), 395 | raw_data.size(3)]) 396 | cam_wb = cam_wb.clone() 397 | thumb_img = downSample(raw_data, att_size) 398 | struct_img = standardize(thumb_img) 399 | 400 | # run model 401 | _, pred_amp = gain_est_model(thumb_img, struct_img) 402 | pred_amp = t.clamp(pred_amp * amp_range[1], amp_range[0], 403 | amp_range[1]) 404 | print('Predicted ratio(fg/bg) for %s: %.2f, %.2f.' % 405 | (file, pred_amp[0, 0], pred_amp[0, 1])) 406 | amp_high, _ = t.max(pred_amp, 1) 407 | amp_low, _ = t.min(pred_amp, 1) 408 | pred_high,pred_low,pred_fused = raw_process_model(raw_data, amp_high, amp_low, cam_wb) 409 | 410 | # save to images 411 | # save_image( 412 | # pred_high.cpu().squeeze(), 413 | # os.path.join(args.output, 414 | # '%s' % file.replace('.DNG', '-hi(%.2f).png'%amp_high.squeeze().item()))) 415 | # save_image( 416 | # pred_low.cpu().squeeze(), 417 | # os.path.join(args.output, 418 | # '%s' % file.replace('.DNG', '-lo(%.2f).png'%amp_low.squeeze().item()))) 419 | save_image( 420 | pred_fused.cpu().squeeze(), 421 | os.path.join(args.output, 422 | '%s' % file.replace('.DNG', '-fuse.png'))) 423 | 424 | # fisheye lens calibration 425 | # modified from https://medium.com/@kennethjiang/calibrate-fisheye-lens-using-opencv-333b05afa0b0 426 | if args.calib: 427 | import cv2 428 | 429 | DIM=(4000, 3000) 430 | K = np.array([[1715.9053454852321, 0.0, 2025.0267134780845], 431 | [0.0, 1713.8092418955127, 1511.2242172068645], 432 | [0.0, 0.0, 1.0]]) 433 | D = np.array([[0.21801544244553403], [0.011549797903321477], 434 | [-0.05436236262851618], [-0.01888678272481524]]) 435 | img = cv2.imread(os.path.join(args.output, 436 | '%s' % file.replace('.DNG', '-fuse.png'))) 437 | map1, map2 = cv2.fisheye.initUndistortRectifyMap( 438 | K, D, np.eye(3), K, DIM, cv2.CV_16SC2) 439 | calib_img = cv2.remap(img, 440 | map1, 441 | map2, 442 | interpolation=cv2.INTER_LINEAR, 443 | borderMode=cv2.BORDER_CONSTANT) 444 | cv2.imwrite(os.path.join(args.output, 445 | '%s' % file.replace('.DNG', '-calib.png')), 446 | calib_img) 447 | 448 | 449 | if __name__ == '__main__': 450 | parser = argparse.ArgumentParser() 451 | parser.add_argument('--input', default='./samples', help='input directory') 452 | parser.add_argument('--output', 453 | default=None, 454 | help='output directory') 455 | parser.add_argument('--resize', 456 | default=(600,800), 457 | type=tuple, 458 | help='downsample to smaller size (hxw)') 459 | parser.add_argument('--device', 460 | default='cpu', 461 | help='device to be used (cpu or cuda)') 462 | parser.add_argument('--calib', action='store_true') 463 | args = parser.parse_args() 464 | main(args) 465 | 466 | -------------------------------------------------------------------------------- /train-gainEst.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Network Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import os, sys, warnings\n", 32 | "import ipdb\n", 33 | "import torch as t\n", 34 | "import torchvision as tv\n", 35 | "import torchnet as tnt\n", 36 | "from tqdm.notebook import tqdm\n", 37 | "\n", 38 | "# add paths for all sub-folders\n", 39 | "paths = [root for root, dirs, files in os.walk('.')]\n", 40 | "for item in paths:\n", 41 | " sys.path.append(item)\n", 42 | "\n", 43 | "from ipynb.fs.full.config import mainConf\n", 44 | "from ipynb.fs.full.monitor import Visualizer\n", 45 | "from ipynb.fs.full.network import r2rNet, gainEst\n", 46 | "from ipynb.fs.full.dataLoader import fivekNight, valSet\n", 47 | "from ipynb.fs.full.util import *" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "heading_collapsed": true 54 | }, 55 | "source": [ 56 | "## Initialization" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "code_folding": [ 64 | 0 65 | ], 66 | "hidden": true, 67 | "scrolled": true 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# for debugging only\n", 72 | "%pdb off\n", 73 | "warnings.filterwarnings('ignore')\n", 74 | "\n", 75 | "# choose GPU if available\n", 76 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 77 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 78 | "device = t.device('cuda' if t.cuda.is_available() else 'cpu')\n", 79 | "\n", 80 | "# define models\n", 81 | "opt = mainConf()\n", 82 | "converter = r2rNet().to(device)\n", 83 | "converter.load('./saves')\n", 84 | "converter.eval()\n", 85 | "gain_est_model = gainEst().to(device)\n", 86 | "\n", 87 | "# load pre-trained model if necessary\n", 88 | "if opt.save_root:\n", 89 | " _ = gain_est_model.load(opt.save_root)\n", 90 | "\n", 91 | "# dataloader for training\n", 92 | "train_dataset = fivekNight(opt)\n", 93 | "train_loader = t.utils.data.DataLoader(train_dataset,\n", 94 | " batch_size=opt.batch_size,\n", 95 | " shuffle=True,\n", 96 | " num_workers=opt.num_workers,\n", 97 | " pin_memory=True)\n", 98 | "\n", 99 | "# dataloader for validation\n", 100 | "val_dataset = valSet(opt)\n", 101 | "val_loader = t.utils.data.DataLoader(val_dataset)\n", 102 | "\n", 103 | "# optimizer\n", 104 | "bce_loss = t.nn.BCEWithLogitsLoss()\n", 105 | "l2_loss = t.nn.MSELoss()\n", 106 | "gain_est_optim = t.optim.Adam(gain_est_model.parameters(), lr=opt.lr)\n", 107 | "\n", 108 | "# visualizer\n", 109 | "vis = Visualizer(env='deepSelfie(gainEst)', port=8686)\n", 110 | "gain_est_meter = tnt.meter.AverageValueMeter()" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## Validation" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": { 124 | "code_folding": [ 125 | 0 126 | ] 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "def validate():\n", 131 | " # set to evaluation mode\n", 132 | " gain_est_model.eval()\n", 133 | "\n", 134 | " mask_error = 0.0\n", 135 | " amp_error = 0.0\n", 136 | " for (_, thumb_img, struct_img, seg_mask, amp, _, _, _) in val_loader:\n", 137 | " with t.no_grad():\n", 138 | " # copy to device\n", 139 | " thumb_img = thumb_img.to(device)\n", 140 | " struct_img = struct_img.to(device)\n", 141 | " seg_mask = seg_mask.to(device)\n", 142 | " amp = amp.to(device)\n", 143 | "\n", 144 | " # inference\n", 145 | " pred_mask, pred_amp = gain_est_model(thumb_img, struct_img)\n", 146 | "\n", 147 | " # compute mse\n", 148 | " mask_error += t.mean(\n", 149 | " t.abs(t.nn.functional.sigmoid(pred_mask) - seg_mask))\n", 150 | " amp_error += t.mean(t.abs(pred_amp - amp / opt.amp_range[1]))\n", 151 | " mask_error /= len(val_loader)\n", 152 | " amp_error /= len(val_loader)\n", 153 | "\n", 154 | " # set to training mode\n", 155 | " gain_est_model.train(mode=True)\n", 156 | "\n", 157 | " return mask_error, amp_error" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "metadata": {}, 163 | "source": [ 164 | "## Training entry" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "metadata": { 171 | "code_folding": [ 172 | 0 173 | ] 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "for epoch in range(0, 2):\n", 178 | " # reset meter and gradient\n", 179 | " gain_est_meter.reset()\n", 180 | " gain_est_optim.zero_grad()\n", 181 | "\n", 182 | " for index, (syth_img, syth_mask) in tqdm(enumerate(train_loader),\n", 183 | " desc='epoch %d' % epoch,\n", 184 | " total=len(train_loader)):\n", 185 | " # copy to device\n", 186 | " syth_img = syth_img.to(device)\n", 187 | " syth_mask = syth_mask.to(device)\n", 188 | "\n", 189 | " # convert to training sample\n", 190 | " thumb_img, struct_img, seg_mask, amp, _, _, _ = toRaw(\n", 191 | " converter, syth_img, syth_mask, opt)\n", 192 | "\n", 193 | " # inference\n", 194 | " pred_mask, pred_amp = gain_est_model(thumb_img, struct_img)\n", 195 | "\n", 196 | " # compute loss\n", 197 | " gain_est_loss = bce_loss(pred_mask, seg_mask) + l2_loss(\n", 198 | " pred_amp, amp / opt.amp_range[1])\n", 199 | "\n", 200 | " # compute gradient\n", 201 | " gain_est_loss.backward()\n", 202 | "\n", 203 | " # update parameter and reset gradient\n", 204 | " gain_est_optim.step()\n", 205 | " gain_est_optim.zero_grad()\n", 206 | "\n", 207 | " # add to loss meter for logging\n", 208 | " gain_est_meter.add(gain_est_loss.item())\n", 209 | "\n", 210 | " # show intermediate result\n", 211 | " if (index + 1) % opt.plot_freq == 0:\n", 212 | " vis.plot('loss (gain est)', gain_est_meter.value()[0])\n", 213 | " gain_est_plot = t.cat(\n", 214 | " [seg_mask, t.nn.functional.sigmoid(pred_mask)],\n", 215 | " dim=-1)[0, 0, :, :]\n", 216 | " vis.img('gain est mask gt/pred', gain_est_plot.cpu() * 255)\n", 217 | "\n", 218 | " # save model\n", 219 | " if (index + 1) % opt.save_freq == 0:\n", 220 | " gain_est_model.save()\n", 221 | " mask_error, amp_error = validate()\n", 222 | " vis.log('epoch: %d, err(mask/amp): %.4f, %.4f' %\n", 223 | " (epoch, mask_error, amp_error))" 224 | ] 225 | } 226 | ], 227 | "metadata": { 228 | "kernelspec": { 229 | "display_name": "Python 3", 230 | "language": "python", 231 | "name": "python3" 232 | }, 233 | "language_info": { 234 | "codemirror_mode": { 235 | "name": "ipython", 236 | "version": 3 237 | }, 238 | "file_extension": ".py", 239 | "mimetype": "text/x-python", 240 | "name": "python", 241 | "nbconvert_exporter": "python", 242 | "pygments_lexer": "ipython3", 243 | "version": "3.8.3" 244 | }, 245 | "toc": { 246 | "nav_menu": {}, 247 | "number_sections": true, 248 | "sideBar": true, 249 | "skip_h1_title": true, 250 | "title_cell": "Table of Contents", 251 | "title_sidebar": "Contents", 252 | "toc_cell": false, 253 | "toc_position": {}, 254 | "toc_section_display": true, 255 | "toc_window_display": false 256 | }, 257 | "varInspector": { 258 | "cols": { 259 | "lenName": 16, 260 | "lenType": 16, 261 | "lenVar": 40 262 | }, 263 | "kernels_config": { 264 | "python": { 265 | "delete_cmd_postfix": "", 266 | "delete_cmd_prefix": "del ", 267 | "library": "var_list.py", 268 | "varRefreshCmd": "print(var_dic_list())" 269 | }, 270 | "r": { 271 | "delete_cmd_postfix": ") ", 272 | "delete_cmd_prefix": "rm(", 273 | "library": "var_list.r", 274 | "varRefreshCmd": "cat(var_dic_list()) " 275 | } 276 | }, 277 | "types_to_exclude": [ 278 | "module", 279 | "function", 280 | "builtin_function_or_method", 281 | "instance", 282 | "_Feature" 283 | ], 284 | "window_display": false 285 | } 286 | }, 287 | "nbformat": 4, 288 | "nbformat_minor": 2 289 | } 290 | -------------------------------------------------------------------------------- /train-r2rNet.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Network Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import os, sys, warnings\n", 32 | "import ipdb\n", 33 | "import torch as t\n", 34 | "import torchnet as tnt\n", 35 | "from tqdm.notebook import tqdm\n", 36 | "\n", 37 | "# add paths for all sub-folders\n", 38 | "paths = [root for root, dirs, files in os.walk('.')]\n", 39 | "for item in paths:\n", 40 | " sys.path.append(item)\n", 41 | "\n", 42 | "from ipynb.fs.full.config import r2rNetConf\n", 43 | "from ipynb.fs.full.monitor import Visualizer\n", 44 | "from ipynb.fs.full.network import r2rNet\n", 45 | "from ipynb.fs.full.dataLoader import r2rSet\n", 46 | "from ipynb.fs.full.util import *" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": { 52 | "heading_collapsed": true 53 | }, 54 | "source": [ 55 | "## Initialization" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": { 62 | "code_folding": [ 63 | 0 64 | ], 65 | "hidden": true, 66 | "scrolled": true 67 | }, 68 | "outputs": [], 69 | "source": [ 70 | "# for debugging only\n", 71 | "%pdb off\n", 72 | "warnings.filterwarnings('ignore')\n", 73 | "\n", 74 | "# choose GPU if available\n", 75 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 76 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 77 | "device = t.device('cuda' if t.cuda.is_available() else 'cpu')\n", 78 | "\n", 79 | "# define model\n", 80 | "opt = r2rNetConf()\n", 81 | "model = r2rNet().to(device)\n", 82 | "\n", 83 | "# load pre-trained model if necessary\n", 84 | "if opt.save_root:\n", 85 | " last_epoch = model.load(opt.save_root)\n", 86 | " last_epoch += opt.save_epoch\n", 87 | "else:\n", 88 | " last_epoch = 0\n", 89 | "\n", 90 | "# dataloader for training\n", 91 | "train_dataset = r2rSet(opt, mode='train')\n", 92 | "train_loader = t.utils.data.DataLoader(train_dataset,\n", 93 | " batch_size=opt.batch_size,\n", 94 | " shuffle=True,\n", 95 | " num_workers=opt.num_workers,\n", 96 | " pin_memory=True)\n", 97 | "\n", 98 | "# dataloader for validation\n", 99 | "val_dataset = r2rSet(opt, mode='val')\n", 100 | "val_loader = t.utils.data.DataLoader(val_dataset)\n", 101 | "\n", 102 | "# optimizer\n", 103 | "last_lr = opt.lr * opt.lr_decay**(last_epoch // opt.upd_freq)\n", 104 | "optimizer = t.optim.Adam(model.parameters(), lr=last_lr)\n", 105 | "scheduler = t.optim.lr_scheduler.StepLR(optimizer,\n", 106 | " step_size=opt.upd_freq,\n", 107 | " gamma=opt.lr_decay)\n", 108 | "\n", 109 | "# visualizer\n", 110 | "vis = Visualizer(env='r2rNet', port=8686)\n", 111 | "loss_meter = tnt.meter.AverageValueMeter()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Validation" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "code_folding": [ 126 | 0 127 | ] 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "def validate():\n", 132 | " # set to evaluation mode\n", 133 | " model.eval()\n", 134 | "\n", 135 | " psnr = 0.0\n", 136 | " for (raw_patch, srgb_patch, cam_wb) in val_loader:\n", 137 | " with t.no_grad():\n", 138 | " # copy to device\n", 139 | " raw_patch = raw_patch.to(device)\n", 140 | " srgb_patch = srgb_patch.to(device)\n", 141 | " rggb_patch = toRGGB(srgb_patch)\n", 142 | " cam_wb = cam_wb.to(device)\n", 143 | "\n", 144 | " # inference\n", 145 | " pred_patch = model(rggb_patch, cam_wb)\n", 146 | " pred_patch = t.clamp(pred_patch, 0.0, 1.0)\n", 147 | "\n", 148 | " # compute psnr\n", 149 | " mse = t.mean((pred_patch - raw_patch)**2)\n", 150 | " psnr += 10 * t.log10(1 / mse)\n", 151 | " psnr /= len(val_loader)\n", 152 | "\n", 153 | " # set to training mode\n", 154 | " model.train(mode=True)\n", 155 | "\n", 156 | " return psnr" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "## Training entry" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "metadata": { 170 | "code_folding": [ 171 | 0 172 | ], 173 | "scrolled": true 174 | }, 175 | "outputs": [], 176 | "source": [ 177 | "for epoch in tqdm(range(last_epoch, opt.max_epoch),\n", 178 | " desc='epoch',\n", 179 | " total=opt.max_epoch - last_epoch):\n", 180 | " # reset meter and update learning rate\n", 181 | " loss_meter.reset()\n", 182 | " scheduler.step()\n", 183 | "\n", 184 | " for (raw_patch, srgb_patch, cam_wb) in train_loader:\n", 185 | " # reset gradient\n", 186 | " optimizer.zero_grad()\n", 187 | "\n", 188 | " # copy to device\n", 189 | " raw_patch = raw_patch.to(device)\n", 190 | " srgb_patch = srgb_patch.to(device)\n", 191 | " rggb_patch = toRGGB(srgb_patch)\n", 192 | " cam_wb = cam_wb.to(device)\n", 193 | "\n", 194 | " # inference\n", 195 | " pred_patch = model(rggb_patch, cam_wb)\n", 196 | "\n", 197 | " # compute loss\n", 198 | " loss = t.mean(t.abs(pred_patch - raw_patch))\n", 199 | "\n", 200 | " # backpropagation\n", 201 | " loss.backward()\n", 202 | " optimizer.step()\n", 203 | "\n", 204 | " # add to loss meter for logging\n", 205 | " loss_meter.add(loss.item())\n", 206 | "\n", 207 | " # show training status\n", 208 | " vis.plot('loss', loss_meter.value()[0])\n", 209 | " gt_img = raw2Img(raw_patch[0, :, :, :],\n", 210 | " wb=opt.d65_wb,\n", 211 | " cam_matrix=opt.cam_matrix)\n", 212 | " pred_img = raw2Img(pred_patch[0, :, :, :],\n", 213 | " wb=opt.d65_wb,\n", 214 | " cam_matrix=opt.cam_matrix)\n", 215 | " vis.img('gt/pred/mask', t.cat([gt_img, pred_img], dim=2).cpu() * 255)\n", 216 | "\n", 217 | " # save model and do validation\n", 218 | " if (epoch + 1) > opt.save_epoch or (epoch + 1) % 50 == 0:\n", 219 | " model.save()\n", 220 | " psnr = validate()\n", 221 | " vis.log('epoch: %d, psnr: %.2f' % (epoch, psnr))" 222 | ] 223 | } 224 | ], 225 | "metadata": { 226 | "kernelspec": { 227 | "display_name": "Python 3", 228 | "language": "python", 229 | "name": "python3" 230 | }, 231 | "language_info": { 232 | "codemirror_mode": { 233 | "name": "ipython", 234 | "version": 3 235 | }, 236 | "file_extension": ".py", 237 | "mimetype": "text/x-python", 238 | "name": "python", 239 | "nbconvert_exporter": "python", 240 | "pygments_lexer": "ipython3", 241 | "version": "3.8.3" 242 | }, 243 | "toc": { 244 | "nav_menu": {}, 245 | "number_sections": true, 246 | "sideBar": true, 247 | "skip_h1_title": true, 248 | "title_cell": "Table of Contents", 249 | "title_sidebar": "Contents", 250 | "toc_cell": false, 251 | "toc_position": {}, 252 | "toc_section_display": true, 253 | "toc_window_display": false 254 | }, 255 | "varInspector": { 256 | "cols": { 257 | "lenName": 16, 258 | "lenType": 16, 259 | "lenVar": 40 260 | }, 261 | "kernels_config": { 262 | "python": { 263 | "delete_cmd_postfix": "", 264 | "delete_cmd_prefix": "del ", 265 | "library": "var_list.py", 266 | "varRefreshCmd": "print(var_dic_list())" 267 | }, 268 | "r": { 269 | "delete_cmd_postfix": ") ", 270 | "delete_cmd_prefix": "rm(", 271 | "library": "var_list.r", 272 | "varRefreshCmd": "cat(var_dic_list()) " 273 | } 274 | }, 275 | "types_to_exclude": [ 276 | "module", 277 | "function", 278 | "builtin_function_or_method", 279 | "instance", 280 | "_Feature" 281 | ], 282 | "window_display": false 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 2 287 | } 288 | -------------------------------------------------------------------------------- /train-rawProcess.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Network Training" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import os, sys, warnings\n", 32 | "import ipdb\n", 33 | "import torch as t\n", 34 | "import torchvision as tv\n", 35 | "import torchnet as tnt\n", 36 | "from tqdm.notebook import tqdm\n", 37 | "\n", 38 | "# add paths for all sub-folders\n", 39 | "paths = [root for root, dirs, files in os.walk('.')]\n", 40 | "for item in paths:\n", 41 | " sys.path.append(item)\n", 42 | "\n", 43 | "from ipynb.fs.full.config import mainConf\n", 44 | "from ipynb.fs.full.monitor import Visualizer\n", 45 | "from ipynb.fs.full.network import r2rNet, rawProcess\n", 46 | "from ipynb.fs.full.dataLoader import fivekNight, valSet\n", 47 | "from ipynb.fs.full.util import *" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": { 53 | "heading_collapsed": true 54 | }, 55 | "source": [ 56 | "## Initialization" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": { 63 | "code_folding": [ 64 | 0 65 | ], 66 | "hidden": true, 67 | "scrolled": true 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "# for debugging only\n", 72 | "%pdb off\n", 73 | "warnings.filterwarnings('ignore')\n", 74 | "\n", 75 | "# choose GPU if available\n", 76 | "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n", 77 | "os.environ[\"CUDA_VISIBLE_DEVICES\"] = '0'\n", 78 | "device = t.device('cuda' if t.cuda.is_available() else 'cpu')\n", 79 | "\n", 80 | "# define models\n", 81 | "opt = mainConf()\n", 82 | "converter = r2rNet().to(device)\n", 83 | "converter.load('./saves')\n", 84 | "converter.eval()\n", 85 | "raw_process_model = rawProcess().to(device)\n", 86 | "\n", 87 | "# load pre-trained model if necessary\n", 88 | "if opt.save_root:\n", 89 | " last_epoch = raw_process_model.load(opt.save_root)\n", 90 | "else:\n", 91 | " last_epoch = 0\n", 92 | "\n", 93 | "# dataloader for training\n", 94 | "train_dataset = fivekNight(opt)\n", 95 | "train_loader = t.utils.data.DataLoader(train_dataset,\n", 96 | " batch_size=opt.batch_size,\n", 97 | " shuffle=True,\n", 98 | " num_workers=opt.num_workers,\n", 99 | " pin_memory=True)\n", 100 | "\n", 101 | "# dataloader for validation\n", 102 | "val_dataset = valSet(opt)\n", 103 | "val_loader = t.utils.data.DataLoader(val_dataset)\n", 104 | "\n", 105 | "# optimizer\n", 106 | "img_loss = imgLoss(device=device)\n", 107 | "raw_process_optim = t.optim.Adam(raw_process_model.parameters(), lr=opt.lr)\n", 108 | "\n", 109 | "# visualizer\n", 110 | "vis = Visualizer(env='deepSelfie(rawProcess)', port=8686)\n", 111 | "raw_process_meter = tnt.meter.AverageValueMeter()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Validation" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": null, 124 | "metadata": { 125 | "code_folding": [ 126 | 0 127 | ] 128 | }, 129 | "outputs": [], 130 | "source": [ 131 | "def validate():\n", 132 | " # set to evaluation mode\n", 133 | " raw_process_model.eval()\n", 134 | "\n", 135 | " isp_psnr = 0.0\n", 136 | " fuse_psnr = 0.0\n", 137 | " for (syth_img, _, _, _, amp, noisy_raw, sorted_mask, wb) in val_loader:\n", 138 | " with t.no_grad():\n", 139 | " # copy to device\n", 140 | " syth_img = syth_img.to(device)\n", 141 | " amp = amp.to(device)\n", 142 | " noisy_raw = noisy_raw.to(device)\n", 143 | " sorted_mask = sorted_mask.to(device)\n", 144 | " wb = wb.to(device)\n", 145 | "\n", 146 | " # pre-processing\n", 147 | " amp_high, _ = t.max(amp, 1)\n", 148 | " amp_low, _ = t.min(amp, 1)\n", 149 | "\n", 150 | " # inference\n", 151 | " pred_high, pred_low, pred_fused = raw_process_model(\n", 152 | " noisy_raw, amp_high, amp_low, wb)\n", 153 | "\n", 154 | " # compute mse\n", 155 | " pred_masked = sorted_mask[:, 0, :, :].unsqueeze(\n", 156 | " 1) * pred_high + sorted_mask[:,\n", 157 | " 1, :, :].unsqueeze(1) * pred_low\n", 158 | " isp_mse = t.nn.functional.mse_loss(pred_masked, syth_img)\n", 159 | " isp_psnr += 10 * t.log10(1 / isp_mse)\n", 160 | " fuse_mse = t.nn.functional.mse_loss(pred_fused, syth_img)\n", 161 | " fuse_psnr += 10 * t.log10(1 / fuse_mse)\n", 162 | " isp_psnr /= len(val_loader)\n", 163 | " fuse_psnr /= len(val_loader)\n", 164 | "\n", 165 | " # set to training mode\n", 166 | " raw_process_model.train(mode=True)\n", 167 | "\n", 168 | " return isp_psnr, fuse_psnr" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "## Training entry" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": { 182 | "code_folding": [ 183 | 0 184 | ] 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "# reset meter and gradient\n", 189 | "raw_process_meter.reset()\n", 190 | "raw_process_optim.zero_grad()\n", 191 | "\n", 192 | "for index, (syth_img, syth_mask) in tqdm(enumerate(train_loader),\n", 193 | " desc='progress',\n", 194 | " total=len(train_loader)):\n", 195 | " # copy to device\n", 196 | " syth_img = syth_img.to(device)\n", 197 | " syth_mask = syth_mask.to(device)\n", 198 | "\n", 199 | " # convert to training sample\n", 200 | " _, _, _, amp, noisy_raw, sorted_mask, wb = toRaw(converter, syth_img,\n", 201 | " syth_mask, opt)\n", 202 | " amp_high, _ = t.max(amp, 1)\n", 203 | " amp_low, _ = t.min(amp, 1)\n", 204 | "\n", 205 | " # inference\n", 206 | " pred_high, pred_low, pred_fused = raw_process_model(\n", 207 | " noisy_raw, amp_high, amp_low, wb)\n", 208 | "\n", 209 | " # compute loss\n", 210 | " pred_masked = sorted_mask[:, 0, :, :].unsqueeze(\n", 211 | " 1) * pred_high + sorted_mask[:, 1, :, :].unsqueeze(1) * pred_low\n", 212 | " raw_process_loss = img_loss(pred_masked, pred_fused, syth_img)\n", 213 | "\n", 214 | " # compute gradient\n", 215 | " raw_process_loss.backward()\n", 216 | "\n", 217 | " # update parameter and reset gradient\n", 218 | " raw_process_optim.step()\n", 219 | " raw_process_optim.zero_grad()\n", 220 | "\n", 221 | " # add to loss meter for logging\n", 222 | " raw_process_meter.add(raw_process_loss.item())\n", 223 | "\n", 224 | " # show intermediate result\n", 225 | " if (index + 1) % opt.plot_freq == 0:\n", 226 | " vis.plot('loss (raw process)', raw_process_meter.value()[0])\n", 227 | " raw_process_plot = t.nn.functional.interpolate(\n", 228 | " t.clamp(t.cat([syth_img, pred_high, pred_low, pred_fused], dim=-1),\n", 229 | " 0.0, 1.0),\n", 230 | " scale_factor=0.5)[0, :, :, :]\n", 231 | " vis.img('raw process gt/hi/lo/fuse', raw_process_plot.cpu() * 255)\n", 232 | "\n", 233 | " # save model\n", 234 | " if (index + 1) % opt.save_freq == 0:\n", 235 | " raw_process_model.save()\n", 236 | " isp_psnr, fuse_psnr = validate()\n", 237 | " vis.log('psnr(isp/fuse): %.2f, %.2f' % (isp_psnr, fuse_psnr))" 238 | ] 239 | } 240 | ], 241 | "metadata": { 242 | "kernelspec": { 243 | "display_name": "Python 3", 244 | "language": "python", 245 | "name": "python3" 246 | }, 247 | "language_info": { 248 | "codemirror_mode": { 249 | "name": "ipython", 250 | "version": 3 251 | }, 252 | "file_extension": ".py", 253 | "mimetype": "text/x-python", 254 | "name": "python", 255 | "nbconvert_exporter": "python", 256 | "pygments_lexer": "ipython3", 257 | "version": "3.8.3" 258 | }, 259 | "toc": { 260 | "nav_menu": {}, 261 | "number_sections": true, 262 | "sideBar": true, 263 | "skip_h1_title": true, 264 | "title_cell": "Table of Contents", 265 | "title_sidebar": "Contents", 266 | "toc_cell": false, 267 | "toc_position": {}, 268 | "toc_section_display": true, 269 | "toc_window_display": false 270 | }, 271 | "varInspector": { 272 | "cols": { 273 | "lenName": 16, 274 | "lenType": 16, 275 | "lenVar": 40 276 | }, 277 | "kernels_config": { 278 | "python": { 279 | "delete_cmd_postfix": "", 280 | "delete_cmd_prefix": "del ", 281 | "library": "var_list.py", 282 | "varRefreshCmd": "print(var_dic_list())" 283 | }, 284 | "r": { 285 | "delete_cmd_postfix": ") ", 286 | "delete_cmd_prefix": "rm(", 287 | "library": "var_list.r", 288 | "varRefreshCmd": "cat(var_dic_list()) " 289 | } 290 | }, 291 | "types_to_exclude": [ 292 | "module", 293 | "function", 294 | "builtin_function_or_method", 295 | "instance", 296 | "_Feature" 297 | ], 298 | "window_display": false 299 | } 300 | }, 301 | "nbformat": 4, 302 | "nbformat_minor": 2 303 | } 304 | -------------------------------------------------------------------------------- /utils/dataLoader.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Dataloader" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Includes" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": { 21 | "code_folding": [ 22 | 0 23 | ] 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "# mass includes\n", 28 | "import os, sys\n", 29 | "import cv2\n", 30 | "import pickle\n", 31 | "import rawpy as rp\n", 32 | "import numpy as np\n", 33 | "import torch as t\n", 34 | "from torch.utils import data\n", 35 | "from torch.distributions.multivariate_normal import MultivariateNormal" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": { 41 | "heading_collapsed": true 42 | }, 43 | "source": [ 44 | "## Dataset for r2rNet" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "metadata": { 51 | "code_folding": [ 52 | 0 53 | ], 54 | "hidden": true 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "class r2rSet(data.Dataset):\n", 59 | " def __init__(self, opt, mode='train'):\n", 60 | " self.mode = mode\n", 61 | " self.data_root = os.path.join(opt.data_root, self.mode)\n", 62 | " self.r2r_size = opt.r2r_size\n", 63 | " self.file_list = [\n", 64 | " file for file in os.listdir(self.data_root) if '.pkl' in file\n", 65 | " ]\n", 66 | "\n", 67 | " def __getitem__(self, index):\n", 68 | " # load a new sample\n", 69 | " with open(os.path.join(self.data_root, self.file_list[index]),\n", 70 | " 'rb') as file:\n", 71 | " data_dict = pickle.load(file)\n", 72 | "\n", 73 | " # read from file\n", 74 | " raw_data = data_dict['raw'].astype(np.float32)\n", 75 | " srgb_data = data_dict['img'].astype(np.float32)\n", 76 | " blk_level = data_dict['blk_level'].astype(np.float32)\n", 77 | " sat_level = data_dict['sat_level'].astype(np.float32)\n", 78 | " cam_wb = data_dict['cam_wb'].astype(np.float32)\n", 79 | "\n", 80 | " # random transforms\n", 81 | " if self.mode == 'train':\n", 82 | " # random crop\n", 83 | " crop_h = np.random.randint(0, raw_data.shape[1] - self.r2r_size)\n", 84 | " crop_w = np.random.randint(0, raw_data.shape[2] - self.r2r_size)\n", 85 | " raw_patch = raw_data[:, crop_h:crop_h + self.r2r_size,\n", 86 | " crop_w:crop_w + self.r2r_size]\n", 87 | " srgb_patch = srgb_data[:, 2 * crop_h:2 * (crop_h + self.r2r_size),\n", 88 | " 2 * crop_w:2 * (crop_w + self.r2r_size)]\n", 89 | " else:\n", 90 | " raw_patch = raw_data[:, :, :]\n", 91 | " srgb_patch = srgb_data[:, :, :]\n", 92 | "\n", 93 | " # normalization\n", 94 | " raw_patch = np.clip((raw_patch - np.resize(blk_level, [4, 1, 1])) /\n", 95 | " (sat_level - np.resize(blk_level, [4, 1, 1])), 0.0,\n", 96 | " 1.0)\n", 97 | " srgb_patch = srgb_patch / 65535.0\n", 98 | "\n", 99 | " # to pyTorch tensor\n", 100 | " raw_patch = t.from_numpy(raw_patch)\n", 101 | " srgb_patch = t.from_numpy(srgb_patch)\n", 102 | " cam_wb = t.from_numpy(cam_wb).view([3, 1, 1])\n", 103 | " if self.mode == 'train':\n", 104 | " cam_wb = cam_wb.expand([3, self.r2r_size, self.r2r_size])\n", 105 | " else:\n", 106 | " cam_wb = cam_wb.expand([3, raw_patch.size(1), raw_patch.size(2)])\n", 107 | "\n", 108 | " return raw_patch, srgb_patch, cam_wb\n", 109 | "\n", 110 | " def __len__(self):\n", 111 | "\n", 112 | " return len(self.file_list)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Dataset for fivekNight" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": { 126 | "code_folding": [ 127 | 1, 128 | 31, 129 | 69, 130 | 123 131 | ] 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "# random cropping and flipping\n", 136 | "def randTransform(bg_img, fg_img, mask, img_size):\n", 137 | " # cropping\n", 138 | " if bg_img.shape[1] > bg_img.shape[0]:\n", 139 | " crop_h = np.random.randint(img_size[1], bg_img.shape[0])\n", 140 | " crop_w = np.round(crop_h / 0.75)\n", 141 | " else:\n", 142 | " crop_w = np.random.randint(img_size[0], bg_img.shape[1])\n", 143 | " crop_h = np.round(crop_w / 1.33)\n", 144 | " crop_y = int(np.random.randint(0, bg_img.shape[0] - crop_h))\n", 145 | " crop_x = int(np.random.randint(0, bg_img.shape[1] - crop_w))\n", 146 | " crop_h = int(crop_h)\n", 147 | " crop_w = int(crop_w)\n", 148 | " bg_img = bg_img[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w, :]\n", 149 | "\n", 150 | " # flipping\n", 151 | " rand_var = np.random.rand()\n", 152 | " if rand_var < 0.25:\n", 153 | " bg_img = cv2.flip(bg_img, 1)\n", 154 | " elif rand_var < 0.5:\n", 155 | " fg_img = cv2.flip(fg_img, 1)\n", 156 | " mask = cv2.flip(mask, 1)\n", 157 | " elif rand_var < 0.75:\n", 158 | " bg_img = cv2.flip(bg_img, 1)\n", 159 | " fg_img = cv2.flip(fg_img, 1)\n", 160 | " mask = cv2.flip(mask, 1)\n", 161 | "\n", 162 | " return bg_img, fg_img, mask\n", 163 | "\n", 164 | "\n", 165 | "# image blending and rescaling\n", 166 | "def imgBlend(bg_img, fg_img, mask, img_size):\n", 167 | " # scaling\n", 168 | " scale = np.random.uniform(0.5, 1.0) * np.minimum(\n", 169 | " bg_img.shape[0] / mask.shape[0], bg_img.shape[1] / mask.shape[1])\n", 170 | " offset = np.random.randint(0, bg_img.shape[1] - int(scale * mask.shape[1]))\n", 171 | " mask = cv2.resize(mask,\n", 172 | " None,\n", 173 | " fx=scale,\n", 174 | " fy=scale,\n", 175 | " interpolation=cv2.INTER_CUBIC)\n", 176 | " fg_img = cv2.resize(fg_img,\n", 177 | " None,\n", 178 | " fx=scale,\n", 179 | " fy=scale,\n", 180 | " interpolation=cv2.INTER_CUBIC)\n", 181 | "\n", 182 | " # paste crop image to an empty image\n", 183 | " syth_mask = np.zeros((bg_img.shape[0], bg_img.shape[1]), dtype=np.float32)\n", 184 | " syth_mask[bg_img.shape[0] - mask.shape[0]:bg_img.shape[0],\n", 185 | " offset:offset + mask.shape[1]] = mask\n", 186 | " syth_mask = np.repeat(syth_mask[:, :, np.newaxis], 3, axis=2) / 255.0\n", 187 | " syth_img = np.zeros((bg_img.shape[0], bg_img.shape[1], 3),\n", 188 | " dtype=np.float32)\n", 189 | " syth_img[bg_img.shape[0] - mask.shape[0]:bg_img.shape[0],\n", 190 | " offset:offset + mask.shape[1], :] = fg_img\n", 191 | " syth_img = (syth_mask * syth_img + (1 - syth_mask) * bg_img) / 65535.0\n", 192 | "\n", 193 | " # resize to fixed shape\n", 194 | " syth_img = cv2.resize(syth_img, img_size, interpolation=cv2.INTER_CUBIC)\n", 195 | " syth_mask = cv2.resize(syth_mask, img_size, interpolation=cv2.INTER_CUBIC)\n", 196 | "\n", 197 | " # clip to 0-1\n", 198 | " syth_img = np.clip(syth_img, 0.0, 1.0)\n", 199 | " syth_mask = np.clip(syth_mask, 0.0, 1.0)\n", 200 | "\n", 201 | " return syth_img, syth_mask\n", 202 | "\n", 203 | "\n", 204 | "class fivekNight(data.Dataset):\n", 205 | " def __init__(self, opt):\n", 206 | " # get sample list\n", 207 | " self.img_size = opt.isp_size\n", 208 | " self.bg_path = os.path.join(opt.data_root, 'scene')\n", 209 | " self.fg_path = os.path.join(opt.data_root, 'people')\n", 210 | " self.bg_list = [\n", 211 | " file[:-4] for file in os.listdir(os.path.join(self.bg_path, 'raw'))\n", 212 | " if '.png' in file\n", 213 | " ]\n", 214 | " self.fg_list = [\n", 215 | " file[:-4] for file in os.listdir(os.path.join(self.fg_path, 'raw'))\n", 216 | " if '.png' in file\n", 217 | " ]\n", 218 | "\n", 219 | " def __getitem__(self, index):\n", 220 | " # read images and mask\n", 221 | " bg_index = int(index // len(self.fg_list))\n", 222 | " fg_index = int(index % len(self.fg_list))\n", 223 | " bg_img = cv2.imread(\n", 224 | " os.path.join(self.bg_path, 'raw', self.bg_list[bg_index] + '.png'),\n", 225 | " cv2.IMREAD_UNCHANGED)\n", 226 | " fg_img = cv2.imread(\n", 227 | " os.path.join(self.fg_path, 'raw', self.fg_list[fg_index] + '.png'),\n", 228 | " cv2.IMREAD_UNCHANGED)\n", 229 | " with open(\n", 230 | " os.path.join(self.fg_path, 'mask',\n", 231 | " self.fg_list[fg_index] + '.pkl'), 'rb') as pkl:\n", 232 | " mask = pickle.load(pkl)\n", 233 | "\n", 234 | " # BGR to RGB\n", 235 | " bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB)\n", 236 | " fg_img = cv2.cvtColor(fg_img, cv2.COLOR_BGR2RGB)\n", 237 | "\n", 238 | " # random transforms\n", 239 | " bg_img, fg_img, mask = randTransform(bg_img, fg_img, mask,\n", 240 | " self.img_size)\n", 241 | "\n", 242 | " # image blending and rescaling\n", 243 | " syth_img, syth_mask = imgBlend(bg_img, fg_img, mask, self.img_size)\n", 244 | "\n", 245 | " # convert to tensor and normalize\n", 246 | " syth_img = t.tensor(syth_img, dtype=t.float).permute(2, 0, 1)\n", 247 | " syth_mask = t.tensor(np.stack(\n", 248 | " [syth_mask[:, :, 0], 1.0 - syth_mask[:, :, 0]], axis=0),\n", 249 | " dtype=t.float)\n", 250 | "\n", 251 | " return syth_img, syth_mask\n", 252 | "\n", 253 | " def __len__(self):\n", 254 | "\n", 255 | " return len(self.fg_list) * len(self.bg_list)\n", 256 | "\n", 257 | "\n", 258 | "class valSet(data.Dataset):\n", 259 | " def __init__(self, opt):\n", 260 | " self.data_root = os.path.join(opt.data_root,\n", 261 | " 'val%d' % opt.amp_range[1])\n", 262 | " self.file_list = [\n", 263 | " file for file in os.listdir(self.data_root) if '.pkl' in file\n", 264 | " ]\n", 265 | "\n", 266 | " def __getitem__(self, index):\n", 267 | " # load a new sample\n", 268 | " with open(os.path.join(self.data_root, self.file_list[index]),\n", 269 | " 'rb') as file:\n", 270 | " data_dict = pickle.load(file)\n", 271 | "\n", 272 | " # read from file\n", 273 | " syth_img = data_dict['syth_img']\n", 274 | " thumb_img = data_dict['thumb_img']\n", 275 | " struct_img = data_dict['struct_img']\n", 276 | " seg_mask = data_dict['seg_mask']\n", 277 | " amp = data_dict['amp']\n", 278 | " noisy_raw = data_dict['noisy_raw']\n", 279 | " sorted_mask = data_dict['sorted_mask']\n", 280 | " wb = data_dict['wb']\n", 281 | "\n", 282 | " return syth_img, thumb_img, struct_img, seg_mask, amp, noisy_raw, sorted_mask, wb\n", 283 | "\n", 284 | " def __len__(self):\n", 285 | "\n", 286 | " return len(self.file_list)" 287 | ] 288 | } 289 | ], 290 | "metadata": { 291 | "kernelspec": { 292 | "display_name": "Python 3", 293 | "language": "python", 294 | "name": "python3" 295 | }, 296 | "language_info": { 297 | "codemirror_mode": { 298 | "name": "ipython", 299 | "version": 3 300 | }, 301 | "file_extension": ".py", 302 | "mimetype": "text/x-python", 303 | "name": "python", 304 | "nbconvert_exporter": "python", 305 | "pygments_lexer": "ipython3", 306 | "version": "3.8.3" 307 | } 308 | }, 309 | "nbformat": 4, 310 | "nbformat_minor": 2 311 | } 312 | -------------------------------------------------------------------------------- /utils/monitor.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Visualizer" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import time\n", 32 | "import visdom\n", 33 | "import numpy as np\n", 34 | "import torchvision as tv" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": { 40 | "heading_collapsed": true 41 | }, 42 | "source": [ 43 | "## Some useful visdom methods" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "metadata": { 50 | "code_folding": [], 51 | "hidden": true 52 | }, 53 | "outputs": [], 54 | "source": [ 55 | "class Visualizer():\n", 56 | " def __init__(self, env='main', port=8097, **kwargs):\n", 57 | " self.vis = visdom.Visdom(port=port, env=env, **kwargs)\n", 58 | " self.index = {}\n", 59 | " self.log_text = ''\n", 60 | "\n", 61 | " def reinit(self, env='main', port=8097, **kwargs):\n", 62 | " self.vis = visdom.Visdom(port=port, env=env, **kwargs)\n", 63 | "\n", 64 | " return self\n", 65 | "\n", 66 | " # show a new log\n", 67 | " def log(self, info, win='log_text'):\n", 68 | " self.log_text += (\n", 69 | " '[%s] %s
' % (time.strftime('%m%d_%H%M%S'), info))\n", 70 | " self.vis.text(self.log_text, win=win)\n", 71 | "\n", 72 | " # plot single data\n", 73 | " def plot(self, name, y):\n", 74 | " x = self.index.get(name, 0)\n", 75 | " self.vis.line(\n", 76 | " Y=np.array([y]),\n", 77 | " X=np.array([x]),\n", 78 | " win=name,\n", 79 | " opts=dict(title=name),\n", 80 | " update=None if x == 0 else 'append')\n", 81 | " self.index[name] = x + 1\n", 82 | "\n", 83 | " # plot multiple data\n", 84 | " def multiPlot(self, d):\n", 85 | " for k, v in d.items():\n", 86 | " self.plot(k, v)\n", 87 | "\n", 88 | " # plot single image\n", 89 | " def img(self, name, img):\n", 90 | " if len(img.size()) < 3:\n", 91 | " img = img.cpu().unsqueeze(0)\n", 92 | " self.vis.image(img.cpu(), win=name, opts=dict(title=name))\n", 93 | "\n", 94 | " # plot multiple images\n", 95 | " def multiImg(self, d):\n", 96 | " for k, v in d.items():\n", 97 | " self.img(k, v)\n", 98 | "\n", 99 | " # plot multiple images in one grid\n", 100 | " def img_grid(self, name, input_3d):\n", 101 | " self.img(\n", 102 | " name,\n", 103 | " tv.utils.make_grid(input_3d.cpu()[0].unsqueeze(1).clamp(\n", 104 | " max=1, min=0)))\n", 105 | "\n", 106 | " # plot multiple image grids\n", 107 | " def img_grid_many(self, d):\n", 108 | " for k, v in d.items():\n", 109 | " self.img_grid(k, v)\n", 110 | "\n", 111 | " # other visdom methods\n", 112 | " def __getattr__(self, name):\n", 113 | "\n", 114 | " return getattr(self.vis, name)" 115 | ] 116 | } 117 | ], 118 | "metadata": { 119 | "kernelspec": { 120 | "display_name": "Python 3", 121 | "language": "python", 122 | "name": "python3" 123 | }, 124 | "language_info": { 125 | "codemirror_mode": { 126 | "name": "ipython", 127 | "version": 3 128 | }, 129 | "file_extension": ".py", 130 | "mimetype": "text/x-python", 131 | "name": "python", 132 | "nbconvert_exporter": "python", 133 | "pygments_lexer": "ipython3", 134 | "version": "3.6.9" 135 | }, 136 | "toc": { 137 | "nav_menu": {}, 138 | "number_sections": true, 139 | "sideBar": true, 140 | "skip_h1_title": true, 141 | "title_cell": "Table of Contents", 142 | "title_sidebar": "Contents", 143 | "toc_cell": false, 144 | "toc_position": {}, 145 | "toc_section_display": true, 146 | "toc_window_display": false 147 | }, 148 | "varInspector": { 149 | "cols": { 150 | "lenName": 16, 151 | "lenType": 16, 152 | "lenVar": 40 153 | }, 154 | "kernels_config": { 155 | "python": { 156 | "delete_cmd_postfix": "", 157 | "delete_cmd_prefix": "del ", 158 | "library": "var_list.py", 159 | "varRefreshCmd": "print(var_dic_list())" 160 | }, 161 | "r": { 162 | "delete_cmd_postfix": ") ", 163 | "delete_cmd_prefix": "rm(", 164 | "library": "var_list.r", 165 | "varRefreshCmd": "cat(var_dic_list()) " 166 | } 167 | }, 168 | "types_to_exclude": [ 169 | "module", 170 | "function", 171 | "builtin_function_or_method", 172 | "instance", 173 | "_Feature" 174 | ], 175 | "window_display": false 176 | } 177 | }, 178 | "nbformat": 4, 179 | "nbformat_minor": 2 180 | } 181 | -------------------------------------------------------------------------------- /utils/util.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Other Useful Functions" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "heading_collapsed": true 14 | }, 15 | "source": [ 16 | "## Includes" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "code_folding": [ 24 | 0 25 | ], 26 | "hidden": true 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "# mass includes\n", 31 | "import torch as t\n", 32 | "import torchvision as tv" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "metadata": {}, 38 | "source": [ 39 | "## Manual post-process of raw" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": { 46 | "code_folding": [ 47 | 1, 48 | 11, 49 | 20, 50 | 31, 51 | 51 52 | ] 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "# convert RAW to sRGB image\n", 57 | "def raw2Img(raw_data, wb, cam_matrix):\n", 58 | " raw_data = applyWB(raw_data, wb)\n", 59 | " img = demosaic(raw_data)\n", 60 | " img = cam2sRGB(img, cam_matrix)\n", 61 | " img = applyGamma(img)\n", 62 | "\n", 63 | " return t.clamp(img, 0.0, 1.0)\n", 64 | "\n", 65 | "\n", 66 | "# apply white balancing\n", 67 | "def applyWB(raw_data, wb):\n", 68 | " raw_out = raw_data.clone()\n", 69 | " raw_out[0, :, :] *= wb[0]\n", 70 | " raw_out[3, :, :] *= wb[2]\n", 71 | "\n", 72 | " return raw_out\n", 73 | "\n", 74 | "\n", 75 | "# demosaicing\n", 76 | "def demosaic(raw_data):\n", 77 | " _, hei, wid = raw_data.size()\n", 78 | " img = raw_data.new_empty([3, hei, wid])\n", 79 | " img[0, :, :] = raw_data[0, :, :] # R\n", 80 | " img[1, :, :] = (raw_data[1, :, :] + raw_data[2, :, :]) / 2 # G1+G2\n", 81 | " img[2, :, :] = raw_data[3, :, :] # B\n", 82 | "\n", 83 | " return img\n", 84 | "\n", 85 | "\n", 86 | "# color space conversion\n", 87 | "def cam2sRGB(img, cam_matrix):\n", 88 | " cam_matrix = img.new_tensor(cam_matrix)\n", 89 | " xyz_matrix = img.new_tensor([[0.4124564, 0.3575761, 0.1804375],\n", 90 | " [0.2126729, 0.7151522, 0.0721750],\n", 91 | " [0.0193339, 0.1191920, 0.9503041]])\n", 92 | " trans_matrix = t.matmul(cam_matrix, xyz_matrix)\n", 93 | " trans_matrix /= t.sum(trans_matrix, 1, keepdim=True).repeat(1, 3)\n", 94 | " trans_matrix = t.inverse(trans_matrix)\n", 95 | " new_img = t.empty_like(img)\n", 96 | " new_img[0, :, :] = img[0, :, :] * trans_matrix[0, 0] + img[\n", 97 | " 1, :, :] * trans_matrix[0, 1] + img[2, :, :] * trans_matrix[0, 2]\n", 98 | " new_img[1, :, :] = img[0, :, :] * trans_matrix[1, 0] + img[\n", 99 | " 1, :, :] * trans_matrix[1, 1] + img[2, :, :] * trans_matrix[1, 2]\n", 100 | " new_img[2, :, :] = img[0, :, :] * trans_matrix[2, 0] + img[\n", 101 | " 1, :, :] * trans_matrix[2, 1] + img[2, :, :] * trans_matrix[2, 2]\n", 102 | "\n", 103 | " return new_img\n", 104 | "\n", 105 | "\n", 106 | "# gamma correction\n", 107 | "def applyGamma(img):\n", 108 | " new_img = t.pow(img, 1 / 2.2)\n", 109 | "\n", 110 | " return new_img" 111 | ] 112 | }, 113 | { 114 | "cell_type": "markdown", 115 | "metadata": {}, 116 | "source": [ 117 | "## Raw data manipulation" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": { 124 | "code_folding": [ 125 | 1, 126 | 13, 127 | 34 128 | ] 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "# normalization\n", 133 | "def normalize(raw_data, bk_level, sat_level):\n", 134 | " normal_raw = t.empty_like(raw_data)\n", 135 | " for index in range(raw_data.size(0)):\n", 136 | " for channel in range(raw_data.size(1)):\n", 137 | " normal_raw[index, channel, :, :] = (\n", 138 | " raw_data[index, channel, :, :] -\n", 139 | " bk_level[channel]) / (sat_level - bk_level[channel])\n", 140 | "\n", 141 | " return normal_raw\n", 142 | "\n", 143 | "\n", 144 | "# resize Bayer pattern\n", 145 | "def downSample(raw_data, struct_img_size):\n", 146 | " # convert Bayer pattern to down-sized sRGB image\n", 147 | " batch, _, hei, wid = raw_data.size()\n", 148 | " raw_img = raw_data.new_empty((batch, 3, hei, wid))\n", 149 | " raw_img[:, 0, :, :] = raw_data[:, 0, :, :] # R\n", 150 | " raw_img[:,\n", 151 | " 1, :, :] = (raw_data[:, 1, :, :] + raw_data[:, 2, :, :]) / 2.0 # G\n", 152 | " raw_img[:, 2, :, :] = raw_data[:, 3, :, :] # B\n", 153 | "\n", 154 | " # down-sample to small size\n", 155 | " if hei != struct_img_size[1] and wid != struct_img_size[0]:\n", 156 | " raw_img = t.nn.functional.interpolate(raw_img,\n", 157 | " size=(struct_img_size[1],\n", 158 | " struct_img_size[0]),\n", 159 | " mode='bicubic')\n", 160 | " raw_img = t.clamp(raw_img, 0.0, 1.0)\n", 161 | "\n", 162 | " return raw_img\n", 163 | "\n", 164 | "\n", 165 | "# image standardization (mean 0, std 1)\n", 166 | "def standardize(srgb_img):\n", 167 | " struct_img = t.empty_like(srgb_img)\n", 168 | " adj_std = 1.0 / t.sqrt(srgb_img.new_tensor(srgb_img[0, :, :, :].numel()))\n", 169 | " for index in range(srgb_img.size(0)):\n", 170 | " mean = t.mean(srgb_img[index, :, :, :])\n", 171 | " std = t.std(srgb_img[index, :, :, :])\n", 172 | " adj_std = t.max(std, adj_std)\n", 173 | " struct_img[index, :, :, :] = (srgb_img[index, :, :, :] -\n", 174 | " mean) / adj_std\n", 175 | "\n", 176 | " return struct_img" 177 | ] 178 | }, 179 | { 180 | "cell_type": "markdown", 181 | "metadata": {}, 182 | "source": [ 183 | "## Training sample sythesis" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": { 190 | "code_folding": [ 191 | 1, 192 | 11, 193 | 30 194 | ] 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "# convert sRGB image to RGGB pattern\n", 199 | "def toRGGB(srgb_img):\n", 200 | " rggb_img = t.stack(\n", 201 | " (srgb_img[:, 0, 0::2, 0::2], srgb_img[:, 1, 0::2, 1::2],\n", 202 | " srgb_img[:, 1, 1::2, 0::2], srgb_img[:, 2, 1::2, 1::2]),\n", 203 | " dim=1)\n", 204 | "\n", 205 | " return rggb_img\n", 206 | "\n", 207 | "\n", 208 | "# add noise to Bayer pattern\n", 209 | "def addPGNoise(raw_data, noise_stat):\n", 210 | " # add noise to each sample\n", 211 | " noisy_raw = t.empty_like(raw_data)\n", 212 | " for index in range(raw_data.size(0)):\n", 213 | " log_shot = raw_data.new_empty(1).uniform_(noise_stat['min'],\n", 214 | " noise_stat['max'])\n", 215 | " log_read = raw_data.new_empty(1).normal_(\n", 216 | " mean=noise_stat['slope'] * log_shot.item() + noise_stat['const'],\n", 217 | " std=noise_stat['std'])\n", 218 | " delta_final = t.sqrt(\n", 219 | " t.exp(log_shot) * raw_data[index, :, :, :] + t.exp(log_read))\n", 220 | " pg_noise = delta_final * t.randn_like(raw_data[index, :, :, :])\n", 221 | " noisy_raw[index, :, :, :] = raw_data[index, :, :, :] + pg_noise\n", 222 | " noisy_raw = t.clamp(noisy_raw, 0.0, 1.0)\n", 223 | "\n", 224 | " return noisy_raw\n", 225 | "\n", 226 | "\n", 227 | "# blend weighted fg & bg and convert to Bayer pattern\n", 228 | "def toRaw(r2rNet, syth_img, syth_mask, opt):\n", 229 | " # convert sRGB image to half size RGBG pattern\n", 230 | " rggb_raw = toRGGB(syth_img)\n", 231 | "\n", 232 | " # extract saturation mask\n", 233 | " sat_mask = rggb_raw.new_tensor(t.mean(rggb_raw, 1, keepdim=True) > 0.95)\n", 234 | "\n", 235 | " #random white balance\n", 236 | " batch, _, hei, wid = rggb_raw.size()\n", 237 | " wb = rggb_raw.new_empty((batch, 3, hei, wid))\n", 238 | " for index in range(0, batch):\n", 239 | " wb_r = rggb_raw.new_empty(1).uniform_(opt.wb_stat['min'],\n", 240 | " opt.wb_stat['max'])\n", 241 | " wb_b = rggb_raw.new_empty(1).normal_(\n", 242 | " mean=opt.wb_stat['slope'] * wb_r.item() + opt.wb_stat['const'],\n", 243 | " std=opt.wb_stat['std'])\n", 244 | " wb[index, 0, :, :] = t.exp(wb_r)\n", 245 | " wb[index, 1, :, :] = 1.0\n", 246 | " wb[index, 2, :, :] = t.exp(wb_b)\n", 247 | "\n", 248 | " # convert to Bayer pattern\n", 249 | " with t.no_grad():\n", 250 | " org_raw = r2rNet(rggb_raw, wb)\n", 251 | " org_raw = t.clamp(org_raw, 0.0, 1.0)\n", 252 | "\n", 253 | "\n", 254 | "# random amplification ratio\n", 255 | " sorted_mask = syth_mask.clone()\n", 256 | " half_mask = t.nn.functional.interpolate(syth_mask, scale_factor=0.5)\n", 257 | " half_mask = t.clamp(half_mask, 0.0, 1.0)\n", 258 | " amp = org_raw.new_empty((batch, 2))\n", 259 | " clean_raw = t.empty_like(org_raw)\n", 260 | " for index in range(0, batch):\n", 261 | " amp[index, :] = t.clamp(\n", 262 | " syth_img.new_empty((2, )).uniform_(0.0, opt.amp_range[1]), 1.0,\n", 263 | " opt.amp_range[1])\n", 264 | " clean_raw[index, :, :, :] = half_mask[index, 0, :, :].unsqueeze(\n", 265 | " 0) * org_raw[index, :, :, :] / amp[index, 0] + half_mask[\n", 266 | " index,\n", 267 | " 1, :, :].unsqueeze(0) * org_raw[index, :, :, :] / amp[index, 1]\n", 268 | " if amp[index, 0] < amp[index, 1]:\n", 269 | " sorted_mask[index, :, :, :] = t.flip(sorted_mask[index, :, :, :],\n", 270 | " [0])\n", 271 | "\n", 272 | " # preserve saturation\n", 273 | " clean_raw = t.max(clean_raw, sat_mask)\n", 274 | "\n", 275 | " # add noise\n", 276 | " noisy_raw = addPGNoise(clean_raw, opt.noise_stat)\n", 277 | "\n", 278 | " # down-sample to fixed size\n", 279 | " thumb_img = downSample(clean_raw, opt.att_size)\n", 280 | " struct_img = standardize(thumb_img)\n", 281 | " seg_mask = t.nn.functional.interpolate(syth_mask,\n", 282 | " size=(opt.att_size[1],\n", 283 | " opt.att_size[0]))\n", 284 | " seg_mask = t.clamp(seg_mask, 0.0, 1.0)\n", 285 | "\n", 286 | " return thumb_img, struct_img, seg_mask, amp, noisy_raw, sorted_mask, wb" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "metadata": {}, 292 | "source": [ 293 | "## Loss function" 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "metadata": { 300 | "code_folding": [ 301 | 0, 302 | 23 303 | ] 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "class vgg16Loss(t.nn.Module):\n", 308 | " def __init__(self, device):\n", 309 | " super(vgg16Loss, self).__init__()\n", 310 | " features = list(tv.models.vgg16(pretrained=True).features)[:23]\n", 311 | " self.features = t.nn.ModuleList(features).to(device).eval()\n", 312 | " for param in self.parameters():\n", 313 | " param.requires_grad = False\n", 314 | "\n", 315 | " def forward(self, pred_img, gt_img):\n", 316 | " x = pred_img\n", 317 | " y = gt_img\n", 318 | " vgg_loss = 0.0\n", 319 | "\n", 320 | " # use outputs of relu1_2, relu2_2, relu3_3, relu4_3 as loss\n", 321 | " for index, layer in enumerate(self.features):\n", 322 | " x = layer(x)\n", 323 | " y = layer(y)\n", 324 | " if index in {3, 8, 15, 22}:\n", 325 | " vgg_loss += t.nn.functional.mse_loss(x, y)\n", 326 | "\n", 327 | " return vgg_loss / 4.0\n", 328 | "\n", 329 | "\n", 330 | "class imgLoss(t.nn.Module):\n", 331 | " def __init__(self, device):\n", 332 | " super(imgLoss, self).__init__()\n", 333 | " self.l2_loss = t.nn.MSELoss()\n", 334 | " self.vgg_loss = vgg16Loss(device)\n", 335 | "\n", 336 | " def forward(self, masked_img, fused_img, gt_img):\n", 337 | " l2_loss = (self.l2_loss(masked_img, gt_img) +\n", 338 | " self.l2_loss(fused_img, gt_img)) / 2.0\n", 339 | " vgg_loss = (self.vgg_loss(masked_img, gt_img) +\n", 340 | " self.vgg_loss(fused_img, gt_img)) / 2.0\n", 341 | "\n", 342 | " return l2_loss + vgg_loss" 343 | ] 344 | } 345 | ], 346 | "metadata": { 347 | "kernelspec": { 348 | "display_name": "Python 3", 349 | "language": "python", 350 | "name": "python3" 351 | }, 352 | "language_info": { 353 | "codemirror_mode": { 354 | "name": "ipython", 355 | "version": 3 356 | }, 357 | "file_extension": ".py", 358 | "mimetype": "text/x-python", 359 | "name": "python", 360 | "nbconvert_exporter": "python", 361 | "pygments_lexer": "ipython3", 362 | "version": "3.8.3" 363 | } 364 | }, 365 | "nbformat": 4, 366 | "nbformat_minor": 2 367 | } 368 | --------------------------------------------------------------------------------