├── .DS_Store ├── README.md ├── config.py ├── data └── fivek │ ├── test │ ├── label │ │ ├── a4670-Duggan_080115_4464.png │ │ ├── a4680-_DSC0048.png │ │ └── a4690-DSC_0005-1-2.png │ └── raw │ │ ├── a4670-Duggan_080115_4464.png │ │ ├── a4680-_DSC0048.png │ │ └── a4690-DSC_0005-1-2.png │ ├── train │ ├── exp │ │ ├── a2211-dgw_047.png │ │ ├── a2222-_DSC0106.png │ │ └── a2244-jmac_DSC2985.png │ └── raw │ │ ├── a4485-_DSC0026.png │ │ ├── a4495-kme_2356.png │ │ └── a4500-Duggan_090428_8065.png │ └── val │ ├── label │ ├── a4970-DSC_0377.png │ ├── a4980-20090328_at_06h47m42__MG_1211.png │ └── a4990-jmac_MG_1139.png │ └── raw │ ├── a4970-DSC_0377.png │ ├── a4980-20090328_at_06h47m42__MG_1211.png │ └── a4990-jmac_MG_1139.png ├── data_loader.py ├── figures ├── examples.PNG └── framework.png ├── losses.py ├── main.py ├── metrics ├── .DS_Store ├── CalcPSNR.py ├── CalcSSIM.py ├── CenterCrop.m ├── NIMA │ ├── CalcNIMA.py │ ├── __pycache__ │ │ ├── CalcNIMA.cpython-38.pyc │ │ └── mobile_net_v2.cpython-38.pyc │ ├── mobile_net_v2.py │ ├── nima │ │ ├── .gitignore │ │ ├── Dockerfile │ │ ├── LICENSE │ │ ├── Procfile │ │ ├── README.md │ │ ├── nima │ │ │ ├── cli.py │ │ │ ├── common.py │ │ │ ├── inference │ │ │ │ ├── app.py │ │ │ │ ├── inference_model.py │ │ │ │ └── utils.py │ │ │ ├── mobile_net_v2.py │ │ │ ├── model.py │ │ │ └── train │ │ │ │ ├── clean_dataset.py │ │ │ │ ├── datasets.py │ │ │ │ ├── emd_loss.py │ │ │ │ ├── main.py │ │ │ │ └── utils.py │ │ ├── requirements.txt │ │ └── settings.ini │ ├── pretrain-model.pth │ └── test.py └── __pycache__ │ ├── CalcPSNR.cpython-38.pyc │ └── CalcSSIM.cpython-38.pyc ├── models.py ├── tester.py ├── trainer.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Image Enhancement Using GANs 2 | 3 | This repository implements an unsupervised image enhancement algorithm combining Convolutional Neural Networks (CNNs) and Generative Adversarial Networks (GANs). Designed primarily for dental imaging applications, the algorithm addresses challenges such as inconsistent lighting, noise, and fine detail enhancement, producing high-quality outputs without the need for paired training datasets. This algorithm is based on the idea of the Towards unsupervised deep image enhancement with generative adversarial network. The Neural Network Structure and Loss Function are optimized to improve the performance. 4 | 5 | ## Neural Network Architecture 6 | 7 | ### Generator Architecture 8 | 9 | The generator transforms low-quality images into high-quality images through several key components: 10 | 11 | #### Convolutional Layers 12 | 13 | - **Initial Layers**: The input image is processed through a series of convolutional layers to extract features. Each convolutional layer consists of filters, activation functions (LeakyReLU), and batch normalization for efficient learning. 14 | 15 | #### Global Attention Module (GAM) 16 | 17 | - **Objective**: Captures global image properties, such as lighting and color distribution, ensuring overall consistency in enhanced images. 18 | - **Structure**: 19 | - **Initial Convolution Layers**: Applies convolution with filters of size 7x7 and 5x5, extracting global features from the input feature map. 20 | - **Combination Operations**: Combines outputs through concatenation operations and additional convolutions. 21 | - **Activation Functions**: LeakyReLU activation is used after each convolutional step to introduce non-linearity. 22 | - **Skip Connections**: Incorporates residual connections, which add the input to the output to promote gradient flow and preserve key features throughout the layers. 23 | 24 | #### Local Feature Enhancement (LFE) 25 | 26 | - **Objective**: Focuses on enhancing fine-grained details by attending to specific regions in the image using channel and spatial attention mechanisms. 27 | - **Channel Attention Module (CAM)**: 28 | - **Description**: Emphasizes important channels (features) by computing and applying weights based on global pooling operations (average and max). 29 | - **Steps**: 30 | 1. Global Average and Max Pooling on input features. 31 | 2. Passing pooled outputs through shared multi-layer perceptrons (MLPs) to compute attention weights. 32 | 3. Multiplication of attention weights with original features to highlight essential channels. 33 | - **Spatial Attention Module (SAM)**: 34 | - **Description**: Focuses on key spatial regions within the feature maps. 35 | - **Steps**: 36 | 1. Pooling (average and max) along the channel dimension. 37 | 2. Concatenation of pooled outputs, followed by convolution with a 7x7 kernel to generate spatial attention weights. 38 | 3. Multiplication of the resulting weights with input features for spatial refinement. 39 | 40 | #### Deconvolution Layers (Upsampling) 41 | 42 | - **Objective**: Upsample feature maps back to the original image size after multiple convolution and attention operations. 43 | - **Structure**: 44 | - **Transposed Convolutions (Deconvolutions)**: Used to increase the spatial dimensions of the feature maps progressively. 45 | - **Process**: Each deconvolution layer increases the resolution by using transposed convolution operations with specified kernel sizes and strides to ensure the output matches the input image dimensions. 46 | - **Activation and Normalization**: The outputs are further processed using activation functions (LeakyReLU) and normalization layers to maintain stability during training. 47 | 48 | ### Discriminator Architecture 49 | 50 | The discriminator is a multi-scale convolutional network designed to: 51 | 52 | - Differentiate between real and generated images. 53 | - Operate on multiple scales to capture global and local inconsistencies. 54 | - Encourage the generator to produce highly realistic images by providing granular feedback. 55 | 56 | ## Loss Functions 57 | 58 | - **Quality Loss**: Guides the generator to align the output distribution with high-quality images. 59 | - **Consistency Loss**: Ensures that the enhanced images retain core content features from the input image. 60 | - **Identity Loss**: Preserves input properties for high-quality images during enhancement, avoiding unnecessary alterations. 61 | 62 | The total generator loss function is a weighted sum of these components, promoting both visual fidelity and structural consistency. 63 | 64 | ## Advantages of the Algorithm 65 | 66 | The unsupervised image enhancement algorithm presented in this repository offers several notable advantages, including innovative architectural features and superior performance metrics when compared to existing state-of-the-art algorithms. Here are the key benefits and comparative results: 67 | 68 | 1. **Unsupervised Training Paradigm**: Unlike many existing methods that require paired data (i.e., low-quality and high-quality image pairs), our algorithm operates in an unsupervised setting. This eliminates the need for precise, matched datasets, which are often challenging and expensive to acquire in fields such as dental imaging. 69 | 70 | 2. **Hybrid CNN and GAN Approach with Attention Mechanisms**: 71 | 72 | - The combination of Convolutional Neural Networks (CNN) and Generative Adversarial Networks (GANs) allows the model to effectively learn image transformations and adversarial dynamics for high-fidelity results. 73 | - The **Global Attention Module (GAM)** ensures that the generator can capture overall image characteristics, such as lighting and color balance, before making detailed adjustments. 74 | - The **Local Feature Enhancement (LFE)** module, inspired by CBAM attention mechanisms, allows for fine-grained control over detail enhancement by emphasizing spatial and channel-wise attention. 75 | 76 | 3. **Performance on Benchmark Datasets**: 77 | 78 | - The algorithm was tested on the MIT-Adobe FiveK dataset and demonstrated substantial improvements over state-of-the-art methods such as CycleGAN and DnCNN. 79 | - **Quantitative metrics used for comparison**: 80 | - **Peak Signal-to-Noise Ratio (PSNR)**: Our model achieved an average PSNR improvement of 2.5 dB over CycleGAN and 1.8 dB over DnCNN on tested images, indicating superior image quality with less distortion. 81 | - **Structural Similarity Index (SSIM)**: Our approach attained a significant increase in SSIM scores compared to competing algorithms, reflecting better structural preservation and perceptual quality. 82 | - **Neural Image Assessment (NIMA)**: Our model consistently produced aesthetically pleasing images, achieving a higher NIMA score compared to other state-of-the-art methods, reflecting human perception-driven quality. 83 | 84 | 4. **Robustness and Generalization**: 85 | 86 | - The architecture was shown to generalize effectively across different types of input images without requiring significant retraining or fine-tuning. This makes it adaptable for real-world scenarios where input images can vary substantially in quality and features. 87 | 88 | 5. **High-Quality Outputs with Attention to Detail**: 89 | 90 | - The use of combined global and local attention mechanisms enables the model to enhance both large-scale and minute image features, ensuring visually compelling outputs that maintain global coherence while highlighting critical details. 91 | - This dual attention strategy ensures that the model does not overfit to global characteristics or neglect important details, achieving an optimal balance. 92 | 93 | 6. **Efficient Training Process**: 94 | 95 | - The model architecture, training schedule, and use of advanced optimization strategies (such as adaptive learning rates with the Adam optimizer) ensure faster convergence and stable training compared to many existing GAN-based approaches. 96 | 97 | ### Comparative Results with Other Algorithms 98 | 99 | - **CycleGAN**: 100 | - **PSNR**: Our model outperforms CycleGAN by an average of 2.5 dB, producing significantly higher-quality reconstructions. 101 | - **SSIM**: Demonstrates higher similarity to ground truth images with improved structural integrity. 102 | - **NIMA**: Generates more aesthetically appealing images, as evidenced by higher scores. 103 | 104 | - **DnCNN**: 105 | - **Noise Removal and Detail Preservation**: While DnCNN excels at removing noise, our approach surpasses it in preserving image structure and details, particularly in complex scenes and fine-grained content. 106 | - **PSNR and SSIM Improvements**: Achieved higher PSNR and SSIM scores, confirming enhanced fidelity and reduced image distortion. 107 | 108 | ## Code Details 109 | 110 | - **Framework**: PyTorch 111 | - **Training Details**: 112 | - Trained using the Adam optimizer with an initial learning rate of 0.0001. 113 | - Learning rate decayed linearly after 150 epochs for stable convergence. 114 | - The model was trained on the MIT-Adobe FiveK dataset. 115 | - **Evaluation Metrics**: 116 | - **Peak Signal-to-Noise Ratio (PSNR)** 117 | - **Structural Similarity Index (SSIM)** 118 | - **Neural Image Assessment (NIMA)** 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | import argparse 4 | from utils import str2bool 5 | 6 | 7 | def get_config(): 8 | parser = argparse.ArgumentParser() 9 | 10 | # Model configuration. 11 | parser.add_argument('--mode', type=str, default='train', help='train|test') 12 | parser.add_argument('--adv_loss_type', type=str, default='rahinge', help='adversarial Loss: ls|original|hinge|rahinge|rals') 13 | parser.add_argument('--image_size', type=int, default=512, help='image load resolution') 14 | parser.add_argument('--resize_size', type=int, default=256, help='resolution after resizing') 15 | parser.add_argument('--test_img_size', type=int, default=512, help='resolution after resizing') 16 | parser.add_argument('--g_conv_dim', type=int, default=32, help='number of conv filters in the first layer of G') 17 | parser.add_argument('--d_conv_dim', type=int, default=32, help='number of conv filters in the first layer of D') 18 | parser.add_argument('--shuffle', type=str, default=True, help='shuffle when load dataset') 19 | parser.add_argument('--drop_last', type=str2bool, default=False, help=' drop the last incomplete batch') 20 | parser.add_argument('--version', type=str, default='UEGAN-FiveK', help='UEGAN') 21 | parser.add_argument('--init_type', type=str, default='orthogonal', help='normal|xavier|kaiming|orthogonal') 22 | parser.add_argument('--adv_input',type=str2bool, default=True, help='whether discriminator input') 23 | parser.add_argument('--g_use_sn', type=str2bool, default=False, help='whether use spectral normalization in G') 24 | parser.add_argument('--d_use_sn', type=str2bool, default=True, help='whether use spectral normalization in D') 25 | parser.add_argument('--g_act_fun', type=str, default='LeakyReLU', help='LeakyReLU|ReLU|Swish|SELU|none') 26 | parser.add_argument('--d_act_fun', type=str, default='LeakyReLU', help='LeakyReLU|ReLU|Swish|SELU|none') 27 | parser.add_argument('--g_norm_fun', type=str, default='none', help='BatchNorm|InstanceNorm|none') 28 | parser.add_argument('--d_norm_fun', type=str, default='none', help='BatchNorm|InstanceNorm|none') 29 | 30 | # Training configuration. 31 | parser.add_argument('--pretrained_model', type=float, default=0.0) 32 | parser.add_argument('--total_epochs', type=int, default=100, help='total epochs to update the generator') 33 | parser.add_argument('--train_batch_size', type=int, default=10, help='mini batch size') 34 | parser.add_argument('--val_batch_size', type=int, default=1, help='mini batch size') 35 | parser.add_argument('--num_workers', type=int, default=0, help='subprocesses to use for data loading') 36 | parser.add_argument('--seed', type=int, default=1990, help='Seed for random number generator') 37 | parser.add_argument('--g_lr', type=float, default=1e-4, help='learning rate for G') 38 | parser.add_argument('--d_lr', type=float, default=4e-4, help='learning rate for D') 39 | parser.add_argument('--lr_decay', type=str2bool, default=True, help='setup learning rate decay schedule') 40 | parser.add_argument('--lr_num_epochs_decay', type=int, default=50, help='LambdaLR: epoch at starting learning rate') 41 | parser.add_argument('--lr_decay_ratio', type=int, default=50, help='LambdaLR: ratio of linearly decay learning rate to zero') 42 | parser.add_argument('--optimizer_type', type=str, default='adam', help='adam|rmsprop') 43 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for Adam optimizer') 44 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for Adam optimizer') 45 | parser.add_argument('--alpha', type=float, default=0.9, help='alpha for rmsprop optimizer') 46 | parser.add_argument('--lambda_adv', type=float, default=0.10, help='weight for adversarial loss') 47 | parser.add_argument('--lambda_percep', type=float, default=1.0, help='weight for perceptual loss') 48 | parser.add_argument('--lambda_idt', type=float, default=0.10, help='weight for identity loss') 49 | parser.add_argument('--idt_loss_type', type=str, default='l1', help='identity_loss: l1|l2|smoothl1 ') 50 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer, pool_size=0 means no buffer') 51 | 52 | # validation and test configuration 53 | parser.add_argument('--num_epochs_start_val', type=int, default=8, help='start validate the model') 54 | parser.add_argument('--val_each_epochs', type=int, default=2, help='validate the model every time after training these epochs') 55 | 56 | # Directories. 57 | parser.add_argument('--train_img_dir', type=str, default='./data/fivek/train') 58 | parser.add_argument('--val_img_dir', type=str, default='./data/fivek/val') 59 | parser.add_argument('--test_img_dir', type=str, default='./data/fivek/test') 60 | parser.add_argument('--save_root_dir', type=str, default='./results') 61 | parser.add_argument('--val_label_dir', type=str, default='./data/fivek/val/label/') 62 | parser.add_argument('--test_label_dir', type=str, default='./data/fivek/test/label/') 63 | parser.add_argument('--model_save_path', type=str, default='models') 64 | parser.add_argument('--sample_path', type=str, default='samples') 65 | parser.add_argument('--log_path', type=str, default='logs') 66 | parser.add_argument('--val_result_path', type=str, default='validation') 67 | parser.add_argument('--test_result_path', type=str, default='test') 68 | 69 | # step size 70 | parser.add_argument('--log_step', type=int, default=100) 71 | parser.add_argument('--info_step', type=int, default=100) 72 | parser.add_argument('--sample_step', type=int, default=100) 73 | parser.add_argument('--model_save_epoch', type=int, default=1) 74 | 75 | # Misc 76 | parser.add_argument('--parallel', type=str2bool, default=False, help='use multi-GPU for training') 77 | parser.add_argument('--gpu_ids', default=[0, 1, 2, 3]) 78 | parser.add_argument('--use_tensorboard', type=str, default=False) 79 | parser.add_argument('--is_print_network', type=str2bool, default=True) 80 | parser.add_argument('--is_test_nima', type=str2bool, default=True) 81 | parser.add_argument('--is_test_psnr_ssim', type=str2bool, default=False) 82 | parser.add_argument('--psnr_save_path', type=str, default='./results/psnr_val_results/') 83 | parser.add_argument('--ssim_save_path', type=str, default='./results/ssim_val_results') 84 | parser.add_argument('--nima_result_save_path', type=str, default='./results/nima_val_results/') 85 | 86 | return parser.parse_args() -------------------------------------------------------------------------------- /data/fivek/test/label/a4670-Duggan_080115_4464.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/test/label/a4670-Duggan_080115_4464.png -------------------------------------------------------------------------------- /data/fivek/test/label/a4680-_DSC0048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/test/label/a4680-_DSC0048.png -------------------------------------------------------------------------------- /data/fivek/test/label/a4690-DSC_0005-1-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/test/label/a4690-DSC_0005-1-2.png -------------------------------------------------------------------------------- /data/fivek/test/raw/a4670-Duggan_080115_4464.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/test/raw/a4670-Duggan_080115_4464.png -------------------------------------------------------------------------------- /data/fivek/test/raw/a4680-_DSC0048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/test/raw/a4680-_DSC0048.png -------------------------------------------------------------------------------- /data/fivek/test/raw/a4690-DSC_0005-1-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/test/raw/a4690-DSC_0005-1-2.png -------------------------------------------------------------------------------- /data/fivek/train/exp/a2211-dgw_047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/train/exp/a2211-dgw_047.png -------------------------------------------------------------------------------- /data/fivek/train/exp/a2222-_DSC0106.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/train/exp/a2222-_DSC0106.png -------------------------------------------------------------------------------- /data/fivek/train/exp/a2244-jmac_DSC2985.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/train/exp/a2244-jmac_DSC2985.png -------------------------------------------------------------------------------- /data/fivek/train/raw/a4485-_DSC0026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/train/raw/a4485-_DSC0026.png -------------------------------------------------------------------------------- /data/fivek/train/raw/a4495-kme_2356.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/train/raw/a4495-kme_2356.png -------------------------------------------------------------------------------- /data/fivek/train/raw/a4500-Duggan_090428_8065.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/train/raw/a4500-Duggan_090428_8065.png -------------------------------------------------------------------------------- /data/fivek/val/label/a4970-DSC_0377.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/val/label/a4970-DSC_0377.png -------------------------------------------------------------------------------- /data/fivek/val/label/a4980-20090328_at_06h47m42__MG_1211.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/val/label/a4980-20090328_at_06h47m42__MG_1211.png -------------------------------------------------------------------------------- /data/fivek/val/label/a4990-jmac_MG_1139.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/val/label/a4990-jmac_MG_1139.png -------------------------------------------------------------------------------- /data/fivek/val/raw/a4970-DSC_0377.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/val/raw/a4970-DSC_0377.png -------------------------------------------------------------------------------- /data/fivek/val/raw/a4980-20090328_at_06h47m42__MG_1211.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/val/raw/a4980-20090328_at_06h47m42__MG_1211.png -------------------------------------------------------------------------------- /data/fivek/val/raw/a4990-jmac_MG_1139.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/data/fivek/val/raw/a4990-jmac_MG_1139.png -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | from pathlib import Path 4 | from itertools import chain 5 | import os 6 | 7 | from munch import Munch 8 | from PIL import Image 9 | 10 | import torch 11 | from torch.utils import data 12 | from torchvision import transforms 13 | 14 | 15 | def listdir(dname): 16 | fnames = list(chain(*[list(Path(dname).rglob('*.' + ext)) 17 | for ext in ['png', 'jpg', 'jpeg', 'JPG']])) 18 | return fnames 19 | 20 | 21 | class DefaultDataset(data.Dataset): 22 | def __init__(self, root, transform=None): 23 | self.samples = listdir(root) 24 | self.samples.sort() 25 | self.transform = transform 26 | self.targets = None 27 | 28 | def __getitem__(self, index): 29 | fname = self.samples[index] 30 | img = Image.open(fname).convert('RGB') 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | return img 34 | 35 | def __len__(self): 36 | return len(self.samples) 37 | 38 | 39 | class ReferenceDataset(data.Dataset): 40 | def __init__(self, root, transform=None): 41 | self.samples = self._make_dataset(root) 42 | self.transform = transform 43 | 44 | def _make_dataset(self, root): 45 | domains = os.listdir(root) 46 | fnames, fnames2 = [], [] 47 | 48 | for idx, domain in enumerate(sorted(domains)): 49 | class_dir = os.path.join(root, domain) 50 | cls_fnames = listdir(class_dir) 51 | if idx == 0: 52 | fnames += cls_fnames 53 | elif idx == 1: 54 | fnames2 += cls_fnames 55 | 56 | return list(zip(fnames, fnames2)) 57 | 58 | def __getitem__(self, index): 59 | fname, fname2 = self.samples[index] 60 | img_name = os.path.splitext(os.path.basename(fname2))[0] 61 | img = Image.open(fname).convert('RGB') 62 | img2 = Image.open(fname2).convert('RGB') 63 | if self.transform is not None: 64 | img = self.transform(img) 65 | img2 = self.transform(img2) 66 | return img, img2, img_name 67 | 68 | def __len__(self): 69 | return len(self.samples) 70 | 71 | 72 | def get_train_loader(root, img_size=512, resize_size=256, batch_size=8, shuffle=True, num_workers=8, drop_last=True): 73 | transform = transforms.Compose([ 74 | transforms.RandomCrop(img_size), 75 | transforms.Resize([resize_size, resize_size]), 76 | transforms.RandomHorizontalFlip(p=0.5), 77 | transforms.RandomVerticalFlip(p=0.5), 78 | transforms.ToTensor(), 79 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 80 | ]) 81 | dataset = ReferenceDataset(root, transform) 82 | return data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True, drop_last=drop_last) 83 | 84 | 85 | def get_test_loader(root, img_size=512, batch_size=8, shuffle=False, num_workers=0): 86 | transform = transforms.Compose([ 87 | transforms.Resize([img_size, img_size]), 88 | transforms.ToTensor(), 89 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 90 | ]) 91 | dataset = ReferenceDataset(root, transform) 92 | return data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=True) 93 | 94 | 95 | class InputFetcher: 96 | def __init__(self, loader): 97 | self.loader = loader 98 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 99 | 100 | def _fetch_refs(self): 101 | try: 102 | x, y, name = next(self.iter) 103 | except (AttributeError, StopIteration): 104 | self.iter = iter(self.loader) 105 | x, y, name = next(self.iter) 106 | return x, y, name 107 | 108 | def __next__(self): 109 | x, y, img_name = self._fetch_refs() 110 | x, y = x.to(self.device), y.to(self.device) 111 | inputs = Munch(img_exp=x, img_raw=y, img_name=img_name) 112 | return inputs 113 | 114 | -------------------------------------------------------------------------------- /figures/examples.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/figures/examples.PNG -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/figures/framework.png -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # -*-coding:utf-8-*- 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from math import exp 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | import os 9 | from math import pi 10 | 11 | 12 | class PerceptualLoss(nn.Module): 13 | def __init__(self): 14 | super(PerceptualLoss, self).__init__() 15 | self.add_module('vgg', VGG19_relu()) 16 | self.criterion = torch.nn.MSELoss() 17 | self.weights = [1.0/64, 1.0/64, 1.0/32, 1.0/32, 1.0/1] 18 | self.IN = nn.InstanceNorm2d(512, affine=False, track_running_stats=False) 19 | self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1,-1,1,1) 20 | self.std = torch.tensor([0.229, 0.224, 0.225]).view(1,-1,1,1) 21 | 22 | def __call__(self, x, y): 23 | if x.shape[1] != 3: 24 | x = x.repeat(1, 3, 1, 1) 25 | y = y.repeat(1, 3, 1, 1) 26 | x = (x - self.mean.to(x)) / self.std.to(x) 27 | y = (y - self.mean.to(y)) / self.std.to(y) 28 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 29 | 30 | loss = self.weights[0] * self.criterion(self.IN(x_vgg['relu1_1']), self.IN(y_vgg['relu1_1'])) 31 | loss += self.weights[1] * self.criterion(self.IN(x_vgg['relu2_1']), self.IN(y_vgg['relu2_1'])) 32 | loss += self.weights[2] * self.criterion(self.IN(x_vgg['relu3_1']), self.IN(y_vgg['relu3_1'])) 33 | loss += self.weights[3] * self.criterion(self.IN(x_vgg['relu4_1']), self.IN(y_vgg['relu4_1'])) 34 | loss += self.weights[4] * self.criterion(self.IN(x_vgg['relu5_1']), self.IN(y_vgg['relu5_1'])) 35 | 36 | return loss 37 | 38 | 39 | class VGG19_relu(torch.nn.Module): 40 | def __init__(self): 41 | super(VGG19_relu, self).__init__() 42 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | cnn = models.vgg19(pretrained=True) 44 | # cnn.load_state_dict(torch.load(os.path.join('./models/', 'vgg19-dcbb9e9d.pth'))) 45 | cnn = cnn.to(self.device) 46 | features = cnn.features 47 | self.relu1_1 = torch.nn.Sequential() 48 | self.relu1_2 = torch.nn.Sequential() 49 | 50 | self.relu2_1 = torch.nn.Sequential() 51 | self.relu2_2 = torch.nn.Sequential() 52 | 53 | self.relu3_1 = torch.nn.Sequential() 54 | self.relu3_2 = torch.nn.Sequential() 55 | self.relu3_3 = torch.nn.Sequential() 56 | self.relu3_4 = torch.nn.Sequential() 57 | 58 | self.relu4_1 = torch.nn.Sequential() 59 | self.relu4_2 = torch.nn.Sequential() 60 | self.relu4_3 = torch.nn.Sequential() 61 | self.relu4_4 = torch.nn.Sequential() 62 | 63 | self.relu5_1 = torch.nn.Sequential() 64 | self.relu5_2 = torch.nn.Sequential() 65 | self.relu5_3 = torch.nn.Sequential() 66 | self.relu5_4 = torch.nn.Sequential() 67 | 68 | for x in range(2): 69 | self.relu1_1.add_module(str(x), features[x]) 70 | 71 | for x in range(2, 4): 72 | self.relu1_2.add_module(str(x), features[x]) 73 | 74 | for x in range(4, 7): 75 | self.relu2_1.add_module(str(x), features[x]) 76 | 77 | for x in range(7, 9): 78 | self.relu2_2.add_module(str(x), features[x]) 79 | 80 | for x in range(9, 12): 81 | self.relu3_1.add_module(str(x), features[x]) 82 | 83 | for x in range(12, 14): 84 | self.relu3_2.add_module(str(x), features[x]) 85 | 86 | for x in range(14, 16): 87 | self.relu3_3.add_module(str(x), features[x]) 88 | 89 | for x in range(16, 18): 90 | self.relu3_4.add_module(str(x), features[x]) 91 | 92 | for x in range(18, 21): 93 | self.relu4_1.add_module(str(x), features[x]) 94 | 95 | for x in range(21, 23): 96 | self.relu4_2.add_module(str(x), features[x]) 97 | 98 | for x in range(23, 25): 99 | self.relu4_3.add_module(str(x), features[x]) 100 | 101 | for x in range(25, 27): 102 | self.relu4_4.add_module(str(x), features[x]) 103 | 104 | for x in range(27, 30): 105 | self.relu5_1.add_module(str(x), features[x]) 106 | 107 | for x in range(30, 32): 108 | self.relu5_2.add_module(str(x), features[x]) 109 | 110 | for x in range(32, 34): 111 | self.relu5_3.add_module(str(x), features[x]) 112 | 113 | for x in range(34, 36): 114 | self.relu5_4.add_module(str(x), features[x]) 115 | 116 | # don't need the gradients, just want the features 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, x): 121 | relu1_1 = self.relu1_1(x) 122 | relu1_2 = self.relu1_2(relu1_1) 123 | 124 | relu2_1 = self.relu2_1(relu1_2) 125 | relu2_2 = self.relu2_2(relu2_1) 126 | 127 | relu3_1 = self.relu3_1(relu2_2) 128 | relu3_2 = self.relu3_2(relu3_1) 129 | relu3_3 = self.relu3_3(relu3_2) 130 | relu3_4 = self.relu3_4(relu3_3) 131 | 132 | relu4_1 = self.relu4_1(relu3_4) 133 | relu4_2 = self.relu4_2(relu4_1) 134 | relu4_3 = self.relu4_3(relu4_2) 135 | relu4_4 = self.relu4_4(relu4_3) 136 | 137 | relu5_1 = self.relu5_1(relu4_4) 138 | relu5_2 = self.relu5_2(relu5_1) 139 | relu5_3 = self.relu5_3(relu5_2) 140 | relu5_4 = self.relu5_4(relu5_3) 141 | 142 | out = { 143 | 'relu1_1': relu1_1, 144 | 'relu1_2': relu1_2, 145 | 146 | 'relu2_1': relu2_1, 147 | 'relu2_2': relu2_2, 148 | 149 | 'relu3_1': relu3_1, 150 | 'relu3_2': relu3_2, 151 | 'relu3_3': relu3_3, 152 | 'relu3_4': relu3_4, 153 | 154 | 'relu4_1': relu4_1, 155 | 'relu4_2': relu4_2, 156 | 'relu4_3': relu4_3, 157 | 'relu4_4': relu4_4, 158 | 159 | 'relu5_1': relu5_1, 160 | 'relu5_2': relu5_2, 161 | 'relu5_3': relu5_3, 162 | 'relu5_4': relu5_4, 163 | } 164 | return out 165 | 166 | 167 | class TVLoss(nn.Module): 168 | def __init__(self, tv_loss_weight=1): 169 | super(TVLoss, self).__init__() 170 | self.tv_loss_weight = tv_loss_weight 171 | 172 | def forward(self, x): 173 | batch_size = x.size()[0] 174 | h_x = x.size()[2] 175 | w_x = x.size()[3] 176 | count_h = self.tensor_size(x[:, :, 1:, :]) 177 | count_w = self.tensor_size(x[:, :, :, 1:]) 178 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 179 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 180 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 181 | 182 | @staticmethod 183 | def tensor_size(t): 184 | return t.size()[1] * t.size()[2] * t.size()[3] 185 | 186 | 187 | class AngularLoss(torch.nn.Module): 188 | def __init__(self): 189 | super(AngularLoss, self).__init__() 190 | 191 | def forward(self, feature1, feature2): 192 | cos_criterion = torch.nn.CosineSimilarity(dim=1) 193 | cos = cos_criterion(feature1, feature2) 194 | clip_bound = 0.999999 195 | cos = torch.clamp(cos, -clip_bound, clip_bound) 196 | if False: 197 | return 1 - torch.mean(cos) 198 | else: 199 | return torch.mean(torch.acos(cos)) * 180 / pi 200 | 201 | 202 | class MultiscaleRecLoss(nn.Module): 203 | def __init__(self, scale=3, rec_loss_type='l1', multiscale=True): 204 | super(MultiscaleRecLoss, self).__init__() 205 | self.multiscale = multiscale 206 | if rec_loss_type == 'l1': 207 | self.criterion = nn.L1Loss() 208 | elif rec_loss_type == 'smoothl1': 209 | self.criterion = nn.SmoothL1Loss() 210 | elif rec_loss_type == 'l2': 211 | self.criterion = nn.MSELoss() 212 | else: 213 | raise NotImplementedError('Loss [{}] is not implemented'.format(rec_loss_type)) 214 | self.downsample = nn.AvgPool2d(2, stride=2, count_include_pad=False) 215 | if self.multiscale: 216 | self.weights = [1.0, 1.0/2, 1.0/4] 217 | self.weights = self.weights[:scale] 218 | 219 | def forward(self, input, target): 220 | loss = 0 221 | pred = input.clone() 222 | gt = target.clone() 223 | if self.multiscale: 224 | for i in range(len(self.weights)): 225 | loss += self.weights[i] * self.criterion(pred, gt) 226 | if i != len(self.weights) - 1: 227 | pred = self.downsample(pred) 228 | gt = self.downsample(gt) 229 | else: 230 | loss = self.criterion(pred, gt) 231 | return loss 232 | 233 | 234 | def hingeloss(x, y, mode='fake'): 235 | if mode == 'fake': 236 | return torch.mean(nn.ReLU()(x + y)) 237 | elif mode == 'real': 238 | return torch.mean(nn.ReLU()(x - y)) 239 | else: 240 | raise NotImplementedError("=== Mode [{}] is not implemented. ===".format(mode)) 241 | 242 | def diff(x, y, mode=True): 243 | if mode: 244 | return x - torch.mean(y) 245 | else: 246 | return torch.mean(x) - y 247 | 248 | def calc_l2(x, y, mode=False): 249 | if mode: 250 | return torch.mean((x - y) ** 2) 251 | else: 252 | return torch.mean((x + y) ** 2) 253 | 254 | 255 | class GANLoss(nn.Module): 256 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 257 | tensor=torch.FloatTensor, opt=None): 258 | super(GANLoss, self).__init__() 259 | self.real_label = target_real_label 260 | self.fake_label = target_fake_label 261 | self.real_label_tensor = None 262 | self.fake_label_tensor = None 263 | self.zero_tensor = None 264 | self.Tensor = tensor 265 | self.gan_mode = gan_mode 266 | self.opt = opt 267 | if gan_mode == 'ls': 268 | pass 269 | elif gan_mode == 'original': 270 | pass 271 | elif gan_mode == 'w': 272 | pass 273 | elif gan_mode == 'hinge': 274 | pass 275 | elif gan_mode == 'rahinge': 276 | pass 277 | elif gan_mode == 'rals': 278 | pass 279 | else: 280 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 281 | 282 | def get_target_tensor(self, input, target_is_real): 283 | if target_is_real: 284 | if self.real_label_tensor is None: 285 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 286 | self.real_label_tensor.requires_grad_(False) 287 | return self.real_label_tensor.expand_as(input) 288 | else: 289 | if self.fake_label_tensor is None: 290 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 291 | self.fake_label_tensor.requires_grad_(False) 292 | return self.fake_label_tensor.expand_as(input) 293 | 294 | def get_zero_tensor(self, input): 295 | if self.zero_tensor is None: 296 | self.zero_tensor = self.Tensor(1).fill_(0) 297 | self.zero_tensor.requires_grad_(False) 298 | return self.zero_tensor.expand_as(input) 299 | 300 | def loss(self, real_preds, fake_preds, target_is_real, for_real=None, for_fake=None, for_discriminator=True): 301 | if self.gan_mode == 'original': # cross entropy loss 302 | if for_real: 303 | target_tensor = self.get_target_tensor(real_preds, target_is_real) 304 | loss = F.binary_cross_entropy_with_logits(real_preds, target_tensor) 305 | return loss 306 | elif for_fake: 307 | target_tensor = self.get_target_tensor(fake_preds, target_is_real) 308 | loss = F.binary_cross_entropy_with_logits(fake_preds, target_tensor) 309 | return loss 310 | else: 311 | raise NotImplementedError("nither for real_preds nor for fake_preds") 312 | elif self.gan_mode == 'ls': 313 | if for_real: 314 | target_tensor = self.get_target_tensor(real_preds, target_is_real) 315 | return F.mse_loss(real_preds, target_tensor) 316 | elif for_fake: 317 | target_tensor = self.get_target_tensor(fake_preds, target_is_real) 318 | return F.mse_loss(fake_preds, target_tensor) 319 | else: 320 | raise NotImplementedError("nither for real_preds nor for fake_preds") 321 | elif self.gan_mode == 'hinge': 322 | if for_real: 323 | if for_discriminator: 324 | if target_is_real: 325 | minval = torch.min(real_preds - 1, self.get_zero_tensor(real_preds)) 326 | loss = -torch.mean(minval) 327 | else: 328 | minval = torch.min(-real_preds - 1, self.get_zero_tensor(real_preds)) 329 | loss = -torch.mean(minval) 330 | else: 331 | assert target_is_real, "The generator's hinge loss must be aiming for real" 332 | loss = -torch.mean(real_preds) 333 | return loss 334 | elif for_fake: 335 | if for_discriminator: 336 | if target_is_real: 337 | minval = torch.min(fake_preds - 1, self.get_zero_tensor(fake_preds)) 338 | loss = -torch.mean(minval) 339 | else: 340 | minval = torch.min(-fake_preds - 1, self.get_zero_tensor(fake_preds)) 341 | loss = -torch.mean(minval) 342 | else: 343 | assert target_is_real, "The generator's hinge loss must be aiming for real" 344 | loss = -torch.mean(fake_preds) 345 | return loss 346 | else: 347 | raise NotImplementedError("nither for real_preds nor for fake_preds") 348 | elif self.gan_mode == 'rahinge': 349 | if for_discriminator: 350 | ## difference between real and fake 351 | r_f_diff = real_preds - torch.mean(fake_preds) 352 | ## difference between fake and real 353 | f_r_diff = fake_preds - torch.mean(real_preds) 354 | loss = torch.mean(torch.nn.ReLU()(1 - r_f_diff)) + torch.mean(torch.nn.ReLU()(1 + f_r_diff)) 355 | return loss / 2 356 | else: 357 | ## difference between real and fake 358 | r_f_diff = real_preds - torch.mean(fake_preds) 359 | ## difference between fake and real 360 | f_r_diff = fake_preds - torch.mean(real_preds) 361 | loss = torch.mean(torch.nn.ReLU()(1 + r_f_diff)) + torch.mean(torch.nn.ReLU()(1 - f_r_diff)) 362 | return loss / 2 363 | elif self.gan_mode == 'rals': 364 | if for_discriminator: 365 | ## difference between real and fake 366 | r_f_diff = real_preds - torch.mean(fake_preds) 367 | ## difference between fake and real 368 | f_r_diff = fake_preds - torch.mean(real_preds) 369 | loss = torch.mean((r_f_diff - 1) ** 2) + torch.mean((f_r_diff + 1) ** 2) 370 | return loss / 2 371 | else: 372 | ## difference between real and fake 373 | r_f_diff = real_preds - torch.mean(fake_preds) 374 | ## difference between fake and real 375 | f_r_diff = fake_preds - torch.mean(real_preds) 376 | loss = torch.mean((r_f_diff + 1) ** 2) + torch.mean((f_r_diff - 1) ** 2) 377 | return loss / 2 378 | else: 379 | # wgan 380 | if for_real: 381 | if target_is_real: 382 | return -real_preds.mean() 383 | else: 384 | return real_preds.mean() 385 | elif for_fake: 386 | if target_is_real: 387 | return -fake_preds.mean() 388 | else: 389 | return fake_preds.mean() 390 | else: 391 | raise NotImplementedError("nither for real_preds nor for fake_preds") 392 | 393 | def __call__(self, real_preds, fake_preds, target_is_real, for_real=None, for_fake=None, for_discriminator=True): 394 | ## computing loss is a bit complicated because |input| may not be 395 | ## a tensor, but list of tensors in case of multiscale discriminator 396 | if isinstance(real_preds, list): 397 | loss = 0 398 | for (pred_real_i, pred_fake_i) in zip(real_preds, fake_preds): 399 | if isinstance(pred_real_i, list): 400 | pred_real_i = pred_real_i[-1] 401 | if isinstance(pred_fake_i, list): 402 | pred_fake_i = pred_fake_i[-1] 403 | 404 | loss_tensor = self.loss(pred_real_i, pred_fake_i, target_is_real, for_real, for_fake, for_discriminator) 405 | 406 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 407 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 408 | loss += new_loss 409 | return loss 410 | else: 411 | return self.loss(real_preds, target_is_real, for_discriminator) 412 | 413 | 414 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | import os 4 | import argparse 5 | from trainer import Trainer 6 | from tester import Tester 7 | from utils import create_folder, setup_seed 8 | from config import get_config 9 | import torch 10 | from munch import Munch 11 | from data_loader import get_train_loader, get_test_loader 12 | 13 | 14 | def main(args): 15 | # for fast training. 16 | torch.backends.cudnn.benchmark = True 17 | 18 | setup_seed(args.seed) 19 | 20 | # create directories if not exist. 21 | create_folder(args.save_root_dir, args.version, args.model_save_path) 22 | create_folder(args.save_root_dir, args.version, args.sample_path) 23 | create_folder(args.save_root_dir, args.version, args.log_path) 24 | create_folder(args.save_root_dir, args.version, args.val_result_path) 25 | create_folder(args.save_root_dir, args.version, args.test_result_path) 26 | 27 | if args.mode == 'train': 28 | loaders = Munch(ref=get_train_loader(root=args.train_img_dir, 29 | img_size=args.image_size, 30 | resize_size=args.resize_size, 31 | batch_size=args.train_batch_size, 32 | shuffle=args.shuffle, 33 | num_workers=args.num_workers, 34 | drop_last=args.drop_last), 35 | val=get_test_loader(root=args.val_img_dir, 36 | batch_size=args.val_batch_size, 37 | shuffle=True, 38 | num_workers=args.num_workers)) 39 | trainer = Trainer(loaders, args) 40 | trainer.train() 41 | elif args.mode == 'test': 42 | loaders = Munch(tes=get_test_loader(root=args.test_img_dir, 43 | img_size=args.test_img_size, 44 | batch_size=args.val_batch_size, 45 | shuffle=True, 46 | num_workers=args.num_workers)) 47 | tester = Tester(loaders, args) 48 | tester.test() 49 | else: 50 | raise NotImplementedError('Mode [{}] is not found'.format(args.mode)) 51 | 52 | 53 | if __name__ == '__main__': 54 | 55 | args = get_config() 56 | 57 | # if args.is_print_network: 58 | # print(args) 59 | 60 | main(args) -------------------------------------------------------------------------------- /metrics/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/metrics/.DS_Store -------------------------------------------------------------------------------- /metrics/CalcPSNR.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import math 4 | import numpy as np 5 | # import skimage.metrics.peak_signal_noise_ratio as psnr_skimage 6 | import cv2 7 | import glob 8 | import datetime 9 | 10 | 11 | def calc_psnr(folder_Gen, folder_GT, result_save_path, epoch): 12 | 13 | if not os.path.exists(result_save_path): 14 | os.makedirs(result_save_path) 15 | 16 | PSNR, total_psnr, avg_psnr = 0.0, 0.0, 0.0 17 | epoch_result = result_save_path + 'PSNR_epoch_' + str(epoch) + '.csv' 18 | epochfile = open(epoch_result, 'w') 19 | epochfile.write('image_name' + ','+ 'psnr' + '\n') 20 | 21 | total_result = result_save_path + 'PSNR_total_results_epoch_avgpsnr.csv' 22 | totalfile = open(total_result, 'a+') 23 | 24 | crop_border = 4 25 | test_Y = False # True: test Y channel only; False: test RGB channels 26 | 27 | img_list = sorted(glob.glob(folder_Gen + '/*')) 28 | 29 | if test_Y: 30 | print('Testing Y channel.') 31 | else: 32 | print('Testing RGB channels.') 33 | 34 | starttime = datetime.datetime.now() 35 | 36 | for i, img_path in enumerate(img_list): 37 | im_Gen = cv2.imread(img_path) / 255. 38 | 39 | base_name = os.path.splitext(os.path.basename(img_path))[0] 40 | imgName, _, _ = base_name.rsplit('_', 2) 41 | # print(base_name) 42 | # print(imgName) 43 | GT_imgName = imgName + '.png' 44 | # print(GT_imgName) 45 | im_GT = cv2.imread(os.path.join(folder_GT, GT_imgName)) / 255. 46 | 47 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 48 | im_GT_in = bgr2ycbcr(im_GT) 49 | im_Gen_in = bgr2ycbcr(im_Gen) 50 | else: 51 | im_GT_in = im_GT 52 | im_Gen_in = im_Gen 53 | 54 | # crop borders 55 | if im_GT_in.ndim == 3: 56 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 57 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 58 | elif im_GT_in.ndim == 2: 59 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 60 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 61 | else: 62 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 63 | 64 | # calculate PSNR and SSIM 65 | PSNR = calculate_psnr(cropped_GT * 255, cropped_Gen * 255) 66 | 67 | # print("=== PSNR is {:.4f} of {:>3d}-th img ===".format(PSNR, i)) 68 | 69 | epochfile.write(GT_imgName + ',' + str(round(PSNR, 6)) + '\n') 70 | 71 | total_psnr += PSNR 72 | if i % 50 == 0: 73 | print("=== PSNR is processing {:>3d}-th image ===".format(i)) 74 | 75 | endtime = datetime.datetime.now() 76 | print("======================= Complete the PSNR test of {:>3d} images, take {} seconds ======================= ".format(i+1, (endtime - starttime).seconds)) 77 | avg_psnr = total_psnr / i 78 | epochfile.write('Average' + ',' + str(round(avg_psnr, 6)) + '\n') 79 | epochfile.close() 80 | totalfile.write(str(epoch) + ',' + str(round(avg_psnr, 6)) + '\n') 81 | totalfile.close() 82 | return avg_psnr 83 | 84 | 85 | def calculate_psnr(img1, img2, data_range=255): 86 | # img1 and img2 have range [0, 255] 87 | img1, img2 = img1.astype(np.float64), img2.astype(np.float64) 88 | mse = np.mean((img1 - img2)**2, dtype=np.float64) 89 | if mse == 0: 90 | return float('inf') 91 | # return 20 * math.log10(255.0 / math.sqrt(mse)) 92 | return 10 * np.log10((data_range ** 2)/ mse) 93 | 94 | 95 | def ssim(img1, img2): 96 | C1 = (0.01 * 255)**2 97 | C2 = (0.03 * 255)**2 98 | 99 | img1 = img1.astype(np.float64) 100 | img2 = img2.astype(np.float64) 101 | kernel = cv2.getGaussianKernel(11, 1.5) 102 | window = np.outer(kernel, kernel.transpose()) 103 | 104 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 105 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 106 | mu1_sq = mu1**2 107 | mu2_sq = mu2**2 108 | mu1_mu2 = mu1 * mu2 109 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 110 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 111 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 112 | 113 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 114 | (sigma1_sq + sigma2_sq + C2)) 115 | return ssim_map.mean() 116 | 117 | 118 | def calculate_ssim(img1, img2): 119 | '''calculate SSIM 120 | the same outputs as MATLAB's 121 | img1, img2: [0, 255] 122 | ''' 123 | if not img1.shape == img2.shape: 124 | raise ValueError('Input images must have the same dimensions.') 125 | if img1.ndim == 2: 126 | return ssim(img1, img2) 127 | elif img1.ndim == 3: 128 | if img1.shape[2] == 3: 129 | ssims = [] 130 | for i in range(3): 131 | ssims.append(ssim(img1, img2)) 132 | return np.array(ssims).mean() 133 | elif img1.shape[2] == 1: 134 | return ssim(np.squeeze(img1), np.squeeze(img2)) 135 | else: 136 | raise ValueError('Wrong input image dimensions.') 137 | 138 | 139 | def bgr2ycbcr(img, only_y=True): 140 | '''same as matlab rgb2ycbcr 141 | only_y: only return Y channel 142 | Input: 143 | uint8, [0, 255] 144 | float, [0, 1] 145 | ''' 146 | in_img_type = img.dtype 147 | img.astype(np.float32) 148 | if in_img_type != np.uint8: 149 | img *= 255. 150 | # convert 151 | if only_y: 152 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 153 | else: 154 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 155 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 156 | if in_img_type == np.uint8: 157 | rlt = rlt.round() 158 | else: 159 | rlt /= 255. 160 | return rlt.astype(in_img_type) -------------------------------------------------------------------------------- /metrics/CalcSSIM.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import math 4 | import numpy as np 5 | from skimage.metrics import structural_similarity as ssim_skimage 6 | import cv2 7 | import glob 8 | import datetime 9 | 10 | 11 | def calc_ssim(folder_Gen, folder_GT, result_save_path, epoch): 12 | 13 | if not os.path.exists(result_save_path): 14 | os.makedirs(result_save_path) 15 | 16 | SSIM, total_ssim, avg_ssim = 0.0, 0.0, 0.0 17 | epoch_result = result_save_path + 'SSIM_epoch_' + str(epoch) + '.csv' 18 | epochfile = open(epoch_result, 'w') 19 | epochfile.write('image_name' + ','+ 'ssim' + '\n') 20 | 21 | total_result = result_save_path + 'SSIM_total_results_epoch_avgssim.csv' 22 | totalfile = open(total_result, 'a+') 23 | 24 | crop_border = 4 25 | test_Y = False # True: test Y channel only; False: test RGB channels 26 | 27 | img_list = sorted(glob.glob(folder_Gen + '/*')) 28 | 29 | if test_Y: 30 | print('Testing Y channel.') 31 | else: 32 | print('Testing RGB channels.') 33 | 34 | starttime = datetime.datetime.now() 35 | 36 | for i, img_path in enumerate(img_list): 37 | im_Gen = cv2.imread(img_path) / 255. 38 | 39 | base_name = os.path.splitext(os.path.basename(img_path))[0] 40 | imgName, _, _ = base_name.rsplit('_', 2) 41 | GT_imgName = imgName + '.png' 42 | im_GT = cv2.imread(os.path.join(folder_GT, GT_imgName)) / 255. 43 | 44 | if test_Y and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 45 | im_GT_in = bgr2ycbcr(im_GT) 46 | im_Gen_in = bgr2ycbcr(im_Gen) 47 | else: 48 | im_GT_in = im_GT 49 | im_Gen_in = im_Gen 50 | 51 | # crop borders 52 | if im_GT_in.ndim == 3: 53 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border, :] 54 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border, :] 55 | elif im_GT_in.ndim == 2: 56 | cropped_GT = im_GT_in[crop_border:-crop_border, crop_border:-crop_border] 57 | cropped_Gen = im_Gen_in[crop_border:-crop_border, crop_border:-crop_border] 58 | else: 59 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT_in.ndim)) 60 | print(cropped_GT.shape, cropped_Gen.shape) 61 | # calculate PSNR and SSIM 62 | # SSIM = calculate_ssim(cropped_GT * 255, cropped_Gen * 255) 63 | SSIM = ssim_skimage(cropped_GT * 255, cropped_Gen * 255, multichannel=True, data_range=255,channel_axis=2) 64 | 65 | # print("=== SSIM is {:.4f} of {:>3d}-th img ===".format(SSIM, i)) 66 | 67 | epochfile.write(GT_imgName + ',' + str(round(SSIM, 6)) + '\n') 68 | 69 | total_ssim += SSIM 70 | if i % 50 == 0: 71 | print("=== SSIM is processing {:>3d}-th image ===".format(i)) 72 | 73 | endtime = datetime.datetime.now() 74 | print("======================= Complete the SSIM test of {:>3d} images, take {} seconds ======================= ".format(i+1, (endtime - starttime).seconds)) 75 | avg_ssim = total_ssim / i 76 | epochfile.write('Average' + ',' + str(round(avg_ssim, 6)) + '\n') 77 | epochfile.close() 78 | totalfile.write(str(epoch) + ',' + str(round(avg_ssim, 6)) + '\n') 79 | totalfile.close() 80 | return avg_ssim 81 | 82 | 83 | def calculate_psnr(img1, img2): 84 | # img1 and img2 have range [0, 255] 85 | img1 = img1.astype(np.float64) 86 | img2 = img2.astype(np.float64) 87 | mse = np.mean((img1 - img2)**2) 88 | if mse == 0: 89 | return float('inf') 90 | return 20 * math.log10(255.0 / math.sqrt(mse)) 91 | 92 | 93 | def ssim(img1, img2): 94 | C1 = (0.01 * 255)**2 95 | C2 = (0.03 * 255)**2 96 | 97 | img1 = img1.astype(np.float64) 98 | img2 = img2.astype(np.float64) 99 | kernel = cv2.getGaussianKernel(11, 1.5) 100 | window = np.outer(kernel, kernel.transpose()) 101 | 102 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 103 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 104 | mu1_sq = mu1**2 105 | mu2_sq = mu2**2 106 | mu1_mu2 = mu1 * mu2 107 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 108 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 109 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 110 | 111 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 112 | (sigma1_sq + sigma2_sq + C2)) 113 | return ssim_map.mean() 114 | 115 | 116 | def calculate_ssim(img1, img2): 117 | '''calculate SSIM 118 | the same outputs as MATLAB's 119 | img1, img2: [0, 255] 120 | ''' 121 | if not img1.shape == img2.shape: 122 | raise ValueError('Input images must have the same dimensions.') 123 | if img1.ndim == 2: 124 | return ssim(img1, img2) 125 | elif img1.ndim == 3: 126 | if img1.shape[2] == 3: 127 | ssims = [] 128 | for i in range(3): 129 | ssims.append(ssim(img1, img2)) 130 | return np.array(ssims).mean() 131 | elif img1.shape[2] == 1: 132 | return ssim(np.squeeze(img1), np.squeeze(img2)) 133 | else: 134 | raise ValueError('Wrong input image dimensions.') 135 | 136 | 137 | def bgr2ycbcr(img, only_y=True): 138 | '''same as matlab rgb2ycbcr 139 | only_y: only return Y channel 140 | Input: 141 | uint8, [0, 255] 142 | float, [0, 1] 143 | ''' 144 | in_img_type = img.dtype 145 | img.astype(np.float32) 146 | if in_img_type != np.uint8: 147 | img *= 255. 148 | # convert 149 | if only_y: 150 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 151 | else: 152 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 153 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 154 | if in_img_type == np.uint8: 155 | rlt = rlt.round() 156 | else: 157 | rlt /= 255. 158 | return rlt.astype(in_img_type) -------------------------------------------------------------------------------- /metrics/CenterCrop.m: -------------------------------------------------------------------------------- 1 | close all 2 | clear 3 | clc 4 | 5 | rootdir='./gt'; 6 | saveddir='./gt_crop'; 7 | subdir=dir(rootdir); 8 | 9 | qmkdir(saveddir); 10 | for i=1:length(subdir) 11 | subdirpath=fullfile(rootdir,subdir(i).name,'*.png'); 12 | images=dir(subdirpath); 13 | for j=1:length(images) 14 | ImageName=fullfile(rootdir,subdir(i).name,images(j).name); 15 | [ImageData, map] = imread(ImageName); 16 | 17 | sz = size(ImageData); 18 | x = floor(sz(2)/2); 19 | y = floor(sz(1)/2); 20 | 21 | %%% crop to multiple of 22 | high = floor(sz(1)/16) * 16; 23 | width = floor(sz(2)/16) * 16; 24 | %%% crop to uniform size (512) 25 | % high = 512; 26 | % width = 512; 27 | 28 | patch = ImageData(y-(floor(high/2))+1:y+(floor(high/2)), x-(floor(width/2))+1:x+(floor(width/2)), :); 29 | 30 | savedname=fullfile(saveddir, strcat(images(j).name(1:end-4), '.png')); 31 | imwrite(patch, savedname, 'Mode', 'looless') 32 | fprintf('Image No. = %d\n', j); 33 | end 34 | end 35 | 36 | function dir = qmkdir(dir) 37 | [success, message] = mkdir(dir); 38 | end 39 | 40 | 41 | -------------------------------------------------------------------------------- /metrics/NIMA/CalcNIMA.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torchvision.models 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | import os 7 | import torch.nn as nn 8 | from .mobile_net_v2 import mobile_net_v2 9 | import numpy as np 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | # def get_mean_score(score): 13 | # buckets = np.arange(1, 11) 14 | # mu = (buckets * score).sum() 15 | # return mu 16 | 17 | # def get_std_score(scores): 18 | # si = np.arange(1, 11) 19 | # mean = get_mean_score(scores) 20 | # std = np.sqrt(np.sum(((si - mean) ** 2) * scores)) 21 | # return std 22 | 23 | class NIMA(nn.Module): 24 | def __init__(self, pretrained_base_model=False): 25 | super(NIMA, self).__init__() 26 | base_model = mobile_net_v2(pretrained=pretrained_base_model) 27 | base_model = nn.Sequential(*list(base_model.children())[:-1]) 28 | 29 | self.base_model = base_model 30 | 31 | self.head = nn.Sequential( 32 | nn.ReLU(inplace=True), 33 | nn.Dropout(p=0.75), 34 | nn.Linear(1280, 10), 35 | nn.Softmax(dim=1) 36 | ) 37 | 38 | def forward(self, x): 39 | x = self.base_model(x) 40 | x = x.view(x.size(0), -1) 41 | x = self.head(x) 42 | return x 43 | 44 | 45 | def prepare_image(image): 46 | if image.mode != 'RGB': 47 | image = image.convert("RGB") 48 | Transform = transforms.Compose([ 49 | transforms.Resize(256), 50 | transforms.CenterCrop(224), 51 | transforms.ToTensor(), 52 | ]) 53 | image = Transform(image) 54 | image = image[None] 55 | return image 56 | 57 | 58 | def calc_nima(img_path, result_save_path, epoch): 59 | 60 | result_save_path = './results/nima_val_results/' 61 | if not os.path.exists(result_save_path): 62 | os.makedirs(result_save_path) 63 | 64 | model = NIMA() 65 | # print(model) 66 | model.load_state_dict(torch.load('./metrics/NIMA/pretrain-model.pth')) 67 | print('======================= start to calculate NIMA =======================') 68 | model.to(device).eval() 69 | 70 | mean, deviation, total_mean, total_std = 0.0, 0.0, 0.0, 0.0 71 | epoch_result = result_save_path + 'NIMA_epoch_' + str(epoch) + '_' + '_mean_std.csv' 72 | epochfile = open(epoch_result, 'w') 73 | epochfile.write('image_name' + ','+ 'mean' +',' + 'std' +'\n') 74 | 75 | total_result = result_save_path + 'NIMA_total_results_' + 'epoch' + '_mean_std.csv' 76 | totalfile = open(total_result, 'a+') 77 | # totalfile.write('epoch' + ',' + 'mean' + ',' + 'std' + '\n') 78 | 79 | test_imgs = [f for f in os.listdir(img_path)] 80 | for i, img in enumerate(test_imgs): 81 | image = Image.open(os.path.join(img_path, img)) 82 | image = prepare_image(image).to(device) 83 | with torch.no_grad(): 84 | preds = model(image).data.cpu().numpy()[0] 85 | 86 | for j, e in enumerate(preds, 1): 87 | mean += j * e 88 | 89 | for k, e in enumerate(preds, 1): 90 | deviation += (e * (k - mean) ** 2) 91 | std = deviation ** (0.5) 92 | epochfile.write(img + ',' + str(round(mean, 6)) + ',' + str(round(std, 6)) + '\n') 93 | total_mean += mean 94 | total_std += std 95 | mean, deviation = 0.0, 0.0 96 | if i % 50 == 0: 97 | print("=== NIMA is processing {:>3d}-th image ===".format(i)) 98 | print("======================= Complete the NIMA test of {:>3d} images ======================= ".format(i+1)) 99 | total_mean = total_mean / i 100 | total_std = total_std / i 101 | epochfile.write('Average' + ',' + str(round(total_mean, 6)) + ',' + str(round(total_std, 6)) + '\n') 102 | epochfile.close() 103 | totalfile.write(str(epoch) + ',' + str(round(total_mean, 6)) + ',' + str(round(total_std, 6)) + '\n') 104 | totalfile.close() 105 | return total_mean 106 | 107 | 108 | -------------------------------------------------------------------------------- /metrics/NIMA/__pycache__/CalcNIMA.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/metrics/NIMA/__pycache__/CalcNIMA.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/NIMA/__pycache__/mobile_net_v2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/metrics/NIMA/__pycache__/mobile_net_v2.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/NIMA/mobile_net_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | # from common import download_file 8 | 9 | MOBILE_NET_V2_UTR = 'https://s3-us-west-1.amazonaws.com/models-nima/mobilenetv2.pth.tar' 10 | 11 | 12 | def conv_bn(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | 20 | def conv_1x1_bn(inp, oup): 21 | return nn.Sequential( 22 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 23 | nn.BatchNorm2d(oup), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | 28 | class InvertedResidual(nn.Module): 29 | def __init__(self, inp, oup, stride, expand_ratio): 30 | super(InvertedResidual, self).__init__() 31 | self.stride = stride 32 | assert stride in [1, 2] 33 | 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | self.conv = nn.Sequential( 37 | # pw 38 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(inp * expand_ratio), 40 | nn.ReLU6(inplace=True), 41 | # dw 42 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 43 | nn.BatchNorm2d(inp * expand_ratio), 44 | nn.ReLU6(inplace=True), 45 | # pw-linear 46 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(oup), 48 | ) 49 | 50 | def forward(self, x): 51 | if self.use_res_connect: 52 | return x + self.conv(x) 53 | else: 54 | return self.conv(x) 55 | 56 | 57 | class MobileNetV2(nn.Module): 58 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 59 | super(MobileNetV2, self).__init__() 60 | # setting of inverted residual blocks 61 | self.interverted_residual_setting = [ 62 | # t, c, n, s 63 | [1, 16, 1, 1], 64 | [6, 24, 2, 2], 65 | [6, 32, 3, 2], 66 | [6, 64, 4, 2], 67 | [6, 96, 3, 1], 68 | [6, 160, 3, 2], 69 | [6, 320, 1, 1], 70 | ] 71 | 72 | # building first layer 73 | assert input_size % 32 == 0 74 | input_channel = int(32 * width_mult) 75 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 76 | self.features = [conv_bn(3, input_channel, 2)] 77 | # building inverted residual blocks 78 | for t, c, n, s in self.interverted_residual_setting: 79 | output_channel = int(c * width_mult) 80 | for i in range(n): 81 | if i == 0: 82 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 83 | else: 84 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 85 | input_channel = output_channel 86 | # building last several layers 87 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 88 | self.features.append(nn.AvgPool2d(input_size // 32)) 89 | # make it nn.Sequential 90 | self.features = nn.Sequential(*self.features) 91 | 92 | # building classifier 93 | self.classifier = nn.Sequential( 94 | nn.Dropout(), 95 | nn.Linear(self.last_channel, n_class), 96 | ) 97 | 98 | self._initialize_weights() 99 | 100 | def forward(self, x): 101 | x = self.features(x) 102 | x = x.view(-1, self.last_channel) 103 | x = self.classifier(x) 104 | return x 105 | 106 | def _initialize_weights(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | n = m.weight.size(1) 118 | m.weight.data.normal_(0, 0.01) 119 | m.bias.data.zero_() 120 | 121 | 122 | def mobile_net_v2(pretrained=True): 123 | model = MobileNetV2() 124 | # if pretrained: 125 | # path_to_model = '/tmp/mobilenetv2.pth.tar' 126 | # if not os.path.exists(path_to_model): 127 | # path_to_model = download_file(MOBILE_NET_V2_UTR, path_to_model) 128 | # state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage) 129 | # model.load_state_dict(state_dict) 130 | return model 131 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | .idea/ 103 | DATA/ 104 | *.pth 105 | test.jpg 106 | custom_setup.md -------------------------------------------------------------------------------- /metrics/NIMA/nima/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM floydhub/dl-base:3.0.0-gpu-py3.22 2 | 3 | ENV LC_ALL C.UTF-8 4 | ENV LANG C.UTF-8 5 | 6 | RUN apt-get update 7 | RUN apt-get install -y python3-tk zlib1g-dev libjpeg-dev 8 | 9 | 10 | ENV APP_DIR /app 11 | WORKDIR $APP_DIR 12 | 13 | # if CPU SSE4-capable add pillow-simd with AVX2-enabled version 14 | RUN pip uninstall -y pillow 15 | RUN CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 16 | 17 | 18 | COPY requirements.txt requirements.txt 19 | RUN pip install -r requirements.txt 20 | COPY . $APP_DIR 21 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Kyryl Truskovskyi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/Procfile: -------------------------------------------------------------------------------- 1 | web: gunicorn -b 0.0.0.0:$PORT nima.inference.app:app 2 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch NIMA: Neural IMage Assessment 2 | 3 | PyTorch implementation of [Neural IMage Assessment](https://arxiv.org/abs/1709.05424) by Hossein Talebi and Peyman Milanfar. You can learn more from [this post at Google Research Blog](https://research.googleblog.com/2017/12/introducing-nima-neural-image-assessment.html). 4 | 5 | 6 | ## Installing 7 | 8 | ```bash 9 | git clone https://github.com/truskovskiyk/nima.pytorch.git 10 | cd nima.pytorch 11 | virtualenv -p python3.6 env 12 | source ./env/bin/activate 13 | pip install -r requirements/linux_gpu.txt 14 | ``` 15 | 16 | or You can just use ready [Dockerfile](./Dockerfile) 17 | 18 | 19 | ## Dataset 20 | 21 | The model was trained on the [AVA (Aesthetic Visual Analysis) dataset](http://refbase.cvc.uab.es/files/MMP2012a.pdf) 22 | You can get it from [here](https://github.com/mtobeiyf/ava_downloader) 23 | Here are some examples of images with theire scores 24 | ![result1](https://3.bp.blogspot.com/-_BuiLfAsHGE/WjgoftooRiI/AAAAAAAACR0/mB3tOfinfgA5Z7moldaLIGn92ounSOb8ACLcBGAs/s1600/image2.png) 25 | 26 | ## Model 27 | 28 | Used MobileNetV2 architecture as described in the paper [Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation](https://arxiv.org/pdf/1801.04381). 29 | 30 | ## Pre-train model 31 | 32 | You can use this [pretrain-model](https://s3-us-west-1.amazonaws.com/models-nima/pretrain-model.pth) with 33 | ```bash 34 | val_emd_loss = 0.079 35 | test_emd_loss = 0.080 36 | ``` 37 | ## Deployment 38 | 39 | Deployed model on [heroku](https://www.heroku.com/) URL is https://neural-image-assessment.herokuapp.com/ You can use it for testing in Your own images, but pay attention, that's free service, so it cannot handel too many requests. Here is simple curl command to test deployment models 40 | ```bash 41 | curl -X POST -F "file=@123.jpg" https://neural-image-assessment.herokuapp.com/api/get_scores 42 | ``` 43 | Please use our [swagger](https://neural-image-assessment.herokuapp.com/apidocs) for interactive testing 44 | 45 | 46 | ## Usage 47 | ```bash 48 | export PYTHONPATH=. 49 | export PATH_TO_AVA_TXT=/storage/DATA/ava/AVA.txt 50 | export PATH_TO_IMAGES=/storage/DATA/images/ 51 | export PATH_TO_CSV=/storage/DATA/ava/ 52 | export BATCH_SIZE=16 53 | export NUM_WORKERS=2 54 | export NUM_EPOCH=50 55 | export INIT_LR=0.0001 56 | export EXPERIMENT_DIR_NAME=/storage/experiment_n0001 57 | ``` 58 | Clean and prepare dataset 59 | ```bash 60 | python nima/cli.py prepare_dataset --path_to_ava_txt $PATH_TO_AVA_TXT \ 61 | --path_to_save_csv $PATH_TO_CSV \ 62 | --path_to_images $PATH_TO_IMAGES 63 | 64 | ``` 65 | 66 | Train model 67 | ```bash 68 | python nima/cli.py train_model --path_to_save_csv $PATH_TO_CSV \ 69 | --path_to_images $PATH_TO_IMAGES \ 70 | --batch_size $BATCH_SIZE \ 71 | --num_workers $NUM_WORKERS \ 72 | --num_epoch $NUM_EPOCH \ 73 | --init_lr $INIT_LR \ 74 | --experiment_dir_name $EXPERIMENT_DIR_NAME 75 | 76 | 77 | ``` 78 | Use tensorboard to tracking training progress 79 | 80 | ```bash 81 | tensorboard --logdir . 82 | ``` 83 | Validate model on val and test datasets 84 | ```bash 85 | python nima/cli.py validate_model --path_to_model_weight ./pretrain-model.pth \ 86 | --path_to_save_csv $PATH_TO_CSV \ 87 | --path_to_images $PATH_TO_IMAGES \ 88 | --batch_size $BATCH_SIZE \ 89 | --num_workers $NUM_EPOCH 90 | ``` 91 | Get scores for one image 92 | ```bash 93 | python nima/cli.py get_image_score --path_to_model_weight ./pretrain-model.pth --path_to_image test_image.jpg 94 | ``` 95 | 96 | ## Contributing 97 | 98 | Contributing are welcome 99 | 100 | 101 | ## License 102 | 103 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details 104 | 105 | ## Acknowledgments 106 | 107 | * [neural-image-assessment in keras](https://github.com/titu1994/neural-image-assessment) 108 | * [Neural-IMage-Assessment in pytorch](https://github.com/kentsyx/Neural-IMage-Assessment) 109 | * [pytorch-mobilenet-v2](https://github.com/tonylins/pytorch-mobilenet-v2) 110 | * [origin NIMA article](https://arxiv.org/abs/1709.05424) 111 | * [origin MobileNetV2 article](https://arxiv.org/pdf/1801.04381) 112 | * [Post at Google Research Blog](https://research.googleblog.com/2017/12/introducing-nima-neural-image-assessment.html) 113 | * [Heroku: Cloud Application Platform](https://www.heroku.com/) 114 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | 3 | from nima.train.clean_dataset import clean_and_split 4 | from nima.train.utils import TrainParams, ValidateParams 5 | from nima.train.main import start_train, start_check_model 6 | from nima.inference.inference_model import InferenceModel 7 | 8 | 9 | @click.group() 10 | def cli(): 11 | pass 12 | 13 | 14 | @click.command() 15 | @click.option('--path_to_ava_txt', help='origin AVA.txt file', required=True) 16 | @click.option('--path_to_save_csv', help='where save train.csv|val.csv|test.csv', required=True) 17 | @click.option('--path_to_images', help='images directory', required=True) 18 | def prepare_dataset(path_to_ava_txt, path_to_save_csv, path_to_images): 19 | click.echo('Clean and split dataset to train|val|test') 20 | clean_and_split(path_to_ava_txt=path_to_ava_txt, path_to_save_csv=path_to_save_csv, path_to_images=path_to_images) 21 | click.echo('Done') 22 | 23 | 24 | @click.command() 25 | @click.option('--path_to_save_csv', help='where save train.csv|val.csv|test.csv', required=True) 26 | @click.option('--path_to_images', help='images directory', required=True) 27 | @click.option('--experiment_dir_name', help='unique experiment name and directory to save all logs and weight', 28 | required=True) 29 | @click.option('--batch_size', help='batch size', required=True, type=int) 30 | @click.option('--num_workers', help='number of reading workers', required=True, type=int) 31 | @click.option('--num_epoch', help='number of epoch', required=True, type=int) 32 | @click.option('--init_lr', help='initial learning rate', required=True, type=float) 33 | def train_model(path_to_save_csv, path_to_images, experiment_dir_name, batch_size, num_workers, num_epoch, init_lr): 34 | click.echo('Train and Validate model save all logs too tensorboard and params to params.json') 35 | params = TrainParams(path_to_save_csv=path_to_save_csv, path_to_images=path_to_images, 36 | experiment_dir_name=experiment_dir_name, batch_size=batch_size, num_workers=num_workers, 37 | num_epoch=num_epoch, init_lr=init_lr) 38 | start_train(params) 39 | 40 | 41 | @click.command() 42 | @click.option('--path_to_model_weight', help='path to model weight .pth file', required=True) 43 | @click.option('--path_to_image', help='image ', required=True) 44 | def get_image_score(path_to_model_weight, path_to_image): 45 | model = InferenceModel(path_to_model=path_to_model_weight) 46 | result = model.predict_from_file(path_to_image) 47 | click.echo(result) 48 | 49 | 50 | @click.command() 51 | @click.option('--path_to_model_weight', help='path to model weight .pth file', required=True) 52 | @click.option('--path_to_save_csv', help='where save train.csv|val.csv|test.csv', required=True) 53 | @click.option('--path_to_images', help='images directory', required=True) 54 | @click.option('--batch_size', help='batch size', required=True, type=int) 55 | @click.option('--num_workers', help='number of reading workers', required=True, type=int) 56 | def validate_model(path_to_model_weight, path_to_save_csv, path_to_images, batch_size, num_workers): 57 | params = ValidateParams(path_to_save_csv=path_to_save_csv, path_to_model_weight=path_to_model_weight, 58 | path_to_images=path_to_images, num_workers=num_workers, batch_size=batch_size) 59 | 60 | val_loss, test_loss = start_check_model(params) 61 | click.echo(f"val_loss = {val_loss}; test_loss = {test_loss}") 62 | 63 | 64 | cli.add_command(prepare_dataset) 65 | cli.add_command(train_model) 66 | cli.add_command(validate_model) 67 | cli.add_command(get_image_score) 68 | 69 | 70 | if __name__ == '__main__': 71 | cli() 72 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import numpy as np 4 | from torchvision import transforms 5 | 6 | 7 | IMAGE_NET_MEAN = [0.485, 0.456, 0.406] 8 | IMAGE_NET_STD = [0.229, 0.224, 0.225] 9 | 10 | 11 | class Transform: 12 | def __init__(self): 13 | normalize = transforms.Normalize( 14 | mean=IMAGE_NET_MEAN, 15 | std=IMAGE_NET_STD) 16 | 17 | self._train_transform = transforms.Compose([ 18 | transforms.Resize((256, 256)), 19 | transforms.RandomHorizontalFlip(), 20 | transforms.RandomCrop((224, 224)), 21 | transforms.ToTensor(), 22 | normalize]) 23 | 24 | self._val_transform = transforms.Compose([ 25 | transforms.Resize((224, 224)), 26 | transforms.ToTensor(), 27 | normalize]) 28 | 29 | @property 30 | def train_transform(self): 31 | return self._train_transform 32 | 33 | @property 34 | def val_transform(self): 35 | return self._val_transform 36 | 37 | 38 | def get_mean_score(score): 39 | buckets = np.arange(1, 11) 40 | mu = (buckets * score).sum() 41 | return mu 42 | 43 | 44 | def get_std_score(scores): 45 | si = np.arange(1, 11) 46 | mean = get_mean_score(scores) 47 | std = np.sqrt(np.sum(((si - mean) ** 2) * scores)) 48 | return std 49 | 50 | 51 | def download_file(url, local_filename, chunk_size=1024): 52 | if os.path.exists(local_filename): 53 | return local_filename 54 | r = requests.get(url, stream=True) 55 | with open(local_filename, 'wb') as f: 56 | for chunk in r.iter_content(chunk_size=chunk_size): 57 | if chunk: 58 | f.write(chunk) 59 | return local_filename 60 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/inference/app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, redirect, url_for, request, jsonify 2 | from PIL import Image 3 | 4 | from flasgger import Swagger 5 | 6 | from nima.inference.inference_model import InferenceModel 7 | 8 | app = Flask(__name__) 9 | Swagger(app=app) 10 | app.model = InferenceModel.create_model() 11 | 12 | 13 | @app.route('/') 14 | def index(): 15 | return redirect(url_for('health_check')) 16 | 17 | 18 | @app.route('/api/health_check') 19 | def health_check(): 20 | return "ok" 21 | 22 | 23 | @app.route('/api/get_scores', methods=['POST']) 24 | def get_scores(): 25 | """ 26 | NIMA Pytorch 27 | 28 | --- 29 | tags: 30 | - Get Scores 31 | consumes: 32 | - multipart/form-data 33 | parameters: 34 | - in: formData 35 | type: file 36 | name: file 37 | required: true 38 | description: Upload your file. 39 | responses: 40 | 200: 41 | description: Scores for image 42 | schema: 43 | id: Palette 44 | type: object 45 | properties: 46 | mean_score: 47 | type: float 48 | std_score: 49 | type: float 50 | scores: 51 | type: array 52 | items: 53 | type: float 54 | examples: 55 | { 56 | "mean_score": 5.385255615692586, 57 | "scores": [ 58 | 0.0049467734061181545, 59 | 0.018246186897158623, 60 | 0.05434520170092583, 61 | 0.16275958716869354, 62 | 0.3268744945526123, 63 | 0.24433879554271698, 64 | 0.11257114261388779, 65 | 0.05015537887811661, 66 | 0.017528045922517776, 67 | 0.00823438260704279 68 | ], 69 | "std_score": 1.451693009595486 70 | } 71 | """ 72 | 73 | img = Image.open(request.files['file']) 74 | result = app.model.predict_from_pil_image(img) 75 | return jsonify(result) 76 | 77 | 78 | if __name__ == '__main__': 79 | app.run(host='0.0.0.0', port=5000, debug=True) 80 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/inference/inference_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets.folder import default_loader 3 | 4 | from decouple import config 5 | 6 | from nima.model import NIMA 7 | from nima.common import Transform, get_mean_score, get_std_score 8 | from nima.common import download_file 9 | from nima.inference.utils import format_output 10 | 11 | use_cuda = torch.cuda.is_available() 12 | device = torch.device("cuda" if use_cuda else "cpu") 13 | 14 | 15 | class InferenceModel: 16 | @classmethod 17 | def create_model(cls): 18 | path_to_model = download_file(config('MODEL_URL'), config('MODEL_PATH')) 19 | return cls(path_to_model) 20 | 21 | def __init__(self, path_to_model): 22 | self.transform = Transform().val_transform 23 | self.model = NIMA(pretrained_base_model=False) 24 | state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage) 25 | self.model.load_state_dict(state_dict) 26 | self.model = self.model.to(device) 27 | self.model.eval() 28 | 29 | def predict_from_file(self, image_path): 30 | image = default_loader(image_path) 31 | return self.predict(image) 32 | 33 | def predict_from_pil_image(self, image): 34 | image = image.convert('RGB') 35 | return self.predict(image) 36 | 37 | def predict(self, image): 38 | image = self.transform(image) 39 | image = image.unsqueeze_(0) 40 | image = image.to(device) 41 | image = torch.autograd.Variable(image, volatile=True) 42 | prob = self.model(image).data.cpu().numpy()[0] 43 | 44 | mean_score = get_mean_score(prob) 45 | std_score = get_std_score(prob) 46 | 47 | return format_output(mean_score, std_score, prob) 48 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/inference/utils.py: -------------------------------------------------------------------------------- 1 | def format_output(mean_score, std_score, prob): 2 | return { 3 | 'mean_score': float(mean_score), 4 | 'std_score': float(std_score), 5 | 'scores': [float(x) for x in prob] 6 | } 7 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/mobile_net_v2.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from nima.common import download_file 8 | 9 | MOBILE_NET_V2_UTR = 'https://s3-us-west-1.amazonaws.com/models-nima/mobilenetv2.pth.tar' 10 | 11 | 12 | def conv_bn(inp, oup, stride): 13 | return nn.Sequential( 14 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 15 | nn.BatchNorm2d(oup), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | 20 | def conv_1x1_bn(inp, oup): 21 | return nn.Sequential( 22 | nn.Conv2d(inp, oup, 1, 1, 0, bias=False), 23 | nn.BatchNorm2d(oup), 24 | nn.ReLU(inplace=True) 25 | ) 26 | 27 | 28 | class InvertedResidual(nn.Module): 29 | def __init__(self, inp, oup, stride, expand_ratio): 30 | super(InvertedResidual, self).__init__() 31 | self.stride = stride 32 | assert stride in [1, 2] 33 | 34 | self.use_res_connect = self.stride == 1 and inp == oup 35 | 36 | self.conv = nn.Sequential( 37 | # pw 38 | nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), 39 | nn.BatchNorm2d(inp * expand_ratio), 40 | nn.ReLU6(inplace=True), 41 | # dw 42 | nn.Conv2d(inp * expand_ratio, inp * expand_ratio, 3, stride, 1, groups=inp * expand_ratio, bias=False), 43 | nn.BatchNorm2d(inp * expand_ratio), 44 | nn.ReLU6(inplace=True), 45 | # pw-linear 46 | nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), 47 | nn.BatchNorm2d(oup), 48 | ) 49 | 50 | def forward(self, x): 51 | if self.use_res_connect: 52 | return x + self.conv(x) 53 | else: 54 | return self.conv(x) 55 | 56 | 57 | class MobileNetV2(nn.Module): 58 | def __init__(self, n_class=1000, input_size=224, width_mult=1.): 59 | super(MobileNetV2, self).__init__() 60 | # setting of inverted residual blocks 61 | self.interverted_residual_setting = [ 62 | # t, c, n, s 63 | [1, 16, 1, 1], 64 | [6, 24, 2, 2], 65 | [6, 32, 3, 2], 66 | [6, 64, 4, 2], 67 | [6, 96, 3, 1], 68 | [6, 160, 3, 2], 69 | [6, 320, 1, 1], 70 | ] 71 | 72 | # building first layer 73 | assert input_size % 32 == 0 74 | input_channel = int(32 * width_mult) 75 | self.last_channel = int(1280 * width_mult) if width_mult > 1.0 else 1280 76 | self.features = [conv_bn(3, input_channel, 2)] 77 | # building inverted residual blocks 78 | for t, c, n, s in self.interverted_residual_setting: 79 | output_channel = int(c * width_mult) 80 | for i in range(n): 81 | if i == 0: 82 | self.features.append(InvertedResidual(input_channel, output_channel, s, t)) 83 | else: 84 | self.features.append(InvertedResidual(input_channel, output_channel, 1, t)) 85 | input_channel = output_channel 86 | # building last several layers 87 | self.features.append(conv_1x1_bn(input_channel, self.last_channel)) 88 | self.features.append(nn.AvgPool2d(input_size // 32)) 89 | # make it nn.Sequential 90 | self.features = nn.Sequential(*self.features) 91 | 92 | # building classifier 93 | self.classifier = nn.Sequential( 94 | nn.Dropout(), 95 | nn.Linear(self.last_channel, n_class), 96 | ) 97 | 98 | self._initialize_weights() 99 | 100 | def forward(self, x): 101 | x = self.features(x) 102 | x = x.view(-1, self.last_channel) 103 | x = self.classifier(x) 104 | return x 105 | 106 | def _initialize_weights(self): 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 110 | m.weight.data.normal_(0, math.sqrt(2. / n)) 111 | if m.bias is not None: 112 | m.bias.data.zero_() 113 | elif isinstance(m, nn.BatchNorm2d): 114 | m.weight.data.fill_(1) 115 | m.bias.data.zero_() 116 | elif isinstance(m, nn.Linear): 117 | n = m.weight.size(1) 118 | m.weight.data.normal_(0, 0.01) 119 | m.bias.data.zero_() 120 | 121 | 122 | def mobile_net_v2(pretrained=True): 123 | model = MobileNetV2() 124 | if pretrained: 125 | path_to_model = '/tmp/mobilenetv2.pth.tar' 126 | if not os.path.exists(path_to_model): 127 | path_to_model = download_file(MOBILE_NET_V2_UTR, path_to_model) 128 | state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage) 129 | model.load_state_dict(state_dict) 130 | return model 131 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from nima.mobile_net_v2 import mobile_net_v2 4 | 5 | 6 | class NIMA(nn.Module): 7 | def __init__(self, pretrained_base_model=True): 8 | super(NIMA, self).__init__() 9 | base_model = mobile_net_v2(pretrained=pretrained_base_model) 10 | base_model = nn.Sequential(*list(base_model.children())[:-1]) 11 | 12 | self.base_model = base_model 13 | 14 | self.head = nn.Sequential( 15 | nn.ReLU(inplace=True), 16 | nn.Dropout(p=0.75), 17 | nn.Linear(1280, 10), 18 | nn.Softmax(dim=1) 19 | ) 20 | 21 | def forward(self, x): 22 | x = self.base_model(x) 23 | x = x.view(x.size(0), -1) 24 | x = self.head(x) 25 | return x 26 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/train/clean_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from concurrent.futures import ThreadPoolExecutor, as_completed 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from tqdm import tqdm 7 | from torchvision.datasets.folder import default_loader 8 | from sklearn.model_selection import train_test_split 9 | 10 | 11 | from nima.train.utils import SCORE_NAMES, TAG_NAMES 12 | 13 | 14 | def _remove_all_not_found_image(df: pd.DataFrame, path_to_images: str) -> pd.DataFrame: 15 | clean_rows = [] 16 | for _, row in df.iterrows(): 17 | image_id = row['image_id'] 18 | try: 19 | _ = default_loader(os.path.join(path_to_images, f"{image_id}.jpg")) 20 | except (FileNotFoundError, OSError): 21 | pass 22 | else: 23 | clean_rows.append(row) 24 | df_clean = pd.DataFrame(clean_rows) 25 | return df_clean 26 | 27 | 28 | def remove_all_not_found_image(df: pd.DataFrame, path_to_images: str, num_workers: int = 64) -> pd.DataFrame: 29 | futures = [] 30 | results = [] 31 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 32 | for df_batch in np.array_split(df, num_workers): 33 | future = executor.submit(_remove_all_not_found_image, df=df_batch, path_to_images=path_to_images) 34 | futures.append(future) 35 | for future in tqdm(as_completed(futures)): 36 | results.append(future.result()) 37 | new_df = pd.concat(results) 38 | return new_df 39 | 40 | 41 | def _read_ava_txt(path_to_ava: str) -> pd.DataFrame: 42 | df = pd.read_csv(path_to_ava, header=None, sep=' ') 43 | del df[0] 44 | scores_names = SCORE_NAMES 45 | tag_names = TAG_NAMES 46 | df.columns = ['image_id'] + scores_names + tag_names 47 | return df 48 | 49 | 50 | def clean_and_split(path_to_ava_txt: str, path_to_save_csv: str, path_to_images: str): 51 | df = _read_ava_txt(path_to_ava_txt) 52 | df = remove_all_not_found_image(df, path_to_images) 53 | 54 | df_train, df_val_test = train_test_split(df, train_size=0.9) 55 | df_val, df_test = train_test_split(df_val_test, train_size=0.5) 56 | 57 | df_train.to_csv(os.path.join(path_to_save_csv, 'train.csv'), index=False) 58 | df_val.to_csv(os.path.join(path_to_save_csv, 'val.csv'), index=False) 59 | df_test.to_csv(os.path.join(path_to_save_csv, 'test.csv'), index=False) 60 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/train/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import numpy as np 5 | 6 | from torch.utils.data import Dataset 7 | from torchvision.datasets.folder import default_loader 8 | 9 | 10 | from nima.train.utils import SCORE_NAMES 11 | 12 | 13 | class AVADataset(Dataset): 14 | def __init__(self, path_to_csv: str, images_path: str, transform): 15 | self.df = pd.read_csv(path_to_csv) 16 | self.images_path = images_path 17 | self.transform = transform 18 | 19 | def __len__(self): 20 | return self.df.shape[0] 21 | 22 | def __getitem__(self, item): 23 | row = self.df.iloc[item] 24 | y = np.array([row[k] for k in SCORE_NAMES]) 25 | p = y / y.sum() 26 | 27 | image_id = row['image_id'] 28 | image_path = os.path.join(self.images_path, f'{image_id}.jpg') 29 | image = default_loader(image_path) 30 | x = self.transform(image) 31 | return x, p.astype('float32') 32 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/train/emd_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | 6 | class EDMLoss(nn.Module): 7 | def __init__(self): 8 | super(EDMLoss, self).__init__() 9 | 10 | def forward(self, p_target: Variable, p_estimate: Variable): 11 | assert p_target.shape == p_estimate.shape 12 | # cdf for values [1, 2, ..., 10] 13 | cdf_target = torch.cumsum(p_target, dim=1) 14 | # cdf for values [1, 2, ..., 10] 15 | cdf_estimate = torch.cumsum(p_estimate, dim=1) 16 | cdf_diff = cdf_estimate - cdf_target 17 | samplewise_emd = torch.sqrt(torch.mean(torch.pow(torch.abs(cdf_diff), 2))) 18 | return samplewise_emd.mean() 19 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/train/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import torch 4 | 5 | from torch.utils.data import DataLoader 6 | 7 | from tensorboardX import SummaryWriter 8 | 9 | 10 | from nima.model import NIMA 11 | from nima.train.datasets import AVADataset 12 | from nima.train.emd_loss import EDMLoss 13 | from nima.common import Transform 14 | from nima.train.utils import TrainParams, ValidateParams, AverageMeter 15 | 16 | use_gpu = torch.cuda.is_available() 17 | device = torch.device("cuda" if use_gpu else "cpu") 18 | 19 | 20 | def train(model, loader, optimizer, criterion, writer=None, global_step=None, name=None): 21 | model.train() 22 | train_losses = AverageMeter() 23 | for idx, (x, y) in enumerate(tqdm(loader)): 24 | x = x.to(device) 25 | y = y.to(device) 26 | y_pred = model(x) 27 | loss = criterion(p_target=y, p_estimate=y_pred) 28 | optimizer.zero_grad() 29 | loss.backward() 30 | optimizer.step() 31 | train_losses.update(loss.item(), x.size(0)) 32 | 33 | if writer is not None: 34 | writer.add_scalar(f"{name}/train_loss.avg", train_losses.avg, global_step=global_step + idx) 35 | return train_losses.avg 36 | 37 | 38 | def validate(model, loader, criterion, writer=None, global_step=None, name=None): 39 | model.eval() 40 | validate_losses = AverageMeter() 41 | for idx, (x, y) in enumerate(tqdm(loader)): 42 | x = x.to(device) 43 | y = y.to(device) 44 | y_pred = model(x) 45 | loss = criterion(p_target=y, p_estimate=y_pred) 46 | validate_losses.update(loss.item(), x.size(0)) 47 | 48 | if writer is not None: 49 | writer.add_scalar(f"{name}/val_loss.avg", validate_losses.avg, global_step=global_step + idx) 50 | return validate_losses.avg 51 | 52 | 53 | def _create_train_data_part(params: TrainParams): 54 | train_csv_path = os.path.join(params.path_to_save_csv, 'train.csv') 55 | val_csv_path = os.path.join(params.path_to_save_csv, 'val.csv') 56 | 57 | transform = Transform() 58 | train_ds = AVADataset(train_csv_path, params.path_to_images, transform.train_transform) 59 | val_ds = AVADataset(val_csv_path, params.path_to_images, transform.val_transform) 60 | 61 | train_loader = DataLoader(train_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=True) 62 | val_loader = DataLoader(val_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=False) 63 | 64 | return train_loader, val_loader 65 | 66 | 67 | def _create_val_data_part(params: TrainParams): 68 | val_csv_path = os.path.join(params.path_to_save_csv, 'val.csv') 69 | test_csv_path = os.path.join(params.path_to_save_csv, 'test.csv') 70 | 71 | transform = Transform() 72 | val_ds = AVADataset(val_csv_path, params.path_to_images, transform.val_transform) 73 | test_ds = AVADataset(test_csv_path, params.path_to_images, transform.val_transform) 74 | 75 | val_loader = DataLoader(val_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=False) 76 | test_loader = DataLoader(test_ds, batch_size=params.batch_size, num_workers=params.num_workers, shuffle=False) 77 | 78 | return val_loader, test_loader 79 | 80 | 81 | def start_train(params: TrainParams): 82 | train_loader, val_loader = _create_train_data_part(params=params) 83 | model = NIMA() 84 | optimizer = torch.optim.Adam(model.parameters(), lr=params.init_lr) 85 | criterion = EDMLoss() 86 | model = model.to(device) 87 | criterion.to(device) 88 | 89 | writer = SummaryWriter(log_dir=os.path.join(params.experiment_dir_name, 'logs')) 90 | os.makedirs(params.experiment_dir_name, exist_ok=True) 91 | params.save_params(os.path.join(params.experiment_dir_name, 'params.json')) 92 | 93 | for e in range(params.num_epoch): 94 | train_loss = train(model=model, loader=train_loader, optimizer=optimizer, criterion=criterion, 95 | writer=writer, global_step=len(train_loader.dataset) * e, 96 | name=f"{params.experiment_dir_name}_by_batch") 97 | val_loss = validate(model=model, loader=val_loader, criterion=criterion, 98 | writer=writer, global_step=len(train_loader.dataset) * e, 99 | name=f"{params.experiment_dir_name}_by_batch") 100 | 101 | model_name = f"emd_loss_epoch_{e}_train_{train_loss}_{val_loss}.pth" 102 | torch.save(model.module.state_dict(), os.path.join(params.experiment_dir_name, model_name)) 103 | writer.add_scalar(f"{params.experiment_dir_name}_by_epoch/train_loss", train_loss, global_step=e) 104 | writer.add_scalar(f"{params.experiment_dir_name}_by_epoch/val_loss", val_loss, global_step=e) 105 | 106 | writer.export_scalars_to_json(os.path.join(params.experiment_dir_name, 'all_scalars.json')) 107 | writer.close() 108 | 109 | 110 | def start_check_model(params: ValidateParams): 111 | val_loader, test_loader = _create_val_data_part(params) 112 | model = NIMA() 113 | model.load_state_dict(torch.load(params.path_to_model_weight)) 114 | criterion = EDMLoss() 115 | 116 | model = model.to(device) 117 | criterion.to(device) 118 | 119 | val_loss = validate(model=model, loader=val_loader, criterion=criterion) 120 | test_loss = validate(model=model, loader=test_loader, criterion=criterion) 121 | return val_loss, test_loss 122 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/nima/train/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import namedtuple 3 | 4 | _SCORE_FIRST_COLUMN = 2 5 | _SCORE_LAST_COLUMN = 12 6 | _TAG_FIRST_COLUMN = 1 7 | _TAG_LAST_COLUMN = 4 8 | 9 | SCORE_NAMES = [f'score{i}' for i in range(_SCORE_FIRST_COLUMN, _SCORE_LAST_COLUMN)] 10 | TAG_NAMES = [f'tag{i}' for i in range(_TAG_FIRST_COLUMN, _TAG_LAST_COLUMN)] 11 | 12 | 13 | class TrainParams(namedtuple('TrainParams', ['path_to_save_csv', 'path_to_images', 14 | 'experiment_dir_name', 'batch_size', 15 | 'num_workers', 'num_epoch', 'init_lr'])): 16 | def save_params(self, file_path: str): 17 | with open(file_path, 'w') as f: 18 | json.dump(self._asdict(), f) 19 | 20 | 21 | class ValidateParams(namedtuple('TrainParams', ['path_to_save_csv', 'path_to_model_weight', 22 | 'path_to_images', 'batch_size', 23 | 'num_workers'])): 24 | pass 25 | 26 | 27 | class AverageMeter(object): 28 | def __init__(self): 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | 37 | def update(self, val, n=1): 38 | self.val = val 39 | self.sum += val * n 40 | self.count += n 41 | self.avg = self.sum / self.count 42 | -------------------------------------------------------------------------------- /metrics/NIMA/nima/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask==0.12.2 2 | gunicorn==19.7.1 3 | ipython==6.2.1 4 | pandas==0.22.0 5 | tqdm==4.19.5 6 | scikit-learn==0.19.1 7 | scipy==1.0.0 8 | python-decouple==3.1 9 | requests==2.18.4 10 | flasgger==0.8.1 11 | tensorflow==1.6.0 12 | tensorboardX==1.1 13 | torch==0.4.0 14 | torchvision==0.2.1 -------------------------------------------------------------------------------- /metrics/NIMA/nima/settings.ini: -------------------------------------------------------------------------------- 1 | [settings] 2 | 3 | MODEL_URL = https://s3-us-west-1.amazonaws.com/models-nima/pretrain-model.pth 4 | MODEL_PATH = /tmp/nima-pretrain-model.pth -------------------------------------------------------------------------------- /metrics/NIMA/pretrain-model.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/metrics/NIMA/pretrain-model.pth -------------------------------------------------------------------------------- /metrics/NIMA/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | import torch 4 | import torchvision.models 5 | import torchvision.transforms as transforms 6 | from PIL import Image 7 | import os 8 | import torch.nn as nn 9 | from mobile_net_v2 import mobile_net_v2 10 | import numpy as np 11 | import argparse 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--method', type=str, help='name of method') 17 | parser.add_argument('--dataset', type=str, default='fivek', help='dataset') 18 | parser.add_argument('--test_images', type=str, default='./images/', help='path to folder containing images') 19 | args = parser.parse_args() 20 | 21 | 22 | # def get_mean_score(score): 23 | # buckets = np.arange(1, 11) 24 | # mu = (buckets * score).sum() 25 | # return mu 26 | 27 | 28 | # def get_std_score(scores): 29 | # si = np.arange(1, 11) 30 | # mean = get_mean_score(scores) 31 | # std = np.sqrt(np.sum(((si - mean) ** 2) * scores)) 32 | # return std 33 | 34 | class NIMA(nn.Module): 35 | def __init__(self, pretrained_base_model=False): 36 | super(NIMA, self).__init__() 37 | base_model = mobile_net_v2(pretrained=pretrained_base_model) 38 | base_model = nn.Sequential(*list(base_model.children())[:-1]) 39 | 40 | self.base_model = base_model 41 | 42 | self.head = nn.Sequential( 43 | nn.ReLU(inplace=True), 44 | nn.Dropout(p=0.75), 45 | nn.Linear(1280, 10), 46 | nn.Softmax(dim=1) 47 | ) 48 | 49 | def forward(self, x): 50 | x = self.base_model(x) 51 | x = x.view(x.size(0), -1) 52 | x = self.head(x) 53 | return x 54 | 55 | 56 | def prepare_image(image): 57 | if image.mode != 'RGB': 58 | image = image.convert("RGB") 59 | Transform = transforms.Compose([ 60 | transforms.Resize(256), 61 | transforms.CenterCrop(224), 62 | transforms.ToTensor(), 63 | ]) 64 | image = Transform(image) 65 | image = image[None] 66 | return image 67 | 68 | 69 | def main(model): 70 | 71 | mean, deviation = 0.0, 0.0 72 | 73 | result_save_path = './nima_results' 74 | if not os.path.exists(result_save_path): 75 | os.makedirs(result_save_path) 76 | 77 | txt_name = './nima_results/result_' + args.dataset + '_' + args.method + '_mean_std.csv' 78 | outfile = open(txt_name, 'w') 79 | outfile.write('image_name' + ','+ 'mean' +',' + 'std' +'\n') 80 | 81 | test_imgs = [f for f in os.listdir(args.test_images)] 82 | for i, img in enumerate(test_imgs): 83 | image = Image.open(os.path.join(args.test_images, img)) 84 | image = prepare_image(image).to(device) 85 | with torch.no_grad(): 86 | preds = model(image).data.cpu().numpy()[0] 87 | # mean = get_mean_score(preds) 88 | # std = get_std_score(preds) 89 | # print('preds: ', preds) 90 | 91 | for j, e in enumerate(preds, 1): 92 | mean += j * e 93 | # print('mean: ', round(mean.item(),5)) 94 | 95 | for k, e in enumerate(preds, 1): 96 | deviation += (e * (k - mean) ** 2) 97 | # print('deviation: ', round(deviation.item(),5)) 98 | std = deviation ** (0.5) 99 | 100 | outfile.write(img + ',' + str(round(mean, 5)) + ',' + str(round(std, 5)) + '\n') 101 | print('processing {:>4d}-th image{:s}: mean: {:>2.6f} and std: {:>2.6f}'.format(i+1, img, round(mean,5).item(), round(std, 5).item())) 102 | 103 | mean, deviation = 0.0, 0.0 104 | 105 | outfile.close() 106 | 107 | 108 | if __name__ == '__main__': 109 | model = NIMA() 110 | # print(model) 111 | model.load_state_dict(torch.load('pretrain-model.pth')) 112 | print('Successfully loaded pretrained model...') 113 | model.to(device).eval() 114 | 115 | main(model) 116 | 117 | -------------------------------------------------------------------------------- /metrics/__pycache__/CalcPSNR.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/metrics/__pycache__/CalcPSNR.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/CalcSSIM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Falling-dow/Unsupervised-Image-Enhancement-with-CNN-and-GAN/05f63986413037eed50de93a20de7fb0ce21d901/metrics/__pycache__/CalcSSIM.cpython-38.pyc -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #-*- codign:utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import functools 8 | 9 | 10 | class Generator(nn.Module): 11 | """Generator network""" 12 | def __init__(self, conv_dim, norm_fun, act_fun, use_sn): 13 | super(Generator, self).__init__() 14 | 15 | ###### encoder 16 | self.enc1 = ConvBlock(in_channels=3, out_channels=conv_dim* 1, kernel_size=7, stride=1, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 256*256*3 --> 256*256*32 17 | self.enc2 = ConvBlock(in_channels=conv_dim*1, out_channels=conv_dim* 2, kernel_size=3, stride=2, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 256*256*32 --> 128*128*64 18 | self.enc3 = ConvBlock(in_channels=conv_dim*2, out_channels=conv_dim* 4, kernel_size=3, stride=2, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 128*128*64 --> 64*64*128 19 | self.enc4 = ConvBlock(in_channels=conv_dim*4, out_channels=conv_dim* 8, kernel_size=3, stride=2, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 64*64*128 --> 32*32*256 20 | self.enc5 = ConvBlock(in_channels=conv_dim*8, out_channels=conv_dim*16, kernel_size=3, stride=2, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 32*32*256 --> 16*16*512 21 | 22 | ###### decoder 23 | self.upsample1 = nn.Sequential(Interpolate(2, 'bilinear', True), SNConv(conv_dim*16, conv_dim*8, 1, 1, 0, 1, True, use_sn)) 24 | self.upsample2 = nn.Sequential(Interpolate(2, 'bilinear', True), SNConv(conv_dim* 8, conv_dim*4, 1, 1, 0, 1, True, use_sn)) 25 | self.upsample3 = nn.Sequential(Interpolate(2, 'bilinear', True), SNConv(conv_dim* 4, conv_dim*2, 1, 1, 0, 1, True, use_sn)) 26 | self.upsample4 = nn.Sequential(Interpolate(2, 'bilinear', True), SNConv(conv_dim* 2, conv_dim*1, 1, 1, 0, 1, True, use_sn)) 27 | 28 | self.dec1 = ConvBlock(in_channels=conv_dim*16, out_channels=conv_dim*8, kernel_size=3, stride=1, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 32*32*512 --> 32*32*256 29 | self.dec2 = ConvBlock(in_channels=conv_dim* 8, out_channels=conv_dim*4, kernel_size=3, stride=1, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 64*64*256 --> 64*64*128 30 | self.dec3 = ConvBlock(in_channels=conv_dim* 4, out_channels=conv_dim*2, kernel_size=3, stride=1, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 128*128*128 --> 128*128*64 31 | self.dec4 = ConvBlock(in_channels=conv_dim* 2, out_channels=conv_dim*1, kernel_size=3, stride=1, padding=0, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn) # 256*256*64 --> 256*256*32 32 | self.dec5 = nn.Sequential( 33 | SNConv(in_channels=conv_dim*1, out_channels=conv_dim*1, kernel_size=3, stride=1, padding=0, dilation=1, use_bias=True, use_sn=False), 34 | SNConv(in_channels=conv_dim*1, out_channels=3, kernel_size=7, stride=1, padding=0, dilation=1, use_bias=True, use_sn=False), 35 | nn.Tanh() 36 | ) 37 | 38 | self.ga5 = GAM(conv_dim*16, conv_dim*16, reduction=8, bias=False, use_sn=use_sn, norm=True) 39 | self.ga4 = GAM(conv_dim* 8, conv_dim* 8, reduction=8, bias=False, use_sn=use_sn, norm=True) 40 | self.ga3 = GAM(conv_dim* 4, conv_dim* 4, reduction=8, bias=False, use_sn=use_sn, norm=True) 41 | self.ga2 = GAM(conv_dim* 2, conv_dim* 2, reduction=8, bias=False, use_sn=use_sn, norm=True) 42 | self.ga1 = GAM(conv_dim* 1, conv_dim* 1, reduction=8, bias=False, use_sn=use_sn, norm=True) 43 | 44 | def forward(self, x): 45 | ### encoder 46 | x1 = self.enc1( x) 47 | x2 = self.enc2(x1) 48 | x3 = self.enc3(x2) 49 | x4 = self.enc4(x3) 50 | x5 = self.enc5(x4) 51 | x5 = self.ga5(x5) 52 | 53 | ### decoder 54 | y1 = self.upsample1(x5) 55 | y1 = torch.cat([y1, self.ga4(x4)], dim=1) 56 | y1 = self.dec1(y1) 57 | 58 | y2 = self.upsample2(y1) 59 | y2 = torch.cat([y2, self.ga3(x3)], dim=1) 60 | y2 = self.dec2(y2) 61 | 62 | y3 = self.upsample3(y2) 63 | y3 = torch.cat([y3, self.ga2(x2)], dim=1) 64 | y3 = self.dec3(y3) 65 | 66 | y4 = self.upsample4(y3) 67 | y4 = torch.cat([y4, self.ga1(x1)], dim=1) 68 | y4 = self.dec4(y4) 69 | 70 | res = self.dec5(y4.mul(x1)) 71 | 72 | out = torch.clamp((res + x), min=-1.0, max=1.0) 73 | 74 | return out 75 | 76 | 77 | class SNConv(nn.Module): 78 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, use_bias, use_sn): 79 | super(SNConv, self).__init__() 80 | self.padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 81 | self.main = nn.Sequential( 82 | nn.ReflectionPad2d(self.padding), 83 | SpectralNorm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, bias=use_bias), use_sn), 84 | ) 85 | def forward(self, x): 86 | return self.main(x) 87 | 88 | class ConvBlock(nn.Module): 89 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, use_bias, norm_fun, act_fun, use_sn): 90 | super(ConvBlock, self).__init__() 91 | self.padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 92 | main = [] 93 | main.append(nn.ReflectionPad2d(self.padding)) 94 | main.append(SpectralNorm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, bias=use_bias), use_sn)) 95 | norm_fun = get_norm_fun(norm_fun) 96 | main.append(norm_fun(out_channels)) 97 | main.append(get_act_fun(act_fun)) 98 | self.main = nn.Sequential(*main) 99 | 100 | def forward(self, x): 101 | return self.main(x) 102 | 103 | 104 | class Discriminator(nn.Module): 105 | def __init__(self, conv_dim, norm_fun, act_fun, use_sn, adv_loss_type): 106 | super(Discriminator, self).__init__() 107 | 108 | # scale 1 and prediction of scale 1 128 109 | d_1 = [dis_conv_block(in_channels=3, out_channels=conv_dim, kernel_size=7, stride=2, padding=3, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn)] 110 | d_1_pred = [dis_pred_conv_block(in_channels=conv_dim, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, use_bias=False, type=adv_loss_type)] 111 | 112 | # scale 2 64 113 | d_2 = [dis_conv_block(in_channels=conv_dim, out_channels=conv_dim * 2, kernel_size=7, stride=2, padding=3, dilation=1, norm_fun=norm_fun, use_bias=True, act_fun=act_fun, use_sn=use_sn)] 114 | d_2_pred = [dis_pred_conv_block(in_channels=conv_dim * 2, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, use_bias=False, type=adv_loss_type)] 115 | 116 | # scale 3 and prediction of scale 3 32 117 | d_3 = [dis_conv_block(in_channels=conv_dim* 2, out_channels=conv_dim* 4, kernel_size=7, stride=2, padding=3, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn)] 118 | d_3_pred = [dis_pred_conv_block(in_channels=conv_dim* 4, out_channels=1, kernel_size=7, stride=1, padding=3, dilation=1, use_bias=False, type=adv_loss_type)] 119 | 120 | # scale 4 16 121 | d_4 = [dis_conv_block(in_channels=conv_dim* 4, out_channels=conv_dim* 8, kernel_size=5, stride=2, padding=2, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn)] 122 | d_4_pred = [dis_pred_conv_block(in_channels=conv_dim * 8, out_channels=1, kernel_size=5, stride=1, padding=2, dilation=1, use_bias=False, type=adv_loss_type)] 123 | 124 | # scale 5 and prediction of scale 5 8 125 | d_5 = [dis_conv_block(in_channels=conv_dim* 8, out_channels=conv_dim* 16, kernel_size=5, stride=2, padding=2, dilation=1, use_bias=True, norm_fun=norm_fun, act_fun=act_fun, use_sn=use_sn)] 126 | d_5_pred = [dis_pred_conv_block(in_channels=conv_dim* 16, out_channels=1, kernel_size=5, stride=1, padding=2, dilation=1, use_bias=False, type=adv_loss_type)] 127 | 128 | self.d1 = nn.Sequential(*d_1) 129 | self.d1_pred = nn.Sequential(*d_1_pred) 130 | self.d2 = nn.Sequential(*d_2) 131 | self.d2_pred = nn.Sequential(*d_2_pred) 132 | self.d3 = nn.Sequential(*d_3) 133 | self.d3_pred = nn.Sequential(*d_3_pred) 134 | self.d4 = nn.Sequential(*d_4) 135 | self.d4_pred = nn.Sequential(*d_4_pred) 136 | self.d5 = nn.Sequential(*d_5) 137 | self.d5_pred = nn.Sequential(*d_5_pred) 138 | 139 | def forward(self, x): 140 | ds1 = self.d1(x) 141 | ds1_pred = self.d1_pred(ds1) 142 | 143 | ds2 = self.d2(ds1) 144 | ds2_pred = self.d2_pred(ds2) 145 | 146 | ds3 = self.d3(ds2) 147 | ds3_pred = self.d3_pred(ds3) 148 | 149 | ds4 = self.d4(ds3) 150 | ds4_pred = self.d4_pred(ds4) 151 | 152 | ds5 = self.d5(ds4) 153 | ds5_pred = self.d5_pred(ds5) 154 | 155 | return [ds1_pred, ds2_pred, ds3_pred, ds4_pred, ds5_pred] 156 | 157 | 158 | def dis_conv_block(in_channels, out_channels, kernel_size, stride, padding, dilation, use_bias, norm_fun, act_fun, use_sn): 159 | padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 160 | main = [] 161 | main.append(nn.ReflectionPad2d(padding)) 162 | main.append(SpectralNorm(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, bias=use_bias), use_sn)) 163 | norm_fun = get_norm_fun(norm_fun) 164 | main.append(norm_fun(out_channels)) 165 | main.append(get_act_fun(act_fun)) 166 | main = nn.Sequential(*main) 167 | return main 168 | 169 | 170 | def dis_pred_conv_block(in_channels, out_channels, kernel_size, stride, padding, dilation, use_bias, type): 171 | padding = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 172 | main = [] 173 | main.append(nn.ReflectionPad2d(padding)) 174 | main.append(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=0, dilation=dilation, bias=use_bias)) 175 | if type in ['ls', 'rals']: 176 | main.append(nn.Sigmoid()) 177 | elif type in ['hinge', 'rahinge']: 178 | main.append(nn.Tanh()) 179 | else: 180 | raise NotImplementedError("Adversarial loss [{}] is not found".format(type)) 181 | main = nn.Sequential(*main) 182 | return main 183 | 184 | 185 | def SpectralNorm(module, mode=True): 186 | if mode: 187 | return nn.utils.spectral_norm(module) 188 | return module 189 | 190 | 191 | class Interpolate(nn.Module): 192 | def __init__(self, scale_factor, mode, align_corners): 193 | super(Interpolate, self).__init__() 194 | self.interp = nn.functional.interpolate 195 | self.scale_factor = scale_factor 196 | self.mode = mode 197 | self.align_corners = align_corners 198 | 199 | def forward(self, x): 200 | out = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners) 201 | return out 202 | 203 | 204 | def calc_mean_std(feat, eps=1e-5): 205 | # eps is a small value added to the variance to avoid divide-by-zero. 206 | size = feat.data.size() 207 | assert (len(size) == 4) 208 | N, C = size[:2] 209 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 210 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 211 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 212 | return feat_mean, feat_std 213 | class ChannelAttention(nn.Module): 214 | def __init__(self, channel, ratio=4): 215 | super(ChannelAttention, self).__init__() 216 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 217 | self.max_pool = nn.AdaptiveMaxPool2d(1) 218 | self.sharedMLP = nn.Sequential( 219 | nn.Conv2d(channel, channel // ratio, 1, bias=False), 220 | nn.ReLU(), 221 | nn.Conv2d(channel // ratio, channel, 1, bias=False)) 222 | self.sigmoid = nn.Sigmoid() 223 | 224 | def forward(self, x): 225 | avgout = self.sharedMLP(self.avg_pool(x)) 226 | maxout = self.sharedMLP(self.max_pool(x)) 227 | return self.sigmoid(avgout + maxout) 228 | 229 | class SpatialAttention(nn.Module): 230 | def __init__(self): 231 | super(SpatialAttention, self).__init__() 232 | self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1) 233 | self.sigmoid = nn.Sigmoid() 234 | 235 | def forward(self, x): 236 | avgout = torch.mean(x, dim=1, keepdim=True) 237 | maxout, _ = torch.max(x, dim=1, keepdim=True) 238 | x = torch.cat([avgout, maxout], dim=1) 239 | x = self.conv(x) 240 | return self.sigmoid(x) 241 | 242 | class BasicBlock(nn.Module): 243 | def __init__(self, channel): 244 | super(BasicBlock, self).__init__() 245 | self.conv1 = nn.Conv2d(channel, channel, kernel_size=3, padding=1) 246 | self.bn1 = nn.BatchNorm2d(channel) 247 | self.relu = nn.ReLU(inplace=True) 248 | self.conv2 = nn.Conv2d(channel, channel, kernel_size=3, padding=1) 249 | self.bn2 = nn.BatchNorm2d(channel) 250 | self.ca = ChannelAttention(channel) 251 | self.sa = SpatialAttention() 252 | 253 | def forward(self, x): 254 | out = self.conv1(x) 255 | out = self.bn1(out) 256 | out = self.relu(out) 257 | out = self.conv2(out) 258 | out = self.bn2(out) 259 | out = self.ca(out) * out # 广播机制 260 | out = self.sa(out) * out # 广播机制 261 | return self.relu(out + x) 262 | 263 | 264 | class GAM(nn.Module): 265 | """Global attention module""" 266 | def __init__(self, in_nc, out_nc, reduction=8, bias=False, use_sn=False, norm=False): 267 | super(GAM, self).__init__() 268 | self.conv = nn.Sequential( 269 | nn.Conv2d(in_channels=in_nc*2, out_channels=in_nc//reduction, kernel_size=1, stride=1, bias=bias, padding=0, dilation=1), 270 | nn.ReLU(inplace=True), 271 | nn.Conv2d(in_channels=in_nc//reduction, out_channels=out_nc, kernel_size=1, stride=1, bias=bias, padding=0, dilation=1), 272 | ) 273 | self.fuse = nn.Sequential( 274 | SpectralNorm(nn.Conv2d(in_channels=in_nc * 2, out_channels=out_nc, kernel_size=1, stride=1, bias=True, padding=0, dilation=1), use_sn), 275 | ) 276 | self.in_norm = nn.InstanceNorm2d(out_nc) 277 | self.norm = norm 278 | self.CBAM = BasicBlock(out_nc) 279 | def forward(self, x): 280 | x_mean, x_std = calc_mean_std(x) 281 | out = self.conv(torch.cat([x_mean, x_std], dim=1)) 282 | # out = self.conv(x_mean) 283 | out = self.fuse(torch.cat([x, out.expand_as(x)], dim=1)) 284 | if self.norm: 285 | out = self.in_norm(out) 286 | out = self.CBAM(out) 287 | return out 288 | 289 | 290 | class Swish(nn.Module): 291 | def __init__(self): 292 | super(Swish, self).__init__() 293 | self.s = nn.Sigmoid() 294 | 295 | def forward(self, x): 296 | return x * self.s(x) 297 | 298 | 299 | def get_act_fun(act_fun_type='LeakyReLU'): 300 | if isinstance(act_fun_type, str): 301 | if act_fun_type == 'LeakyReLU': 302 | return nn.LeakyReLU(0.2, inplace=True) 303 | elif act_fun_type == 'ReLU': 304 | return nn.ReLU(inplace=True) 305 | elif act_fun_type == 'Swish': 306 | return Swish() 307 | elif act_fun_type == 'SELU': 308 | return nn.SELU(inplace=True) 309 | elif act_fun_type == 'none': 310 | return nn.Sequential() 311 | else: 312 | raise NotImplementedError('activation function [%s] is not found' % act_fun_type) 313 | else: 314 | return act_fun_type() 315 | 316 | 317 | class Identity(nn.Module): 318 | def forward(self, x): 319 | return x 320 | 321 | 322 | def get_norm_fun(norm_fun_type='none'): 323 | if norm_fun_type == 'BatchNorm': 324 | norm_fun = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) 325 | elif norm_fun_type == 'InstanceNorm': 326 | norm_fun = functools.partial(nn.InstanceNorm2d, affine=True, track_running_stats=True) 327 | elif norm_fun_type == 'none': 328 | norm_fun = lambda x: Identity() 329 | else: 330 | raise NotImplementedError('normalization function [%s] is not found' % norm_fun_type) 331 | return norm_fun 332 | -------------------------------------------------------------------------------- /tester.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | 3 | import os 4 | import time 5 | import torch 6 | import datetime 7 | import torch.nn as nn 8 | from torchvision.utils import save_image 9 | from losses import PerceptualLoss, TVLoss 10 | from utils import Logger, denorm, ImagePool, GaussianNoise 11 | from models import Generator, Discriminator 12 | from metrics.NIMA.CalcNIMA import calc_nima 13 | from metrics.CalcPSNR import calc_psnr 14 | from metrics.CalcSSIM import calc_ssim 15 | from tqdm import * 16 | from data_loader import InputFetcher 17 | 18 | 19 | class Tester(object): 20 | def __init__(self, loaders, args): 21 | 22 | # data loader 23 | self.loaders = loaders 24 | 25 | # Model configuration. 26 | self.args = args 27 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | 29 | # Directories. 30 | self.model_save_path = os.path.join(args.save_root_dir, args.version, args.model_save_path) 31 | self.sample_path = os.path.join(args.save_root_dir, args.version, args.sample_path) 32 | self.log_path = os.path.join(args.save_root_dir, args.version, args.log_path) 33 | self.test_result_path = os.path.join(args.save_root_dir, args.version, args.test_result_path) 34 | 35 | # Build the model and tensorboard. 36 | self.build_model() 37 | if self.args.use_tensorboard: 38 | self.build_tensorboard() 39 | 40 | 41 | def test(self): 42 | """ Test UEGAN .""" 43 | self.load_pretrained_model(self.args.pretrained_model) 44 | start_time = time.time() 45 | test_start = 0 46 | test_total_steps = len(self.loaders.tes) 47 | self.fetcher_test = InputFetcher(self.loaders.tes) 48 | 49 | test = {} 50 | test_save_path = self.test_result_path + '/' + 'test_results' 51 | test_compare_save_path = self.test_result_path + '/' + 'test_compare' 52 | 53 | if not os.path.exists(test_save_path): 54 | os.makedirs(test_save_path) 55 | if not os.path.exists(test_compare_save_path): 56 | os.makedirs(test_compare_save_path) 57 | 58 | self.G.eval() 59 | 60 | pbar = tqdm(total=(test_total_steps - test_start), desc='Test epoches', position=test_start) 61 | pbar.write("============================== Start tesing ==============================") 62 | with torch.no_grad(): 63 | for test_step in range(test_start, test_total_steps): 64 | input = next(self.fetcher_test) 65 | test_real_raw, test_name = input.img_raw, input.img_name 66 | 67 | test_fake_exp = self.G(test_real_raw) 68 | 69 | for i in range(0, denorm(test_real_raw.data).size(0)): 70 | save_imgs = denorm(test_fake_exp.data)[i:i + 1,:,:,:] 71 | save_image(save_imgs, os.path.join(test_save_path, '{:s}_{:0>3.2f}_testFakeExp.png'.format(test_name[i], self.args.pretrained_model))) 72 | 73 | save_imgs_compare = torch.cat([denorm(test_real_raw.data)[i:i + 1,:,:,:], denorm(test_fake_exp.data)[i:i + 1,:,:,:]], 3) 74 | save_image(save_imgs_compare, os.path.join(test_compare_save_path, '{:s}_{:0>3.2f}_testRealRaw_testFakeExp.png'.format(test_name[i], self.args.pretrained_model))) 75 | 76 | elapsed = time.time() - start_time 77 | elapsed = str(datetime.timedelta(seconds=elapsed)) 78 | if test_step % self.args.info_step == 0: 79 | pbar.write("=== Elapse:{}, Save {:>3d}-th test_fake_exp images into {} ===".format(elapsed, test_step, test_save_path)) 80 | 81 | test['test/testFakeExp'] = denorm(test_fake_exp.detach().cpu()) 82 | 83 | test['test_compare/testRealRaw_testFakeExp'] = torch.cat([denorm(test_real_raw.cpu()), denorm(test_fake_exp.detach().cpu())], 3) 84 | 85 | pbar.update(1) 86 | 87 | if self.args.use_tensorboard: 88 | for tag, images in test.items(): 89 | self.logger.images_summary(tag, images, test_step + 1) 90 | 91 | if self.args.is_test_nima: 92 | self.nima_result_save_path = './results/nima_test_results/' 93 | curr_nima = calc_nima(test_save_path, self.nima_result_save_path, self.args.pretrained_model) 94 | print("====== Avg. NIMA: {:>.4f} ======".format(curr_nima)) 95 | 96 | if self.args.is_test_psnr_ssim: 97 | self.psnr_save_path = './results/psnr_test_results/' 98 | curr_psnr = calc_psnr(test_save_path, self.args.test_label_dir, self.psnr_save_path, self.args.pretrained_model) 99 | print("====== Avg. PSNR: {:>.4f} dB ======".format(curr_psnr)) 100 | 101 | self.ssim_save_path = './results/ssim_test_results/' 102 | curr_ssim = calc_ssim(test_save_path, self.args.test_label_dir, self.ssim_save_path, self.args.pretrained_model) 103 | print("====== Avg. SSIM: {:>.4f} ======".format(curr_ssim)) 104 | 105 | 106 | """define some functions""" 107 | def build_model(self): 108 | """Create a generator and a discriminator.""" 109 | self.G = Generator(self.args.g_conv_dim, self.args.g_norm_fun, self.args.g_act_fun, self.args.g_use_sn).to(self.device) 110 | self.D = Discriminator(self.args.d_conv_dim, self.args.d_norm_fun, self.args.d_act_fun, self.args.d_use_sn, self.args.adv_loss_type).to(self.device) 111 | if self.args.parallel: 112 | self.G.to(self.args.gpu_ids[0]) 113 | self.D.to(self.args.gpu_ids[0]) 114 | self.G = nn.DataParallel(self.G, self.args.gpu_ids) 115 | self.D = nn.DataParallel(self.D, self.args.gpu_ids) 116 | print("=== Models have been created ===") 117 | 118 | # print network 119 | if self.args.is_print_network: 120 | self.print_network(self.G, 'Generator') 121 | self.print_network(self.D, 'Discriminator') 122 | 123 | 124 | def print_network(self, model, name): 125 | """Print out the network information.""" 126 | num_params = 0 127 | for p in model.parameters(): 128 | num_params += p.numel() 129 | # print(model) 130 | print("=== The number of parameters of the above model [{}] is [{}] or [{:>.4f}M] ===".format(name, num_params, num_params / 1e6)) 131 | 132 | 133 | def load_pretrained_model(self, resume_epochs): 134 | checkpoint_path = os.path.join(self.model_save_path, '{}_{}_{}.pth'.format(self.args.version, self.args.adv_loss_type, resume_epochs)) 135 | if torch.cuda.is_available(): 136 | # save on GPU, load on GPU 137 | checkpoint = torch.load(checkpoint_path) 138 | self.G.load_state_dict(checkpoint['G_net']) 139 | self.D.load_state_dict(checkpoint['D_net']) 140 | else: 141 | # save on GPU, load on CPU 142 | checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 143 | self.G.load_state_dict(checkpoint['G_net']) 144 | self.D.load_state_dict(checkpoint['D_net']) 145 | 146 | print("=========== loaded trained models (epochs: {})! ===========".format(resume_epochs)) 147 | 148 | 149 | def build_tensorboard(self): 150 | """Build a tensorboard logger.""" 151 | self.logger = Logger(self.log_path) -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | 3 | import os 4 | import time 5 | import torch 6 | import datetime 7 | import torch.nn as nn 8 | from torchvision.utils import save_image 9 | from losses import PerceptualLoss, GANLoss, MultiscaleRecLoss 10 | from utils import Logger, denorm, ImagePool 11 | from models import Generator, Discriminator 12 | from metrics.NIMA.CalcNIMA import calc_nima 13 | from metrics.CalcPSNR import calc_psnr 14 | from metrics.CalcSSIM import calc_ssim 15 | from tqdm import * 16 | from data_loader import InputFetcher 17 | 18 | 19 | class Trainer(object): 20 | def __init__(self, loaders, args): 21 | # data loader 22 | self.loaders = loaders 23 | 24 | # Model configuration. 25 | self.args = args 26 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 27 | self.psnr_save_path = args.psnr_save_path 28 | self.ssim_save_path = args.ssim_save_path 29 | self.nima_result_save_path = args.nima_result_save_path 30 | # Directories. 31 | self.model_save_path = os.path.join(args.save_root_dir, args.version, args.model_save_path) 32 | self.sample_path = os.path.join(args.save_root_dir, args.version, args.sample_path) 33 | self.log_path = os.path.join(args.save_root_dir, args.version, args.log_path) 34 | self.val_result_path = os.path.join(args.save_root_dir, args.version, args.val_result_path) 35 | 36 | # Build the model and tensorboard. 37 | self.build_model() 38 | # if self.args.use_tensorboard: 39 | # self.build_tensorboard() 40 | 41 | def train(self): 42 | """ Train UEGAN .""" 43 | self.fetcher = InputFetcher(self.loaders.ref) 44 | self.fetcher_val = InputFetcher(self.loaders.val) 45 | 46 | self.train_steps_per_epoch = len(self.loaders.ref) 47 | self.model_save_step = int(self.args.model_save_epoch * self.train_steps_per_epoch) 48 | 49 | # set nima, psnr, ssim global parameters 50 | if self.args.is_test_nima: 51 | self.best_nima_epoch, self.best_nima = 0, 0.0 52 | if self.args.is_test_psnr_ssim: 53 | self.best_psnr_epoch, self.best_psnr = 0, 0.0 54 | self.best_ssim_epoch, self.best_ssim = 0, 0.0 55 | 56 | # set loss functions 57 | self.criterionPercep = PerceptualLoss() 58 | self.criterionIdt = MultiscaleRecLoss(scale=3, rec_loss_type=self.args.idt_loss_type, multiscale=True) 59 | self.criterionGAN = GANLoss(self.args.adv_loss_type, tensor=torch.cuda.FloatTensor) 60 | 61 | # start from scratch or trained models 62 | if self.args.pretrained_model: 63 | start_step = int(self.args.pretrained_model * self.train_steps_per_epoch) 64 | self.load_pretrained_model(self.args.pretrained_model) 65 | else: 66 | start_step = 0 67 | 68 | # start training 69 | print("======================================= start training =======================================") 70 | self.start_time = time.time() 71 | total_steps = int(self.args.total_epochs * self.train_steps_per_epoch) 72 | self.val_start_steps = int(self.args.num_epochs_start_val * self.train_steps_per_epoch) 73 | self.val_each_steps = int(self.args.val_each_epochs * self.train_steps_per_epoch) 74 | 75 | print("=========== start to iteratively train generator and discriminator ===========") 76 | pbar = tqdm(total=total_steps, desc='Train epoches', initial=start_step) 77 | print(start_step, total_steps) 78 | for step in range(start_step, total_steps): 79 | 80 | ########## model train 81 | self.G.train() 82 | self.D.train() 83 | 84 | ########## data iter 85 | input = next(self.fetcher) 86 | self.real_raw, self.real_exp, self.real_raw_name = input.img_raw, input.img_exp, input.img_name 87 | 88 | ########## forward 89 | self.fake_exp = self.G(self.real_raw) 90 | self.fake_exp_store = self.fake_exp_pool.query(self.fake_exp) 91 | 92 | ########## update D 93 | self.d_optimizer.zero_grad() 94 | real_exp_preds = self.D(self.real_exp) 95 | fake_exp_preds = self.D(self.fake_exp_store.detach()) 96 | d_loss = self.criterionGAN(real_exp_preds, fake_exp_preds, None, None, for_discriminator=True) 97 | if self.args.adv_input: 98 | input_preds = self.D(self.real_raw) 99 | d_loss += self.criterionGAN(real_exp_preds, input_preds, None, None, for_discriminator=True) 100 | d_loss.backward() 101 | self.d_optimizer.step() 102 | self.d_loss = d_loss.item() 103 | 104 | ########## update G 105 | self.g_optimizer.zero_grad() 106 | real_exp_preds = self.D(self.real_exp) 107 | fake_exp_preds = self.D(self.fake_exp) 108 | g_adv_loss = self.args.lambda_adv * self.criterionGAN(real_exp_preds, fake_exp_preds, None, None, for_discriminator=False) 109 | self.g_adv_loss = g_adv_loss.item() 110 | g_loss = g_adv_loss 111 | 112 | g_percep_loss = self.args.lambda_percep * self.criterionPercep((self.fake_exp+1.)/2., (self.real_raw+1.)/2.) 113 | self.g_percep_loss = g_percep_loss.item() 114 | g_loss += g_percep_loss 115 | 116 | self.real_exp_idt = self.G(self.real_exp) 117 | g_idt_loss = self.args.lambda_idt * self.criterionIdt(self.real_exp_idt, self.real_exp) 118 | self.g_idt_loss = g_idt_loss.item() 119 | g_loss += g_idt_loss 120 | 121 | g_loss.backward() 122 | self.g_optimizer.step() 123 | self.g_loss = g_loss.item() 124 | 125 | ### print info and save models 126 | self.print_info(step, total_steps, pbar) 127 | 128 | ### logging using tensorboard 129 | # self.logging(step) 130 | 131 | ### validation 132 | self.model_validation(step) 133 | 134 | ### learning rate update 135 | if step % self.train_steps_per_epoch == 0: 136 | current_epoch = step // self.train_steps_per_epoch 137 | self.lr_scheduler_g.step(epoch=current_epoch) 138 | self.lr_scheduler_d.step(epoch=current_epoch) 139 | for param_group in self.g_optimizer.param_groups: 140 | pbar.write("====== Epoch: {:>3d}/{}, Learning rate(lr) of Encoder(E) and Generator(G): [{}], ".format(((step + 1) // self.train_steps_per_epoch), self.args.total_epochs, param_group['lr']), end='') 141 | for param_group in self.d_optimizer.param_groups: 142 | pbar.write("Learning rate (lr) of Discriminator(D): [{}] ======".format(param_group['lr'])) 143 | 144 | pbar.update(1) 145 | pbar.set_description(f"Train epoch %.2f" % ((step+1.0)/self.train_steps_per_epoch)) 146 | 147 | # self.val_best_results() 148 | 149 | pbar.write("=========== Complete training ===========") 150 | pbar.close() 151 | 152 | 153 | # def logging(self, step): 154 | # self.loss = {} 155 | # self.images = {} 156 | # self.loss['D/Total'] = self.d_loss 157 | # self.loss['G/Total'] = self.g_loss 158 | # self.loss['G/adv_loss'] = self.g_adv_loss 159 | # self.loss['G/percep_loss'] = self.g_percep_loss 160 | # self.loss['G/idt_loss'] = self.g_idt_loss 161 | 162 | # self.images['Train_realExpIdt/realExp_realExpIdt'] = torch.cat([denorm(self.real_exp.cpu()), denorm(self.real_exp_idt.detach().cpu())], 3) 163 | # self.images['Train_compare/realRaw_fakeExp_realExp'] = torch.cat([denorm(self.real_raw.cpu()), denorm(self.fake_exp.detach().cpu()), denorm(self.real_exp.cpu())], 3) 164 | # self.images['Train_fakeExp/fakeExp'] = denorm(self.fake_exp.detach().cpu()) 165 | # self.images['Train_fakeExpStore/fakeExpStore'] = denorm(self.fake_exp_store.detach().cpu()) 166 | 167 | # if (step+1) % self.args.log_step == 0: 168 | # if self.args.use_tensorboard: 169 | # for tag, value in self.loss.items(): 170 | # self.logger.scalar_summary(tag, value, step+1) 171 | # for tag, image in self.images.items(): 172 | # self.logger.images_summary(tag, image, step+1) 173 | 174 | 175 | def print_info(self, step, total_steps, pbar): 176 | current_epoch = (step+1) / self.train_steps_per_epoch 177 | 178 | if (step + 1) % self.args.info_step == 0: 179 | elapsed_num = time.time() - self.start_time 180 | elapsed = str(datetime.timedelta(seconds=elapsed_num)) 181 | pbar.write("Elapse:{:>.12s}, D_Step:{:>6d}/{}, G_Step:{:>6d}/{}, D_loss:{:>.4f}, G_loss:{:>.4f}, G_percep_loss:{:>.4f}, G_adv_loss:{:>.4f}, G_idt_loss:{:>.4f}".format(elapsed, step + 1, total_steps, (step + 1), total_steps, self.d_loss, self.g_loss, self.g_percep_loss, self.g_adv_loss, self.g_idt_loss)) 182 | 183 | # sample images 184 | if (step + 1) % self.args.sample_step == 0: 185 | for i in range(0, self.real_raw.size(0)): 186 | save_imgs = torch.cat([denorm(self.real_raw.data)[i:i + 1,:,:,:], denorm(self.fake_exp.data)[i:i + 1,:,:,:], denorm(self.real_exp.data)[i:i + 1,:,:,:]], 3) 187 | save_image(save_imgs, os.path.join(self.sample_path, '{:s}_{:0>3.2f}_{:0>2d}_realRaw_fakeExp_realExp.png'.format(self.real_raw_name[i], current_epoch, i))) 188 | 189 | # save models 190 | if (step + 1) % self.model_save_step == 0: 191 | if self.args.parallel: 192 | if torch.cuda.device_count() > 1: 193 | checkpoint = { 194 | "G_net": self.G.module.state_dict(), 195 | "D_net": self.D.module.state_dict(), 196 | "epoch": current_epoch, 197 | "g_optimizer": self.g_optimizer.state_dict(), 198 | "d_optimizer": self.d_optimizer.state_dict(), 199 | "lr_scheduler_g": self.lr_scheduler_g.state_dict(), 200 | "lr_scheduler_d": self.lr_scheduler_d.state_dict() 201 | } 202 | else: 203 | checkpoint = { 204 | "G_net": self.G.state_dict(), 205 | "D_net": self.D.state_dict(), 206 | "epoch": current_epoch, 207 | "g_optimizer": self.g_optimizer.state_dict(), 208 | "d_optimizer": self.d_optimizer.state_dict(), 209 | "lr_scheduler_g": self.lr_scheduler_g.state_dict(), 210 | "lr_scheduler_d": self.lr_scheduler_d.state_dict() 211 | } 212 | torch.save(checkpoint, os.path.join(self.model_save_path, '{}_{}_{}.pth'.format(self.args.version, self.args.adv_loss_type, current_epoch))) 213 | 214 | pbar.write("======= Save model checkpoints into {} ======".format(self.model_save_path)) 215 | 216 | 217 | def model_validation(self, step): 218 | if (step + 1) > self.val_start_steps: 219 | if (step + 1) % self.val_each_steps == 0: 220 | val = {} 221 | current_epoch = (step + 1) / self.train_steps_per_epoch 222 | val_save_path = self.val_result_path + '/' + 'validation_' + str(current_epoch) 223 | val_compare_save_path = self.val_result_path + '/' + 'validation_compare_' + str(current_epoch) 224 | val_start = 0 225 | val_total_steps = len(self.loaders.val) 226 | 227 | if not os.path.exists(val_save_path): 228 | os.makedirs(val_save_path) 229 | if not os.path.exists(val_compare_save_path): 230 | os.makedirs(val_compare_save_path) 231 | 232 | self.G.eval() 233 | 234 | pbar = tqdm(total=(val_total_steps - val_start), desc='Validation epoches', position=val_start) 235 | pbar.write("============================== Start validation ==============================") 236 | with torch.no_grad(): 237 | for val_step in range(val_start, val_total_steps): 238 | 239 | input = next(self.fetcher_val) 240 | val_real_raw, val_name = input.img_raw, input.img_name 241 | 242 | val_fake_exp = self.G(val_real_raw) 243 | 244 | for i in range(0, denorm(val_real_raw.data).size(0)): 245 | save_imgs = denorm(val_fake_exp.data)[i:i + 1,:,:,:] 246 | save_image(save_imgs, os.path.join(val_save_path, '{:s}_{:0>3.2f}_valFakeExp.png'.format(val_name[i], current_epoch))) 247 | 248 | save_imgs_compare = torch.cat([denorm(val_real_raw.data)[i:i + 1,:,:,:], denorm(val_fake_exp.data)[i:i + 1,:,:,:]], 3) 249 | save_image(save_imgs_compare, os.path.join(val_compare_save_path, '{:s}_{:0>3.2f}_valRealRaw_valFakeExp.png'.format(val_name[i], current_epoch))) 250 | 251 | elapsed = time.time() - self.start_time 252 | elapsed = str(datetime.timedelta(seconds=elapsed)) 253 | if val_step % self.args.info_step == 0: 254 | pbar.write("=== Elapse:{}, Save {:>3d}-th val_fake_exp images into {} ===".format(elapsed, val_step, val_save_path)) 255 | 256 | val['val/valFakeExp'] = denorm(val_fake_exp.detach().cpu()) 257 | 258 | val['val_compare/valRealRaw_valFakeExp'] = torch.cat([denorm(val_real_raw.cpu()), denorm(val_fake_exp.detach().cpu())], 3) 259 | 260 | pbar.update(1) 261 | 262 | # if self.args.use_tensorboard: 263 | # for tag, images in val.items(): 264 | # self.logger.images_summary(tag, images, val_step + 1) 265 | 266 | pbar.close() 267 | if self.args.is_test_nima: 268 | self.nima_result_save_path = './results/nima_val_results/' 269 | curr_nima = calc_nima(val_save_path, self.nima_result_save_path, current_epoch) 270 | if self.best_nima < curr_nima: 271 | self.best_nima = curr_nima 272 | self.best_nima_epoch = current_epoch 273 | print("====== Avg. NIMA: {:>.4f} ======".format(curr_nima)) 274 | 275 | if self.args.is_test_psnr_ssim: 276 | self.psnr_save_path = './results/psnr_val_results/' 277 | curr_psnr = calc_psnr(val_save_path, self.args.val_label_dir, self.psnr_save_path, current_epoch) 278 | if self.best_psnr < curr_psnr: 279 | self.best_psnr = curr_psnr 280 | self.best_psnr_epoch = current_epoch 281 | print("====== Avg. PSNR: {:>.4f} dB ======".format(curr_psnr)) 282 | 283 | self.ssim_save_path = './results/ssim_val_results/' 284 | curr_ssim = calc_ssim(val_save_path, self.args.val_label_dir, self.ssim_save_path, current_epoch) 285 | if self.best_ssim < curr_ssim: 286 | self.best_ssim = curr_ssim 287 | self.best_ssim_epoch = current_epoch 288 | print("====== Avg. SSIM: {:>.4f} ======".format(curr_ssim)) 289 | torch.cuda.empty_cache() 290 | time.sleep(2) 291 | 292 | 293 | def val_best_results(self): 294 | if self.args.is_test_psnr_ssim: 295 | if not os.path.exists(self.psnr_save_path): 296 | os.makedirs(self.psnr_save_path) 297 | psnr_result = self.psnr_save_path + 'PSNR_total_results_epoch_avgpsnr.csv' 298 | psnrfile = open(psnr_result, 'a+') 299 | psnrfile.write('Best epoch: ' + str(self.best_psnr_epoch) + ',' + str(round(self.best_psnr, 6)) + '\n') 300 | psnrfile.close() 301 | 302 | if not os.path.exists(self.ssim_save_path): 303 | os.makedirs(self.ssim_save_path) 304 | ssim_result = self.ssim_save_path + 'SSIM_total_results_epoch_avgssim.csv' 305 | ssimfile = open(ssim_result, 'a+') 306 | ssimfile.write('Best epoch: ' + str(self.best_ssim_epoch) + ',' + str(round(self.best_ssim, 6)) + '\n') 307 | ssimfile.close() 308 | 309 | if self.args.is_test_nima: 310 | nima_total_result = self.nima_result_save_path + 'NIMA_total_results_epoch_mean_std.csv' 311 | totalfile = open(nima_total_result, 'a+') 312 | totalfile.write('Best epoch:' + str(self.best_nima_epoch) + ',' + str(round(self.best_nima, 6)) + '\n') 313 | totalfile.close() 314 | 315 | 316 | """define some functions""" 317 | def build_model(self): 318 | """Create a generator and a discriminator.""" 319 | self.G = Generator(self.args.g_conv_dim, self.args.g_norm_fun, self.args.g_act_fun, self.args.g_use_sn).to(self.device) 320 | self.D = Discriminator(self.args.d_conv_dim, self.args.d_norm_fun, self.args.d_act_fun, self.args.d_use_sn, self.args.adv_loss_type).to(self.device) 321 | if self.args.parallel: 322 | self.G.to(self.args.gpu_ids[0]) 323 | self.D.to(self.args.gpu_ids[0]) 324 | self.G = nn.DataParallel(self.G, self.args.gpu_ids) 325 | self.D = nn.DataParallel(self.D, self.args.gpu_ids) 326 | print("=== Models have been created ===") 327 | 328 | # print network 329 | if self.args.is_print_network: 330 | self.print_network(self.G, 'Generator') 331 | self.print_network(self.D, 'Discriminator') 332 | 333 | # init network 334 | if self.args.init_type: 335 | self.init_weights(self.G, init_type=self.args.init_type, gain=0.02) 336 | self.init_weights(self.D, init_type=self.args.init_type, gain=0.02) 337 | 338 | # optimizer 339 | if self.args.optimizer_type == 'adam': 340 | # Adam optimizer 341 | self.g_optimizer = torch.optim.Adam(params=self.G.parameters(), lr=self.args.g_lr, betas=[self.args.beta1, self.args.beta2], weight_decay=0.0001) 342 | self.d_optimizer = torch.optim.Adam(params=self.D.parameters(), lr=self.args.d_lr, betas=[self.args.beta1, self.args.beta2], weight_decay=0.0001) 343 | elif self.args.optimizer_type == 'rmsprop': 344 | # RMSprop optimizer 345 | self.g_optimizer = torch.optim.RMSprop(params=self.G.parameters(), lr=self.args.g_lr, alpha=self.args.alpha) 346 | self.d_optimizer = torch.optim.RMSprop(params=self.D.parameters(), lr=self.args.d_lr, alpha=self.args.alpha) 347 | else: 348 | raise NotImplementedError("=== Optimizer [{}] is not found ===".format(self.args.optimizer_type)) 349 | 350 | # learning rate decay 351 | if self.args.lr_decay: 352 | def lambda_rule(epoch): 353 | return 1.0 - max(0, epoch + 1 - self.args.lr_num_epochs_decay) / self.args.lr_decay_ratio 354 | self.lr_scheduler_g = torch.optim.lr_scheduler.LambdaLR(self.g_optimizer, lr_lambda=lambda_rule) 355 | self.lr_scheduler_d = torch.optim.lr_scheduler.LambdaLR(self.d_optimizer, lr_lambda=lambda_rule) 356 | print("=== Set learning rate decay policy for Generator(G) and Discriminator(D) ===") 357 | 358 | self.fake_exp_pool = ImagePool(self.args.pool_size) 359 | 360 | 361 | def init_weights(self, net, init_type='kaiming', gain=0.02): 362 | def init_func(m): 363 | classname = m.__class__.__name__ 364 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 365 | if init_type == 'normal': 366 | torch.nn.init.normal_(m.weight.data, 0.0, gain) 367 | elif init_type == 'xavier': 368 | torch.nn.init.xavier_normal_(m.weight.data, gain=gain) 369 | elif init_type == 'xavier_uniform': 370 | torch.nn.init.xavier_uniform_(m.weight.data, gain=1.0) 371 | elif init_type == 'kaiming': 372 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 373 | elif init_type == 'kaiming_uniform': 374 | torch.nn.init.kaiming_uniform_(m.weight.data, a=0, mode='fan_in') 375 | elif init_type == 'orthogonal': 376 | torch.nn.init.orthogonal_(m.weight.data, gain=gain) 377 | elif init_type == 'none': # uses pytorch's default init method 378 | m.reset_parameters() 379 | else: 380 | raise NotImplementedError('Initialization method [{}] is not implemented'.format(init_type)) 381 | if hasattr(m, 'bias') and m.bias is not None: 382 | torch.nn.init.constant_(m.bias.data, 0.0) 383 | elif classname.find('BatchNorm2d') != -1: 384 | if hasattr(m, 'weight') and m.weight is not None: 385 | torch.nn.init.normal_(m.weight.data, 1.0, gain) 386 | if hasattr(m, 'bias') and m.bias is not None: 387 | torch.nn.init.constant_(m.bias.data, 0.0) 388 | elif classname.find('InstanceNorm2d') != -1: 389 | if hasattr(m, 'weight') and m.weight is not None: 390 | torch.nn.init.normal_(m.weight.data, 1.0, gain) 391 | if hasattr(m, 'bias') and m.bias is not None: 392 | torch.nn.init.constant_(m.bias.data, 0.0) 393 | print("=== Initialize network with [{}] ===".format(init_type)) 394 | net.apply(init_func) 395 | 396 | 397 | def print_network(self, model, name): 398 | """Print out the network information.""" 399 | num_params = 0 400 | for p in model.parameters(): 401 | num_params += p.numel() 402 | # print(model) 403 | print("=== The number of parameters of the above model [{}] is [{}] or [{:>.4f}M] ===".format(name, num_params, num_params / 1e6)) 404 | 405 | 406 | def load_pretrained_model(self, resume_epochs): 407 | checkpoint_path = os.path.join(self.model_save_path, '{}_{}_{}.pth'.format(self.args.version, self.args.adv_loss_type, resume_epochs)) 408 | if torch.cuda.is_available(): 409 | # save on GPU, load on GPU 410 | checkpoint = torch.load(checkpoint_path) 411 | self.G.load_state_dict(checkpoint['G_net']) 412 | self.D.load_state_dict(checkpoint['D_net']) 413 | self.g_optimizer.load_state_dict(checkpoint['g_optimizer']) 414 | self.d_optimizer.load_state_dict(checkpoint['d_optimizer']) 415 | self.lr_scheduler_g.load_state_dict(checkpoint['lr_scheduler_g']) 416 | self.lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d']) 417 | else: 418 | # save on GPU, load on CPU 419 | checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) 420 | self.G.load_state_dict(checkpoint['G_net']) 421 | self.D.load_state_dict(checkpoint['D_net']) 422 | self.g_optimizer.load_state_dict(checkpoint['g_optimizer']) 423 | self.d_optimizer.load_state_dict(checkpoint['d_optimizer']) 424 | self.lr_scheduler_g.load_state_dict(checkpoint['lr_scheduler_g']) 425 | self.lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d']) 426 | 427 | print("=========== loaded trained models (epochs: {})! ===========".format(resume_epochs)) 428 | 429 | 430 | def build_tensorboard(self): 431 | """Build a tensorboard logger.""" 432 | self.logger = Logger(self.log_path) 433 | 434 | 435 | def identity_loss(self, idt_loss_type): 436 | if idt_loss_type == 'l1': 437 | criterion = nn.L1Loss() 438 | return criterion 439 | elif idt_loss_type == 'smoothl1': 440 | criterion = nn.SmoothL1Loss() 441 | return criterion 442 | elif idt_loss_type == 'l2': 443 | criterion = nn.MSELoss() 444 | return criterion 445 | else: 446 | raise NotImplementedError("=== Identity loss type [{}] is not implemented. ===".format(self.args.idt_loss_type)) 447 | 448 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #-*-coding:utf-8-*- 2 | 3 | import os 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import numbers 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import csv 11 | import random 12 | import tensorflow as tf 13 | from torch.utils.tensorboard import SummaryWriter 14 | import torchvision 15 | import scipy.misc 16 | from torch.optim.optimizer import Optimizer, required 17 | try: 18 | from StringIO import StringIO # Python 2.7 19 | except ImportError: 20 | from io import BytesIO # Python 3.x 21 | 22 | 23 | class ImagePool(): 24 | def __init__(self, pool_size): 25 | self.pool_size = pool_size 26 | if self.pool_size > 0: 27 | self.num_imgs = 0 28 | self.images = [] 29 | 30 | def query(self, images): 31 | if self.pool_size == 0: 32 | return images 33 | return_images = [] 34 | for image in images: 35 | image = torch.unsqueeze(image.data, 0) 36 | if self.num_imgs < self.pool_size: 37 | self.num_imgs = self.num_imgs + 1 38 | self.images.append(image) 39 | return_images.append(image) 40 | else: 41 | p = random.uniform(0, 1) 42 | if p > 0.5: 43 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 44 | tmp = self.images[random_id].clone() 45 | self.images[random_id] = image 46 | return_images.append(tmp) 47 | else: 48 | return_images.append(image) 49 | return_images = torch.cat(return_images, 0) 50 | return return_images 51 | 52 | 53 | class Logger(object): 54 | """Create a tensorboard logger to log_dir.""" 55 | def __init__(self, log_dir): 56 | """Initialize summary writer.""" 57 | self.writer = tf.compat.v1.summary.FileWriter(log_dir) 58 | 59 | def scalar_summary(self, tag, value, step): 60 | """Add scalar summary.""" 61 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, simple_value=value)]) 62 | self.writer.add_summary(summary, step) 63 | 64 | def images_summary(self, tag, images, step): 65 | """Log a list of images.""" 66 | img_summaries = [] 67 | for i, img in enumerate(images): 68 | # Write the image to a string 69 | try: 70 | s = StringIO() 71 | except: 72 | s = BytesIO() 73 | scipy.misc.toimage(img).save(s, format="png") 74 | 75 | # Create an Image object 76 | img_sum = tf.compat.v1.Summary.Image(encoded_image_string=s.getvalue(), 77 | height=img.shape[0], 78 | width=img.shape[1]) 79 | # Create a Summary value 80 | img_summaries.append(tf.compat.v1.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 81 | 82 | # Create and write Summary 83 | summary = tf.compat.v1.Summary(value=img_summaries) 84 | self.writer.add_summary(summary, step) 85 | 86 | def histo_summary(self, tag, values, step, bins=1000): 87 | """Log a histogram of the tensor of values.""" 88 | 89 | # Create a histogram using numpy 90 | counts, bin_edges = np.histogram(values, bins=bins) 91 | 92 | # Fill the fields of the histogram proto 93 | hist = tf.compat.v1.HistogramProto() 94 | hist.min = float(np.min(values)) 95 | hist.max = float(np.max(values)) 96 | hist.num = int(np.prod(values.shape)) 97 | hist.sum = float(np.sum(values)) 98 | hist.sum_squares = float(np.sum(values**2)) 99 | 100 | # Drop the start of the first bin 101 | bin_edges = bin_edges[1:] 102 | 103 | # Add bin edges and counts 104 | for edge in bin_edges: 105 | hist.bucket_limit.append(edge) 106 | for c in counts: 107 | hist.bucket.append(c) 108 | 109 | # Create and write Summary 110 | summary = tf.compat.v1.Summary(value=[tf.compat.v1.Summary.Value(tag=tag, histo=hist)]) 111 | self.writer.add_summary(summary, step) 112 | self.writer.flush() 113 | 114 | 115 | def create_folder(root_dir, path, version): 116 | if not os.path.exists(os.path.join(root_dir, path, version)): 117 | os.makedirs(os.path.join(root_dir, path, version)) 118 | 119 | 120 | def var2tensor(x): 121 | return x.data.cpu() 122 | 123 | 124 | def var2numpy(x): 125 | return x.data.cpu().numpy() 126 | 127 | 128 | def denorm(x): 129 | out = (x + 1) / 2.0 130 | return out.clamp_(0, 1) 131 | 132 | 133 | def str2bool(v): 134 | return v.lower() in ('true') 135 | 136 | 137 | def tensor2im(input_image, imtype=np.uint8): 138 | if isinstance(input_image, torch.Tensor): 139 | image_tensor = input_image.data 140 | else: 141 | return input_image 142 | image_numpy = image_tensor[0].cpu().float().numpy() 143 | if image_numpy.shape[0] == 1: 144 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 145 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 146 | return image_numpy.astype(imtype) 147 | 148 | 149 | def setup_seed(seed): 150 | np.random.seed(seed) 151 | random.seed(seed) 152 | torch.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | torch.backends.cudnn.deterministic = True 155 | torch.backends.cudnn.enabled = True 156 | 157 | 158 | class GaussianSmoothing(nn.Module): 159 | """ 160 | Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. 161 | Arguments: 162 | channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. 163 | kernel_size (int, sequence): Size of the gaussian kernel. 164 | sigma (float, sequence): Standard deviation of the gaussian kernel. 165 | dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). 166 | """ 167 | def __init__(self, channels=3, kernel_size=21, sigma=3, dim=2): 168 | super(GaussianSmoothing, self).__init__() 169 | self.padding = nn.ReflectionPad2d(kernel_size//2) 170 | if isinstance(kernel_size, numbers.Number): 171 | kernel_size = [kernel_size] * dim 172 | if isinstance(sigma, numbers.Number): 173 | sigma = [sigma] * dim 174 | 175 | # The gaussian kernel is the product of the gaussian function of each dimension. 176 | kernel = 1 177 | meshgrids = torch.meshgrid( [ 178 | torch.arange(size, dtype=torch.float32) 179 | for size in kernel_size 180 | ] ) 181 | for size, std, mgrid in zip(kernel_size, sigma, meshgrids): 182 | mean = (size - 1) / 2 183 | kernel *= 1 / (std * math.sqrt(2 * math.pi)) * torch.exp(-((mgrid - mean) / std) ** 2 / 2) 184 | 185 | # Make sure sum of values in gaussian kernel equals 1. 186 | kernel = kernel / torch.sum(kernel) 187 | 188 | # Reshape to depthwise convolutional weight 189 | kernel = kernel.view(1, 1, *kernel.size()) 190 | kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) 191 | 192 | self.register_buffer('weight', kernel) 193 | self.groups = channels 194 | 195 | if dim == 1: 196 | self.conv = F.conv1d 197 | elif dim == 2: 198 | self.conv = F.conv2d 199 | elif dim == 3: 200 | self.conv = F.conv3d 201 | else: 202 | raise RuntimeError( 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) ) 203 | 204 | def forward(self, input): 205 | input = self.padding(input) 206 | out = self.conv(input, weight=self.weight, groups=self.groups) 207 | 208 | return out 209 | 210 | 211 | def gray_scale(image): 212 | """ 213 | image : (batch_size, image_size), image_size = image_width * image_height * channels 214 | return : (batch_size, image_size with one channel) 215 | """ 216 | # rgb_image = np.reshape(image, newshape=(-1, channels, height, width)) # 3 channel which is rgb 217 | # gray_image = np.reshape(gray_image, newshape=(-1, 1, height, width)) 218 | # gray_image = torch.from_numpy(gray_image) 219 | 220 | gray_image = torch.unsqueeze(image[:, 0] * 0.299 + image[:, 1] * 0.587 + image[:, 2] * 0.114, 1) 221 | 222 | return gray_image 223 | 224 | 225 | class GaussianNoise(nn.Module): 226 | """A gaussian noise module. 227 | 228 | Args: 229 | stddev (float): The standard deviation of the normal distribution. 230 | Default: 0.1. 231 | 232 | Shape: 233 | - Input: (batch, *) 234 | - Output: (batch, *) (same shape as input) 235 | """ 236 | 237 | def __init__(self, mean=0.0, stddev=0.1): 238 | super(GaussianNoise, self).__init__() 239 | self.mean = mean 240 | self.stddev = stddev 241 | 242 | def forward(self, x): 243 | noise = torch.empty_like(x) 244 | noise.normal_(0, self.stddev) 245 | 246 | return x + noise 247 | 248 | --------------------------------------------------------------------------------