├── assets
└── teaser.gif
├── faceformer
├── datasets
│ ├── __init__.py
│ ├── data.py
│ └── data_para.py
├── models
│ ├── __init__.py
│ ├── model.py
│ └── model_para.py
├── utils.py
├── post_processing.py
├── config.py
├── embedding.py
├── transformer.py
└── trainer.py
├── configs
├── seq2seq.yml
├── seq2seq+coedge.yml
├── ours.yml
├── ours-perspective.yml
└── ours-fixed_viewpoint.yml
├── CITATION.cff
├── environment.yml
├── LICENSE
├── split_jsons.py
├── dataset
├── filters
│ ├── filter_length.py
│ ├── filter_thinness.py
│ ├── filter_topology.py
│ ├── filter_3view.py
│ ├── filter_thickness.py
│ └── 3view_render.py
├── reorganize_dataset_dirs.py
├── utils
│ ├── read_step_file.py
│ ├── Edge.py
│ ├── discretize_edge.py
│ ├── projection_utils.py
│ ├── Face.py
│ ├── json_to_svg.py
│ └── TopoMapper.py
├── README.md
├── tests
│ └── check_faces_enclosed.py
└── prepare_data.py
├── main.py
├── README.md
└── reconstruction
├── utils.py
├── reconstruction_utils.py
└── reconstruct_to_wireframe.py
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/manycore-research/faceformer/HEAD/assets/teaser.gif
--------------------------------------------------------------------------------
/faceformer/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .data import ABCDataset
2 | from .data_para import ABCDataset_Parallel
--------------------------------------------------------------------------------
/faceformer/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import SurfaceFormer
2 | from .model_para import SurfaceFormer_Parallel
--------------------------------------------------------------------------------
/configs/seq2seq.yml:
--------------------------------------------------------------------------------
1 | # All faces as a seq, cylinder face cut into 2
2 |
3 | # dataset generation method:
4 | # - prepare_data.py ... --combine_coedge --order_by_position --random_camera --focus 0
5 |
6 | root_dir: "test_set/seq2seq"
7 | trainer:
8 | name: 'SurfaceFormer'
9 | version: 'seq2seq'
10 | num_gpus: [0]
11 |
12 | model:
13 | num_lines: 110
14 | label_seq_length: 259
15 | max_num_faces: 42
16 |
17 | post_process:
18 | is_coedge: False
19 |
--------------------------------------------------------------------------------
/configs/seq2seq+coedge.yml:
--------------------------------------------------------------------------------
1 | # All faces as a seq, cylinder face cut into 2,
2 | # ordered edges within each face, each edge considered twice.
3 |
4 | # dataset generation method:
5 | # - prepare_data.py ... --no_face_type --random_camera --focus 0
6 |
7 | root_dir: "test_set/seq2seq+coedge"
8 | trainer:
9 | name: 'SurfaceFormer'
10 | version: 'seq2seq+coedge'
11 | num_gpus: [0]
12 |
13 | model:
14 | num_lines: 216
15 | label_seq_length: 259
16 |
17 | post_process:
18 | is_coedge: True
--------------------------------------------------------------------------------
/configs/ours.yml:
--------------------------------------------------------------------------------
1 | # Each face as a seq
2 |
3 | # dataset generation method:
4 | # python dataset/prepare_data.py ... --random_camera --focus 0
5 |
6 | model_class: 'SurfaceFormer_Parallel'
7 | dataset_class: 'ABCDataset_Parallel'
8 | root_dir: "test_set/ours"
9 |
10 | batch_size_train: 4
11 | batch_size_valid: 20
12 |
13 | trainer:
14 | name: 'SurfaceFormer'
15 | version: 'ours'
16 | lr: 1.0e-4
17 | num_gpus: [0]
18 |
19 | model:
20 | num_lines: 216
21 | max_num_faces: 42
22 | max_face_length: 37
23 | token:
24 | PAD: 0
25 | face_type_offset: 1
26 | len: 4
27 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | preferred-citation:
4 | type: conference-paper
5 | collection-type: proceedings
6 | title: "Neural Face Identification in a 2D Wireframe Projection of a Manifold Object"
7 | authors:
8 | - family-names: Wang
9 | given-names: Kehan
10 | - family-names: Zheng
11 | given-names: Jia
12 | - family-names: Zhou
13 | given-names: Zihan
14 | collection-title: "Proceedings of IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)"
15 | start: 1622
16 | end: 1631
17 | year: 2022
18 |
--------------------------------------------------------------------------------
/configs/ours-perspective.yml:
--------------------------------------------------------------------------------
1 | # Each face as a seq in perspective view
2 |
3 | # dataset generation method:
4 | # python dataset/prepare_data.py ... --random_camera
5 |
6 | model_class: 'SurfaceFormer_Parallel'
7 | dataset_class: 'ABCDataset_Parallel'
8 | root_dir: "test_set/ours-perspective"
9 |
10 | batch_size_train: 4
11 | batch_size_valid: 20
12 |
13 | trainer:
14 | name: 'SurfaceFormer'
15 | version: 'ours-perspective'
16 | lr: 1.0e-4
17 | num_gpus: [0]
18 |
19 | model:
20 | num_lines: 202
21 | max_num_faces: 42
22 | max_face_length: 38
23 | token:
24 | PAD: 0
25 | face_type_offset: 1
26 | len: 4
--------------------------------------------------------------------------------
/configs/ours-fixed_viewpoint.yml:
--------------------------------------------------------------------------------
1 | # Each face as a seq with fixed viewpoint
2 |
3 | # dataset generation method:
4 | # python dataset/prepare_data.py ... --focus 0
5 |
6 | model_class: 'SurfaceFormer_Parallel'
7 | dataset_class: 'ABCDataset_Parallel'
8 | root_dir: "test_set/ours-fixed_viewpoint"
9 |
10 | batch_size_train: 4
11 | batch_size_valid: 20
12 |
13 | trainer:
14 | name: 'SurfaceFormer'
15 | version: 'ours-fixed_viewpoint'
16 | lr: 1.0e-4
17 | num_gpus: [0]
18 |
19 | model:
20 | num_lines: 186
21 | max_num_faces: 42
22 | max_face_length: 33
23 | token:
24 | PAD: 0
25 | face_type_offset: 1
26 | len: 4
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: faceformer
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - anaconda
6 | - defaults
7 | dependencies:
8 | - cudatoolkit=11.0.221
9 | - jpeg=9b
10 | - json5=0.9.6
11 | - pillow=8.2.0
12 | - python=3.7.10
13 | - pythonocc-core=7.4.1
14 | - pythreejs=2.3.0
15 | - pytorch=1.7.1
16 | - pip=21.1.3
17 | - pip:
18 | - async-timeout==3.0.1
19 | - cairosvg==2.5.2
20 | - cvxpy==1.1.17
21 | - easydict==1.9
22 | - h5py==3.1.0
23 | - html4vision==0.4.3
24 | - matplotlib==3.4.2
25 | - numpy==1.19.5
26 | - numpyencoder==0.3.0
27 | - open3d==0.13.0
28 | - opencv-python==4.5.2.54
29 | - pytorch-lightning==1.3.5
30 | - pyyaml==5.4.1
31 | - scikit-learn==0.24.2
32 | - scipy==1.6.3
33 | - svgwrite==1.4.1
34 | - timeout-decorator==0.5.0
35 | - torchmetrics==0.3.2
36 | - tqdm==4.61.1
37 | - trimesh==3.9.20
38 | - CairoSVG==2.5.2
39 | - fvcore==0.1.5.post20210617
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Manycore Tech Inc.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/split_jsons.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import shutil
5 |
6 | def prepare_splits(args):
7 | names = []
8 | os.makedirs(os.path.join(args.root, 'json'), exist_ok=True)
9 | for name in sorted(os.listdir(args.root)):
10 | names.append(name[:8])
11 | shutil.move(os.path.join(args.root, name), os.path.join(args.root, "json"))
12 |
13 | np.random.seed(args.seed)
14 | np.random.shuffle(names)
15 | train_ratio, valid_ratio, test_ratio = args.split
16 | trainlist, validlist, testlist = np.split(names, [int(
17 | len(names) * train_ratio), int(len(names) * (train_ratio + valid_ratio))])
18 |
19 | np.savetxt(os.path.join(args.root, 'train.txt'), trainlist, fmt="json/%s.json")
20 | np.savetxt(os.path.join(args.root, 'valid.txt'), validlist, fmt="json/%s.json")
21 | np.savetxt(os.path.join(args.root, 'test.txt'), testlist, fmt="json/%s.json")
22 |
23 | if __name__ == '__main__':
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument('--root', type=str, default="./ours",
26 | help='dataset root. All files under root are all .json files.')
27 | parser.add_argument('--seed', type=int, default=42,
28 | help='numpy random seed')
29 | parser.add_argument('--split', nargs="+", type=int,
30 | default=[0.93, 0.02, 0.05],
31 | help='train/valid/test split ratio')
32 |
33 | args = parser.parse_args()
34 |
35 | prepare_splits(args)
36 |
--------------------------------------------------------------------------------
/dataset/filters/filter_length.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | from tqdm import tqdm
6 |
7 |
8 | def main(args):
9 | if args.clean_start:
10 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'r') as f:
11 | names = json.load(f)
12 | else:
13 | names = [os.path.splitext(name)[0] for name in os.listdir(os.path.join(args.root, 'json'))]
14 |
15 | filtered_names = []
16 |
17 | for name in tqdm(names):
18 | path = os.path.join(args.root, 'json', f'{name}.json')
19 | with open(path, 'r') as f:
20 | data = json.load(f)
21 | total_len = 0
22 | for face in data["faces_indices"]:
23 | total_len += 1+len(face)
24 | total_len += 1
25 | if total_len < args.face_seq_max and len(data['edges']) < args.num_edge_max:
26 | filtered_names.append(name)
27 |
28 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'w') as f:
29 | json.dump(filtered_names, f)
30 |
31 |
32 | if __name__ == "__main__":
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--root', type=str, default='/home/tianhan/data',
35 | help='dataset root')
36 | parser.add_argument('--face_seq_max', type=int, default=128,
37 | help='max length for the constructed face label')
38 | parser.add_argument('--num_edge_max', type=int, default=64,
39 | help='max number of edges in a shape')
40 | parser.add_argument('--clean_start', action='store_true',
41 | help='start from a clean id list')
42 | args = parser.parse_args()
43 |
44 | main(args)
45 |
--------------------------------------------------------------------------------
/faceformer/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def info_value_of_dtype(dtype: torch.dtype):
5 | """
6 | Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool.
7 | """
8 | if dtype == torch.bool:
9 | raise TypeError("Does not support torch.bool")
10 | elif dtype.is_floating_point:
11 | return torch.finfo(dtype)
12 | else:
13 | return torch.iinfo(dtype)
14 |
15 |
16 | def min_value_of_dtype(dtype: torch.dtype):
17 | """
18 | Returns the minimum value of a given PyTorch data type. Does not allow torch.bool.
19 | """
20 | return info_value_of_dtype(dtype).min
21 |
22 |
23 | def max_value_of_dtype(dtype: torch.dtype):
24 | """
25 | Returns the maximum value of a given PyTorch data type. Does not allow torch.bool.
26 | """
27 | return info_value_of_dtype(dtype).max
28 |
29 |
30 | def tiny_value_of_dtype(dtype: torch.dtype):
31 | """
32 | Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical
33 | issues such as division by zero.
34 | This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs.
35 | Only supports floating point dtypes.
36 | """
37 | if not dtype.is_floating_point:
38 | raise TypeError("Only supports floating point dtypes.")
39 | if dtype == torch.float or dtype == torch.double:
40 | return 1e-13
41 | elif dtype == torch.half:
42 | return 1e-4
43 | else:
44 | raise TypeError("Does not support dtype " + str(dtype))
45 |
46 |
47 | def flatten_list(l):
48 | """
49 | Flattens a list of lists.
50 | """
51 | return [item for sublist in l for item in sublist]
52 |
--------------------------------------------------------------------------------
/dataset/reorganize_dataset_dirs.py:
--------------------------------------------------------------------------------
1 | import os, argparse, time
2 | from tqdm import tqdm
3 |
4 |
5 | def main(args):
6 | for name in tqdm(sorted(os.listdir(os.path.join(args.root, args.subdir)))):
7 | dirpath = os.path.join(args.root, args.subdir, name)
8 | # rename file name to 8 digits
9 | if not os.path.isdir(dirpath):
10 | index_name, suffix = os.path.splitext(name)
11 | if len(index_name) != 8:
12 | srcpath = os.path.join(args.root, args.subdir, name)
13 | dstpath = os.path.join(args.root, args.subdir, index_name[:8]+suffix)
14 | os.rename(srcpath, dstpath)
15 | continue
16 | # move file out from their individual folder
17 | filenames = os.listdir(dirpath)
18 | if len(filenames) == 0:
19 | os.rmdir(dirpath)
20 | continue
21 | filename = filenames[0]
22 | suffix = os.path.splitext(filename)[1]
23 | srcpath = os.path.join(args.root, args.subdir, name, filename)
24 | dstpath = os.path.join(args.root, args.subdir, name+suffix)
25 | dirpath = os.path.join(args.root, args.subdir, name)
26 | os.rename(srcpath, dstpath)
27 | os.rmdir(dirpath)
28 | with open('data_processing_log.txt', 'a') as f:
29 | f.write(f"Reorganized folder {args.subdir} to proper structure - " + time.ctime() + '\n')
30 | f.close()
31 |
32 | if __name__ == '__main__':
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--root', type=str, default="./data",
35 | help='dataset root.')
36 | parser.add_argument('--subdir', type=str, default="step",
37 | help='dataset sub-directory to be reorganized.')
38 | args = parser.parse_args()
39 |
40 | main(args)
--------------------------------------------------------------------------------
/faceformer/post_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from dataset.tests.check_faces_enclosed import is_face_enclosed
4 | from faceformer.utils import flatten_list
5 |
6 |
7 | # For each face, if it is enclosed, sort its loops
8 | def filter_faces_by_encloseness(edges, faces, tol):
9 | # find corresponding edges from face indices
10 | filtered_faces = []
11 | for face_type, face in faces:
12 | all_face_loops = is_face_enclosed(edges, face, tol)
13 | if all_face_loops:
14 | # roll enclosed loops so smallest index is at the front
15 | all_face_loops = [tuple(np.roll(loop, -np.argmin(loop), axis=0).astype(int).tolist()) for loop in all_face_loops]
16 | # loops are ordered by first index
17 | all_face_loops = sorted(all_face_loops, key=lambda x: x[0])
18 | filtered_faces.append((face_type, tuple(all_face_loops)))
19 |
20 | return filtered_faces
21 |
22 | # Two coedges that represent the same edge should not be used in the same face
23 | def filter_faces_by_coedge(pairings, faces):
24 | filtered_faces = []
25 | used_indices = set()
26 | for face in faces:
27 | indices = flatten_list(face[1])
28 | drop_face = False
29 | for index in indices:
30 | if index in pairings:
31 | index = pairings[index]
32 | if index in used_indices:
33 | drop_face = True
34 | break
35 | used_indices.add(index)
36 | if not drop_face:
37 | filtered_faces.append(face)
38 |
39 | return filtered_faces
40 |
41 | def map_coedge_into_edges(pairings, indices):
42 | new_indices = []
43 | for i in indices:
44 | if str(i) in pairings:
45 | new_indices.append(pairings[str(i)])
46 | else:
47 | new_indices.append(i)
48 | return new_indices
49 |
--------------------------------------------------------------------------------
/dataset/utils/read_step_file.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import timeout_decorator
4 | from OCC.Core.IFSelect import IFSelect_ItemsByEntity, IFSelect_RetDone
5 | from OCC.Core.STEPControl import STEPControl_Reader
6 | from OCC.Extend.TopologyUtils import list_of_shapes_to_compound
7 |
8 |
9 | @timeout_decorator.timeout(5, use_signals=False)
10 | def read_step_file(filename, as_compound=True, verbosity=True, filter_num_shape=10):
11 | """ read the STEP file and returns a compound and number of shapes
12 | filename: the file path
13 | verbosity: optional, False by default.
14 | as_compound: True by default. If there are more than one shape at root,
15 | gather all shapes into one compound. Otherwise returns a list of shapes.
16 | """
17 | if not os.path.isfile(filename):
18 | raise FileNotFoundError("%s not found." % filename)
19 |
20 | step_reader = STEPControl_Reader()
21 | status = step_reader.ReadFile(filename)
22 |
23 | if status == IFSelect_RetDone: # check status
24 | if verbosity:
25 | failsonly = False
26 | step_reader.PrintCheckLoad(failsonly, IFSelect_ItemsByEntity)
27 | step_reader.PrintCheckTransfer(failsonly, IFSelect_ItemsByEntity)
28 | transfer_result = step_reader.TransferRoots()
29 | if not transfer_result:
30 | raise AssertionError("Transfer failed.")
31 | _nbs = step_reader.NbShapes()
32 | if _nbs == 0:
33 | raise AssertionError("No shape to transfer.")
34 | elif _nbs == 1: # most cases
35 | return step_reader.Shape(1), _nbs
36 | elif _nbs > 1:
37 | if _nbs > filter_num_shape:
38 | return None, _nbs
39 | shps = []
40 | # loop over root shapes
41 | for k in range(1, _nbs + 1):
42 | new_shp = step_reader.Shape(k)
43 | if not new_shp.IsNull():
44 | shps.append(new_shp)
45 | if as_compound:
46 | compound, result = list_of_shapes_to_compound(shps)
47 | if not result:
48 | print("Warning: all shapes were not added to the compound")
49 | return compound, _nbs
50 | else:
51 | print("Warning, returns a list of shapes.")
52 | return shps, _nbs
53 | else:
54 | raise AssertionError("Error: can't read file.")
55 | return None
56 |
--------------------------------------------------------------------------------
/faceformer/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from fvcore.common.config import CfgNode
4 |
5 | CN = CfgNode
6 |
7 | _C = CN()
8 | _C.model_class = 'SurfaceFormer'
9 | _C.dataset_class = 'ABCDataset'
10 | _C.root_dir = "/root/data"
11 |
12 | _C.batch_size_train = 64
13 | _C.batch_size_valid = 128
14 | _C.datasets_train = ['train.txt']
15 | _C.datasets_valid = ['valid.txt']
16 | _C.datasets_test = ['test.txt']
17 |
18 | _C.trainer = CN()
19 | _C.trainer.name = "surfaceformer"
20 | _C.trainer.version = "baseline"
21 | _C.trainer.num_gpus = [0]
22 | _C.trainer.precision = 16 # 16-bit training
23 | _C.trainer.checkpoint_period = 2
24 | _C.trainer.lr = 1e-3
25 | _C.trainer.lr_step = 0
26 |
27 | _C.model = CN()
28 | _C.model.num_points_per_line = 50
29 | _C.model.num_lines = 64
30 | _C.model.point_dim = 2
31 | _C.model.label_seq_length = 128
32 | _C.model.max_num_faces = 42
33 | _C.model.max_face_length = 34
34 | _C.model.num_model = 512
35 | _C.model.num_head = 8
36 | _C.model.num_feedforward = 1024
37 | _C.model.num_encoder_layers = 6
38 | _C.model.num_decoder_layers = 6
39 | _C.model.dropout = 0.2
40 | _C.model.token = CN()
41 | _C.model.token.PAD = 0
42 | _C.model.token.SOS = 1
43 | _C.model.token.SEP = 2
44 | _C.model.token.EOS = 3
45 | _C.model.token.DIR0 = 4
46 | _C.model.token.DIR1 = 5
47 | _C.model.token.len = 4
48 | _C.model.token.face_type_offset = 1
49 |
50 | _C.post_process = CN()
51 | _C.post_process.enclosedness_tol = 2e-4
52 | _C.post_process.is_coedge = True
53 |
54 | def get_parser():
55 | parser = argparse.ArgumentParser(description="SurfaceFormer Training")
56 | parser.add_argument("--config-file", default="", metavar="FILE",
57 | help="path to config file")
58 | parser.add_argument("--valid_ckpt", default="",
59 | help="path to validation checkpoint")
60 | parser.add_argument("--test_ckpt", default="",
61 | help="path to testing checkpoint")
62 | parser.add_argument("--resume_ckpt", default="",
63 | help="path to training checkpoint, will continue train from here")
64 | parser.add_argument(
65 | "opts",
66 | help="Modify config options using the command-line",
67 | default=None,
68 | nargs=argparse.REMAINDER,
69 | )
70 | return parser
71 |
72 |
73 | def get_cfg(args):
74 | cfg = _C.clone()
75 | if args.config_file:
76 | cfg.merge_from_file(args.config_file)
77 | cfg.merge_from_list(args.opts)
78 | cfg.freeze()
79 | return cfg
80 |
--------------------------------------------------------------------------------
/dataset/utils/Edge.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | class Edge:
4 | '''
5 | Edge is unique by the edge's hash
6 | Each edge should have two faces
7 | '''
8 |
9 | def __init__(self, edge, faces=[], orientations=[], dedge=None, index=None, DiscretizedEdge=None, dedge3d=None):
10 | self.edge = edge
11 | self.edges = [edge]
12 | self.faces = faces
13 | self.orientations = orientations
14 | self.dedge = dedge
15 | self.dedge3d = dedge3d
16 | self.index = index # index among all edges in TopoMapper, for construct faces
17 | self.DiscretizedEdge = DiscretizedEdge
18 |
19 | def add_face(self, face, orientation):
20 | self.faces.append(face)
21 | self.orientations.append(orientation)
22 | assert len(self.faces) <= 2, "Too many faces for one edge"
23 |
24 | def get_oriented_dedge(self, orientation, is_3d=False):
25 | '''
26 | Discretized edge is saved as normal orientation.
27 | return reversed when orientation is reversed.
28 | '''
29 | if is_3d:
30 | return self.dedge3d[::-1] if orientation else self.dedge3d
31 | return self.dedge[::-1] if orientation else self.dedge
32 |
33 | def __hash__(self):
34 | return hash(self.edge)
35 |
36 | def __eq__(self, other):
37 | return isinstance(other, Edge) and hash(self) == hash(other)
38 |
39 | def same_orientation(self, other):
40 | dist1 = np.sum(abs(np.array(self.dedge[-1]) - np.array(other.dedge[0])))
41 | dist2 = np.sum(abs(np.array(other.dedge[-1]) - np.array(self.dedge[0])))
42 | return dist1 < dist2
43 |
44 | def merge(self, other, topo):
45 | '''
46 | Merge two edges, considering the faces being merged.
47 | Not assigning self.edge to None so hash of the edge is still available.
48 | Assuming the orientation of dedge is the same.
49 | '''
50 | assert isinstance(other, Edge), 'Cannot merge edge with non-edge'
51 | # check orientation by looking at start and end of two edges
52 | if self.same_orientation(other):
53 | self.dedge = self.dedge + other.dedge
54 | self.edges = self.edges + other.edges
55 | else:
56 | self.dedge = other.dedge + self.dedge
57 | self.edges = other.edges + self.edges
58 |
59 | # remove other in its faces
60 | for face in other.faces:
61 | i = face.keys.index(hash(other.edge))
62 | del face.edges[i]
63 | del face.edge_orientations[i]
64 | del face.keys[i]
65 |
66 | # remove other edge from topo
67 | del topo.all_edges[hash(other.edge)]
68 | return self
69 |
--------------------------------------------------------------------------------
/dataset/utils/discretize_edge.py:
--------------------------------------------------------------------------------
1 | from functools import cmp_to_key
2 |
3 | import numpy as np
4 |
5 |
6 | class DiscretizedEdge:
7 | def __init__(self, points, smaller_edge=None, edge3d=None):
8 | self.points = points
9 | self.index = None
10 | self.smaller_edge = smaller_edge
11 | self.edge3d = edge3d
12 |
13 | def __eq__(self, obj):
14 | return isinstance(obj, DiscretizedEdge) and obj.points == self.points
15 |
16 | def correct_edge_direction(self, tolerance=1e-10):
17 | """
18 | Given a discretized_edge
19 | Point edge in the direction of smaller coordinate to larger coordinate.
20 | """
21 | if self.is_enclosed(tolerance):
22 | self.sort_enclosing_edge()
23 | else:
24 | if comp_points(self.points[0], self.points[-1]) > 0:
25 | # reverse edge
26 | self.points = list(reversed(self.points))
27 |
28 | # check for enclosed polyline with tolerance
29 | def is_enclosed(self, tolerance):
30 | return abs(self.points[0][0] - self.points[-1][0]) < tolerance and \
31 | abs(self.points[0][1] - self.points[-1][1]) < tolerance
32 |
33 | # rotate points for an enclosing edge
34 | def sort_enclosing_edge(self):
35 | # take out the repeating start/end
36 | enclosing_edge = self.points[1:]
37 |
38 | # find smallest starting point
39 | edge_array = np.array(enclosing_edge)
40 | d_edge = np.roll(
41 | edge_array, -np.argmin(edge_array[:, 0]), axis=0).tolist()
42 |
43 | # sort direction clock-wise by y-axis
44 | if d_edge[1][1] > d_edge[-1][1]:
45 | d_edge.append(d_edge[0])
46 | else:
47 | d_edge = [d_edge[0]] + list(reversed(d_edge))
48 |
49 | self.points = d_edge
50 |
51 |
52 | # rank coordinates first by x, then by y
53 | def comp_points(p1, p2):
54 | if p1[0] == p2[0]:
55 | return p1[1] - p2[1]
56 | return p1[0] - p2[0]
57 |
58 | # rank edges in sequence of the points
59 | # assuming edges themselves are sorted
60 |
61 |
62 | def comp_edges(e1, e2):
63 | e1, e2 = e1.points, e2.points
64 | N = min(len(e1), len(e2))
65 | for i in range(N):
66 | diff = comp_points(e1[i], e2[i])
67 | if diff == 0:
68 | continue
69 | return diff
70 | return 0
71 |
72 |
73 | def sort_edges_by_coordinate(edges):
74 | return sorted(edges, key=cmp_to_key(comp_edges))
75 |
76 |
77 | def comp_face_by_index(f1, f2):
78 | N = min(len(f1), len(f2))
79 | for i in range(N):
80 | diff = f1[i] - f2[i]
81 | if diff == 0:
82 | continue
83 | return diff
84 | return 0
85 |
86 |
87 | def sort_faces_by_indices(faces):
88 | return sorted(faces, key=cmp_to_key(comp_face_by_index))
89 |
90 |
--------------------------------------------------------------------------------
/dataset/filters/filter_thinness.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from functools import partial
5 |
6 | import numpy as np
7 | import trimesh
8 | import yaml
9 | from tqdm.contrib.concurrent import process_map
10 |
11 |
12 | def scale_to_unit_sphere(mesh):
13 | if isinstance(mesh, trimesh.Scene):
14 | mesh = mesh.dump().sum()
15 |
16 | vertices = mesh.vertices - mesh.bounding_box.centroid
17 | vertices *= 2 / np.linalg.norm(mesh.bounding_box.extents)
18 |
19 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False, maintain_order=True)
20 |
21 |
22 | def filter_by_raidus(name, args):
23 | mesh_path = os.path.join(args.root, 'obj', f'{name}.obj')
24 | mesh = trimesh.load_mesh(mesh_path, process=False, maintain_order=True)
25 |
26 | if isinstance(mesh, trimesh.Scene):
27 | mesh = mesh.dump().sum()
28 |
29 | scale = np.linalg.norm(mesh.bounding_box.extents)
30 |
31 | feat_path = os.path.join(args.root, 'feat', f'{name}.yml')
32 | with open(feat_path) as file:
33 | annos = yaml.full_load(file)
34 |
35 | radius_array = []
36 |
37 | for curve in annos['curves']:
38 | # finder the case with thinner cylinder
39 | if curve['type'] in ['Circle']:
40 | radius = curve['radius'] / scale
41 |
42 | elif curve['type'] in ['Ellipse']:
43 | radius = min(curve['maj_radius'], curve['min_radius']) / scale
44 |
45 | else:
46 | continue
47 |
48 | radius_array.append(radius)
49 |
50 | if len(radius_array) != 0:
51 | with open(os.path.join(args.root, 'radius', f'{name}.json'), 'w') as f:
52 | json.dump(min(radius_array), f)
53 |
54 | return name
55 |
56 |
57 | def main(args):
58 | with open(os.path.join(args.root, "meta", "filtered_thickness.json"), 'r') as f:
59 | names = json.load(f)
60 |
61 | os.makedirs(os.path.join(args.root, 'radius'), exist_ok=True)
62 |
63 | # preprocess
64 | rets = process_map(
65 | partial(filter_by_raidus, args=args), names,
66 | max_workers=args.num_cores, chunksize=args.num_chunks)
67 |
68 | filtered = [ret for ret in rets if ret is not None]
69 |
70 | with open(os.path.join(args.root, 'meta', 'filtered_thinness.json'), 'w') as f:
71 | json.dump(filtered, f)
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument('--root', type=str, default='/root/Datasets/Faceformer',
77 | help='dataset root')
78 | parser.add_argument('--threshold', type=float, default=0.05,
79 | help='threshold for closer edge')
80 | parser.add_argument('--num_cores', type=int,
81 | default=8, help='number of processors.')
82 | parser.add_argument('--num_chunks', type=int,
83 | default=64, help='number of chunk.')
84 | args = parser.parse_args()
85 |
86 | main(args)
87 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 |
7 | from faceformer.config import get_cfg, get_parser
8 | from faceformer.datasets import *
9 | from faceformer.models import *
10 | from faceformer.trainer import Trainer
11 |
12 |
13 | def str_to_class(classname):
14 | return getattr(sys.modules[__name__], classname)
15 |
16 | class CudaClearCacheCallback(pl.Callback):
17 | def on_train_start(self, trainer, pl_module):
18 | torch.cuda.empty_cache()
19 | def on_validation_start(self, trainer, pl_module):
20 | torch.cuda.empty_cache()
21 | def on_validation_end(self, trainer, pl_module):
22 | torch.cuda.empty_cache()
23 |
24 | if __name__ == "__main__":
25 | args = get_parser().parse_args()
26 |
27 | cfg = get_cfg(args)
28 |
29 | model_class = str_to_class(cfg.model_class)
30 | dataset_class = str_to_class(cfg.dataset_class)
31 | checkpoint_callback = pl.callbacks.ModelCheckpoint(
32 | save_last=True,
33 | filename='{epoch:d}-{valid_precision:.2f}',
34 | save_top_k=2,
35 | monitor='valid_precision',
36 | mode='max',
37 | every_n_val_epochs=1)
38 |
39 | logger = pl.loggers.TensorBoardLogger('logs/', name=cfg.trainer.name, version=cfg.trainer.version)
40 |
41 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(c) for c in cfg.trainer.num_gpus])
42 | gpus = list(range(len(cfg.trainer.num_gpus)))
43 |
44 | if args.test_ckpt != '':
45 | # Testing
46 | model = Trainer(cfg, model_class, dataset_class).load_from_checkpoint(args.test_ckpt, model_class=model_class, dataset_class=dataset_class)
47 | trainer = pl.Trainer(
48 | benchmark=True,
49 | gpus=gpus,
50 | precision=cfg.trainer.precision)
51 | trainer.test(model)
52 | elif args.valid_ckpt != '':
53 | # Validation
54 | model = Trainer(cfg, model_class, dataset_class).load_from_checkpoint(args.valid_ckpt, model_class=model_class, dataset_class=dataset_class)
55 | trainer = pl.Trainer(
56 | benchmark=True,
57 | gpus=gpus,
58 | precision=cfg.trainer.precision)
59 | trainer.validate(model)
60 | elif args.resume_ckpt != '':
61 | # Resume Training
62 | model = Trainer(cfg, model_class, dataset_class).load_from_checkpoint(args.resume_ckpt, model_class=model_class, dataset_class=dataset_class)
63 | trainer = pl.Trainer(
64 | logger=logger,
65 | benchmark=True,
66 | gpus=gpus,
67 | precision=cfg.trainer.precision,
68 | resume_from_checkpoint=args.resume_ckpt)
69 | trainer.fit(model)
70 | else:
71 | model = Trainer(cfg, model_class, dataset_class)
72 | trainer = pl.Trainer(
73 | logger=logger,
74 | callbacks=[checkpoint_callback, CudaClearCacheCallback()],
75 | check_val_every_n_epoch=cfg.trainer.checkpoint_period,
76 | log_every_n_steps=1,
77 | benchmark=True,
78 | gpus=gpus,
79 | precision=cfg.trainer.precision)
80 | trainer.fit(model)
81 |
--------------------------------------------------------------------------------
/dataset/filters/filter_topology.py:
--------------------------------------------------------------------------------
1 | """
2 | Topology Filtering
3 | * Group all objects of similar topology in the same bin
4 | """
5 |
6 | import argparse
7 | import json
8 | import os
9 |
10 | import yaml
11 | from sklearn.neighbors import NearestNeighbors
12 | from tqdm import tqdm
13 |
14 | types_of_curves = {
15 | "Line": 0, "Circle": 1, "Ellipse": 2, "BSpline": 3, "Other": 4}
16 | types_of_surfs = {
17 | "Plane": 0, "Cylinder": 1, "Cone": 2, "Sphere": 3, "Torus": 4,
18 | "Revolution": 5, "Extrusion": 6, "BSpline": 7, "Other": 8}
19 |
20 |
21 | def main(args):
22 | # all step files
23 | names = []
24 | for name in sorted(os.listdir(os.path.join(args.root, 'stat'))):
25 | names.append(name[:8])
26 |
27 | if os.path.exists(args.error_log):
28 | # remove shapes that give errors
29 | with open(args.error_log, 'r') as f:
30 | lines = f.read().splitlines()
31 |
32 | error_names = [line[:8] for line in lines if line[:8].isdigit()]
33 | id_list = [name for name in names if name not in set(error_names)]
34 | else:
35 | print("error log not found")
36 | id_list = names
37 |
38 | # gather topology info for each object as their feature
39 | names = []
40 | features = []
41 | for name in tqdm(id_list):
42 | names.append(name)
43 | path = os.path.join(args.root, 'stat', f'{name}.yml')
44 | with open(path, 'r') as f:
45 | data = yaml.safe_load(f)
46 |
47 | curves = [types_of_curves[curve] for curve in data['curves']]
48 | surfs = [types_of_surfs[surf] for surf in data['surfs']]
49 |
50 | curves_hist = [0] * len(types_of_curves)
51 | for curve in curves:
52 | curves_hist[curve] += 1
53 | surfs_hist = [0] * len(types_of_surfs)
54 | for surf in surfs:
55 | surfs_hist[surf] += 1
56 |
57 | feature = [data['#edges'], data['#parts'], data['#sharp'],
58 | data['#surfs'], *curves_hist, *surfs_hist]
59 | features.append(feature)
60 |
61 | # Use nearest neighbors to find clusterings
62 | neigh = NearestNeighbors()
63 | neigh.fit(features)
64 | dist, indices = neigh.radius_neighbors(features, args.similarity_threshold)
65 | bins = set([tuple(ind) for ind in indices])
66 | list_names = []
67 | for b in bins:
68 | name_bin = [names[i] for i in b]
69 | list_names.append(name_bin)
70 | with open("dataset/dataset_gen_logs/topo_matching_bins.json", 'w') as f:
71 | json.dump(list_names, f)
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument('--root', type=str, default="./data",
77 | help='dataset root.')
78 | parser.add_argument('--error_log', type=str,
79 | default="dataset/dataset_gen_logs/error.txt",
80 | help='dataset generation error log.')
81 | parser.add_argument('--similarity_threshold', type=float,
82 | default=0,
83 | help="grouping threhold for similarity")
84 | args = parser.parse_args()
85 |
86 | main(args)
87 |
--------------------------------------------------------------------------------
/dataset/README.md:
--------------------------------------------------------------------------------
1 | ## Our Dataset
2 |
3 | ### Annotation format
4 |
5 | We save the annotation in a JSON format, including 2D edge points and face loops by edge indices.
6 |
7 | ```json
8 | {
9 | 'edges': [
10 | [...], # edge 1
11 | [...], # edge 2
12 | ...
13 | ],
14 | 'faces_indices': [
15 | [...], # face 1
16 | [...], # face 2
17 | ...
18 | ],
19 | }
20 | ```
21 |
22 | ## Prepare Dataset
23 |
24 | Here, we provide tools to filter and parse data in [ABC dataset](https://archive.nyu.edu/handle/2451/43778). Please download `step`, `stat`, `obj`, and `feat`.
25 |
26 | ## Reorganize ABC dataset directory
27 |
28 | Remove the middle level folder after unzipping ABC dataset.
29 |
30 | ```bash
31 | python reorganize_dataset_dirs.py --root $ABC_ROOT_DIR
32 | ```
33 |
34 | ```bash
35 | # Original ABC Dataset Structure
36 | root
37 | └── step
38 | └──00000050
39 | └── 00000050.step
40 | # Reorganized ABC Dataset structure
41 | root
42 | └── step
43 | └── 00000050.step
44 | ```
45 |
46 | ### Dataset Directory Structure
47 |
48 | ```
49 | root
50 | ├── step
51 | │ └── 00000050.step
52 | ├── json
53 | │ └── 00000050.json
54 | ├── face_png
55 | │ └── 00000050_{face_index}.png
56 | ├── face_svg
57 | │ └── 00000050_{face_index}.svg
58 | ├── png
59 | │ └── 00000050.png
60 | └── svg
61 | └── 00000050.svg
62 | ```
63 |
64 | ## Command Lines
65 |
66 | #### Data Generation
67 |
68 | In each model's [config](configs), we detail the specific options needed to generate dataset of the correct format.
69 |
70 | ```bash
71 | # parse the entire ABC dataset
72 | python dataset/prepare_data.py --root $ABC_ROOT_DIR --id_list dataset/dataset_gen_logs/filtered_id_list.json > dataset/dataset_gen_logs/error.txt
73 | # parse a specific object (for debugging a single data)
74 | python dataset/prepare_data.py --root $ABC_ROOT_DIR --name $8_DIGIT_ID
75 | ```
76 |
77 | #### Dataset Filtering
78 |
79 | Filter ABC objects by similarity
80 | ```bash
81 | # 1. filter by topology similarity
82 | python dataset/filters/filter_topology.py --root $ABC_ROOT_DIR
83 | # 2. render the three views of the entire ABC dataset
84 | python dataset/filters/3view_render.py --root $ABC_ROOT_DIR --id_list dataset/dataset_gen_logs/filtered_id_list.json > dataset/dataset_gen_logs/3view_error.txt
85 | # 3. filter by three-view similarity
86 | python dataset/filters/filter_3view.py --root $ABC_ROOT_DIR
87 | ```
88 |
89 | Filter ABC objects by thickness
90 | ```bash
91 | python dataset/filters/filter_thickness.py --root $ABC_ROOT_DIR --save_root $DIR_FOR_TEMP_DATA
92 | ```
93 |
94 | Filter ABC objects by complexity
95 | ```bash
96 | # By default, $MAX_FACE_SEQ = 128, $MAX_NUM_EDGE = 64
97 | python dataset/filters/filter_length.py --root $ABC_ROOT_DIR --face_seq_max $MAX_FACE_SEQ --num_edge_max $MAX_NUM_EDGE
98 | ```
99 |
100 | Filter Generated Co-edge Data by Face Encloseness
101 | ```bash
102 | # Assume prepare_data.py has finished and all json files have generated
103 | python dataset/tests/check_faces_enclosed.py --root $ABC_ROOT_DIR --tol 1e-4 --remove
104 | # Regenerate train/valid/test splits
105 | python dataset/prepare_data.py --root $ABC_ROOT_DIR --only_split
106 | ```
--------------------------------------------------------------------------------
/dataset/tests/check_faces_enclosed.py:
--------------------------------------------------------------------------------
1 | '''
2 | Dataset Generation Integrity Test
3 | - Check if each face is enclosed
4 | '''
5 |
6 | import json, os, argparse
7 | from functools import partial
8 | from tqdm.contrib.concurrent import process_map
9 |
10 | # check if e1's end meets e2's start
11 | def e1_connects_e2(e1, e2, tol):
12 | return abs(e1[-1][0] - e2[0][0]) < tol and \
13 | abs(e1[-1][1] - e2[0][1]) < tol
14 |
15 | # face can be composed of multiple enclosed loops
16 | # check if all oriented-edges form loops
17 | # return all loops in list of lists if face is enclosed
18 | def is_face_enclosed(edges, face_indices, tol):
19 | all_loops = []
20 | curr_loop = []
21 | to_close = None # the start of an enclosed cycle, to be closed
22 | last_edge = None
23 | for ind in face_indices:
24 | if isinstance(ind, tuple):
25 | i, o = ind
26 | edge = edges[i][::-1] if o else edges[i]
27 | else:
28 | if ind < len(edges):
29 | edge = edges[ind]
30 | else:
31 | continue
32 | if to_close is None:
33 | to_close = edge
34 | else:
35 | # make sure the current edge connects to the last edge
36 | if not e1_connects_e2(last_edge, edge, tol):
37 | return False
38 |
39 | last_edge = edge
40 | curr_loop.append(ind)
41 | if e1_connects_e2(edge, to_close, tol):
42 | # close the current cycle
43 | to_close = None
44 | all_loops.append(curr_loop)
45 | curr_loop = []
46 | return all_loops if to_close is None else False
47 |
48 | def check_enclosed(name, args):
49 | path = os.path.join(args.root, 'json', f'{name}.json')
50 | with open(path, 'r') as f:
51 | data = json.load(f)
52 | edges = data['edges']
53 | faces_indices = data['faces_indices']
54 |
55 | for face_indices in faces_indices:
56 | if not is_face_enclosed(edges, face_indices, args.tol):
57 | if args.remove:
58 | # remove json from dataset
59 | os.remove(path)
60 | print(f"{name} contains unclosed face")
61 | return
62 |
63 | def main(args):
64 | names = []
65 | for name in sorted(os.listdir(os.path.join(args.root, 'json'))):
66 | names.append(name[:8])
67 |
68 | process_map(
69 | partial(check_enclosed, args=args), names,
70 | max_workers=args.num_cores, chunksize=args.num_chunks
71 | )
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument('--root', type=str, default="./data",
77 | help='dataset root.')
78 | parser.add_argument('--name', type=str, default=None,
79 | help='filename.')
80 | # default to 3e-4 since the discretization tolerance is 1e-4
81 | parser.add_argument('--tol', type=float,
82 | default=3e-4, help='same point tolerance.')
83 | parser.add_argument('--num_cores', type=int,
84 | default=40, help='number of processors.')
85 | parser.add_argument('--num_chunks', type=int,
86 | default=10, help='number of chunk.')
87 | parser.add_argument('--remove', action='store_true')
88 |
89 | args = parser.parse_args()
90 |
91 | if args.name is None:
92 | main(args)
93 | else:
94 | check_enclosed(args.name, args)
--------------------------------------------------------------------------------
/dataset/utils/projection_utils.py:
--------------------------------------------------------------------------------
1 | from OCC.Core.gp import gp_Ax2, gp_Dir, gp_Pnt
2 | from OCC.Core.HLRAlgo import HLRAlgo_Projector
3 | from OCC.Core.HLRBRep import HLRBRep_Algo, HLRBRep_HLRToShape
4 | from OCC.Extend.TopologyUtils import TopologyExplorer, discretize_edge
5 | import numpy as np
6 |
7 | def randnum(low, high):
8 | return np.random.rand() * (high - low) + low
9 |
10 | # generate a random camera
11 | def generate_random_camera_pos(seed):
12 | np.random.seed(seed)
13 | focus = randnum(3, 5)
14 | radius = randnum(1.25, 1.5) # distance of camera to origin
15 | phi = randnum(22.5, 67.5) # longitude, elevation of camera
16 | theta = randnum(0, 360) # latitude, rotation around z-axis
17 | return focus, pose_spherical(theta, phi, radius)
18 |
19 | def pose_spherical(theta, phi, radius):
20 | def trans_t(t): return np.array([
21 | [1, 0, 0, 0],
22 | [0, 1, 0, 0],
23 | [0, 0, 1, t],
24 | [0, 0, 0, 1],
25 | ], dtype=np.float32)
26 |
27 | def rot_phi(phi): return np.array([
28 | [1, 0, 0, 0],
29 | [0, np.cos(phi), -np.sin(phi), 0],
30 | [0, np.sin(phi), np.cos(phi), 0],
31 | [0, 0, 0, 1],
32 | ], dtype=np.float32)
33 |
34 | def rot_theta(th): return np.array([
35 | [np.cos(th), -np.sin(th), 0, 0],
36 | [np.sin(th), np.cos(th), 0, 0],
37 | [0, 0, 1, 0],
38 | [0, 0, 0, 1],
39 | ], dtype=np.float32)
40 |
41 | c2w = trans_t(radius)
42 | c2w = rot_phi(np.deg2rad(phi)) @ c2w
43 | c2w = rot_theta(np.deg2rad(theta)) @ c2w
44 | c2w = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) @ c2w
45 | return c2w
46 |
47 |
48 |
49 | def project_shapes(shapes, args):
50 | location = args.location
51 | direction = args.direction
52 | focus = args.focus
53 |
54 | hlr = HLRBRep_Algo()
55 |
56 | if isinstance(shapes, list):
57 | for shape in shapes:
58 | hlr.Add(shape)
59 | else:
60 | hlr.Add(shapes)
61 | ax = gp_Ax2(gp_Pnt(*location), gp_Dir(*direction))
62 |
63 | if args.pose is not None:
64 | pose = args.pose
65 | ax = gp_Ax2(gp_Pnt(*pose[:3, -1]), gp_Dir(*pose[:3, -2]), gp_Dir(*pose[:3, 0]))
66 |
67 | if focus == 0:
68 | projector = HLRAlgo_Projector(ax)
69 | else:
70 | projector = HLRAlgo_Projector(ax, focus)
71 |
72 | hlr.Projector(projector)
73 | hlr.Update()
74 |
75 | hlr_shapes = HLRBRep_HLRToShape(hlr)
76 | return hlr_shapes
77 |
78 |
79 | def discretize_compound(compound, tol):
80 | """
81 | Given a compound of edges
82 | Return all edges discretized
83 | """
84 | return [d3_to_d2(discretize_edge(edge, tol)) for edge in list(TopologyExplorer(compound).edges())]
85 |
86 |
87 | def d3_to_d2(points_3d):
88 | return [tuple(p[:2]) for p in points_3d]
89 |
90 |
91 | # project a list of 3D points
92 | def project_points(points, args):
93 | location = args.location
94 | direction = args.direction
95 | focus = args.focus
96 |
97 | ax = gp_Ax2(gp_Pnt(*location), gp_Dir(*direction))
98 |
99 | if args.pose is not None:
100 | pose = args.pose
101 | ax = gp_Ax2(gp_Pnt(*pose[:3, -1]), gp_Dir(*pose[:3, -2]), gp_Dir(*pose[:3, 0]))
102 |
103 | if focus == 0:
104 | projector = HLRAlgo_Projector(ax)
105 | else:
106 | projector = HLRAlgo_Projector(ax, focus)
107 |
108 | projected = [projector.Project(gp_Pnt(*p)) for p in points]
109 |
110 | return projected
--------------------------------------------------------------------------------
/dataset/filters/filter_3view.py:
--------------------------------------------------------------------------------
1 | """
2 | Three view Filtering
3 | * Further divide each bin after topology filtering
4 | * By the similarity in three-view line drawings
5 | """
6 |
7 | import argparse
8 | import json
9 | import os
10 |
11 | import cv2
12 | import numpy as np
13 | from sklearn.cluster import AgglomerativeClustering
14 | from sklearn.metrics import pairwise_distances
15 | from tqdm import tqdm
16 |
17 |
18 | def main(args):
19 | # topology bins
20 | with open("dataset/dataset_gen_logs/topo_matching_bins.json", 'r') as f:
21 | list_names = json.load(f)
22 |
23 | # bins of more than 1 objects
24 | multi_bins = [b for b in list_names if len(b) > 1]
25 |
26 | # 3view error list
27 | with open(args.error_log, 'r') as f:
28 | lines = f.read().splitlines()
29 | error_names = [line[:8] for line in lines if line[:8].isdigit()]
30 | error_names = set(error_names)
31 |
32 | # remove error objects from topology bins
33 | filtered_multi_bins = []
34 | new_bins = []
35 | for b in multi_bins:
36 | new_bin = [name for name in b if name not in error_names]
37 | if len(new_bin) == 0:
38 | continue
39 | if len(new_bin) == 1:
40 | new_bins.append(new_bin)
41 | else:
42 | filtered_multi_bins.append(new_bin)
43 |
44 | # cluster by jaccard distance in three views
45 | for large_bin in tqdm(filtered_multi_bins):
46 | all_bin_imgs = []
47 | for name in large_bin:
48 | feature = []
49 | for i in range(1, 4):
50 | img_path = os.path.join(
51 | args.root, '3view_png', f'{name}-{i}.png')
52 | originalImage = cv2.imread(img_path)
53 | if originalImage is None:
54 | feature.append(np.ones(128*128))
55 | continue
56 | halfImage = cv2.resize(originalImage, (0, 0), fx=0.5, fy=0.5)
57 | grayImage = cv2.cvtColor(halfImage, cv2.COLOR_BGR2GRAY)
58 | thresh, bin_img = cv2.threshold(
59 | grayImage, 254, 255, cv2.THRESH_BINARY)
60 | feature.append(bin_img)
61 | all_bin_imgs.append(np.array(feature).flatten())
62 |
63 | X = np.array(all_bin_imgs) == 0
64 |
65 | dist_mat = pairwise_distances(X, metric='jaccard')
66 | clusters = AgglomerativeClustering(n_clusters=None, affinity='precomputed',
67 | distance_threshold=args.similarity_threshold, linkage='single').fit(dist_mat)
68 | classes = clusters.labels_
69 | new_bin = [[] for _ in range(max(classes)+1)]
70 | for name, c in zip(large_bin, classes):
71 | new_bin[c].append(name)
72 | new_bins += new_bin
73 |
74 | # add bins of single object back
75 | for b in list_names:
76 | if len(b) == 1:
77 | new_bins.append(b)
78 |
79 | # generate a list of valid, unique objects
80 | # always sample the smallest from the bin
81 | extracted_names = sorted([min(b, key=lambda s: int(s)) for b in new_bins])
82 |
83 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'w') as f:
84 | json.dump(extracted_names, f)
85 |
86 |
87 | if __name__ == '__main__':
88 | parser = argparse.ArgumentParser()
89 | parser.add_argument('--root', type=str, default="/root/data",
90 | help='dataset root.')
91 | parser.add_argument('--error_log', type=str,
92 | default="dataset/dataset_gen_logs/3view_error.txt",
93 | help='3 view rendering error log.')
94 | parser.add_argument('--similarity_threshold', type=float, default=0.1,
95 | help="grouping threhold for jaccard distance")
96 | args = parser.parse_args()
97 |
98 | main(args)
99 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Neural Face Identification in a 2D Wireframe Projection of a Manifold Object
4 |
5 |
12 |
13 |
14 | IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 2022
15 |
16 |
17 | [](https://arxiv.org/abs/2203.04229)
18 | [](https://openaccess.thecvf.com/content/CVPR2022/html/Wang_Neural_Face_Identification_in_a_2D_Wireframe_Projection_of_a_CVPR_2022_paper.html)
19 |
20 |

21 |
22 |
23 |
24 | ## Requirements
25 |
26 | ```bash
27 | conda env create --file environment.yml
28 | conda activate faceformer
29 | ```
30 |
31 | ## Download Dataset
32 |
33 | We use CAD mechanical models from [ABC dataset](https://archive.nyu.edu/handle/2451/43778). In order to reproduce our results, we also release the dataset used in the paper [here](https://drive.google.com/drive/u/2/folders/1ynMD02E5FWlCPmQkWyjHdq4Zhe8DIXE2). If you would like to build the dataset by yourself, please refer to [here](dataset/README.md).
34 |
35 | ## Evaluation
36 |
37 | ### Face Identification Model
38 | Trained models can be downloaded [here](https://drive.google.com/drive/u/2/folders/1oEoN_GzS36obLjvOlwFrOpWo0N7oh-fS).
39 | ```bash
40 | python main.py --config-file configs/{MODEL_NAME}.yml --test_ckpt trained_models/{MODEL_NAME}.ckpt
41 | ```
42 |
43 | Face predictions will be saved to `lightning_logs/version_{LATEST}/json`.
44 |
45 | ### 3D Reconstruction
46 |
47 | ```bash
48 | # wireframe reconstruction
49 | python reconstruction/reconstruct_to_wireframe.py --root lightning_logs/version_{LATEST}
50 | # surface reconstruction
51 | python reconstruction/reconstruct_to_mesh.py --root lightning_logs/version_{LATEST}
52 | ```
53 |
54 | Reconstructed wireframes (*.ply*) or meshes (*obj*) files will be saved to `lightning_logs/version_{LATEST}/{ply/obj}`
55 |
56 | ## Train a Model from Scratch
57 |
58 | ```bash
59 | python main.py --config_file configs/{MODEL_NAME}.yml
60 | ```
61 |
62 | ## FAQs
63 |
64 | - *Why does root_dir not update when I change it in configs/ours.yml?*
65 | Seems like when pytorch_lightning loads the checkpoint in, it also uses the old root dir which we trained the model with.
66 | To fix: Please uncomment line 25 of faceformer/trainer.py and set the desired root_dir there.
67 |
68 | - *How should I use the downloaded json dataset?*
69 | Assuming we have downloaded *data_ours.tar.gz* and unzipped it to the same directory as [split_json.py](https://drive.google.com/drive/folders/1ynMD02E5FWlCPmQkWyjHdq4Zhe8DIXE2) in the outer-most directory, we now have:
70 |
71 | ```
72 | root
73 | ├── main.py
74 | ├── split_json.py
75 | ├── ours
76 | │ └── 00000050.json
77 | │ └── 00000052.json
78 | │ └── ...
79 | ```
80 |
81 | Run `python split_json.py` and it should prepare the dataset into the following:
82 |
83 | ```
84 | root
85 | ├── main.py
86 | ├── split_json.py
87 | ├── ours
88 | │ └── test.txt
89 | │ └── train.txt
90 | │ └── valid.txt
91 | │ └── json
92 | │ └── 00000050.json
93 | │ └── 00000052.json
94 | │ └── ...
95 | ```
96 |
97 | With this, set the root_dir to "ours" at line 25 of faceformer/trainer.py, and
98 | ```
99 | python main.py --config-file configs/ours.yml --test_ckpt trained_models/ours.ckpt
100 | ```
101 | should work.
102 |
103 |
104 |
105 | ## Acknowledgement
106 |
107 | The work was done during Kehan Wang's internship at Manycore Tech Inc.
108 |
--------------------------------------------------------------------------------
/faceformer/embedding.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class VanillaEmedding(nn.Module):
8 | def __init__(self, input_dim, num_model, token):
9 | super(VanillaEmedding, self).__init__()
10 | self.num_tokens = token.len
11 |
12 | # embedding for special tokens
13 | self.embedding_token = nn.Embedding(self.num_tokens, num_model)
14 | self.embedding_value = nn.Sequential(
15 | nn.Linear(input_dim, num_model),
16 | nn.ReLU(),
17 | nn.Linear(num_model, num_model)
18 | )
19 |
20 | def embed_points(self, lines):
21 | return lines.flatten(-2, -1)
22 |
23 | def forward(self, coord):
24 | """
25 | coord: N x L x P x D, N batch size, L num_edges/lines, P num_points, D num_axes
26 | E: num_model, model input dimension
27 | """
28 | N = coord.size(0)
29 |
30 | token = torch.arange(self.num_tokens, dtype=torch.long).to(coord.device)
31 | token_embed = self.embedding_token(token)
32 | token_embed = token_embed.unsqueeze(0).expand(N, self.num_tokens, -1) # N x 4 x E
33 |
34 | coord_embed = self.embedding_value(self.embed_points(coord)) # N x L x E
35 |
36 | value_embed = torch.cat((token_embed, coord_embed), dim=1) # N x (4+L) x E
37 |
38 | return value_embed
39 |
40 |
41 | class CoordinateEmbedding(nn.Module):
42 | def __init__(self, num_axes, num_bits, num_embed, num_model, dependent_embed=False):
43 | super(CoordinateEmbedding, self).__init__()
44 |
45 | ntoken = 2**num_bits if dependent_embed else 2**num_bits * num_axes
46 |
47 | # embedding
48 | self.embedding_token = nn.Embedding(3, num_model)
49 | self.embedding_value = nn.Embedding(ntoken, num_embed)
50 | self.linear_proj = nn.Linear(num_axes*num_embed, num_model, bias=False)
51 |
52 | def forward(self, coord):
53 | N, S, _ = coord.shape
54 |
55 | # N x S x E
56 | token = torch.arange(3, dtype=torch.long).to(coord.device)
57 | token = token.unsqueeze(0).expand(N, -1)
58 |
59 | token_embed = self.embedding_token(token)
60 | coord_embed = self.embedding_value(coord)
61 | coord_embed = self.linear_proj(coord_embed.view(N, S, -1))
62 |
63 | value_embed = torch.cat((token_embed, coord_embed), dim=1)
64 |
65 | return value_embed
66 |
67 |
68 | class PositionalEncoding(nn.Module):
69 | """
70 | This is a more standard version of the position embedding
71 | used by the Attention is all you need paper.
72 | """
73 |
74 | def __init__(self, num_model, max_len=5000):
75 | super(PositionalEncoding, self).__init__()
76 |
77 | pe = torch.zeros(max_len, num_model)
78 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
79 | div_term = torch.exp(torch.arange(
80 | 0, num_model, 2).float() * (-math.log(10000.0) / num_model))
81 | pe[:, 0::2] = torch.sin(position * div_term)
82 | pe[:, 1::2] = torch.cos(position * div_term)
83 | pe = pe.unsqueeze(0)
84 | self.register_buffer('pe', pe)
85 |
86 | def forward(self, x):
87 | return self.pe[:, :x.size(1)] # N x S x E
88 |
89 |
90 | class PositionEmbeddingLearned(nn.Module):
91 | """
92 | Absolute pos embedding, learned.
93 | """
94 |
95 | def __init__(self, num_model, max_len=5000):
96 | super().__init__()
97 |
98 | position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(0)
99 | self.register_buffer('position', position)
100 | self.pos_embed = nn.Embedding(max_len, num_model)
101 | self._init_embeddings()
102 |
103 | def _init_embeddings(self):
104 | nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in")
105 |
106 | def forward(self, x):
107 | pos = self.position[:, :x.size(1)] # N x L x E
108 | return self.pos_embed(pos)
109 |
--------------------------------------------------------------------------------
/faceformer/datasets/data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from faceformer.utils import flatten_list
8 |
9 |
10 | # TODO: try using bilinear interpolation for all cases
11 | def sample_points(edge, num_samples=50):
12 | if len(edge) == 2:
13 | return sample_points_on_line(edge, num_samples)
14 | return sample_points_on_curve(edge, num_samples)
15 |
16 |
17 | def sample_points_on_line(line, num_samples):
18 | t = np.linspace(0, 1, num_samples)
19 | x1, y1, x2, y2 = line[0][0], line[0][1], line[1][0], line[1][1]
20 | x = x1 + (x2-x1) * t
21 | y = y1 + (y2-y1) * t
22 | return np.vstack([x, y]).T
23 |
24 |
25 | def sample_points_on_curve(curve, num_samples):
26 | samples = np.linspace(0, len(curve)-1, num_samples).round(0).astype(int)
27 | curve = np.array(curve)
28 | return curve[samples]
29 |
30 |
31 | class ABCDataset(torch.utils.data.Dataset):
32 |
33 | def __init__(self, root_dir, datafile_path, config):
34 | super(ABCDataset, self).__init__()
35 | self.root_dir = root_dir
36 | self.info_files = self.parse_splits_list(datafile_path)
37 |
38 | # input shape L x P x D
39 | self.num_points_per_line = config.num_points_per_line # P
40 | self.num_lines = config.num_lines # L
41 | self.point_dim = config.point_dim # D
42 | # output shape S
43 | self.label_seq_length = config.label_seq_length
44 |
45 | self.token = config.token
46 |
47 | # preload all files
48 | self.raw_datas = []
49 | for info_file in self.info_files:
50 | with open(os.path.join(self.root_dir, info_file), "r") as f:
51 | self.raw_datas.append(json.loads(f.read()))
52 |
53 |
54 | def __len__(self):
55 | return len(self.info_files)
56 |
57 | def __getitem__(self, index):
58 | raw_data = self.raw_datas[index]
59 |
60 | edges, faces_indices = raw_data['edges'], raw_data['faces_indices']
61 |
62 | input = np.zeros(
63 | (self.num_lines, self.num_points_per_line, self.point_dim), dtype=np.float32)
64 | for i, edge in enumerate(edges):
65 | input[i, :self.num_points_per_line] = sample_points(
66 | edge, self.num_points_per_line)
67 |
68 | input_mask = np.ones(self.num_lines, dtype=np.bool)
69 | input_mask[:len(edges)] = 0
70 |
71 | label = np.ones(self.label_seq_length, dtype=np.int) * self.token.PAD
72 | label[0] = self.token.SOS
73 | curr_pos = 0
74 | for face in faces_indices:
75 | if not isinstance(face[0], int):
76 | face = flatten_list(face)
77 | curr_pos += 1
78 | label[curr_pos:curr_pos+len(face)] = face
79 | # shift face indices for special tokens
80 | label[curr_pos:curr_pos+len(face)] += self.token.len
81 | curr_pos += len(face)
82 | label[curr_pos] = self.token.SEP
83 | label[curr_pos] = self.token.EOS
84 | label_mask = (label == self.token.PAD)
85 |
86 | data = {
87 | 'id': index,
88 | 'input': input,
89 | 'label': label,
90 | 'num_input': len(edges),
91 | 'num_label': curr_pos+1,
92 | 'input_mask': input_mask,
93 | 'label_mask': label_mask,
94 | 'name': self.info_files[index]
95 | }
96 |
97 | return data
98 |
99 | def parse_splits_list(self, splits):
100 | """ Returns a list of info_file paths
101 | Args:
102 | splits (list of strings): each item is a path to a .json data file
103 | or a path to a .txt file containing a list of .json's relative paths from root.
104 | """
105 | if isinstance(splits, str):
106 | splits = splits.split()
107 | info_files = []
108 | for split in splits:
109 | ext = os.path.splitext(split)[1]
110 | split_path = os.path.join(self.root_dir, split)
111 | # split_path = os.path.join("/root/ablation/polys-test", split)
112 | if ext == '.json':
113 | info_files.append(split_path)
114 | elif ext == '.txt':
115 | info_files += [info_file.rstrip() for info_file in open(split_path, 'r')]
116 | else:
117 | raise NotImplementedError('%s not a valid info_file type' % split)
118 | return info_files
119 |
--------------------------------------------------------------------------------
/dataset/utils/Face.py:
--------------------------------------------------------------------------------
1 | from OCC.Core.GeomAbs import GeomAbs_Cylinder, GeomAbs_Plane
2 | import numpy as np
3 | from OCC.Core.BRepAdaptor import BRepAdaptor_Surface
4 |
5 | class Face:
6 | '''
7 | Face is a collection of non-repeating edges
8 | '''
9 | def __init__(self, face, topo):
10 | surface = BRepAdaptor_Surface(face)
11 | self.face = face
12 | self.face_type = surface.GetType()
13 | self.topo = topo
14 | self.edges = []
15 | self.edge_orientations = []
16 | self.keys = []
17 |
18 | # get face parametric values
19 | if self.face_type == GeomAbs_Plane:
20 | plane = surface.Surface().Plane()
21 | # plane parameters: Location, XAxis, YAxis, ZAxis, Coefficients
22 | self.parameters = {'Location': self._get_vector_parameters(plane.Location()),
23 | 'XAxis': self._get_axis_parameters(plane.XAxis()),
24 | 'YAxis': self._get_axis_parameters(plane.YAxis()),
25 | 'Normal': self._get_axis_parameters(plane.Axis()),
26 | 'Coefficients': plane.Coefficients()}
27 | elif self.face_type == GeomAbs_Cylinder:
28 | cylinder = surface.Surface().Cylinder()
29 | # cylinder parameters: Location, XAxis, YAxis, ZAxis, Coefficients, Radius
30 | self.parameters = {'Location': self._get_vector_parameters(cylinder.Location()),
31 | 'XAxis': self._get_axis_parameters(cylinder.XAxis()),
32 | 'YAxis': self._get_axis_parameters(cylinder.YAxis()),
33 | 'Normal': self._get_axis_parameters(cylinder.Axis()),
34 | 'Coefficients': cylinder.Coefficients(),
35 | 'Radius': cylinder.Radius()}
36 | else:
37 | self.parameters = None
38 |
39 | # Given an OCC vector
40 | # Return XYZ
41 | def _get_vector_parameters(self, vector):
42 | return vector.X(), vector.Y(), vector.Z()
43 |
44 | # Given an axis
45 | # Return Location(XYZ), Direction(XYZ)
46 | def _get_axis_parameters(self, axis):
47 | location = self._get_vector_parameters(axis.Location())
48 | direction = self._get_vector_parameters(axis.Direction())
49 | return location, direction
50 |
51 | def add_edge(self, edge, orientation):
52 | self.edges.append(edge)
53 | self.edge_orientations.append(orientation)
54 | self.keys.append(hash(edge))
55 |
56 | def remove_edge(self, key):
57 | ind = self.keys.index(key)
58 | del self.keys[ind]
59 | del self.edges[ind]
60 | del self.edge_orientations[ind]
61 |
62 | def get_oriented_dedges(self, is_3d=False):
63 | return [e.get_oriented_dedge(o, is_3d) for e, o in zip(self.edges, self.edge_orientations)]
64 |
65 | def get_edge_ind_and_orientation(self):
66 | return [(e.index, o) for e, o in zip(self.edges, self.edge_orientations)]
67 |
68 | def roll(self, n):
69 | self.edges = np.roll(self.edges, -n, axis=0).tolist()
70 | self.edge_orientations = np.roll(self.edge_orientations, -n, axis=0).tolist()
71 | self.keys = np.roll(self.keys, -n, axis=0).tolist()
72 |
73 | def merge(self, other):
74 | '''
75 | Merge faces on sewn edge.
76 | Assume both faces are rolled properly with sewn edge at the front.
77 | edge[0] is sewn edge.
78 | Return edge merging candidates
79 | '''
80 | assert isinstance(other, Face), 'Cannot merge face with non-face'
81 | sewn_edge = self.edges[0]
82 | if self == other:
83 | self.edges = self.edges[1:]
84 | self.edge_orientations = self.edge_orientations[1:]
85 | self.keys = self.keys[1:]
86 | key = hash(sewn_edge)
87 | if key in self.keys:
88 | self.remove_edge(key)
89 |
90 | del self.topo.all_edges[hash(sewn_edge)]
91 | return None
92 |
93 | # change faces in edge
94 | for edge in other.edges[1:]:
95 | i = edge.faces.index(other)
96 | edge.faces[i] = self
97 |
98 | # candidate merge edges
99 | candidates = [(self.keys[1], other.keys[-1]), (self.keys[-1], other.keys[1])]
100 |
101 | # merge face
102 | self.edges = self.edges[1:] + other.edges[1:]
103 | self.edge_orientations = self.edge_orientations[1:] + other.edge_orientations[1:]
104 | self.keys = self.keys[1:] + other.keys[1:]
105 | if self.face_type != other.face_type:
106 | self.face_type = 10 # set face type to other
107 |
108 |
109 | # remove sewn edge and other face in topo
110 | del self.topo.all_edges[hash(sewn_edge)]
111 | del self.topo.all_faces[hash(other.face)]
112 |
113 | return candidates
--------------------------------------------------------------------------------
/faceformer/datasets/data_para.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def sample_points(edge, num_samples=50):
9 | if len(edge) == 2:
10 | return sample_points_on_line(edge, num_samples)
11 | return sample_points_on_curve(edge, num_samples)
12 |
13 |
14 | def sample_points_on_line(line, num_samples):
15 | t = np.linspace(0, 1, num_samples)
16 | x1, y1, x2, y2 = line[0][0], line[0][1], line[1][0], line[1][1]
17 | x = x1 + (x2-x1) * t
18 | y = y1 + (y2-y1) * t
19 | return np.vstack([x, y]).T
20 |
21 |
22 | def sample_points_on_curve(curve, num_samples):
23 | samples = np.linspace(0, len(curve)-1, num_samples).round(0).astype(int)
24 | curve = np.array(curve)
25 | return curve[samples]
26 |
27 |
28 | class ABCDataset_Parallel(torch.utils.data.Dataset):
29 |
30 | def __init__(self, root_dir, datafile_path, config):
31 | super(ABCDataset_Parallel, self).__init__()
32 |
33 | self.root_dir = root_dir
34 | self.info_files = self.parse_splits_list(datafile_path)
35 |
36 | # input shape L x P x D
37 | self.num_points_per_line = config.num_points_per_line # P
38 | self.num_lines = config.num_lines # L
39 | self.point_dim = config.point_dim # D
40 |
41 | # output shape F x T
42 | self.max_num_faces = config.max_num_faces # F
43 | self.max_face_length = config.max_face_length # T
44 |
45 | self.token = config.token
46 |
47 | # preload all files
48 | self.raw_datas = []
49 | for info_file in self.info_files:
50 | with open(os.path.join(self.root_dir, info_file), "r") as f:
51 | self.raw_datas.append(json.loads(f.read()))
52 |
53 | def __len__(self):
54 | return len(self.info_files)
55 |
56 | def __getitem__(self, index):
57 | raw_data = self.raw_datas[index]
58 |
59 | edges, faces_indices = raw_data['edges'], raw_data['faces_indices']
60 |
61 | input = np.zeros(
62 | (self.num_lines, self.num_points_per_line, self.point_dim), dtype=np.float32)
63 | for i, edge in enumerate(edges):
64 | input[i, :self.num_points_per_line] = sample_points(
65 | edge, self.num_points_per_line)
66 |
67 | input_mask = np.ones(self.num_lines, dtype=np.bool) # L
68 | input_mask[:len(edges)] = 0
69 |
70 | # F x T
71 | label = np.ones((self.num_lines, self.max_face_length), dtype=np.int) * self.token.PAD
72 | ind = 0
73 | # each face: [(loop 1), ..., (loop n)]
74 | for face_with_type in faces_indices:
75 | type, face = face_with_type
76 | # only allow Plane - 0, Cylinder - 1, Other - 2
77 | if type > 1:
78 | type = 2
79 | # type offset set to 1
80 | type += self.token.face_type_offset
81 | # each loop rolls itself
82 | for loop in face:
83 | for i in range(len(loop)):
84 | # construct new seq
85 | rotated_loop = np.roll(loop, i, axis=0).tolist()
86 | new_seq = rotated_loop
87 | for other_loop in face:
88 | if other_loop != loop:
89 | new_seq += other_loop
90 | label[ind, :len(new_seq)] = new_seq
91 | # shift face indices for special tokens
92 | label[ind, :len(new_seq)] += self.token.len
93 | label[ind, len(new_seq)] = type
94 | ind += 1
95 | for i in range(ind, self.num_lines):
96 | label[i, 0] = self.token.len - 1 # set to Other type of face
97 | label_mask = (label == self.token.PAD)
98 |
99 | data = {
100 | 'id': index,
101 | 'input': input,
102 | 'label': label,
103 | 'num_input': len(edges),
104 | 'num_faces': len(faces_indices),
105 | 'input_mask': input_mask,
106 | 'label_mask': label_mask,
107 | 'name': self.info_files[index]
108 | }
109 |
110 | return data
111 |
112 | def parse_splits_list(self, splits):
113 | """ Returns a list of info_file paths
114 | Args:
115 | splits (list of strings): each item is a path to a .json data file
116 | or a path to a .txt file containing a list of .json's relative paths from root.
117 | """
118 | if isinstance(splits, str):
119 | splits = splits.split()
120 | info_files = []
121 | for split in splits:
122 | ext = os.path.splitext(split)[1]
123 | split_path = os.path.join(self.root_dir, split)
124 | # split_path = os.path.join("/root/ablation/polys-test", split)
125 | if ext == '.json':
126 | info_files.append(split_path)
127 | elif ext == '.txt':
128 | info_files += [info_file.rstrip() for info_file in open(split_path, 'r')]
129 | else:
130 | raise NotImplementedError('%s not a valid info_file type' % split)
131 | return info_files
132 |
--------------------------------------------------------------------------------
/dataset/filters/filter_thickness.py:
--------------------------------------------------------------------------------
1 | """
2 | Filter closer cases
3 | """
4 | import argparse, os, trimesh, yaml, time, json
5 | from functools import partial
6 |
7 | import numpy as np
8 | from scipy.spatial.distance import cdist
9 | from tqdm.contrib.concurrent import process_map
10 |
11 |
12 | def scale_to_unit_sphere(mesh):
13 | if isinstance(mesh, trimesh.Scene):
14 | mesh = mesh.dump().sum()
15 |
16 | vertices = mesh.vertices - mesh.bounding_box.centroid
17 | vertices *= 2 / np.linalg.norm(mesh.bounding_box.extents)
18 |
19 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False, maintain_order=True)
20 |
21 |
22 | def dist_p2p(vertices, verts_i, verts_j):
23 | dists = cdist(vertices[verts_i], vertices[verts_j])
24 | return np.mean(np.min(dists, 1))
25 |
26 |
27 | def dist_p2l(vertices, verts_i, verts_j, EPS=1e-8, MAX_VALUE=10):
28 | edges = np.vstack((verts_j[:-1], verts_j[1:])).T
29 | edge_vector = vertices[edges[:, 1]] - vertices[edges[:, 0]]
30 | edge_length = np.linalg.norm(edge_vector, axis=1, keepdims=True) + EPS
31 | edge_tangent = edge_vector / edge_length
32 |
33 | # Points x Lines x Dim
34 | vector = vertices[verts_i, np.newaxis] - vertices[edges[:, 0]][np.newaxis]
35 |
36 | # Points x Lines
37 | points_prop = np.sum(
38 | vector * edge_tangent[np.newaxis], axis=-1) / edge_length.reshape(1, -1)
39 | points_perp = points_prop[..., np.newaxis] * edge_vector - vector
40 |
41 | # p2l dists within 0 < points_prop < 1
42 | pl_dists = np.linalg.norm(points_perp, axis=-1)
43 | pl_valid = np.logical_and(0 < points_prop, points_prop < 1)
44 | pl_dists[np.logical_not(pl_valid)] = MAX_VALUE
45 |
46 | # p2p dists
47 | pp_dists = cdist(vertices[verts_i], vertices[edges].reshape(-1, 3))
48 | pp_dists = pp_dists.reshape(-1, len(edges), 2)
49 | pp_dists = np.min(pp_dists, -1)
50 |
51 | dists = np.minimum(pl_dists, pp_dists)
52 | return np.mean(np.min(dists, 1))
53 |
54 |
55 | def load_and_preprocess(name, args):
56 | if os.path.exists(os.path.join(args.save_root, f'{name}.npy')):
57 | return
58 |
59 | mesh_path = os.path.join(args.root, 'obj', f'{name}.obj')
60 | mesh = trimesh.load_mesh(mesh_path, process=False, maintain_order=True)
61 |
62 | # normalize to a unit sphere
63 | mesh = scale_to_unit_sphere(mesh)
64 |
65 | feat_path = os.path.join(args.root, 'feat', f'{name}.yml')
66 | with open(feat_path) as file:
67 | annos = yaml.full_load(file)
68 |
69 | curve_verts = []
70 | for curve in annos['curves']:
71 | vert_indices = np.array(curve['vert_indices']).reshape(-1)
72 | curve_verts.append(vert_indices)
73 |
74 | vertices = mesh.vertices.view(np.ndarray)
75 |
76 | num_curves = len(curve_verts)
77 |
78 | with open(os.path.join(args.save_root, f'{name}.npy'), 'wb') as f:
79 | np.save(f, vertices)
80 | np.save(f, num_curves)
81 | for c in curve_verts:
82 | np.save(f, c)
83 |
84 | def filter_by_thickness(name, args):
85 |
86 | with open(os.path.join(args.save_root, f'{name}.npy'), 'rb') as f:
87 | vertices = np.load(f)
88 | num_curves = np.load(f)
89 | curve_verts = []
90 | max_index = 0
91 | for i in range(num_curves):
92 | curve_verts.append(np.load(f))
93 | max_index = max(curve_verts[-1].max(), max_index)
94 |
95 | if max_index >= len(vertices):
96 | print(f"{name} has vertices don't match {len(vertices)} <= {max_index}")
97 | return None
98 |
99 | # dists = np.zeros((num_curves, num_curves))
100 |
101 | for i in range(num_curves):
102 | verts_i = curve_verts[i]
103 |
104 | for j in range(i+1, num_curves):
105 | verts_j = curve_verts[j]
106 |
107 | if args.p2p:
108 | dist_1 = dist_p2p(vertices, verts_i, verts_j)
109 | dist_2 = dist_p2p(vertices, verts_j, verts_i)
110 | else:
111 | dist_1 = dist_p2l(vertices, verts_i, verts_j)
112 | dist_2 = dist_p2l(vertices, verts_j, verts_i)
113 |
114 | if dist_1 < args.threshold and dist_2 < args.threshold:
115 | return None
116 | # dists[i, j] = dist_1
117 | # dists[j, i] = dist_2
118 | return name
119 |
120 | def main(args):
121 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'r') as f:
122 | names = json.load(f)
123 |
124 | # preprocess
125 | process_map(
126 | partial(load_and_preprocess, args=args), names,
127 | max_workers=args.num_cores, chunksize=args.num_chunks)
128 |
129 | rets = process_map(
130 | partial(filter_by_thickness, args=args), names,
131 | max_workers=args.num_cores, chunksize=args.num_chunks)
132 |
133 | filtered = [ret for ret in rets if ret is not None]
134 |
135 | # Filtering by thickness can take a long time.
136 | # Uncomment the following and lines in filter_by_thickness() to save intermediate result.
137 |
138 | # name_to_dist = {}
139 | # for i in rets:
140 | # if i is None:
141 | # continue
142 | # name, dists = i
143 | # name_to_dist[name] = dists.tolist()
144 |
145 | # with open('all_thickness.json', 'w') as f:
146 | # json.dump(name_to_dist, f)
147 |
148 | with open('filtered_id_list.json', 'w') as f:
149 | json.dump(filtered, f)
150 |
151 | with open('data_processing_log.txt', 'a') as f:
152 | f.write("Thickness id list generation done - " + time.ctime() + '\n')
153 |
154 |
155 | if __name__ == '__main__':
156 | parser = argparse.ArgumentParser()
157 | parser.add_argument('--root', type=str, default='/root/Datasets/FaceFormer',
158 | help='dataset root')
159 | parser.add_argument('--save_root', type=str, default='/root/data/curve_verts',
160 | help='dataset root')
161 | parser.add_argument('--threshold', type=float, default=0.05,
162 | help='threshold for closer edge')
163 | parser.add_argument('--num_cores', type=int,
164 | default=10, help='number of processors.')
165 | parser.add_argument('--num_chunks', type=int,
166 | default=10, help='number of chunk.')
167 | parser.add_argument('--p2p', action='store_true')
168 | args = parser.parse_args()
169 |
170 | main(args)
171 |
172 |
--------------------------------------------------------------------------------
/dataset/utils/json_to_svg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from functools import partial
5 | from faceformer.utils import flatten_list
6 |
7 | import numpy as np
8 | import svgwrite
9 | from cairosvg import svg2png
10 | from matplotlib.cm import get_cmap as colormap
11 | from tqdm.contrib.concurrent import process_map
12 |
13 |
14 | def discretized_edge_to_svg_polyline(points):
15 | """ Returns a svgwrite.Path for the edge, and the 2d bounding box
16 | """
17 | return svgwrite.shapes.Polyline(points, fill="none", class_='vectorEffectClass')
18 |
19 | def save_svg_groups(groups_of_edges, filename, args):
20 | discretized_edges = flatten_list(groups_of_edges)
21 | all_edges = flatten_list(discretized_edges)
22 |
23 | # compute bounding box
24 | min_x, min_y = np.min(all_edges, axis=0) - args.png_padding
25 | max_x, max_y = np.max(all_edges, axis=0) + args.png_padding
26 | width, height = max_x - min_x, max_y - min_y
27 |
28 |
29 | # build the svg drawing
30 | dwg = svgwrite.Drawing(filename, (args.width, args.height), debug=True)
31 | dwg.viewbox(min_x, min_y, width, height)
32 |
33 | # make sure line width stays constant
34 | # https://github.com/mozman/svgwrite/issues/38
35 | dwg.defs.add(
36 | dwg.style(".vectorEffectClass {\nvector-effect: non-scaling-stroke;\n}"))
37 |
38 | n = len(groups_of_edges) + 1
39 | cmap = (colormap('coolwarm')(np.linspace(0, 1, n))[:, :3]*255).astype(np.uint8)
40 | np.random.seed(args.seed)
41 | cmap = cmap[np.random.permutation(n), :]
42 | for index, group in enumerate(groups_of_edges):
43 | color = ",".join([str(c) for c in cmap[index]])
44 | for edge in group:
45 | polyline = discretized_edge_to_svg_polyline(edge)
46 | polyline.stroke(f"rgb({color})",
47 | width=args.line_width, linecap="round")
48 | dwg.add(polyline)
49 | # export to string or file according to the user choice
50 | dwg.save()
51 |
52 |
53 | def save_svg(discretized_edges, filename, args, color='black'):
54 | # compute polylines for all edges
55 | polylines = [discretized_edge_to_svg_polyline(
56 | edge) for edge in discretized_edges]
57 |
58 | all_edges = flatten_list(discretized_edges)
59 |
60 | # compute bounding box
61 | min_x, min_y = np.min(all_edges, axis=0) - args.png_padding
62 | max_x, max_y = np.max(all_edges, axis=0) + args.png_padding
63 | width, height = max_x - min_x, max_y - min_y
64 |
65 |
66 | # build the svg drawing
67 | dwg = svgwrite.Drawing(filename, (args.width, args.height), debug=True)
68 | dwg.viewbox(min_x, min_y, width, height)
69 |
70 | # make sure line width stays constant
71 | # https://github.com/mozman/svgwrite/issues/38
72 | dwg.defs.add(
73 | dwg.style(".vectorEffectClass {\nvector-effect: non-scaling-stroke;\n}"))
74 |
75 | n = len(polylines) + 1
76 | cmap = (colormap('jet')(np.linspace(0, 1, n))[:, :3]*255).astype(np.uint8)
77 | cmap = cmap[np.random.permutation(n), :]
78 |
79 | for index, (dedge, polyline) in enumerate(zip(discretized_edges, polylines)):
80 | if color != 'black':
81 | color = ",".join([str(c) for c in cmap[index]])
82 | color = f"rgb({color})"
83 | polyline.stroke(color,
84 | width=args.line_width, linecap="round")
85 | dwg.add(polyline)
86 | # add a circle at the beginning of the edge
87 | dwg.add(dwg.circle(dedge[0], r=4/256, fill='black'))
88 |
89 | # export to string or file according to the user choice
90 | dwg.save()
91 |
92 |
93 | def save_png(name, args, prefix=''):
94 | svg2png(
95 | bytestring=open(
96 | os.path.join(args.root, prefix+'svg', f'{name}.svg'), 'rb').read(),
97 | output_width=args.width,
98 | output_height=args.height,
99 | background_color='white',
100 | write_to=os.path.join(args.root, prefix+'png', f'{name}.png')
101 | )
102 |
103 |
104 | def json_to_svg_png(name, args):
105 | json_filename = os.path.join(args.root, 'json', f'{name}.json')
106 | with open(json_filename, "r") as f:
107 | data = json.loads(f.read())
108 | edges, faces_indices = data['edges'], data['faces_indices']
109 | # reconstruct faces
110 | for index, face_indices in enumerate(faces_indices):
111 | face = [edges[ind] for ind in face_indices]
112 | filename = os.path.join(
113 | args.root, args.prefix+'face_svg', f'{name}_{index}.svg')
114 | save_svg(face, filename, args)
115 |
116 | save_svg(edges, os.path.join(
117 | args.root, args.prefix+'svg', f'{name}.svg'), args)
118 | # generate_pngs(name, args, prefix=args.prefix)
119 |
120 |
121 | def main(args):
122 | os.makedirs(os.path.join(args.root, args.prefix+'svg'), exist_ok=True)
123 | os.makedirs(os.path.join(args.root, args.prefix+'face_svg'), exist_ok=True)
124 | os.makedirs(os.path.join(args.root, args.prefix+'png'), exist_ok=True)
125 | os.makedirs(os.path.join(args.root, args.prefix+'face_png'), exist_ok=True)
126 |
127 | names = []
128 | for name in sorted(os.listdir(os.path.join(args.root, 'step'))):
129 | if not name.endswith('.step'):
130 | continue
131 |
132 | names.append(os.path.splitext(name)[0])
133 |
134 | process_map(
135 | partial(json_to_svg_png, args=args), names,
136 | max_workers=args.num_cores, chunksize=args.num_chunks
137 | )
138 |
139 |
140 | if __name__ == '__main__':
141 | parser = argparse.ArgumentParser()
142 | parser.add_argument('--root', type=str, default="./data",
143 | help='dataset root.')
144 | parser.add_argument('--num_cores', type=int,
145 | default=1, help='number of processors.')
146 | parser.add_argument('--num_chunks', type=int,
147 | default=16, help='number of chunk.')
148 | parser.add_argument('--line_width', type=str,
149 | default=str(3/256), help='svg line width.')
150 | parser.add_argument('--name', type=str, default=None,
151 | help='filename.')
152 | parser.add_argument('--width', type=int,
153 | default=256, help='svg width.')
154 | parser.add_argument('--height', type=int,
155 | default=256, help='svg height.')
156 | parser.add_argument('--prefix', type=str, default='json_',
157 | help='filename prefix for generated svg and png')
158 | args = parser.parse_args()
159 |
160 | if args.name is None:
161 | main(args)
162 | else:
163 | json_to_svg_png(args.name, args)
164 |
--------------------------------------------------------------------------------
/dataset/filters/3view_render.py:
--------------------------------------------------------------------------------
1 | """
2 | Render three-view line drawings
3 | """
4 | import argparse
5 | import os
6 | from functools import partial
7 | import numpy as np
8 | import svgwrite
9 | from cairosvg import svg2png
10 | from tqdm.contrib.concurrent import process_map
11 |
12 | from OCC.Core.Bnd import Bnd_Box
13 | from OCC.Core.BRepBndLib import brepbndlib_Add
14 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_Transform
15 | from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
16 | from OCC.Core.gp import gp_Ax2, gp_Dir, gp_Pnt, gp_Trsf, gp_Vec
17 | from OCC.Extend.TopologyUtils import TopologyExplorer
18 |
19 | from dataset.utils.json_to_svg import save_png, save_svg
20 | from dataset.utils.read_step_file import read_step_file
21 | from dataset.utils.projection_utils import project_shapes, discretize_compound
22 |
23 | O = gp_Pnt(0,0,0)
24 | X = gp_Dir(1,0,0)
25 | Y = gp_Dir(0,1,0)
26 | nY = gp_Dir(0,-1,0)
27 | Z = gp_Dir(0,0,1)
28 |
29 | directions = [
30 | gp_Ax2(O, gp_Dir(1,1,1)), # 45 degree
31 | gp_Ax2(O, nY, X), # front
32 | gp_Ax2(O, X, Y), # right
33 | gp_Ax2(O, Z, X) # top
34 | ]
35 |
36 | def get_boundingbox(shape, tol=1e-6, use_mesh=False):
37 | """ return the bounding box of the TopoDS_Shape `shape`
38 | Parameters
39 | ----------
40 | shape : TopoDS_Shape or a subclass such as TopoDS_Face
41 | the shape to compute the bounding box from
42 | tol: float
43 | tolerance of the computed boundingbox
44 | use_mesh : bool
45 | a flag that tells whether or not the shape has first to be meshed before the bbox
46 | computation. This produces more accurate results
47 | """
48 | bbox = Bnd_Box()
49 | bbox.SetGap(tol)
50 | if use_mesh:
51 | mesh = BRepMesh_IncrementalMesh()
52 | mesh.SetParallelDefault(True)
53 | mesh.SetShape(shape)
54 | mesh.Perform()
55 | if not mesh.IsDone():
56 | raise AssertionError("Mesh not done.")
57 | brepbndlib_Add(shape, bbox, use_mesh)
58 | xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
59 | center = (xmax + xmin) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2
60 | extent = abs(xmax-xmin), abs(ymax-ymin), abs(zmax-zmin)
61 | return center, extent
62 |
63 | def get_discretized_edges(name, shape, direction, args):
64 | """ Given a TopologyExplorer topo, and a face on it,
65 | find all edges of the face without sewn edges.
66 | Return discretized edges
67 |
68 | VComponent / HComponent: sharp edges
69 | Rg1LineVCompound / Rg1LineHCompound: smooth edges
70 | RgNLineVCompound / RgNLineHCompound: sewn edges
71 | OutLineVCompound / OutLineHCompound: outlines
72 | """
73 | # project the face
74 | hlr_shapes = project_shapes(shape, direction)
75 |
76 | discretized_edges = []
77 |
78 | outline_compound = hlr_shapes.OutLineVCompound()
79 | if outline_compound:
80 | discretized_edges += discretize_compound(outline_compound, args.tol)
81 |
82 | smooth_compound = hlr_shapes.Rg1LineVCompound()
83 | if smooth_compound:
84 | discretized_edges += discretize_compound(smooth_compound, args.tol)
85 |
86 | # project sharp edges from the face, using only edges.
87 | # (to avoid slicing effects from sewn edge when projecting using face)
88 | sharp_edges_3d = list(TopologyExplorer(shape).edges())
89 | sharp_edges_compound = project_shapes(sharp_edges_3d, direction).VCompound()
90 | if sharp_edges_compound:
91 | sharp_edges_discretized = discretize_compound(sharp_edges_compound, args.tol)
92 |
93 | # check if there are sewn edges
94 | sewn_compound = hlr_shapes.RgNLineVCompound()
95 | if sewn_compound:
96 | sewn_edges_discretized = discretize_compound(sewn_compound, args.tol)
97 | for sewn_edge in sewn_edges_discretized:
98 | try:
99 | sharp_edges_discretized.remove(sewn_edge)
100 | except ValueError:
101 | print("sewn edge assumption broken", name)
102 | break
103 | discretized_edges += sharp_edges_discretized
104 |
105 | return discretized_edges
106 |
107 | def discretized_edge_to_svg_polyline(points):
108 | """ Returns a svgwrite.Path for the edge, and the 2d bounding box
109 | """
110 | return svgwrite.shapes.Polyline(points, fill="none", class_='vectorEffectClass')
111 |
112 | def shape_to_svg(shape, name, args):
113 | """ export a single shape to an svg file and json.
114 | shape: the TopoDS_Shape to export
115 | """
116 | if shape.IsNull():
117 | raise AssertionError("shape is Null")
118 |
119 | for i, direction in enumerate(directions):
120 |
121 | shape_discretized_edges = get_discretized_edges(name, shape, direction, args)
122 |
123 | save_svg(shape_discretized_edges, os.path.join(args.root, '3view_svg', f'{name}-{i}.svg'), args)
124 |
125 | svg2png(
126 | bytestring=open(os.path.join(args.root, '3view_svg', f'{name}-{i}.svg'), 'rb').read(),
127 | output_width=args.width,
128 | output_height=args.height,
129 | background_color='white',
130 | write_to=os.path.join(args.root, '3view_png', f'{name}-{i}.png')
131 | )
132 |
133 | def render_3views(name, args):
134 | try:
135 | step_path = os.path.join(args.root, 'step', f'{name}.step')
136 | # step read timeout at 5 seconds
137 | try:
138 | shape, _ = read_step_file(step_path, verbosity=False)
139 | except:
140 | print(f"{name} took too long to read")
141 | return
142 |
143 | if shape is None:
144 | print(f"{name} is NULL shape")
145 | return
146 |
147 | center, extent = get_boundingbox(shape)
148 |
149 | trans, scale = gp_Trsf(), gp_Trsf()
150 | trans.SetTranslation(-gp_Vec(*center))
151 | scale.SetScale(gp_Pnt(0, 0, 0), 2 / np.linalg.norm(extent))
152 | brep_trans = BRepBuilderAPI_Transform(shape, scale * trans)
153 | shape = brep_trans.Shape()
154 |
155 | shape_to_svg(shape, name, args)
156 | except Exception as e:
157 | print(f"{name} received unknown error", e)
158 |
159 | def main(args):
160 | # all step files
161 | names = []
162 | for name in sorted(os.listdir(os.path.join(args.root, 'stat'))):
163 | names.append(name[:8])
164 |
165 | os.makedirs(os.path.join(args.root, '3view_svg'), exist_ok=True)
166 | os.makedirs(os.path.join(args.root, '3view_png'), exist_ok=True)
167 |
168 | process_map(
169 | partial(render_3views, args=args), names,
170 | max_workers=args.num_cores, chunksize=args.num_chunks
171 | )
172 |
173 |
174 | if __name__ == '__main__':
175 | parser = argparse.ArgumentParser()
176 | parser.add_argument('--root', type=str, default="./data",
177 | help='dataset root.')
178 | parser.add_argument('--name', type=str, default=None,
179 | help='filename.')
180 | parser.add_argument('--num_cores', type=int,
181 | default=40, help='number of processors.')
182 | parser.add_argument('--num_chunks', type=int,
183 | default=10, help='number of chunk.')
184 | parser.add_argument('--width', type=int,
185 | default=256, help='svg width.')
186 | parser.add_argument('--height', type=int,
187 | default=256, help='svg height.')
188 | parser.add_argument('--tol', type=float,
189 | default=1e-4, help='svg discretization tolerance.')
190 | parser.add_argument('--line_width', type=str,
191 | default=str(3/256), help='svg line width.')
192 | parser.add_argument('--filter_num_shapes', type=int,
193 | default=8, help='do not process step files \
194 | that have more than this number of shapes.')
195 | parser.add_argument('--filter_num_edges', type=int,
196 | default=1000, help='do not process step files \
197 | that have more than this number of edges.')
198 |
199 | args = parser.parse_args()
200 |
201 | if args.name is None:
202 | main(args)
203 | else:
204 | render_3views(args.name, args)
205 |
--------------------------------------------------------------------------------
/reconstruction/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakeEdge
3 | from OCC.Core.gp import gp_Ax2, gp_Circ, gp_Dir, gp_Pnt, gp_Vec
4 | from OCC.Extend.TopologyUtils import discretize_edge
5 |
6 |
7 | def construct_connected_cylinder(edges, edge_inds, tol=1e-4):
8 | '''
9 | Given lines and their indices,
10 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1).
11 | '''
12 |
13 | # group edges by their intersections
14 | groups = {}
15 | edge_ind_to_intersection = {}
16 | for edge, edge_ind in zip(edges, edge_inds):
17 | start, end = tuple(edge[0]), tuple(edge[-1])
18 | start_found, end_found = False, False
19 | # find start's group
20 | for intersection in groups:
21 | if dist(start, intersection) < tol:
22 | groups[intersection].append((edge, 1, edge_ind))
23 | start_found = True
24 | break
25 | if not start_found:
26 | groups[start] = [(edge, 1, edge_ind)]
27 | intersection = start
28 |
29 | if edge_ind not in edge_ind_to_intersection:
30 | edge_ind_to_intersection[edge_ind] = [intersection]
31 | else:
32 | edge_ind_to_intersection[edge_ind].append(intersection)
33 |
34 | # find end's group
35 | for intersection in groups:
36 | if dist(end, intersection) < tol:
37 | groups[intersection].append((edge, -1, edge_ind))
38 | end_found = True
39 | break
40 | if not end_found:
41 | groups[end] = [(edge, -1, edge_ind)]
42 | intersection = end
43 |
44 | if edge_ind not in edge_ind_to_intersection:
45 | edge_ind_to_intersection[edge_ind] = [intersection]
46 | else:
47 | edge_ind_to_intersection[edge_ind].append(intersection)
48 | # fix one corner to be the origin. Generate a circle from the origin
49 | for intersection, edge_inter in groups.items():
50 | assert len(edge_inter) == 2, "more than two edges intersect at one intersection"
51 | edge1, edge2 = edge_inter[0][0], edge_inter[1][0]
52 | # intersection of a line and a curve is a real intersection
53 | if is_straight_line(edge1) or is_straight_line(edge2):
54 | origin = intersection
55 | break
56 |
57 | # construct the circle
58 | circle = []
59 | circle_inds = []
60 | dirs = []
61 | next_point = origin
62 | count = 0
63 | while True:
64 | for edge, direction, edge_ind in groups[next_point]:
65 | if edge_ind not in circle_inds:
66 | break
67 | circle.append(edge[::direction])
68 | circle_inds.append(edge_ind)
69 | dirs.append(direction)
70 | # find the next point
71 | for intersection in edge_ind_to_intersection[edge_ind]:
72 | if tuple(next_point) != tuple(intersection):
73 | next_point = intersection
74 | break
75 | if next_point == origin:
76 | break
77 | count += 1
78 | if count >= 10:
79 | print("cylinder construction failed")
80 | break
81 |
82 | # return circle indices in sequence
83 | return circle, circle_inds, dirs
84 |
85 |
86 | def construct_connected_cycle(edges, edge_inds, tol=1e-4):
87 | '''
88 | Given lines and their indices,
89 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1).
90 | '''
91 |
92 | # group edges by their intersections
93 | groups = {}
94 | edge_ind_to_intersection = {}
95 | for edge, edge_ind in zip(edges, edge_inds):
96 | start, end = tuple(edge[0]), tuple(edge[-1])
97 | start_found, end_found = False, False
98 | # find start's group
99 | for intersection in groups:
100 | if dist(start, intersection) < tol:
101 | groups[intersection].append((edge, 1, edge_ind))
102 | start_found = True
103 | break
104 | if not start_found:
105 | groups[start] = [(edge, 1, edge_ind)]
106 | intersection = start
107 |
108 | if edge_ind not in edge_ind_to_intersection:
109 | edge_ind_to_intersection[edge_ind] = [intersection]
110 | else:
111 | edge_ind_to_intersection[edge_ind].append(intersection)
112 |
113 | # find end's group
114 | for intersection in groups:
115 | if dist(end, intersection) < tol:
116 | groups[intersection].append((edge, -1, edge_ind))
117 | end_found = True
118 | break
119 | if not end_found:
120 | groups[end] = [(edge, -1, edge_ind)]
121 | intersection = end
122 |
123 | if edge_ind not in edge_ind_to_intersection:
124 | edge_ind_to_intersection[edge_ind] = [intersection]
125 | else:
126 | edge_ind_to_intersection[edge_ind].append(intersection)
127 |
128 |
129 | # construct circles
130 | all_circles = []
131 | all_circle_inds = []
132 | all_dirs = []
133 | while len(groups) > 0:
134 | origin = list(groups.keys())[0]
135 | circle = []
136 | circle_inds = []
137 | dirs = []
138 | next_point = origin
139 | skip = False
140 | while True:
141 | if next_point not in groups:
142 | skip = True
143 | break
144 | for edge, direction, edge_ind in groups[next_point]:
145 | if edge_ind not in circle_inds:
146 | break
147 | circle.append(edge[::direction])
148 | circle_inds.append(edge_ind)
149 | dirs.append(direction)
150 | del groups[next_point]
151 |
152 | # find the next point
153 | for intersection in edge_ind_to_intersection[edge_ind]:
154 | if tuple(next_point) != tuple(intersection):
155 | next_point = intersection
156 | break
157 | if next_point == origin:
158 | break
159 | if not skip:
160 | all_circles.append(circle)
161 | all_circle_inds.append(circle_inds)
162 | all_dirs.append(dirs)
163 | # return circle indices in sequence
164 | return all_circles, all_circle_inds, all_dirs
165 |
166 |
167 |
168 | def check_parallel(v1, v2, tol=1e-10):
169 | return np.abs(np.dot(v1, v2)) > (1 - tol)
170 |
171 | def fit_curve(p1, p2, p3):
172 | '''
173 | Given three 3D points, fit a circle to the points.
174 | Return the discretized curve between p1-p3-p2
175 | '''
176 | center, radius, normal = find_circle_center(p1, p2, p3)
177 |
178 | # construct opencascade circle
179 | center = gp_Pnt(center[0], center[1], center[2])
180 | normal = gp_Vec(normal[0], normal[1], normal[2])
181 | ax = gp_Ax2(center, gp_Dir(normal))
182 | circle = gp_Circ(ax, radius)
183 | circle_edge = BRepBuilderAPI_MakeEdge(circle).Edge()
184 | pts = discretize_edge(circle_edge, deflection=1e-5)
185 | return find_curve_between_points(pts, p1, p2, p3)
186 |
187 | def find_circle_center(p1, p2, p3):
188 | # triangle "edges"
189 | t = np.array(p2 - p1)
190 | u = np.array(p3 - p1)
191 | v = np.array(p3 - p2)
192 |
193 | # triangle normal
194 | w = np.cross(t, u)
195 | wsl = w.dot(w)
196 |
197 | # helpers
198 | iwsl2 = 1.0 / (2.0 * wsl)
199 | tt = t.dot(t)
200 | uu = u.dot(u)
201 |
202 | # result circle
203 | center = p1 + (u*tt*(u.dot(v)) - t*uu*(t.dot(v))) * iwsl2
204 | radius = np.sqrt(tt * uu * (v.dot(v)) * iwsl2 / 2)
205 | normal = w / np.sqrt(wsl)
206 | return center, radius, normal
207 |
208 | def find_curve_between_points(pts, p1, p2, p3):
209 | pts = np.array(pts)
210 | p1_ind = np.argmin(np.linalg.norm(pts - p1, axis=1))
211 | p2_ind = np.argmin(np.linalg.norm(pts - p2, axis=1))
212 | p1_ind, p2_ind = min(p1_ind, p2_ind), max(p1_ind, p2_ind)
213 | right_direction = p3 - pts[p1_ind]
214 | v1 = pts[(p1_ind+1) % (len(pts)-1)] - pts[p1_ind]
215 | # selecting p1_ind to p2_ind if angle is acute
216 | if np.dot(v1, right_direction) > 0:
217 | pts = pts[p1_ind:p2_ind+1]
218 | # selecting everything else if angle is obtuse
219 | else:
220 | pts = np.vstack([pts[p2_ind:], pts[:p1_ind+1]])
221 | return pts
222 |
223 | def dist(p1, p2):
224 | return np.linalg.norm(np.array(p1) - np.array(p2))
225 |
226 | def is_straight_line(line):
227 | return len(line) == 2
228 |
--------------------------------------------------------------------------------
/reconstruction/reconstruction_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakeEdge
3 | from OCC.Core.gp import gp_Ax2, gp_Circ, gp_Dir, gp_Pnt, gp_Vec
4 | from OCC.Extend.TopologyUtils import discretize_edge
5 |
6 |
7 | def construct_connected_cylinder(edges, edge_inds, tol=1e-4):
8 | '''
9 | Given lines and their indices,
10 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1).
11 | '''
12 |
13 | # group edges by their intersections
14 | groups = {}
15 | edge_ind_to_intersection = {}
16 | for edge, edge_ind in zip(edges, edge_inds):
17 | start, end = tuple(edge[0]), tuple(edge[-1])
18 | start_found, end_found = False, False
19 | # find start's group
20 | for intersection in groups:
21 | if dist(start, intersection) < tol:
22 | groups[intersection].append((edge, 1, edge_ind))
23 | start_found = True
24 | break
25 | if not start_found:
26 | groups[start] = [(edge, 1, edge_ind)]
27 | intersection = start
28 |
29 | if edge_ind not in edge_ind_to_intersection:
30 | edge_ind_to_intersection[edge_ind] = [intersection]
31 | else:
32 | edge_ind_to_intersection[edge_ind].append(intersection)
33 |
34 | # find end's group
35 | for intersection in groups:
36 | if dist(end, intersection) < tol:
37 | groups[intersection].append((edge, -1, edge_ind))
38 | end_found = True
39 | break
40 | if not end_found:
41 | groups[end] = [(edge, -1, edge_ind)]
42 | intersection = end
43 |
44 | if edge_ind not in edge_ind_to_intersection:
45 | edge_ind_to_intersection[edge_ind] = [intersection]
46 | else:
47 | edge_ind_to_intersection[edge_ind].append(intersection)
48 | # fix one corner to be the origin. Generate a circle from the origin
49 | for intersection, edge_inter in groups.items():
50 | assert len(edge_inter) == 2, "more than two edges intersect at one intersection"
51 | edge1, edge2 = edge_inter[0][0], edge_inter[1][0]
52 | # intersection of a line and a curve is a real intersection
53 | if is_straight_line(edge1) or is_straight_line(edge2):
54 | origin = intersection
55 | break
56 |
57 | # construct the circle
58 | circle = []
59 | circle_inds = []
60 | dirs = []
61 | next_point = origin
62 | count = 0
63 | while True:
64 | for edge, direction, edge_ind in groups[next_point]:
65 | if edge_ind not in circle_inds:
66 | break
67 | circle.append(edge[::direction])
68 | circle_inds.append(edge_ind)
69 | dirs.append(direction)
70 | # find the next point
71 | for intersection in edge_ind_to_intersection[edge_ind]:
72 | if tuple(next_point) != tuple(intersection):
73 | next_point = intersection
74 | break
75 | if next_point == origin:
76 | break
77 | count += 1
78 | if count >= 10:
79 | print("cylinder construction failed")
80 | break
81 |
82 | # return circle indices in sequence
83 | return circle, circle_inds, dirs
84 |
85 |
86 | def construct_connected_cycle(edges, edge_inds, tol=1e-4):
87 | '''
88 | Given lines and their indices,
89 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1).
90 | '''
91 |
92 | # group edges by their intersections
93 | groups = {}
94 | edge_ind_to_intersection = {}
95 | for edge, edge_ind in zip(edges, edge_inds):
96 | start, end = tuple(edge[0]), tuple(edge[-1])
97 | start_found, end_found = False, False
98 | # find start's group
99 | for intersection in groups:
100 | if dist(start, intersection) < tol:
101 | groups[intersection].append((edge, 1, edge_ind))
102 | start_found = True
103 | break
104 | if not start_found:
105 | groups[start] = [(edge, 1, edge_ind)]
106 | intersection = start
107 |
108 | if edge_ind not in edge_ind_to_intersection:
109 | edge_ind_to_intersection[edge_ind] = [intersection]
110 | else:
111 | edge_ind_to_intersection[edge_ind].append(intersection)
112 |
113 | # find end's group
114 | for intersection in groups:
115 | if dist(end, intersection) < tol:
116 | groups[intersection].append((edge, -1, edge_ind))
117 | end_found = True
118 | break
119 | if not end_found:
120 | groups[end] = [(edge, -1, edge_ind)]
121 | intersection = end
122 |
123 | if edge_ind not in edge_ind_to_intersection:
124 | edge_ind_to_intersection[edge_ind] = [intersection]
125 | else:
126 | edge_ind_to_intersection[edge_ind].append(intersection)
127 |
128 |
129 | # construct circles
130 | all_circles = []
131 | all_circle_inds = []
132 | all_dirs = []
133 | while len(groups) > 0:
134 | origin = list(groups.keys())[0]
135 | circle = []
136 | circle_inds = []
137 | dirs = []
138 | next_point = origin
139 | skip = False
140 | while True:
141 | if next_point not in groups:
142 | skip = True
143 | break
144 | for edge, direction, edge_ind in groups[next_point]:
145 | if edge_ind not in circle_inds:
146 | break
147 | circle.append(edge[::direction])
148 | circle_inds.append(edge_ind)
149 | dirs.append(direction)
150 | del groups[next_point]
151 |
152 | # find the next point
153 | for intersection in edge_ind_to_intersection[edge_ind]:
154 | if tuple(next_point) != tuple(intersection):
155 | next_point = intersection
156 | break
157 | if next_point == origin:
158 | break
159 | if not skip:
160 | all_circles.append(circle)
161 | all_circle_inds.append(circle_inds)
162 | all_dirs.append(dirs)
163 | # return circle indices in sequence
164 | return all_circles, all_circle_inds, all_dirs
165 |
166 |
167 |
168 | def check_parallel(v1, v2, tol=1e-10):
169 | return np.abs(np.dot(v1, v2)) > (1 - tol)
170 |
171 | def fit_curve(p1, p2, p3):
172 | '''
173 | Given three 3D points, fit a circle to the points.
174 | Return the discretized curve between p1-p3-p2
175 | '''
176 | center, radius, normal = find_circle_center(p1, p2, p3)
177 |
178 | # construct opencascade circle
179 | center = gp_Pnt(center[0], center[1], center[2])
180 | normal = gp_Vec(normal[0], normal[1], normal[2])
181 | ax = gp_Ax2(center, gp_Dir(normal))
182 | circle = gp_Circ(ax, radius)
183 | circle_edge = BRepBuilderAPI_MakeEdge(circle).Edge()
184 | pts = discretize_edge(circle_edge, deflection=1e-5)
185 | return find_curve_between_points(pts, p1, p2, p3)
186 |
187 | def find_circle_center(p1, p2, p3):
188 | # triangle "edges"
189 | t = np.array(p2 - p1)
190 | u = np.array(p3 - p1)
191 | v = np.array(p3 - p2)
192 |
193 | # triangle normal
194 | w = np.cross(t, u)
195 | wsl = w.dot(w)
196 |
197 | # helpers
198 | iwsl2 = 1.0 / (2.0 * wsl)
199 | tt = t.dot(t)
200 | uu = u.dot(u)
201 |
202 | # result circle
203 | center = p1 + (u*tt*(u.dot(v)) - t*uu*(t.dot(v))) * iwsl2
204 | radius = np.sqrt(tt * uu * (v.dot(v)) * iwsl2 / 2)
205 | normal = w / np.sqrt(wsl)
206 | return center, radius, normal
207 |
208 | def find_curve_between_points(pts, p1, p2, p3):
209 | pts = np.array(pts)
210 | p1_ind = np.argmin(np.linalg.norm(pts - p1, axis=1))
211 | p2_ind = np.argmin(np.linalg.norm(pts - p2, axis=1))
212 | p1_ind, p2_ind = min(p1_ind, p2_ind), max(p1_ind, p2_ind)
213 | right_direction = p3 - pts[p1_ind]
214 | v1 = pts[(p1_ind+1) % (len(pts)-1)] - pts[p1_ind]
215 | # selecting p1_ind to p2_ind if angle is acute
216 | if np.dot(v1, right_direction) > 0:
217 | pts = pts[p1_ind:p2_ind+1]
218 | # selecting everything else if angle is obtuse
219 | else:
220 | pts = np.vstack([pts[p2_ind:], pts[:p1_ind+1]])
221 | return pts
222 |
223 | def dist(p1, p2):
224 | return np.linalg.norm(np.array(p1) - np.array(p2))
225 |
226 | def is_straight_line(line):
227 | return len(line) == 2
228 |
--------------------------------------------------------------------------------
/dataset/utils/TopoMapper.py:
--------------------------------------------------------------------------------
1 | from OCC.Extend.TopologyUtils import TopologyExplorer, WireExplorer, discretize_edge
2 | from OCC.Core.BRepFeat import BRepFeat_SplitShape
3 | from OCC.Core.TopTools import TopTools_SequenceOfShape
4 | from OCC.Core.ShapeFix import ShapeFix_ShapeTolerance
5 |
6 | from dataset.utils.projection_utils import d3_to_d2, project_shapes, discretize_compound, project_points
7 | from dataset.utils.Edge import Edge
8 | from dataset.utils.Face import Face
9 | from faceformer.utils import flatten_list
10 |
11 | import numpy as np
12 |
13 |
14 |
15 | class TopoMapper:
16 | def __init__(self, shape, args):
17 | self.shape = shape
18 | self.all_edges = None
19 | self.all_faces = None
20 | self.args = args
21 | self.tol = self.args.tol
22 |
23 | # add outline to shape
24 | outline_edges = self._find_outline_edges()
25 | self.full_topo = self._add_outline_edges(outline_edges)
26 |
27 | # construct all edge-face mappings
28 | self._construct_mapping()
29 |
30 | # project to 2D; each edge has dedge now
31 | self._project(args.discretize_last)
32 |
33 | # remove sewn edges
34 | sewn_edge_keys = self._find_sewn_edges()
35 | self._remove_sewn_edges(sewn_edge_keys)
36 |
37 |
38 | def _find_outline_edges(self):
39 | hlr_shapes = project_shapes(self.shape, self.args)
40 | outline_compound = hlr_shapes.OutLineVCompound3d()
41 | if outline_compound:
42 | return list(TopologyExplorer(outline_compound).edges())
43 | return []
44 |
45 | def _num_edges(self, splitshape):
46 | probing_shape = splitshape.Shape()
47 | split = BRepFeat_SplitShape(probing_shape)
48 | return split, len(list(TopologyExplorer(probing_shape).edges()))
49 |
50 | def _add_edge(self, split, edge, num_edge):
51 | toptool_seq_shape = TopTools_SequenceOfShape()
52 | toptool_seq_shape.Append(edge)
53 | add_success = split.Add(toptool_seq_shape)
54 | split, curr_num_edge = self._num_edges(split)
55 | add_success = add_success and (curr_num_edge > num_edge)
56 | return split, curr_num_edge, add_success
57 |
58 | def _add_outline_edges(self, outline_edges):
59 | if len(outline_edges) == 0:
60 | return TopologyExplorer(self.shape)
61 | split_edge_num = 0
62 | while True:
63 | # repeated split edge until number of edges converge
64 | split = BRepFeat_SplitShape(self.shape)
65 | split, num_edge = self._num_edges(split)
66 | for edge in outline_edges:
67 | probing_shape = split.Shape()
68 | backup_split, split = BRepFeat_SplitShape(probing_shape), BRepFeat_SplitShape(probing_shape)
69 | split, curr_num_edge, add_success = self._add_edge(split, edge, num_edge)
70 | if not add_success:
71 | # Increase outline tolerance when add fails
72 | # fixed tolerance, may need update
73 | tol = ShapeFix_ShapeTolerance()
74 | tol.SetTolerance(edge, 1)
75 | split, curr_num_edge, add_success = self._add_edge(backup_split, edge, num_edge)
76 | if not add_success:
77 | raise Exception("Fail to add splitting outline")
78 | if split_edge_num == curr_num_edge:
79 | break
80 | split_edge_num = curr_num_edge
81 |
82 | split_shape = split.Shape()
83 | return TopologyExplorer(split_shape)
84 |
85 | def _construct_mapping(self):
86 | '''
87 | Construct edge-to-face mapping from wireframe.
88 | '''
89 | all_edges = {}
90 | all_faces = {}
91 |
92 | for face in self.full_topo.faces():
93 | new_face = Face(face, self)
94 | all_faces[hash(face)] = new_face
95 |
96 | sharp_edges_wires = list(self.full_topo.wires_from_face(face))
97 | sharp_edges_3d = []
98 | for wire in sharp_edges_wires:
99 | sharp_edges_3d += list(WireExplorer(wire).ordered_edges())
100 |
101 | for edge in sharp_edges_3d:
102 | edge_id = hash(edge) # same edge has same hash
103 |
104 | # create edge
105 | if edge_id in all_edges:
106 | new_edge = all_edges[edge_id]
107 | new_edge.add_face(new_face, edge.Orientation())
108 | else:
109 | new_edge = Edge(edge, faces=[new_face], orientations=[edge.Orientation()])
110 | all_edges[edge_id] = new_edge
111 |
112 | # add edge to face
113 | new_face.add_edge(new_edge, edge.Orientation())
114 |
115 | self.all_faces = all_faces
116 | self.all_edges = all_edges
117 |
118 | def _find_sewn_edges(self):
119 | '''
120 | Any edge that occur in any face twice is sewn edge.
121 | '''
122 | all_sewn_edge_keys = []
123 | topo = TopologyExplorer(self.shape)
124 | for face in topo.faces():
125 | edge_keys = []
126 |
127 | sharp_edges_wires = list(topo.wires_from_face(face))
128 | sharp_edges_3d = []
129 | for wire in sharp_edges_wires:
130 | sharp_edges_3d += list(WireExplorer(wire).ordered_edges())
131 |
132 | for edge in sharp_edges_3d:
133 | edge_id = hash(edge) # same edge has same hash
134 |
135 | # if edge is used twice in a face, it's a sewn edge
136 | if edge_id in edge_keys:
137 | all_sewn_edge_keys.append(edge_id)
138 | else:
139 | edge_keys.append(edge_id)
140 |
141 | return all_sewn_edge_keys
142 |
143 | def _remove_sewn_edges(self, sewn_edge_keys):
144 | '''
145 | Remove all sewn edge and combine faces.
146 | '''
147 | candidate_edges = set()
148 | for key in sewn_edge_keys:
149 | # if key in self.all_edges:
150 | sewn_edge = self.all_edges[key]
151 | # else:
152 | # # sewn edge not found after adding outline
153 |
154 | faces = sewn_edge.faces
155 | # roll edge sequence
156 | for face in faces:
157 | ind = face.keys.index(key)
158 | face.roll(ind)
159 | result_face = faces[0]
160 | for face in faces[1:]:
161 | pairs = result_face.merge(face)
162 | if pairs:
163 | for pair in pairs:
164 | candidate_edges.add(tuple(sorted(pair)))
165 |
166 | # merge candidate edges
167 | for key1, key2 in candidate_edges:
168 | # check if there's a 4th edge connected to this vertex
169 | d1, d2 = np.array(self.all_edges[key1].dedge), np.array(self.all_edges[key2].dedge)
170 | dist = lambda t: np.sum((t[0]-t[1])**2)
171 | p1, p2 = min([(d1[0], d2[0]), (d1[-1], d2[0]), (d1[0], d2[-1]), (d1[-1], d2[-1])], key=dist)
172 | vertex = (p1+p2) / 2
173 |
174 | skip = False
175 | for key in self.all_edges:
176 | if key == key1 or key == key2 or key in sewn_edge_keys:
177 | continue
178 | e = self.all_edges[key]
179 | if dist((vertex, e.dedge[0])) < self.tol or dist((vertex, e.dedge[-1])) < self.tol:
180 | skip = True
181 | break
182 |
183 | if not skip:
184 | self.all_edges[key1].merge(self.all_edges[key2], self)
185 |
186 |
187 |
188 | def _project(self, discretize_last=False):
189 | '''
190 | Project all edges of the shape
191 | '''
192 | for edge in list(self.all_edges.values()):
193 | if not discretize_last:
194 | sharp_dedge = discretize_edge(edge.edge, self.args.tol)
195 | edge.dedge3d = project_points(sharp_dedge, self.args)
196 | edge.dedge = d3_to_d2(edge.dedge3d)
197 | continue
198 | sharp_edges_compound = project_shapes(edge.edge, self.args).VCompound()
199 | if sharp_edges_compound is None:
200 | # invalid edge - delete
201 | key = hash(edge.edge)
202 | del self.all_edges[key]
203 | # del in face
204 | for face in edge.faces:
205 | face.remove_edge(key)
206 | continue
207 |
208 | dedge = discretize_compound(sharp_edges_compound, self.tol)[0]
209 | edge.dedge = dedge
210 |
211 | # Given a list of edges (they are from the same edge but broken into pieces)
212 | # project them and return one unified edge
213 | def _raw_project(self, edges, discretize_last=False):
214 | if not discretize_last:
215 | full_2d_dedge = []
216 | for edge in edges:
217 | dedge = discretize_edge(edge, self.args.tol)
218 | full_2d_dedge += d3_to_d2(project_points(dedge, self.args))
219 | return full_2d_dedge
220 | sharp_edges_compound = project_shapes(edges, self.args).VCompound()
221 | dedge = flatten_list(discretize_compound(sharp_edges_compound, self.tol)[:len(edges)])
222 | return dedge
223 |
224 | # return x,y,z directions in camera world
225 | def get_dominant_directions(self):
226 | # project origin, x, y, z
227 | points = [(0,0,0), (1,0,0), (0,1,0), (0,0,1)]
228 | origin, x, y, z = project_points(points, self.args)
229 | origin, x, y, z = [np.array(p) for p in [origin, x, y, z]]
230 | return (x - origin).tolist(), (y - origin).tolist(), (z - origin).tolist()
--------------------------------------------------------------------------------
/faceformer/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from faceformer.embedding import PositionEmbeddingLearned, VanillaEmedding
5 | from faceformer.transformer import (TransformerDecoder,
6 | TransformerDecoderLayer,
7 | TransformerEncoder,
8 | TransformerEncoderLayer)
9 | from faceformer.utils import min_value_of_dtype
10 |
11 |
12 | class SurfaceFormer(nn.Module):
13 |
14 | def __init__(self, num_model=512, num_head=8, num_feedforward=2048,
15 | num_encoder_layers=6, num_decoder_layers=6,
16 | dropout=0.1, activation="relu", normalize_before=True,
17 | num_points_per_line=50, num_lines=1000, point_dim=2,
18 | label_seq_length=2000, token=None, teacher_forcing_ratio=0, **kwargs):
19 | super(SurfaceFormer, self).__init__()
20 |
21 | self.num_model = num_model # E
22 | self.num_labels = label_seq_length # T
23 | self.teacher_forcing_ratio = teacher_forcing_ratio
24 | self.token = token
25 | self.num_token = token.len
26 |
27 | self.val_enc = VanillaEmedding(num_points_per_line * point_dim, num_model, token)
28 |
29 | # position encoding
30 | self.pos_enc = PositionEmbeddingLearned(num_model, max_len=num_lines+self.num_token)
31 | self.query_pos_enc = PositionEmbeddingLearned(num_model, max_len=label_seq_length)
32 |
33 | # vertex transformer encoder
34 | encoder_layers = TransformerEncoderLayer(
35 | num_model, num_head, num_feedforward, dropout, activation, normalize_before)
36 | encoder_norm = nn.LayerNorm(num_model) if normalize_before else None
37 | self.encoder = TransformerEncoder(encoder_layers, num_encoder_layers, encoder_norm)
38 |
39 | # wire transformer decoder
40 | decoder_layers = TransformerDecoderLayer(
41 | num_model, num_head, num_feedforward, dropout, activation, normalize_before)
42 | decoder_norm = nn.LayerNorm(num_model)
43 | self.decoder = TransformerDecoder(decoder_layers, num_decoder_layers, decoder_norm)
44 |
45 | self.project = nn.Linear(num_model, num_model)
46 |
47 | self._reset_parameters()
48 |
49 | def _reset_parameters(self):
50 | for name, param in self.named_parameters():
51 | if param.dim() > 1:
52 | nn.init.xavier_uniform_(param)
53 |
54 | def get_embeddings(self, input, label):
55 | val_embed = self.val_enc(input)
56 | pos_embed = self.pos_enc(val_embed)
57 | query_pos_embed = self.query_pos_enc(label)
58 |
59 | return val_embed, pos_embed, query_pos_embed
60 |
61 | def process_masks(self, input_mask, tgt_mask=None):
62 | # pad input mask
63 | padding_mask = torch.zeros((len(input_mask), self.num_token), device=input_mask.device).type_as(input_mask)
64 | input_mask = torch.cat([padding_mask, input_mask], dim=1)
65 | if tgt_mask is None:
66 | return input_mask
67 | # tgt is 1 shorter
68 | tgt_mask = tgt_mask[:, :-1].contiguous()
69 | return input_mask, tgt_mask
70 |
71 | def generate_square_subsequent_mask(self, sz):
72 | mask = (1 - torch.tril(torch.ones(sz, sz))) == 1
73 | return mask
74 |
75 | def patch_source(self, src, pos):
76 | src = src.transpose(0, 1)
77 | pos = pos.transpose(0, 1)
78 | return src, pos
79 |
80 | def patch_target(self, tgt, pos):
81 | tgt = tgt.transpose(0, 1)
82 | tgt, label = tgt[:-1].contiguous(), tgt[1:].contiguous()
83 | pos = pos.transpose(0, 1)
84 | pos = pos[:-1].contiguous()
85 | return tgt, label, pos
86 |
87 | def mix_gold_sampled(self, gold_target, sampled_target, prob):
88 | sampled_target = torch.cat((gold_target[0:1], sampled_target[:-1]), dim=0)
89 |
90 | targets = torch.stack((gold_target, sampled_target))
91 |
92 | random = torch.rand(gold_target.shape, device=targets.device)
93 | index = (random < prob).long().unsqueeze(0)
94 |
95 | new_target = torch.gather(targets, 0, index)
96 | return new_target.squeeze(0)
97 |
98 | def forward_train(self, inputs, scheduled_sampling_ratio=0):
99 | # inputs: N x L x P x D
100 | input, input_mask = inputs['input'], inputs['input_mask']
101 | label, label_mask = inputs['label'], inputs['label_mask']
102 |
103 | # process masks
104 | input_mask, label_mask = self.process_masks(input_mask, label_mask)
105 |
106 | # embeddings: N x L x E, L+=4
107 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label)
108 |
109 | # prepare data: L x N x E
110 | source, pos_embed = self.patch_source(val_embed, pos_embed)
111 | target, label, query_pos_embed = self.patch_target(label, query_pos_embed)
112 |
113 | # encoder: L x N x E
114 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed)
115 |
116 | # feature gather: T x N x E
117 | tgt_mask = self.generate_square_subsequent_mask(target.size(0)).to(target.device)
118 |
119 | # T x N
120 | gold_target = target
121 |
122 | if scheduled_sampling_ratio > 0:
123 | with torch.no_grad():
124 | target = target.unsqueeze(-1).repeat(1, 1, self.num_model)
125 |
126 | tgt = torch.gather(memory, 0, target)
127 |
128 | # decoder: T x N x E
129 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed,
130 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask)
131 |
132 | pointer = self.project(pointer)
133 |
134 | logits = torch.bmm(memory.transpose(0, 1), pointer.permute(1, 2, 0))
135 |
136 | logits = logits.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logits.dtype))
137 |
138 | sampled_target = torch.argmax(logits, dim=1).transpose(0, 1)
139 |
140 | target = self.mix_gold_sampled(gold_target, sampled_target, scheduled_sampling_ratio)
141 |
142 |
143 | target = target.unsqueeze(-1).repeat(1, 1, self.num_model)
144 | # memory: L x N x E
145 | # target: T x N x E
146 | # tgt: T x N x E
147 | tgt = torch.gather(memory, 0, target) # selects the targeting edge features
148 |
149 | # decoder: T x N x E
150 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed,
151 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask)
152 |
153 | pointer = self.project(pointer)
154 |
155 | # outputs
156 | inputs['embedding'] = memory.transpose(0, 1)
157 | inputs['pointer'] = pointer.transpose(0, 1)
158 | inputs['label'] = label.transpose(0, 1)
159 | return inputs
160 |
161 | def select_next(self, embedding, pointer, input_mask):
162 | embedding = embedding.transpose(0, 1)
163 | pointer = pointer.permute(1, 2, 0)
164 | logit = torch.bmm(embedding, pointer[..., -1:])
165 | logit = logit.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logit.dtype))
166 | next_token = torch.argmax(logit, dim=1).transpose(0, 1)
167 | return next_token
168 |
169 | def forward_eval(self, inputs):
170 | # inputs
171 | input, input_mask = inputs['input'], inputs['input_mask']
172 | label = inputs['label']
173 |
174 | batch_size = input.size(0)
175 |
176 | # process masks
177 | input_mask = self.process_masks(input_mask)
178 |
179 | # vertex embedding: N x L x E, L+=4
180 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label)
181 |
182 | # prepare data: L x N x E
183 | source, pos_embed = self.patch_source(val_embed, pos_embed)
184 | query_pos_embed = query_pos_embed.transpose(0, 1)
185 |
186 | # encoder: L x N x E
187 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed)
188 |
189 | # T x N
190 | predicts = torch.full((1, batch_size), self.token.SOS, dtype=torch.long).to(source.device)
191 | EOS_found = 0
192 |
193 | for step in range(self.num_labels - 1):
194 | target = predicts.unsqueeze(-1).repeat(1, 1, self.num_model)
195 |
196 | tgt = torch.gather(memory, 0, target)
197 |
198 | # decoder: T x N x E
199 | pointer = self.decoder(tgt, memory, memory_key_padding_mask=input_mask,
200 | pos=pos_embed, query_pos=query_pos_embed[:step+1])
201 |
202 | pointer = self.project(pointer)
203 |
204 | next_token = self.select_next(memory, pointer, input_mask)
205 |
206 | predicts = torch.cat((predicts, next_token), dim=0)
207 | EOS_found += next_token.eq(self.token.EOS).sum().item()
208 |
209 | if EOS_found == batch_size:
210 | break
211 |
212 | # pad predict to self.num_labels
213 | predicts = torch.cat((predicts, torch.zeros(self.num_labels - predicts.size(0), predicts.size(1)).type_as(predicts)), dim=0)
214 |
215 |
216 | inputs['embedding'] = memory.transpose(0, 1)
217 | inputs['pointer'] = pointer.transpose(0, 1)
218 | inputs['predict'] = predicts.transpose(0, 1)
219 | return inputs
220 |
221 | def forward(self, inputs):
222 | """
223 | inputs:
224 | input (N x L x P x D),
225 | label (N x T),
226 | input_mask (N x L),
227 | label_mask (N x T)
228 | outputs:
229 | embedding (N x L x E),
230 | pointer (N x T x E),
231 | label (N x T)
232 | """
233 | if self.training:
234 | outputs = self.forward_train(inputs)
235 | else:
236 | outputs = self.forward_eval(inputs)
237 | return outputs
238 |
--------------------------------------------------------------------------------
/faceformer/models/model_para.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from faceformer.embedding import PositionEmbeddingLearned, VanillaEmedding
5 | from faceformer.transformer import (TransformerDecoder,
6 | TransformerDecoderLayer,
7 | TransformerEncoder,
8 | TransformerEncoderLayer)
9 | from faceformer.utils import min_value_of_dtype
10 |
11 |
12 | class SurfaceFormer_Parallel(nn.Module):
13 |
14 | def __init__(self, num_model=512, num_head=8, num_feedforward=2048,
15 | num_encoder_layers=6, num_decoder_layers=6,
16 | dropout=0.1, activation="relu", normalize_before=True,
17 | num_points_per_line=50, num_lines=64, point_dim=2,
18 | max_face_length=10, token=None,
19 | teacher_forcing_ratio=0, **kwargs):
20 | super(SurfaceFormer_Parallel, self).__init__()
21 |
22 | self.num_model = num_model # E
23 | self.max_face_length = max_face_length # T
24 | self.teacher_forcing_ratio = teacher_forcing_ratio
25 | self.token = token
26 | self.num_token = token.len
27 |
28 | self.val_enc = VanillaEmedding(num_points_per_line * point_dim, num_model, token)
29 |
30 | # position encoding
31 | self.pos_enc = PositionEmbeddingLearned(num_model, max_len=num_lines+self.num_token)
32 | self.query_pos_enc = PositionEmbeddingLearned(num_model, max_len=max_face_length)
33 |
34 | # vertex transformer encoder
35 | encoder_layers = TransformerEncoderLayer(
36 | num_model, num_head, num_feedforward, dropout, activation, normalize_before)
37 | encoder_norm = nn.LayerNorm(num_model) if normalize_before else None
38 | self.encoder = TransformerEncoder(encoder_layers, num_encoder_layers, encoder_norm)
39 |
40 | # wire transformer decoder
41 | decoder_layers = TransformerDecoderLayer(
42 | num_model, num_head, num_feedforward, dropout, activation, normalize_before)
43 | decoder_norm = nn.LayerNorm(num_model)
44 | self.decoder = TransformerDecoder(decoder_layers, num_decoder_layers, decoder_norm)
45 |
46 | self.project = nn.Linear(num_model, num_model)
47 |
48 | self._reset_parameters()
49 |
50 | def _reset_parameters(self):
51 | for _, param in self.named_parameters():
52 | if param.dim() > 1:
53 | nn.init.xavier_uniform_(param)
54 |
55 | def get_embeddings(self, input, label):
56 | val_embed = self.val_enc(input)
57 | pos_embed = self.pos_enc(val_embed)
58 | query_pos_embed = self.query_pos_enc(label.transpose(1, 2))
59 |
60 | return val_embed, pos_embed, query_pos_embed
61 |
62 | def process_masks(self, input_mask, tgt_mask=None):
63 | # pad input mask
64 | padding_mask = torch.zeros((len(input_mask), self.num_token)).type_as(input_mask)
65 | input_mask = torch.cat([padding_mask, input_mask], dim=1)
66 | if tgt_mask is None:
67 | return input_mask
68 | # tgt is 1 shorter
69 | tgt_mask = tgt_mask[..., :-1].contiguous()
70 | return input_mask, tgt_mask
71 |
72 | def generate_square_subsequent_mask(self, sz):
73 | mask = (1 - torch.tril(torch.ones(sz, sz))) == 1
74 | return mask
75 |
76 | def patch_source(self, src, pos):
77 | src = src.transpose(0, 1)
78 | pos = pos.transpose(0, 1)
79 | return src, pos
80 |
81 | def patch_target(self, tgt, pos):
82 | tgt = tgt.permute(2, 0, 1)
83 | tgt, label = tgt[:-1].contiguous(), tgt[1:].contiguous()
84 | pos = pos.transpose(0, 1)
85 | pos = pos[:-1].contiguous()
86 | return tgt, label, pos
87 |
88 | def mix_gold_sampled(self, gold_target, sampled_target, prob):
89 | sampled_target = torch.cat((gold_target[0:1], sampled_target[:-1]), dim=0)
90 |
91 | targets = torch.stack((gold_target, sampled_target))
92 |
93 | random = torch.rand(gold_target.shape, device=targets.device)
94 | index = (random < prob).long().unsqueeze(0)
95 |
96 | new_target = torch.gather(targets, 0, index)
97 | return new_target.squeeze(0)
98 |
99 | def forward_train(self, inputs, scheduled_sampling_ratio=0):
100 | # inputs: N x L x P x D
101 | # target: N x F x T
102 | input, input_mask = inputs['input'], inputs['input_mask']
103 | label, label_mask = inputs['label'], inputs['label_mask']
104 | max_num_edges = max(inputs['num_input'])
105 | label, label_mask = label[:, :max_num_edges, :], label_mask[:, :max_num_edges, :]
106 |
107 | # process masks
108 | input_mask, label_mask = self.process_masks(input_mask, label_mask)
109 |
110 | # embeddings: N x L x E, L+=4
111 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label)
112 |
113 | # prepare data: L x N x E
114 | source, pos_embed = self.patch_source(val_embed, pos_embed)
115 | # target: T x N x (F)
116 | target, label, query_pos_embed = self.patch_target(label, query_pos_embed)
117 |
118 | # encoder: L x N x E
119 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed)
120 |
121 | # L x NF x E
122 | memory = memory.repeat_interleave(max_num_edges, 1)
123 |
124 | # feature gather: T x N x E
125 | tgt_mask = self.generate_square_subsequent_mask(target.size(0)).type_as(input_mask)
126 |
127 | # T x N
128 | gold_target = target
129 |
130 | if scheduled_sampling_ratio > 0:
131 | with torch.no_grad():
132 | target = target.unsqueeze(-1).repeat(1, 1, self.num_model)
133 |
134 | tgt = torch.gather(memory, 0, target)
135 |
136 | # decoder: T x N x E
137 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed,
138 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask)
139 |
140 | pointer = self.project(pointer)
141 |
142 | logits = torch.bmm(memory.transpose(0, 1), pointer.permute(1, 2, 0))
143 |
144 | logits = logits.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logits.dtype))
145 |
146 | sampled_target = torch.argmax(logits, dim=1).transpose(0, 1)
147 |
148 | target = self.mix_gold_sampled(gold_target, sampled_target, scheduled_sampling_ratio)
149 |
150 | # T x NF x E
151 | target = target.unsqueeze(-1).repeat(1, 1, 1, self.num_model).flatten(1, 2)
152 | # memory: L x NF x E
153 | # target: T x NF x E
154 | # tgt: T x NF x E
155 | tgt = torch.gather(memory, 0, target) # selects the targeting edge features
156 | # NF x T
157 | label_mask = label_mask.flatten(0, 1)
158 | # NF x L
159 | input_mask = input_mask.repeat_interleave(max_num_edges, 0)
160 |
161 | # decoder: T x NF x E
162 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed,
163 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask)
164 |
165 | pointer = self.project(pointer)
166 |
167 | # outputs
168 | inputs['embedding'] = memory.transpose(0, 1)
169 | inputs['pointer'] = pointer.transpose(0, 1)
170 | inputs['label'] = label.flatten(1, 2).transpose(0, 1)
171 | return inputs
172 |
173 | def select_next(self, embedding, pointer, input_mask):
174 | embedding = embedding.transpose(0, 1)
175 | pointer = pointer.permute(1, 2, 0)
176 | logit = torch.bmm(embedding, pointer[..., -1:])
177 | logit = logit.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logit.dtype))
178 | next_token = torch.argmax(logit, dim=1).transpose(0, 1)
179 | return next_token
180 |
181 | def forward_eval(self, inputs):
182 | # inputs: N x L x P x D
183 | input, input_mask = inputs['input'], inputs['input_mask']
184 | label = inputs['label']
185 |
186 | batch_size = input.size(0) # N
187 | max_num_edges = max(inputs['num_input']) # L
188 |
189 |
190 | # process masks
191 | input_mask = self.process_masks(input_mask)
192 |
193 | # vertex embedding: N x L x E, L+=4
194 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label)
195 |
196 | # prepare data: L x N x E
197 | source, pos_embed = self.patch_source(val_embed, pos_embed)
198 |
199 | # use all edges as the first token
200 | # anchors: 1 x N x (F)
201 | anchors = torch.arange(max_num_edges).repeat(1, batch_size, 1).type_as(label)
202 |
203 | # mask unused face seq start token as EOS (Other type of face)
204 | for i, num_edges in enumerate(inputs['num_input']):
205 | anchors[:, i, num_edges:] = self.token.len - 1
206 | query_pos_embed = query_pos_embed.transpose(0, 1)
207 | predicts = anchors.flatten(1, 2) # 1 x N(F)
208 |
209 | # encoder: L x N x E
210 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed)
211 | # L x N(F) x E
212 | memory = memory.repeat_interleave(max_num_edges, 1)
213 | # N(F) x L
214 | input_mask = input_mask.repeat_interleave(max_num_edges, 0)
215 |
216 | for step in range(self.max_face_length - 1):
217 | target = predicts.unsqueeze(-1).repeat(1, 1, self.num_model)
218 |
219 | tgt = torch.gather(memory, 0, target)
220 |
221 | # decoder: T x N(F) x E
222 | pointer = self.decoder(tgt, memory, memory_key_padding_mask=input_mask,
223 | pos=pos_embed, query_pos=query_pos_embed[:step+1])
224 |
225 | pointer = self.project(pointer)
226 |
227 | next_token = self.select_next(memory, pointer, input_mask)
228 |
229 | predicts = torch.cat((predicts, next_token), dim=0)
230 |
231 | # if all face tokens are EOS or PAD, stop decoding
232 | if torch.all(next_token < self.num_token):
233 | break
234 |
235 | # pad predict to self.max_face_length
236 | predicts = torch.cat((predicts, torch.zeros(self.max_face_length - predicts.size(0), predicts.size(1)).type_as(predicts)), dim=0)
237 |
238 | #predicts: T x N(F)
239 |
240 | inputs['predict'] = predicts.transpose(0, 1).view(-1, max_num_edges, self.max_face_length)
241 | return inputs
242 |
243 | def forward(self, inputs):
244 | """
245 | inputs:
246 | input (N x L x P x D),
247 | label (N x F x T),
248 | input_mask (N x L),
249 | label_mask (N x F x T)
250 | outputs:
251 | embedding (NF x L x E),
252 | pointer (NF x T x E),
253 | label (N x F x T)
254 | """
255 | if self.training:
256 | outputs = self.forward_train(inputs)
257 | else:
258 | outputs = self.forward_eval(inputs)
259 | return outputs
260 |
--------------------------------------------------------------------------------
/faceformer/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | DETR Transformer class.
4 |
5 | Copy-paste from torch.nn.Transformer with modifications:
6 | * positional encodings are passed in MHattention
7 | * extra LN at the end of encoder is removed
8 | * decoder returns a stack of activations from all decoding layers
9 | """
10 | import copy
11 | from typing import List, Optional
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch import Tensor, nn
16 |
17 |
18 | class Transformer(nn.Module):
19 |
20 | def __init__(self, num_model=512, num_head=8, num_encoder_layers=6,
21 | num_decoder_layers=6, num_feedforward=2048, dropout=0.1,
22 | activation="relu", normalize_before=False,
23 | return_intermediate_dec=False):
24 | super().__init__()
25 |
26 | encoder_layer = TransformerEncoderLayer(num_model, num_head, num_feedforward,
27 | dropout, activation, normalize_before)
28 | encoder_norm = nn.LayerNorm(num_model) if normalize_before else None
29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
30 |
31 | decoder_layer = TransformerDecoderLayer(num_model, num_head, num_feedforward,
32 | dropout, activation, normalize_before)
33 | decoder_norm = nn.LayerNorm(num_model)
34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
35 | return_intermediate=return_intermediate_dec)
36 |
37 | self._reset_parameters()
38 |
39 | self.num_model = num_model
40 | self.num_head = num_head
41 |
42 | def _reset_parameters(self):
43 | for p in self.parameters():
44 | if p.dim() > 1:
45 | nn.init.xavier_uniform_(p)
46 |
47 | def forward(self, src, mask, query_embed, pos_embed):
48 | # flatten NxCxHxW to HWxNxC
49 | bs, c, h, w = src.shape
50 | src = src.flatten(2).permute(2, 0, 1)
51 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
52 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
53 | mask = mask.flatten(1)
54 |
55 | tgt = torch.zeros_like(query_embed)
56 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
57 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
58 | pos=pos_embed, query_pos=query_embed)
59 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
60 |
61 |
62 | class TransformerEncoder(nn.Module):
63 |
64 | def __init__(self, encoder_layer, num_layers, norm=None):
65 | super().__init__()
66 | self.layers = _get_clones(encoder_layer, num_layers)
67 | self.num_layers = num_layers
68 | self.norm = norm
69 |
70 | def forward(self, src,
71 | mask: Optional[Tensor] = None,
72 | src_key_padding_mask: Optional[Tensor] = None,
73 | pos: Optional[Tensor] = None):
74 | output = src
75 |
76 | for layer in self.layers:
77 | output = layer(output, src_mask=mask,
78 | src_key_padding_mask=src_key_padding_mask, pos=pos)
79 |
80 | if self.norm is not None:
81 | output = self.norm(output)
82 |
83 | return output
84 |
85 |
86 | class TransformerDecoder(nn.Module):
87 |
88 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
89 | super().__init__()
90 | self.layers = _get_clones(decoder_layer, num_layers)
91 | self.num_layers = num_layers
92 | self.norm = norm
93 | self.return_intermediate = return_intermediate
94 |
95 | def forward(self, tgt, memory,
96 | tgt_mask: Optional[Tensor] = None,
97 | memory_mask: Optional[Tensor] = None,
98 | tgt_key_padding_mask: Optional[Tensor] = None,
99 | memory_key_padding_mask: Optional[Tensor] = None,
100 | pos: Optional[Tensor] = None,
101 | query_pos: Optional[Tensor] = None):
102 | output = tgt
103 |
104 | intermediate = []
105 |
106 | for layer in self.layers:
107 | output = layer(output, memory, tgt_mask=tgt_mask,
108 | memory_mask=memory_mask,
109 | tgt_key_padding_mask=tgt_key_padding_mask,
110 | memory_key_padding_mask=memory_key_padding_mask,
111 | pos=pos, query_pos=query_pos)
112 | if self.return_intermediate:
113 | intermediate.append(self.norm(output))
114 |
115 | if self.norm is not None:
116 | output = self.norm(output)
117 | if self.return_intermediate:
118 | intermediate.pop()
119 | intermediate.append(output)
120 |
121 | if self.return_intermediate:
122 | return torch.stack(intermediate)
123 |
124 | return output
125 |
126 |
127 | class TransformerEncoderLayer(nn.Module):
128 |
129 | def __init__(self, num_model, num_head, num_feedforward=2048, dropout=0.1,
130 | activation="relu", normalize_before=False):
131 | super().__init__()
132 | self.self_attn = nn.MultiheadAttention(num_model, num_head, dropout=dropout)
133 | # Implementation of Feedforward model
134 | self.linear1 = nn.Linear(num_model, num_feedforward)
135 | self.dropout = nn.Dropout(dropout)
136 | self.linear2 = nn.Linear(num_feedforward, num_model)
137 |
138 | self.norm1 = nn.LayerNorm(num_model)
139 | self.norm2 = nn.LayerNorm(num_model)
140 | self.dropout1 = nn.Dropout(dropout)
141 | self.dropout2 = nn.Dropout(dropout)
142 |
143 | self.activation = _get_activation_fn(activation)
144 | self.normalize_before = normalize_before
145 |
146 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
147 | return tensor if pos is None else tensor + pos
148 |
149 | def forward_post(self,
150 | src,
151 | src_mask: Optional[Tensor] = None,
152 | src_key_padding_mask: Optional[Tensor] = None,
153 | pos: Optional[Tensor] = None):
154 | q = k = self.with_pos_embed(src, pos)
155 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
156 | key_padding_mask=src_key_padding_mask)[0]
157 | src = src + self.dropout1(src2)
158 | src = self.norm1(src)
159 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
160 | src = src + self.dropout2(src2)
161 | src = self.norm2(src)
162 | return src
163 |
164 | def forward_pre(self, src,
165 | src_mask: Optional[Tensor] = None,
166 | src_key_padding_mask: Optional[Tensor] = None,
167 | pos: Optional[Tensor] = None):
168 | src2 = self.norm1(src)
169 | q = k = self.with_pos_embed(src2, pos)
170 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
171 | key_padding_mask=src_key_padding_mask)[0]
172 | src = src + self.dropout1(src2)
173 | src2 = self.norm2(src)
174 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
175 | src = src + self.dropout2(src2)
176 | return src
177 |
178 | def forward(self, src,
179 | src_mask: Optional[Tensor] = None,
180 | src_key_padding_mask: Optional[Tensor] = None,
181 | pos: Optional[Tensor] = None):
182 | if self.normalize_before:
183 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
184 | return self.forward_post(src, src_mask, src_key_padding_mask, pos)
185 |
186 |
187 | class TransformerDecoderLayer(nn.Module):
188 |
189 | def __init__(self, num_model, num_head, num_feedforward=2048, dropout=0.1,
190 | activation="relu", normalize_before=False):
191 | super().__init__()
192 | self.self_attn = nn.MultiheadAttention(num_model, num_head, dropout=dropout)
193 | self.multihead_attn = nn.MultiheadAttention(num_model, num_head, dropout=dropout)
194 | # Implementation of Feedforward model
195 | self.linear1 = nn.Linear(num_model, num_feedforward)
196 | self.dropout = nn.Dropout(dropout)
197 | self.linear2 = nn.Linear(num_feedforward, num_model)
198 |
199 | self.norm1 = nn.LayerNorm(num_model)
200 | self.norm2 = nn.LayerNorm(num_model)
201 | self.norm3 = nn.LayerNorm(num_model)
202 | self.dropout1 = nn.Dropout(dropout)
203 | self.dropout2 = nn.Dropout(dropout)
204 | self.dropout3 = nn.Dropout(dropout)
205 |
206 | self.activation = _get_activation_fn(activation)
207 | self.normalize_before = normalize_before
208 |
209 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
210 | return tensor if pos is None else tensor + pos
211 |
212 | def forward_post(self, tgt, memory,
213 | tgt_mask: Optional[Tensor] = None,
214 | memory_mask: Optional[Tensor] = None,
215 | tgt_key_padding_mask: Optional[Tensor] = None,
216 | memory_key_padding_mask: Optional[Tensor] = None,
217 | pos: Optional[Tensor] = None,
218 | query_pos: Optional[Tensor] = None):
219 | q = k = self.with_pos_embed(tgt, query_pos)
220 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
221 | key_padding_mask=tgt_key_padding_mask)[0]
222 | tgt = tgt + self.dropout1(tgt2)
223 | tgt = self.norm1(tgt)
224 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
225 | key=self.with_pos_embed(memory, pos),
226 | value=memory, attn_mask=memory_mask,
227 | key_padding_mask=memory_key_padding_mask)[0]
228 | tgt = tgt + self.dropout2(tgt2)
229 | tgt = self.norm2(tgt)
230 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
231 | tgt = tgt + self.dropout3(tgt2)
232 | tgt = self.norm3(tgt)
233 | return tgt
234 |
235 | def forward_pre(self, tgt, memory,
236 | tgt_mask: Optional[Tensor] = None,
237 | memory_mask: Optional[Tensor] = None,
238 | tgt_key_padding_mask: Optional[Tensor] = None,
239 | memory_key_padding_mask: Optional[Tensor] = None,
240 | pos: Optional[Tensor] = None,
241 | query_pos: Optional[Tensor] = None):
242 | tgt2 = self.norm1(tgt)
243 | q = k = self.with_pos_embed(tgt2, query_pos)
244 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
245 | key_padding_mask=tgt_key_padding_mask)[0]
246 | tgt = tgt + self.dropout1(tgt2)
247 | tgt2 = self.norm2(tgt)
248 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
249 | key=self.with_pos_embed(memory, pos),
250 | value=memory, attn_mask=memory_mask,
251 | key_padding_mask=memory_key_padding_mask)[0]
252 | tgt = tgt + self.dropout2(tgt2)
253 | tgt2 = self.norm3(tgt)
254 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
255 | tgt = tgt + self.dropout3(tgt2)
256 | return tgt
257 |
258 | def forward(self, tgt, memory,
259 | tgt_mask: Optional[Tensor] = None,
260 | memory_mask: Optional[Tensor] = None,
261 | tgt_key_padding_mask: Optional[Tensor] = None,
262 | memory_key_padding_mask: Optional[Tensor] = None,
263 | pos: Optional[Tensor] = None,
264 | query_pos: Optional[Tensor] = None):
265 | if self.normalize_before:
266 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
267 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
268 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
269 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
270 |
271 |
272 | def _get_clones(module, N):
273 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
274 |
275 |
276 | def _get_activation_fn(activation):
277 | """Return an activation function given a string"""
278 | if activation == "relu":
279 | return F.relu
280 | if activation == "gelu":
281 | return F.gelu
282 | if activation == "glu":
283 | return F.glu
284 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
285 |
--------------------------------------------------------------------------------
/faceformer/trainer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | from collections import Counter
5 |
6 | import numpy as np
7 | import pytorch_lightning as pl
8 | import torch
9 | import torch.nn.functional as F
10 | from numpyencoder import NumpyEncoder
11 |
12 | from faceformer.post_processing import (filter_faces_by_encloseness, map_coedge_into_edges)
13 | from faceformer.utils import flatten_list
14 |
15 |
16 | class Trainer(pl.LightningModule):
17 | def __init__(self, hparams, model_class, dataset_class):
18 | super().__init__()
19 | self.save_hyperparameters(hparams)
20 | self.model = model_class(**self.hparams.model)
21 | self.dataset_class = dataset_class
22 | self.validation_num = 0
23 | self.time_count = 0
24 | self.total_time = 0
25 | # self.hparams.root_dir = "ours"
26 |
27 | def forward(self, batch):
28 | return self.model(batch)
29 |
30 | def train_dataloader(self):
31 | dataset = self.dataset_class(self.hparams.root_dir, self.hparams.datasets_train, self.hparams.model)
32 | dataloader = torch.utils.data.DataLoader(
33 | dataset, batch_size=self.hparams.batch_size_train, num_workers=4,
34 | shuffle=True, drop_last=True)
35 | return dataloader
36 |
37 | def val_dataloader(self):
38 | name_dir = os.path.join(self.logger.log_dir, self.hparams.trainer.version)
39 | if not os.path.exists(name_dir):
40 | os.mkdir(name_dir)
41 | dataset = self.dataset_class(self.hparams.root_dir, self.hparams.datasets_valid, self.hparams.model)
42 | self.dataset = dataset
43 | dataloader = torch.utils.data.DataLoader(
44 | dataset, batch_size=self.hparams.batch_size_valid, num_workers=4,
45 | shuffle=False, drop_last=False)
46 | return dataloader
47 |
48 | def test_dataloader(self):
49 | dataset = self.dataset_class(self.hparams.root_dir, self.hparams.datasets_test, self.hparams.model)
50 | self.dataset = dataset
51 | self.hparams.batch_size_valid = 1
52 | dataloader = torch.utils.data.DataLoader(
53 | dataset, batch_size=self.hparams.batch_size_valid, num_workers=4,
54 | shuffle=False, drop_last=False)
55 | json_dir = os.path.join(self.logger.log_dir, 'json')
56 | if not os.path.exists(json_dir):
57 | os.mkdir(json_dir)
58 | return dataloader
59 |
60 | def compute_loss(self, outputs):
61 | embedding, pointer, labels = outputs['embedding'], outputs['pointer'], outputs['label']
62 |
63 | # embedding N x L x E, pointer N x T x E
64 | # logits: N x L x T
65 | logits = torch.bmm(embedding, pointer.transpose(1, 2))
66 |
67 | #label: N x T
68 | labels = labels.detach().clone()
69 | loss = F.cross_entropy(
70 | logits, labels, ignore_index=self.hparams.model.token.PAD, reduction='sum')
71 |
72 | valid = labels != self.hparams.model.token.PAD
73 | valid_sum = valid.sum()
74 | pred = torch.argmax(logits, dim=1)
75 | outputs['predict'] = pred
76 | acc_sum = (valid * (pred == labels)).sum()
77 | cls_acc = float(acc_sum) / (valid_sum + 1e-10)
78 |
79 | loss = loss / valid_sum
80 | return loss, cls_acc
81 |
82 | def training_step(self, batch, batch_idx):
83 | outputs = self.forward(batch)
84 | loss, acc = self.compute_loss(outputs)
85 | self.log('train_loss', loss, logger=True)
86 | self.log('train_cls_acc', acc, prog_bar=True, logger=True)
87 | if torch.isnan(loss):
88 | return None
89 | return loss
90 |
91 | def validation_step(self, batch, batch_idx):
92 | outputs = self.forward(batch)
93 | acc, outputs = self.face_accuracy(outputs)
94 |
95 | self.log('valid_accuracy', np.mean(outputs['accuracy']), logger=True)
96 | self.log('valid_type_acc_coedge_seq', np.mean(outputs['type_acc_coedge_seq']), logger=True)
97 | self.log('valid_precision', np.mean(outputs['precisions']), logger=True)
98 | self.log('valid_recall', np.mean(outputs['recalls']), logger=True)
99 | self.log('type_acc', np.mean(outputs['type_acc']), logger=True)
100 | for pred, label, prec in zip(outputs['predictions'], outputs['labels'], outputs['precisions']):
101 | self.logger.experiment.add_text('result', f'pred: {pred} \n\n label: {label} \n\n precision: {prec}', self.validation_num)
102 | # return outputs
103 |
104 | # saves all necessary info for reconstruction
105 | def test_step(self, batch, batch_idx):
106 | torch.cuda.synchronize()
107 | a = time.time()
108 | outputs = self.forward(batch)
109 | torch.cuda.synchronize()
110 | self.total_time += time.time() - a
111 | self.time_count += 1
112 | print("Avg Time", self.total_time / self.time_count, "seconds.")
113 | acc, outputs = self.face_accuracy(outputs)
114 |
115 | self.log('test_precision', np.mean(outputs['precisions']), logger=True)
116 | self.log('test_recall', np.mean(outputs['recalls']), logger=True)
117 | self.log('test_type_acc', np.mean(outputs['type_acc']), logger=True)
118 | for ind in range(len(outputs['predictions'])):
119 | predict_faces_w_types = outputs['predictions'][ind]
120 | label_faces_w_types = outputs['labels'][ind]
121 | json_name = batch['name'][ind]
122 | name = json_name[5:13]
123 |
124 | with open(os.path.join(self.hparams.root_dir, json_name), "r") as f:
125 | raw_data = json.loads(f.read())
126 | edges = raw_data['edges']
127 |
128 | predicted_data = {}
129 | predicted_data['edges'] = edges
130 | predicted_data['dominant_directions'] = raw_data['dominant_directions']
131 | predicted_data['pred_faces'] = predict_faces_w_types
132 | predicted_data['label_faces'] = label_faces_w_types
133 |
134 |
135 | with open(os.path.join(self.logger.log_dir, 'json', f'{name}.json'), 'w') as f:
136 | json.dump(predicted_data, f, cls=NumpyEncoder)
137 |
138 | def validation_epoch_end(self, outputs):
139 | self.validation_num += 1
140 |
141 | def configure_optimizers(self):
142 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.trainer.lr)
143 | if self.hparams.trainer.lr_step == 0:
144 | return optimizer
145 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.hparams.trainer.lr_step)
146 | return {
147 | 'optimizer': optimizer,
148 | 'lr_scheduler': scheduler
149 | }
150 |
151 | # Parse single-sequence faces
152 | # Return (0, faces' indices) to be consistent with face_typing task
153 | def parse_faces(self, predicts, labels, num_edges):
154 | # cut off tokens after the [EOS]
155 | label = np.split(labels, np.where(labels == self.hparams.model.token.EOS)[0]+1)[0]
156 | predict = np.split(predicts, np.where(predicts == self.hparams.model.token.EOS)[0] + 1)[0]
157 |
158 | label_faces = np.split(label, np.where(label == self.hparams.model.token.SEP)[0]+1) # split by SEP
159 | constructed_label_faces = []
160 | for face in label_faces:
161 | label = face[:-1] - self.hparams.model.token.len # remove SEP and remove token offset
162 | label = label[label >= 0]
163 | label = label[label < num_edges]
164 | if len(label) > 0:
165 | constructed_label_faces.append((0, tuple(label.tolist())))
166 |
167 | predict_faces = np.split(predict, np.where(predict == self.hparams.model.token.SEP)[0]+1) # split by SEP
168 | constructed_predict_faces = []
169 | for face in predict_faces:
170 | if len(face) > 1:
171 | predict = face[:-1] - self.hparams.model.token.len # remove SEP and remove token offset
172 | predict = predict[predict >= 0]
173 | predict = predict[predict < num_edges]
174 | if len(predict) > 0:
175 | constructed_predict_faces.append((0, tuple(predict.tolist())))
176 |
177 | return constructed_predict_faces, constructed_label_faces
178 |
179 | # Parse multi-sequence faces
180 | # Return faces' indices
181 | def parse_parallel_faces(self, predicts, labels, num_edges):
182 | predict_faces, label_faces = [], []
183 | # cut off tokens after the [EOS] (Type of face in this case)
184 | for label in labels:
185 | label = np.split(label, np.where((label >= self.hparams.model.token.face_type_offset) & (label < self.hparams.model.token.len))[0]+1)[0]
186 | # extract face type
187 | face_type = label[-1] - self.hparams.model.token.face_type_offset
188 | # remove token offset
189 | label -= self.hparams.model.token.len
190 | # only take the valid indices
191 | label = label[label >= 0]
192 | if len(label) > 0:
193 | # only count the face if face is not empty and not full of paddings
194 | label_faces.append((face_type, tuple(label.tolist())))
195 |
196 | for predict in predicts:
197 | predict = np.split(predict, np.where((predict >= self.hparams.model.token.face_type_offset) & (predict < self.hparams.model.token.len))[0]+1)[0]
198 | # extract face type
199 | face_type = predict[-1] - self.hparams.model.token.face_type_offset
200 | # remove token offset
201 | predict -= self.hparams.model.token.len
202 | # only take the valid indices
203 | predict = predict[predict >= 0]
204 | predict = predict[predict < num_edges]
205 | if len(predict) > 0:
206 | predict_faces.append((face_type, tuple(predict.tolist())))
207 |
208 | return predict_faces, label_faces
209 |
210 | def face_accuracy(self, outputs):
211 | labels = outputs['label'].cpu().numpy() # N (x F) x T
212 | predicts = outputs['predict'].cpu().numpy() # N (x L) x T
213 |
214 | outputs.update({'precisions': [], 'labels': [], 'type_acc_coedge_seq': [], \
215 | 'recalls': [], 'predictions': [], 'accuracy': [], 'type_acc':[]})
216 |
217 | for ind in range(len(labels)):
218 | edges = self.dataset.raw_datas[outputs['id'][ind]]['edges']
219 | if len(labels.shape) == 3:
220 | # multi-seq, parallel
221 | predict_faces, label_faces = self.parse_parallel_faces(predicts[ind], labels[ind], len(edges))
222 | else:
223 | # single-seq
224 | predict_faces, label_faces = self.parse_faces(predicts[ind], labels[ind], len(edges))
225 |
226 | if self.hparams.post_process.is_coedge:
227 | pairings = self.dataset.raw_datas[outputs['id'][ind]]['pairings']
228 |
229 | predict_faces = filter_faces_by_encloseness(edges, predict_faces, self.hparams.post_process.enclosedness_tol)
230 | label_faces = filter_faces_by_encloseness(edges, label_faces, self.hparams.post_process.enclosedness_tol)
231 |
232 | # calculate accuracy for faces with coedge
233 | # consider accuracy as the percent of predictions made correct
234 | face_tp = 0
235 | type_tp = 0
236 | for pred_type, pred_face in predict_faces:
237 | for label_type, label_face in set(label_faces):
238 | if pred_face == label_face:
239 | face_tp += 1
240 | if pred_type == label_type:
241 | type_tp += 1
242 | break
243 |
244 | if len(predict_faces) == 0:
245 | outputs['accuracy'].append(0)
246 | outputs['type_acc_coedge_seq'].append(0)
247 | else:
248 | outputs['accuracy'].append(face_tp / len(predict_faces))
249 | if face_tp == 0:
250 | outputs['type_acc_coedge_seq'].append(0)
251 | else:
252 | outputs['type_acc_coedge_seq'].append(type_tp / face_tp)
253 | # map coedge into edges
254 | label_faces = [(ftype, map_coedge_into_edges(pairings, flatten_list(loops))) for ftype, loops in label_faces]
255 | predict_faces = [(ftype, map_coedge_into_edges(pairings, flatten_list(loops))) for ftype, loops in predict_faces]
256 |
257 | # filter duplicate label faces
258 | label_faces_set = list(set([(ftype, tuple(sorted(set(indices)))) for ftype, indices in label_faces]))
259 |
260 | # determine face type by majority vote
261 | predict_unique_faces = {}
262 | for ftype, indices in predict_faces:
263 | face = tuple(sorted(set(indices)))
264 | if face in predict_unique_faces:
265 | predict_unique_faces[face].append(ftype)
266 | else:
267 | predict_unique_faces[face] = [ftype]
268 |
269 | predict_faces_set = [(Counter(ftypes).most_common(1)[0][0], face) for face, ftypes in predict_unique_faces.items()]
270 |
271 | # count TP
272 | face_tp = 0
273 | type_tp = 0
274 | for pred_type, pred_face in predict_faces_set:
275 | for label_type, label_face in label_faces_set:
276 | if pred_face == label_face:
277 | face_tp += 1
278 | if pred_type == label_type:
279 | type_tp += 1
280 | break
281 |
282 | if len(predict_faces_set) == 0 or len(label_faces_set) == 0:
283 | outputs['precisions'].append(0)
284 | outputs['recalls'].append(0)
285 | outputs['type_acc'].append(0)
286 | else:
287 | outputs['precisions'].append(face_tp / len(predict_faces_set))
288 | outputs['recalls'].append(face_tp / len(label_faces_set))
289 | if face_tp == 0:
290 | outputs['type_acc'].append(0)
291 | else:
292 | outputs['type_acc'].append(type_tp / face_tp)
293 | outputs['predictions'].append(predict_faces_set)
294 | outputs['labels'].append(label_faces_set)
295 |
296 | # with first token removed, we only look at non-padded elements
297 | valid = labels > self.hparams.model.token.PAD
298 | acc_sum = (valid * (predicts == labels)).sum()
299 | valid_sum = valid.sum()
300 | return acc_sum/valid_sum, outputs
301 |
302 |
303 |
--------------------------------------------------------------------------------
/reconstruction/reconstruct_to_wireframe.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | from functools import partial
6 |
7 | import cvxpy as cp
8 | import numpy as np
9 | import open3d as o3d
10 | from faceformer.utils import flatten_list
11 | from tqdm.contrib.concurrent import process_map
12 |
13 | from reconstruction.reconstruction_utils import (construct_connected_cylinder,
14 | dist, fit_curve,
15 | is_straight_line)
16 |
17 | INTERMEDIATE_TYPE = 11 # the 4 extra faces added to each cylinder, not used in final reconstruction
18 |
19 | def sample_points_on_line(line, sample_dist):
20 | x1, y1, x2, y2 = line[0][0], line[0][1], line[1][0], line[1][1]
21 | num_samples = int(np.sqrt((x1-x2)**2+(y1-y2)**2) / sample_dist) + 1
22 | t = np.linspace(0, 1, num_samples)
23 | x = x1 + (x2-x1) * t
24 | y = y1 + (y2-y1) * t
25 | return np.vstack([x, y]).T
26 |
27 | def reconstruct_file(name, root):
28 | try:
29 | if os.path.exists(os.path.join(root, 'ply', f"{name}.ply")):
30 | return
31 | data = json.load(open(os.path.join(root, 'json', f'{name}.json')))
32 | num_faces = len(data['pred_faces'])
33 | num_edges = len(data['edges'])
34 |
35 | to_add_new_planes = []
36 | to_add_new_edges = []
37 | face_removal_indices = []
38 | circle_face_to_construct = [] # should contain indices of two outlines
39 | circle_face_to_construct_dir = [] # should contain the direction of the outlines
40 |
41 | dom_directions = [np.array(d[:2]) / np.linalg.norm(d[:2]) for d in data['dominant_directions']]
42 | face_to_normal = {}
43 |
44 | # check face types for other
45 | for i, (face_type, loops) in enumerate(data['pred_faces']):
46 | if face_type not in [0, 1]:
47 | face_removal_indices.append(i)
48 | continue
49 |
50 | # for each cylinder face, we construct two new plane face
51 | # always select the mid point for plane reconstruction
52 | if face_type == 1:
53 | face_removal_indices.append(i)
54 |
55 | # cylinder face should have two curves and two straight lines
56 | all_edge_inds = list(loops)
57 | all_edges = [data['edges'][i] for i in all_edge_inds]
58 | # if more than two straight lines, not a cylinder, skip
59 | count = 0
60 | for edge in all_edges:
61 | if is_straight_line(edge):
62 | count += 1
63 | if count != 2:
64 | print(f"{name} has {count} straight lines, not a cylinder")
65 | continue
66 | try:
67 | all_edges, all_edge_inds, all_dirs = construct_connected_cylinder(all_edges, all_edge_inds)
68 | except:
69 | continue
70 |
71 | # assuming loop has 4 edges as a cylinder face
72 | if len(all_edges) != 4:
73 | # combine nearing curves
74 | i = 0
75 | while i < len(all_edges):
76 | next_edge_ind = (i+1) % len(all_edges)
77 | if not is_straight_line(all_edges[i]) and not is_straight_line(all_edges[next_edge_ind]):
78 | all_edges[i] += all_edges[next_edge_ind]
79 | all_edges.pop(next_edge_ind)
80 | all_edge_inds.pop(next_edge_ind)
81 | all_dirs.pop(next_edge_ind)
82 | continue
83 | i += 1
84 | if len(all_edges) != 4:
85 | print(f"{name} has {len(all_edges)} edges in a cylinder")
86 | continue
87 |
88 | # assert face is a wireframe with coedge directions
89 | # if straight line comes first,
90 | # then direction of new constructed line is opposite of the straight line
91 | if is_straight_line(all_edges[0]):
92 | line_ind = all_edge_inds[0]
93 | line = all_edges[0]
94 | line_dir = all_dirs[0]
95 | curve = all_edges[1]
96 | curve_ind = all_edge_inds[1]
97 | other_line_ind = all_edge_inds[2]
98 | other_line = all_edges[2]
99 | other_line_dir = all_dirs[2]
100 | other_curve_ind = all_edge_inds[3]
101 | else:
102 | curve = all_edges[0]
103 | curve_ind = all_edge_inds[0]
104 | other_line = all_edges[1]
105 | other_line_ind = all_edge_inds[1]
106 | other_line_dir = all_dirs[1]
107 | other_curve_ind = all_edge_inds[2]
108 | line = all_edges[3]
109 | line_ind = all_edge_inds[3]
110 | line_dir = all_dirs[3]
111 |
112 | # assert all length of cylinder straight lines are the same
113 | # displace midpoint of one curve the same amount to generate the middle edge
114 |
115 | direction = np.array(line[0]) - np.array(line[1])
116 | mid_point = np.array(curve[len(curve) // 2])
117 | # next_point = np.array(other_curve[len(other_curve) // 2])
118 |
119 | next_point = mid_point + direction
120 | mid_point = mid_point.tolist()
121 | next_point = next_point.tolist()
122 | new_mid_edge = [mid_point, next_point]
123 | new_edges = [new_mid_edge, [line[0], next_point], [line[1], mid_point], [other_line[1], next_point], [other_line[0], mid_point]]
124 | ind_offset = len(to_add_new_edges) + num_edges
125 | to_add_new_edges += new_edges
126 | face_1 = (INTERMEDIATE_TYPE, [line_ind, 2+ind_offset, ind_offset, 1+ind_offset])
127 | face_2 = (INTERMEDIATE_TYPE, [other_line_ind, 3+ind_offset, ind_offset, 4+ind_offset])
128 | to_add_new_planes += [face_1, face_2]
129 | circle_face_to_construct.append([line_ind, other_line_ind, ind_offset, curve_ind, other_curve_ind])
130 | circle_face_to_construct_dir.append([line_dir, other_line_dir, 1])
131 |
132 | # find plane's normal, assuming the axis of the cylinder face is aligned with one of the dominant directions
133 | edge_direction = np.array(line[0]) - np.array(line[1])
134 | normal_direction_ind = np.argmax([np.abs(np.dot(edge_direction, d)) for d in dom_directions])
135 |
136 | # find other coedge's circle plane and add normal constraint
137 | for i, (face_type, indices) in enumerate(data['pred_faces']):
138 | if curve_ind in indices or other_curve_ind in indices:
139 | face_to_normal[tuple(indices)] = normal_direction_ind
140 |
141 | # Add in new faces
142 | data['pred_faces'] += to_add_new_planes
143 | data['edges'] += to_add_new_edges
144 | num_faces = len(data['pred_faces'])
145 | num_edges = len(data['edges'])
146 |
147 |
148 |
149 | removed_faces = []
150 | # remove cylinder faces
151 | for i, ind in enumerate(face_removal_indices):
152 | removed_faces.append(data['pred_faces'].pop(ind-i))
153 |
154 | P = []
155 | b = []
156 | C = []
157 |
158 |
159 | # create equation for each perpendicular face's normal direction and dominant direction
160 | # check 2D parallelism to one of the dominant directions
161 | # edge parallel to one dominant direction => perpendicular to that direction
162 |
163 | # normalized dom_directions
164 | # # !! 2d dominant directions need human input
165 | # aa = -np.sum(dom_directions[0]*dom_directions[1])
166 | # bb = -np.sum(dom_directions[1]*dom_directions[2])
167 | # cc = -np.sum(dom_directions[0]*dom_directions[2])
168 | # z3 = np.sqrt(cc*bb / aa)
169 | # z2 = bb / z3
170 | # z1 = cc / z3
171 |
172 | # origin_directions = [dom_directions[0].tolist()+[z1], dom_directions[1].tolist()+[z2], dom_directions[2].tolist()+[z3]]
173 | # origin_directions = [np.array(d) / np.linalg.norm(d) for d in origin_directions]
174 | origin_directions = [np.array(d) / np.linalg.norm(d) for d in data['dominant_directions']]
175 | face_removal_indices = []
176 | for face_ind, (face_type, indices) in enumerate(data['pred_faces']):
177 | parallel_count_for_dom_directions = [0] * 3
178 | for edge_ind in indices:
179 | edge = data['edges'][edge_ind]
180 | if not is_straight_line(edge):
181 | continue
182 | edge_direction = np.array(edge[0]) - np.array(edge[1])
183 | edge_direction /= np.linalg.norm(edge_direction)
184 | # check if edge is parallel to one of the dominant directions
185 |
186 | for i, direction in enumerate(dom_directions):
187 | if np.abs(np.dot(edge_direction, direction)) > (1 - 1e-10):
188 | parallel_count_for_dom_directions[i] += 1
189 |
190 |
191 | # cylinder planes have predetermined normals from the outline
192 | if tuple(indices) in face_to_normal:
193 | normal_ind = face_to_normal[tuple(indices)]
194 | for i in range(3):
195 | if i != normal_ind:
196 | parallel_count_for_dom_directions[i] += 1
197 |
198 | if 0 not in parallel_count_for_dom_directions:
199 | # parallel to all dominant directions => wrong face prediction
200 | face_removal_indices.append(face_ind)
201 | continue
202 |
203 | # perpendicular to parallel directions
204 | for ind, count in enumerate(parallel_count_for_dom_directions):
205 | if count != 0:
206 | row = np.zeros(3 * num_faces)
207 | direction_3d = origin_directions[ind]
208 | # account for face removal
209 | face_ind -= len(face_removal_indices)
210 | row[3*face_ind: 3*face_ind+2] = [direction_3d[0], direction_3d[1]]
211 | brow = np.array([direction_3d[2]])
212 | P.append(row)
213 | b.append(brow)
214 | for i, ind in enumerate(face_removal_indices):
215 | data['pred_faces'].pop(ind-i)
216 |
217 | # find all unique vertices
218 | all_vertices = []
219 | all_used_edges = set(flatten_list([indices for _, indices in data['pred_faces']]))
220 | for ind in all_used_edges:
221 | all_vertices += data['edges'][ind]
222 |
223 | unique_vertices = []
224 | tol = 1e-4
225 | for vertex in all_vertices:
226 | dists = np.array([dist(p1, vertex) for p1 in unique_vertices])
227 | if np.sum(dists < tol) < 1:
228 | # new unique vertex
229 | unique_vertices.append(vertex)
230 |
231 | face_grouped_by_vertex = [[] for _ in range(len(unique_vertices))]
232 | # match faces to vertices
233 | for face_ind, (_, indices) in enumerate(data['pred_faces']):
234 | for edge_ind in indices:
235 | for point in data['edges'][edge_ind]:
236 | dists = np.array([dist(p1, point) for p1 in unique_vertices])
237 | vertex_ind = np.argmin(dists)
238 | face_grouped_by_vertex[vertex_ind].append(face_ind)
239 |
240 | face_grouped_by_vertex = [list(set(group)) for group in face_grouped_by_vertex]
241 | for vertex, face_group in zip(unique_vertices, face_grouped_by_vertex):
242 | if len(face_group) < 2:
243 | continue
244 | # for each two face joined on 1 vertex, we create one equation for them
245 | for f1, f2 in itertools.combinations(face_group, 2):
246 | row = np.zeros(3 * num_faces)
247 | row[f1*3: f1*3 + 3] = [vertex[0], vertex[1], 1]
248 | row[f2*3: f2*3 + 3] = [-vertex[0], -vertex[1], -1]
249 | brow = np.array([0])
250 | P.append(row)
251 | b.append(brow)
252 | # for each vertex and face, we create one constraint that z > 0
253 | for f in face_group:
254 | row = np.zeros(3 * num_faces)
255 | row[f*3: f*3 + 3] = [-vertex[0], -vertex[1], -1]
256 | C.append(row)
257 |
258 | P = np.array(P)
259 | C = np.array(C)
260 | b = np.array(b)
261 |
262 | n = P.shape[-1]
263 | if n == 0:
264 | return
265 | if C.shape[-1] == 0:
266 | return
267 |
268 | # sample points for 3d reconstruction
269 | pts = []
270 | pts_label = []
271 | sample_dist = 5e-3
272 | ind_to_3d_map = {} # contains the start of 3d samples for each edge index
273 | mid_edge_to_remove_start_ind = []
274 | mid_edge_inds = []
275 | for face_ind, (face_type, indices) in enumerate(data['pred_faces']):
276 | if face_type == INTERMEDIATE_TYPE:
277 | # only the first line (outline) and the third (mid edge) of intermediate face needs to be reconstructed
278 | sampled_pts = sample_points_on_line(data['edges'][indices[0]], sample_dist)
279 | pts.append(sampled_pts)
280 | ind_to_3d_map[indices[0]] = (len(pts_label), len(sampled_pts))
281 | pts_label += [face_ind] * len(sampled_pts)
282 | # memorize where the outline's corresponding 3d points are
283 | sampled_pts = sample_points_on_line(data['edges'][indices[2]], sample_dist)
284 | pts.append(sampled_pts)
285 | ind_to_3d_map[indices[2]] = (len(pts_label), len(sampled_pts))
286 | mid_edge_to_remove_start_ind.append(len(pts_label))
287 | mid_edge_inds.append(indices[2])
288 | pts_label += [face_ind] * len(sampled_pts)
289 | continue
290 | for edge_ind in indices:
291 | if is_straight_line(data['edges'][edge_ind]):
292 | sampled_pts = sample_points_on_line(data['edges'][edge_ind], sample_dist)
293 | pts.append(sampled_pts)
294 | ind_to_3d_map[edge_ind] = (len(pts_label), len(sampled_pts))
295 | pts_label += [face_ind] * len(sampled_pts)
296 |
297 | if len(pts) == 0:
298 | return
299 | pts = np.vstack(pts)
300 | pts_label = np.array(pts_label)
301 |
302 | f = cp.Variable((n, 1))
303 | try:
304 | objective = cp.Minimize(cp.norm1(P @ f + b))
305 | constraints = [C @ f >= 0]
306 | prob = cp.Problem(objective, constraints)
307 |
308 | result = prob.solve()
309 | except:
310 | return
311 | params = f.value.reshape(-1, 3)
312 |
313 | N = len(pts)
314 |
315 | pts_one = np.hstack((pts, np.ones((N, 1))))
316 |
317 | depth = np.sum(params[pts_label] * pts_one, axis=1, keepdims=True)
318 |
319 | xyz = np.hstack((pts, depth))
320 |
321 | # reconstruct the circle planes
322 | for i in range(len(circle_face_to_construct)):
323 | line_ind, other_line_ind, mid_edge_ind, curve_ind, other_curve_ind = circle_face_to_construct[i]
324 | line_dir, other_line_dir, mid_edge_dir = circle_face_to_construct_dir[i]
325 | # connections between two outlines give our center of the circle
326 | # direction of outline give us normal of the plane
327 | start_ind, num_samples = ind_to_3d_map[line_ind]
328 | pts = xyz[start_ind:start_ind+num_samples]
329 | other_start_ind, num_samples = ind_to_3d_map[other_line_ind]
330 | other_pts = xyz[other_start_ind:other_start_ind+num_samples]
331 | mid_edge_start_ind, num_samples = ind_to_3d_map[mid_edge_ind]
332 | mid_edge_pts = xyz[mid_edge_start_ind:mid_edge_start_ind+num_samples]
333 |
334 | p1, p2, p3 = pts[::line_dir][0], other_pts[::other_line_dir][-1], mid_edge_pts[::mid_edge_dir][-1]
335 | curve_pts = fit_curve(p1, p2, p3)
336 | ind_to_3d_map[other_curve_ind] = (len(xyz), len(curve_pts))
337 | xyz = np.vstack([xyz, curve_pts])
338 |
339 | p1, p2, p3 = pts[::line_dir][-1], other_pts[::other_line_dir][0], mid_edge_pts[::mid_edge_dir][0]
340 | curve_pts = fit_curve(p1, p2, p3)
341 | ind_to_3d_map[curve_ind] = (len(xyz), len(curve_pts))
342 | xyz = np.vstack([xyz, curve_pts])
343 |
344 |
345 | # add back the removed cylinder faces
346 | data['pred_faces'] += removed_faces
347 |
348 | # iterate through the faces and add edges that are not mid-edges
349 | points = []
350 | edges_drawn = set(mid_edge_inds)
351 | for face_type, indices in data['pred_faces']:
352 | if face_type == INTERMEDIATE_TYPE:
353 | continue
354 | for ind in indices:
355 | if ind in ind_to_3d_map and ind not in edges_drawn:
356 | start_ind, length = ind_to_3d_map[ind]
357 | points.append(xyz[start_ind:start_ind+length])
358 | edges_drawn.add(ind)
359 |
360 | pcd = o3d.geometry.PointCloud()
361 | pts = np.vstack(points)
362 | pts[:, 1] = -pts[:, 1]
363 | pcd.points = o3d.utility.Vector3dVector(pts)
364 |
365 | o3d.io.write_point_cloud(os.path.join(root, 'ply', f"{name}.ply"), pcd)
366 | except:
367 | print(f"{name} failed")
368 | return
369 |
370 | if __name__ == '__main__':
371 |
372 | parser = argparse.ArgumentParser()
373 | parser.add_argument('--root', type=str, default="/root/data",
374 | help='dataset root.')
375 | parser.add_argument('--name', type=str, default=None,
376 | help='filename.')
377 | parser.add_argument('--num_cores', type=int,
378 | default=10, help='number of processors.')
379 | parser.add_argument('--num_chunks', type=int,
380 | default=5, help='number of chunk.')
381 |
382 | args = parser.parse_args()
383 |
384 | os.makedirs(os.path.join(args.root, 'ply'), exist_ok=True)
385 |
386 | if args.name is not None:
387 | reconstruct_file(args.name, args.root)
388 | else:
389 | all_names = [name[:8] for name in os.listdir(os.path.join(args.root, 'json'))]
390 | process_map(partial(reconstruct_file, root=args.root), all_names,
391 | max_workers=args.num_cores, chunksize=args.num_chunks)
392 | # for name in all_names:
393 | # reconstruct_file(name, args.root)
394 |
--------------------------------------------------------------------------------
/dataset/prepare_data.py:
--------------------------------------------------------------------------------
1 | """
2 | Data Prep for ABC step Files
3 | * use HLR to render wireframe
4 | * break surfaces with outlines
5 | * remove seam lines
6 | """
7 | import argparse
8 | import json
9 | import os
10 | from functools import partial
11 |
12 | import numpy as np
13 | from OCC.Core.Bnd import Bnd_Box
14 | from OCC.Core.BRepBndLib import brepbndlib_Add
15 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_Transform
16 | from OCC.Core.gp import gp_Pnt, gp_Trsf, gp_Vec
17 | from OCC.Extend.TopologyUtils import TopologyExplorer
18 | from tqdm.contrib.concurrent import process_map
19 |
20 | from dataset.utils.discretize_edge import (DiscretizedEdge, sort_faces_by_indices,
21 | sort_edges_by_coordinate)
22 | from dataset.utils.json_to_svg import save_png, save_svg, save_svg_groups
23 | from dataset.utils.read_step_file import read_step_file
24 | from dataset.utils.TopoMapper import TopoMapper
25 |
26 | from dataset.tests.check_faces_enclosed import is_face_enclosed
27 | from faceformer.utils import flatten_list
28 | from dataset.utils.projection_utils import generate_random_camera_pos
29 |
30 | def get_boundingbox(shapes, tol=1e-6):
31 | """ return the bounding box of the TopoDS_Shape `shape`
32 | Parameters
33 | ----------
34 | shape : TopoDS_Shape or a subclass such as TopoDS_Face
35 | the shape to compute the bounding box from
36 | tol: float
37 | tolerance of the computed boundingbox
38 | """
39 | bbox = Bnd_Box()
40 | bbox.SetGap(tol)
41 | for shape in shapes:
42 | brepbndlib_Add(shape, bbox, False)
43 | xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get()
44 | center = (xmax + xmin) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2
45 | extent = abs(xmax-xmin), abs(ymax-ymin), abs(zmax-zmin)
46 | return center, extent
47 |
48 |
49 | def shape_to_svg(shape, name, args):
50 | """ export a single shape to an svg file and json.
51 | shape: the TopoDS_Shape to export
52 | """
53 | if shape.IsNull():
54 | raise AssertionError("shape is Null")
55 |
56 | topo = TopoMapper(shape, args)
57 | all_dedges = []
58 | faces_pointers = []
59 | face_types = []
60 | # face_parameters = []
61 | all_shrinked_dedges = []
62 | shape_center, _ = get_boundingbox([shape])
63 | # ind = 2
64 | # replacements = {2:[ind, (3, (ind-2)%3)], 3:[(ind-2)%3, (2, ind)]}
65 |
66 | for index, face in enumerate(topo.all_faces.values()):
67 | discretized_edges = face.get_oriented_dedges()
68 | discretized_edges_3d = face.get_oriented_dedges(is_3d=True)
69 |
70 | # generate smaller edges for visualization of faces
71 | face_edges_lists = [edge.edges for edge in face.edges]
72 | face_edges = flatten_list(face_edges_lists)
73 | center, _ = get_boundingbox(face_edges)
74 | translation = gp_Trsf()
75 | push_vec = np.array([center[0] - shape_center[0], center[1] - shape_center[1], center[2] - shape_center[2]]) * 1.04
76 | # push_vec += push_vec / np.sqrt(np.sum(push_vec**2)) * 0.1 # push a fixed amount
77 | # push the edges out along the line from the center to the face
78 | translation.SetTranslation(gp_Vec(*push_vec))
79 |
80 | # scale = gp_Trsf()
81 | # scale.SetScale(gp_Pnt(*center), 0.7)
82 | shrinked_dedges = []
83 | for i, edge_list in enumerate(face_edges_lists):
84 | # if index in replacements and i == replacements[index][0]:
85 | # other_face_index = replacements[index][1][0]
86 | # other_face_edges_lists = [edge.edges for edge in list(topo.all_faces.values())[other_face_index].edges]
87 | # edge_list = other_face_edges_lists[replacements[index][1][1]]
88 | shrinked_edges = []
89 | for edge in edge_list:
90 | brep_trans = BRepBuilderAPI_Transform(edge, translation)
91 | edge = brep_trans.Shape()
92 | shrinked_edges.append(edge)
93 | shrinked_dedges.append(topo._raw_project(shrinked_edges, args.discretize_last))
94 |
95 | all_shrinked_dedges.append(shrinked_dedges)
96 | filename = os.path.join(args.root, 'face_svg', f'{name}_{index}.svg')
97 | save_svg(discretized_edges, filename, args)
98 | save_png(f'{name}_{index}', args, prefix='face_')
99 |
100 |
101 | filename = os.path.join(args.root, 'face_shrinked_face_svg', f'{name}_{index}.svg')
102 | save_svg(shrinked_dedges, filename, args)
103 | save_png(f'{name}_{index}', args, prefix='face_shrinked_face_')
104 |
105 | # generate data for face ground truth
106 | if args.combine_coedge:
107 | # combine coedges => all coedges share the same direction
108 | for edge in face.edges:
109 | if edge.DiscretizedEdge is None:
110 | edge.DiscretizedEdge = DiscretizedEdge(edge.dedge)
111 | all_dedges.append(edge.DiscretizedEdge)
112 | face_pointers = [edge.DiscretizedEdge for edge in face.edges]
113 | faces_pointers.append(face_pointers)
114 | else:
115 | assert len(discretized_edges) == len(shrinked_dedges)
116 | assert len(discretized_edges) == len(discretized_edges_3d)
117 | # each edge is represented as two discretized edges in two directions
118 | # save 3d points
119 | face_pointers = [DiscretizedEdge(dedge, smaller_edge=shrinked_dedge, edge3d=dedge_3d) \
120 | for dedge, shrinked_dedge, dedge_3d in zip(discretized_edges, shrinked_dedges, discretized_edges_3d)]
121 | all_dedges += face_pointers
122 | # get face pointers
123 | faces_pointers.append(face_pointers)
124 |
125 | face_types.append(face.face_type)
126 | # face_parameters.append(face.parameters)
127 |
128 | all_dedges = sort_edges_by_coordinate(all_dedges)
129 | # assign index to each dedge
130 | for index, dedge in enumerate(all_dedges):
131 | dedge.index = index
132 |
133 | faces_indices = []
134 | for face_pointers in faces_pointers:
135 | if args.order_by_position:
136 | faces_indices.append(sorted([dedge.index for dedge in face_pointers]))
137 | else:
138 | faces_indices.append([dedge.index for dedge in face_pointers])
139 |
140 | save_svg([edge.dedge for edge in topo.all_edges.values()], os.path.join(
141 | args.root, 'svg', f'{name}.svg'), args)
142 | save_png(name, args)
143 | save_svg_groups(all_shrinked_dedges, os.path.join(
144 | args.root, 'face_shrinked_svg', f'{name}.svg'), args)
145 | save_png(name, args, prefix='face_shrinked_')
146 |
147 | if args.combine_coedge:
148 | faces_indices = [np.roll(face, -np.argmin(face), axis=0).tolist() for face in faces_indices]
149 | faces_indices = sort_faces_by_indices(faces_indices)
150 | else:
151 | # check enclosedness here, raise error if not enclosed
152 | # group indices
153 | all_edge_points = [dedge.points for dedge in all_dedges]
154 | sorted_faces_indices = []
155 | for i, face in enumerate(faces_indices):
156 | all_face_loops = is_face_enclosed(all_edge_points, face, args.tol * 2)
157 | if not all_face_loops:
158 | raise Exception("faces unenclosed")
159 | # roll enclosed loops so smallest index is at the front
160 | all_face_loops = [np.roll(loop, -np.argmin(loop), axis=0).tolist() for loop in all_face_loops]
161 | # loops are ordered by first index
162 | all_face_loops = sorted(all_face_loops, key=lambda x: x[0])
163 | # sorted_faces_indices.append([face_types[i], all_face_loops, face_parameters[i]])
164 | if args.no_face_type:
165 | sorted_faces_indices.append(all_face_loops)
166 | else:
167 | sorted_faces_indices.append([face_types[i], all_face_loops])
168 |
169 | # each face:
170 | # [
171 | # type,
172 | # [loops],
173 | # [parameters]
174 | # ]
175 | # order faces by first index
176 | if args.no_face_type:
177 | faces_indices = sorted(sorted_faces_indices, key=lambda x: x[0][0])
178 | else:
179 | faces_indices = sorted(sorted_faces_indices, key=lambda x: x[1][0][0])
180 | edges_to_json(all_dedges, faces_indices, name, topo.get_dominant_directions())
181 |
182 |
183 | def shape_to_svg_direction_token(shape, name, args):
184 | """ export a single shape to an svg file and json.
185 | ! combine coedge, and for each face index, give a direction indicator 0/1
186 | shape: the TopoDS_Shape to export
187 | """
188 | if shape.IsNull():
189 | raise AssertionError("shape is Null")
190 |
191 | topo = TopoMapper(shape, args)
192 | all_dedges = []
193 | faces_pointers = []
194 | face_types = []
195 |
196 | for index, face in enumerate(topo.all_faces.values()):
197 | # save shape visualization
198 | discretized_edges = face.get_oriented_dedges()
199 | filename = os.path.join(args.root, 'face_svg', f'{name}_{index}.svg')
200 | save_svg(discretized_edges, filename, args)
201 | save_png(f'{name}_{index}', args, prefix='face_')
202 |
203 | # generate data for face ground truth
204 | for edge in face.edges:
205 | if edge.DiscretizedEdge is None:
206 | edge.DiscretizedEdge = DiscretizedEdge(edge.dedge)
207 | all_dedges.append(edge.DiscretizedEdge)
208 | # e-> edge, o-> orientation
209 | face_pointers = [(e.DiscretizedEdge, o) for e, o in zip(face.edges, face.edge_orientations)]
210 | faces_pointers.append(face_pointers)
211 | face_types.append(face.face_type)
212 |
213 | # save face visualization
214 | save_svg([edge.dedge for edge in topo.all_edges.values()], os.path.join(
215 | args.root, 'svg', f'{name}.svg'), args)
216 | save_png(name, args)
217 |
218 | # generate data for face ground truth
219 | all_dedges = sort_edges_by_coordinate(all_dedges)
220 | # assign index to each dedge
221 | for index, dedge in enumerate(all_dedges):
222 | dedge.index = index
223 |
224 | faces_indices = []
225 | for face_pointers in faces_pointers:
226 | # o-> orientation
227 | faces_indices.append([(dedge.index, o) for dedge, o in face_pointers])
228 |
229 | # check enclosedness here, raise error if not enclosed
230 | # group indices
231 | all_edge_points = [dedge.points for dedge in all_dedges]
232 | sorted_faces_indices = []
233 | for face in faces_indices:
234 | all_face_loops = is_face_enclosed(all_edge_points, face, args.tol * 2)
235 | if not all_face_loops:
236 | raise Exception("faces unenclosed")
237 | # roll enclosed loops so smallest index is at the front
238 | all_face_loops = [np.roll(loop, -np.argmin([t[0] for t in loop]), axis=0).tolist() for loop in all_face_loops]
239 | # loops are ordered by first index
240 | all_face_loops = sorted(all_face_loops, key=lambda x: x[0][0])
241 | sorted_faces_indices.append(all_face_loops)
242 |
243 | # order faces by first index
244 | faces_indices = sorted(sorted_faces_indices, key=lambda x: x[0][0][0])
245 | edges_to_json(all_dedges, faces_indices, name, topo.get_dominant_directions())
246 |
247 |
248 |
249 | def edges_to_json(all_dedges, faces_indices, name, dominant_directions):
250 | # write to json
251 | json_filename = os.path.join(args.root, 'json', f'{name}.json')
252 | data = {}
253 | data['edges'] = [dedge.points for dedge in all_dedges]
254 | data['edges3d'] = [dedge.edge3d for dedge in all_dedges]
255 | data['shrinked_edges'] = [dedge.smaller_edge for dedge in all_dedges]
256 | data['faces_indices'] = faces_indices
257 | data['dominant_directions'] = dominant_directions
258 | data['pairings'] = {}
259 | # find all pairings of indices
260 | for i in range(len(data['edges'])):
261 | for j in range(i+1, len(data['edges'])):
262 | if data['edges'][i] == data['edges'][j][::-1]:
263 | data['pairings'][i] = j
264 | with open(json_filename, 'w') as f:
265 | json.dump(data, f)
266 |
267 |
268 | def render_shape_and_faces(name, args):
269 | try:
270 | # if os.path.exists(os.path.join(args.root, 'json', f'{name}.json')):
271 | # return
272 | step_path = os.path.join(args.root, 'step', f'{name}.step')
273 | # step read timeout at 5 seconds
274 | try:
275 | shape, num_shapes = read_step_file(step_path, verbosity=False)
276 | except:
277 | print(f"{name} took too long to read")
278 | return
279 |
280 | if shape is None:
281 | print(f"{name} is NULL shape")
282 | return
283 |
284 | if num_shapes > args.filter_num_shapes:
285 | print(f"{name} has {num_shapes} shapes. Too many!")
286 | return
287 |
288 | topology_explorer = TopologyExplorer(shape)
289 |
290 | if len(list(topology_explorer.edges())) > args.filter_num_edges:
291 | print(f"{name} has too many edges.")
292 | return
293 |
294 | center, extent = get_boundingbox([shape])
295 |
296 | trans, scale = gp_Trsf(), gp_Trsf()
297 | trans.SetTranslation(-gp_Vec(*center))
298 | scale.SetScale(gp_Pnt(0, 0, 0), 2 / np.linalg.norm(extent))
299 | brep_trans = BRepBuilderAPI_Transform(shape, scale * trans)
300 | shape = brep_trans.Shape()
301 |
302 | args.pose = None
303 | # generate random camera position
304 | if args.random_camera:
305 | # 5 tries at random angle image
306 | for _ in range(5):
307 | try:
308 | focus, cam_pose = generate_random_camera_pos(args.seed)
309 | args.pose = cam_pose
310 | # check orthographic projection
311 | if args.focus != 0:
312 | args.focus = focus
313 | if args.direction_token:
314 | shape_to_svg_direction_token(shape, name, args)
315 | else:
316 | shape_to_svg(shape, name, args)
317 | return
318 | except:
319 | continue
320 |
321 | if args.direction_token:
322 | shape_to_svg_direction_token(shape, name, args)
323 | else:
324 | shape_to_svg(shape, name, args)
325 |
326 | except Exception as e:
327 | print(f"{name} received unknown error", e)
328 |
329 | def prepare_splits(args):
330 | if os.path.exists(args.id_list):
331 | with open(args.id_list, 'r') as f:
332 | names = json.load(f)
333 | else:
334 | names = []
335 | for name in sorted(os.listdir(os.path.join(args.root, 'json'))):
336 | names.append(name[:8])
337 |
338 | np.random.seed(args.seed)
339 | np.random.shuffle(names)
340 | train_ratio, valid_ratio, test_ratio = args.split
341 | trainlist, validlist, testlist = np.split(names, [int(
342 | len(names) * train_ratio), int(len(names) * (train_ratio + valid_ratio))])
343 |
344 | np.savetxt(os.path.join(args.root, 'train.txt'), trainlist, fmt="json/%s.json")
345 | np.savetxt(os.path.join(args.root, 'valid.txt'), validlist, fmt="json/%s.json")
346 | np.savetxt(os.path.join(args.root, 'test.txt'), testlist, fmt="json/%s.json")
347 |
348 |
349 | def main(args):
350 | np.random.seed(args.seed)
351 | os.makedirs(os.path.join(args.root, 'svg'), exist_ok=True)
352 | os.makedirs(os.path.join(args.root, 'png'), exist_ok=True)
353 | os.makedirs(os.path.join(args.root, 'face_shrinked_face_svg'), exist_ok=True)
354 | os.makedirs(os.path.join(args.root, 'face_shrinked_face_png'), exist_ok=True)
355 | os.makedirs(os.path.join(args.root, 'face_shrinked_svg'), exist_ok=True)
356 | os.makedirs(os.path.join(args.root, 'face_shrinked_png'), exist_ok=True)
357 | os.makedirs(os.path.join(args.root, 'face_svg'), exist_ok=True)
358 | os.makedirs(os.path.join(args.root, 'face_png'), exist_ok=True)
359 | os.makedirs(os.path.join(args.root, 'json'), exist_ok=True)
360 |
361 | if os.path.exists(args.id_list):
362 | with open(args.id_list, 'r') as f:
363 | names = json.load(f)
364 | else:
365 | names = []
366 | for name in sorted(os.listdir(os.path.join(args.root, 'step'))):
367 | names.append(os.path.splitext(name)[0])
368 |
369 | if not args.only_split:
370 | process_map(
371 | partial(render_shape_and_faces, args=args), names,
372 | max_workers=args.num_cores, chunksize=args.num_chunks
373 | )
374 |
375 | prepare_splits(args)
376 |
377 |
378 | if __name__ == '__main__':
379 | parser = argparse.ArgumentParser()
380 | parser.add_argument('--root', type=str, default="./data",
381 | help='dataset root.')
382 | parser.add_argument('--id_list', type=str, default="None",
383 | help='filtered(with similarity) data id list')
384 | parser.add_argument('--name', type=str, default=None,
385 | help='filename.')
386 | parser.add_argument('--num_cores', type=int,
387 | default=5, help='number of processors.')
388 | parser.add_argument('--num_chunks', type=int,
389 | default=10, help='number of chunk.')
390 | parser.add_argument('--width', type=int,
391 | default=256, help='svg width.')
392 | parser.add_argument('--height', type=int,
393 | default=256, help='svg height.')
394 | parser.add_argument('--png_padding', type=float,
395 | default=0.2, help='padding from content to the edge of png.')
396 | parser.add_argument('--tol', type=float,
397 | default=1e-4, help='svg discretization tolerance.')
398 | parser.add_argument('--face_shrink_scale', type=float,
399 | default=0.8, help='shrinking face for visualization.')
400 | parser.add_argument('--line_width', type=str,
401 | default=str(6/256), help='svg line width.')
402 | parser.add_argument('--filter_num_shapes', type=int,
403 | default=1, help='do not process step files \
404 | that have more than this number of shapes.')
405 | parser.add_argument('--filter_num_edges', type=int,
406 | default=64, help='do not process step files \
407 | that have more than this number of edges.')
408 | parser.add_argument('--location', nargs="+", type=float,
409 | default=[1, 1, 1], help='projection location')
410 | parser.add_argument('--direction', nargs="+", type=float,
411 | default=[1, 1, 1], help='projection direction')
412 | parser.add_argument('--focus', type=float,
413 | default=3, help='focus of the projection camera.')
414 | parser.add_argument('--split', nargs="+", type=int,
415 | default=[0.93, 0.02, 0.05],
416 | help='train/valid/test split ratio')
417 | parser.add_argument('--only_split', action='store_true')
418 | parser.add_argument('--combine_coedge', action='store_true')
419 | parser.add_argument('--order_by_position', action='store_true')
420 | parser.add_argument('--direction_token', action='store_true')
421 | parser.add_argument('--random_camera', action='store_true')
422 | parser.add_argument('--discretize_last', action='store_true')
423 | parser.add_argument('--no_face_type', action='store_true')
424 | parser.add_argument('--seed', type=int, default=42,
425 | help='numpy random seed')
426 |
427 | args = parser.parse_args()
428 |
429 | if args.name is None:
430 | main(args)
431 | else:
432 | render_shape_and_faces(args.name, args)
433 |
--------------------------------------------------------------------------------