├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── check ├── batch_dist_check.py └── dist_check.py ├── datasets.zip ├── docs ├── anim.gif ├── network.png ├── overview.png ├── semi_anim.gif ├── semi_anim.mp4 └── viewer.png ├── environment.yml ├── mgcn.py ├── mgcn_wo_gt.py ├── preprocess └── prepare.py ├── refinement.py ├── requirements.txt ├── sgcn.py └── util ├── __init__.py ├── datamaker.py ├── loss.py ├── mesh.py ├── meshnet.py ├── models.py ├── networks.py └── render.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | wandb 3 | datasets/ 4 | 5 | *.pt 6 | *.json 7 | *.ply 8 | *.obj 9 | 10 | advancing_front.py 11 | context_fill.py 12 | count_insert_vnum.py 13 | mesh_edit.py 14 | meshfix.py 15 | mgcn_time.py 16 | sgcn_time.py -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | WORKDIR /work 4 | COPY . /work 5 | 6 | RUN apt update 7 | RUN apt install -y libgl1-mesa-dev 8 | 9 | RUN pip install --upgrade pip 10 | RUN pip install -r requirements.txt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Shota Hattori 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Self-Prior for Mesh Inpainting using Self-Supervised Graph Convolutional Networks 2 | 3 |
4 | 5 |

6 | Paper | 7 | arXiv 8 |

9 | Accepted by IEEE TVCG 2024 10 |

11 |
12 | 13 |
14 | gif
15 |

Method Overview

16 | overview
17 |
18 | 19 | ## Usage 20 | 21 | ### Environments 22 | ``` 23 | python==3.10 24 | torch==1.13.0 25 | torch-geometric==2.2.0 26 | ``` 27 | 28 | ### Installation (Docker) 29 | 30 | ``` 31 | docker image build -t astaka-pe/semigcn . 32 | docker run -itd --gpus all -p 8081:8081 --name semigcn -v .:/work astaka-pe/semigcn 33 | docker exec -it semigcn /bin/bash 34 | ``` 35 | 36 | ### Preperation 37 | 38 | - Unzip `datasets.zip` 39 | - Sample meshes will be placed in `datasets/` 40 | - Put your own mesh in a new arbitrary folder as: 41 | - Deficient mesh: `datasets/**/{mesh-name}/{mesh-name}_original.obj` 42 | - Ground truth: `datasets/**/{mesh-name}/{mesh-name}_gt.obj` 43 | - The deficient and the ground truth meshes need not share a same connectivity but their scales must be shared 44 | 45 | ### Preprocess 46 | 47 | - Specify the path of the deficient mesh 48 | - Create **initial mesh** and **smoothed mesh** 49 | 50 | ``` 51 | python3 preprocess/prepare.py -i datasets/**/{mesh-name}/{mesh-name}_original.obj 52 | ``` 53 | - options 54 | - `-r {float}`: Target length of remeshing. The higher the coarser, the lower the finer. `default=0.6`. 55 | 56 | - Computation time: 30 sec 57 | 58 | ### Training 59 | 60 | ``` 61 | python3 sgcn.py -i datasets/**/{mesh-name} # SGCN 62 | python3 mgcn.py -i datasets/**/{mesh-name} # MGCN 63 | ``` 64 | 65 | - options 66 | - `-CAD`: For a CAD model 67 | - `-real`: For a real scan 68 | - `-cache`: For using cache files (for faster computation) 69 | - `-mu` : Weight for refinement 70 | 71 | You can monitor the training progress through the web viewer. (Default: http://localhost:8081) 72 | 73 | viewer
74 | 75 | ### Evaluation 76 | 77 | - Create `datasets/**/{mesh-name}/comparison` and put meshes for evaluation 78 | - A deficient mesh `datasets/**/{mesh-name}/comparison/original.obj` and a ground truth mesh `datasets/**/{mesh-name}/comparison/gt.obj` are needed for evaluation 79 | 80 | ``` 81 | python3 check/batch_dist_check.py -i datasets/**/{mesh-name} 82 | ``` 83 | 84 | - options 85 | - `-real`: For a real scan 86 | 87 | 88 | ### Refinement (Option) 89 | 90 | - If you want to perform only refinement, run 91 | 92 | ``` 93 | python3 refinement.py \\ 94 | -src datasets/**/{mesh-name}/{mesh-name}_initial/obj \\ 95 | -dst datasets/**/{mesh-name}/output/**/100_step/.obj \\ # SGCN 96 | # -dst datasets/**/{mesh-name}/output/**/100_step_0.obj \\ # MGCN 97 | -vm datasets/**/{mesh-name}/{mesh-name}_vmask.json \\ 98 | -ref {arbitrary-output-filename}.obj \\ 99 | ``` 100 | 101 | - option 102 | - `-mu`: Weight for refinement 103 | - Choose a weight so that the remaining vertex positions of the initial mesh and the shape of missing regions of the output mesh are saved 104 | 105 | ## Run other competitive methods 106 | 107 | Please refer to [tinymesh](https://github.com/tatsy/tinymesh). 108 | 109 | 121 | 122 | ## Citation 123 | 124 | ``` 125 | @article{hattori2024semigcn, 126 | title={Learning Self-Prior for Mesh Inpainting Using Self-Supervised Graph Convolutional Networks}, 127 | author={Hattori, Shota and Yatagawa, Tatsuya and Ohtake, Yutaka and Suzuki, Hiromasa}, 128 | journal={IEEE Transactions on Visualization and Computer Graphics}, 129 | year={2024}, 130 | publisher={IEEE} 131 | } 132 | ``` -------------------------------------------------------------------------------- /check/batch_dist_check.py: -------------------------------------------------------------------------------- 1 | import pymeshlab as ml 2 | import numpy as np 3 | import argparse 4 | import glob 5 | import os 6 | import json 7 | import sys 8 | import trimesh 9 | 10 | #EPS = 0.05 # for simulated hole 11 | EPS = 1.0 # for real hole 12 | COLOR_MAX = 0.005 13 | 14 | def get_parser(): 15 | parser = argparse.ArgumentParser(description="calculate hausdorff distances") 16 | parser.add_argument("-i", "--input", type=str, required=True) 17 | parser.add_argument("-real", action="store_true") 18 | args = parser.parse_args() 19 | for k, v in vars(args).items(): 20 | print("{:12s}: {}".format(k, v)) 21 | return args 22 | 23 | def main(): 24 | """ calculate hausdorff distances """ 25 | args = get_parser() 26 | if args.real: 27 | EPS = 1.0 28 | else: 29 | EPS = 0.05 30 | i_dir = "{}/comparison".format(args.input) 31 | o_dir = "{}/colored".format(i_dir) 32 | os.makedirs(o_dir, exist_ok=True) 33 | m_name = i_dir.split("/")[-2] 34 | gt_path = "{}/gt.obj".format(i_dir, m_name) 35 | org_path = "{}/original.obj".format(i_dir, m_name) 36 | all_mesh = glob.glob("{}/*.*".format(i_dir)) 37 | 38 | ms = ml.MeshSet() 39 | ms.load_new_mesh(gt_path) 40 | 41 | face_num = ms.current_mesh().face_number() 42 | diag = ms.current_mesh().bounding_box().diagonal() 43 | 44 | ms.load_new_mesh(org_path) 45 | res = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=1, signeddist=False) 46 | ms.set_current_mesh(0) 47 | quality = ms.current_mesh().vertex_quality_array() 48 | new_vs = quality > EPS 49 | new_pos = ms.current_mesh().vertex_matrix()[new_vs] 50 | with open("{}/inserted.obj".format(o_dir), "w") as f_obj: 51 | for v in new_pos: 52 | print("v {} {} {}".format(v[0], v[1], v[2]), file=f_obj) 53 | 54 | max_val = diag * COLOR_MAX 55 | with open("{}/max_val.txt".format(o_dir), mode="w") as f: 56 | f.write("{:.7f}".format(max_val)) 57 | 58 | for m_path in all_mesh: 59 | if m_path == gt_path or m_path == org_path: 60 | continue 61 | try: 62 | ms.load_new_mesh(m_path) 63 | except: 64 | print("[ERROR] {} is unknown format".format(m_path)) 65 | continue 66 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=ms.current_mesh_id()) 67 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0) 68 | quality = ms.mesh(0).vertex_quality_array() 69 | quality_hole = quality[new_vs] 70 | hd_all = np.sum(np.abs(quality)) / len(quality) / diag 71 | hd_hole = np.sum(np.abs(quality_hole)) / len(quality_hole) / diag 72 | ms.apply_filter("colorize_by_vertex_quality", minval=0, maxval=max_val, zerosym=True) 73 | out_file = os.path.basename(m_path) 74 | out_path = "{}/all={:.6f}-hole={:.6f}-{}".format(o_dir, hd_all, hd_hole, out_file) 75 | ms.save_current_mesh(out_path) 76 | 77 | if __name__ == "__main__": 78 | main() -------------------------------------------------------------------------------- /check/dist_check.py: -------------------------------------------------------------------------------- 1 | from cv2 import COLOR_COLORCVT_MAX 2 | import pymeshlab as ml 3 | import numpy as np 4 | import argparse 5 | import glob 6 | import os 7 | import json 8 | import sys 9 | 10 | EPS = 0.05 11 | COLOR_MAX = 0.005 12 | 13 | def simple_mesh_distance(out_path, gt_path): 14 | """ calculate hausdorff distance """ 15 | ms = ml.MeshSet() 16 | ms.load_new_mesh(gt_path) 17 | face_num = ms.current_mesh().face_number() 18 | diag = ms.current_mesh().bounding_box().diagonal() 19 | max_val = diag * COLOR_COLORCVT_MAX 20 | 21 | ms.load_new_mesh(out_path) 22 | 23 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=ms.current_mesh_id()) 24 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0) 25 | quality = ms.mesh(0).vertex_quality_array() 26 | hd_all = np.sum(np.abs(quality)) / len(quality) / diag 27 | ms.apply_filter("colorize_by_vertex_quality", minval=0, maxval=max_val, zerosym=True) 28 | out_dir = os.path.dirname(out_path) 29 | out_file = os.path.basename(out_path) 30 | os.makedirs("{}/colored".format(out_dir), exist_ok=True) 31 | out_path = "{}/colored/all={:.6f}-{}".format(out_dir, hd_all, out_file) 32 | ms.save_current_mesh(out_path) 33 | return hd_all 34 | 35 | def mesh_distance(gt_path, org_path, out_path, real=False): 36 | """ calculate hausdorff distance """ 37 | if real: 38 | EPS = 1.0 39 | else: 40 | EPS = 0.05 41 | ms = ml.MeshSet() 42 | ms.load_new_mesh(gt_path) 43 | 44 | face_num = ms.current_mesh().face_number() 45 | diag = ms.current_mesh().bounding_box().diagonal() 46 | 47 | ms.load_new_mesh(org_path) 48 | res = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=1, signeddist=False) 49 | ms.set_current_mesh(0) 50 | quality = ms.current_mesh().vertex_quality_array() 51 | new_vs = quality > EPS 52 | 53 | max_val = diag * 0.005 54 | 55 | ms.load_new_mesh(out_path) 56 | 57 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=ms.current_mesh_id()) 58 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0) 59 | quality = ms.mesh(0).vertex_quality_array() 60 | quality_hole = quality[new_vs] 61 | hd_all = np.sum(np.abs(quality)) / len(quality) / diag 62 | hd_hole = np.sum(np.abs(quality_hole)) / len(quality_hole) / diag 63 | ms.apply_filter("colorize_by_vertex_quality", minval=0, maxval=max_val, zerosym=True) 64 | out_dir = os.path.dirname(out_path) 65 | out_file = os.path.basename(out_path) 66 | out_path = "{}/all={:.6f}-hole={:.6f}-{}".format(out_dir, hd_all, hd_hole, out_file) 67 | ms.save_current_mesh(out_path) -------------------------------------------------------------------------------- /datasets.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/datasets.zip -------------------------------------------------------------------------------- /docs/anim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/anim.gif -------------------------------------------------------------------------------- /docs/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/network.png -------------------------------------------------------------------------------- /docs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/overview.png -------------------------------------------------------------------------------- /docs/semi_anim.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/semi_anim.gif -------------------------------------------------------------------------------- /docs/semi_anim.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/semi_anim.mp4 -------------------------------------------------------------------------------- /docs/viewer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/viewer.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: semigcn 2 | channels: 3 | - defaults 4 | - pytorch 5 | - open3d-admin 6 | dependencies: 7 | - python=3.7 8 | - cython 9 | - pytorch=1.7.0 10 | - cudatoolkit=10.2 11 | - numpy 12 | - matplotlib 13 | - tensorflow 14 | - tensorboard 15 | - open3d 16 | - pip 17 | - pip: 18 | - pyyaml 19 | - addict 20 | - plyfile 21 | - opencv-python 22 | - scikit-learn==1.0.2 23 | - --find-links https://pytorch-geometric.com/whl/torch-1.7.1+cu102.html 24 | - torch-scatter==2.0.7 25 | - torch-sparse==0.6.9 26 | - torch-cluster==1.5.9 27 | - torch-spline-conv==1.2.1 28 | - torch-geometric==1.7.1 29 | - pymeshlab==2021.10 30 | - pymeshfix==0.16.1 31 | - tqdm==4.61.1 32 | - wandb==0.12.1 33 | -------------------------------------------------------------------------------- /mgcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import datetime 4 | import argparse 5 | import os 6 | import copy 7 | import random 8 | import viser 9 | from tqdm import tqdm 10 | 11 | import util.loss as Loss 12 | import util.models as Models 13 | import util.datamaker as Datamaker 14 | from util.mesh import Mesh 15 | from util.meshnet import MGCN 16 | import check.dist_check as DIST 17 | 18 | def torch_fix_seed(seed=314): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.use_deterministic_algorithms = True 25 | 26 | def get_parser(): 27 | parser = argparse.ArgumentParser(description="Self-supervised Mesh Completion") 28 | parser.add_argument("-i", "--input", type=str, required=True) 29 | parser.add_argument("-o", "--output", type=str, default="exp") 30 | parser.add_argument("-pos_lr", type=float, default=0.01) 31 | parser.add_argument("-iter", type=int, default=100) 32 | parser.add_argument("-k1", type=float, default=4.0) 33 | parser.add_argument("-k2", type=float, default=4.0) 34 | parser.add_argument("-dm_size", type=int, default=40) 35 | parser.add_argument("-kn", type=int, nargs="*", default=[4]) 36 | parser.add_argument("-batch", type=int, default=5) 37 | parser.add_argument("-skip", action="store_true") 38 | parser.add_argument("-gpu", type=int, default=0) 39 | parser.add_argument("-cache", action="store_true") 40 | parser.add_argument("-CAD", action="store_true") 41 | parser.add_argument("-real", action="store_true") 42 | parser.add_argument("-mu", type=float, default=1.0) 43 | parser.add_argument("-viewer", action="store_true", default=True) 44 | parser.add_argument("-port", type=int, default=8081) 45 | args = parser.parse_args() 46 | 47 | for k, v in vars(args).items(): 48 | print("{:12s}: {}".format(k, v)) 49 | 50 | return args 51 | 52 | def main(): 53 | args = get_parser() 54 | 55 | if args.viewer: 56 | server = viser.ViserServer(port=args.port) 57 | 58 | """ --- create dataset --- """ 59 | mesh_dic, dataset = Datamaker.create_dataset(args.input, dm_size=args.dm_size, kn=args.kn, cache=args.cache) 60 | ini_file, smo_file, v_mask, f_mask, mesh_name = mesh_dic["ini_file"], mesh_dic["smo_file"], mesh_dic["v_mask"], mesh_dic["f_mask"], mesh_dic["mesh_name"] 61 | org_mesh, ini_mesh, smo_mesh, out_mesh = mesh_dic["org_mesh"], mesh_dic["ini_mesh"], mesh_dic["smo_mesh"], mesh_dic["out_mesh"] 62 | rot_mesh = copy.deepcopy(ini_mesh) 63 | dt_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 64 | 65 | vmask_dummy = mesh_dic["vmask_dummy"] 66 | fmask_dummy = mesh_dic["fmask_dummy"] 67 | 68 | """ --- create model instance --- """ 69 | torch_fix_seed() 70 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 71 | posnet = MGCN(device, smo_mesh, ini_mesh, v_mask, skip=args.skip).to(device) 72 | optimizer_pos = torch.optim.Adam(posnet.parameters(), lr=args.pos_lr) 73 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_pos, step_size=50, gamma=0.5) 74 | 75 | anss = posnet.poss 76 | v_masks = posnet.v_masks 77 | nvs = posnet.nvs 78 | meshes = posnet.meshes 79 | v_masks_list = posnet.v_masks_list 80 | poss_list = posnet.poss_list 81 | nvs_all = [len(meshes[0].vs)] + nvs 82 | pos_weight = [0.35, 0.3, 0.2, 0.15] 83 | 84 | os.makedirs("{}/output/{}_mgcn_{}".format(args.input, dt_now, args.output), exist_ok=True) 85 | scale = 2 / np.max(org_mesh.vs) 86 | if args.viewer: 87 | print("\n\033[42m Viewer at: http://localhost:{} \033[0m\n".format(args.port)) 88 | with server.gui.add_folder("Training"): 89 | server.scene.add_mesh_simple( 90 | name="/input", 91 | vertices=org_mesh.vs * scale, 92 | faces=org_mesh.faces, 93 | flat_shading=True, 94 | visible=False, 95 | ) 96 | server.scene.add_mesh_simple( 97 | name="/initial", 98 | vertices=ini_mesh.vs * scale, 99 | faces=ini_mesh.faces, 100 | flat_shading=True, 101 | visible=False, 102 | ) 103 | gui_counter = server.gui.add_number( 104 | "Epoch", 105 | initial_value=0, 106 | disabled=True, 107 | ) 108 | 109 | """ --- learning loop --- """ 110 | with tqdm(total=args.iter) as pbar: 111 | """ --- training --- """ 112 | for epoch in range(1, args.iter+1): 113 | n_data = vmask_dummy.shape[1] 114 | batch_index = torch.randperm(n_data).reshape(-1, args.batch) 115 | epoch_loss_p = 0.0 116 | epoch_loss_n = 0.0 117 | epoch_loss_r = 0.0 118 | epoch_loss_pos = 0.0 119 | epoch_loss = 0.0 120 | 121 | for batch in batch_index: 122 | """ original dummy mask """ 123 | dm_batch = vmask_dummy[:, batch] 124 | posnet.train() 125 | optimizer_pos.zero_grad() 126 | for i, b in enumerate(batch): 127 | """ original dummy mask """ 128 | dm = dm_batch[:, i].reshape(-1, 1) 129 | rm = v_mask.reshape(-1, 1).float() 130 | dm = rm * dm 131 | 132 | ini_vs = ini_mesh.vs 133 | 134 | poss = posnet(dataset, dm) 135 | pos = poss[0] 136 | 137 | norm = Models.compute_fn(pos, ini_mesh.faces) 138 | for mesh_idx, pos_i in enumerate(poss): 139 | if mesh_idx == 0: 140 | loss_p = Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx] 141 | epoch_loss_pos += Loss.mask_pos_rec_loss(pos_i, poss_list[0], v_masks_list[0].reshape(-1).bool()).item() 142 | else: 143 | loss_p = loss_p + Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx] 144 | # loss_p = Loss.mask_pos_rec_loss(poss, anss, v_masks.reshape(-1).bool()) 145 | loss_n = Loss.mask_norm_rec_loss(norm, ini_mesh.fn, f_mask) 146 | if args.CAD: 147 | loss_reg, _ = Loss.fn_bnf_detach_loss(pos, norm, ini_mesh, loop=5) 148 | loss = loss_p + args.k1 * loss_n + args.k2 * loss_reg 149 | epoch_loss_r += loss_reg.item() 150 | else: 151 | # loss_reg = Loss.mesh_laplacian_loss(pos, ini_mesh) 152 | loss = loss_p + args.k1 * loss_n# + 0.0 * loss_reg 153 | epoch_loss_p += loss_p.item() 154 | epoch_loss_n += loss_n.item() 155 | 156 | loss.backward() 157 | epoch_loss += loss.item() 158 | 159 | optimizer_pos.step() 160 | scheduler.step() 161 | 162 | epoch_loss_p /= n_data 163 | epoch_loss_n /= n_data 164 | epoch_loss_r /= n_data 165 | epoch_loss /= n_data 166 | epoch_loss_pos /= n_data 167 | 168 | pbar.set_description("Epoch {}".format(epoch)) 169 | pbar.set_postfix({"loss": epoch_loss}) 170 | 171 | if epoch == args.iter: 172 | out_path = "{}/output/{}_mgcn_{}/train.obj".format(args.input, dt_now, args.output) 173 | out_mesh.vs = pos.detach().to("cpu").numpy().copy() 174 | Mesh.save(out_mesh, out_path) 175 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real) 176 | 177 | """ --- test --- """ 178 | if epoch % 10 == 0: 179 | posnet.eval() 180 | dm = v_mask.reshape(-1, 1).float() 181 | poss = posnet(dataset, dm) 182 | st_nv = 0 183 | for res, mesh in enumerate(meshes): 184 | out_path = "{}/output/{}_mgcn_{}/{}_step_{}.obj".format(args.input, dt_now, args.output, str(epoch), res) 185 | mesh.vs = poss[res].to("cpu").detach().numpy().copy() 186 | st_nv += len(mesh.vs) 187 | Mesh.save(mesh, out_path) 188 | if args.viewer: 189 | server.scene.add_mesh_simple( 190 | name="/output/res-{}".format(res), 191 | vertices=mesh.vs * scale, 192 | faces=mesh.faces, 193 | flat_shading=True, 194 | ) 195 | out_path = "{}/output/{}_mgcn_{}/{}_step_0.obj".format(args.input, dt_now, args.output, str(epoch)) 196 | if args.viewer: 197 | gui_counter.value = epoch 198 | 199 | pbar.update(1) 200 | 201 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real) 202 | 203 | """ refinement """ 204 | posnet.eval() 205 | dm = v_mask.reshape(-1, 1).float() 206 | poss = posnet(dataset, dm) 207 | out_pos = poss[0].to("cpu").detach() 208 | ini_pos = torch.from_numpy(ini_mesh.vs).float() 209 | ref_pos = Mesh.mesh_merge(ini_mesh.Lap, ini_mesh, out_pos, v_mask, w=args.mu) 210 | out_path = "{}/output/{}_mgcn_{}/refine.obj".format(args.input, dt_now, args.output) 211 | out_mesh.vs = ref_pos.detach().numpy().copy() 212 | Mesh.save(out_mesh, out_path) 213 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real) 214 | 215 | if __name__ == "__main__": 216 | main() -------------------------------------------------------------------------------- /mgcn_wo_gt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import datetime 4 | import argparse 5 | import os 6 | import copy 7 | import random 8 | from tqdm import tqdm 9 | 10 | import util.loss as Loss 11 | import util.models as Models 12 | import util.datamaker as Datamaker 13 | from util.mesh import Mesh 14 | from util.meshnet import MGCN 15 | import check.dist_check as DIST 16 | 17 | def torch_fix_seed(seed=314): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed(seed) 22 | torch.backends.cudnn.deterministic = True 23 | torch.use_deterministic_algorithms = True 24 | 25 | def get_parser(): 26 | parser = argparse.ArgumentParser(description="Self-supervised Mesh Completion") 27 | parser.add_argument("-i", "--input", type=str, required=True) 28 | parser.add_argument("-o", "--output", type=str, default="") 29 | parser.add_argument("-pos_lr", type=float, default=0.01) 30 | parser.add_argument("-iter", type=int, default=100) 31 | parser.add_argument("-k1", type=float, default=4.0) 32 | parser.add_argument("-k2", type=float, default=4.0) 33 | parser.add_argument("-dm_size", type=int, default=40) 34 | parser.add_argument("-kn", type=int, nargs="*", default=[4]) 35 | parser.add_argument("-batch", type=int, default=5) 36 | parser.add_argument("-skip", action="store_true") 37 | parser.add_argument("-gpu", type=int, default=0) 38 | parser.add_argument("-cache", action="store_true") 39 | parser.add_argument("-CAD", action="store_true") 40 | parser.add_argument("-real", action="store_true") 41 | parser.add_argument("-mu", type=float, default=1.0) 42 | args = parser.parse_args() 43 | 44 | for k, v in vars(args).items(): 45 | print("{:12s}: {}".format(k, v)) 46 | 47 | return args 48 | 49 | def main(): 50 | args = get_parser() 51 | """ --- create dataset --- """ 52 | mesh_dic, dataset = Datamaker.create_dataset(args.input, dm_size=args.dm_size, kn=args.kn, cache=args.cache) 53 | ini_file, smo_file, v_mask, f_mask, mesh_name = mesh_dic["ini_file"], mesh_dic["smo_file"], mesh_dic["v_mask"], mesh_dic["f_mask"], mesh_dic["mesh_name"] 54 | ini_mesh, smo_mesh, out_mesh = mesh_dic["ini_mesh"], mesh_dic["smo_mesh"], mesh_dic["out_mesh"] 55 | rot_mesh = copy.deepcopy(ini_mesh) 56 | dt_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 57 | 58 | vmask_dummy = mesh_dic["vmask_dummy"] 59 | fmask_dummy = mesh_dic["fmask_dummy"] 60 | 61 | """ --- create model instance --- """ 62 | torch_fix_seed() 63 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 64 | posnet = MGCN(device, smo_mesh, ini_mesh, v_mask, skip=args.skip).to(device) 65 | optimizer_pos = torch.optim.Adam(posnet.parameters(), lr=args.pos_lr) 66 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_pos, step_size=50, gamma=0.5) 67 | 68 | anss = posnet.poss 69 | v_masks = posnet.v_masks 70 | nvs = posnet.nvs 71 | meshes = posnet.meshes 72 | v_masks_list = posnet.v_masks_list 73 | poss_list = posnet.poss_list 74 | nvs_all = [len(meshes[0].vs)] + nvs 75 | pos_weight = [0.35, 0.3, 0.2, 0.15] 76 | 77 | os.makedirs("{}/output/{}_mgcn_{}".format(args.input, dt_now, args.output), exist_ok=True) 78 | 79 | """ --- learning loop --- """ 80 | with tqdm(total=args.iter) as pbar: 81 | """ --- training --- """ 82 | for epoch in range(1, args.iter+1): 83 | n_data = vmask_dummy.shape[1] 84 | batch_index = torch.randperm(n_data).reshape(-1, args.batch) 85 | epoch_loss_p = 0.0 86 | epoch_loss_n = 0.0 87 | epoch_loss_r = 0.0 88 | epoch_loss_pos = 0.0 89 | epoch_loss = 0.0 90 | 91 | for batch in batch_index: 92 | """ original dummy mask """ 93 | dm_batch = vmask_dummy[:, batch] 94 | posnet.train() 95 | optimizer_pos.zero_grad() 96 | for i, b in enumerate(batch): 97 | """ original dummy mask """ 98 | dm = dm_batch[:, i].reshape(-1, 1) 99 | rm = v_mask.reshape(-1, 1).float() 100 | dm = rm * dm 101 | 102 | ini_vs = ini_mesh.vs 103 | 104 | poss = posnet(dataset, dm) 105 | pos = poss[0] 106 | 107 | norm = Models.compute_fn(pos, ini_mesh.faces) 108 | for mesh_idx, pos_i in enumerate(poss): 109 | if mesh_idx == 0: 110 | loss_p = Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx] 111 | epoch_loss_pos += Loss.mask_pos_rec_loss(pos_i, poss_list[0], v_masks_list[0].reshape(-1).bool()).item() 112 | else: 113 | loss_p = loss_p + Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx] 114 | # loss_p = Loss.mask_pos_rec_loss(poss, anss, v_masks.reshape(-1).bool()) 115 | loss_n = Loss.mask_norm_rec_loss(norm, ini_mesh.fn, f_mask) 116 | if args.CAD: 117 | loss_reg, _ = Loss.fn_bnf_detach_loss(pos, norm, ini_mesh, loop=5) 118 | loss = loss_p + args.k1 * loss_n + args.k2 * loss_reg 119 | epoch_loss_r += loss_reg.item() 120 | else: 121 | # loss_reg = Loss.mesh_laplacian_loss(pos, ini_mesh) 122 | loss = loss_p + args.k1 * loss_n# + 0.0 * loss_reg 123 | epoch_loss_p += loss_p.item() 124 | epoch_loss_n += loss_n.item() 125 | 126 | loss.backward() 127 | epoch_loss += loss.item() 128 | 129 | optimizer_pos.step() 130 | scheduler.step() 131 | 132 | epoch_loss_p /= n_data 133 | epoch_loss_n /= n_data 134 | epoch_loss_r /= n_data 135 | epoch_loss /= n_data 136 | epoch_loss_pos /= n_data 137 | 138 | pbar.set_description("Epoch {}".format(epoch)) 139 | pbar.set_postfix({"loss": epoch_loss}) 140 | 141 | """ --- test --- """ 142 | if epoch % 10 == 0: 143 | posnet.eval() 144 | dm = v_mask.reshape(-1, 1).float() 145 | poss = posnet(dataset, dm) 146 | st_nv = 0 147 | for res, mesh in enumerate(meshes): 148 | out_path = "{}/output/{}_mgcn_{}/{}_step_{}.obj".format(args.input, dt_now, args.output, str(epoch), res) 149 | mesh.vs = poss[res].to("cpu").detach().numpy().copy() 150 | st_nv += len(mesh.vs) 151 | Mesh.save(mesh, out_path) 152 | out_path = "{}/output/{}_mgcn_{}/{}_step_0.obj".format(args.input, dt_now, args.output, str(epoch)) 153 | 154 | pbar.update(1) 155 | 156 | 157 | """ refinement """ 158 | posnet.eval() 159 | dm = v_mask.reshape(-1, 1).float() 160 | poss = posnet(dataset, dm) 161 | out_pos = poss[0].to("cpu").detach() 162 | ini_pos = torch.from_numpy(ini_mesh.vs).float() 163 | ref_pos = Mesh.mesh_merge(ini_mesh.Lap, ini_mesh, out_pos, v_mask, w=args.mu) 164 | out_path = "{}/output/{}_mgcn_{}/refine.obj".format(args.input, dt_now, args.output) 165 | out_mesh.vs = ref_pos.detach().numpy().copy() 166 | Mesh.save(out_mesh, out_path) 167 | 168 | if __name__ == "__main__": 169 | main() -------------------------------------------------------------------------------- /preprocess/prepare.py: -------------------------------------------------------------------------------- 1 | import pymeshlab as ml 2 | import pymeshfix as mf 3 | import numpy as np 4 | import torch 5 | import json 6 | import argparse 7 | import sys 8 | import os 9 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 10 | from util.mesh import Mesh 11 | 12 | SMOOTH_ITER = 30 13 | MAXHOLESIZE = 1000 14 | EPSILON = 0.2 15 | REMESH_TARGET = ml.Percentage(0.6) 16 | 17 | def get_parse(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("-i", "--input", type=str, required=True) 20 | parser.add_argument("-r", "--remesh", type=float, default=0.6) 21 | args = parser.parse_args() 22 | 23 | for k, v in vars(args).items(): 24 | print("{:12s}: {}".format(k, v)) 25 | 26 | return args 27 | 28 | def mesh_fix(ms, dirname, meshname): 29 | org_vm = ms.current_mesh().vertex_matrix() 30 | org_fm = ms.current_mesh().face_matrix() 31 | meshfix = mf.MeshFix(org_vm, org_fm) 32 | meshfix.repair() 33 | meshfix.save("{}/tmp/{}_fixed.ply".format(dirname, meshname)) 34 | 35 | def remesh(ms, dirname, meshname, targetlen): 36 | ms.load_new_mesh("{}/tmp/{}_fixed.ply".format(dirname, meshname)) 37 | # [original, fixed] 38 | 39 | ms.apply_filter("remeshing_isotropic_explicit_remeshing", targetlen=targetlen) 40 | # [original, remeshed] 41 | 42 | ms.save_current_mesh("{}/tmp/{}_remeshed.obj".format(dirname, meshname)) 43 | 44 | def normalize(ms): 45 | ms.apply_filter("transform_scale_normalize", scalecenter="barycenter", unitflag=True, alllayers=True) 46 | ms.apply_filter("transform_translate_center_set_origin", traslmethod="Center on Layer BBox", alllayers=True) 47 | 48 | def edge_based_scaling(mesh): 49 | edge_vec = mesh.vs[mesh.edges][:, 0, :] - mesh.vs[mesh.edges][:, 1, :] 50 | ave_len = np.sum(np.linalg.norm(edge_vec, axis=1)) / mesh.edges.shape[0] 51 | mesh.vs /= ave_len 52 | return ave_len, mesh 53 | 54 | def normalize_scale(ms, dirname, meshname): 55 | try: 56 | ms.load_new_mesh("{}/{}_gt.obj".format(dirname, meshname)) 57 | ms.save_current_mesh("{}/tmp/{}_gt.obj".format(dirname, meshname)) 58 | # [original, remeshes, gt<-current] 59 | except: 60 | pass 61 | ms.set_current_mesh(1) 62 | # [original, remeshed<-current, gt] 63 | normalize(ms) 64 | 65 | ms.save_current_mesh("{}/tmp/{}_initial.obj".format(dirname, meshname)) 66 | ms.set_current_mesh(0) 67 | ms.save_current_mesh("{}/tmp/{}_original.obj".format(dirname, meshname)) 68 | try: 69 | ms.set_current_mesh(2) 70 | ms.save_current_mesh("{}/{}_gt.obj".format(dirname, meshname)) 71 | except: 72 | pass 73 | ms.set_current_mesh(1) 74 | 75 | init_mesh = Mesh("{}/tmp/{}_initial.obj".format(dirname, meshname), build_mat=False) 76 | org_mesh = Mesh("{}/tmp/{}_original.obj".format(dirname, meshname), manifold=False) 77 | try: 78 | gt_mesh = Mesh("{}/{}_gt.obj".format(dirname, meshname), manifold=False) 79 | except: 80 | pass 81 | ave_len, init_mesh = edge_based_scaling(init_mesh) 82 | org_mesh.vs /= ave_len 83 | Mesh.save(init_mesh, "{}/{}_initial.obj".format(dirname, meshname)) 84 | init_mesh.path = "{}/{}_initial.obj".format(dirname, meshname) 85 | torch.save(init_mesh, "{}/{}_initial.pt".format(dirname, meshname)) 86 | Mesh.save(org_mesh, "{}/{}_original.obj".format(dirname, meshname), color=True) 87 | try: 88 | gt_mesh.vs /= ave_len 89 | Mesh.save(gt_mesh, "{}/{}_gt.obj".format(dirname, meshname)) 90 | except: 91 | pass 92 | 93 | def write_mask(ms, epsilon, dirname, meshname): 94 | ms.clear() 95 | ms.load_new_mesh("{}/{}_original.obj".format(dirname, meshname)) 96 | ms.load_new_mesh("{}/{}_initial.obj".format(dirname, meshname)) 97 | 98 | ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0, signeddist=False) 99 | quality = ms.current_mesh().vertex_quality_array() 100 | mask_vs = quality < epsilon 101 | 102 | with open("{}/{}_vmask.json".format(dirname, meshname), "w") as vm: 103 | json.dump(mask_vs.tolist(), vm) 104 | 105 | new_vs = ms.current_mesh().vertex_matrix()[np.logical_not(mask_vs)] 106 | with open("{}/{}_inserted.obj".format(dirname, meshname), "w") as f_obj: 107 | for v in new_vs: 108 | print("v {} {} {}".format(v[0], v[1], v[2]), file=f_obj) 109 | 110 | def smooth(ms, dirname, meshname): 111 | # [original, normalized, initial] 112 | ms.apply_filter("laplacian_smooth", selected=True) 113 | ms.apply_filter("laplacian_smooth", stepsmoothnum=SMOOTH_ITER, cotangentweight=False, selected=False) 114 | ms.save_current_mesh("{}/{}_smooth.obj".format(dirname, meshname)) 115 | 116 | def main(): 117 | args = get_parse() 118 | 119 | filename = args.input 120 | targetlen = ml.Percentage(args.remesh) 121 | dirname = os.path.dirname(filename) 122 | meshname = dirname.split("/")[-1] 123 | 124 | ms = ml.MeshSet() 125 | ms.load_new_mesh(filename) 126 | # [original] 127 | 128 | os.makedirs("{}/tmp".format(dirname), exist_ok=True) 129 | print("[MeshFix]") 130 | mesh_fix(ms, dirname, meshname) 131 | print("[Remeshing]") 132 | remesh(ms, dirname, meshname, targetlen) 133 | print("[Normalizing scale]") 134 | normalize_scale(ms, dirname, meshname) 135 | write_mask(ms, EPSILON, dirname, meshname) 136 | print("[Smoothing]") 137 | smooth(ms, dirname, meshname) 138 | print("[FINISHED]") 139 | ms.clear() 140 | 141 | 142 | if __name__ == "__main__": 143 | main() -------------------------------------------------------------------------------- /refinement.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import argparse 5 | 6 | from util.mesh import Mesh 7 | import util.models as Models 8 | import util.datamaker as Datamaker 9 | 10 | def get_parse(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("-src", type=str, required=True) 13 | parser.add_argument("-dst", type=str, required=True) 14 | parser.add_argument("-vm", type=str, required=True) 15 | parser.add_argument("-ref", type=str, required=True) 16 | parser.add_argument("-mu", type=float, default=1.0) 17 | args = parser.parse_args() 18 | 19 | for k, v in vars(args).items(): 20 | print("{:12s}: {}".format(k, v)) 21 | 22 | return args 23 | 24 | def main(): 25 | args = get_parse() 26 | dst_path = args.dst 27 | src_path = args.src 28 | vm_path = args.vm 29 | ref_path = args.ref 30 | dst_mesh = Mesh(dst_path, build_mat=False) 31 | src_mesh = Mesh(src_path) 32 | with open(vm_path, "r") as f: 33 | v_mask = np.array(json.load(f)) 34 | v_mask = torch.from_numpy(v_mask).reshape(-1) 35 | dst_pos = torch.from_numpy(dst_mesh.vs).float() 36 | ref_pos = Mesh.mesh_merge(src_mesh.Lap, src_mesh, dst_pos, v_mask, w=args.mu) 37 | dst_mesh.vs = ref_pos.detach().numpy().copy() 38 | Mesh.save(dst_mesh, ref_path) 39 | 40 | 41 | if __name__ == "__main__": 42 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | numpy 3 | matplotlib==3.3 4 | open3d 5 | scikit-learn==1.0.2 6 | pymeshlab==2021.10 7 | pymeshfix==0.16.1 8 | tqdm==4.61.1 9 | wandb==0.12.1 10 | opencv-python 11 | protobuf==3.20 12 | viser==0.2.1 13 | torch==1.13.0 14 | --find-links https://pytorch-geometric.com/whl/torch-1.13.0+cu117.html 15 | torch-scatter==2.1.0 16 | torch-sparse==0.6.16 17 | torch-cluster==1.6.0 18 | torch-spline-conv==1.2.1 19 | torch-geometric==2.2.0 -------------------------------------------------------------------------------- /sgcn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import datetime 5 | import argparse 6 | import os 7 | import copy 8 | import random 9 | import viser 10 | from tqdm import tqdm 11 | 12 | import util.loss as Loss 13 | import util.models as Models 14 | import util.datamaker as Datamaker 15 | from util.mesh import Mesh 16 | from util.networks import SingleScaleGCN 17 | import check.dist_check as DIST 18 | 19 | def torch_fix_seed(seed=314): 20 | random.seed(seed) 21 | np.random.seed(seed) 22 | torch.manual_seed(seed) 23 | torch.cuda.manual_seed(seed) 24 | torch.backends.cudnn.deterministic = True 25 | torch.use_deterministic_algorithms = True 26 | 27 | def get_parser(): 28 | parser = argparse.ArgumentParser(description="Self-supervised Mesh Completion") 29 | parser.add_argument("-i", "--input", type=str, required=True) 30 | parser.add_argument("-pos_lr", type=float, default=0.01) 31 | parser.add_argument("-iter", type=int, default=100) 32 | parser.add_argument("-iter_refine", type=int, default=1000) 33 | parser.add_argument("-k1", type=float, default=4.0) 34 | parser.add_argument("-k2", type=float, default=4.0) 35 | parser.add_argument("-dm_size", type=int, default=40) 36 | parser.add_argument("-net", type=str, default="single") 37 | parser.add_argument("-activation", type=str, default="lrelu") 38 | parser.add_argument("-ant", type=int, default=0) 39 | parser.add_argument("-kn", type=int, nargs="*", default=[4]) 40 | parser.add_argument("-batch", type=int, default=5) 41 | parser.add_argument("-skip", action="store_true") 42 | parser.add_argument("-drop", type=int, default=1) 43 | parser.add_argument("-drop_rate", type=float, default=0.6) 44 | parser.add_argument("-rot", type=int, default=0) 45 | parser.add_argument("-gpu", type=int, default=0) 46 | parser.add_argument("-cache", action="store_true") 47 | parser.add_argument("-CAD", action="store_true") 48 | parser.add_argument("-real", action="store_true") 49 | parser.add_argument("-mu", type=float, default=1.0) 50 | parser.add_argument("-viewer", action="store_true", default=True) 51 | parser.add_argument("-port", type=int, default=8081) 52 | args = parser.parse_args() 53 | 54 | for k, v in vars(args).items(): 55 | print("{:12s}: {}".format(k, v)) 56 | 57 | return args 58 | 59 | def main(): 60 | args = get_parser() 61 | 62 | if args.viewer: 63 | server = viser.ViserServer(port=args.port) 64 | 65 | mesh_dic, dataset = Datamaker.create_dataset(args.input, dm_size=args.dm_size, kn=args.kn, cache=args.cache) 66 | dataset_rot = copy.deepcopy(dataset) 67 | ini_file, smo_file, v_mask, f_mask, mesh_name = mesh_dic["ini_file"], mesh_dic["smo_file"], mesh_dic["v_mask"], mesh_dic["f_mask"], mesh_dic["mesh_name"] 68 | org_mesh, ini_mesh, smo_mesh, out_mesh = mesh_dic["org_mesh"], mesh_dic["ini_mesh"], mesh_dic["smo_mesh"], mesh_dic["out_mesh"] 69 | rot_mesh = copy.deepcopy(ini_mesh) 70 | dt_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 71 | 72 | vmask_dummy = mesh_dic["vmask_dummy"] 73 | fmask_dummy = mesh_dic["fmask_dummy"] 74 | 75 | """ --- create model instance --- """ 76 | torch_fix_seed() 77 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") 78 | posnet = SingleScaleGCN(device, skip=args.skip).to(device) 79 | optimizer_pos = torch.optim.Adam(posnet.parameters(), lr=args.pos_lr) 80 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_pos, step_size=50, gamma=0.5) 81 | 82 | os.makedirs("{}/output/{}_sgcn".format(args.input, dt_now), exist_ok=True) 83 | scale = 2 / np.max(org_mesh.vs) 84 | if args.viewer: 85 | print("\n\033[42m Viewer at: http://localhost:{} \033[0m\n".format(args.port)) 86 | with server.gui.add_folder("Training"): 87 | server.scene.add_mesh_simple( 88 | name="/input", 89 | vertices=org_mesh.vs * scale, 90 | faces=org_mesh.faces, 91 | flat_shading=True, 92 | visible=False, 93 | ) 94 | server.scene.add_mesh_simple( 95 | name="/initial", 96 | vertices=ini_mesh.vs * scale, 97 | faces=ini_mesh.faces, 98 | flat_shading=True, 99 | visible=False, 100 | ) 101 | gui_counter = server.gui.add_number( 102 | "Epoch", 103 | initial_value=0, 104 | disabled=True, 105 | ) 106 | 107 | """ --- learning loop --- """ 108 | with tqdm(total=args.iter) as pbar: 109 | """ --- training --- """ 110 | for epoch in range(1, args.iter+1): 111 | n_data = vmask_dummy.shape[1] 112 | batch_index = torch.randperm(n_data).reshape(-1, args.batch) 113 | epoch_loss_p = 0.0 114 | epoch_loss_n = 0.0 115 | epoch_loss_r = 0.0 116 | epoch_loss = 0.0 117 | 118 | for batch in batch_index: 119 | dm_batch = vmask_dummy[:, batch] 120 | posnet.train() 121 | optimizer_pos.zero_grad() 122 | 123 | for i, b in enumerate(batch): 124 | dm = dm_batch[:, i].reshape(-1, 1) 125 | rm = v_mask.reshape(-1, 1).float() 126 | dm = rm * dm 127 | ini_vs = ini_mesh.vs 128 | 129 | pos = posnet(dataset_rot, dm) 130 | norm = Models.compute_fn(pos, ini_mesh.faces) 131 | loss_p = Loss.mask_pos_rec_loss(pos, ini_vs, v_mask) 132 | loss_n = Loss.mask_norm_rec_loss(norm, rot_mesh.fn, f_mask) 133 | if args.CAD: 134 | loss_bnf, _ = Loss.fn_bnf_detach_loss(pos, norm, ini_mesh, loop=5) 135 | loss = loss_p + args.k1 * loss_n + args.k2 * loss_bnf 136 | epoch_loss_r += loss_bnf.item() 137 | else: 138 | #loss_lap = Loss.mesh_laplacian_loss(pos, ini_mesh) 139 | loss = loss_p + args.k1 * loss_n#+ 0.0 * loss_lap 140 | epoch_loss_p += loss_p.item() 141 | epoch_loss_n += loss_n.item() 142 | 143 | loss.backward() 144 | epoch_loss += loss.item() 145 | 146 | optimizer_pos.step() 147 | scheduler.step() 148 | 149 | epoch_loss_p /= n_data 150 | epoch_loss_n /= n_data 151 | epoch_loss_r /= n_data 152 | epoch_loss /= n_data 153 | 154 | pbar.set_description("Epoch {}".format(epoch)) 155 | pbar.set_postfix({"loss": epoch_loss}) 156 | 157 | """ --- test --- """ 158 | if epoch == args.iter: 159 | out_path = "{}/output/{}_sgcn/{}_step.obj".format(args.input, dt_now, str(epoch)) 160 | out_mesh.vs = pos.to("cpu").detach().numpy().copy() 161 | Mesh.save(out_mesh, out_path) 162 | 163 | if epoch % 10 == 0: 164 | out_path = "{}/output/{}_sgcn/{}_step.obj".format(args.input, dt_now, str(epoch)) 165 | 166 | posnet.eval() 167 | dm = v_mask.reshape(-1, 1).float() 168 | pos = posnet(dataset, dm) 169 | out_mesh.vs = pos.to("cpu").detach().numpy().copy() 170 | Mesh.save(out_mesh, out_path) 171 | if args.viewer: 172 | server.scene.add_mesh_simple( 173 | name="/output", 174 | vertices=out_mesh.vs * scale, 175 | faces=out_mesh.faces, 176 | flat_shading=True, 177 | ) 178 | if args.viewer: 179 | gui_counter.value = epoch 180 | 181 | pbar.update(1) 182 | 183 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real) 184 | 185 | """ refinement """ 186 | posnet.eval() 187 | dm = v_mask.reshape(-1, 1).float() 188 | out_pos = posnet(dataset, dm).to("cpu").detach() 189 | ref_pos = Mesh.mesh_merge(ini_mesh.Lap, ini_mesh, out_pos, v_mask, w=args.mu) 190 | out_path = "{}/output/{}_sgcn/refine.obj".format(args.input, dt_now, args.net) 191 | out_mesh.vs = ref_pos.detach().numpy().copy() 192 | Mesh.save(out_mesh, out_path) 193 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real) 194 | 195 | if __name__ == "__main__": 196 | main() -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/util/__init__.py -------------------------------------------------------------------------------- /util/datamaker.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import json 4 | import os 5 | from tqdm import tqdm 6 | import torch 7 | import copy 8 | from .mesh import Mesh 9 | from torch_geometric.data import Data 10 | from typing import Tuple 11 | from pathlib import Path 12 | 13 | class Dataset: 14 | def __init__(self, data): 15 | self.keys = data.keys 16 | self.num_nodes = data.num_nodes 17 | self.num_edges = data.num_edges 18 | self.num_node_features = data.num_node_features 19 | self.contains_isolated_nodes = data.has_isolated_nodes() 20 | self.contains_self_loops = data.has_self_loops() 21 | self.z1 = data['z1'] 22 | self.x_pos = data['x_pos'] 23 | self.x_norm = data['x_norm'] 24 | self.edge_index = data['edge_index'] 25 | self.face_index = data['face_index'] 26 | 27 | 28 | def create_dataset(file_path: str, dm_size=40, kn=[1], cache=False) -> Tuple[dict, Dataset]: 29 | """ create mesh """ 30 | mesh_dic = {} 31 | file_path = str(Path(file_path)) 32 | mesh_name = Path(file_path).name 33 | ini_file = "{}/{}_initial.obj".format(file_path, mesh_name) 34 | smo_file = "{}/{}_smooth.obj".format(file_path, mesh_name) 35 | gt_file = "{}/{}_gt.obj".format(file_path, mesh_name) 36 | org_file = "{}/{}_original.obj".format(file_path, mesh_name) 37 | vmask_file = "{}/{}_vmask.json".format(file_path, mesh_name) 38 | fmask_file = "{}/{}_fmask.json".format(file_path, mesh_name) 39 | 40 | print("[Loading meshes...]") 41 | try: 42 | ini_mesh = torch.load("{}/{}_initial.pt".format(file_path, mesh_name)) 43 | out_mesh = torch.load("{}/{}_initial.pt".format(file_path, mesh_name)) 44 | except: 45 | ini_mesh = Mesh(ini_file, build_mat=False) 46 | out_mesh = copy.deepcopy(ini_mesh) 47 | torch.save(ini_mesh, "{}/{}_initial.pt".format(file_path, mesh_name)) 48 | org_mesh = Mesh(org_file, manifold=False) 49 | smo_mesh = Mesh(smo_file, manifold=False) 50 | Mesh.copy_attribute(ini_mesh, smo_mesh) 51 | 52 | try: 53 | with open(vmask_file, "r") as vm: 54 | v_mask = np.array(json.load(vm)) 55 | except: 56 | v_mask = None 57 | 58 | try: 59 | with open(fmask_file, "r") as fm: 60 | f_mask = np.array(json.load(fm)) 61 | color_mask(ini_mesh, f_mask) 62 | except: 63 | if type(v_mask)==np.ndarray: 64 | f_mask = vmask_to_fmask(ini_mesh, v_mask) 65 | color_mask(ini_mesh, f_mask) 66 | else: 67 | f_mask = None 68 | 69 | """ create graph """ 70 | z1 = ini_mesh.vs - smo_mesh.vs 71 | z1 = torch.tensor(z1, dtype=torch.float, requires_grad=True) 72 | 73 | x_pos = torch.tensor(smo_mesh.vs, dtype=torch.float) 74 | x_norm = torch.tensor(ini_mesh.fn, dtype=torch.float) 75 | 76 | edge_index = torch.tensor(ini_mesh.edges.T, dtype=torch.long) 77 | edge_index = torch.cat([edge_index, edge_index[[1,0],:]], dim=1) 78 | face_index = torch.from_numpy(ini_mesh.f_edges) 79 | 80 | mesh_dic["ini_file"] = ini_file 81 | mesh_dic["smo_file"] = smo_file 82 | mesh_dic["gt_file"] = gt_file 83 | mesh_dic["org_file"] = org_file 84 | mesh_dic["v_mask"] = torch.from_numpy(v_mask).bool() 85 | mesh_dic["f_mask"] = f_mask 86 | mesh_dic["mesh_name"] = mesh_name 87 | mesh_dic["org_mesh"] = org_mesh 88 | mesh_dic["ini_mesh"] = ini_mesh 89 | mesh_dic["out_mesh"] = out_mesh 90 | mesh_dic["smo_mesh"] = smo_mesh 91 | mesh_dic["vmask_dummy"] = None 92 | mesh_dic["fmask_dummy"] = None 93 | 94 | os.makedirs("{}/dummy_mask/".format(file_path, mesh_name), exist_ok=True) 95 | if cache: 96 | mesh_dic["vmask_dummy"] = torch.load("{}/dummy_mask/vmask_dummy.pt".format(file_path)) 97 | mesh_dic["fmask_dummy"] = torch.load("{}/dummy_mask/fmask_dummy.pt".format(file_path)) 98 | else: 99 | print("[Creating synthetic occlusion...]") 100 | mesh_dic["vmask_dummy"], mesh_dic["fmask_dummy"] = make_dummy_mask(ini_mesh, dm_size=dm_size, kn=kn, exist_face=f_mask) 101 | torch.save(mesh_dic["vmask_dummy"], "{}/dummy_mask/vmask_dummy.pt".format(file_path)) 102 | torch.save(mesh_dic["fmask_dummy"], "{}/dummy_mask/fmask_dummy.pt".format(file_path)) 103 | 104 | """ create dataset """ 105 | data = Data(x=z1, z1=z1, x_pos=x_pos, x_norm=x_norm, edge_index=edge_index, face_index=face_index, v_mask=mesh_dic["v_mask"], f_mask=mesh_dic["f_mask"], vmask_dummy=mesh_dic["v_mask"], fmask_dummy=mesh_dic["f_mask"]) 106 | dataset = Dataset(data) 107 | return mesh_dic, dataset 108 | 109 | 110 | def make_dummy_mask(mesh, dm_size=40, kn=[3, 4, 5], exist_face=None): 111 | valid_idx = [] 112 | for i in kn: 113 | valid_idx.extend(np.arange(dm_size*i, dm_size*(i+1)).tolist()) 114 | 115 | AI = mesh.AdjI.float() 116 | # p_list = torch.tensor([0.3, 0.06, 0.02, 0.02, 0.014, 0.008, 0.008, 0.0002, 0.001]) 117 | #p_list = torch.tensor([0.007, 0.007, 0.007, 0.007, 0.007, 0.007, 0.007, 0.0007, 0.007]) 118 | p_list = torch.tensor([0.014, 0.014, 0.014, 0.014, 0.014, 0.014, 0.014, 0.0014, 0.014]) 119 | #p_list = torch.tensor([0.021, 0.021, 0.021, 0.021, 0.021, 0.021, 0.021, 0.0021, 0.021]) 120 | #p_list = torch.tensor([0.3, 0.06, 0.04, 0.02, 0.006, 0.008, 0.004, 0.0002, 0.001]) 121 | vmask = torch.tensor([]) 122 | bar = tqdm(total=sum(kn)) 123 | for k in kn: 124 | Mv0 = np.random.binomial(1, p_list[k], size=[len(mesh.vs), dm_size]) 125 | Mv0 = torch.from_numpy(Mv0).float() 126 | for _ in range(k): 127 | Mv1 = (torch.sparse.mm(AI, Mv0) > 0).float() 128 | Mv0 = Mv1 129 | bar.update(1) 130 | 131 | if len(vmask) == 0: 132 | vmask = 1.0 - Mv0 133 | else: 134 | vmask = torch.cat([vmask, 1.0 - Mv0], dim=1) 135 | 136 | fmask = (torch.sparse.mm(mesh.f2v_mat, 1.0 - vmask) == 0).float() 137 | """ write the masked meshes """ 138 | for i, k in enumerate(kn): 139 | color = np.ones([len(mesh.faces), 3]) 140 | color[:, 0] = 0.332 # 0.75 141 | color[:, 1] = 0.664 # 0.75 142 | color[:, 2] = 1.0 # 0.75 143 | black = fmask[:, dm_size*i] == 0 144 | color[black, 0] = 1.0 # 0 145 | color[black, 1] = 0.664 # 0 146 | color[black, 2] = 0.0 # 0 147 | color[exist_face==0, 0] = 1.0 # 0 148 | color[exist_face==0, 1] = 0.0 # 0.5 149 | color[exist_face==0, 2] = 1.0 # 1 150 | dropv = 100 * torch.sum(vmask[:, dm_size*i] == 0) // len(mesh.vs) 151 | filename = "{}/dummy_mask/{}-neighbor-{}per.ply".format(os.path.dirname(mesh.path), k, dropv) 152 | mesh.save_as_ply(filename, color) 153 | 154 | return vmask, fmask 155 | 156 | def vmask_to_fmask(mesh, vmask): 157 | vmask = torch.from_numpy(vmask).reshape(-1, 1).float() 158 | fmask = (torch.sparse.mm(mesh.f2v_mat, 1.0 - vmask) == 0).bool().reshape(-1) 159 | return fmask 160 | 161 | def color_mask(mesh, fmask): 162 | color = np.ones([len(mesh.faces), 3]) 163 | color[:, 0] = 0.332 164 | color[:, 1] = 0.664 165 | color[:, 2] = 1.0 166 | color[fmask==0, 0] = 1.0 167 | color[fmask==0, 1] = 0.0 168 | color[fmask==0, 2] = 1.0 169 | os.makedirs("{}/dummy_mask/".format(os.path.dirname(mesh.path)), exist_ok=True) 170 | filename = "{}/dummy_mask/initial.ply".format(os.path.dirname(mesh.path)) 171 | mesh.save_as_ply(filename, color) -------------------------------------------------------------------------------- /util/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from util.mesh import Mesh 5 | from typing import Union 6 | 7 | 8 | def squared_norm(x, dim=None, keepdim=False): 9 | return torch.sum(x * x, dim=dim, keepdim=keepdim) 10 | 11 | def norm(x, eps=1.0e-12, dim=None, keepdim=False): 12 | return torch.sqrt(squared_norm(x, dim=dim, keepdim=keepdim) + eps) 13 | 14 | def mask_pos_rec_loss(pred_pos: torch.Tensor, real_pos: Union[torch.Tensor, np.ndarray], mask: np.ndarray, ltype="rmse") -> torch.Tensor: 15 | """ reconstructuion error for vertex positions """ 16 | if type(real_pos) == np.ndarray: 17 | real_pos = torch.from_numpy(real_pos) 18 | real_pos = real_pos.to(pred_pos.device) 19 | 20 | if ltype == "l1mae": 21 | diff_pos = torch.sum(torch.abs(real_pos[mask] - pred_pos[mask]), dim=1) 22 | loss = torch.sum(diff_pos) / len(diff_pos) 23 | 24 | elif ltype == "rmse": 25 | diff_pos = torch.abs(real_pos[mask] - pred_pos[mask]) 26 | diff_pos = diff_pos ** 2 27 | diff_pos = torch.sum(diff_pos.squeeze(), dim=1) 28 | diff_pos = torch.sum(diff_pos) / len(diff_pos) 29 | loss = torch.sqrt(diff_pos + 1.0e-6) 30 | else: 31 | print("[ERROR]: ltype error") 32 | exit() 33 | 34 | return loss 35 | 36 | def pos_rec_loss(pred_pos: Union[torch.Tensor, np.ndarray], real_pos: np.ndarray, ltype="rmse") -> torch.Tensor: 37 | """ reconstructuion error for vertex positions """ 38 | if type(pred_pos) == np.ndarray: 39 | pred_pos = torch.from_numpy(pred_pos) 40 | if type(real_pos) == np.ndarray: 41 | real_pos = torch.from_numpy(real_pos) 42 | 43 | real_pos = real_pos.to(pred_pos.device) 44 | 45 | if ltype == "l1mae": 46 | diff_pos = torch.sum(torch.abs(real_pos - pred_pos), dim=1) 47 | loss = torch.sum(diff_pos) / len(diff_pos) 48 | 49 | elif ltype == "rmse": 50 | diff_pos = torch.abs(real_pos - pred_pos) 51 | diff_pos = diff_pos ** 2 52 | diff_pos = torch.sum(diff_pos.squeeze(), dim=1) 53 | diff_pos = torch.sum(diff_pos) / len(diff_pos) 54 | loss = torch.sqrt(diff_pos + 1.0e-6) 55 | else: 56 | print("[ERROR]: ltype error") 57 | exit() 58 | return loss 59 | 60 | def mesh_laplacian_loss(pred_pos: torch.Tensor, mesh: Mesh, ltype="rmse") -> torch.Tensor: 61 | """ simple laplacian for output meshes """ 62 | v2v = mesh.Adj.to(pred_pos.device) 63 | v_dims = mesh.v_dims.reshape(-1, 1).to(pred_pos.device) 64 | lap_pos = torch.sparse.mm(v2v, pred_pos) / v_dims 65 | lap_diff = torch.sum((pred_pos - lap_pos) ** 2, dim=1) 66 | if ltype == "mae": 67 | lap_diff = torch.sqrt(lap_diff + 1.0e-12) 68 | lap_loss = torch.sum(lap_diff) / len(lap_diff) 69 | elif ltype == "rmse": 70 | lap_loss = torch.sum(lap_diff) / len(lap_diff) 71 | lap_loss = torch.sqrt(lap_loss + 1.0e-12) 72 | else: 73 | print("[ERROR]: ltype error") 74 | exit() 75 | 76 | return lap_loss 77 | 78 | def mask_norm_rec_loss(pred_norm: Union[torch.Tensor, np.ndarray], real_norm: Union[torch.Tensor, np.ndarray], mask: np.ndarray, ltype="l1mae") -> torch.Tensor: 79 | """ reconstruction loss for (vertex, face) normal """ 80 | if type(pred_norm) == np.ndarray: 81 | pred_norm = torch.from_numpy(pred_norm) 82 | if type(real_norm) == np.ndarray: 83 | real_norm = torch.from_numpy(real_norm).to(pred_norm.device) 84 | 85 | if ltype == "l2mae": 86 | norm_diff = torch.sum((pred_norm[mask] - real_norm[mask]) ** 2, dim=1) 87 | loss = torch.sqrt(norm_diff + 1e-12) 88 | loss = torch.sum(loss) / len(loss) 89 | elif ltype == "l1mae": 90 | norm_diff = torch.sum(torch.abs(pred_norm[mask] - real_norm[mask]), dim=1) 91 | loss = torch.sum(norm_diff) / len(norm_diff) 92 | elif ltype == "l2rmse": 93 | norm_diff = torch.sum((pred_norm[mask] - real_norm[mask]) ** 2, dim=1) 94 | loss = torch.sum(norm_diff) / len(norm_diff) 95 | loss = torch.sqrt(loss + 1e-12) 96 | elif ltype == "l1rmse": 97 | norm_diff = torch.sum(torch.abs(pred_norm[mask] - real_norm[mask]), dim=1) 98 | loss = torch.sum(norm_diff ** 2) / len(norm_diff) 99 | loss = torch.sqrt(loss + 1e-12) 100 | elif ltype == "cos": 101 | cos_loss = 1.0 - torch.sum(torch.mul(pred_norm[mask], real_norm[mask]), dim=1) 102 | loss = torch.sum(cos_loss, dim=0) / len(cos_loss) 103 | else: 104 | print("[ERROR]: ltype error") 105 | exit() 106 | 107 | return loss 108 | 109 | def norm_rec_loss(pred_norm: Union[torch.Tensor, np.ndarray], real_norm: Union[torch.Tensor, np.ndarray], ltype="l1mae") -> torch.Tensor: 110 | """ reconstruction loss for (vertex, face) normal """ 111 | if type(pred_norm) == np.ndarray: 112 | pred_norm = torch.from_numpy(pred_norm) 113 | if type(real_norm) == np.ndarray: 114 | real_norm = torch.from_numpy(real_norm).to(pred_norm.device) 115 | 116 | if ltype == "l2mae": 117 | norm_diff = torch.sum((pred_norm - real_norm) ** 2, dim=1) 118 | loss = torch.sqrt(norm_diff + 1e-12) 119 | loss = torch.sum(loss) / len(loss) 120 | elif ltype == "l1mae": 121 | norm_diff = torch.sum(torch.abs(pred_norm - real_norm), dim=1) 122 | loss = torch.sum(norm_diff) / len(norm_diff) 123 | elif ltype == "l2rmse": 124 | norm_diff = torch.sum((pred_norm - real_norm) ** 2, dim=1) 125 | loss = torch.sum(norm_diff) / len(norm_diff) 126 | loss = torch.sqrt(loss + 1e-12) 127 | elif ltype == "l1rmse": 128 | norm_diff = torch.sum(torch.abs(pred_norm - real_norm), dim=1) 129 | loss = torch.sum(norm_diff ** 2) / len(norm_diff) 130 | loss = torch.sqrt(loss + 1e-12) 131 | elif ltype == "cos": 132 | cos_loss = 1.0 - torch.sum(torch.mul(pred_norm, real_norm), dim=1) 133 | loss = torch.sum(cos_loss, dim=0) / len(cos_loss) 134 | else: 135 | print("[ERROR]: ltype error") 136 | exit() 137 | 138 | return loss 139 | 140 | def fn_bnf_loss(pos: torch.Tensor, fn: torch.Tensor, mesh: Mesh, ltype="l1mae", loop=5) -> torch.Tensor: 141 | """ bilateral loss for face normal """ 142 | if type(pos) == np.ndarray: 143 | pos = torch.from_numpy(pos).to(fn.device) 144 | else: 145 | pos = pos.detach() 146 | fc = torch.sum(pos[mesh.faces], 1) / 3.0 147 | fa = torch.cross(pos[mesh.faces[:, 1]] - pos[mesh.faces[:, 0]], pos[mesh.faces[:, 2]] - pos[mesh.faces[:, 0]]) 148 | fa = 0.5 * torch.sqrt(torch.sum(fa**2, axis=1) + 1.0e-12) 149 | 150 | #fc = torch.from_numpy(mesh.fc).float().to(fn.device) 151 | #fa = torch.from_numpy(mesh.fa).float().to(fn.device) 152 | f2f = torch.from_numpy(mesh.f2f).long().to(fn.device) 153 | no_neig = 1.0 * (f2f != -1) 154 | 155 | neig_fc = fc[f2f] 156 | neig_fa = fa[f2f] * no_neig 157 | fc0_tile = fc.reshape(-1, 1, 3) 158 | fc_dist = squared_norm(neig_fc - fc0_tile, dim=2) 159 | sigma_c = torch.sum(torch.sqrt(fc_dist + 1.0e-12)) / (fc_dist.shape[0] * fc_dist.shape[1]) 160 | #sigma_c = 1.0 161 | 162 | new_fn = fn 163 | for i in range(loop): 164 | neig_fn = new_fn[f2f] 165 | fn0_tile = new_fn.reshape(-1, 1, 3) 166 | fn_dist = squared_norm(neig_fn - fn0_tile, dim=2) 167 | sigma_s = 0.3 168 | wc = torch.exp(-1.0 * fc_dist / (2 * (sigma_c ** 2))) 169 | ws = torch.exp(-1.0 * fn_dist / (2 * (sigma_s ** 2))) 170 | 171 | W = torch.stack([wc*ws*neig_fa, wc*ws*neig_fa, wc*ws*neig_fa], dim=2) 172 | 173 | new_fn = torch.sum(W * neig_fn, dim=1) 174 | new_fn = new_fn / (norm(new_fn, dim=1, keepdim=True) + 1.0e-12) 175 | 176 | if ltype == "mae": 177 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1) 178 | bnf_diff = torch.sqrt(bnf_diff + 1.0e-12) 179 | loss = torch.sum(bnf_diff) / len(bnf_diff) 180 | elif ltype == "l1mae": 181 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1) 182 | loss = torch.sum(bnf_diff) / len(bnf_diff) 183 | elif ltype == "rmse": 184 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1) 185 | loss = torch.sum(bnf_diff) / len(bnf_diff) 186 | loss = torch.sqrt(loss + 1.0e-12) 187 | elif ltype == "l1rmse": 188 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1) 189 | loss = torch.sum(bnf_diff ** 2) / len(bnf_diff) 190 | loss = torch.sqrt(loss ** 2 + 1.0e-12) 191 | else: 192 | print("[ERROR]: ltype error") 193 | exit() 194 | 195 | return loss, new_fn 196 | 197 | def fn_bnf_detach_loss(pos: torch.Tensor, fn: torch.Tensor, mesh: Mesh, ltype="l1mae", loop=5) -> torch.Tensor: 198 | """ bilateral loss for face normal """ 199 | if type(pos) == np.ndarray: 200 | pos = torch.from_numpy(pos).to(fn.device) 201 | else: 202 | pos = pos.detach() 203 | fc = torch.sum(pos[mesh.faces], 1) / 3.0 204 | fa = torch.cross(pos[mesh.faces[:, 1]] - pos[mesh.faces[:, 0]], pos[mesh.faces[:, 2]] - pos[mesh.faces[:, 0]]) 205 | fa = 0.5 * torch.sqrt(torch.sum(fa**2, axis=1) + 1.0e-12) 206 | 207 | #fc = torch.from_numpy(mesh.fc).float().to(fn.device) 208 | #fa = torch.from_numpy(mesh.fa).float().to(fn.device) 209 | f2f = torch.from_numpy(mesh.f2f).long().to(fn.device) 210 | no_neig = 1.0 * (f2f != -1) 211 | 212 | neig_fc = fc[f2f] 213 | neig_fa = fa[f2f] * no_neig 214 | fc0_tile = fc.reshape(-1, 1, 3) 215 | fc_dist = squared_norm(neig_fc - fc0_tile, dim=2) 216 | sigma_c = torch.sum(torch.sqrt(fc_dist + 1.0e-12)) / (fc_dist.shape[0] * fc_dist.shape[1]) 217 | #sigma_c = 1.0 218 | 219 | new_fn = fn 220 | for i in range(loop): 221 | neig_fn = new_fn[f2f] 222 | fn0_tile = new_fn.reshape(-1, 1, 3) 223 | fn_dist = squared_norm(neig_fn - fn0_tile, dim=2) 224 | sigma_s = 0.3 225 | wc = torch.exp(-1.0 * fc_dist / (2 * (sigma_c ** 2))) 226 | ws = torch.exp(-1.0 * fn_dist / (2 * (sigma_s ** 2))) 227 | 228 | W = torch.stack([wc*ws*neig_fa, wc*ws*neig_fa, wc*ws*neig_fa], dim=2) 229 | 230 | new_fn = torch.sum(W * neig_fn, dim=1) 231 | new_fn = new_fn / (norm(new_fn, dim=1, keepdim=True) + 1.0e-12) 232 | new_fn = new_fn.detach() 233 | 234 | if ltype == "mae": 235 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1) 236 | bnf_diff = torch.sqrt(bnf_diff + 1.0e-12) 237 | loss = torch.sum(bnf_diff) / len(bnf_diff) 238 | elif ltype == "l1mae": 239 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1) 240 | loss = torch.sum(bnf_diff) / len(bnf_diff) 241 | elif ltype == "rmse": 242 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1) 243 | loss = torch.sum(bnf_diff) / len(bnf_diff) 244 | loss = torch.sqrt(loss + 1.0e-12) 245 | elif ltype == "l1rmse": 246 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1) 247 | loss = torch.sum(bnf_diff ** 2) / len(bnf_diff) 248 | loss = torch.sqrt(loss ** 2 + 1.0e-12) 249 | else: 250 | print("[ERROR]: ltype error") 251 | exit() 252 | 253 | return loss, new_fn -------------------------------------------------------------------------------- /util/mesh.py: -------------------------------------------------------------------------------- 1 | from turtle import pd 2 | import numpy as np 3 | import torch 4 | from functools import reduce 5 | from collections import Counter 6 | import scipy as sp 7 | import heapq 8 | import copy 9 | from sklearn.preprocessing import normalize 10 | 11 | OPTIM_VALENCE = 6 12 | VALENCE_WEIGHT = 1 13 | 14 | class Mesh: 15 | def __init__(self, path, manifold=True, build_mat=True, build_code=False): 16 | self.path = path 17 | self.vs, self.vc, self.faces = self.fill_from_file(path) 18 | self.compute_face_normals() 19 | self.compute_face_center() 20 | self.device = 'cpu' 21 | self.simp = False 22 | 23 | if manifold: 24 | self.build_gemm() #self.edges, self.ve 25 | self.compute_vert_normals() 26 | self.build_v2v() 27 | self.build_vf() 28 | self.vs_code = None 29 | if build_mat: 30 | self.build_mesh_lap() 31 | if build_code: 32 | self.vs_code = self.eigen_decomposition(self.lapmat, k=512) 33 | self.fc_code = self.eigen_decomposition(self.f_lapmat, k=100) 34 | 35 | def fill_from_file(self, path): 36 | vs, faces, vc = [], [], [] 37 | f = open(path) 38 | for line in f: 39 | line = line.strip() 40 | splitted_line = line.split() 41 | if not splitted_line: 42 | continue 43 | elif splitted_line[0] == 'v': 44 | vs.append([float(v) for v in splitted_line[1:4]]) 45 | if len(splitted_line) == 7: # colored mesh 46 | vc.append([float(v) for v in splitted_line[4:7]]) 47 | elif splitted_line[0] == 'f': 48 | face_vertex_ids = [int(c.split('/')[0]) for c in splitted_line[1:]] 49 | assert len(face_vertex_ids) == 3 50 | face_vertex_ids = [(ind - 1) if (ind >= 0) else (len(vs) + ind) for ind in face_vertex_ids] 51 | faces.append(face_vertex_ids) 52 | f.close() 53 | vs = np.asarray(vs) 54 | vc = np.asarray(vc) 55 | faces = np.asarray(faces, dtype=int) 56 | 57 | assert np.logical_and(faces >= 0, faces < len(vs)).all() 58 | return vs, vc, faces 59 | 60 | def build_gemm(self): 61 | self.ve = [[] for _ in self.vs] 62 | self.vei = [[] for _ in self.vs] 63 | edge_nb = [] 64 | sides = [] 65 | edge2key = dict() 66 | edges = [] 67 | edges_count = 0 68 | nb_count = [] 69 | for face_id, face in enumerate(self.faces): 70 | faces_edges = [] 71 | for i in range(3): 72 | cur_edge = (face[i], face[(i + 1) % 3]) 73 | faces_edges.append(cur_edge) 74 | for idx, edge in enumerate(faces_edges): 75 | edge = tuple(sorted(list(edge))) 76 | faces_edges[idx] = edge 77 | if edge not in edge2key: 78 | edge2key[edge] = edges_count 79 | edges.append(list(edge)) 80 | edge_nb.append([-1, -1, -1, -1]) 81 | sides.append([-1, -1, -1, -1]) 82 | self.ve[edge[0]].append(edges_count) 83 | self.ve[edge[1]].append(edges_count) 84 | self.vei[edge[0]].append(0) 85 | self.vei[edge[1]].append(1) 86 | nb_count.append(0) 87 | edges_count += 1 88 | for idx, edge in enumerate(faces_edges): 89 | edge_key = edge2key[edge] 90 | edge_nb[edge_key][nb_count[edge_key]] = edge2key[faces_edges[(idx + 1) % 3]] 91 | edge_nb[edge_key][nb_count[edge_key] + 1] = edge2key[faces_edges[(idx + 2) % 3]] 92 | nb_count[edge_key] += 2 93 | for idx, edge in enumerate(faces_edges): 94 | edge_key = edge2key[edge] 95 | sides[edge_key][nb_count[edge_key] - 2] = nb_count[edge2key[faces_edges[(idx + 1) % 3]]] - 1 96 | sides[edge_key][nb_count[edge_key] - 1] = nb_count[edge2key[faces_edges[(idx + 2) % 3]]] - 2 97 | self.edges = np.array(edges, dtype=np.int64) 98 | self.gemm_edges = np.array(edge_nb, dtype=np.int64) 99 | self.sides = np.array(sides, dtype=np.int64) 100 | self.edges_count = edges_count 101 | # lots of DS for loss 102 | """ 103 | self.nvs, self.nvsi, self.nvsin, self.ve_in = [], [], [], [] 104 | for i, e in enumerate(self.ve): 105 | self.nvs.append(len(e)) 106 | self.nvsi += len(e) * [i] 107 | self.nvsin += list(range(len(e))) 108 | self.ve_in += e 109 | self.vei = reduce(lambda a, b: a + b, self.vei, []) 110 | self.vei = torch.from_numpy(np.array(self.vei).ravel()).to(self.device).long() 111 | self.nvsi = torch.from_numpy(np.array(self.nvsi).ravel()).to(self.device).long() 112 | self.nvsin = torch.from_numpy(np.array(self.nvsin).ravel()).to(self.device).long() 113 | self.ve_in = torch.from_numpy(np.array(self.ve_in).ravel()).to(self.device).long() 114 | 115 | self.max_nvs = max(self.nvs) 116 | self.nvs = torch.Tensor(self.nvs).to(self.device).float() 117 | self.edge2key = edge2key 118 | """ 119 | 120 | def compute_face_normals(self): 121 | face_normals = np.cross(self.vs[self.faces[:, 1]] - self.vs[self.faces[:, 0]], self.vs[self.faces[:, 2]] - self.vs[self.faces[:, 0]]) 122 | norm = np.linalg.norm(face_normals, axis=1, keepdims=True) + 1e-24 123 | face_areas = 0.5 * np.sqrt((face_normals**2).sum(axis=1)) 124 | face_normals /= norm 125 | self.fn, self.fa = face_normals, face_areas 126 | 127 | def compute_vert_normals(self): 128 | vert_normals = np.zeros((3, len(self.vs))) 129 | face_normals = self.fn 130 | faces = self.faces 131 | 132 | nv = len(self.vs) 133 | nf = len(faces) 134 | mat_rows = faces.reshape(-1) 135 | mat_cols = np.array([[i] * 3 for i in range(nf)]).reshape(-1) 136 | mat_vals = np.ones(len(mat_rows)) 137 | f2v_mat = sp.sparse.csr_matrix((mat_vals, (mat_rows, mat_cols)), shape=(nv, nf)) 138 | vert_normals = sp.sparse.csr_matrix.dot(f2v_mat, face_normals) 139 | vert_normals = normalize(vert_normals, norm='l2', axis=1) 140 | self.vn = vert_normals 141 | 142 | def compute_face_center(self): 143 | faces = self.faces 144 | vs = self.vs 145 | self.fc = np.sum(vs[faces], 1) / 3.0 146 | 147 | def compute_fn_sphere(self): 148 | fn = self.fn 149 | u = (np.arctan2(fn[:, 1], fn[:, 0]) + np.pi) / (2.0 * np.pi) 150 | v = np.arctan2(np.sqrt(fn[:, 0]**2 + fn[:, 1]**2), fn[:, 2]) / np.pi 151 | self.fn_uv = np.stack([u, v]).T 152 | 153 | def build_uni_lap(self): 154 | """compute uniform laplacian matrix""" 155 | vs = torch.tensor(self.vs.T, dtype=torch.float) 156 | edges = self.edges 157 | ve = self.ve 158 | 159 | sub_mesh_vv = [edges[v_e, :].reshape(-1) for v_e in ve] 160 | sub_mesh_vv = [set(vv.tolist()).difference(set([i])) for i, vv in enumerate(sub_mesh_vv)] 161 | 162 | num_verts = vs.size(1) 163 | mat_rows = [np.array([i] * len(vv), dtype=np.int64) for i, vv in enumerate(sub_mesh_vv)] 164 | mat_rows = np.concatenate(mat_rows) 165 | mat_cols = [np.array(list(vv), dtype=np.int64) for vv in sub_mesh_vv] 166 | mat_cols = np.concatenate(mat_cols) 167 | 168 | mat_rows = torch.from_numpy(mat_rows).long() 169 | mat_cols = torch.from_numpy(mat_cols).long() 170 | mat_vals = torch.ones_like(mat_rows).float() * -1.0 171 | neig_mat = torch.sparse.FloatTensor(torch.stack([mat_rows, mat_cols], dim=0), 172 | mat_vals, 173 | size=torch.Size([num_verts, num_verts])) 174 | vs = vs.T 175 | 176 | sum_count = torch.sparse.mm(neig_mat, torch.ones((num_verts, 1)).type_as(vs)) 177 | mat_rows_ident = np.array([i for i in range(num_verts)]) 178 | mat_cols_ident = np.array([i for i in range(num_verts)]) 179 | mat_ident = np.array([-s for s in sum_count[:, 0]]) 180 | mat_rows_ident = torch.from_numpy(mat_rows_ident).long() 181 | mat_cols_ident = torch.from_numpy(mat_cols_ident).long() 182 | mat_ident = torch.from_numpy(mat_ident).long() 183 | mat_rows = torch.cat([mat_rows, mat_rows_ident]) 184 | mat_cols = torch.cat([mat_cols, mat_cols_ident]) 185 | mat_vals = torch.cat([mat_vals, mat_ident]) 186 | 187 | self.lapmat = torch.sparse.FloatTensor(torch.stack([mat_rows, mat_cols], dim=0), 188 | mat_vals, 189 | size=torch.Size([num_verts, num_verts])) 190 | 191 | def build_vf(self): 192 | vf = [set() for _ in range(len(self.vs))] 193 | for i, f in enumerate(self.faces): 194 | vf[f[0]].add(i) 195 | vf[f[1]].add(i) 196 | vf[f[2]].add(i) 197 | self.vf = vf 198 | 199 | """ build vertex-to-face sparse matrix """ 200 | v2f_inds = [[] for _ in range(2)] 201 | v2f_vals = [] 202 | v2f_areas = [[] for _ in range(len(self.vs))] 203 | for i in range(len(vf)): 204 | v2f_inds[1] += list(vf[i]) 205 | v2f_inds[0] += [i] * len(vf[i]) 206 | v2f_vals += (self.fc[list(vf[i])] - self.vs[i].reshape(1, -1)).tolist() 207 | v2f_areas[i] = np.sum(self.fa[list(vf[i])]) 208 | self.v2f_list = [v2f_inds, v2f_vals, v2f_areas] 209 | 210 | v2f_inds = torch.tensor(v2f_inds).long() 211 | v2f_vals = torch.ones(v2f_inds.shape[1]).float() 212 | self.v2f_mat = torch.sparse.FloatTensor(v2f_inds, v2f_vals, size=torch.Size([len(self.vs), len(self.faces)])) 213 | self.f2v_mat = torch.sparse.FloatTensor(torch.stack([v2f_inds[1], v2f_inds[0]], dim=0), v2f_vals, size=torch.Size([len(self.faces), len(self.vs)])) 214 | 215 | """ build face-to-face (1ring) matrix """ 216 | f_edges = np.array([[i] * 3 for i in range(len(self.faces))]) 217 | f2f = [[] for _ in range(len(self.faces))] 218 | self.f_edges = [[] for _ in range(2)] 219 | for i, f in enumerate(self.faces): 220 | all_neig = list(vf[f[0]]) + list(vf[f[1]]) + list(vf[f[2]]) 221 | one_neig = np.array(list(Counter(all_neig).values())) == 2 222 | f2f_i = np.array(list(Counter(all_neig).keys()))[one_neig].tolist() 223 | self.f_edges[0] += len(f2f_i) * [i] 224 | self.f_edges[1] += f2f_i 225 | f2f[i] = f2f_i + (3 - len(f2f_i)) * [-1] 226 | 227 | self.f2f = np.array(f2f) 228 | self.f_edges = np.array(self.f_edges) 229 | edge_index = torch.tensor(self.edges.T, dtype=torch.long) 230 | self.edge_index = torch.cat([edge_index, edge_index[[1,0],:]], dim=1) 231 | self.face_index = torch.from_numpy(self.f_edges) 232 | 233 | # TODO: change this to correspond to non-watertight mesh 234 | f2f_inds = torch.from_numpy(self.f_edges).long() 235 | f2f_vals = -1.0 * torch.ones(f2f_inds.shape[1]).float() 236 | f2f_mat = torch.sparse.FloatTensor(f2f_inds, f2f_vals, size=torch.Size([len(self.faces), len(self.faces)])) 237 | f_eyes_inds = torch.arange(len(self.faces)).long().repeat(2, 1) 238 | f_dims = torch.ones(len(self.faces)).float() * 3.0 # TODO: change here 239 | f_eyes = torch.sparse.FloatTensor(f_eyes_inds, f_dims, size=torch.Size([len(self.faces), len(self.faces)])) 240 | self.f_lapmat = f2f_mat + f_eyes 241 | 242 | """ build face-to-face (2ring) sparse matrix 243 | self.f2f = np.array(f2f) 244 | f2ring = self.f2f[self.f2f].reshape(-1, 9) 245 | self.f2ring = [set(f) for f in f2ring] 246 | self.f2ring = [list(self.f2ring[i] | set(f)) for i, f in enumerate(self.f2f)] 247 | 248 | self.f_edges = np.concatenate((self.f2f.reshape(1, -1), f_edges.reshape(1, -1)), 0) 249 | mat_inds = torch.from_numpy(self.f_edges).long() 250 | #mat_vals = torch.ones(mat_inds.shape[1]).float() 251 | mat_vals = torch.from_numpy(self.fa[self.f_edges[0]]).float() 252 | self.f2f_mat = torch.sparse.FloatTensor(mat_inds, mat_vals, size=torch.Size([len(self.faces), len(self.faces)])) 253 | """ 254 | def build_v2v(self): 255 | v2v = [[] for _ in range(len(self.vs))] 256 | for i, e in enumerate(self.edges): 257 | v2v[e[0]].append(e[1]) 258 | v2v[e[1]].append(e[0]) 259 | self.v2v = v2v 260 | 261 | """ compute adjacent matrix """ 262 | edges = self.edges 263 | v2v_inds = edges.T 264 | v2v_inds = torch.from_numpy(np.concatenate([v2v_inds, v2v_inds[[1, 0]]], axis=1)).long() 265 | v2v_vals = torch.ones(v2v_inds.shape[1]).float() 266 | self.Adj = torch.sparse.FloatTensor(v2v_inds, v2v_vals, size=torch.Size([len(self.vs), len(self.vs)])) 267 | self.v_dims = torch.sum(self.Adj.to_dense(), axis=1) 268 | D_inds = torch.stack([torch.arange(len(self.vs)), torch.arange(len(self.vs))], dim=0).long() 269 | D_vals = 1 / self.v_dims 270 | self.Diag = torch.sparse.FloatTensor(D_inds, D_vals, size=torch.Size([len(self.vs), len(self.vs)])) 271 | I = torch.eye(len(self.vs)) 272 | self.AdjI = (I + self.Adj).to_sparse() 273 | Lap = I - torch.sparse.mm(self.Diag, self.Adj.to_dense()) 274 | self.Lap = Lap.to_sparse() 275 | 276 | def build_adj_mat(self): 277 | edges = self.edges 278 | v2v_inds = edges.T 279 | v2v_inds = torch.from_numpy(np.concatenate([v2v_inds, v2v_inds[[1, 0]]], axis=1)).long() 280 | v2v_vals = torch.ones(v2v_inds.shape[1]).float() 281 | self.Adj = torch.sparse.FloatTensor(v2v_inds, v2v_vals, size=torch.Size([len(self.vs), len(self.vs)])) 282 | self.v_dims = torch.sum(self.Adj.to_dense(), axis=1) 283 | D_inds = torch.stack([torch.arange(len(self.vs)), torch.arange(len(self.vs))], dim=0).long() 284 | D_vals = 1 / (torch.pow(self.v_dims, 0.5) + 1.0e-12) 285 | self.D_minus_half = torch.sparse.FloatTensor(D_inds, D_vals, size=torch.Size([len(self.vs), len(self.vs)])) 286 | 287 | def build_mesh_lap(self): 288 | self.build_adj_mat() 289 | 290 | vs = self.vs 291 | edges = self.edges 292 | faces = self.faces 293 | 294 | e_dict = {} 295 | for e in edges: 296 | e0, e1 = min(e), max(e) 297 | e_dict[(e0, e1)] = [] 298 | 299 | for f in faces: 300 | s = vs[f[1]] - vs[f[0]] 301 | t = vs[f[2]] - vs[f[1]] 302 | u = vs[f[0]] - vs[f[2]] 303 | cos_0 = np.inner(s, -u) / (np.linalg.norm(s) * np.linalg.norm(u)) 304 | cos_1 = np.inner(t, -s) / (np.linalg.norm(t) * np.linalg.norm(s)) 305 | cos_2 = np.inner(u, -t) / (np.linalg.norm(u) * np.linalg.norm(t)) 306 | cot_0 = cos_0 / (np.sqrt(1 - cos_0 ** 2) + 1e-12) 307 | cot_1 = cos_1 / (np.sqrt(1 - cos_1 ** 2) + 1e-12) 308 | cot_2 = cos_2 / (np.sqrt(1 - cos_2 ** 2) + 1e-12) 309 | key_0 = (min(f[1], f[2]), max(f[1], f[2])) 310 | key_1 = (min(f[2], f[0]), max(f[2], f[0])) 311 | key_2 = (min(f[0], f[1]), max(f[0], f[1])) 312 | e_dict[key_0].append(cot_0) 313 | e_dict[key_1].append(cot_1) 314 | e_dict[key_2].append(cot_2) 315 | 316 | for e in e_dict: 317 | e_dict[e] = -0.5 * (e_dict[e][0] + e_dict[e][1]) 318 | 319 | C_ind = [[], []] 320 | C_val = [] 321 | ident = [0] * len(vs) 322 | for e in e_dict: 323 | C_ind[0].append(e[0]) 324 | C_ind[1].append(e[1]) 325 | C_ind[0].append(e[1]) 326 | C_ind[1].append(e[0]) 327 | C_val.append(e_dict[e]) 328 | C_val.append(e_dict[e]) 329 | ident[e[0]] += -1.0 * e_dict[e] 330 | ident[e[1]] += -1.0 * e_dict[e] 331 | Am_ind = torch.LongTensor(C_ind) 332 | Am_val = -1.0 * torch.FloatTensor(C_val) 333 | self.Am = torch.sparse.FloatTensor(Am_ind, Am_val, torch.Size([len(vs), len(vs)])) 334 | 335 | for i in range(len(vs)): 336 | C_ind[0].append(i) 337 | C_ind[1].append(i) 338 | 339 | C_val = C_val + ident 340 | C_ind = torch.LongTensor(C_ind) 341 | C_val = torch.FloatTensor(C_val) 342 | # cotangent matrix 343 | self.Lm = torch.sparse.FloatTensor(C_ind, C_val, torch.Size([len(vs), len(vs)])) 344 | self.Dm = torch.diag(torch.tensor(ident)).float().to_sparse() 345 | self.Lm_sym = torch.sparse.mm(torch.pow(self.Dm, -0.5), torch.sparse.mm(self.Lm, torch.pow(self.Dm, -0.5).to_dense())).to_sparse() 346 | #self.L = torch.sparse.mm(self.D_minus_half, torch.sparse.mm(C, self.D_minus_half.to_dense())) 347 | self.Am_I = (torch.eye(len(vs)) + self.Am).to_sparse() 348 | Dm_I_diag = torch.sum(self.Am_I.to_dense(), dim=1) 349 | self.Dm_I = torch.diag(Dm_I_diag).to_sparse() 350 | self.meshconvF = torch.sparse.mm(torch.pow(self.Dm_I, -0.5), torch.sparse.mm(self.Am_I, torch.pow(self.Dm_I, -0.5).to_dense())).to_sparse() 351 | 352 | def get_chebconv_coef(self, k=2): 353 | coef_list = [] 354 | eig_max = torch.lobpcg(self.Lm_sym, k=1)[0][0] 355 | Lm_hat = -1.0 * torch.eye(len(self.vs)) + 2.0 * self.Lm_sym / eig_max.item() 356 | self.Lm_hat = Lm_hat.to_sparse() 357 | for i in range(k): 358 | if i == 0: 359 | coef = torch.eye(len(self.vs)).to_sparse() 360 | coef_list.append(coef) 361 | elif i == 1: 362 | coef_list.append(self.Lm_hat) 363 | else: 364 | coef = 2.0 * torch.sparse.mm(self.Lm_hat, coef_list[-1].to_dense()) - coef_list[-2] 365 | coef_list.append(coef.to_sparse()) 366 | return coef_list 367 | 368 | def eigen_decomposition(self, L, k=100): 369 | L = L.to_dense().numpy() 370 | csr = sp.sparse.csr_matrix(L) 371 | w, v = sp.sparse.linalg.eigs(csr, which="SR", k=k) 372 | #index = np.argsort(np.real(w))[::-1] # LR 373 | index = np.argsort(np.real(w)) # SR 374 | # vs_code can include either vertex code or face code 375 | vs_code = np.real(v[:, index]).astype(np.float) 376 | vs_code /= np.linalg.norm(vs_code, axis=0, keepdims=True) 377 | # import open3d as o3d 378 | # mesh = o3d.io.read_triangle_mesh(self.path) 379 | # for i in range(0, k, 1): 380 | # colors = np.zeros([len(self.vs), 3]) 381 | # r = vs_code[:, i] 382 | # r = (r+1)/2 383 | # """ 384 | # r_min, r_max = np.min(r), np.max(r) 385 | # r = (r - r_min) / (r_max-r_min+1.0e-16) 386 | # """ 387 | # colors[:, 1] = r 388 | # mesh.vertex_colors = o3d.utility.Vector3dVector(colors) 389 | # o3d.io.write_triangle_mesh("eigen_decom/eigen_{}.obj".format(str(i)), mesh) 390 | # import pdb;pdb.set_trace() 391 | 392 | return vs_code 393 | 394 | def simplification(self, target_v, valence_aware=True, midpoint=True): 395 | vs, vf, fn, fc, edges = self.vs, self.vf, self.fn, self.fc, self.edges 396 | 397 | """ 1. compute Q for each vertex """ 398 | Q_s = [[] for _ in range(len(vs))] 399 | E_s = [[] for _ in range(len(vs))] 400 | for i, v in enumerate(vs): 401 | f_s = np.array(list(vf[i]), dtype=np.int64) 402 | fc_s = fc[f_s] 403 | fn_s = fn[f_s] 404 | d_s = - 1.0 * np.sum(fn_s * fc_s, axis=1, keepdims=True) 405 | abcd_s = np.concatenate([fn_s, d_s], axis=1) 406 | Q_s[i] = np.matmul(abcd_s.T, abcd_s) 407 | 408 | v4 = np.concatenate([v, np.array([1])]) 409 | E_s[i] = np.matmul(v4, np.matmul(Q_s[i], v4.T)) 410 | 411 | """ 2. compute E for every possible pairs and create heapq """ 412 | E_heap = [] 413 | for i, e in enumerate(edges): 414 | v_0, v_1 = vs[e[0]], vs[e[1]] 415 | Q_0, Q_1 = Q_s[e[0]], Q_s[e[1]] 416 | Q_new = Q_0 + Q_1 417 | 418 | if midpoint: 419 | v_new = 0.5 * (v_0 + v_1) 420 | v4_new = np.concatenate([v_new, np.array([1])]) 421 | else: 422 | Q_lp = np.eye(4) 423 | Q_lp[:3] = Q_new[:3] 424 | try: 425 | Q_lp_inv = np.linalg.inv(Q_lp) 426 | v4_new = np.matmul(Q_lp_inv, np.array([[0,0,0,1]]).reshape(-1,1)).reshape(-1) 427 | except: 428 | v_new = 0.5 * (v_0 + v_1) 429 | v4_new = np.concatenate([v_new, np.array([1])]) 430 | 431 | valence_penalty = 1 432 | if valence_aware: 433 | merged_faces = vf[e[0]].intersection(vf[e[1]]) 434 | valence_new = len(vf[e[0]].union(vf[e[1]]).difference(merged_faces)) 435 | valence_penalty = self.valence_weight(valence_new) 436 | 437 | E_new = np.matmul(v4_new, np.matmul(Q_new, v4_new.T)) * valence_penalty 438 | heapq.heappush(E_heap, (E_new, (e[0], e[1]))) 439 | 440 | """ 3. collapse minimum-error vertex """ 441 | simp_mesh = copy.deepcopy(self) 442 | 443 | vi_mask = np.ones([len(simp_mesh.vs)]).astype(np.bool_) 444 | fi_mask = np.ones([len(simp_mesh.faces)]).astype(np.bool_) 445 | 446 | vert_map = [{i} for i in range(len(simp_mesh.vs))] 447 | 448 | while np.sum(vi_mask) > target_v: 449 | if len(E_heap) == 0: 450 | print("edge cannot be collapsed anymore!") 451 | break 452 | 453 | E_0, (vi_0, vi_1) = heapq.heappop(E_heap) 454 | 455 | if (vi_mask[vi_0] == False) or (vi_mask[vi_1] == False): 456 | continue 457 | 458 | """ edge collapse """ 459 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1]))) 460 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vi_1]) 461 | 462 | if len(shared_vv) != 2: 463 | """ non-manifold! """ 464 | #print("non-manifold can be occured!!" , len(shared_vv)) 465 | self.remove_tri_valance(simp_mesh, vi_0, vi_1, shared_vv, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap) 466 | continue 467 | 468 | elif len(merged_faces) != 2: 469 | """ boundary """ 470 | #print("boundary edge cannot be collapsed!") 471 | continue 472 | 473 | else: 474 | self.edge_collapse(simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap, valence_aware=valence_aware) 475 | #print(np.sum(vi_mask), np.sum(fi_mask)) 476 | 477 | self.rebuild_mesh(simp_mesh, vi_mask, fi_mask, vert_map) 478 | simp_mesh.simp = True 479 | simp_mesh.org = self 480 | self.build_hash(simp_mesh, vi_mask, vert_map) 481 | 482 | return simp_mesh 483 | 484 | def edge_based_simplification(self, target_v, valence_aware=True): 485 | vs, vf, fn, fc, edges = self.vs, self.vf, self.fn, self.fc, self.edges 486 | edge_len = vs[edges][:,0,:] - vs[edges][:,1,:] 487 | edge_len = np.linalg.norm(edge_len, axis=1) 488 | edge_len_heap = np.stack([edge_len, np.arange(len(edge_len))], axis=1).tolist() 489 | heapq.heapify(edge_len_heap) 490 | 491 | """ 2. compute E for every possible pairs and create heapq """ 492 | E_heap = [] 493 | for i, e in enumerate(edges): 494 | v_0, v_1 = vs[e[0]], vs[e[1]] 495 | heapq.heappush(E_heap, (edge_len[i], (e[0], e[1]))) 496 | 497 | """ 3. collapse minimum-error vertex """ 498 | simp_mesh = copy.deepcopy(self) 499 | 500 | vi_mask = np.ones([len(simp_mesh.vs)]).astype(np.bool_) 501 | fi_mask = np.ones([len(simp_mesh.faces)]).astype(np.bool_) 502 | 503 | vert_map = [{i} for i in range(len(simp_mesh.vs))] 504 | 505 | while np.sum(vi_mask) > target_v: 506 | if len(E_heap) == 0: 507 | print("[Warning]: edge cannot be collapsed anymore!") 508 | break 509 | 510 | E_0, (vi_0, vi_1) = heapq.heappop(E_heap) 511 | 512 | if (vi_mask[vi_0] == False) or (vi_mask[vi_1] == False): 513 | continue 514 | 515 | """ edge collapse """ 516 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1]))) 517 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vi_1]) 518 | 519 | if len(shared_vv) != 2: 520 | """ non-manifold! """ 521 | # print("non-manifold can be occured!!" , len(shared_vv)) 522 | continue 523 | 524 | elif len(merged_faces) != 2: 525 | """ boundary """ 526 | # print("boundary edge cannot be collapsed!") 527 | continue 528 | 529 | else: 530 | self.edge_based_collapse(simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, E_heap, valence_aware=valence_aware) 531 | # print(np.sum(vi_mask), np.sum(fi_mask)) 532 | 533 | self.rebuild_mesh(simp_mesh, vi_mask, fi_mask, vert_map) 534 | simp_mesh.simp = True 535 | self.build_hash(simp_mesh, vi_mask, vert_map) 536 | 537 | return simp_mesh 538 | 539 | @staticmethod 540 | def remove_tri_valance(simp_mesh, vi_0, vi_1, shared_vv, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap): 541 | return 542 | 543 | def edge_collapse(self, simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap, valence_aware): 544 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1]))) 545 | new_vi_0 = set(simp_mesh.v2v[vi_0]).union(set(simp_mesh.v2v[vi_1])).difference({vi_0, vi_1}) 546 | simp_mesh.vf[vi_0] = simp_mesh.vf[vi_0].union(simp_mesh.vf[vi_1]).difference(merged_faces) 547 | simp_mesh.vf[vi_1] = set() 548 | simp_mesh.vf[shared_vv[0]] = simp_mesh.vf[shared_vv[0]].difference(merged_faces) 549 | simp_mesh.vf[shared_vv[1]] = simp_mesh.vf[shared_vv[1]].difference(merged_faces) 550 | 551 | simp_mesh.v2v[vi_0] = list(new_vi_0) 552 | for v in simp_mesh.v2v[vi_1]: 553 | if v != vi_0: 554 | simp_mesh.v2v[v] = list(set(simp_mesh.v2v[v]).difference({vi_1}).union({vi_0})) 555 | simp_mesh.v2v[vi_1] = [] 556 | vi_mask[vi_1] = False 557 | 558 | vert_map[vi_0] = vert_map[vi_0].union(vert_map[vi_1]) 559 | vert_map[vi_0] = vert_map[vi_0].union({vi_1}) 560 | vert_map[vi_1] = set() 561 | 562 | fi_mask[np.array(list(merged_faces)).astype(np.int64)] = False 563 | 564 | simp_mesh.vs[vi_0] = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vi_1]) 565 | 566 | """ recompute E """ 567 | Q_0 = Q_s[vi_0] 568 | for vv_i in simp_mesh.v2v[vi_0]: 569 | v_mid = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vv_i]) 570 | Q_1 = Q_s[vv_i] 571 | Q_new = Q_0 + Q_1 572 | v4_mid = np.concatenate([v_mid, np.array([1])]) 573 | 574 | valence_penalty = 1 575 | if valence_aware: 576 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vv_i]) 577 | valence_new = len(simp_mesh.vf[vi_0].union(simp_mesh.vf[vv_i]).difference(merged_faces)) 578 | valence_penalty = self.valence_weight(valence_new) 579 | 580 | E_new = np.matmul(v4_mid, np.matmul(Q_new, v4_mid.T)) * valence_penalty 581 | heapq.heappush(E_heap, (E_new, (vi_0, vv_i))) 582 | 583 | def edge_based_collapse(self, simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, E_heap, valence_aware): 584 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1]))) 585 | new_vi_0 = set(simp_mesh.v2v[vi_0]).union(set(simp_mesh.v2v[vi_1])).difference({vi_0, vi_1}) 586 | simp_mesh.vf[vi_0] = simp_mesh.vf[vi_0].union(simp_mesh.vf[vi_1]).difference(merged_faces) 587 | simp_mesh.vf[vi_1] = set() 588 | simp_mesh.vf[shared_vv[0]] = simp_mesh.vf[shared_vv[0]].difference(merged_faces) 589 | simp_mesh.vf[shared_vv[1]] = simp_mesh.vf[shared_vv[1]].difference(merged_faces) 590 | 591 | simp_mesh.v2v[vi_0] = list(new_vi_0) 592 | for v in simp_mesh.v2v[vi_1]: 593 | if v != vi_0: 594 | simp_mesh.v2v[v] = list(set(simp_mesh.v2v[v]).difference({vi_1}).union({vi_0})) 595 | simp_mesh.v2v[vi_1] = [] 596 | vi_mask[vi_1] = False 597 | 598 | vert_map[vi_0] = vert_map[vi_0].union(vert_map[vi_1]) 599 | vert_map[vi_0] = vert_map[vi_0].union({vi_1}) 600 | vert_map[vi_1] = set() 601 | 602 | fi_mask[np.array(list(merged_faces)).astype(np.int64)] = False 603 | 604 | simp_mesh.vs[vi_0] = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vi_1]) 605 | 606 | """ recompute E """ 607 | for vv_i in simp_mesh.v2v[vi_0]: 608 | v_mid = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vv_i]) 609 | edge_len = np.linalg.norm(simp_mesh.vs[vi_0] - simp_mesh.vs[vv_i]) 610 | valence_penalty = 1 611 | if valence_aware: 612 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vv_i]) 613 | valence_new = len(simp_mesh.vf[vi_0].union(simp_mesh.vf[vv_i]).difference(merged_faces)) 614 | valence_penalty = self.valence_weight(valence_new) 615 | edge_len *= valence_penalty 616 | 617 | heapq.heappush(E_heap, (edge_len, (vi_0, vv_i))) 618 | 619 | @staticmethod 620 | def valence_weight(valence_new): 621 | valence_penalty = abs(valence_new - OPTIM_VALENCE) * VALENCE_WEIGHT + 1 622 | if valence_new == 3: 623 | valence_penalty *= 100000 624 | return valence_penalty 625 | 626 | @staticmethod 627 | def rebuild_mesh(simp_mesh, vi_mask, fi_mask, vert_map): 628 | face_map = dict(zip(np.arange(len(vi_mask)), np.cumsum(vi_mask)-1)) 629 | simp_mesh.vs = simp_mesh.vs[vi_mask] 630 | 631 | vert_dict = {} 632 | for i, vm in enumerate(vert_map): 633 | for j in vm: 634 | vert_dict[j] = i 635 | 636 | for i, f in enumerate(simp_mesh.faces): 637 | for j in range(3): 638 | if f[j] in vert_dict: 639 | simp_mesh.faces[i][j] = vert_dict[f[j]] 640 | 641 | simp_mesh.faces = simp_mesh.faces[fi_mask] 642 | for i, f in enumerate(simp_mesh.faces): 643 | for j in range(3): 644 | simp_mesh.faces[i][j] = face_map[f[j]] 645 | 646 | simp_mesh.compute_face_normals() 647 | simp_mesh.compute_face_center() 648 | simp_mesh.build_gemm() 649 | simp_mesh.compute_vert_normals() 650 | simp_mesh.build_v2v() 651 | simp_mesh.build_vf() 652 | 653 | @staticmethod 654 | def build_hash(simp_mesh, vi_mask, vert_map): 655 | pool_hash = {} 656 | unpool_hash = {} 657 | for simp_i, idx in enumerate(np.where(vi_mask)[0]): 658 | if len(vert_map[idx]) == 0: 659 | print("[ERROR] parent node cannot be found!") 660 | return 661 | for org_i in vert_map[idx]: 662 | pool_hash[org_i] = simp_i 663 | unpool_hash[simp_i] = list(vert_map[idx]) 664 | 665 | """ check """ 666 | vl_sum = 0 667 | for vl in unpool_hash.values(): 668 | vl_sum += len(vl) 669 | 670 | if (len(set(pool_hash.keys())) != len(vi_mask)) or (vl_sum != len(vi_mask)): 671 | print("[ERROR] Original vetices cannot be covered!") 672 | return 673 | 674 | pool_hash = sorted(pool_hash.items(), key=lambda x:x[0]) 675 | simp_mesh.pool_hash = pool_hash 676 | simp_mesh.unpool_hash = unpool_hash 677 | 678 | @staticmethod 679 | def mesh_merge(lap, org_mesh, new_pos, preserve, w=1, w_b=0): 680 | # TODO: Delete constraint of boundary 681 | org_pos = torch.from_numpy(org_mesh.vs).float() 682 | org_wo_boundary = torch.sparse.mm(org_mesh.AdjI.float(), 1-preserve.float().reshape(-1, 1)) == 0 683 | org_wo_boundary = org_wo_boundary.reshape(-1) 684 | boundary = torch.logical_xor(preserve, org_wo_boundary) 685 | A_org_wo_boundary = torch.eye(len(org_pos))[org_wo_boundary] * w 686 | A_boundary = torch.eye(len(org_pos))[boundary] * w_b 687 | A = torch.cat([lap.to_dense(), A_org_wo_boundary], dim=0) 688 | A = torch.cat([A, A_boundary], dim=0).to_sparse() 689 | b_mix = torch.sparse.mm(lap, new_pos) 690 | b_mix[org_wo_boundary] = torch.sparse.mm(lap, org_pos)[org_wo_boundary] 691 | b_org_wo_boundary = org_pos[org_wo_boundary] * w 692 | b_boundary = org_pos[boundary] * w_b 693 | b = torch.cat([b_mix, b_org_wo_boundary], dim=0) 694 | b = torch.cat([b, b_boundary], dim=0) 695 | AtA = torch.sparse.mm(A.t(), A.to_dense()) 696 | Atb = torch.sparse.mm(A.t(), b) 697 | ref_pos = torch.linalg.solve(AtA, Atb) 698 | return ref_pos 699 | 700 | @staticmethod 701 | def copy_attribute(src_mesh, dst_mesh): 702 | dst_mesh.vf, dst_mesh.edges, dst_mesh.ve = src_mesh.vf, src_mesh.edges, src_mesh.ve 703 | dst_mesh.v2v = src_mesh.v2v 704 | 705 | def save(self, filename, color=False): 706 | assert len(self.vs) > 0 707 | vertices = np.array(self.vs, dtype=np.float32).flatten() 708 | indices = np.array(self.faces, dtype=np.uint32).flatten() 709 | v_colors = np.array(self.vc, dtype=np.float32).flatten() 710 | 711 | with open(filename, 'w') as fp: 712 | # Write positions 713 | if len(v_colors) == 0 or color == False: 714 | for i in range(0, vertices.size, 3): 715 | x = vertices[i + 0] 716 | y = vertices[i + 1] 717 | z = vertices[i + 2] 718 | fp.write('v {0:.8f} {1:.8f} {2:.8f}\n'.format(x, y, z)) 719 | 720 | else: 721 | for i in range(0, vertices.size, 3): 722 | x = vertices[i + 0] 723 | y = vertices[i + 1] 724 | z = vertices[i + 2] 725 | c1 = v_colors[i + 0] 726 | c2 = v_colors[i + 1] 727 | c3 = v_colors[i + 2] 728 | fp.write('v {0:.8f} {1:.8f} {2:.8f} {3:.8f} {4:.8f} {5:.8f}\n'.format(x, y, z, c1, c2, c3)) 729 | 730 | # Write indices 731 | for i in range(0, len(indices), 3): 732 | i0 = indices[i + 0] + 1 733 | i1 = indices[i + 1] + 1 734 | i2 = indices[i + 2] + 1 735 | fp.write('f {0} {1} {2}\n'.format(i0, i1, i2)) 736 | 737 | def save_as_ply(self, filename, fn): 738 | assert len(self.vs) > 0 739 | vertices = np.array(self.vs, dtype=np.float32).flatten() 740 | indices = np.array(self.faces, dtype=np.uint32).flatten() 741 | fnormals = np.array(fn, dtype=np.float32).flatten() 742 | 743 | with open(filename, 'w') as fp: 744 | # Write Header 745 | fp.write("ply\nformat ascii 1.0\nelement vertex {}\n".format(len(self.vs))) 746 | fp.write("property float x\nproperty float y\nproperty float z\n") 747 | fp.write("element face {}\n".format(len(self.faces))) 748 | fp.write("property list uchar int vertex_indices\n") 749 | fp.write("property uchar red\nproperty uchar green\nproperty uchar blue\nproperty uchar alpha\n") 750 | fp.write("end_header\n") 751 | for i in range(0, vertices.size, 3): 752 | x = vertices[i + 0] 753 | y = vertices[i + 1] 754 | z = vertices[i + 2] 755 | fp.write("{0:.6f} {1:.6f} {2:.6f}\n".format(x, y, z)) 756 | 757 | for i in range(0, len(indices), 3): 758 | i0 = indices[i + 0] 759 | i1 = indices[i + 1] 760 | i2 = indices[i + 2] 761 | c0 = fnormals[i + 0] 762 | c1 = fnormals[i + 1] 763 | c2 = fnormals[i + 2] 764 | c0 = np.clip(int(255 * c0), 0, 255) 765 | c1 = np.clip(int(255 * c1), 0, 255) 766 | c2 = np.clip(int(255 * c2), 0, 255) 767 | c3 = 255 768 | fp.write("3 {0} {1} {2} {3} {4} {5} {6}\n".format(i0, i1, i2, c0, c1, c2, c3)) 769 | 770 | def display_face_normals(self, fn): 771 | import open3d as o3d 772 | self.compute_face_center() 773 | pcd = o3d.geometry.PointCloud() 774 | pcd.points = o3d.utility.Vector3dVector(np.asarray(self.fc)) 775 | pcd.normals = o3d.utility.Vector3dVector(np.asarray(fn)) 776 | o3d.visualization.draw_geometries([pcd]) 777 | -------------------------------------------------------------------------------- /util/meshnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import os 5 | import copy 6 | from torch_geometric.nn import GCNConv, ChebConv, Sequential 7 | 8 | 9 | class MeshPool(nn.Module): 10 | def __init__(self, pool_hash): 11 | super(MeshPool, self).__init__() 12 | self.register_buffer("pool_hash", pool_hash) 13 | 14 | def forward(self, input): 15 | v_sum = torch.sum(self.pool_hash.to_dense(), dim=1, keepdim=True) 16 | out = torch.sparse.mm(self.pool_hash, input) / v_sum 17 | return out 18 | 19 | 20 | class MeshUnpool(nn.Module): 21 | def __init__(self, unpool_hash): 22 | super(MeshUnpool, self).__init__() 23 | self.register_buffer("unpool_hash", unpool_hash) 24 | 25 | def forward(self, input): 26 | out = torch.sparse.mm(self.unpool_hash, input) 27 | return out 28 | 29 | ############################################################# 30 | 31 | class DownConv(nn.Module): 32 | def __init__(self, in_channels, out_channels, edge_index1, edge_index2, pool_hash, K=3, drop_rate=0.0): 33 | super(DownConv, self).__init__() 34 | self.edge_index1 = edge_index1 35 | self.edge_index2 = edge_index2 36 | conv = "chebconv" 37 | if conv == "chebconv": 38 | """ chebconv """ 39 | self.model1 = Sequential("x, edge_index", [ 40 | (ChebConv(in_channels, out_channels, K=K), "x, edge_index -> x"), 41 | (nn.BatchNorm1d(out_channels), "x -> x"), 42 | (nn.LeakyReLU(), "x -> x"), 43 | 44 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 45 | (MeshPool(pool_hash), "x -> x"), 46 | (nn.BatchNorm1d(out_channels), "x -> x"), 47 | (nn.LeakyReLU(), "x -> x"), 48 | ]) 49 | self.model2 = Sequential("x, edge_index", [ 50 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 51 | (nn.BatchNorm1d(out_channels), "x -> x"), 52 | (nn.LeakyReLU(), "x -> x"), 53 | 54 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 55 | (nn.BatchNorm1d(out_channels), "x -> x"), 56 | (nn.LeakyReLU(), "x -> x"), 57 | 58 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 59 | (nn.BatchNorm1d(out_channels), "x -> x"), 60 | (nn.LeakyReLU(), "x -> x"), 61 | 62 | (nn.Dropout(drop_rate), "x -> x"), 63 | ]) 64 | else: 65 | """ gcnconv """ 66 | self.model1 = Sequential("x, edge_index", [ 67 | (GCNConv(in_channels, out_channels), "x, edge_index -> x"), 68 | (nn.BatchNorm1d(out_channels), "x -> x"), 69 | (nn.LeakyReLU(), "x -> x"), 70 | 71 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 72 | (MeshPool(pool_hash), "x -> x"), 73 | (nn.BatchNorm1d(out_channels), "x -> x"), 74 | (nn.LeakyReLU(), "x -> x"), 75 | ]) 76 | self.model2 = Sequential("x, edge_index", [ 77 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 78 | (nn.BatchNorm1d(out_channels), "x -> x"), 79 | (nn.LeakyReLU(), "x -> x"), 80 | 81 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 82 | (nn.BatchNorm1d(out_channels), "x -> x"), 83 | (nn.LeakyReLU(), "x -> x"), 84 | 85 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 86 | (nn.BatchNorm1d(out_channels), "x -> x"), 87 | (nn.LeakyReLU(), "x -> x"), 88 | 89 | (nn.Dropout(drop_rate), "x -> x"), 90 | ]) 91 | 92 | def forward(self, input): 93 | out = self.model1(input, self.edge_index1) 94 | out = self.model2(out, self.edge_index2) 95 | return out 96 | 97 | 98 | class UpConv(nn.Module): 99 | def __init__(self, in_channels, out_channels, edge_index1, edge_index2, unpool_hash, K=3, drop_rate=0.0): 100 | super(UpConv, self).__init__() 101 | self.edge_index1 = edge_index1 102 | self.edge_index2 = edge_index2 103 | conv = "chebconv" 104 | if conv == "chebconv": 105 | self.model1 = Sequential("x, edge_index", [ 106 | (ChebConv(in_channels, out_channels, K=K), "x, edge_index -> x"), 107 | (MeshUnpool(unpool_hash), "x -> x"), 108 | (nn.BatchNorm1d(out_channels), "x -> x"), 109 | (nn.LeakyReLU(), "x -> x"), 110 | ]) 111 | self.model2 = Sequential("x, edge_index", [ 112 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 113 | (nn.BatchNorm1d(out_channels), "x -> x"), 114 | (nn.LeakyReLU(), "x -> x"), 115 | 116 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 117 | (nn.BatchNorm1d(out_channels), "x -> x"), 118 | (nn.LeakyReLU(), "x -> x"), 119 | 120 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 121 | (nn.BatchNorm1d(out_channels), "x -> x"), 122 | (nn.LeakyReLU(), "x -> x"), 123 | 124 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"), 125 | (nn.BatchNorm1d(out_channels), "x -> x"), 126 | (nn.LeakyReLU(), "x -> x"), 127 | 128 | (nn.Dropout(drop_rate), "x -> x"), 129 | ]) 130 | else: 131 | self.model1 = Sequential("x, edge_index", [ 132 | (GCNConv(in_channels, out_channels), "x, edge_index -> x"), 133 | (MeshUnpool(unpool_hash), "x -> x"), 134 | (nn.BatchNorm1d(out_channels), "x -> x"), 135 | (nn.LeakyReLU(), "x -> x"), 136 | ]) 137 | self.model2 = Sequential("x, edge_index", [ 138 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 139 | (nn.BatchNorm1d(out_channels), "x -> x"), 140 | (nn.LeakyReLU(), "x -> x"), 141 | 142 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 143 | (nn.BatchNorm1d(out_channels), "x -> x"), 144 | (nn.LeakyReLU(), "x -> x"), 145 | 146 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 147 | (nn.BatchNorm1d(out_channels), "x -> x"), 148 | (nn.LeakyReLU(), "x -> x"), 149 | 150 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"), 151 | (nn.BatchNorm1d(out_channels), "x -> x"), 152 | (nn.LeakyReLU(), "x -> x"), 153 | 154 | (nn.Dropout(drop_rate), "x -> x"), 155 | ]) 156 | 157 | def forward(self, input): 158 | out = self.model1(input, self.edge_index1) 159 | out = self.model2(out, self.edge_index2) 160 | return out 161 | 162 | 163 | class MGCN(nn.Module): 164 | def __init__(self, device, smo_mesh, ini_mesh, v_mask, K=3, skip=False): 165 | super(MGCN, self).__init__() 166 | self.device = device 167 | self.skip = skip 168 | 169 | pool_levels = 3 170 | nv = len(smo_mesh.vs) 171 | nvs = [int(nv*(0.6**i)) for i in range(1, pool_levels+1)] 172 | self.nvs = nvs 173 | meshes = [smo_mesh] 174 | p_hashes = [] 175 | up_hashes = [] 176 | edge_inds = [ini_mesh.edge_index.to(device)] 177 | v_masks_list = [v_mask.reshape(-1, 1).float()] 178 | v_masks = v_mask.reshape(-1, 1).float() 179 | f_masks = (torch.sparse.mm(ini_mesh.f2v_mat, 1.0 - v_mask.reshape(-1,1).float()) == 0).bool().reshape(-1) 180 | f_masks_list = [f_masks] 181 | os.makedirs("{}/pooled".format(os.path.dirname(smo_mesh.path)), exist_ok=True) 182 | for i in range(pool_levels): 183 | # s_mesh = meshes[i].edge_based_simplification(target_v=nvs[i]) 184 | s_mesh = meshes[i].simplification(target_v=nvs[i]) 185 | meshes.append(s_mesh) 186 | s_pmask = self.pool_hash_to_mask(s_mesh.pool_hash) 187 | s_umask = self.unpool_hash_to_mask(s_mesh.pool_hash) 188 | p_hashes.append(s_pmask.to(device)) 189 | up_hashes.append(s_umask.to(device)) 190 | edge_inds.append(s_mesh.edge_index.to(device)) 191 | vm_i = v_masks_list[-1] 192 | vm_i_inv = torch.logical_not(vm_i).float() 193 | vm_j = torch.logical_not(torch.sparse.mm(s_pmask, vm_i_inv)).float() 194 | v_masks_list.append(vm_j) 195 | v_masks = torch.cat([v_masks, vm_j], dim=0) 196 | fm = (torch.sparse.mm(s_mesh.f2v_mat, 1.0 - vm_j.reshape(-1,1)) == 0).bool().reshape(-1) 197 | f_masks = torch.cat([f_masks, fm], dim=0) 198 | f_masks_list.append(fm) 199 | color = np.ones([len(s_mesh.vs), 3]) 200 | s_mesh.vc = color * vm_j.reshape(-1, 1).detach().numpy().copy() 201 | s_mesh.save("{}/pooled/{}_vs.obj".format(os.path.dirname(smo_mesh.path), len(s_mesh.vs)), color=True) 202 | 203 | self.meshes = meshes 204 | self.p_hashes = p_hashes 205 | self.up_hashes = up_hashes 206 | self.edge_inds = edge_inds 207 | self.v_masks = v_masks 208 | self.v_masks_list = v_masks_list 209 | self.f_masks = f_masks 210 | self.f_masks_list = f_masks_list 211 | 212 | self.encoder1 = DownConv(4, 32, edge_inds[0], edge_inds[1], p_hashes[0], K=K, drop_rate=0.0) 213 | self.encoder2 = DownConv(32, 128, edge_inds[1], edge_inds[2], p_hashes[1], K=K, drop_rate=0.2) 214 | self.encoder3 = DownConv(128, 256, edge_inds[2], edge_inds[3], p_hashes[2], K=K, drop_rate=0.2) 215 | 216 | self.decoder3 = UpConv(256, 128, edge_inds[3], edge_inds[2], up_hashes[2], K=K, drop_rate=0.2) 217 | self.decoder2 = UpConv(128, 32, edge_inds[2], edge_inds[1], up_hashes[1], K=K, drop_rate=0.2) 218 | self.decoder1 = nn.Sequential( 219 | UpConv(32, 16, edge_inds[1], edge_inds[0], up_hashes[0], K=K, drop_rate=0.0), 220 | nn.Linear(16, 3), 221 | ) 222 | 223 | self.mcnn3 = Sequential("x, edge_index", [ 224 | (ChebConv(256, 32, K=K), "x, edge_index -> x"), 225 | # (GCNConv(256, 32), "x, edge_index -> x"), 226 | (nn.BatchNorm1d(32), "x -> x"), 227 | (nn.LeakyReLU(), "x -> x"), 228 | (nn.Linear(32, 3), "x -> x"), 229 | ]) 230 | 231 | self.mcnn2 = Sequential("x, edge_index", [ 232 | (ChebConv(128, 32, K=K), "x, edge_index -> x"), 233 | # (GCNConv(128, 32), "x, edge_index -> x"), 234 | (nn.BatchNorm1d(32), "x -> x"), 235 | (nn.LeakyReLU(), "x -> x"), 236 | (nn.Linear(32, 3), "x -> x"), 237 | ]) 238 | 239 | self.mcnn1 = Sequential("x, edge_index", [ 240 | (ChebConv(32, 32, K=K), "x, edge_index -> x"), 241 | # (GCNConv(32, 32, K=K), "x, edge_index -> x"), 242 | (nn.BatchNorm1d(32), "x -> x"), 243 | (nn.LeakyReLU(), "x -> x"), 244 | (nn.Linear(32, 3), "x -> x"), 245 | ]) 246 | 247 | self.skip2 = nn.Linear(256, 128) 248 | self.skip1 = nn.Linear(64, 32) 249 | 250 | """""" 251 | org_pos = torch.from_numpy(ini_mesh.vs).float().to(device) 252 | org_smpos = torch.from_numpy(smo_mesh.vs).float().to(device) 253 | self.poss = org_pos 254 | self.poss_list = [org_pos] 255 | self.smposs = org_smpos 256 | self.smposs_list = [org_smpos] 257 | for l in range(pool_levels): 258 | simp_mesh = copy.deepcopy(meshes[l+1]) 259 | pool = MeshPool(self.p_hashes[l]) 260 | pos = pool(org_pos) 261 | simp_mesh.vs = pos.detach().to("cpu").numpy().copy() 262 | color = np.ones([len(simp_mesh.faces), 3]) 263 | color[:, 0] = 1.0 264 | color[:, 1] = 0 265 | color[:, 2] = 1.0 266 | color[self.f_masks_list[l+1].detach().numpy().copy(), 0] = 0.332 267 | color[self.f_masks_list[l+1].detach().numpy().copy(), 1] = 0.664 268 | color[self.f_masks_list[l+1].detach().numpy().copy(), 2] = 1.0 269 | # color = color * self.f_masks_list[l+1].reshape(-1, 1).detach().numpy().copy() 270 | simp_mesh.save_as_ply("{}/pooled/ini_{}_vs.ply".format(os.path.dirname(simp_mesh.path), len(simp_mesh.vs)), color) 271 | sm_pos = torch.from_numpy(meshes[l+1].vs).float().to(device) 272 | self.poss = torch.cat([self.poss, pos], dim=0) 273 | self.poss_list.append(pos) 274 | self.smposs = torch.cat([self.smposs, sm_pos], dim=0) 275 | self.smposs_list.append(sm_pos) 276 | org_pos = pos 277 | 278 | def forward(self, data, dm=None): 279 | 280 | z1, _ = data.z1.to(self.device), data.x_pos.to(self.device) 281 | 282 | z_min, z_max = torch.min(z1, dim=0, keepdim=True)[0], torch.max(z1, dim=0, keepdim=True)[0] 283 | z_sc = torch.max(z_max - z_min) 284 | zc = (z_min + z_max) * 0.5 285 | z1 = (z1 - zc) / z_sc 286 | 287 | if type(dm) != np.ndarray: 288 | dm = torch.ones([z1.shape[0], 1]) 289 | else: 290 | dm = torch.from_numpy(dm) 291 | dm = dm.to(self.device) 292 | z1[:, 0:3] = dm * z1[:, 0:3] 293 | z1 = torch.cat([z1, dm], dim=1) 294 | 295 | res1_enc = self.encoder1(z1) 296 | res2_enc = self.encoder2(res1_enc) 297 | res3_bot = self.encoder3(res2_enc) 298 | out3 = self.mcnn3(res3_bot, self.edge_inds[3]) 299 | 300 | res2_dec = self.decoder3(res3_bot) 301 | if self.skip: 302 | res2_cat = torch.cat([res2_dec, res2_enc], dim=1) 303 | res2_dec = self.skip2(res2_cat) 304 | out2 = self.mcnn2(res2_dec, self.edge_inds[2]) 305 | 306 | res1_dec = self.decoder2(res2_dec) 307 | if self.skip: 308 | res1_cat = torch.cat([res1_dec, res1_enc], dim=1) 309 | res1_dec = self.skip1(res1_cat) 310 | out1 = self.mcnn1(res1_dec, self.edge_inds[1]) 311 | 312 | out0 = self.decoder1(res1_dec) 313 | 314 | pos0 = self.smposs_list[0] + out0 315 | pos1 = self.smposs_list[1] + out1 316 | pos2 = self.smposs_list[2] + out2 317 | pos3 = self.smposs_list[3] + out3 318 | return (pos0, pos1, pos2, pos3) 319 | 320 | def pool(self, dx, pool_hash): 321 | pool_hash = pool_hash.to(dx.device) 322 | v_sum = torch.sum(pool_hash.to_dense(), dim=1, keepdim=True) 323 | dx = torch.sparse.mm(pool_hash, dx) / v_sum 324 | return dx 325 | 326 | def unpool(self, dx, unpool_hash): 327 | unpool_hash = unpool_hash.to(dx.device) 328 | dx = torch.sparse.mm(unpool_hash, dx) 329 | return dx 330 | 331 | def pool_hash_to_mask(self, pool_hash): 332 | mask_ind = torch.stack([torch.tensor(pool_hash)[:, 1], torch.tensor(pool_hash)[:, 0]], dim=0) 333 | mask_val = torch.ones(mask_ind.shape[1]).float() 334 | mask_mat = torch.sparse.FloatTensor(mask_ind, mask_val) 335 | return mask_mat 336 | 337 | def unpool_hash_to_mask(self, pool_hash): 338 | mask_ind = torch.tensor(pool_hash).T 339 | mask_val = torch.ones(mask_ind.shape[1]).float() 340 | mask_mat = torch.sparse.FloatTensor(mask_ind, mask_val) 341 | return mask_mat -------------------------------------------------------------------------------- /util/models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from functools import reduce 5 | from collections import Counter 6 | from typing import Union 7 | import scipy as sp 8 | from sklearn.preprocessing import normalize 9 | from util.mesh import Mesh 10 | 11 | def build_div(n_mesh, vn): 12 | vs = n_mesh.vs 13 | faces = n_mesh.faces 14 | fn = n_mesh.fn 15 | fa = n_mesh.fa 16 | vf = n_mesh.vf 17 | grad_b = [[] for _ in range(len(vs))] 18 | for i, v in enumerate(vf): # vf: triangle indices around vertex i 19 | for t in v: # for each triangle index t 20 | f = faces[t] # vertex indices in t 21 | f_n = fn[t] # face normal of t 22 | a = fa[t] # face area of t 23 | """sort vertex indices""" 24 | if f[1] == i: 25 | f = [f[1], f[2], f[0]] 26 | elif f[2] == i: 27 | f = [f[2], f[1], f[0]] 28 | x_kj = vs[f[2]] - vs[f[1]] 29 | x_kj = f_n * np.dot(x_kj, f_n) + np.cross(f_n, x_kj) 30 | x_kj /= 2 31 | grad_b[i].append(x_kj.tolist()) 32 | 33 | div_v = [[] for _ in range(len(vn))] 34 | for i in range(len(vn)): 35 | tn = vn[i] 36 | g = np.array(grad_b[i]) 37 | div_v[i] = np.dot(np.sum(g, 0), tn) 38 | div = np.array(div_v) 39 | div = np.tile(div, 3).reshape(3, -1).T 40 | return div 41 | 42 | def jacobi(n_mesh, div, iter=10): 43 | # preparation 44 | C = n_mesh.mesh_lap.to_dense() 45 | #C = n_mesh.lapmat.to_dense() 46 | B = div 47 | # boundary condition 48 | C_add = torch.eye(len(n_mesh.vs)) 49 | C = torch.cat([C, C_add], dim=0) 50 | B_add = torch.from_numpy(n_mesh.vs) 51 | B = torch.cat([torch.from_numpy(B), B_add], dim=0) 52 | B = torch.matmul(C.T.float(), B.float()) 53 | A = torch.matmul(C.T, C) 54 | # solve Ax=b by jacobi 55 | x = torch.from_numpy(n_mesh.vs).float() 56 | 57 | for i in range(iter): 58 | r = B - torch.matmul(A, x) 59 | alpha = torch.diagonal(torch.matmul(r.T, r)) / (torch.diagonal(torch.matmul(torch.matmul(r.T, A), r)) + 1e-12) 60 | x += alpha * r 61 | 62 | return x 63 | 64 | def poisson_mesh_edit(n_mesh, div): 65 | C = n_mesh.mesh_lap.to_dense() 66 | B = div 67 | # boundary condition 68 | C_add = torch.eye(len(n_mesh.vs)) 69 | C = torch.cat([C, C_add], dim=0) 70 | B_add = torch.from_numpy(n_mesh.vs) 71 | B = torch.cat([torch.from_numpy(B), B_add], dim=0) 72 | A = torch.matmul(C.T, C) 73 | Ainv = torch.inverse(A) 74 | CtB = torch.matmul(C.T.float(), B.float()) 75 | new_vs = torch.matmul(Ainv, CtB) 76 | 77 | return new_vs 78 | 79 | def cg(n_mesh, div, iter=50, a=100): 80 | # preparation 81 | #C = n_mesh.mesh_lap.to_dense() 82 | #C = n_mesh.lapmat.to_dense() 83 | C = n_mesh.cot_mat.to_dense() 84 | B = div 85 | # boundary condition 86 | C_add = torch.eye(len(n_mesh.vs)) * a 87 | C = torch.cat([C, C_add], dim=0) 88 | Ct = C.T.to_sparse() 89 | C = C.to_sparse() 90 | B_add = torch.from_numpy(n_mesh.vs) * a 91 | B = torch.cat([torch.from_numpy(B), B_add], dim=0) 92 | B = torch.sparse.mm(Ct.float(), B.float()) 93 | A = torch.sparse.mm(Ct, C.to_dense()) 94 | # solve Ax=b by cg 95 | """ 96 | x_0 = torch.from_numpy(n_mesh.vs).float() 97 | r_0 = B - torch.matmul(A, x_0) 98 | p_0 = r_0 99 | for i in range(iter): 100 | y_0 = torch.matmul(A, p_0) 101 | alpha = torch.diagonal(torch.matmul(r_0.T, r_0)) / (torch.diagonal(torch.matmul(p_0.T, y_0) + 1e-12)) 102 | x_1 = x_0 + alpha * p_0 103 | r_1 = r_0 - alpha * y_0 104 | if torch.sum(torch.norm(r_1, dim=0), dim=0) < 1e-4: 105 | break 106 | beta = torch.diagonal(torch.matmul(r_1.T, r_1)) / (torch.diagonal(torch.matmul(r_0.T, r_0)) + 1e-12) 107 | p_1 = r_1 + beta * p_0 108 | 109 | x_0 = x_1 110 | r_0 = r_1 111 | p_0 = p_1 112 | return x_1 113 | """ 114 | A = A.detach().numpy().copy() 115 | x_0 = torch.from_numpy(n_mesh.vs).float() 116 | x_1 = [] 117 | for i in range(3): 118 | x_1.append(sp.sparse.linalg.cg(A, B[:,i], x0=x_0[:,i], maxiter=iter)[0].tolist()) 119 | return np.array(x_1).T 120 | 121 | def compute_fn(vs: torch.Tensor, faces: np.ndarray) -> torch.Tensor: 122 | """ compute face normals from mesh with Tensor """ 123 | face_normals = torch.cross(vs[faces[:, 1]] - vs[faces[:, 0]], vs[faces[:, 2]] - vs[faces[:, 0]]) 124 | norm = torch.sqrt(torch.sum(face_normals**2, dim=1)) 125 | face_normals = face_normals / norm.repeat(3, 1).T 126 | return face_normals 127 | 128 | def compute_vn(vs: torch.Tensor, fn: torch.Tensor, faces: np.ndarray) -> torch.Tensor: 129 | """ compute vertex normals from mesh with Tensor""" 130 | vert_normals = torch.zeros((3, len(vs))) 131 | face_normals = fn 132 | faces = torch.from_numpy(faces).long().to(vs.device) 133 | 134 | nv = len(vs) 135 | nf = len(faces) 136 | mat_rows = torch.reshape(faces, (-1,)).to(vs.device) 137 | mat_cols = torch.tensor([[i] * 3 for i in range(nf)]).reshape(-1).to(vs.device) 138 | mat_vals = torch.ones(len(mat_rows)).to(vs.device) 139 | f2v_mat = torch.sparse.FloatTensor(torch.stack([mat_rows, mat_cols], dim=0), 140 | mat_vals, 141 | size=torch.Size([nv, nf])) 142 | vert_normals = torch.sparse.mm(f2v_mat, face_normals) 143 | norm = torch.sqrt(torch.sum(vert_normals**2, dim=1)) 144 | vert_normals = vert_normals / norm.repeat(3, 1).T 145 | return vert_normals 146 | 147 | def uv2xyz(uv): 148 | u = 2.0 * np.pi * uv[:, 0] - np.pi 149 | v = np.pi * uv[:, 1] 150 | x = torch.sin(v) * torch.cos(u) 151 | y = torch.sin(v) * torch.sin(u) 152 | z = torch.cos(v) 153 | xyz = torch.stack([x, y, z]).T 154 | return xyz 155 | 156 | def compute_nvt(mesh: Mesh, alpha=0.2, beta=0.2, delta=0.3) -> np.ndarray: 157 | f2ring = mesh.f2ring 158 | fa = mesh.fa 159 | fn = mesh.fn 160 | fc = mesh.fc 161 | 162 | #f_group = np.zeros([len(fn), 3]) 163 | fec_strength = np.zeros([len(fn), 3]) 164 | 165 | for i, f in enumerate(f2ring): 166 | ci = fc[i].reshape(1, -1) 167 | cj = fc[f] 168 | nj = fn[f] 169 | """ (a cross b) cross a = (a dot a)b - (b dot a)a """ 170 | a_a = np.sum((cj - ci) ** 2, 1).reshape(-1, 1) 171 | b_a = np.sum((cj - ci) * nj, 1).reshape(-1, 1) 172 | wj = a_a * nj - b_a * (cj - ci) 173 | wj = normalize(wj, norm="l2", axis=1) 174 | 175 | nw = np.sum(nj * wj, 1).reshape(-1, 1) 176 | nj_prime = 2 * nw * wj - nj 177 | 178 | am = np.max(fa[f]) + 1.0e-12 179 | aj = fa[f] 180 | cji_norm = np.linalg.norm(cj - ci, axis=1) 181 | sigma = np.mean(cji_norm) + 1.0e-12 182 | mu = (aj / am * np.exp(-cji_norm / sigma)).reshape(-1, 1) 183 | Ti = np.matmul(nj_prime.T, (nj_prime * mu)) 184 | order = np.argsort(np.linalg.eig(Ti)[0])[::-1] 185 | e_vals = np.linalg.eig(Ti)[0][order] 186 | e_vecs = np.linalg.eig(Ti)[1][:, order] 187 | n_ave = np.sum(mu * nj_prime, 0) 188 | 189 | fec_strength[i][0] = e_vals[0] - e_vals[1] / np.sum(e_vals) 190 | fec_strength[i][1] = e_vals[1] - e_vals[2] / np.sum(e_vals) 191 | fec_strength[i][2] = e_vals[2] / np.sum(e_vals) 192 | 193 | """ create face group 194 | if len(e_vals) != 3: 195 | print("len(e_vals) < 3 !") 196 | elif e_vals[1] < 0.01 and e_vals[2] < 0.001: 197 | f_group[i] = np.array([-1, -1, 1]) 198 | elif e_vals[1] > 0.01 and e_vals[2] < 0.1: 199 | f_group[i] = np.array([0, 1, -1]) 200 | elif e_vals[2] > 0.1: 201 | f_group[i] = np.array([-1, 1, -1]) 202 | else: 203 | f_group[i] = np.array([-1, 0, 1]) 204 | """ 205 | 206 | return fec_strength 207 | 208 | 209 | def bnf(pos: torch.Tensor, fn: torch.Tensor, mesh: Mesh, loop=1) -> torch.Tensor: 210 | """ bilateral loss for face normal """ 211 | fc = torch.sum(pos[mesh.faces], 1) / 3.0 212 | fa = torch.cross(pos[mesh.faces[:, 1]] - pos[mesh.faces[:, 0]], pos[mesh.faces[:, 2]] - pos[mesh.faces[:, 0]]) 213 | fa = 0.5 * torch.sqrt(torch.sum(fa**2, axis=1) + 1.0e-12) 214 | 215 | f2f = torch.from_numpy(mesh.f2f).long().to(pos.device) 216 | no_neig = 1.0 * (f2f != -1) 217 | 218 | neig_fc = fc[f2f] 219 | neig_fa = fa[f2f] * no_neig 220 | fc0_tile = fc.reshape(-1, 1, 3) 221 | fc_dist = torch.sum((neig_fc-fc0_tile)**2, dim=2) 222 | sigma_c = torch.sum(torch.sqrt(fc_dist + 1.0e-12)) / (fc_dist.shape[0] * fc_dist.shape[1]) 223 | 224 | new_fn = fn 225 | for i in range(loop): 226 | neig_fn = new_fn[f2f] 227 | fn0_tile = new_fn.reshape(-1, 1, 3) 228 | fn_dist = torch.sum((neig_fn-fn0_tile)**2, dim=2) 229 | sigma_s = 0.3 230 | wc = torch.exp(-1.0 * fc_dist / (2 * (sigma_c ** 2))) 231 | ws = torch.exp(-1.0 * fn_dist / (2 * (sigma_s ** 2))) 232 | 233 | W = torch.stack([wc*ws*neig_fa, wc*ws*neig_fa, wc*ws*neig_fa], dim=2) 234 | 235 | new_fn = torch.sum(W * neig_fn, dim=1) 236 | new_fn = new_fn / (torch.norm(new_fn, dim=1, keepdim=True) + 1.0e-12) 237 | return new_fn 238 | 239 | def vertex_updating(pos: torch.Tensor, norm: torch.Tensor, mesh: Mesh, loop=10) -> torch.Tensor: 240 | new_pos = pos.detach().clone() 241 | norm = norm.detach().clone() 242 | for iter in range(loop): 243 | fc = torch.sum(new_pos[mesh.faces], 1) / 3.0 244 | for i in range(len(new_pos)): 245 | cis = fc[list(mesh.vf[i])] 246 | nis = norm[list(mesh.vf[i])] 247 | cvis = cis - new_pos[i].reshape(1, -1) 248 | ncvis = torch.sum(nis * cvis, dim=1) 249 | dvi = torch.sum(ncvis.reshape(-1, 1) * nis, dim=0) 250 | dvi /= len(mesh.vf[i]) 251 | new_pos[i] += dvi 252 | return new_pos 253 | 254 | def random_rotate(pos: Union[torch.Tensor, np.ndarray], seed: None) -> torch.Tensor: 255 | if type(pos) == np.ndarray: 256 | pos = torch.from_numpy(pos) 257 | if seed: 258 | np.random.seed(seed) 259 | rx, ry, rz = np.random.uniform(size=[3]) * 2 * np.pi 260 | Rx = np.array([[1, 0, 0], 261 | [0, np.cos(rx), -np.sin(rx)], 262 | [0, np.sin(rx), np.cos(rx)]]) 263 | Ry = np.array([[np.cos(ry), 0, np.sin(ry)], 264 | [0, 1, 0], 265 | [-np.sin(ry), 0, np.cos(ry)]]) 266 | Rz = np.array([[np.cos(rz), -np.sin(rz), 0], 267 | [np.sin(rz), np.cos(rz), 0], 268 | [0, 0, 1]]) 269 | Rm = np.dot(Rz, np.dot(Ry, Rz)) 270 | Rm = torch.from_numpy(Rm).float() 271 | new_pos = torch.mm(Rm, pos.T.float()).T 272 | return new_pos 273 | -------------------------------------------------------------------------------- /util/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch_geometric.nn import GCNConv, ChebConv, Sequential 5 | 6 | 7 | 8 | class SingleScaleGCN(nn.Module): 9 | def __init__(self, device, activation="lrelu", skip=False): 10 | super(SingleScaleGCN, self).__init__() 11 | self.device = device 12 | self.skip = skip 13 | conv = "chebconv" 14 | 15 | h = [4, 16, 32, 64, 128, 256, 256, 512, 256, 256, 128, 64, 32, 16, 3] 16 | 17 | activ_dict = {"relu": nn.ReLU(), "lrelu": nn.LeakyReLU()} 18 | activation_func = activ_dict[activation] 19 | 20 | blocks = [] 21 | 22 | if conv == "gcnconv": 23 | for i in range(12): 24 | block = Sequential("x, edge_index", [ 25 | (GCNConv(h[i], h[i+1]), "x, edge_index -> x"), 26 | nn.BatchNorm1d(h[i+1]), 27 | activation_func, 28 | ]) 29 | blocks.append(block) 30 | 31 | block = Sequential("x, edge_index", [ 32 | (GCNConv(h[12], h[13]), "x, edge_index -> x"), 33 | nn.BatchNorm1d(h[13]), 34 | activation_func, 35 | (nn.Linear(h[13], h[14]), "x -> x"), 36 | ]) 37 | blocks.append(block) 38 | 39 | elif conv == "chebconv": 40 | for i in range(12): 41 | block = Sequential("x, edge_index", [ 42 | (ChebConv(h[i], h[i+1], K=3), "x, edge_index -> x"), 43 | nn.BatchNorm1d(h[i+1]), 44 | activation_func, 45 | ]) 46 | blocks.append(block) 47 | 48 | block = Sequential("x, edge_index", [ 49 | (ChebConv(h[12], h[13], K=3), "x, edge_index -> x"), 50 | nn.BatchNorm1d(h[13]), 51 | activation_func, 52 | (nn.Linear(h[13], h[14]), "x -> x"), 53 | ]) 54 | blocks.append(block) 55 | 56 | self.blocks = nn.ModuleList(blocks) 57 | 58 | skip_blocks = [] 59 | for i in range(6): 60 | skip_blocks.append(nn.Linear(h[i+1]*2, h[i+1])) 61 | self.skip_blocks = nn.ModuleList(skip_blocks) 62 | 63 | def forward(self, data, dm=None): 64 | 65 | z1, x_pos, edge_index = data.z1.to(self.device), data.x_pos.to(self.device), data.edge_index.to(self.device) 66 | 67 | z_min, z_max = torch.min(z1, dim=0, keepdim=True)[0], torch.max(z1, dim=0, keepdim=True)[0] 68 | z_sc = torch.max(z_max - z_min) 69 | zc = (z_min + z_max) * 0.5 70 | z1 = (z1 - zc) / z_sc 71 | 72 | if type(dm) == np.ndarray: 73 | dm = torch.from_numpy(dm) 74 | elif type(dm) != torch.Tensor: 75 | dm = torch.ones([z1.shape[0], 1]) 76 | 77 | dm = dm.to(self.device) 78 | z1 = dm * z1 79 | z1 = torch.cat([z1, dm], dim=1) 80 | 81 | x = z1 82 | skip_in = [] 83 | for i, b in enumerate(self.blocks): 84 | if i <= 5: 85 | """ encoder """ 86 | y = b(x, edge_index) 87 | x = y 88 | skip_in.append(y) 89 | 90 | elif i <= 7: 91 | """ bottle-neck """ 92 | y = b(x, edge_index) 93 | x = y 94 | else: 95 | """ decoder """ 96 | if self.skip: 97 | x_src = skip_in[13-i] 98 | x_cat = torch.cat([x_src, x], dim=1) 99 | x = self.skip_blocks[13-i](x_cat) 100 | y = b(x, edge_index) 101 | x = y 102 | 103 | return x_pos + x 104 | -------------------------------------------------------------------------------- /util/render.py: -------------------------------------------------------------------------------- 1 | import pyvista as pv 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | class Render(): 7 | def __init__(self, mesh, gif_name): 8 | self.normalize_to_bb(mesh) 9 | self.pyv_mesh = self.load_mesh(mesh) 10 | self.save_gif(gif_name) 11 | 12 | def load_mesh(self, mesh): 13 | vs = mesh.vs 14 | faces = np.ones([len(mesh.faces), 1]) * 3 15 | faces = np.concatenate([faces, mesh.faces], axis=1).astype(np.int32) 16 | pyv_mesh = pv.PolyData(vs, faces) 17 | return pyv_mesh 18 | 19 | def normalize_to_bb(self, mesh): 20 | vs = mesh.vs 21 | maxs = np.max(vs, axis=0, keepdims=True) 22 | mins = np.min(vs, axis=0, keepdims=True) 23 | ranges = maxs - mins 24 | vs = (vs - mins) / ranges 25 | mesh.vs = vs - 0.5 26 | 27 | def save_img(self, gif_name): 28 | plotter = pv.Plotter(off_screen=True, notebook=False) 29 | plotter.add_mesh(self.pyv_mesh, color="orange") 30 | plotter.show(screenshot=gif_name) 31 | 32 | def save_gif(self, gif_name): 33 | figures = [] 34 | plotter = pv.Plotter(off_screen=True, notebook=False, window_size=[400,300]) 35 | plotter.set_focus([0,0,0]) 36 | plotter.set_position([2,2,2]) 37 | for i in range(60): 38 | rot = self.pyv_mesh.rotate_y(6 * i, inplace=False) 39 | plotter = pv.Plotter(off_screen=True, notebook=False, window_size=[400,300]) 40 | plotter.set_focus([0,0,0]) 41 | plotter.set_position([2,2,2]) 42 | plotter.add_mesh(rot, color="orange") 43 | img = plotter.screenshot(filename=None, return_img=True) 44 | figures.append(Image.fromarray(img)) 45 | figures[0].save(gif_name, save_all=True, append_images=figures[1:], optimize=True, duration=50, loop=0) --------------------------------------------------------------------------------