├── .cursorrules
├── .editorconfig
├── .github
└── workflows
│ └── ci.yml
├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
├── neural-mesh-simplification.iml
└── vcs.xml
├── .python-version
├── LICENSE
├── README.md
├── configs
├── default.yaml
└── local.yaml
├── examples
└── example.py
├── pyproject.toml
├── pytest.ini
├── requirements.txt
├── scripts
├── download_test_meshes.py
├── evaluate.py
├── infer.py
├── preprocess_data.py
└── train.py
├── src
└── neural_mesh_simplification
│ ├── __init__.py
│ ├── api
│ ├── __init__.py
│ └── neural_mesh_simplifier.py
│ ├── data
│ ├── __init__.py
│ └── dataset.py
│ ├── losses
│ ├── __init__.py
│ ├── chamfer_distance_loss.py
│ ├── combined_loss.py
│ ├── edge_crossing_loss.py
│ ├── overlapping_triangles_loss.py
│ ├── surface_distance_loss.py
│ └── triangle_collision_loss.py
│ ├── metrics
│ ├── __init__.py
│ ├── chamfer_distance.py
│ ├── edge_preservation.py
│ ├── hausdorff_distance.py
│ └── normal_consistency.py
│ ├── models
│ ├── __init__.py
│ ├── edge_predictor.py
│ ├── face_classifier.py
│ ├── layers
│ │ ├── __init__.py
│ │ ├── devconv.py
│ │ └── triconv.py
│ ├── neural_mesh_simplification.py
│ └── point_sampler.py
│ ├── trainer
│ ├── __init__.py
│ ├── resource_monitor.py
│ └── trainer.py
│ └── utils
│ ├── __init__.py
│ └── mesh_operations.py
├── tests
├── __init__.py
├── conftest.py
├── losses
│ ├── test_edge_crossings_loss.py
│ ├── test_overlapping_triangles_loss.py
│ ├── test_proba_chamfer_distance.py
│ ├── test_proba_surface_distance.py
│ └── test_triangle_collision_loss.py
├── mesh_data
│ ├── cube.obj
│ ├── rounded_cube.obj
│ └── sharp_cube.obj
├── test_dataset.py
├── test_edge_predictor.py
├── test_face_classifier.py
├── test_mesh_operations.py
├── test_metrics.py
├── test_model.py
├── test_model_layers.py
├── test_point_sampler.py
└── test_trimesh.py
├── train.ipynb
└── uv.lock
/.cursorrules:
--------------------------------------------------------------------------------
1 | I have thousands of 3D objects in various formats of meshes, level of detail, formats etc. I want to use pytorch to train a model to convert a high fidelity mesh to a lower level of detail, with minimal loss.
2 |
3 | We will work in an iterative manner, please try to follow these instructions and expectations to our process.
4 |
5 | 0. "The combined_output.txt holds code for the current code-base, of the important components. Take the current implementation into account when suggesting changes. Especially when a change in one file leads to a need to change another class or function. Suggesting new content for an existing class, should consider the current implementation and not just burst out based on assumptions what it should be."
6 |
7 | 1. "Analyze research papers thoroughly, breaking down complex concepts into clear, step-by-step explanations. Focus on architectural details, mathematical formulations, and key algorithms."
8 |
9 | 2. "When implementing new components, start with a high-level class structure. Outline necessary imports, initialization methods, and key function signatures before diving into detailed implementations."
10 |
11 | 3. "For each significant component or method, suggest comprehensive unit tests. Cover normal operations, edge cases, and potential failure modes."
12 |
13 | 4. "When addressing code errors, provide a detailed analysis of the error message and traceback. Explain the root cause and reasoning behind proposed solutions."
14 |
15 | 5. "Maintain consistency with the original research paper or design document. Regularly cross-reference implementations against the source material."
16 |
17 | 6. "Consider compatibility with relevant libraries and frameworks (e.g., PyTorch, PyTorch Geometric) when suggesting code implementations."
18 |
19 | 7. "Offer clear, intuitive explanations for complex mathematical concepts, especially those central to the project's core algorithms."
20 |
21 | 8. "Suggest visualization techniques for intermediate results to aid in debugging and understanding model behavior."
22 |
23 | 9. "Keep track of the overall project structure. Recommend refactoring or reorganization to maintain clean, modular, and efficient code."
24 |
25 | 10. "When implementing loss functions or evaluation metrics, explain the intuition and mathematical basis behind each component."
26 |
27 | 11. "Provide guidance on performance optimization, including suggestions for efficient data handling, model architecture improvements, and computational optimizations."
28 |
29 | 12. "Offer insights on potential extensions or modifications to the current project, based on recent advancements in the field or related research."
30 |
31 | 13. "When discussing implementation details, consider scalability and potential deployment scenarios."
32 |
33 | 14. "Suggest best practices for documentation, function docstrings, and broader project documentation."
34 |
35 | 15. "Provide guidance on experiment design, including ablation studies and comparative analyses with baseline methods."
36 |
37 | 16. "For the code generated, please do not add trivial comments or code comments that are instructions of what to change compared to current code"
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | charset = utf-8
5 | end_of_line = lf
6 | insert_final_newline = true
7 | trim_trailing_whitespace = true
8 |
9 | [*.py]
10 | indent_style = space
11 | indent_size = 4
12 | max_line_length = 180
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: Python CI
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 | pull_request:
7 | branches: [ "main" ]
8 |
9 | permissions:
10 | contents: read
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v4
19 | - name: Set up Python 3.12
20 | uses: actions/setup-python@v3
21 | with:
22 | python-version: "3.12"
23 | - name: Install system dependencies
24 | run: |
25 | sudo apt-get update
26 | sudo apt-get install -y libspatialindex-dev
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu
31 | pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
32 | pip install -r requirements.txt
33 | - name: Install this package
34 | run: |
35 | pip install -e .
36 | - name: Test with pytest
37 | run: |
38 | pytest
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | data/**/*.ply
3 | data/**/*.obj
4 | data/**/*.stl
5 | examples/data
6 | checkpoints
7 |
8 | file_structure.txt
9 | combined_output.txt
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 | cover/
63 |
64 | # Translations
65 | *.mo
66 | *.pot
67 |
68 | # PyBuilder
69 | .pybuilder/
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # IPython
76 | profile_default/
77 | ipython_config.py
78 |
79 | # pyenv
80 | # For a library or package, you might want to ignore these files since the code is
81 | # intended to run in multiple environments; otherwise, check them in:
82 | # .python-version
83 |
84 | # pipenv
85 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
86 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
87 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
88 | # install all needed dependencies.
89 | #Pipfile.lock
90 |
91 | # poetry
92 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
93 | # This is especially recommended for binary packages to ensure reproducibility, and is more
94 | # commonly ignored for libraries.
95 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
96 | #poetry.lock
97 |
98 | # pdm
99 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
100 | #pdm.lock
101 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
102 | # in version control.
103 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
104 | .pdm.toml
105 | .pdm-python
106 | .pdm-build/
107 |
108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
109 | __pypackages__/
110 |
111 | # Celery stuff
112 | celerybeat-schedule
113 | celerybeat.pid
114 |
115 | # SageMath parsed files
116 | *.sage.py
117 |
118 | # Environments
119 | .env
120 | .venv
121 | env/
122 | venv/
123 | ENV/
124 | env.bak/
125 | venv.bak/
126 |
127 | # Spyder project settings
128 | .spyderproject
129 | .spyproject
130 |
131 | # Rope project settings
132 | .ropeproject
133 |
134 | # mkdocs documentation
135 | /site
136 |
137 | # mypy
138 | .mypy_cache/
139 | .dmypy.json
140 | dmypy.json
141 |
142 | # Pyre type checker
143 | .pyre/
144 |
145 | # pytype static type analyzer
146 | .pytype/
147 |
148 | # Cython debug symbols
149 | cython_debug/
150 |
151 | # PyCharm
152 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
153 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
154 | # and can be added to the global gitignore or merged into this file. For a more nuclear
155 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
156 | #.idea/
157 | .idea/codestyles
158 |
159 | # Log files
160 | log/
161 |
162 | # Generated files from tests
163 | tests/mesh_data
164 |
165 | # Example output
166 | examples/data/simplified
167 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/neural-mesh-simplification.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.python-version:
--------------------------------------------------------------------------------
1 | 3.12
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Martin Høst Normark
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Mesh Simplification
2 |
3 | Implementation of the
4 | paper [Neural Mesh Simplification paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Potamias_Neural_Mesh_Simplification_CVPR_2022_paper.pdf)
5 | by Potamias et al. (CVPR 2022) and the updated info shared
6 | in [supplementary material](https://openaccess.thecvf.com/content/CVPR2022/supplemental/Potamias_Neural_Mesh_Simplification_CVPR_2022_supplemental.pdf).
7 |
8 | This Python package provides a fast, learnable method for mesh simplification that generates simplified meshes in
9 | real-time.
10 |
11 | ### Overview
12 |
13 | Neural Mesh Simplification is a novel approach to reduce the resolution of 3D meshes while preserving their appearance.
14 | Unlike traditional simplification methods that collapse edges in a greedy iterative manner, this method simplifies a
15 | given mesh in one pass using deep learning techniques.
16 |
17 | The method consists of three main steps:
18 |
19 | 1. Sampling a subset of input vertices using a sophisticated extension of random sampling.
20 | 2. Training a sparse attention network to propose candidate triangles based on the edge connectivity of sampled
21 | vertices.
22 | 3. Using a classification network to estimate the probability that a candidate triangle will be included in the final
23 | mesh.
24 |
25 | ### Features
26 |
27 | - Fast and scalable mesh simplification
28 | - One-pass simplification process
29 | - Preservation of mesh appearance
30 | - Lightweight and differentiable implementation
31 | - Suitable for integration into learnable pipelines
32 |
33 | ### Installation
34 |
35 | ```bash
36 | conda create -n neural-mesh-simplification python=3.12
37 | conda activate neural-mesh-simplification
38 | conda install pip
39 | ```
40 |
41 | Depending on whether you are using PyTorch on a CPU or a GPU,
42 | you'll have to use the correct binaries for PyTorch and the PyTorch Geometric libraries. You can install them via:
43 |
44 | ```bash
45 | pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu
46 | pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cpu.html
47 | ```
48 |
49 | Replace “cpu” with “cu121” or the appropriate CUDA version for your system. If you don't know what is your cuda version,
50 | run `nvidia-smi`
51 |
52 | After that you can install the remaining requirements
53 |
54 | ```bash
55 | pip install -r requirements.txt
56 | pip install -e .
57 | ```
58 |
59 | ### Example Usage / Playground
60 |
61 | 1. Drop your meshes as `.obj` files to the `examples/data` folder
62 | 2. Run the following command
63 |
64 | ```bash
65 | python examples/example.py
66 | ```
67 |
68 | 3. Collect the simplified meshes in `examples/data`. The simplified mesh objects file name will be the ones prefixed
69 | with `simplified_`.
70 |
71 | ### Data Preparation
72 |
73 | If you don't have a dataset for training and evaluation, you can use a collection from HuggingFace's 3D Meshes dataset.
74 | See https://huggingface.co/datasets/perler/ppsurf for more information.
75 |
76 | Run the following script to use the HuggingFace API to download
77 |
78 | ```bash
79 | python scripts/download_test_meshes.py
80 | ```
81 |
82 | Data will be downloaded in the `data/raw` folder at the root of the project.
83 | You can use `--target-folder` to specify a different folder.
84 |
85 | Once you have some data, you should preprocess it using the following script:
86 |
87 | ```bash
88 | python scripts/preprocess_data.py
89 | ```
90 |
91 | You can use the `--data_path` argument to specify the path to the dataset. The script will create a `data/processed`
92 |
93 | ### Training
94 |
95 | To train the model on your own dataset with the prepared data:
96 |
97 | ```bash
98 | python ./scripts/train.py
99 | ```
100 |
101 | By default, the default training config at `config/default.yaml` will be used. You can override it with
102 | `--config /path/to/your/config.yaml`.\
103 | You can override the following config parameters:
104 |
105 | * the checkpoint directory specified in the config file (where the model will be saved) with
106 | `--checkpoint-dir`.
107 | * the data path with `--data-path` if you have your data in a different location.
108 |
109 | If the training was interrupted, you can resume it by specifying the path to the previously created checkpoint directory
110 | with `--resume`.\
111 | Use `--debug` to see DEBUG logging.
112 |
113 | ### Evaluation
114 |
115 | To evaluate the model on a test set:
116 |
117 | ```bash
118 | python ./scripts/evaluate.py --eval-data-path /path/to/test/set --checkpoint /path/to/checkpoint.pth
119 | ```
120 |
121 | By default, the default training config at `config/default.yaml` will be used. You can override it with
122 | `--config /path/to/your/config.yaml`.
123 |
124 | ### Inference
125 |
126 | To simplify a mesh using the trained model:
127 |
128 | ```bash
129 | python ./scripts/infer.py --input-file /path/to/your/mesh.obj --output-file --model-checkpoint /path/to/checkpoint.pth --device cpu
130 | ```
131 |
132 | The default feature dimension for point sampler and face classifier is 128, but can be configured with
133 | `--hidden-dim `.\
134 | The default feature dimension for edge predictor is 64, but can be configured with `--edge-hidden-dim `.\
135 | If you have a CUDA-compatible GPU, you can specify `--device cuda` to use it for inference.
136 |
137 | ### Citation
138 |
139 | If you use this code in your research, please cite the original paper:
140 |
141 | ```
142 | @InProceedings{Potamias_2022_CVPR,
143 | author = {Potamias, Rolandos Alexandros and Ploumpis, Stylianos and Zafeiriou, Stefanos},
144 | title = {Neural Mesh Simplification},
145 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
146 | month = {June},
147 | year = {2022},
148 | pages = {18583-18592}
149 | }
150 | ```
151 |
152 | ## License
153 |
154 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
155 | m
156 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | # Model parameters
2 | model:
3 | input_dim: 3
4 | hidden_dim: 128 # feature dimension for point sampler and face classifier (as per paper)
5 | edge_hidden_dim: 64 # feature dimension for edge predictor (as per paper)
6 | num_layers: 3 # number of convolutional layers (as per paper)
7 | k: 15 # number of neighbors for graph construction (as per paper)
8 | edge_k: 15 # number of neighbors for edge features (as per paper)
9 | target_ratio: 0.5 # mesh simplification ratio
10 |
11 | # Training parameters
12 | training:
13 | learning_rate: 1.0e-5
14 | weight_decay: 0.99 # weight decay per epoch (as per paper)
15 | batch_size: 2
16 | num_epochs: 20 # total training epochs
17 | early_stopping_patience: 15 # epochs before early stopping
18 | checkpoint_dir: data/checkpoints # model save directory
19 |
20 | # Data parameters
21 | data:
22 | data_dir: data/processed
23 | val_split: 0.2
24 |
25 | # Loss weights
26 | loss:
27 | lambda_c: 1.0 # chamfer distance weight
28 | lambda_e: 1.0 # edge preservation weight
29 | lambda_o: 1.0 # normal consistency weight
30 |
--------------------------------------------------------------------------------
/configs/local.yaml:
--------------------------------------------------------------------------------
1 | # Model parameters
2 | model:
3 | input_dim: 3
4 | hidden_dim: 64 # feature dimension for point sampler and face classifier (as per paper)
5 | edge_hidden_dim: 128 # feature dimension for edge predictor (as per paper)
6 | num_layers: 3 # number of convolutional layers (as per paper)
7 | k: 15 # number of neighbors for graph construction (as per paper)
8 | edge_k: 15 # number of neighbors for edge features (as per paper)
9 | target_ratio: 0.5 # mesh simplification ratio
10 |
11 | # Training parameters
12 | training:
13 | learning_rate: 1.0e-5
14 | weight_decay: 0.99 # weight decay per epoch (as per paper)
15 | batch_size: 2
16 | accumulation_steps: 4
17 | num_epochs: 20 # total training epochs
18 | early_stopping_patience: 10 # epochs before early stopping
19 | checkpoint_dir: data/checkpoints # model save directory
20 |
21 | # Data parameters
22 | data:
23 | data_dir: data/processed
24 | val_split: 0.4
25 |
26 | # Loss weights
27 | loss:
28 | lambda_c: 1.0 # chamfer distance weight
29 | lambda_e: 1.0 # edge preservation weight
30 | lambda_o: 1.0 # normal consistency weight
31 |
--------------------------------------------------------------------------------
/examples/example.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import trimesh
4 | from trimesh import Scene, Trimesh
5 |
6 | from neural_mesh_simplification import NeuralMeshSimplifier
7 | from neural_mesh_simplification.data.dataset import load_mesh
8 |
9 | script_dir = os.path.dirname(os.path.abspath(__file__))
10 | data_dir = os.path.join(script_dir, "data")
11 | default_config_path = os.path.join(script_dir, "../configs/default.yaml")
12 |
13 |
14 | def load_config(config_path):
15 | import yaml
16 |
17 | with open(config_path, "r") as file:
18 | config = yaml.safe_load(file)
19 | return config
20 |
21 |
22 | def save_mesh_to_file(mesh: trimesh.Geometry, file_name: str):
23 | """
24 | Save the simplified mesh to file in the simplified folder.
25 | """
26 | simplified_dir = os.path.join(data_dir, "processed")
27 | os.makedirs(simplified_dir, exist_ok=True)
28 | output_path = os.path.join(simplified_dir, file_name)
29 | mesh.export(output_path)
30 |
31 | print(f"Mesh saved to: {output_path}")
32 |
33 |
34 | def cube_example(simplifier: NeuralMeshSimplifier):
35 | print(f"Creating cube mesh")
36 | file = "cube.obj"
37 | mesh = trimesh.creation.box(extents=[2, 2, 2])
38 | simplified_mesh = simplifier.simplify(mesh)
39 | save_mesh_to_file(mesh, file)
40 | save_mesh_to_file(simplified_mesh, f"simplified_{file}")
41 |
42 |
43 | def sphere_example(simplifier: NeuralMeshSimplifier):
44 | print(f"Creating sphere mesh")
45 | file = "sphere.obj"
46 | mesh = trimesh.creation.icosphere(subdivisions=2, radius=2)
47 | simplified_mesh = simplifier.simplify(mesh)
48 | save_mesh_to_file(mesh, file)
49 | save_mesh_to_file(simplified_mesh, f"simplified_{file}")
50 |
51 |
52 | def cylinder_example(simplifier: NeuralMeshSimplifier):
53 | print(f"Creating cylinder mesh")
54 | file = "cylinder.obj"
55 | mesh = trimesh.creation.cylinder(radius=1, height=2)
56 | simplified_mesh = simplifier.simplify(mesh)
57 | save_mesh_to_file(mesh, file)
58 | save_mesh_to_file(simplified_mesh, f"simplified_{file}")
59 |
60 |
61 | def mesh_dropbox_example(simplifier: NeuralMeshSimplifier):
62 | print(f"Loading all meshes of type '.obj' in folder '{data_dir}'")
63 | mesh_files = [f for f in os.listdir(data_dir) if f.endswith(".obj")]
64 |
65 | for file_name in mesh_files:
66 | mesh_path = os.path.join(data_dir, file_name)
67 |
68 | original_mesh = load_mesh(mesh_path)
69 |
70 | print("Loaded mesh at file" + mesh_path)
71 |
72 | # Create a new scene to hold the simplified meshes
73 | simplified_scene = Scene()
74 |
75 | if isinstance(original_mesh, Trimesh):
76 | print(
77 | "Original: ",
78 | original_mesh.vertices.shape,
79 | original_mesh.edges.shape,
80 | original_mesh.faces.shape,
81 | )
82 | simplified_geom = simplifier.simplify(original_mesh)
83 | print(
84 | "Simplified: ",
85 | simplified_geom.vertices.shape,
86 | simplified_geom.edges.shape,
87 | simplified_geom.faces.shape,
88 | )
89 |
90 | simplified_scene = simplified_geom
91 |
92 | elif isinstance(original_mesh, Scene):
93 | # Iterate through the original mesh geometry
94 | for name, geom in original_mesh.geometry.items():
95 | print("Original: ", geom)
96 | # Simplify each Trimesh object
97 | simplified_geom = simplifier.simplify(geom)
98 | print("Simplified: ", simplified_geom)
99 | # Add the simplified geometry to the new scene
100 | simplified_scene.add_geometry(simplified_geom, geom_name=name)
101 | else:
102 | raise ValueError(
103 | "Invalid mesh type (expected Trimesh or Scene):", type(original_mesh)
104 | )
105 |
106 | # Save the simplified mesh to file
107 | save_mesh_to_file(simplified_scene, f"simplified_{file_name}")
108 |
109 |
110 | def main():
111 | # Initialize the simplifier
112 | config = load_config(config_path=default_config_path)
113 | simplifier = NeuralMeshSimplifier(
114 | input_dim=config["model"]["input_dim"],
115 | hidden_dim=config["model"]["hidden_dim"],
116 | edge_hidden_dim=config["model"]["edge_hidden_dim"],
117 | num_layers=config["model"]["num_layers"],
118 | k=config["model"]["k"],
119 | edge_k=config["model"]["edge_k"],
120 | target_ratio=config["model"]["target_ratio"],
121 | )
122 |
123 | # cube_example(simplifier)
124 | # sphere_example(simplifier)
125 | # cylinder_example(simplifier)
126 | mesh_dropbox_example(simplifier)
127 |
128 |
129 | if __name__ == "__main__":
130 | main()
131 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "neural-mesh-simplification"
7 | version = "0.1.0"
8 | description = "A neural network-based approach to mesh simplification"
9 | authors = [{ name = "Martin Normark", email = "m@martinnormark.com" }, { name = "Gennaro Frazzingaro", email = "gennaro@thinair.ar" }]
10 | license = { file = "LICENSE" }
11 | readme = "README.md"
12 | requires-python = ">=3.12"
13 | classifiers = [
14 | "Programming Language :: Python :: 3",
15 | "License :: OSI Approved :: MIT License",
16 | "Operating System :: OS Independent",
17 | ]
18 |
19 | dependencies = [
20 | "numpy",
21 | "torch",
22 | "trimesh",
23 | "scipy",
24 | "matplotlib",
25 | "tqdm"
26 | ]
27 |
28 | [tool.setuptools.packages.find]
29 | where = ["src"]
30 | include = ["neural_mesh_simplification*"]
31 | namespaces = false
32 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | markers =
3 | trimesh: Run only these tests when specified
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohappyeyeballs==2.3.5
2 | aiohttp==3.10.1
3 | aiosignal==1.3.1
4 | attrs==24.2.0
5 | certifi==2024.7.4
6 | charset-normalizer==3.3.2
7 | datasets==2.20.0
8 | dill==0.3.8
9 | filelock==3.15.4
10 | frozenlist==1.4.1
11 | fsspec==2024.5.0
12 | huggingface-hub==0.24.5
13 | idna==3.7
14 | iniconfig==2.0.0
15 | Jinja2==3.1.4
16 | joblib==1.4.2
17 | MarkupSafe==2.1.5
18 | mpmath==1.3.0
19 | multidict==6.0.5
20 | multiprocess==0.70.16
21 | networkx==3.3
22 | numpy==2.0.1
23 | packaging==24.1
24 | pandas==2.2.2
25 | pluggy==1.5.0
26 | psutil==6.0.0
27 | pyarrow==17.0.0
28 | pyarrow-hotfix==0.6
29 | pyparsing==3.1.2
30 | pytest==8.3.2
31 | python-dateutil==2.9.0.post0
32 | pytz==2024.1
33 | PyYAML==6.0.2
34 | requests==2.32.3
35 | Rtree==1.3.0
36 | scikit-learn==1.5.1
37 | scipy==1.14.0
38 | setuptools==72.1.0
39 | six==1.16.0
40 | sympy==1.13.1
41 | threadpoolctl==3.5.0
42 | tqdm==4.66.5
43 | trimesh==4.4.4
44 | typing_extensions==4.12.2
45 | tzdata==2024.1
46 | urllib3==2.2.2
47 | wheel==0.43.0
48 | xxhash==3.4.1
49 | yarl==1.9.4
50 |
--------------------------------------------------------------------------------
/scripts/download_test_meshes.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 |
5 | from huggingface_hub import snapshot_download
6 |
7 | # abc_train is really large (+5k meshes)
8 | global folder_patterns
9 | folder_patterns = ["abc_extra_noisy/03_meshes/*.ply", "abc_train/03_meshes/*.ply"]
10 |
11 |
12 | def download_meshes(target_folder, folder_pattern=folder_patterns[0]):
13 | wip_folder = os.path.join(target_folder, "wip")
14 | os.makedirs(wip_folder, exist_ok=True)
15 |
16 | snapshot_download(
17 | repo_id="perler/ppsurf",
18 | repo_type="dataset",
19 | cache_dir=wip_folder,
20 | allow_patterns=folder_pattern,
21 | )
22 |
23 | # Move files from wip folder to target folder
24 | for root, _, files in os.walk(wip_folder):
25 | for file in files:
26 | if file.endswith(".ply"):
27 | src_file = os.path.join(root, file)
28 | dest_file = os.path.join(target_folder, file)
29 | shutil.copy2(src_file, dest_file)
30 | os.remove(src_file)
31 |
32 | # Remove the wip folder
33 | shutil.rmtree(wip_folder)
34 |
35 |
36 | def main():
37 | parser = argparse.ArgumentParser(
38 | description="Download test meshes from Hugging Face Hub."
39 | )
40 | parser.add_argument(
41 | "--target-folder",
42 | type=str,
43 | required=False,
44 | help="The target folder path where the meshes will be downloaded.",
45 | )
46 | parser.add_argument(
47 | "--dataset-size",
48 | type=str,
49 | required=False,
50 | default="small",
51 | choices=["small", "large"],
52 | help="The size of the dataset to download. Choose 'small' or 'large'.",
53 | )
54 |
55 | args = parser.parse_args()
56 |
57 | target_folder = args.target_folder if args.target_folder else "data/raw"
58 | if not os.path.exists(target_folder):
59 | os.makedirs(target_folder)
60 |
61 | download_meshes(
62 | target_folder,
63 | folder_patterns[0] if args.dataset_size == "small" else folder_patterns[1],
64 | )
65 |
66 |
67 | if __name__ == "__main__":
68 | main()
69 |
--------------------------------------------------------------------------------
/scripts/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from neural_mesh_simplification.trainer import Trainer
4 |
5 |
6 | def parse_args():
7 | parser = argparse.ArgumentParser(
8 | description="Evaluate the Neural Mesh Simplification model."
9 | )
10 | parser.add_argument(
11 | "--eval-data-path",
12 | type=str,
13 | required=True,
14 | help="Path to the evaluation data directory.",
15 | )
16 | parser.add_argument(
17 | "--config",
18 | type=str,
19 | required=False,
20 | help="Path to the evaluation configuration file.",
21 | )
22 | parser.add_argument(
23 | "--checkpoint", type=str, required=True, help="Path to the model checkpoint."
24 | )
25 | return parser.parse_args()
26 |
27 |
28 | def load_config(config_path):
29 | import yaml
30 |
31 | with open(config_path, "r") as file:
32 | config = yaml.safe_load(file)
33 | return config
34 |
35 |
36 | def main():
37 | args = parse_args()
38 | config = load_config(args.config)
39 | config["data"]["eval_data_path"] = args.eval_data_path
40 |
41 | trainer = Trainer(config)
42 | trainer.load_checkpoint(args.checkpoint)
43 |
44 | evaluation_metrics = trainer.evaluate(trainer.val_loader)
45 | for metric, value in evaluation_metrics.items():
46 | print(f"{metric}: {value:.4f}")
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/scripts/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from trimesh import Trimesh
4 |
5 | from neural_mesh_simplification import (
6 | NeuralMeshSimplifier,
7 | ) # Assuming the model class is named MeshSimplifier
8 | from neural_mesh_simplification.data.dataset import load_mesh
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(
13 | description="Simplify a 3D mesh using a trained model."
14 | )
15 | parser.add_argument(
16 | "--input-file", type=str, required=True, help="Path to the input mesh file."
17 | )
18 | parser.add_argument(
19 | "--output-file",
20 | type=str,
21 | required=True,
22 | help="Where to save the simplified mesh.",
23 | )
24 | parser.add_argument(
25 | "--hidden-dim",
26 | type=int,
27 | required=False,
28 | default=128,
29 | help="Feature dimension for point sampler and face classifier",
30 | )
31 | parser.add_argument(
32 | "--edge-hidden-dim",
33 | type=int,
34 | required=False,
35 | default=64,
36 | help="Feature dimension for edge predictor",
37 | )
38 | parser.add_argument(
39 | "--model-checkpoint",
40 | type=str,
41 | required=True,
42 | help="Path to the trained model checkpoint.",
43 | )
44 | parser.add_argument(
45 | "--device",
46 | type=str,
47 | default="cpu",
48 | help="Device to use for inference (`cpu` or `cuda`).",
49 | )
50 |
51 | return parser.parse_args()
52 |
53 |
54 | def simplify_mesh(
55 | input_file: str,
56 | output_file: str,
57 | model_checkpoint: str,
58 | hidden_dim: int,
59 | edge_hidden_dim: int,
60 | device="cpu",
61 | ):
62 | """
63 | Simplifies a 3D mesh using a trained model.
64 |
65 | Args:
66 | input_file (str): Path to the high-resolution input `.obj` file.
67 | model_checkpoint (str): Path to the trained model checkpoint.
68 | hidden_dim (int): Feature dimension for point sampler and face classifier
69 | edge_hidden_dim (int): Feature dimension for edge predictor
70 | device (str): Device to use for inference (`cpu` or `cuda`).
71 | """
72 | # Load the trained model
73 | print(f"Loading model from {model_checkpoint}...")
74 | simplifier = NeuralMeshSimplifier.using_model(
75 | model_checkpoint,
76 | hidden_dim=hidden_dim,
77 | edge_hidden_dim=edge_hidden_dim,
78 | map_location=device,
79 | )
80 | simplifier.model.to(device)
81 | simplifier.model.eval()
82 |
83 | # Load the input mesh
84 | print(f"Loading input mesh from {input_file}...")
85 | original_mesh = load_mesh(input_file)
86 |
87 | if not isinstance(original_mesh, Trimesh):
88 | raise ValueError("Invalid format for input mesh.")
89 |
90 | simplified_mesh = simplifier.simplify(original_mesh)
91 |
92 | # Save the simplified mesh
93 | print(f"Saving simplified mesh to {output_file}...")
94 | simplified_mesh.export(output_file)
95 | print("Simplification complete.")
96 |
97 |
98 | def load_config(config_path):
99 | import yaml
100 |
101 | with open(config_path, "r") as file:
102 | config = yaml.safe_load(file)
103 | return config
104 |
105 |
106 | def main():
107 | args = parse_args()
108 |
109 | simplify_mesh(
110 | input_file=args.input_file,
111 | output_file=args.output_file,
112 | model_checkpoint=args.model_checkpoint,
113 | hidden_dim=args.hidden_dim,
114 | edge_hidden_dim=args.edge_hidden_dim,
115 | device=args.device,
116 | )
117 |
118 |
119 | if __name__ == "__main__":
120 | main()
121 |
--------------------------------------------------------------------------------
/scripts/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import networkx as nx
4 | import trimesh
5 | from tqdm import tqdm
6 |
7 | from neural_mesh_simplification.data import MeshSimplificationDataset
8 | from neural_mesh_simplification.data.dataset import load_mesh, preprocess_mesh
9 |
10 |
11 | def preprocess_dataset(
12 | input_dir,
13 | output_dir,
14 | pre_process=True,
15 | min_components=1,
16 | max_components=1,
17 | print_stats=False,
18 | ):
19 | dataset = MeshSimplificationDataset(data_dir=input_dir)
20 |
21 | for idx in tqdm(range(len(dataset)), desc="Processing meshes"):
22 | file_path = os.path.join(dataset.data_dir, dataset.file_list[idx])
23 |
24 | mesh = load_mesh(file_path)
25 |
26 | if pre_process:
27 | mesh = preprocess_mesh(mesh)
28 |
29 | if mesh is not None:
30 | face_adjacency = trimesh.graph.face_adjacency(mesh.faces)
31 |
32 | G = nx.Graph()
33 | G.add_edges_from(face_adjacency)
34 |
35 | components = list(nx.connected_components(G))
36 |
37 | num_components = len(components)
38 | num_vertices = len(mesh.vertices)
39 | num_faces = len(mesh.faces)
40 |
41 | if num_components < min_components or num_components > max_components:
42 | print(f"Skipping mesh {idx}: {dataset.file_list[idx]}")
43 | print(f" Connected components: {num_components}")
44 | print()
45 | continue
46 |
47 | if print_stats:
48 | print(f"Mesh {idx}: {dataset.file_list[idx]}")
49 | print(f" Connected components: {num_components}")
50 | print(f" Vertices: {num_vertices}")
51 | print(f" Faces: {num_faces}")
52 | print(f" Is watertight: {mesh.is_watertight}")
53 | print(f" Volume: {mesh.volume}")
54 | print(f" Surface area: {mesh.area}")
55 |
56 | non_manifold_edges = mesh.edges_unique[mesh.edges_unique_length > 2]
57 | print(f" Number of non-manifold edges: {len(non_manifold_edges)}")
58 | print()
59 |
60 | output_file = os.path.join(output_dir, dataset.file_list[idx])
61 | mesh.export(output_file.replace(".ply", ".stl"))
62 | else:
63 | print(f"Failed to load mesh {idx}: {dataset.file_list[idx]}")
64 | print()
65 |
66 | print("Finished processing all meshes.")
67 |
68 |
69 | if __name__ == "__main__":
70 | if not os.path.exists("data/raw"):
71 | raise FileNotFoundError(
72 | "The 'data/raw' directory does not exist. Please download the dataset first."
73 | )
74 |
75 | os.makedirs("data/processed", exist_ok=True)
76 | preprocess_dataset("data/raw", "data/processed")
77 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 |
5 | from neural_mesh_simplification.trainer import Trainer
6 |
7 | script_dir = os.path.dirname(os.path.abspath(__file__))
8 | default_config_path = os.path.join(script_dir, "../configs/default.yaml")
9 |
10 |
11 | def parse_args():
12 | parser = argparse.ArgumentParser(
13 | description="Train the Neural Mesh Simplification model."
14 | )
15 | parser.add_argument(
16 | "--data-path",
17 | type=str,
18 | required=False,
19 | help="Path to the training data directory.",
20 | )
21 | parser.add_argument(
22 | "--config",
23 | type=str,
24 | required=False,
25 | help="Path to the training configuration file.",
26 | )
27 | parser.add_argument(
28 | "--checkpoint-dir",
29 | type=str,
30 | default="checkpoints",
31 | help="Directory to save model checkpoints.",
32 | )
33 | parser.add_argument(
34 | "--resume",
35 | type=str,
36 | default=None,
37 | help="Path to a checkpoint to resume training from.",
38 | )
39 | parser.add_argument("--debug", action="store_true", help="Show debug logs")
40 | parser.add_argument(
41 | "--monitor", action="store_true", help="Monitor CPU and memory usage"
42 | )
43 | return parser.parse_args()
44 |
45 |
46 | def load_config(config_path):
47 | import yaml
48 |
49 | with open(config_path, "r") as file:
50 | config = yaml.safe_load(file)
51 | return config
52 |
53 |
54 | def main():
55 | args = parse_args()
56 |
57 | config_path = args.config if args.config else default_config_path
58 |
59 | config = load_config(config_path)
60 |
61 | if args.data_path:
62 | config["data"]["data_dir"] = args.data_path
63 |
64 | if args.checkpoint_dir:
65 | config["training"]["checkpoint_dir"] = args.checkpoint_dir
66 |
67 | if not os.path.exists(args.checkpoint_dir):
68 | os.makedirs(args.checkpoint_dir)
69 |
70 | if args.debug:
71 | logging.basicConfig(level=logging.DEBUG)
72 | else:
73 | logging.basicConfig(level=logging.INFO)
74 |
75 | if args.monitor:
76 | config["monitor_resources"] = True
77 |
78 | trainer = Trainer(config)
79 |
80 | if args.resume:
81 | trainer.load_checkpoint(args.resume)
82 |
83 | try:
84 | trainer.train()
85 | except Exception as e:
86 | trainer.handle_error(e)
87 | trainer.save_training_state(
88 | os.path.join(config["training"]["checkpoint_dir"], "training_state.pth")
89 | )
90 |
91 |
92 | if __name__ == "__main__":
93 | main()
94 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/__init__.py:
--------------------------------------------------------------------------------
1 | from .api.neural_mesh_simplifier import NeuralMeshSimplifier
2 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/api/__init__.py:
--------------------------------------------------------------------------------
1 | from .neural_mesh_simplifier import NeuralMeshSimplifier
2 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/api/neural_mesh_simplifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import trimesh
3 | from torch_geometric.data import Data
4 |
5 | from ..data.dataset import preprocess_mesh, mesh_to_tensor
6 | from ..models import NeuralMeshSimplification
7 |
8 |
9 | class NeuralMeshSimplifier:
10 | def __init__(
11 | self,
12 | input_dim,
13 | hidden_dim,
14 | edge_hidden_dim, # Separate hidden dim for edge predictor
15 | num_layers,
16 | k,
17 | edge_k,
18 | target_ratio,
19 | ):
20 | self.input_dim = input_dim
21 | self.hidden_dim = hidden_dim
22 | self.edge_hidden_dim = edge_hidden_dim
23 | self.num_layers = num_layers
24 | self.k = k
25 | self.edge_k = edge_k
26 | self.target_ratio = target_ratio
27 | self.model = self._build_model()
28 |
29 | @classmethod
30 | def using_model(
31 | cls, at_path: str, hidden_dim: int, edge_hidden_dim: int, map_location: str
32 | ):
33 | instance = cls(
34 | input_dim=3,
35 | hidden_dim=hidden_dim,
36 | edge_hidden_dim=edge_hidden_dim,
37 | num_layers=3,
38 | k=15,
39 | edge_k=15,
40 | target_ratio=0.5,
41 | )
42 | instance._load_model(at_path, map_location)
43 | return instance
44 |
45 | def _build_model(self):
46 | return NeuralMeshSimplification(
47 | input_dim=self.input_dim,
48 | hidden_dim=self.hidden_dim,
49 | edge_hidden_dim=self.edge_hidden_dim,
50 | num_layers=self.num_layers,
51 | k=self.k,
52 | edge_k=self.edge_k,
53 | target_ratio=self.target_ratio,
54 | )
55 |
56 | def _load_model(self, checkpoint_path: str, map_location: str):
57 | self.model.load_state_dict(
58 | torch.load(checkpoint_path, map_location=map_location)
59 | )
60 |
61 | def simplify(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh:
62 | # Preprocess the mesh (e.g. normalize, center)
63 | preprocessed_mesh: trimesh.Trimesh = preprocess_mesh(mesh)
64 |
65 | # Convert to a tensor
66 | tensor: Data = mesh_to_tensor(preprocessed_mesh)
67 | model_output = self.model(tensor)
68 |
69 | vertices = model_output["sampled_vertices"].detach().numpy()
70 | faces = model_output["simplified_faces"].numpy()
71 | edges = model_output["edge_index"].t().numpy() # Transpose to get (n, 2) shape
72 |
73 | return trimesh.Trimesh(vertices=vertices, faces=faces, edges=edges)
74 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import MeshSimplificationDataset
2 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/data/dataset.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | from typing import Optional
4 |
5 | import numpy as np
6 | import torch
7 | import trimesh
8 | from torch.utils.data import Dataset
9 | from torch_geometric.data import Data
10 | from trimesh import Geometry, Trimesh
11 |
12 | from ..utils import build_graph_from_mesh
13 |
14 |
15 | class MeshSimplificationDataset(Dataset):
16 | def __init__(
17 | self,
18 | data_dir: str,
19 | preprocess: bool = False,
20 | transform: Optional[callable] = None,
21 | ):
22 | self.data_dir = data_dir
23 | self.preprocess = preprocess
24 | self.transform = transform
25 | self.file_list = self._get_file_list()
26 |
27 | def _get_file_list(self):
28 | return [
29 | f
30 | for f in os.listdir(self.data_dir)
31 | if f.endswith(".ply") or f.endswith(".obj") or f.endswith(".stl")
32 | ]
33 |
34 | def __len__(self):
35 | return len(self.file_list)
36 |
37 | def __getitem__(self, idx):
38 | file_path = os.path.join(self.data_dir, self.file_list[idx])
39 | mesh = load_mesh(file_path)
40 |
41 | if self.preprocess:
42 | mesh = preprocess_mesh(mesh)
43 |
44 | if self.transform:
45 | mesh = self.transform(mesh)
46 |
47 | data = mesh_to_tensor(mesh)
48 | gc.collect()
49 | return data
50 |
51 |
52 | def load_mesh(file_path: str) -> Geometry | list[Geometry] | None:
53 | """Load a mesh from file."""
54 | try:
55 | mesh = trimesh.load(file_path)
56 | return mesh
57 | except Exception as e:
58 | print(f"Error loading mesh {file_path}: {e}")
59 | return None
60 |
61 |
62 | def preprocess_mesh(mesh: trimesh.Trimesh) -> Trimesh | None:
63 | """Preprocess a mesh (e.g., normalize, center)."""
64 | if mesh is None:
65 | return None
66 |
67 | # Center the mesh
68 | mesh.vertices -= mesh.vertices.mean(axis=0)
69 |
70 | # Scale to unit cube
71 | max_dim = np.max(mesh.vertices.max(axis=0) - mesh.vertices.min(axis=0))
72 | mesh.vertices /= max_dim
73 |
74 | return mesh
75 |
76 |
77 | def augment_mesh(mesh: trimesh.Trimesh) -> Trimesh | None:
78 | """Apply data augmentation to a mesh."""
79 | if mesh is None:
80 | return None
81 |
82 | # Example: Random rotation
83 | rotation = trimesh.transformations.random_rotation_matrix()
84 | mesh.apply_transform(rotation)
85 |
86 | return mesh
87 |
88 |
89 | def mesh_to_tensor(mesh: trimesh.Trimesh) -> Data:
90 | """Convert a mesh to tensor representation including graph structure."""
91 | if mesh is None:
92 | return None
93 |
94 | # Convert vertices and faces to tensors
95 | vertices_tensor = torch.tensor(mesh.vertices, dtype=torch.float32)
96 | faces_tensor = torch.tensor(mesh.faces, dtype=torch.long).t()
97 |
98 | # Build graph structure
99 | G = build_graph_from_mesh(mesh)
100 |
101 | # Create edge index tensor
102 | edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
103 |
104 | # Create Data object
105 | data = Data(
106 | x=vertices_tensor,
107 | pos=vertices_tensor,
108 | edge_index=edge_index,
109 | face=faces_tensor,
110 | num_nodes=len(mesh.vertices),
111 | )
112 |
113 | return data
114 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .chamfer_distance_loss import ProbabilisticChamferDistanceLoss
2 | from .surface_distance_loss import ProbabilisticSurfaceDistanceLoss
3 | from .triangle_collision_loss import TriangleCollisionLoss
4 | from .edge_crossing_loss import EdgeCrossingLoss
5 | from .overlapping_triangles_loss import OverlappingTrianglesLoss
6 | from .combined_loss import CombinedMeshSimplificationLoss
7 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/losses/chamfer_distance_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ProbabilisticChamferDistanceLoss(nn.Module):
6 | def __init__(self):
7 | super(ProbabilisticChamferDistanceLoss, self).__init__()
8 |
9 | def forward(self, P, Ps, probabilities):
10 | """
11 | Compute the Probabilistic Chamfer Distance loss.
12 |
13 | Args:
14 | P (torch.Tensor): Original point cloud, shape (N, 3)
15 | Ps (torch.Tensor): Sampled point cloud, shape (M, 3)
16 | probabilities (torch.Tensor): Sampling probabilities for Ps, shape (M,)
17 |
18 | Returns:
19 | torch.Tensor: Scalar loss value
20 | """
21 | if P.size(0) == 0 or Ps.size(0) == 0:
22 | return torch.tensor(0.0, device=P.device, requires_grad=True)
23 |
24 | # Ensure inputs are on the same device
25 | Ps = Ps.to(P.device)
26 | probabilities = probabilities.to(P.device)
27 |
28 | # Compute distances from Ps to P
29 | dist_s_to_o = self.compute_minimum_distances(Ps, P)
30 |
31 | # Compute distances from P to Ps
32 | dist_o_to_s, min_indices = self.compute_minimum_distances(
33 | P, Ps, return_indices=True
34 | )
35 |
36 | # Weight distances by probabilities
37 | weighted_dist_s_to_o = dist_s_to_o * probabilities
38 | weighted_dist_o_to_s = dist_o_to_s * probabilities[min_indices]
39 |
40 | # Sum up the weighted distances
41 | loss = weighted_dist_s_to_o.sum() + weighted_dist_o_to_s.sum()
42 |
43 | return loss
44 |
45 | def compute_minimum_distances(self, source, target, return_indices=False):
46 | """
47 | Compute the minimum distances from each point in source to target.
48 |
49 | Args:
50 | source (torch.Tensor): Source point cloud, shape (N, 3)
51 | target (torch.Tensor): Target point cloud, shape (M, 3)
52 | return_indices (bool): If True, also return indices of minimum distances
53 |
54 | Returns:
55 | torch.Tensor: Minimum distances, shape (N,)
56 | torch.Tensor: Indices of minimum distances (if return_indices is True)
57 | """
58 | # Compute pairwise distances
59 | distances = torch.cdist(source, target)
60 |
61 | # Find minimum distances
62 | min_distances, min_indices = distances.min(dim=1)
63 |
64 | if return_indices:
65 | return min_distances, min_indices
66 | else:
67 | return min_distances
68 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/losses/combined_loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch import device
3 |
4 | from . import (
5 | ProbabilisticChamferDistanceLoss,
6 | ProbabilisticSurfaceDistanceLoss,
7 | TriangleCollisionLoss,
8 | EdgeCrossingLoss,
9 | OverlappingTrianglesLoss,
10 | )
11 |
12 |
13 | class CombinedMeshSimplificationLoss(nn.Module):
14 | def __init__(
15 | self,
16 | lambda_c: float = 1.0,
17 | lambda_e: float = 1.0,
18 | lambda_o: float = 1.0,
19 | device=device("cpu"),
20 | ):
21 | super().__init__()
22 | self.device = device
23 | self.prob_chamfer_loss = ProbabilisticChamferDistanceLoss().to(self.device)
24 | self.prob_surface_loss = ProbabilisticSurfaceDistanceLoss().to(self.device)
25 | self.collision_loss = TriangleCollisionLoss().to(self.device)
26 | self.edge_crossing_loss = EdgeCrossingLoss().to(self.device)
27 | self.overlapping_triangles_loss = OverlappingTrianglesLoss().to(self.device)
28 | self.lambda_c = lambda_c
29 | self.lambda_e = lambda_e
30 | self.lambda_o = lambda_o
31 |
32 | def forward(self, original_data, simplified_data):
33 | original_x = (
34 | original_data["pos"] if "pos" in original_data else original_data["x"]
35 | ).to(self.device)
36 | original_face = original_data["face"].to(self.device)
37 |
38 | sampled_vertices = simplified_data["sampled_vertices"].to(self.device)
39 | sampled_probs = simplified_data["sampled_probs"].to(self.device)
40 | sampled_faces = simplified_data["simplified_faces"].to(self.device)
41 | face_probs = simplified_data["face_probs"].to(self.device)
42 |
43 | chamfer_loss = self.prob_chamfer_loss(
44 | original_x, sampled_vertices, sampled_probs
45 | )
46 |
47 | del sampled_probs
48 |
49 | surface_loss = self.prob_surface_loss(
50 | original_x,
51 | original_face,
52 | sampled_vertices,
53 | sampled_faces,
54 | face_probs,
55 | )
56 |
57 | del original_x
58 | del original_face
59 |
60 | collision_loss = self.collision_loss(
61 | sampled_vertices,
62 | sampled_faces,
63 | face_probs,
64 | )
65 | edge_crossing_loss = self.edge_crossing_loss(
66 | sampled_vertices, sampled_faces, face_probs
67 | )
68 |
69 | del face_probs
70 |
71 | overlapping_triangles_loss = self.overlapping_triangles_loss(
72 | sampled_vertices, sampled_faces
73 | )
74 |
75 | del sampled_vertices
76 | del sampled_faces
77 |
78 | total_loss = (
79 | chamfer_loss
80 | + surface_loss
81 | + self.lambda_c * collision_loss
82 | + self.lambda_e * edge_crossing_loss
83 | + self.lambda_o * overlapping_triangles_loss
84 | )
85 |
86 | return total_loss
87 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/losses/edge_crossing_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_cluster import knn
4 |
5 |
6 | class EdgeCrossingLoss(nn.Module):
7 | def __init__(self, k: int = 20):
8 | super().__init__()
9 | self.k = k # Number of nearest triangles to consider
10 |
11 | def forward(
12 | self, vertices: torch.Tensor, faces: torch.Tensor, face_probs: torch.Tensor
13 | ) -> torch.Tensor:
14 | # If no faces, return zero loss
15 | if faces.shape[0] == 0:
16 | return torch.tensor(0.0, device=vertices.device)
17 | # Ensure face_probs matches the number of faces
18 | if face_probs.shape[0] > faces.shape[0]:
19 | face_probs = face_probs[: faces.shape[0]]
20 | elif face_probs.shape[0] < faces.shape[0]:
21 | # Pad with zeros if we have fewer probabilities than faces
22 | padding = torch.zeros(
23 | faces.shape[0] - face_probs.shape[0], device=face_probs.device
24 | )
25 | face_probs = torch.cat([face_probs, padding])
26 |
27 | # 1. Find k-nearest triangles for each triangle
28 | nearest_triangles = self.find_nearest_triangles(vertices, faces)
29 |
30 | # 2. Detect edge crossings between nearby triangles
31 | crossings = self.detect_edge_crossings(vertices, faces, nearest_triangles)
32 |
33 | # 3. Calculate loss
34 | loss = self.calculate_loss(crossings, face_probs)
35 |
36 | return loss
37 |
38 | def find_nearest_triangles(
39 | self, vertices: torch.Tensor, faces: torch.Tensor
40 | ) -> torch.Tensor:
41 | # Compute triangle centroids
42 | centroids = vertices[faces].mean(dim=1)
43 |
44 | # Use knn to find nearest triangles
45 | k = min(
46 | self.k, centroids.shape[0]
47 | ) # Ensure k is not larger than the number of centroids
48 | _, indices = knn(centroids, centroids, k=k)
49 |
50 | # Reshape indices to [num_faces, k]
51 | indices = indices.view(centroids.shape[0], k)
52 |
53 | # Remove self-connections (triangles cannot be their own neighbor)
54 | nearest = []
55 | for i in range(indices.shape[0]):
56 | neighbors = indices[i][indices[i] != i]
57 | if len(neighbors) == 0:
58 | nearest.append(torch.empty(0, dtype=torch.long))
59 | else:
60 | nearest.append(neighbors[: self.k - 1])
61 |
62 | # Return tensor with consistent shape
63 | if len(nearest) > 0 and all(len(n) == 0 for n in nearest):
64 | nearest = torch.empty((len(nearest), 0), dtype=torch.long)
65 | else:
66 | nearest = torch.stack(nearest)
67 | return nearest
68 |
69 | def detect_edge_crossings(
70 | self,
71 | vertices: torch.Tensor,
72 | faces: torch.Tensor,
73 | nearest_triangles: torch.Tensor,
74 | ) -> torch.Tensor:
75 | def edge_vectors(triangles):
76 | # Extracts the edges from a triangle defined by vertex indices
77 | return vertices[triangles[:, [1, 2, 0]]] - vertices[triangles]
78 |
79 | edges = edge_vectors(faces)
80 | crossings = torch.zeros(faces.shape[0], device=vertices.device)
81 |
82 | for i in range(faces.shape[0]):
83 | neighbor_edges = edge_vectors(faces[nearest_triangles[i]])
84 | for j in range(3):
85 | edge = edges[i, j].unsqueeze(0).unsqueeze(0)
86 | cross_product = torch.cross(
87 | edge.expand(neighbor_edges.shape), neighbor_edges, dim=-1
88 | )
89 | t = torch.sum(cross_product * neighbor_edges, dim=-1) / torch.sum(
90 | cross_product * edge.expand(neighbor_edges.shape), dim=-1
91 | )
92 | u = torch.sum(
93 | cross_product * edges[i].unsqueeze(0), dim=-1
94 | ) / torch.sum(cross_product * edge.expand(neighbor_edges.shape), dim=-1)
95 | mask = (t >= 0) & (t <= 1) & (u >= 0) & (u <= 1)
96 | crossings[i] += mask.sum()
97 |
98 | return crossings
99 |
100 | def calculate_loss(
101 | self, crossings: torch.Tensor, face_probs: torch.Tensor
102 | ) -> torch.Tensor:
103 | # Weighted sum of crossings by triangle probabilities
104 | num_faces = face_probs.shape[0]
105 | return torch.sum(face_probs * crossings, dtype=torch.float32) / num_faces
106 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/losses/overlapping_triangles_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class OverlappingTrianglesLoss(nn.Module):
6 | def __init__(self, num_samples: int = 10, k: int = 5):
7 | """
8 | Initializes the OverlappingTrianglesLoss.
9 |
10 | Args:
11 | num_samples (int): The number of points to sample from each triangle.
12 | k (int): The number of nearest triangles to consider for overlap checking.
13 | """
14 | super().__init__()
15 | self.num_samples = num_samples # Number of points to sample from each triangle
16 | self.k = k # Number of nearest triangles to consider
17 |
18 | def forward(self, vertices: torch.Tensor, faces: torch.Tensor):
19 |
20 | # If no faces, return zero loss
21 | if faces.shape[0] == 0:
22 | return torch.tensor(0.0, device=vertices.device)
23 |
24 | # 1. Sample points from each triangle
25 | sampled_points, point_face_indices = self.sample_points_from_triangles(
26 | vertices, faces
27 | )
28 |
29 | # 2. Find k-nearest triangles for each point
30 | nearest_triangles = self.find_nearest_triangles(sampled_points, vertices, faces)
31 |
32 | # 3. Detect overlaps and calculate the loss
33 | overlap_penalty = self.calculate_overlap_loss(
34 | sampled_points, vertices, faces, nearest_triangles, point_face_indices
35 | )
36 |
37 | return overlap_penalty
38 |
39 | def sample_points_from_triangles(
40 | self, vertices: torch.Tensor, faces: torch.Tensor
41 | ) -> (torch.Tensor, torch.Tensor):
42 | """
43 | Samples points from each triangle in the mesh.
44 |
45 | Args:
46 | vertices (torch.Tensor): The vertex positions (V x 3).
47 | faces (torch.Tensor): The indices of the vertices that make up each triangle (F x 3).
48 |
49 | Returns:
50 | torch.Tensor: Sampled points (F * num_samples x 3).
51 | torch.Tensor: index mapping from points to their original faces.
52 | """
53 | # Get vertices for all faces at once
54 | v0, v1, v2 = vertices[faces].unbind(1)
55 |
56 | # Generate random barycentric coordinates
57 | rand_shape = (faces.shape[0], self.num_samples, 1)
58 | u = torch.rand(rand_shape, device=vertices.device)
59 | v = torch.rand(rand_shape, device=vertices.device)
60 |
61 | # Adjust coordinates that sum > 1
62 | mask = (u + v) > 1
63 | u = torch.where(mask, 1 - u, u)
64 | v = torch.where(mask, 1 - v, v)
65 | w = 1 - u - v
66 |
67 | # Calculate the coordinates of the sampled points
68 | points = v0.unsqueeze(1) * w + v1.unsqueeze(1) * u + v2.unsqueeze(1) * v
69 |
70 | # Create index mapping from points to their original faces
71 | point_face_indices = torch.arange(faces.shape[0], device=vertices.device)
72 | point_face_indices = point_face_indices.repeat_interleave(self.num_samples)
73 |
74 | # Reshape to a (F * num_samples x 3) tensor
75 | points = points.reshape(-1, 3)
76 |
77 | return points, point_face_indices
78 |
79 | def find_nearest_triangles(
80 | self, sampled_points: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor
81 | ) -> torch.Tensor:
82 | """
83 | Finds the k-nearest triangles for each sampled point.
84 |
85 | Args:
86 | sampled_points (torch.Tensor): Sampled points from triangles (N x 3).
87 | vertices (torch.Tensor): The vertex positions (V x 3).
88 | faces (torch.Tensor): The indices of the vertices that make up each triangle (F x 3).
89 |
90 | Returns:
91 | torch.Tensor: Indices of the k-nearest triangles for each sampled point (N x k).
92 | """
93 | # Compute triangle centroids
94 | centroids = vertices[faces].mean(dim=1)
95 |
96 | # Adjust k to be no larger than the number of triangles
97 | k = min(self.k, faces.shape[0])
98 | if k == 0:
99 | # Return empty tensor if no triangles
100 | return torch.empty(
101 | (sampled_points.shape[0], 0),
102 | dtype=torch.long,
103 | device=sampled_points.device,
104 | )
105 |
106 | # Use knn to find nearest triangles for each sampled point
107 | distances = torch.cdist(sampled_points, centroids)
108 | _, indices = distances.topk(k, dim=1, largest=False)
109 |
110 | return indices
111 |
112 | def calculate_overlap_loss(
113 | self,
114 | sampled_points: torch.Tensor,
115 | vertices: torch.Tensor,
116 | faces: torch.Tensor,
117 | nearest_triangles: torch.Tensor,
118 | point_face_indices: torch.Tensor,
119 | ) -> torch.Tensor:
120 | """
121 | Calculates the overlap loss by checking if sampled points belong to multiple triangles.
122 |
123 | Args:
124 | sampled_points (torch.Tensor): Sampled points from triangles (N x 3).
125 | vertices (torch.Tensor): The vertex positions (V x 3).
126 | faces (torch.Tensor): The indices of the vertices that make up each triangle (F x 3).
127 | nearest_triangles (torch.Tensor): Indices of the k-nearest triangles for each sampled point (N x k).
128 | point_face_indices (torch.Tensor): Index mapping from points to their original faces.
129 |
130 | Returns:
131 | torch.Tensor: The overlap penalty loss.
132 | """
133 |
134 | # Reshape for broadcasting
135 | points_expanded = sampled_points.unsqueeze(1) # [N, 1, 3]
136 | nearest_faces = faces[nearest_triangles] # [N, K, 3]
137 |
138 | # Get vertices for all nearest triangles
139 | v0 = vertices[nearest_faces[..., 0]] # [N, K, 3]
140 | v1 = vertices[nearest_faces[..., 1]]
141 | v2 = vertices[nearest_faces[..., 2]]
142 |
143 | # Calculate edges
144 | edge1 = v1 - v0 # [N, K, 3]
145 | edge2 = v2 - v0
146 |
147 | # Calculate normals
148 | normals = torch.linalg.cross(edge1, edge2) # [N, K, 3]
149 | normal_lengths = torch.norm(normals, dim=2, keepdim=True)
150 | normals = normals / (normal_lengths + 1e-8)
151 |
152 | del edge1, edge2, normal_lengths
153 |
154 | # Calculate barycentric coordinates for all points at once
155 | p_v0 = points_expanded - v0 # [N, K, 3]
156 |
157 | # Compute dot products for barycentric coordinates
158 | dot00 = torch.sum(normals * normals, dim=2) # [N, K]
159 | dot01 = torch.sum(normals * (v1 - v0), dim=2)
160 | dot02 = torch.sum(normals * (v2 - v0), dim=2)
161 | dot0p = torch.sum(normals * p_v0, dim=2)
162 |
163 | del p_v0, normals
164 |
165 | # Calculate barycentric coordinates
166 | denom = dot00 * dot00 - dot01 * dot01
167 | u = (dot00 * dot0p - dot01 * dot02) / (denom + 1e-8)
168 | v = (dot00 * dot02 - dot01 * dot0p) / (denom + 1e-8)
169 |
170 | del dot00, dot01, dot02, dot0p, denom
171 |
172 | # Check if points are inside triangles
173 | inside_mask = (u >= 0) & (v >= 0) & (u + v <= 1)
174 |
175 | # Don't count overlap with source triangle
176 | source_mask = nearest_triangles == point_face_indices.unsqueeze(1)
177 | inside_mask = inside_mask & ~source_mask
178 |
179 | # Calculate areas only for inside points
180 | areas = torch.zeros_like(inside_mask, dtype=torch.float32)
181 | where_inside = torch.where(inside_mask)
182 |
183 | if where_inside[0].numel() > 0:
184 | # Calculate areas only for points inside triangles
185 | relevant_v0 = v0[where_inside]
186 | relevant_v1 = v1[where_inside]
187 | relevant_v2 = v2[where_inside]
188 |
189 | # Calculate areas using cross product
190 | cross_prod = torch.linalg.cross(
191 | relevant_v1 - relevant_v0, relevant_v2 - relevant_v0
192 | )
193 | areas[where_inside] = 0.5 * torch.norm(cross_prod, dim=1)
194 |
195 | del cross_prod, relevant_v0, relevant_v1, relevant_v2
196 |
197 | del v0, v1, v2, inside_mask, source_mask
198 |
199 | # Sum up the overlap penalty
200 | overlap_penalty = areas.sum()
201 |
202 | return overlap_penalty
203 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/losses/surface_distance_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_cluster import knn
4 |
5 |
6 | class ProbabilisticSurfaceDistanceLoss(nn.Module):
7 | def __init__(self, k: int = 3, num_samples: int = 100, epsilon: float = 1e-8):
8 | super().__init__()
9 | self.k = k
10 | self.num_samples = num_samples
11 | self.epsilon = epsilon
12 |
13 | def forward(
14 | self,
15 | original_vertices: torch.Tensor,
16 | original_faces: torch.Tensor,
17 | simplified_vertices: torch.Tensor,
18 | simplified_faces: torch.Tensor,
19 | face_probabilities: torch.Tensor,
20 | ) -> torch.Tensor:
21 | if original_vertices.shape[0] == 0 or simplified_vertices.shape[0] == 0:
22 | return torch.tensor(0.0, device=original_vertices.device)
23 |
24 | # Pad face probabilities once for both terms
25 | face_probabilities = torch.nn.functional.pad(
26 | face_probabilities,
27 | (0, max(0, simplified_faces.shape[0] - face_probabilities.shape[0])),
28 | )[: simplified_faces.shape[0]]
29 |
30 | forward_term = self.compute_forward_term(
31 | original_vertices,
32 | original_faces,
33 | simplified_vertices,
34 | simplified_faces,
35 | face_probabilities,
36 | )
37 |
38 | reverse_term = self.compute_reverse_term(
39 | original_vertices,
40 | original_faces,
41 | simplified_vertices,
42 | simplified_faces,
43 | face_probabilities,
44 | )
45 |
46 | total_loss = forward_term + reverse_term
47 | return total_loss
48 |
49 | def compute_forward_term(
50 | self,
51 | original_vertices: torch.Tensor,
52 | original_faces: torch.Tensor,
53 | simplified_vertices: torch.Tensor,
54 | simplified_faces: torch.Tensor,
55 | face_probabilities: torch.Tensor,
56 | ) -> torch.Tensor:
57 | # If there are no faces, return zero loss
58 | if simplified_faces.shape[0] == 0:
59 | return torch.tensor(0.0, device=original_vertices.device)
60 |
61 | simplified_barycenters = self.compute_barycenters(
62 | simplified_vertices, simplified_faces
63 | )
64 | original_barycenters = self.compute_barycenters(
65 | original_vertices, original_faces
66 | )
67 |
68 | distances = self.compute_squared_distances(
69 | simplified_barycenters, original_barycenters
70 | )
71 |
72 | min_distances, _ = distances.min(dim=1)
73 |
74 | del distances # Free memory
75 |
76 | # Compute total loss with probability penalty
77 | total_loss = (face_probabilities * min_distances).sum()
78 | probability_penalty = 1e-4 * (1.0 - face_probabilities).sum()
79 |
80 | del min_distances # Free memory
81 |
82 | return total_loss + probability_penalty
83 |
84 | def compute_reverse_term(
85 | self,
86 | original_vertices: torch.Tensor,
87 | original_faces: torch.Tensor,
88 | simplified_vertices: torch.Tensor,
89 | simplified_faces: torch.Tensor,
90 | face_probabilities: torch.Tensor,
91 | ) -> torch.Tensor:
92 | # If there are no faces, return zero loss
93 | if simplified_faces.shape[0] == 0:
94 | return torch.tensor(0.0, device=original_vertices.device)
95 |
96 | # If meshes are identical, reverse term should be zero
97 | if torch.equal(original_vertices, simplified_vertices) and torch.equal(
98 | original_faces, simplified_faces
99 | ):
100 | return torch.tensor(0.0, device=original_vertices.device)
101 |
102 | # Step 1: Sample points from the simplified mesh
103 | sampled_points = self.sample_points_from_triangles(
104 | simplified_vertices, simplified_faces, self.num_samples
105 | )
106 |
107 | # Step 2: Compute the minimum distance from each sampled point to the original mesh
108 | distances = self.compute_min_distances_to_original(
109 | sampled_points, original_vertices
110 | )
111 |
112 | # Normalize and scale distances
113 | max_dist = distances.max() + self.epsilon
114 | scaled_distances = (distances / max_dist) * 0.1
115 |
116 | del distances # Free memory
117 |
118 | # Reshape face probabilities to match the sampled points
119 | face_probs_expanded = face_probabilities.repeat_interleave(self.num_samples)
120 |
121 | # Compute weighted distances
122 | reverse_term = (face_probs_expanded * scaled_distances).sum()
123 |
124 | return reverse_term
125 |
126 | def sample_points_from_triangles(
127 | self, vertices: torch.Tensor, faces: torch.Tensor, num_samples: int
128 | ) -> torch.Tensor:
129 | """Vectorized point sampling from triangles"""
130 | num_faces = faces.shape[0]
131 | face_vertices = vertices[faces]
132 |
133 | # Generate random values for all samples at once
134 | sqrt_r1 = torch.sqrt(
135 | torch.rand(num_faces, num_samples, 1, device=vertices.device)
136 | )
137 | r2 = torch.rand(num_faces, num_samples, 1, device=vertices.device)
138 |
139 | # Compute barycentric coordinates
140 | a = 1 - sqrt_r1
141 | b = sqrt_r1 * (1 - r2)
142 | c = sqrt_r1 * r2
143 |
144 | # Compute samples using broadcasting
145 | samples = (
146 | a * face_vertices[:, None, 0]
147 | + b * face_vertices[:, None, 1]
148 | + c * face_vertices[:, None, 2]
149 | )
150 |
151 | del a, b, c, sqrt_r1, r2, face_vertices # Free memory
152 |
153 | return samples.reshape(-1, 3)
154 |
155 | def compute_min_distances_to_original(
156 | self, sampled_points: torch.Tensor, target_vertices: torch.Tensor
157 | ) -> torch.Tensor:
158 | """Efficient batch distance computation using KNN"""
159 | # Convert to float32 for KNN
160 | sp_float = sampled_points.float()
161 | tv_float = target_vertices.float()
162 |
163 | # Compute KNN distances
164 | distances, _ = knn(tv_float, sp_float, k=1)
165 |
166 | del sp_float, tv_float # Free memory
167 |
168 | return distances.view(-1).float()
169 |
170 | @staticmethod
171 | def compute_squared_distances(
172 | points1: torch.Tensor, points2: torch.Tensor
173 | ) -> torch.Tensor:
174 | """Compute squared distances efficiently using torch.cdist"""
175 | return torch.cdist(points1, points2, p=2).float()
176 |
177 | def compute_barycenters(
178 | self, vertices: torch.Tensor, faces: torch.Tensor
179 | ) -> torch.Tensor:
180 | return vertices[faces].mean(dim=1)
181 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/losses/triangle_collision_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class TriangleCollisionLoss(nn.Module):
6 | def __init__(
7 | self, epsilon=1e-8, k=50, collision_threshold=1e-10, normal_threshold=0.99
8 | ):
9 | super().__init__()
10 | self.epsilon = epsilon
11 | self.k = k
12 | self.collision_threshold = collision_threshold
13 | self.normal_threshold = normal_threshold
14 |
15 | def forward(self, vertices, faces, face_probabilities):
16 | num_faces = faces.shape[0]
17 |
18 | if num_faces == 0:
19 | return torch.tensor(0.0, device=vertices.device)
20 |
21 | # Ensure face_probabilities matches the number of faces
22 | face_probabilities = torch.nn.functional.pad(
23 | face_probabilities, (0, max(0, num_faces - face_probabilities.shape[0]))
24 | )[:num_faces]
25 |
26 | v0, v1, v2 = vertices[faces].unbind(1)
27 |
28 | # Calculate face normals more efficiently
29 | edges1 = v1 - v0
30 | edges2 = v2 - v0
31 | face_normals = torch.linalg.cross(edges1, edges2)
32 | face_normal_lengths = torch.norm(face_normals, dim=1, keepdim=True)
33 | face_normals = face_normals / (face_normal_lengths + self.epsilon)
34 |
35 | del edges1, edges2, face_normal_lengths # No longer needed
36 |
37 | # Calculate centroids
38 | centroids = (v0 + v1 + v2) / 3
39 |
40 | # Find k nearest neighbors using squared distances
41 | diffs = centroids.unsqueeze(1) - centroids.unsqueeze(0)
42 | distances = torch.sum(diffs * diffs, dim=-1)
43 | del diffs # Large tensor no longer needed
44 |
45 | k = min(self.k, num_faces - 1)
46 | _, neighbors = torch.topk(distances, k=k + 1, largest=False)
47 | del distances # Large matrix no longer needed
48 | neighbors = neighbors[:, 1:]
49 |
50 | collision_count = torch.zeros(num_faces, device=vertices.device)
51 |
52 | for i in range(num_faces):
53 | nearby_faces = neighbors[i]
54 | nearby_v0, nearby_v1, nearby_v2 = (
55 | v0[nearby_faces],
56 | v1[nearby_faces],
57 | v2[nearby_faces],
58 | )
59 |
60 | collisions = self.check_triangle_intersection(
61 | v0[i],
62 | v1[i],
63 | v2[i],
64 | face_normals[i],
65 | nearby_v0,
66 | nearby_v1,
67 | nearby_v2,
68 | face_normals[nearby_faces],
69 | faces[i],
70 | faces[nearby_faces],
71 | centroids[i],
72 | centroids[nearby_faces],
73 | )
74 | collision_count[i] += collisions.sum()
75 |
76 | total_loss = torch.sum(face_probabilities * collision_count)
77 | return total_loss
78 |
79 | def check_triangle_intersection(
80 | self,
81 | v0,
82 | v1,
83 | v2,
84 | normal,
85 | nearby_v0,
86 | nearby_v1,
87 | nearby_v2,
88 | nearby_normals,
89 | face,
90 | nearby_faces,
91 | centroid,
92 | nearby_centroids,
93 | ):
94 | # Check if triangles are coplanar (relaxed condition)
95 | normal_dot = torch.abs(torch.sum(normal * nearby_normals, dim=1))
96 | coplanar = normal_dot > self.normal_threshold
97 |
98 | # Check for triangle-triangle intersections
99 | intersections = torch.zeros(
100 | len(nearby_faces), dtype=torch.bool, device=v0.device
101 | )
102 |
103 | for i, (nv0, nv1, nv2) in enumerate(zip(nearby_v0, nearby_v1, nearby_v2)):
104 | if coplanar[i]:
105 | # For coplanar triangles, check distance between centroids
106 | dist = torch.norm(centroid - nearby_centroids[i])
107 | intersections[i] = dist < self.collision_threshold
108 | else:
109 | intersections[i] = self.check_triangle_triangle_intersection(
110 | v0, v1, v2, nv0, nv1, nv2
111 | )
112 |
113 | # Check if triangles are adjacent
114 | adjacent = torch.tensor(
115 | [len(set(face.tolist()) & set(nf.tolist())) >= 2 for nf in nearby_faces],
116 | dtype=torch.bool,
117 | device=v0.device,
118 | )
119 |
120 | collisions = intersections & ~adjacent
121 |
122 | return collisions
123 |
124 | def check_triangle_triangle_intersection(self, v0, v1, v2, w0, w1, w2):
125 | def triangle_plane_intersection(t0, t1, t2, p0, p1, p2):
126 | n = torch.linalg.cross(t1 - t0, t2 - t0)
127 | d0, d1, d2 = (
128 | self.dist_dot(p0, t0, n),
129 | self.dist_dot(p1, t0, n),
130 | self.dist_dot(p2, t0, n),
131 | )
132 | return (d0 * d1 <= 0) or (d0 * d2 <= 0) or (d1 * d2 <= 0)
133 |
134 | return triangle_plane_intersection(
135 | v0, v1, v2, w0, w1, w2
136 | ) and triangle_plane_intersection(w0, w1, w2, v0, v1, v2)
137 |
138 | @staticmethod
139 | def dist_dot(p, t0, n):
140 | return torch.dot(p - t0, n)
141 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .normal_consistency import normal_consistency
2 | from .chamfer_distance import chamfer_distance
3 | from .edge_preservation import edge_preservation
4 | from .hausdorff_distance import hausdorff_distance
5 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/metrics/chamfer_distance.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import numpy as np
3 |
4 |
5 | def chamfer_distance(
6 | mesh1: trimesh.Trimesh, mesh2: trimesh.Trimesh, samples: int = 10000
7 | ):
8 | """
9 | Calculate the Chamfer distance between two meshes.
10 |
11 | Parameters:
12 | mesh1 (trimesh.Trimesh): The first input mesh.
13 | mesh2 (trimesh.Trimesh): The second input mesh.
14 | samples (int): The number of samples to use for the calculation.
15 |
16 | Returns:
17 | float: The Chamfer distance metric
18 | """
19 | points1 = mesh1.sample(samples)
20 | points2 = mesh2.sample(samples)
21 |
22 | _, distances1, _ = trimesh.proximity.closest_point(mesh2, points1)
23 | _, distances2, _ = trimesh.proximity.closest_point(mesh1, points2)
24 |
25 | chamfer_dist = np.mean(distances1) + np.mean(distances2)
26 |
27 | return chamfer_dist
28 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/metrics/edge_preservation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial import cKDTree
3 |
4 |
5 | def calculate_dihedral_angles(mesh):
6 | """Calculate dihedral angles for all edges in the mesh."""
7 | face_adjacency = mesh.face_adjacency
8 | face_adjacency_angles = mesh.face_adjacency_angles
9 |
10 | edge_to_angle = {}
11 | for (face1, face2), angle in zip(face_adjacency, face_adjacency_angles):
12 | edge = tuple(sorted(set(mesh.faces[face1]) & set(mesh.faces[face2])))
13 | edge_to_angle[edge] = angle
14 |
15 | dihedral_angles = np.zeros(len(mesh.edges))
16 | for i, edge in enumerate(mesh.edges):
17 | edge = tuple(sorted(edge))
18 | dihedral_angles[i] = edge_to_angle.get(edge, 0) # 0 for boundary edges
19 |
20 | return dihedral_angles
21 |
22 |
23 | def edge_preservation(
24 | original_mesh, simplified_mesh, angle_threshold=30, important_edge_factor=2.0
25 | ):
26 | """
27 | Calculate the edge preservation metric between the original and simplified meshes.
28 |
29 | Parameters:
30 | original_mesh (trimesh.Trimesh): The original high-resolution mesh.
31 | simplified_mesh (trimesh.Trimesh): The simplified mesh.
32 | angle_threshold (float): The dihedral angle threshold for important edges (in degrees).
33 | important_edge_factor (float): Factor to increase weight for important edges.
34 |
35 | Returns:
36 | float: The edge preservation metric.
37 | """
38 | original_dihedral = calculate_dihedral_angles(original_mesh)
39 | important_original = original_dihedral > np.radians(angle_threshold)
40 |
41 | # Calculate edge midpoints
42 | original_midpoints = original_mesh.vertices[original_mesh.edges].mean(axis=1)
43 | simplified_midpoints = simplified_mesh.vertices[simplified_mesh.edges].mean(axis=1)
44 |
45 | tree = cKDTree(simplified_midpoints)
46 |
47 | # Find closest simplified edge for each original edge
48 | distances, _ = tree.query(original_midpoints)
49 |
50 | # Calculate weighted preservation
51 | weights = np.exp(original_dihedral)
52 | weights[important_original] *= important_edge_factor
53 | weighted_distances = distances * weights
54 |
55 | max_distance = np.max(original_mesh.bounds) - np.min(original_mesh.bounds)
56 | preservation_metric = 1 - np.mean(weighted_distances) / max_distance
57 |
58 | return preservation_metric
59 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/metrics/hausdorff_distance.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import numpy as np
3 |
4 |
5 | def hausdorff_distance(
6 | mesh1: trimesh.Trimesh, mesh2: trimesh.Trimesh, samples: int = 10000
7 | ):
8 | """
9 | Calculate the Hausdorff distance between two meshes.
10 |
11 | Parameters:
12 | mesh1 (trimesh.Trimesh): The first input mesh.
13 | mesh2 (trimesh.Trimesh): The second input mesh.
14 | samples (int): The number of samples to use for the calculation.
15 |
16 | Returns:
17 | float: The Hausdorff distance between the two meshes.
18 | """
19 | # Sample points from both meshes
20 | points1 = mesh1.sample(samples)
21 | points2 = mesh2.sample(samples)
22 |
23 | # Calculate distances from points1 to mesh2
24 | _, distances1, _ = trimesh.proximity.closest_point(mesh2, points1)
25 |
26 | # Calculate distances from points2 to mesh1
27 | _, distances2, _ = trimesh.proximity.closest_point(mesh1, points2)
28 |
29 | # Hausdorff distance is the maximum of all minimum distances
30 | return max(np.max(distances1), np.max(distances2))
31 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/metrics/normal_consistency.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import numpy as np
3 |
4 |
5 | def normal_consistency(mesh, samples=10000):
6 | """
7 | Calculate the normal vector consistency of a mesh.
8 |
9 | Parameters:
10 | mesh (trimesh.Trimesh): The input mesh.
11 | samples (int): The number of samples to use for the calculation.
12 |
13 | Returns:
14 | float: The normal vector consistency metric.
15 | """
16 | points = mesh.sample(samples)
17 |
18 | _, _, face_indices = trimesh.proximity.closest_point(mesh, points)
19 | face_normals = mesh.face_normals[face_indices]
20 |
21 | mesh.vertex_normals
22 |
23 | closest_vertices = mesh.nearest.vertex(points)[1]
24 | vertex_normals = mesh.vertex_normals[closest_vertices]
25 |
26 | consistency = np.abs(np.sum(vertex_normals * face_normals, axis=1))
27 |
28 | return np.mean(consistency)
29 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .point_sampler import PointSampler
2 | from .edge_predictor import EdgePredictor
3 | from .face_classifier import FaceClassifier
4 | from .neural_mesh_simplification import NeuralMeshSimplification
5 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/edge_predictor.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch_geometric.nn import knn_graph
6 | from torch_scatter import scatter_softmax
7 | from torch_sparse import SparseTensor
8 |
9 | from .layers.devconv import DevConv
10 |
11 | warnings.filterwarnings("ignore", message="Sparse CSR tensor support is in beta state")
12 |
13 |
14 | class EdgePredictor(nn.Module):
15 | def __init__(self, in_channels, hidden_channels, k):
16 | super(EdgePredictor, self).__init__()
17 | self.k = k
18 | self.devconv = DevConv(in_channels, hidden_channels)
19 |
20 | # Self-attention components
21 | self.W_q = nn.Linear(hidden_channels, hidden_channels, bias=False)
22 | self.W_k = nn.Linear(hidden_channels, hidden_channels, bias=False)
23 |
24 | def forward(self, x, edge_index):
25 | if edge_index.numel() == 0:
26 | raise ValueError("Edge index is empty")
27 |
28 | # Step 1: Extend original mesh connectivity with k-nearest neighbors
29 | knn_edges = knn_graph(x, k=self.k, flow="target_to_source")
30 |
31 | # Ensure knn_edges indices are within bounds
32 | max_idx = x.size(0) - 1
33 | valid_edges = (knn_edges[0] <= max_idx) & (knn_edges[1] <= max_idx)
34 | knn_edges = knn_edges[:, valid_edges]
35 |
36 | # Combine original edges with knn edges
37 | if edge_index.numel() > 0:
38 | extended_edges = torch.cat([edge_index, knn_edges], dim=1)
39 | # Remove duplicate edges
40 | extended_edges = torch.unique(extended_edges, dim=1)
41 | else:
42 | extended_edges = knn_edges
43 |
44 | # Step 2: Apply DevConv
45 | features = self.devconv(x, extended_edges)
46 |
47 | # Step 3: Apply sparse self-attention
48 | attention_scores = self.compute_attention_scores(features, edge_index)
49 |
50 | # Step 4: Compute simplified adjacency matrix
51 | simplified_adj_indices, simplified_adj_values = (
52 | self.compute_simplified_adjacency(attention_scores, edge_index)
53 | )
54 |
55 | return simplified_adj_indices, simplified_adj_values
56 |
57 | def compute_attention_scores(self, features, edges):
58 | if edges.numel() == 0:
59 | raise ValueError("Edge index is empty")
60 |
61 | row, col = edges
62 | q = self.W_q(features)
63 | k = self.W_k(features)
64 |
65 | # Compute (W_q f_j)^T (W_k f_i)
66 | attention = (q[row] * k[col]).sum(dim=-1)
67 |
68 | # Apply softmax for each source node
69 | attention_scores = scatter_softmax(attention, row, dim=0)
70 |
71 | return attention_scores
72 |
73 | def compute_simplified_adjacency(self, attention_scores, edge_index):
74 | if edge_index.numel() == 0:
75 | raise ValueError("Edge index is empty")
76 |
77 | num_nodes = edge_index.max().item() + 1
78 | row, col = edge_index
79 |
80 | # Ensure indices are within bounds
81 | if row.numel() > 0:
82 | assert torch.all(row < num_nodes) and torch.all(
83 | row >= 0
84 | ), f"Row indices out of bounds: min={row.min()}, max={row.max()}, num_nodes={num_nodes}"
85 | if col.numel() > 0:
86 | assert torch.all(col < num_nodes) and torch.all(
87 | col >= 0
88 | ), f"Column indices out of bounds: min={col.min()}, max={col.max()}, num_nodes={num_nodes}"
89 |
90 | # Create sparse attention matrix
91 | S = SparseTensor(
92 | row=row,
93 | col=col,
94 | value=attention_scores,
95 | sparse_sizes=(num_nodes, num_nodes),
96 | trust_data=True, # Since we verified the indices above
97 | )
98 |
99 | # Create original adjacency matrix
100 | A = SparseTensor(
101 | row=row,
102 | col=col,
103 | value=torch.ones(edge_index.size(1), device=edge_index.device),
104 | sparse_sizes=(num_nodes, num_nodes),
105 | trust_data=True, # Since we verified the indices above
106 | )
107 |
108 | # Compute A_s = S * A * S^T using coalesced sparse tensors
109 | A_s = S.matmul(A).matmul(S.t())
110 |
111 | # Convert to COO format
112 | row, col, value = A_s.coo()
113 | indices = torch.stack([row, col], dim=0)
114 |
115 | return indices, value
116 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/face_classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .layers import TriConv
5 |
6 |
7 | class FaceClassifier(nn.Module):
8 | def __init__(self, input_dim, hidden_dim, num_layers, k):
9 | super(FaceClassifier, self).__init__()
10 | self.k = k
11 | self.num_layers = num_layers
12 |
13 | self.triconv_layers = nn.ModuleList(
14 | [
15 | TriConv(input_dim if i == 0 else hidden_dim, hidden_dim)
16 | for i in range(num_layers)
17 | ]
18 | )
19 |
20 | self.final_layer = nn.Linear(hidden_dim, 1)
21 |
22 | def forward(self, x, pos, batch=None):
23 | # Handle empty input
24 | if x.size(0) == 0 or pos.size(0) == 0:
25 | return torch.tensor([], device=x.device)
26 |
27 | # If pos is 3D (num_faces, 3, 3), compute centroids
28 | if pos.dim() == 3:
29 | pos = pos.mean(dim=1) # Average vertex positions to get face centers
30 |
31 | # Construct k-nn graph based on triangle centers
32 | edge_index = self.custom_knn_graph(pos, self.k, batch)
33 |
34 | # Apply TriConv layers
35 | for i in range(self.num_layers):
36 | x = self.triconv_layers[i](x, pos, edge_index)
37 | x = torch.relu(x)
38 |
39 | # Final classification
40 | x = self.final_layer(x)
41 | logits = x.squeeze(-1) # Remove last dimension
42 |
43 | # Apply softmax normalization per batch
44 | if batch is None:
45 | # Global normalization using softmax
46 | probs = torch.softmax(logits, dim=0)
47 | else:
48 | # Per-batch normalization
49 | probs = torch.zeros_like(logits)
50 | for b in range(int(batch.max().item()) + 1):
51 | mask = batch == b
52 | probs[mask] = torch.softmax(logits[mask], dim=0)
53 |
54 | return probs
55 |
56 | def custom_knn_graph(self, x, k, batch=None):
57 | if x.size(0) == 0:
58 | return torch.empty((2, 0), dtype=torch.long, device=x.device)
59 |
60 | batch_size = 1 if batch is None else int(batch.max().item()) + 1
61 | edge_index = []
62 |
63 | for b in range(batch_size):
64 | if batch is None:
65 | x_batch = x
66 | else:
67 | mask = batch == b
68 | x_batch = x[mask]
69 |
70 | if x_batch.size(0) > 1:
71 | distances = torch.cdist(x_batch, x_batch)
72 | distances.fill_diagonal_(float("inf"))
73 | _, indices = distances.topk(min(k, x_batch.size(0) - 1), largest=False)
74 |
75 | source = (
76 | torch.arange(x_batch.size(0), device=x.device)
77 | .view(-1, 1)
78 | .expand(-1, indices.size(1))
79 | )
80 | edge_index.append(
81 | torch.stack([source.reshape(-1), indices.reshape(-1)])
82 | )
83 |
84 | if edge_index:
85 | edge_index = torch.cat(edge_index, dim=1)
86 |
87 | # Make the graph symmetric
88 | edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
89 | edge_index = torch.unique(edge_index, dim=1)
90 | else:
91 | edge_index = torch.empty((2, 0), dtype=torch.long, device=x.device)
92 |
93 | return edge_index
94 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .devconv import DevConv
2 | from .triconv import TriConv
3 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/layers/devconv.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch_scatter import scatter_max
3 |
4 |
5 | class DevConv(nn.Module):
6 | def __init__(self, in_channels, out_channels):
7 | super(DevConv, self).__init__()
8 | self.W_theta = nn.Linear(in_channels, out_channels)
9 | self.W_phi = nn.Linear(in_channels, out_channels)
10 |
11 | def forward(self, x, edge_index):
12 | row, col = edge_index
13 | x_i, x_j = x[row], x[col]
14 |
15 | rel_pos = x_i - x_j
16 | rel_pos_transformed = self.W_theta(rel_pos) # [num_edges, out_channels]
17 |
18 | x_transformed = self.W_phi(x) # [num_nodes, out_channels]
19 |
20 | # Aggregate using max pooling
21 | aggr_out = scatter_max(rel_pos_transformed, col, dim=0, dim_size=x.size(0))[0]
22 |
23 | return x_transformed + aggr_out
24 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/layers/triconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_scatter import scatter_max, scatter_add
4 |
5 |
6 | class TriConv(nn.Module):
7 | def __init__(self, in_channels, out_channels):
8 | super(TriConv, self).__init__()
9 | self.in_channels = in_channels
10 | self.out_channels = out_channels
11 |
12 | # Calculate the correct input dimension for the MLP
13 | mlp_input_dim = in_channels + 9 # 9 is from the relative position encoding
14 |
15 | self.mlp = nn.Sequential(
16 | nn.Linear(mlp_input_dim, out_channels),
17 | nn.ReLU(),
18 | nn.Linear(out_channels, out_channels),
19 | )
20 | self.last_edge_index = None
21 |
22 | def forward(self, x, pos, edge_index):
23 | self.last_edge_index = edge_index
24 | row, col = edge_index
25 |
26 | rel_pos = self.compute_relative_position_encoding(pos, row, col)
27 | x_diff = x[row] - x[col]
28 | mlp_input = torch.cat([rel_pos, x_diff], dim=-1)
29 |
30 | mlp_output = self.mlp(mlp_input)
31 | out = scatter_add(mlp_output, col, dim=0, dim_size=x.size(0))
32 |
33 | return out
34 |
35 | def compute_relative_position_encoding(self, pos, row, col):
36 | edge_vec = pos[row] - pos[col]
37 |
38 | t_max, _ = scatter_max(edge_vec.abs(), col, dim=0, dim_size=pos.size(0))
39 | t_min, _ = scatter_max(-edge_vec.abs(), col, dim=0, dim_size=pos.size(0))
40 | t_min = -t_min
41 |
42 | barycenter = pos.mean(dim=-1, keepdim=True) if pos.dim() == 3 else pos
43 | barycenter_diff = barycenter[row] - barycenter[col]
44 |
45 | t_max_diff = t_max[row] - t_max[col]
46 | t_min_diff = t_min[row] - t_min[col]
47 | barycenter_diff = barycenter_diff.expand_as(t_max_diff)
48 |
49 | rel_pos = torch.cat([t_max_diff, t_min_diff, barycenter_diff], dim=-1)
50 |
51 | return rel_pos
52 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/neural_mesh_simplification.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch_geometric
4 | from torch_geometric.data import Data
5 |
6 | from ..models import PointSampler, EdgePredictor, FaceClassifier
7 |
8 |
9 | class NeuralMeshSimplification(nn.Module):
10 | def __init__(
11 | self,
12 | input_dim,
13 | hidden_dim,
14 | edge_hidden_dim, # Separate hidden dim for edge predictor
15 | num_layers,
16 | k,
17 | edge_k,
18 | target_ratio,
19 | device=torch.device("cpu"),
20 | ):
21 | super(NeuralMeshSimplification, self).__init__()
22 | self.device = device
23 | self.point_sampler = PointSampler(input_dim, hidden_dim, num_layers).to(
24 | self.device
25 | )
26 | self.edge_predictor = EdgePredictor(
27 | input_dim,
28 | hidden_channels=edge_hidden_dim,
29 | k=edge_k,
30 | ).to(self.device)
31 | self.face_classifier = FaceClassifier(input_dim, hidden_dim, num_layers, k).to(
32 | self.device
33 | )
34 | self.k = k
35 | self.target_ratio = target_ratio
36 |
37 | def forward(self, data: Data):
38 | x, edge_index = data.x, data.edge_index
39 | num_nodes = x.size(0)
40 |
41 | sampled_indices, sampled_probs = self.sample_points(data)
42 |
43 | sampled_x = x[sampled_indices].to(self.device)
44 | sampled_pos = (
45 | data.pos[sampled_indices]
46 | if hasattr(data, "pos") and data.pos is not None
47 | else sampled_x
48 | ).to(self.device)
49 |
50 | sampled_vertices = sampled_pos # Use sampled_pos directly as vertices
51 |
52 | # Update edge_index to reflect the new indices
53 | sampled_edge_index, _ = torch_geometric.utils.subgraph(
54 | sampled_indices, edge_index, relabel_nodes=True, num_nodes=num_nodes
55 | )
56 |
57 | # Predict edges
58 | sampled_edge_index = sampled_edge_index.to(self.device)
59 | edge_index_pred, edge_probs = self.edge_predictor(sampled_x, sampled_edge_index)
60 |
61 | # Generate candidate triangles
62 | candidate_triangles, triangle_probs = self.generate_candidate_triangles(
63 | edge_index_pred, edge_probs
64 | )
65 |
66 | # Classify faces
67 | if candidate_triangles.shape[0] > 0:
68 | # Create triangle features by averaging vertex features
69 | triangle_features = torch.zeros(
70 | (candidate_triangles.shape[0], sampled_x.shape[1]),
71 | device=self.device,
72 | )
73 | for i in range(3):
74 | triangle_features += sampled_x[candidate_triangles[:, i]]
75 | triangle_features /= 3
76 |
77 | # Calculate triangle centers
78 | triangle_centers = torch.zeros(
79 | (candidate_triangles.shape[0], sampled_pos.shape[1]),
80 | device=self.device,
81 | )
82 | for i in range(3):
83 | triangle_centers += sampled_pos[candidate_triangles[:, i]]
84 | triangle_centers /= 3
85 |
86 | face_probs = self.face_classifier(
87 | triangle_features, triangle_centers, batch=None
88 | )
89 | else:
90 | face_probs = torch.empty(0, device=self.device)
91 |
92 | if candidate_triangles.shape[0] == 0:
93 | simplified_faces = torch.empty((0, 3), dtype=torch.long, device=self.device)
94 | else:
95 | threshold = torch.quantile(
96 | face_probs, 1 - self.target_ratio
97 | ) # Use a dynamic threshold
98 | simplified_faces = candidate_triangles[face_probs > threshold]
99 |
100 | return {
101 | "sampled_indices": sampled_indices,
102 | "sampled_probs": sampled_probs,
103 | "sampled_vertices": sampled_vertices,
104 | "edge_index": edge_index_pred,
105 | "edge_probs": edge_probs,
106 | "candidate_triangles": candidate_triangles,
107 | "triangle_probs": triangle_probs,
108 | "face_probs": face_probs,
109 | "simplified_faces": simplified_faces,
110 | }
111 |
112 | def sample_points(self, data: Data):
113 | x, edge_index = data.x, data.edge_index
114 | num_nodes = x.size(0)
115 |
116 | target_nodes = min(
117 | max(int(self.target_ratio * num_nodes), 1),
118 | num_nodes,
119 | )
120 |
121 | # Sample points
122 | x = x.to(self.device)
123 | edge_index = edge_index.to(self.device)
124 | sampled_probs = self.point_sampler(x, edge_index)
125 | sampled_indices = self.point_sampler.sample(
126 | sampled_probs, num_samples=target_nodes
127 | )
128 |
129 | return sampled_indices, sampled_probs[sampled_indices]
130 |
131 | def generate_candidate_triangles(self, edge_index, edge_probs):
132 |
133 | # Handle the case when edge_index is empty
134 | if edge_index.numel() == 0:
135 | return (
136 | torch.empty((0, 3), dtype=torch.long, device=self.device),
137 | torch.empty(0, device=self.device),
138 | )
139 |
140 | num_nodes = edge_index.max().item() + 1
141 |
142 | # Create an adjacency matrix from the edge index
143 | adj_matrix = torch.zeros(num_nodes, num_nodes, device=self.device)
144 |
145 | # Check if edge_probs is a tuple or a tensor
146 | if isinstance(edge_probs, tuple):
147 | edge_indices, edge_values = edge_probs
148 | adj_matrix[edge_indices[0], edge_indices[1]] = edge_values
149 | else:
150 | adj_matrix[edge_index[0], edge_index[1]] = edge_probs
151 |
152 | # Adjust k based on the number of nodes
153 | k = min(self.k, num_nodes - 1)
154 |
155 | # Find k-nearest neighbors for each node
156 | _, knn_indices = torch.topk(adj_matrix, k=k, dim=1)
157 |
158 | # Generate candidate triangles
159 | triangles = []
160 | triangle_probs = []
161 |
162 | for i in range(num_nodes):
163 | neighbors = knn_indices[i]
164 | for j in range(k):
165 | for l in range(j + 1, k):
166 | n1, n2 = neighbors[j], neighbors[l]
167 | if adj_matrix[n1, n2] > 0: # Check if the third edge exists
168 | triangle = torch.tensor([i, n1, n2], device=self.device)
169 | triangles.append(triangle)
170 |
171 | # Calculate triangle probability
172 | prob = (
173 | adj_matrix[i, n1] * adj_matrix[i, n2] * adj_matrix[n1, n2]
174 | ) ** (1 / 3)
175 | triangle_probs.append(prob)
176 |
177 | if triangles:
178 | triangles = torch.stack(triangles)
179 | triangle_probs = torch.tensor(triangle_probs, device=self.device)
180 | else:
181 | triangles = torch.empty((0, 3), dtype=torch.long, device=self.device)
182 | triangle_probs = torch.empty(0, device=self.device)
183 |
184 | return triangles, triangle_probs
185 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/models/point_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .layers.devconv import DevConv
5 |
6 |
7 | class PointSampler(nn.Module):
8 | def __init__(self, in_channels, out_channels, num_layers):
9 | super(PointSampler, self).__init__()
10 | self.num_layers = num_layers
11 |
12 | # Stack of DevConv layers
13 | self.convs = nn.ModuleList()
14 | self.convs.append(DevConv(in_channels, out_channels))
15 | for _ in range(num_layers - 1):
16 | self.convs.append(DevConv(out_channels, out_channels))
17 |
18 | # Final output layer to produce a single score per vertex
19 | self.output_layer = nn.Linear(out_channels, 1)
20 |
21 | # Activation functions
22 | self.activation = nn.ReLU()
23 | self.sigmoid = nn.Sigmoid()
24 |
25 | def forward(self, x, edge_index):
26 | # x: Node features [num_nodes, in_channels]
27 | # edge_index: Graph connectivity [2, num_edges]
28 |
29 | # Apply DevConv layers
30 | for conv in self.convs:
31 | x = conv(x, edge_index)
32 | x = self.activation(x)
33 |
34 | # Generate inclusion scores
35 | scores = self.output_layer(x).squeeze(-1)
36 |
37 | # Convert scores to probabilities
38 | probabilities = self.sigmoid(scores)
39 |
40 | return probabilities
41 |
42 | def sample(self, probabilities, num_samples):
43 | max_samples = probabilities.shape[0]
44 | if num_samples > max_samples:
45 | raise ValueError(
46 | f"num_samples ({num_samples}) cannot be larger than number of vertices ({max_samples})"
47 | )
48 |
49 | # Multinomial sampling based on probabilities
50 | sampled_indices = torch.multinomial(
51 | probabilities, num_samples, replacement=False
52 | )
53 |
54 | return sampled_indices
55 |
56 | def forward_and_sample(self, x, edge_index, num_samples):
57 | # Combine forward pass and sampling in one step
58 | probabilities = self.forward(x, edge_index)
59 | sampled_indices = self.sample(probabilities, num_samples)
60 | return sampled_indices, probabilities
61 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
2 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/trainer/resource_monitor.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import psutil
4 | import torch
5 |
6 |
7 | def monitor_resources(stop_event, main_pid):
8 | main_process = psutil.Process(main_pid)
9 |
10 | interval_idx = 0
11 |
12 | while not stop_event.is_set():
13 | try:
14 | # Get CPU and memory usage for the main process
15 | cpu_percent = main_process.cpu_percent(interval=1)
16 | memory_info = main_process.memory_info()
17 | total_memory_rss = memory_info.rss # Start with main process memory usage
18 |
19 | # Include children processes
20 | for child in main_process.children(recursive=True):
21 | try:
22 | cpu_percent += child.cpu_percent(interval=None)
23 | child_memory = child.memory_info()
24 | total_memory_rss += (
25 | child_memory.rss
26 | ) # Add child process memory usage
27 | except psutil.NoSuchProcess:
28 | pass # Child process no longer exists
29 |
30 | memory_usage_mb = total_memory_rss / (1024 * 1024) # Convert to MB
31 |
32 | interval_idx += 1
33 |
34 | output = f"\rCPU: {cpu_percent:.1f}% | Memory: {memory_usage_mb:.2f} MB"
35 |
36 | # Get GPU and GPU memory usage
37 | if torch.cuda.is_available():
38 | gpu_info = []
39 | for i in range(torch.cuda.device_count()):
40 | gpu_util = torch.cuda.utilization(i) # GPU Utilization
41 | mem_alloc = torch.cuda.memory_allocated(i) / (
42 | 1024 * 1024
43 | ) # Convert to MB
44 | mem_total = torch.cuda.get_device_properties(i).total_memory / (
45 | 1024 * 1024
46 | ) # Total VRAM
47 | gpu_info.append(
48 | f"GPU {i}: {gpu_util:.1f}% | Mem: {mem_alloc:.2f}/{mem_total:.2f} MB"
49 | )
50 |
51 | output += " | " + " | ".join(gpu_info)
52 |
53 | if interval_idx % 3 == 0:
54 | output += "\n"
55 |
56 | print(output, end="", flush=True)
57 |
58 | except psutil.NoSuchProcess:
59 | print("\nMain process has terminated.")
60 | break
61 | time.sleep(1)
62 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/trainer/trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from multiprocessing import Event, Process
4 | from typing import Dict, Any
5 |
6 | import torch
7 | from torch.optim import Adam
8 | from torch.optim.lr_scheduler import ReduceLROnPlateau
9 | from torch.utils.data import random_split
10 | from torch_geometric.loader import DataLoader
11 |
12 | from .resource_monitor import monitor_resources
13 | from ..data import MeshSimplificationDataset
14 | from ..losses import CombinedMeshSimplificationLoss
15 | from ..metrics import (
16 | chamfer_distance,
17 | normal_consistency,
18 | edge_preservation,
19 | hausdorff_distance,
20 | )
21 | from ..models import NeuralMeshSimplification
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class Trainer:
27 | def __init__(self, config: Dict[str, Any]):
28 | self.config = config
29 | logger.info("Initializing trainer...")
30 |
31 | if torch.cuda.is_available():
32 | self.device = torch.device("cuda")
33 | else:
34 | self.device = torch.device("cpu")
35 | logger.info(f"Using device: {self.device}")
36 |
37 | logger.info("Initializing model...")
38 | self.model = NeuralMeshSimplification(
39 | input_dim=config["model"]["input_dim"],
40 | hidden_dim=config["model"]["hidden_dim"],
41 | edge_hidden_dim=config["model"]["edge_hidden_dim"],
42 | num_layers=config["model"]["num_layers"],
43 | k=config["model"]["k"],
44 | edge_k=config["model"]["edge_k"],
45 | target_ratio=config["model"]["target_ratio"],
46 | device=self.device,
47 | )
48 |
49 | logger.debug("Setting up optimizer and loss...")
50 | self.optimizer = Adam(
51 | self.model.parameters(),
52 | lr=config["training"]["learning_rate"],
53 | weight_decay=config["training"]["weight_decay"],
54 | )
55 | self.scheduler = ReduceLROnPlateau(
56 | self.optimizer, mode="min", factor=0.1, patience=10, verbose=True
57 | )
58 | self.criterion = CombinedMeshSimplificationLoss(
59 | lambda_c=config["loss"]["lambda_c"],
60 | lambda_e=config["loss"]["lambda_e"],
61 | lambda_o=config["loss"]["lambda_o"],
62 | device=self.device,
63 | )
64 | self.early_stopping_patience = config["training"]["early_stopping_patience"]
65 | self.best_val_loss = float("inf")
66 | self.early_stopping_counter = 0
67 | self.checkpoint_dir = config["training"]["checkpoint_dir"]
68 | os.makedirs(self.checkpoint_dir, exist_ok=True)
69 |
70 | logger.debug("Preparing data loaders...")
71 | self.train_loader, self.val_loader = self._prepare_data_loaders()
72 | logger.info("Trainer initialization complete.")
73 |
74 | if "monitor_resources" in config:
75 | logger.info("Monitoring resource usage")
76 | self.monitor_resources = True
77 | self.stop_event = Event()
78 | self.monitor_process = None
79 | else:
80 | self.monitor_resources = False
81 |
82 | def _prepare_data_loaders(self):
83 | logger.info(f"Loading dataset from {self.config['data']['data_dir']}")
84 | dataset = MeshSimplificationDataset(
85 | data_dir=self.config["data"]["data_dir"],
86 | preprocess=False, # Can be False, because data has been prepared before training
87 | )
88 | logger.debug(f"Dataset size: {len(dataset)}")
89 |
90 | val_size = int(len(dataset) * self.config["data"]["val_split"])
91 | train_size = len(dataset) - val_size
92 | logger.info(f"Splitting dataset: {train_size} train, {val_size} validation")
93 |
94 | train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
95 | assert (
96 | len(val_dataset) > 0
97 | ), f"There is not enough data to define an evaluation set. len(dataset)={len(dataset)}, train_size={train_size}, val_size={val_size}"
98 |
99 | num_workers = self.config["training"].get("num_workers", os.cpu_count())
100 | logger.info(f"Using {num_workers} workers for data loading")
101 |
102 | train_loader = DataLoader(
103 | train_dataset,
104 | batch_size=self.config["training"]["batch_size"],
105 | shuffle=True,
106 | num_workers=num_workers,
107 | follow_batch=["x", "pos"],
108 | )
109 |
110 | val_loader = DataLoader(
111 | val_dataset,
112 | batch_size=self.config["training"]["batch_size"],
113 | shuffle=False,
114 | num_workers=num_workers,
115 | follow_batch=["x", "pos"],
116 | )
117 | logger.info("Data loaders prepared successfully")
118 |
119 | return train_loader, val_loader
120 |
121 | def train(self):
122 | if self.monitor_resources:
123 | main_pid = os.getpid()
124 | self.monitor_process = Process(
125 | target=monitor_resources, args=(self.stop_event, main_pid)
126 | )
127 | self.monitor_process.start()
128 |
129 | try:
130 | for epoch in range(self.config["training"]["num_epochs"]):
131 | loss = self._train_one_epoch(epoch)
132 |
133 | logging.info(
134 | f"Epoch [{epoch + 1}/{self.config['training']['num_epochs']}], Loss: {loss / len(self.train_loader)}"
135 | )
136 |
137 | val_loss = self._validate()
138 | logging.info(
139 | f"Epoch [{epoch + 1}/{self.config['training']['num_epochs']}], Validation Loss: {val_loss}"
140 | )
141 |
142 | self.scheduler.step(val_loss)
143 |
144 | # Save the checkpoint
145 | self._save_checkpoint(epoch, val_loss)
146 |
147 | if torch.cuda.is_available():
148 | torch.cuda.empty_cache()
149 |
150 | # Early stop as required
151 | if self._early_stopping(val_loss):
152 | logging.info("Early stopping triggered.")
153 | break
154 | except Exception as e:
155 | logger.error(f"{str(e)}")
156 | finally:
157 | if self.monitor_resources:
158 | self.stop_event.set()
159 | self.monitor_process.join()
160 | print() # New line after monitoring output
161 |
162 | def _train_one_epoch(self, epoch: int) -> float:
163 | self.model.train()
164 | running_loss = 0.0
165 | logger.debug(f"Starting epoch {epoch + 1}")
166 |
167 | for batch_idx, batch in enumerate(self.train_loader):
168 | logger.debug(f"Processing batch {batch_idx + 1}")
169 | self.optimizer.zero_grad()
170 | output = self.model(batch)
171 | loss = self.criterion(batch, output)
172 |
173 | del batch
174 | del output
175 |
176 | loss.backward()
177 | self.optimizer.step()
178 | running_loss += loss.item()
179 |
180 | return running_loss / len(self.train_loader)
181 |
182 | def _validate(self) -> float:
183 | self.model.eval()
184 | val_loss = 0.0
185 | with torch.no_grad():
186 | for batch in self.val_loader:
187 | output = self.model(batch)
188 | loss = self.criterion(batch, output)
189 | val_loss += loss.item()
190 |
191 | return val_loss / len(self.val_loader)
192 |
193 | def _save_checkpoint(self, epoch: int, val_loss: float):
194 | checkpoint_path = os.path.join(
195 | self.checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth"
196 | )
197 |
198 | logging.debug(f"Saving checkpoint to {checkpoint_path}")
199 | torch.save(
200 | {
201 | "epoch": epoch + 1,
202 | "model_state_dict": self.model.state_dict(),
203 | "optimizer_state_dict": self.optimizer.state_dict(),
204 | "val_loss": val_loss,
205 | },
206 | checkpoint_path,
207 | )
208 | if val_loss < self.best_val_loss:
209 | self.best_val_loss = val_loss
210 | best_model_path = os.path.join(self.checkpoint_dir, "best_model.pth")
211 | logging.debug(f"Saving best model to {best_model_path}")
212 | torch.save(self.model.state_dict(), best_model_path)
213 |
214 | # Remove old checkpoints to save space
215 | for old_checkpoint in os.listdir(self.checkpoint_dir):
216 | if (
217 | old_checkpoint.startswith("checkpoint_")
218 | and old_checkpoint != f"checkpoint_epoch_{epoch + 1}.pth"
219 | ):
220 | os.remove(os.path.join(self.checkpoint_dir, old_checkpoint))
221 |
222 | def _early_stopping(self, val_loss: float) -> bool:
223 | if val_loss < self.best_val_loss:
224 | self.early_stopping_counter = 0
225 | else:
226 | self.early_stopping_counter += 1
227 | return self.early_stopping_counter >= self.early_stopping_patience
228 |
229 | def load_checkpoint(self, checkpoint_path: str):
230 | checkpoint = torch.load(checkpoint_path)
231 | self.model.load_state_dict(checkpoint["model_state_dict"])
232 | self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
233 | self.best_val_loss = checkpoint["val_loss"]
234 | logging.info(
235 | f"Loaded checkpoint from {checkpoint_path} (epoch {checkpoint['epoch']})"
236 | )
237 |
238 | def log_metrics(self, metrics: Dict[str, float], epoch: int):
239 | log_message = f"Epoch [{epoch + 1}/{self.config['training']['num_epochs']}], "
240 | log_message += ", ".join(
241 | [f"{key}: {value:.4f}" for key, value in metrics.items()]
242 | )
243 | logging.info(log_message)
244 |
245 | def evaluate(self, data_loader: DataLoader) -> Dict[str, float]:
246 | self.model.eval()
247 | metrics = {
248 | "chamfer_distance": 0.0,
249 | "normal_consistency": 0.0,
250 | "edge_preservation": 0.0,
251 | "hausdorff_distance": 0.0,
252 | }
253 | with torch.no_grad():
254 | for batch in data_loader:
255 | output = self.model(batch)
256 |
257 | # TODO: Define methods that can operate on a batch instead of a trimesh object
258 |
259 | metrics["chamfer_distance"] += chamfer_distance(batch, output)
260 | metrics["normal_consistency"] += normal_consistency(batch, output)
261 | metrics["edge_preservation"] += edge_preservation(batch, output)
262 | metrics["hausdorff_distance"] += hausdorff_distance(batch, output)
263 | for key in metrics:
264 | metrics[key] /= len(data_loader)
265 | return metrics
266 |
267 | def handle_error(self, error: Exception):
268 | logging.error(f"An error occurred: {error}")
269 | if isinstance(error, RuntimeError) and "out of memory" in str(error):
270 | logging.error("Out of memory error. Attempting to recover.")
271 | torch.cuda.empty_cache()
272 | else:
273 | raise error
274 |
275 | def save_training_state(self, state_path: str):
276 | torch.save(
277 | {
278 | "model_state_dict": self.model.state_dict(),
279 | "optimizer_state_dict": self.optimizer.state_dict(),
280 | "best_val_loss": self.best_val_loss,
281 | "early_stopping_counter": self.early_stopping_counter,
282 | },
283 | state_path,
284 | )
285 |
286 | def load_training_state(self, state_path: str):
287 | state = torch.load(state_path)
288 | self.model.load_state_dict(state["model_state_dict"])
289 | self.optimizer.load_state_dict(state["optimizer_state_dict"])
290 | self.best_val_loss = state["best_val_loss"]
291 | self.early_stopping_counter = state["early_stopping_counter"]
292 | logging.info(f"Loaded training state from {state_path}")
293 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .mesh_operations import build_graph_from_mesh
2 |
--------------------------------------------------------------------------------
/src/neural_mesh_simplification/utils/mesh_operations.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import trimesh
3 |
4 |
5 | def simplify_mesh(mesh, target_faces):
6 | # Simplify a mesh to a target number of faces
7 | pass
8 |
9 |
10 | def calculate_mesh_features(mesh):
11 | # Calculate relevant features of a mesh (e.g., curvature)
12 | pass
13 |
14 |
15 | def align_meshes(mesh1, mesh2):
16 | # Align two meshes (useful for comparison)
17 | pass
18 |
19 |
20 | def compare_meshes(mesh1, mesh2):
21 | # Compare two meshes (e.g., Hausdorff distance)
22 | pass
23 |
24 |
25 | def build_graph_from_mesh(mesh: trimesh.Trimesh) -> nx.Graph:
26 | """Build a graph structure from a mesh."""
27 | G = nx.Graph()
28 |
29 | # Add nodes (vertices)
30 | for i, vertex in enumerate(mesh.vertices):
31 | G.add_node(i, pos=vertex)
32 |
33 | # Add edges
34 | for face in mesh.faces:
35 | G.add_edge(face[0], face[1])
36 | G.add_edge(face[1], face[2])
37 | G.add_edge(face[2], face[0])
38 |
39 | return G
40 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/martinnormark/neural-mesh-simplification/2ada79100932189cfd28a56093ba3d72e24f3ec7/tests/__init__.py
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import trimesh
3 |
4 |
5 | @pytest.fixture
6 | def sample_mesh():
7 | # Create a simple cube mesh for testing
8 | mesh = trimesh.creation.box()
9 | return mesh
10 |
--------------------------------------------------------------------------------
/tests/losses/test_edge_crossings_loss.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from neural_mesh_simplification.losses import EdgeCrossingLoss
5 |
6 |
7 | @pytest.fixture
8 | def loss_fn():
9 | return EdgeCrossingLoss(k=2)
10 |
11 |
12 | @pytest.fixture
13 | def sample_data():
14 | vertices = torch.tensor(
15 | [
16 | [0, 0, 0],
17 | [1, 0, 0],
18 | [0, 1, 0],
19 | [1, 1, 0],
20 | [0, 0, 1],
21 | [1, 0, 1],
22 | [0, 1, 1],
23 | [1, 1, 1],
24 | ],
25 | dtype=torch.float,
26 | )
27 | faces = torch.tensor([[0, 1, 2], [1, 3, 2], [4, 5, 6], [5, 7, 6]], dtype=torch.long)
28 | face_probs = torch.tensor([0.8, 0.6, 0.7, 0.9], dtype=torch.float)
29 | return vertices, faces, face_probs
30 |
31 |
32 | def test_find_nearest_triangles(loss_fn):
33 | vertices = torch.tensor(
34 | [[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [1, 0, 1]],
35 | dtype=torch.float,
36 | )
37 | faces = torch.tensor([[0, 1, 2], [1, 3, 2], [0, 4, 1], [1, 4, 5]], dtype=torch.long)
38 |
39 | nearest = loss_fn.find_nearest_triangles(vertices, faces)
40 | assert nearest.shape[0] == faces.shape[0]
41 | assert nearest.shape[1] == 1 # k-1 = 1, since k=2
42 |
43 |
44 | def test_detect_edge_crossings(loss_fn, sample_data):
45 | vertices, faces, _ = sample_data
46 |
47 | nearest_triangles = torch.tensor([[1], [0], [3], [2]], dtype=torch.long)
48 | crossings = loss_fn.detect_edge_crossings(vertices, faces, nearest_triangles)
49 |
50 | # Test expected number of crossings
51 | # Here, you should check for specific cases based on the vertex configuration
52 | # Example: You expect 0 crossings if triangles are separate
53 | assert torch.all(crossings == 0) # Modify this based on your actual expectations
54 |
55 |
56 | def test_calculate_loss(loss_fn, sample_data):
57 | _, _, face_probs = sample_data
58 |
59 | crossings = torch.tensor([1.0, 0.0, 2.0, 1.0], dtype=torch.float)
60 | loss = loss_fn.calculate_loss(crossings, face_probs)
61 |
62 | num_faces = face_probs.shape[0]
63 | expected_loss = torch.sum(face_probs * crossings, dtype=torch.float32) / num_faces
64 | assert torch.isclose(
65 | loss, torch.tensor(expected_loss)
66 | ), f"Expected loss: {expected_loss}, but got {loss.item()}"
67 |
68 |
69 | def test_edge_crossing_loss_full(loss_fn, sample_data):
70 | vertices, faces, face_probs = sample_data
71 |
72 | # Run the forward pass of the loss function
73 | loss = loss_fn(vertices, faces, face_probs)
74 |
75 | # Check that the loss value is as expected
76 | # Ensure expected_loss is a floating-point tensor
77 | expected_loss = torch.tensor(
78 | 0.0, dtype=torch.float
79 | ) # Modify this based on your setup
80 | assert torch.isclose(
81 | loss, expected_loss
82 | ), f"Expected loss: {expected_loss.item()}, but got {loss.item()}"
83 |
--------------------------------------------------------------------------------
/tests/losses/test_overlapping_triangles_loss.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from neural_mesh_simplification.losses import OverlappingTrianglesLoss
5 |
6 |
7 | @pytest.fixture
8 | def loss_fn():
9 | return OverlappingTrianglesLoss(num_samples=5, k=3)
10 |
11 |
12 | @pytest.fixture
13 | def sample_data():
14 | vertices = torch.tensor(
15 | [
16 | [0.0, 0.0, 0.0],
17 | [1.0, 0.0, 0.0],
18 | [0.0, 1.0, 0.0],
19 | [1.0, 1.0, 0.0],
20 | [0.0, 0.0, 1.0],
21 | [1.0, 0.0, 1.0],
22 | [0.0, 1.0, 1.0],
23 | [1.0, 1.0, 1.0],
24 | ],
25 | dtype=torch.float,
26 | )
27 | faces = torch.tensor([[0, 1, 2], [1, 3, 2], [4, 5, 6], [5, 7, 6]], dtype=torch.long)
28 | return vertices, faces
29 |
30 |
31 | def test_sample_points_from_triangles(loss_fn, sample_data):
32 | vertices, faces = sample_data
33 |
34 | sampled_points, point_face_indices = loss_fn.sample_points_from_triangles(
35 | vertices, faces
36 | )
37 |
38 | assert sampled_points.shape[0] == point_face_indices.shape[0]
39 |
40 | # Check the shape of the sampled points
41 | expected_shape = (faces.shape[0] * loss_fn.num_samples, 3)
42 | assert (
43 | sampled_points.shape == expected_shape
44 | ), f"Expected shape {expected_shape}, got {sampled_points.shape}"
45 |
46 | # Check that points lie within the bounding box of the mesh
47 | assert torch.all(
48 | sampled_points >= vertices.min(dim=0).values
49 | ), "Sampled points are outside the mesh bounds"
50 | assert torch.all(
51 | sampled_points <= vertices.max(dim=0).values
52 | ), "Sampled points are outside the mesh bounds"
53 |
54 |
55 | def test_find_nearest_triangles(loss_fn, sample_data):
56 | vertices, faces = sample_data
57 |
58 | sampled_points, _ = loss_fn.sample_points_from_triangles(vertices, faces)
59 | nearest_triangles = loss_fn.find_nearest_triangles(sampled_points, vertices, faces)
60 |
61 | # Check the shape of the nearest triangles tensor
62 | expected_shape = (sampled_points.shape[0], loss_fn.k)
63 | assert (
64 | nearest_triangles.shape == expected_shape
65 | ), f"Expected shape {expected_shape}, got {nearest_triangles.shape}"
66 |
67 | # Check that the indices are within the range of faces
68 | assert torch.all(nearest_triangles >= 0) and torch.all(
69 | nearest_triangles < faces.shape[0]
70 | ), "Invalid triangle indices"
71 |
72 |
73 | def test_calculate_overlap_loss(loss_fn, sample_data):
74 | vertices, faces = sample_data
75 |
76 | sampled_points, point_face_indices = loss_fn.sample_points_from_triangles(
77 | vertices, faces
78 | )
79 | nearest_triangles = loss_fn.find_nearest_triangles(sampled_points, vertices, faces)
80 | overlap_loss = loss_fn.calculate_overlap_loss(
81 | sampled_points, vertices, faces, nearest_triangles, point_face_indices
82 | )
83 |
84 | # Check that the overlap loss is a scalar
85 | assert (
86 | isinstance(overlap_loss, torch.Tensor) and overlap_loss.dim() == 0
87 | ), "Overlap loss should be a scalar"
88 |
89 | # For this simple test case, the overlap should be minimal or zero
90 | assert overlap_loss.item() >= 0, "Overlap loss should be non-negative"
91 |
92 |
93 | def test_overlapping_triangles_loss_full(loss_fn, sample_data):
94 | vertices, faces = sample_data
95 |
96 | # Run the forward pass of the loss function
97 | loss = loss_fn(vertices, faces)
98 |
99 | # Check that the loss is a scalar
100 | assert isinstance(loss, torch.Tensor) and loss.dim() == 0, "Loss should be a scalar"
101 |
102 | # Expected behavior: the loss should be non-negative
103 | assert loss.item() >= 0, "Overlap loss should be non-negative"
104 |
--------------------------------------------------------------------------------
/tests/losses/test_proba_chamfer_distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 | from neural_mesh_simplification.losses import ProbabilisticChamferDistanceLoss
4 |
5 |
6 | @pytest.fixture
7 | def pcd_loss():
8 | return ProbabilisticChamferDistanceLoss()
9 |
10 |
11 | def test_pcd_loss_zero_distance(pcd_loss):
12 | P = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.float32)
13 | Ps = P.clone()
14 | probs = torch.tensor([0.5, 0.5], dtype=torch.float32)
15 |
16 | loss = pcd_loss(P, Ps, probs)
17 | assert torch.isclose(
18 | loss, torch.tensor(0.0)
19 | ), f"Expected loss to be 0, but got {loss.item()}"
20 |
21 |
22 | def test_pcd_loss_nonzero_distance(pcd_loss):
23 | P = torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.float32)
24 | Ps = torch.tensor([[0, 0, 1], [1, 1, 0]], dtype=torch.float32)
25 | probs = torch.tensor([0.5, 0.5], dtype=torch.float32)
26 |
27 | loss = pcd_loss(P, Ps, probs)
28 | assert loss > 0, f"Expected loss to be positive, but got {loss.item()}"
29 | assert torch.isfinite(loss), f"Expected loss to be finite, but got {loss.item()}"
30 |
31 |
32 | def test_pcd_loss_empty_input(pcd_loss):
33 | P = torch.empty((0, 3))
34 | Ps = torch.empty((0, 3))
35 | probs = torch.empty(0)
36 |
37 | loss = pcd_loss(P, Ps, probs)
38 | assert loss == 0, f"Expected loss to be 0 for empty input, but got {loss.item()}"
39 |
40 |
41 | def test_pcd_loss_different_sizes(pcd_loss):
42 | P = torch.rand((100, 3))
43 | Ps = torch.rand((50, 3))
44 | probs = torch.rand(50)
45 |
46 | loss = pcd_loss(P, Ps, probs)
47 | assert torch.isfinite(loss), f"Expected loss to be finite, but got {loss.item()}"
48 | assert loss >= 0, f"Expected non-negative loss, but got {loss.item()}"
49 |
50 |
51 | def test_pcd_loss_gradient(pcd_loss):
52 | P = torch.randn((10, 3), requires_grad=True)
53 | Ps = torch.randn((5, 3), requires_grad=True)
54 | probs = torch.rand(5, requires_grad=True)
55 |
56 | loss = pcd_loss(P, Ps, probs)
57 | loss.backward()
58 |
59 | assert P.grad is not None, "Gradient for P should not be None"
60 | assert Ps.grad is not None, "Gradient for Ps should not be None"
61 | assert probs.grad is not None, "Gradient for probabilities should not be None"
62 | assert torch.isfinite(P.grad).all(), "Gradient for P should be finite"
63 | assert torch.isfinite(Ps.grad).all(), "Gradient for Ps should be finite"
64 | assert torch.isfinite(
65 | probs.grad
66 | ).all(), "Gradient for probabilities should be finite"
67 |
--------------------------------------------------------------------------------
/tests/losses/test_proba_surface_distance.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytest
3 | from neural_mesh_simplification.losses.surface_distance_loss import (
4 | ProbabilisticSurfaceDistanceLoss,
5 | )
6 |
7 |
8 | @pytest.fixture
9 | def loss_fn():
10 | return ProbabilisticSurfaceDistanceLoss(k=3, num_samples=100)
11 |
12 |
13 | @pytest.fixture
14 | def simple_cube_data():
15 | vertices = torch.tensor(
16 | [
17 | [0, 0, 0],
18 | [1, 0, 0],
19 | [0, 1, 0],
20 | [1, 1, 0],
21 | [0, 0, 1],
22 | [1, 0, 1],
23 | [0, 1, 1],
24 | [1, 1, 1],
25 | ],
26 | dtype=torch.float32,
27 | )
28 |
29 | faces = torch.tensor(
30 | [
31 | [0, 1, 2],
32 | [1, 3, 2],
33 | [4, 5, 6],
34 | [5, 7, 6],
35 | [0, 4, 1],
36 | [1, 4, 5],
37 | [2, 3, 6],
38 | [3, 7, 6],
39 | [0, 2, 4],
40 | [2, 6, 4],
41 | [1, 5, 3],
42 | [3, 5, 7],
43 | ],
44 | dtype=torch.long,
45 | )
46 |
47 | return vertices, faces
48 |
49 |
50 | def test_loss_zero_for_identical_meshes(loss_fn, simple_cube_data):
51 | vertices, faces = simple_cube_data
52 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32)
53 |
54 | loss = loss_fn(vertices, faces, vertices, faces, face_probs)
55 | print(f"Loss for identical meshes: {loss.item()}")
56 | assert loss.item() < 1e-5
57 |
58 |
59 | def test_loss_increases_with_vertex_displacement(loss_fn, simple_cube_data):
60 | vertices, faces = simple_cube_data
61 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32)
62 |
63 | displaced_vertices = vertices.clone()
64 | displaced_vertices[0] += torch.tensor([0.1, 0.1, 0.1])
65 |
66 | loss_original = loss_fn(vertices, faces, vertices, faces, face_probs)
67 | loss_displaced = loss_fn(vertices, faces, displaced_vertices, faces, face_probs)
68 |
69 | print(
70 | f"Original loss: {loss_original.item()}, Displaced loss: {loss_displaced.item()}"
71 | )
72 | assert loss_displaced > loss_original
73 | assert not torch.isclose(loss_displaced, loss_original, atol=1e-5)
74 |
75 |
76 | def test_loss_increases_with_lower_face_probabilities(loss_fn, simple_cube_data):
77 | vertices, faces = simple_cube_data
78 | high_probs = torch.ones(faces.shape[0], dtype=torch.float32)
79 | low_probs = torch.full((faces.shape[0],), 0.5, dtype=torch.float32)
80 |
81 | loss_high = loss_fn(vertices, faces, vertices, faces, high_probs)
82 | loss_low = loss_fn(vertices, faces, vertices, faces, low_probs)
83 |
84 | assert loss_low > loss_high
85 |
86 |
87 | def test_loss_handles_empty_meshes(loss_fn):
88 | empty_vertices = torch.empty((0, 3), dtype=torch.float32)
89 | empty_faces = torch.empty((0, 3), dtype=torch.long)
90 | empty_probs = torch.empty(0, dtype=torch.float32)
91 |
92 | loss = loss_fn(
93 | empty_vertices, empty_faces, empty_vertices, empty_faces, empty_probs
94 | )
95 | assert loss.item() == 0.0
96 |
97 |
98 | def test_loss_is_symmetric(loss_fn, simple_cube_data):
99 | vertices, faces = simple_cube_data
100 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32)
101 |
102 | loss_forward = loss_fn(vertices, faces, vertices, faces, face_probs)
103 | loss_reverse = loss_fn(vertices, faces, vertices, faces, face_probs)
104 |
105 | assert torch.isclose(loss_forward, loss_reverse, atol=1e-6)
106 |
107 |
108 | def test_loss_gradients(loss_fn, simple_cube_data):
109 | vertices, faces = simple_cube_data
110 | vertices.requires_grad = True
111 | face_probs = torch.ones(faces.shape[0], dtype=torch.float32)
112 | face_probs.requires_grad = True
113 |
114 | loss = loss_fn(vertices, faces, vertices, faces, face_probs)
115 | loss.backward()
116 |
117 | assert vertices.grad is not None
118 | assert face_probs.grad is not None
119 | assert not torch.isnan(vertices.grad).any()
120 | assert not torch.isnan(face_probs.grad).any()
121 |
--------------------------------------------------------------------------------
/tests/losses/test_triangle_collision_loss.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from neural_mesh_simplification.losses.triangle_collision_loss import (
4 | TriangleCollisionLoss,
5 | )
6 |
7 |
8 | @pytest.fixture
9 | def collision_loss():
10 | return TriangleCollisionLoss(k=1, collision_threshold=1e-6)
11 |
12 |
13 | @pytest.fixture
14 | def test_meshes():
15 | return {
16 | "non_penetrating": {
17 | "vertices": torch.tensor(
18 | [
19 | [0, 0, 0],
20 | [1, 0, 0],
21 | [0, 1, 0], # Triangle 1
22 | [1, 1, 0], # Additional vertex for Triangle 2
23 | ],
24 | dtype=torch.float32,
25 | ),
26 | "faces": torch.tensor([[0, 1, 2], [1, 3, 2]], dtype=torch.long),
27 | },
28 | "edge_penetrating": {
29 | "vertices": torch.tensor(
30 | [
31 | [0, 0, 0],
32 | [1, 0, 0],
33 | [0.5, 1, 0], # Triangle 1
34 | [0.25, 0.25, -0.5],
35 | [0.75, 0.25, 0.5],
36 | [0.5, 0.75, 0], # Triangle 2
37 | ],
38 | dtype=torch.float32,
39 | ),
40 | "faces": torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.long),
41 | },
42 | }
43 |
44 |
45 | def test_collision_detection(collision_loss, test_meshes):
46 | for name, data in test_meshes.items():
47 | face_probabilities = torch.ones(data["faces"].shape[0], dtype=torch.float32)
48 | loss = collision_loss(data["vertices"], data["faces"], face_probabilities)
49 |
50 | print(f"{name} - Loss: {loss.item()}")
51 |
52 | if name == "non_penetrating":
53 | assert (
54 | loss.item() == 0
55 | ), f"Non-zero loss computed for non-penetrating case: {name}"
56 | else:
57 | assert loss.item() > 0, f"Zero loss computed for penetrating case: {name}"
58 |
59 |
60 | def test_collision_loss_values(collision_loss, test_meshes):
61 | data = test_meshes["edge_penetrating"]
62 | face_probabilities = torch.ones(data["faces"].shape[0], dtype=torch.float32)
63 | loss = collision_loss(data["vertices"], data["faces"], face_probabilities)
64 |
65 | assert (
66 | loss.item() > 0
67 | ), f"Expected positive loss for edge penetrating case, got {loss.item()}"
68 | print(f"Edge penetrating loss: {loss.item()}")
69 |
70 |
71 | def test_collision_loss_with_probabilities(collision_loss, test_meshes):
72 | data = test_meshes["edge_penetrating"]
73 | face_probabilities = torch.tensor([0.5, 0.7], dtype=torch.float32)
74 | loss = collision_loss(data["vertices"], data["faces"], face_probabilities)
75 |
76 | assert (
77 | loss.item() > 0
78 | ), f"Expected positive loss for edge penetrating case with probabilities, got {loss.item()}"
79 | print(f"Edge penetrating loss with probabilities: {loss.item()}")
80 |
81 |
82 | def test_empty_mesh(collision_loss):
83 | empty_vertices = torch.empty((0, 3), dtype=torch.float32)
84 | empty_faces = torch.empty((0, 3), dtype=torch.long)
85 | empty_probabilities = torch.empty(0, dtype=torch.float32)
86 |
87 | loss = collision_loss(empty_vertices, empty_faces, empty_probabilities)
88 | assert loss.item() == 0, f"Expected zero loss for empty mesh, got {loss.item()}"
89 |
90 |
91 | @pytest.fixture
92 | def complex_mesh():
93 | vertices = torch.tensor(
94 | [
95 | [0, 0, 0],
96 | [1, 0, 0],
97 | [0, 1, 0],
98 | [1, 1, 0],
99 | [0, 0, 1],
100 | [1, 0, 1],
101 | [0, 1, 1],
102 | [1, 1, 1],
103 | ],
104 | dtype=torch.float32,
105 | )
106 | faces = torch.tensor(
107 | [
108 | [0, 1, 2],
109 | [1, 3, 2],
110 | [4, 5, 6],
111 | [5, 7, 6],
112 | [0, 4, 1],
113 | [1, 4, 5],
114 | [2, 3, 6],
115 | [3, 7, 6],
116 | ],
117 | dtype=torch.long,
118 | )
119 | return {"vertices": vertices, "faces": faces}
120 |
121 |
122 | def test_complex_mesh(collision_loss, complex_mesh):
123 | face_probabilities = torch.ones(complex_mesh["faces"].shape[0], dtype=torch.float32)
124 | loss = collision_loss(
125 | complex_mesh["vertices"], complex_mesh["faces"], face_probabilities
126 | )
127 | print(f"Complex mesh - Loss: {loss.item()}")
128 | assert loss.item() >= 0, "Negative loss computed for complex mesh"
129 |
130 |
131 | def test_collision_detection_edge_cases(collision_loss):
132 | vertices = torch.tensor(
133 | [
134 | [0, 0, 0],
135 | [1, 0, 0],
136 | [0, 1, 0], # Triangle 1
137 | [0, 0, 1e-7],
138 | [1, 0, 1e-7],
139 | [0, 1, 1e-7], # Triangle 2 (very close but not intersecting)
140 | [0.5, 0.5, -0.5],
141 | [1.5, 0.5, -0.5],
142 | [0.5, 1.5, -0.5], # Triangle 3 (intersecting)
143 | ],
144 | dtype=torch.float32,
145 | )
146 | faces = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype=torch.long)
147 | face_probabilities = torch.ones(faces.shape[0], dtype=torch.float32)
148 |
149 | # Calculate and print face normals
150 | v0, v1, v2 = vertices[faces[:, 0]], vertices[faces[:, 1]], vertices[faces[:, 2]]
151 | face_normals = torch.linalg.cross(v1 - v0, v2 - v0)
152 | face_normals = face_normals / (torch.norm(face_normals, dim=1, keepdim=True) + 1e-6)
153 | print("Face normals:")
154 | print(face_normals)
155 |
156 | # Print triangle information
157 | for i, face in enumerate(faces):
158 | print(f"Triangle {i}:")
159 | print(f" Vertices: {vertices[face].tolist()}")
160 | print(f" Normal: {face_normals[i].tolist()}")
161 | print(f" Centroid: {vertices[face].mean(dim=0).tolist()}")
162 |
163 | loss = collision_loss(vertices, faces, face_probabilities)
164 | print(f"Edge case loss: {loss.item()}")
165 |
166 | assert loss.item() > 0, "Should detect the intersecting triangle"
167 | assert (
168 | loss.item() < 3
169 | ), "Should not detect collision for very close but non-intersecting triangles"
170 | assert (
171 | loss.item() == 2
172 | ), "Should detect exactly two collisions (Triangle 3 intersects both Triangle 1 and 2)"
173 |
--------------------------------------------------------------------------------
/tests/mesh_data/cube.obj:
--------------------------------------------------------------------------------
1 | # https://github.com/mikedh/trimesh
2 | v -1.00000000 -1.00000000 -1.00000000
3 | v -1.00000000 -1.00000000 1.00000000
4 | v -1.00000000 1.00000000 -1.00000000
5 | v -1.00000000 1.00000000 1.00000000
6 | v 1.00000000 -1.00000000 -1.00000000
7 | v 1.00000000 -1.00000000 1.00000000
8 | v 1.00000000 1.00000000 -1.00000000
9 | v 1.00000000 1.00000000 1.00000000
10 | f 2 4 1
11 | f 5 2 1
12 | f 1 4 3
13 | f 3 5 1
14 | f 2 8 4
15 | f 6 2 5
16 | f 6 8 2
17 | f 4 8 3
18 | f 7 5 3
19 | f 3 8 7
20 | f 7 6 5
21 | f 8 6 7
22 |
23 |
--------------------------------------------------------------------------------
/tests/mesh_data/rounded_cube.obj:
--------------------------------------------------------------------------------
1 | # Blender 4.2.0
2 | # www.blender.org
3 | o Cube
4 | v 0.966276 1.000000 -0.966276
5 | v 1.000000 0.966276 -1.000000
6 | v 0.966276 1.000000 0.966276
7 | v 1.000000 0.966276 1.000000
8 | v -0.966276 1.000000 -0.966276
9 | v -1.000000 0.966276 -1.000000
10 | v -0.966276 1.000000 0.966276
11 | v -1.000000 0.966276 1.000000
12 | v 1.000000 -0.950029 -1.000000
13 | v 0.950029 -1.000000 -0.950029
14 | v 0.950029 -1.000000 0.950029
15 | v 1.000000 -0.950029 1.000000
16 | v -1.000000 -0.950029 -1.000000
17 | v -0.950029 -1.000000 -0.950029
18 | v -0.950029 -1.000000 0.950029
19 | v -1.000000 -0.950029 1.000000
20 | vn -1.0000 -0.0000 -0.0000
21 | vn 1.0000 -0.0000 -0.0000
22 | vn -0.0000 1.0000 -0.0000
23 | vn -0.0000 -0.0000 1.0000
24 | vn -0.0000 0.7071 0.7071
25 | vn 0.7071 0.7071 -0.0000
26 | vn -0.7071 0.7071 -0.0000
27 | vn -0.0000 0.7071 -0.7071
28 | vn -0.0000 -1.0000 -0.0000
29 | vn -0.7071 -0.7071 -0.0000
30 | vn -0.0000 -0.7071 -0.7071
31 | vn -0.0000 -0.7071 0.7071
32 | vn 0.7071 -0.7071 -0.0000
33 | vn -0.0000 -0.0000 -1.0000
34 | s 0
35 | f 16//1 8//1 6//1 13//1
36 | f 9//2 2//2 4//2 12//2
37 | f 1//3 5//3 7//3 3//3
38 | f 12//4 4//4 8//4 16//4
39 | f 3//5 7//5 8//5 4//5
40 | f 1//6 3//6 4//6 2//6
41 | f 7//7 5//7 6//7 8//7
42 | f 5//8 1//8 2//8 6//8
43 | f 14//9 10//9 11//9 15//9
44 | f 14//10 15//10 16//10 13//10
45 | f 10//11 14//11 13//11 9//11
46 | f 15//12 11//12 12//12 16//12
47 | f 11//13 10//13 9//13 12//13
48 | f 13//14 6//14 2//14 9//14
49 |
--------------------------------------------------------------------------------
/tests/mesh_data/sharp_cube.obj:
--------------------------------------------------------------------------------
1 | # Blender 4.2.0
2 | # www.blender.org
3 | o Cube
4 | v 1.000000 1.000000 -1.000000
5 | v 1.000000 -1.000000 -1.000000
6 | v 1.000000 1.000000 1.000000
7 | v 1.000000 -1.000000 1.000000
8 | v -1.000000 1.000000 -1.000000
9 | v -1.000000 -1.000000 -1.000000
10 | v -1.000000 1.000000 1.000000
11 | v -1.000000 -1.000000 1.000000
12 | vn -0.0000 1.0000 -0.0000
13 | vn -0.0000 -0.0000 1.0000
14 | vn -1.0000 -0.0000 -0.0000
15 | vn -0.0000 -1.0000 -0.0000
16 | vn 1.0000 -0.0000 -0.0000
17 | vn -0.0000 -0.0000 -1.0000
18 | s 0
19 | f 1//1 5//1 7//1 3//1
20 | f 4//2 3//2 7//2 8//2
21 | f 8//3 7//3 5//3 6//3
22 | f 6//4 2//4 4//4 8//4
23 | f 2//5 1//5 3//5 4//5
24 | f 6//6 5//6 1//6 2//6
25 |
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import numpy as np
3 | import networkx as nx
4 | from torch_geometric.data import Data
5 | from neural_mesh_simplification.data.dataset import (
6 | MeshSimplificationDataset,
7 | preprocess_mesh,
8 | mesh_to_tensor,
9 | load_mesh,
10 | )
11 |
12 |
13 | def test_load_mesh(tmp_path):
14 | # Create a temporary mesh file
15 | mesh = trimesh.creation.box()
16 | file_path = tmp_path / "test_mesh.obj"
17 | mesh.export(file_path)
18 |
19 | loaded_mesh = load_mesh(str(file_path))
20 | assert isinstance(loaded_mesh, trimesh.Trimesh)
21 | assert np.allclose(loaded_mesh.vertices, mesh.vertices)
22 | assert np.array_equal(loaded_mesh.faces, mesh.faces)
23 |
24 |
25 | def test_preprocess_mesh_centered(sample_mesh):
26 | processed_mesh = preprocess_mesh(sample_mesh)
27 | # Check that the mesh is centered
28 | assert np.allclose(
29 | processed_mesh.vertices.mean(axis=0), np.zeros(3)
30 | ), "Mesh is not centered"
31 |
32 |
33 | def test_preprocess_mesh_scaled(sample_mesh):
34 | processed_mesh = preprocess_mesh(sample_mesh)
35 |
36 | max_dim = np.max(
37 | processed_mesh.vertices.max(axis=0) - processed_mesh.vertices.min(axis=0)
38 | )
39 | assert np.isclose(max_dim, 1.0), "Mesh is not scaled to unit cube"
40 |
41 |
42 | def test_mesh_to_tensor(sample_mesh: trimesh.Trimesh):
43 | data = mesh_to_tensor(sample_mesh)
44 | assert isinstance(data, Data)
45 | assert data.num_nodes == len(sample_mesh.vertices)
46 | assert data.face.shape[1] == len(sample_mesh.faces)
47 | assert data.edge_index.shape[0] == 2
48 | assert data.edge_index.max() < data.num_nodes
49 |
50 |
51 | def test_graph_structure_in_data(sample_mesh):
52 | data = mesh_to_tensor(sample_mesh)
53 |
54 | # Check number of nodes
55 | assert data.num_nodes == len(sample_mesh.vertices)
56 |
57 | # Check edge_index
58 | assert data.edge_index.shape[0] == 2
59 | assert data.edge_index.max() < data.num_nodes
60 |
61 | # Reconstruct graph from edge_index
62 | G = nx.Graph()
63 | edge_list = data.edge_index.t().tolist()
64 | G.add_edges_from(edge_list)
65 |
66 | # Check reconstructed graph properties
67 | assert len(G.nodes) == len(sample_mesh.vertices)
68 | assert len(G.edges) == (3 * len(sample_mesh.faces) - len(sample_mesh.edges_unique))
69 |
70 | # Check connectivity
71 | assert nx.is_connected(G)
72 |
73 | # Check degree distribution
74 | degrees = [d for n, d in G.degree()]
75 | assert min(degrees) >= 3 # Each vertex should be connected to at least 3 others
76 |
77 | # Check if the graph is manifold-like (each edge should be shared by at most two faces)
78 | edge_face_count = {}
79 | for face in sample_mesh.faces:
80 | for i in range(3):
81 | edge = tuple(sorted([face[i], face[(i + 1) % 3]]))
82 | edge_face_count[edge] = edge_face_count.get(edge, 0) + 1
83 | assert all(count <= 2 for count in edge_face_count.values())
84 |
85 |
86 | def test_dataset(tmp_path):
87 | # Create a few temporary mesh files
88 | for i in range(3):
89 | mesh = trimesh.creation.box()
90 | file_path = tmp_path / f"test_mesh_{i}.obj"
91 | mesh.export(file_path)
92 |
93 | dataset = MeshSimplificationDataset(str(tmp_path))
94 | assert len(dataset) == 3
95 |
96 | sample = dataset[0]
97 | assert isinstance(sample, Data)
98 | assert sample.num_nodes > 0
99 | assert sample.face.shape[1] > 0
100 |
--------------------------------------------------------------------------------
/tests/test_edge_predictor.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn as nn
4 | from torch_geometric.data import Data
5 | from torch_geometric.nn import knn_graph
6 |
7 | from neural_mesh_simplification.models.edge_predictor import EdgePredictor
8 | from neural_mesh_simplification.models.layers.devconv import DevConv
9 |
10 |
11 | @pytest.fixture
12 | def sample_mesh_data():
13 | x = torch.tensor(
14 | [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]],
15 | dtype=torch.float,
16 | )
17 | edge_index = torch.tensor(
18 | [[0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 3, 0, 3, 1, 2]], dtype=torch.long
19 | )
20 | return Data(x=x, edge_index=edge_index)
21 |
22 |
23 | def test_edge_predictor_initialization():
24 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=15)
25 | assert isinstance(edge_predictor.devconv, DevConv)
26 | assert isinstance(edge_predictor.W_q, nn.Linear)
27 | assert isinstance(edge_predictor.W_k, nn.Linear)
28 | assert edge_predictor.k == 15
29 |
30 |
31 | def test_edge_predictor_forward(sample_mesh_data):
32 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2)
33 | simplified_adj_indices, simplified_adj_values = edge_predictor(
34 | sample_mesh_data.x, sample_mesh_data.edge_index
35 | )
36 |
37 | assert isinstance(simplified_adj_indices, torch.Tensor)
38 | assert isinstance(simplified_adj_values, torch.Tensor)
39 | assert simplified_adj_indices.shape[0] == 2 # 2 rows for source and target indices
40 | assert (
41 | simplified_adj_values.shape[0] == simplified_adj_indices.shape[1]
42 | ) # Same number of values as edges
43 |
44 |
45 | def test_edge_predictor_output_range(sample_mesh_data):
46 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2)
47 | _, simplified_adj_values = edge_predictor(
48 | sample_mesh_data.x, sample_mesh_data.edge_index
49 | )
50 |
51 | assert (simplified_adj_values >= 0).all() # Values should be non-negative
52 |
53 |
54 | def test_edge_predictor_symmetry(sample_mesh_data):
55 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2)
56 | simplified_adj_indices, simplified_adj_values = edge_predictor(
57 | sample_mesh_data.x, sample_mesh_data.edge_index
58 | )
59 |
60 | # Create a sparse tensor from the output
61 | n = sample_mesh_data.x.shape[0]
62 | adj_matrix = torch.sparse_coo_tensor(
63 | simplified_adj_indices, simplified_adj_values, (n, n)
64 | )
65 | dense_adj = adj_matrix.to_dense()
66 |
67 | assert torch.allclose(dense_adj, dense_adj.t(), atol=1e-6)
68 |
69 |
70 | def test_edge_predictor_connectivity(sample_mesh_data):
71 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2)
72 | simplified_adj_indices, _ = edge_predictor(
73 | sample_mesh_data.x, sample_mesh_data.edge_index
74 | )
75 |
76 | # Check if all nodes are connected
77 | unique_nodes = torch.unique(simplified_adj_indices)
78 | assert len(unique_nodes) == sample_mesh_data.x.shape[0]
79 |
80 |
81 | def test_edge_predictor_different_input_sizes():
82 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=5)
83 |
84 | # Test with a larger graph
85 | x = torch.rand(10, 3)
86 | edge_index = torch.randint(0, 10, (2, 30))
87 | simplified_adj_indices, simplified_adj_values = edge_predictor(x, edge_index)
88 |
89 | assert simplified_adj_indices.shape[0] == 2
90 | assert simplified_adj_values.shape[0] == simplified_adj_indices.shape[1]
91 | assert torch.max(simplified_adj_indices) < 10
92 |
93 |
94 | def test_attention_scores_shape(sample_mesh_data):
95 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2)
96 |
97 | # Get intermediate features
98 | knn_edges = knn_graph(sample_mesh_data.x, k=2, flow="target_to_source")
99 | extended_edges = torch.cat([sample_mesh_data.edge_index, knn_edges], dim=1)
100 | features = edge_predictor.devconv(sample_mesh_data.x, extended_edges)
101 |
102 | # Test attention scores
103 | attention_scores = edge_predictor.compute_attention_scores(
104 | features, sample_mesh_data.edge_index
105 | )
106 |
107 | assert attention_scores.shape[0] == sample_mesh_data.edge_index.shape[1]
108 | assert torch.allclose(
109 | attention_scores.sum(),
110 | torch.tensor(
111 | len(torch.unique(sample_mesh_data.edge_index[0])), dtype=torch.float32
112 | ),
113 | )
114 |
115 |
116 | def test_simplified_adjacency_shapes():
117 | # Create a simple graph
118 | x = torch.rand(5, 3)
119 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
120 | attention_scores = torch.rand(edge_index.shape[1])
121 |
122 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=15)
123 | indices, values = edge_predictor.compute_simplified_adjacency(
124 | attention_scores, edge_index
125 | )
126 |
127 | assert indices.shape[0] == 2
128 | assert indices.shape[1] == values.shape[0]
129 | assert torch.max(indices) < 5
130 |
131 |
132 | def test_empty_input_handling():
133 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=15)
134 | x = torch.rand(5, 3)
135 | empty_edge_index = torch.empty((2, 0), dtype=torch.long)
136 |
137 | # Test forward pass with empty edge_index
138 | with pytest.raises(ValueError, match="Edge index is empty"):
139 | indices, values = edge_predictor(x, empty_edge_index)
140 |
141 | # Test compute_simplified_adjacency with empty edge_index
142 | empty_attention_scores = torch.empty(0)
143 | with pytest.raises(ValueError, match="Edge index is empty"):
144 | indices, values = edge_predictor.compute_simplified_adjacency(
145 | empty_attention_scores, empty_edge_index
146 | )
147 |
148 |
149 | def test_feature_transformation():
150 | edge_predictor = EdgePredictor(in_channels=3, hidden_channels=64, k=2)
151 | x = torch.rand(5, 3)
152 | edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
153 |
154 | # Get intermediate features
155 | knn_edges = knn_graph(x, k=2, flow="target_to_source")
156 | extended_edges = torch.cat([edge_index, knn_edges], dim=1)
157 | features = edge_predictor.devconv(x, extended_edges)
158 |
159 | # Check feature dimensions
160 | assert features.shape == (5, 64) # [num_nodes, hidden_channels]
161 |
162 | # Check transformed features through attention layers
163 | q = edge_predictor.W_q(features)
164 | k = edge_predictor.W_k(features)
165 | assert q.shape == (5, 64)
166 | assert k.shape == (5, 64)
167 |
--------------------------------------------------------------------------------
/tests/test_face_classifier.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from neural_mesh_simplification.models.face_classifier import FaceClassifier
5 |
6 |
7 | @pytest.fixture
8 | def face_classifier():
9 | return FaceClassifier(input_dim=16, hidden_dim=32, num_layers=3, k=20)
10 |
11 |
12 | def test_face_classifier_initialization(face_classifier):
13 | assert len(face_classifier.triconv_layers) == 3
14 | assert isinstance(face_classifier.final_layer, torch.nn.Linear)
15 |
16 |
17 | def test_face_classifier_forward(face_classifier):
18 | num_faces = 100
19 | x = torch.randn(num_faces, 16)
20 | pos = torch.randn(num_faces, 3)
21 |
22 | out = face_classifier(x, pos)
23 | assert out.shape == (num_faces,)
24 | assert torch.all(out >= 0) and torch.all(out <= 1)
25 | assert torch.isclose(out.sum(), torch.tensor(1.0), atol=1e-6)
26 |
27 |
28 | def test_face_classifier_gradient(face_classifier):
29 | num_faces = 100
30 | x = torch.randn(num_faces, 16, requires_grad=True)
31 | pos = torch.randn(num_faces, 3, requires_grad=True)
32 |
33 | out = face_classifier(x, pos)
34 | loss = out.sum()
35 | loss.backward()
36 |
37 | assert x.grad is not None
38 | assert pos.grad is not None
39 | assert all(p.grad is not None for p in face_classifier.parameters())
40 |
41 |
42 | def test_face_classifier_with_batch(face_classifier):
43 | num_faces = 100
44 | batch_size = 2
45 | x = torch.randn(num_faces, 16)
46 | pos = torch.randn(num_faces, 3)
47 | batch = torch.cat(
48 | [torch.full((num_faces // batch_size,), i) for i in range(batch_size)]
49 | )
50 |
51 | out = face_classifier(x, pos, batch)
52 | assert out.shape == (num_faces,)
53 | assert torch.all(out >= 0) and torch.all(out <= 1)
54 |
55 | # Check if the sum of probabilities for each batch is close to 1
56 | for i in range(batch_size):
57 | batch_sum = out[batch == i].sum()
58 | assert torch.isclose(batch_sum, torch.tensor(1.0), atol=1e-6)
59 |
60 |
61 | def test_face_classifier_knn_graph(face_classifier):
62 | num_faces = 100
63 | x = torch.randn(num_faces, 16)
64 | pos = torch.randn(num_faces, 3, 3) # 3 vertices per face
65 |
66 | # Call the forward method to construct the k-nn graph
67 | _ = face_classifier(x, pos)
68 |
69 | # Get the constructed edge_index from the first TriConv layer
70 | edge_index = face_classifier.triconv_layers[0].last_edge_index
71 |
72 | # Check the number of neighbors for each face
73 | for i in range(num_faces):
74 | actual_neighbors = edge_index[1][edge_index[0] == i]
75 | assert (
76 | len(actual_neighbors) >= face_classifier.k
77 | ), f"Face {i} has {len(actual_neighbors)} neighbors, which is less than {face_classifier.k}"
78 |
79 | # Verify that the graph is symmetric
80 | symmetric_diff = set(map(tuple, edge_index.t().tolist())) ^ set(
81 | map(tuple, edge_index.flip(0).t().tolist())
82 | )
83 | assert len(symmetric_diff) == 0, "The k-nn graph is not symmetric"
84 |
85 | # Verify that there are no self-loops
86 | assert torch.all(
87 | edge_index[0] != edge_index[1]
88 | ), "The k-nn graph contains self-loops"
89 |
--------------------------------------------------------------------------------
/tests/test_mesh_operations.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 |
4 | from neural_mesh_simplification.utils import build_graph_from_mesh
5 |
6 |
7 | def test_build_graph_from_mesh(sample_mesh):
8 | graph = build_graph_from_mesh(sample_mesh)
9 |
10 | # Check number of nodes and edges
11 | assert len(graph.nodes) == len(sample_mesh.vertices)
12 | assert len(graph.edges) == (
13 | 3 * len(sample_mesh.faces) - len(sample_mesh.edges_unique)
14 | )
15 |
16 | # Check node attributes
17 | for i, pos in enumerate(sample_mesh.vertices):
18 | assert i in graph.nodes
19 | assert np.allclose(graph.nodes[i]["pos"], pos)
20 |
21 | # Check edge connectivity
22 | for face in sample_mesh.faces:
23 | assert graph.has_edge(face[0], face[1])
24 | assert graph.has_edge(face[1], face[2])
25 | assert graph.has_edge(face[2], face[0])
26 |
27 | # Check graph connectivity
28 | assert nx.is_connected(graph)
29 |
--------------------------------------------------------------------------------
/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import trimesh
3 | from trimesh import Trimesh
4 |
5 | from neural_mesh_simplification.metrics.chamfer_distance import chamfer_distance
6 | from neural_mesh_simplification.metrics.edge_preservation import edge_preservation
7 | from neural_mesh_simplification.metrics.hausdorff_distance import hausdorff_distance
8 | from neural_mesh_simplification.metrics.normal_consistency import normal_consistency
9 |
10 |
11 | def create_cube_mesh(scale=1.0):
12 | vertices = (
13 | np.array(
14 | [
15 | [0, 0, 0],
16 | [1, 0, 0],
17 | [1, 1, 0],
18 | [0, 1, 0],
19 | [0, 0, 1],
20 | [1, 0, 1],
21 | [1, 1, 1],
22 | [0, 1, 1],
23 | ]
24 | )
25 | * scale
26 | )
27 | faces = np.array(
28 | [
29 | [0, 1, 2],
30 | [0, 2, 3],
31 | [4, 5, 6],
32 | [4, 6, 7],
33 | [0, 1, 5],
34 | [0, 5, 4],
35 | [2, 3, 7],
36 | [2, 7, 6],
37 | [1, 2, 6],
38 | [1, 6, 5],
39 | [0, 3, 7],
40 | [0, 7, 4],
41 | ]
42 | )
43 |
44 | return Trimesh(vertices=vertices, faces=faces)
45 |
46 |
47 | def test_chamfer_distance_identical_meshes():
48 | mesh1 = create_cube_mesh()
49 | mesh2 = create_cube_mesh()
50 |
51 | dist = chamfer_distance(mesh1, mesh2)
52 |
53 | assert np.isclose(
54 | dist, 0.0
55 | ), f"Chamfer distance for identical meshes should be 0, got {dist}"
56 |
57 |
58 | def test_chamfer_distance_different_meshes():
59 | mesh1 = create_cube_mesh()
60 | mesh2 = create_cube_mesh(scale=2.0) # Scale the second cube to be twice as large
61 |
62 | dist = chamfer_distance(mesh1, mesh2)
63 |
64 | assert (
65 | dist > 0
66 | ), f"Chamfer distance for different meshes should be greater than 0, got {dist}"
67 |
68 |
69 | def test_normal_consistency():
70 | mesh = create_cube_mesh()
71 | consistency = normal_consistency(mesh)
72 |
73 | expected_consistency = 0.577350269189626
74 |
75 | assert np.isclose(
76 | consistency, expected_consistency
77 | ), f"Normal consistency should be {expected_consistency}, got {consistency}"
78 |
79 |
80 | def test_edge_preservation():
81 | original_mesh = trimesh.load("./tests/mesh_data/rounded_cube.obj")
82 | simplified_mesh = trimesh.load("./tests/mesh_data/sharp_cube.obj")
83 |
84 | # original_mesh = trimesh.creation.icosphere(subdivisions=3, radius=2)
85 | # simplified_mesh = trimesh.creation.icosphere(subdivisions=2, radius=2)
86 |
87 | preservation_metric = edge_preservation(original_mesh, simplified_mesh)
88 |
89 | assert (
90 | preservation_metric < 1.0
91 | ), f"Edge preservation metric should be less than 1.0, got {preservation_metric}"
92 | assert (
93 | preservation_metric > 0.0
94 | ), f"Edge preservation metric should be greater than 0.0, got {preservation_metric}"
95 |
96 |
97 | def test_hausdorff_distance_identical_meshes():
98 | mesh1 = create_cube_mesh()
99 | mesh2 = create_cube_mesh()
100 |
101 | dist = hausdorff_distance(mesh1, mesh2)
102 |
103 | assert np.isclose(
104 | dist, 0.0
105 | ), f"Hausdorff distance for identical meshes should be 0, got {dist}"
106 |
107 |
108 | def test_hausdorff_distance_different_meshes():
109 | mesh1 = create_cube_mesh()
110 | mesh2 = trimesh.creation.icosphere(subdivisions=2, radius=2)
111 |
112 | dist = hausdorff_distance(mesh1, mesh2)
113 |
114 | assert (
115 | dist > 1.99
116 | ), f"Hausdorff distance for identical meshes should be 2, got {dist}"
117 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch_geometric.utils
4 | from torch_geometric.data import Data
5 |
6 | from neural_mesh_simplification.models import NeuralMeshSimplification
7 |
8 |
9 | @pytest.fixture
10 | def sample_data() -> Data:
11 | num_nodes = 10
12 | x = torch.randn(num_nodes, 3)
13 | # Create a more densely connected edge index
14 | edge_index = torch.tensor(
15 | [
16 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4],
17 | [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 6, 7],
18 | ],
19 | dtype=torch.long,
20 | )
21 | pos = torch.randn(num_nodes, 3)
22 | return Data(x=x, edge_index=edge_index, pos=pos)
23 |
24 |
25 | def test_neural_mesh_simplification_forward(sample_data: Data):
26 | # Set a fixed random seed for reproducibility
27 | torch.manual_seed(42)
28 |
29 | model = NeuralMeshSimplification(
30 | input_dim=3,
31 | hidden_dim=64,
32 | edge_hidden_dim=64,
33 | num_layers=3,
34 | k=3, # Reduce k to avoid too many edges in the test
35 | edge_k=15,
36 | target_ratio=0.5, # Ensure we sample roughly half the vertices
37 | )
38 |
39 | # First test point sampling
40 | sampled_indices, sampled_probs = model.sample_points(sample_data)
41 | assert sampled_indices.numel() > 0, "No points were sampled"
42 | assert sampled_indices.numel() <= sample_data.num_nodes, "Too many points sampled"
43 |
44 | # Get the subgraph for sampled points
45 | sampled_edge_index, _ = torch_geometric.utils.subgraph(
46 | sampled_indices,
47 | sample_data.edge_index,
48 | relabel_nodes=True,
49 | num_nodes=sample_data.num_nodes,
50 | )
51 | assert sampled_edge_index.numel() > 0, "No edges in sampled subgraph"
52 |
53 | # Now test the full forward pass
54 | output = model(sample_data)
55 |
56 | # Add assertions to check the output structure and shapes
57 | assert isinstance(output, dict)
58 | assert "sampled_indices" in output
59 | assert "sampled_probs" in output
60 | assert "sampled_vertices" in output
61 | assert "edge_index" in output
62 | assert "edge_probs" in output
63 | assert "candidate_triangles" in output
64 | assert "triangle_probs" in output
65 | assert "face_probs" in output
66 | assert "simplified_faces" in output
67 |
68 | # Check shapes
69 | assert output["sampled_indices"].dim() == 1
70 | # sampled_probs should match the number of sampled vertices
71 | assert output["sampled_probs"].shape == output["sampled_indices"].shape
72 | assert output["sampled_vertices"].shape[1] == 3 # 3D coordinates
73 |
74 | if output["edge_index"].numel() > 0: # Only check if we have edges
75 | assert output["edge_index"].shape[0] == 2 # Source and target nodes
76 | assert (
77 | len(output["edge_probs"]) == output["edge_index"].shape[1]
78 | ) # One prob per edge
79 |
80 | # Check that edge indices are valid
81 | num_sampled_vertices = output["sampled_vertices"].shape[0]
82 | assert torch.all(output["edge_index"] >= 0)
83 | assert torch.all(output["edge_index"] < num_sampled_vertices)
84 |
85 | if output["candidate_triangles"].numel() > 0: # Only check if we have triangles
86 | assert output["candidate_triangles"].shape[1] == 3 # Triangle indices
87 | assert len(output["triangle_probs"]) == len(output["candidate_triangles"])
88 | assert len(output["face_probs"]) == len(output["candidate_triangles"])
89 |
90 | # Additional checks
91 | assert output["sampled_indices"].shape[0] <= sample_data.num_nodes
92 | assert output["sampled_vertices"].shape[0] == output["sampled_indices"].shape[0]
93 |
94 | # Check that sampled_vertices correspond to a subset of original vertices
95 | original_vertices = sample_data.pos
96 | sampled_vertices = output["sampled_vertices"]
97 |
98 | # For each sampled vertex, check if it exists in original vertices
99 | for sv in sampled_vertices:
100 | # Check if this vertex exists in original vertices (within numerical precision)
101 | exists = torch.any(torch.all(torch.abs(original_vertices - sv) < 1e-6, dim=1))
102 | assert exists, "Sampled vertex not found in original vertices"
103 |
104 | # Check that simplified_faces only contain valid indices if not empty
105 | if output["simplified_faces"].numel() > 0:
106 | max_index = output["sampled_vertices"].shape[0] - 1
107 | assert torch.all(output["simplified_faces"] >= 0)
108 | assert torch.all(output["simplified_faces"] <= max_index)
109 |
110 | # Check the relationship between face_probs and simplified_faces
111 | if output["face_probs"].numel() > 0:
112 | assert output["simplified_faces"].shape[0] <= output["face_probs"].shape[0]
113 |
114 |
115 | def test_generate_candidate_triangles():
116 | model = NeuralMeshSimplification(
117 | input_dim=3,
118 | hidden_dim=64,
119 | edge_hidden_dim=64,
120 | num_layers=3,
121 | k=5,
122 | edge_k=15,
123 | target_ratio=0.5,
124 | )
125 | edge_index = torch.tensor(
126 | [[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]], dtype=torch.long
127 | )
128 | edge_probs = torch.tensor([0.9, 0.9, 0.8, 0.8, 0.7, 0.7])
129 |
130 | triangles, triangle_probs = model.generate_candidate_triangles(
131 | edge_index, edge_probs
132 | )
133 |
134 | assert triangles.shape[1] == 3
135 | assert triangle_probs.shape[0] == triangles.shape[0]
136 | assert torch.all(triangles >= 0)
137 | assert torch.all(triangles < edge_index.max() + 1)
138 | assert torch.all(triangle_probs >= 0) and torch.all(triangle_probs <= 1)
139 |
140 | max_possible_triangles = edge_index.max().item() + 1 # num_nodes
141 | max_possible_triangles = (
142 | max_possible_triangles
143 | * (max_possible_triangles - 1)
144 | * (max_possible_triangles - 2)
145 | // 6
146 | )
147 | assert triangles.shape[0] <= max_possible_triangles
148 |
--------------------------------------------------------------------------------
/tests/test_model_layers.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import pytest
3 | import torch
4 | from torch import nn
5 | from neural_mesh_simplification.models.layers import DevConv, TriConv
6 |
7 |
8 | def create_graph_data():
9 | x = torch.tensor(
10 | [
11 | [0.0, 0.0, 0.0], # Node 0
12 | [1.0, 0.0, 0.0], # Node 1
13 | [0.0, 1.0, 0.0], # Node 2
14 | [1.0, 1.0, 0.0], # Node 3
15 | ],
16 | dtype=torch.float,
17 | )
18 |
19 | edge_index = torch.tensor(
20 | [
21 | [0, 0, 1, 1, 2, 2, 3, 3], # Source nodes
22 | [1, 2, 0, 3, 0, 3, 1, 2], # Target nodes
23 | ],
24 | dtype=torch.long,
25 | )
26 |
27 | return x, edge_index
28 |
29 |
30 | @pytest.fixture
31 | def graph_data():
32 | return create_graph_data()
33 |
34 |
35 | def test_devconv(graph_data):
36 | x, edge_index = graph_data
37 |
38 | devconv = DevConv(in_channels=3, out_channels=4)
39 | output = devconv(x, edge_index)
40 |
41 | assert output.shape == (4, 4) # 4 nodes, 4 output channels
42 |
43 | if "-s" in sys.argv:
44 | print("Input shape:", x.shape)
45 | print("Output shape:", output.shape)
46 | print("Output:\n", output)
47 |
48 | analyze_feature_differences(x, edge_index)
49 |
50 |
51 | def analyze_feature_differences(x, edge_index):
52 | devconv = DevConv(in_channels=3, out_channels=3)
53 | output = devconv(x, edge_index)
54 |
55 | for i in range(x.shape[0]):
56 | neighbors = edge_index[1][edge_index[0] == i]
57 | print(f"Node {i}:")
58 | print(f" Input feature: {x[i]}")
59 | print(f" Output feature: {output[i]}")
60 | print(" Neighbor differences:")
61 | for j in neighbors:
62 | print(f" Node {j}: {x[i] - x[j]}")
63 | print()
64 |
65 |
66 | @pytest.fixture
67 | def triconv_layer():
68 | return TriConv(in_channels=16, out_channels=32)
69 |
70 |
71 | def test_triconv_initialization(triconv_layer):
72 | assert triconv_layer.in_channels == 16
73 | assert triconv_layer.out_channels == 32
74 | assert isinstance(triconv_layer.mlp, nn.Sequential)
75 | assert triconv_layer.mlp[0].in_features == 25 # 16 + 9
76 |
77 |
78 | def test_triconv_forward(triconv_layer):
79 | num_nodes = 10
80 | x = torch.randn(num_nodes, 16)
81 | pos = torch.randn(num_nodes, 3)
82 | edge_index = torch.randint(0, num_nodes, (2, 20))
83 |
84 | out = triconv_layer(x, pos, edge_index)
85 | assert out.shape == (num_nodes, 32)
86 |
87 |
88 | def test_relative_position_encoding(triconv_layer):
89 | num_nodes = 5
90 | pos = torch.randn(num_nodes, 3)
91 | edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]])
92 |
93 | rel_pos = triconv_layer.compute_relative_position_encoding(
94 | pos, edge_index[0], edge_index[1]
95 | )
96 | assert rel_pos.shape == (4, 9) # 4 edges, 9-dimensional encoding
97 |
98 |
99 | def test_triconv_gradient(triconv_layer):
100 | num_nodes = 10
101 | x = torch.randn(num_nodes, 16, requires_grad=True)
102 | pos = torch.randn(num_nodes, 3, requires_grad=True)
103 | edge_index = torch.randint(0, num_nodes, (2, 20))
104 |
105 | out = triconv_layer(x, pos, edge_index)
106 | loss = out.sum()
107 | loss.backward()
108 |
109 | assert x.grad is not None
110 | assert pos.grad is not None
111 | assert all(p.grad is not None for p in triconv_layer.parameters())
112 |
113 |
114 | def test_last_edge_index(triconv_layer):
115 | num_nodes = 10
116 | x = torch.randn(num_nodes, 16)
117 | pos = torch.randn(num_nodes, 3)
118 | edge_index = torch.randint(0, num_nodes, (2, 20))
119 |
120 | triconv_layer(x, pos, edge_index)
121 | assert torch.all(triconv_layer.last_edge_index == edge_index)
122 |
--------------------------------------------------------------------------------
/tests/test_point_sampler.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch import nn
4 | from neural_mesh_simplification.models.layers.devconv import DevConv
5 | from neural_mesh_simplification.models.point_sampler import PointSampler
6 |
7 |
8 | @pytest.fixture
9 | def sample_graph_data():
10 | x = torch.tensor(
11 | [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]],
12 | dtype=torch.float,
13 | )
14 | edge_index = torch.tensor(
15 | [[0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 3, 0, 3, 1, 2]], dtype=torch.long
16 | )
17 | return x, edge_index
18 |
19 |
20 | def test_point_sampler_initialization():
21 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3)
22 | assert len(sampler.convs) == 3
23 | assert isinstance(sampler.convs[0], DevConv)
24 | assert isinstance(sampler.output_layer, nn.Linear)
25 |
26 |
27 | def test_point_sampler_forward(sample_graph_data):
28 | x, edge_index = sample_graph_data
29 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3)
30 | probabilities = sampler(x, edge_index)
31 | assert probabilities.shape == (4,) # 4 input vertices
32 | assert (probabilities >= 0).all() and (probabilities <= 1).all()
33 |
34 |
35 | def test_point_sampler_sampling(sample_graph_data):
36 | x, edge_index = sample_graph_data
37 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3)
38 | probabilities = sampler(x, edge_index)
39 | sampled_indices = sampler.sample(probabilities, num_samples=2)
40 | assert sampled_indices.shape == (2,)
41 | assert len(torch.unique(sampled_indices)) == 2 # All indices should be unique
42 |
43 |
44 | def test_point_sampler_forward_and_sample(sample_graph_data):
45 | x, edge_index = sample_graph_data
46 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3)
47 | sampled_indices, probabilities = sampler.forward_and_sample(
48 | x, edge_index, num_samples=2
49 | )
50 | assert sampled_indices.shape == (2,)
51 | assert probabilities.shape == (4,)
52 | assert len(torch.unique(sampled_indices)) == 2
53 |
54 |
55 | def test_point_sampler_deterministic_behavior(sample_graph_data):
56 | x, edge_index = sample_graph_data
57 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3)
58 |
59 | torch.manual_seed(42)
60 | indices1, _ = sampler.forward_and_sample(x, edge_index, num_samples=2)
61 |
62 | torch.manual_seed(42)
63 | indices2, _ = sampler.forward_and_sample(x, edge_index, num_samples=2)
64 |
65 | assert torch.equal(indices1, indices2)
66 |
67 |
68 | def test_point_sampler_different_input_sizes():
69 | sampler = PointSampler(in_channels=3, out_channels=64, num_layers=3)
70 |
71 | x1 = torch.rand(10, 3)
72 | edge_index1 = torch.randint(0, 10, (2, 20))
73 | prob1 = sampler(x1, edge_index1)
74 | assert prob1.shape == (10,)
75 |
76 | x2 = torch.rand(20, 3)
77 | edge_index2 = torch.randint(0, 20, (2, 40))
78 | prob2 = sampler(x2, edge_index2)
79 | assert prob2.shape == (20,)
80 |
--------------------------------------------------------------------------------
/tests/test_trimesh.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import pytest
3 |
4 |
5 | @pytest.mark.trimesh
6 | class TestMeshCreation:
7 | def test_sphere(self):
8 | print("Creating sphere")
9 | mesh = trimesh.creation.icosphere(subdivisions=2, radius=2)
10 | mesh.export("./tests/mesh_data/sphere_2.obj")
11 |
12 | def test_cube(self):
13 | mesh = trimesh.creation.box(extents=[2, 2, 2])
14 | mesh.export("./tests/mesh_data/cube_2.obj")
15 |
--------------------------------------------------------------------------------
/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "eef23eb0-4fc4-4f75-bc46-de36bdceeb1b",
6 | "metadata": {},
7 | "source": [
8 | "# Train the Neural Mesh Simplification model"
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "68548e86-1a2b-4ab7-b89e-5c93314f9345",
14 | "metadata": {},
15 | "source": [
16 | "## Set up the environment"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "13c589dd-aa96-4d93-be4e-6cfdd181f10f",
22 | "metadata": {},
23 | "source": [
24 | "### [*only required for remote runs*] Remote environment setup\n",
25 | "\n",
26 | "If you are running this notebook remotely (e.g. Google Colab), you'll want to set up the environment by\n",
27 | "* Downloading the repository from GitHub\n",
28 | "* Setting up the python environment\n",
29 | "\n",
30 | "If are opening this notebook locally, by running `jupyter lab` from the repository root and the right conda environment activated, the above step is not required."
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "id": "179368da-e6fc-4beb-ae82-ca18137de974",
36 | "metadata": {},
37 | "source": [
38 | "#### Step 1. Check out the repo\n",
39 | "That's where the source code for mesh simplification, along with its dependency definitions and other utilities, lives."
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "id": "0bdf37ac-a233-44f9-8f6c-f929566a0bbd",
46 | "metadata": {},
47 | "outputs": [],
48 | "source": [
49 | "!git clone https://github.com/gennarinoos/neural-mesh-simplification.git neural-mesh-simplification"
50 | ]
51 | },
52 | {
53 | "cell_type": "markdown",
54 | "id": "90db897c-f4e0-416e-b92c-392716730e07",
55 | "metadata": {},
56 | "source": [
57 | "#### Step 2. Install python version 3.12 using apt-get"
58 | ]
59 | },
60 | {
61 | "cell_type": "markdown",
62 | "id": "5e1795ff",
63 | "metadata": {},
64 | "source": [
65 | "Check the current python version by running the following command. This notebook requires Python 3.12 to run. Either install it via your Notebook environment settings and jump to Step 6 or follow all the steps below."
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "id": "aaa8c207",
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "!python --version"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "id": "ef6be6e0-41b0-43ca-9fda-508a8cb198be",
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "!sudo apt-get update\n",
86 | "!sudo apt-get install python3.12"
87 | ]
88 | },
89 | {
90 | "cell_type": "markdown",
91 | "id": "71e15c47-fcff-4e5b-8692-ff62fe0cc764",
92 | "metadata": {},
93 | "source": [
94 | "#### Step 3. Update alternatives to use the new Python version"
95 | ]
96 | },
97 | {
98 | "cell_type": "code",
99 | "execution_count": null,
100 | "id": "6c36faa2-8839-4476-b74d-ffbf4f2f086f",
101 | "metadata": {},
102 | "outputs": [],
103 | "source": [
104 | "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.12 1\n",
105 | "!sudo update-alternatives --config python3"
106 | ]
107 | },
108 | {
109 | "cell_type": "markdown",
110 | "id": "d1c5e4ec-3074-4041-bdee-5246987b4ae0",
111 | "metadata": {},
112 | "source": [
113 | "#### Step 4. Install pip and the required packages for the new Python version."
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "id": "9d11a192-478c-4797-b9ba-42c41f9e9d9e",
120 | "metadata": {},
121 | "outputs": [],
122 | "source": [
123 | "!rm -f get-pip*.py\n",
124 | "!wget https://bootstrap.pypa.io/get-pip.py\n",
125 | "!python get-pip.py\n",
126 | "!python -m pip install ipykernel\n",
127 | "!python -m ipykernel install --user --name python3.12 --display-name \"Python 3.12\""
128 | ]
129 | },
130 | {
131 | "cell_type": "markdown",
132 | "id": "e1c3d62d-f255-4f2e-830f-09080b41d364",
133 | "metadata": {},
134 | "source": [
135 | "#### Step 5. Restart and verify\n",
136 | "At this point you may need to restart the session, after which you want to verify that `python` is at the right version (`3.12`)"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": null,
142 | "id": "0e29e1e1-5e0d-4dbf-b4cd-c5a26e5092b6",
143 | "metadata": {},
144 | "outputs": [],
145 | "source": [
146 | "!python --version"
147 | ]
148 | },
149 | {
150 | "cell_type": "markdown",
151 | "id": "2dd709a2-ebec-48fb-b5b7-2e75a9fbb129",
152 | "metadata": {},
153 | "source": [
154 | "#### Step 6. Upgrade pip and setuptools"
155 | ]
156 | },
157 | {
158 | "cell_type": "code",
159 | "execution_count": null,
160 | "id": "915962f0-9dff-40d3-a4f3-cbc1cf69fcbe",
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "!pip install --upgrade pip setuptools wheel\n",
165 | "!pip install --upgrade build"
166 | ]
167 | },
168 | {
169 | "cell_type": "markdown",
170 | "id": "4f5242f0-c6de-4c1c-9922-eaf4ee161e35",
171 | "metadata": {},
172 | "source": [
173 | "### Set repository as the working directory \n",
174 | "CD into the repository downloaded above"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": null,
180 | "id": "a70ca79a-83e8-418b-85e4-364ca8c6b9d3",
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "%cd neural-mesh-simplification"
185 | ]
186 | },
187 | {
188 | "cell_type": "markdown",
189 | "id": "d6b48728-47c7-4441-8043-258e71f2a8d1",
190 | "metadata": {},
191 | "source": [
192 | "### Package requirements"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "id": "40ccef76-a2ed-47a5-8c7b-bc7c14ba6770",
198 | "metadata": {},
199 | "source": [
200 | "Depending on whether you are using PyTorch on a CPU or a GPU,\n",
201 | "you'll have to use the correct binaries for PyTorch and the PyTorch Geometric libraries. You can install them via:"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "id": "2141322d-334d-4957-b371-0660a7f7dbfc",
208 | "metadata": {},
209 | "outputs": [],
210 | "source": [
211 | "!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121\n",
212 | "!pip install torch_cluster==1.6.3 torch_geometric==2.5.3 torch_scatter==2.1.2 torch_sparse==0.6.18 -f https://data.pyg.org/whl/torch-2.4.0+cu121.html"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "id": "cc1d234c-1c03-445a-b2f7-d56fd42e4f94",
218 | "metadata": {},
219 | "source": [
220 | "Replace “cu121” with the appropriate CUDA version for your system. If you don't know what is your cuda version, run `nvidia-smi`"
221 | ]
222 | },
223 | {
224 | "cell_type": "markdown",
225 | "id": "db20093e-689c-40ab-8f2e-c86337dbc466",
226 | "metadata": {},
227 | "source": [
228 | "Only then you can install the requirements via pip:"
229 | ]
230 | },
231 | {
232 | "cell_type": "code",
233 | "execution_count": null,
234 | "id": "bf9fa547-f157-4679-a6c1-f0304ac59aaf",
235 | "metadata": {},
236 | "outputs": [],
237 | "source": [
238 | "!pip install -r requirements.txt\n",
239 | "!pip uninstall -y neural-mesh-simplification\n",
240 | "!pip install ."
241 | ]
242 | },
243 | {
244 | "cell_type": "markdown",
245 | "id": "76d21495-9bd4-46df-a4c8-a31b573e3d79",
246 | "metadata": {
247 | "jp-MarkdownHeadingCollapsed": true
248 | },
249 | "source": [
250 | "---\n",
251 | "## Download the training data\n",
252 | "We can use the Hugging Face API to download some mesh data to use for training and evaluation."
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "id": "42801141-20d8-40da-8f59-b7b25338ceef",
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "import os\n",
263 | "import shutil\n",
264 | "from huggingface_hub import snapshot_download\n",
265 | "\n",
266 | "target_folder = \"data/raw\"\n",
267 | "wip_folder = os.path.join(target_folder, \"wip\")\n",
268 | "os.makedirs(wip_folder, exist_ok=True)\n",
269 | "\n",
270 | "# abc_train is really large (+5k meshes), so download just a sample\n",
271 | "folder_patterns = [\"abc_extra_noisy/03_meshes/*.ply\", \"abc_train/03_meshes/*.ply\"]\n",
272 | "\n",
273 | "# Download\n",
274 | "snapshot_download(\n",
275 | " repo_id=\"perler/ppsurf\",\n",
276 | " repo_type=\"dataset\",\n",
277 | " cache_dir=wip_folder,\n",
278 | " allow_patterns=folder_patterns[0],\n",
279 | ")\n",
280 | "\n",
281 | "# Move files from wip folder to target folder\n",
282 | "for root, _, files in os.walk(wip_folder):\n",
283 | " for file in files:\n",
284 | " if file.endswith(\".ply\"):\n",
285 | " src_file = os.path.join(root, file)\n",
286 | " dest_file = os.path.join(target_folder, file)\n",
287 | " shutil.copy2(src_file, dest_file)\n",
288 | " os.remove(src_file)\n",
289 | "\n",
290 | "# Remove the wip folder\n",
291 | "shutil.rmtree(wip_folder)"
292 | ]
293 | },
294 | {
295 | "cell_type": "markdown",
296 | "id": "c6df29f3-ee23-47d3-b987-7b3f8e0cfc66",
297 | "metadata": {},
298 | "source": [
299 | "## Prepare the data\n",
300 | "The downloaded data needs to be prepapared for training. We can use a script in the repository we checked out for that."
301 | ]
302 | },
303 | {
304 | "cell_type": "code",
305 | "execution_count": null,
306 | "id": "bebe14c3-7ca1-4a47-80ab-a5d049def2bd",
307 | "metadata": {},
308 | "outputs": [],
309 | "source": [
310 | "!mkdir -p data/processed\n",
311 | "!python scripts/preprocess_data.py"
312 | ]
313 | },
314 | {
315 | "cell_type": "markdown",
316 | "id": "b1fe1228-5a40-4f44-9e58-5607beabd5a5",
317 | "metadata": {},
318 | "source": [
319 | "---\n",
320 | "## Model Training"
321 | ]
322 | },
323 | {
324 | "cell_type": "markdown",
325 | "id": "7b3af882-fe98-48d4-bd34-cf16912cc5d4",
326 | "metadata": {},
327 | "source": [
328 | "When using a GPU, ensure the training is happening on the GPU, and the environment is configured properly."
329 | ]
330 | },
331 | {
332 | "cell_type": "code",
333 | "execution_count": null,
334 | "id": "486cf962-81a4-41b4-a62a-9431d0ed6cf8",
335 | "metadata": {},
336 | "outputs": [],
337 | "source": [
338 | "import torch\n",
339 | "print(torch.cuda.is_available())\n",
340 | "\n",
341 | "!nvcc --version"
342 | ]
343 | },
344 | {
345 | "cell_type": "markdown",
346 | "id": "2fad8fd1-10a8-410c-b826-848c91a227bc",
347 | "metadata": {},
348 | "source": [
349 | "### Start the training"
350 | ]
351 | },
352 | {
353 | "cell_type": "code",
354 | "execution_count": null,
355 | "id": "ef712306-3971-421b-a3ee-d063759c926d",
356 | "metadata": {},
357 | "outputs": [],
358 | "source": [
359 | "!python scripts/train.py --data-path data/processed --config configs/default.yaml"
360 | ]
361 | }
362 | ],
363 | "metadata": {
364 | "kernelspec": {
365 | "display_name": "Python 3 (ipykernel)",
366 | "language": "python",
367 | "name": "python3"
368 | },
369 | "language_info": {
370 | "codemirror_mode": {
371 | "name": "ipython",
372 | "version": 3
373 | },
374 | "file_extension": ".py",
375 | "mimetype": "text/x-python",
376 | "name": "python",
377 | "nbconvert_exporter": "python",
378 | "pygments_lexer": "ipython3",
379 | "version": "3.12.8"
380 | }
381 | },
382 | "nbformat": 4,
383 | "nbformat_minor": 5
384 | }
385 |
--------------------------------------------------------------------------------
/uv.lock:
--------------------------------------------------------------------------------
1 | version = 1
2 | requires-python = ">=3.12"
3 |
4 | [[package]]
5 | name = "neural-mesh-simplification"
6 | version = "0.1.0"
7 | source = { virtual = "." }
--------------------------------------------------------------------------------