├── 4D LUT_overview.jpg
├── Identity4DLUT17.txt
├── Identity4DLUT33.txt
├── LICENSE
├── README.md
├── datasets.py
├── models_x.py
├── quadrilinear_cpp
├── setup.py
├── setup.sh
└── src
│ ├── quadrilinear4d.cpp
│ ├── quadrilinear4d.h
│ ├── quadrilinear4d_cuda.cpp
│ ├── quadrilinear4d_cuda.h
│ ├── quadrilinear4d_kernel.cu
│ └── quadrilinear4d_kernel.h
└── train.py
/4D LUT_overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChengxuLiu/4DLUT/7bfccd19c0170a8cd1a5c8c5755c19bd667b29ac/4D LUT_overview.jpg
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Chengxu Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [TIP 2023] 4D LUT
2 | This is the official PyTorch implementation of the paper [4D LUT: Learnable Context-Aware 4D Lookup Table for Image Enhancement](https://arxiv.org/abs/2209.01749).
3 |
4 | ## Overview
5 |
6 |
7 | ## Contribution
8 | * We propose a novel learnable context-aware 4-dimensional lookup table (4D~LUT), which first extends the lookup table architecture into a 4-dimensional space and achieve content-dependent image enhancement without a significant increase in computational costs.
9 | * The extensive experiments demonstrate that the proposed 4D~LUT can obtain more accurate results and significantly outperform existing SOTA methods in three widely-used image enhancement benchmarks.
10 |
11 | ## Requirements and dependencies
12 | * CUDA 11.4
13 | * GCC 7.5.0
14 | * python 3.8 (recommend to use [Anaconda](https://www.anaconda.com/))
15 | * pytorch == 1.9.0
16 | * torchvision == 0.10.0
17 |
18 | ## Train
19 | 1. Clone this github repo
20 | ```
21 | git clone https://github.com/ChengxuLiu/4DLUT.git
22 | cd 4DLUT
23 | ```
24 | 2. Build:
25 | ```
26 | cd quadrilinear_cpp
27 | sh setup.sh
28 | ```
29 | 3. Prepare training dataset and modify dataset path in `./dataset.py`
30 | 4. Run training
31 | ```
32 | python train.py
33 | ```
34 | 5. The models are saved in `./saved_models`
35 |
36 |
37 | ## Citation
38 | If you find the code useful for your research, please consider citing our paper. :blush:
39 | ```
40 | @article{liu20234d,
41 | title={4D LUT: learnable context-aware 4d lookup table for image enhancement},
42 | author={Liu, Chengxu and Yang, Huan and Fu, Jianlong and Qian, Xueming},
43 | journal={IEEE Transactions on Image Processing},
44 | volume={32},
45 | pages={4742--4756},
46 | year={2023},
47 | publisher={IEEE}
48 | }
49 | ```
50 |
51 | ## Contact
52 | If you meet any problems, please describe them in issues or contact:
53 | * Chengxu Liu:
54 |
55 | ## Acknowledgement
56 | The code of 4DLUT is built upon [Image-Adaptive-3DLUT](https://github.com/HuiZeng/Image-Adaptive-3DLUT), and we express our gratitude to these awesome projects.
57 |
58 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | import os
4 | import numpy as np
5 | import torch
6 | import cv2
7 |
8 | from torch.utils.data import Dataset
9 | from PIL import Image
10 | import torchvision.transforms as transforms
11 | import torchvision.transforms.functional as TF
12 |
13 |
14 | class ImageDataset_sRGB(Dataset):
15 | def __init__(self, root, mode="train", unpaird_data="fiveK", combined=False):
16 | self.mode = mode
17 | self.unpaird_data = unpaird_data
18 |
19 | file = open(os.path.join(root,'images_train.txt'),'r') #for DPE
20 | set1_input_files = sorted(file.readlines())
21 | self.set1_input_files = list()
22 | self.set1_expert_files = list()
23 | for i in range(len(set1_input_files)):
24 | self.set1_input_files.append(os.path.join(root,"input","InputAsShotZero",set1_input_files[i][:-1] + ".png"))
25 | self.set1_expert_files.append(os.path.join(root,"output","Export_C_512",set1_input_files[i][:-1] + ".png"))
26 |
27 | file = open(os.path.join(root,'images_test.txt'),'r')
28 | test_input_files = sorted(file.readlines())
29 | self.test_input_files = list()
30 | self.test_expert_files = list()
31 | for i in range(len(test_input_files)):
32 | self.test_input_files.append(os.path.join(root,"input","InputAsShotZero",test_input_files[i][:-1] + ".png"))
33 | self.test_expert_files.append(os.path.join(root,"output","Export_C_512",test_input_files[i][:-1] + ".png"))
34 |
35 | def __getitem__(self, index):
36 | if self.mode == "train":
37 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1]
38 | img_input = Image.open(self.set1_input_files[index % len(self.set1_input_files)])
39 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)])
40 |
41 | elif self.mode == "test":
42 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1]
43 | img_input = Image.open(self.test_input_files[index % len(self.test_input_files)])
44 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)])
45 |
46 | if self.mode == "train":
47 |
48 | ratio_H = np.random.uniform(0.6,1.0)
49 | ratio_W = np.random.uniform(0.6,1.0)
50 | W,H = img_input._size
51 | crop_h = round(H*ratio_H)
52 | crop_w = round(W*ratio_W)
53 |
54 | i, j, h, w = transforms.RandomCrop.get_params(img_input, output_size=(crop_h, crop_w))
55 | img_input = TF.crop(img_input, i, j, h, w)
56 | img_exptC = TF.crop(img_exptC, i, j, h, w)
57 |
58 | if np.random.random() > 0.5:
59 | img_input = TF.hflip(img_input)
60 | img_exptC = TF.hflip(img_exptC)
61 |
62 | a = np.random.uniform(0.8,1.2)
63 | img_input = TF.adjust_brightness(img_input,a)
64 |
65 | a = np.random.uniform(0.8,1.2)
66 | img_input = TF.adjust_saturation(img_input,a)
67 |
68 | img_input = TF.to_tensor(img_input)
69 | img_exptC = TF.to_tensor(img_exptC)
70 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name}
71 |
72 | def __len__(self):
73 | if self.mode == "train":
74 | return len(self.set1_input_files)
75 | elif self.mode == "test":
76 | return len(self.test_input_files)
77 |
--------------------------------------------------------------------------------
/models_x.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torchvision.models as models
4 | import torchvision.transforms as transforms
5 | from torch.autograd import Variable
6 | import torch
7 | import numpy as np
8 | import math
9 | import quadrilinear4d
10 |
11 | def weights_init_normal_generator(m):
12 | classname = m.__class__.__name__
13 | if classname.find("Conv2d") != -1:
14 | torch.nn.init.xavier_normal_(m.weight.data)
15 |
16 | elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1:
17 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
18 | torch.nn.init.constant_(m.bias.data, 0.0)
19 |
20 |
21 | def discriminator_block(in_filters, out_filters, normalization=False):
22 | """Returns downsampling layers of each discriminator block"""
23 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
24 | layers.append(nn.LeakyReLU(0.2))
25 | if normalization:
26 | layers.append(nn.InstanceNorm2d(out_filters, affine=True))
27 |
28 | return layers
29 |
30 |
31 | class Generator_for_bias(nn.Module):
32 | def __init__(self, in_channels=3):
33 | super(Generator_for_bias, self).__init__()
34 |
35 | self.model = nn.Sequential(
36 | nn.Upsample(size=(256,256),mode='bilinear'),
37 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
38 | nn.LeakyReLU(0.2),
39 | nn.InstanceNorm2d(16, affine=True),
40 | *discriminator_block(16, 32, normalization=True),
41 | *discriminator_block(32, 64, normalization=True),
42 | *discriminator_block(64, 128, normalization=True),
43 | *discriminator_block(128, 128),
44 | nn.Dropout(p=0.5),
45 | nn.Conv2d(128, 12, 8, padding=0),
46 | )
47 |
48 | def forward(self, img_input):
49 | return self.model(img_input)
50 |
51 |
52 | def generator_block(in_filters, out_filters, normalization=False):
53 | """Returns downsampling layers of each discriminator block"""
54 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1)]
55 | layers.append(nn.LeakyReLU(0.2))
56 | if normalization:
57 | layers.append(nn.InstanceNorm2d(out_filters, affine=True))
58 | #layers.append(nn.BatchNorm2d(out_filters))
59 | return layers
60 |
61 |
62 | class Generator_for_info(nn.Module):
63 | def __init__(self, in_channels=3):
64 | super(Generator_for_info, self).__init__()
65 |
66 | self.input_layer = nn.Sequential(
67 | nn.Conv2d(in_channels, 16, 3, stride=1, padding=1),
68 | nn.LeakyReLU(0.2),
69 | nn.InstanceNorm2d(16, affine=True),
70 | )
71 |
72 | self.mid_layer = nn.Sequential(
73 | *generator_block(16, 16, normalization=True),
74 | *generator_block(16, 16, normalization=True),
75 | *generator_block(16, 16, normalization=True),
76 | )
77 |
78 | self.output_layer = nn.Sequential(
79 | nn.Dropout(p=0.5),
80 | nn.Conv2d(16, 1, 3, stride=1, padding=1),
81 | nn.Sigmoid()
82 | )
83 |
84 |
85 | def forward(self, img_input):
86 | x = self.input_layer(img_input)
87 | identity = x
88 | out = self.mid_layer(x)
89 | out += identity
90 | out = self.output_layer(out)
91 | return out
92 |
93 |
94 |
95 | class Generator4DLUT_identity(nn.Module):
96 | def __init__(self, dim=17):
97 | super(Generator4DLUT_identity, self).__init__()
98 | if dim == 17:
99 | file = open("Identity4DLUT17.txt", 'r')
100 | elif dim == 33:
101 | file = open("Identity4DLUT33.txt", 'r')
102 | lines = file.readlines()
103 | buffer = np.zeros((3,2,dim,dim,dim), dtype=np.float32)
104 | for p in range(0,2):
105 | for i in range(0,dim):
106 | for j in range(0,dim):
107 | for k in range(0,dim):
108 | n = p * dim*dim*dim + i * dim*dim + j*dim + k
109 | x = lines[n].split()
110 | buffer[0,p,i,j,k] = float(x[0])
111 | buffer[1,p,i,j,k] = float(x[1])
112 | buffer[2,p,i,j,k] = float(x[2])
113 | self.LUT_en = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True))
114 | self.QuadrilinearInterpolation_4D = QuadrilinearInterpolation_4D()
115 |
116 | def forward(self, x):
117 | _, output = self.QuadrilinearInterpolation_4D(self.LUT_en, x)
118 | return output
119 |
120 |
121 |
122 |
123 | class QuadrilinearInterpolation_Function(torch.autograd.Function):
124 | @staticmethod
125 | def forward(ctx, lut, x):
126 | x = x.contiguous()
127 | output = x.new(x.size()[0],3,x.size()[2],x.size()[3])
128 | dim = lut.size()[-1]
129 | shift = 2 * dim ** 3
130 | binsize = 1.000001 / (dim-1)
131 | W = x.size(2)
132 | H = x.size(3)
133 | batch = x.size(0)
134 | assert 1 == quadrilinear4d.forward(lut,
135 | x,
136 | output,
137 | dim,
138 | shift,
139 | binsize,
140 | W,
141 | H,
142 | batch)
143 | int_package = torch.IntTensor([dim, shift, W, H, batch])
144 | float_package = torch.FloatTensor([binsize])
145 | variables = [lut, x, int_package, float_package]
146 | ctx.save_for_backward(*variables)
147 |
148 | return lut, output
149 |
150 | @staticmethod
151 | def backward(ctx, lut_grad, x_grad):
152 | # print()
153 | x_grad = x_grad.contiguous()
154 | output_grad = x_grad.new(x_grad.size()[0],4,x_grad.size()[2],x_grad.size()[3]).fill_(0)
155 | output_grad[:,1:,:,:] = x_grad
156 | lut, x, int_package, float_package = ctx.saved_variables
157 | dim, shift, W, H, batch = int_package
158 | dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
159 | binsize = float(float_package[0])
160 |
161 | assert 1 == quadrilinear4d.backward(x,
162 | output_grad,
163 | lut,
164 | lut_grad,
165 | dim,
166 | shift,
167 | binsize,
168 | W,
169 | H,
170 | batch)
171 | return lut_grad, output_grad
172 |
173 |
174 | class QuadrilinearInterpolation_4D(torch.nn.Module):
175 | def __init__(self):
176 | super(QuadrilinearInterpolation_4D, self).__init__()
177 |
178 | def forward(self, lut, x):
179 | return QuadrilinearInterpolation_Function.apply(lut, x)
180 |
181 |
182 | class TV_4D(nn.Module):
183 | def __init__(self, dim=17):
184 | super(TV_4D,self).__init__()
185 |
186 | self.weight_r = torch.ones(3,2,dim,dim,dim-1, dtype=torch.float)
187 | self.weight_r[:,:,:,:,(0,dim-2)] *= 2.0
188 | self.weight_g = torch.ones(3,2,dim,dim-1,dim, dtype=torch.float)
189 | self.weight_g[:,:,:,(0,dim-2),:] *= 2.0
190 | self.weight_b = torch.ones(3,2,dim-1,dim,dim, dtype=torch.float)
191 | self.weight_b[:,:,(0,dim-2),:,:] *= 2.0
192 | self.relu = torch.nn.ReLU()
193 |
194 | def forward(self, LUT):
195 | dif_context = LUT.LUT_en[:,:-1,:,:,:] - LUT.LUT_en[:,1:,:,:,:]
196 | dif_r = LUT.LUT_en[:,:,:,:,:-1] - LUT.LUT_en[:,:,:,:,1:]
197 | dif_g = LUT.LUT_en[:,:,:,:-1,:] - LUT.LUT_en[:,:,:,1:,:]
198 | dif_b = LUT.LUT_en[:,:,:-1,:,:] - LUT.LUT_en[:,:,1:,:,:]
199 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))
200 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b)) \
201 | + torch.mean(self.relu(dif_context))
202 | return tv, mn
203 |
204 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | import torch
3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension
4 |
5 | if torch.cuda.is_available():
6 | # if False:
7 | print('Including CUDA code.')
8 | setup(
9 | name='quadrilinear4d',
10 | ext_modules=[
11 | CUDAExtension('quadrilinear4d', [
12 | 'src/quadrilinear4d_cuda.cpp',
13 | 'src/quadrilinear4d_kernel.cu',
14 | ])
15 | ],
16 | cmdclass={
17 | 'build_ext': BuildExtension
18 | })
19 | else:
20 | print('NO CUDA is found. Fall back to CPU.')
21 | setup(name='quadrilinear4d',
22 | ext_modules=[CppExtension(name = 'quadrilinear4d',
23 | sources= ['src/quadrilinear4d.cpp'],
24 | extra_compile_args=['-fopenmp'])],
25 | cmdclass={'build_ext': BuildExtension})
26 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/setup.sh:
--------------------------------------------------------------------------------
1 | export CUDA_HOME=/usr/local/cuda && python3 setup.py install
2 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/src/quadrilinear4d.cpp:
--------------------------------------------------------------------------------
1 | #include "quadrilinear.h"
2 |
3 |
4 | void QuadriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels);
5 |
6 | void QuadriLinearBackwardCpu(const float* image, float* image_grad,const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels);
7 |
8 | int quadrilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
9 | int lut_dim, int shift, float binsize, int width, int height, int batch)
10 | {
11 | // Grab the input tensor
12 | float * lut_flat = lut.data();
13 | float * image_flat = image.data();
14 | float * output_flat = output.data();
15 |
16 | // whether color image
17 | auto image_size = image.sizes();
18 | int channels = image_size[1];
19 |
20 | QuadriLinearForwardCpu(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, channels);
21 |
22 | return 1;
23 | }
24 |
25 | int quadrilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut, torch::Tensor lut_grad,
26 | int lut_dim, int shift, float binsize, int width, int height, int batch)
27 | {
28 | // Grab the input tensor
29 | float * image_grad_flat = image_grad.data();
30 | float * lut_flat = lut.data();
31 | float * image_flat = image.data();
32 | float * lut_grad_flat = lut_grad.data();
33 |
34 | // whether color image
35 | auto image_size = image.sizes();
36 | int channels = image_size[1];
37 | if (channels != 3)
38 | {
39 | return 0;
40 | }
41 |
42 | TriLinearBackwardCpu(image_flat, image_grad_flat,lut_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, channels);
43 |
44 | return 1;
45 | }
46 |
47 | void QuadriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels)
48 | {
49 | const int output_size = height * width;;
50 |
51 | int index = 0;
52 |
53 | #pragma omp parallel
54 | #pragma omp for
55 | for (index = 0; index < output_size; ++index)
56 | {
57 | float context = image[index];
58 | float r = image[index + width * height];
59 | float g = image[index + width * height * 2];
60 | float b = image[index + width * height * 3];
61 |
62 | int r_id = floor(r / binsize);
63 | int g_id = floor(g / binsize);
64 | int b_id = floor(b / binsize);
65 | int context_id = floor(context / binsize);
66 |
67 | float r_d = fmod(r,binsize) / binsize;
68 | float g_d = fmod(g,binsize) / binsize;
69 | float b_d = fmod(b,binsize) / binsize;
70 | float context_d = fmod(context,binsize) / binsize;
71 |
72 | int id0000 = context_id + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim;
73 | int id1000 = context_id + 1 + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim;
74 | int id0100 = context_id + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim;
75 | int id0010 = context_id + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
76 | int id0001 = context_id + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
77 | int id1100 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim;
78 | int id0110 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
79 | int id0011 = context_id + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
80 | int id1010 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
81 | int id1001 = context_id + 1 + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
82 | int id0101 = context_id + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
83 | int id1110 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
84 | int id1011 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
85 | int id1101 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
86 | int id0111 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
87 | int id1111 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
88 |
89 |
90 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d);
91 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d);
92 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d);
93 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d);
94 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d;
95 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d);
96 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d);
97 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d;
98 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d);
99 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d;
100 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d;
101 | float w1110 = context_d*r_d*g_d*(1-b_d);
102 | float w0111 = (1-context_d)*r_d*g_d*b_d;
103 | float w1101 = context_d*r_d*(1-g_d)*b_d;
104 | float w1011 = context_d*(1-r_d)*g_d*b_d;
105 | float w1111 = context_d*r_d*g_d*b_d;
106 |
107 | output[index] = w0000 * lut[id0000] + w1000 * lut[id1000] + w0100 * lut[id0100] + w0010 * lut[id0010] +
108 | w0001 * lut[id0001] + w1100 * lut[id1100] + w0110 * lut[id0110] + w0011 * lut[id0011] +
109 | w1010 * lut[id1010] + w1001 * lut[id1001] + w0101 * lut[id0101] + w1110 * lut[id1110] +
110 | w0111 * lut[id0111] + w1101 * lut[id1101] + w1011 * lut[id1011] + w1111 * lut[id1111];
111 |
112 | output[index + width * height] = w0000 * lut[id0000 + shift] + w1000 * lut[id1000 + shift] + w0100 * lut[id0100 + shift] + w0010 * lut[id0010 + shift] +
113 | w0001 * lut[id0001 + shift] + w1100 * lut[id1100 + shift] + w0110 * lut[id0110 + shift] + w0011 * lut[id0011 + shift] +
114 | w1010 * lut[id1010 + shift] + w1001 * lut[id1001 + shift] + w0101 * lut[id0101 + shift] + w1110 * lut[id1110 + shift] +
115 | w0111 * lut[id0111 + shift] + w1101 * lut[id1101 + shift] + w1011 * lut[id1011 + shift] + w1111 * lut[id1111 + shift];
116 |
117 | output[index + width * height * 2] = w0000 * lut[id0000 + shift * 2] + w1000 * lut[id1000 + shift * 2] + w0100 * lut[id0100 + shift * 2] + w0010 * lut[id0010 + shift * 2] +
118 | w0001 * lut[id0001 + shift * 2] + w1100 * lut[id1100 + shift * 2] + w0110 * lut[id0110 + shift * 2] + w0011 * lut[id0011 + shift * 2] +
119 | w1010 * lut[id1010 + shift * 2] + w1001 * lut[id1001 + shift * 2] + w0101 * lut[id0101 + shift * 2] + w1110 * lut[id1110 + shift * 2] +
120 | w0111 * lut[id0111 + shift * 2] + w1101 * lut[id1101 + shift * 2] + w1011 * lut[id1011 + shift * 2] + w1111 * lut[id1111 + shift * 2];
121 | }
122 | }
123 |
124 | void QuadriLinearBackwardCpu(const float* image, float* image_grad, const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels)
125 | {
126 | const int output_size = height * width;
127 |
128 | int index = 0;
129 | #pragma omp parallel
130 | #pragma omp for
131 | for (index = 0; index < output_size; ++index)
132 | {
133 | float context = image[index];
134 | float r = image[index + width * height];
135 | float g = image[index + width * height * 2];
136 | float b = image[index + width * height * 3];
137 |
138 | int r_id = floor(r / binsize);
139 | int g_id = floor(g / binsize);
140 | int b_id = floor(b / binsize);
141 | int context_id = floor(context / binsize);
142 |
143 | float r_d = fmod(r,binsize) / binsize;
144 | float g_d = fmod(g,binsize) / binsize;
145 | float b_d = fmod(b,binsize) / binsize;
146 | float context_d = fmod(context,binsize) / binsize;
147 |
148 | int id0000 = context_id + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim;
149 | int id1000 = context_id + 1 + r_id * dim + g_id * dim * dim + b_id * dim * dim * dim;
150 | int id0100 = context_id + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim;
151 | int id0010 = context_id + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
152 | int id0001 = context_id + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
153 | int id1100 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + b_id * dim * dim * dim;
154 | int id0110 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
155 | int id0011 = context_id + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
156 | int id1010 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
157 | int id1001 = context_id + 1 + r_id * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
158 | int id0101 = context_id + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
159 | int id1110 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + b_id * dim * dim * dim;
160 | int id1011 = context_id + 1 + r_id * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
161 | int id1101 = context_id + 1 + (r_id + 1) * dim + g_id * dim * dim + (b_id + 1) * dim * dim * dim;
162 | int id0111 = context_id + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
163 | int id1111 = context_id + 1 + (r_id + 1) * dim + (g_id + 1) * dim * dim + (b_id + 1) * dim * dim * dim;
164 |
165 |
166 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d);
167 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d);
168 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d);
169 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d);
170 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d;
171 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d);
172 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d);
173 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d;
174 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d);
175 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d;
176 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d;
177 | float w1110 = context_d*r_d*g_d*(1-b_d);
178 | float w0111 = (1-context_d)*r_d*g_d*b_d;
179 | float w1101 = context_d*r_d*(1-g_d)*b_d;
180 | float w1011 = context_d*(1-r_d)*g_d*b_d;
181 | float w1111 = context_d*r_d*g_d*b_d;
182 |
183 |
184 | lut_grad[id0000 ] += w0000 * image_grad[index + width * height];
185 | lut_grad[id1000 ] += w1000 * image_grad[index + width * height];
186 | lut_grad[id0100 ] += w0100 * image_grad[index + width * height];
187 | lut_grad[id0010 ] += w0010 * image_grad[index + width * height];
188 | lut_grad[id0001 ] += w0001 * image_grad[index + width * height];
189 | lut_grad[id1100 ] += w1100 * image_grad[index + width * height];
190 | lut_grad[id0110 ] += w0110 * image_grad[index + width * height];
191 | lut_grad[id0011 ] += w0011 * image_grad[index + width * height];
192 | lut_grad[id1010 ] += w1010 * image_grad[index + width * height];
193 | lut_grad[id1001 ] += w1001 * image_grad[index + width * height];
194 | lut_grad[id0101 ] += w0101 * image_grad[index + width * height];
195 | lut_grad[id1110 ] += w1110 * image_grad[index + width * height];
196 | lut_grad[id0111 ] += w0111 * image_grad[index + width * height];
197 | lut_grad[id1101 ] += w1101 * image_grad[index + width * height];
198 | lut_grad[id1011 ] += w1011 * image_grad[index + width * height];
199 | lut_grad[id1111 ] += w1111 * image_grad[index + width * height];
200 |
201 | lut_grad[id0000 + shift] += w0000 * image_grad[index + width * height * 2];
202 | lut_grad[id1000 + shift] += w1000 * image_grad[index + width * height * 2];
203 | lut_grad[id0100 + shift] += w0100 * image_grad[index + width * height * 2];
204 | lut_grad[id0010 + shift] += w0010 * image_grad[index + width * height * 2];
205 | lut_grad[id0001 + shift] += w0001 * image_grad[index + width * height * 2];
206 | lut_grad[id1100 + shift] += w1100 * image_grad[index + width * height * 2];
207 | lut_grad[id0110 + shift] += w0110 * image_grad[index + width * height * 2];
208 | lut_grad[id0011 + shift] += w0011 * image_grad[index + width * height * 2];
209 | lut_grad[id1010 + shift] += w1010 * image_grad[index + width * height * 2];
210 | lut_grad[id1001 + shift] += w1001 * image_grad[index + width * height * 2];
211 | lut_grad[id0101 + shift] += w0101 * image_grad[index + width * height * 2];
212 | lut_grad[id1110 + shift] += w1110 * image_grad[index + width * height * 2];
213 | lut_grad[id0111 + shift] += w0111 * image_grad[index + width * height * 2];
214 | lut_grad[id1101 + shift] += w1101 * image_grad[index + width * height * 2];
215 | lut_grad[id1011 + shift] += w1011 * image_grad[index + width * height * 2];
216 | lut_grad[id1111 + shift] += w1111 * image_grad[index + width * height * 2];
217 |
218 | lut_grad[id0000 + shift* 3] += w0000 * image_grad[index + width * height * 3];
219 | lut_grad[id1000 + shift* 3] += w1000 * image_grad[index + width * height * 3];
220 | lut_grad[id0100 + shift* 3] += w0100 * image_grad[index + width * height * 3];
221 | lut_grad[id0010 + shift* 3] += w0010 * image_grad[index + width * height * 3];
222 | lut_grad[id0001 + shift* 3] += w0001 * image_grad[index + width * height * 3];
223 | lut_grad[id1100 + shift* 3] += w1100 * image_grad[index + width * height * 3];
224 | lut_grad[id0110 + shift* 3] += w0110 * image_grad[index + width * height * 3];
225 | lut_grad[id0011 + shift* 3] += w0011 * image_grad[index + width * height * 3];
226 | lut_grad[id1010 + shift* 3] += w1010 * image_grad[index + width * height * 3];
227 | lut_grad[id1001 + shift* 3] += w1001 * image_grad[index + width * height * 3];
228 | lut_grad[id0101 + shift* 3] += w0101 * image_grad[index + width * height * 3];
229 | lut_grad[id1110 + shift* 3] += w1110 * image_grad[index + width * height * 3];
230 | lut_grad[id0111 + shift* 3] += w0111 * image_grad[index + width * height * 3];
231 | lut_grad[id1101 + shift* 3] += w1101 * image_grad[index + width * height * 3];
232 | lut_grad[id1011 + shift* 3] += w1011 * image_grad[index + width * height * 3];
233 | lut_grad[id1111 + shift* 3] += w1111 * image_grad[index + width * height * 3];
234 |
235 |
236 | }
237 | }
238 |
239 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
240 | m.def("forward", &quadrilinear_forward, "Quadrilinear forward");
241 | m.def("backward", &quadrilinear_backward, "Quadrilinear backward");
242 | }
243 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/src/quadrilinear4d.h:
--------------------------------------------------------------------------------
1 | #ifndef TRILINEAR_H
2 | #define TRILINEAR_H
3 |
4 | #include
5 |
6 | int quadrilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
7 | int lut_dim, int shift, float binsize, int width, int height, int batch);
8 |
9 | int quadrilinear_backward(torch::Tensor image, torch::Tensor image_grad,torch::Tensor lut, torch::Tensor lut_grad,
10 | int lut_dim, int shift, float binsize, int width, int height, int batch);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/src/quadrilinear4d_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include "quadrilinear4d_kernel.h"
2 | #include
3 | #include
4 |
5 | int quadrilinear4d_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
6 | int lut_dim, int shift, float binsize, int width, int height, int batch)
7 | {
8 | // Grab the input tensor
9 | float * lut_flat = lut.data();
10 | float * image_flat = image.data();
11 | float * output_flat = output.data();
12 |
13 | QuadriLinearForwardLaucher(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream());
14 |
15 | return 1;
16 | }
17 |
18 | int quadrilinear4d_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut, torch::Tensor lut_grad,
19 | int lut_dim, int shift, float binsize, int width, int height, int batch)
20 | {
21 | // Grab the input tensor
22 | float * image_grad_flat = image_grad.data();
23 | float * image_flat = image.data();
24 | float * lut_flat = lut.data();
25 | float * lut_grad_flat = lut_grad.data();
26 |
27 | QuadriLinearBackwardLaucher(image_flat, image_grad_flat, lut_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream());
28 |
29 | return 1;
30 | }
31 |
32 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
33 | m.def("forward", &quadrilinear4d_forward_cuda, "Quadrilinear forward");
34 | m.def("backward", &quadrilinear4d_backward_cuda, "Quadrilinear backward");
35 | }
36 |
37 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/src/quadrilinear4d_cuda.h:
--------------------------------------------------------------------------------
1 | #ifndef TRILINEAR_CUDA_H
2 | #define TRILINEAR_CUDA_H
3 |
4 | #import
5 |
6 | int quadrilinear4d_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output,
7 | int lut_dim, int shift, float binsize, int width, int height, int batch);
8 |
9 | int quadrilinear4d_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut, torch::Tensor lut_grad,
10 | int lut_dim, int shift, float binsize, int width, int height, int batch);
11 |
12 | #endif
13 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/src/quadrilinear4d_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "quadrilinear4d_kernel.h"
4 |
5 | #define CUDA_1D_KERNEL_LOOP(i, n) \
6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
7 | i += blockDim.x * gridDim.x)
8 |
9 |
10 | __global__ void QuadriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) {
11 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
12 |
13 | float context = image[index];
14 | float r = image[index + width * height * batch];
15 | float g = image[index + width * height * batch * 2];
16 | float b = image[index + width * height * batch * 3];
17 |
18 | int context_id = 0;
19 | int r_id = floor(r / binsize);
20 | int g_id = floor(g / binsize);
21 | int b_id = floor(b / binsize);
22 |
23 | float context_d = context;
24 | float r_d = fmod(r,binsize) / binsize;
25 | float g_d = fmod(g,binsize) / binsize;
26 | float b_d = fmod(b,binsize) / binsize;
27 |
28 | int id0000 = context_id * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim;
29 | int id0100 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim;
30 | int id0010 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim;
31 | int id0001 = context_id * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim;
32 | int id0110 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim;
33 | int id0011 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
34 | int id0101 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim;
35 | int id0111 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim;
36 |
37 | int id1000 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim;
38 | int id1100 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim;
39 | int id1010 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim;
40 | int id1001 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim;
41 | int id1110 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim;
42 | int id1011 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
43 | int id1101 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim;
44 | int id1111 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim;
45 |
46 |
47 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d);
48 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d);
49 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d);
50 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d;
51 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d);
52 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d;
53 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d;
54 | float w0111 = (1-context_d)*r_d*g_d*b_d;
55 |
56 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d);
57 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d);
58 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d);
59 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d;
60 | float w1110 = context_d*r_d*g_d*(1-b_d);
61 | float w1011 = context_d*(1-r_d)*g_d*b_d;
62 | float w1101 = context_d*r_d*(1-g_d)*b_d;
63 | float w1111 = context_d*r_d*g_d*b_d;
64 |
65 |
66 |
67 | output[index] = w0000 * lut[id0000] + w0100 * lut[id0100] + w0010 * lut[id0010] +
68 | w0001 * lut[id0001] + w0110 * lut[id0110] + w0011 * lut[id0011] +
69 | w0101 * lut[id0101] + w0111 * lut[id0111] +
70 | w1000 * lut[id1000] + w1100 * lut[id1100] + w1010 * lut[id1010] +
71 | w1001 * lut[id1001] + w1110 * lut[id1110] + w1011 * lut[id1011] +
72 | w1101 * lut[id1101] + w1111 * lut[id1111];
73 |
74 | output[index + width * height * batch] = w0000 * lut[id0000 + shift] + w0100 * lut[id0100 + shift] + w0010 * lut[id0010 + shift] +
75 | w0001 * lut[id0001 + shift] + w0110 * lut[id0110 + shift] + w0011 * lut[id0011 + shift] +
76 | w0101 * lut[id0101 + shift] + w0111 * lut[id0111 + shift] +
77 | w1000 * lut[id1000 + shift] + w1100 * lut[id1100 + shift] + w1010 * lut[id1010 + shift] +
78 | w1001 * lut[id1001 + shift] + w1110 * lut[id1110 + shift] + w1011 * lut[id1011 + shift] +
79 | w1101 * lut[id1101 + shift] + w1111 * lut[id1111 + shift];
80 |
81 | output[index + width * height * batch * 2] = w0000 * lut[id0000 + shift * 2] + w0100 * lut[id0100 + shift * 2] + w0010 * lut[id0010 + shift * 2] +
82 | w0001 * lut[id0001 + shift * 2] + w0110 * lut[id0110 + shift * 2] + w0011 * lut[id0011 + shift * 2] +
83 | w0101 * lut[id0101 + shift * 2] + w0111 * lut[id0111 + shift * 2] +
84 | w1000 * lut[id1000 + shift * 2] + w1100 * lut[id1100 + shift * 2] + w1010 * lut[id1010 + shift * 2] +
85 | w1001 * lut[id1001 + shift * 2] + w1110 * lut[id1110 + shift * 2] + w1011 * lut[id1011 + shift * 2] +
86 | w1101 * lut[id1101 + shift * 2] + w1111 * lut[id1111 + shift * 2];
87 |
88 | }
89 | }
90 |
91 |
92 | int QuadriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) {
93 | const int kThreadsPerBlock = 1024;
94 | const int output_size = height * width * batch;
95 | cudaError_t err;
96 |
97 |
98 | QuadriLinearForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, lut, image, output, lut_dim, shift, binsize, width, height, batch);
99 |
100 | err = cudaGetLastError();
101 | if(cudaSuccess != err) {
102 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
103 | exit( -1 );
104 | }
105 |
106 | return 1;
107 | }
108 |
109 |
110 | __global__ void QuadriLinearBackward(const int nthreads, const float* image, float* image_grad,const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) {
111 | CUDA_1D_KERNEL_LOOP(index, nthreads) {
112 |
113 | float context = image[index];
114 | float r = image[index + width * height * batch];
115 | float g = image[index + width * height * batch * 2];
116 | float b = image[index + width * height * batch * 3];
117 |
118 | int context_id = 0;
119 | int r_id = floor(r / binsize);
120 | int g_id = floor(g / binsize);
121 | int b_id = floor(b / binsize);
122 |
123 | float context_d = context;
124 | float r_d = fmod(r,binsize) / binsize;
125 | float g_d = fmod(g,binsize) / binsize;
126 | float b_d = fmod(b,binsize) / binsize;
127 |
128 |
129 | int id0000 = context_id * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim;
130 | int id0100 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim;
131 | int id0010 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim;
132 | int id0001 = context_id * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim;
133 | int id0110 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim;
134 | int id0011 = context_id * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
135 | int id0101 = context_id * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim;
136 | int id0111 = context_id * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim;
137 |
138 | int id1000 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + b_id * dim * dim;
139 | int id1100 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + b_id * dim * dim;
140 | int id1010 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + b_id * dim * dim;
141 | int id1001 = (context_id + 1) * dim * dim * dim + r_id + g_id * dim + (b_id + 1) * dim * dim;
142 | int id1110 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + b_id * dim * dim;
143 | int id1011 = (context_id + 1) * dim * dim * dim + r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim;
144 | int id1101 = (context_id + 1) * dim * dim * dim + (r_id + 1) + g_id * dim + (b_id + 1) * dim * dim;
145 | int id1111 = (context_id + 1) * dim * dim * dim + (r_id + 1) + (g_id + 1) * dim + (b_id + 1) * dim * dim;
146 |
147 |
148 | float w0000 = (1-context_d)*(1-r_d)*(1-g_d)*(1-b_d);
149 | float w0100 = (1-context_d)*r_d*(1-g_d)*(1-b_d);
150 | float w0010 = (1-context_d)*(1-r_d)*g_d*(1-b_d);
151 | float w0001 = (1-context_d)*(1-r_d)*(1-g_d)*b_d;
152 | float w0110 = (1-context_d)*r_d*g_d*(1-b_d);
153 | float w0011 = (1-context_d)*(1-r_d)*g_d*b_d;
154 | float w0101 = (1-context_d)*r_d*(1-g_d)*b_d;
155 | float w0111 = (1-context_d)*r_d*g_d*b_d;
156 |
157 | float w1000 = context_d*(1-r_d)*(1-g_d)*(1-b_d);
158 | float w1100 = context_d*r_d*(1-g_d)*(1-b_d);
159 | float w1010 = context_d*(1-r_d)*g_d*(1-b_d);
160 | float w1001 = context_d*(1-r_d)*(1-g_d)*b_d;
161 | float w1110 = context_d*r_d*g_d*(1-b_d);
162 | float w1011 = context_d*(1-r_d)*g_d*b_d;
163 | float w1101 = context_d*r_d*(1-g_d)*b_d;
164 | float w1111 = context_d*r_d*g_d*b_d;
165 |
166 |
167 |
168 | atomicAdd(lut_grad + id0000, image_grad[index + width * height * batch] * w0000);
169 | atomicAdd(lut_grad + id0100, image_grad[index + width * height * batch] * w0100);
170 | atomicAdd(lut_grad + id0010, image_grad[index + width * height * batch] * w0010);
171 | atomicAdd(lut_grad + id0001, image_grad[index + width * height * batch] * w0001);
172 | atomicAdd(lut_grad + id0110, image_grad[index + width * height * batch] * w0110);
173 | atomicAdd(lut_grad + id0011, image_grad[index + width * height * batch] * w0011);
174 | atomicAdd(lut_grad + id0101, image_grad[index + width * height * batch] * w0101);
175 | atomicAdd(lut_grad + id0111, image_grad[index + width * height * batch] * w0111);
176 |
177 | atomicAdd(lut_grad + id1000, image_grad[index + width * height * batch] * w1000);
178 | atomicAdd(lut_grad + id1100, image_grad[index + width * height * batch] * w1100);
179 | atomicAdd(lut_grad + id1010, image_grad[index + width * height * batch] * w1010);
180 | atomicAdd(lut_grad + id1001, image_grad[index + width * height * batch] * w1001);
181 | atomicAdd(lut_grad + id1110, image_grad[index + width * height * batch] * w1110);
182 | atomicAdd(lut_grad + id1011, image_grad[index + width * height * batch] * w1011);
183 | atomicAdd(lut_grad + id1101, image_grad[index + width * height * batch] * w1101);
184 | atomicAdd(lut_grad + id1111, image_grad[index + width * height * batch] * w1111);
185 |
186 | atomicAdd(lut_grad + id0000 + shift, image_grad[index + width * height * batch * 2] * w0000);
187 | atomicAdd(lut_grad + id0100 + shift, image_grad[index + width * height * batch * 2] * w0100);
188 | atomicAdd(lut_grad + id0010 + shift, image_grad[index + width * height * batch * 2] * w0010);
189 | atomicAdd(lut_grad + id0001 + shift, image_grad[index + width * height * batch * 2] * w0001);
190 | atomicAdd(lut_grad + id0110 + shift, image_grad[index + width * height * batch * 2] * w0110);
191 | atomicAdd(lut_grad + id0011 + shift, image_grad[index + width * height * batch * 2] * w0011);
192 | atomicAdd(lut_grad + id0101 + shift, image_grad[index + width * height * batch * 2] * w0101);
193 | atomicAdd(lut_grad + id0111 + shift, image_grad[index + width * height * batch * 2] * w0111);
194 |
195 | atomicAdd(lut_grad + id1000 + shift, image_grad[index + width * height * batch * 2] * w1000);
196 | atomicAdd(lut_grad + id1100 + shift, image_grad[index + width * height * batch * 2] * w1100);
197 | atomicAdd(lut_grad + id1010 + shift, image_grad[index + width * height * batch * 2] * w1010);
198 | atomicAdd(lut_grad + id1001 + shift, image_grad[index + width * height * batch * 2] * w1001);
199 | atomicAdd(lut_grad + id1110 + shift, image_grad[index + width * height * batch * 2] * w1110);
200 | atomicAdd(lut_grad + id1011 + shift, image_grad[index + width * height * batch * 2] * w1011);
201 | atomicAdd(lut_grad + id1101 + shift, image_grad[index + width * height * batch * 2] * w1101);
202 | atomicAdd(lut_grad + id1111 + shift, image_grad[index + width * height * batch * 2] * w1111);
203 |
204 | atomicAdd(lut_grad + id0000 + shift * 2, image_grad[index + width * height * batch * 3] * w0000);
205 | atomicAdd(lut_grad + id0100 + shift * 2, image_grad[index + width * height * batch * 3] * w0100);
206 | atomicAdd(lut_grad + id0010 + shift * 2, image_grad[index + width * height * batch * 3] * w0010);
207 | atomicAdd(lut_grad + id0001 + shift * 2, image_grad[index + width * height * batch * 3] * w0001);
208 | atomicAdd(lut_grad + id0110 + shift * 2, image_grad[index + width * height * batch * 3] * w0110);
209 | atomicAdd(lut_grad + id0011 + shift * 2, image_grad[index + width * height * batch * 3] * w0011);
210 | atomicAdd(lut_grad + id0101 + shift * 2, image_grad[index + width * height * batch * 3] * w0101);
211 | atomicAdd(lut_grad + id0111 + shift * 2, image_grad[index + width * height * batch * 3] * w0111);
212 |
213 | atomicAdd(lut_grad + id1000 + shift * 2, image_grad[index + width * height * batch * 3] * w1000);
214 | atomicAdd(lut_grad + id1100 + shift * 2, image_grad[index + width * height * batch * 3] * w1100);
215 | atomicAdd(lut_grad + id1010 + shift * 2, image_grad[index + width * height * batch * 3] * w1010);
216 | atomicAdd(lut_grad + id1001 + shift * 2, image_grad[index + width * height * batch * 3] * w1001);
217 | atomicAdd(lut_grad + id1110 + shift * 2, image_grad[index + width * height * batch * 3] * w1110);
218 | atomicAdd(lut_grad + id1011 + shift * 2, image_grad[index + width * height * batch * 3] * w1011);
219 | atomicAdd(lut_grad + id1101 + shift * 2, image_grad[index + width * height * batch * 3] * w1101);
220 | atomicAdd(lut_grad + id1111 + shift * 2, image_grad[index + width * height * batch * 3] * w1111);
221 |
222 | // atomicAdd(image_grad + index, (image_grad[index + width * height * batch] + image_grad[index + width * height * batch * 2] + image_grad[index + width * height * batch * 3]) / 3);
223 |
224 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] *
225 | // ( (-1)*(1-r_d)*(1-g_d)*(1-b_d) * lut[id0000] +
226 | // 1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000] +
227 | // (-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100] +
228 | // (-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010] +
229 | // (-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001] +
230 | // 1*r_d*(1-g_d)*(1-b_d) * lut[id1100] +
231 | // (-1)*r_d*g_d*(1-b_d) * lut[id0110] +
232 | // (-1)*(1-r_d)*g_d*b_d * lut[id0011] +
233 | // 1*(1-r_d)*g_d*(1-b_d) * lut[id1010] +
234 | // 1*(1-r_d)*(1-g_d)*b_d * lut[id1001] +
235 | // (-1)*r_d*(1-g_d)*b_d * lut[id0101] +
236 | // 1*r_d*g_d*(1-b_d) * lut[id1110] +
237 | // (-1)*r_d*g_d*b_d * lut[id0111] +
238 | // 1*r_d*(1-g_d)*b_d * lut[id1101] +
239 | // 1*(1-r_d)*g_d*b_d * lut[id1011] +
240 | // 1*r_d*g_d*b_d * lut[id1111]
241 | // )
242 | // );
243 |
244 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*(1-g_d)*(1-b_d)*lut[id0000]));
245 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000]));
246 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100]));
247 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010]));
248 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001]));
249 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*(1-g_d)*(1-b_d) * lut[id1100]));
250 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*g_d*(1-b_d) * lut[id0110]));
251 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*(1-r_d)*g_d*b_d * lut[id0011]));
252 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*g_d*(1-b_d) * lut[id1010]));
253 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*(1-g_d)*b_d * lut[id1001]));
254 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*(1-g_d)*b_d * lut[id0101]));
255 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*g_d*(1-b_d) * lut[id1110]));
256 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * ((-1)*r_d*g_d*b_d * lut[id0111]));
257 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*(1-g_d)*b_d * lut[id1101]));
258 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*(1-r_d)*g_d*b_d * lut[id1011]));
259 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch] * (1*r_d*g_d*b_d * lut[id1111]));
260 |
261 |
262 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] *
263 | // ( (-1)*(1-r_d)*(1-g_d)*(1-b_d) * lut[id0000] +
264 | // 1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000] +
265 | // (-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100] +
266 | // (-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010] +
267 | // (-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001] +
268 | // 1*r_d*(1-g_d)*(1-b_d) * lut[id1100] +
269 | // (-1)*r_d*g_d*(1-b_d) * lut[id0110] +
270 | // (-1)*(1-r_d)*g_d*b_d * lut[id0011] +
271 | // 1*(1-r_d)*g_d*(1-b_d) * lut[id1010] +
272 | // 1*(1-r_d)*(1-g_d)*b_d * lut[id1001] +
273 | // (-1)*r_d*(1-g_d)*b_d * lut[id0101] +
274 | // 1*r_d*g_d*(1-b_d) * lut[id1110] +
275 | // (-1)*r_d*g_d*b_d * lut[id0111] +
276 | // 1*r_d*(1-g_d)*b_d * lut[id1101] +
277 | // 1*(1-r_d)*g_d*b_d * lut[id1011] +
278 | // 1*r_d*g_d*b_d * lut[id1111]
279 | // )
280 | // );
281 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] *
282 | // ( (-1)*(1-r_d)*(1-g_d)*(1-b_d) * lut[id0000] +
283 | // 1*(1-r_d)*(1-g_d)*(1-b_d) * lut[id1000] +
284 | // (-1)*r_d*(1-g_d)*(1-b_d) * lut[id0100] +
285 | // (-1)*(1-r_d)*g_d*(1-b_d) * lut[id0010] +
286 | // (-1)*(1-r_d)*(1-g_d)*b_d * lut[id0001] +
287 | // 1*r_d*(1-g_d)*(1-b_d) * lut[id1100] +
288 | // (-1)*r_d*g_d*(1-b_d) * lut[id0110] +
289 | // (-1)*(1-r_d)*g_d*b_d * lut[id0011] +
290 | // 1*(1-r_d)*g_d*(1-b_d) * lut[id1010] +
291 | // 1*(1-r_d)*(1-g_d)*b_d * lut[id1001] +
292 | // (-1)*r_d*(1-g_d)*b_d * lut[id0101] +
293 | // 1*r_d*g_d*(1-b_d) * lut[id1110] +
294 | // (-1)*r_d*g_d*b_d * lut[id0111] +
295 | // 1*r_d*(1-g_d)*b_d * lut[id1101] +
296 | // 1*(1-r_d)*g_d*b_d * lut[id1011] +
297 | // 1*r_d*g_d*b_d * lut[id1111]
298 | // )
299 | // );
300 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2]);
301 | // atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3]);
302 |
303 |
304 |
305 | // float i000 = lut[id1000]-lut[id0000];
306 | // float i100 = lut[id1100]-lut[id0100];
307 | // float i010 = lut[id1010]-lut[id0010];
308 | // float i001 = lut[id1001]-lut[id0001];
309 | // float i110 = lut[id1110]-lut[id0110];
310 | // float i011 = lut[id1011]-lut[id0011];
311 | // float i101 = lut[id1101]-lut[id0101];
312 | // float i111 = lut[id1111]-lut[id0111];
313 |
314 | float w000 = (1-r_d)*(1-g_d)*(1-b_d);
315 | float w100 = r_d*(1-g_d)*(1-b_d);
316 | float w010 = (1-r_d)*g_d*(1-b_d);
317 | float w001 = (1-r_d)*(1-g_d)*b_d;
318 | float w110 = r_d*g_d*(1-b_d);
319 | float w011 = (1-r_d)*g_d*b_d;
320 | float w101 = r_d*(1-g_d)*b_d;
321 | float w111 = r_d*g_d*b_d;
322 |
323 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w000 * binsize);
324 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w100 * binsize);
325 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w010 * binsize);
326 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w001 * binsize);
327 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w110 * binsize);
328 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w011 * binsize);
329 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w101 * binsize);
330 | atomicAdd(image_grad + index, image_grad[index + width * height * batch] * w111 * binsize);
331 |
332 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w000 * binsize);
333 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w100 * binsize);
334 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w010 * binsize);
335 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w001 * binsize);
336 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w110 * binsize);
337 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w011 * binsize);
338 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w101 * binsize);
339 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 2] * w111 * binsize);
340 |
341 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w000 * binsize);
342 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w100 * binsize);
343 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w010 * binsize);
344 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w001 * binsize);
345 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w110 * binsize);
346 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w011 * binsize);
347 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w101 * binsize);
348 | atomicAdd(image_grad + index, image_grad[index + width * height * batch * 3] * w111 * binsize);
349 | }
350 | }
351 |
352 | int QuadriLinearBackwardLaucher(const float* image, float* image_grad, const float* lut, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) {
353 | const int kThreadsPerBlock = 1024;
354 | const int output_size = height * width * batch;
355 | cudaError_t err;
356 |
357 | QuadriLinearBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, image, image_grad, lut, lut_grad, lut_dim, shift, binsize, width, height, batch);
358 |
359 | err = cudaGetLastError();
360 | if(cudaSuccess != err) {
361 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) );
362 | exit( -1 );
363 | }
364 |
365 | return 1;
366 | }
367 |
--------------------------------------------------------------------------------
/quadrilinear_cpp/src/quadrilinear4d_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _TRILINEAR_KERNEL
2 | #define _TRILINEAR_KERNEL
3 |
4 | #include
5 |
6 | __global__ void QuadriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch);
7 |
8 | int QuadriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream);
9 |
10 | __global__ void QuadriLinearBackward(const int nthreads, const float* image, float* image_grad, const float* lut, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch);
11 |
12 | int QuadriLinearBackwardLaucher(const float* image, float* image_grad, const float* lut, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream);
13 |
14 |
15 | #endif
16 |
17 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import math
5 | import itertools
6 | import time
7 | import datetime
8 | import sys
9 | from skimage.metrics import structural_similarity as ssim
10 | import torchvision.transforms as transforms
11 | from torchvision.utils import save_image
12 |
13 | from torch.utils.data import DataLoader
14 | from torchvision import datasets
15 | from torch.autograd import Variable
16 |
17 | from models_x import *
18 | from datasetsMIT import *
19 |
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch
23 |
24 | # CUDA_VISIBLE_DEVICES
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from, 0 starts from scratch, >0 starts from saved checkpoints")
27 | parser.add_argument("--n_epochs", type=int, default=1000, help="total number of epochs of training")
28 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset")
29 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ")
30 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
31 | parser.add_argument("--lr", type=float, default=0.0001, help="adam: learning rate")
32 | parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
33 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
34 | parser.add_argument("--lambda_smooth", type=float, default=0.0001, help="smooth regularization")
35 | parser.add_argument("--lambda_monotonicity", type=float, default=10.0, help="monotonicity regularization")
36 | parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation")
37 | parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between model checkpoints")
38 | parser.add_argument("--output_dir", type=str, default="LUTs/fiveK", help="path to save model")
39 | opt = parser.parse_args()
40 |
41 | opt.output_dir = opt.output_dir + '_' + opt.input_color_space
42 | print(opt)
43 |
44 | os.makedirs("saved_models/%s" % opt.output_dir, exist_ok=True)
45 |
46 | cuda = True if torch.cuda.is_available() else False
47 | # Tensor type
48 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
49 |
50 | # Loss functions
51 | criterion_pixelwise = torch.nn.MSELoss()
52 |
53 | # Initialize generator and discriminator
54 | LUT_enhancement = Generator4DLUT_identity()
55 | Generator_bias = Generator_for_bias()
56 | Generator_context = Generator_for_info()
57 | TV4 = TV_4D()
58 | quadrilinear_enhancement_ = QuadrilinearInterpolation_4D()
59 |
60 | if cuda:
61 | LUT_enhancement = LUT_enhancement.cuda()
62 | Generator_bias = Generator_bias.cuda()
63 | Generator_context = Generator_context.cuda()
64 | criterion_pixelwise.cuda()
65 | TV4.cuda()
66 | TV4.weight_r = TV4.weight_r.type(Tensor)
67 | TV4.weight_g = TV4.weight_g.type(Tensor)
68 | TV4.weight_b = TV4.weight_b.type(Tensor)
69 |
70 | if opt.epoch != 0:
71 | LUT_enhancements = torch.load("saved_models/%s/4DLUTs_enhancement_%d.pth" % (opt.output_dir, opt.epoch))
72 | LUT_enhancement.load_state_dict(LUT_enhancements)
73 | Generator_bias.load_state_dict(torch.load("saved_models/%s/generator_bias_%d.pth" % (opt.output_dir, opt.epoch)))
74 | Generator_context.load_state_dict(torch.load("saved_models/%s/generator_context_%d.pth" % (opt.output_dir, opt.epoch)))
75 | else:
76 | # Initialize weights
77 | Generator_bias.apply(weights_init_normal_generator)
78 | torch.nn.init.constant_(Generator_bias.model[16].bias.data, 1.0)
79 |
80 | # Optimizers
81 | optimizer_G = torch.optim.Adam(itertools.chain(Generator_bias.parameters(), Generator_context.parameters(), LUT_enhancement.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) #, LUT3.parameters(), LUT4.parameters()
82 |
83 | if opt.input_color_space == 'sRGB':
84 | dataloader = DataLoader(
85 | ImageDataset_sRGB("/fivek_dataset/MIT-Adobe5k-UPE/" , mode = "train"),
86 | batch_size=opt.batch_size,
87 | shuffle=True,
88 | num_workers=1,
89 | )
90 |
91 | psnr_dataloader = DataLoader(
92 | ImageDataset_sRGB("/fivek_dataset/MIT-Adobe5k-UPE/" , mode="test"),
93 | batch_size=1,
94 | shuffle=False,
95 | num_workers=1,
96 | )
97 |
98 |
99 | def generator_train(img):
100 |
101 | context = Generator_context(img)
102 | pred = Generator_bias(img)
103 |
104 | context = context.new(context.size())
105 |
106 | context = Variable(context.fill_(0).type(Tensor))
107 |
108 | pred = pred.squeeze(2).squeeze(2)
109 | combine = torch.cat([context,img],1)
110 |
111 | gen_A0 = LUT_enhancement(combine)
112 |
113 | weights_norm = torch.mean(pred ** 2)
114 |
115 | combine_A = img.new(img.size())
116 | for b in range(img.size(0)):
117 | combine_A[b,0,:,:] = pred[b,0] * gen_A0[b,0,:,:] + pred[b,1] * gen_A0[b,1,:,:] + pred[b,2] * gen_A0[b,2,:,:] + pred[b,9]
118 | combine_A[b,1,:,:] = pred[b,3] * gen_A0[b,0,:,:] + pred[b,4] * gen_A0[b,1,:,:] + pred[b,5] * gen_A0[b,2,:,:] + pred[b,10]
119 | combine_A[b,2,:,:] = pred[b,6] * gen_A0[b,0,:,:] + pred[b,7] * gen_A0[b,1,:,:] + pred[b,8] * gen_A0[b,2,:,:] + pred[b,11]
120 |
121 | return combine_A, weights_norm
122 |
123 | def generator_eval(img):
124 | context = Generator_context(img)
125 | pred = Generator_bias(img)
126 |
127 | context = context.new(context.size())
128 | context = Variable(context.fill_(0).type(Tensor))
129 |
130 | pred = pred.squeeze(2).squeeze(2).squeeze(0)
131 |
132 | combine = torch.cat([context,img],1)
133 |
134 | new_LUT_enhancement = LUT_enhancement.LUT_en.new(LUT_enhancement.LUT_en.size())
135 | new_LUT_enhancement[0] = pred[0] * LUT_enhancement.LUT_en[0] + pred[1] * LUT_enhancement.LUT_en[1] + pred[2] * LUT_enhancement.LUT_en[2] + pred[9]
136 | new_LUT_enhancement[1] = pred[3] * LUT_enhancement.LUT_en[0] + pred[4] * LUT_enhancement.LUT_en[1] + pred[5] * LUT_enhancement.LUT_en[2] + pred[10]
137 | new_LUT_enhancement[2] = pred[6] * LUT_enhancement.LUT_en[0] + pred[7] * LUT_enhancement.LUT_en[1] + pred[8] * LUT_enhancement.LUT_en[2] + pred[11]
138 |
139 | weights_norm = torch.mean(pred[0] ** 2)
140 | combine_A = img.new(img.size())
141 | _, combine_A = quadrilinear_enhancement_(new_LUT_enhancement,combine)
142 |
143 | return combine_A, weights_norm
144 |
145 | def calculate_psnr():
146 | Generator_bias.eval()
147 | Generator_context.eval()
148 | avg_psnr = 0
149 | for i, batch in enumerate(psnr_dataloader):
150 | real_A = Variable(batch["A_input"].type(Tensor))
151 | real_B = Variable(batch["A_exptC"].type(Tensor))
152 | fake_B, _ = generator_eval(real_A)
153 | fake_B = torch.clamp(fake_B,0.0,1.0)
154 | fake_B = torch.round(fake_B*255)
155 | real_B = torch.round(real_B*255)
156 | mse = criterion_pixelwise(fake_B, real_B)
157 | psnr = 10 * math.log10(255.0 * 255.0 / mse.cpu().detach().item())
158 | avg_psnr += psnr
159 | return avg_psnr/ len(psnr_dataloader)
160 |
161 | def calculate_ssim():
162 | Generator_bias.eval()
163 | Generator_context.eval()
164 | avg_ssim = 0
165 | for i, batch in enumerate(psnr_dataloader):
166 | real_A = Variable(batch["A_input"].type(Tensor))
167 | real_B = Variable(batch["A_exptC"].type(Tensor))
168 | fake_B, _ = generator_eval(real_A)
169 | fake_B = torch.clamp(fake_B,0.0,1.0)
170 | fake_B = fake_B.squeeze(0).cpu().detach().numpy()
171 | real_B = real_B.squeeze(0).cpu().detach().numpy()
172 | fake_B = np.swapaxes(np.swapaxes(fake_B, 0, 2), 0, 1)
173 | real_B = np.swapaxes(np.swapaxes(real_B, 0, 2), 0, 1)
174 | fake_B = fake_B.astype(np.float32)
175 | real_B = real_B.astype(np.float32)
176 | ssim_val = ssim(real_B,fake_B, data_range=real_B.max() - fake_B.min(), multichannel=True, gaussian_weights=True, win_size=11)
177 | avg_ssim += ssim_val
178 | return avg_ssim / len(psnr_dataloader)
179 |
180 |
181 |
182 | def visualize_result(epoch):
183 | """Saves a generated sample from the validation set"""
184 | Generator_bias.eval()
185 | Generator_context.eval()
186 | os.makedirs("images/%s/" % opt.output_dir +str(epoch), exist_ok=True)
187 | for i, batch in enumerate(psnr_dataloader):
188 | real_A = Variable(batch["A_input"].type(Tensor))
189 | real_B = Variable(batch["A_exptC"].type(Tensor))
190 | img_name = batch["input_name"]
191 | fake_B, _ = generator_eval(real_A)
192 | fake_B = torch.clamp(fake_B,0.0,1.0)
193 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1)
194 | fake_B = torch.round(fake_B*255)
195 | real_B = torch.round(real_B*255)
196 | mse = criterion_pixelwise(fake_B, real_B)
197 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item())
198 | save_image(img_sample, "images/%s/%s/%s.jpg" % (opt.output_dir,epoch, img_name[0]+'_'+str(psnr)[:5]), nrow=3, normalize=False)
199 |
200 | # ----------
201 | # Training
202 | # ----------
203 | prev_time = time.time()
204 | max_psnr = 0
205 | max_epoch = 0
206 | for epoch in range(opt.epoch, opt.n_epochs):
207 | mse_avg = 0
208 | psnr_avg = 0
209 | Generator_bias.train()
210 | Generator_context.train()
211 | for i, batch in enumerate(dataloader):
212 | # Model inputs
213 | real_A = Variable(batch["A_input"].type(Tensor))
214 | real_B = Variable(batch["A_exptC"].type(Tensor))
215 | # ------------------
216 | # Train Generators
217 | # ------------------
218 |
219 | optimizer_G.zero_grad()
220 |
221 | fake_B, weights_norm = generator_train(real_A)
222 |
223 | # Pixel-wise loss
224 | mse = criterion_pixelwise(fake_B, real_B)
225 |
226 | tv_enhancement, mn_enhancement = TV4(LUT_enhancement)
227 |
228 | tv_cons = tv_enhancement
229 | mn_cons = mn_enhancement
230 |
231 | # loss = mse
232 | loss = mse + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_cons
233 | psnr_avg += 10 * math.log10(1 / mse.item())
234 |
235 | mse_avg += mse.item()
236 |
237 | loss.backward()
238 |
239 | optimizer_G.step()
240 |
241 |
242 | # --------------
243 | # Log Progress
244 | # --------------
245 |
246 | # Determine approximate time left
247 | batches_done = epoch * len(dataloader) + i
248 | batches_left = opt.n_epochs * len(dataloader) - batches_done
249 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
250 | prev_time = time.time()
251 |
252 | # Print log
253 | if i % 500 == 0:
254 | sys.stdout.write(
255 | "\r[Epoch %d/%d] [Batch %d/%d] [psnr: %f, tv: %f, mn: %f] ETA: %s"
256 | % (epoch,opt.n_epochs,i,len(dataloader),psnr_avg / (i+1),tv_cons, mn_cons, time_left,
257 | )
258 | )
259 | if i % 500 == 0:
260 | print(
261 | "\r[Epoch %d/%d] [Batch %d/%d] [psnr: %f, tv: %f, mn: %f] ETA: %s"
262 | % (epoch,opt.n_epochs,i,len(dataloader),psnr_avg / (i+1),tv_cons, mn_cons, time_left,
263 | )
264 | )
265 | avg_ssim = calculate_ssim()
266 | avg_psnr = calculate_psnr()
267 | if avg_psnr > max_psnr:
268 | max_psnr = avg_psnr
269 | max_epoch = epoch
270 |
271 | LUTs_enhancement = LUT_enhancement.state_dict()
272 | torch.save(LUTs_enhancement, "saved_models/%s/4DLUTs_enhancement_%d.pth" % (opt.output_dir, epoch))
273 | torch.save(Generator_bias.state_dict(), "saved_models/%s/generator_bias_%d.pth" % (opt.output_dir, epoch))
274 | torch.save(Generator_context.state_dict(), "saved_models/%s/generator_context_%d.pth" % (opt.output_dir, epoch))
275 | file = open('saved_models/%s/result.txt' % opt.output_dir,'a')
276 | file.write(" [PSNR: %f , SSIM: %f] [epoch: %d]\n"% (max_psnr,avg_ssim, max_epoch))
277 | file.close()
278 |
279 | sys.stdout.write(" [PSNR: %f, SSIM: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, avg_ssim, max_psnr, max_epoch))
280 | print(" [PSNR: %f, SSIM: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, avg_ssim, max_psnr, max_epoch))
281 |
282 | if (epoch+1) % 100 == 0:
283 | visualize_result(epoch+1)
284 |
285 |
--------------------------------------------------------------------------------