├── 4D LUT_overview.jpg ├── Identity4DLUT17.txt ├── Identity4DLUT33.txt ├── LICENSE ├── README.md ├── datasets.py ├── models_x.py ├── quadrilinear_cpp ├── setup.py ├── setup.sh └── src │ ├── quadrilinear4d.cpp │ ├── quadrilinear4d.h │ ├── quadrilinear4d_cuda.cpp │ ├── quadrilinear4d_cuda.h │ ├── quadrilinear4d_kernel.cu │ └── quadrilinear4d_kernel.h └── train.py /4D LUT_overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChengxuLiu/4DLUT/7bfccd19c0170a8cd1a5c8c5755c19bd667b29ac/4D LUT_overview.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Chengxu Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [TIP 2023] 4D LUT 2 | This is the official PyTorch implementation of the paper [4D LUT: Learnable Context-Aware 4D Lookup Table for Image Enhancement](https://arxiv.org/abs/2209.01749). 3 | 4 | ## Overview 5 | 6 | 7 | ## Contribution 8 | * We propose a novel learnable context-aware 4-dimensional lookup table (4D~LUT), which first extends the lookup table architecture into a 4-dimensional space and achieve content-dependent image enhancement without a significant increase in computational costs. 9 | * The extensive experiments demonstrate that the proposed 4D~LUT can obtain more accurate results and significantly outperform existing SOTA methods in three widely-used image enhancement benchmarks. 10 | 11 | ## Requirements and dependencies 12 | * CUDA 11.4 13 | * GCC 7.5.0 14 | * python 3.8 (recommend to use [Anaconda](https://www.anaconda.com/)) 15 | * pytorch == 1.9.0 16 | * torchvision == 0.10.0 17 | 18 | ## Train 19 | 1. Clone this github repo 20 | ``` 21 | git clone https://github.com/ChengxuLiu/4DLUT.git 22 | cd 4DLUT 23 | ``` 24 | 2. Build: 25 | ``` 26 | cd quadrilinear_cpp 27 | sh setup.sh 28 | ``` 29 | 3. Prepare training dataset and modify dataset path in `./dataset.py` 30 | 4. Run training 31 | ``` 32 | python train.py 33 | ``` 34 | 5. The models are saved in `./saved_models` 35 | 36 | 37 | ## Citation 38 | If you find the code useful for your research, please consider citing our paper. :blush: 39 | ``` 40 | @article{liu20234d, 41 | title={4D LUT: learnable context-aware 4d lookup table for image enhancement}, 42 | author={Liu, Chengxu and Yang, Huan and Fu, Jianlong and Qian, Xueming}, 43 | journal={IEEE Transactions on Image Processing}, 44 | volume={32}, 45 | pages={4742--4756}, 46 | year={2023}, 47 | publisher={IEEE} 48 | } 49 | ``` 50 | 51 | ## Contact 52 | If you meet any problems, please describe them in issues or contact: 53 | * Chengxu Liu: 54 | 55 | ## Acknowledgement 56 | The code of 4DLUT is built upon [Image-Adaptive-3DLUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT), and we express our gratitude to these awesome projects. 57 | 58 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | import numpy as np 5 | import torch 6 | import cv2 7 | 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | import torchvision.transforms as transforms 11 | import torchvision.transforms.functional as TF 12 | 13 | 14 | class ImageDataset_sRGB(Dataset): 15 | def __init__(self, root, mode="train", unpaird_data="fiveK", combined=False): 16 | self.mode = mode 17 | self.unpaird_data = unpaird_data 18 | 19 | file = open(os.path.join(root,'images_train.txt'),'r') #for DPE 20 | set1_input_files = sorted(file.readlines()) 21 | self.set1_input_files = list() 22 | self.set1_expert_files = list() 23 | for i in range(len(set1_input_files)): 24 | self.set1_input_files.append(os.path.join(root,"input","InputAsShotZero",set1_input_files[i][:-1] + ".png")) 25 | self.set1_expert_files.append(os.path.join(root,"output","Export_C_512",set1_input_files[i][:-1] + ".png")) 26 | 27 | file = open(os.path.join(root,'images_test.txt'),'r') 28 | test_input_files = sorted(file.readlines()) 29 | self.test_input_files = list() 30 | self.test_expert_files = list() 31 | for i in range(len(test_input_files)): 32 | self.test_input_files.append(os.path.join(root,"input","InputAsShotZero",test_input_files[i][:-1] + ".png")) 33 | self.test_expert_files.append(os.path.join(root,"output","Export_C_512",test_input_files[i][:-1] + ".png")) 34 | 35 | def __getitem__(self, index): 36 | if self.mode == "train": 37 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1] 38 | img_input = Image.open(self.set1_input_files[index % len(self.set1_input_files)]) 39 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)]) 40 | 41 | elif self.mode == "test": 42 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1] 43 | img_input = Image.open(self.test_input_files[index % len(self.test_input_files)]) 44 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)]) 45 | 46 | if self.mode == "train": 47 | 48 | ratio_H = np.random.uniform(0.6,1.0) 49 | ratio_W = np.random.uniform(0.6,1.0) 50 | W,H = img_input._size 51 | crop_h = round(H*ratio_H) 52 | crop_w = round(W*ratio_W) 53 | 54 | i, j, h, w = transforms.RandomCrop.get_params(img_input, output_size=(crop_h, crop_w)) 55 | img_input = TF.crop(img_input, i, j, h, w) 56 | img_exptC = TF.crop(img_exptC, i, j, h, w) 57 | 58 | if np.random.random() > 0.5: 59 | img_input = TF.hflip(img_input) 60 | img_exptC = TF.hflip(img_exptC) 61 | 62 | a = np.random.uniform(0.8,1.2) 63 | img_input = TF.adjust_brightness(img_input,a) 64 | 65 | a = np.random.uniform(0.8,1.2) 66 | img_input = TF.adjust_saturation(img_input,a) 67 | 68 | img_input = TF.to_tensor(img_input) 69 | img_exptC = TF.to_tensor(img_exptC) 70 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name} 71 | 72 | def __len__(self): 73 | if self.mode == "train": 74 | return len(self.set1_input_files) 75 | elif self.mode == "test": 76 | return len(self.test_input_files) 77 | -------------------------------------------------------------------------------- /models_x.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision.models as models 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | import torch 7 | import numpy as np 8 | import math 9 | import quadrilinear4d 10 | 11 | def weights_init_normal_generator(m): 12 | classname = m.__class__.__name__ 13 | if classname.find("Conv2d") != -1: 14 | torch.nn.init.xavier_normal_(m.weight.data) 15 | 16 | elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1: 17 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 18 | torch.nn.init.constant_(m.bias.data, 0.0) 19 | 20 | 21 | def discriminator_block(in_filters, out_filters, normalization=False): 22 | """Returns downsampling layers of each discriminator block""" 23 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)] 24 | layers.append(nn.LeakyReLU(0.2)) 25 | if normalization: 26 | layers.append(nn.InstanceNorm2d(out_filters, affine=True)) 27 | 28 | return layers 29 | 30 | 31 | class Generator_for_bias(nn.Module): 32 | def __init__(self, in_channels=3): 33 | super(Generator_for_bias, self).__init__() 34 | 35 | self.model = nn.Sequential( 36 | nn.Upsample(size=(256,256),mode='bilinear'), 37 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 38 | nn.LeakyReLU(0.2), 39 | nn.InstanceNorm2d(16, affine=True), 40 | *discriminator_block(16, 32, normalization=True), 41 | *discriminator_block(32, 64, normalization=True), 42 | *discriminator_block(64, 128, normalization=True), 43 | *discriminator_block(128, 128), 44 | nn.Dropout(p=0.5), 45 | nn.Conv2d(128, 12, 8, padding=0), 46 | ) 47 | 48 | def forward(self, img_input): 49 | return self.model(img_input) 50 | 51 | 52 | def generator_block(in_filters, out_filters, normalization=False): 53 | """Returns downsampling layers of each discriminator block""" 54 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1)] 55 | layers.append(nn.LeakyReLU(0.2)) 56 | if normalization: 57 | layers.append(nn.InstanceNorm2d(out_filters, affine=True)) 58 | #layers.append(nn.BatchNorm2d(out_filters)) 59 | return layers 60 | 61 | 62 | class Generator_for_info(nn.Module): 63 | def __init__(self, in_channels=3): 64 | super(Generator_for_info, self).__init__() 65 | 66 | self.input_layer = nn.Sequential( 67 | nn.Conv2d(in_channels, 16, 3, stride=1, padding=1), 68 | nn.LeakyReLU(0.2), 69 | nn.InstanceNorm2d(16, affine=True), 70 | ) 71 | 72 | self.mid_layer = nn.Sequential( 73 | *generator_block(16, 16, normalization=True), 74 | *generator_block(16, 16, normalization=True), 75 | *generator_block(16, 16, normalization=True), 76 | ) 77 | 78 | self.output_layer = nn.Sequential( 79 | nn.Dropout(p=0.5), 80 | nn.Conv2d(16, 1, 3, stride=1, padding=1), 81 | nn.Sigmoid() 82 | ) 83 | 84 | 85 | def forward(self, img_input): 86 | x = self.input_layer(img_input) 87 | identity = x 88 | out = self.mid_layer(x) 89 | out += identity 90 | out = self.output_layer(out) 91 | return out 92 | 93 | 94 | 95 | class Generator4DLUT_identity(nn.Module): 96 | def __init__(self, dim=17): 97 | super(Generator4DLUT_identity, self).__init__() 98 | if dim == 17: 99 | file = open("Identity4DLUT17.txt", 'r') 100 | elif dim == 33: 101 | file = open("Identity4DLUT33.txt", 'r') 102 | lines = file.readlines() 103 | buffer = np.zeros((3,2,dim,dim,dim), dtype=np.float32) 104 | for p in range(0,2): 105 | for i in range(0,dim): 106 | for j in range(0,dim): 107 | for k in range(0,dim): 108 | n = p * dim*dim*dim + i * dim*dim + j*dim + k 109 | x = lines[n].split() 110 | buffer[0,p,i,j,k] = float(x[0]) 111 | buffer[1,p,i,j,k] = float(x[1]) 112 | buffer[2,p,i,j,k] = float(x[2]) 113 | self.LUT_en = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True)) 114 | self.QuadrilinearInterpolation_4D = QuadrilinearInterpolation_4D() 115 | 116 | def forward(self, x): 117 | _, output = self.QuadrilinearInterpolation_4D(self.LUT_en, x) 118 | return output 119 | 120 | 121 | 122 | 123 | class QuadrilinearInterpolation_Function(torch.autograd.Function): 124 | @staticmethod 125 | def forward(ctx, lut, x): 126 | x = x.contiguous() 127 | output = x.new(x.size()[0],3,x.size()[2],x.size()[3]) 128 | dim = lut.size()[-1] 129 | shift = 2 * dim ** 3 130 | binsize = 1.000001 / (dim-1) 131 | W = x.size(2) 132 | H = x.size(3) 133 | batch = x.size(0) 134 | assert 1 == quadrilinear4d.forward(lut, 135 | x, 136 | output, 137 | dim, 138 | shift, 139 | binsize, 140 | W, 141 | H, 142 | batch) 143 | int_package = torch.IntTensor([dim, shift, W, H, batch]) 144 | float_package = torch.FloatTensor([binsize]) 145 | variables = [lut, x, int_package, float_package] 146 | ctx.save_for_backward(*variables) 147 | 148 | return lut, output 149 | 150 | @staticmethod 151 | def backward(ctx, lut_grad, x_grad): 152 | # print() 153 | x_grad = x_grad.contiguous() 154 | output_grad = x_grad.new(x_grad.size()[0],4,x_grad.size()[2],x_grad.size()[3]).fill_(0) 155 | output_grad[:,1:,:,:] = x_grad 156 | lut, x, int_package, float_package = ctx.saved_variables 157 | dim, shift, W, H, batch = int_package 158 | dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch) 159 | binsize = float(float_package[0]) 160 | 161 | assert 1 == quadrilinear4d.backward(x, 162 | output_grad, 163 | lut, 164 | lut_grad, 165 | dim, 166 | shift, 167 | binsize, 168 | W, 169 | H, 170 | batch) 171 | return lut_grad, output_grad 172 | 173 | 174 | class QuadrilinearInterpolation_4D(torch.nn.Module): 175 | def __init__(self): 176 | super(QuadrilinearInterpolation_4D, self).__init__() 177 | 178 | def forward(self, lut, x): 179 | return QuadrilinearInterpolation_Function.apply(lut, x) 180 | 181 | 182 | class TV_4D(nn.Module): 183 | def __init__(self, dim=17): 184 | super(TV_4D,self).__init__() 185 | 186 | self.weight_r = torch.ones(3,2,dim,dim,dim-1, dtype=torch.float) 187 | self.weight_r[:,:,:,:,(0,dim-2)] *= 2.0 188 | self.weight_g = torch.ones(3,2,dim,dim-1,dim, dtype=torch.float) 189 | self.weight_g[:,:,:,(0,dim-2),:] *= 2.0 190 | self.weight_b = torch.ones(3,2,dim-1,dim,dim, dtype=torch.float) 191 | self.weight_b[:,:,(0,dim-2),:,:] *= 2.0 192 | self.relu = torch.nn.ReLU() 193 | 194 | def forward(self, LUT): 195 | dif_context = LUT.LUT_en[:,:-1,:,:,:] - LUT.LUT_en[:,1:,:,:,:] 196 | dif_r = LUT.LUT_en[:,:,:,:,:-1] - LUT.LUT_en[:,:,:,:,1:] 197 | dif_g = LUT.LUT_en[:,:,:,:-1,:] - LUT.LUT_en[:,:,:,1:,:] 198 | dif_b = LUT.LUT_en[:,:,:-1,:,:] - LUT.LUT_en[:,:,1:,:,:] 199 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b)) 200 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b)) \ 201 | + torch.mean(self.relu(dif_context)) 202 | return tv, mn 203 | 204 | -------------------------------------------------------------------------------- /quadrilinear_cpp/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import torch 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 4 | 5 | if torch.cuda.is_available(): 6 | # if False: 7 | print('Including CUDA code.') 8 | setup( 9 | name='quadrilinear4d', 10 | ext_modules=[ 11 | CUDAExtension('quadrilinear4d', [ 12 | 'src/quadrilinear4d_cuda.cpp', 13 | 'src/quadrilinear4d_kernel.cu', 14 | ]) 15 | ], 16 | cmdclass={ 17 | 'build_ext': BuildExtension 18 | }) 19 | else: 20 | print('NO CUDA is found. Fall back to CPU.') 21 | setup(name='quadrilinear4d', 22 | ext_modules=[CppExtension(name = 'quadrilinear4d', 23 | sources= ['src/quadrilinear4d.cpp'], 24 | extra_compile_args=['-fopenmp'])], 25 | cmdclass={'build_ext': BuildExtension}) 26 | -------------------------------------------------------------------------------- /quadrilinear_cpp/setup.sh: -------------------------------------------------------------------------------- 1 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install 2 | -------------------------------------------------------------------------------- /quadrilinear_cpp/src/quadrilinear4d.cpp: -------------------------------------------------------------------------------- 1 | #include "quadrilinear.h" 2 | 3 | 4 | void QuadriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 5 | 6 | void QuadriLinearBackwardCpu(const float* image, float* image_grad,const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 7 | 8 | int quadrilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 9 | int lut_dim, int shift, float binsize, int width, int height, int batch) 10 | { 11 | // Grab the input tensor 12 | float * lut_flat = lut.data(); 13 | float * image_flat = image.data(); 14 | float * output_flat = output.data(); 15 | 16 | // whether color image 17 | auto image_size = image.sizes(); 18 | int channels = image_size[1]; 19 | 20 | QuadriLinearForwardCpu(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, channels); 21 | 22 | return 1; 23 | } 24 | 25 | int quadrilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut, torch::Tensor lut_grad, 26 | int lut_dim, int shift, float binsize, int width, int height, int batch) 27 | { 28 | // Grab the input tensor 29 | float * image_grad_flat = image_grad.data(); 30 | float * lut_flat = lut.data(); 31 | float * image_flat = image.data(); 32 | float * lut_grad_flat = lut_grad.data(); 33 | 34 | // whether color image 35 | auto image_size = image.sizes(); 36 | int channels = image_size[1]; 37 | if (channels != 3) 38 | { 39 | return 0; 40 | } 41 | 42 | TriLinearBackwardCpu(image_flat, image_grad_flat,lut_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, channels); 43 | 44 | return 1; 45 | } 46 | 47 | void QuadriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 48 | { 49 | const int output_size = height * width;; 50 | 51 | int index = 0; 52 | 53 | #pragma omp parallel 54 | #pragma omp for 55 | for (index = 0; index < output_size; ++index) 56 | { 57 | float context = image[index]; 58 | float r = image[index + width * height]; 59 | float g = image[index + width * height * 2]; 60 | float b = image[index + width * height * 3]; 61 | 62 | int r_id = floor(r / binsize); 63 | int g_id = floor(g / binsize); 64 | int b_id = floor(b / binsize); 65 | int context_id = floor(context / binsize); 66 | 67 | float r_d = fmod(r,binsize) / binsize; 68 | float g_d = fmod(g,binsize) / binsize; 69 | float b_d = fmod(b,binsize) / binsize; 70 | float context_d = fmod(context,binsize) / binsize; 71 | 72 | int id0000 = context_id + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim; 73 | int id1000 = context_id + 1 + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim; 74 | int id0100 = context_id + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim; 75 | int id0010 = context_id + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 76 | int id0001 = context_id + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 77 | int id1100 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim; 78 | int id0110 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 79 | int id0011 = context_id + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 80 | int id1010 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 81 | int id1001 = context_id + 1 + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 82 | int id0101 = context_id + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 83 | int id1110 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 84 | int id1011 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 85 | int id1101 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 86 | int id0111 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 87 | int id1111 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 88 | 89 | 90 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d); 91 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d); 92 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d); 93 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d); 94 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d; 95 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d); 96 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d); 97 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d; 98 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d); 99 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d; 100 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d; 101 | float w1110 = context_d*r_d*g_d*(1-b_d); 102 | float w0111 = (1-context_d)*r_d*g_d*b_d; 103 | float w1101 = context_d*r_d*(1-g_d)*b_d; 104 | float w1011 = context_d*(1-r_d)*g_d*b_d; 105 | float w1111 = context_d*r_d*g_d*b_d; 106 | 107 | output[index] = w0000 * lut[id0000] + w1000 * lut[id1000] + w0100 * lut[id0100] + w0010 * lut[id0010] + 108 | w0001 * lut[id0001] + w1100 * lut[id1100] + w0110 * lut[id0110] + w0011 * lut[id0011] + 109 | w1010 * lut[id1010] + w1001 * lut[id1001] + w0101 * lut[id0101] + w1110 * lut[id1110] + 110 | w0111 * lut[id0111] + w1101 * lut[id1101] + w1011 * lut[id1011] + w1111 * lut[id1111]; 111 | 112 | output[index + width * height] = w0000 * lut[id0000 + shift] + w1000 * lut[id1000 + shift] + w0100 * lut[id0100 + shift] + w0010 * lut[id0010 + shift] + 113 | w0001 * lut[id0001 + shift] + w1100 * lut[id1100 + shift] + w0110 * lut[id0110 + shift] + w0011 * lut[id0011 + shift] + 114 | w1010 * lut[id1010 + shift] + w1001 * lut[id1001 + shift] + w0101 * lut[id0101 + shift] + w1110 * lut[id1110 + shift] + 115 | w0111 * lut[id0111 + shift] + w1101 * lut[id1101 + shift] + w1011 * lut[id1011 + shift] + w1111 * lut[id1111 + shift]; 116 | 117 | output[index + width * height * 2] = w0000 * lut[id0000 + shift * 2] + w1000 * lut[id1000 + shift * 2] + w0100 * lut[id0100 + shift * 2] + w0010 * lut[id0010 + shift * 2] + 118 | w0001 * lut[id0001 + shift * 2] + w1100 * lut[id1100 + shift * 2] + w0110 * lut[id0110 + shift * 2] + w0011 * lut[id0011 + shift * 2] + 119 | w1010 * lut[id1010 + shift * 2] + w1001 * lut[id1001 + shift * 2] + w0101 * lut[id0101 + shift * 2] + w1110 * lut[id1110 + shift * 2] + 120 | w0111 * lut[id0111 + shift * 2] + w1101 * lut[id1101 + shift * 2] + w1011 * lut[id1011 + shift * 2] + w1111 * lut[id1111 + shift * 2]; 121 | } 122 | } 123 | 124 | void QuadriLinearBackwardCpu(const float* image, float* image_grad, const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 125 | { 126 | const int output_size = height * width; 127 | 128 | int index = 0; 129 | #pragma omp parallel 130 | #pragma omp for 131 | for (index = 0; index < output_size; ++index) 132 | { 133 | float context = image[index]; 134 | float r = image[index + width * height]; 135 | float g = image[index + width * height * 2]; 136 | float b = image[index + width * height * 3]; 137 | 138 | int r_id = floor(r / binsize); 139 | int g_id = floor(g / binsize); 140 | int b_id = floor(b / binsize); 141 | int context_id = floor(context / binsize); 142 | 143 | float r_d = fmod(r,binsize) / binsize; 144 | float g_d = fmod(g,binsize) / binsize; 145 | float b_d = fmod(b,binsize) / binsize; 146 | float context_d = fmod(context,binsize) / binsize; 147 | 148 | int id0000 = context_id + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim; 149 | int id1000 = context_id + 1 + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim; 150 | int id0100 = context_id + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim; 151 | int id0010 = context_id + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 152 | int id0001 = context_id + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 153 | int id1100 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim; 154 | int id0110 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 155 | int id0011 = context_id + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 156 | int id1010 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 157 | int id1001 = context_id + 1 + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 158 | int id0101 = context_id + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 159 | int id1110 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim; 160 | int id1011 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 161 | int id1101 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim; 162 | int id0111 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 163 | int id1111 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim; 164 | 165 | 166 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d); 167 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d); 168 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d); 169 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d); 170 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d; 171 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d); 172 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d); 173 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d; 174 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d); 175 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d; 176 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d; 177 | float w1110 = context_d*r_d*g_d*(1-b_d); 178 | float w0111 = (1-context_d)*r_d*g_d*b_d; 179 | float w1101 = context_d*r_d*(1-g_d)*b_d; 180 | float w1011 = context_d*(1-r_d)*g_d*b_d; 181 | float w1111 = context_d*r_d*g_d*b_d; 182 | 183 | 184 | lut_grad[id0000 ] += w0000 * image_grad[index + width * height]; 185 | lut_grad[id1000 ] += w1000 * image_grad[index + width * height]; 186 | lut_grad[id0100 ] += w0100 * image_grad[index + width * height]; 187 | lut_grad[id0010 ] += w0010 * image_grad[index + width * height]; 188 | lut_grad[id0001 ] += w0001 * image_grad[index + width * height]; 189 | lut_grad[id1100 ] += w1100 * image_grad[index + width * height]; 190 | lut_grad[id0110 ] += w0110 * image_grad[index + width * height]; 191 | lut_grad[id0011 ] += w0011 * image_grad[index + width * height]; 192 | lut_grad[id1010 ] += w1010 * image_grad[index + width * height]; 193 | lut_grad[id1001 ] += w1001 * image_grad[index + width * height]; 194 | lut_grad[id0101 ] += w0101 * image_grad[index + width * height]; 195 | lut_grad[id1110 ] += w1110 * image_grad[index + width * height]; 196 | lut_grad[id0111 ] += w0111 * image_grad[index + width * height]; 197 | lut_grad[id1101 ] += w1101 * image_grad[index + width * height]; 198 | lut_grad[id1011 ] += w1011 * image_grad[index + width * height]; 199 | lut_grad[id1111 ] += w1111 * image_grad[index + width * height]; 200 | 201 | lut_grad[id0000 + shift] += w0000 * image_grad[index + width * height * 2]; 202 | lut_grad[id1000 + shift] += w1000 * image_grad[index + width * height * 2]; 203 | lut_grad[id0100 + shift] += w0100 * image_grad[index + width * height * 2]; 204 | lut_grad[id0010 + shift] += w0010 * image_grad[index + width * height * 2]; 205 | lut_grad[id0001 + shift] += w0001 * image_grad[index + width * height * 2]; 206 | lut_grad[id1100 + shift] += w1100 * image_grad[index + width * height * 2]; 207 | lut_grad[id0110 + shift] += w0110 * image_grad[index + width * height * 2]; 208 | lut_grad[id0011 + shift] += w0011 * image_grad[index + width * height * 2]; 209 | lut_grad[id1010 + shift] += w1010 * image_grad[index + width * height * 2]; 210 | lut_grad[id1001 + shift] += w1001 * image_grad[index + width * height * 2]; 211 | lut_grad[id0101 + shift] += w0101 * image_grad[index + width * height * 2]; 212 | lut_grad[id1110 + shift] += w1110 * image_grad[index + width * height * 2]; 213 | lut_grad[id0111 + shift] += w0111 * image_grad[index + width * height * 2]; 214 | lut_grad[id1101 + shift] += w1101 * image_grad[index + width * height * 2]; 215 | lut_grad[id1011 + shift] += w1011 * image_grad[index + width * height * 2]; 216 | lut_grad[id1111 + shift] += w1111 * image_grad[index + width * height * 2]; 217 | 218 | lut_grad[id0000 + shift* 3] += w0000 * image_grad[index + width * height * 3]; 219 | lut_grad[id1000 + shift* 3] += w1000 * image_grad[index + width * height * 3]; 220 | lut_grad[id0100 + shift* 3] += w0100 * image_grad[index + width * height * 3]; 221 | lut_grad[id0010 + shift* 3] += w0010 * image_grad[index + width * height * 3]; 222 | lut_grad[id0001 + shift* 3] += w0001 * image_grad[index + width * height * 3]; 223 | lut_grad[id1100 + shift* 3] += w1100 * image_grad[index + width * height * 3]; 224 | lut_grad[id0110 + shift* 3] += w0110 * image_grad[index + width * height * 3]; 225 | lut_grad[id0011 + shift* 3] += w0011 * image_grad[index + width * height * 3]; 226 | lut_grad[id1010 + shift* 3] += w1010 * image_grad[index + width * height * 3]; 227 | lut_grad[id1001 + shift* 3] += w1001 * image_grad[index + width * height * 3]; 228 | lut_grad[id0101 + shift* 3] += w0101 * image_grad[index + width * height * 3]; 229 | lut_grad[id1110 + shift* 3] += w1110 * image_grad[index + width * height * 3]; 230 | lut_grad[id0111 + shift* 3] += w0111 * image_grad[index + width * height * 3]; 231 | lut_grad[id1101 + shift* 3] += w1101 * image_grad[index + width * height * 3]; 232 | lut_grad[id1011 + shift* 3] += w1011 * image_grad[index + width * height * 3]; 233 | lut_grad[id1111 + shift* 3] += w1111 * image_grad[index + width * height * 3]; 234 | 235 | 236 | } 237 | } 238 | 239 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 240 | m.def("forward", &quadrilinear_forward, "Quadrilinear forward"); 241 | m.def("backward", &quadrilinear_backward, "Quadrilinear backward"); 242 | } 243 | -------------------------------------------------------------------------------- /quadrilinear_cpp/src/quadrilinear4d.h: -------------------------------------------------------------------------------- 1 | #ifndef TRILINEAR_H 2 | #define TRILINEAR_H 3 | 4 | #include 5 | 6 | int quadrilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 7 | int lut_dim, int shift, float binsize, int width, int height, int batch); 8 | 9 | int quadrilinear_backward(torch::Tensor image, torch::Tensor image_grad,torch::Tensor lut, torch::Tensor lut_grad, 10 | int lut_dim, int shift, float binsize, int width, int height, int batch); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /quadrilinear_cpp/src/quadrilinear4d_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include "quadrilinear4d_kernel.h" 2 | #include 3 | #include 4 | 5 | int quadrilinear4d_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 6 | int lut_dim, int shift, float binsize, int width, int height, int batch) 7 | { 8 | // Grab the input tensor 9 | float * lut_flat = lut.data(); 10 | float * image_flat = image.data(); 11 | float * output_flat = output.data(); 12 | 13 | QuadriLinearForwardLaucher(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream()); 14 | 15 | return 1; 16 | } 17 | 18 | int quadrilinear4d_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut, torch::Tensor lut_grad, 19 | int lut_dim, int shift, float binsize, int width, int height, int batch) 20 | { 21 | // Grab the input tensor 22 | float * image_grad_flat = image_grad.data(); 23 | float * image_flat = image.data(); 24 | float * lut_flat = lut.data(); 25 | float * lut_grad_flat = lut_grad.data(); 26 | 27 | QuadriLinearBackwardLaucher(image_flat, image_grad_flat, lut_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream()); 28 | 29 | return 1; 30 | } 31 | 32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 33 | m.def("forward", &quadrilinear4d_forward_cuda, "Quadrilinear forward"); 34 | m.def("backward", &quadrilinear4d_backward_cuda, "Quadrilinear backward"); 35 | } 36 | 37 | -------------------------------------------------------------------------------- /quadrilinear_cpp/src/quadrilinear4d_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef TRILINEAR_CUDA_H 2 | #define TRILINEAR_CUDA_H 3 | 4 | #import 5 | 6 | int quadrilinear4d_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 7 | int lut_dim, int shift, float binsize, int width, int height, int batch); 8 | 9 | int quadrilinear4d_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut, torch::Tensor lut_grad, 10 | int lut_dim, int shift, float binsize, int width, int height, int batch); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /quadrilinear_cpp/src/quadrilinear4d_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "quadrilinear4d_kernel.h" 4 | 5 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 7 | i += blockDim.x * gridDim.x) 8 | 9 | 10 | __global__ void QuadriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 11 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 12 | 13 | float context = image[index]; 14 | float r = image[index + width * height * batch]; 15 | float g = image[index + width * height * batch * 2]; 16 | float b = image[index + width * height * batch * 3]; 17 | 18 | int context_id = 0; 19 | int r_id = floor(r / binsize); 20 | int g_id = floor(g / binsize); 21 | int b_id = floor(b / binsize); 22 | 23 | float context_d = context; 24 | float r_d = fmod(r,binsize) / binsize; 25 | float g_d = fmod(g,binsize) / binsize; 26 | float b_d = fmod(b,binsize) / binsize; 27 | 28 | int id0000 = context_id * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim; 29 | int id0100 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim; 30 | int id0010 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim; 31 | int id0001 = context_id * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim; 32 | int id0110 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim; 33 | int id0011 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 34 | int id0101 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim; 35 | int id0111 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim; 36 | 37 | int id1000 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim; 38 | int id1100 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim; 39 | int id1010 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim; 40 | int id1001 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim; 41 | int id1110 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim; 42 | int id1011 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 43 | int id1101 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim; 44 | int id1111 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim; 45 | 46 | 47 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d); 48 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d); 49 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d); 50 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d; 51 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d); 52 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d; 53 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d; 54 | float w0111 = (1-context_d)*r_d*g_d*b_d; 55 | 56 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d); 57 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d); 58 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d); 59 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d; 60 | float w1110 = context_d*r_d*g_d*(1-b_d); 61 | float w1011 = context_d*(1-r_d)*g_d*b_d; 62 | float w1101 = context_d*r_d*(1-g_d)*b_d; 63 | float w1111 = context_d*r_d*g_d*b_d; 64 | 65 | 66 | 67 | output[index] = w0000 * lut[id0000] + w0100 * lut[id0100] + w0010 * lut[id0010] + 68 | w0001 * lut[id0001] + w0110 * lut[id0110] + w0011 * lut[id0011] + 69 | w0101 * lut[id0101] + w0111 * lut[id0111] + 70 | w1000 * lut[id1000] + w1100 * lut[id1100] + w1010 * lut[id1010] + 71 | w1001 * lut[id1001] + w1110 * lut[id1110] + w1011 * lut[id1011] + 72 | w1101 * lut[id1101] + w1111 * lut[id1111]; 73 | 74 | output[index + width * height * batch] = w0000 * lut[id0000 + shift] + w0100 * lut[id0100 + shift] + w0010 * lut[id0010 + shift] + 75 | w0001 * lut[id0001 + shift] + w0110 * lut[id0110 + shift] + w0011 * lut[id0011 + shift] + 76 | w0101 * lut[id0101 + shift] + w0111 * lut[id0111 + shift] + 77 | w1000 * lut[id1000 + shift] + w1100 * lut[id1100 + shift] + w1010 * lut[id1010 + shift] + 78 | w1001 * lut[id1001 + shift] + w1110 * lut[id1110 + shift] + w1011 * lut[id1011 + shift] + 79 | w1101 * lut[id1101 + shift] + w1111 * lut[id1111 + shift]; 80 | 81 | output[index + width * height * batch * 2] = w0000 * lut[id0000 + shift * 2] + w0100 * lut[id0100 + shift * 2] + w0010 * lut[id0010 + shift * 2] + 82 | w0001 * lut[id0001 + shift * 2] + w0110 * lut[id0110 + shift * 2] + w0011 * lut[id0011 + shift * 2] + 83 | w0101 * lut[id0101 + shift * 2] + w0111 * lut[id0111 + shift * 2] + 84 | w1000 * lut[id1000 + shift * 2] + w1100 * lut[id1100 + shift * 2] + w1010 * lut[id1010 + shift * 2] + 85 | w1001 * lut[id1001 + shift * 2] + w1110 * lut[id1110 + shift * 2] + w1011 * lut[id1011 + shift * 2] + 86 | w1101 * lut[id1101 + shift * 2] + w1111 * lut[id1111 + shift * 2]; 87 | 88 | } 89 | } 90 | 91 | 92 | int QuadriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 93 | const int kThreadsPerBlock = 1024; 94 | const int output_size = height * width * batch; 95 | cudaError_t err; 96 | 97 | 98 | QuadriLinearForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, lut, image, output, lut_dim, shift, binsize, width, height, batch); 99 | 100 | err = cudaGetLastError(); 101 | if(cudaSuccess != err) { 102 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 103 | exit( -1 ); 104 | } 105 | 106 | return 1; 107 | } 108 | 109 | 110 | __global__ void QuadriLinearBackward(const int nthreads, const float* image, float* image_grad,const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 111 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 112 | 113 | float context = image[index]; 114 | float r = image[index + width * height * batch]; 115 | float g = image[index + width * height * batch * 2]; 116 | float b = image[index + width * height * batch * 3]; 117 | 118 | int context_id = 0; 119 | int r_id = floor(r / binsize); 120 | int g_id = floor(g / binsize); 121 | int b_id = floor(b / binsize); 122 | 123 | float context_d = context; 124 | float r_d = fmod(r,binsize) / binsize; 125 | float g_d = fmod(g,binsize) / binsize; 126 | float b_d = fmod(b,binsize) / binsize; 127 | 128 | 129 | int id0000 = context_id * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim; 130 | int id0100 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim; 131 | int id0010 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim; 132 | int id0001 = context_id * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim; 133 | int id0110 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim; 134 | int id0011 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 135 | int id0101 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim; 136 | int id0111 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim; 137 | 138 | int id1000 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim; 139 | int id1100 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim; 140 | int id1010 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim; 141 | int id1001 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim; 142 | int id1110 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim; 143 | int id1011 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 144 | int id1101 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim; 145 | int id1111 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim; 146 | 147 | 148 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d); 149 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d); 150 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d); 151 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d; 152 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d); 153 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d; 154 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d; 155 | float w0111 = (1-context_d)*r_d*g_d*b_d; 156 | 157 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d); 158 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d); 159 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d); 160 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d; 161 | float w1110 = context_d*r_d*g_d*(1-b_d); 162 | float w1011 = context_d*(1-r_d)*g_d*b_d; 163 | float w1101 = context_d*r_d*(1-g_d)*b_d; 164 | float w1111 = context_d*r_d*g_d*b_d; 165 | 166 | 167 | 168 | atomicAdd(lut_grad + id0000, image_grad[index + width * height * batch] * w0000); 169 | atomicAdd(lut_grad + id0100, image_grad[index + width * height * batch] * w0100); 170 | atomicAdd(lut_grad + id0010, image_grad[index + width * height * batch] * w0010); 171 | atomicAdd(lut_grad + id0001, image_grad[index + width * height * batch] * w0001); 172 | atomicAdd(lut_grad + id0110, image_grad[index + width * height * batch] * w0110); 173 | atomicAdd(lut_grad + id0011, image_grad[index + width * height * batch] * w0011); 174 | atomicAdd(lut_grad + id0101, image_grad[index + width * height * batch] * w0101); 175 | atomicAdd(lut_grad + id0111, image_grad[index + width * height * batch] * w0111); 176 | 177 | atomicAdd(lut_grad + id1000, image_grad[index + width * height * batch] * w1000); 178 | atomicAdd(lut_grad + id1100, image_grad[index + width * height * batch] * w1100); 179 | atomicAdd(lut_grad + id1010, image_grad[index + width * height * batch] * w1010); 180 | atomicAdd(lut_grad + id1001, image_grad[index + width * height * batch] * w1001); 181 | atomicAdd(lut_grad + id1110, image_grad[index + width * height * batch] * w1110); 182 | atomicAdd(lut_grad + id1011, image_grad[index + width * height * batch] * w1011); 183 | atomicAdd(lut_grad + id1101, image_grad[index + width * height * batch] * w1101); 184 | atomicAdd(lut_grad + id1111, image_grad[index + width * height * batch] * w1111); 185 | 186 | atomicAdd(lut_grad + id0000 + shift, image_grad[index + width * height * batch * 2] * w0000); 187 | atomicAdd(lut_grad + id0100 + shift, image_grad[index + width * height * batch * 2] * w0100); 188 | atomicAdd(lut_grad + id0010 + shift, image_grad[index + width * height * batch * 2] * w0010); 189 | atomicAdd(lut_grad + id0001 + shift, image_grad[index + width * height * batch * 2] * w0001); 190 | atomicAdd(lut_grad + id0110 + shift, image_grad[index + width * height * batch * 2] * w0110); 191 | atomicAdd(lut_grad + id0011 + shift, image_grad[index + width * height * batch * 2] * w0011); 192 | atomicAdd(lut_grad + id0101 + shift, image_grad[index + width * height * batch * 2] * w0101); 193 | atomicAdd(lut_grad + id0111 + shift, image_grad[index + width * height * batch * 2] * w0111); 194 | 195 | atomicAdd(lut_grad + id1000 + shift, image_grad[index + width * height * batch * 2] * w1000); 196 | atomicAdd(lut_grad + id1100 + shift, image_grad[index + width * height * batch * 2] * w1100); 197 | atomicAdd(lut_grad + id1010 + shift, image_grad[index + width * height * batch * 2] * w1010); 198 | atomicAdd(lut_grad + id1001 + shift, image_grad[index + width * height * batch * 2] * w1001); 199 | atomicAdd(lut_grad + id1110 + shift, image_grad[index + width * height * batch * 2] * w1110); 200 | atomicAdd(lut_grad + id1011 + shift, image_grad[index + width * height * batch * 2] * w1011); 201 | atomicAdd(lut_grad + id1101 + shift, image_grad[index + width * height * batch * 2] * w1101); 202 | atomicAdd(lut_grad + id1111 + shift, image_grad[index + width * height * batch * 2] * w1111); 203 | 204 | atomicAdd(lut_grad + id0000 + shift * 2, image_grad[index + width * height * batch * 3] * w0000); 205 | atomicAdd(lut_grad + id0100 + shift * 2, image_grad[index + width * height * batch * 3] * w0100); 206 | atomicAdd(lut_grad + id0010 + shift * 2, image_grad[index + width * height * batch * 3] * w0010); 207 | atomicAdd(lut_grad + id0001 + shift * 2, image_grad[index + width * height * batch * 3] * w0001); 208 | atomicAdd(lut_grad + id0110 + shift * 2, image_grad[index + width * height * batch * 3] * w0110); 209 | atomicAdd(lut_grad + id0011 + shift * 2, image_grad[index + width * height * batch * 3] * w0011); 210 | atomicAdd(lut_grad + id0101 + shift * 2, image_grad[index + width * height * batch * 3] * w0101); 211 | atomicAdd(lut_grad + id0111 + shift * 2, image_grad[index + width * height * batch * 3] * w0111); 212 | 213 | atomicAdd(lut_grad + id1000 + shift * 2, image_grad[index + width * height * batch * 3] * w1000); 214 | atomicAdd(lut_grad + id1100 + shift * 2, image_grad[index + width * height * batch * 3] * w1100); 215 | atomicAdd(lut_grad + id1010 + shift * 2, image_grad[index + width * height * batch * 3] * w1010); 216 | atomicAdd(lut_grad + id1001 + shift * 2, image_grad[index + width * height * batch * 3] * w1001); 217 | atomicAdd(lut_grad + id1110 + shift * 2, image_grad[index + width * height * batch * 3] * w1110); 218 | atomicAdd(lut_grad + id1011 + shift * 2, image_grad[index + width * height * batch * 3] * w1011); 219 | atomicAdd(lut_grad + id1101 + shift * 2, image_grad[index + width * height * batch * 3] * w1101); 220 | atomicAdd(lut_grad + id1111 + shift * 2, image_grad[index + width * height * batch * 3] * w1111); 221 | 222 | // atomicAdd(image_grad + index, (image_grad[index + width * height * batch] + image_grad[index + width * height * batch * 2] + image_grad[index + width * height * batch * 3]) / 3); 223 | 224 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * 225 | // ( (-1)*(1-r_d)*(1-g_d)*(1-b_d) * lut[id0000] + 226 | // 1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000] + 227 | // (-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100] + 228 | // (-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010] + 229 | // (-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001] + 230 | // 1*r_d*(1-g_d)*(1-b_d) * lut[id1100] + 231 | // (-1)*r_d*g_d*(1-b_d) * lut[id0110] + 232 | // (-1)*(1-r_d)*g_d*b_d * lut[id0011] + 233 | // 1*(1-r_d)*g_d*(1-b_d) * lut[id1010] + 234 | // 1*(1-r_d)*(1-g_d)*b_d * lut[id1001] + 235 | // (-1)*r_d*(1-g_d)*b_d * lut[id0101] + 236 | // 1*r_d*g_d*(1-b_d) * lut[id1110] + 237 | // (-1)*r_d*g_d*b_d * lut[id0111] + 238 | // 1*r_d*(1-g_d)*b_d * lut[id1101] + 239 | // 1*(1-r_d)*g_d*b_d * lut[id1011] + 240 | // 1*r_d*g_d*b_d * lut[id1111] 241 | // ) 242 | // ); 243 | 244 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*(1-g_d)*(1-b_d)*lut[id0000])); 245 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000])); 246 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100])); 247 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010])); 248 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001])); 249 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*(1-g_d)*(1-b_d) * lut[id1100])); 250 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*g_d*(1-b_d) * lut[id0110])); 251 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*g_d*b_d * lut[id0011])); 252 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*g_d*(1-b_d) * lut[id1010])); 253 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*(1-g_d)*b_d * lut[id1001])); 254 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*(1-g_d)*b_d * lut[id0101])); 255 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*g_d*(1-b_d) * lut[id1110])); 256 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*g_d*b_d * lut[id0111])); 257 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*(1-g_d)*b_d * lut[id1101])); 258 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*g_d*b_d * lut[id1011])); 259 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*g_d*b_d * lut[id1111])); 260 | 261 | 262 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * 263 | // ( (-1)*(1-r_d)*(1-g_d)*(1-b_d) * lut[id0000] + 264 | // 1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000] + 265 | // (-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100] + 266 | // (-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010] + 267 | // (-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001] + 268 | // 1*r_d*(1-g_d)*(1-b_d) * lut[id1100] + 269 | // (-1)*r_d*g_d*(1-b_d) * lut[id0110] + 270 | // (-1)*(1-r_d)*g_d*b_d * lut[id0011] + 271 | // 1*(1-r_d)*g_d*(1-b_d) * lut[id1010] + 272 | // 1*(1-r_d)*(1-g_d)*b_d * lut[id1001] + 273 | // (-1)*r_d*(1-g_d)*b_d * lut[id0101] + 274 | // 1*r_d*g_d*(1-b_d) * lut[id1110] + 275 | // (-1)*r_d*g_d*b_d * lut[id0111] + 276 | // 1*r_d*(1-g_d)*b_d * lut[id1101] + 277 | // 1*(1-r_d)*g_d*b_d * lut[id1011] + 278 | // 1*r_d*g_d*b_d * lut[id1111] 279 | // ) 280 | // ); 281 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * 282 | // ( (-1)*(1-r_d)*(1-g_d)*(1-b_d) * lut[id0000] + 283 | // 1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000] + 284 | // (-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100] + 285 | // (-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010] + 286 | // (-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001] + 287 | // 1*r_d*(1-g_d)*(1-b_d) * lut[id1100] + 288 | // (-1)*r_d*g_d*(1-b_d) * lut[id0110] + 289 | // (-1)*(1-r_d)*g_d*b_d * lut[id0011] + 290 | // 1*(1-r_d)*g_d*(1-b_d) * lut[id1010] + 291 | // 1*(1-r_d)*(1-g_d)*b_d * lut[id1001] + 292 | // (-1)*r_d*(1-g_d)*b_d * lut[id0101] + 293 | // 1*r_d*g_d*(1-b_d) * lut[id1110] + 294 | // (-1)*r_d*g_d*b_d * lut[id0111] + 295 | // 1*r_d*(1-g_d)*b_d * lut[id1101] + 296 | // 1*(1-r_d)*g_d*b_d * lut[id1011] + 297 | // 1*r_d*g_d*b_d * lut[id1111] 298 | // ) 299 | // ); 300 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2]); 301 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3]); 302 | 303 | 304 | 305 | // float i000 = lut[id1000]-lut[id0000]; 306 | // float i100 = lut[id1100]-lut[id0100]; 307 | // float i010 = lut[id1010]-lut[id0010]; 308 | // float i001 = lut[id1001]-lut[id0001]; 309 | // float i110 = lut[id1110]-lut[id0110]; 310 | // float i011 = lut[id1011]-lut[id0011]; 311 | // float i101 = lut[id1101]-lut[id0101]; 312 | // float i111 = lut[id1111]-lut[id0111]; 313 | 314 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 315 | float w100 = r_d*(1-g_d)*(1-b_d); 316 | float w010 = (1-r_d)*g_d*(1-b_d); 317 | float w001 = (1-r_d)*(1-g_d)*b_d; 318 | float w110 = r_d*g_d*(1-b_d); 319 | float w011 = (1-r_d)*g_d*b_d; 320 | float w101 = r_d*(1-g_d)*b_d; 321 | float w111 = r_d*g_d*b_d; 322 | 323 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w000 * binsize); 324 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w100 * binsize); 325 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w010 * binsize); 326 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w001 * binsize); 327 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w110 * binsize); 328 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w011 * binsize); 329 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w101 * binsize); 330 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w111 * binsize); 331 | 332 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w000 * binsize); 333 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w100 * binsize); 334 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w010 * binsize); 335 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w001 * binsize); 336 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w110 * binsize); 337 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w011 * binsize); 338 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w101 * binsize); 339 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w111 * binsize); 340 | 341 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w000 * binsize); 342 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w100 * binsize); 343 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w010 * binsize); 344 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w001 * binsize); 345 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w110 * binsize); 346 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w011 * binsize); 347 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w101 * binsize); 348 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w111 * binsize); 349 | } 350 | } 351 | 352 | int QuadriLinearBackwardLaucher(const float* image, float* image_grad, const float* lut, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 353 | const int kThreadsPerBlock = 1024; 354 | const int output_size = height * width * batch; 355 | cudaError_t err; 356 | 357 | QuadriLinearBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, image, image_grad, lut, lut_grad, lut_dim, shift, binsize, width, height, batch); 358 | 359 | err = cudaGetLastError(); 360 | if(cudaSuccess != err) { 361 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 362 | exit( -1 ); 363 | } 364 | 365 | return 1; 366 | } 367 | -------------------------------------------------------------------------------- /quadrilinear_cpp/src/quadrilinear4d_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _TRILINEAR_KERNEL 2 | #define _TRILINEAR_KERNEL 3 | 4 | #include 5 | 6 | __global__ void QuadriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 7 | 8 | int QuadriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 9 | 10 | __global__ void QuadriLinearBackward(const int nthreads, const float* image, float* image_grad, const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 11 | 12 | int QuadriLinearBackwardLaucher(const float* image, float* image_grad, const float* lut, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 13 | 14 | 15 | #endif 16 | 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import itertools 6 | import time 7 | import datetime 8 | import sys 9 | from skimage.metrics import structural_similarity as ssim 10 | import torchvision.transforms as transforms 11 | from torchvision.utils import save_image 12 | 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets 15 | from torch.autograd import Variable 16 | 17 | from models_x import * 18 | from datasetsMIT import * 19 | 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch 23 | 24 | # CUDA_VISIBLE_DEVICES 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from, 0 starts from scratch, >0 starts from saved checkpoints") 27 | parser.add_argument("--n_epochs", type=int, default=1000, help="total number of epochs of training") 28 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset") 29 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ") 30 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") 31 | parser.add_argument("--lr", type=float, default=0.0001, help="adam: learning rate") 32 | parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient") 33 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 34 | parser.add_argument("--lambda_smooth", type=float, default=0.0001, help="smooth regularization") 35 | parser.add_argument("--lambda_monotonicity", type=float, default=10.0, help="monotonicity regularization") 36 | parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation") 37 | parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between model checkpoints") 38 | parser.add_argument("--output_dir", type=str, default="LUTs/fiveK", help="path to save model") 39 | opt = parser.parse_args() 40 | 41 | opt.output_dir = opt.output_dir + '_' + opt.input_color_space 42 | print(opt) 43 | 44 | os.makedirs("saved_models/%s" % opt.output_dir, exist_ok=True) 45 | 46 | cuda = True if torch.cuda.is_available() else False 47 | # Tensor type 48 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 49 | 50 | # Loss functions 51 | criterion_pixelwise = torch.nn.MSELoss() 52 | 53 | # Initialize generator and discriminator 54 | LUT_enhancement = Generator4DLUT_identity() 55 | Generator_bias = Generator_for_bias() 56 | Generator_context = Generator_for_info() 57 | TV4 = TV_4D() 58 | quadrilinear_enhancement_ = QuadrilinearInterpolation_4D() 59 | 60 | if cuda: 61 | LUT_enhancement = LUT_enhancement.cuda() 62 | Generator_bias = Generator_bias.cuda() 63 | Generator_context = Generator_context.cuda() 64 | criterion_pixelwise.cuda() 65 | TV4.cuda() 66 | TV4.weight_r = TV4.weight_r.type(Tensor) 67 | TV4.weight_g = TV4.weight_g.type(Tensor) 68 | TV4.weight_b = TV4.weight_b.type(Tensor) 69 | 70 | if opt.epoch != 0: 71 | LUT_enhancements = torch.load("saved_models/%s/4DLUTs_enhancement_%d.pth" % (opt.output_dir, opt.epoch)) 72 | LUT_enhancement.load_state_dict(LUT_enhancements) 73 | Generator_bias.load_state_dict(torch.load("saved_models/%s/generator_bias_%d.pth" % (opt.output_dir, opt.epoch))) 74 | Generator_context.load_state_dict(torch.load("saved_models/%s/generator_context_%d.pth" % (opt.output_dir, opt.epoch))) 75 | else: 76 | # Initialize weights 77 | Generator_bias.apply(weights_init_normal_generator) 78 | torch.nn.init.constant_(Generator_bias.model[16].bias.data, 1.0) 79 | 80 | # Optimizers 81 | optimizer_G = torch.optim.Adam(itertools.chain(Generator_bias.parameters(), Generator_context.parameters(), LUT_enhancement.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) #, LUT3.parameters(), LUT4.parameters() 82 | 83 | if opt.input_color_space == 'sRGB': 84 | dataloader = DataLoader( 85 | ImageDataset_sRGB("/fivek_dataset/MIT-Adobe5k-UPE/" , mode = "train"), 86 | batch_size=opt.batch_size, 87 | shuffle=True, 88 | num_workers=1, 89 | ) 90 | 91 | psnr_dataloader = DataLoader( 92 | ImageDataset_sRGB("/fivek_dataset/MIT-Adobe5k-UPE/" , mode="test"), 93 | batch_size=1, 94 | shuffle=False, 95 | num_workers=1, 96 | ) 97 | 98 | 99 | def generator_train(img): 100 | 101 | context = Generator_context(img) 102 | pred = Generator_bias(img) 103 | 104 | context = context.new(context.size()) 105 | 106 | context = Variable(context.fill_(0).type(Tensor)) 107 | 108 | pred = pred.squeeze(2).squeeze(2) 109 | combine = torch.cat([context,img],1) 110 | 111 | gen_A0 = LUT_enhancement(combine) 112 | 113 | weights_norm = torch.mean(pred ** 2) 114 | 115 | combine_A = img.new(img.size()) 116 | for b in range(img.size(0)): 117 | combine_A[b,0,:,:] = pred[b,0] * gen_A0[b,0,:,:] + pred[b,1] * gen_A0[b,1,:,:] + pred[b,2] * gen_A0[b,2,:,:] + pred[b,9] 118 | combine_A[b,1,:,:] = pred[b,3] * gen_A0[b,0,:,:] + pred[b,4] * gen_A0[b,1,:,:] + pred[b,5] * gen_A0[b,2,:,:] + pred[b,10] 119 | combine_A[b,2,:,:] = pred[b,6] * gen_A0[b,0,:,:] + pred[b,7] * gen_A0[b,1,:,:] + pred[b,8] * gen_A0[b,2,:,:] + pred[b,11] 120 | 121 | return combine_A, weights_norm 122 | 123 | def generator_eval(img): 124 | context = Generator_context(img) 125 | pred = Generator_bias(img) 126 | 127 | context = context.new(context.size()) 128 | context = Variable(context.fill_(0).type(Tensor)) 129 | 130 | pred = pred.squeeze(2).squeeze(2).squeeze(0) 131 | 132 | combine = torch.cat([context,img],1) 133 | 134 | new_LUT_enhancement = LUT_enhancement.LUT_en.new(LUT_enhancement.LUT_en.size()) 135 | new_LUT_enhancement[0] = pred[0] * LUT_enhancement.LUT_en[0] + pred[1] * LUT_enhancement.LUT_en[1] + pred[2] * LUT_enhancement.LUT_en[2] + pred[9] 136 | new_LUT_enhancement[1] = pred[3] * LUT_enhancement.LUT_en[0] + pred[4] * LUT_enhancement.LUT_en[1] + pred[5] * LUT_enhancement.LUT_en[2] + pred[10] 137 | new_LUT_enhancement[2] = pred[6] * LUT_enhancement.LUT_en[0] + pred[7] * LUT_enhancement.LUT_en[1] + pred[8] * LUT_enhancement.LUT_en[2] + pred[11] 138 | 139 | weights_norm = torch.mean(pred[0] ** 2) 140 | combine_A = img.new(img.size()) 141 | _, combine_A = quadrilinear_enhancement_(new_LUT_enhancement,combine) 142 | 143 | return combine_A, weights_norm 144 | 145 | def calculate_psnr(): 146 | Generator_bias.eval() 147 | Generator_context.eval() 148 | avg_psnr = 0 149 | for i, batch in enumerate(psnr_dataloader): 150 | real_A = Variable(batch["A_input"].type(Tensor)) 151 | real_B = Variable(batch["A_exptC"].type(Tensor)) 152 | fake_B, _ = generator_eval(real_A) 153 | fake_B = torch.clamp(fake_B,0.0,1.0) 154 | fake_B = torch.round(fake_B*255) 155 | real_B = torch.round(real_B*255) 156 | mse = criterion_pixelwise(fake_B, real_B) 157 | psnr = 10 * math.log10(255.0 * 255.0 / mse.cpu().detach().item()) 158 | avg_psnr += psnr 159 | return avg_psnr/ len(psnr_dataloader) 160 | 161 | def calculate_ssim(): 162 | Generator_bias.eval() 163 | Generator_context.eval() 164 | avg_ssim = 0 165 | for i, batch in enumerate(psnr_dataloader): 166 | real_A = Variable(batch["A_input"].type(Tensor)) 167 | real_B = Variable(batch["A_exptC"].type(Tensor)) 168 | fake_B, _ = generator_eval(real_A) 169 | fake_B = torch.clamp(fake_B,0.0,1.0) 170 | fake_B = fake_B.squeeze(0).cpu().detach().numpy() 171 | real_B = real_B.squeeze(0).cpu().detach().numpy() 172 | fake_B = np.swapaxes(np.swapaxes(fake_B, 0, 2), 0, 1) 173 | real_B = np.swapaxes(np.swapaxes(real_B, 0, 2), 0, 1) 174 | fake_B = fake_B.astype(np.float32) 175 | real_B = real_B.astype(np.float32) 176 | ssim_val = ssim(real_B,fake_B, data_range=real_B.max() - fake_B.min(), multichannel=True, gaussian_weights=True, win_size=11) 177 | avg_ssim += ssim_val 178 | return avg_ssim / len(psnr_dataloader) 179 | 180 | 181 | 182 | def visualize_result(epoch): 183 | """Saves a generated sample from the validation set""" 184 | Generator_bias.eval() 185 | Generator_context.eval() 186 | os.makedirs("images/%s/" % opt.output_dir +str(epoch), exist_ok=True) 187 | for i, batch in enumerate(psnr_dataloader): 188 | real_A = Variable(batch["A_input"].type(Tensor)) 189 | real_B = Variable(batch["A_exptC"].type(Tensor)) 190 | img_name = batch["input_name"] 191 | fake_B, _ = generator_eval(real_A) 192 | fake_B = torch.clamp(fake_B,0.0,1.0) 193 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1) 194 | fake_B = torch.round(fake_B*255) 195 | real_B = torch.round(real_B*255) 196 | mse = criterion_pixelwise(fake_B, real_B) 197 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item()) 198 | save_image(img_sample, "images/%s/%s/%s.jpg" % (opt.output_dir,epoch, img_name[0]+'_'+str(psnr)[:5]), nrow=3, normalize=False) 199 | 200 | # ---------- 201 | # Training 202 | # ---------- 203 | prev_time = time.time() 204 | max_psnr = 0 205 | max_epoch = 0 206 | for epoch in range(opt.epoch, opt.n_epochs): 207 | mse_avg = 0 208 | psnr_avg = 0 209 | Generator_bias.train() 210 | Generator_context.train() 211 | for i, batch in enumerate(dataloader): 212 | # Model inputs 213 | real_A = Variable(batch["A_input"].type(Tensor)) 214 | real_B = Variable(batch["A_exptC"].type(Tensor)) 215 | # ------------------ 216 | # Train Generators 217 | # ------------------ 218 | 219 | optimizer_G.zero_grad() 220 | 221 | fake_B, weights_norm = generator_train(real_A) 222 | 223 | # Pixel-wise loss 224 | mse = criterion_pixelwise(fake_B, real_B) 225 | 226 | tv_enhancement, mn_enhancement = TV4(LUT_enhancement) 227 | 228 | tv_cons = tv_enhancement 229 | mn_cons = mn_enhancement 230 | 231 | # loss = mse 232 | loss = mse + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_cons 233 | psnr_avg += 10 * math.log10(1 / mse.item()) 234 | 235 | mse_avg += mse.item() 236 | 237 | loss.backward() 238 | 239 | optimizer_G.step() 240 | 241 | 242 | # -------------- 243 | # Log Progress 244 | # -------------- 245 | 246 | # Determine approximate time left 247 | batches_done = epoch * len(dataloader) + i 248 | batches_left = opt.n_epochs * len(dataloader) - batches_done 249 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 250 | prev_time = time.time() 251 | 252 | # Print log 253 | if i % 500 == 0: 254 | sys.stdout.write( 255 | "\r[Epoch %d/%d] [Batch %d/%d] [psnr: %f, tv: %f, mn: %f] ETA: %s" 256 | % (epoch,opt.n_epochs,i,len(dataloader),psnr_avg / (i+1),tv_cons, mn_cons, time_left, 257 | ) 258 | ) 259 | if i % 500 == 0: 260 | print( 261 | "\r[Epoch %d/%d] [Batch %d/%d] [psnr: %f, tv: %f, mn: %f] ETA: %s" 262 | % (epoch,opt.n_epochs,i,len(dataloader),psnr_avg / (i+1),tv_cons, mn_cons, time_left, 263 | ) 264 | ) 265 | avg_ssim = calculate_ssim() 266 | avg_psnr = calculate_psnr() 267 | if avg_psnr > max_psnr: 268 | max_psnr = avg_psnr 269 | max_epoch = epoch 270 | 271 | LUTs_enhancement = LUT_enhancement.state_dict() 272 | torch.save(LUTs_enhancement, "saved_models/%s/4DLUTs_enhancement_%d.pth" % (opt.output_dir, epoch)) 273 | torch.save(Generator_bias.state_dict(), "saved_models/%s/generator_bias_%d.pth" % (opt.output_dir, epoch)) 274 | torch.save(Generator_context.state_dict(), "saved_models/%s/generator_context_%d.pth" % (opt.output_dir, epoch)) 275 | file = open('saved_models/%s/result.txt' % opt.output_dir,'a') 276 | file.write(" [PSNR: %f , SSIM: %f] [epoch: %d]\n"% (max_psnr,avg_ssim, max_epoch)) 277 | file.close() 278 | 279 | sys.stdout.write(" [PSNR: %f, SSIM: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, avg_ssim, max_psnr, max_epoch)) 280 | print(" [PSNR: %f, SSIM: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, avg_ssim, max_psnr, max_epoch)) 281 | 282 | if (epoch+1) % 100 == 0: 283 | visualize_result(epoch+1) 284 | 285 | --------------------------------------------------------------------------------