├── config ├── __pycache__ │ ├── config.cpython-37.pyc │ └── test_config.cpython-37.pyc ├── config.ini └── config.py ├── run.py ├── MIT License ├── test.py ├── README.md ├── inpaint.py └── rmnet.py /config/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jireh-Jam/R-MNet-Inpainting-keras/HEAD/config/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /config/__pycache__/test_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jireh-Jam/R-MNet-Inpainting-keras/HEAD/config/__pycache__/test_config.cpython-37.pyc -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jan 24 19:57:01 2021 4 | 5 | @author: Jireh Jam 6 | """ 7 | 8 | from rmnet import RMNETWGAN 9 | from config import config 10 | 11 | CONFIG_FILE = './config/config.ini' 12 | config = config.MainConfig(CONFIG_FILE).training 13 | 14 | if __name__ == '__main__': 15 | r_mnetwgan = RMNETWGAN(config) 16 | r_mnetwgan.train() 17 | -------------------------------------------------------------------------------- /config/config.ini: -------------------------------------------------------------------------------- 1 | [TESTING] 2 | BATCH_SIZE = 1 3 | IMG_HEIGHT = 256 4 | IMG_WIDTH = 256 5 | CHANNELS = 3 6 | MASK_CHANNELS = 1 7 | IMGS_IN_BATCH = 1 8 | LAST_IMG_ON = 0 9 | NUM_EPOCHS = 100 10 | GEN_LEARNING_RATE = 1E-4 11 | DISC_LEARNING_RATE = 1E-12 12 | GEN_FILTER = 64 13 | DISC_FILTER=64 14 | CURRNT_EPOCH = 0 15 | SAMPLE_INTERVAL = 10 16 | BETA_1 = 0.9 17 | BETA_2 = 0.999 18 | EPSILON= 1e-08 19 | LAST_TRAINED_EPOCH=99 20 | 21 | [TRAINING] 22 | BATCH_SIZE = 5 23 | IMG_HEIGHT = 256 24 | IMG_WIDTH = 256 25 | NUM_CHANNELS = 3 26 | MASK_CHANNELS = 1 27 | NUM_EPOCHS = 100 28 | GEN_LEARNING_RATE = 1E-4 29 | DISC_LEARNING_RATE = 1E-12 30 | GEN_FILTER = 64 31 | DISC_FILTER=64 32 | CURRNT_EPOCH = 0 33 | SAMPLE_INTERVAL = 10 34 | BETA_1 = 0.9 35 | BETA_2 = 0.999 36 | EPSILON= 1e-08 37 | LAST_TRAINED_EPOCH=0 38 | -------------------------------------------------------------------------------- /MIT License: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jireh Jam 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 25 14:24:02 2021 4 | 5 | @author: Jireh Jam 6 | """ 7 | 8 | from configparser import ConfigParser 9 | class MainConfig: 10 | 11 | def __init__(self, file_path): 12 | inpaint_config = ConfigParser() 13 | seen = inpaint_config.read(file_path) 14 | if not seen: 15 | raise ValueError('No config file found!') 16 | print(inpaint_config.sections()) 17 | self.testing = TestingConfig(inpaint_config['TESTING']) 18 | self.training = TrainingConfig(inpaint_config['TRAINING']) 19 | 20 | 21 | class TestingConfig: 22 | 23 | def __init__(self, test_section): 24 | self.batch_size = int(test_section['BATCH_SIZE']) 25 | self.img_height = int(test_section['IMG_HEIGHT']) 26 | self.img_width = int(test_section['IMG_WIDTH']) 27 | self.channels = int(test_section['CHANNELS']) 28 | self.mask_channels = int(test_section['MASK_CHANNELS']) 29 | self.imgs_in_batch = int(test_section['IMGS_IN_BATCH']) 30 | self.last_img_on = int(test_section['LAST_IMG_ON']) 31 | self.num_epochs = int(test_section['NUM_EPOCHS']) 32 | self.g_learning_rate = float(test_section['GEN_LEARNING_RATE']) 33 | self.d_learning_rate = float(test_section['DISC_LEARNING_RATE']) 34 | self.gf = int(test_section['GEN_FILTER']) 35 | self.df = int(test_section['DISC_FILTER']) 36 | self.current_epoch = int(test_section['CURRNT_EPOCH']) 37 | self.sample_interval = int(test_section['SAMPLE_INTERVAL']) 38 | self.beta_1 = float(test_section['BETA_1']) 39 | self.beta_2 = float(test_section['BETA_2']) 40 | self.epsilon = float(test_section['EPSILON']) 41 | self.last_trained_epoch = int(test_section['LAST_TRAINED_EPOCH']) 42 | 43 | class TrainingConfig: 44 | 45 | def __init__(self, training_section): 46 | self.batch_size = int(training_section['BATCH_SIZE']) 47 | self.img_height = int(training_section['IMG_HEIGHT']) 48 | self.img_width = int(training_section['IMG_WIDTH']) 49 | self.channels = int(training_section['NUM_CHANNELS']) 50 | self.mask_channels = int(training_section['MASK_CHANNELS']) 51 | self.num_epochs = int(training_section['NUM_EPOCHS']) 52 | self.g_learning_rate = float(training_section['GEN_LEARNING_RATE']) 53 | self.d_learning_rate = float(training_section['DISC_LEARNING_RATE']) 54 | self.gf = int(training_section['GEN_FILTER']) 55 | self.df = int(training_section['DISC_FILTER']) 56 | self.current_epoch = int(training_section['CURRNT_EPOCH']) 57 | self.sample_interval = int(training_section['SAMPLE_INTERVAL']) 58 | self.beta_1 = float(training_section['BETA_1']) 59 | self.beta_2 = float(training_section['BETA_2']) 60 | self.epsilon = float(training_section['EPSILON']) 61 | self.last_trained_epoch = int(training_section['LAST_TRAINED_EPOCH']) 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jan 24 19:58:05 2021 4 | 5 | @author: Jireh Jam 6 | """ 7 | import os 8 | import cv2 9 | import numpy as np 10 | from copy import deepcopy 11 | from rmnet import RMNETWGAN 12 | from config import config 13 | 14 | 15 | 16 | #Config Loader 17 | CONFIG_FILE = './config/config.ini' 18 | config = config.MainConfig(CONFIG_FILE).testing 19 | 20 | #Data params 21 | test_img_dir ='./images/celebA_HQ_test/' 22 | test_mask_dir ='./masks/test_masks/' 23 | test_imgs_path = os.listdir(test_img_dir) 24 | test_masks_path = os.listdir(test_mask_dir) 25 | 26 | #Directories 27 | imgs_dir = 'real_images_rmnet' 28 | masked_dir = 'masked_images_rmnet' 29 | inpainted_dir = 'inpainted_images_rmnet' 30 | trained_model_path = r'./models/RMNet_WACV2021' 31 | 32 | #Load data 33 | def generate_test_batch(last_img_on, imgs_in_batch): 34 | 35 | if (last_img_on + imgs_in_batch) >= len(test_imgs_path): 36 | imgs_in_batch = len(test_imgs_path)-last_img_on 37 | imgs = np.zeros((config.imgs_in_batch,config.img_width,config.img_height,config.channels)) 38 | masked_imgs = np.zeros((config.imgs_in_batch,config.img_width,config.img_height,config.channels)) 39 | masks = np.zeros((config.imgs_in_batch,config.img_width,config.img_height,config.mask_channels)) 40 | idx = 0 41 | for i in range(imgs_in_batch): 42 | print("\rloading Image " + str(i) + ' of ' +str(len(test_imgs_path)), end=" ") 43 | img = (cv2.imread(test_img_dir+test_imgs_path[last_img_on],1)) 44 | img = cv2.resize(img,(config.img_width, config.img_height)) 45 | img = img[..., [2, 1, 0]] 46 | img = (img - 127.5) / 127.5 47 | mask = (cv2.imread(test_mask_dir+test_masks_path[last_img_on],0)) 48 | mask[mask == 255] = 1 49 | mask = cv2.resize(mask,(config.img_width, config.img_height)) 50 | mask = np.reshape(mask,(config.img_width,config.img_height,config.mask_channels)) 51 | masks[i] = mask 52 | masked_imgs[i] = deepcopy(img) 53 | masked_imgs[i][np.where((mask == [1,1,1]).all(axis=2))]=[255,255,255] 54 | imgs[i] = img 55 | last_img_on += 1 56 | idx+=1 57 | return last_img_on, imgs,masks,masked_imgs 58 | 59 | #Inpaint imgaes 60 | def inpaint(): 61 | imgs_in_batch = config.imgs_in_batch 62 | last_img_on =config.last_img_on 63 | rmnet_model = RMNETWGAN(config) 64 | #Edit last_trained_epoch in config.ini 65 | rmnet_model.generator.load_weights('{}/weight_{}.h5'.format(trained_model_path,config.last_trained_epoch)) 66 | for i in range(len(test_imgs_path)): 67 | if not os.path.exists(inpainted_dir): 68 | os.makedirs(inpainted_dir) 69 | if not os.path.exists(imgs_dir): 70 | os.makedirs(imgs_dir) 71 | if not os.path.exists(masked_dir): 72 | os.makedirs(masked_dir) 73 | last_img_on, imgs, masks,masked_imgs = generate_test_batch(last_img_on, imgs_in_batch) 74 | gen_imgs = rmnet_model.generator.predict([imgs,masks], config.batch_size) 75 | gen_imgsRGB = gen_imgs[:,:,:,0:3] 76 | input_img = np.expand_dims(imgs[0], 0) 77 | input_mask = np.expand_dims(masks[0], 0) 78 | maskedImg = ((1 - input_mask)*input_img) + input_mask 79 | cv2.imwrite(r'./' + imgs_dir +'/'+str(i) +'.jpg',(imgs[0][..., [2, 1, 0]]* 127.5 + 127.5).astype("uint8")) 80 | cv2.imwrite(r'./' + masked_dir +'/' +str(i) +'.jpg',(maskedImg[0][..., [2, 1, 0]]* 127.5 + 127.5).astype("uint8")) 81 | cv2.imwrite(r'./' + inpainted_dir +'/' +str(i) +'.jpg',(gen_imgsRGB[0][..., [2, 1, 0]]* 127.5 + 127.5).astype("uint8")) 82 | 83 | if __name__=='__main__': 84 | inpaint() 85 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](https://img.shields.io/badge/Python-3.6-yewllo.svg) ![](https://img.shields.io/badge/Keras-2.3.1-yewllo.svg) ![](https://img.shields.io/badge/TensorFlow-1.13.1-yewllo.svg) ![](https://img.shields.io/badge/License-MIT-yewllo.svg) 2 | # R-MNET-A-Perceptual-Adversarial-Network-for-Image-Inpainting in Keras 3 | R-MNET: A Perceptual Adversarial Network for Image Inpainting. 4 | Jireh Jam, Connah Kendrick, Vincent Drouard, Kevin Walker, Gee-Sern Hsu, Moi Hoon Yap 5 | ### 6 | Keras implementation of R-MNET model proposed at WACV2021. 7 | ### 8 | https://arxiv.org/pdf/2008.04621.pdf 9 | 10 | 11 | ### Architecture 12 | 13 | 14 | ## Requirements 15 | ### Download Trained Model For Inference 16 | Download pre-trained model and create a director in the order "models/RMNet_WACV2021/" and save the pre-trained weight here before running the inpaint.py file. Note that we used quickdraw mask dataset and this can be altererd accordingly as per the script. All instructions are there. 17 | [Download CelebA-HQ](https://drive.google.com/drive/folders/1ZzswYSyCs4Z3pyR1feVJ6EfkBPhw9jf5?usp=sharing) 18 | ### Images dataset 19 | [Download Places2 Dataset]( http://data.csail.mit.edu/places/places365/places365standard_easyformat.tar) and [CelebA-HQ Dataset](https://github.com/willylulu/celeba-hq-modified) 20 | ### Mask dataset 21 | The training mask dataset used for training our model: [QD-IMD: Quick Draw Irregular Mask Dataset](https://github.com/karfly/qd-imd) 22 | The NVIDIA's mask dataset is available [here](https://nv-adlr.github.io/publication/partialconv-inpainting) 23 | ### Folder structure 24 | After downloading the datasets, you should put create these folders into `/images/train/train_images` and `/masks/train/train_masks`. Place the images and masks in the train_images and train_masks respectively and it should be like 25 | 26 | ``` 27 | -- images 28 | ---- train 29 | ------ train_images 30 | ---- celebA_HQ_test 31 | -- masks 32 | ---- train 33 | ------ train_masks 34 | ---- test_masks 35 | ``` 36 | /images/train/train_images and /masks/train/train_masks and place the images and masks in the train_images and train_masks respectively. 37 | Make sure the directory path is 38 | 39 | ``` 40 | --self.train_mask_dir='./masks/train/' 41 | --self.train_img_dir = './images/train/' 42 | --test_img_dir ='./images/celebA_HQ_test/' 43 | --test_mask_dir ='./masks/test_masks/' 44 | ``` 45 | ### Python requirements 46 | - Python 3.6 47 | - Tensorflow 1.13.1 48 | - keras 2.3.1 49 | - opencv 50 | - Numpy 51 | 52 | ### Training and Testing scripts. 53 | Use the run.py file to train the model and inpaint.py to test the model. We recommend training for 100 epochs as a benchmark based on the state-of-the-art used to compare with out model. 54 | ## Code Reference 55 | 1. Wasserstain GAN was implemented based on: [Wasserstein GAN Keras](https://github.com/eriklindernoren/Keras-GAN/blob/master/wgan/wgan.py) 56 | 2. Generative Multi-column Convolutional Neural Networks inpainting model in Keras : [Image Inpainting via Generative Multi-column Convolutional Neural Networks](https://github.com/tlatkowski/inpainting-gmcnn-keras/) 57 | 3. Nvidia Mask Dataset, based on the paper: [Image Inpainting for Irregular Holes Using Partial Convolutions](https://eccv2018.org/openaccess/content_ECCV_2018/papers/Guilin_Liu_Image_Inpainting_for_ECCV_2018_paper.pdf) 58 | ## Citing this script 59 | If you use this script, please consider citing [R-MNet](https://openaccess.thecvf.com/content/WACV2021/papers/Jam_R-MNet_A_Perceptual_Adversarial_Network_for_Image_Inpainting_WACV_2021_paper.pdf): 60 | ``` 61 | @inproceedings{jam2021r, 62 | title={R-mnet: A perceptual adversarial network for image inpainting}, 63 | author={Jam, Jireh and Kendrick, Connah and Drouard, Vincent and Walker, Kevin and Hsu, Gee-Sern and Yap, Moi Hoon}, 64 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 65 | pages={2714--2723}, 66 | year={2021} 67 | } 68 | ``` 69 | ``` 70 | @article{jam2020r, 71 | title={R-MNet: A Perceptual Adversarial Network for Image Inpainting}, 72 | author={Jam, Jireh and Kendrick, Connah and Drouard, Vincent and Walker, Kevin and Hsu, Gee-Sern and Yap, Moi Hoon}, 73 | journal={arXiv preprint arXiv:2008.04621}, 74 | year={2020} 75 | } 76 | ``` 77 | 78 | 79 | -------------------------------------------------------------------------------- /inpaint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from copy import deepcopy 5 | from rmnet import RMNETWGAN 6 | from config import config 7 | import random 8 | 9 | # =================================================================================== # 10 | # 1. Config Loader # 11 | # =================================================================================== # 12 | 13 | CONFIG_FILE = './config/config.ini' 14 | config = config.MainConfig(CONFIG_FILE).testing 15 | 16 | # =================================================================================== # 17 | # 2. Data params # 18 | # =================================================================================== # 19 | 20 | test_img_dir ='./images/celebA_HQ_test/' 21 | test_mask_dir ='./masks/test_masks/' 22 | test_imgs_path = os.listdir(test_img_dir) 23 | test_masks_path = os.listdir(test_mask_dir) 24 | 25 | # =================================================================================== # 26 | # 3. Directories # 27 | # =================================================================================== # 28 | 29 | imgs_dir = 'real_images_rmnet' 30 | masked_dir = 'masked_images_rmnet' 31 | inpainted_dir = 'inpainted_images_rmnet' 32 | trained_model_path = r'./models/RMNet_WACV2021' 33 | 34 | # =================================================================================== # 35 | # 4. Data Loader # 36 | # =================================================================================== # 37 | 38 | def generate_test_batch(last_img_on, imgs_in_batch): 39 | 40 | if (last_img_on + imgs_in_batch) >= len(test_imgs_path): 41 | imgs_in_batch = len(test_imgs_path)-last_img_on 42 | imgs = np.zeros((config.imgs_in_batch,config.img_width,config.img_height,config.channels)) 43 | masked_imgs = np.zeros((config.imgs_in_batch,config.img_width,config.img_height,config.channels)) 44 | masks = np.zeros((config.imgs_in_batch,config.img_width,config.img_height,config.mask_channels)) 45 | for i in range(imgs_in_batch): 46 | print("\rLoading image number "+ str(i) + " of " + str(len(test_imgs_path)), end = " ") 47 | img = cv2.imread(test_img_dir+test_imgs_path[last_img_on],1).astype('float')/ 127.5 -1 48 | img = cv2.resize(img,(config.img_width, config.img_height)) 49 | #If Mask regions are white, DO NOT subtract from 1. 50 | #If mask regions are black, subtract from 1. 51 | mask = 1-cv2.imread(test_mask_dir+test_masks_path[last_img_on],0).astype('float')/ 255 52 | mask = cv2.resize(mask,(config.img_width, config.img_height)) 53 | mask = np.reshape(mask,(config.img_width,config.img_height,config.mask_channels)) 54 | 55 | masks[i] = mask 56 | masked_imgs[i] = deepcopy(img) 57 | imgs[i] = img 58 | masked_imgs[i][np.where((mask == [1,1,1]).all(axis=2))]=[1,1,1] 59 | last_img_on += 1 60 | # if(last_img_on >= len(test_imgs_path)): 61 | # last_img_on = 0 62 | # cv2.imshow("mask",((masks[0])* 255).astype("uint8")) 63 | # cv2.imshow("masked",((masked_imgs[0]+1)* 127.5).astype("uint8")) 64 | # cv2.waitKey(0 ) 65 | return last_img_on, imgs,masks,masked_imgs 66 | 67 | # =================================================================================== # 68 | # 5. Data Loader # 69 | # =================================================================================== # 70 | 71 | def inpaint(): 72 | imgs_in_batch = config.imgs_in_batch 73 | last_img_on =config.last_img_on 74 | rmnet_model = RMNETWGAN(config) 75 | #Edit last_trained_epoch in config.ini 76 | rmnet_model.generator.load_weights('{}/weight_{}.h5'.format(trained_model_path,config.last_trained_epoch)) 77 | for i in range(len(test_imgs_path)): 78 | if not os.path.exists(inpainted_dir): 79 | os.makedirs(inpainted_dir) 80 | if not os.path.exists(imgs_dir): 81 | os.makedirs(imgs_dir) 82 | if not os.path.exists(masked_dir): 83 | os.makedirs(masked_dir) 84 | d=0 85 | for i in range(3000): 86 | last_img_on, imgs, masks, masked_imgs = generate_test_batch(last_img_on, imgs_in_batch) 87 | gen_imgs = rmnet_model.generator.predict([imgs,masks], config.batch_size) 88 | gen_imgsRGB = gen_imgs[:,:,:,0:3] 89 | imgs = ((imgs[0]+1)* 127.5).astype("uint8") 90 | gen_image = ((gen_imgsRGB[0]+1)* 127.5).astype("uint8") 91 | mask_image = ((masked_imgs[0]+1)* 127.5).astype("uint8") 92 | inpainted_imgs_folder = "inpainted_images_rmnet/%d.jpg"%d 93 | masked_imgs_folder = "masked_images_rmnet/%d.jpg"%d 94 | real_imgs_folder = "real_images_rmnet/%d.jpg"%d 95 | cv2.imwrite(inpainted_imgs_folder,gen_image) 96 | cv2.imwrite(masked_imgs_folder,mask_image) 97 | cv2.imwrite(real_imgs_folder,imgs) 98 | d+=1 99 | 100 | if __name__=='__main__': 101 | inpaint() 102 | -------------------------------------------------------------------------------- /rmnet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Jan 24 19:35:05 2021 4 | @author: Jireh Jam 5 | """ 6 | 7 | from __future__ import print_function, division 8 | 9 | 10 | from keras.applications import VGG19 11 | from keras.layers import Input, Dense, Flatten, Dropout, Concatenate, Multiply, Lambda, Add 12 | from keras.layers import BatchNormalization, Activation, ZeroPadding2D 13 | from keras.layers.advanced_activations import LeakyReLU 14 | from keras.layers.convolutional import UpSampling2D, Conv2D,MaxPooling2D,Conv2DTranspose 15 | from keras.models import Model 16 | from keras.optimizers import Adam 17 | from keras.preprocessing.image import ImageDataGenerator 18 | from keras.utils import multi_gpu_model 19 | from keras import backend as K 20 | 21 | import tensorflow as tf 22 | 23 | import cv2 24 | import matplotlib.pyplot as plt 25 | import numpy as np 26 | 27 | import os 28 | import datetime 29 | import time 30 | import gc 31 | import random 32 | 33 | 34 | class RMNETWGAN(): 35 | def __init__(self,config): 36 | #Input shape 37 | self.img_width=config.img_width 38 | self.img_height=config.img_height 39 | self.channels=config.channels 40 | self.mask_channles = config.mask_channels 41 | self.img_shape=(self.img_width, self.img_height, self.channels) 42 | self.img_shape_mask=(self.img_width, self.img_height, self.mask_channles) 43 | self.missing_shape = (self.img_width, self.img_height, self.channels) 44 | self.num_epochs = config.num_epochs 45 | self.batch_size = config.batch_size 46 | self.start_time = time.time() 47 | self.end_time = time.time() 48 | self.sample_interval = config.sample_interval 49 | self.current_epoch =config.current_epoch 50 | self.last_trained_epoch = config.last_trained_epoch 51 | 52 | #Folders 53 | self.dataset_name = 'RMNet_WACV2021' 54 | self.models_path = 'models' 55 | 56 | #Configure Loader 57 | self.img_dir = r'./images/train/celebA_HQ_train/' 58 | self.masks_dir = r'./masks/train/qd_imd/train/' 59 | self.imgs_in_path = os.listdir(self.img_dir) 60 | self.masks_in_path = os.listdir(self.masks_dir) 61 | 62 | # Number of filters in the first layer of G and D 63 | self.gf = config.gf 64 | self.df = config.gf 65 | self.continue_train = True 66 | 67 | 68 | #Optimizer 69 | self.g_optimizer = Adam(lr=config.g_learning_rate, 70 | beta_1=config.beta_1, 71 | beta_2=config.beta_2, 72 | epsilon=config.epsilon) 73 | self.d_optimizer = Adam(lr=config.d_learning_rate, 74 | beta_1=config.beta_1, 75 | beta_2=config.beta_2, 76 | epsilon=config.epsilon) 77 | 78 | # =================================================================================== # 79 | # 1. Build and compile the discriminator # 80 | # =================================================================================== # 81 | 82 | self.discriminator = self.build_discriminator() 83 | self.discriminator.compile(loss=[self.wasserstein_loss], 84 | optimizer=self.d_optimizer, 85 | metrics=['accuracy']) 86 | 87 | # =================================================================================== # 88 | # 2. Build the generator # 89 | # =================================================================================== # 90 | self.generator = self.build_generator() 91 | 92 | # =================================================================================== # 93 | # 3. The combined model (stacked generator and discriminator) # 94 | # Trains the generator to fool the discriminator # 95 | # =================================================================================== # 96 | 97 | try: 98 | self.multi_model = multi_gpu_model(self.combined, gpus=2) 99 | self.multi_model.compile(loss=[self.generator_loss, self.wasserstein_loss], loss_weights=[1.0, 1e-3], optimizer=self.g_optimizer) 100 | 101 | except: 102 | self.combined = self.build_gan(self.generator, self.discriminator) 103 | self.combined.compile(loss=[self.generator_loss, self.wasserstein_loss],loss_weights=[1, 1e-3], optimizer=self.g_optimizer) 104 | 105 | def build_gan(self, generator, discriminator): 106 | #Generator takes mask and image as input 107 | image = Input(shape=self.img_shape) 108 | mask = Input(shape=self.img_shape_mask) 109 | 110 | #Generator predicts image 111 | gen_output = generator([image, mask]) 112 | 113 | #Train the generator only for the combined model 114 | discriminator.trainable = False 115 | 116 | #Descriminator validates the predicted image 117 | # It takes generated images as input and determines validity 118 | gen_img = Lambda(lambda x : x[:,:,:,0:3])(gen_output) 119 | # print("this is generated image in shape {} ".format(gen_image.shape)) 120 | score = discriminator(gen_img) 121 | 122 | 123 | model = Model([image, mask], [gen_output, score]) 124 | return model 125 | # =================================================================================== # 126 | # 4. Define the discriminator and generator losses # 127 | # =================================================================================== # 128 | 129 | def wasserstein_loss(self, y_true, y_pred): 130 | return -K.mean(y_true * y_pred) 131 | 132 | def generator_loss(self, y_true, y_pred): 133 | mask = Lambda(lambda x : x[:,:,:,3:])(y_true) 134 | reversed_mask = Lambda(self.reverse_mask, output_shape=(self.img_shape_mask))(mask) 135 | 136 | input_img = Lambda(lambda x : x[:,:,:,0:3])(y_true) 137 | output_img = Lambda(lambda x : x[:,:,:,0:3])(y_pred) 138 | 139 | vgg = VGG19(include_top=False, weights='imagenet', input_shape=self.img_shape) 140 | loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output) 141 | loss_model.trainable = False 142 | p_loss = K.mean(K.square(loss_model(output_img) - loss_model(input_img))) 143 | 144 | masking = Multiply()([reversed_mask,input_img]) 145 | predicting = Multiply()([reversed_mask, output_img]) 146 | reversed_mask_loss = (K.mean(K.square(loss_model(predicting) - loss_model(masking)))) 147 | new_loss = 0.6*(p_loss) + 0.4*reversed_mask_loss 148 | return new_loss 149 | 150 | # =================================================================================== # 151 | # 5. Define the reverese mask # 152 | # =================================================================================== # 153 | 154 | def reverse_mask(self,x): 155 | return 1-x 156 | 157 | # =================================================================================== # 158 | # 6. Define the generator # 159 | # =================================================================================== # 160 | 161 | def build_generator(self): 162 | 163 | #compute inputs 164 | input_img = Input(shape=(self.img_shape), dtype='float32', name='image_input') 165 | input_mask = Input(shape=(self.img_shape_mask), dtype='float32',name='mask_input') 166 | reversed_mask = Lambda(self.reverse_mask,output_shape=(self.img_shape_mask))(input_mask) 167 | masked_image = Multiply()([input_img,reversed_mask]) 168 | 169 | #encoder 170 | x =(Conv2D(self.gf,(5, 5), dilation_rate=2, input_shape=self.img_shape, padding="same",name="enc_conv_1"))(masked_image) 171 | x =(LeakyReLU(alpha=0.2))(x) 172 | x =(BatchNormalization(momentum=0.8))(x) 173 | 174 | pool_1 = MaxPooling2D(pool_size=(2,2))(x) 175 | 176 | x =(Conv2D(self.gf,(5, 5), dilation_rate=2, padding="same",name="enc_conv_2"))(pool_1) 177 | x =(LeakyReLU(alpha=0.2))(x) 178 | x =(BatchNormalization(momentum=0.8))(x) 179 | 180 | pool_2 = MaxPooling2D(pool_size=(2,2))(x) 181 | 182 | x =(Conv2D(self.gf*2, (5, 5), dilation_rate=2, padding="same",name="enc_conv_3"))(pool_2) 183 | x =(LeakyReLU(alpha=0.2))(x) 184 | x =(BatchNormalization(momentum=0.8))(x) 185 | 186 | pool_3 = MaxPooling2D(pool_size=(2,2))(x) 187 | 188 | x =(Conv2D(self.gf*4, (5, 5), dilation_rate=2, padding="same",name="enc_conv_4"))(pool_3) 189 | x =(LeakyReLU(alpha=0.2))(x) 190 | x =(BatchNormalization(momentum=0.8))(x) 191 | 192 | pool_4 = MaxPooling2D(pool_size=(2,2))(x) 193 | 194 | x =(Conv2D(self.gf*8, (5, 5), dilation_rate=2, padding="same",name="enc_conv_5"))(pool_4) 195 | x =(LeakyReLU(alpha=0.2))(x) 196 | x =(Dropout(0.5))(x) 197 | 198 | #Decoder 199 | x =(UpSampling2D(size=(2, 2), interpolation='bilinear'))(x) 200 | x =(Conv2DTranspose(self.gf*8, (3, 3), padding="same",name="upsample_conv_1"))(x) 201 | x = Lambda(lambda x: tf.pad(x,[[0,0],[0,0],[0,0],[0,0]],'REFLECT'))(x) 202 | x =(Activation('relu'))(x) 203 | x =(BatchNormalization(momentum=0.8))(x) 204 | 205 | x =(UpSampling2D(size=(2, 2), interpolation='bilinear'))(x) 206 | x = (Conv2DTranspose(self.gf*4, (3, 3), padding="same",name="upsample_conv_2"))(x) 207 | x = Lambda(lambda x: tf.pad(x,[[0,0],[0,0],[0,0],[0,0]],'REFLECT'))(x) 208 | x =(Activation('relu'))(x) 209 | x =(BatchNormalization(momentum=0.8))(x) 210 | 211 | x =(UpSampling2D(size=(2, 2), interpolation='bilinear'))(x) 212 | x = (Conv2DTranspose(self.gf*2, (3, 3), padding="same",name="upsample_conv_3"))(x) 213 | x = Lambda(lambda x: tf.pad(x,[[0,0],[0,0],[0,0],[0,0]],'REFLECT'))(x) 214 | x =(Activation('relu'))(x) 215 | x =(BatchNormalization(momentum=0.8))(x) 216 | 217 | x =(UpSampling2D(size=(2, 2), interpolation='bilinear'))(x) 218 | x = (Conv2DTranspose(self.gf, (3, 3), padding="same",name="upsample_conv_4"))(x) 219 | x = Lambda(lambda x: tf.pad(x,[[0,0],[0,0],[0,0],[0,0]],'REFLECT'))(x) 220 | x =(Activation('relu'))(x) 221 | x =(BatchNormalization(momentum=0.8))(x) 222 | 223 | x = (Conv2DTranspose(self.channels, (3, 3), padding="same",name="final_output"))(x) 224 | x =(Activation('tanh'))(x) 225 | decoded_output = x 226 | reversed_mask_image = Multiply()([decoded_output, input_mask]) 227 | output_img = Add()([masked_image,reversed_mask_image]) 228 | concat_output_img = Concatenate()([output_img,input_mask]) 229 | model = Model(inputs = [input_img, input_mask], outputs = [concat_output_img]) 230 | print("====Generator Summary===") 231 | model.summary() 232 | return model 233 | 234 | # =================================================================================== # 235 | # 7. Define the discriminator # 236 | # =================================================================================== # 237 | 238 | def build_discriminator(self): 239 | input_img = Input(shape=(self.missing_shape), dtype='float32', name='d_input') 240 | 241 | dis = (Conv2D(self.df, kernel_size=3, strides=2, input_shape=self.missing_shape, padding="same"))(input_img) 242 | dis = (LeakyReLU(alpha=0.2))(dis) 243 | dis = (Dropout(0.25))(dis) 244 | dis = (Conv2D(self.df*2, kernel_size=3, strides=2, padding="same"))(dis) 245 | dis = (ZeroPadding2D(padding=((0,1),(0,1))))(dis) 246 | dis = (BatchNormalization(momentum=0.8))(dis) 247 | dis = (LeakyReLU(alpha=0.2))(dis) 248 | dis = (Dropout(0.25))(dis) 249 | dis = (Conv2D(self.df*4, kernel_size=3, strides=2, padding="same"))(dis) 250 | dis = (BatchNormalization(momentum=0.8))(dis) 251 | dis = (LeakyReLU(alpha=0.2))(dis) 252 | dis = (Dropout(0.25))(dis) 253 | dis = (Conv2D(self.df*8, kernel_size=3, strides=2, padding="same"))(dis) 254 | dis = (BatchNormalization(momentum=0.8))(dis) 255 | dis = (LeakyReLU(alpha=0.2))(dis) 256 | dis = (Dropout(0.25))(dis) 257 | dis = (Flatten())(dis) 258 | dis = (Dense(1))(dis) 259 | 260 | model = Model(inputs=[input_img], outputs=dis) 261 | print("====Discriminator Summary===") 262 | model.summary() 263 | return model 264 | 265 | # =================================================================================== # 266 | # 8. Define the loading function # 267 | # =================================================================================== # 268 | def get_batch(self, imgs_index, batch_imgs): 269 | if(imgs_index+batch_imgs) >= len(self.imgs_in_path): 270 | batch_imgs = len(self.imgs_in_path)-imgs_index 271 | real_imgs = np.zeros((batch_imgs, self.img_width, self.img_height,3)) 272 | masks = np.zeros((batch_imgs, self.img_width, self.img_height,1)) 273 | masked_imgs = np.zeros((batch_imgs, self.img_width, self.img_height,3)) 274 | masks_index = random.sample(range(1,len(self.masks_in_path)), batch_imgs) 275 | maskindx = 0 276 | for i in range(batch_imgs): 277 | print("\rLoading image number "+ str(i) + " of " + str(len(self.imgs_in_path)), end = " ") 278 | real_img = cv2.imread(self.img_dir + self.imgs_in_path[imgs_index], 1).astype('float')/ 127.5 -1 279 | real_img = cv2.resize(real_img,(self.img_width, self.img_height)) 280 | #If masks bits are white, DO NOT subtract from 1. 281 | #If masks bits are black, subtract from 1. 282 | mask = 1-cv2.imread(self.masks_dir + self.masks_in_path[masks_index[maskindx]],0).astype('float')/ 255 283 | mask = cv2.resize(mask,(self.img_width, self.img_height)) 284 | mask = np.reshape(mask,(self.img_width, self.img_height,1)) 285 | 286 | masks[i] = mask 287 | real_imgs[i] = real_img 288 | #masked_imgs[np.where((mask ==[1,1,1]).all(axis=2))]=[255,255,255] 289 | masked_imgs[i][np.where(mask == 0)]=1 290 | maskindx +=1 291 | imgs_index +=1 292 | if(imgs_index >= len(self.imgs_in_path)): 293 | imgs_index = 0 294 | # cv2.imwrite(os.path.join(path, 'mask_'+str(i)+'.jpg'),rawmask) 295 | # cv2.imshow("mask",((masked_imgs[0]+1)* 127.5).astype("uint8")) 296 | # cv2.waitKey(0 ) 297 | return imgs_index,real_imgs, masks,masked_imgs 298 | # =================================================================================== # 299 | # 8. Define the loading function # 300 | # =================================================================================== # 301 | def train(self): 302 | # Ground truths for adversarial loss 303 | valid = np.ones([self.batch_size, 1]) 304 | fake = -np.ones((self.batch_size, 1)) 305 | total_files= 27000 306 | batch_imgs = 1000 307 | imgs_index =0 308 | dataLoads = total_files//batch_imgs 309 | #self.generator.load_weights(r'./{}/{}/weight_{}.h5'.format(self.models_path, self.dataset_name, self.last_trained_epoch)) 310 | # print ( "Successfully loaded last check point" ) 311 | for epoch in range(1, self.num_epochs + 1): 312 | 313 | for databatchs in range(dataLoads): 314 | imgs_index,imgs, masks,masked_imgs = self.get_batch(imgs_index, batch_imgs) 315 | batches = imgs.shape[0]//self.batch_size 316 | global_step = 0 317 | for batch in range(batches): 318 | idx = np.random.permutation(imgs.shape[0]) 319 | idx_batches = idx[batch*self.batch_size:(batch+1)*self.batch_size] 320 | gen_imgs=self.generator.predict([imgs[idx_batches],masks[idx_batches]], self.batch_size) 321 | gen_imgs = gen_imgs[:,:,:,0:3] 322 | 323 | # =================================================================================== # 324 | # 8.2. Train the discriminator # 325 | # =================================================================================== # 326 | self.discriminator.trainable = True 327 | d_loss_real = self.discriminator.train_on_batch(imgs[idx_batches], valid) 328 | d_loss_fake = self.discriminator.train_on_batch(gen_imgs[:,:,:,0:3], fake) 329 | d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) 330 | 331 | 332 | # =================================================================================== # 333 | # 8.3. Train the generator # 334 | # =================================================================================== # 335 | 336 | # Train the generator 337 | self.discriminator.trainable = False 338 | g_loss = self.combined.train_on_batch([imgs[idx_batches], masks[idx_batches]], 339 | [K.stack([imgs[idx_batches], masks[idx_batches]], axis=-1),valid]) 340 | 341 | # =================================================================================== # 342 | # 8.4. Plot the progress # 343 | # =================================================================================== # 344 | print ("Epoch: %d Batch: %d/%d dataloads: %d/%d [D loss: %f, op_acc: %.2f%%] [G loss: %f MSE loss: %f]" % (epoch+self.current_epoch, 345 | batch, batches,databatchs,dataLoads, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1])) 346 | 347 | 348 | idx_batches = idx[databatchs*self.batch_size:(databatchs+1)*self.batch_size] 349 | imgs = imgs[idx] 350 | masks = masks[idx] 351 | 352 | input_img = np.expand_dims(imgs[0], 0) 353 | input_mask = np.expand_dims(masks[0], 0) 354 | 355 | if epoch % 1 == 0: 356 | if not os.path.exists("{}/{}/".format(self.models_path, self.dataset_name)): 357 | os.makedirs("{}/{}/".format(self.models_path, self.dataset_name)) 358 | name = "{}/{}/weight_{}.h5".format(self.models_path, self.dataset_name, epoch+self.current_epoch) 359 | 360 | self.generator.save_weights(name) 361 | if not os.path.exists(self.dataset_name): 362 | os.makedirs(self.dataset_name,exist_ok=True) 363 | predicted_img = self.generator.predict([input_img, input_mask]) 364 | self.sample_images(self.dataset_name, input_img, predicted_img[:,:,:,0:3], 365 | input_mask, epoch) 366 | print("Total Processing time:: {:4.2f}min" .format((self.end_time - self.start_time)/60)) 367 | self.epoch+=1 368 | 369 | # =================================================================================== # 370 | # 9. Sample images during training # 371 | # =================================================================================== # 372 | 373 | def sample_images(self, dataset_name,input_img, sample_pred, mask, epoch): 374 | if not os.path.exists(self.dataset_name): 375 | os.makedirs(self.dataset_name) 376 | input_img = np.expand_dims(input_img[0], 0) 377 | input_mask = np.expand_dims(mask[0], 0) 378 | maskedImg = ((1 - input_mask)*input_img) + input_mask 379 | img = np.concatenate((((maskedImg[0]+1)* 127.5).astype("uint8"), 380 | ((sample_pred[0]+1)* 127.5).astype("uint8"), 381 | ((input_img[0]+1)* 127.5).astype("uint8")),axis=1) 382 | img_filepath = os.path.join(self.dataset_name, 'pred_{}.jpg'.format(epoch+self.current_epoch)) 383 | 384 | cv2.imwrite(img_filepath, img) 385 | 386 | # =================================================================================== # 387 | # 10. Plot the discriminator and generator losses # 388 | # =================================================================================== # 389 | 390 | def plot_logs(self,epoch, avg_d_loss, avg_g_loss): 391 | if not os.path.exists("LogsUnet"): 392 | os.makedirs("LogsUnet") 393 | plt.figure() 394 | plt.plot(range(len(avg_d_loss)), avg_d_loss, 395 | color='red', label='Discriminator loss') 396 | plt.plot(range(len(avg_g_loss)), avg_g_loss, 397 | color='blue', label='Adversarial loss') 398 | plt.title('Discriminator and Adversarial loss') 399 | plt.xlabel('Iterations') 400 | plt.ylabel('Loss (Adversarial/Discriminator)') 401 | plt.legend() 402 | plt.savefig("LogsUnet/{}_paper/log_ep{}.pdf".format(self.dataset_name, epoch+self.current_epoch)) 403 | --------------------------------------------------------------------------------