├── .flake8 ├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── autolab_core ├── __init__.py ├── camera_intrinsics.py ├── chessboard_registration.py ├── completer.py ├── constants.py ├── csv_model.py ├── data_stream_recorder.py ├── data_stream_syncer.py ├── detector.py ├── dist_metrics.py ├── dual_quaternion.py ├── experiment_logger.py ├── feature_matcher.py ├── features.py ├── image.py ├── json_serialization.py ├── learning_analysis.py ├── logger.py ├── orthographic_intrinsics.py ├── point_registration.py ├── points.py ├── primitives.py ├── random_variables.py ├── rigid_transformations.py ├── tensor_dataset.py ├── transformations.py ├── utils.py ├── version.py └── yaml_config.py ├── cfg └── tools │ ├── aggregate_tensor_datasets.yaml │ └── compute_dataset_statistics.yaml ├── docs ├── Makefile ├── gh_deploy.sh └── source │ ├── api │ ├── csv_model.rst │ ├── dual_quaternion.rst │ ├── exceptions.rst │ ├── experiment_logger.rst │ ├── json_serialization.rst │ ├── points.rst │ ├── primitives.rst │ ├── random_variables.rst │ ├── rigid_transform.rst │ ├── utils.rst │ └── yaml_config.rst │ ├── conf.py │ ├── index.rst │ └── install │ └── install.rst ├── launch └── rigid_transforms.launch ├── package.xml ├── ros_nodes ├── rigid_transform_listener.py └── rigid_transform_publisher.py ├── setup.py ├── srv ├── RigidTransformListener.srv └── RigidTransformPublisher.srv ├── tests ├── __init__.py ├── constants.py ├── test_dataset.py ├── test_image.py ├── test_points.py ├── test_registration.py └── test_rigid_transform.py └── tools ├── aggregate_tensor_datasets.py ├── compute_dataset_statistics.py ├── convert_legacy_dataset_to_tensor_dataset.py ├── shuffle_tensor_dataset.py ├── split_dataset.py └── subsample_tensor_dataset.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 79 3 | ignore = W503, E203, E231 # conflicts with `black` formatting 4 | per-file-ignores = __init__.py: F401, setup.py: F821 5 | exclude = .git, docs/, srv/, launch/, cfg/, .eggs/ -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release Autolab Core 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | 8 | jobs: 9 | formatting: 10 | name: Check Code Formatting 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - name: Set up Python ${{ matrix.python-version }} 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.8 18 | - name: Install Formatting 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install black flake8 22 | - name: Check Formatting 23 | run: | 24 | flake8 --config=.flake8 . 25 | black -l 79 --check . 26 | 27 | tests: 28 | name: Run Unit Tests 29 | needs: formatting 30 | runs-on: ${{ matrix.os }} 31 | strategy: 32 | fail-fast: false 33 | matrix: 34 | python-version: [3.6, 3.7, 3.8, 3.9] 35 | os: [ubuntu-latest, macos-latest, windows-latest] 36 | 37 | steps: 38 | - uses: actions/checkout@v2 39 | - name: Set up Python ${{ matrix.python-version }} 40 | uses: actions/setup-python@v2 41 | with: 42 | python-version: ${{ matrix.python-version }} 43 | - name: Install Nose2 44 | run: | 45 | python -m pip install --upgrade pip 46 | pip install nose2 47 | - name: Install Autolab Core 48 | run: pip install . 49 | - name: Run Nose Tests 50 | run: nose2 51 | 52 | pypi: 53 | name: Release To PyPi 54 | needs: tests 55 | runs-on: ubuntu-latest 56 | steps: 57 | - uses: actions/checkout@v2 58 | with: 59 | fetch-depth: 2 60 | - name: Get changed files 61 | id: changed-files 62 | uses: tj-actions/changed-files@v5.1 63 | - name: Set up Python 64 | if: contains(steps.changed-files.outputs.modified_files, 'autolab_core/version.py') 65 | uses: actions/setup-python@v2 66 | with: 67 | python-version: '3.x' 68 | - name: Install publishing dependencies 69 | if: contains(steps.changed-files.outputs.modified_files, 'autolab_core/version.py') 70 | run: | 71 | python -m pip install --upgrade pip 72 | pip install setuptools wheel twine 73 | - name: Build and publish 74 | if: contains(steps.changed-files.outputs.modified_files, 'autolab_core/version.py') 75 | env: 76 | TWINE_USERNAME: __token__ 77 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 78 | run: | 79 | python setup.py sdist bdist_wheel 80 | twine upload dist/* 81 | 82 | release: 83 | name: Create GitHub Release 84 | needs: tests 85 | runs-on: ubuntu-latest 86 | steps: 87 | - name: Checkout code 88 | uses: actions/checkout@master 89 | with: 90 | fetch-depth: 2 91 | - name: Get changed files 92 | id: changed-files 93 | uses: tj-actions/changed-files@v5.1 94 | - name: Tag Version 95 | if: contains(steps.changed-files.outputs.modified_files, 'autolab_core/version.py') 96 | id: set_tag 97 | run: | 98 | export VER=$(python -c "exec(open('autolab_core/version.py','r').read());print(__version__)") 99 | echo "::set-output name=tag_name::${VER}" 100 | - name: Create Release 101 | if: contains(steps.changed-files.outputs.modified_files, 'autolab_core/version.py') 102 | id: create_release 103 | uses: actions/create-release@latest 104 | env: 105 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 106 | with: 107 | tag_name: ${{ steps.set_tag.outputs.tag_name }} 108 | release_name: Release ${{ steps.set_tag.outputs.tag_name }} 109 | draft: false 110 | prerelease: false 111 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | 3 | on: [pull_request, workflow_dispatch] 4 | 5 | jobs: 6 | formatting: 7 | name: Check Formatting 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Set up Python ${{ matrix.python-version }} 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: 3.8 15 | - name: Install Formatting 16 | run: | 17 | pip install black flake8 18 | - name: Check Formatting 19 | run: | 20 | flake8 --config=.flake8 . 21 | black -l 79 --check . 22 | 23 | tests: 24 | name: Run Unit Tests 25 | runs-on: ${{ matrix.os }} 26 | strategy: 27 | fail-fast: false 28 | matrix: 29 | python-version: [3.6, 3.7, 3.8, 3.9] 30 | os: [ubuntu-latest, windows-latest, macos-latest] 31 | steps: 32 | - uses: actions/checkout@v2 33 | - name: Set up Python ${{ matrix.python-version }} 34 | uses: actions/setup-python@v2 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | - name: Install Nose2 38 | run: | 39 | python -m pip install --upgrade pip 40 | pip install nose2 41 | - name: Install Autolab Core 42 | run: pip install . 43 | - name: Run Nose Tests 44 | run: nose2 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # PyBuilder 64 | target/ 65 | 66 | # IPython Notebook 67 | .ipynb_checkpoints 68 | 69 | # pyenv 70 | .python-version 71 | 72 | # celery beat schedule file 73 | celerybeat-schedule 74 | 75 | # dotenv 76 | .env 77 | 78 | # virtualenv 79 | venv/ 80 | ENV/ 81 | 82 | # Spyder project settings 83 | .spyderproject 84 | 85 | # Rope project settings 86 | .ropeproject 87 | 88 | # Temp files 89 | *~ 90 | .#* 91 | #* -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(autolab_core) 3 | 4 | ## Find catkin macros and libraries 5 | ## if COMPONENTS list like find_package(catkin REQUIRED COMPONENTS xyz) 6 | ## is used, also find other catkin packages 7 | find_package(catkin REQUIRED COMPONENTS 8 | rospy 9 | std_msgs 10 | message_generation 11 | ) 12 | 13 | ## System dependencies are found with CMake's conventions 14 | # find_package(Boost REQUIRED COMPONENTS system) 15 | 16 | 17 | ## Uncomment this if the package has a setup.py. This macro ensures 18 | ## modules and global scripts declared therein get installed 19 | ## See http://ros.org/doc/api/catkin/html/user_guide/setup_dot_py.html 20 | catkin_python_setup() 21 | 22 | ################################################ 23 | ## Declare ROS messages, services and actions ## 24 | ################################################ 25 | 26 | ## To declare and build messages, services or actions from within this 27 | ## package, follow these steps: 28 | ## * Let MSG_DEP_SET be the set of packages whose message types you use in 29 | ## your messages/services/actions (e.g. std_msgs, actionlib_msgs, ...). 30 | ## * In the file package.xml: 31 | ## * add a build_depend tag for "message_generation" 32 | ## * add a build_depend and a run_depend tag for each package in MSG_DEP_SET 33 | ## * If MSG_DEP_SET isn't empty the following dependency has been pulled in 34 | ## but can be declared for certainty nonetheless: 35 | ## * add a run_depend tag for "message_runtime" 36 | ## * In this file (CMakeLists.txt): 37 | ## * add "message_generation" and every package in MSG_DEP_SET to 38 | ## find_package(catkin REQUIRED COMPONENTS ...) 39 | ## * add "message_runtime" and every package in MSG_DEP_SET to 40 | ## catkin_package(CATKIN_DEPENDS ...) 41 | ## * uncomment the add_*_files sections below as needed 42 | ## and list every .msg/.srv/.action file to be processed 43 | ## * uncomment the generate_messages entry below 44 | ## * add every package in MSG_DEP_SET to generate_messages(DEPENDENCIES ...) 45 | 46 | ## Generate messages in the 'msg' folder 47 | # add_message_files( 48 | # FILES 49 | # Message1.msg 50 | # Message2.msg 51 | # ) 52 | 53 | ## Generate services in the 'srv' folder 54 | add_service_files( 55 | FILES 56 | RigidTransformPublisher.srv 57 | RigidTransformListener.srv 58 | ) 59 | 60 | ## Generate actions in the 'action' folder 61 | # add_action_files( 62 | # FILES 63 | # Action1.action 64 | # Action2.action 65 | # ) 66 | 67 | ## Generate added messages and services with any dependencies listed here 68 | generate_messages( 69 | DEPENDENCIES 70 | std_msgs # Or other packages containing msgs 71 | ) 72 | 73 | ################################################ 74 | ## Declare ROS dynamic reconfigure parameters ## 75 | ################################################ 76 | 77 | ## To declare and build dynamic reconfigure parameters within this 78 | ## package, follow these steps: 79 | ## * In the file package.xml: 80 | ## * add a build_depend and a run_depend tag for "dynamic_reconfigure" 81 | ## * In this file (CMakeLists.txt): 82 | ## * add "dynamic_reconfigure" to 83 | ## find_package(catkin REQUIRED COMPONENTS ...) 84 | ## * uncomment the "generate_dynamic_reconfigure_options" section below 85 | ## and list every .cfg file to be processed 86 | 87 | ## Generate dynamic reconfigure parameters in the 'cfg' folder 88 | # generate_dynamic_reconfigure_options( 89 | # cfg/DynReconf1.cfg 90 | # cfg/DynReconf2.cfg 91 | # ) 92 | 93 | ################################### 94 | ## catkin specific configuration ## 95 | ################################### 96 | ## The catkin_package macro generates cmake config files for your package 97 | ## Declare things to be passed to dependent projects 98 | ## INCLUDE_DIRS: uncomment this if you package contains header files 99 | ## LIBRARIES: libraries you create in this project that dependent projects also need 100 | ## CATKIN_DEPENDS: catkin_packages dependent projects also need 101 | ## DEPENDS: system dependencies of this project that dependent projects also need 102 | catkin_package( 103 | # INCLUDE_DIRS include 104 | # LIBRARIES yumipy 105 | # CATKIN_DEPENDS rospy 106 | # DEPENDS system_lib 107 | ) 108 | 109 | ########### 110 | ## Build ## 111 | ########### 112 | 113 | ## Specify additional locations of header files 114 | ## Your package locations should be listed before other locations 115 | # include_directories(include) 116 | include_directories( 117 | ${catkin_INCLUDE_DIRS} 118 | ) 119 | 120 | ## Declare a C++ library 121 | # add_library(yumipy 122 | # src/${PROJECT_NAME}/yumipy.cpp 123 | # ) 124 | 125 | ## Add cmake target dependencies of the library 126 | ## as an example, code may need to be generated before libraries 127 | ## either from message generation or dynamic reconfigure 128 | # add_dependencies(yumipy ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) 129 | 130 | ## Declare a C++ executable 131 | # add_executable(yumipy_node src/yumipy_node.cpp) 132 | 133 | ## Add cmake target dependencies of the executable 134 | ## same as for the library above 135 | # add_dependencies(yumipy_node ${${PROJECT_NAME}_EXPORTED_TARGETS} ${catkin_EXPORTED_TARGETS}) 136 | 137 | ## Specify libraries to link a library or executable target against 138 | # target_link_libraries(yumipy_node 139 | # ${catkin_LIBRARIES} 140 | # ) 141 | 142 | ############# 143 | ## Install ## 144 | ############# 145 | 146 | # all install targets should use catkin DESTINATION variables 147 | # See http://ros.org/doc/api/catkin/html/adv_user_guide/variables.html 148 | 149 | ## Mark executable scripts (Python etc.) for installation 150 | ## in contrast to setup.py, you can choose the destination 151 | # install(PROGRAMS 152 | # scripts/my_python_script 153 | # DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 154 | # ) 155 | 156 | ## Mark executables and/or libraries for installation 157 | # install(TARGETS yumipy yumipy_node 158 | # ARCHIVE DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} 159 | # LIBRARY DESTINATION ${CATKIN_PACKAGE_LIB_DESTINATION} 160 | # RUNTIME DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION} 161 | # ) 162 | 163 | ## Mark cpp header files for installation 164 | # install(DIRECTORY include/${PROJECT_NAME}/ 165 | # DESTINATION ${CATKIN_PACKAGE_INCLUDE_DESTINATION} 166 | # FILES_MATCHING PATTERN "*.h" 167 | # PATTERN ".svn" EXCLUDE 168 | # ) 169 | 170 | ## Mark other files for installation (e.g. launch and bag files, etc.) 171 | # install(FILES 172 | # # myfile1 173 | # # myfile2 174 | # DESTINATION ${CATKIN_PACKAGE_SHARE_DESTINATION} 175 | # ) 176 | 177 | ############# 178 | ## Testing ## 179 | ############# 180 | 181 | ## Add gtest based cpp test target and link libraries 182 | # catkin_add_gtest(${PROJECT_NAME}-test test/test_yumipy.cpp) 183 | # if(TARGET ${PROJECT_NAME}-test) 184 | # target_link_libraries(${PROJECT_NAME}-test ${PROJECT_NAME}) 185 | # endif() 186 | 187 | ## Add folders to be run by python nosetests 188 | # catkin_add_nosetests(test) 189 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2017 Berkeley AUTOLAB & University of California, Berkeley 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include the license 2 | include LICENSE 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Berkeley Autolab Core Modules 2 | [![pypi](https://img.shields.io/pypi/v/autolab-core.svg)](https://pypi.org/project/autolab-core/) [![python-versions](https://img.shields.io/pypi/pyversions/autolab-core.svg)](https://pypi.org/project/autolab-core/) [![status](https://github.com/BerkeleyAutomation/autolab_core/workflows/Release%20Autolab%20Core/badge.svg)](https://github.com/BerkeleyAutomation/autolab_core/actions) [![style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 3 | 4 | This module contains a set of useful utilities for robotic tasks. 5 | View the install guide and API documentation [here](https://BerkeleyAutomation.github.io/autolab_core). 6 | 7 | NOTE: As of May 3, 2021, this package has dropped support for Python versions 3.5 and lower, as these versions have reached their EOL. If you wish to use older Python versions, please use the 0.x.x series of tags. Additionally, many modules from the `autolab_perception` package have been migrated here to reduce confusion. 8 | 9 | NOTE: As of June 18, 2017, this package has been renamed from `core` to `autolab_core` to prevent naming conflicts. 10 | If you wish to use the old version named `core`, please checkout the branch `pre_name_change`. 11 | However, no further updates will be pushed to this branch. 12 | -------------------------------------------------------------------------------- /autolab_core/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__ 2 | from .csv_model import CSVModel 3 | from .dual_quaternion import DualQuaternion 4 | from .experiment_logger import ExperimentLogger 5 | from .json_serialization import dump, load 6 | from .points import BagOfPoints, BagOfVectors, Point, Direction, Plane3D 7 | from .points import ( 8 | PointCloud, 9 | NormalCloud, 10 | ImageCoords, 11 | RgbCloud, 12 | RgbPointCloud, 13 | PointNormalCloud, 14 | ) 15 | from .primitives import Box, Contour 16 | from .rigid_transformations import RigidTransform, SimilarityTransform 17 | from .utils import ( 18 | gen_experiment_id, 19 | histogram, 20 | skew, 21 | deskew, 22 | pretty_str_time, 23 | filenames, 24 | sph2cart, 25 | cart2sph, 26 | is_positive_definite, 27 | is_positive_semi_definite, 28 | keyboard_input, 29 | ) 30 | from .yaml_config import YamlConfig 31 | from .dist_metrics import abs_angle_diff, DistMetrics 32 | from .random_variables import ( 33 | RandomVariable, 34 | BernoulliRV, 35 | GaussianRV, 36 | ArtificialRV, 37 | ArtificialSingleRV, 38 | GaussianRigidTransformRandomVariable, 39 | IsotropicGaussianRigidTransformRandomVariable, 40 | ) 41 | from .completer import Completer 42 | from .learning_analysis import ( 43 | ConfusionMatrix, 44 | ClassificationResult, 45 | BinaryClassificationResult, 46 | RegressionResult, 47 | ) 48 | from .tensor_dataset import Tensor, TensorDatapoint, TensorDataset 49 | from .logger import Logger 50 | from .data_stream_syncer import DataStreamSyncer 51 | from .data_stream_recorder import DataStreamRecorder 52 | 53 | 54 | from .features import ( 55 | Feature, 56 | LocalFeature, 57 | GlobalFeature, 58 | SHOTFeature, 59 | MVCNNFeature, 60 | BagOfFeatures, 61 | ) 62 | 63 | from .feature_matcher import ( 64 | Correspondences, 65 | NormalCorrespondences, 66 | FeatureMatcher, 67 | RawDistanceFeatureMatcher, 68 | PointToPlaneFeatureMatcher, 69 | ) 70 | from .image import ( 71 | Image, 72 | ColorImage, 73 | DepthImage, 74 | IrImage, 75 | GrayscaleImage, 76 | RgbdImage, 77 | GdImage, 78 | SegmentationImage, 79 | BinaryImage, 80 | PointCloudImage, 81 | NormalCloudImage, 82 | ) 83 | from .chessboard_registration import ( 84 | ChessboardRegistrationResult, 85 | CameraChessboardRegistration, 86 | ) 87 | from .point_registration import ( 88 | RegistrationResult, 89 | IterativeRegistrationSolver, 90 | PointToPlaneICPSolver, 91 | ) 92 | from .detector import ( 93 | RgbdDetection, 94 | RgbdDetector, 95 | RgbdForegroundMaskDetector, 96 | RgbdForegroundMaskQueryImageDetector, 97 | PointCloudBoxDetector, 98 | RgbdDetectorFactory, 99 | ) 100 | from .camera_intrinsics import CameraIntrinsics 101 | -------------------------------------------------------------------------------- /autolab_core/chessboard_registration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes for easy chessboard registration 3 | Authors: Jeff Mahler and Jacky Liang 4 | """ 5 | import logging 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import time 9 | 10 | from .image import DepthImage 11 | from .points import PointCloud, Point 12 | from .rigid_transformations import RigidTransform 13 | 14 | 15 | class ChessboardRegistrationResult(object): 16 | """Struct to encapsulate results of camera-to-chessboard registration. 17 | 18 | Attributes 19 | ---------- 20 | T_camera_cb : :obj:`autolab_core.RigidTransform` 21 | transformation from camera to chessboard frame 22 | cb_points_camera : :obj:`autolab_core.PointCloud` 23 | 3D locations of chessboard corners in the camera frame 24 | """ 25 | 26 | def __init__(self, T_camera_cb, cb_points_camera): 27 | self.T_camera_cb = T_camera_cb 28 | self.cb_points_cam = cb_points_camera 29 | 30 | 31 | class CameraChessboardRegistration: 32 | """ 33 | Namespace for camera to chessboard registration functions. 34 | """ 35 | 36 | @staticmethod 37 | def register(sensor, config): 38 | """ 39 | Registers a camera to a chessboard. 40 | 41 | Parameters 42 | ---------- 43 | sensor : :obj:`perception.RgbdSensor` 44 | the sensor to register 45 | config : :obj:`autolab_core.YamlConfig` or :obj:`dict` 46 | configuration file for registration 47 | 48 | Returns 49 | ------- 50 | :obj:`ChessboardRegistrationResult` 51 | the result of registration 52 | 53 | Notes 54 | ----- 55 | The config must have the parameters specified in the 56 | Other Parameters section. 57 | 58 | Other Parameters 59 | ---------------- 60 | num_transform_avg : int 61 | the number of independent registrations to average together 62 | num_images : int 63 | the number of images to read for each independent registration 64 | corners_x : int 65 | the number of chessboard corners in the x-direction 66 | corners_y : int 67 | the number of chessboard corners in the y-direction 68 | color_image_rescale_factor : float 69 | amount to rescale the color image for detection 70 | (numbers around 4-8 are useful) 71 | vis : bool 72 | whether or not to visualize the registration 73 | """ 74 | # read config 75 | num_transform_avg = config["num_transform_avg"] 76 | num_images = config["num_images"] 77 | sx = config["corners_x"] 78 | sy = config["corners_y"] 79 | point_order = config["point_order"] 80 | color_image_rescale_factor = config["color_image_rescale_factor"] 81 | flip_normal = config["flip_normal"] 82 | y_points_left = False 83 | if "y_points_left" in config.keys() and sx == sy: 84 | y_points_left = config["y_points_left"] 85 | num_images = 1 86 | vis = config["vis"] 87 | 88 | # read params from sensor 89 | logging.info("Registering camera %s" % (sensor.frame)) 90 | 91 | # repeat registration multiple times and average results 92 | points_3d_plane = PointCloud( 93 | np.zeros([3, sx * sy]), frame=sensor.ir_frame 94 | ) 95 | 96 | k = 0 97 | while k < num_transform_avg: 98 | # average a bunch of depth images together 99 | depth_ims = None 100 | for i in range(num_images): 101 | start = time.time() 102 | small_color_im, new_depth_im, _ = sensor.frames() 103 | end = time.time() 104 | logging.info("Frames Runtime: %.3f" % (end - start)) 105 | if depth_ims is None: 106 | depth_ims = np.zeros( 107 | [new_depth_im.height, new_depth_im.width, num_images] 108 | ) 109 | depth_ims[:, :, i] = new_depth_im.data 110 | 111 | med_depth_im = np.median(depth_ims, axis=2) 112 | depth_im = DepthImage(med_depth_im, sensor.ir_frame) 113 | 114 | # find the corner pixels in an upsampled version of the color image 115 | big_color_im = small_color_im.resize(color_image_rescale_factor) 116 | corner_px = big_color_im.find_chessboard(sx=sx, sy=sy) 117 | 118 | # Visualize corner detections on big color image if vis==True 119 | if vis and corner_px is not None: 120 | plt.figure() 121 | plt.imshow(big_color_im.data) 122 | for i in range(sx): 123 | plt.scatter(corner_px[i, 0], corner_px[i, 1], s=25, c="b") 124 | plt.show() 125 | 126 | elif corner_px is None: 127 | logging.error( 128 | "No chessboard detected! Check camera exposure settings" 129 | ) 130 | continue 131 | 132 | # convert back to original image 133 | small_corner_px = corner_px / color_image_rescale_factor 134 | 135 | if vis: 136 | plt.figure() 137 | plt.imshow(small_color_im.data) 138 | for i in range(sx): 139 | plt.scatter( 140 | small_corner_px[i, 0], 141 | small_corner_px[i, 1], 142 | s=25, 143 | c="b", 144 | ) 145 | plt.axis("off") 146 | plt.show() 147 | 148 | # project points into 3D 149 | camera_intr = sensor.ir_intrinsics 150 | points_3d = camera_intr.deproject(depth_im) 151 | 152 | # get round chessboard ind 153 | corner_px_round = np.round(small_corner_px).astype(np.uint16) 154 | corner_ind = depth_im.ij_to_linear( 155 | corner_px_round[:, 0], corner_px_round[:, 1] 156 | ) 157 | if corner_ind.shape[0] != sx * sy: 158 | logging.warning("Did not find all corners. Discarding...") 159 | continue 160 | 161 | # average 3d points 162 | points_3d_plane = (k * points_3d_plane + points_3d[corner_ind]) / ( 163 | k + 1 164 | ) 165 | logging.info( 166 | "Registration iteration %d of %d" 167 | % (k + 1, config["num_transform_avg"]) 168 | ) 169 | k += 1 170 | 171 | # fit a plane to the chessboard corners 172 | X = np.c_[ 173 | points_3d_plane.x_coords, 174 | points_3d_plane.y_coords, 175 | np.ones(points_3d_plane.num_points), 176 | ] 177 | y = points_3d_plane.z_coords 178 | A = X.T.dot(X) 179 | b = X.T.dot(y) 180 | w = np.linalg.inv(A).dot(b) 181 | n = np.array([w[0], w[1], -1]) 182 | n = n / np.linalg.norm(n) 183 | if flip_normal: 184 | n = -n 185 | mean_point_plane = points_3d_plane.mean() 186 | 187 | # find x-axis of the chessboard coordinates on the fitted plane 188 | T_camera_table = RigidTransform( 189 | translation=-points_3d_plane.mean().data, 190 | from_frame=points_3d_plane.frame, 191 | to_frame="table", 192 | ) 193 | points_3d_centered = T_camera_table * points_3d_plane 194 | 195 | # get points along y 196 | if point_order == "row_major": 197 | coord_pos_x = int(np.floor(sx * sy / 2.0)) 198 | coord_neg_x = int(np.ceil(sx * sy / 2.0)) 199 | 200 | points_pos_x = points_3d_centered[coord_pos_x:] 201 | points_neg_x = points_3d_centered[:coord_neg_x] 202 | x_axis = np.mean(points_pos_x.data, axis=1) - np.mean( 203 | points_neg_x.data, axis=1 204 | ) 205 | x_axis = x_axis - np.vdot(x_axis, n) * n 206 | x_axis = x_axis / np.linalg.norm(x_axis) 207 | y_axis = np.cross(n, x_axis) 208 | else: 209 | coord_pos_y = int(np.floor(sx * (sy - 1) / 2.0)) 210 | coord_neg_y = int(np.ceil(sx * (sy + 1) / 2.0)) 211 | points_pos_y = points_3d_centered[:coord_pos_y] 212 | points_neg_y = points_3d_centered[coord_neg_y:] 213 | y_axis = np.mean(points_pos_y.data, axis=1) - np.mean( 214 | points_neg_y.data, axis=1 215 | ) 216 | y_axis = y_axis - np.vdot(y_axis, n) * n 217 | y_axis = y_axis / np.linalg.norm(y_axis) 218 | x_axis = np.cross(-n, y_axis) 219 | 220 | # produce translation and rotation from plane center and chessboard 221 | # basis 222 | rotation_cb_camera = RigidTransform.rotation_from_axes( 223 | x_axis, y_axis, n 224 | ) 225 | translation_cb_camera = mean_point_plane.data 226 | T_cb_camera = RigidTransform( 227 | rotation=rotation_cb_camera, 228 | translation=translation_cb_camera, 229 | from_frame="cb", 230 | to_frame=sensor.frame, 231 | ) 232 | 233 | if y_points_left and np.abs(T_cb_camera.y_axis[1]) > 0.1: 234 | if T_cb_camera.x_axis[0] > 0: 235 | T_cb_camera.rotation = T_cb_camera.rotation.dot( 236 | RigidTransform.z_axis_rotation(-np.pi / 2).T 237 | ) 238 | else: 239 | T_cb_camera.rotation = T_cb_camera.rotation.dot( 240 | RigidTransform.z_axis_rotation(np.pi / 2).T 241 | ) 242 | T_camera_cb = T_cb_camera.inverse() 243 | 244 | # optionally display cb corners with detected pose in 3d space 245 | if config["debug"]: 246 | # display image with axes overlayed 247 | cb_center_im = camera_intr.project( 248 | Point(T_cb_camera.translation, frame=sensor.ir_frame) 249 | ) 250 | cb_x_im = camera_intr.project( 251 | Point( 252 | T_cb_camera.translation 253 | + T_cb_camera.x_axis * config["scale_amt"], 254 | frame=sensor.ir_frame, 255 | ) 256 | ) 257 | cb_y_im = camera_intr.project( 258 | Point( 259 | T_cb_camera.translation 260 | + T_cb_camera.y_axis * config["scale_amt"], 261 | frame=sensor.ir_frame, 262 | ) 263 | ) 264 | cb_z_im = camera_intr.project( 265 | Point( 266 | T_cb_camera.translation 267 | + T_cb_camera.z_axis * config["scale_amt"], 268 | frame=sensor.ir_frame, 269 | ) 270 | ) 271 | x_line = np.array([cb_center_im.data, cb_x_im.data]) 272 | y_line = np.array([cb_center_im.data, cb_y_im.data]) 273 | z_line = np.array([cb_center_im.data, cb_z_im.data]) 274 | 275 | plt.figure() 276 | plt.imshow(small_color_im.data) 277 | plt.scatter(cb_center_im.data[0], cb_center_im.data[1]) 278 | plt.plot(x_line[:, 0], x_line[:, 1], c="r", linewidth=3) 279 | plt.plot(y_line[:, 0], y_line[:, 1], c="g", linewidth=3) 280 | plt.plot(z_line[:, 0], z_line[:, 1], c="b", linewidth=3) 281 | plt.axis("off") 282 | plt.title("Chessboard frame in camera %s" % (sensor.frame)) 283 | plt.show() 284 | 285 | return ChessboardRegistrationResult(T_camera_cb, points_3d_plane) 286 | -------------------------------------------------------------------------------- /autolab_core/completer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class for autocomplete on prompts. Adapted from 3 | http://stackoverflow.com/questions/5637124/tab-completion-in-pythons-raw-input 4 | 5 | Author: Jeff Mahler 6 | """ 7 | import os 8 | import re 9 | 10 | if os.name == "nt": 11 | import pyreadline as readline 12 | else: 13 | import readline 14 | 15 | RE_SPACE = re.compile(r".*\s+$", re.M) 16 | 17 | 18 | class Completer(object): 19 | """ 20 | Tab completion class for CLI. 21 | """ 22 | 23 | def __init__(self, commands=[]): 24 | """Provide a list of commands""" 25 | self.commands = commands 26 | self.prefix = None 27 | self.words = [] 28 | 29 | def _listdir(self, root): 30 | """List directory 'root' appending the path separator to subdirs.""" 31 | res = [] 32 | for name in os.listdir(root): 33 | path = os.path.join(root, name) 34 | if os.path.isdir(path): 35 | name += os.sep 36 | res.append(name) 37 | return res 38 | 39 | def _complete_path(self, path=None): 40 | """Perform completion of filesystem path.""" 41 | if path is None or path == "": 42 | return self._listdir("./") 43 | dirname, rest = os.path.split(path) 44 | tmp = dirname if dirname else "." 45 | res = [ 46 | os.path.join(dirname, p) 47 | for p in self._listdir(tmp) 48 | if p.startswith(rest) 49 | ] 50 | # more than one match, or single match which does not exist (typo) 51 | if len(res) > 1 or not os.path.exists(path): 52 | return res 53 | # resolved to a single directory, so return list of files below it 54 | if os.path.isdir(path): 55 | return [os.path.join(path, p) for p in self._listdir(path)] 56 | # exact file match terminates this completion 57 | return [path + " "] 58 | 59 | def complete_extra(self, args): 60 | "Completions for the 'extra' command." 61 | # treat the last arg as a path and complete it 62 | if len(args) == 0: 63 | return self._listdir("./") 64 | return self._complete_path(args[-1]) 65 | 66 | def complete(self, text, state): 67 | "Generic readline completion entry point." 68 | 69 | results = [w for w in self.words if w.startswith(text)] + [None] 70 | if results != [None]: 71 | return results[state] 72 | 73 | buffer = readline.get_line_buffer() 74 | line = readline.get_line_buffer().split() 75 | 76 | results = [w for w in self.words if w.startswith(text)] + [None] 77 | if results != [None]: 78 | return results[state] 79 | 80 | # account for last argument ending in a space 81 | if RE_SPACE.match(buffer): 82 | line.append("") 83 | 84 | return (self.complete_extra(line) + [None])[state] 85 | 86 | # Sets word list for tab completion 87 | def set_words(self, words): 88 | self.words = [str(w) for w in words] 89 | -------------------------------------------------------------------------------- /autolab_core/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | """ 24 | # Access levels 25 | READ_ONLY_ACCESS = "READ_ONLY" 26 | READ_WRITE_ACCESS = "READ_WRITE" 27 | WRITE_ACCESS = "WRITE" 28 | 29 | # Formatting 30 | JSON_INDENT = 2 31 | 32 | # Train / Test 33 | TRAIN_ID = 0 34 | TEST_ID = 1 35 | 36 | # Extensions 37 | COLOR_IMAGE_EXTS = [".png", ".jpg"] 38 | INTR_EXTENSION = ".intr" 39 | 40 | # Image constants 41 | MIN_DEPTH = 0.25 42 | MAX_DEPTH = 1.25 43 | MAX_IR = 65535 44 | METERS_TO_MM = 1000.0 45 | MM_TO_METERS = 1.0 / METERS_TO_MM 46 | -------------------------------------------------------------------------------- /autolab_core/data_stream_recorder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class to record streams of data from a given object in separate process. 3 | Author: Jacky Liang 4 | """ 5 | import os 6 | import sys 7 | import logging 8 | import shutil 9 | import time 10 | from multiprocess import Process, Queue 11 | from joblib import dump, load 12 | from setproctitle import setproctitle 13 | 14 | 15 | def _NULL(): 16 | return None 17 | 18 | 19 | def _caches_to_file(cache_path, start, end, name, cb, concat): 20 | start_time = time.time() 21 | if concat: 22 | all_data = [] 23 | for i in range(start, end): 24 | data = load(os.path.join(cache_path, "{0}.jb".format(i))) 25 | all_data.extend(data) 26 | dump(all_data, name, 3) 27 | else: 28 | target_path = os.path.join(cache_path, name[:-3]) 29 | if not os.path.exists(target_path): 30 | os.makedirs(target_path) 31 | for i in range(start, end): 32 | src_file_path = os.path.join(cache_path, "{0}.jb".format(i)) 33 | 34 | basename = os.path.basename(src_file_path) 35 | target_file_path = os.path.join(target_path, basename) 36 | shutil.move(src_file_path, target_file_path) 37 | 38 | finished_flag = os.path.join(target_path, ".finished") 39 | with open(finished_flag, "a"): 40 | os.utime(finished_flag, None) 41 | 42 | logging.debug( 43 | "Finished saving data to {0}. Took {1}s".format( 44 | name, time.time() - start_time 45 | ) 46 | ) 47 | cb() 48 | 49 | 50 | def _dump_cache(data, filename, name, i): 51 | dump(data, filename, 3) 52 | logging.debug( 53 | "Finished saving cache for {0} block {1} to {2}".format( 54 | name, i, filename 55 | ) 56 | ) 57 | 58 | 59 | def _dump(data, filename, cb): 60 | dump(data, filename, 3) 61 | logging.debug("Finished saving data to {0}".format(filename)) 62 | cb() 63 | 64 | 65 | class DataStreamRecorder(Process): 66 | def __init__( 67 | self, name, data_sampler_method, cache_path=None, save_every=50 68 | ): 69 | """Initializes a DataStreamRecorder 70 | Parameters 71 | ---------- 72 | name : string 73 | User-friendly identifier for this data stream 74 | data_sampler_method : function 75 | Method to call to retrieve data 76 | """ 77 | Process.__init__(self) 78 | self._data_sampler_method = data_sampler_method 79 | 80 | self._has_set_sampler_params = False 81 | self._recording = False 82 | 83 | self._name = name 84 | 85 | self._cmds_q = Queue() 86 | self._data_qs = [Queue()] 87 | self._ok_q = None 88 | self._tokens_q = None 89 | 90 | self._save_every = save_every 91 | self._cache_path = cache_path 92 | self._saving_cache = cache_path is not None 93 | if self._saving_cache: 94 | self._save_path = os.path.join(cache_path, self.name) 95 | if not os.path.exists(self._save_path): 96 | os.makedirs(self._save_path) 97 | 98 | self._start_data_segment = 0 99 | self._cur_data_segment = 0 100 | self._saving_ps = [] 101 | 102 | def run(self): 103 | setproctitle("python.DataStreamRecorder.{0}".format(self._name)) 104 | try: 105 | logging.debug("Starting data recording on {0}".format(self.name)) 106 | self._tokens_q.put(("return", self.name)) 107 | while True: 108 | if not self._cmds_q.empty(): 109 | cmd = self._cmds_q.get() 110 | if cmd[0] == "stop": 111 | break 112 | elif cmd[0] == "pause": 113 | self._recording = False 114 | if self._saving_cache: 115 | self._save_cache(self._cur_data_segment) 116 | self._cur_data_segment += 1 117 | self._data_qs.append(Queue()) 118 | elif cmd[0] == "reset_data_segment": 119 | self._start_data_segment = self._cur_data_segment 120 | elif cmd[0] == "resume": 121 | self._recording = True 122 | elif cmd[0] == "save": 123 | self._save_data(cmd[1], cmd[2], cmd[3]) 124 | elif cmd[0] == "params": 125 | self._args = cmd[1] 126 | self._kwargs = cmd[2] 127 | 128 | if self._recording and not self._ok_q.empty(): 129 | timestamp = self._ok_q.get() 130 | self._tokens_q.put(("take", self.name)) 131 | 132 | data = self._data_sampler_method( 133 | *self._args, **self._kwargs 134 | ) 135 | 136 | cur_data_q = self._data_qs[self._cur_data_segment] 137 | if ( 138 | self._saving_cache 139 | and cur_data_q.qsize() == self._save_every 140 | ): 141 | self._save_cache(self._cur_data_segment) 142 | cur_data_q = Queue() 143 | self._data_qs.append(cur_data_q) 144 | self._cur_data_segment += 1 145 | cur_data_q.put((timestamp, data)) 146 | 147 | self._tokens_q.put(("return", self.name)) 148 | 149 | except KeyboardInterrupt: 150 | logging.debug( 151 | "Shutting down data streamer on {0}".format(self.name) 152 | ) 153 | sys.exit(0) 154 | 155 | def _extract_q(self, i): 156 | q = self._data_qs[i] 157 | vals = [] 158 | while q.qsize() > 0: 159 | vals.append(q.get()) 160 | self._data_qs[i] = None 161 | del q 162 | return vals 163 | 164 | def _save_data(self, path, cb, concat): 165 | if not os.path.exists(path): 166 | os.makedirs(path) 167 | target_filename = os.path.join(path, "{0}.jb".format(self.name)) 168 | if self._saving_cache: 169 | while True in [p.is_alive() for p in self._saving_ps]: 170 | time.sleep(1e-3) 171 | 172 | p = Process( 173 | target=_caches_to_file, 174 | args=( 175 | self._save_path, 176 | self._start_data_segment, 177 | self._cur_data_segment, 178 | target_filename, 179 | cb, 180 | concat, 181 | ), 182 | ) 183 | p.start() 184 | self._start_data_segment = self._cur_data_segment 185 | else: 186 | data = self._extract_q(0) 187 | p = Process(target=_dump, args=(data, target_filename, cb)) 188 | p.start() 189 | 190 | def _save_cache(self, i): 191 | if not self._save_cache: 192 | raise Exception( 193 | "Cannot save cache if no cache path was specified." 194 | ) 195 | logging.debug( 196 | "Saving cache for {0} block {1}".format( 197 | self.name, self._cur_data_segment 198 | ) 199 | ) 200 | data = self._extract_q(i) 201 | p = Process( 202 | target=_dump_cache, 203 | args=( 204 | data, 205 | os.path.join( 206 | self._save_path, "{0}.jb".format(self._cur_data_segment) 207 | ), 208 | self.name, 209 | self._cur_data_segment, 210 | ), 211 | ) 212 | p.start() 213 | self._saving_ps.append(p) 214 | 215 | def _start_recording(self, *args, **kwargs): 216 | """Starts recording 217 | Parameters 218 | ---------- 219 | *args : any 220 | Ordinary args used for calling the specified data 221 | sampler method 222 | **kwargs : any 223 | Keyword args used for calling the specified data 224 | sampler method 225 | """ 226 | while not self._cmds_q.empty(): 227 | self._cmds_q.get_nowait() 228 | while not self._data_qs[self._cur_data_segment].empty(): 229 | self._data_qs[self._cur_data_segment].get_nowait() 230 | 231 | self._args = args 232 | self._kwargs = kwargs 233 | 234 | self._recording = True 235 | self.start() 236 | 237 | @property 238 | def name(self): 239 | return self._name 240 | 241 | def _set_qs(self, ok_q, tokens_q): 242 | self._ok_q = ok_q 243 | self._tokens_q = tokens_q 244 | 245 | def _flush(self): 246 | """Returns a list of all current data""" 247 | if self._recording: 248 | raise Exception("Cannot flush data queue while recording!") 249 | if self._saving_cache: 250 | logging.warn( 251 | "Flush when using cache means unsaved data will be lost " 252 | "and not returned!" 253 | ) 254 | self._cmds_q.put(("reset_data_segment",)) 255 | else: 256 | data = self._extract_q(0) 257 | return data 258 | 259 | def save_data(self, path, cb=_NULL, concat=True): 260 | if self._recording: 261 | raise Exception("Cannot save data while recording!") 262 | self._cmds_q.put(("save", path, cb, concat)) 263 | 264 | def _stop(self): 265 | """Stops recording. Returns all recorded data and their timestamps. 266 | Destroys recorder process.""" 267 | self._pause() 268 | self._cmds_q.put(("stop",)) 269 | try: 270 | self._recorder.terminate() 271 | except Exception: 272 | pass 273 | self._recording = False 274 | 275 | def _pause(self): 276 | """Pauses recording""" 277 | self._cmds_q.put(("pause",)) 278 | self._recording = False 279 | 280 | def _resume(self): 281 | """Resumes recording""" 282 | self._cmds_q.put(("resume",)) 283 | self._recording = True 284 | 285 | def change_data_sampler_params(self, *args, **kwargs): 286 | """Chanes args and kwargs for data sampler method 287 | Parameters 288 | ---------- 289 | *args : any 290 | Ordinary args used for calling the specified data 291 | sampler method 292 | **kwargs : any 293 | Keyword args used for calling the specified data 294 | sampler method 295 | """ 296 | self._cmds_q.put(("params", args, kwargs)) 297 | -------------------------------------------------------------------------------- /autolab_core/data_stream_syncer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class to sync and rate limit multiple DataStreamRecorders 3 | Author: Jacky Liang 4 | """ 5 | from multiprocess import Process, Queue 6 | from queue import Empty 7 | from time import time 8 | import logging 9 | import sys 10 | from setproctitle import setproctitle 11 | 12 | 13 | class _DataStreamSyncer(Process): 14 | def __init__(self, frequency, ok_qs, cmds_q, tokens_q): 15 | Process.__init__(self) 16 | self._cmds_q = cmds_q 17 | self._tokens_q = tokens_q 18 | 19 | self._ok_qs = ok_qs 20 | self._tokens = {key: False for key in self._ok_qs.keys()} 21 | 22 | self._T = 1.0 / frequency if frequency > 0 else 0 23 | self._ok_start_time = None 24 | self._pause = False 25 | 26 | def run(self): 27 | self._session_start_time = time() 28 | setproctitle("python._DataStreamSyncer") 29 | try: 30 | while True: 31 | if not self._cmds_q.empty(): 32 | cmd = self._cmds_q.get() 33 | if cmd[0] == "reset_time": 34 | self._session_start_time = time() 35 | elif cmd[0] == "pause": 36 | self._pause = True 37 | self._pause_time = time() 38 | self._take_oks() 39 | elif cmd[0] == "resume": 40 | self._pause = False 41 | if cmd[1]: 42 | self._session_start_time = time() 43 | else: 44 | self._session_start_time += ( 45 | time() - self._pause_time 46 | ) 47 | elif cmd[0] == "stop": 48 | break 49 | if not self._tokens_q.empty(): 50 | motion, name = self._tokens_q.get() 51 | if motion == "take": 52 | self._tokens[name] = False 53 | elif motion == "return": 54 | self._tokens[name] = True 55 | self._try_ok() 56 | except KeyboardInterrupt: 57 | logging.debug("Exiting DataStreamSyncer") 58 | sys.exit(0) 59 | 60 | def _send_oks(self): 61 | self._ok_start_time = time() 62 | t = self._ok_start_time - self._session_start_time 63 | for ok_q in self._ok_qs.values(): 64 | ok_q.put(t) 65 | 66 | def _take_oks(self): 67 | for ok_q in self._ok_qs.values(): 68 | while ok_q.qsize() > 0: 69 | try: 70 | ok_q.get_nowait() 71 | except Empty: 72 | pass 73 | 74 | def _try_ok(self): 75 | if self._pause: 76 | return 77 | if False in self._tokens.values(): 78 | if ( 79 | self._ok_start_time is not None 80 | and time() - self._ok_start_time > self._T 81 | ): 82 | timeout_names = [] 83 | for name, returned in self._tokens.items(): 84 | if not returned: 85 | timeout_names.append(name) 86 | logging.warn( 87 | f"Potential timeout! {timeout_names} not yet returned " 88 | "within desired period!" 89 | ) 90 | return 91 | if ( 92 | self._T <= 0 93 | or self._ok_start_time is None 94 | or time() - self._ok_start_time > self._T 95 | ): 96 | self._send_oks() 97 | 98 | 99 | class DataStreamSyncer: 100 | def __init__(self, data_stream_recorders, frequency=0): 101 | """ 102 | Instantiates a new DataStreamSyncer 103 | 104 | Parameters 105 | ---------- 106 | data_stream_recorders : list of DataStreamRecorders to sync 107 | frequency : float, optional 108 | Frequency in hz used for ratelimiting. If set to 0 109 | or less, will not rate limit. Defaults to 0. 110 | """ 111 | self._cmds_q = Queue() 112 | self._tokens_q = Queue() 113 | 114 | self._data_stream_recorders = data_stream_recorders 115 | ok_qs = {} 116 | for data_stream_recorder in self._data_stream_recorders: 117 | ok_q = Queue() 118 | name = data_stream_recorder.name 119 | if name in ok_qs: 120 | raise ValueError( 121 | "Data Stream Recorders must have unique names! " 122 | f"{name} is a duplicate!" 123 | ) 124 | ok_qs[name] = ok_q 125 | data_stream_recorder._set_qs(ok_q, self._tokens_q) 126 | 127 | self._syncer = _DataStreamSyncer( 128 | frequency, ok_qs, self._cmds_q, self._tokens_q 129 | ) 130 | self._syncer.start() 131 | 132 | def start(self): 133 | """Starts syncer operations""" 134 | for recorder in self._data_stream_recorders: 135 | recorder._start_recording() 136 | 137 | def stop(self): 138 | """Stops syncer operations. Destroys syncer process.""" 139 | self._cmds_q.put(("stop",)) 140 | for recorder in self._data_stream_recorders: 141 | recorder._stop() 142 | try: 143 | self._syncer.terminate() 144 | except Exception: 145 | pass 146 | 147 | def pause(self): 148 | self._cmds_q.put(("pause",)) 149 | for recorder in self._data_stream_recorders: 150 | recorder._pause() 151 | 152 | def resume(self, reset_time=False): 153 | self._cmds_q.put(("resume", reset_time)) 154 | for recorder in self._data_stream_recorders: 155 | recorder._resume() 156 | 157 | def reset_time(self): 158 | self._cmds_q.put(("reset_time",)) 159 | 160 | def flush(self): 161 | data = {} 162 | for recorder in self._data_stream_recorders: 163 | data[recorder.name] = recorder._flush() 164 | return data 165 | -------------------------------------------------------------------------------- /autolab_core/dist_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom distance metrics. 3 | Author: Jeff Mahler 4 | """ 5 | import numpy as np 6 | 7 | 8 | def abs_angle_diff(v_i, v_j): 9 | """ 10 | Returns the absolute value of the angle between two 3D vectors. 11 | 12 | Parameters 13 | ---------- 14 | v_i : :obj:`numpy.ndarray` 15 | the first 3D array 16 | v_j : :obj:`numpy.ndarray` 17 | the second 3D array 18 | """ 19 | # compute angle distance 20 | dot_prod = min(max(v_i.dot(v_j), -1), 1) 21 | angle_diff = np.arccos(dot_prod) 22 | return np.abs(angle_diff) 23 | 24 | 25 | # dictionary of distance functions 26 | DistMetrics = {"abs_angle_diff": abs_angle_diff} 27 | -------------------------------------------------------------------------------- /autolab_core/dual_quaternion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class to handle dual quaternions and their interpolations 3 | Implementation details inspired by Ben Kenwright's "A Beginners Guide to 4 | Dual-Quaternions" 5 | http://cs.gmu.edu/~jmlien/teaching/cs451/uploads/Main/dual-quaternion.pdf 6 | Author: Jacky Liang 7 | """ 8 | from numbers import Number 9 | import numpy as np 10 | 11 | from .transformations import quaternion_multiply, quaternion_conjugate 12 | 13 | 14 | class DualQuaternion(object): 15 | """Class for handling dual quaternions and their interpolations. 16 | 17 | Attributes 18 | ---------- 19 | qr : :obj:`numpy.ndarray` of float 20 | A 4-entry quaternion in wxyz format. 21 | 22 | qd : :obj:`numpy.ndarray` of float 23 | A 4-entry quaternion in wxyz format. 24 | 25 | conjugate : :obj:`DualQuaternion` 26 | The conjugate of this DualQuaternion. 27 | 28 | norm : :obj:`tuple` of :obj:`numpy.ndarray` 29 | The normalized vectors for qr and qd, respectively. 30 | 31 | normalized : :obj:`DualQuaternion` 32 | This quaternion with qr normalized. 33 | """ 34 | 35 | def __init__( 36 | self, qr=[1, 0, 0, 0], qd=[0, 0, 0, 0], enforce_unit_norm=True 37 | ): 38 | """Initialize a dual quaternion. 39 | 40 | Parameters 41 | ---------- 42 | qr : :obj:`numpy.ndarray` of float 43 | A 4-entry quaternion in wxyz format. 44 | 45 | qd : :obj:`numpy.ndarray` of float 46 | A 4-entry quaternion in wxyz format. 47 | 48 | enforce_unit_norm : bool 49 | If true, raises a ValueError when the quaternion is not normalized. 50 | 51 | Raises 52 | ------ 53 | ValueError 54 | If enforce_unit_norm is True and the norm of qr is not 1. 55 | """ 56 | self.qr = qr 57 | self.qd = qd 58 | 59 | if enforce_unit_norm: 60 | norm = self.norm 61 | if not np.allclose(norm[0], [1]): 62 | raise ValueError( 63 | "Dual quaternion does not have norm 1! Got {0}".format( 64 | norm[0] 65 | ) 66 | ) 67 | 68 | @property 69 | def qr(self): 70 | """:obj:`numpy.ndarray` of float: A 4-entry quaternion in 71 | wxyz format.""" 72 | qr_wxyz = np.roll(self._qr, 1) 73 | return qr_wxyz 74 | 75 | @qr.setter 76 | def qr(self, qr_wxyz): 77 | qr_wxyz = np.array([n for n in qr_wxyz]) 78 | qr_xyzw = np.roll(qr_wxyz, -1) 79 | self._qr = qr_xyzw 80 | 81 | @property 82 | def qd(self): 83 | """:obj:`numpy.ndarray` of float: A 4-entry quaternion in 84 | wxyz format.""" 85 | qd_wxyz = np.roll(self._qd, 1) 86 | return qd_wxyz 87 | 88 | @qd.setter 89 | def qd(self, qd_wxyz): 90 | qd_wxyz = np.array([n for n in qd_wxyz]) 91 | if qd_wxyz[0] != 0: 92 | raise ValueError( 93 | "Invalid dual quaternion! First value of Qd must be 0. " 94 | f"Got {qd_wxyz[0]}" 95 | ) 96 | qd_xyzw = np.roll(qd_wxyz, -1) 97 | self._qd = qd_xyzw 98 | 99 | @property 100 | def conjugate(self): 101 | """:obj:`DualQuaternion`: The conjugate of this quaternion.""" 102 | qr_c_xyzw = quaternion_conjugate(self._qr) 103 | qd_c_xyzw = quaternion_conjugate(self._qd) 104 | 105 | qr_c_wxyz = np.roll(qr_c_xyzw, 1) 106 | qd_c_wxyz = np.roll(qd_c_xyzw, 1) 107 | return DualQuaternion(qr_c_wxyz, qd_c_wxyz) 108 | 109 | @property 110 | def norm(self): 111 | """:obj:`tuple` of :obj:`numpy.ndarray`: The normalized vectors for qr 112 | and qd, respectively.""" 113 | qr_c = quaternion_conjugate(self._qr) 114 | qd_c = quaternion_conjugate(self._qd) 115 | 116 | qr_norm = np.linalg.norm(quaternion_multiply(self._qr, qr_c)) 117 | qd_norm = np.linalg.norm( 118 | quaternion_multiply(self._qr, qd_c) 119 | + quaternion_multiply(self._qd, qr_c) 120 | ) 121 | 122 | return (qr_norm, qd_norm) 123 | 124 | @property 125 | def normalized(self): 126 | """:obj:`DualQuaternion`: This quaternion with qr normalized.""" 127 | qr = self.qr / 1.0 / np.linalg.norm(self.qr) 128 | return DualQuaternion(qr, self.qd, True) 129 | 130 | def copy(self): 131 | """Return a copy of this quaternion. 132 | 133 | Returns 134 | ------- 135 | :obj:`DualQuaternion` 136 | The copied DualQuaternion. 137 | """ 138 | return DualQuaternion(self.qr.copy(), self.qd.copy()) 139 | 140 | @staticmethod 141 | def interpolate(dq0, dq1, t): 142 | """Return the interpolation of two DualQuaternions. 143 | 144 | This uses the Dual Quaternion Linear Blending Method as described by 145 | Matthew Smith's 'Applications of Dual Quaternions in Three Dimensional 146 | Transformation and Interpolation' 147 | https://www.cosc.canterbury.ac.nz/research/reports/HonsReps/2013/hons_1305.pdf 148 | 149 | Parameters 150 | ---------- 151 | dq0 : :obj:`DualQuaternion` 152 | The first DualQuaternion. 153 | 154 | dq1 : :obj:`DualQuaternion` 155 | The second DualQuaternion. 156 | 157 | t : float 158 | The interpolation step in [0,1]. When t=0, this returns dq0, and 159 | when t=1, this returns dq1. 160 | 161 | Returns 162 | ------- 163 | :obj:`DualQuaternion` 164 | The interpolated DualQuaternion. 165 | 166 | Raises 167 | ------ 168 | ValueError 169 | If t isn't in [0,1]. 170 | """ 171 | if not 0 <= t <= 1: 172 | raise ValueError( 173 | "Interpolation step must be between 0 and 1! Got {0}".format(t) 174 | ) 175 | 176 | dqt = dq0 * (1 - t) + dq1 * t 177 | return dqt.normalized 178 | 179 | def __mul__(self, val): 180 | """Multiplies the dual quaternion by another dual quaternion or a 181 | scalar. 182 | 183 | Parameters 184 | ---------- 185 | val : :obj:`DualQuaternion` or number 186 | The value by which to multiply this dual quaternion. 187 | 188 | Returns 189 | ------- 190 | :obj:`DualQuaternion` 191 | A new DualQuaternion that results from the multiplication. 192 | 193 | Raises 194 | ------ 195 | ValueError 196 | If val is not a DualQuaternion or Number. 197 | """ 198 | if isinstance(val, DualQuaternion): 199 | new_qr_xyzw = quaternion_multiply(self._qr, val._qr) 200 | new_qd_xyzw = quaternion_multiply( 201 | self._qr, val._qd 202 | ) + quaternion_multiply(self._qd, val._qr) 203 | 204 | new_qr_wxyz = np.roll(new_qr_xyzw, 1) 205 | new_qd_wxyz = np.roll(new_qd_xyzw, 1) 206 | 207 | return DualQuaternion(new_qr_wxyz, new_qd_wxyz) 208 | elif isinstance(val, Number): 209 | new_qr_wxyz = val * self.qr 210 | new_qd_wxyz = val * self.qd 211 | 212 | return DualQuaternion(new_qr_wxyz, new_qd_wxyz, False) 213 | 214 | raise ValueError( 215 | "Cannot multiply dual quaternion with object of type {0}".format( 216 | type(val) 217 | ) 218 | ) 219 | 220 | def __add__(self, val): 221 | """Adds the dual quaternion to another dual quaternion. 222 | 223 | Parameters 224 | ---------- 225 | val : :obj:`DualQuaternion` 226 | The DualQuaternion to add to this one. 227 | 228 | Returns 229 | ------- 230 | :obj:`DualQuaternion` 231 | A new DualQuaternion that results from the addition.. 232 | 233 | Raises 234 | ------ 235 | ValueError 236 | If val is not a DualQuaternion. 237 | """ 238 | if not isinstance(val, DualQuaternion): 239 | raise ValueError( 240 | "Cannot add dual quaternion with object of type {0}".format( 241 | type(val) 242 | ) 243 | ) 244 | 245 | new_qr_wxyz = self.qr + val.qr 246 | new_qd_wxyz = self.qd + val.qd 247 | new_qr_wxyz = new_qr_wxyz / np.linalg.norm(new_qr_wxyz) 248 | 249 | return DualQuaternion(new_qr_wxyz, new_qd_wxyz, False) 250 | 251 | def __str__(self): 252 | return "{0}+{1}e".format(self.qr, self.qd) 253 | 254 | def __repr__(self): 255 | return "DualQuaternion({0},{1})".format(repr(self.qr), repr(self.qd)) 256 | -------------------------------------------------------------------------------- /autolab_core/experiment_logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Class to handle experiment logging. 3 | Authors: Jeff, Jacky 4 | """ 5 | from abc import ABCMeta, abstractmethod 6 | import os 7 | import shutil 8 | import subprocess 9 | import logging 10 | 11 | from .csv_model import CSVModel 12 | from .utils import gen_experiment_id 13 | 14 | 15 | class ExperimentLogger: 16 | """Abstract class for experiment logging. 17 | 18 | Experiments are logged to CSV files, which are encapsulated with the 19 | :obj:`CSVModel` class. 20 | """ 21 | 22 | __metaclass__ = ABCMeta 23 | 24 | _MASTER_RECORD_FILENAME = "experiment_record.csv" 25 | 26 | def __init__( 27 | self, 28 | experiment_root_path, 29 | experiment_tag="experiment", 30 | log_to_file=True, 31 | sub_experiment_dirs=True, 32 | ): 33 | """Initialize an ExperimentLogger. 34 | 35 | Parameters 36 | ---------- 37 | experiment_root_path : :obj:`str` 38 | The root directory in which to save experiment files. 39 | experiment_tag : :obj:`str` 40 | The tag to use when prefixing new experiments 41 | log_to_file : bool, optional 42 | Default: True 43 | If True will log all logging statements to a log file 44 | sub_experiment_dirs : bool, optional 45 | Defautl: True 46 | If True will make sub directories corresponding to generated 47 | experiment name 48 | """ 49 | self.experiment_root_path = experiment_root_path 50 | 51 | # open the master record 52 | self.master_record_filepath = os.path.join( 53 | self.experiment_root_path, ExperimentLogger._MASTER_RECORD_FILENAME 54 | ) 55 | self.master_record = CSVModel.get_or_create( 56 | self.master_record_filepath, self.experiment_meta_headers 57 | ) 58 | 59 | # add new experiment to the master record 60 | self.id = ExperimentLogger.gen_experiment_ref(experiment_tag) 61 | self._master_record_uid = self.master_record.insert( 62 | self.experiment_meta_data 63 | ) 64 | 65 | # make experiment output dir 66 | if sub_experiment_dirs: 67 | self.experiment_path = os.path.join( 68 | self.experiment_root_path, self.id 69 | ) 70 | if not os.path.exists(self.experiment_path): 71 | os.makedirs(self.experiment_path) 72 | else: 73 | self.experiment_path = self.experiment_root_path 74 | 75 | if log_to_file: 76 | # redirect logging statements to a file 77 | if not sub_experiment_dirs: 78 | self.log_path = os.path.join(self.experiment_root_path, "logs") 79 | else: 80 | self.log_path = self.experiment_path 81 | if not os.path.exists(self.log_path): 82 | os.makedirs(self.log_path) 83 | experiment_log = os.path.join(self.log_path, "%s.log" % (self.id)) 84 | formatter = logging.Formatter( 85 | "%(asctime)s %(levelname)s: %(message)s" 86 | ) 87 | hdlr = logging.FileHandler(experiment_log) 88 | hdlr.setFormatter(formatter) 89 | logging.getLogger().addHandler(hdlr) 90 | 91 | # internal dir struct 92 | self._dirs = {} 93 | 94 | @staticmethod 95 | def gen_experiment_ref(experiment_tag, n=10): 96 | """Generate a random string for naming. 97 | 98 | Parameters 99 | ---------- 100 | experiment_tag : :obj:`str` 101 | tag to prefix name with 102 | n : int 103 | number of random chars to use 104 | 105 | Returns 106 | ------- 107 | :obj:`str` 108 | string experiment ref 109 | """ 110 | experiment_id = gen_experiment_id(n=n) 111 | return "{0}_{1}".format(experiment_tag, experiment_id) 112 | 113 | def update_master_record(self, data): 114 | """Update a row of the experimental master record CSV with the given data. 115 | 116 | Parameters 117 | ---------- 118 | uid : int 119 | The UID of the row to update. 120 | 121 | data : :obj:`dict` 122 | A dictionary mapping keys (header strings) to values, which 123 | represents the new row. 124 | """ 125 | self.master_record.update_by_uid(self._master_record_uid, data) 126 | 127 | @abstractmethod 128 | def experiment_meta_headers(self): 129 | """Returns list of two-tuples of header names and types of meta 130 | information for the experiments 131 | 132 | Returns 133 | ------- 134 | :obj:`tuple` 135 | The metadata for this experiment. 136 | """ 137 | pass 138 | 139 | @abstractmethod 140 | def experiment_meta_data(self): 141 | """Returns the dict of header names and value of meta information for 142 | the experiments 143 | 144 | Returns 145 | ------- 146 | :obj:`dict` 147 | The metadata for this experiment. 148 | """ 149 | pass 150 | 151 | @property 152 | def dirs(self): 153 | return self._dirs.copy() 154 | 155 | def construct_internal_dirs(self, dirs, realize=False): 156 | cur_dir = self._dirs 157 | for dir in dirs: 158 | if dir not in cur_dir: 159 | cur_dir[dir] = {} 160 | cur_dir = cur_dir[dir] 161 | if realize: 162 | self._realize_dirs(dirs) 163 | 164 | def construct_internal_dirs_group(self, group_dirs): 165 | for dirs in group_dirs: 166 | self.construct_internal_dirs(dirs) 167 | 168 | def has_internal_dirs(self, dirs): 169 | cur_dir = self.dirs 170 | for dir in dirs: 171 | if dir not in cur_dir: 172 | return False 173 | cur_dir = cur_dir[dir] 174 | return True 175 | 176 | def dirs_to_path(self, dirs): 177 | rel_path = "/".join(dirs) 178 | abs_path = os.path.join(self.experiment_path, rel_path) 179 | return abs_path 180 | 181 | def _realize_dirs(self, dirs): 182 | if not self.has_internal_dirs(dirs): 183 | raise Exception( 184 | "Directory has not been constructed internally! {0}".format( 185 | dirs 186 | ) 187 | ) 188 | abs_path = self.dirs_to_path(dirs) 189 | if not os.path.exists(abs_path): 190 | os.makedirs(abs_path) 191 | return abs_path 192 | 193 | def remove_dirs(self, dirs): 194 | if not self.has_internal_dirs(dirs): 195 | raise Exception( 196 | "Directory has not been construted internally! {0}".format( 197 | dirs 198 | ) 199 | ) 200 | 201 | path = self.dirs_to_path(dirs) 202 | if os.path.exists(path): 203 | subprocess.call(["trash", "-r", path]) 204 | 205 | # remove the deepest node 206 | cur_dir = self.dirs 207 | for dir in dirs[:-1]: 208 | cur_dir = cur_dir[dir] 209 | cur_dir.pop(dirs[-1]) 210 | 211 | for i in range(len(dirs) - 1): 212 | cur_dir = self._dirs 213 | depth = len(dirs) - i - 2 214 | for j in range(depth): 215 | cur_dir = cur_dir[dirs[j]] 216 | 217 | dir_to_remove = dirs[depth] 218 | if not cur_dir[dir_to_remove]: 219 | cur_dir.pop(dir_to_remove) 220 | else: 221 | break 222 | 223 | def copy_to_dir(self, src_file_path, target_dirs): 224 | abs_path = self._realize_dirs(target_dirs) 225 | basename = os.path.basename(src_file_path) 226 | target_file_path = os.path.join(abs_path, basename) 227 | 228 | logging.debug( 229 | "Copying {0} to {1}".format(src_file_path, target_file_path) 230 | ) 231 | shutil.copyfile(src_file_path, target_file_path) 232 | 233 | def copy_dirs(self, src_dirs_path, target_dirs): 234 | if not self.has_internal_dirs(target_dirs): 235 | raise Exception( 236 | "Directory has not been constructed internally! {0}".format( 237 | target_dirs 238 | ) 239 | ) 240 | 241 | target_dirs_path = self.dirs_to_path(target_dirs) 242 | if os.path.exists(target_dirs_path): 243 | if len(os.listdir(target_dirs_path)) > 0: 244 | raise Exception( 245 | "Target path for copying directories is not empty! " 246 | "Got: {target_dirs_path}" 247 | ) 248 | else: 249 | os.rmdir(target_dirs_path) 250 | shutil.copytree(src_dirs_path, target_dirs_path) 251 | 252 | @staticmethod 253 | def pretty_str_time(dt): 254 | return "{0}_{1}_{2}_{3}:{4}".format( 255 | dt.year, dt.month, dt.day, dt.hour, dt.minute 256 | ) 257 | -------------------------------------------------------------------------------- /autolab_core/feature_matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes for feature matching between point sets for registration 3 | Author: Jeff Mahler 4 | """ 5 | from abc import ABCMeta, abstractmethod 6 | import numpy as np 7 | from scipy import spatial 8 | import scipy.spatial.distance as ssd 9 | 10 | from .features import BagOfFeatures 11 | 12 | 13 | class Correspondences: 14 | """Wrapper for point-set correspondences. 15 | 16 | Attributes 17 | ---------- 18 | index_map : :obj:`list` of int 19 | maps list indices (source points) to target point indices 20 | source_points : Nx3 :obj:`numpy.ndarray` 21 | set of source points for registration 22 | target_points : Nx3 :obj:`numpy.ndarray` 23 | set of target points for registration 24 | num_matches : int 25 | the total number of matches 26 | """ 27 | 28 | def __init__(self, index_map, source_points, target_points): 29 | self.index_map_ = index_map 30 | self.source_points_ = source_points 31 | self.target_points_ = target_points 32 | self.num_matches_ = source_points.shape[0] 33 | 34 | @property 35 | def index_map(self): 36 | return self.index_map_ 37 | 38 | @property 39 | def source_points(self): 40 | return self.source_points_ 41 | 42 | @property 43 | def target_points(self): 44 | return self.target_points_ 45 | 46 | @property 47 | def num_matches(self): 48 | return self.num_matches_ 49 | 50 | # Functions to iterate through matches like "for source_corr, target_corr 51 | # in correspondences" 52 | def __iter__(self): 53 | self.iter_count_ = 0 54 | 55 | def next(self): 56 | if self.iter_count_ >= len(self.num_matches_): 57 | raise StopIteration 58 | else: 59 | return ( 60 | self.source_points_[self.iter_count, :], 61 | self.target_points_[self.iter_count, :], 62 | ) 63 | 64 | 65 | class NormalCorrespondences(Correspondences): 66 | """Wrapper for point-set correspondences with surface normals. 67 | 68 | Attributes 69 | ---------- 70 | index_map : :obj:`list` of int 71 | maps list indices (source points) to target point indices 72 | source_points : Nx3 :obj:`numpy.ndarray` 73 | set of source points for registration 74 | target_points : Nx3 :obj:`numpy.ndarray` 75 | set of target points for registration 76 | source_normals : normalized Nx3 :obj:`numpy.ndarray` 77 | set of source normals for registration 78 | target_normals : normalized Nx3 :obj:`numpy.ndarray` 79 | set of target points for registration 80 | num_matches : int 81 | the total number of matches 82 | """ 83 | 84 | def __init__( 85 | self, 86 | index_map, 87 | source_points, 88 | target_points, 89 | source_normals, 90 | target_normals, 91 | ): 92 | self.source_normals_ = source_normals 93 | self.target_normals_ = target_normals 94 | Correspondences.__init__(self, index_map, source_points, target_points) 95 | 96 | @property 97 | def source_normals(self): 98 | return self.source_normals_ 99 | 100 | @property 101 | def target_normals(self): 102 | return self.target_normals_ 103 | 104 | # Functions to iterate through matches like "for source_corr, target_corr 105 | # in correspondences" 106 | def __iter__(self): 107 | self.iter_count_ = 0 108 | 109 | def next(self): 110 | if self.iter_count_ >= len(self.num_matches_): 111 | raise StopIteration 112 | else: 113 | return ( 114 | self.source_points_[self.iter_count, :], 115 | self.target_points_[self.iter_count, :], 116 | self.source_normals_[self.iter_count, :], 117 | self.target_normals_[self.iter_count, :], 118 | ) 119 | 120 | 121 | class FeatureMatcher: 122 | """ 123 | Generic feature matching between local features on a source and 124 | target object using nearest neighbors. 125 | """ 126 | 127 | __metaclass__ = ABCMeta 128 | 129 | def __init__(self): 130 | pass 131 | 132 | @staticmethod 133 | def get_point_index(point, all_points, eps=1e-4): 134 | """Get the index of a point in an array""" 135 | inds = np.where(np.linalg.norm(point - all_points, axis=1) < eps) 136 | if inds[0].shape[0] == 0: 137 | return -1 138 | return inds[0][0] 139 | 140 | @abstractmethod 141 | def match(self, source_obj, target_obj): 142 | """ 143 | Matches features between a source and target object. Source and target 144 | object types depend on subclass implementation. 145 | """ 146 | pass 147 | 148 | 149 | class RawDistanceFeatureMatcher(FeatureMatcher): 150 | def match(self, source_obj_features, target_obj_features): 151 | """ 152 | Matches features between two graspable objects based on 153 | a full distance matrix. 154 | 155 | Parameters 156 | ---------- 157 | source_obj_features : :obj:`BagOfFeatures` 158 | bag of the source objects features 159 | target_obj_features : :obj:`BagOfFeatures` 160 | bag of the target objects features 161 | 162 | Returns 163 | ------- 164 | corrs : :obj:`Correspondences` 165 | the correspondences between source and target 166 | """ 167 | if not isinstance(source_obj_features, BagOfFeatures): 168 | raise ValueError("Must supply source bag of object features") 169 | if not isinstance(target_obj_features, BagOfFeatures): 170 | raise ValueError("Must supply target bag of object features") 171 | 172 | # source feature descriptors and keypoints 173 | source_descriptors = source_obj_features.descriptors 174 | target_descriptors = target_obj_features.descriptors 175 | source_keypoints = source_obj_features.keypoints 176 | target_keypoints = target_obj_features.keypoints 177 | 178 | # calculate distance between this model's descriptors and each of the 179 | # other_model's descriptors 180 | dists = spatial.distance.cdist(source_descriptors, target_descriptors) 181 | 182 | # calculate the indices of the target_model that minimize the distance 183 | # to the descriptors in this model 184 | source_closest_descriptors = dists.argmin(axis=1) 185 | target_closest_descriptors = dists.argmin(axis=0) 186 | match_indices = [] 187 | source_matched_points = np.zeros((0, 3)) 188 | target_matched_points = np.zeros((0, 3)) 189 | 190 | # calculate which points/indices the closest descriptors correspond to 191 | for i, j in enumerate(source_closest_descriptors): 192 | # for now, only keep correspondences that are a 2-way match 193 | if target_closest_descriptors[j] == i: 194 | match_indices.append(j) 195 | source_matched_points = np.r_[ 196 | source_matched_points, source_keypoints[i : i + 1, :] 197 | ] 198 | target_matched_points = np.r_[ 199 | target_matched_points, target_keypoints[j : j + 1, :] 200 | ] 201 | else: 202 | match_indices.append(-1) 203 | 204 | return Correspondences( 205 | match_indices, source_matched_points, target_matched_points 206 | ) 207 | 208 | 209 | class PointToPlaneFeatureMatcher(FeatureMatcher): 210 | """Match points using a point to plane criterion with thresholding. 211 | 212 | Attributes 213 | ---------- 214 | dist_thresh : float 215 | threshold distance to consider a match valid 216 | norm_thresh : float 217 | threshold cosine distance alignment betwen normals 218 | to consider a match valid 219 | """ 220 | 221 | def __init__(self, dist_thresh=0.05, norm_thresh=0.75): 222 | self.dist_thresh_ = dist_thresh 223 | self.norm_thresh_ = norm_thresh 224 | FeatureMatcher.__init__(self) 225 | 226 | def match( 227 | self, source_points, target_points, source_normals, target_normals 228 | ): 229 | """ 230 | Matches points between two point-normal sets. Uses the closest ip 231 | to choose matches, with distance for thresholding only. 232 | 233 | Parameters 234 | ---------- 235 | source_point_cloud : Nx3 :obj:`numpy.ndarray` 236 | source object points 237 | target_point_cloud : Nx3 :obj:`numpy.ndarray` 238 | target object points 239 | source_normal_cloud : Nx3 :obj:`numpy.ndarray` 240 | source object outward-pointing normals 241 | target_normal_cloud : Nx3 :obj`numpy.ndarray` 242 | target object outward-pointing normals 243 | 244 | Returns 245 | ------- 246 | :obj`Correspondences` 247 | the correspondences between source and target 248 | """ 249 | # compute the distances and inner products between the point sets 250 | dists = ssd.cdist(source_points, target_points, "euclidean") 251 | ip = source_normals.dot( 252 | target_normals.T 253 | ) # abs because we don't have correct orientations 254 | source_ip = source_points.dot(target_normals.T) 255 | target_ip = target_points.dot(target_normals.T) 256 | target_ip = np.diag(target_ip) 257 | target_ip = np.tile(target_ip, [source_points.shape[0], 1]) 258 | abs_diff = np.abs( 259 | source_ip - target_ip 260 | ) # difference in inner products 261 | 262 | # mark invalid correspondences 263 | invalid_dists = np.where(dists > self.dist_thresh_) 264 | abs_diff[invalid_dists[0], invalid_dists[1]] = np.inf 265 | invalid_norms = np.where(ip < self.norm_thresh_) 266 | abs_diff[invalid_norms[0], invalid_norms[1]] = np.inf 267 | 268 | # choose the closest matches 269 | match_indices = np.argmin(abs_diff, axis=1) 270 | match_vals = np.min(abs_diff, axis=1) 271 | invalid_matches = np.where(match_vals == np.inf) 272 | match_indices[invalid_matches[0]] = -1 273 | 274 | return NormalCorrespondences( 275 | match_indices, 276 | source_points, 277 | target_points, 278 | source_normals, 279 | target_normals, 280 | ) 281 | -------------------------------------------------------------------------------- /autolab_core/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes for features of a 3D object surface. 3 | Author: Jeff Mahler 4 | """ 5 | from abc import ABCMeta 6 | 7 | import numpy as np 8 | 9 | 10 | class Feature: 11 | """Abstract class for features""" 12 | 13 | __metaclass__ = ABCMeta 14 | 15 | def __init__(self): 16 | pass 17 | 18 | 19 | class LocalFeature(Feature): 20 | """Local (e.g. pointwise) features on shape surfaces. 21 | 22 | Attributes 23 | ---------- 24 | descriptor : :obj:`numpy.ndarray` 25 | vector to describe the point 26 | reference_frame : :obj:`numpy.ndarray` 27 | reference frame of the descriptor, as an array 28 | point : :obj:`numpy.ndarray` 29 | 3D point on shape surface that descriptor corresponds to 30 | normal : :obj:`numpy.ndarray` 31 | 3D surface normal on shape surface at corresponding point 32 | """ 33 | 34 | __metaclass__ = ABCMeta 35 | 36 | def __init__(self, descriptor, rf, point, normal): 37 | self.descriptor_ = descriptor 38 | self.rf_ = rf 39 | self.point_ = point 40 | self.normal_ = normal 41 | 42 | @property 43 | def descriptor(self): 44 | return self.descriptor_ 45 | 46 | @property 47 | def reference_frame(self): 48 | return self.rf_ 49 | 50 | @property 51 | def keypoint(self): 52 | return self.point_ 53 | 54 | @property 55 | def normal(self): 56 | return self.normal_ 57 | 58 | 59 | class GlobalFeature(Feature): 60 | """Global features of a full shape surface. 61 | 62 | Attributes 63 | ---------- 64 | key : :obj:`str` 65 | object key in database that descriptor corresponds to 66 | descriptor : :obj:`numpy.ndarray` 67 | vector to describe the object 68 | pose : :obj:`autolab_core.RigidTransform` 69 | pose of object for the descriptor, if relevant 70 | """ 71 | 72 | __metaclass__ = ABCMeta 73 | 74 | def __init__(self, key, descriptor, pose=None): 75 | self.key_ = key 76 | self.descriptor_ = descriptor 77 | self.pose_ = pose 78 | 79 | @property 80 | def key(self): 81 | return self.key_ 82 | 83 | @property 84 | def descriptor(self): 85 | return self.descriptor_ 86 | 87 | @property 88 | def pose(self): 89 | return self.pose_ 90 | 91 | 92 | class SHOTFeature(LocalFeature): 93 | """Signature of Oriented Histogram (SHOT) features""" 94 | 95 | def __init__(self, descriptor, rf, point, normal): 96 | LocalFeature.__init__(self, descriptor, rf, point, normal) 97 | 98 | 99 | class MVCNNFeature(GlobalFeature): 100 | """Multi-View Convolutional Neural Network (MV-CNN) descriptor""" 101 | 102 | def __init__(self, key, descriptor, pose=None): 103 | GlobalFeature.__init__(self, key, descriptor, pose) 104 | 105 | 106 | class BagOfFeatures: 107 | """Wrapper for a list of features, created for the sake 108 | of future bag-of-words reps. 109 | 110 | Attributes 111 | ---------- 112 | features : :obj:`list` of :obj:`Feature` 113 | list of feature objects 114 | """ 115 | 116 | def __init__(self, features=None): 117 | self.features_ = features 118 | if self.features_ is None: 119 | self.features_ = [] 120 | 121 | self.num_features_ = len(self.features_) 122 | 123 | def add(self, feature): 124 | """Add a new feature to the bag. 125 | 126 | Parameters 127 | ---------- 128 | feature : :obj:`Feature` 129 | feature to add 130 | """ 131 | self.features_.append(feature) 132 | self.num_features_ = len(self.features_) 133 | 134 | def extend(self, features): 135 | """Add a list of features to the bag. 136 | 137 | Parameters 138 | ---------- 139 | feature : :obj:`list` of :obj:`Feature` 140 | features to add 141 | """ 142 | self.features_.extend(features) 143 | self.num_features_ = len(self.features_) 144 | 145 | def feature(self, index): 146 | """Returns a feature. 147 | 148 | Parameters 149 | ---------- 150 | index : int 151 | index of feature in list 152 | 153 | Returns 154 | ------- 155 | :obj:`Feature` 156 | """ 157 | if index < 0 or index >= self.num_features_: 158 | raise ValueError("Index %d out of range" % (index)) 159 | return self.features_[index] 160 | 161 | def feature_subset(self, indices): 162 | """Returns some subset of the features. 163 | 164 | Parameters 165 | ---------- 166 | indices : :obj:`list` of :obj:`int` 167 | indices of the features in the list 168 | 169 | Returns 170 | ------- 171 | :obj:`list` of :obj:`Feature` 172 | """ 173 | if isinstance(indices, np.ndarray): 174 | indices = indices.tolist() 175 | if not isinstance(indices, list): 176 | raise ValueError("Can only index with lists") 177 | return [self.features_[i] for i in indices] 178 | 179 | @property 180 | def num_features(self): 181 | return self.num_features_ 182 | 183 | @property 184 | def descriptors(self): 185 | """Make a nice array of the descriptors""" 186 | return np.array([f.descriptor for f in self.features_]) 187 | 188 | @property 189 | def reference_frames(self): 190 | """Make a nice array of the reference frames""" 191 | return np.array([f.reference_frame for f in self.features_]) 192 | 193 | @property 194 | def keypoints(self): 195 | """Make a nice array of the keypoints""" 196 | return np.array([f.keypoint for f in self.features_]) 197 | 198 | @property 199 | def normals(self): 200 | """Make a nice array of the normals""" 201 | return np.array([f.normal for f in self.features_]) 202 | -------------------------------------------------------------------------------- /autolab_core/json_serialization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Module to serialize and deserialize JSON with numpy arrays. Adapted from 3 | http://stackoverflow.com/a/24375113/723090 so that arrays are human-readable. 4 | 5 | Author: Brian Hou 6 | """ 7 | 8 | import json as _json 9 | import numpy as np 10 | 11 | 12 | class NumpyEncoder(_json.JSONEncoder): 13 | """A numpy array to json encoder.""" 14 | 15 | def default(self, obj): 16 | """Converts an ndarray into a dictionary for efficient serialization. 17 | 18 | The dict has three keys: 19 | - dtype : The datatype of the array as a string. 20 | - shape : The shape of the array as a tuple. 21 | - __ndarray__ : The data of the array as a list. 22 | 23 | Parameters 24 | ---------- 25 | obj : :obj:`numpy.ndarray` 26 | The ndarray to encode. 27 | 28 | Returns 29 | ------- 30 | :obj:`dict` 31 | The dictionary serialization of obj. 32 | 33 | Raises 34 | ------ 35 | TypeError 36 | If obj isn't an ndarray. 37 | """ 38 | if isinstance(obj, np.ndarray): 39 | return dict( 40 | __ndarray__=obj.tolist(), dtype=str(obj.dtype), shape=obj.shape 41 | ) 42 | # Let the base class default method raise the TypeError 43 | return _json.JSONEncoder(self, obj) 44 | 45 | 46 | def json_numpy_obj_hook(dct): 47 | """Decodes a previously encoded numpy ndarray with proper shape and dtype. 48 | 49 | Parameters 50 | ---------- 51 | dct : :obj:`dict` 52 | The encoded dictionary. 53 | 54 | Returns 55 | ------- 56 | :obj:`numpy.ndarray` 57 | The ndarray that `dct` was encoding. 58 | """ 59 | if isinstance(dct, dict) and "__ndarray__" in dct: 60 | data = np.asarray(dct["__ndarray__"], dtype=dct["dtype"]) 61 | return data.reshape(dct["shape"]) 62 | return dct 63 | 64 | 65 | def dump(*args, **kwargs): 66 | """Dump a numpy.ndarray to file stream. 67 | 68 | This works exactly like the usual `json.dump()` function, 69 | but it uses our custom serializer. 70 | """ 71 | kwargs.update( 72 | dict( 73 | cls=NumpyEncoder, sort_keys=True, indent=4, separators=(",", ": ") 74 | ) 75 | ) 76 | return _json.dump(*args, **kwargs) 77 | 78 | 79 | def load(*args, **kwargs): 80 | """Load an numpy.ndarray from a file stream. 81 | 82 | This works exactly like the usual `json.load()` function, 83 | but it uses our custom deserializer. 84 | """ 85 | kwargs.update(dict(object_hook=json_numpy_obj_hook)) 86 | return _json.load(*args, **kwargs) 87 | -------------------------------------------------------------------------------- /autolab_core/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility class for logging. 3 | 4 | Author: Vishal Satish 5 | """ 6 | import logging 7 | import sys 8 | 9 | import colorlog 10 | 11 | ROOT_LOG_LEVEL = logging.INFO 12 | ROOT_LOG_STREAM = sys.stdout 13 | 14 | 15 | def configure_root(): 16 | """Configure the root logger.""" 17 | root_logger = logging.getLogger() 18 | 19 | # clear any existing handles to streams because we don't want 20 | # duplicate logs. NOTE: we assume that any stream handles we find 21 | # are to ROOT_LOG_STREAM, which is usually the case(because it is 22 | # stdout). This is fine because we will be re-creating that handle. 23 | # Otherwise we might be deleting a handle that won't be re-created, 24 | # which could result in dropped logs. 25 | for hdlr in root_logger.handlers: 26 | if isinstance(hdlr, logging.StreamHandler): 27 | root_logger.removeHandler(hdlr) 28 | 29 | # configure the root logger 30 | root_logger.setLevel(ROOT_LOG_LEVEL) 31 | hdlr = logging.StreamHandler(ROOT_LOG_STREAM) 32 | formatter = colorlog.ColoredFormatter( 33 | "%(purple)s%(name)-10s " 34 | "%(log_color)s%(levelname)-8s%(reset)s " 35 | "%(white)s%(message)s", 36 | reset=True, 37 | log_colors={ 38 | "DEBUG": "cyan", 39 | "INFO": "green", 40 | "WARNING": "yellow", 41 | "ERROR": "red", 42 | "CRITICAL": "red,bg_white", 43 | }, 44 | ) 45 | hdlr.setFormatter(formatter) 46 | root_logger.addHandler(hdlr) 47 | 48 | 49 | def add_root_log_file(log_file): 50 | """ 51 | Add a log file to the root logger. 52 | 53 | Parameters 54 | ---------- 55 | log_file :obj:`str` 56 | The path to the log file. 57 | """ 58 | root_logger = logging.getLogger() 59 | 60 | # add a file handle to the root logger 61 | hdlr = logging.FileHandler(log_file) 62 | formatter = logging.Formatter( 63 | "%(asctime)s %(name)-10s %(levelname)-8s %(message)s", 64 | datefmt="%m-%d %H:%M:%S", 65 | ) 66 | hdlr.setFormatter(formatter) 67 | root_logger.addHandler(hdlr) 68 | root_logger.info("Root logger now logging to {}".format(log_file)) 69 | 70 | 71 | class Logger(object): 72 | ROOT_CONFIGURED = False 73 | 74 | @staticmethod 75 | def reconfigure_root(): 76 | """Reconfigure the root logger.""" 77 | configure_root() 78 | 79 | @staticmethod 80 | def get_logger( 81 | name, 82 | log_level=logging.INFO, 83 | log_file=None, 84 | global_log_file=False, 85 | silence=False, 86 | ): 87 | """ 88 | Build a logger. All logs will be propagated up to the root logger 89 | if not silenced. If log_file is provided, logs will be written out 90 | to that file. If global_log_file is true, log_file will be handed 91 | the root logger, otherwise it will only be used by this particular 92 | logger. 93 | 94 | Parameters 95 | ---------- 96 | name :obj:`str` 97 | The name of the logger to be built. 98 | log_level : `int` 99 | The log level. See the python logging module documentation 100 | for possible enum values. 101 | log_file :obj:`str` 102 | The path to the log file to log to. 103 | global_log_file :obj:`bool` 104 | Whether or not to use the given log_file for this particular 105 | logger or for the root logger. 106 | silence :obj:`bool` 107 | Whether or not to silence this logger. If it is silenced, the 108 | only way to get output from this logger is through a 109 | non-global log file. 110 | 111 | Returns 112 | ------- 113 | :obj:`logging.Logger` 114 | A custom logger. 115 | """ 116 | no_op = False 117 | # some checks for silencing/no-op logging 118 | if silence and global_log_file: 119 | raise ValueError( 120 | "You can't silence a logger and log to a global log file!" 121 | ) 122 | if silence and log_file is None: 123 | logging.warning("You are creating a no-op logger!") 124 | no_op = True 125 | 126 | # configure the root logger if it hasn't been already 127 | if not Logger.ROOT_CONFIGURED: 128 | configure_root() 129 | Logger.ROOT_CONFIGURED = True 130 | 131 | # build a logger 132 | logger = logging.getLogger(name) 133 | logger.setLevel(log_level) 134 | 135 | # silence the logger by preventing it from propagating upwards 136 | # to the root 137 | logger.propagate = not silence 138 | 139 | # configure the log file stream 140 | if log_file is not None: 141 | # if the log file is global, add it to the root logger 142 | if global_log_file: 143 | add_root_log_file(log_file) 144 | # otherwise add it to this particular logger 145 | else: 146 | hdlr = logging.FileHandler(log_file) 147 | formatter = logging.Formatter( 148 | "%(asctime)s %(name)-10s %(levelname)-8s %(message)s", 149 | datefmt="%m-%d %H:%M:%S", 150 | ) 151 | hdlr.setFormatter(formatter) 152 | logger.addHandler(hdlr) 153 | 154 | # add a no-op handler to suppress warnings about 155 | # there being no handlers 156 | if no_op: 157 | logger.addHandler(logging.NullHandler()) 158 | return logger 159 | 160 | @staticmethod 161 | def add_log_file(logger, log_file, global_log_file=False): 162 | """ 163 | Add a log file to this logger. If global_log_file is true, 164 | log_file will be handed the root logger, otherwise it will 165 | only be used by this particular logger. 166 | 167 | Parameters 168 | ---------- 169 | logger :obj:`logging.Logger` 170 | The logger. 171 | log_file :obj:`str` 172 | The path to the log file to log to. 173 | global_log_file :obj:`bool` 174 | Whether or not to use the given log_file for this particular 175 | logger or for the root logger. 176 | """ 177 | 178 | if global_log_file: 179 | add_root_log_file(log_file) 180 | else: 181 | hdlr = logging.FileHandler(log_file) 182 | formatter = logging.Formatter( 183 | "%(asctime)s %(name)-10s %(levelname)-8s %(message)s", 184 | datefmt="%m-%d %H:%M:%S", 185 | ) 186 | hdlr.setFormatter(formatter) 187 | logger.addHandler(hdlr) 188 | -------------------------------------------------------------------------------- /autolab_core/primitives.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common geometric primitives. 3 | Author: Jeff Mahler 4 | """ 5 | import numpy as np 6 | 7 | 8 | class Box(object): 9 | """A 2D box or 3D rectangular prism. 10 | 11 | Attributes 12 | ---------- 13 | dims : :obj:`numpy.ndarray` of float 14 | Maximal extent in x, y, and (optionally) z. 15 | 16 | width : float 17 | Maximal extent in x. 18 | 19 | height : float 20 | Maximal extent in y. 21 | 22 | area : float 23 | Area of projection onto xy plane. 24 | 25 | min_pt : :obj:`numpy.ndarray` of float 26 | The minimum x, y, and (optionally) z points. 27 | 28 | max_pt : :obj:`numpy.ndarray` of float 29 | The maximum x, y, and (optionally) z points. 30 | 31 | center : :obj:`numpy.ndarray` of float 32 | The center of the box in 2 or 3D coords. 33 | 34 | frame : :obj:`str` 35 | The frame in which this box is placed. 36 | """ 37 | 38 | def __init__(self, min_pt, max_pt, frame="unspecified"): 39 | """Initialize a box. 40 | 41 | Parameters 42 | ---------- 43 | min_pt : :obj:`numpy.ndarray` of float 44 | The minimum x, y, and (optionally) z points. 45 | 46 | max_pt : :obj:`numpy.ndarray` of float 47 | The maximum x, y, and (optionally) z points. 48 | 49 | frame : :obj:`str` 50 | The frame in which this box is placed. 51 | 52 | Raises 53 | ------ 54 | ValueError 55 | If max_pt is not strictly larger than min_pt in all dims. 56 | """ 57 | if np.any((max_pt - min_pt) < 0): 58 | raise ValueError("Min point must be smaller than max point") 59 | self._min_pt = min_pt 60 | self._max_pt = max_pt 61 | self._frame = frame 62 | 63 | @property 64 | def dims(self): 65 | """:obj:`numpy.ndarray` of float: Maximal extent in x, y, 66 | and (optionally) z""" 67 | return self._max_pt - self._min_pt 68 | 69 | @property 70 | def width(self): 71 | """float: Maximal extent in x.""" 72 | return int(np.round(self.dims[1])) 73 | 74 | @property 75 | def height(self): 76 | """float: Maximal extent in y.""" 77 | return int(np.round(self.dims[0])) 78 | 79 | @property 80 | def area(self): 81 | """float: Area of projection onto xy plane.""" 82 | return self.width * self.height 83 | 84 | @property 85 | def min_pt(self): 86 | """:obj:`numpy.ndarray` of float: The minimum x, y, and 87 | (optionally) z points.""" 88 | return self._min_pt 89 | 90 | @property 91 | def max_pt(self): 92 | """:obj:`numpy.ndarray` of float: The maximum x, y, and 93 | (optionally) z points.""" 94 | return self._max_pt 95 | 96 | @property 97 | def center(self): 98 | """:obj:`numpy.ndarray` of float: The center of the box in 2D or 99 | 3D coords.""" 100 | return self.min_pt + self.dims / 2.0 101 | 102 | @property 103 | def ci(self): 104 | """float value of center i coordinate""" 105 | return self.center[0] 106 | 107 | @property 108 | def cj(self): 109 | """float value of center j coordinate""" 110 | return self.center[1] 111 | 112 | @property 113 | def frame(self): 114 | """:obj:`str`: The frame in which this box is placed.""" 115 | return self._frame 116 | 117 | 118 | class Contour(object): 119 | """A set of pixels forming the boundary of an object of interest in an image. 120 | 121 | Attributes 122 | ---------- 123 | boundary_pixels : :obj:`numpy.ndarray` 124 | Nx2 array of pixel coordinates on the boundary of a contour 125 | bounding_box : :obj:`Box` 126 | smallest box containing the contour 127 | area : float 128 | area of the contour 129 | num_pixels : int 130 | number of pixels along the boundary 131 | """ 132 | 133 | def __init__(self, boundary_pixels, area=0.0, frame="unspecified"): 134 | self.boundary_pixels = boundary_pixels.squeeze() 135 | self.bounding_box = Box( 136 | np.min(self.boundary_pixels, axis=0), 137 | np.max(self.boundary_pixels, axis=0), 138 | frame, 139 | ) 140 | self.area = area 141 | 142 | @property 143 | def num_pixels(self): 144 | return self.boundary_pixels.shape[0] 145 | -------------------------------------------------------------------------------- /autolab_core/random_variables.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic Random Variable wrapper classes 3 | Author: Jeff Mahler 4 | """ 5 | from abc import ABCMeta, abstractmethod 6 | 7 | import numpy as np 8 | import scipy.stats 9 | 10 | from .rigid_transformations import RigidTransform 11 | from .utils import skew, is_positive_semi_definite 12 | 13 | 14 | class RandomVariable(object): 15 | """Abstract base class for random variables.""" 16 | 17 | __metaclass__ = ABCMeta 18 | 19 | def __init__(self, num_prealloc_samples=0): 20 | """Initialize a random variable with optional pre-sampling. 21 | 22 | Parameters 23 | ---------- 24 | num_prealloc_samples : int 25 | The number of samples to pre-allocate. 26 | """ 27 | self.num_prealloc_samples_ = num_prealloc_samples 28 | if self.num_prealloc_samples_ > 0: 29 | self._preallocate_samples() 30 | 31 | def _preallocate_samples(self): 32 | """Preallocate samples for faster adaptive sampling.""" 33 | self.prealloc_samples_ = [] 34 | for _ in range(self.num_prealloc_samples_): 35 | self.prealloc_samples_.append(self.sample()) 36 | 37 | @abstractmethod 38 | def sample(self, size=1): 39 | """Generate samples of the random variable. 40 | 41 | Parameters 42 | ---------- 43 | size : int 44 | The number of samples to generate. 45 | 46 | Returns 47 | ------- 48 | :obj:`numpy.ndarray` of float or int 49 | The samples of the random variable. If `size == 1`, then 50 | the returned value will not be wrapped in an array. 51 | """ 52 | pass 53 | 54 | def rvs(self, size=1, iteration=1): 55 | """Sample the random variable, using the preallocated samples if 56 | possible. 57 | 58 | Parameters 59 | ---------- 60 | size : int 61 | The number of samples to generate. 62 | 63 | iteration : int 64 | The location in the preallocated sample array to start sampling 65 | from. 66 | 67 | Returns 68 | ------- 69 | :obj:`numpy.ndarray` of float or int 70 | The samples of the random variable. If `size == 1`, then 71 | the returned value will not be wrapped in an array. 72 | """ 73 | if self.num_prealloc_samples_ > 0: 74 | samples = [] 75 | for i in range(size): 76 | samples.append( 77 | self.prealloc_samples_[ 78 | (iteration + i) % self.num_prealloc_samples_ 79 | ] 80 | ) 81 | if size == 1: 82 | return samples[0] 83 | return samples 84 | # generate a new sample 85 | return self.sample(size=size) 86 | 87 | 88 | class BernoulliRV(RandomVariable): 89 | """A Bernoulli random variable.""" 90 | 91 | def __init__(self, p, *args, **kwargs): 92 | """Initialize a Bernoulli random variable with probability p. 93 | 94 | Parameters 95 | ---------- 96 | p : float 97 | The probability that the random variable takes the value 1. 98 | """ 99 | self.p = p 100 | super(BernoulliRV, self).__init__(*args, **kwargs) 101 | 102 | def sample(self, size=1): 103 | """Generate samples of the random variable. 104 | 105 | Parameters 106 | ---------- 107 | size : int 108 | The number of samples to generate. 109 | 110 | Returns 111 | ------- 112 | :obj:`numpy.ndarray` of int or int 113 | The samples of the random variable. If `size == 1`, then 114 | the returned value will not be wrapped in an array. 115 | """ 116 | samples = scipy.stats.bernoulli.rvs(self.p, size=size) 117 | if size == 1: 118 | return samples[0] 119 | return samples 120 | 121 | 122 | class GaussianRV(RandomVariable): 123 | """A Gaussian random variable.""" 124 | 125 | def __init__(self, mu, sigma, *args, **kwargs): 126 | """Initialize a Gaussian random variable. 127 | 128 | Parameters 129 | ---------- 130 | mu : float 131 | The mean of the Gaussian. 132 | 133 | sigma : float 134 | The standard deviation of the Gaussian. 135 | """ 136 | self.mu = mu 137 | self.sigma = sigma 138 | 139 | super(GaussianRV, self).__init__(*args, **kwargs) 140 | 141 | def sample(self, size=1): 142 | """Generate samples of the random variable. 143 | 144 | Parameters 145 | ---------- 146 | size : int 147 | The number of samples to generate. 148 | 149 | Returns 150 | ------- 151 | :obj:`numpy.ndarray` of float 152 | The samples of the random variable. 153 | """ 154 | samples = scipy.stats.multivariate_normal.rvs( 155 | self.mu, self.sigma, size=size 156 | ) 157 | return samples 158 | 159 | 160 | class ArtificialRV(RandomVariable): 161 | """A fake RV that deterministically returns the given object.""" 162 | 163 | def __init__(self, obj, *args, **kwargs): 164 | """Initialize an artifical RV. 165 | 166 | Parameters 167 | ---------- 168 | obj : item 169 | The item to always return. 170 | 171 | num_prealloc_samples : int 172 | The number of samples to pre-allocate. 173 | """ 174 | self.obj_ = obj 175 | super(ArtificialRV, self).__init__(*args, **kwargs) 176 | 177 | def sample(self, size=1): 178 | """Generate copies of the artifical RV. 179 | 180 | Parameters 181 | ---------- 182 | size : int 183 | The number of samples to generate. 184 | 185 | Returns 186 | ------- 187 | :obj:`numpy.ndarray` of item 188 | The copies of the fake RV. 189 | """ 190 | return np.array([self.obj_] * size) 191 | 192 | 193 | class ArtificialSingleRV(ArtificialRV): 194 | """A single ArtificialRV.""" 195 | 196 | def sample(self, size=None): 197 | """Generate a single copy of the artificial RV. 198 | 199 | Returns 200 | ------- 201 | item 202 | The copies of the fake RV. 203 | """ 204 | return self.obj_ 205 | 206 | 207 | class GaussianRigidTransformRandomVariable(RandomVariable): 208 | """Random variable for sampling RigidTransformations with 209 | a Gaussian distribution over pose variables. 210 | 211 | We assume no correlation between translation and rotation, so 212 | their values are sampled independently. 213 | 214 | To sample rotations, we use the method described on page 7 here: 215 | http://ethaneade.com/lie.pdf 216 | 217 | Attributes 218 | ---------- 219 | mu_tra : :obj:`numpy.ndarray` of float or int 220 | Mean translation 221 | mu_rot : :obj:`numpy.ndarray` of float or int 222 | Mean rotation 223 | sigma_tra : :obj:`numpy.ndarray` of float or int 224 | Covariance of translation. 225 | sigma_rot: :obj:`numpy.ndarray` of float or int 226 | Covariance of rotation 227 | from_frame : str 228 | to_frame : str 229 | 230 | Raises 231 | ------ 232 | ValueError 233 | If mu_rot is not a valid rotation, or if either sigma_tra or sigma_rot 234 | is not positive semi-definite. 235 | """ 236 | 237 | def __init__( 238 | self, 239 | mu_tra=np.zeros(3), 240 | mu_rot=np.eye(3), 241 | sigma_tra=np.eye(3), 242 | sigma_rot=np.eye(3), 243 | from_frame="world", 244 | to_frame="world", 245 | *args, 246 | **kwargs 247 | ): 248 | if np.abs(np.linalg.det(mu_rot) - 1.0) > 1e-3: 249 | raise ValueError("Illegal rotation. Must have determinant == 1.0") 250 | if not is_positive_semi_definite(sigma_tra): 251 | raise ValueError( 252 | "Translation covariance is not positive semi-definite!" 253 | ) 254 | if not is_positive_semi_definite(sigma_rot): 255 | raise ValueError( 256 | "Rotation covariance is not positive semi-definite!" 257 | ) 258 | 259 | # read params 260 | self._mu_tra = mu_tra.copy() 261 | self._mu_rot = mu_rot.copy() 262 | self._sigma_tra = sigma_tra.copy() 263 | self._sigma_rot = sigma_rot.copy() 264 | 265 | diag_idx = np.diag_indices(3) 266 | self._sigma_tra[diag_idx] = np.clip( 267 | np.diag(self._sigma_tra), 1e-10, np.inf 268 | ) 269 | self._sigma_rot[diag_idx] = np.clip( 270 | np.diag(self._sigma_rot), 1e-10, np.inf 271 | ) 272 | 273 | self._from_frame = from_frame 274 | self._to_frame = to_frame 275 | 276 | # setup random variables 277 | self._t_rv = scipy.stats.multivariate_normal( 278 | self._mu_tra, self._sigma_tra 279 | ) 280 | self._r_xi_rv = scipy.stats.multivariate_normal( 281 | np.zeros(3), self._sigma_rot 282 | ) 283 | super(GaussianRigidTransformRandomVariable, self).__init__( 284 | *args, **kwargs 285 | ) 286 | 287 | def sample(self, size=1): 288 | """Sample rigid transform random variables. 289 | 290 | Parameters 291 | ---------- 292 | size : int 293 | number of sample to take 294 | 295 | Returns 296 | ------- 297 | :obj:`list` of :obj:`RigidTransform` 298 | sampled rigid transformations 299 | """ 300 | samples = [] 301 | for _ in range(size): 302 | # sample random pose 303 | xi = self._r_xi_rv.rvs(size=1) 304 | S_xi = skew(xi) 305 | R_sample = scipy.linalg.expm(S_xi).dot(self._mu_rot) 306 | 307 | t_sample = self._t_rv.rvs(size=1) 308 | 309 | samples.append( 310 | RigidTransform( 311 | rotation=R_sample, 312 | translation=t_sample, 313 | from_frame=self._from_frame, 314 | to_frame=self._to_frame, 315 | ) 316 | ) 317 | 318 | # not a list if only 1 sample 319 | if size == 1 and len(samples) > 0: 320 | return samples[0] 321 | return samples 322 | 323 | 324 | class IsotropicGaussianRigidTransformRandomVariable( 325 | GaussianRigidTransformRandomVariable 326 | ): 327 | """Random variable for sampling RigidTransformations with 328 | a zero-mean isotropic Gaussian distribution over pose variables. 329 | 330 | Attributes 331 | ---------- 332 | sigma_trans : float 333 | variance for translation 334 | sigma_rot : float 335 | variance for rotation 336 | from_frame : str 337 | to_frame : str 338 | 339 | """ 340 | 341 | def __init__( 342 | self, 343 | sigma_trans, 344 | sigma_rot, 345 | from_frame="world", 346 | to_frame="world", 347 | *args, 348 | **kwargs 349 | ): 350 | super(IsotropicGaussianRigidTransformRandomVariable, self).__init__( 351 | sigma_tra=max(1e-10, sigma_trans) * np.eye(3), 352 | sigma_rot=max(1e-10, sigma_rot) * np.eye(3), 353 | from_frame=from_frame, 354 | to_frame=to_frame, 355 | *args, 356 | **kwargs 357 | ) 358 | -------------------------------------------------------------------------------- /autolab_core/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Commonly used helper functions 3 | Author: Jeff Mahler 4 | """ 5 | import logging 6 | import os 7 | import numpy as np 8 | 9 | 10 | def gen_experiment_id(n=10): 11 | """Generate a random string with n characters. 12 | 13 | Parameters 14 | ---------- 15 | n : int 16 | The length of the string to be generated. 17 | 18 | Returns 19 | ------- 20 | :obj:`str` 21 | A string with only alphabetic characters. 22 | """ 23 | chrs = "abcdefghijklmnopqrstuvwxyz" 24 | inds = np.random.randint(0, len(chrs), size=n) 25 | return "".join([chrs[i] for i in inds]) 26 | 27 | 28 | def get_elapsed_time(time_in_seconds): 29 | """Helper function to get elapsed time in human-readable format. 30 | 31 | Parameters 32 | ---------- 33 | time_in_seconds : float 34 | runtime, in seconds 35 | 36 | Returns 37 | ------- 38 | str 39 | formatted human-readable string describing the time 40 | """ 41 | if time_in_seconds < 60: 42 | return "%.1f seconds" % (time_in_seconds) 43 | elif time_in_seconds < 3600: 44 | return "%.1f minutes" % (time_in_seconds / 60) 45 | else: 46 | return "%.1f hours" % (time_in_seconds / 3600) 47 | 48 | 49 | def mkdir_safe(path): 50 | """Creates a directory if it does not already exist. 51 | 52 | Parameters 53 | ---------- 54 | path : str 55 | path to the directory to create 56 | 57 | Returns 58 | ------- 59 | bool 60 | True if the directory was created, False otherwise 61 | """ 62 | if not os.path.exists(path): 63 | os.mkdir(path) 64 | 65 | 66 | def histogram( 67 | values, num_bins, bounds, normalized=True, plot=False, color="b" 68 | ): 69 | """Generate a histogram plot. 70 | 71 | Parameters 72 | ---------- 73 | values : :obj:`numpy.ndarray` 74 | An array of values to put in the histogram. 75 | 76 | num_bins : int 77 | The number equal-width bins in the histogram. 78 | 79 | bounds : :obj:`tuple` of float 80 | Two floats - a min and a max - that define the lower and upper 81 | ranges of the histogram, respectively. 82 | 83 | normalized : bool 84 | If True, the bins will show the percentage of elements they contain 85 | rather than raw counts. 86 | 87 | plot : bool 88 | If True, this function uses pyplot to plot the histogram. 89 | 90 | color : :obj:`str` 91 | The color identifier for the plotted bins. 92 | 93 | Returns 94 | ------- 95 | :obj:`tuple of `:obj:`numpy.ndarray` 96 | The values of the histogram and the bin edges as ndarrays. 97 | """ 98 | hist, bins = np.histogram(values, bins=num_bins, range=bounds) 99 | width = bins[1] - bins[0] 100 | if normalized: 101 | if np.sum(hist) > 0: 102 | hist = hist.astype(np.float32) / np.sum(hist) 103 | if plot: 104 | import matplotlib.pyplot as plt 105 | 106 | plt.bar(bins[:-1], hist, width=width, color=color) 107 | return hist, bins 108 | 109 | 110 | def skew(xi): 111 | """Return the skew-symmetric matrix that can be used to calculate 112 | cross-products with vector xi. 113 | 114 | Multiplying this matrix by a vector `v` gives the same result 115 | as `xi x v`. 116 | 117 | Parameters 118 | ---------- 119 | xi : :obj:`numpy.ndarray` of float 120 | A 3-entry vector. 121 | 122 | Returns 123 | ------- 124 | :obj:`numpy.ndarray` of float 125 | The 3x3 skew-symmetric cross product matrix for the vector. 126 | """ 127 | S = np.array( 128 | [[0, -xi[2], xi[1]], [xi[2], 0, -xi[0]], [-xi[1], xi[0], 0]], 129 | dtype=np.float, 130 | ) 131 | return S 132 | 133 | 134 | def deskew(S): 135 | """Converts a skew-symmetric cross-product matrix to its corresponding 136 | vector. Only works for 3x3 matrices. 137 | 138 | Parameters 139 | ---------- 140 | S : :obj:`numpy.ndarray` of float 141 | A 3x3 skew-symmetric matrix. 142 | 143 | Returns 144 | ------- 145 | :obj:`numpy.ndarray` of float 146 | A 3-entry vector that corresponds to the given cross product matrix. 147 | """ 148 | x = np.zeros(3) 149 | x[0] = S[2, 1] 150 | x[1] = S[0, 2] 151 | x[2] = S[1, 0] 152 | return x 153 | 154 | 155 | def reverse_dictionary(d): 156 | """Reverses the key value pairs for a given dictionary. 157 | 158 | Parameters 159 | ---------- 160 | d : :obj:`dict` 161 | dictionary to reverse 162 | 163 | Returns 164 | ------- 165 | :obj:`dict` 166 | dictionary with keys and values swapped 167 | """ 168 | rev_d = {} 169 | [rev_d.update({v: k}) for k, v in d.items()] 170 | return rev_d 171 | 172 | 173 | def pretty_str_time(dt): 174 | """Get a pretty string for the given datetime object. 175 | 176 | Parameters 177 | ---------- 178 | dt : :obj:`datetime` 179 | A datetime object to format. 180 | 181 | Returns 182 | ------- 183 | :obj:`str` 184 | The `datetime` formatted as {year}_{month}_{day}_{hour}_{minute}. 185 | """ 186 | return "{0}_{1}_{2}_{3}:{4}".format( 187 | dt.year, dt.month, dt.day, dt.hour, dt.minute 188 | ) 189 | 190 | 191 | def filenames(directory, tag="", sorted=False, recursive=False): 192 | """Reads in all filenames from a directory that contain a specified substring. 193 | 194 | Parameters 195 | ---------- 196 | directory : :obj:`str` 197 | the directory to read from 198 | tag : :obj:`str` 199 | optional tag to match in the filenames 200 | sorted : bool 201 | whether or not to sort the filenames 202 | recursive : bool 203 | whether or not to search for the files recursively 204 | 205 | Returns 206 | ------- 207 | :obj:`list` of :obj:`str` 208 | filenames to read from 209 | """ 210 | if recursive: 211 | f = [ 212 | os.path.join(directory, f) 213 | for directory, _, filename in os.walk(directory) 214 | for f in filename 215 | if f.find(tag) > -1 216 | ] 217 | else: 218 | f = [ 219 | os.path.join(directory, f) 220 | for f in os.listdir(directory) 221 | if f.find(tag) > -1 222 | ] 223 | if sorted: 224 | f.sort() 225 | return f 226 | 227 | 228 | def sph2cart(r, az, elev): 229 | """Convert spherical to cartesian coordinates. 230 | 231 | Attributes 232 | ---------- 233 | r : float 234 | radius 235 | az : float 236 | aziumth (angle about z axis) 237 | elev : float 238 | elevation from xy plane 239 | 240 | Returns 241 | ------- 242 | float 243 | x-coordinate 244 | float 245 | y-coordinate 246 | float 247 | z-coordinate 248 | """ 249 | x = r * np.cos(az) * np.sin(elev) 250 | y = r * np.sin(az) * np.sin(elev) 251 | z = r * np.cos(elev) 252 | return x, y, z 253 | 254 | 255 | def cart2sph(x, y, z): 256 | """Convert cartesian to spherical coordinates. 257 | 258 | Attributes 259 | ---------- 260 | x : float 261 | x-coordinate 262 | y : float 263 | y-coordinate 264 | z : float 265 | z-coordinate 266 | 267 | Returns 268 | ------- 269 | float 270 | radius 271 | float 272 | aziumth 273 | float 274 | elevation 275 | """ 276 | r = np.sqrt(x**2 + y**2 + z**2) 277 | if x > 0 and y > 0: 278 | az = np.arctan(y / x) 279 | elif x > 0 and y < 0: 280 | az = 2 * np.pi - np.arctan(-y / x) 281 | elif x < 0 and y > 0: 282 | az = np.pi - np.arctan(-y / x) 283 | elif x < 0 and y < 0: 284 | az = np.pi + np.arctan(y / x) 285 | elif x == 0 and y > 0: 286 | az = np.pi / 2 287 | elif x == 0 and y < 0: 288 | az = 3 * np.pi / 2 289 | elif y == 0 and x > 0: 290 | az = 0 291 | elif y == 0 and x < 0: 292 | az = np.pi 293 | elev = np.arccos(z / r) 294 | return r, az, elev 295 | 296 | 297 | def keyboard_input(message, yesno=False): 298 | """Get keyboard input from a human, optionally reasking for valid 299 | yes or no input. 300 | 301 | Parameters 302 | ---------- 303 | message : :obj:`str` 304 | the message to display to the user 305 | yesno : :obj:`bool` 306 | whether or not to enforce yes or no inputs 307 | 308 | Returns 309 | ------- 310 | :obj:`str` 311 | string input by the human 312 | """ 313 | # add space for readability 314 | message += " " 315 | 316 | # add yes or no to message 317 | if yesno: 318 | message += "[y/n] " 319 | 320 | # ask human 321 | human_input = input(message) 322 | if yesno: 323 | while human_input.lower() != "n" and human_input.lower() != "y": 324 | logging.info("Did not understand input. Please answer 'y' or 'n'") 325 | human_input = input(message) 326 | return human_input 327 | 328 | 329 | def sqrt_ceil(n): 330 | """Computes the square root of an number rounded up to the nearest 331 | integer. Very useful for plotting. 332 | 333 | Parameters 334 | ---------- 335 | n : int 336 | number to sqrt 337 | 338 | Returns 339 | ------- 340 | int 341 | the sqrt rounded up to the nearest integer 342 | """ 343 | return int(np.ceil(np.sqrt(n))) 344 | 345 | 346 | def is_positive_definite(A): 347 | """Checks if a given matrix is positive definite. 348 | 349 | See https://stackoverflow.com/a/16266736 for details. 350 | 351 | Parameters 352 | ---------- 353 | A : :obj:`numpy.ndarray` of float or int 354 | The square matrix of interest 355 | 356 | Returns 357 | ------- 358 | bool 359 | whether or not A is positive definite 360 | """ 361 | is_pd = True 362 | 363 | try: 364 | np.linalg.cholesky(A) 365 | except np.linalg.LinAlgError: 366 | is_pd = False 367 | 368 | return is_pd 369 | 370 | 371 | def is_positive_semi_definite(A): 372 | """Checks if a given matrix is positive semi definite. 373 | 374 | Parameters 375 | ---------- 376 | A : :obj:`numpy.ndarray` of float or int 377 | The square matrix of interest 378 | 379 | Returns 380 | ------- 381 | bool 382 | whether or not A is positive semi-definite 383 | """ 384 | return is_positive_definite(A + np.eye(len(A)) * 1e-20) 385 | -------------------------------------------------------------------------------- /autolab_core/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.1" 2 | -------------------------------------------------------------------------------- /autolab_core/yaml_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | YAML Configuration Parser 3 | Author : Jeff Mahler 4 | """ 5 | import os 6 | import ruamel.yaml as yaml 7 | import re 8 | from collections import OrderedDict 9 | 10 | 11 | class YamlConfig(object): 12 | """Class to load a configuration file and parse it into a dictionary. 13 | 14 | Attributes 15 | ---------- 16 | config : :obj:`dictionary` 17 | A dictionary that contains the contents of the configuration. 18 | """ 19 | 20 | def __init__(self, filename=None): 21 | """Initialize a YamlConfig by loading it from the given file. 22 | 23 | Parameters 24 | ---------- 25 | filename : :obj:`str` 26 | The filename of the .yaml file that contains the configuration. 27 | """ 28 | self.config = {} 29 | if filename: 30 | self._load_config(filename) 31 | 32 | def keys(self): 33 | """Return the keys of the config dictionary. 34 | 35 | Returns 36 | ------- 37 | :obj:`list` of :obj:`Object` 38 | A list of the keys in the config dictionary. 39 | """ 40 | return self.config.keys() 41 | 42 | def update(self, d): 43 | """Update the config with a dictionary of parameters. 44 | 45 | Parameters 46 | ---------- 47 | d : :obj:`dict` 48 | dictionary of parameters 49 | """ 50 | self.config.update(d) 51 | 52 | def get(self, key, default=None): 53 | """Allows for get method like python dict.""" 54 | return self.config.get(key, default) 55 | 56 | def __contains__(self, key): 57 | """Overrides 'in' operator.""" 58 | return key in self.config.keys() 59 | 60 | def __getitem__(self, key): 61 | """Overrides the key access operator [].""" 62 | return self.config[key] 63 | 64 | def __setitem__(self, key, val): 65 | """Overrides the keyed setting operator [].""" 66 | self.config[key] = val 67 | 68 | def iteritems(self): 69 | """Returns iterator over config dict.""" 70 | return self.config.iteritems() 71 | 72 | def save(self, filename): 73 | """Save a YamlConfig to disk.""" 74 | y = yaml.YAML() 75 | y.dump(self.config, open(filename, "w")) 76 | 77 | def _load_config(self, filename): 78 | """Loads a yaml configuration file from the given filename. 79 | 80 | Parameters 81 | ---------- 82 | filename : :obj:`str` 83 | The filename of the .yaml file that contains the configuration. 84 | """ 85 | # Read entire file for metadata 86 | fh = open(filename, "r") 87 | self.file_contents = fh.read() 88 | 89 | # Replace !include directives with content 90 | config_dir = os.path.split(filename)[0] 91 | include_re = re.compile(r"^(.*)!include\s+(.*)$", re.MULTILINE) 92 | 93 | def recursive_load(matchobj, path): 94 | first_spacing = matchobj.group(1) 95 | other_spacing = first_spacing.replace("-", " ") 96 | fname = os.path.join(path, matchobj.group(2).rstrip()) 97 | new_path, _ = os.path.split(fname) 98 | new_path = os.path.realpath(new_path) 99 | text = "" 100 | with open(fname) as f: 101 | text = f.read() 102 | text = first_spacing + text 103 | text = text.replace( 104 | "\n", "\n{}".format(other_spacing), text.count("\n") - 1 105 | ) 106 | return re.sub( 107 | include_re, lambda m: recursive_load(m, new_path), text 108 | ) 109 | 110 | # def include_repl(matchobj): 111 | # first_spacing = matchobj.group(1) 112 | # other_spacing = first_spacing.replace("-", " ") 113 | # fname = os.path.join(config_dir, matchobj.group(2)) 114 | # text = "" 115 | # with open(fname) as f: 116 | # text = f.read() 117 | # text = first_spacing + text 118 | # text = text.replace( 119 | # "\n", "\n{}".format(other_spacing), text.count("\n") - 1 120 | # ) 121 | # return text 122 | 123 | self.file_contents = re.sub( 124 | include_re, 125 | lambda m: recursive_load(m, config_dir), 126 | self.file_contents, 127 | ) 128 | # Read in dictionary 129 | self.config = self.__ordered_load(self.file_contents) 130 | 131 | # Convert functions of other params to true expressions 132 | for k in self.config.keys(): 133 | self.config[k] = YamlConfig.__convert_key(self.config[k]) 134 | 135 | fh.close() 136 | 137 | # Load core configuration 138 | return self.config 139 | 140 | @staticmethod 141 | def __convert_key(expression): 142 | """Converts keys in YAML that reference other keys.""" 143 | if ( 144 | type(expression) is str 145 | and len(expression) > 2 146 | and expression[1] == "!" 147 | ): 148 | expression = eval(expression[2:-1]) 149 | return expression 150 | 151 | def __ordered_load( 152 | self, stream, Loader=yaml.Loader, object_pairs_hook=OrderedDict 153 | ): 154 | """Load an ordered dictionary from a yaml file. 155 | 156 | Note 157 | ---- 158 | Borrowed from John Schulman. 159 | http://stackoverflow.com/questions/5121931/in-python-how-can-you-load-yaml-mappings-as-ordereddicts/21048064#21048064" 160 | """ 161 | 162 | class OrderedLoader(Loader): 163 | pass 164 | 165 | OrderedLoader.add_constructor( 166 | yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, 167 | lambda loader, node: object_pairs_hook( 168 | loader.construct_pairs(node) 169 | ), 170 | ) 171 | return yaml.load(stream, OrderedLoader) 172 | 173 | def __iter__(self): 174 | # Converting to a `list` will have a higher memory overhead, 175 | # but realistically there should not be *that* many keys. 176 | self._keys = list(self.config.keys()) 177 | return self 178 | 179 | def __next__(self): 180 | try: 181 | return self._keys.pop(0) 182 | except IndexError: 183 | raise StopIteration 184 | -------------------------------------------------------------------------------- /cfg/tools/aggregate_tensor_datasets.yaml: -------------------------------------------------------------------------------- 1 | input_datasets: 2 | - /mnt/data/datasets/knapp/large_suction/grasps 3 | - /mnt/data/datasets/knapp/large_suction_no_segmask/grasps 4 | 5 | output_dataset: /mnt/data/datasets/knapp/large_suction_big 6 | 7 | exclude_fields: 8 | - color_ims 9 | 10 | display_rate: 1 11 | -------------------------------------------------------------------------------- /cfg/tools/compute_dataset_statistics.yaml: -------------------------------------------------------------------------------- 1 | analysis_fields: 2 | - rewards 3 | 4 | num_percentiles: 10 5 | thresholds: 6 | - 0.0 7 | log_rate: 100 8 | 9 | font_size: 15 10 | line_width: 5 11 | dpi: 100 12 | num_bins: 25 13 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | GH_PAGES_SOURCES = docs autolab_core 10 | 11 | # User-friendly check for sphinx-build 12 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 13 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 14 | endif 15 | 16 | # Internal variables. 17 | PAPEROPT_a4 = -D latex_paper_size=a4 18 | PAPEROPT_letter = -D latex_paper_size=letter 19 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 20 | # the i18n builder cannot share the environment and doctrees with the others 21 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 22 | 23 | .PHONY: help 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " applehelp to make an Apple Help Book" 34 | @echo " devhelp to make HTML files and a Devhelp project" 35 | @echo " epub to make an epub" 36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 37 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 39 | @echo " text to make text files" 40 | @echo " man to make manual pages" 41 | @echo " texinfo to make Texinfo files" 42 | @echo " info to make Texinfo files and run them through makeinfo" 43 | @echo " gettext to make PO message catalogs" 44 | @echo " changes to make an overview of all changed/added/deprecated items" 45 | @echo " xml to make Docutils-native XML files" 46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 47 | @echo " linkcheck to check all external links for integrity" 48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 49 | @echo " coverage to run coverage check of the documentation (if enabled)" 50 | 51 | .PHONY: clean 52 | clean: 53 | rm -rf $(BUILDDIR)/* 54 | 55 | .PHONY: gh-pages 56 | gh-pages: 57 | git checkout gh-pages && \ 58 | cd .. && \ 59 | git rm -rf . && git clean -fxd && \ 60 | git checkout master $(GH_PAGES_SOURCES) && \ 61 | git reset HEAD && \ 62 | cd docs && \ 63 | make html && \ 64 | cd .. && \ 65 | mv -fv docs/build/html/* ./ && \ 66 | touch .nojekyll && \ 67 | rm -rf $(GH_PAGES_SOURCES) && \ 68 | git add -A && \ 69 | git commit -m "Generated gh-pages for `git log master -1 --pretty=short --abbrev-commit`" && \ 70 | git push origin --delete gh-pages && \ 71 | git push origin gh-pages ; \ 72 | git checkout master 73 | 74 | .PHONY: html 75 | html: 76 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 77 | @echo 78 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 79 | 80 | .PHONY: dirhtml 81 | dirhtml: 82 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 83 | @echo 84 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 85 | 86 | .PHONY: singlehtml 87 | singlehtml: 88 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 89 | @echo 90 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 91 | 92 | .PHONY: pickle 93 | pickle: 94 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 95 | @echo 96 | @echo "Build finished; now you can process the pickle files." 97 | 98 | .PHONY: json 99 | json: 100 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 101 | @echo 102 | @echo "Build finished; now you can process the JSON files." 103 | 104 | .PHONY: htmlhelp 105 | htmlhelp: 106 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 107 | @echo 108 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 109 | ".hhp project file in $(BUILDDIR)/htmlhelp." 110 | 111 | .PHONY: qthelp 112 | qthelp: 113 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 114 | @echo 115 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 116 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 117 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/autolab_core.qhcp" 118 | @echo "To view the help file:" 119 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/autolab_core.qhc" 120 | 121 | .PHONY: applehelp 122 | applehelp: 123 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 124 | @echo 125 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 126 | @echo "N.B. You won't be able to view it unless you put it in" \ 127 | "~/Library/Documentation/Help or install it in your application" \ 128 | "bundle." 129 | 130 | .PHONY: devhelp 131 | devhelp: 132 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 133 | @echo 134 | @echo "Build finished." 135 | @echo "To view the help file:" 136 | @echo "# mkdir -p $$HOME/.local/share/devhelp/autolab_core" 137 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/autolab_core" 138 | @echo "# devhelp" 139 | 140 | .PHONY: epub 141 | epub: 142 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 143 | @echo 144 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 145 | 146 | .PHONY: latex 147 | latex: 148 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 149 | @echo 150 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 151 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 152 | "(use \`make latexpdf' here to do that automatically)." 153 | 154 | .PHONY: latexpdf 155 | latexpdf: 156 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 157 | @echo "Running LaTeX files through pdflatex..." 158 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 159 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 160 | 161 | .PHONY: latexpdfja 162 | latexpdfja: 163 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 164 | @echo "Running LaTeX files through platex and dvipdfmx..." 165 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 166 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 167 | 168 | .PHONY: text 169 | text: 170 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 171 | @echo 172 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 173 | 174 | .PHONY: man 175 | man: 176 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 177 | @echo 178 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 179 | 180 | .PHONY: texinfo 181 | texinfo: 182 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 183 | @echo 184 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 185 | @echo "Run \`make' in that directory to run these through makeinfo" \ 186 | "(use \`make info' here to do that automatically)." 187 | 188 | .PHONY: info 189 | info: 190 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 191 | @echo "Running Texinfo files through makeinfo..." 192 | make -C $(BUILDDIR)/texinfo info 193 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 194 | 195 | .PHONY: gettext 196 | gettext: 197 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 198 | @echo 199 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 200 | 201 | .PHONY: changes 202 | changes: 203 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 204 | @echo 205 | @echo "The overview file is in $(BUILDDIR)/changes." 206 | 207 | .PHONY: linkcheck 208 | linkcheck: 209 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 210 | @echo 211 | @echo "Link check complete; look for any errors in the above output " \ 212 | "or in $(BUILDDIR)/linkcheck/output.txt." 213 | 214 | .PHONY: doctest 215 | doctest: 216 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 217 | @echo "Testing of doctests in the sources finished, look at the " \ 218 | "results in $(BUILDDIR)/doctest/output.txt." 219 | 220 | .PHONY: coverage 221 | coverage: 222 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 223 | @echo "Testing of coverage in the sources finished, look at the " \ 224 | "results in $(BUILDDIR)/coverage/python.txt." 225 | 226 | .PHONY: xml 227 | xml: 228 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 229 | @echo 230 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 231 | 232 | .PHONY: pseudoxml 233 | pseudoxml: 234 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 235 | @echo 236 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 237 | -------------------------------------------------------------------------------- /docs/gh_deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | make gh-pages 3 | cd ../docs 4 | -------------------------------------------------------------------------------- /docs/source/api/csv_model.rst: -------------------------------------------------------------------------------- 1 | CSV Processing 2 | ============== 3 | 4 | CSVModel 5 | ~~~~~~~~ 6 | A model for processing CSV files, including reading, writing, and 7 | table structuring. 8 | 9 | .. autoclass:: autolab_core.CSVModel 10 | -------------------------------------------------------------------------------- /docs/source/api/dual_quaternion.rst: -------------------------------------------------------------------------------- 1 | Quaternions 2 | =========== 3 | 4 | DualQuaternion 5 | ~~~~~~~~~~~~~~ 6 | A class for represeting and manipulating dual quaternions. 7 | 8 | .. autoclass:: autolab_core.DualQuaternion 9 | 10 | -------------------------------------------------------------------------------- /docs/source/api/exceptions.rst: -------------------------------------------------------------------------------- 1 | Exceptions 2 | ========== 3 | 4 | .. autoclass:: autolab_core.TerminateException 5 | -------------------------------------------------------------------------------- /docs/source/api/experiment_logger.rst: -------------------------------------------------------------------------------- 1 | Experiment Logging 2 | ================== 3 | Classes for logging experimental data to files in a well-specified format. 4 | 5 | ExperimentLogger 6 | ~~~~~~~~~~~~~~~~ 7 | .. autoclass:: autolab_core.ExperimentLogger 8 | -------------------------------------------------------------------------------- /docs/source/api/json_serialization.rst: -------------------------------------------------------------------------------- 1 | JSON Manipulation 2 | ================= 3 | Functions for JSON serialization and deserialization that use custom hooks 4 | for properly processing numpy arrays. 5 | 6 | .. autofunction:: autolab_core.dump 7 | 8 | .. autofunction:: autolab_core.load 9 | 10 | -------------------------------------------------------------------------------- /docs/source/api/points.rst: -------------------------------------------------------------------------------- 1 | Points and Point Clouds 2 | ======================= 3 | A set of useful classes for 2D and 3D points, vectors, 4 | planes, point clouds, and normal clouds. 5 | 6 | BagOfPoints 7 | ~~~~~~~~~~~~~~ 8 | .. autoclass:: autolab_core.BagOfPoints 9 | 10 | BagOfVectors 11 | ~~~~~~~~~~~~~~ 12 | .. autoclass:: autolab_core.BagOfVectors 13 | 14 | Point 15 | ~~~~~~~~~~~~~~ 16 | .. autoclass:: autolab_core.Point 17 | 18 | Direction 19 | ~~~~~~~~~~~~~~ 20 | .. autoclass:: autolab_core.Direction 21 | 22 | Plane3D 23 | ~~~~~~~~~~~~~~ 24 | .. autoclass:: autolab_core.Plane3D 25 | 26 | PointCloud 27 | ~~~~~~~~~~~~~~ 28 | .. autoclass:: autolab_core.PointCloud 29 | 30 | NormalCloud 31 | ~~~~~~~~~~~~~~ 32 | .. autoclass:: autolab_core.NormalCloud 33 | 34 | ImageCoords 35 | ~~~~~~~~~~~~~~ 36 | .. autoclass:: autolab_core.ImageCoords 37 | 38 | RgbCloud 39 | ~~~~~~~~~~~~~~ 40 | .. autoclass:: autolab_core.RgbCloud 41 | 42 | RgbPointCloud 43 | ~~~~~~~~~~~~~~ 44 | .. autoclass:: autolab_core.RgbPointCloud 45 | 46 | PointNormalCloud 47 | ~~~~~~~~~~~~~~~~ 48 | .. autoclass:: autolab_core.PointNormalCloud 49 | 50 | -------------------------------------------------------------------------------- /docs/source/api/primitives.rst: -------------------------------------------------------------------------------- 1 | Geometric Primitives 2 | ==================== 3 | A set of geometric primitives. 4 | 5 | Box 6 | ~~~ 7 | 8 | .. autoclass:: autolab_core.Box 9 | 10 | Contour 11 | ~~~~~~~ 12 | .. autoclass:: autolab_core.Contour 13 | -------------------------------------------------------------------------------- /docs/source/api/random_variables.rst: -------------------------------------------------------------------------------- 1 | Random Variables 2 | ================ 3 | A set of random variable classes with sampling methods. 4 | 5 | RandomVariable 6 | ~~~~~~~~~~~~~~ 7 | .. autoclass:: autolab_core.RandomVariable 8 | 9 | BernoulliRV 10 | ~~~~~~~~~~~~~~ 11 | .. autoclass:: autolab_core.BernoulliRV 12 | 13 | GaussianRV 14 | ~~~~~~~~~~~~~~ 15 | .. autoclass:: autolab_core.GaussianRV 16 | 17 | ArtificialRV 18 | ~~~~~~~~~~~~~~ 19 | .. autoclass:: autolab_core.ArtificialRV 20 | 21 | ArtificialSingleRV 22 | ~~~~~~~~~~~~~~~~~~ 23 | .. autoclass:: autolab_core.ArtificialSingleRV 24 | 25 | -------------------------------------------------------------------------------- /docs/source/api/rigid_transform.rst: -------------------------------------------------------------------------------- 1 | Rigid Transformations 2 | ===================== 3 | 4 | RigidTransform 5 | ~~~~~~~~~~~~~~ 6 | .. autoclass:: autolab_core.RigidTransform 7 | 8 | SimilarityTransform 9 | ~~~~~~~~~~~~~~~~~~~ 10 | .. autoclass:: autolab_core.SimilarityTransform 11 | 12 | -------------------------------------------------------------------------------- /docs/source/api/utils.rst: -------------------------------------------------------------------------------- 1 | Utility Functions 2 | ================= 3 | General utility functions. 4 | 5 | .. autofunction:: autolab_core.gen_experiment_id 6 | .. autofunction:: autolab_core.histogram 7 | .. autofunction:: autolab_core.skew 8 | .. autofunction:: autolab_core.deskew 9 | .. autofunction:: autolab_core.pretty_str_time 10 | .. autofunction:: autolab_core.filenames 11 | .. autofunction:: autolab_core.sph2cart 12 | .. autofunction:: autolab_core.cart2sph 13 | -------------------------------------------------------------------------------- /docs/source/api/yaml_config.rst: -------------------------------------------------------------------------------- 1 | YAML Config Processor 2 | ===================== 3 | A class for loading and writing to YAML configuration files. 4 | 5 | .. autoclass:: autolab_core.YamlConfig 6 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # core documentation build configuration file, created by 4 | # sphinx-quickstart on Sun Oct 16 14:33:48 2016. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | import sphinx_rtd_theme 18 | from autolab_core import __version__ 19 | 20 | # If extensions (or modules to document with autodoc) are in another directory, 21 | # add these directories to sys.path here. If the directory is relative to the 22 | # documentation root, use os.path.abspath to make it absolute, like shown here. 23 | sys.path.insert(0, os.path.abspath("../../")) 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # needs_sphinx = '1.0' 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = ["sphinx.ext.autodoc", "sphinxcontrib.napoleon"] 34 | autoclass_content = "class" 35 | autodoc_member_order = "bysource" 36 | autodoc_default_flags = ["members", "show-inheritance"] 37 | napoleon_include_special_with_doc = True 38 | napoleon_include_init_with_doc = True 39 | 40 | # Add any paths that contain templates here, relative to this directory. 41 | templates_path = ["_templates"] 42 | 43 | # The suffix(es) of source filenames. 44 | # You can specify multiple suffix as a list of string: 45 | # source_suffix = ['.rst', '.md'] 46 | source_suffix = ".rst" 47 | 48 | # The encoding of source files. 49 | # source_encoding = 'utf-8-sig' 50 | 51 | # The master toctree document. 52 | master_doc = "index" 53 | 54 | # General information about the project. 55 | project = "autolab_core" 56 | copyright = "2016, Jeff Mahler" 57 | author = "Jeff Mahler" 58 | 59 | # The version info for the project you're documenting, acts as replacement for 60 | # |version| and |release|, also used in various other places throughout the 61 | # built documents. 62 | # 63 | # The short X.Y version. 64 | version = __version__ 65 | # The full version, including alpha/beta/rc tags. 66 | release = __version__ 67 | 68 | # The language for content autogenerated by Sphinx. Refer to documentation 69 | # for a list of supported languages. 70 | # 71 | # This is also used if you do content translation via gettext catalogs. 72 | # Usually you set "language" from the command line for these cases. 73 | language = None 74 | 75 | # There are two options for replacing |today|: either, you set today to some 76 | # non-false value, then it is used: 77 | # today = '' 78 | # Else, today_fmt is used as the format for a strftime call. 79 | # today_fmt = '%B %d, %Y' 80 | 81 | # List of patterns, relative to source directory, that match files and 82 | # directories to ignore when looking for source files. 83 | exclude_patterns = [] 84 | 85 | # The reST default role (used for this markup: `text`) to use for all 86 | # documents. 87 | # default_role = None 88 | 89 | # If true, '()' will be appended to :func: etc. cross-reference text. 90 | # add_function_parentheses = True 91 | 92 | # If true, the current module name will be prepended to all description 93 | # unit titles (such as .. function::). 94 | # add_module_names = True 95 | 96 | # If true, sectionauthor and moduleauthor directives will be shown in the 97 | # output. They are ignored by default. 98 | # show_authors = False 99 | 100 | # The name of the Pygments (syntax highlighting) style to use. 101 | pygments_style = "sphinx" 102 | 103 | # A list of ignored prefixes for module index sorting. 104 | # modindex_common_prefix = [] 105 | 106 | # If true, keep warnings as "system message" paragraphs in the built documents. 107 | # keep_warnings = False 108 | 109 | # If true, `todo` and `todoList` produce output, else they produce nothing. 110 | todo_include_todos = False 111 | 112 | 113 | # -- Options for HTML output ---------------------------------------------- 114 | 115 | # The theme to use for HTML and HTML Help pages. See the documentation for 116 | # a list of builtin themes. 117 | html_theme = "sphinx_rtd_theme" 118 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 119 | 120 | # Theme options are theme-specific and customize the look and feel of a theme 121 | # further. For a list of options available for each theme, see the 122 | # documentation. 123 | # html_theme_options = {} 124 | 125 | # Add any paths that contain custom themes here, relative to this directory. 126 | # html_theme_path = [] 127 | 128 | # The name for this set of Sphinx documents. If None, it defaults to 129 | # " v documentation". 130 | # html_title = None 131 | 132 | # A shorter title for the navigation bar. Default is the same as html_title. 133 | # html_short_title = None 134 | 135 | # The name of an image file (relative to this directory) to place at the top 136 | # of the sidebar. 137 | # html_logo = None 138 | 139 | # The name of an image file (relative to this directory) to use as a favicon of 140 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 141 | # pixels large. 142 | # html_favicon = None 143 | 144 | # Add any paths that contain custom static files (such as style sheets) here, 145 | # relative to this directory. They are copied after the builtin static files, 146 | # so a file named "default.css" will overwrite the builtin "default.css". 147 | html_static_path = ["_static"] 148 | 149 | # Add any extra paths that contain custom files (such as robots.txt or 150 | # .htaccess) here, relative to this directory. These files are copied 151 | # directly to the root of the documentation. 152 | # html_extra_path = [] 153 | 154 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 155 | # using the given strftime format. 156 | # html_last_updated_fmt = '%b %d, %Y' 157 | 158 | # If true, SmartyPants will be used to convert quotes and dashes to 159 | # typographically correct entities. 160 | # html_use_smartypants = True 161 | 162 | # Custom sidebar templates, maps document names to template names. 163 | # html_sidebars = {} 164 | 165 | # Additional templates that should be rendered to pages, maps page names to 166 | # template names. 167 | # html_additional_pages = {} 168 | 169 | # If false, no module index is generated. 170 | # html_domain_indices = True 171 | 172 | # If false, no index is generated. 173 | # html_use_index = True 174 | 175 | # If true, the index is split into individual pages for each letter. 176 | # html_split_index = False 177 | 178 | # If true, links to the reST sources are added to the pages. 179 | # html_show_sourcelink = True 180 | 181 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 182 | # html_show_sphinx = True 183 | 184 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 185 | # html_show_copyright = True 186 | 187 | # If true, an OpenSearch description file will be output, and all pages will 188 | # contain a tag referring to it. The value of this option must be the 189 | # base URL from which the finished HTML is served. 190 | # html_use_opensearch = '' 191 | 192 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 193 | # html_file_suffix = None 194 | 195 | # Language to be used for generating the HTML full-text search index. 196 | # Sphinx supports the following languages: 197 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'hu', 'it', 'ja' 198 | # 'nl', 'no', 'pt', 'ro', 'ru', 'sv', 'tr' 199 | # html_search_language = 'en' 200 | 201 | # A dictionary with options for the search language support, empty by default. 202 | # Now only 'ja' uses this config value 203 | # html_search_options = {'type': 'default'} 204 | 205 | # The name of a javascript file (relative to the configuration directory) that 206 | # implements a search results scorer. If empty, the default will be used. 207 | # html_search_scorer = 'scorer.js' 208 | 209 | # Output file base name for HTML help builder. 210 | htmlhelp_basename = "coredoc" 211 | 212 | # -- Options for LaTeX output --------------------------------------------- 213 | 214 | latex_elements = { 215 | # The paper size ('letterpaper' or 'a4paper'). 216 | #'papersize': 'letterpaper', 217 | # The font size ('10pt', '11pt' or '12pt'). 218 | #'pointsize': '10pt', 219 | # Additional stuff for the LaTeX preamble. 220 | #'preamble': '', 221 | # Latex figure (float) alignment 222 | #'figure_align': 'htbp', 223 | } 224 | 225 | # Grouping the document tree into LaTeX files. List of tuples 226 | # (source start file, target name, title, 227 | # author, documentclass [howto, manual, or own class]). 228 | latex_documents = [ 229 | ( 230 | master_doc, 231 | "autolab_core.tex", 232 | "autolab_core Documentation", 233 | "Jeff Mahler", 234 | "manual", 235 | ), 236 | ] 237 | 238 | # The name of an image file (relative to this directory) to place at the top of 239 | # the title page. 240 | # latex_logo = None 241 | 242 | # For "manual" documents, if this is true, then toplevel headings are parts, 243 | # not chapters. 244 | # latex_use_parts = False 245 | 246 | # If true, show page references after internal links. 247 | # latex_show_pagerefs = False 248 | 249 | # If true, show URL addresses after external links. 250 | # latex_show_urls = False 251 | 252 | # Documents to append as an appendix to all manuals. 253 | # latex_appendices = [] 254 | 255 | # If false, no module index is generated. 256 | # latex_domain_indices = True 257 | 258 | 259 | # -- Options for manual page output --------------------------------------- 260 | 261 | # One entry per manual page. List of tuples 262 | # (source start file, name, description, authors, manual section). 263 | man_pages = [ 264 | (master_doc, "autolab_core", "autolab_core Documentation", [author], 1) 265 | ] 266 | 267 | # If true, show URL addresses after external links. 268 | # man_show_urls = False 269 | 270 | 271 | # -- Options for Texinfo output ------------------------------------------- 272 | 273 | # Grouping the document tree into Texinfo files. List of tuples 274 | # (source start file, target name, title, author, 275 | # dir menu entry, description, category) 276 | texinfo_documents = [ 277 | ( 278 | master_doc, 279 | "autolab_core", 280 | "autolab_core Documentation", 281 | author, 282 | "autolab_core", 283 | "One line description of project.", 284 | "Miscellaneous", 285 | ), 286 | ] 287 | 288 | # Documents to append as an appendix to all manuals. 289 | # texinfo_appendices = [] 290 | 291 | # If false, no module index is generated. 292 | # texinfo_domain_indices = True 293 | 294 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 295 | # texinfo_show_urls = 'footnote' 296 | 297 | # If true, do not generate a @detailmenu in the "Top" node's menu. 298 | # texinfo_no_detailmenu = False 299 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. core documentation master file, created by 2 | sphinx-quickstart on Sun Oct 16 14:33:48 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Berkeley AutoLab Core Documentation 7 | =================================== 8 | Welcome to the documentation for the Berkeley AutoLab's `autolab_core` module! 9 | This module is designed to be useful in a broad set of robotics tasks. 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Installation Guide 14 | 15 | install/install.rst 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: API Documentation 20 | :glob: 21 | 22 | api/* 23 | 24 | Indices and tables 25 | ================== 26 | 27 | * :ref:`genindex` 28 | * :ref:`modindex` 29 | * :ref:`search` 30 | 31 | -------------------------------------------------------------------------------- /docs/source/install/install.rst: -------------------------------------------------------------------------------- 1 | Python Installation 2 | ~~~~~~~~~~~~~~~~~~~ 3 | 4 | Option 1: Pip 5 | """"""""""""" 6 | This package is now installable via `pip` :: 7 | 8 | $ pip install autolab_core 9 | 10 | Use this option if you aren't interested in using our ROS nodes. 11 | 12 | Option 2: Install from Source for ROS 13 | """"""""""""""""""""""""""""""""""""" 14 | The `autolab_core` library can also be used with ROS, as our `RigidTransform`_ class can be used to wrap rigid transformations accessed through `tf`_. 15 | This provides a convenient override of multiplication operator :: 16 | 17 | T_b_a = RigidTransform.rigid_transform_from_ros(from_frame='a', to_frame='b') 18 | T_c_b = RigidTransform.rigid_transform_from_ros(from_frame='b', to_frame='c') 19 | T_c_a = T_c_b * T_b_a 20 | 21 | The `RigidTransform`_ class also does automatic checking of frame name compatibility to help prevent bugs. 22 | 23 | See the static methods `publish_to_ros`_, `delete_from_ros`_, and `rigid_transform_from_ros`_ of `RigidTransform`_ for more information. 24 | 25 | .. _RigidTransform: ../api/rigid_transform.html 26 | .. _tf: http://wiki.ros.org/tf 27 | .. _publish_to_ros: ../api/rigid_transform.html#autolab_core.RigidTransform.publish_to_ros 28 | .. _delete_from_ros: ../api/rigid_transform.html#autolab_core.RigidTransform.delete_from_ros 29 | .. _rigid_transform_from_ros: ../api/rigid_transform.html#autolab_core.RigidTransform.rigid_transform_from_ros 30 | 31 | Start by cloning or downloading our source code from `Github`_. :: 32 | 33 | $ cd {PATH_TO_YOUR_CATKIN_WORKSPACE}/src 34 | $ git clone https://github.com/BerkeleyAutomation/autolab_core.git 35 | 36 | .. _Github: https://github.com/BerkeleyAutomation/autolab_core 37 | 38 | 39 | Change directories into the `autolab_core` repository and run :: 40 | 41 | $ python setup.py install 42 | 43 | Finally, run `catkin_make` :: 44 | 45 | $ cd {PATH_TO_YOUR_CATKIN_WORKSPACE} 46 | $ catkin_make 47 | 48 | Then re-source devel/setup.bash for the module to be available through Python. 49 | 50 | Documentation 51 | ~~~~~~~~~~~~~ 52 | 53 | Building 54 | """""""" 55 | Building `autolab_core`'s documentation requires a few extra dependencies -- 56 | specifically, `sphinx`_ and a few plugins. 57 | 58 | .. _sphinx: http://www.sphinx-doc.org/en/1.4.8/ 59 | 60 | To install the dependencies required, simply change directories into the `autolab_core` source and run :: 61 | 62 | $ pip install .[docs] 63 | 64 | Then, go to the `docs` directory and run ``make`` with the appropriate target. 65 | For example, :: 66 | 67 | $ cd docs/ 68 | $ make html 69 | 70 | will generate a set of web pages. Any documentation files 71 | generated in this manner can be found in `docs/build`. 72 | 73 | Deploying 74 | """"""""" 75 | To deploy documentation to the Github Pages site for the repository, 76 | simply push any changes to the documentation source to master 77 | and then run :: 78 | 79 | $ . gh_deploy.sh 80 | 81 | from the `docs` folder. This script will automatically checkout the 82 | ``gh-pages`` branch, build the documentation from source, and push it 83 | to Github. 84 | 85 | -------------------------------------------------------------------------------- /launch/rigid_transforms.launch: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | autolab_core 4 | 1.1.0 5 | The autolab_core package 6 | 7 | 8 | 9 | 10 | Mike Danielczuk 11 | 12 | 13 | 14 | 15 | 16 | Apache v2.0 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | Jeff Mahler 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | catkin 43 | rospy 44 | message_generation 45 | rospy 46 | message_runtime 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /ros_nodes/rigid_transform_listener.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ Buffers ROS TF and provides a service to get transforms 4 | """ 5 | 6 | import rospy 7 | import tf2_ros 8 | 9 | try: 10 | from autolab_core.srv import ( 11 | RigidTransformListener, 12 | RigidTransformListenerResponse, 13 | ) 14 | except ImportError: 15 | raise RuntimeError( 16 | "rigid_transform_ros_listener service unavailable outside of " 17 | "catkin package" 18 | ) 19 | 20 | if __name__ == "__main__": 21 | rospy.init_node("rigid_transform_listener") 22 | 23 | tfBuffer = tf2_ros.Buffer() 24 | listener = tf2_ros.TransformListener(tfBuffer) 25 | 26 | def handle_request(req): 27 | trans = tfBuffer.lookup_transform( 28 | req.from_frame, req.to_frame, rospy.Time() 29 | ) 30 | return RigidTransformListenerResponse( 31 | trans.transform.translation.x, 32 | trans.transform.translation.y, 33 | trans.transform.translation.z, 34 | trans.transform.rotation.w, 35 | trans.transform.rotation.x, 36 | trans.transform.rotation.y, 37 | trans.transform.rotation.z, 38 | ) 39 | 40 | s = rospy.Service( 41 | "rigid_transform_listener", RigidTransformListener, handle_request 42 | ) 43 | 44 | rospy.spin() 45 | -------------------------------------------------------------------------------- /ros_nodes/rigid_transform_publisher.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ Publisher service that takes transforms and publishes 3 | them to ROS TF periodically 4 | """ 5 | 6 | import rospy 7 | import tf2_ros 8 | import tf2_msgs 9 | import geometry_msgs 10 | 11 | try: 12 | from autolab_core.srv import ( 13 | RigidTransformPublisher, 14 | RigidTransformPublisherResponse, 15 | ) 16 | except ImportError: 17 | raise RuntimeError( 18 | "rigid_transform_publisher service is unavailable " 19 | "outside of catkin package" 20 | ) 21 | 22 | if __name__ == "__main__": 23 | to_publish = {} 24 | 25 | def handle_request(req): 26 | mode = req.mode.lower() 27 | transform_key = frozenset((req.from_frame, req.to_frame)) 28 | if mode == "delete": 29 | if transform_key in to_publish: 30 | del to_publish[transform_key] 31 | elif req.mode == "frame" or mode == "transform": 32 | t = geometry_msgs.msg.TransformStamped() 33 | 34 | t.header.stamp = rospy.Time.now() 35 | t.header.frame_id = req.from_frame 36 | t.child_frame_id = req.to_frame 37 | 38 | t.transform.translation.x = req.x_trans 39 | t.transform.translation.y = req.y_trans 40 | t.transform.translation.z = req.z_trans 41 | t.transform.rotation.w = req.w_rot 42 | t.transform.rotation.x = req.x_rot 43 | t.transform.rotation.y = req.y_rot 44 | t.transform.rotation.z = req.z_rot 45 | 46 | to_publish[transform_key] = (t, mode) 47 | else: 48 | raise RuntimeError("mode {0} is not supported".format(req.mode)) 49 | return RigidTransformPublisherResponse() 50 | 51 | rospy.init_node("rigid_transform_publisher") 52 | s = rospy.Service( 53 | "rigid_transform_publisher", RigidTransformPublisher, handle_request 54 | ) 55 | 56 | publisher = rospy.Publisher("/tf", tf2_msgs.msg.TFMessage, queue_size=1) 57 | broadcaster = tf2_ros.TransformBroadcaster() 58 | 59 | rate_keeper = rospy.Rate(10) 60 | while not rospy.is_shutdown(): 61 | for transform, mode in to_publish.values(): 62 | transform.header.stamp = rospy.Time.now() 63 | if mode == "frame": 64 | publisher.publish(tf2_msgs.msg.TFMessage([transform])) 65 | elif mode == "transform": 66 | broadcaster.sendTransform(transform) 67 | rate_keeper.sleep() 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup of core python codebase 3 | Author: Jeff Mahler 4 | """ 5 | import os 6 | from setuptools import setup 7 | 8 | requirements = [ 9 | "numpy", 10 | "scipy", 11 | "scikit-image", 12 | "scikit-learn", 13 | "ruamel.yaml", 14 | "matplotlib", 15 | "multiprocess", 16 | "setproctitle", 17 | "opencv-python", 18 | "Pillow", 19 | "joblib", 20 | "colorlog", 21 | "pyreadline; platform_system=='Windows'", 22 | ] 23 | 24 | # load __version__ without importing anything 25 | version_file = os.path.join( 26 | os.path.dirname(__file__), "autolab_core/version.py" 27 | ) 28 | with open(version_file, "r") as f: 29 | # use eval to get a clean string of version from file 30 | __version__ = eval(f.read().strip().split("=")[-1]) 31 | 32 | setup( 33 | name="autolab_core", 34 | version=__version__, 35 | description="Core utilities for the Berkeley AutoLab", 36 | long_description=( 37 | "Core utilities for the Berkeley AutoLab. " 38 | "Includes rigid transformations, loggers, and 3D data wrappers." 39 | ), 40 | author="Jeff Mahler", 41 | author_email="jmahler@berkeley.edu", 42 | maintainer="Mike Danielczuk", 43 | maintainer_email="mdanielczuk@berkeley.edu", 44 | license="Apache Software License", 45 | url="https://github.com/BerkeleyAutomation/autolab_core", 46 | keywords="robotics grasping transformations", 47 | classifiers=[ 48 | "Development Status :: 4 - Beta", 49 | "License :: OSI Approved :: Apache Software License", 50 | "Programming Language :: Python", 51 | "Programming Language :: Python :: 3.6", 52 | "Programming Language :: Python :: 3.7", 53 | "Programming Language :: Python :: 3.8", 54 | "Programming Language :: Python :: 3.9", 55 | "Natural Language :: English", 56 | "Topic :: Scientific/Engineering", 57 | ], 58 | packages=["autolab_core"], 59 | install_requires=requirements, 60 | extras_require={ 61 | "docs": ["sphinx", "sphinxcontrib-napoleon", "sphinx_rtd_theme"], 62 | "ros": ["rospkg", "catkin_pkg", "empy"], 63 | }, 64 | ) 65 | -------------------------------------------------------------------------------- /srv/RigidTransformListener.srv: -------------------------------------------------------------------------------- 1 | string from_frame 2 | string to_frame 3 | --- 4 | float64 x_trans 5 | float64 y_trans 6 | float64 z_trans 7 | float64 w_rot 8 | float64 x_rot 9 | float64 y_rot 10 | float64 z_rot -------------------------------------------------------------------------------- /srv/RigidTransformPublisher.srv: -------------------------------------------------------------------------------- 1 | float64 x_trans 2 | float64 y_trans 3 | float64 z_trans 4 | float64 w_rot 5 | float64 x_rot 6 | float64 y_rot 7 | float64 z_rot 8 | string from_frame 9 | string to_frame 10 | string mode 11 | --- -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BerkeleyAutomation/autolab_core/04a75b55d1e8cf51c21a9f84f9ce813cf840351a/tests/__init__.py -------------------------------------------------------------------------------- /tests/constants.py: -------------------------------------------------------------------------------- 1 | IM_HEIGHT = 100 2 | IM_WIDTH = 100 3 | NUM_POINTS = 100 4 | NUM_ITERS = 500 5 | BINARY_THRESH = 127 6 | COLOR_IM_FILEROOT = "tests/data/test_color" 7 | -------------------------------------------------------------------------------- /tests/test_points.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Test basic functionality of the point classes 25 | Authors: Jeff Mahler 26 | """ 27 | import numpy as np 28 | import unittest 29 | 30 | from autolab_core import ( 31 | ImageCoords, 32 | Point, 33 | PointCloud, 34 | RgbCloud, 35 | ) 36 | 37 | 38 | class PointsTest(unittest.TestCase): 39 | def test_inits(self, num_points=10): 40 | # basic init 41 | data = np.random.rand(3, num_points) 42 | p_a = PointCloud(data, "a") 43 | self.assertTrue( 44 | np.abs(data.shape[0] - p_a.shape[0]) < 1e-5, 45 | msg="BagOfPoints has incorrect shape", 46 | ) 47 | self.assertTrue( 48 | np.abs(data.shape[1] - p_a.shape[1]) < 1e-5, 49 | msg="BagOfPoints has incorrect shape", 50 | ) 51 | self.assertTrue( 52 | np.sum(np.abs(data - p_a.data)) < 1e-5, 53 | msg="BagOfPoints has incorrect data", 54 | ) 55 | self.assertTrue( 56 | np.abs(data.shape[0] - p_a.dim) < 1e-5, 57 | msg="BagOfPoints has incorrect dim", 58 | ) 59 | self.assertTrue( 60 | np.abs(data.shape[1] - p_a.num_points) < 1e-5, 61 | msg="BagOfPoints has incorrect num points", 62 | ) 63 | self.assertEqual("a", p_a.frame, msg="BagOfPoints has incorrect frame") 64 | 65 | # point init with multiple points 66 | caught_bad_init = False 67 | try: 68 | data = np.random.rand(3, num_points) 69 | p_a = Point(data, "a") 70 | except ValueError: 71 | caught_bad_init = True 72 | self.assertTrue( 73 | caught_bad_init, 74 | msg="Failed to catch point init with more than one point", 75 | ) 76 | 77 | # point init with bad dim 78 | caught_bad_init = False 79 | try: 80 | data = np.random.rand(3, 3) 81 | p_a = Point(data, "a") 82 | except ValueError: 83 | caught_bad_init = True 84 | self.assertTrue( 85 | caught_bad_init, msg="Failed to catch point init with 3x3" 86 | ) 87 | 88 | # point cloud with bad shape 89 | caught_bad_init = False 90 | try: 91 | data = np.random.rand(3, 3, 3) 92 | p_a = PointCloud(data, "a") 93 | except ValueError: 94 | caught_bad_init = True 95 | self.assertTrue( 96 | caught_bad_init, msg="Failed to catch point cloud init with 3x3x3" 97 | ) 98 | 99 | # point cloud with bad dim 100 | caught_bad_init = False 101 | try: 102 | data = np.random.rand(4, num_points) 103 | p_a = PointCloud(data, "a") 104 | except ValueError: 105 | caught_bad_init = True 106 | self.assertTrue( 107 | caught_bad_init, 108 | msg="Failed to catch point cloud init with 4x%d" % (num_points), 109 | ) 110 | 111 | # point cloud with bad type 112 | caught_bad_init = False 113 | try: 114 | data = 100 * np.random.rand(3, num_points).astype(np.uint8) 115 | p_a = PointCloud(data, "a") 116 | except ValueError: 117 | caught_bad_init = True 118 | self.assertTrue( 119 | caught_bad_init, 120 | msg="Failed to catch point cloud init with uint type", 121 | ) 122 | 123 | # image coords with bad type 124 | caught_bad_init = False 125 | try: 126 | data = np.random.rand(2, num_points) 127 | p_a = ImageCoords(data, "a") 128 | except ValueError: 129 | caught_bad_init = True 130 | self.assertTrue( 131 | caught_bad_init, 132 | msg="Failed to catch image coords init with float type", 133 | ) 134 | 135 | # image coords with bad dim 136 | caught_bad_init = False 137 | try: 138 | data = 100 * np.random.rand(3, num_points).astype(np.uint16) 139 | p_a = ImageCoords(data, "a") 140 | except ValueError: 141 | caught_bad_init = True 142 | self.assertTrue( 143 | caught_bad_init, 144 | msg="Failed to catch image coords init with 3xN array", 145 | ) 146 | 147 | # rgb coordinate with bad type 148 | caught_bad_init = False 149 | try: 150 | data = np.random.rand(3, num_points) 151 | p_a = RgbCloud(data, "a") 152 | except ValueError: 153 | caught_bad_init = True 154 | self.assertTrue( 155 | caught_bad_init, 156 | msg="Failed to catch rgb cloud init with float type", 157 | ) 158 | 159 | # image coords with bad dim 160 | caught_bad_init = False 161 | try: 162 | data = 100 * np.random.rand(4, num_points).astype(np.uint16) 163 | p_a = RgbCloud(data, "a") 164 | except ValueError: 165 | caught_bad_init = True 166 | self.assertTrue( 167 | caught_bad_init, 168 | msg="Failed to catch rgb cloud init with 4xN array", 169 | ) 170 | 171 | def test_divs(self, num_points=10): 172 | data = np.random.rand(3, num_points) 173 | p_a = PointCloud(data, "a") 174 | p_b = Point(np.random.rand(3), "b") 175 | 176 | # div on left 177 | p_a_int = p_a / 5 178 | assert np.allclose(p_a_int._data, p_a._data / 5) 179 | p_a_float = p_a / 2.5 180 | assert np.allclose(p_a_float._data, p_a._data / 2.5) 181 | p_b_int = p_b / 5 182 | assert np.allclose(p_b_int._data, p_b._data / 5) 183 | p_b_float = p_b / 2.5 184 | assert np.allclose(p_b_float._data, p_b._data / 2.5) 185 | 186 | # div on right 187 | p_a_int = 5 / p_a 188 | assert np.allclose(p_a_int._data, 5 / p_a._data) 189 | p_a_float = 2.5 / p_a 190 | assert np.allclose(p_a_float._data, 2.5 / p_a._data) 191 | p_b_int = 5 / p_b 192 | assert np.allclose(p_b_int._data, 5 / p_b._data) 193 | p_b_float = 2.5 / p_b 194 | assert np.allclose(p_b_float._data, 2.5 / p_b._data) 195 | 196 | 197 | if __name__ == "__main__": 198 | unittest.main() 199 | -------------------------------------------------------------------------------- /tests/test_registration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests the image class. 3 | Author: Jeff Mahler 4 | """ 5 | import unittest 6 | 7 | import numpy as np 8 | 9 | from .constants import NUM_POINTS, NUM_ITERS 10 | from autolab_core import ( 11 | RigidTransform, 12 | PointCloud, 13 | NormalCloud, 14 | PointToPlaneICPSolver, 15 | PointToPlaneFeatureMatcher, 16 | ) 17 | 18 | 19 | class TestRegistration(unittest.TestCase): 20 | def test_registration(self): 21 | np.random.seed(101) 22 | 23 | source_points = np.random.rand(3, NUM_POINTS).astype(np.float32) 24 | source_normals = np.random.rand(3, NUM_POINTS).astype(np.float32) 25 | source_normals = source_normals / np.tile( 26 | np.linalg.norm(source_normals, axis=0)[np.newaxis, :], [3, 1] 27 | ) 28 | 29 | source_point_cloud = PointCloud(source_points, frame="world") 30 | source_normal_cloud = NormalCloud(source_normals, frame="world") 31 | 32 | matcher = PointToPlaneFeatureMatcher() 33 | solver = PointToPlaneICPSolver(sample_size=NUM_POINTS) 34 | 35 | # 3d registration 36 | tf = RigidTransform( 37 | rotation=RigidTransform.random_rotation(), 38 | translation=RigidTransform.random_translation(), 39 | from_frame="world", 40 | to_frame="world", 41 | ) 42 | tf = RigidTransform( 43 | from_frame="world", to_frame="world" 44 | ).interpolate_with(tf, 0.01) 45 | target_point_cloud = tf * source_point_cloud 46 | target_normal_cloud = tf * source_normal_cloud 47 | 48 | result = solver.register( 49 | source_point_cloud, 50 | target_point_cloud, 51 | source_normal_cloud, 52 | target_normal_cloud, 53 | matcher, 54 | num_iterations=NUM_ITERS, 55 | ) 56 | 57 | self.assertTrue( 58 | np.allclose(tf.matrix, result.T_source_target.matrix, atol=1e-3) 59 | ) 60 | 61 | # 2d registration 62 | theta = 0.1 * np.random.rand() 63 | t = 0.005 * np.random.rand(3, 1) 64 | t[2] = 0 65 | R = np.array( 66 | [ 67 | [np.cos(theta), -np.sin(theta), 0], 68 | [np.sin(theta), np.cos(theta), 0], 69 | [0, 0, 1], 70 | ] 71 | ) 72 | tf = RigidTransform(R, t, from_frame="world", to_frame="world") 73 | target_point_cloud = tf * source_point_cloud 74 | target_normal_cloud = tf * source_normal_cloud 75 | 76 | result = solver.register_2d( 77 | source_point_cloud, 78 | target_point_cloud, 79 | source_normal_cloud, 80 | target_normal_cloud, 81 | matcher, 82 | num_iterations=NUM_ITERS, 83 | ) 84 | 85 | self.assertTrue( 86 | np.allclose(tf.matrix, result.T_source_target.matrix, atol=1e-3) 87 | ) 88 | 89 | 90 | if __name__ == "__main__": 91 | unittest.main() 92 | -------------------------------------------------------------------------------- /tests/test_rigid_transform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Test correct functionality of the rigid transform class 25 | Authors: Jeff Mahler 26 | """ 27 | import numpy as np 28 | import unittest 29 | 30 | from autolab_core import Point, PointCloud, Direction 31 | from autolab_core import RigidTransform, SimilarityTransform 32 | 33 | 34 | class RigidTransformTest(unittest.TestCase): 35 | def test_init(self): 36 | R = RigidTransform.random_rotation() 37 | t = RigidTransform.random_translation() 38 | from_frame = "a" 39 | to_frame = "b" 40 | T_a_b = RigidTransform(R, t, from_frame, to_frame) 41 | self.assertTrue(np.sum(np.abs(R - T_a_b.rotation)) < 1e-5) 42 | self.assertTrue(np.sum(np.abs(t - T_a_b.translation)) < 1e-5) 43 | 44 | def test_bad_inits(self): 45 | # test bad rotation dim 46 | R = np.random.rand(3) 47 | caught_bad_rotation = False 48 | try: 49 | RigidTransform(R) 50 | except ValueError: 51 | caught_bad_rotation = True 52 | self.assertTrue( 53 | caught_bad_rotation, msg="Failed to catch 3x1 rotation matrix" 54 | ) 55 | 56 | # test bad rotation dim 57 | R = np.random.rand(3, 3, 3) 58 | caught_bad_rotation = False 59 | try: 60 | RigidTransform(R) 61 | except ValueError: 62 | caught_bad_rotation = True 63 | self.assertTrue( 64 | caught_bad_rotation, msg="Failed to catch 3x3x3 rotation matrix" 65 | ) 66 | 67 | # determinant not equal to one 68 | R = np.random.rand(3, 3) 69 | caught_bad_rotation = False 70 | try: 71 | RigidTransform(R) 72 | except ValueError: 73 | caught_bad_rotation = True 74 | self.assertTrue( 75 | caught_bad_rotation, msg="Failed to catch rotation with det != 1" 76 | ) 77 | 78 | # translation with illegal dimensions 79 | t = np.random.rand(3, 3) 80 | caught_bad_translation = False 81 | try: 82 | RigidTransform(translation=t) 83 | except ValueError: 84 | caught_bad_translation = True 85 | self.assertTrue( 86 | caught_bad_translation, msg="Failed to catch 3x3 translation" 87 | ) 88 | 89 | # translation with illegal dimensions 90 | t = np.random.rand(2) 91 | caught_bad_translation = False 92 | try: 93 | RigidTransform(translation=t) 94 | except ValueError: 95 | caught_bad_translation = True 96 | self.assertTrue( 97 | caught_bad_translation, msg="Failed to catch 2x1 translation" 98 | ) 99 | 100 | def test_inverse(self): 101 | R_a_b = RigidTransform.random_rotation() 102 | t_a_b = RigidTransform.random_translation() 103 | T_a_b = RigidTransform(R_a_b, t_a_b, "a", "b") 104 | T_b_a = T_a_b.inverse() 105 | 106 | # multiple with numpy arrays 107 | M_a_b = np.r_[np.c_[R_a_b, t_a_b], [[0, 0, 0, 1]]] 108 | M_b_a = np.linalg.inv(M_a_b) 109 | 110 | self.assertTrue( 111 | np.sum(np.abs(T_b_a.matrix - M_b_a)) < 1e-5, 112 | msg="Inverse gave incorrect transformation", 113 | ) 114 | 115 | # check frames 116 | self.assertEqual( 117 | T_b_a.from_frame, "b", msg="Inverse has incorrect input frame" 118 | ) 119 | self.assertEqual( 120 | T_b_a.to_frame, "a", msg="Inverse has incorrect output frame" 121 | ) 122 | 123 | def test_composition(self): 124 | R_a_b = RigidTransform.random_rotation() 125 | t_a_b = RigidTransform.random_translation() 126 | R_b_c = RigidTransform.random_rotation() 127 | t_b_c = RigidTransform.random_translation() 128 | T_a_b = RigidTransform(R_a_b, t_a_b, "a", "b") 129 | T_b_c = RigidTransform(R_b_c, t_b_c, "b", "c") 130 | 131 | # multiply with numpy arrays 132 | M_a_b = np.r_[np.c_[R_a_b, t_a_b], [[0, 0, 0, 1]]] 133 | M_b_c = np.r_[np.c_[R_b_c, t_b_c], [[0, 0, 0, 1]]] 134 | M_a_c = M_b_c.dot(M_a_b) 135 | 136 | # use multiplication operator 137 | T_a_c = T_b_c * T_a_b 138 | 139 | self.assertTrue( 140 | np.sum(np.abs(T_a_c.matrix - M_a_c)) < 1e-5, 141 | msg="Composition gave incorrect transformation", 142 | ) 143 | 144 | # check frames 145 | self.assertEqual( 146 | T_a_c.from_frame, "a", msg="Composition has incorrect input frame" 147 | ) 148 | self.assertEqual( 149 | T_a_c.to_frame, "c", msg="Composition has incorrect output frame" 150 | ) 151 | 152 | def test_point_transformation(self): 153 | R_a_b = RigidTransform.random_rotation() 154 | t_a_b = RigidTransform.random_translation() 155 | T_a_b = RigidTransform(R_a_b, t_a_b, "a", "b") 156 | 157 | x_a = np.random.rand(3) 158 | p_a = Point(x_a, "a") 159 | 160 | # multiply with numpy arrays 161 | x_b = R_a_b.dot(x_a) + t_a_b 162 | 163 | # use multiplication operator 164 | p_b = T_a_b * p_a 165 | 166 | self.assertTrue( 167 | np.sum(np.abs(p_b.vector - x_b)) < 1e-5, 168 | msg="Point transformation incorrect: Expected {}, Got {}".format( 169 | x_b, p_b.data 170 | ), 171 | ) 172 | 173 | # check frames 174 | self.assertEqual( 175 | p_b.frame, "b", msg="Transformed point has incorrect frame" 176 | ) 177 | 178 | def test_point_cloud_transformation(self, num_points=10): 179 | R_a_b = RigidTransform.random_rotation() 180 | t_a_b = RigidTransform.random_translation() 181 | T_a_b = RigidTransform(R_a_b, t_a_b, "a", "b") 182 | 183 | x_a = np.random.rand(3, num_points) 184 | pc_a = PointCloud(x_a, "a") 185 | 186 | # multiply with numpy arrays 187 | x_b = R_a_b.dot(x_a) + np.tile(t_a_b.reshape(3, 1), [1, num_points]) 188 | 189 | # use multiplication operator 190 | pc_b = T_a_b * pc_a 191 | 192 | self.assertTrue( 193 | np.sum(np.abs(pc_b.data - x_b)) < 1e-5, 194 | msg="Point cloud transformation incorrect:\n" 195 | "Expected:\n{}\nGot:\n{}".format(x_b, pc_b.data), 196 | ) 197 | 198 | # check frames 199 | self.assertEqual( 200 | pc_b.frame, "b", msg="Transformed point cloud has incorrect frame" 201 | ) 202 | 203 | def test_bad_transformation(self, num_points=10): 204 | R_a_b = RigidTransform.random_rotation() 205 | t_a_b = RigidTransform.random_translation() 206 | T_a_b = RigidTransform(R_a_b, t_a_b, "a", "b") 207 | 208 | # bad point frame 209 | caught_bad_frame = False 210 | try: 211 | x_c = np.random.rand(3) 212 | p_c = Point(x_c, "c") 213 | T_a_b * p_c 214 | except ValueError: 215 | caught_bad_frame = True 216 | self.assertTrue( 217 | caught_bad_frame, msg="Failed to catch bad point frame" 218 | ) 219 | 220 | # bad point cloud frame 221 | caught_bad_frame = False 222 | try: 223 | x_c = np.random.rand(3, num_points) 224 | pc_c = PointCloud(x_c, "c") 225 | T_a_b * pc_c 226 | except ValueError: 227 | caught_bad_frame = True 228 | self.assertTrue( 229 | caught_bad_frame, msg="Failed to catch bad point cloud frame" 230 | ) 231 | 232 | # illegal input 233 | caught_bad_input = False 234 | try: 235 | x_a = np.random.rand(3, num_points) 236 | T_a_b * x_a 237 | except ValueError: 238 | caught_bad_input = True 239 | self.assertTrue( 240 | caught_bad_input, msg="Failed to catch numpy array input" 241 | ) 242 | 243 | def test_similarity_transformation(self): 244 | R_a_b = RigidTransform.random_rotation() 245 | t_a_b = RigidTransform.random_translation() 246 | s_a_b = 2 * np.random.rand() 247 | R_b_c = RigidTransform.random_rotation() 248 | t_b_c = RigidTransform.random_translation() 249 | s_b_c = 2 * np.random.rand() 250 | T_a_b = SimilarityTransform(R_a_b, t_a_b, s_a_b, "a", "b") 251 | T_b_c = SimilarityTransform(R_b_c, t_b_c, s_b_c, "b", "c") 252 | 253 | T_b_a = T_a_b.inverse() 254 | 255 | x_a = np.random.rand(3) 256 | p_a = Point(x_a, "a") 257 | p_a2 = T_b_a * T_a_b * p_a 258 | self.assertTrue(np.allclose(p_a.data, p_a2.data)) 259 | 260 | p_b = T_a_b * p_a 261 | p_b2 = s_a_b * (R_a_b.dot(p_a.data)) + t_a_b 262 | self.assertTrue(np.allclose(p_b.data, p_b2)) 263 | 264 | p_c = T_b_c * T_a_b * p_a 265 | p_c2 = s_b_c * (R_b_c.dot(p_b2)) + t_b_c 266 | self.assertTrue(np.allclose(p_c.data, p_c2)) 267 | 268 | v_a = np.random.rand(3) 269 | v_a = v_a / np.linalg.norm(v_a) 270 | v_a = Direction(v_a, "a") 271 | v_b = T_a_b * v_a 272 | v_b2 = R_a_b.dot(v_a.data) 273 | self.assertTrue(np.allclose(v_b.data, v_b2)) 274 | 275 | def test_linear_trajectory(self): 276 | R_a = RigidTransform.random_rotation() 277 | t_a = RigidTransform.random_translation() 278 | R_b = RigidTransform.random_rotation() 279 | t_b = RigidTransform.random_translation() 280 | T_a = RigidTransform(R_a, t_a, "w", "a") 281 | T_b = RigidTransform(R_b, t_b, "w", "b") 282 | 283 | for i in range(10): 284 | traj = T_a.linear_trajectory_to(T_b, i) 285 | self.assertEqual(len(traj), i, "Trajectory has incorrect length") 286 | 287 | 288 | if __name__ == "__main__": 289 | unittest.main() 290 | -------------------------------------------------------------------------------- /tools/aggregate_tensor_datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Aggregates a pair of tensor datasets, merging them into a single dataset 25 | Author: Jeff Mahler 26 | """ 27 | import argparse 28 | import copy 29 | import logging 30 | import numpy as np 31 | import os 32 | import shutil 33 | 34 | from autolab_core import TensorDataset, YamlConfig 35 | import autolab_core.utils as utils 36 | 37 | if __name__ == "__main__": 38 | # set up logger 39 | logging.getLogger().setLevel(logging.INFO) 40 | 41 | # parse args 42 | parser = argparse.ArgumentParser( 43 | description="Merges a set of tensor datasets" 44 | ) 45 | parser.add_argument( 46 | "--config_filename", 47 | type=str, 48 | default="cfg/tools/aggregate_tensor_datasets.yaml", 49 | help="configuration file to use", 50 | ) 51 | args = parser.parse_args() 52 | config_filename = args.config_filename 53 | 54 | # open config file 55 | cfg = YamlConfig(config_filename) 56 | input_dataset_names = cfg["input_datasets"] 57 | output_dataset_name = cfg["output_dataset"] 58 | display_rate = cfg["display_rate"] 59 | 60 | # modify list of dataset names 61 | all_input_dataset_names = [] 62 | for dataset_name in input_dataset_names: 63 | tensor_dir = os.path.join(dataset_name, "tensors") 64 | if os.path.exists(tensor_dir): 65 | all_input_dataset_names.append(dataset_name) 66 | else: 67 | dataset_subdirs = utils.filenames(dataset_name, tag="dataset_") 68 | all_input_dataset_names.extend(dataset_subdirs) 69 | 70 | # open tensor dataset 71 | dataset = TensorDataset.open(all_input_dataset_names[0]) 72 | tensor_config = copy.deepcopy(dataset.config) 73 | for field_name in cfg["exclude_fields"]: 74 | if field_name in tensor_config["fields"].keys(): 75 | del tensor_config["fields"][field_name] 76 | field_names = tensor_config["fields"].keys() 77 | alt_field_names = [ 78 | f if f != "rewards" else "grasp_metrics" for f in field_names 79 | ] 80 | 81 | # init tensor dataset 82 | output_dataset = TensorDataset(output_dataset_name, tensor_config) 83 | 84 | # copy config 85 | out_config_filename = os.path.join( 86 | output_dataset_name, "merge_config.yaml" 87 | ) 88 | shutil.copyfile(config_filename, out_config_filename) 89 | 90 | # incrementally add points to the new dataset 91 | obj_id = 0 92 | obj_ids = {"unknown": 0} 93 | for dataset_name in all_input_dataset_names: 94 | dataset = TensorDataset.open(dataset_name) 95 | if "obj_ids" in dataset.metadata.keys(): 96 | dataset_obj_ids = dataset.metadata["obj_ids"] 97 | logging.info("Aggregating data from dataset %s" % (dataset_name)) 98 | for i in range(dataset.num_datapoints): 99 | try: 100 | datapoint = dataset.datapoint(i, field_names=field_names) 101 | except IndexError: 102 | datapoint = dataset.datapoint(i, field_names=alt_field_names) 103 | datapoint["rewards"] = datapoint["grasp_metrics"] 104 | del datapoint["grasp_metrics"] 105 | 106 | if i % display_rate == 0: 107 | logging.info( 108 | "Datapoint: %d of %d" % (i + 1, dataset.num_datapoints) 109 | ) 110 | 111 | if "obj_ids" in dataset.metadata.keys(): 112 | # modify object ids 113 | dataset_obj_ids = dataset.metadata["obj_ids"] 114 | for k in range(datapoint["obj_ids"].shape[0]): 115 | dataset_obj_id = datapoint["obj_ids"][k] 116 | if dataset_obj_id != np.iinfo(np.uint32).max: 117 | dataset_obj_key = dataset_obj_ids[str(dataset_obj_id)] 118 | if dataset_obj_key not in obj_ids.keys(): 119 | obj_ids[dataset_obj_key] = obj_id 120 | obj_id += 1 121 | datapoint["obj_ids"][k] = obj_ids[dataset_obj_key] 122 | 123 | # modify grasped obj id 124 | dataset_grasped_obj_id = datapoint["grasped_obj_ids"] 125 | grasped_obj_key = dataset_obj_ids[str(dataset_grasped_obj_id)] 126 | datapoint["grasped_obj_ids"] = obj_ids[grasped_obj_key] 127 | 128 | # add datapoint 129 | output_dataset.add(datapoint) 130 | 131 | # set metadata 132 | obj_ids = utils.reverse_dictionary(obj_ids) 133 | output_dataset.add_metadata("obj_ids", obj_ids) 134 | for field_name, field_data in dataset.metadata.iteritems(): 135 | if field_name not in ["obj_ids"]: 136 | output_dataset.add_metadata(field_name, field_data) 137 | 138 | # flush to disk 139 | output_dataset.flush() 140 | -------------------------------------------------------------------------------- /tools/compute_dataset_statistics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Analysis tool for experiments with the OpenAI gym environment 25 | Author: Jeff Mahler 26 | """ 27 | import argparse 28 | import json 29 | import logging 30 | import numpy as np 31 | import os 32 | import random 33 | 34 | from autolab_core import TensorDataset, YamlConfig 35 | import autolab_core.utils as utils 36 | 37 | from visualization import Visualizer2D as vis2d 38 | 39 | SEED = 2345 40 | 41 | 42 | def compute_dataset_statistics(dataset_path, output_path, config): 43 | """ 44 | Compute the statistics of fields of a TensorDataset 45 | 46 | Parameters 47 | ---------- 48 | dataset_path : str 49 | path to the dataset 50 | output_dir : str 51 | where to save the data 52 | config : :obj:`YamlConfig` 53 | parameters for the analysis 54 | """ 55 | # parse config 56 | analysis_fields = config["analysis_fields"] 57 | num_percentiles = config["num_percentiles"] 58 | thresholds = config["thresholds"] 59 | log_rate = config["log_rate"] 60 | 61 | num_bins = config["num_bins"] 62 | font_size = config["font_size"] 63 | dpi = config["dpi"] 64 | 65 | # create dataset for the aggregated results 66 | dataset = TensorDataset.open(dataset_path) 67 | num_datapoints = dataset.num_datapoints 68 | 69 | # allocate buffers 70 | analysis_data = {} 71 | for field in analysis_fields: 72 | analysis_data[field] = [] 73 | 74 | # loop through dataset 75 | for i in range(num_datapoints): 76 | if i % log_rate == 0: 77 | logging.info( 78 | "Reading datapoint %d of %d" % (i + 1, num_datapoints) 79 | ) 80 | 81 | # read datapoint 82 | datapoint = dataset.datapoint(i, analysis_fields) 83 | for key, value in datapoint.iteritems(): 84 | analysis_data[key].append(value) 85 | 86 | # create output CSV 87 | stats_headers = { 88 | "name": "str", 89 | "mean": "float", 90 | "median": "float", 91 | "std": "float", 92 | } 93 | for i in range(num_percentiles): 94 | pctile = int((100.0 / num_percentiles) * i) 95 | field = "%d_pctile" % (pctile) 96 | stats_headers[field] = "float" 97 | for t in thresholds: 98 | field = "pct_above_%.3f" % (t) 99 | stats_headers[field] = "float" 100 | 101 | # analyze statistics 102 | for field, data in analysis_data.iteritems(): 103 | # init arrays 104 | data = np.array(data) 105 | 106 | # init filename 107 | stats_filename = os.path.join(output_path, "%s_stats.json" % (field)) 108 | if os.path.exists(stats_filename): 109 | logging.warning("Statistics file %s exists!" % (stats_filename)) 110 | 111 | # stats 112 | mean = np.mean(data) 113 | median = np.median(data) 114 | std = np.std(data) 115 | stats = { 116 | "name": str(field), 117 | "mean": float(mean), 118 | "median": float(median), 119 | "std": float(std), 120 | } 121 | for i in range(num_percentiles): 122 | pctile = int((100.0 / num_percentiles) * i) 123 | pctile_field = "%d_pctile" % (pctile) 124 | stats[pctile_field] = float(np.percentile(data, pctile)) 125 | for t in thresholds: 126 | t_field = "pct_above_%.3f" % (t) 127 | stats[t_field] = float(np.mean(1 * (data > t))) 128 | json.dump(stats, open(stats_filename, "w"), indent=2, sort_keys=True) 129 | 130 | # histogram 131 | num_unique = np.unique(data).shape[0] 132 | nb = min(num_bins, data.shape[0], num_unique) 133 | bounds = (np.min(data), np.max(data)) 134 | vis2d.figure() 135 | utils.histogram(data, nb, bounds, normalized=False, plot=True) 136 | vis2d.xlabel(field, fontsize=font_size) 137 | vis2d.ylabel("Count", fontsize=font_size) 138 | data_filename = os.path.join(output_path, "histogram_%s.pdf" % (field)) 139 | vis2d.show(data_filename, dpi=dpi) 140 | 141 | 142 | if __name__ == "__main__": 143 | # initialize logging 144 | logging.getLogger().setLevel(logging.INFO) 145 | 146 | # parse args 147 | parser = argparse.ArgumentParser( 148 | description="Compute statistics of select fields of a tensor dataset" 149 | ) 150 | parser.add_argument( 151 | "dataset_path", 152 | type=str, 153 | default=None, 154 | help="path to an experiment dataset", 155 | ) 156 | parser.add_argument( 157 | "--output_path", 158 | type=str, 159 | default=None, 160 | help="path to save dataset statistics", 161 | ) 162 | parser.add_argument( 163 | "--debug", 164 | type=bool, 165 | default=True, 166 | help="whether to set the random seed", 167 | ) 168 | parser.add_argument( 169 | "--config_filename", 170 | type=str, 171 | default=None, 172 | help="configuration file to use", 173 | ) 174 | args = parser.parse_args() 175 | dataset_path = args.dataset_path 176 | output_path = args.output_path 177 | debug = args.debug 178 | config_filename = args.config_filename 179 | 180 | # auto-save in dataset 181 | if output_path is None: 182 | output_path = os.path.join(dataset_path, "stats") 183 | 184 | # create output dir 185 | if not os.path.exists(output_path): 186 | os.mkdir(output_path) 187 | 188 | # set random seed 189 | if debug: 190 | np.random.seed(SEED) 191 | random.seed(SEED) 192 | 193 | # handle config filename 194 | if config_filename is None: 195 | config_filename = os.path.join( 196 | os.path.dirname(os.path.realpath(__file__)), 197 | "..", 198 | "cfg/tools/compute_dataset_statistics.yaml", 199 | ) 200 | 201 | # turn relative paths absolute 202 | if not os.path.isabs(config_filename): 203 | config_filename = os.path.join(os.getcwd(), config_filename) 204 | 205 | # load config 206 | config = YamlConfig(config_filename) 207 | 208 | # run analysis 209 | compute_dataset_statistics(dataset_path, output_path, config) 210 | -------------------------------------------------------------------------------- /tools/convert_legacy_dataset_to_tensor_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Converts datasets from the old format to a format readable by TensorDataset 25 | Author: Jeff Mahler 26 | """ 27 | import argparse 28 | import json 29 | import numpy as np 30 | import os 31 | import shutil 32 | import logging 33 | 34 | from autolab_core.constants import JSON_INDENT 35 | import autolab_core.utils as utils 36 | 37 | if __name__ == "__main__": 38 | # initialize logging 39 | logging.getLogger().setLevel(logging.INFO) 40 | 41 | # parse args 42 | parser = argparse.ArgumentParser( 43 | description="Convert a legacy dataset to TensorDataset (in-place)" 44 | ) 45 | parser.add_argument( 46 | "dataset_dir", type=str, default=None, help="path to a tensor dataset" 47 | ) 48 | args = parser.parse_args() 49 | dataset_dir = args.dataset_dir 50 | 51 | # read filenames 52 | filenames = utils.filenames(dataset_dir) 53 | 54 | # create config file 55 | datapoints_per_file = None 56 | field_config = {} 57 | for filename in filenames: 58 | _, f = os.path.split(filename) 59 | _, ext = os.path.splitext(f) 60 | if ext != ".npz": 61 | continue 62 | 63 | u_ind = f.rfind("_") 64 | field_name = f[:u_ind] 65 | 66 | if field_name not in field_config.keys(): 67 | field_config[field_name] = {} 68 | data = np.load(filename)["arr_0"] 69 | if datapoints_per_file is None: 70 | datapoints_per_file = data.shape[0] 71 | dtype = str(data.dtype) 72 | field_config[field_name]["dtype"] = dtype 73 | if len(data.shape) > 1: 74 | height = data.shape[1] 75 | field_config[field_name]["height"] = height 76 | if len(data.shape) > 2: 77 | width = data.shape[2] 78 | field_config[field_name]["width"] = width 79 | if len(data.shape) > 3: 80 | channels = data.shape[3] 81 | field_config[field_name]["channels"] = channels 82 | 83 | # write tensor dataset headers 84 | tensor_config = { 85 | "datapoints_per_file": datapoints_per_file, 86 | "fields": field_config, 87 | } 88 | 89 | config_filename = os.path.join(dataset_dir, "config.json") 90 | json.dump( 91 | tensor_config, 92 | open(config_filename, "w"), 93 | indent=JSON_INDENT, 94 | sort_keys=True, 95 | ) 96 | 97 | metadata_filename = os.path.join(dataset_dir, "metadata.json") 98 | json.dump( 99 | {}, open(metadata_filename, "w"), indent=JSON_INDENT, sort_keys=True 100 | ) 101 | 102 | tensor_dir = os.path.join(dataset_dir, "tensors") 103 | if not os.path.exists(tensor_dir): 104 | os.mkdir(tensor_dir) 105 | 106 | # move each individual file 107 | for filename in filenames: 108 | logging.info("Moving file {}".format(filename)) 109 | if filename != tensor_dir: 110 | _, f = os.path.split(filename) 111 | new_filename = os.path.join(tensor_dir, f) 112 | shutil.move(filename, new_filename) 113 | -------------------------------------------------------------------------------- /tools/shuffle_tensor_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Subsamples a TensorDataset. 25 | Author: Jeff Mahler 26 | """ 27 | import argparse 28 | import logging 29 | import numpy as np 30 | 31 | from autolab_core import TensorDataset 32 | 33 | if __name__ == "__main__": 34 | # initialize logging 35 | logging.getLogger().setLevel(logging.INFO) 36 | 37 | # parse args 38 | parser = argparse.ArgumentParser(description="Subsamples a dataset") 39 | parser.add_argument( 40 | "dataset_path", 41 | type=str, 42 | default=None, 43 | help="directory of the dataset to subsample", 44 | ) 45 | parser.add_argument( 46 | "output_path", 47 | type=str, 48 | default=None, 49 | help="directory to store the subsampled dataset", 50 | ) 51 | args = parser.parse_args() 52 | dataset_path = args.dataset_path 53 | output_path = args.output_path 54 | 55 | dataset = TensorDataset.open(dataset_path) 56 | out_dataset = TensorDataset(output_path, dataset.config) 57 | 58 | ind = np.arange(dataset.num_datapoints) 59 | np.random.shuffle(ind) 60 | 61 | for i, j in enumerate(ind): 62 | logging.info("Saving datapoint %d" % (i)) 63 | datapoint = dataset[j] 64 | out_dataset.add(datapoint) 65 | out_dataset.flush() 66 | 67 | for split_name in dataset.split_names: 68 | _, val_indices, _ = dataset.split(split_name) 69 | new_val_indices = [] 70 | for i in range(ind.shape[0]): 71 | if ind[i] in val_indices: 72 | new_val_indices.append(i) 73 | 74 | out_dataset.make_split(split_name, val_indices=new_val_indices) 75 | -------------------------------------------------------------------------------- /tools/split_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Makes a new split for a TensorDataset 25 | 26 | Author 27 | ------ 28 | Jeff Mahler 29 | """ 30 | import argparse 31 | import logging 32 | 33 | from autolab_core import TensorDataset 34 | 35 | 36 | if __name__ == "__main__": 37 | # setup logger 38 | logging.getLogger().setLevel(logging.INFO) 39 | 40 | # parse args 41 | parser = argparse.ArgumentParser( 42 | description="Split a training TensorDataset based on an attribute" 43 | ) 44 | parser.add_argument( 45 | "dataset_dir", 46 | type=str, 47 | default=None, 48 | help="path to the dataset to use for training and validation", 49 | ) 50 | parser.add_argument( 51 | "split_name", type=str, default=None, help="name to use for the split" 52 | ) 53 | parser.add_argument( 54 | "--train_pct", 55 | type=float, 56 | default=0.8, 57 | help="percent of data to use for training", 58 | ) 59 | parser.add_argument( 60 | "--field_name", 61 | type=str, 62 | default=None, 63 | help="name of the field to split on", 64 | ) 65 | args = parser.parse_args() 66 | dataset_dir = args.dataset_dir 67 | split_name = args.split_name 68 | train_pct = args.train_pct 69 | field_name = args.field_name 70 | 71 | # create split 72 | dataset = TensorDataset.open(dataset_dir) 73 | train_indices, val_indices = dataset.make_split( 74 | split_name, train_pct=train_pct, field_name=field_name 75 | ) 76 | -------------------------------------------------------------------------------- /tools/subsample_tensor_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright ©2017. The Regents of the University of California (Regents). 3 | All Rights Reserved. Permission to use, copy, modify, and distribute this 4 | software and its documentation for educational, research, and not-for-profit 5 | purposes, without fee and without a signed licensing agreement, is hereby 6 | granted, provided that the above copyright notice, this paragraph and the 7 | following two paragraphs appear in all copies, modifications, and 8 | distributions. Contact The Office of Technology Licensing, UC Berkeley, 9 | 2150 Shattuck Avenue, Suite 510, Berkeley, CA 94720-1620, (510) 643-7201, 10 | otl@berkeley.edu, http://ipira.berkeley.edu/industry-info for commercial 11 | licensing opportunities. 12 | 13 | IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, 14 | SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, 15 | ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF 16 | REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 17 | 18 | REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 19 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 20 | PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF ANY, PROVIDED 21 | HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION TO PROVIDE 22 | MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. 23 | 24 | Subsamples a TensorDataset. 25 | Author: Jeff Mahler 26 | """ 27 | import argparse 28 | import logging 29 | import numpy as np 30 | 31 | from autolab_core import TensorDataset 32 | 33 | if __name__ == "__main__": 34 | # initialize logging 35 | logging.getLogger().setLevel(logging.INFO) 36 | 37 | # parse args 38 | parser = argparse.ArgumentParser(description="Subsamples a dataset") 39 | parser.add_argument( 40 | "dataset_path", 41 | type=str, 42 | default=None, 43 | help="directory of the dataset to subsample", 44 | ) 45 | parser.add_argument( 46 | "output_path", 47 | type=str, 48 | default=None, 49 | help="directory to store the subsampled dataset", 50 | ) 51 | parser.add_argument( 52 | "num_datapoints", 53 | type=int, 54 | default=None, 55 | help="number of datapoints to subsample", 56 | ) 57 | args = parser.parse_args() 58 | dataset_path = args.dataset_path 59 | output_path = args.output_path 60 | num_datapoints = args.num_datapoints 61 | 62 | dataset = TensorDataset.open(dataset_path) 63 | out_dataset = TensorDataset(output_path, dataset.config) 64 | 65 | num_datapoints = min(num_datapoints, dataset.num_datapoints) 66 | 67 | ind = np.arange(dataset.num_datapoints) 68 | np.random.shuffle(ind) 69 | ind = ind[:num_datapoints] 70 | ind = np.sort(ind) 71 | 72 | for i in ind: 73 | logging.info("Saving datapoint %d" % (i)) 74 | datapoint = dataset[i] 75 | out_dataset.add(datapoint) 76 | 77 | out_dataset.flush() 78 | --------------------------------------------------------------------------------