├── utils.py ├── dataset.py ├── loss.py ├── model.py └── train.py /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def data_augmentation(image, mode): 6 | if mode == 0: 7 | # original 8 | return image 9 | elif mode == 1: 10 | # flip up and down 11 | return np.flipud(image) 12 | elif mode == 2: 13 | # rotate counter-wise 90 degree 14 | return np.rot90(image) 15 | elif mode == 3: 16 | # rotate 90 degree and flip up and down 17 | image = np.rot90(image) 18 | return np.flipud(image) 19 | elif mode == 4: 20 | # rotate 180 degree 21 | return np.rot90(image, k=2) 22 | elif mode == 5: 23 | # rotate 180 degree and flip 24 | image = np.rot90(image, k=2) 25 | return np.flipud(image) 26 | elif mode == 6: 27 | # rotate 270 degree 28 | return np.rot90(image, k=3) 29 | elif mode == 7: 30 | # rotate 270 degree and flip 31 | image = np.rot90(image, k=3) 32 | return np.flipud(image) 33 | 34 | 35 | def load_images(file): 36 | im = Image.open(file) 37 | return np.array(im, dtype="float32") / 255.0 38 | 39 | 40 | def save_images(filepath, result_1, result_2=None): 41 | result_1 = np.squeeze(result_1) 42 | result_2 = np.squeeze(result_2) 43 | 44 | if not result_2.any(): 45 | cat_image = result_1 46 | else: 47 | cat_image = np.concatenate([result_1, result_2], axis=1) 48 | 49 | im = Image.fromarray(np.clip(cat_image * 255.0, 0, 255.0).astype('uint8')) 50 | im.save(filepath, 'png') 51 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch.utils.data as data 3 | from glob import glob 4 | from utils import load_images, data_augmentation 5 | 6 | 7 | # folder structure 8 | # data 9 | # |train 10 | # | |low 11 | # | | |*.png 12 | # | |high 13 | # | | |*.png 14 | # |eval 15 | # | |low 16 | # | | |*.png 17 | def get_dataset_len(route, phase): 18 | if phase == 'train': 19 | low_data_names = glob(route + phase + '/low/*.png') 20 | high_data_names = glob(route + phase + '/high/*.png') 21 | low_data_names.sort() 22 | high_data_names.sort() 23 | assert len(low_data_names) == len(high_data_names) 24 | return len(low_data_names), [low_data_names, high_data_names] 25 | elif phase == 'eval': 26 | low_data_names = glob(route + phase + '/low/*.png') 27 | return len(low_data_names), low_data_names 28 | else: 29 | return 0, [] 30 | 31 | 32 | def getitem(phase, data_names, item, patch_size): 33 | if phase == 'train': 34 | low_im = load_images(data_names[0][item]) 35 | high_im = load_images(data_names[1][item]) 36 | 37 | h, w, _ = low_im.shape 38 | x = random.randint(0, h - patch_size) 39 | y = random.randint(0, w - patch_size) 40 | rand_mode = random.randint(0, 7) 41 | 42 | low_im = data_augmentation(low_im[x:x + patch_size, y:y + patch_size, :], rand_mode) 43 | high_im = data_augmentation(high_im[x:x + patch_size, y:y + patch_size, :], rand_mode) 44 | 45 | return low_im, high_im 46 | elif phase == 'eval': 47 | low_im = load_images(data_names[item]) 48 | return low_im 49 | 50 | 51 | class TheDataset(data.dataset): 52 | 53 | def __init__(self, route='./data/', phase='train', patch_size=48): 54 | self.route = route 55 | self.phase = phase 56 | self.patch_size = patch_size 57 | self.len, self.data_names = get_dataset_len(route, phase) 58 | 59 | def __len__(self): 60 | return self.len 61 | 62 | def __getitem__(self, item): 63 | return getitem(self.phase, self.data_names, item, self.patch_size) 64 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def gradient(input_tensor, direction): 7 | input_tensor = input_tensor.permute(0, 3, 1, 2) 8 | h, w = input_tensor.size()[2], input_tensor.size()[3] 9 | 10 | smooth_kernel_x = torch.reshape(torch.Tensor([[0., 0.], [-1., 1.]]), (1, 1, 2, 2)) 11 | smooth_kernel_y = smooth_kernel_x.permute(0, 1, 3, 2) 12 | 13 | assert direction in ['x', 'y'] 14 | if direction == "x": 15 | kernel = smooth_kernel_x 16 | else: 17 | kernel = smooth_kernel_y 18 | 19 | out = F.conv2d(input_tensor, kernel, padding=(1, 1)) 20 | out = torch.abs(out[:, :, 0:h, 0:w]) 21 | 22 | return out.permute(0, 2, 3, 1) 23 | 24 | 25 | def ave_gradient(input_tensor, direction): 26 | return (F.avg_pool2d(gradient(input_tensor, direction).permute(0, 3, 1, 2), 3, stride=1, padding=1))\ 27 | .permute(0, 2, 3, 1) 28 | 29 | 30 | def smooth(input_l, input_r): 31 | rgb_weights = torch.Tensor([0.2989, 0.5870, 0.1140]) 32 | input_r = torch.tensordot(input_r, rgb_weights, dims=([-1], [-1])) 33 | input_r = torch.unsqueeze(input_r, -1) 34 | 35 | return torch.mean( 36 | gradient(input_l, 'x') * torch.exp(-10 * ave_gradient(input_r, 'x')) + 37 | gradient(input_l, 'y') * torch.exp(-10 * ave_gradient(input_r, 'y')) 38 | ) 39 | 40 | 41 | class DecomLoss(nn.Module): 42 | 43 | def __init__(self): 44 | super(DecomLoss, self).__init__() 45 | 46 | def forward(self, r_low, l_low, r_high, l_high, input_low, input_high): 47 | l_low_3 = torch.cat((l_low, l_low, l_low), -1) 48 | l_high_3 = torch.cat((l_high, l_high, l_high), -1) 49 | 50 | recon_loss_low = torch.mean(torch.abs(r_low * l_low_3 - input_low)) 51 | recon_loss_high = torch.mean(torch.abs(r_high * l_high_3 - input_high)) 52 | recon_loss_mutal_low = torch.mean(torch.abs(r_high * l_low_3 - input_low)) 53 | recon_loss_mutal_high = torch.mean(torch.abs(r_low * l_high_3 - input_high)) 54 | equal_r_loss = torch.mean(torch.abs(r_low - r_high)) 55 | 56 | ismooth_loss_low = smooth(l_low, r_low) 57 | ismooth_loss_high = smooth(l_high, r_high) 58 | 59 | return \ 60 | recon_loss_low + recon_loss_high +\ 61 | 0.001*recon_loss_mutal_low + 0.001*recon_loss_mutal_high + \ 62 | 0.1*ismooth_loss_low + 0.1*ismooth_loss_high + \ 63 | 0.01*equal_r_loss 64 | 65 | 66 | class RelightLoss(nn.Module): 67 | 68 | def __init__(self): 69 | super(RelightLoss, self).__init__() 70 | 71 | def forward(self, l_delta, r_low, input_high): 72 | l_delta_3 = torch.cat((l_delta, l_delta, l_delta), -1) 73 | 74 | relight_loss = torch.mean(torch.abs(r_low * l_delta_3 - input_high)) 75 | 76 | ismooth_loss_delta = smooth(l_delta, r_low) 77 | 78 | return relight_loss + 3 * ismooth_loss_delta 79 | 80 | 81 | if __name__ == '__main__': 82 | tensor = torch.rand(1, 300, 400, 1) 83 | out_data = smooth(tensor, torch.rand(1, 300, 400, 3)) 84 | print(out_data) 85 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DecomNet(nn.Module): 7 | 8 | def __init__(self, layer_num=5, channel=64, kernel_size=3): 9 | super(DecomNet, self).__init__() 10 | self.layer_num = layer_num 11 | self.conv0 = nn.Conv2d(4, channel, kernel_size*3, padding=4) 12 | feature_conv = [] 13 | for idx in range(layer_num): 14 | feature_conv.append(nn.Sequential( 15 | nn.Conv2d(channel, channel, kernel_size, padding=1), 16 | nn.ReLU() 17 | )) 18 | self.conv = nn.ModuleList(feature_conv) 19 | self.conv1 = nn.Conv2d(channel, 4, kernel_size, padding=1) 20 | self.sig = nn.Sigmoid() 21 | 22 | def forward(self, x): 23 | x_max = torch.max(x, dim=3, keepdim=True) 24 | x = torch.cat((x, x_max[0]), dim=3) 25 | x = x.permute(0, 3, 1, 2) 26 | 27 | out = self.conv0(x) 28 | for idx in range(self.layer_num): 29 | out = self.conv[idx](out) 30 | out = self.conv1(out) 31 | out = self.sig(out) 32 | 33 | out = out.permute(0, 2, 3, 1) 34 | r_part = out[:, :, :, 0:3] 35 | l_part = out[:, :, :, 3:4] 36 | 37 | return out, r_part, l_part 38 | 39 | 40 | class RelightNet(nn.Module): 41 | 42 | def __init__(self, channel=64, kernel_size=3): 43 | super(RelightNet, self).__init__() 44 | self.conv0 = nn.Conv2d(4, channel, kernel_size, padding=1) 45 | self.conv1 = nn.Sequential( 46 | nn.Conv2d(channel, channel, kernel_size, stride=2, padding=1), 47 | nn.ReLU() 48 | ) 49 | self.conv2 = nn.Sequential( 50 | nn.Conv2d(channel, channel, kernel_size, stride=2, padding=1), 51 | nn.ReLU() 52 | ) 53 | self.conv3 = nn.Sequential( 54 | nn.Conv2d(channel, channel, kernel_size, stride=2, padding=1), 55 | nn.ReLU() 56 | ) 57 | self.deconv1 = nn.Sequential( 58 | nn.Conv2d(channel, channel, kernel_size, padding=1), 59 | nn.ReLU() 60 | ) 61 | self.deconv2 = nn.Sequential( 62 | nn.Conv2d(channel, channel, kernel_size, padding=1), 63 | nn.ReLU() 64 | ) 65 | self.deconv3 = nn.Sequential( 66 | nn.Conv2d(channel, channel, kernel_size, padding=1), 67 | nn.ReLU() 68 | ) 69 | self.feature_fusion = nn.Conv2d(channel*3, channel, 1) 70 | self.output = nn.Conv2d(channel, 1, kernel_size, padding=1) 71 | 72 | def forward(self, x): 73 | x = x.permute(0, 3, 1, 2) 74 | 75 | conv0 = self.conv0(x) 76 | conv1 = self.conv1(conv0) 77 | conv2 = self.conv1(conv1) 78 | conv3 = self.conv1(conv2) 79 | 80 | up1 = F.interpolate(conv3, scale_factor=2) 81 | deconv1 = self.deconv1(up1) + conv2 82 | up2 = F.interpolate(deconv1, scale_factor=2) 83 | deconv2 = self.deconv2(up2) + conv1 84 | up3 = F.interpolate(deconv2, scale_factor=2) 85 | deconv3 = self.deconv3(up3) + conv0 86 | 87 | deconv1_resize = F.interpolate(deconv1, scale_factor=4) 88 | deconv2_resize = F.interpolate(deconv2, scale_factor=2) 89 | 90 | out = torch.cat((deconv1_resize, deconv2_resize, deconv3), dim=1) 91 | out = self.feature_fusion(out) 92 | out = self.output(out) 93 | 94 | return out.permute(0, 2, 3, 1) 95 | 96 | 97 | if __name__ == '__main__': 98 | net = DecomNet() 99 | relight_net = RelightNet() 100 | data_in = torch.rand(1, 600, 400, 3) 101 | out_sum, r_low, l_low = net(data_in) 102 | out_S = relight_net(out_sum) 103 | print(out_S.size()) 104 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.backends.cudnn as cudnn 4 | import numpy as np 5 | import argparse 6 | from model import DecomNet, RelightNet 7 | from dataset import TheDataset 8 | from loss import DecomLoss, RelightLoss 9 | import tqdm 10 | 11 | 12 | parser = argparse.ArgumentParser(description='RetinexNet args setting') 13 | 14 | parser.add_argument('--use_gpu', dest='use_gpu', type=int, default=1, help='gpu flag, 1 for GPU and 0 for CPU') 15 | parser.add_argument('--gpu_idx', dest='gpu_idx', default='0', help='GPU idx') 16 | parser.add_argument('--phase', dest='phase', default='train', help='train or test') 17 | 18 | parser.add_argument('--epoch', dest='epoch', type=int, default=100, help='number of total epoch') 19 | parser.add_argument('--batch_size', dest='batch_size', type=int, default=16, help='number of samples in one batch') 20 | parser.add_argument('--patch_size', dest='patch_size', type=int, default=48, help='patch size') 21 | parser.add_argument('--workers', dest='workers', type=int, default=8, help='num workers of dataloader') 22 | parser.add_argument('--start_lr', dest='start_lr', type=float, default=0.001, help='initial learning rate for adam') 23 | parser.add_argument('--save_interval', dest='save_interval', type=int, default=20, help='save model every # epoch') 24 | 25 | parser.add_argument('--checkpoint_dir', dest='ckpt_dir', default='./checkpoint', help='directory for checkpoints') 26 | parser.add_argument('--save_dir', dest='save_dir', default='./test_results', help='directory for testing outputs') 27 | parser.add_argument('--test_dir', dest='test_dir', default='./data/test/low', help='directory for testing inputs') 28 | parser.add_argument('--decom', dest='decom', default=0, 29 | help='decom flag, 0 for enhanced results only and 1 for decomposition results') 30 | 31 | args = parser.parse_args() 32 | 33 | if not os.path.exists(args.ckpt_dir): 34 | os.makedirs(args.ckpt_dir) 35 | 36 | decom_net = DecomNet() 37 | relight_net = RelightNet() 38 | 39 | if args.use_gpu: 40 | decom_net = decom_net.cuda() 41 | relight_net = relight_net.cuda() 42 | cudnn.benchmark = True 43 | cudnn.enabled = True 44 | 45 | lr = args.start_lr * np.ones([args.epoch]) 46 | lr[20:] = lr[0] / 10.0 47 | 48 | decom_optim = torch.optim.Adam(decom_net.parameters(), lr=args.start_lr) 49 | relight_optim = torch.optim.Adam(relight_net.parameters(), lr=args.start_lr) 50 | 51 | train_set = TheDataset() 52 | 53 | decom_criterion = DecomLoss() 54 | relight_criterion = RelightLoss() 55 | 56 | 57 | def train(): 58 | decom_net.train() 59 | relight_net.train() 60 | 61 | for epoch in range(args.epoch): 62 | times_per_epoch, sum_loss = 0, 0. 63 | 64 | dataloader = torch.utils.data.Dataloader(train_set, batch_size=args.batch_size, shuffle=True, 65 | num_workers=args.workers, pin_memory=True) 66 | decom_optim.param_groups[0]['lr'] = lr[epoch] 67 | 68 | for data in tqdm.tqdm(dataloader): 69 | times_per_epoch += 1 70 | low_im, high_im = data 71 | low_im, high_im = low_im.cuda(), high_im.cuda() 72 | 73 | decom_optim.zero_grad() 74 | _, r_low, l_low = decom_net(low_im) 75 | _, r_high, l_high = decom_net(high_im) 76 | loss = decom_criterion(r_low, l_low, r_high, l_high, low_im, high_im) 77 | loss.backward() 78 | decom_optim.step() 79 | 80 | sum_loss += loss 81 | 82 | print('epoch: ' + str(epoch) + ' | loss: ' + str(sum_loss / times_per_epoch)) 83 | if (epoch+1) % args.save_interval == 0: 84 | torch.save(decom_net.state_dict(), args.ckpt_dir + '/decom_' + str(epoch) + '.pth') 85 | 86 | torch.save(decom_net.state_dict(), args.ckpt_dir + '/decom_final.pth') 87 | 88 | for epoch in range(args.epoch): 89 | times_per_epoch, sum_loss = 0, 0. 90 | 91 | dataloader = torch.utils.data.Dataloader(train_set, batch_size=args.batch_size, shuffle=True, 92 | num_workers=args.workers, pin_memory=True) 93 | relight_optim.param_groups[0]['lr'] = lr[epoch] 94 | 95 | for data in tqdm.tqdm(dataloader): 96 | times_per_epoch += 1 97 | low_im, high_im = data 98 | low_im, high_im = low_im.cuda(), high_im.cuda() 99 | 100 | relight_optim.zero_grad() 101 | lr_low, r_low, _ = decom_net(low_im) 102 | l_delta = relight_net(lr_low.detach()) 103 | loss = relight_criterion(l_delta, r_low.detach(), high_im) 104 | loss.backward() 105 | relight_optim.step() 106 | 107 | sum_loss += loss 108 | 109 | print('epoch: ' + str(epoch) + ' | loss: ' + str(sum_loss / times_per_epoch)) 110 | if (epoch+1) % args.save_interval == 0: 111 | torch.save(relight_net.state_dict(), args.ckpt_dir + '/relight_' + str(epoch) + '.pth') 112 | 113 | torch.save(relight_net.state_dict(), args.ckpt_dir + '/relight_final.pth') 114 | 115 | 116 | if __name__ == '__main__': 117 | train() 118 | --------------------------------------------------------------------------------