├── assets └── teaser.gif ├── faceformer ├── datasets │ ├── __init__.py │ ├── data.py │ └── data_para.py ├── models │ ├── __init__.py │ ├── model.py │ └── model_para.py ├── utils.py ├── post_processing.py ├── config.py ├── embedding.py ├── transformer.py └── trainer.py ├── configs ├── seq2seq.yml ├── seq2seq+coedge.yml ├── ours.yml ├── ours-perspective.yml └── ours-fixed_viewpoint.yml ├── CITATION.cff ├── environment.yml ├── LICENSE ├── split_jsons.py ├── dataset ├── filters │ ├── filter_length.py │ ├── filter_thinness.py │ ├── filter_topology.py │ ├── filter_3view.py │ ├── filter_thickness.py │ └── 3view_render.py ├── reorganize_dataset_dirs.py ├── utils │ ├── read_step_file.py │ ├── Edge.py │ ├── discretize_edge.py │ ├── projection_utils.py │ ├── Face.py │ ├── json_to_svg.py │ └── TopoMapper.py ├── README.md ├── tests │ └── check_faces_enclosed.py └── prepare_data.py ├── main.py ├── README.md └── reconstruction ├── utils.py ├── reconstruction_utils.py └── reconstruct_to_wireframe.py /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/manycore-research/faceformer/HEAD/assets/teaser.gif -------------------------------------------------------------------------------- /faceformer/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import ABCDataset 2 | from .data_para import ABCDataset_Parallel -------------------------------------------------------------------------------- /faceformer/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import SurfaceFormer 2 | from .model_para import SurfaceFormer_Parallel -------------------------------------------------------------------------------- /configs/seq2seq.yml: -------------------------------------------------------------------------------- 1 | # All faces as a seq, cylinder face cut into 2 2 | 3 | # dataset generation method: 4 | # - prepare_data.py ... --combine_coedge --order_by_position --random_camera --focus 0 5 | 6 | root_dir: "test_set/seq2seq" 7 | trainer: 8 | name: 'SurfaceFormer' 9 | version: 'seq2seq' 10 | num_gpus: [0] 11 | 12 | model: 13 | num_lines: 110 14 | label_seq_length: 259 15 | max_num_faces: 42 16 | 17 | post_process: 18 | is_coedge: False 19 | -------------------------------------------------------------------------------- /configs/seq2seq+coedge.yml: -------------------------------------------------------------------------------- 1 | # All faces as a seq, cylinder face cut into 2, 2 | # ordered edges within each face, each edge considered twice. 3 | 4 | # dataset generation method: 5 | # - prepare_data.py ... --no_face_type --random_camera --focus 0 6 | 7 | root_dir: "test_set/seq2seq+coedge" 8 | trainer: 9 | name: 'SurfaceFormer' 10 | version: 'seq2seq+coedge' 11 | num_gpus: [0] 12 | 13 | model: 14 | num_lines: 216 15 | label_seq_length: 259 16 | 17 | post_process: 18 | is_coedge: True -------------------------------------------------------------------------------- /configs/ours.yml: -------------------------------------------------------------------------------- 1 | # Each face as a seq 2 | 3 | # dataset generation method: 4 | # python dataset/prepare_data.py ... --random_camera --focus 0 5 | 6 | model_class: 'SurfaceFormer_Parallel' 7 | dataset_class: 'ABCDataset_Parallel' 8 | root_dir: "test_set/ours" 9 | 10 | batch_size_train: 4 11 | batch_size_valid: 20 12 | 13 | trainer: 14 | name: 'SurfaceFormer' 15 | version: 'ours' 16 | lr: 1.0e-4 17 | num_gpus: [0] 18 | 19 | model: 20 | num_lines: 216 21 | max_num_faces: 42 22 | max_face_length: 37 23 | token: 24 | PAD: 0 25 | face_type_offset: 1 26 | len: 4 27 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | preferred-citation: 4 | type: conference-paper 5 | collection-type: proceedings 6 | title: "Neural Face Identification in a 2D Wireframe Projection of a Manifold Object" 7 | authors: 8 | - family-names: Wang 9 | given-names: Kehan 10 | - family-names: Zheng 11 | given-names: Jia 12 | - family-names: Zhou 13 | given-names: Zihan 14 | collection-title: "Proceedings of IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)" 15 | start: 1622 16 | end: 1631 17 | year: 2022 18 | -------------------------------------------------------------------------------- /configs/ours-perspective.yml: -------------------------------------------------------------------------------- 1 | # Each face as a seq in perspective view 2 | 3 | # dataset generation method: 4 | # python dataset/prepare_data.py ... --random_camera 5 | 6 | model_class: 'SurfaceFormer_Parallel' 7 | dataset_class: 'ABCDataset_Parallel' 8 | root_dir: "test_set/ours-perspective" 9 | 10 | batch_size_train: 4 11 | batch_size_valid: 20 12 | 13 | trainer: 14 | name: 'SurfaceFormer' 15 | version: 'ours-perspective' 16 | lr: 1.0e-4 17 | num_gpus: [0] 18 | 19 | model: 20 | num_lines: 202 21 | max_num_faces: 42 22 | max_face_length: 38 23 | token: 24 | PAD: 0 25 | face_type_offset: 1 26 | len: 4 -------------------------------------------------------------------------------- /configs/ours-fixed_viewpoint.yml: -------------------------------------------------------------------------------- 1 | # Each face as a seq with fixed viewpoint 2 | 3 | # dataset generation method: 4 | # python dataset/prepare_data.py ... --focus 0 5 | 6 | model_class: 'SurfaceFormer_Parallel' 7 | dataset_class: 'ABCDataset_Parallel' 8 | root_dir: "test_set/ours-fixed_viewpoint" 9 | 10 | batch_size_train: 4 11 | batch_size_valid: 20 12 | 13 | trainer: 14 | name: 'SurfaceFormer' 15 | version: 'ours-fixed_viewpoint' 16 | lr: 1.0e-4 17 | num_gpus: [0] 18 | 19 | model: 20 | num_lines: 186 21 | max_num_faces: 42 22 | max_face_length: 33 23 | token: 24 | PAD: 0 25 | face_type_offset: 1 26 | len: 4 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: faceformer 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - cudatoolkit=11.0.221 9 | - jpeg=9b 10 | - json5=0.9.6 11 | - pillow=8.2.0 12 | - python=3.7.10 13 | - pythonocc-core=7.4.1 14 | - pythreejs=2.3.0 15 | - pytorch=1.7.1 16 | - pip=21.1.3 17 | - pip: 18 | - async-timeout==3.0.1 19 | - cairosvg==2.5.2 20 | - cvxpy==1.1.17 21 | - easydict==1.9 22 | - h5py==3.1.0 23 | - html4vision==0.4.3 24 | - matplotlib==3.4.2 25 | - numpy==1.19.5 26 | - numpyencoder==0.3.0 27 | - open3d==0.13.0 28 | - opencv-python==4.5.2.54 29 | - pytorch-lightning==1.3.5 30 | - pyyaml==5.4.1 31 | - scikit-learn==0.24.2 32 | - scipy==1.6.3 33 | - svgwrite==1.4.1 34 | - timeout-decorator==0.5.0 35 | - torchmetrics==0.3.2 36 | - tqdm==4.61.1 37 | - trimesh==3.9.20 38 | - CairoSVG==2.5.2 39 | - fvcore==0.1.5.post20210617 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Manycore Tech Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /split_jsons.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import shutil 5 | 6 | def prepare_splits(args): 7 | names = [] 8 | os.makedirs(os.path.join(args.root, 'json'), exist_ok=True) 9 | for name in sorted(os.listdir(args.root)): 10 | names.append(name[:8]) 11 | shutil.move(os.path.join(args.root, name), os.path.join(args.root, "json")) 12 | 13 | np.random.seed(args.seed) 14 | np.random.shuffle(names) 15 | train_ratio, valid_ratio, test_ratio = args.split 16 | trainlist, validlist, testlist = np.split(names, [int( 17 | len(names) * train_ratio), int(len(names) * (train_ratio + valid_ratio))]) 18 | 19 | np.savetxt(os.path.join(args.root, 'train.txt'), trainlist, fmt="json/%s.json") 20 | np.savetxt(os.path.join(args.root, 'valid.txt'), validlist, fmt="json/%s.json") 21 | np.savetxt(os.path.join(args.root, 'test.txt'), testlist, fmt="json/%s.json") 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--root', type=str, default="./ours", 26 | help='dataset root. All files under root are all .json files.') 27 | parser.add_argument('--seed', type=int, default=42, 28 | help='numpy random seed') 29 | parser.add_argument('--split', nargs="+", type=int, 30 | default=[0.93, 0.02, 0.05], 31 | help='train/valid/test split ratio') 32 | 33 | args = parser.parse_args() 34 | 35 | prepare_splits(args) 36 | -------------------------------------------------------------------------------- /dataset/filters/filter_length.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from tqdm import tqdm 6 | 7 | 8 | def main(args): 9 | if args.clean_start: 10 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'r') as f: 11 | names = json.load(f) 12 | else: 13 | names = [os.path.splitext(name)[0] for name in os.listdir(os.path.join(args.root, 'json'))] 14 | 15 | filtered_names = [] 16 | 17 | for name in tqdm(names): 18 | path = os.path.join(args.root, 'json', f'{name}.json') 19 | with open(path, 'r') as f: 20 | data = json.load(f) 21 | total_len = 0 22 | for face in data["faces_indices"]: 23 | total_len += 1+len(face) 24 | total_len += 1 25 | if total_len < args.face_seq_max and len(data['edges']) < args.num_edge_max: 26 | filtered_names.append(name) 27 | 28 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'w') as f: 29 | json.dump(filtered_names, f) 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--root', type=str, default='/home/tianhan/data', 35 | help='dataset root') 36 | parser.add_argument('--face_seq_max', type=int, default=128, 37 | help='max length for the constructed face label') 38 | parser.add_argument('--num_edge_max', type=int, default=64, 39 | help='max number of edges in a shape') 40 | parser.add_argument('--clean_start', action='store_true', 41 | help='start from a clean id list') 42 | args = parser.parse_args() 43 | 44 | main(args) 45 | -------------------------------------------------------------------------------- /faceformer/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def info_value_of_dtype(dtype: torch.dtype): 5 | """ 6 | Returns the `finfo` or `iinfo` object of a given PyTorch data type. Does not allow torch.bool. 7 | """ 8 | if dtype == torch.bool: 9 | raise TypeError("Does not support torch.bool") 10 | elif dtype.is_floating_point: 11 | return torch.finfo(dtype) 12 | else: 13 | return torch.iinfo(dtype) 14 | 15 | 16 | def min_value_of_dtype(dtype: torch.dtype): 17 | """ 18 | Returns the minimum value of a given PyTorch data type. Does not allow torch.bool. 19 | """ 20 | return info_value_of_dtype(dtype).min 21 | 22 | 23 | def max_value_of_dtype(dtype: torch.dtype): 24 | """ 25 | Returns the maximum value of a given PyTorch data type. Does not allow torch.bool. 26 | """ 27 | return info_value_of_dtype(dtype).max 28 | 29 | 30 | def tiny_value_of_dtype(dtype: torch.dtype): 31 | """ 32 | Returns a moderately tiny value for a given PyTorch data type that is used to avoid numerical 33 | issues such as division by zero. 34 | This is different from `info_value_of_dtype(dtype).tiny` because it causes some NaN bugs. 35 | Only supports floating point dtypes. 36 | """ 37 | if not dtype.is_floating_point: 38 | raise TypeError("Only supports floating point dtypes.") 39 | if dtype == torch.float or dtype == torch.double: 40 | return 1e-13 41 | elif dtype == torch.half: 42 | return 1e-4 43 | else: 44 | raise TypeError("Does not support dtype " + str(dtype)) 45 | 46 | 47 | def flatten_list(l): 48 | """ 49 | Flattens a list of lists. 50 | """ 51 | return [item for sublist in l for item in sublist] 52 | -------------------------------------------------------------------------------- /dataset/reorganize_dataset_dirs.py: -------------------------------------------------------------------------------- 1 | import os, argparse, time 2 | from tqdm import tqdm 3 | 4 | 5 | def main(args): 6 | for name in tqdm(sorted(os.listdir(os.path.join(args.root, args.subdir)))): 7 | dirpath = os.path.join(args.root, args.subdir, name) 8 | # rename file name to 8 digits 9 | if not os.path.isdir(dirpath): 10 | index_name, suffix = os.path.splitext(name) 11 | if len(index_name) != 8: 12 | srcpath = os.path.join(args.root, args.subdir, name) 13 | dstpath = os.path.join(args.root, args.subdir, index_name[:8]+suffix) 14 | os.rename(srcpath, dstpath) 15 | continue 16 | # move file out from their individual folder 17 | filenames = os.listdir(dirpath) 18 | if len(filenames) == 0: 19 | os.rmdir(dirpath) 20 | continue 21 | filename = filenames[0] 22 | suffix = os.path.splitext(filename)[1] 23 | srcpath = os.path.join(args.root, args.subdir, name, filename) 24 | dstpath = os.path.join(args.root, args.subdir, name+suffix) 25 | dirpath = os.path.join(args.root, args.subdir, name) 26 | os.rename(srcpath, dstpath) 27 | os.rmdir(dirpath) 28 | with open('data_processing_log.txt', 'a') as f: 29 | f.write(f"Reorganized folder {args.subdir} to proper structure - " + time.ctime() + '\n') 30 | f.close() 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--root', type=str, default="./data", 35 | help='dataset root.') 36 | parser.add_argument('--subdir', type=str, default="step", 37 | help='dataset sub-directory to be reorganized.') 38 | args = parser.parse_args() 39 | 40 | main(args) -------------------------------------------------------------------------------- /faceformer/post_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from dataset.tests.check_faces_enclosed import is_face_enclosed 4 | from faceformer.utils import flatten_list 5 | 6 | 7 | # For each face, if it is enclosed, sort its loops 8 | def filter_faces_by_encloseness(edges, faces, tol): 9 | # find corresponding edges from face indices 10 | filtered_faces = [] 11 | for face_type, face in faces: 12 | all_face_loops = is_face_enclosed(edges, face, tol) 13 | if all_face_loops: 14 | # roll enclosed loops so smallest index is at the front 15 | all_face_loops = [tuple(np.roll(loop, -np.argmin(loop), axis=0).astype(int).tolist()) for loop in all_face_loops] 16 | # loops are ordered by first index 17 | all_face_loops = sorted(all_face_loops, key=lambda x: x[0]) 18 | filtered_faces.append((face_type, tuple(all_face_loops))) 19 | 20 | return filtered_faces 21 | 22 | # Two coedges that represent the same edge should not be used in the same face 23 | def filter_faces_by_coedge(pairings, faces): 24 | filtered_faces = [] 25 | used_indices = set() 26 | for face in faces: 27 | indices = flatten_list(face[1]) 28 | drop_face = False 29 | for index in indices: 30 | if index in pairings: 31 | index = pairings[index] 32 | if index in used_indices: 33 | drop_face = True 34 | break 35 | used_indices.add(index) 36 | if not drop_face: 37 | filtered_faces.append(face) 38 | 39 | return filtered_faces 40 | 41 | def map_coedge_into_edges(pairings, indices): 42 | new_indices = [] 43 | for i in indices: 44 | if str(i) in pairings: 45 | new_indices.append(pairings[str(i)]) 46 | else: 47 | new_indices.append(i) 48 | return new_indices 49 | -------------------------------------------------------------------------------- /dataset/utils/read_step_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import timeout_decorator 4 | from OCC.Core.IFSelect import IFSelect_ItemsByEntity, IFSelect_RetDone 5 | from OCC.Core.STEPControl import STEPControl_Reader 6 | from OCC.Extend.TopologyUtils import list_of_shapes_to_compound 7 | 8 | 9 | @timeout_decorator.timeout(5, use_signals=False) 10 | def read_step_file(filename, as_compound=True, verbosity=True, filter_num_shape=10): 11 | """ read the STEP file and returns a compound and number of shapes 12 | filename: the file path 13 | verbosity: optional, False by default. 14 | as_compound: True by default. If there are more than one shape at root, 15 | gather all shapes into one compound. Otherwise returns a list of shapes. 16 | """ 17 | if not os.path.isfile(filename): 18 | raise FileNotFoundError("%s not found." % filename) 19 | 20 | step_reader = STEPControl_Reader() 21 | status = step_reader.ReadFile(filename) 22 | 23 | if status == IFSelect_RetDone: # check status 24 | if verbosity: 25 | failsonly = False 26 | step_reader.PrintCheckLoad(failsonly, IFSelect_ItemsByEntity) 27 | step_reader.PrintCheckTransfer(failsonly, IFSelect_ItemsByEntity) 28 | transfer_result = step_reader.TransferRoots() 29 | if not transfer_result: 30 | raise AssertionError("Transfer failed.") 31 | _nbs = step_reader.NbShapes() 32 | if _nbs == 0: 33 | raise AssertionError("No shape to transfer.") 34 | elif _nbs == 1: # most cases 35 | return step_reader.Shape(1), _nbs 36 | elif _nbs > 1: 37 | if _nbs > filter_num_shape: 38 | return None, _nbs 39 | shps = [] 40 | # loop over root shapes 41 | for k in range(1, _nbs + 1): 42 | new_shp = step_reader.Shape(k) 43 | if not new_shp.IsNull(): 44 | shps.append(new_shp) 45 | if as_compound: 46 | compound, result = list_of_shapes_to_compound(shps) 47 | if not result: 48 | print("Warning: all shapes were not added to the compound") 49 | return compound, _nbs 50 | else: 51 | print("Warning, returns a list of shapes.") 52 | return shps, _nbs 53 | else: 54 | raise AssertionError("Error: can't read file.") 55 | return None 56 | -------------------------------------------------------------------------------- /faceformer/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from fvcore.common.config import CfgNode 4 | 5 | CN = CfgNode 6 | 7 | _C = CN() 8 | _C.model_class = 'SurfaceFormer' 9 | _C.dataset_class = 'ABCDataset' 10 | _C.root_dir = "/root/data" 11 | 12 | _C.batch_size_train = 64 13 | _C.batch_size_valid = 128 14 | _C.datasets_train = ['train.txt'] 15 | _C.datasets_valid = ['valid.txt'] 16 | _C.datasets_test = ['test.txt'] 17 | 18 | _C.trainer = CN() 19 | _C.trainer.name = "surfaceformer" 20 | _C.trainer.version = "baseline" 21 | _C.trainer.num_gpus = [0] 22 | _C.trainer.precision = 16 # 16-bit training 23 | _C.trainer.checkpoint_period = 2 24 | _C.trainer.lr = 1e-3 25 | _C.trainer.lr_step = 0 26 | 27 | _C.model = CN() 28 | _C.model.num_points_per_line = 50 29 | _C.model.num_lines = 64 30 | _C.model.point_dim = 2 31 | _C.model.label_seq_length = 128 32 | _C.model.max_num_faces = 42 33 | _C.model.max_face_length = 34 34 | _C.model.num_model = 512 35 | _C.model.num_head = 8 36 | _C.model.num_feedforward = 1024 37 | _C.model.num_encoder_layers = 6 38 | _C.model.num_decoder_layers = 6 39 | _C.model.dropout = 0.2 40 | _C.model.token = CN() 41 | _C.model.token.PAD = 0 42 | _C.model.token.SOS = 1 43 | _C.model.token.SEP = 2 44 | _C.model.token.EOS = 3 45 | _C.model.token.DIR0 = 4 46 | _C.model.token.DIR1 = 5 47 | _C.model.token.len = 4 48 | _C.model.token.face_type_offset = 1 49 | 50 | _C.post_process = CN() 51 | _C.post_process.enclosedness_tol = 2e-4 52 | _C.post_process.is_coedge = True 53 | 54 | def get_parser(): 55 | parser = argparse.ArgumentParser(description="SurfaceFormer Training") 56 | parser.add_argument("--config-file", default="", metavar="FILE", 57 | help="path to config file") 58 | parser.add_argument("--valid_ckpt", default="", 59 | help="path to validation checkpoint") 60 | parser.add_argument("--test_ckpt", default="", 61 | help="path to testing checkpoint") 62 | parser.add_argument("--resume_ckpt", default="", 63 | help="path to training checkpoint, will continue train from here") 64 | parser.add_argument( 65 | "opts", 66 | help="Modify config options using the command-line", 67 | default=None, 68 | nargs=argparse.REMAINDER, 69 | ) 70 | return parser 71 | 72 | 73 | def get_cfg(args): 74 | cfg = _C.clone() 75 | if args.config_file: 76 | cfg.merge_from_file(args.config_file) 77 | cfg.merge_from_list(args.opts) 78 | cfg.freeze() 79 | return cfg 80 | -------------------------------------------------------------------------------- /dataset/utils/Edge.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class Edge: 4 | ''' 5 | Edge is unique by the edge's hash 6 | Each edge should have two faces 7 | ''' 8 | 9 | def __init__(self, edge, faces=[], orientations=[], dedge=None, index=None, DiscretizedEdge=None, dedge3d=None): 10 | self.edge = edge 11 | self.edges = [edge] 12 | self.faces = faces 13 | self.orientations = orientations 14 | self.dedge = dedge 15 | self.dedge3d = dedge3d 16 | self.index = index # index among all edges in TopoMapper, for construct faces 17 | self.DiscretizedEdge = DiscretizedEdge 18 | 19 | def add_face(self, face, orientation): 20 | self.faces.append(face) 21 | self.orientations.append(orientation) 22 | assert len(self.faces) <= 2, "Too many faces for one edge" 23 | 24 | def get_oriented_dedge(self, orientation, is_3d=False): 25 | ''' 26 | Discretized edge is saved as normal orientation. 27 | return reversed when orientation is reversed. 28 | ''' 29 | if is_3d: 30 | return self.dedge3d[::-1] if orientation else self.dedge3d 31 | return self.dedge[::-1] if orientation else self.dedge 32 | 33 | def __hash__(self): 34 | return hash(self.edge) 35 | 36 | def __eq__(self, other): 37 | return isinstance(other, Edge) and hash(self) == hash(other) 38 | 39 | def same_orientation(self, other): 40 | dist1 = np.sum(abs(np.array(self.dedge[-1]) - np.array(other.dedge[0]))) 41 | dist2 = np.sum(abs(np.array(other.dedge[-1]) - np.array(self.dedge[0]))) 42 | return dist1 < dist2 43 | 44 | def merge(self, other, topo): 45 | ''' 46 | Merge two edges, considering the faces being merged. 47 | Not assigning self.edge to None so hash of the edge is still available. 48 | Assuming the orientation of dedge is the same. 49 | ''' 50 | assert isinstance(other, Edge), 'Cannot merge edge with non-edge' 51 | # check orientation by looking at start and end of two edges 52 | if self.same_orientation(other): 53 | self.dedge = self.dedge + other.dedge 54 | self.edges = self.edges + other.edges 55 | else: 56 | self.dedge = other.dedge + self.dedge 57 | self.edges = other.edges + self.edges 58 | 59 | # remove other in its faces 60 | for face in other.faces: 61 | i = face.keys.index(hash(other.edge)) 62 | del face.edges[i] 63 | del face.edge_orientations[i] 64 | del face.keys[i] 65 | 66 | # remove other edge from topo 67 | del topo.all_edges[hash(other.edge)] 68 | return self 69 | -------------------------------------------------------------------------------- /dataset/utils/discretize_edge.py: -------------------------------------------------------------------------------- 1 | from functools import cmp_to_key 2 | 3 | import numpy as np 4 | 5 | 6 | class DiscretizedEdge: 7 | def __init__(self, points, smaller_edge=None, edge3d=None): 8 | self.points = points 9 | self.index = None 10 | self.smaller_edge = smaller_edge 11 | self.edge3d = edge3d 12 | 13 | def __eq__(self, obj): 14 | return isinstance(obj, DiscretizedEdge) and obj.points == self.points 15 | 16 | def correct_edge_direction(self, tolerance=1e-10): 17 | """ 18 | Given a discretized_edge 19 | Point edge in the direction of smaller coordinate to larger coordinate. 20 | """ 21 | if self.is_enclosed(tolerance): 22 | self.sort_enclosing_edge() 23 | else: 24 | if comp_points(self.points[0], self.points[-1]) > 0: 25 | # reverse edge 26 | self.points = list(reversed(self.points)) 27 | 28 | # check for enclosed polyline with tolerance 29 | def is_enclosed(self, tolerance): 30 | return abs(self.points[0][0] - self.points[-1][0]) < tolerance and \ 31 | abs(self.points[0][1] - self.points[-1][1]) < tolerance 32 | 33 | # rotate points for an enclosing edge 34 | def sort_enclosing_edge(self): 35 | # take out the repeating start/end 36 | enclosing_edge = self.points[1:] 37 | 38 | # find smallest starting point 39 | edge_array = np.array(enclosing_edge) 40 | d_edge = np.roll( 41 | edge_array, -np.argmin(edge_array[:, 0]), axis=0).tolist() 42 | 43 | # sort direction clock-wise by y-axis 44 | if d_edge[1][1] > d_edge[-1][1]: 45 | d_edge.append(d_edge[0]) 46 | else: 47 | d_edge = [d_edge[0]] + list(reversed(d_edge)) 48 | 49 | self.points = d_edge 50 | 51 | 52 | # rank coordinates first by x, then by y 53 | def comp_points(p1, p2): 54 | if p1[0] == p2[0]: 55 | return p1[1] - p2[1] 56 | return p1[0] - p2[0] 57 | 58 | # rank edges in sequence of the points 59 | # assuming edges themselves are sorted 60 | 61 | 62 | def comp_edges(e1, e2): 63 | e1, e2 = e1.points, e2.points 64 | N = min(len(e1), len(e2)) 65 | for i in range(N): 66 | diff = comp_points(e1[i], e2[i]) 67 | if diff == 0: 68 | continue 69 | return diff 70 | return 0 71 | 72 | 73 | def sort_edges_by_coordinate(edges): 74 | return sorted(edges, key=cmp_to_key(comp_edges)) 75 | 76 | 77 | def comp_face_by_index(f1, f2): 78 | N = min(len(f1), len(f2)) 79 | for i in range(N): 80 | diff = f1[i] - f2[i] 81 | if diff == 0: 82 | continue 83 | return diff 84 | return 0 85 | 86 | 87 | def sort_faces_by_indices(faces): 88 | return sorted(faces, key=cmp_to_key(comp_face_by_index)) 89 | 90 | -------------------------------------------------------------------------------- /dataset/filters/filter_thinness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from functools import partial 5 | 6 | import numpy as np 7 | import trimesh 8 | import yaml 9 | from tqdm.contrib.concurrent import process_map 10 | 11 | 12 | def scale_to_unit_sphere(mesh): 13 | if isinstance(mesh, trimesh.Scene): 14 | mesh = mesh.dump().sum() 15 | 16 | vertices = mesh.vertices - mesh.bounding_box.centroid 17 | vertices *= 2 / np.linalg.norm(mesh.bounding_box.extents) 18 | 19 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False, maintain_order=True) 20 | 21 | 22 | def filter_by_raidus(name, args): 23 | mesh_path = os.path.join(args.root, 'obj', f'{name}.obj') 24 | mesh = trimesh.load_mesh(mesh_path, process=False, maintain_order=True) 25 | 26 | if isinstance(mesh, trimesh.Scene): 27 | mesh = mesh.dump().sum() 28 | 29 | scale = np.linalg.norm(mesh.bounding_box.extents) 30 | 31 | feat_path = os.path.join(args.root, 'feat', f'{name}.yml') 32 | with open(feat_path) as file: 33 | annos = yaml.full_load(file) 34 | 35 | radius_array = [] 36 | 37 | for curve in annos['curves']: 38 | # finder the case with thinner cylinder 39 | if curve['type'] in ['Circle']: 40 | radius = curve['radius'] / scale 41 | 42 | elif curve['type'] in ['Ellipse']: 43 | radius = min(curve['maj_radius'], curve['min_radius']) / scale 44 | 45 | else: 46 | continue 47 | 48 | radius_array.append(radius) 49 | 50 | if len(radius_array) != 0: 51 | with open(os.path.join(args.root, 'radius', f'{name}.json'), 'w') as f: 52 | json.dump(min(radius_array), f) 53 | 54 | return name 55 | 56 | 57 | def main(args): 58 | with open(os.path.join(args.root, "meta", "filtered_thickness.json"), 'r') as f: 59 | names = json.load(f) 60 | 61 | os.makedirs(os.path.join(args.root, 'radius'), exist_ok=True) 62 | 63 | # preprocess 64 | rets = process_map( 65 | partial(filter_by_raidus, args=args), names, 66 | max_workers=args.num_cores, chunksize=args.num_chunks) 67 | 68 | filtered = [ret for ret in rets if ret is not None] 69 | 70 | with open(os.path.join(args.root, 'meta', 'filtered_thinness.json'), 'w') as f: 71 | json.dump(filtered, f) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--root', type=str, default='/root/Datasets/Faceformer', 77 | help='dataset root') 78 | parser.add_argument('--threshold', type=float, default=0.05, 79 | help='threshold for closer edge') 80 | parser.add_argument('--num_cores', type=int, 81 | default=8, help='number of processors.') 82 | parser.add_argument('--num_chunks', type=int, 83 | default=64, help='number of chunk.') 84 | args = parser.parse_args() 85 | 86 | main(args) 87 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | 7 | from faceformer.config import get_cfg, get_parser 8 | from faceformer.datasets import * 9 | from faceformer.models import * 10 | from faceformer.trainer import Trainer 11 | 12 | 13 | def str_to_class(classname): 14 | return getattr(sys.modules[__name__], classname) 15 | 16 | class CudaClearCacheCallback(pl.Callback): 17 | def on_train_start(self, trainer, pl_module): 18 | torch.cuda.empty_cache() 19 | def on_validation_start(self, trainer, pl_module): 20 | torch.cuda.empty_cache() 21 | def on_validation_end(self, trainer, pl_module): 22 | torch.cuda.empty_cache() 23 | 24 | if __name__ == "__main__": 25 | args = get_parser().parse_args() 26 | 27 | cfg = get_cfg(args) 28 | 29 | model_class = str_to_class(cfg.model_class) 30 | dataset_class = str_to_class(cfg.dataset_class) 31 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 32 | save_last=True, 33 | filename='{epoch:d}-{valid_precision:.2f}', 34 | save_top_k=2, 35 | monitor='valid_precision', 36 | mode='max', 37 | every_n_val_epochs=1) 38 | 39 | logger = pl.loggers.TensorBoardLogger('logs/', name=cfg.trainer.name, version=cfg.trainer.version) 40 | 41 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(c) for c in cfg.trainer.num_gpus]) 42 | gpus = list(range(len(cfg.trainer.num_gpus))) 43 | 44 | if args.test_ckpt != '': 45 | # Testing 46 | model = Trainer(cfg, model_class, dataset_class).load_from_checkpoint(args.test_ckpt, model_class=model_class, dataset_class=dataset_class) 47 | trainer = pl.Trainer( 48 | benchmark=True, 49 | gpus=gpus, 50 | precision=cfg.trainer.precision) 51 | trainer.test(model) 52 | elif args.valid_ckpt != '': 53 | # Validation 54 | model = Trainer(cfg, model_class, dataset_class).load_from_checkpoint(args.valid_ckpt, model_class=model_class, dataset_class=dataset_class) 55 | trainer = pl.Trainer( 56 | benchmark=True, 57 | gpus=gpus, 58 | precision=cfg.trainer.precision) 59 | trainer.validate(model) 60 | elif args.resume_ckpt != '': 61 | # Resume Training 62 | model = Trainer(cfg, model_class, dataset_class).load_from_checkpoint(args.resume_ckpt, model_class=model_class, dataset_class=dataset_class) 63 | trainer = pl.Trainer( 64 | logger=logger, 65 | benchmark=True, 66 | gpus=gpus, 67 | precision=cfg.trainer.precision, 68 | resume_from_checkpoint=args.resume_ckpt) 69 | trainer.fit(model) 70 | else: 71 | model = Trainer(cfg, model_class, dataset_class) 72 | trainer = pl.Trainer( 73 | logger=logger, 74 | callbacks=[checkpoint_callback, CudaClearCacheCallback()], 75 | check_val_every_n_epoch=cfg.trainer.checkpoint_period, 76 | log_every_n_steps=1, 77 | benchmark=True, 78 | gpus=gpus, 79 | precision=cfg.trainer.precision) 80 | trainer.fit(model) 81 | -------------------------------------------------------------------------------- /dataset/filters/filter_topology.py: -------------------------------------------------------------------------------- 1 | """ 2 | Topology Filtering 3 | * Group all objects of similar topology in the same bin 4 | """ 5 | 6 | import argparse 7 | import json 8 | import os 9 | 10 | import yaml 11 | from sklearn.neighbors import NearestNeighbors 12 | from tqdm import tqdm 13 | 14 | types_of_curves = { 15 | "Line": 0, "Circle": 1, "Ellipse": 2, "BSpline": 3, "Other": 4} 16 | types_of_surfs = { 17 | "Plane": 0, "Cylinder": 1, "Cone": 2, "Sphere": 3, "Torus": 4, 18 | "Revolution": 5, "Extrusion": 6, "BSpline": 7, "Other": 8} 19 | 20 | 21 | def main(args): 22 | # all step files 23 | names = [] 24 | for name in sorted(os.listdir(os.path.join(args.root, 'stat'))): 25 | names.append(name[:8]) 26 | 27 | if os.path.exists(args.error_log): 28 | # remove shapes that give errors 29 | with open(args.error_log, 'r') as f: 30 | lines = f.read().splitlines() 31 | 32 | error_names = [line[:8] for line in lines if line[:8].isdigit()] 33 | id_list = [name for name in names if name not in set(error_names)] 34 | else: 35 | print("error log not found") 36 | id_list = names 37 | 38 | # gather topology info for each object as their feature 39 | names = [] 40 | features = [] 41 | for name in tqdm(id_list): 42 | names.append(name) 43 | path = os.path.join(args.root, 'stat', f'{name}.yml') 44 | with open(path, 'r') as f: 45 | data = yaml.safe_load(f) 46 | 47 | curves = [types_of_curves[curve] for curve in data['curves']] 48 | surfs = [types_of_surfs[surf] for surf in data['surfs']] 49 | 50 | curves_hist = [0] * len(types_of_curves) 51 | for curve in curves: 52 | curves_hist[curve] += 1 53 | surfs_hist = [0] * len(types_of_surfs) 54 | for surf in surfs: 55 | surfs_hist[surf] += 1 56 | 57 | feature = [data['#edges'], data['#parts'], data['#sharp'], 58 | data['#surfs'], *curves_hist, *surfs_hist] 59 | features.append(feature) 60 | 61 | # Use nearest neighbors to find clusterings 62 | neigh = NearestNeighbors() 63 | neigh.fit(features) 64 | dist, indices = neigh.radius_neighbors(features, args.similarity_threshold) 65 | bins = set([tuple(ind) for ind in indices]) 66 | list_names = [] 67 | for b in bins: 68 | name_bin = [names[i] for i in b] 69 | list_names.append(name_bin) 70 | with open("dataset/dataset_gen_logs/topo_matching_bins.json", 'w') as f: 71 | json.dump(list_names, f) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--root', type=str, default="./data", 77 | help='dataset root.') 78 | parser.add_argument('--error_log', type=str, 79 | default="dataset/dataset_gen_logs/error.txt", 80 | help='dataset generation error log.') 81 | parser.add_argument('--similarity_threshold', type=float, 82 | default=0, 83 | help="grouping threhold for similarity") 84 | args = parser.parse_args() 85 | 86 | main(args) 87 | -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | ## Our Dataset 2 | 3 | ### Annotation format 4 | 5 | We save the annotation in a JSON format, including 2D edge points and face loops by edge indices. 6 | 7 | ```json 8 | { 9 | 'edges': [ 10 | [...], # edge 1 11 | [...], # edge 2 12 | ... 13 | ], 14 | 'faces_indices': [ 15 | [...], # face 1 16 | [...], # face 2 17 | ... 18 | ], 19 | } 20 | ``` 21 | 22 | ## Prepare Dataset 23 | 24 | Here, we provide tools to filter and parse data in [ABC dataset](https://archive.nyu.edu/handle/2451/43778). Please download `step`, `stat`, `obj`, and `feat`. 25 | 26 | ## Reorganize ABC dataset directory 27 | 28 | Remove the middle level folder after unzipping ABC dataset. 29 | 30 | ```bash 31 | python reorganize_dataset_dirs.py --root $ABC_ROOT_DIR 32 | ``` 33 | 34 | ```bash 35 | # Original ABC Dataset Structure 36 | root 37 | └── step 38 | └──00000050 39 | └── 00000050.step 40 | # Reorganized ABC Dataset structure 41 | root 42 | └── step 43 | └── 00000050.step 44 | ``` 45 | 46 | ### Dataset Directory Structure 47 | 48 | ``` 49 | root 50 | ├── step 51 | │ └── 00000050.step 52 | ├── json 53 | │ └── 00000050.json 54 | ├── face_png 55 | │ └── 00000050_{face_index}.png 56 | ├── face_svg 57 | │ └── 00000050_{face_index}.svg 58 | ├── png 59 | │ └── 00000050.png 60 | └── svg 61 | └── 00000050.svg 62 | ``` 63 | 64 | ## Command Lines 65 | 66 | #### Data Generation 67 | 68 | In each model's [config](configs), we detail the specific options needed to generate dataset of the correct format. 69 | 70 | ```bash 71 | # parse the entire ABC dataset 72 | python dataset/prepare_data.py --root $ABC_ROOT_DIR --id_list dataset/dataset_gen_logs/filtered_id_list.json > dataset/dataset_gen_logs/error.txt 73 | # parse a specific object (for debugging a single data) 74 | python dataset/prepare_data.py --root $ABC_ROOT_DIR --name $8_DIGIT_ID 75 | ``` 76 | 77 | #### Dataset Filtering 78 | 79 | Filter ABC objects by similarity 80 | ```bash 81 | # 1. filter by topology similarity 82 | python dataset/filters/filter_topology.py --root $ABC_ROOT_DIR 83 | # 2. render the three views of the entire ABC dataset 84 | python dataset/filters/3view_render.py --root $ABC_ROOT_DIR --id_list dataset/dataset_gen_logs/filtered_id_list.json > dataset/dataset_gen_logs/3view_error.txt 85 | # 3. filter by three-view similarity 86 | python dataset/filters/filter_3view.py --root $ABC_ROOT_DIR 87 | ``` 88 | 89 | Filter ABC objects by thickness 90 | ```bash 91 | python dataset/filters/filter_thickness.py --root $ABC_ROOT_DIR --save_root $DIR_FOR_TEMP_DATA 92 | ``` 93 | 94 | Filter ABC objects by complexity 95 | ```bash 96 | # By default, $MAX_FACE_SEQ = 128, $MAX_NUM_EDGE = 64 97 | python dataset/filters/filter_length.py --root $ABC_ROOT_DIR --face_seq_max $MAX_FACE_SEQ --num_edge_max $MAX_NUM_EDGE 98 | ``` 99 | 100 | Filter Generated Co-edge Data by Face Encloseness 101 | ```bash 102 | # Assume prepare_data.py has finished and all json files have generated 103 | python dataset/tests/check_faces_enclosed.py --root $ABC_ROOT_DIR --tol 1e-4 --remove 104 | # Regenerate train/valid/test splits 105 | python dataset/prepare_data.py --root $ABC_ROOT_DIR --only_split 106 | ``` -------------------------------------------------------------------------------- /dataset/tests/check_faces_enclosed.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Dataset Generation Integrity Test 3 | - Check if each face is enclosed 4 | ''' 5 | 6 | import json, os, argparse 7 | from functools import partial 8 | from tqdm.contrib.concurrent import process_map 9 | 10 | # check if e1's end meets e2's start 11 | def e1_connects_e2(e1, e2, tol): 12 | return abs(e1[-1][0] - e2[0][0]) < tol and \ 13 | abs(e1[-1][1] - e2[0][1]) < tol 14 | 15 | # face can be composed of multiple enclosed loops 16 | # check if all oriented-edges form loops 17 | # return all loops in list of lists if face is enclosed 18 | def is_face_enclosed(edges, face_indices, tol): 19 | all_loops = [] 20 | curr_loop = [] 21 | to_close = None # the start of an enclosed cycle, to be closed 22 | last_edge = None 23 | for ind in face_indices: 24 | if isinstance(ind, tuple): 25 | i, o = ind 26 | edge = edges[i][::-1] if o else edges[i] 27 | else: 28 | if ind < len(edges): 29 | edge = edges[ind] 30 | else: 31 | continue 32 | if to_close is None: 33 | to_close = edge 34 | else: 35 | # make sure the current edge connects to the last edge 36 | if not e1_connects_e2(last_edge, edge, tol): 37 | return False 38 | 39 | last_edge = edge 40 | curr_loop.append(ind) 41 | if e1_connects_e2(edge, to_close, tol): 42 | # close the current cycle 43 | to_close = None 44 | all_loops.append(curr_loop) 45 | curr_loop = [] 46 | return all_loops if to_close is None else False 47 | 48 | def check_enclosed(name, args): 49 | path = os.path.join(args.root, 'json', f'{name}.json') 50 | with open(path, 'r') as f: 51 | data = json.load(f) 52 | edges = data['edges'] 53 | faces_indices = data['faces_indices'] 54 | 55 | for face_indices in faces_indices: 56 | if not is_face_enclosed(edges, face_indices, args.tol): 57 | if args.remove: 58 | # remove json from dataset 59 | os.remove(path) 60 | print(f"{name} contains unclosed face") 61 | return 62 | 63 | def main(args): 64 | names = [] 65 | for name in sorted(os.listdir(os.path.join(args.root, 'json'))): 66 | names.append(name[:8]) 67 | 68 | process_map( 69 | partial(check_enclosed, args=args), names, 70 | max_workers=args.num_cores, chunksize=args.num_chunks 71 | ) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--root', type=str, default="./data", 77 | help='dataset root.') 78 | parser.add_argument('--name', type=str, default=None, 79 | help='filename.') 80 | # default to 3e-4 since the discretization tolerance is 1e-4 81 | parser.add_argument('--tol', type=float, 82 | default=3e-4, help='same point tolerance.') 83 | parser.add_argument('--num_cores', type=int, 84 | default=40, help='number of processors.') 85 | parser.add_argument('--num_chunks', type=int, 86 | default=10, help='number of chunk.') 87 | parser.add_argument('--remove', action='store_true') 88 | 89 | args = parser.parse_args() 90 | 91 | if args.name is None: 92 | main(args) 93 | else: 94 | check_enclosed(args.name, args) -------------------------------------------------------------------------------- /dataset/utils/projection_utils.py: -------------------------------------------------------------------------------- 1 | from OCC.Core.gp import gp_Ax2, gp_Dir, gp_Pnt 2 | from OCC.Core.HLRAlgo import HLRAlgo_Projector 3 | from OCC.Core.HLRBRep import HLRBRep_Algo, HLRBRep_HLRToShape 4 | from OCC.Extend.TopologyUtils import TopologyExplorer, discretize_edge 5 | import numpy as np 6 | 7 | def randnum(low, high): 8 | return np.random.rand() * (high - low) + low 9 | 10 | # generate a random camera 11 | def generate_random_camera_pos(seed): 12 | np.random.seed(seed) 13 | focus = randnum(3, 5) 14 | radius = randnum(1.25, 1.5) # distance of camera to origin 15 | phi = randnum(22.5, 67.5) # longitude, elevation of camera 16 | theta = randnum(0, 360) # latitude, rotation around z-axis 17 | return focus, pose_spherical(theta, phi, radius) 18 | 19 | def pose_spherical(theta, phi, radius): 20 | def trans_t(t): return np.array([ 21 | [1, 0, 0, 0], 22 | [0, 1, 0, 0], 23 | [0, 0, 1, t], 24 | [0, 0, 0, 1], 25 | ], dtype=np.float32) 26 | 27 | def rot_phi(phi): return np.array([ 28 | [1, 0, 0, 0], 29 | [0, np.cos(phi), -np.sin(phi), 0], 30 | [0, np.sin(phi), np.cos(phi), 0], 31 | [0, 0, 0, 1], 32 | ], dtype=np.float32) 33 | 34 | def rot_theta(th): return np.array([ 35 | [np.cos(th), -np.sin(th), 0, 0], 36 | [np.sin(th), np.cos(th), 0, 0], 37 | [0, 0, 1, 0], 38 | [0, 0, 0, 1], 39 | ], dtype=np.float32) 40 | 41 | c2w = trans_t(radius) 42 | c2w = rot_phi(np.deg2rad(phi)) @ c2w 43 | c2w = rot_theta(np.deg2rad(theta)) @ c2w 44 | c2w = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) @ c2w 45 | return c2w 46 | 47 | 48 | 49 | def project_shapes(shapes, args): 50 | location = args.location 51 | direction = args.direction 52 | focus = args.focus 53 | 54 | hlr = HLRBRep_Algo() 55 | 56 | if isinstance(shapes, list): 57 | for shape in shapes: 58 | hlr.Add(shape) 59 | else: 60 | hlr.Add(shapes) 61 | ax = gp_Ax2(gp_Pnt(*location), gp_Dir(*direction)) 62 | 63 | if args.pose is not None: 64 | pose = args.pose 65 | ax = gp_Ax2(gp_Pnt(*pose[:3, -1]), gp_Dir(*pose[:3, -2]), gp_Dir(*pose[:3, 0])) 66 | 67 | if focus == 0: 68 | projector = HLRAlgo_Projector(ax) 69 | else: 70 | projector = HLRAlgo_Projector(ax, focus) 71 | 72 | hlr.Projector(projector) 73 | hlr.Update() 74 | 75 | hlr_shapes = HLRBRep_HLRToShape(hlr) 76 | return hlr_shapes 77 | 78 | 79 | def discretize_compound(compound, tol): 80 | """ 81 | Given a compound of edges 82 | Return all edges discretized 83 | """ 84 | return [d3_to_d2(discretize_edge(edge, tol)) for edge in list(TopologyExplorer(compound).edges())] 85 | 86 | 87 | def d3_to_d2(points_3d): 88 | return [tuple(p[:2]) for p in points_3d] 89 | 90 | 91 | # project a list of 3D points 92 | def project_points(points, args): 93 | location = args.location 94 | direction = args.direction 95 | focus = args.focus 96 | 97 | ax = gp_Ax2(gp_Pnt(*location), gp_Dir(*direction)) 98 | 99 | if args.pose is not None: 100 | pose = args.pose 101 | ax = gp_Ax2(gp_Pnt(*pose[:3, -1]), gp_Dir(*pose[:3, -2]), gp_Dir(*pose[:3, 0])) 102 | 103 | if focus == 0: 104 | projector = HLRAlgo_Projector(ax) 105 | else: 106 | projector = HLRAlgo_Projector(ax, focus) 107 | 108 | projected = [projector.Project(gp_Pnt(*p)) for p in points] 109 | 110 | return projected -------------------------------------------------------------------------------- /dataset/filters/filter_3view.py: -------------------------------------------------------------------------------- 1 | """ 2 | Three view Filtering 3 | * Further divide each bin after topology filtering 4 | * By the similarity in three-view line drawings 5 | """ 6 | 7 | import argparse 8 | import json 9 | import os 10 | 11 | import cv2 12 | import numpy as np 13 | from sklearn.cluster import AgglomerativeClustering 14 | from sklearn.metrics import pairwise_distances 15 | from tqdm import tqdm 16 | 17 | 18 | def main(args): 19 | # topology bins 20 | with open("dataset/dataset_gen_logs/topo_matching_bins.json", 'r') as f: 21 | list_names = json.load(f) 22 | 23 | # bins of more than 1 objects 24 | multi_bins = [b for b in list_names if len(b) > 1] 25 | 26 | # 3view error list 27 | with open(args.error_log, 'r') as f: 28 | lines = f.read().splitlines() 29 | error_names = [line[:8] for line in lines if line[:8].isdigit()] 30 | error_names = set(error_names) 31 | 32 | # remove error objects from topology bins 33 | filtered_multi_bins = [] 34 | new_bins = [] 35 | for b in multi_bins: 36 | new_bin = [name for name in b if name not in error_names] 37 | if len(new_bin) == 0: 38 | continue 39 | if len(new_bin) == 1: 40 | new_bins.append(new_bin) 41 | else: 42 | filtered_multi_bins.append(new_bin) 43 | 44 | # cluster by jaccard distance in three views 45 | for large_bin in tqdm(filtered_multi_bins): 46 | all_bin_imgs = [] 47 | for name in large_bin: 48 | feature = [] 49 | for i in range(1, 4): 50 | img_path = os.path.join( 51 | args.root, '3view_png', f'{name}-{i}.png') 52 | originalImage = cv2.imread(img_path) 53 | if originalImage is None: 54 | feature.append(np.ones(128*128)) 55 | continue 56 | halfImage = cv2.resize(originalImage, (0, 0), fx=0.5, fy=0.5) 57 | grayImage = cv2.cvtColor(halfImage, cv2.COLOR_BGR2GRAY) 58 | thresh, bin_img = cv2.threshold( 59 | grayImage, 254, 255, cv2.THRESH_BINARY) 60 | feature.append(bin_img) 61 | all_bin_imgs.append(np.array(feature).flatten()) 62 | 63 | X = np.array(all_bin_imgs) == 0 64 | 65 | dist_mat = pairwise_distances(X, metric='jaccard') 66 | clusters = AgglomerativeClustering(n_clusters=None, affinity='precomputed', 67 | distance_threshold=args.similarity_threshold, linkage='single').fit(dist_mat) 68 | classes = clusters.labels_ 69 | new_bin = [[] for _ in range(max(classes)+1)] 70 | for name, c in zip(large_bin, classes): 71 | new_bin[c].append(name) 72 | new_bins += new_bin 73 | 74 | # add bins of single object back 75 | for b in list_names: 76 | if len(b) == 1: 77 | new_bins.append(b) 78 | 79 | # generate a list of valid, unique objects 80 | # always sample the smallest from the bin 81 | extracted_names = sorted([min(b, key=lambda s: int(s)) for b in new_bins]) 82 | 83 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'w') as f: 84 | json.dump(extracted_names, f) 85 | 86 | 87 | if __name__ == '__main__': 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('--root', type=str, default="/root/data", 90 | help='dataset root.') 91 | parser.add_argument('--error_log', type=str, 92 | default="dataset/dataset_gen_logs/3view_error.txt", 93 | help='3 view rendering error log.') 94 | parser.add_argument('--similarity_threshold', type=float, default=0.1, 95 | help="grouping threhold for jaccard distance") 96 | args = parser.parse_args() 97 | 98 | main(args) 99 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Neural Face Identification in a 2D Wireframe Projection of a Manifold Object 4 | 5 |

6 | Kehan Wang 7 | · 8 | Jia Zheng 9 | · 10 | Zihan Zhou 11 |

12 | 13 |

14 | IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 2022 15 |

16 | 17 | [![arXiv](http://img.shields.io/badge/arXiv-2203.04229-B31B1B.svg)](https://arxiv.org/abs/2203.04229) 18 | [![Conference](https://img.shields.io/badge/CVPR-2022-4b44ce.svg)](https://openaccess.thecvf.com/content/CVPR2022/html/Wang_Neural_Face_Identification_in_a_2D_Wireframe_Projection_of_a_CVPR_2022_paper.html) 19 | 20 | 21 | 22 |
23 | 24 | ## Requirements 25 | 26 | ```bash 27 | conda env create --file environment.yml 28 | conda activate faceformer 29 | ``` 30 | 31 | ## Download Dataset 32 | 33 | We use CAD mechanical models from [ABC dataset](https://archive.nyu.edu/handle/2451/43778). In order to reproduce our results, we also release the dataset used in the paper [here](https://drive.google.com/drive/u/2/folders/1ynMD02E5FWlCPmQkWyjHdq4Zhe8DIXE2). If you would like to build the dataset by yourself, please refer to [here](dataset/README.md). 34 | 35 | ## Evaluation 36 | 37 | ### Face Identification Model 38 | Trained models can be downloaded [here](https://drive.google.com/drive/u/2/folders/1oEoN_GzS36obLjvOlwFrOpWo0N7oh-fS). 39 | ```bash 40 | python main.py --config-file configs/{MODEL_NAME}.yml --test_ckpt trained_models/{MODEL_NAME}.ckpt 41 | ``` 42 | 43 | Face predictions will be saved to `lightning_logs/version_{LATEST}/json`. 44 | 45 | ### 3D Reconstruction 46 | 47 | ```bash 48 | # wireframe reconstruction 49 | python reconstruction/reconstruct_to_wireframe.py --root lightning_logs/version_{LATEST} 50 | # surface reconstruction 51 | python reconstruction/reconstruct_to_mesh.py --root lightning_logs/version_{LATEST} 52 | ``` 53 | 54 | Reconstructed wireframes (*.ply*) or meshes (*obj*) files will be saved to `lightning_logs/version_{LATEST}/{ply/obj}` 55 | 56 | ## Train a Model from Scratch 57 | 58 | ```bash 59 | python main.py --config_file configs/{MODEL_NAME}.yml 60 | ``` 61 | 62 | ## FAQs 63 | 64 | - *Why does root_dir not update when I change it in configs/ours.yml?* 65 | Seems like when pytorch_lightning loads the checkpoint in, it also uses the old root dir which we trained the model with. 66 | To fix: Please uncomment line 25 of faceformer/trainer.py and set the desired root_dir there. 67 | 68 | - *How should I use the downloaded json dataset?* 69 | Assuming we have downloaded *data_ours.tar.gz* and unzipped it to the same directory as [split_json.py](https://drive.google.com/drive/folders/1ynMD02E5FWlCPmQkWyjHdq4Zhe8DIXE2) in the outer-most directory, we now have: 70 | 71 | ``` 72 | root 73 | ├── main.py 74 | ├── split_json.py 75 | ├── ours 76 | │ └── 00000050.json 77 | │ └── 00000052.json 78 | │ └── ... 79 | ``` 80 | 81 | Run `python split_json.py` and it should prepare the dataset into the following: 82 | 83 | ``` 84 | root 85 | ├── main.py 86 | ├── split_json.py 87 | ├── ours 88 | │ └── test.txt 89 | │ └── train.txt 90 | │ └── valid.txt 91 | │ └── json 92 | │ └── 00000050.json 93 | │ └── 00000052.json 94 | │ └── ... 95 | ``` 96 | 97 | With this, set the root_dir to "ours" at line 25 of faceformer/trainer.py, and 98 | ``` 99 | python main.py --config-file configs/ours.yml --test_ckpt trained_models/ours.ckpt 100 | ``` 101 | should work. 102 | 103 | 104 | 105 | ## Acknowledgement 106 | 107 | The work was done during Kehan Wang's internship at Manycore Tech Inc. 108 | -------------------------------------------------------------------------------- /faceformer/embedding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class VanillaEmedding(nn.Module): 8 | def __init__(self, input_dim, num_model, token): 9 | super(VanillaEmedding, self).__init__() 10 | self.num_tokens = token.len 11 | 12 | # embedding for special tokens 13 | self.embedding_token = nn.Embedding(self.num_tokens, num_model) 14 | self.embedding_value = nn.Sequential( 15 | nn.Linear(input_dim, num_model), 16 | nn.ReLU(), 17 | nn.Linear(num_model, num_model) 18 | ) 19 | 20 | def embed_points(self, lines): 21 | return lines.flatten(-2, -1) 22 | 23 | def forward(self, coord): 24 | """ 25 | coord: N x L x P x D, N batch size, L num_edges/lines, P num_points, D num_axes 26 | E: num_model, model input dimension 27 | """ 28 | N = coord.size(0) 29 | 30 | token = torch.arange(self.num_tokens, dtype=torch.long).to(coord.device) 31 | token_embed = self.embedding_token(token) 32 | token_embed = token_embed.unsqueeze(0).expand(N, self.num_tokens, -1) # N x 4 x E 33 | 34 | coord_embed = self.embedding_value(self.embed_points(coord)) # N x L x E 35 | 36 | value_embed = torch.cat((token_embed, coord_embed), dim=1) # N x (4+L) x E 37 | 38 | return value_embed 39 | 40 | 41 | class CoordinateEmbedding(nn.Module): 42 | def __init__(self, num_axes, num_bits, num_embed, num_model, dependent_embed=False): 43 | super(CoordinateEmbedding, self).__init__() 44 | 45 | ntoken = 2**num_bits if dependent_embed else 2**num_bits * num_axes 46 | 47 | # embedding 48 | self.embedding_token = nn.Embedding(3, num_model) 49 | self.embedding_value = nn.Embedding(ntoken, num_embed) 50 | self.linear_proj = nn.Linear(num_axes*num_embed, num_model, bias=False) 51 | 52 | def forward(self, coord): 53 | N, S, _ = coord.shape 54 | 55 | # N x S x E 56 | token = torch.arange(3, dtype=torch.long).to(coord.device) 57 | token = token.unsqueeze(0).expand(N, -1) 58 | 59 | token_embed = self.embedding_token(token) 60 | coord_embed = self.embedding_value(coord) 61 | coord_embed = self.linear_proj(coord_embed.view(N, S, -1)) 62 | 63 | value_embed = torch.cat((token_embed, coord_embed), dim=1) 64 | 65 | return value_embed 66 | 67 | 68 | class PositionalEncoding(nn.Module): 69 | """ 70 | This is a more standard version of the position embedding 71 | used by the Attention is all you need paper. 72 | """ 73 | 74 | def __init__(self, num_model, max_len=5000): 75 | super(PositionalEncoding, self).__init__() 76 | 77 | pe = torch.zeros(max_len, num_model) 78 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 79 | div_term = torch.exp(torch.arange( 80 | 0, num_model, 2).float() * (-math.log(10000.0) / num_model)) 81 | pe[:, 0::2] = torch.sin(position * div_term) 82 | pe[:, 1::2] = torch.cos(position * div_term) 83 | pe = pe.unsqueeze(0) 84 | self.register_buffer('pe', pe) 85 | 86 | def forward(self, x): 87 | return self.pe[:, :x.size(1)] # N x S x E 88 | 89 | 90 | class PositionEmbeddingLearned(nn.Module): 91 | """ 92 | Absolute pos embedding, learned. 93 | """ 94 | 95 | def __init__(self, num_model, max_len=5000): 96 | super().__init__() 97 | 98 | position = torch.arange(0, max_len, dtype=torch.long).unsqueeze(0) 99 | self.register_buffer('position', position) 100 | self.pos_embed = nn.Embedding(max_len, num_model) 101 | self._init_embeddings() 102 | 103 | def _init_embeddings(self): 104 | nn.init.kaiming_normal_(self.pos_embed.weight, mode="fan_in") 105 | 106 | def forward(self, x): 107 | pos = self.position[:, :x.size(1)] # N x L x E 108 | return self.pos_embed(pos) 109 | -------------------------------------------------------------------------------- /faceformer/datasets/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from faceformer.utils import flatten_list 8 | 9 | 10 | # TODO: try using bilinear interpolation for all cases 11 | def sample_points(edge, num_samples=50): 12 | if len(edge) == 2: 13 | return sample_points_on_line(edge, num_samples) 14 | return sample_points_on_curve(edge, num_samples) 15 | 16 | 17 | def sample_points_on_line(line, num_samples): 18 | t = np.linspace(0, 1, num_samples) 19 | x1, y1, x2, y2 = line[0][0], line[0][1], line[1][0], line[1][1] 20 | x = x1 + (x2-x1) * t 21 | y = y1 + (y2-y1) * t 22 | return np.vstack([x, y]).T 23 | 24 | 25 | def sample_points_on_curve(curve, num_samples): 26 | samples = np.linspace(0, len(curve)-1, num_samples).round(0).astype(int) 27 | curve = np.array(curve) 28 | return curve[samples] 29 | 30 | 31 | class ABCDataset(torch.utils.data.Dataset): 32 | 33 | def __init__(self, root_dir, datafile_path, config): 34 | super(ABCDataset, self).__init__() 35 | self.root_dir = root_dir 36 | self.info_files = self.parse_splits_list(datafile_path) 37 | 38 | # input shape L x P x D 39 | self.num_points_per_line = config.num_points_per_line # P 40 | self.num_lines = config.num_lines # L 41 | self.point_dim = config.point_dim # D 42 | # output shape S 43 | self.label_seq_length = config.label_seq_length 44 | 45 | self.token = config.token 46 | 47 | # preload all files 48 | self.raw_datas = [] 49 | for info_file in self.info_files: 50 | with open(os.path.join(self.root_dir, info_file), "r") as f: 51 | self.raw_datas.append(json.loads(f.read())) 52 | 53 | 54 | def __len__(self): 55 | return len(self.info_files) 56 | 57 | def __getitem__(self, index): 58 | raw_data = self.raw_datas[index] 59 | 60 | edges, faces_indices = raw_data['edges'], raw_data['faces_indices'] 61 | 62 | input = np.zeros( 63 | (self.num_lines, self.num_points_per_line, self.point_dim), dtype=np.float32) 64 | for i, edge in enumerate(edges): 65 | input[i, :self.num_points_per_line] = sample_points( 66 | edge, self.num_points_per_line) 67 | 68 | input_mask = np.ones(self.num_lines, dtype=np.bool) 69 | input_mask[:len(edges)] = 0 70 | 71 | label = np.ones(self.label_seq_length, dtype=np.int) * self.token.PAD 72 | label[0] = self.token.SOS 73 | curr_pos = 0 74 | for face in faces_indices: 75 | if not isinstance(face[0], int): 76 | face = flatten_list(face) 77 | curr_pos += 1 78 | label[curr_pos:curr_pos+len(face)] = face 79 | # shift face indices for special tokens 80 | label[curr_pos:curr_pos+len(face)] += self.token.len 81 | curr_pos += len(face) 82 | label[curr_pos] = self.token.SEP 83 | label[curr_pos] = self.token.EOS 84 | label_mask = (label == self.token.PAD) 85 | 86 | data = { 87 | 'id': index, 88 | 'input': input, 89 | 'label': label, 90 | 'num_input': len(edges), 91 | 'num_label': curr_pos+1, 92 | 'input_mask': input_mask, 93 | 'label_mask': label_mask, 94 | 'name': self.info_files[index] 95 | } 96 | 97 | return data 98 | 99 | def parse_splits_list(self, splits): 100 | """ Returns a list of info_file paths 101 | Args: 102 | splits (list of strings): each item is a path to a .json data file 103 | or a path to a .txt file containing a list of .json's relative paths from root. 104 | """ 105 | if isinstance(splits, str): 106 | splits = splits.split() 107 | info_files = [] 108 | for split in splits: 109 | ext = os.path.splitext(split)[1] 110 | split_path = os.path.join(self.root_dir, split) 111 | # split_path = os.path.join("/root/ablation/polys-test", split) 112 | if ext == '.json': 113 | info_files.append(split_path) 114 | elif ext == '.txt': 115 | info_files += [info_file.rstrip() for info_file in open(split_path, 'r')] 116 | else: 117 | raise NotImplementedError('%s not a valid info_file type' % split) 118 | return info_files 119 | -------------------------------------------------------------------------------- /dataset/utils/Face.py: -------------------------------------------------------------------------------- 1 | from OCC.Core.GeomAbs import GeomAbs_Cylinder, GeomAbs_Plane 2 | import numpy as np 3 | from OCC.Core.BRepAdaptor import BRepAdaptor_Surface 4 | 5 | class Face: 6 | ''' 7 | Face is a collection of non-repeating edges 8 | ''' 9 | def __init__(self, face, topo): 10 | surface = BRepAdaptor_Surface(face) 11 | self.face = face 12 | self.face_type = surface.GetType() 13 | self.topo = topo 14 | self.edges = [] 15 | self.edge_orientations = [] 16 | self.keys = [] 17 | 18 | # get face parametric values 19 | if self.face_type == GeomAbs_Plane: 20 | plane = surface.Surface().Plane() 21 | # plane parameters: Location, XAxis, YAxis, ZAxis, Coefficients 22 | self.parameters = {'Location': self._get_vector_parameters(plane.Location()), 23 | 'XAxis': self._get_axis_parameters(plane.XAxis()), 24 | 'YAxis': self._get_axis_parameters(plane.YAxis()), 25 | 'Normal': self._get_axis_parameters(plane.Axis()), 26 | 'Coefficients': plane.Coefficients()} 27 | elif self.face_type == GeomAbs_Cylinder: 28 | cylinder = surface.Surface().Cylinder() 29 | # cylinder parameters: Location, XAxis, YAxis, ZAxis, Coefficients, Radius 30 | self.parameters = {'Location': self._get_vector_parameters(cylinder.Location()), 31 | 'XAxis': self._get_axis_parameters(cylinder.XAxis()), 32 | 'YAxis': self._get_axis_parameters(cylinder.YAxis()), 33 | 'Normal': self._get_axis_parameters(cylinder.Axis()), 34 | 'Coefficients': cylinder.Coefficients(), 35 | 'Radius': cylinder.Radius()} 36 | else: 37 | self.parameters = None 38 | 39 | # Given an OCC vector 40 | # Return XYZ 41 | def _get_vector_parameters(self, vector): 42 | return vector.X(), vector.Y(), vector.Z() 43 | 44 | # Given an axis 45 | # Return Location(XYZ), Direction(XYZ) 46 | def _get_axis_parameters(self, axis): 47 | location = self._get_vector_parameters(axis.Location()) 48 | direction = self._get_vector_parameters(axis.Direction()) 49 | return location, direction 50 | 51 | def add_edge(self, edge, orientation): 52 | self.edges.append(edge) 53 | self.edge_orientations.append(orientation) 54 | self.keys.append(hash(edge)) 55 | 56 | def remove_edge(self, key): 57 | ind = self.keys.index(key) 58 | del self.keys[ind] 59 | del self.edges[ind] 60 | del self.edge_orientations[ind] 61 | 62 | def get_oriented_dedges(self, is_3d=False): 63 | return [e.get_oriented_dedge(o, is_3d) for e, o in zip(self.edges, self.edge_orientations)] 64 | 65 | def get_edge_ind_and_orientation(self): 66 | return [(e.index, o) for e, o in zip(self.edges, self.edge_orientations)] 67 | 68 | def roll(self, n): 69 | self.edges = np.roll(self.edges, -n, axis=0).tolist() 70 | self.edge_orientations = np.roll(self.edge_orientations, -n, axis=0).tolist() 71 | self.keys = np.roll(self.keys, -n, axis=0).tolist() 72 | 73 | def merge(self, other): 74 | ''' 75 | Merge faces on sewn edge. 76 | Assume both faces are rolled properly with sewn edge at the front. 77 | edge[0] is sewn edge. 78 | Return edge merging candidates 79 | ''' 80 | assert isinstance(other, Face), 'Cannot merge face with non-face' 81 | sewn_edge = self.edges[0] 82 | if self == other: 83 | self.edges = self.edges[1:] 84 | self.edge_orientations = self.edge_orientations[1:] 85 | self.keys = self.keys[1:] 86 | key = hash(sewn_edge) 87 | if key in self.keys: 88 | self.remove_edge(key) 89 | 90 | del self.topo.all_edges[hash(sewn_edge)] 91 | return None 92 | 93 | # change faces in edge 94 | for edge in other.edges[1:]: 95 | i = edge.faces.index(other) 96 | edge.faces[i] = self 97 | 98 | # candidate merge edges 99 | candidates = [(self.keys[1], other.keys[-1]), (self.keys[-1], other.keys[1])] 100 | 101 | # merge face 102 | self.edges = self.edges[1:] + other.edges[1:] 103 | self.edge_orientations = self.edge_orientations[1:] + other.edge_orientations[1:] 104 | self.keys = self.keys[1:] + other.keys[1:] 105 | if self.face_type != other.face_type: 106 | self.face_type = 10 # set face type to other 107 | 108 | 109 | # remove sewn edge and other face in topo 110 | del self.topo.all_edges[hash(sewn_edge)] 111 | del self.topo.all_faces[hash(other.face)] 112 | 113 | return candidates -------------------------------------------------------------------------------- /faceformer/datasets/data_para.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def sample_points(edge, num_samples=50): 9 | if len(edge) == 2: 10 | return sample_points_on_line(edge, num_samples) 11 | return sample_points_on_curve(edge, num_samples) 12 | 13 | 14 | def sample_points_on_line(line, num_samples): 15 | t = np.linspace(0, 1, num_samples) 16 | x1, y1, x2, y2 = line[0][0], line[0][1], line[1][0], line[1][1] 17 | x = x1 + (x2-x1) * t 18 | y = y1 + (y2-y1) * t 19 | return np.vstack([x, y]).T 20 | 21 | 22 | def sample_points_on_curve(curve, num_samples): 23 | samples = np.linspace(0, len(curve)-1, num_samples).round(0).astype(int) 24 | curve = np.array(curve) 25 | return curve[samples] 26 | 27 | 28 | class ABCDataset_Parallel(torch.utils.data.Dataset): 29 | 30 | def __init__(self, root_dir, datafile_path, config): 31 | super(ABCDataset_Parallel, self).__init__() 32 | 33 | self.root_dir = root_dir 34 | self.info_files = self.parse_splits_list(datafile_path) 35 | 36 | # input shape L x P x D 37 | self.num_points_per_line = config.num_points_per_line # P 38 | self.num_lines = config.num_lines # L 39 | self.point_dim = config.point_dim # D 40 | 41 | # output shape F x T 42 | self.max_num_faces = config.max_num_faces # F 43 | self.max_face_length = config.max_face_length # T 44 | 45 | self.token = config.token 46 | 47 | # preload all files 48 | self.raw_datas = [] 49 | for info_file in self.info_files: 50 | with open(os.path.join(self.root_dir, info_file), "r") as f: 51 | self.raw_datas.append(json.loads(f.read())) 52 | 53 | def __len__(self): 54 | return len(self.info_files) 55 | 56 | def __getitem__(self, index): 57 | raw_data = self.raw_datas[index] 58 | 59 | edges, faces_indices = raw_data['edges'], raw_data['faces_indices'] 60 | 61 | input = np.zeros( 62 | (self.num_lines, self.num_points_per_line, self.point_dim), dtype=np.float32) 63 | for i, edge in enumerate(edges): 64 | input[i, :self.num_points_per_line] = sample_points( 65 | edge, self.num_points_per_line) 66 | 67 | input_mask = np.ones(self.num_lines, dtype=np.bool) # L 68 | input_mask[:len(edges)] = 0 69 | 70 | # F x T 71 | label = np.ones((self.num_lines, self.max_face_length), dtype=np.int) * self.token.PAD 72 | ind = 0 73 | # each face: [(loop 1), ..., (loop n)] 74 | for face_with_type in faces_indices: 75 | type, face = face_with_type 76 | # only allow Plane - 0, Cylinder - 1, Other - 2 77 | if type > 1: 78 | type = 2 79 | # type offset set to 1 80 | type += self.token.face_type_offset 81 | # each loop rolls itself 82 | for loop in face: 83 | for i in range(len(loop)): 84 | # construct new seq 85 | rotated_loop = np.roll(loop, i, axis=0).tolist() 86 | new_seq = rotated_loop 87 | for other_loop in face: 88 | if other_loop != loop: 89 | new_seq += other_loop 90 | label[ind, :len(new_seq)] = new_seq 91 | # shift face indices for special tokens 92 | label[ind, :len(new_seq)] += self.token.len 93 | label[ind, len(new_seq)] = type 94 | ind += 1 95 | for i in range(ind, self.num_lines): 96 | label[i, 0] = self.token.len - 1 # set to Other type of face 97 | label_mask = (label == self.token.PAD) 98 | 99 | data = { 100 | 'id': index, 101 | 'input': input, 102 | 'label': label, 103 | 'num_input': len(edges), 104 | 'num_faces': len(faces_indices), 105 | 'input_mask': input_mask, 106 | 'label_mask': label_mask, 107 | 'name': self.info_files[index] 108 | } 109 | 110 | return data 111 | 112 | def parse_splits_list(self, splits): 113 | """ Returns a list of info_file paths 114 | Args: 115 | splits (list of strings): each item is a path to a .json data file 116 | or a path to a .txt file containing a list of .json's relative paths from root. 117 | """ 118 | if isinstance(splits, str): 119 | splits = splits.split() 120 | info_files = [] 121 | for split in splits: 122 | ext = os.path.splitext(split)[1] 123 | split_path = os.path.join(self.root_dir, split) 124 | # split_path = os.path.join("/root/ablation/polys-test", split) 125 | if ext == '.json': 126 | info_files.append(split_path) 127 | elif ext == '.txt': 128 | info_files += [info_file.rstrip() for info_file in open(split_path, 'r')] 129 | else: 130 | raise NotImplementedError('%s not a valid info_file type' % split) 131 | return info_files 132 | -------------------------------------------------------------------------------- /dataset/filters/filter_thickness.py: -------------------------------------------------------------------------------- 1 | """ 2 | Filter closer cases 3 | """ 4 | import argparse, os, trimesh, yaml, time, json 5 | from functools import partial 6 | 7 | import numpy as np 8 | from scipy.spatial.distance import cdist 9 | from tqdm.contrib.concurrent import process_map 10 | 11 | 12 | def scale_to_unit_sphere(mesh): 13 | if isinstance(mesh, trimesh.Scene): 14 | mesh = mesh.dump().sum() 15 | 16 | vertices = mesh.vertices - mesh.bounding_box.centroid 17 | vertices *= 2 / np.linalg.norm(mesh.bounding_box.extents) 18 | 19 | return trimesh.Trimesh(vertices=vertices, faces=mesh.faces, process=False, maintain_order=True) 20 | 21 | 22 | def dist_p2p(vertices, verts_i, verts_j): 23 | dists = cdist(vertices[verts_i], vertices[verts_j]) 24 | return np.mean(np.min(dists, 1)) 25 | 26 | 27 | def dist_p2l(vertices, verts_i, verts_j, EPS=1e-8, MAX_VALUE=10): 28 | edges = np.vstack((verts_j[:-1], verts_j[1:])).T 29 | edge_vector = vertices[edges[:, 1]] - vertices[edges[:, 0]] 30 | edge_length = np.linalg.norm(edge_vector, axis=1, keepdims=True) + EPS 31 | edge_tangent = edge_vector / edge_length 32 | 33 | # Points x Lines x Dim 34 | vector = vertices[verts_i, np.newaxis] - vertices[edges[:, 0]][np.newaxis] 35 | 36 | # Points x Lines 37 | points_prop = np.sum( 38 | vector * edge_tangent[np.newaxis], axis=-1) / edge_length.reshape(1, -1) 39 | points_perp = points_prop[..., np.newaxis] * edge_vector - vector 40 | 41 | # p2l dists within 0 < points_prop < 1 42 | pl_dists = np.linalg.norm(points_perp, axis=-1) 43 | pl_valid = np.logical_and(0 < points_prop, points_prop < 1) 44 | pl_dists[np.logical_not(pl_valid)] = MAX_VALUE 45 | 46 | # p2p dists 47 | pp_dists = cdist(vertices[verts_i], vertices[edges].reshape(-1, 3)) 48 | pp_dists = pp_dists.reshape(-1, len(edges), 2) 49 | pp_dists = np.min(pp_dists, -1) 50 | 51 | dists = np.minimum(pl_dists, pp_dists) 52 | return np.mean(np.min(dists, 1)) 53 | 54 | 55 | def load_and_preprocess(name, args): 56 | if os.path.exists(os.path.join(args.save_root, f'{name}.npy')): 57 | return 58 | 59 | mesh_path = os.path.join(args.root, 'obj', f'{name}.obj') 60 | mesh = trimesh.load_mesh(mesh_path, process=False, maintain_order=True) 61 | 62 | # normalize to a unit sphere 63 | mesh = scale_to_unit_sphere(mesh) 64 | 65 | feat_path = os.path.join(args.root, 'feat', f'{name}.yml') 66 | with open(feat_path) as file: 67 | annos = yaml.full_load(file) 68 | 69 | curve_verts = [] 70 | for curve in annos['curves']: 71 | vert_indices = np.array(curve['vert_indices']).reshape(-1) 72 | curve_verts.append(vert_indices) 73 | 74 | vertices = mesh.vertices.view(np.ndarray) 75 | 76 | num_curves = len(curve_verts) 77 | 78 | with open(os.path.join(args.save_root, f'{name}.npy'), 'wb') as f: 79 | np.save(f, vertices) 80 | np.save(f, num_curves) 81 | for c in curve_verts: 82 | np.save(f, c) 83 | 84 | def filter_by_thickness(name, args): 85 | 86 | with open(os.path.join(args.save_root, f'{name}.npy'), 'rb') as f: 87 | vertices = np.load(f) 88 | num_curves = np.load(f) 89 | curve_verts = [] 90 | max_index = 0 91 | for i in range(num_curves): 92 | curve_verts.append(np.load(f)) 93 | max_index = max(curve_verts[-1].max(), max_index) 94 | 95 | if max_index >= len(vertices): 96 | print(f"{name} has vertices don't match {len(vertices)} <= {max_index}") 97 | return None 98 | 99 | # dists = np.zeros((num_curves, num_curves)) 100 | 101 | for i in range(num_curves): 102 | verts_i = curve_verts[i] 103 | 104 | for j in range(i+1, num_curves): 105 | verts_j = curve_verts[j] 106 | 107 | if args.p2p: 108 | dist_1 = dist_p2p(vertices, verts_i, verts_j) 109 | dist_2 = dist_p2p(vertices, verts_j, verts_i) 110 | else: 111 | dist_1 = dist_p2l(vertices, verts_i, verts_j) 112 | dist_2 = dist_p2l(vertices, verts_j, verts_i) 113 | 114 | if dist_1 < args.threshold and dist_2 < args.threshold: 115 | return None 116 | # dists[i, j] = dist_1 117 | # dists[j, i] = dist_2 118 | return name 119 | 120 | def main(args): 121 | with open("dataset/dataset_gen_logs/filtered_id_list.json", 'r') as f: 122 | names = json.load(f) 123 | 124 | # preprocess 125 | process_map( 126 | partial(load_and_preprocess, args=args), names, 127 | max_workers=args.num_cores, chunksize=args.num_chunks) 128 | 129 | rets = process_map( 130 | partial(filter_by_thickness, args=args), names, 131 | max_workers=args.num_cores, chunksize=args.num_chunks) 132 | 133 | filtered = [ret for ret in rets if ret is not None] 134 | 135 | # Filtering by thickness can take a long time. 136 | # Uncomment the following and lines in filter_by_thickness() to save intermediate result. 137 | 138 | # name_to_dist = {} 139 | # for i in rets: 140 | # if i is None: 141 | # continue 142 | # name, dists = i 143 | # name_to_dist[name] = dists.tolist() 144 | 145 | # with open('all_thickness.json', 'w') as f: 146 | # json.dump(name_to_dist, f) 147 | 148 | with open('filtered_id_list.json', 'w') as f: 149 | json.dump(filtered, f) 150 | 151 | with open('data_processing_log.txt', 'a') as f: 152 | f.write("Thickness id list generation done - " + time.ctime() + '\n') 153 | 154 | 155 | if __name__ == '__main__': 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument('--root', type=str, default='/root/Datasets/FaceFormer', 158 | help='dataset root') 159 | parser.add_argument('--save_root', type=str, default='/root/data/curve_verts', 160 | help='dataset root') 161 | parser.add_argument('--threshold', type=float, default=0.05, 162 | help='threshold for closer edge') 163 | parser.add_argument('--num_cores', type=int, 164 | default=10, help='number of processors.') 165 | parser.add_argument('--num_chunks', type=int, 166 | default=10, help='number of chunk.') 167 | parser.add_argument('--p2p', action='store_true') 168 | args = parser.parse_args() 169 | 170 | main(args) 171 | 172 | -------------------------------------------------------------------------------- /dataset/utils/json_to_svg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from functools import partial 5 | from faceformer.utils import flatten_list 6 | 7 | import numpy as np 8 | import svgwrite 9 | from cairosvg import svg2png 10 | from matplotlib.cm import get_cmap as colormap 11 | from tqdm.contrib.concurrent import process_map 12 | 13 | 14 | def discretized_edge_to_svg_polyline(points): 15 | """ Returns a svgwrite.Path for the edge, and the 2d bounding box 16 | """ 17 | return svgwrite.shapes.Polyline(points, fill="none", class_='vectorEffectClass') 18 | 19 | def save_svg_groups(groups_of_edges, filename, args): 20 | discretized_edges = flatten_list(groups_of_edges) 21 | all_edges = flatten_list(discretized_edges) 22 | 23 | # compute bounding box 24 | min_x, min_y = np.min(all_edges, axis=0) - args.png_padding 25 | max_x, max_y = np.max(all_edges, axis=0) + args.png_padding 26 | width, height = max_x - min_x, max_y - min_y 27 | 28 | 29 | # build the svg drawing 30 | dwg = svgwrite.Drawing(filename, (args.width, args.height), debug=True) 31 | dwg.viewbox(min_x, min_y, width, height) 32 | 33 | # make sure line width stays constant 34 | # https://github.com/mozman/svgwrite/issues/38 35 | dwg.defs.add( 36 | dwg.style(".vectorEffectClass {\nvector-effect: non-scaling-stroke;\n}")) 37 | 38 | n = len(groups_of_edges) + 1 39 | cmap = (colormap('coolwarm')(np.linspace(0, 1, n))[:, :3]*255).astype(np.uint8) 40 | np.random.seed(args.seed) 41 | cmap = cmap[np.random.permutation(n), :] 42 | for index, group in enumerate(groups_of_edges): 43 | color = ",".join([str(c) for c in cmap[index]]) 44 | for edge in group: 45 | polyline = discretized_edge_to_svg_polyline(edge) 46 | polyline.stroke(f"rgb({color})", 47 | width=args.line_width, linecap="round") 48 | dwg.add(polyline) 49 | # export to string or file according to the user choice 50 | dwg.save() 51 | 52 | 53 | def save_svg(discretized_edges, filename, args, color='black'): 54 | # compute polylines for all edges 55 | polylines = [discretized_edge_to_svg_polyline( 56 | edge) for edge in discretized_edges] 57 | 58 | all_edges = flatten_list(discretized_edges) 59 | 60 | # compute bounding box 61 | min_x, min_y = np.min(all_edges, axis=0) - args.png_padding 62 | max_x, max_y = np.max(all_edges, axis=0) + args.png_padding 63 | width, height = max_x - min_x, max_y - min_y 64 | 65 | 66 | # build the svg drawing 67 | dwg = svgwrite.Drawing(filename, (args.width, args.height), debug=True) 68 | dwg.viewbox(min_x, min_y, width, height) 69 | 70 | # make sure line width stays constant 71 | # https://github.com/mozman/svgwrite/issues/38 72 | dwg.defs.add( 73 | dwg.style(".vectorEffectClass {\nvector-effect: non-scaling-stroke;\n}")) 74 | 75 | n = len(polylines) + 1 76 | cmap = (colormap('jet')(np.linspace(0, 1, n))[:, :3]*255).astype(np.uint8) 77 | cmap = cmap[np.random.permutation(n), :] 78 | 79 | for index, (dedge, polyline) in enumerate(zip(discretized_edges, polylines)): 80 | if color != 'black': 81 | color = ",".join([str(c) for c in cmap[index]]) 82 | color = f"rgb({color})" 83 | polyline.stroke(color, 84 | width=args.line_width, linecap="round") 85 | dwg.add(polyline) 86 | # add a circle at the beginning of the edge 87 | dwg.add(dwg.circle(dedge[0], r=4/256, fill='black')) 88 | 89 | # export to string or file according to the user choice 90 | dwg.save() 91 | 92 | 93 | def save_png(name, args, prefix=''): 94 | svg2png( 95 | bytestring=open( 96 | os.path.join(args.root, prefix+'svg', f'{name}.svg'), 'rb').read(), 97 | output_width=args.width, 98 | output_height=args.height, 99 | background_color='white', 100 | write_to=os.path.join(args.root, prefix+'png', f'{name}.png') 101 | ) 102 | 103 | 104 | def json_to_svg_png(name, args): 105 | json_filename = os.path.join(args.root, 'json', f'{name}.json') 106 | with open(json_filename, "r") as f: 107 | data = json.loads(f.read()) 108 | edges, faces_indices = data['edges'], data['faces_indices'] 109 | # reconstruct faces 110 | for index, face_indices in enumerate(faces_indices): 111 | face = [edges[ind] for ind in face_indices] 112 | filename = os.path.join( 113 | args.root, args.prefix+'face_svg', f'{name}_{index}.svg') 114 | save_svg(face, filename, args) 115 | 116 | save_svg(edges, os.path.join( 117 | args.root, args.prefix+'svg', f'{name}.svg'), args) 118 | # generate_pngs(name, args, prefix=args.prefix) 119 | 120 | 121 | def main(args): 122 | os.makedirs(os.path.join(args.root, args.prefix+'svg'), exist_ok=True) 123 | os.makedirs(os.path.join(args.root, args.prefix+'face_svg'), exist_ok=True) 124 | os.makedirs(os.path.join(args.root, args.prefix+'png'), exist_ok=True) 125 | os.makedirs(os.path.join(args.root, args.prefix+'face_png'), exist_ok=True) 126 | 127 | names = [] 128 | for name in sorted(os.listdir(os.path.join(args.root, 'step'))): 129 | if not name.endswith('.step'): 130 | continue 131 | 132 | names.append(os.path.splitext(name)[0]) 133 | 134 | process_map( 135 | partial(json_to_svg_png, args=args), names, 136 | max_workers=args.num_cores, chunksize=args.num_chunks 137 | ) 138 | 139 | 140 | if __name__ == '__main__': 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--root', type=str, default="./data", 143 | help='dataset root.') 144 | parser.add_argument('--num_cores', type=int, 145 | default=1, help='number of processors.') 146 | parser.add_argument('--num_chunks', type=int, 147 | default=16, help='number of chunk.') 148 | parser.add_argument('--line_width', type=str, 149 | default=str(3/256), help='svg line width.') 150 | parser.add_argument('--name', type=str, default=None, 151 | help='filename.') 152 | parser.add_argument('--width', type=int, 153 | default=256, help='svg width.') 154 | parser.add_argument('--height', type=int, 155 | default=256, help='svg height.') 156 | parser.add_argument('--prefix', type=str, default='json_', 157 | help='filename prefix for generated svg and png') 158 | args = parser.parse_args() 159 | 160 | if args.name is None: 161 | main(args) 162 | else: 163 | json_to_svg_png(args.name, args) 164 | -------------------------------------------------------------------------------- /dataset/filters/3view_render.py: -------------------------------------------------------------------------------- 1 | """ 2 | Render three-view line drawings 3 | """ 4 | import argparse 5 | import os 6 | from functools import partial 7 | import numpy as np 8 | import svgwrite 9 | from cairosvg import svg2png 10 | from tqdm.contrib.concurrent import process_map 11 | 12 | from OCC.Core.Bnd import Bnd_Box 13 | from OCC.Core.BRepBndLib import brepbndlib_Add 14 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_Transform 15 | from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh 16 | from OCC.Core.gp import gp_Ax2, gp_Dir, gp_Pnt, gp_Trsf, gp_Vec 17 | from OCC.Extend.TopologyUtils import TopologyExplorer 18 | 19 | from dataset.utils.json_to_svg import save_png, save_svg 20 | from dataset.utils.read_step_file import read_step_file 21 | from dataset.utils.projection_utils import project_shapes, discretize_compound 22 | 23 | O = gp_Pnt(0,0,0) 24 | X = gp_Dir(1,0,0) 25 | Y = gp_Dir(0,1,0) 26 | nY = gp_Dir(0,-1,0) 27 | Z = gp_Dir(0,0,1) 28 | 29 | directions = [ 30 | gp_Ax2(O, gp_Dir(1,1,1)), # 45 degree 31 | gp_Ax2(O, nY, X), # front 32 | gp_Ax2(O, X, Y), # right 33 | gp_Ax2(O, Z, X) # top 34 | ] 35 | 36 | def get_boundingbox(shape, tol=1e-6, use_mesh=False): 37 | """ return the bounding box of the TopoDS_Shape `shape` 38 | Parameters 39 | ---------- 40 | shape : TopoDS_Shape or a subclass such as TopoDS_Face 41 | the shape to compute the bounding box from 42 | tol: float 43 | tolerance of the computed boundingbox 44 | use_mesh : bool 45 | a flag that tells whether or not the shape has first to be meshed before the bbox 46 | computation. This produces more accurate results 47 | """ 48 | bbox = Bnd_Box() 49 | bbox.SetGap(tol) 50 | if use_mesh: 51 | mesh = BRepMesh_IncrementalMesh() 52 | mesh.SetParallelDefault(True) 53 | mesh.SetShape(shape) 54 | mesh.Perform() 55 | if not mesh.IsDone(): 56 | raise AssertionError("Mesh not done.") 57 | brepbndlib_Add(shape, bbox, use_mesh) 58 | xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() 59 | center = (xmax + xmin) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2 60 | extent = abs(xmax-xmin), abs(ymax-ymin), abs(zmax-zmin) 61 | return center, extent 62 | 63 | def get_discretized_edges(name, shape, direction, args): 64 | """ Given a TopologyExplorer topo, and a face on it, 65 | find all edges of the face without sewn edges. 66 | Return discretized edges 67 | 68 | VComponent / HComponent: sharp edges 69 | Rg1LineVCompound / Rg1LineHCompound: smooth edges 70 | RgNLineVCompound / RgNLineHCompound: sewn edges 71 | OutLineVCompound / OutLineHCompound: outlines 72 | """ 73 | # project the face 74 | hlr_shapes = project_shapes(shape, direction) 75 | 76 | discretized_edges = [] 77 | 78 | outline_compound = hlr_shapes.OutLineVCompound() 79 | if outline_compound: 80 | discretized_edges += discretize_compound(outline_compound, args.tol) 81 | 82 | smooth_compound = hlr_shapes.Rg1LineVCompound() 83 | if smooth_compound: 84 | discretized_edges += discretize_compound(smooth_compound, args.tol) 85 | 86 | # project sharp edges from the face, using only edges. 87 | # (to avoid slicing effects from sewn edge when projecting using face) 88 | sharp_edges_3d = list(TopologyExplorer(shape).edges()) 89 | sharp_edges_compound = project_shapes(sharp_edges_3d, direction).VCompound() 90 | if sharp_edges_compound: 91 | sharp_edges_discretized = discretize_compound(sharp_edges_compound, args.tol) 92 | 93 | # check if there are sewn edges 94 | sewn_compound = hlr_shapes.RgNLineVCompound() 95 | if sewn_compound: 96 | sewn_edges_discretized = discretize_compound(sewn_compound, args.tol) 97 | for sewn_edge in sewn_edges_discretized: 98 | try: 99 | sharp_edges_discretized.remove(sewn_edge) 100 | except ValueError: 101 | print("sewn edge assumption broken", name) 102 | break 103 | discretized_edges += sharp_edges_discretized 104 | 105 | return discretized_edges 106 | 107 | def discretized_edge_to_svg_polyline(points): 108 | """ Returns a svgwrite.Path for the edge, and the 2d bounding box 109 | """ 110 | return svgwrite.shapes.Polyline(points, fill="none", class_='vectorEffectClass') 111 | 112 | def shape_to_svg(shape, name, args): 113 | """ export a single shape to an svg file and json. 114 | shape: the TopoDS_Shape to export 115 | """ 116 | if shape.IsNull(): 117 | raise AssertionError("shape is Null") 118 | 119 | for i, direction in enumerate(directions): 120 | 121 | shape_discretized_edges = get_discretized_edges(name, shape, direction, args) 122 | 123 | save_svg(shape_discretized_edges, os.path.join(args.root, '3view_svg', f'{name}-{i}.svg'), args) 124 | 125 | svg2png( 126 | bytestring=open(os.path.join(args.root, '3view_svg', f'{name}-{i}.svg'), 'rb').read(), 127 | output_width=args.width, 128 | output_height=args.height, 129 | background_color='white', 130 | write_to=os.path.join(args.root, '3view_png', f'{name}-{i}.png') 131 | ) 132 | 133 | def render_3views(name, args): 134 | try: 135 | step_path = os.path.join(args.root, 'step', f'{name}.step') 136 | # step read timeout at 5 seconds 137 | try: 138 | shape, _ = read_step_file(step_path, verbosity=False) 139 | except: 140 | print(f"{name} took too long to read") 141 | return 142 | 143 | if shape is None: 144 | print(f"{name} is NULL shape") 145 | return 146 | 147 | center, extent = get_boundingbox(shape) 148 | 149 | trans, scale = gp_Trsf(), gp_Trsf() 150 | trans.SetTranslation(-gp_Vec(*center)) 151 | scale.SetScale(gp_Pnt(0, 0, 0), 2 / np.linalg.norm(extent)) 152 | brep_trans = BRepBuilderAPI_Transform(shape, scale * trans) 153 | shape = brep_trans.Shape() 154 | 155 | shape_to_svg(shape, name, args) 156 | except Exception as e: 157 | print(f"{name} received unknown error", e) 158 | 159 | def main(args): 160 | # all step files 161 | names = [] 162 | for name in sorted(os.listdir(os.path.join(args.root, 'stat'))): 163 | names.append(name[:8]) 164 | 165 | os.makedirs(os.path.join(args.root, '3view_svg'), exist_ok=True) 166 | os.makedirs(os.path.join(args.root, '3view_png'), exist_ok=True) 167 | 168 | process_map( 169 | partial(render_3views, args=args), names, 170 | max_workers=args.num_cores, chunksize=args.num_chunks 171 | ) 172 | 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser() 176 | parser.add_argument('--root', type=str, default="./data", 177 | help='dataset root.') 178 | parser.add_argument('--name', type=str, default=None, 179 | help='filename.') 180 | parser.add_argument('--num_cores', type=int, 181 | default=40, help='number of processors.') 182 | parser.add_argument('--num_chunks', type=int, 183 | default=10, help='number of chunk.') 184 | parser.add_argument('--width', type=int, 185 | default=256, help='svg width.') 186 | parser.add_argument('--height', type=int, 187 | default=256, help='svg height.') 188 | parser.add_argument('--tol', type=float, 189 | default=1e-4, help='svg discretization tolerance.') 190 | parser.add_argument('--line_width', type=str, 191 | default=str(3/256), help='svg line width.') 192 | parser.add_argument('--filter_num_shapes', type=int, 193 | default=8, help='do not process step files \ 194 | that have more than this number of shapes.') 195 | parser.add_argument('--filter_num_edges', type=int, 196 | default=1000, help='do not process step files \ 197 | that have more than this number of edges.') 198 | 199 | args = parser.parse_args() 200 | 201 | if args.name is None: 202 | main(args) 203 | else: 204 | render_3views(args.name, args) 205 | -------------------------------------------------------------------------------- /reconstruction/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakeEdge 3 | from OCC.Core.gp import gp_Ax2, gp_Circ, gp_Dir, gp_Pnt, gp_Vec 4 | from OCC.Extend.TopologyUtils import discretize_edge 5 | 6 | 7 | def construct_connected_cylinder(edges, edge_inds, tol=1e-4): 8 | ''' 9 | Given lines and their indices, 10 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1). 11 | ''' 12 | 13 | # group edges by their intersections 14 | groups = {} 15 | edge_ind_to_intersection = {} 16 | for edge, edge_ind in zip(edges, edge_inds): 17 | start, end = tuple(edge[0]), tuple(edge[-1]) 18 | start_found, end_found = False, False 19 | # find start's group 20 | for intersection in groups: 21 | if dist(start, intersection) < tol: 22 | groups[intersection].append((edge, 1, edge_ind)) 23 | start_found = True 24 | break 25 | if not start_found: 26 | groups[start] = [(edge, 1, edge_ind)] 27 | intersection = start 28 | 29 | if edge_ind not in edge_ind_to_intersection: 30 | edge_ind_to_intersection[edge_ind] = [intersection] 31 | else: 32 | edge_ind_to_intersection[edge_ind].append(intersection) 33 | 34 | # find end's group 35 | for intersection in groups: 36 | if dist(end, intersection) < tol: 37 | groups[intersection].append((edge, -1, edge_ind)) 38 | end_found = True 39 | break 40 | if not end_found: 41 | groups[end] = [(edge, -1, edge_ind)] 42 | intersection = end 43 | 44 | if edge_ind not in edge_ind_to_intersection: 45 | edge_ind_to_intersection[edge_ind] = [intersection] 46 | else: 47 | edge_ind_to_intersection[edge_ind].append(intersection) 48 | # fix one corner to be the origin. Generate a circle from the origin 49 | for intersection, edge_inter in groups.items(): 50 | assert len(edge_inter) == 2, "more than two edges intersect at one intersection" 51 | edge1, edge2 = edge_inter[0][0], edge_inter[1][0] 52 | # intersection of a line and a curve is a real intersection 53 | if is_straight_line(edge1) or is_straight_line(edge2): 54 | origin = intersection 55 | break 56 | 57 | # construct the circle 58 | circle = [] 59 | circle_inds = [] 60 | dirs = [] 61 | next_point = origin 62 | count = 0 63 | while True: 64 | for edge, direction, edge_ind in groups[next_point]: 65 | if edge_ind not in circle_inds: 66 | break 67 | circle.append(edge[::direction]) 68 | circle_inds.append(edge_ind) 69 | dirs.append(direction) 70 | # find the next point 71 | for intersection in edge_ind_to_intersection[edge_ind]: 72 | if tuple(next_point) != tuple(intersection): 73 | next_point = intersection 74 | break 75 | if next_point == origin: 76 | break 77 | count += 1 78 | if count >= 10: 79 | print("cylinder construction failed") 80 | break 81 | 82 | # return circle indices in sequence 83 | return circle, circle_inds, dirs 84 | 85 | 86 | def construct_connected_cycle(edges, edge_inds, tol=1e-4): 87 | ''' 88 | Given lines and their indices, 89 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1). 90 | ''' 91 | 92 | # group edges by their intersections 93 | groups = {} 94 | edge_ind_to_intersection = {} 95 | for edge, edge_ind in zip(edges, edge_inds): 96 | start, end = tuple(edge[0]), tuple(edge[-1]) 97 | start_found, end_found = False, False 98 | # find start's group 99 | for intersection in groups: 100 | if dist(start, intersection) < tol: 101 | groups[intersection].append((edge, 1, edge_ind)) 102 | start_found = True 103 | break 104 | if not start_found: 105 | groups[start] = [(edge, 1, edge_ind)] 106 | intersection = start 107 | 108 | if edge_ind not in edge_ind_to_intersection: 109 | edge_ind_to_intersection[edge_ind] = [intersection] 110 | else: 111 | edge_ind_to_intersection[edge_ind].append(intersection) 112 | 113 | # find end's group 114 | for intersection in groups: 115 | if dist(end, intersection) < tol: 116 | groups[intersection].append((edge, -1, edge_ind)) 117 | end_found = True 118 | break 119 | if not end_found: 120 | groups[end] = [(edge, -1, edge_ind)] 121 | intersection = end 122 | 123 | if edge_ind not in edge_ind_to_intersection: 124 | edge_ind_to_intersection[edge_ind] = [intersection] 125 | else: 126 | edge_ind_to_intersection[edge_ind].append(intersection) 127 | 128 | 129 | # construct circles 130 | all_circles = [] 131 | all_circle_inds = [] 132 | all_dirs = [] 133 | while len(groups) > 0: 134 | origin = list(groups.keys())[0] 135 | circle = [] 136 | circle_inds = [] 137 | dirs = [] 138 | next_point = origin 139 | skip = False 140 | while True: 141 | if next_point not in groups: 142 | skip = True 143 | break 144 | for edge, direction, edge_ind in groups[next_point]: 145 | if edge_ind not in circle_inds: 146 | break 147 | circle.append(edge[::direction]) 148 | circle_inds.append(edge_ind) 149 | dirs.append(direction) 150 | del groups[next_point] 151 | 152 | # find the next point 153 | for intersection in edge_ind_to_intersection[edge_ind]: 154 | if tuple(next_point) != tuple(intersection): 155 | next_point = intersection 156 | break 157 | if next_point == origin: 158 | break 159 | if not skip: 160 | all_circles.append(circle) 161 | all_circle_inds.append(circle_inds) 162 | all_dirs.append(dirs) 163 | # return circle indices in sequence 164 | return all_circles, all_circle_inds, all_dirs 165 | 166 | 167 | 168 | def check_parallel(v1, v2, tol=1e-10): 169 | return np.abs(np.dot(v1, v2)) > (1 - tol) 170 | 171 | def fit_curve(p1, p2, p3): 172 | ''' 173 | Given three 3D points, fit a circle to the points. 174 | Return the discretized curve between p1-p3-p2 175 | ''' 176 | center, radius, normal = find_circle_center(p1, p2, p3) 177 | 178 | # construct opencascade circle 179 | center = gp_Pnt(center[0], center[1], center[2]) 180 | normal = gp_Vec(normal[0], normal[1], normal[2]) 181 | ax = gp_Ax2(center, gp_Dir(normal)) 182 | circle = gp_Circ(ax, radius) 183 | circle_edge = BRepBuilderAPI_MakeEdge(circle).Edge() 184 | pts = discretize_edge(circle_edge, deflection=1e-5) 185 | return find_curve_between_points(pts, p1, p2, p3) 186 | 187 | def find_circle_center(p1, p2, p3): 188 | # triangle "edges" 189 | t = np.array(p2 - p1) 190 | u = np.array(p3 - p1) 191 | v = np.array(p3 - p2) 192 | 193 | # triangle normal 194 | w = np.cross(t, u) 195 | wsl = w.dot(w) 196 | 197 | # helpers 198 | iwsl2 = 1.0 / (2.0 * wsl) 199 | tt = t.dot(t) 200 | uu = u.dot(u) 201 | 202 | # result circle 203 | center = p1 + (u*tt*(u.dot(v)) - t*uu*(t.dot(v))) * iwsl2 204 | radius = np.sqrt(tt * uu * (v.dot(v)) * iwsl2 / 2) 205 | normal = w / np.sqrt(wsl) 206 | return center, radius, normal 207 | 208 | def find_curve_between_points(pts, p1, p2, p3): 209 | pts = np.array(pts) 210 | p1_ind = np.argmin(np.linalg.norm(pts - p1, axis=1)) 211 | p2_ind = np.argmin(np.linalg.norm(pts - p2, axis=1)) 212 | p1_ind, p2_ind = min(p1_ind, p2_ind), max(p1_ind, p2_ind) 213 | right_direction = p3 - pts[p1_ind] 214 | v1 = pts[(p1_ind+1) % (len(pts)-1)] - pts[p1_ind] 215 | # selecting p1_ind to p2_ind if angle is acute 216 | if np.dot(v1, right_direction) > 0: 217 | pts = pts[p1_ind:p2_ind+1] 218 | # selecting everything else if angle is obtuse 219 | else: 220 | pts = np.vstack([pts[p2_ind:], pts[:p1_ind+1]]) 221 | return pts 222 | 223 | def dist(p1, p2): 224 | return np.linalg.norm(np.array(p1) - np.array(p2)) 225 | 226 | def is_straight_line(line): 227 | return len(line) == 2 228 | -------------------------------------------------------------------------------- /reconstruction/reconstruction_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakeEdge 3 | from OCC.Core.gp import gp_Ax2, gp_Circ, gp_Dir, gp_Pnt, gp_Vec 4 | from OCC.Extend.TopologyUtils import discretize_edge 5 | 6 | 7 | def construct_connected_cylinder(edges, edge_inds, tol=1e-4): 8 | ''' 9 | Given lines and their indices, 10 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1). 11 | ''' 12 | 13 | # group edges by their intersections 14 | groups = {} 15 | edge_ind_to_intersection = {} 16 | for edge, edge_ind in zip(edges, edge_inds): 17 | start, end = tuple(edge[0]), tuple(edge[-1]) 18 | start_found, end_found = False, False 19 | # find start's group 20 | for intersection in groups: 21 | if dist(start, intersection) < tol: 22 | groups[intersection].append((edge, 1, edge_ind)) 23 | start_found = True 24 | break 25 | if not start_found: 26 | groups[start] = [(edge, 1, edge_ind)] 27 | intersection = start 28 | 29 | if edge_ind not in edge_ind_to_intersection: 30 | edge_ind_to_intersection[edge_ind] = [intersection] 31 | else: 32 | edge_ind_to_intersection[edge_ind].append(intersection) 33 | 34 | # find end's group 35 | for intersection in groups: 36 | if dist(end, intersection) < tol: 37 | groups[intersection].append((edge, -1, edge_ind)) 38 | end_found = True 39 | break 40 | if not end_found: 41 | groups[end] = [(edge, -1, edge_ind)] 42 | intersection = end 43 | 44 | if edge_ind not in edge_ind_to_intersection: 45 | edge_ind_to_intersection[edge_ind] = [intersection] 46 | else: 47 | edge_ind_to_intersection[edge_ind].append(intersection) 48 | # fix one corner to be the origin. Generate a circle from the origin 49 | for intersection, edge_inter in groups.items(): 50 | assert len(edge_inter) == 2, "more than two edges intersect at one intersection" 51 | edge1, edge2 = edge_inter[0][0], edge_inter[1][0] 52 | # intersection of a line and a curve is a real intersection 53 | if is_straight_line(edge1) or is_straight_line(edge2): 54 | origin = intersection 55 | break 56 | 57 | # construct the circle 58 | circle = [] 59 | circle_inds = [] 60 | dirs = [] 61 | next_point = origin 62 | count = 0 63 | while True: 64 | for edge, direction, edge_ind in groups[next_point]: 65 | if edge_ind not in circle_inds: 66 | break 67 | circle.append(edge[::direction]) 68 | circle_inds.append(edge_ind) 69 | dirs.append(direction) 70 | # find the next point 71 | for intersection in edge_ind_to_intersection[edge_ind]: 72 | if tuple(next_point) != tuple(intersection): 73 | next_point = intersection 74 | break 75 | if next_point == origin: 76 | break 77 | count += 1 78 | if count >= 10: 79 | print("cylinder construction failed") 80 | break 81 | 82 | # return circle indices in sequence 83 | return circle, circle_inds, dirs 84 | 85 | 86 | def construct_connected_cycle(edges, edge_inds, tol=1e-4): 87 | ''' 88 | Given lines and their indices, 89 | Form a loop of edges and return the loop's edges, indices and edge directions(1/-1). 90 | ''' 91 | 92 | # group edges by their intersections 93 | groups = {} 94 | edge_ind_to_intersection = {} 95 | for edge, edge_ind in zip(edges, edge_inds): 96 | start, end = tuple(edge[0]), tuple(edge[-1]) 97 | start_found, end_found = False, False 98 | # find start's group 99 | for intersection in groups: 100 | if dist(start, intersection) < tol: 101 | groups[intersection].append((edge, 1, edge_ind)) 102 | start_found = True 103 | break 104 | if not start_found: 105 | groups[start] = [(edge, 1, edge_ind)] 106 | intersection = start 107 | 108 | if edge_ind not in edge_ind_to_intersection: 109 | edge_ind_to_intersection[edge_ind] = [intersection] 110 | else: 111 | edge_ind_to_intersection[edge_ind].append(intersection) 112 | 113 | # find end's group 114 | for intersection in groups: 115 | if dist(end, intersection) < tol: 116 | groups[intersection].append((edge, -1, edge_ind)) 117 | end_found = True 118 | break 119 | if not end_found: 120 | groups[end] = [(edge, -1, edge_ind)] 121 | intersection = end 122 | 123 | if edge_ind not in edge_ind_to_intersection: 124 | edge_ind_to_intersection[edge_ind] = [intersection] 125 | else: 126 | edge_ind_to_intersection[edge_ind].append(intersection) 127 | 128 | 129 | # construct circles 130 | all_circles = [] 131 | all_circle_inds = [] 132 | all_dirs = [] 133 | while len(groups) > 0: 134 | origin = list(groups.keys())[0] 135 | circle = [] 136 | circle_inds = [] 137 | dirs = [] 138 | next_point = origin 139 | skip = False 140 | while True: 141 | if next_point not in groups: 142 | skip = True 143 | break 144 | for edge, direction, edge_ind in groups[next_point]: 145 | if edge_ind not in circle_inds: 146 | break 147 | circle.append(edge[::direction]) 148 | circle_inds.append(edge_ind) 149 | dirs.append(direction) 150 | del groups[next_point] 151 | 152 | # find the next point 153 | for intersection in edge_ind_to_intersection[edge_ind]: 154 | if tuple(next_point) != tuple(intersection): 155 | next_point = intersection 156 | break 157 | if next_point == origin: 158 | break 159 | if not skip: 160 | all_circles.append(circle) 161 | all_circle_inds.append(circle_inds) 162 | all_dirs.append(dirs) 163 | # return circle indices in sequence 164 | return all_circles, all_circle_inds, all_dirs 165 | 166 | 167 | 168 | def check_parallel(v1, v2, tol=1e-10): 169 | return np.abs(np.dot(v1, v2)) > (1 - tol) 170 | 171 | def fit_curve(p1, p2, p3): 172 | ''' 173 | Given three 3D points, fit a circle to the points. 174 | Return the discretized curve between p1-p3-p2 175 | ''' 176 | center, radius, normal = find_circle_center(p1, p2, p3) 177 | 178 | # construct opencascade circle 179 | center = gp_Pnt(center[0], center[1], center[2]) 180 | normal = gp_Vec(normal[0], normal[1], normal[2]) 181 | ax = gp_Ax2(center, gp_Dir(normal)) 182 | circle = gp_Circ(ax, radius) 183 | circle_edge = BRepBuilderAPI_MakeEdge(circle).Edge() 184 | pts = discretize_edge(circle_edge, deflection=1e-5) 185 | return find_curve_between_points(pts, p1, p2, p3) 186 | 187 | def find_circle_center(p1, p2, p3): 188 | # triangle "edges" 189 | t = np.array(p2 - p1) 190 | u = np.array(p3 - p1) 191 | v = np.array(p3 - p2) 192 | 193 | # triangle normal 194 | w = np.cross(t, u) 195 | wsl = w.dot(w) 196 | 197 | # helpers 198 | iwsl2 = 1.0 / (2.0 * wsl) 199 | tt = t.dot(t) 200 | uu = u.dot(u) 201 | 202 | # result circle 203 | center = p1 + (u*tt*(u.dot(v)) - t*uu*(t.dot(v))) * iwsl2 204 | radius = np.sqrt(tt * uu * (v.dot(v)) * iwsl2 / 2) 205 | normal = w / np.sqrt(wsl) 206 | return center, radius, normal 207 | 208 | def find_curve_between_points(pts, p1, p2, p3): 209 | pts = np.array(pts) 210 | p1_ind = np.argmin(np.linalg.norm(pts - p1, axis=1)) 211 | p2_ind = np.argmin(np.linalg.norm(pts - p2, axis=1)) 212 | p1_ind, p2_ind = min(p1_ind, p2_ind), max(p1_ind, p2_ind) 213 | right_direction = p3 - pts[p1_ind] 214 | v1 = pts[(p1_ind+1) % (len(pts)-1)] - pts[p1_ind] 215 | # selecting p1_ind to p2_ind if angle is acute 216 | if np.dot(v1, right_direction) > 0: 217 | pts = pts[p1_ind:p2_ind+1] 218 | # selecting everything else if angle is obtuse 219 | else: 220 | pts = np.vstack([pts[p2_ind:], pts[:p1_ind+1]]) 221 | return pts 222 | 223 | def dist(p1, p2): 224 | return np.linalg.norm(np.array(p1) - np.array(p2)) 225 | 226 | def is_straight_line(line): 227 | return len(line) == 2 228 | -------------------------------------------------------------------------------- /dataset/utils/TopoMapper.py: -------------------------------------------------------------------------------- 1 | from OCC.Extend.TopologyUtils import TopologyExplorer, WireExplorer, discretize_edge 2 | from OCC.Core.BRepFeat import BRepFeat_SplitShape 3 | from OCC.Core.TopTools import TopTools_SequenceOfShape 4 | from OCC.Core.ShapeFix import ShapeFix_ShapeTolerance 5 | 6 | from dataset.utils.projection_utils import d3_to_d2, project_shapes, discretize_compound, project_points 7 | from dataset.utils.Edge import Edge 8 | from dataset.utils.Face import Face 9 | from faceformer.utils import flatten_list 10 | 11 | import numpy as np 12 | 13 | 14 | 15 | class TopoMapper: 16 | def __init__(self, shape, args): 17 | self.shape = shape 18 | self.all_edges = None 19 | self.all_faces = None 20 | self.args = args 21 | self.tol = self.args.tol 22 | 23 | # add outline to shape 24 | outline_edges = self._find_outline_edges() 25 | self.full_topo = self._add_outline_edges(outline_edges) 26 | 27 | # construct all edge-face mappings 28 | self._construct_mapping() 29 | 30 | # project to 2D; each edge has dedge now 31 | self._project(args.discretize_last) 32 | 33 | # remove sewn edges 34 | sewn_edge_keys = self._find_sewn_edges() 35 | self._remove_sewn_edges(sewn_edge_keys) 36 | 37 | 38 | def _find_outline_edges(self): 39 | hlr_shapes = project_shapes(self.shape, self.args) 40 | outline_compound = hlr_shapes.OutLineVCompound3d() 41 | if outline_compound: 42 | return list(TopologyExplorer(outline_compound).edges()) 43 | return [] 44 | 45 | def _num_edges(self, splitshape): 46 | probing_shape = splitshape.Shape() 47 | split = BRepFeat_SplitShape(probing_shape) 48 | return split, len(list(TopologyExplorer(probing_shape).edges())) 49 | 50 | def _add_edge(self, split, edge, num_edge): 51 | toptool_seq_shape = TopTools_SequenceOfShape() 52 | toptool_seq_shape.Append(edge) 53 | add_success = split.Add(toptool_seq_shape) 54 | split, curr_num_edge = self._num_edges(split) 55 | add_success = add_success and (curr_num_edge > num_edge) 56 | return split, curr_num_edge, add_success 57 | 58 | def _add_outline_edges(self, outline_edges): 59 | if len(outline_edges) == 0: 60 | return TopologyExplorer(self.shape) 61 | split_edge_num = 0 62 | while True: 63 | # repeated split edge until number of edges converge 64 | split = BRepFeat_SplitShape(self.shape) 65 | split, num_edge = self._num_edges(split) 66 | for edge in outline_edges: 67 | probing_shape = split.Shape() 68 | backup_split, split = BRepFeat_SplitShape(probing_shape), BRepFeat_SplitShape(probing_shape) 69 | split, curr_num_edge, add_success = self._add_edge(split, edge, num_edge) 70 | if not add_success: 71 | # Increase outline tolerance when add fails 72 | # fixed tolerance, may need update 73 | tol = ShapeFix_ShapeTolerance() 74 | tol.SetTolerance(edge, 1) 75 | split, curr_num_edge, add_success = self._add_edge(backup_split, edge, num_edge) 76 | if not add_success: 77 | raise Exception("Fail to add splitting outline") 78 | if split_edge_num == curr_num_edge: 79 | break 80 | split_edge_num = curr_num_edge 81 | 82 | split_shape = split.Shape() 83 | return TopologyExplorer(split_shape) 84 | 85 | def _construct_mapping(self): 86 | ''' 87 | Construct edge-to-face mapping from wireframe. 88 | ''' 89 | all_edges = {} 90 | all_faces = {} 91 | 92 | for face in self.full_topo.faces(): 93 | new_face = Face(face, self) 94 | all_faces[hash(face)] = new_face 95 | 96 | sharp_edges_wires = list(self.full_topo.wires_from_face(face)) 97 | sharp_edges_3d = [] 98 | for wire in sharp_edges_wires: 99 | sharp_edges_3d += list(WireExplorer(wire).ordered_edges()) 100 | 101 | for edge in sharp_edges_3d: 102 | edge_id = hash(edge) # same edge has same hash 103 | 104 | # create edge 105 | if edge_id in all_edges: 106 | new_edge = all_edges[edge_id] 107 | new_edge.add_face(new_face, edge.Orientation()) 108 | else: 109 | new_edge = Edge(edge, faces=[new_face], orientations=[edge.Orientation()]) 110 | all_edges[edge_id] = new_edge 111 | 112 | # add edge to face 113 | new_face.add_edge(new_edge, edge.Orientation()) 114 | 115 | self.all_faces = all_faces 116 | self.all_edges = all_edges 117 | 118 | def _find_sewn_edges(self): 119 | ''' 120 | Any edge that occur in any face twice is sewn edge. 121 | ''' 122 | all_sewn_edge_keys = [] 123 | topo = TopologyExplorer(self.shape) 124 | for face in topo.faces(): 125 | edge_keys = [] 126 | 127 | sharp_edges_wires = list(topo.wires_from_face(face)) 128 | sharp_edges_3d = [] 129 | for wire in sharp_edges_wires: 130 | sharp_edges_3d += list(WireExplorer(wire).ordered_edges()) 131 | 132 | for edge in sharp_edges_3d: 133 | edge_id = hash(edge) # same edge has same hash 134 | 135 | # if edge is used twice in a face, it's a sewn edge 136 | if edge_id in edge_keys: 137 | all_sewn_edge_keys.append(edge_id) 138 | else: 139 | edge_keys.append(edge_id) 140 | 141 | return all_sewn_edge_keys 142 | 143 | def _remove_sewn_edges(self, sewn_edge_keys): 144 | ''' 145 | Remove all sewn edge and combine faces. 146 | ''' 147 | candidate_edges = set() 148 | for key in sewn_edge_keys: 149 | # if key in self.all_edges: 150 | sewn_edge = self.all_edges[key] 151 | # else: 152 | # # sewn edge not found after adding outline 153 | 154 | faces = sewn_edge.faces 155 | # roll edge sequence 156 | for face in faces: 157 | ind = face.keys.index(key) 158 | face.roll(ind) 159 | result_face = faces[0] 160 | for face in faces[1:]: 161 | pairs = result_face.merge(face) 162 | if pairs: 163 | for pair in pairs: 164 | candidate_edges.add(tuple(sorted(pair))) 165 | 166 | # merge candidate edges 167 | for key1, key2 in candidate_edges: 168 | # check if there's a 4th edge connected to this vertex 169 | d1, d2 = np.array(self.all_edges[key1].dedge), np.array(self.all_edges[key2].dedge) 170 | dist = lambda t: np.sum((t[0]-t[1])**2) 171 | p1, p2 = min([(d1[0], d2[0]), (d1[-1], d2[0]), (d1[0], d2[-1]), (d1[-1], d2[-1])], key=dist) 172 | vertex = (p1+p2) / 2 173 | 174 | skip = False 175 | for key in self.all_edges: 176 | if key == key1 or key == key2 or key in sewn_edge_keys: 177 | continue 178 | e = self.all_edges[key] 179 | if dist((vertex, e.dedge[0])) < self.tol or dist((vertex, e.dedge[-1])) < self.tol: 180 | skip = True 181 | break 182 | 183 | if not skip: 184 | self.all_edges[key1].merge(self.all_edges[key2], self) 185 | 186 | 187 | 188 | def _project(self, discretize_last=False): 189 | ''' 190 | Project all edges of the shape 191 | ''' 192 | for edge in list(self.all_edges.values()): 193 | if not discretize_last: 194 | sharp_dedge = discretize_edge(edge.edge, self.args.tol) 195 | edge.dedge3d = project_points(sharp_dedge, self.args) 196 | edge.dedge = d3_to_d2(edge.dedge3d) 197 | continue 198 | sharp_edges_compound = project_shapes(edge.edge, self.args).VCompound() 199 | if sharp_edges_compound is None: 200 | # invalid edge - delete 201 | key = hash(edge.edge) 202 | del self.all_edges[key] 203 | # del in face 204 | for face in edge.faces: 205 | face.remove_edge(key) 206 | continue 207 | 208 | dedge = discretize_compound(sharp_edges_compound, self.tol)[0] 209 | edge.dedge = dedge 210 | 211 | # Given a list of edges (they are from the same edge but broken into pieces) 212 | # project them and return one unified edge 213 | def _raw_project(self, edges, discretize_last=False): 214 | if not discretize_last: 215 | full_2d_dedge = [] 216 | for edge in edges: 217 | dedge = discretize_edge(edge, self.args.tol) 218 | full_2d_dedge += d3_to_d2(project_points(dedge, self.args)) 219 | return full_2d_dedge 220 | sharp_edges_compound = project_shapes(edges, self.args).VCompound() 221 | dedge = flatten_list(discretize_compound(sharp_edges_compound, self.tol)[:len(edges)]) 222 | return dedge 223 | 224 | # return x,y,z directions in camera world 225 | def get_dominant_directions(self): 226 | # project origin, x, y, z 227 | points = [(0,0,0), (1,0,0), (0,1,0), (0,0,1)] 228 | origin, x, y, z = project_points(points, self.args) 229 | origin, x, y, z = [np.array(p) for p in [origin, x, y, z]] 230 | return (x - origin).tolist(), (y - origin).tolist(), (z - origin).tolist() -------------------------------------------------------------------------------- /faceformer/models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from faceformer.embedding import PositionEmbeddingLearned, VanillaEmedding 5 | from faceformer.transformer import (TransformerDecoder, 6 | TransformerDecoderLayer, 7 | TransformerEncoder, 8 | TransformerEncoderLayer) 9 | from faceformer.utils import min_value_of_dtype 10 | 11 | 12 | class SurfaceFormer(nn.Module): 13 | 14 | def __init__(self, num_model=512, num_head=8, num_feedforward=2048, 15 | num_encoder_layers=6, num_decoder_layers=6, 16 | dropout=0.1, activation="relu", normalize_before=True, 17 | num_points_per_line=50, num_lines=1000, point_dim=2, 18 | label_seq_length=2000, token=None, teacher_forcing_ratio=0, **kwargs): 19 | super(SurfaceFormer, self).__init__() 20 | 21 | self.num_model = num_model # E 22 | self.num_labels = label_seq_length # T 23 | self.teacher_forcing_ratio = teacher_forcing_ratio 24 | self.token = token 25 | self.num_token = token.len 26 | 27 | self.val_enc = VanillaEmedding(num_points_per_line * point_dim, num_model, token) 28 | 29 | # position encoding 30 | self.pos_enc = PositionEmbeddingLearned(num_model, max_len=num_lines+self.num_token) 31 | self.query_pos_enc = PositionEmbeddingLearned(num_model, max_len=label_seq_length) 32 | 33 | # vertex transformer encoder 34 | encoder_layers = TransformerEncoderLayer( 35 | num_model, num_head, num_feedforward, dropout, activation, normalize_before) 36 | encoder_norm = nn.LayerNorm(num_model) if normalize_before else None 37 | self.encoder = TransformerEncoder(encoder_layers, num_encoder_layers, encoder_norm) 38 | 39 | # wire transformer decoder 40 | decoder_layers = TransformerDecoderLayer( 41 | num_model, num_head, num_feedforward, dropout, activation, normalize_before) 42 | decoder_norm = nn.LayerNorm(num_model) 43 | self.decoder = TransformerDecoder(decoder_layers, num_decoder_layers, decoder_norm) 44 | 45 | self.project = nn.Linear(num_model, num_model) 46 | 47 | self._reset_parameters() 48 | 49 | def _reset_parameters(self): 50 | for name, param in self.named_parameters(): 51 | if param.dim() > 1: 52 | nn.init.xavier_uniform_(param) 53 | 54 | def get_embeddings(self, input, label): 55 | val_embed = self.val_enc(input) 56 | pos_embed = self.pos_enc(val_embed) 57 | query_pos_embed = self.query_pos_enc(label) 58 | 59 | return val_embed, pos_embed, query_pos_embed 60 | 61 | def process_masks(self, input_mask, tgt_mask=None): 62 | # pad input mask 63 | padding_mask = torch.zeros((len(input_mask), self.num_token), device=input_mask.device).type_as(input_mask) 64 | input_mask = torch.cat([padding_mask, input_mask], dim=1) 65 | if tgt_mask is None: 66 | return input_mask 67 | # tgt is 1 shorter 68 | tgt_mask = tgt_mask[:, :-1].contiguous() 69 | return input_mask, tgt_mask 70 | 71 | def generate_square_subsequent_mask(self, sz): 72 | mask = (1 - torch.tril(torch.ones(sz, sz))) == 1 73 | return mask 74 | 75 | def patch_source(self, src, pos): 76 | src = src.transpose(0, 1) 77 | pos = pos.transpose(0, 1) 78 | return src, pos 79 | 80 | def patch_target(self, tgt, pos): 81 | tgt = tgt.transpose(0, 1) 82 | tgt, label = tgt[:-1].contiguous(), tgt[1:].contiguous() 83 | pos = pos.transpose(0, 1) 84 | pos = pos[:-1].contiguous() 85 | return tgt, label, pos 86 | 87 | def mix_gold_sampled(self, gold_target, sampled_target, prob): 88 | sampled_target = torch.cat((gold_target[0:1], sampled_target[:-1]), dim=0) 89 | 90 | targets = torch.stack((gold_target, sampled_target)) 91 | 92 | random = torch.rand(gold_target.shape, device=targets.device) 93 | index = (random < prob).long().unsqueeze(0) 94 | 95 | new_target = torch.gather(targets, 0, index) 96 | return new_target.squeeze(0) 97 | 98 | def forward_train(self, inputs, scheduled_sampling_ratio=0): 99 | # inputs: N x L x P x D 100 | input, input_mask = inputs['input'], inputs['input_mask'] 101 | label, label_mask = inputs['label'], inputs['label_mask'] 102 | 103 | # process masks 104 | input_mask, label_mask = self.process_masks(input_mask, label_mask) 105 | 106 | # embeddings: N x L x E, L+=4 107 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label) 108 | 109 | # prepare data: L x N x E 110 | source, pos_embed = self.patch_source(val_embed, pos_embed) 111 | target, label, query_pos_embed = self.patch_target(label, query_pos_embed) 112 | 113 | # encoder: L x N x E 114 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed) 115 | 116 | # feature gather: T x N x E 117 | tgt_mask = self.generate_square_subsequent_mask(target.size(0)).to(target.device) 118 | 119 | # T x N 120 | gold_target = target 121 | 122 | if scheduled_sampling_ratio > 0: 123 | with torch.no_grad(): 124 | target = target.unsqueeze(-1).repeat(1, 1, self.num_model) 125 | 126 | tgt = torch.gather(memory, 0, target) 127 | 128 | # decoder: T x N x E 129 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed, 130 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask) 131 | 132 | pointer = self.project(pointer) 133 | 134 | logits = torch.bmm(memory.transpose(0, 1), pointer.permute(1, 2, 0)) 135 | 136 | logits = logits.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logits.dtype)) 137 | 138 | sampled_target = torch.argmax(logits, dim=1).transpose(0, 1) 139 | 140 | target = self.mix_gold_sampled(gold_target, sampled_target, scheduled_sampling_ratio) 141 | 142 | 143 | target = target.unsqueeze(-1).repeat(1, 1, self.num_model) 144 | # memory: L x N x E 145 | # target: T x N x E 146 | # tgt: T x N x E 147 | tgt = torch.gather(memory, 0, target) # selects the targeting edge features 148 | 149 | # decoder: T x N x E 150 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed, 151 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask) 152 | 153 | pointer = self.project(pointer) 154 | 155 | # outputs 156 | inputs['embedding'] = memory.transpose(0, 1) 157 | inputs['pointer'] = pointer.transpose(0, 1) 158 | inputs['label'] = label.transpose(0, 1) 159 | return inputs 160 | 161 | def select_next(self, embedding, pointer, input_mask): 162 | embedding = embedding.transpose(0, 1) 163 | pointer = pointer.permute(1, 2, 0) 164 | logit = torch.bmm(embedding, pointer[..., -1:]) 165 | logit = logit.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logit.dtype)) 166 | next_token = torch.argmax(logit, dim=1).transpose(0, 1) 167 | return next_token 168 | 169 | def forward_eval(self, inputs): 170 | # inputs 171 | input, input_mask = inputs['input'], inputs['input_mask'] 172 | label = inputs['label'] 173 | 174 | batch_size = input.size(0) 175 | 176 | # process masks 177 | input_mask = self.process_masks(input_mask) 178 | 179 | # vertex embedding: N x L x E, L+=4 180 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label) 181 | 182 | # prepare data: L x N x E 183 | source, pos_embed = self.patch_source(val_embed, pos_embed) 184 | query_pos_embed = query_pos_embed.transpose(0, 1) 185 | 186 | # encoder: L x N x E 187 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed) 188 | 189 | # T x N 190 | predicts = torch.full((1, batch_size), self.token.SOS, dtype=torch.long).to(source.device) 191 | EOS_found = 0 192 | 193 | for step in range(self.num_labels - 1): 194 | target = predicts.unsqueeze(-1).repeat(1, 1, self.num_model) 195 | 196 | tgt = torch.gather(memory, 0, target) 197 | 198 | # decoder: T x N x E 199 | pointer = self.decoder(tgt, memory, memory_key_padding_mask=input_mask, 200 | pos=pos_embed, query_pos=query_pos_embed[:step+1]) 201 | 202 | pointer = self.project(pointer) 203 | 204 | next_token = self.select_next(memory, pointer, input_mask) 205 | 206 | predicts = torch.cat((predicts, next_token), dim=0) 207 | EOS_found += next_token.eq(self.token.EOS).sum().item() 208 | 209 | if EOS_found == batch_size: 210 | break 211 | 212 | # pad predict to self.num_labels 213 | predicts = torch.cat((predicts, torch.zeros(self.num_labels - predicts.size(0), predicts.size(1)).type_as(predicts)), dim=0) 214 | 215 | 216 | inputs['embedding'] = memory.transpose(0, 1) 217 | inputs['pointer'] = pointer.transpose(0, 1) 218 | inputs['predict'] = predicts.transpose(0, 1) 219 | return inputs 220 | 221 | def forward(self, inputs): 222 | """ 223 | inputs: 224 | input (N x L x P x D), 225 | label (N x T), 226 | input_mask (N x L), 227 | label_mask (N x T) 228 | outputs: 229 | embedding (N x L x E), 230 | pointer (N x T x E), 231 | label (N x T) 232 | """ 233 | if self.training: 234 | outputs = self.forward_train(inputs) 235 | else: 236 | outputs = self.forward_eval(inputs) 237 | return outputs 238 | -------------------------------------------------------------------------------- /faceformer/models/model_para.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from faceformer.embedding import PositionEmbeddingLearned, VanillaEmedding 5 | from faceformer.transformer import (TransformerDecoder, 6 | TransformerDecoderLayer, 7 | TransformerEncoder, 8 | TransformerEncoderLayer) 9 | from faceformer.utils import min_value_of_dtype 10 | 11 | 12 | class SurfaceFormer_Parallel(nn.Module): 13 | 14 | def __init__(self, num_model=512, num_head=8, num_feedforward=2048, 15 | num_encoder_layers=6, num_decoder_layers=6, 16 | dropout=0.1, activation="relu", normalize_before=True, 17 | num_points_per_line=50, num_lines=64, point_dim=2, 18 | max_face_length=10, token=None, 19 | teacher_forcing_ratio=0, **kwargs): 20 | super(SurfaceFormer_Parallel, self).__init__() 21 | 22 | self.num_model = num_model # E 23 | self.max_face_length = max_face_length # T 24 | self.teacher_forcing_ratio = teacher_forcing_ratio 25 | self.token = token 26 | self.num_token = token.len 27 | 28 | self.val_enc = VanillaEmedding(num_points_per_line * point_dim, num_model, token) 29 | 30 | # position encoding 31 | self.pos_enc = PositionEmbeddingLearned(num_model, max_len=num_lines+self.num_token) 32 | self.query_pos_enc = PositionEmbeddingLearned(num_model, max_len=max_face_length) 33 | 34 | # vertex transformer encoder 35 | encoder_layers = TransformerEncoderLayer( 36 | num_model, num_head, num_feedforward, dropout, activation, normalize_before) 37 | encoder_norm = nn.LayerNorm(num_model) if normalize_before else None 38 | self.encoder = TransformerEncoder(encoder_layers, num_encoder_layers, encoder_norm) 39 | 40 | # wire transformer decoder 41 | decoder_layers = TransformerDecoderLayer( 42 | num_model, num_head, num_feedforward, dropout, activation, normalize_before) 43 | decoder_norm = nn.LayerNorm(num_model) 44 | self.decoder = TransformerDecoder(decoder_layers, num_decoder_layers, decoder_norm) 45 | 46 | self.project = nn.Linear(num_model, num_model) 47 | 48 | self._reset_parameters() 49 | 50 | def _reset_parameters(self): 51 | for _, param in self.named_parameters(): 52 | if param.dim() > 1: 53 | nn.init.xavier_uniform_(param) 54 | 55 | def get_embeddings(self, input, label): 56 | val_embed = self.val_enc(input) 57 | pos_embed = self.pos_enc(val_embed) 58 | query_pos_embed = self.query_pos_enc(label.transpose(1, 2)) 59 | 60 | return val_embed, pos_embed, query_pos_embed 61 | 62 | def process_masks(self, input_mask, tgt_mask=None): 63 | # pad input mask 64 | padding_mask = torch.zeros((len(input_mask), self.num_token)).type_as(input_mask) 65 | input_mask = torch.cat([padding_mask, input_mask], dim=1) 66 | if tgt_mask is None: 67 | return input_mask 68 | # tgt is 1 shorter 69 | tgt_mask = tgt_mask[..., :-1].contiguous() 70 | return input_mask, tgt_mask 71 | 72 | def generate_square_subsequent_mask(self, sz): 73 | mask = (1 - torch.tril(torch.ones(sz, sz))) == 1 74 | return mask 75 | 76 | def patch_source(self, src, pos): 77 | src = src.transpose(0, 1) 78 | pos = pos.transpose(0, 1) 79 | return src, pos 80 | 81 | def patch_target(self, tgt, pos): 82 | tgt = tgt.permute(2, 0, 1) 83 | tgt, label = tgt[:-1].contiguous(), tgt[1:].contiguous() 84 | pos = pos.transpose(0, 1) 85 | pos = pos[:-1].contiguous() 86 | return tgt, label, pos 87 | 88 | def mix_gold_sampled(self, gold_target, sampled_target, prob): 89 | sampled_target = torch.cat((gold_target[0:1], sampled_target[:-1]), dim=0) 90 | 91 | targets = torch.stack((gold_target, sampled_target)) 92 | 93 | random = torch.rand(gold_target.shape, device=targets.device) 94 | index = (random < prob).long().unsqueeze(0) 95 | 96 | new_target = torch.gather(targets, 0, index) 97 | return new_target.squeeze(0) 98 | 99 | def forward_train(self, inputs, scheduled_sampling_ratio=0): 100 | # inputs: N x L x P x D 101 | # target: N x F x T 102 | input, input_mask = inputs['input'], inputs['input_mask'] 103 | label, label_mask = inputs['label'], inputs['label_mask'] 104 | max_num_edges = max(inputs['num_input']) 105 | label, label_mask = label[:, :max_num_edges, :], label_mask[:, :max_num_edges, :] 106 | 107 | # process masks 108 | input_mask, label_mask = self.process_masks(input_mask, label_mask) 109 | 110 | # embeddings: N x L x E, L+=4 111 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label) 112 | 113 | # prepare data: L x N x E 114 | source, pos_embed = self.patch_source(val_embed, pos_embed) 115 | # target: T x N x (F) 116 | target, label, query_pos_embed = self.patch_target(label, query_pos_embed) 117 | 118 | # encoder: L x N x E 119 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed) 120 | 121 | # L x NF x E 122 | memory = memory.repeat_interleave(max_num_edges, 1) 123 | 124 | # feature gather: T x N x E 125 | tgt_mask = self.generate_square_subsequent_mask(target.size(0)).type_as(input_mask) 126 | 127 | # T x N 128 | gold_target = target 129 | 130 | if scheduled_sampling_ratio > 0: 131 | with torch.no_grad(): 132 | target = target.unsqueeze(-1).repeat(1, 1, self.num_model) 133 | 134 | tgt = torch.gather(memory, 0, target) 135 | 136 | # decoder: T x N x E 137 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed, 138 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask) 139 | 140 | pointer = self.project(pointer) 141 | 142 | logits = torch.bmm(memory.transpose(0, 1), pointer.permute(1, 2, 0)) 143 | 144 | logits = logits.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logits.dtype)) 145 | 146 | sampled_target = torch.argmax(logits, dim=1).transpose(0, 1) 147 | 148 | target = self.mix_gold_sampled(gold_target, sampled_target, scheduled_sampling_ratio) 149 | 150 | # T x NF x E 151 | target = target.unsqueeze(-1).repeat(1, 1, 1, self.num_model).flatten(1, 2) 152 | # memory: L x NF x E 153 | # target: T x NF x E 154 | # tgt: T x NF x E 155 | tgt = torch.gather(memory, 0, target) # selects the targeting edge features 156 | # NF x T 157 | label_mask = label_mask.flatten(0, 1) 158 | # NF x L 159 | input_mask = input_mask.repeat_interleave(max_num_edges, 0) 160 | 161 | # decoder: T x NF x E 162 | pointer = self.decoder(tgt, memory, tgt_mask=tgt_mask, pos=pos_embed, query_pos=query_pos_embed, 163 | tgt_key_padding_mask=label_mask, memory_key_padding_mask=input_mask) 164 | 165 | pointer = self.project(pointer) 166 | 167 | # outputs 168 | inputs['embedding'] = memory.transpose(0, 1) 169 | inputs['pointer'] = pointer.transpose(0, 1) 170 | inputs['label'] = label.flatten(1, 2).transpose(0, 1) 171 | return inputs 172 | 173 | def select_next(self, embedding, pointer, input_mask): 174 | embedding = embedding.transpose(0, 1) 175 | pointer = pointer.permute(1, 2, 0) 176 | logit = torch.bmm(embedding, pointer[..., -1:]) 177 | logit = logit.masked_fill(input_mask.unsqueeze(-1), min_value_of_dtype(logit.dtype)) 178 | next_token = torch.argmax(logit, dim=1).transpose(0, 1) 179 | return next_token 180 | 181 | def forward_eval(self, inputs): 182 | # inputs: N x L x P x D 183 | input, input_mask = inputs['input'], inputs['input_mask'] 184 | label = inputs['label'] 185 | 186 | batch_size = input.size(0) # N 187 | max_num_edges = max(inputs['num_input']) # L 188 | 189 | 190 | # process masks 191 | input_mask = self.process_masks(input_mask) 192 | 193 | # vertex embedding: N x L x E, L+=4 194 | val_embed, pos_embed, query_pos_embed = self.get_embeddings(input, label) 195 | 196 | # prepare data: L x N x E 197 | source, pos_embed = self.patch_source(val_embed, pos_embed) 198 | 199 | # use all edges as the first token 200 | # anchors: 1 x N x (F) 201 | anchors = torch.arange(max_num_edges).repeat(1, batch_size, 1).type_as(label) 202 | 203 | # mask unused face seq start token as EOS (Other type of face) 204 | for i, num_edges in enumerate(inputs['num_input']): 205 | anchors[:, i, num_edges:] = self.token.len - 1 206 | query_pos_embed = query_pos_embed.transpose(0, 1) 207 | predicts = anchors.flatten(1, 2) # 1 x N(F) 208 | 209 | # encoder: L x N x E 210 | memory = self.encoder(source, src_key_padding_mask=input_mask, pos=pos_embed) 211 | # L x N(F) x E 212 | memory = memory.repeat_interleave(max_num_edges, 1) 213 | # N(F) x L 214 | input_mask = input_mask.repeat_interleave(max_num_edges, 0) 215 | 216 | for step in range(self.max_face_length - 1): 217 | target = predicts.unsqueeze(-1).repeat(1, 1, self.num_model) 218 | 219 | tgt = torch.gather(memory, 0, target) 220 | 221 | # decoder: T x N(F) x E 222 | pointer = self.decoder(tgt, memory, memory_key_padding_mask=input_mask, 223 | pos=pos_embed, query_pos=query_pos_embed[:step+1]) 224 | 225 | pointer = self.project(pointer) 226 | 227 | next_token = self.select_next(memory, pointer, input_mask) 228 | 229 | predicts = torch.cat((predicts, next_token), dim=0) 230 | 231 | # if all face tokens are EOS or PAD, stop decoding 232 | if torch.all(next_token < self.num_token): 233 | break 234 | 235 | # pad predict to self.max_face_length 236 | predicts = torch.cat((predicts, torch.zeros(self.max_face_length - predicts.size(0), predicts.size(1)).type_as(predicts)), dim=0) 237 | 238 | #predicts: T x N(F) 239 | 240 | inputs['predict'] = predicts.transpose(0, 1).view(-1, max_num_edges, self.max_face_length) 241 | return inputs 242 | 243 | def forward(self, inputs): 244 | """ 245 | inputs: 246 | input (N x L x P x D), 247 | label (N x F x T), 248 | input_mask (N x L), 249 | label_mask (N x F x T) 250 | outputs: 251 | embedding (NF x L x E), 252 | pointer (NF x T x E), 253 | label (N x F x T) 254 | """ 255 | if self.training: 256 | outputs = self.forward_train(inputs) 257 | else: 258 | outputs = self.forward_eval(inputs) 259 | return outputs 260 | -------------------------------------------------------------------------------- /faceformer/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | 5 | Copy-paste from torch.nn.Transformer with modifications: 6 | * positional encodings are passed in MHattention 7 | * extra LN at the end of encoder is removed 8 | * decoder returns a stack of activations from all decoding layers 9 | """ 10 | import copy 11 | from typing import List, Optional 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import Tensor, nn 16 | 17 | 18 | class Transformer(nn.Module): 19 | 20 | def __init__(self, num_model=512, num_head=8, num_encoder_layers=6, 21 | num_decoder_layers=6, num_feedforward=2048, dropout=0.1, 22 | activation="relu", normalize_before=False, 23 | return_intermediate_dec=False): 24 | super().__init__() 25 | 26 | encoder_layer = TransformerEncoderLayer(num_model, num_head, num_feedforward, 27 | dropout, activation, normalize_before) 28 | encoder_norm = nn.LayerNorm(num_model) if normalize_before else None 29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 30 | 31 | decoder_layer = TransformerDecoderLayer(num_model, num_head, num_feedforward, 32 | dropout, activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(num_model) 34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 35 | return_intermediate=return_intermediate_dec) 36 | 37 | self._reset_parameters() 38 | 39 | self.num_model = num_model 40 | self.num_head = num_head 41 | 42 | def _reset_parameters(self): 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | def forward(self, src, mask, query_embed, pos_embed): 48 | # flatten NxCxHxW to HWxNxC 49 | bs, c, h, w = src.shape 50 | src = src.flatten(2).permute(2, 0, 1) 51 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 52 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 53 | mask = mask.flatten(1) 54 | 55 | tgt = torch.zeros_like(query_embed) 56 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 57 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 58 | pos=pos_embed, query_pos=query_embed) 59 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 60 | 61 | 62 | class TransformerEncoder(nn.Module): 63 | 64 | def __init__(self, encoder_layer, num_layers, norm=None): 65 | super().__init__() 66 | self.layers = _get_clones(encoder_layer, num_layers) 67 | self.num_layers = num_layers 68 | self.norm = norm 69 | 70 | def forward(self, src, 71 | mask: Optional[Tensor] = None, 72 | src_key_padding_mask: Optional[Tensor] = None, 73 | pos: Optional[Tensor] = None): 74 | output = src 75 | 76 | for layer in self.layers: 77 | output = layer(output, src_mask=mask, 78 | src_key_padding_mask=src_key_padding_mask, pos=pos) 79 | 80 | if self.norm is not None: 81 | output = self.norm(output) 82 | 83 | return output 84 | 85 | 86 | class TransformerDecoder(nn.Module): 87 | 88 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 89 | super().__init__() 90 | self.layers = _get_clones(decoder_layer, num_layers) 91 | self.num_layers = num_layers 92 | self.norm = norm 93 | self.return_intermediate = return_intermediate 94 | 95 | def forward(self, tgt, memory, 96 | tgt_mask: Optional[Tensor] = None, 97 | memory_mask: Optional[Tensor] = None, 98 | tgt_key_padding_mask: Optional[Tensor] = None, 99 | memory_key_padding_mask: Optional[Tensor] = None, 100 | pos: Optional[Tensor] = None, 101 | query_pos: Optional[Tensor] = None): 102 | output = tgt 103 | 104 | intermediate = [] 105 | 106 | for layer in self.layers: 107 | output = layer(output, memory, tgt_mask=tgt_mask, 108 | memory_mask=memory_mask, 109 | tgt_key_padding_mask=tgt_key_padding_mask, 110 | memory_key_padding_mask=memory_key_padding_mask, 111 | pos=pos, query_pos=query_pos) 112 | if self.return_intermediate: 113 | intermediate.append(self.norm(output)) 114 | 115 | if self.norm is not None: 116 | output = self.norm(output) 117 | if self.return_intermediate: 118 | intermediate.pop() 119 | intermediate.append(output) 120 | 121 | if self.return_intermediate: 122 | return torch.stack(intermediate) 123 | 124 | return output 125 | 126 | 127 | class TransformerEncoderLayer(nn.Module): 128 | 129 | def __init__(self, num_model, num_head, num_feedforward=2048, dropout=0.1, 130 | activation="relu", normalize_before=False): 131 | super().__init__() 132 | self.self_attn = nn.MultiheadAttention(num_model, num_head, dropout=dropout) 133 | # Implementation of Feedforward model 134 | self.linear1 = nn.Linear(num_model, num_feedforward) 135 | self.dropout = nn.Dropout(dropout) 136 | self.linear2 = nn.Linear(num_feedforward, num_model) 137 | 138 | self.norm1 = nn.LayerNorm(num_model) 139 | self.norm2 = nn.LayerNorm(num_model) 140 | self.dropout1 = nn.Dropout(dropout) 141 | self.dropout2 = nn.Dropout(dropout) 142 | 143 | self.activation = _get_activation_fn(activation) 144 | self.normalize_before = normalize_before 145 | 146 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 147 | return tensor if pos is None else tensor + pos 148 | 149 | def forward_post(self, 150 | src, 151 | src_mask: Optional[Tensor] = None, 152 | src_key_padding_mask: Optional[Tensor] = None, 153 | pos: Optional[Tensor] = None): 154 | q = k = self.with_pos_embed(src, pos) 155 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 156 | key_padding_mask=src_key_padding_mask)[0] 157 | src = src + self.dropout1(src2) 158 | src = self.norm1(src) 159 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 160 | src = src + self.dropout2(src2) 161 | src = self.norm2(src) 162 | return src 163 | 164 | def forward_pre(self, src, 165 | src_mask: Optional[Tensor] = None, 166 | src_key_padding_mask: Optional[Tensor] = None, 167 | pos: Optional[Tensor] = None): 168 | src2 = self.norm1(src) 169 | q = k = self.with_pos_embed(src2, pos) 170 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 171 | key_padding_mask=src_key_padding_mask)[0] 172 | src = src + self.dropout1(src2) 173 | src2 = self.norm2(src) 174 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 175 | src = src + self.dropout2(src2) 176 | return src 177 | 178 | def forward(self, src, 179 | src_mask: Optional[Tensor] = None, 180 | src_key_padding_mask: Optional[Tensor] = None, 181 | pos: Optional[Tensor] = None): 182 | if self.normalize_before: 183 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 184 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 185 | 186 | 187 | class TransformerDecoderLayer(nn.Module): 188 | 189 | def __init__(self, num_model, num_head, num_feedforward=2048, dropout=0.1, 190 | activation="relu", normalize_before=False): 191 | super().__init__() 192 | self.self_attn = nn.MultiheadAttention(num_model, num_head, dropout=dropout) 193 | self.multihead_attn = nn.MultiheadAttention(num_model, num_head, dropout=dropout) 194 | # Implementation of Feedforward model 195 | self.linear1 = nn.Linear(num_model, num_feedforward) 196 | self.dropout = nn.Dropout(dropout) 197 | self.linear2 = nn.Linear(num_feedforward, num_model) 198 | 199 | self.norm1 = nn.LayerNorm(num_model) 200 | self.norm2 = nn.LayerNorm(num_model) 201 | self.norm3 = nn.LayerNorm(num_model) 202 | self.dropout1 = nn.Dropout(dropout) 203 | self.dropout2 = nn.Dropout(dropout) 204 | self.dropout3 = nn.Dropout(dropout) 205 | 206 | self.activation = _get_activation_fn(activation) 207 | self.normalize_before = normalize_before 208 | 209 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 210 | return tensor if pos is None else tensor + pos 211 | 212 | def forward_post(self, tgt, memory, 213 | tgt_mask: Optional[Tensor] = None, 214 | memory_mask: Optional[Tensor] = None, 215 | tgt_key_padding_mask: Optional[Tensor] = None, 216 | memory_key_padding_mask: Optional[Tensor] = None, 217 | pos: Optional[Tensor] = None, 218 | query_pos: Optional[Tensor] = None): 219 | q = k = self.with_pos_embed(tgt, query_pos) 220 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 221 | key_padding_mask=tgt_key_padding_mask)[0] 222 | tgt = tgt + self.dropout1(tgt2) 223 | tgt = self.norm1(tgt) 224 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 225 | key=self.with_pos_embed(memory, pos), 226 | value=memory, attn_mask=memory_mask, 227 | key_padding_mask=memory_key_padding_mask)[0] 228 | tgt = tgt + self.dropout2(tgt2) 229 | tgt = self.norm2(tgt) 230 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 231 | tgt = tgt + self.dropout3(tgt2) 232 | tgt = self.norm3(tgt) 233 | return tgt 234 | 235 | def forward_pre(self, tgt, memory, 236 | tgt_mask: Optional[Tensor] = None, 237 | memory_mask: Optional[Tensor] = None, 238 | tgt_key_padding_mask: Optional[Tensor] = None, 239 | memory_key_padding_mask: Optional[Tensor] = None, 240 | pos: Optional[Tensor] = None, 241 | query_pos: Optional[Tensor] = None): 242 | tgt2 = self.norm1(tgt) 243 | q = k = self.with_pos_embed(tgt2, query_pos) 244 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 245 | key_padding_mask=tgt_key_padding_mask)[0] 246 | tgt = tgt + self.dropout1(tgt2) 247 | tgt2 = self.norm2(tgt) 248 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 249 | key=self.with_pos_embed(memory, pos), 250 | value=memory, attn_mask=memory_mask, 251 | key_padding_mask=memory_key_padding_mask)[0] 252 | tgt = tgt + self.dropout2(tgt2) 253 | tgt2 = self.norm3(tgt) 254 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 255 | tgt = tgt + self.dropout3(tgt2) 256 | return tgt 257 | 258 | def forward(self, tgt, memory, 259 | tgt_mask: Optional[Tensor] = None, 260 | memory_mask: Optional[Tensor] = None, 261 | tgt_key_padding_mask: Optional[Tensor] = None, 262 | memory_key_padding_mask: Optional[Tensor] = None, 263 | pos: Optional[Tensor] = None, 264 | query_pos: Optional[Tensor] = None): 265 | if self.normalize_before: 266 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 267 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 268 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 269 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 270 | 271 | 272 | def _get_clones(module, N): 273 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 274 | 275 | 276 | def _get_activation_fn(activation): 277 | """Return an activation function given a string""" 278 | if activation == "relu": 279 | return F.relu 280 | if activation == "gelu": 281 | return F.gelu 282 | if activation == "glu": 283 | return F.glu 284 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 285 | -------------------------------------------------------------------------------- /faceformer/trainer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from collections import Counter 5 | 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | import torch.nn.functional as F 10 | from numpyencoder import NumpyEncoder 11 | 12 | from faceformer.post_processing import (filter_faces_by_encloseness, map_coedge_into_edges) 13 | from faceformer.utils import flatten_list 14 | 15 | 16 | class Trainer(pl.LightningModule): 17 | def __init__(self, hparams, model_class, dataset_class): 18 | super().__init__() 19 | self.save_hyperparameters(hparams) 20 | self.model = model_class(**self.hparams.model) 21 | self.dataset_class = dataset_class 22 | self.validation_num = 0 23 | self.time_count = 0 24 | self.total_time = 0 25 | # self.hparams.root_dir = "ours" 26 | 27 | def forward(self, batch): 28 | return self.model(batch) 29 | 30 | def train_dataloader(self): 31 | dataset = self.dataset_class(self.hparams.root_dir, self.hparams.datasets_train, self.hparams.model) 32 | dataloader = torch.utils.data.DataLoader( 33 | dataset, batch_size=self.hparams.batch_size_train, num_workers=4, 34 | shuffle=True, drop_last=True) 35 | return dataloader 36 | 37 | def val_dataloader(self): 38 | name_dir = os.path.join(self.logger.log_dir, self.hparams.trainer.version) 39 | if not os.path.exists(name_dir): 40 | os.mkdir(name_dir) 41 | dataset = self.dataset_class(self.hparams.root_dir, self.hparams.datasets_valid, self.hparams.model) 42 | self.dataset = dataset 43 | dataloader = torch.utils.data.DataLoader( 44 | dataset, batch_size=self.hparams.batch_size_valid, num_workers=4, 45 | shuffle=False, drop_last=False) 46 | return dataloader 47 | 48 | def test_dataloader(self): 49 | dataset = self.dataset_class(self.hparams.root_dir, self.hparams.datasets_test, self.hparams.model) 50 | self.dataset = dataset 51 | self.hparams.batch_size_valid = 1 52 | dataloader = torch.utils.data.DataLoader( 53 | dataset, batch_size=self.hparams.batch_size_valid, num_workers=4, 54 | shuffle=False, drop_last=False) 55 | json_dir = os.path.join(self.logger.log_dir, 'json') 56 | if not os.path.exists(json_dir): 57 | os.mkdir(json_dir) 58 | return dataloader 59 | 60 | def compute_loss(self, outputs): 61 | embedding, pointer, labels = outputs['embedding'], outputs['pointer'], outputs['label'] 62 | 63 | # embedding N x L x E, pointer N x T x E 64 | # logits: N x L x T 65 | logits = torch.bmm(embedding, pointer.transpose(1, 2)) 66 | 67 | #label: N x T 68 | labels = labels.detach().clone() 69 | loss = F.cross_entropy( 70 | logits, labels, ignore_index=self.hparams.model.token.PAD, reduction='sum') 71 | 72 | valid = labels != self.hparams.model.token.PAD 73 | valid_sum = valid.sum() 74 | pred = torch.argmax(logits, dim=1) 75 | outputs['predict'] = pred 76 | acc_sum = (valid * (pred == labels)).sum() 77 | cls_acc = float(acc_sum) / (valid_sum + 1e-10) 78 | 79 | loss = loss / valid_sum 80 | return loss, cls_acc 81 | 82 | def training_step(self, batch, batch_idx): 83 | outputs = self.forward(batch) 84 | loss, acc = self.compute_loss(outputs) 85 | self.log('train_loss', loss, logger=True) 86 | self.log('train_cls_acc', acc, prog_bar=True, logger=True) 87 | if torch.isnan(loss): 88 | return None 89 | return loss 90 | 91 | def validation_step(self, batch, batch_idx): 92 | outputs = self.forward(batch) 93 | acc, outputs = self.face_accuracy(outputs) 94 | 95 | self.log('valid_accuracy', np.mean(outputs['accuracy']), logger=True) 96 | self.log('valid_type_acc_coedge_seq', np.mean(outputs['type_acc_coedge_seq']), logger=True) 97 | self.log('valid_precision', np.mean(outputs['precisions']), logger=True) 98 | self.log('valid_recall', np.mean(outputs['recalls']), logger=True) 99 | self.log('type_acc', np.mean(outputs['type_acc']), logger=True) 100 | for pred, label, prec in zip(outputs['predictions'], outputs['labels'], outputs['precisions']): 101 | self.logger.experiment.add_text('result', f'pred: {pred} \n\n label: {label} \n\n precision: {prec}', self.validation_num) 102 | # return outputs 103 | 104 | # saves all necessary info for reconstruction 105 | def test_step(self, batch, batch_idx): 106 | torch.cuda.synchronize() 107 | a = time.time() 108 | outputs = self.forward(batch) 109 | torch.cuda.synchronize() 110 | self.total_time += time.time() - a 111 | self.time_count += 1 112 | print("Avg Time", self.total_time / self.time_count, "seconds.") 113 | acc, outputs = self.face_accuracy(outputs) 114 | 115 | self.log('test_precision', np.mean(outputs['precisions']), logger=True) 116 | self.log('test_recall', np.mean(outputs['recalls']), logger=True) 117 | self.log('test_type_acc', np.mean(outputs['type_acc']), logger=True) 118 | for ind in range(len(outputs['predictions'])): 119 | predict_faces_w_types = outputs['predictions'][ind] 120 | label_faces_w_types = outputs['labels'][ind] 121 | json_name = batch['name'][ind] 122 | name = json_name[5:13] 123 | 124 | with open(os.path.join(self.hparams.root_dir, json_name), "r") as f: 125 | raw_data = json.loads(f.read()) 126 | edges = raw_data['edges'] 127 | 128 | predicted_data = {} 129 | predicted_data['edges'] = edges 130 | predicted_data['dominant_directions'] = raw_data['dominant_directions'] 131 | predicted_data['pred_faces'] = predict_faces_w_types 132 | predicted_data['label_faces'] = label_faces_w_types 133 | 134 | 135 | with open(os.path.join(self.logger.log_dir, 'json', f'{name}.json'), 'w') as f: 136 | json.dump(predicted_data, f, cls=NumpyEncoder) 137 | 138 | def validation_epoch_end(self, outputs): 139 | self.validation_num += 1 140 | 141 | def configure_optimizers(self): 142 | optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.trainer.lr) 143 | if self.hparams.trainer.lr_step == 0: 144 | return optimizer 145 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, self.hparams.trainer.lr_step) 146 | return { 147 | 'optimizer': optimizer, 148 | 'lr_scheduler': scheduler 149 | } 150 | 151 | # Parse single-sequence faces 152 | # Return (0, faces' indices) to be consistent with face_typing task 153 | def parse_faces(self, predicts, labels, num_edges): 154 | # cut off tokens after the [EOS] 155 | label = np.split(labels, np.where(labels == self.hparams.model.token.EOS)[0]+1)[0] 156 | predict = np.split(predicts, np.where(predicts == self.hparams.model.token.EOS)[0] + 1)[0] 157 | 158 | label_faces = np.split(label, np.where(label == self.hparams.model.token.SEP)[0]+1) # split by SEP 159 | constructed_label_faces = [] 160 | for face in label_faces: 161 | label = face[:-1] - self.hparams.model.token.len # remove SEP and remove token offset 162 | label = label[label >= 0] 163 | label = label[label < num_edges] 164 | if len(label) > 0: 165 | constructed_label_faces.append((0, tuple(label.tolist()))) 166 | 167 | predict_faces = np.split(predict, np.where(predict == self.hparams.model.token.SEP)[0]+1) # split by SEP 168 | constructed_predict_faces = [] 169 | for face in predict_faces: 170 | if len(face) > 1: 171 | predict = face[:-1] - self.hparams.model.token.len # remove SEP and remove token offset 172 | predict = predict[predict >= 0] 173 | predict = predict[predict < num_edges] 174 | if len(predict) > 0: 175 | constructed_predict_faces.append((0, tuple(predict.tolist()))) 176 | 177 | return constructed_predict_faces, constructed_label_faces 178 | 179 | # Parse multi-sequence faces 180 | # Return faces' indices 181 | def parse_parallel_faces(self, predicts, labels, num_edges): 182 | predict_faces, label_faces = [], [] 183 | # cut off tokens after the [EOS] (Type of face in this case) 184 | for label in labels: 185 | label = np.split(label, np.where((label >= self.hparams.model.token.face_type_offset) & (label < self.hparams.model.token.len))[0]+1)[0] 186 | # extract face type 187 | face_type = label[-1] - self.hparams.model.token.face_type_offset 188 | # remove token offset 189 | label -= self.hparams.model.token.len 190 | # only take the valid indices 191 | label = label[label >= 0] 192 | if len(label) > 0: 193 | # only count the face if face is not empty and not full of paddings 194 | label_faces.append((face_type, tuple(label.tolist()))) 195 | 196 | for predict in predicts: 197 | predict = np.split(predict, np.where((predict >= self.hparams.model.token.face_type_offset) & (predict < self.hparams.model.token.len))[0]+1)[0] 198 | # extract face type 199 | face_type = predict[-1] - self.hparams.model.token.face_type_offset 200 | # remove token offset 201 | predict -= self.hparams.model.token.len 202 | # only take the valid indices 203 | predict = predict[predict >= 0] 204 | predict = predict[predict < num_edges] 205 | if len(predict) > 0: 206 | predict_faces.append((face_type, tuple(predict.tolist()))) 207 | 208 | return predict_faces, label_faces 209 | 210 | def face_accuracy(self, outputs): 211 | labels = outputs['label'].cpu().numpy() # N (x F) x T 212 | predicts = outputs['predict'].cpu().numpy() # N (x L) x T 213 | 214 | outputs.update({'precisions': [], 'labels': [], 'type_acc_coedge_seq': [], \ 215 | 'recalls': [], 'predictions': [], 'accuracy': [], 'type_acc':[]}) 216 | 217 | for ind in range(len(labels)): 218 | edges = self.dataset.raw_datas[outputs['id'][ind]]['edges'] 219 | if len(labels.shape) == 3: 220 | # multi-seq, parallel 221 | predict_faces, label_faces = self.parse_parallel_faces(predicts[ind], labels[ind], len(edges)) 222 | else: 223 | # single-seq 224 | predict_faces, label_faces = self.parse_faces(predicts[ind], labels[ind], len(edges)) 225 | 226 | if self.hparams.post_process.is_coedge: 227 | pairings = self.dataset.raw_datas[outputs['id'][ind]]['pairings'] 228 | 229 | predict_faces = filter_faces_by_encloseness(edges, predict_faces, self.hparams.post_process.enclosedness_tol) 230 | label_faces = filter_faces_by_encloseness(edges, label_faces, self.hparams.post_process.enclosedness_tol) 231 | 232 | # calculate accuracy for faces with coedge 233 | # consider accuracy as the percent of predictions made correct 234 | face_tp = 0 235 | type_tp = 0 236 | for pred_type, pred_face in predict_faces: 237 | for label_type, label_face in set(label_faces): 238 | if pred_face == label_face: 239 | face_tp += 1 240 | if pred_type == label_type: 241 | type_tp += 1 242 | break 243 | 244 | if len(predict_faces) == 0: 245 | outputs['accuracy'].append(0) 246 | outputs['type_acc_coedge_seq'].append(0) 247 | else: 248 | outputs['accuracy'].append(face_tp / len(predict_faces)) 249 | if face_tp == 0: 250 | outputs['type_acc_coedge_seq'].append(0) 251 | else: 252 | outputs['type_acc_coedge_seq'].append(type_tp / face_tp) 253 | # map coedge into edges 254 | label_faces = [(ftype, map_coedge_into_edges(pairings, flatten_list(loops))) for ftype, loops in label_faces] 255 | predict_faces = [(ftype, map_coedge_into_edges(pairings, flatten_list(loops))) for ftype, loops in predict_faces] 256 | 257 | # filter duplicate label faces 258 | label_faces_set = list(set([(ftype, tuple(sorted(set(indices)))) for ftype, indices in label_faces])) 259 | 260 | # determine face type by majority vote 261 | predict_unique_faces = {} 262 | for ftype, indices in predict_faces: 263 | face = tuple(sorted(set(indices))) 264 | if face in predict_unique_faces: 265 | predict_unique_faces[face].append(ftype) 266 | else: 267 | predict_unique_faces[face] = [ftype] 268 | 269 | predict_faces_set = [(Counter(ftypes).most_common(1)[0][0], face) for face, ftypes in predict_unique_faces.items()] 270 | 271 | # count TP 272 | face_tp = 0 273 | type_tp = 0 274 | for pred_type, pred_face in predict_faces_set: 275 | for label_type, label_face in label_faces_set: 276 | if pred_face == label_face: 277 | face_tp += 1 278 | if pred_type == label_type: 279 | type_tp += 1 280 | break 281 | 282 | if len(predict_faces_set) == 0 or len(label_faces_set) == 0: 283 | outputs['precisions'].append(0) 284 | outputs['recalls'].append(0) 285 | outputs['type_acc'].append(0) 286 | else: 287 | outputs['precisions'].append(face_tp / len(predict_faces_set)) 288 | outputs['recalls'].append(face_tp / len(label_faces_set)) 289 | if face_tp == 0: 290 | outputs['type_acc'].append(0) 291 | else: 292 | outputs['type_acc'].append(type_tp / face_tp) 293 | outputs['predictions'].append(predict_faces_set) 294 | outputs['labels'].append(label_faces_set) 295 | 296 | # with first token removed, we only look at non-padded elements 297 | valid = labels > self.hparams.model.token.PAD 298 | acc_sum = (valid * (predicts == labels)).sum() 299 | valid_sum = valid.sum() 300 | return acc_sum/valid_sum, outputs 301 | 302 | 303 | -------------------------------------------------------------------------------- /reconstruction/reconstruct_to_wireframe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | from functools import partial 6 | 7 | import cvxpy as cp 8 | import numpy as np 9 | import open3d as o3d 10 | from faceformer.utils import flatten_list 11 | from tqdm.contrib.concurrent import process_map 12 | 13 | from reconstruction.reconstruction_utils import (construct_connected_cylinder, 14 | dist, fit_curve, 15 | is_straight_line) 16 | 17 | INTERMEDIATE_TYPE = 11 # the 4 extra faces added to each cylinder, not used in final reconstruction 18 | 19 | def sample_points_on_line(line, sample_dist): 20 | x1, y1, x2, y2 = line[0][0], line[0][1], line[1][0], line[1][1] 21 | num_samples = int(np.sqrt((x1-x2)**2+(y1-y2)**2) / sample_dist) + 1 22 | t = np.linspace(0, 1, num_samples) 23 | x = x1 + (x2-x1) * t 24 | y = y1 + (y2-y1) * t 25 | return np.vstack([x, y]).T 26 | 27 | def reconstruct_file(name, root): 28 | try: 29 | if os.path.exists(os.path.join(root, 'ply', f"{name}.ply")): 30 | return 31 | data = json.load(open(os.path.join(root, 'json', f'{name}.json'))) 32 | num_faces = len(data['pred_faces']) 33 | num_edges = len(data['edges']) 34 | 35 | to_add_new_planes = [] 36 | to_add_new_edges = [] 37 | face_removal_indices = [] 38 | circle_face_to_construct = [] # should contain indices of two outlines 39 | circle_face_to_construct_dir = [] # should contain the direction of the outlines 40 | 41 | dom_directions = [np.array(d[:2]) / np.linalg.norm(d[:2]) for d in data['dominant_directions']] 42 | face_to_normal = {} 43 | 44 | # check face types for other 45 | for i, (face_type, loops) in enumerate(data['pred_faces']): 46 | if face_type not in [0, 1]: 47 | face_removal_indices.append(i) 48 | continue 49 | 50 | # for each cylinder face, we construct two new plane face 51 | # always select the mid point for plane reconstruction 52 | if face_type == 1: 53 | face_removal_indices.append(i) 54 | 55 | # cylinder face should have two curves and two straight lines 56 | all_edge_inds = list(loops) 57 | all_edges = [data['edges'][i] for i in all_edge_inds] 58 | # if more than two straight lines, not a cylinder, skip 59 | count = 0 60 | for edge in all_edges: 61 | if is_straight_line(edge): 62 | count += 1 63 | if count != 2: 64 | print(f"{name} has {count} straight lines, not a cylinder") 65 | continue 66 | try: 67 | all_edges, all_edge_inds, all_dirs = construct_connected_cylinder(all_edges, all_edge_inds) 68 | except: 69 | continue 70 | 71 | # assuming loop has 4 edges as a cylinder face 72 | if len(all_edges) != 4: 73 | # combine nearing curves 74 | i = 0 75 | while i < len(all_edges): 76 | next_edge_ind = (i+1) % len(all_edges) 77 | if not is_straight_line(all_edges[i]) and not is_straight_line(all_edges[next_edge_ind]): 78 | all_edges[i] += all_edges[next_edge_ind] 79 | all_edges.pop(next_edge_ind) 80 | all_edge_inds.pop(next_edge_ind) 81 | all_dirs.pop(next_edge_ind) 82 | continue 83 | i += 1 84 | if len(all_edges) != 4: 85 | print(f"{name} has {len(all_edges)} edges in a cylinder") 86 | continue 87 | 88 | # assert face is a wireframe with coedge directions 89 | # if straight line comes first, 90 | # then direction of new constructed line is opposite of the straight line 91 | if is_straight_line(all_edges[0]): 92 | line_ind = all_edge_inds[0] 93 | line = all_edges[0] 94 | line_dir = all_dirs[0] 95 | curve = all_edges[1] 96 | curve_ind = all_edge_inds[1] 97 | other_line_ind = all_edge_inds[2] 98 | other_line = all_edges[2] 99 | other_line_dir = all_dirs[2] 100 | other_curve_ind = all_edge_inds[3] 101 | else: 102 | curve = all_edges[0] 103 | curve_ind = all_edge_inds[0] 104 | other_line = all_edges[1] 105 | other_line_ind = all_edge_inds[1] 106 | other_line_dir = all_dirs[1] 107 | other_curve_ind = all_edge_inds[2] 108 | line = all_edges[3] 109 | line_ind = all_edge_inds[3] 110 | line_dir = all_dirs[3] 111 | 112 | # assert all length of cylinder straight lines are the same 113 | # displace midpoint of one curve the same amount to generate the middle edge 114 | 115 | direction = np.array(line[0]) - np.array(line[1]) 116 | mid_point = np.array(curve[len(curve) // 2]) 117 | # next_point = np.array(other_curve[len(other_curve) // 2]) 118 | 119 | next_point = mid_point + direction 120 | mid_point = mid_point.tolist() 121 | next_point = next_point.tolist() 122 | new_mid_edge = [mid_point, next_point] 123 | new_edges = [new_mid_edge, [line[0], next_point], [line[1], mid_point], [other_line[1], next_point], [other_line[0], mid_point]] 124 | ind_offset = len(to_add_new_edges) + num_edges 125 | to_add_new_edges += new_edges 126 | face_1 = (INTERMEDIATE_TYPE, [line_ind, 2+ind_offset, ind_offset, 1+ind_offset]) 127 | face_2 = (INTERMEDIATE_TYPE, [other_line_ind, 3+ind_offset, ind_offset, 4+ind_offset]) 128 | to_add_new_planes += [face_1, face_2] 129 | circle_face_to_construct.append([line_ind, other_line_ind, ind_offset, curve_ind, other_curve_ind]) 130 | circle_face_to_construct_dir.append([line_dir, other_line_dir, 1]) 131 | 132 | # find plane's normal, assuming the axis of the cylinder face is aligned with one of the dominant directions 133 | edge_direction = np.array(line[0]) - np.array(line[1]) 134 | normal_direction_ind = np.argmax([np.abs(np.dot(edge_direction, d)) for d in dom_directions]) 135 | 136 | # find other coedge's circle plane and add normal constraint 137 | for i, (face_type, indices) in enumerate(data['pred_faces']): 138 | if curve_ind in indices or other_curve_ind in indices: 139 | face_to_normal[tuple(indices)] = normal_direction_ind 140 | 141 | # Add in new faces 142 | data['pred_faces'] += to_add_new_planes 143 | data['edges'] += to_add_new_edges 144 | num_faces = len(data['pred_faces']) 145 | num_edges = len(data['edges']) 146 | 147 | 148 | 149 | removed_faces = [] 150 | # remove cylinder faces 151 | for i, ind in enumerate(face_removal_indices): 152 | removed_faces.append(data['pred_faces'].pop(ind-i)) 153 | 154 | P = [] 155 | b = [] 156 | C = [] 157 | 158 | 159 | # create equation for each perpendicular face's normal direction and dominant direction 160 | # check 2D parallelism to one of the dominant directions 161 | # edge parallel to one dominant direction => perpendicular to that direction 162 | 163 | # normalized dom_directions 164 | # # !! 2d dominant directions need human input 165 | # aa = -np.sum(dom_directions[0]*dom_directions[1]) 166 | # bb = -np.sum(dom_directions[1]*dom_directions[2]) 167 | # cc = -np.sum(dom_directions[0]*dom_directions[2]) 168 | # z3 = np.sqrt(cc*bb / aa) 169 | # z2 = bb / z3 170 | # z1 = cc / z3 171 | 172 | # origin_directions = [dom_directions[0].tolist()+[z1], dom_directions[1].tolist()+[z2], dom_directions[2].tolist()+[z3]] 173 | # origin_directions = [np.array(d) / np.linalg.norm(d) for d in origin_directions] 174 | origin_directions = [np.array(d) / np.linalg.norm(d) for d in data['dominant_directions']] 175 | face_removal_indices = [] 176 | for face_ind, (face_type, indices) in enumerate(data['pred_faces']): 177 | parallel_count_for_dom_directions = [0] * 3 178 | for edge_ind in indices: 179 | edge = data['edges'][edge_ind] 180 | if not is_straight_line(edge): 181 | continue 182 | edge_direction = np.array(edge[0]) - np.array(edge[1]) 183 | edge_direction /= np.linalg.norm(edge_direction) 184 | # check if edge is parallel to one of the dominant directions 185 | 186 | for i, direction in enumerate(dom_directions): 187 | if np.abs(np.dot(edge_direction, direction)) > (1 - 1e-10): 188 | parallel_count_for_dom_directions[i] += 1 189 | 190 | 191 | # cylinder planes have predetermined normals from the outline 192 | if tuple(indices) in face_to_normal: 193 | normal_ind = face_to_normal[tuple(indices)] 194 | for i in range(3): 195 | if i != normal_ind: 196 | parallel_count_for_dom_directions[i] += 1 197 | 198 | if 0 not in parallel_count_for_dom_directions: 199 | # parallel to all dominant directions => wrong face prediction 200 | face_removal_indices.append(face_ind) 201 | continue 202 | 203 | # perpendicular to parallel directions 204 | for ind, count in enumerate(parallel_count_for_dom_directions): 205 | if count != 0: 206 | row = np.zeros(3 * num_faces) 207 | direction_3d = origin_directions[ind] 208 | # account for face removal 209 | face_ind -= len(face_removal_indices) 210 | row[3*face_ind: 3*face_ind+2] = [direction_3d[0], direction_3d[1]] 211 | brow = np.array([direction_3d[2]]) 212 | P.append(row) 213 | b.append(brow) 214 | for i, ind in enumerate(face_removal_indices): 215 | data['pred_faces'].pop(ind-i) 216 | 217 | # find all unique vertices 218 | all_vertices = [] 219 | all_used_edges = set(flatten_list([indices for _, indices in data['pred_faces']])) 220 | for ind in all_used_edges: 221 | all_vertices += data['edges'][ind] 222 | 223 | unique_vertices = [] 224 | tol = 1e-4 225 | for vertex in all_vertices: 226 | dists = np.array([dist(p1, vertex) for p1 in unique_vertices]) 227 | if np.sum(dists < tol) < 1: 228 | # new unique vertex 229 | unique_vertices.append(vertex) 230 | 231 | face_grouped_by_vertex = [[] for _ in range(len(unique_vertices))] 232 | # match faces to vertices 233 | for face_ind, (_, indices) in enumerate(data['pred_faces']): 234 | for edge_ind in indices: 235 | for point in data['edges'][edge_ind]: 236 | dists = np.array([dist(p1, point) for p1 in unique_vertices]) 237 | vertex_ind = np.argmin(dists) 238 | face_grouped_by_vertex[vertex_ind].append(face_ind) 239 | 240 | face_grouped_by_vertex = [list(set(group)) for group in face_grouped_by_vertex] 241 | for vertex, face_group in zip(unique_vertices, face_grouped_by_vertex): 242 | if len(face_group) < 2: 243 | continue 244 | # for each two face joined on 1 vertex, we create one equation for them 245 | for f1, f2 in itertools.combinations(face_group, 2): 246 | row = np.zeros(3 * num_faces) 247 | row[f1*3: f1*3 + 3] = [vertex[0], vertex[1], 1] 248 | row[f2*3: f2*3 + 3] = [-vertex[0], -vertex[1], -1] 249 | brow = np.array([0]) 250 | P.append(row) 251 | b.append(brow) 252 | # for each vertex and face, we create one constraint that z > 0 253 | for f in face_group: 254 | row = np.zeros(3 * num_faces) 255 | row[f*3: f*3 + 3] = [-vertex[0], -vertex[1], -1] 256 | C.append(row) 257 | 258 | P = np.array(P) 259 | C = np.array(C) 260 | b = np.array(b) 261 | 262 | n = P.shape[-1] 263 | if n == 0: 264 | return 265 | if C.shape[-1] == 0: 266 | return 267 | 268 | # sample points for 3d reconstruction 269 | pts = [] 270 | pts_label = [] 271 | sample_dist = 5e-3 272 | ind_to_3d_map = {} # contains the start of 3d samples for each edge index 273 | mid_edge_to_remove_start_ind = [] 274 | mid_edge_inds = [] 275 | for face_ind, (face_type, indices) in enumerate(data['pred_faces']): 276 | if face_type == INTERMEDIATE_TYPE: 277 | # only the first line (outline) and the third (mid edge) of intermediate face needs to be reconstructed 278 | sampled_pts = sample_points_on_line(data['edges'][indices[0]], sample_dist) 279 | pts.append(sampled_pts) 280 | ind_to_3d_map[indices[0]] = (len(pts_label), len(sampled_pts)) 281 | pts_label += [face_ind] * len(sampled_pts) 282 | # memorize where the outline's corresponding 3d points are 283 | sampled_pts = sample_points_on_line(data['edges'][indices[2]], sample_dist) 284 | pts.append(sampled_pts) 285 | ind_to_3d_map[indices[2]] = (len(pts_label), len(sampled_pts)) 286 | mid_edge_to_remove_start_ind.append(len(pts_label)) 287 | mid_edge_inds.append(indices[2]) 288 | pts_label += [face_ind] * len(sampled_pts) 289 | continue 290 | for edge_ind in indices: 291 | if is_straight_line(data['edges'][edge_ind]): 292 | sampled_pts = sample_points_on_line(data['edges'][edge_ind], sample_dist) 293 | pts.append(sampled_pts) 294 | ind_to_3d_map[edge_ind] = (len(pts_label), len(sampled_pts)) 295 | pts_label += [face_ind] * len(sampled_pts) 296 | 297 | if len(pts) == 0: 298 | return 299 | pts = np.vstack(pts) 300 | pts_label = np.array(pts_label) 301 | 302 | f = cp.Variable((n, 1)) 303 | try: 304 | objective = cp.Minimize(cp.norm1(P @ f + b)) 305 | constraints = [C @ f >= 0] 306 | prob = cp.Problem(objective, constraints) 307 | 308 | result = prob.solve() 309 | except: 310 | return 311 | params = f.value.reshape(-1, 3) 312 | 313 | N = len(pts) 314 | 315 | pts_one = np.hstack((pts, np.ones((N, 1)))) 316 | 317 | depth = np.sum(params[pts_label] * pts_one, axis=1, keepdims=True) 318 | 319 | xyz = np.hstack((pts, depth)) 320 | 321 | # reconstruct the circle planes 322 | for i in range(len(circle_face_to_construct)): 323 | line_ind, other_line_ind, mid_edge_ind, curve_ind, other_curve_ind = circle_face_to_construct[i] 324 | line_dir, other_line_dir, mid_edge_dir = circle_face_to_construct_dir[i] 325 | # connections between two outlines give our center of the circle 326 | # direction of outline give us normal of the plane 327 | start_ind, num_samples = ind_to_3d_map[line_ind] 328 | pts = xyz[start_ind:start_ind+num_samples] 329 | other_start_ind, num_samples = ind_to_3d_map[other_line_ind] 330 | other_pts = xyz[other_start_ind:other_start_ind+num_samples] 331 | mid_edge_start_ind, num_samples = ind_to_3d_map[mid_edge_ind] 332 | mid_edge_pts = xyz[mid_edge_start_ind:mid_edge_start_ind+num_samples] 333 | 334 | p1, p2, p3 = pts[::line_dir][0], other_pts[::other_line_dir][-1], mid_edge_pts[::mid_edge_dir][-1] 335 | curve_pts = fit_curve(p1, p2, p3) 336 | ind_to_3d_map[other_curve_ind] = (len(xyz), len(curve_pts)) 337 | xyz = np.vstack([xyz, curve_pts]) 338 | 339 | p1, p2, p3 = pts[::line_dir][-1], other_pts[::other_line_dir][0], mid_edge_pts[::mid_edge_dir][0] 340 | curve_pts = fit_curve(p1, p2, p3) 341 | ind_to_3d_map[curve_ind] = (len(xyz), len(curve_pts)) 342 | xyz = np.vstack([xyz, curve_pts]) 343 | 344 | 345 | # add back the removed cylinder faces 346 | data['pred_faces'] += removed_faces 347 | 348 | # iterate through the faces and add edges that are not mid-edges 349 | points = [] 350 | edges_drawn = set(mid_edge_inds) 351 | for face_type, indices in data['pred_faces']: 352 | if face_type == INTERMEDIATE_TYPE: 353 | continue 354 | for ind in indices: 355 | if ind in ind_to_3d_map and ind not in edges_drawn: 356 | start_ind, length = ind_to_3d_map[ind] 357 | points.append(xyz[start_ind:start_ind+length]) 358 | edges_drawn.add(ind) 359 | 360 | pcd = o3d.geometry.PointCloud() 361 | pts = np.vstack(points) 362 | pts[:, 1] = -pts[:, 1] 363 | pcd.points = o3d.utility.Vector3dVector(pts) 364 | 365 | o3d.io.write_point_cloud(os.path.join(root, 'ply', f"{name}.ply"), pcd) 366 | except: 367 | print(f"{name} failed") 368 | return 369 | 370 | if __name__ == '__main__': 371 | 372 | parser = argparse.ArgumentParser() 373 | parser.add_argument('--root', type=str, default="/root/data", 374 | help='dataset root.') 375 | parser.add_argument('--name', type=str, default=None, 376 | help='filename.') 377 | parser.add_argument('--num_cores', type=int, 378 | default=10, help='number of processors.') 379 | parser.add_argument('--num_chunks', type=int, 380 | default=5, help='number of chunk.') 381 | 382 | args = parser.parse_args() 383 | 384 | os.makedirs(os.path.join(args.root, 'ply'), exist_ok=True) 385 | 386 | if args.name is not None: 387 | reconstruct_file(args.name, args.root) 388 | else: 389 | all_names = [name[:8] for name in os.listdir(os.path.join(args.root, 'json'))] 390 | process_map(partial(reconstruct_file, root=args.root), all_names, 391 | max_workers=args.num_cores, chunksize=args.num_chunks) 392 | # for name in all_names: 393 | # reconstruct_file(name, args.root) 394 | -------------------------------------------------------------------------------- /dataset/prepare_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Data Prep for ABC step Files 3 | * use HLR to render wireframe 4 | * break surfaces with outlines 5 | * remove seam lines 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | from functools import partial 11 | 12 | import numpy as np 13 | from OCC.Core.Bnd import Bnd_Box 14 | from OCC.Core.BRepBndLib import brepbndlib_Add 15 | from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_Transform 16 | from OCC.Core.gp import gp_Pnt, gp_Trsf, gp_Vec 17 | from OCC.Extend.TopologyUtils import TopologyExplorer 18 | from tqdm.contrib.concurrent import process_map 19 | 20 | from dataset.utils.discretize_edge import (DiscretizedEdge, sort_faces_by_indices, 21 | sort_edges_by_coordinate) 22 | from dataset.utils.json_to_svg import save_png, save_svg, save_svg_groups 23 | from dataset.utils.read_step_file import read_step_file 24 | from dataset.utils.TopoMapper import TopoMapper 25 | 26 | from dataset.tests.check_faces_enclosed import is_face_enclosed 27 | from faceformer.utils import flatten_list 28 | from dataset.utils.projection_utils import generate_random_camera_pos 29 | 30 | def get_boundingbox(shapes, tol=1e-6): 31 | """ return the bounding box of the TopoDS_Shape `shape` 32 | Parameters 33 | ---------- 34 | shape : TopoDS_Shape or a subclass such as TopoDS_Face 35 | the shape to compute the bounding box from 36 | tol: float 37 | tolerance of the computed boundingbox 38 | """ 39 | bbox = Bnd_Box() 40 | bbox.SetGap(tol) 41 | for shape in shapes: 42 | brepbndlib_Add(shape, bbox, False) 43 | xmin, ymin, zmin, xmax, ymax, zmax = bbox.Get() 44 | center = (xmax + xmin) / 2, (ymin + ymax) / 2, (zmin + zmax) / 2 45 | extent = abs(xmax-xmin), abs(ymax-ymin), abs(zmax-zmin) 46 | return center, extent 47 | 48 | 49 | def shape_to_svg(shape, name, args): 50 | """ export a single shape to an svg file and json. 51 | shape: the TopoDS_Shape to export 52 | """ 53 | if shape.IsNull(): 54 | raise AssertionError("shape is Null") 55 | 56 | topo = TopoMapper(shape, args) 57 | all_dedges = [] 58 | faces_pointers = [] 59 | face_types = [] 60 | # face_parameters = [] 61 | all_shrinked_dedges = [] 62 | shape_center, _ = get_boundingbox([shape]) 63 | # ind = 2 64 | # replacements = {2:[ind, (3, (ind-2)%3)], 3:[(ind-2)%3, (2, ind)]} 65 | 66 | for index, face in enumerate(topo.all_faces.values()): 67 | discretized_edges = face.get_oriented_dedges() 68 | discretized_edges_3d = face.get_oriented_dedges(is_3d=True) 69 | 70 | # generate smaller edges for visualization of faces 71 | face_edges_lists = [edge.edges for edge in face.edges] 72 | face_edges = flatten_list(face_edges_lists) 73 | center, _ = get_boundingbox(face_edges) 74 | translation = gp_Trsf() 75 | push_vec = np.array([center[0] - shape_center[0], center[1] - shape_center[1], center[2] - shape_center[2]]) * 1.04 76 | # push_vec += push_vec / np.sqrt(np.sum(push_vec**2)) * 0.1 # push a fixed amount 77 | # push the edges out along the line from the center to the face 78 | translation.SetTranslation(gp_Vec(*push_vec)) 79 | 80 | # scale = gp_Trsf() 81 | # scale.SetScale(gp_Pnt(*center), 0.7) 82 | shrinked_dedges = [] 83 | for i, edge_list in enumerate(face_edges_lists): 84 | # if index in replacements and i == replacements[index][0]: 85 | # other_face_index = replacements[index][1][0] 86 | # other_face_edges_lists = [edge.edges for edge in list(topo.all_faces.values())[other_face_index].edges] 87 | # edge_list = other_face_edges_lists[replacements[index][1][1]] 88 | shrinked_edges = [] 89 | for edge in edge_list: 90 | brep_trans = BRepBuilderAPI_Transform(edge, translation) 91 | edge = brep_trans.Shape() 92 | shrinked_edges.append(edge) 93 | shrinked_dedges.append(topo._raw_project(shrinked_edges, args.discretize_last)) 94 | 95 | all_shrinked_dedges.append(shrinked_dedges) 96 | filename = os.path.join(args.root, 'face_svg', f'{name}_{index}.svg') 97 | save_svg(discretized_edges, filename, args) 98 | save_png(f'{name}_{index}', args, prefix='face_') 99 | 100 | 101 | filename = os.path.join(args.root, 'face_shrinked_face_svg', f'{name}_{index}.svg') 102 | save_svg(shrinked_dedges, filename, args) 103 | save_png(f'{name}_{index}', args, prefix='face_shrinked_face_') 104 | 105 | # generate data for face ground truth 106 | if args.combine_coedge: 107 | # combine coedges => all coedges share the same direction 108 | for edge in face.edges: 109 | if edge.DiscretizedEdge is None: 110 | edge.DiscretizedEdge = DiscretizedEdge(edge.dedge) 111 | all_dedges.append(edge.DiscretizedEdge) 112 | face_pointers = [edge.DiscretizedEdge for edge in face.edges] 113 | faces_pointers.append(face_pointers) 114 | else: 115 | assert len(discretized_edges) == len(shrinked_dedges) 116 | assert len(discretized_edges) == len(discretized_edges_3d) 117 | # each edge is represented as two discretized edges in two directions 118 | # save 3d points 119 | face_pointers = [DiscretizedEdge(dedge, smaller_edge=shrinked_dedge, edge3d=dedge_3d) \ 120 | for dedge, shrinked_dedge, dedge_3d in zip(discretized_edges, shrinked_dedges, discretized_edges_3d)] 121 | all_dedges += face_pointers 122 | # get face pointers 123 | faces_pointers.append(face_pointers) 124 | 125 | face_types.append(face.face_type) 126 | # face_parameters.append(face.parameters) 127 | 128 | all_dedges = sort_edges_by_coordinate(all_dedges) 129 | # assign index to each dedge 130 | for index, dedge in enumerate(all_dedges): 131 | dedge.index = index 132 | 133 | faces_indices = [] 134 | for face_pointers in faces_pointers: 135 | if args.order_by_position: 136 | faces_indices.append(sorted([dedge.index for dedge in face_pointers])) 137 | else: 138 | faces_indices.append([dedge.index for dedge in face_pointers]) 139 | 140 | save_svg([edge.dedge for edge in topo.all_edges.values()], os.path.join( 141 | args.root, 'svg', f'{name}.svg'), args) 142 | save_png(name, args) 143 | save_svg_groups(all_shrinked_dedges, os.path.join( 144 | args.root, 'face_shrinked_svg', f'{name}.svg'), args) 145 | save_png(name, args, prefix='face_shrinked_') 146 | 147 | if args.combine_coedge: 148 | faces_indices = [np.roll(face, -np.argmin(face), axis=0).tolist() for face in faces_indices] 149 | faces_indices = sort_faces_by_indices(faces_indices) 150 | else: 151 | # check enclosedness here, raise error if not enclosed 152 | # group indices 153 | all_edge_points = [dedge.points for dedge in all_dedges] 154 | sorted_faces_indices = [] 155 | for i, face in enumerate(faces_indices): 156 | all_face_loops = is_face_enclosed(all_edge_points, face, args.tol * 2) 157 | if not all_face_loops: 158 | raise Exception("faces unenclosed") 159 | # roll enclosed loops so smallest index is at the front 160 | all_face_loops = [np.roll(loop, -np.argmin(loop), axis=0).tolist() for loop in all_face_loops] 161 | # loops are ordered by first index 162 | all_face_loops = sorted(all_face_loops, key=lambda x: x[0]) 163 | # sorted_faces_indices.append([face_types[i], all_face_loops, face_parameters[i]]) 164 | if args.no_face_type: 165 | sorted_faces_indices.append(all_face_loops) 166 | else: 167 | sorted_faces_indices.append([face_types[i], all_face_loops]) 168 | 169 | # each face: 170 | # [ 171 | # type, 172 | # [loops], 173 | # [parameters] 174 | # ] 175 | # order faces by first index 176 | if args.no_face_type: 177 | faces_indices = sorted(sorted_faces_indices, key=lambda x: x[0][0]) 178 | else: 179 | faces_indices = sorted(sorted_faces_indices, key=lambda x: x[1][0][0]) 180 | edges_to_json(all_dedges, faces_indices, name, topo.get_dominant_directions()) 181 | 182 | 183 | def shape_to_svg_direction_token(shape, name, args): 184 | """ export a single shape to an svg file and json. 185 | ! combine coedge, and for each face index, give a direction indicator 0/1 186 | shape: the TopoDS_Shape to export 187 | """ 188 | if shape.IsNull(): 189 | raise AssertionError("shape is Null") 190 | 191 | topo = TopoMapper(shape, args) 192 | all_dedges = [] 193 | faces_pointers = [] 194 | face_types = [] 195 | 196 | for index, face in enumerate(topo.all_faces.values()): 197 | # save shape visualization 198 | discretized_edges = face.get_oriented_dedges() 199 | filename = os.path.join(args.root, 'face_svg', f'{name}_{index}.svg') 200 | save_svg(discretized_edges, filename, args) 201 | save_png(f'{name}_{index}', args, prefix='face_') 202 | 203 | # generate data for face ground truth 204 | for edge in face.edges: 205 | if edge.DiscretizedEdge is None: 206 | edge.DiscretizedEdge = DiscretizedEdge(edge.dedge) 207 | all_dedges.append(edge.DiscretizedEdge) 208 | # e-> edge, o-> orientation 209 | face_pointers = [(e.DiscretizedEdge, o) for e, o in zip(face.edges, face.edge_orientations)] 210 | faces_pointers.append(face_pointers) 211 | face_types.append(face.face_type) 212 | 213 | # save face visualization 214 | save_svg([edge.dedge for edge in topo.all_edges.values()], os.path.join( 215 | args.root, 'svg', f'{name}.svg'), args) 216 | save_png(name, args) 217 | 218 | # generate data for face ground truth 219 | all_dedges = sort_edges_by_coordinate(all_dedges) 220 | # assign index to each dedge 221 | for index, dedge in enumerate(all_dedges): 222 | dedge.index = index 223 | 224 | faces_indices = [] 225 | for face_pointers in faces_pointers: 226 | # o-> orientation 227 | faces_indices.append([(dedge.index, o) for dedge, o in face_pointers]) 228 | 229 | # check enclosedness here, raise error if not enclosed 230 | # group indices 231 | all_edge_points = [dedge.points for dedge in all_dedges] 232 | sorted_faces_indices = [] 233 | for face in faces_indices: 234 | all_face_loops = is_face_enclosed(all_edge_points, face, args.tol * 2) 235 | if not all_face_loops: 236 | raise Exception("faces unenclosed") 237 | # roll enclosed loops so smallest index is at the front 238 | all_face_loops = [np.roll(loop, -np.argmin([t[0] for t in loop]), axis=0).tolist() for loop in all_face_loops] 239 | # loops are ordered by first index 240 | all_face_loops = sorted(all_face_loops, key=lambda x: x[0][0]) 241 | sorted_faces_indices.append(all_face_loops) 242 | 243 | # order faces by first index 244 | faces_indices = sorted(sorted_faces_indices, key=lambda x: x[0][0][0]) 245 | edges_to_json(all_dedges, faces_indices, name, topo.get_dominant_directions()) 246 | 247 | 248 | 249 | def edges_to_json(all_dedges, faces_indices, name, dominant_directions): 250 | # write to json 251 | json_filename = os.path.join(args.root, 'json', f'{name}.json') 252 | data = {} 253 | data['edges'] = [dedge.points for dedge in all_dedges] 254 | data['edges3d'] = [dedge.edge3d for dedge in all_dedges] 255 | data['shrinked_edges'] = [dedge.smaller_edge for dedge in all_dedges] 256 | data['faces_indices'] = faces_indices 257 | data['dominant_directions'] = dominant_directions 258 | data['pairings'] = {} 259 | # find all pairings of indices 260 | for i in range(len(data['edges'])): 261 | for j in range(i+1, len(data['edges'])): 262 | if data['edges'][i] == data['edges'][j][::-1]: 263 | data['pairings'][i] = j 264 | with open(json_filename, 'w') as f: 265 | json.dump(data, f) 266 | 267 | 268 | def render_shape_and_faces(name, args): 269 | try: 270 | # if os.path.exists(os.path.join(args.root, 'json', f'{name}.json')): 271 | # return 272 | step_path = os.path.join(args.root, 'step', f'{name}.step') 273 | # step read timeout at 5 seconds 274 | try: 275 | shape, num_shapes = read_step_file(step_path, verbosity=False) 276 | except: 277 | print(f"{name} took too long to read") 278 | return 279 | 280 | if shape is None: 281 | print(f"{name} is NULL shape") 282 | return 283 | 284 | if num_shapes > args.filter_num_shapes: 285 | print(f"{name} has {num_shapes} shapes. Too many!") 286 | return 287 | 288 | topology_explorer = TopologyExplorer(shape) 289 | 290 | if len(list(topology_explorer.edges())) > args.filter_num_edges: 291 | print(f"{name} has too many edges.") 292 | return 293 | 294 | center, extent = get_boundingbox([shape]) 295 | 296 | trans, scale = gp_Trsf(), gp_Trsf() 297 | trans.SetTranslation(-gp_Vec(*center)) 298 | scale.SetScale(gp_Pnt(0, 0, 0), 2 / np.linalg.norm(extent)) 299 | brep_trans = BRepBuilderAPI_Transform(shape, scale * trans) 300 | shape = brep_trans.Shape() 301 | 302 | args.pose = None 303 | # generate random camera position 304 | if args.random_camera: 305 | # 5 tries at random angle image 306 | for _ in range(5): 307 | try: 308 | focus, cam_pose = generate_random_camera_pos(args.seed) 309 | args.pose = cam_pose 310 | # check orthographic projection 311 | if args.focus != 0: 312 | args.focus = focus 313 | if args.direction_token: 314 | shape_to_svg_direction_token(shape, name, args) 315 | else: 316 | shape_to_svg(shape, name, args) 317 | return 318 | except: 319 | continue 320 | 321 | if args.direction_token: 322 | shape_to_svg_direction_token(shape, name, args) 323 | else: 324 | shape_to_svg(shape, name, args) 325 | 326 | except Exception as e: 327 | print(f"{name} received unknown error", e) 328 | 329 | def prepare_splits(args): 330 | if os.path.exists(args.id_list): 331 | with open(args.id_list, 'r') as f: 332 | names = json.load(f) 333 | else: 334 | names = [] 335 | for name in sorted(os.listdir(os.path.join(args.root, 'json'))): 336 | names.append(name[:8]) 337 | 338 | np.random.seed(args.seed) 339 | np.random.shuffle(names) 340 | train_ratio, valid_ratio, test_ratio = args.split 341 | trainlist, validlist, testlist = np.split(names, [int( 342 | len(names) * train_ratio), int(len(names) * (train_ratio + valid_ratio))]) 343 | 344 | np.savetxt(os.path.join(args.root, 'train.txt'), trainlist, fmt="json/%s.json") 345 | np.savetxt(os.path.join(args.root, 'valid.txt'), validlist, fmt="json/%s.json") 346 | np.savetxt(os.path.join(args.root, 'test.txt'), testlist, fmt="json/%s.json") 347 | 348 | 349 | def main(args): 350 | np.random.seed(args.seed) 351 | os.makedirs(os.path.join(args.root, 'svg'), exist_ok=True) 352 | os.makedirs(os.path.join(args.root, 'png'), exist_ok=True) 353 | os.makedirs(os.path.join(args.root, 'face_shrinked_face_svg'), exist_ok=True) 354 | os.makedirs(os.path.join(args.root, 'face_shrinked_face_png'), exist_ok=True) 355 | os.makedirs(os.path.join(args.root, 'face_shrinked_svg'), exist_ok=True) 356 | os.makedirs(os.path.join(args.root, 'face_shrinked_png'), exist_ok=True) 357 | os.makedirs(os.path.join(args.root, 'face_svg'), exist_ok=True) 358 | os.makedirs(os.path.join(args.root, 'face_png'), exist_ok=True) 359 | os.makedirs(os.path.join(args.root, 'json'), exist_ok=True) 360 | 361 | if os.path.exists(args.id_list): 362 | with open(args.id_list, 'r') as f: 363 | names = json.load(f) 364 | else: 365 | names = [] 366 | for name in sorted(os.listdir(os.path.join(args.root, 'step'))): 367 | names.append(os.path.splitext(name)[0]) 368 | 369 | if not args.only_split: 370 | process_map( 371 | partial(render_shape_and_faces, args=args), names, 372 | max_workers=args.num_cores, chunksize=args.num_chunks 373 | ) 374 | 375 | prepare_splits(args) 376 | 377 | 378 | if __name__ == '__main__': 379 | parser = argparse.ArgumentParser() 380 | parser.add_argument('--root', type=str, default="./data", 381 | help='dataset root.') 382 | parser.add_argument('--id_list', type=str, default="None", 383 | help='filtered(with similarity) data id list') 384 | parser.add_argument('--name', type=str, default=None, 385 | help='filename.') 386 | parser.add_argument('--num_cores', type=int, 387 | default=5, help='number of processors.') 388 | parser.add_argument('--num_chunks', type=int, 389 | default=10, help='number of chunk.') 390 | parser.add_argument('--width', type=int, 391 | default=256, help='svg width.') 392 | parser.add_argument('--height', type=int, 393 | default=256, help='svg height.') 394 | parser.add_argument('--png_padding', type=float, 395 | default=0.2, help='padding from content to the edge of png.') 396 | parser.add_argument('--tol', type=float, 397 | default=1e-4, help='svg discretization tolerance.') 398 | parser.add_argument('--face_shrink_scale', type=float, 399 | default=0.8, help='shrinking face for visualization.') 400 | parser.add_argument('--line_width', type=str, 401 | default=str(6/256), help='svg line width.') 402 | parser.add_argument('--filter_num_shapes', type=int, 403 | default=1, help='do not process step files \ 404 | that have more than this number of shapes.') 405 | parser.add_argument('--filter_num_edges', type=int, 406 | default=64, help='do not process step files \ 407 | that have more than this number of edges.') 408 | parser.add_argument('--location', nargs="+", type=float, 409 | default=[1, 1, 1], help='projection location') 410 | parser.add_argument('--direction', nargs="+", type=float, 411 | default=[1, 1, 1], help='projection direction') 412 | parser.add_argument('--focus', type=float, 413 | default=3, help='focus of the projection camera.') 414 | parser.add_argument('--split', nargs="+", type=int, 415 | default=[0.93, 0.02, 0.05], 416 | help='train/valid/test split ratio') 417 | parser.add_argument('--only_split', action='store_true') 418 | parser.add_argument('--combine_coedge', action='store_true') 419 | parser.add_argument('--order_by_position', action='store_true') 420 | parser.add_argument('--direction_token', action='store_true') 421 | parser.add_argument('--random_camera', action='store_true') 422 | parser.add_argument('--discretize_last', action='store_true') 423 | parser.add_argument('--no_face_type', action='store_true') 424 | parser.add_argument('--seed', type=int, default=42, 425 | help='numpy random seed') 426 | 427 | args = parser.parse_args() 428 | 429 | if args.name is None: 430 | main(args) 431 | else: 432 | render_shape_and_faces(args.name, args) 433 | --------------------------------------------------------------------------------