├── miche ├── __init__.py ├── michelangelo │ ├── __init__.py │ ├── graphics │ │ ├── __init__.py │ │ └── primitives │ │ │ ├── __init__.py │ │ │ └── volume.py │ ├── models │ │ ├── __init__.py │ │ ├── tsal │ │ │ ├── __init__.py │ │ │ ├── inference_utils.py │ │ │ ├── clip_asl_module.py │ │ │ ├── tsal_base.py │ │ │ ├── loss.py │ │ │ └── asl_pl_module.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── distributions.py │ │ │ ├── embedder.py │ │ │ └── transformer_blocks.py │ └── utils │ │ ├── __init__.py │ │ └── misc.py ├── shapevae-256.yaml └── encode.py ├── model ├── __init__.py ├── miche_conditioner.py └── data_provider_infer.py ├── silkutils ├── __init__.py ├── meto │ ├── __init__.py │ ├── mathutils.py │ ├── ss_engine.py │ ├── decode_utils.py │ └── decode_utils_fix.py ├── meshdata │ ├── __init__.py │ ├── mesh_color.py │ └── mesh_graph.py ├── dataset_clean │ ├── step3_cleanfix.py │ ├── step2_clean.py │ ├── step5_sample.py │ ├── process_dataset_fix.py │ ├── process_one.py │ ├── step4_datafilter.py │ └── process_dataset.py ├── silksong_tokenization.py ├── silksong_manifold_process.py └── ss_platform.py ├── assets └── teaser_mid_compress.png ├── acc_configs ├── gpu1.yaml ├── gpu8.yaml ├── gpu32.yaml └── gpu16.yaml ├── slurm_jobs └── infer_silksong_obj.sh ├── scripts ├── infer_silksong_obj.sh ├── train_silksong_scratch_gpu16.sh └── train_silksong_ft_gpu16.sh ├── LICENSE ├── .gitignore ├── nonmani_process.md ├── config └── options.py ├── requirements.txt ├── dataset_clean.md ├── README.md ├── infer.py └── train.py /miche/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /silkutils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /silkutils/meto/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /silkutils/meshdata/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /miche/michelangelo/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /miche/michelangelo/graphics/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /miche/michelangelo/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /miche/michelangelo/models/tsal/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /miche/michelangelo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .misc import instantiate_from_config 4 | -------------------------------------------------------------------------------- /assets/teaser_mid_compress.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gaochao-s/Mesh-Silksong/HEAD/assets/teaser_mid_compress.png -------------------------------------------------------------------------------- /miche/michelangelo/models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .checkpoint import checkpoint 4 | -------------------------------------------------------------------------------- /miche/michelangelo/graphics/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .volume import generate_dense_grid_points 4 | 5 | -------------------------------------------------------------------------------- /acc_configs/gpu1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 1 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu8.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 8 10 | rdzv_backend: static 11 | same_network: true 12 | tpu_env: [] 13 | tpu_use_cluster: false 14 | tpu_use_sudo: false 15 | use_cpu: false 16 | -------------------------------------------------------------------------------- /acc_configs/gpu32.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 4 9 | num_processes: 32 10 | rdzv_backend: static 11 | rdzv_endpoint: :29500 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /acc_configs/gpu16.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 2 9 | num_processes: 16 10 | rdzv_backend: static 11 | same_network: true 12 | main_process_ip: 127.0.0.1 13 | main_process_port: 29500 14 | tpu_env: [] 15 | tpu_use_cluster: false 16 | tpu_use_sudo: false 17 | use_cpu: false 18 | -------------------------------------------------------------------------------- /slurm_jobs/infer_silksong_obj.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=infer_silk # job name 3 | #SBATCH --output=main_workspace/job_logs/infer/infer_%j.log # logs 4 | #SBATCH --nodes=1 # nodes applying 5 | #SBATCH --partition=gpu # partition 6 | #SBATCH --ntasks=1 # job number 7 | #SBATCH --cpus-per-task=8 # CPU cores per task 8 | #SBATCH --time=48:00:00 # time limit 9 | #SBATCH --mem=64G # Memory 10 | #SBATCH --gres=gpu:1 # GPU number apply 11 | 12 | 13 | source /public/opt/conda/etc/profile.d/conda.sh 14 | conda activate silk 15 | export PATH="/public/home/group_gaosh/gaochao/.conda/envs/silk/bin:$PATH" 16 | cd /public/home/group_gaosh/gaochao/main_workspace/MeshSilksong 17 | 18 | sh scripts/infer_silksong_obj.sh 19 | -------------------------------------------------------------------------------- /scripts/infer_silksong_obj.sh: -------------------------------------------------------------------------------- 1 | 2 | # source /public/opt/conda/etc/profile.d/conda.sh 3 | # conda activate silk 4 | # export PATH="/public/home/group_gaosh/gaochao/.conda/envs/silk/bin:$PATH" 5 | # cd main_workspace/MeshSilksong 6 | 7 | # export CUDA_VISIBLE_DEVICES=4 8 | 9 | MSL=10240 10 | REPEAT=1 11 | INFER_BATCH=4 # H800 80G is ok, you may set it smaller if GPU Mem < 80G 12 | TEMPRETURE=0.5 13 | MAX_FILTER=0 14 | WORKSPACE="workspace_infer/silksong_output_test" 15 | TEST_INPUT="datasets/sample_test/meshes/test_mix_origin/batch_00" 16 | RESUME="checkpoints/release-100K/model.safetensors" 17 | 18 | 19 | python infer.py \ 20 | --workspace $WORKSPACE \ 21 | --train.resume $RESUME \ 22 | --infer.test_path_input $TEST_INPUT \ 23 | --max_seq_length $MSL \ 24 | --infer.test_repeat $REPEAT \ 25 | --infer.infer_batch $INFER_BATCH \ 26 | --infer.temperature $TEMPRETURE \ 27 | --infer.max_filter $MAX_FILTER 28 | 29 | 30 | -------------------------------------------------------------------------------- /miche/michelangelo/graphics/primitives/volume.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | # produce dense points 6 | def generate_dense_grid_points(bbox_min: np.ndarray, 7 | bbox_max: np.ndarray, 8 | octree_depth: int, 9 | indexing: str = "ij"): 10 | length = bbox_max - bbox_min 11 | num_cells = np.exp2(octree_depth) 12 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) 13 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) 14 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) 15 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) 16 | xyz = np.stack((xs, ys, zs), axis=-1) 17 | xyz = xyz.reshape(-1, 3) 18 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] 19 | 20 | return xyz, grid_size, length 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 gaochao-s 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /silkutils/meshdata/mesh_color.py: -------------------------------------------------------------------------------- 1 | 2 | def get_layer_rgb_color(index): 3 | 4 | colors =[ 5 | (255, 0, 0), # r 6 | (0, 255, 0), # g 7 | (0, 0, 255), # b 8 | (255, 255, 255), # white 9 | (0, 0, 0), # black 10 | ] 11 | 12 | if index==0: 13 | return colors[3] 14 | return colors[(index - 1) % 3] 15 | 16 | def get_distinct_rgb_color(index): 17 | if not index: 18 | return (0, 0, 0) 19 | if index=='B': 20 | return (0, 0, 0) 21 | if index=='W': 22 | return (255,255,255) 23 | if index=='bl': 24 | return (135, 206, 250) 25 | if index=='R': 26 | return (255, 0, 0) 27 | colors = [ 28 | (135, 206, 250), # 浅天蓝 29 | (144, 238, 144), # 浅绿色 30 | (255, 182, 193), # 浅粉色 31 | (210, 180, 140), # 黄褐色 32 | (218, 112, 214), # 兰花紫 33 | (255, 215, 0), # 金色 34 | (255, 99, 71), # 番茄红 35 | (240, 128, 128), # 浅珊瑚色 36 | (173, 216, 230), # 淡蓝色 37 | (152, 251, 152), # 苍绿色 38 | 39 | ] 40 | 41 | return colors[(index - 1) % len(colors)] -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.so 5 | .Python 6 | build/ 7 | develop-eggs/ 8 | dist/ 9 | downloads/ 10 | eggs/ 11 | .eggs/ 12 | lib/ 13 | lib64/ 14 | parts/ 15 | sdist/ 16 | var/ 17 | wheels/ 18 | *.egg-info/ 19 | .installed.cfg 20 | *.egg 21 | 22 | .vscode/ 23 | checkpoints/ 24 | config/__pycache__ 25 | datasets/ 26 | miche/shapevae-256.ckpt 27 | miche/__pycache__ 28 | model/__pycache__ 29 | scripts/train_silksong_ft_gpu16_back.sh 30 | scripts/train_silksong_c_gpu16_back.sh 31 | scripts/infer_silksong_obj_back.sh 32 | slurm_jobs/infer_silksong_obj_back.sh 33 | silkutils/__pycache__ 34 | silkutils/dataset_clean/__pycache__ 35 | silkutils/dataset_clean/blender_env 36 | silkutils/dataset_clean/decompress_trellis.py 37 | silkutils/dataset_clean/gobjaverse_280k_index_to_objaverse.json 38 | silkutils/dataset_clean/step5_datafilter_testset.py 39 | silkutils/dataset_clean/vis_statistic.py 40 | silkutils/demo_test/manifold_repair 41 | silkutils/meshdata/__pycache__ 42 | silkutils/meto/__pycache__ 43 | wandb/ 44 | workspace_infer/ 45 | workspace_train/ 46 | debug_infer.py 47 | debug_wandb.py 48 | learn.py 49 | run_docker.txt 50 | train_back.py 51 | infer_back.py 52 | vis_wandb.ipynb 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /nonmani_process.md: -------------------------------------------------------------------------------- 1 | # Nonmanifold Mesh Processing Guidance 2 | 3 | ## 1. Demo 4 | #### How to Run 5 | Run the code following this templete: 6 | ``` 7 | python silkutils/silksong_manifold_process.py \ 8 | --input_file silkutils/demo_test/shapenetv2_03761084_ee5861.obj \ 9 | --quant_resolution 1024 \ 10 | --output_dir silkutils/demo_test/manifold_repair \ 11 | --verbose 12 | ``` 13 | 14 | #### Output Illustration 15 | Based on the input mesh name, the output files have different prefix: 16 | - `M1_`: Load and normalize the mesh. 17 | - `M2_`: Process the non-manifold mesh to manifold mesh, and color different connected components with distinctive colors. 18 | - `M2f_`: Do the further face orientation consistency fixing. (If could) 19 | 20 | 21 | ## 2. Limitations 22 | 1. Currently the code is based on python and has no engineering acceleration for meshes with many faces, so it may be blocked for dense meshes. 23 | 24 | 2. The nonmanifold processing is binded with mesh vertices quantization to merge redundant faces as many as possible, if you do not want this, you can set a higher quantization resulution for trade-off. 25 | 26 | 3. The current algorithm may cause face topology like "mobius loop", which may hinder the repair of face orientation consistnecy. -------------------------------------------------------------------------------- /miche/shapevae-256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: miche.michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule 3 | params: 4 | shape_module_cfg: 5 | target: miche.michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver 6 | params: 7 | num_latents: 256 8 | embed_dim: 64 9 | point_feats: 3 # normal 10 | num_freqs: 8 11 | include_pi: false 12 | heads: 12 13 | width: 768 14 | num_encoder_layers: 8 15 | num_decoder_layers: 16 16 | use_ln_post: true 17 | init_scale: 0.25 18 | qkv_bias: false 19 | use_checkpoint: true 20 | aligned_module_cfg: 21 | target: miche.michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule 22 | params: 23 | clip_model_version: "./checkpoints/clip/clip-vit-large-patch14" 24 | 25 | loss_cfg: 26 | target: miche.michelangelo.models.tsal.loss.ContrastKLNearFar 27 | params: 28 | contrast_weight: 0.1 29 | near_weight: 0.1 30 | kl_weight: 0.001 31 | 32 | optimizer_cfg: 33 | optimizer: 34 | target: torch.optim.AdamW 35 | params: 36 | betas: [0.9, 0.99] 37 | eps: 1.e-6 38 | weight_decay: 1.e-2 39 | 40 | scheduler: 41 | target: miche.michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler 42 | params: 43 | warm_up_steps: 5000 44 | f_start: 1.e-6 45 | f_min: 1.e-3 46 | f_max: 1.0 47 | -------------------------------------------------------------------------------- /scripts/train_silksong_scratch_gpu16.sh: -------------------------------------------------------------------------------- 1 | # set the right conda env path for yours (the following are examples on cluster) 2 | source /public/opt/conda/etc/profile.d/conda.sh 3 | conda activate silk 4 | export PATH="/public/home/group_gaosh/gaochao/.conda/envs/silk/bin:$PATH" 5 | cd /public/home/group_gaosh/gaochao/main_workspace/MeshSilksong 6 | # skip verify of wandb 7 | export CURL_SSL_NO_VERIFY=1 8 | export REQUESTS_CA_BUNDLE="" 9 | # set your wandb api key 10 | export WANDB_API_KEY="you wandb api key" 11 | wandb login --relogin you wandb api key 12 | 13 | ########## 14 | # mask the above codes if you do not use slurm for cluster. you may activate you conda env and login to wandb manually. 15 | ########## 16 | 17 | # specify these params for cluster multinode training, default param is for single cluster node training or single machine training 18 | MACHINE_RANK=${MACHINE_RANK:-0} 19 | MASTER_ADDR=${MASTER_ADDR:-"localhost"} 20 | 21 | # the params for main training script 22 | MSL=10240 23 | RESAMPLE=1 24 | WARMUP=0.05 25 | LR=0.0001 26 | BATCH_SIZE=2 27 | FT=0 28 | I_BETA=0.0 29 | WORKSPACE="workspace_train/silk_scratch_multi16" 30 | NUM_EPOCH=200 31 | SAVE_EPOCH=1 32 | DATASET="ss" 33 | XLSX_DIR="datasets/cleaned" 34 | DATA_SUBSETS="gobjaversev1*3dfuture*toys4k*shapenetv2" # use * to seperate the datasets you want to train 35 | DATA_FILTER=11 36 | EVAL_MODE="loss" 37 | 38 | # modify --config_file if you want to use other GPU number 39 | accelerate launch \ 40 | --config_file acc_configs/gpu16.yaml \ 41 | --machine_rank $MACHINE_RANK \ 42 | --main_process_ip $MASTER_ADDR \ 43 | train.py \ 44 | --workspace $WORKSPACE \ 45 | --data.resample $RESAMPLE \ 46 | --data.dataset $DATASET \ 47 | --data.i_beta $I_BETA \ 48 | --train.ft $FT \ 49 | --train.warmup_ratio $WARMUP \ 50 | --data.batch_size $BATCH_SIZE \ 51 | --train.lr $LR \ 52 | --max_seq_length $MSL \ 53 | --data.xlsx_dir $XLSX_DIR \ 54 | --data.data_subsets $DATA_SUBSETS \ 55 | --data.data_filter_cnt $DATA_FILTER \ 56 | --train.num_epochs $NUM_EPOCH \ 57 | --train.eval_mode $EVAL_MODE \ 58 | --train.save_epoch $SAVE_EPOCH 59 | -------------------------------------------------------------------------------- /scripts/train_silksong_ft_gpu16.sh: -------------------------------------------------------------------------------- 1 | # set the right conda env path for yours (the following are examples on cluster) 2 | source /public/opt/conda/etc/profile.d/conda.sh 3 | conda activate silk 4 | export PATH="/public/home/group_gaosh/gaochao/.conda/envs/silk/bin:$PATH" 5 | cd /public/home/group_gaosh/gaochao/main_workspace/MeshSilksong 6 | # skip verify of wandb 7 | export CURL_SSL_NO_VERIFY=1 8 | export REQUESTS_CA_BUNDLE="" 9 | # set your wandb api key 10 | export WANDB_API_KEY="you wandb api key" 11 | wandb login --relogin you wandb api key 12 | 13 | ########## 14 | # mask the above codes if you do not use slurm for cluster. you may activate you conda env and login to wandb manually. 15 | ########## 16 | 17 | # specify these params for cluster multinode training, default param is for single cluster node training or single machine training 18 | MACHINE_RANK=${MACHINE_RANK:-0} 19 | MASTER_ADDR=${MASTER_ADDR:-"localhost"} 20 | 21 | # the params for main training script 22 | MSL=10240 23 | RESAMPLE=1 24 | WARMUP=0.05 25 | LR=0.0001 26 | BATCH_SIZE=2 27 | FT=1 28 | I_BETA=0.0 29 | WORKSPACE="workspace_train/silk_ft_multi16" 30 | RESUME="checkpoints/release-100K/model.safetensors" 31 | NUM_EPOCH=200 32 | SAVE_EPOCH=1 33 | DATASET="ss" 34 | XLSX_DIR="datasets/cleaned" 35 | DATA_SUBSETS="gobjaversev1*3dfuture*toys4k*shapenetv2" # use * to seperate the datasets you want to train 36 | DATA_FILTER=11 37 | EVAL_MODE="loss" 38 | 39 | # modify --config_file if you want to use other GPU number 40 | accelerate launch \ 41 | --config_file acc_configs/gpu16.yaml \ 42 | --machine_rank $MACHINE_RANK \ 43 | --main_process_ip $MASTER_ADDR \ 44 | train.py \ 45 | --workspace $WORKSPACE \ 46 | --data.resample $RESAMPLE \ 47 | --data.dataset $DATASET \ 48 | --data.i_beta $I_BETA \ 49 | --train.ft $FT \ 50 | --train.warmup_ratio $WARMUP \ 51 | --data.batch_size $BATCH_SIZE \ 52 | --train.lr $LR \ 53 | --train.resume $RESUME \ 54 | --max_seq_length $MSL \ 55 | --data.xlsx_dir $XLSX_DIR \ 56 | --data.data_subsets $DATA_SUBSETS \ 57 | --data.data_filter_cnt $DATA_FILTER \ 58 | --train.num_epochs $NUM_EPOCH \ 59 | --train.eval_mode $EVAL_MODE \ 60 | --train.save_epoch $SAVE_EPOCH 61 | -------------------------------------------------------------------------------- /silkutils/dataset_clean/step3_cleanfix.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from ss_platform import get_base_dir_platform 3 | import re 4 | import os 5 | 6 | def find_satisfied_files(dataset_name_in,reso_in,version_in, part_min, part_max): 7 | directory=get_base_dir_platform(dataset_name_in) 8 | pattern = re.compile(r'meta_all_(?P\w+)_res(?P\d+)_v(?P\d{2})_p(?P\d{4})_done_b(?P\d{3})\.xlsx') 9 | 10 | satisfied_filesdic=[] 11 | 12 | for filename in os.listdir(directory): 13 | match = pattern.match(filename) 14 | if match: 15 | dataset_name = match.group('dataset_name') 16 | reso = int(match.group('reso')) 17 | version = int(match.group('version')) 18 | part_ind = int(match.group('part_ind')) 19 | b_i = int(match.group('b_i')) 20 | if dataset_name_in != dataset_name or part_ind= part_max: 21 | continue 22 | if reso_in !=reso or version!=version_in: 23 | continue 24 | file_dic={ 25 | 'dataset_name': dataset_name, 26 | 'reso': reso, 27 | 'version': version, 28 | 'part_ind': part_ind, 29 | 'b_i': b_i, 30 | 'file': filename, 31 | } 32 | 33 | satisfied_filesdic.append(file_dic) 34 | return satisfied_filesdic 35 | 36 | def get_command_fix(file_dic, max_workers): 37 | return f'python dataset_clean/process_dataset_fix.py --dataset_name {file_dic['dataset_name']} --reso {file_dic['reso']} --version {file_dic['version']}\ 38 | --part_ind {file_dic['part_ind']} --b_i {file_dic['b_i']} --file {file_dic['file']} --max_workers {max_workers}' 39 | 40 | if __name__ == "__main__": 41 | dataset_name_in='objaversev1' 42 | reso_in=128 43 | version_in=2 44 | part_min=50 45 | part_max=100 46 | max_workers=64 47 | filesdic_list=find_satisfied_files(dataset_name_in,reso_in,version_in, part_min, part_max) 48 | for ind, ele in enumerate(filesdic_list): 49 | print(f'[{ind}] find: {ele['file']}') 50 | 51 | for filedic in filesdic_list: 52 | command=get_command_fix(filedic, max_workers) 53 | subprocess.run(command, shell=True) -------------------------------------------------------------------------------- /silkutils/meto/mathutils.py: -------------------------------------------------------------------------------- 1 | from math import comb 2 | 3 | def combination_to_index(combination, n): 4 | """ 5 | 给定组合,返回其编号。 6 | 7 | :param combination: 包含k个元素的列表,表示组合。 8 | :param n: 元素取值范围是1到n。 9 | :return: 组合的编号。 10 | """ 11 | k = len(combination) 12 | index = 0 13 | for i in range(k): 14 | element = combination[i] 15 | if i > 0: 16 | prev_element = combination[i - 1] 17 | else: 18 | prev_element = 0 19 | 20 | for j in range(prev_element + 1, element): 21 | index += comb(n - j, k - i - 1) 22 | 23 | return index 24 | 25 | def index_to_combination(index, k, n): 26 | """ 27 | 给定编号,返回组合。 28 | 29 | :param index: 组合的编号。 30 | :param k: 组合中元素的个数。 31 | :param n: 元素取值范围是1到n。 32 | :return: 对应的组合。 33 | """ 34 | combination = [] 35 | current_index = index 36 | start = 1 37 | 38 | for i in range(k): 39 | for j in range(start, n + 1): 40 | count = comb(n - j, k - i - 1) 41 | if current_index < count: 42 | combination.append(j) 43 | start = j + 1 44 | break 45 | current_index -= count 46 | 47 | return combination 48 | 49 | def generate_combination_mappings(n, k): 50 | """ 51 | 生成组合与编号的映射关系。 52 | 53 | :param n: 元素取值范围是1到n。 54 | :param k: 组合中元素的个数。 55 | :return: (组合到编号的字典, 编号到组合的列表) 56 | """ 57 | def generate_combinations(start, k, n, current_combination, all_combinations): 58 | if k == 0: 59 | all_combinations.append(list(current_combination)) 60 | return 61 | for i in range(start, n + 1): 62 | current_combination.append(i) 63 | generate_combinations(i + 1, k - 1, n, current_combination, all_combinations) 64 | current_combination.pop() 65 | 66 | all_combinations = [] 67 | generate_combinations(1, k, n, [], all_combinations) 68 | 69 | combination_to_index_map = {} 70 | index_to_combination_map = [] 71 | 72 | for combination in all_combinations: 73 | index = combination_to_index(combination, n) 74 | combination_to_index_map[tuple(combination)] = index 75 | index_to_combination_map.append(combination) 76 | 77 | return combination_to_index_map, index_to_combination_map 78 | -------------------------------------------------------------------------------- /miche/michelangelo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import importlib 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | 9 | 10 | def get_obj_from_str(string, reload=False): 11 | module, cls = string.rsplit(".", 1) 12 | if reload: 13 | module_imp = importlib.import_module(module) 14 | importlib.reload(module_imp) 15 | return getattr(importlib.import_module(module, package=None), cls) 16 | 17 | 18 | def get_obj_from_config(config): 19 | if "target" not in config: 20 | raise KeyError("Expected key `target` to instantiate.") 21 | 22 | return get_obj_from_str(config["target"]) 23 | 24 | 25 | def instantiate_from_config(config, **kwargs): 26 | if "target" not in config: 27 | raise KeyError("Expected key `target` to instantiate.") 28 | 29 | cls = get_obj_from_str(config["target"]) 30 | 31 | params = config.get("params", dict()) 32 | # params.update(kwargs) 33 | # instance = cls(**params) 34 | kwargs.update(params) 35 | instance = cls(**kwargs) 36 | 37 | return instance 38 | 39 | 40 | def is_dist_avail_and_initialized(): 41 | if not dist.is_available(): 42 | return False 43 | if not dist.is_initialized(): 44 | return False 45 | return True 46 | 47 | 48 | def get_rank(): 49 | if not is_dist_avail_and_initialized(): 50 | return 0 51 | return dist.get_rank() 52 | 53 | 54 | def get_world_size(): 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | return dist.get_world_size() 58 | 59 | 60 | def all_gather_batch(tensors): 61 | """ 62 | Performs all_gather operation on the provided tensors. 63 | """ 64 | # Queue the gathered tensors 65 | world_size = get_world_size() 66 | # There is no need for reduction in the single-proc case 67 | if world_size == 1: 68 | return tensors 69 | tensor_list = [] 70 | output_tensor = [] 71 | for tensor in tensors: 72 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 73 | dist.all_gather( 74 | tensor_all, 75 | tensor, 76 | async_op=False # performance opt 77 | ) 78 | 79 | tensor_list.append(tensor_all) 80 | 81 | for tensor_all in tensor_list: 82 | output_tensor.append(torch.cat(tensor_all, dim=0)) 83 | return output_tensor 84 | -------------------------------------------------------------------------------- /miche/michelangelo/models/modules/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from typing import Callable, Iterable, Sequence, Union 5 | 6 | 7 | def checkpoint( 8 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 9 | inputs: Sequence[torch.Tensor], 10 | params: Iterable[torch.Tensor], 11 | flag: bool, 12 | use_deepspeed: bool = False 13 | ): 14 | # Evaluate a function without caching intermediate activations, allowing for 15 | # reduced memory at the expense of extra compute in the backward pass. 16 | # :param func: the function to evaluate. 17 | # :param inputs: the argument sequence to pass to `func`. 18 | # :param params: a sequence of parameters `func` depends on but does not 19 | # explicitly take as arguments. 20 | # :param flag: if False, disable gradient checkpointing. 21 | # :param use_deepspeed: if True, use deepspeed 22 | if flag: 23 | if use_deepspeed: 24 | import deepspeed 25 | return deepspeed.checkpointing.checkpoint(func, *inputs) 26 | 27 | args = tuple(inputs) + tuple(params) 28 | return CheckpointFunction.apply(func, len(inputs), *args) 29 | else: 30 | return func(*inputs) 31 | 32 | 33 | class CheckpointFunction(torch.autograd.Function): 34 | @staticmethod 35 | @torch.cuda.amp.custom_fwd 36 | def forward(ctx, run_function, length, *args): 37 | ctx.run_function = run_function 38 | ctx.input_tensors = list(args[:length]) 39 | ctx.input_params = list(args[length:]) 40 | 41 | with torch.no_grad(): 42 | output_tensors = ctx.run_function(*ctx.input_tensors) 43 | return output_tensors 44 | 45 | @staticmethod 46 | @torch.cuda.amp.custom_bwd 47 | def backward(ctx, *output_grads): 48 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 49 | with torch.enable_grad(): 50 | # Fixes a bug where the first op in run_function modifies the 51 | # Tensor storage in place, which is not allowed for detach()'d 52 | # Tensors. 53 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 54 | output_tensors = ctx.run_function(*shallow_copies) 55 | input_grads = torch.autograd.grad( 56 | output_tensors, 57 | ctx.input_tensors + ctx.input_params, 58 | output_grads, 59 | allow_unused=True, 60 | ) 61 | del ctx.input_tensors 62 | del ctx.input_params 63 | del output_tensors 64 | return (None, None) + input_grads 65 | -------------------------------------------------------------------------------- /silkutils/dataset_clean/step2_clean.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | def get_command(dataset_name, part_ind, batch_ind, max_workers, resolution, version, head): 4 | command_template=f'python dataset_clean/process_dataset.py --dataset_name {dataset_name} --part_ind {part_ind} --batch_ind {batch_ind} --max_workers {max_workers} --b_l 1000 --resolution {resolution} --version {version} --head {head}' 5 | return command_template 6 | 7 | def part_batch_mapping(dataset_name): 8 | # the start and end index of init xlsx file, you may modify it for part processing 9 | map_part={ 10 | '3dfuture': (0, 9), # for example, xlsx file part index from 0-8, it is based on you situation 11 | 'objaverse': (0, 40), # 0-157 12 | 'gobjaversev1': (40, 53), # 0-52 13 | 'abo': (0, 36), 14 | 'toys4k': (0, 1), 15 | 'thingi10k': (0, 5), 16 | 'shapenetv2': (30, 55), # 0-54 17 | 'animal3d': (0, 2), 18 | 'gso': (0, 1), 19 | 'buildingnet': (0, 1), 20 | 'trellis-hssd': (0, 4), 21 | 'trellis-abo': (0, 36), 22 | 'trellis-3dfuture': (0, 5), 23 | 'trellis-objxl-sketchfab': (120, 160), # 0-159 24 | 'trellis-objxl-github': (50, 61) # 0-60 61 end 25 | } 26 | # process files with batch size 1000 because of ulimit on linux. Just keep this > xlsx's mesh number/1000 27 | map_batch={ 28 | '3dfuture': 10, 29 | 'objaverse': 8, 30 | 'gobjaversev1': 8, 31 | 'abo': 5, 32 | 'toys4k': 5, 33 | 'thingi10k': 5, 34 | 'shapenetv2': 10, 35 | 'animal3d': 10, 36 | 'gso': 10, 37 | 'buildingnet': 10, 38 | 'trellis-hssd': 5, 39 | 'trellis-abo': 5, 40 | 'trellis-3dfuture': 5, 41 | 'trellis-objxl-sketchfab': 5, 42 | 'trellis-objxl-github': 8, 43 | } 44 | return map_part[dataset_name][0], map_part[dataset_name][1], map_batch[dataset_name] 45 | 46 | 47 | def clean_datasetname(dataset_name, resolution, version, head=-1, max_workers=64): 48 | part_num0, part_num1, batch_num=part_batch_mapping(dataset_name=dataset_name) 49 | for i in range(part_num0, part_num1): 50 | for j in range(batch_num): 51 | command=get_command(dataset_name=dataset_name, part_ind=i, batch_ind=j, max_workers=max_workers, resolution=resolution, version=version, head=head) 52 | print(f'command: {command}') 53 | subprocess.run(command, shell=True) 54 | 55 | if __name__ == "__main__": 56 | datasetname='gobjaversev1' 57 | resolution=128 58 | version=4 59 | head=-1 60 | max_workers=64 61 | clean_datasetname(dataset_name=datasetname, resolution=resolution, version=version, head=head, max_workers=max_workers) 62 | 63 | -------------------------------------------------------------------------------- /miche/michelangelo/models/tsal/inference_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from tqdm import tqdm 5 | from einops import repeat 6 | import numpy as np 7 | from typing import Callable, Tuple, List, Union, Optional 8 | from skimage import measure 9 | 10 | from miche.michelangelo.graphics.primitives import generate_dense_grid_points 11 | 12 | 13 | @torch.no_grad() 14 | def extract_geometry(geometric_func: Callable, 15 | device: torch.device, 16 | batch_size: int = 1, 17 | bounds: Union[Tuple[float], List[float], float] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 18 | octree_depth: int = 7, 19 | num_chunks: int = 10000, 20 | disable: bool = True): 21 | 22 | # Args: 23 | # geometric_func: 24 | # device: 25 | # bounds: 26 | # octree_depth: 27 | # batch_size: 28 | # num_chunks: 29 | # disable: 30 | # Returns: 31 | 32 | if isinstance(bounds, float): 33 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 34 | 35 | bbox_min = np.array(bounds[0:3]) 36 | bbox_max = np.array(bounds[3:6]) 37 | bbox_size = bbox_max - bbox_min 38 | 39 | xyz_samples, grid_size, length = generate_dense_grid_points( 40 | bbox_min=bbox_min, 41 | bbox_max=bbox_max, 42 | octree_depth=octree_depth, 43 | indexing="ij" 44 | ) 45 | xyz_samples = torch.FloatTensor(xyz_samples) 46 | 47 | batch_logits = [] 48 | for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), 49 | desc="Implicit Function:", disable=disable, leave=False): 50 | queries = xyz_samples[start: start + num_chunks, :].to(device) 51 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) 52 | 53 | logits = geometric_func(batch_queries) 54 | batch_logits.append(logits.cpu()) 55 | 56 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).numpy() 57 | 58 | mesh_v_f = [] 59 | has_surface = np.zeros((batch_size,), dtype=np.bool_) 60 | for i in range(batch_size): 61 | try: 62 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") 63 | vertices = vertices / grid_size * bbox_size + bbox_min 64 | # vertices[:, [0, 1]] = vertices[:, [1, 0]] 65 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) 66 | has_surface[i] = True 67 | 68 | except ValueError: 69 | mesh_v_f.append((None, None)) 70 | has_surface[i] = False 71 | 72 | except RuntimeError: 73 | mesh_v_f.append((None, None)) 74 | has_surface[i] = False 75 | 76 | return mesh_v_f, has_surface 77 | -------------------------------------------------------------------------------- /model/miche_conditioner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from beartype import beartype 4 | from miche.encode import load_model 5 | import os 6 | # helper functions 7 | 8 | def exists(val): 9 | return val is not None 10 | 11 | def default(*values): 12 | for value in values: 13 | if exists(value): 14 | return value 15 | return None 16 | 17 | 18 | # point-cloud encoder from Michelangelo 19 | @beartype 20 | class PointConditioner(torch.nn.Module): 21 | def __init__( 22 | self, 23 | *, 24 | dim_latent = None, 25 | model_name = 'miche-256-feature', 26 | cond_dim = 768, 27 | freeze = True, 28 | ): 29 | super().__init__() 30 | 31 | # open-source version of miche 32 | if model_name == 'miche-256-feature': 33 | ckpt_path = "miche/shapevae-256.ckpt" 34 | if not os.path.exists(ckpt_path): 35 | ckpt_path=None 36 | print('[WARNING] Michelangelo ckpt not exist, please check if you are training') 37 | config_path = 'miche/shapevae-256.yaml' 38 | 39 | self.feature_dim = 1024 # embedding dimension 40 | self.cond_length = 257 # length of embedding 41 | self.point_encoder = load_model(ckpt_path=ckpt_path, config_path=config_path) 42 | 43 | # additional layers to connect miche and GPT 44 | self.cond_head_proj = nn.Linear(cond_dim, self.feature_dim) 45 | self.cond_proj = nn.Linear(cond_dim, self.feature_dim) 46 | 47 | else: 48 | raise NotImplementedError 49 | 50 | # whether to finetuen point-cloud encoder 51 | if freeze: 52 | for parameter in self.point_encoder.parameters(): 53 | parameter.requires_grad = False 54 | 55 | self.freeze = freeze 56 | self.model_name = model_name 57 | self.dim_latent = default(dim_latent, self.feature_dim) 58 | 59 | self.register_buffer('_device_param', torch.tensor(0.), persistent = False) 60 | 61 | 62 | @property 63 | def device(self): 64 | return next(self.buffers()).device 65 | 66 | 67 | def embed_pc(self, pc_normal): 68 | # encode point cloud to embeddings 69 | if self.model_name == 'miche-256-feature': 70 | point_feature = self.point_encoder.encode_latents(pc_normal) 71 | pc_embed_head = self.cond_head_proj(point_feature[:, 0:1]) 72 | pc_embed = self.cond_proj(point_feature[:, 1:]) 73 | pc_embed = torch.cat([pc_embed_head, pc_embed], dim=1) 74 | 75 | return pc_embed 76 | 77 | 78 | def forward( 79 | self, 80 | pc = None, 81 | pc_embeds = None, 82 | ): 83 | if pc_embeds is None: 84 | pc_embeds = self.embed_pc(pc.to(next(self.buffers()).dtype)) 85 | 86 | assert not torch.any(torch.isnan(pc_embeds)), 'NAN values in pc embedings' 87 | 88 | return pc_embeds 89 | 90 | -------------------------------------------------------------------------------- /miche/michelangelo/models/modules/distributions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import numpy as np 5 | from typing import Union, List 6 | 7 | 8 | class DiagonalGaussianDistribution(object): 9 | # Gaussian distribution 10 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 11 | self.feat_dim = feat_dim 12 | self.parameters = parameters 13 | 14 | if isinstance(parameters, list): 15 | self.mean = parameters[0] 16 | self.logvar = parameters[1] 17 | else: 18 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 19 | 20 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 21 | self.deterministic = deterministic 22 | self.std = torch.exp(0.5 * self.logvar) 23 | self.var = torch.exp(self.logvar) 24 | if self.deterministic: 25 | self.var = self.std = torch.zeros_like(self.mean) 26 | 27 | # sample from the guassian distribution 28 | def sample(self): 29 | x = self.mean + self.std * torch.randn_like(self.mean) 30 | return x 31 | 32 | def kl(self, other=None, dims=(1, 2, 3)): 33 | if self.deterministic: 34 | return torch.Tensor([0.]) 35 | else: 36 | if other is None: 37 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 38 | + self.var - 1.0 - self.logvar, 39 | dim=dims) 40 | else: 41 | return 0.5 * torch.mean( 42 | torch.pow(self.mean - other.mean, 2) / other.var 43 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 44 | dim=dims) 45 | 46 | def nll(self, sample, dims=(1, 2, 3)): 47 | if self.deterministic: 48 | return torch.Tensor([0.]) 49 | logtwopi = np.log(2.0 * np.pi) 50 | return 0.5 * torch.sum( 51 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 52 | dim=dims) 53 | 54 | def mode(self): 55 | return self.mean 56 | 57 | 58 | def normal_kl(mean1, logvar1, mean2, logvar2): 59 | # Compute the KL divergence between two gaussians. 60 | # Shapes are automatically broadcasted, so batches can be compared to 61 | # scalars, among other use cases. 62 | 63 | tensor = None 64 | for obj in (mean1, logvar1, mean2, logvar2): 65 | if isinstance(obj, torch.Tensor): 66 | tensor = obj 67 | break 68 | assert tensor is not None, "at least one argument must be a Tensor" 69 | 70 | # Force variances to be Tensors. Broadcasting helps convert scalars to 71 | # Tensors, but it does not work for torch.exp(). 72 | logvar1, logvar2 = [ 73 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 74 | for x in (logvar1, logvar2) 75 | ] 76 | 77 | return 0.5 * ( 78 | -1.0 79 | + logvar2 80 | - logvar1 81 | + torch.exp(logvar1 - logvar2) 82 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 83 | ) 84 | -------------------------------------------------------------------------------- /miche/encode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | from omegaconf import OmegaConf 4 | import numpy as np 5 | import torch 6 | from .michelangelo.utils.misc import instantiate_from_config 7 | 8 | def load_surface(fp): 9 | 10 | with np.load(fp) as input_pc: 11 | surface = input_pc['points'] 12 | normal = input_pc['normals'] 13 | 14 | rng = np.random.default_rng() 15 | ind = rng.choice(surface.shape[0], 4096, replace=False) 16 | surface = torch.FloatTensor(surface[ind]) 17 | normal = torch.FloatTensor(normal[ind]) 18 | 19 | surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() 20 | 21 | return surface 22 | 23 | def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): 24 | 25 | surface = load_surface(args.pointcloud_path) 26 | # old_surface = surface.clone() 27 | 28 | # surface[0,:,0]*=-1 29 | # surface[0,:,1]*=-1 30 | surface[0,:,2]*=-1 31 | 32 | # encoding 33 | shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) 34 | shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) 35 | 36 | # decoding 37 | latents = model.model.shape_model.decode(shape_zq) 38 | # geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) 39 | 40 | return 0 41 | 42 | def load_model(ckpt_path="miche/shapevae-256.ckpt", config_path="miche/shapevae-256.yaml"): 43 | 44 | model_config = OmegaConf.load(config_path) 45 | # print(model_config) 46 | if hasattr(model_config, "model"): 47 | model_config = model_config.model 48 | 49 | original_load = torch.load 50 | torch.load = lambda path, *args, **kwargs: original_load(path, map_location="cpu", weights_only=False) 51 | try: 52 | model = instantiate_from_config(model_config, ckpt_path=ckpt_path) 53 | model = model.eval() 54 | finally: 55 | # 恢复原始的 torch.load 56 | torch.load = original_load 57 | 58 | return model 59 | if __name__ == "__main__": 60 | ''' 61 | 1. Reconstruct point cloud 62 | 2. Image-conditioned generation 63 | 3. Text-conditioned generation 64 | ''' 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--config_path", type=str, required=True) 67 | parser.add_argument("--ckpt_path", type=str, required=True) 68 | parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', 69 | help='Path to the input point cloud') 70 | parser.add_argument("--image_path", type=str, help='Path to the input image') 71 | parser.add_argument("--text", type=str, 72 | help='Input text within a format: A 3D model of motorcar; Porsche 911.') 73 | parser.add_argument("--output_dir", type=str, default='./output') 74 | parser.add_argument("-s", "--seed", type=int, default=0) 75 | args = parser.parse_args() 76 | 77 | print(f'-----------------------------------------------------------------------------') 78 | print(f'>>> Output directory: {args.output_dir}') 79 | print(f'-----------------------------------------------------------------------------') 80 | 81 | reconstruction(args, load_model(args)) 82 | -------------------------------------------------------------------------------- /silkutils/silksong_tokenization.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 4 | import torch 5 | import numpy as np 6 | import trimesh 7 | from meto.ss_engine import Engine 8 | import logging 9 | 10 | 11 | def get_tokenizer_silksong(resolution=128, ss_mode=4, meta_init_data={}, structure_limit={}, debugging=False): 12 | 13 | tokenizer = Engine(discrete_bins=resolution, mode=ss_mode, debugging=debugging, meta_init_data=meta_init_data, structure_limit=structure_limit) 14 | vocab_size = tokenizer.num_word_table # C U E I0 I1 Out B O 15 | 16 | return tokenizer, vocab_size 17 | 18 | def tokenize_mesh_ss(tokenizer, vertices, faces, non_mani_process=True): 19 | tokens_ori = tokenizer.encode(vertices, faces, non_mani_process) 20 | tokens, meta_data = tokenizer.token_encode(tokens_ori, tokenizer.mode) 21 | return tokens, meta_data 22 | 23 | def detokenize_mesh_ss(tokenizer, tokens, colorful=False, mani_fix=False): 24 | vertices, faces = tokenizer.decode(tokens, tokenizer.discrete_bins, tokenizer.mode, colorful, mani_fix) 25 | return vertices, faces 26 | 27 | def quantize_num_faces_ss(n): 28 | # 0: <=0, un cond 29 | # 1: 0-1000, low-poly 30 | # 2: 1000-2000, mid-poly 31 | # 3: 2000-4000, high-poly 32 | # 4: 4000-8000, ultra-poly 33 | 34 | if isinstance(n, int): 35 | if n <= 0: 36 | return 0 37 | elif n <= 1000: 38 | return 1 39 | elif n <= 2000: 40 | return 2 41 | elif n <= 4000: 42 | return 3 43 | elif n <= 8000: 44 | return 4 45 | elif n <= 12000: 46 | return 5 47 | elif n <= 16000: 48 | return 6 49 | else: 50 | return 7 51 | else: # torch tensor 52 | results = torch.zeros_like(n) 53 | # results[n <= 0] = 0 54 | results[(n > 0) & (n <= 1000)] = 1 55 | results[(n > 1000) & (n <= 2000)] = 2 56 | results[(n > 2000) & (n <= 4000)] = 3 57 | results[(n > 4000) & (n <= 8000)] = 4 58 | results[(n > 8000) & (n <= 12000)] = 5 59 | results[(n > 12000) & (n <= 16000)] = 6 60 | results[n > 16000] = 7 61 | return results 62 | 63 | def quantize_num_CC_ss(n): 64 | # 0: <=0, un cond 65 | # 1: 0-1000, low-poly 66 | # 2: 1000-2000, mid-poly 67 | # 3: 2000-4000, high-poly 68 | # 4: 4000-8000, ultra-poly 69 | 70 | if isinstance(n, int): 71 | if n <= 0: 72 | return 0 73 | elif n <= 1: 74 | return 1 75 | elif n <= 5: 76 | return 2 77 | elif n <= 10: 78 | return 3 79 | elif n <= 30: 80 | return 4 81 | elif n <= 50: 82 | return 5 83 | elif n <= 100: 84 | return 6 85 | else: 86 | return 7 87 | else: # torch tensor 88 | results = torch.zeros_like(n) 89 | # results[n <= 0] = 0 90 | results[(n > 0) & (n <= 1)] = 1 91 | results[(n > 1) & (n <= 5)] = 2 92 | results[(n > 5) & (n <= 10)] = 3 93 | results[(n > 10) & (n <= 30)] = 4 94 | results[(n > 30) & (n <= 50)] = 5 95 | results[(n > 50) & (n <= 100)] = 6 96 | results[n > 100] = 7 97 | return results -------------------------------------------------------------------------------- /silkutils/silksong_manifold_process.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 4 | from meshdata.mesh_io import load_mesh, write_obj 5 | from meshdata.mesh_structure import Mesh 6 | import argparse 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser( 10 | description='config for non-manifold mesh', 11 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 12 | ) 13 | 14 | parser.add_argument('--input_file', 15 | default='silkutils/demo_test/shapenetv2_03761084_ee5861.obj', 16 | type=str, 17 | help='inputfile') 18 | 19 | 20 | parser.add_argument('--quant_resolution', 21 | default=1024, 22 | type=int, 23 | help='quantization resolution') 24 | 25 | parser.add_argument('-o', '--output_dir', 26 | default='silkutils/demo_test/manifold_repair', 27 | help='output dir') 28 | 29 | parser.add_argument('--verbose', 30 | action='store_true', 31 | help='if show detail') 32 | 33 | parser.add_argument('--max_edge_graph_edges', 34 | default=100, 35 | type=int, 36 | help='structure limit') 37 | 38 | parser.add_argument('--max_nonmani_verts', 39 | default=500, 40 | type=int, 41 | help='structure limit2') 42 | 43 | parser.add_argument('--min_CC_face', 44 | default=4, 45 | type=int, 46 | help='structure limit3') 47 | 48 | parser.add_argument('--max_face_num', 49 | default=16000, 50 | type=int, 51 | help='structure limit4') 52 | 53 | 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | if __name__ == '__main__': 59 | args = parse_args() 60 | 61 | structure_limit_kwargs={ 62 | 'NM_max_edge_graph': args.max_edge_graph_edges, 63 | 'NM_max_nonmani_verts': args.max_nonmani_verts, 64 | 'min_CC_face': args.min_CC_face, 65 | 'max_face_num_p': args.max_face_num, 66 | } 67 | 68 | output_dir=args.output_dir 69 | input_file=args.input_file 70 | resolution=args.quant_resolution 71 | verbose=args.verbose 72 | 73 | os.makedirs(output_dir, exist_ok=True) 74 | 75 | save_name=os.path.basename(input_file).split('.')[0] 76 | 77 | # M1: loaded by trimesh, normalized, pre clean by kiui 78 | M1_save_path=os.path.join(output_dir, f"M1_{save_name}.obj") 79 | # M2: mesh quantization, nonmani processing, colored by connected component 80 | M2_save_path=os.path.join(output_dir, f"M2_{save_name}.obj") 81 | 82 | try: 83 | vertices, faces = load_mesh(input_file, clean=True) 84 | except Exception as e: 85 | raise Exception('[E] loading Failed') 86 | 87 | 88 | write_obj(vertices, faces, M1_save_path) 89 | 90 | mesh = Mesh(vertices=vertices, triangles=faces, discrete_bins=resolution, verbose=verbose, debugging=False, non_mani_process=True, NM_max_edge_graph=structure_limit_kwargs["NM_max_edge_graph"], NM_max_nonmani_verts=structure_limit_kwargs["NM_max_nonmani_verts"], min_CC_face=structure_limit_kwargs["min_CC_face"], M2_path=M2_save_path, just_process=True) 91 | 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /config/options.py: -------------------------------------------------------------------------------- 1 | import tyro 2 | from dataclasses import dataclass, field 3 | from typing import Optional, List, Tuple, Literal 4 | 5 | @dataclass 6 | class MetoConfigs: 7 | """Mesh to token engine config""" 8 | # basic 9 | discrete_bins: int = 128 10 | ss_mode: int = 4 11 | block_size: int = 8 12 | offset_size: int = 16 13 | 14 | 15 | @dataclass 16 | class TrainConfigs: 17 | # basic 18 | lr: float = 1e-4 19 | num_epochs: int = 100 20 | save_epoch: int = 1 21 | eval_mode: str = "loss" # loss / none 22 | warmup_ratio: float = 0.01 23 | use_wandb: int = 1 24 | 25 | # resume 26 | resume: Optional[str] = None 27 | resume_epoch: int = 0 28 | ft: int = 0 29 | 30 | # others 31 | resume_step_ratio: float = 0 32 | gradient_accumulation_steps: int = 1 33 | gradient_clip: float = 1.0 34 | mixed_precision: Literal['no', 'fp8', 'fp16', 'fp32'] = 'fp16' 35 | checkpointing: bool = True # gradient checkpointing 36 | debug_eval: int = 0 37 | 38 | @dataclass 39 | class InferConfigs: 40 | # basic 41 | test_path_input: str = "test_datasets/sb06" 42 | test_repeat: int = 1 43 | infer_batch: int = 1 44 | temperature: float = 0.5 45 | max_filter: int = 0 46 | 47 | 48 | @dataclass 49 | class DataConfigs: 50 | 51 | # dataset 52 | dataset: str = "ss" # debug_one / ss 53 | data_subsets: str = "gobjaversev1" 54 | xlsx_dir: str = "datasets/cleaned" 55 | testset_xlsx_dir: str = "datasets/sample_test/tables" 56 | testset_prefix: str = "testset" 57 | data_filter_cnt: int = 8 58 | 59 | # resample 60 | resample: int = 0 61 | face_delta: int = 100 62 | i_beta: float=0.0 63 | e_beta: float=1.0 64 | 65 | # iter 66 | batch_size: int = 4 67 | num_workers: int = 4 68 | testset_size: int = 32 69 | 70 | # aug 71 | use_scale_aug: int = 1 72 | use_rot_aug: int = 1 73 | use_decimate_aug: int = 0 74 | 75 | 76 | @dataclass 77 | class ModelConfigs: 78 | 79 | # encoder 80 | conditioned_on_pc: int = 1 81 | encoder_name: str = "miche-256-feature" 82 | encoder_freeze: int = 0 83 | pc_num: int = 4096 84 | 85 | # GPT 86 | mode: str = "vertices" 87 | dim: int = 1024 88 | depth: int = 24 89 | attn_dim_head: int = 64 90 | attn_heads: int = 16 91 | dropout: float = 0.0 92 | pad_token_id: int = -1 93 | 94 | 95 | 96 | @dataclass 97 | class LoggingConfigs: 98 | 99 | output_dir: str = "outputs" 100 | log_dir: str = "logs" 101 | 102 | 103 | @dataclass 104 | class AllConfigs: 105 | 106 | train: TrainConfigs = field(default_factory=TrainConfigs) 107 | infer: InferConfigs = field(default_factory=InferConfigs) 108 | data: DataConfigs = field(default_factory=DataConfigs) 109 | model: ModelConfigs = field(default_factory=ModelConfigs) 110 | logging: LoggingConfigs = field(default_factory=LoggingConfigs) 111 | meto: MetoConfigs = field(default_factory=MetoConfigs) 112 | 113 | # global 114 | workspace: str = "workspace_train/" 115 | max_face_length: int = 12000 116 | max_seq_length: int = 10240 117 | seed: int = 0 118 | 119 | def __getattr__(self, name): 120 | 121 | for config in [self.train, self.infer, self.data, self.model, self.logging, self.meto]: 122 | if hasattr(config, name): 123 | return getattr(config, name) 124 | raise AttributeError(f"'AllConfigs' object has no attribute '{name}'") 125 | 126 | def __post_init__(self): 127 | # post process 128 | if self.data.batch_size <= 0: 129 | raise ValueError("batch_size must be positive") 130 | -------------------------------------------------------------------------------- /silkutils/meto/ss_engine.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | import numpy as np 5 | import meto.ss_meto as ss_meto 6 | 7 | class Engine: 8 | def __init__(self, discrete_bins, mode, debugging, meta_init_data, structure_limit): 9 | self.discrete_bins = discrete_bins 10 | self.mode = mode 11 | self.debugging = debugging 12 | 13 | self.num_base_tokens = discrete_bins 14 | 15 | 16 | if self.mode==3: # vertex: 3 tokens for xyz 17 | self.impl = ss_meto.Engine_SilkSong(discrete_bins=discrete_bins, verbose=False, debugging=debugging, meta_init_data=meta_init_data, structure_limit=structure_limit) 18 | self.num_special_tokens = self.impl.OP.OP_NUM 19 | self.num_word_table = self.num_base_tokens + self.num_special_tokens + self.impl.topology_num # word table 20 | self.impl.num_word_table=self.num_word_table 21 | elif self.mode==4: # vertex: 2 tokens for BO 22 | self.impl = ss_meto.Engine_SilkSong(discrete_bins=discrete_bins, verbose=False, debugging=debugging, meta_init_data=meta_init_data, structure_limit=structure_limit) 23 | self.num_special_tokens = self.impl.OP.OP_NUM 24 | self.blocks=8 25 | self.offsets=16 26 | self.num_word_table = self.blocks**3 + self.offsets**3 + self.num_special_tokens + self.impl.topology_num # word table 27 | self.impl.num_word_table=self.num_word_table 28 | 29 | 30 | def get_metaData(self): 31 | return self.impl.get_metaData() 32 | 33 | 34 | def encode(self, vertices, faces, non_mani_process): 35 | # vertices: [N, 3], float 36 | # faces: [M, 3], int 37 | tokens = self.impl.encode(vertices, faces, non_mani_process) 38 | return np.asarray(tokens) 39 | 40 | def token_encode(self, input_tokens, mode): 41 | tokens = self.impl.token_encode(input_tokens, mode) 42 | return np.asarray(tokens), self.impl.get_metaData() 43 | 44 | def decode(self, tokens, discrete_bins, mode, colorful=False, manifix=False): 45 | # tokens: [N], int 46 | vertices, faces = self.impl.decode(tokens, discrete_bins, mode, colorful, manifix) 47 | ret_1=None 48 | ret_2=None 49 | if colorful: 50 | ret_1=[np.asarray(ele) for ele in vertices] 51 | else: 52 | ret_1=np.asarray(vertices) 53 | if manifix: 54 | ret_2=[np.asarray(ele) for ele in faces] 55 | else: 56 | ret_2=np.asarray(faces) 57 | return ret_1, ret_2 58 | 59 | def decode_ori(self, tokens): 60 | # tokens: [N], int 61 | vertices, faces= self.impl.decode_ori(tokens) 62 | return np.asarray(vertices), np.asarray(faces) 63 | 64 | def decode_bfs(self, faces): 65 | # tokens: [N], int 66 | vertices, faces= self.impl.decode_bfs(faces) 67 | return np.asarray(vertices), np.asarray(faces) 68 | 69 | def translate_tokens(self, tokens, discrete_bins, mode): 70 | translated_tokens=self.impl.translate_tokens(tokens, discrete_bins, mode) 71 | return translated_tokens 72 | 73 | def translate_tokens_direct(self, tokens): 74 | translated_tokens=self.impl.translate_tokens_direct(tokens) 75 | return translated_tokens 76 | 77 | def trans_compare_tokens(self, tokens, tokens_gt): 78 | difference_tokens=self.impl.translate_compare_tokens(tokens, tokens_gt) 79 | return difference_tokens 80 | 81 | def get_token_classify(self): 82 | return self.impl.get_token_classify() 83 | 84 | def get_token_map_GPT(self, token_classify_dic, mode): 85 | return self.impl.get_token_map_list_GPT(token_classify_dic, mode) -------------------------------------------------------------------------------- /silkutils/ss_platform.py: -------------------------------------------------------------------------------- 1 | import os 2 | ###################### local windows or H800 cluster 3 | def get_platform(): 4 | if os.path.exists('/public/SothisAI'): 5 | platform='H800' 6 | elif os.path.exists('/workspace/3ddatasets'): 7 | platform='local_docker' 8 | else: 9 | platform='H800_docker' 10 | return platform 11 | 12 | def base_dir(): 13 | if get_platform() == "local_docker": 14 | return '/workspace/3ddatasets' 15 | else: 16 | return '/public/home/group_gaosh/gaochao/3ddatasets' 17 | 18 | def get_savedirs(): 19 | save_dirs={ 20 | 'npy': 'process_dir/npy_dir', 21 | 'meta': 'process_dir/meta_dir', 22 | 'debug': 'process_dir/debug_dir', 23 | 'spc': 'process_dir/spc_dir', 24 | 'error': 'process_dir/error_log' 25 | } 26 | return save_dirs 27 | 28 | def get_base_dir_platform(name): 29 | path_prefix=base_dir() 30 | if get_platform()=="H800_docker": 31 | objaverse_fix='_docker' 32 | trellis_fix='_docker' 33 | else: 34 | objaverse_fix='' 35 | trellis_fix='' 36 | base_dir_platform_dic_local={ 37 | '3dcaricshop': f"{path_prefix}/3DCaricShop", 38 | 'shapenetv2' : f"{path_prefix}/shapenet/ShapeNetCore.v2", 39 | 'objaversev1' : f"{path_prefix}/Objaverse{objaverse_fix}", 40 | 'gobjaversev1' : f"{path_prefix}/gObjaverse{objaverse_fix}", 41 | 'abo': f"{path_prefix}/ABO", 42 | '3dfuture': f"{path_prefix}/3DFuture", 43 | 'animal3d': f"{path_prefix}/animal3d", 44 | 'buildingnet' : f"{path_prefix}/BuildingNet", 45 | 'thingi10k' : f"{path_prefix}/Thingi10K/Thingi10K-002/Thingi10K-002/Thingi10K", 46 | 'toys4k' : f"{path_prefix}/TOYS4K", 47 | 'gso' : f"{path_prefix}/GSO", 48 | } 49 | base_dir_platform_dic_h800={ 50 | '3dcaricshop': f"{path_prefix}/3DCaricShop", 51 | 'shapenetv2' : f"{path_prefix}/shapenet/shapenet/ShapeNetCore.v2", 52 | 'objaversev1' : f"{path_prefix}/Objaverse{objaverse_fix}", 53 | 'gobjaversev1' : f"{path_prefix}/gObjaverse{objaverse_fix}", 54 | 'abo': f"{path_prefix}/ABO", 55 | '3dfuture': f"{path_prefix}/3DFuture/3DFuture", 56 | 'animal3d': f"{path_prefix}/animal3d/animal3d", 57 | 'buildingnet' : f"{path_prefix}/BuildingNet/BuildingNet", 58 | 'thingi10k' : f"{path_prefix}/Thingi10K/Thingi10K/Thingi10K", 59 | 'toys4k' : f"{path_prefix}/TOYS4K/TOYS4K", 60 | 'gso' : f"{path_prefix}/GSO/GSO", 61 | 'trellis-hssd': f'{path_prefix}/trellis/HSSD/HSSD/raw', 62 | 'trellis-3dfuture': f'{path_prefix}/trellis/3D-FUTURE/raw', 63 | 'trellis-abo': f'{path_prefix}/trellis/ABO/ABO', 64 | 'trellis-toys4k': '', 65 | 'trellis-objxl-github': f'{path_prefix}/trellis/Objaversexl_github', 66 | 'trellis-objxl-sketchfab': f'{path_prefix}/trellis/Objaversexl_sketchfab/raw/raw', 67 | } 68 | if get_platform()=="local_docker": 69 | return base_dir_platform_dic_local[name] 70 | else: 71 | return base_dir_platform_dic_h800[name] 72 | 73 | def get_base_dir_rel(name): 74 | base_dir_rel_dic={ 75 | '3dcaricshop': 'processedData', 76 | 'shapenetv2' : "ShapeNetCore.v2", 77 | 'objaversev1' : 'hf-objaverse-v1', 78 | 'gobjaversev1' : 'hf-objaverse-v1', 79 | 'abo': '3dmodels', 80 | '3dfuture': '3D-FUTURE-model', 81 | 'animal3d': 'drive-download', 82 | 'buildingnet': 'OBJ_MODELS-001', 83 | 'thingi10k' : 'raw_meshes', 84 | 'toys4k' : 'obj_points', 85 | 'gso' : 'unzipped', 86 | 'trellis-hssd': 'objects', 87 | 'trellis-3dfuture': '3D-FUTURE-model', 88 | 'trellis-abo': 'raw', 89 | 'trellis-toys4k': '', 90 | 'trellis-objxl-github': 'trellis_objxl_github_convertNorm', 91 | 'trellis-objxl-sketchfab': 'hf-objaverse-v1', 92 | 'debug': 'debug_input' 93 | } 94 | return base_dir_rel_dic[name] 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.6.0 2 | aiohappyeyeballs==2.6.1 3 | aiohttp==3.11.18 4 | aiosignal==1.3.2 5 | annotated-types==0.7.0 6 | antlr4-python3-runtime==4.9.3 7 | asttokens==3.0.0 8 | async-timeout==5.0.1 9 | attrs==25.3.0 10 | beartype==0.20.2 11 | certifi==2022.12.7 12 | chardet==5.2.0 13 | charset-normalizer==2.1.1 14 | classifier-free-guidance-pytorch==0.5.1 15 | click==8.1.8 16 | conda-pack @ file:///croot/conda-pack_1710258022278/work 17 | contourpy==1.3.0 18 | cycler==0.12.1 19 | decorator==5.2.1 20 | deepspeed==0.16.7 21 | docker-pycreds==0.4.0 22 | docstring_parser==0.16 23 | easydict==1.13 24 | einops==0.8.1 25 | ema-pytorch==0.7.7 26 | et_xmlfile==2.0.0 27 | eval_type_backport==0.2.2 28 | exceptiongroup==1.2.2 29 | executing==2.2.0 30 | filelock==3.13.1 31 | fonttools==4.57.0 32 | freetype-py==2.5.1 33 | frozenlist==1.6.0 34 | fsspec==2024.6.1 35 | ftfy==6.3.1 36 | func_timeout==4.3.5 37 | gateloop-transformer==0.2.5 38 | gitdb==4.0.12 39 | GitPython==3.1.44 40 | hjson==3.1.0 41 | huggingface-hub==0.30.2 42 | hyper-connections==0.1.15 43 | idna==3.4 44 | imageio==2.37.0 45 | importlib_metadata==8.6.1 46 | importlib_resources==6.5.2 47 | inquirerpy==0.3.4 48 | ipdb==0.13.13 49 | ipython==8.18.1 50 | jedi==0.19.2 51 | Jinja2==3.1.4 52 | kiui==0.2.16 53 | kiwisolver==1.4.7 54 | lazy_loader==0.4 55 | lightning-utilities==0.14.3 56 | local-attention==1.11.1 57 | markdown-it-py==3.0.0 58 | MarkupSafe==2.1.5 59 | matplotlib==3.9.4 60 | matplotlib-inline==0.1.7 61 | mdurl==0.1.2 62 | mesh2sdf==1.1.0 63 | meshgpt-pytorch==0.6.7 64 | mpmath==1.3.0 65 | msgpack==1.1.0 66 | multidict==6.4.3 67 | multipledispatch==1.0.0 68 | networkx==3.2.1 69 | ninja==1.11.1.4 70 | numpy==1.26.3 71 | nvidia-cublas-cu12==12.6.4.1 72 | nvidia-cuda-cupti-cu12==12.6.80 73 | nvidia-cuda-nvrtc-cu12==12.6.77 74 | nvidia-cuda-runtime-cu12==12.6.77 75 | nvidia-cudnn-cu12==9.5.1.17 76 | nvidia-cufft-cu12==11.3.0.4 77 | nvidia-cufile-cu12==1.11.1.6 78 | nvidia-curand-cu12==10.3.7.77 79 | nvidia-cusolver-cu12==11.7.1.2 80 | nvidia-cusparse-cu12==12.5.4.2 81 | nvidia-cusparselt-cu12==0.6.3 82 | nvidia-ml-py==12.570.86 83 | nvidia-nccl-cu12==2.26.2 84 | nvidia-nvjitlink-cu12==12.6.85 85 | nvidia-nvtx-cu12==12.6.77 86 | objprint==0.3.0 87 | omegaconf==2.3.0 88 | open3d-python==0.3.0.0 89 | open_clip_torch==2.32.0 90 | opencv-python==4.11.0.86 91 | openpyxl==3.1.5 92 | optree==0.15.0 93 | packaging==25.0 94 | pandas==2.2.3 95 | parso==0.8.4 96 | pexpect==4.9.0 97 | pfzy==0.3.4 98 | pillow==11.0.0 99 | platformdirs==4.3.7 100 | point-cloud-utils==0.34.0 101 | prompt_toolkit==3.0.51 102 | propcache==0.3.1 103 | protobuf==6.30.2 104 | psutil==7.0.0 105 | ptyprocess==0.7.0 106 | pure_eval==0.2.3 107 | py-cpuinfo==9.0.0 108 | pydantic==2.11.3 109 | pydantic_core==2.33.1 110 | pyglet==2.1.5 111 | Pygments==2.19.1 112 | pymeshlab==2023.12.post3 113 | PyOpenGL==3.1.0 114 | pyparsing==3.2.3 115 | pyrender==0.1.45 116 | pyrr==0.10.3 117 | python-dateutil==2.9.0.post0 118 | pytorch-custom-utils==0.0.21 119 | pytorch-lightning==2.5.1 120 | pytorch-warmup==0.2.0 121 | pytz==2025.2 122 | PyYAML==6.0.2 123 | regex==2024.11.6 124 | requests==2.28.1 125 | rich==14.0.0 126 | rotary-embedding-torch==0.8.6 127 | safetensors==0.5.3 128 | scikit-image==0.24.0 129 | scipy==1.13.1 130 | sentry-sdk==2.26.1 131 | setproctitle==1.3.5 132 | shtab==1.7.2 133 | six==1.17.0 134 | smmap==5.0.2 135 | stack-data==0.6.3 136 | sympy==1.13.3 137 | tifffile==2024.8.30 138 | timm==1.0.15 139 | tokenizers==0.21.1 140 | tomli==2.2.1 141 | torch==2.7.0 142 | torch-geometric==2.6.1 143 | torchmetrics==1.7.1 144 | torchtyping==0.1.5 145 | torchvision==0.22.0 146 | tqdm==4.67.1 147 | traitlets==5.14.3 148 | transformers==4.51.3 149 | trimesh==4.5.3 150 | triton==3.3.0 151 | typeguard==4.4.2 152 | typing-inspection==0.4.0 153 | typing_extensions==4.13.2 154 | tyro==0.9.19 155 | tzdata==2025.2 156 | urllib3==1.26.13 157 | varname==0.14.0 158 | vector_quantize_pytorch==1.12.8 159 | wandb==0.19.10 160 | wcwidth==0.2.13 161 | x-transformers==1.26.6 162 | yarl==1.20.0 163 | zipp==3.21.0 164 | -------------------------------------------------------------------------------- /dataset_clean.md: -------------------------------------------------------------------------------- 1 | # Dataset Clean Guidance 2 | 3 | ## 1. 3D Datasets Download 4 | The following 3d datasets are public available: 5 | 6 | Main 7 | 8 | - [Objaverse(XL)](https://huggingface.co/datasets/allenai/objaverse-xl) 9 | - [3D-FUTURE](https://tianchi.aliyun.com/dataset/98063) 10 | - [Toys4K](https://github.com/rehg-lab/lowshot-shapebias/tree/main/toys4k) 11 | - [gObjaverse](https://github.com/modelscope/richdreamer/tree/main/dataset/gobjaverse) 12 | - [ShapeNet v2](https://shapenet.org/) 13 | - [Trellis](https://github.com/microsoft/TRELLIS/blob/main/DATASET.md) 14 | 15 | More for any special need 16 | 17 | - [Thingi10K](https://github.com/Thingi10K/Thingi10K) 18 | - [GSO](https://app.gazebosim.org/home) 19 | - [Animal3D](https://xujiacong.github.io/Animal3D/) 20 | - [3DCaricShop](https://github.com/qiuyuda/3DCaricShop) 21 | - [ABO](https://amazon-berkeley-objects.s3.amazonaws.com/index.html#download) 22 | - [BuildingNet](https://github.com/buildingnet/buildingnet_dataset?tab=readme-ov-file) 23 | 24 | 25 | ## 2. Dataset Clean 26 | 1. After downloading the 3ddatasets, you may set the right base path in `silkutils/ss_platform.py` based on your own platform. In `ss_platform.py`, the following inited .xlsx file will be stored in the output path of function: `get_base_dir_platform(name)` 27 | 28 | If you are processing gObjaverse dataset, download objaversev1 first, and download `gobjaverse_280k_index_to_objaverse.json` from [here](https://github.com/modelscope/richdreamer/tree/main/dataset/gobjaverse) and put it to 29 | ``` 30 | silkutils/dataset_clean/gobjaverse_280k_index_to_objaverse.json 31 | ``` 32 | 33 | 2. Run `silkutils/datset_clean/step1_init.py`, modify the input dataset name in the function and several inited .xlsx files will be stored. The following steps will be based on these .xlsx files for data cleaning. 34 | 35 | 3. Run `silkutils/datset_clean/step2_clean.py` for multi-threaded processing. The data will be processed by silksong tokenization algorithm, and the statistic information will be stored in a new .xlsx file. In the following step, some data will be filtered based on the statistic information (e.g. too many faces). This step may take long time. 36 | 37 | 4. Run `silkutils/datset_clean/step3_cleanfix.py` if there are any accident when processing some data (e.g. Disk failure or Memory Problems), hence you need not re-run `step2_clean.py` again. 38 | 39 | 5. Run `silkutils/datset_clean/step4_datafilter.py`. 40 | - `filtered_xlsx_save_dir`: All of processed .xlsx file will be automatically merged to this path. When training, the `model/data_provider.py` should read input .xlsx files from here. 41 | - `filter_version` : Modify you data filter rule. Refer to function `get_filtered_df(df, filter_version)`. The version `11` is recommended for training with max token length 10240. You may modify a new version if you have any special need. 42 | - The finally .xlsx files should be like `meta_all_{datasetname}_res128_v04_mergeall_filter{filter_version}.xlsx` 43 | 44 | ## 3. Sample from Dataset for Analysing/Testing/Evaluation 45 | 1. Refer to `silkutils/datset_clean/step5_sample.py` for training data sampling and testing data generation. 46 | 2. Modify the right path 47 | - `table_dir`: for the dir of input .xlsx files 48 | - `sample_test_dir`: the save dir of test data you sampled 49 | - `sample_test_table_dir`: the save dir of test data's .xlsx file you sampled 50 | - `sample_train_dir`: the save dir of training data you sampled 51 | 3. The sampled data will be copyed to a new dir `sample_test_dir` or `sample_train_dir`, for the convenient of preview. 52 | 4. The data will be copied twice: 53 | - Directly copying original data for the input of model inference. (with postfix `_origin`) 54 | - Normalize the original data and save it to .obj file, for the convenience of your prereview. (with postfix `_norm`) 55 | 56 | #### 3.1 Sampling training data: 57 | 58 | - Run function `sample_table_specify()`. 59 | - View them in `sample_train_dir`. 60 | 61 | #### 3.2 Sampling testing data: 62 | 63 | - Run function `sample_and_generate_testset_table()` 64 | - The sampled testset's .xlsx file has prefix `testset_`, which will be recognized and excluded when training in `model/data_provider.py`. You may change the prefix in the `DataConfigs.testset_prefix` of `config/options.py` if you have any special need. -------------------------------------------------------------------------------- /silkutils/meshdata/mesh_graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class UnionFind: 4 | def __init__(self, size): 5 | self.parent = list(range(size)) 6 | self.rank = [1] * size 7 | 8 | def find(self, a): 9 | if self.parent[a] != a: 10 | self.parent[a] = self.find(self.parent[a]) # Path compression 11 | return self.parent[a] 12 | 13 | def union(self, a, b): 14 | rootA = self.find(a) 15 | rootB = self.find(b) 16 | 17 | if rootA != rootB: 18 | # Union by rank 19 | if self.rank[rootA] > self.rank[rootB]: 20 | self.parent[rootB] = rootA 21 | elif self.rank[rootA] < self.rank[rootB]: 22 | self.parent[rootA] = rootB 23 | else: 24 | self.parent[rootB] = rootA 25 | self.rank[rootA] += 1 26 | 27 | def extend(self): 28 | cur_size=len(self.parent) 29 | self.parent.append(cur_size) 30 | self.rank.append(1) 31 | return cur_size 32 | 33 | def group_num(self, cc_list): 34 | root_list=[self.find(x) for x in cc_list] 35 | return len(set(root_list)) 36 | 37 | def union_group(self, cc_list): 38 | for i in range(len(cc_list)-1): 39 | a=cc_list[i] 40 | b=cc_list[i+1] 41 | self.union(a, b) 42 | return self.find(cc_list[0]) 43 | 44 | def parent_list(self): 45 | return self.parent, len(set(self.parent)) 46 | 47 | 48 | def edge_key(a, b): 49 | return (a, b) if a < b else (b, a) 50 | 51 | def check_flip(f, FN): 52 | edge_right=[] 53 | edge_fix=[] 54 | next_neighbors=[] 55 | for j in range(FN): 56 | 57 | if f.half_edges[j].o is None: 58 | # this is bound 59 | edge_right.append(True) 60 | edge_fix.append(False) 61 | continue 62 | else: 63 | this_he = f.half_edges[j] 64 | edge_fix.append(this_he.o.t.fix_orientation) 65 | if not this_he.o.t.fix_orientation: 66 | next_neighbors.append(this_he.o.t.i) 67 | if this_he.s.i==this_he.o.s.i and this_he.e.i==this_he.o.e.i: 68 | edge_right.append(False) 69 | else: 70 | edge_right.append(True) 71 | if sum(edge_right)==3: # if we need flip / next_neighbor_faces flip / flip process success? 72 | return False, next_neighbors, True 73 | else: 74 | if sum(edge_fix) ==0: 75 | # print(f'find free face {f.i}') 76 | if sum(edge_right)<2: 77 | f.flip() 78 | return True, next_neighbors, True 79 | else: 80 | return False, next_neighbors, True 81 | # 1 or 2 or 3 neighbor faces fixed 82 | flag_flip = None 83 | flag_notflip = None 84 | for index_edge_fix in range(3): 85 | 86 | # if edge_fix[index_edge_fix] == True and edge_right[index_edge_fix] == False: 87 | # flag_flip = True 88 | # elif edge_fix[index_edge_fix] == True and edge_right[index_edge_fix] == True: 89 | # flag_notflip = True 90 | 91 | if edge_fix[index_edge_fix] is True: 92 | if edge_right[index_edge_fix] is False: 93 | flag_flip=True 94 | else: 95 | flag_notflip=True 96 | 97 | if flag_flip and flag_notflip: 98 | # ipdb.set_trace() 99 | # raise Exception('[E] face flip wrong') 100 | return False, next_neighbors, False 101 | if flag_flip: 102 | f.flip() 103 | return True, next_neighbors, True 104 | elif flag_notflip: 105 | return False, next_neighbors, True 106 | else: 107 | raise Exception('[E] face flip wrong 2') 108 | 109 | 110 | def triangle_area_3d(p1, p2, p3): 111 | p1 = np.array(p1) 112 | p2 = np.array(p2) 113 | p3 = np.array(p3) 114 | 115 | v1 = p2 - p1 116 | v2 = p3 - p1 117 | 118 | cross_product = np.cross(v1, v2) 119 | area = np.linalg.norm(cross_product) / 2 120 | 121 | return area -------------------------------------------------------------------------------- /miche/michelangelo/models/tsal/clip_asl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | from torch import nn 5 | from einops import rearrange 6 | from transformers import CLIPModel 7 | 8 | from miche.michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentModule 9 | 10 | 11 | class CLIPAlignedShapeAsLatentModule(AlignedShapeAsLatentModule): 12 | 13 | def __init__(self, *, 14 | shape_model, 15 | clip_model_version: str = "openai/clip-vit-large-patch14"): 16 | 17 | super().__init__() 18 | 19 | # self.clip_model: CLIPModel = CLIPModel.from_pretrained(clip_model_version) 20 | # for params in self.clip_model.parameters(): 21 | # params.requires_grad = False 22 | self.clip_model = None 23 | self.shape_model = shape_model 24 | self.shape_projection = nn.Parameter(torch.empty(self.shape_model.width, self.shape_model.width)) 25 | # nn.init.normal_(self.shape_projection, std=self.shape_model.width ** -0.5) 26 | 27 | def set_shape_model_only(self): 28 | self.clip_model = None 29 | 30 | def encode_shape_embed(self, surface, return_latents: bool = False): 31 | """ 32 | 33 | Args: 34 | surface (torch.FloatTensor): [bs, n, 3 + c] 35 | return_latents (bool): 36 | 37 | Returns: 38 | x (torch.FloatTensor): [bs, projection_dim] 39 | shape_latents (torch.FloatTensor): [bs, m, d] 40 | """ 41 | 42 | pc = surface[..., 0:3] 43 | feats = surface[..., 3:] 44 | 45 | shape_embed, shape_latents = self.shape_model.encode_latents(pc, feats) 46 | x = shape_embed @ self.shape_projection 47 | 48 | if return_latents: 49 | return x, shape_latents 50 | else: 51 | return x 52 | 53 | def encode_image_embed(self, image): 54 | """ 55 | 56 | Args: 57 | image (torch.FloatTensor): [bs, 3, h, w] 58 | 59 | Returns: 60 | x (torch.FloatTensor): [bs, projection_dim] 61 | """ 62 | 63 | x = self.clip_model.get_image_features(image) 64 | 65 | return x 66 | 67 | def encode_text_embed(self, text): 68 | x = self.clip_model.get_text_features(text) 69 | return x 70 | 71 | def forward(self, surface, image, text): 72 | """ 73 | 74 | Args: 75 | surface (torch.FloatTensor): 76 | image (torch.FloatTensor): [bs, 3, 224, 224] 77 | text (torch.LongTensor): [bs, num_templates, 77] 78 | 79 | Returns: 80 | embed_outputs (dict): the embedding outputs, and it contains: 81 | - image_embed (torch.FloatTensor): 82 | - text_embed (torch.FloatTensor): 83 | - shape_embed (torch.FloatTensor): 84 | - logit_scale (float): 85 | """ 86 | 87 | # # text embedding 88 | # text_embed_all = [] 89 | # for i in range(text.shape[0]): 90 | # text_for_one_sample = text[i] 91 | # text_embed = self.encode_text_embed(text_for_one_sample) 92 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 93 | # text_embed = text_embed.mean(dim=0) 94 | # text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 95 | # text_embed_all.append(text_embed) 96 | # text_embed_all = torch.stack(text_embed_all) 97 | 98 | b = text.shape[0] 99 | text_tokens = rearrange(text, "b t l -> (b t) l") 100 | text_embed = self.encode_text_embed(text_tokens) 101 | text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b) 102 | text_embed = text_embed.mean(dim=1) 103 | text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) 104 | 105 | # image embedding 106 | image_embed = self.encode_image_embed(image) 107 | 108 | # shape embedding 109 | shape_embed, shape_latents = self.encode_shape_embed(surface, return_latents=True) 110 | 111 | embed_outputs = { 112 | "image_embed": image_embed, 113 | "text_embed": text_embed, 114 | "shape_embed": shape_embed, 115 | # "logit_scale": self.clip_model.logit_scale.exp() 116 | } 117 | 118 | return embed_outputs, shape_latents 119 | -------------------------------------------------------------------------------- /miche/michelangelo/models/tsal/tsal_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | from typing import Tuple, List, Optional 5 | 6 | # Base class for output of Point to Mesh transformation 7 | class Point2MeshOutput(object): 8 | def __init__(self): 9 | self.mesh_v = None # Vertices of the mesh 10 | self.mesh_f = None # Faces of the mesh 11 | self.center = None # Center of the mesh 12 | self.pc = None # Point cloud data 13 | 14 | 15 | # Base class for output of Latent to Mesh transformation 16 | class Latent2MeshOutput(object): 17 | def __init__(self): 18 | self.mesh_v = None # Vertices of the mesh 19 | self.mesh_f = None # Faces of the mesh 20 | 21 | 22 | # Base class for output of Aligned Mesh transformation 23 | class AlignedMeshOutput(object): 24 | def __init__(self): 25 | self.mesh_v = None # Vertices of the mesh 26 | self.mesh_f = None # Faces of the mesh 27 | self.surface = None # Surface data 28 | self.image = None # Aligned image data 29 | self.text: Optional[str] = None # Aligned text data 30 | self.shape_text_similarity: Optional[float] = None # Similarity between shape and text 31 | self.shape_image_similarity: Optional[float] = None # Similarity between shape and image 32 | 33 | 34 | # Base class for Shape as Latent with Point to Mesh transformation module 35 | class ShapeAsLatentPLModule(nn.Module): 36 | latent_shape: Tuple[int] # Shape of the latent space 37 | 38 | def encode(self, surface, *args, **kwargs): 39 | raise NotImplementedError 40 | 41 | def decode(self, z_q, *args, **kwargs): 42 | raise NotImplementedError 43 | 44 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 45 | raise NotImplementedError 46 | 47 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 48 | raise NotImplementedError 49 | 50 | 51 | # Base class for Shape as Latent module 52 | class ShapeAsLatentModule(nn.Module): 53 | latent_shape: Tuple[int, int] # Shape of the latent space 54 | 55 | def __init__(self, *args, **kwargs): 56 | super().__init__() 57 | 58 | def encode(self, *args, **kwargs): 59 | raise NotImplementedError 60 | 61 | def decode(self, *args, **kwargs): 62 | raise NotImplementedError 63 | 64 | def query_geometry(self, *args, **kwargs): 65 | raise NotImplementedError 66 | 67 | 68 | # Base class for Aligned Shape as Latent with Point to Mesh transformation module 69 | class AlignedShapeAsLatentPLModule(nn.Module): 70 | latent_shape: Tuple[int] # Shape of the latent space 71 | 72 | def set_shape_model_only(self): 73 | raise NotImplementedError 74 | 75 | def encode(self, surface, *args, **kwargs): 76 | raise NotImplementedError 77 | 78 | def decode(self, z_q, *args, **kwargs): 79 | raise NotImplementedError 80 | 81 | def latent2mesh(self, latents, *args, **kwargs) -> List[Latent2MeshOutput]: 82 | raise NotImplementedError 83 | 84 | def point2mesh(self, *args, **kwargs) -> List[Point2MeshOutput]: 85 | raise NotImplementedError 86 | 87 | 88 | # Base class for Aligned Shape as Latent module 89 | class AlignedShapeAsLatentModule(nn.Module): 90 | shape_model: ShapeAsLatentModule # Shape model module 91 | latent_shape: Tuple[int, int] # Shape of the latent space 92 | 93 | 94 | def __init__(self, *args, **kwargs): 95 | super().__init__() 96 | 97 | def set_shape_model_only(self): 98 | raise NotImplementedError 99 | 100 | def encode_image_embed(self, *args, **kwargs): 101 | raise NotImplementedError 102 | 103 | def encode_text_embed(self, *args, **kwargs): 104 | raise NotImplementedError 105 | 106 | def encode_shape_embed(self, *args, **kwargs): 107 | raise NotImplementedError 108 | 109 | # Base class for Textured Shape as Latent module 110 | class TexturedShapeAsLatentModule(nn.Module): 111 | 112 | def __init__(self, *args, **kwargs): 113 | super().__init__() 114 | 115 | def encode(self, *args, **kwargs): 116 | raise NotImplementedError 117 | 118 | def decode(self, *args, **kwargs): 119 | raise NotImplementedError 120 | 121 | def query_geometry(self, *args, **kwargs): 122 | raise NotImplementedError 123 | 124 | def query_color(self, *args, **kwargs): 125 | raise NotImplementedError 126 | -------------------------------------------------------------------------------- /silkutils/meto/decode_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | 4 | 5 | 6 | def decode_layer_face(inlayer_matlast, inlayer_matcur, outlayer_mat, last_verts, cur_verts): 7 | M, N = outlayer_mat.shape 8 | face_all=[] 9 | for i in range(M): # check inlayer mat last 10 | up_v=cur_verts[i] 11 | down_index_list=[] 12 | for j in range(N): 13 | if outlayer_mat[i][j]==1: 14 | down_index_list.append(j) 15 | N_down=len(down_index_list) 16 | if N_down == 1: 17 | continue 18 | if N_down == 2: 19 | ind_1=down_index_list[0] 20 | ind_2=down_index_list[1] 21 | if abs((ind_2-ind_1)%N) > abs((ind_1-ind_2)%N): 22 | down_index_1=ind_2 23 | down_index_2=ind_1 24 | else: 25 | down_index_1=ind_1 26 | down_index_2=ind_2 27 | if check_inlayer(down_index_1, down_index_2, inlayer_matlast): 28 | face_all.append([up_v, last_verts[down_index_2], last_verts[down_index_1]]) 29 | continue 30 | for k in range(N_down): 31 | down_index_1=down_index_list[k%N_down] 32 | down_index_2=down_index_list[(k+1)%N_down] 33 | if check_inlayer(down_index_1, down_index_2, inlayer_matlast): 34 | face_all.append([up_v, last_verts[down_index_2], last_verts[down_index_1]]) 35 | 36 | for j in range(N): 37 | down_v=last_verts[j] 38 | up_index_list=[] 39 | for i in range(M): 40 | if outlayer_mat[i][j]==1: 41 | up_index_list.append(i) 42 | N_up=len(up_index_list) 43 | if N_up == 1: 44 | continue 45 | if N_up == 2: 46 | ind_1=up_index_list[0] 47 | ind_2=up_index_list[1] 48 | if abs((ind_2-ind_1)%M) > abs((ind_1-ind_2)%M): 49 | up_index_1=ind_2 50 | up_index_2=ind_1 51 | else: 52 | up_index_1=ind_1 53 | up_index_2=ind_2 54 | if check_inlayer(up_index_1, up_index_2, inlayer_matcur): 55 | face_all.append([down_v, cur_verts[up_index_1], cur_verts[up_index_2]]) 56 | continue 57 | for k in range(N_up): 58 | up_index_1=up_index_list[k%N_up] 59 | up_index_2=up_index_list[(k+1)%N_up] 60 | if check_inlayer(up_index_1, up_index_2, inlayer_matcur): 61 | face_all.append([down_v, cur_verts[up_index_1], cur_verts[up_index_2]]) 62 | 63 | return face_all 64 | 65 | 66 | def decode_inlayer_connect_faces(inlayer_matrix, vertex_list): 67 | G = nx.Graph() 68 | 69 | all_edges=[] 70 | M=len(inlayer_matrix) 71 | for i in range(M): 72 | for j in range(i+1, M): 73 | if inlayer_matrix[i][j]==1: 74 | all_edges.append((i, j)) 75 | 76 | G.add_edges_from(all_edges) 77 | triangles_index = [tuple(sorted(cycle)) for cycle in nx.cycle_basis(G) if len(cycle) == 3] 78 | triangles_index_fix=[] 79 | for triangle_i in triangles_index: 80 | a, b, c= triangle_i 81 | triangles_index_fix.append([a, c, b]) 82 | triangles = [[vertex_list[triangle[0]], vertex_list[triangle[1]], vertex_list[triangle[2]]] for triangle in triangles_index_fix] 83 | return triangles 84 | 85 | 86 | def check_inlayer(down_index_1, down_index_2, inlayer_mat): 87 | if inlayer_mat is None: 88 | return True 89 | if inlayer_mat[down_index_1][down_index_2]==1: 90 | return True 91 | else: 92 | return False 93 | 94 | def save_matrix_to_txt(mat_in, mat_out, file_name="matrix.txt"): 95 | with open(file_name, "w") as f: 96 | CC_num=len(mat_in) 97 | for i in range(CC_num): 98 | in_list=mat_in[i] 99 | out_list=mat_out[i] 100 | f.write(f"CC {i+1}: \n") 101 | layer_num=len(in_list) 102 | for j in range(1, layer_num): 103 | f.write(f"Matrix layer: {j}\n") 104 | matrix_in=in_list[j] 105 | matrix_out=out_list[j] 106 | matrix_in_str = "\n".join(" " + " ".join(f"{val}" for val in row) for row in matrix_in) 107 | matrix_out_str = "\n".join(" " + " ".join(f"{val}" for val in row) for row in matrix_out) 108 | f.write(matrix_in_str + "\n\n") 109 | f.write(matrix_out_str + "\n\n") 110 | f.write('-------------------\n') -------------------------------------------------------------------------------- /miche/michelangelo/models/tsal/loss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from typing import Optional 7 | 8 | from miche.michelangelo.models.modules.distributions import DiagonalGaussianDistribution 9 | from miche.michelangelo.utils import misc 10 | 11 | 12 | class ContrastKLNearFar(nn.Module): 13 | def __init__(self, 14 | contrast_weight: float = 1.0, 15 | near_weight: float = 0.1, 16 | kl_weight: float = 1.0, 17 | num_near_samples: Optional[int] = None): 18 | 19 | super().__init__() 20 | 21 | self.labels = None 22 | self.last_local_batch_size = None 23 | 24 | self.contrast_weight = contrast_weight 25 | self.near_weight = near_weight 26 | self.kl_weight = kl_weight 27 | self.num_near_samples = num_near_samples 28 | self.geo_criterion = nn.BCEWithLogitsLoss() 29 | 30 | def forward(self, 31 | shape_embed: torch.FloatTensor, 32 | text_embed: torch.FloatTensor, 33 | image_embed: torch.FloatTensor, 34 | logit_scale: torch.FloatTensor, 35 | posteriors: Optional[DiagonalGaussianDistribution], 36 | shape_logits: torch.FloatTensor, 37 | shape_labels: torch.FloatTensor, 38 | split: Optional[str] = "train", **kwargs): 39 | 40 | # shape_embed: torch.FloatTensor 41 | # text_embed: torch.FloatTensor 42 | # image_embed: torch.FloatTensor 43 | # logit_scale: torch.FloatTensor 44 | # posteriors: Optional[DiagonalGaussianDistribution] 45 | # shape_logits: torch.FloatTensor 46 | # shape_labels: torch.FloatTensor 47 | 48 | local_batch_size = shape_embed.size(0) 49 | 50 | if local_batch_size != self.last_local_batch_size: 51 | self.labels = local_batch_size * misc.get_rank() + torch.arange( 52 | local_batch_size, device=shape_embed.device 53 | ).long() 54 | self.last_local_batch_size = local_batch_size 55 | 56 | # normalized features 57 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 58 | text_embed = F.normalize(text_embed, dim=-1, p=2) 59 | image_embed = F.normalize(image_embed, dim=-1, p=2) 60 | 61 | # gather features from all GPUs 62 | shape_embed_all, text_embed_all, image_embed_all = misc.all_gather_batch( 63 | [shape_embed, text_embed, image_embed] 64 | ) 65 | 66 | # cosine similarity as logits 67 | logits_per_shape_text = logit_scale * shape_embed @ text_embed_all.t() 68 | logits_per_text_shape = logit_scale * text_embed @ shape_embed_all.t() 69 | logits_per_shape_image = logit_scale * shape_embed @ image_embed_all.t() 70 | logits_per_image_shape = logit_scale * image_embed @ shape_embed_all.t() 71 | contrast_loss = (F.cross_entropy(logits_per_shape_text, self.labels) + 72 | F.cross_entropy(logits_per_text_shape, self.labels)) / 2 + \ 73 | (F.cross_entropy(logits_per_shape_image, self.labels) + 74 | F.cross_entropy(logits_per_image_shape, self.labels)) / 2 75 | 76 | # shape reconstruction 77 | if self.num_near_samples is None: 78 | num_vol = shape_logits.shape[1] // 2 79 | else: 80 | num_vol = shape_logits.shape[1] - self.num_near_samples 81 | 82 | vol_logits = shape_logits[:, 0:num_vol] 83 | vol_labels = shape_labels[:, 0:num_vol] 84 | 85 | near_logits = shape_logits[:, num_vol:] 86 | near_labels = shape_labels[:, num_vol:] 87 | 88 | # occupancy loss 89 | vol_bce = self.geo_criterion(vol_logits.float(), vol_labels.float()) 90 | near_bce = self.geo_criterion(near_logits.float(), near_labels.float()) 91 | 92 | if posteriors is None: 93 | kl_loss = torch.tensor(0.0, dtype=vol_logits.dtype, device=vol_logits.device) 94 | else: 95 | kl_loss = posteriors.kl(dims=(1, 2)) 96 | kl_loss = torch.mean(kl_loss) 97 | 98 | loss = vol_bce + near_bce * self.near_weight + kl_loss * self.kl_weight + contrast_loss * self.contrast_weight 99 | 100 | # compute accuracy 101 | with torch.no_grad(): 102 | pred = torch.argmax(logits_per_shape_text, dim=-1) 103 | correct = pred.eq(self.labels).sum() 104 | shape_text_acc = 100 * correct / local_batch_size 105 | 106 | pred = torch.argmax(logits_per_shape_image, dim=-1) 107 | correct = pred.eq(self.labels).sum() 108 | shape_image_acc = 100 * correct / local_batch_size 109 | 110 | preds = shape_logits >= 0 111 | accuracy = (preds == shape_labels).float() 112 | accuracy = accuracy.mean() 113 | 114 | log = { 115 | "{}/contrast".format(split): contrast_loss.clone().detach(), 116 | "{}/near".format(split): near_bce.detach(), 117 | "{}/far".format(split): vol_bce.detach(), 118 | "{}/kl".format(split): kl_loss.detach(), 119 | "{}/shape_text_acc".format(split): shape_text_acc, 120 | "{}/shape_image_acc".format(split): shape_image_acc, 121 | "{}/total_loss".format(split): loss.clone().detach(), 122 | "{}/accuracy".format(split): accuracy, 123 | } 124 | 125 | if posteriors is not None: 126 | log[f"{split}/mean"] = posteriors.mean.mean().detach() 127 | log[f"{split}/std_mean"] = posteriors.std.mean().detach() 128 | log[f"{split}/std_max"] = posteriors.std.max().detach() 129 | 130 | return loss, log 131 | -------------------------------------------------------------------------------- /model/data_provider_infer.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | from kiui.mesh_utils import clean_mesh 4 | from x_transformers.autoregressive_wrapper import top_p, top_k 5 | from silkutils.silksong_tokenization import get_tokenizer_silksong, tokenize_mesh_ss 6 | from silkutils.meshdata.mesh_io import load_mesh_nonorm, normalize_mesh 7 | import model.data_provider as data_pro 8 | import torch 9 | import traceback 10 | 11 | class Dataset: 12 | ''' 13 | A toy dataset for inference 14 | ''' 15 | def __init__(self, input_type, input_list): 16 | super().__init__() 17 | self.data = [] 18 | if input_type == 'pc_normal': 19 | for input_path in input_list: 20 | # load npy 21 | cur_data = np.load(input_path) 22 | # sample 4096 23 | assert cur_data.shape[0] >= 4096, "input pc_normal should have at least 4096 points" 24 | idx = np.random.choice(cur_data.shape[0], 4096, replace=False) 25 | cur_data = cur_data[idx] 26 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) 27 | 28 | elif input_type == 'mesh': 29 | mesh_list, pc_list = [], [] 30 | for input_path in input_list: 31 | # sample point cloud and normal from mesh 32 | ####ss 33 | v, f = load_mesh_nonorm(input_path) 34 | v, f = clean_mesh(v, f, min_f=0, min_d=0, remesh=False, verbose=False) 35 | v = normalize_mesh(v, bound=0.95) 36 | cur_data = trimesh.Trimesh(vertices=v, faces=f) 37 | #### ss - bpt 38 | # cur_data = trimesh.load(input_path, force='mesh') 39 | # cur_data = apply_normalize(cur_data) 40 | #### bpt 41 | mesh_list.append(cur_data) 42 | pc_list.append(sample_pc(cur_data, pc_num=4096, with_normal=True)) 43 | 44 | for input_path, cur_data in zip(input_list, pc_list): 45 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0]}) 46 | 47 | print(f"dataset total data samples: {len(self.data)}") 48 | 49 | def __len__(self): 50 | return len(self.data) 51 | 52 | def __getitem__(self, idx): 53 | data_dict = {} 54 | data_dict['pc_normal'] = self.data[idx]['pc_normal'] 55 | data_dict['uid'] = self.data[idx]['uid'] 56 | 57 | return data_dict 58 | 59 | class InferDataset: 60 | ''' 61 | A toy dataset for inference 62 | ''' 63 | def __init__(self, input_type, input_list): 64 | super().__init__() 65 | self.data = [] 66 | self.tokenizer, _=get_tokenizer_silksong() 67 | if input_type == 'mesh': 68 | pc_list = [] 69 | gt_mesh = [] 70 | for input_path in input_list: 71 | # sample point cloud and normal from mesh 72 | v, f = load_mesh_nonorm(input_path) 73 | v, f = clean_mesh(v, f, min_f=0, min_d=0, remesh=False, verbose=False) 74 | v = normalize_mesh(v, bound=0.95) 75 | mesh = trimesh.Trimesh(vertices=v, faces=f) 76 | points = data_pro.sample_pc(mesh, pc_num=4096, with_normal=True, aug=False) 77 | try: 78 | gt_token, _=tokenize_mesh_ss(tokenizer=self.tokenizer, vertices=v, faces=f) 79 | except Exception as e: 80 | print(f'[DatasetInfer] {input_path}, {str(e)}') 81 | traceback.print_exc() 82 | gt_token=None 83 | pc_list.append(points) 84 | gt_mesh.append({"v":v, "f":f, 'tokens': gt_token}) 85 | print(f'{input_path} Done') 86 | 87 | for input_path, cur_data, gt in zip(input_list, pc_list, gt_mesh): 88 | self.data.append({'pc_normal': cur_data, 'uid': input_path.split('/')[-1].split('.')[0], 'full_path': input_path, 'gt_mesh': gt}) 89 | else: 90 | raise Exception('not implement') 91 | 92 | print(f"infer dataset total data samples: {len(self.data)}") 93 | 94 | def __len__(self): 95 | return len(self.data) 96 | 97 | def __getitem__(self, idx): 98 | data_dict = {} 99 | data_dict['pc_normal'] = self.data[idx]['pc_normal'] 100 | data_dict['uid'] = self.data[idx]['uid'] 101 | data_dict['full_path'] = self.data[idx]['full_path'] 102 | data_dict['gt_mesh'] = self.data[idx]['gt_mesh'] 103 | 104 | return data_dict 105 | 106 | def joint_filter(logits, k = 50, p=0.95): 107 | logits = top_k(logits, k = k) 108 | logits = top_p(logits, thres = p) 109 | return logits 110 | 111 | def max_filter(logits, k = 1): 112 | logits = top_k(logits, k = k) 113 | return logits 114 | 115 | def apply_normalize(mesh): 116 | ''' 117 | normalize mesh to [-1, 1] 118 | ''' 119 | bbox = mesh.bounds 120 | center = (bbox[1] + bbox[0]) / 2 121 | scale = (bbox[1] - bbox[0]).max() 122 | 123 | mesh.apply_translation(-center) 124 | mesh.apply_scale(1 / scale * 2 * 0.95) 125 | 126 | return mesh 127 | 128 | 129 | 130 | def sample_pc(mesh_path, pc_num, with_normal=False): 131 | 132 | mesh = trimesh.load(mesh_path, force='mesh', process=False) 133 | mesh = apply_normalize(mesh) 134 | 135 | if not with_normal: 136 | points, _ = mesh.sample(pc_num, return_index=True) 137 | return points 138 | 139 | points, face_idx = mesh.sample(50000, return_index=True) 140 | normals = mesh.face_normals[face_idx] 141 | pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16) 142 | 143 | # random sample point cloud 144 | ind = np.random.choice(pc_normal.shape[0], pc_num, replace=False) 145 | pc_normal = pc_normal[ind] 146 | 147 | return pc_normal 148 | 149 | 150 | def collate_fn_infer(batch): 151 | 152 | # conds 153 | conds = [item['pc_normal'] for item in batch] 154 | 155 | 156 | results = {} 157 | results['pc_normal'] = torch.from_numpy(np.stack(conds, axis=0)).float() 158 | results['uid'] = [item['uid'] for item in batch] 159 | results['full_path']=[item['full_path'] for item in batch] 160 | results['gt_mesh']=[item['gt_mesh'] for item in batch] 161 | 162 | return results -------------------------------------------------------------------------------- /silkutils/dataset_clean/step5_sample.py: -------------------------------------------------------------------------------- 1 | 2 | import pandas as pd 3 | import os 4 | import shutil 5 | import sys 6 | import os 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | import ss_platform 9 | from meshdata.mesh_io import load_mesh, write_obj 10 | 11 | # sample_save_dir='debug_provider/infer_input_eachdataset' 12 | sample_save_dir='test_set_A' 13 | sample_save_dirC='test_set_C' 14 | 15 | 16 | table_dir='datasets/cleaned' 17 | sample_test_dir='datasets/sample_test/meshes' 18 | sample_train_dir='datasets/sample_train/meshes' 19 | sample_test_table_dir='datasets/sample_test/tables' 20 | # sample_train_table_dir='datasets/sample_train/tables' 21 | 22 | 23 | def copy_and_norm(dataset_name, xlsx_dic, save_dir_norm, save_dir_origin): 24 | # xlsx_dic: sampled_rows[['obj_path','obj_name','id']].to_dict(orient='records') 25 | base_dir_platform=ss_platform.get_base_dir_platform(dataset_name) 26 | os.makedirs(save_dir_norm, exist_ok=True) 27 | os.makedirs(save_dir_origin, exist_ok=True) 28 | for file_dic in xlsx_dic: 29 | file_full_path=os.path.join(base_dir_platform, file_dic['obj_path']) 30 | 31 | # copy origin first 32 | target_file_ext=os.path.splitext(os.path.basename(file_dic['obj_path']))[1] 33 | target_file_savename=f'ori_{dataset_name}_{file_dic["obj_name"]}{target_file_ext}' 34 | target_full_path_origin=os.path.join(save_dir_origin, target_file_savename) 35 | shutil.copy(file_full_path, target_full_path_origin) 36 | print(f'origin mesh copy to {target_full_path_origin}') 37 | 38 | # try to normalize 39 | try: 40 | print('load mesh') 41 | vertices, faces = load_mesh(file_full_path, clean=True) 42 | target_full_path_norm=os.path.join(save_dir_norm, f'norm_{dataset_name}_{file_dic["obj_name"]}.obj') 43 | write_obj(vertices, faces, target_full_path_norm) 44 | print(f'normed mesh saving to {target_full_path_norm}') 45 | except Exception as e: 46 | raise Exception(f'[E] Norm saving Failed, {file_full_path}') 47 | 48 | def sample_table_specify(dataset_name, reso, verison, filter_num, num, split_list): 49 | 50 | xlsx_file=f'{table_dir}/meta_all_{dataset_name}_res{reso}_v{verison:02}_mergeall_filter{filter_num:02}.xlsx' 51 | df = pd.read_excel(xlsx_file) 52 | print(f'read num {len(df)}') 53 | 54 | for index in range(len(split_list)-1): 55 | low=split_list[index] 56 | high=split_list[index+1] 57 | filtered_df=df[(df['face_num_process'] < high) & (df['face_num_process'] > low)] 58 | print(f'{low}-{high} filter num {len(filtered_df)}') 59 | 60 | available_samples = len(filtered_df) 61 | 62 | if available_samples < num: 63 | print(f"[WARNING] {low}-{high}: avail: {available_samples} < {num}") 64 | num = available_samples 65 | 66 | sampled_rows=filtered_df.sample(n=num) 67 | result=sampled_rows[['obj_path','obj_name','id']].to_dict(orient='records') 68 | 69 | save_dir_norm=os.path.join(sample_train_dir, f'train_{dataset_name}_norm', f'face{low:06}-{high:06}') 70 | save_dir_origin=os.path.join(sample_train_dir, f'train_{dataset_name}_origin', f'face{low:06}-{high:06}') 71 | 72 | print(f'sampling {dataset_name}_face{low:06}-{high:06}') 73 | copy_and_norm(dataset_name, result, save_dir_norm, save_dir_origin) 74 | 75 | 76 | def sample_table_left(dataset_name, xlsx_all, xlsx_exclude, num, sample_batch): 77 | 78 | # xlsx_all = f'{table_dir}/meta_all_{dataset_name}_res{reso}_v{v:02}_mergeall_filter02.xlsx' 79 | # xlsx_exclude=f'{table_dir}/meta_all_{dataset_name}_res{reso}_v{v:02}_mergeall_filter{filter:02}.xlsx' 80 | 81 | df_exclude = pd.read_excel(xlsx_exclude) # usually train set 82 | df_all = pd.read_excel(xlsx_all) 83 | 84 | exclude_id_list=df_exclude['id'].tolist() 85 | # filtered_df = df_avail[~df_avail['id'].isin(train_id_list)] 86 | filtered_df = df_all[~df_all['id'].isin(exclude_id_list)] 87 | sampled_rows=filtered_df.sample(n=num) 88 | 89 | result=sampled_rows[['obj_path','obj_name','id']].to_dict(orient='records') 90 | 91 | save_dir_origin=os.path.join(sample_test_dir, f'test_L_{dataset_name}_origin', f'batch_{sample_batch:02}') 92 | save_dir_norm=os.path.join(sample_test_dir, f'test_L_{dataset_name}_norm', f'batch_{sample_batch:02}') 93 | 94 | copy_and_norm(dataset_name=dataset_name, xlsx_dic=result, save_dir_norm=save_dir_norm, save_dir_origin=save_dir_origin) 95 | 96 | 97 | def sample_and_generate_testset_table(dataset_name, reso, verison, filter_num, num, sample_batch): 98 | os.makedirs(sample_test_table_dir, exist_ok=True) 99 | xlsx_filter_all=f'{table_dir}/meta_all_{dataset_name}_res{reso}_v{verison:02}_mergeall_filter{filter_num:02}.xlsx' 100 | xlsx_filter_sample_save=f'{sample_test_table_dir}/testset_meta_all_{dataset_name}_res{reso}_v{verison:02}_mergeall_filter{filter_num:02}_sample{num:04}_sb{sample_batch:02}.xlsx' 101 | df = pd.read_excel(xlsx_filter_all) 102 | if num>len(df): 103 | num=len(df) 104 | sampled_rows=df.sample(n=num) 105 | # save sampled 106 | print(f'test set xlsx save to {xlsx_filter_sample_save}') 107 | sampled_rows.to_excel(xlsx_filter_sample_save, index=False) 108 | result=sampled_rows[['obj_path','obj_name','id']].to_dict(orient='records') 109 | 110 | save_dir_origin=os.path.join(sample_test_dir, f'test_A_{dataset_name}_origin', f'batch_{sample_batch:02}') 111 | save_dir_norm=os.path.join(sample_test_dir, f'test_A_{dataset_name}_norm', f'batch_{sample_batch:02}') 112 | copy_and_norm(dataset_name=dataset_name, xlsx_dic=result, save_dir_norm=save_dir_norm, save_dir_origin=save_dir_origin) 113 | 114 | 115 | 116 | if __name__ == "__main__": 117 | 118 | all_dataset=['abo','thingi10k'] 119 | # all_dataset=['3dfuture','shapenetv2'] 120 | # all_dataset=['objaversev1'] 121 | # all_dataset=['trellis-objxl-github'] 122 | all_dataset=['gobjaversev1'] 123 | split_list=[0, 100, 500, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000] 124 | for dataset_name in all_dataset: 125 | # sample and analyze training set 126 | sample_table_specify(dataset_name=dataset_name, reso=128, verison=4, filter_num=11, num=50, split_list=split_list) 127 | 128 | # generate test set, excluded in data_provider 129 | # sample_and_generate_testset_table(dataset_name=dataset_name, reso=128, verison=4, filter_num=11, num=200, sample_batch=0) 130 | 131 | 132 | -------------------------------------------------------------------------------- /silkutils/dataset_clean/process_dataset_fix.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import pandas as pd 6 | from tqdm import tqdm 7 | from dataset_clean import process_one 8 | from ss_platform import get_savedirs, get_base_dir_platform 9 | from dataset_clean.process_dataset import init_process_one_params, update_item, setup_logger 10 | import concurrent.futures 11 | from typing import Dict, Any 12 | import traceback 13 | 14 | import re 15 | import argparse 16 | 17 | 18 | def process_item_fix(item: Dict[str, Any], log_dir: str, dataset_name: str, resolution: int, version: int, part_num: int) -> Dict[str, Any]: 19 | obj_name = item.get('obj_name') 20 | obj_id = item.get('id') 21 | bug_info=item.get('bug_info') 22 | str_buginfo=str(bug_info) 23 | key_words=['Time', 'loading', 'Unknown', 'timed', 'process', 'thread', 'pickle', 'open'] 24 | if not any(keyword in str_buginfo for keyword in key_words): 25 | # print(f'can not handle buginfo {bug_info}') 26 | return item 27 | 28 | try: 29 | # 30 | print(f'fixing bug: {str_buginfo} of id: {obj_id}') 31 | input_params=init_process_one_params(dataset_name=dataset_name, item=item, resolution=resolution, version=version, part_num=part_num) 32 | result = process_one(**input_params) 33 | updated_item=update_item(result, item) 34 | print(f'fixed {updated_item['bug_info']} of id: {obj_id}') 35 | updated_item['bug_info']=None 36 | return updated_item 37 | # except func_timeout.exceptions.FunctionTimedOut as ee: 38 | # logger = setup_logger(obj_id, obj_name, log_dir) # 39 | # logger.error("An error occurred: %s", str(ee)) 40 | # logger.error(traceback.format_exc()) 41 | # item['bug_info']=f'Time out' 42 | # item['done']=0 43 | # return item 44 | except Exception as e: 45 | 46 | # logger = setup_logger(obj_id, obj_name, log_dir) # 为每个obj_name设置一个logger 47 | # logger.error("An error occurred: %s", str(e)) 48 | # logger.error(traceback.format_exc()) 49 | item['bug_info']=f'{str(e)}' 50 | item['done']=0 51 | return item 52 | 53 | 54 | def process_dataset_fix(xlsx_dic, max_workers=24, head=None): 55 | 56 | dataset_name, resolution, version, part_ind, b_i=xlsx_dic['dataset_name'], xlsx_dic['reso'], xlsx_dic['version'], xlsx_dic['part_ind'], xlsx_dic['b_i'] 57 | 58 | basic_dir_platform=get_base_dir_platform(dataset_name) 59 | xlsx_path_platform=os.path.join(get_base_dir_platform(dataset_name), xlsx_dic['file']) 60 | log_dir=os.path.join(basic_dir_platform, f'p_{part_ind:04}_v{version:02}', get_savedirs()['error']) 61 | with pd.ExcelFile(xlsx_path_platform) as xls: 62 | df = pd.read_excel(xls) 63 | if head is not None: 64 | df = df.head(head) 65 | 66 | key_words_reg='Time|loading|Unknown|timed|process|thread|pickle|open' 67 | 68 | df_filter = df[df['bug_info'].str.contains(key_words_reg, regex=True, na=False)] 69 | 70 | # print(f'fixing {xlsx_path_platform}') 71 | print(f'process {len(df_filter)}/{len(df)} items') 72 | if len(df_filter)==0: 73 | print('------------- no rows need to process!------------- ') 74 | return 75 | save_path_fix_xlsx=xlsx_path_platform 76 | # save_path_fix_xlsx=os.path.join(basic_dir_platform, os.path.splitext(xlsx_dic['file'])[0]+"_debug.xlsx") 77 | if True: 78 | 79 | results=[] 80 | with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: 81 | future_to_item = {executor.submit(process_item_fix, row.to_dict(), log_dir, dataset_name, resolution, version, part_ind): row for index, row in df.iterrows()} 82 | 83 | 84 | with tqdm(total=len(df), desc=f"{dataset_name}-p_{part_ind:04}-b{b_i}", unit="item") as pbar: 85 | for future in concurrent.futures.as_completed(future_to_item): 86 | item_dic = future_to_item[future].to_dict() 87 | try: 88 | result = future.result() 89 | results.append(result) 90 | # except concurrent.futures.TimeoutError: 91 | # item_dic['bug_info'] = 'Timeout occurred' 92 | # item_dic['done'] = 0 93 | # results.append(item_dic) 94 | 95 | except Exception as e: 96 | item_dic['bug_info']=f'{str(e)}' 97 | item_dic['done']=0 98 | results.append(item_dic) 99 | 100 | # pbar.update(progress_queue.get()) 101 | pbar.update(1) 102 | 103 | 104 | results_df = pd.DataFrame(results) 105 | results_df = results_df.sort_values(by='id') 106 | 107 | with pd.ExcelWriter(save_path_fix_xlsx) as writer: 108 | results_df.to_excel(writer, index=False) 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser(description='data fixing.') 112 | 113 | parser.add_argument('--dataset_name', type=str, required=True, 114 | help='The name of the dataset to process.') 115 | parser.add_argument('--file', type=str, required=True, 116 | help='The name of the dataset to process.') 117 | parser.add_argument('--reso', type=int, required=True, 118 | help='The index of the part to process.') 119 | parser.add_argument('--version', type=int, required=True, 120 | help='The index of the batch to process.') 121 | parser.add_argument('--part_ind', type=int, required=True, 122 | help='The resolution to process.') 123 | parser.add_argument('--b_i', type=int, required=True, 124 | help='The resolution to process.') 125 | parser.add_argument('--max_workers', type=int, required=True, 126 | help='The resolution to process.') 127 | 128 | args = parser.parse_args() 129 | 130 | file_dic={ 131 | 'dataset_name': args.dataset_name, 132 | 'reso': args.reso, 133 | 'version': args.version, 134 | 'part_ind': args.part_ind, 135 | 'b_i': args.b_i, 136 | 'file': args.file 137 | } 138 | 139 | print(f'------------- fixing {file_dic['file']} ------------- ') 140 | process_dataset_fix(file_dic, max_workers=args.max_workers, head=None) 141 | print(f'------------- done {file_dic['file']} ------------- ') 142 | 143 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mesh Silksong: Auto-Regressive Mesh Generation as Weaving Silk 2 | 3 | 4 | ![img](./assets/teaser_mid_compress.png) 5 | 6 | ## todo 7 | - [ ] Fix bugs of geometry processing like "mobius loop". 8 | - [ ] Release checkpoint trained on more data. 9 | 10 | ## 1. Environment 11 | #### Conda 12 | ``` 13 | conda create -n silk python=3.9 14 | conda activate silk 15 | pip install -r requirements.txt 16 | ``` 17 | #### Docker (Optional) 18 | 1. pull docker image 19 | ``` 20 | docker pull song21/silk:v1 21 | ``` 22 | 2. start docker container 23 | 24 | For example, if you are running docker on Windows, and this project is stored in `D:\Docker_projs\codes\MeshSilksong`, run the command in windows cmd like this: 25 | ``` 26 | docker run -it --name silk --gpus all -p 8050:22 -v D:\Docker_projs\codes\MeshSilksong:/workspace/MeshSilksong:rw song21/silk:v1 27 | ``` 28 | 3. get into docker container 29 | 30 | For example, in windows cmd: 31 | ``` 32 | docker exec -it silk /bin/bash 33 | ``` 34 | 4. activate conda env (in docker container) 35 | ``` 36 | conda activate silk 37 | ``` 38 | 5. If sudo in container: 39 | ``` 40 | su root 41 | passwd: 111111 42 | ``` 43 | 44 | ## 2. Checkpoints Download 45 | 46 | #### Main checkpoint 47 | 48 | Currently we released lite version checkpoints trained on 100K public available datasets, refer to training part for reproduction. Checkpoints trained on more datasets will be released soon. 49 | 50 | To download the model, use huggingface-cli: 51 | 52 | ``` 53 | python3 -m pip install "huggingface_hub[cli]" 54 | mkdir ./checkpoints 55 | huggingface-cli download gcsong/mesh_silksong --local-dir ./checkpoints 56 | ``` 57 | Or directly download from [Huggingface](https://huggingface.co/gcsong/mesh_silksong/tree/main), and put checkpoint in this path 58 | ``` 59 | ./checkpoints/release-100K/model.safetensors 60 | ``` 61 | 62 | #### Michelangelo checkpoint 63 | 64 | If you want to train the GPT model from scratch, the pretrained [Michelangelo](https://github.com/NeuralCarver/Michelangelo) point-encoder is required for finetune. Just download `shapevae-256.ckpt` from [here](https://huggingface.co/Maikou/Michelangelo/tree/main/checkpoints/aligned_shape_latents) and put it in this path 65 | ``` 66 | miche/shapevae-256.ckpt 67 | ``` 68 | 69 | ## 3. Inference 70 | Run `sh scripts/infer_silksong_obj.sh` for inference, key parameter illustration: 71 | - `INFER_BATCH`: the batch size for inference, you may set it to 1 on limited GPU Memory. 72 | - `WORKSPACE`: the save dir of generated meshes. 73 | - `TEST_INPUT`: the input dir of dense meshes/ground truth meshes. Point cloud will be sampled as GPT condition. We provide some mesh examples sampled from public datasets [here](https://drive.google.com/drive/folders/1zR7UpC1LJPN2mQC_CfR-Dn2lHRWXG5Eb?usp=sharing). Download them and put them in this path 74 | ``` 75 | datasets/sample_test/meshes/test_mix_origin/batch00/ 76 | ``` 77 | - `RESUME`: main checkpoint path. 78 | 79 | If you have a cluster, run the slurm script: 80 | ``` 81 | sbatch slurm_jobs/infer_silksong_obj.sh 82 | ``` 83 | 84 | ## 4. Training 85 | #### 4.1 Train from scratch 86 | 1. Hardware Requirements 87 | 88 | It is recommended to prepare at least 16 GPUs if the dataset scale is 50K+. Empirically, it may take about 2 weeks for 100K data items on 16 H800 GPUs, the training time will be significantly shorter with more GPUs. 89 | 90 | 2. Point Cloud Encoder 91 | 92 | Download pretrained [Michelangelo](https://github.com/NeuralCarver/Michelangelo) point-encoder. 93 | 94 | 3. Data Prefilter 95 | 96 | To prevent the blocking of data iter during training, the training data is recommended to be processed via silksong tokenization first to filter unnessary items (e.g. Meshes with too many faces). 97 | 98 | Refer to [dataset_clean.md](https://github.com/gaochao-s/Mesh-Silksong/blob/main/dataset_clean.md) for data prefilter and data organization. 99 | 100 | 101 | 4. Data Organization 102 | 103 | Following [dataset_clean.md](https://github.com/gaochao-s/Mesh-Silksong/blob/main/dataset_clean.md), the training data should be organized in a `.xlsx` file for each dataset and the file should be saved to `./datasets/cleaned`. The testing data's `.xlsx` file should be saved in `datasets/sample_test/tables`. During training, 32 items will be selected from training set for evaluation. Refer to `config/options.py DataConfigs` for modifying. 104 | 105 | 106 | 5. Training 107 | 108 | Refer to the template `scripts/train_silksong_scratch_gpu16.sh` for training from scratch. Follow the guidance to handle different requirements: 109 | 110 | - If you just want to train on a single node with 8 GPUs: 111 | 112 | Just modify `--config_file acc_configs/gpu16.yaml` to `--config_file acc_configs/gpu8.yaml` 113 | 114 | - If you want to train on 2 nodes with 8 GPUs for each node: 115 | 116 | Apply nodes first, supposing the `server10` and `server11` are available nodes on cluster and the internal network connection should be available. Then, connect to main node `server10`: 117 | ``` 118 | ssh server10 119 | sh MeshSilksong/scripts/train_silksong_scratch_gpu16.sh 120 | ``` 121 | Then, connect to another node `server11`: 122 | ``` 123 | ssh server11 124 | MACHINE_RANK=1 MASTER_ADDR=server10 MeshSilksong/scripts/train_silksong_scratch_gpu16.sh 125 | ``` 126 | After doing these, the multi-nodes training should get started. 127 | 128 | 6. Specify Training Dataset 129 | 130 | Following [dataset_clean.md](https://github.com/gaochao-s/Mesh-Silksong/blob/main/dataset_clean.md), each dataset that used for training will be stored in a `.xlsx` file. You can specify the dataset in `DATA_SUBSETS` of training script. 131 | 132 | 133 | #### 4.2 Finetune on your own datasets 134 | Refer to the template `scripts/train_silksong_ft_gpu16.sh` for finetuning on you own dataset. You may follow the templete in [dataset_clean.md](https://github.com/gaochao-s/Mesh-Silksong/blob/main/dataset_clean.md) to organize them in a .xlsx file. 135 | 136 | 137 | ## 5. Non-manifold Process 138 | 139 | We also encapsulated the code for non-manifold processing separately. Refer to [nonmani_process.md](https://github.com/gaochao-s/Mesh-Silksong/blob/main/nonmani_process.md) for guidance. 140 | 141 | ## Acknowledgements 142 | The half-edge data structure of geometry processing is borrowed from [EdgeRunner](https://github.com/NVlabs/EdgeRunner)'s C++ implementation. 143 | ``` 144 | @article{tang2024edgerunner, 145 | title={Edgerunner: Auto-regressive auto-encoder for artistic mesh generation}, 146 | author={Tang, Jiaxiang and Li, Zhaoshuo and Hao, Zekun and Liu, Xian and Zeng, Gang and Liu, Ming-Yu and Zhang, Qinsheng}, 147 | journal={arXiv preprint arXiv:2409.18114}, 148 | year={2024} 149 | } 150 | ``` 151 | The model architecture is borrowed from [BPT](https://github.com/Tencent-Hunyuan/bpt)'s open source implementation. 152 | ``` 153 | @inproceedings{weng2025scaling, 154 | title={Scaling mesh generation via compressive tokenization}, 155 | author={Weng, Haohan and Zhao, Zibo and Lei, Biwen and Yang, Xianghui and Liu, Jian and Lai, Zeqiang and Chen, Zhuo and Liu, Yuhong and Jiang, Jie and Guo, Chunchao and others}, 156 | booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference}, 157 | pages={11093--11103}, 158 | year={2025} 159 | } 160 | ``` 161 | Thanks other wonderful works: 162 | - [Michelangelo](https://github.com/NeuralCarver/Michelangelo) 163 | - [TreeMeshGPT](https://github.com/sail-sg/TreeMeshGPT) 164 | - [DeepMesh](https://github.com/zhaorw02/DeepMesh) 165 | - [MeshAnything V2](https://github.com/buaacyw/MeshAnythingV2/tree/main) 166 | 167 | 168 | 169 | ## Citation 170 | 171 | ``` 172 | @article{song2025mesh, 173 | title={Mesh Silksong: Auto-Regressive Mesh Generation as Weaving Silk}, 174 | author={Song, Gaochao and Zhao, Zibo and Weng, Haohan and Zeng, Jingbo and Jia, Rongfei and Gao, Shenghua}, 175 | journal={arXiv preprint arXiv:2507.02477}, 176 | year={2025} 177 | } 178 | ``` 179 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tyro 3 | import glob 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from safetensors.torch import load_file 10 | import ipdb 11 | import kiui 12 | import trimesh 13 | from kiui.op import recenter 14 | from kiui.mesh_utils import clean_mesh 15 | import traceback 16 | from config.options import AllConfigs 17 | from silkutils.meshdata.mesh_io import write_obj, trans_write_tokens, trans_compare_write_tokens, trans_write_tokens_direct, write_ply_fix 18 | from model.model import SSMeshTransformer 19 | from silkutils.silksong_tokenization import get_tokenizer_silksong, detokenize_mesh_ss 20 | from datetime import datetime 21 | from model.data_provider_infer import InferDataset, joint_filter, max_filter, collate_fn_infer 22 | from x_transformers.autoregressive_wrapper import top_p, top_k 23 | 24 | opt = tyro.cli(AllConfigs) 25 | 26 | kiui.seed_everything(opt.seed) 27 | # tokenizer 28 | tokenizer, _ = get_tokenizer_silksong() 29 | 30 | # model 31 | model = SSMeshTransformer( 32 | dim = opt.model.dim, 33 | attn_depth = opt.model.depth, 34 | attn_dim_head = opt.model.attn_dim_head, 35 | attn_heads = opt.model.attn_heads, 36 | max_seq_len = opt.max_seq_length, 37 | dropout = opt.model.dropout, 38 | mode = opt.mode, 39 | num_discrete_coors= opt.meto.discrete_bins, 40 | block_size = opt.meto.block_size, 41 | offset_size = opt.meto.offset_size, 42 | conditioned_on_pc = opt.model.conditioned_on_pc, 43 | encoder_name = opt.model.encoder_name, 44 | encoder_freeze = opt.model.encoder_freeze, 45 | ) 46 | 47 | # resume pretrained checkpoint 48 | if opt.resume is not None: 49 | if opt.resume.endswith('safetensors'): 50 | ckpt = load_file(opt.resume, device='cpu') 51 | else: 52 | ckpt = torch.load(opt.resume, map_location='cpu') 53 | 54 | model.load_state_dict(ckpt, strict=False) 55 | print(f'[INFO] Loaded checkpoint from {opt.resume}') 56 | else: 57 | raise Exception('please set resume path') 58 | 59 | # device 60 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 61 | model = model.half().eval().to(device) 62 | 63 | num_params = sum([param.nelement() for param in model.decoder.parameters()]) 64 | print('Number of parameters: %.2f M' % (num_params / 1e6)) 65 | 66 | assert opt.test_path_input is not None 67 | if os.path.isdir(opt.test_path_input): 68 | file_paths = glob.glob(os.path.join(opt.test_path_input, "*")) 69 | else: 70 | file_paths = [opt.test_path_input] 71 | 72 | now=datetime.now() 73 | formatted_time = now.strftime("%Y-%m-%d-%H:%M:%S") 74 | 75 | ckpt_name=opt.resume 76 | 77 | if 'best' not in ckpt_name: 78 | ep=ckpt_name.split('/')[-2] 79 | exp_name=ckpt_name.split('/')[-3] 80 | else: 81 | ep='best' 82 | exp_name=ckpt_name.split('/')[-2] 83 | fd1=opt.test_path_input.split('/')[-2] 84 | fd2=opt.test_path_input.split('/')[-1] 85 | 86 | max_f=None 87 | if opt.infer.max_filter: 88 | max_f='maxf' 89 | else: 90 | max_f='nomax' 91 | 92 | save_folder=f'{max_f}_{fd1}_{fd2}_{exp_name}_{ep}_{formatted_time}' 93 | # for path in file_paths: 94 | # process(opt, path, save_folder, tokenizer) 95 | os.makedirs(opt.workspace, exist_ok=True) 96 | target_folder=os.path.join(opt.workspace, save_folder) 97 | os.makedirs(target_folder, exist_ok=True) 98 | method_name='silksong' 99 | 100 | infer_dataset=InferDataset(input_type='mesh', input_list=sorted(file_paths)) 101 | 102 | infer_dataloader=torch.utils.data.DataLoader( 103 | infer_dataset, 104 | batch_size=opt.infer.infer_batch, 105 | drop_last = False, 106 | shuffle = False, 107 | collate_fn=collate_fn_infer, 108 | ) 109 | 110 | with torch.no_grad(): 111 | for it, data in enumerate(infer_dataloader): 112 | if opt.infer.max_filter: 113 | codes = model.generate( 114 | batch_size = opt.infer.infer_batch, 115 | temperature = opt.infer.temperature, 116 | pc = data['pc_normal'].cuda().half(), 117 | filter_logits_fn = max_filter, 118 | filter_kwargs = dict(k=1), 119 | return_codes=True, 120 | ) 121 | else: 122 | codes = model.generate( 123 | batch_size = opt.infer.infer_batch, 124 | temperature = opt.infer.temperature, 125 | pc = data['pc_normal'].cuda().half(), 126 | filter_logits_fn = joint_filter, 127 | filter_kwargs = dict(k=50, p=0.95), 128 | return_codes=True, 129 | ) 130 | 131 | coords = [] 132 | 133 | # decoding codes to coordinates 134 | for i in range(len(codes)): 135 | code = codes[i] 136 | full_path = data['full_path'][i] 137 | code = code[code != model.pad_id].cpu().numpy() 138 | try: 139 | verts, faces = detokenize_mesh_ss(tokenizer, code, colorful=True, mani_fix=True) 140 | coords.append({'v':verts, 'f': faces, 'tokens': code}) 141 | except Exception as e: 142 | print(f'path generation failed: {full_path}, {str(e)}') 143 | traceback.print_exc() 144 | coords.append({'tokens': code}) 145 | 146 | # convert coordinates to mesh 147 | for i in range(opt.infer.infer_batch): 148 | uid = data['uid'][i] 149 | pcd = data['pc_normal'][i].cpu().numpy() 150 | gt_v= data['gt_mesh'][i]['v'] 151 | gt_f= data['gt_mesh'][i]['f'] 152 | if data['gt_mesh'][i]['tokens'] is not None: 153 | gt_token= data['gt_mesh'][i]['tokens'] 154 | else: 155 | gt_token=None 156 | 157 | pc_save_name=f'{uid}_{method_name}_POINT.ply' 158 | gt_save_name=f'{uid}_{method_name}_GT.obj' 159 | gt_token_save_name=f'{uid}_{method_name}_GT_tokens.txt' 160 | 161 | # save point 162 | point_cloud = trimesh.points.PointCloud(pcd[..., 0:3]) 163 | point_cloud.export(f'{target_folder}/{pc_save_name}', "ply") 164 | 165 | # save gt 166 | gt_mesh=trimesh.Trimesh(vertices=gt_v, faces=gt_f) 167 | gt_mesh.export(os.path.join(target_folder, gt_save_name)) 168 | if gt_token is not None: 169 | trans_write_tokens_direct(tokens=gt_token, filename=os.path.join(target_folder, gt_token_save_name.replace('.txt', f'_len{len(gt_token):05}.txt')), engine=tokenizer) 170 | 171 | # save pred 172 | pred_dic = coords[i] 173 | pred_token_save_name=f'{uid}_{method_name}_gen_tokens.txt' 174 | trans_write_tokens_direct(tokens=pred_dic['tokens'], filename=os.path.join(target_folder, pred_token_save_name.replace('.txt', f'_len{len(pred_dic["tokens"]):05}.txt')), engine=tokenizer) 175 | np.save(os.path.join(target_folder, pred_token_save_name.replace('.txt', f'_len{len(pred_dic["tokens"]):05}.npy')), pred_dic['tokens']) 176 | if gt_token is not None: 177 | trans_compare_write_tokens(tokens=pred_dic['tokens'], tokens_gt=gt_token, filename=os.path.join(target_folder, gt_token_save_name.replace('.txt', f'_len{len(gt_token):05}.txt')), engine=tokenizer) 178 | 179 | if 'v' not in pred_dic: 180 | continue 181 | pred_vert_6_layerColor=pred_dic['v'][0] 182 | pred_vert_6_ccColor=pred_dic['v'][1] 183 | 184 | pred_face=pred_dic['f'][0][:, :3] # decoding from GPT output 185 | pred_face_F=pred_dic['f'][1][:, :3] # auto fix 186 | pred_save_name=f'{uid}_{method_name}_gen.obj' 187 | 188 | 189 | F_dir=target_folder 190 | os.makedirs(F_dir, exist_ok=True) 191 | # the view of connected components 192 | write_obj(pred_vert_6_ccColor, pred_face_F, os.path.join(F_dir, pred_save_name.replace('.obj','_CCcolor.obj'))) 193 | 194 | # if you want to visualize layer view (RGB contour lines) 195 | # write_obj(pred_vert_6_layerColor, pred_face_F, os.path.join(F_dir, pred_save_name.replace('.obj','_layercolor.obj'))) 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /miche/michelangelo/models/modules/embedder.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | 8 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] 9 | 10 | 11 | class FourierEmbedder(nn.Module): 12 | """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts 13 | each feature dimension of `x[..., i]` into: 14 | [ 15 | sin(x[..., i]), 16 | sin(f_1*x[..., i]), 17 | sin(f_2*x[..., i]), 18 | ... 19 | sin(f_N * x[..., i]), 20 | cos(x[..., i]), 21 | cos(f_1*x[..., i]), 22 | cos(f_2*x[..., i]), 23 | ... 24 | cos(f_N * x[..., i]), 25 | x[..., i] # only present if include_input is True. 26 | ], here f_i is the frequency. 27 | 28 | Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs]. 29 | If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...]; 30 | Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]. 31 | 32 | Args: 33 | num_freqs (int): the number of frequencies, default is 6; 34 | logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 35 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)]; 36 | input_dim (int): the input dimension, default is 3; 37 | include_input (bool): include the input tensor or not, default is True. 38 | 39 | Attributes: 40 | frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...], 41 | otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1); 42 | 43 | out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1), 44 | otherwise, it is input_dim * num_freqs * 2. 45 | 46 | """ 47 | 48 | def __init__(self, 49 | num_freqs: int = 6, 50 | logspace: bool = True, 51 | input_dim: int = 3, 52 | include_input: bool = True, 53 | include_pi: bool = True) -> None: 54 | 55 | """The initialization""" 56 | 57 | super().__init__() 58 | 59 | if logspace: 60 | frequencies = 2.0 ** torch.arange( 61 | num_freqs, 62 | dtype=torch.float32 63 | ) 64 | else: 65 | frequencies = torch.linspace( 66 | 1.0, 67 | 2.0 ** (num_freqs - 1), 68 | num_freqs, 69 | dtype=torch.float32 70 | ) 71 | 72 | if include_pi: 73 | frequencies *= torch.pi 74 | 75 | self.register_buffer("frequencies", frequencies, persistent=False) 76 | self.include_input = include_input 77 | self.num_freqs = num_freqs 78 | 79 | self.out_dim = self.get_dims(input_dim) 80 | 81 | def get_dims(self, input_dim): 82 | temp = 1 if self.include_input or self.num_freqs == 0 else 0 83 | out_dim = input_dim * (self.num_freqs * 2 + temp) 84 | 85 | return out_dim 86 | 87 | def forward(self, x: torch.Tensor) -> torch.Tensor: 88 | """ Forward process. 89 | 90 | Args: 91 | x: tensor of shape [..., dim] 92 | 93 | Returns: 94 | embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)] 95 | where temp is 1 if include_input is True and 0 otherwise. 96 | """ 97 | 98 | if self.num_freqs > 0: 99 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) 100 | if self.include_input: 101 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 102 | else: 103 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 104 | else: 105 | return x 106 | 107 | 108 | class LearnedFourierEmbedder(nn.Module): 109 | """ following @crowsonkb "s lead with learned sinusoidal pos emb """ 110 | """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ 111 | 112 | def __init__(self, in_channels, dim): 113 | super().__init__() 114 | assert (dim % 2) == 0 115 | half_dim = dim // 2 116 | per_channel_dim = half_dim // in_channels 117 | self.weights = nn.Parameter(torch.randn(per_channel_dim)) 118 | 119 | def forward(self, x): 120 | """ 121 | 122 | Args: 123 | x (torch.FloatTensor): [..., c] 124 | 125 | Returns: 126 | x (torch.FloatTensor): [..., d] 127 | """ 128 | 129 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] 130 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) 131 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) 132 | return fouriered 133 | 134 | 135 | class TriplaneLearnedFourierEmbedder(nn.Module): 136 | def __init__(self, in_channels, dim): 137 | super().__init__() 138 | 139 | self.yz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 140 | self.xz_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 141 | self.xy_plane_embedder = LearnedFourierEmbedder(in_channels, dim) 142 | 143 | self.out_dim = in_channels + dim 144 | 145 | def forward(self, x): 146 | 147 | yz_embed = self.yz_plane_embedder(x) 148 | xz_embed = self.xz_plane_embedder(x) 149 | xy_embed = self.xy_plane_embedder(x) 150 | 151 | embed = yz_embed + xz_embed + xy_embed 152 | 153 | return embed 154 | 155 | 156 | def sequential_pos_embed(num_len, embed_dim): 157 | assert embed_dim % 2 == 0 158 | 159 | pos = torch.arange(num_len, dtype=torch.float32) 160 | omega = torch.arange(embed_dim // 2, dtype=torch.float32) 161 | omega /= embed_dim / 2. 162 | omega = 1. / 10000 ** omega # (D/2,) 163 | 164 | pos = pos.reshape(-1) # (M,) 165 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 166 | 167 | emb_sin = torch.sin(out) # (M, D/2) 168 | emb_cos = torch.cos(out) # (M, D/2) 169 | 170 | embeddings = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 171 | 172 | return embeddings 173 | 174 | 175 | def timestep_embedding(timesteps, dim, max_period=10000): 176 | """ 177 | Create sinusoidal timestep embeddings. 178 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 179 | These may be fractional. 180 | :param dim: the dimension of the output. 181 | :param max_period: controls the minimum frequency of the embeddings. 182 | :return: an [N x dim] Tensor of positional embeddings. 183 | """ 184 | half = dim // 2 185 | freqs = torch.exp( 186 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 187 | ).to(device=timesteps.device) 188 | args = timesteps[:, None].to(timesteps.dtype) * freqs[None] 189 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 190 | if dim % 2: 191 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 192 | return embedding 193 | 194 | 195 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, degree=4, 196 | num_levels=16, level_dim=2, per_level_scale=2, base_resolution=16, 197 | log2_hashmap_size=19, desired_resolution=None): 198 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): 199 | return nn.Identity(), input_dim 200 | 201 | elif embed_type == "fourier": 202 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, input_dim=input_dim, 203 | logspace=True, include_input=True) 204 | return embedder_obj, embedder_obj.out_dim 205 | 206 | elif embed_type == "hashgrid": 207 | raise NotImplementedError 208 | 209 | elif embed_type == "sphere_harmonic": 210 | raise NotImplementedError 211 | 212 | else: 213 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") 214 | -------------------------------------------------------------------------------- /silkutils/meto/decode_utils_fix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | 4 | def edge_key(a, b): 5 | return (a, b) if a < b else (b, a) 6 | 7 | def get_edge_to_face_dic(face_list, triangles, vert_index): 8 | ret_dic={} 9 | for face_index in face_list: 10 | edge=[x for x in triangles[face_index] if x !=vert_index] 11 | ret_dic[edge_key(edge[0], edge[1])]=face_index 12 | return ret_dic 13 | 14 | def find_cycle_or_chain(component): 15 | # component: subgraph 16 | if not nx.is_tree(component): 17 | # 查找所有的简单环 18 | cycles = list(nx.simple_cycles(component.to_directed())) 19 | if cycles: 20 | max_cycle = max(cycles, key=len) 21 | if max_cycle == 2: 22 | raise Exception('wrong') 23 | return max_cycle, 'cycle' 24 | # max_cycle = min(cycles, key=len) 25 | # return max_cycle, 'cycle' 26 | raise Exception('can not find cycle ?') 27 | 28 | # 如果没有环,返回最长路径 29 | longest_path = [] 30 | for node in component.nodes(): 31 | # 进行深度优先搜索找最长路径 32 | paths = nx.single_source_dijkstra_path(component, node) 33 | for path in paths.values(): 34 | if len(path) > len(longest_path): 35 | longest_path = path 36 | return longest_path, 'chain' 37 | 38 | def find_all_cycle(component): 39 | # component: subgraph 40 | if not nx.is_tree(component): 41 | # 查找所有的简单环 42 | cycles = list(nx.simple_cycles(component)) 43 | 44 | if cycles: 45 | return cycles, 'cycle' 46 | raise Exception('can not find cycle ?') 47 | 48 | longest_path = [] 49 | return longest_path, 'chain' 50 | 51 | class VertDecode: 52 | def __init__(self, v6_bfs, v6_cc, index, layer, cc_id): 53 | self.v6_bfs=v6_bfs 54 | self.v6_cc=v6_cc 55 | self.index=index 56 | self.cc_id=cc_id 57 | self.neighbor_faces=[] 58 | 59 | 60 | class MeshDecode: 61 | def __init__(self, vertices_list, faces): 62 | self.verts=[VertDecode(ele['v6_bfs'], ele['v6_cc'], ele['index'], ele['layer'], ele['cc_id']) for ele in vertices_list] 63 | self.faces=faces 64 | self.vert_num=len(self.verts) 65 | self.face_num=len(faces) 66 | self.face_exclude=[] 67 | self.bound_edges=[] 68 | self.hole_fix_faces=[] 69 | 70 | def exclude_nm_faces(self): 71 | for face_index, face in enumerate(self.faces): 72 | self.verts[face[0]].neighbor_faces.append(face_index) 73 | self.verts[face[1]].neighbor_faces.append(face_index) 74 | self.verts[face[2]].neighbor_faces.append(face_index) 75 | 76 | for vert_index in range(self.vert_num): 77 | face_list_o=self.verts[vert_index].neighbor_faces 78 | face_list=[face for face in face_list_o if face not in self.face_exclude] 79 | 80 | if not face_list: 81 | print(f'[WARNING] vert {vert_index} single') 82 | continue 83 | edge_to_face_dic=get_edge_to_face_dic(face_list, self.faces, vert_index) 84 | # 只要有环,排除环外的全部face,全部连通分量,且无边界边 85 | # 若无环,对于每个连通分量,计算最长链条,排除非流形face,标记边界边 86 | # 修复一般孔洞:cycle 长度< 6 且同一CC 87 | # 如果要水密性修复,则xxx 88 | info_dic=self.check_edge_graph(vert_index, edge_to_face_dic) 89 | keeped_face=info_dic['keeped_face'] 90 | exclude_faces=[f for f in face_list if f not in keeped_face] 91 | self.face_exclude+=exclude_faces 92 | if info_dic['type']=='chain': 93 | self.bound_edges+=info_dic['bound_edges'] 94 | 95 | print(f'about {len(self.face_exclude)} excluded') 96 | print(f'find {len(self.bound_edges)} bound edges') 97 | 98 | def fix_hole(self, max_cycle=6, water_tight=False): 99 | if water_tight: 100 | max_cycle=100 101 | hole_graph=nx.Graph() 102 | hole_graph.add_edges_from(self.bound_edges) 103 | components = list(nx.connected_components(hole_graph)) 104 | for cc_i, component_nodes in enumerate(components): 105 | component = hole_graph.subgraph(component_nodes) 106 | cycles, type = find_all_cycle(component) 107 | if type=='chain': 108 | continue 109 | for cycle in cycles: 110 | cycle_len=len(cycle) 111 | if cycle_len>max_cycle: 112 | continue 113 | # fix hole 114 | self.hole_fix_faces+=self.fix_cycle(cycle) 115 | print(f'fix hole and add {len(self.hole_fix_faces)} faces') 116 | 117 | def fix_cycle(self, cycle): 118 | if len(cycle)==3: 119 | a, b, c= cycle 120 | return [[a,b,c]] 121 | if len(cycle)==4: 122 | a,b,c,d=cycle 123 | return [[a,d,b],[c,b,d]] 124 | if len(cycle)==5: 125 | a,b,c,d,e=cycle 126 | return [[a,c,b],[a,d,c],[a,e,d]] 127 | if len(cycle)==6: 128 | a,b,c,d,e,f=cycle 129 | return [[a,c,b],[a,d,c],[a,f,d],[f,e,d]] 130 | raise Exception('not implement') 131 | 132 | 133 | 134 | def check_edge_graph(self, vert_index, edge_to_face_dic): 135 | graph=nx.Graph() 136 | graph.add_edges_from(edge_to_face_dic.keys()) 137 | components=list(nx.connected_components(graph)) 138 | cc_info_list=[] 139 | cycle_cc_index=[] 140 | for cc_i, component_nodes in enumerate(components): 141 | component = graph.subgraph(component_nodes) 142 | cycle_or_chain, type = find_cycle_or_chain(component) 143 | cc_dic={} 144 | cc_dic['type']=type 145 | if type == 'cycle': 146 | 147 | cc_edges_cycle=[edge_key(cycle_or_chain[i%len(cycle_or_chain)], cycle_or_chain[(i+1)%len(cycle_or_chain)]) for i in range(len(cycle_or_chain))] 148 | cycle_len=len(cc_edges_cycle) 149 | cc_dic['cycle_edge']=cc_edges_cycle 150 | cc_dic['cycle_face']=[edge_to_face_dic[edge] for edge in cc_edges_cycle] 151 | cc_dic['cycle_len']=cycle_len 152 | if not cycle_cc_index: 153 | cycle_cc_index.append([cc_i, cycle_len]) 154 | else: 155 | last_len=cycle_cc_index[0][1] 156 | if cycle_len>last_len: 157 | cycle_cc_index=[[cc_i, cycle_len]] 158 | else: 159 | cc_edges_chain=[edge_key(cycle_or_chain[i], cycle_or_chain[(i+1)]) for i in range(len(cycle_or_chain)-1)] 160 | cc_dic['chain_edge']=cc_edges_chain 161 | cc_dic['chain_face']=[edge_to_face_dic[edge] for edge in cc_edges_chain] 162 | cc_dic['chain_len']=len(cc_edges_chain) 163 | cc_dic['chain_bound']=[cycle_or_chain[0], cycle_or_chain[-1]] 164 | 165 | cc_info_list.append(cc_dic) 166 | 167 | ret_dic={} 168 | if cycle_cc_index: 169 | # 有cycle且找到最大cycle, 则保留这些cycle的face,去除其它face 170 | cc_i, cycle_len = cycle_cc_index[0] 171 | keeped_face=cc_info_list[cc_i]['cycle_face'] 172 | ret_dic['type']='cycle' 173 | ret_dic['keeped_face']=keeped_face 174 | else: 175 | # 全是chain 176 | keeped_face=[] 177 | bound_edges=[] 178 | for chain_dic in cc_info_list: 179 | keeped_face+=chain_dic['chain_face'] 180 | bp_1=chain_dic['chain_bound'][0] 181 | bp_2=chain_dic['chain_bound'][1] 182 | bound_edges.append([vert_index, bp_1]) 183 | bound_edges.append([vert_index, bp_2]) 184 | ret_dic['type']='chain' 185 | ret_dic['keeped_face']=keeped_face 186 | ret_dic['bound_edges']=bound_edges 187 | return ret_dic 188 | 189 | 190 | 191 | def manifold_fix(vertices_dic_list, faces): 192 | # vertices: list 193 | # faces: list 194 | mesh_input=MeshDecode(vertices_list=vertices_dic_list, faces=faces) 195 | mesh_input.exclude_nm_faces() 196 | mesh_input.fix_hole() 197 | new_faces=[f for idx, f in enumerate(mesh_input.faces) if idx not in mesh_input.face_exclude] 198 | new_faces+=mesh_input.hole_fix_faces 199 | 200 | return vertices_dic_list, new_faces 201 | 202 | if __name__ == "__main__": 203 | pass -------------------------------------------------------------------------------- /silkutils/dataset_clean/process_one.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | from silksong_tokenization import get_tokenizer_silksong 5 | from meshdata.mesh_io import load_mesh, load_mesh_modify, write_obj, write_tokens_ori, quick_demo, trans_write_tokens, trans_compare_write_tokens, write_ply_fix 6 | from ss_platform import get_base_dir_platform, get_base_dir_rel 7 | import numpy as np 8 | import os 9 | import pandas as pd 10 | from tqdm import tqdm 11 | import kiui 12 | import ipdb 13 | import time 14 | import func_timeout 15 | from func_timeout import func_set_timeout 16 | 17 | 18 | 19 | def get_save_basename(dataset_name, obj_save_name, mode): 20 | return f'{dataset_name}_{obj_save_name}_v{mode:02}' 21 | 22 | # @func_set_timeout(100) 23 | def process_one(**kwargs): 24 | max_face_after_trimesh=kwargs.get('max_face_num', 16000) # max: 16k 25 | 26 | dataset_name=kwargs.get('dataset_name') 27 | 28 | obj_path=kwargs.get('obj_path') 29 | obj_save_name=kwargs.get('obj_save_name') 30 | save_dir_npy=kwargs.get('save_dir_npy') 31 | save_dir_meta=kwargs.get('save_dir_meta') 32 | save_dir_debug=kwargs.get('save_dir_debug') 33 | resolution=kwargs.get('resolution') 34 | mode=kwargs.get('mode') 35 | face_type=kwargs.get('face_type') 36 | debugging=kwargs.get('debugging') 37 | 38 | structure_limit_kwargs={ 39 | 'NM_max_edge_graph': kwargs.get('NM_max_edge_graph', 50), 40 | 'NM_max_nonmani_verts': kwargs.get('NM_max_nonmani_verts', 300), 41 | 'min_CC_face': kwargs.get('min_CC_face', 3), 42 | 'max_face_num_p': kwargs.get('max_face_num_p', 12000) 43 | } 44 | 45 | if debugging is None: 46 | debugging=False 47 | # print('debugging false') 48 | elif debugging == False: 49 | pass 50 | else: 51 | debugging=True 52 | print('debugging true') 53 | S_time=time.time() 54 | 55 | save_base_name=get_save_basename(dataset_name=dataset_name, obj_save_name=obj_save_name, mode=mode) 56 | 57 | if debugging: 58 | save_dir_platform='/workspace/MeshSilksong/silkutils/Debugs/debug_output' 59 | save_dir_platform=os.path.join(save_dir_platform, save_base_name) 60 | else: 61 | save_dir_platform=get_base_dir_platform(dataset_name) 62 | 63 | if face_type=='triangle': 64 | # support .obj, .glb, .ply ... as well as trimesh support 65 | obj_path_platform=os.path.join(get_base_dir_platform(dataset_name), obj_path) 66 | try: 67 | vertices, faces = load_mesh(obj_path_platform, clean=True) 68 | except Exception as e: 69 | raise Exception('[E] loading Failed') 70 | elif face_type=='multi': 71 | # support .obj 72 | obj_path_platform=os.path.join(get_base_dir_platform(dataset_name), obj_path) 73 | try: 74 | vertices, faces = load_mesh_modify(obj_path_platform) 75 | except Exception as e: 76 | raise Exception('[E] loading Failed') 77 | elif face_type=='quick_demo': 78 | vertices, faces = quick_demo(obj_path) 79 | else: 80 | raise Exception('[E] wrong face type') 81 | 82 | if len(faces)>max_face_after_trimesh: 83 | # pass 84 | raise Exception(f'[E] too many faces {len(faces)} > {max_face_after_trimesh}!') 85 | 86 | 87 | tokensO_filename=f'tokensO_{save_base_name}.txt' # readable tokens 88 | tokensE_filename=f'tokensE_{save_base_name}.npy' # encode vertex 89 | tokensT_filename=f'tokensT_{save_base_name}.txt' # translate from encoded 90 | meta_filename=f'meta_{save_base_name}.txt' 91 | decode_filename=f'decode_{save_base_name}.obj' 92 | 93 | tokensO_filedir=os.path.join(save_dir_platform, save_dir_debug) 94 | tokensE_filedir=os.path.join(save_dir_platform, save_dir_npy) 95 | tokensT_filedir=os.path.join(save_dir_platform, save_dir_debug) 96 | meta_filedir=os.path.join(save_dir_platform, save_dir_meta) 97 | decode_filedir=os.path.join(save_dir_platform, save_dir_debug) 98 | 99 | os.makedirs(tokensO_filedir, exist_ok=True) 100 | os.makedirs(tokensE_filedir, exist_ok=True) 101 | os.makedirs(tokensT_filedir, exist_ok=True) 102 | os.makedirs(meta_filedir, exist_ok=True) 103 | os.makedirs(decode_filedir, exist_ok=True) 104 | 105 | tokensO_path=os.path.join(tokensO_filedir, tokensO_filename) 106 | tokensE_path=os.path.join(tokensE_filedir, tokensE_filename) 107 | tokensT_path=os.path.join(tokensT_filedir, tokensT_filename) 108 | meta_path=os.path.join(meta_filedir, meta_filename) 109 | # mesh after token decoder 110 | decode_path=os.path.join(decode_filedir, decode_filename) 111 | # M1:(GT) read mesh (obj,ply,glb,...) and trimesh process, normalize, clean by kiui 112 | # M2: mesh processed by silksong: non-mani preprocess, flip fixing, cc classification and coloring 113 | # M3: mesh processed by vertex layering and sorting, colored by layering 114 | M1_path=decode_path.replace('decode','M1') 115 | M2_path=decode_path.replace('decode','M2') 116 | M3_path=decode_path.replace('decode','M3') 117 | 118 | if debugging: 119 | print(f'Mesh M1 saving to {M1_path}') 120 | write_obj(vertices, faces, M1_path) 121 | 122 | 123 | meta_init_kwargs={ 124 | 'version': mode, 125 | 'origin_path': os.path.join(get_base_dir_platform(dataset_name), obj_path), 126 | 'other_path':[None, None], 127 | 'face_type': face_type, 128 | 'resolution': resolution, 129 | 'M1_path': M1_path, 130 | 'M2_path': M2_path, 131 | 'M3_path': M3_path, 132 | } 133 | 134 | tokenizer, _ =get_tokenizer_silksong(resolution=resolution, ss_mode=mode, meta_init_data=meta_init_kwargs, structure_limit=structure_limit_kwargs, debugging=debugging) 135 | 136 | tokensO = tokenizer.encode(vertices, faces, non_mani_process=True) 137 | meta_data_temp=tokenizer.get_metaData() 138 | if debugging: 139 | write_tokens_ori(tokensO, tokensO_path) 140 | meta_data_temp.save_meta(meta_path) 141 | 142 | 143 | tokensE, meta_data_temp=tokenizer.token_encode(input_tokens=tokensO, mode=mode) 144 | E_time=time.time() 145 | 146 | if debugging: 147 | np.save(tokensE_path, tokensE) 148 | meta_data_temp.save_meta(meta_path) 149 | trans_write_tokens(tokensE, tokensT_path, tokenizer, resolution, mode) 150 | load_token=np.load(tokensE_path) 151 | else: 152 | load_token=tokensE 153 | 154 | # if debugging: 155 | # vertices_decode, faces_decode = engine.decode_ori(tokens_ori) 156 | 157 | vertices_decode, faces_decode = tokenizer.decode(load_token, discrete_bins=resolution, mode=mode, colorful=True) 158 | 159 | # save v specified 160 | meta_data_temp=tokenizer.get_metaData() 161 | 162 | if debugging: 163 | write_obj(vertices_decode[0], faces_decode, decode_path.replace('.obj', '_layercolor.obj')) 164 | write_obj(vertices_decode[1], faces_decode, decode_path.replace('.obj', '_CCcolor.obj')) 165 | meta_data_temp.save_meta(meta_path) 166 | 167 | xlsx_line_new={ 168 | 'done': 1, 169 | 'vert_num': meta_data_temp.vert_num, 170 | 'face_num': meta_data_temp.face_num, 171 | 'vert_num_process': meta_data_temp.vert_num_p, 172 | 'face_num_process': meta_data_temp.face_num_p, 173 | 'CC_num': meta_data_temp.CC_num, 174 | 'CC_num_valid': meta_data_temp.CC_num_valid, 175 | 'CC_num_pre': meta_data_temp.CC_num_pre, 176 | 'CC_num_pre_all': meta_data_temp.CC_num_pre_all, 177 | 'max_lv': meta_data_temp.max_lv, 178 | 'max_l': meta_data_temp.max_l, 179 | 'max_edge_num': meta_data_temp.max_edge_num, 180 | 'token_length': meta_data_temp.token_length, 181 | 'compression_rate': meta_data_temp.compression_rate, 182 | 'flipped_face': meta_data_temp.flipped_face_cnt, 183 | '1v2cc_gen': meta_data_temp.new_generated_verts, 184 | 'nonmani_gen': meta_data_temp.non_manifold_new_gen, 185 | 'nonmani_face': meta_data_temp.non_manifold_vert_cnt, 186 | 'nonmani_process_time': meta_data_temp.non_manifold_process_time, 187 | 'not_success_flip_face': meta_data_temp.not_success_flip_face, 188 | 'merge_repeat_verts': meta_data_temp.merge_repeat_verts_num, 189 | 'replace_facevert_num': meta_data_temp.replace_facevert_num, 190 | 'degraded_face_num': meta_data_temp.degraded_face_num, 191 | 'move_repeat_face_num': meta_data_temp.move_repeat_face_num, 192 | 'CC_invalid_vert_num': meta_data_temp.CC_invalid_verts, 193 | 'CC_invalid_face_num': meta_data_temp.CC_invalid_faces, 194 | 'encode_time': E_time-S_time, 195 | } 196 | if debugging: 197 | for k, v in xlsx_line_new.items(): 198 | print(f'{k}:{v}') 199 | 200 | return xlsx_line_new 201 | 202 | 203 | -------------------------------------------------------------------------------- /silkutils/dataset_clean/step4_datafilter.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import sys 4 | import os 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | import pandas as pd 7 | import re 8 | import ss_platform 9 | 10 | def get_filtered_df(df, filter_version): 11 | if filter_version == 1: 12 | return df[df['bug_info'].isna() & (df['CC_num_valid'] <= 30) & (df['nonmani_process_time'] == 0)] 13 | elif filter_version == 2: # success processed file 14 | return df[df['bug_info'].isna()] 15 | elif filter_version == 0: 16 | return df 17 | elif filter_version == 3: 18 | df['compression_rate'] = df['compression_rate'].str.replace('%', '').astype(float) 19 | # import ipdb;ipdb.set_trace() 20 | filtered_df = df[ 21 | (df['compression_rate'] <= 65.0) & 22 | (df['CC_num_valid'] <= 100) & 23 | (df['token_length'] <= 20000) 24 | ] 25 | return filtered_df 26 | elif filter_version == 4: # num face <= 4k, cc<=5, token_length<40960 27 | filtered_df = df[ 28 | (df['face_num_process'] <= 4000) & 29 | (df['CC_num_valid'] <= 5) & 30 | (df['token_length'] <= 40960) 31 | ] 32 | return filtered_df 33 | elif filter_version == 5: # num face <= 8k, cc<=20, token_length<40960 34 | filtered_df = df[ 35 | (df['face_num_process'] <= 8000) & 36 | (df['CC_num_valid'] <= 20) & 37 | (df['token_length'] <= 40960) 38 | ] 39 | return filtered_df 40 | elif filter_version == 6: # num face <= 12k, cc<=100, token_length<40960 41 | filtered_df = df[ 42 | (df['face_num_process'] <= 12000) & 43 | (df['CC_num_valid'] <= 100) & 44 | (df['token_length'] <= 40960) 45 | ] 46 | return filtered_df 47 | elif filter_version == 7: # 100<=num face <= 4k, cc<=10, token_length<40960 48 | filtered_df = df[ 49 | (df['face_num_process'] <= 4000) & 50 | (df['face_num_process'] >= 100) & 51 | (df['CC_num_valid'] <= 10) & 52 | (df['token_length'] <= 20480) 53 | ] 54 | return filtered_df 55 | elif filter_version == 8: # num face <= 8k, cc<=20, token_length<20480, max_lv <=200 56 | filtered_df = df[ 57 | (df['face_num_process'] <= 8000) & 58 | (df['max_lv'] <= 200) & 59 | (df['CC_num_valid'] <= 20) & 60 | (df['token_length'] <= 20480) 61 | ] 62 | return filtered_df 63 | elif filter_version == 9: # num face <= 8k, cc<=20, token_length<10000, max_lv <=200 64 | filtered_df = df[ 65 | (df['face_num_process'] <= 8000) & 66 | (df['face_num_process'] >= 80) & 67 | (df['max_lv'] <= 200) & 68 | (df['CC_num_valid'] <= 30) & 69 | (df['token_length'] <= 10000) 70 | ] 71 | return filtered_df 72 | elif filter_version == 10: # cc<=100, token_length<40960, max_lv <=200 73 | filtered_df = df[ 74 | (df['max_lv'] <= 200) & 75 | (df['CC_num_valid'] <= 100) & 76 | (df['token_length'] <= 40960) 77 | ] 78 | return filtered_df 79 | elif filter_version == 11: # for trellis more 80 | filtered_df = df[ 81 | (df['max_lv'] <= 200) & 82 | (df['face_num_process'] >= 40) & 83 | (df['CC_num_valid'] <= 50) & 84 | (df['token_length'] <= 10000) 85 | ] 86 | return filtered_df 87 | elif filter_version == 12: # for longer, new archi 88 | filtered_df = df[ 89 | (df['max_lv'] <= 200) & 90 | (df['face_num_process'] >= 40) & 91 | (df['CC_num_valid'] <= 100) & 92 | (df['token_length'] <= 20000) 93 | ] 94 | return filtered_df 95 | 96 | def merge_part_files(folder_path, version_list=None): 97 | 98 | merged_files = {} 99 | 100 | pattern = r'meta_all_(.+?)_res(\d+)_v(\d{2})_p(\d{4})_done_b(\d{3})\.xlsx' 101 | 102 | for filename in os.listdir(folder_path): 103 | match = re.match(pattern, filename) 104 | if match: 105 | dataset_name = match.group(1) 106 | reso = int(match.group(2)) 107 | version = int(match.group(3)) 108 | part_ind = int(match.group(4)) 109 | b_i = int(match.group(5)) 110 | if version_list: 111 | if version not in version_list: 112 | continue 113 | 114 | key = (dataset_name, reso, version, part_ind) 115 | 116 | 117 | file_path = os.path.join(folder_path, filename) 118 | print(f'find {file_path}') 119 | df = pd.read_excel(file_path) 120 | 121 | if key not in merged_files: 122 | merged_files[key] = [] 123 | merged_files[key].append(df) 124 | 125 | for key, dfs in merged_files.items(): 126 | merged_df = pd.concat(dfs, ignore_index=True) 127 | 128 | dataset_name, reso, version, part_ind = key 129 | merged_filename = f'meta_all_{dataset_name}_res{reso}_v{version:02}_p{part_ind:04}_merge.xlsx' 130 | 131 | merged_file_path = os.path.join(folder_path, merged_filename) 132 | 133 | merged_df.sort_values(by='id').reset_index(drop=True) 134 | merged_df.to_excel(merged_file_path, index=False) 135 | 136 | print(f'Merged {len(dfs)} files into {merged_filename}') 137 | 138 | def merge_all_files(directory, version_list=None): 139 | 140 | pattern = re.compile(r'meta_all_(?P.+?)_res(?P\d+)_v(?P\d{2})_p(?P\d{4})_merge\.xlsx') 141 | 142 | 143 | files_dict = {} 144 | 145 | 146 | for filename in os.listdir(directory): 147 | match = pattern.match(filename) 148 | if match: 149 | 150 | dataset_name = match.group('dataset_name') 151 | reso = int(match.group('reso')) 152 | version = int(match.group('version')) 153 | part_ind = int(match.group('part_ind')) 154 | if version_list: 155 | if version not in version_list: 156 | continue 157 | key = (dataset_name, reso, version) 158 | file_path = os.path.join(directory, filename) 159 | 160 | 161 | if key not in files_dict: 162 | files_dict[key] = [] 163 | files_dict[key].append(file_path) 164 | 165 | 166 | for (dataset_name, reso, version), file_paths in files_dict.items(): 167 | # List to hold DataFrames 168 | dataframes = [] 169 | 170 | 171 | for file_path in file_paths: 172 | df = pd.read_excel(file_path) 173 | dataframes.append(df) 174 | 175 | 176 | merged_df = pd.concat(dataframes, ignore_index=True) 177 | 178 | 179 | output_filename = f'meta_all_{dataset_name}_res{reso}_v{version:02}_mergeall.xlsx' 180 | output_path = os.path.join(directory, output_filename) 181 | 182 | 183 | merged_df.sort_values(by='id').reset_index(drop=True) 184 | merged_df.to_excel(output_path, index=False) 185 | 186 | print(f'Merged {len(file_paths)} files into {output_path}') 187 | 188 | 189 | def filter_excel_files(in_directory, out_directory, filter_version, version_list=None): 190 | 191 | pattern = re.compile(r'meta_all_(?P.+?)_res(?P\d+)_v(?P\d{2})_mergeall\.xlsx') 192 | 193 | for filename in os.listdir(in_directory): 194 | match = pattern.match(filename) 195 | if match: 196 | 197 | dataset_name = match.group('dataset_name') 198 | reso = int(match.group('reso')) 199 | version = int(match.group('version')) 200 | file_path = os.path.join(in_directory, filename) 201 | if version_list: 202 | if version not in version_list: 203 | continue 204 | 205 | df = pd.read_excel(file_path) 206 | 207 | 208 | filtered_df = get_filtered_df(df, filter_version) 209 | 210 | output_filename = f'meta_all_{dataset_name}_res{reso}_v{version:02}_mergeall_filter{filter_version:02}.xlsx' 211 | output_path = os.path.join(out_directory, output_filename) 212 | 213 | filtered_df.sort_values(by='id').reset_index(drop=True) 214 | filtered_df.to_excel(output_path, index=False) 215 | 216 | print(f'Filtered version {filter_version:02} ---> {filename} and saved to {output_path}') 217 | print(f'filter/all {len(filtered_df)}/{len(df)}') 218 | 219 | def filter_dataset(dataset_name, filtered_xlsx_save_dir, filter_version, version_list=[4], merge=True): 220 | work_dir=ss_platform.get_base_dir_platform(dataset_name) 221 | if merge: 222 | print(f'Merging {dataset_name}') 223 | merge_part_files(work_dir, version_list=version_list) 224 | merge_all_files(work_dir, version_list=version_list) 225 | print(f'Filtering {dataset_name}, filter verion {filter_version}') 226 | filter_excel_files(in_directory=work_dir, out_directory=filtered_xlsx_save_dir, filter_version=filter_version, version_list=version_list) 227 | 228 | 229 | if __name__ == "__main__": 230 | # dataset_name='objaversev1' 231 | # datasetnames=['3dcaricshop','3dfuture','abo','animal3d','buildingnet','gso', 'thingi10k','toys4k','shapenetv2','objaversev1'] 232 | # datasetnames=['3dfuture','thingi10k','toys4k','shapenetv2','gobjaversev1'] 233 | filtered_xlsx_save_dir='/public/home/group_gaosh/gaochao/main_workspace/MeshSilksong/datasets/cleaned' 234 | datasetnames=['trellis-objxl-github'] 235 | for dataset_name in datasetnames: 236 | filter_dataset(dataset_name=dataset_name, filtered_xlsx_save_dir=filtered_xlsx_save_dir, filter_version=0, version_list=[4], merge=False) 237 | -------------------------------------------------------------------------------- /silkutils/dataset_clean/process_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import logging 8 | from dataset_clean.process_one import process_one 9 | from ss_platform import get_savedirs, base_dir, get_base_dir_platform 10 | 11 | import concurrent.futures 12 | import time 13 | from typing import Dict, Any 14 | import traceback 15 | import re 16 | import random 17 | import argparse 18 | import func_timeout 19 | from func_timeout import func_set_timeout 20 | 21 | 22 | def extract_xlsx_param(filename): 23 | 24 | pattern = r'meta_all_(?P[^_]+)_res(?P\d+)_v(?P\d{2})_p(?P\d{4})\.xlsx' 25 | 26 | match = re.search(pattern, filename) 27 | 28 | if match: 29 | 30 | file_info = match.groupdict() 31 | file_info['version'] = int(file_info['version']) 32 | file_info['part_num'] = int(file_info['part_num']) 33 | file_info['resolution'] = int(file_info['resolution']) 34 | return file_info['type'], file_info['resolution'], file_info['version'], file_info['part_num'] 35 | else: 36 | raise ValueError("File name does not match the expected pattern.") 37 | 38 | 39 | def setup_logging(log_file_path): 40 | """ 41 | 配置日志记录 42 | :param log_file_path: 日志文件的路径 43 | """ 44 | logging.basicConfig( 45 | filename=log_file_path, # 46 | level=logging.ERROR, # 47 | format='%(asctime)s - %(levelname)s - %(message)s' # 48 | ) 49 | 50 | 51 | def init_process_one_params(dataset_name, item, resolution, version, part_num): 52 | process_one_params={ 53 | 'dataset_name': dataset_name, 54 | 'obj_id': item.get('id'), 55 | 'obj_path': item.get('obj_path'), 56 | 'obj_save_name' : item.get('obj_name'), 57 | 'save_dir_npy' : os.path.join(f'p_{part_num:04}_v{version:02}', get_savedirs()['npy']), 58 | 'save_dir_meta' : os.path.join(f'p_{part_num:04}_v{version:02}', get_savedirs()['meta']), 59 | 'save_dir_debug' : os.path.join(f'p_{part_num:04}_v{version:02}', get_savedirs()['debug']), 60 | 'resolution': resolution, 61 | 'mode': version, 62 | 'face_type': item.get('face_type'), 63 | } 64 | return process_one_params 65 | 66 | def update_item(xlsx_line_new, item): 67 | for key in xlsx_line_new.keys(): 68 | item[key]=xlsx_line_new[key] 69 | return item 70 | 71 | 72 | def setup_logger(obj_id: int, obj_name: str, log_dir: str) -> logging.Logger: 73 | # 74 | logger = logging.getLogger(f'{obj_name}') 75 | logger.setLevel(logging.ERROR) 76 | 77 | # 78 | os.makedirs(log_dir, exist_ok=True) 79 | 80 | # 81 | log_file_path = os.path.join(log_dir, f"{obj_id:08}_{obj_name}.log") 82 | 83 | file_handler = logging.FileHandler(log_file_path) 84 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 85 | file_handler.setFormatter(formatter) 86 | 87 | logger.addHandler(file_handler) 88 | return logger 89 | 90 | @func_set_timeout(2) 91 | def process_one_tst(**kwargs): 92 | a=random.randint(0, 4) 93 | time.sleep(a) 94 | return {'d':3} 95 | 96 | def process_item(item: Dict[str, Any], log_dir: str, dataset_name: str, resolution: int, version: int, part_num: int) -> Dict[str, Any]: 97 | 98 | 99 | vert_num=item.get('vert_num') 100 | face_num=item.get('face_num') 101 | if vert_num is not None and face_num is not None: 102 | if face_num > 16000: 103 | # logger = setup_logger(obj_id, obj_name, log_dir) # 104 | # logger.error("An error occurred") 105 | # logger.error(traceback.format_exc()) 106 | item['bug_info']=f'too many face at first' 107 | item['done']=0 108 | 109 | return item 110 | 111 | try: 112 | # 113 | input_params=init_process_one_params(dataset_name=dataset_name, item=item, resolution=resolution, version=version, part_num=part_num) 114 | result = process_one(**input_params) 115 | updated_item=update_item(result, item) 116 | 117 | return updated_item 118 | except func_timeout.exceptions.FunctionTimedOut as ee: 119 | # logger = setup_logger(obj_id, obj_name, log_dir) # 120 | # logger.error("An error occurred: %s", str(ee)) 121 | # logger.error(traceback.format_exc()) 122 | item['bug_info']=f'Time out' 123 | item['done']=0 124 | return item 125 | except Exception as e: 126 | # 127 | # logger = setup_logger(obj_id, obj_name, log_dir) # 128 | # logger.error("An error occurred: %s", str(e)) 129 | # logger.error(traceback.format_exc()) 130 | item['bug_info']=f'{str(e)}' 131 | item['done']=0 132 | return item 133 | 134 | 135 | def process_dataset(xlsx_path_platform, max_workers=24, time_out=30, head=None, batch_ind=0, b_l=1000): 136 | basic_dir_platform=os.path.dirname(xlsx_path_platform) 137 | 138 | 139 | dataset_name, resolution, version, part_num=extract_xlsx_param(xlsx_path_platform) 140 | log_dir=os.path.join(basic_dir_platform, f'p_{part_num:04}_v{version:02}', get_savedirs()['error']) 141 | with pd.ExcelFile(xlsx_path_platform) as xls: 142 | df = pd.read_excel(xls) 143 | if head is not None: 144 | df = df.head(head) 145 | 146 | total_items=len(df) 147 | 148 | 149 | results = [] 150 | df_batches = [] 151 | for i in range(0, total_items, b_l): 152 | batch = df.iloc[i:i + b_l] 153 | df_batches.append(batch) 154 | 155 | print(f'Batch num: {len(df_batches)}, batch size: {len(batch)}') 156 | 157 | for b_i, batch in enumerate(df_batches): 158 | if b_i != batch_ind: 159 | continue 160 | 161 | done_xlsx_name=os.path.basename(xlsx_path_platform) 162 | done_xlsx_name=os.path.splitext(done_xlsx_name)[0] + f'_done_b{b_i:03}.xlsx' 163 | save_path_done_xlsx=os.path.join(basic_dir_platform, done_xlsx_name) 164 | if os.path.exists(save_path_done_xlsx): 165 | print(f'find {save_path_done_xlsx} exists, skip') 166 | continue 167 | 168 | results=[] 169 | with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor: 170 | future_to_item = {executor.submit(process_item, row.to_dict(), log_dir, dataset_name, resolution, version, part_num): row for index, row in batch.iterrows()} 171 | 172 | 173 | with tqdm(total=len(batch), desc=f"{dataset_name}-p_{part_num:04}-b{b_i}/{len(df_batches)}", unit="item") as pbar: 174 | for future in concurrent.futures.as_completed(future_to_item): 175 | item_dic = future_to_item[future].to_dict() 176 | try: 177 | result = future.result(timeout=time_out) 178 | results.append(result) 179 | # except concurrent.futures.TimeoutError: # 180 | # item_dic['bug_info'] = 'Timeout occurred' 181 | # item_dic['done'] = 0 182 | # results.append(item_dic) 183 | 184 | except Exception as e: 185 | item_dic['bug_info']=f'{str(e)}' 186 | item_dic['done']=0 187 | results.append(item_dic) 188 | 189 | # pbar.update(progress_queue.get()) 190 | pbar.update(1) 191 | 192 | 193 | results_df = pd.DataFrame(results) 194 | results_df = results_df.sort_values(by='id') 195 | 196 | with pd.ExcelWriter(save_path_done_xlsx) as writer: 197 | results_df.to_excel(writer, index=False) 198 | 199 | 200 | if __name__ == "__main__": 201 | 202 | parser = argparse.ArgumentParser(description='Process some data.') 203 | 204 | parser.add_argument('--dataset_name', type=str, required=True, 205 | help='The name of the dataset to process.') 206 | parser.add_argument('--part_ind', type=int, required=True, 207 | help='The index of the part to process.') 208 | parser.add_argument('--batch_ind', type=int, required=True, 209 | help='The index of the batch to process.') 210 | parser.add_argument('--resolution', type=int, required=True, 211 | help='The resolution to process.') 212 | parser.add_argument('--version', type=int, required=True, 213 | help='The version to process.') 214 | parser.add_argument('--head', type=int, required=True, 215 | help='The head to process.') 216 | parser.add_argument('--max_workers', type=int, default=64, 217 | help='The maximum number of workers to use (default: 64).') 218 | parser.add_argument('--b_l', type=int, default=1000, 219 | help='batch_length default 1000') 220 | 221 | args = parser.parse_args() 222 | dataset_name=args.dataset_name 223 | part_ind=args.part_ind 224 | batch_ind=args.batch_ind 225 | max_workers=args.max_workers 226 | reso=args.resolution 227 | version=args.version 228 | head=args.head 229 | if head==-1: 230 | head=None 231 | b_l=args.b_l 232 | 233 | # xlsx_base_dir=base_dir() 234 | 235 | # data_set_dir={ 236 | # # "3dfuture": '3DFuture/3DFuture', # H800 237 | # # 'toys4k': 'TOYS4K/TOYS4K', # H800 238 | # "3dfuture": '3DFuture', # local 239 | # 'toys4k': 'TOYS4K', # local 240 | # 'objaversev1': 'Objaverse', 241 | # 'abo': 'ABO', 242 | # 'thingi10k': 'Thingi10K/Thingi10K/Thingi10K', 243 | # 'shapenetv2': 'shapenet/shapenet/ShapeNetCore.v2', 244 | # 'animal3d': 'animal3d/animal3d', 245 | # '3dcaricshop': '3DCaricShop', 246 | # 'buildingnet': 'BuildingNet/BuildingNet', 247 | # 'gso': 'GSO/GSO' 248 | # } 249 | 250 | xlsx_name=f'meta_all_{dataset_name}_res{reso}_v{version:02}_p{part_ind:04}.xlsx' 251 | # xlsx_full_path=os.path.join(xlsx_base_dir, data_set_dir[dataset_name], xlsx_name) 252 | xlsx_full_path=os.path.join(get_base_dir_platform(dataset_name), xlsx_name) 253 | if not os.path.exists(xlsx_full_path): 254 | print(f'[WARNING] can not find {xlsx_full_path}, file part wrong') 255 | else: 256 | process_dataset(xlsx_path_platform=xlsx_full_path, max_workers=max_workers, time_out=30, head=head, batch_ind=batch_ind, b_l=b_l) 257 | -------------------------------------------------------------------------------- /miche/michelangelo/models/modules/transformer_blocks.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from typing import Optional 8 | 9 | from miche.michelangelo.models.modules.checkpoint import checkpoint 10 | 11 | # Initialize linear layers with normal distribution weights and zero biases 12 | def init_linear(l, stddev): 13 | nn.init.normal_(l.weight, std=stddev) 14 | if l.bias is not None: 15 | nn.init.constant_(l.bias, 0.0) 16 | 17 | # Multihead attention module 18 | class MultiheadAttention(nn.Module): 19 | def __init__( 20 | self, 21 | *, 22 | device: torch.device, 23 | dtype: torch.dtype, 24 | n_ctx: int, # Context size 25 | width: int, # Width of the input tensor 26 | heads: int, # Number of attention heads 27 | init_scale: float, # Initialization scale for weights 28 | qkv_bias: bool, # Whether to use bias in QKV layers 29 | flash: bool = False # Whether to use flash attention 30 | ): 31 | super().__init__() 32 | self.n_ctx = n_ctx 33 | self.width = width 34 | self.heads = heads 35 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias, device=device, dtype=dtype) 36 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 37 | self.attention = QKVMultiheadAttention(device=device, dtype=dtype, heads=heads, n_ctx=n_ctx, flash=flash) 38 | init_linear(self.c_qkv, init_scale) 39 | init_linear(self.c_proj, init_scale) 40 | 41 | def forward(self, x): 42 | x = self.c_qkv(x) 43 | x = checkpoint(self.attention, (x,), (), True) 44 | x = self.c_proj(x) 45 | return x 46 | 47 | # QKV multihead attention module 48 | class QKVMultiheadAttention(nn.Module): 49 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, n_ctx: int, flash: bool = False): 50 | super().__init__() 51 | self.device = device 52 | self.dtype = dtype 53 | self.heads = heads 54 | self.n_ctx = n_ctx 55 | self.flash = flash 56 | 57 | def forward(self, qkv): 58 | bs, n_ctx, width = qkv.shape 59 | attn_ch = width // self.heads // 3 60 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 61 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 62 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 63 | 64 | if self.flash: 65 | out = F.scaled_dot_product_attention(q, k, v) 66 | else: 67 | weight = torch.einsum( 68 | "bthc,bshc->bhts", q * scale, k * scale 69 | ) # More stable with f16 than dividing afterwards 70 | wdtype = weight.dtype 71 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 72 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 73 | 74 | return out 75 | 76 | # Residual attention block module 77 | class ResidualAttentionBlock(nn.Module): 78 | def __init__( 79 | self, 80 | *, 81 | device: torch.device, 82 | dtype: torch.dtype, 83 | use_checkpoint: bool = False, 84 | n_ctx: int, # Context size 85 | width: int, # Width of the input tensor 86 | heads: int, # Number of attention heads 87 | init_scale: float, # Initialization scale for weights 88 | qkv_bias: bool, # Whether to use bias in QKV layers 89 | flash: bool = False # Whether to use flash attention 90 | ): 91 | super().__init__() 92 | 93 | self.use_checkpoint = use_checkpoint 94 | 95 | self.attn = MultiheadAttention( 96 | device=device, 97 | dtype=dtype, 98 | n_ctx=n_ctx, 99 | width=width, 100 | heads=heads, 101 | init_scale=init_scale, 102 | qkv_bias=qkv_bias, 103 | flash=flash 104 | ) 105 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 106 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 107 | self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype) 108 | 109 | def _forward(self, x: torch.Tensor): 110 | x = x + self.attn(self.ln_1(x)) 111 | x = x + self.mlp(self.ln_2(x)) 112 | return x 113 | 114 | def forward(self, x: torch.Tensor): 115 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 116 | 117 | # Multihead cross attention module 118 | class MultiheadCrossAttention(nn.Module): 119 | def __init__( 120 | self, 121 | *, 122 | device: torch.device, 123 | dtype: torch.dtype, 124 | n_data: Optional[int] = None, 125 | data_width: Optional[int] = None, 126 | width: int, # Width of the input tensor 127 | heads: int, # Number of attention heads 128 | init_scale: float, # Initialization scale for weights 129 | qkv_bias: bool, # Whether to use bias in QKV layers 130 | flash: bool = False # Whether to use flash attention 131 | ): 132 | super().__init__() 133 | self.n_data = n_data 134 | self.width = width 135 | self.heads = heads 136 | self.data_width = width if data_width is None else data_width 137 | self.c_q = nn.Linear(width, width, bias=qkv_bias, device=device, dtype=dtype) 138 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias, device=device, dtype=dtype) 139 | self.c_proj = nn.Linear(width, width, device=device, dtype=dtype) 140 | self.attention = QKVMultiheadCrossAttention( 141 | device=device, dtype=dtype, heads=heads, n_data=n_data, flash=flash 142 | ) 143 | init_linear(self.c_q, init_scale) 144 | init_linear(self.c_kv, init_scale) 145 | init_linear(self.c_proj, init_scale) 146 | 147 | def forward(self, x, data): 148 | x = self.c_q(x) 149 | data = self.c_kv(data) 150 | x = checkpoint(self.attention, (x, data), (), True) 151 | x = self.c_proj(x) 152 | return x 153 | 154 | # QKV multihead cross attention module 155 | class QKVMultiheadCrossAttention(nn.Module): 156 | def __init__(self, *, device: torch.device, dtype: torch.dtype, heads: int, 157 | flash: bool = False, n_data: Optional[int] = None): 158 | 159 | super().__init__() 160 | self.device = device 161 | self.dtype = dtype 162 | self.heads = heads 163 | self.n_data = n_data 164 | self.flash = flash 165 | 166 | def forward(self, q, kv): 167 | _, n_ctx, _ = q.shape 168 | bs, n_data, width = kv.shape 169 | attn_ch = width // self.heads // 2 170 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 171 | q = q.view(bs, n_ctx, self.heads, -1) 172 | kv = kv.view(bs, n_data, self.heads, -1) 173 | k, v = torch.split(kv, attn_ch, dim=-1) 174 | 175 | if self.flash: 176 | out = F.scaled_dot_product_attention(q, k, v) 177 | else: 178 | weight = torch.einsum( 179 | "bthc,bshc->bhts", q * scale, k * scale 180 | ) # More stable with f16 than dividing afterwards 181 | wdtype = weight.dtype 182 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 183 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 184 | 185 | return out 186 | 187 | # Residual cross attention block module 188 | class ResidualCrossAttentionBlock(nn.Module): 189 | def __init__( 190 | self, 191 | *, 192 | device: Optional[torch.device], 193 | dtype: Optional[torch.dtype], 194 | n_data: Optional[int] = None, 195 | data_width: Optional[int] = None, 196 | width: int, # Width of the input tensor 197 | heads: int, # Number of attention heads 198 | init_scale: float, # Initialization scale for weights 199 | qkv_bias: bool, # Whether to use bias in QKV layers 200 | flash: bool = False # Whether to use flash attention 201 | ): 202 | super().__init__() 203 | 204 | if data_width is None: 205 | data_width = width 206 | 207 | self.attn = MultiheadCrossAttention( 208 | device=device, 209 | dtype=dtype, 210 | n_data=n_data, 211 | width=width, 212 | heads=heads, 213 | data_width=data_width, 214 | init_scale=init_scale, 215 | qkv_bias=qkv_bias, 216 | flash=flash, 217 | ) 218 | self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype) 219 | self.ln_2 = nn.LayerNorm(data_width, device=device, dtype=dtype) 220 | self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale) 221 | self.ln_3 = nn.LayerNorm(width, device=device, dtype=dtype) 222 | 223 | def forward(self, x: torch.Tensor, data: torch.Tensor): 224 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 225 | x = x + self.mlp(self.ln_3(x)) 226 | return x 227 | 228 | # MLP Module 229 | class MLP(nn.Module): 230 | def __init__(self, *, 231 | device: Optional[torch.device], 232 | dtype: Optional[torch.dtype], 233 | width: int, 234 | init_scale: float): 235 | super().__init__() 236 | self.width = width 237 | self.c_fc = nn.Linear(width, width * 4, device=device, dtype=dtype) 238 | self.c_proj = nn.Linear(width * 4, width, device=device, dtype=dtype) 239 | self.gelu = nn.GELU() 240 | init_linear(self.c_fc, init_scale) 241 | init_linear(self.c_proj, init_scale) 242 | 243 | def forward(self, x): 244 | return self.c_proj(self.gelu(self.c_fc(x))) 245 | 246 | # Transformer Module 247 | class Transformer(nn.Module): 248 | def __init__( 249 | self, 250 | *, 251 | device: Optional[torch.device], 252 | dtype: Optional[torch.dtype], 253 | layers: int, 254 | use_checkpoint: bool = False, 255 | n_ctx: int, # Context size 256 | width: int, # Width of the input tensor 257 | heads: int, # Number of attention heads 258 | init_scale: float, # Initialization scale for weights 259 | qkv_bias: bool, # Whether to use bias in QKV layers 260 | flash: bool = False # Whether to use flash attention 261 | ): 262 | super().__init__() 263 | self.n_ctx = n_ctx 264 | self.width = width 265 | self.layers = layers 266 | self.resblocks = nn.ModuleList( 267 | [ 268 | ResidualAttentionBlock( 269 | device=device, 270 | dtype=dtype, 271 | n_ctx=n_ctx, 272 | width=width, 273 | heads=heads, 274 | init_scale=init_scale, 275 | qkv_bias=qkv_bias, 276 | flash=flash, 277 | use_checkpoint=use_checkpoint 278 | ) 279 | for _ in range(layers) 280 | ] 281 | ) 282 | 283 | def forward(self, x: torch.Tensor): 284 | for block in self.resblocks: 285 | x = block(x) 286 | return x 287 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | 2 | import ipdb 3 | import os 4 | import tyro 5 | import math 6 | import time 7 | import shutil 8 | from functools import partial 9 | import traceback 10 | import torch 11 | from accelerate import Accelerator, DistributedDataParallelKwargs 12 | # from accelerate.utils import DummyOptim, DummyScheduler 13 | from safetensors.torch import load_file 14 | from model.model import SSMeshTransformer 15 | from config.options import AllConfigs 16 | from silkutils.silksong_tokenization import get_tokenizer_silksong 17 | from silkutils.meshdata.mesh_io import init_logger 18 | import kiui 19 | from model.data_provider import SSDataset, DebugOneDataset, collate_fn, ProgressivelyBalancedSampler 20 | 21 | # torch.autograd.set_detect_anomaly(True) 22 | 23 | 24 | 25 | def main(): 26 | opt = tyro.cli(AllConfigs) 27 | 28 | if opt.resume: 29 | print(f'resuming {opt.resume}') 30 | else: 31 | print(f'not resume') 32 | 33 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 34 | accelerator = Accelerator( 35 | mixed_precision=opt.mixed_precision, 36 | gradient_accumulation_steps=opt.gradient_accumulation_steps, 37 | kwargs_handlers=[ddp_kwargs], 38 | ) 39 | 40 | os.makedirs(opt.workspace, exist_ok=True) 41 | logfile = os.path.join(opt.workspace, 'log.txt') 42 | logger = init_logger(logfile) 43 | 44 | # print options 45 | accelerator.print(opt) 46 | 47 | # tokenizer 48 | tokenizer, vocab_size = get_tokenizer_silksong(resolution=opt.discrete_bins, ss_mode=opt.ss_mode) 49 | 50 | print(f'---- engine word table size: {vocab_size}---------') 51 | # model 52 | model = SSMeshTransformer( 53 | dim = opt.model.dim, 54 | attn_depth = opt.model.depth, 55 | attn_dim_head = opt.model.attn_dim_head, 56 | attn_heads = opt.model.attn_heads, 57 | max_seq_len = opt.max_seq_length, 58 | dropout = opt.model.dropout, 59 | mode = opt.mode, 60 | num_discrete_coors= opt.meto.discrete_bins, 61 | block_size = opt.meto.block_size, 62 | offset_size = opt.meto.offset_size, 63 | conditioned_on_pc = opt.model.conditioned_on_pc, 64 | encoder_name = opt.model.encoder_name, 65 | encoder_freeze = opt.model.encoder_freeze, 66 | ) 67 | 68 | # resume 69 | if opt.resume is not None: 70 | print(f'resuming {opt.resume}') 71 | if opt.resume.endswith('safetensors'): 72 | ckpt = load_file(opt.resume, device='cpu') 73 | else: 74 | ckpt = torch.load(opt.resume, map_location='cpu') 75 | 76 | # tolerant load (only load matching shapes) 77 | state_dict = model.state_dict() 78 | for k, v in ckpt.items(): 79 | if k in state_dict: 80 | if state_dict[k].shape == v.shape: 81 | state_dict[k].copy_(v) 82 | else: 83 | logger.warning(f'mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') 84 | print(f'[WARNING] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.') 85 | else: 86 | logger.warning(f'unexpected param {k}: {v.shape}') 87 | print(f'[WARNING] unexpected param {k}: {v.shape}') 88 | 89 | # count params 90 | num_p = sum(p.numel() for p in model.parameters() if p.requires_grad) 91 | num_decoder_p = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) 92 | total_p = sum(p.numel() for p in model.parameters()) 93 | logger.info(f'trainable param num: {num_p/1024/1024:.6f} M, GPT param num: {num_decoder_p/1024/1024:.6f} M, total param num: {total_p/1024/1024:.6f}') 94 | 95 | # data 96 | if opt.data.dataset == 'ss': 97 | train_dataset = SSDataset(opt, training=True, tokenizer=tokenizer) 98 | test_dataset = SSDataset(opt, training=False, tokenizer=tokenizer) 99 | logger.info(f'train dataset size: {len(train_dataset)}') 100 | logger.info(f'test dataset size: {len(test_dataset)}') 101 | 102 | elif opt.data.dataset == 'debug_one': 103 | train_dataset = DebugOneDataset(opt, training=True, tokenizer=tokenizer) 104 | test_dataset = DebugOneDataset(opt, training=False, tokenizer=tokenizer) 105 | logger.info(f'train dataset size: {len(train_dataset)}') 106 | logger.info(f'test dataset size: {len(test_dataset)}') 107 | else: 108 | raise Exception('not implement dataset') 109 | 110 | if opt.data.dataset=='debug_one': 111 | train_dataloader = torch.utils.data.DataLoader( 112 | train_dataset, 113 | batch_size=1, 114 | shuffle=True, 115 | num_workers=1, 116 | pin_memory=True, 117 | drop_last=True, 118 | collate_fn=partial(collate_fn, opt=opt), 119 | ) 120 | test_dataloader = torch.utils.data.DataLoader( 121 | test_dataset, 122 | batch_size=1, 123 | shuffle=False, 124 | num_workers=1, 125 | pin_memory=True, 126 | drop_last=False, 127 | collate_fn=partial(collate_fn, opt=opt), 128 | ) 129 | else: 130 | if opt.data.resample: 131 | sampler = ProgressivelyBalancedSampler( 132 | opt, 133 | train_dataset, 134 | face_delta=opt.data.face_delta, 135 | initial_beta=opt.data.i_beta, 136 | final_beta=opt.data.e_beta, 137 | epochs=opt.train.num_epochs 138 | ) 139 | train_dataloader = torch.utils.data.DataLoader( 140 | train_dataset, 141 | batch_size=opt.data.batch_size, 142 | sampler=sampler, 143 | num_workers=opt.data.num_workers, 144 | pin_memory=True, 145 | drop_last=True, 146 | collate_fn=partial(collate_fn, opt=opt), 147 | ) 148 | else: 149 | train_dataloader = torch.utils.data.DataLoader( 150 | train_dataset, 151 | batch_size=opt.data.batch_size, 152 | shuffle=True, 153 | num_workers=opt.data.num_workers, 154 | pin_memory=True, 155 | drop_last=True, 156 | collate_fn=partial(collate_fn, opt=opt), 157 | ) 158 | test_dataloader = torch.utils.data.DataLoader( 159 | test_dataset, 160 | batch_size=opt.data.batch_size, 161 | shuffle=False, 162 | num_workers=opt.data.num_workers, 163 | pin_memory=True, 164 | drop_last=False, 165 | collate_fn=partial(collate_fn, opt=opt), 166 | ) 167 | 168 | 169 | # optimizer 170 | optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=0.01, betas=(0.9, 0.99)) 171 | 172 | total_steps = opt.num_epochs * len(train_dataloader) // opt.gradient_accumulation_steps 173 | def _lr_lambda(current_step, warmup_ratio=opt.warmup_ratio, num_cycles=0.5, min_ratio=0.5): 174 | progress = current_step / max(1, total_steps) 175 | if warmup_ratio > 0 and progress < warmup_ratio: 176 | return progress / warmup_ratio 177 | progress = (progress - warmup_ratio) / (1 - warmup_ratio) 178 | return max(min_ratio, min_ratio + (1 - min_ratio) * 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 179 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=_lr_lambda) 180 | 181 | if opt.resume is not None and not opt.ft: 182 | ckpt_stat_dir=os.path.dirname(opt.resume) 183 | ckpt_stat_path=os.path.join(ckpt_stat_dir, 'model_state.pth') 184 | if os.path.exists(ckpt_stat_path): 185 | logger.info(f'state path exist ! loading optimizer and scheduler') 186 | checkpoint_stat = torch.load(ckpt_stat_path, map_location='cpu') 187 | optimizer.load_state_dict(checkpoint_stat['optimizer']) 188 | scheduler.load_state_dict(checkpoint_stat['scheduler']) 189 | opt.train.resume_epoch=checkpoint_stat['epoch']+1 190 | logger.info(f'resume epoch: {opt.train.resume_epoch}') 191 | else: 192 | logger.info(f'[WARNING] no state, check you resume epoch {opt.resume_epoch}') 193 | # accelerate 194 | model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare( 195 | model, optimizer, train_dataloader, test_dataloader, scheduler 196 | ) 197 | 198 | # wandb 199 | if opt.use_wandb and accelerator.is_main_process: 200 | import wandb # set WAND_API_KEY in env 201 | wandb.init(entity="you entity", project='MeshSilksong', name=opt.workspace.replace('workspace_', ''), config=opt) 202 | 203 | # loop 204 | old_save_dirs = [] 205 | best_loss = 1e9 206 | for epoch in range(opt.train.resume_epoch, opt.train.num_epochs): 207 | if epoch%opt.save_epoch==0: 208 | save_dir = os.path.join(opt.workspace, f'ep{epoch:04d}') 209 | os.makedirs(save_dir, exist_ok=True) 210 | if opt.resample: 211 | sampler.update_epoch(epoch) 212 | beta_num = sampler.get_distribution_info()['beta'] 213 | logger.info(f'beta {beta_num} for epoch {epoch}') 214 | # train 215 | if not opt.debug_eval: 216 | model.train() 217 | # if opt.cond_mode == 'point_miche' and 'multi' in opt.workspace: 218 | # model._set_static_graph() 219 | total_loss = 0 220 | t_start = time.time() 221 | for i, data in enumerate(train_dataloader): 222 | with accelerator.accumulate(model): 223 | 224 | optimizer.zero_grad() 225 | codes=data['tokens'] 226 | pc=data['conds'] 227 | 228 | loss = model(codes=codes, pc=pc) 229 | 230 | accelerator.backward(loss) 231 | 232 | # gradient clipping 233 | if accelerator.sync_gradients: 234 | accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip) 235 | 236 | 237 | optimizer.step() 238 | scheduler.step() 239 | 240 | total_loss += loss.detach() 241 | 242 | if accelerator.is_main_process: 243 | # logging 244 | if i % 10 == 0: 245 | mem_free, mem_total = torch.cuda.mem_get_info() 246 | log = f"{epoch:03d}:{i}/{len(train_dataloader)} mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G lr: {scheduler.get_last_lr()[0]:.7f} loss: {loss.item():.6f}" 247 | logger.info(log) 248 | 249 | 250 | total_loss = accelerator.gather_for_metrics(total_loss).mean().item() 251 | torch.cuda.synchronize() 252 | t_end = time.time() 253 | if accelerator.is_main_process: 254 | total_loss /= len(train_dataloader) 255 | logger.info(f"Train epoch: {epoch} loss: {total_loss:.6f} time: {(t_end - t_start)/60:.2f}min") 256 | 257 | # wandb 258 | if opt.use_wandb: 259 | wandb.log({'train_loss': total_loss, 'epoch': epoch}) 260 | wandb.log({'lr': scheduler.get_last_lr()[0], 'epoch': epoch}) 261 | if opt.use_wandb and opt.resample: 262 | wandb.log({'sampler_beta': sampler.get_distribution_info()['beta'], 'epoch': epoch}) 263 | # checkpoint 264 | if epoch % opt.save_epoch == 0 or epoch == opt.num_epochs - 1: 265 | accelerator.wait_for_everyone() 266 | print(f'epoch {epoch} done, saving') 267 | accelerator.save_model(model, save_dir) 268 | model_state={ 269 | 'optimizer': optimizer.state_dict(), 270 | 'scheduler': scheduler.state_dict(), 271 | 'epoch': epoch, 272 | } 273 | if accelerator.is_main_process: 274 | # save state 275 | state_path=os.path.join(save_dir, 'model_state.pth') 276 | torch.save(model_state, state_path) 277 | # symlink latest checkpoint for linux 278 | if os.name == 'posix': 279 | os.system(f'ln -sf {os.path.join(f"ep{epoch:04d}", "model.safetensors")} {os.path.join(opt.workspace, "model.safetensors")}') 280 | # copy best checkpoint 281 | if total_loss < best_loss: 282 | best_loss = total_loss 283 | shutil.copy(os.path.join(save_dir, 'model.safetensors'), os.path.join(opt.workspace, 'best.safetensors')) 284 | old_save_dirs.append(save_dir) 285 | if len(old_save_dirs) > 3 and opt.dataset!='debug_one': # save at most 3 ckpts 286 | shutil.rmtree(old_save_dirs.pop(0)) 287 | else: 288 | if accelerator.is_main_process: 289 | logger.info(f"epoch: {epoch} skip training for debug !!!") 290 | 291 | # eval 292 | print(f'evaluating, eval_mode {opt.eval_mode}') 293 | if opt.eval_mode == 'loss' and epoch % opt.save_epoch ==0: 294 | model.eval() 295 | with torch.no_grad(): 296 | total_loss = 0 297 | for i, data in enumerate(test_dataloader): 298 | codes=data['tokens'] 299 | pc=data['conds'] 300 | 301 | loss = model(codes=codes, pc=pc) 302 | 303 | 304 | total_loss += loss.detach() 305 | 306 | total_loss = accelerator.gather_for_metrics(total_loss).mean() 307 | if accelerator.is_main_process: 308 | total_loss /= len(test_dataloader) 309 | logger.info(f"Eval epoch: {epoch} loss: {total_loss:.6f}") 310 | if opt.use_wandb: 311 | wandb.log({'eval_loss': total_loss, 'epoch': epoch}) 312 | 313 | # else: 314 | # pass 315 | # if accelerator.is_main_process: 316 | # logger.info(f"Eval epoch: {epoch} skip evaluation.") 317 | 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /miche/michelangelo/models/tsal/asl_pl_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from typing import List, Tuple, Dict, Optional 4 | from omegaconf import DictConfig 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | from torch.optim import lr_scheduler 10 | from typing import Union 11 | from functools import partial 12 | 13 | from miche.michelangelo.utils import instantiate_from_config 14 | 15 | from .tsal_base import ( 16 | AlignedShapeAsLatentModule, 17 | ShapeAsLatentModule, 18 | Latent2MeshOutput, 19 | AlignedMeshOutput 20 | ) 21 | from miche.michelangelo.models.tsal.inference_utils import extract_geometry 22 | import trimesh 23 | 24 | class AlignedShapeAsLatentPLModule(nn.Module): 25 | def __init__(self, *, 26 | shape_module_cfg, 27 | aligned_module_cfg, 28 | loss_cfg, 29 | optimizer_cfg: Optional[DictConfig] = None, 30 | ckpt_path: Optional[str] = None, 31 | ignore_keys: Union[Tuple[str], List[str]] = ()): 32 | 33 | super().__init__() 34 | 35 | shape_model: ShapeAsLatentModule = instantiate_from_config( 36 | shape_module_cfg, device=None, dtype=None 37 | ) 38 | self.model: AlignedShapeAsLatentModule = instantiate_from_config( 39 | aligned_module_cfg, shape_model=shape_model 40 | ) 41 | 42 | self.loss = instantiate_from_config(loss_cfg) 43 | 44 | self.optimizer_cfg = optimizer_cfg 45 | 46 | if ckpt_path is not None: 47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 48 | 49 | def set_shape_model_only(self): 50 | self.model.set_shape_model_only() 51 | 52 | @property 53 | def latent_shape(self): 54 | return self.model.shape_model.latent_shape 55 | 56 | @property 57 | def zero_rank(self): 58 | if self._trainer: 59 | zero_rank = self.trainer.local_rank == 0 60 | else: 61 | zero_rank = True 62 | 63 | return zero_rank 64 | 65 | def init_from_ckpt(self, path, ignore_keys=()): 66 | state_dict = torch.load(path, map_location="cpu")["state_dict"] 67 | 68 | keys = list(state_dict.keys()) 69 | for k in keys: 70 | for ik in ignore_keys: 71 | if k.startswith(ik): 72 | print("Deleting key {} from state_dict.".format(k)) 73 | del state_dict[k] 74 | 75 | missing, unexpected = self.load_state_dict(state_dict, strict=False) 76 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 77 | if len(missing) > 0: 78 | print(f"Missing Keys: {missing}") 79 | print(f"Unexpected Keys: {unexpected}") 80 | 81 | def configure_optimizers(self) -> Tuple[List, List]: 82 | lr = self.learning_rate 83 | 84 | trainable_parameters = list(self.model.parameters()) 85 | 86 | if self.optimizer_cfg is None: 87 | optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)] 88 | schedulers = [] 89 | else: 90 | optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters) 91 | scheduler_func = instantiate_from_config( 92 | self.optimizer_cfg.scheduler, 93 | max_decay_steps=self.trainer.max_steps, 94 | lr_max=lr 95 | ) 96 | scheduler = { 97 | "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule), 98 | "interval": "step", 99 | "frequency": 1 100 | } 101 | optimizers = [optimizer] 102 | schedulers = [scheduler] 103 | 104 | return optimizers, schedulers 105 | 106 | def forward(self, 107 | surface: torch.FloatTensor, 108 | image: torch.FloatTensor, 109 | text: torch.FloatTensor, 110 | volume_queries: torch.FloatTensor): 111 | # Args: 112 | # surface (torch.FloatTensor): 113 | # image (torch.FloatTensor): 114 | # text (torch.FloatTensor): 115 | # volume_queries (torch.FloatTensor): 116 | # 117 | # Returns: 118 | 119 | embed_outputs, shape_z = self.model(surface, image, text) 120 | 121 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 122 | latents = self.model.shape_model.decode(shape_zq) 123 | logits = self.model.shape_model.query_geometry(volume_queries, latents) 124 | 125 | return embed_outputs, logits, posterior 126 | 127 | def encode(self, surface: torch.FloatTensor, sample_posterior=True): 128 | 129 | pc = surface[..., 0:3] 130 | feats = surface[..., 3:6] 131 | 132 | shape_embed, shape_zq, posterior = self.model.shape_model.encode( 133 | pc=pc, feats=feats, sample_posterior=sample_posterior 134 | ) 135 | 136 | return shape_zq 137 | 138 | def encode_latents(self, surface: torch.FloatTensor): 139 | 140 | pc = surface[..., 0:3] 141 | feats = surface[..., 3:6] 142 | 143 | shape_embed, shape_latents = self.model.shape_model.encode_latents( 144 | pc=pc, feats=feats 145 | ) 146 | shape_embed = shape_embed.unsqueeze(1) 147 | assert shape_embed.shape[1] == 1 and shape_latents.shape[1] == 256 148 | cat_latents = torch.cat([shape_embed, shape_latents], dim=1) 149 | 150 | return cat_latents 151 | 152 | def recon(self, surface): 153 | cat_latents = self.encode_latents(surface) 154 | shape_latents = cat_latents[:, 1:] 155 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_latents) 156 | 157 | # decoding 158 | latents = self.model.shape_model.decode(shape_zq) 159 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 160 | 161 | # reconstruction 162 | mesh_v_f, has_surface = extract_geometry( 163 | geometric_func=geometric_func, 164 | device=surface.device, 165 | batch_size=surface.shape[0], 166 | bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 167 | octree_depth=7, 168 | num_chunks=10000, 169 | ) 170 | recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) 171 | 172 | return recon_mesh 173 | 174 | 175 | def to_shape_latents(self, latents): 176 | 177 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(latents, sample_posterior = False) 178 | return self.model.shape_model.decode(shape_zq) 179 | 180 | def decode(self, 181 | z_q, 182 | bounds: Union[Tuple[float], List[float], float] = 1.1, 183 | octree_depth: int = 7, 184 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 185 | 186 | latents = self.model.shape_model.decode(z_q) # latents: [bs, num_latents, dim] 187 | outputs = self.latent2mesh(latents, bounds=bounds, octree_depth=octree_depth, num_chunks=num_chunks) 188 | 189 | return outputs 190 | 191 | def training_step(self, batch: Dict[str, torch.FloatTensor], 192 | batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor: 193 | #Args: 194 | # batch (dict): the batch sample, and it contains: 195 | # - surface (torch.FloatTensor): [bs, n_surface, (3 + input_dim)] 196 | # - image (torch.FloatTensor): [bs, 3, 224, 224] 197 | # - text (torch.FloatTensor): [bs, num_templates, 77] 198 | # - geo_points (torch.FloatTensor): [bs, n_pts, (3 + 1)] 199 | # 200 | # batch_idx (int): 201 | # 202 | # optimizer_idx (int): 203 | # 204 | # Returns: 205 | # loss (torch.FloatTensor): 206 | 207 | surface = batch["surface"] 208 | image = batch["image"] 209 | text = batch["text"] 210 | 211 | volume_queries = batch["geo_points"][..., 0:3] 212 | shape_labels = batch["geo_points"][..., -1] 213 | 214 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 215 | 216 | aeloss, log_dict_ae = self.loss( 217 | **embed_outputs, 218 | posteriors=posteriors, 219 | shape_logits=shape_logits, 220 | shape_labels=shape_labels, 221 | split="train" 222 | ) 223 | 224 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 225 | sync_dist=False, rank_zero_only=True) 226 | 227 | return aeloss 228 | 229 | def validation_step(self, batch: Dict[str, torch.FloatTensor], batch_idx: int) -> torch.FloatTensor: 230 | 231 | surface = batch["surface"] 232 | image = batch["image"] 233 | text = batch["text"] 234 | 235 | volume_queries = batch["geo_points"][..., 0:3] 236 | shape_labels = batch["geo_points"][..., -1] 237 | 238 | embed_outputs, shape_logits, posteriors = self(surface, image, text, volume_queries) 239 | 240 | aeloss, log_dict_ae = self.loss( 241 | **embed_outputs, 242 | posteriors=posteriors, 243 | shape_logits=shape_logits, 244 | shape_labels=shape_labels, 245 | split="val" 246 | ) 247 | self.log_dict(log_dict_ae, prog_bar=True, logger=True, batch_size=shape_logits.shape[0], 248 | sync_dist=False, rank_zero_only=True) 249 | 250 | return aeloss 251 | 252 | def visual_alignment(self, 253 | surface: torch.FloatTensor, 254 | image: torch.FloatTensor, 255 | text: torch.FloatTensor, 256 | description: Optional[List[str]] = None, 257 | bounds: Union[Tuple[float], List[float]] = (-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), 258 | octree_depth: int = 7, 259 | num_chunks: int = 10000) -> List[AlignedMeshOutput]: 260 | 261 | """ 262 | 263 | Args: 264 | surface: 265 | image: 266 | text: 267 | description: 268 | bounds: 269 | octree_depth: 270 | num_chunks: 271 | 272 | Returns: 273 | mesh_outputs (List[AlignedMeshOutput]): the mesh outputs list. 274 | 275 | """ 276 | 277 | outputs = [] 278 | 279 | device = surface.device 280 | bs = surface.shape[0] 281 | 282 | embed_outputs, shape_z = self.model(surface, image, text) 283 | 284 | # calculate the similarity 285 | image_embed = embed_outputs["image_embed"] 286 | text_embed = embed_outputs["text_embed"] 287 | shape_embed = embed_outputs["shape_embed"] 288 | 289 | # normalized features 290 | shape_embed = F.normalize(shape_embed, dim=-1, p=2) 291 | text_embed = F.normalize(text_embed, dim=-1, p=2) 292 | image_embed = F.normalize(image_embed, dim=-1, p=2) 293 | 294 | # B x B 295 | shape_text_similarity = (100.0 * shape_embed @ text_embed.T).softmax(dim=-1) 296 | 297 | # B x B 298 | shape_image_similarity = (100.0 * shape_embed @ image_embed.T).softmax(dim=-1) 299 | 300 | # shape reconstruction 301 | shape_zq, posterior = self.model.shape_model.encode_kl_embed(shape_z) 302 | latents = self.model.shape_model.decode(shape_zq) 303 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 304 | 305 | # 2. decode geometry 306 | mesh_v_f, has_surface = extract_geometry( 307 | geometric_func=geometric_func, 308 | device=device, 309 | batch_size=bs, 310 | bounds=bounds, 311 | octree_depth=octree_depth, 312 | num_chunks=num_chunks, 313 | disable=not self.zero_rank 314 | ) 315 | 316 | # 3. decode texture 317 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 318 | if not is_surface: 319 | outputs.append(None) 320 | continue 321 | 322 | out = AlignedMeshOutput() 323 | out.mesh_v = mesh_v 324 | out.mesh_f = mesh_f 325 | out.surface = surface[i].cpu().numpy() 326 | out.image = image[i].cpu().numpy() 327 | if description is not None: 328 | out.text = description[i] 329 | out.shape_text_similarity = shape_text_similarity[i, i] 330 | out.shape_image_similarity = shape_image_similarity[i, i] 331 | 332 | outputs.append(out) 333 | 334 | return outputs 335 | 336 | def latent2mesh(self, 337 | latents: torch.FloatTensor, 338 | bounds: Union[Tuple[float], List[float], float] = 1.1, 339 | octree_depth: int = 7, 340 | num_chunks: int = 10000) -> List[Latent2MeshOutput]: 341 | 342 | """ 343 | 344 | Args: 345 | latents: [bs, num_latents, dim] 346 | bounds: 347 | octree_depth: 348 | num_chunks: 349 | 350 | Returns: 351 | mesh_outputs (List[MeshOutput]): the mesh outputs list. 352 | 353 | """ 354 | 355 | outputs = [] 356 | 357 | geometric_func = partial(self.model.shape_model.query_geometry, latents=latents) 358 | 359 | # 2. decode geometry 360 | device = latents.device 361 | mesh_v_f, has_surface = extract_geometry( 362 | geometric_func=geometric_func, 363 | device=device, 364 | batch_size=len(latents), 365 | bounds=bounds, 366 | octree_depth=octree_depth, 367 | num_chunks=num_chunks, 368 | disable=not self.zero_rank 369 | ) 370 | 371 | # 3. decode texture 372 | for i, ((mesh_v, mesh_f), is_surface) in enumerate(zip(mesh_v_f, has_surface)): 373 | if not is_surface: 374 | outputs.append(None) 375 | continue 376 | 377 | out = Latent2MeshOutput() 378 | out.mesh_v = mesh_v 379 | out.mesh_f = mesh_f 380 | 381 | outputs.append(out) 382 | 383 | return outputs 384 | --------------------------------------------------------------------------------