├── README.md
├── crop.py
├── data
├── cropped
│ └── 0.jpg
└── original
│ └── 0.jpg
├── resnet_unet
├── __init__.py
├── resnet.py
└── resnet_unet_model.py
├── teaser.png
└── tester.py
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-modal Deep Face Normals with Deactivable Skip Connections
2 | Victoria Fernández Abrevaya*, Adnane Boukhayma*, Philip H. S. Torr, Edmond Boyer (*Equal contrib.).
3 | [CVPR 2020 (Oral)](https://arxiv.org/abs/2003.09691)
4 |
5 |
6 |
7 | ## Requirements
8 | + Python 2.7
9 | + PyTorch 0.3
10 |
11 | ## Data preprocessing
12 | Input images are assumed to be crops of fixed size around the face. Using `dlib`, this command finds the tightest rectangular box of edge size
13 | `l` containing the face. Images are then cropped with a square patch of size `1.2xl`. Input images are located in `data/original` and cropped images are saved in `data/cropped`.
14 |
15 | Download the [dlib trained facial shape predictor](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2). Put file `shape_predictor_68_face_landmarks.dat` in directory `data`.
16 | ```
17 | python crop.py
18 | ```
19 |
20 | ## Testing
21 | Download the [model weights](https://drive.google.com/file/d/1Qb7CZbM13Zpksa30ywjXEEHHDcVWHju_). Put file `model.pth` in directory `data`.
22 |
23 | Run the following command to generate an image of normals from a cropped RGB image example in `data/cropped`. Results are saved in `data/output`.
24 | ```
25 | python tester.py
26 | ```
27 |
28 |
29 | ## Citation
30 | ```
31 | @InProceedings{Abrevaya_2020_CVPR,
32 | author = {Abrevaya, Victoria Fernandez and Boukhayma, Adnane and Torr, Philip H.S. and Boyer, Edmond},
33 | title = {Cross-Modal Deep Face Normals With Deactivable Skip Connections},
34 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
35 | month = {June},
36 | year = {2020}
37 | }
38 | ```
39 |
40 | ## Acknowledgement
41 | This work was partly supported by the ERC grant ERC-2012-AdG 321162-HELIOS, the EPSRC grant Seebibyte EP/M013774/1 and the EPSRC/MURI grant EP/N019474/1.
42 |
43 | ## License
44 | 
This work is licensed under a Creative Commons Attribution-NonCommercial 4.0 International License.
45 |
--------------------------------------------------------------------------------
/crop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | from imutils import face_utils
4 | import dlib
5 |
6 | predictor = dlib.shape_predictor('data/shape_predictor_68_face_landmarks.dat')
7 | detector = dlib.get_frontal_face_detector()
8 |
9 | # load image
10 | img = cv2.imread('data/original/0.jpg',1)
11 | gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
12 |
13 | # detect kp
14 | rects = detector(gray, 0)
15 | shape = predictor(gray, rects[0])
16 | shape = face_utils.shape_to_np(shape)
17 | shape = np.round(shape)
18 |
19 | # draw mask
20 | msk = np.zeros(img.shape, dtype=np.uint8)
21 | cv2.fillPoly(msk, [cv2.convexHull(shape)], (1,1,1))
22 |
23 | # crop & resize
24 | umin = np.min(shape[:,0])
25 | umax = np.max(shape[:,0])
26 | vmin = np.min(shape[:,1])
27 | vmax = np.max(shape[:,1])
28 |
29 | umean = np.mean((umin,umax))
30 | vmean = np.mean((vmin,vmax))
31 |
32 | l = round( 1.2 * np.max((umax-umin,vmax-vmin)))
33 |
34 | if (l > np.min(img.shape[:2])):
35 | l = np.min(img.shape[:2])
36 |
37 | us = round(np.max((0,umean-float(l)/2)))
38 | ue = us + l
39 |
40 | vs = round(np.max((0,vmean-float(l)/2)))
41 | ve = vs + l
42 |
43 | if (ue>img.shape[1]):
44 | ue = img.shape[1]
45 | us = img.shape[1]-l
46 |
47 | if (ve>img.shape[0]):
48 | ve = img.shape[0]
49 | vs = img.shape[0]-l
50 |
51 | us = int(us)
52 | ue = int(ue)
53 |
54 | vs = int(vs)
55 | ve = int(ve)
56 |
57 | img = cv2.resize(img[vs:ve,us:ue],(256,256))
58 | msk = cv2.resize(msk[vs:ve,us:ue],(256,256),interpolation=cv2.INTER_NEAREST)
59 |
60 | # save images
61 | cv2.imwrite('data/cropped/0.jpg', img)
62 | cv2.imwrite('data/cropped/0_msk.jpg', msk * 255)
--------------------------------------------------------------------------------
/data/cropped/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boukhayma/face_normals/f9018333bd049cc0c58a5ab87a843515386d7f5f/data/cropped/0.jpg
--------------------------------------------------------------------------------
/data/original/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boukhayma/face_normals/f9018333bd049cc0c58a5ab87a843515386d7f5f/data/original/0.jpg
--------------------------------------------------------------------------------
/resnet_unet/__init__.py:
--------------------------------------------------------------------------------
1 | from .resnet_unet_model import ResNetUNet
2 |
--------------------------------------------------------------------------------
/resnet_unet/resnet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 | import torch.utils.model_zoo as model_zoo
4 |
5 |
6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
7 | 'resnet152']
8 |
9 |
10 | model_urls = {
11 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
12 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
13 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
14 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
15 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
16 | }
17 |
18 |
19 | def conv3x3(in_planes, out_planes, stride=1):
20 | "3x3 convolution with padding"
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22 | padding=1, bias=False)
23 |
24 |
25 | class BasicBlock(nn.Module):
26 | expansion = 1
27 |
28 | def __init__(self, inplanes, planes, stride=1, downsample=None):
29 | super(BasicBlock, self).__init__()
30 | self.conv1 = conv3x3(inplanes, planes, stride)
31 | self.bn1 = nn.BatchNorm2d(planes)
32 | self.relu = nn.ReLU(inplace=True)
33 | self.conv2 = conv3x3(planes, planes)
34 | self.bn2 = nn.BatchNorm2d(planes)
35 | self.downsample = downsample
36 | self.stride = stride
37 |
38 | def forward(self, x):
39 | residual = x
40 |
41 | out = self.conv1(x)
42 | out = self.bn1(out)
43 | out = self.relu(out)
44 |
45 | out = self.conv2(out)
46 | out = self.bn2(out)
47 |
48 | if self.downsample is not None:
49 | residual = self.downsample(x)
50 |
51 | out += residual
52 | out = self.relu(out)
53 |
54 | return out
55 |
56 |
57 | class Bottleneck(nn.Module):
58 | expansion = 4
59 |
60 | def __init__(self, inplanes, planes, stride=1, downsample=None):
61 | super(Bottleneck, self).__init__()
62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
63 | self.bn1 = nn.BatchNorm2d(planes)
64 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
65 | padding=1, bias=False)
66 | self.bn2 = nn.BatchNorm2d(planes)
67 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
68 | self.bn3 = nn.BatchNorm2d(planes * 4)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.downsample = downsample
71 | self.stride = stride
72 |
73 | def forward(self, x):
74 | residual = x
75 |
76 | out = self.conv1(x)
77 | out = self.bn1(out)
78 | out = self.relu(out)
79 |
80 | out = self.conv2(out)
81 | out = self.bn2(out)
82 | out = self.relu(out)
83 |
84 | out = self.conv3(out)
85 | out = self.bn3(out)
86 |
87 | if self.downsample is not None:
88 | residual = self.downsample(x)
89 |
90 | out += residual
91 | out = self.relu(out)
92 |
93 | return out
94 |
95 |
96 | class ResNet(nn.Module):
97 |
98 | def __init__(self, block, layers, num_classes=1000):
99 | self.inplanes = 64
100 | super(ResNet, self).__init__()
101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
102 | bias=False)
103 | self.bn1 = nn.BatchNorm2d(64)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
106 | self.layer1 = self._make_layer(block, 64, layers[0])
107 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
108 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
109 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
110 |
111 |
112 | self.avgpool = nn.AvgPool2d(7, stride=1)
113 | self.fc = nn.Linear(512 * block.expansion, num_classes)
114 |
115 | for m in self.modules():
116 | if isinstance(m, nn.Conv2d):
117 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
118 | m.weight.data.normal_(0, math.sqrt(2. / n))
119 | elif isinstance(m, nn.BatchNorm2d):
120 | m.weight.data.fill_(1)
121 | m.bias.data.zero_()
122 |
123 | def _make_layer(self, block, planes, blocks, stride=1):
124 | downsample = None
125 | if stride != 1 or self.inplanes != planes * block.expansion:
126 | downsample = nn.Sequential(
127 | nn.Conv2d(self.inplanes, planes * block.expansion,
128 | kernel_size=1, stride=stride, bias=False),
129 | nn.BatchNorm2d(planes * block.expansion),
130 | )
131 |
132 | layers = []
133 | layers.append(block(self.inplanes, planes, stride, downsample))
134 | self.inplanes = planes * block.expansion
135 | for i in range(1, blocks):
136 | layers.append(block(self.inplanes, planes))
137 |
138 | return nn.Sequential(*layers)
139 |
140 | def forward(self, x):
141 | x = self.conv1(x)
142 | x = self.bn1(x)
143 | x = self.relu(x)
144 | x = self.maxpool(x)
145 |
146 | x = self.layer1(x)
147 | x = self.layer2(x)
148 | x = self.layer3(x)
149 | x = self.layer4(x)
150 |
151 | x = self.avgpool(x)
152 | x = x.view(x.size(0), -1)
153 | x = self.fc(x)
154 |
155 | return x
156 |
157 |
158 |
159 |
160 |
161 | def resnet18(pretrained=True, **kwargs):
162 | """Constructs a ResNet-18 model.
163 |
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
168 | if pretrained:
169 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
170 | #model.layer4 = model._make_layer(BasicBlock, 512, 2, stride=2)
171 | return model
172 |
173 |
174 | def resnet34(pretrained=False, **kwargs):
175 | """Constructs a ResNet-34 model.
176 |
177 | Args:
178 | pretrained (bool): If True, returns a model pre-trained on ImageNet
179 | """
180 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
181 | if pretrained:
182 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
183 | return model
184 |
185 |
186 | def resnet50(pretrained=False, **kwargs):
187 | """Constructs a ResNet-50 model.
188 |
189 | Args:
190 | pretrained (bool): If True, returns a model pre-trained on ImageNet
191 | """
192 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
193 | if pretrained:
194 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
195 | return model
196 |
197 |
198 | def resnet101(pretrained=False, **kwargs):
199 | """Constructs a ResNet-101 model.
200 |
201 | Args:
202 | pretrained (bool): If True, returns a model pre-trained on ImageNet
203 | """
204 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
205 | if pretrained:
206 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
207 | return model
208 |
209 |
210 | def resnet152(pretrained=False, **kwargs):
211 | """Constructs a ResNet-152 model.
212 |
213 | Args:
214 | pretrained (bool): If True, returns a model pre-trained on ImageNet
215 | """
216 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
217 | if pretrained:
218 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
219 | return model
220 |
--------------------------------------------------------------------------------
/resnet_unet/resnet_unet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from resnet import resnet18
4 |
5 |
6 |
7 | import copy
8 |
9 | def convrelu(in_channels, out_channels, kernel, padding):
10 | return nn.Sequential(
11 | nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
12 | nn.ReLU(inplace=True),
13 | )
14 |
15 |
16 | class ResNetUNet(nn.Module):
17 | def __init__(self, n_class):
18 | super(ResNetUNet,self).__init__()
19 |
20 | self.base_model = resnet18(pretrained=True)
21 | self.base_layers = list(self.base_model.children())
22 |
23 | self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
24 | self.layer0_1x1 = convrelu(64, 64, 1, 0)
25 | self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
26 | self.layer1_1x1 = convrelu(64, 64, 1, 0)
27 | self.layer2 = self.base_layers[5] # size=(N, 128, x.H/8, x.W/8)
28 | self.layer2_1x1 = convrelu(128, 128, 1, 0)
29 | self.layer3 = self.base_layers[6] # size=(N, 256, x.H/16, x.W/16)
30 | self.layer3_1x1 = convrelu(256, 256, 1, 0)
31 | self.layer4 = self.base_layers[7] # size=(N, 512, x.H/32, x.W/32)
32 | self.layer4_1x1 = convrelu(512, 256 + 512, 1, 0)
33 |
34 | self.layer0_2 = copy.deepcopy(self.layer0)
35 | self.layer1_2 = copy.deepcopy(self.layer1)
36 | self.layer2_2 = copy.deepcopy(self.layer2)
37 | self.layer3_2 = copy.deepcopy(self.layer3)
38 | self.layer4_2 = copy.deepcopy(self.layer4)
39 |
40 |
41 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
42 |
43 | self.upsample_2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
44 |
45 | self.conv_up3 = convrelu(256 + 512, 128 + 512, 3, 1)
46 | self.conv_up2 = convrelu(128 + 512, 64 + 256, 3, 1)
47 | self.conv_up1 = convrelu( 64 + 256, 64 + 256, 3, 1)
48 | self.conv_up0 = convrelu( 64 + 256, 64 + 128, 3, 1)
49 |
50 | self.conv_up3_2 = convrelu(512, 512, 3, 1)
51 | self.conv_up2_2 = convrelu(512, 256, 3, 1)
52 | self.conv_up1_2 = convrelu(256, 256, 3, 1)
53 | self.conv_up0_2 = convrelu(256, 128, 3, 1)
54 |
55 | self.conv_original_size0 = convrelu(3, 64, 3, 1)
56 | self.conv_original_size1 = convrelu(64, 64, 3, 1)
57 | self.conv_original_size2 = convrelu(64 + 128, 64, 3, 1)
58 |
59 | self.conv_original_size0_2 = convrelu(3, 64, 3, 1)
60 | self.conv_original_size1_2 = convrelu(64, 64, 3, 1)
61 | self.conv_original_size2_2 = convrelu(128, 64, 3, 1)
62 |
63 | self.conv_last = nn.Conv2d(64, n_class, 1)
64 |
65 | self.conv_last_2 = nn.Conv2d(64, n_class, 1)
66 |
67 | def forward(self, input, flag=0):
68 |
69 | if (flag == 0): #'im-input'
70 |
71 | # Image Encoder
72 | x_original = self.conv_original_size0(input)
73 | x_original = self.conv_original_size1(x_original)
74 |
75 | layer0 = self.layer0(input)
76 | layer1 = self.layer1(layer0)
77 | layer2 = self.layer2(layer1)
78 | layer3 = self.layer3(layer2)
79 | layer4 = self.layer4(layer3)
80 |
81 | # Normal decoder
82 | layer4_1 = self.layer4_1x1(layer4)
83 | x_1 = self.upsample(layer4_1)
84 | layer3_1 = self.layer3_1x1(layer3)
85 | x_1 = torch.cat([x_1[:,:512,:,:], torch.max(x_1[:,512:,:,:] , layer3_1)], dim=1)
86 | x_1 = self.conv_up3(x_1)
87 |
88 | x_1 = self.upsample(x_1)
89 | layer2_1 = self.layer2_1x1(layer2)
90 | x_1 = torch.cat([x_1[:,:512,:,:], torch.max(x_1[:,512:,:,:] , layer2_1)], dim=1)
91 | x_1 = self.conv_up2(x_1)
92 |
93 | x_1 = self.upsample(x_1)
94 | layer1_1 = self.layer1_1x1(layer1)
95 | x_1 = torch.cat([x_1[:,:256,:,:], torch.max(x_1[:,256:,:,:] , layer1_1)], dim=1)
96 | x_1 = self.conv_up1(x_1)
97 |
98 | x_1 = self.upsample(x_1)
99 | layer0_1 = self.layer0_1x1(layer0)
100 | x_1 = torch.cat([x_1[:,:256,:,:], torch.max(x_1[:,256:,:,:] , layer0_1)], dim=1)
101 | x_1 = self.conv_up0(x_1)
102 |
103 | x_1 = self.upsample(x_1)
104 | x_1 = torch.cat([x_1[:,:128,:,:], torch.max(x_1[:,128:,:,:] , x_original)], dim=1)
105 | x_1 = self.conv_original_size2(x_1)
106 |
107 | out_1 = self.conv_last(x_1)
108 |
109 |
110 | # Image decoder
111 | x_2 = self.upsample_2(layer4)
112 | x_2 = self.conv_up3_2(x_2)
113 |
114 | x_2 = self.upsample_2(x_2)
115 | x_2 = self.conv_up2_2(x_2)
116 |
117 | x_2 = self.upsample_2(x_2)
118 | x_2 = self.conv_up1_2(x_2)
119 |
120 | x_2 = self.upsample_2(x_2)
121 | x_2 = self.conv_up0_2(x_2)
122 |
123 | x_2 = self.upsample_2(x_2)
124 | x_2 = self.conv_original_size2_2(x_2)
125 |
126 | out_2 = self.conv_last_2(x_2)
127 |
128 | return out_1, out_2
129 |
130 |
131 | if (flag == 1): #'norm-input'
132 |
133 | # Normal Encoder
134 | x_original = self.conv_original_size0_2(input)
135 | x_original = self.conv_original_size1_2(x_original)
136 |
137 | layer0 = self.layer0_2(input)
138 | layer1 = self.layer1_2(layer0)
139 | layer2 = self.layer2_2(layer1)
140 | layer3 = self.layer3_2(layer2)
141 | layer4 = self.layer4_2(layer3)
142 |
143 | # Normal decoder
144 | layer4 = self.layer4_1x1(layer4)
145 |
146 | x_1 = self.upsample(layer4)
147 | x_1 = self.conv_up3(x_1)
148 |
149 | x_1 = self.upsample(x_1)
150 | x_1 = self.conv_up2(x_1)
151 |
152 | x_1 = self.upsample(x_1)
153 | x_1 = self.conv_up1(x_1)
154 |
155 | x_1 = self.upsample(x_1)
156 | x_1 = self.conv_up0(x_1)
157 |
158 | x_1 = self.upsample(x_1)
159 | x_1 = self.conv_original_size2(x_1)
160 |
161 | out_1 = self.conv_last(x_1)
162 |
163 | return out_1
164 |
165 |
166 |
167 |
168 |
169 |
170 |
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/boukhayma/face_normals/f9018333bd049cc0c58a5ab87a843515386d7f5f/teaser.png
--------------------------------------------------------------------------------
/tester.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from torchvision.transforms import Compose, ToTensor
4 | from resnet_unet import ResNetUNet
5 | import cv2
6 | import numpy as np
7 | from PIL import Image
8 |
9 | img_transform = Compose([
10 | ToTensor()
11 | ])
12 |
13 | model = ResNetUNet(n_class = 3).cuda()
14 | model.load_state_dict(torch.load('data/model.pth'))
15 |
16 | model.eval()
17 |
18 | img = img_transform(Image.open('data/cropped/0.jpg')).unsqueeze(0)
19 | img = Variable(img.cuda())
20 |
21 | outs = model(img)[0]
22 | out = np.array(outs[0].data.permute(1,2,0).cpu())
23 | out = out / np.expand_dims(np.sqrt(np.sum(out * out, 2)),2)
24 | out = 127.5 * (out + 1.0)
25 |
26 | cv2.imwrite('data/output/0.jpg', cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------