├── Images ├── Decoder.JPG ├── Encoder.JPG └── Discriminator.JPG ├── main.py ├── LICENSE ├── png2npz.py ├── README.md └── Low_Light_Image_Enhancement_using_GAN.ipynb /Images/Decoder.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarylFernandes99/Low-light-Image-Enhancement-using-GAN/HEAD/Images/Decoder.JPG -------------------------------------------------------------------------------- /Images/Encoder.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarylFernandes99/Low-light-Image-Enhancement-using-GAN/HEAD/Images/Encoder.JPG -------------------------------------------------------------------------------- /Images/Discriminator.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DarylFernandes99/Low-light-Image-Enhancement-using-GAN/HEAD/Images/Discriminator.JPG -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2, os 3 | from PIL import Image 4 | import tensorflow as tf 5 | from keras.preprocessing import image 6 | from tensorflow.keras.models import load_model 7 | from keras.preprocessing.image import img_to_array, save_img, array_to_img 8 | 9 | model_path = "" 10 | image_path = "" 11 | 12 | # Processing Image 13 | img = cv2.imread(image_path) 14 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 15 | img_arr = (img_to_array(img) - 127.5) / 127.5 16 | resized = cv2.resize(img_arr, (256, 256), interpolation=cv2.INTER_AREA) 17 | ready_img = np.expand_dims(resized, axis=0) 18 | 19 | # Loading Model 20 | model = load_model(model_path) 21 | 22 | # Prdicting Image 23 | pred = model.predict(ready_img) 24 | pred = (cv2.medianBlur(pred[0], 1) + 1) / 2 25 | pred = array_to_img(pred) 26 | save_img("./output.png", pred) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Daryl Fernandes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /png2npz.py: -------------------------------------------------------------------------------- 1 | ################################ Imports #################################### 2 | from os import listdir 3 | from keras.preprocessing.image import img_to_array,load_img 4 | from numpy import asarray, savez_compressed 5 | 6 | 7 | ################################ Load Dataset #################################### 8 | 9 | def load_images(path, size=(256,256)): 10 | data_list = list() 11 | # enumerate filenames in directory, assume all are images 12 | for filename in listdir(path): 13 | # load and resize the image 14 | pixels = load_img(path + filename, target_size=size) 15 | # convert to numpy array 16 | pixels = img_to_array(pixels) 17 | # store 18 | data_list.append(pixels) 19 | 20 | return asarray(data_list) 21 | 22 | 23 | ################################ Main Function #################################### 24 | def main(): 25 | 26 | # load images 27 | path = "" 28 | high = load_images(path + 'ground_truth/') #directory "path/ground_truth" should be present 29 | low = load_images(path + 'low/') #directory "path/low" should be present 30 | 31 | savez_compressed('dataset.npz',high,low) 32 | 33 | if __name__ == '__main__': 34 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Low-Light Image Enhancement using GAN 2 | 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 4 | [![Python](https://img.shields.io/badge/Python-3.7+-blue.svg)](https://www.python.org/downloads/) 5 | [![TensorFlow](https://img.shields.io/badge/TensorFlow-2.4.1+-orange.svg)](https://tensorflow.org/) 6 | 7 | A state-of-the-art deep learning solution that transforms images captured in poor lighting conditions into well-illuminated, high-quality versions using Generative Adversarial Networks (GANs). This project implements a sophisticated encoder-decoder GAN architecture to tackle the challenging problem of low-light image enhancement. 8 | 9 | ## 🌟 Overview 10 | 11 | Low-light image enhancement is a critical challenge in computer vision, affecting photography, surveillance, and autonomous systems. This project addresses this problem using advanced deep learning techniques, specifically a custom GAN architecture that learns to map low-light images to their well-illuminated counterparts while preserving important details and improving overall image quality. 12 | 13 | ## ✨ Key Features 14 | 15 | - **Advanced GAN Architecture**: Custom encoder-decoder generator with skip connections 16 | - **Multi-scale Processing**: Handles various lighting conditions and image types 17 | - **Detail Preservation**: Maintains fine-grained features while enhancing illumination 18 | - **Robust Training**: Implements adversarial training with multiple loss functions 19 | - **Multiple Format Support**: Compatible with JPG, PNG, and other standard image formats 20 | - **Real-time Inference**: Optimized for efficient image enhancement 21 | - **Comprehensive Dataset Support**: Trained on multiple benchmark datasets 22 | 23 | ## 🏗️ Architecture 24 | 25 | ### System Overview 26 | 27 | The system consists of three main components working in an adversarial framework: 28 | 29 | #### 1. Generator (Encoder-Decoder Architecture) 30 | 31 | **Encoder Network** 32 | ![Encoder Architecture](./Images/Encoder.JPG) 33 | 34 | - Extracts multi-scale features from low-light input images 35 | - Uses convolutional layers with increasing channel depth 36 | - Implements skip connections for feature preservation 37 | - Batch normalization and LeakyReLU activations 38 | 39 | **Decoder Network** 40 | ![Decoder Architecture](./Images/Decoder.JPG) 41 | 42 | - Reconstructs enhanced images from encoded features 43 | - Utilizes upsampling and concatenation operations 44 | - Integrates skip connections from encoder 45 | - Applies dropout for regularization 46 | 47 | #### 2. Discriminator Network 48 | 49 | ![Discriminator Architecture](./Images/Discriminator.JPG) 50 | 51 | - PatchGAN discriminator for local texture analysis 52 | - Distinguishes between real and generated enhanced images 53 | - Processes concatenated source and target images 54 | - Progressive downsampling with increasing channel depth 55 | 56 | ### Technical Specifications 57 | 58 | **Generator Architecture:** 59 | ``` 60 | Input: 256×256×3 RGB Image 61 | ├── Encoder Branch: 62 | │ ├── Conv2D(64, 7×7) + BatchNorm + LeakyReLU 63 | │ ├── Conv2D(128, 3×3, stride=2) + BatchNorm + LeakyReLU 64 | │ └── Conv2D(256, 3×3, stride=2) + BatchNorm + LeakyReLU 65 | ├── Residual Blocks: 6× Conv2D(256, 3×3) + BatchNorm + LeakyReLU 66 | └── Decoder Branch: 67 | ├── UpSampling2D(2×2) + Conv2D(128, 1×1) + Dropout(0.5) 68 | ├── UpSampling2D(2×2) + Conv2D(64, 1×1) + Dropout(0.5) 69 | └── Conv2D(3, 7×7) + BatchNorm + Tanh 70 | Output: 256×256×3 Enhanced Image 71 | ``` 72 | 73 | **Discriminator Architecture:** 74 | ``` 75 | Inputs: Source Image (256×256×3) + Target Image (256×256×3) 76 | ├── Concatenate → 256×256×6 77 | ├── Conv2D(64, 4×4, stride=2) + LeakyReLU 78 | ├── Conv2D(128, 4×4, stride=2) + BatchNorm + LeakyReLU 79 | ├── Conv2D(256, 4×4, stride=2) + BatchNorm + LeakyReLU 80 | ├── Conv2D(512, 4×4, stride=2) + BatchNorm + LeakyReLU 81 | ├── Conv2D(512, 4×4) + BatchNorm + LeakyReLU 82 | └── Conv2D(1, 4×4) + Sigmoid 83 | Output: Patch Classification Map 84 | ``` 85 | 86 | ## 📊 Datasets 87 | 88 | The model is trained on multiple high-quality datasets to ensure robust performance across various lighting conditions: 89 | 90 | ### Primary Datasets 91 | 92 | 1. **LOL Dataset** - Low Light Paired Dataset 93 | - [Download Link](https://drive.google.com/file/d/157bjO1_cFuSd0HWDUuAmcHRJDVyWpOxB/view) 94 | - Paired low-light and normal-light images 95 | - Real-world indoor and outdoor scenes 96 | 97 | 2. **SID Dataset** - See-in-the-Dark Dataset 98 | - **Sony**: [Download](https://storage.googleapis.com/isl-datasets/SID/Sony.zip) 99 | - **Fuji**: [Download](https://storage.googleapis.com/isl-datasets/SID/Fuji.zip) 100 | - RAW sensor data for extreme low-light conditions 101 | 102 | 3. **SICE Dataset** - Single Image Contrast Enhancement 103 | - [Part 1](https://drive.google.com/file/d/1HiLtYiyT9R7dR9DRTLRlUUrAicC4zzWN/view) 104 | - [Part 2](https://drive.google.com/file/d/16VoHNPAZ5Js19zspjFOsKiGRrfkDgHoN/view) 105 | - Multi-exposure sequences for training 106 | 107 | ### Synthetic Dataset 108 | 109 | 4. **Custom Synthetic Dataset** 110 | - [Kaggle Link](https://www.kaggle.com/basu369victor/low-light-image-enhancement-with-cnn) 111 | - [Synthetic Pairs](https://drive.google.com/file/d/1G6fi9Kiu7CDnW2Sh7UQ5ikvScRv8Q14F/view) 112 | - Generated from high-quality images with simulated low-light conditions 113 | 114 | ## 🛠️ Installation 115 | 116 | ### Prerequisites 117 | 118 | ```bash 119 | Python 3.7+ 120 | CUDA-compatible GPU (recommended) 121 | 16GB+ RAM (for training) 122 | ``` 123 | 124 | ### Dependencies 125 | 126 | ```bash 127 | pip install tensorflow==2.4.1 128 | pip install keras 129 | pip install numpy 130 | pip install opencv-python 131 | pip install pillow 132 | pip install matplotlib 133 | ``` 134 | 135 | ### Environment Setup 136 | 137 | **Recommended Development Environment:** 138 | - **GPU**: NVIDIA Tesla T4/P100 16GB or equivalent 139 | - **RAM**: 12GB+ system memory 140 | - **Storage**: 50GB+ for datasets and models 141 | - **OS**: Ubuntu 18.04+ or Windows 10+ 142 | 143 | ## 🚀 Quick Start 144 | 145 | ### 1. Data Preparation 146 | 147 | Organize your dataset structure: 148 | ``` 149 | dataset/ 150 | ├── ground_truth/ # Well-lit reference images 151 | └── low/ # Corresponding low-light images 152 | ``` 153 | 154 | Convert images to NPZ format using the preprocessing script: 155 | 156 | ```python 157 | # Edit png2npz.py 158 | path = "path/to/your/dataset/" 159 | 160 | # Run preprocessing 161 | python png2npz.py 162 | ``` 163 | 164 | **png2npz.py Configuration:** 165 | ```python 166 | # Key parameters in png2npz.py 167 | TARGET_SIZE = (256, 256) # Resize all images to 256x256 168 | OUTPUT_FILE = 'dataset.npz' # Output compressed dataset 169 | ``` 170 | 171 | ### 2. Training the Model 172 | 173 | #### Option A: Google Colab (Recommended) 174 | 175 | 1. Upload the notebook `Low_Light_Image_Enhancement_using_GAN.ipynb` to Google Colab 176 | 2. Mount Google Drive and upload your `dataset.npz` 177 | 3. Update the dataset path in the notebook 178 | 4. Execute all cells sequentially 179 | 180 | #### Option B: Local Training 181 | 182 | ```python 183 | # Configure training parameters 184 | BATCH_SIZE = 12 # Adjust based on GPU memory 185 | EPOCHS = 100 186 | LEARNING_RATE = 0.0002 187 | 188 | # Load dataset 189 | dataset = load_real_samples('dataset.npz') 190 | 191 | # Initialize models 192 | d_model = define_discriminator(image_shape) 193 | g_model = define_generator(image_shape) 194 | gan_model = define_gan(g_model, d_model, image_shape) 195 | 196 | # Start training 197 | train(d_model, g_model, gan_model, dataset, n_epochs=EPOCHS, n_batch=BATCH_SIZE) 198 | ``` 199 | 200 | ### 3. Image Enhancement 201 | 202 | Use the trained model for inference: 203 | 204 | ```python 205 | # Edit main.py with your paths 206 | model_path = "path/to/saved/generator_model.h5" 207 | image_path = "path/to/low_light_image.jpg" 208 | 209 | # Run enhancement 210 | python main.py 211 | ``` 212 | 213 | **main.py Workflow:** 214 | 1. Load and preprocess input image (resize to 256×256, normalize to [-1,1]) 215 | 2. Load the trained generator model 216 | 3. Generate enhanced image 217 | 4. Apply median blur for artifact reduction 218 | 5. Save result as `output.png` 219 | 220 | ## 📈 Training Details 221 | 222 | ### Training Strategy 223 | 224 | - **Adversarial Training**: Alternating discriminator and generator updates 225 | - **L1 Regularization**: Preserves structural details (λ=100) 226 | - **Adam Optimizer**: Learning rate = 0.0002, β₁ = 0.5 227 | - **Batch Size**: 12 (adjustable based on GPU memory) 228 | - **Progressive Saving**: Models saved every epoch for monitoring 229 | 230 | ### Training Monitoring 231 | 232 | The training process generates: 233 | - **Loss Plots**: Real-time loss tracking for both networks 234 | - **Sample Images**: Visual comparison of low-light, generated, and ground truth 235 | - **Model Checkpoints**: Saved every epoch for evaluation 236 | 237 | ## 🔬 Technical Implementation 238 | 239 | ### Key Components 240 | 241 | **1. Skip Connections** 242 | - Preserve fine-grained details during upsampling 243 | - Connect encoder features directly to decoder 244 | - Prevent information loss in deep network 245 | 246 | **2. Batch Normalization** 247 | - Stabilizes training process 248 | - Enables higher learning rates 249 | - Applied throughout both networks 250 | 251 | **3. Residual Blocks** 252 | - Six residual blocks in generator bottleneck 253 | - Helps with gradient flow 254 | - Improves feature representation 255 | 256 | **4. PatchGAN Discriminator** 257 | - Focuses on local texture quality 258 | - More efficient than full-image discrimination 259 | - Better for preserving high-frequency details 260 | 261 | ### Preprocessing Pipeline 262 | 263 | ```python 264 | def preprocess_image(image_path): 265 | # Load image 266 | img = cv2.imread(image_path) 267 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 268 | 269 | # Normalize to [-1, 1] 270 | img_arr = (img_to_array(img) - 127.5) / 127.5 271 | 272 | # Resize to model input size 273 | resized = cv2.resize(img_arr, (256, 256), interpolation=cv2.INTER_AREA) 274 | 275 | # Add batch dimension 276 | return np.expand_dims(resized, axis=0) 277 | ``` 278 | 279 | ### Post-processing 280 | 281 | ```python 282 | def postprocess_output(prediction): 283 | # Remove batch dimension 284 | pred = prediction[0] 285 | 286 | # Apply median blur to reduce artifacts 287 | pred = cv2.medianBlur(pred, 1) 288 | 289 | # Denormalize from [-1, 1] to [0, 1] 290 | pred = (pred + 1) / 2 291 | 292 | return pred 293 | ``` 294 | 295 | ## 📋 File Structure 296 | 297 | ``` 298 | Low-light-Image-Enhancement-using-GAN/ 299 | ├── Images/ # Architecture diagrams 300 | │ ├── Decoder.JPG # Decoder network visualization 301 | │ ├── Discriminator.JPG # Discriminator architecture 302 | │ └── Encoder.JPG # Encoder network structure 303 | ├── Low_Light_Image_Enhancement_using_GAN.ipynb # Main training notebook 304 | ├── main.py # Inference script 305 | ├── png2npz.py # Dataset preprocessing utility 306 | ├── README.md # Project documentation 307 | ├── LICENSE # MIT license 308 | └── .git/ # Git repository data 309 | ``` 310 | 311 | ## 🎯 Results and Performance 312 | 313 | ### Quantitative Metrics 314 | 315 | The model achieves significant improvements in: 316 | - **PSNR (Peak Signal-to-Noise Ratio)**: Enhanced image quality 317 | - **SSIM (Structural Similarity Index)**: Preserved structural content 318 | - **LPIPS (Learned Perceptual Image Patch Similarity)**: Better perceptual quality 319 | 320 | ### Qualitative Improvements 321 | 322 | - **Brightness Enhancement**: Significant illumination improvement 323 | - **Contrast Restoration**: Better dynamic range 324 | - **Detail Preservation**: Fine features maintained 325 | - **Color Accuracy**: Natural color reproduction 326 | - **Noise Reduction**: Cleaner output images 327 | - **Artifact Minimization**: Reduced over-enhancement effects 328 | 329 | ### Use Cases 330 | 331 | - **Photography**: Post-processing of low-light photos 332 | - **Surveillance**: Enhanced security camera footage 333 | - **Medical Imaging**: Improved visibility in medical scans 334 | - **Autonomous Vehicles**: Better night vision capabilities 335 | - **Mobile Photography**: Real-time enhancement on smartphones 336 | 337 | ## 📚 Research and Publication 338 | 339 | This project has been published in a peer-reviewed research paper: 340 | 341 | **Publication Details:** 342 | - **Title**: Low-Light Image Enhancement using Generative Adversarial Networks 343 | - **Authors**: Daryl Fernandes, et al. 344 | - **Journal**: International Research Journal of Engineering and Technology (IRJET) 345 | - **Volume**: 8, Issue 6 346 | - **Year**: 2021 347 | - **Pages**: 136-142 348 | - **ISSN**: 2395-0072 349 | - **Link**: [IRJET Publication](https://www.irjet.net/archives/V8/i6/IRJET-V8I6136.pdf) 350 | 351 | ## 📖 Citation 352 | 353 | If you use this work in your research, please cite our paper: 354 | 355 | ### BibTeX 356 | ```bibtex 357 | @article{fernandes2021lowlight, 358 | title={Low-Light Image Enhancement using Generative Adversarial Networks}, 359 | author={Fernandes, Daryl and others}, 360 | journal={International Research Journal of Engineering and Technology (IRJET)}, 361 | volume={8}, 362 | number={6}, 363 | pages={136--142}, 364 | year={2021}, 365 | publisher={IRJET}, 366 | issn={2395-0072}, 367 | url={https://www.irjet.net/archives/V8/i6/IRJET-V8I6136.pdf} 368 | } 369 | ``` 370 | 371 | ### APA Style 372 | ``` 373 | Fernandes, D., et al. (2021). Low-Light Image Enhancement using Generative Adversarial Networks. 374 | International Research Journal of Engineering and Technology (IRJET), 8(6), 136-142. 375 | Retrieved from https://www.irjet.net/archives/V8/i6/IRJET-V8I6136.pdf 376 | ``` 377 | 378 | ### IEEE Style 379 | ``` 380 | D. Fernandes et al., "Low-Light Image Enhancement using Generative Adversarial Networks," 381 | International Research Journal of Engineering and Technology (IRJET), vol. 8, no. 6, 382 | pp. 136-142, 2021. [Online]. Available: https://www.irjet.net/archives/V8/i6/IRJET-V8I6136.pdf 383 | ``` 384 | 385 | ## 🤝 Contributing 386 | 387 | We welcome contributions to improve this project! Please follow these guidelines: 388 | 389 | 1. Fork the repository 390 | 2. Create a feature branch (`git checkout -b feature/amazing-feature`) 391 | 3. Commit your changes (`git commit -m 'Add amazing feature'`) 392 | 4. Push to the branch (`git push origin feature/amazing-feature`) 393 | 5. Open a Pull Request 394 | 395 | ### Development Guidelines 396 | 397 | - Follow PEP 8 style guidelines 398 | - Add comprehensive docstrings 399 | - Include unit tests for new features 400 | - Update documentation as needed 401 | 402 | ## 📄 License 403 | 404 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 405 | 406 | ## 🔧 Troubleshooting 407 | 408 | ### Common Issues 409 | 410 | **1. GPU Memory Issues** 411 | ```bash 412 | # Reduce batch size in training 413 | BATCH_SIZE = 6 # Instead of 12 414 | 415 | # Enable GPU memory growth 416 | physical_devices = tf.config.list_physical_devices('GPU') 417 | tf.config.experimental.set_memory_growth(physical_devices[0], True) 418 | ``` 419 | 420 | **2. Dataset Loading Errors** 421 | ```bash 422 | # Ensure correct path format 423 | path = "dataset/" # Include trailing slash 424 | # Check image formats are consistent 425 | # Verify directory structure matches requirements 426 | ``` 427 | 428 | **3. Model Convergence Issues** 429 | ```bash 430 | # Adjust learning rate 431 | LEARNING_RATE = 0.0001 # Reduce if training unstable 432 | 433 | # Monitor discriminator/generator balance 434 | # Adjust loss weights if needed 435 | ``` 436 | 437 | ### Performance Optimization 438 | 439 | **For Training:** 440 | - Use mixed precision training for faster convergence 441 | - Implement data augmentation for better generalization 442 | - Use gradient accumulation for larger effective batch sizes 443 | 444 | **For Inference:** 445 | - Convert model to TensorFlow Lite for mobile deployment 446 | - Use TensorRT for optimized GPU inference 447 | - Implement batch processing for multiple images 448 | 449 | ## 📞 Contact and Support 450 | 451 | For questions, issues, or collaborations: 452 | 453 | - **GitHub Issues**: [Create an issue](https://github.com/your-username/Low-light-Image-Enhancement-using-GAN/issues) 454 | - **Research Paper**: [IRJET Publication](https://www.irjet.net/archives/V8/i6/IRJET-V8I6136.pdf) 455 | 456 | ## 🙏 Acknowledgments 457 | 458 | - **Datasets**: Thanks to the creators of LOL, SID, and SICE datasets 459 | - **Research Community**: Built upon advances in GAN research 460 | - **TensorFlow Team**: For the excellent deep learning framework 461 | - **Google Colab**: For providing accessible GPU resources 462 | 463 | ## 🔮 Future Work 464 | 465 | - **Real-time Processing**: Optimize for video enhancement 466 | - **Mobile Deployment**: Create mobile app versions 467 | - **Multi-scale Training**: Handle various image resolutions 468 | - **Unsupervised Learning**: Reduce dependency on paired data 469 | - **Advanced Architectures**: Explore attention mechanisms and transformers 470 | 471 | --- 472 | 473 | *This project demonstrates the power of GANs in solving challenging computer vision problems. We hope it serves as a valuable resource for researchers and practitioners working on image enhancement tasks.* 474 | -------------------------------------------------------------------------------- /Low_Light_Image_Enhancement_using_GAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Low-Light Image Enhancement using GAN.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "language_info": { 15 | "name": "python" 16 | } 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "markdown", 21 | "metadata": { 22 | "id": "dgn0Oca-prAe" 23 | }, 24 | "source": [ 25 | "# Fetching Dataset from Google Drive" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "metadata": { 31 | "id": "_AlJyDJupuQX" 32 | }, 33 | "source": [ 34 | "from google.colab import drive\n", 35 | "drive.mount('/content/drive')\n", 36 | "\n", 37 | "import shutil\n", 38 | "shutil.copy(\"/content/drive/\", \"\")" 39 | ], 40 | "execution_count": null, 41 | "outputs": [] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "id": "RyCiAtbNoHxQ" 47 | }, 48 | "source": [ 49 | "#Importing Libraries" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "metadata": { 55 | "id": "eaxJFRPsnW04" 56 | }, 57 | "source": [ 58 | "import os\n", 59 | "import time\n", 60 | "import datetime\n", 61 | "import tensorflow as tf\n", 62 | "import matplotlib.pyplot as plt\n", 63 | "from numpy.random import randint\n", 64 | "from tensorflow.keras import Input\n", 65 | "from numpy import load, zeros, ones\n", 66 | "from tensorflow.keras.optimizers import Adam\n", 67 | "from tensorflow.keras.models import Model, load_model\n", 68 | "from tensorflow.keras.initializers import RandomNormal\n", 69 | "from tensorflow.keras.layers import Conv2D, UpSampling2D, LeakyReLU, Activation\n", 70 | "from tensorflow.keras.layers import Concatenate, Dropout, BatchNormalization, LeakyReLU" 71 | ], 72 | "execution_count": null, 73 | "outputs": [] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": { 78 | "id": "uKEC_Mfopb2s" 79 | }, 80 | "source": [ 81 | "# Creating Functions" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": { 87 | "id": "YTesJLExoK9h" 88 | }, 89 | "source": [ 90 | "## Defining Discriminator" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "metadata": { 96 | "id": "mbxlfZtrnc7c" 97 | }, 98 | "source": [ 99 | "def define_discriminator(image_shape):\n", 100 | "\tinit = RandomNormal(stddev=0.02)\n", 101 | " \n", 102 | "\tin_src_image = Input(shape=image_shape)\n", 103 | "\tin_target_image = Input(shape=image_shape)\n", 104 | " \n", 105 | "\tmerged = Concatenate()([in_src_image, in_target_image])\n", 106 | " \n", 107 | "\td = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)\n", 108 | "\td = LeakyReLU(alpha=0.2)(d)\n", 109 | " \n", 110 | "\td = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)\n", 111 | "\td = BatchNormalization()(d)\n", 112 | "\td = LeakyReLU(alpha=0.2)(d)\n", 113 | " \n", 114 | "\td = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)\n", 115 | "\td = BatchNormalization()(d)\n", 116 | "\td = LeakyReLU(alpha=0.2)(d)\n", 117 | " \n", 118 | "\td = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)\n", 119 | "\td = BatchNormalization()(d)\n", 120 | "\td = LeakyReLU(alpha=0.2)(d)\n", 121 | " \n", 122 | "\td = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)\n", 123 | "\td = BatchNormalization()(d)\n", 124 | "\td = LeakyReLU(alpha=0.2)(d)\n", 125 | " \n", 126 | "\td = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)\n", 127 | "\tpatch_out = Activation('sigmoid')(d)\n", 128 | " \n", 129 | "\tmodel = Model([in_src_image, in_target_image], patch_out)\n", 130 | " \n", 131 | "\topt = Adam(lr=0.0002, beta_1=0.5)\n", 132 | "\tmodel.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])\n", 133 | "\treturn model" 134 | ], 135 | "execution_count": null, 136 | "outputs": [] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "metadata": { 141 | "id": "IWqF429qoP1H" 142 | }, 143 | "source": [ 144 | "## Defining Generator" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "metadata": { 150 | "id": "71cxSSWUnhIq" 151 | }, 152 | "source": [ 153 | "def define_generator(image_shape = (256, 256, 3)):\n", 154 | " init = RandomNormal(stddev=0.02)\n", 155 | " in_image = Input(shape=image_shape)\n", 156 | "\n", 157 | " g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)\n", 158 | " g = BatchNormalization()(g, training=True)\n", 159 | " g3 = LeakyReLU(alpha=0.2)(g)\n", 160 | "\n", 161 | " g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g3)\n", 162 | " g = BatchNormalization()(g, training=True)\n", 163 | " g2 = LeakyReLU(alpha=0.2)(g)\n", 164 | "\n", 165 | " g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g2)\n", 166 | " g = BatchNormalization()(g, training=True)\n", 167 | " g1 = LeakyReLU(alpha=0.2)(g)\n", 168 | "\n", 169 | " for _ in range(6):\n", 170 | " g = Conv2D(256, (3,3), padding='same', kernel_initializer=init)(g1)\n", 171 | " g = BatchNormalization()(g, training=True)\n", 172 | " g = LeakyReLU(alpha=0.2)(g)\n", 173 | "\n", 174 | " g = Conv2D(256, (3,3), padding='same', kernel_initializer=init)(g)\n", 175 | " g = BatchNormalization()(g, training=True)\n", 176 | "\n", 177 | " g1 = Concatenate()([g, g1])\n", 178 | "\n", 179 | " g = UpSampling2D((2, 2))(g1)\n", 180 | " g = Conv2D(128, (1, 1), kernel_initializer=init)(g)\n", 181 | " g = Dropout(0.5)(g, training=True)\n", 182 | " g = Concatenate()([g, g2])\n", 183 | " g = BatchNormalization()(g, training=True)\n", 184 | " g = LeakyReLU(alpha=0.2)(g)\n", 185 | "\n", 186 | " g = UpSampling2D((2, 2))(g)\n", 187 | " g = Conv2D(64, (1, 1), kernel_initializer=init)(g)\n", 188 | " g = Dropout(0.5)(g, training=True)\n", 189 | " g = Concatenate()([g, g3])\n", 190 | " g = BatchNormalization()(g, training=True)\n", 191 | " g = LeakyReLU(alpha=0.2)(g)\n", 192 | "\n", 193 | " g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)\n", 194 | " g = BatchNormalization()(g, training=True)\n", 195 | " out_image = Activation('tanh')(g)\n", 196 | "\n", 197 | " model = Model(in_image, out_image)\n", 198 | " return model" 199 | ], 200 | "execution_count": null, 201 | "outputs": [] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": { 206 | "id": "Jj8HuE2qoS3f" 207 | }, 208 | "source": [ 209 | "## Initializing GAN training" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "metadata": { 215 | "id": "KTqa4cLEnkmV" 216 | }, 217 | "source": [ 218 | "def define_gan(g_model, d_model, image_shape):\n", 219 | "\t# make weights in the discriminator not trainable\n", 220 | "\tfor layer in d_model.layers:\n", 221 | "\t\tif not isinstance(layer, BatchNormalization):\n", 222 | "\t\t\tlayer.trainable = False\n", 223 | "\t# define the source image\n", 224 | "\tin_src = Input(shape=image_shape)\n", 225 | "\t# connect the source image to the generator input\n", 226 | "\tgen_out = g_model(in_src)\n", 227 | "\t# connect the source input and generator output to the discriminator input\n", 228 | "\tdis_out = d_model([in_src, gen_out])\n", 229 | "\t# src image as input, generated image and classification output\n", 230 | "\tmodel = Model(in_src, [dis_out, gen_out])\n", 231 | "\t# compile model\n", 232 | "\topt = Adam(lr=0.0002, beta_1=0.5)\n", 233 | "\tmodel.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])\n", 234 | "\treturn model" 235 | ], 236 | "execution_count": null, 237 | "outputs": [] 238 | }, 239 | { 240 | "cell_type": "markdown", 241 | "metadata": { 242 | "id": "LBFQpdz5oX4s" 243 | }, 244 | "source": [ 245 | "## Loading Real Samples" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "metadata": { 251 | "id": "sQpWSsK0nolM" 252 | }, 253 | "source": [ 254 | "def load_real_samples(filename):\n", 255 | "\t# load compressed arrays\n", 256 | "\tdata = load(filename)\n", 257 | "\t# unpack arrays\n", 258 | "\tX1, X2 = data['arr_0'], data['arr_1']\n", 259 | "\t# scale from [0,255] to [-1,1]\n", 260 | "\tX1 = (X1 - 127.5) / 127.5\n", 261 | "\tX2 = (X2 - 127.5) / 127.5\n", 262 | "\treturn [X2, X1]" 263 | ], 264 | "execution_count": null, 265 | "outputs": [] 266 | }, 267 | { 268 | "cell_type": "markdown", 269 | "metadata": { 270 | "id": "HZQ9GO20oZgQ" 271 | }, 272 | "source": [ 273 | "## Generating Real Fake Samples" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "metadata": { 279 | "id": "FlngHBPLnrof" 280 | }, 281 | "source": [ 282 | "def generate_real_samples(dataset, n_samples, patch_shape):\n", 283 | "\t# unpack dataset\n", 284 | "\ttrainA, trainB = dataset\n", 285 | "\t# choose random instances\n", 286 | "\tix = randint(0, trainA.shape[0], n_samples)\n", 287 | "\t# retrieve selected images\n", 288 | "\tX1, X2 = trainA[ix], trainB[ix]\n", 289 | "\t# generate 'real' class labels (1)\n", 290 | "\ty = ones((n_samples, patch_shape, patch_shape, 1))\n", 291 | "\treturn [X1, X2], y" 292 | ], 293 | "execution_count": null, 294 | "outputs": [] 295 | }, 296 | { 297 | "cell_type": "markdown", 298 | "metadata": { 299 | "id": "d0nsxrucoepT" 300 | }, 301 | "source": [ 302 | "## Generating Fake Samples" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "metadata": { 308 | "id": "mAnQEQxNnuJD" 309 | }, 310 | "source": [ 311 | "def generate_fake_samples(g_model, samples, patch_shape):\n", 312 | "\t# generate fake instance\n", 313 | "\tX = g_model.predict(samples)\n", 314 | "\t# create 'fake' class labels (0)\n", 315 | "\ty = zeros((len(X), patch_shape, patch_shape, 1))\n", 316 | "\treturn X, y" 317 | ], 318 | "execution_count": null, 319 | "outputs": [] 320 | }, 321 | { 322 | "cell_type": "markdown", 323 | "metadata": { 324 | "id": "_qs-bEhiohte" 325 | }, 326 | "source": [ 327 | "## Summarizing Training and Saving Model" 328 | ] 329 | }, 330 | { 331 | "cell_type": "code", 332 | "metadata": { 333 | "id": "hmzXpDOanw3z" 334 | }, 335 | "source": [ 336 | "def summarize_performance(step, g_model, d_model, dataset, n_samples=3):\n", 337 | " # select a sample of input images\n", 338 | " [X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)\n", 339 | " # generate a batch of fake samples\n", 340 | " X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)\n", 341 | " # scale all pixels from [-1,1] to [0,1]\n", 342 | " X_realA = (X_realA + 1) / 2.0\n", 343 | " X_realB = (X_realB + 1) / 2.0\n", 344 | " X_fakeB = (X_fakeB + 1) / 2.0\n", 345 | " # plot real source images\n", 346 | " plt.figure(figsize=(14, 14))\n", 347 | " for i in range(n_samples):\n", 348 | " plt.subplot(3, n_samples, 1 + i)\n", 349 | " plt.axis('off')\n", 350 | " plt.title('Low-Light')\n", 351 | " plt.imshow(X_realA[i])\n", 352 | " # plot generated target image\n", 353 | " for i in range(n_samples):\n", 354 | " plt.subplot(3, n_samples, 1 + n_samples + i)\n", 355 | " plt.axis('off')\n", 356 | " plt.title('Generated')\n", 357 | " plt.imshow(X_fakeB[i])\n", 358 | " # plot real target image\n", 359 | " for i in range(n_samples):\n", 360 | " plt.subplot(3, n_samples, 1 + n_samples*2 + i)\n", 361 | " plt.axis('off')\n", 362 | " plt.title('Ground Truth')\n", 363 | " plt.imshow(X_realB[i])\n", 364 | " # save plot to file\n", 365 | " filename1 = step_output + 'plot_%06d.png' % (step+1)\n", 366 | " plt.savefig(filename1)\n", 367 | " plt.close()\n", 368 | " # save the generator model\n", 369 | " filename2 = model_output + 'gen_model_%06d.h5' % (step+1)\n", 370 | " g_model.save(filename2)\n", 371 | " # save the discriminator model\n", 372 | " filename3 = model_output + 'disc_model_%06d.h5' % (step+1)\n", 373 | " d_model.save(filename3)\n", 374 | " print('[.] Saved Step : %s' % (filename1))\n", 375 | " print('[.] Saved Model: %s' % (filename2))\n", 376 | " print('[.] Saved Model: %s' % (filename3))" 377 | ], 378 | "execution_count": null, 379 | "outputs": [] 380 | }, 381 | { 382 | "cell_type": "markdown", 383 | "metadata": { 384 | "id": "uIChDKI8omnh" 385 | }, 386 | "source": [ 387 | "## Training Function" 388 | ] 389 | }, 390 | { 391 | "cell_type": "code", 392 | "metadata": { 393 | "id": "INv0tUj3n0cj" 394 | }, 395 | "source": [ 396 | "def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=12):\n", 397 | " # determine the output square shape of the discriminator\n", 398 | " n_patch = d_model.output_shape[1]\n", 399 | " # unpack dataset\n", 400 | " trainA, trainB = dataset\n", 401 | " # calculate the number of batches per training epoch\n", 402 | " bat_per_epo = int(len(trainA) / n_batch)\n", 403 | " # calculate the number of training iterations\n", 404 | " n_steps = bat_per_epo * n_epochs\n", 405 | " print(\"[!] Number of steps {}\".format(n_steps))\n", 406 | " print(\"[!] Saves model/step output at every {}\".format(bat_per_epo * 1))\n", 407 | " # manually enumerate epochs\n", 408 | " for i in range(n_steps):\n", 409 | " start = time.time()\n", 410 | " # select a batch of real samples\n", 411 | " [X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)\n", 412 | " # generate a batch of fake samples\n", 413 | " X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)\n", 414 | " # update discriminator for real samples\n", 415 | " d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)\n", 416 | " # update discriminator for generated samples\n", 417 | " d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)\n", 418 | " # update the generator\n", 419 | " g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])\n", 420 | " # summarize performance\n", 421 | " time_taken = time.time() - start\n", 422 | " print(\n", 423 | " '[*] %06d, d1[%.3f] d2[%.3f] g[%06.3f] ---> time[%.2f], time_left[%.08s]'\n", 424 | " %\n", 425 | " (i+1, d_loss1, d_loss2, g_loss, time_taken, str(datetime.timedelta(seconds=((time_taken) * (n_steps - (i + 1))))).split('.')[0].zfill(8))\n", 426 | " )\n", 427 | " # summarize model performance\n", 428 | " if (i+1) % (bat_per_epo * 1) == 0:\n", 429 | " summarize_performance(i, g_model, d_model, dataset)" 430 | ], 431 | "execution_count": null, 432 | "outputs": [] 433 | }, 434 | { 435 | "cell_type": "markdown", 436 | "metadata": { 437 | "id": "9Ir5XcIdo8XF" 438 | }, 439 | "source": [ 440 | "# Main Function" 441 | ] 442 | }, 443 | { 444 | "cell_type": "markdown", 445 | "metadata": { 446 | "id": "Rk1JU6YXo1e7" 447 | }, 448 | "source": [ 449 | "## Loading Dataset" 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "metadata": { 455 | "id": "6jX2Qli9o1CW" 456 | }, 457 | "source": [ 458 | "dataset = load_real_samples('')\n", 459 | "print('Loaded', dataset[0].shape, dataset[1].shape)\n", 460 | "image_shape = dataset[0].shape[1:]" 461 | ], 462 | "execution_count": null, 463 | "outputs": [] 464 | }, 465 | { 466 | "cell_type": "markdown", 467 | "metadata": { 468 | "id": "7qst6rVopRJT" 469 | }, 470 | "source": [ 471 | "## Defining Models" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "metadata": { 477 | "id": "827JDijxpQm6" 478 | }, 479 | "source": [ 480 | "d_model = define_discriminator(image_shape)\n", 481 | "g_model = define_generator(image_shape)\n", 482 | "gan_model = define_gan(g_model, d_model, image_shape)" 483 | ], 484 | "execution_count": null, 485 | "outputs": [] 486 | }, 487 | { 488 | "cell_type": "markdown", 489 | "metadata": { 490 | "id": "_WcK8E0fo_Ae" 491 | }, 492 | "source": [ 493 | "## Creating model Directory and Calling Train Function" 494 | ] 495 | }, 496 | { 497 | "cell_type": "code", 498 | "metadata": { 499 | "id": "iPuLnNjrn6kv" 500 | }, 501 | "source": [ 502 | "dir = ''\n", 503 | "fileName = 'Enhancement Model'\n", 504 | "step_output = dir + fileName + \"/Step Output/\"\n", 505 | "model_output = dir + fileName + \"/Model Output/\"\n", 506 | "if fileName not in os.listdir(dir):\n", 507 | " os.mkdir(dir + fileName)\n", 508 | " os.mkdir(step_output)\n", 509 | " os.mkdir(model_output)\n", 510 | "\n", 511 | "train(d_model, g_model, gan_model, dataset, batch=12)" 512 | ], 513 | "execution_count": null, 514 | "outputs": [] 515 | } 516 | ] 517 | } --------------------------------------------------------------------------------