├── .gitignore ├── LICENSE ├── Procfile ├── README.md ├── app.py ├── assets ├── lol_results.gif └── mirnet_architecture.png ├── checkpoints └── .gitkeep ├── main.py ├── mirnet ├── __init__.py ├── dataloaders │ ├── __init__.py │ ├── common.py │ └── dataloader.py ├── inference.py ├── losses.py ├── model │ ├── __init__.py │ ├── dual_attention_unit │ │ ├── __init__.py │ │ ├── attention_blocks.py │ │ └── dau.py │ ├── mirnet_model.py │ ├── residual_resizing_modules.py │ └── skff.py ├── train.py └── utils.py ├── notebooks ├── .gitkeep ├── MIRNet_LOL_Inference.ipynb ├── MIRNet_LOL_Inference_256x256.ipynb └── MIRNet_Low_Light_Train.ipynb ├── requirements.txt ├── run.sh ├── setup.sh └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | venv/ 3 | .idea/ 4 | **/__pycache__/** 5 | .DS_Store 6 | checkpoints/ 7 | **.h5 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Procfile: -------------------------------------------------------------------------------- 1 | web: sh setup.sh && streamlit run app.py -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIRNet 2 | 3 | [![](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://share.streamlit.io/soumik12345/mirnet/app.py) 4 | 5 | Tensorflow implementation of the MIRNet architecture as proposed by [Learning Enriched Features for Real Image 6 | Restoration and Enhancement](https://arxiv.org/pdf/2003.06792v2.pdf). 7 | 8 | **Lanuch Notebooks:** [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/soumik12345/MIRNet/HEAD) 9 | 10 | **Wandb Logs:** [https://wandb.ai/19soumik-rakshit96/mirnet](https://wandb.ai/19soumik-rakshit96/mirnet) 11 | 12 | **Blog Post:** [https://keras.io/examples/vision/mirnet/](https://keras.io/examples/vision/mirnet/) 13 | 14 | **TFLite Variant of MIRNet:** [https://github.com/sayakpaul/MIRNet-TFLite](https://github.com/sayakpaul/MIRNet-TFLite). 15 | 16 | **TFLite Models on Tensorflow Hub:** [https://tfhub.dev/sayakpaul/lite-model/mirnet-fixed/dr/1](https://tfhub.dev/sayakpaul/lite-model/mirnet-fixed/dr/1). 17 | 18 | **Tensorflow JS Variant of MIRNet:** [https://github.com/Rishit-dagli/MIRNet-TFJS](https://github.com/Rishit-dagli/MIRNet-TFJS). 19 | 20 | ![](./assets/mirnet_architecture.png) 21 | 22 | ![](./assets/lol_results.gif) 23 | 24 | ## Pre-trained Weights 25 | 26 | - **Trained on 128x128 patches:** [https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing](https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing) 27 | 28 | - **Trained on 256x256 patches:** [https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing](https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing) 29 | 30 | ## Citation 31 | 32 | ``` 33 | @misc{ 34 | 2003.06792, 35 | Author = {Syed Waqas Zamir and Aditya Arora and Salman Khan and Munawar Hayat and Fahad Shahbaz Khan and Ming-Hsuan Yang and Ling Shao}, 36 | Title = {Learning Enriched Features for Real Image Restoration and Enhancement}, 37 | Year = {2020}, 38 | Eprint = {arXiv:2003.06792}, 39 | } 40 | ``` 41 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from PIL import Image 4 | import streamlit as st 5 | from mirnet.inference import Inferer 6 | 7 | 8 | def main(): 9 | st.markdown( 10 | '

Low-light Image Enhancement using MIRNet


', 11 | unsafe_allow_html=True 12 | ) 13 | inferer = Inferer() 14 | if not os.path.exists('low_light_weights_best.h5'): 15 | st.sidebar.text('Downloading Model weights...') 16 | inferer.download_weights('1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL') 17 | st.sidebar.text('Done') 18 | st.sidebar.text('Building MIRNet Model...') 19 | inferer.build_model( 20 | num_rrg=3, num_mrb=2, channels=64, 21 | weights_path='low_light_weights_best.h5' 22 | ) 23 | st.sidebar.text('Done') 24 | uploaded_files = st.sidebar.file_uploader( 25 | 'Please Upload your Low-light Images', 26 | accept_multiple_files=True 27 | ) 28 | col_1, col_2 = st.beta_columns(2) 29 | if len(uploaded_files) > 0: 30 | for uploaded_file in uploaded_files: 31 | pil_image = Image.open(uploaded_file) 32 | original_image, output_image = inferer.infer_streamlit(pil_image) 33 | with col_1: 34 | st.image( 35 | original_image, use_column_width=True, 36 | caption='Original Image' 37 | ) 38 | with col_2: 39 | st.image( 40 | output_image, use_column_width=True, 41 | caption='Predicted Image' 42 | ) 43 | st.markdown('---') 44 | if not os.path.exists('low_light_weights_best.h5'): 45 | subprocess.run(['rm', 'low_light_weights_best.h5']) 46 | 47 | 48 | if __name__ == '__main__': 49 | main() 50 | -------------------------------------------------------------------------------- /assets/lol_results.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/assets/lol_results.gif -------------------------------------------------------------------------------- /assets/mirnet_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/assets/mirnet_architecture.png -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/checkpoints/.gitkeep -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from mirnet.train import LowLightTrainer 3 | from mirnet.utils import init_wandb, download_dataset 4 | 5 | 6 | download_dataset('LOL') 7 | 8 | init_wandb( 9 | project_name='mirnet', experiment_name='LOL_lowlight', 10 | wandb_api_key='cf0947ccde62903d4df0742a58b8a54ca4c11673' 11 | ) 12 | 13 | train_low_light_images = glob('./our485/low/*') 14 | train_high_light_images = glob('./our485/high/*') 15 | valid_low_light_images = glob('./eval15/low/*') 16 | valid_high_light_images = glob('./eval15/high/*') 17 | 18 | trainer = LowLightTrainer() 19 | trainer.build_dataset( 20 | train_low_light_images, train_high_light_images, 21 | valid_low_light_images, valid_high_light_images, 22 | crop_size=128, batch_size=16 23 | ) 24 | 25 | trainer.compile() 26 | 27 | trainer.train(epochs=100, checkpoint_dir='./checkpoints') 28 | -------------------------------------------------------------------------------- /mirnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/mirnet/__init__.py -------------------------------------------------------------------------------- /mirnet/dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import LOLDataLoader 2 | -------------------------------------------------------------------------------- /mirnet/dataloaders/common.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def read_images(image_files): 5 | dataset = tf.data.Dataset.from_tensor_slices(image_files) 6 | dataset = dataset.map(tf.io.read_file) 7 | dataset = dataset.map( 8 | lambda x: tf.image.decode_png(x, channels=3), 9 | num_parallel_calls=tf.data.experimental.AUTOTUNE 10 | ) 11 | return dataset 12 | 13 | 14 | def random_crop(low_image, enhanced_image, low_crop_size, enhanced_crop_size): 15 | low_image_shape = tf.shape(low_image)[:2] 16 | low_w = tf.random.uniform( 17 | shape=(), maxval=low_image_shape[1] - low_crop_size + 1, dtype=tf.int32) 18 | low_h = tf.random.uniform( 19 | shape=(), maxval=low_image_shape[0] - low_crop_size + 1, dtype=tf.int32) 20 | enhanced_w = low_w 21 | enhanced_h = low_h 22 | low_image_cropped = low_image[ 23 | low_h:low_h + low_crop_size, 24 | low_w:low_w + low_crop_size 25 | ] 26 | enhanced_image_cropped = enhanced_image[ 27 | enhanced_h:enhanced_h + enhanced_crop_size, 28 | enhanced_w:enhanced_w + enhanced_crop_size 29 | ] 30 | return low_image_cropped, enhanced_image_cropped 31 | 32 | 33 | def random_flip(low_image, enhanced_image): 34 | return tf.cond( 35 | tf.random.uniform(shape=(), maxval=1) < 0.5, 36 | lambda: (low_image, enhanced_image), 37 | lambda: ( 38 | tf.image.flip_left_right(low_image), 39 | tf.image.flip_left_right(enhanced_image) 40 | ) 41 | ) 42 | 43 | 44 | def random_rotate(low_image, enhanced_image): 45 | condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32) 46 | return tf.image.rot90(low_image, condition), tf.image.rot90(enhanced_image, condition) 47 | 48 | 49 | def apply_scaling(low_image, enhanced_image): 50 | low_image = tf.cast(low_image, tf.float32) 51 | enhanced_image = tf.cast(enhanced_image, tf.float32) 52 | low_image = low_image / 255.0 53 | enhanced_image = enhanced_image / 255.0 54 | return low_image, enhanced_image 55 | -------------------------------------------------------------------------------- /mirnet/dataloaders/dataloader.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from .common import * 3 | 4 | 5 | class LOLDataLoader: 6 | 7 | def __init__(self, images_lowlight: List[str], images_highlight: List[str]): 8 | self.images_lowlight = images_lowlight 9 | self.images_highlight = images_highlight 10 | 11 | def __len__(self): 12 | assert len(self.images_lowlight) == len(self.images_enhanced) 13 | return len(self.images_lowlight) 14 | 15 | def build_dataset(self, image_crop_size: int, batch_size: int, is_dataset_train: bool): 16 | low_light_dataset = read_images(self.images_lowlight) 17 | high_light_dataset = read_images(self.images_highlight) 18 | dataset = tf.data.Dataset.zip((low_light_dataset, high_light_dataset)) 19 | dataset = dataset.map(apply_scaling, num_parallel_calls=tf.data.experimental.AUTOTUNE) 20 | dataset = dataset.map( 21 | lambda low, high: random_crop(low, high, image_crop_size, image_crop_size), 22 | num_parallel_calls=tf.data.experimental.AUTOTUNE 23 | ) 24 | if is_dataset_train: 25 | dataset = dataset.map(random_rotate, num_parallel_calls=tf.data.experimental.AUTOTUNE) 26 | dataset = dataset.map(random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE) 27 | dataset = dataset.batch(batch_size) 28 | dataset = dataset.repeat(1) 29 | dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) 30 | return dataset 31 | -------------------------------------------------------------------------------- /mirnet/inference.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import numpy as np 3 | from PIL import Image 4 | import tensorflow as tf 5 | from .model import mirnet_model 6 | from .utils import closest_number 7 | 8 | 9 | class Inferer: 10 | 11 | def __init__(self): 12 | self.model = None 13 | 14 | @staticmethod 15 | def download_weights(file_id: str): 16 | gdown.download( 17 | 'https://drive.google.com/uc?id={}'.format(file_id), 18 | 'low_light_weights_best.h5', quiet=False 19 | ) 20 | 21 | def build_model(self, num_rrg: int, num_mrb: int, channels: int, weights_path: str): 22 | self.model = mirnet_model( 23 | image_size=None, num_rrg=num_rrg, 24 | num_mrb=num_mrb, channels=channels 25 | ) 26 | self.model.load_weights(weights_path) 27 | 28 | def _predict(self, original_image, image_resize_factor: float = 1.): 29 | width, height = original_image.size 30 | target_width, target_height = ( 31 | closest_number(width // image_resize_factor, 4), 32 | closest_number(height // image_resize_factor, 4) 33 | ) 34 | original_image = original_image.resize( 35 | (target_width, target_height), Image.ANTIALIAS 36 | ) 37 | image = tf.keras.preprocessing.image.img_to_array(original_image) 38 | image = image.astype('float32') / 255.0 39 | image = np.expand_dims(image, axis=0) 40 | output = self.model.predict(image) 41 | output_image = output[0] * 255.0 42 | output_image = output_image.clip(0, 255) 43 | output_image = output_image.reshape( 44 | (np.shape(output_image)[0], np.shape(output_image)[1], 3) 45 | ) 46 | output_image = Image.fromarray(np.uint8(output_image)) 47 | original_image = Image.fromarray(np.uint8(original_image)) 48 | return output_image 49 | 50 | def infer(self, image_path, image_resize_factor: float = 1.): 51 | original_image = Image.open(image_path) 52 | output_image = self._predict(original_image, image_resize_factor) 53 | return original_image, output_image 54 | 55 | def infer_streamlit(self, image_pil, image_resize_factor: float = 1.): 56 | original_image = image_pil 57 | output_image = self._predict(original_image, image_resize_factor) 58 | return original_image, output_image 59 | -------------------------------------------------------------------------------- /mirnet/losses.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def charbonnier_loss(y_true, y_pred): 5 | return tf.reduce_mean( 6 | tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3)) 7 | ) 8 | -------------------------------------------------------------------------------- /mirnet/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .mirnet_model import mirnet_model 2 | -------------------------------------------------------------------------------- /mirnet/model/dual_attention_unit/__init__.py: -------------------------------------------------------------------------------- 1 | from .dau import dual_attention_unit_block 2 | -------------------------------------------------------------------------------- /mirnet/model/dual_attention_unit/attention_blocks.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def spatial_attention_block(input_tensor): 5 | """Spatial Attention Block""" 6 | average_pooling = tf.reduce_max(input_tensor, axis=-1) 7 | average_pooling = tf.expand_dims(average_pooling, axis=-1) 8 | max_pooling = tf.reduce_mean(input_tensor, axis=-1) 9 | max_pooling = tf.expand_dims(max_pooling, axis=-1) 10 | concatenated = tf.keras.layers.Concatenate(axis=-1)([average_pooling, max_pooling]) 11 | feature_map = tf.keras.layers.Conv2D(1, kernel_size=(1, 1))(concatenated) 12 | feature_map = tf.nn.sigmoid(feature_map) 13 | return input_tensor * feature_map 14 | 15 | 16 | def channel_attention_block(input_tensor): 17 | """Channel Attention Block""" 18 | channels = list(input_tensor.shape)[-1] 19 | average_pooling = tf.keras.layers.GlobalAveragePooling2D()(input_tensor) 20 | feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels)) 21 | feature_activations = tf.keras.layers.ReLU()( 22 | tf.keras.layers.Conv2D( 23 | filters=channels // 8, kernel_size=(1, 1) 24 | )(feature_descriptor) 25 | ) 26 | feature_activations = tf.nn.sigmoid( 27 | tf.keras.layers.Conv2D( 28 | filters=channels, kernel_size=(1, 1) 29 | )(feature_activations) 30 | ) 31 | return input_tensor * feature_activations 32 | -------------------------------------------------------------------------------- /mirnet/model/dual_attention_unit/dau.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .attention_blocks import ( 3 | channel_attention_block, 4 | spatial_attention_block 5 | ) 6 | 7 | 8 | def dual_attention_unit_block(input_tensor): 9 | """Dual Attention Unit Block""" 10 | channels = list(input_tensor.shape)[-1] 11 | feature_map = tf.keras.layers.Conv2D( 12 | channels, kernel_size=(3, 3), padding='same')(input_tensor) 13 | feature_map = tf.keras.layers.ReLU()(feature_map) 14 | feature_map = tf.keras.layers.Conv2D( 15 | channels, kernel_size=(3, 3), padding='same')(feature_map) 16 | channel_attention = channel_attention_block(feature_map) 17 | spatial_attention = spatial_attention_block(feature_map) 18 | concatenation = tf.keras.layers.Concatenate(axis=-1)([ 19 | channel_attention, spatial_attention]) 20 | concatenation = tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(concatenation) 21 | return tf.keras.layers.Add()([input_tensor, concatenation]) 22 | -------------------------------------------------------------------------------- /mirnet/model/mirnet_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from .skff import selective_kernel_feature_fusion 3 | from .dual_attention_unit import dual_attention_unit_block 4 | from .residual_resizing_modules import up_sampling_module, down_sampling_module 5 | 6 | 7 | def multi_scale_residual_block(input_tensor, channels): 8 | # features 9 | level1 = input_tensor 10 | level2 = down_sampling_module(input_tensor) 11 | level3 = down_sampling_module(level2) 12 | # DAU 13 | level1_dau = dual_attention_unit_block(level1) 14 | level2_dau = dual_attention_unit_block(level2) 15 | level3_dau = dual_attention_unit_block(level3) 16 | # SKFF 17 | level1_skff = selective_kernel_feature_fusion( 18 | level1_dau, up_sampling_module(level2_dau), 19 | up_sampling_module(up_sampling_module(level3_dau)) 20 | ) 21 | level2_skff = selective_kernel_feature_fusion( 22 | down_sampling_module(level1_dau), level2_dau, 23 | up_sampling_module(level3_dau) 24 | ) 25 | level3_skff = selective_kernel_feature_fusion( 26 | down_sampling_module(down_sampling_module(level1_dau)), 27 | down_sampling_module(level2_dau), level3_dau 28 | ) 29 | # DAU 2 30 | level1_dau_2 = dual_attention_unit_block(level1_skff) 31 | level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff))) 32 | level3_dau_2 = up_sampling_module(up_sampling_module(dual_attention_unit_block(level3_skff))) 33 | # SKFF 2 34 | skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2) 35 | # skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2) 36 | conv = tf.keras.layers.Conv2D(channels, kernel_size=(3, 3), padding='same')(skff_) 37 | return tf.keras.layers.Add()([input_tensor, conv]) 38 | 39 | 40 | def recursive_residual_group(input_tensor, num_mrb, channels): 41 | conv1 = tf.keras.layers.Conv2D( 42 | channels, kernel_size=(3, 3), padding='same')(input_tensor) 43 | for _ in range(num_mrb): 44 | conv1 = multi_scale_residual_block(conv1, channels) 45 | conv2 = tf.keras.layers.Conv2D( 46 | channels, kernel_size=(3, 3), padding='same')(conv1) 47 | return tf.keras.layers.Add()([conv2, input_tensor]) 48 | 49 | 50 | def mirnet_model(image_size: int, num_rrg: int, num_mrb: int, channels: int): 51 | input_tensor = tf.keras.Input(shape=[image_size, image_size, 3]) 52 | x1 = tf.keras.layers.Conv2D( 53 | channels, kernel_size=(3, 3), padding='same')(input_tensor) 54 | for _ in range(num_rrg): 55 | x1 = recursive_residual_group(x1, num_mrb, channels) 56 | conv = tf.keras.layers.Conv2D( 57 | 3, kernel_size=(3, 3), padding='same')(x1) 58 | output_tensor = tf.keras.layers.Add()([input_tensor, conv]) 59 | return tf.keras.Model(input_tensor, output_tensor) 60 | -------------------------------------------------------------------------------- /mirnet/model/residual_resizing_modules.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def down_sampling_module(input_tensor): 5 | """Downsampling Module""" 6 | channels = list(input_tensor.shape)[-1] 7 | main_branch = tf.keras.layers.Conv2D( 8 | channels, kernel_size=(1, 1))(input_tensor) 9 | main_branch = tf.nn.relu(main_branch) 10 | # main_branch = tf.keras.layers.Conv2D( 11 | # channels, kernel_size=(3, 3), padding='same')(input_tensor) 12 | main_branch = tf.keras.layers.Conv2D( 13 | channels, kernel_size=(3, 3), padding='same')(main_branch) 14 | main_branch = tf.nn.relu(main_branch) 15 | main_branch = tf.keras.layers.MaxPooling2D()(main_branch) 16 | main_branch = tf.keras.layers.Conv2D( 17 | channels * 2, kernel_size=(1, 1))(main_branch) 18 | skip_branch = tf.keras.layers.MaxPooling2D()(input_tensor) 19 | skip_branch = tf.keras.layers.Conv2D( 20 | channels * 2, kernel_size=(1, 1))(skip_branch) 21 | return tf.keras.layers.Add()([skip_branch, main_branch]) 22 | 23 | 24 | def up_sampling_module(input_tensor): 25 | """Upsampling Module""" 26 | channels = list(input_tensor.shape)[-1] 27 | main_branch = tf.keras.layers.Conv2D( 28 | channels, kernel_size=(1, 1))(input_tensor) 29 | main_branch = tf.nn.relu(main_branch) 30 | # main_branch = tf.keras.layers.Conv2D( 31 | # channels, kernel_size=(3, 3), padding='same')(input_tensor) 32 | main_branch = tf.keras.layers.Conv2D( 33 | channels, kernel_size=(3, 3), padding='same')(main_branch) 34 | main_branch = tf.nn.relu(main_branch) 35 | main_branch = tf.keras.layers.UpSampling2D()(main_branch) 36 | main_branch = tf.keras.layers.Conv2D( 37 | channels // 2, kernel_size=(1, 1))(main_branch) 38 | skip_branch = tf.keras.layers.UpSampling2D()(input_tensor) 39 | skip_branch = tf.keras.layers.Conv2D( 40 | channels // 2, kernel_size=(1, 1))(skip_branch) 41 | return tf.keras.layers.Add()([skip_branch, main_branch]) 42 | -------------------------------------------------------------------------------- /mirnet/model/skff.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def selective_kernel_feature_fusion( 5 | multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3): 6 | """Selective Kernel Feature Fusion Block""" 7 | channels = list(multi_scale_feature_1.shape)[-1] 8 | combined_feature = tf.keras.layers.Add()([ 9 | multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3]) 10 | gap = tf.keras.layers.GlobalAveragePooling2D()(combined_feature) 11 | channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels)) 12 | compact_feature_representation = tf.keras.layers.ReLU()( 13 | tf.keras.layers.Conv2D( 14 | filters=channels // 8, kernel_size=(1, 1) 15 | )(channel_wise_statistics) 16 | ) 17 | feature_descriptor_1 = tf.nn.softmax( 18 | tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(compact_feature_representation) 19 | ) 20 | feature_descriptor_2 = tf.nn.softmax( 21 | tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(compact_feature_representation) 22 | ) 23 | feature_descriptor_3 = tf.nn.softmax( 24 | tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(compact_feature_representation) 25 | ) 26 | feature_1 = multi_scale_feature_1 * feature_descriptor_1 27 | feature_2 = multi_scale_feature_2 * feature_descriptor_2 28 | feature_3 = multi_scale_feature_3 * feature_descriptor_3 29 | aggregated_feature = tf.keras.layers.Add()([feature_1, feature_2, feature_3]) 30 | return aggregated_feature 31 | -------------------------------------------------------------------------------- /mirnet/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | from typing import List 4 | from .utils import psnr 5 | from .model import mirnet_model 6 | from .losses import charbonnier_loss 7 | from wandb.keras import WandbCallback 8 | from .dataloaders import LOLDataLoader 9 | 10 | 11 | class LowLightTrainer: 12 | 13 | def __init__(self): 14 | self.model = None 15 | self.crop_size = None 16 | self.train_dataset = None 17 | self.valid_dataset = None 18 | # self.strategy = tf.distribute.OneDeviceStrategy("GPU:0") 19 | # if len(tf.config.list_physical_devices('GPU')) > 1: 20 | # self.strategy = tf.distribute.MirroredStrategy() 21 | 22 | def build_dataset( 23 | self, train_low_light_images: List[str], train_high_light_images: List[str], 24 | valid_low_light_images: List[str], valid_high_light_images: List[str], 25 | crop_size: int, batch_size: int): 26 | self.crop_size = crop_size 27 | self.train_dataset = LOLDataLoader( 28 | images_lowlight=train_low_light_images, 29 | images_highlight=train_high_light_images 30 | ).build_dataset( 31 | image_crop_size=crop_size, batch_size=batch_size, is_dataset_train=True) 32 | self.valid_dataset = LOLDataLoader( 33 | images_lowlight=valid_low_light_images, 34 | images_highlight=valid_high_light_images 35 | ).build_dataset( 36 | image_crop_size=crop_size, batch_size=batch_size, is_dataset_train=False) 37 | 38 | def compile(self, num_rrg=3, num_mrb=2, channels=64, learning_rate=1e-4, use_mae_loss=True): 39 | self.model = mirnet_model(self.crop_size, num_rrg, num_mrb, channels) 40 | loss_function = tf.keras.losses.MeanAbsoluteError() if use_mae_loss else charbonnier_loss 41 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) 42 | self.model.compile(optimizer=optimizer, loss=loss_function, metrics=[psnr]) 43 | 44 | def train(self, epochs: int, checkpoint_dir: str): 45 | callbacks = [ 46 | tf.keras.callbacks.EarlyStopping( 47 | monitor="val_psnr", 48 | patience=10, mode='max' 49 | ), 50 | tf.keras.callbacks.ReduceLROnPlateau( 51 | monitor='val_psnr', factor=0.5, 52 | patience=5, verbose=1, min_delta=1e-7, mode='max' 53 | ), 54 | tf.keras.callbacks.ModelCheckpoint( 55 | os.path.join(checkpoint_dir, 'low_light_weights_best.h5'), 56 | monitor="val_psnr", save_weights_only=True, 57 | mode="max", save_best_only=True, save_freq=1 58 | ), WandbCallback() 59 | ] 60 | history = self.model.fit( 61 | self.train_dataset, validation_data=self.valid_dataset, 62 | epochs=epochs, callbacks=callbacks, verbose=1 63 | ) 64 | return history 65 | -------------------------------------------------------------------------------- /mirnet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import wandb 4 | import subprocess 5 | import tensorflow as tf 6 | from matplotlib import pyplot as plt 7 | 8 | 9 | def psnr(y_true, y_pred): 10 | return tf.image.psnr(y_pred, y_true, max_val=255.0) 11 | 12 | 13 | def init_wandb(project_name, experiment_name, wandb_api_key): 14 | """Initialize Wandb 15 | Args: 16 | project_name: project name on Wandb 17 | experiment_name: experiment name on Wandb 18 | wandb_api_key: Wandb API Key 19 | """ 20 | if project_name is not None and experiment_name is not None: 21 | os.environ['WANDB_API_KEY'] = wandb_api_key 22 | wandb.init(project=project_name, name=experiment_name) 23 | 24 | 25 | def download_dataset(dataset_tag): 26 | """Utility for downloading and unpacking dataset dataset 27 | Args: 28 | dataset_tag: Tag for the respective dataset. 29 | Available tags -> ('LOL') 30 | """ 31 | print('Downloading dataset...') 32 | if dataset_tag == 'LOL': 33 | gdown.download( 34 | 'https://drive.google.com/uc?id=157bjO1_cFuSd0HWDUuAmcHRJDVyWpOxB', 35 | 'LOLdataset.zip', quiet=False 36 | ) 37 | print('Unpacking Dataset') 38 | subprocess.run(['unzip', 'LOLdataset.zip']) 39 | print('Done!!!') 40 | else: 41 | raise AssertionError('Dataset tag not found') 42 | 43 | 44 | def plot_result(image, enhanced): 45 | """Utility for Plotting inference result 46 | Args: 47 | image: original image 48 | enhanced: enhanced image 49 | """ 50 | fig = plt.figure(figsize=(12, 12)) 51 | fig.add_subplot(1, 2, 1).set_title('Original Image') 52 | _ = plt.imshow(image) 53 | fig.add_subplot(1, 2, 2).set_title('Enhanced Image') 54 | _ = plt.imshow(enhanced) 55 | plt.show() 56 | 57 | 58 | def closest_number(n, m): 59 | q = int(n / m) 60 | n1 = m * q 61 | if (n * m) > 0: 62 | n2 = (m * (q + 1)) 63 | else: 64 | n2 = (m * (q - 1)) 65 | if abs(n - n1) < abs(n - n2): 66 | return n1 67 | return n2 68 | -------------------------------------------------------------------------------- /notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/notebooks/.gitkeep -------------------------------------------------------------------------------- /notebooks/MIRNet_Low_Light_Train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "MIRNet-Low-Light-Train.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyPYkiW3WGoZrgLnaTdb98k7", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "view-in-github", 23 | "colab_type": "text" 24 | }, 25 | "source": [ 26 | "\"Open" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "colab": { 33 | "base_uri": "https://localhost:8080/" 34 | }, 35 | "id": "u6SxYHy-0tAX", 36 | "outputId": "0b8689e6-a29e-4bd0-9f6f-f4d81cc7f836" 37 | }, 38 | "source": [ 39 | "!nvidia-smi" 40 | ], 41 | "execution_count": 1, 42 | "outputs": [ 43 | { 44 | "output_type": "stream", 45 | "text": [ 46 | "Tue Dec 1 07:16:36 2020 \n", 47 | "+-----------------------------------------------------------------------------+\n", 48 | "| NVIDIA-SMI 455.38 Driver Version: 418.67 CUDA Version: 10.1 |\n", 49 | "|-------------------------------+----------------------+----------------------+\n", 50 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", 51 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", 52 | "| | | MIG M. |\n", 53 | "|===============================+======================+======================|\n", 54 | "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n", 55 | "| N/A 66C P8 11W / 70W | 0MiB / 15079MiB | 0% Default |\n", 56 | "| | | ERR! |\n", 57 | "+-------------------------------+----------------------+----------------------+\n", 58 | " \n", 59 | "+-----------------------------------------------------------------------------+\n", 60 | "| Processes: |\n", 61 | "| GPU GI CI PID Type Process name GPU Memory |\n", 62 | "| ID ID Usage |\n", 63 | "|=============================================================================|\n", 64 | "| No running processes found |\n", 65 | "+-----------------------------------------------------------------------------+\n" 66 | ], 67 | "name": "stdout" 68 | } 69 | ] 70 | }, 71 | { 72 | "cell_type": "code", 73 | "metadata": { 74 | "colab": { 75 | "base_uri": "https://localhost:8080/" 76 | }, 77 | "id": "H42UtLtD3QMh", 78 | "outputId": "9e3d5c8d-7f11-40c6-a0c5-a1d2ddbcc0c0" 79 | }, 80 | "source": [ 81 | "!git clone https://github.com/soumik12345/MIRNet\n", 82 | "%cd MIRNet" 83 | ], 84 | "execution_count": 2, 85 | "outputs": [ 86 | { 87 | "output_type": "stream", 88 | "text": [ 89 | "Cloning into 'MIRNet'...\n", 90 | "remote: Enumerating objects: 138, done.\u001b[K\n", 91 | "remote: Counting objects: 100% (138/138), done.\u001b[K\n", 92 | "remote: Compressing objects: 100% (104/104), done.\u001b[K\n", 93 | "remote: Total 138 (delta 52), reused 106 (delta 25), pack-reused 0\u001b[K\n", 94 | "Receiving objects: 100% (138/138), 9.84 MiB | 14.21 MiB/s, done.\n", 95 | "Resolving deltas: 100% (52/52), done.\n", 96 | "/content/MIRNet\n" 97 | ], 98 | "name": "stdout" 99 | } 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "metadata": { 105 | "colab": { 106 | "base_uri": "https://localhost:8080/" 107 | }, 108 | "id": "GMhYQqQ_6Qh5", 109 | "outputId": "f263304a-bb43-458d-e06f-b345cc496f60" 110 | }, 111 | "source": [ 112 | "!pip install -qq wandb" 113 | ], 114 | "execution_count": 3, 115 | "outputs": [ 116 | { 117 | "output_type": "stream", 118 | "text": [ 119 | "\u001b[K |████████████████████████████████| 1.8MB 23.2MB/s \n", 120 | "\u001b[K |████████████████████████████████| 102kB 14.7MB/s \n", 121 | "\u001b[K |████████████████████████████████| 133kB 60.6MB/s \n", 122 | "\u001b[K |████████████████████████████████| 102kB 13.9MB/s \n", 123 | "\u001b[K |████████████████████████████████| 163kB 59.2MB/s \n", 124 | "\u001b[K |████████████████████████████████| 71kB 11.6MB/s \n", 125 | "\u001b[?25h Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 126 | " Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 127 | " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n" 128 | ], 129 | "name": "stdout" 130 | } 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "metadata": { 136 | "id": "hAXrPnAx-seN" 137 | }, 138 | "source": [ 139 | "from glob import glob\n", 140 | "import tensorflow as tf\n", 141 | "from mirnet.train import LowLightTrainer\n", 142 | "from mirnet.utils import init_wandb, download_dataset\n", 143 | "\n", 144 | "\n", 145 | "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)" 146 | ], 147 | "execution_count": 4, 148 | "outputs": [] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "metadata": { 153 | "colab": { 154 | "base_uri": "https://localhost:8080/" 155 | }, 156 | "id": "fT7Fr8Tq_bqn", 157 | "outputId": "96ccfa74-77e8-453d-d6c1-7227387daea2" 158 | }, 159 | "source": [ 160 | "download_dataset('LOL')" 161 | ], 162 | "execution_count": 5, 163 | "outputs": [ 164 | { 165 | "output_type": "stream", 166 | "text": [ 167 | "Downloading dataset...\n" 168 | ], 169 | "name": "stdout" 170 | }, 171 | { 172 | "output_type": "stream", 173 | "text": [ 174 | "Downloading...\n", 175 | "From: https://drive.google.com/uc?id=157bjO1_cFuSd0HWDUuAmcHRJDVyWpOxB\n", 176 | "To: /content/MIRNet/LOLdataset.zip\n", 177 | "347MB [00:01, 211MB/s]\n" 178 | ], 179 | "name": "stderr" 180 | }, 181 | { 182 | "output_type": "stream", 183 | "text": [ 184 | "Unpacking Dataset\n", 185 | "Done!!!\n" 186 | ], 187 | "name": "stdout" 188 | } 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "metadata": { 194 | "colab": { 195 | "base_uri": "https://localhost:8080/", 196 | "height": 136 197 | }, 198 | "id": "0lO57o2W_sqq", 199 | "outputId": "ac006479-bf25-428f-e087-7ff523e89bba" 200 | }, 201 | "source": [ 202 | "init_wandb(\n", 203 | " project_name='mirnet',\n", 204 | " experiment_name='LOL_lowlight_experiment_2_256x256',\n", 205 | " wandb_api_key='cf0947ccde62903d4df0742a58b8a54ca4c11673'\n", 206 | ")" 207 | ], 208 | "execution_count": 6, 209 | "outputs": [ 210 | { 211 | "output_type": "stream", 212 | "text": [ 213 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m19soumik-rakshit96\u001b[0m (use `wandb login --relogin` to force relogin)\n" 214 | ], 215 | "name": "stderr" 216 | }, 217 | { 218 | "output_type": "display_data", 219 | "data": { 220 | "text/html": [ 221 | "\n", 222 | " Tracking run with wandb version 0.10.11
\n", 223 | " Syncing run LOL_lowlight_experiment_2_256x256 to Weights & Biases (Documentation).
\n", 224 | " Project page: https://wandb.ai/19soumik-rakshit96/mirnet
\n", 225 | " Run page: https://wandb.ai/19soumik-rakshit96/mirnet/runs/15w0gpyv
\n", 226 | " Run data is saved locally in /content/MIRNet/wandb/run-20201201_071736-15w0gpyv

\n", 227 | " " 228 | ], 229 | "text/plain": [ 230 | "" 231 | ] 232 | }, 233 | "metadata": { 234 | "tags": [] 235 | } 236 | } 237 | ] 238 | }, 239 | { 240 | "cell_type": "code", 241 | "metadata": { 242 | "id": "Kiom1ZjH8eKi" 243 | }, 244 | "source": [ 245 | "trainer = LowLightTrainer()\n", 246 | "train_low_light_images = glob('./our485/low/*')\n", 247 | "train_high_light_images = glob('./our485/high/*')\n", 248 | "valid_low_light_images = glob('./eval15/low/*')\n", 249 | "valid_high_light_images = glob('./eval15/high/*')" 250 | ], 251 | "execution_count": 7, 252 | "outputs": [] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "metadata": { 257 | "id": "Kz0i4QoQ8qw1" 258 | }, 259 | "source": [ 260 | "trainer.build_dataset(\n", 261 | " train_low_light_images, train_high_light_images,\n", 262 | " valid_low_light_images, valid_high_light_images,\n", 263 | " crop_size=256, batch_size=2\n", 264 | ")\n", 265 | "trainer.compile()" 266 | ], 267 | "execution_count": 8, 268 | "outputs": [] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "metadata": { 273 | "colab": { 274 | "base_uri": "https://localhost:8080/" 275 | }, 276 | "id": "zeCmSZqI9GBU", 277 | "outputId": "e939e100-091d-4596-c542-1516c60f401f" 278 | }, 279 | "source": [ 280 | "trainer.train(epochs=100, checkpoint_dir='./checkpoints')" 281 | ], 282 | "execution_count": 9, 283 | "outputs": [ 284 | { 285 | "output_type": "stream", 286 | "text": [ 287 | "Epoch 1/100\n", 288 | "243/243 [==============================] - 442s 2s/step - loss: 0.1592 - psnr: 64.1155 - val_loss: 0.1249 - val_psnr: 65.5065\n", 289 | "Epoch 2/100\n", 290 | "243/243 [==============================] - 433s 2s/step - loss: 0.1509 - psnr: 64.5416 - val_loss: 0.1429 - val_psnr: 64.6090\n", 291 | "Epoch 3/100\n", 292 | "243/243 [==============================] - 431s 2s/step - loss: 0.1439 - psnr: 65.2337 - val_loss: 0.1303 - val_psnr: 65.2745\n", 293 | "Epoch 4/100\n", 294 | "243/243 [==============================] - 436s 2s/step - loss: 0.1285 - psnr: 65.8450 - val_loss: 0.1140 - val_psnr: 66.2238\n", 295 | "Epoch 5/100\n", 296 | "243/243 [==============================] - 438s 2s/step - loss: 0.1194 - psnr: 66.4395 - val_loss: 0.1135 - val_psnr: 66.3191\n", 297 | "Epoch 6/100\n", 298 | "243/243 [==============================] - 437s 2s/step - loss: 0.1102 - psnr: 67.0214 - val_loss: 0.1049 - val_psnr: 66.9083\n", 299 | "Epoch 7/100\n", 300 | "243/243 [==============================] - 430s 2s/step - loss: 0.1060 - psnr: 67.3232 - val_loss: 0.1132 - val_psnr: 66.4320\n", 301 | "Epoch 8/100\n", 302 | "243/243 [==============================] - 431s 2s/step - loss: 0.1069 - psnr: 67.3737 - val_loss: 0.1094 - val_psnr: 66.4070\n", 303 | "Epoch 9/100\n", 304 | "243/243 [==============================] - 431s 2s/step - loss: 0.1044 - psnr: 67.4384 - val_loss: 0.1210 - val_psnr: 65.9124\n", 305 | "Epoch 10/100\n", 306 | "243/243 [==============================] - 432s 2s/step - loss: 0.1080 - psnr: 67.3665 - val_loss: 0.1115 - val_psnr: 66.9443\n", 307 | "Epoch 11/100\n", 308 | "243/243 [==============================] - 438s 2s/step - loss: 0.1041 - psnr: 67.5550 - val_loss: 0.1022 - val_psnr: 67.0890\n", 309 | "Epoch 12/100\n", 310 | "243/243 [==============================] - 432s 2s/step - loss: 0.1043 - psnr: 67.4996 - val_loss: 0.1067 - val_psnr: 67.0900\n", 311 | "Epoch 13/100\n", 312 | "243/243 [==============================] - 438s 2s/step - loss: 0.1059 - psnr: 67.4166 - val_loss: 0.0977 - val_psnr: 67.6418\n", 313 | "Epoch 14/100\n", 314 | "243/243 [==============================] - 438s 2s/step - loss: 0.0989 - psnr: 67.9418 - val_loss: 0.0963 - val_psnr: 68.1703\n", 315 | "Epoch 15/100\n", 316 | "243/243 [==============================] - 438s 2s/step - loss: 0.1015 - psnr: 67.7641 - val_loss: 0.0939 - val_psnr: 68.4126\n", 317 | "Epoch 16/100\n", 318 | "243/243 [==============================] - 431s 2s/step - loss: 0.0982 - psnr: 67.9785 - val_loss: 0.0951 - val_psnr: 68.0276\n", 319 | "Epoch 17/100\n", 320 | "243/243 [==============================] - 431s 2s/step - loss: 0.0981 - psnr: 68.0821 - val_loss: 0.1169 - val_psnr: 66.7316\n", 321 | "Epoch 18/100\n", 322 | "243/243 [==============================] - 430s 2s/step - loss: 0.1004 - psnr: 67.8817 - val_loss: 0.1016 - val_psnr: 67.3338\n", 323 | "Epoch 19/100\n", 324 | "243/243 [==============================] - 430s 2s/step - loss: 0.0959 - psnr: 68.2082 - val_loss: 0.0995 - val_psnr: 67.4730\n", 325 | "Epoch 20/100\n", 326 | "243/243 [==============================] - ETA: 0s - loss: 0.1026 - psnr: 67.7257\n", 327 | "Epoch 00020: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.\n", 328 | "243/243 [==============================] - 430s 2s/step - loss: 0.1026 - psnr: 67.7257 - val_loss: 0.1003 - val_psnr: 67.3352\n", 329 | "Epoch 21/100\n", 330 | "243/243 [==============================] - 436s 2s/step - loss: 0.0919 - psnr: 68.6314 - val_loss: 0.0826 - val_psnr: 69.3167\n", 331 | "Epoch 22/100\n", 332 | "243/243 [==============================] - 430s 2s/step - loss: 0.0936 - psnr: 68.3714 - val_loss: 0.1035 - val_psnr: 67.5037\n", 333 | "Epoch 23/100\n", 334 | "243/243 [==============================] - 432s 2s/step - loss: 0.0896 - psnr: 68.7174 - val_loss: 0.0862 - val_psnr: 69.1280\n", 335 | "Epoch 24/100\n", 336 | "243/243 [==============================] - 435s 2s/step - loss: 0.0917 - psnr: 68.6044 - val_loss: 0.0796 - val_psnr: 69.3943\n", 337 | "Epoch 25/100\n", 338 | "243/243 [==============================] - 431s 2s/step - loss: 0.0895 - psnr: 68.7495 - val_loss: 0.0898 - val_psnr: 68.8661\n", 339 | "Epoch 26/100\n", 340 | "243/243 [==============================] - 430s 2s/step - loss: 0.0890 - psnr: 68.8537 - val_loss: 0.0983 - val_psnr: 68.0671\n", 341 | "Epoch 27/100\n", 342 | "243/243 [==============================] - 431s 2s/step - loss: 0.0920 - psnr: 68.5267 - val_loss: 0.0975 - val_psnr: 67.9822\n", 343 | "Epoch 28/100\n", 344 | "243/243 [==============================] - 431s 2s/step - loss: 0.0897 - psnr: 68.7615 - val_loss: 0.1008 - val_psnr: 67.6455\n", 345 | "Epoch 29/100\n", 346 | "243/243 [==============================] - ETA: 0s - loss: 0.0864 - psnr: 69.0921\n", 347 | "Epoch 00029: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.\n", 348 | "243/243 [==============================] - 430s 2s/step - loss: 0.0864 - psnr: 69.0921 - val_loss: 0.0940 - val_psnr: 68.5634\n", 349 | "Epoch 30/100\n", 350 | "243/243 [==============================] - 431s 2s/step - loss: 0.0868 - psnr: 69.0637 - val_loss: 0.0917 - val_psnr: 68.4715\n", 351 | "Epoch 31/100\n", 352 | "243/243 [==============================] - 436s 2s/step - loss: 0.0845 - psnr: 69.2952 - val_loss: 0.0791 - val_psnr: 69.2714\n", 353 | "Epoch 32/100\n", 354 | "243/243 [==============================] - 431s 2s/step - loss: 0.0867 - psnr: 69.1577 - val_loss: 0.0809 - val_psnr: 69.5186\n", 355 | "Epoch 33/100\n", 356 | "243/243 [==============================] - 430s 2s/step - loss: 0.0860 - psnr: 69.3142 - val_loss: 0.0868 - val_psnr: 68.9742\n", 357 | "Epoch 34/100\n", 358 | "243/243 [==============================] - 431s 2s/step - loss: 0.0858 - psnr: 69.1895 - val_loss: 0.0890 - val_psnr: 68.3421\n", 359 | "Epoch 35/100\n", 360 | "243/243 [==============================] - 430s 2s/step - loss: 0.0860 - psnr: 69.1920 - val_loss: 0.0791 - val_psnr: 69.2999\n", 361 | "Epoch 36/100\n", 362 | "243/243 [==============================] - 430s 2s/step - loss: 0.0867 - psnr: 69.2137 - val_loss: 0.0832 - val_psnr: 68.8973\n", 363 | "Epoch 37/100\n", 364 | "243/243 [==============================] - 436s 2s/step - loss: 0.0853 - psnr: 69.2414 - val_loss: 0.0716 - val_psnr: 70.2069\n", 365 | "Epoch 38/100\n", 366 | "243/243 [==============================] - 431s 2s/step - loss: 0.0858 - psnr: 69.2487 - val_loss: 0.0895 - val_psnr: 68.9241\n", 367 | "Epoch 39/100\n", 368 | "243/243 [==============================] - 431s 2s/step - loss: 0.0856 - psnr: 69.2450 - val_loss: 0.0896 - val_psnr: 68.6078\n", 369 | "Epoch 40/100\n", 370 | "243/243 [==============================] - 430s 2s/step - loss: 0.0840 - psnr: 69.2957 - val_loss: 0.0752 - val_psnr: 69.8205\n", 371 | "Epoch 41/100\n", 372 | "243/243 [==============================] - 430s 2s/step - loss: 0.0843 - psnr: 69.3408 - val_loss: 0.0924 - val_psnr: 68.4360\n", 373 | "Epoch 42/100\n", 374 | "243/243 [==============================] - ETA: 0s - loss: 0.0834 - psnr: 69.4409\n", 375 | "Epoch 00042: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-05.\n", 376 | "243/243 [==============================] - 430s 2s/step - loss: 0.0834 - psnr: 69.4409 - val_loss: 0.0966 - val_psnr: 67.8450\n", 377 | "Epoch 43/100\n", 378 | "243/243 [==============================] - 431s 2s/step - loss: 0.0803 - psnr: 69.8295 - val_loss: 0.0844 - val_psnr: 69.0862\n", 379 | "Epoch 44/100\n", 380 | "243/243 [==============================] - 431s 2s/step - loss: 0.0801 - psnr: 69.7533 - val_loss: 0.0876 - val_psnr: 68.7879\n", 381 | "Epoch 45/100\n", 382 | "243/243 [==============================] - 431s 2s/step - loss: 0.0826 - psnr: 69.6043 - val_loss: 0.0934 - val_psnr: 68.2593\n", 383 | "Epoch 46/100\n", 384 | "243/243 [==============================] - 430s 2s/step - loss: 0.0789 - psnr: 69.9293 - val_loss: 0.0947 - val_psnr: 68.2925\n", 385 | "Epoch 47/100\n", 386 | "243/243 [==============================] - ETA: 0s - loss: 0.0803 - psnr: 69.7273\n", 387 | "Epoch 00047: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-06.\n", 388 | "243/243 [==============================] - 429s 2s/step - loss: 0.0803 - psnr: 69.7273 - val_loss: 0.0935 - val_psnr: 68.1041\n" 389 | ], 390 | "name": "stdout" 391 | }, 392 | { 393 | "output_type": "execute_result", 394 | "data": { 395 | "text/plain": [ 396 | "" 397 | ] 398 | }, 399 | "metadata": { 400 | "tags": [] 401 | }, 402 | "execution_count": 9 403 | } 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "metadata": { 409 | "colab": { 410 | "base_uri": "https://localhost:8080/", 411 | "height": 17 412 | }, 413 | "id": "dGkCwyM2_odJ", 414 | "outputId": "451d95d4-8113-4322-fbbb-6f4b40c0898a" 415 | }, 416 | "source": [ 417 | "from glob import glob\n", 418 | "from google.colab import files\n", 419 | "\n", 420 | "\n", 421 | "for file in glob('/content/MIRNet/checkpoints/*'):\n", 422 | " files.download(file)" 423 | ], 424 | "execution_count": 14, 425 | "outputs": [ 426 | { 427 | "output_type": "display_data", 428 | "data": { 429 | "application/javascript": [ 430 | "\n", 431 | " async function download(id, filename, size) {\n", 432 | " if (!google.colab.kernel.accessAllowed) {\n", 433 | " return;\n", 434 | " }\n", 435 | " const div = document.createElement('div');\n", 436 | " const label = document.createElement('label');\n", 437 | " label.textContent = `Downloading \"${filename}\": `;\n", 438 | " div.appendChild(label);\n", 439 | " const progress = document.createElement('progress');\n", 440 | " progress.max = size;\n", 441 | " div.appendChild(progress);\n", 442 | " document.body.appendChild(div);\n", 443 | "\n", 444 | " const buffers = [];\n", 445 | " let downloaded = 0;\n", 446 | "\n", 447 | " const channel = await google.colab.kernel.comms.open(id);\n", 448 | " // Send a message to notify the kernel that we're ready.\n", 449 | " channel.send({})\n", 450 | "\n", 451 | " for await (const message of channel.messages) {\n", 452 | " // Send a message to notify the kernel that we're ready.\n", 453 | " channel.send({})\n", 454 | " if (message.buffers) {\n", 455 | " for (const buffer of message.buffers) {\n", 456 | " buffers.push(buffer);\n", 457 | " downloaded += buffer.byteLength;\n", 458 | " progress.value = downloaded;\n", 459 | " }\n", 460 | " }\n", 461 | " }\n", 462 | " const blob = new Blob(buffers, {type: 'application/binary'});\n", 463 | " const a = document.createElement('a');\n", 464 | " a.href = window.URL.createObjectURL(blob);\n", 465 | " a.download = filename;\n", 466 | " div.appendChild(a);\n", 467 | " a.click();\n", 468 | " div.remove();\n", 469 | " }\n", 470 | " " 471 | ], 472 | "text/plain": [ 473 | "" 474 | ] 475 | }, 476 | "metadata": { 477 | "tags": [] 478 | } 479 | }, 480 | { 481 | "output_type": "display_data", 482 | "data": { 483 | "application/javascript": [ 484 | "download(\"download_6ea7b0ed-1fe2-4f54-a838-3f3f7d13aa3c\", \"low_light_weights_best_256x256.h5\", 147825648)" 485 | ], 486 | "text/plain": [ 487 | "" 488 | ] 489 | }, 490 | "metadata": { 491 | "tags": [] 492 | } 493 | } 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "metadata": { 499 | "id": "IrgoJL0xZTuv" 500 | }, 501 | "source": [ 502 | "" 503 | ], 504 | "execution_count": null, 505 | "outputs": [] 506 | } 507 | ] 508 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gdown==3.12.2 2 | matplotlib==3.3.3 3 | streamlit==0.74.1 4 | tensorflow==2.4.0 5 | wandb==0.10.14 6 | Pillow~=8.1.0 7 | numpy~=1.19.5 8 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | streamlit run app.py 2 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | mkdir -p ~/.streamlit/ 2 | 3 | echo "\ 4 | [server]\n\ 5 | port = $PORT\n\ 6 | enableCORS = false\n\ 7 | headless = true\n\ 8 | \n\ 9 | " > ~/.streamlit/config.toml 10 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import tensorflow as tf 3 | from matplotlib import pyplot as plt 4 | from mirnet.model import mirnet_model 5 | from mirnet.dataloaders import LOLDataLoader 6 | 7 | 8 | def test_output_dim(): 9 | mirnet = mirnet_model(256, 3, 2, 64) 10 | x = tf.ones((1, 256, 256, 3)) 11 | y = mirnet(x) 12 | assert x.shape == y.shape 13 | 14 | 15 | def test_dataloader(): 16 | lowlight_images = glob('./data/LOLdataset/our485/low/*') 17 | highlight_images = glob('./data/LOLdataset/our485/high/*') 18 | dataset = LOLDataLoader( 19 | images_lowlight=lowlight_images, 20 | images_highlight=highlight_images 21 | ).build_dataset( 22 | image_crop_size=128, batch_size=1, apply_transforms=True 23 | ) 24 | print(dataset) 25 | x, y = next(iter(dataset)) 26 | print(x.shape, y.shape) 27 | plt.imshow(tf.cast(x[0] * 255, dtype=tf.uint8)) 28 | plt.title('Low Light Patch (128 x 128)') 29 | plt.show() 30 | plt.title('High Light Patch (128 x 128)') 31 | plt.imshow(tf.cast(y[0] * 255, dtype=tf.uint8)) 32 | plt.show() 33 | 34 | 35 | if __name__ == '__main__': 36 | # test_dataloader() 37 | test_output_dim() 38 | --------------------------------------------------------------------------------