├── docs └── img │ ├── Teaser.png │ ├── UVGrid.png │ └── MessagePassing.png ├── environment.yml ├── process ├── README.md ├── solid_to_pointcloud.py ├── solid_to_rendermesh.py ├── visualize.py ├── solid_to_graph.py └── visualize_uvgrid_graph.py ├── LICENSE ├── .gitignore ├── datasets ├── mfcad.py ├── base.py ├── solidletters.py ├── fusiongallery.py └── util.py ├── classification.py ├── segmentation.py ├── README.md └── uvnet ├── encoders.py └── models.py /docs/img/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutodeskAILab/UV-Net/HEAD/docs/img/Teaser.png -------------------------------------------------------------------------------- /docs/img/UVGrid.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutodeskAILab/UV-Net/HEAD/docs/img/UVGrid.png -------------------------------------------------------------------------------- /docs/img/MessagePassing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AutodeskAILab/UV-Net/HEAD/docs/img/MessagePassing.png -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: uv_net 2 | channels: 3 | - pytorch 4 | - lambouj 5 | - dglteam 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - python=3.9 10 | - occwl=1.0 11 | - pytorch=1.8 12 | - pytorch-lightning=1.3 13 | - torchmetrics=0.3 14 | - dgl-cuda11.0=0.6.1=py39_0 15 | - joblib=1.0.1=pyhd3eb1b0_0 16 | - matplotlib=3.4.2 17 | - matplotlib-base=3.4.2 18 | - scikit-learn 19 | - torchmetrics=0.3.2 20 | - tqdm=4.59.0 21 | - trimesh=3.9.18 -------------------------------------------------------------------------------- /process/README.md: -------------------------------------------------------------------------------- 1 | # Processing your own data 2 | 3 | We provide scripts to process your own STEP file data into the DGL bin format that UV-Net consumes, point clouds in NPZ format and render meshes (non-watertight meshes) in STL format. 4 | 5 | Example usage: 6 | 7 | ``` 8 | cd /path/to/uv_net 9 | python -m process.solid_to_graph /path/to/input/step_files /path/to/output/bin_graphs 10 | ``` 11 | 12 | Other scripts can be run similarly. For more details, run the script with the `--help` argument. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Autodesk, Inc 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | results/* 132 | .vscode/ 133 | run_scripts/ 134 | segmentation_checkpoints/ 135 | segmentation_logs/ 136 | segmentation_pretrained/ -------------------------------------------------------------------------------- /datasets/mfcad.py: -------------------------------------------------------------------------------- 1 | from datasets.base import BaseDataset 2 | import pathlib 3 | import torch 4 | import json 5 | 6 | 7 | class MFCADDataset(BaseDataset): 8 | @staticmethod 9 | def num_classes(): 10 | return 16 11 | 12 | def __init__( 13 | self, root_dir, split="train", center_and_scale=True, random_rotate=False, 14 | ): 15 | """ 16 | Load the MFCAD dataset from: 17 | Weijuan Cao, Trevor Robinson, Yang Hua, Flavien Boussuge, 18 | Andrew R. Colligan, and Wanbin Pan. "Graph representation 19 | of 3d cad models for machining feature recognition with deep 20 | learning." In Proceedings of the ASME 2020 International 21 | Design Engineering Technical Conferences and Computers 22 | and Information in Engineering Conference, IDETC-CIE. 23 | ASME, 2020. 24 | 25 | Args: 26 | root_dir (str): Root path of dataset 27 | split (str, optional): Data split to load. Defaults to "train". 28 | center_and_scale (bool, optional): Whether to center and scale the solid. Defaults to True. 29 | random_rotate (bool, optional): Whether to apply random rotations to the solid in 90 degree increments. Defaults to False. 30 | """ 31 | path = pathlib.Path(root_dir) 32 | self.path = path 33 | assert split in ("train", "val", "test") 34 | 35 | with open(str(str(path.joinpath("split.json"))), "r") as read_file: 36 | filelist = json.load(read_file) 37 | 38 | if split == "train": 39 | split_filelist = filelist["train"] 40 | elif split == "val": 41 | split_filelist = filelist["validation"] 42 | else: 43 | split_filelist = filelist["test"] 44 | 45 | self.random_rotate = random_rotate 46 | 47 | all_files = [] 48 | for fn in split_filelist: 49 | all_files.append(path.joinpath("graph").joinpath(fn + ".bin")) 50 | 51 | # Load graphs 52 | print(f"Loading {split} data...") 53 | self.load_graphs(all_files, center_and_scale) 54 | print("Done loading {} files".format(len(self.data))) 55 | 56 | def load_one_graph(self, file_path): 57 | # Load the graph using base class method 58 | sample = super().load_one_graph(file_path) 59 | # Additionally load the label and store it as node data 60 | label_file = self.path.joinpath("labels").joinpath(file_path.stem + "_ids.json") 61 | with open(str(label_file), "r") as read_file: 62 | labels_data = json.load(read_file) 63 | label = [] 64 | for face in labels_data["body"]["faces"]: 65 | index = face["segment"]["index"] 66 | label.append(index) 67 | sample["graph"].ndata["y"] = torch.tensor(label).long() 68 | return sample 69 | -------------------------------------------------------------------------------- /datasets/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | from torch import FloatTensor 3 | import dgl 4 | from dgl.data.utils import load_graphs 5 | from datasets import util 6 | from tqdm import tqdm 7 | from abc import abstractmethod 8 | 9 | 10 | class BaseDataset(Dataset): 11 | @staticmethod 12 | @abstractmethod 13 | def num_classes(): 14 | pass 15 | 16 | def load_graphs(self, file_paths, center_and_scale=True): 17 | self.data = [] 18 | for fn in tqdm(file_paths): 19 | if not fn.exists(): 20 | continue 21 | sample = self.load_one_graph(fn) 22 | if sample is None: 23 | continue 24 | if sample["graph"].edata["x"].size(0) == 0: 25 | # Catch the case of graphs with no edges 26 | continue 27 | self.data.append(sample) 28 | if center_and_scale: 29 | self.center_and_scale() 30 | self.convert_to_float32() 31 | 32 | def load_one_graph(self, file_path): 33 | graph = load_graphs(str(file_path))[0][0] 34 | sample = {"graph": graph, "filename": file_path.stem} 35 | return sample 36 | 37 | def center_and_scale(self): 38 | for i in range(len(self.data)): 39 | self.data[i]["graph"].ndata["x"], center, scale = util.center_and_scale_uvgrid( 40 | self.data[i]["graph"].ndata["x"], return_center_scale=True 41 | ) 42 | self.data[i]["graph"].edata["x"][..., :3] -= center 43 | self.data[i]["graph"].edata["x"][..., :3] *= scale 44 | 45 | def convert_to_float32(self): 46 | for i in range(len(self.data)): 47 | self.data[i]["graph"].ndata["x"] = self.data[i]["graph"].ndata["x"].type(FloatTensor) 48 | self.data[i]["graph"].edata["x"] = self.data[i]["graph"].edata["x"].type(FloatTensor) 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def __getitem__(self, idx): 54 | sample = self.data[idx] 55 | if self.random_rotate: 56 | rotation = util.get_random_rotation() 57 | sample["graph"].ndata["x"] = util.rotate_uvgrid(sample["graph"].ndata["x"], rotation) 58 | sample["graph"].edata["x"] = util.rotate_uvgrid(sample["graph"].edata["x"], rotation) 59 | return sample 60 | 61 | def _collate(self, batch): 62 | batched_graph = dgl.batch([sample["graph"] for sample in batch]) 63 | batched_filenames = [sample["filename"] for sample in batch] 64 | return {"graph": batched_graph, "filename": batched_filenames} 65 | 66 | def get_dataloader(self, batch_size=128, shuffle=True, num_workers=0): 67 | return DataLoader( 68 | self, 69 | batch_size=batch_size, 70 | shuffle=shuffle, 71 | collate_fn=self._collate, 72 | num_workers=num_workers, # Can be set to non-zero on Linux 73 | drop_last=True, 74 | ) 75 | -------------------------------------------------------------------------------- /datasets/solidletters.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import string 3 | 4 | import torch 5 | from sklearn.model_selection import train_test_split 6 | 7 | from datasets.base import BaseDataset 8 | 9 | 10 | def _get_filenames(root_dir, filelist): 11 | with open(str(root_dir / f"{filelist}"), "r") as f: 12 | file_list = [x.strip() for x in f.readlines()] 13 | 14 | files = list( 15 | x 16 | for x in root_dir.rglob(f"*.bin") 17 | if x.stem in file_list 18 | #if util.valid_font(x) and x.stem in file_list 19 | ) 20 | return files 21 | 22 | 23 | CHAR2LABEL = {char: i for (i, char) in enumerate(string.ascii_lowercase)} 24 | 25 | 26 | def _char_to_label(char): 27 | return CHAR2LABEL[char.lower()] 28 | 29 | 30 | class SolidLetters(BaseDataset): 31 | @staticmethod 32 | def num_classes(): 33 | return 26 34 | 35 | def __init__( 36 | self, 37 | root_dir, 38 | split="train", 39 | center_and_scale=True, 40 | random_rotate=False, 41 | ): 42 | """ 43 | Load the SolidLetters dataset 44 | 45 | Args: 46 | root_dir (str): Root path to the dataset 47 | split (str, optional): Split (train, val, or test) to load. Defaults to "train". 48 | center_and_scale (bool, optional): Whether to center and scale the solid. Defaults to True. 49 | random_rotate (bool, optional): Whether to apply random rotations to the solid in 90 degree increments. Defaults to False. 50 | """ 51 | assert split in ("train", "val", "test") 52 | path = pathlib.Path(root_dir) 53 | 54 | self.random_rotate = random_rotate 55 | 56 | if split in ("train", "val"): 57 | file_paths = _get_filenames(path, filelist="train.txt") 58 | # The first character of filename must be the alphabet 59 | labels = [_char_to_label(fn.stem[0]) for fn in file_paths] 60 | train_files, val_files = train_test_split( 61 | file_paths, test_size=0.2, random_state=42, stratify=labels, 62 | ) 63 | if split == "train": 64 | file_paths = train_files 65 | elif split == "val": 66 | file_paths = val_files 67 | elif split == "test": 68 | file_paths = _get_filenames(path, filelist="test.txt") 69 | 70 | print(f"Loading {split} data...") 71 | self.load_graphs(file_paths, center_and_scale) 72 | print("Done loading {} files".format(len(self.data))) 73 | 74 | def load_one_graph(self, file_path): 75 | # Load the graph using base class method 76 | sample = super().load_one_graph(file_path) 77 | # Additionally get the label from the filename and store it in the sample dict 78 | sample["label"] = torch.tensor([_char_to_label(file_path.stem[0])]).long() 79 | return sample 80 | 81 | def _collate(self, batch): 82 | collated = super()._collate(batch) 83 | collated["label"] = torch.cat([x["label"] for x in batch], dim=0) 84 | return collated 85 | -------------------------------------------------------------------------------- /datasets/fusiongallery.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from datasets.base import BaseDataset 3 | import pathlib 4 | import torch 5 | import json 6 | from sklearn.model_selection import train_test_split 7 | 8 | 9 | class FusionGalleryDataset(BaseDataset): 10 | @staticmethod 11 | def num_classes(): 12 | return 8 13 | 14 | def __init__( 15 | self, root_dir, split="train", center_and_scale=True, random_rotate=False, 16 | ): 17 | """ 18 | Load the Fusion Gallery dataset from: 19 | Joseph G. Lambourne, Karl D. D. Willis, Pradeep Kumar Jayaraman, Aditya Sanghi, 20 | Peter Meltzer, Hooman Shayani. "BRepNet: A topological message passing system 21 | for solid models," CVPR 2021. 22 | 23 | Args: 24 | root_dir (str): Root path of dataset 25 | split (str, optional): Data split to load. Defaults to "train". 26 | center_and_scale (bool, optional): Whether to center and scale the solid. Defaults to True. 27 | random_rotate (bool, optional): Whether to apply random rotations to the solid in 90 degree increments. Defaults to False. 28 | """ 29 | path = pathlib.Path(root_dir) 30 | self.path = path 31 | 32 | # Locate the labels directory. In s1.0.0 this would be self.path / "breps" 33 | # but in s2.0.0 this is self.path / "breps/seg" 34 | self.seg_path = self.path / "breps/seg" 35 | if not self.seg_path.exists(): 36 | self.seg_path = self.path / "breps" 37 | 38 | assert split in ("train", "val", "test") 39 | 40 | with open(str(path.joinpath("train_test.json")), "r") as read_file: 41 | filelist = json.load(read_file) 42 | 43 | # NOTE: Using a held out validation set may be better. 44 | # But it's not easy to perform stratified sampling on some rare classes 45 | # which only show up on a few solids. 46 | if split in ("train", "val"): 47 | full_train_filelist = filelist["train"] 48 | train_filesplit, val_filesplit = train_test_split( 49 | full_train_filelist, test_size=0.2, random_state=42 50 | ) 51 | if split == "train": 52 | split_filelist = train_filesplit 53 | else: 54 | split_filelist = val_filesplit 55 | else: 56 | split_filelist = filelist["test"] 57 | 58 | self.random_rotate = random_rotate 59 | 60 | # Call base class method to load all graphs 61 | print(f"Loading {split} data...") 62 | all_files = [path.joinpath("graph").joinpath(fn + ".bin") for fn in split_filelist] 63 | self.load_graphs(all_files, center_and_scale) 64 | print("Done loading {} files".format(len(self.data))) 65 | 66 | def load_one_graph(self, file_path): 67 | # Load the graph using base class method 68 | sample = super().load_one_graph(file_path) 69 | # Additionally load the label and store it as node data 70 | label = np.loadtxt( 71 | self.seg_path.joinpath(file_path.stem + ".seg"), dtype=np.int, ndmin=1 72 | ) 73 | if sample["graph"].number_of_nodes() != label.shape[0]: 74 | return None 75 | sample["graph"].ndata["y"] = torch.tensor(label).long() 76 | return sample 77 | -------------------------------------------------------------------------------- /process/solid_to_pointcloud.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import repeat 3 | from multiprocessing.pool import Pool 4 | import pathlib 5 | import signal 6 | 7 | import numpy as np 8 | import trimesh 9 | from occwl.compound import Compound 10 | from tqdm import tqdm 11 | 12 | from process.solid_to_rendermesh import triangulate_with_face_mapping 13 | 14 | 15 | def process_one_file(arguments): 16 | fn, args = arguments 17 | if fn.stat().st_size == 0: 18 | return None 19 | fn_stem = fn.stem 20 | output_path = pathlib.Path(args.output) 21 | if not output_path.exists(): 22 | output_path.mkdir(parents=True, exist_ok=True) 23 | try: 24 | solid = Compound.load_from_step(fn) 25 | except Exception as e: 26 | print(e) 27 | return 28 | 29 | verts, tris, tri_mapping = triangulate_with_face_mapping(solid) 30 | 31 | mesh = trimesh.Trimesh(vertices=verts, faces=tris) 32 | points, face_indices = trimesh.sample.sample_surface(mesh, args.num_points) 33 | points_to_face_mapping = tri_mapping[face_indices] 34 | 35 | # import matplotlib.pyplot as plt 36 | # from matplotlib.colors import Normalize 37 | # from matplotlib.cm import tab20 38 | # from mpl_toolkits.mplot3d import Axes3D 39 | 40 | # fig = plt.figure() 41 | # ax = fig.gca(projection="3d") 42 | # colors = tab20(points_to_face_mapping) 43 | # norm = Normalize( 44 | # vmin=np.amin(points_to_face_mapping), vmax=np.amax(points_to_face_mapping) 45 | # ) 46 | # ax.scatter( 47 | # points[:, 0], points[:, 1], points[:, 2], c=colors, norm=norm, 48 | # ) 49 | # plt.show() 50 | 51 | # Write to numpy compressed archive 52 | np.savez( 53 | str(output_path.joinpath(fn_stem + ".npz")), 54 | points=points, 55 | point_mapping=points_to_face_mapping, 56 | ) 57 | 58 | 59 | def initializer(): 60 | """Ignore CTRL+C in the worker process.""" 61 | signal.signal(signal.SIGINT, signal.SIG_IGN) 62 | 63 | 64 | def process(args): 65 | input_path = pathlib.Path(args.input) 66 | output_path = pathlib.Path(args.output) 67 | if not output_path.exists(): 68 | output_path.mkdir(parents=True, exist_ok=True) 69 | step_files = list(input_path.glob("*.st*p")) 70 | # for fn in tqdm(step_files): 71 | # process_one_file(fn, args) 72 | pool = Pool(processes=args.num_processes, initializer=initializer) 73 | try: 74 | results = list(tqdm(pool.imap(process_one_file, zip(step_files, repeat(args))), total=len(step_files))) 75 | except KeyboardInterrupt: 76 | pool.terminate() 77 | pool.join() 78 | print(f"Processed {len(results)} files.") 79 | 80 | 81 | def main(): 82 | parser = argparse.ArgumentParser("Convert solid models to point clouds") 83 | parser.add_argument("input", type=str, help="Input folder of STEP files") 84 | parser.add_argument( 85 | "output", type=str, help="Output folder of NPZ point cloud files" 86 | ) 87 | parser.add_argument( 88 | "--num_points", 89 | type=int, 90 | default=2048, 91 | help="Number of points in the point cloud", 92 | ) 93 | parser.add_argument( 94 | "--num_processes", 95 | type=int, 96 | default=8, 97 | help="Number of processes to use", 98 | ) 99 | args = parser.parse_args() 100 | process(args) 101 | 102 | 103 | if __name__ == "__main__": 104 | main() 105 | -------------------------------------------------------------------------------- /process/solid_to_rendermesh.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import numpy as np 5 | from occwl.entity_mapper import EntityMapper 6 | from occwl.compound import Compound 7 | from tqdm import tqdm 8 | import trimesh 9 | 10 | 11 | def triangulate_with_face_mapping(solid, triangle_face_tol=0.01, angle_tol_rads=0.1): 12 | # Triangulate faces 13 | solid.triangulate_all_faces( 14 | triangle_face_tol=triangle_face_tol, angle_tol_rads=angle_tol_rads 15 | ) 16 | 17 | verts = [] 18 | tris = [] 19 | # Store B-rep face index along with triangles 20 | mapper = EntityMapper(solid) 21 | tri_mapping = [] 22 | vert_counter = 0 23 | for face in solid.faces(): 24 | face_index = mapper.face_index(face) 25 | face_verts, face_tris = face.get_triangles() 26 | if len(face_tris) == 0: 27 | continue 28 | face_tris += vert_counter 29 | vert_counter += face_verts.shape[0] 30 | face_mapping = np.ones(face_tris.shape[0]) * face_index 31 | verts.append(face_verts) 32 | tris.append(face_tris) 33 | tri_mapping.append(face_mapping) 34 | if len(verts) == 0: 35 | return None, None, None 36 | verts = np.concatenate(verts, axis=0).astype(np.float32) 37 | if len(tris) == 0: 38 | return None, None, None 39 | tris = np.concatenate(tris, axis=0).astype(np.int32) 40 | tri_mapping = np.concatenate(tri_mapping, axis=-1).astype(np.int32) 41 | return verts, tris, tri_mapping 42 | 43 | 44 | def process_one_file(fn, args): 45 | fn_stem = fn.stem 46 | output_path = pathlib.Path(args.output) 47 | output_filename = output_path.joinpath(fn_stem + ".stl") 48 | if output_filename.exists(): 49 | return 50 | solid = Compound.load_from_step(fn) 51 | 52 | verts, tris, tri_mapping = triangulate_with_face_mapping( 53 | solid, args.triangle_face_tol, args.angle_tol_rads 54 | ) 55 | 56 | # from mpl_toolkits.mplot3d import Axes3D 57 | # import matplotlib.pyplot as plt 58 | # fig = plt.figure() 59 | # ax = fig.gca(projection='3d') 60 | # ax.plot_trisurf(verts[:,0], verts[:,1], verts[:,2], triangles = tris, alpha=0.8) 61 | # plt.show() 62 | 63 | # Write to numpy compressed archive 64 | # np.savez( 65 | # str(output_path.joinpath(fn_stem + ".npz")), 66 | # vertices=verts, 67 | # triangles=tris, 68 | # triangle_mapping=tri_mapping, 69 | # ) 70 | trimesh.Trimesh(vertices=verts, faces=tris).export(str(output_filename)) 71 | 72 | 73 | def process(args): 74 | input_path = pathlib.Path(args.input) 75 | step_files = list(input_path.glob(f"{args.filename_pattern}.st*p")) 76 | for fn in tqdm(step_files): 77 | process_one_file(fn, args) 78 | 79 | 80 | def main(): 81 | parser = argparse.ArgumentParser( 82 | "Convert solid models to render (non-watertight) meshes" 83 | ) 84 | parser.add_argument("input", type=str, help="Input folder of STEP files") 85 | parser.add_argument("output", type=str, help="Output folder of NPZ mesh files") 86 | parser.add_argument( 87 | "--triangle_face_tol", 88 | type=float, 89 | default=0.01, 90 | help="Tolerance between triangle and surface relative to each B-rep face", 91 | ) 92 | parser.add_argument( 93 | "--angle_tol_rads", 94 | type=float, 95 | default=0.1, 96 | help="Tolerance angle between normals/tangents at triangle vertices (in radians)", 97 | ) 98 | parser.add_argument( 99 | "--filename_pattern", 100 | type=str, 101 | default="*", 102 | help="Filename regex pattern to filter input files. Defaults to '*'", 103 | ) 104 | 105 | args = parser.parse_args() 106 | process(args) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /process/visualize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from occwl.viewer import Viewer 4 | from occwl.io import load_step 5 | from occwl.edge import Edge 6 | from occwl.solid import Solid 7 | 8 | import torch 9 | import dgl 10 | from dgl.data.utils import load_graphs 11 | 12 | 13 | def draw_face_uvgrids(solid, graph, viewer): 14 | face_uvgrids = graph.ndata["x"].view(-1, 7) 15 | points = [] 16 | normals = [] 17 | for idx in range(face_uvgrids.shape[0]): 18 | # Don't draw points outside trimming loop 19 | if face_uvgrids[idx, -1] == 0: 20 | continue 21 | points.append(face_uvgrids[idx, :3].cpu().numpy()) 22 | normals.append(face_uvgrids[idx, 3:6].cpu().numpy()) 23 | 24 | points = np.asarray(points, dtype=np.float32) 25 | normals = np.asarray(normals, dtype=np.float32) 26 | 27 | bbox = solid.box() 28 | max_length = max(bbox.x_length(), bbox.y_length(), bbox.z_length()) 29 | 30 | # Draw the points 31 | viewer.display_points( 32 | points, color=(51.0 / 255.0, 0, 1), marker="point", scale=2*max_length 33 | ) 34 | 35 | # Draw the normals 36 | for pt, nor in zip(points, normals): 37 | viewer.display(Edge.make_line_from_points(pt, pt + nor * 0.05 * max_length), color=(51.0 / 255.0, 0, 1)) 38 | 39 | 40 | def draw_edge_uvgrids(solid, graph, viewer): 41 | edge_uvgrids = graph.edata["x"].view(-1, 6) 42 | points = [] 43 | tangents = [] 44 | for idx in range(edge_uvgrids.shape[0]): 45 | points.append(edge_uvgrids[idx, :3].cpu().numpy()) 46 | tangents.append(edge_uvgrids[idx, 3:6].cpu().numpy()) 47 | 48 | points = np.asarray(points, dtype=np.float32) 49 | tangents = np.asarray(tangents, dtype=np.float32) 50 | 51 | bbox = solid.box() 52 | max_length = max(bbox.x_length(), bbox.y_length(), bbox.z_length()) 53 | 54 | # Draw the points 55 | viewer.display_points(points, color=(1, 0, 1), marker="point", scale=2*max_length) 56 | 57 | # Draw the tangents 58 | for pt, tgt in zip(points, tangents): 59 | viewer.display(Edge.make_line_from_points(pt, pt + tgt * 0.1 * max_length), color=(1, 0, 1)) 60 | 61 | 62 | def draw_graph_edges(solid, graph, viewer): 63 | src, dst = graph.edges() 64 | num_u = graph.ndata["x"].shape[1] 65 | num_v = graph.ndata["x"].shape[2] 66 | bbox = solid.box() 67 | max_length = max(bbox.x_length(), bbox.y_length(), bbox.z_length()) 68 | 69 | for s, d in zip(src, dst): 70 | src_pt = graph.ndata["x"][s, num_u // 2, num_v // 2, :3].cpu().numpy() 71 | dst_pt = graph.ndata["x"][d, num_u // 2, num_v // 2, :3].cpu().numpy() 72 | # Make a cylinder for each edge connecting a pair of faces 73 | up_dir = dst_pt - src_pt 74 | height = np.linalg.norm(up_dir) 75 | if height > 1e-3: 76 | v.display( 77 | Solid.make_cylinder( 78 | radius=0.01 * max_length, height=height, base_point=src_pt, up_dir=up_dir 79 | ), 80 | color="BLACK", 81 | ) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser( 86 | "Visualize UV-grids and face adj graphs for testing" 87 | ) 88 | parser.add_argument("solid", type=str, help="Solid STEP file") 89 | parser.add_argument("graph", type=str, help="Graph BIN file") 90 | args = parser.parse_args() 91 | 92 | solid = load_step(args.solid)[0] 93 | graph = load_graphs(args.graph)[0][0] 94 | 95 | v = Viewer(backend="wx") 96 | # Draw the solid 97 | v.display(solid, transparency=0.5, color=(0.2, 0.2, 0.2)) 98 | # Draw the face UV-grids 99 | draw_face_uvgrids(solid, graph, viewer=v) 100 | # Draw the edge UV-grids 101 | draw_edge_uvgrids(solid, graph, viewer=v) 102 | # Draw face-adj graph edges 103 | draw_graph_edges(solid, graph, viewer=v) 104 | 105 | v.fit() 106 | v.show() 107 | -------------------------------------------------------------------------------- /datasets/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.spatial.transform import Rotation 6 | 7 | 8 | def bounding_box_uvgrid(inp: torch.Tensor): 9 | pts = inp[..., :3].reshape((-1, 3)) 10 | mask = inp[..., 6].reshape(-1) 11 | point_indices_inside_faces = mask == 1 12 | pts = pts[point_indices_inside_faces, :] 13 | return bounding_box_pointcloud(pts) 14 | 15 | 16 | def bounding_box_pointcloud(pts: torch.Tensor): 17 | x = pts[:, 0] 18 | y = pts[:, 1] 19 | z = pts[:, 2] 20 | box = [[x.min(), y.min(), z.min()], [x.max(), y.max(), z.max()]] 21 | return torch.tensor(box) 22 | 23 | 24 | def center_and_scale_uvgrid(inp: torch.Tensor, return_center_scale=False): 25 | bbox = bounding_box_uvgrid(inp) 26 | diag = bbox[1] - bbox[0] 27 | scale = 2.0 / max(diag[0], diag[1], diag[2]) 28 | center = 0.5 * (bbox[0] + bbox[1]) 29 | inp[..., :3] -= center 30 | inp[..., :3] *= scale 31 | if return_center_scale: 32 | return inp, center, scale 33 | return inp 34 | 35 | 36 | def get_random_rotation(): 37 | """Get a random rotation in 90 degree increments along the canonical axes""" 38 | axes = [ 39 | np.array([1, 0, 0]), 40 | np.array([0, 1, 0]), 41 | np.array([0, 0, 1]), 42 | ] 43 | angles = [0.0, 90.0, 180.0, 270.0] 44 | axis = random.choice(axes) 45 | angle_radians = np.radians(random.choice(angles)) 46 | return Rotation.from_rotvec(angle_radians * axis) 47 | 48 | 49 | def rotate_uvgrid(inp, rotation): 50 | """Rotate the node features in the graph by a given rotation""" 51 | Rmat = torch.tensor(rotation.as_matrix()).float() 52 | orig_size = inp[..., :3].size() 53 | inp[..., :3] = torch.mm(inp[..., :3].view(-1, 3), Rmat).view( 54 | orig_size 55 | ) # Points 56 | inp[..., 3:6] = torch.mm(inp[..., 3:6].view(-1, 3), Rmat).view( 57 | orig_size 58 | ) # Normals/tangents 59 | return inp 60 | 61 | 62 | INVALID_FONTS = [ 63 | "Bokor", 64 | "Lao Muang Khong", 65 | "Lao Sans Pro", 66 | "MS Outlook", 67 | "Catamaran Black", 68 | "Dubai", 69 | "HoloLens MDL2 Assets", 70 | "Lao Muang Don", 71 | "Oxanium Medium", 72 | "Rounded Mplus 1c", 73 | "Moul Pali", 74 | "Noto Sans Tamil", 75 | "Webdings", 76 | "Armata", 77 | "Koulen", 78 | "Yinmar", 79 | "Ponnala", 80 | "Noto Sans Tamil", 81 | "Chenla", 82 | "Lohit Devanagari", 83 | "Metal", 84 | "MS Office Symbol", 85 | "Cormorant Garamond Medium", 86 | "Chiller", 87 | "Give You Glory", 88 | "Hind Vadodara Light", 89 | "Libre Barcode 39 Extended", 90 | "Myanmar Sans Pro", 91 | "Scheherazade", 92 | "Segoe MDL2 Assets", 93 | "Siemreap", 94 | "Signika SemiBold" "Taprom", 95 | "Times New Roman TUR", 96 | "Playfair Display SC Black", 97 | "Poppins Thin", 98 | "Raleway Dots", 99 | "Raleway Thin", 100 | "Segoe MDL2 Assets", 101 | "Segoe MDL2 Assets", 102 | "Spectral SC ExtraLight", 103 | "Txt", 104 | "Uchen", 105 | "Yinmar", 106 | "Almarai ExtraBold", 107 | "Fasthand", 108 | "Exo", 109 | "Freckle Face", 110 | "Montserrat Light", 111 | "Inter", 112 | "MS Reference Specialty", 113 | "MS Outlook", 114 | "Preah Vihear", 115 | "Sitara", 116 | "Barkerville Old Face", 117 | "Bodoni MT" "Bokor", 118 | "Fasthand", 119 | "HoloLens MDL2 Assests", 120 | "Libre Barcode 39", 121 | "Lohit Tamil", 122 | "Marlett", 123 | "MS outlook", 124 | "MS office Symbol Semilight", 125 | "MS office symbol regular", 126 | "Ms office symbol extralight", 127 | "Ms Reference speciality", 128 | "Segoe MDL2 Assets", 129 | "Siemreap", 130 | "Sitara", 131 | "Symbol", 132 | "Wingdings", 133 | "Metal", 134 | "Ponnala", 135 | "Webdings", 136 | "Souliyo Unicode", 137 | "Aguafina Script", 138 | "Yantramanav Black", 139 | # "Yaldevi", 140 | # Taprom, 141 | # "Zhi Mang Xing", 142 | # "Taviraj", 143 | # "SeoulNamsan EB", 144 | ] 145 | 146 | 147 | def valid_font(filename): 148 | for name in INVALID_FONTS: 149 | if name.lower() in str(filename).lower(): 150 | return False 151 | return True 152 | -------------------------------------------------------------------------------- /classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import time 4 | 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | from pytorch_lightning.utilities.seed import seed_everything 9 | 10 | from datasets.solidletters import SolidLetters 11 | from uvnet.models import Classification 12 | 13 | parser = argparse.ArgumentParser("UV-Net solid model classification") 14 | parser.add_argument( 15 | "traintest", choices=("train", "test"), help="Whether to train or test" 16 | ) 17 | parser.add_argument("--dataset", choices=("solidletters",), help="Dataset to train on") 18 | parser.add_argument("--dataset_path", type=str, help="Path to dataset") 19 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size") 20 | parser.add_argument( 21 | "--num_workers", 22 | type=int, 23 | default=0, 24 | help="Number of workers for the dataloader. NOTE: set this to 0 on Windows, any other value leads to poor performance", 25 | ) 26 | parser.add_argument( 27 | "--checkpoint", 28 | type=str, 29 | default=None, 30 | help="Checkpoint file to load weights from for testing", 31 | ) 32 | parser.add_argument( 33 | "--experiment_name", 34 | type=str, 35 | default="classification", 36 | help="Experiment name (used to create folder inside ./results/ to save logs and checkpoints)", 37 | ) 38 | 39 | parser = Trainer.add_argparse_args(parser) 40 | args = parser.parse_args() 41 | 42 | results_path = ( 43 | pathlib.Path(__file__).parent.joinpath("results").joinpath(args.experiment_name) 44 | ) 45 | if not results_path.exists(): 46 | results_path.mkdir(parents=True, exist_ok=True) 47 | 48 | # Define a path to save the results based date and time. E.g. 49 | # results/args.experiment_name/0430/123103 50 | month_day = time.strftime("%m%d") 51 | hour_min_second = time.strftime("%H%M%S") 52 | checkpoint_callback = ModelCheckpoint( 53 | monitor="val_loss", 54 | dirpath=str(results_path.joinpath(month_day, hour_min_second)), 55 | filename="best", 56 | save_last=True, 57 | ) 58 | 59 | trainer = Trainer.from_argparse_args( 60 | args, 61 | callbacks=[checkpoint_callback], 62 | logger=TensorBoardLogger( 63 | str(results_path), name=month_day, version=hour_min_second, 64 | ), 65 | ) 66 | 67 | if args.dataset == "solidletters": 68 | Dataset = SolidLetters 69 | else: 70 | raise ValueError("Unsupported dataset") 71 | 72 | if args.traintest == "train": 73 | # Train/val 74 | seed_everything(workers=True) 75 | print( 76 | f""" 77 | ----------------------------------------------------------------------------------- 78 | UV-Net Classification 79 | ----------------------------------------------------------------------------------- 80 | Logs written to results/{args.experiment_name}/{month_day}/{hour_min_second} 81 | 82 | To monitor the logs, run: 83 | tensorboard --logdir results/{args.experiment_name}/{month_day}/{hour_min_second} 84 | 85 | The trained model with the best validation loss will be written to: 86 | results/{args.experiment_name}/{month_day}/{hour_min_second}/best.ckpt 87 | ----------------------------------------------------------------------------------- 88 | """ 89 | ) 90 | model = Classification(num_classes=Dataset.num_classes()) 91 | train_data = Dataset(root_dir=args.dataset_path, split="train") 92 | val_data = Dataset(root_dir=args.dataset_path, split="val") 93 | train_loader = train_data.get_dataloader( 94 | batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers 95 | ) 96 | val_loader = val_data.get_dataloader( 97 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 98 | ) 99 | trainer.fit(model, train_loader, val_loader) 100 | else: 101 | # Test 102 | assert ( 103 | args.checkpoint is not None 104 | ), "Expected the --checkpoint argument to be provided" 105 | test_data = Dataset(root_dir=args.dataset_path, split="test") 106 | test_loader = test_data.get_dataloader( 107 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 108 | ) 109 | model = Classification.load_from_checkpoint(args.checkpoint) 110 | results = trainer.test(model=model, test_dataloaders=[test_loader], verbose=False) 111 | print( 112 | f"Classification accuracy (%) on test set: {results[0]['test_acc_epoch'] * 100.0}" 113 | ) 114 | -------------------------------------------------------------------------------- /segmentation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | import time 4 | 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | from pytorch_lightning.utilities.seed import seed_everything 9 | 10 | from datasets.fusiongallery import FusionGalleryDataset 11 | from datasets.mfcad import MFCADDataset 12 | from uvnet.models import Segmentation 13 | 14 | parser = argparse.ArgumentParser("UV-Net solid model face segmentation") 15 | parser.add_argument( 16 | "traintest", choices=("train", "test"), help="Whether to train or test" 17 | ) 18 | parser.add_argument( 19 | "--dataset", choices=("mfcad", "fusiongallery"), help="Segmentation dataset" 20 | ) 21 | parser.add_argument("--dataset_path", type=str, help="Path to dataset") 22 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size") 23 | parser.add_argument( 24 | "--num_workers", 25 | type=int, 26 | default=0, 27 | help="Number of workers for the dataloader. NOTE: set this to 0 on Windows, any other value leads to poor performance", 28 | ) 29 | parser.add_argument( 30 | "--random_rotate", 31 | action="store_true", 32 | help="Whether to randomly rotate the solids in 90 degree increments along the canonical axes", 33 | ) 34 | parser.add_argument( 35 | "--crv_in_channels", 36 | type=int, 37 | default=6, 38 | help="Number of channels for curve input", 39 | ) 40 | parser.add_argument( 41 | "--checkpoint", 42 | type=str, 43 | default=None, 44 | help="Checkpoint file to load weights from for testing", 45 | ) 46 | parser.add_argument( 47 | "--experiment_name", 48 | type=str, 49 | default="segmentation", 50 | help="Experiment name (used to create folder inside ./results/ to save logs and checkpoints)", 51 | ) 52 | 53 | parser = Trainer.add_argparse_args(parser) 54 | args = parser.parse_args() 55 | 56 | results_path = ( 57 | pathlib.Path(__file__).parent.joinpath("results").joinpath(args.experiment_name) 58 | ) 59 | if not results_path.exists(): 60 | results_path.mkdir(parents=True, exist_ok=True) 61 | 62 | # Define a path to save the results based date and time. E.g. 63 | # results/args.experiment_name/0430/123103 64 | month_day = time.strftime("%m%d") 65 | hour_min_second = time.strftime("%H%M%S") 66 | checkpoint_callback = ModelCheckpoint( 67 | monitor="val_loss", 68 | dirpath=str(results_path.joinpath(month_day, hour_min_second)), 69 | filename="best", 70 | save_last=True, 71 | ) 72 | trainer = Trainer.from_argparse_args( 73 | args, 74 | callbacks=[checkpoint_callback], 75 | logger=TensorBoardLogger( 76 | str(results_path), name=month_day, version=hour_min_second, 77 | ), 78 | ) 79 | 80 | if args.dataset == "mfcad": 81 | Dataset = MFCADDataset 82 | elif args.dataset == "fusiongallery": 83 | Dataset = FusionGalleryDataset 84 | 85 | if args.traintest == "train": 86 | # Train/val 87 | seed_everything(workers=True) 88 | print( 89 | f""" 90 | ----------------------------------------------------------------------------------- 91 | UV-Net Segmentation 92 | ----------------------------------------------------------------------------------- 93 | Logs written to results/{args.experiment_name}/{month_day}/{hour_min_second} 94 | 95 | To monitor the logs, run: 96 | tensorboard --logdir results/{args.experiment_name}/{month_day}/{hour_min_second} 97 | 98 | The trained model with the best validation loss will be written to: 99 | results/{args.experiment_name}/{month_day}/{hour_min_second}/best.ckpt 100 | ----------------------------------------------------------------------------------- 101 | """ 102 | ) 103 | model = Segmentation( 104 | num_classes=Dataset.num_classes(), 105 | crv_in_channels=args.crv_in_channels 106 | ) 107 | train_data = Dataset( 108 | root_dir=args.dataset_path, split="train", random_rotate=args.random_rotate 109 | ) 110 | val_data = Dataset(root_dir=args.dataset_path, split="val") 111 | train_loader = train_data.get_dataloader( 112 | batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers 113 | ) 114 | val_loader = val_data.get_dataloader( 115 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 116 | ) 117 | trainer.fit(model, train_loader, val_loader) 118 | else: 119 | # Test 120 | assert ( 121 | args.checkpoint is not None 122 | ), "Expected the --checkpoint argument to be provided" 123 | model = Segmentation.load_from_checkpoint(args.checkpoint) 124 | test_data = Dataset( 125 | root_dir=args.dataset_path, split="test", random_rotate=args.random_rotate 126 | ) 127 | test_loader = test_data.get_dataloader( 128 | batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers 129 | ) 130 | results = trainer.test(model=model, test_dataloaders=[test_loader], verbose=False) 131 | print(f"Segmentation IoU (%) on test set: {results[0]['test_iou'] * 100.0}") 132 | -------------------------------------------------------------------------------- /process/solid_to_graph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pathlib 3 | 4 | import dgl 5 | import numpy as np 6 | import torch 7 | from occwl.graph import face_adjacency 8 | from occwl.io import load_step 9 | from occwl.uvgrid import ugrid, uvgrid 10 | from tqdm import tqdm 11 | from multiprocessing.pool import Pool 12 | from itertools import repeat 13 | import signal 14 | 15 | 16 | def build_graph(solid, curv_num_u_samples, surf_num_u_samples, surf_num_v_samples): 17 | # Build face adjacency graph with B-rep entities as node and edge features 18 | graph = face_adjacency(solid) 19 | 20 | # Compute the UV-grids for faces 21 | graph_face_feat = [] 22 | for face_idx in graph.nodes: 23 | # Get the B-rep face 24 | face = graph.nodes[face_idx]["face"] 25 | # Compute UV-grids 26 | points = uvgrid( 27 | face, method="point", num_u=surf_num_u_samples, num_v=surf_num_v_samples 28 | ) 29 | normals = uvgrid( 30 | face, method="normal", num_u=surf_num_u_samples, num_v=surf_num_v_samples 31 | ) 32 | visibility_status = uvgrid( 33 | face, method="visibility_status", num_u=surf_num_u_samples, num_v=surf_num_v_samples 34 | ) 35 | mask = np.logical_or(visibility_status == 0, visibility_status == 2) # 0: Inside, 1: Outside, 2: On boundary 36 | # Concatenate channel-wise to form face feature tensor 37 | face_feat = np.concatenate((points, normals, mask), axis=-1) 38 | graph_face_feat.append(face_feat) 39 | graph_face_feat = np.asarray(graph_face_feat) 40 | 41 | # Compute the U-grids for edges 42 | graph_edge_feat = [] 43 | for edge_idx in graph.edges: 44 | # Get the B-rep edge 45 | edge = graph.edges[edge_idx]["edge"] 46 | # Ignore dgenerate edges, e.g. at apex of cone 47 | if not edge.has_curve(): 48 | continue 49 | # Compute U-grids 50 | points = ugrid(edge, method="point", num_u=curv_num_u_samples) 51 | tangents = ugrid(edge, method="tangent", num_u=curv_num_u_samples) 52 | # Concatenate channel-wise to form edge feature tensor 53 | edge_feat = np.concatenate((points, tangents), axis=-1) 54 | graph_edge_feat.append(edge_feat) 55 | graph_edge_feat = np.asarray(graph_edge_feat) 56 | 57 | # Convert face-adj graph to DGL format 58 | edges = list(graph.edges) 59 | src = [e[0] for e in edges] 60 | dst = [e[1] for e in edges] 61 | dgl_graph = dgl.graph((src, dst), num_nodes=len(graph.nodes)) 62 | dgl_graph.ndata["x"] = torch.from_numpy(graph_face_feat) 63 | dgl_graph.edata["x"] = torch.from_numpy(graph_edge_feat) 64 | return dgl_graph 65 | 66 | 67 | def process_one_file(arguments): 68 | fn, args = arguments 69 | fn_stem = fn.stem 70 | output_path = pathlib.Path(args.output) 71 | solid = load_step(fn)[0] # Assume there's one solid per file 72 | graph = build_graph( 73 | solid, args.curv_u_samples, args.surf_u_samples, args.surf_v_samples 74 | ) 75 | dgl.data.utils.save_graphs(str(output_path.joinpath(fn_stem + ".bin")), [graph]) 76 | 77 | 78 | def initializer(): 79 | """Ignore CTRL+C in the worker process.""" 80 | signal.signal(signal.SIGINT, signal.SIG_IGN) 81 | 82 | 83 | def process(args): 84 | input_path = pathlib.Path(args.input) 85 | output_path = pathlib.Path(args.output) 86 | if not output_path.exists(): 87 | output_path.mkdir(parents=True, exist_ok=True) 88 | step_files = list(input_path.glob("*.st*p")) 89 | # for fn in tqdm(step_files): 90 | # process_one_file(fn, args) 91 | pool = Pool(processes=args.num_processes, initializer=initializer) 92 | try: 93 | results = list(tqdm(pool.imap(process_one_file, zip(step_files, repeat(args))), total=len(step_files))) 94 | except KeyboardInterrupt: 95 | pool.terminate() 96 | pool.join() 97 | print(f"Processed {len(results)} files.") 98 | 99 | 100 | def main(): 101 | parser = argparse.ArgumentParser( 102 | "Convert solid models to face-adjacency graphs with UV-grid features" 103 | ) 104 | parser.add_argument("input", type=str, help="Input folder of STEP files") 105 | parser.add_argument("output", type=str, help="Output folder of DGL graph BIN files") 106 | parser.add_argument( 107 | "--curv_u_samples", type=int, default=10, help="Number of samples on each curve" 108 | ) 109 | parser.add_argument( 110 | "--surf_u_samples", 111 | type=int, 112 | default=10, 113 | help="Number of samples on each surface along the u-direction", 114 | ) 115 | parser.add_argument( 116 | "--surf_v_samples", 117 | type=int, 118 | default=10, 119 | help="Number of samples on each surface along the v-direction", 120 | ) 121 | parser.add_argument( 122 | "--num_processes", 123 | type=int, 124 | default=8, 125 | help="Number of processes to use", 126 | ) 127 | args = parser.parse_args() 128 | process(args) 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /process/visualize_uvgrid_graph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import pathlib 4 | import os 5 | import os.path as osp 6 | import matplotlib 7 | 8 | # matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | from mpl_toolkits.mplot3d import Axes3D 11 | import torch 12 | import dgl 13 | from dgl.data.utils import load_graphs 14 | 15 | 16 | def bounding_box_pointcloud(pts: torch.Tensor): 17 | x = pts[:, 0] 18 | y = pts[:, 1] 19 | z = pts[:, 2] 20 | box = [[x.min(), y.min(), z.min()], [x.max(), y.max(), z.max()]] 21 | return torch.tensor(box) 22 | 23 | 24 | def bounding_box_uvsolid(inp: torch.Tensor): 25 | pts = inp[:, :, :, :3].reshape((-1, 3)) 26 | mask = inp[:, :, :, 6].reshape(-1) 27 | point_indices_inside_faces = mask == 1 28 | pts = pts[point_indices_inside_faces, :] 29 | return bounding_box_pointcloud(pts) 30 | 31 | 32 | def plot_uvsolid(uvsolid: torch.Tensor, ax, normals=False): 33 | """ 34 | Plot the loaded UV solid features to a MPL 3D Axes 35 | :param uvsolid torch.Tensor: Features loaded from *.feat file of shape [#faces, #u, #v, 10] 36 | :param ax matplotlib Axes3D: 3D Axes object for plotting 37 | """ 38 | assert len(uvsolid.shape) == 4 # faces x #u x #v x 10 39 | bbox = bounding_box_uvsolid(uvsolid) 40 | bbox_diag = torch.norm(bbox[1] - bbox[0]).item() 41 | num_faces = uvsolid.size(0) 42 | for i in range(num_faces): 43 | pts = uvsolid[i, :, :, :3].cpu().detach().numpy().reshape((-1, 3)) 44 | nor = uvsolid[i, :, :, 3:6].cpu().detach().numpy().reshape((-1, 3)) 45 | mask = uvsolid[i, :, :, 6].cpu().detach().numpy().reshape(-1) 46 | point_indices_inside_faces = mask == 1 47 | pts = pts[point_indices_inside_faces, :] 48 | if normals: 49 | nor = nor[point_indices_inside_faces, :] 50 | ax.quiver( 51 | pts[:, 0], 52 | pts[:, 1], 53 | pts[:, 2], 54 | nor[:, 0], 55 | nor[:, 1], 56 | nor[:, 2], 57 | length=0.075 * bbox_diag, 58 | ) 59 | ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2]) 60 | 61 | 62 | def plot_uvsolid_edges(graph: dgl.DGLGraph, ax, tangents=False): 63 | """ 64 | Plot the loaded UV solid's edge features to a MPL 3D Axes 65 | :param graph: dgl.DGLGraph: DGL Graph containing the graph with UV-grids as node features and 1D UV-grids as edge features 66 | :param ax matplotlib Axes3D: 3D Axes object for plotting 67 | """ 68 | face_feat = graph.ndata["x"] 69 | if graph.edata.get("x") is None: 70 | print("Edge features not found") 71 | return 72 | edge_feat = graph.edata["x"] 73 | if edge_feat.shape[0] == 0: 74 | return 75 | assert edge_feat.shape[2] in (3, 6), edge_feat.shape # edges x #u x 3/6 76 | bbox = bounding_box_uvsolid(face_feat) 77 | bbox_diag = torch.norm(bbox[1] - bbox[0]).item() 78 | num_edges = graph.edata["x"].size(0) 79 | for i in range(num_edges): 80 | pts = graph.edata["x"][i, :, :3].cpu().detach().numpy().reshape((-1, 3)) 81 | if tangents: 82 | tgt = graph.edata["x"][i, :, 3:6].cpu().detach().numpy().reshape((-1, 3)) 83 | ax.quiver( 84 | pts[:, 0], 85 | pts[:, 1], 86 | pts[:, 2], 87 | tgt[:, 0], 88 | tgt[:, 1], 89 | tgt[:, 2], 90 | length=0.075 * bbox_diag, 91 | ) 92 | ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2]) 93 | 94 | 95 | def plot_faceadj_graph(graph: dgl.DGLGraph, ax): 96 | """ 97 | Plot the face-adj graph to a MPL 3D Axes 98 | :param graph: dgl.DGLGraph: DGL Graph containing the graph with UV-grids as node features 99 | :param ax matplotlib Axes3D: 3D Axes object for plotting 100 | """ 101 | assert len(graph.ndata["x"].shape) == 4 # faces x #u x #v x 10 102 | src, dst = graph.edges() 103 | for i in range(src.size(0)): 104 | center_idx = graph.ndata["x"].size(1) // 2 105 | src_pt = graph.ndata["x"][src[i], center_idx, center_idx, :3] 106 | dst_pt = graph.ndata["x"][dst[i], center_idx, center_idx, :3] 107 | ax.plot( 108 | (src_pt[0], dst_pt[0]), 109 | (src_pt[1], dst_pt[1]), 110 | zs=(src_pt[2], dst_pt[2]), 111 | color="k", 112 | linewidth=2, 113 | marker="o", 114 | ) 115 | 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser( 119 | "Visualize UV-grids and face adj graphs for testing" 120 | ) 121 | parser.add_argument("dir", type=str, default=None, help="Directory of bin files") 122 | parser.add_argument( 123 | "--hide_plots", 124 | action="store_true", 125 | help="Whether to hide the plots, and only save them", 126 | ) 127 | parser.add_argument( 128 | "--plot_face_normals", 129 | action="store_true", 130 | help="Whether to plot face normals", 131 | ) 132 | parser.add_argument( 133 | "--plot_edge_tangents", 134 | action="store_true", 135 | help="Whether to plot edge tangents", 136 | ) 137 | args, _ = parser.parse_known_args() 138 | 139 | if args.dir is None: 140 | raise ValueError("Expected a valid directory to be provided") 141 | folder = pathlib.Path(args.dir) 142 | bin_files = folder.glob("*.bin") 143 | 144 | for f in bin_files: 145 | graph = load_graphs(str(f))[0][0] 146 | 147 | fig = plt.figure() 148 | ax = fig.add_subplot(111, projection="3d") 149 | plt.gca().view_init(35, 90) 150 | ax.auto_scale_xyz([-1, 1], [-1, 1], [-1, 1]) 151 | 152 | plot_uvsolid( 153 | graph.ndata["x"], 154 | ax, 155 | normals=args.plot_face_normals, 156 | ) 157 | plot_faceadj_graph( 158 | graph, 159 | ax, 160 | ) 161 | plot_uvsolid_edges(graph, ax, tangents=args.plot_edge_tangents) 162 | plt.savefig(folder.joinpath(f.stem + ".jpg")) 163 | if not args.hide_plots: 164 | plt.show() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UV-Net: Learning from Boundary Representations 2 | 3 | This repository contains code for the paper: 4 | 5 | ["UV-Net: Learning from Boundary Representations."](https://arxiv.org/abs/2006.10211) Pradeep Kumar Jayaraman, Aditya Sanghi, Joseph G. Lambourne, Karl D.D. Willis, Thomas Davies, Hooman Shayani, Nigel Morris. CVPR 2021. 6 | 7 | ![Teaser](docs/img/Teaser.png) 8 | 9 | UV-Net is a neural network designed to operate directly on Boundary representation (B-rep) data from 3D CAD models. The B-rep format is widely used in the design, simulation and manufacturing industries to enable sophisticated and precise CAD modeling operations. However, B-rep data presents some unique challenges when used with neural networks due to the complexity of the data structure and its support for various disparate geometric and topological entities. 10 | 11 | In UV-Net, we represent the geometry stored in the edges (curves) and faces (surfaces) of the B-rep using 1D and 2D UV-grids, a structured set of points sampled by taking uniform steps in the parameter domain. 1D and 2D convolutional neural networks can be applied on these UV-grids to encode the edge and face geometry. 12 | 13 | ![UVGrid](docs/img/UVGrid.png) 14 | 15 | The topology is represented using a face-adjacency graph where features from the face UV-grids are stored as node features, and features from the edge UV-grids are stored as edge features. A graph neural network is then used to message pass these features to obtain embeddings for faces, edges and the entire solid model. 16 | 17 | ![MessagePassing](docs/img/MessagePassing.png) 18 | 19 | ## Environment setup 20 | 21 | ``` 22 | conda env create -f environment.yml 23 | conda activate uv_net 24 | ``` 25 | 26 | For CPU-only environments, the CPU-version of the dgl has to be installed manually: 27 | ``` 28 | conda install -c dglteam dgl=0.6.1=py39_0 29 | ``` 30 | 31 | ## Training 32 | 33 | The classification model can be trained using: 34 | ``` 35 | python classification.py train --dataset solidletters --dataset_path /path/to/solidletters --max_epochs 100 --batch_size 64 --experiment_name classification 36 | ``` 37 | 38 | Only the SolidLetters dataset is currently supported for classification. 39 | 40 | The segmentation model can be trained similarly: 41 | ``` 42 | python segmentation.py train --dataset mfcad --dataset_path /path/to/mfcad --max_epochs 100 --batch_size 64 --experiment_name segmentation 43 | ``` 44 | 45 | The MFCAD and Fusion 360 Gallery segmentation datasets are supported for segmentation. 46 | 47 | The logs and checkpoints will be stored in a folder called `results/classification` or `results/segmentation` based on the experiment name and timestamp, and can be monitored with Tensorboard: 48 | 49 | ``` 50 | tensorboard --logdir results/ 51 | ``` 52 | 53 | ## Testing 54 | The best checkpoints based on the smallest validation loss are saved in the results folder. The checkpoints can be used to test the model as follows: 55 | 56 | ``` 57 | python segmentation.py test --dataset mfcad --dataset_path /path/to/mfcad/ --checkpoint ./results/segmentation/best.ckpt 58 | ``` 59 | 60 | ## Data 61 | The network consumes [DGL](https://dgl.ai/)-based face-adjacency graphs, where each B-rep face is mapped to a node, and each B-rep edge is mapped to a edge. The face UV-grids are expected as node features and edge UV-grids as edge features. For example, the UV-grid features from our face-adjacency graph representation can be accessed as follows: 62 | 63 | ```python 64 | from dgl.data.utils import load_graphs 65 | 66 | graph = load_graphs(filename)[0] 67 | graph.ndata["x"] # num_facesx10x10x7 face UV-grids (we use 10 samples along the u- and v-directions of the surface) 68 | # The first three channels are the point coordinates, next three channels are the surface normals, and 69 | # the last channel is a trimming mask set to 1 if the point is in the visible part of the face and 0 otherwise 70 | graph.edata["x"] # num_edgesx10x6 edge UV-grids (we use 10 samples along the u-direction of the curve) 71 | # The first three channels are the point coordinates, next three channels are the curve tangents 72 | ``` 73 | 74 | ### SolidLetters 75 | 76 | SolidLetters is a synthetic dataset of ~96k solids created by extruding and filleting fonts. It has class labels (alphabets), and style labels (font name and upper/lower case) for each solid. 77 | 78 | The dataset can be downloaded from here: https://uv-net-data.s3.us-west-2.amazonaws.com/SolidLetters.zip 79 | 80 | To train the UV-Net classification model on the data: 81 | 82 | 1. Extract it to a folder, say `/path/to/solidletters/`. Please refer to the license in `/path/to/solidletters/SolidLetters Dataset License.pdf`. 83 | 84 | 2. There should be three files: 85 | 86 | - `/path/to/solidletters/smt.7z` contains the solid models in `.smt` format that can be read by a proprietory Autodesk solid modeling kernel and the Fusion 360 software. 87 | - `/path/to/solidletters/step.zip` contains the solid models in `.step` format that can be read with OpenCascade and its Python bindings [pythonOCC](https://github.com/tpaviot/pythonocc-core). 88 | - `/path/to/solidletters/graph.7z` contains the derived face-adjacency graphs in DGL's `.bin` format with UV-grids stored as node and edge features. This is the data in that gets passed to UV-Net for training and testing. Extract this file. 89 | 90 | 3. Pass the `/path/to/solidletters/` folder to the `--dataset_path` argument in the classification script and set `--dataset` to `solidletters`. 91 | 92 | ### MFCAD 93 | 94 | The original solid model data is available here in STEP format: [github.com/hducg/MFCAD](https://github.com/hducg/MFCAD). 95 | 96 | We provide pre-processed DGL graphs in `.bin` format to train UV-Net on this dataset. 97 | 98 | 1. Download and extract the data to a folder, say `/path/to/mfcad/` from here: https://uv-net-data.s3.us-west-2.amazonaws.com/MFCADDataset.zip 99 | 100 | 2. Pass the `/path/to/mfcad/` folder to the `--dataset_path` argument in the segmentation script and set `--dataset` to `mfcad`. 101 | 102 | ### Fusion 360 Gallery segmentation 103 | 104 | We provide pre-processed DGL graphs in `.bin` format to train UV-Net on the [Fusion 360 Gallery](https://github.com/AutodeskAILab/Fusion360GalleryDataset) segmentation task. 105 | 106 | 1. Download and extract the dataset to a folder, say `/path/to/fusiongallery/` from here: https://uv-net-data.s3.us-west-2.amazonaws.com/Fusion360GallerySegmentationDataset.zip 107 | 108 | 2. Pass the `/path/to/fusiongallery/` folder to the `--dataset_path` argument in the segmentation script and set `--dataset` to `fusiongallery`. 109 | 110 | 111 | ## Processing your own data 112 | Refer to our [guide](process/README.md) to process your own solid model data (in STEP format) into the `.bin` format that is understood by UV-Net, convert STEP files to meshes and pointclouds. 113 | 114 | ## Citation 115 | 116 | ``` 117 | @inproceedings{jayaraman2021uvnet, 118 | title = {UV-Net: Learning from Boundary Representations}, 119 | author = {Pradeep Kumar Jayaraman and Aditya Sanghi and Joseph G. Lambourne and Karl D.D. Willis and Thomas Davies and Hooman Shayani and Nigel Morris}, 120 | eprint = {2006.10211}, 121 | eprinttype = {arXiv}, 122 | eprintclass = {cs.CV}, 123 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 124 | year = {2021} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /uvnet/encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch.conv import NNConv 5 | from dgl.nn.pytorch.glob import MaxPooling 6 | 7 | def _conv1d(in_channels, out_channels, kernel_size=3, padding=0, bias=False): 8 | """ 9 | Helper function to create a 1D convolutional layer with batchnorm and LeakyReLU activation 10 | 11 | Args: 12 | in_channels (int): Input channels 13 | out_channels (int): Output channels 14 | kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. 15 | padding (int, optional): Padding size on each side. Defaults to 0. 16 | bias (bool, optional): Whether bias is used. Defaults to False. 17 | 18 | Returns: 19 | nn.Sequential: Sequential contained the Conv1d, BatchNorm1d and LeakyReLU layers 20 | """ 21 | return nn.Sequential( 22 | nn.Conv1d( 23 | in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias 24 | ), 25 | nn.BatchNorm1d(out_channels), 26 | nn.LeakyReLU(), 27 | ) 28 | 29 | 30 | def _conv2d(in_channels, out_channels, kernel_size, padding=0, bias=False): 31 | """ 32 | Helper function to create a 2D convolutional layer with batchnorm and LeakyReLU activation 33 | 34 | Args: 35 | in_channels (int): Input channels 36 | out_channels (int): Output channels 37 | kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. 38 | padding (int, optional): Padding size on each side. Defaults to 0. 39 | bias (bool, optional): Whether bias is used. Defaults to False. 40 | 41 | Returns: 42 | nn.Sequential: Sequential contained the Conv2d, BatchNorm2d and LeakyReLU layers 43 | """ 44 | return nn.Sequential( 45 | nn.Conv2d( 46 | in_channels, 47 | out_channels, 48 | kernel_size=kernel_size, 49 | padding=padding, 50 | bias=bias, 51 | ), 52 | nn.BatchNorm2d(out_channels), 53 | nn.LeakyReLU(), 54 | ) 55 | 56 | 57 | def _fc(in_features, out_features, bias=False): 58 | return nn.Sequential( 59 | nn.Linear(in_features, out_features, bias=bias), 60 | nn.BatchNorm1d(out_features), 61 | nn.LeakyReLU(), 62 | ) 63 | 64 | 65 | class _MLP(nn.Module): 66 | """""" 67 | 68 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 69 | """ 70 | MLP with linear output 71 | Args: 72 | num_layers (int): The number of linear layers in the MLP 73 | input_dim (int): Input feature dimension 74 | hidden_dim (int): Hidden feature dimensions for all hidden layers 75 | output_dim (int): Output feature dimension 76 | 77 | Raises: 78 | ValueError: If the given number of layers is <1 79 | """ 80 | super(_MLP, self).__init__() 81 | self.linear_or_not = True # default is linear model 82 | self.num_layers = num_layers 83 | self.output_dim = output_dim 84 | 85 | if num_layers < 1: 86 | raise ValueError("Number of layers should be positive!") 87 | elif num_layers == 1: 88 | # Linear model 89 | self.linear = nn.Linear(input_dim, output_dim) 90 | else: 91 | # Multi-layer model 92 | self.linear_or_not = False 93 | self.linears = torch.nn.ModuleList() 94 | self.batch_norms = torch.nn.ModuleList() 95 | 96 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 97 | for layer in range(num_layers - 2): 98 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 99 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 100 | 101 | # TODO: this could move inside the above loop 102 | for layer in range(num_layers - 1): 103 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 104 | 105 | def forward(self, x): 106 | if self.linear_or_not: 107 | # If linear model 108 | return self.linear(x) 109 | else: 110 | # If MLP 111 | h = x 112 | for i in range(self.num_layers - 1): 113 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 114 | return self.linears[-1](h) 115 | 116 | 117 | class UVNetCurveEncoder(nn.Module): 118 | def __init__(self, in_channels=6, output_dims=64): 119 | """ 120 | This is the 1D convolutional network that extracts features from the B-rep edge 121 | geometry described as 1D UV-grids (see Section 3.2, Curve & surface convolution 122 | in paper) 123 | 124 | Args: 125 | in_channels (int, optional): Number of channels in the edge UV-grids. By default 126 | we expect 3 channels for point coordinates and 3 for 127 | curve tangents. Defaults to 6. 128 | output_dims (int, optional): Output curve embedding dimension. Defaults to 64. 129 | """ 130 | super(UVNetCurveEncoder, self).__init__() 131 | self.in_channels = in_channels 132 | self.conv1 = _conv1d(in_channels, 64, kernel_size=3, padding=1, bias=False) 133 | self.conv2 = _conv1d(64, 128, kernel_size=3, padding=1, bias=False) 134 | self.conv3 = _conv1d(128, 256, kernel_size=3, padding=1, bias=False) 135 | self.final_pool = nn.AdaptiveAvgPool1d(1) 136 | self.fc = _fc(256, output_dims, bias=False) 137 | 138 | for m in self.modules(): 139 | self.weights_init(m) 140 | 141 | def weights_init(self, m): 142 | if isinstance(m, (nn.Linear, nn.Conv1d)): 143 | torch.nn.init.kaiming_uniform_(m.weight.data) 144 | if m.bias is not None: 145 | m.bias.data.fill_(0.0) 146 | 147 | def forward(self, x): 148 | assert x.size(1) == self.in_channels 149 | batch_size = x.size(0) 150 | x = self.conv1(x) 151 | x = self.conv2(x) 152 | x = self.conv3(x) 153 | x = self.final_pool(x) 154 | x = x.view(batch_size, -1) 155 | x = self.fc(x) 156 | return x 157 | 158 | 159 | class UVNetSurfaceEncoder(nn.Module): 160 | def __init__( 161 | self, 162 | in_channels=7, 163 | output_dims=64, 164 | ): 165 | """ 166 | This is the 2D convolutional network that extracts features from the B-rep face 167 | geometry described as 2D UV-grids (see Section 3.2, Curve & surface convolution 168 | in paper) 169 | 170 | Args: 171 | in_channels (int, optional): Number of channels in the edge UV-grids. By default 172 | we expect 3 channels for point coordinates and 3 for 173 | surface normals and 1 for the trimming mask. Defaults 174 | to 7. 175 | output_dims (int, optional): Output surface embedding dimension. Defaults to 64. 176 | """ 177 | super(UVNetSurfaceEncoder, self).__init__() 178 | self.in_channels = in_channels 179 | self.conv1 = _conv2d(in_channels, 64, 3, padding=1, bias=False) 180 | self.conv2 = _conv2d(64, 128, 3, padding=1, bias=False) 181 | self.conv3 = _conv2d(128, 256, 3, padding=1, bias=False) 182 | self.final_pool = nn.AdaptiveAvgPool2d(1) 183 | self.fc = _fc(256, output_dims, bias=False) 184 | for m in self.modules(): 185 | self.weights_init(m) 186 | 187 | def weights_init(self, m): 188 | if isinstance(m, (nn.Linear, nn.Conv2d)): 189 | torch.nn.init.kaiming_uniform_(m.weight.data) 190 | if m.bias is not None: 191 | m.bias.data.fill_(0.0) 192 | 193 | def forward(self, x): 194 | assert x.size(1) == self.in_channels 195 | batch_size = x.size(0) 196 | x = self.conv1(x) 197 | x = self.conv2(x) 198 | x = self.conv3(x) 199 | x = self.final_pool(x) 200 | x = x.view(batch_size, -1) 201 | x = self.fc(x) 202 | return x 203 | 204 | 205 | 206 | class _EdgeConv(nn.Module): 207 | def __init__( 208 | self, 209 | edge_feats, 210 | out_feats, 211 | node_feats, 212 | num_mlp_layers=2, 213 | hidden_mlp_dim=64, 214 | ): 215 | """ 216 | This module implements Eq. 2 from the paper where the edge features are 217 | updated using the node features at the endpoints. 218 | 219 | Args: 220 | edge_feats (int): Input edge feature dimension 221 | out_feats (int): Output feature deimension 222 | node_feats (int): Input node feature dimension 223 | num_mlp_layers (int, optional): Number of layers used in the MLP. Defaults to 2. 224 | hidden_mlp_dim (int, optional): Hidden feature dimension in the MLP. Defaults to 64. 225 | """ 226 | super(_EdgeConv, self).__init__() 227 | self.proj = _MLP(1, node_feats, hidden_mlp_dim, edge_feats) 228 | self.mlp = _MLP(num_mlp_layers, edge_feats, hidden_mlp_dim, out_feats) 229 | self.batchnorm = nn.BatchNorm1d(out_feats) 230 | self.eps = torch.nn.Parameter(torch.FloatTensor([0.0])) 231 | 232 | def forward(self, graph, nfeat, efeat): 233 | src, dst = graph.edges() 234 | proj1, proj2 = self.proj(nfeat[src]), self.proj(nfeat[dst]) 235 | agg = proj1 + proj2 236 | h = self.mlp((1 + self.eps) * efeat + agg) 237 | h = F.leaky_relu(self.batchnorm(h)) 238 | return h 239 | 240 | 241 | class _NodeConv(nn.Module): 242 | def __init__( 243 | self, 244 | node_feats, 245 | out_feats, 246 | edge_feats, 247 | num_mlp_layers=2, 248 | hidden_mlp_dim=64, 249 | ): 250 | """ 251 | This module implements Eq. 1 from the paper where the node features are 252 | updated using the neighboring node and edge features. 253 | 254 | Args: 255 | node_feats (int): Input edge feature dimension 256 | out_feats (int): Output feature deimension 257 | node_feats (int): Input node feature dimension 258 | num_mlp_layers (int, optional): Number of layers used in the MLP. Defaults to 2. 259 | hidden_mlp_dim (int, optional): Hidden feature dimension in the MLP. Defaults to 64. 260 | """ 261 | super(_NodeConv, self).__init__() 262 | self.gconv = NNConv( 263 | in_feats=node_feats, 264 | out_feats=out_feats, 265 | edge_func=nn.Linear(edge_feats, node_feats * out_feats), 266 | aggregator_type="sum", 267 | bias=False, 268 | ) 269 | self.batchnorm = nn.BatchNorm1d(out_feats) 270 | self.mlp = _MLP(num_mlp_layers, node_feats, hidden_mlp_dim, out_feats) 271 | self.eps = torch.nn.Parameter(torch.FloatTensor([0.0])) 272 | 273 | def forward(self, graph, nfeat, efeat): 274 | h = (1 + self.eps) * nfeat 275 | h = self.gconv(graph, h, efeat) 276 | h = self.mlp(h) 277 | h = F.leaky_relu(self.batchnorm(h)) 278 | return h 279 | 280 | 281 | class UVNetGraphEncoder(nn.Module): 282 | def __init__( 283 | self, 284 | input_dim, 285 | input_edge_dim, 286 | output_dim, 287 | hidden_dim=64, 288 | learn_eps=True, 289 | num_layers=3, 290 | num_mlp_layers=2, 291 | ): 292 | """ 293 | This is the graph neural network used for message-passing features in the 294 | face-adjacency graph. (see Section 3.2, Message passing in paper) 295 | 296 | Args: 297 | input_dim ([type]): [description] 298 | input_edge_dim ([type]): [description] 299 | output_dim ([type]): [description] 300 | hidden_dim (int, optional): [description]. Defaults to 64. 301 | learn_eps (bool, optional): [description]. Defaults to True. 302 | num_layers (int, optional): [description]. Defaults to 3. 303 | num_mlp_layers (int, optional): [description]. Defaults to 2. 304 | """ 305 | super(UVNetGraphEncoder, self).__init__() 306 | self.num_layers = num_layers 307 | self.learn_eps = learn_eps 308 | 309 | # List of layers for node and edge feature message passing 310 | self.node_conv_layers = torch.nn.ModuleList() 311 | self.edge_conv_layers = torch.nn.ModuleList() 312 | 313 | for layer in range(self.num_layers - 1): 314 | node_feats = input_dim if layer == 0 else hidden_dim 315 | edge_feats = input_edge_dim if layer == 0 else hidden_dim 316 | self.node_conv_layers.append( 317 | _NodeConv( 318 | node_feats=node_feats, 319 | out_feats=hidden_dim, 320 | edge_feats=edge_feats, 321 | num_mlp_layers=num_mlp_layers, 322 | hidden_mlp_dim=hidden_dim, 323 | ), 324 | ) 325 | self.edge_conv_layers.append( 326 | _EdgeConv( 327 | edge_feats=edge_feats, 328 | out_feats=hidden_dim, 329 | node_feats=node_feats, 330 | num_mlp_layers=num_mlp_layers, 331 | hidden_mlp_dim=hidden_dim, 332 | ) 333 | ) 334 | 335 | # Linear function for graph poolings of output of each layer 336 | # which maps the output of different layers into a prediction score 337 | self.linears_prediction = torch.nn.ModuleList() 338 | 339 | for layer in range(num_layers): 340 | if layer == 0: 341 | self.linears_prediction.append(nn.Linear(input_dim, output_dim)) 342 | else: 343 | self.linears_prediction.append(nn.Linear(hidden_dim, output_dim)) 344 | 345 | self.drop1 = nn.Dropout(0.3) 346 | self.drop = nn.Dropout(0.5) 347 | self.pool = MaxPooling() 348 | 349 | def forward(self, g, h, efeat): 350 | hidden_rep = [h] 351 | he = efeat 352 | 353 | for i in range(self.num_layers - 1): 354 | # Update node features 355 | h = self.node_conv_layers[i](g, h, he) 356 | # Update edge features 357 | he = self.edge_conv_layers[i](g, h, he) 358 | hidden_rep.append(h) 359 | 360 | out = hidden_rep[-1] 361 | out = self.drop1(out) 362 | score_over_layer = 0 363 | 364 | # Perform pooling over all nodes in each graph in every layer 365 | for i, h in enumerate(hidden_rep): 366 | pooled_h = self.pool(g, h) 367 | score_over_layer += self.drop(self.linears_prediction[i](pooled_h)) 368 | 369 | return out, score_over_layer 370 | -------------------------------------------------------------------------------- /uvnet/models.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torchmetrics 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | import uvnet.encoders 7 | 8 | 9 | class _NonLinearClassifier(nn.Module): 10 | def __init__(self, input_dim, num_classes, dropout=0.3): 11 | """ 12 | A 3-layer MLP with linear outputs 13 | 14 | Args: 15 | input_dim (int): Dimension of the input tensor 16 | num_classes (int): Dimension of the output logits 17 | dropout (float, optional): Dropout used after each linear layer. Defaults to 0.3. 18 | """ 19 | super().__init__() 20 | self.linear1 = nn.Linear(input_dim, 512, bias=False) 21 | self.bn1 = nn.BatchNorm1d(512) 22 | self.dp1 = nn.Dropout(p=dropout) 23 | self.linear2 = nn.Linear(512, 256, bias=False) 24 | self.bn2 = nn.BatchNorm1d(256) 25 | self.dp2 = nn.Dropout(p=dropout) 26 | self.linear3 = nn.Linear(256, num_classes) 27 | 28 | for m in self.modules(): 29 | self.weights_init(m) 30 | 31 | def weights_init(self, m): 32 | if isinstance(m, nn.Linear): 33 | torch.nn.init.kaiming_uniform_(m.weight.data) 34 | if m.bias is not None: 35 | m.bias.data.fill_(0.0) 36 | 37 | def forward(self, inp): 38 | """ 39 | Forward pass 40 | 41 | Args: 42 | inp (torch.tensor): Inputs features to be mapped to logits 43 | (batch_size x input_dim) 44 | 45 | Returns: 46 | torch.tensor: Logits (batch_size x num_classes) 47 | """ 48 | x = F.relu(self.bn1(self.linear1(inp))) 49 | x = self.dp1(x) 50 | x = F.relu(self.bn2(self.linear2(x))) 51 | x = self.dp2(x) 52 | x = self.linear3(x) 53 | return x 54 | 55 | 56 | ############################################################################### 57 | # Classification model 58 | ############################################################################### 59 | 60 | 61 | class UVNetClassifier(nn.Module): 62 | """ 63 | UV-Net solid classification model 64 | """ 65 | 66 | def __init__( 67 | self, 68 | num_classes, 69 | crv_emb_dim=64, 70 | srf_emb_dim=64, 71 | graph_emb_dim=128, 72 | dropout=0.3, 73 | ): 74 | """ 75 | Initialize the UV-Net solid classification model 76 | 77 | Args: 78 | num_classes (int): Number of classes to output 79 | crv_emb_dim (int, optional): Embedding dimension for the 1D edge UV-grids. Defaults to 64. 80 | srf_emb_dim (int, optional): Embedding dimension for the 2D face UV-grids. Defaults to 64. 81 | graph_emb_dim (int, optional): Embedding dimension for the graph. Defaults to 128. 82 | dropout (float, optional): Dropout for the final non-linear classifier. Defaults to 0.3. 83 | """ 84 | super().__init__() 85 | self.curv_encoder = uvnet.encoders.UVNetCurveEncoder( 86 | in_channels=6, output_dims=crv_emb_dim 87 | ) 88 | self.surf_encoder = uvnet.encoders.UVNetSurfaceEncoder( 89 | in_channels=7, output_dims=srf_emb_dim 90 | ) 91 | self.graph_encoder = uvnet.encoders.UVNetGraphEncoder( 92 | srf_emb_dim, crv_emb_dim, graph_emb_dim, 93 | ) 94 | self.clf = _NonLinearClassifier(graph_emb_dim, num_classes, dropout) 95 | 96 | def forward(self, batched_graph): 97 | """ 98 | Forward pass 99 | 100 | Args: 101 | batched_graph (dgl.Graph): A batched DGL graph containing the face 2D UV-grids in node features 102 | (ndata['x']) and 1D edge UV-grids in the edge features (edata['x']). 103 | 104 | Returns: 105 | torch.tensor: Logits (batch_size x num_classes) 106 | """ 107 | # Input features 108 | input_crv_feat = batched_graph.edata["x"] 109 | input_srf_feat = batched_graph.ndata["x"] 110 | # Compute hidden edge and face features 111 | hidden_crv_feat = self.curv_encoder(input_crv_feat) 112 | hidden_srf_feat = self.surf_encoder(input_srf_feat) 113 | # Message pass and compute per-face(node) and global embeddings 114 | # Per-face embeddings are ignored during solid classification 115 | _, graph_emb = self.graph_encoder( 116 | batched_graph, hidden_srf_feat, hidden_crv_feat 117 | ) 118 | # Map to logits 119 | out = self.clf(graph_emb) 120 | return out 121 | 122 | 123 | class Classification(pl.LightningModule): 124 | """ 125 | PyTorch Lightning module to train/test the classifier. 126 | """ 127 | 128 | def __init__(self, num_classes): 129 | """ 130 | Args: 131 | num_classes (int): Number of per-solid classes in the dataset 132 | """ 133 | super().__init__() 134 | self.save_hyperparameters() 135 | self.model = UVNetClassifier(num_classes=num_classes) 136 | self.train_acc = torchmetrics.Accuracy() 137 | self.val_acc = torchmetrics.Accuracy() 138 | self.test_acc = torchmetrics.Accuracy() 139 | 140 | def forward(self, batched_graph): 141 | logits = self.model(batched_graph) 142 | return logits 143 | 144 | def training_step(self, batch, batch_idx): 145 | inputs = batch["graph"].to(self.device) 146 | labels = batch["label"].to(self.device) 147 | inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2) 148 | inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1) 149 | logits = self.model(inputs) 150 | loss = F.cross_entropy(logits, labels, reduction="mean") 151 | self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 152 | preds = F.softmax(logits, dim=-1) 153 | self.log("train_acc", self.train_acc(preds, labels), on_step=False, on_epoch=True, sync_dist=True) 154 | return loss 155 | 156 | def validation_step(self, batch, batch_idx): 157 | inputs = batch["graph"].to(self.device) 158 | labels = batch["label"].to(self.device) 159 | inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2) 160 | inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1) 161 | logits = self.model(inputs) 162 | loss = F.cross_entropy(logits, labels, reduction="mean") 163 | self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 164 | preds = F.softmax(logits, dim=-1) 165 | self.log("val_acc", self.val_acc(preds, labels), on_step=False, on_epoch=True, sync_dist=True) 166 | return loss 167 | 168 | def test_step(self, batch, batch_idx): 169 | inputs = batch["graph"].to(self.device) 170 | labels = batch["label"].to(self.device) 171 | inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2) 172 | inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1) 173 | logits = self.model(inputs) 174 | loss = F.cross_entropy(logits, labels, reduction="mean") 175 | self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 176 | preds = F.softmax(logits, dim=-1) 177 | self.log("test_acc", self.test_acc(preds, labels), on_step=False, on_epoch=True, sync_dist=True) 178 | 179 | def configure_optimizers(self): 180 | optimizer = torch.optim.Adam(self.parameters()) 181 | return optimizer 182 | 183 | 184 | ############################################################################### 185 | # Segmentation model 186 | ############################################################################### 187 | 188 | 189 | class UVNetSegmenter(nn.Module): 190 | """ 191 | UV-Net solid face segmentation model 192 | """ 193 | 194 | def __init__( 195 | self, 196 | num_classes, 197 | crv_in_channels=6, 198 | crv_emb_dim=64, 199 | srf_emb_dim=64, 200 | graph_emb_dim=128, 201 | dropout=0.3, 202 | ): 203 | """ 204 | Initialize the UV-Net solid face segmentation model 205 | 206 | Args: 207 | num_classes (int): Number of classes to output per-face 208 | crv_in_channels (int, optional): Number of input channels for the 1D edge UV-grids 209 | crv_emb_dim (int, optional): Embedding dimension for the 1D edge UV-grids. Defaults to 64. 210 | srf_emb_dim (int, optional): Embedding dimension for the 2D face UV-grids. Defaults to 64. 211 | graph_emb_dim (int, optional): Embedding dimension for the graph. Defaults to 128. 212 | dropout (float, optional): Dropout for the final non-linear classifier. Defaults to 0.3. 213 | """ 214 | super().__init__() 215 | # A 1D convolutional network to encode B-rep edge geometry represented as 1D UV-grids 216 | self.curv_encoder = uvnet.encoders.UVNetCurveEncoder( 217 | in_channels=crv_in_channels, output_dims=crv_emb_dim 218 | ) 219 | # A 2D convolutional network to encode B-rep face geometry represented as 2D UV-grids 220 | self.surf_encoder = uvnet.encoders.UVNetSurfaceEncoder( 221 | in_channels=7, output_dims=srf_emb_dim 222 | ) 223 | # A graph neural network that message passes face and edge features 224 | self.graph_encoder = uvnet.encoders.UVNetGraphEncoder( 225 | srf_emb_dim, crv_emb_dim, graph_emb_dim, 226 | ) 227 | # A non-linear classifier that maps face embeddings to face logits 228 | self.seg = _NonLinearClassifier( 229 | graph_emb_dim + srf_emb_dim, num_classes, dropout=dropout 230 | ) 231 | 232 | def forward(self, batched_graph): 233 | """ 234 | Forward pass 235 | 236 | Args: 237 | batched_graph (dgl.Graph): A batched DGL graph containing the face 2D UV-grids in node features 238 | (ndata['x']) and 1D edge UV-grids in the edge features (edata['x']). 239 | 240 | Returns: 241 | torch.tensor: Logits (total_nodes_in_batch x num_classes) 242 | """ 243 | # Input features 244 | input_crv_feat = batched_graph.edata["x"] 245 | input_srf_feat = batched_graph.ndata["x"] 246 | # Compute hidden edge and face features 247 | hidden_crv_feat = self.curv_encoder(input_crv_feat) 248 | hidden_srf_feat = self.surf_encoder(input_srf_feat) 249 | # Message pass and compute per-face(node) and global embeddings 250 | node_emb, graph_emb = self.graph_encoder( 251 | batched_graph, hidden_srf_feat, hidden_crv_feat 252 | ) 253 | # Repeat the global graph embedding so that it can be 254 | # concatenated to the per-node embeddings 255 | num_nodes_per_graph = batched_graph.batch_num_nodes().to(graph_emb.device) 256 | graph_emb = graph_emb.repeat_interleave(num_nodes_per_graph, dim=0).to(graph_emb.device) 257 | local_global_feat = torch.cat((node_emb, graph_emb), dim=1) 258 | # Map to logits 259 | out = self.seg(local_global_feat) 260 | return out 261 | 262 | 263 | class Segmentation(pl.LightningModule): 264 | """ 265 | PyTorch Lightning module to train/test the segmenter (per-face classifier). 266 | """ 267 | 268 | def __init__(self, num_classes, crv_in_channels=6): 269 | """ 270 | Args: 271 | num_classes (int): Number of per-face classes in the dataset 272 | """ 273 | super().__init__() 274 | self.save_hyperparameters() 275 | self.model = UVNetSegmenter(num_classes, crv_in_channels=crv_in_channels) 276 | # Setting compute_on_step = False to compute "part IoU" 277 | # This is because we want to compute the IoU on the entire dataset 278 | # at the end to account for rare classes, rather than within each batch 279 | self.train_iou = torchmetrics.IoU( 280 | num_classes=num_classes, compute_on_step=False 281 | ) 282 | self.val_iou = torchmetrics.IoU(num_classes=num_classes, compute_on_step=False) 283 | self.test_iou = torchmetrics.IoU(num_classes=num_classes, compute_on_step=False) 284 | 285 | self.train_accuracy = torchmetrics.Accuracy( 286 | num_classes=num_classes, compute_on_step=False 287 | ) 288 | self.val_accuracy = torchmetrics.Accuracy( 289 | num_classes=num_classes, compute_on_step=False 290 | ) 291 | self.test_accuracy = torchmetrics.Accuracy( 292 | num_classes=num_classes, compute_on_step=False 293 | ) 294 | 295 | def forward(self, batched_graph): 296 | logits = self.model(batched_graph) 297 | return logits 298 | 299 | def training_step(self, batch, batch_idx): 300 | inputs = batch["graph"].to(self.device) 301 | inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2) 302 | inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1) 303 | labels = inputs.ndata["y"] 304 | logits = self.model(inputs) 305 | loss = F.cross_entropy(logits, labels, reduction="mean") 306 | self.log("train_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 307 | preds = F.softmax(logits, dim=-1) 308 | self.train_iou(preds, labels) 309 | self.train_accuracy(preds, labels) 310 | return loss 311 | 312 | def training_epoch_end(self, outs): 313 | self.log("train_iou", self.train_iou.compute()) 314 | self.log("train_accuracy", self.train_accuracy.compute()) 315 | 316 | def validation_step(self, batch, batch_idx): 317 | inputs = batch["graph"].to(self.device) 318 | inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2) 319 | inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1) 320 | labels = inputs.ndata["y"] 321 | logits = self.model(inputs) 322 | loss = F.cross_entropy(logits, labels, reduction="mean") 323 | self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 324 | preds = F.softmax(logits, dim=-1) 325 | self.val_iou(preds, labels) 326 | self.val_accuracy(preds, labels) 327 | return loss 328 | 329 | def validation_epoch_end(self, outs): 330 | self.log("val_iou", self.val_iou.compute()) 331 | self.log("val_accuracy", self.val_accuracy.compute()) 332 | 333 | def test_step(self, batch, batch_idx): 334 | inputs = batch["graph"].to(self.device) 335 | inputs.ndata["x"] = inputs.ndata["x"].permute(0, 3, 1, 2) 336 | inputs.edata["x"] = inputs.edata["x"].permute(0, 2, 1) 337 | labels = inputs.ndata["y"] 338 | logits = self.model(inputs) 339 | loss = F.cross_entropy(logits, labels, reduction="mean") 340 | self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 341 | preds = F.softmax(logits, dim=-1) 342 | self.test_iou(preds, labels) 343 | self.test_accuracy(preds, labels) 344 | 345 | def test_epoch_end(self, outs): 346 | self.log("test_iou", self.test_iou.compute()) 347 | self.log("test_accuracy", self.test_accuracy.compute()) 348 | 349 | def configure_optimizers(self): 350 | optimizer = torch.optim.Adam(self.parameters()) 351 | return optimizer 352 | --------------------------------------------------------------------------------