├── .cursorrules ├── .editorconfig ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── neural-mesh-simplification.iml └── vcs.xml ├── .python-version ├── LICENSE ├── README.md ├── configs ├── default.yaml └── local.yaml ├── examples └── example.py ├── pyproject.toml ├── pytest.ini ├── requirements.txt ├── scripts ├── download_test_meshes.py ├── evaluate.py ├── infer.py ├── preprocess_data.py └── train.py ├── src └── neural_mesh_simplification │ ├── __init__.py │ ├── api │ ├── __init__.py │ └── neural_mesh_simplifier.py │ ├── data │ ├── __init__.py │ └── dataset.py │ ├── losses │ ├── __init__.py │ ├── chamfer_distance_loss.py │ ├── combined_loss.py │ ├── edge_crossing_loss.py │ ├── overlapping_triangles_loss.py │ ├── surface_distance_loss.py │ └── triangle_collision_loss.py │ ├── metrics │ ├── __init__.py │ ├── chamfer_distance.py │ ├── edge_preservation.py │ ├── hausdorff_distance.py │ └── normal_consistency.py │ ├── models │ ├── __init__.py │ ├── edge_predictor.py │ ├── face_classifier.py │ ├── layers │ │ ├── __init__.py │ │ ├── devconv.py │ │ └── triconv.py │ ├── neural_mesh_simplification.py │ └── point_sampler.py │ ├── trainer │ ├── __init__.py │ ├── resource_monitor.py │ └── trainer.py │ └── utils │ ├── __init__.py │ └── mesh_operations.py ├── tests ├── __init__.py ├── conftest.py ├── losses │ ├── test_edge_crossings_loss.py │ ├── test_overlapping_triangles_loss.py │ ├── test_proba_chamfer_distance.py │ ├── test_proba_surface_distance.py │ └── test_triangle_collision_loss.py ├── mesh_data │ ├── cube.obj │ ├── rounded_cube.obj │ └── sharp_cube.obj ├── test_dataset.py ├── test_edge_predictor.py ├── test_face_classifier.py ├── test_mesh_operations.py ├── test_metrics.py ├── test_model.py ├── test_model_layers.py ├── test_point_sampler.py └── test_trimesh.py ├── train.ipynb └── uv.lock /.cursorrules: -------------------------------------------------------------------------------- 1 | I have thousands of 3D objects in various formats of meshes, level of detail, formats etc. I want to use pytorch to train a model to convert a high fidelity mesh to a lower level of detail, with minimal loss. 2 | 3 | We will work in an iterative manner, please try to follow these instructions and expectations to our process. 4 | 5 | 0. "The combined_output.txt holds code for the current code-base, of the important components. Take the current implementation into account when suggesting changes. Especially when a change in one file leads to a need to change another class or function. Suggesting new content for an existing class, should consider the current implementation and not just burst out based on assumptions what it should be." 6 | 7 | 1. "Analyze research papers thoroughly, breaking down complex concepts into clear, step-by-step explanations. Focus on architectural details, mathematical formulations, and key algorithms." 8 | 9 | 2. "When implementing new components, start with a high-level class structure. Outline necessary imports, initialization methods, and key function signatures before diving into detailed implementations." 10 | 11 | 3. "For each significant component or method, suggest comprehensive unit tests. Cover normal operations, edge cases, and potential failure modes." 12 | 13 | 4. "When addressing code errors, provide a detailed analysis of the error message and traceback. Explain the root cause and reasoning behind proposed solutions." 14 | 15 | 5. "Maintain consistency with the original research paper or design document. Regularly cross-reference implementations against the source material." 16 | 17 | 6. "Consider compatibility with relevant libraries and frameworks (e.g., PyTorch, PyTorch Geometric) when suggesting code implementations." 18 | 19 | 7. "Offer clear, intuitive explanations for complex mathematical concepts, especially those central to the project's core algorithms." 20 | 21 | 8. "Suggest visualization techniques for intermediate results to aid in debugging and understanding model behavior." 22 | 23 | 9. "Keep track of the overall project structure. Recommend refactoring or reorganization to maintain clean, modular, and efficient code." 24 | 25 | 10. "When implementing loss functions or evaluation metrics, explain the intuition and mathematical basis behind each component." 26 | 27 | 11. "Provide guidance on performance optimization, including suggestions for efficient data handling, model architecture improvements, and computational optimizations." 28 | 29 | 12. "Offer insights on potential extensions or modifications to the current project, based on recent advancements in the field or related research." 30 | 31 | 13. "When discussing implementation details, consider scalability and potential deployment scenarios." 32 | 33 | 14. "Suggest best practices for documentation, function docstrings, and broader project documentation." 34 | 35 | 15. "Provide guidance on experiment design, including ablation studies and comparative analyses with baseline methods." 36 | 37 | 16. "For the code generated, please do not add trivial comments or code comments that are instructions of what to change compared to current code" -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | charset = utf-8 5 | end_of_line = lf 6 | insert_final_newline = true 7 | trim_trailing_whitespace = true 8 | 9 | [*.py] 10 | indent_style = space 11 | indent_size = 4 12 | max_line_length = 180 -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Python CI 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v4 19 | - name: Set up Python 3.12 20 | uses: actions/setup-python@v3 21 | with: 22 | python-version: "3.12" 23 | - name: Install system dependencies 24 | run: | 25 | sudo apt-get update 26 | sudo apt-get install -y libspatialindex-dev 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu 31 | pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html 32 | pip install -r requirements.txt 33 | - name: Install this package 34 | run: | 35 | pip install -e . 36 | - name: Test with pytest 37 | run: | 38 | pytest 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | data/**/*.ply 3 | data/**/*.obj 4 | data/**/*.stl 5 | examples/data 6 | checkpoints 7 | 8 | file_structure.txt 9 | combined_output.txt 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # PyBuilder 69 | .pybuilder/ 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # IPython 76 | profile_default/ 77 | ipython_config.py 78 | 79 | # pyenv 80 | # For a library or package, you might want to ignore these files since the code is 81 | # intended to run in multiple environments; otherwise, check them in: 82 | # .python-version 83 | 84 | # pipenv 85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 88 | # install all needed dependencies. 89 | #Pipfile.lock 90 | 91 | # poetry 92 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 93 | # This is especially recommended for binary packages to ensure reproducibility, and is more 94 | # commonly ignored for libraries. 95 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 96 | #poetry.lock 97 | 98 | # pdm 99 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 100 | #pdm.lock 101 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 102 | # in version control. 103 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 104 | .pdm.toml 105 | .pdm-python 106 | .pdm-build/ 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | # pytype static type analyzer 146 | .pytype/ 147 | 148 | # Cython debug symbols 149 | cython_debug/ 150 | 151 | # PyCharm 152 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 153 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 154 | # and can be added to the global gitignore or merged into this file. For a more nuclear 155 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 156 | #.idea/ 157 | .idea/codestyles 158 | 159 | # Log files 160 | log/ 161 | 162 | # Generated files from tests 163 | tests/mesh_data 164 | 165 | # Example output 166 | examples/data/simplified 167 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/neural-mesh-simplification.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 17 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Martin Høst Normark 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Mesh Simplification 2 | 3 | Implementation of the 4 | paper [Neural Mesh Simplification paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Potamias_Neural_Mesh_Simplification_CVPR_2022_paper.pdf) 5 | by Potamias et al. (CVPR 2022) and the updated info shared 6 | in [supplementary material](https://openaccess.thecvf.com/content/CVPR2022/supplemental/Potamias_Neural_Mesh_Simplification_CVPR_2022_supplemental.pdf). 7 | 8 | This Python package provides a fast, learnable method for mesh simplification that generates simplified meshes in 9 | real-time. 10 | 11 | ### Overview 12 | 13 | Neural Mesh Simplification is a novel approach to reduce the resolution of 3D meshes while preserving their appearance. 14 | Unlike traditional simplification methods that collapse edges in a greedy iterative manner, this method simplifies a 15 | given mesh in one pass using deep learning techniques. 16 | 17 | The method consists of three main steps: 18 | 19 | 1. Sampling a subset of input vertices using a sophisticated extension of random sampling. 20 | 2. Training a sparse attention network to propose candidate triangles based on the edge connectivity of sampled 21 | vertices. 22 | 3. Using a classification network to estimate the probability that a candidate triangle will be included in the final 23 | mesh. 24 | 25 | ### Features 26 | 27 | - Fast and scalable mesh simplification 28 | - One-pass simplification process 29 | - Preservation of mesh appearance 30 | - Lightweight and differentiable implementation 31 | - Suitable for integration into learnable pipelines 32 | 33 | ### Installation 34 | 35 | ```bash 36 | conda create -n neural-mesh-simplification python=3.12 37 | conda activate neural-mesh-simplification 38 | conda install pip 39 | ``` 40 | 41 | Depending on whether you are using PyTorch on a CPU or a GPU, 42 | you'll have to use the correct binaries for PyTorch and the PyTorch Geometric libraries. You can install them via: 43 | 44 | ```bash 45 | pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu 46 | pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html 47 | ``` 48 | 49 | Replace “cpu” with “cu121” or the appropriate CUDA version for your system. If you don't know what is your cuda version, 50 | run `nvidia-smi` 51 | 52 | After that you can install the remaining requirements 53 | 54 | ```bash 55 | pip install -r requirements.txt 56 | pip install -e . 57 | ``` 58 | 59 | ### Example Usage / Playground 60 | 61 | 1. Drop your meshes as `.obj` files to the `examples/data` folder 62 | 2. Run the following command 63 | 64 | ```bash 65 | python examples/example.py 66 | ``` 67 | 68 | 3. Collect the simplified meshes in `examples/data`. The simplified mesh objects file name will be the ones prefixed 69 | with `simplified_`. 70 | 71 | ### Data Preparation 72 | 73 | If you don't have a dataset for training and evaluation, you can use a collection from HuggingFace's 3D Meshes dataset. 74 | See https://huggingface.co/datasets/perler/ppsurf for more information. 75 | 76 | Run the following script to use the HuggingFace API to download 77 | 78 | ```bash 79 | python scripts/download_test_meshes.py 80 | ``` 81 | 82 | Data will be downloaded in the `data/raw` folder at the root of the project. 83 | You can use `--target-folder` to specify a different folder. 84 | 85 | Once you have some data, you should preprocess it using the following script: 86 | 87 | ```bash 88 | python scripts/preprocess_data.py 89 | ``` 90 | 91 | You can use the `--data_path` argument to specify the path to the dataset. The script will create a `data/processed` 92 | 93 | ### Training 94 | 95 | To train the model on your own dataset with the prepared data: 96 | 97 | ```bash 98 | python ./scripts/train.py 99 | ``` 100 | 101 | By default, the default training config at `config/default.yaml` will be used. You can override it with 102 | `--config /path/to/your/config.yaml`.\ 103 | You can override the following config parameters: 104 | 105 | * the checkpoint directory specified in the config file (where the model will be saved) with 106 | `--checkpoint-dir`. 107 | * the data path with `--data-path` if you have your data in a different location. 108 | 109 | If the training was interrupted, you can resume it by specifying the path to the previously created checkpoint directory 110 | with `--resume`.\ 111 | Use `--debug` to see DEBUG logging. 112 | 113 | ### Evaluation 114 | 115 | To evaluate the model on a test set: 116 | 117 | ```bash 118 | python ./scripts/evaluate.py --eval-data-path /path/to/test/set --checkpoint /path/to/checkpoint.pth 119 | ``` 120 | 121 | By default, the default training config at `config/default.yaml` will be used. You can override it with 122 | `--config /path/to/your/config.yaml`. 123 | 124 | ### Inference 125 | 126 | To simplify a mesh using the trained model: 127 | 128 | ```bash 129 | python ./scripts/infer.py --input-file /path/to/your/mesh.obj --output-file --model-checkpoint /path/to/checkpoint.pth --device cpu 130 | ``` 131 | 132 | The default feature dimension for point sampler and face classifier is 128, but can be configured with 133 | `--hidden-dim `.\ 134 | The default feature dimension for edge predictor is 64, but can be configured with `--edge-hidden-dim `.\ 135 | If you have a CUDA-compatible GPU, you can specify `--device cuda` to use it for inference. 136 | 137 | ### Citation 138 | 139 | If you use this code in your research, please cite the original paper: 140 | 141 | ``` 142 | @InProceedings{Potamias_2022_CVPR, 143 | author = {Potamias, Rolandos Alexandros and Ploumpis, Stylianos and Zafeiriou, Stefanos}, 144 | title = {Neural Mesh Simplification}, 145 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 146 | month = {June}, 147 | year = {2022}, 148 | pages = {18583-18592} 149 | } 150 | ``` 151 | 152 | ## License 153 | 154 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 155 | m 156 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters 2 | model: 3 | input_dim: 3 4 | hidden_dim: 128 # feature dimension for point sampler and face classifier (as per paper) 5 | edge_hidden_dim: 64 # feature dimension for edge predictor (as per paper) 6 | num_layers: 3 # number of convolutional layers (as per paper) 7 | k: 15 # number of neighbors for graph construction (as per paper) 8 | edge_k: 15 # number of neighbors for edge features (as per paper) 9 | target_ratio: 0.5 # mesh simplification ratio 10 | 11 | # Training parameters 12 | training: 13 | learning_rate: 1.0e-5 14 | weight_decay: 0.99 # weight decay per epoch (as per paper) 15 | batch_size: 2 16 | num_epochs: 20 # total training epochs 17 | early_stopping_patience: 15 # epochs before early stopping 18 | checkpoint_dir: data/checkpoints # model save directory 19 | 20 | # Data parameters 21 | data: 22 | data_dir: data/processed 23 | val_split: 0.2 24 | 25 | # Loss weights 26 | loss: 27 | lambda_c: 1.0 # chamfer distance weight 28 | lambda_e: 1.0 # edge preservation weight 29 | lambda_o: 1.0 # normal consistency weight 30 | -------------------------------------------------------------------------------- /configs/local.yaml: -------------------------------------------------------------------------------- 1 | # Model parameters 2 | model: 3 | input_dim: 3 4 | hidden_dim: 64 # feature dimension for point sampler and face classifier (as per paper) 5 | edge_hidden_dim: 128 # feature dimension for edge predictor (as per paper) 6 | num_layers: 3 # number of convolutional layers (as per paper) 7 | k: 15 # number of neighbors for graph construction (as per paper) 8 | edge_k: 15 # number of neighbors for edge features (as per paper) 9 | target_ratio: 0.5 # mesh simplification ratio 10 | 11 | # Training parameters 12 | training: 13 | learning_rate: 1.0e-5 14 | weight_decay: 0.99 # weight decay per epoch (as per paper) 15 | batch_size: 2 16 | accumulation_steps: 4 17 | num_epochs: 20 # total training epochs 18 | early_stopping_patience: 10 # epochs before early stopping 19 | checkpoint_dir: data/checkpoints # model save directory 20 | 21 | # Data parameters 22 | data: 23 | data_dir: data/processed 24 | val_split: 0.4 25 | 26 | # Loss weights 27 | loss: 28 | lambda_c: 1.0 # chamfer distance weight 29 | lambda_e: 1.0 # edge preservation weight 30 | lambda_o: 1.0 # normal consistency weight 31 | -------------------------------------------------------------------------------- /examples/example.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import trimesh 4 | from trimesh import Scene, Trimesh 5 | 6 | from neural_mesh_simplification import NeuralMeshSimplifier 7 | from neural_mesh_simplification.data.dataset import load_mesh 8 | 9 | script_dir = os.path.dirname(os.path.abspath(__file__)) 10 | data_dir = os.path.join(script_dir, "data") 11 | default_config_path = os.path.join(script_dir, "../configs/default.yaml") 12 | 13 | 14 | def load_config(config_path): 15 | import yaml 16 | 17 | with open(config_path, "r") as file: 18 | config = yaml.safe_load(file) 19 | return config 20 | 21 | 22 | def save_mesh_to_file(mesh: trimesh.Geometry, file_name: str): 23 | """ 24 | Save the simplified mesh to file in the simplified folder. 25 | """ 26 | simplified_dir = os.path.join(data_dir, "processed") 27 | os.makedirs(simplified_dir, exist_ok=True) 28 | output_path = os.path.join(simplified_dir, file_name) 29 | mesh.export(output_path) 30 | 31 | print(f"Mesh saved to: {output_path}") 32 | 33 | 34 | def cube_example(simplifier: NeuralMeshSimplifier): 35 | print(f"Creating cube mesh") 36 | file = "cube.obj" 37 | mesh = trimesh.creation.box(extents=[2, 2, 2]) 38 | simplified_mesh = simplifier.simplify(mesh) 39 | save_mesh_to_file(mesh, file) 40 | save_mesh_to_file(simplified_mesh, f"simplified_{file}") 41 | 42 | 43 | def sphere_example(simplifier: NeuralMeshSimplifier): 44 | print(f"Creating sphere mesh") 45 | file = "sphere.obj" 46 | mesh = trimesh.creation.icosphere(subdivisions=2, radius=2) 47 | simplified_mesh = simplifier.simplify(mesh) 48 | save_mesh_to_file(mesh, file) 49 | save_mesh_to_file(simplified_mesh, f"simplified_{file}") 50 | 51 | 52 | def cylinder_example(simplifier: NeuralMeshSimplifier): 53 | print(f"Creating cylinder mesh") 54 | file = "cylinder.obj" 55 | mesh = trimesh.creation.cylinder(radius=1, height=2) 56 | simplified_mesh = simplifier.simplify(mesh) 57 | save_mesh_to_file(mesh, file) 58 | save_mesh_to_file(simplified_mesh, f"simplified_{file}") 59 | 60 | 61 | def mesh_dropbox_example(simplifier: NeuralMeshSimplifier): 62 | print(f"Loading all meshes of type '.obj' in folder '{data_dir}'") 63 | mesh_files = [f for f in os.listdir(data_dir) if f.endswith(".obj")] 64 | 65 | for file_name in mesh_files: 66 | mesh_path = os.path.join(data_dir, file_name) 67 | 68 | original_mesh = load_mesh(mesh_path) 69 | 70 | print("Loaded mesh at file" + mesh_path) 71 | 72 | # Create a new scene to hold the simplified meshes 73 | simplified_scene = Scene() 74 | 75 | if isinstance(original_mesh, Trimesh): 76 | print( 77 | "Original: ", 78 | original_mesh.vertices.shape, 79 | original_mesh.edges.shape, 80 | original_mesh.faces.shape, 81 | ) 82 | simplified_geom = simplifier.simplify(original_mesh) 83 | print( 84 | "Simplified: ", 85 | simplified_geom.vertices.shape, 86 | simplified_geom.edges.shape, 87 | simplified_geom.faces.shape, 88 | ) 89 | 90 | simplified_scene = simplified_geom 91 | 92 | elif isinstance(original_mesh, Scene): 93 | # Iterate through the original mesh geometry 94 | for name, geom in original_mesh.geometry.items(): 95 | print("Original: ", geom) 96 | # Simplify each Trimesh object 97 | simplified_geom = simplifier.simplify(geom) 98 | print("Simplified: ", simplified_geom) 99 | # Add the simplified geometry to the new scene 100 | simplified_scene.add_geometry(simplified_geom, geom_name=name) 101 | else: 102 | raise ValueError( 103 | "Invalid mesh type (expected Trimesh or Scene):", type(original_mesh) 104 | ) 105 | 106 | # Save the simplified mesh to file 107 | save_mesh_to_file(simplified_scene, f"simplified_{file_name}") 108 | 109 | 110 | def main(): 111 | # Initialize the simplifier 112 | config = load_config(config_path=default_config_path) 113 | simplifier = NeuralMeshSimplifier( 114 | input_dim=config["model"]["input_dim"], 115 | hidden_dim=config["model"]["hidden_dim"], 116 | edge_hidden_dim=config["model"]["edge_hidden_dim"], 117 | num_layers=config["model"]["num_layers"], 118 | k=config["model"]["k"], 119 | edge_k=config["model"]["edge_k"], 120 | target_ratio=config["model"]["target_ratio"], 121 | ) 122 | 123 | # cube_example(simplifier) 124 | # sphere_example(simplifier) 125 | # cylinder_example(simplifier) 126 | mesh_dropbox_example(simplifier) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "neural-mesh-simplification" 7 | version = "0.1.0" 8 | description = "A neural network-based approach to mesh simplification" 9 | authors = [{ name = "Martin Normark", email = "m@martinnormark.com" }, { name = "Gennaro Frazzingaro", email = "gennaro@thinair.ar" }] 10 | license = { file = "LICENSE" } 11 | readme = "README.md" 12 | requires-python = ">=3.12" 13 | classifiers = [ 14 | "Programming Language :: Python :: 3", 15 | "License :: OSI Approved :: MIT License", 16 | "Operating System :: OS Independent", 17 | ] 18 | 19 | dependencies = [ 20 | "numpy", 21 | "torch", 22 | "trimesh", 23 | "scipy", 24 | "matplotlib", 25 | "tqdm" 26 | ] 27 | 28 | [tool.setuptools.packages.find] 29 | where = ["src"] 30 | include = ["neural_mesh_simplification*"] 31 | namespaces = false 32 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | trimesh: Run only these tests when specified 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohappyeyeballs==2.3.5 2 | aiohttp==3.10.1 3 | aiosignal==1.3.1 4 | attrs==24.2.0 5 | certifi==2024.7.4 6 | charset-normalizer==3.3.2 7 | datasets==2.20.0 8 | dill==0.3.8 9 | filelock==3.15.4 10 | frozenlist==1.4.1 11 | fsspec==2024.5.0 12 | huggingface-hub==0.24.5 13 | idna==3.7 14 | iniconfig==2.0.0 15 | Jinja2==3.1.4 16 | joblib==1.4.2 17 | MarkupSafe==2.1.5 18 | mpmath==1.3.0 19 | multidict==6.0.5 20 | multiprocess==0.70.16 21 | networkx==3.3 22 | numpy==2.0.1 23 | packaging==24.1 24 | pandas==2.2.2 25 | pluggy==1.5.0 26 | psutil==6.0.0 27 | pyarrow==17.0.0 28 | pyarrow-hotfix==0.6 29 | pyparsing==3.1.2 30 | pytest==8.3.2 31 | python-dateutil==2.9.0.post0 32 | pytz==2024.1 33 | PyYAML==6.0.2 34 | requests==2.32.3 35 | Rtree==1.3.0 36 | scikit-learn==1.5.1 37 | scipy==1.14.0 38 | setuptools==72.1.0 39 | six==1.16.0 40 | sympy==1.13.1 41 | threadpoolctl==3.5.0 42 | tqdm==4.66.5 43 | trimesh==4.4.4 44 | typing_extensions==4.12.2 45 | tzdata==2024.1 46 | urllib3==2.2.2 47 | wheel==0.43.0 48 | xxhash==3.4.1 49 | yarl==1.9.4 50 | -------------------------------------------------------------------------------- /scripts/download_test_meshes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | from huggingface_hub import snapshot_download 6 | 7 | # abc_train is really large (+5k meshes) 8 | global folder_patterns 9 | folder_patterns = ["abc_extra_noisy/03_meshes/*.ply", "abc_train/03_meshes/*.ply"] 10 | 11 | 12 | def download_meshes(target_folder, folder_pattern=folder_patterns[0]): 13 | wip_folder = os.path.join(target_folder, "wip") 14 | os.makedirs(wip_folder, exist_ok=True) 15 | 16 | snapshot_download( 17 | repo_id="perler/ppsurf", 18 | repo_type="dataset", 19 | cache_dir=wip_folder, 20 | allow_patterns=folder_pattern, 21 | ) 22 | 23 | # Move files from wip folder to target folder 24 | for root, _, files in os.walk(wip_folder): 25 | for file in files: 26 | if file.endswith(".ply"): 27 | src_file = os.path.join(root, file) 28 | dest_file = os.path.join(target_folder, file) 29 | shutil.copy2(src_file, dest_file) 30 | os.remove(src_file) 31 | 32 | # Remove the wip folder 33 | shutil.rmtree(wip_folder) 34 | 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser( 38 | description="Download test meshes from Hugging Face Hub." 39 | ) 40 | parser.add_argument( 41 | "--target-folder", 42 | type=str, 43 | required=False, 44 | help="The target folder path where the meshes will be downloaded.", 45 | ) 46 | parser.add_argument( 47 | "--dataset-size", 48 | type=str, 49 | required=False, 50 | default="small", 51 | choices=["small", "large"], 52 | help="The size of the dataset to download. Choose 'small' or 'large'.", 53 | ) 54 | 55 | args = parser.parse_args() 56 | 57 | target_folder = args.target_folder if args.target_folder else "data/raw" 58 | if not os.path.exists(target_folder): 59 | os.makedirs(target_folder) 60 | 61 | download_meshes( 62 | target_folder, 63 | folder_patterns[0] if args.dataset_size == "small" else folder_patterns[1], 64 | ) 65 | 66 | 67 | if __name__ == "__main__": 68 | main() 69 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from neural_mesh_simplification.trainer import Trainer 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser( 8 | description="Evaluate the Neural Mesh Simplification model." 9 | ) 10 | parser.add_argument( 11 | "--eval-data-path", 12 | type=str, 13 | required=True, 14 | help="Path to the evaluation data directory.", 15 | ) 16 | parser.add_argument( 17 | "--config", 18 | type=str, 19 | required=False, 20 | help="Path to the evaluation configuration file.", 21 | ) 22 | parser.add_argument( 23 | "--checkpoint", type=str, required=True, help="Path to the model checkpoint." 24 | ) 25 | return parser.parse_args() 26 | 27 | 28 | def load_config(config_path): 29 | import yaml 30 | 31 | with open(config_path, "r") as file: 32 | config = yaml.safe_load(file) 33 | return config 34 | 35 | 36 | def main(): 37 | args = parse_args() 38 | config = load_config(args.config) 39 | config["data"]["eval_data_path"] = args.eval_data_path 40 | 41 | trainer = Trainer(config) 42 | trainer.load_checkpoint(args.checkpoint) 43 | 44 | evaluation_metrics = trainer.evaluate(trainer.val_loader) 45 | for metric, value in evaluation_metrics.items(): 46 | print(f"{metric}: {value:.4f}") 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /scripts/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from trimesh import Trimesh 4 | 5 | from neural_mesh_simplification import ( 6 | NeuralMeshSimplifier, 7 | ) # Assuming the model class is named MeshSimplifier 8 | from neural_mesh_simplification.data.dataset import load_mesh 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description="Simplify a 3D mesh using a trained model." 14 | ) 15 | parser.add_argument( 16 | "--input-file", type=str, required=True, help="Path to the input mesh file." 17 | ) 18 | parser.add_argument( 19 | "--output-file", 20 | type=str, 21 | required=True, 22 | help="Where to save the simplified mesh.", 23 | ) 24 | parser.add_argument( 25 | "--hidden-dim", 26 | type=int, 27 | required=False, 28 | default=128, 29 | help="Feature dimension for point sampler and face classifier", 30 | ) 31 | parser.add_argument( 32 | "--edge-hidden-dim", 33 | type=int, 34 | required=False, 35 | default=64, 36 | help="Feature dimension for edge predictor", 37 | ) 38 | parser.add_argument( 39 | "--model-checkpoint", 40 | type=str, 41 | required=True, 42 | help="Path to the trained model checkpoint.", 43 | ) 44 | parser.add_argument( 45 | "--device", 46 | type=str, 47 | default="cpu", 48 | help="Device to use for inference (`cpu` or `cuda`).", 49 | ) 50 | 51 | return parser.parse_args() 52 | 53 | 54 | def simplify_mesh( 55 | input_file: str, 56 | output_file: str, 57 | model_checkpoint: str, 58 | hidden_dim: int, 59 | edge_hidden_dim: int, 60 | device="cpu", 61 | ): 62 | """ 63 | Simplifies a 3D mesh using a trained model. 64 | 65 | Args: 66 | input_file (str): Path to the high-resolution input `.obj` file. 67 | model_checkpoint (str): Path to the trained model checkpoint. 68 | hidden_dim (int): Feature dimension for point sampler and face classifier 69 | edge_hidden_dim (int): Feature dimension for edge predictor 70 | device (str): Device to use for inference (`cpu` or `cuda`). 71 | """ 72 | # Load the trained model 73 | print(f"Loading model from {model_checkpoint}...") 74 | simplifier = NeuralMeshSimplifier.using_model( 75 | model_checkpoint, 76 | hidden_dim=hidden_dim, 77 | edge_hidden_dim=edge_hidden_dim, 78 | map_location=device, 79 | ) 80 | simplifier.model.to(device) 81 | simplifier.model.eval() 82 | 83 | # Load the input mesh 84 | print(f"Loading input mesh from {input_file}...") 85 | original_mesh = load_mesh(input_file) 86 | 87 | if not isinstance(original_mesh, Trimesh): 88 | raise ValueError("Invalid format for input mesh.") 89 | 90 | simplified_mesh = simplifier.simplify(original_mesh) 91 | 92 | # Save the simplified mesh 93 | print(f"Saving simplified mesh to {output_file}...") 94 | simplified_mesh.export(output_file) 95 | print("Simplification complete.") 96 | 97 | 98 | def load_config(config_path): 99 | import yaml 100 | 101 | with open(config_path, "r") as file: 102 | config = yaml.safe_load(file) 103 | return config 104 | 105 | 106 | def main(): 107 | args = parse_args() 108 | 109 | simplify_mesh( 110 | input_file=args.input_file, 111 | output_file=args.output_file, 112 | model_checkpoint=args.model_checkpoint, 113 | hidden_dim=args.hidden_dim, 114 | edge_hidden_dim=args.edge_hidden_dim, 115 | device=args.device, 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /scripts/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import networkx as nx 4 | import trimesh 5 | from tqdm import tqdm 6 | 7 | from neural_mesh_simplification.data import MeshSimplificationDataset 8 | from neural_mesh_simplification.data.dataset import load_mesh, preprocess_mesh 9 | 10 | 11 | def preprocess_dataset( 12 | input_dir, 13 | output_dir, 14 | pre_process=True, 15 | min_components=1, 16 | max_components=1, 17 | print_stats=False, 18 | ): 19 | dataset = MeshSimplificationDataset(data_dir=input_dir) 20 | 21 | for idx in tqdm(range(len(dataset)), desc="Processing meshes"): 22 | file_path = os.path.join(dataset.data_dir, dataset.file_list[idx]) 23 | 24 | mesh = load_mesh(file_path) 25 | 26 | if pre_process: 27 | mesh = preprocess_mesh(mesh) 28 | 29 | if mesh is not None: 30 | face_adjacency = trimesh.graph.face_adjacency(mesh.faces) 31 | 32 | G = nx.Graph() 33 | G.add_edges_from(face_adjacency) 34 | 35 | components = list(nx.connected_components(G)) 36 | 37 | num_components = len(components) 38 | num_vertices = len(mesh.vertices) 39 | num_faces = len(mesh.faces) 40 | 41 | if num_components < min_components or num_components > max_components: 42 | print(f"Skipping mesh {idx}: {dataset.file_list[idx]}") 43 | print(f" Connected components: {num_components}") 44 | print() 45 | continue 46 | 47 | if print_stats: 48 | print(f"Mesh {idx}: {dataset.file_list[idx]}") 49 | print(f" Connected components: {num_components}") 50 | print(f" Vertices: {num_vertices}") 51 | print(f" Faces: {num_faces}") 52 | print(f" Is watertight: {mesh.is_watertight}") 53 | print(f" Volume: {mesh.volume}") 54 | print(f" Surface area: {mesh.area}") 55 | 56 | non_manifold_edges = mesh.edges_unique[mesh.edges_unique_length > 2] 57 | print(f" Number of non-manifold edges: {len(non_manifold_edges)}") 58 | print() 59 | 60 | output_file = os.path.join(output_dir, dataset.file_list[idx]) 61 | mesh.export(output_file.replace(".ply", ".stl")) 62 | else: 63 | print(f"Failed to load mesh {idx}: {dataset.file_list[idx]}") 64 | print() 65 | 66 | print("Finished processing all meshes.") 67 | 68 | 69 | if __name__ == "__main__": 70 | if not os.path.exists("data/raw"): 71 | raise FileNotFoundError( 72 | "The 'data/raw' directory does not exist. Please download the dataset first." 73 | ) 74 | 75 | os.makedirs("data/processed", exist_ok=True) 76 | preprocess_dataset("data/raw", "data/processed") 77 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | from neural_mesh_simplification.trainer import Trainer 6 | 7 | script_dir = os.path.dirname(os.path.abspath(__file__)) 8 | default_config_path = os.path.join(script_dir, "../configs/default.yaml") 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | description="Train the Neural Mesh Simplification model." 14 | ) 15 | parser.add_argument( 16 | "--data-path", 17 | type=str, 18 | required=False, 19 | help="Path to the training data directory.", 20 | ) 21 | parser.add_argument( 22 | "--config", 23 | type=str, 24 | required=False, 25 | help="Path to the training configuration file.", 26 | ) 27 | parser.add_argument( 28 | "--checkpoint-dir", 29 | type=str, 30 | default="checkpoints", 31 | help="Directory to save model checkpoints.", 32 | ) 33 | parser.add_argument( 34 | "--resume", 35 | type=str, 36 | default=None, 37 | help="Path to a checkpoint to resume training from.", 38 | ) 39 | parser.add_argument("--debug", action="store_true", help="Show debug logs") 40 | parser.add_argument( 41 | "--monitor", action="store_true", help="Monitor CPU and memory usage" 42 | ) 43 | return parser.parse_args() 44 | 45 | 46 | def load_config(config_path): 47 | import yaml 48 | 49 | with open(config_path, "r") as file: 50 | config = yaml.safe_load(file) 51 | return config 52 | 53 | 54 | def main(): 55 | args = parse_args() 56 | 57 | config_path = args.config if args.config else default_config_path 58 | 59 | config = load_config(config_path) 60 | 61 | if args.data_path: 62 | config["data"]["data_dir"] = args.data_path 63 | 64 | if args.checkpoint_dir: 65 | config["training"]["checkpoint_dir"] = args.checkpoint_dir 66 | 67 | if not os.path.exists(args.checkpoint_dir): 68 | os.makedirs(args.checkpoint_dir) 69 | 70 | if args.debug: 71 | logging.basicConfig(level=logging.DEBUG) 72 | else: 73 | logging.basicConfig(level=logging.INFO) 74 | 75 | if args.monitor: 76 | config["monitor_resources"] = True 77 | 78 | trainer = Trainer(config) 79 | 80 | if args.resume: 81 | trainer.load_checkpoint(args.resume) 82 | 83 | try: 84 | trainer.train() 85 | except Exception as e: 86 | trainer.handle_error(e) 87 | trainer.save_training_state( 88 | os.path.join(config["training"]["checkpoint_dir"], "training_state.pth") 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | main() 94 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/__init__.py: -------------------------------------------------------------------------------- 1 | from .api.neural_mesh_simplifier import NeuralMeshSimplifier 2 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/api/__init__.py: -------------------------------------------------------------------------------- 1 | from .neural_mesh_simplifier import NeuralMeshSimplifier 2 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/api/neural_mesh_simplifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import trimesh 3 | from torch_geometric.data import Data 4 | 5 | from ..data.dataset import preprocess_mesh, mesh_to_tensor 6 | from ..models import NeuralMeshSimplification 7 | 8 | 9 | class NeuralMeshSimplifier: 10 | def __init__( 11 | self, 12 | input_dim, 13 | hidden_dim, 14 | edge_hidden_dim, # Separate hidden dim for edge predictor 15 | num_layers, 16 | k, 17 | edge_k, 18 | target_ratio, 19 | ): 20 | self.input_dim = input_dim 21 | self.hidden_dim = hidden_dim 22 | self.edge_hidden_dim = edge_hidden_dim 23 | self.num_layers = num_layers 24 | self.k = k 25 | self.edge_k = edge_k 26 | self.target_ratio = target_ratio 27 | self.model = self._build_model() 28 | 29 | @classmethod 30 | def using_model( 31 | cls, at_path: str, hidden_dim: int, edge_hidden_dim: int, map_location: str 32 | ): 33 | instance = cls( 34 | input_dim=3, 35 | hidden_dim=hidden_dim, 36 | edge_hidden_dim=edge_hidden_dim, 37 | num_layers=3, 38 | k=15, 39 | edge_k=15, 40 | target_ratio=0.5, 41 | ) 42 | instance._load_model(at_path, map_location) 43 | return instance 44 | 45 | def _build_model(self): 46 | return NeuralMeshSimplification( 47 | input_dim=self.input_dim, 48 | hidden_dim=self.hidden_dim, 49 | edge_hidden_dim=self.edge_hidden_dim, 50 | num_layers=self.num_layers, 51 | k=self.k, 52 | edge_k=self.edge_k, 53 | target_ratio=self.target_ratio, 54 | ) 55 | 56 | def _load_model(self, checkpoint_path: str, map_location: str): 57 | self.model.load_state_dict( 58 | torch.load(checkpoint_path, map_location=map_location) 59 | ) 60 | 61 | def simplify(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: 62 | # Preprocess the mesh (e.g. normalize, center) 63 | preprocessed_mesh: trimesh.Trimesh = preprocess_mesh(mesh) 64 | 65 | # Convert to a tensor 66 | tensor: Data = mesh_to_tensor(preprocessed_mesh) 67 | model_output = self.model(tensor) 68 | 69 | vertices = model_output["sampled_vertices"].detach().numpy() 70 | faces = model_output["simplified_faces"].numpy() 71 | edges = model_output["edge_index"].t().numpy() # Transpose to get (n, 2) shape 72 | 73 | return trimesh.Trimesh(vertices=vertices, faces=faces, edges=edges) 74 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import MeshSimplificationDataset 2 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/data/dataset.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | from typing import Optional 4 | 5 | import numpy as np 6 | import torch 7 | import trimesh 8 | from torch.utils.data import Dataset 9 | from torch_geometric.data import Data 10 | from trimesh import Geometry, Trimesh 11 | 12 | from ..utils import build_graph_from_mesh 13 | 14 | 15 | class MeshSimplificationDataset(Dataset): 16 | def __init__( 17 | self, 18 | data_dir: str, 19 | preprocess: bool = False, 20 | transform: Optional[callable] = None, 21 | ): 22 | self.data_dir = data_dir 23 | self.preprocess = preprocess 24 | self.transform = transform 25 | self.file_list = self._get_file_list() 26 | 27 | def _get_file_list(self): 28 | return [ 29 | f 30 | for f in os.listdir(self.data_dir) 31 | if f.endswith(".ply") or f.endswith(".obj") or f.endswith(".stl") 32 | ] 33 | 34 | def __len__(self): 35 | return len(self.file_list) 36 | 37 | def __getitem__(self, idx): 38 | file_path = os.path.join(self.data_dir, self.file_list[idx]) 39 | mesh = load_mesh(file_path) 40 | 41 | if self.preprocess: 42 | mesh = preprocess_mesh(mesh) 43 | 44 | if self.transform: 45 | mesh = self.transform(mesh) 46 | 47 | data = mesh_to_tensor(mesh) 48 | gc.collect() 49 | return data 50 | 51 | 52 | def load_mesh(file_path: str) -> Geometry | list[Geometry] | None: 53 | """Load a mesh from file.""" 54 | try: 55 | mesh = trimesh.load(file_path) 56 | return mesh 57 | except Exception as e: 58 | print(f"Error loading mesh {file_path}: {e}") 59 | return None 60 | 61 | 62 | def preprocess_mesh(mesh: trimesh.Trimesh) -> Trimesh | None: 63 | """Preprocess a mesh (e.g., normalize, center).""" 64 | if mesh is None: 65 | return None 66 | 67 | # Center the mesh 68 | mesh.vertices -= mesh.vertices.mean(axis=0) 69 | 70 | # Scale to unit cube 71 | max_dim = np.max(mesh.vertices.max(axis=0) - mesh.vertices.min(axis=0)) 72 | mesh.vertices /= max_dim 73 | 74 | return mesh 75 | 76 | 77 | def augment_mesh(mesh: trimesh.Trimesh) -> Trimesh | None: 78 | """Apply data augmentation to a mesh.""" 79 | if mesh is None: 80 | return None 81 | 82 | # Example: Random rotation 83 | rotation = trimesh.transformations.random_rotation_matrix() 84 | mesh.apply_transform(rotation) 85 | 86 | return mesh 87 | 88 | 89 | def mesh_to_tensor(mesh: trimesh.Trimesh) -> Data: 90 | """Convert a mesh to tensor representation including graph structure.""" 91 | if mesh is None: 92 | return None 93 | 94 | # Convert vertices and faces to tensors 95 | vertices_tensor = torch.tensor(mesh.vertices, dtype=torch.float32) 96 | faces_tensor = torch.tensor(mesh.faces, dtype=torch.long).t() 97 | 98 | # Build graph structure 99 | G = build_graph_from_mesh(mesh) 100 | 101 | # Create edge index tensor 102 | edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous() 103 | 104 | # Create Data object 105 | data = Data( 106 | x=vertices_tensor, 107 | pos=vertices_tensor, 108 | edge_index=edge_index, 109 | face=faces_tensor, 110 | num_nodes=len(mesh.vertices), 111 | ) 112 | 113 | return data 114 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance_loss import ProbabilisticChamferDistanceLoss 2 | from .surface_distance_loss import ProbabilisticSurfaceDistanceLoss 3 | from .triangle_collision_loss import TriangleCollisionLoss 4 | from .edge_crossing_loss import EdgeCrossingLoss 5 | from .overlapping_triangles_loss import OverlappingTrianglesLoss 6 | from .combined_loss import CombinedMeshSimplificationLoss 7 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/losses/chamfer_distance_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ProbabilisticChamferDistanceLoss(nn.Module): 6 | def __init__(self): 7 | super(ProbabilisticChamferDistanceLoss, self).__init__() 8 | 9 | def forward(self, P, Ps, probabilities): 10 | """ 11 | Compute the Probabilistic Chamfer Distance loss. 12 | 13 | Args: 14 | P (torch.Tensor): Original point cloud, shape (N, 3) 15 | Ps (torch.Tensor): Sampled point cloud, shape (M, 3) 16 | probabilities (torch.Tensor): Sampling probabilities for Ps, shape (M,) 17 | 18 | Returns: 19 | torch.Tensor: Scalar loss value 20 | """ 21 | if P.size(0) == 0 or Ps.size(0) == 0: 22 | return torch.tensor(0.0, device=P.device, requires_grad=True) 23 | 24 | # Ensure inputs are on the same device 25 | Ps = Ps.to(P.device) 26 | probabilities = probabilities.to(P.device) 27 | 28 | # Compute distances from Ps to P 29 | dist_s_to_o = self.compute_minimum_distances(Ps, P) 30 | 31 | # Compute distances from P to Ps 32 | dist_o_to_s, min_indices = self.compute_minimum_distances( 33 | P, Ps, return_indices=True 34 | ) 35 | 36 | # Weight distances by probabilities 37 | weighted_dist_s_to_o = dist_s_to_o * probabilities 38 | weighted_dist_o_to_s = dist_o_to_s * probabilities[min_indices] 39 | 40 | # Sum up the weighted distances 41 | loss = weighted_dist_s_to_o.sum() + weighted_dist_o_to_s.sum() 42 | 43 | return loss 44 | 45 | def compute_minimum_distances(self, source, target, return_indices=False): 46 | """ 47 | Compute the minimum distances from each point in source to target. 48 | 49 | Args: 50 | source (torch.Tensor): Source point cloud, shape (N, 3) 51 | target (torch.Tensor): Target point cloud, shape (M, 3) 52 | return_indices (bool): If True, also return indices of minimum distances 53 | 54 | Returns: 55 | torch.Tensor: Minimum distances, shape (N,) 56 | torch.Tensor: Indices of minimum distances (if return_indices is True) 57 | """ 58 | # Compute pairwise distances 59 | distances = torch.cdist(source, target) 60 | 61 | # Find minimum distances 62 | min_distances, min_indices = distances.min(dim=1) 63 | 64 | if return_indices: 65 | return min_distances, min_indices 66 | else: 67 | return min_distances 68 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/losses/combined_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import device 3 | 4 | from . import ( 5 | ProbabilisticChamferDistanceLoss, 6 | ProbabilisticSurfaceDistanceLoss, 7 | TriangleCollisionLoss, 8 | EdgeCrossingLoss, 9 | OverlappingTrianglesLoss, 10 | ) 11 | 12 | 13 | class CombinedMeshSimplificationLoss(nn.Module): 14 | def __init__( 15 | self, 16 | lambda_c: float = 1.0, 17 | lambda_e: float = 1.0, 18 | lambda_o: float = 1.0, 19 | device=device("cpu"), 20 | ): 21 | super().__init__() 22 | self.device = device 23 | self.prob_chamfer_loss = ProbabilisticChamferDistanceLoss().to(self.device) 24 | self.prob_surface_loss = ProbabilisticSurfaceDistanceLoss().to(self.device) 25 | self.collision_loss = TriangleCollisionLoss().to(self.device) 26 | self.edge_crossing_loss = EdgeCrossingLoss().to(self.device) 27 | self.overlapping_triangles_loss = OverlappingTrianglesLoss().to(self.device) 28 | self.lambda_c = lambda_c 29 | self.lambda_e = lambda_e 30 | self.lambda_o = lambda_o 31 | 32 | def forward(self, original_data, simplified_data): 33 | original_x = ( 34 | original_data["pos"] if "pos" in original_data else original_data["x"] 35 | ).to(self.device) 36 | original_face = original_data["face"].to(self.device) 37 | 38 | sampled_vertices = simplified_data["sampled_vertices"].to(self.device) 39 | sampled_probs = simplified_data["sampled_probs"].to(self.device) 40 | sampled_faces = simplified_data["simplified_faces"].to(self.device) 41 | face_probs = simplified_data["face_probs"].to(self.device) 42 | 43 | chamfer_loss = self.prob_chamfer_loss( 44 | original_x, sampled_vertices, sampled_probs 45 | ) 46 | 47 | del sampled_probs 48 | 49 | surface_loss = self.prob_surface_loss( 50 | original_x, 51 | original_face, 52 | sampled_vertices, 53 | sampled_faces, 54 | face_probs, 55 | ) 56 | 57 | del original_x 58 | del original_face 59 | 60 | collision_loss = self.collision_loss( 61 | sampled_vertices, 62 | sampled_faces, 63 | face_probs, 64 | ) 65 | edge_crossing_loss = self.edge_crossing_loss( 66 | sampled_vertices, sampled_faces, face_probs 67 | ) 68 | 69 | del face_probs 70 | 71 | overlapping_triangles_loss = self.overlapping_triangles_loss( 72 | sampled_vertices, sampled_faces 73 | ) 74 | 75 | del sampled_vertices 76 | del sampled_faces 77 | 78 | total_loss = ( 79 | chamfer_loss 80 | + surface_loss 81 | + self.lambda_c * collision_loss 82 | + self.lambda_e * edge_crossing_loss 83 | + self.lambda_o * overlapping_triangles_loss 84 | ) 85 | 86 | return total_loss 87 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/losses/edge_crossing_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_cluster import knn 4 | 5 | 6 | class EdgeCrossingLoss(nn.Module): 7 | def __init__(self, k: int = 20): 8 | super().__init__() 9 | self.k = k # Number of nearest triangles to consider 10 | 11 | def forward( 12 | self, vertices: torch.Tensor, faces: torch.Tensor, face_probs: torch.Tensor 13 | ) -> torch.Tensor: 14 | # If no faces, return zero loss 15 | if faces.shape[0] == 0: 16 | return torch.tensor(0.0, device=vertices.device) 17 | # Ensure face_probs matches the number of faces 18 | if face_probs.shape[0] > faces.shape[0]: 19 | face_probs = face_probs[: faces.shape[0]] 20 | elif face_probs.shape[0] < faces.shape[0]: 21 | # Pad with zeros if we have fewer probabilities than faces 22 | padding = torch.zeros( 23 | faces.shape[0] - face_probs.shape[0], device=face_probs.device 24 | ) 25 | face_probs = torch.cat([face_probs, padding]) 26 | 27 | # 1. Find k-nearest triangles for each triangle 28 | nearest_triangles = self.find_nearest_triangles(vertices, faces) 29 | 30 | # 2. Detect edge crossings between nearby triangles 31 | crossings = self.detect_edge_crossings(vertices, faces, nearest_triangles) 32 | 33 | # 3. Calculate loss 34 | loss = self.calculate_loss(crossings, face_probs) 35 | 36 | return loss 37 | 38 | def find_nearest_triangles( 39 | self, vertices: torch.Tensor, faces: torch.Tensor 40 | ) -> torch.Tensor: 41 | # Compute triangle centroids 42 | centroids = vertices[faces].mean(dim=1) 43 | 44 | # Use knn to find nearest triangles 45 | k = min( 46 | self.k, centroids.shape[0] 47 | ) # Ensure k is not larger than the number of centroids 48 | _, indices = knn(centroids, centroids, k=k) 49 | 50 | # Reshape indices to [num_faces, k] 51 | indices = indices.view(centroids.shape[0], k) 52 | 53 | # Remove self-connections (triangles cannot be their own neighbor) 54 | nearest = [] 55 | for i in range(indices.shape[0]): 56 | neighbors = indices[i][indices[i] != i] 57 | if len(neighbors) == 0: 58 | nearest.append(torch.empty(0, dtype=torch.long)) 59 | else: 60 | nearest.append(neighbors[: self.k - 1]) 61 | 62 | # Return tensor with consistent shape 63 | if len(nearest) > 0 and all(len(n) == 0 for n in nearest): 64 | nearest = torch.empty((len(nearest), 0), dtype=torch.long) 65 | else: 66 | nearest = torch.stack(nearest) 67 | return nearest 68 | 69 | def detect_edge_crossings( 70 | self, 71 | vertices: torch.Tensor, 72 | faces: torch.Tensor, 73 | nearest_triangles: torch.Tensor, 74 | ) -> torch.Tensor: 75 | def edge_vectors(triangles): 76 | # Extracts the edges from a triangle defined by vertex indices 77 | return vertices[triangles[:, [1, 2, 0]]] - vertices[triangles] 78 | 79 | edges = edge_vectors(faces) 80 | crossings = torch.zeros(faces.shape[0], device=vertices.device) 81 | 82 | for i in range(faces.shape[0]): 83 | neighbor_edges = edge_vectors(faces[nearest_triangles[i]]) 84 | for j in range(3): 85 | edge = edges[i, j].unsqueeze(0).unsqueeze(0) 86 | cross_product = torch.cross( 87 | edge.expand(neighbor_edges.shape), neighbor_edges, dim=-1 88 | ) 89 | t = torch.sum(cross_product * neighbor_edges, dim=-1) / torch.sum( 90 | cross_product * edge.expand(neighbor_edges.shape), dim=-1 91 | ) 92 | u = torch.sum( 93 | cross_product * edges[i].unsqueeze(0), dim=-1 94 | ) / torch.sum(cross_product * edge.expand(neighbor_edges.shape), dim=-1) 95 | mask = (t >= 0) & (t <= 1) & (u >= 0) & (u <= 1) 96 | crossings[i] += mask.sum() 97 | 98 | return crossings 99 | 100 | def calculate_loss( 101 | self, crossings: torch.Tensor, face_probs: torch.Tensor 102 | ) -> torch.Tensor: 103 | # Weighted sum of crossings by triangle probabilities 104 | num_faces = face_probs.shape[0] 105 | return torch.sum(face_probs * crossings, dtype=torch.float32) / num_faces 106 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/losses/overlapping_triangles_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class OverlappingTrianglesLoss(nn.Module): 6 | def __init__(self, num_samples: int = 10, k: int = 5): 7 | """ 8 | Initializes the OverlappingTrianglesLoss. 9 | 10 | Args: 11 | num_samples (int): The number of points to sample from each triangle. 12 | k (int): The number of nearest triangles to consider for overlap checking. 13 | """ 14 | super().__init__() 15 | self.num_samples = num_samples # Number of points to sample from each triangle 16 | self.k = k # Number of nearest triangles to consider 17 | 18 | def forward(self, vertices: torch.Tensor, faces: torch.Tensor): 19 | 20 | # If no faces, return zero loss 21 | if faces.shape[0] == 0: 22 | return torch.tensor(0.0, device=vertices.device) 23 | 24 | # 1. Sample points from each triangle 25 | sampled_points, point_face_indices = self.sample_points_from_triangles( 26 | vertices, faces 27 | ) 28 | 29 | # 2. Find k-nearest triangles for each point 30 | nearest_triangles = self.find_nearest_triangles(sampled_points, vertices, faces) 31 | 32 | # 3. Detect overlaps and calculate the loss 33 | overlap_penalty = self.calculate_overlap_loss( 34 | sampled_points, vertices, faces, nearest_triangles, point_face_indices 35 | ) 36 | 37 | return overlap_penalty 38 | 39 | def sample_points_from_triangles( 40 | self, vertices: torch.Tensor, faces: torch.Tensor 41 | ) -> (torch.Tensor, torch.Tensor): 42 | """ 43 | Samples points from each triangle in the mesh. 44 | 45 | Args: 46 | vertices (torch.Tensor): The vertex positions (V x 3). 47 | faces (torch.Tensor): The indices of the vertices that make up each triangle (F x 3). 48 | 49 | Returns: 50 | torch.Tensor: Sampled points (F * num_samples x 3). 51 | torch.Tensor: index mapping from points to their original faces. 52 | """ 53 | # Get vertices for all faces at once 54 | v0, v1, v2 = vertices[faces].unbind(1) 55 | 56 | # Generate random barycentric coordinates 57 | rand_shape = (faces.shape[0], self.num_samples, 1) 58 | u = torch.rand(rand_shape, device=vertices.device) 59 | v = torch.rand(rand_shape, device=vertices.device) 60 | 61 | # Adjust coordinates that sum > 1 62 | mask = (u + v) > 1 63 | u = torch.where(mask, 1 - u, u) 64 | v = torch.where(mask, 1 - v, v) 65 | w = 1 - u - v 66 | 67 | # Calculate the coordinates of the sampled points 68 | points = v0.unsqueeze(1) * w + v1.unsqueeze(1) * u + v2.unsqueeze(1) * v 69 | 70 | # Create index mapping from points to their original faces 71 | point_face_indices = torch.arange(faces.shape[0], device=vertices.device) 72 | point_face_indices = point_face_indices.repeat_interleave(self.num_samples) 73 | 74 | # Reshape to a (F * num_samples x 3) tensor 75 | points = points.reshape(-1, 3) 76 | 77 | return points, point_face_indices 78 | 79 | def find_nearest_triangles( 80 | self, sampled_points: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor 81 | ) -> torch.Tensor: 82 | """ 83 | Finds the k-nearest triangles for each sampled point. 84 | 85 | Args: 86 | sampled_points (torch.Tensor): Sampled points from triangles (N x 3). 87 | vertices (torch.Tensor): The vertex positions (V x 3). 88 | faces (torch.Tensor): The indices of the vertices that make up each triangle (F x 3). 89 | 90 | Returns: 91 | torch.Tensor: Indices of the k-nearest triangles for each sampled point (N x k). 92 | """ 93 | # Compute triangle centroids 94 | centroids = vertices[faces].mean(dim=1) 95 | 96 | # Adjust k to be no larger than the number of triangles 97 | k = min(self.k, faces.shape[0]) 98 | if k == 0: 99 | # Return empty tensor if no triangles 100 | return torch.empty( 101 | (sampled_points.shape[0], 0), 102 | dtype=torch.long, 103 | device=sampled_points.device, 104 | ) 105 | 106 | # Use knn to find nearest triangles for each sampled point 107 | distances = torch.cdist(sampled_points, centroids) 108 | _, indices = distances.topk(k, dim=1, largest=False) 109 | 110 | return indices 111 | 112 | def calculate_overlap_loss( 113 | self, 114 | sampled_points: torch.Tensor, 115 | vertices: torch.Tensor, 116 | faces: torch.Tensor, 117 | nearest_triangles: torch.Tensor, 118 | point_face_indices: torch.Tensor, 119 | ) -> torch.Tensor: 120 | """ 121 | Calculates the overlap loss by checking if sampled points belong to multiple triangles. 122 | 123 | Args: 124 | sampled_points (torch.Tensor): Sampled points from triangles (N x 3). 125 | vertices (torch.Tensor): The vertex positions (V x 3). 126 | faces (torch.Tensor): The indices of the vertices that make up each triangle (F x 3). 127 | nearest_triangles (torch.Tensor): Indices of the k-nearest triangles for each sampled point (N x k). 128 | point_face_indices (torch.Tensor): Index mapping from points to their original faces. 129 | 130 | Returns: 131 | torch.Tensor: The overlap penalty loss. 132 | """ 133 | 134 | # Reshape for broadcasting 135 | points_expanded = sampled_points.unsqueeze(1) # [N, 1, 3] 136 | nearest_faces = faces[nearest_triangles] # [N, K, 3] 137 | 138 | # Get vertices for all nearest triangles 139 | v0 = vertices[nearest_faces[..., 0]] # [N, K, 3] 140 | v1 = vertices[nearest_faces[..., 1]] 141 | v2 = vertices[nearest_faces[..., 2]] 142 | 143 | # Calculate edges 144 | edge1 = v1 - v0 # [N, K, 3] 145 | edge2 = v2 - v0 146 | 147 | # Calculate normals 148 | normals = torch.linalg.cross(edge1, edge2) # [N, K, 3] 149 | normal_lengths = torch.norm(normals, dim=2, keepdim=True) 150 | normals = normals / (normal_lengths + 1e-8) 151 | 152 | del edge1, edge2, normal_lengths 153 | 154 | # Calculate barycentric coordinates for all points at once 155 | p_v0 = points_expanded - v0 # [N, K, 3] 156 | 157 | # Compute dot products for barycentric coordinates 158 | dot00 = torch.sum(normals * normals, dim=2) # [N, K] 159 | dot01 = torch.sum(normals * (v1 - v0), dim=2) 160 | dot02 = torch.sum(normals * (v2 - v0), dim=2) 161 | dot0p = torch.sum(normals * p_v0, dim=2) 162 | 163 | del p_v0, normals 164 | 165 | # Calculate barycentric coordinates 166 | denom = dot00 * dot00 - dot01 * dot01 167 | u = (dot00 * dot0p - dot01 * dot02) / (denom + 1e-8) 168 | v = (dot00 * dot02 - dot01 * dot0p) / (denom + 1e-8) 169 | 170 | del dot00, dot01, dot02, dot0p, denom 171 | 172 | # Check if points are inside triangles 173 | inside_mask = (u >= 0) & (v >= 0) & (u + v <= 1) 174 | 175 | # Don't count overlap with source triangle 176 | source_mask = nearest_triangles == point_face_indices.unsqueeze(1) 177 | inside_mask = inside_mask & ~source_mask 178 | 179 | # Calculate areas only for inside points 180 | areas = torch.zeros_like(inside_mask, dtype=torch.float32) 181 | where_inside = torch.where(inside_mask) 182 | 183 | if where_inside[0].numel() > 0: 184 | # Calculate areas only for points inside triangles 185 | relevant_v0 = v0[where_inside] 186 | relevant_v1 = v1[where_inside] 187 | relevant_v2 = v2[where_inside] 188 | 189 | # Calculate areas using cross product 190 | cross_prod = torch.linalg.cross( 191 | relevant_v1 - relevant_v0, relevant_v2 - relevant_v0 192 | ) 193 | areas[where_inside] = 0.5 * torch.norm(cross_prod, dim=1) 194 | 195 | del cross_prod, relevant_v0, relevant_v1, relevant_v2 196 | 197 | del v0, v1, v2, inside_mask, source_mask 198 | 199 | # Sum up the overlap penalty 200 | overlap_penalty = areas.sum() 201 | 202 | return overlap_penalty 203 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/losses/surface_distance_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_cluster import knn 4 | 5 | 6 | class ProbabilisticSurfaceDistanceLoss(nn.Module): 7 | def __init__(self, k: int = 3, num_samples: int = 100, epsilon: float = 1e-8): 8 | super().__init__() 9 | self.k = k 10 | self.num_samples = num_samples 11 | self.epsilon = epsilon 12 | 13 | def forward( 14 | self, 15 | original_vertices: torch.Tensor, 16 | original_faces: torch.Tensor, 17 | simplified_vertices: torch.Tensor, 18 | simplified_faces: torch.Tensor, 19 | face_probabilities: torch.Tensor, 20 | ) -> torch.Tensor: 21 | if original_vertices.shape[0] == 0 or simplified_vertices.shape[0] == 0: 22 | return torch.tensor(0.0, device=original_vertices.device) 23 | 24 | # Pad face probabilities once for both terms 25 | face_probabilities = torch.nn.functional.pad( 26 | face_probabilities, 27 | (0, max(0, simplified_faces.shape[0] - face_probabilities.shape[0])), 28 | )[: simplified_faces.shape[0]] 29 | 30 | forward_term = self.compute_forward_term( 31 | original_vertices, 32 | original_faces, 33 | simplified_vertices, 34 | simplified_faces, 35 | face_probabilities, 36 | ) 37 | 38 | reverse_term = self.compute_reverse_term( 39 | original_vertices, 40 | original_faces, 41 | simplified_vertices, 42 | simplified_faces, 43 | face_probabilities, 44 | ) 45 | 46 | total_loss = forward_term + reverse_term 47 | return total_loss 48 | 49 | def compute_forward_term( 50 | self, 51 | original_vertices: torch.Tensor, 52 | original_faces: torch.Tensor, 53 | simplified_vertices: torch.Tensor, 54 | simplified_faces: torch.Tensor, 55 | face_probabilities: torch.Tensor, 56 | ) -> torch.Tensor: 57 | # If there are no faces, return zero loss 58 | if simplified_faces.shape[0] == 0: 59 | return torch.tensor(0.0, device=original_vertices.device) 60 | 61 | simplified_barycenters = self.compute_barycenters( 62 | simplified_vertices, simplified_faces 63 | ) 64 | original_barycenters = self.compute_barycenters( 65 | original_vertices, original_faces 66 | ) 67 | 68 | distances = self.compute_squared_distances( 69 | simplified_barycenters, original_barycenters 70 | ) 71 | 72 | min_distances, _ = distances.min(dim=1) 73 | 74 | del distances # Free memory 75 | 76 | # Compute total loss with probability penalty 77 | total_loss = (face_probabilities * min_distances).sum() 78 | probability_penalty = 1e-4 * (1.0 - face_probabilities).sum() 79 | 80 | del min_distances # Free memory 81 | 82 | return total_loss + probability_penalty 83 | 84 | def compute_reverse_term( 85 | self, 86 | original_vertices: torch.Tensor, 87 | original_faces: torch.Tensor, 88 | simplified_vertices: torch.Tensor, 89 | simplified_faces: torch.Tensor, 90 | face_probabilities: torch.Tensor, 91 | ) -> torch.Tensor: 92 | # If there are no faces, return zero loss 93 | if simplified_faces.shape[0] == 0: 94 | return torch.tensor(0.0, device=original_vertices.device) 95 | 96 | # If meshes are identical, reverse term should be zero 97 | if torch.equal(original_vertices, simplified_vertices) and torch.equal( 98 | original_faces, simplified_faces 99 | ): 100 | return torch.tensor(0.0, device=original_vertices.device) 101 | 102 | # Step 1: Sample points from the simplified mesh 103 | sampled_points = self.sample_points_from_triangles( 104 | simplified_vertices, simplified_faces, self.num_samples 105 | ) 106 | 107 | # Step 2: Compute the minimum distance from each sampled point to the original mesh 108 | distances = self.compute_min_distances_to_original( 109 | sampled_points, original_vertices 110 | ) 111 | 112 | # Normalize and scale distances 113 | max_dist = distances.max() + self.epsilon 114 | scaled_distances = (distances / max_dist) * 0.1 115 | 116 | del distances # Free memory 117 | 118 | # Reshape face probabilities to match the sampled points 119 | face_probs_expanded = face_probabilities.repeat_interleave(self.num_samples) 120 | 121 | # Compute weighted distances 122 | reverse_term = (face_probs_expanded * scaled_distances).sum() 123 | 124 | return reverse_term 125 | 126 | def sample_points_from_triangles( 127 | self, vertices: torch.Tensor, faces: torch.Tensor, num_samples: int 128 | ) -> torch.Tensor: 129 | """Vectorized point sampling from triangles""" 130 | num_faces = faces.shape[0] 131 | face_vertices = vertices[faces] 132 | 133 | # Generate random values for all samples at once 134 | sqrt_r1 = torch.sqrt( 135 | torch.rand(num_faces, num_samples, 1, device=vertices.device) 136 | ) 137 | r2 = torch.rand(num_faces, num_samples, 1, device=vertices.device) 138 | 139 | # Compute barycentric coordinates 140 | a = 1 - sqrt_r1 141 | b = sqrt_r1 * (1 - r2) 142 | c = sqrt_r1 * r2 143 | 144 | # Compute samples using broadcasting 145 | samples = ( 146 | a * face_vertices[:, None, 0] 147 | + b * face_vertices[:, None, 1] 148 | + c * face_vertices[:, None, 2] 149 | ) 150 | 151 | del a, b, c, sqrt_r1, r2, face_vertices # Free memory 152 | 153 | return samples.reshape(-1, 3) 154 | 155 | def compute_min_distances_to_original( 156 | self, sampled_points: torch.Tensor, target_vertices: torch.Tensor 157 | ) -> torch.Tensor: 158 | """Efficient batch distance computation using KNN""" 159 | # Convert to float32 for KNN 160 | sp_float = sampled_points.float() 161 | tv_float = target_vertices.float() 162 | 163 | # Compute KNN distances 164 | distances, _ = knn(tv_float, sp_float, k=1) 165 | 166 | del sp_float, tv_float # Free memory 167 | 168 | return distances.view(-1).float() 169 | 170 | @staticmethod 171 | def compute_squared_distances( 172 | points1: torch.Tensor, points2: torch.Tensor 173 | ) -> torch.Tensor: 174 | """Compute squared distances efficiently using torch.cdist""" 175 | return torch.cdist(points1, points2, p=2).float() 176 | 177 | def compute_barycenters( 178 | self, vertices: torch.Tensor, faces: torch.Tensor 179 | ) -> torch.Tensor: 180 | return vertices[faces].mean(dim=1) 181 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/losses/triangle_collision_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class TriangleCollisionLoss(nn.Module): 6 | def __init__( 7 | self, epsilon=1e-8, k=50, collision_threshold=1e-10, normal_threshold=0.99 8 | ): 9 | super().__init__() 10 | self.epsilon = epsilon 11 | self.k = k 12 | self.collision_threshold = collision_threshold 13 | self.normal_threshold = normal_threshold 14 | 15 | def forward(self, vertices, faces, face_probabilities): 16 | num_faces = faces.shape[0] 17 | 18 | if num_faces == 0: 19 | return torch.tensor(0.0, device=vertices.device) 20 | 21 | # Ensure face_probabilities matches the number of faces 22 | face_probabilities = torch.nn.functional.pad( 23 | face_probabilities, (0, max(0, num_faces - face_probabilities.shape[0])) 24 | )[:num_faces] 25 | 26 | v0, v1, v2 = vertices[faces].unbind(1) 27 | 28 | # Calculate face normals more efficiently 29 | edges1 = v1 - v0 30 | edges2 = v2 - v0 31 | face_normals = torch.linalg.cross(edges1, edges2) 32 | face_normal_lengths = torch.norm(face_normals, dim=1, keepdim=True) 33 | face_normals = face_normals / (face_normal_lengths + self.epsilon) 34 | 35 | del edges1, edges2, face_normal_lengths # No longer needed 36 | 37 | # Calculate centroids 38 | centroids = (v0 + v1 + v2) / 3 39 | 40 | # Find k nearest neighbors using squared distances 41 | diffs = centroids.unsqueeze(1) - centroids.unsqueeze(0) 42 | distances = torch.sum(diffs * diffs, dim=-1) 43 | del diffs # Large tensor no longer needed 44 | 45 | k = min(self.k, num_faces - 1) 46 | _, neighbors = torch.topk(distances, k=k + 1, largest=False) 47 | del distances # Large matrix no longer needed 48 | neighbors = neighbors[:, 1:] 49 | 50 | collision_count = torch.zeros(num_faces, device=vertices.device) 51 | 52 | for i in range(num_faces): 53 | nearby_faces = neighbors[i] 54 | nearby_v0, nearby_v1, nearby_v2 = ( 55 | v0[nearby_faces], 56 | v1[nearby_faces], 57 | v2[nearby_faces], 58 | ) 59 | 60 | collisions = self.check_triangle_intersection( 61 | v0[i], 62 | v1[i], 63 | v2[i], 64 | face_normals[i], 65 | nearby_v0, 66 | nearby_v1, 67 | nearby_v2, 68 | face_normals[nearby_faces], 69 | faces[i], 70 | faces[nearby_faces], 71 | centroids[i], 72 | centroids[nearby_faces], 73 | ) 74 | collision_count[i] += collisions.sum() 75 | 76 | total_loss = torch.sum(face_probabilities * collision_count) 77 | return total_loss 78 | 79 | def check_triangle_intersection( 80 | self, 81 | v0, 82 | v1, 83 | v2, 84 | normal, 85 | nearby_v0, 86 | nearby_v1, 87 | nearby_v2, 88 | nearby_normals, 89 | face, 90 | nearby_faces, 91 | centroid, 92 | nearby_centroids, 93 | ): 94 | # Check if triangles are coplanar (relaxed condition) 95 | normal_dot = torch.abs(torch.sum(normal * nearby_normals, dim=1)) 96 | coplanar = normal_dot > self.normal_threshold 97 | 98 | # Check for triangle-triangle intersections 99 | intersections = torch.zeros( 100 | len(nearby_faces), dtype=torch.bool, device=v0.device 101 | ) 102 | 103 | for i, (nv0, nv1, nv2) in enumerate(zip(nearby_v0, nearby_v1, nearby_v2)): 104 | if coplanar[i]: 105 | # For coplanar triangles, check distance between centroids 106 | dist = torch.norm(centroid - nearby_centroids[i]) 107 | intersections[i] = dist < self.collision_threshold 108 | else: 109 | intersections[i] = self.check_triangle_triangle_intersection( 110 | v0, v1, v2, nv0, nv1, nv2 111 | ) 112 | 113 | # Check if triangles are adjacent 114 | adjacent = torch.tensor( 115 | [len(set(face.tolist()) & set(nf.tolist())) >= 2 for nf in nearby_faces], 116 | dtype=torch.bool, 117 | device=v0.device, 118 | ) 119 | 120 | collisions = intersections & ~adjacent 121 | 122 | return collisions 123 | 124 | def check_triangle_triangle_intersection(self, v0, v1, v2, w0, w1, w2): 125 | def triangle_plane_intersection(t0, t1, t2, p0, p1, p2): 126 | n = torch.linalg.cross(t1 - t0, t2 - t0) 127 | d0, d1, d2 = ( 128 | self.dist_dot(p0, t0, n), 129 | self.dist_dot(p1, t0, n), 130 | self.dist_dot(p2, t0, n), 131 | ) 132 | return (d0 * d1 <= 0) or (d0 * d2 <= 0) or (d1 * d2 <= 0) 133 | 134 | return triangle_plane_intersection( 135 | v0, v1, v2, w0, w1, w2 136 | ) and triangle_plane_intersection(w0, w1, w2, v0, v1, v2) 137 | 138 | @staticmethod 139 | def dist_dot(p, t0, n): 140 | return torch.dot(p - t0, n) 141 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .normal_consistency import normal_consistency 2 | from .chamfer_distance import chamfer_distance 3 | from .edge_preservation import edge_preservation 4 | from .hausdorff_distance import hausdorff_distance 5 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/metrics/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | 4 | 5 | def chamfer_distance( 6 | mesh1: trimesh.Trimesh, mesh2: trimesh.Trimesh, samples: int = 10000 7 | ): 8 | """ 9 | Calculate the Chamfer distance between two meshes. 10 | 11 | Parameters: 12 | mesh1 (trimesh.Trimesh): The first input mesh. 13 | mesh2 (trimesh.Trimesh): The second input mesh. 14 | samples (int): The number of samples to use for the calculation. 15 | 16 | Returns: 17 | float: The Chamfer distance metric 18 | """ 19 | points1 = mesh1.sample(samples) 20 | points2 = mesh2.sample(samples) 21 | 22 | _, distances1, _ = trimesh.proximity.closest_point(mesh2, points1) 23 | _, distances2, _ = trimesh.proximity.closest_point(mesh1, points2) 24 | 25 | chamfer_dist = np.mean(distances1) + np.mean(distances2) 26 | 27 | return chamfer_dist 28 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/metrics/edge_preservation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial import cKDTree 3 | 4 | 5 | def calculate_dihedral_angles(mesh): 6 | """Calculate dihedral angles for all edges in the mesh.""" 7 | face_adjacency = mesh.face_adjacency 8 | face_adjacency_angles = mesh.face_adjacency_angles 9 | 10 | edge_to_angle = {} 11 | for (face1, face2), angle in zip(face_adjacency, face_adjacency_angles): 12 | edge = tuple(sorted(set(mesh.faces[face1]) & set(mesh.faces[face2]))) 13 | edge_to_angle[edge] = angle 14 | 15 | dihedral_angles = np.zeros(len(mesh.edges)) 16 | for i, edge in enumerate(mesh.edges): 17 | edge = tuple(sorted(edge)) 18 | dihedral_angles[i] = edge_to_angle.get(edge, 0) # 0 for boundary edges 19 | 20 | return dihedral_angles 21 | 22 | 23 | def edge_preservation( 24 | original_mesh, simplified_mesh, angle_threshold=30, important_edge_factor=2.0 25 | ): 26 | """ 27 | Calculate the edge preservation metric between the original and simplified meshes. 28 | 29 | Parameters: 30 | original_mesh (trimesh.Trimesh): The original high-resolution mesh. 31 | simplified_mesh (trimesh.Trimesh): The simplified mesh. 32 | angle_threshold (float): The dihedral angle threshold for important edges (in degrees). 33 | important_edge_factor (float): Factor to increase weight for important edges. 34 | 35 | Returns: 36 | float: The edge preservation metric. 37 | """ 38 | original_dihedral = calculate_dihedral_angles(original_mesh) 39 | important_original = original_dihedral > np.radians(angle_threshold) 40 | 41 | # Calculate edge midpoints 42 | original_midpoints = original_mesh.vertices[original_mesh.edges].mean(axis=1) 43 | simplified_midpoints = simplified_mesh.vertices[simplified_mesh.edges].mean(axis=1) 44 | 45 | tree = cKDTree(simplified_midpoints) 46 | 47 | # Find closest simplified edge for each original edge 48 | distances, _ = tree.query(original_midpoints) 49 | 50 | # Calculate weighted preservation 51 | weights = np.exp(original_dihedral) 52 | weights[important_original] *= important_edge_factor 53 | weighted_distances = distances * weights 54 | 55 | max_distance = np.max(original_mesh.bounds) - np.min(original_mesh.bounds) 56 | preservation_metric = 1 - np.mean(weighted_distances) / max_distance 57 | 58 | return preservation_metric 59 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/metrics/hausdorff_distance.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | 4 | 5 | def hausdorff_distance( 6 | mesh1: trimesh.Trimesh, mesh2: trimesh.Trimesh, samples: int = 10000 7 | ): 8 | """ 9 | Calculate the Hausdorff distance between two meshes. 10 | 11 | Parameters: 12 | mesh1 (trimesh.Trimesh): The first input mesh. 13 | mesh2 (trimesh.Trimesh): The second input mesh. 14 | samples (int): The number of samples to use for the calculation. 15 | 16 | Returns: 17 | float: The Hausdorff distance between the two meshes. 18 | """ 19 | # Sample points from both meshes 20 | points1 = mesh1.sample(samples) 21 | points2 = mesh2.sample(samples) 22 | 23 | # Calculate distances from points1 to mesh2 24 | _, distances1, _ = trimesh.proximity.closest_point(mesh2, points1) 25 | 26 | # Calculate distances from points2 to mesh1 27 | _, distances2, _ = trimesh.proximity.closest_point(mesh1, points2) 28 | 29 | # Hausdorff distance is the maximum of all minimum distances 30 | return max(np.max(distances1), np.max(distances2)) 31 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/metrics/normal_consistency.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | 4 | 5 | def normal_consistency(mesh, samples=10000): 6 | """ 7 | Calculate the normal vector consistency of a mesh. 8 | 9 | Parameters: 10 | mesh (trimesh.Trimesh): The input mesh. 11 | samples (int): The number of samples to use for the calculation. 12 | 13 | Returns: 14 | float: The normal vector consistency metric. 15 | """ 16 | points = mesh.sample(samples) 17 | 18 | _, _, face_indices = trimesh.proximity.closest_point(mesh, points) 19 | face_normals = mesh.face_normals[face_indices] 20 | 21 | mesh.vertex_normals 22 | 23 | closest_vertices = mesh.nearest.vertex(points)[1] 24 | vertex_normals = mesh.vertex_normals[closest_vertices] 25 | 26 | consistency = np.abs(np.sum(vertex_normals * face_normals, axis=1)) 27 | 28 | return np.mean(consistency) 29 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .point_sampler import PointSampler 2 | from .edge_predictor import EdgePredictor 3 | from .face_classifier import FaceClassifier 4 | from .neural_mesh_simplification import NeuralMeshSimplification 5 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/edge_predictor.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch_geometric.nn import knn_graph 6 | from torch_scatter import scatter_softmax 7 | from torch_sparse import SparseTensor 8 | 9 | from .layers.devconv import DevConv 10 | 11 | warnings.filterwarnings("ignore", message="Sparse CSR tensor support is in beta state") 12 | 13 | 14 | class EdgePredictor(nn.Module): 15 | def __init__(self, in_channels, hidden_channels, k): 16 | super(EdgePredictor, self).__init__() 17 | self.k = k 18 | self.devconv = DevConv(in_channels, hidden_channels) 19 | 20 | # Self-attention components 21 | self.W_q = nn.Linear(hidden_channels, hidden_channels, bias=False) 22 | self.W_k = nn.Linear(hidden_channels, hidden_channels, bias=False) 23 | 24 | def forward(self, x, edge_index): 25 | if edge_index.numel() == 0: 26 | raise ValueError("Edge index is empty") 27 | 28 | # Step 1: Extend original mesh connectivity with k-nearest neighbors 29 | knn_edges = knn_graph(x, k=self.k, flow="target_to_source") 30 | 31 | # Ensure knn_edges indices are within bounds 32 | max_idx = x.size(0) - 1 33 | valid_edges = (knn_edges[0] <= max_idx) & (knn_edges[1] <= max_idx) 34 | knn_edges = knn_edges[:, valid_edges] 35 | 36 | # Combine original edges with knn edges 37 | if edge_index.numel() > 0: 38 | extended_edges = torch.cat([edge_index, knn_edges], dim=1) 39 | # Remove duplicate edges 40 | extended_edges = torch.unique(extended_edges, dim=1) 41 | else: 42 | extended_edges = knn_edges 43 | 44 | # Step 2: Apply DevConv 45 | features = self.devconv(x, extended_edges) 46 | 47 | # Step 3: Apply sparse self-attention 48 | attention_scores = self.compute_attention_scores(features, edge_index) 49 | 50 | # Step 4: Compute simplified adjacency matrix 51 | simplified_adj_indices, simplified_adj_values = ( 52 | self.compute_simplified_adjacency(attention_scores, edge_index) 53 | ) 54 | 55 | return simplified_adj_indices, simplified_adj_values 56 | 57 | def compute_attention_scores(self, features, edges): 58 | if edges.numel() == 0: 59 | raise ValueError("Edge index is empty") 60 | 61 | row, col = edges 62 | q = self.W_q(features) 63 | k = self.W_k(features) 64 | 65 | # Compute (W_q f_j)^T (W_k f_i) 66 | attention = (q[row] * k[col]).sum(dim=-1) 67 | 68 | # Apply softmax for each source node 69 | attention_scores = scatter_softmax(attention, row, dim=0) 70 | 71 | return attention_scores 72 | 73 | def compute_simplified_adjacency(self, attention_scores, edge_index): 74 | if edge_index.numel() == 0: 75 | raise ValueError("Edge index is empty") 76 | 77 | num_nodes = edge_index.max().item() + 1 78 | row, col = edge_index 79 | 80 | # Ensure indices are within bounds 81 | if row.numel() > 0: 82 | assert torch.all(row < num_nodes) and torch.all( 83 | row >= 0 84 | ), f"Row indices out of bounds: min={row.min()}, max={row.max()}, num_nodes={num_nodes}" 85 | if col.numel() > 0: 86 | assert torch.all(col < num_nodes) and torch.all( 87 | col >= 0 88 | ), f"Column indices out of bounds: min={col.min()}, max={col.max()}, num_nodes={num_nodes}" 89 | 90 | # Create sparse attention matrix 91 | S = SparseTensor( 92 | row=row, 93 | col=col, 94 | value=attention_scores, 95 | sparse_sizes=(num_nodes, num_nodes), 96 | trust_data=True, # Since we verified the indices above 97 | ) 98 | 99 | # Create original adjacency matrix 100 | A = SparseTensor( 101 | row=row, 102 | col=col, 103 | value=torch.ones(edge_index.size(1), device=edge_index.device), 104 | sparse_sizes=(num_nodes, num_nodes), 105 | trust_data=True, # Since we verified the indices above 106 | ) 107 | 108 | # Compute A_s = S * A * S^T using coalesced sparse tensors 109 | A_s = S.matmul(A).matmul(S.t()) 110 | 111 | # Convert to COO format 112 | row, col, value = A_s.coo() 113 | indices = torch.stack([row, col], dim=0) 114 | 115 | return indices, value 116 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/face_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layers import TriConv 5 | 6 | 7 | class FaceClassifier(nn.Module): 8 | def __init__(self, input_dim, hidden_dim, num_layers, k): 9 | super(FaceClassifier, self).__init__() 10 | self.k = k 11 | self.num_layers = num_layers 12 | 13 | self.triconv_layers = nn.ModuleList( 14 | [ 15 | TriConv(input_dim if i == 0 else hidden_dim, hidden_dim) 16 | for i in range(num_layers) 17 | ] 18 | ) 19 | 20 | self.final_layer = nn.Linear(hidden_dim, 1) 21 | 22 | def forward(self, x, pos, batch=None): 23 | # Handle empty input 24 | if x.size(0) == 0 or pos.size(0) == 0: 25 | return torch.tensor([], device=x.device) 26 | 27 | # If pos is 3D (num_faces, 3, 3), compute centroids 28 | if pos.dim() == 3: 29 | pos = pos.mean(dim=1) # Average vertex positions to get face centers 30 | 31 | # Construct k-nn graph based on triangle centers 32 | edge_index = self.custom_knn_graph(pos, self.k, batch) 33 | 34 | # Apply TriConv layers 35 | for i in range(self.num_layers): 36 | x = self.triconv_layers[i](x, pos, edge_index) 37 | x = torch.relu(x) 38 | 39 | # Final classification 40 | x = self.final_layer(x) 41 | logits = x.squeeze(-1) # Remove last dimension 42 | 43 | # Apply softmax normalization per batch 44 | if batch is None: 45 | # Global normalization using softmax 46 | probs = torch.softmax(logits, dim=0) 47 | else: 48 | # Per-batch normalization 49 | probs = torch.zeros_like(logits) 50 | for b in range(int(batch.max().item()) + 1): 51 | mask = batch == b 52 | probs[mask] = torch.softmax(logits[mask], dim=0) 53 | 54 | return probs 55 | 56 | def custom_knn_graph(self, x, k, batch=None): 57 | if x.size(0) == 0: 58 | return torch.empty((2, 0), dtype=torch.long, device=x.device) 59 | 60 | batch_size = 1 if batch is None else int(batch.max().item()) + 1 61 | edge_index = [] 62 | 63 | for b in range(batch_size): 64 | if batch is None: 65 | x_batch = x 66 | else: 67 | mask = batch == b 68 | x_batch = x[mask] 69 | 70 | if x_batch.size(0) > 1: 71 | distances = torch.cdist(x_batch, x_batch) 72 | distances.fill_diagonal_(float("inf")) 73 | _, indices = distances.topk(min(k, x_batch.size(0) - 1), largest=False) 74 | 75 | source = ( 76 | torch.arange(x_batch.size(0), device=x.device) 77 | .view(-1, 1) 78 | .expand(-1, indices.size(1)) 79 | ) 80 | edge_index.append( 81 | torch.stack([source.reshape(-1), indices.reshape(-1)]) 82 | ) 83 | 84 | if edge_index: 85 | edge_index = torch.cat(edge_index, dim=1) 86 | 87 | # Make the graph symmetric 88 | edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1) 89 | edge_index = torch.unique(edge_index, dim=1) 90 | else: 91 | edge_index = torch.empty((2, 0), dtype=torch.long, device=x.device) 92 | 93 | return edge_index 94 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .devconv import DevConv 2 | from .triconv import TriConv 3 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/layers/devconv.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch_scatter import scatter_max 3 | 4 | 5 | class DevConv(nn.Module): 6 | def __init__(self, in_channels, out_channels): 7 | super(DevConv, self).__init__() 8 | self.W_theta = nn.Linear(in_channels, out_channels) 9 | self.W_phi = nn.Linear(in_channels, out_channels) 10 | 11 | def forward(self, x, edge_index): 12 | row, col = edge_index 13 | x_i, x_j = x[row], x[col] 14 | 15 | rel_pos = x_i - x_j 16 | rel_pos_transformed = self.W_theta(rel_pos) # [num_edges, out_channels] 17 | 18 | x_transformed = self.W_phi(x) # [num_nodes, out_channels] 19 | 20 | # Aggregate using max pooling 21 | aggr_out = scatter_max(rel_pos_transformed, col, dim=0, dim_size=x.size(0))[0] 22 | 23 | return x_transformed + aggr_out 24 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/layers/triconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_scatter import scatter_max, scatter_add 4 | 5 | 6 | class TriConv(nn.Module): 7 | def __init__(self, in_channels, out_channels): 8 | super(TriConv, self).__init__() 9 | self.in_channels = in_channels 10 | self.out_channels = out_channels 11 | 12 | # Calculate the correct input dimension for the MLP 13 | mlp_input_dim = in_channels + 9 # 9 is from the relative position encoding 14 | 15 | self.mlp = nn.Sequential( 16 | nn.Linear(mlp_input_dim, out_channels), 17 | nn.ReLU(), 18 | nn.Linear(out_channels, out_channels), 19 | ) 20 | self.last_edge_index = None 21 | 22 | def forward(self, x, pos, edge_index): 23 | self.last_edge_index = edge_index 24 | row, col = edge_index 25 | 26 | rel_pos = self.compute_relative_position_encoding(pos, row, col) 27 | x_diff = x[row] - x[col] 28 | mlp_input = torch.cat([rel_pos, x_diff], dim=-1) 29 | 30 | mlp_output = self.mlp(mlp_input) 31 | out = scatter_add(mlp_output, col, dim=0, dim_size=x.size(0)) 32 | 33 | return out 34 | 35 | def compute_relative_position_encoding(self, pos, row, col): 36 | edge_vec = pos[row] - pos[col] 37 | 38 | t_max, _ = scatter_max(edge_vec.abs(), col, dim=0, dim_size=pos.size(0)) 39 | t_min, _ = scatter_max(-edge_vec.abs(), col, dim=0, dim_size=pos.size(0)) 40 | t_min = -t_min 41 | 42 | barycenter = pos.mean(dim=-1, keepdim=True) if pos.dim() == 3 else pos 43 | barycenter_diff = barycenter[row] - barycenter[col] 44 | 45 | t_max_diff = t_max[row] - t_max[col] 46 | t_min_diff = t_min[row] - t_min[col] 47 | barycenter_diff = barycenter_diff.expand_as(t_max_diff) 48 | 49 | rel_pos = torch.cat([t_max_diff, t_min_diff, barycenter_diff], dim=-1) 50 | 51 | return rel_pos 52 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/neural_mesh_simplification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch_geometric 4 | from torch_geometric.data import Data 5 | 6 | from ..models import PointSampler, EdgePredictor, FaceClassifier 7 | 8 | 9 | class NeuralMeshSimplification(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim, 13 | hidden_dim, 14 | edge_hidden_dim, # Separate hidden dim for edge predictor 15 | num_layers, 16 | k, 17 | edge_k, 18 | target_ratio, 19 | device=torch.device("cpu"), 20 | ): 21 | super(NeuralMeshSimplification, self).__init__() 22 | self.device = device 23 | self.point_sampler = PointSampler(input_dim, hidden_dim, num_layers).to( 24 | self.device 25 | ) 26 | self.edge_predictor = EdgePredictor( 27 | input_dim, 28 | hidden_channels=edge_hidden_dim, 29 | k=edge_k, 30 | ).to(self.device) 31 | self.face_classifier = FaceClassifier(input_dim, hidden_dim, num_layers, k).to( 32 | self.device 33 | ) 34 | self.k = k 35 | self.target_ratio = target_ratio 36 | 37 | def forward(self, data: Data): 38 | x, edge_index = data.x, data.edge_index 39 | num_nodes = x.size(0) 40 | 41 | sampled_indices, sampled_probs = self.sample_points(data) 42 | 43 | sampled_x = x[sampled_indices].to(self.device) 44 | sampled_pos = ( 45 | data.pos[sampled_indices] 46 | if hasattr(data, "pos") and data.pos is not None 47 | else sampled_x 48 | ).to(self.device) 49 | 50 | sampled_vertices = sampled_pos # Use sampled_pos directly as vertices 51 | 52 | # Update edge_index to reflect the new indices 53 | sampled_edge_index, _ = torch_geometric.utils.subgraph( 54 | sampled_indices, edge_index, relabel_nodes=True, num_nodes=num_nodes 55 | ) 56 | 57 | # Predict edges 58 | sampled_edge_index = sampled_edge_index.to(self.device) 59 | edge_index_pred, edge_probs = self.edge_predictor(sampled_x, sampled_edge_index) 60 | 61 | # Generate candidate triangles 62 | candidate_triangles, triangle_probs = self.generate_candidate_triangles( 63 | edge_index_pred, edge_probs 64 | ) 65 | 66 | # Classify faces 67 | if candidate_triangles.shape[0] > 0: 68 | # Create triangle features by averaging vertex features 69 | triangle_features = torch.zeros( 70 | (candidate_triangles.shape[0], sampled_x.shape[1]), 71 | device=self.device, 72 | ) 73 | for i in range(3): 74 | triangle_features += sampled_x[candidate_triangles[:, i]] 75 | triangle_features /= 3 76 | 77 | # Calculate triangle centers 78 | triangle_centers = torch.zeros( 79 | (candidate_triangles.shape[0], sampled_pos.shape[1]), 80 | device=self.device, 81 | ) 82 | for i in range(3): 83 | triangle_centers += sampled_pos[candidate_triangles[:, i]] 84 | triangle_centers /= 3 85 | 86 | face_probs = self.face_classifier( 87 | triangle_features, triangle_centers, batch=None 88 | ) 89 | else: 90 | face_probs = torch.empty(0, device=self.device) 91 | 92 | if candidate_triangles.shape[0] == 0: 93 | simplified_faces = torch.empty((0, 3), dtype=torch.long, device=self.device) 94 | else: 95 | threshold = torch.quantile( 96 | face_probs, 1 - self.target_ratio 97 | ) # Use a dynamic threshold 98 | simplified_faces = candidate_triangles[face_probs > threshold] 99 | 100 | return { 101 | "sampled_indices": sampled_indices, 102 | "sampled_probs": sampled_probs, 103 | "sampled_vertices": sampled_vertices, 104 | "edge_index": edge_index_pred, 105 | "edge_probs": edge_probs, 106 | "candidate_triangles": candidate_triangles, 107 | "triangle_probs": triangle_probs, 108 | "face_probs": face_probs, 109 | "simplified_faces": simplified_faces, 110 | } 111 | 112 | def sample_points(self, data: Data): 113 | x, edge_index = data.x, data.edge_index 114 | num_nodes = x.size(0) 115 | 116 | target_nodes = min( 117 | max(int(self.target_ratio * num_nodes), 1), 118 | num_nodes, 119 | ) 120 | 121 | # Sample points 122 | x = x.to(self.device) 123 | edge_index = edge_index.to(self.device) 124 | sampled_probs = self.point_sampler(x, edge_index) 125 | sampled_indices = self.point_sampler.sample( 126 | sampled_probs, num_samples=target_nodes 127 | ) 128 | 129 | return sampled_indices, sampled_probs[sampled_indices] 130 | 131 | def generate_candidate_triangles(self, edge_index, edge_probs): 132 | 133 | # Handle the case when edge_index is empty 134 | if edge_index.numel() == 0: 135 | return ( 136 | torch.empty((0, 3), dtype=torch.long, device=self.device), 137 | torch.empty(0, device=self.device), 138 | ) 139 | 140 | num_nodes = edge_index.max().item() + 1 141 | 142 | # Create an adjacency matrix from the edge index 143 | adj_matrix = torch.zeros(num_nodes, num_nodes, device=self.device) 144 | 145 | # Check if edge_probs is a tuple or a tensor 146 | if isinstance(edge_probs, tuple): 147 | edge_indices, edge_values = edge_probs 148 | adj_matrix[edge_indices[0], edge_indices[1]] = edge_values 149 | else: 150 | adj_matrix[edge_index[0], edge_index[1]] = edge_probs 151 | 152 | # Adjust k based on the number of nodes 153 | k = min(self.k, num_nodes - 1) 154 | 155 | # Find k-nearest neighbors for each node 156 | _, knn_indices = torch.topk(adj_matrix, k=k, dim=1) 157 | 158 | # Generate candidate triangles 159 | triangles = [] 160 | triangle_probs = [] 161 | 162 | for i in range(num_nodes): 163 | neighbors = knn_indices[i] 164 | for j in range(k): 165 | for l in range(j + 1, k): 166 | n1, n2 = neighbors[j], neighbors[l] 167 | if adj_matrix[n1, n2] > 0: # Check if the third edge exists 168 | triangle = torch.tensor([i, n1, n2], device=self.device) 169 | triangles.append(triangle) 170 | 171 | # Calculate triangle probability 172 | prob = ( 173 | adj_matrix[i, n1] * adj_matrix[i, n2] * adj_matrix[n1, n2] 174 | ) ** (1 / 3) 175 | triangle_probs.append(prob) 176 | 177 | if triangles: 178 | triangles = torch.stack(triangles) 179 | triangle_probs = torch.tensor(triangle_probs, device=self.device) 180 | else: 181 | triangles = torch.empty((0, 3), dtype=torch.long, device=self.device) 182 | triangle_probs = torch.empty(0, device=self.device) 183 | 184 | return triangles, triangle_probs 185 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/models/point_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .layers.devconv import DevConv 5 | 6 | 7 | class PointSampler(nn.Module): 8 | def __init__(self, in_channels, out_channels, num_layers): 9 | super(PointSampler, self).__init__() 10 | self.num_layers = num_layers 11 | 12 | # Stack of DevConv layers 13 | self.convs = nn.ModuleList() 14 | self.convs.append(DevConv(in_channels, out_channels)) 15 | for _ in range(num_layers - 1): 16 | self.convs.append(DevConv(out_channels, out_channels)) 17 | 18 | # Final output layer to produce a single score per vertex 19 | self.output_layer = nn.Linear(out_channels, 1) 20 | 21 | # Activation functions 22 | self.activation = nn.ReLU() 23 | self.sigmoid = nn.Sigmoid() 24 | 25 | def forward(self, x, edge_index): 26 | # x: Node features [num_nodes, in_channels] 27 | # edge_index: Graph connectivity [2, num_edges] 28 | 29 | # Apply DevConv layers 30 | for conv in self.convs: 31 | x = conv(x, edge_index) 32 | x = self.activation(x) 33 | 34 | # Generate inclusion scores 35 | scores = self.output_layer(x).squeeze(-1) 36 | 37 | # Convert scores to probabilities 38 | probabilities = self.sigmoid(scores) 39 | 40 | return probabilities 41 | 42 | def sample(self, probabilities, num_samples): 43 | max_samples = probabilities.shape[0] 44 | if num_samples > max_samples: 45 | raise ValueError( 46 | f"num_samples ({num_samples}) cannot be larger than number of vertices ({max_samples})" 47 | ) 48 | 49 | # Multinomial sampling based on probabilities 50 | sampled_indices = torch.multinomial( 51 | probabilities, num_samples, replacement=False 52 | ) 53 | 54 | return sampled_indices 55 | 56 | def forward_and_sample(self, x, edge_index, num_samples): 57 | # Combine forward pass and sampling in one step 58 | probabilities = self.forward(x, edge_index) 59 | sampled_indices = self.sample(probabilities, num_samples) 60 | return sampled_indices, probabilities 61 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/trainer/resource_monitor.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import psutil 4 | import torch 5 | 6 | 7 | def monitor_resources(stop_event, main_pid): 8 | main_process = psutil.Process(main_pid) 9 | 10 | interval_idx = 0 11 | 12 | while not stop_event.is_set(): 13 | try: 14 | # Get CPU and memory usage for the main process 15 | cpu_percent = main_process.cpu_percent(interval=1) 16 | memory_info = main_process.memory_info() 17 | total_memory_rss = memory_info.rss # Start with main process memory usage 18 | 19 | # Include children processes 20 | for child in main_process.children(recursive=True): 21 | try: 22 | cpu_percent += child.cpu_percent(interval=None) 23 | child_memory = child.memory_info() 24 | total_memory_rss += ( 25 | child_memory.rss 26 | ) # Add child process memory usage 27 | except psutil.NoSuchProcess: 28 | pass # Child process no longer exists 29 | 30 | memory_usage_mb = total_memory_rss / (1024 * 1024) # Convert to MB 31 | 32 | interval_idx += 1 33 | 34 | output = f"\rCPU: {cpu_percent:.1f}% | Memory: {memory_usage_mb:.2f} MB" 35 | 36 | # Get GPU and GPU memory usage 37 | if torch.cuda.is_available(): 38 | gpu_info = [] 39 | for i in range(torch.cuda.device_count()): 40 | gpu_util = torch.cuda.utilization(i) # GPU Utilization 41 | mem_alloc = torch.cuda.memory_allocated(i) / ( 42 | 1024 * 1024 43 | ) # Convert to MB 44 | mem_total = torch.cuda.get_device_properties(i).total_memory / ( 45 | 1024 * 1024 46 | ) # Total VRAM 47 | gpu_info.append( 48 | f"GPU {i}: {gpu_util:.1f}% | Mem: {mem_alloc:.2f}/{mem_total:.2f} MB" 49 | ) 50 | 51 | output += " | " + " | ".join(gpu_info) 52 | 53 | if interval_idx % 3 == 0: 54 | output += "\n" 55 | 56 | print(output, end="", flush=True) 57 | 58 | except psutil.NoSuchProcess: 59 | print("\nMain process has terminated.") 60 | break 61 | time.sleep(1) 62 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from multiprocessing import Event, Process 4 | from typing import Dict, Any 5 | 6 | import torch 7 | from torch.optim import Adam 8 | from torch.optim.lr_scheduler import ReduceLROnPlateau 9 | from torch.utils.data import random_split 10 | from torch_geometric.loader import DataLoader 11 | 12 | from .resource_monitor import monitor_resources 13 | from ..data import MeshSimplificationDataset 14 | from ..losses import CombinedMeshSimplificationLoss 15 | from ..metrics import ( 16 | chamfer_distance, 17 | normal_consistency, 18 | edge_preservation, 19 | hausdorff_distance, 20 | ) 21 | from ..models import NeuralMeshSimplification 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class Trainer: 27 | def __init__(self, config: Dict[str, Any]): 28 | self.config = config 29 | logger.info("Initializing trainer...") 30 | 31 | if torch.cuda.is_available(): 32 | self.device = torch.device("cuda") 33 | else: 34 | self.device = torch.device("cpu") 35 | logger.info(f"Using device: {self.device}") 36 | 37 | logger.info("Initializing model...") 38 | self.model = NeuralMeshSimplification( 39 | input_dim=config["model"]["input_dim"], 40 | hidden_dim=config["model"]["hidden_dim"], 41 | edge_hidden_dim=config["model"]["edge_hidden_dim"], 42 | num_layers=config["model"]["num_layers"], 43 | k=config["model"]["k"], 44 | edge_k=config["model"]["edge_k"], 45 | target_ratio=config["model"]["target_ratio"], 46 | device=self.device, 47 | ) 48 | 49 | logger.debug("Setting up optimizer and loss...") 50 | self.optimizer = Adam( 51 | self.model.parameters(), 52 | lr=config["training"]["learning_rate"], 53 | weight_decay=config["training"]["weight_decay"], 54 | ) 55 | self.scheduler = ReduceLROnPlateau( 56 | self.optimizer, mode="min", factor=0.1, patience=10, verbose=True 57 | ) 58 | self.criterion = CombinedMeshSimplificationLoss( 59 | lambda_c=config["loss"]["lambda_c"], 60 | lambda_e=config["loss"]["lambda_e"], 61 | lambda_o=config["loss"]["lambda_o"], 62 | device=self.device, 63 | ) 64 | self.early_stopping_patience = config["training"]["early_stopping_patience"] 65 | self.best_val_loss = float("inf") 66 | self.early_stopping_counter = 0 67 | self.checkpoint_dir = config["training"]["checkpoint_dir"] 68 | os.makedirs(self.checkpoint_dir, exist_ok=True) 69 | 70 | logger.debug("Preparing data loaders...") 71 | self.train_loader, self.val_loader = self._prepare_data_loaders() 72 | logger.info("Trainer initialization complete.") 73 | 74 | if "monitor_resources" in config: 75 | logger.info("Monitoring resource usage") 76 | self.monitor_resources = True 77 | self.stop_event = Event() 78 | self.monitor_process = None 79 | else: 80 | self.monitor_resources = False 81 | 82 | def _prepare_data_loaders(self): 83 | logger.info(f"Loading dataset from {self.config['data']['data_dir']}") 84 | dataset = MeshSimplificationDataset( 85 | data_dir=self.config["data"]["data_dir"], 86 | preprocess=False, # Can be False, because data has been prepared before training 87 | ) 88 | logger.debug(f"Dataset size: {len(dataset)}") 89 | 90 | val_size = int(len(dataset) * self.config["data"]["val_split"]) 91 | train_size = len(dataset) - val_size 92 | logger.info(f"Splitting dataset: {train_size} train, {val_size} validation") 93 | 94 | train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) 95 | assert ( 96 | len(val_dataset) > 0 97 | ), f"There is not enough data to define an evaluation set. len(dataset)={len(dataset)}, train_size={train_size}, val_size={val_size}" 98 | 99 | num_workers = self.config["training"].get("num_workers", os.cpu_count()) 100 | logger.info(f"Using {num_workers} workers for data loading") 101 | 102 | train_loader = DataLoader( 103 | train_dataset, 104 | batch_size=self.config["training"]["batch_size"], 105 | shuffle=True, 106 | num_workers=num_workers, 107 | follow_batch=["x", "pos"], 108 | ) 109 | 110 | val_loader = DataLoader( 111 | val_dataset, 112 | batch_size=self.config["training"]["batch_size"], 113 | shuffle=False, 114 | num_workers=num_workers, 115 | follow_batch=["x", "pos"], 116 | ) 117 | logger.info("Data loaders prepared successfully") 118 | 119 | return train_loader, val_loader 120 | 121 | def train(self): 122 | if self.monitor_resources: 123 | main_pid = os.getpid() 124 | self.monitor_process = Process( 125 | target=monitor_resources, args=(self.stop_event, main_pid) 126 | ) 127 | self.monitor_process.start() 128 | 129 | try: 130 | for epoch in range(self.config["training"]["num_epochs"]): 131 | loss = self._train_one_epoch(epoch) 132 | 133 | logging.info( 134 | f"Epoch [{epoch + 1}/{self.config['training']['num_epochs']}], Loss: {loss / len(self.train_loader)}" 135 | ) 136 | 137 | val_loss = self._validate() 138 | logging.info( 139 | f"Epoch [{epoch + 1}/{self.config['training']['num_epochs']}], Validation Loss: {val_loss}" 140 | ) 141 | 142 | self.scheduler.step(val_loss) 143 | 144 | # Save the checkpoint 145 | self._save_checkpoint(epoch, val_loss) 146 | 147 | if torch.cuda.is_available(): 148 | torch.cuda.empty_cache() 149 | 150 | # Early stop as required 151 | if self._early_stopping(val_loss): 152 | logging.info("Early stopping triggered.") 153 | break 154 | except Exception as e: 155 | logger.error(f"{str(e)}") 156 | finally: 157 | if self.monitor_resources: 158 | self.stop_event.set() 159 | self.monitor_process.join() 160 | print() # New line after monitoring output 161 | 162 | def _train_one_epoch(self, epoch: int) -> float: 163 | self.model.train() 164 | running_loss = 0.0 165 | logger.debug(f"Starting epoch {epoch + 1}") 166 | 167 | for batch_idx, batch in enumerate(self.train_loader): 168 | logger.debug(f"Processing batch {batch_idx + 1}") 169 | self.optimizer.zero_grad() 170 | output = self.model(batch) 171 | loss = self.criterion(batch, output) 172 | 173 | del batch 174 | del output 175 | 176 | loss.backward() 177 | self.optimizer.step() 178 | running_loss += loss.item() 179 | 180 | return running_loss / len(self.train_loader) 181 | 182 | def _validate(self) -> float: 183 | self.model.eval() 184 | val_loss = 0.0 185 | with torch.no_grad(): 186 | for batch in self.val_loader: 187 | output = self.model(batch) 188 | loss = self.criterion(batch, output) 189 | val_loss += loss.item() 190 | 191 | return val_loss / len(self.val_loader) 192 | 193 | def _save_checkpoint(self, epoch: int, val_loss: float): 194 | checkpoint_path = os.path.join( 195 | self.checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth" 196 | ) 197 | 198 | logging.debug(f"Saving checkpoint to {checkpoint_path}") 199 | torch.save( 200 | { 201 | "epoch": epoch + 1, 202 | "model_state_dict": self.model.state_dict(), 203 | "optimizer_state_dict": self.optimizer.state_dict(), 204 | "val_loss": val_loss, 205 | }, 206 | checkpoint_path, 207 | ) 208 | if val_loss < self.best_val_loss: 209 | self.best_val_loss = val_loss 210 | best_model_path = os.path.join(self.checkpoint_dir, "best_model.pth") 211 | logging.debug(f"Saving best model to {best_model_path}") 212 | torch.save(self.model.state_dict(), best_model_path) 213 | 214 | # Remove old checkpoints to save space 215 | for old_checkpoint in os.listdir(self.checkpoint_dir): 216 | if ( 217 | old_checkpoint.startswith("checkpoint_") 218 | and old_checkpoint != f"checkpoint_epoch_{epoch + 1}.pth" 219 | ): 220 | os.remove(os.path.join(self.checkpoint_dir, old_checkpoint)) 221 | 222 | def _early_stopping(self, val_loss: float) -> bool: 223 | if val_loss < self.best_val_loss: 224 | self.early_stopping_counter = 0 225 | else: 226 | self.early_stopping_counter += 1 227 | return self.early_stopping_counter >= self.early_stopping_patience 228 | 229 | def load_checkpoint(self, checkpoint_path: str): 230 | checkpoint = torch.load(checkpoint_path) 231 | self.model.load_state_dict(checkpoint["model_state_dict"]) 232 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 233 | self.best_val_loss = checkpoint["val_loss"] 234 | logging.info( 235 | f"Loaded checkpoint from {checkpoint_path} (epoch {checkpoint['epoch']})" 236 | ) 237 | 238 | def log_metrics(self, metrics: Dict[str, float], epoch: int): 239 | log_message = f"Epoch [{epoch + 1}/{self.config['training']['num_epochs']}], " 240 | log_message += ", ".join( 241 | [f"{key}: {value:.4f}" for key, value in metrics.items()] 242 | ) 243 | logging.info(log_message) 244 | 245 | def evaluate(self, data_loader: DataLoader) -> Dict[str, float]: 246 | self.model.eval() 247 | metrics = { 248 | "chamfer_distance": 0.0, 249 | "normal_consistency": 0.0, 250 | "edge_preservation": 0.0, 251 | "hausdorff_distance": 0.0, 252 | } 253 | with torch.no_grad(): 254 | for batch in data_loader: 255 | output = self.model(batch) 256 | 257 | # TODO: Define methods that can operate on a batch instead of a trimesh object 258 | 259 | metrics["chamfer_distance"] += chamfer_distance(batch, output) 260 | metrics["normal_consistency"] += normal_consistency(batch, output) 261 | metrics["edge_preservation"] += edge_preservation(batch, output) 262 | metrics["hausdorff_distance"] += hausdorff_distance(batch, output) 263 | for key in metrics: 264 | metrics[key] /= len(data_loader) 265 | return metrics 266 | 267 | def handle_error(self, error: Exception): 268 | logging.error(f"An error occurred: {error}") 269 | if isinstance(error, RuntimeError) and "out of memory" in str(error): 270 | logging.error("Out of memory error. Attempting to recover.") 271 | torch.cuda.empty_cache() 272 | else: 273 | raise error 274 | 275 | def save_training_state(self, state_path: str): 276 | torch.save( 277 | { 278 | "model_state_dict": self.model.state_dict(), 279 | "optimizer_state_dict": self.optimizer.state_dict(), 280 | "best_val_loss": self.best_val_loss, 281 | "early_stopping_counter": self.early_stopping_counter, 282 | }, 283 | state_path, 284 | ) 285 | 286 | def load_training_state(self, state_path: str): 287 | state = torch.load(state_path) 288 | self.model.load_state_dict(state["model_state_dict"]) 289 | self.optimizer.load_state_dict(state["optimizer_state_dict"]) 290 | self.best_val_loss = state["best_val_loss"] 291 | self.early_stopping_counter = state["early_stopping_counter"] 292 | logging.info(f"Loaded training state from {state_path}") 293 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .mesh_operations import build_graph_from_mesh 2 | -------------------------------------------------------------------------------- /src/neural_mesh_simplification/utils/mesh_operations.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import trimesh 3 | 4 | 5 | def simplify_mesh(mesh, target_faces): 6 | # Simplify a mesh to a target number of faces 7 | pass 8 | 9 | 10 | def calculate_mesh_features(mesh): 11 | # Calculate relevant features of a mesh (e.g., curvature) 12 | pass 13 | 14 | 15 | def align_meshes(mesh1, mesh2): 16 | # Align two meshes (useful for comparison) 17 | pass 18 | 19 | 20 | def compare_meshes(mesh1, mesh2): 21 | # Compare two meshes (e.g., Hausdorff distance) 22 | pass 23 | 24 | 25 | def build_graph_from_mesh(mesh: trimesh.Trimesh) -> nx.Graph: 26 | """Build a graph structure from a mesh.""" 27 | G = nx.Graph() 28 | 29 | # Add nodes (vertices) 30 | for i, vertex in enumerate(mesh.vertices): 31 | G.add_node(i, pos=vertex) 32 | 33 | # Add edges 34 | for face in mesh.faces: 35 | G.add_edge(face[0], face[1]) 36 | G.add_edge(face[1], face[2]) 37 | G.add_edge(face[2], face[0]) 38 | 39 | return G 40 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinnormark/neural-mesh-simplification/2ada79100932189cfd28a56093ba3d72e24f3ec7/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import trimesh 3 | 4 | 5 | @pytest.fixture 6 | def sample_mesh(): 7 | # Create a simple cube mesh for testing 8 | mesh = trimesh.creation.box() 9 | return mesh 10 | -------------------------------------------------------------------------------- /tests/losses/test_edge_crossings_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from neural_mesh_simplification.losses import EdgeCrossingLoss 5 | 6 | 7 | @pytest.fixture 8 | def loss_fn(): 9 | return EdgeCrossingLoss(k=2) 10 | 11 | 12 | @pytest.fixture 13 | def sample_data(): 14 | vertices = torch.tensor( 15 | [ 16 | [0, 0, 0], 17 | [1, 0, 0], 18 | [0, 1, 0], 19 | [1, 1, 0], 20 | [0, 0, 1], 21 | [1, 0, 1], 22 | [0, 1, 1], 23 | [1, 1, 1], 24 | ], 25 | dtype=torch.float, 26 | ) 27 | faces = torch.tensor([[0, 1, 2], [1, 3, 2], [4, 5, 6], [5, 7, 6]], dtype=torch.long) 28 | face_probs = torch.tensor([0.8, 0.6, 0.7, 0.9], dtype=torch.float) 29 | return vertices, faces, face_probs 30 | 31 | 32 | def test_find_nearest_triangles(loss_fn): 33 | vertices = torch.tensor( 34 | [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1]], 35 | dtype=torch.float, 36 | ) 37 | faces = torch.tensor([[0, 1, 2], [1, 3, 2], [0, 4, 1], [1, 4, 5]], dtype=torch.long) 38 | 39 | nearest = loss_fn.find_nearest_triangles(vertices, faces) 40 | assert nearest.shape[0] == faces.shape[0] 41 | assert nearest.shape[1] == 1 # k-1 = 1, since k=2 42 | 43 | 44 | def test_detect_edge_crossings(loss_fn, sample_data): 45 | vertices, faces, _ = sample_data 46 | 47 | nearest_triangles = torch.tensor([[1], [0], [3], [2]], dtype=torch.long) 48 | crossings = loss_fn.detect_edge_crossings(vertices, faces, nearest_triangles) 49 | 50 | # Test expected number of crossings 51 | # Here, you should check for specific cases based on the vertex configuration 52 | # Example: You expect 0 crossings if triangles are separate 53 | assert torch.all(crossings == 0) # Modify this based on your actual expectations 54 | 55 | 56 | def test_calculate_loss(loss_fn, sample_data): 57 | _, _, face_probs = sample_data 58 | 59 | crossings = torch.tensor([1.0, 0.0, 2.0, 1.0], dtype=torch.float) 60 | loss = loss_fn.calculate_loss(crossings, face_probs) 61 | 62 | num_faces = face_probs.shape[0] 63 | expected_loss = torch.sum(face_probs * crossings, dtype=torch.float32) / num_faces 64 | assert torch.isclose( 65 | loss, torch.tensor(expected_loss) 66 | ), f"Expected loss: {expected_loss}, but got {loss.item()}" 67 | 68 | 69 | def test_edge_crossing_loss_full(loss_fn, sample_data): 70 | vertices, faces, face_probs = sample_data 71 | 72 | # Run the forward pass of the loss function 73 | loss = loss_fn(vertices, faces, face_probs) 74 | 75 | # Check that the loss value is as expected 76 | # Ensure expected_loss is a floating-point tensor 77 | expected_loss = torch.tensor( 78 | 0.0, dtype=torch.float 79 | ) # Modify this based on your setup 80 | assert torch.isclose( 81 | loss, expected_loss 82 | ), f"Expected loss: {expected_loss.item()}, but got {loss.item()}" 83 | -------------------------------------------------------------------------------- /tests/losses/test_overlapping_triangles_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from neural_mesh_simplification.losses import OverlappingTrianglesLoss 5 | 6 | 7 | @pytest.fixture 8 | def loss_fn(): 9 | return OverlappingTrianglesLoss(num_samples=5, k=3) 10 | 11 | 12 | @pytest.fixture 13 | def sample_data(): 14 | vertices = torch.tensor( 15 | [ 16 | [0.0, 0.0, 0.0], 17 | [1.0, 0.0, 0.0], 18 | [0.0, 1.0, 0.0], 19 | [1.0, 1.0, 0.0], 20 | [0.0, 0.0, 1.0], 21 | [1.0, 0.0, 1.0], 22 | [0.0, 1.0, 1.0], 23 | [1.0, 1.0, 1.0], 24 | ], 25 | dtype=torch.float, 26 | ) 27 | faces = torch.tensor([[0, 1, 2], [1, 3, 2], [4, 5, 6], [5, 7, 6]], dtype=torch.long) 28 | return vertices, faces 29 | 30 | 31 | def test_sample_points_from_triangles(loss_fn, sample_data): 32 | vertices, faces = sample_data 33 | 34 | sampled_points, point_face_indices = loss_fn.sample_points_from_triangles( 35 | vertices, faces 36 | ) 37 | 38 | assert sampled_points.shape[0] == point_face_indices.shape[0] 39 | 40 | # Check the shape of the sampled points 41 | expected_shape = (faces.shape[0] * loss_fn.num_samples, 3) 42 | assert ( 43 | sampled_points.shape == expected_shape 44 | ), f"Expected shape {expected_shape}, got {sampled_points.shape}" 45 | 46 | # Check that points lie within the bounding box of the mesh 47 | assert torch.all( 48 | sampled_points >= vertices.min(dim=0).values 49 | ), "Sampled points are outside the mesh bounds" 50 | assert torch.all( 51 | sampled_points <= vertices.max(dim=0).values 52 | ), "Sampled points are outside the mesh bounds" 53 | 54 | 55 | def test_find_nearest_triangles(loss_fn, sample_data): 56 | vertices, faces = sample_data 57 | 58 | sampled_points, _ = loss_fn.sample_points_from_triangles(vertices, faces) 59 | nearest_triangles = loss_fn.find_nearest_triangles(sampled_points, vertices, faces) 60 | 61 | # Check the shape of the nearest triangles tensor 62 | expected_shape = (sampled_points.shape[0], loss_fn.k) 63 | assert ( 64 | nearest_triangles.shape == expected_shape 65 | ), f"Expected shape {expected_shape}, got {nearest_triangles.shape}" 66 | 67 | # Check that the indices are within the range of faces 68 | assert torch.all(nearest_triangles >= 0) and torch.all( 69 | nearest_triangles < faces.shape[0] 70 | ), "Invalid triangle indices" 71 | 72 | 73 | def test_calculate_overlap_loss(loss_fn, sample_data): 74 | vertices, faces = sample_data 75 | 76 | sampled_points, point_face_indices = loss_fn.sample_points_from_triangles( 77 | vertices, faces 78 | ) 79 | nearest_triangles = loss_fn.find_nearest_triangles(sampled_points, vertices, faces) 80 | overlap_loss = loss_fn.calculate_overlap_loss( 81 | sampled_points, vertices, faces, nearest_triangles, point_face_indices 82 | ) 83 | 84 | # Check that the overlap loss is a scalar 85 | assert ( 86 | isinstance(overlap_loss, torch.Tensor) and overlap_loss.dim() == 0 87 | ), "Overlap loss should be a scalar" 88 | 89 | # For this simple test case, the overlap should be minimal or zero 90 | assert overlap_loss.item() >= 0, "Overlap loss should be non-negative" 91 | 92 | 93 | def test_overlapping_triangles_loss_full(loss_fn, sample_data): 94 | vertices, faces = sample_data 95 | 96 | # Run the forward pass of the loss function 97 | loss = loss_fn(vertices, faces) 98 | 99 | # Check that the loss is a scalar 100 | assert isinstance(loss, torch.Tensor) and loss.dim() == 0, "Loss should be a scalar" 101 | 102 | # Expected behavior: the loss should be non-negative 103 | assert loss.item() >= 0, "Overlap loss should be non-negative" 104 | -------------------------------------------------------------------------------- /tests/losses/test_proba_chamfer_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from neural_mesh_simplification.losses import ProbabilisticChamferDistanceLoss 4 | 5 | 6 | @pytest.fixture 7 | def pcd_loss(): 8 | return ProbabilisticChamferDistanceLoss() 9 | 10 | 11 | def test_pcd_loss_zero_distance(pcd_loss): 12 | P = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.float32) 13 | Ps = P.clone() 14 | probs = torch.tensor([0.5, 0.5], dtype=torch.float32) 15 | 16 | loss = pcd_loss(P, Ps, probs) 17 | assert torch.isclose( 18 | loss, torch.tensor(0.0) 19 | ), f"Expected loss to be 0, but got {loss.item()}" 20 | 21 | 22 | def test_pcd_loss_nonzero_distance(pcd_loss): 23 | P = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.float32) 24 | Ps = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.float32) 25 | probs = torch.tensor([0.5, 0.5], dtype=torch.float32) 26 | 27 | loss = pcd_loss(P, Ps, probs) 28 | assert loss > 0, f"Expected loss to be positive, but got {loss.item()}" 29 | assert torch.isfinite(loss), f"Expected loss to be finite, but got {loss.item()}" 30 | 31 | 32 | def test_pcd_loss_empty_input(pcd_loss): 33 | P = torch.empty((0, 3)) 34 | Ps = torch.empty((0, 3)) 35 | probs = torch.empty(0) 36 | 37 | loss = pcd_loss(P, Ps, probs) 38 | assert loss == 0, f"Expected loss to be 0 for empty input, but got {loss.item()}" 39 | 40 | 41 | def test_pcd_loss_different_sizes(pcd_loss): 42 | P = torch.rand((100, 3)) 43 | Ps = torch.rand((50, 3)) 44 | probs = torch.rand(50) 45 | 46 | loss = pcd_loss(P, Ps, probs) 47 | assert torch.isfinite(loss), f"Expected loss to be finite, but got {loss.item()}" 48 | assert loss >= 0, f"Expected non-negative loss, but got {loss.item()}" 49 | 50 | 51 | def test_pcd_loss_gradient(pcd_loss): 52 | P = torch.randn((10, 3), requires_grad=True) 53 | Ps = torch.randn((5, 3), requires_grad=True) 54 | probs = torch.rand(5, requires_grad=True) 55 | 56 | loss = pcd_loss(P, Ps, probs) 57 | loss.backward() 58 | 59 | assert P.grad is not None, "Gradient for P should not be None" 60 | assert Ps.grad is not None, "Gradient for Ps should not be None" 61 | assert probs.grad is not None, "Gradient for probabilities should not be None" 62 | assert torch.isfinite(P.grad).all(), "Gradient for P should be finite" 63 | assert torch.isfinite(Ps.grad).all(), "Gradient for Ps should be finite" 64 | assert torch.isfinite( 65 | probs.grad 66 | ).all(), "Gradient for probabilities should be finite" 67 | -------------------------------------------------------------------------------- /tests/losses/test_proba_surface_distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytest 3 | from neural_mesh_simplification.losses.surface_distance_loss import ( 4 | ProbabilisticSurfaceDistanceLoss, 5 | ) 6 | 7 | 8 | @pytest.fixture 9 | def loss_fn(): 10 | return ProbabilisticSurfaceDistanceLoss(k=3, num_samples=100) 11 | 12 | 13 | @pytest.fixture 14 | def simple_cube_data(): 15 | vertices = torch.tensor( 16 | [ 17 | [0, 0, 0], 18 | [1, 0, 0], 19 | [0, 1, 0], 20 | [1, 1, 0], 21 | [0, 0, 1], 22 | [1, 0, 1], 23 | [0, 1, 1], 24 | [1, 1, 1], 25 | ], 26 | dtype=torch.float32, 27 | ) 28 | 29 | faces = torch.tensor( 30 | [ 31 | [0, 1, 2], 32 | [1, 3, 2], 33 | [4, 5, 6], 34 | [5, 7, 6], 35 | [0, 4, 1], 36 | [1, 4, 5], 37 | [2, 3, 6], 38 | [3, 7, 6], 39 | [0, 2, 4], 40 | [2, 6, 4], 41 | [1, 5, 3], 42 | [3, 5, 7], 43 | ], 44 | dtype=torch.long, 45 | ) 46 | 47 | return vertices, faces 48 | 49 | 50 | def test_loss_zero_for_identical_meshes(loss_fn, simple_cube_data): 51 | vertices, faces = simple_cube_data 52 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32) 53 | 54 | loss = loss_fn(vertices, faces, vertices, faces, face_probs) 55 | print(f"Loss for identical meshes: {loss.item()}") 56 | assert loss.item() < 1e-5 57 | 58 | 59 | def test_loss_increases_with_vertex_displacement(loss_fn, simple_cube_data): 60 | vertices, faces = simple_cube_data 61 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32) 62 | 63 | displaced_vertices = vertices.clone() 64 | displaced_vertices[0] += torch.tensor([0.1, 0.1, 0.1]) 65 | 66 | loss_original = loss_fn(vertices, faces, vertices, faces, face_probs) 67 | loss_displaced = loss_fn(vertices, faces, displaced_vertices, faces, face_probs) 68 | 69 | print( 70 | f"Original loss: {loss_original.item()}, Displaced loss: {loss_displaced.item()}" 71 | ) 72 | assert loss_displaced > loss_original 73 | assert not torch.isclose(loss_displaced, loss_original, atol=1e-5) 74 | 75 | 76 | def test_loss_increases_with_lower_face_probabilities(loss_fn, simple_cube_data): 77 | vertices, faces = simple_cube_data 78 | high_probs = torch.ones(faces.shape[0], dtype=torch.float32) 79 | low_probs = torch.full((faces.shape[0],), 0.5, dtype=torch.float32) 80 | 81 | loss_high = loss_fn(vertices, faces, vertices, faces, high_probs) 82 | loss_low = loss_fn(vertices, faces, vertices, faces, low_probs) 83 | 84 | assert loss_low > loss_high 85 | 86 | 87 | def test_loss_handles_empty_meshes(loss_fn): 88 | empty_vertices = torch.empty((0, 3), dtype=torch.float32) 89 | empty_faces = torch.empty((0, 3), dtype=torch.long) 90 | empty_probs = torch.empty(0, dtype=torch.float32) 91 | 92 | loss = loss_fn( 93 | empty_vertices, empty_faces, empty_vertices, empty_faces, empty_probs 94 | ) 95 | assert loss.item() == 0.0 96 | 97 | 98 | def test_loss_is_symmetric(loss_fn, simple_cube_data): 99 | vertices, faces = simple_cube_data 100 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32) 101 | 102 | loss_forward = loss_fn(vertices, faces, vertices, faces, face_probs) 103 | loss_reverse = loss_fn(vertices, faces, vertices, faces, face_probs) 104 | 105 | assert torch.isclose(loss_forward, loss_reverse, atol=1e-6) 106 | 107 | 108 | def test_loss_gradients(loss_fn, simple_cube_data): 109 | vertices, faces = simple_cube_data 110 | vertices.requires_grad = True 111 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32) 112 | face_probs.requires_grad = True 113 | 114 | loss = loss_fn(vertices, faces, vertices, faces, face_probs) 115 | loss.backward() 116 | 117 | assert vertices.grad is not None 118 | assert face_probs.grad is not None 119 | assert not torch.isnan(vertices.grad).any() 120 | assert not torch.isnan(face_probs.grad).any() 121 | -------------------------------------------------------------------------------- /tests/losses/test_triangle_collision_loss.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from neural_mesh_simplification.losses.triangle_collision_loss import ( 4 | TriangleCollisionLoss, 5 | ) 6 | 7 | 8 | @pytest.fixture 9 | def collision_loss(): 10 | return TriangleCollisionLoss(k=1, collision_threshold=1e-6) 11 | 12 | 13 | @pytest.fixture 14 | def test_meshes(): 15 | return { 16 | "non_penetrating": { 17 | "vertices": torch.tensor( 18 | [ 19 | [0, 0, 0], 20 | [1, 0, 0], 21 | [0, 1, 0], # Triangle 1 22 | [1, 1, 0], # Additional vertex for Triangle 2 23 | ], 24 | dtype=torch.float32, 25 | ), 26 | "faces": torch.tensor([[0, 1, 2], [1, 3, 2]], dtype=torch.long), 27 | }, 28 | "edge_penetrating": { 29 | "vertices": torch.tensor( 30 | [ 31 | [0, 0, 0], 32 | [1, 0, 0], 33 | [0.5, 1, 0], # Triangle 1 34 | [0.25, 0.25, -0.5], 35 | [0.75, 0.25, 0.5], 36 | [0.5, 0.75, 0], # Triangle 2 37 | ], 38 | dtype=torch.float32, 39 | ), 40 | "faces": torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.long), 41 | }, 42 | } 43 | 44 | 45 | def test_collision_detection(collision_loss, test_meshes): 46 | for name, data in test_meshes.items(): 47 | face_probabilities = torch.ones(data["faces"].shape[0], dtype=torch.float32) 48 | loss = collision_loss(data["vertices"], data["faces"], face_probabilities) 49 | 50 | print(f"{name} - Loss: {loss.item()}") 51 | 52 | if name == "non_penetrating": 53 | assert ( 54 | loss.item() == 0 55 | ), f"Non-zero loss computed for non-penetrating case: {name}" 56 | else: 57 | assert loss.item() > 0, f"Zero loss computed for penetrating case: {name}" 58 | 59 | 60 | def test_collision_loss_values(collision_loss, test_meshes): 61 | data = test_meshes["edge_penetrating"] 62 | face_probabilities = torch.ones(data["faces"].shape[0], dtype=torch.float32) 63 | loss = collision_loss(data["vertices"], data["faces"], face_probabilities) 64 | 65 | assert ( 66 | loss.item() > 0 67 | ), f"Expected positive loss for edge penetrating case, got {loss.item()}" 68 | print(f"Edge penetrating loss: {loss.item()}") 69 | 70 | 71 | def test_collision_loss_with_probabilities(collision_loss, test_meshes): 72 | data = test_meshes["edge_penetrating"] 73 | face_probabilities = torch.tensor([0.5, 0.7], dtype=torch.float32) 74 | loss = collision_loss(data["vertices"], data["faces"], face_probabilities) 75 | 76 | assert ( 77 | loss.item() > 0 78 | ), f"Expected positive loss for edge penetrating case with probabilities, got {loss.item()}" 79 | print(f"Edge penetrating loss with probabilities: {loss.item()}") 80 | 81 | 82 | def test_empty_mesh(collision_loss): 83 | empty_vertices = torch.empty((0, 3), dtype=torch.float32) 84 | empty_faces = torch.empty((0, 3), dtype=torch.long) 85 | empty_probabilities = torch.empty(0, dtype=torch.float32) 86 | 87 | loss = collision_loss(empty_vertices, empty_faces, empty_probabilities) 88 | assert loss.item() == 0, f"Expected zero loss for empty mesh, got {loss.item()}" 89 | 90 | 91 | @pytest.fixture 92 | def complex_mesh(): 93 | vertices = torch.tensor( 94 | [ 95 | [0, 0, 0], 96 | [1, 0, 0], 97 | [0, 1, 0], 98 | [1, 1, 0], 99 | [0, 0, 1], 100 | [1, 0, 1], 101 | [0, 1, 1], 102 | [1, 1, 1], 103 | ], 104 | dtype=torch.float32, 105 | ) 106 | faces = torch.tensor( 107 | [ 108 | [0, 1, 2], 109 | [1, 3, 2], 110 | [4, 5, 6], 111 | [5, 7, 6], 112 | [0, 4, 1], 113 | [1, 4, 5], 114 | [2, 3, 6], 115 | [3, 7, 6], 116 | ], 117 | dtype=torch.long, 118 | ) 119 | return {"vertices": vertices, "faces": faces} 120 | 121 | 122 | def test_complex_mesh(collision_loss, complex_mesh): 123 | face_probabilities = torch.ones(complex_mesh["faces"].shape[0], dtype=torch.float32) 124 | loss = collision_loss( 125 | complex_mesh["vertices"], complex_mesh["faces"], face_probabilities 126 | ) 127 | print(f"Complex mesh - Loss: {loss.item()}") 128 | assert loss.item() >= 0, "Negative loss computed for complex mesh" 129 | 130 | 131 | def test_collision_detection_edge_cases(collision_loss): 132 | vertices = torch.tensor( 133 | [ 134 | [0, 0, 0], 135 | [1, 0, 0], 136 | [0, 1, 0], # Triangle 1 137 | [0, 0, 1e-7], 138 | [1, 0, 1e-7], 139 | [0, 1, 1e-7], # Triangle 2 (very close but not intersecting) 140 | [0.5, 0.5, -0.5], 141 | [1.5, 0.5, -0.5], 142 | [0.5, 1.5, -0.5], # Triangle 3 (intersecting) 143 | ], 144 | dtype=torch.float32, 145 | ) 146 | faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.long) 147 | face_probabilities = torch.ones(faces.shape[0], dtype=torch.float32) 148 | 149 | # Calculate and print face normals 150 | v0, v1, v2 = vertices[faces[:, 0]], vertices[faces[:, 1]], vertices[faces[:, 2]] 151 | face_normals = torch.linalg.cross(v1 - v0, v2 - v0) 152 | face_normals = face_normals / (torch.norm(face_normals, dim=1, keepdim=True) + 1e-6) 153 | print("Face normals:") 154 | print(face_normals) 155 | 156 | # Print triangle information 157 | for i, face in enumerate(faces): 158 | print(f"Triangle {i}:") 159 | print(f" Vertices: {vertices[face].tolist()}") 160 | print(f" Normal: {face_normals[i].tolist()}") 161 | print(f" Centroid: {vertices[face].mean(dim=0).tolist()}") 162 | 163 | loss = collision_loss(vertices, faces, face_probabilities) 164 | print(f"Edge case loss: {loss.item()}") 165 | 166 | assert loss.item() > 0, "Should detect the intersecting triangle" 167 | assert ( 168 | loss.item() < 3 169 | ), "Should not detect collision for very close but non-intersecting triangles" 170 | assert ( 171 | loss.item() == 2 172 | ), "Should detect exactly two collisions (Triangle 3 intersects both Triangle 1 and 2)" 173 | -------------------------------------------------------------------------------- /tests/mesh_data/cube.obj: -------------------------------------------------------------------------------- 1 | # https://github.com/mikedh/trimesh 2 | v -1.00000000 -1.00000000 -1.00000000 3 | v -1.00000000 -1.00000000 1.00000000 4 | v -1.00000000 1.00000000 -1.00000000 5 | v -1.00000000 1.00000000 1.00000000 6 | v 1.00000000 -1.00000000 -1.00000000 7 | v 1.00000000 -1.00000000 1.00000000 8 | v 1.00000000 1.00000000 -1.00000000 9 | v 1.00000000 1.00000000 1.00000000 10 | f 2 4 1 11 | f 5 2 1 12 | f 1 4 3 13 | f 3 5 1 14 | f 2 8 4 15 | f 6 2 5 16 | f 6 8 2 17 | f 4 8 3 18 | f 7 5 3 19 | f 3 8 7 20 | f 7 6 5 21 | f 8 6 7 22 | 23 | -------------------------------------------------------------------------------- /tests/mesh_data/rounded_cube.obj: -------------------------------------------------------------------------------- 1 | # Blender 4.2.0 2 | # www.blender.org 3 | o Cube 4 | v 0.966276 1.000000 -0.966276 5 | v 1.000000 0.966276 -1.000000 6 | v 0.966276 1.000000 0.966276 7 | v 1.000000 0.966276 1.000000 8 | v -0.966276 1.000000 -0.966276 9 | v -1.000000 0.966276 -1.000000 10 | v -0.966276 1.000000 0.966276 11 | v -1.000000 0.966276 1.000000 12 | v 1.000000 -0.950029 -1.000000 13 | v 0.950029 -1.000000 -0.950029 14 | v 0.950029 -1.000000 0.950029 15 | v 1.000000 -0.950029 1.000000 16 | v -1.000000 -0.950029 -1.000000 17 | v -0.950029 -1.000000 -0.950029 18 | v -0.950029 -1.000000 0.950029 19 | v -1.000000 -0.950029 1.000000 20 | vn -1.0000 -0.0000 -0.0000 21 | vn 1.0000 -0.0000 -0.0000 22 | vn -0.0000 1.0000 -0.0000 23 | vn -0.0000 -0.0000 1.0000 24 | vn -0.0000 0.7071 0.7071 25 | vn 0.7071 0.7071 -0.0000 26 | vn -0.7071 0.7071 -0.0000 27 | vn -0.0000 0.7071 -0.7071 28 | vn -0.0000 -1.0000 -0.0000 29 | vn -0.7071 -0.7071 -0.0000 30 | vn -0.0000 -0.7071 -0.7071 31 | vn -0.0000 -0.7071 0.7071 32 | vn 0.7071 -0.7071 -0.0000 33 | vn -0.0000 -0.0000 -1.0000 34 | s 0 35 | f 16//1 8//1 6//1 13//1 36 | f 9//2 2//2 4//2 12//2 37 | f 1//3 5//3 7//3 3//3 38 | f 12//4 4//4 8//4 16//4 39 | f 3//5 7//5 8//5 4//5 40 | f 1//6 3//6 4//6 2//6 41 | f 7//7 5//7 6//7 8//7 42 | f 5//8 1//8 2//8 6//8 43 | f 14//9 10//9 11//9 15//9 44 | f 14//10 15//10 16//10 13//10 45 | f 10//11 14//11 13//11 9//11 46 | f 15//12 11//12 12//12 16//12 47 | f 11//13 10//13 9//13 12//13 48 | f 13//14 6//14 2//14 9//14 49 | -------------------------------------------------------------------------------- /tests/mesh_data/sharp_cube.obj: -------------------------------------------------------------------------------- 1 | # Blender 4.2.0 2 | # www.blender.org 3 | o Cube 4 | v 1.000000 1.000000 -1.000000 5 | v 1.000000 -1.000000 -1.000000 6 | v 1.000000 1.000000 1.000000 7 | v 1.000000 -1.000000 1.000000 8 | v -1.000000 1.000000 -1.000000 9 | v -1.000000 -1.000000 -1.000000 10 | v -1.000000 1.000000 1.000000 11 | v -1.000000 -1.000000 1.000000 12 | vn -0.0000 1.0000 -0.0000 13 | vn -0.0000 -0.0000 1.0000 14 | vn -1.0000 -0.0000 -0.0000 15 | vn -0.0000 -1.0000 -0.0000 16 | vn 1.0000 -0.0000 -0.0000 17 | vn -0.0000 -0.0000 -1.0000 18 | s 0 19 | f 1//1 5//1 7//1 3//1 20 | f 4//2 3//2 7//2 8//2 21 | f 8//3 7//3 5//3 6//3 22 | f 6//4 2//4 4//4 8//4 23 | f 2//5 1//5 3//5 4//5 24 | f 6//6 5//6 1//6 2//6 25 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import numpy as np 3 | import networkx as nx 4 | from torch_geometric.data import Data 5 | from neural_mesh_simplification.data.dataset import ( 6 | MeshSimplificationDataset, 7 | preprocess_mesh, 8 | mesh_to_tensor, 9 | load_mesh, 10 | ) 11 | 12 | 13 | def test_load_mesh(tmp_path): 14 | # Create a temporary mesh file 15 | mesh = trimesh.creation.box() 16 | file_path = tmp_path / "test_mesh.obj" 17 | mesh.export(file_path) 18 | 19 | loaded_mesh = load_mesh(str(file_path)) 20 | assert isinstance(loaded_mesh, trimesh.Trimesh) 21 | assert np.allclose(loaded_mesh.vertices, mesh.vertices) 22 | assert np.array_equal(loaded_mesh.faces, mesh.faces) 23 | 24 | 25 | def test_preprocess_mesh_centered(sample_mesh): 26 | processed_mesh = preprocess_mesh(sample_mesh) 27 | # Check that the mesh is centered 28 | assert np.allclose( 29 | processed_mesh.vertices.mean(axis=0), np.zeros(3) 30 | ), "Mesh is not centered" 31 | 32 | 33 | def test_preprocess_mesh_scaled(sample_mesh): 34 | processed_mesh = preprocess_mesh(sample_mesh) 35 | 36 | max_dim = np.max( 37 | processed_mesh.vertices.max(axis=0) - processed_mesh.vertices.min(axis=0) 38 | ) 39 | assert np.isclose(max_dim, 1.0), "Mesh is not scaled to unit cube" 40 | 41 | 42 | def test_mesh_to_tensor(sample_mesh: trimesh.Trimesh): 43 | data = mesh_to_tensor(sample_mesh) 44 | assert isinstance(data, Data) 45 | assert data.num_nodes == len(sample_mesh.vertices) 46 | assert data.face.shape[1] == len(sample_mesh.faces) 47 | assert data.edge_index.shape[0] == 2 48 | assert data.edge_index.max() < data.num_nodes 49 | 50 | 51 | def test_graph_structure_in_data(sample_mesh): 52 | data = mesh_to_tensor(sample_mesh) 53 | 54 | # Check number of nodes 55 | assert data.num_nodes == len(sample_mesh.vertices) 56 | 57 | # Check edge_index 58 | assert data.edge_index.shape[0] == 2 59 | assert data.edge_index.max() < data.num_nodes 60 | 61 | # Reconstruct graph from edge_index 62 | G = nx.Graph() 63 | edge_list = data.edge_index.t().tolist() 64 | G.add_edges_from(edge_list) 65 | 66 | # Check reconstructed graph properties 67 | assert len(G.nodes) == len(sample_mesh.vertices) 68 | assert len(G.edges) == (3 * len(sample_mesh.faces) - len(sample_mesh.edges_unique)) 69 | 70 | # Check connectivity 71 | assert nx.is_connected(G) 72 | 73 | # Check degree distribution 74 | degrees = [d for n, d in G.degree()] 75 | assert min(degrees) >= 3 # Each vertex should be connected to at least 3 others 76 | 77 | # Check if the graph is manifold-like (each edge should be shared by at most two faces) 78 | edge_face_count = {} 79 | for face in sample_mesh.faces: 80 | for i in range(3): 81 | edge = tuple(sorted([face[i], face[(i + 1) % 3]])) 82 | edge_face_count[edge] = edge_face_count.get(edge, 0) + 1 83 | assert all(count <= 2 for count in edge_face_count.values()) 84 | 85 | 86 | def test_dataset(tmp_path): 87 | # Create a few temporary mesh files 88 | for i in range(3): 89 | mesh = trimesh.creation.box() 90 | file_path = tmp_path / f"test_mesh_{i}.obj" 91 | mesh.export(file_path) 92 | 93 | dataset = MeshSimplificationDataset(str(tmp_path)) 94 | assert len(dataset) == 3 95 | 96 | sample = dataset[0] 97 | assert isinstance(sample, Data) 98 | assert sample.num_nodes > 0 99 | assert sample.face.shape[1] > 0 100 | -------------------------------------------------------------------------------- /tests/test_edge_predictor.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn as nn 4 | from torch_geometric.data import Data 5 | from torch_geometric.nn import knn_graph 6 | 7 | from neural_mesh_simplification.models.edge_predictor import EdgePredictor 8 | from neural_mesh_simplification.models.layers.devconv import DevConv 9 | 10 | 11 | @pytest.fixture 12 | def sample_mesh_data(): 13 | x = torch.tensor( 14 | [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], 15 | dtype=torch.float, 16 | ) 17 | edge_index = torch.tensor( 18 | [[0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 3, 0, 3, 1, 2]], dtype=torch.long 19 | ) 20 | return Data(x=x, edge_index=edge_index) 21 | 22 | 23 | def test_edge_predictor_initialization(): 24 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=15) 25 | assert isinstance(edge_predictor.devconv, DevConv) 26 | assert isinstance(edge_predictor.W_q, nn.Linear) 27 | assert isinstance(edge_predictor.W_k, nn.Linear) 28 | assert edge_predictor.k == 15 29 | 30 | 31 | def test_edge_predictor_forward(sample_mesh_data): 32 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2) 33 | simplified_adj_indices, simplified_adj_values = edge_predictor( 34 | sample_mesh_data.x, sample_mesh_data.edge_index 35 | ) 36 | 37 | assert isinstance(simplified_adj_indices, torch.Tensor) 38 | assert isinstance(simplified_adj_values, torch.Tensor) 39 | assert simplified_adj_indices.shape[0] == 2 # 2 rows for source and target indices 40 | assert ( 41 | simplified_adj_values.shape[0] == simplified_adj_indices.shape[1] 42 | ) # Same number of values as edges 43 | 44 | 45 | def test_edge_predictor_output_range(sample_mesh_data): 46 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2) 47 | _, simplified_adj_values = edge_predictor( 48 | sample_mesh_data.x, sample_mesh_data.edge_index 49 | ) 50 | 51 | assert (simplified_adj_values >= 0).all() # Values should be non-negative 52 | 53 | 54 | def test_edge_predictor_symmetry(sample_mesh_data): 55 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2) 56 | simplified_adj_indices, simplified_adj_values = edge_predictor( 57 | sample_mesh_data.x, sample_mesh_data.edge_index 58 | ) 59 | 60 | # Create a sparse tensor from the output 61 | n = sample_mesh_data.x.shape[0] 62 | adj_matrix = torch.sparse_coo_tensor( 63 | simplified_adj_indices, simplified_adj_values, (n, n) 64 | ) 65 | dense_adj = adj_matrix.to_dense() 66 | 67 | assert torch.allclose(dense_adj, dense_adj.t(), atol=1e-6) 68 | 69 | 70 | def test_edge_predictor_connectivity(sample_mesh_data): 71 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2) 72 | simplified_adj_indices, _ = edge_predictor( 73 | sample_mesh_data.x, sample_mesh_data.edge_index 74 | ) 75 | 76 | # Check if all nodes are connected 77 | unique_nodes = torch.unique(simplified_adj_indices) 78 | assert len(unique_nodes) == sample_mesh_data.x.shape[0] 79 | 80 | 81 | def test_edge_predictor_different_input_sizes(): 82 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=5) 83 | 84 | # Test with a larger graph 85 | x = torch.rand(10, 3) 86 | edge_index = torch.randint(0, 10, (2, 30)) 87 | simplified_adj_indices, simplified_adj_values = edge_predictor(x, edge_index) 88 | 89 | assert simplified_adj_indices.shape[0] == 2 90 | assert simplified_adj_values.shape[0] == simplified_adj_indices.shape[1] 91 | assert torch.max(simplified_adj_indices) < 10 92 | 93 | 94 | def test_attention_scores_shape(sample_mesh_data): 95 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2) 96 | 97 | # Get intermediate features 98 | knn_edges = knn_graph(sample_mesh_data.x, k=2, flow="target_to_source") 99 | extended_edges = torch.cat([sample_mesh_data.edge_index, knn_edges], dim=1) 100 | features = edge_predictor.devconv(sample_mesh_data.x, extended_edges) 101 | 102 | # Test attention scores 103 | attention_scores = edge_predictor.compute_attention_scores( 104 | features, sample_mesh_data.edge_index 105 | ) 106 | 107 | assert attention_scores.shape[0] == sample_mesh_data.edge_index.shape[1] 108 | assert torch.allclose( 109 | attention_scores.sum(), 110 | torch.tensor( 111 | len(torch.unique(sample_mesh_data.edge_index[0])), dtype=torch.float32 112 | ), 113 | ) 114 | 115 | 116 | def test_simplified_adjacency_shapes(): 117 | # Create a simple graph 118 | x = torch.rand(5, 3) 119 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) 120 | attention_scores = torch.rand(edge_index.shape[1]) 121 | 122 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=15) 123 | indices, values = edge_predictor.compute_simplified_adjacency( 124 | attention_scores, edge_index 125 | ) 126 | 127 | assert indices.shape[0] == 2 128 | assert indices.shape[1] == values.shape[0] 129 | assert torch.max(indices) < 5 130 | 131 | 132 | def test_empty_input_handling(): 133 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=15) 134 | x = torch.rand(5, 3) 135 | empty_edge_index = torch.empty((2, 0), dtype=torch.long) 136 | 137 | # Test forward pass with empty edge_index 138 | with pytest.raises(ValueError, match="Edge index is empty"): 139 | indices, values = edge_predictor(x, empty_edge_index) 140 | 141 | # Test compute_simplified_adjacency with empty edge_index 142 | empty_attention_scores = torch.empty(0) 143 | with pytest.raises(ValueError, match="Edge index is empty"): 144 | indices, values = edge_predictor.compute_simplified_adjacency( 145 | empty_attention_scores, empty_edge_index 146 | ) 147 | 148 | 149 | def test_feature_transformation(): 150 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2) 151 | x = torch.rand(5, 3) 152 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) 153 | 154 | # Get intermediate features 155 | knn_edges = knn_graph(x, k=2, flow="target_to_source") 156 | extended_edges = torch.cat([edge_index, knn_edges], dim=1) 157 | features = edge_predictor.devconv(x, extended_edges) 158 | 159 | # Check feature dimensions 160 | assert features.shape == (5, 64) # [num_nodes, hidden_channels] 161 | 162 | # Check transformed features through attention layers 163 | q = edge_predictor.W_q(features) 164 | k = edge_predictor.W_k(features) 165 | assert q.shape == (5, 64) 166 | assert k.shape == (5, 64) 167 | -------------------------------------------------------------------------------- /tests/test_face_classifier.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from neural_mesh_simplification.models.face_classifier import FaceClassifier 5 | 6 | 7 | @pytest.fixture 8 | def face_classifier(): 9 | return FaceClassifier(input_dim=16, hidden_dim=32, num_layers=3, k=20) 10 | 11 | 12 | def test_face_classifier_initialization(face_classifier): 13 | assert len(face_classifier.triconv_layers) == 3 14 | assert isinstance(face_classifier.final_layer, torch.nn.Linear) 15 | 16 | 17 | def test_face_classifier_forward(face_classifier): 18 | num_faces = 100 19 | x = torch.randn(num_faces, 16) 20 | pos = torch.randn(num_faces, 3) 21 | 22 | out = face_classifier(x, pos) 23 | assert out.shape == (num_faces,) 24 | assert torch.all(out >= 0) and torch.all(out <= 1) 25 | assert torch.isclose(out.sum(), torch.tensor(1.0), atol=1e-6) 26 | 27 | 28 | def test_face_classifier_gradient(face_classifier): 29 | num_faces = 100 30 | x = torch.randn(num_faces, 16, requires_grad=True) 31 | pos = torch.randn(num_faces, 3, requires_grad=True) 32 | 33 | out = face_classifier(x, pos) 34 | loss = out.sum() 35 | loss.backward() 36 | 37 | assert x.grad is not None 38 | assert pos.grad is not None 39 | assert all(p.grad is not None for p in face_classifier.parameters()) 40 | 41 | 42 | def test_face_classifier_with_batch(face_classifier): 43 | num_faces = 100 44 | batch_size = 2 45 | x = torch.randn(num_faces, 16) 46 | pos = torch.randn(num_faces, 3) 47 | batch = torch.cat( 48 | [torch.full((num_faces // batch_size,), i) for i in range(batch_size)] 49 | ) 50 | 51 | out = face_classifier(x, pos, batch) 52 | assert out.shape == (num_faces,) 53 | assert torch.all(out >= 0) and torch.all(out <= 1) 54 | 55 | # Check if the sum of probabilities for each batch is close to 1 56 | for i in range(batch_size): 57 | batch_sum = out[batch == i].sum() 58 | assert torch.isclose(batch_sum, torch.tensor(1.0), atol=1e-6) 59 | 60 | 61 | def test_face_classifier_knn_graph(face_classifier): 62 | num_faces = 100 63 | x = torch.randn(num_faces, 16) 64 | pos = torch.randn(num_faces, 3, 3) # 3 vertices per face 65 | 66 | # Call the forward method to construct the k-nn graph 67 | _ = face_classifier(x, pos) 68 | 69 | # Get the constructed edge_index from the first TriConv layer 70 | edge_index = face_classifier.triconv_layers[0].last_edge_index 71 | 72 | # Check the number of neighbors for each face 73 | for i in range(num_faces): 74 | actual_neighbors = edge_index[1][edge_index[0] == i] 75 | assert ( 76 | len(actual_neighbors) >= face_classifier.k 77 | ), f"Face {i} has {len(actual_neighbors)} neighbors, which is less than {face_classifier.k}" 78 | 79 | # Verify that the graph is symmetric 80 | symmetric_diff = set(map(tuple, edge_index.t().tolist())) ^ set( 81 | map(tuple, edge_index.flip(0).t().tolist()) 82 | ) 83 | assert len(symmetric_diff) == 0, "The k-nn graph is not symmetric" 84 | 85 | # Verify that there are no self-loops 86 | assert torch.all( 87 | edge_index[0] != edge_index[1] 88 | ), "The k-nn graph contains self-loops" 89 | -------------------------------------------------------------------------------- /tests/test_mesh_operations.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | 4 | from neural_mesh_simplification.utils import build_graph_from_mesh 5 | 6 | 7 | def test_build_graph_from_mesh(sample_mesh): 8 | graph = build_graph_from_mesh(sample_mesh) 9 | 10 | # Check number of nodes and edges 11 | assert len(graph.nodes) == len(sample_mesh.vertices) 12 | assert len(graph.edges) == ( 13 | 3 * len(sample_mesh.faces) - len(sample_mesh.edges_unique) 14 | ) 15 | 16 | # Check node attributes 17 | for i, pos in enumerate(sample_mesh.vertices): 18 | assert i in graph.nodes 19 | assert np.allclose(graph.nodes[i]["pos"], pos) 20 | 21 | # Check edge connectivity 22 | for face in sample_mesh.faces: 23 | assert graph.has_edge(face[0], face[1]) 24 | assert graph.has_edge(face[1], face[2]) 25 | assert graph.has_edge(face[2], face[0]) 26 | 27 | # Check graph connectivity 28 | assert nx.is_connected(graph) 29 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | from trimesh import Trimesh 4 | 5 | from neural_mesh_simplification.metrics.chamfer_distance import chamfer_distance 6 | from neural_mesh_simplification.metrics.edge_preservation import edge_preservation 7 | from neural_mesh_simplification.metrics.hausdorff_distance import hausdorff_distance 8 | from neural_mesh_simplification.metrics.normal_consistency import normal_consistency 9 | 10 | 11 | def create_cube_mesh(scale=1.0): 12 | vertices = ( 13 | np.array( 14 | [ 15 | [0, 0, 0], 16 | [1, 0, 0], 17 | [1, 1, 0], 18 | [0, 1, 0], 19 | [0, 0, 1], 20 | [1, 0, 1], 21 | [1, 1, 1], 22 | [0, 1, 1], 23 | ] 24 | ) 25 | * scale 26 | ) 27 | faces = np.array( 28 | [ 29 | [0, 1, 2], 30 | [0, 2, 3], 31 | [4, 5, 6], 32 | [4, 6, 7], 33 | [0, 1, 5], 34 | [0, 5, 4], 35 | [2, 3, 7], 36 | [2, 7, 6], 37 | [1, 2, 6], 38 | [1, 6, 5], 39 | [0, 3, 7], 40 | [0, 7, 4], 41 | ] 42 | ) 43 | 44 | return Trimesh(vertices=vertices, faces=faces) 45 | 46 | 47 | def test_chamfer_distance_identical_meshes(): 48 | mesh1 = create_cube_mesh() 49 | mesh2 = create_cube_mesh() 50 | 51 | dist = chamfer_distance(mesh1, mesh2) 52 | 53 | assert np.isclose( 54 | dist, 0.0 55 | ), f"Chamfer distance for identical meshes should be 0, got {dist}" 56 | 57 | 58 | def test_chamfer_distance_different_meshes(): 59 | mesh1 = create_cube_mesh() 60 | mesh2 = create_cube_mesh(scale=2.0) # Scale the second cube to be twice as large 61 | 62 | dist = chamfer_distance(mesh1, mesh2) 63 | 64 | assert ( 65 | dist > 0 66 | ), f"Chamfer distance for different meshes should be greater than 0, got {dist}" 67 | 68 | 69 | def test_normal_consistency(): 70 | mesh = create_cube_mesh() 71 | consistency = normal_consistency(mesh) 72 | 73 | expected_consistency = 0.577350269189626 74 | 75 | assert np.isclose( 76 | consistency, expected_consistency 77 | ), f"Normal consistency should be {expected_consistency}, got {consistency}" 78 | 79 | 80 | def test_edge_preservation(): 81 | original_mesh = trimesh.load("./tests/mesh_data/rounded_cube.obj") 82 | simplified_mesh = trimesh.load("./tests/mesh_data/sharp_cube.obj") 83 | 84 | # original_mesh = trimesh.creation.icosphere(subdivisions=3, radius=2) 85 | # simplified_mesh = trimesh.creation.icosphere(subdivisions=2, radius=2) 86 | 87 | preservation_metric = edge_preservation(original_mesh, simplified_mesh) 88 | 89 | assert ( 90 | preservation_metric < 1.0 91 | ), f"Edge preservation metric should be less than 1.0, got {preservation_metric}" 92 | assert ( 93 | preservation_metric > 0.0 94 | ), f"Edge preservation metric should be greater than 0.0, got {preservation_metric}" 95 | 96 | 97 | def test_hausdorff_distance_identical_meshes(): 98 | mesh1 = create_cube_mesh() 99 | mesh2 = create_cube_mesh() 100 | 101 | dist = hausdorff_distance(mesh1, mesh2) 102 | 103 | assert np.isclose( 104 | dist, 0.0 105 | ), f"Hausdorff distance for identical meshes should be 0, got {dist}" 106 | 107 | 108 | def test_hausdorff_distance_different_meshes(): 109 | mesh1 = create_cube_mesh() 110 | mesh2 = trimesh.creation.icosphere(subdivisions=2, radius=2) 111 | 112 | dist = hausdorff_distance(mesh1, mesh2) 113 | 114 | assert ( 115 | dist > 1.99 116 | ), f"Hausdorff distance for identical meshes should be 2, got {dist}" 117 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch_geometric.utils 4 | from torch_geometric.data import Data 5 | 6 | from neural_mesh_simplification.models import NeuralMeshSimplification 7 | 8 | 9 | @pytest.fixture 10 | def sample_data() -> Data: 11 | num_nodes = 10 12 | x = torch.randn(num_nodes, 3) 13 | # Create a more densely connected edge index 14 | edge_index = torch.tensor( 15 | [ 16 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4], 17 | [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 6, 7], 18 | ], 19 | dtype=torch.long, 20 | ) 21 | pos = torch.randn(num_nodes, 3) 22 | return Data(x=x, edge_index=edge_index, pos=pos) 23 | 24 | 25 | def test_neural_mesh_simplification_forward(sample_data: Data): 26 | # Set a fixed random seed for reproducibility 27 | torch.manual_seed(42) 28 | 29 | model = NeuralMeshSimplification( 30 | input_dim=3, 31 | hidden_dim=64, 32 | edge_hidden_dim=64, 33 | num_layers=3, 34 | k=3, # Reduce k to avoid too many edges in the test 35 | edge_k=15, 36 | target_ratio=0.5, # Ensure we sample roughly half the vertices 37 | ) 38 | 39 | # First test point sampling 40 | sampled_indices, sampled_probs = model.sample_points(sample_data) 41 | assert sampled_indices.numel() > 0, "No points were sampled" 42 | assert sampled_indices.numel() <= sample_data.num_nodes, "Too many points sampled" 43 | 44 | # Get the subgraph for sampled points 45 | sampled_edge_index, _ = torch_geometric.utils.subgraph( 46 | sampled_indices, 47 | sample_data.edge_index, 48 | relabel_nodes=True, 49 | num_nodes=sample_data.num_nodes, 50 | ) 51 | assert sampled_edge_index.numel() > 0, "No edges in sampled subgraph" 52 | 53 | # Now test the full forward pass 54 | output = model(sample_data) 55 | 56 | # Add assertions to check the output structure and shapes 57 | assert isinstance(output, dict) 58 | assert "sampled_indices" in output 59 | assert "sampled_probs" in output 60 | assert "sampled_vertices" in output 61 | assert "edge_index" in output 62 | assert "edge_probs" in output 63 | assert "candidate_triangles" in output 64 | assert "triangle_probs" in output 65 | assert "face_probs" in output 66 | assert "simplified_faces" in output 67 | 68 | # Check shapes 69 | assert output["sampled_indices"].dim() == 1 70 | # sampled_probs should match the number of sampled vertices 71 | assert output["sampled_probs"].shape == output["sampled_indices"].shape 72 | assert output["sampled_vertices"].shape[1] == 3 # 3D coordinates 73 | 74 | if output["edge_index"].numel() > 0: # Only check if we have edges 75 | assert output["edge_index"].shape[0] == 2 # Source and target nodes 76 | assert ( 77 | len(output["edge_probs"]) == output["edge_index"].shape[1] 78 | ) # One prob per edge 79 | 80 | # Check that edge indices are valid 81 | num_sampled_vertices = output["sampled_vertices"].shape[0] 82 | assert torch.all(output["edge_index"] >= 0) 83 | assert torch.all(output["edge_index"] < num_sampled_vertices) 84 | 85 | if output["candidate_triangles"].numel() > 0: # Only check if we have triangles 86 | assert output["candidate_triangles"].shape[1] == 3 # Triangle indices 87 | assert len(output["triangle_probs"]) == len(output["candidate_triangles"]) 88 | assert len(output["face_probs"]) == len(output["candidate_triangles"]) 89 | 90 | # Additional checks 91 | assert output["sampled_indices"].shape[0] <= sample_data.num_nodes 92 | assert output["sampled_vertices"].shape[0] == output["sampled_indices"].shape[0] 93 | 94 | # Check that sampled_vertices correspond to a subset of original vertices 95 | original_vertices = sample_data.pos 96 | sampled_vertices = output["sampled_vertices"] 97 | 98 | # For each sampled vertex, check if it exists in original vertices 99 | for sv in sampled_vertices: 100 | # Check if this vertex exists in original vertices (within numerical precision) 101 | exists = torch.any(torch.all(torch.abs(original_vertices - sv) < 1e-6, dim=1)) 102 | assert exists, "Sampled vertex not found in original vertices" 103 | 104 | # Check that simplified_faces only contain valid indices if not empty 105 | if output["simplified_faces"].numel() > 0: 106 | max_index = output["sampled_vertices"].shape[0] - 1 107 | assert torch.all(output["simplified_faces"] >= 0) 108 | assert torch.all(output["simplified_faces"] <= max_index) 109 | 110 | # Check the relationship between face_probs and simplified_faces 111 | if output["face_probs"].numel() > 0: 112 | assert output["simplified_faces"].shape[0] <= output["face_probs"].shape[0] 113 | 114 | 115 | def test_generate_candidate_triangles(): 116 | model = NeuralMeshSimplification( 117 | input_dim=3, 118 | hidden_dim=64, 119 | edge_hidden_dim=64, 120 | num_layers=3, 121 | k=5, 122 | edge_k=15, 123 | target_ratio=0.5, 124 | ) 125 | edge_index = torch.tensor( 126 | [[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]], dtype=torch.long 127 | ) 128 | edge_probs = torch.tensor([0.9, 0.9, 0.8, 0.8, 0.7, 0.7]) 129 | 130 | triangles, triangle_probs = model.generate_candidate_triangles( 131 | edge_index, edge_probs 132 | ) 133 | 134 | assert triangles.shape[1] == 3 135 | assert triangle_probs.shape[0] == triangles.shape[0] 136 | assert torch.all(triangles >= 0) 137 | assert torch.all(triangles < edge_index.max() + 1) 138 | assert torch.all(triangle_probs >= 0) and torch.all(triangle_probs <= 1) 139 | 140 | max_possible_triangles = edge_index.max().item() + 1 # num_nodes 141 | max_possible_triangles = ( 142 | max_possible_triangles 143 | * (max_possible_triangles - 1) 144 | * (max_possible_triangles - 2) 145 | // 6 146 | ) 147 | assert triangles.shape[0] <= max_possible_triangles 148 | -------------------------------------------------------------------------------- /tests/test_model_layers.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import pytest 3 | import torch 4 | from torch import nn 5 | from neural_mesh_simplification.models.layers import DevConv, TriConv 6 | 7 | 8 | def create_graph_data(): 9 | x = torch.tensor( 10 | [ 11 | [0.0, 0.0, 0.0], # Node 0 12 | [1.0, 0.0, 0.0], # Node 1 13 | [0.0, 1.0, 0.0], # Node 2 14 | [1.0, 1.0, 0.0], # Node 3 15 | ], 16 | dtype=torch.float, 17 | ) 18 | 19 | edge_index = torch.tensor( 20 | [ 21 | [0, 0, 1, 1, 2, 2, 3, 3], # Source nodes 22 | [1, 2, 0, 3, 0, 3, 1, 2], # Target nodes 23 | ], 24 | dtype=torch.long, 25 | ) 26 | 27 | return x, edge_index 28 | 29 | 30 | @pytest.fixture 31 | def graph_data(): 32 | return create_graph_data() 33 | 34 | 35 | def test_devconv(graph_data): 36 | x, edge_index = graph_data 37 | 38 | devconv = DevConv(in_channels=3, out_channels=4) 39 | output = devconv(x, edge_index) 40 | 41 | assert output.shape == (4, 4) # 4 nodes, 4 output channels 42 | 43 | if "-s" in sys.argv: 44 | print("Input shape:", x.shape) 45 | print("Output shape:", output.shape) 46 | print("Output:\n", output) 47 | 48 | analyze_feature_differences(x, edge_index) 49 | 50 | 51 | def analyze_feature_differences(x, edge_index): 52 | devconv = DevConv(in_channels=3, out_channels=3) 53 | output = devconv(x, edge_index) 54 | 55 | for i in range(x.shape[0]): 56 | neighbors = edge_index[1][edge_index[0] == i] 57 | print(f"Node {i}:") 58 | print(f" Input feature: {x[i]}") 59 | print(f" Output feature: {output[i]}") 60 | print(" Neighbor differences:") 61 | for j in neighbors: 62 | print(f" Node {j}: {x[i] - x[j]}") 63 | print() 64 | 65 | 66 | @pytest.fixture 67 | def triconv_layer(): 68 | return TriConv(in_channels=16, out_channels=32) 69 | 70 | 71 | def test_triconv_initialization(triconv_layer): 72 | assert triconv_layer.in_channels == 16 73 | assert triconv_layer.out_channels == 32 74 | assert isinstance(triconv_layer.mlp, nn.Sequential) 75 | assert triconv_layer.mlp[0].in_features == 25 # 16 + 9 76 | 77 | 78 | def test_triconv_forward(triconv_layer): 79 | num_nodes = 10 80 | x = torch.randn(num_nodes, 16) 81 | pos = torch.randn(num_nodes, 3) 82 | edge_index = torch.randint(0, num_nodes, (2, 20)) 83 | 84 | out = triconv_layer(x, pos, edge_index) 85 | assert out.shape == (num_nodes, 32) 86 | 87 | 88 | def test_relative_position_encoding(triconv_layer): 89 | num_nodes = 5 90 | pos = torch.randn(num_nodes, 3) 91 | edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]) 92 | 93 | rel_pos = triconv_layer.compute_relative_position_encoding( 94 | pos, edge_index[0], edge_index[1] 95 | ) 96 | assert rel_pos.shape == (4, 9) # 4 edges, 9-dimensional encoding 97 | 98 | 99 | def test_triconv_gradient(triconv_layer): 100 | num_nodes = 10 101 | x = torch.randn(num_nodes, 16, requires_grad=True) 102 | pos = torch.randn(num_nodes, 3, requires_grad=True) 103 | edge_index = torch.randint(0, num_nodes, (2, 20)) 104 | 105 | out = triconv_layer(x, pos, edge_index) 106 | loss = out.sum() 107 | loss.backward() 108 | 109 | assert x.grad is not None 110 | assert pos.grad is not None 111 | assert all(p.grad is not None for p in triconv_layer.parameters()) 112 | 113 | 114 | def test_last_edge_index(triconv_layer): 115 | num_nodes = 10 116 | x = torch.randn(num_nodes, 16) 117 | pos = torch.randn(num_nodes, 3) 118 | edge_index = torch.randint(0, num_nodes, (2, 20)) 119 | 120 | triconv_layer(x, pos, edge_index) 121 | assert torch.all(triconv_layer.last_edge_index == edge_index) 122 | -------------------------------------------------------------------------------- /tests/test_point_sampler.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import nn 4 | from neural_mesh_simplification.models.layers.devconv import DevConv 5 | from neural_mesh_simplification.models.point_sampler import PointSampler 6 | 7 | 8 | @pytest.fixture 9 | def sample_graph_data(): 10 | x = torch.tensor( 11 | [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], 12 | dtype=torch.float, 13 | ) 14 | edge_index = torch.tensor( 15 | [[0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 3, 0, 3, 1, 2]], dtype=torch.long 16 | ) 17 | return x, edge_index 18 | 19 | 20 | def test_point_sampler_initialization(): 21 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3) 22 | assert len(sampler.convs) == 3 23 | assert isinstance(sampler.convs[0], DevConv) 24 | assert isinstance(sampler.output_layer, nn.Linear) 25 | 26 | 27 | def test_point_sampler_forward(sample_graph_data): 28 | x, edge_index = sample_graph_data 29 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3) 30 | probabilities = sampler(x, edge_index) 31 | assert probabilities.shape == (4,) # 4 input vertices 32 | assert (probabilities >= 0).all() and (probabilities <= 1).all() 33 | 34 | 35 | def test_point_sampler_sampling(sample_graph_data): 36 | x, edge_index = sample_graph_data 37 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3) 38 | probabilities = sampler(x, edge_index) 39 | sampled_indices = sampler.sample(probabilities, num_samples=2) 40 | assert sampled_indices.shape == (2,) 41 | assert len(torch.unique(sampled_indices)) == 2 # All indices should be unique 42 | 43 | 44 | def test_point_sampler_forward_and_sample(sample_graph_data): 45 | x, edge_index = sample_graph_data 46 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3) 47 | sampled_indices, probabilities = sampler.forward_and_sample( 48 | x, edge_index, num_samples=2 49 | ) 50 | assert sampled_indices.shape == (2,) 51 | assert probabilities.shape == (4,) 52 | assert len(torch.unique(sampled_indices)) == 2 53 | 54 | 55 | def test_point_sampler_deterministic_behavior(sample_graph_data): 56 | x, edge_index = sample_graph_data 57 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3) 58 | 59 | torch.manual_seed(42) 60 | indices1, _ = sampler.forward_and_sample(x, edge_index, num_samples=2) 61 | 62 | torch.manual_seed(42) 63 | indices2, _ = sampler.forward_and_sample(x, edge_index, num_samples=2) 64 | 65 | assert torch.equal(indices1, indices2) 66 | 67 | 68 | def test_point_sampler_different_input_sizes(): 69 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3) 70 | 71 | x1 = torch.rand(10, 3) 72 | edge_index1 = torch.randint(0, 10, (2, 20)) 73 | prob1 = sampler(x1, edge_index1) 74 | assert prob1.shape == (10,) 75 | 76 | x2 = torch.rand(20, 3) 77 | edge_index2 = torch.randint(0, 20, (2, 40)) 78 | prob2 = sampler(x2, edge_index2) 79 | assert prob2.shape == (20,) 80 | -------------------------------------------------------------------------------- /tests/test_trimesh.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import pytest 3 | 4 | 5 | @pytest.mark.trimesh 6 | class TestMeshCreation: 7 | def test_sphere(self): 8 | print("Creating sphere") 9 | mesh = trimesh.creation.icosphere(subdivisions=2, radius=2) 10 | mesh.export("./tests/mesh_data/sphere_2.obj") 11 | 12 | def test_cube(self): 13 | mesh = trimesh.creation.box(extents=[2, 2, 2]) 14 | mesh.export("./tests/mesh_data/cube_2.obj") 15 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "eef23eb0-4fc4-4f75-bc46-de36bdceeb1b", 6 | "metadata": {}, 7 | "source": [ 8 | "# Train the Neural Mesh Simplification model" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "68548e86-1a2b-4ab7-b89e-5c93314f9345", 14 | "metadata": {}, 15 | "source": [ 16 | "## Set up the environment" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "13c589dd-aa96-4d93-be4e-6cfdd181f10f", 22 | "metadata": {}, 23 | "source": [ 24 | "### [*only required for remote runs*] Remote environment setup\n", 25 | "\n", 26 | "If you are running this notebook remotely (e.g. Google Colab), you'll want to set up the environment by\n", 27 | "* Downloading the repository from GitHub\n", 28 | "* Setting up the python environment\n", 29 | "\n", 30 | "If are opening this notebook locally, by running `jupyter lab` from the repository root and the right conda environment activated, the above step is not required." 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "id": "179368da-e6fc-4beb-ae82-ca18137de974", 36 | "metadata": {}, 37 | "source": [ 38 | "#### Step 1. Check out the repo\n", 39 | "That's where the source code for mesh simplification, along with its dependency definitions and other utilities, lives." 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "0bdf37ac-a233-44f9-8f6c-f929566a0bbd", 46 | "metadata": {}, 47 | "outputs": [], 48 | "source": [ 49 | "!git clone https://github.com/gennarinoos/neural-mesh-simplification.git neural-mesh-simplification" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "id": "90db897c-f4e0-416e-b92c-392716730e07", 55 | "metadata": {}, 56 | "source": [ 57 | "#### Step 2. Install python version 3.12 using apt-get" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "id": "5e1795ff", 63 | "metadata": {}, 64 | "source": [ 65 | "Check the current python version by running the following command. This notebook requires Python 3.12 to run. Either install it via your Notebook environment settings and jump to Step 6 or follow all the steps below." 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "aaa8c207", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "!python --version" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "id": "ef6be6e0-41b0-43ca-9fda-508a8cb198be", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "!sudo apt-get update\n", 86 | "!sudo apt-get install python3.12" 87 | ] 88 | }, 89 | { 90 | "cell_type": "markdown", 91 | "id": "71e15c47-fcff-4e5b-8692-ff62fe0cc764", 92 | "metadata": {}, 93 | "source": [ 94 | "#### Step 3. Update alternatives to use the new Python version" 95 | ] 96 | }, 97 | { 98 | "cell_type": "code", 99 | "execution_count": null, 100 | "id": "6c36faa2-8839-4476-b74d-ffbf4f2f086f", 101 | "metadata": {}, 102 | "outputs": [], 103 | "source": [ 104 | "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1\n", 105 | "!sudo update-alternatives --config python3" 106 | ] 107 | }, 108 | { 109 | "cell_type": "markdown", 110 | "id": "d1c5e4ec-3074-4041-bdee-5246987b4ae0", 111 | "metadata": {}, 112 | "source": [ 113 | "#### Step 4. Install pip and the required packages for the new Python version." 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "9d11a192-478c-4797-b9ba-42c41f9e9d9e", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "!rm -f get-pip*.py\n", 124 | "!wget https://bootstrap.pypa.io/get-pip.py\n", 125 | "!python get-pip.py\n", 126 | "!python -m pip install ipykernel\n", 127 | "!python -m ipykernel install --user --name python3.12 --display-name \"Python 3.12\"" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "id": "e1c3d62d-f255-4f2e-830f-09080b41d364", 133 | "metadata": {}, 134 | "source": [ 135 | "#### Step 5. Restart and verify\n", 136 | "At this point you may need to restart the session, after which you want to verify that `python` is at the right version (`3.12`)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": null, 142 | "id": "0e29e1e1-5e0d-4dbf-b4cd-c5a26e5092b6", 143 | "metadata": {}, 144 | "outputs": [], 145 | "source": [ 146 | "!python --version" 147 | ] 148 | }, 149 | { 150 | "cell_type": "markdown", 151 | "id": "2dd709a2-ebec-48fb-b5b7-2e75a9fbb129", 152 | "metadata": {}, 153 | "source": [ 154 | "#### Step 6. Upgrade pip and setuptools" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "915962f0-9dff-40d3-a4f3-cbc1cf69fcbe", 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "!pip install --upgrade pip setuptools wheel\n", 165 | "!pip install --upgrade build" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "4f5242f0-c6de-4c1c-9922-eaf4ee161e35", 171 | "metadata": {}, 172 | "source": [ 173 | "### Set repository as the working directory \n", 174 | "CD into the repository downloaded above" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": null, 180 | "id": "a70ca79a-83e8-418b-85e4-364ca8c6b9d3", 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "%cd neural-mesh-simplification" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "id": "d6b48728-47c7-4441-8043-258e71f2a8d1", 190 | "metadata": {}, 191 | "source": [ 192 | "### Package requirements" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "40ccef76-a2ed-47a5-8c7b-bc7c14ba6770", 198 | "metadata": {}, 199 | "source": [ 200 | "Depending on whether you are using PyTorch on a CPU or a GPU,\n", 201 | "you'll have to use the correct binaries for PyTorch and the PyTorch Geometric libraries. You can install them via:" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "2141322d-334d-4957-b371-0660a7f7dbfc", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121\n", 212 | "!pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cu121.html" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "id": "cc1d234c-1c03-445a-b2f7-d56fd42e4f94", 218 | "metadata": {}, 219 | "source": [ 220 | "Replace “cu121” with the appropriate CUDA version for your system. If you don't know what is your cuda version, run `nvidia-smi`" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "id": "db20093e-689c-40ab-8f2e-c86337dbc466", 226 | "metadata": {}, 227 | "source": [ 228 | "Only then you can install the requirements via pip:" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "bf9fa547-f157-4679-a6c1-f0304ac59aaf", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "!pip install -r requirements.txt\n", 239 | "!pip uninstall -y neural-mesh-simplification\n", 240 | "!pip install ." 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "id": "76d21495-9bd4-46df-a4c8-a31b573e3d79", 246 | "metadata": { 247 | "jp-MarkdownHeadingCollapsed": true 248 | }, 249 | "source": [ 250 | "---\n", 251 | "## Download the training data\n", 252 | "We can use the Hugging Face API to download some mesh data to use for training and evaluation." 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "id": "42801141-20d8-40da-8f59-b7b25338ceef", 259 | "metadata": {}, 260 | "outputs": [], 261 | "source": [ 262 | "import os\n", 263 | "import shutil\n", 264 | "from huggingface_hub import snapshot_download\n", 265 | "\n", 266 | "target_folder = \"data/raw\"\n", 267 | "wip_folder = os.path.join(target_folder, \"wip\")\n", 268 | "os.makedirs(wip_folder, exist_ok=True)\n", 269 | "\n", 270 | "# abc_train is really large (+5k meshes), so download just a sample\n", 271 | "folder_patterns = [\"abc_extra_noisy/03_meshes/*.ply\", \"abc_train/03_meshes/*.ply\"]\n", 272 | "\n", 273 | "# Download\n", 274 | "snapshot_download(\n", 275 | " repo_id=\"perler/ppsurf\",\n", 276 | " repo_type=\"dataset\",\n", 277 | " cache_dir=wip_folder,\n", 278 | " allow_patterns=folder_patterns[0],\n", 279 | ")\n", 280 | "\n", 281 | "# Move files from wip folder to target folder\n", 282 | "for root, _, files in os.walk(wip_folder):\n", 283 | " for file in files:\n", 284 | " if file.endswith(\".ply\"):\n", 285 | " src_file = os.path.join(root, file)\n", 286 | " dest_file = os.path.join(target_folder, file)\n", 287 | " shutil.copy2(src_file, dest_file)\n", 288 | " os.remove(src_file)\n", 289 | "\n", 290 | "# Remove the wip folder\n", 291 | "shutil.rmtree(wip_folder)" 292 | ] 293 | }, 294 | { 295 | "cell_type": "markdown", 296 | "id": "c6df29f3-ee23-47d3-b987-7b3f8e0cfc66", 297 | "metadata": {}, 298 | "source": [ 299 | "## Prepare the data\n", 300 | "The downloaded data needs to be prepapared for training. We can use a script in the repository we checked out for that." 301 | ] 302 | }, 303 | { 304 | "cell_type": "code", 305 | "execution_count": null, 306 | "id": "bebe14c3-7ca1-4a47-80ab-a5d049def2bd", 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "!mkdir -p data/processed\n", 311 | "!python scripts/preprocess_data.py" 312 | ] 313 | }, 314 | { 315 | "cell_type": "markdown", 316 | "id": "b1fe1228-5a40-4f44-9e58-5607beabd5a5", 317 | "metadata": {}, 318 | "source": [ 319 | "---\n", 320 | "## Model Training" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "id": "7b3af882-fe98-48d4-bd34-cf16912cc5d4", 326 | "metadata": {}, 327 | "source": [ 328 | "When using a GPU, ensure the training is happening on the GPU, and the environment is configured properly." 329 | ] 330 | }, 331 | { 332 | "cell_type": "code", 333 | "execution_count": null, 334 | "id": "486cf962-81a4-41b4-a62a-9431d0ed6cf8", 335 | "metadata": {}, 336 | "outputs": [], 337 | "source": [ 338 | "import torch\n", 339 | "print(torch.cuda.is_available())\n", 340 | "\n", 341 | "!nvcc --version" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "id": "2fad8fd1-10a8-410c-b826-848c91a227bc", 347 | "metadata": {}, 348 | "source": [ 349 | "### Start the training" 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "execution_count": null, 355 | "id": "ef712306-3971-421b-a3ee-d063759c926d", 356 | "metadata": {}, 357 | "outputs": [], 358 | "source": [ 359 | "!python scripts/train.py --data-path data/processed --config configs/default.yaml" 360 | ] 361 | } 362 | ], 363 | "metadata": { 364 | "kernelspec": { 365 | "display_name": "Python 3 (ipykernel)", 366 | "language": "python", 367 | "name": "python3" 368 | }, 369 | "language_info": { 370 | "codemirror_mode": { 371 | "name": "ipython", 372 | "version": 3 373 | }, 374 | "file_extension": ".py", 375 | "mimetype": "text/x-python", 376 | "name": "python", 377 | "nbconvert_exporter": "python", 378 | "pygments_lexer": "ipython3", 379 | "version": "3.12.8" 380 | } 381 | }, 382 | "nbformat": 4, 383 | "nbformat_minor": 5 384 | } 385 | -------------------------------------------------------------------------------- /uv.lock: -------------------------------------------------------------------------------- 1 | version = 1 2 | requires-python = ">=3.12" 3 | 4 | [[package]] 5 | name = "neural-mesh-simplification" 6 | version = "0.1.0" 7 | source = { virtual = "." } --------------------------------------------------------------------------------