├── README.md ├── environment.yml ├── license.txt ├── model.png └── src ├── config.json ├── data.py ├── losses.py ├── main.py ├── model.py ├── test.py ├── test_one.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Deep-based film grain removal and synthesis 2 | Zoubida Ameur, Wassim Hamidouche, Edouard François, Milos Radosavljevic, Daniel Ménard and Claire-Hélène Demarty 3 | 4 | 5 | ## Abstract 6 | 7 | In this paper, deep learning-based techniques for film grain removal and synthesis that can be applied in video coding are proposed. Film grain is inherent in analog film content because of the physical process of capturing images and video on 8 | film. It can also be present in digital content where it is purposely added to reflect the era of analog film and to evoke certain emotions in the viewer or enhance the perceived quality. In the context of video coding, the random nature of film grain makes 9 | it both difficult to preserve and very expensive to compress. To better preserve it while compressing the content efficiently, film grain is removed and modeled before video encoding and then restored after video decoding. In this paper, a film grain removal model based on an encoder-decoder architecture and a film grain 10 | synthesis model based on a conditional generative adversarial network (cGAN) are proposed. Both models are trained on a large 11 | dataset of pairs of clean (grain-free) and grainy images. Quantitative and qualitative evaluations of the developed solutions were conducted and showed that the proposed film grain removal 12 | model is effective in filtering film grain at different intensity levels using two configurations: 1) a non-blind configuration where the film grain level of the grainy input is known and provided as input, 2) a blind configuration where the film grain level is 13 | unknown. As for the film grain synthesis task, the experimental results show that the proposed model is able to reproduce realistic film grain with a controllable intensity level specified as input. 14 | 15 | 16 | ## Network architecture 17 | 18 | ![](model.png) 19 | 20 | 21 | ## Structure of the repository 22 | - `environment.yml`: conda environment specifications and required packages 23 | - `src`: 24 | - `data.py`: data loader 25 | - `model.py`: models builder 26 | - `utils.py`: utility functions 27 | - `losses.py`: loss functions 28 | - `train.py`: training loops. 29 | - `main.py`: script for training 30 | - `test.py`: script for evaluating a trained model on a test dataset 31 | - `test_one.py`: script for evaluating a trained model on a single test image 32 | - `config.json`: training parameters 33 | 34 | File config.json contains the following training parameters: 35 | ``` 36 | 37 | { 38 | "epochs" : 5, 39 | "batch_size" : 1, 40 | "input_dim" : 256, 41 | "learning_rate_alpha_gen" : 3e-4, 42 | "learning_rate_alpha_dis" : 1e-4, 43 | "learning_rate_beta_gen" : 0.5, 44 | "levels" : [0.01,0.025,0.05,0.075,0.1] 45 | } 46 | 47 | ``` 48 | 49 | 50 | ## Structure of the associated FilmGrain dataset 51 | 52 | In this paper, we also provide a FilmGrain dataset that is available at the address: 53 | https://www.interdigital.com/data_sets/filmgrain-dataset 54 | 55 | The structure of the dataset is the following: 56 | 57 | ``` 58 | FilmGrainDataset 59 | ├── org 60 | └── fg 61 | ├── 01 62 | ├── 025 63 | ├── 05 64 | ├── 075 65 | └── 1 66 | ``` 67 | 68 | Subfolder org contains original images without film grain. Subfolder fg contains images with grain. 69 | 70 | 71 | ## 1. Getting started 72 | 73 | Clone the repository 74 | 75 | ```bash 76 | git clone https://github.com/InterDigitalInc/DeepFilmGrain.git 77 | cd DeepFilmGrain/ 78 | ``` 79 | 80 | Create and activate a conda environment with required packages: 81 | ```bash 82 | conda env create --name example-environment --file environment.yml 83 | conda activate example-environment 84 | ``` 85 | 86 | ## 2.Finding a free gpu node 87 | 88 | When running on gpu, you'll have to explicitly name the node on which you want to launch your command. To see free nodes, under linux, type: 89 | ```bash 90 | nvidia-smi 91 | ``` 92 | 93 | 94 | ## 3.Testing 95 | 96 | 97 | ### On a test dataset 98 | 99 | To test on one folder of test images, run the following (assuming gpu node#2 is free): 100 | ```bash 101 | CUDA_VISIBLE_DEVICES=2 python3 src/test.py --pretrained_model path/to/pretrained/model/model.h5 --level 0.01 --input_path path/to/input/folder/ --output_path path/to/output/folder/ 102 | ``` 103 | | Option | Description | 104 | |-----------------|------------------------------------------------------------| 105 | | --pretrained_model | path to pretrained model| 106 | | --level | film grain levels {0.01, 0.025, 0.05, 0.075, 0.1} | 107 | | --input_path | path to input folder 108 | | --output_path | path to output folder 109 | 110 | When testing on the FilmGrain dataset: 111 | * if the model was trained for removing grain, use subfolder fg (samples with grain) as input; 112 | * if the model was trained to synthesize grain, use subfolder org (samples without grain) as input. 113 | 114 | ### On one test image with different film grain patterns 115 | To test on a single image, run the following (same options as above and assuming gpu node#2 is free): 116 | 117 | ```bash 118 | CUDA_VISIBLE_DEVICES=2 python3 src/test_one.py --pretrained_model path/to/pretrained/model/model.h5 --level 0.1 --input_path path/to/input/image --output_path path/to/output/folder/ 119 | ``` 120 | 121 | ## 4. Training 122 | ### On GPU 123 | An example command is (assuming gpu node#2 is free): 124 | ```bash 125 | CUDA_VISIBLE_DEVICES=2 python3 src/main.py --path path/to/dataset/ --task removal_case_1 126 | ``` 127 | 128 | 129 | | Option | Description | Usage | 130 | |-----------------|------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------| 131 | | --path | path of the custom dataset | Path to the dataset to be used for training | 132 | | --task | task you want to learn | Task and case of the ablation study, described below. | 133 | 134 | The different configurations considered in the ablation study are the following: 135 | 136 | | Tasks | Inputs | Backbone | Loss functions | 137 | |-------|---------------------------------------|----------|----------------| 138 | | removal_case_1 | grainy image |U-Net | l1 | 139 | | removal_case_2 | grainy image |U-Net + residual blocks | l1 | 140 | | removal_case_3 | grainy image |U-Net + residual blocks | l1 + MS-SSIM | 141 | | removal_case_4 | grainy image |U-Net + residual blocks | l1 + MS-SSIM | 142 | | synthesis_case_1 | clean image |U-Net | l1 | 143 | | synthesis_case_2 | clean image |U-Net + residual blocks | l1 | 144 | | synthesis_case_3 | clean image |U-Net + residual blocks + PatchGAN | l1 + adv loss | 145 | 146 | PS : synthesis_case_3_gray, removal_case_3_gray, removal_case_4_gray are the grayscale versions. 147 | 148 | ### Train on the FilmGrain dataset 149 | ```bash 150 | CUDA_VISIBLE_DEVICES=2 python3 src/main.py --path path/to/FilmGrain/dataset/ --task synthesis_case_1 151 | ``` 152 | 153 | ### Train on your dataset 154 | If you want to train on your own dataset, use the option –-path to set the dataset path and set the folder tree as follows: 155 | 156 | ``` 157 | dataset 158 | ├── org 159 | ├── fg 160 | │ ├── level1 161 | │ ├── level2 162 | │ ├── ... 163 | │ ├── ... 164 | └── └── leveln 165 | ``` 166 | 167 | level1, level2, ..., leveln are the different film grain intensity levels. 168 | 169 | 170 | ## Citation 171 | ``` 172 | @article{ameur2022deep, 173 | title={Deep-based Film Grain Removal and Synthesis}, 174 | author={Ameur, Zoubida and Hamidouche, Wassim and Fran{\c{c}}ois, Edouard and Radosavljevi{\'c}, Milo{\v{s}} and M{\'e}nard, Daniel and Demarty, Claire-H{\'e}l{\`e}ne}, 175 | journal={arXiv preprint arXiv:2206.07411}, 176 | year={2022} 177 | } 178 | ``` 179 | ## License 180 | Copyright © 2022, InterDigital R&D France. All rights reserved. 181 | This source code is made available under the license found in the license.txt in the root directory of this source tree. 182 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: example-environment 2 | 3 | dependencies: 4 | - tensorflow-gpu=2.2 5 | - cudatoolkit=10.1 6 | - cudnn=7.6 7 | - python=3.8 8 | - pip=20.0 9 | - imageio 10 | - matplotlib 11 | - numpy 12 | - pillow 13 | - scikit-image 14 | -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 InterDigital R&D France 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 4 | o * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | o * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | o * Neither the name of InterDigital R&D France nor its affiliiates in InterDigital group nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 7 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | 9 | 10 | -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InterDigitalInc/DeepFilmGrain/45b2bfd9c08c970bb17b71731863da992499a943/model.png -------------------------------------------------------------------------------- /src/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "epochs" : 5, 3 | "batch_size" : 1, 4 | "input_dim" : 256, 5 | "learning_rate_alpha_gen" : 3e-4, 6 | "learning_rate_alpha_dis" : 1e-4, 7 | "learning_rate_beta_gen" : 0.5, 8 | "levels" : [0.01,0.025,0.05,0.075,0.1] 9 | } 10 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | from PIL import Image 11 | import pickle 12 | from numpy import asarray 13 | import os 14 | from math import ceil 15 | from math import log10, sqrt 16 | import random 17 | 18 | 19 | class Data: 20 | def __init__(self,list_IDs, list_levels, input_dim = 256, batch_size = 1, task= "synthesis"): 21 | self.list_IDs=list_IDs 22 | self.list_levels = list_levels 23 | self.task = task 24 | self.input_dim=input_dim 25 | self.batch_size = batch_size 26 | self.list_IDs_org, self.list_IDs_fg, self.list_levels = self.load_pkl_multilevel(list_IDs,list_levels) 27 | self.on_epoch_end() 28 | 29 | def load_pkl_multilevel(self, list_IDs, list_levels): 30 | pickle_in = open(list_IDs,'rb') 31 | list_IDs_org = pickle.load(pickle_in)["org"] 32 | pickle_in.close() 33 | pickle_in = open(list_IDs,'rb') 34 | list_IDs_fg = pickle.load(pickle_in)["fg"] 35 | pickle_in.close() 36 | pickle_in = open(list_levels,'rb') 37 | list_levels = pickle.load(pickle_in) 38 | pickle_in.close() 39 | return list_IDs_org, list_IDs_fg, list_levels 40 | 41 | def batches(self): 42 | return int(np.floor(len(self.list_IDs_org) / self.batch_size)) 43 | 44 | def on_epoch_end(self): 45 | self.indexes = np.arange(len(self.list_IDs_org)) 46 | return self.indexes 47 | 48 | def data_generation(self): 49 | X = [] 50 | for i, ID in enumerate(self.list_IDs_temp): 51 | image = Image.open(ID) 52 | image = asarray(image) 53 | image = tf.cast(image, tf.float32) 54 | if "removal" in self.task: 55 | image = (image /127.5)-1 56 | X.append(image) 57 | return tf.stack(X) 58 | 59 | 60 | 61 | def generate_samples_multilevel(self,index): 62 | indexes = self.indexes[index * self.batch_size:(index+1) * self.batch_size] 63 | self.list_IDs_temp = [self.list_IDs_org[k] for k in indexes] 64 | X_org = self.data_generation() 65 | self.list_IDs_temp = [self.list_IDs_fg[k] for k in indexes] 66 | X_fg = self.data_generation() 67 | X_levels = np.empty((self.batch_size, self.input_dim, self.input_dim, 1), dtype=np.float32) 68 | for i, ID in enumerate(self.list_IDs_temp): 69 | X_levels[i,:,:,:] = np.full((self.input_dim, self.input_dim, 1), self.list_levels[ID]) 70 | if "removal" in self.task: 71 | return X_fg, X_org, X_levels 72 | if "synthesis" in self.task: 73 | return X_org, X_fg, X_levels 74 | 75 | 76 | 77 | 78 | def process_input(input_image,normalize): 79 | image = Image.open(input_image) 80 | image = asarray(image,dtype=np.float32) 81 | image = tf.cast(image, tf.float32) 82 | image = tf.expand_dims(image, 0, name=None) 83 | original_shape = image.shape 84 | width_add = ceil(original_shape[1] / 32 ) * 32 - original_shape[1] 85 | height_add = ceil(original_shape[2] / 32 ) * 32 - original_shape[2] 86 | if len(original_shape) != 3 : 87 | paddings = tf.constant([[0,0],[0, width_add], [0, height_add],[0,0]]) 88 | else: 89 | paddings = tf.constant([[0,0],[0, width_add], [0, height_add]]) 90 | image = tf.pad(image, paddings, "SYMMETRIC") 91 | 92 | if normalize: 93 | image = (image /127.5)-1 94 | return image, original_shape 95 | 96 | 97 | 98 | def pickle_files(path, levels): 99 | arr = [] 100 | for root, dirs, files in os.walk(path +"fg/01/"): 101 | for file in files: 102 | if(".tiff" in file): 103 | arr.append(os.path.join(root,file)) 104 | 105 | partition={} 106 | org=[] 107 | fg=[] 108 | level = {} 109 | 110 | for i in range(len(arr)): 111 | name = arr[i] 112 | for j in range(len(levels)): 113 | fg.append(name.replace("fg/01/", "fg/"+ str(levels[j]).replace(".", "")+"/")) 114 | org.append(name.replace("fg/01/", "org/")) 115 | level[name.replace("fg/01/", "fg/"+ str(levels[j]).replace(".", "")+"/")]= levels[j] 116 | 117 | 118 | c = list(zip(fg, org)) 119 | 120 | random.shuffle(c) 121 | 122 | fg, org= zip(*c) 123 | partition["org"] = org 124 | partition["fg"] = fg 125 | list_IDs = "data/list_samples.pickle" 126 | filehandler = open(list_IDs,"wb") 127 | pickle.dump(partition,filehandler) 128 | filehandler.close() 129 | 130 | list_levels = "data/list_levels.pickle" 131 | filehandler = open(list_levels,"wb") 132 | pickle.dump(level,filehandler) 133 | filehandler.close() 134 | return list_IDs, list_levels 135 | 136 | -------------------------------------------------------------------------------- /src/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | 10 | loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=False) 11 | 12 | 13 | def generator_loss(disc_generated_output, gen_output, target): 14 | gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output) 15 | l1_loss = tf.reduce_mean(tf.abs(target - gen_output)) 16 | total_gen_loss = gan_loss + (0.1 * l1_loss) 17 | return total_gen_loss, gan_loss, l1_loss 18 | 19 | 20 | 21 | def discriminator_loss(disc_real_output, disc_generated_output): 22 | real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output) 23 | generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output) 24 | total_disc_loss = ( real_loss + generated_loss) 25 | return total_disc_loss 26 | 27 | 28 | 29 | class LOSS(object): 30 | def __init__(self, k1=0.01, k2=0.02, L=2, window_size=11, channel=3): 31 | self.k1 = k1 32 | self.k2 = k2 # constants for stable 33 | self.L = L # the value range of input image pixels 34 | self.WS = window_size 35 | self.channel= channel 36 | 37 | def _tf_fspecial_gauss(self, size, sigma=1.5): 38 | """Function to mimic the 'fspecial' gaussian MATLAB function""" 39 | x_data, y_data = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 40 | 41 | x_data = np.expand_dims(x_data, axis=-1) 42 | x_data = np.expand_dims(x_data, axis=-1) 43 | 44 | y_data = np.expand_dims(y_data, axis=-1) 45 | y_data = np.expand_dims(y_data, axis=-1) 46 | 47 | x = tf.constant(x_data, dtype=tf.float32) 48 | y = tf.constant(y_data, dtype=tf.float32) 49 | 50 | g = tf.exp(-((x**2 + y**2)/(2.0*sigma**2))) 51 | return g / tf.reduce_sum(g) 52 | 53 | 54 | 55 | def ssim_loss(self, img1, img2): 56 | """ 57 | The function is to calculate the ssim score 58 | """ 59 | window = self._tf_fspecial_gauss(size=self.WS) # output size is (window_size, window_size, 1, 1) 60 | 61 | 62 | (_, _, _, self.channel) = img1.shape.as_list() 63 | 64 | window = tf.tile(window, [1, 1, self.channel, 1]) 65 | 66 | mu1 = tf.nn.depthwise_conv2d(img1, window, strides = [1, 1, 1, 1], padding = 'VALID') 67 | mu2 = tf.nn.depthwise_conv2d(img2, window, strides = [1, 1, 1, 1], padding = 'VALID') 68 | 69 | mu1_sq = mu1 * mu1 70 | mu2_sq = mu2 * mu2 71 | mu1_mu2 = mu1 * mu2 72 | 73 | img1_2 = img1*img1 74 | sigma1_sq = tf.subtract(tf.nn.depthwise_conv2d(img1_2, window, strides = [1 ,1, 1, 1], padding = 'VALID') , mu1_sq) 75 | img2_2 = img2*img2 76 | sigma2_sq = tf.subtract(tf.nn.depthwise_conv2d(img2_2, window, strides = [1, 1, 1, 1], padding = 'VALID') ,mu2_sq) 77 | img12_2 = img1*img2 78 | sigma1_2 = tf.subtract(tf.nn.depthwise_conv2d(img12_2, window, strides = [1, 1, 1, 1], padding = 'VALID') , mu1_mu2) 79 | 80 | c1 = (self.k1*self.L)**2 81 | c2 = (self.k2*self.L)**2 82 | 83 | v1 = 2.0 * sigma1_2 +c2 84 | v2 = sigma1_sq + sigma2_sq +c2 85 | # print(v2) 86 | cs = v1 / v2 87 | ssim_map = ((2*mu1_mu2 + c1)* v1) / ((mu1_sq + mu2_sq + c1)* v2) 88 | 89 | return tf.reduce_mean(ssim_map), tf.reduce_mean(cs) 90 | 91 | 92 | def ms_ssim_l1(self, img1, img2,level=5): 93 | weight = tf.constant([0.0448, 0.2856, 0.3001, 0.2363, 0.1333], dtype=tf.float32) 94 | mssim = [] 95 | mcs = [] 96 | 97 | img1_zero = tf.pad(img1, [[0,0], [5, 5], [5,5], [0,0]], "CONSTANT") 98 | img2_zero = tf.pad(img2, [[0,0], [5, 5], [5, 5], [0,0]], "CONSTANT") 99 | 100 | l1 = tf.abs(img1_zero - img2_zero) 101 | 102 | #Padding reflect for ssim computation 103 | img1 = tf.pad(img1, [[0,0], [5, 5], [5,5], [0,0]], "REFLECT") 104 | img2 = tf.pad(img2, [[0,0], [5, 5], [5, 5], [0,0]], "REFLECT") 105 | 106 | for l in range(level): 107 | ssim_map, cs_map = self.ssim_loss(img1, img2) 108 | mssim.append(ssim_map) 109 | mcs.append(cs_map) 110 | 111 | filtered_im1 = tf.nn.avg_pool2d(img1, [1,2,2,1], [1,2,2,1], padding='SAME') 112 | filtered_im2 = tf.nn.avg_pool2d(img2, [1,2,2,1], [1,2,2,1], padding='SAME') 113 | img1 = filtered_im1 114 | img2 = filtered_im2 115 | 116 | mssim = tf.stack(mssim) 117 | mcs = tf.stack(mcs) 118 | 119 | mssim = (mssim + 1) / 2 120 | mcs = (mcs + 1) / 2 121 | 122 | pow1 = mcs ** weight 123 | pow2 = mssim ** weight 124 | 125 | ms_loss = 1 - (tf.math.reduce_prod(pow1[:-1]) * pow2[-1]) 126 | window = self._tf_fspecial_gauss(size=self.WS) # output size is (window_size, window_size, 1, 1) 127 | window = tf.tile(window, [1, 1, self.channel, 1]) 128 | weighted_l1 = tf.nn.depthwise_conv2d(l1, window, strides = [1, 1, 1, 1], padding = 'VALID') 129 | weighted_l1 = tf.reduce_mean(weighted_l1) 130 | 131 | return 0.84 * ms_loss + 0.16 * weighted_l1 132 | 133 | 134 | def l1(self, img1, img2): 135 | return tf.reduce_mean(tf.abs(img1 - img2)) 136 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | 7 | import tensorflow as tf 8 | import argparse 9 | from train import trainloop 10 | import json 11 | 12 | 13 | #============================ Arguments parsing ============================ # 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--task', default='removal_case_1' , type=str, help='Specify task') 16 | parser.add_argument('--path', type=str, help='Dataset path') 17 | args = parser.parse_args() 18 | with open('src/config.json', 'r') as config_file: 19 | data_config = json.load(config_file) 20 | 21 | trainloop(data_config, args.task, args.path) 22 | 23 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | 7 | from tensorflow.keras.initializers import RandomNormal 8 | from tensorflow.keras.models import Model 9 | from tensorflow.keras.layers import Conv2DTranspose, ZeroPadding2D, Conv2D, MaxPooling2D, UpSampling2D 10 | from tensorflow.keras.layers import ReLU, LeakyReLU 11 | from tensorflow.keras.layers import Concatenate, Input 12 | from tensorflow.keras.layers import BatchNormalization,Activation, Add 13 | 14 | 15 | ###### Define encoder block ########### CONV + BN + ReLu 16 | def encoder_block(inputs, num_filters, batchnorm= True, padding="same", strides =2): 17 | init = RandomNormal(stddev=0.02) 18 | x = Conv2D(num_filters, kernel_size=(4,4), strides=(2,2), padding="same", kernel_initializer=init,use_bias=False)(inputs) 19 | if batchnorm: 20 | x = BatchNormalization()(x) 21 | x =ReLU()(x) 22 | return x 23 | 24 | ###### Define decoder block ########### CONV + BN + ReLu + Concat 25 | def decoder_block(inputs, skip_in, num_filters, batchnorm=True): 26 | init = RandomNormal(stddev=0.02) 27 | x = Conv2DTranspose(num_filters, kernel_size=(4, 4), strides=(2,2), padding="same", kernel_initializer=init,use_bias=False)(inputs) 28 | if batchnorm: 29 | x = BatchNormalization()(x) 30 | x = ReLU()(x) 31 | x = Concatenate()([x, skip_in]) 32 | return x 33 | 34 | ###### Define encoder block for disc ########### CONV + BN + LeakyReLu 35 | def encoder_block_disc(inputs, num_filters, batchnorm= True, padding="same", strides =2): 36 | init = RandomNormal(stddev=0.02) 37 | x = Conv2D(num_filters, kernel_size=(4,4), strides = (2,2), padding="same", kernel_initializer=init,use_bias=False)(inputs) 38 | if batchnorm: 39 | x = BatchNormalization()(x) 40 | x =LeakyReLU()(x) 41 | return x 42 | 43 | 44 | ####### Define generator based on unet ############# 45 | def unet(input_dim=None, output_channel=3, normalized = False): 46 | init = RandomNormal(stddev=0.02) 47 | inputs = Input((input_dim, input_dim, output_channel)) 48 | level = Input((input_dim, input_dim, 1)) 49 | x = Concatenate()([inputs,level]) 50 | ####### 5 encoder blocks ######### 51 | s1 = encoder_block(x, 64, batchnorm= False) 52 | s2 = encoder_block(s1, 128) 53 | s3 = encoder_block(s2, 256) 54 | s4= encoder_block(s3, 512) 55 | 56 | b = encoder_block(s4, 512 ) 57 | 58 | 59 | d4 = decoder_block(b, s4, 512) 60 | d5 = decoder_block(d4, s3, 256) 61 | d6 = decoder_block(d5, s2, 128) 62 | d7 = decoder_block(d6, s1, 64) 63 | 64 | ###### Output ######## 65 | outputs = Conv2DTranspose(output_channel, kernel_size=(4, 4), strides=(2,2), padding="same", kernel_initializer=init)(d7) 66 | if normalized == True: 67 | outputs = Activation('tanh')(outputs) 68 | model = Model([inputs, level], outputs, name="U-Net") 69 | return model 70 | 71 | def patchgan(input_dim=256, output_channel=3): 72 | init = RandomNormal(stddev=0.02) 73 | real = Input((input_dim, input_dim, output_channel)) 74 | fake = Input((input_dim, input_dim, output_channel)) 75 | level = Input((input_dim, input_dim, 1)) 76 | 77 | inputs = Concatenate()([real,fake,level]) 78 | 79 | a1 = encoder_block_disc(inputs, 64, batchnorm= False) 80 | a2 = encoder_block_disc(a1, 128) 81 | a3 = encoder_block_disc(a2, 256) 82 | 83 | a = ZeroPadding2D()(a3) 84 | 85 | x = Conv2D(512, kernel_size=(4,4), strides = (1,1), kernel_initializer=init)(a) 86 | x = BatchNormalization()(x) 87 | x = ReLU()(x) 88 | 89 | a = ZeroPadding2D()(x) 90 | 91 | output = Conv2D(output_channel, kernel_size=(4,4), strides =(1,1), kernel_initializer=init, activation ="sigmoid")(a) 92 | 93 | model = Model([real, fake,level],output, name="PatchGAN") 94 | return model 95 | 96 | 97 | ####### Define generator based on unet with res blocks ############# 98 | def res_unet(filter_root, depth, input_dim=None, output_channel=3, normalized= False, blind=False): 99 | 100 | inputs = Input((input_dim, input_dim, output_channel)) 101 | level = Input((input_dim, input_dim, 1)) 102 | x = Concatenate()([inputs,level]) 103 | 104 | if blind == True: 105 | x = inputs 106 | 107 | 108 | # Dictionary for long connections 109 | long_connection_store = {} 110 | # Down sampling 111 | for i in range(depth): 112 | out_channel = 2**i * filter_root 113 | 114 | # Residual/Skip connection 115 | res = Conv2D(out_channel, kernel_size=1, padding='same', use_bias=False)(x) 116 | 117 | # First Conv Block with Conv, BN and activation 118 | conv1 = Conv2D(out_channel, kernel_size=3, padding='same')(x) 119 | conv1 = BatchNormalization()(conv1) 120 | act1 = Activation('relu')(conv1) 121 | 122 | # Second Conv block with Conv and BN only 123 | conv2 = Conv2D(out_channel, kernel_size=3, padding='same')(act1) 124 | conv2 = BatchNormalization()(conv2) 125 | 126 | # Add + ReLU 127 | resconnection = Add()([res, conv2]) 128 | act2 = Activation('relu')(resconnection) 129 | 130 | # Max pooling 131 | if i < depth - 1: 132 | long_connection_store[str(i)] = act2 133 | x = MaxPooling2D(padding='same')(act2) 134 | else: 135 | x = act2 136 | # Upsampling 137 | for i in range(depth - 2, -1, -1): 138 | out_channel = 2**(i) * filter_root 139 | long_connection = long_connection_store[str(i)] 140 | 141 | # Upsampling + conv 142 | up1 = UpSampling2D()(x) 143 | up_conv1 = Conv2D(out_channel, 2, activation='relu', padding='same')(up1) 144 | 145 | #Long skip connection 146 | up_conc = Concatenate(axis=-1)([up_conv1, long_connection]) 147 | 148 | # Residual/Skip connection 149 | res = Conv2D(out_channel, kernel_size=1, padding='same', use_bias=False)(up_conc) 150 | 151 | # First Conv Block with Conv, BN and activation 152 | up_conv2 = Conv2D(out_channel, 3, padding='same')(up_conc) 153 | up_conv2 = BatchNormalization()(up_conv2) 154 | up_act1 = Activation('relu')(up_conv2) 155 | 156 | # Second Conv block with Conv and BN only 157 | up_conv2 = Conv2D(out_channel, 3, padding='same')(up_act1) 158 | up_conv2 = BatchNormalization()(up_conv2) 159 | 160 | # Add + ReLU 161 | resconnection = Add()([res, up_conv2]) 162 | x = Activation('relu')(resconnection) 163 | 164 | 165 | # Final convolution 166 | output = Conv2D(output_channel, 1, padding='same',name='output')(x) 167 | if normalized == True: 168 | output = Activation('tanh')(output) 169 | if blind == True: 170 | return Model(inputs, outputs=output, name='Res-UNet-Blind') 171 | return Model([inputs,level], outputs=output, name='Res-UNet') 172 | 173 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | 7 | import logging 8 | logging.getLogger('tensorflow').setLevel(logging.FATAL) 9 | from skimage.metrics import structural_similarity as ssim 10 | from skimage.metrics import peak_signal_noise_ratio as psnr 11 | import argparse 12 | import os 13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 14 | import tensorflow as tf 15 | from tensorflow.keras.models import load_model 16 | from PIL import Image 17 | import numpy as np 18 | from numpy import asarray 19 | import imageio 20 | from warnings import simplefilter 21 | simplefilter(action='ignore', category=FutureWarning) 22 | from data import process_input 23 | 24 | 25 | def main(): 26 | 27 | # ---------------------------------------- 28 | # Preparation 29 | # ---------------------------------------- 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--pretrained_model', type=str, help='path/of/pretrained/model') 32 | parser.add_argument('--level', type=float, default='0.01', help='0.01, 0.025, 0.05, 0.075, 0.1') 33 | parser.add_argument('--input_path', type=str, help='path/to/test/set') 34 | parser.add_argument('--output_path', type=str, help='path/to/test/result') 35 | 36 | args = parser.parse_args() 37 | 38 | list_levels= [0.01, 0.025, 0.05, 0.075, 0.1] 39 | 40 | if ".h5" not in args.pretrained_model: 41 | raise argparse.ArgumentTypeError( 42 | 'Incorrect pretrained model fromat') 43 | 44 | if args.level not in list_levels: 45 | raise argparse.ArgumentTypeError( 46 | 'Specified level not supported, supported levels = {0.01, 0.025, 0.05, 0.075, 0.1}') 47 | 48 | if "removal" in args.pretrained_model: 49 | print("Task : Film grain removal") 50 | else: 51 | print("Task : Film grain synthesis") 52 | 53 | print("Model : " + args.pretrained_model) 54 | level = str(args.level).replace(".", "") 55 | print("Level : " + level) 56 | 57 | # ---------------------------------------- 58 | # L_path, E_path, H_path 59 | # ---------------------------------------- 60 | if "removal" in args.pretrained_model: 61 | L_path = args.input_path 62 | else: 63 | L_path = args.input_path 64 | 65 | E_path = args.output_path 66 | 67 | if not os.path.exists(E_path): 68 | os.makedirs(E_path) 69 | paths = os.listdir(L_path) 70 | 71 | ## ---------------------------------------- 72 | ## load model 73 | ## ---------------------------------------- 74 | model = load_model(args.pretrained_model, compile=False) 75 | 76 | ## ---------------------------------------- 77 | ## Test model 78 | ## ---------------------------------------- 79 | if "removal" in args.pretrained_model: 80 | for idx, img in enumerate(paths): 81 | image_name = L_path + img 82 | image_L, original_shape = process_input(image_name, True) 83 | if 'non_blind' in args.pretrained_model: 84 | shape = image_L.shape 85 | X_levels = np.empty((1, shape[1], shape[2], 1), dtype=np.float32) 86 | X_levels[0,:,:,:] = np.full((shape[1], shape[2], 1), args.level) 87 | image_E = model([image_L, X_levels], training = False) 88 | else: 89 | image_E = model(image_L, training = False) 90 | 91 | image_E = image_E[0, :original_shape[1], :original_shape[2],:] 92 | image_E = np.array(image_E)*0.5 +0.5 93 | image_E = np.clip(image_E,0.0,1.0)*255 94 | image_E = image_E.astype(np.uint8) 95 | imageio.imsave(E_path + img, image_E) 96 | 97 | if "synthesis" in args.pretrained_model: 98 | for idx, img in enumerate(paths): 99 | image_name = L_path + img 100 | image_L, original_shape = process_input(image_name, False) 101 | shape = image_L.shape 102 | X_levels = np.empty((1, shape[1], shape[2], 1), dtype=np.float32) 103 | X_levels[0,:,:,:] = np.full((shape[1], shape[2], 1), args.level) 104 | image_E = model([image_L, X_levels], training = True) 105 | image_E = image_E[0, :original_shape[1], :original_shape[2],:] 106 | image_E = np.clip(image_E,0.0,255.0) 107 | image_E = image_E.astype(np.uint8) 108 | imageio.imsave(E_path + img, image_E) 109 | 110 | 111 | if __name__ == '__main__': 112 | main() 113 | 114 | -------------------------------------------------------------------------------- /src/test_one.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | 7 | import logging 8 | logging.getLogger('tensorflow').setLevel(logging.FATAL) 9 | from skimage.metrics import structural_similarity as ssim 10 | from skimage.metrics import peak_signal_noise_ratio as psnr 11 | import argparse 12 | import os 13 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 14 | import tensorflow as tf 15 | from tensorflow.keras.models import load_model 16 | from PIL import Image 17 | import numpy as np 18 | from numpy import asarray 19 | import imageio 20 | from warnings import simplefilter 21 | simplefilter(action='ignore', category=FutureWarning) 22 | from data import process_input 23 | 24 | 25 | 26 | def main(): 27 | 28 | # ---------------------------------------- 29 | # Preparation 30 | # ---------------------------------------- 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--pretrained_model', type=str, default='synthesis', help='synthesis_RGB, synthesis_Y, removal_blind_RGB, removal_blind_Y ,removal_non_blind_RGB, removal_non_blind_Y ') 33 | parser.add_argument('--level', type=float, default='0.01', help='0.01, 0.025, 0.05, 0.075, 0.1') 34 | parser.add_argument('--input_path', type=str) 35 | parser.add_argument('--output_path', type=str) 36 | args = parser.parse_args() 37 | 38 | 39 | list_levels= [0.01, 0.025, 0.05, 0.075, 0.1] 40 | 41 | if ".h5" not in args.pretrained_model: 42 | raise argparse.ArgumentTypeError( 43 | 'Incorrect pretrained model fromat') 44 | 45 | if args.level not in list_levels: 46 | raise argparse.ArgumentTypeError( 47 | 'Specified level not supported, supported levels = {0.01, 0.025, 0.05, 0.075, 0.1}') 48 | 49 | if not os.path.exists(args.output_path): 50 | os.makedirs(args.output_path) 51 | 52 | 53 | ## ---------------------------------------- 54 | ## load model 55 | ## ---------------------------------------- 56 | model = load_model(args.pretrained_model, compile=False) 57 | 58 | ## ---------------------------------------- 59 | ## Test model 60 | ## ---------------------------------------- 61 | if "removal" in args.pretrained_model: 62 | image_L, original_shape = process_input(args.input_path, True) 63 | if "non_blind" in args.pretrained_model: 64 | shape = image_L.shape 65 | X_levels = np.empty((1, shape[1], shape[2], 1), dtype=np.float32) 66 | X_levels[0,:,:,:] = np.full((shape[1], shape[2], 1), args.level) 67 | image_E = model([image_L, X_levels], training = False) 68 | else: 69 | image_E = model(image_L, training = True) 70 | image_E = image_E[0, :original_shape[1], :original_shape[2],:] 71 | image_E = np.array(image_E)*0.5 +0.5 72 | image_E = np.clip(image_E,0.0,1.0)*255 73 | image_E = image_E.astype(np.uint8) 74 | imageio.imsave( args.output_path +os.path.basename(args.input_path), image_E) 75 | 76 | 77 | if "synthesis" in args.pretrained_model: 78 | image_L, original_shape = process_input(args.input_path, False) 79 | shape = image_L.shape 80 | X_levels = np.empty((1, shape[1], shape[2], 1), dtype=np.float32) 81 | X_levels[0,:,:,:] = np.full((shape[1], shape[2], 1), args.level) 82 | image_E = model([image_L, X_levels], training = True) 83 | image_E = image_E[0, :original_shape[1], :original_shape[2],:] 84 | image_E = np.clip(image_E,0.0,255.0) 85 | image_E = image_E.astype(np.uint8) 86 | imageio.imsave(args.output_path+os.path.basename(args.input_path), image_E) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | 7 | import logging 8 | import os 9 | 10 | 11 | logging.getLogger('tensorflow').setLevel(logging.FATAL) 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' 13 | from warnings import simplefilter 14 | simplefilter(action='ignore', category=FutureWarning) 15 | from data import Data, pickle_files 16 | from model import * 17 | from losses import * 18 | from utils import * 19 | import tensorflow as tf 20 | import json 21 | import sys 22 | 23 | 24 | 25 | 26 | def trainloop(data_config, task, path): 27 | if not os.path.exists("train_output"): 28 | os.makedirs("train_output") 29 | logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler("debug.log"), logging.StreamHandler(sys.stdout)]) 30 | 31 | if "gray" in task: 32 | output_channel=1 33 | path_channel = "Y" 34 | else: 35 | output_channel=3 36 | path_channel = "RGB" 37 | 38 | logging.info('Data loading') 39 | 40 | list_IDs,list_levels = pickle_files(path, data_config['levels']) 41 | 42 | DATA = Data(list_IDs, list_levels, data_config['input_dim'], data_config['batch_size'], task) 43 | list_IDs_org, list_IDs_fg, list_levels = DATA.load_pkl_multilevel(list_IDs, list_levels) 44 | num_batches = DATA.batches() 45 | print(num_batches) 46 | epochs = data_config["epochs"] 47 | input_dim = data_config["input_dim"] 48 | batch_size = data_config["batch_size"] 49 | generator_optimizer = tf.keras.optimizers.Adam(data_config['learning_rate_alpha_gen'], data_config['learning_rate_beta_gen'] ) 50 | discriminator_optimizer = tf.keras.optimizers.Adam(data_config['learning_rate_alpha_dis'], data_config['learning_rate_beta_gen'] ) 51 | loss = LOSS(k1=0.01, k2=0.03, L=2, window_size=11) 52 | 53 | 54 | 55 | logging.info('Model building') 56 | GAN = False 57 | 58 | if "synthesis" in task: 59 | gen_loss = generator_loss 60 | disc_loss = discriminator_loss 61 | normalized = False 62 | 63 | if "case_1" in task: # U-Net optimized with l1 64 | generator = unet(input_dim, output_channel, normalized) 65 | gen_loss = loss.l1 66 | 67 | elif "case_2" in task: # cGAN with U-Net as backbone for generator 68 | generator = unet(input_dim, output_channel, normalized) 69 | discriminator = patchgan(input_dim, output_channel) 70 | GAN = True 71 | 72 | elif "case_3" in task: # cGAN with U-Net + residual blocks as backbone for generator 73 | generator = res_unet(64, 5, input_dim, output_channel, normalized, blind=False) 74 | discriminator = patchgan(input_dim, output_channel) 75 | GAN = True 76 | 77 | elif "case_3_gray" in task: # cGAN with U-Net + residual blocks as backbone for generator for grayscale images 78 | generator = res_unet(64, 5, input_dim, output_channel, normalized, blind=False) 79 | discriminator = patchgan(input_dim, output_channel) 80 | GAN = True 81 | generator.summary() 82 | 83 | 84 | elif "removal" in task: 85 | normalized = True 86 | blind=False 87 | if "case_1" in task: # U-Net optimized with l1 88 | generator = unet(input_dim, output_channel, normalized) 89 | gen_loss = loss.l1 90 | 91 | elif "case_2" in task: # U-Net with residual blocks optimized with l1 92 | generator = res_unet(64, 5, input_dim, output_channel, normalized, blind) 93 | gen_loss = loss.l1 94 | 95 | elif "case_3" in task: # U-Net with residual blocks optimized with mix of l1 and MS-SSIM 96 | generator = res_unet(64, 5, input_dim, output_channel, normalized, blind) 97 | gen_loss = loss.ms_ssim_l1 98 | 99 | elif "case_4" in task: # U-Net with residual blocks optimized with mix of l1 and MS-SSIM but blind 100 | blind =True 101 | generator = res_unet(64, 5, input_dim, output_channel, normalized, blind) 102 | gen_loss = loss.ms_ssim_l1 103 | 104 | elif "case_3_gray" in task: # U-Net with residual blocks optimized with mix of l1 and MS-SSIM for grayscale images 105 | generator = res_unet(64, 5, input_dim, output_channel, normalized, blind) 106 | gen_loss = loss.ms_ssim_l1 107 | 108 | elif "case_4_gray" in task: # U-Net with residual blocks optimized with mix of l1 and MS-SSIM but blind for grayscale images 109 | blind =True 110 | generator = res_unet(64, 5, input_dim, output_channel, normalized, blind) 111 | gen_loss = loss.ms_ssim_l1 112 | generator.summary() 113 | 114 | 115 | 116 | 117 | for epoch in range(epochs): 118 | print("Epoch : " +str(epoch)) 119 | for batch_num in range(num_batches): 120 | input_image, target_image, X_level = DATA.generate_samples_multilevel(batch_num) 121 | if "gray" in task: 122 | input_image= tf.reshape(input_image, (batch_size,256,256,1)) 123 | target_image= tf.reshape(target_image, (batch_size,256,256,1)) 124 | X_level = tf.reshape(X_level, (batch_size,256,256,1)) 125 | 126 | if GAN: 127 | with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: 128 | # Forward 129 | gen_output = generator([input_image, X_level], training=True) 130 | disc_real_output = discriminator([input_image, target_image, X_level], training=True) 131 | disc_generated_output = discriminator([input_image, gen_output,X_level], training=True) 132 | 133 | # Loss computing 134 | gen_total_loss, gen_gan_loss, gen_l1_loss = gen_loss(disc_generated_output, gen_output, target_image) 135 | disc_total_loss = disc_loss(disc_real_output, disc_generated_output) 136 | 137 | # Gradients computing 138 | generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables) 139 | discriminator_gradients = disc_tape.gradient(disc_total_loss, discriminator.trainable_variables) 140 | 141 | # Gradients application 142 | generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables)) 143 | discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables)) 144 | 145 | else: 146 | with tf.GradientTape() as gen_tape: 147 | if "case_4" in task: 148 | gen_output = generator(input_image, training=True) 149 | else: 150 | gen_output = generator([input_image, X_level], training=True) 151 | 152 | gen_total_loss = gen_loss(gen_output, target_image) 153 | 154 | generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables) 155 | generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables)) 156 | if(batch_num % 20 == 0): 157 | if "removal" in task: 158 | visualize_org(input_image, gen_output, target_image, X_level, task) 159 | elif "synthesis" in task: 160 | visualize_fg(input_image, gen_output, target_image, X_level, task) 161 | generator.save("train_output/"+task+"_"+str(epoch)+".h5") 162 | 163 | 164 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022 InterDigital R&D France 3 | All rights reserved. 4 | Licensed under BSD 3 clause clear license - check license.txt for more information 5 | """ 6 | 7 | 8 | import numpy as np 9 | import imageio 10 | from matplotlib import pyplot as plt 11 | 12 | def visualize_fg(input_image, gen_output, target_image, X_level, name): 13 | prediction = np.clip(gen_output[0],0.0,255.0) 14 | prediction = prediction.astype(np.uint8) 15 | 16 | target_input = target_image[0] 17 | target_input = np.clip(target_input,0.0,255.0) 18 | target_input = target_input.astype(np.uint8) 19 | 20 | test_input = input_image[0] 21 | test_input = np.clip(test_input,0.0,255.0) 22 | test_input = test_input.astype(np.uint8) 23 | 24 | level = X_level[0,0,0] 25 | level = level[0] 26 | plt.figure(figsize=(15, 15)) 27 | 28 | display_list = [test_input, prediction, target_input] 29 | title = ['Input' , 'Predicted', 'Target (level = '+ str(level)+')'] 30 | 31 | for i in range(3): 32 | plt.subplot(1, 3, i+1) 33 | plt.title(title[i]) 34 | plt.imshow(display_list[i] ) 35 | plt.axis('off') 36 | plt.savefig("train_output/test_" + name+ ".tiff") 37 | plt.close() 38 | 39 | 40 | def visualize_org(input_image, gen_output, target_image, X_level, name): 41 | prediction = gen_output[0]*0.5 + 0.5 42 | prediction = np.clip(prediction,0.0,1.0)*255 43 | prediction = prediction.astype(np.uint8) 44 | 45 | target_input = target_image[0]*0.5 +0.5 46 | target_input = np.clip(target_input,0.0,1.0)*255 47 | target_input = target_input.astype(np.uint8) 48 | 49 | 50 | test_input = input_image[0]*0.5 +0.5 51 | test_input = np.clip(test_input,0.0,1.0)*255 52 | test_input = test_input.astype(np.uint8) 53 | 54 | level = X_level[0,0,0] 55 | level = level[0] 56 | plt.figure(figsize=(15, 15)) 57 | 58 | display_list = [test_input, prediction, target_input] 59 | title = ['Input (level = '+ str(level)+')' , 'Predicted', 'Target'] 60 | 61 | for i in range(3): 62 | plt.subplot(1, 3, i+1) 63 | plt.title(title[i]) 64 | plt.imshow(display_list[i] ) 65 | plt.axis('off') 66 | plt.savefig("train_output/test_" + name+ ".tiff") 67 | plt.close() 68 | 69 | 70 | --------------------------------------------------------------------------------