├── config ├── __init__.py ├── config.cfg └── config.py ├── utils ├── __init__.py ├── render.py ├── pytorch3d_extend.py ├── my_dataset.py ├── funcs.py ├── models.py └── completion.py ├── scripts ├── __init__.py ├── face_sample.py ├── face_completion.py └── train.py ├── data └── norm.pt ├── assets ├── image-20240716144300162.png ├── image-20240716144509692.png └── image-20240716144812434.png ├── LICENSE └── README.md /config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/norm.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/data/norm.pt -------------------------------------------------------------------------------- /assets/image-20240716144300162.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/assets/image-20240716144300162.png -------------------------------------------------------------------------------- /assets/image-20240716144509692.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/assets/image-20240716144509692.png -------------------------------------------------------------------------------- /assets/image-20240716144812434.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dragonylee/FaceCom/HEAD/assets/image-20240716144812434.png -------------------------------------------------------------------------------- /config/config.cfg: -------------------------------------------------------------------------------- 1 | [I/O parameters] 2 | dataset_dir = PATH_TO_DATASET 3 | template_file = FILE_TO_TEMPLATE 4 | checkpoint_dir = PATH_TO_CHECKPOINTS 5 | 6 | [model parameters] 7 | n_layers = 3 8 | z_length = 256 9 | down_sampling_factors = 4, 4, 4 10 | num_features_local = 3, 8, 16, 32 11 | num_features_global = 3, 64, 512, 2048 12 | batch_norm = 1 13 | 14 | [training parameters] 15 | num_workers = 4 16 | lr = 0.001 17 | batch_size = 32 18 | weight_decay = 0.0005 19 | epoch = 1000 20 | lambda_reg = 0.001 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 dragonylee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/face_sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | 6 | import argparse 7 | import torch 8 | from utils.completion import generate_face_sample 9 | from utils.funcs import load_generator 10 | from config.config import read_config 11 | from os.path import join 12 | from tqdm import tqdm 13 | 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser(description='face sample') 17 | 18 | parser.add_argument("--config_file", type=str) 19 | parser.add_argument("--out_dir", type=str) 20 | parser.add_argument("--number", type=int, default=1) 21 | 22 | args = parser.parse_args() 23 | 24 | config = read_config(args.config_file) 25 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 26 | generator = load_generator(config).to(device) 27 | 28 | if not os.path.exists(args.out_dir): 29 | os.mkdir(args.out_dir) 30 | for i in tqdm(range(args.number)): 31 | generate_face_sample(join(args.out_dir, str(i + 1) + ".ply"), config, generator) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /scripts/face_completion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | 6 | import argparse 7 | import torch 8 | from utils.completion import facial_mesh_completion 9 | from utils.funcs import load_generator 10 | from config.config import read_config 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser(description='facial shape completion') 15 | 16 | parser.add_argument("--config_file", type=str) 17 | parser.add_argument("--in_file", type=str) 18 | parser.add_argument("--out_file", type=str) 19 | 20 | parser.add_argument("--rr", type=bool, default=True) 21 | parser.add_argument("--lambda_reg", type=float, default=0.1) 22 | parser.add_argument("--verbose", type=bool, default=True) 23 | parser.add_argument("--dis_percent", type=float, default=None) 24 | 25 | args = parser.parse_args() 26 | 27 | config = read_config(args.config_file) 28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 29 | generator = load_generator(config).to(device) 30 | facial_mesh_completion(args.in_file, args.out_file, config, generator, args.lambda_reg, args.verbose, 31 | args.rr, args.dis_percent) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /utils/render.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import trimesh 3 | from pytorch3d.structures import Meshes 4 | from pytorch3d.renderer import ( 5 | look_at_view_transform, 6 | FoVPerspectiveCameras, 7 | PointLights, 8 | RasterizationSettings, 9 | MeshRenderer, 10 | MeshRasterizer, 11 | SoftPhongShader, 12 | TexturesVertex 13 | ) 14 | 15 | 16 | def render_d(vertices: torch.Tensor, faces: torch.Tensor, image_size=(256, 256)): 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | 19 | R, T = look_at_view_transform(eye=torch.tensor([[0, 0, 250]], dtype=torch.float32), 20 | up=((0, 1, 0),), 21 | at=((0, 0, 0),), 22 | device=device) 23 | 24 | # 创建相机 25 | cameras = FoVPerspectiveCameras( 26 | device=device, R=R, T=T, 27 | znear=0.1, zfar=1000, 28 | fov=50, 29 | ) 30 | 31 | # 渲染设置 32 | raster_settings = RasterizationSettings( 33 | image_size=image_size, 34 | blur_radius=0.0, 35 | faces_per_pixel=1, 36 | ) 37 | 38 | # 光源 39 | lights = PointLights(device=device, location=[[0.0, 0.0, 100.0]]) 40 | 41 | # mesh 42 | p3_mesh = Meshes([vertices], 43 | [faces], 44 | textures=TexturesVertex(verts_features=[torch.ones(vertices.shape, device=device)])) 45 | 46 | # 创建渲染器 47 | renderer = MeshRenderer(rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), 48 | shader=SoftPhongShader(device=device, cameras=cameras, lights=lights)) 49 | 50 | # 渲染图像 51 | images = renderer(p3_mesh) 52 | images = images[0, ..., :3] 53 | 54 | return images 55 | -------------------------------------------------------------------------------- /utils/pytorch3d_extend.py: -------------------------------------------------------------------------------- 1 | from pytorch3d.loss.point_mesh_distance import * 2 | from pytorch3d.loss import mesh_laplacian_smoothing, chamfer_distance 3 | from pytorch3d.structures import Meshes, Pointclouds 4 | import torch 5 | from torch import Tensor 6 | from trimesh import Trimesh 7 | 8 | 9 | def point_mesh_face_distance_single_direction( 10 | meshes: Meshes, 11 | pcls: Pointclouds, 12 | min_triangle_area: float = 1e-6, 13 | ): 14 | if len(meshes) != len(pcls): 15 | raise ValueError("meshes and pointclouds must be equal sized batches") 16 | N = len(meshes) 17 | 18 | # packed representation for pointclouds 19 | points = pcls.points_packed() # (P, 3) 20 | points_first_idx = pcls.cloud_to_packed_first_idx() 21 | max_points = pcls.num_points_per_cloud().max().item() 22 | 23 | # packed representation for faces 24 | verts_packed = meshes.verts_packed() 25 | faces_packed = meshes.faces_packed() 26 | tris = verts_packed[faces_packed] # (T, 3, 3) 27 | tris_first_idx = meshes.mesh_to_faces_packed_first_idx() 28 | max_tris = meshes.num_faces_per_mesh().max().item() 29 | 30 | # point to face distance: shape (P,) 31 | point_to_face = point_face_distance( 32 | points, points_first_idx, tris, tris_first_idx, max_points, min_triangle_area 33 | ) 34 | 35 | return point_to_face 36 | 37 | 38 | def distance_from_reference_mesh(points: Tensor, mesh_vertices: Tensor, mesh_faces: Tensor): 39 | """ 40 | return distance^2 from mesh for every point in points 41 | """ 42 | meshes = Meshes([mesh_vertices], [mesh_faces]) 43 | pcs = Pointclouds([points]) 44 | distances = point_mesh_face_distance_single_direction(meshes, pcs) 45 | return distances 46 | 47 | 48 | def smoothness_loss(vertices: Tensor, faces: Tensor): 49 | meshes = Meshes([vertices], [faces]) 50 | return mesh_laplacian_smoothing(meshes) 51 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | 3 | 4 | def read_config(file_name): 5 | config = configparser.RawConfigParser() 6 | config.read(file_name) 7 | 8 | config_parms = {} 9 | 10 | config_parms['dataset_dir'] = config.get('I/O parameters', 'dataset_dir') 11 | config_parms['template_file'] = config.get('I/O parameters', 'template_file') 12 | config_parms['checkpoint_dir'] = config.get('I/O parameters', 'checkpoint_dir') 13 | 14 | config_parms['n_layers'] = config.getint('model parameters', 'n_layers') 15 | config_parms['z_length'] = config.getint('model parameters', 'z_length') 16 | config_parms['down_sampling_factors'] = [int(x) for x in 17 | config.get('model parameters', 'down_sampling_factors').split(',')] 18 | config_parms['num_features_global'] = [int(x) for x in 19 | config.get('model parameters', 'num_features_global').split(',')] 20 | config_parms['num_features_local'] = [int(x) for x in 21 | config.get('model parameters', 'num_features_local').split(',')] 22 | config_parms['batch_norm'] = True if config.getint('model parameters', 'batch_norm') == 1 else False 23 | 24 | config_parms['num_workers'] = config.getint('training parameters', 'num_workers') 25 | config_parms['lr'] = config.getfloat('training parameters', 'lr') 26 | config_parms['batch_size'] = config.getint('training parameters', 'batch_size') 27 | config_parms['weight_decay'] = config.getfloat('training parameters', 'weight_decay') 28 | config_parms['epoch'] = config.getint('training parameters', 'epoch') 29 | config_parms["lambda_reg"] = config.getfloat("training parameters", "lambda_reg") 30 | 31 | assert config_parms['n_layers'] == len( 32 | config_parms['down_sampling_factors']), 'length of down_sampling_factors must equal to n_layers' 33 | assert config_parms['n_layers'] + 1 == len( 34 | config_parms['num_features_global']), 'length of num_features must equal to n_layers + 1' 35 | 36 | return config_parms 37 | -------------------------------------------------------------------------------- /utils/my_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | import numpy as np 4 | import torch 5 | from torch_geometric.data import Data, Dataset, InMemoryDataset 6 | from tqdm import tqdm 7 | from trimesh import Trimesh, load_mesh 8 | import os 9 | from utils.funcs import save_ply_explicit, get_edge_index 10 | from concurrent.futures import ProcessPoolExecutor 11 | 12 | 13 | class Normalize(object): 14 | def __init__(self, mean=None, std=None): 15 | self.mean = mean 16 | self.std = std 17 | 18 | def __call__(self, data): 19 | assert self.mean is not None and self.std is not None, ('Initialize mean and std to normalize with') 20 | self.mean = torch.as_tensor(self.mean, dtype=data.x.dtype, device=data.x.device) 21 | self.std = torch.as_tensor(self.std, dtype=data.x.dtype, device=data.x.device) 22 | data.x = (data.x - self.mean) / self.std 23 | data.y = (data.y - self.mean) / self.std 24 | return data 25 | 26 | 27 | def load_mesh_file(file, train_dir): 28 | return load_mesh(osp.join(train_dir, file)) 29 | 30 | 31 | class MyDataset(InMemoryDataset): 32 | 33 | def __init__(self, config, dtype='train'): 34 | assert dtype in ['train', 'eval'], "Invalid dtype!" 35 | 36 | self.config = config 37 | self.root = config['dataset_dir'] 38 | 39 | super(MyDataset, self).__init__(self.root) 40 | 41 | data_path = self.processed_paths[0] 42 | if dtype == 'eval': 43 | data_path = self.processed_paths[1] 44 | norm_path = self.processed_paths[2] 45 | edge_index_path = self.processed_paths[3] 46 | 47 | norm_dict = torch.load(norm_path) 48 | self.mean, self.std = norm_dict['mean'], norm_dict['std'] 49 | self.data, self.slices = torch.load(data_path) 50 | self.edge_index = torch.load(edge_index_path) 51 | 52 | @property 53 | def processed_file_names(self): 54 | processed_files = ['training.pt', 'eval.pt', 'norm.pt', "edge_index.pt"] 55 | return processed_files 56 | 57 | def process(self): 58 | meshes = [] 59 | train_data, eval_data = [], [] 60 | train_vertices = [] 61 | 62 | train_dir = osp.join(self.root, "train") 63 | files = os.listdir(train_dir) 64 | 65 | # for file in tqdm(files): 66 | # mesh = load_mesh(osp.join(train_dir, file)) 67 | # meshes.append(mesh) 68 | 69 | with ProcessPoolExecutor(max_workers=8) as executor: # 调整max_workers的数量以达到最佳性能 70 | futures = [executor.submit(load_mesh_file, file, train_dir) for file in files] 71 | for future in tqdm(futures): 72 | meshes.append(future.result()) 73 | 74 | edge_index = get_edge_index(meshes[0].vertices, meshes[0].faces) 75 | 76 | # shuffle 77 | random.shuffle(meshes) 78 | count = int(0.8 * len(meshes)) 79 | for i in range(len(meshes)): 80 | mesh_verts = torch.Tensor(meshes[i].vertices) 81 | data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index) 82 | if i < count: 83 | train_data.append(data) 84 | train_vertices.append(mesh_verts) 85 | else: 86 | eval_data.append(data) 87 | 88 | mean_train = torch.Tensor(np.mean(train_vertices, axis=0)) 89 | std_train = torch.Tensor(np.std(train_vertices, axis=0)) 90 | norm_dict = {'mean': mean_train, 'std': std_train} 91 | 92 | # save template 93 | mesh = Trimesh(vertices=mean_train, faces=meshes[0].faces) 94 | save_ply_explicit(mesh, self.config['template_file']) 95 | 96 | print("transforming...") 97 | transform = Normalize(mean_train, std_train) 98 | train_data = [transform(x) for x in train_data] 99 | eval_data = [transform(x) for x in eval_data] 100 | 101 | # save 102 | print("saving...") 103 | torch.save(self.collate(train_data), self.processed_paths[0]) 104 | torch.save(self.collate(eval_data), self.processed_paths[1]) 105 | torch.save(norm_dict, self.processed_paths[2]) 106 | torch.save(edge_index, self.processed_paths[3]) 107 | torch.save(norm_dict, osp.join(self.config['dataset_dir'], "norm.pt")) 108 | -------------------------------------------------------------------------------- /utils/funcs.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import numpy as np 3 | from scipy.spatial import cKDTree 4 | import scipy.sparse as sp 5 | import torch 6 | from .models import FMGenDecoder 7 | from quad_mesh_simplify import simplify_mesh 8 | from torch.nn import Parameter, ParameterList 9 | from trimesh import Trimesh, load_mesh 10 | import os 11 | from os.path import join 12 | import warnings 13 | 14 | warnings.filterwarnings("ignore") 15 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 16 | 17 | 18 | def save_ply_explicit(mesh, ply_file_path): 19 | vertices = mesh.vertices 20 | faces = mesh.faces 21 | with open(ply_file_path, 'w') as f: 22 | f.write('ply\n') 23 | f.write('format ascii 1.0\n') 24 | f.write('element vertex {}\n'.format(len(vertices))) 25 | f.write('property float x\n') 26 | f.write('property float y\n') 27 | f.write('property float z\n') 28 | f.write('element face {}\n'.format(len(faces))) 29 | f.write('property list uchar int vertex_indices\n') 30 | f.write('end_header\n') 31 | 32 | # Write vertices 33 | for vertex in vertices: 34 | f.write('{} {} {}\n'.format(vertex[0], vertex[1], vertex[2])) 35 | 36 | # Write faces 37 | for face in faces: 38 | f.write('{} '.format(len(face))) 39 | f.write(' '.join(str(idx) for idx in face)) 40 | f.write('\n') 41 | 42 | 43 | def row(A): 44 | return A.reshape((1, -1)) 45 | 46 | 47 | def col(A): 48 | return A.reshape((-1, 1)) 49 | 50 | 51 | def get_edge_index(vertices, faces): 52 | """ 53 | 54 | Modified from https://github.com/pixelite1201/pytorch_coma/blob/master/mesh_operations.py 55 | 56 | :param vertices: 57 | :param faces: 58 | :return: 59 | """ 60 | vpv = sp.csc_matrix((len(vertices), len(vertices))) 61 | 62 | for i in range(3): 63 | IS = faces[:, i] 64 | JS = faces[:, (i + 1) % 3] 65 | data = np.ones(len(IS)) 66 | ij = np.vstack((row(IS.flatten()), row(JS.flatten()))) 67 | mtx = sp.csc_matrix((data, ij), shape=vpv.shape) 68 | vpv = vpv + mtx + mtx.T 69 | 70 | vpv = vpv.tocoo() 71 | return torch.tensor(np.vstack((vpv.row, vpv.col)), dtype=torch.int) 72 | 73 | 74 | def get_transform_matrix(vertices1, vertices2, k=5): 75 | """ 76 | Calculate the transformation matrix D such that vertices2 = D * vertices1, where for each vertex in vertices2, 77 | it is formed as a linear combination of coordinates of the k nearest vertices in vertices1. 78 | 79 | :param vertices1: 80 | :param vertices2: 81 | :param k: 82 | :return: 83 | """ 84 | kdtree = cKDTree(vertices1) 85 | _, indices = kdtree.query(vertices2, k=k) 86 | D = np.zeros((vertices2.shape[0], vertices1.shape[0])) 87 | 88 | for i in range(vertices2.shape[0]): 89 | vs = vertices1[indices[i]] 90 | w = np.matmul(vertices2[i], np.linalg.pinv(vs)) 91 | D[i, indices[i]] = w 92 | 93 | return D 94 | 95 | 96 | def generate_VFDU(mesh: Trimesh, factors: List[float]): 97 | V, F, D, U = [], [], [], [] 98 | 99 | vertices = np.array(mesh.vertices) 100 | faces = np.array(mesh.faces, dtype=np.uint32) 101 | V.append(vertices) 102 | F.append(faces) 103 | 104 | for factor in factors: 105 | # QEM简化网格 106 | new_vertices, new_faces = simplify_mesh(vertices, faces, vertices.shape[0] / factor) 107 | V.append(new_vertices) 108 | F.append(new_faces) 109 | 110 | d = get_transform_matrix(vertices, new_vertices, 9) 111 | u = get_transform_matrix(new_vertices, vertices, 9) 112 | D.append(d) 113 | U.append(u) 114 | 115 | vertices = new_vertices 116 | faces = new_faces 117 | 118 | return V, F, D, U 119 | 120 | 121 | def generate_transform_matrices_trimesh(mesh: Trimesh, factors: List[float]): 122 | V, F, D, U = generate_VFDU(mesh, factors) 123 | 124 | A = [] 125 | for i in range(len(F)): 126 | edge_index = get_edge_index(V[i], F[i]) 127 | A.append(edge_index) 128 | 129 | V = [torch.tensor(v) for v in V] 130 | D = [torch.tensor(d, dtype=torch.float32) for d in D] 131 | U = [torch.tensor(u, dtype=torch.float32) for u in U] 132 | 133 | return V, A, D, U 134 | 135 | 136 | def get_mesh_matrices(config): 137 | template_mesh = load_mesh(config["template_file"]) 138 | _, A, D, U = generate_transform_matrices_trimesh( 139 | template_mesh, config["down_sampling_factors"] 140 | ) 141 | pA = ParameterList([Parameter(a, requires_grad=False) for a in A]) 142 | pD = ParameterList([Parameter(a, requires_grad=False) for a in D]) 143 | pU = ParameterList([Parameter(a, requires_grad=False) for a in U]) 144 | 145 | return pA, pD, pU 146 | 147 | 148 | def load_generator(config): 149 | A, D, U = get_mesh_matrices(config) 150 | model = FMGenDecoder(config, A, U) 151 | model.load_state_dict( 152 | torch.load(os.path.join(config["checkpoint_dir"], "checkpoint_decoder.pt"), map_location='cuda') 153 | ) 154 | for param in model.parameters(): 155 | param.requires_grad = False 156 | model.eval() 157 | return model 158 | 159 | 160 | def load_norm(config): 161 | norm_dict = torch.load(join(config['dataset_dir'], "norm.pt")) 162 | return norm_dict['mean'], norm_dict['std'] 163 | 164 | 165 | def get_random_z(length, requires_grad=True, jitter=False): 166 | z = torch.randn(1, length) 167 | z = z / torch.sqrt(torch.sum(z ** 2)) 168 | if jitter: 169 | p = torch.randn_like(z) / 10 * z 170 | z += p 171 | z.requires_grad = requires_grad 172 | return z 173 | 174 | 175 | def spherical_regularization_loss(z, target_radius=1.0): 176 | z_norm = torch.norm(z, p=2, dim=1) 177 | deviation = z_norm - target_radius 178 | loss = torch.mean(deviation ** 2) 179 | return loss 180 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from torch_geometric.loader import DataLoader 10 | from tqdm import tqdm 11 | import torch.nn.functional as F 12 | import random 13 | from config.config import read_config 14 | from utils.my_dataset import MyDataset 15 | from utils.models import FMGenModel 16 | import argparse 17 | from torch.nn import Conv1d, Parameter, ParameterList 18 | from trimesh import Trimesh, load_mesh 19 | from utils.funcs import get_mesh_matrices, spherical_regularization_loss 20 | import warnings 21 | 22 | warnings.filterwarnings("ignore") 23 | os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' 24 | 25 | 26 | def load_model(config, load_state_dict): 27 | pA, pD, pU = get_mesh_matrices(config) 28 | model = FMGenModel(config, pA, pD, pU) 29 | if load_state_dict: 30 | model.encoder.load_state_dict(torch.load(os.path.join(config['checkpoint_dir'], 'checkpoint_encoder.pt'))) 31 | model.decoder.load_state_dict(torch.load(os.path.join(config['checkpoint_dir'], 'checkpoint_decoder.pt'))) 32 | 33 | return model 34 | 35 | 36 | def train_epoch(model, train_loader, optimizer, device, size, epoch, lambda_reg=1.0): 37 | model.train() 38 | total_loss_l1 = 0 39 | total_loss_mse = 0 40 | total_loss_reg = 0 41 | 42 | for batch in tqdm(train_loader): 43 | batch = batch.to(device) 44 | optimizer.zero_grad() 45 | 46 | out, z = model(batch) 47 | 48 | loss_mse = F.mse_loss(out, batch.y) 49 | loss_l1 = F.l1_loss(out, batch.y) 50 | loss_reg = spherical_regularization_loss(z) 51 | 52 | total_loss_mse += batch.num_graphs * loss_mse.item() 53 | total_loss_l1 += batch.num_graphs * loss_l1.item() 54 | total_loss_reg += batch.num_graphs * loss_reg.item() 55 | 56 | loss = loss_l1 + loss_mse + lambda_reg * loss_reg 57 | 58 | loss.backward() 59 | optimizer.step() 60 | 61 | return total_loss_l1 / size, total_loss_mse / size, total_loss_reg / size 62 | 63 | 64 | def test_epoch(model, test_loader, device, size): 65 | model.eval() 66 | total_loss_l1 = 0 67 | total_loss_mse = 0 68 | total_loss_reg = 0 69 | 70 | for batch in tqdm(test_loader): 71 | batch = batch.to(device) 72 | 73 | out, z = model(batch) 74 | 75 | loss_mse = F.mse_loss(out, batch.y) 76 | loss_l1 = F.l1_loss(out, batch.y) 77 | loss_reg = spherical_regularization_loss(z) 78 | 79 | total_loss_mse += batch.num_graphs * loss_mse.item() 80 | total_loss_l1 += batch.num_graphs * loss_l1.item() 81 | total_loss_reg += batch.num_graphs * loss_reg.item() 82 | 83 | return total_loss_l1 / size, total_loss_mse / size, total_loss_reg / size 84 | 85 | 86 | def train(config): 87 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 88 | 89 | if not os.path.exists(config['checkpoint_dir']): 90 | os.makedirs(config['checkpoint_dir']) 91 | 92 | # #### dataset ##### 93 | print("loading datasets...") 94 | dataset_train = MyDataset(config, 'train') 95 | dataset_test = MyDataset(config, 'eval') 96 | train_loader = DataLoader(dataset_train, batch_size=config['batch_size'], shuffle=True, 97 | num_workers=config['num_workers'], pin_memory=True, persistent_workers=True) 98 | test_loader = DataLoader(dataset_test, batch_size=config['batch_size'], shuffle=False, 99 | num_workers=config['num_workers'], pin_memory=True, persistent_workers=True) 100 | 101 | # #### model ##### 102 | print("loading model...") 103 | model = load_model(config, False) 104 | model.to(device) 105 | 106 | # #### optimization ##### 107 | optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay']) 108 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5) 109 | # scheduler.load_state_dict(torch.load(os.path.join(config['checkpoint_dir'], 'scheduler.pt'))) 110 | 111 | # #### train for epochs ##### 112 | print("start training...") 113 | best_loss_item = float('inf') 114 | 115 | lambda_reg = config['lambda_reg'] 116 | 117 | for epoch in range(scheduler.last_epoch + 1, config['epoch'] + 1): 118 | print("Epoch", epoch, " lr:", scheduler.get_lr()) 119 | 120 | loss_l1, loss_mse, loss_reg = train_epoch(model, train_loader, optimizer, device, len(dataset_train), epoch, 121 | lambda_reg) 122 | print("Train loss: L1:", loss_l1, "MSE:", loss_mse, "REG:", loss_reg) 123 | 124 | loss_l1_test, loss_mse_test, loss_reg_test = test_epoch(model, test_loader, device, len(dataset_test)) 125 | print("Test loss: L1:", loss_l1_test, "MSE:", loss_mse_test, "REG:", loss_reg_test) 126 | 127 | scheduler.step() 128 | 129 | if loss_l1_test + loss_mse_test + lambda_reg * loss_reg_test < best_loss_item: 130 | best_loss_item = loss_l1_test + loss_mse_test + lambda_reg * loss_reg_test 131 | torch.save(model.encoder.state_dict(), os.path.join(config['checkpoint_dir'], 'checkpoint_encoder.pt')) 132 | torch.save(model.decoder.state_dict(), os.path.join(config['checkpoint_dir'], 'checkpoint_decoder.pt')) 133 | torch.save(scheduler.state_dict(), os.path.join(config['checkpoint_dir'], 'scheduler.pt')) 134 | print("\nsave!\n\n") 135 | 136 | 137 | def main(): 138 | parser = argparse.ArgumentParser(description='train') 139 | parser.add_argument("--config_file", type=str) 140 | args = parser.parse_args() 141 | 142 | # args.config_file = "../config/test_config.cfg" 143 | 144 | config = read_config(args.config_file) 145 | train(config) 146 | 147 | 148 | if __name__ == "__main__": 149 | main() 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## FaceCom: Towards High-fidelity 3D Facial Shape Completion via Optimization and Inpainting Guidance 2 | 3 | CVPR 2024 4 | 5 |
6 | 7 | ## Methods and Results 8 | 9 | #### Method Overview 10 | 11 | image-20240716144509692 12 | 13 | #### Fitting results with only Shape Generator 14 | 15 | ![image-20240716144300162](./assets/image-20240716144300162.png) 16 | 17 | #### Shape completion results 18 | 19 | ![image-20240716144812434](./assets/image-20240716144812434.png) 20 | 21 |
22 | 23 | ## Set-up 24 | 25 | 1. Download code : 26 | 27 | ``` 28 | git clone https://github.com/dragonylee/FaceCom.git 29 | ``` 30 | 31 | 2. Create and activate a conda environment : 32 | 33 | ``` 34 | conda create -n FaceCom python=3.10 35 | conda activate FaceCom 36 | ``` 37 | 38 | 3. Install dependencies using `pip` or `conda` : 39 | 40 | - [pytorch](https://pytorch.org/get-started/locally/) 41 | 42 | ``` 43 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 44 | ``` 45 | 46 | - [pytorch3d](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) 47 | 48 | ``` 49 | conda install -c fvcore -c iopath -c conda-forge fvcore iopath 50 | pip install "git+https://github.com/facebookresearch/pytorch3d.git@stable" 51 | ``` 52 | 53 | - [torch_geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) & trimesh & [quad_mesh_simplify](https://github.com/jannessm/quadric-mesh-simplification) 54 | 55 | ``` 56 | pip install torch_geometric trimesh quad_mesh_simplify 57 | ``` 58 | 59 | You will find that after the installation, there is only `quad_mesh_simplify-1.1.5.dist-info` under the `site-packages` folder of your Python environment. Therefore, you also need to copy the `quad_mesh_simplify` folder from the [GitHub repository](https://github.com/jannessm/quadric-mesh-simplification) to the `site-packages` folder. 60 | 61 |
62 | 63 | ## Data 64 | 65 | We trained our network using a structured hybrid 3D face dataset, which includes [Facescape](https://facescape.nju.edu.cn/) and [HeadSpace](https://www-users.york.ac.uk/~np7/research/Headspace/) datasets (under permissions), as well as our own dataset collected from hospitals. Due to certain reasons, the data we collected cannot be made public temporarily. 66 | 67 | You can download our pre-trained model `checkpoint_decoder.pt` ([Google Drive](https://drive.google.com/file/d/1oPfWRPgCXjAffPJWfZyZyZOgd5EYPrHf/view?usp=drive_link)|[百度网盘](https://pan.baidu.com/s/1SsBW08yieLTCbK9ec6EnwA?pwd=z4vc)) and put it in `data` folder. 68 | 69 |
70 | 71 | ## Config 72 | 73 | After downloading the pre-trained model, you need to modify the project path of the first three lines of `config/config.cfg` 74 | 75 | ``` 76 | dataset_dir = PATH_TO_THE_PROJECT\data 77 | template_file = PATH_TO_THE_PROJECT\data\template.ply 78 | checkpoint_dir = PATH_TO_THE_PROJECT\data 79 | ``` 80 | 81 | to match your own environment. If you create a new config file using the provided `config.cfg`, these three parameters should respectively satisfy the following conditions: 82 | 83 | 1. `dataset_dir` should contain the `norm.pt` file (if you intend to train, it should include a `train` folder instead, with all training data placed inside the `train` folder). 84 | 2. `template_file` should be the path to the template file. 85 | 3. `checkpoint_dir` should be the folder containing the model parameter files. 86 | 87 | The provided `config.cfg` file and the corresponding `data` folder can be used normally after downloading the pre-trained model described in [Data](#data). 88 | 89 |
90 | 91 | ## Usages 92 | 93 | After setting up the config file, you can thoroughly test with the scripts we provide below. 94 | 95 | ### Random Sample 96 | 97 | Randomly generate `--number` 3D face models. 98 | 99 | ``` 100 | python scripts/face_sample.py --config_file config/config.cfg --out_dir sample_out --number 10 101 | ``` 102 | 103 | ### Facial Shape Completion 104 | 105 | **NOTE** that our method has some considerations and flaws to be aware of. 106 | 107 | 1. The unit of the face model is in millimeters. 108 | 2. The range of the facial model should preferably be smaller than the `template.ply` we provide, otherwise add `--dis_percent 0.8` to achieve better results. 109 | 3. We use trimesh's ICP for rigid registration, but are unsure of its accuracy and robustness. You may perform precise rigid registration with `template.ply` first and set `--rr False`. 110 | 111 | Then, you can run our script to perform shape completion on `--in_file`, 112 | 113 | ``` 114 | python scripts/face_completion.py --config_file config/config.cfg --in_file defect.ply --out_file comp.ply --rr True 115 | ``` 116 | 117 | where `--in_file` is a file that trimesh can read, with no requirements on topology. We provide `defect.ply` for convenience. 118 | 119 | ### Mesh Fit / Non-rigid Registration 120 | 121 | When the input is a complete facial model without any defects, the script in the "Facial Shape Completion" section will actually output a fitting result to the input. Since the topology of our method's output is consistent, it can also be used for non-rigid registration. 122 | 123 |
124 | 125 | ## Train 126 | 127 | After preparing the dataset with unified topology, you can train a shape generator using the code we provided. First, determine a dataset folder path `A`, then create or modify the config file, changing the first three lines to 128 | 129 | ``` 130 | dataset_dir = A 131 | template_file = A\template.ply 132 | checkpoint_dir = A 133 | ``` 134 | 135 | You may pre-select the test data that will not be used for training, and then place the remaining training data in the `train` subfolder within the folder `A`. That is to say, before training, the directory structure of folder `A` should be as follows: 136 | 137 | ``` 138 | A/ 139 | ├── train/ 140 | │ ├── training_data_1.ply 141 | │ ├── training_data_2.ply 142 | │ └── ... 143 | ``` 144 | 145 | Then you can start training using the script below: 146 | 147 | ``` 148 | python scripts/train.py --config_file config/config.cfg 149 | ``` 150 | 151 | During training, the structure of folder `A` will look like this, with the average of the training data generated as a template: 152 | 153 | ``` 154 | A/ 155 | ├── train/ 156 | │ ├── training_data_1 157 | │ ├── training_data_2 158 | │ └── ... 159 | └── template.ply 160 | └── norm.pt 161 | └── checkpoint_decoder.pt 162 | ``` 163 | 164 | These results are sufficient for the usage described in [Usages](#usages). 165 | 166 |
167 | 168 | ## Citations 169 | 170 | If you find our work helpful, please cite us 171 | 172 | ``` 173 | @inproceedings{li2024facecom, 174 | title={FaceCom: Towards High-fidelity 3D Facial Shape Completion via Optimization and Inpainting Guidance}, 175 | author={Li, Yinglong and Wu, Hongyu and Wang, Xiaogang and Qin, Qingzhao and Zhao, Yijiao and Wang, Yong and Hao, Aimin}, 176 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 177 | pages={2177--2186}, 178 | year={2024} 179 | } 180 | ``` 181 | 182 | or 183 | 184 | ``` 185 | @article{li2024facecom, 186 | title={FaceCom: Towards High-fidelity 3D Facial Shape Completion via Optimization and Inpainting Guidance}, 187 | author={Li, Yinglong and Wu, Hongyu and Wang, Xiaogang and Qin, Qingzhao and Zhao, Yijiao and Hao, Aimin and others}, 188 | journal={arXiv preprint arXiv:2406.02074}, 189 | year={2024} 190 | } 191 | ``` 192 | 193 | -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn.conv import FeaStConv 4 | from torch_geometric.nn import BatchNorm 5 | from torch_geometric.data.batch import Batch 6 | 7 | 8 | class FMGenEncoder(torch.nn.Module): 9 | def __init__(self, config, A, D): 10 | super(FMGenEncoder, self).__init__() 11 | self.A = [torch.tensor(a, requires_grad=False) for a in A] 12 | self.D = [torch.tensor(a, requires_grad=False) for a in D] 13 | 14 | self.batch_norm = config['batch_norm'] 15 | self.n_layers = config['n_layers'] 16 | self.z_length = config['z_length'] 17 | self.num_features_global = config['num_features_global'] 18 | self.num_features_local = config['num_features_local'] 19 | 20 | # conv layers 21 | self.encoder_convs_global = torch.nn.ModuleList([ 22 | FeaStConv(in_channels=self.num_features_global[k], out_channels=self.num_features_global[k + 1]) 23 | for k in range(self.n_layers) 24 | ]) 25 | self.encoder_convs_local = torch.nn.ModuleList([ 26 | FeaStConv(in_channels=self.num_features_local[k], out_channels=self.num_features_local[k + 1]) 27 | for k in range(self.n_layers) 28 | ]) 29 | 30 | # bn 31 | self.encoder_bns_local = torch.nn.ModuleList([ 32 | BatchNorm(in_channels=self.num_features_local[k]) for k in range(self.n_layers + 1) 33 | ]) 34 | self.encoder_bns_global = torch.nn.ModuleList([ 35 | BatchNorm(in_channels=self.num_features_global[k]) for k in range(self.n_layers + 1) 36 | ]) 37 | 38 | # linear layers 39 | self.encoder_lin = torch.nn.Linear(self.z_length + self.num_features_global[-1], self.z_length) 40 | self.encoder_local_lin = torch.nn.Linear(self.D[0].shape[1] * self.num_features_local[-1], self.z_length) 41 | 42 | self.reset_parameter() 43 | 44 | def reset_parameter(self): 45 | torch.nn.init.normal_(self.encoder_lin.weight, 0, 0.1) 46 | torch.nn.init.normal_(self.encoder_local_lin.weight, 0, 0.1) 47 | 48 | def forward(self, x, batch_size): 49 | self.A = [a.to(x.device) for a in self.A] 50 | self.D = [d.to(x.device) for d in self.D] 51 | 52 | # 分叉 53 | x_global = x 54 | x_local = x 55 | 56 | """ 57 | global 58 | """ 59 | # x_global: [batch_size * D[0].shape[1], num_features_global[0]] 60 | for i in range(self.n_layers): 61 | # 卷积 62 | x_global = self.encoder_convs_global[i](x=x_global, edge_index=self.A[i]) 63 | # 归一化 64 | if self.batch_norm: 65 | x_global = self.encoder_bns_global[i + 1](x_global) 66 | # 激活函数 67 | x_global = F.leaky_relu(x_global) 68 | # 下采样 69 | x_global = x_global.reshape(batch_size, -1, self.num_features_global[i + 1]) 70 | y = torch.zeros(batch_size, self.D[i].shape[0], x_global.shape[2], device=x_global.device) 71 | for j in range(batch_size): 72 | y[j] = torch.mm(self.D[i], x_global[j]) 73 | x_global = y 74 | x_global = x_global.reshape(-1, self.num_features_global[i + 1]) 75 | # x_global: [batch_size * D[-1].shape[0], num_features_global[-1]] 76 | 77 | x_global = x_global.reshape(batch_size, -1, self.num_features_global[-1]) 78 | # x_global: [batch_size, D[-1].shape[0], num_features_global[-1]] 79 | 80 | # (mean pool & relu) 81 | x_global = torch.mean(x_global, dim=1) 82 | x_global = F.leaky_relu(x_global) 83 | # x_global: [batch_size, num_features_global[-1]] 84 | 85 | """ 86 | local 87 | """ 88 | # begin x_local: [batch_size * D[0].shape[1], num_features_local[0]] 89 | for i in range(self.n_layers): 90 | # 卷积 91 | x_local = self.encoder_convs_local[i](x=x_local, edge_index=self.A[0]) 92 | # 归一化 93 | if self.batch_norm: 94 | x_local = self.encoder_bns_local[i + 1](x_local) 95 | # 激活函数 96 | x_local = F.leaky_relu(x_local) 97 | # x_local: [batch_size * D[0].shape[1], num_features_local[-1]] 98 | 99 | x_local = x_local.reshape(batch_size, -1) 100 | # x_local: [batch_size, D[0].shape[1] * num_features_local[0]] 101 | 102 | # (linear & relu) 103 | x_local = self.encoder_local_lin(x_local) 104 | x_local = F.leaky_relu(x_local) 105 | # x_local: [batch_size, z_length] 106 | 107 | """ 108 | get z 109 | """ 110 | z = torch.concat((x_global, x_local), dim=1) 111 | z = self.encoder_lin(z) 112 | 113 | return z 114 | 115 | 116 | class FMGenDecoder(torch.nn.Module): 117 | def __init__(self, config, A, U): 118 | super(FMGenDecoder, self).__init__() 119 | self.A = [torch.tensor(a, requires_grad=False) for a in A] 120 | self.U = [torch.tensor(u, requires_grad=False) for u in U] 121 | 122 | self.batch_norm = config['batch_norm'] 123 | self.n_layers = config['n_layers'] 124 | self.z_length = config['z_length'] 125 | self.num_features_global = config['num_features_global'] 126 | self.num_features_local = config['num_features_local'] 127 | 128 | # conv layers 129 | self.decoder_convs_global = torch.nn.ModuleList([ 130 | FeaStConv(in_channels=self.num_features_global[-1 - k], out_channels=self.num_features_global[-2 - k]) 131 | for k in range(self.n_layers) 132 | ]) 133 | self.decoder_convs_local = torch.nn.ModuleList([ 134 | FeaStConv(in_channels=self.num_features_local[-1 - k], out_channels=self.num_features_local[-2 - k]) 135 | for k in range(self.n_layers) 136 | ]) 137 | 138 | # bn 139 | self.decoder_bns_local = torch.nn.ModuleList([ 140 | BatchNorm(in_channels=self.num_features_local[-1 - k]) for k in range(self.n_layers + 1) 141 | ]) 142 | self.decoder_bns_global = torch.nn.ModuleList([ 143 | BatchNorm(in_channels=self.num_features_global[-1 - k]) for k in range(self.n_layers + 1) 144 | ]) 145 | 146 | # linear layers 147 | self.decoder_lin = torch.nn.Linear(self.z_length, self.z_length + self.num_features_global[-1]) 148 | self.decoder_local_lin = torch.nn.Linear(self.z_length, self.num_features_local[-1] * self.U[0].shape[0]) 149 | 150 | self.reset_parameter() 151 | 152 | # merge ratio 153 | self.global_ratio = 0.01 154 | self.local_ratio = 1 - self.global_ratio 155 | 156 | def reset_parameter(self): 157 | torch.nn.init.normal_(self.decoder_lin.weight, 0, 0.1) 158 | torch.nn.init.normal_(self.decoder_local_lin.weight, 0, 0.1) 159 | 160 | def forward(self, z, batch_size): 161 | self.A = [a.to(z.device) for a in self.A] 162 | self.U = [u.to(z.device) for u in self.U] 163 | 164 | # decoder linear & 分叉 165 | x = self.decoder_lin(z) 166 | x_global = x[:, :self.num_features_global[-1]] 167 | x_local = x[:, self.num_features_global[-1]:] 168 | 169 | """ 170 | global 171 | """ 172 | # x_global: [batch_size, num_features_global[-1]] 173 | x_global = torch.unsqueeze(x_global, dim=1).repeat(1, self.U[-1].shape[1], 1) 174 | # x_global: [batch_size, U[-1].shape[1], num_features_global[-1]] 175 | 176 | for i in range(self.n_layers): 177 | # 上采样 178 | x_global = x_global.reshape(batch_size, -1, self.num_features_global[-1 - i]) 179 | y = torch.zeros(batch_size, self.U[-1 - i].shape[0], x_global.shape[2], device=x_global.device) 180 | for j in range(batch_size): 181 | y[j] = torch.mm(self.U[-1 - i], x_global[j]) 182 | x_global = y 183 | x_global = x_global.reshape(-1, self.num_features_global[-1 - i]) 184 | # 卷积 185 | x_global = self.decoder_convs_global[i](x=x_global, edge_index=self.A[-2 - i]) 186 | if i < self.n_layers - 1: 187 | # 归一化 188 | if self.batch_norm: 189 | x_global = self.decoder_bns_global[i + 1](x_global) 190 | # 激活函数 191 | x_global = F.leaky_relu(x_global) 192 | # x_global: [batch_size, U[0].shape[0], num_features_global[0]] 193 | x_global = x_global.reshape(-1, self.num_features_global[0]) 194 | # x_global: [batch_size * U[0].shape[0], num_features_global[0]] 195 | 196 | """ 197 | local 198 | """ 199 | # x_local: [batch_size, z_length] 200 | x_local = self.decoder_local_lin(x_local) 201 | # x_local: [batch_size, num_features_local[-1] * U[0].shape[0]] 202 | x_local = x_local.reshape(-1, self.num_features_local[-1]) 203 | # x_local: [batch_size * U[0].shape[0], num_features_local[-1]] 204 | 205 | for i in range(self.n_layers): 206 | # 卷积 207 | x_local = self.decoder_convs_local[i](x=x_local, edge_index=self.A[0]) 208 | if i < self.n_layers - 1: 209 | # 归一化 210 | if self.batch_norm: 211 | x_local = self.decoder_bns_local[i + 1](x_local) 212 | # 激活函数 213 | x_local = F.leaky_relu(x_local) 214 | # x_local: [batch_size * U[0].shape[0], num_features_local[0]] 215 | 216 | """ 217 | merge 218 | """ 219 | x = self.global_ratio * x_global + self.local_ratio * x_local 220 | 221 | return x 222 | 223 | 224 | class FMGenModel(torch.nn.Module): 225 | def __init__(self, config, A, D, U): 226 | super(FMGenModel, self).__init__() 227 | 228 | self.encoder = FMGenEncoder(config, A, D) 229 | self.decoder = FMGenDecoder(config, A, U) 230 | 231 | def forward(self, batch: Batch): 232 | batch_size = batch.num_graphs 233 | 234 | z = self.encoder(batch.x, batch_size) 235 | x = self.decoder(z, batch_size) 236 | 237 | return x, z 238 | -------------------------------------------------------------------------------- /utils/completion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.utils 3 | 4 | from config.config import read_config 5 | from trimesh import Trimesh, load_mesh 6 | from .funcs import load_generator, spherical_regularization_loss, save_ply_explicit, get_random_z, load_norm 7 | import torch.nn.functional as F 8 | import os 9 | from os.path import join 10 | import warnings 11 | from tqdm import tqdm 12 | import numpy as np 13 | from queue import Queue 14 | 15 | import math 16 | from pytorch3d.structures import Meshes 17 | from pytorch3d.loss import chamfer_distance 18 | from .pytorch3d_extend import distance_from_reference_mesh, smoothness_loss 19 | from trimesh.registration import icp 20 | from scipy.spatial import cKDTree 21 | import sys 22 | 23 | from PIL import Image 24 | import torchvision.transforms as transforms 25 | from .render import render_d 26 | 27 | warnings.filterwarnings("ignore") 28 | os.environ["KMP_DUPLICATE_LIB_OK"] = "True" 29 | sys.setrecursionlimit(30000) 30 | 31 | 32 | def generate_face_sample(out_file, config, generator): 33 | generator.eval() 34 | device = generator.parameters().__next__().device 35 | z = get_random_z(generator.z_length, requires_grad=False) 36 | mean, std = load_norm(config) 37 | out = generator(z.to(device), 1).detach().cpu() 38 | out = out * std + mean 39 | template_mesh = load_mesh(config["template_file"]) 40 | mesh = Trimesh(out, template_mesh.faces) 41 | save_ply_explicit(mesh, out_file) 42 | 43 | 44 | def rigid_registration(in_mesh, config, verbose=True): 45 | if verbose: 46 | print("rigid registration...") 47 | 48 | # mesh = in_mesh.copy() 49 | mesh = in_mesh 50 | template_mesh = load_mesh(config["template_file"]) 51 | 52 | centroid = mesh.centroid 53 | mesh.vertices -= mesh.centroid 54 | T, _, _ = icp(mesh.vertices, template_mesh.vertices, max_iterations=50) 55 | mesh.apply_transform(T) 56 | 57 | return mesh, T, centroid 58 | 59 | 60 | def fit(in_mesh, generator, config, device, max_iters=1000, loss_convergence=1e-6, lambda_reg=None, 61 | verbose=True, dis_percent=None): 62 | if verbose: 63 | sys.stdout.write("\rFitting...") 64 | sys.stdout.flush() 65 | 66 | mesh = in_mesh.copy() 67 | template_mesh = load_mesh(config["template_file"]) 68 | 69 | generator.eval() 70 | 71 | target_pc = torch.tensor(mesh.vertices, dtype=torch.float).to(device) 72 | 73 | z = get_random_z(generator.z_length, requires_grad=True, jitter=True) 74 | 75 | mean, std = load_norm(config) 76 | mean = mean.to(device) 77 | std = std.to(device) 78 | faces = torch.tensor(template_mesh.faces).to(device) 79 | 80 | optimizer = torch.optim.Adam([z], lr=0.1) 81 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9) 82 | 83 | if not lambda_reg: 84 | lambda_reg = config['lambda_reg'] 85 | last_loss = math.inf 86 | iters = 0 87 | for i in range(max_iters): 88 | optimizer.zero_grad() 89 | 90 | out = generator(z.to(device), 1) 91 | out = out * std + mean 92 | 93 | loss_reg = spherical_regularization_loss(z) 94 | loss = loss_reg 95 | 96 | distance = torch.sqrt(distance_from_reference_mesh(target_pc, out, faces)) 97 | if dis_percent: 98 | # 只取距离最小的一部分顶点 99 | distance, idx = torch.sort(distance) 100 | distance = distance[:int(dis_percent * len(distance))] 101 | loss_dfrm = torch.mean(distance) 102 | loss = loss_dfrm + lambda_reg * loss_reg 103 | 104 | if verbose: 105 | sys.stdout.write( 106 | "\rFitting...\tIter {}, loss_recon: {:.6f}, loss_reg: {:.6f}".format(i + 1, 107 | loss_dfrm.item(), 108 | loss_reg.item())) 109 | sys.stdout.flush() 110 | if math.fabs(last_loss - loss.item()) < loss_convergence: 111 | iters = i 112 | break 113 | 114 | last_loss = loss.item() 115 | loss.backward() 116 | optimizer.step() 117 | scheduler.step() 118 | 119 | out = generator(z.to(device), 1) 120 | out = out * std + mean 121 | fit_mesh = Trimesh(out.detach().cpu(), template_mesh.faces) 122 | if verbose: 123 | print("") 124 | 125 | return fit_mesh 126 | 127 | 128 | def post_processing(in_mesh_fit, in_mesh_faulty, device, laplacian=True, verbose=True): 129 | if verbose: 130 | print("post processing...") 131 | 132 | def get_color_mesh(mesh, idx, init_color=True, color=None): 133 | if color is None: 134 | color = [255, 0, 0, 255] 135 | color_mesh = mesh.copy() 136 | 137 | if init_color: 138 | color_array = np.zeros((mesh.vertices.shape[0], 4), dtype=np.uint8) # RGBA颜色 139 | color_array[idx] = color 140 | color_mesh.visual.vertex_colors = color_array 141 | else: 142 | color_mesh.visual.vertex_colors[idx] = color 143 | return color_mesh 144 | 145 | def extract_connected_components(mesh: Trimesh, idx): 146 | visited = set() 147 | components = [] 148 | 149 | def dfs(vertex, component): 150 | if vertex in visited: 151 | return 152 | visited.add(vertex) 153 | component.add(vertex) 154 | for neighbor in mesh.vertex_neighbors[vertex]: 155 | if neighbor in idx: 156 | dfs(neighbor, component) 157 | 158 | for vertex in idx: 159 | if vertex not in visited: 160 | component = set() 161 | dfs(vertex, component) 162 | components.append(component) 163 | 164 | return components 165 | 166 | def expand_connected_component(mesh, component_, distance): 167 | expanded_component = set() 168 | component = component_.copy() 169 | 170 | for _ in range(distance): 171 | new_neighbors = set() 172 | for vertex in component: 173 | neighbors = mesh.vertex_neighbors[vertex] 174 | for neighbor in neighbors: 175 | if neighbor not in component and neighbor not in expanded_component: 176 | new_neighbors.add(neighbor) 177 | expanded_component.update(new_neighbors) 178 | component.update(new_neighbors) 179 | 180 | return expanded_component 181 | 182 | def special_point_refinement(mesh: Trimesh): 183 | vertices = mesh.vertices 184 | for i in tqdm(range(mesh.vertices.shape[0])): 185 | neighbor = mesh.vertex_neighbors[i] 186 | mean_x = np.mean(vertices[neighbor], axis=0) 187 | x = vertices[i] 188 | mean_distance = np.mean(np.linalg.norm(mesh.vertices[neighbor] - mean_x, axis=1)) 189 | if np.linalg.norm(mean_x - x) > 0.5 * mean_distance: 190 | vertices[i] = mean_x 191 | return Trimesh(vertices, mesh.faces) 192 | 193 | def projection(source_mesh: Trimesh, largest_component_mask, target_mesh: Trimesh, max_iters=1000): 194 | x = torch.tensor(source_mesh.vertices, dtype=torch.float).to(device) 195 | normal_vectors = torch.tensor(source_mesh.vertex_normals, dtype=torch.float).to(device) 196 | ndf = torch.randn(source_mesh.vertices.shape[0]).detach().to(device) 197 | ndf.requires_grad = True 198 | 199 | optimizer = torch.optim.Adam([ndf], lr=0.1) 200 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.9) 201 | 202 | last_loss = math.inf 203 | for i in range(max_iters): 204 | optimizer.zero_grad() 205 | 206 | out = x + normal_vectors * torch.unsqueeze(ndf, 1) 207 | distance = distance_from_reference_mesh(out[~largest_component_mask], 208 | torch.tensor(target_mesh.vertices, dtype=torch.float).to(device), 209 | torch.tensor(target_mesh.faces).to(device)) 210 | distance = torch.sqrt(distance) 211 | loss_dfrm = torch.mean(distance) 212 | 213 | loss_smoothness = smoothness_loss(out, torch.tensor(source_mesh.faces).to(device)) 214 | # 投影时保持平滑,防止某些顶点投影到别的曲面上去 215 | 216 | # loss = loss_dfrm + 1 * loss_smoothness 217 | loss = loss_dfrm 218 | 219 | if verbose: 220 | sys.stdout.write("\rProjection... Iter {}, Loss: {}".format(i + 1, loss.item())) 221 | sys.stdout.flush() 222 | if i > 100 and math.fabs(last_loss - loss.item()) < 1e-6: 223 | break 224 | 225 | last_loss = loss.item() 226 | loss.backward() 227 | optimizer.step() 228 | scheduler.step() 229 | 230 | out = x + normal_vectors * torch.unsqueeze(ndf, 1) 231 | out = out.detach().cpu() 232 | 233 | # 还要把largest_component_mask对应顶点变为原来的顶点 234 | out[largest_component_mask] = torch.tensor(source_mesh.vertices, dtype=torch.float)[largest_component_mask] 235 | 236 | new_mesh = Trimesh(out, source_mesh.faces) 237 | if verbose: 238 | print("") 239 | 240 | return new_mesh 241 | 242 | def find_nearest_vertices(target_vertices, source_vertices, k=1): 243 | tree = cKDTree(source_vertices) 244 | distances, indices = tree.query(target_vertices, k=k) 245 | return indices, distances 246 | 247 | # 读取mesh 248 | fit_mesh = in_mesh_fit.copy() 249 | faulty_mesh = in_mesh_faulty.copy() 250 | 251 | # 1.阈值法识别fit_mesh中缺损部分(“修补”部分) 252 | distance = distance_from_reference_mesh(torch.tensor(fit_mesh.vertices, dtype=torch.float).to(device), 253 | torch.tensor(faulty_mesh.vertices, dtype=torch.float).to(device), 254 | torch.tensor(faulty_mesh.faces).to(device)).cpu().numpy() 255 | idx = np.where(distance > 4)[0] # 阈值 256 | color_mesh = get_color_mesh(fit_mesh, idx) 257 | # color_mesh.export(join(out_path, "color_1.ply")) 258 | 259 | # 2.计算最大联通分量(缺损部分)以及扩展部分 260 | connected_components = extract_connected_components(fit_mesh, idx) 261 | if len(connected_components) == 0: 262 | return fit_mesh 263 | largest_component = max(connected_components, key=len) 264 | expanded_component = expand_connected_component(fit_mesh, largest_component, 2) 265 | color_mesh = get_color_mesh(fit_mesh, list(largest_component)) 266 | color_mesh = get_color_mesh(color_mesh, list(expanded_component), False, [0, 255, 0, 255]) 267 | # color_mesh.export(join(out_path, 'color_2.ply')) 268 | 269 | # 3.最优化法向位移场(投影) 270 | vertex_mask = np.zeros(len(fit_mesh.vertices), dtype=bool) 271 | vertex_mask[list(largest_component)] = True 272 | projection_mesh = projection(fit_mesh, vertex_mask, faulty_mesh) 273 | # projection_mesh.export(join(out_path, "projection.ply")) 274 | 275 | # 4.将最大联通分量的顶点逐一进行位移。 位移量:扩展部分中K个最近顶点[投影]时位移的均值 276 | vertices_expanded = fit_mesh.vertices[list(expanded_component)] 277 | normal_displacement = (projection_mesh.vertices - fit_mesh.vertices)[list(expanded_component)] 278 | indices, distances = find_nearest_vertices(fit_mesh.vertices, vertices_expanded, k=15) 279 | completion_mesh = projection_mesh.copy() 280 | for id in largest_component: 281 | mean_displacement = np.mean(normal_displacement[indices[id]], axis=0) 282 | completion_mesh.vertices[id] += mean_displacement 283 | 284 | def laplacian_smoothing(iterations=2, smoothing_factor=0.5): 285 | vertices_to_smooth = list(expanded_component) 286 | vertices_to_smooth.extend(list(largest_component)) 287 | 288 | # 循环进行拉普拉斯平滑处理 289 | for _ in range(iterations): 290 | smoothed_vertices = [] 291 | for vertex_index in vertices_to_smooth: 292 | vertex = completion_mesh.vertices[vertex_index] 293 | neighbors = completion_mesh.vertex_neighbors[vertex_index] 294 | neighbor_vertices = completion_mesh.vertices[neighbors] 295 | smoothed_vertex = vertex + smoothing_factor * np.mean(neighbor_vertices - vertex, axis=0) 296 | smoothed_vertices.append(smoothed_vertex) 297 | # 更新要平滑的顶点坐标 298 | for i, vertex_index in enumerate(vertices_to_smooth): 299 | completion_mesh.vertices[vertex_index] = smoothed_vertices[i] 300 | 301 | # 5.拉普拉斯平滑处理修复区域 302 | if laplacian: 303 | laplacian_smoothing() 304 | 305 | # 保存 306 | # completion_mesh.export(join(out_path, 'completion.ply')) 307 | 308 | # TODO: trimesh blender? 309 | #### 尝试把最大联通分量与缺损的组合在一起 #### 310 | # indices = list(largest_component) 311 | # new_vertices = completion_mesh.vertices[indices] 312 | # index_map = {index: i for i, index in enumerate(indices)} 313 | # new_faces = [] 314 | # for face in completion_mesh.faces: 315 | # new_face = [index_map[index] for index in face if index in index_map] 316 | # if len(new_face) >= 3: 317 | # new_faces.append(new_face) 318 | # new_mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces) 319 | # # 缺损部分 320 | # new_mesh.export(join(out_path, 'fix_part.ply')) 321 | # # union 322 | # union_mesh = faulty_mesh.union(new_mesh) 323 | # union_mesh.export(join(out_path, "union.ply")) 324 | 325 | # refinement 326 | if verbose: 327 | print("refinement") 328 | for i in range(5): 329 | completion_mesh = special_point_refinement(completion_mesh) 330 | # completion_mesh.export(join(out_path, 'refinement.ply')) 331 | 332 | # done 333 | print("done!") 334 | return completion_mesh 335 | 336 | 337 | def facial_mesh_completion(in_file, out_file, config, generator, lambda_reg=None, verbose=True, rr=False, 338 | dis_percent=None): 339 | dir = os.path.dirname(in_file) 340 | device = generator.parameters().__next__().device 341 | 342 | mesh_in = load_mesh(in_file) 343 | 344 | if rr: 345 | mesh_in, T, centroid = rigid_registration(mesh_in, config, verbose=verbose) 346 | 347 | # save_ply_explicit(mesh_in, "rr.ply") 348 | 349 | mesh_fit = fit(mesh_in, generator, config, device, lambda_reg=lambda_reg, verbose=verbose, loss_convergence=1e-7, 350 | dis_percent=dis_percent) 351 | mesh_com = post_processing(mesh_fit, mesh_in, device, verbose=verbose) 352 | 353 | if rr: 354 | mesh_com.apply_transform(np.linalg.inv(T)) 355 | mesh_com.vertices += centroid 356 | 357 | save_ply_explicit(mesh_com, out_file) 358 | --------------------------------------------------------------------------------