├── README.md ├── datapipline.py ├── evaluation.py ├── models.py ├── results ├── enlighten00001.png ├── enlighten10499.png ├── low00001.png └── low10499.png ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | ## RetinexNet Pytorch 2 | 3 | This is a repository for code to reproduce **Deep Retinex Decomposition for Low-Light Enhancement** as a pytorch project. 4 | 5 | In this project I basically copied the same setting in authors' [code](https://github.com/weichen582/RetinexNet), which was written in tensorflow. 6 | 7 | I did this project for an interview. I am no longer interested in it. I publish the code in case it might be helpful to others. 8 | 9 | Please refer to author's code if my code confuses you. 10 | 11 | 12 | 13 | #### Resutls 14 | 15 | Before; After 16 | 17 | ![low10499](./results/low10499.png) 18 | 19 | 20 | 21 | ![low10499](./results/enlighten10499.png) 22 | 23 | 24 | 25 | #### Requirements 26 | 27 | torch 1.0.0 28 | 29 | PIL 30 | 31 | 32 | 33 | #### Datasets 34 | 35 | [google drive(including train and test)](https://drive.google.com/open?id=1-PqpEKjJxfAH0GmVwsPPQB-R3NqiPVCO) 36 | 37 | 38 | 39 | #### Project Structure 40 | 41 | ``` 42 | - Desktop 43 | 44 | - Retinex_pytorch 45 | ...... 46 | 47 | - final_dataset 48 | - trainA 49 | - trainB 50 | 51 | - test_dataset 52 | - testA 53 | - testB 54 | (- resultsA) # this dir will be created during test 55 | ``` 56 | 57 | 58 | 59 | #### Usage 60 | 61 | training: 62 | 63 | ```python 64 | python train.py 65 | ``` 66 | 67 | testing: 68 | 69 | ```python 70 | python test.py 71 | ``` 72 | 73 | evaluating:(report PSN score between testA and resultsA; run after testing) 74 | 75 | ```python 76 | python evaluation.py 77 | ``` 78 | 79 | 80 | 81 | #### Acknowledge 82 | 83 | [authors' website about their project](https://daooshee.github.io/BMVC2018website/), [source paper](https://arxiv.org/pdf/1808.04560.pdf), [datapipline code](https://github.com/TAMU-VITA/EnlightenGAN) 84 | -------------------------------------------------------------------------------- /datapipline.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.transforms as transforms 3 | 4 | import os 5 | import random 6 | import torch 7 | 8 | from PIL import Image 9 | 10 | from utils import make_dataset 11 | 12 | 13 | class BaseDataset(data.Dataset): 14 | def __init__(self): 15 | super(BaseDataset, self).__init__() 16 | 17 | def name(self): 18 | return 'BaseDataset' 19 | 20 | def initialize(self, opt): 21 | pass 22 | 23 | 24 | class PairDataset(BaseDataset): 25 | def initialize(self, opt): 26 | self.opt = opt 27 | self.root = opt.dataroot 28 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') 29 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 30 | 31 | self.A_paths = make_dataset(self.dir_A) 32 | self.B_paths = make_dataset(self.dir_B) 33 | 34 | self.A_paths = sorted(self.A_paths) 35 | self.B_paths = sorted(self.B_paths) 36 | self.A_size = len(self.A_paths) 37 | self.B_size = len(self.B_paths) 38 | 39 | transform_list = [] 40 | 41 | transform_list += [transforms.ToTensor(), 42 | transforms.Normalize((0.5, 0.5, 0.5), 43 | (0.5, 0.5, 0.5))] 44 | # transform_list = [transforms.ToTensor()] 45 | 46 | self.transform = transforms.Compose(transform_list) 47 | # self.transform = get_transform(opt) 48 | 49 | def __getitem__(self, index): # 这个方法如何,何时调用? 50 | A_path = self.A_paths[index % self.A_size] 51 | B_path = self.B_paths[index % self.B_size] 52 | 53 | A_img = Image.open(A_path).convert('RGB') 54 | B_img = Image.open(A_path.replace("low", "normal").replace("A", "B")).convert('RGB') 55 | 56 | A_img = self.transform(A_img) 57 | B_img = self.transform(B_img) 58 | 59 | w = A_img.size(2) 60 | h = A_img.size(1) 61 | # w_offset = random.randint(0, max(0, w - self.opt.fineSize - 1)) 62 | # h_offset = random.randint(0, max(0, h - self.opt.fineSize - 1)) 63 | 64 | # A_img = A_img[:, h_offset:h_offset + self.opt.fineSize, 65 | # w_offset:w_offset + self.opt.fineSize] 66 | # B_img = B_img[:, h_offset:h_offset + self.opt.fineSize, 67 | # w_offset:w_offset + self.opt.fineSize] 68 | 69 | if self.opt.resize_or_crop == 'no': 70 | r, g, b = A_img[0] + 1, A_img[1] + 1, A_img[2] + 1 71 | A_gray = 1. - (0.299 * r + 0.587 * g + 0.114 * b) / 2. 72 | A_gray = torch.unsqueeze(A_gray, 0) 73 | input_img = A_img 74 | # A_gray = (1./A_gray)/255. 75 | else: 76 | 77 | # A_gray = (1./A_gray)/255. 78 | if (not self.opt.no_flip) and random.random() < 0.5: 79 | idx = [i for i in range(A_img.size(2) - 1, -1, -1)] 80 | idx = torch.LongTensor(idx) 81 | A_img = A_img.index_select(2, idx) 82 | B_img = B_img.index_select(2, idx) 83 | if (not self.opt.no_flip) and random.random() < 0.5: 84 | idx = [i for i in range(A_img.size(1) - 1, -1, -1)] 85 | idx = torch.LongTensor(idx) 86 | A_img = A_img.index_select(1, idx) 87 | B_img = B_img.index_select(1, idx) 88 | if (not self.opt.no_flip) and random.random() < 0.5: 89 | times = random.randint(self.opt.low_times, self.opt.high_times) / 100. 90 | input_img = (A_img + 1) / 2. / times 91 | input_img = input_img * 2 - 1 92 | else: 93 | input_img = A_img 94 | r, g, b = input_img[0] + 1, input_img[1] + 1, input_img[2] + 1 95 | A_gray = 1. - (0.299 * r + 0.587 * g + 0.114 * b) / 2. 96 | A_gray = torch.unsqueeze(A_gray, 0) 97 | return {'A': A_img, 'B': B_img, 'A_gray': A_gray, 'input_img': input_img, 98 | 'A_paths': A_path, 'B_paths': B_path} 99 | 100 | def __len__(self): 101 | return self.A_size 102 | 103 | def name(self): 104 | return 'PairDataset' 105 | 106 | class Structure(): 107 | def __init__(self, opt): 108 | self.__dict__.update(opt) 109 | 110 | def Get_paired_dataset(batch_size): 111 | data = PairDataset() 112 | opt = {"dataroot": "../final_dataset", "phase": "train", "resize_or_crop": False, "isTrain": True, "no_flip": True, 113 | "vary": 1, "lighten": True, "fineSize": 400} 114 | opt = Structure(opt) 115 | data.initialize(opt) 116 | dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size) 117 | return dataloader 118 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | import numpy as np 4 | import os 5 | import re 6 | 7 | 8 | class Evaluation(): 9 | """ 10 | evaluation on two directories with paired images; the paired index should be part of image names 11 | """ 12 | def __init__(self, root_1, root_2): 13 | self.path = {} 14 | self._traverse_test_dir(root_1, self.path) 15 | self._traverse_test_dir(root_2, self.path) 16 | 17 | def calcualte_psn_socre(self): 18 | score_list = [] 19 | for item in self.path: 20 | img_1, img_2 = self.path[item] 21 | img_1 = Image.open(img_1) 22 | img_1 = np.array(img_1) 23 | img_2 = Image.open(img_2) 24 | img_2 = np.array(img_2) 25 | img_1 = torch.Tensor(img_1) 26 | img_2 = torch.Tensor(img_2) 27 | 28 | r = torch.nn.functional.mse_loss(img_1[:, :, 0], img_2[:, :, 0]) 29 | 30 | g = torch.nn.functional.mse_loss(img_1[:, :, 1], img_2[:, :, 1]) 31 | 32 | b = torch.nn.functional.mse_loss(img_1[:, :, 2], img_2[:, :, 2]) 33 | 34 | t = (r + g + b) / 3. 35 | t = t.data.numpy() 36 | score_list.append(10. * np.log10(255. ** 2 / t)) 37 | 38 | return score_list 39 | 40 | def _traverse_test_dir(self, root_dir, path): 41 | for (r, v, file_names) in os.walk(root_dir): 42 | for f in file_names: 43 | if f.endswith('.png') and not f.startswith("._"): 44 | # if not in key(number),then create new list for that key and append path as value 45 | # if already exists the key, then append path as value 46 | idx = int(re.findall(string=f, pattern='\d+')[0]) 47 | img_path = os.path.join(r, f) 48 | if idx not in path.keys(): 49 | path[idx] = [img_path] 50 | else: 51 | path[idx].append(img_path) 52 | 53 | root_1 = "../test_dataset/testA/" 54 | root_2 = "../test_dataset/resultsA/" 55 | eval = Evaluation(root_1, root_2) 56 | scores = eval.calcualte_psn_socre() 57 | print(np.mean(scores)) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Decom_Net(nn.Module): 5 | def __init__(self, num_layers): 6 | super(Decom_Net, self).__init__() 7 | layers = [] 8 | layers.append(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=9, stride=1, padding=4)) 9 | for i in range(num_layers): 10 | layers.append(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)) 11 | layers.append(nn.ReLU(inplace=True)) 12 | layers.append(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)) 13 | layers.append(nn.Sigmoid()) 14 | self.model = nn.Sequential(*layers) 15 | 16 | def forward(self, img): 17 | img = torch.cat((torch.max(input=img, dim=1, keepdim=True)[0], img), dim=1) 18 | output = self.model(img) 19 | R, I = output[:,:3,:,:], output[:,3:4,:,:] 20 | return R, I 21 | 22 | class Enhance_Net(nn.Module): 23 | def __init__(self): 24 | super(Enhance_Net, self).__init__() 25 | self.conv1 = nn.Conv2d(in_channels=4, out_channels=64, kernel_size=4, stride=2, padding=1) 26 | self.relu1 = nn.ReLU(True) 27 | 28 | self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1) 29 | self.relu2 = nn.ReLU(True) 30 | 31 | self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1) 32 | self.relu3 = nn.ReLU(True) 33 | 34 | self.up_conv3 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1) 35 | self.up_relu3 = nn.ReLU(True) 36 | 37 | self.up_conv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1) 38 | self.up_relu2 = nn.ReLU(True) 39 | 40 | self.up_conv1 = nn.ConvTranspose2d(in_channels=128, out_channels=1, kernel_size=4, stride=2, padding=1) 41 | self.activation = nn.ReLU(True) 42 | 43 | self.fusion_conv = nn.Conv2d(in_channels=321, out_channels=64, kernel_size=1, stride=1, padding=0) 44 | self.final_conv = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1) 45 | self.final_activation = nn.Sigmoid() 46 | 47 | 48 | 49 | def forward(self, R, I): 50 | # enhance的输入是两个 51 | I = torch.cat((R,I),dim=1) 52 | h1 = self.conv1(I) 53 | x = self.relu1(h1) 54 | 55 | h2 = self.conv2(x) 56 | x = self.relu2(h2) 57 | 58 | h3 = self.conv3(x) 59 | 60 | h2_ = self.up_conv3(h3) 61 | h2_ = torch.cat((h2, h2_),dim=1) 62 | x = self.up_relu3(h2_) 63 | 64 | h1_ = self.up_conv2(x) 65 | h1_ = torch.cat((h1, h1_),dim=1) 66 | x = self.up_relu2(h1_) 67 | 68 | x = self.up_conv1(x) 69 | x = self.activation(x) 70 | 71 | c1 = nn.UpsamplingNearest2d(scale_factor=2)(h1_) 72 | c2 = nn.UpsamplingNearest2d(scale_factor=4)(h2_) 73 | c3 = nn.UpsamplingNearest2d(scale_factor=8)(h3) 74 | 75 | x = torch.cat([x,c1,c2,c3], dim=1) 76 | 77 | x = self.fusion_conv(x) 78 | x = self.final_conv(x) 79 | # x = self.final_activation(x) 80 | 81 | return x 82 | -------------------------------------------------------------------------------- /results/enlighten00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/houze-liu/RetinexNet_pytorch/4f7f631e31c149b7a8ccf694cb607f2fc7d97682/results/enlighten00001.png -------------------------------------------------------------------------------- /results/enlighten10499.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/houze-liu/RetinexNet_pytorch/4f7f631e31c149b7a8ccf694cb607f2fc7d97682/results/enlighten10499.png -------------------------------------------------------------------------------- /results/low00001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/houze-liu/RetinexNet_pytorch/4f7f631e31c149b7a8ccf694cb607f2fc7d97682/results/low00001.png -------------------------------------------------------------------------------- /results/low10499.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/houze-liu/RetinexNet_pytorch/4f7f631e31c149b7a8ccf694cb607f2fc7d97682/results/low10499.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | from models import Decom_Net, Enhance_Net 5 | 6 | def normalize_img(img): 7 | return (img + 1.) / 2. 8 | 9 | def denormalize_img(img): 10 | return (img - 0.5) * 2. 11 | 12 | def to_numpy(tensor): 13 | return tensor.detach().cpu().numpy() 14 | 15 | def process_img(img): 16 | from torchvision import transforms 17 | # image transformation 18 | img = transforms.ToTensor()(img) 19 | img = torch.unsqueeze(img, dim=0).cuda() 20 | return img 21 | 22 | def _test(model_name, save_dir, test_data_dir): 23 | checkpoint = torch.load(model_name) 24 | Dec.load_state_dict(checkpoint["Dec_model"]) 25 | Enh.load_state_dict(checkpoint["Enh_model"]) 26 | 27 | from torchvision import transforms 28 | for root, _, img_paths in os.walk(test_data_dir): 29 | for img_path in img_paths: 30 | with torch.cuda.device(0): 31 | img = Image.open(os.path.join(root, img_path)).convert('RGB') 32 | img = process_img(img) 33 | img = normalize_img(img) 34 | # decompose 35 | R, I = Dec(img) 36 | # enhance 37 | I_hat = Enh(R, I) 38 | # enlight 39 | S_hat = R.mul(I_hat) 40 | S_hat = denormalize_img(to_numpy(S_hat)) 41 | # transform image array from (-1,1) back to Image 42 | S_hat = transforms.ToPILImage()(torch.Tensor(S_hat[0])) 43 | if not os.path.isdir(save_dir): 44 | os.makedirs(save_dir) 45 | S_hat.save(save_dir + img_path) 46 | 47 | if __name__ == "__main__": 48 | num_layer = 5 49 | model_name = "./checkpoints/model80.tar" # change checkpoint 50 | test_data_dir = "../test_dataset/testA" 51 | save_dir = "../test_dataset/resultsA/" 52 | # init networks 53 | Dec = torch.nn.DataParallel(Decom_Net(num_layer).cuda(), device_ids=range(torch.cuda.device_count())) 54 | Enh = torch.nn.DataParallel(Enhance_Net().cuda(), device_ids=range(torch.cuda.device_count())) 55 | # test 56 | _test(model_name, save_dir, test_data_dir) 57 | 58 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from models import Decom_Net, Enhance_Net 2 | import torch 3 | import os 4 | nn = torch.nn 5 | 6 | # ---------------------------------------Define Network and Init------------------------------ 7 | epoch = 200 8 | lr = 1e-4 9 | nums_layer = 5 10 | load_from_check_point = False # set false to train from the scratch; otherwise set iter num to resume training 11 | Dec = Decom_Net(nums_layer).cuda() 12 | Enh = Enhance_Net().cuda() 13 | Dec = torch.nn.DataParallel(Dec, device_ids=range(torch.cuda.device_count())) 14 | Enh = torch.nn.DataParallel(Enh, device_ids=range(torch.cuda.device_count())) 15 | 16 | opt_Dec = torch.optim.Adam(Dec.parameters(), lr=lr) 17 | opt_Enh = torch.optim.Adam(Enh.parameters(), lr=lr) 18 | 19 | def load_check_point(param): 20 | if not param: 21 | return 22 | else: 23 | model_name = "./checkpoints/model{}.tar".format(param) 24 | checkpoint = torch.load(model_name) 25 | Dec.load_state_dict(checkpoint["Dec_model"]) 26 | Enh.load_state_dict(checkpoint["Enh_model"]) 27 | 28 | load_check_point(load_from_check_point) 29 | 30 | def reconst_loss(x, y): 31 | return torch.mean(torch.abs(x - y)) 32 | 33 | def normalize_img(img): 34 | # from (-1,1) to (0,1) 35 | return (img + 1.) / 2. 36 | 37 | # ----------------------------------------Training--------------------------------------------- 38 | def train(epoch): 39 | from datapipline import Get_paired_dataset 40 | dataset = Get_paired_dataset(1) 41 | Dec.train() 42 | Enh.train() 43 | flag = True 44 | for e in range(epoch): 45 | # train one epoch 46 | for data in dataset: 47 | # Get paired data 48 | with torch.cuda.device(0): 49 | S_low = data['A'].cuda() 50 | S_normal = data['B'].cuda() 51 | S_low = normalize_img(S_low) 52 | S_normal = normalize_img(S_normal) 53 | # Decompose Stage 54 | R_low, I_low = Dec(S_low) 55 | R_normal, I_normal = Dec(S_normal) 56 | 57 | # Enhance stage 58 | I_low_hat = Enh(R_low, I_low) 59 | # ---------------------------------Define Loss Function------------------------------------- 60 | # Decompose Net Loss: L_reconst + L_invariable_reflectance 61 | loss_reconst_dec = reconst_loss(S_low, R_low.mul(I_low)) \ 62 | + reconst_loss(S_normal, R_normal.mul(I_normal)) \ 63 | + 0.001 * reconst_loss(S_low, R_normal.mul(I_low)) \ 64 | + 0.001 * reconst_loss(S_normal, R_low.mul(I_normal)) 65 | loss_ivref = 0.01 * reconst_loss(R_low, R_normal) 66 | loss_dec = loss_reconst_dec + loss_ivref 67 | 68 | def get_smooth(I, direction): 69 | #smooth 70 | weights = torch.tensor([[0., 0.], 71 | [-1., 1.]] 72 | ).cuda() 73 | weights_x = weights.view(1, 1, 2, 2).repeat(1, 1, 1, 1) 74 | weights_y = torch.transpose(weights_x, 0, 1) 75 | if direction == 'x': 76 | weights = weights_x 77 | elif direction == 'y': 78 | weights = weights_y 79 | 80 | F = torch.nn.functional 81 | output = torch.abs(F.conv2d(I, weights, stride=1, padding=1)) # stride, padding 82 | return output 83 | 84 | def avg(R, direction): 85 | return nn.AvgPool2d(kernel_size=3, stride=1, padding=1)(get_smooth(R, direction)) 86 | 87 | def get_gradients_loss(I, R): 88 | R_gray = torch.mean(R, dim=1, keepdim=True) 89 | gradients_I_x = get_smooth(I,'x') 90 | gradients_I_y = get_smooth(I,'y') 91 | 92 | return torch.mean(gradients_I_x * torch.exp(-10 * avg(R_gray, 'x')) + gradients_I_y * torch.exp(-10 * avg(R_gray, 'y'))) 93 | 94 | smooth_loss_low = get_gradients_loss(I_low, R_low) 95 | smooth_loss_normal = get_gradients_loss(I_normal, R_normal) 96 | smooth_loss_low_hat = get_gradients_loss(I_low_hat, R_low) 97 | 98 | loss_dec += 0.1 * smooth_loss_low + 0.1 * smooth_loss_normal 99 | if flag: 100 | opt_Dec.zero_grad() 101 | loss_dec.backward() 102 | opt_Dec.step() 103 | 104 | elif not flag: 105 | loss_reconst_enh = reconst_loss(S_normal, R_low.mul(I_low_hat)) 106 | loss_enh = loss_reconst_enh + 3 * smooth_loss_low_hat 107 | 108 | opt_Enh.zero_grad() 109 | loss_enh.backward() 110 | opt_Enh.step() 111 | 112 | flag = not flag 113 | 114 | print("Epoch: {}; Loss_Dec: {}; Loss_Enh: {}".format(e, loss_dec, loss_enh)) 115 | if e % 10 == 0 and e != 0: 116 | if not os.path.isdir("./checkpoints/"): 117 | os.makedirs("./checkpoints/") 118 | torch.save({"Dec_model": Dec.state_dict(), 119 | "Enh_model": Enh.state_dict()}, 120 | "./checkpoints/model{}.tar".format(e)) 121 | 122 | torch.save({"Dec_model": Dec.state_dict(), 123 | "Enh_model": Enh.state_dict()}, 124 | "./checkpoints/model_newest.tar") 125 | 126 | 127 | if __name__ == '__main__': 128 | train(epoch) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | IMG_EXTENSIONS = [ 4 | '.jpg', '.JPG', '.jpeg', '.JPEG', 5 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 6 | ] 7 | 8 | def make_dataset(dir): 9 | images = [] 10 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 11 | 12 | for root, _, fnames in sorted(os.walk(dir)): 13 | for fname in fnames: 14 | if is_image_file(fname): 15 | path = os.path.join(root, fname) 16 | images.append(path) 17 | 18 | return images 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | --------------------------------------------------------------------------------