├── 1_cor.png ├── cc_model.h5 ├── GT_Illum_Mat ├── Canon600D_gt.mat └── real_illum_568.mat ├── Model └── Shi-Gehler │ ├── cc_model.h5 │ └── cc_model.json ├── __pycache__ ├── local_2_global.cpython-37.pyc ├── gamma_linearization.cpython-37.pyc ├── create_local_illum_map.cpython-37.pyc └── linear_nonlinear_transform.cpython-37.pyc ├── calculate_error.py ├── progress_timer.py ├── linear_nonlinear_transform.py ├── error_measurement.py ├── create_local_illum_map.py ├── white_balancing.py ├── patch_sampling.py ├── gamma_linearization.py ├── local_2_global.py ├── CNN_keras.py ├── color_space_conversion.py ├── cc_model.json ├── FCN.py ├── README.md ├── generate_data.py └── prepare_dataset.py /1_cor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/1_cor.png -------------------------------------------------------------------------------- /cc_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/cc_model.h5 -------------------------------------------------------------------------------- /GT_Illum_Mat/Canon600D_gt.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/GT_Illum_Mat/Canon600D_gt.mat -------------------------------------------------------------------------------- /Model/Shi-Gehler/cc_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/Model/Shi-Gehler/cc_model.h5 -------------------------------------------------------------------------------- /GT_Illum_Mat/real_illum_568.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/GT_Illum_Mat/real_illum_568.mat -------------------------------------------------------------------------------- /__pycache__/local_2_global.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/__pycache__/local_2_global.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/gamma_linearization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/__pycache__/gamma_linearization.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/create_local_illum_map.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/__pycache__/create_local_illum_map.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/linear_nonlinear_transform.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WhuEven/CNN_model_ColorConstancy/HEAD/__pycache__/linear_nonlinear_transform.cpython-37.pyc -------------------------------------------------------------------------------- /calculate_error.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed May 2 11:44:06 2018 4 | 5 | @author: phamh 6 | """ 7 | 8 | import numpy as np; 9 | import cv2; 10 | 11 | from error_measurement import angular_error; 12 | from glob import glob; 13 | from white_balancing import white_balancing; 14 | from progress_timer import progress_timer; 15 | 16 | casted_list = glob('Color-casted\\*.png'); 17 | corrected_list = glob('Corrected\\*.png'); 18 | gt_list = glob('Ground-truth\\*.png'); 19 | ang_error = []; 20 | patch_size = (64, 64); 21 | 22 | pt = progress_timer(n_iter = len(casted_list), description = 'Processing :'); 23 | 24 | for i in range (0, len(casted_list)): 25 | 26 | img = cv2.imread(casted_list[i]); 27 | image_name = casted_list[i].replace('Color-casted\\', ''); 28 | image_name = image_name.replace('.png', ''); 29 | 30 | img_gt = cv2.imread(gt_list[i]); 31 | 32 | img_cor = white_balancing(img, image_name, patch_size); 33 | cv2.imwrite('Corrected\\' + image_name + '_cor.png', img_cor); 34 | 35 | error = angular_error(img_gt, img_cor, 'mean'); 36 | ang_error.append(error); 37 | 38 | pt.update(); 39 | 40 | avg_ang_error = np.mean(ang_error); 41 | pt.finish(); -------------------------------------------------------------------------------- /progress_timer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Apr 12 10:00:05 2018 4 | 5 | @author: phamh 6 | """ 7 | #import libraries 8 | import progressbar as pb 9 | 10 | #define progress timer class 11 | class progress_timer: 12 | 13 | def __init__(self, n_iter, description="Something"): 14 | self.n_iter = n_iter 15 | self.iter = 0 16 | self.description = description + ': ' 17 | self.timer = None 18 | self.initialize() 19 | 20 | def initialize(self): 21 | #initialize timer 22 | widgets = [self.description, pb.Percentage(), ' ', 23 | pb.Bar('=', '[', ']'), ' ', pb.ETA()] 24 | self.timer = pb.ProgressBar(widgets=widgets, maxval=self.n_iter).start() 25 | 26 | def update(self, q=1): 27 | #update timer 28 | self.timer.update(self.iter) 29 | self.iter += q 30 | 31 | def finish(self): 32 | #end timer 33 | self.timer.finish() 34 | 35 | 36 | # ============================================================================= 37 | # #initialize 38 | # pt = progress_timer(description= 'For loop example', n_iter=1000000) 39 | # #for loop example 40 | # for i in range(0,1000000): 41 | # #update 42 | # pt.update() 43 | # #finish 44 | # pt.finish() 45 | # ============================================================================= 46 | 47 | -------------------------------------------------------------------------------- /linear_nonlinear_transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 9 09:59:00 2018 4 | 5 | @author: phamh 6 | """ 7 | import numpy as np; 8 | import cv2; 9 | 10 | def linear_transform(channel_sRGB, gamma, f, s, T): 11 | m, n = channel_sRGB.shape; 12 | channel_linear = channel_sRGB; channel_sRGB_norm = channel_sRGB/255; 13 | for i in range(0, m): 14 | for j in range(0, n): 15 | if (T < channel_sRGB_norm[i][j] < 1): 16 | channel_linear[i][j] = ((channel_sRGB_norm[i][j] + f)/(1 + f))**(1/gamma); 17 | elif (0 <= channel_sRGB_norm[i][j] <= T): 18 | channel_linear[i][j] = channel_sRGB_norm[i][j]/s; 19 | 20 | channel_linear = 255*channel_linear; 21 | channel_linear[channel_linear > 255] = 255; 22 | return channel_linear; 23 | 24 | def non_linear_transform(channel, gamma, f, s, t): 25 | m, n = channel.shape; 26 | channel_gamma_cor = channel; channel_norm = channel/255; 27 | for i in range(0, m): 28 | for j in range(0, n): 29 | if (t < channel_norm[i][j] < 1): 30 | channel_gamma_cor[i][j] = (1 + f)*(channel_norm[i][j]**gamma) - f; 31 | elif (0 <= channel_norm[i][j] <= t): 32 | channel_gamma_cor[i][j] = s*channel_norm[i][j]; 33 | 34 | channel_gamma_cor = channel_gamma_cor*255; 35 | channel_gamma_cor[channel_gamma_cor > 255] = 255; 36 | return channel_gamma_cor; 37 | -------------------------------------------------------------------------------- /error_measurement.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 9 10:23:49 2018 4 | 5 | @author: phamh 6 | """ 7 | import numpy as np; 8 | import cv2; 9 | from color_space_conversion import rgb2lab; 10 | 11 | def angular_error(ground_truth_image, corrected_image, measurement_type): 12 | B_gt, G_gt, R_gt = cv2.split(ground_truth_image); 13 | B_cor, G_cor, R_cor = cv2.split(corrected_image); 14 | 15 | if measurement_type == 'mean': 16 | e_gt = np.array([np.mean(B_gt), np.mean(G_gt), np.mean(R_gt)]); 17 | e_est = np.array([np.mean(B_cor), np.mean(G_cor), np.mean(R_cor)]); 18 | 19 | elif measurement_type == 'median': 20 | e_gt = np.array([np.median(B_gt), np.median(G_gt), np.median(R_gt)]); 21 | e_est = np.array([np.median(B_cor), np.median(G_cor), np.median(R_cor)]); 22 | 23 | error_cos = np.dot(e_gt, e_est)/(np.linalg.norm(e_gt)*np.linalg.norm(e_est)); 24 | e_angular = np.degrees(np.arccos(error_cos)); 25 | return e_angular; 26 | 27 | def Euclidean_distance(ground_truth_image, corrected_image): 28 | B_gt, G_gt, R_gt = cv2.split(ground_truth_image); 29 | B_cor, G_cor, R_cor = cv2.split(corrected_image); 30 | 31 | L_gt, a_gt, b_gt = rgb2lab(B_gt, G_gt, R_gt, 'AdobeRGB', 'D65'); 32 | L_cor, a_cor, b_cor= rgb2lab(B_cor, G_cor, R_cor, 'AdobeRGB', 'D65'); 33 | 34 | delta = np.sqrt(np.square(L_gt - L_cor) + np.square(a_gt - a_cor) + np.square(b_gt - b_cor)); 35 | average_Euclidean_dist = np.mean(delta); 36 | 37 | return average_Euclidean_dist; 38 | -------------------------------------------------------------------------------- /create_local_illum_map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Apr 26 14:52:45 2018 4 | 5 | @author: phamh 6 | """ 7 | 8 | import numpy as np; 9 | import h5py; 10 | import cv2; 11 | 12 | from glob import glob; 13 | from tensorflow.python.keras.models import model_from_json; 14 | 15 | def create_local_illum_map(image, patch_size): 16 | 17 | #load cc_model 18 | json_file = open('cc_model.json', 'r'); 19 | loaded_model_json = json_file.read(); 20 | json_file.close(); 21 | cc_model = model_from_json(loaded_model_json); 22 | 23 | #load weights into cc_model 24 | cc_model.load_weights("cc_model.h5"); 25 | cc_model.compile(optimizer = 'Adam', loss = 'cosine_proximity', metrics = ['acc']); 26 | 27 | n_r, n_c, _ = image.shape; 28 | patch_r, patch_c = patch_size; 29 | 30 | total_patch = int(((n_r - n_r%patch_r)/patch_r)*((n_c - n_c%patch_c)/patch_c)); 31 | 32 | img_resize = cv2.resize(image, ((n_r - n_r%patch_r), (n_c - n_c%patch_c))); 33 | img_reshape = np.reshape(img_resize, (int(patch_r), -1, 3)); 34 | 35 | illum_estimate = []; 36 | 37 | for i in range(total_patch): 38 | 39 | img_patch = img_reshape[0:patch_r, i*patch_c:(i+1)*patch_c]; 40 | 41 | patch_cvt = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB); 42 | patch_cvt = np.expand_dims(patch_cvt, axis=0); 43 | illum_estimate.append(cc_model.predict(patch_cvt)); 44 | 45 | illum_estimate = np.asarray(illum_estimate); 46 | illum_map = np.reshape(illum_estimate, (int((n_r - n_r%patch_r)/patch_r), int((n_c - n_c%patch_c)/patch_c), 3)); 47 | 48 | return illum_map; 49 | 50 | -------------------------------------------------------------------------------- /white_balancing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Apr 26 18:28:36 2018 4 | 5 | @author: phamh 6 | """ 7 | 8 | import numpy as np 9 | import cv2 10 | from gamma_linearization import gamma_decode, gamma_encode 11 | from local_2_global import local_2_global 12 | 13 | #image name correspond to patches's sub-folder 14 | #Exp: image 0015 corresponds to patches/0015 folder 15 | 16 | def white_balancing(img, image_name, patch_size): 17 | 18 | m, n, _ = img.shape 19 | B_channel, G_channel, R_channel = cv2.split(img) 20 | Color_space = 'AdobeRGB' 21 | 22 | #Undo Gamma Correction 23 | B_channel, G_channel, R_channel = gamma_decode(B_channel, G_channel, R_channel, Color_space) 24 | 25 | #Compute local illuminant map then aggregate to global illuminant 26 | illum_global = local_2_global(image_name, img, patch_size) 27 | 28 | #Gain of red and blue channels 29 | alpha = np.max(illum_global)/illum_global[0, 2] 30 | beta = np.max(illum_global)/illum_global[0, 1] 31 | ceta = np.max(illum_global)/illum_global[0, 0] 32 | 33 | #Corrected Image 34 | B_cor = alpha*B_channel 35 | G_cor = beta*G_channel 36 | R_cor = ceta*R_channel 37 | 38 | #Gamma correction to display 39 | B_cor, G_cor, R_cor = gamma_encode(B_cor, G_cor, R_cor, Color_space) 40 | B_cor[B_cor > 255] = 255 41 | G_cor[G_cor > 255] = 255 42 | R_cor[R_cor > 255] = 255 43 | 44 | #Convert to uint8 to display 45 | B_cor = B_cor.astype(np.uint8) 46 | G_cor = G_cor.astype(np.uint8) 47 | R_cor = R_cor.astype(np.uint8) 48 | img_cor = cv2.merge((B_cor, G_cor, R_cor)) 49 | 50 | return img_cor 51 | 52 | img = cv2.imread('E:\\e_data\\1.tif') 53 | image_name = '1' 54 | patch_size = (32, 32) 55 | img_white_balance = white_balancing(img, image_name, patch_size) 56 | cv2.imwrite(image_name + '_cor.png', img_white_balance) -------------------------------------------------------------------------------- /patch_sampling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Spyder Editor 4 | 5 | This is a temporary script file. 6 | """ 7 | 8 | import numpy as np; 9 | import cv2; 10 | from glob import glob; 11 | import os; 12 | from progress_timer import progress_timer; 13 | 14 | def patch_sampling(img, patch_size, index): 15 | 16 | n_r, n_c, _ = np.shape(img); 17 | patch_r, patch_c = patch_size; 18 | total_patch = int(((n_r - n_r%patch_r)/patch_r)*((n_c - n_c%patch_c)/patch_c)); 19 | 20 | img_resize = cv2.resize(img, ((n_r - n_r%patch_r), (n_c - n_c%patch_c))); 21 | img_reshape = np.reshape(img_resize, (int(patch_r), int(patch_c*total_patch), 3)); 22 | 23 | #Create CLAHE object 24 | clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)); 25 | 26 | for i in range(total_patch): 27 | img_patch = img_reshape[0:patch_r, i*patch_c:(i+1)*patch_c]; 28 | 29 | #Convert image to Lab to perform contrast normalizing 30 | lab= cv2.cvtColor(img_patch, cv2.COLOR_BGR2LAB); 31 | l, a, b = cv2.split(lab); 32 | cl = clahe.apply(l); 33 | clab = cv2.merge((cl, a, b)); 34 | 35 | img_patch = cv2.cvtColor(clab, cv2.COLOR_LAB2BGR); 36 | count = "%04d" % index; 37 | path_name = 'patches\\' + count + "\\%04d_" % index + "%04d.png" % (i+1); 38 | 39 | if not os.path.exists('patches\\' + count): 40 | os.makedirs('patches\\' + count); 41 | 42 | cv2.imwrite(path_name, img_patch); 43 | 44 | return 0; 45 | 46 | path = r'C:\Users\phamh\Workspace\Dataset\Shi_Gehler\Converted'; 47 | img_list = glob(path + "\*.png"); 48 | patch_size = (32, 32); 49 | 50 | pt = progress_timer(n_iter = len(img_list), description = 'Preparing Patches :'); 51 | 52 | for i in range(len(img_list)): 53 | index = i + 1; 54 | img = cv2.imread(img_list[i], 1); 55 | a = patch_sampling(img, patch_size, index); 56 | 57 | pt.update(); 58 | 59 | pt.finish(); -------------------------------------------------------------------------------- /gamma_linearization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 9 09:58:22 2018 4 | 5 | @author: phamh 6 | """ 7 | import numpy as np; 8 | from linear_nonlinear_transform import linear_transform, non_linear_transform; 9 | 10 | def gamma_decode(B_gamma, G_gamma, R_gamma, decoding_type): 11 | B_gamma = B_gamma/255; G_gamma = G_gamma/255; R_gamma = R_gamma/255; 12 | if decoding_type == 'AdobeRGB': 13 | #Adobe RGB 14 | gamma = 1/2.2; 15 | B_gamma_decode = 255*(B_gamma**(1/gamma)); 16 | G_gamma_decode = 255*(G_gamma**(1/gamma)); 17 | R_gamma_decode = 255*(R_gamma**(1/gamma)); 18 | return (B_gamma_decode, G_gamma_decode, R_gamma_decode); 19 | elif decoding_type == 'sRGB': 20 | #sRGB 21 | gamma = 1/2.4; 22 | f = 0.055; s = 12.92; T = 0.04045; 23 | B_gamma_inverse = linear_transform(B_gamma, gamma, f, s, T); 24 | G_gamma_inverse = linear_transform(G_gamma, gamma, f, s, T); 25 | R_gamma_inverse = linear_transform(R_gamma, gamma, f, s, T); 26 | return (B_gamma_inverse, G_gamma_inverse, R_gamma_inverse); 27 | 28 | def gamma_encode(B_channel, G_channel, R_channel, encoding_type): 29 | B_channel = B_channel/255; G_channel = G_channel/255; R_channel = R_channel/255; 30 | #Non linear encoding 31 | if encoding_type == 'AdobeRGB': 32 | #Adobe RGB 33 | gamma = 1/2.2; 34 | if np.all(B_channel <= 0): 35 | B_gamma_cor = (B_channel**(gamma + 0j)); 36 | B_gamma_cor = 255*(abs(B_gamma_cor)); 37 | else: 38 | B_gamma_cor = 255*(B_channel**gamma); 39 | 40 | if np.all(G_channel <= 0): 41 | G_gamma_cor = (G_channel**(gamma + 0j)); 42 | G_gamma_cor = 255*(abs(G_gamma_cor)); 43 | else: 44 | G_gamma_cor = 255*(G_channel**gamma); 45 | 46 | if np.all(R_channel <= 0): 47 | R_gamma_cor = (R_channel**(gamma + 0j)); 48 | R_gamma_cor = 255*(abs(R_gamma_cor)); 49 | else: 50 | R_gamma_cor = 255*(R_channel**gamma); 51 | 52 | return (B_gamma_cor, G_gamma_cor, R_gamma_cor); 53 | elif encoding_type == 'sRGB': 54 | #sRGB 55 | gamma = 1/2.4; 56 | f = 0.055; s = 12.92; t = 0.0031308; 57 | B_gamma_cor = non_linear_transform(B_channel, gamma, f, s, t); 58 | G_gamma_cor = non_linear_transform(G_channel, gamma, f, s, t); 59 | R_gamma_cor = non_linear_transform(R_channel, gamma, f, s, t); 60 | return (B_gamma_cor, G_gamma_cor, R_gamma_cor); 61 | -------------------------------------------------------------------------------- /local_2_global.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Apr 26 11:34:51 2018 4 | 5 | @author: phamh 6 | """ 7 | 8 | import numpy as np 9 | import cv2 10 | import os 11 | 12 | from create_local_illum_map import create_local_illum_map 13 | #from sklearn.svm import SVR 14 | 15 | 16 | def pool_function(illum_map, pool_size, stride, mode): 17 | 18 | n_h, n_w, n_c = illum_map.shape 19 | pool_h, pool_w = pool_size 20 | stride_h, stride_w = stride 21 | 22 | # Define the dimensions of the output 23 | h_out = int(1 + (n_h - pool_h) / stride_h) 24 | w_out = int(1 + (n_w - pool_w) / stride_w) 25 | c_out = n_c 26 | 27 | pool_out = np.zeros((h_out, w_out, c_out)) 28 | 29 | for h in range(h_out): 30 | for w in range(w_out): 31 | for c in range (c_out): 32 | 33 | vert_start = h*stride_h 34 | vert_end = h*stride_h + pool_h 35 | horiz_start = w*stride_w 36 | horiz_end = w*stride_w + pool_w 37 | 38 | slice_window = illum_map[vert_start:vert_end, horiz_start:horiz_end, c] 39 | 40 | if mode == "std": 41 | pool_out[h, w, c] = np.std(slice_window) 42 | elif mode == "average": 43 | pool_out[h, w, c] = np.mean(slice_window) 44 | 45 | return pool_out 46 | 47 | def local_2_global(image_name, image, patch_size): 48 | 49 | n_r, n_c, _ = image.shape 50 | 51 | illum_map = create_local_illum_map(image, patch_size) #patch_size = (32, 32) 52 | 53 | #Gaussian Smooting with 5x5 kernel 54 | illum_map_smoothed = cv2.GaussianBlur(illum_map, (5,5), cv2.BORDER_DEFAULT) 55 | 56 | #Perform median pooling 57 | median_pooling = np.zeros((1, 3)) 58 | median_pooling[0, 0] = np.median(illum_map_smoothed[:, :, 0]) 59 | median_pooling[0, 1] = np.median(illum_map_smoothed[:, :, 1]) 60 | median_pooling[0, 2] = np.median(illum_map_smoothed[:, :, 2]) 61 | 62 | #Perform avarage pooling 63 | n_r, n_c, _ = illum_map.shape 64 | pool_size = (int(n_r/3), int(n_c/3)) 65 | stride = (int(n_r/3), int(n_c/3)) 66 | average_pooling = pool_function(illum_map, pool_size, stride, mode = 'average') 67 | 68 | average_pooling[0, 0] = np.mean(average_pooling[:, :, 0]) 69 | average_pooling[0, 1] = np.mean(average_pooling[:, :, 1]) 70 | average_pooling[0, 2] = np.mean(average_pooling[:, :, 2]) 71 | 72 | #reshape_average = np.reshape(average_pooling, (-1, )); 73 | #svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1) 74 | #illum_global = svr_rbf.fit(np.reshape(median_pooling, (3, 1)), reshape_average).predict(median_pooling); 75 | 76 | #Switch to average pooling if you prefer: 77 | #illum_global = average_pooling; 78 | illum_global = median_pooling 79 | 80 | return illum_global 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /CNN_keras.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 25 14:27:20 2018 4 | 5 | @author: phamh 6 | """ 7 | import h5py 8 | import numpy as np 9 | import cv2; 10 | import tensorflow as tf; 11 | 12 | from tensorflow.python.keras import layers, optimizers 13 | from tensorflow.python.keras.layers import Input, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D 14 | from tensorflow.python.keras.layers import AveragePooling2D, MaxPooling2D, Dropout, GlobalMaxPooling2D, GlobalAveragePooling2D 15 | from tensorflow.python.keras.models import Model 16 | from tensorflow.python.keras.preprocessing import image 17 | from tensorflow.python.keras.initializers import glorot_uniform 18 | from tensorflow.python.keras import backend as K 19 | from generate_data import generate_train_data, generate_test_data; 20 | 21 | 22 | 23 | def ColorNet(input_shape, channels = 3): 24 | 25 | # Define the input as a tensor with shape input_shape 26 | X_input = Input(input_shape) 27 | 28 | # Stage 1 29 | X = Conv2D(240, (1, 1), strides = (1, 1), name = 'conv1', kernel_initializer = glorot_uniform(seed=0))(X_input) 30 | X = BatchNormalization(name = 'bn_conv1')(X) 31 | 32 | X = MaxPooling2D((8, 8), strides=(8, 8))(X) 33 | 34 | # output layer 35 | X = Flatten()(X) 36 | X = Dense(40, activation='relu', name='fc' + str(40))(X); 37 | X = Dropout(rate = 0.5)(X); 38 | 39 | X = Dense(channels, activation=None, name='fc' + str(channels))(X); 40 | 41 | # Create model 42 | color_model = Model(inputs = X_input, outputs = X, name='ColorNet'); 43 | 44 | return color_model; 45 | 46 | #X_train = np.load('X_train.npy'); Y_train = np.load('Y_train.npy'); 47 | #X_test = np.load('X_test.npy'); Y_test = np.load('Y_test.npy'); 48 | 49 | X_train, Y_train, _ = generate_train_data(train_size = 9000, set_name = 'Shi-Gehler', patch_size = (64, 64)); 50 | X_test, Y_test, _ = generate_test_data(test_size = 2420, set_name = 'Shi-Gehler', patch_size = (64, 64)); 51 | 52 | rmsprop = optimizers.RMSprop(lr = 0.001, rho=0.9, epsilon=None, decay=0.0); 53 | 54 | cc_model = ColorNet(input_shape = X_train.shape[1:4]); 55 | cc_model.compile(optimizer = 'Adam', loss = 'cosine_proximity', metrics = ['acc']); 56 | 57 | estimate = cc_model.fit(X_train, Y_train, validation_split = 0.3333, epochs = 20, batch_size = 160); 58 | 59 | preds = cc_model.evaluate(X_test, Y_test); 60 | print(); 61 | print ("Loss = " + str(preds[0])); 62 | print ("Test Accuracy = " + str(preds[1])); 63 | 64 | # serialize model to JSON 65 | model_json = cc_model.to_json() 66 | with open("cc_model.json", "w") as json_file: 67 | json_file.write(model_json) 68 | # serialize weights to HDF5 69 | cc_model.save_weights("cc_model.h5") 70 | print("Saved model to disk") 71 | 72 | 73 | # ============================================================================= 74 | # img1 = cv2.imread('0044_0002.png'); 75 | # img1 = np.expand_dims(img1, axis=0); 76 | # a = cc_model.predict(img1); 77 | # 78 | # img2 = cv2.imread('0015_0015.png'); 79 | # img2 = np.expand_dims(img2, axis=0); 80 | # b = cc_model.predict(img2); 81 | # ============================================================================= 82 | -------------------------------------------------------------------------------- /color_space_conversion.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Apr 9 14:30:04 2018 4 | 5 | @author: phamh 6 | """ 7 | import numpy as np; 8 | 9 | def rgb2xyz(B_channel, G_channel, R_channel, color_space): 10 | B_channel = B_channel/255; G_channel = G_channel/255; R_channel = R_channel/255; 11 | if color_space == 'AdobeRGB': 12 | X = 0.5767*R_channel + 0.1856*G_channel + 0.1882*B_channel; 13 | Y = 0.2974*R_channel + 0.6273*G_channel + 0.0753*B_channel; 14 | Z = 0.0270*R_channel + 0.0707*G_channel + 0.9911*B_channel; 15 | return (X, Y, Z); 16 | elif color_space == 'sRGB': 17 | X = 0.4124*R_channel + 0.3576*G_channel + 0.1805*B_channel; 18 | Y = 0.2126*R_channel + 0.7152*G_channel + 0.0722*B_channel; 19 | Z = 0.0193*R_channel + 0.1192*G_channel + 0.9505*B_channel; 20 | return (X, Y, Z); 21 | 22 | def xyz2rgb(X, Y, Z, color_space): 23 | if color_space == 'AdobeRGB': 24 | B_channel = 0.0134*X + -0.1184*Y + 1.0154*Z; 25 | G_channel = -0.9693*X + 1.8760*Y + 0.0416*Z; 26 | R_channel = 2.0414*X + -0.5649*Y + -0.3447*Z; 27 | B_channel = B_channel*255; G_channel = G_channel*255; R_channel = R_channel*255; 28 | return (B_channel, G_channel, R_channel); 29 | elif color_space == 'sRGB': 30 | B_channel = 0.0556*X + -0.2040*Y + 1.0572*Z; 31 | G_channel = -0.9693*X + 1.8760*Y + 0.0416*Z; 32 | R_channel = 3.2405*X + -1.5371*Y + -0.4985*Z; 33 | B_channel = B_channel*255; G_channel = G_channel*255; R_channel = R_channel*255; 34 | return (B_channel, G_channel, R_channel); 35 | 36 | def f_function(i, i_ref): 37 | r = i/i_ref; 38 | m, n = r.shape; 39 | f = np.zeros((m, n)); 40 | x, y = np.where(r > 0.008856); 41 | f[x, y] = r[x, y]**1/3; 42 | x, y = np.where(r <= 0.008856); 43 | f[x, y] = (7.787*r[x, y] + 16/116); 44 | return f; 45 | 46 | def xyz2lab(X, Y, Z, illuminant): 47 | if illuminant == 'D65': 48 | Xn = 95.047; 49 | Yn = 100; 50 | Zn = 108.883; 51 | elif illuminant == 'D50': 52 | Xn = 96.6797; 53 | Yn = 100; 54 | Zn = 82.5188; 55 | 56 | L = 116*f_function(Y, Yn) - 16; 57 | a = 500*f_function(X, Xn) + -500*f_function(Y, Yn); 58 | b = 200*f_function(Y, Yn) + -200*f_function(Z, Zn); 59 | 60 | return (L, a, b); 61 | 62 | def lab2xyz(L, a, b, illuminant): 63 | if illuminant == 'D65': 64 | Xn = 95.047; 65 | Yn = 100; 66 | Zn = 108.883; 67 | elif illuminant == 'D50': 68 | Xn = 96.6797; 69 | Yn = 100; 70 | Zn = 82.5188; 71 | 72 | if L > 7.9996: 73 | X = Xn*((L/116 + a/500 + 16/116)**3); 74 | Y = Yn*((L/116 + 16/116)**3); 75 | Z = Zn*((L/116 - b/200 + 16/116)**3); 76 | elif L <= 7.9996: 77 | X = Xn*(1/7.787)*(L/116 + a/500); 78 | Y = Yn*(1/7.787)*(L/116); 79 | Z = Zn*(1/7.787)*(L/116 - b/200); 80 | 81 | return (X, Y, Z); 82 | 83 | def rgb2lab(B_channel, G_channel, R_channel, color_space, illuminant): 84 | X, Y, Z = rgb2xyz(B_channel, G_channel, R_channel, color_space); 85 | L, a, b = xyz2lab(X, Y, Z, illuminant); 86 | return (L, a, b); 87 | 88 | def lab2rgb(L, a, b, color_space, illuminant): 89 | X, Y, Z = lab2xyz(L, a, b, illuminant); 90 | B_channel, G_channel, R_channel = xyz2rgb(X, Y, Z, color_space); 91 | return (B_channel, G_channel, R_channel); -------------------------------------------------------------------------------- /cc_model.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "ColorNet", "layers": [{"name": "input_1", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 32, 32, 3], "dtype": "float32", "sparse": false, "name": "input_1"}, "inbound_nodes": []}, {"name": "conv1", "class_name": "Conv2D", "config": {"name": "conv1", "trainable": true, "dtype": "float32", "filters": 240, "kernel_size": [1, 1], "strides": [1, 1], "padding": "valid", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": 0, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"name": "bn_conv1", "class_name": "BatchNormalization", "config": {"name": "bn_conv1", "trainable": true, "dtype": "float32", "axis": [3], "momentum": 0.99, "epsilon": 0.001, "center": true, "scale": true, "beta_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "gamma_initializer": {"class_name": "Ones", "config": {"dtype": "float32"}}, "moving_mean_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "moving_variance_initializer": {"class_name": "Ones", "config": {"dtype": "float32"}}, "beta_regularizer": null, "gamma_regularizer": null, "beta_constraint": null, "gamma_constraint": null}, "inbound_nodes": [[["conv1", 0, 0, {}]]]}, {"name": "max_pooling2d_1", "class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", "trainable": true, "dtype": "float32", "pool_size": [8, 8], "padding": "valid", "strides": [8, 8], "data_format": "channels_last"}, "inbound_nodes": [[["bn_conv1", 0, 0, {}]]]}, {"name": "flatten_1", "class_name": "Flatten", "config": {"name": "flatten_1", "trainable": true, "dtype": "float32"}, "inbound_nodes": [[["max_pooling2d_1", 0, 0, {}]]]}, {"name": "fc40", "class_name": "Dense", "config": {"name": "fc40", "trainable": true, "dtype": "float32", "units": 40, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten_1", 0, 0, {}]]]}, {"name": "dropout_1", "class_name": "Dropout", "config": {"name": "dropout_1", "trainable": true, "dtype": "float32", "rate": 0.5, "noise_shape": null, "seed": null}, "inbound_nodes": [[["fc40", 0, 0, {}]]]}, {"name": "fc3", "class_name": "Dense", "config": {"name": "fc3", "trainable": true, "dtype": "float32", "units": 3, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dropout_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["fc3", 0, 0]]}, "keras_version": "2.1.4-tf", "backend": "tensorflow"} -------------------------------------------------------------------------------- /Model/Shi-Gehler/cc_model.json: -------------------------------------------------------------------------------- 1 | {"class_name": "Model", "config": {"name": "ColorNet", "layers": [{"name": "input_1", "class_name": "InputLayer", "config": {"batch_input_shape": [null, 32, 32, 3], "dtype": "float32", "sparse": false, "name": "input_1"}, "inbound_nodes": []}, {"name": "conv1", "class_name": "Conv2D", "config": {"name": "conv1", "trainable": true, "dtype": "float32", "filters": 240, "kernel_size": [1, 1], "strides": [1, 1], "padding": "valid", "data_format": "channels_last", "dilation_rate": [1, 1], "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": 0, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["input_1", 0, 0, {}]]]}, {"name": "bn_conv1", "class_name": "BatchNormalization", "config": {"name": "bn_conv1", "trainable": true, "dtype": "float32", "axis": [3], "momentum": 0.99, "epsilon": 0.001, "center": true, "scale": true, "beta_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "gamma_initializer": {"class_name": "Ones", "config": {"dtype": "float32"}}, "moving_mean_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "moving_variance_initializer": {"class_name": "Ones", "config": {"dtype": "float32"}}, "beta_regularizer": null, "gamma_regularizer": null, "beta_constraint": null, "gamma_constraint": null}, "inbound_nodes": [[["conv1", 0, 0, {}]]]}, {"name": "max_pooling2d_1", "class_name": "MaxPooling2D", "config": {"name": "max_pooling2d_1", "trainable": true, "dtype": "float32", "pool_size": [8, 8], "padding": "valid", "strides": [8, 8], "data_format": "channels_last"}, "inbound_nodes": [[["bn_conv1", 0, 0, {}]]]}, {"name": "flatten_1", "class_name": "Flatten", "config": {"name": "flatten_1", "trainable": true, "dtype": "float32"}, "inbound_nodes": [[["max_pooling2d_1", 0, 0, {}]]]}, {"name": "fc40", "class_name": "Dense", "config": {"name": "fc40", "trainable": true, "dtype": "float32", "units": 40, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["flatten_1", 0, 0, {}]]]}, {"name": "dropout_1", "class_name": "Dropout", "config": {"name": "dropout_1", "trainable": true, "dtype": "float32", "rate": 0.5, "noise_shape": null, "seed": null}, "inbound_nodes": [[["fc40", 0, 0, {}]]]}, {"name": "fc3", "class_name": "Dense", "config": {"name": "fc3", "trainable": true, "dtype": "float32", "units": 3, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "VarianceScaling", "config": {"scale": 1.0, "mode": "fan_avg", "distribution": "uniform", "seed": null, "dtype": "float32"}}, "bias_initializer": {"class_name": "Zeros", "config": {"dtype": "float32"}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}, "inbound_nodes": [[["dropout_1", 0, 0, {}]]]}], "input_layers": [["input_1", 0, 0]], "output_layers": [["fc3", 0, 0]]}, "keras_version": "2.1.4-tf", "backend": "tensorflow"} -------------------------------------------------------------------------------- /FCN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jun 4 13:36:15 2018 4 | 5 | @author: phamh 6 | """ 7 | 8 | import numpy as np; 9 | import h5py; 10 | import tensorflow as tf; 11 | 12 | from tensorflow.python.keras.applications.vgg16 import VGG16 13 | 14 | from tensorflow.python.keras.layers import Input, Conv2D, Dropout, Activation, MaxPooling2D, Dense, Lambda; 15 | from tensorflow.python.keras import backend as K 16 | from SqueezeNet_simple_bypass import SqueezeNet, extract_layer_from_model; 17 | from tensorflow.python.keras.models import Model, model_from_json, save_model, load_model; 18 | from tensorflow.python.keras.optimizers import Adam; 19 | from generate_data import generate_data; 20 | 21 | SEPERATE_CONFIDENCE = False; 22 | 23 | def FCN(input_shape): 24 | 25 | 26 | vgg16_model = VGG16(weights = 'imagenet', include_top = False, input_shape = input_shape); 27 | 28 | #Sq_net = squeezenet(float(input_shape)); 29 | fire8 = extract_layer_from_model(vgg16_model, layer_name = 'block4_pool'); 30 | 31 | pool8 = MaxPooling2D((3,3), strides = (2,2), name = 'pool8')(fire8.output); 32 | 33 | fc1 = Conv2D(64, (6,6), strides= (1, 1), padding = 'same', name = 'fc1')(pool8); 34 | 35 | fc1 = Dropout(rate = 0.5)(fc1); 36 | 37 | 38 | if SEPERATE_CONFIDENCE: 39 | fc2 = Conv2D(4 , (1, 1), strides = (1, 1), padding = 'same', activation = 'relu', name = 'fc2')(fc1); 40 | rgb = K.l2_normalize(fc2[:, :, :, 0:3], axis = 3); 41 | w, h = map(int, fc2.get_shape()[1:3]); 42 | 43 | confidence = fc2[:, :, :, 3:4]; 44 | confidence = np.reshape(confidence, [-1, w*h]); 45 | confidence = K.softmax(confidence); 46 | confidence = np.reshape(confidence, shape=[-1, w, h, 1]); 47 | 48 | fc2 = rgb * confidence; 49 | 50 | else: 51 | fc2 = Conv2D(3, (1, 1), strides = (1, 1), padding = 'same', name = 'fc2')(fc1); 52 | 53 | fc2 = Activation('relu')(fc2); 54 | 55 | fc2 = Conv2D(3, (15, 15), padding = 'valid', name = 'fc_pooling')(fc2); 56 | 57 | 58 | def norm(fc2): 59 | 60 | fc2_norm = K.l2_normalize(fc2, axis = 3); 61 | illum_est = K.tf.reduce_sum(fc2_norm, axis = (1, 2)); 62 | illum_est = K.l2_normalize(illum_est); 63 | 64 | return illum_est; 65 | 66 | #illum_est = Dense(3)(fc2); 67 | 68 | illum_est = Lambda(norm)(fc2); 69 | 70 | 71 | FCN_model = Model(inputs = vgg16_model.input, outputs = illum_est, name = 'FC4'); 72 | 73 | return FCN_model; 74 | 75 | ''' 76 | squeezenet_file = open('squeezenet.json', 'r'); 77 | loaded_squeezenet_json = squeezenet_file.read(); 78 | squeezenet_file.close(); 79 | squeezenet = model_from_json(loaded_squeezenet_json); 80 | 81 | #load weights into squeezenet 82 | squeezenet.load_weights("squeezenet.h5"); 83 | squeezenet.compile(optimizer = 'Adam', loss = 'categorical_crossentropy', metrics = ['acc']); 84 | ''' 85 | train_size = 500; 86 | test_size = 500; 87 | patch_size = (512, 512); 88 | set_name = 'Shi-Gehler'; 89 | 90 | 91 | X_train_origin, Y_train_origin, X_train_norm, Y_train_norm, _, = generate_data(train_size, set_name, patch_size, job = 'Train'); 92 | 93 | #X_test, Y_test, _, = generate_data(test_size, set_name, patch_size, job = 'Test'); 94 | 95 | fcn_model = FCN(input_shape = X_train_origin.shape[1:4]); 96 | 97 | fcn_model.compile(optimizer = Adam(lr = 0.0002), loss = 'cosine_proximity', metrics = ['acc']); 98 | fcn_model.fit(X_train_origin, Y_train_origin, validation_split = 0.2, epochs = 20, batch_size = 16); 99 | 100 | save_model(fcn_model, 'fcn_model.h5'); 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Color Constancy(CC) using CNN 2 | 3 | Solving the color constancy porblem with CNN model proposed by [Simone_et_al](https://arxiv.org/pdf/1504.04548.pdf) 4 | 5 | ## Getting Started 6 | 7 | 8 | ### Prerequisites 9 | 10 | * You will need OpenCV library, Tensorflow and Keras in order to run the code. 11 | 12 | * Also, the dataset for CC problem I used here is of Shi and Gehler you can find more information and download link [here](http://www.cs.sfu.ca/~colour/data/shi_gehler/). 13 | 14 | * [Here](https://drive.google.com/file/d/1aHx7v-VGREQfyemsGqB8-1Ccfz4lrzKr/view?usp=sharing) is the smaller set from Shi-Gehler for you to use with the CNN model. 15 | 16 | * If you prefer a fastest way to play with the model, you can skip directly to the part **"Running the tests"** where you can use my pre-trained CNN model to see how it works against the CC problem. 17 | 18 | * If you have downloaded Shi-Gehler dataset, you will need to process these HDR images, I have uploaded the python file for this job under the name `preapare_dataset.py`. At the end of the file, replace the `path` variable with the directory to the **Shi_Gehler folder**. For example, `path = ...//Dataset//Shi_Gehler` and inside the **Shi_Gehler folder** you will need to create the subfolders which look like [this](https://imgur.com/a/tIfyEMp) and inside the **gt folder** you put the ground truth illuminant files and it looks like [this](https://imgur.com/a/CJ1ELtP). After that, you are ready to go. 19 | 20 | ### Establishing Dataset 21 | 22 | * The idea of the algorithm is training the CNN on the patches sampled from the image. 23 | 24 | * First, you can create your own train and test set by running the script 25 | 26 | ``` 27 | generate_data.py 28 | ``` 29 | You need to have the 'color-casted' images and corresponding ground truth illuminant matrix. We work with the Shi-Gehler dataset so these 'color-casted' images are the processed images from the original [HDR images](http://www.cs.sfu.ca/~colour/data/shi_gehler/) and the corresponding ground truth illuminant [here](http://www.cs.sfu.ca/~colour/data/shi_gehler/groundtruth_568.zip). If you have already processed the HDR images, you would have about 500+ 'color-casted' images. You will need to divide these image into two parts, one for training and one for testing, the division ratio is your own choice, for me, I chose 2/3 for training and 1/3 for testing. 30 | 31 | Then, inside the `generate_data.py`, within the function `generate_train_data`, replace the `path` variable with the directory to your train set, for example: `path = 'C:\\Users\\...\\Shi_Gehler\\Train_set\\'`, and do the same for the function `generate_test_data` with the directory of your test set. And also, in the code `illum_mat = scipy.io.loadmat('GT_Illum_Mat\\' + mat_name, squeeze_me = True, struct_as_record = False);` replace the `'GT_Illum_Mat'` with the directory where you put the file 'real_illum_568.mat'. 32 | 33 | After completed all the step above, you can generate train and test data of your own choice, I have given an example at the end of the `generate_data.py` file. 34 | 35 | * You can change the size of your train(test) set and number of train(test) ground-truth illuminants by simply changing these arguments: 36 | 37 | ``` 38 | train_size, test_size, number_of_train_gt, number of_test_gt 39 | ``` 40 | 41 | * However, if you prefer a faster way, you can directly download the train and test set I have created [here](https://drive.google.com/file/d/1w-qfkDugvs1oUdob2_DI4uo26OuGUK-S/view?usp=sharing). 42 | These are .npy file, you can simply use **numpy.load** to load them: 43 | 44 | ``` 45 | X_train = np.load('X_train.npy');... 46 | ``` 47 | 48 | ### Traing the model 49 | 50 | After finish preparing the dataset, you can start training the model with **CNN_keras.py**, simply load your train and test set and run it. 51 | 52 | **Remarks :** This is the problem of CC so i used the loss function of **"cosine_proximity"** as it is closest to the **"[angular error](https://fr.mathworks.com/help/images/examples/comparison-of-auto-white-balance-algorithms.html)"**. 53 | 54 | 55 | ## Running the tests 56 | 57 | * You can find my pre-trained models and weights in these file: 58 | 59 | ``` 60 | cc_model.h5 <---weights 61 | cc_model.json <--pre-trained model 62 | ``` 63 | 64 | * If you wish to start your own model, simply delete it and start from **"Establishing Dataset"** to build your own. However your choice, you can test the models with the images download from the smaller data set of Shi-Gehler I've provided above (**Prerequisites**). 65 | 66 | * Now, you can easily test the model by running the **white_balancing.py**, for example: 67 | 68 | ``` 69 | img = cv2.imread('0001.png') 70 | image_name = '0001' 71 | patch_size = (32, 32) 72 | img_white_balance = white_balancing(img, image_name, patch_size) 73 | ``` 74 | 75 | Remark: modify the argument **image_name** to the name of the image you want to test in the **Color-casted** folder: 76 | (remember to put the image inside the same folder of your test_illum.py, otherwise you have to include the path in the **image_name** argument) 77 | 78 | ``` 79 | image_name = '0005' 80 | ``` 81 | 82 | if not at the same folder: 83 | 84 | ``` 85 | image_name = '.../yourfolder/0005' 86 | ``` 87 | * After running the script, you can compare the results with the corresponding image in the **Ground-truth** folder. 88 | 89 | ### Some notes 90 | 91 | * There are many things hard to explain if you are not familiar with color science. For this reason, if you have questions, do not hesitate to contact me. 92 | 93 | 94 | ## Authors 95 | 96 | * **Simone Bianco** - *Initial work* - [Color Constancy Using CNNs](https://arxiv.org/pdf/1504.04548.pdf) 97 | 98 | * **Hien Pham** - *Re-implementation* 99 | 100 | ## License 101 | 102 | This project is under license of Technicolor. 103 | 104 | ## Acknowledgments 105 | 106 | * This is the implementation of Simone Bianco works. If you use this code for research purposes please cite [Simone's work](https://arxiv.org/pdf/1504.04548.pdf) and my implementation in the references. 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Apr 24 15:49:57 2018 4 | 5 | @author: phamh 6 | """ 7 | import numpy as np; 8 | import cv2; 9 | import scipy.io; 10 | from glob import glob; 11 | from random import randint; 12 | from progress_timer import progress_timer; 13 | 14 | def generate_train_data(train_size, set_name, patch_size): 15 | 16 | #Load ground truth illum value 17 | if (set_name == 'Shi-Gehler'): 18 | mat_name = 'real_illum_568.mat'; 19 | key = 'real_rgb'; 20 | path = 'C:\\Users\\phamh\\Workspace\\Dataset\\Shi_Gehler\\Train_set\\'; 21 | 22 | elif (set_name == 'Canon'): 23 | mat_name = 'Canon600D_gt.mat'; 24 | key = 'groundtruth_illuminants'; 25 | path = 'C:\\Users\\phamh\\Workspace\\Dataset\\Canon_600D\\Train_set\\'; 26 | 27 | illum_mat = scipy.io.loadmat('GT_Illum_Mat\\' + mat_name, squeeze_me = True, struct_as_record = False); 28 | ground_truth_illum = illum_mat[key]; 29 | 30 | flist = glob(path + '*.png'); 31 | number_of_train_gt = len(flist); 32 | 33 | pt = progress_timer(n_iter = number_of_train_gt, description = 'Generating Training Data :'); 34 | 35 | patches_per_image = int(train_size/number_of_train_gt); 36 | 37 | X_train_origin, Y_train_origin, name_train = [], [], []; 38 | i = 0; 39 | patch_r, patch_c = patch_size; 40 | 41 | while (i < number_of_train_gt): 42 | 43 | image_number = flist[i]; 44 | index = (image_number.replace(path ,'')).replace('.png', ''); 45 | 46 | image = cv2.imread(image_number); 47 | n_r, n_c, _ = np.shape(image); 48 | total_patch = int(((n_r - n_r%patch_r)/patch_r)*((n_c - n_c%patch_c)/patch_c)); 49 | 50 | img_resize = cv2.resize(image, ((n_r - n_r%patch_r), (n_c - n_c%patch_c))); 51 | img_reshape = np.reshape(img_resize, (int(patch_r), -1, 3)); 52 | 53 | #Create CLAHE object 54 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)); 55 | 56 | for j in range (0, patches_per_image): 57 | 58 | rd = randint(0, total_patch - 1); 59 | img_patch = img_reshape[0:patch_r, rd*patch_c:(rd+1)*patch_c]; 60 | 61 | #Convert image to Lab to perform contrast normalizing 62 | lab= cv2.cvtColor(img_patch, cv2.COLOR_BGR2LAB); 63 | 64 | #Contrast normalizing(Stretching) 65 | l, a, b = cv2.split(lab); 66 | cl = clahe.apply(l); 67 | clab = cv2.merge((cl, a, b)); 68 | 69 | #Convert back to BGR 70 | img_patch = cv2.cvtColor(clab, cv2.COLOR_LAB2BGR); 71 | 72 | img_patch = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB); 73 | 74 | X_train_origin.append(img_patch); 75 | Y_train_origin.append(ground_truth_illum[int(index) - 1]); 76 | 77 | name_train.append('%04d' % (int(index) - 1)); 78 | 79 | i += 1; 80 | 81 | pt.update(); 82 | 83 | X_train_origin = np.asarray(X_train_origin); 84 | Y_train_origin = np.asarray(Y_train_origin); 85 | 86 | X_train_origin = X_train_origin/255; 87 | max_Y = np.amax(Y_train_origin, 1); 88 | Y_train_origin[:, 0] = Y_train_origin[:, 0]/max_Y; 89 | Y_train_origin[:, 1] = Y_train_origin[:, 1]/max_Y; 90 | Y_train_origin[:, 2] = Y_train_origin[:, 2]/max_Y; 91 | 92 | seed = randint(1, 5000); 93 | np.random.seed(seed); 94 | X_train_origin = np.random.permutation(X_train_origin); 95 | 96 | np.random.seed(seed); 97 | Y_train_origin = np.random.permutation(Y_train_origin); 98 | 99 | pt.finish(); 100 | 101 | return X_train_origin, Y_train_origin, name_train; 102 | 103 | def generate_test_data(test_size, set_name, patch_size): 104 | 105 | #Load ground truth illum value 106 | if (set_name == 'Shi-Gehler'): 107 | mat_name = 'real_illum_568.mat'; 108 | key = 'real_rgb'; 109 | path = 'C:\\Users\\phamh\\Workspace\\Dataset\\Shi_Gehler\\Test_set\\'; 110 | 111 | elif (set_name == 'Canon'): 112 | mat_name = 'Canon600D_gt.mat'; 113 | key = 'groundtruth_illuminants'; 114 | path = 'C:\\Users\\phamh\\Workspace\\Dataset\\Canon_600D\\Test_set\\' 115 | 116 | illum_mat = scipy.io.loadmat('GT_Illum_Mat\\' + mat_name, squeeze_me = True, struct_as_record = False); 117 | ground_truth_illum = illum_mat[key]; 118 | 119 | flist = glob(path + '*.png'); 120 | 121 | number_of_test_gt = len(flist); 122 | 123 | pt = progress_timer(n_iter = number_of_test_gt, description = 'Generating Testing Data :'); 124 | 125 | patches_per_image = int(test_size/number_of_test_gt); 126 | X_test, Y_test, name_test = [], [], []; 127 | i = 0; 128 | 129 | patch_r, patch_c = patch_size; 130 | 131 | while (i < number_of_test_gt): 132 | 133 | image_number = flist[i]; 134 | index = (image_number.replace(path ,'')).replace('.png', ''); 135 | 136 | image = cv2.imread(image_number); 137 | n_r, n_c, _ = np.shape(image); 138 | total_patch = int(((n_r - n_r%patch_r)/patch_r)*((n_c - n_c%patch_c)/patch_c)); 139 | 140 | img_resize = cv2.resize(image, ((n_r - n_r%patch_r), (n_c - n_c%patch_c))); 141 | img_reshape = np.reshape(img_resize, (int(patch_r), -1, 3)); 142 | 143 | #Create CLAHE object 144 | clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)); 145 | 146 | for j in range (0, patches_per_image): 147 | 148 | rd = randint(0, total_patch - 1); 149 | img_patch = img_reshape[0:patch_r, rd*patch_c:(rd+1)*patch_c]; 150 | 151 | #Convert image to Lab to perform contrast normalizing 152 | lab= cv2.cvtColor(img_patch, cv2.COLOR_BGR2LAB); 153 | 154 | #Contrast normalizing(Stretching) 155 | l, a, b = cv2.split(lab); 156 | cl = clahe.apply(l); 157 | clab = cv2.merge((cl, a, b)); 158 | 159 | #Convert back to BGR 160 | img_patch = cv2.cvtColor(clab, cv2.COLOR_LAB2BGR); 161 | 162 | img_patch = cv2.cvtColor(img_patch, cv2.COLOR_BGR2RGB); 163 | 164 | X_test.append(img_patch); 165 | Y_test.append(ground_truth_illum[int(index) - 1]); 166 | 167 | name_test.append('%04d' % (int(index) - 1)); 168 | 169 | i += 1; 170 | pt.update(); 171 | 172 | X_test = np.asarray(X_test); 173 | Y_test = np.asarray(Y_test); 174 | 175 | X_test = X_test/255; 176 | max_Y = np.amax(Y_test, 1); 177 | Y_test[:, 0] = Y_test[:, 0]/max_Y; 178 | Y_test[:, 1] = Y_test[:, 1]/max_Y; 179 | Y_test[:, 2] = Y_test[:, 2]/max_Y; 180 | 181 | seed = randint(1, 5000); 182 | np.random.seed(seed); 183 | X_test = np.random.permutation(X_test); 184 | 185 | np.random.seed(seed); 186 | Y_test = np.random.permutation(Y_test); 187 | 188 | pt.finish(); 189 | 190 | return X_test, Y_test, name_test; 191 | 192 | #train_size = 3000; 193 | #test_size = 2420; 194 | #patch_size = (32, 32); 195 | #set_name = 'Shi-Gehler'; 196 | # 197 | #X_train, Y_train, name_train = generate_train_data(train_size, set_name, patch_size); 198 | #np.save('X_train.npy', X_train); np.save('Y_train.npy', Y_train); np.save('name_train.npy', name_train); 199 | # 200 | #X_test, Y_test, name_test = generate_test_data(test_size, set_name, patch_size); 201 | #np.save('X_test.npy', X_test); np.save('Y_test.npy', Y_test); np.save('name_test.npy', name_test); -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Apr 11 15:56:18 2018 4 | 5 | @author: phamh 6 | """ 7 | 8 | import numpy as np; 9 | import os; 10 | import cv2; 11 | from glob import glob; 12 | import scipy.io 13 | from progress_timer import progress_timer; 14 | 15 | def set_name(path): 16 | index_list = glob(path + '\INDEX\*.mat'); 17 | name, count = [], []; 18 | 19 | for i in range(len(index_list)): 20 | index_mat = (scipy.io.loadmat(index_list[i], squeeze_me=True, struct_as_record = False)); 21 | name.append(index_mat['index'].set_name); 22 | count.append((index_mat['index'].cc).size + (index_mat['index'].img).size); 23 | 24 | return (name, count); 25 | 26 | def check_set(path): 27 | index_list = glob(path + '\\real_illum\*.mat'); 28 | name, count = set_name(path); 29 | i, j = 0, 0; 30 | name_out, count_out = [], []; 31 | 32 | while i < len(index_list): 33 | if name[j] in index_list[i]: 34 | name_out.append(name[j]); 35 | count_out.append(count[j]); 36 | j += 1; i += 1; 37 | 38 | else: 39 | j += 1; 40 | 41 | return (name_out, count_out); 42 | 43 | 44 | def convert_to_8bit(arr, clip_percentile): 45 | arr = np.clip(arr * (255.0 / np.percentile(arr, 100 - clip_percentile, keepdims=True)), 0, 255) 46 | return arr.astype(np.uint8) 47 | 48 | def convert2png(path, bit, Dataset): 49 | if Dataset == 'Funt_et_al': 50 | (name, _) = set_name(path); 51 | name.remove('S0170'); name.remove('S0390'); name.remove('S0540'); name.remove('S1150'); 52 | name.remove('S0030'); name.remove('S0040'); name.remove('S0120'); name.remove('S0180'); 53 | name.remove('S0430'); name.remove('S0500'); name.remove('S1100'); 54 | path_hdr = path + '\hdr_cc'; 55 | flist = glob(path_hdr + '\*.hdr'); 56 | else: 57 | path_img = path +'\png1'; 58 | flist = glob(path_img + '\*.png'); 59 | 60 | pt = progress_timer(n_iter = len(flist), description = 'Preparing Color-casted Images :'); 61 | 62 | for i in range(len(flist)): 63 | 64 | img12 = cv2.imread(flist[i], cv2.IMREAD_UNCHANGED).astype(np.float32); 65 | if Dataset == 'Shi_Gehler': 66 | if i >= 86: 67 | img12 = np.maximum(0.,img12-129.); 68 | clip_percentile = 2.5; 69 | img12 = (img12/(2**bit-1))*100.0; 70 | image = convert_to_8bit(img12, clip_percentile); 71 | image = (cv2.cvtColor(image, cv2.COLOR_BGR2RGB)); 72 | gamma = 1/2.2; 73 | image = pow(image, gamma) * (255.0/pow(255,gamma)); 74 | img8 = np.array(image,dtype=np.uint8) 75 | image8 = (cv2.cvtColor(img8, cv2.COLOR_BGR2RGB)); 76 | 77 | elif Dataset == 'Funt_et_al': 78 | gamma_tone = 2.2; clip_percentile = 0.2; 79 | tonemap = cv2.createTonemap(gamma_tone); 80 | img12 = tonemap.process(img12); 81 | if np.any(np.isnan(img12)): 82 | img12[np.isnan(img12)] = 0; 83 | 84 | image8 = convert_to_8bit(img12, clip_percentile); 85 | 86 | elif Dataset == 'Canon_600D': 87 | img12 = np.maximum(0., img12 - 2048); 88 | clip_percentile = 2.5; 89 | img12 = (img12/(2**bit-1))*100.0 90 | image = convert_to_8bit(img12, clip_percentile); 91 | image = (cv2.cvtColor(image, cv2.COLOR_BGR2RGB)); 92 | gamma = 1/2.2; 93 | image = pow(image, gamma) * (255.0/pow(255,gamma)); 94 | img8 = np.array(image,dtype=np.uint8) 95 | image8 = (cv2.cvtColor(img8, cv2.COLOR_BGR2RGB)); 96 | 97 | if Dataset == 'Shi_Gehler' or Dataset == 'Canon_600D': 98 | save_dir = path_img.replace('png1', 'Converted\\') + "%04d.png" % (i+1); 99 | cv2.imwrite(save_dir, image8); 100 | 101 | elif Dataset == 'Funt_et_al': 102 | img_name = name[i] + "_%04d.png" % (i+1); 103 | save_dir = path_hdr.replace('hdr_cc', 'Converted\\') + img_name; 104 | cv2.imwrite(save_dir, image8); 105 | 106 | 107 | #cv2.imwrite(save_dir, image8); 108 | pt.update(); 109 | 110 | pt.finish(); 111 | 112 | def mat_illum(path): 113 | index_list = glob(path + '\\real_illum\*.mat'); 114 | mat = []; 115 | 116 | for i in range(len(index_list)): 117 | mat_load = scipy.io.loadmat(index_list[i], squeeze_me=True, struct_as_record = False); 118 | mat.append(mat_load['real_illum'].real_illum_by_cc19_PNG); 119 | 120 | return mat; 121 | 122 | def white_balance_image(img12, bit, Gain_R, Gain_G, Gain_B): 123 | img12 = (img12/(2**bit - 1))*100.0; #12bit data 124 | image = convert_to_8bit(img12, 2.5) 125 | image = (cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 126 | 127 | image[:, :, 0] = np.minimum(image[:, :, 0] * Gain_R,255) 128 | image[:, :, 1] = np.minimum(image[:, :, 1] * Gain_G,255) 129 | image[:, :, 2] = np.minimum(image[:, :, 2] * Gain_B,255) 130 | 131 | gamma = 1/2.2 132 | image = pow(image, gamma) * (255.0/pow(255,gamma)) 133 | image = np.array(image,dtype=np.uint8) 134 | image8 = (cv2.cvtColor(image, cv2.COLOR_BGR2RGB)); 135 | 136 | return image8; 137 | 138 | 139 | def WhiteBalanceDataSet(path, bit, Dataset): 140 | 141 | if Dataset == 'Shi_Gehler': 142 | path_gt = path + '\gt'; 143 | mat = (scipy.io.loadmat(path_gt + '\\real_illum_568.mat', squeeze_me=True, struct_as_record = False)); 144 | path_img = path + '\png1'; 145 | flist = glob(path_img + '\*.png'); 146 | 147 | elif Dataset == 'Funt_et_al': 148 | mat = mat_illum(path); 149 | name, _ = check_set(path); 150 | path_hdr = path + '\hdr_cc'; 151 | flist = glob(path_hdr + '\*.hdr'); 152 | 153 | elif Dataset == 'Canon_600D': 154 | mat = (scipy.io.loadmat(path + '\\Canon600D_gt.mat', squeeze_me=True, struct_as_record = False)); 155 | path_img = path + '\png1'; 156 | flist = glob(path_img + '\*.png'); 157 | 158 | pt = progress_timer(n_iter = len(flist), description = 'Preparing Groundtruth Images :'); 159 | 160 | for i in range(390, len(flist)): 161 | 162 | if Dataset == 'Shi_Gehler': 163 | img12 = cv2.imread(flist[393], cv2.IMREAD_UNCHANGED).astype(np.float32); 164 | 165 | if i >= 86: 166 | img12 = np.maximum(0.,img12-129.); 167 | 168 | Gain_R= float(np.max(mat['real_rgb'][i]))/float((mat['real_rgb'][i][0])) 169 | Gain_G= float(np.max(mat['real_rgb'][i]))/float((mat['real_rgb'][i][1])) 170 | Gain_B= float(np.max(mat['real_rgb'][i]))/float((mat['real_rgb'][i][2])) 171 | 172 | image8 = white_balance_image(img12, bit, Gain_R, Gain_G, Gain_B); 173 | save_dir = path_img.replace('png1', 'GroundTruth_1') + "\%04d.png" % (i+1); 174 | cv2.imwrite(save_dir, image8); 175 | i += 1; 176 | 177 | elif Dataset == 'Funt_et_al': 178 | #k = 0; 179 | gamma_tone = 2.2; 180 | tonemap = cv2.createTonemap(gamma_tone); 181 | 182 | 183 | img12 = cv2.imread(flist[i], cv2.IMREAD_UNCHANGED); 184 | avg_R = np.mean(mat[i][:, 0]); avg_G = np.mean(mat[i][:, 1]); avg_B = np.mean(mat[i][:, 2]); 185 | Gain_R = float(np.max((avg_R, avg_G, avg_B))/float(avg_R)); 186 | Gain_G = float(np.max((avg_R, avg_G, avg_B))/float(avg_G)); 187 | Gain_B = float(np.max((avg_R, avg_G, avg_B))/float(avg_B)); 188 | 189 | img12[:, :, 0] = np.minimum(img12[:, :, 0] * Gain_B,255); 190 | img12[:, :, 1] = np.minimum(img12[:, :, 1] * Gain_G,255); 191 | img12[:, :, 2] = np.minimum(img12[:, :, 2] * Gain_R,255); 192 | 193 | image = tonemap.process(img12); 194 | clip_percentile = 0.2; 195 | 196 | if np.any(np.isnan(image)): 197 | image[np.isnan(image)] = 0; 198 | 199 | image = np.clip(image * (255.0 / np.percentile(image, 100 - clip_percentile, keepdims=True)), 0, 255); 200 | 201 | img_name = name[i] + "_GT_%01d.png" % (i+1); 202 | save_dir = path_hdr.replace('hdr_cc', 'GroundTruth\\') + img_name; 203 | cv2.imwrite(save_dir, image); 204 | # ============================================================================= 205 | # while k < 4: 206 | # img12 = cv2.imread(flist[i], cv2.IMREAD_UNCHANGED); 207 | # Gain_R= float(np.max(mat[i][k]))/float((mat[i][k][0])); 208 | # Gain_G= float(np.max(mat[i][k]))/float((mat[i][k][1])); 209 | # Gain_B= float(np.max(mat[i][k]))/float((mat[i][k][2])); 210 | # 211 | # img12[:, :, 0] = np.minimum(img12[:, :, 0] * Gain_B,255); 212 | # img12[:, :, 1] = np.minimum(img12[:, :, 1] * Gain_G,255); 213 | # img12[:, :, 2] = np.minimum(img12[:, :, 2] * Gain_R,255); 214 | # 215 | # 216 | # image = tonemap.process(img12); 217 | # clip_percentile = 0.2; 218 | # 219 | # if np.any(np.isnan(image)): 220 | # image[np.isnan(image)] = 0; 221 | # 222 | # image = np.clip(image * (255.0 / np.percentile(image, 100 - clip_percentile, keepdims=True)), 0, 255); 223 | # 224 | # img_name = name[i] + "_GT_%01d.png" % (k+1); 225 | # save_dir = path_hdr.replace('hdr_cc', 'GroundTruth\\') + img_name; 226 | # cv2.imwrite(save_dir, image); 227 | # 228 | # k += 1; 229 | # 230 | # ============================================================================= 231 | elif Dataset == 'Canon_600D': 232 | img12 = cv2.imread(flist[i], cv2.IMREAD_UNCHANGED).astype(np.float32); 233 | 234 | img12 = np.maximum(0.,img12 - 2048.); 235 | 236 | Gain_R= float(np.max(mat['groundtruth_illuminants'][i]))/float((mat['groundtruth_illuminants'][i][0])) 237 | Gain_G= float(np.max(mat['groundtruth_illuminants'][i]))/float((mat['groundtruth_illuminants'][i][1])) 238 | Gain_B= float(np.max(mat['groundtruth_illuminants'][i]))/float((mat['groundtruth_illuminants'][i][2])) 239 | 240 | image8 = white_balance_image(img12, bit, Gain_R, Gain_G, Gain_B); 241 | save_dir = path_img.replace('png1', 'GroundTruth') + "\%04d.png" % (i+1); 242 | cv2.imwrite(save_dir, image8); 243 | i += 1; 244 | 245 | 246 | pt.update(); 247 | 248 | pt.finish(); 249 | 250 | 251 | Dataset = 'Shi_Gehler'; 252 | img_path = os.getcwd(); 253 | path = img_path.replace('ColorConstancy', 'Dataset\\' + Dataset); 254 | 255 | #convert2png(path, 14, Dataset); 256 | WhiteBalanceDataSet(path, 12, Dataset); --------------------------------------------------------------------------------