├── .gitignore ├── README.md ├── dataset.py ├── ds ├── input │ └── img19.jpg └── output │ └── img19.jpg ├── metrics.py ├── model.py ├── ops ├── __init__.py ├── cuda │ ├── bilateral_slice.cu.cc │ └── bilateral_slice.h └── ops.py ├── requirements.txt ├── slice.py ├── test.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | algo/ 93 | ch/ 94 | ds2/ 95 | out.png 96 | 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Bilateral Learning for Real-Time Image Enhancements 2 | Unofficial PyTorch implementation of 'Deep Bilateral Learning for Real-Time Image Enhancement', SIGGRAPH 2017 https://groups.csail.mit.edu/graphics/hdrnet/ 3 | 4 | Python 3.6 5 | 6 | ### Dependencies 7 | 8 | To install the Python dependencies, run: 9 | 10 | pip install -r requirements.txt 11 | 12 | ## Datasets 13 | Adobe FiveK - https://data.csail.mit.edu/graphics/fivek/ 14 | 15 | ## Usage 16 | 17 | To train a model, run the following command: 18 | 19 | python train.py --test-image=./DSC_1177.jpg --dataset=/dataset_path --lr=0.0001 20 | 21 | To get all train params run: 22 | 23 | python train.py -h 24 | 25 | To test image run: 26 | 27 | python test.py --checkpoint=./ch/ckpt_0_4000.pth --input=./DSC_1177.jpg --output=out.png 28 | 29 | 30 | ## Known issues 31 | 32 | * Only PointwiseNN implemented currently 33 | * Dataset has no augmentation which making training difficult 34 | * No raw HDR input 35 | 36 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | 4 | import torch 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | 8 | class HDRDataset(Dataset): 9 | def __init__(self, image_path, params=None, suffix='', aug=False): 10 | self.image_path = image_path 11 | self.suffix = suffix 12 | self.aug = aug 13 | self.in_files = self.list_files(os.path.join(image_path, 'input'+suffix)) 14 | ls = params['net_input_size'] 15 | fs = params['net_output_size'] 16 | self.low = transforms.Compose([ 17 | transforms.Resize((ls,ls), Image.BICUBIC), 18 | transforms.ToTensor() 19 | ]) 20 | self.correction = transforms.Compose([ 21 | transforms.ColorJitter(brightness=0.5, contrast=0.2, saturation=0.2, hue=0), 22 | ]) 23 | self.out = transforms.Compose([ 24 | transforms.Resize((fs,fs), Image.BICUBIC), 25 | transforms.ToTensor() 26 | ]) 27 | self.full = transforms.Compose([ 28 | transforms.Resize((fs,fs), Image.BICUBIC), 29 | transforms.ToTensor() 30 | ]) 31 | 32 | def __len__(self): 33 | return len(self.in_files) 34 | 35 | def __getitem__(self, idx): 36 | fname = os.path.split(self.in_files[idx])[-1] 37 | imagein = Image.open(self.in_files[idx]).convert('RGB') 38 | imageout = Image.open(os.path.join(self.image_path, 'output'+self.suffix, fname)).convert('RGB') 39 | if self.aug: 40 | imagein = self.correction(imagein) 41 | imagein_low = self.low(imagein) 42 | imagein_full = self.full(imagein) 43 | imageout = self.out(imageout) 44 | 45 | return imagein_low,imagein_full,imageout 46 | 47 | def list_files(self, in_path): 48 | files = [] 49 | for (dirpath, dirnames, filenames) in os.walk(in_path): 50 | files.extend(filenames) 51 | break 52 | files = sorted([os.path.join(in_path, x) for x in files]) 53 | return files -------------------------------------------------------------------------------- /ds/input/img19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/creotiv/hdrnet-pytorch/5335ce7e0e32c4c7416f562bf81175963865e93d/ds/input/img19.jpg -------------------------------------------------------------------------------- /ds/output/img19.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/creotiv/hdrnet-pytorch/5335ce7e0e32c4c7416f562bf81175963865e93d/ds/output/img19.jpg -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def psnr(target, prediction): 6 | x = (target-prediction)**2 7 | x = x.view(x.shape[0], -1) 8 | p = torch.mean((-10/np.log(10))*torch.log(torch.mean(x, 1))) 9 | return p 10 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from collections import OrderedDict 7 | from slice import bilateral_slice 8 | 9 | 10 | class L2LOSS(nn.Module): 11 | 12 | def forward(self, x,y): 13 | return torch.mean((x-y)**2) 14 | 15 | class ConvBlock(nn.Module): 16 | def __init__(self, inc , outc, kernel_size=3, padding=1, stride=1, use_bias=True, activation=nn.ReLU, batch_norm=False): 17 | super(ConvBlock, self).__init__() 18 | self.conv = nn.Conv2d(int(inc), int(outc), kernel_size, padding=padding, stride=stride, bias=use_bias) 19 | self.activation = activation() if activation else None 20 | self.bn = nn.BatchNorm2d(outc) if batch_norm else None 21 | 22 | if use_bias and not batch_norm: 23 | self.conv.bias.data.fill_(0.00) 24 | # aka TF variance_scaling_initializer 25 | torch.nn.init.kaiming_uniform_(self.conv.weight)#, mode='fan_out',nonlinearity='relu') 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | if self.bn: 30 | x = self.bn(x) 31 | if self.activation: 32 | x = self.activation(x) 33 | return x 34 | 35 | class FC(nn.Module): 36 | def __init__(self, inc , outc, activation=nn.ReLU, batch_norm=False): 37 | super(FC, self).__init__() 38 | self.fc = nn.Linear(int(inc), int(outc), bias=(not batch_norm)) 39 | self.activation = activation() if activation else None 40 | self.bn = nn.BatchNorm1d(outc) if batch_norm else None 41 | 42 | if not batch_norm: 43 | self.fc.bias.data.fill_(0.00) 44 | # aka TF variance_scaling_initializer 45 | torch.nn.init.kaiming_uniform_(self.fc.weight)#, mode='fan_out',nonlinearity='relu') 46 | 47 | 48 | def forward(self, x): 49 | x = self.fc(x) 50 | if self.bn: 51 | x = self.bn(x) 52 | if self.activation: 53 | x = self.activation(x) 54 | return x 55 | 56 | class Slice(nn.Module): 57 | def __init__(self): 58 | super(Slice, self).__init__() 59 | 60 | def forward(self, bilateral_grid, guidemap): 61 | bilateral_grid = bilateral_grid.permute(0,3,4,2,1) 62 | guidemap = guidemap.squeeze(1) 63 | # grid: The bilateral grid with shape (gh, gw, gd, gc). 64 | # guide: A guide image with shape (h, w). Values must be in the range [0, 1]. 65 | coeefs = bilateral_slice(bilateral_grid, guidemap).permute(0,3,1,2) 66 | return coeefs 67 | # Nx12x8x16x16 68 | # print(guidemap.shape) 69 | # print(bilateral_grid.shape) 70 | # device = bilateral_grid.get_device() 71 | # N, _, H, W = guidemap.shape 72 | # hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW 73 | # if device >= 0: 74 | # hg = hg.to(device) 75 | # wg = wg.to(device) 76 | # hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H-1)# * 2 - 1 # norm to [-1,1] NxHxWx1 77 | # wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W-1)# * 2 - 1 # norm to [-1,1] NxHxWx1 78 | # guidemap = guidemap.permute(0,2,3,1).contiguous() 79 | # guidemap_guide = torch.cat([hg, wg, guidemap], dim=3).unsqueeze(1) # Nx1xHxWx3 80 | # # When mode='bilinear' and the input is 5-D, the interpolation mode used internally will actually be trilinear. 81 | # coeff = F.grid_sample(bilateral_grid, guidemap_guide, 'bilinear')#, align_corners=True) 82 | # print(coeff.shape) 83 | # return coeff.squeeze(2) 84 | 85 | class ApplyCoeffs(nn.Module): 86 | def __init__(self): 87 | super(ApplyCoeffs, self).__init__() 88 | 89 | def forward(self, coeff, full_res_input): 90 | 91 | ''' 92 | Affine: 93 | r = a11*r + a12*g + a13*b + a14 94 | g = a21*r + a22*g + a23*b + a24 95 | ... 96 | ''' 97 | 98 | # out_channels = [] 99 | # for chan in range(n_out): 100 | # ret = scale[:, :, :, chan, 0]*input_image[:, :, :, 0] 101 | # for chan_i in range(1, n_in): 102 | # ret += scale[:, :, :, chan, chan_i]*input_image[:, :, :, chan_i] 103 | # if has_affine_term: 104 | # ret += offset[:, :, :, chan] 105 | # ret = tf.expand_dims(ret, 3) 106 | # out_channels.append(ret) 107 | 108 | # ret = tf.concat(out_channels, 3) 109 | """ 110 | R = r1[0]*r2 + r1[1]*g2 + r1[2]*b3 +r1[3] 111 | """ 112 | 113 | # print(coeff.shape) 114 | # R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 3:4, :, :] 115 | # G = torch.sum(full_res_input * coeff[:, 4:7, :, :], dim=1, keepdim=True) + coeff[:, 7:8, :, :] 116 | # B = torch.sum(full_res_input * coeff[:, 8:11, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :] 117 | R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 9:10, :, :] 118 | G = torch.sum(full_res_input * coeff[:, 3:6, :, :], dim=1, keepdim=True) + coeff[:, 10:11, :, :] 119 | B = torch.sum(full_res_input * coeff[:, 6:9, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :] 120 | 121 | return torch.cat([R, G, B], dim=1) 122 | 123 | 124 | class GuideNN(nn.Module): 125 | def __init__(self, params=None): 126 | super(GuideNN, self).__init__() 127 | self.params = params 128 | self.conv1 = ConvBlock(3, params['guide_complexity'], kernel_size=1, padding=0, batch_norm=True) 129 | self.conv2 = ConvBlock(params['guide_complexity'], 1, kernel_size=1, padding=0, activation= nn.Sigmoid) #nn.Tanh nn.Sigmoid 130 | 131 | def forward(self, x): 132 | return self.conv2(self.conv1(x))#.squeeze(1) 133 | 134 | class Coeffs(nn.Module): 135 | 136 | def __init__(self, nin=4, nout=3, params=None): 137 | super(Coeffs, self).__init__() 138 | self.params = params 139 | self.nin = nin 140 | self.nout = nout 141 | 142 | lb = params['luma_bins'] 143 | cm = params['channel_multiplier'] 144 | sb = params['spatial_bin'] 145 | bn = params['batch_norm'] 146 | nsize = params['net_input_size'] 147 | 148 | self.relu = nn.ReLU() 149 | 150 | # splat features 151 | n_layers_splat = int(np.log2(nsize/sb)) 152 | self.splat_features = nn.ModuleList() 153 | prev_ch = 3 154 | for i in range(n_layers_splat): 155 | use_bn = bn if i > 0 else False 156 | self.splat_features.append(ConvBlock(prev_ch, cm*(2**i)*lb, 3, stride=2, batch_norm=use_bn)) 157 | prev_ch = splat_ch = cm*(2**i)*lb 158 | 159 | # global features 160 | n_layers_global = int(np.log2(sb/4)) 161 | self.global_features_conv = nn.ModuleList() 162 | self.global_features_fc = nn.ModuleList() 163 | for i in range(n_layers_global): 164 | self.global_features_conv.append(ConvBlock(prev_ch, cm*8*lb, 3, stride=2, batch_norm=bn)) 165 | prev_ch = cm*8*lb 166 | 167 | n_total = n_layers_splat + n_layers_global 168 | prev_ch = prev_ch * (nsize/2**n_total)**2 169 | self.global_features_fc.append(FC(prev_ch, 32*cm*lb, batch_norm=bn)) 170 | self.global_features_fc.append(FC(32*cm*lb, 16*cm*lb, batch_norm=bn)) 171 | self.global_features_fc.append(FC(16*cm*lb, 8*cm*lb, activation=None, batch_norm=bn)) 172 | 173 | # local features 174 | self.local_features = nn.ModuleList() 175 | self.local_features.append(ConvBlock(splat_ch, 8*cm*lb, 3, batch_norm=bn)) 176 | self.local_features.append(ConvBlock(8*cm*lb, 8*cm*lb, 3, activation=None, use_bias=False)) 177 | 178 | # predicton 179 | self.conv_out = ConvBlock(8*cm*lb, lb*nout*nin, 1, padding=0, activation=None)#,batch_norm=True) 180 | 181 | 182 | def forward(self, lowres_input): 183 | params = self.params 184 | bs = lowres_input.shape[0] 185 | lb = params['luma_bins'] 186 | cm = params['channel_multiplier'] 187 | sb = params['spatial_bin'] 188 | 189 | x = lowres_input 190 | for layer in self.splat_features: 191 | x = layer(x) 192 | splat_features = x 193 | 194 | for layer in self.global_features_conv: 195 | x = layer(x) 196 | x = x.view(bs, -1) 197 | for layer in self.global_features_fc: 198 | x = layer(x) 199 | global_features = x 200 | 201 | x = splat_features 202 | for layer in self.local_features: 203 | x = layer(x) 204 | local_features = x 205 | 206 | fusion_grid = local_features 207 | fusion_global = global_features.view(bs,8*cm*lb,1,1) 208 | fusion = self.relu( fusion_grid + fusion_global ) 209 | 210 | x = self.conv_out(fusion) 211 | s = x.shape 212 | y = torch.stack(torch.split(x, self.nin*self.nout, 1),2) 213 | # y = torch.stack(torch.split(y, self.nin, 1),3) 214 | # print(y.shape) 215 | # x = x.view(bs,self.nin*self.nout,lb,sb,sb) # B x Coefs x Luma x Spatial x Spatial 216 | # print(x.shape) 217 | return y 218 | 219 | 220 | class HDRPointwiseNN(nn.Module): 221 | 222 | def __init__(self, params): 223 | super(HDRPointwiseNN, self).__init__() 224 | self.coeffs = Coeffs(params=params) 225 | self.guide = GuideNN(params=params) 226 | self.slice = Slice() 227 | self.apply_coeffs = ApplyCoeffs() 228 | # self.bsa = bsa.BilateralSliceApply() 229 | 230 | def forward(self, lowres, fullres): 231 | coeffs = self.coeffs(lowres) 232 | guide = self.guide(fullres) 233 | slice_coeffs = self.slice(coeffs, guide) 234 | out = self.apply_coeffs(slice_coeffs, fullres) 235 | # out = bsa.bsa(coeffs,guide,fullres) 236 | return out 237 | 238 | 239 | ######################################################################################################### 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/creotiv/hdrnet-pytorch/5335ce7e0e32c4c7416f562bf81175963865e93d/ops/__init__.py -------------------------------------------------------------------------------- /ops/cuda/bilateral_slice.cu.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Google Inc. 2 | // 3 | // Licensed under the Apache License, Version 2.0 (the "License"); 4 | // you may not use this file except in compliance with the License. 5 | // You may obtain a copy of the License at 6 | // 7 | // http://www.apache.org/licenses/LICENSE-2.0 8 | // 9 | // Unless required by applicable law or agreed to in writing, software 10 | // distributed under the License is distributed on an "AS IS" BASIS, 11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | // See the License for the specific language governing permissions and 13 | // limitations under the License. 14 | 15 | 16 | #include 17 | #include 18 | #include "math.h" 19 | 20 | extern THCState *state; 21 | 22 | __device__ float diff_abs(float x) { 23 | float eps = 1e-8; 24 | return sqrt(x*x+eps); 25 | } 26 | 27 | __device__ float d_diff_abs(float x) { 28 | float eps = 1e-8; 29 | return x/sqrt(x*x+eps); 30 | } 31 | 32 | __device__ float weight_z(float x) { 33 | float abx = diff_abs(x); 34 | return max(1.0f-abx, 0.0f); 35 | } 36 | 37 | __device__ float d_weight_z(float x) { 38 | float abx = diff_abs(x); 39 | if(abx > 1.0f) { 40 | return 0.0f; 41 | // return abx; 42 | } else { 43 | return d_diff_abs(x); 44 | } 45 | } 46 | 47 | __global__ void BilateralSliceApplyKernel( 48 | int64_t nthreads, 49 | const float* grid, const float* guide, const float* input, 50 | const int bs, const int h, const int w, 51 | const int gh, const int gw, const int gd, 52 | const int input_chans, const int output_chans, 53 | float* out) 54 | { 55 | // - Samples centered at 0.5. 56 | // - Repeating boundary conditions 57 | 58 | int grid_chans = (input_chans+1)*output_chans; 59 | int coeff_stride = input_chans+1; 60 | 61 | const int64_t idx = blockIdx.x*blockDim.x + threadIdx.x; 62 | if(idx < nthreads) { 63 | int x = idx % w; 64 | int y = (idx / w) % h; 65 | int out_c = (idx / (w*h)) % output_chans; 66 | int b = (idx / (output_chans*w*h)); 67 | 68 | float gx = (x+0.5f)*gw/(1.0f*w); 69 | float gy = (y+0.5f)*gh/(1.0f*h); 70 | float gz = guide[x + w*(y + h*b)]*gd; 71 | 72 | int fx = static_cast(floor(gx-0.5f)); 73 | int fy = static_cast(floor(gy-0.5f)); 74 | int fz = static_cast(floor(gz-0.5f)); 75 | 76 | // Grid strides 77 | int sx = 1; 78 | int sy = gw; 79 | int sz = gw*gh; 80 | int sc = gw*gh*gd; 81 | int sb = grid_chans*gd*gw*gh; 82 | 83 | float value = 0.0f; 84 | for (int in_c = 0; in_c < coeff_stride; ++in_c) { 85 | float coeff_sample = 0.0f; 86 | for (int xx = fx; xx < fx+2; ++xx) { 87 | int x_ = max(min(xx, gw-1), 0); 88 | float wx = max(1.0f-abs(xx+0.5-gx), 0.0f); 89 | for (int yy = fy; yy < fy+2; ++yy) 90 | { 91 | int y_ = max(min(yy, gh-1), 0); 92 | float wy = max(1.0f-abs(yy+0.5-gy), 0.0f); 93 | for (int zz = fz; zz < fz+2; ++zz) 94 | { 95 | int z_ = max(min(zz, gd-1), 0); 96 | float wz = weight_z(zz+0.5-gz); 97 | int grid_idx = 98 | sc*(coeff_stride*out_c + in_c) + sz*z_ + sx*x_ + sy*y_ + sb*b; 99 | coeff_sample += grid[grid_idx]*wx*wy*wz; 100 | } 101 | } 102 | } // Grid trilinear interpolation 103 | if(in_c < input_chans) { 104 | int input_idx = x + w*(y + input_chans*(in_c + h*b)); 105 | value += coeff_sample*input[input_idx]; 106 | } else { // Offset term 107 | value += coeff_sample; 108 | } 109 | } 110 | out[idx] = value; 111 | } 112 | } 113 | 114 | 115 | __global__ void BilateralSliceApplyGridGradKernel( 116 | int64_t nthreads, 117 | const float* grid, const float* guide, const float* input, const float* d_output, 118 | const int bs, const int h, const int w, 119 | const int gh, const int gw, const int gd, 120 | const int input_chans, const int output_chans, 121 | float* out) 122 | { 123 | int grid_chans = (input_chans+1)*output_chans; 124 | int coeff_stride = input_chans+1; 125 | 126 | const int64_t idx = blockIdx.x*blockDim.x + threadIdx.x; 127 | if(idx < nthreads) { 128 | int gx = idx % gw; 129 | int gy = (idx / gw) % gh; 130 | int gz = (idx / (gh*gw)) % gd; 131 | int c = (idx / (gd*gh*gw)) % grid_chans; 132 | int b = (idx / (grid_chans*gd*gw*gh)); 133 | 134 | float scale_w = w*1.0/gw; 135 | float scale_h = h*1.0/gh; 136 | 137 | int left_x = static_cast(floor(scale_w*(gx+0.5-1))); 138 | int right_x = static_cast(ceil(scale_w*(gx+0.5+1))); 139 | int left_y = static_cast(floor(scale_h*(gy+0.5-1))); 140 | int right_y = static_cast(ceil(scale_h*(gy+0.5+1))); 141 | 142 | // Strides in the output 143 | int sx = 1; 144 | int sy = w; 145 | int sc = h*w; 146 | int sb = output_chans*w*h; 147 | 148 | // Strides in the input 149 | int isx = 1; 150 | int isy = w; 151 | int isc = h*w; 152 | int isb = output_chans*w*h; 153 | 154 | int out_c = c / coeff_stride; 155 | int in_c = c % coeff_stride; 156 | 157 | float value = 0.0f; 158 | for (int x = left_x; x < right_x; ++x) 159 | { 160 | int x_ = x; 161 | 162 | // mirror boundary 163 | if (x_ < 0) x_ = -x_-1; 164 | if (x_ >= w) x_ = 2*w-1-x_; 165 | 166 | float gx2 = (x+0.5f)/scale_w; 167 | float wx = max(1.0f-abs(gx+0.5-gx2), 0.0f); 168 | 169 | for (int y = left_y; y < right_y; ++y) 170 | { 171 | int y_ = y; 172 | 173 | // mirror boundary 174 | if (y_ < 0) y_ = -y_-1; 175 | if (y_ >= h) y_ = 2*h-1-y_; 176 | 177 | float gy2 = (y+0.5f)/scale_h; 178 | float wy = max(1.0f-abs(gy+0.5-gy2), 0.0f); 179 | 180 | int guide_idx = x_ + w*y_ + h*w*b; 181 | float gz2 = guide[guide_idx]*gd; 182 | float wz = weight_z(gz+0.5f-gz2); 183 | if ((gz==0 && gz2<0.5f) || (gz==gd-1 && gz2>gd-0.5f)) { 184 | wz = 1.0f; 185 | } 186 | 187 | int back_idx = sc*out_c + sx*x_ + sy*y_ + sb*b; 188 | if (in_c < input_chans) { 189 | int input_idx = isc*in_c + isx*x_ + isy*y_ + isb*b; 190 | value += wz*wx*wy*d_output[back_idx]*input[input_idx]; 191 | } else { // offset term 192 | value += wz*wx*wy*d_output[back_idx]; 193 | } 194 | } 195 | } 196 | out[idx] = value; 197 | } 198 | } 199 | 200 | 201 | __global__ void BilateralSliceApplyGuideGradKernel( 202 | int64_t nthreads, 203 | const float* grid, const float* guide, const float* input, const float* d_output, 204 | const int bs, const int h, const int w, 205 | const int gh, const int gw, const int gd, 206 | const int input_chans, const int output_chans, 207 | float* out) 208 | { 209 | int grid_chans = (input_chans+1)*output_chans; 210 | int coeff_stride = input_chans+1; 211 | 212 | const int64_t idx = blockIdx.x*blockDim.x + threadIdx.x; 213 | if(idx < nthreads) { 214 | int x = idx % w; 215 | int y = (idx / w) % h; 216 | int b = (idx / (w*h)); 217 | 218 | float gx = (x+0.5f)*gw/(1.0f*w); 219 | float gy = (y+0.5f)*gh/(1.0f*h); 220 | float gz = guide[x + w*(y + h*b)]*gd; 221 | 222 | int fx = static_cast(floor(gx-0.5f)); 223 | int fy = static_cast(floor(gy-0.5f)); 224 | int fz = static_cast(floor(gz-0.5f)); 225 | 226 | // Grid stride 227 | int sx = 1; 228 | int sy = gw; 229 | int sz = gw*gh; 230 | int sc = gw*gh*gd; 231 | int sb = grid_chans*gd*gw*gh; 232 | 233 | float out_sum = 0.0f; 234 | for (int out_c = 0; out_c < output_chans; ++out_c) { 235 | 236 | float in_sum = 0.0f; 237 | for (int in_c = 0; in_c < coeff_stride; ++in_c) { 238 | 239 | float grid_sum = 0.0f; 240 | for (int xx = fx; xx < fx+2; ++xx) { 241 | int x_ = max(min(xx, gw-1), 0); 242 | float wx = max(1.0f-abs(xx+0.5-gx), 0.0f); 243 | for (int yy = fy; yy < fy+2; ++yy) 244 | { 245 | int y_ = max(min(yy, gh-1), 0); 246 | float wy = max(1.0f-abs(yy+0.5-gy), 0.0f); 247 | for (int zz = fz; zz < fz+2; ++zz) 248 | { 249 | int z_ = max(min(zz, gd-1), 0); 250 | float dwz = gd*d_weight_z(zz+0.5-gz); 251 | 252 | int grid_idx = sc*(coeff_stride*out_c + in_c) + sz*z_ + sx*x_ + sy*y_ + sb*b; 253 | grid_sum += grid[grid_idx]*wx*wy*dwz; 254 | } // z 255 | } // y 256 | } // x, grid trilinear interp 257 | 258 | if(in_c < input_chans) { 259 | in_sum += grid_sum*input[input_chans*(x + w*(y + h*(in_c + input_chans*b)))]; 260 | } else { // offset term 261 | in_sum += grid_sum; 262 | } 263 | } // in_c 264 | 265 | out_sum += in_sum*d_output[x + w*(y + h*(out_c + output_chans*b))]; 266 | } // out_c 267 | 268 | out[idx] = out_sum; 269 | } 270 | } 271 | 272 | 273 | __global__ void BilateralSliceApplyInputGradKernel( 274 | int64_t nthreads, 275 | const float* grid, const float* guide, const float* input, const float* d_output, 276 | const int bs, const int h, const int w, 277 | const int gh, const int gw, const int gd, 278 | const int input_chans, const int output_chans, 279 | float* out) 280 | { 281 | int grid_chans = (input_chans+1)*output_chans; 282 | int coeff_stride = input_chans+1; 283 | 284 | const int64_t idx = blockIdx.x*blockDim.x + threadIdx.x; 285 | if(idx < nthreads) { 286 | int x = idx % w; 287 | int y = (idx / w) % h; 288 | int in_c = (idx / (w*h)) % input_chans; 289 | int b = (idx / (input_chans*w*h)); 290 | 291 | float gx = (x+0.5f)*gw/(1.0f*w); 292 | float gy = (y+0.5f)*gh/(1.0f*h); 293 | float gz = guide[x + w*(y + h*b)]*gd; 294 | 295 | int fx = static_cast(floor(gx-0.5f)); 296 | int fy = static_cast(floor(gy-0.5f)); 297 | int fz = static_cast(floor(gz-0.5f)); 298 | 299 | // Grid stride 300 | int sx = 1; 301 | int sy = gw; 302 | int sz = gw*gh; 303 | int sc = gw*gh*gd; 304 | int sb = grid_chans*gd*gw*gh; 305 | 306 | float value = 0.0f; 307 | for (int out_c = 0; out_c < output_chans; ++out_c) { 308 | float chan_val = 0.0f; 309 | for (int xx = fx; xx < fx+2; ++xx) { 310 | int x_ = max(min(xx, gw-1), 0); 311 | float wx = max(1.0f-abs(xx+0.5-gx), 0.0f); 312 | for (int yy = fy; yy < fy+2; ++yy) 313 | { 314 | int y_ = max(min(yy, gh-1), 0); 315 | float wy = max(1.0f-abs(yy+0.5-gy), 0.0f); 316 | for (int zz = fz; zz < fz+2; ++zz) 317 | { 318 | 319 | int z_ = max(min(zz, gd-1), 0); 320 | 321 | float wz = weight_z(zz+0.5-gz); 322 | 323 | int grid_idx = sc*(coeff_stride*out_c + in_c) + sz*z_ + sx*x_ + sy*y_ + sb*b; 324 | chan_val += grid[grid_idx]*wx*wy*wz; 325 | } // z 326 | } // y 327 | } // x, grid trilinear interp 328 | 329 | value += chan_val*d_output[x + w*(y + h*(out_c + output_chans*b))]; 330 | } // out_c 331 | out[idx] = value; 332 | } 333 | } 334 | 335 | 336 | // // -- KERNEL LAUNCHERS --------------------------------------------------------- 337 | void BilateralSliceApplyKernelLauncher( 338 | int bs, int gh, int gw, int gd, 339 | int input_chans, int output_chans, 340 | int h, int w, 341 | const float* const grid, const float* const guide, const float* const input, 342 | float* const out) 343 | { 344 | int total_count = bs*h*w*output_chans; 345 | const int64_t block_sz = 512; 346 | const int64_t nblocks = (total_count + block_sz - 1) / block_sz; 347 | if (total_count > 0) { 348 | BilateralSliceApplyKernel<<>>( 349 | total_count, grid, guide, input, 350 | bs, h, w, gh, gw, gd, input_chans, output_chans, 351 | out); 352 | THCudaCheck(cudaPeekAtLastError()); 353 | } 354 | } 355 | 356 | 357 | void BilateralSliceApplyGradKernelLauncher( 358 | int bs, int gh, int gw, int gd, 359 | int input_chans, int output_chans, int h, int w, 360 | const float* grid, const float* guide, const float* input, 361 | const float* d_output, 362 | float* d_grid, float* d_guide, float* d_input) 363 | { 364 | int64_t coeff_chans = (input_chans+1)*output_chans; 365 | const int64_t block_sz = 512; 366 | int64_t grid_count = bs*gh*gw*gd*coeff_chans; 367 | if (grid_count > 0) { 368 | const int64_t nblocks = (grid_count + block_sz - 1) / block_sz; 369 | BilateralSliceApplyGridGradKernel<<>>( 370 | grid_count, grid, guide, input, d_output, 371 | bs, h, w, gh, gw, gd, 372 | input_chans, output_chans, 373 | d_grid); 374 | } 375 | 376 | int64_t guide_count = bs*h*w; 377 | if (guide_count > 0) { 378 | const int64_t nblocks = (guide_count + block_sz - 1) / block_sz; 379 | BilateralSliceApplyGuideGradKernel<<>>( 380 | guide_count, grid, guide, input, d_output, 381 | bs, h, w, gh, gw, gd, 382 | input_chans, output_chans, 383 | d_guide); 384 | } 385 | 386 | int64_t input_count = bs*h*w*input_chans; 387 | if (input_count > 0) { 388 | const int64_t nblocks = (input_count + block_sz - 1) / block_sz; 389 | BilateralSliceApplyInputGradKernel<<>>( 390 | input_count, grid, guide, input, d_output, 391 | bs, h, w, gh, gw, gd, 392 | input_chans, output_chans, 393 | d_input); 394 | } 395 | } 396 | -------------------------------------------------------------------------------- /ops/cuda/bilateral_slice.h: -------------------------------------------------------------------------------- 1 | #ifndef BILATERAL_SLICE_H_SZ3NVCCJ 2 | #define BILATERAL_SLICE_H_SZ3NVCCJ 3 | 4 | // #ifdef __cplusplus 5 | // extern "C" { 6 | // #endif 7 | 8 | void BilateralSliceApplyKernelLauncher( 9 | int bs, int gh, int gw, int gd, 10 | int input_chans, int output_chans, int h, int w, 11 | const float* const grid, const float* const guide, const float* const input, 12 | float* const out); 13 | 14 | void BilateralSliceApplyGradKernelLauncher( 15 | int bs, int gh, int gw, int gd, 16 | int input_chans, int output_chans, int h, int w, 17 | const float* grid, const float* guide, const float* input, 18 | const float* d_output, 19 | float* d_grid, float* d_guide, float* d_input); 20 | // #ifdef __cplusplus 21 | // } 22 | // #endif 23 | 24 | #endif /* end of include guard: BILATERAL_SLICE_H_SZ3NVCCJ */ 25 | -------------------------------------------------------------------------------- /ops/ops.py: -------------------------------------------------------------------------------- 1 | """Wrap our operator (and gradient) in autograd.""" 2 | 3 | # We need to import torch before loading the custom modules 4 | import torch 5 | import bilateral_slice_apply as ops 6 | 7 | class BilateralSliceApplyFunction(torch.autograd.Function): 8 | # grid (x,y,z,c,n) 9 | # guide (x,y,n) 10 | # input (x,y,c,n) 11 | 12 | @staticmethod 13 | def forward(ctx, grid, guide, image): 14 | grid = grid.permute(4,3,2,1,0).contiguous() 15 | guide = guide.permute(3,2,1,0).contiguous() 16 | image = image.permute(3,2,1,0).contiguous() 17 | out = image.new() 18 | out.resize_(image.shape) 19 | ops.bilateral_slice_apply_cuda_float32(grid, guide, image, out) 20 | ctx.save_for_backward(grid, guide, image) 21 | 22 | return out.permute(3,2,1,0).contiguous() 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | grid = ctx.saved_tensors[0] 27 | guide = ctx.saved_tensors[1] 28 | image = ctx.saved_tensors[2] 29 | d_grid = grid.new() 30 | d_grid.resize_(grid.shape) 31 | d_guide = guide.new() 32 | d_guide.resize_(guide.shape) 33 | d_image = image.new() 34 | d_image.resize_(image.shape) 35 | 36 | grad_output = grad_output.clone() 37 | ops.bilateral_slice_apply_cuda_float32_grad(grid, guide, image, grad_output, d_grid, d_guide) 38 | return d_grid, d_guide, None 39 | 40 | 41 | class BilateralSliceApply(torch.nn.Module): 42 | def __init__(self): 43 | super(BilateralSliceApply, self).__init__() 44 | 45 | def forward(self, grid, guide, image): 46 | return BilateralSliceApplyFunction.apply(grid, guide, image) 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py==2.9.0 2 | imageio==2.5.0 3 | numpy==1.17.0 4 | opencv-python==4.1.0.25 5 | Pillow==7.1.0 6 | scikit-image==0.15.0 7 | torch==1.1.0 8 | torchsummary==1.5.1 9 | torchvision==0.3.0 10 | -------------------------------------------------------------------------------- /slice.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def lerp_weight(x, xs): 4 | """Linear interpolation weight from a sample at x to xs. 5 | Returns the linear interpolation weight of a "query point" at coordinate `x` 6 | with respect to a "sample" at coordinate `xs`. 7 | The integer coordinates `x` are at pixel centers. 8 | The floating point coordinates `xs` are at pixel edges. 9 | (OpenGL convention). 10 | Args: 11 | x: "Query" point position. 12 | xs: "Sample" position. 13 | Returns: 14 | - 1 when x = xs. 15 | - 0 when |x - xs| > 1. 16 | """ 17 | dx = x - xs 18 | abs_dx = abs(dx) 19 | return torch.maximum(torch.tensor(1.0).to(x.device) - abs_dx, torch.tensor(0.0).to(x.device)) 20 | 21 | 22 | def smoothed_abs(x, eps): 23 | """A smoothed version of |x| with improved numerical stability.""" 24 | return torch.sqrt(torch.multiply(x, x) + eps) 25 | 26 | 27 | def smoothed_lerp_weight(x, xs): 28 | """Smoothed version of `LerpWeight` with gradients more suitable for backprop. 29 | Let f(x, xs) = LerpWeight(x, xs) 30 | = max(1 - |x - xs|, 0) 31 | = max(1 - |dx|, 0) 32 | f is not smooth when: 33 | - |dx| is close to 0. We smooth this by replacing |dx| with 34 | SmoothedAbs(dx, eps) = sqrt(dx * dx + eps), which has derivative 35 | dx / sqrt(dx * dx + eps). 36 | - |dx| = 1. When smoothed, this happens when dx = sqrt(1 - eps). Like ReLU, 37 | We just ignore this (in the implementation below, when the floats are 38 | exactly equal, we choose the SmoothedAbsGrad path since it is more useful 39 | than returning a 0 gradient). 40 | Args: 41 | x: "Query" point position. 42 | xs: "Sample" position. 43 | eps: a small number. 44 | Returns: 45 | max(1 - |dx|, 0) where |dx| is smoothed_abs(dx). 46 | """ 47 | eps = torch.tensor(1e-8).to(torch.float32).to(x.device) 48 | dx = x - xs 49 | abs_dx = smoothed_abs(dx, eps) 50 | return torch.maximum(torch.tensor(1.0).to(x.device) - abs_dx, torch.tensor(0.0).to(x.device)) 51 | 52 | def _bilateral_slice(grid, guide): 53 | """Slices a bilateral grid using the a guide image. 54 | Args: 55 | grid: The bilateral grid with shape (gh, gw, gd, gc). 56 | guide: A guide image with shape (h, w). Values must be in the range [0, 1]. 57 | Returns: 58 | sliced: An image with shape (h, w, gc), computed by trilinearly 59 | interpolating for each grid channel c the grid at 3D position 60 | [(i + 0.5) * gh / h, 61 | (j + 0.5) * gw / w, 62 | guide(i, j) * gd] 63 | """ 64 | dev = grid.device 65 | ii, jj = torch.meshgrid( 66 | [torch.arange(guide.shape[0]).to(dev), torch.arange(guide.shape[1]).to(dev)], indexing='ij') 67 | 68 | scale_i = grid.shape[0] / guide.shape[0] 69 | scale_j = grid.shape[1] / guide.shape[1] 70 | 71 | gif = (ii + 0.5) * scale_i 72 | gjf = (jj + 0.5) * scale_j 73 | gkf = guide * grid.shape[2] 74 | 75 | # Compute trilinear interpolation weights without clamping. 76 | gi0 = torch.floor(gif - 0.5).to(torch.int32) 77 | gj0 = torch.floor(gjf - 0.5).to(torch.int32) 78 | gk0 = torch.floor(gkf - 0.5).to(torch.int32) 79 | gi1 = gi0 + 1 80 | gj1 = gj0 + 1 81 | gk1 = gk0 + 1 82 | 83 | wi0 = lerp_weight(gi0 + 0.5, gif) 84 | wi1 = lerp_weight(gi1 + 0.5, gif) 85 | wj0 = lerp_weight(gj0 + 0.5, gjf) 86 | wj1 = lerp_weight(gj1 + 0.5, gjf) 87 | wk0 = smoothed_lerp_weight(gk0 + 0.5, gkf) 88 | wk1 = smoothed_lerp_weight(gk1 + 0.5, gkf) 89 | 90 | w_000 = wi0 * wj0 * wk0 91 | w_001 = wi0 * wj0 * wk1 92 | w_010 = wi0 * wj1 * wk0 93 | w_011 = wi0 * wj1 * wk1 94 | w_100 = wi1 * wj0 * wk0 95 | w_101 = wi1 * wj0 * wk1 96 | w_110 = wi1 * wj1 * wk0 97 | w_111 = wi1 * wj1 * wk1 98 | 99 | # But clip when indexing into `grid`. 100 | gi0c = gi0.clip(0, grid.shape[0] - 1).to(torch.long) 101 | gj0c = gj0.clip(0, grid.shape[1] - 1).to(torch.long) 102 | gk0c = gk0.clip(0, grid.shape[2] - 1).to(torch.long) 103 | 104 | gi1c = (gi0 + 1).clip(0, grid.shape[0] - 1).to(torch.long) 105 | gj1c = (gj0 + 1).clip(0, grid.shape[1] - 1).to(torch.long) 106 | gk1c = (gk0 + 1).clip(0, grid.shape[2] - 1).to(torch.long) 107 | 108 | # ijk: 0 means floor, 1 means ceil. 109 | grid_val_000 = grid[gi0c, gj0c, gk0c, :] 110 | grid_val_001 = grid[gi0c, gj0c, gk1c, :] 111 | grid_val_010 = grid[gi0c, gj1c, gk0c, :] 112 | grid_val_011 = grid[gi0c, gj1c, gk1c, :] 113 | grid_val_100 = grid[gi1c, gj0c, gk0c, :] 114 | grid_val_101 = grid[gi1c, gj0c, gk1c, :] 115 | grid_val_110 = grid[gi1c, gj1c, gk0c, :] 116 | grid_val_111 = grid[gi1c, gj1c, gk1c, :] 117 | 118 | # Append a singleton "channels" dimension. 119 | w_000, w_001, w_010, w_011, w_100, w_101, w_110, w_111 = torch.atleast_3d( 120 | w_000, w_001, w_010, w_011, w_100, w_101, w_110, w_111) 121 | 122 | # TODO(jiawen): Cache intermediates and pass them in. 123 | # Just pass out w_ijk and the same ones multiplied by by dwk. 124 | return (torch.multiply(w_000, grid_val_000) + 125 | torch.multiply(w_001, grid_val_001) + 126 | torch.multiply(w_010, grid_val_010) + 127 | torch.multiply(w_011, grid_val_011) + 128 | torch.multiply(w_100, grid_val_100) + 129 | torch.multiply(w_101, grid_val_101) + 130 | torch.multiply(w_110, grid_val_110) + 131 | torch.multiply(w_111, grid_val_111)) 132 | 133 | @torch.jit.script 134 | def batch_bilateral_slice(grid, guide): 135 | res = [] 136 | for i in range(grid.shape[0]): 137 | res.append(_bilateral_slice(grid[i], guide[i]).unsqueeze(0)) 138 | return torch.concat(res, 0) 139 | 140 | def trace_bilateral_slice(grid, guide): 141 | return batch_bilateral_slice(grid, guide) 142 | 143 | 144 | # grid: The bilateral grid with shape (gh, gw, gd, gc). 145 | # guide: A guide image with shape (h, w). Values must be in the range [0, 1]. 146 | 147 | grid = torch.rand(1, 3, 3, 8, 12).cuda() 148 | guide = torch.rand(1,16, 16).cuda() 149 | 150 | bilateral_slice = torch.jit.trace( 151 | trace_bilateral_slice, (grid, guide)) 152 | 153 | bilateral_slice(grid, guide) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import skimage.exposure 6 | import torch 7 | from torchvision import transforms 8 | 9 | from model import HDRPointwiseNN 10 | from utils import load_image, resize, load_params 11 | import matplotlib.pyplot as plt 12 | 13 | def test(ckpt, args={}): 14 | state_dict = torch.load(ckpt) 15 | state_dict, params = load_params(state_dict) 16 | params.update(args) 17 | 18 | device = torch.device("cuda") 19 | tensor = transforms.Compose([ 20 | transforms.ToTensor(), 21 | ]) 22 | low = tensor(resize(load_image(params['test_image']),params['net_input_size'],strict=True).astype(np.float32)).repeat(1,1,1,1)/255 23 | full = tensor(load_image(params['test_image']).astype(np.float32)).repeat(1,1,1,1)/255 24 | 25 | low = low.to(device) 26 | full = full.to(device) 27 | with torch.no_grad(): 28 | model = HDRPointwiseNN(params=params) 29 | model.load_state_dict(state_dict) 30 | model.eval() 31 | model.to(device) 32 | img = model(low, full) 33 | print('MIN:',torch.min(img),'MAX:',torch.max(img)) 34 | img = (img.cpu().detach().numpy()).transpose(0,2,3,1)[0] 35 | img = skimage.exposure.rescale_intensity(img, out_range=(0.0,255.0)).astype(np.uint8) 36 | cv2.imwrite(params['test_out'], img[...,::-1]) 37 | 38 | if __name__ == '__main__': 39 | import argparse 40 | 41 | parser = argparse.ArgumentParser(description='HDRNet Inference') 42 | parser.add_argument('--checkpoint', type=str, help='model state path') 43 | parser.add_argument('--input', type=str, dest="test_image", help='image path') 44 | parser.add_argument('--output', type=str, dest="test_out", help='output image path') 45 | 46 | args = vars(parser.parse_args()) 47 | 48 | test(args['checkpoint'], args) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from test import test 4 | 5 | import cv2 6 | import random 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import PIL 10 | import torch 11 | from torch.optim import SGD, Adam, RAdam, RMSprop 12 | from torch.utils.data import DataLoader 13 | import adabound 14 | 15 | from dataset import HDRDataset 16 | from metrics import psnr 17 | from model import HDRPointwiseNN, L2LOSS 18 | from utils import load_image, save_params, get_latest_ckpt, load_params 19 | 20 | torch.manual_seed(13) 21 | random.seed(13) 22 | 23 | 24 | def train(params=None): 25 | os.makedirs(params['ckpt_path'], exist_ok=True) 26 | 27 | device = torch.device("cuda") 28 | 29 | train_dataset = HDRDataset(params['dataset'], params=params, suffix=params['dataset_suffix']) 30 | train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True) 31 | 32 | model = HDRPointwiseNN(params=params) 33 | ckpt = get_latest_ckpt(params['ckpt_path']) 34 | if ckpt: 35 | print('Loading previous state:', ckpt) 36 | state_dict = torch.load(ckpt) 37 | state_dict,_ = load_params(state_dict) 38 | model.load_state_dict(state_dict) 39 | model.to(device) 40 | 41 | mseloss = torch.nn.SmoothL1Loss()#L2LOSS()#torch.nn.MSELoss()#torch.nn.SmoothL1Loss()# 42 | optimizer = Adam(model.parameters(), params['lr'], eps=1e-7)#, weight_decay=1e-5) 43 | # optimizer = SGD(model.parameters(), params['lr'], momentum=0.9) 44 | # optimizer = adabound.AdaBound(model.parameters(), lr=params['lr'], final_lr=0.1) 45 | 46 | count = 0 47 | for e in range(params['epochs']): 48 | model.train() 49 | for i, (low, full, target) in enumerate(train_loader): 50 | optimizer.zero_grad() 51 | 52 | low = low.to(device) 53 | full = full.to(device) 54 | t = target.to(device) 55 | res = model(low, full) 56 | 57 | total_loss = mseloss(t,res) 58 | total_loss.backward() 59 | 60 | if (count+1) % params['log_interval'] == 0: 61 | _psnr = psnr(res,t).item() 62 | loss = total_loss.item() 63 | print(e, count, loss, _psnr) 64 | 65 | optimizer.step() 66 | if (count+1) % params['ckpt_interval'] == 0: 67 | print('@@ MIN:',torch.min(res),'MAX:',torch.max(res)) 68 | model.eval().cpu() 69 | ckpt_model_filename = "ckpt_"+str(e)+'_' + str(count) + ".pth" 70 | ckpt_model_path = os.path.join(params['ckpt_path'], ckpt_model_filename) 71 | state = save_params(model.state_dict(), params) 72 | torch.save(state, ckpt_model_path) 73 | test(ckpt_model_path) 74 | model.to(device).train() 75 | count += 1 76 | 77 | if __name__ == '__main__': 78 | import argparse 79 | 80 | parser = argparse.ArgumentParser(description='HDRNet Inference') 81 | parser.add_argument('--ckpt-path', type=str, default='./ch', help='Model checkpoint path') 82 | parser.add_argument('--test-image', type=str, dest="test_image", help='Test image path') 83 | parser.add_argument('--test-out', type=str, default='out.png', dest="test_out", help='Output test image path') 84 | 85 | parser.add_argument('--luma-bins', type=int, default=8) 86 | parser.add_argument('--channel-multiplier', default=1, type=int) 87 | parser.add_argument('--spatial-bin', type=int, default=16) 88 | parser.add_argument('--guide-complexity', type=int, default=16) 89 | parser.add_argument('--batch-norm', action='store_true', help='If set use batch norm') 90 | parser.add_argument('--net-input-size', type=int, default=256, help='Size of low-res input') 91 | parser.add_argument('--net-output-size', type=int, default=512, help='Size of full-res input/output') 92 | 93 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate') 94 | parser.add_argument('--batch-size', type=int, default=6) 95 | parser.add_argument('--epochs', type=int, default=10) 96 | parser.add_argument('--log-interval', type=int, default=10) 97 | parser.add_argument('--ckpt-interval', type=int, default=100) 98 | parser.add_argument('--dataset', type=str, default='', help='Dataset path with input/output dirs', required=True) 99 | parser.add_argument('--dataset-suffix', type=str, default='', help='Add suffix to input/output dirs. Useful when train on different dataset image sizes') 100 | 101 | params = vars(parser.parse_args()) 102 | 103 | print('PARAMS:') 104 | print(params) 105 | 106 | train(params=params) 107 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import os 4 | import glob 5 | 6 | def resize(img, size=512, strict=False): 7 | short = min(img.shape[:2]) 8 | scale = size/short 9 | if not strict: 10 | img = cv2.resize(img, (round( 11 | img.shape[1]*scale), round(img.shape[0]*scale)), interpolation=cv2.INTER_NEAREST) 12 | else: 13 | img = cv2.resize(img, (size,size), interpolation=cv2.INTER_NEAREST) 14 | return img 15 | 16 | 17 | def crop(img, size=512): 18 | try: 19 | y, x = random.randint( 20 | 0, img.shape[0]-size), random.randint(0, img.shape[1]-size) 21 | except Exception as e: 22 | y, x = 0, 0 23 | return img[y:y+size, x:x+size, :] 24 | 25 | 26 | def load_image(filename, size=None, use_crop=False): 27 | img = cv2.imread(filename, cv2.IMREAD_COLOR) 28 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 29 | if size: 30 | img = resize(img, size=size) 31 | if use_crop: 32 | img = crop(img, size) 33 | return img 34 | 35 | def get_latest_ckpt(path): 36 | try: 37 | list_of_files = glob.glob(os.path.join(path,'*')) 38 | latest_file = max(list_of_files, key=os.path.getctime) 39 | return latest_file 40 | except ValueError: 41 | return None 42 | 43 | def save_params(state, params): 44 | state['model_params'] = params 45 | return state 46 | 47 | def load_params(state): 48 | params = state['model_params'] 49 | del state['model_params'] 50 | return state, params --------------------------------------------------------------------------------