├── config
├── __init__.py
├── config.cfg
└── config.py
├── utils
├── __init__.py
├── render.py
├── pytorch3d_extend.py
├── my_dataset.py
├── funcs.py
├── models.py
└── completion.py
├── scripts
├── __init__.py
├── face_sample.py
├── face_completion.py
└── train.py
├── data
└── norm.pt
├── assets
├── image-20240716144300162.png
├── image-20240716144509692.png
└── image-20240716144812434.png
├── LICENSE
└── README.md
/config/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/norm.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/data/norm.pt
--------------------------------------------------------------------------------
/assets/image-20240716144300162.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/assets/image-20240716144300162.png
--------------------------------------------------------------------------------
/assets/image-20240716144509692.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/assets/image-20240716144509692.png
--------------------------------------------------------------------------------
/assets/image-20240716144812434.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/assets/image-20240716144812434.png
--------------------------------------------------------------------------------
/config/config.cfg:
--------------------------------------------------------------------------------
1 | [I/O parameters]
2 | dataset_dir = PATH_TO_DATASET
3 | template_file = FILE_TO_TEMPLATE
4 | checkpoint_dir = PATH_TO_CHECKPOINTS
5 |
6 | [model parameters]
7 | n_layers = 3
8 | z_length = 256
9 | down_sampling_factors = 4, 4, 4
10 | num_features_local = 3, 8, 16, 32
11 | num_features_global = 3, 64, 512, 2048
12 | batch_norm = 1
13 |
14 | [training parameters]
15 | num_workers = 4
16 | lr = 0.001
17 | batch_size = 32
18 | weight_decay = 0.0005
19 | epoch = 1000
20 | lambda_reg = 0.001
21 |
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 dragonylee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/face_sample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5 |
6 | import argparse
7 | import torch
8 | from utils.completion import generate_face_sample
9 | from utils.funcs import load_generator
10 | from config.config import read_config
11 | from os.path import join
12 | from tqdm import tqdm
13 |
14 |
15 | def main():
16 | parser = argparse.ArgumentParser(description='face sample')
17 |
18 | parser.add_argument("--config_file", type=str)
19 | parser.add_argument("--out_dir", type=str)
20 | parser.add_argument("--number", type=int, default=1)
21 |
22 | args = parser.parse_args()
23 |
24 | config = read_config(args.config_file)
25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26 | generator = load_generator(config).to(device)
27 |
28 | if not os.path.exists(args.out_dir):
29 | os.mkdir(args.out_dir)
30 | for i in tqdm(range(args.number)):
31 | generate_face_sample(join(args.out_dir, str(i + 1) + ".ply"), config, generator)
32 |
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
--------------------------------------------------------------------------------
/scripts/face_completion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5 |
6 | import argparse
7 | import torch
8 | from utils.completion import facial_mesh_completion
9 | from utils.funcs import load_generator
10 | from config.config import read_config
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser(description='facial shape completion')
15 |
16 | parser.add_argument("--config_file", type=str)
17 | parser.add_argument("--in_file", type=str)
18 | parser.add_argument("--out_file", type=str)
19 |
20 | parser.add_argument("--rr", type=bool, default=True)
21 | parser.add_argument("--lambda_reg", type=float, default=0.1)
22 | parser.add_argument("--verbose", type=bool, default=True)
23 | parser.add_argument("--dis_percent", type=float, default=None)
24 |
25 | args = parser.parse_args()
26 |
27 | config = read_config(args.config_file)
28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
29 | generator = load_generator(config).to(device)
30 | facial_mesh_completion(args.in_file, args.out_file, config, generator, args.lambda_reg, args.verbose,
31 | args.rr, args.dis_percent)
32 |
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
--------------------------------------------------------------------------------
/utils/render.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import trimesh
3 | from pytorch3d.structures import Meshes
4 | from pytorch3d.renderer import (
5 | look_at_view_transform,
6 | FoVPerspectiveCameras,
7 | PointLights,
8 | RasterizationSettings,
9 | MeshRenderer,
10 | MeshRasterizer,
11 | SoftPhongShader,
12 | TexturesVertex
13 | )
14 |
15 |
16 | def render_d(vertices: torch.Tensor, faces: torch.Tensor, image_size=(256, 256)):
17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
18 |
19 | R, T = look_at_view_transform(eye=torch.tensor([[0, 0, 250]], dtype=torch.float32),
20 | up=((0, 1, 0),),
21 | at=((0, 0, 0),),
22 | device=device)
23 |
24 | # 创建相机
25 | cameras = FoVPerspectiveCameras(
26 | device=device, R=R, T=T,
27 | znear=0.1, zfar=1000,
28 | fov=50,
29 | )
30 |
31 | # 渲染设置
32 | raster_settings = RasterizationSettings(
33 | image_size=image_size,
34 | blur_radius=0.0,
35 | faces_per_pixel=1,
36 | )
37 |
38 | # 光源
39 | lights = PointLights(device=device, location=[[0.0, 0.0, 100.0]])
40 |
41 | # mesh
42 | p3_mesh = Meshes([vertices],
43 | [faces],
44 | textures=TexturesVertex(verts_features=[torch.ones(vertices.shape, device=device)]))
45 |
46 | # 创建渲染器
47 | renderer = MeshRenderer(rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
48 | shader=SoftPhongShader(device=device, cameras=cameras, lights=lights))
49 |
50 | # 渲染图像
51 | images = renderer(p3_mesh)
52 | images = images[0, ..., :3]
53 |
54 | return images
55 |
--------------------------------------------------------------------------------
/utils/pytorch3d_extend.py:
--------------------------------------------------------------------------------
1 | from pytorch3d.loss.point_mesh_distance import *
2 | from pytorch3d.loss import mesh_laplacian_smoothing, chamfer_distance
3 | from pytorch3d.structures import Meshes, Pointclouds
4 | import torch
5 | from torch import Tensor
6 | from trimesh import Trimesh
7 |
8 |
9 | def point_mesh_face_distance_single_direction(
10 | meshes: Meshes,
11 | pcls: Pointclouds,
12 | min_triangle_area: float = 1e-6,
13 | ):
14 | if len(meshes) != len(pcls):
15 | raise ValueError("meshes and pointclouds must be equal sized batches")
16 | N = len(meshes)
17 |
18 | # packed representation for pointclouds
19 | points = pcls.points_packed() # (P, 3)
20 | points_first_idx = pcls.cloud_to_packed_first_idx()
21 | max_points = pcls.num_points_per_cloud().max().item()
22 |
23 | # packed representation for faces
24 | verts_packed = meshes.verts_packed()
25 | faces_packed = meshes.faces_packed()
26 | tris = verts_packed[faces_packed] # (T, 3, 3)
27 | tris_first_idx = meshes.mesh_to_faces_packed_first_idx()
28 | max_tris = meshes.num_faces_per_mesh().max().item()
29 |
30 | # point to face distance: shape (P,)
31 | point_to_face = point_face_distance(
32 | points, points_first_idx, tris, tris_first_idx, max_points, min_triangle_area
33 | )
34 |
35 | return point_to_face
36 |
37 |
38 | def distance_from_reference_mesh(points: Tensor, mesh_vertices: Tensor, mesh_faces: Tensor):
39 | """
40 | return distance^2 from mesh for every point in points
41 | """
42 | meshes = Meshes([mesh_vertices], [mesh_faces])
43 | pcs = Pointclouds([points])
44 | distances = point_mesh_face_distance_single_direction(meshes, pcs)
45 | return distances
46 |
47 |
48 | def smoothness_loss(vertices: Tensor, faces: Tensor):
49 | meshes = Meshes([vertices], [faces])
50 | return mesh_laplacian_smoothing(meshes)
51 |
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | import configparser
2 |
3 |
4 | def read_config(file_name):
5 | config = configparser.RawConfigParser()
6 | config.read(file_name)
7 |
8 | config_parms = {}
9 |
10 | config_parms['dataset_dir'] = config.get('I/O parameters', 'dataset_dir')
11 | config_parms['template_file'] = config.get('I/O parameters', 'template_file')
12 | config_parms['checkpoint_dir'] = config.get('I/O parameters', 'checkpoint_dir')
13 |
14 | config_parms['n_layers'] = config.getint('model parameters', 'n_layers')
15 | config_parms['z_length'] = config.getint('model parameters', 'z_length')
16 | config_parms['down_sampling_factors'] = [int(x) for x in
17 | config.get('model parameters', 'down_sampling_factors').split(',')]
18 | config_parms['num_features_global'] = [int(x) for x in
19 | config.get('model parameters', 'num_features_global').split(',')]
20 | config_parms['num_features_local'] = [int(x) for x in
21 | config.get('model parameters', 'num_features_local').split(',')]
22 | config_parms['batch_norm'] = True if config.getint('model parameters', 'batch_norm') == 1 else False
23 |
24 | config_parms['num_workers'] = config.getint('training parameters', 'num_workers')
25 | config_parms['lr'] = config.getfloat('training parameters', 'lr')
26 | config_parms['batch_size'] = config.getint('training parameters', 'batch_size')
27 | config_parms['weight_decay'] = config.getfloat('training parameters', 'weight_decay')
28 | config_parms['epoch'] = config.getint('training parameters', 'epoch')
29 | config_parms["lambda_reg"] = config.getfloat("training parameters", "lambda_reg")
30 |
31 | assert config_parms['n_layers'] == len(
32 | config_parms['down_sampling_factors']), 'length of down_sampling_factors must equal to n_layers'
33 | assert config_parms['n_layers'] + 1 == len(
34 | config_parms['num_features_global']), 'length of num_features must equal to n_layers + 1'
35 |
36 | return config_parms
37 |
--------------------------------------------------------------------------------
/utils/my_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import random
3 | import numpy as np
4 | import torch
5 | from torch_geometric.data import Data, Dataset, InMemoryDataset
6 | from tqdm import tqdm
7 | from trimesh import Trimesh, load_mesh
8 | import os
9 | from utils.funcs import save_ply_explicit, get_edge_index
10 | from concurrent.futures import ProcessPoolExecutor
11 |
12 |
13 | class Normalize(object):
14 | def __init__(self, mean=None, std=None):
15 | self.mean = mean
16 | self.std = std
17 |
18 | def __call__(self, data):
19 | assert self.mean is not None and self.std is not None, ('Initialize mean and std to normalize with')
20 | self.mean = torch.as_tensor(self.mean, dtype=data.x.dtype, device=data.x.device)
21 | self.std = torch.as_tensor(self.std, dtype=data.x.dtype, device=data.x.device)
22 | data.x = (data.x - self.mean) / self.std
23 | data.y = (data.y - self.mean) / self.std
24 | return data
25 |
26 |
27 | def load_mesh_file(file, train_dir):
28 | return load_mesh(osp.join(train_dir, file))
29 |
30 |
31 | class MyDataset(InMemoryDataset):
32 |
33 | def __init__(self, config, dtype='train'):
34 | assert dtype in ['train', 'eval'], "Invalid dtype!"
35 |
36 | self.config = config
37 | self.root = config['dataset_dir']
38 |
39 | super(MyDataset, self).__init__(self.root)
40 |
41 | data_path = self.processed_paths[0]
42 | if dtype == 'eval':
43 | data_path = self.processed_paths[1]
44 | norm_path = self.processed_paths[2]
45 | edge_index_path = self.processed_paths[3]
46 |
47 | norm_dict = torch.load(norm_path)
48 | self.mean, self.std = norm_dict['mean'], norm_dict['std']
49 | self.data, self.slices = torch.load(data_path)
50 | self.edge_index = torch.load(edge_index_path)
51 |
52 | @property
53 | def processed_file_names(self):
54 | processed_files = ['training.pt', 'eval.pt', 'norm.pt', "edge_index.pt"]
55 | return processed_files
56 |
57 | def process(self):
58 | meshes = []
59 | train_data, eval_data = [], []
60 | train_vertices = []
61 |
62 | train_dir = osp.join(self.root, "train")
63 | files = os.listdir(train_dir)
64 |
65 | # for file in tqdm(files):
66 | # mesh = load_mesh(osp.join(train_dir, file))
67 | # meshes.append(mesh)
68 |
69 | with ProcessPoolExecutor(max_workers=8) as executor: # 调整max_workers的数量以达到最佳性能
70 | futures = [executor.submit(load_mesh_file, file, train_dir) for file in files]
71 | for future in tqdm(futures):
72 | meshes.append(future.result())
73 |
74 | edge_index = get_edge_index(meshes[0].vertices, meshes[0].faces)
75 |
76 | # shuffle
77 | random.shuffle(meshes)
78 | count = int(0.8 * len(meshes))
79 | for i in range(len(meshes)):
80 | mesh_verts = torch.Tensor(meshes[i].vertices)
81 | data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)
82 | if i < count:
83 | train_data.append(data)
84 | train_vertices.append(mesh_verts)
85 | else:
86 | eval_data.append(data)
87 |
88 | mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
89 | std_train = torch.Tensor(np.std(train_vertices, axis=0))
90 | norm_dict = {'mean': mean_train, 'std': std_train}
91 |
92 | # save template
93 | mesh = Trimesh(vertices=mean_train, faces=meshes[0].faces)
94 | save_ply_explicit(mesh, self.config['template_file'])
95 |
96 | print("transforming...")
97 | transform = Normalize(mean_train, std_train)
98 | train_data = [transform(x) for x in train_data]
99 | eval_data = [transform(x) for x in eval_data]
100 |
101 | # save
102 | print("saving...")
103 | torch.save(self.collate(train_data), self.processed_paths[0])
104 | torch.save(self.collate(eval_data), self.processed_paths[1])
105 | torch.save(norm_dict, self.processed_paths[2])
106 | torch.save(edge_index, self.processed_paths[3])
107 | torch.save(norm_dict, osp.join(self.config['dataset_dir'], "norm.pt"))
108 |
--------------------------------------------------------------------------------
/utils/funcs.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import numpy as np
3 | from scipy.spatial import cKDTree
4 | import scipy.sparse as sp
5 | import torch
6 | from .models import FMGenDecoder
7 | from quad_mesh_simplify import simplify_mesh
8 | from torch.nn import Parameter, ParameterList
9 | from trimesh import Trimesh, load_mesh
10 | import os
11 | from os.path import join
12 | import warnings
13 |
14 | warnings.filterwarnings("ignore")
15 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
16 |
17 |
18 | def save_ply_explicit(mesh, ply_file_path):
19 | vertices = mesh.vertices
20 | faces = mesh.faces
21 | with open(ply_file_path, 'w') as f:
22 | f.write('ply\n')
23 | f.write('format ascii 1.0\n')
24 | f.write('element vertex {}\n'.format(len(vertices)))
25 | f.write('property float x\n')
26 | f.write('property float y\n')
27 | f.write('property float z\n')
28 | f.write('element face {}\n'.format(len(faces)))
29 | f.write('property list uchar int vertex_indices\n')
30 | f.write('end_header\n')
31 |
32 | # Write vertices
33 | for vertex in vertices:
34 | f.write('{} {} {}\n'.format(vertex[0], vertex[1], vertex[2]))
35 |
36 | # Write faces
37 | for face in faces:
38 | f.write('{} '.format(len(face)))
39 | f.write(' '.join(str(idx) for idx in face))
40 | f.write('\n')
41 |
42 |
43 | def row(A):
44 | return A.reshape((1, -1))
45 |
46 |
47 | def col(A):
48 | return A.reshape((-1, 1))
49 |
50 |
51 | def get_edge_index(vertices, faces):
52 | """
53 |
54 | Modified from https://github.com/pixelite1201/pytorch_coma/blob/master/mesh_operations.py
55 |
56 | :param vertices:
57 | :param faces:
58 | :return:
59 | """
60 | vpv = sp.csc_matrix((len(vertices), len(vertices)))
61 |
62 | for i in range(3):
63 | IS = faces[:, i]
64 | JS = faces[:, (i + 1) % 3]
65 | data = np.ones(len(IS))
66 | ij = np.vstack((row(IS.flatten()), row(JS.flatten())))
67 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape)
68 | vpv = vpv + mtx + mtx.T
69 |
70 | vpv = vpv.tocoo()
71 | return torch.tensor(np.vstack((vpv.row, vpv.col)), dtype=torch.int)
72 |
73 |
74 | def get_transform_matrix(vertices1, vertices2, k=5):
75 | """
76 | Calculate the transformation matrix D such that vertices2 = D * vertices1, where for each vertex in vertices2,
77 | it is formed as a linear combination of coordinates of the k nearest vertices in vertices1.
78 |
79 | :param vertices1:
80 | :param vertices2:
81 | :param k:
82 | :return:
83 | """
84 | kdtree = cKDTree(vertices1)
85 | _, indices = kdtree.query(vertices2, k=k)
86 | D = np.zeros((vertices2.shape[0], vertices1.shape[0]))
87 |
88 | for i in range(vertices2.shape[0]):
89 | vs = vertices1[indices[i]]
90 | w = np.matmul(vertices2[i], np.linalg.pinv(vs))
91 | D[i, indices[i]] = w
92 |
93 | return D
94 |
95 |
96 | def generate_VFDU(mesh: Trimesh, factors: List[float]):
97 | V, F, D, U = [], [], [], []
98 |
99 | vertices = np.array(mesh.vertices)
100 | faces = np.array(mesh.faces, dtype=np.uint32)
101 | V.append(vertices)
102 | F.append(faces)
103 |
104 | for factor in factors:
105 | # QEM简化网格
106 | new_vertices, new_faces = simplify_mesh(vertices, faces, vertices.shape[0] / factor)
107 | V.append(new_vertices)
108 | F.append(new_faces)
109 |
110 | d = get_transform_matrix(vertices, new_vertices, 9)
111 | u = get_transform_matrix(new_vertices, vertices, 9)
112 | D.append(d)
113 | U.append(u)
114 |
115 | vertices = new_vertices
116 | faces = new_faces
117 |
118 | return V, F, D, U
119 |
120 |
121 | def generate_transform_matrices_trimesh(mesh: Trimesh, factors: List[float]):
122 | V, F, D, U = generate_VFDU(mesh, factors)
123 |
124 | A = []
125 | for i in range(len(F)):
126 | edge_index = get_edge_index(V[i], F[i])
127 | A.append(edge_index)
128 |
129 | V = [torch.tensor(v) for v in V]
130 | D = [torch.tensor(d, dtype=torch.float32) for d in D]
131 | U = [torch.tensor(u, dtype=torch.float32) for u in U]
132 |
133 | return V, A, D, U
134 |
135 |
136 | def get_mesh_matrices(config):
137 | template_mesh = load_mesh(config["template_file"])
138 | _, A, D, U = generate_transform_matrices_trimesh(
139 | template_mesh, config["down_sampling_factors"]
140 | )
141 | pA = ParameterList([Parameter(a, requires_grad=False) for a in A])
142 | pD = ParameterList([Parameter(a, requires_grad=False) for a in D])
143 | pU = ParameterList([Parameter(a, requires_grad=False) for a in U])
144 |
145 | return pA, pD, pU
146 |
147 |
148 | def load_generator(config):
149 | A, D, U = get_mesh_matrices(config)
150 | model = FMGenDecoder(config, A, U)
151 | model.load_state_dict(
152 | torch.load(os.path.join(config["checkpoint_dir"], "checkpoint_decoder.pt"), map_location='cuda')
153 | )
154 | for param in model.parameters():
155 | param.requires_grad = False
156 | model.eval()
157 | return model
158 |
159 |
160 | def load_norm(config):
161 | norm_dict = torch.load(join(config['dataset_dir'], "norm.pt"))
162 | return norm_dict['mean'], norm_dict['std']
163 |
164 |
165 | def get_random_z(length, requires_grad=True, jitter=False):
166 | z = torch.randn(1, length)
167 | z = z / torch.sqrt(torch.sum(z ** 2))
168 | if jitter:
169 | p = torch.randn_like(z) / 10 * z
170 | z += p
171 | z.requires_grad = requires_grad
172 | return z
173 |
174 |
175 | def spherical_regularization_loss(z, target_radius=1.0):
176 | z_norm = torch.norm(z, p=2, dim=1)
177 | deviation = z_norm - target_radius
178 | loss = torch.mean(deviation ** 2)
179 | return loss
180 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
5 |
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | from torch_geometric.loader import DataLoader
10 | from tqdm import tqdm
11 | import torch.nn.functional as F
12 | import random
13 | from config.config import read_config
14 | from utils.my_dataset import MyDataset
15 | from utils.models import FMGenModel
16 | import argparse
17 | from torch.nn import Conv1d, Parameter, ParameterList
18 | from trimesh import Trimesh, load_mesh
19 | from utils.funcs import get_mesh_matrices, spherical_regularization_loss
20 | import warnings
21 |
22 | warnings.filterwarnings("ignore")
23 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
24 |
25 |
26 | def load_model(config, load_state_dict):
27 | pA, pD, pU = get_mesh_matrices(config)
28 | model = FMGenModel(config, pA, pD, pU)
29 | if load_state_dict:
30 | model.encoder.load_state_dict(torch.load(os.path.join(config['checkpoint_dir'], 'checkpoint_encoder.pt')))
31 | model.decoder.load_state_dict(torch.load(os.path.join(config['checkpoint_dir'], 'checkpoint_decoder.pt')))
32 |
33 | return model
34 |
35 |
36 | def train_epoch(model, train_loader, optimizer, device, size, epoch, lambda_reg=1.0):
37 | model.train()
38 | total_loss_l1 = 0
39 | total_loss_mse = 0
40 | total_loss_reg = 0
41 |
42 | for batch in tqdm(train_loader):
43 | batch = batch.to(device)
44 | optimizer.zero_grad()
45 |
46 | out, z = model(batch)
47 |
48 | loss_mse = F.mse_loss(out, batch.y)
49 | loss_l1 = F.l1_loss(out, batch.y)
50 | loss_reg = spherical_regularization_loss(z)
51 |
52 | total_loss_mse += batch.num_graphs * loss_mse.item()
53 | total_loss_l1 += batch.num_graphs * loss_l1.item()
54 | total_loss_reg += batch.num_graphs * loss_reg.item()
55 |
56 | loss = loss_l1 + loss_mse + lambda_reg * loss_reg
57 |
58 | loss.backward()
59 | optimizer.step()
60 |
61 | return total_loss_l1 / size, total_loss_mse / size, total_loss_reg / size
62 |
63 |
64 | def test_epoch(model, test_loader, device, size):
65 | model.eval()
66 | total_loss_l1 = 0
67 | total_loss_mse = 0
68 | total_loss_reg = 0
69 |
70 | for batch in tqdm(test_loader):
71 | batch = batch.to(device)
72 |
73 | out, z = model(batch)
74 |
75 | loss_mse = F.mse_loss(out, batch.y)
76 | loss_l1 = F.l1_loss(out, batch.y)
77 | loss_reg = spherical_regularization_loss(z)
78 |
79 | total_loss_mse += batch.num_graphs * loss_mse.item()
80 | total_loss_l1 += batch.num_graphs * loss_l1.item()
81 | total_loss_reg += batch.num_graphs * loss_reg.item()
82 |
83 | return total_loss_l1 / size, total_loss_mse / size, total_loss_reg / size
84 |
85 |
86 | def train(config):
87 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
88 |
89 | if not os.path.exists(config['checkpoint_dir']):
90 | os.makedirs(config['checkpoint_dir'])
91 |
92 | # #### dataset #####
93 | print("loading datasets...")
94 | dataset_train = MyDataset(config, 'train')
95 | dataset_test = MyDataset(config, 'eval')
96 | train_loader = DataLoader(dataset_train, batch_size=config['batch_size'], shuffle=True,
97 | num_workers=config['num_workers'], pin_memory=True, persistent_workers=True)
98 | test_loader = DataLoader(dataset_test, batch_size=config['batch_size'], shuffle=False,
99 | num_workers=config['num_workers'], pin_memory=True, persistent_workers=True)
100 |
101 | # #### model #####
102 | print("loading model...")
103 | model = load_model(config, False)
104 | model.to(device)
105 |
106 | # #### optimization #####
107 | optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
108 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
109 | # scheduler.load_state_dict(torch.load(os.path.join(config['checkpoint_dir'], 'scheduler.pt')))
110 |
111 | # #### train for epochs #####
112 | print("start training...")
113 | best_loss_item = float('inf')
114 |
115 | lambda_reg = config['lambda_reg']
116 |
117 | for epoch in range(scheduler.last_epoch + 1, config['epoch'] + 1):
118 | print("Epoch", epoch, " lr:", scheduler.get_lr())
119 |
120 | loss_l1, loss_mse, loss_reg = train_epoch(model, train_loader, optimizer, device, len(dataset_train), epoch,
121 | lambda_reg)
122 | print("Train loss: L1:", loss_l1, "MSE:", loss_mse, "REG:", loss_reg)
123 |
124 | loss_l1_test, loss_mse_test, loss_reg_test = test_epoch(model, test_loader, device, len(dataset_test))
125 | print("Test loss: L1:", loss_l1_test, "MSE:", loss_mse_test, "REG:", loss_reg_test)
126 |
127 | scheduler.step()
128 |
129 | if loss_l1_test + loss_mse_test + lambda_reg * loss_reg_test < best_loss_item:
130 | best_loss_item = loss_l1_test + loss_mse_test + lambda_reg * loss_reg_test
131 | torch.save(model.encoder.state_dict(), os.path.join(config['checkpoint_dir'], 'checkpoint_encoder.pt'))
132 | torch.save(model.decoder.state_dict(), os.path.join(config['checkpoint_dir'], 'checkpoint_decoder.pt'))
133 | torch.save(scheduler.state_dict(), os.path.join(config['checkpoint_dir'], 'scheduler.pt'))
134 | print("\nsave!\n\n")
135 |
136 |
137 | def main():
138 | parser = argparse.ArgumentParser(description='train')
139 | parser.add_argument("--config_file", type=str)
140 | args = parser.parse_args()
141 |
142 | # args.config_file = "../config/test_config.cfg"
143 |
144 | config = read_config(args.config_file)
145 | train(config)
146 |
147 |
148 | if __name__ == "__main__":
149 | main()
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## FaceCom: Towards High-fidelity 3D Facial Shape Completion via Optimization and Inpainting Guidance
2 |
3 | CVPR 2024
4 |
5 |
6 |
7 | ## Methods and Results
8 |
9 | #### Method Overview
10 |
11 |
12 |
13 | #### Fitting results with only Shape Generator
14 |
15 | 
16 |
17 | #### Shape completion results
18 |
19 | 
20 |
21 |
22 |
23 | ## Set-up
24 |
25 | 1. Download code :
26 |
27 | ```
28 | git clone https://github.com/dragonylee/FaceCom.git
29 | ```
30 |
31 | 2. Create and activate a conda environment :
32 |
33 | ```
34 | conda create -n FaceCom python=3.10
35 | conda activate FaceCom
36 | ```
37 |
38 | 3. Install dependencies using `pip` or `conda` :
39 |
40 | - [pytorch](https://pytorch.org/get-started/locally/)
41 |
42 | ```
43 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
44 | ```
45 |
46 | - [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md)
47 |
48 | ```
49 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath
50 | pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable"
51 | ```
52 |
53 | - [torch_geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) & trimesh & [quad_mesh_simplify](https://github.com/jannessm/quadric-mesh-simplification)
54 |
55 | ```
56 | pip install torch_geometric trimesh quad_mesh_simplify
57 | ```
58 |
59 | You will find that after the installation, there is only `quad_mesh_simplify-1.1.5.dist-info` under the `site-packages` folder of your Python environment. Therefore, you also need to copy the `quad_mesh_simplify` folder from the [GitHub repository](https://github.com/jannessm/quadric-mesh-simplification) to the `site-packages` folder.
60 |
61 |
62 |
63 | ## Data
64 |
65 | We trained our network using a structured hybrid 3D face dataset, which includes [Facescape](https://facescape.nju.edu.cn/) and [HeadSpace](https://www-users.york.ac.uk/~np7/research/Headspace/) datasets (under permissions), as well as our own dataset collected from hospitals. Due to certain reasons, the data we collected cannot be made public temporarily.
66 |
67 | You can download our pre-trained model `checkpoint_decoder.pt` ([Google Drive](https://drive.google.com/file/d/1oPfWRPgCXjAffPJWfZyZyZOgd5EYPrHf/view?usp=drive_link)|[百度网盘](https://pan.baidu.com/s/1SsBW08yieLTCbK9ec6EnwA?pwd=z4vc)) and put it in `data` folder.
68 |
69 |
70 |
71 | ## Config
72 |
73 | After downloading the pre-trained model, you need to modify the project path of the first three lines of `config/config.cfg`
74 |
75 | ```
76 | dataset_dir = PATH_TO_THE_PROJECT\data
77 | template_file = PATH_TO_THE_PROJECT\data\template.ply
78 | checkpoint_dir = PATH_TO_THE_PROJECT\data
79 | ```
80 |
81 | to match your own environment. If you create a new config file using the provided `config.cfg`, these three parameters should respectively satisfy the following conditions:
82 |
83 | 1. `dataset_dir` should contain the `norm.pt` file (if you intend to train, it should include a `train` folder instead, with all training data placed inside the `train` folder).
84 | 2. `template_file` should be the path to the template file.
85 | 3. `checkpoint_dir` should be the folder containing the model parameter files.
86 |
87 | The provided `config.cfg` file and the corresponding `data` folder can be used normally after downloading the pre-trained model described in [Data](#data).
88 |
89 |
90 |
91 | ## Usages
92 |
93 | After setting up the config file, you can thoroughly test with the scripts we provide below.
94 |
95 | ### Random Sample
96 |
97 | Randomly generate `--number` 3D face models.
98 |
99 | ```
100 | python scripts/face_sample.py --config_file config/config.cfg --out_dir sample_out --number 10
101 | ```
102 |
103 | ### Facial Shape Completion
104 |
105 | **NOTE** that our method has some considerations and flaws to be aware of.
106 |
107 | 1. The unit of the face model is in millimeters.
108 | 2. The range of the facial model should preferably be smaller than the `template.ply` we provide, otherwise add `--dis_percent 0.8` to achieve better results.
109 | 3. We use trimesh's ICP for rigid registration, but are unsure of its accuracy and robustness. You may perform precise rigid registration with `template.ply` first and set `--rr False`.
110 |
111 | Then, you can run our script to perform shape completion on `--in_file`,
112 |
113 | ```
114 | python scripts/face_completion.py --config_file config/config.cfg --in_file defect.ply --out_file comp.ply --rr True
115 | ```
116 |
117 | where `--in_file` is a file that trimesh can read, with no requirements on topology. We provide `defect.ply` for convenience.
118 |
119 | ### Mesh Fit / Non-rigid Registration
120 |
121 | When the input is a complete facial model without any defects, the script in the "Facial Shape Completion" section will actually output a fitting result to the input. Since the topology of our method's output is consistent, it can also be used for non-rigid registration.
122 |
123 |
124 |
125 | ## Train
126 |
127 | After preparing the dataset with unified topology, you can train a shape generator using the code we provided. First, determine a dataset folder path `A`, then create or modify the config file, changing the first three lines to
128 |
129 | ```
130 | dataset_dir = A
131 | template_file = A\template.ply
132 | checkpoint_dir = A
133 | ```
134 |
135 | You may pre-select the test data that will not be used for training, and then place the remaining training data in the `train` subfolder within the folder `A`. That is to say, before training, the directory structure of folder `A` should be as follows:
136 |
137 | ```
138 | A/
139 | ├── train/
140 | │ ├── training_data_1.ply
141 | │ ├── training_data_2.ply
142 | │ └── ...
143 | ```
144 |
145 | Then you can start training using the script below:
146 |
147 | ```
148 | python scripts/train.py --config_file config/config.cfg
149 | ```
150 |
151 | During training, the structure of folder `A` will look like this, with the average of the training data generated as a template:
152 |
153 | ```
154 | A/
155 | ├── train/
156 | │ ├── training_data_1
157 | │ ├── training_data_2
158 | │ └── ...
159 | └── template.ply
160 | └── norm.pt
161 | └── checkpoint_decoder.pt
162 | ```
163 |
164 | These results are sufficient for the usage described in [Usages](#usages).
165 |
166 |
167 |
168 | ## Citations
169 |
170 | If you find our work helpful, please cite us
171 |
172 | ```
173 | @inproceedings{li2024facecom,
174 | title={FaceCom: Towards High-fidelity 3D Facial Shape Completion via Optimization and Inpainting Guidance},
175 | author={Li, Yinglong and Wu, Hongyu and Wang, Xiaogang and Qin, Qingzhao and Zhao, Yijiao and Wang, Yong and Hao, Aimin},
176 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
177 | pages={2177--2186},
178 | year={2024}
179 | }
180 | ```
181 |
182 | or
183 |
184 | ```
185 | @article{li2024facecom,
186 | title={FaceCom: Towards High-fidelity 3D Facial Shape Completion via Optimization and Inpainting Guidance},
187 | author={Li, Yinglong and Wu, Hongyu and Wang, Xiaogang and Qin, Qingzhao and Zhao, Yijiao and Hao, Aimin and others},
188 | journal={arXiv preprint arXiv:2406.02074},
189 | year={2024}
190 | }
191 | ```
192 |
193 |
--------------------------------------------------------------------------------
/utils/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch_geometric.nn.conv import FeaStConv
4 | from torch_geometric.nn import BatchNorm
5 | from torch_geometric.data.batch import Batch
6 |
7 |
8 | class FMGenEncoder(torch.nn.Module):
9 | def __init__(self, config, A, D):
10 | super(FMGenEncoder, self).__init__()
11 | self.A = [torch.tensor(a, requires_grad=False) for a in A]
12 | self.D = [torch.tensor(a, requires_grad=False) for a in D]
13 |
14 | self.batch_norm = config['batch_norm']
15 | self.n_layers = config['n_layers']
16 | self.z_length = config['z_length']
17 | self.num_features_global = config['num_features_global']
18 | self.num_features_local = config['num_features_local']
19 |
20 | # conv layers
21 | self.encoder_convs_global = torch.nn.ModuleList([
22 | FeaStConv(in_channels=self.num_features_global[k], out_channels=self.num_features_global[k + 1])
23 | for k in range(self.n_layers)
24 | ])
25 | self.encoder_convs_local = torch.nn.ModuleList([
26 | FeaStConv(in_channels=self.num_features_local[k], out_channels=self.num_features_local[k + 1])
27 | for k in range(self.n_layers)
28 | ])
29 |
30 | # bn
31 | self.encoder_bns_local = torch.nn.ModuleList([
32 | BatchNorm(in_channels=self.num_features_local[k]) for k in range(self.n_layers + 1)
33 | ])
34 | self.encoder_bns_global = torch.nn.ModuleList([
35 | BatchNorm(in_channels=self.num_features_global[k]) for k in range(self.n_layers + 1)
36 | ])
37 |
38 | # linear layers
39 | self.encoder_lin = torch.nn.Linear(self.z_length + self.num_features_global[-1], self.z_length)
40 | self.encoder_local_lin = torch.nn.Linear(self.D[0].shape[1] * self.num_features_local[-1], self.z_length)
41 |
42 | self.reset_parameter()
43 |
44 | def reset_parameter(self):
45 | torch.nn.init.normal_(self.encoder_lin.weight, 0, 0.1)
46 | torch.nn.init.normal_(self.encoder_local_lin.weight, 0, 0.1)
47 |
48 | def forward(self, x, batch_size):
49 | self.A = [a.to(x.device) for a in self.A]
50 | self.D = [d.to(x.device) for d in self.D]
51 |
52 | # 分叉
53 | x_global = x
54 | x_local = x
55 |
56 | """
57 | global
58 | """
59 | # x_global: [batch_size * D[0].shape[1], num_features_global[0]]
60 | for i in range(self.n_layers):
61 | # 卷积
62 | x_global = self.encoder_convs_global[i](x=x_global, edge_index=self.A[i])
63 | # 归一化
64 | if self.batch_norm:
65 | x_global = self.encoder_bns_global[i + 1](x_global)
66 | # 激活函数
67 | x_global = F.leaky_relu(x_global)
68 | # 下采样
69 | x_global = x_global.reshape(batch_size, -1, self.num_features_global[i + 1])
70 | y = torch.zeros(batch_size, self.D[i].shape[0], x_global.shape[2], device=x_global.device)
71 | for j in range(batch_size):
72 | y[j] = torch.mm(self.D[i], x_global[j])
73 | x_global = y
74 | x_global = x_global.reshape(-1, self.num_features_global[i + 1])
75 | # x_global: [batch_size * D[-1].shape[0], num_features_global[-1]]
76 |
77 | x_global = x_global.reshape(batch_size, -1, self.num_features_global[-1])
78 | # x_global: [batch_size, D[-1].shape[0], num_features_global[-1]]
79 |
80 | # (mean pool & relu)
81 | x_global = torch.mean(x_global, dim=1)
82 | x_global = F.leaky_relu(x_global)
83 | # x_global: [batch_size, num_features_global[-1]]
84 |
85 | """
86 | local
87 | """
88 | # begin x_local: [batch_size * D[0].shape[1], num_features_local[0]]
89 | for i in range(self.n_layers):
90 | # 卷积
91 | x_local = self.encoder_convs_local[i](x=x_local, edge_index=self.A[0])
92 | # 归一化
93 | if self.batch_norm:
94 | x_local = self.encoder_bns_local[i + 1](x_local)
95 | # 激活函数
96 | x_local = F.leaky_relu(x_local)
97 | # x_local: [batch_size * D[0].shape[1], num_features_local[-1]]
98 |
99 | x_local = x_local.reshape(batch_size, -1)
100 | # x_local: [batch_size, D[0].shape[1] * num_features_local[0]]
101 |
102 | # (linear & relu)
103 | x_local = self.encoder_local_lin(x_local)
104 | x_local = F.leaky_relu(x_local)
105 | # x_local: [batch_size, z_length]
106 |
107 | """
108 | get z
109 | """
110 | z = torch.concat((x_global, x_local), dim=1)
111 | z = self.encoder_lin(z)
112 |
113 | return z
114 |
115 |
116 | class FMGenDecoder(torch.nn.Module):
117 | def __init__(self, config, A, U):
118 | super(FMGenDecoder, self).__init__()
119 | self.A = [torch.tensor(a, requires_grad=False) for a in A]
120 | self.U = [torch.tensor(u, requires_grad=False) for u in U]
121 |
122 | self.batch_norm = config['batch_norm']
123 | self.n_layers = config['n_layers']
124 | self.z_length = config['z_length']
125 | self.num_features_global = config['num_features_global']
126 | self.num_features_local = config['num_features_local']
127 |
128 | # conv layers
129 | self.decoder_convs_global = torch.nn.ModuleList([
130 | FeaStConv(in_channels=self.num_features_global[-1 - k], out_channels=self.num_features_global[-2 - k])
131 | for k in range(self.n_layers)
132 | ])
133 | self.decoder_convs_local = torch.nn.ModuleList([
134 | FeaStConv(in_channels=self.num_features_local[-1 - k], out_channels=self.num_features_local[-2 - k])
135 | for k in range(self.n_layers)
136 | ])
137 |
138 | # bn
139 | self.decoder_bns_local = torch.nn.ModuleList([
140 | BatchNorm(in_channels=self.num_features_local[-1 - k]) for k in range(self.n_layers + 1)
141 | ])
142 | self.decoder_bns_global = torch.nn.ModuleList([
143 | BatchNorm(in_channels=self.num_features_global[-1 - k]) for k in range(self.n_layers + 1)
144 | ])
145 |
146 | # linear layers
147 | self.decoder_lin = torch.nn.Linear(self.z_length, self.z_length + self.num_features_global[-1])
148 | self.decoder_local_lin = torch.nn.Linear(self.z_length, self.num_features_local[-1] * self.U[0].shape[0])
149 |
150 | self.reset_parameter()
151 |
152 | # merge ratio
153 | self.global_ratio = 0.01
154 | self.local_ratio = 1 - self.global_ratio
155 |
156 | def reset_parameter(self):
157 | torch.nn.init.normal_(self.decoder_lin.weight, 0, 0.1)
158 | torch.nn.init.normal_(self.decoder_local_lin.weight, 0, 0.1)
159 |
160 | def forward(self, z, batch_size):
161 | self.A = [a.to(z.device) for a in self.A]
162 | self.U = [u.to(z.device) for u in self.U]
163 |
164 | # decoder linear & 分叉
165 | x = self.decoder_lin(z)
166 | x_global = x[:, :self.num_features_global[-1]]
167 | x_local = x[:, self.num_features_global[-1]:]
168 |
169 | """
170 | global
171 | """
172 | # x_global: [batch_size, num_features_global[-1]]
173 | x_global = torch.unsqueeze(x_global, dim=1).repeat(1, self.U[-1].shape[1], 1)
174 | # x_global: [batch_size, U[-1].shape[1], num_features_global[-1]]
175 |
176 | for i in range(self.n_layers):
177 | # 上采样
178 | x_global = x_global.reshape(batch_size, -1, self.num_features_global[-1 - i])
179 | y = torch.zeros(batch_size, self.U[-1 - i].shape[0], x_global.shape[2], device=x_global.device)
180 | for j in range(batch_size):
181 | y[j] = torch.mm(self.U[-1 - i], x_global[j])
182 | x_global = y
183 | x_global = x_global.reshape(-1, self.num_features_global[-1 - i])
184 | # 卷积
185 | x_global = self.decoder_convs_global[i](x=x_global, edge_index=self.A[-2 - i])
186 | if i < self.n_layers - 1:
187 | # 归一化
188 | if self.batch_norm:
189 | x_global = self.decoder_bns_global[i + 1](x_global)
190 | # 激活函数
191 | x_global = F.leaky_relu(x_global)
192 | # x_global: [batch_size, U[0].shape[0], num_features_global[0]]
193 | x_global = x_global.reshape(-1, self.num_features_global[0])
194 | # x_global: [batch_size * U[0].shape[0], num_features_global[0]]
195 |
196 | """
197 | local
198 | """
199 | # x_local: [batch_size, z_length]
200 | x_local = self.decoder_local_lin(x_local)
201 | # x_local: [batch_size, num_features_local[-1] * U[0].shape[0]]
202 | x_local = x_local.reshape(-1, self.num_features_local[-1])
203 | # x_local: [batch_size * U[0].shape[0], num_features_local[-1]]
204 |
205 | for i in range(self.n_layers):
206 | # 卷积
207 | x_local = self.decoder_convs_local[i](x=x_local, edge_index=self.A[0])
208 | if i < self.n_layers - 1:
209 | # 归一化
210 | if self.batch_norm:
211 | x_local = self.decoder_bns_local[i + 1](x_local)
212 | # 激活函数
213 | x_local = F.leaky_relu(x_local)
214 | # x_local: [batch_size * U[0].shape[0], num_features_local[0]]
215 |
216 | """
217 | merge
218 | """
219 | x = self.global_ratio * x_global + self.local_ratio * x_local
220 |
221 | return x
222 |
223 |
224 | class FMGenModel(torch.nn.Module):
225 | def __init__(self, config, A, D, U):
226 | super(FMGenModel, self).__init__()
227 |
228 | self.encoder = FMGenEncoder(config, A, D)
229 | self.decoder = FMGenDecoder(config, A, U)
230 |
231 | def forward(self, batch: Batch):
232 | batch_size = batch.num_graphs
233 |
234 | z = self.encoder(batch.x, batch_size)
235 | x = self.decoder(z, batch_size)
236 |
237 | return x, z
238 |
--------------------------------------------------------------------------------
/utils/completion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.utils
3 |
4 | from config.config import read_config
5 | from trimesh import Trimesh, load_mesh
6 | from .funcs import load_generator, spherical_regularization_loss, save_ply_explicit, get_random_z, load_norm
7 | import torch.nn.functional as F
8 | import os
9 | from os.path import join
10 | import warnings
11 | from tqdm import tqdm
12 | import numpy as np
13 | from queue import Queue
14 |
15 | import math
16 | from pytorch3d.structures import Meshes
17 | from pytorch3d.loss import chamfer_distance
18 | from .pytorch3d_extend import distance_from_reference_mesh, smoothness_loss
19 | from trimesh.registration import icp
20 | from scipy.spatial import cKDTree
21 | import sys
22 |
23 | from PIL import Image
24 | import torchvision.transforms as transforms
25 | from .render import render_d
26 |
27 | warnings.filterwarnings("ignore")
28 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
29 | sys.setrecursionlimit(30000)
30 |
31 |
32 | def generate_face_sample(out_file, config, generator):
33 | generator.eval()
34 | device = generator.parameters().__next__().device
35 | z = get_random_z(generator.z_length, requires_grad=False)
36 | mean, std = load_norm(config)
37 | out = generator(z.to(device), 1).detach().cpu()
38 | out = out * std + mean
39 | template_mesh = load_mesh(config["template_file"])
40 | mesh = Trimesh(out, template_mesh.faces)
41 | save_ply_explicit(mesh, out_file)
42 |
43 |
44 | def rigid_registration(in_mesh, config, verbose=True):
45 | if verbose:
46 | print("rigid registration...")
47 |
48 | # mesh = in_mesh.copy()
49 | mesh = in_mesh
50 | template_mesh = load_mesh(config["template_file"])
51 |
52 | centroid = mesh.centroid
53 | mesh.vertices -= mesh.centroid
54 | T, _, _ = icp(mesh.vertices, template_mesh.vertices, max_iterations=50)
55 | mesh.apply_transform(T)
56 |
57 | return mesh, T, centroid
58 |
59 |
60 | def fit(in_mesh, generator, config, device, max_iters=1000, loss_convergence=1e-6, lambda_reg=None,
61 | verbose=True, dis_percent=None):
62 | if verbose:
63 | sys.stdout.write("\rFitting...")
64 | sys.stdout.flush()
65 |
66 | mesh = in_mesh.copy()
67 | template_mesh = load_mesh(config["template_file"])
68 |
69 | generator.eval()
70 |
71 | target_pc = torch.tensor(mesh.vertices, dtype=torch.float).to(device)
72 |
73 | z = get_random_z(generator.z_length, requires_grad=True, jitter=True)
74 |
75 | mean, std = load_norm(config)
76 | mean = mean.to(device)
77 | std = std.to(device)
78 | faces = torch.tensor(template_mesh.faces).to(device)
79 |
80 | optimizer = torch.optim.Adam([z], lr=0.1)
81 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
82 |
83 | if not lambda_reg:
84 | lambda_reg = config['lambda_reg']
85 | last_loss = math.inf
86 | iters = 0
87 | for i in range(max_iters):
88 | optimizer.zero_grad()
89 |
90 | out = generator(z.to(device), 1)
91 | out = out * std + mean
92 |
93 | loss_reg = spherical_regularization_loss(z)
94 | loss = loss_reg
95 |
96 | distance = torch.sqrt(distance_from_reference_mesh(target_pc, out, faces))
97 | if dis_percent:
98 | # 只取距离最小的一部分顶点
99 | distance, idx = torch.sort(distance)
100 | distance = distance[:int(dis_percent * len(distance))]
101 | loss_dfrm = torch.mean(distance)
102 | loss = loss_dfrm + lambda_reg * loss_reg
103 |
104 | if verbose:
105 | sys.stdout.write(
106 | "\rFitting...\tIter {}, loss_recon: {:.6f}, loss_reg: {:.6f}".format(i + 1,
107 | loss_dfrm.item(),
108 | loss_reg.item()))
109 | sys.stdout.flush()
110 | if math.fabs(last_loss - loss.item()) < loss_convergence:
111 | iters = i
112 | break
113 |
114 | last_loss = loss.item()
115 | loss.backward()
116 | optimizer.step()
117 | scheduler.step()
118 |
119 | out = generator(z.to(device), 1)
120 | out = out * std + mean
121 | fit_mesh = Trimesh(out.detach().cpu(), template_mesh.faces)
122 | if verbose:
123 | print("")
124 |
125 | return fit_mesh
126 |
127 |
128 | def post_processing(in_mesh_fit, in_mesh_faulty, device, laplacian=True, verbose=True):
129 | if verbose:
130 | print("post processing...")
131 |
132 | def get_color_mesh(mesh, idx, init_color=True, color=None):
133 | if color is None:
134 | color = [255, 0, 0, 255]
135 | color_mesh = mesh.copy()
136 |
137 | if init_color:
138 | color_array = np.zeros((mesh.vertices.shape[0], 4), dtype=np.uint8) # RGBA颜色
139 | color_array[idx] = color
140 | color_mesh.visual.vertex_colors = color_array
141 | else:
142 | color_mesh.visual.vertex_colors[idx] = color
143 | return color_mesh
144 |
145 | def extract_connected_components(mesh: Trimesh, idx):
146 | visited = set()
147 | components = []
148 |
149 | def dfs(vertex, component):
150 | if vertex in visited:
151 | return
152 | visited.add(vertex)
153 | component.add(vertex)
154 | for neighbor in mesh.vertex_neighbors[vertex]:
155 | if neighbor in idx:
156 | dfs(neighbor, component)
157 |
158 | for vertex in idx:
159 | if vertex not in visited:
160 | component = set()
161 | dfs(vertex, component)
162 | components.append(component)
163 |
164 | return components
165 |
166 | def expand_connected_component(mesh, component_, distance):
167 | expanded_component = set()
168 | component = component_.copy()
169 |
170 | for _ in range(distance):
171 | new_neighbors = set()
172 | for vertex in component:
173 | neighbors = mesh.vertex_neighbors[vertex]
174 | for neighbor in neighbors:
175 | if neighbor not in component and neighbor not in expanded_component:
176 | new_neighbors.add(neighbor)
177 | expanded_component.update(new_neighbors)
178 | component.update(new_neighbors)
179 |
180 | return expanded_component
181 |
182 | def special_point_refinement(mesh: Trimesh):
183 | vertices = mesh.vertices
184 | for i in tqdm(range(mesh.vertices.shape[0])):
185 | neighbor = mesh.vertex_neighbors[i]
186 | mean_x = np.mean(vertices[neighbor], axis=0)
187 | x = vertices[i]
188 | mean_distance = np.mean(np.linalg.norm(mesh.vertices[neighbor] - mean_x, axis=1))
189 | if np.linalg.norm(mean_x - x) > 0.5 * mean_distance:
190 | vertices[i] = mean_x
191 | return Trimesh(vertices, mesh.faces)
192 |
193 | def projection(source_mesh: Trimesh, largest_component_mask, target_mesh: Trimesh, max_iters=1000):
194 | x = torch.tensor(source_mesh.vertices, dtype=torch.float).to(device)
195 | normal_vectors = torch.tensor(source_mesh.vertex_normals, dtype=torch.float).to(device)
196 | ndf = torch.randn(source_mesh.vertices.shape[0]).detach().to(device)
197 | ndf.requires_grad = True
198 |
199 | optimizer = torch.optim.Adam([ndf], lr=0.1)
200 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9)
201 |
202 | last_loss = math.inf
203 | for i in range(max_iters):
204 | optimizer.zero_grad()
205 |
206 | out = x + normal_vectors * torch.unsqueeze(ndf, 1)
207 | distance = distance_from_reference_mesh(out[~largest_component_mask],
208 | torch.tensor(target_mesh.vertices, dtype=torch.float).to(device),
209 | torch.tensor(target_mesh.faces).to(device))
210 | distance = torch.sqrt(distance)
211 | loss_dfrm = torch.mean(distance)
212 |
213 | loss_smoothness = smoothness_loss(out, torch.tensor(source_mesh.faces).to(device))
214 | # 投影时保持平滑,防止某些顶点投影到别的曲面上去
215 |
216 | # loss = loss_dfrm + 1 * loss_smoothness
217 | loss = loss_dfrm
218 |
219 | if verbose:
220 | sys.stdout.write("\rProjection... Iter {}, Loss: {}".format(i + 1, loss.item()))
221 | sys.stdout.flush()
222 | if i > 100 and math.fabs(last_loss - loss.item()) < 1e-6:
223 | break
224 |
225 | last_loss = loss.item()
226 | loss.backward()
227 | optimizer.step()
228 | scheduler.step()
229 |
230 | out = x + normal_vectors * torch.unsqueeze(ndf, 1)
231 | out = out.detach().cpu()
232 |
233 | # 还要把largest_component_mask对应顶点变为原来的顶点
234 | out[largest_component_mask] = torch.tensor(source_mesh.vertices, dtype=torch.float)[largest_component_mask]
235 |
236 | new_mesh = Trimesh(out, source_mesh.faces)
237 | if verbose:
238 | print("")
239 |
240 | return new_mesh
241 |
242 | def find_nearest_vertices(target_vertices, source_vertices, k=1):
243 | tree = cKDTree(source_vertices)
244 | distances, indices = tree.query(target_vertices, k=k)
245 | return indices, distances
246 |
247 | # 读取mesh
248 | fit_mesh = in_mesh_fit.copy()
249 | faulty_mesh = in_mesh_faulty.copy()
250 |
251 | # 1.阈值法识别fit_mesh中缺损部分(“修补”部分)
252 | distance = distance_from_reference_mesh(torch.tensor(fit_mesh.vertices, dtype=torch.float).to(device),
253 | torch.tensor(faulty_mesh.vertices, dtype=torch.float).to(device),
254 | torch.tensor(faulty_mesh.faces).to(device)).cpu().numpy()
255 | idx = np.where(distance > 4)[0] # 阈值
256 | color_mesh = get_color_mesh(fit_mesh, idx)
257 | # color_mesh.export(join(out_path, "color_1.ply"))
258 |
259 | # 2.计算最大联通分量(缺损部分)以及扩展部分
260 | connected_components = extract_connected_components(fit_mesh, idx)
261 | if len(connected_components) == 0:
262 | return fit_mesh
263 | largest_component = max(connected_components, key=len)
264 | expanded_component = expand_connected_component(fit_mesh, largest_component, 2)
265 | color_mesh = get_color_mesh(fit_mesh, list(largest_component))
266 | color_mesh = get_color_mesh(color_mesh, list(expanded_component), False, [0, 255, 0, 255])
267 | # color_mesh.export(join(out_path, 'color_2.ply'))
268 |
269 | # 3.最优化法向位移场(投影)
270 | vertex_mask = np.zeros(len(fit_mesh.vertices), dtype=bool)
271 | vertex_mask[list(largest_component)] = True
272 | projection_mesh = projection(fit_mesh, vertex_mask, faulty_mesh)
273 | # projection_mesh.export(join(out_path, "projection.ply"))
274 |
275 | # 4.将最大联通分量的顶点逐一进行位移。 位移量:扩展部分中K个最近顶点[投影]时位移的均值
276 | vertices_expanded = fit_mesh.vertices[list(expanded_component)]
277 | normal_displacement = (projection_mesh.vertices - fit_mesh.vertices)[list(expanded_component)]
278 | indices, distances = find_nearest_vertices(fit_mesh.vertices, vertices_expanded, k=15)
279 | completion_mesh = projection_mesh.copy()
280 | for id in largest_component:
281 | mean_displacement = np.mean(normal_displacement[indices[id]], axis=0)
282 | completion_mesh.vertices[id] += mean_displacement
283 |
284 | def laplacian_smoothing(iterations=2, smoothing_factor=0.5):
285 | vertices_to_smooth = list(expanded_component)
286 | vertices_to_smooth.extend(list(largest_component))
287 |
288 | # 循环进行拉普拉斯平滑处理
289 | for _ in range(iterations):
290 | smoothed_vertices = []
291 | for vertex_index in vertices_to_smooth:
292 | vertex = completion_mesh.vertices[vertex_index]
293 | neighbors = completion_mesh.vertex_neighbors[vertex_index]
294 | neighbor_vertices = completion_mesh.vertices[neighbors]
295 | smoothed_vertex = vertex + smoothing_factor * np.mean(neighbor_vertices - vertex, axis=0)
296 | smoothed_vertices.append(smoothed_vertex)
297 | # 更新要平滑的顶点坐标
298 | for i, vertex_index in enumerate(vertices_to_smooth):
299 | completion_mesh.vertices[vertex_index] = smoothed_vertices[i]
300 |
301 | # 5.拉普拉斯平滑处理修复区域
302 | if laplacian:
303 | laplacian_smoothing()
304 |
305 | # 保存
306 | # completion_mesh.export(join(out_path, 'completion.ply'))
307 |
308 | # TODO: trimesh blender?
309 | #### 尝试把最大联通分量与缺损的组合在一起 ####
310 | # indices = list(largest_component)
311 | # new_vertices = completion_mesh.vertices[indices]
312 | # index_map = {index: i for i, index in enumerate(indices)}
313 | # new_faces = []
314 | # for face in completion_mesh.faces:
315 | # new_face = [index_map[index] for index in face if index in index_map]
316 | # if len(new_face) >= 3:
317 | # new_faces.append(new_face)
318 | # new_mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces)
319 | # # 缺损部分
320 | # new_mesh.export(join(out_path, 'fix_part.ply'))
321 | # # union
322 | # union_mesh = faulty_mesh.union(new_mesh)
323 | # union_mesh.export(join(out_path, "union.ply"))
324 |
325 | # refinement
326 | if verbose:
327 | print("refinement")
328 | for i in range(5):
329 | completion_mesh = special_point_refinement(completion_mesh)
330 | # completion_mesh.export(join(out_path, 'refinement.ply'))
331 |
332 | # done
333 | print("done!")
334 | return completion_mesh
335 |
336 |
337 | def facial_mesh_completion(in_file, out_file, config, generator, lambda_reg=None, verbose=True, rr=False,
338 | dis_percent=None):
339 | dir = os.path.dirname(in_file)
340 | device = generator.parameters().__next__().device
341 |
342 | mesh_in = load_mesh(in_file)
343 |
344 | if rr:
345 | mesh_in, T, centroid = rigid_registration(mesh_in, config, verbose=verbose)
346 |
347 | # save_ply_explicit(mesh_in, "rr.ply")
348 |
349 | mesh_fit = fit(mesh_in, generator, config, device, lambda_reg=lambda_reg, verbose=verbose, loss_convergence=1e-7,
350 | dis_percent=dis_percent)
351 | mesh_com = post_processing(mesh_fit, mesh_in, device, verbose=verbose)
352 |
353 | if rr:
354 | mesh_com.apply_transform(np.linalg.inv(T))
355 | mesh_com.vertices += centroid
356 |
357 | save_ply_explicit(mesh_com, out_file)
358 |
--------------------------------------------------------------------------------