├── .gitignore
├── README.md
├── data
├── example_light
│ ├── rotate_light_00.txt
│ ├── rotate_light_01.txt
│ ├── rotate_light_02.txt
│ ├── rotate_light_03.txt
│ ├── rotate_light_04.txt
│ ├── rotate_light_05.txt
│ └── rotate_light_06.txt
├── obama.jpg
├── test.lst
├── train.lst
└── val.lst
├── model
├── defineHourglass_1024_gray_skip_matchFeature.py
└── defineHourglass_512_gray_skip.py
├── result
├── light_00.png
├── light_01.png
├── light_02.png
├── light_03.png
├── light_04.png
├── light_05.png
├── light_06.png
├── obama_00.jpg
├── obama_01.jpg
├── obama_02.jpg
├── obama_03.jpg
├── obama_04.jpg
├── obama_05.jpg
└── obama_06.jpg
├── testNetwork_demo_1024.py
├── testNetwork_demo_512.py
├── trained_model
├── trained_model_03.t7
└── trained_model_1024_03.t7
└── utils
├── utils_SH.py
├── utils_normal.py
└── utils_shtools.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.sh
3 | log_test/
4 | result_1024
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## Deep Single-Image Portrait Relighting [[Project Page]](http://zhhoper.github.io/dpr.html)
3 | Hao Zhou, Sunil Hadap, Kalyan Sunkavalli, David W. Jacobs. In ICCV, 2019
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
pytorch >= 1.0.0
28 |opencv >= 4.0.0
29 |shtools: https://shtools.oca.eu/shtools/ (optional)
30 | 31 | ### Notes 32 | We include an example image and seven example lightings in data. Note that different methods may have different coordinate system for Spherical Harmonics (SH), you may need to change the coordiante system if you use SH lighting from other sources. The coordinate system of our method is in accordance with shtools, we provide a function utils_normal.py in utils to help you tansfer the coordinate system from [bip2017](https://gravis.dmi.unibas.ch/PMM/data/bip/) and [sfsNet](https://senguptaumd.github.io/SfSNet/) to our coordinate system. To use utils_normal.py you need to install shtools. The code is for research purpose only. 33 | 34 | ### Data Preparation 35 | We publish the code for data preparation, please find it in (https://github.com/zhhoper/RI_render_DPR). 36 | 37 | ### Citation 38 | If you use this code for your research, please consider citing: 39 | ``` 40 | @InProceedings{DPR, 41 | title={Deep Single Portrait Image Relighting}, 42 | author = {Hao Zhou and Sunil Hadap and Kalyan Sunkavalli and David W. Jacobs}, 43 | booktitle={International Conference on Computer Vision (ICCV)}, 44 | year={2019} 45 | } 46 | ``` 47 | -------------------------------------------------------------------------------- /data/example_light/rotate_light_00.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617166185e-01 3 | 2.837846795150648915e-02 4 | 6.765292733937575687e-01 5 | -3.594067725393816914e-01 6 | 4.790996460111427574e-02 7 | -2.280054643781863066e-01 8 | -8.125983081159608712e-02 9 | 2.881082012687687932e-01 10 | -------------------------------------------------------------------------------- /data/example_light/rotate_light_01.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617170626e-01 3 | 5.466255701105990905e-01 4 | 3.996219229512094628e-01 5 | -2.615439760463462715e-01 6 | -2.511241554473071513e-01 7 | 6.495694866016435420e-02 8 | 3.510322039081858470e-01 9 | 1.189662732386344152e-01 10 | -------------------------------------------------------------------------------- /data/example_light/rotate_light_02.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617179508e-01 3 | 6.532524688468428486e-01 4 | -1.782088862752457814e-01 5 | 3.326676893441832261e-02 6 | -3.610566644446819295e-01 7 | 3.647561777790956361e-01 8 | -7.496419691318900735e-02 9 | -5.412289239602386531e-02 10 | -------------------------------------------------------------------------------- /data/example_light/rotate_light_03.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617186724e-01 3 | 2.679669346194941126e-01 4 | -6.218447693376460972e-01 5 | 3.030269583891490037e-01 6 | -1.991061409014726058e-01 7 | -6.162944418511027977e-02 8 | -3.176699976873690878e-01 9 | 1.920509612235956343e-01 10 | -------------------------------------------------------------------------------- /data/example_light/rotate_light_04.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617186724e-01 3 | -3.191031669056417219e-01 4 | -5.972188577671910803e-01 5 | 3.446016675533919993e-01 6 | 1.127753677656503223e-01 7 | -1.716692196540034188e-01 8 | 2.163406460637767315e-01 9 | 2.555824552121269688e-01 10 | -------------------------------------------------------------------------------- /data/example_light/rotate_light_05.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617178398e-01 3 | -6.658820752324799974e-01 4 | -1.228749652534838893e-01 5 | 1.266842924569576145e-01 6 | 3.397347243069742673e-01 7 | 3.036887095295650041e-01 8 | 2.213893524577207617e-01 9 | -1.886557316342868038e-02 10 | -------------------------------------------------------------------------------- /data/example_light/rotate_light_06.txt: -------------------------------------------------------------------------------- 1 | 1.084125496282453138e+00 2 | -4.642676300617169516e-01 3 | -5.112381993903207800e-01 4 | 4.439962822886048266e-01 5 | -1.866289387481862572e-01 6 | 3.108669041197227867e-01 7 | 2.021743042675238355e-01 8 | -3.148681770175290051e-01 9 | 3.974379604123656762e-02 10 | -------------------------------------------------------------------------------- /data/obama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/data/obama.jpg -------------------------------------------------------------------------------- /model/defineHourglass_1024_gray_skip_matchFeature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | import numpy as np 7 | import time 8 | 9 | # we define Hour Glass network based on the paper 10 | # Stacked Hourglass Networks for Human Pose Estimation 11 | # Alejandro Newell, Kaiyu Yang, and Jia Deng 12 | # the code is adapted from 13 | # https://github.com/umich-vl/pose-hg-train/blob/master/src/models/hg.lua 14 | 15 | 16 | def conv3X3(in_planes, out_planes, stride=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | # define the network 21 | class BasicBlock(nn.Module): 22 | def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | # batchNorm_type 0 means batchnormalization 25 | # 1 means instance normalization 26 | self.inplanes = inplanes 27 | self.outplanes = outplanes 28 | self.conv1 = conv3X3(inplanes, outplanes, 1) 29 | self.conv2 = conv3X3(outplanes, outplanes, 1) 30 | if batchNorm_type == 0: 31 | self.bn1 = nn.BatchNorm2d(outplanes) 32 | self.bn2 = nn.BatchNorm2d(outplanes) 33 | else: 34 | self.bn1 = nn.InstanceNorm2d(outplanes) 35 | self.bn2 = nn.InstanceNorm2d(outplanes) 36 | 37 | self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False) 38 | 39 | def forward(self, x): 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = F.relu(out) 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.inplanes != self.outplanes: 47 | out += self.shortcuts(x) 48 | else: 49 | out += x 50 | 51 | out = F.relu(out) 52 | return out 53 | 54 | class HourglassBlock(nn.Module): 55 | ''' 56 | define a basic block for hourglass neetwork 57 | ^-------------------------upper conv------------------- 58 | | | 59 | | V 60 | input------>downsample-->low1-->middle-->low2-->upsample-->+-->output 61 | NOTE about output: 62 | Since we need the lighting from the inner most layer, 63 | let's also output the results from middel layer 64 | ''' 65 | def __init__(self, inplane, mid_plane, middleNet, skipLayer=True): 66 | super(HourglassBlock, self).__init__() 67 | # upper branch 68 | self.skipLayer = True 69 | self.upper = BasicBlock(inplane, inplane, batchNorm_type=1) 70 | 71 | # lower branch 72 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) 73 | self.upSample = nn.Upsample(scale_factor=2, mode='nearest') 74 | self.low1 = BasicBlock(inplane, mid_plane) 75 | self.middle = middleNet 76 | self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1) 77 | 78 | def forward(self, x, light, count, skip_count): 79 | # we use count to indicate wich layer we are in 80 | # max_count indicates the from which layer, we would use skip connections 81 | out_upper = self.upper(x) 82 | out_lower = self.downSample(x) 83 | out_lower = self.low1(out_lower) 84 | out_lower, out_feat, out_middle = self.middle(out_lower, light, count+1, skip_count) 85 | out_lower = self.low2(out_lower) 86 | out_lower = self.upSample(out_lower) 87 | 88 | if count >= skip_count and self.skipLayer: 89 | # withSkip is true, then we use skip layer 90 | # easy for analysis 91 | out = out_lower + out_upper 92 | else: 93 | out = out_lower 94 | #out = out_upper 95 | return out, out_feat, out_middle 96 | 97 | class lightingNet(nn.Module): 98 | ''' 99 | define lighting network 100 | ''' 101 | def __init__(self, ncInput, ncOutput, ncMiddle): 102 | super(lightingNet, self).__init__() 103 | self.ncInput = ncInput 104 | self.ncOutput = ncOutput 105 | self.ncMiddle = ncMiddle 106 | 107 | # basic idea is to compute the average of the channel corresponding to lighting 108 | # using fully connected layers to get the lighting 109 | # then fully connected layers to get back to the output size 110 | 111 | self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False) 112 | self.predict_relu1 = nn.PReLU() 113 | self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False) 114 | 115 | self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False) 116 | self.post_relu1 = nn.PReLU() 117 | self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False) 118 | self.post_relu2 = nn.ReLU() # to be consistance with the original feature 119 | 120 | def forward(self, innerFeat, target_light, count, skip_count): 121 | x = innerFeat[:,0:self.ncInput,:,:] # lighting feature 122 | _, _, row, col = x.shape 123 | 124 | # predict lighting 125 | feat = x.mean(dim=(2,3), keepdim=True) 126 | light = self.predict_relu1(self.predict_FC1(feat)) 127 | light = self.predict_FC2(light) 128 | 129 | # get back the feature space 130 | upFeat = self.post_relu1(self.post_FC1(target_light)) 131 | upFeat = self.post_relu2(self.post_FC2(upFeat)) 132 | upFeat = upFeat.repeat((1,1,row, col)) 133 | innerFeat[:,0:self.ncInput,:,:] = upFeat 134 | return innerFeat, innerFeat[:, self.ncInput:, :, :], light 135 | 136 | 137 | class HourglassNet(nn.Module): 138 | ''' 139 | basic idea: low layers are shared, upper layers are different 140 | lighting should be estimated from the inner most layer 141 | NOTE: we split the bottle neck layer into albedo, normal and lighting 142 | ''' 143 | def __init__(self, baseFilter = 16, gray=True): 144 | super(HourglassNet, self).__init__() 145 | 146 | self.ncLight = 27 # number of channels for input to lighting network 147 | self.baseFilter = baseFilter 148 | 149 | # number of channles for output of lighting network 150 | if gray: 151 | self.ncOutLight = 9 # gray: channel is 1 152 | else: 153 | self.ncOutLight = 27 # color: channel is 3 154 | 155 | self.ncPre = self.baseFilter # number of channels for pre-convolution 156 | 157 | # number of channels 158 | self.ncHG3 = self.baseFilter 159 | self.ncHG2 = 2*self.baseFilter 160 | self.ncHG1 = 4*self.baseFilter 161 | self.ncHG0 = 8*self.baseFilter + self.ncLight 162 | 163 | self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2) 164 | self.pre_bn = nn.BatchNorm2d(self.ncPre) 165 | 166 | self.light = lightingNet(self.ncLight, self.ncOutLight, 128) 167 | self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light) 168 | self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0) 169 | self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1) 170 | self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2) 171 | 172 | self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1) 173 | self.bn_1 = nn.BatchNorm2d(self.ncPre) 174 | self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 175 | self.bn_2 = nn.BatchNorm2d(self.ncPre) 176 | self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 177 | self.bn_3 = nn.BatchNorm2d(self.ncPre) 178 | 179 | self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0) 180 | 181 | def forward(self, x, target_light, skip_count, oriImg=None): 182 | #feat = self.pre_conv(x) 183 | #feat = F.relu(self.pre_bn(feat)) 184 | feat = x 185 | # get the inner most features 186 | feat, out_feat, out_light = self.HG3(feat, target_light, 0, skip_count) 187 | #feat = F.relu(self.bn_1(self.conv_1(feat))) 188 | #feat = F.relu(self.bn_2(self.conv_2(feat))) 189 | #feat = F.relu(self.bn_3(self.conv_3(feat))) 190 | #out_img = self.output(feat) 191 | #out_img = torch.sigmoid(out_img) 192 | 193 | # for training, we need the original image 194 | # to supervise the bottle neck layer feature 195 | out_feat_ori = None 196 | if not oriImg is None: 197 | _, out_feat_ori, _ = self.HG3(oriImg, target_light, 0, skip_count) 198 | 199 | return out_feat, out_light, out_feat_ori, feat 200 | 201 | class HourglassNet_1024(nn.Module): 202 | ''' 203 | basic idea: low layers are shared, upper layers are different 204 | lighting should be estimated from the inner most layer 205 | NOTE: we split the bottle neck layer into albedo, normal and lighting 206 | ''' 207 | def __init__(self, model_512, baseFilter = 16, gray=True): 208 | super(HourglassNet_1024, self).__init__() 209 | self.model_512 = model_512 210 | self.ncLight = 27 # number of channels for input to lighting network 211 | self.baseFilter = baseFilter 212 | 213 | self.ncPre = self.baseFilter # number of channels for pre-convolution 214 | 215 | self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2) 216 | self.pre_bn = nn.BatchNorm2d(self.ncPre) 217 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) 218 | 219 | self.upSample = nn.Upsample(scale_factor=2, mode='nearest') 220 | 221 | self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1) 222 | self.bn_1 = nn.BatchNorm2d(self.ncPre) 223 | self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 224 | self.bn_2 = nn.BatchNorm2d(self.ncPre) 225 | self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 226 | self.bn_3 = nn.BatchNorm2d(self.ncPre) 227 | 228 | self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0) 229 | def forward(self, x, target_light, skip_count, oriImg=None): 230 | 231 | feat = self.pre_conv(x) 232 | feat = F.relu(self.pre_bn(feat)) 233 | feat = self.downSample(feat) 234 | 235 | if not oriImg is None: 236 | feat_ori = self.pre_conv(oriImg) 237 | feat_ori = F.relu(self.pre_bn(feat_ori)) 238 | oriImg = self.downSample(feat_ori) 239 | 240 | out_feat, out_light, out_feat_ori, feat = self.model_512(feat, target_light, skip_count, oriImg) 241 | feat = self.upSample(feat) 242 | 243 | feat = F.relu(self.bn_1(self.conv_1(feat))) 244 | feat = F.relu(self.bn_2(self.conv_2(feat))) 245 | feat = F.relu(self.bn_3(self.conv_3(feat))) 246 | out_img = self.output(feat) 247 | out_img = torch.sigmoid(out_img) 248 | 249 | return out_img, out_feat, out_light, out_feat_ori 250 | 251 | 252 | if __name__ == '__main__': 253 | pass 254 | -------------------------------------------------------------------------------- /model/defineHourglass_512_gray_skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | import numpy as np 7 | import time 8 | 9 | # we define Hour Glass network based on the paper 10 | # Stacked Hourglass Networks for Human Pose Estimation 11 | # Alejandro Newell, Kaiyu Yang, and Jia Deng 12 | # the code is adapted from 13 | # https://github.com/umich-vl/pose-hg-train/blob/master/src/models/hg.lua 14 | 15 | 16 | def conv3X3(in_planes, out_planes, stride=1): 17 | """3x3 convolution with padding""" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | # define the network 21 | class BasicBlock(nn.Module): 22 | def __init__(self, inplanes, outplanes, batchNorm_type=0, stride=1, downsample=None): 23 | super(BasicBlock, self).__init__() 24 | # batchNorm_type 0 means batchnormalization 25 | # 1 means instance normalization 26 | self.inplanes = inplanes 27 | self.outplanes = outplanes 28 | self.conv1 = conv3X3(inplanes, outplanes, 1) 29 | self.conv2 = conv3X3(outplanes, outplanes, 1) 30 | if batchNorm_type == 0: 31 | self.bn1 = nn.BatchNorm2d(outplanes) 32 | self.bn2 = nn.BatchNorm2d(outplanes) 33 | else: 34 | self.bn1 = nn.InstanceNorm2d(outplanes) 35 | self.bn2 = nn.InstanceNorm2d(outplanes) 36 | 37 | self.shortcuts = nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False) 38 | 39 | def forward(self, x): 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = F.relu(out) 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.inplanes != self.outplanes: 47 | out += self.shortcuts(x) 48 | else: 49 | out += x 50 | 51 | out = F.relu(out) 52 | return out 53 | 54 | class HourglassBlock(nn.Module): 55 | ''' 56 | define a basic block for hourglass neetwork 57 | ^-------------------------upper conv------------------- 58 | | | 59 | | V 60 | input------>downsample-->low1-->middle-->low2-->upsample-->+-->output 61 | NOTE about output: 62 | Since we need the lighting from the inner most layer, 63 | let's also output the results from middel layer 64 | ''' 65 | def __init__(self, inplane, mid_plane, middleNet, skipLayer=True): 66 | super(HourglassBlock, self).__init__() 67 | # upper branch 68 | self.skipLayer = True 69 | self.upper = BasicBlock(inplane, inplane, batchNorm_type=1) 70 | 71 | # lower branch 72 | self.downSample = nn.MaxPool2d(kernel_size=2, stride=2) 73 | self.upSample = nn.Upsample(scale_factor=2, mode='nearest') 74 | self.low1 = BasicBlock(inplane, mid_plane) 75 | self.middle = middleNet 76 | self.low2 = BasicBlock(mid_plane, inplane, batchNorm_type=1) 77 | 78 | def forward(self, x, light, count, skip_count): 79 | # we use count to indicate wich layer we are in 80 | # max_count indicates the from which layer, we would use skip connections 81 | out_upper = self.upper(x) 82 | out_lower = self.downSample(x) 83 | out_lower = self.low1(out_lower) 84 | out_lower, out_middle = self.middle(out_lower, light, count+1, skip_count) 85 | out_lower = self.low2(out_lower) 86 | out_lower = self.upSample(out_lower) 87 | 88 | if count >= skip_count and self.skipLayer: 89 | # withSkip is true, then we use skip layer 90 | # easy for analysis 91 | out = out_lower + out_upper 92 | else: 93 | out = out_lower 94 | #out = out_upper 95 | return out, out_middle 96 | 97 | class lightingNet(nn.Module): 98 | ''' 99 | define lighting network 100 | ''' 101 | def __init__(self, ncInput, ncOutput, ncMiddle): 102 | super(lightingNet, self).__init__() 103 | self.ncInput = ncInput 104 | self.ncOutput = ncOutput 105 | self.ncMiddle = ncMiddle 106 | 107 | # basic idea is to compute the average of the channel corresponding to lighting 108 | # using fully connected layers to get the lighting 109 | # then fully connected layers to get back to the output size 110 | 111 | self.predict_FC1 = nn.Conv2d(self.ncInput, self.ncMiddle, kernel_size=1, stride=1, bias=False) 112 | self.predict_relu1 = nn.PReLU() 113 | self.predict_FC2 = nn.Conv2d(self.ncMiddle, self.ncOutput, kernel_size=1, stride=1, bias=False) 114 | 115 | self.post_FC1 = nn.Conv2d(self.ncOutput, self.ncMiddle, kernel_size=1, stride=1, bias=False) 116 | self.post_relu1 = nn.PReLU() 117 | self.post_FC2 = nn.Conv2d(self.ncMiddle, self.ncInput, kernel_size=1, stride=1, bias=False) 118 | self.post_relu2 = nn.ReLU() # to be consistance with the original feature 119 | 120 | def forward(self, innerFeat, target_light, count, skip_count): 121 | x = innerFeat[:,0:self.ncInput,:,:] # lighting feature 122 | _, _, row, col = x.shape 123 | 124 | # predict lighting 125 | feat = x.mean(dim=(2,3), keepdim=True) 126 | light = self.predict_relu1(self.predict_FC1(feat)) 127 | light = self.predict_FC2(light) 128 | 129 | # get back the feature space 130 | upFeat = self.post_relu1(self.post_FC1(target_light)) 131 | upFeat = self.post_relu2(self.post_FC2(upFeat)) 132 | upFeat = upFeat.repeat((1,1,row, col)) 133 | innerFeat[:,0:self.ncInput,:,:] = upFeat 134 | return innerFeat, light 135 | 136 | 137 | class HourglassNet(nn.Module): 138 | ''' 139 | basic idea: low layers are shared, upper layers are different 140 | lighting should be estimated from the inner most layer 141 | NOTE: we split the bottle neck layer into albedo, normal and lighting 142 | ''' 143 | def __init__(self, baseFilter = 16, gray=True): 144 | super(HourglassNet, self).__init__() 145 | 146 | self.ncLight = 27 # number of channels for input to lighting network 147 | self.baseFilter = baseFilter 148 | 149 | # number of channles for output of lighting network 150 | if gray: 151 | self.ncOutLight = 9 # gray: channel is 1 152 | else: 153 | self.ncOutLight = 27 # color: channel is 3 154 | 155 | self.ncPre = self.baseFilter # number of channels for pre-convolution 156 | 157 | # number of channels 158 | self.ncHG3 = self.baseFilter 159 | self.ncHG2 = 2*self.baseFilter 160 | self.ncHG1 = 4*self.baseFilter 161 | self.ncHG0 = 8*self.baseFilter + self.ncLight 162 | 163 | self.pre_conv = nn.Conv2d(1, self.ncPre, kernel_size=5, stride=1, padding=2) 164 | self.pre_bn = nn.BatchNorm2d(self.ncPre) 165 | 166 | self.light = lightingNet(self.ncLight, self.ncOutLight, 128) 167 | self.HG0 = HourglassBlock(self.ncHG1, self.ncHG0, self.light) 168 | self.HG1 = HourglassBlock(self.ncHG2, self.ncHG1, self.HG0) 169 | self.HG2 = HourglassBlock(self.ncHG3, self.ncHG2, self.HG1) 170 | self.HG3 = HourglassBlock(self.ncPre, self.ncHG3, self.HG2) 171 | 172 | self.conv_1 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=3, stride=1, padding=1) 173 | self.bn_1 = nn.BatchNorm2d(self.ncPre) 174 | self.conv_2 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 175 | self.bn_2 = nn.BatchNorm2d(self.ncPre) 176 | self.conv_3 = nn.Conv2d(self.ncPre, self.ncPre, kernel_size=1, stride=1, padding=0) 177 | self.bn_3 = nn.BatchNorm2d(self.ncPre) 178 | 179 | self.output = nn.Conv2d(self.ncPre, 1, kernel_size=1, stride=1, padding=0) 180 | 181 | def forward(self, x, target_light, skip_count): 182 | feat = self.pre_conv(x) 183 | feat = F.relu(self.pre_bn(feat)) 184 | # get the inner most features 185 | feat, out_light = self.HG3(feat, target_light, 0, skip_count) 186 | feat = F.relu(self.bn_1(self.conv_1(feat))) 187 | feat = F.relu(self.bn_2(self.conv_2(feat))) 188 | feat = F.relu(self.bn_3(self.conv_3(feat))) 189 | out_img = self.output(feat) 190 | out_img = torch.sigmoid(out_img) 191 | return out_img, out_light 192 | 193 | if __name__ == '__main__': 194 | pass 195 | -------------------------------------------------------------------------------- /result/light_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/light_00.png -------------------------------------------------------------------------------- /result/light_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/light_01.png -------------------------------------------------------------------------------- /result/light_02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/light_02.png -------------------------------------------------------------------------------- /result/light_03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/light_03.png -------------------------------------------------------------------------------- /result/light_04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/light_04.png -------------------------------------------------------------------------------- /result/light_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/light_05.png -------------------------------------------------------------------------------- /result/light_06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/light_06.png -------------------------------------------------------------------------------- /result/obama_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/obama_00.jpg -------------------------------------------------------------------------------- /result/obama_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/obama_01.jpg -------------------------------------------------------------------------------- /result/obama_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/obama_02.jpg -------------------------------------------------------------------------------- /result/obama_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/obama_03.jpg -------------------------------------------------------------------------------- /result/obama_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/obama_04.jpg -------------------------------------------------------------------------------- /result/obama_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/obama_05.jpg -------------------------------------------------------------------------------- /result/obama_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/result/obama_06.jpg -------------------------------------------------------------------------------- /testNetwork_demo_1024.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this is a simple test file 3 | ''' 4 | import sys 5 | sys.path.append('model') 6 | sys.path.append('utils') 7 | 8 | from utils_SH import * 9 | 10 | # other modules 11 | import os 12 | import numpy as np 13 | 14 | from torch.autograd import Variable 15 | from torchvision.utils import make_grid 16 | import torch 17 | import time 18 | import cv2 19 | 20 | # ---------------- create normal for rendering half sphere ------ 21 | img_size = 256 22 | x = np.linspace(-1, 1, img_size) 23 | z = np.linspace(1, -1, img_size) 24 | x, z = np.meshgrid(x, z) 25 | 26 | mag = np.sqrt(x**2 + z**2) 27 | valid = mag <=1 28 | y = -np.sqrt(1 - (x*valid)**2 - (z*valid)**2) 29 | x = x * valid 30 | y = y * valid 31 | z = z * valid 32 | normal = np.concatenate((x[...,None], y[...,None], z[...,None]), axis=2) 33 | normal = np.reshape(normal, (-1, 3)) 34 | #----------------------------------------------------------------- 35 | 36 | modelFolder = 'trained_model/' 37 | 38 | # load model 39 | from defineHourglass_1024_gray_skip_matchFeature import * 40 | my_network_512 = HourglassNet(16) 41 | my_network = HourglassNet_1024(my_network_512, 16) 42 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, 'trained_model_1024_03.t7'))) 43 | my_network.cuda() 44 | my_network.train(False) 45 | 46 | 47 | lightFolder = 'data/example_light/' 48 | saveFolder = 'result_1024' 49 | if not os.path.exists(saveFolder): 50 | os.makedirs(saveFolder) 51 | 52 | 53 | img = cv2.imread('data/obama.jpg') 54 | row, col, _ = img.shape 55 | img = cv2.resize(img, (1024, 1024)) 56 | Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 57 | 58 | inputL = Lab[:,:,0] 59 | inputL = inputL.astype(np.float32)/255.0 60 | inputL = inputL.transpose((0,1)) 61 | inputL = inputL[None,None,...] 62 | inputL = Variable(torch.from_numpy(inputL).cuda()) 63 | 64 | for i in range(7): 65 | sh = np.loadtxt(os.path.join(lightFolder, 'rotate_light_{:02d}.txt'.format(i))) 66 | sh = sh[0:9] 67 | sh = sh * 0.7 68 | 69 | # rendering half-sphere 70 | sh = np.squeeze(sh) 71 | shading = get_shading(normal, sh) 72 | value = np.percentile(shading, 95) 73 | ind = shading > value 74 | shading[ind] = value 75 | shading = (shading - np.min(shading))/(np.max(shading) - np.min(shading)) 76 | shading = (shading *255.0).astype(np.uint8) 77 | shading = np.reshape(shading, (256, 256)) 78 | shading = shading * valid 79 | cv2.imwrite(os.path.join(saveFolder, \ 80 | 'light_{:02d}.png'.format(i)), shading) 81 | 82 | #---------------------------------------------- 83 | # rendering images using the network 84 | #---------------------------------------------- 85 | sh = np.reshape(sh, (1,9,1,1)).astype(np.float32) 86 | sh = Variable(torch.from_numpy(sh).cuda()) 87 | outputImg, _, outputSH, _ = my_network(inputL, sh, 0) 88 | outputImg = outputImg[0].cpu().data.numpy() 89 | outputImg = outputImg.transpose((1,2,0)) 90 | outputImg = np.squeeze(outputImg) 91 | outputImg = (outputImg*255.0).astype(np.uint8) 92 | Lab[:,:,0] = outputImg 93 | resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) 94 | resultLab = cv2.resize(resultLab, (col, row)) 95 | cv2.imwrite(os.path.join(saveFolder, \ 96 | 'obama_{:02d}.jpg'.format(i)), resultLab) 97 | -------------------------------------------------------------------------------- /testNetwork_demo_512.py: -------------------------------------------------------------------------------- 1 | ''' 2 | this is a simple test file 3 | ''' 4 | import sys 5 | sys.path.append('model') 6 | sys.path.append('utils') 7 | 8 | from utils_SH import * 9 | 10 | # other modules 11 | import os 12 | import numpy as np 13 | 14 | from torch.autograd import Variable 15 | from torchvision.utils import make_grid 16 | import torch 17 | import time 18 | import cv2 19 | 20 | # ---------------- create normal for rendering half sphere ------ 21 | img_size = 256 22 | x = np.linspace(-1, 1, img_size) 23 | z = np.linspace(1, -1, img_size) 24 | x, z = np.meshgrid(x, z) 25 | 26 | mag = np.sqrt(x**2 + z**2) 27 | valid = mag <=1 28 | y = -np.sqrt(1 - (x*valid)**2 - (z*valid)**2) 29 | x = x * valid 30 | y = y * valid 31 | z = z * valid 32 | normal = np.concatenate((x[...,None], y[...,None], z[...,None]), axis=2) 33 | normal = np.reshape(normal, (-1, 3)) 34 | #----------------------------------------------------------------- 35 | 36 | modelFolder = 'trained_model/' 37 | 38 | # load model 39 | from defineHourglass_512_gray_skip import * 40 | my_network = HourglassNet() 41 | my_network.load_state_dict(torch.load(os.path.join(modelFolder, 'trained_model_03.t7'))) 42 | my_network.cuda() 43 | my_network.train(False) 44 | 45 | lightFolder = 'data/example_light/' 46 | 47 | saveFolder = 'result' 48 | if not os.path.exists(saveFolder): 49 | os.makedirs(saveFolder) 50 | 51 | img = cv2.imread('data/obama.jpg') 52 | row, col, _ = img.shape 53 | img = cv2.resize(img, (512, 512)) 54 | Lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 55 | 56 | inputL = Lab[:,:,0] 57 | inputL = inputL.astype(np.float32)/255.0 58 | inputL = inputL.transpose((0,1)) 59 | inputL = inputL[None,None,...] 60 | inputL = Variable(torch.from_numpy(inputL).cuda()) 61 | 62 | for i in range(7): 63 | sh = np.loadtxt(os.path.join(lightFolder, 'rotate_light_{:02d}.txt'.format(i))) 64 | sh = sh[0:9] 65 | sh = sh * 0.7 66 | 67 | #-------------------------------------------------- 68 | # rendering half-sphere 69 | sh = np.squeeze(sh) 70 | shading = get_shading(normal, sh) 71 | value = np.percentile(shading, 95) 72 | ind = shading > value 73 | shading[ind] = value 74 | shading = (shading - np.min(shading))/(np.max(shading) - np.min(shading)) 75 | shading = (shading *255.0).astype(np.uint8) 76 | shading = np.reshape(shading, (256, 256)) 77 | shading = shading * valid 78 | cv2.imwrite(os.path.join(saveFolder, \ 79 | 'light_{:02d}.png'.format(i)), shading) 80 | #-------------------------------------------------- 81 | 82 | #---------------------------------------------- 83 | # rendering images using the network 84 | sh = np.reshape(sh, (1,9,1,1)).astype(np.float32) 85 | sh = Variable(torch.from_numpy(sh).cuda()) 86 | outputImg, outputSH = my_network(inputL, sh, 0) 87 | outputImg = outputImg[0].cpu().data.numpy() 88 | outputImg = outputImg.transpose((1,2,0)) 89 | outputImg = np.squeeze(outputImg) 90 | outputImg = (outputImg*255.0).astype(np.uint8) 91 | Lab[:,:,0] = outputImg 92 | resultLab = cv2.cvtColor(Lab, cv2.COLOR_LAB2BGR) 93 | resultLab = cv2.resize(resultLab, (col, row)) 94 | cv2.imwrite(os.path.join(saveFolder, \ 95 | 'obama_{:02d}.jpg'.format(i)), resultLab) 96 | #---------------------------------------------- 97 | -------------------------------------------------------------------------------- /trained_model/trained_model_03.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/trained_model/trained_model_03.t7 -------------------------------------------------------------------------------- /trained_model/trained_model_1024_03.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhhoper/DPR/737efadaf09a2f0f2f8b1e7f8372c6734dfaa5bc/trained_model/trained_model_1024_03.t7 -------------------------------------------------------------------------------- /utils/utils_SH.py: -------------------------------------------------------------------------------- 1 | ''' 2 | construct shading using sh basis 3 | ''' 4 | import numpy as np 5 | def SH_basis(normal): 6 | ''' 7 | get SH basis based on normal 8 | normal is a Nx3 matrix 9 | return a Nx9 matrix 10 | The order of SH here is: 11 | 1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2 12 | ''' 13 | numElem = normal.shape[0] 14 | 15 | norm_X = normal[:,0] 16 | norm_Y = normal[:,1] 17 | norm_Z = normal[:,2] 18 | 19 | sh_basis = np.zeros((numElem, 9)) 20 | att= np.pi*np.array([1, 2.0/3.0, 1/4.0]) 21 | sh_basis[:,0] = 0.5/np.sqrt(np.pi)*att[0] 22 | 23 | sh_basis[:,1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1] 24 | sh_basis[:,2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1] 25 | sh_basis[:,3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1] 26 | 27 | sh_basis[:,4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2] 28 | sh_basis[:,5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2] 29 | sh_basis[:,6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2] 30 | sh_basis[:,7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2] 31 | sh_basis[:,8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2] 32 | return sh_basis 33 | 34 | def SH_basis_noAtt(normal): 35 | ''' 36 | get SH basis based on normal 37 | normal is a Nx3 matrix 38 | return a Nx9 matrix 39 | The order of SH here is: 40 | 1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2 41 | ''' 42 | numElem = normal.shape[0] 43 | 44 | norm_X = normal[:,0] 45 | norm_Y = normal[:,1] 46 | norm_Z = normal[:,2] 47 | 48 | sh_basis = np.zeros((numElem, 9)) 49 | sh_basis[:,0] = 0.5/np.sqrt(np.pi) 50 | 51 | sh_basis[:,1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y 52 | sh_basis[:,2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z 53 | sh_basis[:,3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X 54 | 55 | sh_basis[:,4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X 56 | sh_basis[:,5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z 57 | sh_basis[:,6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1) 58 | sh_basis[:,7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z 59 | sh_basis[:,8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2) 60 | return sh_basis 61 | 62 | def get_shading(normal, SH): 63 | ''' 64 | get shading based on normals and SH 65 | normal is Nx3 matrix 66 | SH: 9 x m vector 67 | return Nxm vector, where m is the number of returned images 68 | ''' 69 | sh_basis = SH_basis(normal) 70 | shading = np.matmul(sh_basis, SH) 71 | #shading = np.matmul(np.reshape(sh_basis, (-1, 9)), SH) 72 | #shading = np.reshape(shading, normal.shape[0:2]) 73 | return shading 74 | 75 | def SH_basis_debug(normal): 76 | ''' 77 | get SH basis based on normal 78 | normal is a Nx3 matrix 79 | return a Nx9 matrix 80 | The order of SH here is: 81 | 1, Y, Z, X, YX, YZ, 3Z^2-1, XZ, X^2-y^2 82 | ''' 83 | numElem = normal.shape[0] 84 | 85 | norm_X = normal[:,0] 86 | norm_Y = normal[:,1] 87 | norm_Z = normal[:,2] 88 | 89 | sh_basis = np.zeros((numElem, 9)) 90 | att= np.pi*np.array([1, 2.0/3.0, 1/4.0]) 91 | # att = [1,1,1] 92 | sh_basis[:,0] = 0.5/np.sqrt(np.pi)*att[0] 93 | 94 | sh_basis[:,1] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Y*att[1] 95 | sh_basis[:,2] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_Z*att[1] 96 | sh_basis[:,3] = np.sqrt(3)/2/np.sqrt(np.pi)*norm_X*att[1] 97 | 98 | sh_basis[:,4] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_X*att[2] 99 | sh_basis[:,5] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_Y*norm_Z*att[2] 100 | sh_basis[:,6] = np.sqrt(5)/4/np.sqrt(np.pi)*(3*norm_Z**2-1)*att[2] 101 | sh_basis[:,7] = np.sqrt(15)/2/np.sqrt(np.pi)*norm_X*norm_Z*att[2] 102 | sh_basis[:,8] = np.sqrt(15)/4/np.sqrt(np.pi)*(norm_X**2-norm_Y**2)*att[2] 103 | return sh_basis 104 | 105 | def get_shading_debug(normal, SH): 106 | ''' 107 | get shading based on normals and SH 108 | normal is Nx3 matrix 109 | SH: 9 x m vector 110 | return Nxm vector, where m is the number of returned images 111 | ''' 112 | sh_basis = SH_basis_debug(normal) 113 | shading = np.matmul(sh_basis, SH) 114 | #shading = sh_basis*SH[0] 115 | return shading 116 | -------------------------------------------------------------------------------- /utils/utils_normal.py: -------------------------------------------------------------------------------- 1 | ''' 2 | adjust normals according to which SH we want to use 3 | ''' 4 | import numpy as np 5 | import sys 6 | from utils_shtools import * 7 | from pyshtools.rotate import djpi2, SHRotateRealCoef 8 | 9 | class sh_cvt(): 10 | ''' 11 | the normal direction we get from projection is: 12 | 13 | > z 14 | | / 15 | | / 16 | |/ 17 | --------------------------> x 18 | | 19 | | 20 | v y 21 | 22 | the x, y, z direction of SH from SHtools is 23 | ^ z > y 24 | | / 25 | | / 26 | |/ 27 | --------------------------> x 28 | | 29 | | 30 | 31 | the bip lighting coordinate is 32 | > z 33 | | / 34 | | / 35 | |/ 36 | <-------------------------- 37 | x | 38 | | 39 | v y 40 | 41 | the sfs lighting coordinate is 42 | | 43 | | 44 | --------------------------> y 45 | / | 46 | / | 47 | z / v x 48 | ''' 49 | def __init__(self): 50 | self.SH_DEGREE = 2 51 | self.dj = djpi2(self.SH_DEGREE) 52 | 53 | 54 | def cvt2shtools(self, normalImages): 55 | ''' 56 | align coordinates of normal with shtools 57 | ''' 58 | newNormals = normalImages.copy() 59 | # new y is the old z 60 | newNormals[:,:,1] = normalImages[:,:,2] 61 | # new z is the negative old y 62 | newNormals[:,:,2] = -1*normalImages[:,:,1] 63 | return newNormals 64 | 65 | def bip2shtools(self, lighting): 66 | ''' 67 | lighting is n x 9 matrix of bip lighting, we want to convert it 68 | to the coordinate of shtools so we can use the same coordinate 69 | --we use shtools to rotate the coordinate: 70 | we use shtools to rotate the object: 71 | we need to use x convention, 72 | alpha_x = -pi (contour clock-wise rotate along z by pi) 73 | beta_x = -pi/2 (contour clock-wise rotate along new x by pi/2) 74 | gamma_x = 0 75 | then y convention is: 76 | alpha_y = alpha_x - pi/2 = 0 77 | beta_y = beta_x = -pi/2 78 | gamma_y = gamma_x + pi/2 = pi/2 79 | reference: https://shtools.oca.eu/shtools/pyshrotaterealcoef.html 80 | ''' 81 | new_lighting = np.zeros(lighting.shape) 82 | n = lighting.shape[0] 83 | for i in range(n): 84 | shMatrix = shtools_sh2matrix(lighting[i,:], self.SH_DEGREE) 85 | # rotate coordinate 86 | shMatrix = SHRotateRealCoef(shMatrix, np.array([0, -np.pi/2, np.pi/2]), self.dj) 87 | # rotate object 88 | #shMatrix = SHRotateRealCoef(shMatrix, np.array([-np.pi/2, np.pi/2, -np.pi/2]), self.dj) 89 | new_lighting[i,:] = shtools_matrix2vec(shMatrix) 90 | return new_lighting 91 | 92 | def sfs2shtools(self, lighting): 93 | ''' 94 | convert sfs SH to shtools 95 | --we use shtools to rotate the coordinate: 96 | we use shtools to rotate the object: 97 | 98 | we need to use x convention, 99 | we use shtools to rotate the coordinate: 100 | we need to use x convention, 101 | alpha_x = pi/2 (clock-wise rotate along z axis by pi/2) 102 | beta_x = -pi/2 (contour clock-wise rotate along new x by pi/2) 103 | gamma_x = 0 104 | then y convention is: 105 | alpha_y = alpha_x - pi/2 = 0 106 | beta_y = beta_x = -pi/2 107 | gamma_y = gamma_x + pi/2 = pi/2 108 | reference: https://shtools.oca.eu/shtools/pyshrotaterealcoef.html 109 | ''' 110 | new_lighting = np.zeros(lighting.shape) 111 | n = lighting.shape[0] 112 | for i in range(n): 113 | shMatrix = shtools_sh2matrix(lighting[i,:], self.SH_DEGREE) 114 | # rotate coordinate 115 | shMatrix = SHRotateRealCoef(shMatrix, np.array([0, -np.pi/2, np.pi/2]), self.dj) 116 | # rotate object 117 | #shMatrix = SHRotateRealCoef(shMatrix, np.array([np.pi/2, -np.pi/2, 0]), self.dj) 118 | new_lighting[i,:] = shtools_matrix2vec(shMatrix) 119 | return new_lighting 120 | -------------------------------------------------------------------------------- /utils/utils_shtools.py: -------------------------------------------------------------------------------- 1 | ''' 2 | define some helper functions for shtools 3 | ''' 4 | import pyshtools 5 | from pyshtools.expand import MakeGridDH 6 | import numpy as np 7 | 8 | def shtools_matrix2vec(SH_matrix): 9 | ''' 10 | for the sh matrix created by sh tools, 11 | we create the vector of the sh 12 | ''' 13 | numOrder = SH_matrix.shape[1] 14 | vec_SH = np.zeros(numOrder**2) 15 | count = 0 16 | for i in range(numOrder): 17 | for j in range(i,0,-1): 18 | vec_SH[count] = SH_matrix[1,i,j] 19 | count = count + 1 20 | for j in range(0,i+1): 21 | vec_SH[count]= SH_matrix[0, i,j] 22 | count = count + 1 23 | return vec_SH 24 | 25 | def shtools_sh2matrix(coefficients, degree): 26 | ''' 27 | convert vector of sh to matrix 28 | ''' 29 | coeffs_matrix = np.zeros((2, degree + 1, degree + 1)) 30 | current_zero_index = 0 31 | for l in range(0, degree + 1): 32 | coeffs_matrix[0, l, 0] = coefficients[current_zero_index] 33 | for m in range(1, l + 1): 34 | coeffs_matrix[0, l, m] = coefficients[current_zero_index + m] 35 | coeffs_matrix[1, l, m] = coefficients[current_zero_index - m] 36 | current_zero_index += 2*(l+1) 37 | return coeffs_matrix 38 | 39 | def shtools_getSH(envMap, order=5): 40 | ''' 41 | get SH based on the envmap 42 | ''' 43 | SH = pyshtools.expand.SHExpandDH(envMap, sampling=2, lmax_calc=order, norm=4) 44 | return SH 45 | --------------------------------------------------------------------------------