├── README.md └── hdrnet.py /README.md: -------------------------------------------------------------------------------- 1 | # hdrnet 2 | 3 | An pytorch re-implementation of 'Deep Bilateral Learning for Real-Time Image Enhancements', SIGGRAPH 2017 4 | 5 | ## Network 6 | 7 | Slice operation is implemetated with `torch.nn.functional.grid_sample`. -------------------------------------------------------------------------------- /hdrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from collections import OrderedDict 6 | 7 | 8 | class conv_block(nn.Module): 9 | def __init__(self, inc , outc, kernel_size=3, padding=1, stride=1, use_bias=True, activation=nn.ReLU(inplace=True), is_BN=False): 10 | super(conv_block, self).__init__() 11 | if is_BN: 12 | self.conv = nn.Sequential(OrderedDict([ 13 | ("conv", nn.Conv2d(inc, outc, kernel_size, padding=padding, stride=stride, bias=use_bias)), 14 | ("bn", nn.BatchNorm2d(outc)), 15 | ("act", activation) 16 | ])) 17 | elif activation is not None: 18 | self.conv = nn.Sequential(OrderedDict([ 19 | ("conv", nn.Conv2d(inc, outc, kernel_size, padding=padding, stride=stride, bias=use_bias)), 20 | ("act", activation) 21 | ])) 22 | else: 23 | self.conv = nn.Sequential(OrderedDict([ 24 | ("conv", nn.Conv2d(inc, outc, kernel_size, padding=padding, stride=stride, bias=use_bias)), 25 | ])) 26 | 27 | def forward(self, input): 28 | return self.conv(input) 29 | 30 | class fc(nn.Module): 31 | def __init__(self, inc, outc, activation=None, is_BN=False): 32 | super(fc, self).__init__() 33 | if is_BN: 34 | self.fc = nn.Sequential(OrderedDict([ 35 | ("fc", nn.Linear(inc, outc)), 36 | ("bn", nn.BatchNorm1d(outc)), 37 | ("act", activation), 38 | ])) 39 | elif activation is not None: 40 | self.fc = nn.Sequential(OrderedDict([ 41 | ("fc", nn.Linear(inc, outc)), 42 | ("act", activation), 43 | ])) 44 | else: 45 | self.fc = nn.Sequential(OrderedDict([ 46 | ("fc", nn.Linear(inc, outc)), 47 | ])) 48 | 49 | def forward(self, input): 50 | return self.fc(input) 51 | 52 | class Guide(nn.Module): 53 | ''' 54 | pointwise neural net 55 | ''' 56 | def __init__(self, mode="PointwiseNN"): 57 | super(Guide, self).__init__() 58 | if mode == "PointwiseNN": 59 | self.mode = "PointwiseNN" 60 | self.conv1 = conv_block(3, 16, kernel_size=1, padding=0, is_BN=True) 61 | self.conv2 = conv_block(16, 1, kernel_size=1, padding=0, activation=nn.Tanh()) 62 | 63 | elif mode == "PointwiseCurve": 64 | # ccm: color correction matrix 65 | self.ccm = nn.Conv2d(3, 3, kernel_size=1) 66 | 67 | pixelwise_weight = torch.FloatTensor([1, 0, 0, 0, 1, 0, 0, 0, 1]) + torch.randn(1) * 1e-4 68 | pixelwise_bias = torch.FloatTensor([0, 0, 0]) 69 | 70 | self.conv1x1.weight.data.copy_(pixelwise_weight.view(3, 3, 1, 1)) 71 | self.conv1x1.bias.data.copy_(pixelwise_bias) 72 | 73 | # per channel curve 74 | pass 75 | 76 | # conv2d: num_output = 1 77 | self.conv1x1 = nn.Conv2d(3, 1, kernel_size=1) 78 | 79 | def forward(self, x): 80 | if self.mode == "PointwiseNN": 81 | guidemap = self.conv2(self.conv1(x)) 82 | 83 | return guidemap 84 | 85 | class Slice(nn.Module): 86 | def __init__(self): 87 | super(Slice, self).__init__() 88 | 89 | def forward(self, bilateral_grid, guidemap): 90 | N, _, H, W = guidemap.shape 91 | hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) 92 | hg = hg.type(torch.cuda.FloatTensor).repeat(N, 1, 1).unsqueeze(3) / (H-1) * 2 - 1 93 | wg = wg.type(torch.cuda.FloatTensor).repeat(N, 1, 1).unsqueeze(3) / (W-1) * 2 - 1 94 | guidemap = guidemap.permute(0,2,3,1).contiguous() 95 | guidemap_guide = torch.cat([guidemap, hg, wg], dim=3).unsqueeze(1) 96 | 97 | coeff = F.grid_sample(bilateral_grid, guidemap_guide) 98 | 99 | return coeff.squeeze(2) 100 | 101 | class Transform(nn.Module): 102 | def __init__(self): 103 | super(Transform, self).__init__() 104 | 105 | def forward(self, coeff, full_res_input): 106 | R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 3:4, :, :] 107 | G = torch.sum(full_res_input * coeff[:, 4:7, :, :], dim=1, keepdim=True) + coeff[:, 7:8, :, :] 108 | B = torch.sum(full_res_input * coeff[:, 8:11, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :] 109 | 110 | return torch.cat([R, G, B], dim=1) 111 | 112 | class HDRNet(nn.Module): 113 | def __init__(self, inc=3, outc=3): 114 | super(HDRNet, self).__init__() 115 | self.inc = inc 116 | self.outc = outc 117 | 118 | self.downsample = nn.Upsample(size=(256, 256), mode='bilinear', align_corners=True) 119 | self.activation = nn.ReLU(inplace=True) 120 | 121 | # ----------------------------------------------------------------------- 122 | splat_layers = [] 123 | for i in xrange(4): 124 | if i == 0: 125 | splat_layers.append(conv_block(self.inc, (2**i) * 8, kernel_size=3, padding=1, stride=2, activation=self.activation, is_BN=False)) 126 | else: 127 | splat_layers.append(conv_block((2**(i-1) * 8), (2**(i)) * 8, kernel_size=3, padding=1, stride=2, activation=self.activation, is_BN=True)) 128 | 129 | self.splat_conv = nn.Sequential(*splat_layers) 130 | 131 | # ----------------------------------------------------------------------- 132 | global_conv_layers = [ 133 | conv_block(64, 64, stride=2, activation=self.activation, is_BN=True), 134 | conv_block(64, 64, stride=2, activation=self.activation, is_BN=True), 135 | ] 136 | self.global_conv = nn.Sequential(*global_conv_layers) 137 | 138 | global_fc_layers = [ 139 | fc(1024, 256, activation=self.activation, is_BN=True), 140 | fc(256, 128, activation=self.activation, is_BN=True), 141 | fc(128, 64) 142 | ] 143 | self.global_fc = nn.Sequential(*global_fc_layers) 144 | 145 | # ----------------------------------------------------------------------- 146 | local_layers = [ 147 | conv_block(64, 64, activation=self.activation, is_BN=True), 148 | conv_block(64, 64, use_bias=False, activation=None, is_BN=False), 149 | ] 150 | self.local_conv = nn.Sequential(*local_layers) 151 | 152 | # ----------------------------------------------------------------------- 153 | self.linear = nn.Conv2d(64, 96, kernel_size=1) 154 | 155 | self.guide_func = Guide() 156 | self.slice_func = Slice() 157 | self.transform_func = Transform() 158 | 159 | def forward(self, full_res_input): 160 | low_res_input = self.downsample(full_res_input) 161 | bs, _, _, _ = low_res_input.size() 162 | 163 | splat_fea = self.splat_conv(low_res_input) 164 | 165 | local_fea = self.local_conv(splat_fea) 166 | 167 | global_fea = self.global_conv(splat_fea) 168 | global_fea = self.global_fc(global_fea.view(bs, -1)) 169 | 170 | fused = self.activation(global_fea.view(-1, 64, 1, 1) + local_fea) 171 | fused = self.linear(fused) 172 | 173 | bilateral_grid = fused.view(-1, 12, 8, 16, 16) 174 | 175 | guidemap = self.guide_func(full_res_input) 176 | coeff = self.slice_func(bilateral_grid, guidemap) 177 | output = self.transform_func(coeff, full_res_input) 178 | 179 | return output 180 | 181 | if __name__ == "__main__": 182 | from torchsummary import summary 183 | net = HDRNet().cuda() 184 | summary(net, (3,960,540)) 185 | print net 186 | print 'done' 187 | # slice_func = Slice() 188 | # bilateral_grid = torch.randn(4, 12, 8, 16, 16) 189 | # guide = torch.randn(4, 1, 256, 256) 190 | # slice_func(bilateral_grid, guide) --------------------------------------------------------------------------------