├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── check
├── batch_dist_check.py
└── dist_check.py
├── datasets.zip
├── docs
├── anim.gif
├── network.png
├── overview.png
├── semi_anim.gif
├── semi_anim.mp4
└── viewer.png
├── environment.yml
├── mgcn.py
├── mgcn_wo_gt.py
├── preprocess
└── prepare.py
├── refinement.py
├── requirements.txt
├── sgcn.py
└── util
├── __init__.py
├── datamaker.py
├── loss.py
├── mesh.py
├── meshnet.py
├── models.py
├── networks.py
└── render.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | wandb
3 | datasets/
4 |
5 | *.pt
6 | *.json
7 | *.ply
8 | *.obj
9 |
10 | advancing_front.py
11 | context_fill.py
12 | count_insert_vnum.py
13 | mesh_edit.py
14 | meshfix.py
15 | mgcn_time.py
16 | sgcn_time.py
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10
2 |
3 | WORKDIR /work
4 | COPY . /work
5 |
6 | RUN apt update
7 | RUN apt install -y libgl1-mesa-dev
8 |
9 | RUN pip install --upgrade pip
10 | RUN pip install -r requirements.txt
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Shota Hattori
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Learning Self-Prior for Mesh Inpainting using Self-Supervised Graph Convolutional Networks
2 |
3 |
4 |
5 |
6 | Paper |
7 | arXiv
8 |
9 | Accepted by IEEE TVCG 2024
10 |
11 |
12 |
13 |
14 |

15 |
Method Overview
16 |

17 |
18 |
19 | ## Usage
20 |
21 | ### Environments
22 | ```
23 | python==3.10
24 | torch==1.13.0
25 | torch-geometric==2.2.0
26 | ```
27 |
28 | ### Installation (Docker)
29 |
30 | ```
31 | docker image build -t astaka-pe/semigcn .
32 | docker run -itd --gpus all -p 8081:8081 --name semigcn -v .:/work astaka-pe/semigcn
33 | docker exec -it semigcn /bin/bash
34 | ```
35 |
36 | ### Preperation
37 |
38 | - Unzip `datasets.zip`
39 | - Sample meshes will be placed in `datasets/`
40 | - Put your own mesh in a new arbitrary folder as:
41 | - Deficient mesh: `datasets/**/{mesh-name}/{mesh-name}_original.obj`
42 | - Ground truth: `datasets/**/{mesh-name}/{mesh-name}_gt.obj`
43 | - The deficient and the ground truth meshes need not share a same connectivity but their scales must be shared
44 |
45 | ### Preprocess
46 |
47 | - Specify the path of the deficient mesh
48 | - Create **initial mesh** and **smoothed mesh**
49 |
50 | ```
51 | python3 preprocess/prepare.py -i datasets/**/{mesh-name}/{mesh-name}_original.obj
52 | ```
53 | - options
54 | - `-r {float}`: Target length of remeshing. The higher the coarser, the lower the finer. `default=0.6`.
55 |
56 | - Computation time: 30 sec
57 |
58 | ### Training
59 |
60 | ```
61 | python3 sgcn.py -i datasets/**/{mesh-name} # SGCN
62 | python3 mgcn.py -i datasets/**/{mesh-name} # MGCN
63 | ```
64 |
65 | - options
66 | - `-CAD`: For a CAD model
67 | - `-real`: For a real scan
68 | - `-cache`: For using cache files (for faster computation)
69 | - `-mu` : Weight for refinement
70 |
71 | You can monitor the training progress through the web viewer. (Default: http://localhost:8081)
72 |
73 | 
74 |
75 | ### Evaluation
76 |
77 | - Create `datasets/**/{mesh-name}/comparison` and put meshes for evaluation
78 | - A deficient mesh `datasets/**/{mesh-name}/comparison/original.obj` and a ground truth mesh `datasets/**/{mesh-name}/comparison/gt.obj` are needed for evaluation
79 |
80 | ```
81 | python3 check/batch_dist_check.py -i datasets/**/{mesh-name}
82 | ```
83 |
84 | - options
85 | - `-real`: For a real scan
86 |
87 |
88 | ### Refinement (Option)
89 |
90 | - If you want to perform only refinement, run
91 |
92 | ```
93 | python3 refinement.py \\
94 | -src datasets/**/{mesh-name}/{mesh-name}_initial/obj \\
95 | -dst datasets/**/{mesh-name}/output/**/100_step/.obj \\ # SGCN
96 | # -dst datasets/**/{mesh-name}/output/**/100_step_0.obj \\ # MGCN
97 | -vm datasets/**/{mesh-name}/{mesh-name}_vmask.json \\
98 | -ref {arbitrary-output-filename}.obj \\
99 | ```
100 |
101 | - option
102 | - `-mu`: Weight for refinement
103 | - Choose a weight so that the remaining vertex positions of the initial mesh and the shape of missing regions of the output mesh are saved
104 |
105 | ## Run other competitive methods
106 |
107 | Please refer to [tinymesh](https://github.com/tatsy/tinymesh).
108 |
109 |
121 |
122 | ## Citation
123 |
124 | ```
125 | @article{hattori2024semigcn,
126 | title={Learning Self-Prior for Mesh Inpainting Using Self-Supervised Graph Convolutional Networks},
127 | author={Hattori, Shota and Yatagawa, Tatsuya and Ohtake, Yutaka and Suzuki, Hiromasa},
128 | journal={IEEE Transactions on Visualization and Computer Graphics},
129 | year={2024},
130 | publisher={IEEE}
131 | }
132 | ```
--------------------------------------------------------------------------------
/check/batch_dist_check.py:
--------------------------------------------------------------------------------
1 | import pymeshlab as ml
2 | import numpy as np
3 | import argparse
4 | import glob
5 | import os
6 | import json
7 | import sys
8 | import trimesh
9 |
10 | #EPS = 0.05 # for simulated hole
11 | EPS = 1.0 # for real hole
12 | COLOR_MAX = 0.005
13 |
14 | def get_parser():
15 | parser = argparse.ArgumentParser(description="calculate hausdorff distances")
16 | parser.add_argument("-i", "--input", type=str, required=True)
17 | parser.add_argument("-real", action="store_true")
18 | args = parser.parse_args()
19 | for k, v in vars(args).items():
20 | print("{:12s}: {}".format(k, v))
21 | return args
22 |
23 | def main():
24 | """ calculate hausdorff distances """
25 | args = get_parser()
26 | if args.real:
27 | EPS = 1.0
28 | else:
29 | EPS = 0.05
30 | i_dir = "{}/comparison".format(args.input)
31 | o_dir = "{}/colored".format(i_dir)
32 | os.makedirs(o_dir, exist_ok=True)
33 | m_name = i_dir.split("/")[-2]
34 | gt_path = "{}/gt.obj".format(i_dir, m_name)
35 | org_path = "{}/original.obj".format(i_dir, m_name)
36 | all_mesh = glob.glob("{}/*.*".format(i_dir))
37 |
38 | ms = ml.MeshSet()
39 | ms.load_new_mesh(gt_path)
40 |
41 | face_num = ms.current_mesh().face_number()
42 | diag = ms.current_mesh().bounding_box().diagonal()
43 |
44 | ms.load_new_mesh(org_path)
45 | res = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=1, signeddist=False)
46 | ms.set_current_mesh(0)
47 | quality = ms.current_mesh().vertex_quality_array()
48 | new_vs = quality > EPS
49 | new_pos = ms.current_mesh().vertex_matrix()[new_vs]
50 | with open("{}/inserted.obj".format(o_dir), "w") as f_obj:
51 | for v in new_pos:
52 | print("v {} {} {}".format(v[0], v[1], v[2]), file=f_obj)
53 |
54 | max_val = diag * COLOR_MAX
55 | with open("{}/max_val.txt".format(o_dir), mode="w") as f:
56 | f.write("{:.7f}".format(max_val))
57 |
58 | for m_path in all_mesh:
59 | if m_path == gt_path or m_path == org_path:
60 | continue
61 | try:
62 | ms.load_new_mesh(m_path)
63 | except:
64 | print("[ERROR] {} is unknown format".format(m_path))
65 | continue
66 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=ms.current_mesh_id())
67 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0)
68 | quality = ms.mesh(0).vertex_quality_array()
69 | quality_hole = quality[new_vs]
70 | hd_all = np.sum(np.abs(quality)) / len(quality) / diag
71 | hd_hole = np.sum(np.abs(quality_hole)) / len(quality_hole) / diag
72 | ms.apply_filter("colorize_by_vertex_quality", minval=0, maxval=max_val, zerosym=True)
73 | out_file = os.path.basename(m_path)
74 | out_path = "{}/all={:.6f}-hole={:.6f}-{}".format(o_dir, hd_all, hd_hole, out_file)
75 | ms.save_current_mesh(out_path)
76 |
77 | if __name__ == "__main__":
78 | main()
--------------------------------------------------------------------------------
/check/dist_check.py:
--------------------------------------------------------------------------------
1 | from cv2 import COLOR_COLORCVT_MAX
2 | import pymeshlab as ml
3 | import numpy as np
4 | import argparse
5 | import glob
6 | import os
7 | import json
8 | import sys
9 |
10 | EPS = 0.05
11 | COLOR_MAX = 0.005
12 |
13 | def simple_mesh_distance(out_path, gt_path):
14 | """ calculate hausdorff distance """
15 | ms = ml.MeshSet()
16 | ms.load_new_mesh(gt_path)
17 | face_num = ms.current_mesh().face_number()
18 | diag = ms.current_mesh().bounding_box().diagonal()
19 | max_val = diag * COLOR_COLORCVT_MAX
20 |
21 | ms.load_new_mesh(out_path)
22 |
23 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=ms.current_mesh_id())
24 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0)
25 | quality = ms.mesh(0).vertex_quality_array()
26 | hd_all = np.sum(np.abs(quality)) / len(quality) / diag
27 | ms.apply_filter("colorize_by_vertex_quality", minval=0, maxval=max_val, zerosym=True)
28 | out_dir = os.path.dirname(out_path)
29 | out_file = os.path.basename(out_path)
30 | os.makedirs("{}/colored".format(out_dir), exist_ok=True)
31 | out_path = "{}/colored/all={:.6f}-{}".format(out_dir, hd_all, out_file)
32 | ms.save_current_mesh(out_path)
33 | return hd_all
34 |
35 | def mesh_distance(gt_path, org_path, out_path, real=False):
36 | """ calculate hausdorff distance """
37 | if real:
38 | EPS = 1.0
39 | else:
40 | EPS = 0.05
41 | ms = ml.MeshSet()
42 | ms.load_new_mesh(gt_path)
43 |
44 | face_num = ms.current_mesh().face_number()
45 | diag = ms.current_mesh().bounding_box().diagonal()
46 |
47 | ms.load_new_mesh(org_path)
48 | res = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=1, signeddist=False)
49 | ms.set_current_mesh(0)
50 | quality = ms.current_mesh().vertex_quality_array()
51 | new_vs = quality > EPS
52 |
53 | max_val = diag * 0.005
54 |
55 | ms.load_new_mesh(out_path)
56 |
57 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=0, refmesh=ms.current_mesh_id())
58 | res1 = ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0)
59 | quality = ms.mesh(0).vertex_quality_array()
60 | quality_hole = quality[new_vs]
61 | hd_all = np.sum(np.abs(quality)) / len(quality) / diag
62 | hd_hole = np.sum(np.abs(quality_hole)) / len(quality_hole) / diag
63 | ms.apply_filter("colorize_by_vertex_quality", minval=0, maxval=max_val, zerosym=True)
64 | out_dir = os.path.dirname(out_path)
65 | out_file = os.path.basename(out_path)
66 | out_path = "{}/all={:.6f}-hole={:.6f}-{}".format(out_dir, hd_all, hd_hole, out_file)
67 | ms.save_current_mesh(out_path)
--------------------------------------------------------------------------------
/datasets.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/datasets.zip
--------------------------------------------------------------------------------
/docs/anim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/anim.gif
--------------------------------------------------------------------------------
/docs/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/network.png
--------------------------------------------------------------------------------
/docs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/overview.png
--------------------------------------------------------------------------------
/docs/semi_anim.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/semi_anim.gif
--------------------------------------------------------------------------------
/docs/semi_anim.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/semi_anim.mp4
--------------------------------------------------------------------------------
/docs/viewer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/docs/viewer.png
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: semigcn
2 | channels:
3 | - defaults
4 | - pytorch
5 | - open3d-admin
6 | dependencies:
7 | - python=3.7
8 | - cython
9 | - pytorch=1.7.0
10 | - cudatoolkit=10.2
11 | - numpy
12 | - matplotlib
13 | - tensorflow
14 | - tensorboard
15 | - open3d
16 | - pip
17 | - pip:
18 | - pyyaml
19 | - addict
20 | - plyfile
21 | - opencv-python
22 | - scikit-learn==1.0.2
23 | - --find-links https://pytorch-geometric.com/whl/torch-1.7.1+cu102.html
24 | - torch-scatter==2.0.7
25 | - torch-sparse==0.6.9
26 | - torch-cluster==1.5.9
27 | - torch-spline-conv==1.2.1
28 | - torch-geometric==1.7.1
29 | - pymeshlab==2021.10
30 | - pymeshfix==0.16.1
31 | - tqdm==4.61.1
32 | - wandb==0.12.1
33 |
--------------------------------------------------------------------------------
/mgcn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import datetime
4 | import argparse
5 | import os
6 | import copy
7 | import random
8 | import viser
9 | from tqdm import tqdm
10 |
11 | import util.loss as Loss
12 | import util.models as Models
13 | import util.datamaker as Datamaker
14 | from util.mesh import Mesh
15 | from util.meshnet import MGCN
16 | import check.dist_check as DIST
17 |
18 | def torch_fix_seed(seed=314):
19 | random.seed(seed)
20 | np.random.seed(seed)
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed(seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.use_deterministic_algorithms = True
25 |
26 | def get_parser():
27 | parser = argparse.ArgumentParser(description="Self-supervised Mesh Completion")
28 | parser.add_argument("-i", "--input", type=str, required=True)
29 | parser.add_argument("-o", "--output", type=str, default="exp")
30 | parser.add_argument("-pos_lr", type=float, default=0.01)
31 | parser.add_argument("-iter", type=int, default=100)
32 | parser.add_argument("-k1", type=float, default=4.0)
33 | parser.add_argument("-k2", type=float, default=4.0)
34 | parser.add_argument("-dm_size", type=int, default=40)
35 | parser.add_argument("-kn", type=int, nargs="*", default=[4])
36 | parser.add_argument("-batch", type=int, default=5)
37 | parser.add_argument("-skip", action="store_true")
38 | parser.add_argument("-gpu", type=int, default=0)
39 | parser.add_argument("-cache", action="store_true")
40 | parser.add_argument("-CAD", action="store_true")
41 | parser.add_argument("-real", action="store_true")
42 | parser.add_argument("-mu", type=float, default=1.0)
43 | parser.add_argument("-viewer", action="store_true", default=True)
44 | parser.add_argument("-port", type=int, default=8081)
45 | args = parser.parse_args()
46 |
47 | for k, v in vars(args).items():
48 | print("{:12s}: {}".format(k, v))
49 |
50 | return args
51 |
52 | def main():
53 | args = get_parser()
54 |
55 | if args.viewer:
56 | server = viser.ViserServer(port=args.port)
57 |
58 | """ --- create dataset --- """
59 | mesh_dic, dataset = Datamaker.create_dataset(args.input, dm_size=args.dm_size, kn=args.kn, cache=args.cache)
60 | ini_file, smo_file, v_mask, f_mask, mesh_name = mesh_dic["ini_file"], mesh_dic["smo_file"], mesh_dic["v_mask"], mesh_dic["f_mask"], mesh_dic["mesh_name"]
61 | org_mesh, ini_mesh, smo_mesh, out_mesh = mesh_dic["org_mesh"], mesh_dic["ini_mesh"], mesh_dic["smo_mesh"], mesh_dic["out_mesh"]
62 | rot_mesh = copy.deepcopy(ini_mesh)
63 | dt_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
64 |
65 | vmask_dummy = mesh_dic["vmask_dummy"]
66 | fmask_dummy = mesh_dic["fmask_dummy"]
67 |
68 | """ --- create model instance --- """
69 | torch_fix_seed()
70 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
71 | posnet = MGCN(device, smo_mesh, ini_mesh, v_mask, skip=args.skip).to(device)
72 | optimizer_pos = torch.optim.Adam(posnet.parameters(), lr=args.pos_lr)
73 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_pos, step_size=50, gamma=0.5)
74 |
75 | anss = posnet.poss
76 | v_masks = posnet.v_masks
77 | nvs = posnet.nvs
78 | meshes = posnet.meshes
79 | v_masks_list = posnet.v_masks_list
80 | poss_list = posnet.poss_list
81 | nvs_all = [len(meshes[0].vs)] + nvs
82 | pos_weight = [0.35, 0.3, 0.2, 0.15]
83 |
84 | os.makedirs("{}/output/{}_mgcn_{}".format(args.input, dt_now, args.output), exist_ok=True)
85 | scale = 2 / np.max(org_mesh.vs)
86 | if args.viewer:
87 | print("\n\033[42m Viewer at: http://localhost:{} \033[0m\n".format(args.port))
88 | with server.gui.add_folder("Training"):
89 | server.scene.add_mesh_simple(
90 | name="/input",
91 | vertices=org_mesh.vs * scale,
92 | faces=org_mesh.faces,
93 | flat_shading=True,
94 | visible=False,
95 | )
96 | server.scene.add_mesh_simple(
97 | name="/initial",
98 | vertices=ini_mesh.vs * scale,
99 | faces=ini_mesh.faces,
100 | flat_shading=True,
101 | visible=False,
102 | )
103 | gui_counter = server.gui.add_number(
104 | "Epoch",
105 | initial_value=0,
106 | disabled=True,
107 | )
108 |
109 | """ --- learning loop --- """
110 | with tqdm(total=args.iter) as pbar:
111 | """ --- training --- """
112 | for epoch in range(1, args.iter+1):
113 | n_data = vmask_dummy.shape[1]
114 | batch_index = torch.randperm(n_data).reshape(-1, args.batch)
115 | epoch_loss_p = 0.0
116 | epoch_loss_n = 0.0
117 | epoch_loss_r = 0.0
118 | epoch_loss_pos = 0.0
119 | epoch_loss = 0.0
120 |
121 | for batch in batch_index:
122 | """ original dummy mask """
123 | dm_batch = vmask_dummy[:, batch]
124 | posnet.train()
125 | optimizer_pos.zero_grad()
126 | for i, b in enumerate(batch):
127 | """ original dummy mask """
128 | dm = dm_batch[:, i].reshape(-1, 1)
129 | rm = v_mask.reshape(-1, 1).float()
130 | dm = rm * dm
131 |
132 | ini_vs = ini_mesh.vs
133 |
134 | poss = posnet(dataset, dm)
135 | pos = poss[0]
136 |
137 | norm = Models.compute_fn(pos, ini_mesh.faces)
138 | for mesh_idx, pos_i in enumerate(poss):
139 | if mesh_idx == 0:
140 | loss_p = Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx]
141 | epoch_loss_pos += Loss.mask_pos_rec_loss(pos_i, poss_list[0], v_masks_list[0].reshape(-1).bool()).item()
142 | else:
143 | loss_p = loss_p + Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx]
144 | # loss_p = Loss.mask_pos_rec_loss(poss, anss, v_masks.reshape(-1).bool())
145 | loss_n = Loss.mask_norm_rec_loss(norm, ini_mesh.fn, f_mask)
146 | if args.CAD:
147 | loss_reg, _ = Loss.fn_bnf_detach_loss(pos, norm, ini_mesh, loop=5)
148 | loss = loss_p + args.k1 * loss_n + args.k2 * loss_reg
149 | epoch_loss_r += loss_reg.item()
150 | else:
151 | # loss_reg = Loss.mesh_laplacian_loss(pos, ini_mesh)
152 | loss = loss_p + args.k1 * loss_n# + 0.0 * loss_reg
153 | epoch_loss_p += loss_p.item()
154 | epoch_loss_n += loss_n.item()
155 |
156 | loss.backward()
157 | epoch_loss += loss.item()
158 |
159 | optimizer_pos.step()
160 | scheduler.step()
161 |
162 | epoch_loss_p /= n_data
163 | epoch_loss_n /= n_data
164 | epoch_loss_r /= n_data
165 | epoch_loss /= n_data
166 | epoch_loss_pos /= n_data
167 |
168 | pbar.set_description("Epoch {}".format(epoch))
169 | pbar.set_postfix({"loss": epoch_loss})
170 |
171 | if epoch == args.iter:
172 | out_path = "{}/output/{}_mgcn_{}/train.obj".format(args.input, dt_now, args.output)
173 | out_mesh.vs = pos.detach().to("cpu").numpy().copy()
174 | Mesh.save(out_mesh, out_path)
175 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real)
176 |
177 | """ --- test --- """
178 | if epoch % 10 == 0:
179 | posnet.eval()
180 | dm = v_mask.reshape(-1, 1).float()
181 | poss = posnet(dataset, dm)
182 | st_nv = 0
183 | for res, mesh in enumerate(meshes):
184 | out_path = "{}/output/{}_mgcn_{}/{}_step_{}.obj".format(args.input, dt_now, args.output, str(epoch), res)
185 | mesh.vs = poss[res].to("cpu").detach().numpy().copy()
186 | st_nv += len(mesh.vs)
187 | Mesh.save(mesh, out_path)
188 | if args.viewer:
189 | server.scene.add_mesh_simple(
190 | name="/output/res-{}".format(res),
191 | vertices=mesh.vs * scale,
192 | faces=mesh.faces,
193 | flat_shading=True,
194 | )
195 | out_path = "{}/output/{}_mgcn_{}/{}_step_0.obj".format(args.input, dt_now, args.output, str(epoch))
196 | if args.viewer:
197 | gui_counter.value = epoch
198 |
199 | pbar.update(1)
200 |
201 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real)
202 |
203 | """ refinement """
204 | posnet.eval()
205 | dm = v_mask.reshape(-1, 1).float()
206 | poss = posnet(dataset, dm)
207 | out_pos = poss[0].to("cpu").detach()
208 | ini_pos = torch.from_numpy(ini_mesh.vs).float()
209 | ref_pos = Mesh.mesh_merge(ini_mesh.Lap, ini_mesh, out_pos, v_mask, w=args.mu)
210 | out_path = "{}/output/{}_mgcn_{}/refine.obj".format(args.input, dt_now, args.output)
211 | out_mesh.vs = ref_pos.detach().numpy().copy()
212 | Mesh.save(out_mesh, out_path)
213 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real)
214 |
215 | if __name__ == "__main__":
216 | main()
--------------------------------------------------------------------------------
/mgcn_wo_gt.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import datetime
4 | import argparse
5 | import os
6 | import copy
7 | import random
8 | from tqdm import tqdm
9 |
10 | import util.loss as Loss
11 | import util.models as Models
12 | import util.datamaker as Datamaker
13 | from util.mesh import Mesh
14 | from util.meshnet import MGCN
15 | import check.dist_check as DIST
16 |
17 | def torch_fix_seed(seed=314):
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 | torch.cuda.manual_seed(seed)
22 | torch.backends.cudnn.deterministic = True
23 | torch.use_deterministic_algorithms = True
24 |
25 | def get_parser():
26 | parser = argparse.ArgumentParser(description="Self-supervised Mesh Completion")
27 | parser.add_argument("-i", "--input", type=str, required=True)
28 | parser.add_argument("-o", "--output", type=str, default="")
29 | parser.add_argument("-pos_lr", type=float, default=0.01)
30 | parser.add_argument("-iter", type=int, default=100)
31 | parser.add_argument("-k1", type=float, default=4.0)
32 | parser.add_argument("-k2", type=float, default=4.0)
33 | parser.add_argument("-dm_size", type=int, default=40)
34 | parser.add_argument("-kn", type=int, nargs="*", default=[4])
35 | parser.add_argument("-batch", type=int, default=5)
36 | parser.add_argument("-skip", action="store_true")
37 | parser.add_argument("-gpu", type=int, default=0)
38 | parser.add_argument("-cache", action="store_true")
39 | parser.add_argument("-CAD", action="store_true")
40 | parser.add_argument("-real", action="store_true")
41 | parser.add_argument("-mu", type=float, default=1.0)
42 | args = parser.parse_args()
43 |
44 | for k, v in vars(args).items():
45 | print("{:12s}: {}".format(k, v))
46 |
47 | return args
48 |
49 | def main():
50 | args = get_parser()
51 | """ --- create dataset --- """
52 | mesh_dic, dataset = Datamaker.create_dataset(args.input, dm_size=args.dm_size, kn=args.kn, cache=args.cache)
53 | ini_file, smo_file, v_mask, f_mask, mesh_name = mesh_dic["ini_file"], mesh_dic["smo_file"], mesh_dic["v_mask"], mesh_dic["f_mask"], mesh_dic["mesh_name"]
54 | ini_mesh, smo_mesh, out_mesh = mesh_dic["ini_mesh"], mesh_dic["smo_mesh"], mesh_dic["out_mesh"]
55 | rot_mesh = copy.deepcopy(ini_mesh)
56 | dt_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
57 |
58 | vmask_dummy = mesh_dic["vmask_dummy"]
59 | fmask_dummy = mesh_dic["fmask_dummy"]
60 |
61 | """ --- create model instance --- """
62 | torch_fix_seed()
63 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
64 | posnet = MGCN(device, smo_mesh, ini_mesh, v_mask, skip=args.skip).to(device)
65 | optimizer_pos = torch.optim.Adam(posnet.parameters(), lr=args.pos_lr)
66 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_pos, step_size=50, gamma=0.5)
67 |
68 | anss = posnet.poss
69 | v_masks = posnet.v_masks
70 | nvs = posnet.nvs
71 | meshes = posnet.meshes
72 | v_masks_list = posnet.v_masks_list
73 | poss_list = posnet.poss_list
74 | nvs_all = [len(meshes[0].vs)] + nvs
75 | pos_weight = [0.35, 0.3, 0.2, 0.15]
76 |
77 | os.makedirs("{}/output/{}_mgcn_{}".format(args.input, dt_now, args.output), exist_ok=True)
78 |
79 | """ --- learning loop --- """
80 | with tqdm(total=args.iter) as pbar:
81 | """ --- training --- """
82 | for epoch in range(1, args.iter+1):
83 | n_data = vmask_dummy.shape[1]
84 | batch_index = torch.randperm(n_data).reshape(-1, args.batch)
85 | epoch_loss_p = 0.0
86 | epoch_loss_n = 0.0
87 | epoch_loss_r = 0.0
88 | epoch_loss_pos = 0.0
89 | epoch_loss = 0.0
90 |
91 | for batch in batch_index:
92 | """ original dummy mask """
93 | dm_batch = vmask_dummy[:, batch]
94 | posnet.train()
95 | optimizer_pos.zero_grad()
96 | for i, b in enumerate(batch):
97 | """ original dummy mask """
98 | dm = dm_batch[:, i].reshape(-1, 1)
99 | rm = v_mask.reshape(-1, 1).float()
100 | dm = rm * dm
101 |
102 | ini_vs = ini_mesh.vs
103 |
104 | poss = posnet(dataset, dm)
105 | pos = poss[0]
106 |
107 | norm = Models.compute_fn(pos, ini_mesh.faces)
108 | for mesh_idx, pos_i in enumerate(poss):
109 | if mesh_idx == 0:
110 | loss_p = Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx]
111 | epoch_loss_pos += Loss.mask_pos_rec_loss(pos_i, poss_list[0], v_masks_list[0].reshape(-1).bool()).item()
112 | else:
113 | loss_p = loss_p + Loss.mask_pos_rec_loss(pos_i, poss_list[mesh_idx], v_masks_list[mesh_idx].reshape(-1).bool()) * pos_weight[mesh_idx]
114 | # loss_p = Loss.mask_pos_rec_loss(poss, anss, v_masks.reshape(-1).bool())
115 | loss_n = Loss.mask_norm_rec_loss(norm, ini_mesh.fn, f_mask)
116 | if args.CAD:
117 | loss_reg, _ = Loss.fn_bnf_detach_loss(pos, norm, ini_mesh, loop=5)
118 | loss = loss_p + args.k1 * loss_n + args.k2 * loss_reg
119 | epoch_loss_r += loss_reg.item()
120 | else:
121 | # loss_reg = Loss.mesh_laplacian_loss(pos, ini_mesh)
122 | loss = loss_p + args.k1 * loss_n# + 0.0 * loss_reg
123 | epoch_loss_p += loss_p.item()
124 | epoch_loss_n += loss_n.item()
125 |
126 | loss.backward()
127 | epoch_loss += loss.item()
128 |
129 | optimizer_pos.step()
130 | scheduler.step()
131 |
132 | epoch_loss_p /= n_data
133 | epoch_loss_n /= n_data
134 | epoch_loss_r /= n_data
135 | epoch_loss /= n_data
136 | epoch_loss_pos /= n_data
137 |
138 | pbar.set_description("Epoch {}".format(epoch))
139 | pbar.set_postfix({"loss": epoch_loss})
140 |
141 | """ --- test --- """
142 | if epoch % 10 == 0:
143 | posnet.eval()
144 | dm = v_mask.reshape(-1, 1).float()
145 | poss = posnet(dataset, dm)
146 | st_nv = 0
147 | for res, mesh in enumerate(meshes):
148 | out_path = "{}/output/{}_mgcn_{}/{}_step_{}.obj".format(args.input, dt_now, args.output, str(epoch), res)
149 | mesh.vs = poss[res].to("cpu").detach().numpy().copy()
150 | st_nv += len(mesh.vs)
151 | Mesh.save(mesh, out_path)
152 | out_path = "{}/output/{}_mgcn_{}/{}_step_0.obj".format(args.input, dt_now, args.output, str(epoch))
153 |
154 | pbar.update(1)
155 |
156 |
157 | """ refinement """
158 | posnet.eval()
159 | dm = v_mask.reshape(-1, 1).float()
160 | poss = posnet(dataset, dm)
161 | out_pos = poss[0].to("cpu").detach()
162 | ini_pos = torch.from_numpy(ini_mesh.vs).float()
163 | ref_pos = Mesh.mesh_merge(ini_mesh.Lap, ini_mesh, out_pos, v_mask, w=args.mu)
164 | out_path = "{}/output/{}_mgcn_{}/refine.obj".format(args.input, dt_now, args.output)
165 | out_mesh.vs = ref_pos.detach().numpy().copy()
166 | Mesh.save(out_mesh, out_path)
167 |
168 | if __name__ == "__main__":
169 | main()
--------------------------------------------------------------------------------
/preprocess/prepare.py:
--------------------------------------------------------------------------------
1 | import pymeshlab as ml
2 | import pymeshfix as mf
3 | import numpy as np
4 | import torch
5 | import json
6 | import argparse
7 | import sys
8 | import os
9 | sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
10 | from util.mesh import Mesh
11 |
12 | SMOOTH_ITER = 30
13 | MAXHOLESIZE = 1000
14 | EPSILON = 0.2
15 | REMESH_TARGET = ml.Percentage(0.6)
16 |
17 | def get_parse():
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("-i", "--input", type=str, required=True)
20 | parser.add_argument("-r", "--remesh", type=float, default=0.6)
21 | args = parser.parse_args()
22 |
23 | for k, v in vars(args).items():
24 | print("{:12s}: {}".format(k, v))
25 |
26 | return args
27 |
28 | def mesh_fix(ms, dirname, meshname):
29 | org_vm = ms.current_mesh().vertex_matrix()
30 | org_fm = ms.current_mesh().face_matrix()
31 | meshfix = mf.MeshFix(org_vm, org_fm)
32 | meshfix.repair()
33 | meshfix.save("{}/tmp/{}_fixed.ply".format(dirname, meshname))
34 |
35 | def remesh(ms, dirname, meshname, targetlen):
36 | ms.load_new_mesh("{}/tmp/{}_fixed.ply".format(dirname, meshname))
37 | # [original, fixed]
38 |
39 | ms.apply_filter("remeshing_isotropic_explicit_remeshing", targetlen=targetlen)
40 | # [original, remeshed]
41 |
42 | ms.save_current_mesh("{}/tmp/{}_remeshed.obj".format(dirname, meshname))
43 |
44 | def normalize(ms):
45 | ms.apply_filter("transform_scale_normalize", scalecenter="barycenter", unitflag=True, alllayers=True)
46 | ms.apply_filter("transform_translate_center_set_origin", traslmethod="Center on Layer BBox", alllayers=True)
47 |
48 | def edge_based_scaling(mesh):
49 | edge_vec = mesh.vs[mesh.edges][:, 0, :] - mesh.vs[mesh.edges][:, 1, :]
50 | ave_len = np.sum(np.linalg.norm(edge_vec, axis=1)) / mesh.edges.shape[0]
51 | mesh.vs /= ave_len
52 | return ave_len, mesh
53 |
54 | def normalize_scale(ms, dirname, meshname):
55 | try:
56 | ms.load_new_mesh("{}/{}_gt.obj".format(dirname, meshname))
57 | ms.save_current_mesh("{}/tmp/{}_gt.obj".format(dirname, meshname))
58 | # [original, remeshes, gt<-current]
59 | except:
60 | pass
61 | ms.set_current_mesh(1)
62 | # [original, remeshed<-current, gt]
63 | normalize(ms)
64 |
65 | ms.save_current_mesh("{}/tmp/{}_initial.obj".format(dirname, meshname))
66 | ms.set_current_mesh(0)
67 | ms.save_current_mesh("{}/tmp/{}_original.obj".format(dirname, meshname))
68 | try:
69 | ms.set_current_mesh(2)
70 | ms.save_current_mesh("{}/{}_gt.obj".format(dirname, meshname))
71 | except:
72 | pass
73 | ms.set_current_mesh(1)
74 |
75 | init_mesh = Mesh("{}/tmp/{}_initial.obj".format(dirname, meshname), build_mat=False)
76 | org_mesh = Mesh("{}/tmp/{}_original.obj".format(dirname, meshname), manifold=False)
77 | try:
78 | gt_mesh = Mesh("{}/{}_gt.obj".format(dirname, meshname), manifold=False)
79 | except:
80 | pass
81 | ave_len, init_mesh = edge_based_scaling(init_mesh)
82 | org_mesh.vs /= ave_len
83 | Mesh.save(init_mesh, "{}/{}_initial.obj".format(dirname, meshname))
84 | init_mesh.path = "{}/{}_initial.obj".format(dirname, meshname)
85 | torch.save(init_mesh, "{}/{}_initial.pt".format(dirname, meshname))
86 | Mesh.save(org_mesh, "{}/{}_original.obj".format(dirname, meshname), color=True)
87 | try:
88 | gt_mesh.vs /= ave_len
89 | Mesh.save(gt_mesh, "{}/{}_gt.obj".format(dirname, meshname))
90 | except:
91 | pass
92 |
93 | def write_mask(ms, epsilon, dirname, meshname):
94 | ms.clear()
95 | ms.load_new_mesh("{}/{}_original.obj".format(dirname, meshname))
96 | ms.load_new_mesh("{}/{}_initial.obj".format(dirname, meshname))
97 |
98 | ms.apply_filter("distance_from_reference_mesh", measuremesh=ms.current_mesh_id(), refmesh=0, signeddist=False)
99 | quality = ms.current_mesh().vertex_quality_array()
100 | mask_vs = quality < epsilon
101 |
102 | with open("{}/{}_vmask.json".format(dirname, meshname), "w") as vm:
103 | json.dump(mask_vs.tolist(), vm)
104 |
105 | new_vs = ms.current_mesh().vertex_matrix()[np.logical_not(mask_vs)]
106 | with open("{}/{}_inserted.obj".format(dirname, meshname), "w") as f_obj:
107 | for v in new_vs:
108 | print("v {} {} {}".format(v[0], v[1], v[2]), file=f_obj)
109 |
110 | def smooth(ms, dirname, meshname):
111 | # [original, normalized, initial]
112 | ms.apply_filter("laplacian_smooth", selected=True)
113 | ms.apply_filter("laplacian_smooth", stepsmoothnum=SMOOTH_ITER, cotangentweight=False, selected=False)
114 | ms.save_current_mesh("{}/{}_smooth.obj".format(dirname, meshname))
115 |
116 | def main():
117 | args = get_parse()
118 |
119 | filename = args.input
120 | targetlen = ml.Percentage(args.remesh)
121 | dirname = os.path.dirname(filename)
122 | meshname = dirname.split("/")[-1]
123 |
124 | ms = ml.MeshSet()
125 | ms.load_new_mesh(filename)
126 | # [original]
127 |
128 | os.makedirs("{}/tmp".format(dirname), exist_ok=True)
129 | print("[MeshFix]")
130 | mesh_fix(ms, dirname, meshname)
131 | print("[Remeshing]")
132 | remesh(ms, dirname, meshname, targetlen)
133 | print("[Normalizing scale]")
134 | normalize_scale(ms, dirname, meshname)
135 | write_mask(ms, EPSILON, dirname, meshname)
136 | print("[Smoothing]")
137 | smooth(ms, dirname, meshname)
138 | print("[FINISHED]")
139 | ms.clear()
140 |
141 |
142 | if __name__ == "__main__":
143 | main()
--------------------------------------------------------------------------------
/refinement.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import json
4 | import argparse
5 |
6 | from util.mesh import Mesh
7 | import util.models as Models
8 | import util.datamaker as Datamaker
9 |
10 | def get_parse():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("-src", type=str, required=True)
13 | parser.add_argument("-dst", type=str, required=True)
14 | parser.add_argument("-vm", type=str, required=True)
15 | parser.add_argument("-ref", type=str, required=True)
16 | parser.add_argument("-mu", type=float, default=1.0)
17 | args = parser.parse_args()
18 |
19 | for k, v in vars(args).items():
20 | print("{:12s}: {}".format(k, v))
21 |
22 | return args
23 |
24 | def main():
25 | args = get_parse()
26 | dst_path = args.dst
27 | src_path = args.src
28 | vm_path = args.vm
29 | ref_path = args.ref
30 | dst_mesh = Mesh(dst_path, build_mat=False)
31 | src_mesh = Mesh(src_path)
32 | with open(vm_path, "r") as f:
33 | v_mask = np.array(json.load(f))
34 | v_mask = torch.from_numpy(v_mask).reshape(-1)
35 | dst_pos = torch.from_numpy(dst_mesh.vs).float()
36 | ref_pos = Mesh.mesh_merge(src_mesh.Lap, src_mesh, dst_pos, v_mask, w=args.mu)
37 | dst_mesh.vs = ref_pos.detach().numpy().copy()
38 | Mesh.save(dst_mesh, ref_path)
39 |
40 |
41 | if __name__ == "__main__":
42 | main()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cython
2 | numpy
3 | matplotlib==3.3
4 | open3d
5 | scikit-learn==1.0.2
6 | pymeshlab==2021.10
7 | pymeshfix==0.16.1
8 | tqdm==4.61.1
9 | wandb==0.12.1
10 | opencv-python
11 | protobuf==3.20
12 | viser==0.2.1
13 | torch==1.13.0
14 | --find-links https://pytorch-geometric.com/whl/torch-1.13.0+cu117.html
15 | torch-scatter==2.1.0
16 | torch-sparse==0.6.16
17 | torch-cluster==1.6.0
18 | torch-spline-conv==1.2.1
19 | torch-geometric==2.2.0
--------------------------------------------------------------------------------
/sgcn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import datetime
5 | import argparse
6 | import os
7 | import copy
8 | import random
9 | import viser
10 | from tqdm import tqdm
11 |
12 | import util.loss as Loss
13 | import util.models as Models
14 | import util.datamaker as Datamaker
15 | from util.mesh import Mesh
16 | from util.networks import SingleScaleGCN
17 | import check.dist_check as DIST
18 |
19 | def torch_fix_seed(seed=314):
20 | random.seed(seed)
21 | np.random.seed(seed)
22 | torch.manual_seed(seed)
23 | torch.cuda.manual_seed(seed)
24 | torch.backends.cudnn.deterministic = True
25 | torch.use_deterministic_algorithms = True
26 |
27 | def get_parser():
28 | parser = argparse.ArgumentParser(description="Self-supervised Mesh Completion")
29 | parser.add_argument("-i", "--input", type=str, required=True)
30 | parser.add_argument("-pos_lr", type=float, default=0.01)
31 | parser.add_argument("-iter", type=int, default=100)
32 | parser.add_argument("-iter_refine", type=int, default=1000)
33 | parser.add_argument("-k1", type=float, default=4.0)
34 | parser.add_argument("-k2", type=float, default=4.0)
35 | parser.add_argument("-dm_size", type=int, default=40)
36 | parser.add_argument("-net", type=str, default="single")
37 | parser.add_argument("-activation", type=str, default="lrelu")
38 | parser.add_argument("-ant", type=int, default=0)
39 | parser.add_argument("-kn", type=int, nargs="*", default=[4])
40 | parser.add_argument("-batch", type=int, default=5)
41 | parser.add_argument("-skip", action="store_true")
42 | parser.add_argument("-drop", type=int, default=1)
43 | parser.add_argument("-drop_rate", type=float, default=0.6)
44 | parser.add_argument("-rot", type=int, default=0)
45 | parser.add_argument("-gpu", type=int, default=0)
46 | parser.add_argument("-cache", action="store_true")
47 | parser.add_argument("-CAD", action="store_true")
48 | parser.add_argument("-real", action="store_true")
49 | parser.add_argument("-mu", type=float, default=1.0)
50 | parser.add_argument("-viewer", action="store_true", default=True)
51 | parser.add_argument("-port", type=int, default=8081)
52 | args = parser.parse_args()
53 |
54 | for k, v in vars(args).items():
55 | print("{:12s}: {}".format(k, v))
56 |
57 | return args
58 |
59 | def main():
60 | args = get_parser()
61 |
62 | if args.viewer:
63 | server = viser.ViserServer(port=args.port)
64 |
65 | mesh_dic, dataset = Datamaker.create_dataset(args.input, dm_size=args.dm_size, kn=args.kn, cache=args.cache)
66 | dataset_rot = copy.deepcopy(dataset)
67 | ini_file, smo_file, v_mask, f_mask, mesh_name = mesh_dic["ini_file"], mesh_dic["smo_file"], mesh_dic["v_mask"], mesh_dic["f_mask"], mesh_dic["mesh_name"]
68 | org_mesh, ini_mesh, smo_mesh, out_mesh = mesh_dic["org_mesh"], mesh_dic["ini_mesh"], mesh_dic["smo_mesh"], mesh_dic["out_mesh"]
69 | rot_mesh = copy.deepcopy(ini_mesh)
70 | dt_now = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
71 |
72 | vmask_dummy = mesh_dic["vmask_dummy"]
73 | fmask_dummy = mesh_dic["fmask_dummy"]
74 |
75 | """ --- create model instance --- """
76 | torch_fix_seed()
77 | device = torch.device("cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu")
78 | posnet = SingleScaleGCN(device, skip=args.skip).to(device)
79 | optimizer_pos = torch.optim.Adam(posnet.parameters(), lr=args.pos_lr)
80 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer_pos, step_size=50, gamma=0.5)
81 |
82 | os.makedirs("{}/output/{}_sgcn".format(args.input, dt_now), exist_ok=True)
83 | scale = 2 / np.max(org_mesh.vs)
84 | if args.viewer:
85 | print("\n\033[42m Viewer at: http://localhost:{} \033[0m\n".format(args.port))
86 | with server.gui.add_folder("Training"):
87 | server.scene.add_mesh_simple(
88 | name="/input",
89 | vertices=org_mesh.vs * scale,
90 | faces=org_mesh.faces,
91 | flat_shading=True,
92 | visible=False,
93 | )
94 | server.scene.add_mesh_simple(
95 | name="/initial",
96 | vertices=ini_mesh.vs * scale,
97 | faces=ini_mesh.faces,
98 | flat_shading=True,
99 | visible=False,
100 | )
101 | gui_counter = server.gui.add_number(
102 | "Epoch",
103 | initial_value=0,
104 | disabled=True,
105 | )
106 |
107 | """ --- learning loop --- """
108 | with tqdm(total=args.iter) as pbar:
109 | """ --- training --- """
110 | for epoch in range(1, args.iter+1):
111 | n_data = vmask_dummy.shape[1]
112 | batch_index = torch.randperm(n_data).reshape(-1, args.batch)
113 | epoch_loss_p = 0.0
114 | epoch_loss_n = 0.0
115 | epoch_loss_r = 0.0
116 | epoch_loss = 0.0
117 |
118 | for batch in batch_index:
119 | dm_batch = vmask_dummy[:, batch]
120 | posnet.train()
121 | optimizer_pos.zero_grad()
122 |
123 | for i, b in enumerate(batch):
124 | dm = dm_batch[:, i].reshape(-1, 1)
125 | rm = v_mask.reshape(-1, 1).float()
126 | dm = rm * dm
127 | ini_vs = ini_mesh.vs
128 |
129 | pos = posnet(dataset_rot, dm)
130 | norm = Models.compute_fn(pos, ini_mesh.faces)
131 | loss_p = Loss.mask_pos_rec_loss(pos, ini_vs, v_mask)
132 | loss_n = Loss.mask_norm_rec_loss(norm, rot_mesh.fn, f_mask)
133 | if args.CAD:
134 | loss_bnf, _ = Loss.fn_bnf_detach_loss(pos, norm, ini_mesh, loop=5)
135 | loss = loss_p + args.k1 * loss_n + args.k2 * loss_bnf
136 | epoch_loss_r += loss_bnf.item()
137 | else:
138 | #loss_lap = Loss.mesh_laplacian_loss(pos, ini_mesh)
139 | loss = loss_p + args.k1 * loss_n#+ 0.0 * loss_lap
140 | epoch_loss_p += loss_p.item()
141 | epoch_loss_n += loss_n.item()
142 |
143 | loss.backward()
144 | epoch_loss += loss.item()
145 |
146 | optimizer_pos.step()
147 | scheduler.step()
148 |
149 | epoch_loss_p /= n_data
150 | epoch_loss_n /= n_data
151 | epoch_loss_r /= n_data
152 | epoch_loss /= n_data
153 |
154 | pbar.set_description("Epoch {}".format(epoch))
155 | pbar.set_postfix({"loss": epoch_loss})
156 |
157 | """ --- test --- """
158 | if epoch == args.iter:
159 | out_path = "{}/output/{}_sgcn/{}_step.obj".format(args.input, dt_now, str(epoch))
160 | out_mesh.vs = pos.to("cpu").detach().numpy().copy()
161 | Mesh.save(out_mesh, out_path)
162 |
163 | if epoch % 10 == 0:
164 | out_path = "{}/output/{}_sgcn/{}_step.obj".format(args.input, dt_now, str(epoch))
165 |
166 | posnet.eval()
167 | dm = v_mask.reshape(-1, 1).float()
168 | pos = posnet(dataset, dm)
169 | out_mesh.vs = pos.to("cpu").detach().numpy().copy()
170 | Mesh.save(out_mesh, out_path)
171 | if args.viewer:
172 | server.scene.add_mesh_simple(
173 | name="/output",
174 | vertices=out_mesh.vs * scale,
175 | faces=out_mesh.faces,
176 | flat_shading=True,
177 | )
178 | if args.viewer:
179 | gui_counter.value = epoch
180 |
181 | pbar.update(1)
182 |
183 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real)
184 |
185 | """ refinement """
186 | posnet.eval()
187 | dm = v_mask.reshape(-1, 1).float()
188 | out_pos = posnet(dataset, dm).to("cpu").detach()
189 | ref_pos = Mesh.mesh_merge(ini_mesh.Lap, ini_mesh, out_pos, v_mask, w=args.mu)
190 | out_path = "{}/output/{}_sgcn/refine.obj".format(args.input, dt_now, args.net)
191 | out_mesh.vs = ref_pos.detach().numpy().copy()
192 | Mesh.save(out_mesh, out_path)
193 | DIST.mesh_distance(mesh_dic["gt_file"], mesh_dic["org_file"], out_path, args.real)
194 |
195 | if __name__ == "__main__":
196 | main()
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/astaka-pe/SeMIGCN/fe17377db287298c77f325eeeeec5285b1c0f626/util/__init__.py
--------------------------------------------------------------------------------
/util/datamaker.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import numpy as np
3 | import json
4 | import os
5 | from tqdm import tqdm
6 | import torch
7 | import copy
8 | from .mesh import Mesh
9 | from torch_geometric.data import Data
10 | from typing import Tuple
11 | from pathlib import Path
12 |
13 | class Dataset:
14 | def __init__(self, data):
15 | self.keys = data.keys
16 | self.num_nodes = data.num_nodes
17 | self.num_edges = data.num_edges
18 | self.num_node_features = data.num_node_features
19 | self.contains_isolated_nodes = data.has_isolated_nodes()
20 | self.contains_self_loops = data.has_self_loops()
21 | self.z1 = data['z1']
22 | self.x_pos = data['x_pos']
23 | self.x_norm = data['x_norm']
24 | self.edge_index = data['edge_index']
25 | self.face_index = data['face_index']
26 |
27 |
28 | def create_dataset(file_path: str, dm_size=40, kn=[1], cache=False) -> Tuple[dict, Dataset]:
29 | """ create mesh """
30 | mesh_dic = {}
31 | file_path = str(Path(file_path))
32 | mesh_name = Path(file_path).name
33 | ini_file = "{}/{}_initial.obj".format(file_path, mesh_name)
34 | smo_file = "{}/{}_smooth.obj".format(file_path, mesh_name)
35 | gt_file = "{}/{}_gt.obj".format(file_path, mesh_name)
36 | org_file = "{}/{}_original.obj".format(file_path, mesh_name)
37 | vmask_file = "{}/{}_vmask.json".format(file_path, mesh_name)
38 | fmask_file = "{}/{}_fmask.json".format(file_path, mesh_name)
39 |
40 | print("[Loading meshes...]")
41 | try:
42 | ini_mesh = torch.load("{}/{}_initial.pt".format(file_path, mesh_name))
43 | out_mesh = torch.load("{}/{}_initial.pt".format(file_path, mesh_name))
44 | except:
45 | ini_mesh = Mesh(ini_file, build_mat=False)
46 | out_mesh = copy.deepcopy(ini_mesh)
47 | torch.save(ini_mesh, "{}/{}_initial.pt".format(file_path, mesh_name))
48 | org_mesh = Mesh(org_file, manifold=False)
49 | smo_mesh = Mesh(smo_file, manifold=False)
50 | Mesh.copy_attribute(ini_mesh, smo_mesh)
51 |
52 | try:
53 | with open(vmask_file, "r") as vm:
54 | v_mask = np.array(json.load(vm))
55 | except:
56 | v_mask = None
57 |
58 | try:
59 | with open(fmask_file, "r") as fm:
60 | f_mask = np.array(json.load(fm))
61 | color_mask(ini_mesh, f_mask)
62 | except:
63 | if type(v_mask)==np.ndarray:
64 | f_mask = vmask_to_fmask(ini_mesh, v_mask)
65 | color_mask(ini_mesh, f_mask)
66 | else:
67 | f_mask = None
68 |
69 | """ create graph """
70 | z1 = ini_mesh.vs - smo_mesh.vs
71 | z1 = torch.tensor(z1, dtype=torch.float, requires_grad=True)
72 |
73 | x_pos = torch.tensor(smo_mesh.vs, dtype=torch.float)
74 | x_norm = torch.tensor(ini_mesh.fn, dtype=torch.float)
75 |
76 | edge_index = torch.tensor(ini_mesh.edges.T, dtype=torch.long)
77 | edge_index = torch.cat([edge_index, edge_index[[1,0],:]], dim=1)
78 | face_index = torch.from_numpy(ini_mesh.f_edges)
79 |
80 | mesh_dic["ini_file"] = ini_file
81 | mesh_dic["smo_file"] = smo_file
82 | mesh_dic["gt_file"] = gt_file
83 | mesh_dic["org_file"] = org_file
84 | mesh_dic["v_mask"] = torch.from_numpy(v_mask).bool()
85 | mesh_dic["f_mask"] = f_mask
86 | mesh_dic["mesh_name"] = mesh_name
87 | mesh_dic["org_mesh"] = org_mesh
88 | mesh_dic["ini_mesh"] = ini_mesh
89 | mesh_dic["out_mesh"] = out_mesh
90 | mesh_dic["smo_mesh"] = smo_mesh
91 | mesh_dic["vmask_dummy"] = None
92 | mesh_dic["fmask_dummy"] = None
93 |
94 | os.makedirs("{}/dummy_mask/".format(file_path, mesh_name), exist_ok=True)
95 | if cache:
96 | mesh_dic["vmask_dummy"] = torch.load("{}/dummy_mask/vmask_dummy.pt".format(file_path))
97 | mesh_dic["fmask_dummy"] = torch.load("{}/dummy_mask/fmask_dummy.pt".format(file_path))
98 | else:
99 | print("[Creating synthetic occlusion...]")
100 | mesh_dic["vmask_dummy"], mesh_dic["fmask_dummy"] = make_dummy_mask(ini_mesh, dm_size=dm_size, kn=kn, exist_face=f_mask)
101 | torch.save(mesh_dic["vmask_dummy"], "{}/dummy_mask/vmask_dummy.pt".format(file_path))
102 | torch.save(mesh_dic["fmask_dummy"], "{}/dummy_mask/fmask_dummy.pt".format(file_path))
103 |
104 | """ create dataset """
105 | data = Data(x=z1, z1=z1, x_pos=x_pos, x_norm=x_norm, edge_index=edge_index, face_index=face_index, v_mask=mesh_dic["v_mask"], f_mask=mesh_dic["f_mask"], vmask_dummy=mesh_dic["v_mask"], fmask_dummy=mesh_dic["f_mask"])
106 | dataset = Dataset(data)
107 | return mesh_dic, dataset
108 |
109 |
110 | def make_dummy_mask(mesh, dm_size=40, kn=[3, 4, 5], exist_face=None):
111 | valid_idx = []
112 | for i in kn:
113 | valid_idx.extend(np.arange(dm_size*i, dm_size*(i+1)).tolist())
114 |
115 | AI = mesh.AdjI.float()
116 | # p_list = torch.tensor([0.3, 0.06, 0.02, 0.02, 0.014, 0.008, 0.008, 0.0002, 0.001])
117 | #p_list = torch.tensor([0.007, 0.007, 0.007, 0.007, 0.007, 0.007, 0.007, 0.0007, 0.007])
118 | p_list = torch.tensor([0.014, 0.014, 0.014, 0.014, 0.014, 0.014, 0.014, 0.0014, 0.014])
119 | #p_list = torch.tensor([0.021, 0.021, 0.021, 0.021, 0.021, 0.021, 0.021, 0.0021, 0.021])
120 | #p_list = torch.tensor([0.3, 0.06, 0.04, 0.02, 0.006, 0.008, 0.004, 0.0002, 0.001])
121 | vmask = torch.tensor([])
122 | bar = tqdm(total=sum(kn))
123 | for k in kn:
124 | Mv0 = np.random.binomial(1, p_list[k], size=[len(mesh.vs), dm_size])
125 | Mv0 = torch.from_numpy(Mv0).float()
126 | for _ in range(k):
127 | Mv1 = (torch.sparse.mm(AI, Mv0) > 0).float()
128 | Mv0 = Mv1
129 | bar.update(1)
130 |
131 | if len(vmask) == 0:
132 | vmask = 1.0 - Mv0
133 | else:
134 | vmask = torch.cat([vmask, 1.0 - Mv0], dim=1)
135 |
136 | fmask = (torch.sparse.mm(mesh.f2v_mat, 1.0 - vmask) == 0).float()
137 | """ write the masked meshes """
138 | for i, k in enumerate(kn):
139 | color = np.ones([len(mesh.faces), 3])
140 | color[:, 0] = 0.332 # 0.75
141 | color[:, 1] = 0.664 # 0.75
142 | color[:, 2] = 1.0 # 0.75
143 | black = fmask[:, dm_size*i] == 0
144 | color[black, 0] = 1.0 # 0
145 | color[black, 1] = 0.664 # 0
146 | color[black, 2] = 0.0 # 0
147 | color[exist_face==0, 0] = 1.0 # 0
148 | color[exist_face==0, 1] = 0.0 # 0.5
149 | color[exist_face==0, 2] = 1.0 # 1
150 | dropv = 100 * torch.sum(vmask[:, dm_size*i] == 0) // len(mesh.vs)
151 | filename = "{}/dummy_mask/{}-neighbor-{}per.ply".format(os.path.dirname(mesh.path), k, dropv)
152 | mesh.save_as_ply(filename, color)
153 |
154 | return vmask, fmask
155 |
156 | def vmask_to_fmask(mesh, vmask):
157 | vmask = torch.from_numpy(vmask).reshape(-1, 1).float()
158 | fmask = (torch.sparse.mm(mesh.f2v_mat, 1.0 - vmask) == 0).bool().reshape(-1)
159 | return fmask
160 |
161 | def color_mask(mesh, fmask):
162 | color = np.ones([len(mesh.faces), 3])
163 | color[:, 0] = 0.332
164 | color[:, 1] = 0.664
165 | color[:, 2] = 1.0
166 | color[fmask==0, 0] = 1.0
167 | color[fmask==0, 1] = 0.0
168 | color[fmask==0, 2] = 1.0
169 | os.makedirs("{}/dummy_mask/".format(os.path.dirname(mesh.path)), exist_ok=True)
170 | filename = "{}/dummy_mask/initial.ply".format(os.path.dirname(mesh.path))
171 | mesh.save_as_ply(filename, color)
--------------------------------------------------------------------------------
/util/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from util.mesh import Mesh
5 | from typing import Union
6 |
7 |
8 | def squared_norm(x, dim=None, keepdim=False):
9 | return torch.sum(x * x, dim=dim, keepdim=keepdim)
10 |
11 | def norm(x, eps=1.0e-12, dim=None, keepdim=False):
12 | return torch.sqrt(squared_norm(x, dim=dim, keepdim=keepdim) + eps)
13 |
14 | def mask_pos_rec_loss(pred_pos: torch.Tensor, real_pos: Union[torch.Tensor, np.ndarray], mask: np.ndarray, ltype="rmse") -> torch.Tensor:
15 | """ reconstructuion error for vertex positions """
16 | if type(real_pos) == np.ndarray:
17 | real_pos = torch.from_numpy(real_pos)
18 | real_pos = real_pos.to(pred_pos.device)
19 |
20 | if ltype == "l1mae":
21 | diff_pos = torch.sum(torch.abs(real_pos[mask] - pred_pos[mask]), dim=1)
22 | loss = torch.sum(diff_pos) / len(diff_pos)
23 |
24 | elif ltype == "rmse":
25 | diff_pos = torch.abs(real_pos[mask] - pred_pos[mask])
26 | diff_pos = diff_pos ** 2
27 | diff_pos = torch.sum(diff_pos.squeeze(), dim=1)
28 | diff_pos = torch.sum(diff_pos) / len(diff_pos)
29 | loss = torch.sqrt(diff_pos + 1.0e-6)
30 | else:
31 | print("[ERROR]: ltype error")
32 | exit()
33 |
34 | return loss
35 |
36 | def pos_rec_loss(pred_pos: Union[torch.Tensor, np.ndarray], real_pos: np.ndarray, ltype="rmse") -> torch.Tensor:
37 | """ reconstructuion error for vertex positions """
38 | if type(pred_pos) == np.ndarray:
39 | pred_pos = torch.from_numpy(pred_pos)
40 | if type(real_pos) == np.ndarray:
41 | real_pos = torch.from_numpy(real_pos)
42 |
43 | real_pos = real_pos.to(pred_pos.device)
44 |
45 | if ltype == "l1mae":
46 | diff_pos = torch.sum(torch.abs(real_pos - pred_pos), dim=1)
47 | loss = torch.sum(diff_pos) / len(diff_pos)
48 |
49 | elif ltype == "rmse":
50 | diff_pos = torch.abs(real_pos - pred_pos)
51 | diff_pos = diff_pos ** 2
52 | diff_pos = torch.sum(diff_pos.squeeze(), dim=1)
53 | diff_pos = torch.sum(diff_pos) / len(diff_pos)
54 | loss = torch.sqrt(diff_pos + 1.0e-6)
55 | else:
56 | print("[ERROR]: ltype error")
57 | exit()
58 | return loss
59 |
60 | def mesh_laplacian_loss(pred_pos: torch.Tensor, mesh: Mesh, ltype="rmse") -> torch.Tensor:
61 | """ simple laplacian for output meshes """
62 | v2v = mesh.Adj.to(pred_pos.device)
63 | v_dims = mesh.v_dims.reshape(-1, 1).to(pred_pos.device)
64 | lap_pos = torch.sparse.mm(v2v, pred_pos) / v_dims
65 | lap_diff = torch.sum((pred_pos - lap_pos) ** 2, dim=1)
66 | if ltype == "mae":
67 | lap_diff = torch.sqrt(lap_diff + 1.0e-12)
68 | lap_loss = torch.sum(lap_diff) / len(lap_diff)
69 | elif ltype == "rmse":
70 | lap_loss = torch.sum(lap_diff) / len(lap_diff)
71 | lap_loss = torch.sqrt(lap_loss + 1.0e-12)
72 | else:
73 | print("[ERROR]: ltype error")
74 | exit()
75 |
76 | return lap_loss
77 |
78 | def mask_norm_rec_loss(pred_norm: Union[torch.Tensor, np.ndarray], real_norm: Union[torch.Tensor, np.ndarray], mask: np.ndarray, ltype="l1mae") -> torch.Tensor:
79 | """ reconstruction loss for (vertex, face) normal """
80 | if type(pred_norm) == np.ndarray:
81 | pred_norm = torch.from_numpy(pred_norm)
82 | if type(real_norm) == np.ndarray:
83 | real_norm = torch.from_numpy(real_norm).to(pred_norm.device)
84 |
85 | if ltype == "l2mae":
86 | norm_diff = torch.sum((pred_norm[mask] - real_norm[mask]) ** 2, dim=1)
87 | loss = torch.sqrt(norm_diff + 1e-12)
88 | loss = torch.sum(loss) / len(loss)
89 | elif ltype == "l1mae":
90 | norm_diff = torch.sum(torch.abs(pred_norm[mask] - real_norm[mask]), dim=1)
91 | loss = torch.sum(norm_diff) / len(norm_diff)
92 | elif ltype == "l2rmse":
93 | norm_diff = torch.sum((pred_norm[mask] - real_norm[mask]) ** 2, dim=1)
94 | loss = torch.sum(norm_diff) / len(norm_diff)
95 | loss = torch.sqrt(loss + 1e-12)
96 | elif ltype == "l1rmse":
97 | norm_diff = torch.sum(torch.abs(pred_norm[mask] - real_norm[mask]), dim=1)
98 | loss = torch.sum(norm_diff ** 2) / len(norm_diff)
99 | loss = torch.sqrt(loss + 1e-12)
100 | elif ltype == "cos":
101 | cos_loss = 1.0 - torch.sum(torch.mul(pred_norm[mask], real_norm[mask]), dim=1)
102 | loss = torch.sum(cos_loss, dim=0) / len(cos_loss)
103 | else:
104 | print("[ERROR]: ltype error")
105 | exit()
106 |
107 | return loss
108 |
109 | def norm_rec_loss(pred_norm: Union[torch.Tensor, np.ndarray], real_norm: Union[torch.Tensor, np.ndarray], ltype="l1mae") -> torch.Tensor:
110 | """ reconstruction loss for (vertex, face) normal """
111 | if type(pred_norm) == np.ndarray:
112 | pred_norm = torch.from_numpy(pred_norm)
113 | if type(real_norm) == np.ndarray:
114 | real_norm = torch.from_numpy(real_norm).to(pred_norm.device)
115 |
116 | if ltype == "l2mae":
117 | norm_diff = torch.sum((pred_norm - real_norm) ** 2, dim=1)
118 | loss = torch.sqrt(norm_diff + 1e-12)
119 | loss = torch.sum(loss) / len(loss)
120 | elif ltype == "l1mae":
121 | norm_diff = torch.sum(torch.abs(pred_norm - real_norm), dim=1)
122 | loss = torch.sum(norm_diff) / len(norm_diff)
123 | elif ltype == "l2rmse":
124 | norm_diff = torch.sum((pred_norm - real_norm) ** 2, dim=1)
125 | loss = torch.sum(norm_diff) / len(norm_diff)
126 | loss = torch.sqrt(loss + 1e-12)
127 | elif ltype == "l1rmse":
128 | norm_diff = torch.sum(torch.abs(pred_norm - real_norm), dim=1)
129 | loss = torch.sum(norm_diff ** 2) / len(norm_diff)
130 | loss = torch.sqrt(loss + 1e-12)
131 | elif ltype == "cos":
132 | cos_loss = 1.0 - torch.sum(torch.mul(pred_norm, real_norm), dim=1)
133 | loss = torch.sum(cos_loss, dim=0) / len(cos_loss)
134 | else:
135 | print("[ERROR]: ltype error")
136 | exit()
137 |
138 | return loss
139 |
140 | def fn_bnf_loss(pos: torch.Tensor, fn: torch.Tensor, mesh: Mesh, ltype="l1mae", loop=5) -> torch.Tensor:
141 | """ bilateral loss for face normal """
142 | if type(pos) == np.ndarray:
143 | pos = torch.from_numpy(pos).to(fn.device)
144 | else:
145 | pos = pos.detach()
146 | fc = torch.sum(pos[mesh.faces], 1) / 3.0
147 | fa = torch.cross(pos[mesh.faces[:, 1]] - pos[mesh.faces[:, 0]], pos[mesh.faces[:, 2]] - pos[mesh.faces[:, 0]])
148 | fa = 0.5 * torch.sqrt(torch.sum(fa**2, axis=1) + 1.0e-12)
149 |
150 | #fc = torch.from_numpy(mesh.fc).float().to(fn.device)
151 | #fa = torch.from_numpy(mesh.fa).float().to(fn.device)
152 | f2f = torch.from_numpy(mesh.f2f).long().to(fn.device)
153 | no_neig = 1.0 * (f2f != -1)
154 |
155 | neig_fc = fc[f2f]
156 | neig_fa = fa[f2f] * no_neig
157 | fc0_tile = fc.reshape(-1, 1, 3)
158 | fc_dist = squared_norm(neig_fc - fc0_tile, dim=2)
159 | sigma_c = torch.sum(torch.sqrt(fc_dist + 1.0e-12)) / (fc_dist.shape[0] * fc_dist.shape[1])
160 | #sigma_c = 1.0
161 |
162 | new_fn = fn
163 | for i in range(loop):
164 | neig_fn = new_fn[f2f]
165 | fn0_tile = new_fn.reshape(-1, 1, 3)
166 | fn_dist = squared_norm(neig_fn - fn0_tile, dim=2)
167 | sigma_s = 0.3
168 | wc = torch.exp(-1.0 * fc_dist / (2 * (sigma_c ** 2)))
169 | ws = torch.exp(-1.0 * fn_dist / (2 * (sigma_s ** 2)))
170 |
171 | W = torch.stack([wc*ws*neig_fa, wc*ws*neig_fa, wc*ws*neig_fa], dim=2)
172 |
173 | new_fn = torch.sum(W * neig_fn, dim=1)
174 | new_fn = new_fn / (norm(new_fn, dim=1, keepdim=True) + 1.0e-12)
175 |
176 | if ltype == "mae":
177 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1)
178 | bnf_diff = torch.sqrt(bnf_diff + 1.0e-12)
179 | loss = torch.sum(bnf_diff) / len(bnf_diff)
180 | elif ltype == "l1mae":
181 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1)
182 | loss = torch.sum(bnf_diff) / len(bnf_diff)
183 | elif ltype == "rmse":
184 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1)
185 | loss = torch.sum(bnf_diff) / len(bnf_diff)
186 | loss = torch.sqrt(loss + 1.0e-12)
187 | elif ltype == "l1rmse":
188 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1)
189 | loss = torch.sum(bnf_diff ** 2) / len(bnf_diff)
190 | loss = torch.sqrt(loss ** 2 + 1.0e-12)
191 | else:
192 | print("[ERROR]: ltype error")
193 | exit()
194 |
195 | return loss, new_fn
196 |
197 | def fn_bnf_detach_loss(pos: torch.Tensor, fn: torch.Tensor, mesh: Mesh, ltype="l1mae", loop=5) -> torch.Tensor:
198 | """ bilateral loss for face normal """
199 | if type(pos) == np.ndarray:
200 | pos = torch.from_numpy(pos).to(fn.device)
201 | else:
202 | pos = pos.detach()
203 | fc = torch.sum(pos[mesh.faces], 1) / 3.0
204 | fa = torch.cross(pos[mesh.faces[:, 1]] - pos[mesh.faces[:, 0]], pos[mesh.faces[:, 2]] - pos[mesh.faces[:, 0]])
205 | fa = 0.5 * torch.sqrt(torch.sum(fa**2, axis=1) + 1.0e-12)
206 |
207 | #fc = torch.from_numpy(mesh.fc).float().to(fn.device)
208 | #fa = torch.from_numpy(mesh.fa).float().to(fn.device)
209 | f2f = torch.from_numpy(mesh.f2f).long().to(fn.device)
210 | no_neig = 1.0 * (f2f != -1)
211 |
212 | neig_fc = fc[f2f]
213 | neig_fa = fa[f2f] * no_neig
214 | fc0_tile = fc.reshape(-1, 1, 3)
215 | fc_dist = squared_norm(neig_fc - fc0_tile, dim=2)
216 | sigma_c = torch.sum(torch.sqrt(fc_dist + 1.0e-12)) / (fc_dist.shape[0] * fc_dist.shape[1])
217 | #sigma_c = 1.0
218 |
219 | new_fn = fn
220 | for i in range(loop):
221 | neig_fn = new_fn[f2f]
222 | fn0_tile = new_fn.reshape(-1, 1, 3)
223 | fn_dist = squared_norm(neig_fn - fn0_tile, dim=2)
224 | sigma_s = 0.3
225 | wc = torch.exp(-1.0 * fc_dist / (2 * (sigma_c ** 2)))
226 | ws = torch.exp(-1.0 * fn_dist / (2 * (sigma_s ** 2)))
227 |
228 | W = torch.stack([wc*ws*neig_fa, wc*ws*neig_fa, wc*ws*neig_fa], dim=2)
229 |
230 | new_fn = torch.sum(W * neig_fn, dim=1)
231 | new_fn = new_fn / (norm(new_fn, dim=1, keepdim=True) + 1.0e-12)
232 | new_fn = new_fn.detach()
233 |
234 | if ltype == "mae":
235 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1)
236 | bnf_diff = torch.sqrt(bnf_diff + 1.0e-12)
237 | loss = torch.sum(bnf_diff) / len(bnf_diff)
238 | elif ltype == "l1mae":
239 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1)
240 | loss = torch.sum(bnf_diff) / len(bnf_diff)
241 | elif ltype == "rmse":
242 | bnf_diff = torch.sum((new_fn - fn) ** 2, dim=1)
243 | loss = torch.sum(bnf_diff) / len(bnf_diff)
244 | loss = torch.sqrt(loss + 1.0e-12)
245 | elif ltype == "l1rmse":
246 | bnf_diff = torch.sum(torch.abs(new_fn - fn), dim=1)
247 | loss = torch.sum(bnf_diff ** 2) / len(bnf_diff)
248 | loss = torch.sqrt(loss ** 2 + 1.0e-12)
249 | else:
250 | print("[ERROR]: ltype error")
251 | exit()
252 |
253 | return loss, new_fn
--------------------------------------------------------------------------------
/util/mesh.py:
--------------------------------------------------------------------------------
1 | from turtle import pd
2 | import numpy as np
3 | import torch
4 | from functools import reduce
5 | from collections import Counter
6 | import scipy as sp
7 | import heapq
8 | import copy
9 | from sklearn.preprocessing import normalize
10 |
11 | OPTIM_VALENCE = 6
12 | VALENCE_WEIGHT = 1
13 |
14 | class Mesh:
15 | def __init__(self, path, manifold=True, build_mat=True, build_code=False):
16 | self.path = path
17 | self.vs, self.vc, self.faces = self.fill_from_file(path)
18 | self.compute_face_normals()
19 | self.compute_face_center()
20 | self.device = 'cpu'
21 | self.simp = False
22 |
23 | if manifold:
24 | self.build_gemm() #self.edges, self.ve
25 | self.compute_vert_normals()
26 | self.build_v2v()
27 | self.build_vf()
28 | self.vs_code = None
29 | if build_mat:
30 | self.build_mesh_lap()
31 | if build_code:
32 | self.vs_code = self.eigen_decomposition(self.lapmat, k=512)
33 | self.fc_code = self.eigen_decomposition(self.f_lapmat, k=100)
34 |
35 | def fill_from_file(self, path):
36 | vs, faces, vc = [], [], []
37 | f = open(path)
38 | for line in f:
39 | line = line.strip()
40 | splitted_line = line.split()
41 | if not splitted_line:
42 | continue
43 | elif splitted_line[0] == 'v':
44 | vs.append([float(v) for v in splitted_line[1:4]])
45 | if len(splitted_line) == 7: # colored mesh
46 | vc.append([float(v) for v in splitted_line[4:7]])
47 | elif splitted_line[0] == 'f':
48 | face_vertex_ids = [int(c.split('/')[0]) for c in splitted_line[1:]]
49 | assert len(face_vertex_ids) == 3
50 | face_vertex_ids = [(ind - 1) if (ind >= 0) else (len(vs) + ind) for ind in face_vertex_ids]
51 | faces.append(face_vertex_ids)
52 | f.close()
53 | vs = np.asarray(vs)
54 | vc = np.asarray(vc)
55 | faces = np.asarray(faces, dtype=int)
56 |
57 | assert np.logical_and(faces >= 0, faces < len(vs)).all()
58 | return vs, vc, faces
59 |
60 | def build_gemm(self):
61 | self.ve = [[] for _ in self.vs]
62 | self.vei = [[] for _ in self.vs]
63 | edge_nb = []
64 | sides = []
65 | edge2key = dict()
66 | edges = []
67 | edges_count = 0
68 | nb_count = []
69 | for face_id, face in enumerate(self.faces):
70 | faces_edges = []
71 | for i in range(3):
72 | cur_edge = (face[i], face[(i + 1) % 3])
73 | faces_edges.append(cur_edge)
74 | for idx, edge in enumerate(faces_edges):
75 | edge = tuple(sorted(list(edge)))
76 | faces_edges[idx] = edge
77 | if edge not in edge2key:
78 | edge2key[edge] = edges_count
79 | edges.append(list(edge))
80 | edge_nb.append([-1, -1, -1, -1])
81 | sides.append([-1, -1, -1, -1])
82 | self.ve[edge[0]].append(edges_count)
83 | self.ve[edge[1]].append(edges_count)
84 | self.vei[edge[0]].append(0)
85 | self.vei[edge[1]].append(1)
86 | nb_count.append(0)
87 | edges_count += 1
88 | for idx, edge in enumerate(faces_edges):
89 | edge_key = edge2key[edge]
90 | edge_nb[edge_key][nb_count[edge_key]] = edge2key[faces_edges[(idx + 1) % 3]]
91 | edge_nb[edge_key][nb_count[edge_key] + 1] = edge2key[faces_edges[(idx + 2) % 3]]
92 | nb_count[edge_key] += 2
93 | for idx, edge in enumerate(faces_edges):
94 | edge_key = edge2key[edge]
95 | sides[edge_key][nb_count[edge_key] - 2] = nb_count[edge2key[faces_edges[(idx + 1) % 3]]] - 1
96 | sides[edge_key][nb_count[edge_key] - 1] = nb_count[edge2key[faces_edges[(idx + 2) % 3]]] - 2
97 | self.edges = np.array(edges, dtype=np.int64)
98 | self.gemm_edges = np.array(edge_nb, dtype=np.int64)
99 | self.sides = np.array(sides, dtype=np.int64)
100 | self.edges_count = edges_count
101 | # lots of DS for loss
102 | """
103 | self.nvs, self.nvsi, self.nvsin, self.ve_in = [], [], [], []
104 | for i, e in enumerate(self.ve):
105 | self.nvs.append(len(e))
106 | self.nvsi += len(e) * [i]
107 | self.nvsin += list(range(len(e)))
108 | self.ve_in += e
109 | self.vei = reduce(lambda a, b: a + b, self.vei, [])
110 | self.vei = torch.from_numpy(np.array(self.vei).ravel()).to(self.device).long()
111 | self.nvsi = torch.from_numpy(np.array(self.nvsi).ravel()).to(self.device).long()
112 | self.nvsin = torch.from_numpy(np.array(self.nvsin).ravel()).to(self.device).long()
113 | self.ve_in = torch.from_numpy(np.array(self.ve_in).ravel()).to(self.device).long()
114 |
115 | self.max_nvs = max(self.nvs)
116 | self.nvs = torch.Tensor(self.nvs).to(self.device).float()
117 | self.edge2key = edge2key
118 | """
119 |
120 | def compute_face_normals(self):
121 | face_normals = np.cross(self.vs[self.faces[:, 1]] - self.vs[self.faces[:, 0]], self.vs[self.faces[:, 2]] - self.vs[self.faces[:, 0]])
122 | norm = np.linalg.norm(face_normals, axis=1, keepdims=True) + 1e-24
123 | face_areas = 0.5 * np.sqrt((face_normals**2).sum(axis=1))
124 | face_normals /= norm
125 | self.fn, self.fa = face_normals, face_areas
126 |
127 | def compute_vert_normals(self):
128 | vert_normals = np.zeros((3, len(self.vs)))
129 | face_normals = self.fn
130 | faces = self.faces
131 |
132 | nv = len(self.vs)
133 | nf = len(faces)
134 | mat_rows = faces.reshape(-1)
135 | mat_cols = np.array([[i] * 3 for i in range(nf)]).reshape(-1)
136 | mat_vals = np.ones(len(mat_rows))
137 | f2v_mat = sp.sparse.csr_matrix((mat_vals, (mat_rows, mat_cols)), shape=(nv, nf))
138 | vert_normals = sp.sparse.csr_matrix.dot(f2v_mat, face_normals)
139 | vert_normals = normalize(vert_normals, norm='l2', axis=1)
140 | self.vn = vert_normals
141 |
142 | def compute_face_center(self):
143 | faces = self.faces
144 | vs = self.vs
145 | self.fc = np.sum(vs[faces], 1) / 3.0
146 |
147 | def compute_fn_sphere(self):
148 | fn = self.fn
149 | u = (np.arctan2(fn[:, 1], fn[:, 0]) + np.pi) / (2.0 * np.pi)
150 | v = np.arctan2(np.sqrt(fn[:, 0]**2 + fn[:, 1]**2), fn[:, 2]) / np.pi
151 | self.fn_uv = np.stack([u, v]).T
152 |
153 | def build_uni_lap(self):
154 | """compute uniform laplacian matrix"""
155 | vs = torch.tensor(self.vs.T, dtype=torch.float)
156 | edges = self.edges
157 | ve = self.ve
158 |
159 | sub_mesh_vv = [edges[v_e, :].reshape(-1) for v_e in ve]
160 | sub_mesh_vv = [set(vv.tolist()).difference(set([i])) for i, vv in enumerate(sub_mesh_vv)]
161 |
162 | num_verts = vs.size(1)
163 | mat_rows = [np.array([i] * len(vv), dtype=np.int64) for i, vv in enumerate(sub_mesh_vv)]
164 | mat_rows = np.concatenate(mat_rows)
165 | mat_cols = [np.array(list(vv), dtype=np.int64) for vv in sub_mesh_vv]
166 | mat_cols = np.concatenate(mat_cols)
167 |
168 | mat_rows = torch.from_numpy(mat_rows).long()
169 | mat_cols = torch.from_numpy(mat_cols).long()
170 | mat_vals = torch.ones_like(mat_rows).float() * -1.0
171 | neig_mat = torch.sparse.FloatTensor(torch.stack([mat_rows, mat_cols], dim=0),
172 | mat_vals,
173 | size=torch.Size([num_verts, num_verts]))
174 | vs = vs.T
175 |
176 | sum_count = torch.sparse.mm(neig_mat, torch.ones((num_verts, 1)).type_as(vs))
177 | mat_rows_ident = np.array([i for i in range(num_verts)])
178 | mat_cols_ident = np.array([i for i in range(num_verts)])
179 | mat_ident = np.array([-s for s in sum_count[:, 0]])
180 | mat_rows_ident = torch.from_numpy(mat_rows_ident).long()
181 | mat_cols_ident = torch.from_numpy(mat_cols_ident).long()
182 | mat_ident = torch.from_numpy(mat_ident).long()
183 | mat_rows = torch.cat([mat_rows, mat_rows_ident])
184 | mat_cols = torch.cat([mat_cols, mat_cols_ident])
185 | mat_vals = torch.cat([mat_vals, mat_ident])
186 |
187 | self.lapmat = torch.sparse.FloatTensor(torch.stack([mat_rows, mat_cols], dim=0),
188 | mat_vals,
189 | size=torch.Size([num_verts, num_verts]))
190 |
191 | def build_vf(self):
192 | vf = [set() for _ in range(len(self.vs))]
193 | for i, f in enumerate(self.faces):
194 | vf[f[0]].add(i)
195 | vf[f[1]].add(i)
196 | vf[f[2]].add(i)
197 | self.vf = vf
198 |
199 | """ build vertex-to-face sparse matrix """
200 | v2f_inds = [[] for _ in range(2)]
201 | v2f_vals = []
202 | v2f_areas = [[] for _ in range(len(self.vs))]
203 | for i in range(len(vf)):
204 | v2f_inds[1] += list(vf[i])
205 | v2f_inds[0] += [i] * len(vf[i])
206 | v2f_vals += (self.fc[list(vf[i])] - self.vs[i].reshape(1, -1)).tolist()
207 | v2f_areas[i] = np.sum(self.fa[list(vf[i])])
208 | self.v2f_list = [v2f_inds, v2f_vals, v2f_areas]
209 |
210 | v2f_inds = torch.tensor(v2f_inds).long()
211 | v2f_vals = torch.ones(v2f_inds.shape[1]).float()
212 | self.v2f_mat = torch.sparse.FloatTensor(v2f_inds, v2f_vals, size=torch.Size([len(self.vs), len(self.faces)]))
213 | self.f2v_mat = torch.sparse.FloatTensor(torch.stack([v2f_inds[1], v2f_inds[0]], dim=0), v2f_vals, size=torch.Size([len(self.faces), len(self.vs)]))
214 |
215 | """ build face-to-face (1ring) matrix """
216 | f_edges = np.array([[i] * 3 for i in range(len(self.faces))])
217 | f2f = [[] for _ in range(len(self.faces))]
218 | self.f_edges = [[] for _ in range(2)]
219 | for i, f in enumerate(self.faces):
220 | all_neig = list(vf[f[0]]) + list(vf[f[1]]) + list(vf[f[2]])
221 | one_neig = np.array(list(Counter(all_neig).values())) == 2
222 | f2f_i = np.array(list(Counter(all_neig).keys()))[one_neig].tolist()
223 | self.f_edges[0] += len(f2f_i) * [i]
224 | self.f_edges[1] += f2f_i
225 | f2f[i] = f2f_i + (3 - len(f2f_i)) * [-1]
226 |
227 | self.f2f = np.array(f2f)
228 | self.f_edges = np.array(self.f_edges)
229 | edge_index = torch.tensor(self.edges.T, dtype=torch.long)
230 | self.edge_index = torch.cat([edge_index, edge_index[[1,0],:]], dim=1)
231 | self.face_index = torch.from_numpy(self.f_edges)
232 |
233 | # TODO: change this to correspond to non-watertight mesh
234 | f2f_inds = torch.from_numpy(self.f_edges).long()
235 | f2f_vals = -1.0 * torch.ones(f2f_inds.shape[1]).float()
236 | f2f_mat = torch.sparse.FloatTensor(f2f_inds, f2f_vals, size=torch.Size([len(self.faces), len(self.faces)]))
237 | f_eyes_inds = torch.arange(len(self.faces)).long().repeat(2, 1)
238 | f_dims = torch.ones(len(self.faces)).float() * 3.0 # TODO: change here
239 | f_eyes = torch.sparse.FloatTensor(f_eyes_inds, f_dims, size=torch.Size([len(self.faces), len(self.faces)]))
240 | self.f_lapmat = f2f_mat + f_eyes
241 |
242 | """ build face-to-face (2ring) sparse matrix
243 | self.f2f = np.array(f2f)
244 | f2ring = self.f2f[self.f2f].reshape(-1, 9)
245 | self.f2ring = [set(f) for f in f2ring]
246 | self.f2ring = [list(self.f2ring[i] | set(f)) for i, f in enumerate(self.f2f)]
247 |
248 | self.f_edges = np.concatenate((self.f2f.reshape(1, -1), f_edges.reshape(1, -1)), 0)
249 | mat_inds = torch.from_numpy(self.f_edges).long()
250 | #mat_vals = torch.ones(mat_inds.shape[1]).float()
251 | mat_vals = torch.from_numpy(self.fa[self.f_edges[0]]).float()
252 | self.f2f_mat = torch.sparse.FloatTensor(mat_inds, mat_vals, size=torch.Size([len(self.faces), len(self.faces)]))
253 | """
254 | def build_v2v(self):
255 | v2v = [[] for _ in range(len(self.vs))]
256 | for i, e in enumerate(self.edges):
257 | v2v[e[0]].append(e[1])
258 | v2v[e[1]].append(e[0])
259 | self.v2v = v2v
260 |
261 | """ compute adjacent matrix """
262 | edges = self.edges
263 | v2v_inds = edges.T
264 | v2v_inds = torch.from_numpy(np.concatenate([v2v_inds, v2v_inds[[1, 0]]], axis=1)).long()
265 | v2v_vals = torch.ones(v2v_inds.shape[1]).float()
266 | self.Adj = torch.sparse.FloatTensor(v2v_inds, v2v_vals, size=torch.Size([len(self.vs), len(self.vs)]))
267 | self.v_dims = torch.sum(self.Adj.to_dense(), axis=1)
268 | D_inds = torch.stack([torch.arange(len(self.vs)), torch.arange(len(self.vs))], dim=0).long()
269 | D_vals = 1 / self.v_dims
270 | self.Diag = torch.sparse.FloatTensor(D_inds, D_vals, size=torch.Size([len(self.vs), len(self.vs)]))
271 | I = torch.eye(len(self.vs))
272 | self.AdjI = (I + self.Adj).to_sparse()
273 | Lap = I - torch.sparse.mm(self.Diag, self.Adj.to_dense())
274 | self.Lap = Lap.to_sparse()
275 |
276 | def build_adj_mat(self):
277 | edges = self.edges
278 | v2v_inds = edges.T
279 | v2v_inds = torch.from_numpy(np.concatenate([v2v_inds, v2v_inds[[1, 0]]], axis=1)).long()
280 | v2v_vals = torch.ones(v2v_inds.shape[1]).float()
281 | self.Adj = torch.sparse.FloatTensor(v2v_inds, v2v_vals, size=torch.Size([len(self.vs), len(self.vs)]))
282 | self.v_dims = torch.sum(self.Adj.to_dense(), axis=1)
283 | D_inds = torch.stack([torch.arange(len(self.vs)), torch.arange(len(self.vs))], dim=0).long()
284 | D_vals = 1 / (torch.pow(self.v_dims, 0.5) + 1.0e-12)
285 | self.D_minus_half = torch.sparse.FloatTensor(D_inds, D_vals, size=torch.Size([len(self.vs), len(self.vs)]))
286 |
287 | def build_mesh_lap(self):
288 | self.build_adj_mat()
289 |
290 | vs = self.vs
291 | edges = self.edges
292 | faces = self.faces
293 |
294 | e_dict = {}
295 | for e in edges:
296 | e0, e1 = min(e), max(e)
297 | e_dict[(e0, e1)] = []
298 |
299 | for f in faces:
300 | s = vs[f[1]] - vs[f[0]]
301 | t = vs[f[2]] - vs[f[1]]
302 | u = vs[f[0]] - vs[f[2]]
303 | cos_0 = np.inner(s, -u) / (np.linalg.norm(s) * np.linalg.norm(u))
304 | cos_1 = np.inner(t, -s) / (np.linalg.norm(t) * np.linalg.norm(s))
305 | cos_2 = np.inner(u, -t) / (np.linalg.norm(u) * np.linalg.norm(t))
306 | cot_0 = cos_0 / (np.sqrt(1 - cos_0 ** 2) + 1e-12)
307 | cot_1 = cos_1 / (np.sqrt(1 - cos_1 ** 2) + 1e-12)
308 | cot_2 = cos_2 / (np.sqrt(1 - cos_2 ** 2) + 1e-12)
309 | key_0 = (min(f[1], f[2]), max(f[1], f[2]))
310 | key_1 = (min(f[2], f[0]), max(f[2], f[0]))
311 | key_2 = (min(f[0], f[1]), max(f[0], f[1]))
312 | e_dict[key_0].append(cot_0)
313 | e_dict[key_1].append(cot_1)
314 | e_dict[key_2].append(cot_2)
315 |
316 | for e in e_dict:
317 | e_dict[e] = -0.5 * (e_dict[e][0] + e_dict[e][1])
318 |
319 | C_ind = [[], []]
320 | C_val = []
321 | ident = [0] * len(vs)
322 | for e in e_dict:
323 | C_ind[0].append(e[0])
324 | C_ind[1].append(e[1])
325 | C_ind[0].append(e[1])
326 | C_ind[1].append(e[0])
327 | C_val.append(e_dict[e])
328 | C_val.append(e_dict[e])
329 | ident[e[0]] += -1.0 * e_dict[e]
330 | ident[e[1]] += -1.0 * e_dict[e]
331 | Am_ind = torch.LongTensor(C_ind)
332 | Am_val = -1.0 * torch.FloatTensor(C_val)
333 | self.Am = torch.sparse.FloatTensor(Am_ind, Am_val, torch.Size([len(vs), len(vs)]))
334 |
335 | for i in range(len(vs)):
336 | C_ind[0].append(i)
337 | C_ind[1].append(i)
338 |
339 | C_val = C_val + ident
340 | C_ind = torch.LongTensor(C_ind)
341 | C_val = torch.FloatTensor(C_val)
342 | # cotangent matrix
343 | self.Lm = torch.sparse.FloatTensor(C_ind, C_val, torch.Size([len(vs), len(vs)]))
344 | self.Dm = torch.diag(torch.tensor(ident)).float().to_sparse()
345 | self.Lm_sym = torch.sparse.mm(torch.pow(self.Dm, -0.5), torch.sparse.mm(self.Lm, torch.pow(self.Dm, -0.5).to_dense())).to_sparse()
346 | #self.L = torch.sparse.mm(self.D_minus_half, torch.sparse.mm(C, self.D_minus_half.to_dense()))
347 | self.Am_I = (torch.eye(len(vs)) + self.Am).to_sparse()
348 | Dm_I_diag = torch.sum(self.Am_I.to_dense(), dim=1)
349 | self.Dm_I = torch.diag(Dm_I_diag).to_sparse()
350 | self.meshconvF = torch.sparse.mm(torch.pow(self.Dm_I, -0.5), torch.sparse.mm(self.Am_I, torch.pow(self.Dm_I, -0.5).to_dense())).to_sparse()
351 |
352 | def get_chebconv_coef(self, k=2):
353 | coef_list = []
354 | eig_max = torch.lobpcg(self.Lm_sym, k=1)[0][0]
355 | Lm_hat = -1.0 * torch.eye(len(self.vs)) + 2.0 * self.Lm_sym / eig_max.item()
356 | self.Lm_hat = Lm_hat.to_sparse()
357 | for i in range(k):
358 | if i == 0:
359 | coef = torch.eye(len(self.vs)).to_sparse()
360 | coef_list.append(coef)
361 | elif i == 1:
362 | coef_list.append(self.Lm_hat)
363 | else:
364 | coef = 2.0 * torch.sparse.mm(self.Lm_hat, coef_list[-1].to_dense()) - coef_list[-2]
365 | coef_list.append(coef.to_sparse())
366 | return coef_list
367 |
368 | def eigen_decomposition(self, L, k=100):
369 | L = L.to_dense().numpy()
370 | csr = sp.sparse.csr_matrix(L)
371 | w, v = sp.sparse.linalg.eigs(csr, which="SR", k=k)
372 | #index = np.argsort(np.real(w))[::-1] # LR
373 | index = np.argsort(np.real(w)) # SR
374 | # vs_code can include either vertex code or face code
375 | vs_code = np.real(v[:, index]).astype(np.float)
376 | vs_code /= np.linalg.norm(vs_code, axis=0, keepdims=True)
377 | # import open3d as o3d
378 | # mesh = o3d.io.read_triangle_mesh(self.path)
379 | # for i in range(0, k, 1):
380 | # colors = np.zeros([len(self.vs), 3])
381 | # r = vs_code[:, i]
382 | # r = (r+1)/2
383 | # """
384 | # r_min, r_max = np.min(r), np.max(r)
385 | # r = (r - r_min) / (r_max-r_min+1.0e-16)
386 | # """
387 | # colors[:, 1] = r
388 | # mesh.vertex_colors = o3d.utility.Vector3dVector(colors)
389 | # o3d.io.write_triangle_mesh("eigen_decom/eigen_{}.obj".format(str(i)), mesh)
390 | # import pdb;pdb.set_trace()
391 |
392 | return vs_code
393 |
394 | def simplification(self, target_v, valence_aware=True, midpoint=True):
395 | vs, vf, fn, fc, edges = self.vs, self.vf, self.fn, self.fc, self.edges
396 |
397 | """ 1. compute Q for each vertex """
398 | Q_s = [[] for _ in range(len(vs))]
399 | E_s = [[] for _ in range(len(vs))]
400 | for i, v in enumerate(vs):
401 | f_s = np.array(list(vf[i]), dtype=np.int64)
402 | fc_s = fc[f_s]
403 | fn_s = fn[f_s]
404 | d_s = - 1.0 * np.sum(fn_s * fc_s, axis=1, keepdims=True)
405 | abcd_s = np.concatenate([fn_s, d_s], axis=1)
406 | Q_s[i] = np.matmul(abcd_s.T, abcd_s)
407 |
408 | v4 = np.concatenate([v, np.array([1])])
409 | E_s[i] = np.matmul(v4, np.matmul(Q_s[i], v4.T))
410 |
411 | """ 2. compute E for every possible pairs and create heapq """
412 | E_heap = []
413 | for i, e in enumerate(edges):
414 | v_0, v_1 = vs[e[0]], vs[e[1]]
415 | Q_0, Q_1 = Q_s[e[0]], Q_s[e[1]]
416 | Q_new = Q_0 + Q_1
417 |
418 | if midpoint:
419 | v_new = 0.5 * (v_0 + v_1)
420 | v4_new = np.concatenate([v_new, np.array([1])])
421 | else:
422 | Q_lp = np.eye(4)
423 | Q_lp[:3] = Q_new[:3]
424 | try:
425 | Q_lp_inv = np.linalg.inv(Q_lp)
426 | v4_new = np.matmul(Q_lp_inv, np.array([[0,0,0,1]]).reshape(-1,1)).reshape(-1)
427 | except:
428 | v_new = 0.5 * (v_0 + v_1)
429 | v4_new = np.concatenate([v_new, np.array([1])])
430 |
431 | valence_penalty = 1
432 | if valence_aware:
433 | merged_faces = vf[e[0]].intersection(vf[e[1]])
434 | valence_new = len(vf[e[0]].union(vf[e[1]]).difference(merged_faces))
435 | valence_penalty = self.valence_weight(valence_new)
436 |
437 | E_new = np.matmul(v4_new, np.matmul(Q_new, v4_new.T)) * valence_penalty
438 | heapq.heappush(E_heap, (E_new, (e[0], e[1])))
439 |
440 | """ 3. collapse minimum-error vertex """
441 | simp_mesh = copy.deepcopy(self)
442 |
443 | vi_mask = np.ones([len(simp_mesh.vs)]).astype(np.bool_)
444 | fi_mask = np.ones([len(simp_mesh.faces)]).astype(np.bool_)
445 |
446 | vert_map = [{i} for i in range(len(simp_mesh.vs))]
447 |
448 | while np.sum(vi_mask) > target_v:
449 | if len(E_heap) == 0:
450 | print("edge cannot be collapsed anymore!")
451 | break
452 |
453 | E_0, (vi_0, vi_1) = heapq.heappop(E_heap)
454 |
455 | if (vi_mask[vi_0] == False) or (vi_mask[vi_1] == False):
456 | continue
457 |
458 | """ edge collapse """
459 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1])))
460 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vi_1])
461 |
462 | if len(shared_vv) != 2:
463 | """ non-manifold! """
464 | #print("non-manifold can be occured!!" , len(shared_vv))
465 | self.remove_tri_valance(simp_mesh, vi_0, vi_1, shared_vv, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap)
466 | continue
467 |
468 | elif len(merged_faces) != 2:
469 | """ boundary """
470 | #print("boundary edge cannot be collapsed!")
471 | continue
472 |
473 | else:
474 | self.edge_collapse(simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap, valence_aware=valence_aware)
475 | #print(np.sum(vi_mask), np.sum(fi_mask))
476 |
477 | self.rebuild_mesh(simp_mesh, vi_mask, fi_mask, vert_map)
478 | simp_mesh.simp = True
479 | simp_mesh.org = self
480 | self.build_hash(simp_mesh, vi_mask, vert_map)
481 |
482 | return simp_mesh
483 |
484 | def edge_based_simplification(self, target_v, valence_aware=True):
485 | vs, vf, fn, fc, edges = self.vs, self.vf, self.fn, self.fc, self.edges
486 | edge_len = vs[edges][:,0,:] - vs[edges][:,1,:]
487 | edge_len = np.linalg.norm(edge_len, axis=1)
488 | edge_len_heap = np.stack([edge_len, np.arange(len(edge_len))], axis=1).tolist()
489 | heapq.heapify(edge_len_heap)
490 |
491 | """ 2. compute E for every possible pairs and create heapq """
492 | E_heap = []
493 | for i, e in enumerate(edges):
494 | v_0, v_1 = vs[e[0]], vs[e[1]]
495 | heapq.heappush(E_heap, (edge_len[i], (e[0], e[1])))
496 |
497 | """ 3. collapse minimum-error vertex """
498 | simp_mesh = copy.deepcopy(self)
499 |
500 | vi_mask = np.ones([len(simp_mesh.vs)]).astype(np.bool_)
501 | fi_mask = np.ones([len(simp_mesh.faces)]).astype(np.bool_)
502 |
503 | vert_map = [{i} for i in range(len(simp_mesh.vs))]
504 |
505 | while np.sum(vi_mask) > target_v:
506 | if len(E_heap) == 0:
507 | print("[Warning]: edge cannot be collapsed anymore!")
508 | break
509 |
510 | E_0, (vi_0, vi_1) = heapq.heappop(E_heap)
511 |
512 | if (vi_mask[vi_0] == False) or (vi_mask[vi_1] == False):
513 | continue
514 |
515 | """ edge collapse """
516 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1])))
517 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vi_1])
518 |
519 | if len(shared_vv) != 2:
520 | """ non-manifold! """
521 | # print("non-manifold can be occured!!" , len(shared_vv))
522 | continue
523 |
524 | elif len(merged_faces) != 2:
525 | """ boundary """
526 | # print("boundary edge cannot be collapsed!")
527 | continue
528 |
529 | else:
530 | self.edge_based_collapse(simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, E_heap, valence_aware=valence_aware)
531 | # print(np.sum(vi_mask), np.sum(fi_mask))
532 |
533 | self.rebuild_mesh(simp_mesh, vi_mask, fi_mask, vert_map)
534 | simp_mesh.simp = True
535 | self.build_hash(simp_mesh, vi_mask, vert_map)
536 |
537 | return simp_mesh
538 |
539 | @staticmethod
540 | def remove_tri_valance(simp_mesh, vi_0, vi_1, shared_vv, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap):
541 | return
542 |
543 | def edge_collapse(self, simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, Q_s, E_heap, valence_aware):
544 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1])))
545 | new_vi_0 = set(simp_mesh.v2v[vi_0]).union(set(simp_mesh.v2v[vi_1])).difference({vi_0, vi_1})
546 | simp_mesh.vf[vi_0] = simp_mesh.vf[vi_0].union(simp_mesh.vf[vi_1]).difference(merged_faces)
547 | simp_mesh.vf[vi_1] = set()
548 | simp_mesh.vf[shared_vv[0]] = simp_mesh.vf[shared_vv[0]].difference(merged_faces)
549 | simp_mesh.vf[shared_vv[1]] = simp_mesh.vf[shared_vv[1]].difference(merged_faces)
550 |
551 | simp_mesh.v2v[vi_0] = list(new_vi_0)
552 | for v in simp_mesh.v2v[vi_1]:
553 | if v != vi_0:
554 | simp_mesh.v2v[v] = list(set(simp_mesh.v2v[v]).difference({vi_1}).union({vi_0}))
555 | simp_mesh.v2v[vi_1] = []
556 | vi_mask[vi_1] = False
557 |
558 | vert_map[vi_0] = vert_map[vi_0].union(vert_map[vi_1])
559 | vert_map[vi_0] = vert_map[vi_0].union({vi_1})
560 | vert_map[vi_1] = set()
561 |
562 | fi_mask[np.array(list(merged_faces)).astype(np.int64)] = False
563 |
564 | simp_mesh.vs[vi_0] = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vi_1])
565 |
566 | """ recompute E """
567 | Q_0 = Q_s[vi_0]
568 | for vv_i in simp_mesh.v2v[vi_0]:
569 | v_mid = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vv_i])
570 | Q_1 = Q_s[vv_i]
571 | Q_new = Q_0 + Q_1
572 | v4_mid = np.concatenate([v_mid, np.array([1])])
573 |
574 | valence_penalty = 1
575 | if valence_aware:
576 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vv_i])
577 | valence_new = len(simp_mesh.vf[vi_0].union(simp_mesh.vf[vv_i]).difference(merged_faces))
578 | valence_penalty = self.valence_weight(valence_new)
579 |
580 | E_new = np.matmul(v4_mid, np.matmul(Q_new, v4_mid.T)) * valence_penalty
581 | heapq.heappush(E_heap, (E_new, (vi_0, vv_i)))
582 |
583 | def edge_based_collapse(self, simp_mesh, vi_0, vi_1, merged_faces, vi_mask, fi_mask, vert_map, E_heap, valence_aware):
584 | shared_vv = list(set(simp_mesh.v2v[vi_0]).intersection(set(simp_mesh.v2v[vi_1])))
585 | new_vi_0 = set(simp_mesh.v2v[vi_0]).union(set(simp_mesh.v2v[vi_1])).difference({vi_0, vi_1})
586 | simp_mesh.vf[vi_0] = simp_mesh.vf[vi_0].union(simp_mesh.vf[vi_1]).difference(merged_faces)
587 | simp_mesh.vf[vi_1] = set()
588 | simp_mesh.vf[shared_vv[0]] = simp_mesh.vf[shared_vv[0]].difference(merged_faces)
589 | simp_mesh.vf[shared_vv[1]] = simp_mesh.vf[shared_vv[1]].difference(merged_faces)
590 |
591 | simp_mesh.v2v[vi_0] = list(new_vi_0)
592 | for v in simp_mesh.v2v[vi_1]:
593 | if v != vi_0:
594 | simp_mesh.v2v[v] = list(set(simp_mesh.v2v[v]).difference({vi_1}).union({vi_0}))
595 | simp_mesh.v2v[vi_1] = []
596 | vi_mask[vi_1] = False
597 |
598 | vert_map[vi_0] = vert_map[vi_0].union(vert_map[vi_1])
599 | vert_map[vi_0] = vert_map[vi_0].union({vi_1})
600 | vert_map[vi_1] = set()
601 |
602 | fi_mask[np.array(list(merged_faces)).astype(np.int64)] = False
603 |
604 | simp_mesh.vs[vi_0] = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vi_1])
605 |
606 | """ recompute E """
607 | for vv_i in simp_mesh.v2v[vi_0]:
608 | v_mid = 0.5 * (simp_mesh.vs[vi_0] + simp_mesh.vs[vv_i])
609 | edge_len = np.linalg.norm(simp_mesh.vs[vi_0] - simp_mesh.vs[vv_i])
610 | valence_penalty = 1
611 | if valence_aware:
612 | merged_faces = simp_mesh.vf[vi_0].intersection(simp_mesh.vf[vv_i])
613 | valence_new = len(simp_mesh.vf[vi_0].union(simp_mesh.vf[vv_i]).difference(merged_faces))
614 | valence_penalty = self.valence_weight(valence_new)
615 | edge_len *= valence_penalty
616 |
617 | heapq.heappush(E_heap, (edge_len, (vi_0, vv_i)))
618 |
619 | @staticmethod
620 | def valence_weight(valence_new):
621 | valence_penalty = abs(valence_new - OPTIM_VALENCE) * VALENCE_WEIGHT + 1
622 | if valence_new == 3:
623 | valence_penalty *= 100000
624 | return valence_penalty
625 |
626 | @staticmethod
627 | def rebuild_mesh(simp_mesh, vi_mask, fi_mask, vert_map):
628 | face_map = dict(zip(np.arange(len(vi_mask)), np.cumsum(vi_mask)-1))
629 | simp_mesh.vs = simp_mesh.vs[vi_mask]
630 |
631 | vert_dict = {}
632 | for i, vm in enumerate(vert_map):
633 | for j in vm:
634 | vert_dict[j] = i
635 |
636 | for i, f in enumerate(simp_mesh.faces):
637 | for j in range(3):
638 | if f[j] in vert_dict:
639 | simp_mesh.faces[i][j] = vert_dict[f[j]]
640 |
641 | simp_mesh.faces = simp_mesh.faces[fi_mask]
642 | for i, f in enumerate(simp_mesh.faces):
643 | for j in range(3):
644 | simp_mesh.faces[i][j] = face_map[f[j]]
645 |
646 | simp_mesh.compute_face_normals()
647 | simp_mesh.compute_face_center()
648 | simp_mesh.build_gemm()
649 | simp_mesh.compute_vert_normals()
650 | simp_mesh.build_v2v()
651 | simp_mesh.build_vf()
652 |
653 | @staticmethod
654 | def build_hash(simp_mesh, vi_mask, vert_map):
655 | pool_hash = {}
656 | unpool_hash = {}
657 | for simp_i, idx in enumerate(np.where(vi_mask)[0]):
658 | if len(vert_map[idx]) == 0:
659 | print("[ERROR] parent node cannot be found!")
660 | return
661 | for org_i in vert_map[idx]:
662 | pool_hash[org_i] = simp_i
663 | unpool_hash[simp_i] = list(vert_map[idx])
664 |
665 | """ check """
666 | vl_sum = 0
667 | for vl in unpool_hash.values():
668 | vl_sum += len(vl)
669 |
670 | if (len(set(pool_hash.keys())) != len(vi_mask)) or (vl_sum != len(vi_mask)):
671 | print("[ERROR] Original vetices cannot be covered!")
672 | return
673 |
674 | pool_hash = sorted(pool_hash.items(), key=lambda x:x[0])
675 | simp_mesh.pool_hash = pool_hash
676 | simp_mesh.unpool_hash = unpool_hash
677 |
678 | @staticmethod
679 | def mesh_merge(lap, org_mesh, new_pos, preserve, w=1, w_b=0):
680 | # TODO: Delete constraint of boundary
681 | org_pos = torch.from_numpy(org_mesh.vs).float()
682 | org_wo_boundary = torch.sparse.mm(org_mesh.AdjI.float(), 1-preserve.float().reshape(-1, 1)) == 0
683 | org_wo_boundary = org_wo_boundary.reshape(-1)
684 | boundary = torch.logical_xor(preserve, org_wo_boundary)
685 | A_org_wo_boundary = torch.eye(len(org_pos))[org_wo_boundary] * w
686 | A_boundary = torch.eye(len(org_pos))[boundary] * w_b
687 | A = torch.cat([lap.to_dense(), A_org_wo_boundary], dim=0)
688 | A = torch.cat([A, A_boundary], dim=0).to_sparse()
689 | b_mix = torch.sparse.mm(lap, new_pos)
690 | b_mix[org_wo_boundary] = torch.sparse.mm(lap, org_pos)[org_wo_boundary]
691 | b_org_wo_boundary = org_pos[org_wo_boundary] * w
692 | b_boundary = org_pos[boundary] * w_b
693 | b = torch.cat([b_mix, b_org_wo_boundary], dim=0)
694 | b = torch.cat([b, b_boundary], dim=0)
695 | AtA = torch.sparse.mm(A.t(), A.to_dense())
696 | Atb = torch.sparse.mm(A.t(), b)
697 | ref_pos = torch.linalg.solve(AtA, Atb)
698 | return ref_pos
699 |
700 | @staticmethod
701 | def copy_attribute(src_mesh, dst_mesh):
702 | dst_mesh.vf, dst_mesh.edges, dst_mesh.ve = src_mesh.vf, src_mesh.edges, src_mesh.ve
703 | dst_mesh.v2v = src_mesh.v2v
704 |
705 | def save(self, filename, color=False):
706 | assert len(self.vs) > 0
707 | vertices = np.array(self.vs, dtype=np.float32).flatten()
708 | indices = np.array(self.faces, dtype=np.uint32).flatten()
709 | v_colors = np.array(self.vc, dtype=np.float32).flatten()
710 |
711 | with open(filename, 'w') as fp:
712 | # Write positions
713 | if len(v_colors) == 0 or color == False:
714 | for i in range(0, vertices.size, 3):
715 | x = vertices[i + 0]
716 | y = vertices[i + 1]
717 | z = vertices[i + 2]
718 | fp.write('v {0:.8f} {1:.8f} {2:.8f}\n'.format(x, y, z))
719 |
720 | else:
721 | for i in range(0, vertices.size, 3):
722 | x = vertices[i + 0]
723 | y = vertices[i + 1]
724 | z = vertices[i + 2]
725 | c1 = v_colors[i + 0]
726 | c2 = v_colors[i + 1]
727 | c3 = v_colors[i + 2]
728 | fp.write('v {0:.8f} {1:.8f} {2:.8f} {3:.8f} {4:.8f} {5:.8f}\n'.format(x, y, z, c1, c2, c3))
729 |
730 | # Write indices
731 | for i in range(0, len(indices), 3):
732 | i0 = indices[i + 0] + 1
733 | i1 = indices[i + 1] + 1
734 | i2 = indices[i + 2] + 1
735 | fp.write('f {0} {1} {2}\n'.format(i0, i1, i2))
736 |
737 | def save_as_ply(self, filename, fn):
738 | assert len(self.vs) > 0
739 | vertices = np.array(self.vs, dtype=np.float32).flatten()
740 | indices = np.array(self.faces, dtype=np.uint32).flatten()
741 | fnormals = np.array(fn, dtype=np.float32).flatten()
742 |
743 | with open(filename, 'w') as fp:
744 | # Write Header
745 | fp.write("ply\nformat ascii 1.0\nelement vertex {}\n".format(len(self.vs)))
746 | fp.write("property float x\nproperty float y\nproperty float z\n")
747 | fp.write("element face {}\n".format(len(self.faces)))
748 | fp.write("property list uchar int vertex_indices\n")
749 | fp.write("property uchar red\nproperty uchar green\nproperty uchar blue\nproperty uchar alpha\n")
750 | fp.write("end_header\n")
751 | for i in range(0, vertices.size, 3):
752 | x = vertices[i + 0]
753 | y = vertices[i + 1]
754 | z = vertices[i + 2]
755 | fp.write("{0:.6f} {1:.6f} {2:.6f}\n".format(x, y, z))
756 |
757 | for i in range(0, len(indices), 3):
758 | i0 = indices[i + 0]
759 | i1 = indices[i + 1]
760 | i2 = indices[i + 2]
761 | c0 = fnormals[i + 0]
762 | c1 = fnormals[i + 1]
763 | c2 = fnormals[i + 2]
764 | c0 = np.clip(int(255 * c0), 0, 255)
765 | c1 = np.clip(int(255 * c1), 0, 255)
766 | c2 = np.clip(int(255 * c2), 0, 255)
767 | c3 = 255
768 | fp.write("3 {0} {1} {2} {3} {4} {5} {6}\n".format(i0, i1, i2, c0, c1, c2, c3))
769 |
770 | def display_face_normals(self, fn):
771 | import open3d as o3d
772 | self.compute_face_center()
773 | pcd = o3d.geometry.PointCloud()
774 | pcd.points = o3d.utility.Vector3dVector(np.asarray(self.fc))
775 | pcd.normals = o3d.utility.Vector3dVector(np.asarray(fn))
776 | o3d.visualization.draw_geometries([pcd])
777 |
--------------------------------------------------------------------------------
/util/meshnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import os
5 | import copy
6 | from torch_geometric.nn import GCNConv, ChebConv, Sequential
7 |
8 |
9 | class MeshPool(nn.Module):
10 | def __init__(self, pool_hash):
11 | super(MeshPool, self).__init__()
12 | self.register_buffer("pool_hash", pool_hash)
13 |
14 | def forward(self, input):
15 | v_sum = torch.sum(self.pool_hash.to_dense(), dim=1, keepdim=True)
16 | out = torch.sparse.mm(self.pool_hash, input) / v_sum
17 | return out
18 |
19 |
20 | class MeshUnpool(nn.Module):
21 | def __init__(self, unpool_hash):
22 | super(MeshUnpool, self).__init__()
23 | self.register_buffer("unpool_hash", unpool_hash)
24 |
25 | def forward(self, input):
26 | out = torch.sparse.mm(self.unpool_hash, input)
27 | return out
28 |
29 | #############################################################
30 |
31 | class DownConv(nn.Module):
32 | def __init__(self, in_channels, out_channels, edge_index1, edge_index2, pool_hash, K=3, drop_rate=0.0):
33 | super(DownConv, self).__init__()
34 | self.edge_index1 = edge_index1
35 | self.edge_index2 = edge_index2
36 | conv = "chebconv"
37 | if conv == "chebconv":
38 | """ chebconv """
39 | self.model1 = Sequential("x, edge_index", [
40 | (ChebConv(in_channels, out_channels, K=K), "x, edge_index -> x"),
41 | (nn.BatchNorm1d(out_channels), "x -> x"),
42 | (nn.LeakyReLU(), "x -> x"),
43 |
44 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
45 | (MeshPool(pool_hash), "x -> x"),
46 | (nn.BatchNorm1d(out_channels), "x -> x"),
47 | (nn.LeakyReLU(), "x -> x"),
48 | ])
49 | self.model2 = Sequential("x, edge_index", [
50 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
51 | (nn.BatchNorm1d(out_channels), "x -> x"),
52 | (nn.LeakyReLU(), "x -> x"),
53 |
54 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
55 | (nn.BatchNorm1d(out_channels), "x -> x"),
56 | (nn.LeakyReLU(), "x -> x"),
57 |
58 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
59 | (nn.BatchNorm1d(out_channels), "x -> x"),
60 | (nn.LeakyReLU(), "x -> x"),
61 |
62 | (nn.Dropout(drop_rate), "x -> x"),
63 | ])
64 | else:
65 | """ gcnconv """
66 | self.model1 = Sequential("x, edge_index", [
67 | (GCNConv(in_channels, out_channels), "x, edge_index -> x"),
68 | (nn.BatchNorm1d(out_channels), "x -> x"),
69 | (nn.LeakyReLU(), "x -> x"),
70 |
71 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
72 | (MeshPool(pool_hash), "x -> x"),
73 | (nn.BatchNorm1d(out_channels), "x -> x"),
74 | (nn.LeakyReLU(), "x -> x"),
75 | ])
76 | self.model2 = Sequential("x, edge_index", [
77 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
78 | (nn.BatchNorm1d(out_channels), "x -> x"),
79 | (nn.LeakyReLU(), "x -> x"),
80 |
81 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
82 | (nn.BatchNorm1d(out_channels), "x -> x"),
83 | (nn.LeakyReLU(), "x -> x"),
84 |
85 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
86 | (nn.BatchNorm1d(out_channels), "x -> x"),
87 | (nn.LeakyReLU(), "x -> x"),
88 |
89 | (nn.Dropout(drop_rate), "x -> x"),
90 | ])
91 |
92 | def forward(self, input):
93 | out = self.model1(input, self.edge_index1)
94 | out = self.model2(out, self.edge_index2)
95 | return out
96 |
97 |
98 | class UpConv(nn.Module):
99 | def __init__(self, in_channels, out_channels, edge_index1, edge_index2, unpool_hash, K=3, drop_rate=0.0):
100 | super(UpConv, self).__init__()
101 | self.edge_index1 = edge_index1
102 | self.edge_index2 = edge_index2
103 | conv = "chebconv"
104 | if conv == "chebconv":
105 | self.model1 = Sequential("x, edge_index", [
106 | (ChebConv(in_channels, out_channels, K=K), "x, edge_index -> x"),
107 | (MeshUnpool(unpool_hash), "x -> x"),
108 | (nn.BatchNorm1d(out_channels), "x -> x"),
109 | (nn.LeakyReLU(), "x -> x"),
110 | ])
111 | self.model2 = Sequential("x, edge_index", [
112 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
113 | (nn.BatchNorm1d(out_channels), "x -> x"),
114 | (nn.LeakyReLU(), "x -> x"),
115 |
116 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
117 | (nn.BatchNorm1d(out_channels), "x -> x"),
118 | (nn.LeakyReLU(), "x -> x"),
119 |
120 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
121 | (nn.BatchNorm1d(out_channels), "x -> x"),
122 | (nn.LeakyReLU(), "x -> x"),
123 |
124 | (ChebConv(out_channels, out_channels, K=K), "x, edge_index -> x"),
125 | (nn.BatchNorm1d(out_channels), "x -> x"),
126 | (nn.LeakyReLU(), "x -> x"),
127 |
128 | (nn.Dropout(drop_rate), "x -> x"),
129 | ])
130 | else:
131 | self.model1 = Sequential("x, edge_index", [
132 | (GCNConv(in_channels, out_channels), "x, edge_index -> x"),
133 | (MeshUnpool(unpool_hash), "x -> x"),
134 | (nn.BatchNorm1d(out_channels), "x -> x"),
135 | (nn.LeakyReLU(), "x -> x"),
136 | ])
137 | self.model2 = Sequential("x, edge_index", [
138 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
139 | (nn.BatchNorm1d(out_channels), "x -> x"),
140 | (nn.LeakyReLU(), "x -> x"),
141 |
142 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
143 | (nn.BatchNorm1d(out_channels), "x -> x"),
144 | (nn.LeakyReLU(), "x -> x"),
145 |
146 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
147 | (nn.BatchNorm1d(out_channels), "x -> x"),
148 | (nn.LeakyReLU(), "x -> x"),
149 |
150 | (GCNConv(out_channels, out_channels), "x, edge_index -> x"),
151 | (nn.BatchNorm1d(out_channels), "x -> x"),
152 | (nn.LeakyReLU(), "x -> x"),
153 |
154 | (nn.Dropout(drop_rate), "x -> x"),
155 | ])
156 |
157 | def forward(self, input):
158 | out = self.model1(input, self.edge_index1)
159 | out = self.model2(out, self.edge_index2)
160 | return out
161 |
162 |
163 | class MGCN(nn.Module):
164 | def __init__(self, device, smo_mesh, ini_mesh, v_mask, K=3, skip=False):
165 | super(MGCN, self).__init__()
166 | self.device = device
167 | self.skip = skip
168 |
169 | pool_levels = 3
170 | nv = len(smo_mesh.vs)
171 | nvs = [int(nv*(0.6**i)) for i in range(1, pool_levels+1)]
172 | self.nvs = nvs
173 | meshes = [smo_mesh]
174 | p_hashes = []
175 | up_hashes = []
176 | edge_inds = [ini_mesh.edge_index.to(device)]
177 | v_masks_list = [v_mask.reshape(-1, 1).float()]
178 | v_masks = v_mask.reshape(-1, 1).float()
179 | f_masks = (torch.sparse.mm(ini_mesh.f2v_mat, 1.0 - v_mask.reshape(-1,1).float()) == 0).bool().reshape(-1)
180 | f_masks_list = [f_masks]
181 | os.makedirs("{}/pooled".format(os.path.dirname(smo_mesh.path)), exist_ok=True)
182 | for i in range(pool_levels):
183 | # s_mesh = meshes[i].edge_based_simplification(target_v=nvs[i])
184 | s_mesh = meshes[i].simplification(target_v=nvs[i])
185 | meshes.append(s_mesh)
186 | s_pmask = self.pool_hash_to_mask(s_mesh.pool_hash)
187 | s_umask = self.unpool_hash_to_mask(s_mesh.pool_hash)
188 | p_hashes.append(s_pmask.to(device))
189 | up_hashes.append(s_umask.to(device))
190 | edge_inds.append(s_mesh.edge_index.to(device))
191 | vm_i = v_masks_list[-1]
192 | vm_i_inv = torch.logical_not(vm_i).float()
193 | vm_j = torch.logical_not(torch.sparse.mm(s_pmask, vm_i_inv)).float()
194 | v_masks_list.append(vm_j)
195 | v_masks = torch.cat([v_masks, vm_j], dim=0)
196 | fm = (torch.sparse.mm(s_mesh.f2v_mat, 1.0 - vm_j.reshape(-1,1)) == 0).bool().reshape(-1)
197 | f_masks = torch.cat([f_masks, fm], dim=0)
198 | f_masks_list.append(fm)
199 | color = np.ones([len(s_mesh.vs), 3])
200 | s_mesh.vc = color * vm_j.reshape(-1, 1).detach().numpy().copy()
201 | s_mesh.save("{}/pooled/{}_vs.obj".format(os.path.dirname(smo_mesh.path), len(s_mesh.vs)), color=True)
202 |
203 | self.meshes = meshes
204 | self.p_hashes = p_hashes
205 | self.up_hashes = up_hashes
206 | self.edge_inds = edge_inds
207 | self.v_masks = v_masks
208 | self.v_masks_list = v_masks_list
209 | self.f_masks = f_masks
210 | self.f_masks_list = f_masks_list
211 |
212 | self.encoder1 = DownConv(4, 32, edge_inds[0], edge_inds[1], p_hashes[0], K=K, drop_rate=0.0)
213 | self.encoder2 = DownConv(32, 128, edge_inds[1], edge_inds[2], p_hashes[1], K=K, drop_rate=0.2)
214 | self.encoder3 = DownConv(128, 256, edge_inds[2], edge_inds[3], p_hashes[2], K=K, drop_rate=0.2)
215 |
216 | self.decoder3 = UpConv(256, 128, edge_inds[3], edge_inds[2], up_hashes[2], K=K, drop_rate=0.2)
217 | self.decoder2 = UpConv(128, 32, edge_inds[2], edge_inds[1], up_hashes[1], K=K, drop_rate=0.2)
218 | self.decoder1 = nn.Sequential(
219 | UpConv(32, 16, edge_inds[1], edge_inds[0], up_hashes[0], K=K, drop_rate=0.0),
220 | nn.Linear(16, 3),
221 | )
222 |
223 | self.mcnn3 = Sequential("x, edge_index", [
224 | (ChebConv(256, 32, K=K), "x, edge_index -> x"),
225 | # (GCNConv(256, 32), "x, edge_index -> x"),
226 | (nn.BatchNorm1d(32), "x -> x"),
227 | (nn.LeakyReLU(), "x -> x"),
228 | (nn.Linear(32, 3), "x -> x"),
229 | ])
230 |
231 | self.mcnn2 = Sequential("x, edge_index", [
232 | (ChebConv(128, 32, K=K), "x, edge_index -> x"),
233 | # (GCNConv(128, 32), "x, edge_index -> x"),
234 | (nn.BatchNorm1d(32), "x -> x"),
235 | (nn.LeakyReLU(), "x -> x"),
236 | (nn.Linear(32, 3), "x -> x"),
237 | ])
238 |
239 | self.mcnn1 = Sequential("x, edge_index", [
240 | (ChebConv(32, 32, K=K), "x, edge_index -> x"),
241 | # (GCNConv(32, 32, K=K), "x, edge_index -> x"),
242 | (nn.BatchNorm1d(32), "x -> x"),
243 | (nn.LeakyReLU(), "x -> x"),
244 | (nn.Linear(32, 3), "x -> x"),
245 | ])
246 |
247 | self.skip2 = nn.Linear(256, 128)
248 | self.skip1 = nn.Linear(64, 32)
249 |
250 | """"""
251 | org_pos = torch.from_numpy(ini_mesh.vs).float().to(device)
252 | org_smpos = torch.from_numpy(smo_mesh.vs).float().to(device)
253 | self.poss = org_pos
254 | self.poss_list = [org_pos]
255 | self.smposs = org_smpos
256 | self.smposs_list = [org_smpos]
257 | for l in range(pool_levels):
258 | simp_mesh = copy.deepcopy(meshes[l+1])
259 | pool = MeshPool(self.p_hashes[l])
260 | pos = pool(org_pos)
261 | simp_mesh.vs = pos.detach().to("cpu").numpy().copy()
262 | color = np.ones([len(simp_mesh.faces), 3])
263 | color[:, 0] = 1.0
264 | color[:, 1] = 0
265 | color[:, 2] = 1.0
266 | color[self.f_masks_list[l+1].detach().numpy().copy(), 0] = 0.332
267 | color[self.f_masks_list[l+1].detach().numpy().copy(), 1] = 0.664
268 | color[self.f_masks_list[l+1].detach().numpy().copy(), 2] = 1.0
269 | # color = color * self.f_masks_list[l+1].reshape(-1, 1).detach().numpy().copy()
270 | simp_mesh.save_as_ply("{}/pooled/ini_{}_vs.ply".format(os.path.dirname(simp_mesh.path), len(simp_mesh.vs)), color)
271 | sm_pos = torch.from_numpy(meshes[l+1].vs).float().to(device)
272 | self.poss = torch.cat([self.poss, pos], dim=0)
273 | self.poss_list.append(pos)
274 | self.smposs = torch.cat([self.smposs, sm_pos], dim=0)
275 | self.smposs_list.append(sm_pos)
276 | org_pos = pos
277 |
278 | def forward(self, data, dm=None):
279 |
280 | z1, _ = data.z1.to(self.device), data.x_pos.to(self.device)
281 |
282 | z_min, z_max = torch.min(z1, dim=0, keepdim=True)[0], torch.max(z1, dim=0, keepdim=True)[0]
283 | z_sc = torch.max(z_max - z_min)
284 | zc = (z_min + z_max) * 0.5
285 | z1 = (z1 - zc) / z_sc
286 |
287 | if type(dm) != np.ndarray:
288 | dm = torch.ones([z1.shape[0], 1])
289 | else:
290 | dm = torch.from_numpy(dm)
291 | dm = dm.to(self.device)
292 | z1[:, 0:3] = dm * z1[:, 0:3]
293 | z1 = torch.cat([z1, dm], dim=1)
294 |
295 | res1_enc = self.encoder1(z1)
296 | res2_enc = self.encoder2(res1_enc)
297 | res3_bot = self.encoder3(res2_enc)
298 | out3 = self.mcnn3(res3_bot, self.edge_inds[3])
299 |
300 | res2_dec = self.decoder3(res3_bot)
301 | if self.skip:
302 | res2_cat = torch.cat([res2_dec, res2_enc], dim=1)
303 | res2_dec = self.skip2(res2_cat)
304 | out2 = self.mcnn2(res2_dec, self.edge_inds[2])
305 |
306 | res1_dec = self.decoder2(res2_dec)
307 | if self.skip:
308 | res1_cat = torch.cat([res1_dec, res1_enc], dim=1)
309 | res1_dec = self.skip1(res1_cat)
310 | out1 = self.mcnn1(res1_dec, self.edge_inds[1])
311 |
312 | out0 = self.decoder1(res1_dec)
313 |
314 | pos0 = self.smposs_list[0] + out0
315 | pos1 = self.smposs_list[1] + out1
316 | pos2 = self.smposs_list[2] + out2
317 | pos3 = self.smposs_list[3] + out3
318 | return (pos0, pos1, pos2, pos3)
319 |
320 | def pool(self, dx, pool_hash):
321 | pool_hash = pool_hash.to(dx.device)
322 | v_sum = torch.sum(pool_hash.to_dense(), dim=1, keepdim=True)
323 | dx = torch.sparse.mm(pool_hash, dx) / v_sum
324 | return dx
325 |
326 | def unpool(self, dx, unpool_hash):
327 | unpool_hash = unpool_hash.to(dx.device)
328 | dx = torch.sparse.mm(unpool_hash, dx)
329 | return dx
330 |
331 | def pool_hash_to_mask(self, pool_hash):
332 | mask_ind = torch.stack([torch.tensor(pool_hash)[:, 1], torch.tensor(pool_hash)[:, 0]], dim=0)
333 | mask_val = torch.ones(mask_ind.shape[1]).float()
334 | mask_mat = torch.sparse.FloatTensor(mask_ind, mask_val)
335 | return mask_mat
336 |
337 | def unpool_hash_to_mask(self, pool_hash):
338 | mask_ind = torch.tensor(pool_hash).T
339 | mask_val = torch.ones(mask_ind.shape[1]).float()
340 | mask_mat = torch.sparse.FloatTensor(mask_ind, mask_val)
341 | return mask_mat
--------------------------------------------------------------------------------
/util/models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from functools import reduce
5 | from collections import Counter
6 | from typing import Union
7 | import scipy as sp
8 | from sklearn.preprocessing import normalize
9 | from util.mesh import Mesh
10 |
11 | def build_div(n_mesh, vn):
12 | vs = n_mesh.vs
13 | faces = n_mesh.faces
14 | fn = n_mesh.fn
15 | fa = n_mesh.fa
16 | vf = n_mesh.vf
17 | grad_b = [[] for _ in range(len(vs))]
18 | for i, v in enumerate(vf): # vf: triangle indices around vertex i
19 | for t in v: # for each triangle index t
20 | f = faces[t] # vertex indices in t
21 | f_n = fn[t] # face normal of t
22 | a = fa[t] # face area of t
23 | """sort vertex indices"""
24 | if f[1] == i:
25 | f = [f[1], f[2], f[0]]
26 | elif f[2] == i:
27 | f = [f[2], f[1], f[0]]
28 | x_kj = vs[f[2]] - vs[f[1]]
29 | x_kj = f_n * np.dot(x_kj, f_n) + np.cross(f_n, x_kj)
30 | x_kj /= 2
31 | grad_b[i].append(x_kj.tolist())
32 |
33 | div_v = [[] for _ in range(len(vn))]
34 | for i in range(len(vn)):
35 | tn = vn[i]
36 | g = np.array(grad_b[i])
37 | div_v[i] = np.dot(np.sum(g, 0), tn)
38 | div = np.array(div_v)
39 | div = np.tile(div, 3).reshape(3, -1).T
40 | return div
41 |
42 | def jacobi(n_mesh, div, iter=10):
43 | # preparation
44 | C = n_mesh.mesh_lap.to_dense()
45 | #C = n_mesh.lapmat.to_dense()
46 | B = div
47 | # boundary condition
48 | C_add = torch.eye(len(n_mesh.vs))
49 | C = torch.cat([C, C_add], dim=0)
50 | B_add = torch.from_numpy(n_mesh.vs)
51 | B = torch.cat([torch.from_numpy(B), B_add], dim=0)
52 | B = torch.matmul(C.T.float(), B.float())
53 | A = torch.matmul(C.T, C)
54 | # solve Ax=b by jacobi
55 | x = torch.from_numpy(n_mesh.vs).float()
56 |
57 | for i in range(iter):
58 | r = B - torch.matmul(A, x)
59 | alpha = torch.diagonal(torch.matmul(r.T, r)) / (torch.diagonal(torch.matmul(torch.matmul(r.T, A), r)) + 1e-12)
60 | x += alpha * r
61 |
62 | return x
63 |
64 | def poisson_mesh_edit(n_mesh, div):
65 | C = n_mesh.mesh_lap.to_dense()
66 | B = div
67 | # boundary condition
68 | C_add = torch.eye(len(n_mesh.vs))
69 | C = torch.cat([C, C_add], dim=0)
70 | B_add = torch.from_numpy(n_mesh.vs)
71 | B = torch.cat([torch.from_numpy(B), B_add], dim=0)
72 | A = torch.matmul(C.T, C)
73 | Ainv = torch.inverse(A)
74 | CtB = torch.matmul(C.T.float(), B.float())
75 | new_vs = torch.matmul(Ainv, CtB)
76 |
77 | return new_vs
78 |
79 | def cg(n_mesh, div, iter=50, a=100):
80 | # preparation
81 | #C = n_mesh.mesh_lap.to_dense()
82 | #C = n_mesh.lapmat.to_dense()
83 | C = n_mesh.cot_mat.to_dense()
84 | B = div
85 | # boundary condition
86 | C_add = torch.eye(len(n_mesh.vs)) * a
87 | C = torch.cat([C, C_add], dim=0)
88 | Ct = C.T.to_sparse()
89 | C = C.to_sparse()
90 | B_add = torch.from_numpy(n_mesh.vs) * a
91 | B = torch.cat([torch.from_numpy(B), B_add], dim=0)
92 | B = torch.sparse.mm(Ct.float(), B.float())
93 | A = torch.sparse.mm(Ct, C.to_dense())
94 | # solve Ax=b by cg
95 | """
96 | x_0 = torch.from_numpy(n_mesh.vs).float()
97 | r_0 = B - torch.matmul(A, x_0)
98 | p_0 = r_0
99 | for i in range(iter):
100 | y_0 = torch.matmul(A, p_0)
101 | alpha = torch.diagonal(torch.matmul(r_0.T, r_0)) / (torch.diagonal(torch.matmul(p_0.T, y_0) + 1e-12))
102 | x_1 = x_0 + alpha * p_0
103 | r_1 = r_0 - alpha * y_0
104 | if torch.sum(torch.norm(r_1, dim=0), dim=0) < 1e-4:
105 | break
106 | beta = torch.diagonal(torch.matmul(r_1.T, r_1)) / (torch.diagonal(torch.matmul(r_0.T, r_0)) + 1e-12)
107 | p_1 = r_1 + beta * p_0
108 |
109 | x_0 = x_1
110 | r_0 = r_1
111 | p_0 = p_1
112 | return x_1
113 | """
114 | A = A.detach().numpy().copy()
115 | x_0 = torch.from_numpy(n_mesh.vs).float()
116 | x_1 = []
117 | for i in range(3):
118 | x_1.append(sp.sparse.linalg.cg(A, B[:,i], x0=x_0[:,i], maxiter=iter)[0].tolist())
119 | return np.array(x_1).T
120 |
121 | def compute_fn(vs: torch.Tensor, faces: np.ndarray) -> torch.Tensor:
122 | """ compute face normals from mesh with Tensor """
123 | face_normals = torch.cross(vs[faces[:, 1]] - vs[faces[:, 0]], vs[faces[:, 2]] - vs[faces[:, 0]])
124 | norm = torch.sqrt(torch.sum(face_normals**2, dim=1))
125 | face_normals = face_normals / norm.repeat(3, 1).T
126 | return face_normals
127 |
128 | def compute_vn(vs: torch.Tensor, fn: torch.Tensor, faces: np.ndarray) -> torch.Tensor:
129 | """ compute vertex normals from mesh with Tensor"""
130 | vert_normals = torch.zeros((3, len(vs)))
131 | face_normals = fn
132 | faces = torch.from_numpy(faces).long().to(vs.device)
133 |
134 | nv = len(vs)
135 | nf = len(faces)
136 | mat_rows = torch.reshape(faces, (-1,)).to(vs.device)
137 | mat_cols = torch.tensor([[i] * 3 for i in range(nf)]).reshape(-1).to(vs.device)
138 | mat_vals = torch.ones(len(mat_rows)).to(vs.device)
139 | f2v_mat = torch.sparse.FloatTensor(torch.stack([mat_rows, mat_cols], dim=0),
140 | mat_vals,
141 | size=torch.Size([nv, nf]))
142 | vert_normals = torch.sparse.mm(f2v_mat, face_normals)
143 | norm = torch.sqrt(torch.sum(vert_normals**2, dim=1))
144 | vert_normals = vert_normals / norm.repeat(3, 1).T
145 | return vert_normals
146 |
147 | def uv2xyz(uv):
148 | u = 2.0 * np.pi * uv[:, 0] - np.pi
149 | v = np.pi * uv[:, 1]
150 | x = torch.sin(v) * torch.cos(u)
151 | y = torch.sin(v) * torch.sin(u)
152 | z = torch.cos(v)
153 | xyz = torch.stack([x, y, z]).T
154 | return xyz
155 |
156 | def compute_nvt(mesh: Mesh, alpha=0.2, beta=0.2, delta=0.3) -> np.ndarray:
157 | f2ring = mesh.f2ring
158 | fa = mesh.fa
159 | fn = mesh.fn
160 | fc = mesh.fc
161 |
162 | #f_group = np.zeros([len(fn), 3])
163 | fec_strength = np.zeros([len(fn), 3])
164 |
165 | for i, f in enumerate(f2ring):
166 | ci = fc[i].reshape(1, -1)
167 | cj = fc[f]
168 | nj = fn[f]
169 | """ (a cross b) cross a = (a dot a)b - (b dot a)a """
170 | a_a = np.sum((cj - ci) ** 2, 1).reshape(-1, 1)
171 | b_a = np.sum((cj - ci) * nj, 1).reshape(-1, 1)
172 | wj = a_a * nj - b_a * (cj - ci)
173 | wj = normalize(wj, norm="l2", axis=1)
174 |
175 | nw = np.sum(nj * wj, 1).reshape(-1, 1)
176 | nj_prime = 2 * nw * wj - nj
177 |
178 | am = np.max(fa[f]) + 1.0e-12
179 | aj = fa[f]
180 | cji_norm = np.linalg.norm(cj - ci, axis=1)
181 | sigma = np.mean(cji_norm) + 1.0e-12
182 | mu = (aj / am * np.exp(-cji_norm / sigma)).reshape(-1, 1)
183 | Ti = np.matmul(nj_prime.T, (nj_prime * mu))
184 | order = np.argsort(np.linalg.eig(Ti)[0])[::-1]
185 | e_vals = np.linalg.eig(Ti)[0][order]
186 | e_vecs = np.linalg.eig(Ti)[1][:, order]
187 | n_ave = np.sum(mu * nj_prime, 0)
188 |
189 | fec_strength[i][0] = e_vals[0] - e_vals[1] / np.sum(e_vals)
190 | fec_strength[i][1] = e_vals[1] - e_vals[2] / np.sum(e_vals)
191 | fec_strength[i][2] = e_vals[2] / np.sum(e_vals)
192 |
193 | """ create face group
194 | if len(e_vals) != 3:
195 | print("len(e_vals) < 3 !")
196 | elif e_vals[1] < 0.01 and e_vals[2] < 0.001:
197 | f_group[i] = np.array([-1, -1, 1])
198 | elif e_vals[1] > 0.01 and e_vals[2] < 0.1:
199 | f_group[i] = np.array([0, 1, -1])
200 | elif e_vals[2] > 0.1:
201 | f_group[i] = np.array([-1, 1, -1])
202 | else:
203 | f_group[i] = np.array([-1, 0, 1])
204 | """
205 |
206 | return fec_strength
207 |
208 |
209 | def bnf(pos: torch.Tensor, fn: torch.Tensor, mesh: Mesh, loop=1) -> torch.Tensor:
210 | """ bilateral loss for face normal """
211 | fc = torch.sum(pos[mesh.faces], 1) / 3.0
212 | fa = torch.cross(pos[mesh.faces[:, 1]] - pos[mesh.faces[:, 0]], pos[mesh.faces[:, 2]] - pos[mesh.faces[:, 0]])
213 | fa = 0.5 * torch.sqrt(torch.sum(fa**2, axis=1) + 1.0e-12)
214 |
215 | f2f = torch.from_numpy(mesh.f2f).long().to(pos.device)
216 | no_neig = 1.0 * (f2f != -1)
217 |
218 | neig_fc = fc[f2f]
219 | neig_fa = fa[f2f] * no_neig
220 | fc0_tile = fc.reshape(-1, 1, 3)
221 | fc_dist = torch.sum((neig_fc-fc0_tile)**2, dim=2)
222 | sigma_c = torch.sum(torch.sqrt(fc_dist + 1.0e-12)) / (fc_dist.shape[0] * fc_dist.shape[1])
223 |
224 | new_fn = fn
225 | for i in range(loop):
226 | neig_fn = new_fn[f2f]
227 | fn0_tile = new_fn.reshape(-1, 1, 3)
228 | fn_dist = torch.sum((neig_fn-fn0_tile)**2, dim=2)
229 | sigma_s = 0.3
230 | wc = torch.exp(-1.0 * fc_dist / (2 * (sigma_c ** 2)))
231 | ws = torch.exp(-1.0 * fn_dist / (2 * (sigma_s ** 2)))
232 |
233 | W = torch.stack([wc*ws*neig_fa, wc*ws*neig_fa, wc*ws*neig_fa], dim=2)
234 |
235 | new_fn = torch.sum(W * neig_fn, dim=1)
236 | new_fn = new_fn / (torch.norm(new_fn, dim=1, keepdim=True) + 1.0e-12)
237 | return new_fn
238 |
239 | def vertex_updating(pos: torch.Tensor, norm: torch.Tensor, mesh: Mesh, loop=10) -> torch.Tensor:
240 | new_pos = pos.detach().clone()
241 | norm = norm.detach().clone()
242 | for iter in range(loop):
243 | fc = torch.sum(new_pos[mesh.faces], 1) / 3.0
244 | for i in range(len(new_pos)):
245 | cis = fc[list(mesh.vf[i])]
246 | nis = norm[list(mesh.vf[i])]
247 | cvis = cis - new_pos[i].reshape(1, -1)
248 | ncvis = torch.sum(nis * cvis, dim=1)
249 | dvi = torch.sum(ncvis.reshape(-1, 1) * nis, dim=0)
250 | dvi /= len(mesh.vf[i])
251 | new_pos[i] += dvi
252 | return new_pos
253 |
254 | def random_rotate(pos: Union[torch.Tensor, np.ndarray], seed: None) -> torch.Tensor:
255 | if type(pos) == np.ndarray:
256 | pos = torch.from_numpy(pos)
257 | if seed:
258 | np.random.seed(seed)
259 | rx, ry, rz = np.random.uniform(size=[3]) * 2 * np.pi
260 | Rx = np.array([[1, 0, 0],
261 | [0, np.cos(rx), -np.sin(rx)],
262 | [0, np.sin(rx), np.cos(rx)]])
263 | Ry = np.array([[np.cos(ry), 0, np.sin(ry)],
264 | [0, 1, 0],
265 | [-np.sin(ry), 0, np.cos(ry)]])
266 | Rz = np.array([[np.cos(rz), -np.sin(rz), 0],
267 | [np.sin(rz), np.cos(rz), 0],
268 | [0, 0, 1]])
269 | Rm = np.dot(Rz, np.dot(Ry, Rz))
270 | Rm = torch.from_numpy(Rm).float()
271 | new_pos = torch.mm(Rm, pos.T.float()).T
272 | return new_pos
273 |
--------------------------------------------------------------------------------
/util/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from torch_geometric.nn import GCNConv, ChebConv, Sequential
5 |
6 |
7 |
8 | class SingleScaleGCN(nn.Module):
9 | def __init__(self, device, activation="lrelu", skip=False):
10 | super(SingleScaleGCN, self).__init__()
11 | self.device = device
12 | self.skip = skip
13 | conv = "chebconv"
14 |
15 | h = [4, 16, 32, 64, 128, 256, 256, 512, 256, 256, 128, 64, 32, 16, 3]
16 |
17 | activ_dict = {"relu": nn.ReLU(), "lrelu": nn.LeakyReLU()}
18 | activation_func = activ_dict[activation]
19 |
20 | blocks = []
21 |
22 | if conv == "gcnconv":
23 | for i in range(12):
24 | block = Sequential("x, edge_index", [
25 | (GCNConv(h[i], h[i+1]), "x, edge_index -> x"),
26 | nn.BatchNorm1d(h[i+1]),
27 | activation_func,
28 | ])
29 | blocks.append(block)
30 |
31 | block = Sequential("x, edge_index", [
32 | (GCNConv(h[12], h[13]), "x, edge_index -> x"),
33 | nn.BatchNorm1d(h[13]),
34 | activation_func,
35 | (nn.Linear(h[13], h[14]), "x -> x"),
36 | ])
37 | blocks.append(block)
38 |
39 | elif conv == "chebconv":
40 | for i in range(12):
41 | block = Sequential("x, edge_index", [
42 | (ChebConv(h[i], h[i+1], K=3), "x, edge_index -> x"),
43 | nn.BatchNorm1d(h[i+1]),
44 | activation_func,
45 | ])
46 | blocks.append(block)
47 |
48 | block = Sequential("x, edge_index", [
49 | (ChebConv(h[12], h[13], K=3), "x, edge_index -> x"),
50 | nn.BatchNorm1d(h[13]),
51 | activation_func,
52 | (nn.Linear(h[13], h[14]), "x -> x"),
53 | ])
54 | blocks.append(block)
55 |
56 | self.blocks = nn.ModuleList(blocks)
57 |
58 | skip_blocks = []
59 | for i in range(6):
60 | skip_blocks.append(nn.Linear(h[i+1]*2, h[i+1]))
61 | self.skip_blocks = nn.ModuleList(skip_blocks)
62 |
63 | def forward(self, data, dm=None):
64 |
65 | z1, x_pos, edge_index = data.z1.to(self.device), data.x_pos.to(self.device), data.edge_index.to(self.device)
66 |
67 | z_min, z_max = torch.min(z1, dim=0, keepdim=True)[0], torch.max(z1, dim=0, keepdim=True)[0]
68 | z_sc = torch.max(z_max - z_min)
69 | zc = (z_min + z_max) * 0.5
70 | z1 = (z1 - zc) / z_sc
71 |
72 | if type(dm) == np.ndarray:
73 | dm = torch.from_numpy(dm)
74 | elif type(dm) != torch.Tensor:
75 | dm = torch.ones([z1.shape[0], 1])
76 |
77 | dm = dm.to(self.device)
78 | z1 = dm * z1
79 | z1 = torch.cat([z1, dm], dim=1)
80 |
81 | x = z1
82 | skip_in = []
83 | for i, b in enumerate(self.blocks):
84 | if i <= 5:
85 | """ encoder """
86 | y = b(x, edge_index)
87 | x = y
88 | skip_in.append(y)
89 |
90 | elif i <= 7:
91 | """ bottle-neck """
92 | y = b(x, edge_index)
93 | x = y
94 | else:
95 | """ decoder """
96 | if self.skip:
97 | x_src = skip_in[13-i]
98 | x_cat = torch.cat([x_src, x], dim=1)
99 | x = self.skip_blocks[13-i](x_cat)
100 | y = b(x, edge_index)
101 | x = y
102 |
103 | return x_pos + x
104 |
--------------------------------------------------------------------------------
/util/render.py:
--------------------------------------------------------------------------------
1 | import pyvista as pv
2 | import numpy as np
3 | from PIL import Image
4 |
5 |
6 | class Render():
7 | def __init__(self, mesh, gif_name):
8 | self.normalize_to_bb(mesh)
9 | self.pyv_mesh = self.load_mesh(mesh)
10 | self.save_gif(gif_name)
11 |
12 | def load_mesh(self, mesh):
13 | vs = mesh.vs
14 | faces = np.ones([len(mesh.faces), 1]) * 3
15 | faces = np.concatenate([faces, mesh.faces], axis=1).astype(np.int32)
16 | pyv_mesh = pv.PolyData(vs, faces)
17 | return pyv_mesh
18 |
19 | def normalize_to_bb(self, mesh):
20 | vs = mesh.vs
21 | maxs = np.max(vs, axis=0, keepdims=True)
22 | mins = np.min(vs, axis=0, keepdims=True)
23 | ranges = maxs - mins
24 | vs = (vs - mins) / ranges
25 | mesh.vs = vs - 0.5
26 |
27 | def save_img(self, gif_name):
28 | plotter = pv.Plotter(off_screen=True, notebook=False)
29 | plotter.add_mesh(self.pyv_mesh, color="orange")
30 | plotter.show(screenshot=gif_name)
31 |
32 | def save_gif(self, gif_name):
33 | figures = []
34 | plotter = pv.Plotter(off_screen=True, notebook=False, window_size=[400,300])
35 | plotter.set_focus([0,0,0])
36 | plotter.set_position([2,2,2])
37 | for i in range(60):
38 | rot = self.pyv_mesh.rotate_y(6 * i, inplace=False)
39 | plotter = pv.Plotter(off_screen=True, notebook=False, window_size=[400,300])
40 | plotter.set_focus([0,0,0])
41 | plotter.set_position([2,2,2])
42 | plotter.add_mesh(rot, color="orange")
43 | img = plotter.screenshot(filename=None, return_img=True)
44 | figures.append(Image.fromarray(img))
45 | figures[0].save(gif_name, save_all=True, append_images=figures[1:], optimize=True, duration=50, loop=0)
--------------------------------------------------------------------------------