├── utils ├── __init__.py ├── geometry │ ├── curve.py │ ├── line.py │ ├── circle.py │ ├── arc.py │ ├── geom_utils.py │ ├── obj_utils.py │ └── obj_parser.py ├── parse_obj2seq.py ├── data.py ├── sample_points.py ├── cad_img.py ├── visual_obj.py ├── convert.py ├── eval_dclip.py ├── eval_cad.py ├── directional_clip_score.py ├── obj_reconverter.py └── parse_seq2obj.py ├── data ├── processed.zip ├── merge.py ├── filter_sequence.py ├── pair.py ├── caption_sequence.py └── caption_image.py ├── CODE_OF_CONDUCT.md ├── finetune ├── ds_config.yaml ├── create_mask.py ├── llama_sample.py └── llama_finetune.py ├── environment.yaml ├── LICENSE ├── SUPPORT.md ├── SECURITY.md ├── .gitignore ├── prompt ├── gpt4_basic.py └── gpt4_fs3.py ├── hnc-cad └── ac_gen.py └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/processed.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CAD-Editor/HEAD/data/processed.zip -------------------------------------------------------------------------------- /utils/geometry/curve.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class Curve(): 5 | def __init__(self, point_indices, point_data): 6 | self.point_indices = point_indices 7 | self.point_geom = point_data[point_indices, 0:2] 8 | 9 | def verts_to_bbox(self, verts): 10 | xs = [v[0] for v in verts] 11 | ys = [v[1] for v in verts] 12 | bbox = [min(xs), max(xs), min(ys), max(ys)] 13 | return bbox -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /finetune/ds_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: fp16 15 | num_machines: 1 16 | num_processes: 4 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false -------------------------------------------------------------------------------- /utils/geometry/line.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from geometry.curve import Curve 3 | 4 | class Line(Curve): 5 | def __init__(self, point_indices, point_data, is_outer): 6 | assert len(point_indices) == 2, "Line must be defined by two points" 7 | assert point_data is not None 8 | super(Line, self).__init__(point_indices, point_data) 9 | pt0 = self.point_geom[0, :] 10 | pt1 = self.point_geom[1, :] 11 | self.type = 'line' 12 | self.start = pt0 13 | self.end = pt1 14 | self.start_idx = point_indices[0] 15 | self.end_idx = point_indices[1] 16 | self.is_outer = is_outer 17 | 18 | self.bbox = self.verts_to_bbox(np.vstack([pt0, pt1])) 19 | self.bottom_left = np.array([self.bbox[0], self.bbox[2]]) 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /utils/parse_obj2seq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from data import SE 4 | import pickle 5 | 6 | NUM_TRHEADS = 36 7 | NUM_FOLDERS = 100 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--input", type=str, required=True, help="Input folder of the CAD obj (after normalization)") 12 | parser.add_argument("--bit", type=int, required=True, help='Number of bits for quantization') 13 | parser.add_argument("--output", type=str, required=True, help="Output file path to save the data") 14 | args = parser.parse_args() 15 | 16 | # Create output directory if it doesn't exist 17 | output_dir = os.path.dirname(args.output) 18 | if output_dir and not os.path.exists(output_dir): 19 | os.makedirs(output_dir, exist_ok=True) 20 | 21 | # Start creating dataset 22 | parser = SE(start=0, end=NUM_FOLDERS, datapath=args.input, bit=args.bit, threads=NUM_TRHEADS) # number of threads in your pc 23 | train_samples = parser.load_all_obj() 24 | 25 | # Save to file 26 | with open(args.output, "wb") as tf: 27 | pickle.dump(train_samples, tf) -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: cad-editor 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | 7 | dependencies: 8 | - python=3.10 9 | - numpy=1.26.4 10 | - scipy=1.13.1 11 | - opencv 12 | - pythonocc-core=7.7.2 13 | - pip 14 | 15 | - pip: 16 | - accelerate==0.28.0 17 | - azure-identity==1.21.0 18 | - git+https://github.com/openai/CLIP.git 19 | - git+https://github.com/otaheri/chamfer_distance@dc9987dcf70888d387d96893ba1fb9ba9a333992 20 | - datasets==2.18.0 21 | - deepspeed==0.15.4 22 | - diffusers==0.20.0 23 | - gitdb==4.0.11 24 | - gitpython==3.1.42 25 | - huggingface-hub==0.30.0 26 | - multiprocess==0.70.16 27 | - ninja==1.11.1.1 28 | - openai==1.59.7 29 | - peft==0.9.0 30 | - pillow==10.3.0 31 | - plyfile==1.1 32 | - pytz==2024.1 33 | - requests==2.32.4 34 | - safetensors==0.4.3 35 | - torch=2.5.1 36 | - torchvision=0.20.1 37 | - torchaudio=2.5.1 38 | - tensorboard==2.16.2 39 | - tqdm==4.66.3 40 | - transformers==4.52.1 41 | - trimesh==4.6.0 42 | - wandb==0.16.4 43 | -------------------------------------------------------------------------------- /data/merge.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | def merge_json_files(args): 5 | """ 6 | Merge two JSON files 7 | """ 8 | with open(args.file1, 'r', encoding='utf-8') as f: 9 | data1 = json.load(f) 10 | 11 | with open(args.file2, 'r', encoding='utf-8') as f: 12 | data2 = json.load(f) 13 | 14 | merged_data = data1 + data2 15 | 16 | with open(args.output, 'w', encoding='utf-8') as f: 17 | json.dump(merged_data, f, ensure_ascii=False, indent=2) 18 | 19 | print(f"Merge completed!") 20 | print(f"File 1: {len(data1)} items") 21 | print(f"File 2: {len(data2)} items") 22 | print(f"Merged: {len(merged_data)} items") 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description="Merge two JSON files") 26 | parser.add_argument("--file1", type=str, required=True, help="Path to first JSON file") 27 | parser.add_argument("--file2", type=str, required=True, help="Path to second JSON file") 28 | parser.add_argument("--output", type=str, required=True, help="Output file path") 29 | 30 | args = parser.parse_args() 31 | merge_json_files(args) -------------------------------------------------------------------------------- /utils/geometry/circle.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from geometry.curve import Curve 3 | import pdb 4 | 5 | class Circle(Curve): 6 | def __init__(self, point_indices, point_data, is_outer): 7 | assert len(point_indices) == 2, "Circle must be defined by 1 points" 8 | assert point_data is not None 9 | super(Circle, self).__init__(point_indices, point_data) 10 | self.type = 'circle' 11 | self.center = self.point_geom[0, :] 12 | self.radius = self.point_geom[1, 0] 13 | self.center_idx = point_indices[0] 14 | self.radius_idx = point_indices[1] 15 | self.is_outer = is_outer 16 | 17 | self.pt1 = np.array([self.center[0], self.center[1]+self.radius]) 18 | self.pt2 = np.array([self.center[0], self.center[1]-self.radius]) 19 | self.pt3 = np.array([self.center[0]+self.radius, self.center[1]]) 20 | self.pt4 = np.array([self.center[0]-self.radius, self.center[1]]) 21 | self.bbox = self.verts_to_bbox(np.vstack([self.pt1, self.pt2, self.pt3, self.pt4])) 22 | self.bottom_left = np.array([self.bbox[0], self.bbox[2]]) 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /utils/geometry/arc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from geometry.curve import Curve 4 | 5 | 6 | class Arc(Curve): 7 | def __init__(self, point_indices, point_data, is_outer): 8 | assert len(point_indices) == 4, "Arc must be defined by 3 points" 9 | assert point_data is not None 10 | super(Arc, self).__init__(point_indices, point_data) 11 | self.type = 'arc' 12 | self.is_outer = is_outer 13 | self.start = self.point_geom[0, :] 14 | self.mid = self.point_geom[1, :] 15 | self.center = self.point_geom[2, :] 16 | self.end = self.point_geom[3, :] 17 | 18 | self.r1 = math.sqrt( (self.start[0] - self.center[0])**2 + (self.start[1] - self.center[1])**2 ) 19 | self.r2 = math.sqrt( (self.end[0] - self.center[0])**2 + (self.end[1] - self.center[1])**2 ) 20 | self.radius = (self.r1+self.r2)/2 21 | 22 | self.start_idx = point_indices[0] 23 | self.mid_idx = point_indices[1] 24 | self.center_idx = point_indices[2] 25 | self.end_idx = point_indices[3] 26 | 27 | self.bbox = self.verts_to_bbox(np.vstack([self.start, self.end, self.mid])) 28 | self.bottom_left = np.array([self.bbox[0], self.bbox[2]]) 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /data/filter_sequence.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | def filter(args): 5 | with open(args.in_path, 'r') as file: 6 | data = json.load(file) 7 | 8 | filtered_items = [] 9 | 10 | for item in data: 11 | original_sequence = item['original_sequence'] 12 | edited_sequence = item['edited_sequence'] 13 | 14 | original_extrude_count = original_sequence.count('') 15 | 16 | edited_extrude_count = edited_sequence.count('') 17 | 18 | # Check if an instruction contains 3 or more edits 19 | instruction = item['instruction'].lower() 20 | punctuation_count = instruction.count(',') + instruction.count(';') 21 | 22 | wordlist = ["no transformation", "are identical"] 23 | # Check if any word in wordlist is in the instruction 24 | contains_wordlist = any(word in instruction for word in wordlist) 25 | 26 | if (original_extrude_count <= 3 and edited_extrude_count <= 3 and 27 | punctuation_count <= 2 and not contains_wordlist): 28 | filtered_items.append(item) 29 | 30 | with open(args.out_path, 'w') as file: 31 | json.dump(filtered_items, file, indent=4) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--in_path", type=str, required=True) 37 | parser.add_argument("--out_path", type=str, required=True) 38 | args = parser.parse_args() 39 | 40 | filter(args) 41 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | from util import process_obj_se 2 | from tqdm import tqdm 3 | from multiprocessing import Pool 4 | import json 5 | from pathlib import Path 6 | from glob import glob 7 | import itertools 8 | 9 | 10 | class SE(): 11 | """ sketch-extrude dataset """ 12 | def __init__(self, start, end, datapath, bit, threads=16): 13 | self.start = start 14 | self.end = end 15 | self.datapath = datapath 16 | self.threads = threads 17 | self.bit = bit 18 | 19 | def load_all_obj(self): 20 | print("Loading obj data...") 21 | 22 | # with open('../data/train_val_test_split.json') as f: 23 | # data_split = json.load(f) 24 | 25 | project_folders = [] 26 | cur_dir = Path(self.datapath) 27 | # print(cur_dir) 28 | project_folders += glob(str(cur_dir)+'/*/') 29 | print(project_folders) 30 | # Parallel loader 31 | iter_data = zip( 32 | project_folders, 33 | itertools.repeat(self.bit), 34 | ) 35 | samples = [] 36 | load_iter = Pool(self.threads).imap(process_obj_se, iter_data) 37 | for data_sample in tqdm(load_iter, total=len(project_folders)): 38 | samples += data_sample 39 | 40 | print('Splitting data...') 41 | train_samples = [] 42 | 43 | for data in tqdm(samples): 44 | train_samples.append(data) # put into training if no match 45 | 46 | print(f"Data Summary") 47 | print(f"\tTraining data: {len(train_samples)}") 48 | # print(f"\tValidation data: {len(val_samples)}") 49 | # print(f"\tTest data: {len(test_samples)}") 50 | return train_samples 51 | 52 | -------------------------------------------------------------------------------- /data/pair.py: -------------------------------------------------------------------------------- 1 | import json 2 | from collections import defaultdict 3 | from itertools import combinations 4 | import argparse 5 | 6 | def pair_cad(args): 7 | with open(args.in_path, "r", encoding="utf-8") as f: 8 | data = json.load(f) 9 | 10 | # Group items by the first 8 characters of their name 11 | groups = defaultdict(list) 12 | for item in data: 13 | key = item["name"][:8] 14 | groups[key].append(item) 15 | 16 | result = [] 17 | 18 | for key, items in groups.items(): 19 | if len(items) >= 2: 20 | temp_result = [] 21 | 22 | for item1, item2 in combinations(items, 2): 23 | # Determine the type of change 24 | def determine_type(original_name, edited_name): 25 | if "origInput" in original_name: 26 | return "add" 27 | elif "origInput" in edited_name: 28 | return "delete" 29 | else: 30 | return "modify" 31 | 32 | type1 = determine_type(item1["name"], item2["name"]) 33 | type2 = determine_type(item2["name"], item1["name"]) 34 | 35 | # Forward combination 36 | temp_result.append({ 37 | "original_pic_name": item1["name"], 38 | "edited_pic_name": item2["name"], 39 | "original_sequence": item1["original_sequence"], 40 | "edited_sequence": item2["original_sequence"], 41 | "type": type1 42 | }) 43 | 44 | # Reverse combination 45 | temp_result.append({ 46 | "original_pic_name": item2["name"], 47 | "edited_pic_name": item1["name"], 48 | "original_sequence": item2["original_sequence"], 49 | "edited_sequence": item1["original_sequence"], 50 | "type": type2 51 | }) 52 | 53 | # Stop if 56 results have been generated for this group 54 | if len(temp_result) >= 56: 55 | break 56 | 57 | result.extend(temp_result) 58 | 59 | with open(args.out_path, "w", encoding="utf-8") as f: 60 | json.dump(result, f, ensure_ascii=False, indent=4) 61 | 62 | if __name__ == "__main__": 63 | parser = argparse.ArgumentParser() 64 | parser.add_argument("--in_path", type=str, required=True, help="Path to the input JSON file") 65 | parser.add_argument("--out_path", type=str, required=True, help="Path to the output JSON file") 66 | args = parser.parse_args() 67 | 68 | pair_cad(args) 69 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /utils/sample_points.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import ntpath 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | from pathlib import Path 7 | from glob import glob 8 | import trimesh 9 | from trimesh.sample import sample_surface 10 | from plyfile import PlyData, PlyElement 11 | import numpy as np 12 | 13 | NUM_TRHEADS = 36 14 | 15 | def write_ply(points, filename, text=False): 16 | """ input: Nx3, write points to filename as PLY format. """ 17 | points = [(points[i,0], points[i,1], points[i,2]) for i in range(points.shape[0])] 18 | vertex = np.array(points, dtype=[('x', 'f4'), ('y', 'f4'),('z', 'f4')]) 19 | el = PlyElement.describe(vertex, 'vertex', comments=['vertices']) 20 | with open(filename, mode='wb') as f: 21 | PlyData([el], text=text).write(f) 22 | 23 | 24 | def find_files(folder, extension): 25 | return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)]) 26 | 27 | class SamplePoints: 28 | """ 29 | Perform sampleing of points. 30 | """ 31 | 32 | def __init__(self): 33 | """ 34 | Constructor. 35 | """ 36 | parser = self.get_parser() 37 | self.options = parser.parse_args() 38 | 39 | 40 | def get_parser(self): 41 | """ 42 | Get parser of tool. 43 | 44 | :return: parser 45 | """ 46 | parser = argparse.ArgumentParser(description='Scale a set of meshes stored as OFF files.') 47 | parser.add_argument('--in_dir', type=str, help='Path to input directory.') 48 | parser.add_argument('--out_dir', type=str, help='Path to output directory; files within are overwritten!') 49 | return parser 50 | 51 | 52 | def run_parallel(self, project_folder): 53 | out_folder = os.path.join(project_folder, self.options.out_dir) 54 | if not os.path.exists(out_folder): 55 | os.makedirs(out_folder) 56 | 57 | files = find_files(project_folder, 'final.stl') 58 | 59 | for filepath in files: 60 | N_POINTS = 2000 61 | try: 62 | out_mesh = trimesh.load(str(filepath)) 63 | out_pc, _ = sample_surface(out_mesh, N_POINTS) 64 | save_path = os.path.join(out_folder, ntpath.basename(filepath)[:-4]+'_pcd.ply') 65 | write_ply(out_pc, save_path) 66 | 67 | except Exception as ex: 68 | return project_folder 69 | return 70 | 71 | 72 | def run(self): 73 | """ 74 | Run simplification. 75 | """ 76 | project_folders = sorted(glob(self.options.in_dir+'/*/')) 77 | convert_iter = Pool(NUM_TRHEADS).imap(self.run_parallel, project_folders) 78 | for _ in tqdm(convert_iter, total=len(project_folders)): 79 | pass 80 | 81 | 82 | if __name__ == '__main__': 83 | app = SamplePoints() 84 | app.run() 85 | -------------------------------------------------------------------------------- /utils/geometry/geom_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | def angle_from_vector_to_x(vec): 5 | assert vec.size == 2 6 | # We need to find a unit vector 7 | angle = 0.0 8 | 9 | l = np.linalg.norm(vec) 10 | uvec = vec/l 11 | 12 | # 2 | 1 13 | #------- 14 | # 3 | 4 15 | if uvec[0] >=0: 16 | if uvec[1] >= 0: 17 | # Qadrant 1 18 | angle = math.asin(uvec[1]) 19 | else: 20 | # Qadrant 4 21 | angle = 2.0*math.pi - math.asin(-uvec[1]) 22 | else: 23 | if vec[1] >= 0: 24 | # Qadrant 2 25 | angle = math.pi - math.asin(uvec[1]) 26 | else: 27 | # Qadrant 3 28 | angle = math.pi + math.asin(-uvec[1]) 29 | return angle 30 | 31 | 32 | def convert_angle_to_1to360_range(angle_rad): 33 | """ 34 | Converts the given angle in radians into 1-360 degrees range 35 | """ 36 | angle = math.degrees(angle_rad) 37 | # Lifted from: https://stackoverflow.com/questions/12234574/calculating-if-an-angle-is-between-two-angles 38 | angle=(int(angle) % 360) + (angle-math.trunc(angle)) # converts angle to range -360 + 360 39 | if angle > 0.0: 40 | return angle 41 | else: 42 | return angle + 360.0 43 | 44 | 45 | def angle_is_between(angle_rad, a_rad, b_rad): 46 | """ 47 | Checks if angle is in between the range of a and b 48 | (All angles must be given in radians) 49 | """ 50 | angle = convert_angle_to_1to360_range(angle_rad) 51 | a = convert_angle_to_1to360_range(a_rad) 52 | b = convert_angle_to_1to360_range(b_rad) 53 | if a < b: 54 | return a <= angle and angle <= b 55 | return a <= angle or angle <= b 56 | 57 | 58 | def quantize_verts(verts, n_bits=8): 59 | """Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1].""" 60 | min_range = -0.5 61 | max_range = 0.5 62 | range_quantize = 2 ** n_bits - 1 63 | verts_quantize = (verts - min_range) * range_quantize / (max_range - min_range) 64 | return verts_quantize.astype("int32") 65 | 66 | 67 | def dequantize_verts(verts, n_bits=8, add_noise=False): 68 | """Convert quantized vertices to floats.""" 69 | min_range = -0.5 70 | max_range = 0.5 71 | range_quantize = 2 ** n_bits - 1 72 | verts = verts.astype("float32") 73 | verts = verts * (max_range - min_range) / range_quantize + min_range 74 | if add_noise: 75 | verts += np.random.uniform(size=verts.shape) * (1 / range_quantize) 76 | return verts 77 | 78 | 79 | def center_vertices(vertices): 80 | """Translate the vertices so that bounding box is centered at zero.""" 81 | vert_min = vertices.min(axis=0) 82 | vert_max = vertices.max(axis=0) 83 | vert_center = 0.5 * (vert_min + vert_max) 84 | return vertices - vert_center, vert_center 85 | 86 | 87 | def scale_vertices(vertices): 88 | """Scale the vertices so that the long diagonal of the bounding box is one.""" 89 | vert_min = vertices.min(axis=0) 90 | vert_max = vertices.max(axis=0) 91 | extents = vert_max - vert_min 92 | scale = np.sqrt(np.sum(extents ** 2)) 93 | return vertices / scale, scale 94 | 95 | 96 | -------------------------------------------------------------------------------- /utils/geometry/obj_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from pathlib import Path 4 | import pdb 5 | 6 | 7 | def read_wire_obj(obj_path): 8 | """Read vertices and lines from .obj file defining a wire body.""" 9 | vertex_list = [] 10 | loops = [] 11 | 12 | # Read vertice and curves 13 | with open(obj_path) as obj_file: 14 | 15 | for line in obj_file: 16 | tokens = line.split() 17 | if not tokens: 18 | continue 19 | 20 | line_type = tokens[0] 21 | 22 | if line_type == "v": 23 | vertex_list.append([float(x) for x in tokens[1:]]) 24 | 25 | if line_type == "g": 26 | pdb.set_trace() 27 | 28 | 29 | 30 | 31 | # Read meta data 32 | meta_data = line.strip('# ').strip(' \n').split(' ') 33 | meta_name = meta_data[0] 34 | if meta_name == 'Extrude': 35 | extrude_values= [float(x) for x in meta_data[1:]] 36 | elif meta_name == 'T_origin': 37 | t_orig = [float(x) for x in meta_data[1:]] 38 | elif meta_name == 'T_xaxis': 39 | t_x = [float(x) for x in meta_data[1:]] 40 | elif meta_name == 'T_yaxis': 41 | t_y = [float(x) for x in meta_data[1:]] 42 | elif meta_name == 'T_zaxis': 43 | t_z = [float(x) for x in meta_data[1:]] 44 | elif meta_name == 'ExtrudeOperation:': 45 | set_op = meta_data[1] 46 | 47 | 48 | vertices = np.array(vertex_list) 49 | 50 | 51 | 52 | meta_info = {'extrude_value': extrude_values, 53 | 'set_op': set_op, 54 | 't_orig': t_orig, 55 | 't_x': t_x, 56 | 't_y': t_y, 57 | 't_z': t_z} 58 | 59 | total_in_outs.append(in_outs) 60 | 61 | return np.array(flat_vertices_list, dtype=np.float32), flat_hyperedge, total_in_outs, meta_info 62 | 63 | 64 | def write_wire_obj(vertices, faces, file_path, transpose=True, scale=1.0): 65 | """Write vertices and hyperedges to obj.""" 66 | vertex_dimension = vertices.shape[1] 67 | assert vertex_dimension in (2, 3) 68 | if transpose and vertex_dimension == 3: 69 | # Permute 3D vertices where z comes first followed by x and y 70 | vertices = vertices[:, [1, 2, 0]] 71 | vertices *= scale 72 | if faces is not None: 73 | if len(faces) > 0: 74 | if min(min(faces)) == 0: 75 | f_add = 1 76 | else: 77 | f_add = 0 78 | with open(file_path, "w") as f: 79 | for v in vertices: 80 | if vertex_dimension == 2: 81 | f.write("v {} {} {}\n".format(v[0], v[1], 0.0)) 82 | else: 83 | f.write("v {} {} {}\n".format(v[0], v[1], v[2])) 84 | for face in faces: 85 | line = "l" 86 | for i in face: 87 | # Pradeep: always adding 1 to the face index makes sense to me. Not sure why 88 | # PolyGen does this conditionally (see L95 above) 89 | # Something to note. 90 | line += " {}".format(i + 1) 91 | line += "\n" 92 | f.write(line) 93 | 94 | -------------------------------------------------------------------------------- /utils/cad_img.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from OCC.Core.Graphic3d import * 3 | from OCC.Display.OCCViewer import Viewer3d 4 | from OCC.Extend.DataExchange import read_step_file 5 | from OCC.Extend.TopologyUtils import TopologyExplorer 6 | from OCC.Core.Quantity import Quantity_Color, Quantity_TOC_RGB, Quantity_NOC_WHITE 7 | from OCC.Core.V3d import V3d_DirectionalLight 8 | from OCC.Core.gp import gp_Dir 9 | from glob import glob 10 | import pathlib 11 | from tqdm import tqdm 12 | 13 | 14 | def render(shape, filename, width=1024, height=768, face_color_rgb=(0.2, 0.2, 0.2), edge_color_rgb=(0, 0, 0), show_face_boundary=True): 15 | viewer = Viewer3d() 16 | viewer.Create(phong_shading=True, create_default_lights=True) 17 | viewer.set_bg_gradient_color([255, 255, 255], [255, 255, 255]) 18 | viewer.SetModeShaded() 19 | viewer.hide_triedron() 20 | viewer.EnableAntiAliasing() 21 | dir_light = V3d_DirectionalLight(gp_Dir(0, 0.5, -1), Quantity_Color(Quantity_NOC_WHITE)) 22 | dir_light.SetEnabled(True) 23 | dir_light.SetIntensity(500.0) 24 | viewer.Viewer.AddLight(dir_light) 25 | viewer.Viewer.SetLightOn() 26 | 27 | viewer.default_drawer.EnableDrawHiddenLine() 28 | viewer.default_drawer.SetFaceBoundaryDraw(show_face_boundary) 29 | ais_context = viewer.GetContext() 30 | dc = ais_context.DeviationCoefficient() 31 | da = ais_context.DeviationAngle() 32 | factor = 10 33 | ais_context.SetDeviationCoefficient(dc / factor) 34 | ais_context.SetDeviationAngle(da / factor) 35 | topexp = TopologyExplorer(shape) 36 | for face in topexp.faces(): 37 | if face is not None: 38 | viewer.DisplayShape(face, color=Quantity_Color(*face_color_rgb, Quantity_TOC_RGB)) 39 | for edge in topexp.edges(): 40 | if edge is not None: 41 | viewer.DisplayShape(edge, color=Quantity_Color(*edge_color_rgb, Quantity_TOC_RGB)) 42 | viewer.FitAll() 43 | # Set complementary viewing angle: view model from bottom-left-rear 44 | # viewer.View.SetProj(-1, -1, -1) 45 | viewer.SetSize(width, height) 46 | viewer.View.Dump(str(filename)) 47 | 48 | 49 | def main(): 50 | p = argparse.ArgumentParser() 51 | p.add_argument("--input_dir", type=str, required=True, help="Input folder of STP/STEP files") 52 | p.add_argument("--output_dir", type=str, required=True, help="Output folder of PNG files") 53 | p.add_argument("--width", type=int, default=1024, help="Width of image") 54 | p.add_argument("--height", type=int, default=768, help="Height of image") 55 | 56 | args = p.parse_args() 57 | 58 | files = [] 59 | cad_folders = sorted(glob(args.input_dir+'/*/')) 60 | for folder in cad_folders: 61 | input_path = pathlib.Path(folder) 62 | files += list(input_path.glob("*.st*p")) 63 | print(len(files)) 64 | output_path = pathlib.Path(args.output_dir) 65 | if not output_path.exists(): 66 | output_path.mkdir(parents=True, exist_ok=True) 67 | 68 | i = 0 69 | j = 0 70 | for fn in tqdm(files): 71 | j += 1 72 | shape = read_step_file(str(fn)) 73 | try: 74 | render(shape, output_path.joinpath(fn.stem + ".png"), args.width, args.height) 75 | except: 76 | i += 1 77 | print("error") 78 | continue 79 | print("error number: ", i) 80 | print("total number: ", j) 81 | 82 | if __name__ == "__main__": 83 | main() 84 | -------------------------------------------------------------------------------- /utils/visual_obj.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | from multiprocessing import Pool 6 | from glob import glob 7 | from obj_reconverter import OBJReconverter 8 | from OCC.Core.BRepCheck import BRepCheck_Analyzer 9 | from geometry.obj_parser import OBJParser 10 | from util import write_stl_file 11 | from OCC.Extend.DataExchange import write_step_file 12 | 13 | import signal 14 | from contextlib import contextmanager 15 | @contextmanager 16 | def timeout(time): 17 | # Register a function to raise a TimeoutError on the signal. 18 | signal.signal(signal.SIGALRM, raise_timeout) 19 | # Schedule the signal to be sent after ``time``. 20 | signal.alarm(time) 21 | try: 22 | yield 23 | except TimeoutError: 24 | raise Exception("time out") 25 | finally: 26 | # Unregister the signal so it won't be triggered 27 | # if the timeout is not reached. 28 | signal.signal(signal.SIGALRM, signal.SIG_IGN) 29 | def raise_timeout(signum, frame): 30 | raise TimeoutError 31 | 32 | NUM_TRHEADS = 36 33 | 34 | def find_files(folder, extension): 35 | return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)]) 36 | 37 | 38 | def run_parallel(project_folder): 39 | output_folder = project_folder 40 | 41 | param_objs = find_files(project_folder, 'param.obj') 42 | 43 | cur_solid = None 44 | extrude_idx = 0 45 | for obj in param_objs: 46 | try: 47 | with timeout(30): 48 | parser = OBJParser(obj) 49 | _, faces, meta_info = parser.parse_file(1.0) 50 | converter = OBJReconverter() 51 | ext_solid, _, _ = converter.parse_obj(faces, meta_info) 52 | set_op = meta_info["set_op"] 53 | if set_op == "NewBodyFeatureOperation" or set_op == "JoinFeatureOperation": 54 | if cur_solid is None: 55 | cur_solid = ext_solid 56 | else: 57 | cur_solid = converter.my_op(cur_solid, ext_solid, 'fuse') 58 | elif set_op == "CutFeatureOperation": 59 | cur_solid = converter.my_op(cur_solid, ext_solid, 'cut') 60 | elif set_op == "IntersectFeatureOperation": 61 | cur_solid = converter.my_op(cur_solid, ext_solid, 'common') 62 | else: 63 | raise Exception("Unknown operation type") 64 | 65 | analyzer = BRepCheck_Analyzer(cur_solid) 66 | if not analyzer.IsValid(): 67 | raise Exception("brep check failed") 68 | 69 | extrude_idx += 1 70 | 71 | except Exception as ex: 72 | msg = [project_folder, str(ex)[:100]] 73 | return None 74 | try: 75 | with timeout(30): 76 | stl_name = Path(output_folder).stem + '_'+ str(extrude_idx).zfill(3) + "_final.stl" 77 | output_path = os.path.join(output_folder, stl_name) 78 | write_stl_file(cur_solid, output_path, linear_deflection=0.001, angular_deflection=0.5) 79 | 80 | step_name = Path(output_folder).stem + '_'+ str(extrude_idx).zfill(3) + "_final.step" 81 | output_path = os.path.join(output_folder, step_name) 82 | write_step_file(cur_solid, output_path) 83 | 84 | except Exception as ex: 85 | msg = [project_folder, str(ex)[:500]] 86 | return None 87 | 88 | return cur_solid 89 | 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--data_folder", type=str, required=True) 94 | args = parser.parse_args() 95 | 96 | solids = [] 97 | cad_folders = sorted(glob(args.data_folder+'/*/')) 98 | 99 | success_count = 0 100 | convert_iter = Pool(NUM_TRHEADS).imap(run_parallel, cad_folders) 101 | 102 | for solid in tqdm(convert_iter, total=len(cad_folders)): 103 | if solid is not None: 104 | success_count += 1 105 | 106 | print(f"\n✅ Successfully generated STEP files: {success_count}/{len(cad_folders)}") -------------------------------------------------------------------------------- /utils/convert.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import json 4 | # hyperparameters from SkexGen project 5 | SKETCH_R = 1 6 | RADIUS_R = 1 7 | EXTRUDE_R = 1.0 8 | SCALE_R = 1.4 9 | OFFSET_R = 0.9 10 | PIX_PAD = 4 11 | CMD_PAD = 3 12 | COORD_PAD = 4 13 | EXT_PAD = 1 14 | EXTRA_PAD = 1 15 | R_PAD = 2 16 | 17 | 18 | def create_curve_str(se_xy, se_cmd): 19 | curve_str = "" 20 | xy_offset = 0 21 | if se_cmd == 0: # line 22 | curve_str = " line," + ",".join(str(x) for x in se_xy[0]) 23 | xy_offset = 2 24 | elif se_cmd == 1: # arc 25 | curve_str = " arc," + ",".join(str(x) for x in se_xy[0:2].flatten()) 26 | xy_offset = 3 27 | elif se_cmd == 2: # circle 28 | curve_str = " circle," + ",".join(str(x) for x in se_xy[0:4].flatten()) 29 | xy_offset = 5 30 | curve_str += " " 31 | return curve_str, xy_offset 32 | 33 | 34 | def create_sketch_str(se_xy, se_cmd): 35 | sketch_str = "" 36 | len_xy, len_cmd = len(se_xy), len(se_cmd) 37 | xy_idx = 0 38 | for cmd_item in se_cmd: # for each command 39 | if 0 <= cmd_item <= 2: # curve 40 | curve_str, xy_offset = create_curve_str(se_xy[xy_idx:], cmd_item) 41 | sketch_str += curve_str 42 | xy_idx += xy_offset 43 | elif cmd_item == -1: # loop 44 | sketch_str += " " 45 | xy_idx += 1 46 | elif cmd_item == -2: # face 47 | sketch_str += " " 48 | xy_idx += 1 49 | elif cmd_item == -3: # sketch 50 | sketch_str += " " 51 | xy_idx += 1 52 | else: 53 | raise ValueError("Invalid command: " + str(cmd_item)) 54 | if xy_idx != len_xy: 55 | raise ValueError("xy_idx != len_xy") 56 | return sketch_str 57 | 58 | 59 | def create_extrude_str(se_ext): 60 | extrude_str = "" 61 | # extrude operation 62 | if se_ext[14] == 1: 63 | extrude_str += "add" 64 | elif se_ext[14] == 2: 65 | extrude_str += "cut" 66 | elif se_ext[14] == 3: 67 | extrude_str += "intersect" 68 | else: 69 | raise ValueError("Invalid extrude operation: " + str(se_ext[14])) 70 | # other extrude parameters 71 | extrude_str = ( 72 | extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[0:5]) 73 | ) # ext_v, ext_T 74 | extrude_str = ( 75 | extrude_str + "," + ",".join(str(x - R_PAD) for x in se_ext[5:14]) 76 | ) # ext_R 77 | extrude_str = ( 78 | extrude_str + "," + ",".join(str(x - EXT_PAD) for x in se_ext[15:18]) 79 | ) # scale, offset 80 | # extrude end 81 | extrude_str += " " 82 | return extrude_str 83 | 84 | 85 | def convert(in_path, out_path): 86 | with open(in_path, "rb") as f: 87 | data = pickle.load(f) 88 | print("Data loaded: " + str(len(data)) + " samples") 89 | 90 | results = [] 91 | for item in data: # for each data 92 | se_str = "" 93 | num_se = item["num_se"] 94 | for se_idx in range(num_se): # for each sketch-extrude 95 | xy, cmd, ext = ( 96 | item["se_xy"][se_idx] - COORD_PAD, 97 | item["se_cmd"][se_idx] - CMD_PAD, 98 | item["se_ext"][se_idx], 99 | ) 100 | se_str = se_str + " " + create_sketch_str(xy, cmd).strip() 101 | se_str = se_str + " " + create_extrude_str(ext).strip() 102 | name = item['name'] 103 | result={ 104 | 'name': name, 105 | 'original_sequence': se_str.strip() 106 | } 107 | results.append(result) 108 | 109 | with open(out_path, "w") as f: 110 | json.dump(results, f) 111 | print("Data converted: " + str(len(results)) + " samples") 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("--in_path", type=str, required=True) 117 | parser.add_argument("--out_path", type=str, required=True) 118 | args = parser.parse_args() 119 | 120 | convert(args.in_path, args.out_path) 121 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | **/dataset/** 176 | **/model/** 177 | -------------------------------------------------------------------------------- /utils/eval_dclip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from PIL import Image 4 | from directional_clip_score import CLIPLoss 5 | import json 6 | import os 7 | from collections import defaultdict 8 | import heapq 9 | import argparse 10 | 11 | def preprocess_image(pil_image): 12 | # Define a transform to convert the PIL image to a tensor 13 | transform = transforms.Compose([ 14 | transforms.ToTensor(), # Converts a PIL Image or numpy.ndarray (H x W x C) to a FloatTensor of shape (C x H x W) 15 | ]) 16 | 17 | # Apply the transform to the PIL image 18 | tensor_image = transform(pil_image) 19 | return tensor_image 20 | 21 | 22 | def find_files_with_prefix_source(directory, prefix): 23 | files = sorted(os.listdir(directory)) 24 | 25 | matching_files = [ 26 | file_name for file_name in files 27 | if file_name[5:].startswith(prefix) 28 | ] 29 | return matching_files 30 | 31 | 32 | def find_files_with_prefix(directory, prefix): 33 | files = sorted(os.listdir(directory)) 34 | 35 | matching_files = [ 36 | file_name for file_name in files 37 | if file_name.startswith(prefix) 38 | ] 39 | return matching_files 40 | 41 | 42 | def cal_dclip(args): 43 | source_files = sorted(os.listdir(args.source_dir)) 44 | 45 | with open(args.instruction_path, 'rb') as file: 46 | ins_data = json.load(file) 47 | 48 | # Collect all DCLIP scores for average calculation 49 | all_dclip_scores = [] 50 | total_comparisons = 0 51 | 52 | for idx, item in enumerate(ins_data): 53 | prefix_source = str(idx).zfill(5) + '_' 54 | prefix_edit = str(idx).zfill(5) + '_' 55 | 56 | source_files = find_files_with_prefix(args.source_dir, prefix_source) 57 | matching_files = find_files_with_prefix(args.edit_dir, prefix_edit) 58 | print(prefix_edit, matching_files) 59 | 60 | for source_file in source_files: 61 | source_path = os.path.join(args.source_dir, source_file) 62 | for matching_file in matching_files: 63 | matching_path = os.path.join(args.edit_dir, matching_file) 64 | print(matching_path) 65 | 66 | try: 67 | src_img = Image.open(source_path) 68 | target_img = Image.open(matching_path) 69 | src_img = preprocess_image(src_img) 70 | target_img = preprocess_image(target_img) 71 | 72 | device = "cuda" if torch.cuda.is_available() else "cpu" 73 | 74 | clip_loss_module = CLIPLoss(device=device) 75 | 76 | src_img_tensor = clip_loss_module.preprocess(src_img).unsqueeze(0).to(device) 77 | target_img_tensor = clip_loss_module.preprocess(target_img).unsqueeze(0).to(device) 78 | 79 | src_text = "This is a 3D shape. " 80 | instruction = item['instruction'] 81 | target_text = src_text + instruction 82 | 83 | directional_loss = clip_loss_module.clip_directional_loss(src_img_tensor, src_text, target_img_tensor, target_text) 84 | dclip_score = directional_loss.item() 85 | 86 | # Collect score for average calculation 87 | all_dclip_scores.append(dclip_score) 88 | total_comparisons += 1 89 | 90 | item['dclip'] = dclip_score 91 | 92 | # Write the results as a JSON list 93 | with open(args.out_path, 'a', encoding='utf-8') as f: 94 | json.dump(item, f, ensure_ascii=False) 95 | f.write(',\n') 96 | 97 | except Exception as e: 98 | print(f"Error processing {source_path} -> {matching_path}: {e}") 99 | continue 100 | 101 | # Calculate and print statistics 102 | if all_dclip_scores: 103 | average_dclip = sum(all_dclip_scores) / len(all_dclip_scores) 104 | max_dclip = max(all_dclip_scores) 105 | min_dclip = min(all_dclip_scores) 106 | 107 | print(f"\n{'='*50}") 108 | print(f"DCLIP Score Statistics:") 109 | print(f"{'='*50}") 110 | print(f"Total comparisons: {total_comparisons}") 111 | print(f"Average DCLIP score: {average_dclip:.6f}") 112 | print(f"Maximum DCLIP score: {max_dclip:.6f}") 113 | print(f"Minimum DCLIP score: {min_dclip:.6f}") 114 | print(f"{'='*50}") 115 | 116 | else: 117 | print("No DCLIP scores were calculated!") 118 | 119 | return all_dclip_scores 120 | 121 | 122 | if __name__ == "__main__": 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--source_dir", type=str, required=True) 125 | parser.add_argument("--edit_dir", type=str, required=True) 126 | parser.add_argument("--instruction_path", type=str, required=True) 127 | parser.add_argument("--out_path", type=str, default="cad_dclip.json") 128 | 129 | args = parser.parse_args() 130 | 131 | cal_dclip(args) -------------------------------------------------------------------------------- /finetune/create_mask.py: -------------------------------------------------------------------------------- 1 | import re 2 | from difflib import SequenceMatcher 3 | from copy import deepcopy 4 | import json 5 | import argparse 6 | 7 | def parse_token(token): 8 | """ 9 | Parse a token into its components. 10 | """ 11 | return token.split() # split by space 12 | 13 | def merge_consecutive_masks(tokens): 14 | """ 15 | Merge consecutive tokens into a single . 16 | """ 17 | merged_tokens = [] 18 | for token in tokens: 19 | if token == "": 20 | if not merged_tokens or merged_tokens[-1] != "": 21 | merged_tokens.append("") 22 | else: 23 | merged_tokens.append(token) 24 | return merged_tokens 25 | 26 | def compare_tokens(token1, token2): 27 | """ 28 | Compare two tokens at a finer granularity. 29 | If they are partially similar (e.g., 'line,14,14' vs 'line,13,13'), 30 | preserve the common part and mask only the differences. 31 | """ 32 | # Parse tokens into components 33 | components1 = parse_token(token1) 34 | components2 = parse_token(token2) 35 | 36 | # If the base command (e.g., 'line') is different, mask the entire token 37 | if components1[0] != components2[0]: 38 | return "" 39 | 40 | # If the base command is the same, compare the rest of the components 41 | result = [components1[0]] # Start with the base command 42 | for comp1, comp2 in zip(components1[1:], components2[1:]): 43 | if comp1 == comp2: 44 | result.append(comp1) 45 | else: 46 | result.append("") 47 | 48 | # Reconstruct the token 49 | return ','.join(result) 50 | 51 | def generate_mask_lcs_with_partial_matching(original_sequence, edited_sequence): 52 | """ 53 | Generate the masked sequence using LCS with partial matching for tokens. 54 | """ 55 | original_tokens = original_sequence.split() 56 | edited_tokens = edited_sequence.split() 57 | 58 | # Find the longest common subsequence (LCS) 59 | matcher = SequenceMatcher(None, original_tokens, edited_tokens) 60 | lcs = matcher.get_matching_blocks() # Get matching blocks of tokens 61 | 62 | masked_sequence = [] 63 | original_idx = 0 64 | edited_idx = 0 65 | 66 | for match in lcs: 67 | # Handle tokens in original_sequence that are not in LCS (deletions) 68 | while original_idx < match.a: 69 | masked_sequence.append("") 70 | original_idx += 1 71 | 72 | # Handle tokens in edited_sequence that are not in LCS (additions) 73 | while edited_idx < match.b: 74 | masked_sequence.append("") 75 | edited_idx += 1 76 | 77 | # Add the matching tokens (LCS part), with partial matching for differences 78 | for i in range(match.size): 79 | token1 = original_tokens[original_idx] 80 | token2 = edited_tokens[edited_idx] 81 | if token1 == token2: 82 | masked_sequence.append(token1) 83 | else: 84 | masked_sequence.append(compare_tokens(token1, token2)) 85 | original_idx += 1 86 | edited_idx += 1 87 | 88 | # Handle remaining tokens in original_sequence (deletions) 89 | while original_idx < len(original_tokens): 90 | masked_sequence.append("") 91 | original_idx += 1 92 | 93 | # Handle remaining tokens in edited_sequence (additions) 94 | while edited_idx < len(edited_tokens): 95 | masked_sequence.append("") 96 | edited_idx += 1 97 | 98 | # Merge consecutive tokens 99 | merged_sequence = merge_consecutive_masks(masked_sequence) 100 | return " ".join(merged_sequence) 101 | 102 | def process_dataset_with_partial_matching(dataset): 103 | """ 104 | Process the dataset to add `original_sequence_mask` for each entry using partial matching. 105 | """ 106 | processed_data = [] 107 | 108 | for entry in dataset: 109 | # Skip entries without both original_sequence and edited_sequence 110 | if 'original_sequence' not in entry or 'edited_sequence' not in entry: 111 | continue 112 | 113 | original_sequence = entry['original_sequence'] 114 | edited_sequence = entry['edited_sequence'] 115 | 116 | # Generate the masked sequence using partial matching 117 | original_sequence_mask = generate_mask_lcs_with_partial_matching(original_sequence, edited_sequence) 118 | 119 | # Add the masked sequence to the entry 120 | new_entry = deepcopy(entry) 121 | new_entry['masked_sequence'] = original_sequence_mask 122 | processed_data.append(new_entry) 123 | 124 | return processed_data 125 | 126 | 127 | def main(args): 128 | # load and process a JSON dataset 129 | with open(args.input_path, "r") as f: 130 | dataset = json.load(f) 131 | 132 | updated_dataset = process_dataset_with_partial_matching(dataset) 133 | 134 | with open(args.output_path, "w") as f: 135 | json.dump(updated_dataset, f, indent=4) 136 | 137 | print(f"Processed {len(updated_dataset)} entries from {args.input_path} and saved to {args.output_path}") 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser(description='Process dataset to create masked sequences') 142 | parser.add_argument('--input_path', type=str, default='raw_train.json', 143 | help='Path to input JSON dataset file (default: raw_train.json)') 144 | parser.add_argument('--output_path', type=str, default='train.json', 145 | help='Path to output JSON dataset file (default: train.json)') 146 | 147 | args = parser.parse_args() 148 | 149 | main(args) 150 | 151 | 152 | -------------------------------------------------------------------------------- /data/caption_sequence.py: -------------------------------------------------------------------------------- 1 | import json 2 | import transformers 3 | import torch 4 | import argparse 5 | 6 | def sample(args): 7 | model_path = "meta-llama/Meta-Llama-3-70B-Instruct" 8 | 9 | pipeline = transformers.pipeline( 10 | "text-generation", 11 | model=model_path, 12 | model_kwargs={"torch_dtype": torch.bfloat16}, 13 | device_map="auto", 14 | ) 15 | 16 | with open(args.in_path, 'r', encoding='utf-8') as file: 17 | data = json.load(file) 18 | 19 | for idx, item in enumerate(data): 20 | prompt = f"""## Task 21 | You are a senior CAD engineer. Your task is to provide a clear and concise editing instruction (10 words or fewer) for editing a sketch-and-extrude CAD model. Your response should include: 22 | 1. Description of the Original CAD Model: Analyze the CAD operation sequence and describe the resulting geometry. Include element types (e.g., cylinder, prism, hole), quantities, proportions, spatial relationships, and any notable details. 23 | 2. Description of the Edited CAD Model: Analyze the CAD operation sequence and describe the resulting geometry. Include element types (e.g., cylinder, prism, hole), quantities, proportions, spatial relationships, and any notable details. 24 | 3. Change Analysis: 25 | - Geometric Changes: Describe added, removed, or modified elements, including types (e.g., cylinder, prism, hole) and quantities (e.g., two rectangles). Use spatial or geometric features (e.g., "upper triangular face", "smaller rectangular prism", "central circular hole") instead of unintuitive terms like "first" or "second." 26 | - Proportions and Dimensions: Note changes in size, scaling, or relative proportions. 27 | - Positional Relationships: Explain spatial alignment and relationships between elements. 28 | - Other Notable Details: Highlight any additional observations. 29 | - Purpose: Suggest the intent behind the edit (e.g., "add a central hole", "remove the smaller prism", or "increase length by 8 units"). 30 | 4. Editing Instruction: Provide a concise instruction (max 10 words) describing the modification. 31 | 32 | ## Sketch-and-Extrude Model Overview 33 | An "extruded-sketch" is a 3D volume, formed by extruding a sketch. A "sketch-and-extrude" model is formed by multiple extruded-sketches via Boolean operations (i.e., add, cut, and intersect). 34 | # Sketch 35 | - A "sketch" is formed by one or multiple faces. 36 | - A "face" is a 2D area bounded by loops. 37 | - A "loop" is a closed path, consisting of one (i.e., circle) or multiple curves (e.g., line-arc-line). 38 | - A "curve" (i.e., line, arc, or circle) is the lowest-level primitive. 39 | - A circle is defined by four points. 40 | - An arc is defined by two points, with the third point specified by the next curve (or the first curve when a loop is closed). 41 | - A line is defined by start point. 42 | - A point is represented by two integers which stands for the x and y coordinate, respectively. 43 | - A loop with a circle can not contain additional curves since it is already a closed path. 44 | - When a face consists of multiple loops, the first loop defines the external boundary, and the remaining loops define internal loops (i.e., holes). 45 | - An end-primitive token appears at the end of each primitive (curve, line, face, loop or sketch). 46 | # Extrude 47 | Each sketch will be followed by an extrude, which is represented by 18 parameters: BVVTTTRRRRRRRRRSOO 48 | - B represents one of the three Boolean operations: add, cut or intersect. It occupies 1 parameter. 49 | - V indicates the displacements of the top and the bottom planes from the reference plane in which a sketch is extruded to form a solid. It occupies 2 parameters. 50 | - T represents 3D translation applied to the extruded solid. It occupies 3 parameters. 51 | - R represents 3D rotation of the extrusion direction. It occupies 9 parameters. 52 | - S represents the uniform scaling factor. It occupies 1 parameter. 53 | - O represents the center of scaling as a 2D coordinate. It occupies 2 parameters. 54 | # Note 55 | - Note that every number is an integer. 56 | 57 | ## Your Task 58 | Original CAD Sequence: 59 | {item['original_sequence']} 60 | Edited CAD Sequence: 61 | {item['edited_sequence']} 62 | Let's think step by step. Your output should be of the following json format: 63 | {{ 64 | "Description of the Original CAD Model": your description here. 65 | "Description of the Edited CAD Model": your description here. 66 | "Change Analysis": your change analysis here. 67 | "Editing Instruction": the final editing instruction here (10 words maximum). 68 | }} 69 | """ 70 | messages = [ 71 | {"role": "system", "content": "You are an assistant trained to evaluate the semantic relevance between a Query and a Title. "}, 72 | {"role": "user", "content": prompt}, 73 | ] 74 | 75 | terminators = [ 76 | pipeline.tokenizer.eos_token_id, 77 | pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>") 78 | ] 79 | 80 | outputs = pipeline( 81 | messages, 82 | max_new_tokens=512, 83 | eos_token_id=terminators, 84 | do_sample=True, 85 | temperature=0.6, 86 | top_p=0.9, 87 | batch_size=512 88 | ) 89 | 90 | with open(f'{args.out_path}', 'a', encoding='utf-8') as f: 91 | item['instruction'] = outputs[0]["generated_text"][-1]['content'] 92 | item['method'] = "sequence" 93 | f.write(json.dumps(item, ensure_ascii=False) + ",\n") 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--in_path", type=str, required=True) 99 | parser.add_argument("--out_path", type=str, required=True) 100 | args = parser.parse_args() 101 | sample(args) -------------------------------------------------------------------------------- /prompt/gpt4_basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.feature_extraction.text import TfidfVectorizer 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | import os 5 | from openai import AzureOpenAI 6 | import json 7 | import requests 8 | import time 9 | from azure.identity import DefaultAzureCredential, get_bearer_token_provider 10 | import argparse 11 | 12 | # Configuration 13 | endpoint = 'YOUR_AZURE_OPENAI_ENDPOINT' 14 | token_provider = get_bearer_token_provider( 15 | DefaultAzureCredential(), 16 | "https://cognitiveservices.azure.com/.default" 17 | ) 18 | deployment_name = 'gpt-4o' 19 | 20 | client = AzureOpenAI( 21 | azure_ad_token_provider=token_provider, 22 | azure_endpoint=endpoint, 23 | api_version='2024-02-01' 24 | ) 25 | 26 | def call_gpt4_1(prompt): 27 | output = None 28 | message_text = [ 29 | {"role":"system","content":"You are an AI assistant that helps people find information."}, 30 | {"role":"user","content":prompt} 31 | ] 32 | while output is None: 33 | try: 34 | time.sleep(0.5) 35 | completion = client.chat.completions.create( 36 | model = deployment_name, 37 | messages = message_text, 38 | ) 39 | output = completion.choices[0].message.content 40 | except Exception as e: 41 | print("API call exception:", e) 42 | output = None 43 | return output 44 | 45 | def cad_basic(args): 46 | with open(args.in_path, 'r', encoding='utf-8') as file: 47 | data = json.load(file) 48 | for idx, item in enumerate(data): 49 | instruction = item['instruction'] 50 | for _ in range(5): 51 | output = None 52 | output_json = None 53 | time.sleep(0.5) 54 | prompt = f"""## Task 55 | You are a senior Computer-Aided Design (CAD) engineer. Your task is to provide clear, natural language editing instructions to a junior CAD designer for editing a sketch-and-extrude CAD model. Focus on geometric properties, including: 56 | - The type and number of elements. 57 | - Proportions and dimensions. 58 | - Positional relationships between elements. 59 | - Any other notable details. 60 | ## Sketch-and-Extrude Model Overview 61 | An "extruded-sketch" is a 3D volume, formed by extruding a sketch. A "sketch-and-extrude" model is formed by multiple extruded-sketches via Boolean operations (i.e., add, cut, and intersect). 62 | # Sketch 63 | - A "sketch" is formed by one or multiple faces. 64 | - A "face" is a 2D area bounded by loops. 65 | - A "loop" is a closed path, consisting of one (i.e., circle) or multiple curves (e.g., line-arc-line). 66 | - A "curve" (i.e., line, arc, or circle) is the lowest-level primitive. 67 | - A circle is defined by four points. 68 | - An arc is defined by three points but with two points, where the third point is specified by the next curve (or the first curve when a loop is closed). 69 | - A line is defined by start point. 70 | - A point is represented by two integers which stands for the x and y coordinate, respectively. 71 | - A loop with a circle can not contain additional curves since it is already a closed path. 72 | - When a face consists of multiple loops, the first loop defines the external boundary, and the remaining loops define internal loops (i.e., holes). 73 | - An end-primitive token appears at the end of each primitive (curve, line, face, loop or sketch). 74 | # Extrude 75 | Each sketch will be followed by an extrude, which is represented by 18 parameters: BVVTTTRRRRRRRRRSOO 76 | - B represents one of the three Boolean operations: add, cut or intersect. It occupies 1 parameter. 77 | - V indicates the displacements of the top and the bottom planes from the reference plane in which a sketch is extruded to form a solid. It occupies 2 parameters. 78 | - T represents 3D translation applied to the extruded solid. It occupies 3 parameters. 79 | - R represents 3D rotation of the extrusion direction. It occupies 9 parameters. 80 | - S represents the uniform scaling factor. It occupies 1 parameter. 81 | - O represents the center of scaling as a 2D coordinate. It occupies 2 parameters. 82 | # Note 83 | - Note that every number is an integer. 84 | 85 | ## Your task 86 | Original CAD Command Sequence: 87 | {item['original_sequence']} 88 | Instruction: 89 | {instruction} 90 | Your output should be of the following json format: 91 | {{ 92 | "modified sequence": your Modified CAD Command Sequence here. 93 | }} 94 | """ 95 | 96 | while output is None: 97 | try: 98 | output = call_gpt4_1(prompt) 99 | try: 100 | output_json = json.loads(output) 101 | except: 102 | output_json_lines = output.strip().splitlines()[1:-1] 103 | output_json = "\n".join(output_json_lines).strip() 104 | output_json = json.loads(output_json) 105 | except Exception as e: 106 | print("error: ", e) 107 | time.sleep(1) 108 | output = None 109 | 110 | item['output'] = output_json.get("modified sequence", None) 111 | 112 | with open(args.out_path, 'a', encoding='utf-8') as f: 113 | json.dump(item, f, ensure_ascii=False, indent=4) 114 | f.write(',\n') 115 | 116 | if __name__ == "__main__": 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument("--in_path", type=str, required=True) 119 | parser.add_argument("--out_path", type=str, required=True) 120 | args = parser.parse_args() 121 | cad_basic(args) -------------------------------------------------------------------------------- /prompt/gpt4_fs3.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.feature_extraction.text import TfidfVectorizer 3 | from sklearn.metrics.pairwise import cosine_similarity 4 | import os 5 | from openai import AzureOpenAI 6 | import json 7 | import requests 8 | import time 9 | from azure.identity import DefaultAzureCredential, get_bearer_token_provider 10 | import argparse 11 | 12 | # Configuration 13 | endpoint = 'YOUR_AZURE_OPENAI_ENDPOINT' 14 | token_provider = get_bearer_token_provider( 15 | DefaultAzureCredential(), 16 | "https://cognitiveservices.azure.com/.default" 17 | ) 18 | deployment_name = 'gpt-4o' 19 | 20 | client = AzureOpenAI( 21 | azure_ad_token_provider=token_provider, 22 | azure_endpoint=endpoint, 23 | api_version='2024-02-01' 24 | ) 25 | 26 | def call_gpt4_1(prompt): 27 | output = None 28 | message_text = [ 29 | {"role":"system","content":"You are an AI assistant that helps people find information."}, 30 | {"role":"user","content":prompt} 31 | ] 32 | while output is None: 33 | try: 34 | time.sleep(0.5) 35 | completion = client.chat.completions.create( 36 | model = deployment_name, 37 | messages = message_text, 38 | ) 39 | output = completion.choices[0].message.content 40 | except Exception as e: 41 | print("API call error:", e) 42 | output = None 43 | return output 44 | 45 | def cad_fs3(args): 46 | with open(args.example_path, 'r', encoding='utf-8') as file: 47 | example_data = json.load(file) 48 | 49 | with open(args.in_path, 'r', encoding='utf-8') as file: 50 | data = json.load(file) 51 | for idx, item in enumerate(data): 52 | for _ in range(5): 53 | output = None 54 | output_json = None 55 | time.sleep(1) 56 | descriptions = [entry['instruction'] for entry in example_data] 57 | vectorizer = TfidfVectorizer() 58 | tfidf_matrix = vectorizer.fit_transform(descriptions + [item['instruction']]) 59 | cosine_similarities = cosine_similarity(tfidf_matrix[-1], tfidf_matrix[:-1]) 60 | top_indices = np.argsort(cosine_similarities[0])[::-1][:6] 61 | top_entries = [example_data[index] for index in top_indices] 62 | 63 | prompt = f"""Modify the original Computer-Aided Design(CAD) command sequence according to the instruction:\n' 64 | 65 | ## Instructions for sketch-and-extrude model 66 | A sketch-and-extrude model consists of multiple extruded-sketches. 67 | # Sketch 68 | - A sketch consists of multiple faces 69 | - A face consists of multiple loops. 70 | - A loop consists of multiple curves. 71 | - A curve is either a line, an arc, or a circle. 72 | - A circle is defined by four points with four geometry tokens. 73 | - An arc is defined by three points but with two tokens, where the third point is specified by the next curve (or the first curve when a loop is closed). 74 | - A line is defined by start point. 75 | - A point is represented by two integers which stands for the x and y coordinate, respectively. 76 | - A loop with a circle can not contain additional curves since it is already a closed path. 77 | - When a face consists of multiple loops, the first loop defines the external boundary, and the remaining loops define internal loops (i.e., holes). 78 | - An end-primitive token appears at the end of each primitive (curve, line, face, loop or sketch). 79 | # Extrude 80 | Each sketch will be followd by an extrude, which is represented by 18 parameters: BWVTTTRRRRRRRRRSOO 81 | - B represents one of the three Boolean operations: add, cut or intersect. It occupies 1 parameter 82 | - V indicates the displacements of the top and the bottom planes from the referenceplane in which a sketch is extruded to form a solid. It occupies 2 parameters.T represents 3D translation applied to the extruded solid. It occupies 3parameters 83 | - R represents 3D rotation of the extrusion direction. It occupies 6 parameters. 84 | - S represents the uniform scaling factor. It occupies 1 parameter. 85 | - O represents the center of scaling as a 2D coordinate. It occupies 2 parameters. 86 | # Note 87 | - Note that every number is an integer. 88 | 89 | ## Examples for editing sketch-and-extrude model 90 | Example 1: 91 | Original CAD Command Sequence: 92 | {top_entries[0]['original_sequence']} 93 | Instruction: 94 | {top_entries[0]['instruction']} 95 | Modified CAD Command Sequence: 96 | {top_entries[0]['edited_sequence']} 97 | Example 2: 98 | Original CAD Command Sequence: 99 | {top_entries[1]['original_sequence']} 100 | Instruction: 101 | {top_entries[1]['instruction']} 102 | Modified CAD Command Sequence: 103 | {top_entries[1]['edited_sequence']} 104 | Example 3: 105 | Original CAD Command Sequence: 106 | {top_entries[2]['original_sequence']} 107 | Instruction: 108 | {top_entries[2]['instruction']} 109 | Modified CAD Command Sequence: 110 | {top_entries[2]['edited_sequence']} 111 | 112 | 113 | ## Your task 114 | Original CAD Command Sequence: 115 | {item['original_sequence']} 116 | Instruction: 117 | {item[instruction_field]} 118 | Your output should be of the following json format: 119 | {{ 120 | "modified sequence": your Modified CAD Command Sequence here. 121 | }} 122 | """ 123 | 124 | while output is None: 125 | try: 126 | output = call_gpt4_1(prompt) 127 | try: 128 | output_json = json.loads(output) 129 | except: 130 | output_json_lines = output.strip().splitlines()[1:-1] 131 | output_json = "\n".join(output_json_lines).strip() 132 | output_json = json.loads(output_json) 133 | except Exception as e: 134 | print("error: ", e) 135 | time.sleep(1) 136 | output = None 137 | 138 | item['output'] = output_json.get("modified sequence", None) 139 | 140 | with open(args.out_path, 'a', encoding='utf-8') as f: 141 | json.dump(item, f, ensure_ascii=False, indent=4) 142 | f.write(',\n') 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--in_path", type=str, required=True) 147 | parser.add_argument("--example_path", type=str, required=True) 148 | parser.add_argument("--out_path", type=str, required=True) 149 | args = parser.parse_args() 150 | cad_fs3(args) -------------------------------------------------------------------------------- /hnc-cad/ac_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from config import * 5 | from hashlib import sha256 6 | import numpy as np 7 | from dataset import CADData 8 | from utils import CADparser, write_obj_sample 9 | from model.encoder import SketchEncoder, ExtEncoder 10 | from model.decoder import SketchDecoder, ExtDecoder, CodeDecoder 11 | 12 | 13 | def raster_cad(coord, ext): 14 | parser = CADparser(CAD_BIT) 15 | parsed_data = parser.perform(coord, ext) 16 | return parsed_data 17 | 18 | 19 | def pad_code(total_code): 20 | keys = np.ones(len(total_code)) 21 | padding = np.zeros(MAX_CODE-len(total_code)).astype(int) 22 | total_code = np.concatenate([total_code, padding], axis=0) 23 | seq_mask = 1-np.concatenate([keys, padding]) == 1 24 | return total_code, seq_mask 25 | 26 | 27 | def hash_sketch(sketch, ext): 28 | hash_str = sha256(np.ascontiguousarray(sketch).flatten()).hexdigest() +'_'+\ 29 | sha256(np.ascontiguousarray(ext).flatten()).hexdigest() 30 | return hash_str 31 | 32 | 33 | @torch.inference_mode() 34 | def sample(args): 35 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 36 | dataset = CADData(CAD_TRAIN_PATH, args.solid_code, args.profile_code, args.loop_code, args.mode, is_training=False) 37 | dataloader = torch.utils.data.DataLoader(dataset, 38 | shuffle=False, 39 | batch_size=1, 40 | num_workers=1) 41 | code_size = dataset.solid_unique_num + dataset.profile_unique_num + dataset.loop_unique_num 42 | 43 | # Load model weights 44 | sketch_enc = SketchEncoder() 45 | sketch_enc.load_state_dict(torch.load(os.path.join(args.weight, 'sketch_enc_epoch_250.pt'))) 46 | sketch_enc.cuda().eval() 47 | 48 | sketch_dec = SketchDecoder(args.mode, num_code=9390) 49 | sketch_dec.load_state_dict(torch.load(os.path.join(args.weight, 'sketch_dec_epoch_250.pt'))) 50 | sketch_dec.cuda().eval() 51 | 52 | ext_enc = ExtEncoder() 53 | ext_enc.load_state_dict(torch.load(os.path.join(args.weight, 'ext_enc_epoch_250.pt'))) 54 | ext_enc.cuda().eval() 55 | 56 | ext_dec = ExtDecoder(args.mode, num_code=9390) 57 | ext_dec.load_state_dict(torch.load(os.path.join(args.weight, 'ext_dec_epoch_250.pt'))) 58 | ext_dec.cuda().eval() 59 | 60 | code_dec = CodeDecoder(args.mode, 9390) 61 | code_dec.load_state_dict(torch.load(os.path.join(args.weight, 'code_dec_epoch_250.pt'))) 62 | code_dec.cuda().eval() 63 | 64 | # Random sampling 65 | code_bsz = 10 # every partial input samples this many neural codes 66 | count = 0 67 | for name, pixel_p, coord_p, sketch_mask_p, ext_p, ext_mask_p, _, _, _, _, _, _, _ in dataloader: 68 | pixel_p = pixel_p.cuda() 69 | coord_p = coord_p.cuda() 70 | sketch_mask_p = sketch_mask_p.cuda() 71 | ext_p = ext_p.cuda() 72 | ext_mask_p = ext_mask_p.cuda() 73 | 74 | # encode partial CAD model 75 | latent_sketch = sketch_enc(pixel_p, coord_p, sketch_mask_p) 76 | latent_extrude = ext_enc(ext_p, ext_mask_p) 77 | 78 | # generate the neural code tree 79 | latent_z = torch.cat([latent_sketch, latent_extrude], 1) 80 | latent_mask = torch.cat([sketch_mask_p, ext_mask_p], 1) 81 | code_sample = code_dec.sample(n_samples=code_bsz, latent_z=latent_z.repeat(code_bsz, 1, 1), 82 | latent_mask=latent_mask.repeat(code_bsz, 1), top_k=0, top_p=0.95) 83 | print("\ncode_sample:",code_sample) 84 | # filter code, only keep unique code 85 | if len(code_sample)<3: 86 | continue 87 | code_unique = {} 88 | for ii in range(len(code_sample)): 89 | if len(torch.where(code_sample[ii]==0)[0])==0: 90 | continue 91 | code = (code_sample[ii][:torch.where(code_sample[ii]==0)[0][0]+1]).detach().cpu().numpy() 92 | code_uid = code.tobytes() 93 | if code_uid not in code_unique: 94 | code_unique[code_uid] = code 95 | total_code = [] 96 | total_code_mask = [] 97 | print("code_unique: ", code_unique) 98 | for _, code in code_unique.items(): 99 | _code_, _code_mask_ = dataset.pad_code(code) 100 | total_code.append(_code_) 101 | total_code_mask.append(_code_mask_) 102 | total_code = np.vstack(np.vstack(total_code)) 103 | total_code_mask = np.vstack(total_code_mask) 104 | total_code = torch.LongTensor(total_code).cuda() 105 | total_code_mask = torch.BoolTensor(total_code_mask).cuda() 106 | 107 | # generate the full CAD model 108 | latent_z = latent_z.repeat(len(total_code), 1, 1) 109 | latent_mask = latent_mask.repeat(len(total_code), 1) 110 | xy_samples, _code_, _code_mask_, _latent_z_, _latent_mask_ = sketch_dec.sample(total_code, total_code_mask, latent_z, latent_mask, top_k=1, top_p=0) 111 | cad_samples = ext_dec.sample(xy_samples, _code_, _code_mask_, _latent_z_, _latent_mask_, top_k=1, top_p=0) 112 | 113 | # raster user input cad 114 | try: 115 | print("try1") 116 | cad_obj = raster_cad(coord_p.detach().cpu().numpy()[0], ext_p.detach().cpu().numpy()[0]) 117 | save_folder = os.path.join(result_folder, str(name[0]) + '_' + str(count).zfill(6)+'_origInput') 118 | if not os.path.exists(save_folder): 119 | os.makedirs(save_folder) 120 | write_obj_sample(save_folder, cad_obj) 121 | except Exception as error_msg: 122 | print("error1: ", error_msg) 123 | continue 124 | 125 | # raster auto-completed cad 126 | for ii, sample in enumerate(cad_samples): 127 | try: 128 | print("try2") 129 | cad_obj = raster_cad(sample[0], sample[1]) 130 | save_folder = os.path.join(result_folder, str(name[0]) + '_' + str(count).zfill(6) +'_postAC'+str(ii)) 131 | if not os.path.exists(save_folder): 132 | os.makedirs(save_folder) 133 | write_obj_sample(save_folder, cad_obj) 134 | except Exception as error_msg: 135 | print("error2: ", error_msg) 136 | continue 137 | count += 1 138 | print("count1: ",count) 139 | 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--weight", type=str, help="Pretrained CAD model", required=True) 145 | parser.add_argument("--output", type=str, help="Output folder to save the data", required=True) 146 | parser.add_argument("--device", type=str, help="CUDA Device Index", required=True) 147 | parser.add_argument("--mode", type=str, required=True, help="eval | sample") 148 | parser.add_argument("--solid_code", type=str, required=True) 149 | parser.add_argument("--profile_code", type=str, required=True) 150 | parser.add_argument("--loop_code", type=str, required=True) 151 | args = parser.parse_args() 152 | 153 | result_folder = args.output 154 | if not os.path.exists(result_folder): 155 | os.makedirs(result_folder) 156 | 157 | sample(args) 158 | -------------------------------------------------------------------------------- /finetune/llama_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from tqdm import tqdm 4 | import transformers 5 | from peft import PeftModel 6 | from pathlib import Path 7 | import torch 8 | from transformers import LlamaForCausalLM, LlamaTokenizer, TrainingArguments 9 | 10 | 11 | DEFAULT_PAD_TOKEN = "[PAD]" 12 | DEFAULT_EOS_TOKEN = "" 13 | DEFAULT_BOS_TOKEN = "" 14 | DEFAULT_UNK_TOKEN = "" 15 | MAX_LENGTH = 1024 16 | 17 | def prepare_model_and_tokenizer(args): 18 | model_id= 'meta-llama/Meta-Llama-3-8B-Instruct' 19 | pipeline = transformers.pipeline("text2text-generation", 20 | model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map='auto') 21 | tokenizer = pipeline.tokenizer 22 | model = pipeline.model 23 | 24 | model.eval() 25 | 26 | special_tokens_dict = dict() 27 | if tokenizer.pad_token is None: 28 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 29 | if tokenizer.eos_token is None: 30 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 31 | if tokenizer.bos_token is None: 32 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 33 | if tokenizer.unk_token is None: 34 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 35 | 36 | smart_tokenizer_and_embedding_resize( 37 | special_tokens_dict=special_tokens_dict, 38 | llama_tokenizer=tokenizer, 39 | model=model, 40 | ) 41 | 42 | model = PeftModel.from_pretrained(model, args.model_path, device_map="auto") 43 | # merge 44 | # model.merge_and_unload() 45 | 46 | return model, tokenizer 47 | 48 | 49 | def smart_tokenizer_and_embedding_resize( 50 | special_tokens_dict, 51 | llama_tokenizer, 52 | model, 53 | ): 54 | """Resize tokenizer and embedding. 55 | 56 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 57 | """ 58 | num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict) 59 | model.resize_token_embeddings(len(llama_tokenizer)) 60 | 61 | if num_new_tokens > 0: 62 | input_embeddings = model.get_input_embeddings().weight.data 63 | output_embeddings = model.get_output_embeddings().weight.data 64 | 65 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 66 | dim=0, keepdim=True 67 | ) 68 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 69 | dim=0, keepdim=True 70 | ) 71 | 72 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 73 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 74 | 75 | 76 | def get_prompt_template(task_type, item, instruction): 77 | """Get prompt template based on task type""" 78 | if task_type == "mask": 79 | return f"""Below is a Computer-Aided Design (CAD) operation sequence, replace the parts that need to be modified with the string "" according to the editing instruction. 80 | Original CAD Operation Sequence: 81 | {item['original_sequence']} 82 | Editing Instruction: 83 | {instruction} 84 | Masked CAD Operation Sequence: 85 | """ 86 | elif task_type == "infill": 87 | return f"""Below is the original Computer-Aided Design (CAD) operation sequence. 88 | Original CAD Operation Sequence: 89 | {item['original_sequence']} 90 | 91 | The parts that need to be modified according to the editing instruction have been replaced by the string "". 92 | Editing Instruction: 93 | {instruction} 94 | Masked CAD Operation Sequence: 95 | {item['output_mask']} 96 | 97 | Based on the original CAD sequence, the editing instruction, and the masked sequence, generate the complete edited CAD sequence by replacing "" with the appropriate content: 98 | """ 99 | else: 100 | raise ValueError(f"Unknown task: {task_type}") 101 | 102 | 103 | def get_output_field_name(task_type): 104 | """Get output field name based on task type""" 105 | if task_type == "mask": 106 | return "output_mask" 107 | elif task_type == "infill": 108 | return "output_infill" 109 | else: 110 | raise ValueError(f"Unknown task: {task_type}") 111 | 112 | 113 | def conditional_sample(args): 114 | model, tokenizer = prepare_model_and_tokenizer(args) 115 | with open(args.data_path, 'r', encoding='utf-8') as file: 116 | content = file.read().strip() 117 | try: 118 | data = json.loads(content) 119 | except json.JSONDecodeError: 120 | try: 121 | data =[] 122 | for line in content.split('\n'): 123 | line = line.strip() 124 | if line: 125 | data.append(json.loads(line)) 126 | except json.JSONDecodeError: 127 | raise ValueError(f"Failed to parse JSON from {args.data_path}. Please check the file format.") 128 | 129 | output_field = get_output_field_name(args.task_type) 130 | 131 | for idx, item in enumerate(data): 132 | instruction = item['instruction'] 133 | 134 | # Check if infill task requires output_mask field 135 | if args.task_type == "infill" and 'output_mask' not in item: 136 | print(f"Warning: Item {idx} missing 'output_mask' field required for infill task. Skipping.") 137 | continue 138 | 139 | prompts = [] 140 | for _ in range(args.num_samples): 141 | prompt = get_prompt_template(args.task_type, item, instruction) 142 | prompts.append(prompt) 143 | 144 | outputs = [] 145 | 146 | while len(outputs) < args.num_samples: 147 | batch_prompts = prompts[len(outputs) : len(outputs) + args.batch_size] 148 | 149 | batch = tokenizer( 150 | list(batch_prompts), 151 | return_tensors="pt", 152 | ) 153 | batch = {k: v.cuda() for k, v in batch.items()} 154 | 155 | generate_ids = model.generate( 156 | **batch, 157 | do_sample=True, 158 | max_new_tokens=MAX_LENGTH, 159 | temperature=args.temperature, 160 | top_p=args.top_p, 161 | repetition_penalty=1.3, 162 | ) 163 | 164 | gen_strs = tokenizer.batch_decode( 165 | generate_ids, 166 | skip_special_tokens=True, 167 | clean_up_tokenization_spaces=False, 168 | ) 169 | 170 | outputs.extend(gen_strs) 171 | print(f"Generated {len(outputs)}/{args.num_samples} samples.") 172 | 173 | with open(args.out_path, "a+") as f: 174 | for prompt, output in zip(prompts, outputs): 175 | item[output_field] = output[len(prompt):] 176 | f.write(json.dumps(item) + "\n") 177 | 178 | 179 | if __name__ == "__main__": 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument("--task_type", type=str, required=True, choices=["mask", "infill"], 182 | help="Task to perform: 'mask' for masking CAD sequences, 'infill' for infilling from masked sequences") 183 | parser.add_argument("--model_path", type=str, required=True) 184 | parser.add_argument("--data_path", type=str, required=True) 185 | parser.add_argument("--num_samples", type=int, default=10) 186 | parser.add_argument("--batch_size", type=int, default=32, 187 | help="Batch size (default: 32)") 188 | parser.add_argument("--out_path", type=str, default="cad_samples.jsonl") 189 | parser.add_argument("--temperature", type=float, default=0.9) 190 | parser.add_argument("--top_p", type=float, default=0.9) 191 | args = parser.parse_args() 192 | 193 | print(f"Running {args.task_type} task with batch size {args.batch_size}") 194 | conditional_sample(args) -------------------------------------------------------------------------------- /utils/geometry/obj_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | 5 | from geometry.arc import Arc 6 | from geometry.circle import Circle 7 | from geometry.line import Line 8 | 9 | from geometry import geom_utils 10 | import pdb 11 | 12 | 13 | class OBJParser: 14 | """ 15 | A class to read an OBJ file containing the sketch data 16 | and hand it back in a form which is easy to work with. 17 | """ 18 | def __init__(self, pathname=None): 19 | self.pathname = pathname 20 | 21 | 22 | def convert_vertices(self, vertices): 23 | """Convert all the vertices to .obj format""" 24 | vertex_strings = "" 25 | for pt in vertices: 26 | # e.g. v 0.123 0.234 0.345 1.0 27 | vertex_string = f"v {pt[0]} {pt[1]}\n" 28 | vertex_strings += vertex_string 29 | return vertex_strings 30 | 31 | 32 | def convert_curves(self, faces): 33 | curve_strings = "" 34 | total_curve = 0 35 | 36 | # Faces (multiple closed regions) 37 | for group_idx, loops in enumerate(faces): 38 | curve_strings += f"\nface\n" 39 | # Multiple loops (inner and outer) 40 | for loop in loops: 41 | if loop[0].is_outer: 42 | curve_strings += f"out\n" 43 | else: 44 | curve_strings += f"in\n" 45 | # All curves in one loop 46 | for curve in loop: 47 | total_curve += 1 48 | if curve.type == 'line': 49 | curve_strings += f"l {curve.start_idx} {curve.end_idx}\n" 50 | elif curve.type == 'circle': 51 | curve_strings += f"c {curve.center_idx} {curve.radius_idx}\n" 52 | elif curve.type == 'arc': 53 | curve_strings += f"a {curve.start_idx} {curve.mid_idx} {curve.center_idx} {curve.end_idx}\n" 54 | 55 | return curve_strings, total_curve 56 | 57 | 58 | def parse3d(self, point3d): 59 | x = point3d[0] 60 | y = point3d[1] 61 | z = point3d[2] 62 | return str(x)+' '+str(y)+' '+str(z) 63 | 64 | 65 | def write_obj2(self, file, vertices, faces, meta_info, scale=None): 66 | """ Write to .obj file """ 67 | vertex_strings = self.convert_vertices(vertices) 68 | curve_strings, total_curve = self.convert_curves(faces) 69 | 70 | with open(file, "w") as fh: 71 | # Write Meta info 72 | fh.write("# WaveFront *.obj file\n") 73 | fh.write(f"# Vertices: {len(vertices)}\n") 74 | fh.write(f"# Curves: {total_curve}\n") 75 | fh.write("\n") 76 | 77 | # Write vertex and curve 78 | fh.write(vertex_strings) 79 | fh.write("\n") 80 | fh.write(curve_strings) 81 | fh.write("\n") 82 | 83 | #Write extrude value 84 | fh.write("ExtrudeOperation: " + meta_info['set_op']+"\n") 85 | extrude_string = 'Extrude ' 86 | for value in meta_info['extrude_value']: 87 | extrude_string += str(value)+' ' 88 | fh.write(extrude_string) 89 | fh.write("\n") 90 | 91 | #Write refe plane transformation 92 | p_orig = self.parse3d(meta_info['t_orig']) 93 | x_axis = self.parse3d(meta_info['t_x']) 94 | y_axis = self.parse3d(meta_info['t_y']) 95 | z_axis = self.parse3d(meta_info['t_z']) 96 | fh.write('T_origin '+p_orig) 97 | fh.write("\n") 98 | fh.write('T_xaxis '+x_axis) 99 | fh.write("\n") 100 | fh.write('T_yaxis '+y_axis) 101 | fh.write("\n") 102 | fh.write('T_zaxis '+z_axis) 103 | fh.write("\n") 104 | 105 | # Normalized object 106 | if scale is not None: 107 | fh.write('Scale '+str(scale)) 108 | 109 | 110 | def write_obj(self, file, curve_strings, total_curve, vertex_strings, total_v, meta_info, scale=None): 111 | """ Write to .obj file """ 112 | #vertex_strings = self.convert_vertices(vertices) 113 | #curve_strings, total_curve = self.convert_curves(faces) 114 | 115 | with open(file, "w") as fh: 116 | # Write Meta info 117 | fh.write("# WaveFront *.obj file\n") 118 | fh.write(f"# Vertices: {total_v}\n") 119 | fh.write(f"# Curves: {total_curve}\n") 120 | fh.write("\n") 121 | 122 | # Write vertex and curve 123 | fh.write(vertex_strings) 124 | fh.write("\n") 125 | fh.write(curve_strings) 126 | fh.write("\n") 127 | 128 | #Write extrude value 129 | fh.write("ExtrudeOperation: " + meta_info['set_op']+"\n") 130 | extrude_string = 'Extrude ' 131 | for value in meta_info['extrude_value']: 132 | extrude_string += str(value)+' ' 133 | fh.write(extrude_string) 134 | fh.write("\n") 135 | 136 | #Write refe plane transformation 137 | p_orig = self.parse3d(meta_info['t_orig']) 138 | x_axis = self.parse3d(meta_info['t_x']) 139 | y_axis = self.parse3d(meta_info['t_y']) 140 | z_axis = self.parse3d(meta_info['t_z']) 141 | fh.write('T_origin '+p_orig) 142 | fh.write("\n") 143 | fh.write('T_xaxis '+x_axis) 144 | fh.write("\n") 145 | fh.write('T_yaxis '+y_axis) 146 | fh.write("\n") 147 | fh.write('T_zaxis '+z_axis) 148 | fh.write("\n") 149 | 150 | # Normalized object 151 | if scale is not None: 152 | fh.write('Scale '+str(scale)) 153 | 154 | 155 | def parse_file(self, scale=1.0): 156 | """ 157 | Parse obj file 158 | Return 159 | vertex 2D location numpy 160 | curve list (geometry class) 161 | extrude parameters 162 | """ 163 | 164 | assert self.pathname is not None, "File is None" 165 | assert self.pathname.exists(), "No such file" 166 | 167 | # Parse file 168 | vertex_list = [] 169 | loops = [] 170 | closed_loop = [] 171 | 172 | # Read vertice 173 | with open(self.pathname) as obj_file: 174 | for line in obj_file: 175 | tokens = line.split() 176 | if not tokens: 177 | continue 178 | line_type = tokens[0] 179 | # Vertex 180 | if line_type == "v": 181 | vertex_list.append([float(x) for x in tokens[1:]]) 182 | vertices = np.array(vertex_list, dtype=np.float64) * scale 183 | 184 | # Read curves 185 | faces = [] 186 | loops = [] 187 | loop = [] 188 | 189 | # Read in all lines 190 | lines = [] 191 | with open(self.pathname) as obj_file: 192 | for line in obj_file: 193 | lines.append(line) 194 | 195 | # Parse all lines 196 | faces = [] 197 | for str_idx, line in enumerate(lines): 198 | tokens = line.split() 199 | if not tokens: 200 | continue 201 | line_type = tokens[0] 202 | 203 | # Start of a new face 204 | if line_type == "face": 205 | faces.append(self.read_face(lines, str_idx+1, vertices)) 206 | 207 | # Read meta data 208 | meta_data = line.strip('# ').strip(' \n').split(' ') 209 | meta_name = meta_data[0] 210 | 211 | if meta_name == 'Extrude': 212 | extrude_values = [float(x) for x in meta_data[1:]] 213 | extrude_values = [x*scale for x in extrude_values] 214 | elif meta_name == 'T_origin': 215 | t_orig = [float(x) for x in meta_data[1:]] 216 | t_orig = [x*scale for x in t_orig] 217 | elif meta_name == 'T_xaxis': 218 | t_x = [float(x) for x in meta_data[1:]] 219 | elif meta_name == 'T_yaxis': 220 | t_y = [float(x) for x in meta_data[1:]] 221 | elif meta_name == 'T_zaxis': 222 | t_z = [float(x) for x in meta_data[1:]] 223 | elif meta_name == 'ExtrudeOperation:': 224 | set_op = meta_data[1] 225 | 226 | meta_info = {'extrude_value': extrude_values, 227 | 'set_op': set_op, 228 | 't_orig': t_orig, 229 | 't_x': t_x, 230 | 't_y': t_y, 231 | 't_z': t_z, 232 | } 233 | 234 | return vertices, faces, meta_info 235 | 236 | 237 | 238 | def read_face(self, lines, str_idx, vertices): 239 | loops = [] 240 | loop = [] 241 | for line in lines[str_idx:]: 242 | tokens = line.split() 243 | if not tokens: 244 | continue 245 | line_type = tokens[0] 246 | 247 | if line_type == 'face': 248 | break 249 | 250 | # Start of a new loop 251 | if line_type == "out" or line_type == "in": 252 | if len(loop) > 0: 253 | loops.append(loop) 254 | loop = [] 255 | is_outer = (line_type == 'out') 256 | 257 | # Line 258 | if line_type == 'l': 259 | c_tok = tokens[1:] 260 | curve = Line([int(c_tok[0]), int(c_tok[1])], vertices, is_outer=is_outer) 261 | loop.append(curve) 262 | 263 | # Arc 264 | if line_type == 'a': 265 | c_tok = tokens[1:] 266 | curve = Arc([int(c_tok[0]), int(c_tok[1]), int(c_tok[2]), int(c_tok[3])], vertices, is_outer=is_outer) 267 | loop.append(curve) 268 | 269 | # Circle 270 | if line_type == 'c': 271 | c_tok = tokens[1:] 272 | curve = Circle([int(c_tok[0]), int(c_tok[1])], vertices, is_outer=is_outer) 273 | loop.append(curve) 274 | 275 | loops.append(loop) 276 | return loops 277 | -------------------------------------------------------------------------------- /utils/eval_cad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | import random 7 | import warnings 8 | from glob import glob 9 | from scipy.stats import entropy 10 | from sklearn.neighbors import NearestNeighbors 11 | from plyfile import PlyData 12 | from pathlib import Path 13 | from multiprocessing import Pool 14 | from chamfer_distance import ChamferDistance 15 | 16 | N_POINTS = 2000 17 | NUM_TRHEADS = 36 18 | 19 | 20 | def find_files(folder, extension): 21 | return sorted([Path(os.path.join(folder, f)) for f in os.listdir(folder) if f.endswith(extension)]) 22 | 23 | 24 | def read_ply(path): 25 | with open(path, 'rb') as f: 26 | plydata = PlyData.read(f) 27 | x = np.array(plydata['vertex']['x']) 28 | y = np.array(plydata['vertex']['y']) 29 | z = np.array(plydata['vertex']['z']) 30 | vertex = np.stack([x, y, z], axis=1) 31 | return vertex 32 | 33 | 34 | def _pairwise_CD(sample_pcs, ref_pcs, batch_size): 35 | N_sample = sample_pcs.shape[0] 36 | N_ref = ref_pcs.shape[0] 37 | all_cd = [] 38 | chamfer_dist = ChamferDistance() 39 | pbar = tqdm(range(N_sample), desc="Computing Chamfer Distances") 40 | 41 | for i in pbar: 42 | sample = sample_pcs[i] # (N, 3) 43 | cd_list = [] 44 | for j in range(0, N_ref, batch_size): 45 | ref_batch = ref_pcs[j:min(j + batch_size, N_ref)] # (B, N, 3) 46 | bsz = ref_batch.size(0) 47 | sample_expand = sample.unsqueeze(0).expand(bsz, -1, -1).contiguous() # (B, N, 3) 48 | 49 | dl, dr, _, _ = chamfer_dist(sample_expand, ref_batch) 50 | cd = dl.mean(dim=1) + dr.mean(dim=1) # (B,) 51 | cd_list.append(cd) 52 | 53 | cd_list = torch.cat(cd_list, dim=0) # (N_ref,) 54 | all_cd.append(cd_list.unsqueeze(0)) # (1, N_ref) 55 | 56 | all_cd = torch.cat(all_cd, dim=0) # (N_sample, N_ref) 57 | return all_cd 58 | 59 | 60 | def compute_avg_cd(sample_pcs, ref_pcs, batch_size): 61 | all_dist = _pairwise_CD(sample_pcs, ref_pcs, batch_size) # (N_sample, N_ref) 62 | min_cd_per_sample, _ = torch.min(all_dist, dim=1) # (N_sample,) 63 | 64 | avg_cd = min_cd_per_sample.mean() 65 | median_cd = min_cd_per_sample.median() 66 | 67 | return { 68 | 'Avg-CD': avg_cd.item(), 69 | 'Median-CD': median_cd.item() 70 | } 71 | 72 | 73 | def jsd_between_point_cloud_sets(sample_pcs, ref_pcs, in_unit_sphere, resolution=28): 74 | sample_grid_var = entropy_of_occupancy_grid(sample_pcs, resolution, in_unit_sphere)[1] 75 | ref_grid_var = entropy_of_occupancy_grid(ref_pcs, resolution, in_unit_sphere)[1] 76 | return jensen_shannon_divergence(sample_grid_var, ref_grid_var) 77 | 78 | 79 | def entropy_of_occupancy_grid(pclouds, grid_resolution, in_sphere=False): 80 | epsilon = 10e-4 81 | bound = 1 + epsilon 82 | if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: 83 | warnings.warn('Point-clouds are not in unit cube.') 84 | 85 | if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: 86 | warnings.warn('Point-clouds are not in unit sphere.') 87 | 88 | grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) 89 | grid_coordinates = grid_coordinates.reshape(-1, 3) 90 | grid_counters = np.zeros(len(grid_coordinates)) 91 | grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) 92 | nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) 93 | 94 | for pc in pclouds: 95 | _, indices = nn.kneighbors(pc) 96 | indices = np.squeeze(indices) 97 | for i in indices: 98 | grid_counters[i] += 1 99 | indices = np.unique(indices) 100 | for i in indices: 101 | grid_bernoulli_rvars[i] += 1 102 | 103 | acc_entropy = 0.0 104 | n = float(len(pclouds)) 105 | for g in grid_bernoulli_rvars: 106 | if g > 0: 107 | p = float(g) / n 108 | acc_entropy += entropy([p, 1.0 - p]) 109 | 110 | return acc_entropy / len(grid_counters), grid_counters 111 | 112 | 113 | def unit_cube_grid_point_cloud(resolution, clip_sphere=False): 114 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) 115 | spacing = 1.0 / float(resolution - 1) * 2 116 | for i in range(resolution): 117 | for j in range(resolution): 118 | for k in range(resolution): 119 | grid[i, j, k, 0] = i * spacing - 0.5 * 2 120 | grid[i, j, k, 1] = j * spacing - 0.5 * 2 121 | grid[i, j, k, 2] = k * spacing - 0.5 * 2 122 | 123 | if clip_sphere: 124 | grid = grid.reshape(-1, 3) 125 | grid = grid[np.linalg.norm(grid, axis=1) <= 0.5] 126 | 127 | return grid, spacing 128 | 129 | 130 | def jensen_shannon_divergence(P, Q): 131 | if np.any(P < 0) or np.any(Q < 0): 132 | raise ValueError('Negative values.') 133 | if len(P) != len(Q): 134 | raise ValueError('Non equal size.') 135 | 136 | P_ = P / np.sum(P) 137 | Q_ = Q / np.sum(Q) 138 | 139 | e1 = entropy(P_, base=2) 140 | e2 = entropy(Q_, base=2) 141 | e_sum = entropy((P_ + Q_) / 2.0, base=2) 142 | return e_sum - ((e1 + e2) / 2.0) 143 | 144 | 145 | def downsample_pc(points, n): 146 | sample_idx = random.sample(list(range(points.shape[0])), n) 147 | return points[sample_idx] 148 | 149 | 150 | def normalize_pc(points): 151 | scale = np.max(np.abs(points)) 152 | points = points / scale 153 | return points 154 | 155 | 156 | def collect_pc(cad_folder): 157 | pc_path = find_files(os.path.join(cad_folder, 'pcd'), 'final_pcd.ply') 158 | if len(pc_path) == 0: 159 | return [] 160 | pc_path = pc_path[-1] 161 | pc = read_ply(pc_path) 162 | if pc.shape[0] > N_POINTS: 163 | pc = downsample_pc(pc, N_POINTS) 164 | pc = normalize_pc(pc) 165 | return pc 166 | 167 | 168 | def compute_pairwise_cd(sample_pcs, ref_pcs, batch_size): 169 | """ 170 | Compute one-to-one Chamfer Distance between sample and reference point clouds. 171 | Assumes sample_pcs and ref_pcs have the same number of point clouds and correspond to each other. 172 | 173 | Args: 174 | sample_pcs: Generated/reconstructed point clouds (N, num_points, 3) 175 | ref_pcs: Ground truth point clouds (N, num_points, 3) 176 | batch_size: Batch size for computation 177 | 178 | Returns: 179 | Dictionary with CD metrics 180 | """ 181 | assert sample_pcs.shape[0] == ref_pcs.shape[0], "Sample and reference must have same number of point clouds" 182 | 183 | N = sample_pcs.shape[0] 184 | chamfer_dist = ChamferDistance() 185 | all_cd = [] 186 | 187 | pbar = tqdm(range(0, N, batch_size), desc="Computing Pairwise Chamfer Distances") 188 | 189 | for i in pbar: 190 | end_idx = min(i + batch_size, N) 191 | sample_batch = sample_pcs[i:end_idx] # (B, num_points, 3) 192 | ref_batch = ref_pcs[i:end_idx] # (B, num_points, 3) 193 | 194 | dl, dr, _, _ = chamfer_dist(sample_batch, ref_batch) 195 | cd = dl.mean(dim=1) + dr.mean(dim=1) # (B,) 196 | all_cd.append(cd) 197 | 198 | all_cd = torch.cat(all_cd, dim=0) # (N,) 199 | 200 | avg_cd = all_cd.mean() 201 | median_cd = all_cd.median() 202 | std_cd = all_cd.std() 203 | 204 | return { 205 | 'Pairwise-Avg-CD': avg_cd.item(), 206 | 'Pairwise-Median-CD': median_cd.item(), 207 | 'Pairwise-Std-CD': std_cd.item(), 208 | 'All-CD': all_cd.cpu().numpy() # Return all individual CD values for further analysis 209 | } 210 | 211 | 212 | def main(): 213 | parser = argparse.ArgumentParser() 214 | parser.add_argument("--fake", type=str) 215 | parser.add_argument("--real", type=str) 216 | parser.add_argument("--output", type=str) 217 | parser.add_argument("--n_test", type=int, default=1988) 218 | parser.add_argument("--multi", type=int, default=3) 219 | parser.add_argument("--times", type=int, default=3) 220 | parser.add_argument("--batch_size", type=int, default=64) 221 | args = parser.parse_args() 222 | 223 | print("n_test: {}, multiplier: {}, repeat times: {}".format(args.n_test, args.multi, args.times)) 224 | if args.output is None: 225 | args.output = args.fake + '_cad_results.txt' 226 | 227 | # Load reference point clouds 228 | ref_pcs = [] 229 | project_folders = sorted(glob(args.real + '/*/')) 230 | load_iter = Pool(NUM_TRHEADS).imap(collect_pc, project_folders) 231 | for pc in tqdm(load_iter, total=len(project_folders), desc="Loading real point clouds"): 232 | if len(pc) > 0: 233 | ref_pcs.append(pc) 234 | ref_pcs = np.stack(ref_pcs, axis=0) 235 | print("Loaded real point clouds:", ref_pcs.shape) 236 | 237 | # Load generated point clouds 238 | sample_pcs = [] 239 | project_folders = sorted(glob(args.fake + '/*/')) 240 | load_iter = Pool(NUM_TRHEADS).imap(collect_pc, project_folders) 241 | for pc in tqdm(load_iter, total=len(project_folders), desc="Loading fake point clouds"): 242 | if len(pc) > 0: 243 | sample_pcs.append(pc) 244 | sample_pcs = np.stack(sample_pcs, axis=0) 245 | print("Loaded fake point clouds:", sample_pcs.shape) 246 | 247 | # Evaluation 248 | fp = open(args.output, "w") 249 | result_list = [] 250 | for i in range(args.times): 251 | print(f"Iteration {i}...") 252 | select_idx = random.sample(list(range(len(sample_pcs))), int(args.multi * args.n_test)) 253 | rand_sample_pcs = sample_pcs[select_idx] 254 | 255 | select_idx = random.sample(list(range(len(ref_pcs))), args.n_test) 256 | rand_ref_pcs = ref_pcs[select_idx] 257 | 258 | jsd = jsd_between_point_cloud_sets(rand_sample_pcs, rand_ref_pcs, in_unit_sphere=False) 259 | 260 | with torch.no_grad(): 261 | rand_sample_pcs = torch.tensor(rand_sample_pcs).cuda() 262 | rand_ref_pcs = torch.tensor(rand_ref_pcs).cuda() 263 | result = compute_avg_cd(rand_sample_pcs, rand_ref_pcs, batch_size=args.batch_size) 264 | 265 | result.update({"JSD": jsd}) 266 | 267 | print(result) 268 | print(result, file=fp) 269 | result_list.append(result) 270 | 271 | # Average results 272 | avg_result = {} 273 | for k in result_list[0].keys(): 274 | avg_result["avg-" + k] = np.mean([x[k] for x in result_list]) 275 | print("Average result:") 276 | print(avg_result) 277 | print(avg_result, file=fp) 278 | fp.close() 279 | 280 | 281 | if __name__ == '__main__': 282 | main() 283 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CAD-Editor 2 | 3 | Official implementation of **[ICML 2025] CAD-Editor: A Locate-then-Infill Framework with Automated Training Data Synthesis for Text-Based CAD Editing** by *Yu Yuan, Shizhao Sun, Qi Liu, Jiang Bian*. 4 | 5 | 📄 [Paper](https://arxiv.org/abs/2502.03997) | 🤗 [Model](https://huggingface.co/microsoft/CAD-Editor) | 🏠 [Project Page](https://cad-editor.github.io) 6 | 7 | 8 | ## Installation 9 | 10 | ```bash 11 | conda env create -f environment.yaml 12 | conda activate cad-editor 13 | ``` 14 | 15 | ## Data preparation 16 | 17 | We provide the complete data generation pipeline below for those who wish to generate their own dataset. 18 | We also share the data processed by us under `data/processed.zip`. 19 | 20 | ### 1. Paired CAD Generation 21 | 22 | **Step 1**: Generate design variations using hnc-cad. 23 | 24 | - Clone the [hnc-cad](https://github.com/samxuxiang/hnc-cad) repo. 25 | - Replace `gen/ac_gen.py` in the cloned hnc-cad repo with `hnc-cad/ac_gen.py` from this repo. Our updated version includes CAD model IDs (i.e., picture names) for pairing. 26 | - Follow the steps in the [hnc-cad](https://github.com/samxuxiang/hnc-cad) repo (especially `scripts/sample_cond.sh`) to generate design variations of a CAD model. 27 | 28 | 29 | **Step 2**: Convert generated `.obj` files to CAD sequences: 30 | 31 | ```python 32 | # Under utils folder: 33 | # Parse obj to primitive sequence 34 | python parse_obj2seq.py --input data \ 35 | --output data/dataset/train.pkl \ 36 | --bit 6 37 | 38 | # Convert to our sequence format 39 | python convert.py --in_path data/dataset/train.pkl \ 40 | --out_path data/dataset/train_converted.json 41 | ``` 42 | 43 | **Step 3**: Pair CAD sequences: 44 | 45 | ```python 46 | python data/pair.py --in_path data/dataset/train_converted.json \ 47 | --out_path data/dataset/train_converted_pair.json 48 | ``` 49 | 50 | ### 2. Editing Instruction Generation 51 | 52 | **Visual Level** 53 | 54 | (1) Render CAD objects to images. 55 | 56 | ```python 57 | timeout 180 python utils/visual_obj.py --data_folder 58 | 59 | python utils/cad_img.py --input_dir \ 60 | --output_dir 61 | ``` 62 | 63 | (2) Generate captions. Please update the OpenAI endpoint information in `data/caption_image.py` before running. 64 | 65 | ```python 66 | python data/caption_image.py --sequence_dir data/dataset/train_converted_pair.json \ 67 | --image_dir data/dataset/train_img \ 68 | --caption_path data/dataset/train_caption_image.json 69 | ``` 70 | 71 | **Sequence Level** 72 | 73 | (1) Generate captions. 74 | 75 | ```python 76 | python data/caption_sequence.py --in_path data/dataset/train_converted_pair_2.json \ 77 | --out_path data/dataset/train_caption_sequence.json 78 | ``` 79 | 80 | ### 3. Merge and filter long-tailed sequences. 81 | 82 | ```python 83 | python data/merge.py --file1 data/dataset/train_caption_image.json \ 84 | --file2 data/dataset/train_caption_sequence.json \ 85 | --output data/dataset/train_all.json 86 | 87 | python data/filter_sequence.py --in_path data/dataset/train_all.json \ 88 | --out_path data/dataset/train.json 89 | ``` 90 | 91 | 92 | 93 | ## Training 94 | 95 | ### 1. Locating Stage 96 | 97 | **Step 1**: Create ground-truth masked CAD sequences: 98 | 99 | ```python 100 | # All training and inference are performed under the finetune folder: 101 | python create_mask.py --input_path \ 102 | --output_path 103 | ``` 104 | 105 | **Step 2**: Run locate training with multiple GPUs. Change `num_processes` in `ds_config.yaml` to specify how many GPUs will be used. 106 | 107 | ```python 108 | CUDA_VISIBLE_DEVICES= accelerate launch --config_file ds_config.yaml llama_finetune.py --task_type mask \ 109 | --run_name \ 110 | --data_folder \ 111 | --eval_freq 1000000 \ 112 | --save_freq 10000 113 | ``` 114 | 115 | ### 2. Infilling Stage 116 | 117 | **Step 1**: Train infilling model: 118 | 119 | ```python 120 | CUDA_VISIBLE_DEVICES= accelerate launch --config_file ds_config.yaml finetune/llama_finetune.py --task_type infill \ 121 | --run_name \ 122 | --data_folder \ 123 | --eval_freq 1000000 \ 124 | --save_freq 10000 125 | ``` 126 | 127 | **Step 2.** Enhanced training with selective data. Set `model_path` to the pretrained model from Step 1. Change `data_folder` to the folder of your selective data. 128 | 129 | ```python 130 | CUDA_VISIBLE_DEVICES= accelerate launch --config_file ds_config.yaml finetune/llama_finetune.py --task_type infill_selective \ 131 | --run_name \ 132 | --pretrained_model_path \ 133 | --data_folder \ 134 | --eval_freq 1000000 \ 135 | --save_freq 10000 136 | ``` 137 | 138 | ## Inference 139 | 140 | Download our trained model checkpoints from [HuggingFace](https://huggingface.co/microsoft/CAD-Editor) to your ``````. 141 | 142 | ### 1. Locating Stage 143 | 144 | Generate masked sequences. Set the `` as ``. Set the `` as the path of `test.json` after unzip `data/processed.zip`. 145 | 146 | ```python 147 | CUDA_VISIBLE_DEVICES= python llama_sample.py \ 148 | --task_type mask \ 149 | --model_path \ 150 | --data_path \ 151 | --out_path \ 152 | --num_samples 153 | ``` 154 | 155 | ### 2. Infilling Stage 156 | 157 | Generate final edited CAD sequences. Set the `` as ``. Set the `` the same as the `out_path` of the locating stage. 158 | 159 | ```python 160 | CUDA_VISIBLE_DEVICES= python llama_sample.py \ 161 | --task_type infill \ 162 | --model_path \ 163 | --data_path \ 164 | --out_path \ 165 | --num_samples 166 | ``` 167 | 168 | ## Evaluation 169 | 170 | - Validity. 171 | 172 | ```python 173 | # Step 1: Parse the generated string to CAD obj. The in_path should be set the same as the out_path in the inference. 174 | python utils/parse_seq2obj.py --in_path \ 175 | --out_path \ 176 | --type infill 177 | 178 | # Step 2: Convert generated CAD obj to stl format. Use timeout command to prevent occ hanging. The data_folder should be set the same as the out_path in Step 1. 179 | timeout 180 python utils/visual_obj.py --data_folder 180 | 181 | # Step 3: Render and visualize to images. The input_dir should be set the same as the data_folder in Step 2. Use the number of successful generated images here to calculate the validity. 182 | python utils/cad_img.py --input_dir \ 183 | --output_dir 184 | ``` 185 | 186 | - 3D metrics (after running `visual_obj.py`). 187 | 188 | ```python 189 | # Under utils folder: 190 | # Uniformly sample points. Note that the generated CAD models and the ground truth test CAD models should be sampled respectively. 191 | python sample_points.py --in_dir \ 192 | --out_dir pcd 193 | 194 | # Evaluate performance. 195 | python eval_cad.py --fake \ 196 | --real 197 | ``` 198 | 199 | - Directional Clip Score ( Ensure you have run `cad_img.py` to render both the original and edited CAD sequences). 200 | 201 | ```python 202 | python eval_dclip.py --source_dir \ 203 | --edit_dir \ 204 | --instruction_path \ 205 | --out_path 206 | ``` 207 | 208 | ## Prompting-based Baselines 209 | 210 | We provide implementations of prompting-based baselines (including zero-shot and fewshot GPT-4o) under the `prompt/` folder. 211 | 212 | 213 | ## Citation 214 | 215 | If you find our work useful, please cite the following paper: 216 | 217 | ``` 218 | @article{yuan2025cad, 219 | title={CAD-Editor: A Locate-then-Infill Framework with Automated Training Data Synthesis for Text-Based CAD Editing}, 220 | author={Yuan, Yu and Sun, Shizhao and Liu, Qi and Bian, Jiang}, 221 | journal={Forty-Second International Conference on Machine Learning}, 222 | year={2025} 223 | } 224 | ``` 225 | 226 | ## Acknowledgements 227 | 228 | We would like to thank and acknowledge referenced codes from [hnc-cad](https://github.com/samxuxiang/hnc-cad), [SkexGen](https://github.com/samxuxiang/SkexGen) and [StyleGAN-nada](https://github.com/rinongal/StyleGAN-nada). 229 | 230 | 231 | ## Contributing 232 | 233 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 234 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 235 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 236 | 237 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 238 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 239 | provided by the bot. You will only need to do this once across all repos using our CLA. 240 | 241 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 242 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 243 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 244 | 245 | ## Trademarks 246 | 247 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 248 | trademarks or logos is subject to and must follow 249 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 250 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 251 | Any use of third-party trademarks or logos are subject to those third-party's policies. -------------------------------------------------------------------------------- /data/caption_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import base64 4 | import json 5 | import time 6 | import argparse 7 | from mimetypes import guess_type 8 | from openai import AzureOpenAI 9 | from azure.identity import DefaultAzureCredential, get_bearer_token_provider 10 | 11 | 12 | def local_image_to_data_url(image_path): 13 | # Encode a local image into data URL 14 | mime_type, _ = guess_type(image_path) 15 | if mime_type is None: 16 | mime_type = 'application/octet-stream' 17 | with open(image_path, "rb") as image_file: 18 | base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8') 19 | return f"data:{mime_type};base64,{base64_encoded_data}" 20 | 21 | # Configuration - IMPORTANT: Update these values before running 22 | # Replace with your actual Azure OpenAI endpoint 23 | endpoint = 'https://your-azure-openai-endpoint.openai.azure.com/' 24 | token_provider = get_bearer_token_provider( 25 | DefaultAzureCredential(), 26 | "https://cognitiveservices.azure.com/.default" 27 | ) 28 | deployment_name = 'gpt-4o' 29 | 30 | client = AzureOpenAI( 31 | azure_ad_token_provider=token_provider, 32 | azure_endpoint=endpoint, 33 | api_version='2024-02-01' 34 | ) 35 | 36 | def call_gpt4o_1(prompt, image_path): 37 | output = None 38 | message_text = [ 39 | {"role":"system","content":"You are an AI assistant that helps people find information."}, 40 | {"role":"user","content":[ 41 | { 42 | "type": "text", 43 | "text": prompt 44 | }, 45 | { 46 | "type": "image_url", 47 | "image_url": {"url": local_image_to_data_url(image_path)} 48 | } 49 | ]} 50 | ] 51 | while output is None: 52 | try: 53 | time.sleep(0.5) 54 | completion = client.chat.completions.create( 55 | model = deployment_name, 56 | messages = message_text, 57 | ) 58 | output = completion.choices[0].message.content 59 | except Exception as e: 60 | print("API call exceptions:", e) 61 | output = None 62 | return output 63 | 64 | def call_gpt4o_2(prompt1, image_path1, output1, prompt2, image_path2): 65 | output = None 66 | message_text = [ 67 | {"role":"system","content":"You are an AI assistant that helps people find information."}, 68 | {"role":"user","content":[ 69 | { 70 | "type": "text", 71 | "text": prompt1 72 | }, 73 | { 74 | "type": "image_url", 75 | "image_url": {"url": local_image_to_data_url(image_path1)} 76 | } 77 | ]}, 78 | {"role":"assistant","content":output1}, 79 | {"role":"user","content":[ 80 | { 81 | "type": "text", 82 | "text": prompt2 83 | }, 84 | { 85 | "type": "image_url", 86 | "image_url": {"url": local_image_to_data_url(image_path2)} 87 | } 88 | ]} 89 | ] 90 | while output is None: 91 | try: 92 | time.sleep(0.5) 93 | completion = client.chat.completions.create( 94 | model = deployment_name, 95 | messages = message_text, 96 | ) 97 | output = completion.choices[0].message.content 98 | except Exception as e: 99 | print("API call exceptions:", e) 100 | time.sleep(1) 101 | output = None 102 | return output 103 | 104 | 105 | def call_gpt4o_3(prompt1, image_path1, output1, prompt2, image_path2, output2, prompt3): 106 | output = None 107 | message_text = [ 108 | {"role":"system","content":"You are an AI assistant that helps people find information."}, 109 | {"role":"user","content":[ 110 | { 111 | "type": "text", 112 | "text": prompt1 113 | }, 114 | { 115 | "type": "image_url", 116 | "image_url": {"url": local_image_to_data_url(image_path1)} 117 | } 118 | ]}, 119 | {"role":"assistant","content":output1}, 120 | {"role":"user","content":[ 121 | { 122 | "type": "text", 123 | "text": prompt2 124 | }, 125 | { 126 | "type": "image_url", 127 | "image_url": {"url": local_image_to_data_url(image_path2)} 128 | } 129 | ]}, 130 | {"role":"assistant","content":output2}, 131 | {"role":"user","content":prompt3} 132 | ] 133 | while output is None: 134 | try: 135 | time.sleep(0.5) 136 | completion = client.chat.completions.create( 137 | model = deployment_name, 138 | messages = message_text, 139 | ) 140 | output = completion.choices[0].message.content 141 | except Exception as e: 142 | print("API call exceptions:", e) 143 | time.sleep(1) 144 | output = None 145 | return output 146 | 147 | 148 | def call_gpt4o_4(prompt1, image_path1, output1, prompt2, image_path2, output2, prompt3, output3, prompt4): 149 | output = None 150 | message_text = [ 151 | {"role":"system","content":"You are an AI assistant that helps people find information."}, 152 | {"role":"user","content":[ 153 | { 154 | "type": "text", 155 | "text": prompt1 156 | }, 157 | { 158 | "type": "image_url", 159 | "image_url": {"url": local_image_to_data_url(image_path1)} 160 | } 161 | ]}, 162 | {"role":"assistant","content":output1}, 163 | {"role":"user","content":[ 164 | { 165 | "type": "text", 166 | "text": prompt2 167 | }, 168 | { 169 | "type": "image_url", 170 | "image_url": {"url": local_image_to_data_url(image_path2)} 171 | } 172 | ]}, 173 | {"role":"assistant","content":output2}, 174 | {"role":"user","content":prompt3}, 175 | {"role":"assistant","content":output3}, 176 | {"role":"user","content":prompt4} 177 | ] 178 | while output is None: 179 | try: 180 | time.sleep(0.5) 181 | completion = client.chat.completions.create( 182 | model = deployment_name, 183 | messages = message_text, 184 | ) 185 | output = completion.choices[0].message.content 186 | except Exception as e: 187 | print("API call exceptions:", e) 188 | time.sleep(1) 189 | output = None 190 | return output 191 | 192 | 193 | def find_files_with_prefix(directory, prefix): 194 | files = sorted(os.listdir(directory)) 195 | matching_files = [file_name for file_name in files if file_name.startswith(prefix)] 196 | 197 | return matching_files 198 | 199 | def multi_level_captioning(args): 200 | with open(args.sequence_dir, 'r', encoding='utf-8') as file: 201 | data = json.load(file) 202 | 203 | for idx, item in enumerate(data): 204 | original_prefix = item['original_pic_name']+'_' 205 | original_file = find_files_with_prefix(args.image_dir, original_prefix)[0] 206 | source_path = os.path.join(args.image_dir, original_file) 207 | edit_prefix = item['edited_pic_name']+'_' 208 | edit_file = find_files_with_prefix(args.image_dir, edit_prefix)[0] 209 | edit_path = os.path.join(args.image_dir, edit_file) 210 | 211 | print(original_file, edit_file) 212 | 213 | time.sleep(0.5) 214 | output1 = None 215 | output2 = None 216 | output3 = None 217 | output4 = None 218 | 219 | prompt1 = """Please take a look at the first of two 3D shapes we'll be examining. Please provide a detailed description, focusing on its geometric properties, including the type and number of elements it features, the proportions of its size, its positional relationships between elements, and any additional details that stand out.""" 220 | prompt2 = """Now, let's turn our attention to the second 3D shape. Please provide a detailed description, focusing on its geometric properties, including the type and number of elements it features, the proportions of its size, its positional relationships between elements, and any additional details that stand out.""" 221 | prompt3 = """Please provide concise instructions for transforming the first 3D shape into the second. """ 222 | prompt4 = """Condense your instructions to one sentence, 10 words maximum.""" 223 | 224 | while output1 is None or str(output1).startswith("I'm sorry"): 225 | try: 226 | output1 = call_gpt4o_1(prompt1, source_path) 227 | except requests.RequestException as e: 228 | print(f"Request failed: {e}") 229 | time.sleep(0.5) 230 | output1 = None 231 | while output2 is None or str(output2).startswith("I'm sorry"): 232 | try: 233 | output2 = call_gpt4o_2(prompt1, source_path, output1, prompt2, edit_path) 234 | except requests.RequestException as e: 235 | print(f"Request failed: {e}") 236 | time.sleep(0.5) 237 | output2 = None 238 | while output3 is None or str(output3).startswith("I'm sorry"): 239 | try: 240 | output3 = call_gpt4o_3(prompt1, source_path, output1, prompt2, edit_path, output2, prompt3) 241 | except requests.RequestException as e: 242 | print(f"Request failed: {e}") 243 | time.sleep(0.5) 244 | output3 = None 245 | while output4 is None or str(output4).startswith("I'm sorry"): 246 | try: 247 | output4 = call_gpt4o_4(prompt1, source_path, output1, prompt2, edit_path, output2, prompt3, output3, prompt4) 248 | except requests.RequestException as e: 249 | print(f"Request failed: {e}") 250 | time.sleep(0.5) 251 | output4 = None 252 | 253 | result = { 254 | "original_pic_name": original_file, 255 | "edited_pic_name": edit_file, 256 | "original_sequence": item['original_sequence'], 257 | "edited_sequence": item['edited_sequence'], 258 | "type": item["type"], 259 | "method": "image", 260 | "description_original": output1, 261 | "description_edited":output2, 262 | "detailed_instruction": output3, 263 | "instruction":output4, 264 | } 265 | 266 | with open(args.caption_path, 'a', encoding='utf-8') as f: 267 | json.dump(result, f, ensure_ascii=False, indent=4) 268 | f.write(',\n') 269 | 270 | 271 | if __name__ == "__main__": 272 | parser = argparse.ArgumentParser() 273 | parser.add_argument("--sequence_dir", type=str, required=True) 274 | parser.add_argument("--image_dir", type=str, required=True) 275 | parser.add_argument("--caption_path", type=str, required=True) 276 | args = parser.parse_args() 277 | 278 | multi_level_captioning(args) -------------------------------------------------------------------------------- /finetune/llama_finetune.py: -------------------------------------------------------------------------------- 1 | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel 2 | from torch.utils.data import Dataset 3 | from transformers import Trainer, TrainingArguments 4 | import transformers 5 | from dataclasses import dataclass 6 | import json 7 | import pickle 8 | from pathlib import Path 9 | import numpy as np 10 | import warnings 11 | import random 12 | import torch 13 | import argparse 14 | import glob 15 | import os 16 | # from diffusers.training_utils import cast_training_params 17 | 18 | 19 | IGNORE_INDEX = -100 20 | MAX_LENGTH = 1024 21 | DEFAULT_PAD_TOKEN = "[PAD]" 22 | DEFAULT_EOS_TOKEN = "" 23 | DEFAULT_BOS_TOKEN = "" 24 | DEFAULT_UNK_TOKEN = "" 25 | 26 | 27 | class CADDataset(Dataset): 28 | def __init__(self, json_fn, task_type="mask", llama_tokenizer=None): 29 | if not os.path.exists(json_fn): 30 | raise ValueError(f"{json_fn} does not exist") 31 | self.inputs = json.load(open(json_fn, "r")) 32 | self.llama_tokenizer = llama_tokenizer 33 | self.task_type = task_type # 'mask', 'infill', or 'infill_selective' 34 | 35 | def __len__(self): 36 | return len(self.inputs) 37 | 38 | def __getitem__(self, index): 39 | item = self.inputs[index] 40 | src_seq = item['original_sequence'] 41 | mask_seq = item['masked_sequence'] 42 | instruction = item['instruction'] 43 | 44 | # For infilly tasks, also get edited sequence 45 | if (self.task_type == "infill" or self.task_type == "infill_selective") and 'edited_sequence' in item: 46 | edit_seq = item['edited_sequence'] 47 | val = self.tokenize(src_seq, instruction, mask_seq, edit_seq) 48 | else: 49 | val = self.tokenize(src_seq, instruction, mask_seq) 50 | 51 | return val 52 | 53 | def tokenize(self, src_seq, instruction, mask_seq, edit_seq=None): 54 | if self.task_type == "infill" or self.task_type == "infill_selective": 55 | tokens, prompt_length = self.conditional_generation_infill( 56 | src_seq, instruction, mask_seq, edit_seq) 57 | else: 58 | tokens, prompt_length = self.conditional_generation_mask( 59 | src_seq, instruction, mask_seq) 60 | 61 | input_ids = tokens.input_ids[0] 62 | labels = tokens.input_ids[0].clone() 63 | # Set the labels for the prompt part to IGNORE_INDEX so they are ignored in loss calculation 64 | labels[:prompt_length] = IGNORE_INDEX 65 | input_id_lens = label_lens = ( 66 | tokens.input_ids.ne(self.llama_tokenizer.pad_token_id).sum().item() 67 | ) 68 | return dict( 69 | input_ids=input_ids, 70 | input_id_lens=input_id_lens, 71 | labels=labels, 72 | label_lens=label_lens, 73 | ) 74 | 75 | def conditional_generation_mask(self, src_seq, instruction, mask_seq): 76 | prompt = f"""Below is a Computer-Aided Design (CAD) operation sequence, replace the parts that need to be modified with the string "" according to the editing instruction. 77 | Original CAD Operation Sequence: 78 | {src_seq} 79 | Editing Instruction: 80 | {instruction} 81 | Masked CAD Operation Sequence: 82 | """ 83 | 84 | full_text = prompt + mask_seq + self.llama_tokenizer.eos_token 85 | tokens = self.llama_tokenizer( 86 | full_text, 87 | max_length=MAX_LENGTH, 88 | return_tensors="pt", 89 | truncation=True, 90 | ) 91 | prompt_length = len(self.llama_tokenizer(prompt)['input_ids']) 92 | return tokens, prompt_length 93 | 94 | def conditional_generation_infill(self, src_seq, instruction, mask_seq, edit_seq): 95 | prompt = f"""Below is the original Computer-Aided Design (CAD) operation sequence. 96 | Original CAD Operation Sequence: 97 | {src_seq} 98 | 99 | The parts that need to be modified according to the editing instruction have been replaced by the string "". 100 | Editing Instruction: 101 | {instruction} 102 | Masked CAD Operation Sequence: 103 | {mask_seq} 104 | 105 | Based on the original CAD sequence, the editing instruction, and the masked sequence, generate the complete edited CAD sequence by replacing "" with the appropriate content: 106 | """ 107 | 108 | full_text = prompt + edit_seq + self.llama_tokenizer.eos_token 109 | tokens = self.llama_tokenizer( 110 | full_text, 111 | max_length=MAX_LENGTH, 112 | return_tensors="pt", 113 | truncation=True, 114 | ) 115 | prompt_length = len(self.llama_tokenizer(prompt)['input_ids']) 116 | return tokens, prompt_length 117 | 118 | 119 | @dataclass 120 | class DataCollatorForSupervisedDataset(object): 121 | """Collate examples for supervised fine-tuning.""" 122 | 123 | tokenizer: transformers.PreTrainedTokenizer 124 | 125 | def __call__(self, instances): 126 | input_ids, labels = tuple( 127 | [instance[key].clone().detach() for instance in instances] 128 | for key in ("input_ids", "labels") 129 | ) 130 | input_ids = torch.nn.utils.rnn.pad_sequence( 131 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 132 | ) 133 | labels = torch.nn.utils.rnn.pad_sequence( 134 | labels, batch_first=True, padding_value=IGNORE_INDEX 135 | ) 136 | return dict( 137 | input_ids=input_ids, 138 | labels=labels, 139 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 140 | ) 141 | 142 | 143 | def setup_datasets(args, llama_tokenizer, transform_args={}): 144 | train_file = "train.json" 145 | val_file = "test.json" 146 | 147 | datasets = { 148 | "train": CADDataset( 149 | str(args.data_folder / train_file), 150 | task_type=args.task_type, 151 | llama_tokenizer=llama_tokenizer, 152 | ), 153 | "val": CADDataset( 154 | str(args.data_folder / val_file), 155 | task_type=args.task_type, 156 | llama_tokenizer=llama_tokenizer, 157 | ), 158 | } 159 | 160 | return datasets 161 | 162 | 163 | def setup_training_args(args): 164 | output_dir = args.expdir / args.run_name 165 | output_dir.mkdir(parents=True, exist_ok=True) 166 | 167 | if args.debug: 168 | report_to = "none" 169 | else: 170 | report_to = "wandb" 171 | os.environ["ACCELERATE_MIXED_PRECISION"] = "no" 172 | training_args = TrainingArguments( 173 | fsdp=False, 174 | fp16=not args.fp8, 175 | bf16=False, 176 | gradient_checkpointing=False, 177 | ddp_find_unused_parameters=False, 178 | num_train_epochs=args.num_epochs, 179 | eval_steps=args.eval_freq, 180 | save_steps=args.save_freq, 181 | logging_steps=10, 182 | eval_strategy="steps", # Use modern parameter name 183 | per_device_train_batch_size=args.batch_size, 184 | per_device_eval_batch_size=args.batch_size, 185 | learning_rate=args.lr, 186 | lr_scheduler_type=args.lr_scheduler, 187 | warmup_steps=args.num_warmup_steps, 188 | weight_decay=args.weight_decay, 189 | gradient_accumulation_steps=args.grad_accum, 190 | output_dir=output_dir, 191 | run_name=args.run_name, 192 | report_to=report_to, 193 | dataloader_num_workers=8, 194 | remove_unused_columns=False, 195 | # this is just to get trainer to behave how I want 196 | label_names=["cad_ids"], 197 | ) 198 | return training_args 199 | 200 | 201 | def smart_tokenizer_and_embedding_resize( 202 | special_tokens_dict, 203 | llama_tokenizer, 204 | model, 205 | ): 206 | """Resize tokenizer and embedding. 207 | 208 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 209 | """ 210 | num_new_tokens = llama_tokenizer.add_special_tokens(special_tokens_dict) 211 | model.resize_token_embeddings(len(llama_tokenizer)) 212 | 213 | if num_new_tokens > 0: 214 | input_embeddings = model.get_input_embeddings().weight.data 215 | output_embeddings = model.get_output_embeddings().weight.data 216 | 217 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 218 | dim=0, keepdim=True 219 | ) 220 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 221 | dim=0, keepdim=True 222 | ) 223 | 224 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 225 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 226 | 227 | 228 | def setup_model(args, rank): 229 | model_id = "meta-llama/Meta-Llama-3-8B-Instruct" 230 | pipeline = transformers.pipeline("text2text-generation", 231 | model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}) 232 | llama_tokenizer = pipeline.tokenizer 233 | base_model = pipeline.model 234 | 235 | device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu") 236 | base_model.to(device) 237 | 238 | lora_config = LoraConfig( 239 | r=args.lora_rank, 240 | lora_alpha=args.lora_alpha, 241 | lora_dropout=args.lora_dropout, 242 | bias="none", 243 | task_type="CAUSAL_LM", 244 | ) 245 | 246 | # For selective infilly, load a pre-trained infilly model checkpoint 247 | if args.task_type == "infill_selective" and args.pretrained_model_path: 248 | print(f"Loading pre-trained model from {args.pretrained_model_path}") 249 | peft_model = PeftModel.from_pretrained(base_model, args.pretrained_model_path, device_map="auto") 250 | peft_model.to(device) 251 | original_state_dict = {f"{k}": v for k, v in peft_model.state_dict().items()} 252 | 253 | model = get_peft_model(base_model, lora_config) 254 | model.load_state_dict(original_state_dict, strict=True) 255 | else: 256 | # For mask or initial infilly training 257 | model = get_peft_model(base_model, lora_config) 258 | 259 | model.print_trainable_parameters() 260 | 261 | special_tokens_dict = dict() 262 | if llama_tokenizer.pad_token is None: 263 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 264 | if llama_tokenizer.eos_token is None: 265 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 266 | if llama_tokenizer.bos_token is None: 267 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 268 | if llama_tokenizer.unk_token is None: 269 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 270 | 271 | smart_tokenizer_and_embedding_resize( 272 | special_tokens_dict=special_tokens_dict, 273 | llama_tokenizer=llama_tokenizer, 274 | model=model, 275 | ) 276 | 277 | return model, llama_tokenizer 278 | 279 | 280 | def setup_trainer(args): 281 | training_args = setup_training_args(args) 282 | model, llama_tokenizer = setup_model(args, training_args.local_rank) 283 | 284 | datasets = setup_datasets(args, llama_tokenizer) 285 | 286 | data_collator = DataCollatorForSupervisedDataset( 287 | tokenizer=llama_tokenizer, 288 | ) 289 | 290 | trainer = Trainer( 291 | model=model, 292 | args=training_args, 293 | train_dataset=datasets["train"], 294 | eval_dataset=datasets["val"], 295 | data_collator=data_collator, 296 | ) 297 | 298 | return trainer 299 | 300 | 301 | def main(args): 302 | trainer = setup_trainer(args) 303 | 304 | if args.resume_dir is not None: 305 | train_result = trainer.train(resume_from_checkpoint=args.resume_dir) 306 | else: 307 | train_result = trainer.train() 308 | 309 | print(train_result) 310 | trainer.save_state() 311 | trainer.save_model() 312 | 313 | 314 | if __name__ == "__main__": 315 | parser = argparse.ArgumentParser() 316 | parser.add_argument("--task_type", type=str, choices=["mask", "infill", "infill_selective"], 317 | default="mask", help="Task type: 'mask' for masking parts, 'infill' for infilly training, 'infill_selective' for selective infilly training") 318 | parser.add_argument("--run_name", type=str, required=True) 319 | parser.add_argument("--pretrained_model_path", type=str, default=None, 320 | help="Path to pretrained model checkpoint (required for infill_selective)") 321 | parser.add_argument("--expdir", type=Path, default="model/") 322 | parser.add_argument("--fp8", action="store_true", default=False) 323 | parser.add_argument("--lora_rank", type=int, default=32) 324 | parser.add_argument("--lora_alpha", type=int, default=32) 325 | parser.add_argument("--lora_dropout", type=float, default=0.05) 326 | parser.add_argument("--data_folder", type=Path, default="/data/dataset/") 327 | parser.add_argument("--num_epochs", type=int, default=100) 328 | parser.add_argument("--batch_size", type=int, default=4) 329 | parser.add_argument("--grad_accum", type=int, default=1) 330 | parser.add_argument("--lr", type=float, default=1e-4) 331 | parser.add_argument("--lr_scheduler", type=str, default="cosine") 332 | parser.add_argument("--num_warmup_steps", type=int, default=100) 333 | parser.add_argument("--weight_decay", type=float, default=0.0) 334 | parser.add_argument("--eval_freq", default=100000000, type=int) 335 | parser.add_argument("--save_freq", default=1000, type=int) 336 | parser.add_argument("--resume_dir", type=Path, default=None) 337 | parser.add_argument("--debug", action="store_true", default=False) 338 | 339 | args = parser.parse_args() 340 | 341 | # Validate arguments 342 | if args.task_type == "infill_selective" and not args.pretrained_model_path: 343 | raise ValueError("For infill_selective training, pretrained_model_path must be provided") 344 | 345 | # Set WANDB project name 346 | if not args.debug: 347 | os.environ["WANDB_PROJECT"] = "CAD-Editor" 348 | print(args) 349 | main(args) -------------------------------------------------------------------------------- /utils/directional_clip_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | 7 | import clip 8 | from PIL import Image 9 | 10 | 11 | class DirectionLoss(torch.nn.Module): 12 | 13 | def __init__(self, loss_type='mse'): 14 | super(DirectionLoss, self).__init__() 15 | 16 | self.loss_type = loss_type 17 | 18 | self.loss_func = { 19 | 'mse': torch.nn.MSELoss, 20 | 'cosine': torch.nn.CosineSimilarity, 21 | 'mae': torch.nn.L1Loss 22 | }[loss_type]() 23 | 24 | def forward(self, x, y): 25 | if self.loss_type == "cosine": 26 | return self.loss_func(x, y) 27 | 28 | return self.loss_func(x, y) 29 | 30 | class CLIPLoss(torch.nn.Module): 31 | def __init__(self, device, lambda_direction=1., lambda_patch=0., lambda_global=0., lambda_manifold=0., lambda_texture=0., patch_loss_type='mae', direction_loss_type='cosine', clip_model='ViT-B/32'): 32 | super(CLIPLoss, self).__init__() 33 | 34 | self.device = device 35 | self.model, clip_preprocess = clip.load(clip_model, device=self.device) 36 | 37 | self.clip_preprocess = clip_preprocess 38 | 39 | self.preprocess = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. 40 | clip_preprocess.transforms[:2] + # to match CLIP input scale assumptions 41 | clip_preprocess.transforms[4:]) # + skip convert PIL to tensor 42 | 43 | self.target_direction = None 44 | self.patch_text_directions = None 45 | 46 | self.patch_loss = DirectionLoss(patch_loss_type) 47 | self.direction_loss = DirectionLoss(direction_loss_type) 48 | self.patch_direction_loss = torch.nn.CosineSimilarity(dim=2) 49 | 50 | self.lambda_global = lambda_global 51 | self.lambda_patch = lambda_patch 52 | self.lambda_direction = lambda_direction 53 | self.lambda_manifold = lambda_manifold 54 | self.lambda_texture = lambda_texture 55 | 56 | self.src_text_features = None 57 | self.target_text_features = None 58 | self.angle_loss = torch.nn.L1Loss() 59 | 60 | self.model_cnn, preprocess_cnn = clip.load("RN50", device=self.device) 61 | self.preprocess_cnn = transforms.Compose([transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0])] + # Un-normalize from [-1.0, 1.0] (GAN output) to [0, 1]. 62 | preprocess_cnn.transforms[:2] + # to match CLIP input scale assumptions 63 | preprocess_cnn.transforms[4:]) # + skip convert PIL to tensor 64 | 65 | self.model.requires_grad_(False) 66 | self.model_cnn.requires_grad_(False) 67 | 68 | self.texture_loss = torch.nn.MSELoss() 69 | 70 | def tokenize(self, strings: list): 71 | return clip.tokenize(strings).to(self.device) 72 | 73 | def encode_text(self, tokens: list) -> torch.Tensor: 74 | return self.model.encode_text(tokens) 75 | 76 | def encode_images(self, images: torch.Tensor) -> torch.Tensor: 77 | images = self.preprocess(images).to(self.device) 78 | return self.model.encode_image(images) 79 | 80 | def encode_images_with_cnn(self, images: torch.Tensor) -> torch.Tensor: 81 | images = self.preprocess_cnn(images).to(self.device) 82 | return self.model_cnn.encode_image(images) 83 | 84 | 85 | def get_text_features(self, instruction_str: str, norm: bool = True) -> torch.Tensor: 86 | 87 | tokens = clip.tokenize(instruction_str).to(self.device) 88 | 89 | text_features = self.encode_text(tokens).detach() 90 | 91 | if norm: 92 | text_features /= text_features.norm(dim=-1, keepdim=True) 93 | 94 | return text_features 95 | 96 | def get_image_features(self, img: torch.Tensor, norm: bool = True) -> torch.Tensor: 97 | image_features = self.encode_images(img) 98 | 99 | if norm: 100 | image_features /= image_features.clone().norm(dim=-1, keepdim=True) 101 | 102 | return image_features 103 | 104 | def compute_text_direction(self, source_class: str, target_class: str) -> torch.Tensor: 105 | source_features = self.get_text_features(source_class) 106 | target_features = self.get_text_features(target_class) 107 | 108 | text_direction = (target_features - source_features).mean(axis=0, keepdim=True) 109 | text_direction /= text_direction.norm(dim=-1, keepdim=True) 110 | 111 | return text_direction 112 | 113 | def compute_text_direction_instruction(self, instruction: str) -> torch.Tensor: 114 | text_direction = self.get_text_features(instruction) 115 | text_direction /= text_direction.norm(dim=-1, keepdim=True) 116 | 117 | return text_direction 118 | 119 | 120 | def compute_img2img_direction(self, source_images: torch.Tensor, target_images: list) -> torch.Tensor: 121 | with torch.no_grad(): 122 | 123 | src_encoding = self.get_image_features(source_images) 124 | src_encoding = src_encoding.mean(dim=0, keepdim=True) 125 | 126 | target_encodings = [] 127 | for target_img in target_images: 128 | 129 | preprocessed = self.clip_preprocess(Image.open(target_img)).unsqueeze(0).to(self.device) 130 | 131 | encoding = self.model.encode_image(preprocessed) 132 | encoding /= encoding.norm(dim=-1, keepdim=True) 133 | 134 | target_encodings.append(encoding) 135 | 136 | target_encoding = torch.cat(target_encodings, axis=0) 137 | target_encoding = target_encoding.mean(dim=0, keepdim=True) 138 | 139 | direction = target_encoding - src_encoding 140 | direction /= direction.norm(dim=-1, keepdim=True) 141 | 142 | return direction 143 | 144 | def set_text_features(self, source_class: str, target_class: str) -> None: 145 | source_features = self.get_text_features(source_class).mean(axis=0, keepdim=True) 146 | self.src_text_features = source_features / source_features.norm(dim=-1, keepdim=True) 147 | 148 | target_features = self.get_text_features(target_class).mean(axis=0, keepdim=True) 149 | self.target_text_features = target_features / target_features.norm(dim=-1, keepdim=True) 150 | 151 | def clip_angle_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: 152 | if self.src_text_features is None: 153 | self.set_text_features(source_class, target_class) 154 | 155 | cos_text_angle = self.target_text_features @ self.src_text_features.T 156 | text_angle = torch.acos(cos_text_angle) 157 | 158 | src_img_features = self.get_image_features(src_img).unsqueeze(2) 159 | target_img_features = self.get_image_features(target_img).unsqueeze(1) 160 | 161 | cos_img_angle = torch.clamp(target_img_features @ src_img_features, min=-1.0, max=1.0) 162 | img_angle = torch.acos(cos_img_angle) 163 | 164 | text_angle = text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1) 165 | cos_text_angle = cos_text_angle.unsqueeze(0).repeat(img_angle.size()[0], 1, 1) 166 | 167 | return self.angle_loss(cos_img_angle, cos_text_angle) 168 | 169 | def clip_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: 170 | 171 | if self.target_direction is None: 172 | self.target_direction = self.compute_text_direction(source_class, target_class) 173 | 174 | src_encoding = self.get_image_features(src_img) 175 | target_encoding = self.get_image_features(target_img) 176 | 177 | edit_direction = (target_encoding - src_encoding) 178 | if edit_direction.sum() == 0: 179 | target_encoding = self.get_image_features(target_img + 1e-6) 180 | edit_direction = (target_encoding - src_encoding) 181 | 182 | edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True)) 183 | 184 | return self.direction_loss(edit_direction, self.target_direction).mean() 185 | 186 | def clip_directional_loss_instruction(self, src_img: torch.Tensor, instruction: str, target_img: torch.Tensor) -> torch.Tensor: 187 | 188 | if self.target_direction is None: 189 | self.target_direction = self.compute_text_direction_instruction(instruction) 190 | 191 | src_encoding = self.get_image_features(src_img) 192 | target_encoding = self.get_image_features(target_img) 193 | 194 | edit_direction = (target_encoding - src_encoding) 195 | if edit_direction.sum() == 0: 196 | target_encoding = self.get_image_features(target_img + 1e-6) 197 | edit_direction = (target_encoding - src_encoding) 198 | 199 | edit_direction /= (edit_direction.clone().norm(dim=-1, keepdim=True)) 200 | 201 | return self.direction_loss(edit_direction, self.target_direction).mean() 202 | 203 | def global_clip_loss(self, img: torch.Tensor, text) -> torch.Tensor: 204 | if not isinstance(text, list): 205 | text = [text] 206 | 207 | tokens = clip.tokenize(text).to(self.device) 208 | image = self.preprocess(img) 209 | 210 | logits_per_image, _ = self.model(image, tokens) 211 | 212 | return (1. - logits_per_image / 100).mean() 213 | 214 | def random_patch_centers(self, img_shape, num_patches, size): 215 | batch_size, channels, height, width = img_shape 216 | 217 | half_size = size // 2 218 | patch_centers = np.concatenate([np.random.randint(half_size, width - half_size, size=(batch_size * num_patches, 1)), 219 | np.random.randint(half_size, height - half_size, size=(batch_size * num_patches, 1))], axis=1) 220 | 221 | return patch_centers 222 | 223 | def generate_patches(self, img: torch.Tensor, patch_centers, size): 224 | batch_size = img.shape[0] 225 | num_patches = len(patch_centers) // batch_size 226 | half_size = size // 2 227 | 228 | patches = [] 229 | 230 | for batch_idx in range(batch_size): 231 | for patch_idx in range(num_patches): 232 | 233 | center_x = patch_centers[batch_idx * num_patches + patch_idx][0] 234 | center_y = patch_centers[batch_idx * num_patches + patch_idx][1] 235 | 236 | patch = img[batch_idx:batch_idx+1, :, center_y - half_size:center_y + half_size, center_x - half_size:center_x + half_size] 237 | 238 | patches.append(patch) 239 | 240 | patches = torch.cat(patches, axis=0) 241 | 242 | return patches 243 | 244 | def patch_scores(self, img: torch.Tensor, class_str: str, patch_centers, patch_size: int) -> torch.Tensor: 245 | 246 | parts = self.compose_text_with_templates(class_str, part_templates) 247 | tokens = clip.tokenize(parts).to(self.device) 248 | text_features = self.encode_text(tokens).detach() 249 | 250 | patches = self.generate_patches(img, patch_centers, patch_size) 251 | image_features = self.get_image_features(patches) 252 | 253 | similarity = image_features @ text_features.T 254 | 255 | return similarity 256 | 257 | def clip_patch_similarity(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: 258 | patch_size = 196 #TODO remove magic number 259 | 260 | patch_centers = self.random_patch_centers(src_img.shape, 4, patch_size) #TODO remove magic number 261 | 262 | src_scores = self.patch_scores(src_img, source_class, patch_centers, patch_size) 263 | target_scores = self.patch_scores(target_img, target_class, patch_centers, patch_size) 264 | 265 | return self.patch_loss(src_scores, target_scores) 266 | 267 | def patch_directional_loss(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str) -> torch.Tensor: 268 | 269 | if self.patch_text_directions is None: 270 | src_part_classes = self.compose_text_with_templates(source_class, part_templates) 271 | target_part_classes = self.compose_text_with_templates(target_class, part_templates) 272 | 273 | parts_classes = list(zip(src_part_classes, target_part_classes)) 274 | 275 | self.patch_text_directions = torch.cat([self.compute_text_direction(pair[0], pair[1]) for pair in parts_classes], dim=0) 276 | 277 | patch_size = 510 # TODO remove magic numbers 278 | 279 | patch_centers = self.random_patch_centers(src_img.shape, 1, patch_size) 280 | 281 | patches = self.generate_patches(src_img, patch_centers, patch_size) 282 | src_features = self.get_image_features(patches) 283 | 284 | patches = self.generate_patches(target_img, patch_centers, patch_size) 285 | target_features = self.get_image_features(patches) 286 | 287 | edit_direction = (target_features - src_features) 288 | edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True) 289 | 290 | cosine_dists = 1. - self.patch_direction_loss(edit_direction.unsqueeze(1), self.patch_text_directions.unsqueeze(0)) 291 | 292 | patch_class_scores = cosine_dists * (edit_direction @ self.patch_text_directions.T).softmax(dim=-1) 293 | 294 | return patch_class_scores.mean() 295 | 296 | def cnn_feature_loss(self, src_img: torch.Tensor, target_img: torch.Tensor) -> torch.Tensor: 297 | src_features = self.encode_images_with_cnn(src_img) 298 | target_features = self.encode_images_with_cnn(target_img) 299 | 300 | return self.texture_loss(src_features, target_features) 301 | 302 | def forward(self, src_img: torch.Tensor, source_class: str, target_img: torch.Tensor, target_class: str, texture_image: torch.Tensor = None): 303 | clip_loss = 0.0 304 | 305 | if self.lambda_global: 306 | clip_loss += self.lambda_global * self.global_clip_loss(target_img, [f"a {target_class}"]) 307 | 308 | if self.lambda_patch: 309 | clip_loss += self.lambda_patch * self.patch_directional_loss(src_img, source_class, target_img, target_class) 310 | 311 | if self.lambda_direction: 312 | clip_loss += self.lambda_direction * self.clip_directional_loss(src_img, source_class, target_img, target_class) 313 | 314 | if self.lambda_manifold: 315 | clip_loss += self.lambda_manifold * self.clip_angle_loss(src_img, source_class, target_img, target_class) 316 | 317 | if self.lambda_texture and (texture_image is not None): 318 | clip_loss += self.lambda_texture * self.cnn_feature_loss(texture_image, target_img) 319 | 320 | return clip_loss -------------------------------------------------------------------------------- /utils/obj_reconverter.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | from util import create_point, create_unit_vec, get_transform, create_sketch_plane 4 | 5 | # OCC 6 | from OCC.Core.BRepCheck import BRepCheck_Analyzer 7 | from OCC.Core.GC import GC_MakeArcOfCircle 8 | from OCC.Core.BRepBuilderAPI import ( 9 | BRepBuilderAPI_MakeFace, 10 | BRepBuilderAPI_MakeWire, 11 | BRepBuilderAPI_MakeEdge, 12 | ) 13 | from OCC.Core.BRepAlgoAPI import BRepAlgoAPI_Fuse, BRepAlgoAPI_Cut, BRepAlgoAPI_Common 14 | from OCC.Core.BRepPrimAPI import BRepPrimAPI_MakePrism 15 | from OCC.Core.BRepAdaptor import BRepAdaptor_Surface 16 | from OCC.Core.BRepGProp import brepgprop_VolumeProperties, brepgprop_SurfaceProperties 17 | from OCC.Core.GProp import GProp_GProps 18 | from OCC.Core.ShapeFix import ShapeFix_Face, ShapeFix_Wire 19 | from OCC.Core.gp import gp_Vec, gp_Ax2, gp_Dir, gp_Circ 20 | from OCC.Extend.DataExchange import write_stl_file 21 | 22 | 23 | class OBJReconverter: 24 | """OBJ Data Reconverter""" 25 | 26 | def __init__(self): 27 | self.vertex_dict = OrderedDict() 28 | self.PRECISION = 1e-5 29 | self.eps = 1e-7 30 | self.x_axis = gp_Dir(1.0, 0.0, 0.0) 31 | 32 | def convert_curve(self, curve): 33 | """ 34 | convert to json dict format 35 | """ 36 | json_curve = {} 37 | 38 | if curve.type == "circle": 39 | json_curve["type"] = "Circle3D" 40 | json_curve["center_point"] = { 41 | "x": curve.center[0], 42 | "y": curve.center[1], 43 | "z": 0, 44 | } 45 | json_curve["radius"] = curve.radius 46 | 47 | if curve.type == "line": 48 | json_curve["type"] = "Line3D" 49 | json_curve["start_point"] = { 50 | "x": curve.start[0], 51 | "y": curve.start[1], 52 | "z": 0, 53 | } 54 | json_curve["end_point"] = {"x": curve.end[0], "y": curve.end[1], "z": 0} 55 | 56 | if curve.type == "arc": 57 | json_curve["type"] = "Arc3D" 58 | json_curve["start_point"] = { 59 | "x": curve.start[0], 60 | "y": curve.start[1], 61 | "z": 0, 62 | } 63 | json_curve["end_point"] = {"x": curve.end[0], "y": curve.end[1], "z": 0} 64 | json_curve["mid_point"] = {"x": curve.mid[0], "y": curve.mid[1], "z": 0} 65 | json_curve["center_point"] = { 66 | "x": curve.center[0], 67 | "y": curve.center[1], 68 | "z": 0, 69 | } 70 | 71 | json_curve["is_outer"] = curve.is_outer 72 | return json_curve 73 | 74 | def convert_vertices(self): 75 | """Convert all the vertices to .obj format""" 76 | vertex_strings = "" 77 | for pt in self.vertex_dict.values(): 78 | # e.g. v 0.123 0.234 0.345 1.0 79 | vertex_string = f"v {pt[0]} {pt[1]}\n" 80 | vertex_strings += vertex_string 81 | return vertex_strings 82 | 83 | def parse_obj(self, faces, meta_info): 84 | """ 85 | reconstruct brep from obj file 86 | """ 87 | # At least one needs to match 88 | for face in faces: 89 | for loop in face: 90 | if len(loop) > 1: 91 | for idx, curve in enumerate(loop[:-1]): 92 | next_curve = np.vstack([loop[idx + 1].start, loop[idx + 1].end]) 93 | diff1 = np.sum(np.abs(curve.start - next_curve), 1) 94 | diff2 = np.sum(np.abs(curve.end - next_curve), 1) 95 | 96 | if min(diff2) == 0 or min(diff1) == 0: 97 | continue # edge connected 98 | 99 | assert ( 100 | min(diff1) < 1e-3 or min(diff2) < 1e-3 101 | ) # difference should be small 102 | 103 | if min(diff1) > min(diff2): 104 | min_idx = np.argmin(diff2) 105 | if min_idx == 0: 106 | loop[idx + 1].start_idx = curve.end_idx 107 | loop[idx + 1].start = curve.end 108 | else: 109 | loop[idx + 1].end_idx = curve.end_idx 110 | loop[idx + 1].end = curve.end 111 | else: 112 | min_idx = np.argmin(diff1) 113 | if min_idx == 0: 114 | loop[idx + 1].start_idx = curve.start_idx 115 | loop[idx + 1].start = curve.start 116 | else: 117 | loop[idx + 1].end_idx = curve.start_idx 118 | loop[idx + 1].end = curve.start 119 | 120 | # Solve start / end connection 121 | shared_idx = list( 122 | set([loop[-2].start_idx, loop[-2].end_idx]).intersection( 123 | set([loop[-1].start_idx, loop[-1].end_idx]) 124 | ) 125 | ) 126 | 127 | assert len(shared_idx) >= 1 128 | 129 | if len(shared_idx) == 2: 130 | assert len(loop) == 2 # do nothing 131 | else: 132 | if shared_idx[0] == loop[-1].start_idx: 133 | do_start = False 134 | else: 135 | do_start = True 136 | start_curve = np.vstack([loop[0].start, loop[0].end]) 137 | 138 | if do_start: 139 | diff = np.sum(np.abs(loop[-1].start - start_curve), 1) 140 | else: 141 | diff = np.sum(np.abs(loop[-1].end - start_curve), 1) 142 | assert min(diff) < 1e-3 143 | 144 | min_idx = np.argmin(diff) 145 | if min_idx == 0: 146 | if do_start: 147 | loop[-1].start_idx = loop[0].start_idx 148 | loop[-1].start = loop[0].start 149 | else: 150 | loop[-1].end_idx = loop[0].start_idx 151 | loop[-1].end = loop[0].start 152 | else: 153 | if do_start: 154 | loop[-1].start_idx = loop[0].end_idx 155 | loop[-1].start = loop[0].end 156 | else: 157 | loop[-1].end_idx = loop[0].end_idx 158 | loop[-1].end = loop[0].end 159 | 160 | # Parse groups to json loop/curve profile 161 | extrusion = {} 162 | extrusion["profiles"] = [] 163 | for face in faces: 164 | profile = {} 165 | profile["loops"] = [] 166 | for loop in face: 167 | pl = {} 168 | pl["profile_curves"] = [] 169 | for curve in loop: 170 | # convert to json format 171 | pl["profile_curves"].append(self.convert_curve(curve)) 172 | profile["loops"].append(pl) 173 | extrusion["profiles"].append(profile) 174 | 175 | # Parse transform 176 | sketch = {} 177 | transform = {} 178 | transform["origin"] = { 179 | "x": meta_info["t_orig"][0], 180 | "y": meta_info["t_orig"][1], 181 | "z": meta_info["t_orig"][2], 182 | } 183 | transform["x_axis"] = { 184 | "x": meta_info["t_x"][0], 185 | "y": meta_info["t_x"][1], 186 | "z": meta_info["t_x"][2], 187 | } 188 | transform["y_axis"] = { 189 | "x": meta_info["t_y"][0], 190 | "y": meta_info["t_y"][1], 191 | "z": meta_info["t_y"][2], 192 | } 193 | transform["z_axis"] = { 194 | "x": meta_info["t_z"][0], 195 | "y": meta_info["t_z"][1], 196 | "z": meta_info["t_z"][2], 197 | } 198 | sketch["transform"] = transform 199 | 200 | # Parse extrude 201 | extrude_params = {} 202 | extrude_params["extrude_type"] = meta_info["set_op"] 203 | extrude_params["extrude_values"] = meta_info["extrude_value"] 204 | 205 | # Create sketch 206 | all_faces = [] 207 | curve_strings = "" 208 | curve_count = 0 209 | for profile in extrusion["profiles"]: 210 | ref_face, face, curve_string, c_count = self.parse_sketch(sketch, profile) 211 | curve_strings += curve_string 212 | curve_count += c_count 213 | all_faces.append(face) 214 | 215 | # Merge all faces in the same plane 216 | plane_face = all_faces[0] 217 | for face in all_faces[1:]: 218 | plane_face = self.my_op(plane_face, face, "fuse") 219 | solid = self.extrude_face(ref_face, plane_face, extrude_params) 220 | return solid, curve_strings, curve_count 221 | 222 | def my_op(self, big, small, op_name): 223 | if op_name == "cut": 224 | op = BRepAlgoAPI_Cut(big, small) 225 | elif op_name == "fuse": 226 | op = BRepAlgoAPI_Fuse(big, small) 227 | elif op_name == "common": 228 | op = BRepAlgoAPI_Common(big, small) 229 | op.SetFuzzyValue(self.PRECISION) 230 | op.Build() 231 | return op.Shape() 232 | 233 | def build_body(self, face, normal, value): 234 | extrusion_vec = gp_Vec(normal).Multiplied(value) 235 | make_prism = BRepPrimAPI_MakePrism(face, extrusion_vec) 236 | make_prism.Build() 237 | prism = make_prism.Prism() 238 | return prism.Shape() 239 | 240 | def extrudeBasedOnType(self, face, normal, distance): 241 | # Extrude based on the two bound values 242 | if not (distance[0] < distance[1]): 243 | raise Exception("incorrect distance") 244 | large_value = distance[1] 245 | small_value = distance[0] 246 | 247 | if large_value == 0: 248 | return self.build_body(face, -normal, -small_value) 249 | elif small_value == 0: 250 | return self.build_body(face, normal, large_value) 251 | elif np.sign(large_value) == np.sign(small_value): 252 | if large_value < 0: 253 | body1 = self.build_body(face, -normal, -small_value) 254 | body2 = self.build_body(face, -normal, -large_value) 255 | return self.my_op(body1, body2, "cut") 256 | else: 257 | assert large_value > 0 258 | body1 = self.build_body(face, normal, small_value) 259 | body2 = self.build_body(face, normal, large_value) 260 | return self.my_op(body2, body1, "cut") 261 | else: 262 | assert np.sign(large_value) != np.sign(small_value) 263 | body1 = self.build_body(face, normal, large_value) 264 | body2 = self.build_body(face, -normal, -small_value) 265 | return self.my_op(body1, body2, "fuse") 266 | 267 | def extrude_face(self, ref_face, face, extrude_params): 268 | distance = extrude_params["extrude_values"] 269 | surf = BRepAdaptor_Surface(ref_face).Plane() 270 | normal = surf.Axis().Direction() 271 | extruded_shape = self.extrudeBasedOnType(face, normal, distance) 272 | return extruded_shape 273 | 274 | def parse_sketch(self, sketch, profile): 275 | """ 276 | Sketch in one closed loop (one out, multiple ins) 277 | """ 278 | # Transformation from local to global xyz coord 279 | transform = get_transform(sketch["transform"]) 280 | 281 | # Create face region (automatically infer from all wires) 282 | outer_facelist = [] 283 | inner_facelist = [] 284 | curve_count = 0 285 | outer_string = [] 286 | inner_string = [] 287 | plane = create_sketch_plane(sketch["transform"]) 288 | 289 | for idx, pl in enumerate(profile["loops"]): 290 | # Create loop 291 | loop, curve_string, num_curve = self.parse_loop( 292 | pl["profile_curves"], transform 293 | ) 294 | # Create face 295 | face_builder = BRepBuilderAPI_MakeFace(plane, loop) 296 | if not face_builder.IsDone(): 297 | raise Exception("face builder not done") 298 | face = face_builder.Face() 299 | # Fix face 300 | fixer = ShapeFix_Face(face) 301 | fixer.SetPrecision(self.PRECISION) 302 | fixer.FixOrientation() 303 | 304 | analyzer = BRepCheck_Analyzer(fixer.Face()) 305 | if not analyzer.IsValid(): 306 | raise Exception("face check failed") 307 | 308 | curve_count += num_curve 309 | 310 | if pl["profile_curves"][0]["is_outer"]: 311 | outer_facelist.append(fixer.Face()) 312 | outer_string.append(curve_string) 313 | else: 314 | inner_facelist.append(fixer.Face()) 315 | inner_string.append(curve_string) 316 | 317 | # Create final closed loop face 318 | assert len(outer_facelist) > 0 319 | final_face = outer_facelist[0] 320 | for face in outer_facelist[1:]: 321 | final_face = self.my_op(final_face, face, "fuse") 322 | for face in inner_facelist: 323 | final_face = self.my_op(final_face, face, "cut") 324 | 325 | # Append inner outer information to string 326 | assert len(outer_string) == 1 327 | out_str = "" 328 | in_str = "" 329 | for c_str in outer_string: 330 | out_str += "out\n" + c_str + "\n" 331 | for c_str in inner_string: 332 | in_str += "in\n" + c_str + "\n" 333 | final_str = "face\n" + out_str + in_str 334 | 335 | return outer_facelist[0], final_face, final_str, curve_count 336 | 337 | def parse_loop(self, profile_loop, transform): 338 | """Create face in one closed loop""" 339 | topo_wire = BRepBuilderAPI_MakeWire() 340 | curve_strings = "" 341 | curve_count = 0 342 | 343 | # Loop through all the curves in one loop 344 | for profile_curve in profile_loop: 345 | curve_edge, curve_string = self.parse_curve(profile_curve, transform) 346 | topo_wire.Add(curve_edge) 347 | if not topo_wire.IsDone(): 348 | raise Exception("wire builder not done") 349 | 350 | curve_string += "\n" 351 | curve_count += 1 352 | curve_strings += curve_string 353 | 354 | fixer = ShapeFix_Wire() 355 | fixer.Load(topo_wire.Wire()) 356 | fixer.SetPrecision(self.PRECISION) 357 | fixer.FixClosed() 358 | fixer.Perform() 359 | return fixer.Wire(), curve_strings, curve_count 360 | 361 | def parse_curve(self, curve, transform): 362 | if curve["type"] == "Line3D": 363 | return self.create_line(curve, transform) 364 | elif curve["type"] == "Circle3D": 365 | return self.create_circle(curve, transform) 366 | elif curve["type"] == "Arc3D": 367 | return self.create_arc(curve, transform) 368 | else: 369 | raise Exception("unknown curve type") 370 | 371 | def create_line(self, line, transform): 372 | start = create_point(line["start_point"], transform) 373 | end = create_point(line["end_point"], transform) 374 | if start.Distance(end) == 0: 375 | raise Exception("start/end point same location") 376 | topo_edge = BRepBuilderAPI_MakeEdge(start, end) 377 | 378 | # Save pre-transform 379 | star_idx = self.save_vertex( 380 | line["start_point"]["x"] + 0.0, line["start_point"]["y"] + 0.0, "p" 381 | ) 382 | end_idx = self.save_vertex( 383 | line["end_point"]["x"] + 0.0, line["end_point"]["y"] + 0.0, "p" 384 | ) 385 | curve_string = f"l {star_idx} {end_idx}" 386 | return topo_edge.Edge(), curve_string 387 | 388 | def create_arc(self, arc, transform): 389 | start = create_point(arc["start_point"], transform) 390 | mid = create_point(arc["mid_point"], transform) 391 | end = create_point(arc["end_point"], transform) 392 | arc_occ = GC_MakeArcOfCircle(start, mid, end).Value() 393 | topo_edge = BRepBuilderAPI_MakeEdge(arc_occ) 394 | 395 | # Save pre-transform 396 | start_idx = self.save_vertex( 397 | arc["start_point"]["x"] + 0.0, arc["start_point"]["y"] + 0.0, "p" 398 | ) 399 | end_idx = self.save_vertex( 400 | arc["end_point"]["x"] + 0.0, arc["end_point"]["y"] + 0.0, "p" 401 | ) 402 | center_idx = self.save_vertex( 403 | arc["center_point"]["x"] + 0.0, arc["center_point"]["y"] + 0.0, "p" 404 | ) 405 | mid_idx = self.save_vertex( 406 | arc["mid_point"]["x"] + 0.0, arc["mid_point"]["y"] + 0.0, "p" 407 | ) 408 | curve_string = f"a {start_idx} {mid_idx} {center_idx} {end_idx}" 409 | return topo_edge.Edge(), curve_string 410 | 411 | def create_circle(self, circle, transform): 412 | center = create_point(circle["center_point"], transform) 413 | radius = circle["radius"] 414 | normal = create_unit_vec({"x": 0.0, "y": 0.0, "z": 1.0}, transform) 415 | ref_vector3d = self.x_axis.Transformed(transform) 416 | axis = gp_Ax2(center, normal, ref_vector3d) 417 | gp_circle = gp_Circ(axis, abs(float(radius))) 418 | topo_edge = BRepBuilderAPI_MakeEdge(gp_circle) 419 | 420 | center_idx = self.save_vertex( 421 | circle["center_point"]["x"] + 0.0, circle["center_point"]["y"] + 0.0, "p" 422 | ) 423 | radius_idx = self.save_vertex(abs(float(radius)) + 0.0, 0, "r") 424 | curve_string = f"c {center_idx} {radius_idx}" 425 | return topo_edge.Edge(), curve_string 426 | 427 | def save_vertex(self, h_x, h_y, text): 428 | unique_key = f"{text}:x{h_x}y{h_y}" 429 | index = 0 430 | for key in self.vertex_dict.keys(): 431 | # Vertex location already exist in dict 432 | if unique_key == key: 433 | return index 434 | index += 1 435 | # Vertex location does not exist in dict 436 | self.vertex_dict[unique_key] = [h_x, h_y] 437 | return index 438 | -------------------------------------------------------------------------------- /utils/parse_seq2obj.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import OrderedDict 3 | import re 4 | from pathlib import Path 5 | import argparse 6 | import os 7 | import json 8 | import math 9 | 10 | # hyperparameters from SkexGen project 11 | SKETCH_R = 1 12 | RADIUS_R = 1 13 | EXTRUDE_R = 1.0 14 | SCALE_R = 1.4 15 | OFFSET_R = 0.9 16 | PIX_PAD = 4 17 | CMD_PAD = 3 18 | COORD_PAD = 4 19 | EXT_PAD = 1 20 | EXTRA_PAD = 1 21 | R_PAD = 2 22 | 23 | 24 | class CADparser: 25 | """Parse CAD sequence to CAD object.""" 26 | 27 | def __init__(self, bit): 28 | self.vertex_dict = OrderedDict() 29 | self.bit = bit 30 | 31 | def perform(self, cad_seq): 32 | # Check for floating point numbers or non-integers in the input string 33 | if re.search(r'\d+\.\d+', cad_seq): 34 | print("contain float nums!") 35 | return None 36 | # divide into sketch and extrude 37 | sketches, extrudes = self.get_SE(cad_seq) 38 | if sketches is None or extrudes is None: 39 | print("sketches is None or extrudes is None!") 40 | return None 41 | # sequentially parse each pair of SE into obj 42 | se_datas = [] 43 | for sketch, extrude in zip(sketches, extrudes): 44 | extrude_param, scale, offset = self.parse_extrude(extrude) 45 | if extrude_param is None or scale is None or offset is None: 46 | print("sketches is None or extrudes is None!") 47 | return None 48 | vertex_str, se_str = self.parse_sketch(sketch, scale, offset) 49 | if vertex_str is None or se_str is None: 50 | return None 51 | se_datas.append( 52 | {"vertex": vertex_str, "curve": se_str, "extrude": extrude_param} 53 | ) 54 | self.vertex_dict.clear() 55 | 56 | return se_datas 57 | 58 | def parse_sketch(self, sketch, scale, offset): 59 | faces = self.get_faces(sketch) 60 | if len(faces) == 0: 61 | print("face is None") 62 | return None, None 63 | se_str = "" 64 | for face_idx, face in enumerate(faces): # each face 65 | face_str = "face\n" 66 | loops = self.get_loops(face) 67 | if len(loops) == 0: 68 | print("loop is None") 69 | return None, None 70 | for loop_idx, loop in enumerate(loops): # each loop 71 | curves = self.get_curves(loop) 72 | if len(curves) == 0: 73 | print("curve is None") 74 | return None, None 75 | next_curves = curves[1:] 76 | next_curves += curves[:1] 77 | cur_str = [] 78 | for curve, next_curve in zip(curves, next_curves): # each curve 79 | if not self.obj_curve(curve, next_curve, cur_str, scale, offset): 80 | return None, None 81 | loop_str = "" 82 | for c in cur_str: 83 | loop_str += f"{c}\n" 84 | if loop_idx == 0: 85 | face_str += f"out\n{loop_str}\n" 86 | else: 87 | face_str += f"in\n{loop_str}\n" 88 | se_str += face_str 89 | vertex_str = self.convert_vertices() 90 | return vertex_str, se_str 91 | 92 | def parse_extrude(self, extrude): 93 | ext = extrude.split(",") 94 | if len(ext) != 18: 95 | print("extrude para != 18") 96 | return None, None, None 97 | 98 | # operation str to int 99 | ext_op = {"add": 1, "cut": 2, "intersect": 3}.get(ext[0], None) 100 | if ext_op is None: 101 | return None, None, None 102 | # dequantize ext_v, ext_T, scale and offset 103 | ext_v, ext_T, scale, offset = self.dequantize_extrude_params(ext) 104 | # get ext_R 105 | ext_R = np.array(ext[6:15], dtype=int) 106 | 107 | extrude_param = {"value": ext_v, "T": ext_T, "R": ext_R, "op": ext_op} 108 | return extrude_param, scale, offset 109 | 110 | def obj_curve(self, curve, next_curve, cur_str, scale, offset): 111 | cur = curve.split(",") 112 | next_cur = next_curve.split(",") 113 | if cur[0] == "circle": 114 | if len(cur) != 9: 115 | print("circle para num != 9!\n") 116 | return False 117 | p1, p2, p3, p4 = self.dequantize_circle_points( 118 | cur, next_cur, scale, offset) 119 | center = np.asarray([0.5 * (p1[0] + p2[0]), 0.5 * (p3[1] + p4[1])]) 120 | radius = (np.linalg.norm(p1 - p2) + np.linalg.norm(p3 - p4)) / 4.0 121 | 122 | center = center * scale + offset 123 | radius = radius * scale 124 | 125 | center_idx = self.save_vertex(center[0], center[1], "p") 126 | radius_idx = self.save_vertex(radius, 0.0, "r") 127 | cur_str.append(f"c {center_idx} {radius_idx}") 128 | elif cur[0] == "arc": 129 | if len(cur) != 5: 130 | print("arc para num != 5!\n") 131 | return False 132 | if ( 133 | cur[1:3] == cur[3:5] 134 | or cur[1:3] == next_cur[1:3] 135 | or cur[3:5] == next_cur[3:5] 136 | ): # invalid arc 137 | print("invalid arc!\n") 138 | return False 139 | start_v, mid_v, end_v = self.dequantize_arc_points( 140 | cur, next_cur, scale, offset 141 | ) 142 | try: 143 | center, _, _, _ = find_arc_geometry(start_v, mid_v, end_v) 144 | except Exception: 145 | return False 146 | start_v = start_v * scale + offset 147 | mid_v = mid_v * scale + offset 148 | end_v = end_v * scale + offset 149 | center = center * scale + offset 150 | 151 | center_idx = self.save_vertex(center[0], center[1], "p") 152 | start_idx = self.save_vertex(start_v[0], start_v[1], "p") 153 | mid_idx = self.save_vertex(mid_v[0], mid_v[1], "p") 154 | end_idx = self.save_vertex(end_v[0], end_v[1], "p") 155 | cur_str.append(f"a {start_idx} {mid_idx} {center_idx} {end_idx}") 156 | elif cur[0] == "line": 157 | if len(cur) != 3: 158 | print("line para num != 3!\n") 159 | return False 160 | if cur[1:3] == next_cur[1:3]: 161 | return False 162 | start_v, end_v = self.dequantize_line_points( 163 | cur, next_cur, scale, offset) 164 | start_v = start_v * scale + offset 165 | end_v = end_v * scale + offset 166 | 167 | start_idx = self.save_vertex(start_v[0], start_v[1], "p") 168 | end_idx = self.save_vertex(end_v[0], end_v[1], "p") 169 | cur_str.append(f"l {start_idx} {end_idx}") 170 | else: 171 | print("invalid curve name!\n") 172 | return False 173 | return True 174 | 175 | def get_SE(self, cad_seq): 176 | # sketches: 1) between sequence start and sketch_end, 177 | sketches_from_start = re.findall(r"^(.+?)(?=)", cad_seq) 178 | # sketches: 2) between extrude_end and sketch_end 179 | sketches_after_extrude = re.findall( 180 | r"(?<=)(.+?)(?=)", cad_seq 181 | ) 182 | sketches = [x.strip() for x in sketches_from_start] + [ 183 | x.strip() for x in sketches_after_extrude 184 | ] 185 | # extrudes: between sketch_end and extrude_end 186 | extrudes = [ 187 | x.strip() for x in re.findall(r"(.+?)", cad_seq) 188 | ] 189 | if len(sketches) != len(extrudes): 190 | print("len(sketches) != len(extrudes)") 191 | return None, None 192 | return sketches, extrudes 193 | 194 | def get_faces(self, sketch): 195 | faces = sketch.split("") 196 | return [x.strip() for x in faces if x.strip() != ""] 197 | 198 | def get_loops(self, face): 199 | loops = face.split("") 200 | return [x.strip() for x in loops if x.strip() != ""] 201 | 202 | def get_curves(self, loop): 203 | curves = loop.split("") 204 | return [x.strip() for x in curves if x.strip() != ""] 205 | 206 | def dequantize_circle_points(self, curve, next_curve, scale, offset): 207 | p1 = dequantize_verts( 208 | np.array(curve[1:3], dtype=int), 209 | n_bits=self.bit, 210 | min_range=-SKETCH_R, 211 | max_range=SKETCH_R, 212 | add_noise=False, 213 | ) 214 | p2 = dequantize_verts( 215 | np.array(curve[3:5], dtype=int), 216 | n_bits=self.bit, 217 | min_range=-SKETCH_R, 218 | max_range=SKETCH_R, 219 | add_noise=False, 220 | ) 221 | p3 = dequantize_verts( 222 | np.array(curve[5:7], dtype=int), 223 | n_bits=self.bit, 224 | min_range=-SKETCH_R, 225 | max_range=SKETCH_R, 226 | add_noise=False, 227 | ) 228 | p4 = dequantize_verts( 229 | np.array(curve[7:9], dtype=int), 230 | n_bits=self.bit, 231 | min_range=-SKETCH_R, 232 | max_range=SKETCH_R, 233 | add_noise=False, 234 | ) 235 | return p1, p2, p3, p4 236 | 237 | def dequantize_arc_points(self, curve, next_curve, scale, offset): 238 | start_v = dequantize_verts( 239 | np.array(curve[1:3], dtype=int), 240 | n_bits=self.bit, 241 | min_range=-SKETCH_R, 242 | max_range=SKETCH_R, 243 | add_noise=False, 244 | ) 245 | mid_v = dequantize_verts( 246 | np.array(curve[3:5], dtype=int), 247 | n_bits=self.bit, 248 | min_range=-SKETCH_R, 249 | max_range=SKETCH_R, 250 | add_noise=False, 251 | ) 252 | end_v = dequantize_verts( 253 | np.array(next_curve[1:3], dtype=int), 254 | n_bits=self.bit, 255 | min_range=-SKETCH_R, 256 | max_range=SKETCH_R, 257 | add_noise=False, 258 | ) 259 | return start_v, mid_v, end_v 260 | 261 | def dequantize_line_points(self, curve, next_curve, scale, offset): 262 | start_v = dequantize_verts( 263 | np.array(curve[1:3], dtype=int), 264 | n_bits=self.bit, 265 | min_range=-SKETCH_R, 266 | max_range=SKETCH_R, 267 | add_noise=False, 268 | ) 269 | try: 270 | point = np.array(next_curve[1:3], dtype=int) 271 | end_v = dequantize_verts( 272 | point, 273 | n_bits=self.bit, 274 | min_range=-SKETCH_R, 275 | max_range=SKETCH_R, 276 | add_noise=False, 277 | ) 278 | except ValueError as e: 279 | print(f"Invalid data in next_curve: {next_curve[1:3]}") 280 | return None, None 281 | return start_v, end_v 282 | 283 | def dequantize_extrude_params(self, extrude): 284 | ext_v = dequantize_verts( 285 | np.array(extrude[1:3], dtype=int), 286 | n_bits=self.bit, 287 | min_range=-EXTRUDE_R, 288 | max_range=EXTRUDE_R, 289 | add_noise=False, 290 | ) 291 | ext_T = dequantize_verts( 292 | np.array(extrude[3:6], dtype=int), 293 | n_bits=self.bit, 294 | min_range=-EXTRUDE_R, 295 | max_range=EXTRUDE_R, 296 | add_noise=False, 297 | ) 298 | scale = dequantize_verts( 299 | np.array(extrude[15], dtype=int), 300 | n_bits=self.bit, 301 | min_range=0.0, 302 | max_range=SCALE_R, 303 | add_noise=False, 304 | ) 305 | offset = dequantize_verts( 306 | np.array(extrude[16:18], dtype=int), 307 | n_bits=self.bit, 308 | min_range=-OFFSET_R, 309 | max_range=OFFSET_R, 310 | add_noise=False, 311 | ) 312 | return ext_v, ext_T, scale, offset 313 | 314 | def save_vertex(self, h_x, h_y, text): 315 | unique_key = f"{text}:x{h_x}y{h_y}" 316 | index = 0 317 | for key in self.vertex_dict.keys(): 318 | # Vertex location already exist in dict 319 | if unique_key == key: 320 | return index 321 | index += 1 322 | # Vertex location does not exist in dict 323 | self.vertex_dict[unique_key] = [h_x, h_y] 324 | return index 325 | 326 | def convert_vertices(self): 327 | """Convert all the vertices to .obj format""" 328 | vertex_strings = "" 329 | for pt in self.vertex_dict.values(): 330 | # e.g. v 0.123 0.234 0.345 1.0 331 | vertex_string = f"v {pt[0]} {pt[1]}\n" 332 | vertex_strings += vertex_string 333 | return vertex_strings 334 | 335 | 336 | def find_arc_geometry(a, b, c): 337 | A = b[0] - a[0] 338 | B = b[1] - a[1] 339 | C = c[0] - a[0] 340 | D = c[1] - a[1] 341 | 342 | E = A*(a[0] + b[0]) + B*(a[1] + b[1]) 343 | F = C*(a[0] + c[0]) + D*(a[1] + c[1]) 344 | 345 | G = 2.0*(A*(c[1] - b[1])-B*(c[0] - b[0])) 346 | 347 | if G == 0: 348 | raise Exception("zero G") 349 | 350 | p_0 = (D*E - B*F) / G 351 | p_1 = (A*F - C*E) / G 352 | 353 | center = np.array([p_0, p_1]) 354 | radius = np.linalg.norm(center - a) 355 | 356 | angles = [] 357 | for xx in [a, b, c]: 358 | angle = angle_from_vector_to_x(xx - center) 359 | angles.append(angle) 360 | 361 | ab = b-a 362 | ac = c-a 363 | cp = np.cross(ab, ac) 364 | if cp >= 0: 365 | start_angle_rads = angles[0] 366 | end_angle_rads = angles[2] 367 | else: 368 | start_angle_rads = angles[2] 369 | end_angle_rads = angles[0] 370 | 371 | return center, radius, start_angle_rads, end_angle_rads 372 | 373 | 374 | def angle_from_vector_to_x(vec): 375 | assert vec.size == 2 376 | # We need to find a unit vector 377 | angle = 0.0 378 | 379 | l = np.linalg.norm(vec) 380 | uvec = vec/l 381 | 382 | # 2 | 1 383 | # ------- 384 | # 3 | 4 385 | if uvec[0] >= 0: 386 | if uvec[1] >= 0: 387 | # Qadrant 1 388 | angle = math.asin(uvec[1]) 389 | else: 390 | # Qadrant 4 391 | angle = 2.0*math.pi - math.asin(-uvec[1]) 392 | else: 393 | if vec[1] >= 0: 394 | # Qadrant 2 395 | angle = math.pi - math.asin(uvec[1]) 396 | else: 397 | # Qadrant 3 398 | angle = math.pi + math.asin(-uvec[1]) 399 | return angle 400 | 401 | 402 | def dequantize_verts(verts, n_bits=8, min_range=-0.5, max_range=0.5, add_noise=False): 403 | """Convert quantized vertices to floats.""" 404 | range_quantize = 2**n_bits - 1 405 | verts = verts.astype("float32") 406 | verts = verts * (max_range - min_range) / range_quantize + min_range 407 | return verts 408 | 409 | 410 | def write_obj_sample(save_folder, data): 411 | for idx, write_data in enumerate(data): 412 | obj_name = Path(save_folder).stem + "_" + \ 413 | str(idx).zfill(3) + "_param.obj" 414 | obj_file = Path(save_folder) / obj_name 415 | extrude_param = write_data["extrude"] 416 | vertex_strings = write_data["vertex"] 417 | curve_strings = write_data["curve"] 418 | 419 | """Write an .obj file with the curves and verts""" 420 | if extrude_param["op"] == 1: # 'add' 421 | set_op = "NewBodyFeatureOperation" 422 | elif extrude_param["op"] == 2: # 'cut' 423 | set_op = "CutFeatureOperation" 424 | elif extrude_param["op"] == 3: # 'cut' 425 | set_op = "IntersectFeatureOperation" 426 | 427 | with open(obj_file, "w") as fh: 428 | # Write Meta info 429 | fh.write("# WaveFront *.obj file\n") 430 | fh.write("# ExtrudeOperation: " + set_op + "\n") 431 | fh.write("\n") 432 | 433 | # Write vertex and curve 434 | fh.write(vertex_strings) 435 | fh.write("\n") 436 | fh.write(curve_strings) 437 | fh.write("\n") 438 | 439 | # Write extrude value 440 | extrude_string = "Extrude " 441 | for value in extrude_param["value"]: 442 | extrude_string += str(value) + " " 443 | fh.write(extrude_string) 444 | fh.write("\n") 445 | 446 | # Write refe plane value 447 | p_orig = parse3d_sample(extrude_param["T"]) 448 | x_axis = parse3d_sample(extrude_param["R"][0:3]) 449 | y_axis = parse3d_sample(extrude_param["R"][3:6]) 450 | z_axis = parse3d_sample(extrude_param["R"][6:9]) 451 | fh.write("T_origin " + p_orig) 452 | fh.write("\n") 453 | fh.write("T_xaxis " + x_axis) 454 | fh.write("\n") 455 | fh.write("T_yaxis " + y_axis) 456 | fh.write("\n") 457 | fh.write("T_zaxis " + z_axis) 458 | 459 | 460 | def parse3d_sample(point3d): 461 | x = point3d[0] 462 | y = point3d[1] 463 | z = point3d[2] 464 | return str(x) + " " + str(y) + " " + str(z) 465 | 466 | 467 | if __name__ == "__main__": 468 | parser = argparse.ArgumentParser() 469 | parser.add_argument("--in_path", type=str, required=True) 470 | parser.add_argument("--out_path", type=str, required=True) 471 | parser.add_argument("--type", type=str, required=True) 472 | args = parser.parse_args() 473 | 474 | with open(args.in_path, 'r') as file: 475 | content = file.read().strip() 476 | try: 477 | data = json.loads(content) 478 | except json.JSONDecodeError: 479 | try: 480 | data =[] 481 | for line in content.split('\n'): 482 | line = line.strip() 483 | if line: 484 | data.append(json.loads(line)) 485 | except json.JSONDecodeError: 486 | raise ValueError(f"Failed to parse JSON from {args.in_path}. Please check the file format.") 487 | 488 | num_valid_str = 0 489 | for idx, item in enumerate(data): 490 | try: 491 | cad_parser = CADparser(bit=6) 492 | if args.type == "original": 493 | parsed_data = cad_parser.perform(item["original_sequence"]) 494 | elif args.type == "edit": 495 | parsed_data = cad_parser.perform(item["edited_sequence"]) 496 | elif args.type == "infill": 497 | parsed_data = cad_parser.perform(item["output_infill"]) 498 | elif args.type == "gpt": 499 | parsed_data = cad_parser.perform(item["output"]) 500 | elif args.type == "direct": 501 | parsed_data = cad_parser.perform(item["output_sequence"]) 502 | 503 | out_path = os.path.join(args.out_path, str(str(idx).zfill(5)) + '_' + str(item['original_pic_name']) + '_' + str(item['edited_pic_name'])) 504 | #out_path = os.path.join(args.out_path, str(str(idx).zfill(5))) 505 | 506 | os.makedirs(out_path, exist_ok=True) 507 | if parsed_data is not None: 508 | num_valid_str += 1 509 | write_obj_sample(out_path, parsed_data) 510 | except Exception as e: 511 | print(f"Error processing item {idx}: {e}") 512 | continue # Skip this item if an error occurs 513 | 514 | 515 | print(f"Number of valid CAD strings: {num_valid_str}/{len(data)}") 516 | --------------------------------------------------------------------------------