├── .gitignore ├── README.md ├── README ├── adamaaliamix crop.png └── obamaResults.png ├── analyze_lighting.py ├── analyze_lighting_multiple.py ├── data └── test │ ├── images │ └── obama.jpg │ └── light │ ├── rotate_light_00.txt │ ├── rotate_light_01.txt │ ├── rotate_light_02.txt │ ├── rotate_light_03.txt │ ├── rotate_light_04.txt │ ├── rotate_light_05.txt │ └── rotate_light_06.txt ├── face_detect ├── faceDetect.py └── haarcascade_frontalface_default.xml ├── gui.py ├── live_lighting_transfer.py ├── model ├── data.py ├── debug.py ├── loss.py ├── model.py └── train.py ├── relight.py ├── requirements.txt ├── testPostRotate.jpg ├── testPreRotate.jpg ├── test_network.py ├── trained_models ├── official.t7 └── trained.pt └── utils ├── clean_data.py ├── utils_SH.py ├── utils_SH.pyc ├── utils_normal.py └── utils_shtools.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Large Files 2 | data/train/* 3 | result/* 4 | *.MP4 5 | *.avi 6 | *.MOV 7 | 8 | # Misc 9 | .idea 10 | **/.DS_Store 11 | **/__pycache__ 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Portrait Image Relighting 2 | 3 | Brown CS1430 Final Project: Portrait Image Relighting 4 | 5 | # Team 6 | 7 | Aalia Habib, Adam Pikielny, Jack Dermer, Ben Givertz 8 | 9 | Monsters.inc 10 | 11 | # Project 12 | 13 | ![Obama image](https://github.com/APikielny/image-relighting/blob/master/README/obamaResults.png) 14 | ![Face to face relighting](https://github.com/APikielny/image-relighting/blob/master/README/adamaaliamix%20crop.png) 15 | 16 | ## Overview 17 | 18 | Traditional methods for relighting faces require knowledge of the subject's reflectance, lighting, and structure. We sought to implement a deep learning algorithm to solve this task and relight portraits given only a single image as input. We implemented an Hourglass-shaped CNN from research by Zhou et al., [Deep Single Image Portrait Relighting](https://zhhoper.github.io/dpr.html), in order to relight portrait images. The model first separates the input image into facial and lighting features, from which a specialized lighting network predicts the direction of light. Then, the facial features are combined with the desired new lighting. Using a synthesized data set of portrait images under various artificial lighting conditions for training and ground truth, we were able to achieve realistic results, outputting images at a resolution of 256\*256. 19 | 20 | ## Dataset 21 | 22 | Due to computational and storage limitations, we used a scaled down version of the 23 | DPR dataset created by the original paper. The images were scaled down to both 24 | 128x128 and 256x256. Both datasets are available for download on [Google Drive](https://drive.google.com/open?id=1v-8FebXQPk5YqlWYYDe7frwy9OkJ24yq). 25 | 26 | # Usage 27 | 28 | The dependencies for this project can be found in the `requirements.txt` file and 29 | can be installed with `pip install -r requirements.txt`. 30 | 31 | ## Model Training 32 | 33 | The `train.py` file can be found in the model directory along with files for data loading, 34 | the loss function, and the model itself. To train the model, the image folders from the dataset must be moved into the `data/train/` directory. 35 | 36 | Train a new model using: 37 | 38 | - `python train.py [-h] [--epochs EPOCHS] [--batch BATCH] [--lr learning_rate] [--data DATA] [--model MODEL] [--verbose] [--debug]` 39 | 40 | ## Model Testing 41 | 42 | There are multiple ways to test our model, detailed below. For each test, we allow specification of input images, the model to use, and whether or not to use a GPU. 43 | 44 | The image(s) should be stored in the folder `data/test/images/`. The model should be stored in `trained_models/`. Use the `--gpu` flag if you'd like to run on a CUDA GPU (such as on Google Cloud Platform). 45 | 46 | 1. To **relight a face from several angles**, use `test_network.py`. The `test_network.py` file can be run using: 47 | 48 | - `python test_network.py [-h] [--image IMAGE)] [--model MODEL] [--gpu]` 49 | 50 | 2. To **relight based on lighting from another face**, use: 51 | 52 | - `python relight.py [-h] [--source_image SOURCE_IMAGE] [--light_image LIGHT_IMAGE] [--model MODEL] [--gpu] [--face_detect FACE_DETECT]` 53 | 54 | - The `[--face_detect]` flag can be passed "both" or "light". "Light" will only run face detection on the lighting input, which is recommended. Running "both" will crop both faces, so the output face will also be cropped. 55 | 56 | 3. The `live_lighting_transfer.py` file can be run to **see a live webcam view with dynamic relighting**: 57 | 58 | - `python live_lighting_transfer.py [-h] [--light_image LIGHT_IMAGE] [--light_text LIGHT_TEXT] [--video_path VIDEO_PATH]` 59 | 60 | - `[--light_text]` is the target lighting as an array. 61 | 62 | - `[--video_path]` can be added to use a pre-recorded video instead of the live webcam. 63 | 64 | 4. For a more **user friendly** approach, the `gui.py` file can be run using `python gui.py` in the `GUI` folder. 65 | 66 | - Use any image that contains faces for lighting reference--no cropping necessary 67 | - For the image you would like to apply lighting to, please crop close to the face prior to input. 68 | -------------------------------------------------------------------------------- /README/adamaaliamix crop.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/README/adamaaliamix crop.png -------------------------------------------------------------------------------- /README/obamaResults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/README/obamaResults.png -------------------------------------------------------------------------------- /analyze_lighting.py: -------------------------------------------------------------------------------- 1 | #Adam Pikielny 2 | #Fall 2020 3 | #analyze lighting of a face, outputting SH coordinates (by plotting, sphere, or array) 4 | 5 | import sys 6 | sys.path.append('model') 7 | sys.path.append('utils') 8 | 9 | from utils_SH import * 10 | 11 | from face_detect.faceDetect import cropFace 12 | 13 | # other modules 14 | import os 15 | import numpy as np 16 | 17 | from torch.autograd import Variable 18 | import torch 19 | import cv2 20 | import argparse 21 | import matplotlib.pyplot as plt 22 | from pyshtools import rotate 23 | 24 | 25 | 26 | # This code is adapted from https://github.com/zhhoper/DPR 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser( 30 | description="image relighting training.") 31 | parser.add_argument( 32 | '--model', 33 | default='trained.pt', 34 | help='model file to use stored in trained_model/' 35 | ) 36 | parser.add_argument( 37 | '--gpu', 38 | action='store_true', 39 | help='cpu vs. gpu' 40 | ) 41 | parser.add_argument( 42 | '--face_detect', 43 | default='Neither', 44 | help='Options: "both" or "light". Face detection/cropping for more accurate relighting.' 45 | ) 46 | parser.add_argument( 47 | '--video_path', 48 | default='/video.avi', 49 | help='video path to analyze' 50 | ) 51 | parser.add_argument( 52 | '--output_light_path', 53 | help='output path for sphere video' 54 | ) 55 | parser.add_argument( 56 | '--output_sphere', 57 | help='output path for sphere image' 58 | ) 59 | parser.add_argument( 60 | '--plot_path', 61 | help='output path for plots of SH coordinates' 62 | ) 63 | parser.add_argument( 64 | '--frames', 65 | default = 30, 66 | help='number of frames to analyze' 67 | ) 68 | 69 | 70 | return parser.parse_args() 71 | 72 | def preprocess_image(src_img, srcOrLight): 73 | # src_img = cv2.imread(img_path) 74 | if (ARGS.face_detect == 'both') or (ARGS.face_detect == 'light' and srcOrLight == 2): 75 | src_img = cropFace(src_img) 76 | row, col, _ = src_img.shape 77 | src_img = cv2.resize(src_img, (256, 256)) 78 | Lab = cv2.cvtColor(src_img, cv2.COLOR_BGR2LAB) #converts image to one color space LAB 79 | 80 | inputL = Lab[:,:,0] #taking only the L channel 81 | inputL = inputL.astype(np.float32)/255.0 #normalise 82 | inputL = inputL.transpose((0,1)) 83 | inputL = inputL[None,None,...] #not sure what's happening here 84 | inputL = Variable(torch.from_numpy(inputL)) 85 | if (ARGS.gpu): 86 | inputL = inputL.cuda() 87 | return inputL, row, col, Lab 88 | 89 | ## copied from test_network.py 90 | def render_half_sphere(sh): 91 | img_size = 256 92 | x = np.linspace(-1, 1, img_size) 93 | z = np.linspace(1, -1, img_size) 94 | x, z = np.meshgrid(x, z) 95 | 96 | mag = np.sqrt(x**2 + z**2) 97 | valid = mag <=1 98 | y = -np.sqrt(1 - (x*valid)**2 - (z*valid)**2) 99 | x = x * valid 100 | y = y * valid 101 | z = z * valid 102 | normal = np.concatenate((x[...,None], y[...,None], z[...,None]), axis=2) 103 | normal = np.reshape(normal, (-1, 3)) 104 | 105 | sh = np.squeeze(sh) 106 | shading = get_shading(normal, sh) 107 | value = np.percentile(shading, 95) 108 | ind = shading > value 109 | shading[ind] = value 110 | shading = (shading - np.min(shading))/(np.max(shading) - np.min(shading)) 111 | shading = (shading *255.0).astype(np.uint8) 112 | shading = np.reshape(shading, (256, 256)) 113 | shading = shading * valid 114 | # print("outputting to ", ARGS.output_light_path) 115 | # cv2.imwrite(ARGS.output_light_path, shading) 116 | return shading 117 | 118 | 119 | ARGS = parse_args() 120 | 121 | modelFolder = 'trained_models/' 122 | 123 | # load model 124 | from model import * 125 | my_network = HourglassNet() 126 | 127 | if (ARGS.gpu): 128 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model))) 129 | my_network.cuda() 130 | else: 131 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model), map_location=torch.device('cpu'))) 132 | 133 | my_network.train(False) 134 | 135 | # create video reader and writer 136 | if (ARGS.video_path is not None): 137 | vc = cv2.VideoCapture(ARGS.video_path) 138 | else: 139 | pass 140 | 141 | if (ARGS.gpu): 142 | sh = sh.cuda() 143 | 144 | if (ARGS.output_light_path is not None): 145 | videoWriter = cv2.VideoWriter(ARGS.output_light_path,cv2.VideoWriter_fourcc(*'MJPG'), 30, (256,256)) 146 | 147 | SHs = [] 148 | squashedOutput = [] #for plotting 149 | 150 | _, img = vc.read() 151 | i = 0 152 | # while img is not None: 153 | frames = ARGS.frames 154 | for f in range(frames): 155 | light_img, _, _, _ = preprocess_image(img, 2) 156 | 157 | sh = torch.zeros((1,9,1,1)) 158 | 159 | _, outputSH = my_network(light_img, sh, 0) 160 | SHs.append(outputSH) 161 | squashedOutput.append(torch.reshape(outputSH, (9,)).cpu().data.numpy()) 162 | 163 | 164 | ########## 165 | # rendering SH coords as sphere image/video 166 | frame = render_half_sphere(outputSH.cpu().data.numpy()) 167 | 168 | # cv2.imwrite('/Users/Adam/Desktop/brown/junior/cs1970/image-relighting/analyzeLightPics/frame' + str(i) + '.jpg', frame) 169 | i += 1 170 | 171 | # frame = (frame*255).astype('uint8') 172 | # videoWriter.write(frame) 173 | ########## 174 | 175 | _, img = vc.read() 176 | 177 | # if videoWriter is not None: 178 | # videoWriter.release() 179 | # print(SHs) 180 | 181 | mean = torch.mean(torch.stack(SHs), dim = 0) 182 | var = torch.var(torch.stack(SHs), dim = 0) 183 | 184 | #### rotating the SH coords 185 | print("sh coords shape", mean.cpu().data.numpy().shape) 186 | 187 | 188 | lmax = 1 189 | rcoeffs = np.random.normal(size=(2, lmax + 1, lmax + 1)) 190 | print("rcoeffs shape", rcoeffs.shape) 191 | cv2.imwrite('testPreRotate.jpg', rcoeffs) 192 | dj_matrix = rotate.djpi2(lmax) 193 | angles = np.radians([20, 20, 20]) 194 | rotated = rotate.SHRotateRealCoef(rcoeffs, angles, dj_matrix) 195 | cv2.imwrite('testPostRotate.jpg', rotated) 196 | #### 197 | 198 | # print("mean of SHs:", mean) 199 | # print("var of SHs:", var) 200 | 201 | # if (ARGS.plot_path is not None): 202 | # plt.plot(squashedOutput) 203 | # plt.title('SHs over first ' + str(frames) + ' frames, for video: ' + ARGS.video_path) 204 | # plt.xlabel('Frame') 205 | # plt.savefig(ARGS.plot_path) 206 | 207 | frame = render_half_sphere(mean.cpu().data.numpy()) 208 | cv2.imwrite(ARGS.output_sphere, frame) 209 | -------------------------------------------------------------------------------- /analyze_lighting_multiple.py: -------------------------------------------------------------------------------- 1 | #Adam Pikielny 2 | #Fall 2020 3 | #analyze lighting of a face, outputting SH coordinates into a matlab file 4 | 5 | import sys 6 | sys.path.append('model') 7 | sys.path.append('utils') 8 | 9 | from utils_SH import * 10 | 11 | from face_detect.faceDetect import cropFace 12 | 13 | # other modules 14 | import os 15 | import numpy as np 16 | 17 | from torch.autograd import Variable 18 | import torch 19 | import cv2 20 | import argparse 21 | import matplotlib.pyplot as plt 22 | from scipy.io import savemat 23 | import re 24 | 25 | 26 | # This code is adapted from https://github.com/zhhoper/DPR 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser( 30 | description="image relighting training.") 31 | parser.add_argument( 32 | '--model', 33 | default='trained.pt', 34 | help='model file to use stored in trained_model/' 35 | ) 36 | parser.add_argument( 37 | '--gpu', 38 | action='store_true', 39 | help='cpu vs. gpu' 40 | ) 41 | parser.add_argument( 42 | '--face_detect', 43 | default='Neither', 44 | help='Options: "both" or "light". Face detection/cropping for more accurate relighting.' 45 | ) 46 | parser.add_argument( 47 | '--videos_path', 48 | default='', 49 | help='folder with videos to put in dictionary' 50 | ) 51 | parser.add_argument( 52 | '--fake_path', 53 | default='', 54 | help='path to fake to add to dictionary' 55 | ) 56 | parser.add_argument( 57 | '--output_light_path', 58 | help='output path for sphere video' 59 | ) 60 | parser.add_argument( 61 | '--output_sphere', 62 | help='output path for sphere image' 63 | ) 64 | parser.add_argument( 65 | '--plot_path', 66 | help='output path for plots of SH coordinates' 67 | ) 68 | parser.add_argument( 69 | '--frames', 70 | default = 30, 71 | help='number of frames to analyze' 72 | ) 73 | parser.add_argument( 74 | '--mat_path', 75 | default='', 76 | help='output .mat file path' 77 | ) 78 | 79 | 80 | return parser.parse_args() 81 | 82 | def preprocess_image(src_img, srcOrLight): 83 | # src_img = cv2.imread(img_path) 84 | if (ARGS.face_detect == 'both') or (ARGS.face_detect == 'light' and srcOrLight == 2): 85 | src_img = cropFace(src_img) 86 | row, col, _ = src_img.shape 87 | src_img = cv2.resize(src_img, (256, 256)) 88 | Lab = cv2.cvtColor(src_img, cv2.COLOR_BGR2LAB) #converts image to one color space LAB 89 | 90 | inputL = Lab[:,:,0] #taking only the L channel 91 | inputL = inputL.astype(np.float32)/255.0 #normalise 92 | inputL = inputL.transpose((0,1)) 93 | inputL = inputL[None,None,...] #not sure what's happening here 94 | inputL = Variable(torch.from_numpy(inputL)) 95 | if (ARGS.gpu): 96 | inputL = inputL.cuda() 97 | return inputL, row, col, Lab 98 | 99 | ## copied from test_network.py 100 | def render_half_sphere(sh): 101 | img_size = 256 102 | x = np.linspace(-1, 1, img_size) 103 | z = np.linspace(1, -1, img_size) 104 | x, z = np.meshgrid(x, z) 105 | 106 | mag = np.sqrt(x**2 + z**2) 107 | valid = mag <=1 108 | y = -np.sqrt(1 - (x*valid)**2 - (z*valid)**2) 109 | x = x * valid 110 | y = y * valid 111 | z = z * valid 112 | normal = np.concatenate((x[...,None], y[...,None], z[...,None]), axis=2) 113 | normal = np.reshape(normal, (-1, 3)) 114 | 115 | sh = np.squeeze(sh) 116 | shading = get_shading(normal, sh) 117 | value = np.percentile(shading, 95) 118 | ind = shading > value 119 | shading[ind] = value 120 | shading = (shading - np.min(shading))/(np.max(shading) - np.min(shading)) 121 | shading = (shading *255.0).astype(np.uint8) 122 | shading = np.reshape(shading, (256, 256)) 123 | shading = shading * valid 124 | # print("outputting to ", ARGS.output_light_path) 125 | # cv2.imwrite(ARGS.output_light_path, shading) 126 | return shading 127 | 128 | 129 | ARGS = parse_args() 130 | 131 | modelFolder = 'trained_models/' 132 | 133 | # load model 134 | from model import * 135 | my_network = HourglassNet() 136 | 137 | if (ARGS.gpu): 138 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model))) 139 | my_network.cuda() 140 | else: 141 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model), map_location=torch.device('cpu'))) 142 | 143 | my_network.train(False) 144 | 145 | 146 | if (ARGS.gpu): 147 | sh = sh.cuda() 148 | 149 | filePaths = os.listdir(ARGS.videos_path) 150 | filePaths.append(ARGS.fake_path) 151 | 152 | i = 1 153 | dataDict = {} 154 | for filename in filePaths: 155 | if (filename == ARGS.fake_path) or ((filename.endswith(".avi") or filename.endswith(".MP4") or filename.endswith(".mp4")) and (re.search("camera\d.MP4", filename) is not None or re.search("cam\d.avi", filename) is not None)): 156 | # create video reader and writer 157 | vc = cv2.VideoCapture(ARGS.videos_path + filename) 158 | _, img = vc.read() 159 | if img is None: 160 | vc = cv2.VideoCapture(filename) 161 | _, img = vc.read() 162 | 163 | # i = 0 164 | # while img is not None: 165 | frames = int(ARGS.frames) 166 | SHs = np.zeros((frames, 9)) 167 | 168 | for f in range(frames): 169 | if img is not None: 170 | 171 | light_img, _, _, _ = preprocess_image(img, 2) 172 | 173 | sh = torch.zeros((1,9,1,1)) 174 | 175 | _, outputSH = my_network(light_img, sh, 0) 176 | # SHs.append(outputSH) 177 | # squashedOutput.append(torch.reshape(outputSH, (9,)).cpu().data.numpy()) 178 | SHs[f] = torch.reshape(outputSH, (9,)).cpu().data.numpy() 179 | 180 | 181 | _, img = vc.read() 182 | if (filename == ARGS.fake_path): 183 | dataDict['fake'] = (SHs - np.mean(SHs)) / np.std(SHs) 184 | else: 185 | dataDict['cam' + str(i)] = (SHs - np.mean(SHs)) / np.std(SHs) 186 | i += 1 187 | 188 | # print(dataDict) 189 | savemat(ARGS.mat_path, dataDict) 190 | -------------------------------------------------------------------------------- /data/test/images/obama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/data/test/images/obama.jpg -------------------------------------------------------------------------------- /data/test/light/rotate_light_00.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617166185e-01 3 | 2.837846795150648915e-02 4 | 6.765292733937575687e-01 5 | -3.594067725393816914e-01 6 | 4.790996460111427574e-02 7 | -2.280054643781863066e-01 8 | -8.125983081159608712e-02 9 | 2.881082012687687932e-01 10 | -------------------------------------------------------------------------------- /data/test/light/rotate_light_01.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617170626e-01 3 | 5.466255701105990905e-01 4 | 3.996219229512094628e-01 5 | -2.615439760463462715e-01 6 | -2.511241554473071513e-01 7 | 6.495694866016435420e-02 8 | 3.510322039081858470e-01 9 | 1.189662732386344152e-01 10 | -------------------------------------------------------------------------------- /data/test/light/rotate_light_02.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617179508e-01 3 | 6.532524688468428486e-01 4 | -1.782088862752457814e-01 5 | 3.326676893441832261e-02 6 | -3.610566644446819295e-01 7 | 3.647561777790956361e-01 8 | -7.496419691318900735e-02 9 | -5.412289239602386531e-02 10 | -------------------------------------------------------------------------------- /data/test/light/rotate_light_03.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617186724e-01 3 | 2.679669346194941126e-01 4 | -6.218447693376460972e-01 5 | 3.030269583891490037e-01 6 | -1.991061409014726058e-01 7 | -6.162944418511027977e-02 8 | -3.176699976873690878e-01 9 | 1.920509612235956343e-01 10 | -------------------------------------------------------------------------------- /data/test/light/rotate_light_04.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617186724e-01 3 | -3.191031669056417219e-01 4 | -5.972188577671910803e-01 5 | 3.446016675533919993e-01 6 | 1.127753677656503223e-01 7 | -1.716692196540034188e-01 8 | 2.163406460637767315e-01 9 | 2.555824552121269688e-01 10 | -------------------------------------------------------------------------------- /data/test/light/rotate_light_05.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617178398e-01 3 | -6.658820752324799974e-01 4 | -1.228749652534838893e-01 5 | 1.266842924569576145e-01 6 | 3.397347243069742673e-01 7 | 3.036887095295650041e-01 8 | 2.213893524577207617e-01 9 | -1.886557316342868038e-02 10 | -------------------------------------------------------------------------------- /data/test/light/rotate_light_06.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617169516e-01 3 | -5.112381993903207800e-01 4 | 4.439962822886048266e-01 5 | -1.866289387481862572e-01 6 | 3.108669041197227867e-01 7 | 2.021743042675238355e-01 8 | -3.148681770175290051e-01 9 | 3.974379604123656762e-02 10 | -------------------------------------------------------------------------------- /face_detect/faceDetect.py: -------------------------------------------------------------------------------- 1 | # taken from https://machinelearningmastery.com/how-to-perform-face-detection-with-classical-and-deep-learning-methods-in-python-with-keras/ 2 | # adapted by Adam Pikielny, May 2020 3 | # plot photo with detected faces using opencv cascade classifier 4 | import cv2 5 | from cv2 import imread, resize 6 | from cv2 import imshow 7 | from cv2 import waitKey 8 | from cv2 import destroyAllWindows 9 | from cv2 import CascadeClassifier 10 | from cv2 import rectangle 11 | 12 | def cropFace(img): 13 | 14 | # load the photograph 15 | # pixels = img 16 | pixels = img 17 | 18 | # load the pre-trained model 19 | classifier = CascadeClassifier('face_detect/haarcascade_frontalface_default.xml') 20 | 21 | # perform face detection 22 | bboxes = classifier.detectMultiScale(pixels) 23 | 24 | if len(bboxes) == 0: 25 | print("ERROR: No faces found.") 26 | 27 | # extract 28 | x, y, width, height = bboxes[0] 29 | x2, y2 = x + width, y + height 30 | 31 | BUFFER = int(width * 0.25) 32 | 33 | # show the image 34 | image = pixels[max(y - BUFFER, 0):min(y2 + BUFFER, pixels.shape[0]), max(x - BUFFER, 0):min(x2 + BUFFER, pixels.shape[1])] 35 | # imshow('hehe', image) 36 | # waitKey(0) 37 | return image 38 | 39 | 40 | def cropFace2(img_path): 41 | img = cv2.imread(img_path) 42 | 43 | # load the photograph 44 | # pixels = img 45 | pixels = img 46 | 47 | # load the pre-trained model 48 | classifier = CascadeClassifier('face_detect/haarcascade_frontalface_default.xml') 49 | 50 | # perform face detection 51 | bboxes = classifier.detectMultiScale(pixels) 52 | 53 | if len(bboxes) == 0: 54 | print("ERROR: No faces found.") 55 | return None 56 | 57 | # extract 58 | x, y, width, height = bboxes[0] 59 | x2, y2 = x + width, y + height 60 | 61 | BUFFER = int(width * 0.25) 62 | 63 | images = [] 64 | 65 | # show the image 66 | for i in range(len(bboxes)): 67 | x, y, width, height = bboxes[i] 68 | x2, y2 = x + width, y + height 69 | images.append(pixels[max(y - BUFFER, 0):min(y2 + BUFFER, pixels.shape[0]), 70 | max(x - BUFFER, 0):min(x2 + BUFFER, pixels.shape[1])]) 71 | # imshow('hehe', images[i]) 72 | # waitKey(0) 73 | images[i] = cv2.cvtColor(images[i], cv2.COLOR_BGR2RGB) 74 | 75 | return images 76 | #cropFace(1) 77 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | from tkinter import * 2 | from PIL import ImageTk, Image 3 | from tkinter import filedialog 4 | import cv2 5 | import os 6 | 7 | import sys 8 | sys.path.append('model') 9 | sys.path.append('utils') 10 | 11 | from utils_SH import * 12 | 13 | # other modules 14 | import numpy as np 15 | 16 | from torch.autograd import Variable 17 | import torch 18 | 19 | from model import HourglassNet 20 | from face_detect.faceDetect import cropFace2 21 | 22 | class ImportImg(): 23 | 24 | def __init__(self): 25 | self.root = Tk() 26 | self.root.title('Face Relighting - Import Image') 27 | 28 | self.pull_button = Button(self.root, text="Select Image To Pull Lighting From", command=self.getPullFile, padx=50, pady=50) 29 | self.apply_button = Button(self.root, text="Select Image To Apply Lighting To", command=self.getApplyFile, padx=50, pady=50) 30 | 31 | self.next_button = Button(self.root, text="Next", state=DISABLED, command=self.loadNextSection) 32 | 33 | self.next_button.grid(row=2, column=1) 34 | self.pull_button.grid(row=0, column=0) 35 | self.apply_button.grid(row=0, column=2) 36 | 37 | self.pull_filepath = None 38 | self.apply_filepath = None 39 | 40 | self.root.mainloop() 41 | 42 | 43 | def getPullFile(self): 44 | self.pull_filepath = filedialog.askopenfilename(title="Select Image", filetypes=(("png", ".png"), ("jpeg", ".jpeg"), ("jpg", ".jpg"))) 45 | pull_label = Label(self.root, text=self.pull_filepath, bg="blue", fg="white") 46 | pull_label.grid(row=1, column=0) 47 | 48 | if self.apply_filepath != None: 49 | self.nextButtonReady() 50 | 51 | 52 | def getApplyFile(self): 53 | self.apply_filepath = filedialog.askopenfilename(title="Select Image", filetypes=(("png", ".png"), ("jpeg", ".jpeg"), ("jpg", ".jpg"))) 54 | apply_label = Label(self.root, text=self.apply_filepath, bg="blue", fg="white") 55 | apply_label.grid(row=1, column=2) 56 | 57 | if self.pull_filepath != None: 58 | self.nextButtonReady() 59 | 60 | def nextButtonReady(self): 61 | self.next_button.config(state=NORMAL, bg="green") 62 | 63 | def loadNextSection(self): 64 | self.root.destroy() 65 | SelectLightFace(self.pull_filepath, self.apply_filepath) 66 | 67 | class SelectLightFace(): 68 | def __init__(self, pull_filepath, apply_filepath): 69 | self.root = Tk() 70 | # self.root.geometry("600x600") 71 | self.root.title('Face Relighting - Select Lighting Face') 72 | 73 | self.pull_filepath = pull_filepath 74 | self.apply_filepath = apply_filepath 75 | 76 | instruction_label = Label(self.root, text="Pick a face on the left to use as lighting reference") 77 | 78 | apply_img = ImageTk.PhotoImage(Image.open(self.apply_filepath).resize((512, 512))) 79 | apply_img_label = Label(self.root, image=apply_img) 80 | 81 | 82 | self.face_list = cropFace2(self.pull_filepath) 83 | self.face_list_np = self.face_list.copy() 84 | self.convertNumpyImgs(self.face_list) 85 | self.curr = 0 86 | 87 | self.curr_carosel = Label(self.root, image=self.face_list[0]) 88 | 89 | left_button = Button(self.root, text="<<", command=self.left) 90 | right_button = Button(self.root, text=">>", command=self.right) 91 | select_button = Button(self.root, text="Apply Lighting", command=self.select) 92 | 93 | 94 | instruction_label.grid(row=0, column=1) 95 | left_button.grid(row=1, column=0) 96 | self.curr_carosel.grid(row=1, column=1) 97 | right_button.grid(row=1, column=2) 98 | apply_img_label.grid(row=1, column=3) 99 | select_button.grid(row=2, column=1) 100 | 101 | self.root.mainloop() 102 | 103 | def left(self): 104 | if self.curr > 0: 105 | self.curr -= 1 106 | self.curr_carosel.config(image=self.face_list[self.curr]) 107 | 108 | def right(self): 109 | if self.curr < (len(self.face_list) - 1): 110 | self.curr += 1 111 | self.curr_carosel.config(image=self.face_list[self.curr]) 112 | 113 | def select(self): 114 | 115 | pull_img = self.face_list_np[self.curr] 116 | 117 | apply_img = cv2.imread(self.apply_filepath) 118 | 119 | dest = filedialog.askdirectory(title="Select Output Folder") 120 | 121 | Relight(apply_img, pull_img, dest) 122 | 123 | self.root.destroy() 124 | 125 | def convertNumpyImgs(self, np_faces): 126 | for i in range(len(np_faces)): 127 | np_faces[i] = ImageTk.PhotoImage(Image.fromarray(np_faces[i], 'RGB').resize((512, 512))) 128 | 129 | 130 | class Relight(): 131 | def __init__(self, source, light, dest): 132 | self.relighting(source, light, dest) 133 | 134 | def preprocess_image(self, img): 135 | row, col, _ = img.shape 136 | src_img = cv2.resize(img, (256, 256)) 137 | Lab = cv2.cvtColor(src_img, cv2.COLOR_BGR2LAB) 138 | 139 | inputL = Lab[:, :, 0] 140 | inputL = inputL.astype(np.float32) / 255.0 141 | inputL = inputL.transpose((0, 1)) 142 | inputL = inputL[None, None, ...] 143 | inputL = Variable(torch.from_numpy(inputL)) 144 | 145 | return inputL, row, col, Lab 146 | 147 | def relighting(self, source, light, dest): 148 | # load model 149 | my_network = HourglassNet() 150 | 151 | my_network.load_state_dict(torch.load('trained_models/trained.pt', map_location=torch.device('cpu'))) 152 | 153 | my_network.train(False) 154 | 155 | light_img, _, _, _ = self.preprocess_image(light) 156 | 157 | sh = torch.zeros((1, 9, 1, 1)) 158 | 159 | _, outputSH = my_network(light_img, sh, 0) 160 | 161 | src_img, row, col, Lab = self.preprocess_image(source) 162 | 163 | outputImg, _ = my_network(src_img, outputSH, 0) 164 | 165 | outputImg = outputImg[0].cpu().data.numpy() 166 | outputImg = outputImg.transpose((1, 2, 0)) 167 | outputImg = np.squeeze(outputImg) 168 | outputImg = (outputImg * 255.0).astype(np.uint8) 169 | Lab[:, :, 0] = outputImg 170 | resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) 171 | resultLab = cv2.resize(resultLab, (col, row)) 172 | 173 | cv2.imwrite(os.path.join(dest, 174 | 'relit.jpg'), resultLab) 175 | 176 | ImportImg() -------------------------------------------------------------------------------- /live_lighting_transfer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding: utf8 3 | 4 | import sys 5 | sys.path.append('model') 6 | sys.path.append('utils') 7 | 8 | import cv2 9 | import time 10 | import math 11 | import numpy as np 12 | from scipy import ndimage 13 | from skimage import io 14 | from skimage import img_as_float, img_as_ubyte 15 | from skimage.color import rgb2gray 16 | 17 | import sys 18 | import os 19 | 20 | from torch.autograd import Variable 21 | import torch 22 | import argparse 23 | 24 | from model import * 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser( 28 | description="live image relighting") 29 | parser.add_argument( 30 | '--light_image', 31 | default=None, 32 | help='path to image light to copy', 33 | ) 34 | parser.add_argument( 35 | '--light_text', 36 | default=None, 37 | help='path to lighting matrix to copy', 38 | ) 39 | parser.add_argument( 40 | '--input_path', 41 | default=None, 42 | help='specify a path to a video to be relit, instead of the webcam' 43 | ) 44 | parser.add_argument( 45 | '--output_path', 46 | default='output.avi', 47 | help='specify path to write output video to, if input was specified' 48 | ) 49 | 50 | return parser.parse_args() 51 | 52 | 53 | def preprocess_image(img): 54 | row, col, _ = img.shape 55 | img = cv2.resize(img, (256, 256)) 56 | Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) #converts image to one color space LAB 57 | 58 | inputL = Lab[:,:,0] #taking only the L channel 59 | inputL = inputL.astype(np.float32)/255.0 #normalise 60 | inputL = inputL.transpose((0,1)) 61 | inputL = inputL[None,None,...] #not sure what's happening here 62 | 63 | inputL = Variable(torch.from_numpy(inputL)) 64 | return inputL, row, col, Lab 65 | 66 | def relight_image(model, src_img, target): 67 | src_img, row, col, Lab = preprocess_image(src_img) 68 | 69 | outputImg, _ = model(src_img, target, 0) 70 | 71 | outputImg = outputImg[0].cpu().data.numpy() 72 | outputImg = outputImg.transpose((1,2,0)) 73 | outputImg = np.squeeze(outputImg) 74 | outputImg = (outputImg*255.0).astype(np.uint8) 75 | Lab[:,:,0] = outputImg 76 | resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) 77 | resultLab = cv2.resize(resultLab, (col, row)) 78 | return resultLab 79 | 80 | class live_transfer_handler(): 81 | """ 82 | This function shows the live Fourier transform of a continuous stream of 83 | images captured from an attached camera. 84 | 85 | """ 86 | 87 | wn = "Image Lighting Transfer" 88 | use_camera = True 89 | im = 0 90 | model = 0 91 | target = None 92 | 93 | def __init__(self, target_img_path, target_text_path, input_path, **kwargs): 94 | # Camera device 95 | if (input_path is not None): 96 | self.vc = cv2.VideoCapture(input_path) 97 | else: 98 | self.vc = cv2.VideoCapture(0) 99 | 100 | if not self.vc.isOpened(): 101 | print( "No camera found or error opening camera." ) 102 | self.use_camera = False 103 | return 104 | 105 | else: 106 | # We found a camera! 107 | # Requested camera size. This will be cropped square later on, e.g., 240 x 240 108 | if (input_path is None): 109 | # Set the size of the output window 110 | 111 | self.vc.set(cv2.CAP_PROP_FRAME_WIDTH, 320) 112 | self.vc.set(cv2.CAP_PROP_FRAME_HEIGHT, 240) 113 | cv2.namedWindow(self.wn, 0) 114 | 115 | # load model 116 | my_network = HourglassNet() 117 | my_network.load_state_dict(torch.load("trained_models/trained.pt", map_location=torch.device('cpu'))) 118 | my_network.train(False) 119 | self.model = my_network 120 | 121 | # load target 122 | if target_img_path: 123 | target_img = cv2.imread(target_img_path) 124 | light_img, _, _, _ = preprocess_image(target_img) 125 | sh = torch.zeros((1,9,1,1)) 126 | _, outputSH = self.model(light_img, sh, 0) 127 | self.target = outputSH 128 | 129 | elif target_text_path: 130 | sh = np.loadtxt(target_text_path) 131 | sh = sh[0:9] 132 | sh = sh * 0.5 133 | sh = np.reshape(sh, (1,9,1,1)).astype(np.float32) 134 | outputSH = Variable(torch.from_numpy(sh)) 135 | self.target = outputSH 136 | 137 | else: 138 | print("No target specified") 139 | return 140 | 141 | videoWriter = None 142 | if (ARGS.output_path is not None): 143 | _, img = self.vc.read() 144 | dim1 = img.shape[0] 145 | dim2 = img.shape[1] 146 | #for now I am fixing this as a square because it gets better results 147 | videoWriter = cv2.VideoWriter(ARGS.output_path,cv2.VideoWriter_fourcc(*'MJPG'), 30, (dim1,dim1)) #change to dim2, dim1 to output not a square 148 | 149 | #if flag is true, pass to relighter 150 | # Main loop 151 | end = False 152 | while not end: 153 | a = time.perf_counter() 154 | end = self.relighter(videoWriter) 155 | if (ARGS.input_path is None): 156 | print('framerate = {} fps \r'.format(1. / (time.perf_counter() - a))) 157 | 158 | if self.use_camera: 159 | # Stop camera 160 | self.vc.release() 161 | 162 | def relighter(self, writer = None): 163 | 164 | if self.use_camera: 165 | # Read image 166 | _, im = self.vc.read() 167 | if im is None: 168 | writer.release() 169 | return True 170 | 171 | # if ARGS.input_path is None: 172 | if True: #for now I am fixing this as a square because it gets better results 173 | if im.shape[1] > im.shape[0]: 174 | cropx = int((im.shape[1]-im.shape[0])/2) 175 | cropy = 0 176 | elif im.shape[0] > im.shape[1]: 177 | cropx = 0 178 | cropy = int((im.shape[0]-im.shape[1])/2) 179 | 180 | self.im = im[cropy:im.shape[0]-cropy, cropx:im.shape[1]-cropx] 181 | 182 | else: 183 | self.im = im 184 | 185 | if (ARGS.input_path is None): 186 | # Set size 187 | width = 256 188 | height = 256 189 | cv2.resizeWindow(self.wn, width*2, height*2) 190 | 191 | real = img_as_float(self.im) 192 | relit = relight_image(self.model, self.im, self.target) 193 | relit = img_as_float(relit) 194 | output = np.clip(np.concatenate((real,relit),axis = 1),0,1) 195 | 196 | 197 | if writer is not None: 198 | frame = (relit*255).astype('uint8') 199 | writer.write(frame) 200 | 201 | if (ARGS.input_path is None): 202 | cv2.imshow(self.wn, relit) 203 | cv2.waitKey(1) 204 | 205 | return 206 | 207 | ARGS = parse_args() 208 | 209 | live_transfer_handler(ARGS.light_image, ARGS.light_text, ARGS.input_path) -------------------------------------------------------------------------------- /model/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from torch.autograd import Variable 4 | import torch 5 | import os 6 | from torch.utils.data import Dataset 7 | 8 | class CelebData(Dataset): 9 | def __init__(self, root_dir, max_data): 10 | paths = [] 11 | img_folders = os.listdir(root_dir) 12 | for img_folder in img_folders: 13 | img_path = os.path.join(root_dir, img_folder) 14 | if os.path.isdir(img_path): 15 | paths.append(img_path) 16 | if len(paths) == max_data: 17 | break 18 | if len(paths) == max_data: 19 | break 20 | add = max_data - len(paths) 21 | if add <= 0: 22 | self.paths = paths 23 | else: 24 | self.paths = loop_data_helper(paths, add) 25 | 26 | def __len__(self): 27 | return len(self.paths) 28 | 29 | def __getitem__(self, idx): 30 | folder_path = self.paths[idx] 31 | pair = np.random.choice(5, 2) 32 | img_folder_name = folder_path[-10:] 33 | 34 | image_s_path = os.path.join(folder_path, img_folder_name + "_0" + str(pair[0]) + ".jpg") 35 | image_t_path = os.path.join(folder_path, img_folder_name + "_0" + str(pair[1]) + ".jpg") 36 | lighting_s_path = os.path.join(folder_path, img_folder_name + "_light_0" + str(pair[0]) + ".txt") 37 | lighting_t_path = os.path.join(folder_path, img_folder_name + "_light_0" + str(pair[1]) + ".txt") 38 | 39 | I_s = get_image(image_s_path) 40 | I_t = get_image(image_t_path) 41 | L_s = get_lighting(lighting_s_path) 42 | L_t = get_lighting(lighting_t_path) 43 | 44 | return I_s, I_t, L_s, L_t 45 | 46 | 47 | def get_image(path_to_image): 48 | img = cv2.imread(path_to_image) 49 | Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) # converts image to one color space LAB 50 | 51 | inputL = Lab[:, :, 0] # taking only the L channel 52 | inputL = inputL.astype(np.float32) / 255.0 # normalise 53 | inputL = inputL.transpose((0, 1)) 54 | inputL = inputL[None, None, ...] # not sure what's happening here 55 | inputL = Variable(torch.from_numpy(inputL)) 56 | 57 | return inputL 58 | 59 | 60 | def get_lighting(path_to_light): 61 | sh = np.loadtxt(path_to_light) 62 | sh = sh[0:9] 63 | sh = sh * 0.7 64 | 65 | sh = np.reshape(sh, (1, 9, 1, 1)).astype(np.float32) 66 | sh = Variable(torch.from_numpy(sh)) 67 | return sh 68 | 69 | def loop_data_helper(paths, add_dup): 70 | data_size = len(paths) 71 | orig_paths = paths.copy() 72 | 73 | num_add = add_dup // data_size 74 | for i in range(num_add): 75 | paths.extend(orig_paths) 76 | rem = add_dup % data_size 77 | paths.extend(orig_paths[0:rem]) 78 | 79 | return paths -------------------------------------------------------------------------------- /model/debug.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('model') 3 | sys.path.append('utils') 4 | 5 | from datetime import datetime 6 | 7 | import os 8 | import numpy as np 9 | 10 | from torch.autograd import Variable 11 | import torch 12 | import cv2 13 | 14 | # Debugs the model during training 15 | 16 | def debug(model, epoch, modelId = None): 17 | lightFolder = '../data/test/light/' 18 | imgPath = '../data/test/images/obama.jpg' 19 | 20 | ##### getting image 21 | img = cv2.imread(imgPath) 22 | row, col, _ = img.shape 23 | img = cv2.resize(img, (128, 128)) 24 | Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) #converts image to one color space LAB 25 | 26 | inputL = Lab[:,:,0] #taking only the L channel 27 | inputL = inputL.astype(np.float32)/255.0 #normalise 28 | inputL = inputL.transpose((0,1)) 29 | inputL = inputL[None,None,...] #not sure what's happening here 30 | inputL = Variable(torch.from_numpy(inputL).cuda()) 31 | 32 | if (epoch == 0): 33 | print("datetime", datetime.now()) 34 | modelId = datetime.now().strftime("%d/%m/%Y %H:%M:%S") 35 | newModelId = modelId[0:2] + "&" + modelId[3:5] + "&" + modelId[6:10] + "," + modelId[11:] 36 | modelId = newModelId 37 | 38 | # newModelId = "" 39 | # for i in range(len(modelId)): 40 | # if i == 10: 41 | # newModelId += "," 42 | # elif i == 2 or i == 5: 43 | # newModelId += "&" 44 | # else: 45 | # newModelId += modelId[i] 46 | # modelId = newModelId 47 | 48 | 49 | print("Fixed modelId:", modelId) 50 | 51 | saveFolder = '../result/debug/' + modelId 52 | if not os.path.exists(saveFolder): 53 | os.makedirs(saveFolder) 54 | 55 | 56 | for i in range(7): 57 | ##### getting sh 58 | 59 | sh = np.loadtxt(os.path.join(lightFolder, 'rotate_light_{:02d}.txt'.format(i))) 60 | sh = sh[0:9] 61 | sh = sh * 0.7 62 | sh = np.reshape(sh, (1,9,1,1)).astype(np.float32) 63 | sh = Variable(torch.from_numpy(sh).cuda()) 64 | ##### 65 | 66 | outputImg, outputSH = model.forward(inputL, sh, 0) 67 | outputImg = outputImg[0].cpu().data.numpy() 68 | outputImg = outputImg.transpose((1,2,0)) 69 | outputImg = np.squeeze(outputImg) 70 | outputImg = (outputImg*255.0).astype(np.uint8) 71 | Lab[:,:,0] = outputImg 72 | resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) 73 | resultLab = cv2.resize(resultLab, (col, row)) 74 | #img_name, e = os.path.splitext(ARGS.image) 75 | img_name = "Epoch" + str(epoch) + "Light" + '{:02}'.format(i) 76 | 77 | cv2.imwrite(os.path.join(saveFolder, 78 | '{}.jpg'.format(img_name)), resultLab) 79 | 80 | return modelId -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.filters import SpatialGradient 3 | 4 | # Loss function for training 5 | 6 | def L1(I_t, I_tp, L_s, L_sp): 7 | img_l1 = torch.sum(torch.abs(I_t - I_tp)) 8 | 9 | I_t_grad = SpatialGradient()(I_t) 10 | I_tp_grad = SpatialGradient()(I_tp) 11 | 12 | grad_l1 = torch.sum(torch.abs(I_t_grad - I_tp_grad)) 13 | 14 | light_l2 = torch.sum((L_s - L_sp) ** 2) 15 | 16 | N = I_t.shape[2] 17 | loss = ((img_l1 + grad_l1) / (N * N)) + (light_l2 / 9) 18 | return loss -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # This code is taken directly from https://github.com/zhhoper/DPR 6 | 7 | # we define Hour Glass network based on the paper 8 | # Stacked Hourglass Networks for Human Pose Estimation 9 | # Alejandro Newell, Kaiyu Yang, and Jia Deng 10 | # the code is adapted from 11 | # https://github.com/umich-vl/pose-hg-train/blob/master/src/models/hg.lua 12 | 13 | 14 | def conv3X3(in_planes, out_planes, stride=1): 15 | """3x3 convolution with padding""" 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | # define the network 20 | class BasicBlock(nn.Module): 21 | def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | # batchNorm_type 0 means batchnormalization 24 | # 1 means instance normalization 25 | self.inplanes = inplanes 26 | self.outplanes = outplanes 27 | self.conv1 = conv3X3(inplanes, outplanes, 1) 28 | self.conv2 = conv3X3(outplanes, outplanes, 1) 29 | if batchNorm_type == 0: 30 | self.bn1 = nn.BatchNorm2d(outplanes) 31 | self.bn2 = nn.BatchNorm2d(outplanes) 32 | else: 33 | self.bn1 = nn.InstanceNorm2d(outplanes) 34 | self.bn2 = nn.InstanceNorm2d(outplanes) 35 | 36 | self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False) 37 | 38 | def forward(self, x): 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = F.relu(out) 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.inplanes != self.outplanes: 46 | out += self.shortcuts(x) 47 | else: 48 | out += x 49 | 50 | out = F.relu(out) 51 | return out 52 | 53 | 54 | class HourglassBlock(nn.Module): 55 | ''' 56 | define a basic block for hourglass neetwork 57 | ^-------------------------upper conv------------------- 58 | | | 59 | | V 60 | input------>downsample-->low1-->middle-->low2-->upsample-->+-->output 61 | NOTE about output: 62 | Since we need the lighting from the inner most layer, 63 | let's also output the results from middel layer 64 | ''' 65 | 66 | def __init__(self, inplane, mid_plane, middleNet, skipLayer=True): 67 | super(HourglassBlock, self).__init__() 68 | # upper branch 69 | self.skipLayer = True 70 | self.upper = BasicBlock(inplane, inplane, batchNorm_type=1) 71 | 72 | # lower branch 73 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) 74 | self.upSample = nn.Upsample(scale_factor=2, mode='nearest') 75 | self.low1 = BasicBlock(inplane, mid_plane) 76 | self.middle = middleNet 77 | self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1) 78 | 79 | def forward(self, x, light, count, skip_count): 80 | # we use count to indicate wich layer we are in 81 | # max_count indicates the from which layer, we would use skip connections 82 | out_upper = self.upper(x) 83 | out_lower = self.downSample(x) 84 | out_lower = self.low1(out_lower) 85 | out_lower, out_middle = self.middle(out_lower, light, count + 1, skip_count) 86 | out_lower = self.low2(out_lower) 87 | out_lower = self.upSample(out_lower) 88 | 89 | if count >= skip_count and self.skipLayer: 90 | # withSkip is true, then we use skip layer 91 | # easy for analysis 92 | out = out_lower + out_upper 93 | else: 94 | out = out_lower 95 | # out = out_upper 96 | return out, out_middle 97 | 98 | 99 | class lightingNet(nn.Module): 100 | ''' 101 | define lighting network 102 | ''' 103 | 104 | def __init__(self, ncInput, ncOutput, ncMiddle): 105 | super(lightingNet, self).__init__() 106 | self.ncInput = ncInput 107 | self.ncOutput = ncOutput 108 | self.ncMiddle = ncMiddle 109 | 110 | # basic idea is to compute the average of the channel corresponding to lighting 111 | # using fully connected layers to get the lighting 112 | # then fully connected layers to get back to the output size 113 | 114 | self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False) 115 | self.predict_relu1 = nn.PReLU() 116 | self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False) 117 | 118 | self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False) 119 | self.post_relu1 = nn.PReLU() 120 | self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False) 121 | self.post_relu2 = nn.ReLU() # to be consistance with the original feature 122 | 123 | def forward(self, innerFeat, target_light, count, skip_count): 124 | x = innerFeat[:, 0:self.ncInput, :, :] # lighting feature 125 | _, _, row, col = x.shape 126 | 127 | # predict lighting 128 | feat = x.mean(dim=(2, 3), keepdim=True) 129 | light = self.predict_relu1(self.predict_FC1(feat)) 130 | light = self.predict_FC2(light) 131 | 132 | # get back the feature space 133 | upFeat = self.post_relu1(self.post_FC1(target_light)) 134 | upFeat = self.post_relu2(self.post_FC2(upFeat)) 135 | upFeat = upFeat.repeat((1, 1, row, col)) 136 | innerFeat[:, 0:self.ncInput, :, :] = upFeat 137 | return innerFeat, light #(old return statement pre Zf) 138 | #return innerFeat, innerFeat[:, self.ncInput:, :, :], light 139 | 140 | 141 | class HourglassNet(nn.Module): 142 | ''' 143 | basic idea: low layers are shared, upper layers are different 144 | lighting should be estimated from the inner most layer 145 | NOTE: we split the bottle neck layer into albedo, normal and lighting 146 | ''' 147 | 148 | def __init__(self, baseFilter=16, gray=True): 149 | super(HourglassNet, self).__init__() 150 | 151 | self.ncLight = 27 # number of channels for input to lighting network 152 | self.baseFilter = baseFilter 153 | 154 | # number of channles for output of lighting network 155 | if gray: 156 | self.ncOutLight = 9 # gray: channel is 1 157 | else: 158 | self.ncOutLight = 27 # color: channel is 3 159 | 160 | self.ncPre = self.baseFilter # number of channels for pre-convolution 161 | 162 | # number of channels 163 | self.ncHG3 = self.baseFilter 164 | self.ncHG2 = 2 * self.baseFilter 165 | self.ncHG1 = 4 * self.baseFilter 166 | self.ncHG0 = 8 * self.baseFilter + self.ncLight 167 | 168 | self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2) 169 | self.pre_bn = nn.BatchNorm2d(self.ncPre) 170 | 171 | self.light = lightingNet(self.ncLight, self.ncOutLight, 128) 172 | self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light) 173 | self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0) 174 | self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1) 175 | self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2) 176 | 177 | self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1) 178 | self.bn_1 = nn.BatchNorm2d(self.ncPre) 179 | self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 180 | self.bn_2 = nn.BatchNorm2d(self.ncPre) 181 | self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 182 | self.bn_3 = nn.BatchNorm2d(self.ncPre) 183 | 184 | self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0) 185 | 186 | def forward(self, x, target_light, skip_count): 187 | feat = self.pre_conv(x) 188 | feat = F.relu(self.pre_bn(feat)) 189 | # get the inner most features 190 | feat, out_light = self.HG3(feat, target_light, 0, skip_count) 191 | #feat, out_features, out_light = self.HG3(feat, target_light, 0, skip_count) 192 | 193 | feat = F.relu(self.bn_1(self.conv_1(feat))) 194 | feat = F.relu(self.bn_2(self.conv_2(feat))) 195 | feat = F.relu(self.bn_3(self.conv_3(feat))) 196 | out_img = self.output(feat) 197 | out_img = torch.sigmoid(out_img) 198 | return out_img, out_light 199 | 200 | 201 | if __name__ == '__main__': 202 | pass 203 | -------------------------------------------------------------------------------- /model/train.py: -------------------------------------------------------------------------------- 1 | # Libraries 2 | import torch 3 | import time 4 | import os 5 | import argparse 6 | from torch.utils.data import DataLoader 7 | from datetime import datetime 8 | 9 | # Local Files 10 | from model import HourglassNet 11 | from loss import L1 12 | from data import CelebData 13 | from debug import debug 14 | 15 | # Script to train the model, which is saved in trained_models 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser( 19 | description="Train a new model.") 20 | parser.add_argument( 21 | '--epochs', 22 | default=10, 23 | type=int, 24 | help='number of epochs', 25 | ) 26 | parser.add_argument( 27 | '--batch', 28 | default=100, 29 | type=int, 30 | help='batch size' 31 | ) 32 | parser.add_argument( 33 | '--lr', 34 | type=float, 35 | default=0.0001, 36 | help='learning rate for Adam optimizer' 37 | ) 38 | parser.add_argument( 39 | '--data', 40 | default=30000, 41 | type=int, 42 | help='number of data points to use' 43 | ) 44 | parser.add_argument( 45 | '--model', 46 | default=None, 47 | help='name of the model to be saved' 48 | ) 49 | parser.add_argument( 50 | '--verbose', 51 | action='store_true', 52 | help='print additional information') 53 | parser.add_argument( 54 | '--debug', 55 | action='store_true', 56 | help='debug model by outputting intermediate images') 57 | 58 | return parser.parse_args() 59 | 60 | ARGS = parse_args() 61 | # Settings 62 | VERBOSE = bool(ARGS.verbose) 63 | DEBUG = bool(ARGS.debug) 64 | 65 | # Hyper parameters 66 | EPOCHS = ARGS.epochs 67 | BATCH_SIZE = ARGS.batch 68 | LEARNING_RATE = ARGS.lr 69 | MAX_DATA = ARGS.data 70 | 71 | def train(model, optimizer, dataloader, skip_count): 72 | 73 | num_batches = MAX_DATA // BATCH_SIZE 74 | 75 | epoch_loss = torch.tensor([0], dtype=torch.float32).cuda() 76 | 77 | for j, data in enumerate(dataloader, 0): 78 | I_sbatch, I_tbatch, L_sbatch, L_tbatch = data 79 | 80 | I_sbatch = torch.squeeze(I_sbatch, dim=1).cuda() 81 | L_tbatch = torch.squeeze(L_tbatch, dim=1).cuda() 82 | 83 | I_tbatch = torch.squeeze(I_tbatch, dim=1).cuda() 84 | L_sbatch = torch.squeeze(L_sbatch, dim=1).cuda() 85 | 86 | I_tp_batch, L_sp_batch = model.forward(I_sbatch, L_tbatch, skip_count) 87 | 88 | loss = L1(I_tbatch, I_tp_batch, L_sbatch, L_sp_batch) 89 | 90 | if (VERBOSE): 91 | print("Batch # {} / {} loss: {}".format(j + 1, num_batches, loss)) 92 | 93 | epoch_loss += loss 94 | 95 | optimizer.zero_grad() 96 | loss.backward() 97 | optimizer.step() 98 | 99 | epoch_loss = epoch_loss / num_batches 100 | print("Epoch loss: ", epoch_loss) 101 | 102 | model = HourglassNet(gray=True) 103 | model.cuda() 104 | model.train(True) 105 | modelId = None 106 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 107 | 108 | dataset = CelebData('../data/train/', int(ARGS.data)) 109 | dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True) 110 | for i in range(EPOCHS): 111 | if (DEBUG): 112 | print("Outputing debug image.") 113 | if (i == 0): 114 | modelId = debug(model, i) 115 | else: 116 | debug(model, i, modelId) 117 | print("Finished outputting debug image. Continuing training") 118 | 119 | start = time.time() 120 | print("Training epoch #", i + 1, "/", EPOCHS) 121 | 122 | train(model, optimizer, dataloader, 0) 123 | end = time.time() 124 | print("Time elapsed to train epoch #", i + 1, ":", end - start) 125 | 126 | if ARGS.model is None: 127 | now = datetime.now() 128 | model_name = 'model_{}.pt'.format(now.strftime("%m-%d-%H%M")) 129 | else: 130 | model_name = ARGS.model 131 | 132 | print("Done training! Saving model as {}".format(model_name)) 133 | torch.save(model.state_dict(), os.path.join('../trained_models/', model_name)) -------------------------------------------------------------------------------- /relight.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this is a simple test file 3 | ''' 4 | import sys 5 | sys.path.append('model') 6 | sys.path.append('utils') 7 | 8 | from utils_SH import * 9 | 10 | from face_detect.faceDetect import cropFace 11 | 12 | # other modules 13 | import os 14 | import numpy as np 15 | 16 | from torch.autograd import Variable 17 | import torch 18 | import cv2 19 | import argparse 20 | 21 | # This code is adapted from https://github.com/zhhoper/DPR 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser( 25 | description="image relighting training.") 26 | parser.add_argument( 27 | '--source_image', 28 | default='obama.jpg', 29 | help='name of image stored in data/', 30 | ) 31 | parser.add_argument( 32 | '--light_image', 33 | default='obama.jpg', 34 | help='name of image stored in data/', 35 | ) 36 | parser.add_argument( 37 | '--model', 38 | default='trained.pt', 39 | help='model file to use stored in trained_model/' 40 | ) 41 | parser.add_argument( 42 | '--gpu', 43 | action='store_true', 44 | help='cpu vs. gpu' 45 | ) 46 | parser.add_argument( 47 | '--face_detect', 48 | default='Neither', 49 | help='Options: "both" or "light". Face detection/cropping for more accurate relighting.' 50 | ) 51 | 52 | 53 | return parser.parse_args() 54 | 55 | def preprocess_image(img_path, srcOrLight): 56 | src_img = cv2.imread(img_path) 57 | if (ARGS.face_detect == 'both') or (ARGS.face_detect == 'light' and srcOrLight == 2): 58 | src_img = cropFace(src_img) 59 | row, col, _ = src_img.shape 60 | src_img = cv2.resize(src_img, (256, 256)) 61 | Lab = cv2.cvtColor(src_img, cv2.COLOR_BGR2LAB) #converts image to one color space LAB 62 | 63 | inputL = Lab[:,:,0] #taking only the L channel 64 | inputL = inputL.astype(np.float32)/255.0 #normalise 65 | inputL = inputL.transpose((0,1)) 66 | inputL = inputL[None,None,...] #not sure what's happening here 67 | inputL = Variable(torch.from_numpy(inputL)) 68 | if (ARGS.gpu): 69 | inputL = inputL.cuda() 70 | return inputL, row, col, Lab 71 | 72 | 73 | ARGS = parse_args() 74 | 75 | modelFolder = 'trained_models/' 76 | 77 | # load model 78 | from model import * 79 | my_network = HourglassNet() 80 | 81 | if (ARGS.gpu): 82 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model))) 83 | my_network.cuda() 84 | else: 85 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model), map_location=torch.device('cpu'))) 86 | 87 | my_network.train(False) 88 | 89 | saveFolder = 'result' 90 | saveFolder = os.path.join(saveFolder, ARGS.model.split(".")[0]) 91 | if not os.path.exists(saveFolder): 92 | os.makedirs(saveFolder) 93 | 94 | light_img, _, _, _ = preprocess_image('data/test/images/{}'.format(ARGS.light_image), 2) 95 | 96 | sh = torch.zeros((1,9,1,1)) 97 | if (ARGS.gpu): 98 | sh = sh.cuda() 99 | 100 | _, outputSH = my_network(light_img, sh, 0) 101 | 102 | src_img, row, col, Lab = preprocess_image('data/test/images/{}'.format(ARGS.source_image), 1) 103 | 104 | outputImg, _ = my_network(src_img, outputSH, 0) 105 | 106 | 107 | outputImg = outputImg[0].cpu().data.numpy() 108 | outputImg = outputImg.transpose((1,2,0)) 109 | outputImg = np.squeeze(outputImg) 110 | outputImg = (outputImg*255.0).astype(np.uint8) 111 | Lab[:,:,0] = outputImg 112 | resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) 113 | resultLab = cv2.resize(resultLab, (col, row)) 114 | img_name, e = os.path.splitext(ARGS.source_image) 115 | if (ARGS.face_detect == 'both'): 116 | img_name += "_faceDetectBoth" 117 | if (ARGS.face_detect == 'light'): 118 | img_name += "_faceDetectLight" 119 | cv2.imwrite(os.path.join(saveFolder, 120 | '{}_relit.jpg'.format(img_name)), resultLab) 121 | #---------------------------------------------- -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kornia>=0.3.0 2 | numpy>=1.18.1 3 | opencv-python>=3.4.8.29 4 | torch>=1.5.0 5 | torchvision>=0.6.0 6 | pillow>=7.0.0 7 | -------------------------------------------------------------------------------- /testPostRotate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/testPostRotate.jpg -------------------------------------------------------------------------------- /testPreRotate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/testPreRotate.jpg -------------------------------------------------------------------------------- /test_network.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this is a simple test file 3 | ''' 4 | import sys 5 | sys.path.append('model') 6 | sys.path.append('utils') 7 | 8 | from utils_SH import * 9 | 10 | # other modules 11 | import os 12 | import numpy as np 13 | 14 | from torch.autograd import Variable 15 | import torch 16 | import cv2 17 | import argparse 18 | 19 | # This code is adapted from https://github.com/zhhoper/DPR 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser( 23 | description="image relighting training.") 24 | parser.add_argument( 25 | '--image', 26 | default='obama.jpg', 27 | help='name of image stored in data/', 28 | ) 29 | parser.add_argument( 30 | '--model', 31 | default='trained.pt', 32 | help='model file to use stored in trained_model/' 33 | ) 34 | parser.add_argument( 35 | '--gpu', 36 | action='store_true', 37 | help='cpu vs. gpu' 38 | ) 39 | 40 | return parser.parse_args() 41 | 42 | 43 | ARGS = parse_args() 44 | 45 | # ---------------- create normal for rendering half sphere ------ 46 | img_size = 256 47 | x = np.linspace(-1, 1, img_size) 48 | z = np.linspace(1, -1, img_size) 49 | x, z = np.meshgrid(x, z) 50 | 51 | mag = np.sqrt(x**2 + z**2) 52 | valid = mag <=1 53 | y = -np.sqrt(1 - (x*valid)**2 - (z*valid)**2) 54 | x = x * valid 55 | y = y * valid 56 | z = z * valid 57 | normal = np.concatenate((x[...,None], y[...,None], z[...,None]), axis=2) 58 | normal = np.reshape(normal, (-1, 3)) 59 | #----------------------------------------------------------------- 60 | 61 | modelFolder = 'trained_models/' 62 | 63 | # load model 64 | from model import * 65 | my_network = HourglassNet() 66 | if (ARGS.gpu): 67 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model))) 68 | my_network.cuda() 69 | else: 70 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, ARGS.model), map_location=torch.device('cpu'))) 71 | my_network.train(False) 72 | 73 | lightFolder = 'data/test/light/' 74 | 75 | saveFolder = 'result' 76 | saveFolder = os.path.join(saveFolder, ARGS.model.split(".")[0]) 77 | if not os.path.exists(saveFolder): 78 | os.makedirs(saveFolder) 79 | 80 | img = cv2.imread('data/test/images/{}'.format(ARGS.image)) 81 | row, col, _ = img.shape 82 | img = cv2.resize(img, (256, 256)) 83 | Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) #converts image to one color space LAB 84 | 85 | inputL = Lab[:,:,0] #taking only the L channel 86 | inputL = inputL.astype(np.float32)/255.0 #normalise 87 | inputL = inputL.transpose((0,1)) 88 | inputL = inputL[None,None,...] #not sure what's happening here 89 | inputL = Variable(torch.from_numpy(inputL)) 90 | if ARGS.gpu: 91 | inputL = inputL.cuda() 92 | 93 | def render_half_sphere(sh, output): 94 | sh = np.squeeze(sh) 95 | shading = get_shading(normal, sh) 96 | value = np.percentile(shading, 95) 97 | ind = shading > value 98 | shading[ind] = value 99 | shading = (shading - np.min(shading))/(np.max(shading) - np.min(shading)) 100 | shading = (shading *255.0).astype(np.uint8) 101 | shading = np.reshape(shading, (256, 256)) 102 | shading = shading * valid 103 | if output: 104 | cv2.imwrite(os.path.join(saveFolder,'light_predicted.png'.format(i)), shading) 105 | else: 106 | cv2.imwrite(os.path.join(saveFolder,'light_{:02d}.png'.format(i)), shading) 107 | 108 | for i in range(7): 109 | sh = np.loadtxt(os.path.join(lightFolder, 'rotate_light_{:02d}.txt'.format(i))) 110 | sh = sh[0:9] 111 | sh = sh * 0.5 112 | 113 | render_half_sphere(sh, False) 114 | 115 | # rendering images using the network 116 | sh = np.reshape(sh, (1,9,1,1)).astype(np.float32) 117 | sh = Variable(torch.from_numpy(sh)) 118 | if ARGS.gpu: 119 | sh = sh.cuda() 120 | #sh = Variable(torch.from_numpy(sh)) 121 | 122 | 123 | outputImg, outputSH = my_network(inputL, sh, 0) 124 | 125 | outputSH = outputSH.cpu().data.numpy() 126 | render_half_sphere(outputSH, True) 127 | 128 | outputImg = outputImg[0].cpu().data.numpy() 129 | outputImg = outputImg.transpose((1,2,0)) 130 | outputImg = np.squeeze(outputImg) 131 | outputImg = (outputImg*255.0).astype(np.uint8) 132 | Lab[:,:,0] = outputImg 133 | resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) 134 | resultLab = cv2.resize(resultLab, (col, row)) 135 | img_name, e = os.path.splitext(ARGS.image) 136 | cv2.imwrite(os.path.join(saveFolder, 137 | '{}_{:02d}.jpg'.format(img_name,i)), resultLab) 138 | #---------------------------------------------- -------------------------------------------------------------------------------- /trained_models/official.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/trained_models/official.t7 -------------------------------------------------------------------------------- /trained_models/trained.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/trained_models/trained.pt -------------------------------------------------------------------------------- /utils/clean_data.py: -------------------------------------------------------------------------------- 1 | # Script to descend into a DPR folder, remove all excess 2 | # information and resize images. 3 | # 4 | # Usage: python clean_data.py --dir --size --save 5 | # - dir_path: path to a DPR folder containing image folders 6 | # - size: size to rescale images to 7 | # - save_path: path to a folder to save all image folders 8 | 9 | from PIL import Image 10 | import os, shutil 11 | import argparse 12 | 13 | def clean_data(path, size, save): 14 | # Loop through all folders in the dir 15 | image_folders = os.listdir(path) 16 | for imgf in image_folders: 17 | imgf_path = os.path.join(path, imgf) 18 | if os.path.isdir(imgf_path): 19 | save_path = os.path.join(save, imgf) 20 | if not os.path.exists(save_path): 21 | os.makedirs(save_path) 22 | for item in os.listdir(imgf_path): 23 | i_path = os.path.join(imgf_path, item) 24 | i_name = os.path.basename(i_path) 25 | 26 | if (imgf in i_name) and ('.png' in i_name): 27 | im = Image.open(i_path) 28 | im = im.resize((size, size)) 29 | im.save(os.path.join(save_path, (i_name[:len(i_name)-4] + ".jpg")), 'JPEG') 30 | elif '_light_' in i_name: 31 | shutil.copyfile(i_path, os.path.join(save_path, i_name)) 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser( 35 | description="Script to clean data for image-relighting") 36 | parser.add_argument( 37 | '--size', 38 | required=True, 39 | type=int, 40 | help='the size to rescale images to') 41 | parser.add_argument( 42 | '--dir', 43 | required=True, 44 | help='the directory containing all the data' 45 | ) 46 | parser.add_argument( 47 | '--save', 48 | required=True, 49 | help='path to save data' 50 | ) 51 | return parser.parse_args() 52 | 53 | ARGS = parse_args() 54 | size = ARGS.size 55 | path = ARGS.dir 56 | save = ARGS.save 57 | 58 | clean_data(path, size, save) -------------------------------------------------------------------------------- /utils/utils_SH.py: -------------------------------------------------------------------------------- 1 | ''' 2 | construct shading using sh basis 3 | ''' 4 | import numpy as np 5 | 6 | # This code is taken directly from https://github.com/zhhoper/DPR 7 | 8 | def SH_basis(normal): 9 | ''' 10 | get SH basis based on normal 11 | normal is a Nx3 matrix 12 | return a Nx9 matrix 13 | The order of SH here is: 14 | 1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2 15 | ''' 16 | numElem = normal.shape[0] 17 | 18 | norm_X = normal[:,0] 19 | norm_Y = normal[:,1] 20 | norm_Z = normal[:,2] 21 | 22 | sh_basis = np.zeros((numElem, 9)) 23 | att= np.pi*np.array([1, 2.0/3.0, 1/4.0]) 24 | sh_basis[:,0] = 0.5/np.sqrt(np.pi)*att[0] 25 | 26 | sh_basis[:,1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1] 27 | sh_basis[:,2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1] 28 | sh_basis[:,3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1] 29 | 30 | sh_basis[:,4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2] 31 | sh_basis[:,5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2] 32 | sh_basis[:,6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2] 33 | sh_basis[:,7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2] 34 | sh_basis[:,8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2] 35 | return sh_basis 36 | 37 | def SH_basis_noAtt(normal): 38 | ''' 39 | get SH basis based on normal 40 | normal is a Nx3 matrix 41 | return a Nx9 matrix 42 | The order of SH here is: 43 | 1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2 44 | ''' 45 | numElem = normal.shape[0] 46 | 47 | norm_X = normal[:,0] 48 | norm_Y = normal[:,1] 49 | norm_Z = normal[:,2] 50 | 51 | sh_basis = np.zeros((numElem, 9)) 52 | sh_basis[:,0] = 0.5/np.sqrt(np.pi) 53 | 54 | sh_basis[:,1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y 55 | sh_basis[:,2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z 56 | sh_basis[:,3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X 57 | 58 | sh_basis[:,4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X 59 | sh_basis[:,5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z 60 | sh_basis[:,6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1) 61 | sh_basis[:,7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z 62 | sh_basis[:,8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2) 63 | return sh_basis 64 | 65 | def get_shading(normal, SH): 66 | ''' 67 | get shading based on normals and SH 68 | normal is Nx3 matrix 69 | SH: 9 x m vector 70 | return Nxm vector, where m is the number of returned images 71 | ''' 72 | sh_basis = SH_basis(normal) 73 | shading = np.matmul(sh_basis, SH) 74 | #shading = np.matmul(np.reshape(sh_basis, (-1, 9)), SH) 75 | #shading = np.reshape(shading, normal.shape[0:2]) 76 | return shading 77 | 78 | def SH_basis_debug(normal): 79 | ''' 80 | get SH basis based on normal 81 | normal is a Nx3 matrix 82 | return a Nx9 matrix 83 | The order of SH here is: 84 | 1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2 85 | ''' 86 | numElem = normal.shape[0] 87 | 88 | norm_X = normal[:,0] 89 | norm_Y = normal[:,1] 90 | norm_Z = normal[:,2] 91 | 92 | sh_basis = np.zeros((numElem, 9)) 93 | att= np.pi*np.array([1, 2.0/3.0, 1/4.0]) 94 | # att = [1,1,1] 95 | sh_basis[:,0] = 0.5/np.sqrt(np.pi)*att[0] 96 | 97 | sh_basis[:,1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1] 98 | sh_basis[:,2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1] 99 | sh_basis[:,3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1] 100 | 101 | sh_basis[:,4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2] 102 | sh_basis[:,5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2] 103 | sh_basis[:,6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2] 104 | sh_basis[:,7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2] 105 | sh_basis[:,8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2] 106 | return sh_basis 107 | 108 | def get_shading_debug(normal, SH): 109 | ''' 110 | get shading based on normals and SH 111 | normal is Nx3 matrix 112 | SH: 9 x m vector 113 | return Nxm vector, where m is the number of returned images 114 | ''' 115 | sh_basis = SH_basis_debug(normal) 116 | shading = np.matmul(sh_basis, SH) 117 | #shading = sh_basis*SH[0] 118 | return shading 119 | -------------------------------------------------------------------------------- /utils/utils_SH.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/APikielny/image-relighting/edaa4aa0e2f2ed846810e8b7df35d7c33ee6a557/utils/utils_SH.pyc -------------------------------------------------------------------------------- /utils/utils_normal.py: -------------------------------------------------------------------------------- 1 | ''' 2 | adjust normals according to which SH we want to use 3 | ''' 4 | import numpy as np 5 | import sys 6 | from utils_shtools import * 7 | from pyshtools.rotate import djpi2, SHRotateRealCoef 8 | 9 | # This code is taken directly from https://github.com/zhhoper/DPR 10 | 11 | class sh_cvt(): 12 | ''' 13 | the normal direction we get from projection is: 14 | 15 | > z 16 | | / 17 | | / 18 | |/ 19 | --------------------------> x 20 | | 21 | | 22 | v y 23 | 24 | the x, y, z direction of SH from SHtools is 25 | ^ z > y 26 | | / 27 | | / 28 | |/ 29 | --------------------------> x 30 | | 31 | | 32 | 33 | the bip lighting coordinate is 34 | > z 35 | | / 36 | | / 37 | |/ 38 | <-------------------------- 39 | x | 40 | | 41 | v y 42 | 43 | the sfs lighting coordinate is 44 | | 45 | | 46 | --------------------------> y 47 | / | 48 | / | 49 | z / v x 50 | ''' 51 | def __init__(self): 52 | self.SH_DEGREE = 2 53 | self.dj = djpi2(self.SH_DEGREE) 54 | 55 | 56 | def cvt2shtools(self, normalImages): 57 | ''' 58 | align coordinates of normal with shtools 59 | ''' 60 | newNormals = normalImages.copy() 61 | # new y is the old z 62 | newNormals[:,:,1] = normalImages[:,:,2] 63 | # new z is the negative old y 64 | newNormals[:,:,2] = -1*normalImages[:,:,1] 65 | return newNormals 66 | 67 | def bip2shtools(self, lighting): 68 | ''' 69 | lighting is n x 9 matrix of bip lighting, we want to convert it 70 | to the coordinate of shtools so we can use the same coordinate 71 | --we use shtools to rotate the coordinate: 72 | we use shtools to rotate the object: 73 | we need to use x convention, 74 | alpha_x = -pi (contour clock-wise rotate along z by pi) 75 | beta_x = -pi/2 (contour clock-wise rotate along new x by pi/2) 76 | gamma_x = 0 77 | then y convention is: 78 | alpha_y = alpha_x - pi/2 = 0 79 | beta_y = beta_x = -pi/2 80 | gamma_y = gamma_x + pi/2 = pi/2 81 | reference: https://shtools.oca.eu/shtools/pyshrotaterealcoef.html 82 | ''' 83 | new_lighting = np.zeros(lighting.shape) 84 | n = lighting.shape[0] 85 | for i in range(n): 86 | shMatrix = shtools_sh2matrix(lighting[i,:], self.SH_DEGREE) 87 | # rotate coordinate 88 | shMatrix = SHRotateRealCoef(shMatrix, np.array([0, -np.pi/2, np.pi/2]), self.dj) 89 | # rotate object 90 | #shMatrix = SHRotateRealCoef(shMatrix, np.array([-np.pi/2, np.pi/2, -np.pi/2]), self.dj) 91 | new_lighting[i,:] = shtools_matrix2vec(shMatrix) 92 | return new_lighting 93 | 94 | def sfs2shtools(self, lighting): 95 | ''' 96 | convert sfs SH to shtools 97 | --we use shtools to rotate the coordinate: 98 | we use shtools to rotate the object: 99 | 100 | we need to use x convention, 101 | we use shtools to rotate the coordinate: 102 | we need to use x convention, 103 | alpha_x = pi/2 (clock-wise rotate along z axis by pi/2) 104 | beta_x = -pi/2 (contour clock-wise rotate along new x by pi/2) 105 | gamma_x = 0 106 | then y convention is: 107 | alpha_y = alpha_x - pi/2 = 0 108 | beta_y = beta_x = -pi/2 109 | gamma_y = gamma_x + pi/2 = pi/2 110 | reference: https://shtools.oca.eu/shtools/pyshrotaterealcoef.html 111 | ''' 112 | new_lighting = np.zeros(lighting.shape) 113 | n = lighting.shape[0] 114 | for i in range(n): 115 | shMatrix = shtools_sh2matrix(lighting[i,:], self.SH_DEGREE) 116 | # rotate coordinate 117 | shMatrix = SHRotateRealCoef(shMatrix, np.array([0, -np.pi/2, np.pi/2]), self.dj) 118 | # rotate object 119 | #shMatrix = SHRotateRealCoef(shMatrix, np.array([np.pi/2, -np.pi/2, 0]), self.dj) 120 | new_lighting[i,:] = shtools_matrix2vec(shMatrix) 121 | return new_lighting 122 | -------------------------------------------------------------------------------- /utils/utils_shtools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | define some helper functions for shtools 3 | ''' 4 | import pyshtools 5 | from pyshtools.expand import MakeGridDH 6 | import numpy as np 7 | 8 | # This code is taken directly from https://github.com/zhhoper/DPR 9 | 10 | def shtools_matrix2vec(SH_matrix): 11 | ''' 12 | for the sh matrix created by sh tools, 13 | we create the vector of the sh 14 | ''' 15 | numOrder = SH_matrix.shape[1] 16 | vec_SH = np.zeros(numOrder**2) 17 | count = 0 18 | for i in range(numOrder): 19 | for j in range(i,0,-1): 20 | vec_SH[count] = SH_matrix[1,i,j] 21 | count = count + 1 22 | for j in range(0,i+1): 23 | vec_SH[count]= SH_matrix[0, i,j] 24 | count = count + 1 25 | return vec_SH 26 | 27 | def shtools_sh2matrix(coefficients, degree): 28 | ''' 29 | convert vector of sh to matrix 30 | ''' 31 | coeffs_matrix = np.zeros((2, degree + 1, degree + 1)) 32 | current_zero_index = 0 33 | for l in range(0, degree + 1): 34 | coeffs_matrix[0, l, 0] = coefficients[current_zero_index] 35 | for m in range(1, l + 1): 36 | coeffs_matrix[0, l, m] = coefficients[current_zero_index + m] 37 | coeffs_matrix[1, l, m] = coefficients[current_zero_index - m] 38 | current_zero_index += 2*(l+1) 39 | return coeffs_matrix 40 | 41 | def shtools_getSH(envMap, order=5): 42 | ''' 43 | get SH based on the envmap 44 | ''' 45 | SH = pyshtools.expand.SHExpandDH(envMap, sampling=2, lmax_calc=order, norm=4) 46 | return SH 47 | --------------------------------------------------------------------------------