├── .clang-format ├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── CITATION.cff ├── LICENSE.txt ├── README.md ├── pyproject.toml ├── src └── pytorch_kinematics │ ├── __init__.py │ ├── cfg.py │ ├── chain.py │ ├── frame.py │ ├── ik.py │ ├── jacobian.py │ ├── mjcf.py │ ├── sdf.py │ ├── transforms │ ├── __init__.py │ ├── math.py │ ├── perturbation.py │ ├── rotation_conversions.py │ ├── so3.py │ └── transform3d.py │ ├── urdf.py │ └── urdf_parser_py │ ├── __init__.py │ ├── sdf.py │ ├── urdf.py │ └── xml_reflection │ ├── __init__.py │ ├── basics.py │ └── core.py └── tests ├── __init__.py ├── ant.xml ├── gen_fk_perf.py ├── hopper.xml ├── humanoid.xml ├── joint_limit_robot.urdf ├── joint_no_limit_robot.urdf ├── kuka_iiwa.urdf ├── kuka_iiwa.xml ├── link_0.stl ├── link_1.stl ├── link_2.stl ├── link_3.stl ├── link_4.stl ├── link_5.stl ├── link_6.stl ├── link_7.stl ├── meshes ├── cone.mtl ├── cone.obj ├── link_0.obj ├── link_1.obj ├── link_2.obj ├── link_3.obj ├── link_4.obj ├── link_5.obj ├── link_6.obj └── link_7.obj ├── prismatic_robot.urdf ├── simple_arm.sdf ├── simple_y_arm.urdf ├── test_attributes.py ├── test_inverse_kinematics.py ├── test_jacobian.py ├── test_kinematics.py ├── test_menagerie.py ├── test_rotation_conversions.py ├── test_serial_chain_creation.py ├── test_transform.py ├── ur5.urdf ├── val.xml ├── viz_fk_perf.ipynb └── widowx ├── README.md ├── interbotix_black.png ├── meshes_wx250s ├── WXSA-250-M-1-Base.mtl ├── WXSA-250-M-1-Base.obj ├── WXSA-250-M-1-Base.stl ├── WXSA-250-M-10-Finger.mtl ├── WXSA-250-M-10-Finger.obj ├── WXSA-250-M-10-Finger.stl ├── WXSA-250-M-2-Shoulder.mtl ├── WXSA-250-M-2-Shoulder.obj ├── WXSA-250-M-2-Shoulder.stl ├── WXSA-250-M-3-UA.mtl ├── WXSA-250-M-3-UA.obj ├── WXSA-250-M-3-UA.stl ├── WXSA-250-M-4-UF.mtl ├── WXSA-250-M-4-UF.obj ├── WXSA-250-M-4-UF.stl ├── WXSA-250-M-5-LF.mtl ├── WXSA-250-M-5-LF.obj ├── WXSA-250-M-5-LF.stl ├── WXSA-250-M-6-Wrist.mtl ├── WXSA-250-M-6-Wrist.obj ├── WXSA-250-M-6-Wrist.stl ├── WXSA-250-M-7-Gripper.mtl ├── WXSA-250-M-7-Gripper.obj ├── WXSA-250-M-7-Gripper.stl ├── WXSA-250-M-8-Gripper-Prop.mtl ├── WXSA-250-M-8-Gripper-Prop.obj ├── WXSA-250-M-8-Gripper-Prop.stl ├── WXSA-250-M-9-Gripper-Bar.mtl ├── WXSA-250-M-9-Gripper-Bar.obj └── WXSA-250-M-9-Gripper-Bar.stl ├── wx250s.srdf └── wx250s.urdf /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: Google 4 | AccessModifierOffset: -1 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveMacros: false 7 | AlignConsecutiveAssignments: false 8 | AlignConsecutiveDeclarations: false 9 | AlignEscapedNewlines: Left 10 | AlignOperands: true 11 | AlignTrailingComments: true 12 | AllowAllArgumentsOnNextLine: true 13 | AllowAllConstructorInitializersOnNextLine: true 14 | AllowAllParametersOfDeclarationOnNextLine: true 15 | AllowShortBlocksOnASingleLine: Never 16 | AllowShortCaseLabelsOnASingleLine: false 17 | AllowShortFunctionsOnASingleLine: All 18 | AllowShortLambdasOnASingleLine: All 19 | AllowShortIfStatementsOnASingleLine: WithoutElse 20 | AllowShortLoopsOnASingleLine: true 21 | AlwaysBreakAfterDefinitionReturnType: None 22 | AlwaysBreakAfterReturnType: None 23 | AlwaysBreakBeforeMultilineStrings: true 24 | AlwaysBreakTemplateDeclarations: Yes 25 | BinPackArguments: true 26 | BinPackParameters: true 27 | BraceWrapping: 28 | AfterCaseLabel: false 29 | AfterClass: false 30 | AfterControlStatement: false 31 | AfterEnum: false 32 | AfterFunction: false 33 | AfterNamespace: false 34 | AfterObjCDeclaration: false 35 | AfterStruct: false 36 | AfterUnion: false 37 | AfterExternBlock: false 38 | BeforeCatch: false 39 | BeforeElse: false 40 | IndentBraces: false 41 | SplitEmptyFunction: true 42 | SplitEmptyRecord: true 43 | SplitEmptyNamespace: true 44 | BreakBeforeBinaryOperators: None 45 | BreakBeforeBraces: Attach 46 | BreakBeforeInheritanceComma: false 47 | BreakInheritanceList: BeforeColon 48 | BreakBeforeTernaryOperators: true 49 | BreakConstructorInitializersBeforeComma: false 50 | BreakConstructorInitializers: BeforeColon 51 | BreakAfterJavaFieldAnnotations: false 52 | BreakStringLiterals: true 53 | ColumnLimit: 120 54 | CommentPragmas: '^ IWYU pragma:' 55 | CompactNamespaces: false 56 | ConstructorInitializerAllOnOneLineOrOnePerLine: true 57 | ConstructorInitializerIndentWidth: 4 58 | ContinuationIndentWidth: 4 59 | Cpp11BracedListStyle: true 60 | DeriveLineEnding: true 61 | DerivePointerAlignment: true 62 | DisableFormat: false 63 | ExperimentalAutoDetectBinPacking: false 64 | FixNamespaceComments: true 65 | ForEachMacros: 66 | - foreach 67 | - Q_FOREACH 68 | - BOOST_FOREACH 69 | IncludeBlocks: Regroup 70 | IncludeCategories: 71 | - Regex: '^' 72 | Priority: 2 73 | SortPriority: 0 74 | - Regex: '^<.*\.h>' 75 | Priority: 1 76 | SortPriority: 0 77 | - Regex: '^<.*' 78 | Priority: 2 79 | SortPriority: 0 80 | - Regex: '.*' 81 | Priority: 3 82 | SortPriority: 0 83 | IncludeIsMainRegex: '([-_](test|unittest))?$' 84 | IncludeIsMainSourceRegex: '' 85 | IndentCaseLabels: true 86 | IndentGotoLabels: true 87 | IndentPPDirectives: None 88 | IndentWidth: 2 89 | IndentWrappedFunctionNames: false 90 | JavaScriptQuotes: Leave 91 | JavaScriptWrapImports: true 92 | KeepEmptyLinesAtTheStartOfBlocks: false 93 | MacroBlockBegin: '' 94 | MacroBlockEnd: '' 95 | MaxEmptyLinesToKeep: 1 96 | NamespaceIndentation: None 97 | ObjCBinPackProtocolList: Never 98 | ObjCBlockIndentWidth: 2 99 | ObjCSpaceAfterProperty: false 100 | ObjCSpaceBeforeProtocolList: true 101 | PenaltyBreakAssignment: 2 102 | PenaltyBreakBeforeFirstCallParameter: 1 103 | PenaltyBreakComment: 300 104 | PenaltyBreakFirstLessLess: 120 105 | PenaltyBreakString: 1000 106 | PenaltyBreakTemplateDeclaration: 10 107 | PenaltyExcessCharacter: 1000000 108 | PenaltyReturnTypeOnItsOwnLine: 200 109 | PointerAlignment: Left 110 | RawStringFormats: 111 | - Language: Cpp 112 | Delimiters: 113 | - cc 114 | - CC 115 | - cpp 116 | - Cpp 117 | - CPP 118 | - 'c++' 119 | - 'C++' 120 | CanonicalDelimiter: '' 121 | BasedOnStyle: google 122 | - Language: TextProto 123 | Delimiters: 124 | - pb 125 | - PB 126 | - proto 127 | - PROTO 128 | EnclosingFunctions: 129 | - EqualsProto 130 | - EquivToProto 131 | - PARSE_PARTIAL_TEXT_PROTO 132 | - PARSE_TEST_PROTO 133 | - PARSE_TEXT_PROTO 134 | - ParseTextOrDie 135 | - ParseTextProtoOrDie 136 | CanonicalDelimiter: '' 137 | BasedOnStyle: google 138 | ReflowComments: true 139 | SortIncludes: true 140 | SortUsingDeclarations: true 141 | SpaceAfterCStyleCast: false 142 | SpaceAfterLogicalNot: false 143 | SpaceAfterTemplateKeyword: true 144 | SpaceBeforeAssignmentOperators: true 145 | SpaceBeforeCpp11BracedList: false 146 | SpaceBeforeCtorInitializerColon: true 147 | SpaceBeforeInheritanceColon: true 148 | SpaceBeforeParens: ControlStatements 149 | SpaceBeforeRangeBasedForLoopColon: true 150 | SpaceInEmptyBlock: false 151 | SpaceInEmptyParentheses: false 152 | SpacesBeforeTrailingComments: 2 153 | SpacesInAngles: false 154 | SpacesInConditionalStatement: false 155 | SpacesInContainerLiterals: true 156 | SpacesInCStyleCastParentheses: false 157 | SpacesInParentheses: false 158 | SpacesInSquareBrackets: false 159 | SpaceBeforeSquareBrackets: false 160 | Standard: Auto 161 | StatementMacros: 162 | - Q_UNUSED 163 | - QT_REQUIRE_VERSION 164 | TabWidth: 8 165 | UseCRLF: false 166 | UseTab: Never 167 | ... 168 | 169 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ "master" ] 9 | pull_request: 10 | branches: [ "master" ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: [ "3.8", "3.9", "3.10" ] 20 | 21 | steps: 22 | - uses: actions/checkout@v3 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v3 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install .[test] 31 | python -m pip install flake8 pytest 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | pip install mujoco 34 | - name: Clone mujoco_menagerie repository into the tests/ folder 35 | run: | 36 | git clone https://github.com/google-deepmind/mujoco_menagerie 37 | working-directory: ${{ runner.workspace }}/pytorch_kinematics/tests 38 | - name: Lint with flake8 39 | run: | 40 | # stop the build if there are Python syntax errors or undefined names 41 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 42 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 43 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 44 | - name: Test with pytest 45 | run: | 46 | pytest --ignore=tests/mujoco_menagerie 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.mp4 3 | *.so 4 | *.pkl 5 | *.egg-info 6 | __pycache__ 7 | temp* 8 | build 9 | dist 10 | *.pyc 11 | # These are cloned/generated when testing with mujoco 12 | tests/MUJOCO_LOG.TXT 13 | tests/mujoco_menagerie/ 14 | .ipynb_checkpoints 15 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.1.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Zhong 5 | given-names: Sheng 6 | orcid: https://orcid.org/0000-0002-8658-3061 7 | - family-names: Power 8 | given-names: Thomas 9 | orcid: https://orcid.org/0000-0002-2439-3262 10 | - family-names: Gupta 11 | given-names: Ashwin 12 | - family-names: Mitrano 13 | given-names: Peter 14 | orcid: https://orcid.org/0000-0002-8701-9809 15 | title: PyTorch Kinematics 16 | doi: 10.5281/zenodo.7700587 17 | version: v0.7.1 18 | date-released: 2024-07-08 19 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 University of Michigan ARM Lab 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | of the Software, and to permit persons to whom the Software is furnished to do 8 | so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Robot Kinematics 2 | - Parallel and differentiable forward kinematics (FK), Jacobian calculation, and damped least squares inverse kinematics (IK) 3 | - Load robot description from URDF, SDF, and MJCF formats 4 | - SDF queries batched across configurations and points via [pytorch-volumetric](https://github.com/UM-ARM-Lab/pytorch_volumetric) 5 | 6 | # Installation 7 | ```shell 8 | pip install pytorch-kinematics 9 | ``` 10 | 11 | For development, clone repository somewhere, then `pip3 install -e .` to install in editable mode. 12 | 13 | ## Reference 14 | [![DOI](https://zenodo.org/badge/331721571.svg)](https://zenodo.org/badge/latestdoi/331721571) 15 | 16 | If you use this package in your research, consider citing 17 | ``` 18 | @software{Zhong_PyTorch_Kinematics_2024, 19 | author = {Zhong, Sheng and Power, Thomas and Gupta, Ashwin and Mitrano, Peter}, 20 | doi = {10.5281/zenodo.7700587}, 21 | month = feb, 22 | title = {{PyTorch Kinematics}}, 23 | version = {v0.7.1}, 24 | year = {2024} 25 | } 26 | ``` 27 | 28 | # Usage 29 | 30 | See `tests` for code samples; some are also shown here. 31 | 32 | ## Loading Robots 33 | ```python 34 | import pytorch_kinematics as pk 35 | 36 | urdf = "widowx/wx250s.urdf" 37 | # there are multiple natural end effector links so it's not a serial chain 38 | chain = pk.build_chain_from_urdf(open(urdf, mode="rb").read()) 39 | # visualize the frames (the string is also returned) 40 | chain.print_tree() 41 | """ 42 | base_link 43 | └── shoulder_link 44 | └── upper_arm_link 45 | └── upper_forearm_link 46 | └── lower_forearm_link 47 | └── wrist_link 48 | └── gripper_link 49 | └── ee_arm_link 50 | ├── gripper_prop_link 51 | └── gripper_bar_link 52 | └── fingers_link 53 | ├── left_finger_link 54 | ├── right_finger_link 55 | └── ee_gripper_link 56 | """ 57 | 58 | # extract a specific serial chain such as for inverse kinematics 59 | serial_chain = pk.SerialChain(chain, "ee_gripper_link", "base_link") 60 | serial_chain.print_tree() 61 | """ 62 | base_link 63 | └── shoulder_link 64 | └── upper_arm_link 65 | └── upper_forearm_link 66 | └── lower_forearm_link 67 | └── wrist_link 68 | └── gripper_link 69 | └── ee_arm_link 70 | └── gripper_bar_link 71 | └── fingers_link 72 | └── ee_gripper_link 73 | """ 74 | 75 | # you can also extract a serial chain with a different root than the original chain 76 | serial_chain = pk.SerialChain(chain, "ee_gripper_link", "gripper_link") 77 | serial_chain.print_tree() 78 | """ 79 | gripper_link 80 | └── ee_arm_link 81 | └── gripper_bar_link 82 | └── fingers_link 83 | └── ee_gripper_link 84 | """ 85 | ``` 86 | 87 | ## Forward Kinematics (FK) 88 | ```python 89 | import math 90 | import pytorch_kinematics as pk 91 | 92 | # load robot description from URDF and specify end effector link 93 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 94 | # prints out the (nested) tree of links 95 | print(chain) 96 | # prints out list of joint names 97 | print(chain.get_joint_parameter_names()) 98 | 99 | # specify joint values (can do so in many forms) 100 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0] 101 | # do forward kinematics and get transform objects; end_only=False gives a dictionary of transforms for all links 102 | ret = chain.forward_kinematics(th, end_only=False) 103 | # look up the transform for a specific link 104 | tg = ret['lbr_iiwa_link_7'] 105 | # get transform matrix (1,4,4), then convert to separate position and unit quaternion 106 | m = tg.get_matrix() 107 | pos = m[:, :3, 3] 108 | rot = pk.matrix_to_quaternion(m[:, :3, :3]) 109 | ``` 110 | 111 | We can parallelize FK by passing in 2D joint values, and also use CUDA if available 112 | ```python 113 | import torch 114 | import pytorch_kinematics as pk 115 | 116 | d = "cuda" if torch.cuda.is_available() else "cpu" 117 | dtype = torch.float64 118 | 119 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 120 | chain = chain.to(dtype=dtype, device=d) 121 | 122 | N = 1000 123 | th_batch = torch.rand(N, len(chain.get_joint_parameter_names()), dtype=dtype, device=d) 124 | 125 | # order of magnitudes faster when doing FK in parallel 126 | # elapsed 0.008678913116455078s for N=1000 when parallel 127 | # (N,4,4) transform matrix; only the one for the end effector is returned since end_only=True by default 128 | tg_batch = chain.forward_kinematics(th_batch) 129 | 130 | # elapsed 8.44686508178711s for N=1000 when serial 131 | for i in range(N): 132 | tg = chain.forward_kinematics(th_batch[i]) 133 | ``` 134 | 135 | We can compute gradients through the FK 136 | ```python 137 | import torch 138 | import math 139 | import pytorch_kinematics as pk 140 | 141 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 142 | 143 | # require gradient through the input joint values 144 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0], requires_grad=True) 145 | tg = chain.forward_kinematics(th) 146 | m = tg.get_matrix() 147 | pos = m[:, :3, 3] 148 | pos.norm().backward() 149 | # now th.grad is populated 150 | ``` 151 | 152 | We can load SDF and MJCF descriptions too, and pass in joint values via a dictionary (unspecified joints get th=0) for non-serial chains 153 | ```python 154 | import math 155 | import torch 156 | import pytorch_kinematics as pk 157 | 158 | chain = pk.build_chain_from_sdf(open("simple_arm.sdf").read()) 159 | ret = chain.forward_kinematics({'arm_elbow_pan_joint': math.pi / 2.0, 'arm_wrist_lift_joint': -0.5}) 160 | # recall that we specify joint values and get link transforms 161 | tg = ret['arm_wrist_roll'] 162 | 163 | # can also do this in parallel 164 | N = 100 165 | ret = chain.forward_kinematics({'arm_elbow_pan_joint': torch.rand(N, 1), 'arm_wrist_lift_joint': torch.rand(N, 1)}) 166 | # (N, 4, 4) transform object 167 | tg = ret['arm_wrist_roll'] 168 | 169 | # building the robot from a MJCF file 170 | chain = pk.build_chain_from_mjcf(open("ant.xml").read()) 171 | print(chain) 172 | print(chain.get_joint_parameter_names()) 173 | th = {'hip_1': 1.0, 'ankle_1': 1} 174 | ret = chain.forward_kinematics(th) 175 | 176 | chain = pk.build_chain_from_mjcf(open("humanoid.xml").read()) 177 | print(chain) 178 | print(chain.get_joint_parameter_names()) 179 | th = {'left_knee': 0.0, 'right_knee': 0.0} 180 | ret = chain.forward_kinematics(th) 181 | ``` 182 | 183 | ## Jacobian calculation 184 | The Jacobian (in the kinematics context) is a matrix describing how the end effector changes with respect to joint value changes 185 | (where ![dx](https://latex.codecogs.com/png.latex?%5Cinline%20%5Cdot%7Bx%7D) is the twist, or stacked velocity and angular velocity): 186 | ![jacobian](https://latex.codecogs.com/png.latex?%5Cinline%20%5Cdot%7Bx%7D%3DJ%5Cdot%7Bq%7D) 187 | 188 | For `SerialChain` we provide a differentiable and parallelizable method for computing the Jacobian with respect to the base frame. 189 | ```python 190 | import math 191 | import torch 192 | import pytorch_kinematics as pk 193 | 194 | # can convert Chain to SerialChain by choosing end effector frame 195 | chain = pk.build_chain_from_sdf(open("simple_arm.sdf").read()) 196 | # print(chain) to see the available links for use as end effector 197 | # note that any link can be chosen; it doesn't have to be a link with no children 198 | chain = pk.SerialChain(chain, "arm_wrist_roll_frame") 199 | 200 | chain = pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), "lbr_iiwa_link_7") 201 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]) 202 | # (1,6,7) tensor, with 7 corresponding to the DOF of the robot 203 | J = chain.jacobian(th) 204 | 205 | # get Jacobian in parallel and use CUDA if available 206 | N = 1000 207 | d = "cuda" if torch.cuda.is_available() else "cpu" 208 | dtype = torch.float64 209 | 210 | chain = chain.to(dtype=dtype, device=d) 211 | # Jacobian calculation is differentiable 212 | th = torch.rand(N, 7, dtype=dtype, device=d, requires_grad=True) 213 | # (N,6,7) 214 | J = chain.jacobian(th) 215 | 216 | # can get Jacobian at a point offset from the end effector (location is specified in EE link frame) 217 | # by default location is at the origin of the EE frame 218 | loc = torch.rand(N, 3, dtype=dtype, device=d) 219 | J = chain.jacobian(th, locations=loc) 220 | ``` 221 | 222 | The Jacobian can be used to do inverse kinematics. See [IK survey](https://www.math.ucsd.edu/~sbuss/ResearchWeb/ikmethods/iksurvey.pdf) 223 | for a survey of ways to do so. Note that IK may be better performed through other means (but doing it through the Jacobian can give an end-to-end differentiable method). 224 | 225 | ## Inverse Kinematics (IK) 226 | Inverse kinematics is available via damped least squares (iterative steps with Jacobian pseudo-inverse damped to avoid oscillation near singularlities). 227 | Compared to other IK libraries, these are the typical advantages over them: 228 | - not ROS dependent (many IK libraries need the robot description on the ROS parameter server) 229 | - batched in both goal specification and retries from different starting configurations 230 | - goal orientation in addition to goal position 231 | 232 | ![IK](https://i.imgur.com/QgaUME9.gif) 233 | 234 | See `tests/test_inverse_kinematics.py` for usage, but generally what you need is below: 235 | ```python 236 | full_urdf = os.path.join(search_path, urdf) 237 | chain = pk.build_serial_chain_from_urdf(open(full_urdf).read(), "lbr_iiwa_link_7") 238 | 239 | # goals are specified as Transform3d poses in the **robot frame** 240 | # so if you have the goals specified in the world frame, you also need the robot frame in the world frame 241 | pos = torch.tensor([0.0, 0.0, 0.0], device=device) 242 | rot = torch.tensor([0.0, 0.0, 0.0], device=device) 243 | rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device) 244 | 245 | # specify goals as Transform3d poses in world frame 246 | goal_in_world_frame_tf = ... 247 | # convert to robot frame (skip if you have it specified in robot frame already, or if world = robot frame) 248 | goal_in_rob_frame_tf = rob_tf.inverse().compose(goal_tf) 249 | 250 | # get robot joint limits 251 | lim = torch.tensor(chain.get_joint_limits(), device=device) 252 | 253 | # create the IK object 254 | # see the constructor for more options and their explanations, such as convergence tolerances 255 | ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=10, 256 | joint_limits=lim.T, 257 | early_stopping_any_converged=True, 258 | early_stopping_no_improvement="all", 259 | debug=False, 260 | lr=0.2) 261 | # solve IK 262 | sol = ik.solve(goal_in_rob_frame_tf) 263 | # num goals x num retries x DOF tensor of joint angles; if not converged, best solution found so far 264 | print(sol.solutions) 265 | # num goals x num retries can check for the convergence of each run 266 | print(sol.converged) 267 | # num goals x num retries can look at errors directly 268 | print(sol.err_pos) 269 | print(sol.err_rot) 270 | ``` 271 | 272 | ## SDF Queries 273 | See [pytorch-volumetric](https://github.com/UM-ARM-Lab/pytorch_volumetric) for the latest details, some instructions are pasted here: 274 | 275 | For many applications such as collision checking, it is useful to have the 276 | SDF of a multi-link robot in certain configurations. 277 | First, we create the robot model (loaded from URDF, SDF, MJCF, ...) with 278 | [pytorch kinematics](https://github.com/UM-ARM-Lab/pytorch_kinematics). 279 | For example, we will be using the KUKA 7 DOF arm model from pybullet data 280 | 281 | ```python 282 | import os 283 | import torch 284 | import pybullet_data 285 | import pytorch_kinematics as pk 286 | import pytorch_volumetric as pv 287 | 288 | urdf = "kuka_iiwa/model.urdf" 289 | search_path = pybullet_data.getDataPath() 290 | full_urdf = os.path.join(search_path, urdf) 291 | chain = pk.build_serial_chain_from_urdf(open(full_urdf).read(), "lbr_iiwa_link_7") 292 | d = "cuda" if torch.cuda.is_available() else "cpu" 293 | 294 | chain = chain.to(device=d) 295 | # paths to the link meshes are specified with their relative path inside the URDF 296 | # we need to give them the path prefix as we need their absolute path to load 297 | s = pv.RobotSDF(chain, path_prefix=os.path.join(search_path, "kuka_iiwa")) 298 | ``` 299 | 300 | By default, each link will have a `MeshSDF`. To instead use `CachedSDF` for faster queries 301 | 302 | ```python 303 | s = pv.RobotSDF(chain, path_prefix=os.path.join(search_path, "kuka_iiwa"), 304 | link_sdf_cls=pv.cache_link_sdf_factory(resolution=0.02, padding=1.0, device=d)) 305 | ``` 306 | 307 | Which when the `y=0.02` SDF slice is visualized: 308 | 309 | ![sdf slice](https://i.imgur.com/Putw72A.png) 310 | 311 | With surface points corresponding to: 312 | 313 | ![wireframe](https://i.imgur.com/L3atG9h.png) 314 | ![solid](https://i.imgur.com/XiAks7a.png) 315 | 316 | Queries on this SDF is dependent on the joint configurations (by default all zero). 317 | **Queries are batched across configurations and query points**. For example, we have a batch of 318 | joint configurations to query 319 | 320 | ```python 321 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0], device=d) 322 | N = 200 323 | th_perturbation = torch.randn(N - 1, 7, device=d) * 0.1 324 | # N x 7 joint values 325 | th = torch.cat((th.view(1, -1), th_perturbation + th)) 326 | ``` 327 | 328 | And also a batch of points to query (same points for each configuration): 329 | 330 | ```python 331 | y = 0.02 332 | query_range = np.array([ 333 | [-1, 0.5], 334 | [y, y], 335 | [-0.2, 0.8], 336 | ]) 337 | # M x 3 points 338 | coords, pts = pv.get_coordinates_and_points_in_grid(0.01, query_range, device=s.device) 339 | ``` 340 | 341 | We set the batch of joint configurations and query: 342 | 343 | ```python 344 | s.set_joint_configuration(th) 345 | # N x M SDF value 346 | # N x M x 3 SDF gradient 347 | sdf_val, sdf_grad = s(pts) 348 | ``` 349 | 350 | 351 | # Credits 352 | - `pytorch_kinematics/transforms` is extracted from [pytorch3d](https://github.com/facebookresearch/pytorch3d) with minor extensions. 353 | This was done instead of including `pytorch3d` as a dependency because it is hard to install and most of its code is unrelated. 354 | An important difference is that we use left hand multiplied transforms as is convention in robotics (T * pt) instead of their 355 | right hand multiplied transforms. 356 | - `pytorch_kinematics/urdf_parser_py`, and `pytorch_kinematics/mjcf_parser` is extracted from [kinpy](https://github.com/neka-nat/kinpy), as well as the FK logic. 357 | This repository ports the logic to pytorch, parallelizes it, and provides some extensions. 358 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pytorch_kinematics" 3 | version = "0.7.5" 4 | description = "Robot kinematics implemented in pytorch" 5 | readme = "README.md" # Optional 6 | 7 | # Specify which Python versions you support. In contrast to the 8 | # 'Programming Language' classifiers above, 'pip install' will check this 9 | # and refuse to install the project if the version does not match. See 10 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#python-requires 11 | requires-python = ">=3.6" 12 | 13 | # This is either text indicating the license for the distribution, or a file 14 | # that contains the license 15 | # https://packaging.python.org/en/latest/specifications/core-metadata/#license 16 | license = { file = "LICENSE.txt" } 17 | 18 | # This field adds keywords for your project which will appear on the 19 | # project page. What does your project relate to? 20 | # 21 | # Note that this is a list of additional keywords, separated 22 | # by commas, to be used to assist searching for the distribution in a 23 | # larger catalog. 24 | keywords = ["kinematics", "pytorch", "ik", "fk", "robotics"] # Optional 25 | authors = [ 26 | { name = "Sheng Zhong", email = "zhsh@umich.edu" } # Optional 27 | ] 28 | maintainers = [ 29 | { name = "Sheng Zhong", email = "zhsh@umich.edu" } # Optional 30 | ] 31 | 32 | # Classifiers help users find your project by categorizing it. 33 | # For a list of valid classifiers, see https://pypi.org/classifiers/ 34 | classifiers = [# Optional 35 | "Development Status :: 4 - Beta", 36 | "Intended Audience :: Developers", 37 | "License :: OSI Approved :: MIT License", 38 | "Programming Language :: Python :: 3", 39 | "Programming Language :: Python :: 3 :: Only", 40 | ] 41 | 42 | dependencies = [ 43 | 'absl-py', 44 | 'lxml', 45 | 'numpy<2', # pybullet requires numpy<2 for testing; for future versions this may be relaxed 46 | 'pyyaml', 47 | 'torch', 48 | 'matplotlib', 49 | 'pytorch_seed', 50 | 'arm_pytorch_utilities', 51 | ] 52 | 53 | [project.optional-dependencies] 54 | test = [ 55 | "pytest", 56 | "pybullet", 57 | ] 58 | 59 | [project.urls] 60 | "Homepage" = "https://github.com/UM-ARM-Lab/pytorch_kinematics" 61 | "Bug Reports" = "https://github.com/UM-ARM-Lab/pytorch_kinematics/issues" 62 | "Source" = "https://github.com/UM-ARM-Lab/pytorch_kinematics" 63 | 64 | # The following would provide a command line executable called `sample` 65 | # which executes the function `main` from this package when invoked. 66 | #[project.scripts] # Optional 67 | #sample = "sample:main" 68 | 69 | # This is configuration specific to the `setuptools` build backend. 70 | # If you are using a different build backend, you will need to change this. 71 | [tool.setuptools] 72 | # If there are data files included in your packages that need to be 73 | # installed, specify them here. 74 | 75 | [build-system] 76 | # Including torch and ninja here are needed to build the native code. 77 | # They will be installed as dependencies during the build, which can take a while the first time. 78 | requires = ["setuptools>=60.0.0", "wheel"] 79 | build-backend= "setuptools.build_meta" 80 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/__init__.py: -------------------------------------------------------------------------------- 1 | from pytorch_kinematics.sdf import * 2 | from pytorch_kinematics.urdf import * 3 | 4 | try: 5 | from pytorch_kinematics.mjcf import * 6 | except ImportError: 7 | pass 8 | from pytorch_kinematics.transforms import * 9 | from pytorch_kinematics.chain import * 10 | from pytorch_kinematics.ik import * 11 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/cfg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) 4 | TEST_DIR = os.path.join(ROOT_DIR, 'tests') 5 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/frame.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import pytorch_kinematics.transforms as tf 4 | from pytorch_kinematics.transforms import axis_and_angle_to_matrix_33 5 | 6 | 7 | class Visual(object): 8 | TYPES = ['box', 'cylinder', 'sphere', 'capsule', 'mesh'] 9 | 10 | def __init__(self, offset=None, geom_type=None, geom_param=None): 11 | if offset is None: 12 | self.offset = None 13 | else: 14 | self.offset = offset 15 | self.geom_type = geom_type 16 | self.geom_param = geom_param 17 | 18 | def __repr__(self): 19 | return "Visual(offset={0}, geom_type='{1}', geom_param={2})".format(self.offset, 20 | self.geom_type, 21 | self.geom_param) 22 | 23 | 24 | class Link(object): 25 | def __init__(self, name=None, offset=None, visuals=()): 26 | if offset is None: 27 | self.offset = None 28 | else: 29 | self.offset = offset 30 | self.name = name 31 | self.visuals = visuals 32 | 33 | def to(self, *args, **kwargs): 34 | if self.offset is not None: 35 | self.offset = self.offset.to(*args, **kwargs) 36 | return self 37 | 38 | def __repr__(self): 39 | return "Link(name='{0}', offset={1}, visuals={2})".format(self.name, 40 | self.offset, 41 | self.visuals) 42 | 43 | 44 | class Joint(object): 45 | TYPES = ['fixed', 'revolute', 'prismatic'] 46 | 47 | def __init__(self, name=None, offset=None, joint_type='fixed', axis=(0.0, 0.0, 1.0), 48 | dtype=torch.float32, device="cpu", limits=None, 49 | velocity_limits=None, effort_limits=None): 50 | if offset is None: 51 | self.offset = None 52 | else: 53 | self.offset = offset 54 | self.name = name 55 | if joint_type not in self.TYPES: 56 | raise RuntimeError("joint specified as {} type not, but we only support {}".format(joint_type, self.TYPES)) 57 | self.joint_type = joint_type 58 | if axis is None: 59 | self.axis = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device) 60 | else: 61 | if torch.is_tensor(axis): 62 | self.axis = axis.clone().detach().to(dtype=dtype, device=device) 63 | else: 64 | self.axis = torch.tensor(axis, dtype=dtype, device=device) 65 | # normalize axis to have norm 1 (needed for correct representation scaling with theta) 66 | self.axis = self.axis / self.axis.norm() 67 | 68 | self.limits = limits 69 | self.velocity_limits = velocity_limits 70 | self.effort_limits = effort_limits 71 | 72 | def to(self, *args, **kwargs): 73 | self.axis = self.axis.to(*args, **kwargs) 74 | if self.offset is not None: 75 | self.offset = self.offset.to(*args, **kwargs) 76 | return self 77 | 78 | def clamp(self, joint_position): 79 | if self.limits is None: 80 | return joint_position 81 | else: 82 | return torch.clamp(joint_position, self.limits[0], self.limits[1]) 83 | 84 | def __repr__(self): 85 | return "Joint(name='{0}', offset={1}, joint_type='{2}', axis={3})".format(self.name, 86 | self.offset, 87 | self.joint_type, 88 | self.axis) 89 | 90 | 91 | # prefix components: 92 | space = ' ' 93 | branch = '│ ' 94 | # pointers: 95 | tee = '├── ' 96 | last = '└── ' 97 | 98 | class Frame(object): 99 | def __init__(self, name=None, link=None, joint=None, children=None): 100 | self.name = 'None' if name is None else name 101 | self.link = link if link is not None else Link() 102 | self.joint = joint if joint is not None else Joint() 103 | if children is None: 104 | self.children = [] 105 | 106 | def __str__(self, prefix='', root=True): 107 | pointers = [tee] * (len(self.children) - 1) + [last] 108 | if root: 109 | ret = prefix + self.name + "\n" 110 | else: 111 | ret = "" 112 | for pointer, child in zip(pointers, self.children): 113 | ret += prefix + pointer + child.name + "\n" 114 | if child.children: 115 | extension = branch if pointer == tee else space 116 | # i.e. space because last, └── , above so no more | 117 | ret += child.__str__(prefix=prefix + extension, root=False) 118 | return ret 119 | 120 | def to(self, *args, **kwargs): 121 | self.joint = self.joint.to(*args, **kwargs) 122 | self.link = self.link.to(*args, **kwargs) 123 | self.children = [c.to(*args, **kwargs) for c in self.children] 124 | return self 125 | 126 | def add_child(self, child): 127 | self.children.append(child) 128 | 129 | def is_end(self): 130 | return (len(self.children) == 0) 131 | 132 | def get_transform(self, theta): 133 | dtype = self.joint.axis.dtype 134 | d = self.joint.axis.device 135 | if self.joint.joint_type == 'revolute': 136 | rot = axis_and_angle_to_matrix_33(self.joint.axis, theta) 137 | t = tf.Transform3d(rot=rot, dtype=dtype, device=d) 138 | elif self.joint.joint_type == 'prismatic': 139 | pos = theta.unsqueeze(1) * self.joint.axis 140 | t = tf.Transform3d(pos=pos, dtype=dtype, device=d) 141 | elif self.joint.joint_type == 'fixed': 142 | t = tf.Transform3d(default_batch_size=theta.shape[0], dtype=dtype, device=d) 143 | else: 144 | raise ValueError("Unsupported joint type %s." % self.joint.joint_type) 145 | if self.joint.offset is None: 146 | return t 147 | else: 148 | return self.joint.offset.compose(t) 149 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/jacobian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytorch_kinematics import transforms 4 | 5 | 6 | def calc_jacobian(serial_chain, th, tool=None, ret_eef_pose=False): 7 | """ 8 | Return robot Jacobian J in base frame (N,6,DOF) where dot{x} = J dot{q} 9 | The first 3 rows relate the translational velocities and the 10 | last 3 rows relate the angular velocities. 11 | 12 | tool is the transformation wrt the end effector; default is identity. If specified, will have to 13 | specify for each of the N inputs 14 | 15 | FIXME: this code assumes the joint frame and the child link frame are the same 16 | """ 17 | if not torch.is_tensor(th): 18 | th = torch.tensor(th, dtype=serial_chain.dtype, device=serial_chain.device) 19 | if len(th.shape) <= 1: 20 | N = 1 21 | th = th.reshape(1, -1) 22 | else: 23 | N = th.shape[0] 24 | ndof = th.shape[1] 25 | 26 | j_eef = torch.zeros((N, 6, ndof), dtype=serial_chain.dtype, device=serial_chain.device) 27 | 28 | if tool is None: 29 | cur_transform = transforms.Transform3d(device=serial_chain.device, 30 | dtype=serial_chain.dtype).get_matrix().repeat(N, 1, 1) 31 | else: 32 | if tool.dtype != serial_chain.dtype or tool.device != serial_chain.device: 33 | tool = tool.to(device=serial_chain.device, copy=True, dtype=serial_chain.dtype) 34 | cur_transform = tool.get_matrix() 35 | 36 | cnt = 0 37 | for f in reversed(serial_chain._serial_frames): 38 | if f.joint.joint_type == "revolute": 39 | cnt += 1 40 | # cur_transform transforms a point in eef frame into a point in joint frame, i.e. p_joint = curr_transform @ p_eef 41 | axis_in_eef = cur_transform[:, :3, :3].transpose(1, 2) @ f.joint.axis 42 | eef2joint_pos_in_joint = cur_transform[:, :3, 3].unsqueeze(2) 43 | joint2eef_rot = cur_transform[:, :3, :3].transpose(1, 2) # transpose of rotation is inverse 44 | eef2joint_pos_in_eef = joint2eef_rot @ eef2joint_pos_in_joint 45 | position_jacobian = torch.cross(axis_in_eef, eef2joint_pos_in_eef.squeeze(2), dim=1) 46 | j_eef[:, :, -cnt] = torch.cat((position_jacobian, axis_in_eef), dim=-1) 47 | elif f.joint.joint_type == "prismatic": 48 | cnt += 1 49 | j_eef[:, :3, -cnt] = (f.joint.axis.repeat(N, 1, 1) @ cur_transform[:, :3, :3])[:, 0, :] 50 | cur_frame_transform = f.get_transform(th[:, -cnt]).get_matrix() 51 | cur_transform = cur_frame_transform @ cur_transform 52 | 53 | # currently j_eef is Jacobian in end-effector frame, convert to base/world frame 54 | pose = serial_chain.forward_kinematics(th).get_matrix() 55 | rotation = pose[:, :3, :3] 56 | j_tr = torch.zeros((N, 6, 6), dtype=serial_chain.dtype, device=serial_chain.device) 57 | j_tr[:, :3, :3] = rotation 58 | j_tr[:, 3:, 3:] = rotation 59 | j_w = j_tr @ j_eef 60 | if ret_eef_pose: 61 | return j_w, pose 62 | return j_w 63 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/mjcf.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import mujoco 4 | from mujoco._structs import _MjModelBodyViews as MjModelBodyViews 5 | 6 | import pytorch_kinematics.transforms as tf 7 | from . import chain 8 | from . import frame 9 | 10 | # Converts from MuJoCo joint types to pytorch_kinematics joint types 11 | JOINT_TYPE_MAP = { 12 | mujoco.mjtJoint.mjJNT_HINGE: 'revolute', 13 | mujoco.mjtJoint.mjJNT_SLIDE: "prismatic" 14 | } 15 | 16 | 17 | def body_to_geoms(m: mujoco.MjModel, body: MjModelBodyViews): 18 | # Find all geoms which have body as parent 19 | visuals = [] 20 | for geom_id in range(m.ngeom): 21 | geom = m.geom(geom_id) 22 | if geom.bodyid == body.id: 23 | visuals.append(frame.Visual(offset=tf.Transform3d(rot=geom.quat, pos=geom.pos), geom_type=geom.type, 24 | geom_param=geom.size)) 25 | return visuals 26 | 27 | 28 | def _build_chain_recurse(m, parent_frame, parent_body): 29 | parent_frame.link.visuals = body_to_geoms(m, parent_body) 30 | # iterate through all bodies that are children of parent_body 31 | for body_id in range(m.nbody): 32 | body = m.body(body_id) 33 | if body.parentid == parent_body.id and body_id != parent_body.id: 34 | n_joints = body.jntnum 35 | if n_joints > 1: 36 | raise ValueError("composite joints not supported (could implement this if needed)") 37 | if n_joints == 1: 38 | # Find the joint for this body, again assuming there's only one joint per body. 39 | joint = m.joint(body.jntadr[0]) 40 | joint_offset = tf.Transform3d(pos=joint.pos) 41 | child_joint = frame.Joint(joint.name, offset=joint_offset, axis=joint.axis, 42 | joint_type=JOINT_TYPE_MAP[joint.type[0]], 43 | limits=(joint.range[0], joint.range[1])) 44 | else: 45 | child_joint = frame.Joint(body.name + "_fixed_joint") 46 | child_link = frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos)) 47 | child_frame = frame.Frame(name=body.name, link=child_link, joint=child_joint) 48 | parent_frame.children = parent_frame.children + [child_frame, ] 49 | _build_chain_recurse(m, child_frame, body) 50 | 51 | # iterate through all sites that are children of parent_body 52 | for site_id in range(m.nsite): 53 | site = m.site(site_id) 54 | if site.bodyid == parent_body.id: 55 | site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos)) 56 | site_frame = frame.Frame(name=site.name, link=site_link) 57 | parent_frame.children = parent_frame.children + [site_frame, ] 58 | 59 | 60 | def build_chain_from_mjcf(data, body: Union[None, str, int] = None): 61 | """ 62 | Build a Chain object from MJCF data. 63 | 64 | Parameters 65 | ---------- 66 | data : str 67 | MJCF string data. 68 | body : str or int, optional 69 | The name or index of the body to use as the root of the chain. If None, body idx=0 is used. 70 | 71 | Returns 72 | ------- 73 | chain.Chain 74 | Chain object created from MJCF. 75 | """ 76 | m = mujoco.MjModel.from_xml_string(data) 77 | if body is None: 78 | root_body = m.body(0) 79 | else: 80 | root_body = m.body(body) 81 | root_frame = frame.Frame(root_body.name, 82 | link=frame.Link(root_body.name, 83 | offset=tf.Transform3d(rot=root_body.quat, pos=root_body.pos)), 84 | joint=frame.Joint()) 85 | _build_chain_recurse(m, root_frame, root_body) 86 | return chain.Chain(root_frame) 87 | 88 | 89 | def build_serial_chain_from_mjcf(data, end_link_name, root_link_name=""): 90 | """ 91 | Build a SerialChain object from MJCF data. 92 | 93 | Parameters 94 | ---------- 95 | data : str 96 | MJCF string data. 97 | end_link_name : str 98 | The name of the link that is the end effector. 99 | root_link_name : str, optional 100 | The name of the root link. 101 | 102 | Returns 103 | ------- 104 | chain.SerialChain 105 | SerialChain object created from MJCF. 106 | """ 107 | mjcf_chain = build_chain_from_mjcf(data) 108 | serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name) 109 | return serial_chain 110 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/sdf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from .urdf_parser_py.sdf import SDF, Mesh, Cylinder, Box, Sphere 4 | from . import frame 5 | from . import chain 6 | import pytorch_kinematics.transforms as tf 7 | 8 | JOINT_TYPE_MAP = {'revolute': 'revolute', 9 | 'prismatic': 'prismatic', 10 | 'fixed': 'fixed'} 11 | 12 | 13 | def _convert_transform(pose): 14 | if pose is None: 15 | return tf.Transform3d() 16 | else: 17 | return tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor(pose[3:]), "ZYX"), pos=pose[:3]) 18 | 19 | 20 | def _convert_visuals(visuals): 21 | vlist = [] 22 | for v in visuals: 23 | v_tf = _convert_transform(v.pose) 24 | if isinstance(v.geometry, Mesh): 25 | g_type = "mesh" 26 | g_param = v.geometry.filename 27 | elif isinstance(v.geometry, Cylinder): 28 | g_type = "cylinder" 29 | v_tf = v_tf.compose( 30 | tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor([0.5 * math.pi, 0, 0]), "ZYX"))) 31 | g_param = (v.geometry.radius, v.geometry.length) 32 | elif isinstance(v.geometry, Box): 33 | g_type = "box" 34 | g_param = v.geometry.size 35 | elif isinstance(v.geometry, Sphere): 36 | g_type = "sphere" 37 | g_param = v.geometry.radius 38 | else: 39 | g_type = None 40 | g_param = None 41 | vlist.append(frame.Visual(v_tf, g_type, g_param)) 42 | return vlist 43 | 44 | 45 | def _build_chain_recurse(root_frame, lmap, joints): 46 | children = [] 47 | for j in joints: 48 | if j.parent == root_frame.link.name: 49 | child_frame = frame.Frame(j.child) 50 | link_p = lmap[j.parent] 51 | link_c = lmap[j.child] 52 | t_p = _convert_transform(link_p.pose) 53 | t_c = _convert_transform(link_c.pose) 54 | try: 55 | limits = (j.axis.limit.lower, j.axis.limit.upper) 56 | except AttributeError: 57 | limits = None 58 | child_frame.joint = frame.Joint(j.name, offset=t_p.inverse().compose(t_c), 59 | joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis.xyz, limits=limits) 60 | child_frame.link = frame.Link(link_c.name, offset=tf.Transform3d(), 61 | visuals=_convert_visuals(link_c.visuals)) 62 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints) 63 | children.append(child_frame) 64 | return children 65 | 66 | 67 | def build_chain_from_sdf(data): 68 | """ 69 | Build a Chain object from SDF data. 70 | 71 | Parameters 72 | ---------- 73 | data : str 74 | SDF string data. 75 | 76 | Returns 77 | ------- 78 | chain.Chain 79 | Chain object created from SDF. 80 | """ 81 | sdf = SDF.from_xml_string(data) 82 | robot = sdf.model 83 | lmap = robot.link_map 84 | joints = robot.joints 85 | n_joints = len(joints) 86 | has_root = [True for _ in range(len(joints))] 87 | for i in range(n_joints): 88 | for j in range(i + 1, n_joints): 89 | if joints[i].parent == joints[j].child: 90 | has_root[i] = False 91 | elif joints[j].parent == joints[i].child: 92 | has_root[j] = False 93 | for i in range(n_joints): 94 | if has_root[i]: 95 | root_link = lmap[joints[i].parent] 96 | break 97 | root_frame = frame.Frame(root_link.name) 98 | root_frame.joint = frame.Joint(offset=_convert_transform(root_link.pose)) 99 | root_frame.link = frame.Link(root_link.name, tf.Transform3d(), 100 | _convert_visuals(root_link.visuals)) 101 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints) 102 | return chain.Chain(root_frame) 103 | 104 | 105 | def build_serial_chain_from_sdf(data, end_link_name, root_link_name=""): 106 | mjcf_chain = build_chain_from_sdf(data) 107 | serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name) 108 | return serial_chain 109 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | from pytorch_kinematics.transforms.perturbation import sample_perturbations 4 | from .rotation_conversions import ( 5 | axis_angle_to_quaternion, 6 | euler_angles_to_matrix, 7 | matrix_to_axis_angle, 8 | matrix_to_euler_angles, 9 | matrix_to_quaternion, 10 | matrix_to_rotation_6d, 11 | quaternion_apply, 12 | quaternion_invert, 13 | quaternion_multiply, 14 | quaternion_raw_multiply, 15 | quaternion_to_matrix, 16 | quaternion_from_euler, 17 | quaternion_to_axis_angle, 18 | random_quaternions, 19 | random_rotation, 20 | random_rotations, 21 | rotation_6d_to_matrix, 22 | standardize_quaternion, 23 | axis_and_angle_to_matrix_33, 24 | axis_and_d_to_pris_matrix, 25 | wxyz_to_xyzw, 26 | xyzw_to_wxyz, 27 | matrix44_to_se3_9d, 28 | se3_9d_to_matrix44, 29 | pos_rot_to_matrix, 30 | matrix_to_pos_rot, 31 | ) 32 | from .so3 import ( 33 | so3_exp_map, 34 | so3_log_map, 35 | so3_relative_angle, 36 | so3_rotation_angle, 37 | ) 38 | from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate 39 | from pytorch_kinematics.transforms.math import ( 40 | quaternion_angular_distance, 41 | acos_linear_extrapolation, 42 | quaternion_close, 43 | quaternion_slerp, 44 | ) 45 | 46 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 47 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/transforms/math.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the BSD-style license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple, Union 9 | 10 | import torch 11 | 12 | 13 | def quaternion_angular_distance(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor: 14 | """ 15 | Computes the angular distance between two quaternions. 16 | Args: 17 | q1: First quaternion (assume normalized). 18 | q2: Second quaternion (assume normalized). 19 | Returns: 20 | Angular distance between the two quaternions. 21 | """ 22 | 23 | # Compute the cosine of the angle between the two quaternions 24 | cos_theta = torch.sum(q1 * q2, dim=-1) 25 | # we use atan2 instead of acos for better numerical stability 26 | cos_theta = torch.clamp(cos_theta, -1.0, 1.0) 27 | abs_dot = torch.abs(cos_theta) 28 | # identity sin^2(theta) = 1 - cos^2(theta) 29 | sin_half_theta = torch.sqrt(1.0 - torch.square(abs_dot)) 30 | theta = 2.0 * torch.atan2(sin_half_theta, abs_dot) 31 | 32 | # theta for the ones that are close gets 0 and we don't care about them 33 | close = quaternion_close(q1, q2) 34 | theta[close] = 0 35 | return theta 36 | 37 | 38 | def quaternion_close(q1: torch.Tensor, q2: torch.Tensor, eps: float = 1e-4): 39 | """ 40 | Returns true if two quaternions are close to each other. Assumes the quaternions are normalized. 41 | Based on: https://math.stackexchange.com/a/90098/516340 42 | 43 | """ 44 | dist = 1 - torch.square(torch.sum(q1 * q2, dim=-1)) 45 | return torch.all(dist < eps) 46 | 47 | 48 | def quaternion_slerp(q1: torch.Tensor, q2: torch.Tensor, t: Union[float, torch.tensor]) -> torch.Tensor: 49 | """ 50 | Spherical linear interpolation between two quaternions. 51 | Args: 52 | q1: First quaternion (assume normalized). 53 | q2: Second quaternion (assume normalized). 54 | t: Interpolation parameter. 55 | Returns: 56 | Interpolated quaternion. 57 | """ 58 | # Compute the cosine of the angle between the two quaternions 59 | cos_theta = torch.sum(q1 * q2, dim=-1) 60 | 61 | # reverse the direction of q2 if q1 and q2 are not in the same hemisphere 62 | to_invert = cos_theta < 0 63 | q2[to_invert] = -q2[to_invert] 64 | cos_theta[to_invert] = -cos_theta[to_invert] 65 | 66 | # If the quaternions are close, perform a linear interpolation 67 | if torch.all(cos_theta > 1.0 - 1e-6): 68 | return q1 + t * (q2 - q1) 69 | 70 | # Ensure the angle is between 0 and pi 71 | theta = torch.acos(cos_theta) 72 | sin_theta = torch.sin(theta) 73 | 74 | # Perform the interpolation 75 | w1 = torch.sin((1.0 - t) * theta) / sin_theta 76 | w2 = torch.sin(t * theta) / sin_theta 77 | return w1[:, None] * q1 + w2[:, None] * q2 78 | 79 | 80 | def acos_linear_extrapolation( 81 | x: torch.Tensor, 82 | bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4, 83 | ) -> torch.Tensor: 84 | """ 85 | Implements `arccos(x)` which is linearly extrapolated outside `x`'s original 86 | domain of `(-1, 1)`. This allows for stable backpropagation in case `x` 87 | is not guaranteed to be strictly within `(-1, 1)`. 88 | More specifically: 89 | ``` 90 | if -bound <= x <= bound: 91 | acos_linear_extrapolation(x) = acos(x) 92 | elif x <= -bound: # 1st order Taylor approximation 93 | acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound)) 94 | else: # x >= bound 95 | acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound) 96 | ``` 97 | Note that `bound` can be made more specific with setting 98 | `bound=[lower_bound, upper_bound]` as detailed below. 99 | Args: 100 | x: Input `Tensor`. 101 | bound: A float constant or a float 2-tuple defining the region for the 102 | linear extrapolation of `acos`. 103 | If `bound` is a float scalar, linearly interpolates acos for 104 | `x <= -bound` or `bound <= x`. 105 | If `bound` is a 2-tuple, the first/second element of `bound` 106 | describes the lower/upper bound that defines the lower/upper 107 | extrapolation region, i.e. the region where 108 | `x <= bound[0]`/`bound[1] <= x`. 109 | Note that all elements of `bound` have to be within (-1, 1). 110 | Returns: 111 | acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`. 112 | """ 113 | 114 | if isinstance(bound, float): 115 | upper_bound = bound 116 | lower_bound = -bound 117 | else: 118 | lower_bound, upper_bound = bound 119 | 120 | if lower_bound > upper_bound: 121 | raise ValueError("lower bound has to be smaller or equal to upper bound.") 122 | 123 | if lower_bound <= -1.0 or upper_bound >= 1.0: 124 | raise ValueError("Both lower bound and upper bound have to be within (-1, 1).") 125 | 126 | # init an empty tensor and define the domain sets 127 | acos_extrap = torch.empty_like(x) 128 | x_upper = x >= upper_bound 129 | x_lower = x <= lower_bound 130 | x_mid = (~x_upper) & (~x_lower) 131 | 132 | # acos calculation for upper_bound < x < lower_bound 133 | acos_extrap[x_mid] = torch.acos(x[x_mid]) 134 | # the linear extrapolation for x >= upper_bound 135 | acos_extrap[x_upper] = _acos_linear_approximation(x[x_upper], upper_bound) 136 | # the linear extrapolation for x <= lower_bound 137 | acos_extrap[x_lower] = _acos_linear_approximation(x[x_lower], lower_bound) 138 | 139 | return acos_extrap 140 | 141 | 142 | def _acos_linear_approximation(x: torch.Tensor, x0: float) -> torch.Tensor: 143 | """ 144 | Calculates the 1st order Taylor expansion of `arccos(x)` around `x0`. 145 | """ 146 | return (x - x0) * _dacos_dx(x0) + math.acos(x0) 147 | 148 | 149 | def _dacos_dx(x: float) -> float: 150 | """ 151 | Calculates the derivative of `arccos(x)` w.r.t. `x`. 152 | """ 153 | return (-1.0) / math.sqrt(1.0 - x * x) 154 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/transforms/perturbation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33 3 | 4 | 5 | def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma, axis_of_rotation=None, 6 | translation_perpendicular_to_axis_of_rotation=True): 7 | """ 8 | Sample perturbations around the given transform. The translation and rotation are sampled independently from 9 | 0 mean gaussians. The angular perturbations' directions are uniformly sampled from the unit sphere while its 10 | magnitude is sampled from a gaussian. 11 | :param T: given transform to perturb around 12 | :param num_perturbations: number of perturbations to sample 13 | :param radian_sigma: standard deviation of the gaussian angular perturbation in radians 14 | :param translation_sigma: standard deviation of the gaussian translation perturbation in meters / T units 15 | :param axis_of_rotation: if not None, the axis of rotation to sample the perturbations around 16 | :param translation_perpendicular_to_axis_of_rotation: if True and the axis_of_rotation is not None, the translation 17 | perturbations will be perpendicular to the axis of rotation 18 | :return: perturbed transforms; may not include the original transform 19 | """ 20 | dtype = T.dtype 21 | device = T.device 22 | perturbed = torch.eye(4, dtype=dtype, device=device).repeat(num_perturbations, 1, 1) 23 | 24 | delta_t = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * translation_sigma 25 | # consider sampling from the Bingham distribution 26 | theta = torch.randn(num_perturbations, dtype=dtype, device=device) * radian_sigma 27 | if axis_of_rotation is not None: 28 | axis_angle = axis_of_rotation 29 | # sample translation perturbation perpendicular to the axis of rotation 30 | # remove the component of delta_t along the axis_of_rotation 31 | if translation_perpendicular_to_axis_of_rotation: 32 | delta_t -= (delta_t * axis_of_rotation).sum(dim=1, keepdim=True) * axis_of_rotation 33 | else: 34 | axis_angle = torch.randn((num_perturbations, 3), dtype=dtype, device=device) 35 | # normalize to unit length 36 | axis_angle = axis_angle / axis_angle.norm(dim=1, keepdim=True) 37 | 38 | delta_R = axis_and_angle_to_matrix_33(axis_angle, theta) 39 | perturbed[:, :3, :3] = delta_R @ T[..., :3, :3] 40 | perturbed[:, :3, 3] = T[..., :3, 3] 41 | 42 | perturbed[:, :3, 3] += delta_t 43 | 44 | return perturbed 45 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/transforms/so3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. 2 | 3 | import warnings 4 | from typing import Tuple 5 | import torch 6 | 7 | from pytorch_kinematics.transforms.math import acos_linear_extrapolation 8 | 9 | HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5 10 | 11 | 12 | def so3_relative_angle(R1, R2, cos_angle: bool = False): 13 | """ 14 | Calculates the relative angle (in radians) between pairs of 15 | rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))` 16 | 17 | .. note:: 18 | This corresponds to a geodesic distance on the 3D manifold of rotation 19 | matrices. 20 | 21 | Args: 22 | R1: Batch of rotation matrices of shape `(minibatch, 3, 3)`. 23 | R2: Batch of rotation matrices of shape `(minibatch, 3, 3)`. 24 | cos_angle: If==True return cosine of the relative angle rather than 25 | the angle itself. This can avoid the unstable 26 | calculation of `acos`. 27 | 28 | Returns: 29 | Corresponding rotation angles of shape `(minibatch,)`. 30 | If `cos_angle==True`, returns the cosine of the angles. 31 | 32 | Raises: 33 | ValueError if `R1` or `R2` is of incorrect shape. 34 | ValueError if `R1` or `R2` has an unexpected trace. 35 | """ 36 | R12 = torch.bmm(R1, R2.permute(0, 2, 1)) 37 | return so3_rotation_angle(R12, cos_angle=cos_angle) 38 | 39 | 40 | def so3_rotation_angle( 41 | R: torch.Tensor, 42 | eps: float = 1e-4, 43 | cos_angle: bool = False, 44 | cos_bound: float = 1e-4, 45 | ) -> torch.Tensor: 46 | """ 47 | Calculates angles (in radians) of a batch of rotation matrices `R` with 48 | `angle = acos(0.5 * (Trace(R)-1))`. The trace of the 49 | input matrices is checked to be in the valid range `[-1-eps,3+eps]`. 50 | The `eps` argument is a small constant that allows for small errors 51 | caused by limited machine precision. 52 | Args: 53 | R: Batch of rotation matrices of shape `(minibatch, 3, 3)`. 54 | eps: Tolerance for the valid trace check. 55 | cos_angle: If==True return cosine of the rotation angles rather than 56 | the angle itself. This can avoid the unstable 57 | calculation of `acos`. 58 | cos_bound: Clamps the cosine of the rotation angle to 59 | [-1 + cos_bound, 1 - cos_bound] to avoid non-finite outputs/gradients 60 | of the `acos` call. Note that the non-finite outputs/gradients 61 | are returned when the angle is requested (i.e. `cos_angle==False`) 62 | and the rotation angle is close to 0 or π. 63 | Returns: 64 | Corresponding rotation angles of shape `(minibatch,)`. 65 | If `cos_angle==True`, returns the cosine of the angles. 66 | Raises: 67 | ValueError if `R` is of incorrect shape. 68 | ValueError if `R` has an unexpected trace. 69 | """ 70 | 71 | N, dim1, dim2 = R.shape 72 | if dim1 != 3 or dim2 != 3: 73 | raise ValueError("Input has to be a batch of 3x3 Tensors.") 74 | 75 | rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] 76 | 77 | if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any(): 78 | raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].") 79 | 80 | # phi ... rotation angle 81 | phi_cos = (rot_trace - 1.0) * 0.5 82 | 83 | if cos_angle: 84 | return phi_cos 85 | else: 86 | if cos_bound > 0.0: 87 | return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound) 88 | else: 89 | return torch.acos(phi_cos) 90 | 91 | 92 | def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor: 93 | """ 94 | Convert a batch of logarithmic representations of rotation matrices `log_rot` 95 | to a batch of 3x3 rotation matrices using Rodrigues formula [1]. 96 | In the logarithmic representation, each rotation matrix is represented as 97 | a 3-dimensional vector (`log_rot`) who's l2-norm and direction correspond 98 | to the magnitude of the rotation angle and the axis of rotation respectively. 99 | The conversion has a singularity around `log(R) = 0` 100 | which is handled by clamping controlled with the `eps` argument. 101 | Args: 102 | log_rot: Batch of vectors of shape `(minibatch, 3)`. 103 | eps: A float constant handling the conversion singularity. 104 | Returns: 105 | Batch of rotation matrices of shape `(minibatch, 3, 3)`. 106 | Raises: 107 | ValueError if `log_rot` is of incorrect shape. 108 | [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula 109 | """ 110 | return _so3_exp_map(log_rot, eps=eps)[0] 111 | 112 | 113 | def _so3_exp_map( 114 | log_rot: torch.Tensor, eps: float = 0.0001 115 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 116 | """ 117 | A helper function that computes the so3 exponential map and, 118 | apart from the rotation matrix, also returns intermediate variables 119 | that can be re-used in other functions. 120 | """ 121 | _, dim = log_rot.shape 122 | if dim != 3: 123 | raise ValueError("Input tensor shape has to be Nx3.") 124 | 125 | nrms = (log_rot * log_rot).sum(1) 126 | # phis ... rotation angles 127 | rot_angles = torch.clamp(nrms, eps).sqrt() 128 | rot_angles_inv = 1.0 / rot_angles 129 | fac1 = rot_angles_inv * rot_angles.sin() 130 | fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos()) 131 | skews = hat(log_rot) 132 | skews_square = torch.bmm(skews, skews) 133 | 134 | R = ( 135 | # pyre-fixme[16]: `float` has no attribute `__getitem__`. 136 | fac1[:, None, None] * skews 137 | + fac2[:, None, None] * skews_square 138 | + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None] 139 | ) 140 | 141 | return R, rot_angles, skews, skews_square 142 | 143 | 144 | def so3_log_map(R, eps: float = 0.0001): 145 | """ 146 | Convert a batch of 3x3 rotation matrices `R` 147 | to a batch of 3-dimensional matrix logarithms of rotation matrices 148 | The conversion has a singularity around `(R=I)` which is handled 149 | by clamping controlled with the `eps` argument. 150 | 151 | Args: 152 | R: batch of rotation matrices of shape `(minibatch, 3, 3)`. 153 | eps: A float constant handling the conversion singularity. 154 | 155 | Returns: 156 | Batch of logarithms of input rotation matrices 157 | of shape `(minibatch, 3)`. 158 | 159 | Raises: 160 | ValueError if `R` is of incorrect shape. 161 | ValueError if `R` has an unexpected trace. 162 | """ 163 | 164 | N, dim1, dim2 = R.shape 165 | if dim1 != 3 or dim2 != 3: 166 | raise ValueError("Input has to be a batch of 3x3 Tensors.") 167 | 168 | phi = so3_rotation_angle(R) 169 | 170 | phi_sin = phi.sin() 171 | 172 | phi_denom = ( 173 | torch.clamp(phi_sin.abs(), eps) * phi_sin.sign() 174 | + (phi_sin == 0).type_as(phi) * eps 175 | ) 176 | 177 | log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1)) 178 | log_rot = hat_inv(log_rot_hat) 179 | 180 | return log_rot 181 | 182 | 183 | def hat_inv(h): 184 | """ 185 | Compute the inverse Hat operator [1] of a batch of 3x3 matrices. 186 | 187 | Args: 188 | h: Batch of skew-symmetric matrices of shape `(minibatch, 3, 3)`. 189 | 190 | Returns: 191 | Batch of 3d vectors of shape `(minibatch, 3, 3)`. 192 | 193 | Raises: 194 | ValueError if `h` is of incorrect shape. 195 | ValueError if `h` not skew-symmetric. 196 | 197 | [1] https://en.wikipedia.org/wiki/Hat_operator 198 | """ 199 | 200 | N, dim1, dim2 = h.shape 201 | if dim1 != 3 or dim2 != 3: 202 | raise ValueError("Input has to be a batch of 3x3 Tensors.") 203 | 204 | ss_diff = (h + h.permute(0, 2, 1)).abs().max() 205 | if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL: 206 | raise ValueError("One of input matrices not skew-symmetric.") 207 | 208 | x = h[:, 2, 1] 209 | y = h[:, 0, 2] 210 | z = h[:, 1, 0] 211 | 212 | v = torch.stack((x, y, z), dim=1) 213 | 214 | return v 215 | 216 | 217 | def hat(v): 218 | """ 219 | Compute the Hat operator [1] of a batch of 3D vectors. 220 | 221 | Args: 222 | v: Batch of vectors of shape `(minibatch , 3)`. 223 | 224 | Returns: 225 | Batch of skew-symmetric matrices of shape 226 | `(minibatch, 3 , 3)` where each matrix is of the form: 227 | `[ 0 -v_z v_y ] 228 | [ v_z 0 -v_x ] 229 | [ -v_y v_x 0 ]` 230 | 231 | Raises: 232 | ValueError if `v` is of incorrect shape. 233 | 234 | [1] https://en.wikipedia.org/wiki/Hat_operator 235 | """ 236 | 237 | N, dim = v.shape 238 | if dim != 3: 239 | raise ValueError("Input vectors have to be 3-dimensional.") 240 | 241 | h = v.new_zeros(N, 3, 3) 242 | 243 | x, y, z = v.unbind(1) 244 | 245 | h[:, 0, 1] = -z 246 | h[:, 0, 2] = y 247 | h[:, 1, 0] = z 248 | h[:, 1, 2] = -x 249 | h[:, 2, 0] = -y 250 | h[:, 2, 1] = x 251 | 252 | return h 253 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/urdf.py: -------------------------------------------------------------------------------- 1 | from .urdf_parser_py.urdf import URDF, Mesh, Cylinder, Box, Sphere 2 | from . import frame 3 | from . import chain 4 | import torch 5 | import pytorch_kinematics.transforms as tf 6 | 7 | JOINT_TYPE_MAP = {'revolute': 'revolute', 8 | 'continuous': 'revolute', 9 | 'prismatic': 'prismatic', 10 | 'fixed': 'fixed'} 11 | 12 | 13 | def _convert_transform(origin): 14 | if origin is None: 15 | return tf.Transform3d() 16 | else: 17 | rpy = torch.tensor(origin.rpy, dtype=torch.float32, device="cpu") 18 | return tf.Transform3d(rot=tf.quaternion_from_euler(rpy, "sxyz"), pos=origin.xyz) 19 | 20 | 21 | def _convert_visual(visual): 22 | if visual is None or visual.geometry is None: 23 | return frame.Visual() 24 | else: 25 | v_tf = _convert_transform(visual.origin) 26 | if isinstance(visual.geometry, Mesh): 27 | g_type = "mesh" 28 | g_param = (visual.geometry.filename, visual.geometry.scale) 29 | elif isinstance(visual.geometry, Cylinder): 30 | g_type = "cylinder" 31 | g_param = (visual.geometry.radius, visual.geometry.length) 32 | elif isinstance(visual.geometry, Box): 33 | g_type = "box" 34 | g_param = visual.geometry.size 35 | elif isinstance(visual.geometry, Sphere): 36 | g_type = "sphere" 37 | g_param = visual.geometry.radius 38 | else: 39 | g_type = None 40 | g_param = None 41 | return frame.Visual(v_tf, g_type, g_param) 42 | 43 | 44 | def _build_chain_recurse(root_frame, lmap, joints): 45 | children = [] 46 | for j in joints: 47 | if j.parent == root_frame.link.name: 48 | try: 49 | limits = (j.limit.lower, j.limit.upper) 50 | except AttributeError: 51 | limits = None 52 | # URDF assumes symmetric velocity and effort limits 53 | try: 54 | velocity_limits = (-j.limit.velocity, j.limit.velocity) 55 | except AttributeError: 56 | velocity_limits = None 57 | try: 58 | effort_limits = (-j.limit.effort, j.limit.effort) 59 | except AttributeError: 60 | effort_limits = None 61 | child_frame = frame.Frame(j.child) 62 | child_frame.joint = frame.Joint(j.name, offset=_convert_transform(j.origin), 63 | joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis, limits=limits, 64 | velocity_limits=velocity_limits, effort_limits=effort_limits) 65 | link = lmap[j.child] 66 | child_frame.link = frame.Link(link.name, offset=_convert_transform(link.origin), 67 | visuals=[_convert_visual(link.visual)]) 68 | child_frame.children = _build_chain_recurse(child_frame, lmap, joints) 69 | children.append(child_frame) 70 | return children 71 | 72 | 73 | def build_chain_from_urdf(data): 74 | """ 75 | Build a Chain object from URDF data. 76 | 77 | Parameters 78 | ---------- 79 | data : str 80 | URDF string data. 81 | 82 | Returns 83 | ------- 84 | chain.Chain 85 | Chain object created from URDF. 86 | 87 | Example 88 | ------- 89 | >>> import pytorch_kinematics as pk 90 | >>> data = ''' 91 | ... 92 | ... 93 | ... 94 | ... 95 | ... 96 | ... 97 | ... ''' 98 | >>> chain = pk.build_chain_from_urdf(data) 99 | >>> print(chain) 100 | link1_frame 101 | link2_frame 102 | 103 | """ 104 | robot = URDF.from_xml_string(data) 105 | lmap = robot.link_map 106 | joints = robot.joints 107 | n_joints = len(joints) 108 | has_root = [True for _ in range(len(joints))] 109 | for i in range(n_joints): 110 | for j in range(i + 1, n_joints): 111 | if joints[i].parent == joints[j].child: 112 | has_root[i] = False 113 | elif joints[j].parent == joints[i].child: 114 | has_root[j] = False 115 | for i in range(n_joints): 116 | if has_root[i]: 117 | root_link = lmap[joints[i].parent] 118 | break 119 | root_frame = frame.Frame(root_link.name) 120 | root_frame.joint = frame.Joint() 121 | root_frame.link = frame.Link(root_link.name, _convert_transform(root_link.origin), 122 | [_convert_visual(root_link.visual)]) 123 | root_frame.children = _build_chain_recurse(root_frame, lmap, joints) 124 | return chain.Chain(root_frame) 125 | 126 | 127 | def build_serial_chain_from_urdf(data, end_link_name, root_link_name=""): 128 | """ 129 | Build a SerialChain object from urdf data. 130 | 131 | Parameters 132 | ---------- 133 | data : str 134 | URDF string data. 135 | end_link_name : str 136 | The name of the link that is the end effector. 137 | root_link_name : str, optional 138 | The name of the root link. 139 | 140 | Returns 141 | ------- 142 | chain.SerialChain 143 | SerialChain object created from URDF. 144 | """ 145 | urdf_chain = build_chain_from_urdf(data) 146 | return chain.SerialChain(urdf_chain, end_link_name, root_link_name or '') 147 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/urdf_parser_py/__init__.py: -------------------------------------------------------------------------------- 1 | from . import urdf 2 | from . import sdf 3 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/urdf_parser_py/sdf.py: -------------------------------------------------------------------------------- 1 | from .xml_reflection.basics import * 2 | from . import xml_reflection as xmlr 3 | 4 | # What is the scope of plugins? Model, World, Sensor? 5 | 6 | xmlr.start_namespace('sdf') 7 | 8 | name_attribute = xmlr.Attribute('name', str, False) 9 | pose_element = xmlr.Element('pose', 'vector6', False) 10 | 11 | 12 | class Inertia(xmlr.Object): 13 | KEYS = ['ixx', 'ixy', 'ixz', 'iyy', 'iyz', 'izz'] 14 | 15 | def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0): 16 | self.ixx = ixx 17 | self.ixy = ixy 18 | self.ixz = ixz 19 | self.iyy = iyy 20 | self.iyz = iyz 21 | self.izz = izz 22 | 23 | def to_matrix(self): 24 | return [ 25 | [self.ixx, self.ixy, self.ixz], 26 | [self.ixy, self.iyy, self.iyz], 27 | [self.ixz, self.iyz, self.izz]] 28 | 29 | 30 | xmlr.reflect(Inertia, 31 | params=[xmlr.Element(key, float) for key in Inertia.KEYS]) 32 | 33 | 34 | # Pretty much copy-paste... Better method? 35 | # Use multiple inheritance to separate the objects out so they are unique? 36 | 37 | 38 | class Inertial(xmlr.Object): 39 | def __init__(self, mass=0.0, inertia=None, pose=None): 40 | self.mass = mass 41 | self.inertia = inertia 42 | self.pose = pose 43 | 44 | 45 | xmlr.reflect(Inertial, params=[ 46 | xmlr.Element('mass', float), 47 | xmlr.Element('inertia', Inertia), 48 | pose_element 49 | ]) 50 | 51 | 52 | class Box(xmlr.Object): 53 | def __init__(self, size=None): 54 | self.size = size 55 | 56 | 57 | xmlr.reflect(Box, tag='box', params=[ 58 | xmlr.Element('size', 'vector3') 59 | ]) 60 | 61 | 62 | class Cylinder(xmlr.Object): 63 | def __init__(self, radius=0.0, length=0.0): 64 | self.radius = radius 65 | self.length = length 66 | 67 | 68 | xmlr.reflect(Cylinder, tag='cylinder', params=[ 69 | xmlr.Element('radius', float), 70 | xmlr.Element('length', float) 71 | ]) 72 | 73 | 74 | class Sphere(xmlr.Object): 75 | def __init__(self, radius=0.0): 76 | self.radius = radius 77 | 78 | 79 | xmlr.reflect(Sphere, tag='sphere', params=[ 80 | xmlr.Element('radius', float) 81 | ]) 82 | 83 | 84 | class Mesh(xmlr.Object): 85 | def __init__(self, filename=None, scale=None): 86 | self.filename = filename 87 | self.scale = scale 88 | 89 | 90 | xmlr.reflect(Mesh, tag='mesh', params=[ 91 | xmlr.Element('filename', str), 92 | xmlr.Element('scale', 'vector3', required=False) 93 | ]) 94 | 95 | 96 | class GeometricType(xmlr.ValueType): 97 | def __init__(self): 98 | self.factory = xmlr.FactoryType('geometric', { 99 | 'box': Box, 100 | 'cylinder': Cylinder, 101 | 'sphere': Sphere, 102 | 'mesh': Mesh 103 | }) 104 | 105 | def from_xml(self, node, path): 106 | children = xml_children(node) 107 | assert len(children) == 1, 'One element only for geometric' 108 | return self.factory.from_xml(children[0], path=path) 109 | 110 | def write_xml(self, node, obj): 111 | name = self.factory.get_name(obj) 112 | child = node_add(node, name) 113 | obj.write_xml(child) 114 | 115 | 116 | xmlr.add_type('geometric', GeometricType()) 117 | 118 | 119 | class Script(xmlr.Object): 120 | def __init__(self, uri=None, name=None): 121 | self.uri = uri 122 | self.name = name 123 | 124 | 125 | xmlr.reflect(Script, tag='script', params=[ 126 | xmlr.Element('name', str, False), 127 | xmlr.Element('uri', str, False) 128 | ]) 129 | 130 | 131 | class Material(xmlr.Object): 132 | def __init__(self, name=None, script=None): 133 | self.name = name 134 | self.script = script 135 | 136 | 137 | xmlr.reflect(Material, tag='material', params=[ 138 | name_attribute, 139 | xmlr.Element('script', Script, False) 140 | ]) 141 | 142 | 143 | class Visual(xmlr.Object): 144 | def __init__(self, name=None, geometry=None, pose=None): 145 | self.name = name 146 | self.geometry = geometry 147 | self.pose = pose 148 | 149 | 150 | xmlr.reflect(Visual, tag='visual', params=[ 151 | name_attribute, 152 | xmlr.Element('geometry', 'geometric'), 153 | xmlr.Element('material', Material, False), 154 | pose_element 155 | ]) 156 | 157 | 158 | class Collision(xmlr.Object): 159 | def __init__(self, name=None, geometry=None, pose=None): 160 | self.name = name 161 | self.geometry = geometry 162 | self.pose = pose 163 | 164 | 165 | xmlr.reflect(Collision, tag='collision', params=[ 166 | name_attribute, 167 | xmlr.Element('geometry', 'geometric'), 168 | pose_element 169 | ]) 170 | 171 | 172 | class Dynamics(xmlr.Object): 173 | def __init__(self, damping=None, friction=None): 174 | self.damping = damping 175 | self.friction = friction 176 | 177 | 178 | xmlr.reflect(Dynamics, tag='dynamics', params=[ 179 | xmlr.Element('damping', float, False), 180 | xmlr.Element('friction', float, False) 181 | ]) 182 | 183 | 184 | class Limit(xmlr.Object): 185 | def __init__(self, lower=None, upper=None): 186 | self.lower = lower 187 | self.upper = upper 188 | 189 | 190 | xmlr.reflect(Limit, tag='limit', params=[ 191 | xmlr.Element('lower', float, False), 192 | xmlr.Element('upper', float, False) 193 | ]) 194 | 195 | 196 | class Axis(xmlr.Object): 197 | def __init__(self, xyz=None, limit=None, dynamics=None, 198 | use_parent_model_frame=None): 199 | self.xyz = xyz 200 | self.limit = limit 201 | self.dynamics = dynamics 202 | self.use_parent_model_frame = use_parent_model_frame 203 | 204 | 205 | xmlr.reflect(Axis, tag='axis', params=[ 206 | xmlr.Element('xyz', 'vector3'), 207 | xmlr.Element('limit', Limit, False), 208 | xmlr.Element('dynamics', Dynamics, False), 209 | xmlr.Element('use_parent_model_frame', bool, False) 210 | ]) 211 | 212 | 213 | class Joint(xmlr.Object): 214 | TYPES = ['unknown', 'revolute', 'gearbox', 'revolute2', 215 | 'prismatic', 'ball', 'screw', 'universal', 'fixed'] 216 | 217 | def __init__(self, name=None, parent=None, child=None, joint_type=None, 218 | axis=None, pose=None): 219 | self.aggregate_init() 220 | self.name = name 221 | self.parent = parent 222 | self.child = child 223 | self.type = joint_type 224 | self.axis = axis 225 | self.pose = pose 226 | 227 | # Aliases 228 | @property 229 | def joint_type(self): return self.type 230 | 231 | @joint_type.setter 232 | def joint_type(self, value): self.type = value 233 | 234 | 235 | xmlr.reflect(Joint, tag='joint', params=[ 236 | name_attribute, 237 | xmlr.Attribute('type', str, False), 238 | xmlr.Element('axis', Axis), 239 | xmlr.Element('parent', str), 240 | xmlr.Element('child', str), 241 | pose_element 242 | ]) 243 | 244 | 245 | class Link(xmlr.Object): 246 | def __init__(self, name=None, pose=None, inertial=None, kinematic=False): 247 | self.aggregate_init() 248 | self.name = name 249 | self.pose = pose 250 | self.inertial = inertial 251 | self.kinematic = kinematic 252 | self.visuals = [] 253 | self.collisions = [] 254 | 255 | 256 | xmlr.reflect(Link, tag='link', params=[ 257 | name_attribute, 258 | xmlr.Element('inertial', Inertial), 259 | xmlr.Attribute('kinematic', bool, False), 260 | xmlr.AggregateElement('visual', Visual, var='visuals'), 261 | xmlr.AggregateElement('collision', Collision, var='collisions'), 262 | pose_element 263 | ]) 264 | 265 | 266 | class Model(xmlr.Object): 267 | def __init__(self, name=None, pose=None): 268 | self.aggregate_init() 269 | self.name = name 270 | self.pose = pose 271 | self.links = [] 272 | self.joints = [] 273 | self.joint_map = {} 274 | self.link_map = {} 275 | 276 | self.parent_map = {} 277 | self.child_map = {} 278 | 279 | def add_aggregate(self, typeName, elem): 280 | xmlr.Object.add_aggregate(self, typeName, elem) 281 | 282 | if typeName == 'joint': 283 | joint = elem 284 | self.joint_map[joint.name] = joint 285 | self.parent_map[joint.child] = (joint.name, joint.parent) 286 | if joint.parent in self.child_map: 287 | self.child_map[joint.parent].append((joint.name, joint.child)) 288 | else: 289 | self.child_map[joint.parent] = [(joint.name, joint.child)] 290 | elif typeName == 'link': 291 | link = elem 292 | self.link_map[link.name] = link 293 | 294 | def add_link(self, link): 295 | self.add_aggregate('link', link) 296 | 297 | def add_joint(self, joint): 298 | self.add_aggregate('joint', joint) 299 | 300 | 301 | xmlr.reflect(Model, tag='model', params=[ 302 | name_attribute, 303 | xmlr.AggregateElement('link', Link, var='links'), 304 | xmlr.AggregateElement('joint', Joint, var='joints'), 305 | pose_element 306 | ]) 307 | 308 | 309 | class SDF(xmlr.Object): 310 | def __init__(self, version=None): 311 | self.version = version 312 | 313 | 314 | xmlr.reflect(SDF, tag='sdf', params=[ 315 | xmlr.Attribute('version', str, False), 316 | xmlr.Element('model', Model, False), 317 | ]) 318 | 319 | xmlr.end_namespace() 320 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/urdf_parser_py/xml_reflection/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | -------------------------------------------------------------------------------- /src/pytorch_kinematics/urdf_parser_py/xml_reflection/basics.py: -------------------------------------------------------------------------------- 1 | import string 2 | import yaml 3 | import collections 4 | from lxml import etree 5 | 6 | 7 | def xml_string(rootXml, addHeader=True): 8 | # Meh 9 | xmlString = etree.tostring(rootXml, pretty_print=True, encoding='unicode') 10 | if addHeader: 11 | xmlString = '\n' + xmlString 12 | return xmlString 13 | 14 | 15 | def dict_sub(obj, keys): 16 | return dict((key, obj[key]) for key in keys) 17 | 18 | 19 | def node_add(doc, sub): 20 | if sub is None: 21 | return None 22 | if type(sub) == str: 23 | return etree.SubElement(doc, sub) 24 | elif isinstance(sub, etree._Element): 25 | doc.append(sub) # This screws up the rest of the tree for prettyprint 26 | return sub 27 | else: 28 | raise Exception('Invalid sub value') 29 | 30 | 31 | def pfloat(x): 32 | return str(x).rstrip('.') 33 | 34 | 35 | def xml_children(node): 36 | children = node.getchildren() 37 | 38 | def predicate(node): 39 | return not isinstance(node, etree._Comment) 40 | 41 | return list(filter(predicate, children)) 42 | 43 | 44 | def isstring(obj): 45 | try: 46 | return isinstance(obj, basestring) 47 | except NameError: 48 | return isinstance(obj, str) 49 | 50 | 51 | def to_yaml(obj): 52 | """ Simplify yaml representation for pretty printing """ 53 | # Is there a better way to do this by adding a representation with 54 | # yaml.Dumper? 55 | # Ordered dict: http://pyyaml.org/ticket/29#comment:11 56 | if obj is None or isstring(obj): 57 | out = str(obj) 58 | elif type(obj) in [int, float, bool]: 59 | return obj 60 | elif hasattr(obj, 'to_yaml'): 61 | out = obj.to_yaml() 62 | elif isinstance(obj, etree._Element): 63 | out = etree.tostring(obj, pretty_print=True) 64 | elif type(obj) == dict: 65 | out = {} 66 | for (var, value) in obj.items(): 67 | out[str(var)] = to_yaml(value) 68 | elif hasattr(obj, 'tolist'): 69 | # For numpy objects 70 | out = to_yaml(obj.tolist()) 71 | elif isinstance(obj, collections.Iterable): 72 | out = [to_yaml(item) for item in obj] 73 | else: 74 | out = str(obj) 75 | return out 76 | 77 | 78 | class SelectiveReflection(object): 79 | def get_refl_vars(self): 80 | return list(vars(self).keys()) 81 | 82 | 83 | class YamlReflection(SelectiveReflection): 84 | def to_yaml(self): 85 | raw = dict((var, getattr(self, var)) for var in self.get_refl_vars()) 86 | return to_yaml(raw) 87 | 88 | def __str__(self): 89 | # Good idea? Will it remove other important things? 90 | return yaml.dump(self.to_yaml()).rstrip() 91 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/__init__.py -------------------------------------------------------------------------------- /tests/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 72 | -------------------------------------------------------------------------------- /tests/gen_fk_perf.py: -------------------------------------------------------------------------------- 1 | """ Generate performance data for multiple models, devices, data types, batch sizes, etc. """ 2 | import timeit 3 | from time import perf_counter 4 | import torch 5 | 6 | import pytorch_kinematics as pk 7 | import numpy as np 8 | 9 | 10 | def main(): 11 | np.set_printoptions(precision=3, suppress=True, linewidth=220) 12 | torch.set_printoptions(precision=3, sci_mode=False, linewidth=220) 13 | 14 | chains = { 15 | 'val': pk.build_chain_from_mjcf(open('val.xml').read()), 16 | 'val_serial': pk.build_serial_chain_from_mjcf(open('val.xml').read(), end_link_name='left_tool'), 17 | 'kuka_iiwa': pk.build_serial_chain_from_urdf(open('kuka_iiwa.urdf').read(), end_link_name='lbr_iiwa_link_7'), 18 | } 19 | 20 | devices = ['cpu', 'cuda'] 21 | dtypes = [torch.float32, torch.float64] 22 | batch_sizes = [1, 10, 100, 1_000, 10_000, 100_000] 23 | number = 100 24 | 25 | # iterate over all combinations and store in a pandas dataframe 26 | headers = ['method', 'chain', 'device', 'dtype', 'batch_size', 'time'] 27 | data = [] 28 | 29 | def _fk(th): 30 | return chain.forward_kinematics(th) 31 | 32 | for name, chain in chains.items(): 33 | for device in devices: 34 | for dtype in dtypes: 35 | for batch_size in batch_sizes: 36 | chain = chain.to(dtype=dtype, device=device) 37 | th = torch.zeros(batch_size, chain.n_joints).to(dtype=dtype, device=device) 38 | 39 | dt = timeit.timeit(lambda: _fk(th), number=number) 40 | data.append([name, device, dtype, batch_size, dt / number]) 41 | print(f"{name=} {device=} {dtype=} {batch_size=} {dt / number:.4f}") 42 | 43 | # pickle the data for visualization in jupyter notebook 44 | import pickle 45 | with open('fk_perf.pkl', 'wb') as f: 46 | pickle.dump([headers, data], f) 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /tests/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /tests/humanoid.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 8 | 9 | 10 | 11 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 23 | 25 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /tests/joint_limit_robot.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /tests/joint_no_limit_robot.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /tests/kuka_iiwa.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | -------------------------------------------------------------------------------- /tests/kuka_iiwa.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /tests/link_0.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_0.stl -------------------------------------------------------------------------------- /tests/link_1.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_1.stl -------------------------------------------------------------------------------- /tests/link_2.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_2.stl -------------------------------------------------------------------------------- /tests/link_3.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_3.stl -------------------------------------------------------------------------------- /tests/link_4.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_4.stl -------------------------------------------------------------------------------- /tests/link_5.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_5.stl -------------------------------------------------------------------------------- /tests/link_6.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_6.stl -------------------------------------------------------------------------------- /tests/link_7.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/link_7.stl -------------------------------------------------------------------------------- /tests/meshes/cone.mtl: -------------------------------------------------------------------------------- 1 | # Blender 3.6.4 MTL File: 'ycb.blend' 2 | # www.blender.org 3 | -------------------------------------------------------------------------------- /tests/meshes/cone.obj: -------------------------------------------------------------------------------- 1 | # Blender 3.6.4 2 | # www.blender.org 3 | mtllib cone.mtl 4 | o Cone 5 | v 0.097681 0.000000 -0.097681 6 | v 0.095804 -0.019057 -0.097681 7 | v 0.090246 -0.037381 -0.097681 8 | v 0.081219 -0.054269 -0.097681 9 | v 0.069071 -0.069071 -0.097681 10 | v 0.054269 -0.081219 -0.097681 11 | v 0.037381 -0.090246 -0.097681 12 | v 0.019057 -0.095804 -0.097681 13 | v 0.000000 -0.097681 -0.097681 14 | v -0.019057 -0.095804 -0.097681 15 | v -0.037381 -0.090246 -0.097681 16 | v -0.054269 -0.081219 -0.097681 17 | v -0.069071 -0.069071 -0.097681 18 | v -0.081219 -0.054269 -0.097681 19 | v -0.090246 -0.037381 -0.097681 20 | v -0.095804 -0.019057 -0.097681 21 | v -0.097681 0.000000 -0.097681 22 | v -0.095804 0.019057 -0.097681 23 | v -0.090246 0.037381 -0.097681 24 | v -0.081219 0.054269 -0.097681 25 | v -0.069071 0.069071 -0.097681 26 | v -0.054269 0.081219 -0.097681 27 | v -0.037381 0.090246 -0.097681 28 | v -0.019057 0.095804 -0.097681 29 | v 0.000000 0.097681 -0.097681 30 | v 0.019057 0.095804 -0.097681 31 | v 0.037381 0.090246 -0.097681 32 | v 0.054269 0.081219 -0.097681 33 | v 0.069071 0.069071 -0.097681 34 | v 0.081219 0.054269 -0.097681 35 | v 0.090246 0.037381 -0.097681 36 | v 0.095804 0.019057 -0.097681 37 | v 0.000000 0.000000 0.097681 38 | vn 0.8910 -0.0878 0.4455 39 | vn 0.8567 -0.2599 0.4455 40 | vn 0.7896 -0.4220 0.4455 41 | vn 0.6921 -0.5680 0.4455 42 | vn 0.5680 -0.6921 0.4455 43 | vn 0.4220 -0.7896 0.4455 44 | vn 0.2599 -0.8567 0.4455 45 | vn 0.0878 -0.8910 0.4455 46 | vn -0.0878 -0.8910 0.4455 47 | vn -0.2599 -0.8567 0.4455 48 | vn -0.4220 -0.7896 0.4455 49 | vn -0.5680 -0.6921 0.4455 50 | vn -0.6921 -0.5680 0.4455 51 | vn -0.7896 -0.4220 0.4455 52 | vn -0.8567 -0.2599 0.4455 53 | vn -0.8910 -0.0878 0.4455 54 | vn -0.8910 0.0878 0.4455 55 | vn -0.8567 0.2599 0.4455 56 | vn -0.7896 0.4220 0.4455 57 | vn -0.6921 0.5680 0.4455 58 | vn -0.5680 0.6921 0.4455 59 | vn -0.4220 0.7896 0.4455 60 | vn -0.2599 0.8567 0.4455 61 | vn -0.0878 0.8910 0.4455 62 | vn 0.0878 0.8910 0.4455 63 | vn 0.2599 0.8567 0.4455 64 | vn 0.4220 0.7896 0.4455 65 | vn 0.5680 0.6921 0.4455 66 | vn 0.6921 0.5680 0.4455 67 | vn 0.7896 0.4220 0.4455 68 | vn -0.0000 -0.0000 -1.0000 69 | vn 0.8567 0.2599 0.4455 70 | vn 0.8910 0.0878 0.4455 71 | vt 0.250000 0.490000 72 | vt 0.250000 0.250000 73 | vt 0.296822 0.485388 74 | vt 0.341844 0.471731 75 | vt 0.383337 0.449553 76 | vt 0.419706 0.419706 77 | vt 0.449553 0.383337 78 | vt 0.471731 0.341844 79 | vt 0.485388 0.296822 80 | vt 0.490000 0.250000 81 | vt 0.485388 0.203178 82 | vt 0.471731 0.158156 83 | vt 0.449553 0.116663 84 | vt 0.419706 0.080294 85 | vt 0.383337 0.050447 86 | vt 0.341844 0.028269 87 | vt 0.296822 0.014612 88 | vt 0.250000 0.010000 89 | vt 0.203178 0.014612 90 | vt 0.158156 0.028269 91 | vt 0.116663 0.050447 92 | vt 0.080294 0.080294 93 | vt 0.050447 0.116663 94 | vt 0.028269 0.158156 95 | vt 0.014612 0.203178 96 | vt 0.010000 0.250000 97 | vt 0.014612 0.296822 98 | vt 0.028269 0.341844 99 | vt 0.050447 0.383337 100 | vt 0.080294 0.419706 101 | vt 0.116663 0.449553 102 | vt 0.158156 0.471731 103 | vt 0.750000 0.490000 104 | vt 0.796822 0.485388 105 | vt 0.841844 0.471731 106 | vt 0.883337 0.449553 107 | vt 0.919706 0.419706 108 | vt 0.949553 0.383337 109 | vt 0.971731 0.341844 110 | vt 0.985388 0.296822 111 | vt 0.990000 0.250000 112 | vt 0.985388 0.203178 113 | vt 0.971731 0.158156 114 | vt 0.949553 0.116663 115 | vt 0.919706 0.080294 116 | vt 0.883337 0.050447 117 | vt 0.841844 0.028269 118 | vt 0.796822 0.014612 119 | vt 0.750000 0.010000 120 | vt 0.703178 0.014612 121 | vt 0.658156 0.028269 122 | vt 0.616663 0.050447 123 | vt 0.580294 0.080294 124 | vt 0.550447 0.116663 125 | vt 0.528269 0.158156 126 | vt 0.514612 0.203178 127 | vt 0.510000 0.250000 128 | vt 0.514612 0.296822 129 | vt 0.528269 0.341844 130 | vt 0.550447 0.383337 131 | vt 0.580294 0.419706 132 | vt 0.616663 0.449553 133 | vt 0.658156 0.471731 134 | vt 0.703178 0.485388 135 | vt 0.203178 0.485388 136 | s 0 137 | f 1/1/1 33/2/1 2/3/1 138 | f 2/3/2 33/2/2 3/4/2 139 | f 3/4/3 33/2/3 4/5/3 140 | f 4/5/4 33/2/4 5/6/4 141 | f 5/6/5 33/2/5 6/7/5 142 | f 6/7/6 33/2/6 7/8/6 143 | f 7/8/7 33/2/7 8/9/7 144 | f 8/9/8 33/2/8 9/10/8 145 | f 9/10/9 33/2/9 10/11/9 146 | f 10/11/10 33/2/10 11/12/10 147 | f 11/12/11 33/2/11 12/13/11 148 | f 12/13/12 33/2/12 13/14/12 149 | f 13/14/13 33/2/13 14/15/13 150 | f 14/15/14 33/2/14 15/16/14 151 | f 15/16/15 33/2/15 16/17/15 152 | f 16/17/16 33/2/16 17/18/16 153 | f 17/18/17 33/2/17 18/19/17 154 | f 18/19/18 33/2/18 19/20/18 155 | f 19/20/19 33/2/19 20/21/19 156 | f 20/21/20 33/2/20 21/22/20 157 | f 21/22/21 33/2/21 22/23/21 158 | f 22/23/22 33/2/22 23/24/22 159 | f 23/24/23 33/2/23 24/25/23 160 | f 24/25/24 33/2/24 25/26/24 161 | f 25/26/25 33/2/25 26/27/25 162 | f 26/27/26 33/2/26 27/28/26 163 | f 27/28/27 33/2/27 28/29/27 164 | f 28/29/28 33/2/28 29/30/28 165 | f 29/30/29 33/2/29 30/31/29 166 | f 30/31/30 33/2/30 31/32/30 167 | f 1/33/31 2/34/31 3/35/31 4/36/31 5/37/31 6/38/31 7/39/31 8/40/31 9/41/31 10/42/31 11/43/31 12/44/31 13/45/31 14/46/31 15/47/31 16/48/31 17/49/31 18/50/31 19/51/31 20/52/31 21/53/31 22/54/31 23/55/31 24/56/31 25/57/31 26/58/31 27/59/31 28/60/31 29/61/31 30/62/31 31/63/31 32/64/31 168 | f 31/32/32 33/2/32 32/65/32 169 | f 32/65/33 33/2/33 1/1/33 170 | -------------------------------------------------------------------------------- /tests/prismatic_robot.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /tests/simple_arm.sdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 0 0 0.00099 0 0 0 7 | 8 | 1.11 9 | 0 10 | 0 11 | 100.11 12 | 0 13 | 1.01 14 | 15 | 101.0 16 | 17 | 18 | 0 0 0.05 0 0 0 19 | 20 | 21 | 1.0 1.0 0.1 22 | 23 | 24 | 25 | 26 | 0 0 0.05 0 0 0 27 | 28 | 29 | 1.0 1.0 0.1 30 | 31 | 32 | 33 | 37 | 38 | 39 | 40 | 0 0 0.6 0 0 0 41 | 42 | 43 | 0.05 44 | 1.0 45 | 46 | 47 | 48 | 49 | 0 0 0.6 0 0 0 50 | 51 | 52 | 0.05 53 | 1.0 54 | 55 | 56 | 57 | 61 | 62 | 63 | 64 | 65 | 0 0 1.1 0 0 0 66 | 67 | 0.045455 0 0 0 0 0 68 | 69 | 0.011 70 | 0 71 | 0 72 | 0.0225 73 | 0 74 | 0.0135 75 | 76 | 1.1 77 | 78 | 79 | 0 0 0.05 0 0 0 80 | 81 | 82 | 0.05 83 | 0.1 84 | 85 | 86 | 87 | 88 | 0 0 0.05 0 0 0 89 | 90 | 91 | 0.05 92 | 0.1 93 | 94 | 95 | 96 | 100 | 101 | 102 | 103 | 0.55 0 0.05 0 0 0 104 | 105 | 106 | 1.0 0.05 0.1 107 | 108 | 109 | 110 | 111 | 0.55 0 0.05 0 0 0 112 | 113 | 114 | 1.0 0.05 0.1 115 | 116 | 117 | 118 | 122 | 123 | 124 | 125 | 126 | 1.05 0 1.1 0 0 0 127 | 128 | 0.0875 0 0.083333 0 0 0 129 | 130 | 0.031 131 | 0 132 | 0.005 133 | 0.07275 134 | 0 135 | 0.04475 136 | 137 | 1.2 138 | 139 | 140 | 0 0 0.1 0 0 0 141 | 142 | 143 | 0.05 144 | 0.2 145 | 146 | 147 | 148 | 149 | 0 0 0.1 0 0 0 150 | 151 | 152 | 0.05 153 | 0.2 154 | 155 | 156 | 157 | 161 | 162 | 163 | 164 | 0.3 0 0.15 0 0 0 165 | 166 | 167 | 0.5 0.03 0.1 168 | 169 | 170 | 171 | 172 | 0.3 0 0.15 0 0 0 173 | 174 | 175 | 0.5 0.03 0.1 176 | 177 | 178 | 179 | 183 | 184 | 185 | 186 | 0.55 0 0.15 0 0 0 187 | 188 | 189 | 0.05 190 | 0.3 191 | 192 | 193 | 194 | 195 | 0.55 0 0.15 0 0 0 196 | 197 | 198 | 0.05 199 | 0.3 200 | 201 | 202 | 203 | 207 | 208 | 209 | 210 | 211 | 1.6 0 1.05 0 0 0 212 | 213 | 0 0 0 0 0 0 214 | 215 | 0.01 216 | 0 217 | 0 218 | 0.01 219 | 0 220 | 0.001 221 | 222 | 0.1 223 | 224 | 225 | 0 0 0.5 0 0 0 226 | 227 | 228 | 0.03 229 | 1.0 230 | 231 | 232 | 233 | 234 | 0 0 0.5 0 0 0 235 | 236 | 237 | 0.03 238 | 1.0 239 | 240 | 241 | 242 | 246 | 247 | 248 | 249 | 250 | 1.6 0 1.0 0 0 0 251 | 252 | 0 0 0 0 0 0 253 | 254 | 0.01 255 | 0 256 | 0 257 | 0.01 258 | 0 259 | 0.001 260 | 261 | 0.1 262 | 263 | 264 | 0 0 0.025 0 0 0 265 | 266 | 267 | 0.05 268 | 0.05 269 | 270 | 271 | 272 | 273 | 0 0 0.025 0 0 0 274 | 275 | 276 | 0.05 277 | 0.05 278 | 279 | 280 | 281 | 285 | 286 | 287 | 288 | 289 | arm_base 290 | arm_shoulder_pan 291 | 292 | 293 | 1.000000 294 | 0.000000 295 | 296 | 0 0 1 297 | true 298 | 299 | 300 | 301 | arm_shoulder_pan 302 | arm_elbow_pan 303 | 304 | 305 | 1.000000 306 | 0.000000 307 | 308 | 0 0 1 309 | true 310 | 311 | 312 | 313 | arm_elbow_pan 314 | arm_wrist_lift 315 | 316 | 317 | 1.000000 318 | 0.000000 319 | 320 | 321 | -0.8 322 | 0.1 323 | 324 | 0 0 1 325 | true 326 | 327 | 328 | 329 | arm_wrist_lift 330 | arm_wrist_roll 331 | 332 | 333 | 1.000000 334 | 0.000000 335 | 336 | 337 | -2.999994 338 | 2.999994 339 | 340 | 0 0 1 341 | true 342 | 343 | 344 | 357 | 358 | 359 | -------------------------------------------------------------------------------- /tests/simple_y_arm.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /tests/test_attributes.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytorch_kinematics as pk 4 | 5 | TEST_DIR = os.path.dirname(__file__) 6 | 7 | def test_limits(): 8 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7") 9 | iiwa_low_individual = [] 10 | iiwa_high_individual = [] 11 | for joint in chain.get_joints(): 12 | # Default velocity and effort limits for the iiwa arm 13 | assert joint.velocity_limits == (-10, 10) 14 | assert joint.effort_limits == (-300, 300) 15 | iiwa_low_individual.append(joint.limits[0]) 16 | iiwa_high_individual.append(joint.limits[1]) 17 | iiwa_low, iiwa_high = chain.get_joint_limits() 18 | assert iiwa_low == iiwa_low_individual 19 | assert iiwa_high == iiwa_high_individual 20 | chain = pk.build_chain_from_urdf(open(os.path.join(TEST_DIR, "joint_limit_robot.urdf")).read()) 21 | nums = [] 22 | for joint in chain.get_joints(): 23 | # Slice off the "joint" prefix to get just the number of the joint 24 | num = int(joint.name[5]) 25 | nums.append(num) 26 | # This robot is defined specifically to test joint limits. For joint 27 | # number `num`, it sets lower, upper, velocity, and effort limits to 28 | # `num+8`, `num+9`, `num`, and `num+4` respectively 29 | assert joint.limits == (num + 8, num + 9) 30 | assert joint.velocity_limits == (-num, num) 31 | assert joint.effort_limits == (-(num + 4), num + 4) 32 | low, high = chain.get_joint_limits() 33 | v_low, v_high = chain.get_joint_velocity_limits() 34 | e_low, e_high = chain.get_joint_effort_limits() 35 | assert low == [x + 8 for x in nums] 36 | assert high == [x + 9 for x in nums] 37 | assert v_low == [-x for x in nums] 38 | assert v_high == [x for x in nums] 39 | assert e_low == [-(x + 4) for x in nums] 40 | assert e_high == [x + 4 for x in nums] 41 | 42 | 43 | def test_empty_limits(): 44 | chain = pk.build_chain_from_urdf(open(os.path.join(TEST_DIR, "joint_no_limit_robot.urdf")).read()) 45 | nums = [] 46 | for joint in chain.get_joints(): 47 | # Slice off the "joint" prefix to get just the number of the joint 48 | num = int(joint.name[5]) 49 | nums.append(num) 50 | # This robot is defined specifically to test joint limits. For joint 51 | # number `num`, it sets velocity, and effort limits to 52 | # `num`, and `num+4` respectively, and leaves the lower and upper 53 | # limits undefined 54 | assert joint.limits == (0, 0) 55 | assert joint.velocity_limits == (-num, num) 56 | assert joint.effort_limits == (-(num + 4), num + 4) 57 | low, high = chain.get_joint_limits() 58 | v_low, v_high = chain.get_joint_velocity_limits() 59 | e_low, e_high = chain.get_joint_effort_limits() 60 | assert low == [0] * len(nums) 61 | assert high == [0] * len(nums) 62 | assert v_low == [-x for x in nums] 63 | assert v_high == [x for x in nums] 64 | assert e_low == [-(x + 4) for x in nums] 65 | assert e_high == [x + 4 for x in nums] 66 | 67 | 68 | if __name__ == "__main__": 69 | test_limits() 70 | test_empty_limits() 71 | -------------------------------------------------------------------------------- /tests/test_inverse_kinematics.py: -------------------------------------------------------------------------------- 1 | import os 2 | from timeit import default_timer as timer 3 | 4 | import numpy as np 5 | import torch 6 | 7 | import pytorch_kinematics as pk 8 | import pytorch_seed 9 | 10 | import pybullet as p 11 | import pybullet_data 12 | 13 | visualize = False 14 | 15 | 16 | def _make_robot_translucent(robot_id, alpha=0.4): 17 | def make_transparent(link): 18 | link_id = link[1] 19 | rgba = list(link[7]) 20 | rgba[3] = alpha 21 | p.changeVisualShape(robot_id, link_id, rgbaColor=rgba) 22 | 23 | visual_data = p.getVisualShapeData(robot_id) 24 | for link in visual_data: 25 | make_transparent(link) 26 | 27 | def create_test_chain(robot="kuka_iiwa", device="cpu"): 28 | if robot == "kuka_iiwa": 29 | urdf = "kuka_iiwa/model.urdf" 30 | search_path = pybullet_data.getDataPath() 31 | full_urdf = os.path.join(search_path, urdf) 32 | chain = pk.build_serial_chain_from_urdf(open(full_urdf).read(), "lbr_iiwa_link_7") 33 | chain = chain.to(device=device) 34 | elif robot == "widowx": 35 | urdf = "widowx/wx250s.urdf" 36 | full_urdf = urdf 37 | chain = pk.build_serial_chain_from_urdf(open(full_urdf, "rb").read(), "ee_gripper_link") 38 | chain = chain.to(device=device) 39 | else: 40 | raise NotImplementedError(f"Robot {robot} not implemented") 41 | return chain, urdf 42 | 43 | def test_jacobian_follower(robot="kuka_iiwa"): 44 | pytorch_seed.seed(2) 45 | device = "cuda" if torch.cuda.is_available() else "cpu" 46 | search_path = pybullet_data.getDataPath() 47 | chain, urdf = create_test_chain(robot=robot, device=device) 48 | 49 | # robot frame 50 | pos = torch.tensor([0.0, 0.0, 0.0], device=device) 51 | rot = torch.tensor([0.0, 0.0, 0.0], device=device) 52 | rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device) 53 | 54 | # world frame goal 55 | M = 1000 56 | # generate random goal joint angles (so these are all achievable) 57 | # use the joint limits to generate random joint angles 58 | lim = torch.tensor(chain.get_joint_limits(), device=device) 59 | goal_q = torch.rand(M, lim.shape[1], device=device) * (lim[1] - lim[0]) + lim[0] 60 | 61 | # get ee pose (in robot frame) 62 | goal_in_rob_frame_tf = chain.forward_kinematics(goal_q) 63 | 64 | # transform to world frame for visualization 65 | goal_tf = rob_tf.compose(goal_in_rob_frame_tf) 66 | goal = goal_tf.get_matrix() 67 | goal_pos = goal[..., :3, 3] 68 | goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ") 69 | 70 | num_retries = 10 71 | ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=num_retries, 72 | joint_limits=lim.T, 73 | early_stopping_any_converged=True, 74 | early_stopping_no_improvement="all", 75 | # line_search=pk.BacktrackingLineSearch(max_lr=0.2), 76 | debug=False, 77 | lr=0.2) 78 | 79 | # do IK 80 | timer_start = timer() 81 | sol = ik.solve(goal_in_rob_frame_tf) 82 | timer_end = timer() 83 | print("IK took %f seconds" % (timer_end - timer_start)) 84 | print("IK converged number: %d / %d" % (sol.converged.sum(), sol.converged.numel())) 85 | print("IK took %d iterations" % sol.iterations) 86 | print("IK solved %d / %d goals" % (sol.converged_any.sum(), M)) 87 | 88 | # check that solving again produces the same solutions 89 | sol_again = ik.solve(goal_in_rob_frame_tf) 90 | assert torch.allclose(sol.solutions, sol_again.solutions) 91 | assert torch.allclose(sol.converged, sol_again.converged) 92 | 93 | # visualize everything 94 | if visualize: 95 | p.connect(p.GUI) 96 | p.setRealTimeSimulation(False) 97 | p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0) 98 | p.setAdditionalSearchPath(search_path) 99 | 100 | yaw = 90 101 | pitch = -65 102 | # dist = 1. 103 | dist = 2.4 104 | target = np.array([2., 1.5, 0]) 105 | p.resetDebugVisualizerCamera(dist, yaw, pitch, target) 106 | 107 | plane_id = p.loadURDF("plane.urdf", [0, 0, 0], useFixedBase=True) 108 | p.changeVisualShape(plane_id, -1, rgbaColor=[0.3, 0.3, 0.3, 1]) 109 | 110 | # make 1 per retry with positional offsets 111 | robots = [] 112 | num_robots = 16 113 | # 4x4 grid position offset 114 | offset = 1.0 115 | m = rob_tf.get_matrix() 116 | pos = m[0, :3, 3] 117 | rot = m[0, :3, :3] 118 | quat = pk.matrix_to_quaternion(rot) 119 | pos = pos.cpu().numpy() 120 | rot = pk.wxyz_to_xyzw(quat).cpu().numpy() 121 | 122 | for i in range(num_robots): 123 | this_offset = np.array([i % 4 * offset, i // 4 * offset, 0]) 124 | armId = p.loadURDF(urdf, basePosition=pos + this_offset, baseOrientation=rot, useFixedBase=True) 125 | # _make_robot_translucent(armId, alpha=0.6) 126 | robots.append({"id": armId, "offset": this_offset, "pos": pos}) 127 | 128 | show_max_num_retries_per_goal = 10 129 | 130 | goals = [] 131 | # draw cone to indicate pose instead of sphere 132 | visId = p.createVisualShape(p.GEOM_MESH, fileName="meshes/cone.obj", meshScale=1.0, 133 | rgbaColor=[0., 1., 0., 0.5]) 134 | for i in range(num_robots): 135 | goals.append(p.createMultiBody(baseMass=0, baseVisualShapeIndex=visId)) 136 | 137 | try: 138 | import window_recorder 139 | with window_recorder.WindowRecorder(save_dir="."): 140 | # batch over goals with num_robots 141 | for j in range(0, M, num_robots): 142 | this_selection = slice(j, j + num_robots) 143 | r = goal_rot[this_selection] 144 | xyzw = pk.wxyz_to_xyzw(pk.matrix_to_quaternion(pk.euler_angles_to_matrix(r, "XYZ"))) 145 | 146 | solutions = sol.solutions[this_selection, :, :] 147 | converged = sol.converged[this_selection, :] 148 | 149 | # print how many retries converged for this one 150 | print("Goal %d to %d converged %d / %d" % (j, j + num_robots, converged.sum(), converged.numel())) 151 | 152 | # outer loop over retries, inner loop over goals (for each robot shown in parallel) 153 | for ii in range(num_retries): 154 | if ii > show_max_num_retries_per_goal: 155 | break 156 | for jj in range(num_robots): 157 | p.resetBasePositionAndOrientation(goals[jj], 158 | goal_pos[j + jj].cpu().numpy() + robots[jj]["offset"], 159 | xyzw[jj].cpu().numpy()) 160 | armId = robots[jj]["id"] 161 | q = solutions[jj, ii, :] 162 | for dof in range(q.shape[0]): 163 | p.resetJointState(armId, dof, q[dof]) 164 | 165 | input("Press enter to continue") 166 | except ImportError: 167 | print("pip install window_recorder") 168 | 169 | while True: 170 | p.stepSimulation() 171 | 172 | 173 | def test_ik_in_place_no_err(robot="kuka_iiwa"): 174 | pytorch_seed.seed(2) 175 | device = "cuda" if torch.cuda.is_available() else "cpu" 176 | # device = "cpu" 177 | chain, urdf = create_test_chain(robot=robot, device=device) 178 | # robot frame 179 | pos = torch.tensor([0.0, 0.0, 0.0], device=device) 180 | rot = torch.tensor([0.0, 0.0, 0.0], device=device) 181 | rob_tf = pk.Transform3d(pos=pos, rot=rot, device=device) 182 | 183 | # goal equal to current configuration 184 | lim = torch.tensor(chain.get_joint_limits(), device=device) 185 | cur_q = torch.rand(lim.shape[1], device=device) * (lim[1] - lim[0]) + lim[0] 186 | M = 1 187 | goal_q = cur_q.unsqueeze(0).repeat(M, 1) 188 | 189 | # get ee pose (in robot frame) 190 | goal_in_rob_frame_tf = chain.forward_kinematics(goal_q) 191 | 192 | # transform to world frame for visualization 193 | goal_tf = rob_tf.compose(goal_in_rob_frame_tf) 194 | goal = goal_tf.get_matrix() 195 | goal_pos = goal[..., :3, 3] 196 | goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ") 197 | 198 | ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=10, 199 | joint_limits=lim.T, 200 | early_stopping_any_converged=True, 201 | early_stopping_no_improvement="all", 202 | retry_configs=cur_q.reshape(1, -1), 203 | # line_search=pk.BacktrackingLineSearch(max_lr=0.2), 204 | debug=False, 205 | lr=0.2) 206 | 207 | # do IK 208 | sol = ik.solve(goal_in_rob_frame_tf) 209 | assert sol.converged.sum() == M 210 | assert torch.allclose(sol.solutions[0][0], cur_q) 211 | assert torch.allclose(sol.err_pos[0], torch.zeros(1, device=device), atol=1e-6) 212 | assert torch.allclose(sol.err_rot[0], torch.zeros(1, device=device), atol=1e-6) 213 | 214 | 215 | 216 | 217 | if __name__ == "__main__": 218 | print("Testing kuka_iiwa IK") 219 | test_jacobian_follower(robot="kuka_iiwa") 220 | test_ik_in_place_no_err(robot="kuka_iiwa") 221 | print("Testing widowx IK") 222 | test_jacobian_follower(robot="widowx") 223 | test_ik_in_place_no_err(robot="widowx") -------------------------------------------------------------------------------- /tests/test_jacobian.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from timeit import default_timer as timer 4 | 5 | import torch 6 | 7 | import pytorch_kinematics as pk 8 | 9 | TEST_DIR = os.path.dirname(__file__) 10 | 11 | 12 | def test_correctness(): 13 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), 14 | "lbr_iiwa_link_7") 15 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]) 16 | J = chain.jacobian(th) 17 | 18 | J_expected = torch.tensor([[[0, 1.41421356e-02, 0, 2.82842712e-01, 0, 0, 0], 19 | [-6.60827561e-01, 0, -4.57275649e-01, 0, 5.72756493e-02, 0, 0], 20 | [0, 6.60827561e-01, 0, -3.63842712e-01, 0, 8.10000000e-02, 0], 21 | [0, 0, -7.07106781e-01, 0, -7.07106781e-01, 0, -1], 22 | [0, 1, 0, -1, 0, 1, 0], 23 | [1, 0, 7.07106781e-01, 0, -7.07106781e-01, 0, 0]]]) 24 | assert torch.allclose(J, J_expected, atol=1e-7) 25 | 26 | chain = pk.build_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read()) 27 | chain = pk.SerialChain(chain, "arm_wrist_roll") 28 | th = torch.tensor([0.8, 0.2, -0.5, -0.3]) 29 | J = chain.jacobian(th) 30 | torch.allclose(J, torch.tensor([[[0., -1.51017878, -0.46280904, 0.], 31 | [0., 0.37144033, 0.29716627, 0.], 32 | [0., 0., 0., 0.], 33 | [0., 0., 0., 0.], 34 | [0., 0., 0., 0.], 35 | [0., 1., 1., 1.]]])) 36 | 37 | 38 | def test_jacobian_at_different_loc_than_ee(): 39 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), 40 | "lbr_iiwa_link_7") 41 | th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]) 42 | loc = torch.tensor([0.1, 0, 0]) 43 | J = chain.jacobian(th, locations=loc) 44 | J_c1 = torch.tensor([[[-0., 0.11414214, -0., 0.18284271, 0., 0.1, 0.], 45 | [-0.66082756, -0., -0.38656497, -0., 0.12798633, -0., 0.1], 46 | [-0., 0.66082756, -0., -0.36384271, 0., 0.081, -0.], 47 | [-0., -0., -0.70710678, -0., -0.70710678, 0., -1.], 48 | [0., 1., 0., -1., 0., 1., 0.], 49 | [1., 0., 0.70710678, 0., -0.70710678, -0., 0.]]]) 50 | 51 | assert torch.allclose(J, J_c1, atol=1e-7) 52 | 53 | loc = torch.tensor([-0.1, 0.05, 0]) 54 | J = chain.jacobian(th, locations=loc) 55 | J_c2 = torch.tensor([[[-0.05, -0.08585786, -0.03535534, 0.38284271, 0.03535534, -0.1, -0.], 56 | [-0.66082756, -0., -0.52798633, -0., -0.01343503, 0., -0.1], 57 | [-0., 0.66082756, -0.03535534, -0.36384271, -0.03535534, 0.081, -0.05], 58 | [-0., -0., -0.70710678, -0., -0.70710678, 0., -1.], 59 | [0., 1., 0., -1., 0., 1., 0.], 60 | [1., 0., 0.70710678, 0., -0.70710678, -0., 0.]]]) 61 | 62 | assert torch.allclose(J, J_c2, atol=1e-7) 63 | 64 | # check that batching the location is fine 65 | th = th.repeat(2, 1) 66 | loc = torch.tensor([[0.1, 0, 0], [-0.1, 0.05, 0]]) 67 | J = chain.jacobian(th, locations=loc) 68 | assert torch.allclose(J, torch.cat((J_c1, J_c2)), atol=1e-7) 69 | 70 | 71 | def test_jacobian_y_joint_axis(): 72 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "simple_y_arm.urdf")).read(), "eef") 73 | th = torch.tensor([0.]) 74 | J = chain.jacobian(th) 75 | J_c3 = torch.tensor([[[0.], [0.], [-0.3], [0.], [1.], [0.]]]) 76 | assert torch.allclose(J, J_c3, atol=1e-7) 77 | 78 | 79 | def test_parallel(): 80 | N = 100 81 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), 82 | "lbr_iiwa_link_7") 83 | th = torch.cat( 84 | (torch.tensor([[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]]), torch.rand(N, 7))) 85 | J = chain.jacobian(th) 86 | for i in range(N): 87 | J_i = chain.jacobian(th[i]) 88 | assert torch.allclose(J[i], J_i) 89 | 90 | 91 | def test_dtype_device(): 92 | N = 1000 93 | d = "cuda" if torch.cuda.is_available() else "cpu" 94 | dtype = torch.float64 95 | 96 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), 97 | "lbr_iiwa_link_7") 98 | chain = chain.to(dtype=dtype, device=d) 99 | th = torch.rand(N, 7, dtype=dtype, device=d) 100 | J = chain.jacobian(th) 101 | assert J.dtype is dtype 102 | 103 | 104 | def test_gradient(): 105 | N = 10 106 | d = "cuda" if torch.cuda.is_available() else "cpu" 107 | dtype = torch.float64 108 | 109 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), 110 | "lbr_iiwa_link_7") 111 | chain = chain.to(dtype=dtype, device=d) 112 | th = torch.rand(N, 7, dtype=dtype, device=d, requires_grad=True) 113 | J = chain.jacobian(th) 114 | assert th.grad is None 115 | J.norm().backward() 116 | assert th.grad is not None 117 | 118 | 119 | def test_jacobian_prismatic(): 120 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "prismatic_robot.urdf")).read(), "link4") 121 | th = torch.zeros(3) 122 | tg = chain.forward_kinematics(th) 123 | m = tg.get_matrix() 124 | pos = m[0, :3, 3] 125 | assert torch.allclose(pos, torch.tensor([0, 0, 1.])) 126 | th = torch.tensor([0, 0.1, 0]) 127 | tg = chain.forward_kinematics(th) 128 | m = tg.get_matrix() 129 | pos = m[0, :3, 3] 130 | assert torch.allclose(pos, torch.tensor([0, -0.1, 1.])) 131 | th = torch.tensor([0.1, 0.1, 0]) 132 | tg = chain.forward_kinematics(th) 133 | m = tg.get_matrix() 134 | pos = m[0, :3, 3] 135 | assert torch.allclose(pos, torch.tensor([0, -0.1, 1.1])) 136 | th = torch.tensor([0.1, 0.1, 0.1]) 137 | tg = chain.forward_kinematics(th) 138 | m = tg.get_matrix() 139 | pos = m[0, :3, 3] 140 | assert torch.allclose(pos, torch.tensor([0.1, -0.1, 1.1])) 141 | 142 | J = chain.jacobian(th) 143 | assert torch.allclose(J, torch.tensor([[[0., 0., 1.], 144 | [0., -1., 0.], 145 | [1., 0., 0.], 146 | [0., 0., 0.], 147 | [0., 0., 0.], 148 | [0., 0., 0.]]])) 149 | 150 | 151 | def test_comparison_to_autograd(): 152 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), 153 | "lbr_iiwa_link_7") 154 | d = "cuda" if torch.cuda.is_available() else "cpu" 155 | chain = chain.to(device=d) 156 | 157 | def get_pt(th): 158 | return chain.forward_kinematics(th).transform_points( 159 | torch.zeros((1, 3), device=th.device, dtype=th.dtype)).squeeze(1) 160 | 161 | # compare the time taken 162 | N = 1000 163 | ths = (torch.tensor([[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]], device=d), 164 | torch.rand(N - 1, 7, device=d)) 165 | th = torch.cat(ths) 166 | 167 | autograd_start = timer() 168 | j1 = torch.autograd.functional.jacobian(get_pt, inputs=th, vectorize=True) 169 | # get_pt will produce N x 3 170 | # jacobian will compute the jacobian of the N x 3 points with respect to each of the N x DOF inputs 171 | # so j1 is N x 3 x N x DOF (3 since it only considers the position change) 172 | # however, we know the ith point only has a non-zero jacobian with the ith input 173 | j1_ = j1[range(N), :, range(N)] 174 | pk_start = timer() 175 | j2 = chain.jacobian(th) 176 | pk_end = timer() 177 | # we can only compare the positional parts 178 | assert torch.allclose(j1_, j2[:, :3], atol=1e-6) 179 | print(f"for N={N} on {d} autograd:{(pk_start - autograd_start) * 1000}ms") 180 | print(f"for N={N} on {d} pytorch-kinematics:{(pk_end - pk_start) * 1000}ms") 181 | # if we have functools (for pytorch>=1.13.0 it comes with installing pytorch) 182 | try: 183 | import functorch 184 | ft_start = timer() 185 | grad_func = torch.vmap(functorch.jacrev(get_pt)) 186 | j3 = grad_func(th).squeeze(1) 187 | ft_end = timer() 188 | assert torch.allclose(j1_, j3, atol=1e-6) 189 | assert torch.allclose(j3, j2[:, :3], atol=1e-6) 190 | print(f"for N={N} on {d} functorch:{(ft_end - ft_start) * 1000}ms") 191 | except: 192 | pass 193 | 194 | 195 | if __name__ == "__main__": 196 | test_correctness() 197 | test_parallel() 198 | test_dtype_device() 199 | test_gradient() 200 | test_jacobian_prismatic() 201 | test_jacobian_at_different_loc_than_ee() 202 | test_comparison_to_autograd() 203 | -------------------------------------------------------------------------------- /tests/test_kinematics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from timeit import timeit 4 | 5 | import torch 6 | 7 | import pytorch_kinematics as pk 8 | from pytorch_kinematics.transforms.math import quaternion_close 9 | 10 | TEST_DIR = os.path.dirname(__file__) 11 | 12 | 13 | def quat_pos_from_transform3d(tg): 14 | m = tg.get_matrix() 15 | pos = m[:, :3, 3] 16 | rot = pk.matrix_to_quaternion(m[:, :3, :3]) 17 | return pos, rot 18 | 19 | 20 | # test more complex robot and the MJCF parser 21 | def test_fk_mjcf(): 22 | chain = pk.build_chain_from_mjcf(open(os.path.join(TEST_DIR, "ant.xml")).read()) 23 | chain = chain.to(dtype=torch.float64) 24 | print(chain) 25 | print(chain.get_joint_parameter_names()) 26 | 27 | th = {joint: 0.0 for joint in chain.get_joint_parameter_names()} 28 | th.update({'hip_1': 1.0, 'ankle_1': 1}) 29 | ret = chain.forward_kinematics(th) 30 | tg = ret['aux_1'] 31 | pos, rot = quat_pos_from_transform3d(tg) 32 | assert quaternion_close(rot, torch.tensor([0.87758256, 0., 0., 0.47942554], dtype=torch.float64)) 33 | assert torch.allclose(pos, torch.tensor([0.2, 0.2, 0.75], dtype=torch.float64)) 34 | tg = ret['front_left_foot'] 35 | pos, rot = quat_pos_from_transform3d(tg) 36 | assert quaternion_close(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64)) 37 | assert torch.allclose(pos, torch.tensor([0.13976626, 0.47635466, 0.75], dtype=torch.float64)) 38 | print(ret) 39 | 40 | 41 | def test_fk_serial_mjcf(): 42 | chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "ant.xml")).read(), 'front_left_foot') 43 | chain = chain.to(dtype=torch.float64) 44 | tg = chain.forward_kinematics([1.0, 1.0]) 45 | pos, rot = quat_pos_from_transform3d(tg) 46 | assert quaternion_close(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64)) 47 | assert torch.allclose(pos, torch.tensor([0.13976626, 0.47635466, 0.75], dtype=torch.float64)) 48 | 49 | 50 | def test_fkik(): 51 | data = '' \ 52 | '' \ 53 | '' \ 54 | '' \ 55 | '' \ 56 | '' \ 57 | '' \ 58 | '' \ 59 | '' \ 60 | '' \ 61 | '' \ 62 | '' \ 63 | '' \ 64 | '' \ 65 | '' 66 | chain = pk.build_serial_chain_from_urdf(data, 'link3') 67 | th1 = torch.tensor([0.42553542, 0.17529176]) 68 | tg = chain.forward_kinematics(th1) 69 | pos, rot = quat_pos_from_transform3d(tg) 70 | assert torch.allclose(pos, torch.tensor([[1.91081784, 0.41280851, 0.0000]])) 71 | assert quaternion_close(rot, torch.tensor([[0.95521418, 0.0000, 0.0000, 0.2959153]])) 72 | N = 20 73 | th_batch = torch.rand(N, 2) 74 | tg_batch = chain.forward_kinematics(th_batch) 75 | m = tg_batch.get_matrix() 76 | for i in range(N): 77 | tg = chain.forward_kinematics(th_batch[i]) 78 | assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) 79 | 80 | # check that gradients are passed through 81 | th2 = torch.tensor([0.42553542, 0.17529176], requires_grad=True) 82 | tg = chain.forward_kinematics(th2) 83 | pos, rot = quat_pos_from_transform3d(tg) 84 | # note that since we are using existing operations we are not checking grad calculation correctness 85 | assert th2.grad is None 86 | pos.norm().backward() 87 | assert th2.grad is not None 88 | 89 | 90 | def test_urdf(): 91 | chain = pk.build_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read()) 92 | chain.to(dtype=torch.float64) 93 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0] 94 | ret = chain.forward_kinematics(th) 95 | tg = ret['lbr_iiwa_link_7'] 96 | pos, rot = quat_pos_from_transform3d(tg) 97 | assert quaternion_close(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64)) 98 | assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64), atol=1e-6) 99 | 100 | 101 | def test_urdf_serial(): 102 | chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7") 103 | chain.to(dtype=torch.float64) 104 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0] 105 | 106 | ret = chain.forward_kinematics(th, end_only=False) 107 | tg = ret['lbr_iiwa_link_7'] 108 | pos, rot = quat_pos_from_transform3d(tg) 109 | assert quaternion_close(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64)) 110 | assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64), atol=1e-6) 111 | 112 | N = 1000 113 | d = "cuda" if torch.cuda.is_available() else "cpu" 114 | dtype = torch.float64 115 | 116 | th_batch = torch.rand(N, len(chain.get_joint_parameter_names()), dtype=dtype, device=d) 117 | 118 | chain = chain.to(dtype=dtype, device=d) 119 | 120 | # NOTE: Warmstart since pytorch can be slow the first time you run it 121 | # this has to be done after you move it to the GPU. Otherwise the timing isn't representative. 122 | for _ in range(5): 123 | ret = chain.forward_kinematics(th) 124 | 125 | number = 10 126 | 127 | def _fk_parallel(): 128 | tg_batch = chain.forward_kinematics(th_batch) 129 | m = tg_batch.get_matrix() 130 | 131 | dt_parallel = timeit(_fk_parallel, number=number) / number 132 | print("elapsed {}s for N={} when parallel".format(dt_parallel, N)) 133 | 134 | def _fk_serial(): 135 | for i in range(N): 136 | tg = chain.forward_kinematics(th_batch[i]) 137 | m = tg.get_matrix() 138 | 139 | dt_serial = timeit(_fk_serial, number=number) / number 140 | print("elapsed {}s for N={} when serial".format(dt_serial, N)) 141 | 142 | # assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) 143 | 144 | 145 | # test robot with prismatic and fixed joints 146 | def test_fk_simple_arm(): 147 | chain = pk.build_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read()) 148 | chain = chain.to(dtype=torch.float64) 149 | # print(chain) 150 | # print(chain.get_joint_parameter_names()) 151 | ret = chain.forward_kinematics({ 152 | 'arm_shoulder_pan_joint': 0., 153 | 'arm_elbow_pan_joint': math.pi / 2.0, 154 | 'arm_wrist_lift_joint': -0.5, 155 | 'arm_wrist_roll_joint': 0., 156 | }) 157 | tg = ret['arm_wrist_roll'] 158 | pos, rot = quat_pos_from_transform3d(tg) 159 | assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64)) 160 | assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64)) 161 | 162 | N = 100 163 | ret = chain.forward_kinematics({k: torch.rand(N) for k in chain.get_joint_parameter_names()}) 164 | tg = ret['arm_wrist_roll'] 165 | assert list(tg.get_matrix().shape) == [N, 4, 4] 166 | 167 | 168 | def test_sdf_serial_chain(): 169 | chain = pk.build_serial_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read(), 'arm_wrist_roll') 170 | chain = chain.to(dtype=torch.float64) 171 | tg = chain.forward_kinematics([0., math.pi / 2.0, -0.5, 0.]) 172 | pos, rot = quat_pos_from_transform3d(tg) 173 | assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64)) 174 | assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64)) 175 | 176 | 177 | def test_cuda(): 178 | if torch.cuda.is_available(): 179 | d = "cuda" 180 | dtype = torch.float64 181 | chain = pk.build_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read()) 182 | # noinspection PyUnusedLocal 183 | chain = chain.to(dtype=dtype, device=d) 184 | 185 | # NOTE: do it twice because we previously had an issue with default arguments 186 | # like joint=Joint() causing spooky behavior 187 | chain = pk.build_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read()) 188 | chain = chain.to(dtype=dtype, device=d) 189 | 190 | ret = chain.forward_kinematics({ 191 | 'arm_shoulder_pan_joint': 0, 192 | 'arm_elbow_pan_joint': math.pi / 2.0, 193 | 'arm_wrist_lift_joint': -0.5, 194 | 'arm_wrist_roll_joint': 0, 195 | }) 196 | tg = ret['arm_wrist_roll'] 197 | pos, rot = quat_pos_from_transform3d(tg) 198 | assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=dtype, device=d)) 199 | assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=dtype, device=d)) 200 | 201 | data = '' \ 202 | '' \ 203 | '' \ 204 | '' \ 205 | '' \ 206 | '' \ 207 | '' \ 208 | '' \ 209 | '' \ 210 | '' \ 211 | '' \ 212 | '' \ 213 | '' \ 214 | '' \ 215 | '' 216 | chain = pk.build_serial_chain_from_urdf(data, 'link3') 217 | chain = chain.to(dtype=dtype, device=d) 218 | N = 20 219 | th_batch = torch.rand(N, 2).to(device=d, dtype=dtype) 220 | tg_batch = chain.forward_kinematics(th_batch) 221 | m = tg_batch.get_matrix() 222 | for i in range(N): 223 | tg = chain.forward_kinematics(th_batch[i]) 224 | assert torch.allclose(tg.get_matrix().view(4, 4), m[i]) 225 | 226 | 227 | # FIXME: comment out because compound joints are no longer implemented 228 | # def test_fk_mjcf_humanoid(): 229 | # chain = pk.build_chain_from_mjcf(open(os.path.join(TEST_DIR, "humanoid.xml")).read()) 230 | # print(chain) 231 | # print(chain.get_joint_parameter_names()) 232 | # th = {'left_knee': 0.0, 'right_knee': 0.0} 233 | # ret = chain.forward_kinematics(th) 234 | # print(ret) 235 | 236 | 237 | def test_mjcf_slide_joint_parsing(): 238 | # just testing that we can parse it without error 239 | # the slide joint is not actually of a link to another link, but instead of the base to the world 240 | # which we do not represent 241 | chain = pk.build_chain_from_mjcf(open(os.path.join(TEST_DIR, "hopper.xml")).read()) 242 | print(chain.get_joint_parameter_names()) 243 | print(chain.get_frame_names()) 244 | 245 | 246 | def test_fk_val(): 247 | chain = pk.build_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read()) 248 | chain = chain.to(dtype=torch.float64) 249 | ret = chain.forward_kinematics(torch.zeros([1000, chain.n_joints], dtype=torch.float64)) 250 | tg = ret['drive45'] 251 | pos, rot = quat_pos_from_transform3d(tg) 252 | torch.set_printoptions(precision=6, sci_mode=False) 253 | assert quaternion_close(rot, torch.tensor([0.5, 0.5, -0.5, 0.5], dtype=torch.float64)) 254 | assert torch.allclose(pos, torch.tensor([-0.225692, 0.259045, 0.262139], dtype=torch.float64)) 255 | 256 | 257 | def test_fk_partial_batched_dict(): 258 | # Test that you can pass in dict of batched joint configs for a subset of the joints 259 | chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), 'left_tool') 260 | th = { 261 | 'joint56': torch.zeros([1000], dtype=torch.float64), 262 | 'joint57': torch.zeros([1000], dtype=torch.float64), 263 | 'joint41': torch.zeros([1000], dtype=torch.float64), 264 | 'joint42': torch.zeros([1000], dtype=torch.float64), 265 | 'joint43': torch.zeros([1000], dtype=torch.float64), 266 | 'joint44': torch.zeros([1000], dtype=torch.float64), 267 | 'joint45': torch.zeros([1000], dtype=torch.float64), 268 | 'joint46': torch.zeros([1000], dtype=torch.float64), 269 | 'joint47': torch.zeros([1000], dtype=torch.float64), 270 | } 271 | chain = chain.to(dtype=torch.float64) 272 | tg = chain.forward_kinematics(th) 273 | 274 | 275 | def test_fk_partial_batched(): 276 | # Test that you can pass in dict of batched joint configs for a subset of the joints 277 | chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), 'left_tool') 278 | th = torch.zeros([1000, 9], dtype=torch.float64) 279 | chain = chain.to(dtype=torch.float64) 280 | tg = chain.forward_kinematics(th) 281 | 282 | 283 | def test_ur5_fk(): 284 | urdf = os.path.join(TEST_DIR, "ur5.urdf") 285 | pk_chain = pk.build_serial_chain_from_urdf(open(urdf).read(), 'ee_link', 'base_link') 286 | th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0] 287 | 288 | try: 289 | import ikpy.chain 290 | ik_chain = ikpy.chain.Chain.from_urdf_file(urdf, 291 | active_links_mask=[False, True, True, True, True, True, True, False]) 292 | ik_ret = ik_chain.forward_kinematics([0, *th, 0]) 293 | except ImportError: 294 | ik_ret = [[-6.44330720e-18, 3.58979314e-09, -1.00000000e+00, 5.10955359e-01], 295 | [1.00000000e+00, 1.79489651e-09, 0.00000000e+00, 1.91450000e-01], 296 | [1.79489651e-09, -1.00000000e+00, -3.58979312e-09, 6.00114361e-01], 297 | [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]] 298 | 299 | ret = pk_chain.forward_kinematics(th, end_only=True) 300 | print(ret.get_matrix()) 301 | ik_ret = torch.tensor(ik_ret, dtype=ret.dtype) 302 | print(ik_ret) 303 | assert torch.allclose(ik_ret, ret.get_matrix(), atol=1e-6) 304 | 305 | 306 | if __name__ == "__main__": 307 | test_fk_partial_batched() 308 | test_fk_partial_batched_dict() 309 | test_fk_val() 310 | test_sdf_serial_chain() 311 | test_urdf_serial() 312 | test_fkik() 313 | test_fk_simple_arm() 314 | test_fk_mjcf() 315 | test_cuda() 316 | test_urdf() 317 | # test_fk_mjcf_humanoid() 318 | test_mjcf_slide_joint_parsing() 319 | test_ur5_fk() 320 | -------------------------------------------------------------------------------- /tests/test_menagerie.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import numpy as np 5 | 6 | import pytorch_kinematics as pk 7 | 8 | # Find all files named "scene*.xml" in the "mujoco_menagerie" directory 9 | _MENAGERIE_ROOT = pathlib.Path(__file__).parent / 'mujoco_menagerie' 10 | _XMLS_AND_BODIES = { 11 | # 'agility_cassie/scene.xml': 'cassie-pelvis', # not supported because it has a ball joint 12 | 'anybotics_anymal_b/scene.xml': 'base', 13 | 'anybotics_anymal_c/scene.xml': 'base', 14 | 'franka_emika_panda/scene.xml': 'link0', 15 | 'google_barkour_v0/scene.xml': 'chassis', 16 | 'google_barkour_v0/scene_barkour.xml': 'chassis', 17 | # 'hello_robot_stretch/scene.xml': 'base_link', # not supported because it has composite joints 18 | 'kuka_iiwa_14/scene.xml': 'base', 19 | 'rethink_robotics_sawyer/scene.xml': 'base', 20 | 'robotiq_2f85/scene.xml': 'base_mount', 21 | 'robotis_op3/scene.xml': 'body_link', 22 | 'shadow_hand/scene_left.xml': 'lh_forearm', 23 | 'shadow_hand/scene_right.xml': 'rh_forearm', 24 | 'ufactory_xarm7/scene.xml': 'link_base', 25 | 'unitree_a1/scene.xml': 'trunk', 26 | 'unitree_go1/scene.xml': 'trunk', 27 | 'universal_robots_ur5e/scene.xml': 'base', 28 | 'wonik_allegro/scene_left.xml': 'palm', 29 | 'wonik_allegro/scene_right.xml': 'palm', 30 | } 31 | 32 | 33 | def test_menagerie(): 34 | for xml_filename, body in _XMLS_AND_BODIES.items(): 35 | xml_filename = _MENAGERIE_ROOT / xml_filename 36 | xml_dir = xml_filename.parent 37 | # Menagerie files assume the current working directory is the directory of the scene.xml 38 | os.chdir(xml_dir) 39 | with xml_filename.open('r') as f: 40 | xml = f.read() 41 | chain = pk.build_chain_from_mjcf(xml, body) 42 | print(xml_filename) 43 | print("=" * 32) 44 | print(f"\t {chain.get_frame_names()}") 45 | print(f"\t {chain.get_joint_parameter_names()}") 46 | th = np.zeros(len(chain.get_joint_parameter_names())) 47 | fk_dict = chain.forward_kinematics(th) 48 | 49 | 50 | if __name__ == '__main__': 51 | test_menagerie() 52 | -------------------------------------------------------------------------------- /tests/test_rotation_conversions.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from pytorch_kinematics.transforms.math import quaternion_close 7 | from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33, axis_angle_to_matrix, \ 8 | pos_rot_to_matrix, matrix_to_pos_rot, random_rotations, quaternion_from_euler 9 | 10 | 11 | def test_axis_angle_to_matrix_perf(): 12 | number = 100 13 | N = 1_000 14 | 15 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 16 | axis_angle = torch.randn([N, 3], device=device, dtype=torch.float64) 17 | axis_1d = torch.tensor([1., 0, 0], device=device, dtype=torch.float64) # in the FK code this is NOT batched! 18 | theta = axis_angle.norm(dim=1, keepdim=True) 19 | 20 | dt1 = timeit.timeit(lambda: axis_angle_to_matrix(axis_angle), number=number) 21 | print(f'Old method: {dt1:.5f}') 22 | 23 | dt2 = timeit.timeit(lambda: axis_and_angle_to_matrix_33(axis=axis_1d, theta=theta), number=number) 24 | print(f'New method: {dt2:.5f}') 25 | 26 | 27 | def test_quaternion_not_close(): 28 | # ensure it returns false for quaternions that are far apart 29 | q1 = torch.tensor([1., 0, 0, 0]) 30 | q2 = torch.tensor([0., 1, 0, 0]) 31 | assert not quaternion_close(q1, q2) 32 | 33 | 34 | def test_quaternion_from_euler(): 35 | q = quaternion_from_euler(torch.tensor([0., 0, 0])) 36 | assert quaternion_close(q, torch.tensor([1., 0, 0, 0])) 37 | root2_over_2 = np.sqrt(2) / 2 38 | 39 | q = quaternion_from_euler(torch.tensor([0, 0, np.pi / 2])) 40 | assert quaternion_close(q, torch.tensor([root2_over_2, 0, 0, root2_over_2], dtype=q.dtype)) 41 | 42 | q = quaternion_from_euler(torch.tensor([-np.pi / 2, 0, 0])) 43 | assert quaternion_close(q, torch.tensor([root2_over_2, -root2_over_2, 0, 0], dtype=q.dtype)) 44 | 45 | q = quaternion_from_euler(torch.tensor([0, np.pi / 2, 0])) 46 | assert quaternion_close(q, torch.tensor([root2_over_2, 0, root2_over_2, 0], dtype=q.dtype)) 47 | 48 | # Test batched 49 | b = 32 50 | rpy = torch.tensor([0, np.pi / 2, 0]) 51 | rpy_batch = torch.tile(rpy[None], (b, 1)) 52 | q_batch = quaternion_from_euler(rpy_batch) 53 | q_expected = torch.tensor([root2_over_2, 0, root2_over_2, 0], dtype=q.dtype) 54 | q_expected_batch = torch.tile(q_expected[None], (b, 1)) 55 | assert quaternion_close(q_batch, q_expected_batch) 56 | 57 | 58 | def test_pos_rot_conversion(): 59 | N = 1000 60 | R = random_rotations(N) 61 | t = torch.randn((N, 3), dtype=R.dtype, device=R.device) 62 | T = torch.eye(4, dtype=R.dtype, device=R.device).repeat(N, 1, 1) 63 | T[:, :3, 3] = t 64 | T[:, :3, :3] = R 65 | pos, rot = matrix_to_pos_rot(T) 66 | TT = pos_rot_to_matrix(pos, rot) 67 | assert torch.allclose(T, TT, atol=1e-6) 68 | 69 | 70 | if __name__ == '__main__': 71 | test_axis_angle_to_matrix_perf() 72 | test_pos_rot_conversion() 73 | -------------------------------------------------------------------------------- /tests/test_serial_chain_creation.py: -------------------------------------------------------------------------------- 1 | import os 2 | from timeit import default_timer as timer 3 | 4 | import torch 5 | 6 | import pytorch_kinematics as pk 7 | 8 | TEST_DIR = os.path.dirname(__file__) 9 | 10 | 11 | def test_extract_serial_chain_from_tree(): 12 | urdf = "widowx/wx250s.urdf" 13 | full_urdf = os.path.join(TEST_DIR, urdf) 14 | chain = pk.build_chain_from_urdf(open(full_urdf, mode="rb").read()) 15 | # full frames 16 | full_frame_expected = """ 17 | base_link 18 | └── shoulder_link 19 | └── upper_arm_link 20 | └── upper_forearm_link 21 | └── lower_forearm_link 22 | └── wrist_link 23 | └── gripper_link 24 | └── ee_arm_link 25 | ├── gripper_prop_link 26 | └── gripper_bar_link 27 | └── fingers_link 28 | ├── left_finger_link 29 | ├── right_finger_link 30 | └── ee_gripper_link 31 | """ 32 | full_frame = chain.print_tree() 33 | assert full_frame_expected.strip() == full_frame.strip() 34 | 35 | serial_chain = pk.SerialChain(chain, "ee_gripper_link", "base_link") 36 | serial_frame_expected = """ 37 | base_link 38 | └── shoulder_link 39 | └── upper_arm_link 40 | └── upper_forearm_link 41 | └── lower_forearm_link 42 | └── wrist_link 43 | └── gripper_link 44 | └── ee_arm_link 45 | └── gripper_bar_link 46 | └── fingers_link 47 | └── ee_gripper_link 48 | """ 49 | serial_frame = serial_chain.print_tree() 50 | assert serial_frame_expected.strip() == serial_frame.strip() 51 | 52 | # full chain should have DOF = 8, however since we are creating just a serial chain to ee_gripper_link, should be 6 53 | assert chain.n_joints == 8 54 | assert serial_chain.n_joints == 6 55 | 56 | serial_chain = pk.SerialChain(chain, "gripper_prop_link", "base_link") 57 | serial_frame_expected = """ 58 | base_link 59 | └── shoulder_link 60 | └── upper_arm_link 61 | └── upper_forearm_link 62 | └── lower_forearm_link 63 | └── wrist_link 64 | └── gripper_link 65 | └── ee_arm_link 66 | └── gripper_prop_link 67 | """ 68 | serial_frame = serial_chain.print_tree() 69 | assert serial_frame_expected.strip() == serial_frame.strip() 70 | 71 | serial_chain = pk.SerialChain(chain, "ee_gripper_link", "gripper_link") 72 | serial_frame_expected = """ 73 | gripper_link 74 | └── ee_arm_link 75 | └── gripper_bar_link 76 | └── fingers_link 77 | └── ee_gripper_link 78 | """ 79 | serial_frame = serial_chain.print_tree() 80 | assert serial_frame_expected.strip() == serial_frame.strip() 81 | # only gripper_link is the parent frame of a joint in this serial chain 82 | assert serial_chain.n_joints == 1 83 | 84 | 85 | if __name__ == "__main__": 86 | test_extract_serial_chain_from_tree() 87 | -------------------------------------------------------------------------------- /tests/test_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import pytorch_kinematics.transforms as tf 4 | import pytorch_kinematics as pk 5 | 6 | 7 | def test_transform(): 8 | N = 20 9 | mats = tf.random_rotations(N, dtype=torch.float64, device="cpu", requires_grad=True) 10 | assert list(mats.shape) == [N, 3, 3] 11 | # test batch conversions 12 | quat = tf.matrix_to_quaternion(mats) 13 | assert list(quat.shape) == [N, 4] 14 | mats_recovered = tf.quaternion_to_matrix(quat) 15 | assert torch.allclose(mats, mats_recovered) 16 | 17 | quat_identity = tf.quaternion_multiply(quat, tf.quaternion_invert(quat)) 18 | assert torch.allclose(tf.quaternion_to_matrix(quat_identity), torch.eye(3, dtype=torch.float64).repeat(N, 1, 1)) 19 | 20 | 21 | def test_translations(): 22 | t = tf.Translate(1, 2, 3) 23 | points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( 24 | 1, 3, 3 25 | ) 26 | points_out = t.transform_points(points) 27 | points_out_expected = torch.tensor( 28 | [[2.0, 2.0, 3.0], [1.0, 3.0, 3.0], [1.5, 2.5, 3.0]] 29 | ).view(1, 3, 3) 30 | assert torch.allclose(points_out, points_out_expected) 31 | 32 | N = 20 33 | points = torch.randn((N, N, 3)) 34 | translation = torch.randn((N, 3)) 35 | transforms = tf.Transform3d(pos=translation) 36 | translated_points = transforms.transform_points(points) 37 | assert torch.allclose(translated_points, translation.repeat(N, 1, 1).transpose(0, 1) + points) 38 | returned_points = transforms.inverse().transform_points(translated_points) 39 | assert torch.allclose(returned_points, points, atol=1e-6) 40 | 41 | 42 | def test_rotate_axis_angle(): 43 | t = tf.Transform3d().rotate_axis_angle(90.0, axis="Z") 44 | points = torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]).view( 45 | 1, 3, 3 46 | ) 47 | normals = torch.tensor( 48 | [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]] 49 | ).view(1, 3, 3) 50 | points_out = t.transform_points(points) 51 | normals_out = t.transform_normals(normals) 52 | points_out_expected = torch.tensor( 53 | [[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 1.0]] 54 | ).view(1, 3, 3) 55 | normals_out_expected = torch.tensor( 56 | [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] 57 | ).view(1, 3, 3) 58 | assert torch.allclose(points_out, points_out_expected) 59 | assert torch.allclose(normals_out, normals_out_expected) 60 | 61 | 62 | def test_rotate(): 63 | R = tf.so3_exp_map(torch.randn((1, 3))) 64 | t = tf.Transform3d().rotate(R) 65 | points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( 66 | 1, 3, 3 67 | ) 68 | normals = torch.tensor( 69 | [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] 70 | ).view(1, 3, 3) 71 | points_out = t.transform_points(points) 72 | normals_out = t.transform_normals(normals) 73 | points_out_expected = torch.bmm(points, R.transpose(-1, -2)) 74 | normals_out_expected = torch.bmm(normals, R.transpose(-1, -2)) 75 | assert torch.allclose(points_out, points_out_expected, atol=1e-7) 76 | assert torch.allclose(normals_out, normals_out_expected, atol=1e-7) 77 | for i in range(3): 78 | assert torch.allclose(points_out[0, i], R @ points[0, i], atol=1e-7) 79 | assert torch.allclose(normals_out[0, i], R @ normals[0, i], atol=1e-7) 80 | 81 | 82 | def test_transform_combined(): 83 | R = tf.so3_exp_map(torch.randn((1, 3))) 84 | tr = torch.randn((1, 3)) 85 | t = tf.Transform3d(rot=R, pos=tr) 86 | N = 10 87 | points = torch.randn((N, 3)) 88 | normals = torch.randn((N, 3)) 89 | points_out = t.transform_points(points) 90 | normals_out = t.transform_normals(normals) 91 | for i in range(N): 92 | assert torch.allclose(points_out[i], R @ points[i] + tr, atol=1e-7) 93 | assert torch.allclose(normals_out[i], R @ normals[i], atol=1e-7) 94 | 95 | 96 | def test_euler(): 97 | euler_angles = torch.tensor([1, 0, 0.5]) 98 | t = tf.Transform3d(rot=euler_angles) 99 | sxyz_matrix = torch.tensor([[0.87758256, -0.47942554, 0., 0., ], 100 | [0.25903472, 0.47415988, -0.84147098, 0.], 101 | [0.40342268, 0.73846026, 0.54030231, 0.], 102 | [0., 0., 0., 1.]]) 103 | # from tf.transformations import euler_matrix 104 | # print(euler_matrix(*euler_angles, "rxyz")) 105 | # print(t.get_matrix()) 106 | assert torch.allclose(sxyz_matrix, t.get_matrix()) 107 | 108 | 109 | def test_quaternions(): 110 | import pytorch_seed 111 | pytorch_seed.seed(0) 112 | 113 | n = 10 114 | q = tf.random_quaternions(n) 115 | q_tf = tf.wxyz_to_xyzw(q) 116 | assert torch.allclose(q, tf.xyzw_to_wxyz(q_tf)) 117 | 118 | qq = pk.standardize_quaternion(q) 119 | assert torch.allclose(qq.norm(dim=-1), torch.ones(n)) 120 | 121 | # random quaternions should already be unit quaternions 122 | assert torch.allclose(q, qq) 123 | 124 | # distances to themselves should be zero 125 | d = pk.quaternion_angular_distance(q, q) 126 | assert torch.allclose(d, torch.zeros(n)) 127 | # q = -q 128 | d = pk.quaternion_angular_distance(q, -q) 129 | assert torch.allclose(d, torch.zeros(n)) 130 | 131 | axis = torch.tensor([0.0, 0.5, 0.5]) 132 | axis = axis / axis.norm() 133 | magnitudes = torch.tensor([2.32, 1.56, -0.52, 0.1]) 134 | n = len(magnitudes) 135 | aa_1 = axis.repeat(n, 1) 136 | aa_2 = axis * magnitudes[:, None] 137 | q1 = pk.axis_angle_to_quaternion(aa_1) 138 | q2 = pk.axis_angle_to_quaternion(aa_2) 139 | d = pk.quaternion_angular_distance(q1, q2) 140 | expected_d = (magnitudes - 1).abs() 141 | assert torch.allclose(d, expected_d, atol=1e-4) 142 | 143 | 144 | def test_compose(): 145 | import torch 146 | theta = 1.5707 147 | a2b = tf.Transform3d(pos=[0.1, 0, 0]) # joint.offset 148 | b2j = tf.Transform3d(rot=tf.axis_angle_to_quaternion(theta * torch.tensor([0.0, 0, 1]))) # joint.axis 149 | j2c = tf.Transform3d(pos=[0.1, 0, 0]) # link.offset ? 150 | a2c = a2b.compose(b2j, j2c) 151 | m = a2c.get_matrix() 152 | print(m) 153 | print(a2c.transform_points(torch.zeros([1, 3]))) 154 | 155 | 156 | def test_quaternion_slerp(): 157 | q = tf.random_quaternions(20) 158 | q1 = q[:10] 159 | q2 = q[10:] 160 | t = torch.rand(10) 161 | q_interp = pk.quaternion_slerp(q1, q2, t) 162 | # check the distance between them is consistent 163 | full_dist = pk.quaternion_angular_distance(q1, q2) 164 | interp_dist = pk.quaternion_angular_distance(q1, q_interp) 165 | # print(f"full_dist: {full_dist} interp_dist: {interp_dist} t: {t}") 166 | assert torch.allclose(full_dist * t, interp_dist, atol=1e-5) 167 | 168 | 169 | if __name__ == "__main__": 170 | test_compose() 171 | test_transform() 172 | test_translations() 173 | test_rotate_axis_angle() 174 | test_rotate() 175 | test_euler() 176 | test_quaternions() 177 | test_quaternion_slerp() 178 | -------------------------------------------------------------------------------- /tests/ur5.urdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | transmission_interface/SimpleTransmission 252 | 253 | PositionJointInterface 254 | 255 | 256 | 1 257 | 258 | 259 | 260 | transmission_interface/SimpleTransmission 261 | 262 | PositionJointInterface 263 | 264 | 265 | 1 266 | 267 | 268 | 269 | transmission_interface/SimpleTransmission 270 | 271 | PositionJointInterface 272 | 273 | 274 | 1 275 | 276 | 277 | 278 | transmission_interface/SimpleTransmission 279 | 280 | PositionJointInterface 281 | 282 | 283 | 1 284 | 285 | 286 | 287 | transmission_interface/SimpleTransmission 288 | 289 | PositionJointInterface 290 | 291 | 292 | 1 293 | 294 | 295 | 296 | transmission_interface/SimpleTransmission 297 | 298 | PositionJointInterface 299 | 300 | 301 | 1 302 | 303 | 304 | 305 | 306 | -------------------------------------------------------------------------------- /tests/widowx/README.md: -------------------------------------------------------------------------------- 1 | # WidowX 250 S Robot Description (URDF) 2 | 3 | The robot model here is based on the real2sim project: https://github.com/simpler-env/SimplerEnv 4 | 5 | 6 | -------------------------------------------------------------------------------- /tests/widowx/interbotix_black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/interbotix_black.png -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-1-Base.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.018 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-1-Base.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-1-Base.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-10-Finger.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.027 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-10-Finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-10-Finger.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-2-Shoulder.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.019 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-2-Shoulder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-2-Shoulder.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-3-UA.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.020 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-3-UA.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-3-UA.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-4-UF.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.021 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-4-UF.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-4-UF.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-5-LF.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.022 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-5-LF.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-5-LF.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-6-Wrist.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.023 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-6-Wrist.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-6-Wrist.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-7-Gripper.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.024 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-7-Gripper.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-7-Gripper.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-8-Gripper-Prop.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.025 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-8-Gripper-Prop.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-8-Gripper-Prop.stl -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-9-Gripper-Bar.mtl: -------------------------------------------------------------------------------- 1 | # Blender MTL File: 'None' 2 | # Material Count: 1 3 | 4 | newmtl Material.026 5 | Ns 250.000000 6 | Ka 1.000000 1.000000 1.000000 7 | Kd 0.800000 0.800000 0.800000 8 | Ks 0.500000 0.500000 0.500000 9 | Ke 0.000000 0.000000 0.000000 10 | Ni 1.450000 11 | d 1.000000 12 | illum 2 13 | map_Kd ../interbotix_black.png 14 | -------------------------------------------------------------------------------- /tests/widowx/meshes_wx250s/WXSA-250-M-9-Gripper-Bar.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UM-ARM-Lab/pytorch_kinematics/fed888983095f8bb1dd8ed023258232c23928fa4/tests/widowx/meshes_wx250s/WXSA-250-M-9-Gripper-Bar.stl -------------------------------------------------------------------------------- /tests/widowx/wx250s.srdf: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | --------------------------------------------------------------------------------