├── README.md ├── crop.py ├── data ├── cropped │ └── 0.jpg └── original │ └── 0.jpg ├── resnet_unet ├── __init__.py ├── resnet.py └── resnet_unet_model.py ├── teaser.png └── tester.py /README.md: -------------------------------------------------------------------------------- 1 | # Cross-modal Deep Face Normals with Deactivable Skip Connections 2 | Victoria Fernández Abrevaya*, Adnane Boukhayma*, Philip H. S. Torr, Edmond Boyer (*Equal contrib.).
3 | [CVPR 2020 (Oral)](https://arxiv.org/abs/2003.09691)
4 | 5 | 6 | 7 | ## Requirements 8 | + Python 2.7 9 | + PyTorch 0.3 10 | 11 | ## Data preprocessing 12 | Input images are assumed to be crops of fixed size around the face. Using `dlib`, this command finds the tightest rectangular box of edge size 13 | `l` containing the face. Images are then cropped with a square patch of size `1.2xl`. Input images are located in `data/original` and cropped images are saved in `data/cropped`. 14 | 15 | Download the [dlib trained facial shape predictor](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2). Put file `shape_predictor_68_face_landmarks.dat` in directory `data`. 16 | ``` 17 | python crop.py 18 | ``` 19 | 20 | ## Testing 21 | Download the [model weights](https://drive.google.com/file/d/1Qb7CZbM13Zpksa30ywjXEEHHDcVWHju_). Put file `model.pth` in directory `data`. 22 | 23 | Run the following command to generate an image of normals from a cropped RGB image example in `data/cropped`. Results are saved in `data/output`. 24 | ``` 25 | python tester.py 26 | ``` 27 | 28 | 29 | ## Citation 30 | ``` 31 | @InProceedings{Abrevaya_2020_CVPR, 32 | author = {Abrevaya, Victoria Fernandez and Boukhayma, Adnane and Torr, Philip H.S. and Boyer, Edmond}, 33 | title = {Cross-Modal Deep Face Normals With Deactivable Skip Connections}, 34 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 35 | month = {June}, 36 | year = {2020} 37 | } 38 | ``` 39 | 40 | ## Acknowledgement 41 | This work was partly supported by the ERC grant ERC-2012-AdG 321162-HELIOS, the EPSRC grant Seebibyte EP/M013774/1 and the EPSRC/MURI grant EP/N019474/1. 42 | 43 | ## License 44 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License. 45 | -------------------------------------------------------------------------------- /crop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from imutils import face_utils 4 | import dlib 5 | 6 | predictor = dlib.shape_predictor('data/shape_predictor_68_face_landmarks.dat') 7 | detector = dlib.get_frontal_face_detector() 8 | 9 | # load image 10 | img = cv2.imread('data/original/0.jpg',1) 11 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 12 | 13 | # detect kp 14 | rects = detector(gray, 0) 15 | shape = predictor(gray, rects[0]) 16 | shape = face_utils.shape_to_np(shape) 17 | shape = np.round(shape) 18 | 19 | # draw mask 20 | msk = np.zeros(img.shape, dtype=np.uint8) 21 | cv2.fillPoly(msk, [cv2.convexHull(shape)], (1,1,1)) 22 | 23 | # crop & resize 24 | umin = np.min(shape[:,0]) 25 | umax = np.max(shape[:,0]) 26 | vmin = np.min(shape[:,1]) 27 | vmax = np.max(shape[:,1]) 28 | 29 | umean = np.mean((umin,umax)) 30 | vmean = np.mean((vmin,vmax)) 31 | 32 | l = round( 1.2 * np.max((umax-umin,vmax-vmin))) 33 | 34 | if (l > np.min(img.shape[:2])): 35 | l = np.min(img.shape[:2]) 36 | 37 | us = round(np.max((0,umean-float(l)/2))) 38 | ue = us + l 39 | 40 | vs = round(np.max((0,vmean-float(l)/2))) 41 | ve = vs + l 42 | 43 | if (ue>img.shape[1]): 44 | ue = img.shape[1] 45 | us = img.shape[1]-l 46 | 47 | if (ve>img.shape[0]): 48 | ve = img.shape[0] 49 | vs = img.shape[0]-l 50 | 51 | us = int(us) 52 | ue = int(ue) 53 | 54 | vs = int(vs) 55 | ve = int(ve) 56 | 57 | img = cv2.resize(img[vs:ve,us:ue],(256,256)) 58 | msk = cv2.resize(msk[vs:ve,us:ue],(256,256),interpolation=cv2.INTER_NEAREST) 59 | 60 | # save images 61 | cv2.imwrite('data/cropped/0.jpg', img) 62 | cv2.imwrite('data/cropped/0_msk.jpg', msk * 255) -------------------------------------------------------------------------------- /data/cropped/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boukhayma/face_normals/f9018333bd049cc0c58a5ab87a843515386d7f5f/data/cropped/0.jpg -------------------------------------------------------------------------------- /data/original/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boukhayma/face_normals/f9018333bd049cc0c58a5ab87a843515386d7f5f/data/original/0.jpg -------------------------------------------------------------------------------- /resnet_unet/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_unet_model import ResNetUNet 2 | -------------------------------------------------------------------------------- /resnet_unet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152'] 8 | 9 | 10 | model_urls = { 11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1): 20 | "3x3 convolution with padding" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=1, bias=False) 23 | 24 | 25 | class BasicBlock(nn.Module): 26 | expansion = 1 27 | 28 | def __init__(self, inplanes, planes, stride=1, downsample=None): 29 | super(BasicBlock, self).__init__() 30 | self.conv1 = conv3x3(inplanes, planes, stride) 31 | self.bn1 = nn.BatchNorm2d(planes) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.conv2 = conv3x3(planes, planes) 34 | self.bn2 = nn.BatchNorm2d(planes) 35 | self.downsample = downsample 36 | self.stride = stride 37 | 38 | def forward(self, x): 39 | residual = x 40 | 41 | out = self.conv1(x) 42 | out = self.bn1(out) 43 | out = self.relu(out) 44 | 45 | out = self.conv2(out) 46 | out = self.bn2(out) 47 | 48 | if self.downsample is not None: 49 | residual = self.downsample(x) 50 | 51 | out += residual 52 | out = self.relu(out) 53 | 54 | return out 55 | 56 | 57 | class Bottleneck(nn.Module): 58 | expansion = 4 59 | 60 | def __init__(self, inplanes, planes, stride=1, downsample=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = nn.BatchNorm2d(planes) 64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 65 | padding=1, bias=False) 66 | self.bn2 = nn.BatchNorm2d(planes) 67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 68 | self.bn3 = nn.BatchNorm2d(planes * 4) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.downsample = downsample 71 | self.stride = stride 72 | 73 | def forward(self, x): 74 | residual = x 75 | 76 | out = self.conv1(x) 77 | out = self.bn1(out) 78 | out = self.relu(out) 79 | 80 | out = self.conv2(out) 81 | out = self.bn2(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv3(out) 85 | out = self.bn3(out) 86 | 87 | if self.downsample is not None: 88 | residual = self.downsample(x) 89 | 90 | out += residual 91 | out = self.relu(out) 92 | 93 | return out 94 | 95 | 96 | class ResNet(nn.Module): 97 | 98 | def __init__(self, block, layers, num_classes=1000): 99 | self.inplanes = 64 100 | super(ResNet, self).__init__() 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 102 | bias=False) 103 | self.bn1 = nn.BatchNorm2d(64) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 106 | self.layer1 = self._make_layer(block, 64, layers[0]) 107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 110 | 111 | 112 | self.avgpool = nn.AvgPool2d(7, stride=1) 113 | self.fc = nn.Linear(512 * block.expansion, num_classes) 114 | 115 | for m in self.modules(): 116 | if isinstance(m, nn.Conv2d): 117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 118 | m.weight.data.normal_(0, math.sqrt(2. / n)) 119 | elif isinstance(m, nn.BatchNorm2d): 120 | m.weight.data.fill_(1) 121 | m.bias.data.zero_() 122 | 123 | def _make_layer(self, block, planes, blocks, stride=1): 124 | downsample = None 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | nn.Conv2d(self.inplanes, planes * block.expansion, 128 | kernel_size=1, stride=stride, bias=False), 129 | nn.BatchNorm2d(planes * block.expansion), 130 | ) 131 | 132 | layers = [] 133 | layers.append(block(self.inplanes, planes, stride, downsample)) 134 | self.inplanes = planes * block.expansion 135 | for i in range(1, blocks): 136 | layers.append(block(self.inplanes, planes)) 137 | 138 | return nn.Sequential(*layers) 139 | 140 | def forward(self, x): 141 | x = self.conv1(x) 142 | x = self.bn1(x) 143 | x = self.relu(x) 144 | x = self.maxpool(x) 145 | 146 | x = self.layer1(x) 147 | x = self.layer2(x) 148 | x = self.layer3(x) 149 | x = self.layer4(x) 150 | 151 | x = self.avgpool(x) 152 | x = x.view(x.size(0), -1) 153 | x = self.fc(x) 154 | 155 | return x 156 | 157 | 158 | 159 | 160 | 161 | def resnet18(pretrained=True, **kwargs): 162 | """Constructs a ResNet-18 model. 163 | 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 170 | #model.layer4 = model._make_layer(BasicBlock, 512, 2, stride=2) 171 | return model 172 | 173 | 174 | def resnet34(pretrained=False, **kwargs): 175 | """Constructs a ResNet-34 model. 176 | 177 | Args: 178 | pretrained (bool): If True, returns a model pre-trained on ImageNet 179 | """ 180 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 181 | if pretrained: 182 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 183 | return model 184 | 185 | 186 | def resnet50(pretrained=False, **kwargs): 187 | """Constructs a ResNet-50 model. 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | """ 192 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 193 | if pretrained: 194 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 195 | return model 196 | 197 | 198 | def resnet101(pretrained=False, **kwargs): 199 | """Constructs a ResNet-101 model. 200 | 201 | Args: 202 | pretrained (bool): If True, returns a model pre-trained on ImageNet 203 | """ 204 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 205 | if pretrained: 206 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 207 | return model 208 | 209 | 210 | def resnet152(pretrained=False, **kwargs): 211 | """Constructs a ResNet-152 model. 212 | 213 | Args: 214 | pretrained (bool): If True, returns a model pre-trained on ImageNet 215 | """ 216 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 217 | if pretrained: 218 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 219 | return model 220 | -------------------------------------------------------------------------------- /resnet_unet/resnet_unet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from resnet import resnet18 4 | 5 | 6 | 7 | import copy 8 | 9 | def convrelu(in_channels, out_channels, kernel, padding): 10 | return nn.Sequential( 11 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding), 12 | nn.ReLU(inplace=True), 13 | ) 14 | 15 | 16 | class ResNetUNet(nn.Module): 17 | def __init__(self, n_class): 18 | super(ResNetUNet,self).__init__() 19 | 20 | self.base_model = resnet18(pretrained=True) 21 | self.base_layers = list(self.base_model.children()) 22 | 23 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2) 24 | self.layer0_1x1 = convrelu(64, 64, 1, 0) 25 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4) 26 | self.layer1_1x1 = convrelu(64, 64, 1, 0) 27 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8) 28 | self.layer2_1x1 = convrelu(128, 128, 1, 0) 29 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16) 30 | self.layer3_1x1 = convrelu(256, 256, 1, 0) 31 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32) 32 | self.layer4_1x1 = convrelu(512, 256 + 512, 1, 0) 33 | 34 | self.layer0_2 = copy.deepcopy(self.layer0) 35 | self.layer1_2 = copy.deepcopy(self.layer1) 36 | self.layer2_2 = copy.deepcopy(self.layer2) 37 | self.layer3_2 = copy.deepcopy(self.layer3) 38 | self.layer4_2 = copy.deepcopy(self.layer4) 39 | 40 | 41 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 42 | 43 | self.upsample_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 44 | 45 | self.conv_up3 = convrelu(256 + 512, 128 + 512, 3, 1) 46 | self.conv_up2 = convrelu(128 + 512, 64 + 256, 3, 1) 47 | self.conv_up1 = convrelu( 64 + 256, 64 + 256, 3, 1) 48 | self.conv_up0 = convrelu( 64 + 256, 64 + 128, 3, 1) 49 | 50 | self.conv_up3_2 = convrelu(512, 512, 3, 1) 51 | self.conv_up2_2 = convrelu(512, 256, 3, 1) 52 | self.conv_up1_2 = convrelu(256, 256, 3, 1) 53 | self.conv_up0_2 = convrelu(256, 128, 3, 1) 54 | 55 | self.conv_original_size0 = convrelu(3, 64, 3, 1) 56 | self.conv_original_size1 = convrelu(64, 64, 3, 1) 57 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1) 58 | 59 | self.conv_original_size0_2 = convrelu(3, 64, 3, 1) 60 | self.conv_original_size1_2 = convrelu(64, 64, 3, 1) 61 | self.conv_original_size2_2 = convrelu(128, 64, 3, 1) 62 | 63 | self.conv_last = nn.Conv2d(64, n_class, 1) 64 | 65 | self.conv_last_2 = nn.Conv2d(64, n_class, 1) 66 | 67 | def forward(self, input, flag=0): 68 | 69 | if (flag == 0): #'im-input' 70 | 71 | # Image Encoder 72 | x_original = self.conv_original_size0(input) 73 | x_original = self.conv_original_size1(x_original) 74 | 75 | layer0 = self.layer0(input) 76 | layer1 = self.layer1(layer0) 77 | layer2 = self.layer2(layer1) 78 | layer3 = self.layer3(layer2) 79 | layer4 = self.layer4(layer3) 80 | 81 | # Normal decoder 82 | layer4_1 = self.layer4_1x1(layer4) 83 | x_1 = self.upsample(layer4_1) 84 | layer3_1 = self.layer3_1x1(layer3) 85 | x_1 = torch.cat([x_1[:,:512,:,:], torch.max(x_1[:,512:,:,:] , layer3_1)], dim=1) 86 | x_1 = self.conv_up3(x_1) 87 | 88 | x_1 = self.upsample(x_1) 89 | layer2_1 = self.layer2_1x1(layer2) 90 | x_1 = torch.cat([x_1[:,:512,:,:], torch.max(x_1[:,512:,:,:] , layer2_1)], dim=1) 91 | x_1 = self.conv_up2(x_1) 92 | 93 | x_1 = self.upsample(x_1) 94 | layer1_1 = self.layer1_1x1(layer1) 95 | x_1 = torch.cat([x_1[:,:256,:,:], torch.max(x_1[:,256:,:,:] , layer1_1)], dim=1) 96 | x_1 = self.conv_up1(x_1) 97 | 98 | x_1 = self.upsample(x_1) 99 | layer0_1 = self.layer0_1x1(layer0) 100 | x_1 = torch.cat([x_1[:,:256,:,:], torch.max(x_1[:,256:,:,:] , layer0_1)], dim=1) 101 | x_1 = self.conv_up0(x_1) 102 | 103 | x_1 = self.upsample(x_1) 104 | x_1 = torch.cat([x_1[:,:128,:,:], torch.max(x_1[:,128:,:,:] , x_original)], dim=1) 105 | x_1 = self.conv_original_size2(x_1) 106 | 107 | out_1 = self.conv_last(x_1) 108 | 109 | 110 | # Image decoder 111 | x_2 = self.upsample_2(layer4) 112 | x_2 = self.conv_up3_2(x_2) 113 | 114 | x_2 = self.upsample_2(x_2) 115 | x_2 = self.conv_up2_2(x_2) 116 | 117 | x_2 = self.upsample_2(x_2) 118 | x_2 = self.conv_up1_2(x_2) 119 | 120 | x_2 = self.upsample_2(x_2) 121 | x_2 = self.conv_up0_2(x_2) 122 | 123 | x_2 = self.upsample_2(x_2) 124 | x_2 = self.conv_original_size2_2(x_2) 125 | 126 | out_2 = self.conv_last_2(x_2) 127 | 128 | return out_1, out_2 129 | 130 | 131 | if (flag == 1): #'norm-input' 132 | 133 | # Normal Encoder 134 | x_original = self.conv_original_size0_2(input) 135 | x_original = self.conv_original_size1_2(x_original) 136 | 137 | layer0 = self.layer0_2(input) 138 | layer1 = self.layer1_2(layer0) 139 | layer2 = self.layer2_2(layer1) 140 | layer3 = self.layer3_2(layer2) 141 | layer4 = self.layer4_2(layer3) 142 | 143 | # Normal decoder 144 | layer4 = self.layer4_1x1(layer4) 145 | 146 | x_1 = self.upsample(layer4) 147 | x_1 = self.conv_up3(x_1) 148 | 149 | x_1 = self.upsample(x_1) 150 | x_1 = self.conv_up2(x_1) 151 | 152 | x_1 = self.upsample(x_1) 153 | x_1 = self.conv_up1(x_1) 154 | 155 | x_1 = self.upsample(x_1) 156 | x_1 = self.conv_up0(x_1) 157 | 158 | x_1 = self.upsample(x_1) 159 | x_1 = self.conv_original_size2(x_1) 160 | 161 | out_1 = self.conv_last(x_1) 162 | 163 | return out_1 164 | 165 | 166 | 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/boukhayma/face_normals/f9018333bd049cc0c58a5ab87a843515386d7f5f/teaser.png -------------------------------------------------------------------------------- /tester.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from torchvision.transforms import Compose, ToTensor 4 | from resnet_unet import ResNetUNet 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | 9 | img_transform = Compose([ 10 | ToTensor() 11 | ]) 12 | 13 | model = ResNetUNet(n_class = 3).cuda() 14 | model.load_state_dict(torch.load('data/model.pth')) 15 | 16 | model.eval() 17 | 18 | img = img_transform(Image.open('data/cropped/0.jpg')).unsqueeze(0) 19 | img = Variable(img.cuda()) 20 | 21 | outs = model(img)[0] 22 | out = np.array(outs[0].data.permute(1,2,0).cpu()) 23 | out = out / np.expand_dims(np.sqrt(np.sum(out * out, 2)),2) 24 | out = 127.5 * (out + 1.0) 25 | 26 | cv2.imwrite('data/output/0.jpg', cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) 27 | 28 | 29 | 30 | 31 | 32 | --------------------------------------------------------------------------------