├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── environment.yml
├── examples
├── 3dprint.jpg
├── jupyter_gif.gif
└── sample.jpg
├── pix2vertex
├── __init__.py
├── constants.py
├── detector.py
├── models
│ ├── __init__.py
│ └── pix2pix.py
├── reconstructor.py
└── utils.py
├── reconstruct_pipeline.ipynb
├── requirements.txt
├── setup.py
├── tests
└── test_pix2vertex.py
└── weights
└── README.md
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | build
3 | dist
4 | *__pycache__*
5 | *.egg-info*
6 |
7 | *.bz2
8 | *.dat
9 | *.pth
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Elad Richardson and Matan Sela
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation - Official PyTorch Implementation
2 |
3 | [](https://mybinder.org/v2/gh/eladrich/pix2vertex.pytorch/mybinder?filepath=reconstruct_pipeline.ipynb)
4 | [](https://badge.fury.io/py/pix2vertex)
5 | [](https://opensource.org/licenses/MIT)
6 |
7 |
8 | [[Arxiv]](https://arxiv.org/pdf/1703.10131.pdf) [[Video]](https://www.youtube.com/watch?v=6lUdSVcBB-k)
9 |
10 |
11 | Evaluation code for Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation. Finally ported to PyTorch!
12 |
13 |
14 |
15 |
16 |
17 | ## Recent Updates
18 |
19 | **`2020.10.27`**: Added STL support
20 |
21 | **`2020.05.07`**: Added a wheel package!
22 |
23 | **`2020.05.06`**: Added [myBinder](https://mybinder.org/v2/gh/eladrich/pix2vertex.pytorch/mybinder?filepath=reconstruct_pipeline.ipynb) version for quick testing of the model
24 |
25 | **`2020.04.30`**: Initial pyTorch release
26 |
27 | # What's in this release?
28 |
29 | The [original pix2vertex repo](https://github.com/matansel/pix2vertex) was composed of three parts
30 | - A network to perform the image to depth + correspondence maps trained on synthetic facial data
31 | - A non-rigid ICP scheme for converting the output maps to a full 3D Mesh
32 | - A shape-from-shading scheme for adding fine mesoscopic details
33 |
34 |
35 | This repo currently contains our image-to-image network with weights and model to `PyTorch` and a simple `python` postprocessing scheme.
36 | - The released network was trained on a combination of synthetic images and unlabeled real images for some extra robustness :)
37 |
38 | ## Installation
39 | Installation from PyPi
40 | ```bash
41 | $ pip install pix2vertex
42 | ```
43 | Installation from source
44 | ```bash
45 | $ git clone https://github.com/eladrich/pix2vertex.pytorch.git
46 | $ cd pix2vertex.pytorch
47 | $ python setup.py install
48 | ```
49 | ## Usage
50 | The quickest way to try `p2v` is using the `reconstruct` method over an input image, followed by visualization or STL creation.
51 | ```python
52 | import pix2vertex as p2v
53 | from imageio import imread
54 |
55 | image = imread()
56 | result, crop = p2v.reconstruct(image)
57 |
58 | # Interactive visualization in a notebook
59 | p2v.vis_depth_interactive(result['Z_surface'])
60 |
61 | # Static visualization using matplotlib
62 | p2v.vis_depth_matplotlib(crop, result['Z_surface'])
63 |
64 | # Export to STL
65 | p2v.save2stl(result['Z_surface'], 'res.stl')
66 | ```
67 | For a more complete example see the `reconstruct_pipeline` notebook. You can give it a try without any installations using our [binder port](https://mybinder.org/v2/gh/eladrich/pix2vertex.pytorch/mybinder?filepath=reconstruct_pipeline.ipynb).
68 |
69 | ### Pretrained Model
70 | Models can be downloaded from these links:
71 | - [pix2vertex model](https://drive.google.com/open?id=1op5_zyH4CWm_JFDdCUPZM4X-A045ETex)
72 | - [dlib landmark predictor](http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2) - note that the dlib model has its own license.
73 |
74 | If no model path is specified the package automagically downloads the required models.
75 |
76 |
77 | ## TODOs
78 | - [x] Port Torch model to PyTorch
79 | - [x] Release an inference notebook (using [K3D](https://github.com/K3D-tools/K3D-jupyter))
80 | - [x] Add requirements
81 | - [x] Pack as wheel
82 | - [x] Ported to MyBinder
83 | - [x] Add a simple method to export a stl file for printing
84 | - [ ] Port the Shape-from-Shading method used in our matlab paper
85 | - [ ] Write a short blog about the revised training scheme
86 |
87 | ## Citation
88 | If you use this code for your research, please cite our paper Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation:
89 |
90 | ```
91 | @article{sela2017unrestricted,
92 | title={Unrestricted Facial Geometry Reconstruction Using Image-to-Image Translation},
93 | author={Sela, Matan and Richardson, Elad and Kimmel, Ron},
94 | journal={arxiv},
95 | year={2017}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: py36
2 | channels:
3 | - menpo
4 | - defaults
5 | dependencies:
6 | - python=3.6
7 | - dlib
8 | - pip
9 | - pip:
10 | - matplotlib
11 | - scikit-image
12 | - torch
13 | - colormap
14 | - easydev
15 | - k3d==2.7.4
16 | - tqdm
17 | - six
18 | - requests
--------------------------------------------------------------------------------
/examples/3dprint.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/examples/3dprint.jpg
--------------------------------------------------------------------------------
/examples/jupyter_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/examples/jupyter_gif.gif
--------------------------------------------------------------------------------
/examples/sample.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/examples/sample.jpg
--------------------------------------------------------------------------------
/pix2vertex/__init__.py:
--------------------------------------------------------------------------------
1 | from .reconstructor import Reconstructor
2 | from .detector import Detector
3 | from .utils import vis_net_result, vis_depth_interactive, vis_pcloud_interactive, vis_depth_matplotlib, save2stl
4 |
5 | reconstructor = None
6 | def reconstruct(image=None, verbose=False):
7 | global reconstructor
8 | if reconstructor is None:
9 | reconstructor = Reconstructor()
10 | if image is None:
11 | import os
12 | from .constants import sample_image
13 | image = os.path.join(os.path.dirname(__file__), sample_image)
14 | print('No image specified, using {} as default input image'.format(image))
15 | return reconstructor.run(image, verbose)
16 |
--------------------------------------------------------------------------------
/pix2vertex/constants.py:
--------------------------------------------------------------------------------
1 | predictor_url = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
2 | predictor_file = '../weights/shape_predictor_68_face_landmarks.dat'
3 | p2v_model_gdrive_id = '1op5_zyH4CWm_JFDdCUPZM4X-A045ETex'
4 | sample_image = '../examples/sample.jpg'
5 |
--------------------------------------------------------------------------------
/pix2vertex/detector.py:
--------------------------------------------------------------------------------
1 | import os
2 | import dlib
3 | import numpy as np
4 | import math
5 | from skimage.transform import resize
6 | from .utils import download_url, extract_file
7 |
8 |
9 | class Detector:
10 | def __init__(self, predictor_path=None):
11 | self.detector = dlib.get_frontal_face_detector()
12 | self.set_predictor(predictor_path)
13 |
14 | def set_predictor(self, predictor_path):
15 | if predictor_path is None:
16 | from .constants import predictor_file
17 | predictor_path = os.path.join(os.path.dirname(__file__), predictor_file)
18 | print('Loading default detector weights from {}'.format(predictor_path))
19 | if not os.path.exists(predictor_path):
20 | from .constants import predictor_url
21 | os.makedirs(os.path.dirname(predictor_path), exist_ok=True)
22 | print('\tDownloading weights from {}...'.format(predictor_url))
23 | download_url(predictor_url, save_path=os.path.dirname(predictor_path))
24 | print('\tExtracting weights...')
25 | extract_file(predictor_path + '.bz2', os.path.dirname(predictor_path))
26 | print('\tDone!')
27 | self.predictor = dlib.shape_predictor(predictor_path)
28 |
29 | def detect_and_crop(self, img, img_size=512):
30 | dets = self.detector(img, 1) # Take a single detection
31 | for k, d in enumerate(dets):
32 | print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
33 | k, d.left(), d.top(), d.right(), d.bottom()))
34 | dets = dets[0]
35 | shape = self.predictor(img, dets)
36 | points = shape.parts()
37 |
38 | pts = np.array([[p.x, p.y] for p in points])
39 | min_x = np.min(pts[:, 0])
40 | min_y = np.min(pts[:, 1])
41 | max_x = np.max(pts[:, 0])
42 | max_y = np.max(pts[:, 1])
43 | box_width = (max_x - min_x) * 1.2
44 | box_height = (max_y - min_y) * 1.2
45 | bbox = np.array([min_y - box_height * 0.3, min_x, box_height, box_width]).astype(np.int)
46 |
47 | img_crop = Detector.adjust_box_and_crop(img, bbox, crop_percent=150, img_size=img_size)
48 | # img_crop = img[bbox[0]:bbox[0]+bbox[2], bbox[1]:bbox[1]+bbox[3], :]
49 | return img_crop
50 |
51 | @staticmethod
52 | def adjust_box_and_crop(img, bbox, crop_percent=100, img_size=None):
53 | w_ext = math.floor(bbox[2])
54 | h_ext = math.floor(bbox[3])
55 | bbox_center = np.round(np.array([bbox[0] + 0.5 * bbox[2], bbox[1] + 0.5 * bbox[3]]))
56 | max_ext = np.round(crop_percent / 100 * max(w_ext, h_ext) / 2)
57 | top = max(1, bbox_center[0] - max_ext)
58 | left = max(1, bbox_center[1] - max_ext)
59 | bottom = min(img.shape[0], bbox_center[0] + max_ext)
60 | right = min(img.shape[1], bbox_center[1] + max_ext)
61 | height = bottom - top
62 | width = right - left
63 | # make the frame as square as possible
64 | if height < width:
65 | diff = width - height
66 | top_pad = int(max(0, np.floor(diff / 2) - top + 1))
67 | top = max(1, top - np.floor(diff / 2))
68 | bottom_pad = int(max(0, bottom + np.ceil(diff / 2) - img.shape[0]))
69 | bottom = min(img.shape[0], bottom + np.ceil(diff / 2))
70 | left_pad = 0
71 | right_pad = 0
72 | else:
73 | diff = height - width
74 | left_pad = int(max(0, np.floor(diff / 2) - left + 1))
75 | left = max(1, left - np.floor(diff / 2))
76 | right_pad = int(max(0, right + np.ceil(diff / 2) - img.shape[1]))
77 | right = min(img.shape[1], right + np.ceil(diff / 2))
78 | top_pad = 0
79 | bottom_pad = 0
80 |
81 | # crop the image
82 | img_crop = img[int(top):int(bottom), int(left):int(right), :]
83 | # pad the image
84 | img_crop = np.pad(img_crop, ((top_pad, bottom_pad), (left_pad, right_pad), (0, 0)), 'constant')
85 | if img_size is not None:
86 | img_crop = resize(img_crop, (img_size, img_size))*255
87 | img_crop = img_crop.astype(np.uint8)
88 | return img_crop
89 |
--------------------------------------------------------------------------------
/pix2vertex/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/eladrich/pix2vertex.pytorch/1d2bb61c91584af0ca1495b97a6a6de847df2cec/pix2vertex/models/__init__.py
--------------------------------------------------------------------------------
/pix2vertex/models/pix2pix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | def conv_block(in_channels, out_channels):
5 | return nn.Sequential(
6 | nn.LeakyReLU(0.2,inplace=True),
7 | nn.Conv2d(in_channels, out_channels, 4, stride=2,padding=1),
8 | nn.BatchNorm2d(out_channels)
9 | )
10 |
11 |
12 | def deconv_block(in_channels, out_channels,use_dropout=False):
13 | layers = [
14 | nn.ReLU(inplace=True),
15 | nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2,padding=1),
16 | nn.BatchNorm2d(out_channels)
17 | ]
18 | if use_dropout:
19 | layers.append(nn.Dropout(0.5))
20 | return nn.Sequential(*layers)
21 |
22 | class UNet(nn.Module):
23 |
24 | def __init__(self):
25 | super().__init__()
26 |
27 | self.conv_down1 = nn.Conv2d(3,64,4,stride=2,padding=1)
28 |
29 | self.conv_down2 = conv_block(64,128)
30 |
31 | self.conv_down3 = conv_block(128,256)
32 |
33 | self.conv_down4 = conv_block(256,512)
34 |
35 | self.conv_down5 = conv_block(512,512)
36 |
37 | self.conv_down6 = conv_block(512,512)
38 |
39 | self.conv_down7 = conv_block(512,512)
40 |
41 | self.conv_down8 = conv_block(512,512)
42 |
43 | self.conv_up1 = deconv_block(512,512,use_dropout=True)
44 |
45 | self.conv_up2 = deconv_block(1024,512,use_dropout=True)
46 |
47 | self.conv_up3 = deconv_block(1024,512,use_dropout=True)
48 |
49 | self.conv_up4 = deconv_block(1024,512)
50 |
51 | self.conv_up5 = deconv_block(1024,256)
52 |
53 | self.conv_up6 = deconv_block(512,128)
54 |
55 | self.conv_up7 = deconv_block(256,64)
56 |
57 | self.conv_up8 = deconv_block(128,64)
58 |
59 | self.conv_up9 = nn.Sequential(
60 | nn.ReLU(inplace=True),
61 | nn.ConvTranspose2d(64, 64, 3, stride=1,padding=1),
62 | nn.BatchNorm2d(64)
63 | )
64 |
65 | self.conv_up9 = nn.Sequential(
66 | nn.ReLU(inplace=True),
67 | nn.ConvTranspose2d(64, 64, 3, stride=1,padding=1),
68 | nn.BatchNorm2d(64)
69 | )
70 |
71 | self.conv_up10 = nn.Sequential(
72 | nn.ReLU(inplace=True),
73 | nn.ConvTranspose2d(64, 32, 3, stride=1,padding=1),
74 | nn.BatchNorm2d(32)
75 | )
76 |
77 | self.conv_up11 = nn.Sequential(
78 | nn.ReLU(inplace=True),
79 | nn.ConvTranspose2d(32, 7, 3, stride=1,padding=1)
80 | )
81 |
82 | # TODO: rewrite nicely
83 |
84 | def forward(self, x):
85 | down1 = self.conv_down1(x)
86 | down2 = self.conv_down2(down1)
87 | down3 = self.conv_down3(down2)
88 | down4 = self.conv_down4(down3)
89 | down5 = self.conv_down5(down4)
90 | down6 = self.conv_down6(down5)
91 | down7 = self.conv_down7(down6)
92 | down8 = self.conv_down8(down7)
93 |
94 | up1 = self.conv_up1(down8)
95 | up2 = self.conv_up2(torch.cat([up1, down7], dim=1))
96 | up3 = self.conv_up3(torch.cat([up2, down6], dim=1))
97 | up4 = self.conv_up4(torch.cat([up3, down5], dim=1))
98 | up5 = self.conv_up5(torch.cat([up4, down4], dim=1))
99 | up6 = self.conv_up6(torch.cat([up5, down3], dim=1))
100 | up7 = self.conv_up7(torch.cat([up6, down2], dim=1))
101 | up8 = self.conv_up8(torch.cat([up7, down1], dim=1))
102 | up9 = self.conv_up9(up8)
103 | up10 = self.conv_up10(up9)
104 | up11 = self.conv_up11(up10)
105 |
106 | return up11
107 |
--------------------------------------------------------------------------------
/pix2vertex/reconstructor.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import numpy as np
4 | import torch
5 | from imageio import imread
6 |
7 | from .models import pix2pix
8 | from .detector import Detector
9 |
10 | class Reconstructor:
11 | def __init__(self, weights_path=None, detector=None):
12 | if detector is None:
13 | detector = Detector()
14 | self.detector = detector
15 | self.unet = pix2pix.UNet()
16 | self.set_initial_weights(weights_path)
17 | self.unet.train() # As in the original pix2pix, works as InstanceNormalization
18 |
19 | def set_initial_weights(self, weights_path):
20 | if weights_path is None:
21 | weights_path = os.path.join(os.path.dirname(__file__),
22 | '../weights/faces_hybrid_and_rotated_2.pth')
23 | print('loading default reconstructor weights from {}'.format(weights_path))
24 | if not os.path.exists(weights_path):
25 | from .utils import download_from_gdrive
26 | from .constants import p2v_model_gdrive_id
27 | os.makedirs(os.path.dirname(weights_path), exist_ok=True)
28 | print('\tDownloading weights...')
29 | download_from_gdrive(p2v_model_gdrive_id, weights_path)
30 | print('\tDone!')
31 | self.initial_weights = torch.load(weights_path)
32 |
33 | def run(self, image, verbose=False):
34 | if type(image) is str:
35 | image = imread(image)
36 | image_cropped = self.detector.detect_and_crop(image)
37 | net_res = self.run_net(image_cropped)
38 | final_res = self.post_process(net_res)
39 | if verbose:
40 | from . import vis_depth_interactive
41 | vis_depth_interactive(final_res['Z_surface'])
42 | return final_res, image_cropped
43 |
44 | def run_net(self, img):
45 | # Because is actually instance normalization need to copy weights each time
46 | self.unet.load_state_dict(copy.deepcopy(self.initial_weights), strict=True)
47 |
48 | # Forward
49 | input = torch.from_numpy(img.transpose()).float()
50 | input = input.unsqueeze(0)
51 | input = input.transpose(2, 3)
52 | input = input.div(255.0).mul(2).add(-1)
53 | output = self.unet(input)
54 | output = output.add(1).div(2).mul(255)
55 |
56 | # Post Processing
57 | im_both = output.squeeze(0).detach().numpy().transpose().swapaxes(0, 1).copy()
58 | im_pncc = im_both[:, :, 0:3]
59 | im_depth = im_both[:, :, 3:6]
60 | im_depth[np.logical_and(im_depth < 10, im_depth > -10)] = 0
61 | im_pncc[np.logical_and(im_pncc < 10, im_pncc > -10)] = 0
62 |
63 | return {'pnnc': im_pncc, 'depth': im_depth}
64 |
65 | def post_process(self, net_res):
66 | im_pncc = net_res['pnnc'].astype(np.float64)
67 | im_depth = net_res['depth'].astype(np.float64)
68 | net_X = im_depth[:, :, 0] * (1.3674) / 255 - 0.6852
69 | net_Y = im_depth[:, :, 1] * (1.8401) / 255 - 0.9035
70 | net_Z = im_depth[:, :, 2] * (0.7542) / 255 - 0.2997
71 | mask = np.any(im_depth, axis=2) * np.all(im_pncc, axis=2)
72 |
73 | X = np.tile(np.linspace(-1, 1, im_depth.shape[1]), (im_depth.shape[0], 1))
74 | Y = np.tile(np.linspace(1, -1, im_depth.shape[0]).reshape(-1, 1), (1, im_depth.shape[1]))
75 |
76 | # Normalize fixed grid according to the network result, as X,Y are actually redundant
77 | X = (X - np.mean(X[mask])) / np.std(X[mask]) * np.std(net_X[mask]) + np.mean(net_X[mask])
78 | Y = (Y - np.mean(Y[mask])) / np.std(Y[mask]) * np.std(net_Y[mask]) + np.mean(net_Y[mask])
79 |
80 | Z = net_Z * 2 # Due to image resizing
81 |
82 | f = 1 / (X[0, 1] - X[0, 0])
83 |
84 | Z_surface = Z * f
85 | Z_surface[mask == False] = np.nan
86 | Z[mask == False] = np.nan
87 |
88 | return {'Z': Z, 'X': X, 'Y': Y, 'Z_surface': Z_surface}
89 |
--------------------------------------------------------------------------------
/pix2vertex/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 | from tqdm import tqdm
5 | from itertools import product
6 | import struct
7 |
8 | ASCII_FACET = """ facet normal {face[0]:e} {face[1]:e} {face[2]:e}
9 | outer loop
10 | vertex {face[3]:e} {face[4]:e} {face[5]:e}
11 | vertex {face[6]:e} {face[7]:e} {face[8]:e}
12 | vertex {face[9]:e} {face[10]:e} {face[11]:e}
13 | endloop
14 | endfacet"""
15 |
16 | BINARY_HEADER = "80sI"
17 | BINARY_FACET = "12fH"
18 |
19 |
20 | # Saving to STL is based on https://github.com/thearn/stl_tools/
21 |
22 | def vis_depth_matplotlib(img, Z, elevation=60, azimuth=45, stride=5):
23 | from mpl_toolkits.mplot3d import Axes3D
24 | from matplotlib import cm
25 | import matplotlib.pyplot as plt
26 | from matplotlib.colors import LightSource
27 |
28 | fig = plt.figure()
29 | ax = fig.gca(projection='3d')
30 |
31 | # Create X and Y data
32 | x = np.arange(0, 512, 1)
33 | y = np.arange(0, 512, 1)
34 | X, Y = np.meshgrid(x, y)
35 |
36 | ls = LightSource(azdeg=0, altdeg=90)
37 | # shade data, creating an rgb array.
38 | img_c = np.concatenate((img.astype('float') / 255, np.ones((512, 512, 1))), axis=2)
39 | rgb = ls.shade_rgb(img.astype('float') / 255, Z, blend_mode='overlay')
40 |
41 | surf = ax.plot_surface(X, Y, Z, cstride=stride, rstride=stride, linewidth=0, antialiased=False, facecolors=rgb)
42 | ax.view_init(elev=elevation, azim=azimuth)
43 | # ax.view_init(elev=90., azim=90)`
44 | # Show the plot
45 | plt.show()
46 |
47 |
48 | def vis_depth_interactive(Z):
49 | import k3d
50 | Nx, Ny = 1, 1
51 | xmin, xmax = 0, Z.shape[0]
52 | ymin, ymax = 0, Z.shape[1]
53 |
54 | x = np.linspace(xmin, xmax, Nx)
55 | y = np.linspace(ymin, ymax, Ny)
56 | x, y = np.meshgrid(x, y)
57 | plot = k3d.plot(grid_auto_fit=True, camera_auto_fit=False)
58 | plt_surface = k3d.surface(-Z.astype(np.float32), color=0xb2ccff,
59 | bounds=[xmin, xmax, ymin, ymax]) # -Z for mirroring
60 | plot += plt_surface
61 | plot.display()
62 | plot.camera = [242.57934019166004, 267.50948550191197, -406.62328311352337, 256, 256, -8.300323486328125,
63 | -0.13796270478729053, -0.987256298362836, -0.07931767413815752]
64 | return plot
65 |
66 |
67 | def vis_pcloud_interactive(res, img):
68 | import k3d
69 | from colormap import rgb2hex
70 | color_vals = np.zeros((img.shape[0], img.shape[1]))
71 | for i in range(img.shape[0]):
72 | for j in range(img.shape[1]):
73 | color_vals[i, j] = int(rgb2hex(img[i, j, 0], img[i, j, 1], img[i, j, 2]).replace('#', '0x'), 0)
74 | colors = color_vals.flatten()
75 |
76 | points = np.stack((res['X'].flatten(), res['Y'].flatten(), res['Z'].flatten()), axis=1)
77 |
78 | invalid_inds = np.any(np.isnan(points), axis=1)
79 | points_valid = points[invalid_inds == False]
80 | colors_valid = colors[invalid_inds == False]
81 | plot = k3d.plot(grid_auto_fit=True, camera_auto_fit=False)
82 | plot += k3d.points(points_valid, colors_valid, point_size=0.01, compression_level=9, shader='flat')
83 | plot.display()
84 | plot.camera = [-0.3568942548181382, -0.12775125650240726, 3.5390732533009452, 0.33508163690567017,
85 | 0.3904658555984497, -0.0499117374420166, 0.11033077266672488, 0.9696364582197756, 0.2182481603445357]
86 | return plot
87 |
88 |
89 | def vis_net_result(img, net_result):
90 | plt.figure()
91 | plt.subplot(1, 3, 1)
92 | plt.imshow(img)
93 | plt.title('Input Image')
94 | plt.subplot(1, 3, 2)
95 | plt.imshow(net_result['pnnc'].astype(np.uint8))
96 | plt.title('PNCC Visualization')
97 | plt.subplot(1, 3, 3)
98 | plt.imshow(net_result['depth'][:, :, 2].astype(np.uint8), cmap='gray')
99 | plt.title('Depth Visualization')
100 | plt.show()
101 |
102 |
103 | class TqdmUpTo(tqdm):
104 | """Provides `update_to(n)` which uses `tqdm.update(delta_n)`."""
105 |
106 | def update_to(self, b=1, bsize=1, tsize=None):
107 | """
108 | b : int, optional
109 | Number of blocks transferred so far [default: 1].
110 | bsize : int, optional
111 | Size of each block (in tqdm units) [default: 1].
112 | tsize : int, optional
113 | Total size (in tqdm units). If [default: None] remains unchanged.
114 | """
115 | if tsize is not None:
116 | self.total = tsize
117 | self.update(b * bsize - self.n) # will also set self.n = b * bsize
118 |
119 |
120 | def download_url(url, save_path):
121 | from six.moves import urllib
122 | save_path = os.path.expanduser(save_path)
123 | if not os.path.exists(save_path):
124 | makedir(save_path)
125 |
126 | filename = url.rpartition('/')[2]
127 | filepath = os.path.join(save_path, filename)
128 |
129 | try:
130 | with TqdmUpTo(unit='B', unit_scale=True, unit_divisor=1024, miniters=1,
131 | desc=url.split('/')[-1]) as t: # all optional kwargs
132 | urllib.request.urlretrieve(url, filepath, reporthook=t.update_to)
133 | t.total = t.n
134 | except ValueError:
135 | raise Exception('Failed to download! Check URL: ' + url +
136 | ' and local path: ' + save_path)
137 |
138 |
139 | def extract_file(path, to_directory=None):
140 | path = os.path.expanduser(path)
141 | if path.endswith('.zip'):
142 | opener, mode = zipfile.ZipFile, 'r'
143 | elif path.endswith(('.tar.gz', '.tgz')):
144 | opener, mode = tarfile.open, 'r:gz'
145 | elif path.endswith(('tar.bz2', '.tbz')):
146 | opener, mode = tarfile.open, 'r:bz2'
147 | elif path.endswith('.bz2'):
148 | import bz2
149 | opener, mode = bz2.BZ2File, 'rb'
150 | with open(path[:-4], 'wb') as fp_out, opener(path, 'rb') as fp_in:
151 | for data in iter(lambda: fp_in.read(100 * 1024), b''):
152 | fp_out.write(data)
153 | return
154 | else:
155 | raise (ValueError,
156 | "Could not extract `{}` as no extractor is found!".format(path))
157 |
158 | if to_directory is None:
159 | to_directory = os.path.abspath(os.path.join(path, os.path.pardir))
160 | cwd = os.getcwd()
161 | os.chdir(to_directory)
162 |
163 | try:
164 | file = opener(path, mode)
165 | try:
166 | file.extractall()
167 | finally:
168 | file.close()
169 | finally:
170 | os.chdir(cwd)
171 |
172 |
173 | def download_from_gdrive(id, destination):
174 | import requests
175 | URL = "https://docs.google.com/uc?export=download"
176 |
177 | session = requests.Session()
178 |
179 | response = session.get(URL, params={'id': id}, stream=True)
180 | token = get_confirm_token(response)
181 |
182 | if token:
183 | params = {'id': id, 'confirm': token}
184 | response = session.get(URL, params=params, stream=True)
185 |
186 | save_response_content(response, destination)
187 |
188 |
189 | def get_confirm_token(response):
190 | for key, value in response.cookies.items():
191 | if key.startswith('download_warning'):
192 | return value
193 |
194 | return None
195 |
196 |
197 | def save_response_content(response, destination):
198 | CHUNK_SIZE = 32768
199 | t = tqdm(unit='B', unit_scale=True, miniters=1)
200 |
201 | with open(destination, "wb") as f:
202 | for chunk in response.iter_content(CHUNK_SIZE):
203 | if chunk: # filter out keep-alive new chunks
204 | t.update(len(chunk))
205 | f.write(chunk)
206 |
207 |
208 | def _build_binary_stl(facets):
209 | """returns a string of binary binary data for the stl file"""
210 |
211 | lines = [struct.pack(BINARY_HEADER, b'Binary STL Writer', len(facets)), ]
212 | for facet in facets:
213 | facet = list(facet)
214 | facet.append(0) # need to pad the end with a unsigned short byte
215 | lines.append(struct.pack(BINARY_FACET, *facet))
216 | return lines
217 |
218 |
219 | def _build_ascii_stl(facets):
220 | """returns a list of ascii lines for the stl file """
221 |
222 | lines = ['solid ffd_geom', ]
223 | for facet in facets:
224 | lines.append(ASCII_FACET.format(face=facet))
225 | lines.append('endsolid ffd_geom')
226 | return lines
227 |
228 |
229 | def writeSTL(facets, file_name, ascii=False):
230 | """writes an ASCII or binary STL file"""
231 |
232 | f = open(file_name, 'wb')
233 | if ascii:
234 | lines = _build_ascii_stl(facets)
235 | lines_ = "\n".join(lines).encode("UTF-8")
236 | f.write(lines_)
237 | else:
238 | data = _build_binary_stl(facets)
239 | data = b"".join(data)
240 | f.write(data)
241 |
242 | f.close()
243 |
244 |
245 | def save2stl(A, fn, scale=1, mask_val=None, ascii=False,
246 | max_width=235.,
247 | max_depth=140.,
248 | max_height=150.,
249 | solid=False,
250 | rotate=True,
251 | min_thickness_percent=0.1):
252 | """
253 | Reads a numpy array, and outputs an STL file
254 | Inputs:
255 | A (ndarray) - an 'm' by 'n' 2D numpy array
256 | fn (string) - filename to use for STL file
257 | Optional input:
258 | scale (float) - scales the height (surface) of the
259 | resulting STL mesh. Tune to match needs
260 | mask_val (float) - any element of the inputted array that is less
261 | than this value will not be included in the mesh.
262 | default renders all vertices (x > -inf for all float x)
263 | ascii (bool) - sets the STL format to ascii or binary (default)
264 | max_width, max_depth, max_height (floats) - maximum size of the stl
265 | object (in mm). Match this to
266 | the dimensions of a 3D printer
267 | platform
268 | solid (bool): sets whether to create a solid geometry (with sides and
269 | a bottom) or not.
270 | min_thickness_percent (float) : when creating the solid bottom face, this
271 | multiplier sets the minimum thickness in
272 | the final geometry (shallowest interior
273 | point to bottom face), as a percentage of
274 | the thickness of the model computed up to
275 | that point.
276 | Returns: (None)
277 | """
278 |
279 | # Remove Nans, set their values as the minimal one
280 | A = A.copy()
281 | A[np.isnan(A)] = A[~np.isnan(A)].min()
282 |
283 | m, n = A.shape
284 | if n >= m and rotate:
285 | # rotate to best fit a printing platform
286 | A = np.rot90(A, k=3)
287 | m, n = n, m
288 | A = scale * (A - A.min())
289 |
290 | if not mask_val:
291 | mask_val = A.min() # - 1.
292 |
293 | facets = []
294 | mask = np.zeros((m, n))
295 | print("Creating top mesh...")
296 | for i, k in product(range(m - 1), range(n - 1)):
297 | this_pt = np.array([i - m / 2., k - n / 2., A[i, k]])
298 | top_right = np.array([i - m / 2., k + 1 - n / 2., A[i, k + 1]])
299 | bottom_left = np.array([i + 1. - m / 2., k - n / 2., A[i + 1, k]])
300 | bottom_right = np.array(
301 | [i + 1. - m / 2., k + 1 - n / 2., A[i + 1, k + 1]])
302 |
303 | n1, n2 = np.zeros(3), np.zeros(3)
304 | is_a, is_b = False, False
305 | if (this_pt[-1] > mask_val and top_right[-1] > mask_val and
306 | bottom_left[-1] > mask_val):
307 | facet = np.concatenate([n1, top_right, this_pt, bottom_right])
308 | mask[i, k] = 1
309 | mask[i, k + 1] = 1
310 | mask[i + 1, k] = 1
311 | facets.append(facet)
312 |
313 | if (this_pt[-1] > mask_val and bottom_right[-1] > mask_val and
314 | bottom_left[-1] > mask_val):
315 | facet = np.concatenate(
316 | [n2, bottom_right, this_pt, bottom_left])
317 | facets.append(facet)
318 | mask[i, k] = 1
319 | mask[i + 1, k + 1] = 1
320 | mask[i + 1, k] = 1
321 |
322 | print('\t', len(facets), 'facets')
323 | facets = np.array(facets)
324 |
325 | if solid:
326 | print("Computed edges...")
327 | edge_mask = np.sum([roll2d(mask, (i, k))
328 | for i, k in product([-1, 0, 1], repeat=2)],
329 | axis=0)
330 | edge_mask[np.where(edge_mask == 9.)] = 0.
331 | edge_mask[np.where(edge_mask != 0.)] = 1.
332 | edge_mask[0::m - 1, :] = 1.
333 | edge_mask[:, 0::n - 1] = 1.
334 | X, Y = np.where(edge_mask == 1.)
335 | locs = zip(X - m / 2., Y - n / 2.)
336 |
337 | zvals = facets[:, 5::3]
338 | zmin, zthickness = zvals.min(), zvals.ptp()
339 |
340 | minval = zmin - min_thickness_percent * zthickness
341 |
342 | bottom = []
343 | print("Extending edges, creating bottom...")
344 | for i, facet in enumerate(facets):
345 | if (facet[3], facet[4]) in locs:
346 | facets[i][5] = minval
347 | if (facet[6], facet[7]) in locs:
348 | facets[i][8] = minval
349 | if (facet[9], facet[10]) in locs:
350 | facets[i][11] = minval
351 | this_bottom = np.concatenate(
352 | [facet[:3], facet[6:8], [minval], facet[3:5], [minval],
353 | facet[9:11], [minval]])
354 | bottom.append(this_bottom)
355 |
356 | facets = np.concatenate([facets, bottom])
357 |
358 | xsize = facets[:, 3::3].ptp()
359 | if xsize > max_width:
360 | facets = facets * float(max_width) / xsize
361 |
362 | ysize = facets[:, 4::3].ptp()
363 | if ysize > max_depth:
364 | facets = facets * float(max_depth) / ysize
365 |
366 | zsize = facets[:, 5::3].ptp()
367 | if zsize > max_height:
368 | facets = facets * float(max_height) / zsize
369 |
370 | print('Writing STL...')
371 | writeSTL(facets, fn, ascii=ascii)
372 | print('Done!')
373 |
--------------------------------------------------------------------------------
/reconstruct_pipeline.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "%load_ext autoreload\n",
10 | "%autoreload 2\n",
11 | "\n",
12 | "import matplotlib\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import numpy as np\n",
15 | "from skimage import io\n",
16 | "import pix2vertex as p2v\n",
17 | "\n",
18 | "matplotlib.rcParams['figure.figsize'] = (13,7)"
19 | ]
20 | },
21 | {
22 | "cell_type": "markdown",
23 | "metadata": {},
24 | "source": [
25 | "## Initializations"
26 | ]
27 | },
28 | {
29 | "cell_type": "code",
30 | "execution_count": null,
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "detector = p2v.Detector()\n",
35 | "reconstructor = p2v.Reconstructor(detector=detector)"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "metadata": {},
42 | "outputs": [],
43 | "source": [
44 | "im_path = 'examples/sample.jpg' # im_path can be a URL as well!\n",
45 | "img = io.imread(im_path) \n",
46 | "fig = plt.figure()\n",
47 | "plt.imshow(img)\n",
48 | "plt.show()"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "## Inference"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "img_crop = detector.detect_and_crop(img)\n",
65 | "fig = plt.figure()\n",
66 | "plt.imshow(img_crop)\n",
67 | "plt.show()"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "metadata": {},
74 | "outputs": [],
75 | "source": [
76 | "net_res = reconstructor.run_net(img_crop)\n",
77 | "p2v.vis_net_result(img_crop,net_res)\n",
78 | "final_res = reconstructor.post_process(net_res)"
79 | ]
80 | },
81 | {
82 | "cell_type": "markdown",
83 | "metadata": {},
84 | "source": [
85 | "## Interactive Visualizations"
86 | ]
87 | },
88 | {
89 | "cell_type": "code",
90 | "execution_count": null,
91 | "metadata": {},
92 | "outputs": [],
93 | "source": [
94 | "plot = p2v.vis_depth_interactive(final_res['Z_surface'])"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "metadata": {},
101 | "outputs": [],
102 | "source": [
103 | "plot = p2v.vis_pcloud_interactive(final_res,img_crop)"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "# Fallback matplotlib visualization\n",
113 | "p2v.vis_depth_matplotlib(img_crop,final_res['Z_surface'])"
114 | ]
115 | },
116 | {
117 | "cell_type": "markdown",
118 | "metadata": {},
119 | "source": [
120 | "## Saving Result"
121 | ]
122 | },
123 | {
124 | "cell_type": "code",
125 | "execution_count": null,
126 | "metadata": {},
127 | "outputs": [],
128 | "source": [
129 | "p2v.save2stl(final_res['Z_surface'],'res.stl')"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "Create link to make accessible from notebook"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "from IPython.display import FileLink\n",
146 | "FileLink('res.stl')"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": null,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": []
155 | }
156 | ],
157 | "metadata": {
158 | "kernelspec": {
159 | "display_name": "Python 3",
160 | "language": "python",
161 | "name": "python3"
162 | },
163 | "language_info": {
164 | "codemirror_mode": {
165 | "name": "ipython",
166 | "version": 3
167 | },
168 | "file_extension": ".py",
169 | "mimetype": "text/x-python",
170 | "name": "python",
171 | "nbconvert_exporter": "python",
172 | "pygments_lexer": "ipython3",
173 | "version": "3.7.6"
174 | }
175 | },
176 | "nbformat": 4,
177 | "nbformat_minor": 4
178 | }
179 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | k3d==2.7.4
2 | easydev==0.9.38
3 | colormap==1.0.3
4 | torch
5 | imageio
6 | dlib
7 | scikit-image
8 | matplotlib
9 | tqdm
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # MIT License
4 | #
5 | # Copyright (c) 2020 Elad Richardson and Matan Sela
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | import os
26 | import setuptools
27 |
28 | with open(os.path.join(os.path.dirname(__file__), "README.md"), "r") as fh:
29 | long_description = fh.read()
30 |
31 | with open(os.path.join(os.path.dirname(__file__), "requirements.txt")) as f:
32 | required = f.read().splitlines()
33 |
34 | setuptools.setup(
35 | name="pix2vertex", # Replace with your own username
36 | version="1.0.4",
37 | author="Elad Richardson, Matan Sela",
38 | author_email="elad.richardson@gmail.com, matansel@gmail.com",
39 | description="3D face reconstruction from a single image",
40 | long_description=long_description,
41 | long_description_content_type="text/markdown",
42 | url="https://github.com/eladrich/pix2vertex.pytorch",
43 | packages=setuptools.find_packages(exclude=["tests.*", "tests"]),
44 | classifiers=[
45 | "Environment :: Console",
46 | "Intended Audience :: Developers",
47 | "Intended Audience :: Science/Research",
48 | "Intended Audience :: Education",
49 | "Programming Language :: Python :: 3",
50 | "License :: OSI Approved :: MIT License",
51 | "Operating System :: OS Independent",
52 | ],
53 | install_requires=required,
54 | include_package_data=True,
55 | keywords="pix2vertex face reconstruction 3d pytorch pip package",
56 | python_requires='>=3.4',
57 | zip_safe=False
58 | )
59 |
--------------------------------------------------------------------------------
/tests/test_pix2vertex.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from imageio import imread
3 | import pix2vertex as p2v
4 |
5 | from unittest import TestCase
6 |
7 | class TestPix2Vertex(TestCase):
8 | def test_reconstruct(self):
9 | image = imread('examples/sample.jpg')
10 | results = p2v.reconstruct(image)
11 | self.assertEqual(len(results), 2)
12 |
--------------------------------------------------------------------------------
/weights/README.md:
--------------------------------------------------------------------------------
1 | Place the dlib and pix2vertex models here, or simply run download.sh
2 |
--------------------------------------------------------------------------------