├── .gitignore
├── LICENSE
├── Procfile
├── README.md
├── app.py
├── assets
├── lol_results.gif
└── mirnet_architecture.png
├── checkpoints
└── .gitkeep
├── main.py
├── mirnet
├── __init__.py
├── dataloaders
│ ├── __init__.py
│ ├── common.py
│ └── dataloader.py
├── inference.py
├── losses.py
├── model
│ ├── __init__.py
│ ├── dual_attention_unit
│ │ ├── __init__.py
│ │ ├── attention_blocks.py
│ │ └── dau.py
│ ├── mirnet_model.py
│ ├── residual_resizing_modules.py
│ └── skff.py
├── train.py
└── utils.py
├── notebooks
├── .gitkeep
├── MIRNet_LOL_Inference.ipynb
├── MIRNet_LOL_Inference_256x256.ipynb
└── MIRNet_Low_Light_Train.ipynb
├── requirements.txt
├── run.sh
├── setup.sh
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | venv/
3 | .idea/
4 | **/__pycache__/**
5 | .DS_Store
6 | checkpoints/
7 | **.h5
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Procfile:
--------------------------------------------------------------------------------
1 | web: sh setup.sh && streamlit run app.py
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MIRNet
2 |
3 | [](https://share.streamlit.io/soumik12345/mirnet/app.py)
4 |
5 | Tensorflow implementation of the MIRNet architecture as proposed by [Learning Enriched Features for Real Image
6 | Restoration and Enhancement](https://arxiv.org/pdf/2003.06792v2.pdf).
7 |
8 | **Lanuch Notebooks:** [](https://mybinder.org/v2/gh/soumik12345/MIRNet/HEAD)
9 |
10 | **Wandb Logs:** [https://wandb.ai/19soumik-rakshit96/mirnet](https://wandb.ai/19soumik-rakshit96/mirnet)
11 |
12 | **Blog Post:** [https://keras.io/examples/vision/mirnet/](https://keras.io/examples/vision/mirnet/)
13 |
14 | **TFLite Variant of MIRNet:** [https://github.com/sayakpaul/MIRNet-TFLite](https://github.com/sayakpaul/MIRNet-TFLite).
15 |
16 | **TFLite Models on Tensorflow Hub:** [https://tfhub.dev/sayakpaul/lite-model/mirnet-fixed/dr/1](https://tfhub.dev/sayakpaul/lite-model/mirnet-fixed/dr/1).
17 |
18 | **Tensorflow JS Variant of MIRNet:** [https://github.com/Rishit-dagli/MIRNet-TFJS](https://github.com/Rishit-dagli/MIRNet-TFJS).
19 |
20 | 
21 |
22 | 
23 |
24 | ## Pre-trained Weights
25 |
26 | - **Trained on 128x128 patches:** [https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing](https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing)
27 |
28 | - **Trained on 256x256 patches:** [https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing](https://drive.google.com/file/d/1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL/view?usp=sharing)
29 |
30 | ## Citation
31 |
32 | ```
33 | @misc{
34 | 2003.06792,
35 | Author = {Syed Waqas Zamir and Aditya Arora and Salman Khan and Munawar Hayat and Fahad Shahbaz Khan and Ming-Hsuan Yang and Ling Shao},
36 | Title = {Learning Enriched Features for Real Image Restoration and Enhancement},
37 | Year = {2020},
38 | Eprint = {arXiv:2003.06792},
39 | }
40 | ```
41 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from PIL import Image
4 | import streamlit as st
5 | from mirnet.inference import Inferer
6 |
7 |
8 | def main():
9 | st.markdown(
10 | '
Low-light Image Enhancement using MIRNet
',
11 | unsafe_allow_html=True
12 | )
13 | inferer = Inferer()
14 | if not os.path.exists('low_light_weights_best.h5'):
15 | st.sidebar.text('Downloading Model weights...')
16 | inferer.download_weights('1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL')
17 | st.sidebar.text('Done')
18 | st.sidebar.text('Building MIRNet Model...')
19 | inferer.build_model(
20 | num_rrg=3, num_mrb=2, channels=64,
21 | weights_path='low_light_weights_best.h5'
22 | )
23 | st.sidebar.text('Done')
24 | uploaded_files = st.sidebar.file_uploader(
25 | 'Please Upload your Low-light Images',
26 | accept_multiple_files=True
27 | )
28 | col_1, col_2 = st.beta_columns(2)
29 | if len(uploaded_files) > 0:
30 | for uploaded_file in uploaded_files:
31 | pil_image = Image.open(uploaded_file)
32 | original_image, output_image = inferer.infer_streamlit(pil_image)
33 | with col_1:
34 | st.image(
35 | original_image, use_column_width=True,
36 | caption='Original Image'
37 | )
38 | with col_2:
39 | st.image(
40 | output_image, use_column_width=True,
41 | caption='Predicted Image'
42 | )
43 | st.markdown('---')
44 | if not os.path.exists('low_light_weights_best.h5'):
45 | subprocess.run(['rm', 'low_light_weights_best.h5'])
46 |
47 |
48 | if __name__ == '__main__':
49 | main()
50 |
--------------------------------------------------------------------------------
/assets/lol_results.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/assets/lol_results.gif
--------------------------------------------------------------------------------
/assets/mirnet_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/assets/mirnet_architecture.png
--------------------------------------------------------------------------------
/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | from mirnet.train import LowLightTrainer
3 | from mirnet.utils import init_wandb, download_dataset
4 |
5 |
6 | download_dataset('LOL')
7 |
8 | init_wandb(
9 | project_name='mirnet', experiment_name='LOL_lowlight',
10 | wandb_api_key='cf0947ccde62903d4df0742a58b8a54ca4c11673'
11 | )
12 |
13 | train_low_light_images = glob('./our485/low/*')
14 | train_high_light_images = glob('./our485/high/*')
15 | valid_low_light_images = glob('./eval15/low/*')
16 | valid_high_light_images = glob('./eval15/high/*')
17 |
18 | trainer = LowLightTrainer()
19 | trainer.build_dataset(
20 | train_low_light_images, train_high_light_images,
21 | valid_low_light_images, valid_high_light_images,
22 | crop_size=128, batch_size=16
23 | )
24 |
25 | trainer.compile()
26 |
27 | trainer.train(epochs=100, checkpoint_dir='./checkpoints')
28 |
--------------------------------------------------------------------------------
/mirnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/mirnet/__init__.py
--------------------------------------------------------------------------------
/mirnet/dataloaders/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataloader import LOLDataLoader
2 |
--------------------------------------------------------------------------------
/mirnet/dataloaders/common.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def read_images(image_files):
5 | dataset = tf.data.Dataset.from_tensor_slices(image_files)
6 | dataset = dataset.map(tf.io.read_file)
7 | dataset = dataset.map(
8 | lambda x: tf.image.decode_png(x, channels=3),
9 | num_parallel_calls=tf.data.experimental.AUTOTUNE
10 | )
11 | return dataset
12 |
13 |
14 | def random_crop(low_image, enhanced_image, low_crop_size, enhanced_crop_size):
15 | low_image_shape = tf.shape(low_image)[:2]
16 | low_w = tf.random.uniform(
17 | shape=(), maxval=low_image_shape[1] - low_crop_size + 1, dtype=tf.int32)
18 | low_h = tf.random.uniform(
19 | shape=(), maxval=low_image_shape[0] - low_crop_size + 1, dtype=tf.int32)
20 | enhanced_w = low_w
21 | enhanced_h = low_h
22 | low_image_cropped = low_image[
23 | low_h:low_h + low_crop_size,
24 | low_w:low_w + low_crop_size
25 | ]
26 | enhanced_image_cropped = enhanced_image[
27 | enhanced_h:enhanced_h + enhanced_crop_size,
28 | enhanced_w:enhanced_w + enhanced_crop_size
29 | ]
30 | return low_image_cropped, enhanced_image_cropped
31 |
32 |
33 | def random_flip(low_image, enhanced_image):
34 | return tf.cond(
35 | tf.random.uniform(shape=(), maxval=1) < 0.5,
36 | lambda: (low_image, enhanced_image),
37 | lambda: (
38 | tf.image.flip_left_right(low_image),
39 | tf.image.flip_left_right(enhanced_image)
40 | )
41 | )
42 |
43 |
44 | def random_rotate(low_image, enhanced_image):
45 | condition = tf.random.uniform(shape=(), maxval=4, dtype=tf.int32)
46 | return tf.image.rot90(low_image, condition), tf.image.rot90(enhanced_image, condition)
47 |
48 |
49 | def apply_scaling(low_image, enhanced_image):
50 | low_image = tf.cast(low_image, tf.float32)
51 | enhanced_image = tf.cast(enhanced_image, tf.float32)
52 | low_image = low_image / 255.0
53 | enhanced_image = enhanced_image / 255.0
54 | return low_image, enhanced_image
55 |
--------------------------------------------------------------------------------
/mirnet/dataloaders/dataloader.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from .common import *
3 |
4 |
5 | class LOLDataLoader:
6 |
7 | def __init__(self, images_lowlight: List[str], images_highlight: List[str]):
8 | self.images_lowlight = images_lowlight
9 | self.images_highlight = images_highlight
10 |
11 | def __len__(self):
12 | assert len(self.images_lowlight) == len(self.images_enhanced)
13 | return len(self.images_lowlight)
14 |
15 | def build_dataset(self, image_crop_size: int, batch_size: int, is_dataset_train: bool):
16 | low_light_dataset = read_images(self.images_lowlight)
17 | high_light_dataset = read_images(self.images_highlight)
18 | dataset = tf.data.Dataset.zip((low_light_dataset, high_light_dataset))
19 | dataset = dataset.map(apply_scaling, num_parallel_calls=tf.data.experimental.AUTOTUNE)
20 | dataset = dataset.map(
21 | lambda low, high: random_crop(low, high, image_crop_size, image_crop_size),
22 | num_parallel_calls=tf.data.experimental.AUTOTUNE
23 | )
24 | if is_dataset_train:
25 | dataset = dataset.map(random_rotate, num_parallel_calls=tf.data.experimental.AUTOTUNE)
26 | dataset = dataset.map(random_flip, num_parallel_calls=tf.data.experimental.AUTOTUNE)
27 | dataset = dataset.batch(batch_size)
28 | dataset = dataset.repeat(1)
29 | dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
30 | return dataset
31 |
--------------------------------------------------------------------------------
/mirnet/inference.py:
--------------------------------------------------------------------------------
1 | import gdown
2 | import numpy as np
3 | from PIL import Image
4 | import tensorflow as tf
5 | from .model import mirnet_model
6 | from .utils import closest_number
7 |
8 |
9 | class Inferer:
10 |
11 | def __init__(self):
12 | self.model = None
13 |
14 | @staticmethod
15 | def download_weights(file_id: str):
16 | gdown.download(
17 | 'https://drive.google.com/uc?id={}'.format(file_id),
18 | 'low_light_weights_best.h5', quiet=False
19 | )
20 |
21 | def build_model(self, num_rrg: int, num_mrb: int, channels: int, weights_path: str):
22 | self.model = mirnet_model(
23 | image_size=None, num_rrg=num_rrg,
24 | num_mrb=num_mrb, channels=channels
25 | )
26 | self.model.load_weights(weights_path)
27 |
28 | def _predict(self, original_image, image_resize_factor: float = 1.):
29 | width, height = original_image.size
30 | target_width, target_height = (
31 | closest_number(width // image_resize_factor, 4),
32 | closest_number(height // image_resize_factor, 4)
33 | )
34 | original_image = original_image.resize(
35 | (target_width, target_height), Image.ANTIALIAS
36 | )
37 | image = tf.keras.preprocessing.image.img_to_array(original_image)
38 | image = image.astype('float32') / 255.0
39 | image = np.expand_dims(image, axis=0)
40 | output = self.model.predict(image)
41 | output_image = output[0] * 255.0
42 | output_image = output_image.clip(0, 255)
43 | output_image = output_image.reshape(
44 | (np.shape(output_image)[0], np.shape(output_image)[1], 3)
45 | )
46 | output_image = Image.fromarray(np.uint8(output_image))
47 | original_image = Image.fromarray(np.uint8(original_image))
48 | return output_image
49 |
50 | def infer(self, image_path, image_resize_factor: float = 1.):
51 | original_image = Image.open(image_path)
52 | output_image = self._predict(original_image, image_resize_factor)
53 | return original_image, output_image
54 |
55 | def infer_streamlit(self, image_pil, image_resize_factor: float = 1.):
56 | original_image = image_pil
57 | output_image = self._predict(original_image, image_resize_factor)
58 | return original_image, output_image
59 |
--------------------------------------------------------------------------------
/mirnet/losses.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def charbonnier_loss(y_true, y_pred):
5 | return tf.reduce_mean(
6 | tf.sqrt(tf.square(y_true - y_pred) + tf.square(1e-3))
7 | )
8 |
--------------------------------------------------------------------------------
/mirnet/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .mirnet_model import mirnet_model
2 |
--------------------------------------------------------------------------------
/mirnet/model/dual_attention_unit/__init__.py:
--------------------------------------------------------------------------------
1 | from .dau import dual_attention_unit_block
2 |
--------------------------------------------------------------------------------
/mirnet/model/dual_attention_unit/attention_blocks.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def spatial_attention_block(input_tensor):
5 | """Spatial Attention Block"""
6 | average_pooling = tf.reduce_max(input_tensor, axis=-1)
7 | average_pooling = tf.expand_dims(average_pooling, axis=-1)
8 | max_pooling = tf.reduce_mean(input_tensor, axis=-1)
9 | max_pooling = tf.expand_dims(max_pooling, axis=-1)
10 | concatenated = tf.keras.layers.Concatenate(axis=-1)([average_pooling, max_pooling])
11 | feature_map = tf.keras.layers.Conv2D(1, kernel_size=(1, 1))(concatenated)
12 | feature_map = tf.nn.sigmoid(feature_map)
13 | return input_tensor * feature_map
14 |
15 |
16 | def channel_attention_block(input_tensor):
17 | """Channel Attention Block"""
18 | channels = list(input_tensor.shape)[-1]
19 | average_pooling = tf.keras.layers.GlobalAveragePooling2D()(input_tensor)
20 | feature_descriptor = tf.reshape(average_pooling, shape=(-1, 1, 1, channels))
21 | feature_activations = tf.keras.layers.ReLU()(
22 | tf.keras.layers.Conv2D(
23 | filters=channels // 8, kernel_size=(1, 1)
24 | )(feature_descriptor)
25 | )
26 | feature_activations = tf.nn.sigmoid(
27 | tf.keras.layers.Conv2D(
28 | filters=channels, kernel_size=(1, 1)
29 | )(feature_activations)
30 | )
31 | return input_tensor * feature_activations
32 |
--------------------------------------------------------------------------------
/mirnet/model/dual_attention_unit/dau.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from .attention_blocks import (
3 | channel_attention_block,
4 | spatial_attention_block
5 | )
6 |
7 |
8 | def dual_attention_unit_block(input_tensor):
9 | """Dual Attention Unit Block"""
10 | channels = list(input_tensor.shape)[-1]
11 | feature_map = tf.keras.layers.Conv2D(
12 | channels, kernel_size=(3, 3), padding='same')(input_tensor)
13 | feature_map = tf.keras.layers.ReLU()(feature_map)
14 | feature_map = tf.keras.layers.Conv2D(
15 | channels, kernel_size=(3, 3), padding='same')(feature_map)
16 | channel_attention = channel_attention_block(feature_map)
17 | spatial_attention = spatial_attention_block(feature_map)
18 | concatenation = tf.keras.layers.Concatenate(axis=-1)([
19 | channel_attention, spatial_attention])
20 | concatenation = tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(concatenation)
21 | return tf.keras.layers.Add()([input_tensor, concatenation])
22 |
--------------------------------------------------------------------------------
/mirnet/model/mirnet_model.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from .skff import selective_kernel_feature_fusion
3 | from .dual_attention_unit import dual_attention_unit_block
4 | from .residual_resizing_modules import up_sampling_module, down_sampling_module
5 |
6 |
7 | def multi_scale_residual_block(input_tensor, channels):
8 | # features
9 | level1 = input_tensor
10 | level2 = down_sampling_module(input_tensor)
11 | level3 = down_sampling_module(level2)
12 | # DAU
13 | level1_dau = dual_attention_unit_block(level1)
14 | level2_dau = dual_attention_unit_block(level2)
15 | level3_dau = dual_attention_unit_block(level3)
16 | # SKFF
17 | level1_skff = selective_kernel_feature_fusion(
18 | level1_dau, up_sampling_module(level2_dau),
19 | up_sampling_module(up_sampling_module(level3_dau))
20 | )
21 | level2_skff = selective_kernel_feature_fusion(
22 | down_sampling_module(level1_dau), level2_dau,
23 | up_sampling_module(level3_dau)
24 | )
25 | level3_skff = selective_kernel_feature_fusion(
26 | down_sampling_module(down_sampling_module(level1_dau)),
27 | down_sampling_module(level2_dau), level3_dau
28 | )
29 | # DAU 2
30 | level1_dau_2 = dual_attention_unit_block(level1_skff)
31 | level2_dau_2 = up_sampling_module((dual_attention_unit_block(level2_skff)))
32 | level3_dau_2 = up_sampling_module(up_sampling_module(dual_attention_unit_block(level3_skff)))
33 | # SKFF 2
34 | skff_ = selective_kernel_feature_fusion(level1_dau_2, level3_dau_2, level3_dau_2)
35 | # skff_ = selective_kernel_feature_fusion(level1_dau_2, level2_dau_2, level3_dau_2)
36 | conv = tf.keras.layers.Conv2D(channels, kernel_size=(3, 3), padding='same')(skff_)
37 | return tf.keras.layers.Add()([input_tensor, conv])
38 |
39 |
40 | def recursive_residual_group(input_tensor, num_mrb, channels):
41 | conv1 = tf.keras.layers.Conv2D(
42 | channels, kernel_size=(3, 3), padding='same')(input_tensor)
43 | for _ in range(num_mrb):
44 | conv1 = multi_scale_residual_block(conv1, channels)
45 | conv2 = tf.keras.layers.Conv2D(
46 | channels, kernel_size=(3, 3), padding='same')(conv1)
47 | return tf.keras.layers.Add()([conv2, input_tensor])
48 |
49 |
50 | def mirnet_model(image_size: int, num_rrg: int, num_mrb: int, channels: int):
51 | input_tensor = tf.keras.Input(shape=[image_size, image_size, 3])
52 | x1 = tf.keras.layers.Conv2D(
53 | channels, kernel_size=(3, 3), padding='same')(input_tensor)
54 | for _ in range(num_rrg):
55 | x1 = recursive_residual_group(x1, num_mrb, channels)
56 | conv = tf.keras.layers.Conv2D(
57 | 3, kernel_size=(3, 3), padding='same')(x1)
58 | output_tensor = tf.keras.layers.Add()([input_tensor, conv])
59 | return tf.keras.Model(input_tensor, output_tensor)
60 |
--------------------------------------------------------------------------------
/mirnet/model/residual_resizing_modules.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def down_sampling_module(input_tensor):
5 | """Downsampling Module"""
6 | channels = list(input_tensor.shape)[-1]
7 | main_branch = tf.keras.layers.Conv2D(
8 | channels, kernel_size=(1, 1))(input_tensor)
9 | main_branch = tf.nn.relu(main_branch)
10 | # main_branch = tf.keras.layers.Conv2D(
11 | # channels, kernel_size=(3, 3), padding='same')(input_tensor)
12 | main_branch = tf.keras.layers.Conv2D(
13 | channels, kernel_size=(3, 3), padding='same')(main_branch)
14 | main_branch = tf.nn.relu(main_branch)
15 | main_branch = tf.keras.layers.MaxPooling2D()(main_branch)
16 | main_branch = tf.keras.layers.Conv2D(
17 | channels * 2, kernel_size=(1, 1))(main_branch)
18 | skip_branch = tf.keras.layers.MaxPooling2D()(input_tensor)
19 | skip_branch = tf.keras.layers.Conv2D(
20 | channels * 2, kernel_size=(1, 1))(skip_branch)
21 | return tf.keras.layers.Add()([skip_branch, main_branch])
22 |
23 |
24 | def up_sampling_module(input_tensor):
25 | """Upsampling Module"""
26 | channels = list(input_tensor.shape)[-1]
27 | main_branch = tf.keras.layers.Conv2D(
28 | channels, kernel_size=(1, 1))(input_tensor)
29 | main_branch = tf.nn.relu(main_branch)
30 | # main_branch = tf.keras.layers.Conv2D(
31 | # channels, kernel_size=(3, 3), padding='same')(input_tensor)
32 | main_branch = tf.keras.layers.Conv2D(
33 | channels, kernel_size=(3, 3), padding='same')(main_branch)
34 | main_branch = tf.nn.relu(main_branch)
35 | main_branch = tf.keras.layers.UpSampling2D()(main_branch)
36 | main_branch = tf.keras.layers.Conv2D(
37 | channels // 2, kernel_size=(1, 1))(main_branch)
38 | skip_branch = tf.keras.layers.UpSampling2D()(input_tensor)
39 | skip_branch = tf.keras.layers.Conv2D(
40 | channels // 2, kernel_size=(1, 1))(skip_branch)
41 | return tf.keras.layers.Add()([skip_branch, main_branch])
42 |
--------------------------------------------------------------------------------
/mirnet/model/skff.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def selective_kernel_feature_fusion(
5 | multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3):
6 | """Selective Kernel Feature Fusion Block"""
7 | channels = list(multi_scale_feature_1.shape)[-1]
8 | combined_feature = tf.keras.layers.Add()([
9 | multi_scale_feature_1, multi_scale_feature_2, multi_scale_feature_3])
10 | gap = tf.keras.layers.GlobalAveragePooling2D()(combined_feature)
11 | channel_wise_statistics = tf.reshape(gap, shape=(-1, 1, 1, channels))
12 | compact_feature_representation = tf.keras.layers.ReLU()(
13 | tf.keras.layers.Conv2D(
14 | filters=channels // 8, kernel_size=(1, 1)
15 | )(channel_wise_statistics)
16 | )
17 | feature_descriptor_1 = tf.nn.softmax(
18 | tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(compact_feature_representation)
19 | )
20 | feature_descriptor_2 = tf.nn.softmax(
21 | tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(compact_feature_representation)
22 | )
23 | feature_descriptor_3 = tf.nn.softmax(
24 | tf.keras.layers.Conv2D(channels, kernel_size=(1, 1))(compact_feature_representation)
25 | )
26 | feature_1 = multi_scale_feature_1 * feature_descriptor_1
27 | feature_2 = multi_scale_feature_2 * feature_descriptor_2
28 | feature_3 = multi_scale_feature_3 * feature_descriptor_3
29 | aggregated_feature = tf.keras.layers.Add()([feature_1, feature_2, feature_3])
30 | return aggregated_feature
31 |
--------------------------------------------------------------------------------
/mirnet/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 | from typing import List
4 | from .utils import psnr
5 | from .model import mirnet_model
6 | from .losses import charbonnier_loss
7 | from wandb.keras import WandbCallback
8 | from .dataloaders import LOLDataLoader
9 |
10 |
11 | class LowLightTrainer:
12 |
13 | def __init__(self):
14 | self.model = None
15 | self.crop_size = None
16 | self.train_dataset = None
17 | self.valid_dataset = None
18 | # self.strategy = tf.distribute.OneDeviceStrategy("GPU:0")
19 | # if len(tf.config.list_physical_devices('GPU')) > 1:
20 | # self.strategy = tf.distribute.MirroredStrategy()
21 |
22 | def build_dataset(
23 | self, train_low_light_images: List[str], train_high_light_images: List[str],
24 | valid_low_light_images: List[str], valid_high_light_images: List[str],
25 | crop_size: int, batch_size: int):
26 | self.crop_size = crop_size
27 | self.train_dataset = LOLDataLoader(
28 | images_lowlight=train_low_light_images,
29 | images_highlight=train_high_light_images
30 | ).build_dataset(
31 | image_crop_size=crop_size, batch_size=batch_size, is_dataset_train=True)
32 | self.valid_dataset = LOLDataLoader(
33 | images_lowlight=valid_low_light_images,
34 | images_highlight=valid_high_light_images
35 | ).build_dataset(
36 | image_crop_size=crop_size, batch_size=batch_size, is_dataset_train=False)
37 |
38 | def compile(self, num_rrg=3, num_mrb=2, channels=64, learning_rate=1e-4, use_mae_loss=True):
39 | self.model = mirnet_model(self.crop_size, num_rrg, num_mrb, channels)
40 | loss_function = tf.keras.losses.MeanAbsoluteError() if use_mae_loss else charbonnier_loss
41 | optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
42 | self.model.compile(optimizer=optimizer, loss=loss_function, metrics=[psnr])
43 |
44 | def train(self, epochs: int, checkpoint_dir: str):
45 | callbacks = [
46 | tf.keras.callbacks.EarlyStopping(
47 | monitor="val_psnr",
48 | patience=10, mode='max'
49 | ),
50 | tf.keras.callbacks.ReduceLROnPlateau(
51 | monitor='val_psnr', factor=0.5,
52 | patience=5, verbose=1, min_delta=1e-7, mode='max'
53 | ),
54 | tf.keras.callbacks.ModelCheckpoint(
55 | os.path.join(checkpoint_dir, 'low_light_weights_best.h5'),
56 | monitor="val_psnr", save_weights_only=True,
57 | mode="max", save_best_only=True, save_freq=1
58 | ), WandbCallback()
59 | ]
60 | history = self.model.fit(
61 | self.train_dataset, validation_data=self.valid_dataset,
62 | epochs=epochs, callbacks=callbacks, verbose=1
63 | )
64 | return history
65 |
--------------------------------------------------------------------------------
/mirnet/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gdown
3 | import wandb
4 | import subprocess
5 | import tensorflow as tf
6 | from matplotlib import pyplot as plt
7 |
8 |
9 | def psnr(y_true, y_pred):
10 | return tf.image.psnr(y_pred, y_true, max_val=255.0)
11 |
12 |
13 | def init_wandb(project_name, experiment_name, wandb_api_key):
14 | """Initialize Wandb
15 | Args:
16 | project_name: project name on Wandb
17 | experiment_name: experiment name on Wandb
18 | wandb_api_key: Wandb API Key
19 | """
20 | if project_name is not None and experiment_name is not None:
21 | os.environ['WANDB_API_KEY'] = wandb_api_key
22 | wandb.init(project=project_name, name=experiment_name)
23 |
24 |
25 | def download_dataset(dataset_tag):
26 | """Utility for downloading and unpacking dataset dataset
27 | Args:
28 | dataset_tag: Tag for the respective dataset.
29 | Available tags -> ('LOL')
30 | """
31 | print('Downloading dataset...')
32 | if dataset_tag == 'LOL':
33 | gdown.download(
34 | 'https://drive.google.com/uc?id=157bjO1_cFuSd0HWDUuAmcHRJDVyWpOxB',
35 | 'LOLdataset.zip', quiet=False
36 | )
37 | print('Unpacking Dataset')
38 | subprocess.run(['unzip', 'LOLdataset.zip'])
39 | print('Done!!!')
40 | else:
41 | raise AssertionError('Dataset tag not found')
42 |
43 |
44 | def plot_result(image, enhanced):
45 | """Utility for Plotting inference result
46 | Args:
47 | image: original image
48 | enhanced: enhanced image
49 | """
50 | fig = plt.figure(figsize=(12, 12))
51 | fig.add_subplot(1, 2, 1).set_title('Original Image')
52 | _ = plt.imshow(image)
53 | fig.add_subplot(1, 2, 2).set_title('Enhanced Image')
54 | _ = plt.imshow(enhanced)
55 | plt.show()
56 |
57 |
58 | def closest_number(n, m):
59 | q = int(n / m)
60 | n1 = m * q
61 | if (n * m) > 0:
62 | n2 = (m * (q + 1))
63 | else:
64 | n2 = (m * (q - 1))
65 | if abs(n - n1) < abs(n - n2):
66 | return n1
67 | return n2
68 |
--------------------------------------------------------------------------------
/notebooks/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/soumik12345/MIRNet/c6812f5bc4ac87e4e63af21aa4e0db84597a17c8/notebooks/.gitkeep
--------------------------------------------------------------------------------
/notebooks/MIRNet_Low_Light_Train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "MIRNet-Low-Light-Train.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "authorship_tag": "ABX9TyPYkiW3WGoZrgLnaTdb98k7",
10 | "include_colab_link": true
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | },
16 | "accelerator": "GPU"
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "metadata": {
22 | "id": "view-in-github",
23 | "colab_type": "text"
24 | },
25 | "source": [
26 | "
"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "colab": {
33 | "base_uri": "https://localhost:8080/"
34 | },
35 | "id": "u6SxYHy-0tAX",
36 | "outputId": "0b8689e6-a29e-4bd0-9f6f-f4d81cc7f836"
37 | },
38 | "source": [
39 | "!nvidia-smi"
40 | ],
41 | "execution_count": 1,
42 | "outputs": [
43 | {
44 | "output_type": "stream",
45 | "text": [
46 | "Tue Dec 1 07:16:36 2020 \n",
47 | "+-----------------------------------------------------------------------------+\n",
48 | "| NVIDIA-SMI 455.38 Driver Version: 418.67 CUDA Version: 10.1 |\n",
49 | "|-------------------------------+----------------------+----------------------+\n",
50 | "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
51 | "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
52 | "| | | MIG M. |\n",
53 | "|===============================+======================+======================|\n",
54 | "| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
55 | "| N/A 66C P8 11W / 70W | 0MiB / 15079MiB | 0% Default |\n",
56 | "| | | ERR! |\n",
57 | "+-------------------------------+----------------------+----------------------+\n",
58 | " \n",
59 | "+-----------------------------------------------------------------------------+\n",
60 | "| Processes: |\n",
61 | "| GPU GI CI PID Type Process name GPU Memory |\n",
62 | "| ID ID Usage |\n",
63 | "|=============================================================================|\n",
64 | "| No running processes found |\n",
65 | "+-----------------------------------------------------------------------------+\n"
66 | ],
67 | "name": "stdout"
68 | }
69 | ]
70 | },
71 | {
72 | "cell_type": "code",
73 | "metadata": {
74 | "colab": {
75 | "base_uri": "https://localhost:8080/"
76 | },
77 | "id": "H42UtLtD3QMh",
78 | "outputId": "9e3d5c8d-7f11-40c6-a0c5-a1d2ddbcc0c0"
79 | },
80 | "source": [
81 | "!git clone https://github.com/soumik12345/MIRNet\n",
82 | "%cd MIRNet"
83 | ],
84 | "execution_count": 2,
85 | "outputs": [
86 | {
87 | "output_type": "stream",
88 | "text": [
89 | "Cloning into 'MIRNet'...\n",
90 | "remote: Enumerating objects: 138, done.\u001b[K\n",
91 | "remote: Counting objects: 100% (138/138), done.\u001b[K\n",
92 | "remote: Compressing objects: 100% (104/104), done.\u001b[K\n",
93 | "remote: Total 138 (delta 52), reused 106 (delta 25), pack-reused 0\u001b[K\n",
94 | "Receiving objects: 100% (138/138), 9.84 MiB | 14.21 MiB/s, done.\n",
95 | "Resolving deltas: 100% (52/52), done.\n",
96 | "/content/MIRNet\n"
97 | ],
98 | "name": "stdout"
99 | }
100 | ]
101 | },
102 | {
103 | "cell_type": "code",
104 | "metadata": {
105 | "colab": {
106 | "base_uri": "https://localhost:8080/"
107 | },
108 | "id": "GMhYQqQ_6Qh5",
109 | "outputId": "f263304a-bb43-458d-e06f-b345cc496f60"
110 | },
111 | "source": [
112 | "!pip install -qq wandb"
113 | ],
114 | "execution_count": 3,
115 | "outputs": [
116 | {
117 | "output_type": "stream",
118 | "text": [
119 | "\u001b[K |████████████████████████████████| 1.8MB 23.2MB/s \n",
120 | "\u001b[K |████████████████████████████████| 102kB 14.7MB/s \n",
121 | "\u001b[K |████████████████████████████████| 133kB 60.6MB/s \n",
122 | "\u001b[K |████████████████████████████████| 102kB 13.9MB/s \n",
123 | "\u001b[K |████████████████████████████████| 163kB 59.2MB/s \n",
124 | "\u001b[K |████████████████████████████████| 71kB 11.6MB/s \n",
125 | "\u001b[?25h Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
126 | " Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
127 | " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
128 | ],
129 | "name": "stdout"
130 | }
131 | ]
132 | },
133 | {
134 | "cell_type": "code",
135 | "metadata": {
136 | "id": "hAXrPnAx-seN"
137 | },
138 | "source": [
139 | "from glob import glob\n",
140 | "import tensorflow as tf\n",
141 | "from mirnet.train import LowLightTrainer\n",
142 | "from mirnet.utils import init_wandb, download_dataset\n",
143 | "\n",
144 | "\n",
145 | "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)"
146 | ],
147 | "execution_count": 4,
148 | "outputs": []
149 | },
150 | {
151 | "cell_type": "code",
152 | "metadata": {
153 | "colab": {
154 | "base_uri": "https://localhost:8080/"
155 | },
156 | "id": "fT7Fr8Tq_bqn",
157 | "outputId": "96ccfa74-77e8-453d-d6c1-7227387daea2"
158 | },
159 | "source": [
160 | "download_dataset('LOL')"
161 | ],
162 | "execution_count": 5,
163 | "outputs": [
164 | {
165 | "output_type": "stream",
166 | "text": [
167 | "Downloading dataset...\n"
168 | ],
169 | "name": "stdout"
170 | },
171 | {
172 | "output_type": "stream",
173 | "text": [
174 | "Downloading...\n",
175 | "From: https://drive.google.com/uc?id=157bjO1_cFuSd0HWDUuAmcHRJDVyWpOxB\n",
176 | "To: /content/MIRNet/LOLdataset.zip\n",
177 | "347MB [00:01, 211MB/s]\n"
178 | ],
179 | "name": "stderr"
180 | },
181 | {
182 | "output_type": "stream",
183 | "text": [
184 | "Unpacking Dataset\n",
185 | "Done!!!\n"
186 | ],
187 | "name": "stdout"
188 | }
189 | ]
190 | },
191 | {
192 | "cell_type": "code",
193 | "metadata": {
194 | "colab": {
195 | "base_uri": "https://localhost:8080/",
196 | "height": 136
197 | },
198 | "id": "0lO57o2W_sqq",
199 | "outputId": "ac006479-bf25-428f-e087-7ff523e89bba"
200 | },
201 | "source": [
202 | "init_wandb(\n",
203 | " project_name='mirnet',\n",
204 | " experiment_name='LOL_lowlight_experiment_2_256x256',\n",
205 | " wandb_api_key='cf0947ccde62903d4df0742a58b8a54ca4c11673'\n",
206 | ")"
207 | ],
208 | "execution_count": 6,
209 | "outputs": [
210 | {
211 | "output_type": "stream",
212 | "text": [
213 | "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m19soumik-rakshit96\u001b[0m (use `wandb login --relogin` to force relogin)\n"
214 | ],
215 | "name": "stderr"
216 | },
217 | {
218 | "output_type": "display_data",
219 | "data": {
220 | "text/html": [
221 | "\n",
222 | " Tracking run with wandb version 0.10.11
\n",
223 | " Syncing run LOL_lowlight_experiment_2_256x256 to Weights & Biases (Documentation).
\n",
224 | " Project page: https://wandb.ai/19soumik-rakshit96/mirnet
\n",
225 | " Run page: https://wandb.ai/19soumik-rakshit96/mirnet/runs/15w0gpyv
\n",
226 | " Run data is saved locally in /content/MIRNet/wandb/run-20201201_071736-15w0gpyv
\n",
227 | " "
228 | ],
229 | "text/plain": [
230 | ""
231 | ]
232 | },
233 | "metadata": {
234 | "tags": []
235 | }
236 | }
237 | ]
238 | },
239 | {
240 | "cell_type": "code",
241 | "metadata": {
242 | "id": "Kiom1ZjH8eKi"
243 | },
244 | "source": [
245 | "trainer = LowLightTrainer()\n",
246 | "train_low_light_images = glob('./our485/low/*')\n",
247 | "train_high_light_images = glob('./our485/high/*')\n",
248 | "valid_low_light_images = glob('./eval15/low/*')\n",
249 | "valid_high_light_images = glob('./eval15/high/*')"
250 | ],
251 | "execution_count": 7,
252 | "outputs": []
253 | },
254 | {
255 | "cell_type": "code",
256 | "metadata": {
257 | "id": "Kz0i4QoQ8qw1"
258 | },
259 | "source": [
260 | "trainer.build_dataset(\n",
261 | " train_low_light_images, train_high_light_images,\n",
262 | " valid_low_light_images, valid_high_light_images,\n",
263 | " crop_size=256, batch_size=2\n",
264 | ")\n",
265 | "trainer.compile()"
266 | ],
267 | "execution_count": 8,
268 | "outputs": []
269 | },
270 | {
271 | "cell_type": "code",
272 | "metadata": {
273 | "colab": {
274 | "base_uri": "https://localhost:8080/"
275 | },
276 | "id": "zeCmSZqI9GBU",
277 | "outputId": "e939e100-091d-4596-c542-1516c60f401f"
278 | },
279 | "source": [
280 | "trainer.train(epochs=100, checkpoint_dir='./checkpoints')"
281 | ],
282 | "execution_count": 9,
283 | "outputs": [
284 | {
285 | "output_type": "stream",
286 | "text": [
287 | "Epoch 1/100\n",
288 | "243/243 [==============================] - 442s 2s/step - loss: 0.1592 - psnr: 64.1155 - val_loss: 0.1249 - val_psnr: 65.5065\n",
289 | "Epoch 2/100\n",
290 | "243/243 [==============================] - 433s 2s/step - loss: 0.1509 - psnr: 64.5416 - val_loss: 0.1429 - val_psnr: 64.6090\n",
291 | "Epoch 3/100\n",
292 | "243/243 [==============================] - 431s 2s/step - loss: 0.1439 - psnr: 65.2337 - val_loss: 0.1303 - val_psnr: 65.2745\n",
293 | "Epoch 4/100\n",
294 | "243/243 [==============================] - 436s 2s/step - loss: 0.1285 - psnr: 65.8450 - val_loss: 0.1140 - val_psnr: 66.2238\n",
295 | "Epoch 5/100\n",
296 | "243/243 [==============================] - 438s 2s/step - loss: 0.1194 - psnr: 66.4395 - val_loss: 0.1135 - val_psnr: 66.3191\n",
297 | "Epoch 6/100\n",
298 | "243/243 [==============================] - 437s 2s/step - loss: 0.1102 - psnr: 67.0214 - val_loss: 0.1049 - val_psnr: 66.9083\n",
299 | "Epoch 7/100\n",
300 | "243/243 [==============================] - 430s 2s/step - loss: 0.1060 - psnr: 67.3232 - val_loss: 0.1132 - val_psnr: 66.4320\n",
301 | "Epoch 8/100\n",
302 | "243/243 [==============================] - 431s 2s/step - loss: 0.1069 - psnr: 67.3737 - val_loss: 0.1094 - val_psnr: 66.4070\n",
303 | "Epoch 9/100\n",
304 | "243/243 [==============================] - 431s 2s/step - loss: 0.1044 - psnr: 67.4384 - val_loss: 0.1210 - val_psnr: 65.9124\n",
305 | "Epoch 10/100\n",
306 | "243/243 [==============================] - 432s 2s/step - loss: 0.1080 - psnr: 67.3665 - val_loss: 0.1115 - val_psnr: 66.9443\n",
307 | "Epoch 11/100\n",
308 | "243/243 [==============================] - 438s 2s/step - loss: 0.1041 - psnr: 67.5550 - val_loss: 0.1022 - val_psnr: 67.0890\n",
309 | "Epoch 12/100\n",
310 | "243/243 [==============================] - 432s 2s/step - loss: 0.1043 - psnr: 67.4996 - val_loss: 0.1067 - val_psnr: 67.0900\n",
311 | "Epoch 13/100\n",
312 | "243/243 [==============================] - 438s 2s/step - loss: 0.1059 - psnr: 67.4166 - val_loss: 0.0977 - val_psnr: 67.6418\n",
313 | "Epoch 14/100\n",
314 | "243/243 [==============================] - 438s 2s/step - loss: 0.0989 - psnr: 67.9418 - val_loss: 0.0963 - val_psnr: 68.1703\n",
315 | "Epoch 15/100\n",
316 | "243/243 [==============================] - 438s 2s/step - loss: 0.1015 - psnr: 67.7641 - val_loss: 0.0939 - val_psnr: 68.4126\n",
317 | "Epoch 16/100\n",
318 | "243/243 [==============================] - 431s 2s/step - loss: 0.0982 - psnr: 67.9785 - val_loss: 0.0951 - val_psnr: 68.0276\n",
319 | "Epoch 17/100\n",
320 | "243/243 [==============================] - 431s 2s/step - loss: 0.0981 - psnr: 68.0821 - val_loss: 0.1169 - val_psnr: 66.7316\n",
321 | "Epoch 18/100\n",
322 | "243/243 [==============================] - 430s 2s/step - loss: 0.1004 - psnr: 67.8817 - val_loss: 0.1016 - val_psnr: 67.3338\n",
323 | "Epoch 19/100\n",
324 | "243/243 [==============================] - 430s 2s/step - loss: 0.0959 - psnr: 68.2082 - val_loss: 0.0995 - val_psnr: 67.4730\n",
325 | "Epoch 20/100\n",
326 | "243/243 [==============================] - ETA: 0s - loss: 0.1026 - psnr: 67.7257\n",
327 | "Epoch 00020: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05.\n",
328 | "243/243 [==============================] - 430s 2s/step - loss: 0.1026 - psnr: 67.7257 - val_loss: 0.1003 - val_psnr: 67.3352\n",
329 | "Epoch 21/100\n",
330 | "243/243 [==============================] - 436s 2s/step - loss: 0.0919 - psnr: 68.6314 - val_loss: 0.0826 - val_psnr: 69.3167\n",
331 | "Epoch 22/100\n",
332 | "243/243 [==============================] - 430s 2s/step - loss: 0.0936 - psnr: 68.3714 - val_loss: 0.1035 - val_psnr: 67.5037\n",
333 | "Epoch 23/100\n",
334 | "243/243 [==============================] - 432s 2s/step - loss: 0.0896 - psnr: 68.7174 - val_loss: 0.0862 - val_psnr: 69.1280\n",
335 | "Epoch 24/100\n",
336 | "243/243 [==============================] - 435s 2s/step - loss: 0.0917 - psnr: 68.6044 - val_loss: 0.0796 - val_psnr: 69.3943\n",
337 | "Epoch 25/100\n",
338 | "243/243 [==============================] - 431s 2s/step - loss: 0.0895 - psnr: 68.7495 - val_loss: 0.0898 - val_psnr: 68.8661\n",
339 | "Epoch 26/100\n",
340 | "243/243 [==============================] - 430s 2s/step - loss: 0.0890 - psnr: 68.8537 - val_loss: 0.0983 - val_psnr: 68.0671\n",
341 | "Epoch 27/100\n",
342 | "243/243 [==============================] - 431s 2s/step - loss: 0.0920 - psnr: 68.5267 - val_loss: 0.0975 - val_psnr: 67.9822\n",
343 | "Epoch 28/100\n",
344 | "243/243 [==============================] - 431s 2s/step - loss: 0.0897 - psnr: 68.7615 - val_loss: 0.1008 - val_psnr: 67.6455\n",
345 | "Epoch 29/100\n",
346 | "243/243 [==============================] - ETA: 0s - loss: 0.0864 - psnr: 69.0921\n",
347 | "Epoch 00029: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.\n",
348 | "243/243 [==============================] - 430s 2s/step - loss: 0.0864 - psnr: 69.0921 - val_loss: 0.0940 - val_psnr: 68.5634\n",
349 | "Epoch 30/100\n",
350 | "243/243 [==============================] - 431s 2s/step - loss: 0.0868 - psnr: 69.0637 - val_loss: 0.0917 - val_psnr: 68.4715\n",
351 | "Epoch 31/100\n",
352 | "243/243 [==============================] - 436s 2s/step - loss: 0.0845 - psnr: 69.2952 - val_loss: 0.0791 - val_psnr: 69.2714\n",
353 | "Epoch 32/100\n",
354 | "243/243 [==============================] - 431s 2s/step - loss: 0.0867 - psnr: 69.1577 - val_loss: 0.0809 - val_psnr: 69.5186\n",
355 | "Epoch 33/100\n",
356 | "243/243 [==============================] - 430s 2s/step - loss: 0.0860 - psnr: 69.3142 - val_loss: 0.0868 - val_psnr: 68.9742\n",
357 | "Epoch 34/100\n",
358 | "243/243 [==============================] - 431s 2s/step - loss: 0.0858 - psnr: 69.1895 - val_loss: 0.0890 - val_psnr: 68.3421\n",
359 | "Epoch 35/100\n",
360 | "243/243 [==============================] - 430s 2s/step - loss: 0.0860 - psnr: 69.1920 - val_loss: 0.0791 - val_psnr: 69.2999\n",
361 | "Epoch 36/100\n",
362 | "243/243 [==============================] - 430s 2s/step - loss: 0.0867 - psnr: 69.2137 - val_loss: 0.0832 - val_psnr: 68.8973\n",
363 | "Epoch 37/100\n",
364 | "243/243 [==============================] - 436s 2s/step - loss: 0.0853 - psnr: 69.2414 - val_loss: 0.0716 - val_psnr: 70.2069\n",
365 | "Epoch 38/100\n",
366 | "243/243 [==============================] - 431s 2s/step - loss: 0.0858 - psnr: 69.2487 - val_loss: 0.0895 - val_psnr: 68.9241\n",
367 | "Epoch 39/100\n",
368 | "243/243 [==============================] - 431s 2s/step - loss: 0.0856 - psnr: 69.2450 - val_loss: 0.0896 - val_psnr: 68.6078\n",
369 | "Epoch 40/100\n",
370 | "243/243 [==============================] - 430s 2s/step - loss: 0.0840 - psnr: 69.2957 - val_loss: 0.0752 - val_psnr: 69.8205\n",
371 | "Epoch 41/100\n",
372 | "243/243 [==============================] - 430s 2s/step - loss: 0.0843 - psnr: 69.3408 - val_loss: 0.0924 - val_psnr: 68.4360\n",
373 | "Epoch 42/100\n",
374 | "243/243 [==============================] - ETA: 0s - loss: 0.0834 - psnr: 69.4409\n",
375 | "Epoch 00042: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-05.\n",
376 | "243/243 [==============================] - 430s 2s/step - loss: 0.0834 - psnr: 69.4409 - val_loss: 0.0966 - val_psnr: 67.8450\n",
377 | "Epoch 43/100\n",
378 | "243/243 [==============================] - 431s 2s/step - loss: 0.0803 - psnr: 69.8295 - val_loss: 0.0844 - val_psnr: 69.0862\n",
379 | "Epoch 44/100\n",
380 | "243/243 [==============================] - 431s 2s/step - loss: 0.0801 - psnr: 69.7533 - val_loss: 0.0876 - val_psnr: 68.7879\n",
381 | "Epoch 45/100\n",
382 | "243/243 [==============================] - 431s 2s/step - loss: 0.0826 - psnr: 69.6043 - val_loss: 0.0934 - val_psnr: 68.2593\n",
383 | "Epoch 46/100\n",
384 | "243/243 [==============================] - 430s 2s/step - loss: 0.0789 - psnr: 69.9293 - val_loss: 0.0947 - val_psnr: 68.2925\n",
385 | "Epoch 47/100\n",
386 | "243/243 [==============================] - ETA: 0s - loss: 0.0803 - psnr: 69.7273\n",
387 | "Epoch 00047: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-06.\n",
388 | "243/243 [==============================] - 429s 2s/step - loss: 0.0803 - psnr: 69.7273 - val_loss: 0.0935 - val_psnr: 68.1041\n"
389 | ],
390 | "name": "stdout"
391 | },
392 | {
393 | "output_type": "execute_result",
394 | "data": {
395 | "text/plain": [
396 | ""
397 | ]
398 | },
399 | "metadata": {
400 | "tags": []
401 | },
402 | "execution_count": 9
403 | }
404 | ]
405 | },
406 | {
407 | "cell_type": "code",
408 | "metadata": {
409 | "colab": {
410 | "base_uri": "https://localhost:8080/",
411 | "height": 17
412 | },
413 | "id": "dGkCwyM2_odJ",
414 | "outputId": "451d95d4-8113-4322-fbbb-6f4b40c0898a"
415 | },
416 | "source": [
417 | "from glob import glob\n",
418 | "from google.colab import files\n",
419 | "\n",
420 | "\n",
421 | "for file in glob('/content/MIRNet/checkpoints/*'):\n",
422 | " files.download(file)"
423 | ],
424 | "execution_count": 14,
425 | "outputs": [
426 | {
427 | "output_type": "display_data",
428 | "data": {
429 | "application/javascript": [
430 | "\n",
431 | " async function download(id, filename, size) {\n",
432 | " if (!google.colab.kernel.accessAllowed) {\n",
433 | " return;\n",
434 | " }\n",
435 | " const div = document.createElement('div');\n",
436 | " const label = document.createElement('label');\n",
437 | " label.textContent = `Downloading \"${filename}\": `;\n",
438 | " div.appendChild(label);\n",
439 | " const progress = document.createElement('progress');\n",
440 | " progress.max = size;\n",
441 | " div.appendChild(progress);\n",
442 | " document.body.appendChild(div);\n",
443 | "\n",
444 | " const buffers = [];\n",
445 | " let downloaded = 0;\n",
446 | "\n",
447 | " const channel = await google.colab.kernel.comms.open(id);\n",
448 | " // Send a message to notify the kernel that we're ready.\n",
449 | " channel.send({})\n",
450 | "\n",
451 | " for await (const message of channel.messages) {\n",
452 | " // Send a message to notify the kernel that we're ready.\n",
453 | " channel.send({})\n",
454 | " if (message.buffers) {\n",
455 | " for (const buffer of message.buffers) {\n",
456 | " buffers.push(buffer);\n",
457 | " downloaded += buffer.byteLength;\n",
458 | " progress.value = downloaded;\n",
459 | " }\n",
460 | " }\n",
461 | " }\n",
462 | " const blob = new Blob(buffers, {type: 'application/binary'});\n",
463 | " const a = document.createElement('a');\n",
464 | " a.href = window.URL.createObjectURL(blob);\n",
465 | " a.download = filename;\n",
466 | " div.appendChild(a);\n",
467 | " a.click();\n",
468 | " div.remove();\n",
469 | " }\n",
470 | " "
471 | ],
472 | "text/plain": [
473 | ""
474 | ]
475 | },
476 | "metadata": {
477 | "tags": []
478 | }
479 | },
480 | {
481 | "output_type": "display_data",
482 | "data": {
483 | "application/javascript": [
484 | "download(\"download_6ea7b0ed-1fe2-4f54-a838-3f3f7d13aa3c\", \"low_light_weights_best_256x256.h5\", 147825648)"
485 | ],
486 | "text/plain": [
487 | ""
488 | ]
489 | },
490 | "metadata": {
491 | "tags": []
492 | }
493 | }
494 | ]
495 | },
496 | {
497 | "cell_type": "code",
498 | "metadata": {
499 | "id": "IrgoJL0xZTuv"
500 | },
501 | "source": [
502 | ""
503 | ],
504 | "execution_count": null,
505 | "outputs": []
506 | }
507 | ]
508 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | gdown==3.12.2
2 | matplotlib==3.3.3
3 | streamlit==0.74.1
4 | tensorflow==2.4.0
5 | wandb==0.10.14
6 | Pillow~=8.1.0
7 | numpy~=1.19.5
8 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | streamlit run app.py
2 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | mkdir -p ~/.streamlit/
2 |
3 | echo "\
4 | [server]\n\
5 | port = $PORT\n\
6 | enableCORS = false\n\
7 | headless = true\n\
8 | \n\
9 | " > ~/.streamlit/config.toml
10 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from glob import glob
2 | import tensorflow as tf
3 | from matplotlib import pyplot as plt
4 | from mirnet.model import mirnet_model
5 | from mirnet.dataloaders import LOLDataLoader
6 |
7 |
8 | def test_output_dim():
9 | mirnet = mirnet_model(256, 3, 2, 64)
10 | x = tf.ones((1, 256, 256, 3))
11 | y = mirnet(x)
12 | assert x.shape == y.shape
13 |
14 |
15 | def test_dataloader():
16 | lowlight_images = glob('./data/LOLdataset/our485/low/*')
17 | highlight_images = glob('./data/LOLdataset/our485/high/*')
18 | dataset = LOLDataLoader(
19 | images_lowlight=lowlight_images,
20 | images_highlight=highlight_images
21 | ).build_dataset(
22 | image_crop_size=128, batch_size=1, apply_transforms=True
23 | )
24 | print(dataset)
25 | x, y = next(iter(dataset))
26 | print(x.shape, y.shape)
27 | plt.imshow(tf.cast(x[0] * 255, dtype=tf.uint8))
28 | plt.title('Low Light Patch (128 x 128)')
29 | plt.show()
30 | plt.title('High Light Patch (128 x 128)')
31 | plt.imshow(tf.cast(y[0] * 255, dtype=tf.uint8))
32 | plt.show()
33 |
34 |
35 | if __name__ == '__main__':
36 | # test_dataloader()
37 | test_output_dim()
38 |
--------------------------------------------------------------------------------