├── .gitignore ├── .gitmodules ├── README.md ├── assets ├── arch2.glb ├── ice_cream.glb ├── jacket.glb ├── potion.glb ├── samesh_examples.png ├── samesh_modalities.png ├── samesh_pipeline.png └── squidward_house.glb ├── configs ├── mesh_segmentation.yaml ├── mesh_segmentation_coseg.yaml ├── mesh_segmentation_princeton.yaml ├── mesh_segmentation_princeton_dynamic.yaml ├── mesh_segmentation_shape_diameter_function.yaml ├── mesh_segmentation_shape_diameter_function_coseg.yaml └── mesh_segmentation_shape_diameter_function_princeton.yaml ├── notebooks ├── mesh_samesh.ipynb ├── mesh_segmentation_annotations.ipynb ├── mesh_segmentation_recolormap.ipynb └── mesh_shape_diameter_function.ipynb ├── pyproject.toml ├── scripts ├── convert_gif2mp4.py ├── convert_mesh2gif.py ├── convert_mesh_formats.py └── generate_dataset_mesh_segmentation_comparison.py └── src └── samesh ├── __init__.py ├── data ├── __init__.py ├── common.py └── loaders.py ├── metrics ├── __init__.py ├── mesh_segmentation.py ├── mesh_segmentation_cut_discrepancy.py └── mesh_segmentation_test.py ├── models ├── __init__.py ├── sam.py ├── sam_mesh.py └── shape_diameter_function.py ├── renderer ├── __init__.py ├── renderer.py ├── renderer_animations.py ├── shader_programs.py └── shaders │ ├── barycentric.frag │ ├── barycentric.vert │ ├── faceid.frag │ ├── faceid.vert │ ├── normal.frag │ └── normal.vert └── utils ├── __init__.py ├── cameras.py ├── math.py ├── mesh.py └── polyhedra.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | outputs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 113 | .pdm.toml 114 | .pdm-python 115 | .pdm-build/ 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/segment-anything-2"] 2 | path = third_party/segment-anything-2 3 | url = https://github.com/facebookresearch/segment-anything-2.git -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment Any Mesh 2 | 3 | [Segment Any Mesh](https://arxiv.org/abs/2408.13679) (SAMesh) is a novel zero-shot method for mesh part segmentation that addresses the limitations of traditional shape analysis (e.g. Shape Diameter Function (ShapeDiam)) and learning-based approaches. It operates in two phases: multimodal rendering, where multiview renders of a mesh are processed through Segment Anything 2 (SAM2) to generate 2D masks, and 2D-to-3D lifting, where these masks are combined to produce a detailed 3D segmentation. Compared to other 2D-to-3D lifting methods, SAMesh does not require an input vocabulary, which limits those methods to semantic segmentation as opposed to part segmentation. SAMesh demonstrates good performance on traditional benchmarks and superior generalization on a newly curated dataset of diverse meshes, which we release below. 4 | 5 | Examples of running SAMesh on our curated dataset: 6 | ![Examples of running SAMesh on our curated dataset](./assets/samesh_examples.png) 7 | 8 | Pipeline of SAMesh: 9 | ![Pipeline of SAMesh](./assets/samesh_pipeline.png) 10 | 11 | Samesh handles untextured meshes, and it does so by rendering different modalities before applying Segment Anything (`mode` parameter in config). 12 | ![Rendered Modalities](./assets/samesh_modalities.png) 13 | 14 | 15 | ## Installation 16 | 17 | To install SAMesh, use the following commands: 18 | 19 | ```bash 20 | pip install -e . 21 | ``` 22 | 23 | Don't forget to init the submodules and pip install -e on them respectively. We tested SAMesh on python 3.12 and cuda 11.8. If you encounter issues with building SAM2, try with the `--no-build-isolation` flag. If you pyrenderer issues related to ctypes, try installing `PyOpenGL==3.1.7`. 24 | 25 | 26 | ## Getting Started 27 | 28 | Download a SAM2 checkpoint as provided in the SAM2 repo. `notebooks/mesh_samesh.ipynb` and `notebooks/mesh_shape_diameter_function.ipynb` are detail how to setup and run SAMesh and ShapeDiam, respectively. Some mesh examples from the curated dataset are provided in `assets`. 29 | 30 | 31 | ## Dataset 32 | 33 | [Download link](https://drive.google.com/file/d/1qzxZZ-RUShNgUKXBPnpI1-Mlr8MkWekN/view?usp=sharing) 34 | 35 | 36 | ## Parameter Tuning 37 | `configs/` contains the settings used for our dataset, CoSeg, as well as Princeton Mesh Segmentation Benchmark for Segment Any Mesh and Shape Diameter Function. Other datasets may need different parameters/settings. For example, PartNet works best with mode `matte` since many meshes are low poly, resulting in subpar normal and shape diameter function scalar renderings. In addition, for certain meshes where some faces are large e.g. PartNet, you should add a parameter `connections_threshold=0` under sam_mesh in the config, which controls how the minimum number of faces need to be covered by two regions for them to be considered mergable. Finally, you can disable the cache directory by commenting out the cache entry in the config, as the cache takes disk space. 38 | 39 | 40 | ## Contributors 41 | George Tang*, William Zhao, Logan Ford, David Benhaim, Paul Zhang 42 | 43 | *Work done during an internship at Backflip AI. 44 | -------------------------------------------------------------------------------- /assets/arch2.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/arch2.glb -------------------------------------------------------------------------------- /assets/ice_cream.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/ice_cream.glb -------------------------------------------------------------------------------- /assets/jacket.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/jacket.glb -------------------------------------------------------------------------------- /assets/potion.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/potion.glb -------------------------------------------------------------------------------- /assets/samesh_examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/samesh_examples.png -------------------------------------------------------------------------------- /assets/samesh_modalities.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/samesh_modalities.png -------------------------------------------------------------------------------- /assets/samesh_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/samesh_pipeline.png -------------------------------------------------------------------------------- /assets/squidward_house.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/assets/squidward_house.glb -------------------------------------------------------------------------------- /configs/mesh_segmentation.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | cache: /home/gtangg12/samesh/outputs/mesh_segmentation_cache 3 | cache_overwrite: False 4 | output: /home/gtangg12/samesh/outputs/mesh_segmentation_output 5 | 6 | sam: 7 | sam: 8 | checkpoint: /home/gtangg12/samesh/checkpoints/sam2_hiera_large.pt 9 | model_config: sam2_hiera_l.yaml 10 | auto: True 11 | ground: False 12 | engine_config: 13 | points_per_side: 32 14 | crop_n_layers: 0 15 | #min_mask_region_area: 1024 # sam2 breaks since it uses C connected components 16 | pred_iou_thresh: 0.5 17 | stability_score_thresh: 0.7 18 | stability_score_offset: 1.0 19 | 20 | sam_mesh: 21 | use_modes: ['sdf', 'norms'] 22 | min_area: 1024 23 | connections_bin_resolution: 100 24 | connections_bin_threshold_percentage: 0.125 25 | smoothing_threshold_percentage_size: 0.025 26 | smoothing_threshold_percentage_area: 0.025 27 | smoothing_iterations: 64 28 | repartition_cost: 1 29 | repartition_lambda: 6 30 | repartition_iterations: 1 31 | 32 | renderer: 33 | target_dim: [1024, 1024] 34 | camera_generation_method: icosahedron #octohedron 35 | renderer_args: 36 | interpolate_norms: True 37 | sampling_args: {radius: 2} 38 | lighting_args: {} -------------------------------------------------------------------------------- /configs/mesh_segmentation_coseg.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | cache: /home/gtangg12/samesh/outputs/mesh_segmentation_cache_coseg 3 | cache_overwrite: False 4 | output: /home/gtangg12/samesh/outputs/mesh_segmentation_output_coseg 5 | 6 | sam: 7 | sam: 8 | checkpoint: /home/gtangg12/samesh/checkpoints/sam2_hiera_large.pt 9 | model_config: sam2_hiera_l.yaml 10 | auto: True 11 | ground: False 12 | engine_config: 13 | points_per_side: 32 14 | crop_n_layers: 0 15 | #min_mask_region_area: 1024 # sam2 breaks since it uses C connected components 16 | pred_iou_thresh: 0.5 17 | stability_score_thresh: 0.7 18 | stability_score_offset: 1.0 19 | 20 | sam_mesh: 21 | use_modes: ['sdf', 'norms'] 22 | min_area: 1024 23 | connections_bin_resolution: 100 24 | connections_bin_threshold_percentage: 0.05 25 | smoothing_threshold_percentage_size: 0.025 26 | smoothing_threshold_percentage_area: 0.025 27 | smoothing_iterations: 64 28 | repartition_cost: 1 29 | repartition_lambda: 5 30 | repartition_iterations: 1 31 | 32 | renderer: 33 | target_dim: [1024, 1024] 34 | camera_generation_method: icosahedron #octohedron 35 | renderer_args: 36 | interpolate_norms: True 37 | sampling_args: {radius: 2} 38 | lighting_args: {} -------------------------------------------------------------------------------- /configs/mesh_segmentation_princeton.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | cache: /home/gtangg12/samesh/outputs/mesh_segmentation_cache_princeton 3 | cache_overwrite: False 4 | output: /home/gtangg12/samesh/outputs/mesh_segmentation_output_princeton 5 | 6 | sam: 7 | sam: 8 | checkpoint: /home/gtangg12/samesh/checkpoints/sam2_hiera_large.pt 9 | model_config: sam2_hiera_l.yaml 10 | auto: True 11 | ground: False 12 | engine_config: 13 | points_per_side: 32 14 | crop_n_layers: 0 15 | #min_mask_region_area: 1024 # sam2 breaks since it uses C connected components 16 | pred_iou_thresh: 0.5 17 | stability_score_thresh: 0.7 18 | stability_score_offset: 1.0 19 | 20 | sam_mesh: 21 | use_modes: ['sdf', 'norms'] 22 | min_area: 1024 23 | connections_bin_resolution: 100 24 | connections_bin_threshold_percentage: 0.125 25 | smoothing_threshold_percentage_size: 0.025 26 | smoothing_threshold_percentage_area: 0.025 27 | smoothing_iterations: 64 28 | repartition_cost: 1 29 | repartition_lambda: 6 30 | repartition_iterations: 1 31 | 32 | renderer: 33 | target_dim: [1024, 1024] 34 | camera_generation_method: icosahedron #octohedron 35 | renderer_args: 36 | interpolate_norms: True 37 | sampling_args: {radius: 2} 38 | lighting_args: {} -------------------------------------------------------------------------------- /configs/mesh_segmentation_princeton_dynamic.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | cache: /home/gtangg12/samesh/outputs/mesh_segmentation_cache_princeton_dynamic 3 | cache_overwrite: False 4 | output: /home/gtangg12/samesh/outputs/mesh_segmentation_output_princeton_dynamic 5 | 6 | sam: 7 | sam: 8 | checkpoint: /home/gtangg12/samesh/checkpoints/sam2_hiera_large.pt 9 | model_config: sam2_hiera_l.yaml 10 | auto: True 11 | ground: False 12 | engine_config: 13 | points_per_side: 32 14 | crop_n_layers: 0 15 | #min_mask_region_area: 1024 # sam2 breaks since it uses C connected components 16 | pred_iou_thresh: 0.5 17 | stability_score_thresh: 0.7 18 | stability_score_offset: 1.0 19 | 20 | sam_mesh: 21 | use_modes: ['sdf', 'norms'] 22 | min_area: 1024 23 | connections_bin_resolution: 100 24 | connections_bin_threshold_percentage: 0.125 25 | smoothing_threshold_percentage_size: 0.025 26 | smoothing_threshold_percentage_area: 0.025 27 | smoothing_iterations: 64 28 | repartition_cost: 1 29 | repartition_lambda: 6 30 | repartition_lambda_tolerance: 1 31 | repartition_lambda_lb: 1 32 | repartition_lambda_ub: 15 33 | repartition_iterations: 1 34 | 35 | renderer: 36 | target_dim: [1024, 1024] 37 | camera_generation_method: icosahedron #octohedron 38 | renderer_args: 39 | interpolate_norms: True 40 | sampling_args: {radius: 2} 41 | lighting_args: {} -------------------------------------------------------------------------------- /configs/mesh_segmentation_shape_diameter_function.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | output: /home/gtangg12/samesh/outputs/mesh_segmentation_output_shape_diameter_function 3 | 4 | num_components: 5 5 | repartition_lambda: 15 #6 6 | repartition_iterations: 1 -------------------------------------------------------------------------------- /configs/mesh_segmentation_shape_diameter_function_coseg.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | output: /home/gtangg12/samesh/outputs/mesh_segmentation_output_coseg_shape_diameter_function 3 | 4 | num_components: 3 5 | repartition_lambda: 15 #6 6 | repartition_iterations: 1 -------------------------------------------------------------------------------- /configs/mesh_segmentation_shape_diameter_function_princeton.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | output: /home/gtangg12/samesh/outputs/mesh_segmentation_output_princeton_shape_diameter_function 3 | 4 | num_components: 5 5 | repartition_lambda: 15 #6 6 | repartition_iterations: 1 -------------------------------------------------------------------------------- /notebooks/mesh_segmentation_annotations.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import json\n", 20 | "import glob\n", 21 | "import numpy as np\n", 22 | "from pathlib import Path\n", 23 | "from collections import defaultdict" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 59, 29 | "metadata": {}, 30 | "outputs": [], 31 | "source": [ 32 | "def class2rank(name: dict) -> list[int]:\n", 33 | " \"\"\"\n", 34 | " \"\"\"\n", 35 | " return [ord(c) - ord('a') + 1 for c in name]\n", 36 | "\n", 37 | "\n", 38 | "def compute_rank_stats(path_annotations: Path | str, path_metadata: Path | str, key: str) -> dict:\n", 39 | " \"\"\"\n", 40 | " \"\"\"\n", 41 | " metadata = json.load(open(path_metadata))\n", 42 | " annotations = {}\n", 43 | " for filename in glob.glob(str(path_annotations / 'annotations/worker-response/iteration-1/*/*.json')):\n", 44 | " modelid = Path(filename).parent.stem\n", 45 | " annotations[int(modelid)] = json.load(open(filename))['answers']\n", 46 | " assert len(metadata) == len(annotations)\n", 47 | "\n", 48 | " method2ranks = {}\n", 49 | " for i in range(len(metadata)):\n", 50 | " ranks = []\n", 51 | " for ans in annotations[i]:\n", 52 | " ranks.append(class2rank(ans['answerContent']['crowd-classifier']['label']))\n", 53 | " ranks_average = []\n", 54 | " for j in range(len(ranks[0])):\n", 55 | " ranks_average.append(sum([r[j] for r in ranks]) / len(ranks))\n", 56 | " for method, index in metadata[i].items():\n", 57 | " method2ranks.setdefault(method, []).append(ranks_average[index])\n", 58 | " method2rank_stats = {\n", 59 | " method: {\n", 60 | " 'avg': sum(ranks) / len(ranks),\n", 61 | " 'std': np.std(ranks),\n", 62 | " } \n", 63 | " for method, ranks in method2ranks.items()\n", 64 | " }\n", 65 | " return method2rank_stats" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 60, 71 | "metadata": {}, 72 | "outputs": [ 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "{'combined': {'avg': 1.176, 'std': 0.2712145522152772}, 'shape_diameter_function': {'avg': 1.824, 'std': 0.2712145522152772}}\n" 78 | ] 79 | } 80 | ], 81 | "source": [ 82 | "print(compute_rank_stats(\n", 83 | " Path('/home/ubuntu/meshseg/tests/mesh-segmentation-samesh-v-sdf-annotations'), \n", 84 | " Path('/home/ubuntu/meshseg/tests/mesh_segmentation-samesh-v-sdf/metadata.json'),\n", 85 | " key='mesh-segmentation-high-quality-metadata'\n", 86 | "))" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 61, 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "{'combined': {'avg': 2.3253333333333335, 'std': 0.5912908665698202}, 'matte': {'avg': 2.5866666666666664, 'std': 0.6946142014736589}, 'norm': {'avg': 2.5653333333333332, 'std': 0.577233247675688}, 'sdf': {'avg': 2.5226666666666664, 'std': 0.7503907130436931}}\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "print(compute_rank_stats(\n", 104 | " Path('/home/ubuntu/meshseg/tests/mesh-segmentation-samesh-modalities-annotations'),\n", 105 | " Path('/home/ubuntu/meshseg/tests/mesh_segmentation-samesh-modalities/metadata.json'),\n", 106 | " key='mesh-segmentation-ablation-high-quality-metadata'\n", 107 | "))" 108 | ] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "meshseg", 114 | "language": "python", 115 | "name": "python3" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.12.4" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 2 132 | } 133 | -------------------------------------------------------------------------------- /notebooks/mesh_segmentation_recolormap.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import glob\n", 20 | "import json\n", 21 | "import numpy as np\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "from tqdm import tqdm\n", 24 | "from pathlib import Path\n", 25 | "from samesh.data.loaders import read_mesh\n", 26 | "from samesh.utils.mesh import duplicate_verts" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 7, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "MESH_DIRS = [\n", 36 | " '/home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-sdf',\n", 37 | " '/home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-norm',\n", 38 | " '/home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-combined',\n", 39 | " '/home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-matte',\n", 40 | " '/home/ubuntu/meshseg/tests/mesh_segmentation_output_shape_diameter_function-5-15',\n", 41 | "]" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 5, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "data": { 51 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxYAAAMWCAYAAABsvhCnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAcdklEQVR4nO3a2+tv+X3X8ff6Hn+nffjNnr3nkJkwTbQmxJBUWtG0TanaVuqFeuMIEU8XBQURbbEtiplKvRBKRUGkpdgawVooFNILYzCUtGortUqrzUVTE5JOmxmzZx9+59/3tPwX1uLFj8Xa83hcfy5efPc6fJ/f327atm0LAAAgMBl6AAAAMH7CAgAAiAkLAAAgJiwAAICYsAAAAGLCAgAAiAkLAAAgJiwAAICYsAAAAGKzrgc//cP/vHab7U1ueaYslmf1kW/5hZpMfGZdXewd1S99/C/Xbtr5snzX215N6sn/eFDVNkNPGY2n11U/+b8ntfWZdbabPKqre29UNZuhp4xG8+i47nzyk9Vs5kNPGY0XFo/q5z/8Ri0nrrOu1u2yvnT9rdX6nbiz5m7V4Q8tqpl7B/Txrd/+U53Odb4SRUU/s/m1qOhptdgTFT21m4mo6OlyU6Kir8mZqOhpcnYkKnq6OzsTFT1t24Wo6Kk5bETFDXI1AgAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEZl0PPtxb1Nm88/F3vb2DVd1/8lK9cHY99JTRmBzervWyradTvdvV9rKp86fLmr+wHHrKaKzaqtea8zp4ejb0lNHYzVbVbPZrPbkcespoTE/WdefotK7b46GnjMZrewf19N4H6+L2bugpo3G9O6ivrl+p7Woz9JTRaB40db/Wdbe8A25C51L4B9/2R2vtC19nLz26rs+8Ma120w49ZTQuN5P62T84qHXbDD1lNKYvH9TLn//OavamQ08ZjeOHT+oXP/8var72Iu6q2Z3X4fW2mu6vjHe9XbV1/tHXqprF0FNG48m9F+rH/9mfqc3Cd42uLjZt/ec3NyXFurvbntYP16fqpPaHnvJM6nz3iop+js82tRQVvZxsJ6Kip8lzS1HR0+HZhajoqanranx16aVtbouKns6PJqKip9W23Jk9HdZlzWs79IxnljsYAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACA263rwbi3qsOY3ueWZcm8xqdP336359mroKaPRrmb13AvTOh96yIjM3l81P/vdml5eDD1lNKZPv15fnKxqtboz9JTRmOxu1Svb43qxcZ111TZV1/V2rWf7Q08ZjdVuWnffWlS7vzf0lNGYrNp6bXVWm9lm6Cmjcb89qZPL23W1Wg495ZnUOSx+cPZNNW/8gaOr2Xuv6mtvnNek2Q09ZTTO5m09+CPb2rrMOpudPKzXfvI7arJdDz1lNE4ujuvPP/wntd35oaSrl+usPjf5Qq08zzq7mi3rfz3/P6udeKB1t6xv+/VvqWbqM+tqN72qj7/661U792ZXJxcH9dOf+yu13U2HnjIqP/i93c51vntFRT/TWomKnq5mJSp6ml6ciYqeLldHoqKn47qqpedZL+vpTFT0NZ+Lip7a6bpq4t7s43K1FBU3yB0MAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABCbdT14Mn+nLmenN7nlmTJvzmt/9ZW6dbkZespoXF+3Nf1iW+vN/tBTRmPz8HEtzh7U7b3l0FNGo2lfrPfdaeqkaYaeMhpHtagvN69Wu/WZdXUx36/N5EEdt7uhp4xHe1AnD5d1tvWZdbWd79Vyf1G3b18MPWU02t11PVh9sb5+th56ysj84U6nmrZt2y4HP3Dy12o78Y/Q1eHJRf2lf/uZmnlAdnZ9vqjf/PSHq936Q1pXx3tH9SPf/omaTzv/RvCudzJp6qefu1VbYdHZ9XZev/XOH6rWH7k7e37X1L++3qtFuc66enK1qx//jbPadPpWQlXV8e3T+kd/51M1n22HnjIabz06rdd/9D/UauMz6+PX3l51Otf5LSEq+tm7uhYVPW2uZqKip6PFnqjo6XLSiIqeNrupqOjpdpWo6Ol804qKng4PLkVFT0/Pr0TFDfKmAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACICQsAACAmLAAAgJiwAAAAYsICAACIzTofPL+oyWp1k1ueKdPT02ral+toczz0lNGYtXt1fHy3LidPhp4yGtPDTf3ewdvVzDrfyu9670wnta1JzS6Php4yGrNNW3dPrqs2fovq6mDX1NPrx7WYXAw9ZTS2l9t6vmmqnd8eespo3Jqs6vrJpubz06GnjMZ8fVoffM+81jPPs5vQtG3bdjn4PT9zVM1ud9N7nhl3Vi/V9//2Z2ve7g09ZTQupg/rcy/+3dpN1kNPGY31vKmvfuio2kkz9JTRmJ3dqtd+4W/UZCfGuppcNnXvv96uZuc662pej+pDk39ck9oMPWU8Frdq+uHvq2bi3uxqNntc7/uGf1qTieusq3avqfXHbldNPc/6uPs9X+h0rnOuiYp+DjfHoqKn1eREVPS0nTWioqfp9b6o6GmynoiKnmZ1Jir6mh2Iip6m0zNR0VM7n4iKG+TvQAAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAsVnXg3c20zrc6pCu7q+3Nbnz5Zoubw09ZTT2Zyd1NFnWxcl66CmjMVs31bxzp66Xh0NPGY326rnatlWHq+3QU0Zjumnr8tXfr8vbZ0NPGY29eqv+79VVTZ5Oh54yHnuX9Xx9ve6tjoZeMhpNu62vXb9Y5/vN0FNGo20n9WT9oGrPvdnHd3Q81zks/v4fvFzz1oXb1fSwqZf/5mdqMnPhdnX26LpmPzqvo03ny/Jd72J+t95c/UDtJvOhp4zG0eW6PvDVr9Zs1w49ZTRO7z+sf//ZH6jtUvR3tXk6qbf+1b2qrfdmV/cv1/VTv/K1Wuy8N7t6tD+tv/dN31ebqR9+u7ra26v//tzHazd1nfWx63iu85UoKvqZ7i1FRU/X5+vabXzZ62M1OxQVPe2vt6Kip6s7p6Kip93FRFT0dHu1Jyp6OltMREVP68VCVNwgVyMAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQGzW9eA7J/M6u+x8/F1vfrWod744q/mtw6GnjMbTR/Oq4/fV8bIZespozPYe1GvPLeu6dkNPGY3jRVub95zX+vLR0FNGY33wdk1P79XV8mDoKaOxaWd1fOuwDq+vh54yGvf3D+vpy+valndAV2dHVfePpjUr3zW6ur1/WG9enNTTjXuzn1c7nepcCv/w33yw1lt/4OiqPbxV10++u2oqxrp6cDivf/e3P1DLmeusq3Wzq4/eOq/We7izq8uv1a/e+2Ttdquhp4zGxfJOvf34R2o3nQ89ZTTuXl3XDz3/f2retkNPGY3raVv/8U8/rdYroLNl3aq/tfzOmja+a3T16Pq03vy1n6sX2u3QU8blu/5Yp2Odb19R0U+73BcVPd3Zm4mKnrZNKyp6Wq0fi4qerueHoqKnw/VGVPS0mZao6GneHIiKns43V7URFTfGLQwAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEJt1PTiZbWsy3d3klmdKO72sw5M3a69ZDj1lNI4ny1r/znHtzY6GnjIa88munux/pd5Zngw9ZTTOz79UpyfP1eHiztBTRmMyuV/v+d1lzdft0FNG48H5trZfebHW8+nQU0ZjPW9rcXVVk8Xl0FNGY9Es6mrydtV0M/SU0dhePa73Pn5S04XP7CY0bdt2elP82Y/+qWqam57z7Dhsp/UXV6/WrHxoXe23R/Xdq79a0+69+6735vSt+tgrr9d1sxp6ymgczh/U69/4MzWbiP6uDk7b+nM/t67pdugl43HVLOo3Dj5SbeM/BnQ1nV7Uiy/8p2oaP2J2tW0v63H9clX5zLpaTq/qm9/z32rqOuvlGz/9VqdznZ94oqKfZTsVFT0t231R0dOj6VNR0dPe9I6o6Gl51YqKnjbNTFT0NJlci4qe2lqVqOhnPl2JihvkqQcAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEJt1PTivo5rV3k1ueabMa16PXnqpNnfnQ08ZjcvdcX3pdFpHJ+uhp4zG6XxR9+58pK6qHXrKaOwtX63d5UUtzk+GnjIeT3Z1cTytxdHB0EtGY9c0dXjr92qy2w49ZTRms9Navnxa1XhvdjVt13U4WVa78g7oam++X9ev7NW2zoee8kzqHBYfnL5ek+7H3/WuXj6on/357612OR16ymjsP72ur//Eb9Z06wHZ1VXz3vrQrc9W2/jjY1fLJw/r/o99f003Arar+fF+nX/yu+py/mToKaMxbS7rT976iZo0u6GnjEY7n9X6A99QNfE862zbVvvo5WqG3jEiq9lJ/fb7H1c7Ef19fLTjuc53r6joZ3N3KSp6WlxuREVP62YqKnqan5+Kip5mh4uazD3P+pg0a1HRUzubioq+diUqetrOLkXFDXIHAwAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEhAUAABATFgAAQExYAAAAMWEBAADEZl0PnqwndbVtbnLLM+XiSdXmzbdq8eJm6Cmj0TaXNZ88rjtP5kNPGY2Dyar2F79f/+/QbwRdbfffqYcfeLEO6mDoKaOxfeGoHl3crfljz7OuZpOq3e1Xqj3wmXW1u1zUk6/cqt1kNfSU8di0NT87q9vL66GXjMb1YlXb3UEdtNuhp4zLN3c71jksPvXlu7VthUVXy/Vl/fHpj9X0ZDf0lNG492Raf/2XHtRi7Trr6s2jbb3xiUd13flOZu+lO3Xr9X9Z04mA7Wr69qaefuLN8n2vu9Wyrd/52J+odjr0kvG43K7q84+/ULtqh54yGoez8/oL7//Vmk581+jq6Hpbr//WKzVzmfXzX7od6/wzp6joZ368qunSjd7H7dOJqOjp0d5OVPS0mB2Kip6mT7aioqftohUVPa3ajajoaTm7EhU97W12ouIG+f8TAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSatm3boUcAAADj5i8WAABATFgAAAAxYQEAAMSEBQAAEBMWAABATFgAAAAxYQEAAMSEBQAAEBMWAABA7P8DO132WNdp/6sAAAAASUVORK5CYII=", 52 | "text/plain": [ 53 | "
" 54 | ] 55 | }, 56 | "metadata": {}, 57 | "output_type": "display_data" 58 | } 59 | ], 60 | "source": [ 61 | "import random\n", 62 | "\n", 63 | "grid_size = 10\n", 64 | "\n", 65 | "COLORS = np.random.rand(grid_size ** 2, 3)\n", 66 | "\n", 67 | "fig, ax = plt.subplots(figsize=(grid_size, grid_size))\n", 68 | "\n", 69 | "for i in range(grid_size):\n", 70 | " for j in range(grid_size):\n", 71 | " idx = i * grid_size + j\n", 72 | " rect = plt.Rectangle((j, grid_size - i - 1), 1, 1, color=COLORS[idx])\n", 73 | " ax.add_patch(rect)\n", 74 | "\n", 75 | "ax.set_xlim(0, grid_size)\n", 76 | "ax.set_ylim(0, grid_size)\n", 77 | "ax.set_aspect('equal')\n", 78 | "ax.axis('off')\n", 79 | "plt.show()" 80 | ] 81 | }, 82 | { 83 | "cell_type": "code", 84 | "execution_count": 8, 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "Recoloring segmented meshes in /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-sdf\n" 92 | ] 93 | }, 94 | { 95 | "name": "stderr", 96 | "output_type": "stream", 97 | "text": [ 98 | "0it [00:00, ?it/s]" 99 | ] 100 | }, 101 | { 102 | "name": "stderr", 103 | "output_type": "stream", 104 | "text": [ 105 | "75it [00:08, 8.96it/s]\n" 106 | ] 107 | }, 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "Recoloring segmented meshes in /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-norm\n" 113 | ] 114 | }, 115 | { 116 | "name": "stderr", 117 | "output_type": "stream", 118 | "text": [ 119 | "75it [00:08, 8.94it/s]\n" 120 | ] 121 | }, 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "Recoloring segmented meshes in /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-combined\n" 127 | ] 128 | }, 129 | { 130 | "name": "stderr", 131 | "output_type": "stream", 132 | "text": [ 133 | "75it [00:08, 8.97it/s]\n" 134 | ] 135 | }, 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "Recoloring segmented meshes in /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-matte\n" 141 | ] 142 | }, 143 | { 144 | "name": "stderr", 145 | "output_type": "stream", 146 | "text": [ 147 | "75it [00:08, 8.97it/s]\n" 148 | ] 149 | }, 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "Recoloring segmented meshes in /home/ubuntu/meshseg/tests/mesh_segmentation_output_shape_diameter_function-5-15\n" 155 | ] 156 | }, 157 | { 158 | "name": "stderr", 159 | "output_type": "stream", 160 | "text": [ 161 | "75it [00:08, 8.69it/s]\n" 162 | ] 163 | } 164 | ], 165 | "source": [ 166 | "for dir in MESH_DIRS:\n", 167 | " print('Recoloring segmented meshes in ', dir)\n", 168 | " labels = glob.glob(f'{dir}/*/*.json')\n", 169 | " meshes = glob.glob(f'{dir}/*/*segmented.glb') # only recolor original meshes\n", 170 | "\n", 171 | " for filename_face2label, filename_mesh in tqdm(zip(labels, meshes)):\n", 172 | " with open(filename_face2label, 'r') as f:\n", 173 | " face2label = json.load(f)\n", 174 | " face2label = {int(k): v for k, v in face2label.items()}\n", 175 | " face2label = sorted(face2label.items())\n", 176 | " face2label = np.array([v for (_, v) in face2label])\n", 177 | " face2label_renumbered = np.zeros_like(face2label)\n", 178 | " for i, v in enumerate(np.unique(face2label)):\n", 179 | " face2label_renumbered[face2label == v] = i\n", 180 | "\n", 181 | " mesh = read_mesh(filename_mesh)\n", 182 | " mesh = duplicate_verts(mesh) # avoid face color interpolation due to OpenGL storing data in vertices\n", 183 | " mesh.visual.face_colors = COLORS[face2label_renumbered[:mesh.faces.shape[0]]]\n", 184 | " mesh.export(filename_mesh.replace('.glb', '_recolored.glb'))" 185 | ] 186 | } 187 | ], 188 | "metadata": { 189 | "kernelspec": { 190 | "display_name": "meshseg", 191 | "language": "python", 192 | "name": "python3" 193 | }, 194 | "language_info": { 195 | "codemirror_mode": { 196 | "name": "ipython", 197 | "version": 3 198 | }, 199 | "file_extension": ".py", 200 | "mimetype": "text/x-python", 201 | "name": "python", 202 | "nbconvert_exporter": "python", 203 | "pygments_lexer": "ipython3", 204 | "version": "3.12.4" 205 | } 206 | }, 207 | "nbformat": 4, 208 | "nbformat_minor": 2 209 | } 210 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "samesh" 7 | version = "0.0.1" 8 | description = "Segment Any Mesh" 9 | readme = "README.md" 10 | requires-python = ">=3.12.0" 11 | classifiers = [ 12 | "Programming Language :: Python", 13 | ] 14 | dependencies = [ 15 | "pandas", 16 | "omegaconf", 17 | "igraph", 18 | "networkx", 19 | "pyrender", 20 | "pymeshlab", 21 | "trimesh", 22 | "lightning", 23 | "tqdm", 24 | "scikit-learn", 25 | "natsort", 26 | "numpy==1.26.4", 27 | "torch==2.3.1", 28 | "torchvision==0.18.1", 29 | "torchtyping", 30 | "matplotlib", 31 | "opencv-python", 32 | "transformers", 33 | ] 34 | 35 | # Install PyOpenGL==3.1.7 if pyrender errors 36 | 37 | [project.optional-dependencies] 38 | 39 | # Development dependencies 40 | dev = [] 41 | 42 | # Install SAM2 with --no-build-isolation flag if build errors 43 | 44 | [project.scripts] 45 | 46 | [tool.setuptools.packages.find] 47 | where = ["src"] 48 | include = ["samesh*"] 49 | -------------------------------------------------------------------------------- /scripts/convert_gif2mp4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | from pathlib import Path 5 | import multiprocessing as mp 6 | 7 | import numpy as np 8 | import imageio 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | 13 | def convert_gif2mp4(ifilename: Path | str, ofilename: Path | str, duration=8): 14 | """ 15 | """ 16 | print(f'Converting {ifilename} to {ofilename}.mp4') 17 | 18 | with Image.open(ifilename) as gif: 19 | frames = [] 20 | try: 21 | while True: 22 | frame = gif.convert('RGB') 23 | frames.append(np.array(frame)) 24 | gif.seek(gif.tell() + 1) 25 | except EOFError: 26 | pass 27 | 28 | imageio.mimsave(ofilename, frames, format='ffmpeg', fps=len(frames) // duration) 29 | 30 | 31 | def convert_filenames(idir: Path | str, odir: Path | str): 32 | """ 33 | """ 34 | os.makedirs(odir, exist_ok=True) 35 | 36 | chunks = [] 37 | filenames = glob.glob(f'{idir}/*.gif') 38 | for filename in tqdm(filenames): 39 | ofilename = os.path.join(odir, os.path.basename(filename).replace('.gif', '.mp4')) 40 | chunks.append((filename, ofilename)) 41 | 42 | with mp.Pool(mp.cpu_count()) as pool: 43 | pool.starmap(convert_gif2mp4, chunks) 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser( 48 | description='Convert gif to mp4 format' 49 | ) 50 | parser.add_argument( 51 | '-id', '--idir', type=str, help='Path to the directory containing the meshes' 52 | ) 53 | parser.add_argument( 54 | '-od', '--odir', type=str, help='Path to the directory containing the rendered gifs' 55 | ) 56 | args = parser.parse_args() 57 | convert_filenames(args.idir, args.odir) -------------------------------------------------------------------------------- /scripts/convert_mesh2gif.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import multiprocessing as mp 5 | from pathlib import Path 6 | from tqdm import tqdm 7 | 8 | from samesh.data.loaders import read_mesh 9 | from samesh.renderer.renderer_animations import mesh2gif 10 | 11 | 12 | def convert_mesh2gif_worker(filename: Path | str, args: argparse.Namespace): 13 | """ 14 | """ 15 | filename_out = Path(args.odir) / Path(filename).name.replace(f'.{args.load_extension}', f'.gif') 16 | source = read_mesh(filename, process=False) 17 | mesh2gif(source, filename_out, fps=args.fps, length=args.length, key=args.key, blend=0.5, background=0) 18 | 19 | 20 | def convert_mesh2gif(args: argparse.Namespace): 21 | """ 22 | """ 23 | os.makedirs(args.odir, exist_ok=True) 24 | 25 | filenames = glob.glob(f'{args.idir}/*/*_segmented_recolored.{args.load_extension}') 26 | chunks = [ 27 | (filename, args) for filename in filenames 28 | ] 29 | with mp.Pool(mp.cpu_count()) as pool: 30 | pool.starmap(convert_mesh2gif_worker, chunks) 31 | print(f'Converted {len(filenames)} meshes to gifs') 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser( 36 | description='Convert mesh to rendererd gifs' 37 | ) 38 | parser.add_argument( 39 | '-id', '--idir', type=str, help='Path to the directory containing the meshes' 40 | ) 41 | parser.add_argument( 42 | '-od', '--odir', type=str, help='Path to the directory containing the rendered gifs' 43 | ) 44 | parser.add_argument( 45 | '-le', '--load_extension', type=str, default='glb', help='Extension of the meshes to load' 46 | ) 47 | parser.add_argument( 48 | '--fps', type=int, default=30, help='Frames per second' 49 | ) 50 | parser.add_argument( 51 | '--length', type=int, default=120, help='Number of frames' 52 | ) 53 | parser.add_argument( 54 | '--key', type=str, default='face_colors', help='Key to render' 55 | ) 56 | args = parser.parse_args() 57 | 58 | convert_mesh2gif(args) 59 | 60 | ''' 61 | python -m scripts.convert_mesh2gif -id /home/ubuntu/meshseg/tests/mesh_segmentation_output_shape_diameter_function-5-15 -od /home/ubuntu/meshseg/tests/mesh_segmentation_output_shape_diameter_function-5-15-gifs 62 | python -m scripts.convert_mesh2gif -id /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-sdf/ -od /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-sdf-gifs 63 | python -m scripts.convert_mesh2gif -id /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-norm/ -od /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-norm-gifs 64 | python -m scripts.convert_mesh2gif -id /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-combined/ -od /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-combined-gifs 65 | python -m scripts.convert_mesh2gif -id /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-matte/ -od /home/ubuntu/meshseg/tests/mesh_segmentation_output-dynamic-0.125-6-0.5-matte-gifs 66 | ''' -------------------------------------------------------------------------------- /scripts/convert_mesh_formats.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | 7 | import numpy as np 8 | import trimesh 9 | 10 | 11 | ''' 12 | NOTE Common mesh format: 13 | - glb format 14 | - poses embedded in mesh graph 15 | - textures, if exist, are expressed as uv coordinates 16 | ''' 17 | 18 | def convert_backflip(filename: str, filename_out: str): 19 | """ 20 | """ 21 | source = trimesh.load(filename) 22 | if isinstance(source, trimesh.Scene): 23 | data = list(source.graph.transforms.edge_data.values())[0] 24 | name = data['geometry'] 25 | pose = data['matrix'] 26 | data['matrix'] = np.eye(4) 27 | source.graph[name] = pose 28 | source.export(filename_out) 29 | 30 | 31 | def convert_meshseg_benchmark(filename: str, filename_out: str): 32 | """ 33 | """ 34 | trimesh.load(filename).export(filename_out) 35 | 36 | 37 | CONVERTERS = { 38 | 'backflip' : convert_backflip, 39 | 'meshseg_benchmark': convert_meshseg_benchmark 40 | } 41 | 42 | 43 | def convert_formats(idir: Path, odir: Path, load_extension: str, save_extension='glb', origin='backflip'): 44 | """ 45 | """ 46 | os.makedirs(odir, exist_ok=True) 47 | 48 | filenames = glob.glob(str(idir) + f'/*.{load_extension}') 49 | for filename in tqdm(filenames): 50 | filename_out = odir / Path(filename).name.replace(f'.{load_extension}', f'.{save_extension}') 51 | if origin not in CONVERTERS: 52 | raise ValueError(f'Conversion from data source {origin} not supported') 53 | CONVERTERS[origin](filename, filename_out) 54 | 55 | print(f'Converted {len(filenames)} meshes from {origin} to {save_extension}') 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser( 60 | description='Convert mesh formats to common format' 61 | ) 62 | parser.add_argument( 63 | '-id', '--idir', type=str, help='Path to the directory containing the meshes' 64 | ) 65 | parser.add_argument( 66 | '-od', '--odir', type=str, help='Path to the directory containing the processed meshes' 67 | ) 68 | parser.add_argument( 69 | '-le', '--load_extension', type=str, help='Extension of the meshes to load' 70 | ) 71 | parser.add_argument( 72 | '-se', '--save_extension', type=str, help='Extension of the meshes to save' 73 | ) 74 | parser.add_argument( 75 | '-o', '--origin', type=str, default='backflip', help='Origin of the meshes' 76 | ) 77 | args = parser.parse_args() 78 | 79 | convert_formats(Path(args.idir), Path(args.odir), args.load_extension, args.save_extension, args.origin) 80 | -------------------------------------------------------------------------------- /scripts/generate_dataset_mesh_segmentation_comparison.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import json 5 | import re 6 | from pathlib import Path 7 | import multiprocessing as mp 8 | 9 | import numpy as np 10 | from PIL import Image, ImageSequence 11 | from tqdm import tqdm 12 | 13 | from samesh.renderer.renderer_animations import images2gif 14 | 15 | 16 | def combine(filenames: list[str], output: Path | str, fps=30): 17 | """ 18 | """ 19 | image_sequences = [] 20 | for filename in tqdm(filenames): 21 | sequence = [] 22 | for frame in ImageSequence.Iterator(Image.open(filename)): 23 | if frame.mode != 'RGB': 24 | frame = frame.convert('RGB') 25 | sequence.append(np.array(frame)) 26 | image_sequences.append(sequence) 27 | 28 | combined = [] 29 | for i in range(len(image_sequences[0])): 30 | combined_frame = np.concatenate([x[i] for x in image_sequences], axis=1) 31 | combined.append(Image.fromarray(combined_frame)) 32 | images2gif(combined, path=output, duration=len(combined) / fps) 33 | 34 | 35 | def combine_renders(idirs: list[Path | str], odir: Path | str, shuffle=True): 36 | """ 37 | """ 38 | filenames_accum = [] 39 | for idir in idirs: 40 | filenames = sorted(list(glob.glob(f'{idir}/*.gif'))) 41 | filenames_accum.append(filenames) 42 | filenames_combined = [ 43 | [filenames[i] for filenames in filenames_accum] 44 | for i in range(len(filenames_accum[0])) 45 | ] 46 | 47 | pattern = re.compile('.*(shape_diameter_function|sdf|matte|norm|combined).*') 48 | chunks = [] 49 | os.makedirs(odir, exist_ok=True) 50 | metadata = [] 51 | for filenames in filenames_combined: 52 | if shuffle: 53 | indices = np.random.permutation(len(filenames)) 54 | split2index = {} 55 | for i, index in enumerate(indices): 56 | split = pattern.match(filenames[index]).group(1) 57 | split2index[split] = i 58 | metadata.append(split2index) 59 | filenames = [filenames[index] for index in indices] 60 | else: 61 | metadata.append({ 62 | filename: i for i, filename in enumerate(filenames) 63 | }) 64 | output_filename = odir / f'{Path(filenames[0]).stem}.gif' 65 | chunks.append((filenames, output_filename)) 66 | 67 | with open(odir / 'metadata.json', 'w') as f: 68 | json.dump(metadata, f) 69 | 70 | with mp.Pool(mp.cpu_count()) as pool: 71 | pool.starmap(combine, chunks) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser( 76 | description='Convert mesh formats to common format' 77 | ) 78 | parser.add_argument( 79 | '-id', '--idir', required=True, nargs='+', 80 | help='List of dirs containing rendererd videos of mesh segmentations' 81 | ) 82 | parser.add_argument( 83 | '-od', '--odir', type=str, required=True, 84 | help='Output directory of combined mesh segmentation comparison' 85 | ) 86 | args = parser.parse_args() 87 | 88 | combine_renders(list(map(Path, args.idir)), Path(args.odir), shuffle=True) -------------------------------------------------------------------------------- /src/samesh/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/src/samesh/__init__.py -------------------------------------------------------------------------------- /src/samesh/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/src/samesh/data/__init__.py -------------------------------------------------------------------------------- /src/samesh/data/common.py: -------------------------------------------------------------------------------- 1 | from torchtyping import TensorType 2 | 3 | 4 | NumpyTensor = TensorType 5 | TorchTensor = TensorType -------------------------------------------------------------------------------- /src/samesh/data/loaders.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import trimesh 6 | from trimesh.base import Trimesh, ColorVisuals, Scene 7 | 8 | from samesh.data.common import NumpyTensor 9 | from samesh.utils.mesh import transform, norm_mesh, norm_scene, handle_pose 10 | 11 | from trimesh.scene import transforms 12 | 13 | 14 | COLORS = { 15 | 'material_diffuse' : np.array([102, 102, 102, 255], dtype=np.uint8), 16 | 'material_ambient' : np.array([ 64, 64, 64, 255], dtype=np.uint8), 17 | 'material_specular': np.array([197, 197, 197, 255], dtype=np.uint8), 18 | } 19 | 20 | 21 | def remove_texture(source: Trimesh | Scene, material='material_diffuse', visual_kind='face'): 22 | """ 23 | Remove texture from mesh or scene. 24 | """ 25 | def assign(visual, color): 26 | """ 27 | Helper function to assign color to visual given visual kind. 28 | """ 29 | if visual_kind == 'face': 30 | visual.face_colors = color 31 | elif visual_kind == 'vertex': 32 | visual.vertex_colors = color 33 | else: 34 | raise ValueError(f'Invalid visual kind {visual_kind}.') 35 | 36 | if isinstance(source, trimesh.Scene): 37 | for _, geom in source.geometry.items(): 38 | geom.visual = ColorVisuals() 39 | assign(geom.visual, COLORS[material]) 40 | else: 41 | source.visual = ColorVisuals() 42 | assign(source.visual, COLORS[material]) 43 | return source 44 | 45 | 46 | def scene2scene_no_transform(scene: Scene) -> Scene: 47 | """ 48 | 49 | NOTE:: in place operation that consumes scene. 50 | """ 51 | for name, geom in scene.geometry.items(): 52 | if name in scene.graph: 53 | pose, _ = scene.graph[name] 54 | pose = handle_pose(pose) 55 | geom.vertices = transform(pose, geom.vertices) 56 | scene.graph[name] = np.eye(4) 57 | return scene 58 | 59 | 60 | def scene2mesh(scene: Scene, process=True) -> Trimesh: 61 | """ 62 | """ 63 | if len(scene.geometry) == 0: 64 | mesh = None # empty scene 65 | else: 66 | data = [] 67 | for name, geom in scene.geometry.items(): 68 | if name in scene.graph: 69 | pose, _ = scene.graph[name] 70 | pose = handle_pose(pose) 71 | vertices = transform(pose, geom.vertices) 72 | else: 73 | vertices = geom.vertices 74 | # process=True removes duplicate vertices (needed for correct face graph), affecting face indices but not faces.shape 75 | data.append(Trimesh(vertices=vertices, faces=geom.faces, visual=geom.visual, process=process)) 76 | 77 | mesh = trimesh.util.concatenate(data) 78 | mesh = Trimesh(vertices=mesh.vertices, faces=mesh.faces, visual=mesh.visual, process=process) 79 | return mesh 80 | 81 | 82 | def read_mesh(filename: Path, norm=False, process=True) -> Trimesh | None: 83 | """ 84 | Read/convert a possible scene to mesh. 85 | 86 | If conversion occurs, the returned mesh has only vertex and face data i.e. no texture information. 87 | 88 | NOTE: sometimes process=True does unexpected actions, such as cause face color misalignment with faces 89 | """ 90 | source = trimesh.load(filename) 91 | 92 | if isinstance(source, trimesh.Scene): 93 | mesh = scene2mesh(source, process=process) 94 | else: 95 | assert(isinstance(source, trimesh.Trimesh)) 96 | mesh = source 97 | if norm: 98 | mesh = norm_mesh(mesh) 99 | return mesh 100 | 101 | 102 | def read_scene(filename: Path, norm=False) -> Scene | None: 103 | """ 104 | """ 105 | source = trimesh.load(filename) 106 | source = scene2scene_no_transform(source) 107 | if norm: 108 | source = norm_scene(source) 109 | return source 110 | 111 | 112 | if __name__ == '__main__': 113 | mesh = read_mesh('/home/ubuntu/meshseg/tests/examples/0ba4ae3aa97b4298866a2903de4fd1e7.glb') 114 | print(mesh.faces.shape) 115 | print(mesh.vertices.shape) -------------------------------------------------------------------------------- /src/samesh/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/src/samesh/metrics/__init__.py -------------------------------------------------------------------------------- /src/samesh/metrics/mesh_segmentation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import glob 4 | import json 5 | import dataclasses 6 | import multiprocessing as mp 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | from typing import Mapping, Sequence 10 | 11 | import numpy as np 12 | import trimesh 13 | from tqdm import tqdm 14 | 15 | from samesh.data.common import NumpyTensor 16 | from samesh.metrics.mesh_segmentation_cut_discrepancy import compute_cut_discrepancy 17 | 18 | 19 | @dataclass 20 | class Metrics: 21 | """ 22 | Metrics as described in the MeshsegBenchmark paper (https://segeval.cs.princeton.edu/). 23 | """ 24 | # Cut Discrepancy according to SegEval implementation 25 | cut_discrepancy: float 26 | 27 | # Hamming distance and Directional Hamming distances 28 | hamming_distance_rm : float 29 | hamming_distance_rf: float 30 | hamming_distance : float 31 | 32 | # Rand Index 33 | inv_rand_index: float 34 | 35 | # Local and global consistency errors 36 | lce: float 37 | gce: float 38 | 39 | @staticmethod 40 | def average(metrics: Sequence[Metrics]) -> Metrics: 41 | """ 42 | Given a sequence of metrics, compute the average of each metric and return a new Metrics object 43 | """ 44 | metrics_dicts = [dataclasses.asdict(m) for m in metrics] 45 | metrics_means = {k: np.mean( 46 | [d[k] for d in metrics_dicts if d[k] is not None] # handle case where cut discrepancy is undefined 47 | ) for k in metrics_dicts[0].keys()} 48 | return Metrics(**metrics_means) 49 | 50 | def check_bounds(self) -> str | None: 51 | for k, v in dataclasses.asdict(self).items(): 52 | if k == 'cut_discrepancy': 53 | if v is None: # handle case where cut discrepancy is undefined 54 | continue 55 | if not 0 <= v: 56 | return f'{k} {v} not in [0, inf)' 57 | else: 58 | if not 0 <= v <= 1: 59 | return f'{k} {v} not in [0, 1]' 60 | 61 | 62 | @dataclass 63 | class SegmentSizes: 64 | """ 65 | """ 66 | total_faces: int 67 | estimated: NumpyTensor['num_estimated'] 68 | reference: NumpyTensor['num_reference'] 69 | intersect: NumpyTensor['num_estimated, num_reference'] 70 | 71 | def check_bounds(self) -> str | None: 72 | if not np.all(self.intersect.sum(axis=1) == self.estimated): 73 | return 'intersect sizes do not sum to estimated sizes' 74 | if not np.all(self.intersect.sum(axis=0) == self.reference): 75 | return 'intersect sizes do not sum to ground truth sizes' 76 | if self.intersect.sum() != self.total_faces: 77 | return f'intersect sizes sum to {self.intersect.sum()} instead of {self.total_faces}' 78 | 79 | if np.any(self.estimated < 0): 80 | return 'Estimated sizes contain negative values' 81 | if np.any(self.reference < 0): 82 | return 'Ground truth sizes contain negative values' 83 | if np.any(self.intersect < 0): 84 | return 'intersect sizes contain negative values' 85 | 86 | 87 | def compute_metrics(mesh: trimesh.Trimesh | None, estimated: NumpyTensor['f'], reference: NumpyTensor['f']) -> Metrics: 88 | """ 89 | """ 90 | metrics = {} 91 | segment_sizes = _compute_segment_sizes(estimated, reference) 92 | rm = ( 93 | _compute_directional_hamming_distance(segment_sizes.reference, segment_sizes.intersect.T) 94 | / segment_sizes.total_faces 95 | ) 96 | rf = ( 97 | _compute_directional_hamming_distance(segment_sizes.estimated, segment_sizes.intersect) 98 | / segment_sizes.total_faces 99 | ) 100 | hamming_distance = (rm + rf) / 2 101 | metrics.update({ 102 | 'hamming_distance_rm': rm, 103 | 'hamming_distance_rf': rf, 104 | 'hamming_distance': hamming_distance, 105 | }) 106 | metrics['inv_rand_index'] = 1 - _compute_rand_index(segment_sizes) 107 | metrics['gce'], metrics['lce'] = _compute_consistency_error(segment_sizes, estimated, reference) 108 | metrics['cut_discrepancy'] = compute_cut_discrepancy(mesh, estimated, reference) 109 | return Metrics(**metrics) 110 | 111 | 112 | def _compute_segment_sizes(estimated: NumpyTensor['f'], reference: NumpyTensor['f']) -> SegmentSizes: 113 | """ 114 | """ 115 | def bincount_check(arr): 116 | sizes = np.bincount(arr) 117 | assert len(sizes) == np.amax(arr) + 1 118 | return sizes, len(sizes) 119 | 120 | estimated_sizes, P_estimated = bincount_check(estimated) 121 | reference_sizes, P_reference = bincount_check(reference) 122 | intersect_sizes = np.bincount( 123 | estimated * P_reference + reference, minlength=P_estimated * P_reference 124 | ).reshape((P_estimated, P_reference)) 125 | res = SegmentSizes( 126 | total_faces=len(estimated), 127 | estimated=estimated_sizes, 128 | reference=reference_sizes, 129 | intersect=intersect_sizes, 130 | ) 131 | assert (err := res.check_bounds()) is None, err 132 | return res 133 | 134 | 135 | def _compute_directional_hamming_distance(s2_sizes: NumpyTensor, intersect_sizes: NumpyTensor) -> float: 136 | """ 137 | """ 138 | return sum(s2_sizes) - sum(intersect_sizes.max(axis=1)) 139 | 140 | 141 | def _compute_rand_index(sizes: SegmentSizes) -> float: 142 | """ 143 | """ 144 | def choose_2(n): 145 | return n * (n - 1) / 2 146 | 147 | N2 = choose_2(sizes.total_faces) 148 | s1 = choose_2(sizes.estimated).sum() 149 | s2 = choose_2(sizes.reference).sum() 150 | s12 = choose_2(sizes.intersect).sum() 151 | return (N2 - s1 - s2 + 2 * s12) / N2 152 | 153 | 154 | def _compute_consistency_error(sizes: SegmentSizes, estimated: NumpyTensor['f'], reference: NumpyTensor['f']) -> tuple[float, float]: 155 | """ 156 | """ 157 | R1 = sizes.estimated[estimated] 158 | R2 = sizes.reference[reference] 159 | E12 = (R1 - sizes.intersect[estimated, reference]) / R1 160 | E21 = (R2 - sizes.intersect[estimated, reference]) / R2 161 | assert E12.shape == estimated.shape 162 | assert E21.shape == reference.shape 163 | gce = min(E21.sum(), E12.sum()) / sizes.total_faces 164 | lce = np.sum(np.minimum(E12, E21)) / sizes.total_faces 165 | return gce, lce 166 | 167 | 168 | def seg_from_face2label(filename: Path | str) -> np.ndarray: 169 | """ 170 | """ 171 | face2label = json.loads(Path(filename).read_text()) 172 | face2label = {int(k): int(v) for k, v in face2label.items()} 173 | face2label_items = sorted(face2label.items()) 174 | assert face2label_items[ 0][0] == 0 175 | assert face2label_items[-1][0] == len(face2label) - 1 176 | return np.array([label for _, label in face2label_items], dtype=np.uint32) 177 | 178 | 179 | def benchmark_dataset_princeton_one( 180 | path_meshes : Path | str, 181 | path_segmentations : Path | str, 182 | path_segmentations_reference: Path | str, 183 | filename: str, category=None, load_json=False, 184 | ) -> Mapping[int, Metrics]: 185 | """ 186 | """ 187 | metrics = {} 188 | print(f'Processing {filename} in category {category}') 189 | 190 | mesh = trimesh.load(f'{path_meshes}/{filename}.off') 191 | 192 | if load_json: 193 | segmentation = seg_from_face2label(f'{path_segmentations}/{filename}/{filename}_face2label.json') 194 | else: 195 | with open(f'{path_segmentations}/{filename}.seg', 'r') as f: 196 | segmentation = np.array([int(x) for x in f.readlines()], dtype=np.uint32) 197 | 198 | bench_dir = Path(f'{path_segmentations_reference}/{filename}') 199 | # Compute average metrics over all human segmentations 200 | for bench_path in bench_dir.iterdir(): 201 | with open(bench_path, 'r') as f: 202 | bench = np.array([int(x) for x in f.readlines()], dtype=np.uint32) 203 | metric = compute_metrics(mesh, segmentation, bench) 204 | assert (err := metric.check_bounds()) is None, (metric, filename, err) 205 | metrics.setdefault(category, []).append(metric) 206 | return metrics 207 | 208 | 209 | def benchmark_dataset_coseg_one( 210 | path_meshes : Path | str, 211 | path_segmentations : Path | str, 212 | path_segmentations_reference: Path | str, 213 | filename: str, category=None 214 | ) -> Mapping[int, Metrics]: 215 | """ 216 | """ 217 | metrics = {} 218 | print(f'Processing {filename} in category {category}') 219 | 220 | mesh = trimesh.load(f'{path_meshes}/{filename}.off') 221 | 222 | segmentation = seg_from_face2label(f'{path_segmentations}/{filename}/{filename}_face2label.json') 223 | 224 | with open(f'{path_segmentations_reference}/{filename}.seg', 'r') as f: 225 | bench = np.array([int(x) for x in f.readlines()], dtype=np.uint32) 226 | metric = compute_metrics(mesh, segmentation, bench) 227 | assert (err := metric.check_bounds()) is None, (metric, filename, err) 228 | metrics.setdefault(category, []).append(metric) 229 | return metrics 230 | 231 | 232 | def benchmark_dataset_princeton( 233 | path_meshes : Path | str, 234 | path_segmentations : Path | str, 235 | path_segmentations_reference: Path | str, 236 | load_json=False, 237 | ) -> Mapping[int, Metrics]: 238 | """ 239 | """ 240 | extract_category = lambda i: (i - 1) // 20 + 1 241 | 242 | pool = mp.Pool(mp.cpu_count()) 243 | chunks = [ 244 | (path_meshes, path_segmentations, path_segmentations_reference, i, extract_category(i), load_json) 245 | for i in range(1, 401) if extract_category(i) not in [14] #, 4, 8, 13, 17] 246 | ] 247 | metrics_list = pool.starmap(benchmark_dataset_princeton_one, chunks) 248 | metrics = {} 249 | for metrics_one in metrics_list: 250 | if metrics_one is None: 251 | continue 252 | for k, v in metrics_one.items(): 253 | metrics.setdefault(k, []).extend(v) 254 | return { 255 | 'averages': Metrics.average([m for v in metrics.values() for m in v]), 256 | 'averages_by_category': {k: Metrics.average(v) for k, v in metrics.items()} 257 | } 258 | 259 | 260 | def benchmark_dataset_coseg( 261 | path_meshes : Path | str, 262 | path_segmentations : Path | str, 263 | path_segmentations_reference: Path | str, 264 | ) -> Mapping[int, Metrics]: 265 | """ 266 | """ 267 | chunks = [] 268 | categories = ['candelabra', 'chairs', 'fourleg', 'goblets', 'guitars', 'irons', 'lamps', 'vases'] 269 | for cat in categories: 270 | cat_path_meshes = f'{path_meshes}/{cat}' 271 | cat_path_segmentations = f'{path_segmentations}/{cat}' 272 | cat_path_segmentations_reference = f'{path_segmentations_reference}/{cat}_gt' 273 | filenames = glob.glob(f'{cat_path_meshes}/*.off') 274 | chunks.extend([ 275 | (cat_path_meshes, cat_path_segmentations, cat_path_segmentations_reference, i, cat) 276 | for i in [int(Path(f).stem) for f in filenames] 277 | ]) 278 | pool = mp.Pool(mp.cpu_count()) 279 | metrics_list = pool.starmap(benchmark_dataset_coseg_one, chunks) 280 | metrics = {} 281 | for metrics_one in metrics_list: 282 | for k, v in metrics_one.items(): 283 | metrics.setdefault(k, []).extend(v) 284 | return { 285 | 'averages': Metrics.average([m for v in metrics.values() for m in v]), 286 | 'averages_by_category': {k: Metrics.average(v) for k, v in metrics.items()} 287 | } 288 | 289 | 290 | if __name__ == "__main__": 291 | # metrics1 = benchmark_dataset_princeton( 292 | # path_meshes='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/off', 293 | # path_segmentations ='/home/gtangg12/samesh/outputs/mesh_segmentation_output_princeton_shape_diameter_function', 294 | # path_segmentations_reference='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/seg/Benchmark', 295 | # load_json=True 296 | # ) 297 | # metrics2 = benchmark_dataset_princeton( 298 | # path_meshes='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/off', 299 | # path_segmentations ='/home/gtangg12/samesh/outputs/mesh_segmentation_output_princeton', 300 | # path_segmentations_reference='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/seg/Benchmark', 301 | # load_json=True 302 | # ) 303 | # print(metrics1) 304 | # print(metrics2) 305 | 306 | metrics1 = benchmark_dataset_princeton( 307 | path_meshes='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/off', 308 | path_segmentations='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/seg/ShapeDiam', 309 | path_segmentations_reference='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/seg/Benchmark', 310 | ) 311 | metrics2 = benchmark_dataset_princeton( 312 | path_meshes='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/off', 313 | path_segmentations ='/home/gtangg12/samesh/outputs/mesh_segmentation_output_princeton_dynamic', 314 | path_segmentations_reference='/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/seg/Benchmark', 315 | load_json=True 316 | ) 317 | print(metrics1) 318 | print(metrics2) 319 | 320 | # metrics1 = benchmark_dataset_coseg( 321 | # path_meshes='/home/gtangg12/samesh/data/coseg', 322 | # path_segmentations='/home/gtangg12/samesh/outputs/mesh_segmentation_output_coseg_shape_diameter_function', 323 | # path_segmentations_reference='/home/gtangg12/datacoseg', 324 | # ) 325 | # metrics2 = benchmark_dataset_coseg( 326 | # path_meshes='/home/gtangg12/samesh/data/coseg', 327 | # path_segmentations='/home/gtangg12/samesh/outputs/mesh_segmentation_output_coseg', 328 | # path_segmentations_reference='/home/gtangg12/datacoseg', 329 | # ) 330 | # print(metrics1) 331 | # print(metrics2) -------------------------------------------------------------------------------- /src/samesh/metrics/mesh_segmentation_cut_discrepancy.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | import trimesh 3 | import numpy as np 4 | import igraph 5 | from samesh.data.common import * 6 | 7 | 8 | def compute_cut_discrepancy(mesh: trimesh.Trimesh, s1: NumpyTensor['f'], s2: NumpyTensor['f']) -> float: 9 | """ 10 | """ 11 | cut1 = _get_cut_vertex(mesh, s1) 12 | cut2 = _get_cut_vertex(mesh, s2) 13 | if len(cut1) == 0 or \ 14 | len(cut2) == 0: # Undefined for empty cuts 15 | return 0 16 | d12 = _compute_distance_cuts(mesh, cut1, cut2) 17 | d21 = _compute_distance_cuts(mesh, cut2, cut1) 18 | cd = (d12.sum() + d21.sum()) / (len(d12) + len(d21)) # same bug as in SegEval's original code 19 | avg_radius = _approx_average_radius(mesh) 20 | return cd / avg_radius 21 | 22 | 23 | def _compute_distance_cuts(mesh: trimesh.Trimesh, cut1: NumpyTensor['f'], cut2: NumpyTensor['f']) -> NumpyTensor | None: 24 | """ 25 | Compute the mean distance from vertices in cut1 to the closest vertex in cut2. 26 | Distance is taken as along the skeleton of the mesh (i.e. shortest path through mesh edges). This is consistent 27 | with SegEval metric (https://segeval.cs.princeton.edu/). 28 | """ 29 | S2_node = len(mesh.vertices) 30 | 31 | graph = igraph.Graph(directed=False) 32 | graph.add_vertices(len(mesh.vertices) + 1) 33 | graph.add_edges( 34 | mesh.edges, attributes={'weight': np.linalg.norm( 35 | mesh.vertices[mesh.edges[:, 0]] - 36 | mesh.vertices[mesh.edges[:, 1]], axis=1 37 | )}, 38 | ) 39 | graph.add_edges( 40 | [(S2_node, vertex) for vertex in cut2], attributes={ 41 | 'weight': np.zeros(len(cut2)) 42 | }, 43 | ) 44 | shortest_path = np.array(graph.shortest_paths(source=S2_node, target=cut1, weights='weight')) 45 | assert shortest_path.shape == (1, len(cut1)) 46 | return shortest_path[0] 47 | 48 | 49 | def _get_cut_vertex(mesh: trimesh.Trimesh, partition: NumpyTensor['f']) -> set[int]: 50 | """ 51 | Get all vertices along cut boundaries of a segmentation 52 | """ 53 | vpair2face: Mapping[tuple[int, int], list[int]] = {} 54 | for i, (v0, v1, v2) in enumerate(mesh.faces): 55 | vpair2face.setdefault(tuple(sorted((v0, v1))), []).append(i) 56 | vpair2face.setdefault(tuple(sorted((v1, v2))), []).append(i) 57 | vpair2face.setdefault(tuple(sorted((v2, v0))), []).append(i) 58 | cut: set[int] = set() 59 | for vpair, fpair in vpair2face.items(): 60 | if len(fpair) == 1: 61 | continue # this is a boundary edge 62 | #assert len(fpair) == 2 63 | if partition[fpair[0]] != partition[fpair[1]]: 64 | cut.add(vpair[0]) 65 | cut.add(vpair[1]) 66 | return cut 67 | 68 | 69 | def _approx_average_radius(mesh: trimesh.Trimesh) -> float: 70 | """ 71 | Weighted distance from an average face to the centroid of the surface 72 | """ 73 | face_cents = np.mean(mesh.vertices[mesh.faces], axis=1) # (F, 3) 74 | face_areas = np.linalg.norm( 75 | np.cross( 76 | mesh.vertices[mesh.faces[:, 1]] - mesh.vertices[mesh.faces[:, 0]], 77 | mesh.vertices[mesh.faces[:, 2]] - mesh.vertices[mesh.faces[:, 0]], 78 | ), axis=1, 79 | ) 80 | cent = (face_cents * face_areas[:, None]).sum(axis=0) / face_areas.sum() 81 | dist = np.linalg.norm(face_cents - cent[None, :], axis=1) 82 | return (dist * face_areas).sum() / face_areas.sum() -------------------------------------------------------------------------------- /src/samesh/metrics/mesh_segmentation_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import trimesh 4 | 5 | from samesh.metrics.mesh_segmentation import compute_metrics 6 | 7 | 8 | def test_combinatorial_metrics(): 9 | mesh = trimesh.Trimesh( 10 | vertices=[ 11 | [ 1, 0, 0], 12 | [-1, 0, 0], 13 | [ 0, 1, 0], 14 | [ 0, -1, 0], 15 | [ 0, 0, 1], 16 | [ 0, 0, -1], 17 | ], 18 | faces=[ 19 | [0, 2, 4], 20 | [0, 4, 3], 21 | [0, 3, 5], 22 | [0, 5, 2], 23 | ], 24 | ) 25 | estimated = np.array([0, 0, 1, 1, 2, 2], dtype=np.uint32) 26 | reference = np.array([0, 1, 3, 2, 2, 2], dtype=np.uint32) 27 | 28 | metrics = compute_metrics(mesh, estimated, reference) 29 | assert math.isclose(metrics.hamming_distance_rm, 1 / 6) 30 | assert math.isclose(metrics.hamming_distance_rf, 1 / 3) 31 | assert math.isclose(metrics.hamming_distance, 1 / 4) 32 | assert math.isclose(metrics.inv_rand_index, 4 / 15) 33 | assert math.isclose(metrics.lce, 1 / 12) 34 | assert math.isclose(metrics.gce, 2 / 9) 35 | 36 | 37 | def test_cut_discrepancy(): 38 | octahedron = trimesh.Trimesh( 39 | vertices=[ 40 | [ 1, 0, 0], 41 | [-1, 0, 0], 42 | [ 0, 1, 0], 43 | [ 0, -1, 0], 44 | [ 0, 0, 1], 45 | [ 0, 0, -1], 46 | ], 47 | faces=[ 48 | [0, 2, 4], 49 | [0, 4, 3], 50 | [0, 3, 5], 51 | [0, 5, 2], 52 | [1, 4, 2], 53 | [1, 3, 4], 54 | [1, 5, 3], 55 | [1, 2, 5], 56 | ], 57 | ) 58 | estimated = np.array([0, 0, 0, 0, 1, 1, 1, 1], dtype=np.uint32) # right half 0, left half 1 59 | reference = np.array([0, 1, 1, 0, 0, 1, 1, 0], dtype=np.uint32) # top half 0, bottom half 1 60 | metrics = compute_metrics(octahedron, estimated, reference) 61 | assert math.isclose(metrics.cut_discrepancy, math.sqrt(3/2)) 62 | 63 | 64 | if __name__ == "__main__": 65 | test_combinatorial_metrics() 66 | test_cut_discrepancy() 67 | print("All tests passed!") -------------------------------------------------------------------------------- /src/samesh/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/src/samesh/models/__init__.py -------------------------------------------------------------------------------- /src/samesh/models/sam.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from PIL import Image 9 | from omegaconf import OmegaConf 10 | from transformers import AutoProcessor, AutoModel 11 | 12 | USE_SAMHQ = False 13 | if USE_SAMHQ: 14 | from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry 15 | else: 16 | from sam2.build_sam import build_sam2 17 | from sam2.sam2_image_predictor import SAM2ImagePredictor 18 | from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator 19 | 20 | 21 | from samesh.data.common import NumpyTensor 22 | 23 | 24 | def combine_bmasks(masks: NumpyTensor['n h w'], sort=False) -> NumpyTensor['h w']: 25 | """ 26 | """ 27 | mask_combined = np.zeros_like(masks[0], dtype=int) 28 | if sort: 29 | masks = sorted(masks, key=lambda x: x.sum(), reverse=True) 30 | for i, mask in enumerate(masks): 31 | mask_combined[mask] = i + 1 32 | return mask_combined 33 | 34 | 35 | def decompose_mask(mask: NumpyTensor['h w'], background=0) -> NumpyTensor['n h w']: 36 | """ 37 | """ 38 | labels = np.unique(mask) 39 | labels = labels[labels != background] 40 | return mask == labels[:, None, None] 41 | 42 | 43 | def remove_artifacts(mask: NumpyTensor['h w'], mode: str, min_area=128) -> NumpyTensor['h w']: 44 | """ 45 | Removes small islands/fill holes from a mask. 46 | """ 47 | assert mode in ['holes', 'islands'] 48 | mode_holes = (mode == 'holes') 49 | 50 | def remove_helper(bmask): 51 | # opencv connected components operates on binary masks only 52 | bmask = (mode_holes ^ bmask).astype(np.uint8) 53 | nregions, regions, stats, _ = cv2.connectedComponentsWithStats(bmask, 8) 54 | sizes = stats[:, -1][1:] # Row 0 corresponds to 0 pixels 55 | fill = [i + 1 for i, s in enumerate(sizes) if s < min_area] + [0] 56 | if not mode_holes: 57 | fill = [i for i in range(nregions) if i not in fill] 58 | return np.isin(regions, fill) 59 | 60 | mask_combined = np.zeros_like(mask) 61 | for label in np.unique(mask): # also process background 62 | mask_combined[remove_helper(mask == label)] = label 63 | return mask_combined 64 | 65 | 66 | def colormap_mask( 67 | mask : NumpyTensor['h w'], 68 | image: NumpyTensor['h w 3']=None, background=np.array([255, 255, 255]), foreground=None, blend=0.25 69 | ) -> Image.Image: 70 | """ 71 | """ 72 | palette = np.random.randint(0, 255, (np.max(mask) + 1, 3)) 73 | palette[0] = background 74 | if foreground is not None: 75 | for i in range(1, len(palette)): 76 | palette[i] = foreground 77 | image_mask = palette[mask.astype(int)] # type conversion for boolean masks 78 | image_blend = image_mask if image is None else image_mask * (1 - blend) + image * blend 79 | image_blend = np.clip(image_blend, 0, 255).astype(np.uint8) 80 | return Image.fromarray(image_blend) 81 | 82 | 83 | def colormap_bmask(bmask: NumpyTensor['h w']) -> Image.Image: 84 | """ 85 | """ 86 | return colormap_mask(bmask, background=np.array([0, 0, 0]), foreground=np.array([255, 255, 255])) 87 | 88 | 89 | def colormap_bmasks( 90 | masks: NumpyTensor['n h w'], 91 | image: NumpyTensor['h w 3']=None, background=np.array([255, 255, 255]), blend=0.25 92 | ) -> Image.Image: 93 | """ 94 | """ 95 | mask = combine_bmasks(masks) 96 | return colormap_mask(mask, image, background=background, blend=blend) 97 | 98 | 99 | def point_grid_from_mask(mask: NumpyTensor['h w'], n: int) -> NumpyTensor['n 2']: 100 | """ 101 | Sample points within valid mask normalized to [0, 1] x [0, 1] 102 | """ 103 | valid = np.argwhere(mask) 104 | if len(valid) == 0: 105 | raise ValueError('No valid points in mask') 106 | 107 | h, w = mask.shape 108 | n = min(n, len(valid)) 109 | indices = np.random.choice(len(valid), n, replace=False) 110 | samples = valid[indices].astype(float) 111 | samples[:, 0] /= h - 1 112 | samples[:, 1] /= w - 1 113 | samples = samples[:, [1, 0]] 114 | samples = samples[np.lexsort((samples[:, 1], samples[:, 0]))] 115 | return samples 116 | 117 | 118 | class SamModel(nn.Module): 119 | """ 120 | """ 121 | def __init__(self, config: OmegaConf, device='cuda'): 122 | """ 123 | """ 124 | super().__init__() 125 | self.config = config 126 | self.device = device 127 | 128 | if config.sam.auto: 129 | self.setup_sam(mode='auto') 130 | else: 131 | if config.sam.ground: 132 | self.setup_grounding_dino() 133 | self.setup_sam(mode='pred') 134 | 135 | def setup_sam(self, mode='auto'): 136 | """ 137 | """ 138 | match = re.search(r'vit_(l|tiny|h)', self.config.sam.checkpoint) 139 | self.sam_model = sam_model_registry[match.group(0)](checkpoint=self.config.sam.checkpoint) 140 | self.sam_model = self.sam_model.to(self.device) 141 | self.sam_model.eval() 142 | self.engine = { 143 | 'pred': SamPredictor, 144 | 'auto': SamAutomaticMaskGenerator, 145 | }[mode](self.sam_model, **self.config.sam.get('engine_config', {})) 146 | 147 | def setup_grounding_dino(self): 148 | """ 149 | """ 150 | self.grounding_dino_processor, self.grounding_dino_model = \ 151 | AutoProcessor.from_pretrained(self.config.grounding_dino.checkpoint), \ 152 | AutoModel .from_pretrained(self.config.grounding_dino.checkpoint).to(self.device) 153 | 154 | def process_image(self, image: Image, prompt: dict = None) -> NumpyTensor['n h w']: 155 | """ 156 | For information on prompt format see: 157 | 158 | https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/predictor.py#L104 159 | """ 160 | image = np.array(image) 161 | 162 | if self.config.sam.auto: 163 | annotations = self.engine.generate(image) 164 | else: 165 | self.engine.set_image(image) 166 | annotations = self.engine.predict(**prompt)[0] 167 | annotations = [{'segmentation': m, 'area': m.sum().item()} for m in annotations] 168 | 169 | annotations = sorted(annotations, key=lambda x: x['area'], reverse=True) 170 | masks = np.stack([anno['segmentation'] for anno in annotations]) 171 | return masks 172 | 173 | def process_boxes(self, image: Image, texts: list[str]) -> tuple[ 174 | list[NumpyTensor[4]], 175 | list[NumpyTensor[2]] 176 | ]: 177 | """ 178 | """ 179 | texts = '. '.join(texts) 180 | inputs = self.grounding_dino_processor(texts, return_tensors='pt').to(self.device) 181 | with torch.no_grad(): 182 | outputs = self.grounding_dino_model(**inputs) 183 | 184 | boxes, logits = self.grounding_dino_processor.post_process_grounded_object_detection( 185 | outputs, 186 | inputs.input_ids, 187 | box_threshold=0.4, text_threshold=0.3, target_sizes=[image.size[::-1]] 188 | ) 189 | return boxes, logits 190 | 191 | def forward(self, image: Image, texts: list[str]=None) -> NumpyTensor['n h w']: 192 | """ 193 | """ 194 | if self.config.sam.auto: 195 | masks = self.process_image(image) 196 | else: 197 | boxes, _ = self.process_boxes(image, texts) 198 | masks = [] 199 | for box in boxes: 200 | masks.append(self.process_image(image, {'box': box})) 201 | masks = np.concatenate(masks) 202 | return masks 203 | 204 | 205 | class Sam2Model(SamModel): 206 | """ 207 | """ 208 | def setup_sam(self, mode='auto'): 209 | """ 210 | """ 211 | self.sam_model = build_sam2(self.config.sam.model_config, self.config.sam.checkpoint, device=self.device, apply_postprocessing=False) 212 | self.sam_model.eval() 213 | self.engine = { 214 | 'pred': SAM2ImagePredictor, 215 | 'auto': SAM2AutomaticMaskGenerator, 216 | }[mode](self.sam_model, **self.config.sam.get('engine_config', {})) 217 | 218 | 219 | if __name__ == '__main__': 220 | import time 221 | 222 | device = 'cuda' 223 | image = Image.open('/home/ubuntu/meshseg/tests/examples/goldpot.png') 224 | 225 | config = OmegaConf.create({ 226 | 'sam': { 227 | 'checkpoint': '/home/ubuntu/meshseg/checkpoints/sam_hq_vit_h.pth', 228 | 'auto': True, 229 | 'ground': False, 230 | 'engine_config': {'points_per_side': 32}, 231 | }, 232 | 'grounding_dino': { 233 | 'checkpoint': 'IDEA-Research/grounding-dino-tiny', # TODO find larger model 234 | }, 235 | }) 236 | 237 | sam = SamModel(config, device) 238 | start_time = time.time() 239 | masks = sam(image) 240 | print(f'Elapsed time: {time.time() - start_time:.2f} s') 241 | image = colormap_bmasks(masks, np.array(image)) 242 | image.save('test_mask.png') 243 | 244 | ''' 245 | config.sam.auto = False 246 | config.sam.ground = False 247 | sam = SamModel(config, device) 248 | masks = sam.process_image(image, prompt={ 249 | 'point_coords': np.array([[image.height // 2, image.width // 2]]), 250 | 'point_labels': np.ones((1,)), 251 | 'multimask_output': False 252 | }) 253 | image = colormap_bmasks(masks, np.array(image)) 254 | image.save('test_mask_grounded.png') 255 | ''' 256 | 257 | # For RuntimeError: No available kernel. Aborting execution 258 | # https://github.com/facebookresearch/segment-anything-2/issues/48 259 | config2 = OmegaConf.create({ 260 | 'sam': { 261 | 'model_config': 'sam2_hiera_l.yaml', 262 | 'checkpoint' : '/home/ubuntu/meshseg/checkpoints/sam2_hiera_large.pt', 263 | 'auto': True, 264 | 'engine_config': {'points_per_side': 32}, 265 | }, 266 | }) 267 | 268 | sam2 = Sam2Model(config2, device) 269 | start_time = time.time() 270 | masks = sam2(image) 271 | print(f'Elapsed time: {time.time() - start_time:.2f} s') 272 | image = colormap_bmasks(masks, np.array(image)) 273 | image.save('test_mask2.png') -------------------------------------------------------------------------------- /src/samesh/models/sam_mesh.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import json 4 | import copy 5 | import multiprocessing as mp 6 | from collections import defaultdict, Counter 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import trimesh 13 | import igraph 14 | from PIL import Image 15 | from omegaconf import OmegaConf 16 | from trimesh.base import Trimesh, Scene 17 | from tqdm import tqdm 18 | from natsort import natsorted 19 | 20 | from samesh.data.common import NumpyTensor 21 | from samesh.data.loaders import read_scene, remove_texture, scene2mesh 22 | from samesh.renderer.renderer import Renderer, render_multiview, colormap_faces, colormap_norms 23 | from samesh.models.sam import SamModel, Sam2Model, combine_bmasks, colormap_mask, remove_artifacts, point_grid_from_mask 24 | from samesh.utils.cameras import * 25 | from samesh.utils.mesh import duplicate_verts 26 | from samesh.models.shape_diameter_function import * 27 | 28 | 29 | def colormap_faces_mesh(mesh: Trimesh, face2label: dict[int, int], background=np.array([0, 0, 0])) -> Trimesh: 30 | """ 31 | """ 32 | label_max = max(face2label.values()) 33 | palette = RandomState(0).randint(0, 255, (label_max + 1, 3)) # +1 for unlabeled faces 34 | palette[0] = background 35 | mesh = duplicate_verts(mesh) # needed to prevent face color interpolation 36 | faces_colored = set() 37 | for face, label in face2label.items(): 38 | mesh.visual.face_colors[face, :3] = palette[label] 39 | faces_colored.add(face) 40 | #print(np.unique(mesh.visual.face_colors, axis=0, return_counts=True)) 41 | ''' 42 | for face in range(len(mesh.faces)): 43 | if face not in faces_colored: 44 | mesh.visual.face_colors[face, :3] = background 45 | print('Unlabeled face ', face) 46 | ''' 47 | return mesh 48 | 49 | 50 | def norms_mask(norms: NumpyTensor['h w 3'], cam2world: NumpyTensor['4 4'], threshold=0.0) -> NumpyTensor['h w 3']: 51 | """ 52 | Mask pixels that are directly facing camera 53 | """ 54 | lookat = cam2world[:3, :3] @ np.array([0, 0, 1]) 55 | return np.abs(np.dot(norms, lookat)) > threshold 56 | 57 | 58 | def load_items(path: Path) -> dict[str, list]: 59 | """ 60 | """ 61 | print('Loading items from cache...') 62 | 63 | filenames = list(path.glob('matte_*.png')) 64 | filenames = natsorted(filenames, key=lambda x: int(x.stem.split('_')[-1])) 65 | items = { 66 | 'matte' : list(map(Image.open, filenames)), 67 | 'faces' : [np.load(path / f'faces_{i}.npy') for i in range(len(filenames))], 68 | 'norms' : [np.load(path / f'norms_{i}.npy') for i in range(len(filenames))], 69 | 'bmasks': [np.load(path / f'bmask_{i}.npy') for i in range(len(filenames))], 70 | 'cmasks': [np.load(path / f'cmask_{i}.npy') for i in range(len(filenames))], 71 | 'norms_masked': [np.load(path / f'norms_mask_{i}.npy') for i in range(len(filenames))], 72 | 'poses': np.load(path / 'poses.npy') 73 | } 74 | try: 75 | filenames = list(path.glob('sdf_*.png')) 76 | filenames = natsorted(filenames, key=lambda x: int(x.stem.split('_')[-1])) 77 | items['sdf'] = list(map(Image.open, filenames)) 78 | except FileNotFoundError: 79 | pass 80 | return items 81 | 82 | 83 | def save_items(items: dict, path: Path) -> None: 84 | """ 85 | """ 86 | print('Saving items to cache...') 87 | 88 | for i, image in enumerate(items['matte']): 89 | image.save(path / f'matte_{i}.png') 90 | if 'sdf' in items: 91 | for i, image in enumerate(items['sdf']): 92 | image.save(path / f'sdf_{i}.png') 93 | 94 | for i, (faces, bmask, cmask, norms, norms_masked) in enumerate( 95 | zip(items['faces'], items['bmasks'], items['cmasks'], items['norms'], items['norms_masked']) 96 | ): 97 | np.save(path / f'faces_{i}.npy', faces) 98 | np.save(path / f'bmask_{i}.npy', bmask) 99 | np.save(path / f'cmask_{i}.npy', cmask) 100 | np.save(path / f'norms_{i}.npy', norms) 101 | np.save(path / f'norms_mask_{i}.npy', norms_masked) 102 | np.save(path / 'poses.npy', items['poses']) 103 | 104 | 105 | def visualize_items(items: dict, path: Path) -> None: 106 | """ 107 | """ 108 | os.makedirs(path, exist_ok=True) 109 | 110 | for i, image in tqdm(enumerate(items['matte']), 'Visualizing images'): 111 | image.save(f'{path}/matte_{i}.png') 112 | if 'sdf' in items: 113 | for i, image in tqdm(enumerate(items['sdf']), 'Visualizing SDF'): 114 | image.save(f'{path}/sdf_{i}.png') 115 | 116 | for i, faces in tqdm(enumerate(items['faces']), 'Visualizing FaceIDs'): 117 | colormap_faces(faces).save(f'{path}/faces_{i}.png') 118 | for i, cmask in tqdm(enumerate(items['cmasks']), 'Visualizing SAM Masks'): 119 | colormap_mask (cmask).save(f'{path}/masks_{i}.png') 120 | for i, norms in tqdm(enumerate(items['norms']), 'Visualizing Normals'): 121 | colormap_norms(norms).save(f'{path}/norms_{i}.png') 122 | for i, norms_masked in tqdm(enumerate(items['norms_masked']), 'Visualizing Normals Mask'): 123 | colormap_norms(norms_masked).save(f'{path}/norms_mask_{i}.png') 124 | 125 | 126 | """ 127 | Multiprocessing functions SamModelMesh 128 | """ 129 | def compute_face2label( 130 | labels: NumpyTensor['l'], 131 | faceid: NumpyTensor['h w'], 132 | mask : NumpyTensor['h w'], 133 | norms : NumpyTensor['h w 3'], 134 | pose : NumpyTensor['4 4'], 135 | label_sequence_count: int, threshold_counts: int=16 136 | ): 137 | """ 138 | """ 139 | #print(f'Computing face2label starting with {label_sequence_count}') 140 | 141 | normal_mask = norms_mask(norms, pose) 142 | 143 | face2label = defaultdict(Counter) 144 | for j, label in enumerate(labels): 145 | label_sequence = label_sequence_count + j 146 | faces_mask = (mask == label) & normal_mask 147 | faces, counts = np.unique(faceid[faces_mask], return_counts=True) 148 | faces = faces[counts > threshold_counts] 149 | faces = faces[faces != -1] # remove background 150 | for face in faces: 151 | face2label[int(face)][label_sequence] += np.sum(faces_mask & (faceid == face)) 152 | return face2label 153 | 154 | 155 | def compute_connections(i: int, j: int, face2label1: dict, face2label2: dict, counter_threshold=32): 156 | """ 157 | """ 158 | #print(f'Computing partial connection ratios for {i} and {j}') 159 | 160 | connections = defaultdict(Counter) 161 | face2label1_common = {face: counter.most_common(1)[0][0] for face, counter in face2label1.items()} 162 | face2label2_common = {face: counter.most_common(1)[0][0] for face, counter in face2label2.items()} 163 | for face1, label1 in face2label1_common.items(): 164 | for face2, label2 in face2label2_common.items(): 165 | if face1 != face2: 166 | continue 167 | connections[label1][label2] += 1 168 | connections[label2][label1] += 1 169 | # remove connections where # overlapping faces is below threshold 170 | for label1, counter in connections.items(): 171 | connections[label1] = {k: v for k, v in counter.items() if v > counter_threshold} 172 | return connections 173 | 174 | 175 | class SamModelMesh(nn.Module): 176 | """ 177 | """ 178 | def __init__(self, config: OmegaConf, device='cuda', use_sam=True): 179 | """ 180 | """ 181 | super().__init__() 182 | self.config = config 183 | self.config.cache = Path(config.cache) if config.cache is not None else None 184 | self.renderer = Renderer(config.renderer) 185 | if use_sam and (self.config.cache is None or not self.config.cache.exists() or self.config.cache_overwrite): 186 | self.sam = Sam2Model(config.sam, device=device) 187 | 188 | def load(self, scene: Scene, mesh_graph=True): 189 | """ 190 | """ 191 | self.renderer.set_object(scene) 192 | self.renderer.set_camera() 193 | 194 | if mesh_graph: 195 | self.mesh_edges = trimesh.graph.face_adjacency(mesh=self.renderer.tmesh) 196 | self.mesh_graph = defaultdict(set) 197 | for face1, face2 in self.mesh_edges: 198 | self.mesh_graph[face1].add(face2) 199 | self.mesh_graph[face2].add(face1) 200 | 201 | def render(self, scene: Scene, visualize_path=None) -> dict[str, NumpyTensor]: 202 | """ 203 | """ 204 | if self.config.cache is not None and self.config.cache.exists(): 205 | if self.config.cache_overwrite: 206 | shutil.rmtree(self.config.cache) 207 | else: 208 | return load_items(self.config.cache) 209 | 210 | def render_func(uv_map=False): 211 | renderer_args = self.config.renderer.renderer_args.copy() 212 | if uv_map: 213 | renderer_args['uv_map'] = True # handle cases like sdf 214 | return render_multiview( 215 | self.renderer, 216 | camera_generation_method=self.config.renderer.camera_generation_method, 217 | renderer_args=renderer_args, 218 | sampling_args=self.config.renderer.sampling_args, 219 | lighting_args=self.config.renderer.lighting_args, 220 | ) 221 | 222 | def compute_norms_masked(norms: NumpyTensor['h w 3'], pose: NumpyTensor['4 4']): 223 | """ 224 | """ 225 | valid = norms_mask(norms, pose) 226 | norms_masked = norms.copy() 227 | norms_masked[~valid] = np.array([1, 1, 1]) 228 | return norms_masked 229 | 230 | renders = render_func() 231 | renders['norms_masked'] = [ 232 | compute_norms_masked(norms, pose) for norms, pose in zip(renders['norms'], renders['poses']) 233 | ] 234 | 235 | def call_sam(image: Image, mask: NumpyTensor['h w']): 236 | """ 237 | """ 238 | self.sam.engine.point_grids = \ 239 | [point_grid_from_mask(mask, self.config.sam.sam.engine_config.points_per_side ** 2)] 240 | return self.sam(image) 241 | 242 | bmasks_list = [] 243 | 244 | if 'norms' in self.config.sam_mesh.use_modes: 245 | images1 = [colormap_norms(norms) for norms in renders['norms']] 246 | bmasks1 = [ 247 | call_sam(image, faces != -1) for image, faces in \ 248 | tqdm(zip(images1, renders['faces']), 'Computing SAM Masks for norms') 249 | ] 250 | bmasks_list.extend(bmasks1) 251 | 252 | if 'sdf' in self.config.sam_mesh.use_modes: 253 | #scene_sdf = remove_texture(scene) 254 | tmesh_sdf = prep_mesh_shape_diameter_function(scene) 255 | tmesh_sdf = colormap_shape_diameter_function(tmesh_sdf, sdf_values=shape_diameter_function(tmesh_sdf)) 256 | self.load(tmesh_sdf) 257 | renders_sdf = render_func(uv_map=True) 258 | images2 = renders_sdf['matte'] 259 | 260 | bmasks2 = [ 261 | call_sam(image, faces != -1) for image, faces in \ 262 | tqdm(zip(images2, renders['faces']), 'Computing SAM Masks for sdf') 263 | ] 264 | bmasks_list.extend(bmasks2) 265 | renders['sdf'] = renders_sdf['matte'] 266 | 267 | if 'matte' in self.config.sam_mesh.use_modes: # default matte render 268 | images3 = renders['matte'] 269 | bmasks3 = [ 270 | call_sam(image, faces != -1) for image, faces in \ 271 | tqdm(zip(images3, renders['faces']), 'Computing SAM Masks for matte') 272 | ] 273 | bmasks_list.extend(bmasks3) 274 | 275 | self.load(scene) # restore original scene 276 | 277 | n = len(renders['faces']) 278 | m = len(bmasks_list) // n 279 | bmasks = [ 280 | np.concatenate([bmasks_list[j * n + i] for j in range(m)], axis=0) 281 | for i in range(n) 282 | ] 283 | cmasks = [combine_bmasks(masks, sort=True) for masks in bmasks] 284 | # sometimes SAM doesn't separate body from background, so we have extra step to remove background using faceids 285 | for cmask, faces in zip(cmasks, renders['faces']): 286 | cmask += 1 287 | cmask[faces == -1] = 0 288 | min_area = self.config.sam_mesh.get('min_area', 1024) 289 | cmasks = [remove_artifacts(mask, mode='islands', min_area=min_area) for mask in cmasks] 290 | cmasks = [remove_artifacts(mask, mode='holes' , min_area=min_area) for mask in cmasks] 291 | renders['bmasks'] = bmasks 292 | renders['cmasks'] = cmasks 293 | 294 | if self.config.cache is not None: 295 | self.config.cache.mkdir(parents=True) 296 | save_items(renders, self.config.cache) 297 | if visualize_path is not None: 298 | visualize_items(renders, visualize_path) 299 | return renders 300 | 301 | def lift(self, renders: dict[str, NumpyTensor]) -> dict: 302 | """ 303 | """ 304 | be, en = 0, len(renders['faces']) 305 | renders = {k: [v[i] for i in range(be, en) if len(v)] for k, v in renders.items()} 306 | 307 | print('Computing face2label for each view on ', mp.cpu_count(), ' cores') 308 | label_sequence_count = 1 # background is 0 309 | args = [] 310 | for faceid, cmask, norms, pose in zip( 311 | renders['faces'], 312 | renders['cmasks'], 313 | renders['norms'], 314 | renders['poses'], 315 | ): 316 | labels = np.unique(cmask) 317 | labels = labels[labels != 0] # remove background 318 | args.append((labels, faceid, cmask, norms, pose, label_sequence_count, self.config.sam_mesh.get('face2label_threshold', 16))) 319 | label_sequence_count += len(labels) 320 | 321 | with mp.Pool(mp.cpu_count()) as pool: 322 | face2label_views = pool.starmap(compute_face2label, args) 323 | 324 | print('Building match graph on ', mp.cpu_count(), ' cores') 325 | args = [] 326 | for i, face2label1 in enumerate(face2label_views): 327 | for j, face2label2 in enumerate(face2label_views): 328 | if i < j: 329 | args.append((i, j, face2label1, face2label2, self.config.sam_mesh.get('connections_threshold', 32))) 330 | 331 | with mp.Pool(mp.cpu_count()) as pool: 332 | partial_connections = pool.starmap(compute_connections, args) 333 | 334 | connections_ratios = defaultdict(Counter) 335 | for c in partial_connections: 336 | for label1, counter in c.items(): 337 | connections_ratios[label1].update(counter) 338 | 339 | # normalize ratios 340 | for label1, counter in connections_ratios.items(): 341 | total = sum(counter.values()) 342 | connections_ratios[label1] = {k: v / total for k, v in counter.items()} 343 | 344 | counter_lens = [len(counter) for counter in connections_ratios.values()] 345 | counter_lens = sorted(counter_lens) 346 | counter_lens_threshold = max(np.percentile(counter_lens, 95), self.config.sam_mesh.get('counter_lens_threshold_min', 16)) 347 | print('Counter lens threshold: ', counter_lens_threshold) 348 | removed = [] 349 | for label, counter in connections_ratios.items(): 350 | if len(counter) > counter_lens_threshold: 351 | removed.append(label) 352 | for label in removed: 353 | connections_ratios.pop(label) 354 | for counter in connections_ratios.values(): 355 | if label in counter: 356 | counter.pop(label) 357 | 358 | ''' 359 | print('Count ratios:') 360 | for label1, counter in connections_ratios.items(): 361 | print(label1) 362 | for label2, count in counter.items(): 363 | print(label2, count) 364 | exit() 365 | ''' 366 | 367 | bins_resolution = self.config.sam_mesh.connections_bin_resolution 368 | bins = np.zeros(bins_resolution + 1) 369 | for label1, counter in connections_ratios.items(): 370 | #print(label1) 371 | for label2, ratio in counter.items(): 372 | #print(label2, ratio) 373 | bins[int(ratio * bins_resolution)] += 1 374 | cutoff = self.config.sam_mesh.connections_bin_threshold_percentage * np.sum(bins) # more connections means higher threshold 375 | accum = 0 376 | accum_bin = 0 377 | while accum < cutoff: 378 | accum += bins[accum_bin] 379 | accum_bin += 1 380 | 381 | ''' 382 | import matplotlib.pyplot as plt 383 | plt.clf() 384 | plt.bar(range(bins_resolution + 1), bins) 385 | plt.xlabel(f'Cutoff bin: {accum_bin}') 386 | plt.axvline(x=accum_bin, color='r') 387 | plt.savefig(f'ratios_{self.config.cache.stem}.png') 388 | ''' 389 | 390 | # construct match graph edges 391 | connections = [] 392 | connections_ratio_threshold = max(accum_bin / bins_resolution, 0.075) 393 | print('Connections ratio threshold: ', connections_ratio_threshold) 394 | for label1, counter in connections_ratios.items(): 395 | for label2, ratio12 in counter.items(): 396 | ratio21 = connections_ratios[label2][label1] 397 | # best buddy match above threshold 398 | if ratio12 > connections_ratio_threshold and \ 399 | ratio21 > connections_ratio_threshold: 400 | connections.append((label1, label2)) 401 | print('Found ', len(connections), ' connections') 402 | 403 | connection_graph = igraph.Graph(edges=connections, directed=False) 404 | connection_graph.simplify() 405 | communities = connection_graph.community_leiden(resolution_parameter=0) 406 | # for comm in communities: 407 | # print(comm) 408 | # exit() 409 | label2label_consistent = {} 410 | comm_count = 0 411 | for comm in communities: 412 | if len(comm) > 1: 413 | label2label_consistent.update({label: comm[0] for label in comm if label != comm[0]}) 414 | comm_count += 1 415 | print('Found ', comm_count, ' communities') 416 | 417 | print('Merging labels') 418 | face2label_combined = defaultdict(Counter) 419 | for face2label in face2label_views: 420 | face2label_combined.update(face2label) 421 | face2label_consistent = {} 422 | for face, labels in face2label_combined.items(): 423 | hook = labels.most_common(1)[0][0] 424 | if hook in label2label_consistent: 425 | hook = label2label_consistent[hook] 426 | face2label_consistent[face] = hook 427 | #print(sorted(face2label_consistent.values())) 428 | return face2label_consistent 429 | 430 | def smooth(self, face2label_consistent: dict) -> dict: 431 | """ 432 | """ 433 | # remove holes 434 | components = self.label_components(face2label_consistent) 435 | 436 | threshold_percentage_size = self.config.sam_mesh.smoothing_threshold_percentage_size 437 | threshold_percentage_area = self.config.sam_mesh.smoothing_threshold_percentage_area 438 | components = sorted(components, key=lambda x: len(x), reverse=True) 439 | components_area = [ 440 | sum([float(self.renderer.tmesh.area_faces[face]) for face in comp]) for comp in components 441 | ] 442 | max_size = max([len(comp) for comp in components]) 443 | max_area = max(components_area) 444 | 445 | remove_comp_size = set() 446 | remove_comp_area = set() 447 | for i, comp in enumerate(components): 448 | if len(comp) < max_size * threshold_percentage_size: 449 | remove_comp_size.add(i) 450 | if components_area[i] < max_area * threshold_percentage_area: 451 | remove_comp_area.add(i) 452 | remove_comp = remove_comp_size.intersection(remove_comp_area) 453 | print('Removing ', len(remove_comp), ' components') 454 | for i in remove_comp: 455 | for face in components[i]: 456 | face2label_consistent.pop(face) 457 | 458 | # fill islands 459 | print('Smoothing labels') 460 | smooth_iterations = self.config.sam_mesh.smoothing_iterations 461 | for iteration in range(smooth_iterations): 462 | count = 0 463 | changes = {} 464 | for face in range(len(self.renderer.tmesh.faces)): 465 | if face in face2label_consistent: 466 | continue 467 | labels_adj = Counter() 468 | for adj in self.mesh_graph[face]: 469 | if adj in face2label_consistent: 470 | label = face2label_consistent[adj] 471 | if label != 0: 472 | labels_adj[label] += 1 473 | if len(labels_adj): 474 | count += 1 475 | changes[face] = labels_adj.most_common(1)[0][0] 476 | for face, label in changes.items(): 477 | face2label_consistent[face] = label 478 | #print('Smoothing iteration ', iteration, ' changed ', count, ' faces') 479 | 480 | return face2label_consistent 481 | 482 | def split(self, face2label_consistent: dict) -> dict: 483 | """ 484 | """ 485 | components = self.label_components(face2label_consistent) 486 | 487 | labels_seen = set() 488 | labels_curr = max(face2label_consistent.values()) + 1 489 | labels_orig = labels_curr 490 | for comp in components: 491 | face = comp.pop() 492 | label = face2label_consistent[face] 493 | comp.add(face) 494 | if label == 0 or label in labels_seen: # background or repeated label 495 | face2label_consistent.update({face: labels_curr for face in comp}) 496 | labels_curr += 1 497 | labels_seen.add(label) 498 | print('Split', (labels_curr - labels_orig), 'times') # account for background 499 | 500 | return face2label_consistent 501 | 502 | def smooth_repartition_faces(self, face2label_consistent: dict, target_labels=None) -> dict: 503 | """ 504 | """ 505 | tmesh = self.renderer.tmesh 506 | 507 | partition = np.array([face2label_consistent[face] for face in range(len(tmesh.faces))]) 508 | 509 | cost_data = np.zeros((len(tmesh.faces), np.max(partition) + 1)) 510 | for f in range(len(tmesh.faces)): 511 | for l in np.unique(partition): 512 | cost_data[f, l] = 0 if partition[f] == l else 1 513 | cost_smoothness = -np.log(tmesh.face_adjacency_angles / np.pi + 1e-20) 514 | 515 | lambda_seed = self.config.sam_mesh.repartition_lambda 516 | if target_labels is None: 517 | refined_partition = repartition(tmesh, partition, cost_data, cost_smoothness, self.config.sam_mesh.repartition_iterations, lambda_seed) 518 | return { 519 | face: refined_partition[face] for face in range(len(tmesh.faces)) 520 | } 521 | 522 | lambda_range=( 523 | self.config.sam_mesh.repartition_lambda_lb, 524 | self.config.sam_mesh.repartition_lambda_ub 525 | ) 526 | lambdas = np.linspace(*lambda_range, num=mp.cpu_count()) 527 | chunks = [ 528 | (tmesh, partition, cost_data, cost_smoothness, self.config.sam_mesh.repartition_iterations, _lambda) 529 | for _lambda in lambdas 530 | ] 531 | with mp.Pool(mp.cpu_count() // 2) as pool: 532 | refined_partitions = pool.starmap(repartition, chunks) 533 | 534 | def compute_cur_labels(part, noise_threshold=10): 535 | """ 536 | """ 537 | values, counts = np.unique(part, return_counts=True) 538 | return values[counts > noise_threshold] 539 | 540 | # lambda crawling algorithm when target_labels is specified i.e. Princeton Mesh Segmentation Benchmark 541 | max_iteration = 8 542 | cur_iteration = 0 543 | cur_lambda_index = np.searchsorted(lambdas, lambda_seed) 544 | cur_labels = len(compute_cur_labels(refined_partitions[cur_lambda_index])) 545 | while not ( 546 | target_labels - self.config.sam_mesh.repartition_lambda_tolerance <= cur_labels and 547 | target_labels + self.config.sam_mesh.repartition_lambda_tolerance >= cur_labels 548 | ) and cur_iteration < max_iteration: 549 | 550 | if cur_labels < target_labels and cur_lambda_index > 0: 551 | # want more labels so decrease lambda 552 | cur_lambda_index -= 1 553 | if cur_labels > target_labels and cur_lambda_index < len(refined_partitions) - 1: 554 | # want less labels so increase lambda 555 | cur_lambda_index += 1 556 | 557 | cur_labels = len(compute_cur_labels(refined_partitions[cur_lambda_index])) 558 | cur_iteration += 1 559 | 560 | print('Repartitioned with ', cur_labels, ' labels aiming for ', target_labels, 'target labels using lambda ', lambdas[cur_lambda_index], ' in ', cur_iteration, ' iterations') 561 | 562 | refined_partition = refined_partitions[cur_lambda_index] 563 | return { 564 | face: refined_partition[face] for face in range(len(tmesh.faces)) 565 | } 566 | 567 | def forward(self, scene: Scene, visualize_path=None, target_labels=None) -> tuple[dict, Trimesh]: 568 | """ 569 | """ 570 | self.load(scene) 571 | renders = self.render(scene, visualize_path=visualize_path) 572 | face2label_consistent = self.lift(renders) 573 | face2label_consistent = self.smooth(face2label_consistent) 574 | # inject unlabeled faces after smoothing 575 | for face in range(len(self.renderer.tmesh.faces)): 576 | if face not in face2label_consistent: 577 | face2label_consistent[face] = 0 578 | face2label_consistent = self.split (face2label_consistent) # needed to label all faces for repartition 579 | face2label_consistent = self.smooth_repartition_faces(face2label_consistent, target_labels=target_labels) 580 | face2label_consistent = {int(k): int(v) for k, v in face2label_consistent.items()} # ensure serialization 581 | assert self.renderer.tmesh.faces.shape[0] == len(face2label_consistent) 582 | return face2label_consistent, self.renderer.tmesh 583 | 584 | def label_components(self, face2label: dict) -> list[set]: 585 | """ 586 | """ 587 | components = [] 588 | visited = set() 589 | 590 | def dfs(source: int): 591 | stack = [source] 592 | components.append({source}) 593 | visited.add(source) 594 | 595 | while stack: 596 | node = stack.pop() 597 | for adj in self.mesh_graph[node]: 598 | if adj not in visited and adj in face2label and face2label[adj] == face2label[node]: 599 | stack.append(adj) 600 | components[-1].add(adj) 601 | visited.add(adj) 602 | 603 | for face in range(len(self.renderer.tmesh.faces)): 604 | if face not in visited and face in face2label: 605 | dfs(face) 606 | return components 607 | 608 | 609 | def segment_mesh(filename: Path | str, config: OmegaConf, visualize=False, extension='glb', target_labels=None, texture=False) -> Trimesh: 610 | """ 611 | """ 612 | print('Segmenting mesh with SAMesh: ', filename) 613 | filename = Path(filename) 614 | config = copy.deepcopy(config) 615 | config.cache = Path(config.cache) / filename.stem if "cache" in config else None 616 | config.output = Path(config.output) / filename.stem 617 | 618 | model = SamModelMesh(config) 619 | tmesh = read_mesh(filename, norm=True) 620 | if not texture: 621 | tmesh = remove_texture(tmesh, visual_kind='vertex') 622 | 623 | # run sam grounded mesh and optionally visualize renders 624 | visualize_path = f'{config.output}/{filename.stem}_visualized' if visualize else None 625 | faces2label, _ = model(tmesh, visualize_path=visualize_path, target_labels=target_labels) 626 | 627 | # colormap and save mesh 628 | os.makedirs(config.output, exist_ok=True) 629 | tmesh_colored = colormap_faces_mesh(tmesh, faces2label) 630 | tmesh_colored.export (f'{config.output}/{filename.stem}_segmented.{extension}') 631 | json.dump(faces2label, open(f'{config.output}/{filename.stem}_face2label.json', 'w')) 632 | return tmesh_colored 633 | 634 | 635 | if __name__ == '__main__': 636 | import glob 637 | from natsort import natsorted 638 | 639 | def read_filenames(pattern: str): 640 | """ 641 | """ 642 | filenames = glob.glob(pattern) 643 | filenames = map(Path, filenames) 644 | filenames = natsorted(list(set(filenames))) 645 | print('Segmenting ', len(filenames), ' meshes') 646 | return filenames 647 | 648 | filenames = read_filenames('/home/gtangg12/data/samesh/backflip-benchmark-remeshed-processed/*.glb') 649 | config = OmegaConf.load('/home/gtangg12/samesh/configs/mesh_segmentation.yaml') 650 | for i, filename in enumerate(filenames): 651 | segment_mesh(filename, config, visualize=False) 652 | 653 | config_original = OmegaConf.load('/home/gtangg12/samesh/configs/mesh_segmentation_coseg.yaml') 654 | categories = ['candelabra', 'chairs', 'fourleg', 'goblets', 'guitars', 'irons', 'lamps', 'vases'] 655 | for cat in categories: 656 | filenames = read_filenames(f'/home/gtangg12/data/samesh/coseg/{cat}/*.off') 657 | for filename in filenames: 658 | config = copy.deepcopy(config_original) 659 | config.output = Path(config.output) / cat 660 | if "cache" in config: 661 | config.cache = Path(config.cache) / cat 662 | segment_mesh(filename, config, visualize=False) 663 | 664 | config = OmegaConf.load('/home/gtangg12/samesh/configs/mesh_segmentation_princeton.yaml') 665 | filenames = read_filenames('/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/off/*.off') 666 | for i, filename in enumerate(filenames): 667 | name, extension = filename.stem, filename.suffix[1:] 668 | category = (int(name) - 1) // 20 + 1 669 | if category in [14]: #[4, 8, 13, 14, 17]: 670 | continue 671 | segment_mesh(filename, config, visualize=False) 672 | 673 | with open('/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/util/parameters/nSeg-ByModel.txt') as f: 674 | target_labels_dict = {str(i): int(line) for i, line in enumerate(f.readlines(), 1)} 675 | 676 | config = OmegaConf.load('/home/gtangg12/samesh/configs/mesh_segmentation_princeton_dynamic.yaml') 677 | filenames = read_filenames('/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/off/*.off') 678 | for i, filename in enumerate(filenames): 679 | name, extension = filename.stem, filename.suffix[1:] 680 | category = (int(name) - 1) // 20 + 1 681 | if category in [14]: #[4, 8, 13, 14, 17]: 682 | continue 683 | segment_mesh(filename, config, visualize=False, target_labels=target_labels_dict[name]) -------------------------------------------------------------------------------- /src/samesh/models/shape_diameter_function.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import copy 5 | from pathlib import Path 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import pymeshlab 10 | import trimesh 11 | import networkx as nx 12 | import igraph 13 | from numpy.random import RandomState 14 | from trimesh.base import Trimesh, Scene 15 | from sklearn.mixture import GaussianMixture 16 | from tqdm import tqdm 17 | from omegaconf import OmegaConf 18 | 19 | from samesh.data.common import NumpyTensor 20 | from samesh.data.loaders import scene2mesh, read_mesh 21 | from samesh.utils.mesh import duplicate_verts 22 | 23 | 24 | EPSILON = 1e-20 25 | SCALE = 1e6 26 | 27 | 28 | def partition_cost( 29 | mesh : Trimesh, 30 | partition : NumpyTensor['f'], 31 | cost_data : NumpyTensor['f num_components'], 32 | cost_smoothness: NumpyTensor['e'] 33 | ) -> float: 34 | """ 35 | """ 36 | cost = 0 37 | for f in range(len(partition)): 38 | cost += cost_data[f, partition[f]] 39 | for i, edge in enumerate(mesh.face_adjacency): 40 | f1, f2 = int(edge[0]), int(edge[1]) 41 | if partition[f1] != partition[f2]: 42 | cost += cost_smoothness[i] 43 | return cost 44 | 45 | 46 | def construct_expansion_graph( 47 | label : int, 48 | mesh : Trimesh, 49 | partition : NumpyTensor['f'], 50 | cost_data : NumpyTensor['f num_components'], 51 | cost_smoothness: NumpyTensor['e'] 52 | ) -> nx.Graph: 53 | """ 54 | """ 55 | G = nx.Graph() # undirected graph 56 | A = 'alpha' 57 | B = 'alpha_complement' 58 | 59 | node2index = {} 60 | G.add_node(A) 61 | G.add_node(B) 62 | node2index[A] = 0 63 | node2index[B] = 1 64 | for i in range(len(mesh.faces)): 65 | G.add_node(i) 66 | node2index[i] = 2 + i 67 | 68 | aux_count = 0 69 | for i, edge in enumerate(mesh.face_adjacency): # auxillary nodes 70 | f1, f2 = int(edge[0]), int(edge[1]) 71 | if partition[f1] != partition[f2]: 72 | a = (f1, f2) 73 | if a in node2index: # duplicate edge 74 | continue 75 | G.add_node(a) 76 | node2index[a] = len(mesh.faces) + 2 + aux_count 77 | aux_count += 1 78 | 79 | for f in range(len(mesh.faces)): 80 | G.add_edge(A, f, capacity=cost_data[f, label]) 81 | G.add_edge(B, f, capacity=float('inf') if partition[f] == label else cost_data[f, partition[f]]) 82 | 83 | for i, edge in enumerate(mesh.face_adjacency): 84 | f1, f2 = int(edge[0]), int(edge[1]) 85 | a = (f1, f2) 86 | if partition[f1] == partition[f2]: 87 | if partition[f1] != label: 88 | G.add_edge(f1, f2, capacity=cost_smoothness[i]) 89 | else: 90 | G.add_edge(a, B, capacity=cost_smoothness[i]) 91 | if partition[f1] != label: 92 | G.add_edge(f1, a, capacity=cost_smoothness[i]) 93 | if partition[f2] != label: 94 | G.add_edge(a, f2, capacity=cost_smoothness[i]) 95 | 96 | return G, node2index 97 | 98 | 99 | def repartition( 100 | mesh: trimesh.Trimesh, 101 | partition : NumpyTensor['f'], 102 | cost_data : NumpyTensor['f num_components'], 103 | cost_smoothness: NumpyTensor['e'], 104 | smoothing_iterations: int, 105 | _lambda=1.0, 106 | ): 107 | A = 'alpha' 108 | B = 'alpha_complement' 109 | labels = np.unique(partition) 110 | 111 | cost_smoothness = cost_smoothness * _lambda 112 | 113 | # networkx broken for float capacities 114 | #cost_data = np.round(cost_data * SCALE).astype(int) 115 | #cost_smoothness = np.round(cost_smoothness * SCALE).astype(int) 116 | 117 | cost_min = partition_cost(mesh, partition, cost_data, cost_smoothness) 118 | 119 | for i in range(smoothing_iterations): 120 | 121 | #print('Repartition iteration ', i) 122 | 123 | for label in tqdm(labels): 124 | G, node2index = construct_expansion_graph(label, mesh, partition, cost_data, cost_smoothness) 125 | index2node = {v: k for k, v in node2index.items()} 126 | 127 | ''' 128 | _, (S, T) = nx.minimum_cut(G, A, B) 129 | assert A in S and B in T 130 | S = np.array([v for v in S if isinstance(v, int)]).astype(int) 131 | T = np.array([v for v in T if isinstance(v, int)]).astype(int) 132 | ''' 133 | 134 | G = igraph.Graph.from_networkx(G) 135 | outputs = G.st_mincut(source=node2index[A], target=node2index[B], capacity='capacity') 136 | S = outputs.partition[0] 137 | T = outputs.partition[1] 138 | assert node2index[A] in S and node2index[B] in T 139 | S = np.array([index2node[v] for v in S if isinstance(index2node[v], int)]).astype(int) 140 | T = np.array([index2node[v] for v in T if isinstance(index2node[v], int)]).astype(int) 141 | 142 | assert (partition[S] == label).sum() == 0 # T consists of those assigned 'alpha' and S 'alpha_complement' (see paper) 143 | partition[T] = label 144 | 145 | cost = partition_cost(mesh, partition, cost_data, cost_smoothness) 146 | if cost > cost_min: 147 | raise ValueError('Cost increased. This should not happen because the graph cut is optimal.') 148 | cost_min = cost 149 | 150 | return partition 151 | 152 | 153 | def prep_mesh_shape_diameter_function(source: Trimesh | Scene) -> Trimesh: 154 | """ 155 | """ 156 | if isinstance(source, trimesh.Scene): 157 | source = scene2mesh(source) 158 | source.merge_vertices(merge_tex=True, merge_norm=True) 159 | return source 160 | 161 | 162 | def colormap_shape_diameter_function(mesh: Trimesh, sdf_values: NumpyTensor['f']) -> Trimesh: 163 | """ 164 | """ 165 | assert len(mesh.faces) == len(sdf_values) 166 | mesh = duplicate_verts(mesh) # needed to prevent face color interpolation 167 | mesh.visual.face_colors = trimesh.visual.interpolate(sdf_values, color_map='jet') 168 | return mesh 169 | 170 | 171 | def colormap_partition(mesh: Trimesh, partition: NumpyTensor['f']) -> Trimesh: 172 | """ 173 | """ 174 | assert len(mesh.faces) == len(partition) 175 | palette = RandomState(0).randint(0, 255, (np.max(partition) + 1, 3)) # must init every time to get same colors 176 | mesh = duplicate_verts(mesh) # needed to prevent face color interpolation 177 | mesh.visual.face_colors = palette[partition] 178 | return mesh 179 | 180 | 181 | def shape_diameter_function(mesh: Trimesh, norm=True, alpha=4, rays=64, cone_amplitude=120) -> NumpyTensor['f']: 182 | """ 183 | """ 184 | mesh = pymeshlab.Mesh(mesh.vertices, mesh.faces) 185 | meshset = pymeshlab.MeshSet() 186 | meshset.add_mesh(mesh) 187 | meshset.compute_scalar_by_shape_diameter_function_per_vertex(rays=rays, cone_amplitude=cone_amplitude) 188 | 189 | sdf_values = meshset.current_mesh().face_scalar_array() 190 | sdf_values[np.isnan(sdf_values)] = 0 191 | if norm: 192 | # normalize and smooth shape diameter function values 193 | min = sdf_values.min() 194 | max = sdf_values.max() 195 | sdf_values = (sdf_values - min) / (max - min) 196 | sdf_values = np.log(sdf_values * alpha + 1) / np.log(alpha + 1) 197 | return sdf_values 198 | 199 | 200 | def partition_faces(mesh: Trimesh, num_components: int, _lambda: float, smooth=True, smoothing_iterations=1, **kwargs) -> NumpyTensor['f']: 201 | """ 202 | """ 203 | sdf_values = shape_diameter_function(mesh, norm=True).reshape(-1, 1) 204 | 205 | # fit 1D GMM to shape diameter function values 206 | gmm = GaussianMixture(num_components) 207 | gmm.fit(sdf_values) 208 | probs = gmm.predict_proba(sdf_values) 209 | if not smooth: 210 | return np.argmax(probs, axis=1) 211 | 212 | # data and smoothness terms 213 | cost_data = -np.log(probs + EPSILON) 214 | cost_smoothness = -np.log(mesh.face_adjacency_angles / np.pi + EPSILON) 215 | cost_smoothness = _lambda * cost_smoothness 216 | 217 | # generate initial partition and refine with alpha expansion graph cut 218 | partition = np.argmin(cost_data, axis=1) 219 | partition = repartition(mesh, partition, cost_data, cost_smoothness, smoothing_iterations=smoothing_iterations) 220 | return partition 221 | 222 | 223 | def partition2label(mesh: Trimesh, partition: NumpyTensor['f']) -> NumpyTensor['f']: 224 | """ 225 | """ 226 | edges = trimesh.graph.face_adjacency(mesh=mesh) 227 | graph = defaultdict(set) 228 | for face1, face2 in edges: 229 | graph[face1].add(face2) 230 | graph[face2].add(face1) 231 | labels = set(list(np.unique(partition))) 232 | 233 | components = [] 234 | visited = set() 235 | 236 | def dfs(source: int): 237 | stack = [source] 238 | components.append({source}) 239 | visited.add(source) 240 | 241 | while stack: 242 | node = stack.pop() 243 | for adj in graph[node]: 244 | if adj not in visited and partition[adj] == partition[node]: 245 | stack.append(adj) 246 | components[-1].add(adj) 247 | visited.add(adj) 248 | 249 | for face in range(len(mesh.faces)): 250 | if face not in visited: 251 | dfs(face) 252 | 253 | partition_output = np.zeros_like(partition) 254 | label_total = 0 255 | for component in components: 256 | for face in component: 257 | partition_output[face] = label_total 258 | label_total += 1 259 | return partition_output 260 | 261 | 262 | def segment_mesh_sdf(filename: Path | str, config: OmegaConf, extension='glb') -> Trimesh: 263 | """ 264 | """ 265 | print('Segmenting mesh with Shape Diameter Funciont: ', filename) 266 | filename = Path(filename) 267 | config = copy.deepcopy(config) 268 | config.output = Path(config.output) / filename.stem 269 | 270 | mesh = read_mesh(filename, norm=True) 271 | mesh = prep_mesh_shape_diameter_function(mesh) 272 | partition = partition_faces(mesh, config.num_components, config.repartition_lambda, config.repartition_iterations) 273 | partition_disconnected = partition2label(mesh, partition) 274 | faces2label = {int(i): int(partition_disconnected[i]) for i in range(len(partition_disconnected))} 275 | 276 | os.makedirs(config.output, exist_ok=True) 277 | mesh_colored = colormap_partition(mesh, partition_disconnected) 278 | mesh_colored.export (f'{config.output}/{filename.stem}_segmented.{extension}') 279 | json.dump(faces2label, open(f'{config.output}/{filename.stem}_face2label.json', 'w')) 280 | return mesh_colored 281 | 282 | 283 | if __name__ == '__main__': 284 | import glob 285 | from natsort import natsorted 286 | 287 | def read_filenames(pattern: str): 288 | """ 289 | """ 290 | filenames = glob.glob(pattern) 291 | filenames = map(Path, filenames) 292 | filenames = natsorted(list(set(filenames))) 293 | print('Segmenting ', len(filenames), ' meshes') 294 | return filenames 295 | 296 | filenames = read_filenames('/home/gtangg12/data/samesh/backflip-benchmark-remeshed-processed/*.glb') 297 | #filenames = [Path('/home/gtangg12/data/samesh.backflip-benchmark-remeshed-processed/jacket.glb')] 298 | config = OmegaConf.load('/home/gtangg12/samesh/configs/mesh_segmentation_shape_diameter_function.yaml') 299 | for i, filename in enumerate(filenames): 300 | segment_mesh_sdf(filename, config) 301 | 302 | config_original = OmegaConf.load('/home/gtangg12/samesh/configs/mesh_segmentation_shape_diameter_function_coseg.yaml') 303 | categories = ['candelabra', 'chairs', 'fourleg', 'goblets', 'guitars', 'irons', 'lamps', 'vases'] 304 | for cat in categories: 305 | filenames = read_filenames(f'/home/gtangg12/data/samesh/coseg/{cat}/*.off') 306 | for filename in filenames: 307 | config = copy.deepcopy(config_original) 308 | config.output = Path(config.output) / cat 309 | segment_mesh_sdf(filename, config) 310 | 311 | config_original = OmegaConf.load('/home/gtangg12/samesh/configs/mesh_segmentation_shape_diameter_function_princeton.yaml') 312 | filenames = read_filenames('/home/gtangg12/data/samesh/MeshsegBenchmark-1.0/data/off/*.off') 313 | for i, filename in enumerate(filenames): 314 | name, extension = filename.stem, filename.suffix[1:] 315 | category = (int(name) - 1) // 20 + 1 316 | if category in [14]: #[4, 8, 13, 14, 17]: 317 | continue 318 | config = copy.deepcopy(config_original) 319 | segment_mesh_sdf(filename, config) -------------------------------------------------------------------------------- /src/samesh/renderer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/src/samesh/renderer/__init__.py -------------------------------------------------------------------------------- /src/samesh/renderer/renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 3 | os.environ['EGL_DEVICE_ID'] = '-1' # NOTE: necessary to not create GPU contention 4 | 5 | ### START VOODOO ### 6 | # Dark encantation for disabling anti-aliasing in pyrender (if needed) 7 | import OpenGL.GL 8 | antialias_active = False 9 | old_gl_enable = OpenGL.GL.glEnable 10 | def new_gl_enable(value): 11 | if not antialias_active and value == OpenGL.GL.GL_MULTISAMPLE: 12 | OpenGL.GL.glDisable(value) 13 | else: 14 | old_gl_enable(value) 15 | OpenGL.GL.glEnable = new_gl_enable 16 | import pyrender 17 | ### END VOODOO ### 18 | 19 | import cv2 20 | import numpy as np 21 | import torch 22 | from numpy.random import RandomState 23 | from PIL import Image 24 | from pyrender.shader_program import ShaderProgramCache as DefaultShaderCache 25 | from trimesh import Trimesh, Scene 26 | from omegaconf import OmegaConf 27 | from tqdm import tqdm 28 | 29 | from samesh.data.common import NumpyTensor 30 | from samesh.data.loaders import scene2mesh 31 | from samesh.utils.cameras import HomogeneousTransform, sample_view_matrices, sample_view_matrices_polyhedra 32 | from samesh.utils.math import range_norm 33 | from samesh.utils.mesh import duplicate_verts 34 | from samesh.renderer.shader_programs import * 35 | 36 | 37 | def colormap_faces(faces: NumpyTensor['h w'], background=np.array([255, 255, 255])) -> Image.Image: 38 | """ 39 | Given a face id map, color each face with a random color. 40 | """ 41 | #print(np.unique(faces, return_counts=True)) 42 | palette = RandomState(0).randint(0, 255, (np.max(faces + 2), 3)) # must init every time to get same colors 43 | #print(palette) 44 | palette[0] = background 45 | image = palette[faces + 1, :].astype(np.uint8) # shift -1 to 0 46 | return Image.fromarray(image) 47 | 48 | 49 | def colormap_norms(norms: NumpyTensor['h w'], background=np.array([255, 255, 255])) -> Image.Image: 50 | """ 51 | Given a normal map, color each normal with a color. 52 | """ 53 | norms = (norms + 1) / 2 54 | norms = (norms * 255).astype(np.uint8) 55 | return Image.fromarray(norms) 56 | 57 | 58 | DEFAULT_CAMERA_PARAMS = {'fov': 60, 'znear': 0.01, 'zfar': 16} 59 | 60 | 61 | class Renderer: 62 | """ 63 | """ 64 | def __init__(self, config: OmegaConf): 65 | """ 66 | """ 67 | self.config = config 68 | self.renderer = pyrender.OffscreenRenderer(*config.target_dim) 69 | self.shaders = { 70 | 'default': DefaultShaderCache(), 71 | 'normals': NormalShaderCache(), 72 | 'faceids': FaceidShaderCache(), 73 | 'barycnt': BarycentricShaderCache(), 74 | } 75 | 76 | def set_object(self, source: Trimesh | Scene, smooth=False): 77 | """ 78 | """ 79 | if isinstance(source, Scene): 80 | self.tmesh = scene2mesh(source) 81 | self.scene = pyrender.Scene(ambient_light=[1.0, 1.0, 1.0]) # RGB no direction 82 | for name, geom in source.geometry.items(): 83 | if name in source.graph: 84 | pose, _ = source.graph[name] 85 | else: 86 | pose = None 87 | self.scene.add(pyrender.Mesh.from_trimesh(geom, smooth=smooth), pose=pose) 88 | 89 | elif isinstance(source, Trimesh): 90 | self.tmesh = source 91 | self.scene = pyrender.Scene(ambient_light=[1.0, 1.0, 1.0]) 92 | self.scene.add(pyrender.Mesh.from_trimesh(source, smooth=smooth)) 93 | 94 | else: 95 | raise ValueError(f'Invalid source type {type(source)}') 96 | 97 | # rearrange mesh for faceid rendering 98 | self.tmesh_faceid = duplicate_verts(self.tmesh) 99 | self.scene_faceid = pyrender.Scene(ambient_light=[1.0, 1.0, 1.0]) 100 | self.scene_faceid.add( 101 | pyrender.Mesh.from_trimesh(self.tmesh_faceid, smooth=smooth) 102 | ) 103 | 104 | def set_camera(self, camera_params: dict = None): 105 | """ 106 | """ 107 | self.camera_params = camera_params or dict(DEFAULT_CAMERA_PARAMS) 108 | self.camera_params['yfov'] = self.camera_params.get('yfov', self.camera_params.pop('fov')) 109 | self.camera_params['yfov'] = self.camera_params['yfov'] * np.pi / 180.0 110 | self.camera = pyrender.PerspectiveCamera(**self.camera_params) 111 | 112 | self.camera_node = self.scene .add(self.camera) 113 | self.camera_node_faceid = self.scene_faceid.add(self.camera) 114 | 115 | def render( 116 | self, 117 | pose: HomogeneousTransform, 118 | lightdir=np.array([0.0, 0.0, 1.0]), uv_map=False, interpolate_norms=True, blur_matte=False 119 | ) -> dict: 120 | """ 121 | """ 122 | self.scene .set_pose(self.camera_node , pose) 123 | self.scene_faceid.set_pose(self.camera_node_faceid, pose) 124 | 125 | def render(shader: str, scene): 126 | """ 127 | """ 128 | self.renderer._renderer._program_cache = self.shaders[shader] 129 | return self.renderer.render(scene) 130 | 131 | if uv_map: 132 | raw_color, raw_depth = render('default', self.scene) 133 | raw_norms, raw_depth = render('normals', self.scene) 134 | raw_faces, raw_depth = render('faceids', self.scene_faceid) 135 | raw_bcent, raw_depth = render('barycnt', self.scene_faceid) 136 | 137 | def render_norms(norms: NumpyTensor['h w 3']) -> NumpyTensor['h w 3']: 138 | """ 139 | """ 140 | return np.clip((norms / 255.0 - 0.5) * 2, -1, 1) 141 | 142 | def render_depth(depth: NumpyTensor['h w'], offset=2.8, alpha=0.8) -> NumpyTensor['h w']: 143 | """ 144 | """ 145 | return np.where(depth > 0, alpha * (1.0 - range_norm(depth, offset=offset)), 1) 146 | 147 | def render_faces(faces: NumpyTensor['h w 3']) -> NumpyTensor['h w']: 148 | """ 149 | """ 150 | faces = faces.astype(np.int32) 151 | faces = faces[:, :, 0] * 65536 + faces[:, :, 1] * 256 + faces[:, :, 2] 152 | faces[faces == (256 ** 3 - 1)] = -1 # set background to -1 153 | return faces 154 | 155 | def render_bcent(bcent: NumpyTensor['h w 3']) -> NumpyTensor['h w 3']: 156 | """ 157 | """ 158 | return np.clip(bcent / 255.0, 0, 1) 159 | 160 | def render_matte( 161 | norms: NumpyTensor['h w 3'], 162 | depth: NumpyTensor['h w'], 163 | faces: NumpyTensor['h w'], 164 | bcent: NumpyTensor['h w 3'], 165 | alpha=0.5, beta=0.25, gaussian_kernel_width=5, gaussian_sigma=1, 166 | ) -> NumpyTensor['h w 3']: 167 | """ 168 | """ 169 | if interpolate_norms: # NOTE requires process=True 170 | verts_index = self.tmesh.faces[faces.reshape(-1)] # (n, 3) 171 | verts_norms = self.tmesh.vertex_normals[verts_index] # (n, 3, 3) 172 | norms = np.sum(verts_norms * bcent.reshape(-1, 3, 1), axis=1) 173 | norms = norms.reshape(bcent.shape) 174 | 175 | diffuse = np.sum(norms * lightdir, axis=2) 176 | diffuse = np.clip(diffuse, -1, 1) 177 | matte = 255 * (diffuse[:, :, None] * alpha + beta) 178 | matte = np.where(depth[:, :, None] > 0, matte, 255) 179 | matte = np.clip(matte, 0, 255).astype(np.uint8) 180 | matte = np.repeat(matte, 3, axis=2) 181 | 182 | if blur_matte: 183 | matte = (faces == -1)[:, :, None] * matte + \ 184 | (faces != -1)[:, :, None] * cv2.GaussianBlur(matte, (gaussian_kernel_width, gaussian_kernel_width), gaussian_sigma) 185 | return matte 186 | 187 | norms = render_norms(raw_norms) 188 | depth = render_depth(raw_depth) 189 | faces = render_faces(raw_faces) 190 | bcent = render_bcent(raw_bcent) 191 | matte = raw_color if uv_map else render_matte(norms, raw_depth, faces, bcent) # use original depth for matte 192 | 193 | return {'norms': norms, 'depth': depth, 'matte': matte, 'faces': faces} 194 | 195 | 196 | def render_multiview( 197 | renderer: Renderer, 198 | camera_generation_method='sphere', 199 | renderer_args: dict=None, 200 | sampling_args: dict=None, 201 | lighting_args: dict=None, 202 | lookat_position=np.array([0, 0, 0]), 203 | verbose=True, 204 | ) -> list[Image.Image]: 205 | """ 206 | """ 207 | lookat_position_torch = torch.from_numpy(lookat_position) 208 | if camera_generation_method == 'sphere': 209 | views = sample_view_matrices(lookat_position=lookat_position_torch, **sampling_args).numpy() 210 | else: 211 | views = sample_view_matrices_polyhedra(camera_generation_method, lookat_position=lookat_position_torch, **sampling_args).numpy() 212 | 213 | def compute_lightdir(pose: HomogeneousTransform) -> NumpyTensor[3]: 214 | """ 215 | """ 216 | lightdir = pose[:3, 3] - (lookat_position) 217 | return lightdir / np.linalg.norm(lightdir) 218 | 219 | renders = [] 220 | if verbose: 221 | views = tqdm(views, 'Rendering Multiviews...') 222 | for pose in views: 223 | outputs = renderer.render(pose, lightdir=compute_lightdir(pose), **renderer_args) 224 | outputs['matte'] = Image.fromarray(outputs['matte']) 225 | outputs['poses'] = pose 226 | renders.append(outputs) 227 | return { 228 | name: [render[name] for render in renders] for name in renders[0].keys() 229 | } 230 | 231 | 232 | if __name__ == '__main__': 233 | from PIL import Image 234 | from samesh.data.loaders import read_mesh, read_scene, remove_texture, scene2mesh 235 | from samesh.models.shape_diameter_function import shape_diameter_function, colormap_shape_diameter_function, prep_mesh_shape_diameter_function 236 | 237 | ''' 238 | NOTE:: if you get ctypes.ArgumentError 239 | 240 | https://github.com/mmatl/pyrender/issues/284 241 | ''' 242 | name = 'potion' 243 | extension = 'glb' 244 | source1 = read_mesh(f'/home/ubuntu/data/BackflipMeshes/{name}.{extension}', norm=True) 245 | #source1 = remove_texture(source1, visual_kind='vertex') 246 | source1 = prep_mesh_shape_diameter_function(source1) 247 | source1 = colormap_shape_diameter_function(source1, shape_diameter_function(source1)) 248 | source1.export('test_mesh1.glb') 249 | source2 = read_scene(f'/home/ubuntu/data/BackflipMeshes/{name}.{extension}', norm=True) 250 | #source2 = remove_texture(source2, visual_kind='vertex') 251 | source2 = prep_mesh_shape_diameter_function(source2) 252 | source2 = colormap_shape_diameter_function(source2, shape_diameter_function(source2)) 253 | source2.export('test_mesh2.glb') 254 | source3 = read_scene('/home/ubuntu/meshseg/tests/examples/valve.glb', norm=True) 255 | source3 = remove_texture(source3, visual_kind='vertex') # remove texture for meshlab to export 256 | source3 = prep_mesh_shape_diameter_function(source3) 257 | source3 = colormap_shape_diameter_function(source3, shape_diameter_function(source3)) 258 | source3.export('test_mesh3.glb') 259 | 260 | pose = np.array([ 261 | [ 1, 0, 0, 0], 262 | [ 0, 1, 0, 0], 263 | [ 0, 0, 1, 2.5], 264 | [ 0, 0, 0, 1], 265 | ]) 266 | 267 | renderer = Renderer(OmegaConf.create({ 268 | 'target_dim': (1024, 1024), 269 | })) 270 | 271 | renderer.set_object(source1) 272 | renderer.set_camera() 273 | renders = renderer.render(pose) 274 | for k, v in renders.items(): 275 | print(k, v.shape) 276 | image = Image.fromarray(renders['matte']) 277 | image.save('test_matte_mesh.png') 278 | image_faceids = colormap_faces(renders['faces']) 279 | image_faceids.save('test_faceids_mesh.png') 280 | image_norms = colormap_norms(renders['norms']) 281 | image_norms.save('test_norms_mesh.png') 282 | 283 | renderer.set_object(source2, smooth=False) 284 | renderer.set_camera() 285 | renders = renderer.render(pose, interpolate_norms=True) 286 | for k, v in renders.items(): 287 | print(k, v.shape) 288 | image = Image.fromarray(renders['matte']) 289 | image.save('test_matte_scene.png') 290 | image_faceids = colormap_faces(renders['faces']) 291 | image_faceids.save('test_faceids_scene.png') 292 | image_norms = colormap_norms(renders['norms']) 293 | image_norms.save('test_norms_scene.png') 294 | 295 | renderer.set_object(source3, smooth=False) 296 | renderer.set_camera() 297 | renders = renderer.render(pose, interpolate_norms=True) 298 | for k, v in renders.items(): 299 | print(k, v.shape) 300 | image = Image.fromarray(renders['matte']) 301 | image.save('test_matte_objaverse.png') 302 | image_faceids = colormap_faces(renders['faces']) 303 | image_faceids.save('test_faceids_objaverse.png') 304 | image_norms = colormap_norms(renders['norms']) 305 | image_norms.save('test_norms_objaverse.png') -------------------------------------------------------------------------------- /src/samesh/renderer/renderer_animations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from trimesh import Trimesh, Scene 4 | from omegaconf import OmegaConf 5 | 6 | from samesh.data.loaders import scene2mesh, read_mesh 7 | from samesh.renderer.renderer import Renderer, render_multiview 8 | from samesh.utils.mesh import duplicate_verts 9 | 10 | 11 | def images2gif(images: list[Image.Image], path, duration=100, loop=0): 12 | """ 13 | """ 14 | images[0].save(path, save_all=True, append_images=images[1:], duration=duration, loop=loop) 15 | 16 | 17 | def mesh2gif( 18 | source: Scene | Trimesh, path: str, fps: int, length: int, size=1024, key='matte', colormap=None, **kwargs 19 | ): 20 | """ 21 | """ 22 | if isinstance(source, Scene): 23 | source = scene2mesh(source) 24 | 25 | renderer = Renderer(OmegaConf.create({'target_dim': (size, size)})) 26 | renderer.set_object(source) 27 | renderer.set_camera() 28 | 29 | duration = length / fps 30 | print(f'Rendering {length} frames at {fps} fps for {duration} s') 31 | renders = render_multiview( 32 | renderer, 33 | camera_generation_method='swirl', 34 | renderer_args=kwargs.pop('renderer_args', {}), 35 | sampling_args=kwargs.pop('sampling_args', {'n': length, 'radius': 3}), 36 | lighting_args=kwargs.pop('lighting_args', {}), 37 | ) 38 | blend = kwargs.pop('blend', 0) 39 | if key == 'face_colors': 40 | images = [] 41 | from samesh.renderer.renderer import colormap_faces 42 | for faces, matte in zip(renders['faces'], renders['matte']): 43 | image = source.visual.face_colors[faces] 44 | image = (1 - blend) * image[:, :, :3] + blend * np.array(matte) # blend 45 | image[faces == -1] = kwargs.pop('background', 0) 46 | image = Image.fromarray(image.astype(np.uint8)) 47 | images.append(image) 48 | 49 | elif key == 'vertex_colors': 50 | raise NotImplementedError 51 | 52 | else: 53 | colormap = colormap or (lambda x: x) 54 | images = [colormap(image) for image in renders[key]] 55 | 56 | images2gif(images, path, duration=duration) 57 | 58 | 59 | if __name__ == '__main__': 60 | source = read_mesh('/home/ubuntu/meshseg/tests/mesh_segmentation_output-0.075-3-0.5/basin/basin_segmented.glb', process=False) 61 | mesh2gif(source, path='segmented.gif', fps=30, length=120, key='face_colors', blend=0.5) -------------------------------------------------------------------------------- /src/samesh/renderer/shader_programs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pyrender 5 | from trimesh import Trimesh 6 | 7 | from samesh.data.common import NumpyTensor 8 | 9 | SHADERS_PATH = os.path.join(os.path.dirname(__file__), 'shaders') 10 | 11 | 12 | class NormalShaderCache: 13 | """ 14 | """ 15 | def __init__(self): 16 | self.program = None 17 | 18 | def get_program(self, vertex_shader, fragment_shader, geometry_shader=None, defines=None): 19 | """ 20 | """ 21 | self.program = self.program or pyrender.shader_program.ShaderProgram( 22 | f'{SHADERS_PATH}/normal.vert', 23 | f'{SHADERS_PATH}/normal.frag', 24 | defines=defines 25 | ) 26 | return self.program 27 | 28 | 29 | class BarycentricShaderCache: 30 | """ 31 | """ 32 | def __init__(self): 33 | self.program = None 34 | 35 | def get_program(self, vertex_shader, fragment_shader, geometry_shader=None, defines=None): 36 | """ 37 | """ 38 | self.program = self.program or pyrender.shader_program.ShaderProgram( 39 | f'{SHADERS_PATH}/barycentric.vert', 40 | f'{SHADERS_PATH}/barycentric.frag', 41 | defines=defines 42 | ) 43 | return self.program 44 | 45 | 46 | class FaceidShaderCache: 47 | """ 48 | """ 49 | def __init__(self): 50 | self.program = None 51 | 52 | def get_program(self, vertex_shader, fragment_shader, geometry_shader=None, defines=None): 53 | """ 54 | """ 55 | self.program = self.program or pyrender.shader_program.ShaderProgram( 56 | f'{SHADERS_PATH}/faceid.vert', 57 | f'{SHADERS_PATH}/faceid.frag', 58 | defines=defines 59 | ) 60 | return self.program -------------------------------------------------------------------------------- /src/samesh/renderer/shaders/barycentric.frag: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | in float frag_x; 4 | in float frag_y; 5 | 6 | out vec4 frag_color; 7 | 8 | void main() 9 | { 10 | vec3 color = vec3(frag_x, frag_y, 1.0 - frag_x - frag_y); 11 | 12 | frag_color = vec4(color, 1); 13 | } -------------------------------------------------------------------------------- /src/samesh/renderer/shaders/barycentric.vert: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | // Vertex Attributes 4 | layout(location = 0) in vec3 position; 5 | layout(location = INST_M_LOC) in mat4 inst_m; 6 | 7 | // Uniforms 8 | uniform mat4 M; 9 | uniform mat4 V; 10 | uniform mat4 P; 11 | 12 | // Outputs 13 | out float frag_x; 14 | out float frag_y; 15 | 16 | void main() 17 | { 18 | gl_Position = P * V * M * inst_m * vec4(position, 1.0); 19 | 20 | // assumes mesh has been constructed so same face vertices are adjacent and duplicated for each face 21 | frag_x = float(((uint(gl_VertexID) + 1u) % 3u) % 2u); 22 | frag_y = float(((uint(gl_VertexID) + 0u) % 3u) % 2u); 23 | } -------------------------------------------------------------------------------- /src/samesh/renderer/shaders/faceid.frag: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | flat in int faceid; 4 | 5 | out vec4 frag_color; 6 | 7 | void main() 8 | { 9 | int x = faceid / 65536; 10 | int r = faceid % 65536; 11 | int y = r / 256; 12 | int z = r % 256; 13 | // Create a color based on the normalized ID 14 | vec3 color = vec3(float(x / 255.0), float(y / 255.0), float(z / 255.0)); 15 | 16 | frag_color = vec4(color, 1); 17 | } -------------------------------------------------------------------------------- /src/samesh/renderer/shaders/faceid.vert: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | // Vertex Attributes 4 | layout(location = 0) in vec3 position; 5 | layout(location = INST_M_LOC) in mat4 inst_m; 6 | 7 | // Uniforms 8 | uniform mat4 M; 9 | uniform mat4 V; 10 | uniform mat4 P; 11 | 12 | // Outputs 13 | flat out int faceid; 14 | 15 | void main() 16 | { 17 | gl_Position = P * V * M * inst_m * vec4(position, 1.0); 18 | 19 | // assumes mesh has been constructed so same face vertices are adjacent and duplicated for each face 20 | faceid = int(gl_VertexID / 3); 21 | } -------------------------------------------------------------------------------- /src/samesh/renderer/shaders/normal.frag: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | in vec3 frag_position; 4 | in vec3 frag_normal; 5 | 6 | out vec4 frag_color; 7 | 8 | void main() 9 | { 10 | vec3 normal = normalize(frag_normal); 11 | 12 | frag_color = vec4(normal * 0.5 + 0.5, 1.0); 13 | } -------------------------------------------------------------------------------- /src/samesh/renderer/shaders/normal.vert: -------------------------------------------------------------------------------- 1 | #version 330 core 2 | 3 | // Vertex Attributes 4 | layout(location = 0) in vec3 position; 5 | layout(location = NORMAL_LOC) in vec3 normal; 6 | layout(location = INST_M_LOC) in mat4 inst_m; 7 | 8 | // Uniforms 9 | uniform mat4 M; 10 | uniform mat4 V; 11 | uniform mat4 P; 12 | 13 | // Outputs 14 | out vec3 frag_position; 15 | out vec3 frag_normal; 16 | 17 | void main() 18 | { 19 | gl_Position = P * V * M * inst_m * vec4(position, 1.0); 20 | frag_position = vec3(M * inst_m * vec4(position, 1.0)); 21 | 22 | mat4 N = transpose(inverse(M * inst_m)); 23 | frag_normal = normalize(vec3(N * vec4(normal, 0.0))); 24 | } -------------------------------------------------------------------------------- /src/samesh/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gtangg12/samesh/9417875e1ad43234a0b66d438dd5dae2b2a1edbe/src/samesh/utils/__init__.py -------------------------------------------------------------------------------- /src/samesh/utils/cameras.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from samesh.data.common import NumpyTensor, TorchTensor 6 | from samesh.utils.polyhedra import * 7 | 8 | 9 | HomogeneousTransform = NumpyTensor['b... 4 4'] | TorchTensor['b... 4 4'] 10 | 11 | 12 | def matrix3x4_to_4x4(matrix3x4: HomogeneousTransform) -> HomogeneousTransform: 13 | """ 14 | Convert a 3x4 transformation matrix to a 4x4 transformation matrix. 15 | """ 16 | bottom = torch.zeros_like(matrix3x4[:, 0, :].unsqueeze(-2)) 17 | bottom[..., -1] = 1 18 | return torch.cat([matrix3x4, bottom], dim=-2) 19 | 20 | 21 | def view_matrix( 22 | camera_position: TorchTensor['n... 3'], 23 | lookat_position: TorchTensor['n... 3'] = torch.tensor([0, 0, 0]), 24 | up : TorchTensor['3'] = torch.tensor([0, 1, 0]), 25 | ) -> HomogeneousTransform: 26 | """ 27 | Given lookat position, camera position, and up vector, compute cam2world poses. 28 | """ 29 | if camera_position.ndim == 1: 30 | camera_position = camera_position.unsqueeze(0) 31 | if lookat_position.ndim == 1: 32 | lookat_position = lookat_position.unsqueeze(0) 33 | camera_position = camera_position.float() 34 | lookat_position = lookat_position.float() 35 | 36 | cam_u = up.unsqueeze(0).repeat(len(lookat_position), 1).float().to(camera_position.device) 37 | 38 | # handle degenerate cases 39 | crossp = torch.abs(torch.cross(lookat_position - camera_position, cam_u, dim=-1)).max(dim=-1).values 40 | camera_position[crossp < 1e-6] += 1e-6 41 | 42 | cam_z = F.normalize((lookat_position - camera_position), dim=-1) 43 | cam_x = F.normalize(torch.cross(cam_z, cam_u, dim=-1), dim=-1) 44 | cam_y = F.normalize(torch.cross(cam_x, cam_z, dim=-1), dim=-1) 45 | poses = torch.stack([cam_x, cam_y, -cam_z, camera_position], dim=-1) # same as nerfstudio convention [right, up, -lookat] 46 | poses = matrix3x4_to_4x4(poses) 47 | return poses 48 | 49 | 50 | def sample_view_matrices(n: int, radius: float, lookat_position: TorchTensor['3']=torch.tensor([0, 0, 0])) -> HomogeneousTransform: 51 | """ 52 | Sample n uniformly distributed view matrices spherically with given radius. 53 | """ 54 | tht = torch.rand(n) * torch.pi * 2 55 | phi = torch.rand(n) * torch.pi 56 | world_x = radius * torch.sin(phi) * torch.cos(tht) 57 | world_y = radius * torch.sin(phi) * torch.sin(tht) 58 | world_z = radius * torch.cos(phi) 59 | camera_position = torch.stack([world_x, world_y, world_z], dim=-1) 60 | lookat_position = lookat_position.unsqueeze(0).repeat(n, 1) 61 | return view_matrix( 62 | camera_position.to(lookat_position.device), 63 | lookat_position, 64 | up=torch.tensor([0, 1, 0], device=lookat_position.device) 65 | ) 66 | 67 | 68 | def sample_view_matrices_polyhedra(polygon: str, radius: float, lookat_position: TorchTensor['3']=torch.tensor([0, 0, 0]), **kwargs) -> HomogeneousTransform: 69 | """ 70 | Sample view matrices according to a polygon with given radius. 71 | """ 72 | camera_position = torch.from_numpy(eval(polygon)(**kwargs)) * radius 73 | return view_matrix( 74 | camera_position.to(lookat_position.device) + lookat_position, 75 | lookat_position, 76 | up=torch.tensor([0, 1, 0], device=lookat_position.device) 77 | ) 78 | 79 | 80 | def cam2world_opengl2pytorch3d(cam2world: HomogeneousTransform) -> HomogeneousTransform: 81 | """ 82 | Convert OpenGL camera matrix to PyTorch3D camera matrix. Compare view_matrix function with 83 | 84 | https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/renderer/cameras.py#L1637 85 | 86 | for details regarding convention PyTorch3D uses. 87 | """ 88 | if isinstance(cam2world, np.ndarray): 89 | cam2world = torch.from_numpy(cam2world).float() 90 | 91 | world2cam = torch.zeros_like(cam2world) 92 | world2cam[:3, :3] = cam2world[:3, :3] 93 | world2cam[:3, 0] = -world2cam[:3, 0] 94 | world2cam[:3, 2] = -world2cam[:3, 2] 95 | world2cam[:3, 3] = -world2cam[:3, :3].T @ cam2world[:3, 3] 96 | return world2cam 97 | 98 | 99 | if __name__ == '__main__': 100 | m = view_matrix( 101 | torch.tensor([0, 0, 1]).unsqueeze(0), 102 | torch.tensor([0, 0, 0]).unsqueeze(0), 103 | ) 104 | print(m) 105 | 106 | m = view_matrix( 107 | torch.tensor([0, 0, 1]), 108 | torch.tensor([0, 0, 0]), 109 | ) 110 | print(m) 111 | 112 | for m in sample_view_matrices(10, 1): 113 | print(m) -------------------------------------------------------------------------------- /src/samesh/utils/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from samesh.data.common import TorchTensor 4 | 5 | 6 | def discretize(t: TorchTensor, lb=-1, ub=1, resolution=128) -> TorchTensor: 7 | """ 8 | Given tensor of continuous values, return corresponding discrete logits. 9 | """ 10 | return torch.round((t - lb) / (ub - lb) * resolution).long() 11 | 12 | 13 | def undiscretize(t: TorchTensor, lb=-1, ub=1, resolution=128) -> TorchTensor: 14 | """ 15 | Given tensor of discrete logits, return corresponding continuous values. 16 | """ 17 | return t.float() / resolution * (ub - lb) + lb 18 | 19 | 20 | def range_norm(t: TorchTensor, lb=None, ub=None, offset=None, eps=1e-8) -> TorchTensor: 21 | """ 22 | Given tensor of continuous values, return corresponding range normalized values. 23 | """ 24 | if lb is None: lb = t.min() - offset if offset else t.min() 25 | if ub is None: ub = t.max() 26 | return (t - lb) / (ub - lb + eps) -------------------------------------------------------------------------------- /src/samesh/utils/mesh.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import trimesh 3 | from trimesh.base import Trimesh, Scene 4 | 5 | from samesh.data.common import NumpyTensor, TorchTensor 6 | 7 | 8 | def duplicate_verts(mesh: Trimesh) -> Trimesh: 9 | """ 10 | Call before coloring mesh to avoid face interpolation since openGL stores color attributes per vertex. 11 | 12 | ... 13 | mesh = duplicate_verts(mesh) 14 | mesh.visual.face_colors = colors 15 | ... 16 | 17 | NOTE: removes visuals for verticies, but preserves for faces. 18 | """ 19 | verts = mesh.vertices[mesh.faces.reshape(-1), :] 20 | faces = np.arange(0, verts.shape[0]) 21 | faces = faces.reshape(-1, 3) 22 | return Trimesh(vertices=verts, faces=faces, face_colors=mesh.visual.face_colors, process=False) 23 | 24 | 25 | def handle_pose(pose: NumpyTensor['4 4']) -> NumpyTensor['4 4']: 26 | """ 27 | Handles common case that results in numerical instability in rendering faceids: 28 | 29 | ... 30 | pose, _ = scene.graph[name] 31 | pose = handle_pose(pose) 32 | ... 33 | """ 34 | identity = np.eye(4) 35 | if np.allclose(pose, identity, atol=1e-6): 36 | return identity 37 | return pose 38 | 39 | 40 | def transform(pose: NumpyTensor['4 4'], vertices: NumpyTensor['nv 3']) -> NumpyTensor['nv 3']: 41 | """ 42 | """ 43 | homogeneous = np.concatenate([vertices, np.ones((vertices.shape[0], 1))], axis=1) 44 | return (pose @ homogeneous.T).T[:, :3] 45 | 46 | 47 | def concat_scene_vertices(scene: Scene) -> NumpyTensor['nv 3']: 48 | """ 49 | """ 50 | verts = [] 51 | for name, geom in scene.geometry.items(): 52 | if name in scene.graph: 53 | pose, _ = scene.graph[name] 54 | pose = handle_pose(pose) 55 | geom.vertices = transform(pose, geom.vertices) 56 | verts.append(geom.vertices) 57 | return np.concatenate(verts) 58 | 59 | 60 | def bounding_box(vertices: NumpyTensor['n 3']) -> NumpyTensor['2 3']: 61 | """ 62 | Compute bounding box from vertices. 63 | """ 64 | return np.array([vertices.min(axis=0), vertices.max(axis=0)]) 65 | 66 | 67 | def bounding_box_centroid(vertices: NumpyTensor['n 3']) -> NumpyTensor['3']: 68 | """ 69 | Compute bounding box centroid from vertices. 70 | """ 71 | return bounding_box(vertices).mean(axis=0) 72 | 73 | 74 | def norm_mesh(mesh: Trimesh) -> Trimesh: 75 | """ 76 | Normalize mesh vertices to bounding box [-1, 1]. 77 | 78 | NOTE:: In place operation that consumes mesh. 79 | """ 80 | centroid = bounding_box_centroid(mesh.vertices) 81 | mesh.vertices -= centroid 82 | mesh.vertices /= np.abs(mesh.vertices).max() 83 | mesh.vertices *= (1 - 1e-3) 84 | return mesh 85 | 86 | 87 | def norm_scene(scene: Scene) -> Scene: 88 | """ 89 | Normalize scene vertices to bounding box [-1, 1]. 90 | 91 | NOTE:: In place operation that consumes scene. 92 | """ 93 | centroid = bounding_box_centroid(concat_scene_vertices(scene)) 94 | for geom in scene.geometry.values(): 95 | geom.vertices -= centroid 96 | extent = np.abs(concat_scene_vertices(scene)).max() 97 | for geom in scene.geometry.values(): 98 | geom.vertices /= extent 99 | geom.vertices *= (1 - 1e-3) 100 | return scene 101 | 102 | 103 | if __name__ == "__main__": 104 | from samesh.data.loaders import read_mesh 105 | mesh = read_mesh('/home/ubuntu/meshseg/tests/examples/0ba4ae3aa97b4298866a2903de4fd1e7.glb') 106 | 107 | mesh.export('/home/ubuntu/meshseg/tests/examples/0ba4ae3aa97b4298866a2903de4fd1e7.obj') 108 | print(mesh.faces) 109 | print(mesh.vertices[mesh.faces[:, 0]]) 110 | mesh = order_faces(mesh) 111 | print(mesh.faces) 112 | print(mesh.vertices[mesh.faces[:, 0]]) 113 | 114 | mesh.export('/home/ubuntu/meshseg/tests/examples/0ba4ae3aa97b4298866a2903de4fd1e7_sorted.obj') 115 | print(mesh.vertices.max(), mesh.vertices.min()) 116 | mesh = norm_mesh(mesh) 117 | print(mesh.vertices.max(), mesh.vertices.min()) 118 | mesh.export('/home/ubuntu/meshseg/tests/examples/0ba4ae3aa97b4298866a2903de4fd1e7_norm.obj') 119 | 120 | print(mesh.vertices.shape) 121 | print(mesh.faces.shape) 122 | print(mesh.vertices[mesh.faces].shape) -------------------------------------------------------------------------------- /src/samesh/utils/polyhedra.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def golden_ratio(): 6 | return (1 + np.sqrt(5)) / 2 7 | 8 | 9 | def tetrahedron(): 10 | return np.array([ 11 | [ 1, 1, 1], 12 | [-1, -1, 1], 13 | [-1, 1, -1], 14 | [ 1, -1, -1], 15 | ]) 16 | 17 | 18 | def octohedron(): 19 | return np.array([ 20 | [ 1, 0, 0], 21 | [ 0, 0, 1], 22 | [-1, 0, 0], 23 | [ 0, 0, -1], 24 | [ 0, 1, 0], 25 | [ 0, -1, 0], 26 | ]) 27 | 28 | 29 | def cube(): 30 | return np.array([ 31 | [ 1, 1, 1], 32 | [-1, 1, 1], 33 | [-1, -1, 1], 34 | [ 1, -1, 1], 35 | [ 1, 1, -1], 36 | [-1, 1, -1], 37 | [-1, -1, -1], 38 | [ 1, -1, -1], 39 | ]) 40 | 41 | 42 | def icosahedron(): 43 | phi = golden_ratio() 44 | return np.array([ 45 | [-1, phi, 0], 46 | [-1, -phi, 0], 47 | [ 1, phi, 0], 48 | [ 1, -phi, 0], 49 | [ 0, -1, phi], 50 | [ 0, 1, phi], 51 | [ 0, -1, -phi], 52 | [ 0, 1, -phi], 53 | [ phi, 0, -1], 54 | [ phi, 0, 1], 55 | [-phi, 0, -1], 56 | [-phi, 0, 1], 57 | ]) / np.sqrt(1 + phi ** 2) 58 | 59 | 60 | def dodecahedron(): 61 | phi = golden_ratio() 62 | a, b = 1 / phi, 1 / (phi * phi) 63 | return np.array([ 64 | [-a, -a, b], [ a, -a, b], [ a, a, b], [-a, a, b], 65 | [-a, -a, -b], [ a, -a, -b], [ a, a, -b], [-a, a, -b], 66 | [ b, -a, -a], [ b, a, -a], [ b, a, a], [ b, -a, a], 67 | [-b, -a, -a], [-b, a, -a], [-b, a, a], [-b, -a, a], 68 | [-a, b, -a], [ a, b, -a], [ a, b, a], [-a, b, a], 69 | ]) / np.sqrt(a ** 2 + b ** 2) 70 | 71 | 72 | def standard(n=8, elevation=15): 73 | """ 74 | """ 75 | pphi = elevation * np.pi / 180 76 | nphi = -elevation * np.pi / 180 77 | coords = [] 78 | for phi in [pphi, nphi]: 79 | for theta in np.linspace(0, 2 * np.pi, n, endpoint=False): 80 | coords.append([ 81 | np.cos(theta) * np.cos(phi), 82 | np.sin(phi), 83 | np.sin(theta) * np.cos(phi), 84 | ]) 85 | coords.append([0, 0, 1]) 86 | coords.append([0, 0, -1]) 87 | return np.array(coords) 88 | 89 | 90 | def swirl(n=120, cycles=1, elevation_range=(-45, 60)): 91 | """ 92 | """ 93 | pphi = elevation_range[0] * np.pi / 180 94 | nphi = elevation_range[1] * np.pi / 180 95 | thetas = np.linspace(0, 2 * np.pi, n, endpoint=False) 96 | coords = [] 97 | for i, phi in enumerate(np.linspace(pphi, nphi, n)): 98 | coords.append([ 99 | np.cos(cycles * thetas[i]) * np.cos(phi), 100 | np.sin(phi), 101 | np.sin(cycles * thetas[i]) * np.cos(phi), 102 | ]) 103 | return np.array(coords) --------------------------------------------------------------------------------