├── .github ├── 021_bleach_cleanser.gif ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── interactive_filter.gif ├── live_demo_mug.gif ├── logo.svg ├── meta_ai.jpeg ├── power_drill_manual_slide.png ├── power_drill_train_data.png ├── power_drill_ycb_slide.png └── rpl.png ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── download_assets.sh ├── environment.yml ├── midastouch ├── __init__.py ├── bash │ ├── generate_codebooks.sh │ └── run_filter.sh ├── config │ ├── config.yaml │ ├── expt │ │ ├── mcmaster.yaml │ │ └── ycb.yaml │ ├── tcn │ │ └── default.yaml │ └── tdn │ │ └── default.yaml ├── contrib │ ├── tcn_minkloc │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── minkfpn.py │ │ ├── minkloc.py │ │ ├── resnet.py │ │ ├── tcn.py │ │ └── utils.py │ └── tdn_fcrn │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config │ │ ├── test.yaml │ │ └── train.yaml │ │ ├── data │ │ ├── data_to_txt.py │ │ └── data_to_txt_real.py │ │ ├── data_loader.py │ │ ├── fcrn.py │ │ ├── flow_transforms.py │ │ ├── tdn.py │ │ ├── test.py │ │ ├── train.py │ │ └── weights.py ├── data_gen │ ├── README.md │ ├── __init__.py │ ├── config │ │ ├── config.yaml │ │ └── method │ │ │ ├── manual_slide.yaml │ │ │ ├── train_data.yaml │ │ │ └── ycb_slide.yaml │ ├── generate_data.py │ ├── touch_simulator.py │ └── utils.py ├── eval │ ├── __init__.py │ ├── compute_contact_area.py │ ├── compute_surface_area.py │ ├── decimate_meshes.py │ ├── single_touch_test.py │ └── viz_codebook.py ├── filter │ ├── __init__.py │ ├── filter.py │ ├── filter_real.py │ └── live_demo.py ├── modules │ ├── __init__.py │ ├── mesh.py │ ├── misc.py │ ├── objects.py │ ├── particle_filter.py │ └── pose.py ├── render │ ├── __init__.py │ └── digit_renderer.py ├── tactile_tree │ ├── __init__.py │ ├── build_codebook.py │ ├── process_codebook.py │ ├── tactile_tree.py │ └── test_codebook.py └── viz │ ├── __init__.py │ ├── demo_visualizer.py │ ├── helpers.py │ └── visualizer.py └── setup.py /.github/021_bleach_cleanser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/021_bleach_cleanser.gif -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MidasTouch 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to MidasTouch, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /.github/interactive_filter.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/interactive_filter.gif -------------------------------------------------------------------------------- /.github/live_demo_mug.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/live_demo_mug.gif -------------------------------------------------------------------------------- /.github/meta_ai.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/meta_ai.jpeg -------------------------------------------------------------------------------- /.github/power_drill_manual_slide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/power_drill_manual_slide.png -------------------------------------------------------------------------------- /.github/power_drill_train_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/power_drill_train_data.png -------------------------------------------------------------------------------- /.github/power_drill_ycb_slide.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/power_drill_ycb_slide.png -------------------------------------------------------------------------------- /.github/rpl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/.github/rpl.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### folders ### 2 | env/* 3 | .vscode/* 4 | .ipynb_checkpoints 5 | outputs/ 6 | multirun/ 7 | midastouch/tactile_tree/data 8 | midastouch/model_weights 9 | midastouch/contrib/tdn_fcrn/data/*.txt 10 | midastouch/contrib/tdn_fcrn/train_log/ 11 | midastouch/deprecated 12 | train_data/ 13 | *.code-workspace 14 | *.diff 15 | 16 | core* 17 | 18 | ### dev chosen ### 19 | # *.gif 20 | # *.mat 21 | *.zip 22 | *.pyc 23 | *.jpeg 24 | *.png 25 | *.pdf 26 | __pycache__/ 27 | *.egg-info 28 | *.prof 29 | 30 | ### Git ### 31 | *.orig 32 | *.csv 33 | *.gif 34 | ### Linux ### 35 | *~ 36 | 37 | # temporary files which can be created if a process still has a handle open of a deleted file 38 | .fuse_hidden* 39 | 40 | # KDE directory preferences 41 | .directory 42 | 43 | # Linux trash folder which might appear on any partition or disk 44 | .Trash-* 45 | 46 | # .nfs files are created when an open file is removed but is still being accessed 47 | .nfs* 48 | 49 | ### macOS ### 50 | *.DS_Store 51 | .AppleDouble 52 | .LSOverride 53 | 54 | # Icon must end with two \r 55 | Icon 56 | 57 | # Thumbnails 58 | ._* 59 | 60 | # Files that might appear in the root of a volume 61 | .DocumentRevisions-V100 62 | .fseventsd 63 | .Spotlight-V100 64 | .TemporaryItems 65 | .Trashes 66 | .VolumeIcon.icns 67 | .com.apple.timemachine.donotpresent 68 | 69 | # Directories potentially created on remote AFP share 70 | .AppleDB 71 | .AppleDesktop 72 | Network Trash Folder 73 | Temporary Items 74 | .apdisk 75 | 76 | ### Matlab ### 77 | ##--------------------------------------------------- 78 | ## Remove autosaves generated by the Matlab editor 79 | ## We have git for backups! 80 | ##--------------------------------------------------- 81 | 82 | # Windows default autosave extension 83 | *.asv 84 | 85 | # OSX / *nix default autosave extension 86 | *.m~ 87 | 88 | # Compiled MEX binaries (all platforms) 89 | *.mex* 90 | 91 | # Simulink Code Generation 92 | slprj/ 93 | 94 | # Session info 95 | octave-workspace 96 | 97 | # Simulink autosave extension 98 | *.autosave 99 | 100 | ### Windows ### 101 | # Windows thumbnail cache files 102 | Thumbs.db 103 | ehthumbs.db 104 | ehthumbs_vista.db 105 | 106 | # Folder config file 107 | Desktop.ini 108 | 109 | # Recycle Bin used on file shares 110 | $RECYCLE.BIN/ 111 | 112 | # Windows Installer files 113 | *.cab 114 | *.msi 115 | *.msm 116 | *.msp 117 | 118 | # Windows shortcuts 119 | *.lnk 120 | 121 | # pre-trained models 122 | *.tar 123 | *.pth -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "YCB-Slide"] 2 | path = YCB-Slide 3 | url = https://github.com/rpl-cmu/YCB-Slide.git 4 | branch = master 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) Meta Platforms, Inc. and affiliates. 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

3 | MidasTouch 4 |

5 |

6 | Monte-Carlo inference over distributions across sliding touch 7 |

8 | 9 | 10 |
11 | Sudharshan Suresh  •  12 | Zilin Si  •  13 | Stuart Anderson  •  14 | Michael Kaess  •  15 | Mustafa Mukadam 16 |
17 | 6th Annual Conference on Robot Learning (CoRL) 2022 18 |
19 | 20 |

21 | Website  •  22 | Paper  •  23 | Presentation  •  24 | YCB-Slide 25 |

26 | 27 |
28 | TL;DR: We track the pose distribution of a robot finger on an
29 | object's surface using geometry captured by a tactile sensor 30 |

31 | 32 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)   [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)        Meta-AI    rpl 33 |
34 | 35 | [cc-by-sa]: http://creativecommons.org/licenses/by-sa/4.0/ 36 | [cc-by-sa-shield]: https://img.shields.io/badge/License-CC%20BY--SA%204.0-lightgrey.svg 37 | 38 | 39 | MidasTouch performs online global localization of a vision-based touch sensor on an object surface during sliding interactions. For details and further results, refer to our website and paper. 40 | 41 |
42 | 44 |
45 | 46 | 47 | 48 | ## Setup 49 | 50 | ### 1. Clone repository 51 | 52 | ```bash 53 | git clone git@github.com:facebookresearch/MidasTouch.git 54 | git submodule update --init --recursive 55 | ``` 56 | ### 2. Download YCB-Slide dataset 57 | ```bash 58 | cd YCB-Slide 59 | chmod +x download_dataset.sh && ./download_dataset.sh # requires gdown 60 | cd .. 61 | ``` 62 | ### 3. Download weights/codebooks 63 | ```bash 64 | chmod +x download_assets.sh && ./download_assets.sh 65 | ``` 66 | ### 4. Setup midastouch conda env 67 | ```bash 68 | sudo apt install build-essential python3-dev libopenblas-dev 69 | conda create -n midastouch 70 | conda activate midastouch 71 | conda env update --file environment.yml --prune 72 | conda install pytorch torchvision torchaudio cudatoolkit pytorch-cuda=11.7 -c pytorch -c nvidia # install torch 73 | conda install -c conda-forge cudatoolkit-dev 74 | pip install theseus-ai 75 | pip install -e . 76 | ``` 77 | 78 | 79 | #### Known issues 80 | ```ImportError: cannot import name 'gcd' from 'fractions' (/private/home/suddhu/.conda/envs/midastouch/lib/python3.9/fractions.py)``` 81 | 82 | ```bash 83 | conda install -c conda-forge networkx=2.5 84 | ``` 85 | ### 5. Install PyTorch and the MinkowskiEngine 86 | 87 |       Follow [the conda instructions](https://github.com/NVIDIA/MinkowskiEngine#anaconda) from the NVIDIA MinkowskiEngine webpage 88 | 89 | 90 | 91 | ## Run MidasTouch 92 | 93 | Run interactive filtering experiments with our YCB-Slide data from both the simulated and real-world tactile interactions. 94 | 95 |
96 | 98 |
99 | 100 | 101 | ### TACTO simulation trajectories 102 | ```python 103 | python midastouch/filter/filter.py expt=ycb # default: 004_sugar_box log 0 104 | python midastouch/filter/filter.py expt.obj_model=035_power_drill expt.log_id=3 # 035_power_drill log 3 105 | python midastouch/filter/filter.py expt.off_screen=True # disable visualization 106 | python midastouch/filter/filter.py expt=mcmaster # small parts: cotter-pin log 0 107 | ``` 108 | 109 | ### Real-world trajectories 110 | 111 | ```python 112 | python midastouch/filter/filter_real.py expt=ycb # default: 004_sugar_box log 0 113 | python midastouch/filter/filter_real.py expt.obj_model=021_bleach_cleanser expt.log_id=2 # 021_bleach_cleanser log 2 114 | ``` 115 | 116 | 117 | 118 | ## Codebook live demo 119 | 120 | With your own [DIGIT](https://digit.ml/), you can simple plug in the sensor and experiment with the image to 3D and tactile codes visualizer. 121 | 122 | 123 | ```python 124 | python midastouch/filter/live_demo.py expt.obj_model=025_mug 125 | ``` 126 | 127 |
128 | 130 |
131 | 132 | 133 | 134 | ## Folder structure 135 | ```bash 136 | midastouch 137 | ├── bash # bash scripts for filtering, codebook generation 138 | ├── config # hydra config files 139 | ├── contrib # modified third-party code for TDN, TCN 140 | ├── data_gen # Generate tactile simulation data for training/eval 141 | ├── eval # select evaluation scripts 142 | ├── filter # filtering and live demo scripts 143 | ├── modules # helper functions and classes 144 | ├── render # DIGIT tactile rendering class 145 | ├── tactile_tree # codebook scripts 146 | └── viz # pyvista visualization 147 | ``` 148 | 149 | ## Data generation scripts 150 | 151 | - To generate your own tactile simulation data on object meshes, refer to the `midastouch/data_gen/` scripts. 152 | - To collect tactile data in the real-world, refer to our experimental scripts in the [YCB-Slide repository](https://github.com/rpl-cmu/YCB-Slide). 153 | 154 |
155 |             157 | 159 |
160 | 161 | ## Bibtex 162 | 163 | ``` 164 | @inproceedings{suresh2022midastouch, 165 | title={{M}idas{T}ouch: {M}onte-{C}arlo inference over distributions across sliding touch}, 166 | author={Suresh, Sudharshan and Si, Zilin and Anderson, Stuart and Kaess, Michael and Mukadam, Mustafa}, 167 | booktitle = {Proc. Conf. on Robot Learning, CoRL}, 168 | address = {Auckland, NZ}, 169 | month = dec, 170 | year = {2022} 171 | } 172 | ``` 173 | 174 | 175 | ## License 176 | 177 | The majority of MidasTouch is licensed under MIT license, however portions of the project are available under separate license terms: MinkLoc3D is licensed under the MIT license; FCRN-DepthPrediction is licensed under the BSD 2-clause license; pytorch3d is licensed under the BSD 3-clause license. Please see the [LICENSE](LICENSE) file for more information. 178 | 179 | 180 | 181 | ## Contributing 182 | 183 | We actively welcome your pull requests! Please see [CONTRIBUTING.md](.github/CONTRIBUTING.md) and [CODE_OF_CONDUCT.md](.github/CODE_OF_CONDUCT.md) for more info. -------------------------------------------------------------------------------- /download_assets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | cd midastouch 9 | echo "Downloading the TDN/TCN model weights" 10 | gdown https://drive.google.com/drive/folders/1Zy1yFJl3-3Q3Ms0NWb2aTZDXMdEj6dW9?usp=sharing --folder 11 | cd tactile_tree 12 | echo "Downloading the YCB tactile codebooks" 13 | gdown --fuzzy https://drive.google.com/file/d/165Bj9eqVJ0As5vitIoPAOU-9Jij01jpB/view?usp=sharing 14 | unzip codebooks.zip 15 | rm codebooks.zip 16 | cd ../.. 17 | echo "Done!" 18 | 19 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: midastouch 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - configargparse 9 | - imageio 10 | - imageio-ffmpeg 11 | - matplotlib 12 | - openblas-devel 13 | - pip 14 | - pyglet 15 | - pyvista 16 | - scikit-image 17 | - scikit-learn 18 | - tensorboard 19 | - tqdm 20 | - transforms3d 21 | - suitesparse 22 | - pip: 23 | - addict 24 | - antlr4-python3-runtime 25 | - anyio 26 | - argon2-cffi-bindings 27 | - argon2-cffi 28 | - asttokens 29 | - babel 30 | - backcall 31 | - beautifulsoup4 32 | - bleach 33 | - debugpy 34 | - defusedxml 35 | - deprecation 36 | - digit-interface 37 | - dill 38 | - entrypoints 39 | - executing 40 | - fastjsonschema 41 | - ffmpeg-python 42 | - freetype-py 43 | - gdown 44 | - git+https://github.com/suddhu/tacto.git@master 45 | - git+https://github.com/u1234x1234/pynanoflann.git@0.0.9 46 | - gitpython 47 | - gputil 48 | - hydra-core 49 | - ipykernel 50 | - ipython-genutils 51 | - ipython 52 | - ipywidgets 53 | - jedi 54 | - jinja2 55 | - json5 56 | - jsonschema 57 | - jupyter-client 58 | - jupyter-core 59 | - jupyter-packaging 60 | - jupyter-server 61 | - jupyterlab-pygments 62 | - jupyterlab-server 63 | - jupyterlab-widgets 64 | - jupyterlab 65 | - lxml 66 | - markupsafe 67 | - matplotlib-inline 68 | - mesh_to_sdf 69 | - mistune 70 | - nbclassic 71 | - nbclient 72 | - nbconvert 73 | - nbformat 74 | - nest-asyncio 75 | - networkx 76 | - ninja 77 | - notebook-shim 78 | - notebook 79 | - omegaconf 80 | - open3d 81 | - opencv-python 82 | - pandas 83 | - pandocfilters 84 | - parso 85 | - pexpect 86 | - pickleshare 87 | - pillow 88 | - potpourri3d 89 | - prometheus-client 90 | - prompt-toolkit 91 | - psutil 92 | - ptyprocess 93 | - pure-eval 94 | - pybind11 95 | - pycollada 96 | - pygments 97 | - pyopengl 98 | - PyQt5 99 | - pyquaternion 100 | - pyrender 101 | - pyrsistent 102 | - pytz 103 | - pyvistaqt 104 | - pyzmq 105 | - qtpy 106 | - rtree 107 | - seaborn 108 | - send2trash 109 | - sniffio 110 | - soupsieve 111 | - stack-data 112 | - terminado 113 | - tinycss2 114 | - tomlkit 115 | - traitlets 116 | - trimesh 117 | - urdfpy 118 | - wcwidth 119 | - webencodings 120 | - websocket-client 121 | - widgetsnbextension 122 | - yappi -------------------------------------------------------------------------------- /midastouch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/__init__.py -------------------------------------------------------------------------------- /midastouch/bash/generate_codebooks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Generate tactile codebook for all YCB objects 9 | 10 | 11 | declare -a objModels=("004_sugar_box" "005_tomato_soup_can" "006_mustard_bottle" "021_bleach_cleanser" "025_mug" "035_power_drill" "037_scissors" "042_adjustable_wrench" "048_hammer" "055_baseball") 12 | 13 | for obj in ${objModels[@]}; do 14 | python midastouch/tactile_tree/build_codebook.py expt.obj_model=$obj tdn.render.pen.max=0.001 expt.codebook_size=50000 15 | done 16 | -------------------------------------------------------------------------------- /midastouch/bash/run_filter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | # Run filtering experiments for all YCB objects 9 | 10 | declare -a objModels=("004_sugar_box" "005_tomato_soup_can" "006_mustard_bottle" "021_bleach_cleanser" "025_mug" "035_power_drill" "037_scissors" "042_adjustable_wrench" "048_hammer" "055_baseball") 11 | # declare -a objModels=("cotter-pin" "steel-nail" "eyebolt") 12 | 13 | for log in {0..4}; do 14 | for obj in ${objModels[@]}; do 15 | python midastouch/filter/filter.py expt.obj_model=$obj expt.log_id=$log 16 | # python midastouch/filter/filter_real.py expt.obj_model=$obj expt.log_id=$log 17 | done 18 | done 19 | -------------------------------------------------------------------------------- /midastouch/config/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # root config for MidasTouch 7 | 8 | defaults: 9 | - expt: ycb # main experimental params 10 | - tcn: default # tactile code network params 11 | - tdn: default # tactile depth network params -------------------------------------------------------------------------------- /midastouch/config/expt/mcmaster.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # config file for McMaster dataset objects 7 | 8 | obj_model : cotter-pin # cotter-pin, steel-nail, eyebolt 9 | log_id : 0 # Log ID 10 | ablation : False 11 | frame_rate : 10 12 | off_screen : False 13 | render : True 14 | max_length : None 15 | codebook_size : 50000 16 | 17 | params: 18 | num_particles : 5000 19 | noise_r : 0.5 20 | noise_t : 1e-4 21 | noise_ratio : 1.0 22 | interval : 5 23 | -------------------------------------------------------------------------------- /midastouch/config/expt/ycb.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Config for MidasTouch experimental setup 7 | 8 | obj_model : 004_sugar_box # Object model 9 | log_id : 0 # Log ID 10 | ablation : False 11 | frame_rate : 10 12 | off_screen : False 13 | render : True 14 | max_length : None 15 | codebook_size : 50000 16 | 17 | params: 18 | num_particles : 50000 19 | noise_r : 20 | sim: 0.5 21 | real: 0.5 22 | noise_t : 23 | sim: 2e-4 24 | real: 2e-4 25 | noise_ratio : 1.0 -------------------------------------------------------------------------------- /midastouch/config/tcn/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Config for tactile codes network 7 | 8 | model: 9 | tcn_weights : tcn_weights.pth.tar 10 | model : MinkFPN 11 | num_points : 4096 12 | batch_size : 100 13 | mink_quantization_size : 0.001 14 | planes : 32,64,64 15 | layers : 1,1,1 16 | num_top_down : 1 17 | conv0_kernel_size : 5 18 | feature_size : 256 19 | output_dim : 256 20 | 21 | train: 22 | num_workers : 8 23 | batch_size : 8 24 | val_batch_size : 64 25 | batch_size_limit : 64 26 | batch_expansion_rate : 1.4 27 | batch_expansion_th : 0.7 28 | max_batches : 1000 29 | # final_block : fc 30 | 31 | lr : 1e-7 32 | image_lr : 1e-4 33 | epochs : 100 34 | 35 | scheduler_milestones : 30, 50, 70 36 | 37 | scheduler : MultiStepLR 38 | min_lr : 1e-6 39 | optimizer : Adam 40 | 41 | aug_mode : 1 42 | weight_decay : 1e-4 43 | 44 | loss : BatchHardTripletMarginLoss 45 | weights : 1.0, 0.0, 0.0 46 | normalize_embeddings : True 47 | margin : 0.2 48 | 49 | pos_margin : 0.2 50 | neg_margin : 0.65 51 | 52 | train_file : train_sets_tacto_40.pickle 53 | val_file : val_sets_tacto_40.pickle 54 | 55 | dataset_folder : /mnt/sda/suddhu/minkloc/minkloc_data 56 | val_folder : /mnt/sda/suddhu/minkloc/minkloc_val 57 | eval_folder : /mnt/sda/suddhu/fcrn/fcrn_eval -------------------------------------------------------------------------------- /midastouch/config/tdn/default.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Config for tactile depth network 7 | 8 | tdn_weights : tdn_weights.pth.tar 9 | 10 | render: 11 | pixmm : 0.03 12 | width : 240 13 | height : 320 14 | cam_dist : 0.022 15 | shear_mag : 5.0 16 | pen : 17 | min : 0.0005 18 | max : 0.002 19 | 20 | fcrn: 21 | real: 22 | blend_sz : 10 23 | border : 10 24 | ratio : 0.9 25 | clip : 5 26 | batch_size : 1 27 | 28 | sim: 29 | blend_sz : 0 30 | border : 1 31 | ratio : 0.2 32 | clip : 5 33 | batch_size : 1 34 | -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Warsaw University of Technology 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/README.md: -------------------------------------------------------------------------------- 1 | # TCN: Tactile code network 2 | Tactile code network based on the third-party [MinkLoc3D](https://github.com/jac99/MinkLoc3D). 3 | -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/contrib/tcn_minkloc/__init__.py -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/minkfpn.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | # Original source: https://github.com/jac99/MinkLoc3D 5 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 6 | 7 | import torch.nn as nn 8 | import MinkowskiEngine as ME 9 | from MinkowskiEngine.modules.resnet_block import BasicBlock 10 | from .resnet import ResNetBase 11 | 12 | 13 | class MinkFPN(ResNetBase): 14 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks 15 | def __init__( 16 | self, 17 | in_channels, 18 | out_channels, 19 | num_top_down=1, 20 | conv0_kernel_size=5, 21 | block=BasicBlock, 22 | layers=(1, 1, 1), 23 | planes=(32, 64, 64), 24 | ): 25 | assert len(layers) == len(planes) 26 | assert 1 <= len(layers) 27 | assert 0 <= num_top_down <= len(layers) 28 | self.num_bottom_up = len(layers) 29 | self.num_top_down = num_top_down 30 | self.conv0_kernel_size = conv0_kernel_size 31 | self.block = block 32 | self.layers = layers 33 | self.planes = planes 34 | self.lateral_dim = out_channels 35 | self.init_dim = planes[0] 36 | ResNetBase.__init__(self, in_channels, out_channels, D=3) 37 | 38 | def network_initialization(self, in_channels, out_channels, D): 39 | assert len(self.layers) == len(self.planes) 40 | assert len(self.planes) == self.num_bottom_up 41 | 42 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 43 | self.bn = nn.ModuleList() # Bottom-up BatchNorms 44 | self.blocks = nn.ModuleList() # Bottom-up blocks 45 | self.tconvs = nn.ModuleList() # Top-down tranposed convolutions 46 | self.conv1x1 = nn.ModuleList() # 1x1 convolutions in lateral connections 47 | 48 | # The first convolution is special case, with kernel size = 5 49 | self.inplanes = self.planes[0] 50 | self.conv0 = ME.MinkowskiConvolution( 51 | in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, dimension=D 52 | ) 53 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 54 | 55 | for plane, layer in zip(self.planes, self.layers): 56 | self.convs.append( 57 | ME.MinkowskiConvolution( 58 | self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D 59 | ) 60 | ) 61 | self.bn.append(ME.MinkowskiBatchNorm(self.inplanes)) 62 | self.blocks.append(self._make_layer(self.block, plane, layer)) 63 | 64 | # Lateral connections 65 | for i in range(self.num_top_down): 66 | self.conv1x1.append( 67 | ME.MinkowskiConvolution( 68 | self.planes[-1 - i], 69 | self.lateral_dim, 70 | kernel_size=1, 71 | stride=1, 72 | dimension=D, 73 | ) 74 | ) 75 | self.tconvs.append( 76 | ME.MinkowskiConvolutionTranspose( 77 | self.lateral_dim, 78 | self.lateral_dim, 79 | kernel_size=2, 80 | stride=2, 81 | dimension=D, 82 | ) 83 | ) 84 | # There's one more lateral connection than top-down TConv blocks 85 | if self.num_top_down < self.num_bottom_up: 86 | # Lateral connection from Conv block 1 or above 87 | self.conv1x1.append( 88 | ME.MinkowskiConvolution( 89 | self.planes[-1 - self.num_top_down], 90 | self.lateral_dim, 91 | kernel_size=1, 92 | stride=1, 93 | dimension=D, 94 | ) 95 | ) 96 | else: 97 | # Lateral connection from Con0 block 98 | self.conv1x1.append( 99 | ME.MinkowskiConvolution( 100 | self.planes[0], 101 | self.lateral_dim, 102 | kernel_size=1, 103 | stride=1, 104 | dimension=D, 105 | ) 106 | ) 107 | 108 | self.relu = ME.MinkowskiReLU(inplace=True) 109 | 110 | def forward(self, x): 111 | # *** BOTTOM-UP PASS *** 112 | # First bottom-up convolution is special (with bigger stride) 113 | feature_maps = [] 114 | x = self.conv0(x) 115 | x = self.bn0(x) 116 | x = self.relu(x) 117 | if self.num_top_down == self.num_bottom_up: 118 | feature_maps.append(x) 119 | 120 | # BOTTOM-UP PASS 121 | for ndx, (conv, bn, block) in enumerate(zip(self.convs, self.bn, self.blocks)): 122 | x = conv(x) # Decreases spatial resolution (conv stride=2) 123 | x = bn(x) 124 | x = self.relu(x) 125 | x = block(x) 126 | if self.num_bottom_up - 1 - self.num_top_down <= ndx < len(self.convs) - 1: 127 | feature_maps.append(x) 128 | 129 | assert len(feature_maps) == self.num_top_down 130 | 131 | x = self.conv1x1[0](x) 132 | 133 | # TOP-DOWN PASS 134 | for ndx, tconv in enumerate(self.tconvs): 135 | x = tconv(x) # Upsample using transposed convolution 136 | x = x + self.conv1x1[ndx + 1](feature_maps[-ndx - 1]) 137 | 138 | return x 139 | -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/minkloc.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | # Original source: https://github.com/jac99/MinkLoc3D 5 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 6 | 7 | import torch.nn as nn 8 | import MinkowskiEngine as ME 9 | 10 | from .minkfpn import MinkFPN 11 | import torch 12 | import torch.nn as nn 13 | 14 | 15 | class MinkLoc(nn.Module): 16 | def __init__( 17 | self, 18 | in_channels, 19 | feature_size, 20 | output_dim, 21 | planes, 22 | layers, 23 | num_top_down, 24 | conv0_kernel_size, 25 | ): 26 | 27 | super().__init__() 28 | self.in_channels = in_channels 29 | self.feature_size = feature_size # Size of local features produced by local feature extraction block 30 | self.output_dim = output_dim # Dimensionality of the global descriptor produced by pooling layer 31 | self.backbone = MinkFPN( 32 | in_channels=in_channels, 33 | out_channels=self.feature_size, 34 | num_top_down=num_top_down, 35 | conv0_kernel_size=conv0_kernel_size, 36 | layers=layers, 37 | planes=planes, 38 | ) 39 | self.n_backbone_features = output_dim 40 | assert ( 41 | self.feature_size == self.output_dim 42 | ), "output_dim must be the same as feature_size" 43 | self.pooling = GeM() 44 | 45 | def forward(self, batch): 46 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 47 | x = ME.SparseTensor(batch["features"], coordinates=batch["coords"]) 48 | x = self.backbone(x) 49 | 50 | # x is (num_points, n_features) tensor 51 | assert ( 52 | x.shape[1] == self.feature_size 53 | ), "Backbone output tensor has: {} channels. Expected: {}".format( 54 | x.shape[1], self.feature_size 55 | ) 56 | x = self.pooling(x) 57 | assert ( 58 | x.dim() == 2 59 | ), "Expected 2-dimensional tensor (batch_size,output_dim). Got {} dimensions.".format( 60 | x.dim() 61 | ) 62 | assert ( 63 | x.shape[1] == self.output_dim 64 | ), "Output tensor has: {} channels. Expected: {}".format( 65 | x.shape[1], self.output_dim 66 | ) 67 | # x is (batch_size, output_dim) tensor 68 | return x 69 | 70 | def print_info(self): 71 | print("Model class: MinkLoc") 72 | n_params = sum([param.nelement() for param in self.parameters()]) 73 | print("Total parameters: {}".format(n_params)) 74 | n_params = sum([param.nelement() for param in self.backbone.parameters()]) 75 | print("Backbone parameters: {}".format(n_params)) 76 | n_params = sum([param.nelement() for param in self.pooling.parameters()]) 77 | print("Aggregation parameters: {}".format(n_params)) 78 | if hasattr(self.backbone, "print_info"): 79 | self.backbone.print_info() 80 | if hasattr(self.pooling, "print_info"): 81 | self.pooling.print_info() 82 | 83 | 84 | class GeM(nn.Module): 85 | def __init__(self, p=3, eps=1e-6): 86 | super(GeM, self).__init__() 87 | self.p = nn.Parameter(torch.ones(1) * p) 88 | self.eps = eps 89 | self.f = ME.MinkowskiGlobalAvgPooling() 90 | 91 | def forward(self, x: ME.SparseTensor): 92 | # This implicitly applies ReLU on x (clamps negative values) 93 | temp = ME.SparseTensor(x.F.clamp(min=self.eps).pow(self.p), coordinates=x.C) 94 | temp = self.f(temp) # Apply ME.MinkowskiGlobalAvgPooling 95 | return temp.F.pow(1.0 / self.p) # Return (batch_size, n_features) tensor 96 | -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | 25 | import torch.nn as nn 26 | 27 | import MinkowskiEngine as ME 28 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 29 | 30 | 31 | class ResNetBase(nn.Module): 32 | block = None 33 | layers = () 34 | init_dim = 64 35 | planes = (64, 128, 256, 512) 36 | 37 | def __init__(self, in_channels, out_channels, D=3): 38 | nn.Module.__init__(self) 39 | self.D = D 40 | assert self.block is not None 41 | 42 | self.network_initialization(in_channels, out_channels, D) 43 | self.weight_initialization() 44 | 45 | def network_initialization(self, in_channels, out_channels, D): 46 | self.inplanes = self.init_dim 47 | self.conv1 = ME.MinkowskiConvolution( 48 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D 49 | ) 50 | 51 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 52 | self.relu = ME.MinkowskiReLU(inplace=True) 53 | 54 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) 55 | 56 | self.layer1 = self._make_layer( 57 | self.block, self.planes[0], self.layers[0], stride=2 58 | ) 59 | self.layer2 = self._make_layer( 60 | self.block, self.planes[1], self.layers[1], stride=2 61 | ) 62 | self.layer3 = self._make_layer( 63 | self.block, self.planes[2], self.layers[2], stride=2 64 | ) 65 | self.layer4 = self._make_layer( 66 | self.block, self.planes[3], self.layers[3], stride=2 67 | ) 68 | 69 | self.conv5 = ME.MinkowskiConvolution( 70 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D 71 | ) 72 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) 73 | 74 | self.glob_avg = ME.MinkowskiGlobalMaxPooling() 75 | 76 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 77 | 78 | def weight_initialization(self): 79 | for m in self.modules(): 80 | if isinstance(m, ME.MinkowskiConvolution): 81 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") 82 | 83 | if isinstance(m, ME.MinkowskiBatchNorm): 84 | nn.init.constant_(m.bn.weight, 1) 85 | nn.init.constant_(m.bn.bias, 0) 86 | 87 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * block.expansion: 90 | downsample = nn.Sequential( 91 | ME.MinkowskiConvolution( 92 | self.inplanes, 93 | planes * block.expansion, 94 | kernel_size=1, 95 | stride=stride, 96 | dimension=self.D, 97 | ), 98 | ME.MinkowskiBatchNorm(planes * block.expansion), 99 | ) 100 | layers = [] 101 | layers.append( 102 | block( 103 | self.inplanes, 104 | planes, 105 | stride=stride, 106 | dilation=dilation, 107 | downsample=downsample, 108 | dimension=self.D, 109 | ) 110 | ) 111 | self.inplanes = planes * block.expansion 112 | for i in range(1, blocks): 113 | layers.append( 114 | block( 115 | self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D 116 | ) 117 | ) 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | x = self.pool(x) 126 | 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | x = self.layer4(x) 131 | 132 | x = self.conv5(x) 133 | x = self.bn5(x) 134 | x = self.relu(x) 135 | 136 | x = self.glob_avg(x) 137 | return self.final(x) 138 | 139 | 140 | class ResNet14(ResNetBase): 141 | BLOCK = BasicBlock 142 | LAYERS = (1, 1, 1, 1) 143 | 144 | 145 | class ResNet18(ResNetBase): 146 | BLOCK = BasicBlock 147 | LAYERS = (2, 2, 2, 2) 148 | 149 | 150 | class ResNet34(ResNetBase): 151 | BLOCK = BasicBlock 152 | LAYERS = (3, 4, 6, 3) 153 | 154 | 155 | class ResNet50(ResNetBase): 156 | BLOCK = Bottleneck 157 | LAYERS = (3, 4, 6, 3) 158 | 159 | 160 | class ResNet101(ResNetBase): 161 | BLOCK = Bottleneck 162 | LAYERS = (3, 4, 23, 3) 163 | -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/tcn.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | # Evaluation code adapted from PointNetVlad code: https://github.com/mikacuy/pointnetvlad 4 | 5 | # Original source: https://github.com/jac99/MinkLoc3D 6 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 7 | 8 | from os import path as osp 9 | 10 | import torch 11 | import MinkowskiEngine as ME 12 | from .utils import MinkLocParams 13 | from .minkloc import MinkLoc 14 | from omegaconf import DictConfig 15 | from midastouch.modules.misc import DIRS, get_device 16 | 17 | 18 | class TCN: 19 | def __init__(self, cfg: DictConfig): 20 | in_channels = 1 21 | 22 | self.params = MinkLocParams(cfg) # Load MinkLoc3d params 23 | self.batch_size = cfg.model.batch_size 24 | if "MinkFPN" in self.params.model_params.model: 25 | self.model = MinkLoc( 26 | in_channels=in_channels, 27 | feature_size=self.params.model_params.feature_size, 28 | output_dim=self.params.model_params.output_dim, 29 | planes=self.params.model_params.planes, 30 | layers=self.params.model_params.layers, 31 | num_top_down=self.params.model_params.num_top_down, 32 | conv0_kernel_size=self.params.model_params.conv0_kernel_size, 33 | ) 34 | else: 35 | raise NotImplementedError( 36 | "Model not implemented: {}".format(self.params.model_params.model) 37 | ) 38 | 39 | tcn_weights = osp.join(DIRS["weights"], cfg.model.tcn_weights) 40 | self.load_weights(tcn_weights) 41 | 42 | def load_weights(self, weights): 43 | device = get_device(cpu=False, verbose=False) 44 | # Load MinkLoc weights 45 | weights = torch.load(weights, map_location=device) 46 | if type(weights) is dict: 47 | self.model.load_state_dict(weights["state_dict"]) 48 | else: 49 | self.model.load_state_dict(weights) 50 | self.model.to(device) 51 | 52 | def cloud_to_tactile_code(self, tac_render, heightmaps, masks): 53 | # Adapted from original PointNetVLAD code 54 | self.model.eval() 55 | 56 | if type(heightmaps) is not list: 57 | heightmaps = [heightmaps] 58 | masks = [masks] 59 | 60 | device = next(self.model.parameters()).device 61 | 62 | embeddings_l = [None] * len(heightmaps) 63 | 64 | with torch.no_grad(): 65 | numSamples = len(heightmaps) 66 | num_batches = numSamples // self.batch_size 67 | 68 | num_batches = 1 if num_batches == 0 else num_batches 69 | 70 | for i in range(num_batches): 71 | i_range = ( 72 | torch.IntTensor(range(i * self.batch_size, numSamples)) 73 | if (i == num_batches - 1) 74 | else torch.IntTensor( 75 | range(i * self.batch_size, (i + 1) * self.batch_size) 76 | ) 77 | ) 78 | batch_clouds = [None] * len(i_range) 79 | for j, (h, c) in enumerate( 80 | zip( 81 | heightmaps[i_range[0] : i_range[-1] + 1], 82 | masks[i_range[0] : i_range[-1] + 1], 83 | ) 84 | ): 85 | batch_clouds[j] = tac_render.heightmap2Pointcloud(h, c) 86 | 87 | n_points = self.params.num_points 88 | for j, batch_cloud in enumerate(batch_clouds): 89 | if batch_cloud.shape[0] == 0: 90 | batch_cloud = torch.repeat_interleave( 91 | torch.Tensor([[0, 0, 0]]).to(batch_cloud.device), 92 | n_points, 93 | dim=0, 94 | ) 95 | else: 96 | idxs = torch.arange( 97 | batch_cloud.shape[0], 98 | device=batch_cloud.device, 99 | dtype=torch.float, 100 | ) 101 | if n_points > batch_cloud.shape[0]: 102 | downsampleIDs = torch.multinomial( 103 | idxs, num_samples=n_points, replacement=True 104 | ) 105 | else: 106 | downsampleIDs = torch.multinomial( 107 | idxs, num_samples=n_points, replacement=False 108 | ) 109 | batch_cloud = batch_cloud[downsampleIDs, :] 110 | 111 | batch_clouds[j] = ( 112 | 2.0 113 | * (batch_cloud - torch.min(batch_cloud)) 114 | / (torch.max(batch_cloud) - torch.min(batch_cloud)) 115 | - 1 116 | ) # scale [-1, 1] 117 | 118 | batch_clouds = torch.stack( 119 | batch_clouds, dim=0 120 | ) # Produces (batch_size, n_points, 3) tensor 121 | batch = {} 122 | 123 | # coords are (n_clouds, num_points, channels) tensor 124 | coords = [ 125 | ME.utils.sparse_quantize( 126 | coordinates=e, 127 | quantization_size=self.params.model_params.mink_quantization_size, 128 | ) 129 | for e in batch_clouds 130 | ] 131 | coords = ME.utils.batched_coordinates(coords, device=device) 132 | # Assign a dummy feature equal to 1 to each point 133 | feats = torch.ones( 134 | (coords.shape[0], 1), dtype=torch.float32, device=device 135 | ) 136 | batch["coords"], batch["features"] = coords, feats 137 | 138 | embedding = self.model(batch) 139 | 140 | if self.params.normalize_embeddings: 141 | embedding = torch.nn.functional.normalize( 142 | embedding, p=2, dim=1 143 | ) # Normalize embeddings 144 | # embedding = embedding.detach().cpu().numpy() 145 | embeddings_l[i_range[0] : i_range[-1] + 1] = embedding 146 | 147 | embeddings_l = torch.vstack(embeddings_l) # list to array (set_sz, output_dim) 148 | return embeddings_l.double() # double for precision 149 | -------------------------------------------------------------------------------- /midastouch/contrib/tcn_minkloc/utils.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | # Original source: https://github.com/jac99/MinkLoc3D 5 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 6 | 7 | import os 8 | import time 9 | 10 | 11 | class ModelParams: 12 | def __init__(self, mode_config): 13 | params = mode_config 14 | 15 | self.model = params.model 16 | self.output_dim = params.output_dim # Size of the final descriptor 17 | 18 | # Add gating as the last step 19 | if "vlad" in self.model.lower(): 20 | self.cluster_size = params.cluster_size # Size of NetVLAD cluster 21 | self.gating = params.gating # Use gating after the NetVlad 22 | 23 | ####################################################################### 24 | # Model dependent 25 | ####################################################################### 26 | 27 | if "MinkFPN" in self.model: 28 | # Models using MinkowskiEngine 29 | self.mink_quantization_size = params.mink_quantization_size 30 | # Size of the local features from backbone network (only for MinkNet based models) 31 | # For PointNet-based models we always use 1024 intermediary features 32 | self.feature_size = params.feature_size 33 | if params.planes: 34 | self.planes = [int(e) for e in params.planes.split(",")] 35 | else: 36 | self.planes = [32, 64, 64] 37 | 38 | if params.layers: 39 | self.layers = [int(e) for e in params.layers.split(",")] 40 | else: 41 | self.layers = [1, 1, 1] 42 | 43 | self.num_top_down = params.num_top_down 44 | self.conv0_kernel_size = params.conv0_kernel_size 45 | 46 | def print(self): 47 | print("Model parameters:") 48 | param_dict = vars(self) 49 | for e in param_dict: 50 | print("{}: {}".format(e, param_dict[e])) 51 | 52 | print("") 53 | 54 | 55 | def get_datetime(): 56 | return time.strftime("%Y%m%d_%H%M") 57 | 58 | 59 | class MinkLocParams: 60 | """ 61 | Params for training MinkLoc models 62 | """ 63 | 64 | def __init__(self, config): 65 | """ 66 | Configuration files 67 | :param path: configuration file 68 | """ 69 | self.num_points = config.model.num_points 70 | self.dataset_folder = config.train.dataset_folder 71 | self.val_folder = config.train.val_folder 72 | self.eval_folder = config.train.eval_folder 73 | 74 | self.num_workers = config.train.num_workers 75 | self.batch_size = config.train.batch_size 76 | self.val_batch_size = config.train.val_batch_size 77 | self.max_batches = config.train.max_batches 78 | 79 | # Set batch_expansion_th to turn on dynamic batch sizing 80 | # When number of non-zero triplets falls below batch_expansion_th, expand batch size 81 | self.batch_expansion_th = config.train.batch_expansion_th 82 | if self.batch_expansion_th is not None: 83 | assert ( 84 | 0.0 < self.batch_expansion_th < 1.0 85 | ), "batch_expansion_th must be between 0 and 1" 86 | self.batch_size_limit = config.train.batch_size_limit 87 | # Batch size expansion rate 88 | self.batch_expansion_rate = config.train.batch_expansion_rate 89 | assert ( 90 | self.batch_expansion_rate > 1.0 91 | ), "batch_expansion_rate must be greater than 1" 92 | else: 93 | self.batch_size_limit = self.batch_size 94 | self.batch_expansion_rate = None 95 | 96 | self.lr = config.train.lr 97 | 98 | self.scheduler = config.train.scheduler 99 | if self.scheduler is not None: 100 | if self.scheduler == "CosineAnnealingLR": 101 | self.min_lr = config.train.min_lr 102 | elif self.scheduler == "MultiStepLR": 103 | scheduler_milestones = config.train.scheduler_milestones 104 | self.scheduler_milestones = [ 105 | int(e) for e in scheduler_milestones.split(",") 106 | ] 107 | else: 108 | raise NotImplementedError( 109 | "Unsupported LR scheduler: {}".format(self.scheduler) 110 | ) 111 | 112 | self.epochs = config.train.epochs 113 | self.weight_decay = config.train.weight_decay 114 | self.normalize_embeddings = ( 115 | config.train.normalize_embeddings 116 | ) # Normalize embeddings during training and evaluation 117 | self.loss = config.train.loss 118 | 119 | if "Contrastive" in self.loss: 120 | self.pos_margin = config.train.pos_margin 121 | self.neg_margin = config.train.neg_margin 122 | elif "Triplet" in self.loss: 123 | self.margin = config.train.margin # Margin used in loss function 124 | else: 125 | raise "Unsupported loss function: {}".format(self.loss) 126 | 127 | self.aug_mode = config.train.aug_mode # Augmentation mode (1 is default) 128 | 129 | self.train_file = config.train.train_file 130 | self.val_file = config.train.val_file 131 | 132 | # Read model parameters 133 | self.model_params = ModelParams(config.model) 134 | # self._check_params() 135 | 136 | def _check_params(self): 137 | assert os.path.exists(self.dataset_folder), "Cannot access dataset: {}".format( 138 | self.dataset_folder 139 | ) 140 | 141 | def print(self): 142 | print("Parameters:") 143 | param_dict = vars(self) 144 | for e in param_dict: 145 | if e != "model_params": 146 | print("{}: {}".format(e, param_dict[e])) 147 | 148 | self.model_params.print() 149 | print("") 150 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/README.md: -------------------------------------------------------------------------------- 1 | # TDN: Tactile depth network 2 | Tactile image to heightmap network, based on [FCRN-DepthPrediction](https://github.com/iro-cp/FCRN-DepthPrediction) and its PyTorch [implementation](https://github.com/XPFly1989/FCRN). 3 | 4 | ## Write a data loading file 5 | - `data/data_to_txt.py` generates dataloader for training/validation/testing given paired image and depth data. To generate this data refer to the `midastouch/data_gen/` folder 6 | 7 | ## Train and test model 8 | - `train.py` Trains with the heightmaps and contact masks 9 | - `test_dataset.py` is testing on cpu. Remember to change the data loading path and result saving path. 10 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/contrib/tdn_fcrn/__init__.py -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/config/test.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Config for tactile depth network evaluation 7 | 8 | real : false 9 | tdn : ../config/tdn/default.yaml 10 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/config/train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | # Config for tactile depth network training 7 | 8 | batch_size : 50 9 | lr : 1.0e-4 10 | max_epochs : 100 11 | resume_from_file : True 12 | data_file_path : data 13 | test_results_path : /mnt/sda/suddhu/fcrn/fcrn-testing 14 | init_weights : NYU_ResNet-UpProj.npy 15 | checkpoint_weights : tdn_weights.pth.tar -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/data/data_to_txt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Load train/test data structure and save to text file, preprocessing for depth training 8 | """ 9 | 10 | import numpy as np 11 | import os 12 | from os import path as osp 13 | 14 | # change the data_root to the root directory of dataset with all objects 15 | abspath = osp.abspath(__file__) 16 | dname = osp.dirname(abspath) 17 | os.chdir(dname) 18 | 19 | # write training/validation/testing data loader files 20 | train_data_file = open("train_data.txt", "w") 21 | train_label_file = open("train_label.txt", "w") 22 | 23 | dev_data_file = open("dev_data.txt", "w") 24 | dev_label_file = open("dev_label.txt", "w") 25 | 26 | test_data_file = open("test_data.txt", "w") 27 | test_label_file = open("test_label.txt", "w") 28 | 29 | global_train_idx = 0 30 | global_dev_idx = 0 31 | global_test_idx = 0 32 | 33 | # different tactile background models: add multiple folders here, below is placeholder 34 | data_root_paths = ["/mnt/sda/fcrn/fcrn_data"] 35 | 36 | for data_root_path in data_root_paths: 37 | object_folders = sorted(os.listdir(data_root_path)) 38 | for object in object_folders: 39 | if object == ".DS_Store": 40 | continue 41 | _, ext = os.path.splitext(object) 42 | if ext == ".pickle": 43 | continue 44 | print("Object: ", object) 45 | 46 | # load in tactile images and ground truth height maps 47 | tactile_path = osp.join(data_root_path, object, "tactile_images") 48 | gt_heightmap_path = osp.join(data_root_path, object, "gt_heightmaps") 49 | gt_contactmask_path = osp.join(data_root_path, object, "gt_contactmasks") 50 | 51 | num_imgs = len(os.listdir(tactile_path)) 52 | all_random_idx = np.random.permutation(num_imgs) 53 | num_train = int(0.8 * num_imgs) 54 | num_dev = int(0.1 * num_imgs) 55 | num_test = int(0.1 * num_imgs) 56 | 57 | train_idx = all_random_idx[0:num_train] 58 | dev_idx = all_random_idx[num_train : num_train + num_dev] 59 | test_idx = all_random_idx[num_train + num_dev : num_train + num_dev + num_test] 60 | 61 | for idx in train_idx: 62 | train_data_file.write( 63 | str(global_train_idx) 64 | + "," 65 | + tactile_path 66 | + "/" 67 | + str(idx) 68 | + ".jpg" 69 | + "\n" 70 | ) 71 | train_label_file.write( 72 | str(global_train_idx) 73 | + "," 74 | + gt_heightmap_path 75 | + "/" 76 | + str(idx) 77 | + ".jpg" 78 | + "," 79 | + gt_contactmask_path 80 | + "/" 81 | + str(idx) 82 | + ".jpg" 83 | + "\n" 84 | ) 85 | global_train_idx += 1 86 | 87 | for idx in dev_idx: 88 | dev_data_file.write( 89 | str(global_dev_idx) 90 | + "," 91 | + tactile_path 92 | + "/" 93 | + str(idx) 94 | + ".jpg" 95 | + "\n" 96 | ) 97 | dev_label_file.write( 98 | str(global_dev_idx) 99 | + "," 100 | + gt_heightmap_path 101 | + "/" 102 | + str(idx) 103 | + ".jpg" 104 | + "," 105 | + gt_contactmask_path 106 | + "/" 107 | + str(idx) 108 | + ".jpg" 109 | + "\n" 110 | ) 111 | global_dev_idx += 1 112 | 113 | for idx in test_idx: 114 | test_data_file.write( 115 | str(global_test_idx) 116 | + "," 117 | + tactile_path 118 | + "/" 119 | + str(idx) 120 | + ".jpg" 121 | + "\n" 122 | ) 123 | test_label_file.write( 124 | str(global_test_idx) 125 | + "," 126 | + gt_heightmap_path 127 | + "/" 128 | + str(idx) 129 | + ".jpg" 130 | + "," 131 | + gt_contactmask_path 132 | + "/" 133 | + str(idx) 134 | + ".jpg" 135 | + "\n" 136 | ) 137 | global_test_idx += 1 138 | 139 | print( 140 | "Train size: {}, Val size: {}, test size: {}".format( 141 | global_train_idx, global_dev_idx, global_test_idx 142 | ) 143 | ) 144 | train_data_file.close() 145 | train_label_file.close() 146 | dev_data_file.close() 147 | dev_label_file.close() 148 | test_data_file.close() 149 | test_label_file.close() 150 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/data/data_to_txt_real.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Load real test data structure and save to text file, preprocessing for depth test 8 | """ 9 | 10 | import os 11 | from os import path as osp 12 | import random 13 | 14 | # change the data_root to the root directory of dataset with all objects 15 | abspath = osp.abspath(__file__) 16 | dname = osp.dirname(abspath) 17 | os.chdir(dname) 18 | 19 | data_root_path = "/home/robospare/suddhu/midastouch/data/real/" 20 | objects = sorted(os.listdir(data_root_path)) 21 | 22 | # write training/validation/testing data loader files 23 | test_data_file = open("test_data_real.txt", "w") 24 | 25 | global_test_idx = 0 26 | for object in objects: 27 | obj_path = osp.join(data_root_path, object) 28 | if not osp.isdir(obj_path): 29 | continue 30 | datasets = sorted(os.listdir(obj_path)) 31 | print("dataset: ", object) 32 | 33 | for dataset in datasets: 34 | dataset_path = osp.join(obj_path, dataset) 35 | if dataset == "bg" or not osp.isdir(dataset_path): 36 | continue 37 | # load in tactile images from real sensor 38 | tactile_path = osp.join(dataset_path, "frames") 39 | imgs = sorted(os.listdir(tactile_path)) 40 | imgs = [x for x in imgs if ".jpg" in x] 41 | 42 | if len(imgs) > 10: 43 | imgs = random.sample(imgs, 10) 44 | for i, img in enumerate(imgs): 45 | test_data_file.write(str(i) + "," + tactile_path + "/" + img + "\n") 46 | global_test_idx += 1 47 | 48 | print("Real test data size: {}".format(global_test_idx)) 49 | test_data_file.close() 50 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from os import path as osp 7 | import numpy as np 8 | from PIL import Image 9 | import torch.utils.data as data 10 | import torchvision.transforms as transforms 11 | import sys 12 | 13 | sys.path.append("..") 14 | from . import flow_transforms 15 | import cv2 16 | import warnings 17 | 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | class data_loader(data.Dataset): 22 | def __init__(self, data_path, label_path): 23 | self.data_path = data_path # to txt file 24 | self.label_path = label_path # to txt file 25 | 26 | self.data = open(osp.join(data_path), "r") 27 | self.label = open(osp.join(label_path), "r") 28 | 29 | self.data_content = self.data.read() 30 | self.data_list = self.data_content.split("\n") 31 | 32 | self.label_content = self.label.read() 33 | self.label_list = self.label_content.split("\n") 34 | 35 | def __getitem__(self, index): 36 | img_path = self.data_list[index].split(",")[1] 37 | depth_path = self.label_list[index].split(",")[1] 38 | 39 | image, depth = None, None 40 | with Image.open(img_path) as im: 41 | image = np.asarray(im) 42 | image = cv2.normalize( 43 | image, None, alpha=0, beta=200, norm_type=cv2.NORM_MINMAX 44 | ) 45 | # image = image[:,:,::-1] # RGB -> BGR 46 | with Image.open(depth_path) as dp: 47 | depth = np.asarray(dp).astype(np.int64) 48 | 49 | input_transform = transforms.Compose( 50 | [flow_transforms.Scale(240), flow_transforms.ArrayToTensor()] 51 | ) 52 | target_depth_transform = transforms.Compose( 53 | [flow_transforms.Scale_Single(240), flow_transforms.ArrayToTensor()] 54 | ) 55 | image = input_transform(image) 56 | depth = target_depth_transform(depth) 57 | 58 | return image, depth 59 | 60 | def __len__(self): 61 | return len(self.label_list) - 1 62 | 63 | 64 | class real_data_loader(data.Dataset): 65 | def __init__(self, data_path): 66 | self.data_path = data_path # to txt file 67 | 68 | self.data = open(osp.join(data_path), "r") 69 | self.data_content = self.data.read() 70 | self.data_list = self.data_content.split("\n") 71 | 72 | def __getitem__(self, index): 73 | img_path = self.data_list[index].split(",")[1] 74 | 75 | image = None 76 | with Image.open(img_path) as im: 77 | image = np.asarray(im) 78 | image = cv2.normalize( 79 | image, None, alpha=0, beta=200, norm_type=cv2.NORM_MINMAX 80 | ) 81 | # image = image[:,:,::-1] # RGB -> BGR 82 | 83 | input_transform = transforms.Compose( 84 | [flow_transforms.Scale(240), flow_transforms.ArrayToTensor()] 85 | ) 86 | image = input_transform(image) 87 | return image 88 | 89 | def __len__(self): 90 | return len(self.data_list) - 1 91 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/fcrn.py: -------------------------------------------------------------------------------- 1 | # Original source: https://github.com/XPFly1989/FCRN 2 | # A Pytorch implementation of Laina, Iro, et al. "Deeper depth prediction with fully convolutional residual networks." 3 | # 3D Vision (3DV), 2016 Fourth International Conference on. IEEE, 2016. 4 | 5 | # Copyright (c) 2016, Iro Laina 6 | # All rights reserved. 7 | 8 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 9 | # Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 10 | # Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | 14 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 15 | 16 | import torch 17 | import torch.nn as nn 18 | import math 19 | 20 | 21 | class Bottleneck(nn.Module): 22 | expansion = 4 23 | 24 | def __init__(self, inplanes, planes, stride=1, downsample=None): 25 | super(Bottleneck, self).__init__() 26 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 27 | self.bn1 = nn.BatchNorm2d(planes) 28 | 29 | self.conv2 = nn.Conv2d( 30 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 31 | ) 32 | self.bn2 = nn.BatchNorm2d(planes) 33 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 34 | self.bn3 = nn.BatchNorm2d(planes * 4) 35 | self.relu = nn.ReLU(inplace=True) 36 | self.downsample = downsample 37 | self.stride = stride 38 | 39 | def forward(self, x): 40 | residual = x 41 | 42 | out = self.conv1(x) 43 | out = self.bn1(out) 44 | out = self.relu(out) 45 | 46 | out = self.conv2(out) 47 | out = self.bn2(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv3(out) 51 | out = self.bn3(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class UpProject(nn.Module): 63 | def __init__(self, in_channels, out_channels, batch_size): 64 | super(UpProject, self).__init__() 65 | self.batch_size = batch_size 66 | 67 | self.conv1_1 = nn.Conv2d(in_channels, out_channels, 3) 68 | self.conv1_2 = nn.Conv2d(in_channels, out_channels, (2, 3)) 69 | self.conv1_3 = nn.Conv2d(in_channels, out_channels, (3, 2)) 70 | self.conv1_4 = nn.Conv2d(in_channels, out_channels, 2) 71 | 72 | self.conv2_1 = nn.Conv2d(in_channels, out_channels, 3) 73 | self.conv2_2 = nn.Conv2d(in_channels, out_channels, (2, 3)) 74 | self.conv2_3 = nn.Conv2d(in_channels, out_channels, (3, 2)) 75 | self.conv2_4 = nn.Conv2d(in_channels, out_channels, 2) 76 | 77 | self.bn1_1 = nn.BatchNorm2d(out_channels) 78 | self.bn1_2 = nn.BatchNorm2d(out_channels) 79 | 80 | self.relu = nn.ReLU(inplace=True) 81 | 82 | self.conv3 = nn.Conv2d(out_channels, out_channels, 3, padding=1) 83 | 84 | self.bn2 = nn.BatchNorm2d(out_channels) 85 | 86 | def forward(self, x): 87 | out1_1 = self.conv1_1(nn.functional.pad(x, (1, 1, 1, 1))) 88 | # out1_2 = self.conv1_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding 89 | out1_2 = self.conv1_2( 90 | nn.functional.pad(x, (1, 1, 1, 0)) 91 | ) # author's interleaving pading in github 92 | # out1_3 = self.conv1_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding 93 | out1_3 = self.conv1_3( 94 | nn.functional.pad(x, (1, 0, 1, 1)) 95 | ) # author's interleaving pading in github 96 | # out1_4 = self.conv1_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding 97 | out1_4 = self.conv1_4( 98 | nn.functional.pad(x, (1, 0, 1, 0)) 99 | ) # author's interleaving pading in github 100 | 101 | out2_1 = self.conv2_1(nn.functional.pad(x, (1, 1, 1, 1))) 102 | # out2_2 = self.conv2_2(nn.functional.pad(x, (1, 1, 0, 1)))#right interleaving padding 103 | out2_2 = self.conv2_2( 104 | nn.functional.pad(x, (1, 1, 1, 0)) 105 | ) # author's interleaving pading in github 106 | # out2_3 = self.conv2_3(nn.functional.pad(x, (0, 1, 1, 1)))#right interleaving padding 107 | out2_3 = self.conv2_3( 108 | nn.functional.pad(x, (1, 0, 1, 1)) 109 | ) # author's interleaving pading in github 110 | # out2_4 = self.conv2_4(nn.functional.pad(x, (0, 1, 0, 1)))#right interleaving padding 111 | out2_4 = self.conv2_4( 112 | nn.functional.pad(x, (1, 0, 1, 0)) 113 | ) # author's interleaving pading in github 114 | 115 | height = out1_1.size()[2] 116 | width = out1_1.size()[3] 117 | 118 | out1_1_2 = ( 119 | torch.stack((out1_1, out1_2), dim=-3) 120 | .permute(0, 1, 3, 4, 2) 121 | .contiguous() 122 | .view(self.batch_size, -1, height, width * 2) 123 | ) 124 | out1_3_4 = ( 125 | torch.stack((out1_3, out1_4), dim=-3) 126 | .permute(0, 1, 3, 4, 2) 127 | .contiguous() 128 | .view(self.batch_size, -1, height, width * 2) 129 | ) 130 | 131 | out1_1234 = ( 132 | torch.stack((out1_1_2, out1_3_4), dim=-3) 133 | .permute(0, 1, 3, 2, 4) 134 | .contiguous() 135 | .view(self.batch_size, -1, height * 2, width * 2) 136 | ) 137 | 138 | out2_1_2 = ( 139 | torch.stack((out2_1, out2_2), dim=-3) 140 | .permute(0, 1, 3, 4, 2) 141 | .contiguous() 142 | .view(self.batch_size, -1, height, width * 2) 143 | ) 144 | out2_3_4 = ( 145 | torch.stack((out2_3, out2_4), dim=-3) 146 | .permute(0, 1, 3, 4, 2) 147 | .contiguous() 148 | .view(self.batch_size, -1, height, width * 2) 149 | ) 150 | 151 | out2_1234 = ( 152 | torch.stack((out2_1_2, out2_3_4), dim=-3) 153 | .permute(0, 1, 3, 2, 4) 154 | .contiguous() 155 | .view(self.batch_size, -1, height * 2, width * 2) 156 | ) 157 | 158 | out1 = self.bn1_1(out1_1234) 159 | out1 = self.relu(out1) 160 | out1 = self.conv3(out1) 161 | out1 = self.bn2(out1) 162 | 163 | out2 = self.bn1_2(out2_1234) 164 | 165 | out = out1 + out2 166 | out = self.relu(out) 167 | 168 | return out 169 | 170 | 171 | import torch.jit as jit 172 | 173 | 174 | class FCRN_net(jit.ScriptModule): 175 | # class FCRN_net(nn.Module): 176 | 177 | def __init__(self, batch_size, bottleneck=False): 178 | super(FCRN_net, self).__init__() 179 | self.inplanes = 64 180 | self.batch_size = batch_size 181 | self.bottleneck = bottleneck 182 | 183 | # ResNet with out avrgpool & fc 184 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 185 | self.bn1 = nn.BatchNorm2d(64) 186 | self.relu = nn.ReLU(inplace=True) 187 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 188 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 189 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 190 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) 191 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) 192 | 193 | # Up-Conv layers 194 | self.conv2 = nn.Conv2d(2048, 1024, kernel_size=1, bias=False) 195 | self.bn2 = nn.BatchNorm2d(1024) 196 | 197 | self.up1 = self._make_upproj_layer(UpProject, 1024, 512, self.batch_size) 198 | self.up2 = self._make_upproj_layer(UpProject, 512, 256, self.batch_size) 199 | self.up3 = self._make_upproj_layer(UpProject, 256, 128, self.batch_size) 200 | self.up4 = self._make_upproj_layer(UpProject, 128, 64, self.batch_size) 201 | 202 | self.drop = nn.Dropout2d() 203 | 204 | self.conv3 = nn.Conv2d(64, 1, 3, padding=1) 205 | 206 | self.upsample = nn.Upsample((320, 240), mode="bilinear", align_corners=False) 207 | 208 | # initialize 209 | if True: 210 | for m in self.modules(): 211 | if isinstance(m, nn.Conv2d): 212 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 213 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 214 | elif isinstance(m, nn.BatchNorm2d): 215 | m.weight.data.fill_(1) 216 | m.bias.data.zero_() 217 | 218 | def _make_layer(self, block, planes, blocks, stride=1): 219 | downsample = None 220 | if stride != 1 or self.inplanes != planes * block.expansion: 221 | downsample = nn.Sequential( 222 | nn.Conv2d( 223 | self.inplanes, 224 | planes * block.expansion, 225 | kernel_size=1, 226 | stride=stride, 227 | bias=False, 228 | ), 229 | nn.BatchNorm2d(planes * block.expansion), 230 | ) 231 | 232 | layers = [] 233 | layers.append(block(self.inplanes, planes, stride, downsample)) 234 | self.inplanes = planes * block.expansion 235 | for i in range(1, blocks): 236 | layers.append(block(self.inplanes, planes)) 237 | 238 | return nn.Sequential(*layers) 239 | 240 | def _make_upproj_layer(self, block, in_channels, out_channels, batch_size): 241 | return block(in_channels, out_channels, batch_size) 242 | 243 | @jit.script_method 244 | def forward(self, x): 245 | x = self.conv1(x) 246 | x = self.bn1(x) 247 | x = self.relu(x) 248 | x = self.maxpool(x) 249 | 250 | x = self.layer1(x) 251 | x = self.layer2(x) 252 | x = self.layer3(x) 253 | x = self.layer4(x) 254 | 255 | x = self.conv2(x) 256 | x = self.bn2(x) 257 | 258 | if self.bottleneck: 259 | return x # feature vector 260 | 261 | x = self.up1(x) 262 | x = self.up2(x) 263 | x = self.up3(x) 264 | x = self.up4(x) 265 | 266 | x = self.drop(x) 267 | 268 | x = self.conv3(x) 269 | x = self.relu(x) 270 | 271 | x = self.upsample(x) 272 | return x 273 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/tdn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | 4 | """ 5 | Tactile depth network that converts tactile images to heightmaps/masks via a fully convolutional residual networks [Laina et. al. 2016] 6 | """ 7 | 8 | import torch 9 | import torch 10 | import torch.nn.functional 11 | from .fcrn import FCRN_net 12 | import numpy as np 13 | 14 | from midastouch.render.digit_renderer import digit_renderer 15 | from midastouch.viz.visualizer import Viz 16 | 17 | from PIL import Image 18 | import collections 19 | from midastouch.modules.misc import view_subplots, DIRS, get_device 20 | from midastouch.modules.pose import transform_pc, extract_poses_sim 21 | import cv2 22 | import os 23 | from os import path as osp 24 | import hydra 25 | from omegaconf import DictConfig 26 | 27 | 28 | class TDN: 29 | def __init__( 30 | self, 31 | cfg: DictConfig, 32 | bg: np.ndarray = None, 33 | bottleneck: bool = False, 34 | real: bool = False, 35 | ): 36 | 37 | tdn_weights = osp.join(DIRS["weights"], cfg.tdn_weights) 38 | 39 | fcrn_config = cfg.fcrn.real if real else cfg.fcrn.sim 40 | self.b, self.r, self.clip = ( 41 | fcrn_config.border, 42 | fcrn_config.ratio, 43 | fcrn_config.clip, 44 | ) 45 | self.batch_size = fcrn_config.batch_size 46 | self.params = {"batch_size": self.batch_size, "shuffle": False} 47 | 48 | self.model = FCRN_net(self.batch_size, bottleneck=bottleneck) 49 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 50 | checkpoint = torch.load(tdn_weights, map_location=self.device) 51 | self.model.load_state_dict(checkpoint["state_dict"]) 52 | self.model.eval() 53 | self.model.to(self.device) 54 | 55 | self.blend_sz = fcrn_config.blend_sz 56 | self.heightmap_window = collections.deque([]) 57 | if bg is not None: 58 | self.bg = torch.Tensor(bg).to(self.device) 59 | 60 | def blend_heightmaps(self, heightmap: torch.Tensor) -> torch.Tensor: 61 | """Exponentially weighted heightmap blending. 62 | 63 | Args: 64 | heightmap: input heightmap 65 | 66 | Returns: 67 | blended_heightmap: output heightmap blended over self.heightmap_window 68 | 69 | """ 70 | 71 | if not self.blend_sz: 72 | return heightmap 73 | 74 | if len(self.heightmap_window) >= self.blend_sz: 75 | self.heightmap_window.popleft() 76 | 77 | self.heightmap_window.append(heightmap) 78 | n = len(self.heightmap_window) 79 | 80 | weights = torch.tensor( 81 | [x / n for x in range(1, n + 1)], device=heightmap.device 82 | ) # exponentially weighted time series costs 83 | 84 | weights = torch.exp(weights) / torch.sum(torch.exp(weights)) 85 | 86 | all_heightmaps = torch.stack(list(self.heightmap_window)) 87 | blended_heightmap = torch.sum( 88 | (all_heightmaps * weights[:, None, None]) / weights.sum(), dim=0 89 | ) # weighted average 90 | 91 | # view_subplots([heightmap, blended_heightmap], [["heightmap", "blended_heightmap"]]) 92 | return blended_heightmap 93 | 94 | def image2heightmap(self, image: np.ndarray) -> torch.Tensor: 95 | """Passes tactile image through FCRN and returns (blended) heightmap 96 | 97 | Args: 98 | image: single tactile image 99 | 100 | Returns: 101 | blended_output: resulting heightmap from FCRN + blending 102 | 103 | """ 104 | 105 | assert ( 106 | self.model.bottleneck is False 107 | ), "Bottleneck feature is enabled, can't carry out image2heightmap" 108 | image = cv2.normalize(image, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX) 109 | with torch.no_grad(): 110 | image = torch.from_numpy(image).permute(2, 0, 1).to(self.device).float() 111 | output = self.model(image[None, :])[ 112 | 0 113 | ].squeeze() # .data.cpu().squeeze().numpy() 114 | blended_output = self.blend_heightmaps(output) 115 | return blended_output 116 | 117 | def image2embedding(self, image: np.ndarray) -> torch.Tensor: 118 | """Passes tactile image through FCRN and returns bottleneck embedding of size 10 * 8 * 1024 119 | 120 | Args: 121 | image: single tactile image 122 | 123 | Returns: 124 | feature: feature tensor (10 * 8 * 1024, 1) 125 | 126 | """ 127 | 128 | if self.model.bottleneck is False: 129 | print("Bottleneck feature extraction not enabled, switching") 130 | self.model.bottleneck = True 131 | image = cv2.normalize(image, None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX) 132 | with torch.no_grad(): 133 | image = torch.from_numpy(image).permute(2, 0, 1).to(self.device).float() 134 | output = self.model(image[None, :])[0].squeeze() 135 | feature = output.reshape((-1, 10 * 8 * 1024)) 136 | feature = feature / torch.norm(feature, axis=1).reshape(-1, 1) 137 | return feature 138 | 139 | def heightmap2mask( 140 | self, heightmap: torch.tensor, small_parts: bool = False 141 | ) -> torch.Tensor: 142 | """Thresholds heightmap to return binary contact mask 143 | 144 | Args: 145 | heightmap: single tactile image 146 | 147 | Returns: 148 | padded_contact_mask: contact mask [True: is_contact, False: no_contact] 149 | 150 | """ 151 | heightmap = heightmap[self.b : -self.b, self.b : -self.b] 152 | init_height = self.bg[self.b : -self.b, self.b : -self.b] 153 | diff_heights = heightmap - init_height 154 | diff_heights[diff_heights < self.clip] = 0 155 | 156 | contact_mask = diff_heights > torch.quantile(diff_heights, 0.8) * self.r 157 | padded_contact_mask = torch.zeros_like(self.bg, dtype=bool) 158 | 159 | total_area = contact_mask.shape[0] * contact_mask.shape[1] 160 | atleast_area = 0.01 * total_area if small_parts else 0.1 * total_area 161 | 162 | if torch.count_nonzero(contact_mask) < atleast_area: 163 | return padded_contact_mask 164 | padded_contact_mask[self.b : -self.b, self.b : -self.b] = contact_mask 165 | return padded_contact_mask 166 | 167 | 168 | @hydra.main(config_path="../config", config_name="config") 169 | def main(cfg: DictConfig): 170 | expt_cfg, tcn_cfg, tdn_cfg = cfg.expt, cfg.tcn, cfg.tdn 171 | device = get_device(cpu=False) # get GPU 172 | 173 | obj_model = expt_cfg.obj_model 174 | log_id = str(expt_cfg.log_id).zfill(2) 175 | 176 | data_path = osp.join(DIRS["data"], "sim", obj_model, log_id) 177 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 178 | 179 | image_path, pose_path = osp.join(data_path, "tactile_images"), osp.join( 180 | data_path, "tactile_data.pkl" 181 | ) 182 | heightmap_path, contactmask_path = osp.join(data_path, "gt_heightmaps"), osp.join( 183 | data_path, "gt_contactmasks" 184 | ) 185 | 186 | viz = Viz(off_screen=False) 187 | 188 | tac_render = digit_renderer(cfg=tdn_cfg.render, obj_path=obj_path) 189 | digit_tdn = TDN(tdn_cfg, bg=tac_render.get_background(frame="gel")) 190 | 191 | # load images and ground truth depthmaps 192 | images = sorted(os.listdir(image_path), key=lambda y: int(y.split(".")[0])) 193 | heightmaps = sorted(os.listdir(heightmap_path), key=lambda y: int(y.split(".")[0])) 194 | contact_masks = sorted( 195 | os.listdir(contactmask_path), key=lambda y: int(y.split(".")[0]) 196 | ) 197 | 198 | # poses 199 | camposes, gelposes, _ = extract_poses_sim( 200 | osp.join(data_path, "tactile_data.pkl"), device=device 201 | ) # poses : (N , 4, 4) 202 | 203 | N = len(images) 204 | for i in range(N): 205 | # Open images 206 | image = np.array(Image.open(osp.join(image_path, images[i]))) 207 | gt_heightmap = np.array( 208 | Image.open(osp.join(heightmap_path, heightmaps[i])) 209 | ).astype(np.int64) 210 | contactmask = np.array( 211 | Image.open(osp.join(contactmask_path, contact_masks[i])) 212 | ).astype(bool) 213 | 214 | # Convert image to heightmap via lookup 215 | est_heightmap = digit_tdn.image2heightmap(image) 216 | est_contactmask = digit_tdn.heightmap2mask(est_heightmap) 217 | # Get pixelwise RMSE in mm, and IoU of the contact masks 218 | error_heightmap = np.abs(est_heightmap - gt_heightmap) * tac_render.pixmm 219 | heightmap_rmse = np.sqrt(np.mean(error_heightmap**2)) 220 | intersection = np.sum(np.logical_and(contactmask, est_contactmask)) 221 | contact_mask_iou = intersection / ( 222 | np.sum(contactmask) + np.sum(est_contactmask) - intersection 223 | ) 224 | 225 | # Visualize heightmaps 226 | print( 227 | "Heightmap RMSE: {:.4f} mm, Contact mask IoU: {:.4f}".format( 228 | heightmap_rmse, contact_mask_iou 229 | ) 230 | ) 231 | view_subplots( 232 | [ 233 | image / 255.0, 234 | gt_heightmap, 235 | contactmask, 236 | est_heightmap, 237 | est_contactmask, 238 | error_heightmap, 239 | ], 240 | [ 241 | ["Tactile image", "GT heightmap", "GT contact mask"], 242 | ["Est. heightmap", "Est. contact mask", "Heightmap Error (mm"], 243 | ], 244 | ) 245 | 246 | # Convert heightmaps to 3D 247 | gt_cloud = tac_render.heightmap2Pointcloud(gt_heightmap, contactmask) 248 | gt_cloud_w = transform_pc(gt_cloud.copy(), camposes[i]) 249 | est_cloud = tac_render.heightmap2Pointcloud(est_heightmap, est_contactmask) 250 | est_cloud_w = transform_pc(est_cloud.copy(), camposes[i]) 251 | 252 | 253 | if __name__ == "__main__": 254 | main() 255 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """ 4 | Loads real/sim test dataset and evaluates TDN accuracy 5 | """ 6 | 7 | import torch 8 | import torch.utils.data 9 | from .data_loader import real_data_loader, data_loader 10 | from .tdn import TDN 11 | import numpy as np 12 | import os 13 | from os import path as osp 14 | from torch.autograd import Variable 15 | import matplotlib 16 | 17 | matplotlib.use("Agg") 18 | import matplotlib.pyplot as plot 19 | from tqdm import tqdm 20 | from midastouch.render.digit_renderer import digit_renderer, pixmm 21 | from midastouch.modules.misc import DIRS 22 | import hydra 23 | from omegaconf import DictConfig 24 | 25 | dtype = torch.cuda.FloatTensor 26 | 27 | 28 | @hydra.main(config_path="config", config_name="test") 29 | def test(cfg: DictConfig): 30 | abspath = osp.abspath(__file__) 31 | dname = osp.dirname(abspath) 32 | os.chdir(dname) 33 | 34 | tac_render = digit_renderer(obj_path=None) 35 | digit_tdn = TDN(cfg.tdn, bg=tac_render.get_background(frame="gel")) 36 | results_path = osp.join(DIRS["debug"], "tdn_test") 37 | if not osp.exists(results_path): 38 | os.makedirs(results_path) 39 | 40 | if cfg.real: 41 | test_file = osp.join("data", "test_data_real.txt") 42 | test_loader = torch.utils.data.DataLoader( 43 | real_data_loader(test_file), batch_size=50, shuffle=False, drop_last=True 44 | ) 45 | 46 | # test on real dataset 47 | print("Testing on real data") 48 | with torch.no_grad(): 49 | count = 0 50 | pbar = tqdm(total=len(test_loader)) 51 | for input in test_loader: 52 | input_var = Variable(input.type(dtype)) 53 | for i in range(len(input_var)): 54 | input_rgb_image = ( 55 | input_var[i] 56 | .data.permute(1, 2, 0) 57 | .cpu() 58 | .numpy() 59 | .astype(np.uint8) 60 | ) 61 | est_h = digit_tdn.image2heightmap(input_rgb_image) 62 | est_c = digit_tdn.heightmap2mask(est_h) 63 | # pred_image /= np.max(pred_image) 64 | plot.imsave( 65 | osp.join(results_path, f"{count}_input.png"), input_rgb_image 66 | ) 67 | plot.imsave( 68 | osp.join(results_path, f"{count}_pred_heightmap.png"), 69 | est_h, 70 | cmap="viridis", 71 | ) 72 | plot.imsave(osp.join(results_path, f"{count}_pred_mask.png"), est_c) 73 | count += 1 74 | pbar.update(1) 75 | pbar.close() 76 | return 77 | else: 78 | test_file = osp.join("data", "test_data.txt") 79 | label_file = osp.join("data", "test_label.txt") 80 | test_loader = torch.utils.data.DataLoader( 81 | data_loader(test_file, label_file), 82 | batch_size=50, 83 | shuffle=False, 84 | drop_last=True, 85 | ) 86 | 87 | heightmap_rmse, contact_mask_iou = [], [] 88 | 89 | # test on real dataset 90 | print("Testing on sim data") 91 | with torch.no_grad(): 92 | count = 0 93 | pbar = tqdm(total=len(test_loader)) 94 | for input, depth in test_loader: 95 | input_var = Variable(input.type(dtype)) 96 | gt_var = Variable(depth.type(dtype)) 97 | 98 | for i in range(len(input_var)): 99 | input_rgb_image = ( 100 | input_var[i] 101 | .data.permute(1, 2, 0) 102 | .cpu() 103 | .numpy() 104 | .astype(np.uint8) 105 | ) 106 | gt_c = gt_var[i].data.squeeze().cpu().numpy().astype(np.float32) 107 | 108 | est_h = digit_tdn.image2heightmap(input_rgb_image) 109 | est_c = digit_tdn.heightmap2mask(est_h) 110 | 111 | error_heightmap = np.abs(est_h - gt_c) * pixmm 112 | heightmap_rmse.append(np.sqrt(np.mean(error_heightmap**2))) 113 | intersection = np.sum(np.logical_and(gt_c, est_c)) 114 | contact_mask_iou.append( 115 | intersection / (np.sum(est_c) + np.sum(gt_c) - intersection) 116 | ) 117 | count += 1 118 | pbar.update(1) 119 | pbar.close() 120 | 121 | heightmap_rmse = [x for x in heightmap_rmse if str(x) != "nan"] 122 | contact_mask_iou = [x for x in contact_mask_iou if str(x) != "nan"] 123 | heightmap_rmse = sum(heightmap_rmse) / len(heightmap_rmse) 124 | contact_mask_iou = sum(contact_mask_iou) / len(contact_mask_iou) 125 | error_file = open(osp.join(results_path, "tdn_error.txt"), "w") 126 | error_file.write(str(heightmap_rmse) + "," + str(contact_mask_iou) + "\n") 127 | error_file.close() 128 | 129 | 130 | if __name__ == "__main__": 131 | test() 132 | -------------------------------------------------------------------------------- /midastouch/contrib/tdn_fcrn/train.py: -------------------------------------------------------------------------------- 1 | # Original source: https://github.com/XPFly1989/FCRN 2 | # A Pytorch implementation of Laina, Iro, et al. "Deeper depth prediction with fully convolutional residual networks." 3 | # 3D Vision (3DV), 2016 Fourth International Conference on. IEEE, 2016. 4 | 5 | # Copyright (c) 2016, Iro Laina 6 | # All rights reserved. 7 | 8 | # Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 9 | # Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 10 | # Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 11 | 12 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 13 | 14 | # Copyright (c) Meta Platforms, Inc. All Rights Reserved 15 | 16 | """ 17 | Loads TACTO training data and trains the TDN 18 | """ 19 | 20 | import torch 21 | import torch.utils.data 22 | from midastouch.contrib.tdn_fcrn.data_loader import data_loader, real_data_loader 23 | import numpy as np 24 | import os 25 | from os import path as osp 26 | from torch.autograd import Variable 27 | import matplotlib 28 | 29 | matplotlib.use("Agg") 30 | import matplotlib.pyplot as plot 31 | from torch.utils.tensorboard import SummaryWriter 32 | from weights import load_weights 33 | from tqdm import tqdm 34 | from midastouch.modules.misc import DIRS 35 | from midastouch.contrib.tdn_fcrn.fcrn import FCRN_net 36 | import hydra 37 | from omegaconf import DictConfig 38 | import time 39 | 40 | dtype = torch.cuda.FloatTensor 41 | pixmm = 0.03 42 | 43 | 44 | @hydra.main(config_path="config", config_name="train") 45 | def main(cfg: DictConfig) -> None: 46 | abspath = osp.abspath(__file__) 47 | dname = osp.dirname(abspath) 48 | os.chdir(dname) 49 | 50 | batch_size, learning_rate, num_epochs = cfg.batch_size, cfg.lr, cfg.max_epochs 51 | resume_from_file = cfg.resume_from_file 52 | checkpoint_path = osp.join(DIRS["weights"], cfg.checkpoint_weights) 53 | checkpoint_save_path = osp.join( 54 | DIRS["weights"], time.strftime("%Y%m%d_%H") + "_" + cfg.checkpoint_weights 55 | ) 56 | print(f"Saving checkpoint to {checkpoint_save_path}") 57 | # momentum, weight_decay = 0.9, 0.0005 58 | 59 | print(f"Batch size: {batch_size}, Learning rate: {learning_rate}") 60 | results_path = osp.join(DIRS["debug"], "tdn_train") 61 | 62 | data_file_path = cfg.data_file_path 63 | train_data_file = osp.join(data_file_path, "train_data.txt") 64 | dev_data_file = osp.join(data_file_path, "dev_data.txt") 65 | train_label_file = osp.join(data_file_path, "train_label.txt") 66 | dev_label_file = osp.join(data_file_path, "dev_label.txt") 67 | test_data_file = osp.join(data_file_path, "test_data.txt") 68 | test_label_file = osp.join(data_file_path, "test_label.txt") 69 | 70 | print(f"Loading data, Resume training: {resume_from_file}") 71 | train_loader = torch.utils.data.DataLoader( 72 | data_loader(train_data_file, train_label_file), 73 | batch_size=batch_size, 74 | shuffle=True, 75 | drop_last=True, 76 | ) 77 | val_loader = torch.utils.data.DataLoader( 78 | data_loader(dev_data_file, dev_label_file), 79 | batch_size=batch_size, 80 | shuffle=False, 81 | drop_last=True, 82 | ) 83 | test_loader = torch.utils.data.DataLoader( 84 | data_loader(test_data_file, test_label_file), 85 | batch_size=batch_size, 86 | shuffle=False, 87 | drop_last=True, 88 | ) 89 | ## test with real data 90 | # test_real_file = osp.join(data_file_path,'test_data_real.txt') 91 | # test_loader = torch.utils.data.DataLoader(real_data_loader(test_real_file), batch_size=batch_size, shuffle=False, drop_last=True) 92 | 93 | print("Loading model...") 94 | model = FCRN_net(batch_size=batch_size) 95 | model = model.cuda() 96 | 97 | loss_fn = torch.nn.MSELoss().cuda() 98 | 99 | input_path = osp.join(results_path, "input") 100 | gt_path = osp.join(results_path, "gt") 101 | pred_path = osp.join(results_path, "pred") 102 | 103 | if not osp.exists(input_path): 104 | os.makedirs(input_path) 105 | if not osp.exists(gt_path): 106 | os.makedirs(gt_path) 107 | if not osp.exists(pred_path): 108 | os.makedirs(pred_path) 109 | 110 | writer = SummaryWriter("train_log") 111 | 112 | start_epoch = 0 113 | if resume_from_file: 114 | if os.path.isfile(checkpoint_path): 115 | print("=> loading checkpoint '{}'".format(checkpoint_path)) 116 | checkpoint = torch.load(checkpoint_path, map_location="cuda:0") 117 | start_epoch = checkpoint["epoch"] 118 | model.load_state_dict(checkpoint["state_dict"]) 119 | print( 120 | "=> loaded checkpoint '{}' (epoch {})".format( 121 | checkpoint_path, checkpoint["epoch"] 122 | ) 123 | ) 124 | else: 125 | print("=> no checkpoint found at '{}'".format(checkpoint_path)) 126 | else: 127 | # curl -O http://campar.in.tum.de/files/rupprecht/depthpred/NYU_ResNet-UpProj.npy 128 | weights_file = osp.join(DIRS["weights"], "NYU_ResNet-UpProj.npy") 129 | print("=> loading pre-trained NYU weights'{}'".format(weights_file)) 130 | model.load_state_dict(load_weights(model, weights_file, dtype)) 131 | 132 | # validate 133 | print("Validating on sim data") 134 | model.eval() 135 | num_samples, loss_local = 0, 0 136 | with torch.no_grad(): 137 | pbar = tqdm(total=len(val_loader)) 138 | for input, depth in val_loader: 139 | input_var = Variable(input.type(dtype)) 140 | gt_var = Variable(depth.type(dtype)) 141 | output = model(input_var) 142 | loss_local += loss_fn(output, gt_var) 143 | num_samples += 1 144 | pbar.update(1) 145 | pbar.close() 146 | 147 | best_val_err = np.sqrt(float(loss_local) / num_samples) 148 | print("Before train error: {:.3f} pixel RMSE".format(best_val_err)) 149 | 150 | for epoch in range(num_epochs): 151 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 152 | # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum) 153 | # optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) 154 | 155 | print("Starting train epoch %d / %d" % (start_epoch + epoch + 1, num_epochs)) 156 | model.train() 157 | running_loss, count, epoch_loss = 0, 0, 0 158 | 159 | pbar = tqdm(total=len(train_loader)) 160 | # for i, (input, depth) in enumerate(train_loader): 161 | for input, depth in train_loader: 162 | input_var = Variable(input.type(dtype)) 163 | gt_var = Variable(depth.type(dtype)) 164 | 165 | output = model(input_var) 166 | loss = loss_fn(output, gt_var) 167 | 168 | # print('loss:', loss.item()) 169 | # input_img = input_var.squeeze().permute(1, 2, 0).detach().cpu().numpy() 170 | # output_img = output.squeeze().detach().cpu().numpy() 171 | running_loss += loss.data.cpu().numpy() 172 | count += 1 173 | optimizer.zero_grad() 174 | loss.backward() 175 | optimizer.step() 176 | pbar.update(1) 177 | pbar.set_description( 178 | "RMSE pixel loss: {:.2f}".format(np.sqrt(running_loss / count)) 179 | ) 180 | pbar.close() 181 | 182 | # TODO: tensorboard 183 | epoch_loss = np.sqrt(running_loss / count) 184 | print("Epoch error: {:.3f} pixel RMSE".format(epoch_loss)) 185 | 186 | writer.add_scalar("train_loss", epoch_loss, start_epoch + epoch + 1) 187 | 188 | # validate 189 | print("Validating on sim data") 190 | model.eval() 191 | num_samples, loss_local = 0, 0 192 | with torch.no_grad(): 193 | pbar = tqdm(total=len(val_loader)) 194 | for input, depth in val_loader: 195 | input_var = Variable(input.type(dtype)) 196 | gt_var = Variable(depth.type(dtype)) 197 | 198 | output = model(input_var) 199 | loss_local += loss_fn(output, gt_var) 200 | num_samples += 1 201 | pbar.update(1) 202 | pbar.close() 203 | 204 | err = np.sqrt(float(loss_local) / num_samples) 205 | print( 206 | "Validation error: {:.3f} pixel RMSE, Best validation error: {:.3f} pixel RMSE".format( 207 | err, best_val_err 208 | ) 209 | ) 210 | writer.add_scalar("val_loss", err, start_epoch + epoch + 1) 211 | 212 | if err < best_val_err: 213 | print("Saving new checkpoint: {}".format(checkpoint_save_path)) 214 | best_val_err = err 215 | torch.save( 216 | { 217 | "epoch": start_epoch + epoch + 1, 218 | "state_dict": model.state_dict(), 219 | "optimizer": optimizer.state_dict(), 220 | }, 221 | checkpoint_save_path, 222 | ) 223 | else: 224 | learning_rate = learning_rate * 0.6 225 | print( 226 | "No reduction of validation error, dropping learning rate to {}".format( 227 | learning_rate 228 | ) 229 | ) 230 | 231 | if (epoch > 0) and (epoch % 10 == 0): 232 | learning_rate = learning_rate * 0.6 233 | print("10 epochs, dropping learning rate to {}".format(learning_rate)) 234 | 235 | print("Testing on sim data") 236 | model.eval() 237 | num_samples, loss_local = 0, 0 238 | # make local IoU 239 | with torch.no_grad(): 240 | pbar = tqdm(total=len(test_loader)) 241 | for input, depth in test_loader: 242 | input_var = Variable(input.type(dtype)) 243 | gt_var = Variable(depth.type(dtype)) 244 | 245 | output = model(input_var) 246 | 247 | if num_samples == 0: 248 | input_rgb_image = ( 249 | input_var[0] 250 | .data.permute(1, 2, 0) 251 | .cpu() 252 | .numpy() 253 | .astype(np.uint8) 254 | ) 255 | gt_image = gt_var[0].data.squeeze().cpu().numpy().astype(np.float32) 256 | pred_image = ( 257 | output[0].data.squeeze().cpu().numpy().astype(np.float32) 258 | ) 259 | gt_image /= np.max(gt_image) 260 | pred_image /= np.max(pred_image) 261 | 262 | plot.imsave( 263 | osp.join( 264 | input_path, 265 | "input_epoch_{}.png".format(start_epoch + epoch + 1), 266 | ), 267 | input_rgb_image, 268 | ) 269 | plot.imsave( 270 | osp.join( 271 | gt_path, "gt_epoch_{}.png".format(start_epoch + epoch + 1) 272 | ), 273 | gt_image, 274 | cmap="viridis", 275 | ) 276 | plot.imsave( 277 | osp.join( 278 | pred_path, 279 | "pred_epoch_{}.png".format(start_epoch + epoch + 1), 280 | ), 281 | pred_image, 282 | cmap="viridis", 283 | ) 284 | loss_local += loss_fn(output, gt_var) 285 | num_samples += 1 286 | pbar.update(1) 287 | pbar.close() 288 | err = np.sqrt(float(loss_local) / num_samples) * pixmm 289 | print(f"Test error: {err:.3f} mm RMSE") 290 | writer.add_scalar("test_loss", err, start_epoch + epoch + 1) 291 | writer.flush() 292 | writer.close() 293 | 294 | 295 | if __name__ == "__main__": 296 | main() 297 | -------------------------------------------------------------------------------- /midastouch/data_gen/README.md: -------------------------------------------------------------------------------- 1 | # Data generation: YCB-Slide and training corpus 2 | 3 | The subfolder consists of scripts to generate [TACTO](https://github.com/facebookresearch/tacto) simulated interactions with YCB-objects. These can be done in three ways: 4 | 5 | ## 1. Densely sample object meshes 6 | This is used to generate tactile training data, with DIGIT background and lighting augmentation. 7 | ```python 8 | python midastouch/data_gen/generate_data.py method=train_data 9 | ``` 10 |
11 | 13 |
14 | 15 | ## 2. Generate random sliding trajectories 16 | This is used to recreate the trajectories from the YCB-Slide dataset. 17 | 18 | ```python 19 | python midastouch/data_gen/generate_data.py method=ycb_slide 20 | ``` 21 | 22 |
23 | 25 |
26 | 27 | ## 3. Manually generate custom sliding trajectories 28 | This is used for custom trajectories entirely defined by user generated waypoints. 29 | 30 | ```python 31 | python midastouch/data_gen/generate_data.py method=manual_slide 32 | ``` 33 | 34 |
35 | 37 |
38 | 39 | --- 40 | 41 | ## Fine-grained settings 42 | 43 | Modify the hydra `.yaml` parameters in the `./config` folder for more fine-grained settings: 44 | 45 | ```yaml 46 | obj_class: ycb_test # ycb_test: 10 evaluation objects in the MidasTouch paper, ycb_train: 40 training objects for the depth network training 47 | obj_model: 035_power_drill 48 | sampling : traj # choose between (random, random+edges, traj, manual) 49 | save_path : ./ycb_slide # all data is saved in the hydra output folder 50 | noise: # noise to corrupt poses 51 | sig_r: 1 # degrees 52 | sig_t: 5e-4 # m 53 | 54 | num_samples : 2000 # number of interactions 55 | total_length : 0.5 # maximum path length for geodesic paths 56 | 57 | render: 58 | pixmm : 0.03 # pix to mm conversion for the DIGIT 59 | width : 240 # width of images 60 | height : 320 # height of images 61 | cam_dist : 0.022 # distance between gel and camera 62 | shear_mag : 5.0 # shear angle noise 63 | pen : 64 | min : 0.0005 # minimum penetration into object 65 | max : 0.001 # maximum penetration into object 66 | randomize : False # randomize the template backgrounds and lighting augmentation 67 | ``` 68 | 69 | --- 70 | ## Other notes 71 | 72 | - The generated data is saved in timestamped folders in the hydra `output` directory 73 | - To use this data to train your tactile depth network, refer to the `contrib/tdn_fcrn` folder 74 | - To use this data in filtering experiments, refer to the `filter/filter.py` script -------------------------------------------------------------------------------- /midastouch/data_gen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/data_gen/__init__.py -------------------------------------------------------------------------------- /midastouch/data_gen/config/config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | defaults: 7 | - method: train_data # train_data, manual_slide, ycb_slide -------------------------------------------------------------------------------- /midastouch/data_gen/config/method/manual_slide.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | obj_class: ycb_test 7 | obj_model: 035_power_drill 8 | sampling : manual # random, random_edges, traj, manual 9 | save_path : ./manual_slide 10 | noise: 11 | sig_r: 1 # degrees 12 | sig_t: 5e-4 # m 13 | 14 | num_samples : 2000 15 | total_length : 0.5 16 | 17 | render: 18 | pixmm : 0.03 19 | width : 240 20 | height : 320 21 | cam_dist : 0.022 22 | shear_mag : 5.0 23 | pen : 24 | min : 0.0005 25 | max : 0.001 26 | randomize : False -------------------------------------------------------------------------------- /midastouch/data_gen/config/method/train_data.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | obj_class: ycb_train 7 | obj_model: 035_power_drill 8 | sampling : random # random, random+edges, traj, manual 9 | save_path : ./train_data 10 | noise: 11 | sig_r: 1 # degrees 12 | sig_t: 5e-4 # m 13 | 14 | num_samples : 10000 15 | total_length : 0.5 16 | 17 | render: 18 | pixmm : 0.03 19 | width : 240 20 | height : 320 21 | cam_dist : 0.022 22 | shear_mag : 5.0 23 | pen : 24 | min : 0.0001 25 | max : 0.001 26 | randomize : True -------------------------------------------------------------------------------- /midastouch/data_gen/config/method/ycb_slide.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | obj_class: ycb_test 7 | obj_model: 035_power_drill 8 | sampling : traj # random, random+edges, traj, manual 9 | save_path : ./ycb_slide 10 | noise: 11 | sig_r: 1 # degrees 12 | sig_t: 5e-4 # m 13 | 14 | num_samples : 2000 15 | total_length : 0.5 16 | 17 | render: 18 | pixmm : 0.03 19 | width : 240 20 | height : 320 21 | cam_dist : 0.022 22 | shear_mag : 5.0 23 | pen : 24 | min : 0.0005 25 | max : 0.001 26 | randomize : False -------------------------------------------------------------------------------- /midastouch/data_gen/generate_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Collect data from YCB objects, choose between [random, traj, manual] methods 8 | """ 9 | 10 | from midastouch.data_gen.touch_simulator import touch_simulator 11 | from midastouch.modules.objects import ycb_test, ycb_train 12 | import hydra 13 | from omegaconf import DictConfig 14 | 15 | 16 | @hydra.main(version_base=None, config_path="./config", config_name="config") 17 | def main(cfg: DictConfig): 18 | cfg = cfg.method 19 | obj_class = ycb_test if cfg.obj_class == "ycb_test" else ycb_train 20 | for obj_model in obj_class: 21 | cfg.obj_model = obj_model 22 | touch_simulator(cfg=cfg) 23 | 24 | 25 | if __name__ == "__main__": 26 | main() 27 | -------------------------------------------------------------------------------- /midastouch/data_gen/touch_simulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Simulates tactile interaction on object meshes, choose between [random, random+edges, traj, manual] 8 | """ 9 | 10 | import os 11 | from os import path as osp 12 | import numpy as np 13 | 14 | from midastouch.viz.helpers import viz_poses_pointclouds_on_mesh 15 | from midastouch.render.digit_renderer import digit_renderer 16 | from midastouch.modules.misc import ( 17 | remove_and_mkdir, 18 | DIRS, 19 | save_contactmasks, 20 | save_heightmaps, 21 | save_images, 22 | ) 23 | from midastouch.modules.mesh import sample_poses_on_mesh 24 | from midastouch.modules.pose import transform_pc, xyzquat_to_tf_numpy 25 | from midastouch.data_gen.utils import random_geodesic_poses, random_manual_poses 26 | import dill as pickle 27 | import trimesh 28 | import hydra 29 | from omegaconf import DictConfig 30 | from tqdm import tqdm 31 | 32 | 33 | def touch_simulator(cfg: DictConfig): 34 | """Tactile simulator function""" 35 | render_cfg = cfg.render 36 | obj_model = cfg.obj_model 37 | sampling = cfg.sampling 38 | num_samples = cfg.num_samples 39 | total_length = cfg.total_length 40 | save_path = cfg.save_path 41 | randomize = render_cfg.randomize 42 | headless = False 43 | 44 | # make paths 45 | if save_path is None: 46 | data_path = osp.join(DIRS["data"], "sim", obj_model) 47 | file_idx = 0 48 | while osp.exists( 49 | osp.join(data_path, str(file_idx).zfill(2), "tactile_data.pkl") 50 | ): 51 | file_idx += 1 52 | data_path = osp.join(data_path, str(file_idx).zfill(2)) 53 | else: 54 | data_path = osp.join(save_path, obj_model) 55 | 56 | remove_and_mkdir(data_path) 57 | 58 | image_path = osp.join(data_path, "tactile_images") 59 | heightmap_path = osp.join(data_path, "gt_heightmaps") 60 | contactmasks_path = osp.join(data_path, "gt_contactmasks") 61 | pose_path = osp.join(data_path, "tactile_data.pkl") 62 | 63 | os.makedirs(image_path) 64 | os.makedirs(heightmap_path) 65 | os.makedirs(contactmasks_path) 66 | 67 | print(f"object: {obj_model}, data_path: {data_path} sampling: {sampling}\n") 68 | 69 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 70 | 71 | mesh = trimesh.load(obj_path) 72 | 73 | # get poses depending on the method 74 | if "random" in sampling: 75 | # random independent samples 76 | print(f"Generating {num_samples} random samples") 77 | sample_poses = sample_poses_on_mesh( 78 | mesh=mesh, 79 | num_samples=num_samples, 80 | edges=True if (sampling == "random+edges") else False, 81 | ) 82 | elif "traj" in sampling: 83 | # random geodesic trajectory 84 | print(f"Generating random geodesic trajectory") 85 | sample_poses = None 86 | while sample_poses is None: 87 | sample_poses = random_geodesic_poses( 88 | mesh, 89 | shear_mag=render_cfg.shear_mag, 90 | total_length=total_length, 91 | N=num_samples, 92 | ) 93 | elif "manual" in sampling: 94 | # manually selected waypoints trajectory 95 | print(f"Generating manual waypoint trajectory") 96 | sample_poses = random_manual_poses( 97 | mesh_path=obj_path, shear_mag=render_cfg.shear_mag, lc=0.001 98 | ) 99 | else: 100 | print("Invalid sampling routine, exiting!") 101 | return 102 | 103 | # start renderer 104 | tac_render = digit_renderer(cfg=render_cfg, obj_path=obj_path, headless=headless) 105 | 106 | # remove NaNs 107 | batch_size = 1000 108 | traj_sz = sample_poses.shape[0] 109 | num_batches = traj_sz // batch_size 110 | num_batches = num_batches if (num_batches != 0) else 1 111 | 112 | gelposes, camposes, gelposes_meas = ( 113 | np.empty((0, 7)), 114 | np.empty((0, 7)), 115 | np.empty((0, 7)), 116 | ) 117 | gt_heightmaps, gt_masks, tactile_images = [], [], [] 118 | for i in tqdm(range(num_batches)): 119 | if randomize: 120 | tac_render = digit_renderer( 121 | cfg=render_cfg, obj_path=obj_path, randomize=randomize, headless=headless 122 | ) 123 | i_range = ( 124 | np.array(range(i * batch_size, traj_sz)) 125 | if (i == num_batches - 1) 126 | else np.array(range(i * batch_size, (i + 1) * batch_size)) 127 | ) 128 | ( 129 | hm, 130 | cm, 131 | image, 132 | campose, 133 | gelpose, 134 | gelpose_meas, 135 | ) = tac_render.render_sensor_trajectory( 136 | p=sample_poses[i_range, :], mNoise=cfg.noise 137 | ) 138 | gelposes = np.append(gelposes, gelpose, axis=0) 139 | camposes = np.append(camposes, campose, axis=0) 140 | gelposes_meas = np.append(gelposes_meas, gelpose_meas, axis=0) 141 | tactile_images = tactile_images + image 142 | gt_heightmaps = gt_heightmaps + hm 143 | gt_masks = gt_masks + cm 144 | 145 | # Save ground-truth pointclouds and tactile images 146 | print( 147 | f"Saving data: \nHeightmaps: {heightmap_path} \nContact masks: {contactmasks_path} \nTactile Images: {image_path}" 148 | ) 149 | save_heightmaps(gt_heightmaps, heightmap_path) 150 | save_contactmasks(gt_masks, contactmasks_path) 151 | save_images(tactile_images, image_path) 152 | 153 | pointclouds, pointclouds_world = [None] * traj_sz, [None] * traj_sz 154 | for i, (h, c, p) in enumerate(zip(gt_heightmaps, gt_masks, camposes)): 155 | pointclouds[i] = tac_render.heightmap2Pointcloud(h, c) 156 | pointclouds_world = transform_pc(pointclouds, camposes) 157 | 158 | save_dict = { 159 | "gelposes": gelposes, 160 | "camposes": camposes, 161 | "gelposes_meas": gelposes_meas, 162 | "mNoise": cfg.noise, 163 | } 164 | 165 | print("Saving data to path: {}".format(pose_path)) 166 | with open(pose_path, "wb") as file: 167 | pickle.dump(save_dict, file) 168 | 169 | if not headless: 170 | viz_gelposes = xyzquat_to_tf_numpy(gelposes) 171 | viz_gelposes_meas = xyzquat_to_tf_numpy(gelposes_meas) 172 | 173 | print("Visualizing data") 174 | if len(pointclouds_world) > 2500: 175 | pointclouds_world = pointclouds_world[::10] 176 | viz_gelposes, viz_gelposes_meas = ( 177 | viz_gelposes[::10, :], 178 | viz_gelposes_meas[::10, :], 179 | ) 180 | 181 | viz_poses_pointclouds_on_mesh( 182 | mesh_path=obj_path, 183 | poses=viz_gelposes, 184 | pointclouds=pointclouds_world, 185 | save_path=osp.join(data_path, "tactile_data"), 186 | decimation_factor=10, 187 | ) 188 | viz_poses_pointclouds_on_mesh( 189 | mesh_path=obj_path, 190 | poses=viz_gelposes_meas, 191 | pointclouds=pointclouds_world, 192 | save_path=osp.join(data_path, "tactile_data_noisy"), 193 | decimation_factor=10, 194 | ) 195 | return 196 | 197 | 198 | @hydra.main(config_path="./config", config_name="config") 199 | def main(cfg: DictConfig): 200 | touch_simulator(cfg=cfg.method) 201 | 202 | 203 | if __name__ == "__main__": 204 | main() 205 | -------------------------------------------------------------------------------- /midastouch/data_gen/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Pose and mesh utilities for data generation 8 | """ 9 | 10 | import numpy as np 11 | import pyvista as pv 12 | import open3d as o3d 13 | import potpourri3d as pp3d 14 | import random 15 | from scipy.spatial import KDTree 16 | import time 17 | import math 18 | import trimesh 19 | from midastouch.modules.pose import ( 20 | pose_from_vertex_normal, 21 | tf_to_xyzquat, 22 | ) 23 | 24 | 25 | def get_geodesic_path(mesh_path: str, start_point, end_point): 26 | """ 27 | Get geodesic path along a mesh given start and end vertices (pyvista geodesic_distance) 28 | """ 29 | mesh = pv.read(mesh_path) 30 | start_point_idx = np.argmin(np.linalg.norm(start_point - mesh.points, axis=1)) 31 | end_point_idx = np.argmin(np.linalg.norm(end_point - mesh.points, axis=1)) 32 | path_pts = mesh.geodesic(start_point_idx, end_point_idx) 33 | path_distance = mesh.geodesic_distance(start_point_idx, end_point_idx) 34 | return path_pts.points, path_distance 35 | 36 | 37 | def random_geodesic_poses(mesh, shear_mag, total_length=0.5, N=2000): 38 | """Generate random points and compute geodesic trajectory""" 39 | 40 | cumm_length = 0.0 41 | num_waypoints = 1 42 | seg_length = total_length / float(num_waypoints) 43 | while seg_length > 0.25 * mesh.scale: 44 | num_waypoints += 1 45 | seg_length = total_length / float(num_waypoints) 46 | 47 | V, F = np.array(mesh.vertices), np.array(mesh.faces) 48 | print(f"num_waypoints: {num_waypoints}") 49 | solver = pp3d.MeshHeatMethodDistanceSolver(V, F) 50 | path_solver = pp3d.EdgeFlipGeodesicSolver( 51 | V, F 52 | ) # shares precomputation for repeated solves 53 | 54 | sample_points, sample_normals = np.empty((0, 3)), np.empty((0, 3)) 55 | seg_start = random.randint(0, V.shape[0]) 56 | 57 | waypoints = V[seg_start, None] 58 | start_time = time.time() 59 | tree = KDTree(mesh.vertices) 60 | 61 | for _ in range(num_waypoints): 62 | geo_dist = solver.compute_distance(seg_start) 63 | candidates = np.argsort(np.abs(geo_dist - seg_length)) 64 | 65 | waypoint_dist = np.linalg.norm(V[candidates, :, None] - waypoints.T, axis=1) 66 | waypoint_dist = np.amin(waypoint_dist, axis=1) 67 | mask = waypoint_dist < 0.01 68 | candidates = np.ma.MaskedArray(candidates, mask=mask) 69 | candidates = candidates.compressed() 70 | seg_end = candidates[0] 71 | seg_dist = geo_dist[seg_end] 72 | waypoints = np.concatenate((waypoints, V[seg_end, None]), axis=0) 73 | 74 | seg_points = path_solver.find_geodesic_path(v_start=seg_start, v_end=seg_end) 75 | # subsample path (spline) to 0.1mm per odom 76 | 77 | # length-based sampling 78 | # segmentSpline = pv.Spline(seg_points, n_segment) 79 | # segmentSpline = np.array(segmentSpline.points) 80 | 81 | _, ii = tree.query(seg_points, k=1) 82 | segmentNormals = mesh.vertex_normals[ii, :] 83 | 84 | sample_points = np.concatenate((sample_points, seg_points), axis=0) 85 | sample_normals = np.concatenate((sample_normals, segmentNormals), axis=0) 86 | 87 | cumm_length += seg_dist 88 | 89 | seg_start = seg_end 90 | if time.time() - start_time > 120: 91 | print("Timeout, trying again!") 92 | return None 93 | 94 | # interval-based sampling 95 | n_interval = math.ceil(len(sample_points) / N) 96 | n_interval = 1 if n_interval == 0 else n_interval 97 | sample_points = sample_points[::n_interval, :] 98 | sample_normals = sample_normals[::n_interval, :] 99 | 100 | delta = np.zeros(sample_points.shape[0]) 101 | 102 | a = 1 103 | for i in range(1, sample_points.shape[0]): 104 | if i % int(sample_points.shape[0] / num_waypoints) == 0: 105 | a = -a 106 | delta[i] = delta[i - 1] + np.radians(np.random.normal(loc=a, scale=0.01)) 107 | T = pose_from_vertex_normal(sample_points, sample_normals, shear_mag, delta) 108 | print( 109 | f"Dataset path length: {cumm_length:.4f} m, Num poses: {sample_points.shape[0]}, Time taken: {time.time() - start_time}" 110 | ) 111 | return T 112 | 113 | 114 | def random_manual_poses(mesh_path, shear_mag, lc=0.001): 115 | """Pick points and sample trajectory""" 116 | 117 | mesh = trimesh.load(mesh_path) 118 | tree = KDTree(mesh.vertices) 119 | 120 | """Get points from user input""" 121 | cumm_length = 0 122 | traj_points = pick_points(mesh_path) 123 | 124 | """Generate point and normals""" 125 | if traj_points.shape[0] == 1: 126 | sample_points = traj_points 127 | else: 128 | sample_points, seg_dist = get_geodesic_path( 129 | mesh_path, traj_points[0, :], traj_points[1, :] 130 | ) 131 | n_segment = int(seg_dist / lc) 132 | n_interval = int(len(sample_points) / n_segment) 133 | n_interval = 1 if (n_interval == 0) else n_interval 134 | sample_points = sample_points[::n_interval, :] 135 | cumm_length += seg_dist 136 | _, ii = tree.query(sample_points, k=1) 137 | sample_normals = mesh.vertex_normals[ii, :] 138 | 139 | for i in range(1, traj_points.shape[0] - 1): 140 | temp, seg_dist = get_geodesic_path( 141 | mesh_path, traj_points[i, :], traj_points[i + 1, :] 142 | ) 143 | n_segment = int(seg_dist / lc) 144 | n_interval = int(len(temp) / n_segment) 145 | temp = temp[::n_interval, :] 146 | sample_points = np.concatenate((sample_points, temp), axis=0) 147 | cumm_length += seg_dist 148 | _, ii = tree.query(temp, k=1) 149 | sample_normals = np.concatenate( 150 | (sample_normals, mesh.vertex_normals[ii, :]), axis=0 151 | ) 152 | 153 | """visualize path""" 154 | # path_visual = trimesh.load_path(sample_points) 155 | # scene = trimesh.Scene([path_visual, mesh]) 156 | # scene.show() 157 | 158 | """Convert point and normals to poses""" 159 | # varying delta over trajectory 160 | delta = np.zeros(sample_points.shape[0]) 161 | a = 1 162 | for i in range(1, sample_points.shape[0]): 163 | if i % int(sample_points.shape[0] / 5) == 0: 164 | a = -a 165 | delta[i] = delta[i - 1] + np.radians(np.random.normal(loc=a, scale=0.01)) 166 | 167 | T = pose_from_vertex_normal(sample_points, sample_normals, shear_mag, delta) 168 | print( 169 | f"Dataset path length: {cumm_length:.4f} m, Num poses: {sample_points.shape[0]}" 170 | ) 171 | 172 | return T 173 | 174 | 175 | def pick_points(mesh_path): 176 | """ 177 | http://www.open3d.org/docs/latest/tutorial/visualization/interactive_visualization.html 178 | """ 179 | print("") 180 | print("1) Please pick waypoints using [shift + left click]") 181 | print(" Press [shift + right click] to undo point picking") 182 | print("2) After picking points, press 'Q' to close the window") 183 | mesh = o3d.io.read_triangle_mesh(mesh_path) 184 | 185 | pcd = o3d.geometry.PointCloud() 186 | pcd.points = o3d.utility.Vector3dVector(mesh.vertices) 187 | vis = o3d.visualization.VisualizerWithEditing() 188 | vis.create_window() 189 | vis.add_geometry(pcd) 190 | vis.run() # user picks points 191 | vis.destroy_window() 192 | return np.asarray(pcd.points)[vis.get_picked_points(), :] 193 | -------------------------------------------------------------------------------- /midastouch/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/eval/__init__.py -------------------------------------------------------------------------------- /midastouch/eval/compute_contact_area.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Compute estimated average contact area per-object 8 | """ 9 | 10 | import os 11 | from os import path as osp 12 | 13 | from midastouch.render.digit_renderer import digit_renderer 14 | from midastouch.contrib.tdn_fcrn import TDN 15 | from midastouch.modules.misc import DIRS, load_images 16 | import tqdm as tqdm 17 | import numpy as np 18 | 19 | import hydra 20 | from omegaconf import DictConfig 21 | 22 | 23 | def compute_contact_area(cfg: DictConfig, real=False): 24 | expt_cfg, tdn_cfg = cfg.expt, cfg.tdn 25 | obj_model = expt_cfg.obj_model 26 | 27 | # make paths 28 | if real: 29 | data_path = osp.join(DIRS["data"], "sim", obj_model) 30 | else: 31 | data_path = osp.join(DIRS["data"], "real", obj_model) 32 | 33 | print(f"compute_contact_area \n Object: {obj_model}\n") 34 | 35 | all_datasets = sorted(os.listdir(data_path)) 36 | print(f"datasets: {all_datasets}") 37 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 38 | 39 | tac_render = digit_renderer(cfg=tdn_cfg.render, obj_path=obj_path) 40 | digit_tdn = TDN(tdn_cfg, bg=tac_render.get_background(frame="gel")) 41 | 42 | for dataset in all_datasets: 43 | if dataset == "bg" or not osp.isdir(osp.join(data_path, dataset)): 44 | continue 45 | 46 | dataset_path = osp.join(data_path, dataset) 47 | if real: 48 | image_path = osp.join(dataset_path, "frames") 49 | else: 50 | image_path = osp.join(dataset_path, "tactile_images") 51 | 52 | images = load_images(image_path) 53 | 54 | traj_sz = len(images) 55 | 56 | digit_area_cm_sq = 0.02 * 0.03 * (10**4) 57 | pbar = tqdm(total=traj_sz) 58 | areas = [] 59 | for j, image in enumerate(images): 60 | est_h = digit_tdn.image2heightmap(image) 61 | est_c = digit_tdn.heightmap2mask(est_h) 62 | ratio = est_c.sum() / est_c.size 63 | areas.append(digit_area_cm_sq * ratio) 64 | pbar.update(1) 65 | pbar.close() 66 | 67 | avg_area = np.vstack(areas).mean() 68 | print(f"avg_contact_area: {avg_area}") 69 | np.save(osp.join(dataset_path, "avg_contact_area.npy"), avg_area) 70 | return 71 | 72 | 73 | @hydra.main(config_path="../config", config_name="config") 74 | def main(cfg: DictConfig): 75 | compute_contact_area(cfg=cfg) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /midastouch/eval/compute_surface_area.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Compute surface area of object v.s. sensor size 8 | """ 9 | 10 | from os import path as osp 11 | import trimesh 12 | from midastouch.modules.misc import DIRS 13 | from midastouch.modules.objects import ycb_test 14 | 15 | 16 | def compute_surface_area(obj_model): 17 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 18 | mesh = trimesh.load(obj_path) 19 | 20 | mesh_area_cm_sq = mesh.area * (10**4) 21 | digit_area_cm_sq = 0.02 * 0.03 * (10**4) 22 | ratio = mesh_area_cm_sq / digit_area_cm_sq 23 | print(f"{obj_model} surface area: {mesh_area_cm_sq:.3f}, ratio: {ratio:.1f}") 24 | return 25 | 26 | 27 | if __name__ == "__main__": 28 | obj_models = ycb_test 29 | for obj_model in obj_models: 30 | compute_surface_area(obj_model) 31 | -------------------------------------------------------------------------------- /midastouch/eval/decimate_meshes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Downsample meshes for faster rendering""" 7 | 8 | from os import path as osp 9 | import os 10 | import pyvista as pv 11 | import trimesh 12 | from midastouch.modules.misc import DIRS 13 | 14 | obj_paths = osp.join(DIRS["obj_models"]) 15 | objects = sorted(os.listdir(obj_paths)) 16 | for object in objects: 17 | stl_path = osp.join(obj_paths, object, "nontextured.stl") 18 | mesh_trimesh = trimesh.load(stl_path) 19 | mesh_pv_deci = pv.wrap( 20 | mesh_trimesh.simplify_quadratic_decimation( 21 | face_count=int(mesh_trimesh.vertices.shape[0] / 10) 22 | ) 23 | ) # decimated pyvista object 24 | stl_path = stl_path.replace("nontextured", "nontextured_decimated") 25 | print(f"Saving: {stl_path}") 26 | mesh_pv_deci.save(stl_path) 27 | -------------------------------------------------------------------------------- /midastouch/eval/single_touch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Generate statistic for single-shot localization of objects""" 7 | 8 | import os 9 | from os import path as osp 10 | import numpy as np 11 | 12 | from midastouch.modules.objects import ycb_test 13 | from midastouch.modules.misc import change_to_dir, DIRS 14 | from midastouch.viz.helpers import viz_embedding_TSNE 15 | import dill as pickle 16 | from sklearn.metrics.pairwise import cosine_similarity 17 | 18 | from midastouch.tactile_tree.tactile_tree import R3_SE3 19 | import seaborn as sns 20 | import pandas as pd 21 | import matplotlib 22 | import matplotlib.pyplot as plt 23 | from tqdm import tqdm 24 | 25 | plt.rc("pdf", fonttype=42) 26 | plt.rc("ps", fonttype=42) 27 | plt.rc("font", family="serif") 28 | 29 | plt.rc("xtick", labelsize="small") 30 | plt.rc("ytick", labelsize="small") 31 | 32 | NUM_NEIGHBORS = 25 33 | 34 | 35 | def top_n_error(embeddings, poses, n=NUM_NEIGHBORS): 36 | """ 37 | Get best embedding error from top-N poses 38 | """ 39 | N = poses.shape[0] 40 | 41 | top_n_error = np.zeros(N) 42 | 43 | batch_size = 5000 44 | num_batches = N // batch_size 45 | num_batches = 1 if num_batches == 0 else num_batches 46 | 47 | C = np.zeros((N, N)) 48 | for i in tqdm(range(num_batches)): 49 | i_range = ( 50 | np.array(range(i * batch_size, N)) 51 | if (i == num_batches - 1) 52 | else np.array(range(i * batch_size, (i + 1) * batch_size)) 53 | ) 54 | for j in range(num_batches): 55 | j_range = ( 56 | np.array(range(j * batch_size, N)) 57 | if (j == num_batches - 1) 58 | else np.array(range(j * batch_size, (j + 1) * batch_size)) 59 | ) 60 | C[i_range[:, None], j_range] = cosine_similarity( 61 | np.atleast_2d(embeddings[i_range, :]), 62 | np.atleast_2d(embeddings[j_range, :]), 63 | ).squeeze() 64 | 65 | np.fill_diagonal(C, 0) 66 | for i in range(N): 67 | best_idxs = np.argpartition(C[i, :], -(n))[-n:] 68 | predicted_poses = poses[best_idxs, :] 69 | gt_pose = poses[i, :] 70 | e_t = np.linalg.norm(predicted_poses - gt_pose, axis=1) 71 | top_n_error[i] = np.min(e_t) 72 | 73 | return top_n_error 74 | 75 | 76 | def get_random_error(poses, n=NUM_NEIGHBORS): 77 | """ 78 | Get random pose error 79 | """ 80 | N = poses.shape[0] 81 | 82 | rand_error = np.zeros(N) 83 | for i in range(N): 84 | pred_idxs = np.random.choice(N, size=n) 85 | predicted_poses = poses[pred_idxs, :] 86 | gt_pose = poses[i, :] 87 | e_t = np.linalg.norm(predicted_poses - gt_pose, axis=1) 88 | rand_error[i] = np.min(e_t) 89 | return np.mean(rand_error) 90 | 91 | 92 | def plot_violin(df, method="pointcloud"): 93 | """ 94 | Image or pointcloud violin plot 95 | """ 96 | 97 | change_to_dir(osp.abspath(__file__)) 98 | 99 | results_path = "single_touch" 100 | if not os.path.exists(results_path): 101 | os.makedirs(results_path) 102 | 103 | df = pd.read_pickle(osp.join(results_path, f"error_{method}.pkl")) 104 | 105 | savepath = osp.join(results_path, f"violin_{method}.pdf") 106 | 107 | fig = plt.figure() 108 | 109 | df = df.sort_values(by=["median_error"], ascending=True) 110 | # sns.set_theme(style="whitegrid") 111 | palette = sns.color_palette("vlag", n_colors=len(df["key"].unique())) 112 | sns.violinplot( 113 | x="key", 114 | y="error", 115 | data=df, 116 | palette=palette, 117 | cut=0, 118 | gridsize=10000, 119 | saturation=1, 120 | linewidth=0.5, 121 | ) 122 | # df.reset_index(level=0, inplace=True) 123 | # ax = sns.lineplot(x = "key", y = "median_error", data=df, color="black", legend=False, linewidth=0.5) 124 | # ax.lines[0].set_linestyle("--") 125 | 126 | plt.xlabel("YCB object models", fontsize=12) 127 | plt.ylabel(f"Normalized Top-{NUM_NEIGHBORS} pose error", fontsize=12) 128 | 129 | ax = plt.gca() 130 | 131 | plt.ylim([0, 1.5]) 132 | plt.axhline(y=1.0, linestyle="--", linewidth=0.3, color=(0, 0, 0, 0.75)) 133 | 134 | figure = plt.gcf() 135 | figure.set_size_inches(12, 4) 136 | plt.savefig(savepath, transparent=True, bbox_inches="tight", pad_inches=0) 137 | print("saved to ", savepath) 138 | plt.close() 139 | 140 | 141 | def plot_split_violin(): 142 | """ 143 | Image and pointcloud violin split plot 144 | """ 145 | 146 | print("Plotting split violin") 147 | change_to_dir(osp.abspath(__file__)) 148 | 149 | matplotlib.use("TkAgg") 150 | 151 | results_path = "single_touch" 152 | if not os.path.exists(results_path): 153 | os.makedirs(results_path) 154 | 155 | cloud_df = pd.read_pickle(osp.join(results_path, "error_cloud.pkl")) 156 | image_df = pd.read_pickle(osp.join(results_path, "error_image.pkl")) 157 | 158 | get_overall_median(cloud_df, method="pointcloud") 159 | get_overall_median(image_df, method="image") 160 | 161 | savepath = osp.join(results_path, "violin_split.pdf") 162 | 163 | df = pd.DataFrame() 164 | df = df.append(cloud_df) 165 | df = df.append(image_df) 166 | 167 | fig = plt.figure() 168 | # df = df.sort_values(by=["median_error"], ascending=True) 169 | # sns.set_theme(style="whitegrid") 170 | # palette = sns.color_palette("vlag", n_colors = len(df["key"].unique())) 171 | muted_pal = sns.color_palette("colorblind") 172 | my_pal = {"cloud": muted_pal[0], "image": muted_pal[-1]} 173 | ax = sns.violinplot( 174 | x="key", 175 | y="error", 176 | data=df, 177 | hue="method", 178 | split=True, 179 | inner="quart", 180 | linewidth=0.5, 181 | palette=my_pal, 182 | cut=0, 183 | gridsize=1000, 184 | saturation=1, 185 | ) 186 | ax.legend_.remove() 187 | 188 | sns.despine(left=True) 189 | 190 | plt.xlabel("YCB object models", fontsize=12) 191 | plt.ylabel(f"Normalized Top-{NUM_NEIGHBORS} pose error", fontsize=12) 192 | 193 | ax = plt.gca() 194 | 195 | plt.ylim([0, 1.5]) 196 | plt.axhline(y=1.0, linestyle="--", linewidth=0.3, color=(0, 0, 0, 0.75)) 197 | 198 | figure = plt.gcf() 199 | figure.set_size_inches(12, 4) 200 | plt.savefig(savepath, transparent=True, bbox_inches="tight", pad_inches=0) 201 | print("saved to ", savepath) 202 | plt.close() 203 | 204 | return 205 | 206 | 207 | def benchmark_embeddings(obj_models, method="pointcloud"): 208 | """ 209 | Benchmark embedding error 210 | """ 211 | 212 | change_to_dir(osp.abspath(__file__)) 213 | 214 | results_path = "single_touch" 215 | if not os.path.exists(results_path): 216 | os.makedirs(results_path) 217 | 218 | df = pd.DataFrame() 219 | error_file = open(osp.join(results_path, f"error_{method}.txt"), "w") 220 | 221 | for obj_model in obj_models: 222 | 223 | if method == "image": 224 | pickle_path = osp.join(DIRS["trees"], obj_model, "image_codebook.pkl") 225 | else: 226 | pickle_path = osp.join(DIRS["trees"], obj_model, "codebook.pkl") 227 | 228 | print(f"Loading tree {pickle_path}") 229 | with open(pickle_path, "rb") as pickle_file: 230 | tactile_tree = pickle.load(pickle_file) 231 | 232 | poses, _ = tactile_tree.get_poses() 233 | poses = R3_SE3(poses) 234 | # poses = poses[:, :3] 235 | 236 | embeddings = tactile_tree.get_embeddings() 237 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 238 | print(f"Getting top-{NUM_NEIGHBORS} {method} error") 239 | 240 | error_n = top_n_error(embeddings, poses, n=NUM_NEIGHBORS) 241 | # error_1 = top_n_error(embeddings, poses, n = 1) 242 | random_error = get_random_error(poses, n=NUM_NEIGHBORS) 243 | 244 | # normalized pose error 245 | # error_1 /= random_error 246 | error_n /= random_error 247 | 248 | print(error_n) 249 | 250 | topN = pd.DataFrame( 251 | { 252 | "error": error_n.tolist(), 253 | "median_error": [np.median(error_n)] * len(error_n), 254 | "key": [obj_model[:3]] * len(error_n), 255 | "method": [method] * len(error_n), 256 | } 257 | ) 258 | df = df.append(topN) 259 | 260 | print( 261 | f"{obj_model} : Median norm. pose RMSE Top {NUM_NEIGHBORS}: {np.median(error_n):.4f}" 262 | ) 263 | error_file.write( 264 | f"{obj_model} : Median norm. pose RMSE Top {NUM_NEIGHBORS}: {np.median(error_n):.4f}\n" 265 | ) 266 | 267 | viz_embedding_TSNE( 268 | mesh_path=obj_path, 269 | samples=poses.copy(), 270 | clusters=error_n, 271 | save_path=osp.join(results_path, f"{obj_model}_cloud_error"), 272 | nPoints=None, 273 | radius_factor=50.0, 274 | ) 275 | 276 | df.to_pickle(osp.join(results_path, f"error_{method}.pkl")) 277 | error_file.close() 278 | return df 279 | 280 | 281 | def get_overall_median(df, method="pointcloud"): 282 | all_median = df["median_error"].median() 283 | print(f"Overall median error (method: {method}) = {all_median}") 284 | return 285 | 286 | 287 | if __name__ == "__main__": 288 | config_file = "scripts/midastouch/config/config.ini" 289 | df = benchmark_embeddings(ycb_test, method="pointcloud") 290 | plot_violin(df=df) 291 | # plot_split_violin() 292 | -------------------------------------------------------------------------------- /midastouch/eval/viz_codebook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Visualize TSNE of tactile embeddings 8 | """ 9 | 10 | from os import path as osp 11 | from midastouch.viz.helpers import viz_embedding_TSNE 12 | import dill as pickle 13 | from midastouch.modules.misc import DIRS, get_device, confusion_matrix, color_tsne 14 | from midastouch.modules.objects import ycb_test 15 | 16 | 17 | def viz_codebook(obj_model): 18 | print("model: ", obj_model) 19 | 20 | device = get_device(cpu=False) 21 | 22 | tree_path = osp.join(DIRS["trees"], obj_model, "codebook.pkl") 23 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 24 | 25 | codebook = pickle.load(open(tree_path, "rb")) 26 | codebook.to_device(device) 27 | 28 | poses, _ = codebook.get_poses() 29 | embeddings = codebook.get_embeddings() 30 | sz = len(codebook) 31 | print("Visualize tree of size: {}".format(sz)) 32 | 33 | # euclidean TSNE is proportional to cosine distance is the features are normalized, 34 | # so we can skip the confusion matrix computation 35 | print(f"Generating feature embedding scores {embeddings.shape[1]}") 36 | if embeddings.shape[1] > 256: 37 | C = confusion_matrix(embeddings.detach().cpu().numpy(), sz) 38 | TSNE = color_tsne(C, "pca") 39 | else: 40 | TSNE = color_tsne(embeddings.detach().cpu().numpy(), "pca") 41 | 42 | print("Viz. TSNE") 43 | viz_embedding_TSNE( 44 | mesh_path=obj_path, 45 | samples=poses.detach().cpu().numpy(), 46 | clusters=TSNE, 47 | save_path=osp.join(DIRS["trees"], obj_model, f"tsne_{sz}"), 48 | nPoints=500, 49 | radius_factor=80.0, 50 | off_screen=False, 51 | ) 52 | return 53 | 54 | 55 | if __name__ == "__main__": 56 | obj_models = ycb_test 57 | for obj_model in obj_models: 58 | viz_codebook(obj_model) 59 | -------------------------------------------------------------------------------- /midastouch/filter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/filter/__init__.py -------------------------------------------------------------------------------- /midastouch/filter/filter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Run MidasTouch on simulated YCB-Slide data 8 | """ 9 | 10 | import os 11 | from os import path as osp 12 | import numpy as np 13 | import torch 14 | from midastouch.modules.particle_filter import particle_filter, particle_rmse 15 | from midastouch.modules.misc import ( 16 | DIRS, 17 | remove_and_mkdir, 18 | get_time, 19 | load_images, 20 | get_device, 21 | images_to_video, 22 | ) 23 | from midastouch.modules.pose import extract_poses_sim 24 | import dill as pickle 25 | 26 | import yappi 27 | import hydra 28 | from hydra.utils import get_original_cwd 29 | from omegaconf import DictConfig, OmegaConf 30 | 31 | from midastouch.viz.visualizer import Viz 32 | from midastouch.render.digit_renderer import digit_renderer 33 | from midastouch.contrib.tdn_fcrn.tdn import TDN 34 | from midastouch.contrib.tcn_minkloc.tcn import TCN 35 | from midastouch.modules.objects import ycb_test 36 | import time 37 | from tqdm import tqdm 38 | 39 | import threading 40 | 41 | 42 | def filter(cfg: DictConfig, viz: Viz) -> None: 43 | """Filtering for tactile simulation data""" 44 | expt_cfg, tcn_cfg, tdn_cfg = cfg.expt, cfg.tcn, cfg.tdn 45 | 46 | device = get_device(cpu=False) 47 | 48 | # print('\n----------------------------------------\n') 49 | # print(OmegaConf.to_yaml(cfg)) 50 | # print('----------------------------------------\n') 51 | 52 | init_particles = expt_cfg.params.num_particles 53 | obj_model = expt_cfg.obj_model 54 | small_parts = False if obj_model in ycb_test else True 55 | log_id = str(expt_cfg.log_id).zfill(2) 56 | 57 | noise_ratio = expt_cfg.params.noise_ratio 58 | frame_rate = expt_cfg.frame_rate 59 | 60 | # Results saved in "output" folder 61 | results_path = osp.join(os.getcwd(), obj_model, log_id) 62 | trial_id = 0 63 | while osp.exists(osp.join(results_path, f"trial_{str(trial_id).zfill(2)}")): 64 | trial_id += 1 65 | results_path = osp.join(results_path, f"trial_{str(trial_id).zfill(2)}") 66 | if expt_cfg.ablation: 67 | results_path = osp.join(results_path, f"{noise_ratio}") 68 | remove_and_mkdir(results_path) 69 | 70 | # Loading data 71 | print("Loading dataset...") 72 | data_path = osp.join(DIRS["data"], "sim", obj_model, log_id) 73 | gt_p_cam, gt_p, meas_p = extract_poses_sim( 74 | osp.join(data_path, "tactile_data.pkl"), device=device 75 | ) # poses : (N , 4, 4) 76 | image_path = osp.join(data_path, "tactile_images") 77 | tactile_images = load_images(image_path, N=expt_cfg.max_length) # (N, 3, H, W) 78 | traj_size = len(tactile_images) 79 | 80 | # Init pf and rendering classes 81 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 82 | pf = particle_filter(cfg, obj_path, noise_ratio) 83 | tac_render = digit_renderer(cfg=tdn_cfg.render, obj_path=obj_path) 84 | 85 | digit_tcn = TCN(tcn_cfg) 86 | digit_tdn = TDN(tdn_cfg, bg=tac_render.get_background(frame="gel")) 87 | 88 | # load tactile codebook 89 | codebok_path = osp.join(DIRS["trees"], obj_model, "codebook.pkl") 90 | codebook = pickle.load(open(codebok_path, "rb")) 91 | codebook.to_device(device) 92 | heatmap_poses, _ = codebook.get_poses() 93 | heatmap_embeddings = codebook.get_embeddings() 94 | 95 | pbar = tqdm(total=traj_size, desc="processing") 96 | timer = dict.fromkeys(["tactile", "motion", "meas"]) 97 | avg_timer = {"tactile": [], "motion": [], "meas": []} 98 | 99 | filter_stats = { 100 | "rmse_t": [], 101 | "rmse_r": [], 102 | "time": [], 103 | "traj_size": traj_size, 104 | "avg_time": None, 105 | "total_time": 0, 106 | "cluster_poses": [], 107 | "cluster_stds": [], 108 | "obj_name": obj_model, 109 | "tree_size": len(codebook), 110 | "noise_ratio": noise_ratio, 111 | "init_noise": pf.init_noise, 112 | "init_particles": init_particles, 113 | "num_particles": [], 114 | "log_id": log_id, 115 | "trial_id": trial_id, 116 | } 117 | msg = "" 118 | 119 | pbar.set_description(msg + "Opening visualizer...") 120 | if viz: 121 | viz.init_variables( 122 | obj_model=obj_model, 123 | mesh_path=obj_path, 124 | gt_pose=gt_p, 125 | n_particles=init_particles, 126 | ) 127 | 128 | prev_idx, count = 0, 0 129 | 130 | # run filter 131 | while True: 132 | while viz.pause: 133 | time.sleep(0.01) 134 | current_time = filter_stats["total_time"] 135 | idx = int(frame_rate * current_time) 136 | diff = idx - prev_idx 137 | 138 | if idx >= traj_size: 139 | break 140 | image = tactile_images[idx] 141 | 142 | start_time = time.time() 143 | # image to heightmap 144 | heightmap = digit_tdn.image2heightmap(image) # expensive 145 | mask = digit_tdn.heightmap2mask(heightmap, small_parts=small_parts) 146 | # heightmap to code 147 | tactile_code = digit_tcn.cloud_to_tactile_code(tac_render, heightmap, mask) 148 | timer["tactile"] = get_time(start_time) 149 | 150 | # motion model 151 | start_time = time.time() 152 | if prev_idx > 0: 153 | # t > 0 Propagate motion model 154 | odom = torch.inverse(meas_p[prev_idx, :]) @ meas_p[idx, :] # noisy 155 | particles = pf.motionModel(particles, odom, multiplier=1.0) 156 | timer["motion"] = get_time(start_time) 157 | else: 158 | # t = 0 Intialize particles 159 | particles = pf.init_filter(gt_p[idx, :], init_particles) 160 | particles.poses, _, _ = codebook.SE3_NN(particles.poses) 161 | timer["motion"] = get_time(start_time) 162 | 163 | # compute RMSE 164 | rmse_t, rmse_r = particle_rmse(particles, gt_p[idx, :]) 165 | filter_stats["rmse_t"].append(rmse_t.item()) 166 | filter_stats["rmse_r"].append(rmse_r.item()) 167 | 168 | # get similarity from codebook 169 | start_time = time.time() 170 | _, _, nn_tactile_codes = codebook.SE3_NN(particles.poses) 171 | particles.weights = pf.get_similarity( 172 | tactile_code, nn_tactile_codes, softmax=True 173 | ) 174 | 175 | # prune drifted particles 176 | particles, drifted = pf.remove_invalid_particles(particles) 177 | if drifted: 178 | pbar.set_description("All particles have drifted, re-projecting to surface") 179 | particles.poses, _, _ = codebook.SE3_NN(particles.poses) 180 | 181 | # cluster particles 182 | if count % 50 == 0: 183 | particles = pf.cluster_particles(particles) 184 | cluster_poses, cluster_stds = pf.get_cluster_centers( 185 | particles, method="quat_avg" 186 | ) 187 | 188 | # anneal and resample 189 | particles = pf.annealing(particles, torch.mean(cluster_stds)) 190 | particles = pf.resampler(particles) 191 | 192 | # save stats 193 | timer["meas"] = get_time(start_time) 194 | filter_stats["cluster_poses"].append(cluster_poses) 195 | filter_stats["cluster_stds"].append(cluster_stds) 196 | filter_stats["num_particles"].append(len(particles)) 197 | 198 | iteration_time = sum(timer.values()) 199 | filter_stats["time"].append(iteration_time) 200 | 201 | msg = ( 202 | f'[RMSE: {1000 * filter_stats["rmse_t"][-1]:.1f} mm, {filter_stats["rmse_r"][-1]:.0f} deg, ' 203 | f"{len(cluster_stds)} cluster(s) with {torch.mean(cluster_stds):.3f} sigma, P: {len(particles)}, " 204 | f'rate: {(1.0/filter_stats["time"][-1]):.2f} Hz] ' 205 | ) 206 | 207 | for key in timer: 208 | avg_timer[key].append(timer[key]) 209 | 210 | if viz is not None: 211 | # Update visualizer 212 | pbar.set_description(msg + " Visualizing results") 213 | heatmap_weights = pf.get_similarity( 214 | tactile_code, heatmap_embeddings, softmax=False 215 | ) 216 | viz.update( 217 | particles, 218 | cluster_poses, 219 | cluster_stds, 220 | gt_p_cam[idx, :], 221 | heatmap_poses, 222 | heatmap_weights, 223 | image, 224 | heightmap, 225 | mask, 226 | idx, 227 | image_savepath=osp.join(results_path, f"{idx}.png"), 228 | ) 229 | 230 | prev_idx = idx 231 | count += 1 232 | filter_stats["total_time"] = sum(filter_stats["time"]) 233 | pbar.update(diff) 234 | 235 | # save stats and data 236 | if viz is not None: 237 | pbar.set_description("End of sequence: saving data") 238 | viz.close() 239 | 240 | for key in avg_timer: 241 | avg_timer[key] = np.average(avg_timer[key]) 242 | 243 | filter_stats["avg_time"] = sum(filter_stats["time"]) / len(filter_stats["time"]) 244 | print( 245 | f"Total time: {filter_stats['total_time']:.3f}, Per iteration time: {filter_stats['avg_time']:.3f}" 246 | ) 247 | print( 248 | f'Avg time: tactile: {avg_timer["tactile"]:.2f}, motion : {avg_timer["motion"]:.2f}, meas : {avg_timer["meas"]:.2f} ' 249 | ) 250 | 251 | print("---------------------------------------------------------\n\n") 252 | np.save(osp.join(results_path, "filter_stats.npy"), filter_stats) 253 | pbar.set_description("Generating video from images") 254 | images_to_video(results_path) # convert saved images to .mp4 255 | pbar.close() 256 | return 257 | 258 | 259 | @hydra.main(config_path="../config", config_name="config") 260 | def main(cfg: DictConfig, viz=None, profile=False): 261 | 262 | if profile: 263 | yappi.set_clock_type("wall") # profiling 264 | yappi.start(builtins=True) 265 | 266 | if cfg.expt.render: 267 | viz = Viz(off_screen=cfg.expt.off_screen, zoom=1.0, window_size=0.25) 268 | 269 | t = threading.Thread(name="filter", target=filter, args=(cfg, viz)) 270 | t.start() 271 | if viz: 272 | viz.plotter.app.exec_() 273 | t.join() 274 | 275 | if profile: 276 | stats = yappi.get_func_stats() 277 | stats.save(osp.join(get_original_cwd(), "filter.prof"), type="pstat") 278 | 279 | 280 | if __name__ == "__main__": 281 | main() 282 | -------------------------------------------------------------------------------- /midastouch/filter/filter_real.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Run MidasTouch on real YCB-Slide data 8 | """ 9 | 10 | import os 11 | from os import path as osp 12 | import numpy as np 13 | import torch 14 | from midastouch.modules.particle_filter import particle_filter, particle_rmse 15 | from midastouch.modules.misc import ( 16 | DIRS, 17 | remove_and_mkdir, 18 | get_time, 19 | load_images, 20 | get_device, 21 | images_to_video, 22 | ) 23 | 24 | from midastouch.modules.pose import ( 25 | extract_poses_real, 26 | euler_angles_to_matrix, 27 | ) 28 | import dill as pickle 29 | 30 | import yappi 31 | import hydra 32 | from hydra.utils import get_original_cwd 33 | from omegaconf import DictConfig, OmegaConf 34 | 35 | from midastouch.viz.visualizer import Viz 36 | from midastouch.render.digit_renderer import digit_renderer 37 | from midastouch.contrib.tdn_fcrn.tdn import TDN 38 | from midastouch.contrib.tcn_minkloc.tcn import TCN 39 | from midastouch.modules.objects import ycb_test 40 | import time 41 | from tqdm import tqdm 42 | 43 | import threading 44 | 45 | update_freq = 1 46 | 47 | 48 | def filter_real(cfg: DictConfig, viz: Viz) -> None: 49 | """Filtering for tactile real-world data""" 50 | expt_cfg, tcn_cfg, tdn_cfg = cfg.expt, cfg.tcn, cfg.tdn 51 | 52 | device = get_device(cpu=False) 53 | 54 | # print('\n----------------------------------------\n') 55 | # print(OmegaConf.to_yaml(cfg)) 56 | # print('----------------------------------------\n') 57 | 58 | init_particles = expt_cfg.params.num_particles 59 | obj_model = expt_cfg.obj_model 60 | log_id = f"dataset_{expt_cfg.log_id}" 61 | 62 | noise_ratio = expt_cfg.params.noise_ratio 63 | frame_rate = expt_cfg.frame_rate 64 | 65 | # Results saved in "output" folder 66 | results_path = osp.join(os.getcwd(), obj_model, log_id) 67 | trial_id = 0 68 | while osp.exists(osp.join(results_path, f"trial_{str(trial_id).zfill(2)}")): 69 | trial_id += 1 70 | results_path = osp.join(results_path, f"trial_{str(trial_id).zfill(2)}") 71 | if expt_cfg.ablation: 72 | results_path = osp.join(results_path, f"{noise_ratio}") 73 | remove_and_mkdir(results_path) 74 | 75 | tree_path = osp.join(DIRS["trees"], obj_model, "codebook.pkl") 76 | 77 | # Loading data 78 | print("Loading dataset...") 79 | data_path = osp.join(DIRS["data"], "real", obj_model, log_id) 80 | subsample = 2 81 | gt_p_cam, gt_p = extract_poses_real( 82 | pose_file=osp.join(data_path, "synced_data.npy"), 83 | alignment_file=osp.join(data_path, "..", "alignment.npy"), 84 | obj_model=obj_model, 85 | device=device, 86 | subsample=subsample, 87 | ) # poses : (N , 4, 4) 88 | 89 | image_path = osp.join(data_path, "frames") 90 | tactile_images = load_images(image_path, N=expt_cfg.max_length) # (N, 3, H, W) 91 | 92 | traj_size = gt_p_cam.shape[0] 93 | tactile_images = tactile_images[::subsample] 94 | 95 | # Init pf and rendering classes 96 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 97 | pf = particle_filter(cfg, obj_path, noise_ratio, real=True) 98 | tac_render = digit_renderer(cfg=tdn_cfg.render, obj_path=obj_path) 99 | 100 | digit_tcn = TCN(tcn_cfg) 101 | digit_tdn = TDN(tdn_cfg, bg=tac_render.get_background(frame="gel"), real=True) 102 | 103 | # load tactile codebook 104 | codebook = pickle.load(open(tree_path, "rb")) 105 | codebook.to_device(device) 106 | heatmap_poses, _ = codebook.get_poses() 107 | heatmap_embeddings = codebook.get_embeddings() 108 | 109 | pbar = tqdm(total=traj_size, desc="processing") 110 | timer = dict.fromkeys(["tactile", "motion", "meas"]) 111 | avg_timer = {"tactile": [], "motion": [], "meas": []} 112 | 113 | filter_stats = { 114 | "rmse_t": [], 115 | "rmse_r": [], 116 | "time": [], 117 | "traj_size": traj_size, 118 | "avg_time": None, 119 | "total_time": 0, 120 | "cluster_poses": [], 121 | "cluster_stds": [], 122 | "obj_name": obj_model, 123 | "tree_size": len(codebook), 124 | "noise_ratio": noise_ratio, 125 | "init_noise": pf.init_noise, 126 | "init_particles": init_particles, 127 | "num_particles": [], 128 | "log_id": log_id, 129 | "trial_id": trial_id, 130 | } 131 | msg = "" 132 | 133 | # add measurement noise 134 | mNoise = None 135 | if mNoise is not None and gt_p.shape[0] > 1: 136 | N = gt_p.shape[0] 137 | tn = torch.normal( 138 | mean=0.0, 139 | std=mNoise["sig_t"], 140 | size=(N, 3), 141 | ).to(gt_p.device) 142 | rotNoise = torch.normal( 143 | mean=0.0, 144 | std=mNoise["sig_r"], 145 | size=(N, 3), 146 | ).to(gt_p.device) 147 | Rn = euler_angles_to_matrix(torch.deg2rad(rotNoise), "ZYX") 148 | Tn = torch.zeros_like(gt_p) 149 | Tn[:, :3, :3], Tn[:, :3, 3], Tn[:, 3, 3] = Rn, tn, 1 150 | meas_p = gt_p @ Tn 151 | else: 152 | meas_p = gt_p 153 | 154 | pbar.set_description(msg + "Opening visualizer...") 155 | if viz: 156 | viz.init_variables( 157 | obj_model=obj_model, 158 | mesh_path=obj_path, 159 | gt_pose=gt_p, 160 | n_particles=init_particles, 161 | ) 162 | 163 | prev_idx, count = 0, 0 164 | 165 | # run filter 166 | while True: 167 | while viz.pause: 168 | time.sleep(0.01) 169 | current_time = filter_stats["total_time"] 170 | idx = int(frame_rate * current_time) 171 | diff = idx - prev_idx 172 | 173 | if idx >= traj_size: 174 | break 175 | image = tactile_images[idx] 176 | 177 | start_time = time.time() 178 | # image to heightmap 179 | heightmap = digit_tdn.image2heightmap(image) # expensive 180 | mask = digit_tdn.heightmap2mask(heightmap) 181 | # heightmap to code 182 | tactile_code = digit_tcn.cloud_to_tactile_code(tac_render, heightmap, mask) 183 | timer["tactile"] = get_time(start_time) 184 | 185 | # motion model 186 | start_time = time.time() 187 | if prev_idx > 0: 188 | # t > 0 Propagate motion model 189 | odom = torch.inverse(meas_p[prev_idx, :]) @ meas_p[idx, :] # noisy 190 | particles = pf.motionModel(particles, odom, multiplier=1.0) 191 | timer["motion"] = get_time(start_time) 192 | else: 193 | # t = 0 Intialize particles 194 | particles = pf.init_filter(gt_p[idx, :], init_particles) 195 | particles.poses, _, _ = codebook.SE3_NN(particles.poses) 196 | timer["motion"] = get_time(start_time) 197 | 198 | # compute RMSE 199 | rmse_t, rmse_r = particle_rmse(particles, gt_p[idx, :]) 200 | filter_stats["rmse_t"].append(rmse_t.item()) 201 | filter_stats["rmse_r"].append(rmse_r.item()) 202 | 203 | start_time = time.time() 204 | # apply measurement model every N frames 205 | if count % update_freq == 0: 206 | # get similarity from codebook 207 | _, _, nn_tactile_codes = codebook.SE3_NN(particles.poses) 208 | particles.weights = pf.get_similarity( 209 | tactile_code, nn_tactile_codes, softmax=False 210 | ) 211 | else: 212 | particles.weights = torch.ones(len(particles), device=device) 213 | 214 | # prune drifted particles 215 | particles, drifted = pf.remove_invalid_particles(particles) 216 | if drifted: 217 | pbar.set_description("All particles have drifted, re-projecting to surface") 218 | particles.poses, _, _ = codebook.SE3_NN(particles.poses) 219 | 220 | # cluster particles 221 | if count % 50 == 0: 222 | particles = pf.cluster_particles(particles) 223 | cluster_poses, cluster_stds = pf.get_cluster_centers( 224 | particles, method="quat_avg" 225 | ) 226 | 227 | # anneal and resample 228 | particles = pf.annealing(particles, torch.mean(cluster_stds), floor=10000) 229 | particles = pf.resampler(particles) 230 | 231 | # save stats 232 | timer["meas"] = get_time(start_time) 233 | filter_stats["cluster_poses"].append(cluster_poses) 234 | filter_stats["cluster_stds"].append(cluster_stds) 235 | filter_stats["num_particles"].append(len(particles)) 236 | 237 | iteration_time = sum(timer.values()) 238 | filter_stats["time"].append(iteration_time) 239 | 240 | msg = ( 241 | f'[RMSE: {1000 * filter_stats["rmse_t"][-1]:.1f} mm, {filter_stats["rmse_r"][-1]:.0f} deg, ' 242 | f"{len(cluster_stds)} cluster(s) with {torch.mean(cluster_stds):.3f} sigma, P: {len(particles)}, " 243 | f'rate: {(1.0/filter_stats["time"][-1]):.2f} Hz] ' 244 | ) 245 | 246 | for key in timer: 247 | avg_timer[key].append(timer[key]) 248 | 249 | if viz is not None: 250 | # Update visualizer 251 | pbar.set_description(msg + " Visualizing results") 252 | heatmap_weights = pf.get_similarity( 253 | tactile_code, heatmap_embeddings, softmax=False 254 | ) 255 | viz.update( 256 | particles, 257 | cluster_poses, 258 | cluster_stds, 259 | gt_p_cam[idx, :], 260 | heatmap_poses, 261 | heatmap_weights, 262 | image, 263 | heightmap, 264 | mask, 265 | idx, 266 | image_savepath=osp.join(results_path, f"{idx}.png"), 267 | ) 268 | 269 | prev_idx = idx 270 | count += 1 271 | filter_stats["total_time"] = sum(filter_stats["time"]) 272 | pbar.update(diff) 273 | 274 | # save stats and data 275 | if viz is not None: 276 | pbar.set_description("End of sequence: saving data") 277 | viz.close() 278 | 279 | for key in avg_timer: 280 | avg_timer[key] = np.average(avg_timer[key]) 281 | 282 | filter_stats["avg_time"] = sum(filter_stats["time"]) / len(filter_stats["time"]) 283 | print( 284 | f"Total time: {filter_stats['total_time']:.3f}, Per iteration time: {filter_stats['avg_time']:.3f}" 285 | ) 286 | print( 287 | f'Avg time: tactile: {avg_timer["tactile"]:.2f}, motion : {avg_timer["motion"]:.2f}, meas : {avg_timer["meas"]:.2f} ' 288 | ) 289 | 290 | print("---------------------------------------------------------\n\n") 291 | np.save(osp.join(results_path, "filter_stats.npy"), filter_stats) 292 | pbar.set_description("Generating video from images") 293 | images_to_video(results_path) # convert saved images to .mp4 294 | pbar.close() 295 | return 296 | 297 | 298 | @hydra.main(config_path="../config", config_name="config") 299 | def main(cfg: DictConfig, viz=None, profile=False): 300 | 301 | if profile: 302 | yappi.set_clock_type("wall") # profiling 303 | yappi.start(builtins=True) 304 | 305 | if cfg.expt.render: 306 | viz = Viz(off_screen=cfg.expt.off_screen, zoom=1.0, window_size=0.25) 307 | 308 | t = threading.Thread(name="filter_real", target=filter_real, args=(cfg, viz)) 309 | t.start() 310 | if viz: 311 | viz.plotter.app.exec_() 312 | t.join() 313 | 314 | if profile: 315 | stats = yappi.get_func_stats() 316 | stats.save(osp.join(get_original_cwd(), "filter_real.prof"), type="pstat") 317 | 318 | 319 | if __name__ == "__main__": 320 | main() 321 | -------------------------------------------------------------------------------- /midastouch/filter/live_demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Main script for particle filtering with tactile embeddings""" 7 | 8 | from os import path as osp 9 | import torch 10 | from midastouch.modules.particle_filter import particle_filter 11 | from midastouch.modules.misc import ( 12 | DIRS, 13 | get_device, 14 | ) 15 | import dill as pickle 16 | 17 | import hydra 18 | import sys 19 | 20 | from omegaconf import DictConfig 21 | 22 | from midastouch.viz.demo_visualizer import Viz 23 | from midastouch.render.digit_renderer import digit_renderer 24 | from midastouch.contrib.tdn_fcrn.tdn import TDN 25 | from midastouch.contrib.tcn_minkloc.tcn import TCN 26 | from midastouch.modules.objects import ycb_test 27 | import time 28 | from tqdm import tqdm 29 | from digit_interface import Digit, DigitHandler 30 | 31 | import threading 32 | 33 | """Initialize the DIGIT capture""" 34 | 35 | 36 | def connectDigit(resolution="QVGA"): 37 | try: 38 | connected_digit = DigitHandler.list_digits()[0] 39 | except: 40 | print("No DIGIT found!") 41 | sys.exit(1) 42 | digit = Digit(connected_digit["serial"]) # Unique serial number 43 | digit.connect() 44 | digit.set_resolution(Digit.STREAMS[resolution]) 45 | digit.set_fps(30) 46 | print(digit.info()) 47 | # print("Collecting data from DIGIT {}".format(digit.serial)) 48 | return digit 49 | 50 | 51 | def live_demo(cfg: DictConfig, viz: Viz) -> None: 52 | """ 53 | Filtering pose for simulated data 54 | """ 55 | 56 | digit = connectDigit() 57 | expt_cfg, tcn_cfg, tdn_cfg = cfg.expt, cfg.tcn, cfg.tdn 58 | 59 | device = get_device(cpu=False) # get GPU 60 | 61 | # print('\n----------------------------------------\n') 62 | # print(OmegaConf.to_yaml(cfg)) 63 | # print('----------------------------------------\n') 64 | 65 | obj_model = expt_cfg.obj_model 66 | small_parts = False if obj_model in ycb_test else True 67 | 68 | tree_path = osp.join(DIRS["trees"], obj_model, "tree.pkl") 69 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 70 | 71 | pf = particle_filter(cfg, obj_path, 1.0, real=True) 72 | tac_render = digit_renderer(cfg=tdn_cfg.render, obj_path=obj_path) 73 | 74 | digit_tcn = TCN(tcn_cfg) 75 | digit_tdn = TDN(tdn_cfg, bg=tac_render.get_background(frame="gel"), real=True) 76 | codebook = pickle.load(open(tree_path, "rb")) 77 | codebook.to_device(device) 78 | heatmap_poses, _ = codebook.get_poses() 79 | heatmap_embeddings = codebook.get_embeddings() 80 | 81 | viz.init_variables(mesh_path=obj_path) 82 | 83 | count = 0 84 | for _ in tqdm(range(10)): 85 | # grab a few frames for stability (10 secs) 86 | time.sleep(0.1) 87 | digit.get_frame() 88 | 89 | while True: 90 | image = digit.get_frame() 91 | image = image[:, :, ::-1] # BGR -> RGB 92 | if count == 0: 93 | for _ in range(20): 94 | digit_tdn.bg = digit_tdn.image2heightmap(image) 95 | 96 | ### 1. TDN + TCN: convert image to heightmap and compress to tactile_code 97 | heightmap = digit_tdn.image2heightmap(image) # expensive 98 | mask = digit_tdn.heightmap2mask(heightmap, small_parts=small_parts) 99 | # view_subplots([image/255.0, heightmap.detach().cpu().numpy(), mask.detach().cpu().numpy()], [["image", "heightmap", "mask"]]) 100 | 101 | tactile_code = digit_tcn.cloud_to_tactile_code(tac_render, heightmap, mask) 102 | 103 | cluster_poses, cluster_stds = None, None 104 | if not torch.sum(mask): 105 | heatmap_weights = torch.zeros(heatmap_embeddings.shape[0]) 106 | else: 107 | heatmap_weights = pf.get_similarity( 108 | tactile_code, heatmap_embeddings, softmax=True 109 | ) 110 | 111 | viz.update( 112 | heatmap_poses, 113 | heatmap_weights, 114 | cluster_poses, 115 | cluster_stds, 116 | image, 117 | heightmap, 118 | mask, 119 | count, 120 | ) 121 | 122 | count += 1 123 | return 124 | 125 | 126 | @hydra.main(config_path="../config", config_name="config") 127 | def main(cfg: DictConfig, viz=None): 128 | if cfg.expt.render: 129 | viz = Viz(off_screen=cfg.expt.off_screen, zoom=1.0) 130 | t = threading.Thread(name="live_demo", target=live_demo, args=(cfg, viz)) 131 | t.start() 132 | if viz: 133 | viz.plotter.app.exec_() 134 | t.join() 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /midastouch/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/modules/__init__.py -------------------------------------------------------------------------------- /midastouch/modules/mesh.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Mesh processing utilities 8 | """ 9 | 10 | import numpy as np 11 | import trimesh 12 | import pyvista as pv 13 | from scipy.spatial import KDTree 14 | from midastouch.modules.pose import pose_from_vertex_normal 15 | from typing import Tuple 16 | 17 | # mesh utils 18 | def sample_mesh( 19 | mesh: trimesh.base.Trimesh, num_samples: int, method: str = "even" 20 | ) -> Tuple[np.ndarray, np.ndarray]: 21 | """ 22 | Sample mesh and return point/normals 23 | """ 24 | sampled_points, faces = np.empty((0, 3)), np.array([], dtype=int) 25 | # https://github.com/mikedh/trimesh/issues/558 : trimesh.sample.sample_surface_even gives wrong number of samples 26 | while True: 27 | if method == "even": 28 | sP, f = trimesh.sample.sample_surface_even(mesh, count=num_samples) 29 | else: 30 | sP, f = trimesh.sample.sample_surface(mesh, count=num_samples) 31 | sampled_points = np.vstack([sampled_points, sP]) 32 | faces = np.concatenate([faces, f]) 33 | if len(sampled_points) <= num_samples: 34 | continue 35 | else: 36 | sampled_points, faces = sampled_points[:num_samples, :], faces[:num_samples] 37 | break 38 | 39 | sampled_normals = mesh.face_normals[faces, :] 40 | sampled_normals = sampled_normals / np.linalg.norm(sampled_normals, axis=1).reshape( 41 | -1, 1 42 | ) 43 | return sampled_points, sampled_normals 44 | 45 | 46 | def extract_edges( 47 | mesh: trimesh.base.Trimesh, num_samples: int 48 | ) -> Tuple[np.ndarray, np.ndarray, int]: 49 | """ 50 | Extract mesh edges via pyvista 51 | """ 52 | mesh = pv.wrap(mesh) 53 | edges = mesh.extract_feature_edges(10) 54 | edges.compute_normals(inplace=True) # this activates the normals as well 55 | 56 | tree = KDTree(mesh.points) 57 | _, ii = tree.query(edges.points, k=1) 58 | edgePoints, edgeNormals = edges.points, mesh.point_normals[ii, :] 59 | 60 | if edgePoints.shape[0] < num_samples: 61 | num_samples = edgePoints.shape[0] 62 | 63 | # https://stackoverflow.com/a/14262743/8314238 64 | indices = np.random.choice(edgePoints.shape[0], num_samples, replace=False) 65 | edgePoints = edgePoints[indices, :] 66 | edgeNormals = edgeNormals[indices, :] / np.linalg.norm( 67 | edgeNormals[indices, :], axis=1 68 | ).reshape(-1, 1) 69 | return edgePoints, edgeNormals, num_samples 70 | 71 | 72 | def sample_mesh_edges( 73 | mesh: trimesh.base.Trimesh, num_samples: int 74 | ) -> Tuple[np.ndarray, np.ndarray]: 75 | """ 76 | Sample only mesh edges 77 | """ 78 | sampled_edge_points, sampled_edge_normals, num_samples = extract_edges( 79 | mesh, num_samples 80 | ) 81 | return sampled_edge_points, sampled_edge_normals 82 | 83 | 84 | def sample_poses_on_mesh( 85 | mesh: trimesh.base.Trimesh, 86 | num_samples: int, 87 | edges: bool = True, 88 | constraint: np.ndarray = None, 89 | r: float = None, 90 | shear_mag: float = 5.0, 91 | ) -> np.ndarray: 92 | """ 93 | Sample mesh and generates candidate sensor poses 94 | """ 95 | if constraint is not None: 96 | constrainedSampledPoints, constrainedSampledNormals = np.empty( 97 | (0, 3) 98 | ), np.empty((0, 3)) 99 | box = trimesh.creation.box(extents=[2 * r, 2 * r, 2 * r]) 100 | box.apply_translation(constraint) 101 | constrainedMesh = mesh.slice_plane(box.facets_origin, -box.facets_normal) 102 | while constrainedSampledPoints.shape[0] < num_samples: 103 | sP, sN = sample_mesh(constrainedMesh, num_samples * 100, method="even") 104 | dist = (np.linalg.norm(sP - constraint, axis=1)).squeeze() 105 | constrainedSampledPoints = np.append( 106 | constrainedSampledPoints, sP[np.less(dist, r), :], axis=0 107 | ) 108 | constrainedSampledNormals = np.append( 109 | constrainedSampledNormals, sN[np.less(dist, r), :], axis=0 110 | ) 111 | idxs = np.random.choice(constrainedSampledPoints.shape[0], num_samples) 112 | sampled_points, sampled_normals = ( 113 | constrainedSampledPoints[idxs, :], 114 | constrainedSampledNormals[idxs, :], 115 | ) 116 | elif edges: 117 | numSamplesEdges = int(0.3 * num_samples) 118 | sampled_edge_points, sampled_edge_normals, numSamplesEdges = extract_edges( 119 | mesh, numSamplesEdges 120 | ) 121 | numSamplesEven = num_samples - numSamplesEdges 122 | sampledPointsEven, sampledNormalsEven = sample_mesh(mesh, numSamplesEven) 123 | sampled_points, sampled_normals = np.concatenate( 124 | (sampledPointsEven, sampled_edge_points), axis=0 125 | ), np.concatenate((sampledNormalsEven, sampled_edge_normals), axis=0) 126 | else: 127 | sampled_points, sampled_normals = sample_mesh( 128 | mesh, num_samples, method="normal" 129 | ) 130 | 131 | # apply random pen into manifold 132 | shear_mag = np.radians(shear_mag) 133 | delta = np.random.uniform(low=0.0, high=2 * np.pi, size=(num_samples,)) 134 | T = pose_from_vertex_normal(sampled_points, sampled_normals, shear_mag, delta) 135 | return T 136 | 137 | 138 | def sample_poses_on_mesh_minkloc( 139 | mesh: trimesh.base.Trimesh, 140 | num_samples: int, 141 | edges: bool = True, 142 | num_angles: int = 1, 143 | shear_mag: float = 5.0, 144 | ) -> np.ndarray: 145 | """ 146 | Sample mesh and generates candidate sensor poses, custom for minkloc data 147 | """ 148 | if edges: 149 | numSamplesEdges = int(0.3 * num_samples) 150 | sampled_edge_points, sampled_edge_normals, numSamplesEdges = extract_edges( 151 | mesh, numSamplesEdges 152 | ) 153 | numSamplesEven = num_samples - numSamplesEdges 154 | sampledPointsEven, sampledNormalsEven = sample_mesh(mesh, numSamplesEven) 155 | sampled_points, sampled_normals = np.concatenate( 156 | (sampledPointsEven, sampled_edge_points), axis=0 157 | ), np.concatenate((sampledNormalsEven, sampled_edge_normals), axis=0) 158 | else: 159 | sampled_points, sampled_normals = sample_mesh(mesh, num_samples) 160 | 161 | sampled_points = np.repeat(sampled_points, num_angles, axis=0) 162 | sampled_normals = np.repeat(sampled_normals, num_angles, axis=0) 163 | 164 | # apply random pen into manifold 165 | delta = np.random.uniform(low=0.0, high=2 * np.pi, size=(num_samples * num_angles,)) 166 | T = pose_from_vertex_normal(sampled_points, sampled_normals, shear_mag, delta) 167 | return T 168 | -------------------------------------------------------------------------------- /midastouch/modules/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Miscellaneous functions 8 | """ 9 | 10 | import numpy as np 11 | 12 | np.seterr(divide="ignore", invalid="ignore") 13 | from sklearn.manifold import TSNE 14 | from sklearn.metrics.pairwise import cosine_similarity 15 | import torch 16 | import GPUtil 17 | import time 18 | import cv2 19 | 20 | import os 21 | from os import path as osp 22 | import shutil 23 | import ffmpeg 24 | 25 | import matplotlib.pyplot as plt 26 | import git 27 | from PIL import Image 28 | from typing import List, Tuple 29 | 30 | plt.rc("pdf", fonttype=42) 31 | plt.rc("ps", fonttype=42) 32 | plt.rc("font", family="serif") 33 | plt.rc("xtick", labelsize="small") 34 | plt.rc("ytick", labelsize="small") 35 | 36 | # quicklink to the root and folder directories 37 | root = git.Repo(".", search_parent_directories=True).working_tree_dir 38 | DIRS = { 39 | "root": root, 40 | "weights": osp.join(root, "midastouch", "model_weights"), 41 | "trees": osp.join(root, "midastouch", "tactile_tree", "data"), 42 | "data": osp.join(root, "YCB-Slide", "dataset"), 43 | "obj_models": osp.join(root, "YCB-Slide", "dataset", "obj_models"), 44 | "debug": osp.join(root, "debug"), 45 | } 46 | 47 | 48 | def get_device(cpu: bool = False, verbose: bool = True) -> str: 49 | """ 50 | Check GPU utilization and return device for torch 51 | """ 52 | if cpu: 53 | device = "cpu" 54 | if verbose: 55 | print("Override, using device:", device) 56 | else: 57 | try: 58 | deviceID = GPUtil.getFirstAvailable( 59 | order="first", 60 | maxLoad=0.8, 61 | maxMemory=0.8, 62 | attempts=5, 63 | interval=1, 64 | verbose=False, 65 | ) 66 | device = torch.device( 67 | "cuda:" + str(deviceID[0]) if torch.cuda.is_available() else "cpu" 68 | ) 69 | if verbose: 70 | print("Using device:", torch.cuda.get_device_name(deviceID[0])) 71 | except: 72 | device = "cpu" 73 | if verbose: 74 | print("Using device:", device) 75 | return device 76 | 77 | 78 | def confusion_matrix( 79 | embeddings: np.ndarray, sz: int, batch_size: int = 100 80 | ) -> np.ndarray: 81 | """ 82 | get pairwise cosine_similarity for embeddings and generate confusion matrix 83 | """ 84 | C = np.nan * np.zeros((sz, sz)) 85 | num_batches = sz // batch_size 86 | 87 | if num_batches == 0: 88 | C = cosine_similarity(embeddings, embeddings).squeeze() 89 | else: 90 | for i in range(num_batches): 91 | i_range = ( 92 | np.array(range(i * batch_size, sz)) 93 | if (i == num_batches - 1) 94 | else np.array(range(i * batch_size, (i + 1) * batch_size)) 95 | ) 96 | for j in range(num_batches): 97 | j_range = ( 98 | np.array(range(j * batch_size, sz)) 99 | if (j == num_batches - 1) 100 | else np.array(range(j * batch_size, (j + 1) * batch_size)) 101 | ) 102 | C[i_range[:, None], j_range] = cosine_similarity( 103 | embeddings[i_range, :], embeddings[j_range, :] 104 | ).squeeze() 105 | 106 | # scale 0 to 1 107 | C = (C - np.min(C)) / np.ptp(C) # scale [0, 1] 108 | return C 109 | 110 | 111 | def color_tsne(C: np.ndarray, TSNE_init: str) -> np.ndarray: 112 | """ 113 | Project high-dimensional data via TSNE and colormap 114 | """ 115 | tsne = TSNE( 116 | n_components=1, 117 | verbose=1, 118 | perplexity=40, 119 | init=TSNE_init, 120 | random_state=0, 121 | ) 122 | 123 | C = np.nan_to_num(C) 124 | tsne_encoding = tsne.fit_transform(C) 125 | tsne_encoding = np.squeeze(tsne_encoding) 126 | tsne_min, tsne_max = np.min(tsne_encoding), np.max(tsne_encoding) 127 | tsne_encoding = (tsne_encoding - tsne_min) / (tsne_max - tsne_min) 128 | colors = plt.cm.Spectral(tsne_encoding)[:, :3] 129 | return colors 130 | 131 | 132 | def get_time(start_time, units="sec"): 133 | """ 134 | Get difference in time since start_time 135 | """ 136 | elapsedInSeconds = time.time() - start_time 137 | if units == "sec": 138 | return elapsedInSeconds 139 | if units == "min": 140 | return elapsedInSeconds / 60 141 | if units == "hour": 142 | return elapsedInSeconds / (60 * 60) 143 | 144 | 145 | def view_subplots(image_data: List[np.ndarray], image_mosaic: List[str]) -> None: 146 | """ 147 | Make subplot mosaic from image data 148 | """ 149 | fig, axes = plt.subplot_mosaic(mosaic=image_mosaic, constrained_layout=True) 150 | for j, (label, ax) in enumerate(axes.items()): 151 | ax.imshow(image_data[j]) 152 | ax.axis("off") 153 | ax.set_title(label) 154 | plt.show() 155 | 156 | 157 | def remove_and_mkdir(results_path: str) -> None: 158 | """ 159 | Remove directory (if exists) and create 160 | """ 161 | if osp.exists(results_path): 162 | shutil.rmtree(results_path) 163 | os.makedirs(results_path) 164 | 165 | 166 | def change_to_dir(abspath: str, rel_dir: str) -> None: 167 | """ 168 | Change path to specified directory 169 | """ 170 | dname = osp.dirname(abspath) 171 | os.chdir(dname) 172 | os.chdir(rel_dir) # root 173 | 174 | 175 | def load_heightmap_mask( 176 | heightmapFile: str, maskFile: str 177 | ) -> Tuple[np.ndarray, np.ndarray]: 178 | """ 179 | Load heightmap and contact mask from file 180 | """ 181 | try: 182 | heightmap = cv2.imread(heightmapFile, 0).astype(np.int64) 183 | contactmask = cv2.imread(maskFile, 0).astype(np.int64) 184 | except AttributeError: 185 | heightmap = np.zeros(heightmap.shape).astype(np.int64) 186 | contactmask = np.zeros(contactmask.shape).astype(np.int64) 187 | return heightmap, contactmask > 255 / 2 188 | 189 | 190 | def load_heightmaps_masks( 191 | heightmapFolder: str, contactmaskFolder: str 192 | ) -> Tuple[List[np.ndarray], List[np.ndarray]]: 193 | """ 194 | Load heightmaps and contact masks from folder 195 | """ 196 | 197 | heightmapFiles = sorted( 198 | os.listdir(heightmapFolder), key=lambda y: int(y.split("_")[0]) 199 | ) 200 | contactmaskFiles = sorted( 201 | os.listdir(contactmaskFolder), key=lambda y: int(y.split("_")[0]) 202 | ) 203 | heightmaps, contactmasks = [], [] 204 | 205 | for heightmapFile, contactmaskFile in zip(heightmapFiles, contactmaskFiles): 206 | heightmap, mask = load_heightmap_mask( 207 | os.path.join(heightmapFolder, heightmapFile), 208 | os.path.join(contactmaskFolder, contactmaskFile), 209 | ) 210 | heightmaps.append(heightmap) 211 | contactmasks.append(mask) 212 | return heightmaps, contactmasks 213 | 214 | 215 | def load_images(imageFolder: str, N: int = None) -> np.ndarray: 216 | """ 217 | Load tactile images from folder (returns N images if specified) 218 | """ 219 | 220 | try: 221 | imageFiles = sorted(os.listdir(imageFolder), key=lambda y: int(y.split(".")[0])) 222 | except: 223 | imageFiles = sorted(os.listdir(imageFolder)) 224 | images = [] 225 | for imageFile in imageFiles: 226 | if imageFile.endswith(".mp4"): 227 | continue 228 | im = Image.open(os.path.join(imageFolder, imageFile)) 229 | images.append(np.array(im)) 230 | if N is not None and len(images) == N: 231 | return np.stack(images) 232 | return np.stack(images) # (N, 3, H, W) 233 | 234 | 235 | def save_image(tactileImage: np.ndarray, i: int, save_path: str) -> None: 236 | """ 237 | Save tactile image as .jpg file 238 | """ 239 | tactileImage = Image.fromarray(tactileImage.astype("uint8"), "RGB") 240 | tactileImage.save("{path}/{p_i}.jpg".format(path=save_path, p_i=i)) 241 | 242 | 243 | def save_images(tactileImages: List[np.ndarray], save_path: str) -> None: 244 | """ 245 | Save tactile images as .jpg files 246 | """ 247 | for i, tactileImage in enumerate(tactileImages): 248 | save_image(tactileImage, i, save_path) 249 | 250 | 251 | def save_heightmap(heightmap: np.ndarray, i: int, save_path: str) -> None: 252 | """ 253 | Save heightmap as .jpg file 254 | """ 255 | cv2.imwrite( 256 | "{path}/{p_i}.jpg".format(path=save_path, p_i=i), heightmap.astype("float32") 257 | ) 258 | 259 | 260 | def save_heightmaps(heightmaps: List[np.ndarray], save_path: str) -> None: 261 | """ 262 | Save heightmaps as .jpg files 263 | """ 264 | for i, heightmap in enumerate(heightmaps): 265 | save_heightmap(heightmap, i, save_path) 266 | 267 | 268 | def save_contactmask(contactMask: np.ndarray, i: int, save_path: str) -> None: 269 | """ 270 | Save contact mask as .jpg file 271 | """ 272 | cv2.imwrite( 273 | "{path}/{p_i}.jpg".format(path=save_path, p_i=i), 274 | 255 * contactMask.astype("uint8"), 275 | ) 276 | 277 | 278 | def save_contactmasks(contactMasks: List[np.ndarray], save_path: str) -> None: 279 | """ 280 | Save contact masks as .jpg files 281 | """ 282 | for i, contactMask in enumerate(contactMasks): 283 | save_contactmask(contactMask, i, save_path) 284 | 285 | 286 | def index_of(val: int, in_list: list) -> int: 287 | """ 288 | https://stackoverflow.com/a/49522958 : returns index of element in list 289 | """ 290 | try: 291 | return in_list.index(val) 292 | except ValueError: 293 | return -1 294 | 295 | 296 | def get_int(file: str) -> int: 297 | """ 298 | Extract numeric value from file name 299 | """ 300 | return int(file.split(".")[0]) 301 | 302 | 303 | def images_to_video(path: str) -> None: 304 | """ 305 | https://stackoverflow.com/a/67152804 : list of images to .mp4 306 | """ 307 | images = os.listdir(path) 308 | images = [im for im in images if im.endswith(".png")] 309 | images = sorted(images, key=get_int) 310 | 311 | # Execute FFmpeg sub-process, with stdin pipe as input, and jpeg_pipe input format 312 | process = ( 313 | ffmpeg.input("pipe:", r="10") 314 | .output(osp.join(path, "video.mp4"), vcodec="libx264") 315 | .global_args("-loglevel", "error") 316 | .global_args("-y") 317 | .overwrite_output() 318 | .run_async(pipe_stdin=True) 319 | ) 320 | 321 | for image in images: 322 | image_path = osp.join(path, image) 323 | with open(image_path, "rb") as f: 324 | # Read the JPEG file content to jpeg_data (bytes array) 325 | jpeg_data = f.read() 326 | # Write JPEG data to stdin pipe of FFmpeg process 327 | process.stdin.write(jpeg_data) 328 | 329 | # Close stdin pipe - FFmpeg fininsh encoding the output file. 330 | process.stdin.close() 331 | process.wait() 332 | -------------------------------------------------------------------------------- /midastouch/modules/objects.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Different object sets 8 | """ 9 | 10 | ycb_test = [ 11 | "004_sugar_box", 12 | "005_tomato_soup_can", 13 | "006_mustard_bottle", 14 | "021_bleach_cleanser", 15 | "025_mug", 16 | "035_power_drill", 17 | "037_scissors", 18 | "042_adjustable_wrench", 19 | "048_hammer", 20 | "055_baseball", 21 | ] 22 | 23 | ycb_train = [ 24 | "002_master_chef_can", 25 | "003_cracker_box", 26 | "007_tuna_fish_can", 27 | "008_pudding_box", 28 | "009_gelatin_box", 29 | "010_potted_meat_can", 30 | "011_banana", 31 | "012_strawberry", 32 | "013_apple", 33 | "014_lemon", 34 | "015_peach", 35 | "016_pear", 36 | "017_orange", 37 | "018_plum", 38 | "019_pitcher_base", 39 | "024_bowl", 40 | "026_sponge", 41 | "029_plate", 42 | "030_fork", 43 | "031_spoon", 44 | "032_knife", 45 | "033_spatula", 46 | "036_wood_block", 47 | "040_large_marker", 48 | "044_flat_screwdriver", 49 | "050_medium_clamp", 50 | "051_large_clamp", 51 | "052_extra_large_clamp", 52 | "053_mini_soccer_ball", 53 | "054_softball", 54 | "056_tennis_ball", 55 | "057_racquetball", 56 | "058_golf_ball", 57 | "061_foam_brick", 58 | "062_dice", 59 | "065-a_cups", 60 | "065-b_cups", 61 | "070-a_colored_wood_blocks", 62 | "072-a_toy_airplane", 63 | "077_rubiks_cube", 64 | ] 65 | 66 | mcmaster_models = ["cotter-pin", "steel-nail", "eyebolt"] 67 | 68 | misc_obj_models = ["cube", "octahedron", "sphere", "bunny"] 69 | -------------------------------------------------------------------------------- /midastouch/render/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/render/__init__.py -------------------------------------------------------------------------------- /midastouch/tactile_tree/__init__.py: -------------------------------------------------------------------------------- 1 | from . import tactile_tree 2 | -------------------------------------------------------------------------------- /midastouch/tactile_tree/build_codebook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Generates codebook from scratch by randomly sampling object meshes 8 | Run: python midastouch/tactile_tree/build_codebook.py expt.obj_model=005_tomato_soup_can expt.codebook_size=50000 9 | """ 10 | 11 | import os 12 | 13 | os.environ["PYOPENGL_PLATFORM"] = "egl" 14 | from os import path as osp 15 | import numpy as np 16 | 17 | from midastouch.contrib.tdn_fcrn.tdn import TDN 18 | from midastouch.contrib.tcn_minkloc.tcn import TCN 19 | import hydra 20 | from omegaconf import DictConfig 21 | from midastouch.modules.misc import DIRS, get_device 22 | from midastouch.modules.mesh import sample_poses_on_mesh 23 | from midastouch.render.digit_renderer import digit_renderer 24 | from tqdm import tqdm 25 | import trimesh 26 | import torch 27 | 28 | from midastouch.tactile_tree.tactile_tree import tactile_tree 29 | import dill as pickle 30 | 31 | 32 | @hydra.main(config_path="../config", config_name="config") 33 | def build_codebook(cfg: DictConfig, image_embedding=False): 34 | expt_cfg, tcn_cfg, tdn_cfg = cfg.expt, cfg.tcn, cfg.tdn 35 | 36 | num_samples = expt_cfg.codebook_size 37 | obj_model = expt_cfg.obj_model 38 | 39 | print( 40 | f"object: {obj_model}, codebook size: {num_samples}, image embedding: {image_embedding}" 41 | ) 42 | 43 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 44 | 45 | if not image_embedding: 46 | tree_path = osp.join(DIRS["trees"], obj_model, "codebook.pkl") 47 | else: 48 | tree_path = osp.join(DIRS["trees"], obj_model, "image_codebook.pkl") 49 | 50 | tac_render = digit_renderer(cfg=tdn_cfg.render, obj_path=obj_path, randomize=True) 51 | digit_tcn = TCN(tcn_cfg) 52 | digit_tdn = TDN(tdn_cfg, bg=tac_render.get_background(frame="gel")) 53 | print(f"Using TDN weights: {tdn_cfg.tdn_weights}") 54 | device = get_device(cpu=False) 55 | 56 | mesh = trimesh.load(obj_path) 57 | 58 | # Generate tree samples 59 | print("Generating {} samples".format(num_samples)) 60 | samples = sample_poses_on_mesh(mesh=mesh, num_samples=num_samples, edges=False) 61 | 62 | """Get multimodal embeddings""" 63 | ##################################### 64 | batch_size = 100 65 | num_batches = num_samples // batch_size 66 | num_batches = 1 if num_batches == 0 else num_batches 67 | 68 | pbar = tqdm(total=num_batches) 69 | gelposes, camposes, embeddings = ( 70 | torch.zeros((num_samples, 4, 4)), 71 | torch.zeros((num_samples, 4, 4)), 72 | torch.zeros( 73 | (num_samples, digit_tcn.params.model_params.output_dim), dtype=torch.double 74 | ), 75 | ) 76 | # heightmaps, masks, images = [None] * num_samples, [None] * num_samples, [None] * num_samples 77 | 78 | # heightmap_rmse, contact_mask_iou = [], [] 79 | for i in range(num_batches): 80 | i_range = ( 81 | np.array(range(i * batch_size, num_samples)) 82 | if (i == num_batches - 1) 83 | else np.array(range(i * batch_size, (i + 1) * batch_size)) 84 | ) 85 | h, cm, tactileImages, campose, gelpose = tac_render.render_sensor_poses( 86 | samples[i_range, :, :], num_depths=1 87 | ) 88 | gelposes[i_range, :] = torch.from_numpy(gelpose).float() 89 | camposes[i_range, :] = torch.from_numpy(campose).float() 90 | 91 | est_heightmaps, est_masks = [], [] 92 | 93 | if not image_embedding: 94 | for j, image in enumerate(tactileImages): 95 | est_h = digit_tdn.image2heightmap(image) # expensive 96 | est_c = digit_tdn.heightmap2mask(est_h) 97 | est_heightmaps.append(est_h) 98 | est_masks.append(est_c) 99 | 100 | # gt_h, gt_c = h[j], cm[j] 101 | # error_heightmap = np.abs(est_h - gt_h) * pixmm # Get pixelwise RMSE in mm, and IoU of the contact masks 102 | # heightmap_rmse.append(np.sqrt(np.mean(error_heightmap**2))) 103 | # intersection = np.sum(np.logical_and(gt_c, est_c)) 104 | # contact_mask_iou.append(intersection/(np.sum(est_c) + np.sum(gt_c) - intersection)) 105 | tactile_code = digit_tcn.cloud_to_tactile_code( 106 | tac_render, est_heightmaps, est_masks 107 | ) 108 | else: 109 | tactile_code = torch.zeros( 110 | (len(tactileImages), digit_tcn.params.model_params.output_dim), 111 | dtype=torch.double, 112 | ) 113 | for j, image in enumerate(tactileImages): 114 | tactile_code[j, :] = digit_tdn.image2embedding(image) 115 | 116 | embeddings[i_range, :] = tactile_code.cpu() 117 | pbar.update(1) 118 | pbar.close() 119 | 120 | # heightmap RMSE (mm), contact mask IoU [0 - 1] 121 | # heightmap_rmse = [x for x in heightmap_rmse if str(x) != 'nan'] 122 | # contact_mask_iou = [x for x in contact_mask_iou if str(x) != 'nan'] 123 | # heightmap_rmse = sum(heightmap_rmse) / len(heightmap_rmse) 124 | # contact_mask_iou = sum(contact_mask_iou) / len(contact_mask_iou) 125 | # error_file = open(osp.join(tree_path,'{}_error.txt'.format(method)),'w') 126 | # error_file.write(str(heightmap_rmse) + "," + str(contact_mask_iou) + "\n") 127 | # error_file.close() 128 | 129 | ##################################### 130 | codebook = tactile_tree( 131 | poses=gelposes, 132 | cam_poses=camposes, 133 | embeddings=embeddings, 134 | ) 135 | print("Saving data to path: {}".format(tree_path)) 136 | with open(tree_path, "wb") as file: 137 | pickle.dump(codebook, file) 138 | return 139 | 140 | 141 | if __name__ == "__main__": 142 | build_codebook() 143 | -------------------------------------------------------------------------------- /midastouch/tactile_tree/process_codebook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Script to modify contents of tactile codebook and overwrite pickle file 8 | """ 9 | 10 | import dill as pickle 11 | from os import path as osp 12 | from midastouch.tactile_tree.tactile_tree import tactile_tree 13 | from midastouch.modules.misc import DIRS 14 | import os 15 | 16 | 17 | def main(): 18 | codebooks_path = osp.join(DIRS["trees"]) 19 | 20 | objects = sorted(os.listdir(codebooks_path)) 21 | 22 | for object in objects: 23 | try: 24 | pickle_path = osp.join(codebooks_path, object, "codebook.pkl") 25 | with open(pickle_path, "rb") as pickle_file: 26 | T = pickle.load(pickle_file) 27 | 28 | poses, cam_poses = T.poses, T.cam_poses 29 | 30 | """ 31 | add intermediate processes here 32 | """ 33 | 34 | new_T = tactile_tree( 35 | poses=poses, cam_poses=cam_poses, embeddings=T.embeddings 36 | ) 37 | pickle_path = osp.join(codebooks_path, object, "codebook.pkl") 38 | print("Saving data to path: {}".format(pickle_path)) 39 | with open(pickle_path, "wb") as file: 40 | pickle.dump(new_T, file) 41 | except: 42 | print(f"Error building tree: {object}") 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /midastouch/tactile_tree/tactile_tree.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import pynanoflann 8 | from midastouch.modules.pose import get_logmap_from_matrix 9 | import torch 10 | from torch import nn 11 | 12 | 13 | class tactile_tree(nn.Module): 14 | def __init__(self, poses, cam_poses, embeddings): 15 | super(tactile_tree, self).__init__() 16 | self.poses = poses.float() 17 | self.logmap_pose = R3_SE3(self.poses.clone()) 18 | self.cam_poses, self.embeddings = cam_poses.float(), embeddings 19 | self.tree, self.tree_size = None, 0 20 | self.init_tree() 21 | 22 | def __len__(self): 23 | return self.tree_size 24 | 25 | def __repr__(self): 26 | return "tactile Tree of size: {}".format(self.__len__) 27 | 28 | def to_device(self, device): 29 | self.poses = self.poses.to(device) 30 | self.logmap_pose = self.logmap_pose.to(device) 31 | self.cam_poses = self.cam_poses.to(device) 32 | self.embeddings = self.embeddings.to(device) 33 | 34 | def init_tree(self): 35 | range = np.max( 36 | np.ptp(self.logmap_pose.cpu().numpy(), axis=0) 37 | ) # maximum range of KDTree data 38 | self.tree = pynanoflann.KDTree(metric="L2", radius=range) 39 | self.tree.fit(self.logmap_pose.cpu().numpy()) 40 | self.tree_size = self.poses.shape[0] 41 | return 42 | 43 | def SE3_NN(self, _query, nn=1): 44 | query = _query.clone() 45 | """ 46 | Get best SE3 match based on R3 and logmap_SO3 distances from a set of candidates 47 | """ 48 | query = torch.atleast_3d(query) 49 | query = R3_SE3(query) 50 | _, indices_p = self.tree.kneighbors( 51 | query.cpu().numpy(), n_neighbors=nn, n_jobs=16 52 | ) 53 | indices_p = torch.tensor(indices_p.squeeze().astype(dtype=np.int64)) 54 | return ( 55 | self.poses[indices_p, :, :], 56 | self.cam_poses[indices_p, :, :], 57 | self.embeddings[indices_p, :], 58 | ) 59 | 60 | def get_poses(self): 61 | return self.poses, self.cam_poses 62 | 63 | def get_pose(self, idx): 64 | return self.poses[idx, :] 65 | 66 | def get_embeddings(self): 67 | return self.embeddings 68 | 69 | def get_embedding(self, idx): 70 | return self.embeddings[idx, :] 71 | 72 | 73 | def R3_SE3(poses, w=0.01): 74 | return torch.cat( 75 | ((1.0 - w) * poses[:, :3, 3], w * get_logmap_from_matrix(poses[:, :3, :3])), 76 | axis=1, 77 | ) 78 | -------------------------------------------------------------------------------- /midastouch/tactile_tree/test_codebook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Loads object codebook and performs nearest neighbor queries 8 | """ 9 | 10 | from os import path as osp 11 | import dill as pickle 12 | import torch 13 | 14 | from midastouch.viz.helpers import viz_query_target_poses_on_mesh 15 | import trimesh 16 | from midastouch.modules.mesh import sample_poses_on_mesh 17 | from midastouch.modules.misc import DIRS 18 | import hydra 19 | from omegaconf import DictConfig 20 | 21 | 22 | @hydra.main(config_path="../config", config_name="config") 23 | def main(cfg: DictConfig): 24 | expt_cfg, tcn_cfg, tdn_cfg = cfg.expt, cfg.tcn, cfg.tdn 25 | obj_model = expt_cfg.obj_model 26 | codebook_path = osp.join(DIRS["trees"], obj_model, "codebook.pkl") 27 | obj_path = osp.join(DIRS["obj_models"], obj_model, "nontextured.stl") 28 | 29 | if osp.exists(codebook_path): 30 | with open(codebook_path, "rb") as pickle_file: 31 | codebook = pickle.load(pickle_file) 32 | 33 | mesh = trimesh.load(obj_path) 34 | 35 | num_pose = 5 36 | query_poses = sample_poses_on_mesh(mesh=mesh, num_samples=num_pose, edges=False) 37 | query_poses = torch.from_numpy(query_poses) 38 | 39 | target_poses, _, _ = codebook.SE3_NN(query_poses) 40 | viz_query_target_poses_on_mesh( 41 | mesh_path=obj_path, query_pose=query_poses, target_poses=target_poses 42 | ) 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /midastouch/viz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/MidasTouch/9afb6ff72b837ab69f140695d38bf95c498b30c3/midastouch/viz/__init__.py -------------------------------------------------------------------------------- /midastouch/viz/demo_visualizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | visualizer class for demo script 8 | """ 9 | 10 | import numpy as np 11 | import pyvista as pv 12 | from matplotlib import cm 13 | from pyvistaqt import BackgroundPlotter 14 | import torch 15 | import copy 16 | from os import path as osp 17 | import queue 18 | from PIL import Image 19 | import tkinter as tk 20 | 21 | from midastouch.modules.misc import DIRS 22 | 23 | pv.set_plot_theme("document") 24 | 25 | 26 | class Viz: 27 | def __init__( 28 | self, off_screen: bool = False, zoom: float = 1.0, window_size: int = 0.5 29 | ): 30 | 31 | pv.global_theme.multi_rendering_splitting_position = 0.7 32 | """ 33 | subplot(0, 0) main viz 34 | subplot(0, 1): tactile image viz 35 | subplot(1, 1): tactile codebook viz 36 | """ 37 | shape, row_weights, col_weights = (2, 2), [0.6, 0.4], [0.6, 0.4] 38 | groups = [(np.s_[:], 0), (0, 1), (1, 1)] 39 | 40 | w, h = tk.Tk().winfo_screenwidth(), tk.Tk().winfo_screenheight() 41 | 42 | self.plotter = BackgroundPlotter( 43 | title="MidasTouch", 44 | lighting="three lights", 45 | window_size=(int(w * window_size), int(h * window_size)), 46 | off_screen=off_screen, 47 | shape=shape, 48 | row_weights=row_weights, 49 | col_weights=col_weights, 50 | groups=groups, 51 | border_color="white", 52 | toolbar=False, 53 | menu_bar=False, 54 | auto_update=True, 55 | ) 56 | self.zoom = zoom 57 | 58 | self.viz_queue = queue.Queue(1) 59 | self.plotter.add_callback(self.update_viz, interval=10) 60 | 61 | def set_camera(self, position="yz", azimuth=45, elevation=20, zoom=None): 62 | ( 63 | self.plotter.camera_position, 64 | self.plotter.camera.azimuth, 65 | self.plotter.camera.elevation, 66 | ) = (position, azimuth, elevation) 67 | if zoom is None: 68 | self.plotter.camera.Zoom(self.zoom) 69 | else: 70 | self.plotter.camera.Zoom(zoom) 71 | self.plotter.camera_set = True 72 | 73 | def reset_vis(self, flag): 74 | self.plotter.subplot(0, 0) 75 | self.set_camera() 76 | self.reset_widget.value = not flag 77 | 78 | def init_variables(self, mesh_path: str): 79 | self.mesh_pv_deci = pv.read( 80 | mesh_path.replace("nontextured", "nontextured_decimated") 81 | ) # decimated pyvista object 82 | 83 | self.moving_sensor = pv.read( 84 | osp.join(DIRS["obj_models"], "digit", "digit.STL") 85 | ) # plotted gt sensor 86 | self.init_sensor = copy.deepcopy(self.moving_sensor) # sensor @ origin 87 | 88 | # Heatmap window 89 | self.plotter.subplot(0, 0) 90 | widget_size, pos = 20, self.plotter.window_size[1] - 40 91 | self.reset_widget = self.plotter.add_checkbox_button_widget( 92 | self.reset_vis, 93 | value=True, 94 | color_off="grey", 95 | color_on="grey", 96 | position=(10, pos - widget_size - 5), 97 | size=widget_size, 98 | ) 99 | self.plotter.add_text( 100 | "Reset camera", 101 | position=(15 + widget_size, pos - widget_size - 5), 102 | color="black", 103 | font="times", 104 | font_size=8, 105 | ) 106 | self.set_camera() 107 | self.plotter.add_text( 108 | "Tactile codebook output", 109 | position="bottom", 110 | color="black", 111 | shadow=True, 112 | font="times", 113 | font_size=10, 114 | name="Codebook text", 115 | ) 116 | 117 | dargs = dict( 118 | color="tan", 119 | ambient=0.0, 120 | opacity=0.7, 121 | smooth_shading=True, 122 | show_edges=False, 123 | specular=1.0, 124 | show_scalar_bar=False, 125 | render=False, 126 | ) 127 | self.plotter.add_mesh(self.moving_sensor, **dargs) 128 | 129 | # Tactile window 130 | self.plotter.subplot(0, 1) 131 | self.plotter.camera.Zoom(1) 132 | self.plotter.add_text( 133 | "Tactile image and heightmap", 134 | position="bottom", 135 | color="black", 136 | shadow=True, 137 | font="times", 138 | font_size=10, 139 | name="Tactile text", 140 | ) 141 | 142 | self.viz_count = 0 143 | self.image_plane, self.heightmap_plane = None, None 144 | 145 | def update_viz( 146 | self, 147 | ): 148 | if self.viz_queue.qsize(): 149 | ( 150 | heatmap_poses, 151 | heatmap_weights, 152 | cluster_poses, 153 | cluster_stds, 154 | image, 155 | heightmap, 156 | mask, 157 | frame, 158 | ) = self.viz_queue.get() 159 | self.viz_heatmap( 160 | heatmap_poses, heatmap_weights, cluster_poses, cluster_stds 161 | ) 162 | self.viz_tactile_image(image, heightmap, mask) 163 | self.plotter.add_text( 164 | f"\nFrame {frame} ", 165 | position="upper_right", 166 | color="black", 167 | shadow=True, 168 | font="times", 169 | font_size=12, 170 | name="frame text", 171 | render=True, 172 | ) 173 | self.viz_queue.task_done() 174 | 175 | def update( 176 | self, 177 | heatmap_poses: torch.Tensor, 178 | heatmap_weights: torch.Tensor, 179 | cluster_poses: torch.Tensor, 180 | cluster_stds: torch.Tensor, 181 | image: np.ndarray, 182 | heightmap: np.ndarray, 183 | mask: np.ndarray, 184 | frame: int, 185 | ) -> None: 186 | 187 | if self.viz_queue.full(): 188 | self.viz_queue.get() 189 | self.viz_queue.put( 190 | ( 191 | heatmap_poses, 192 | heatmap_weights, 193 | cluster_poses, 194 | cluster_stds, 195 | image, 196 | heightmap, 197 | mask, 198 | frame, 199 | ), 200 | block=False, 201 | ) 202 | 203 | def viz_heatmap( 204 | self, 205 | heatmap_poses: torch.Tensor, 206 | heatmap_weights: torch.Tensor, 207 | cluster_poses, 208 | cluster_stds, 209 | ) -> None: 210 | self.plotter.subplot(0, 0) 211 | 212 | heatmap_poses, heatmap_weights = ( 213 | heatmap_poses.cpu().numpy(), 214 | heatmap_weights.cpu().numpy(), 215 | ) 216 | heatmap_points = heatmap_poses[:, :3, 3] 217 | 218 | if cluster_poses is not None: 219 | assert ( 220 | cluster_poses.shape[0] == cluster_stds.shape[0] 221 | ), "dimensions must be equal" 222 | 223 | cluster_poses, cluster_stds = ( 224 | cluster_poses.cpu().numpy(), 225 | cluster_stds.cpu().numpy(), 226 | ) 227 | idx = np.argmin(cluster_stds.squeeze()) 228 | 229 | try: 230 | transformed_gelsight_mesh = self.init_sensor.transform( 231 | cluster_poses[idx, :, :], inplace=False 232 | ) 233 | self.moving_sensor.shallow_copy(transformed_gelsight_mesh) 234 | except: 235 | print(cluster_poses.shape, cluster_stds, idx) 236 | pass 237 | 238 | check = np.where(heatmap_weights < np.percentile(heatmap_weights, 90))[0] 239 | heatmap_weights = np.delete(heatmap_weights, check) 240 | heatmap_points = np.delete(heatmap_points, check, axis=0) 241 | 242 | # print(heatmap_weights.min(), heatmap_weights.mean(), heatmap_weights.max()) 243 | heatmap_weights = np.exp(heatmap_weights) 244 | heatmap_weights = (heatmap_weights - np.min(heatmap_weights)) / ( 245 | np.max(heatmap_weights) - np.min(heatmap_weights) 246 | ) 247 | 248 | # print(heatmap_weights.min(), np.percentile(heatmap_weights, 95), heatmap_weights.max()) 249 | heatmap_weights = np.nan_to_num(heatmap_weights) 250 | heatmap_cloud = pv.PolyData(heatmap_points) 251 | heatmap_cloud["similarity"] = heatmap_weights 252 | 253 | if self.viz_count: 254 | m = self.mesh_pv_deci.interpolate( 255 | heatmap_cloud, 256 | strategy="null_value", 257 | radius=self.mesh_pv_deci.length / 50, 258 | ) 259 | self.plotter.update_scalar_bar_range( 260 | clim=[np.percentile(heatmap_weights, 95), heatmap_weights.max()] 261 | ) 262 | self.plotter.update_scalars( 263 | mesh=self.heatmap_mesh, scalars=m["similarity"], render=False 264 | ) 265 | else: 266 | self.heatmap_mesh = self.mesh_pv_deci.interpolate( 267 | heatmap_cloud, 268 | strategy="null_value", 269 | radius=self.mesh_pv_deci.length / 50, 270 | ) 271 | dargs = dict( 272 | cmap=cm.get_cmap("viridis"), 273 | scalars="similarity", 274 | interpolate_before_map=True, 275 | ambient=0.5, 276 | opacity=1.0, 277 | show_scalar_bar=False, 278 | silhouette=True, 279 | ) 280 | self.plotter.add_mesh(self.heatmap_mesh, **dargs) 281 | self.plotter.set_focus(self.heatmap_mesh.center) 282 | ( 283 | self.plotter.camera_position, 284 | self.plotter.camera.azimuth, 285 | self.plotter.camera.elevation, 286 | ) = ("yz", 45, 20) 287 | self.plotter.camera.Zoom(1.0) 288 | self.plotter.camera_set = True 289 | self.viz_count += 1 290 | return 291 | 292 | def viz_tactile_image( 293 | self, 294 | image: np.ndarray, 295 | heightmap: torch.Tensor, 296 | mask: torch.Tensor, 297 | s: float = 1.8e-3, 298 | ) -> None: 299 | if self.image_plane is None: 300 | self.image_plane = pv.Plane( 301 | i_size=image.shape[1] * s, 302 | j_size=image.shape[0] * s, 303 | i_resolution=image.shape[1] - 1, 304 | j_resolution=image.shape[0] - 1, 305 | ) 306 | self.image_plane.points[:, -1] = 0.25 307 | self.heightmap_plane = copy.deepcopy(self.image_plane) 308 | 309 | # visualize gelsight image 310 | self.plotter.subplot(0, 1) 311 | heightmap, mask = heightmap.cpu().numpy(), mask.cpu().numpy() 312 | image_tex = pv.numpy_to_texture(image) 313 | 314 | heightmap_tex = pv.numpy_to_texture(-heightmap * mask.astype(np.float32)) 315 | self.heightmap_plane.points[:, -1] = ( 316 | np.flip(heightmap * mask.astype(np.float32), axis=0).ravel() * (0.5 * s) 317 | - 0.15 318 | ) 319 | self.plotter.add_mesh( 320 | self.image_plane, 321 | texture=image_tex, 322 | smooth_shading=False, 323 | show_scalar_bar=False, 324 | name="image", 325 | render=False, 326 | ) 327 | self.plotter.add_mesh( 328 | self.heightmap_plane, 329 | texture=heightmap_tex, 330 | cmap=cm.get_cmap("plasma"), 331 | show_scalar_bar=False, 332 | name="heightmap", 333 | render=False, 334 | ) 335 | 336 | def close(self): 337 | if len(self.images): 338 | for (im, path) in zip(self.images["im"], self.images["path"]): 339 | im = Image.fromarray(im.astype("uint8"), "RGB") 340 | im.save(path) 341 | 342 | self.plotter.close() 343 | -------------------------------------------------------------------------------- /midastouch/viz/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """ 7 | Helper functions for visualizer 8 | """ 9 | 10 | import numpy as np 11 | import pyvista as pv 12 | from scipy.spatial.transform import Rotation as R 13 | import matplotlib.pyplot as plt 14 | import matplotlib.animation as animation 15 | from matplotlib import cm 16 | 17 | 18 | def viz_poses_pointclouds_on_mesh( 19 | mesh_path, poses, pointclouds, save_path=None, decimation_factor=5 20 | ): 21 | if type(pointclouds) is not list: 22 | temp = pointclouds 23 | pointclouds = [None] * 1 24 | pointclouds[0] = temp 25 | 26 | plotter = pv.Plotter(window_size=[2000, 2000], off_screen=True) 27 | 28 | mesh = pv.read(mesh_path) # pyvista object 29 | dargs = dict( 30 | color="grey", 31 | ambient=0.6, 32 | opacity=0.5, 33 | smooth_shading=True, 34 | specular=1.0, 35 | show_scalar_bar=False, 36 | render=False, 37 | ) 38 | plotter.add_mesh(mesh, **dargs) 39 | draw_poses(plotter, mesh, poses, quiver_size=0.05) 40 | 41 | if poses.ndim == 2: 42 | spline = pv.lines_from_points(poses[:, :3]) 43 | plotter.add_mesh(spline, line_width=3, color="k") 44 | 45 | final_pc = np.empty((0, 3)) 46 | for i, pointcloud in enumerate(pointclouds): 47 | if pointcloud.shape[0] == 0: 48 | continue 49 | if decimation_factor is not None: 50 | downpcd = pointcloud[ 51 | np.random.choice( 52 | pointcloud.shape[0], 53 | pointcloud.shape[0] // decimation_factor, 54 | replace=False, 55 | ), 56 | :, 57 | ] 58 | else: 59 | downpcd = pointcloud 60 | final_pc = np.append(final_pc, downpcd) 61 | 62 | if final_pc.shape[0]: 63 | pc = pv.PolyData(final_pc) 64 | plotter.add_points( 65 | pc, render_points_as_spheres=True, color="#26D701", point_size=3 66 | ) 67 | 68 | if save_path: 69 | plotter.show(screenshot=save_path) 70 | print(f"Save path: {save_path}.png") 71 | else: 72 | plotter.show() 73 | plotter.close() 74 | pv.close_all() 75 | 76 | 77 | def viz_query_target_poses_on_mesh(mesh_path, query_pose, target_poses): 78 | plotter = pv.Plotter(window_size=[2000, 2000]) 79 | 80 | mesh = pv.read(mesh_path) # pyvista object 81 | dargs = dict( 82 | color="grey", 83 | ambient=0.6, 84 | opacity=0.5, 85 | smooth_shading=True, 86 | specular=1.0, 87 | show_scalar_bar=False, 88 | render=False, 89 | ) 90 | plotter.add_mesh(mesh, **dargs) 91 | 92 | draw_poses(plotter, mesh, target_poses, opacity=0.7) 93 | draw_poses(plotter, mesh, query_pose) 94 | dargs = dict( 95 | color="grey", 96 | ambient=0.6, 97 | opacity=0.6, 98 | smooth_shading=True, 99 | show_edges=False, 100 | specular=1.0, 101 | show_scalar_bar=False, 102 | ) 103 | plotter.add_mesh(mesh, **dargs) 104 | plotter.show() 105 | plotter.close() 106 | pv.close_all() 107 | 108 | 109 | def draw_poses( 110 | plotter: pv.Plotter, 111 | mesh: pv.DataSet, 112 | cluster_poses: np.ndarray, 113 | opacity: float = 1.0, 114 | quiver_size=0.1, 115 | ) -> None: 116 | """ 117 | Draw pose RGB coordinate axes for pose set in pyvista visualizer 118 | """ 119 | quivers = pose2quiver(cluster_poses, quiver_size * mesh.length) 120 | quivers = [quivers["xvectors"]] + [quivers["yvectors"]] + [quivers["zvectors"]] 121 | names = ["xvectors", "yvectors", "zvectors"] 122 | colors = ["r", "g", "b"] 123 | cluster_centers = cluster_poses[:, :3, 3] 124 | for (q, c, n) in zip(quivers, colors, names): 125 | plotter.add_arrows( 126 | cluster_centers, 127 | q, 128 | color=c, 129 | opacity=opacity, 130 | show_scalar_bar=False, 131 | render=False, 132 | name=n, 133 | ) 134 | 135 | 136 | def draw_graph(x, y, savepath, delay, flag="t"): 137 | fig, ax = plt.subplots() 138 | 139 | plt.xlabel("Timestep", fontsize=12) 140 | if flag == "t": 141 | plt.ylabel("Avg. translation RMSE (mm)", fontsize=12) 142 | y = [y_ * 1000.0 for y_ in y] 143 | elif flag == "r": 144 | plt.ylabel("Avg. rotation RMSE (deg)", fontsize=12) 145 | 146 | # rolling avg. over 10 timesteps 147 | import pandas as pd 148 | 149 | df = pd.DataFrame() 150 | N = 50 151 | df["y"] = y 152 | df_smooth = df.rolling(N).mean() 153 | df_smooth["y"][0 : N - 1] = y[0 : N - 1] # first 10 readings are as-is 154 | y = df_smooth["y"] 155 | 156 | N, maxy = len(x), max(y) 157 | (line,) = ax.plot(x, y, color="k") 158 | 159 | def update(num, x, y, line): 160 | line.set_data(x[:num], y[:num]) 161 | line.axes.axis([0, N, 0, maxy]) 162 | return (line,) 163 | 164 | ani = animation.FuncAnimation( 165 | fig, update, len(x), fargs=[x, y, line], interval=delay, blit=True 166 | ) 167 | ani.save(savepath + ".mp4", writer="ffmpeg", codec="h264") 168 | fig.savefig(savepath + ".pdf", transparent=True, bbox_inches="tight", pad_inches=0) 169 | 170 | 171 | def pose2quiver(poses, sz): 172 | """ 173 | Convert pose to quiver object (RGB) 174 | """ 175 | poses = np.atleast_3d(poses) 176 | quivers = pv.PolyData(poses[:, :3, 3]) # (N, 3) [x, y, z] 177 | x, y, z = np.array([1, 0, 0]), np.array([0, 1, 0]), np.array([0, 0, 1]) 178 | r = R.from_matrix(poses[:, 0:3, 0:3]) 179 | quivers["xvectors"], quivers["yvectors"], quivers["zvectors"] = ( 180 | r.apply(x) * sz, 181 | r.apply(y) * sz, 182 | r.apply(z) * sz, 183 | ) 184 | return quivers 185 | 186 | 187 | def viz_embedding_TSNE( 188 | mesh_path, 189 | samples, 190 | clusters, 191 | save_path, 192 | nPoints=500, 193 | radius_factor=80.0, 194 | off_screen=False, 195 | ): 196 | samples = np.atleast_2d(samples) 197 | samplePoints = pv.PolyData(samples[:, :3, 3]) 198 | samplePoints["similarity"] = clusters 199 | 200 | mesh_pv = pv.read(mesh_path) # pyvista object 201 | 202 | mesh = mesh_pv.interpolate( 203 | samplePoints, 204 | strategy="mask_points", 205 | radius=mesh_pv.length / radius_factor, 206 | ) 207 | p = pv.Plotter(off_screen=off_screen, window_size=[1000, 1000]) 208 | 209 | # replace black with gray 210 | if clusters.ndim == 2: 211 | null_idx = np.all(mesh["similarity"] == np.array([0.0, 0.0, 0.0]), axis=1) 212 | mesh["similarity"][null_idx, :] = np.array([189 / 256, 189 / 256, 189 / 256]) 213 | 214 | # Open a gif 215 | if clusters.ndim == 2: 216 | dargs = dict( 217 | scalars="similarity", 218 | rgb=True, 219 | interpolate_before_map=False, 220 | opacity=1, 221 | smooth_shading=True, 222 | show_scalar_bar=False, 223 | silhouette=True, 224 | ) 225 | else: 226 | dargs = dict( 227 | scalars="similarity", 228 | cmap=cm.get_cmap("plasma"), 229 | interpolate_before_map=False, 230 | opacity=1, 231 | smooth_shading=True, 232 | show_scalar_bar=False, 233 | silhouette=True, 234 | ) 235 | p.add_mesh(mesh, **dargs) 236 | 237 | if nPoints is not None: 238 | p.show(screenshot=save_path, auto_close=not off_screen) 239 | viewup = [0.5, 0.5, 1] 240 | path = p.generate_orbital_path( 241 | factor=8.0, 242 | viewup=viewup, 243 | n_points=nPoints, 244 | shift=mesh.length / (np.sqrt(3)), 245 | ) 246 | p.open_movie(save_path + ".mp4") 247 | p.orbit_on_path( 248 | path, write_frames=True, viewup=[0, 0, 1], step=0.01, progress_bar=True 249 | ) 250 | else: 251 | p.show(screenshot=save_path) 252 | p.close() 253 | pv.close_all() 254 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | #!/usr/bin/env python 7 | 8 | from setuptools import setup 9 | import shlex 10 | import subprocess 11 | 12 | 13 | def git_version(): 14 | cmd = 'git log --format="%h" -n 1' 15 | return subprocess.check_output(shlex.split(cmd)).decode() 16 | 17 | 18 | version = git_version() 19 | setup( 20 | name="midastouch", 21 | version=version, 22 | author="Sudharshan Suresh", 23 | author_email="suddhus@gmail.com", 24 | packages=["midastouch"], 25 | ) 26 | --------------------------------------------------------------------------------