├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets └── saved_models │ ├── mpo_state_rgb_test_triplet1 │ ├── saved_model.pb │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ ├── mpo_state_rgb_test_triplet2 │ ├── saved_model.pb │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ ├── mpo_state_rgb_test_triplet3 │ ├── saved_model.pb │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ ├── mpo_state_rgb_test_triplet4 │ ├── saved_model.pb │ └── variables │ │ ├── variables.data-00000-of-00001 │ │ └── variables.index │ └── mpo_state_rgb_test_triplet5 │ ├── saved_model.pb │ └── variables │ ├── variables.data-00000-of-00001 │ └── variables.index ├── doc └── images │ └── rgb_environment.png ├── real_cell_documentation ├── basket_and_camera_assembly.pdf ├── bill_of_materials.pdf ├── cell_assembly.pdf ├── robot_assembly.pdf └── standard_cell_generic_setup.pdf ├── requirements.txt ├── rgb_stacking ├── environment.py ├── environment_test.py ├── main.py ├── physics_utils.py ├── reward_functions.py ├── stack_rewards.py ├── task.py └── utils │ ├── permissive_model.py │ └── policy_loading.py └── run.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | # Pull Requests 4 | 5 | Please send in fixes or feature additions through Pull Requests. 6 | 7 | ## Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a Contributor License 10 | Agreement. You (or your employer) retain the copyright to your contribution, 11 | this simply gives us permission to use and redistribute your contributions as 12 | part of the project. Head over to to see 13 | your current agreements on file or to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RGB-stacking 🛑🟩🔷 for robotic manipulation 2 | ### [BLOG](https://deepmind.com/blog/article/stacking-our-way-to-more-general-robots) | [PAPER][pick_and_place_paper] | [VIDEO](https://youtu.be/BxOKPEtMuZw) 3 | 4 | **Beyond Pick-and-Place: Tackling Robotic Stacking of Diverse Shapes,**\ 5 | Alex X. Lee*, Coline Devin*, Yuxiang Zhou*, Thomas Lampe*, Konstantinos Bousmalis*, Jost Tobias Springenberg*, Arunkumar Byravan, Abbas Abdolmaleki, Nimrod Gileadi, David Khosid, Claudio Fantacci, Jose Enrique Chen, Akhil Raju, Rae Jeong, Michael Neunert, Antoine Laurens, Stefano Saliceti, Federico Casarini, Martin Riedmiller, Raia Hadsell, Francesco Nori.\ 6 | In *Conference on Robot Learning (CoRL)*, 2021. 7 | 8 | The RGB environment 9 | 10 | This repository contains an implementation of the simulation environment 11 | described in the paper 12 | ["Beyond Pick-and-Place: Tackling robotic stacking of diverse shapes"][pick_and_place_paper]. 13 | Note that this is a re-implementation of the environment (to remove dependencies 14 | on internal libraries). As a result, not all the features described in the paper 15 | are available at this point. Noticeably, domain randomization is not included 16 | in this release. We also aim to provide reference performance metrics of 17 | trained policies on this environment in the near future. 18 | 19 | In this environment, the agent controls a robot arm with a parallel gripper 20 | above a basket, which contains three objects — one red, one green, and one blue, 21 | hence the name RGB. The agent's task is to stack the red object on top of the 22 | blue object, within 20 seconds, while the green object serves as an obstacle and 23 | distraction. The agent controls the robot using a 4D Cartesian controller. The 24 | controlled DOFs are x, y, z and rotation around 25 | the z axis. The simulation is a MuJoCo environment built using the 26 | [Modular Manipulation (MoMa) framework](https://github.com/deepmind/robotics/tree/main/py/moma/README.md). 27 | 28 | 29 | ## Corresponding method 30 | The RGB-stacking paper 31 | ["Beyond Pick-and-Place: Tackling robotic stacking of diverse shapes"][pick_and_place_paper] 32 | also contains a description and thorough evaluation of our initial solution to 33 | both the 'Skill Mastery' (training on the 5 designated test triplets and 34 | evaluating on them) and the 'Skill Generalization' (training on triplets of 35 | training objects and evaluating on the 5 test triplets). Our 36 | approach was to first train a state-based policy in simulation via a standard RL 37 | algorithm (we used [MPO](https://arxiv.org/abs/1806.06920)) followed by 38 | interactive distillation of the state-based policy into a vision-based policy (using a 39 | domain randomized version of the environment) that we then deployed to the robot 40 | via zero-shot sim-to-real transfer. We finally improved the policy further via 41 | offline RL based on data collected from the sim-to-real policy (we used [CRR](https://arxiv.org/abs/2006.15134)). For details on our method and the results 42 | please consult the paper. 43 | 44 | ## Released specialist policies 45 | 46 | This repository includes state-based policies that were trained on this 47 | environment, which differs slightly from the internal one we used for the paper. 48 | These are 5 specialist policies, each one trained on one test triplet. They 49 | correspond to the Skill Mastery-State teacher in Table 1 of the manuscript and 50 | they achieve 75% stacking success on average. In detail, the stacking success of 51 | each agent over a run of 1000 episodes is (average of 2 seeds): 52 | 53 | * Triplet 1: 77.7% 54 | * Triplet 2: 47.4% 55 | * Triplet 3: 83.5% 56 | * Triplet 4: 79.9% 57 | * Triplet 5: 89.5% 58 | * Average: 75.6% 59 | 60 | The policy weights in the directory `assets/saved_model` are made available 61 | under the terms of the Creative Commons Attribution 4.0 (CC BY 4.0) license. 62 | You may obtain a copy of the License at 63 | https://creativecommons.org/licenses/by/4.0/legalcode. 64 | 65 | ## Installing and visualising the environment 66 | 67 | Please ensure that you have a working 68 | [MuJoCo200 installation](https://www.roboti.us/download.html) and a valid 69 | [MuJoCo licence](https://www.roboti.us/license.html). 70 | 71 | 1. Clone this repository: 72 | 73 | ```bash 74 | git clone https://github.com/deepmind/rgb_stacking.git 75 | cd rgb_stacking 76 | ``` 77 | 78 | 2. Prepare a Python 3 environment - venv is recommended. 79 | 80 | ```bash 81 | python3 -m venv rgb_stacking_venv 82 | source rgb_stacking_venv/bin/activate 83 | ``` 84 | 85 | 3. Install dependencies: 86 | 87 | ```bash 88 | pip install -r requirements.txt 89 | ``` 90 | 91 | 4. Run the environment viewer: 92 | 93 | ```bash 94 | python -m rgb_stacking.main 95 | ``` 96 | 97 | Step 2-4 can also be done by running the run.sh script: 98 | 99 | ```bash 100 | ./run.sh 101 | ``` 102 | 103 | By default, this loads the environment with a random test triplet and starts the 104 | viewer for visualisation. Alternatively, the object set can be specified with 105 | `--object_triplet` (see the relevant [section](#specifying-the-object-triplet) 106 | for options). 107 | 108 | ## Specifying one of the released specialist policies 109 | 110 | You can also load the environment along with a specialist policy using the flag 111 | `--policy_object_triplet`. E.g. to execute the respective specialist in the 112 | environment with triplet 4 use the following command: 113 | 114 | ```bash 115 | python -m rgb_stacking.main --object_triplet=rgb_test_triplet4 --policy_object_triplet=rgb_test_triplet4 116 | ``` 117 | 118 | Executing and visualising a policy in the viewer can be very slow. 119 | Alternatively, using `launch_viewer=False` will render the policy and save it 120 | as `rendered_policy.mp4` in the current directory. 121 | 122 | ```bash 123 | MUJOCO_GL=egl python -m rgb_stacking.main --launch_viewer=False --object_triplet=rgb_test_triplet4 --policy_object_triplet=rgb_test_triplet4 124 | ``` 125 | 126 | ## Specifying the object triplet 127 | 128 | The default environment will load with a random test triplet (see Sect. 3.2.1 in 129 | the paper). If you wish to use a different triplet you can use the following 130 | commands: 131 | 132 | ```python 133 | from rgb_stacking import environment 134 | 135 | env = environment.rgb_stacking(object_triplet=NAME_OF_TRIPLET) 136 | ``` 137 | 138 | The possible `NAME_OF_TRIPLET` are: 139 | 140 | * `rgb_test_triplet{i}` where `i` is one of 1, 2, 3, 4, 5: Loads test triplet `i`. 141 | * `rgb_test_random`: Randomly loads one of the 5 test triplets. 142 | * `rgb_train_random`: Triplet comprised of blocks from the training set. 143 | * `rgb_heldout_random`: Triplet comprised of blocks from the held-out set. 144 | 145 | For more information on the blocks and the possible options, please refer to 146 | the [rgb_objects repository](https://github.com/deepmind/dm_robotics/tree/main/py/manipulation/props/rgb_objects/README.md). 147 | 148 | ## Specifying the observation space 149 | 150 | By default, the observations exposed by the environment are only the ones we 151 | used for training our state-based agents. To use another set of observations 152 | please use the following code snippet: 153 | 154 | ```python 155 | from rgb_stacking import environment 156 | 157 | env = environment.rgb_stacking( 158 | observations=environment.ObservationSet.CHOSEN_SET) 159 | ``` 160 | 161 | The possible `CHOSEN_SET` are: 162 | 163 | * `STATE_ONLY`: Only the state observations, used for training expert policies 164 | from state in simulation (stage 1). 165 | * `VISION_ONLY`: Only image observations. 166 | * `ALL`: All observations. 167 | * `INTERACTIVE_IMITATION_LEARNING`: Pair of image observations and a subset of 168 | proprioception observations, used for interactive imitation learning 169 | (stage 2). 170 | * `OFFLINE_POLICY_IMPROVEMENT`: Pair of image observations and a subset of 171 | proprioception observations, used for the one-step offline policy 172 | improvement (stage 3). 173 | 174 | ## Real RGB-Stacking Environment: CAD models and assembly instructions 175 | 176 | The [CAD model](https://deepmind.onshape.com/documents/d0b99322019b124525012b2a/w/6702310a5b51c79efed7c65b/e/4b6eb89dc085468d0fee5e97) 177 | of the setup is available in onshape. 178 | 179 | We also provide the following [documents](https://github.com/deepmind/rgb_stacking/tree/main/real_cell_documentation) 180 | for the assembly of the real cell: 181 | 182 | * Assembly instructions for the basket. 183 | * Assembly instructions for the robot. 184 | * Assembly instructions for the cell. 185 | * The bill of materials of all the necessary parts. 186 | * A diagram with the wiring of cell. 187 | 188 | The RGB-objects themselves can be 3D-printed using the STLs available in the [rgb_objects repository](https://github.com/deepmind/dm_robotics/tree/main/py/manipulation/props/rgb_objects/README.md). 189 | 190 | ## Citing 191 | 192 | If you use `rgb_stacking` in your work, please cite the accompanying [paper][pick_and_place_paper]: 193 | 194 | ```bibtex 195 | @inproceedings{lee2021rgbstacking, 196 | title={Beyond Pick-and-Place: Tackling Robotic Stacking of Diverse Shapes}, 197 | author={Alex X. Lee and 198 | Coline Devin and 199 | Yuxiang Zhou and 200 | Thomas Lampe and 201 | Konstantinos Bousmalis and 202 | Jost Tobias Springenberg and 203 | Arunkumar Byravan and 204 | Abbas Abdolmaleki and 205 | Nimrod Gileadi and 206 | David Khosid and 207 | Claudio Fantacci and 208 | Jose Enrique Chen and 209 | Akhil Raju and 210 | Rae Jeong and 211 | Michael Neunert and 212 | Antoine Laurens and 213 | Stefano Saliceti and 214 | Federico Casarini and 215 | Martin Riedmiller and 216 | Raia Hadsell and 217 | Francesco Nori}, 218 | booktitle={Conference on Robot Learning (CoRL)}, 219 | year={2021}, 220 | url={https://openreview.net/forum?id=U0Q8CrtBJxJ} 221 | } 222 | ``` 223 | 224 | 225 | [pick_and_place_paper]: http://arxiv.org/abs/2110.06192 226 | -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet1/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet1/saved_model.pb -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.index -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet2/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet2/saved_model.pb -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.index -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet3/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet3/saved_model.pb -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.index -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet4/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet4/saved_model.pb -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.index -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet5/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet5/saved_model.pb -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.index -------------------------------------------------------------------------------- /doc/images/rgb_environment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/doc/images/rgb_environment.png -------------------------------------------------------------------------------- /real_cell_documentation/basket_and_camera_assembly.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/basket_and_camera_assembly.pdf -------------------------------------------------------------------------------- /real_cell_documentation/bill_of_materials.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/bill_of_materials.pdf -------------------------------------------------------------------------------- /real_cell_documentation/cell_assembly.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/cell_assembly.pdf -------------------------------------------------------------------------------- /real_cell_documentation/robot_assembly.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/robot_assembly.pdf -------------------------------------------------------------------------------- /real_cell_documentation/standard_cell_generic_setup.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/standard_cell_generic_setup.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dm-control == 0.0.364896371 2 | dm-robotics-moma == 0.0.4 3 | dm-robotics-manipulation == 0.0.4 4 | tensorflow == 2.7.0.rc0 5 | dm-reverb-nightly == 0.5.0.dev20211104 6 | opencv-python >= 3.4.0 7 | typing-extensions >= 3.7.4 8 | -------------------------------------------------------------------------------- /rgb_stacking/environment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Builds RGB stacking environment. 17 | 18 | This file builds the RGB stacking environment used in the paper 19 | "Beyond Pick-and-Place: Tackling robotic stacking of diverse shapes". 20 | The environment is composed of a robot with a parallel gripper. In front 21 | of the robot there is a basket containing 3 objects, one red, one green and one 22 | blue. The goal is for the robot to stack the red object on top of the blue one. 23 | 24 | In this specific file, we build the interface that is exposed to the agent, 25 | namely the action spec, the observation spec along with the reward. 26 | """ 27 | 28 | import enum 29 | from typing import Sequence 30 | 31 | from dm_robotics import agentflow as af 32 | from dm_robotics.agentflow.preprocessors import observation_transforms 33 | from dm_robotics.agentflow.preprocessors import timestep_preprocessor as tsp 34 | from dm_robotics.agentflow.subtasks import subtask_termination 35 | from dm_robotics.manipulation.props.rgb_objects import rgb_object 36 | from dm_robotics.moma import action_spaces 37 | from dm_robotics.moma import subtask_env 38 | from dm_robotics.moma import subtask_env_builder 39 | from dm_robotics.moma.utils import mujoco_collisions 40 | import numpy as np 41 | 42 | from rgb_stacking import reward_functions 43 | from rgb_stacking import stack_rewards 44 | from rgb_stacking import task 45 | 46 | 47 | # The environments provides stacked observations to the agent. The data of the 48 | # previous n steps is stacked and provided as a single observation to the agent. 49 | # We stack different observations a different number of times. 50 | _OBSERVATION_STACK_DEPTH = 3 51 | _ACTION_OBS_STACK_DEPTH = 2 52 | 53 | # Number of steps in an episode 54 | _MAX_STEPS = 400 55 | 56 | # Timestep of the physics simulation. 57 | _PHYSICS_TIMESTEP = 0.0005 58 | 59 | 60 | _STATE_OBSERVATIONS = ( 61 | 'action/environment', 62 | 'gripper/grasp', 63 | 'gripper/joints/angle', 64 | 'gripper/joints/velocity', 65 | 'rgb30_blue/abs_pose', 66 | 'rgb30_blue/to_pinch', 67 | 'rgb30_green/abs_pose', 68 | 'rgb30_green/to_pinch', 69 | 'rgb30_red/abs_pose', 70 | 'rgb30_red/to_pinch', 71 | 'sawyer/joints/angle', 72 | 'sawyer/joints/torque', 73 | 'sawyer/joints/velocity', 74 | 'sawyer/pinch/pose', 75 | 'sawyer/tcp/pose', 76 | 'sawyer/tcp/velocity', 77 | 'wrist/force', 78 | 'wrist/torque' 79 | ) 80 | 81 | _VISION_OBSERVATIONS = ( 82 | 'basket_back_left/pixels', 83 | 'basket_front_left/pixels', 84 | 'basket_front_right/pixels' 85 | ) 86 | 87 | # For interactive imitation learning, the vision-based policy used the 88 | # following proprioception and images observations from the pair of front 89 | # cameras given by the simulated environment. 90 | _INTERACTIVE_IMITATION_LEARNING_OBSERVATIONS = ( 91 | 'sawyer/joints/angle', 92 | 'gripper/joints/angle', 93 | 'sawyer/pinch/pose', 94 | 'sawyer/tcp/pose', 95 | 'basket_front_left/pixels', 96 | 'basket_front_right/pixels' 97 | ) 98 | 99 | 100 | # For the one-step offline policy improvement from real data, the vision-based 101 | # policy used the following proprioception and images observations from the pair 102 | # of front cameras given by the real environment. 103 | _OFFLINE_POLICY_IMPROVEMENT_OBSERVATIONS = [ 104 | 'sawyer/joints/angle', 105 | 'sawyer/joints/velocity', 106 | 'gripper/grasp', 107 | 'gripper/joints/angle', 108 | 'gripper/joints/velocity', 109 | 'sawyer/pinch/pose', 110 | 'basket_front_left/pixels', 111 | 'basket_front_right/pixels',] 112 | 113 | 114 | class ObservationSet(int, enum.Enum): 115 | """Different possible set of observations that can be exposed.""" 116 | 117 | _observations: Sequence[str] 118 | 119 | def __new__(cls, value: int, observations: Sequence[str]): 120 | obj = int.__new__(cls, value) 121 | obj._value_ = value 122 | obj._observations = observations 123 | return obj 124 | 125 | @property 126 | def observations(self): 127 | return self._observations 128 | 129 | STATE_ONLY = (0, _STATE_OBSERVATIONS) 130 | VISION_ONLY = (1, _VISION_OBSERVATIONS) 131 | ALL = (2, _STATE_OBSERVATIONS + _VISION_OBSERVATIONS) 132 | INTERACTIVE_IMITATION_LEARNING = ( 133 | 3, _INTERACTIVE_IMITATION_LEARNING_OBSERVATIONS) 134 | OFFLINE_POLICY_IMPROVEMENT = (4, _OFFLINE_POLICY_IMPROVEMENT_OBSERVATIONS) 135 | 136 | 137 | def rgb_stacking( 138 | object_triplet: str = 'rgb_test_random', 139 | observation_set: ObservationSet = ObservationSet.STATE_ONLY, 140 | use_sparse_reward: bool = False 141 | ) -> subtask_env.SubTaskEnvironment: 142 | """Returns the environment. 143 | 144 | The relevant groups can be found here: 145 | https://github.com/deepmind/robotics/blob/main/py/manipulation/props/rgb_objects/rgb_object.py 146 | 147 | The valid object triplets can be found under PROP_TRIPLETS in the file. 148 | 149 | Args: 150 | object_triplet: Triplet of RGB objects to use in the environment. 151 | observation_set: Set of observations that en environment should expose. 152 | use_sparse_reward: If true will use sparse reward, which is 1 if the objects 153 | stacked and not touching the robot, and 0 otherwise. 154 | """ 155 | 156 | red_id, green_id, blue_id = rgb_object.PROP_TRIPLETS[object_triplet][1] 157 | rgb_task = task.rgb_task(red_id, green_id, blue_id) 158 | rgb_task.physics_timestep = _PHYSICS_TIMESTEP 159 | 160 | # To speed up simulation we ensure that mujoco will no check contact between 161 | # geoms that cannot collide. 162 | mujoco_collisions.exclude_bodies_based_on_contype_conaffinity( 163 | rgb_task.root_entity.mjcf_model) 164 | 165 | # Build the agent flow subtask. This is where the task logic is defined, 166 | # observations, and rewards. 167 | env_builder = subtask_env_builder.SubtaskEnvBuilder() 168 | env_builder.set_task(rgb_task) 169 | task_env = env_builder.build_base_env() 170 | 171 | # Define the action space, this is used to expose the actuators used in the 172 | # base task. 173 | effectors_action_spec = rgb_task.effectors_action_spec( 174 | physics=task_env.physics) 175 | robot_action_spaces = [] 176 | for rbt in rgb_task.robots: 177 | arm_action_space = action_spaces.ArmJointActionSpace( 178 | af.prefix_slicer(effectors_action_spec, rbt.arm_effector.prefix)) 179 | gripper_action_space = action_spaces.GripperActionSpace( 180 | af.prefix_slicer(effectors_action_spec, rbt.gripper_effector.prefix)) 181 | robot_action_spaces.extend([arm_action_space, gripper_action_space]) 182 | 183 | composite_action_space = af.CompositeActionSpace( 184 | robot_action_spaces) 185 | env_builder.set_action_space(composite_action_space) 186 | 187 | # Cast all the floating point observations to float32. 188 | env_builder.add_preprocessor( 189 | observation_transforms.DowncastFloatPreprocessor(np.float32)) 190 | 191 | # Concatenate the TCP and wrist site observations. 192 | env_builder.add_preprocessor(observation_transforms.MergeObservations( 193 | obs_to_merge=['robot0_tcp_pos', 'robot0_tcp_quat'], 194 | new_obs='robot0_tcp_pose')) 195 | env_builder.add_preprocessor(observation_transforms.MergeObservations( 196 | obs_to_merge=['robot0_wrist_site_pos', 'robot0_wrist_site_quat'], 197 | new_obs='robot0_wrist_site_pose')) 198 | 199 | # Add in observations to measure the distance from the TCP to the objects. 200 | for color in ('red', 'green', 'blue'): 201 | env_builder.add_preprocessor(observation_transforms.AddObservation( 202 | obs_name=f'{color}_to_pinch', 203 | obs_callable=_distance_delta_obs( 204 | f'rgb_object_{color}_pose', 'robot0_tcp_pose'))) 205 | 206 | # Concatenate the action sent to the robot joints and the gripper actuator. 207 | env_builder.add_preprocessor(observation_transforms.MergeObservations( 208 | obs_to_merge=['robot0_arm_joint_previous_action', 209 | 'robot0_gripper_previous_action'], 210 | new_obs='robot0_previous_action')) 211 | 212 | # Mapping of observation names to match the observation names in the stored 213 | # data. 214 | obs_mapping = { 215 | 'robot0_arm_joint_pos': 'sawyer/joints/angle', 216 | 'robot0_arm_joint_vel': 'sawyer/joints/velocity', 217 | 'robot0_arm_joint_torques': 'sawyer/joints/torque', 218 | 'robot0_tcp_pose': 'sawyer/pinch/pose', 219 | 'robot0_wrist_site_pose': 'sawyer/tcp/pose', 220 | 'robot0_wrist_site_vel_world': 'sawyer/tcp/velocity', 221 | 'robot0_gripper_pos': 'gripper/joints/angle', 222 | 'robot0_gripper_vel': 'gripper/joints/velocity', 223 | 'robot0_gripper_grasp': 'gripper/grasp', 224 | 'robot0_wrist_force': 'wrist/force', 225 | 'robot0_wrist_torque': 'wrist/torque', 226 | 'rgb_object_red_pose': 'rgb30_red/abs_pose', 227 | 'rgb_object_green_pose': 'rgb30_green/abs_pose', 228 | 'rgb_object_blue_pose': 'rgb30_blue/abs_pose', 229 | 'basket_back_left_rgb_img': 'basket_back_left/pixels', 230 | 'basket_front_left_rgb_img': 'basket_front_left/pixels', 231 | 'basket_front_right_rgb_img': 'basket_front_right/pixels', 232 | 'red_to_pinch': 'rgb30_red/to_pinch', 233 | 'blue_to_pinch': 'rgb30_blue/to_pinch', 234 | 'green_to_pinch': 'rgb30_green/to_pinch', 235 | 'robot0_previous_action': 'action/environment', 236 | } 237 | 238 | # Create different subsets of observations. 239 | action_obs = {'action/environment'} 240 | 241 | # These observations only have a single floating point value instead of an 242 | # array. 243 | single_value_obs = {'gripper/joints/angle', 244 | 'gripper/joints/velocity', 245 | 'gripper/grasp'} 246 | 247 | # Rename observations. 248 | env_builder.add_preprocessor(observation_transforms.RenameObservations( 249 | obs_mapping, raise_on_missing=False)) 250 | 251 | if use_sparse_reward: 252 | reward_fn = stack_rewards.get_sparse_reward_fn( 253 | top_object=rgb_task.props[0], 254 | bottom_object=rgb_task.props[2], 255 | get_physics_fn=lambda: task_env.physics) 256 | else: 257 | reward_fn = stack_rewards.get_shaped_stacking_reward() 258 | env_builder.add_preprocessor(reward_functions.RewardPreprocessor(reward_fn)) 259 | 260 | # We concatenate several observations from consecutive timesteps. Depending 261 | # on the observations, we will concatenate a different number of observations. 262 | # - Most observations are stacked 3 times 263 | # - Camera observations are not stacked. 264 | # - The action observation is stacked twice. 265 | # - When stacking three scalar (i.e. numpy array of shape (1,)) observations, 266 | # we do not add a leading dimension, so the final shape is (3,). 267 | env_builder.add_preprocessor( 268 | observation_transforms.StackObservations( 269 | obs_to_stack=list( 270 | set(_STATE_OBSERVATIONS) - action_obs - single_value_obs), 271 | stack_depth=_OBSERVATION_STACK_DEPTH, 272 | add_leading_dim=True)) 273 | env_builder.add_preprocessor( 274 | observation_transforms.StackObservations( 275 | obs_to_stack=list(single_value_obs), 276 | stack_depth=_OBSERVATION_STACK_DEPTH, 277 | add_leading_dim=False)) 278 | env_builder.add_preprocessor( 279 | observation_transforms.StackObservations( 280 | obs_to_stack=list(action_obs), 281 | stack_depth=_ACTION_OBS_STACK_DEPTH, 282 | add_leading_dim=True)) 283 | 284 | # Only keep the obseravtions that we want to expose to the agent. 285 | env_builder.add_preprocessor(observation_transforms.RetainObservations( 286 | observation_set.observations, raise_on_missing=False)) 287 | 288 | # End episodes after 400 steps. 289 | env_builder.add_preprocessor( 290 | subtask_termination.MaxStepsTermination(_MAX_STEPS)) 291 | 292 | return env_builder.build() 293 | 294 | 295 | def _distance_delta_obs(key1: str, key2: str): 296 | """Returns a callable that returns the difference between two observations.""" 297 | def util(timestep: tsp.PreprocessorTimestep) -> np.ndarray: 298 | return timestep.observation[key1] - timestep.observation[key2] 299 | return util 300 | 301 | -------------------------------------------------------------------------------- /rgb_stacking/environment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests environment.py.""" 17 | 18 | from absl.testing import absltest 19 | from dm_env import test_utils 20 | 21 | from rgb_stacking import environment 22 | 23 | 24 | class EnvironmentTest(test_utils.EnvironmentTestMixin, absltest.TestCase): 25 | 26 | def make_object_under_test(self): 27 | return environment.rgb_stacking( 28 | object_triplet='rgb_test_triplet1', 29 | observation_set=environment.ObservationSet.ALL) 30 | 31 | 32 | if __name__ == '__main__': 33 | absltest.main() 34 | -------------------------------------------------------------------------------- /rgb_stacking/main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Script to run a viewer to visualize the rgb stacking environment.""" 17 | 18 | from typing import Sequence 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | import cv2 24 | from dm_control import viewer 25 | from dm_robotics.manipulation.props.rgb_objects import rgb_object 26 | from dm_robotics.moma import subtask_env 27 | import numpy as np 28 | 29 | from rgb_stacking import environment 30 | from rgb_stacking.utils import policy_loading 31 | 32 | _TEST_OBJECT_TRIPLETS = tuple(rgb_object.PROP_TRIPLETS_TEST.keys()) 33 | 34 | _ALL_OBJECT_TRIPLETS = _TEST_OBJECT_TRIPLETS + ( 35 | 'rgb_train_random', 36 | 'rgb_test_random', 37 | ) 38 | 39 | _POLICY_DIR = ('assets/saved_models') 40 | 41 | _POLICY_PATHS = { 42 | k: f'{_POLICY_DIR}/mpo_state_{k}' for k in _TEST_OBJECT_TRIPLETS 43 | } 44 | 45 | _OBJECT_TRIPLET = flags.DEFINE_enum( 46 | 'object_triplet', 'rgb_test_random', _ALL_OBJECT_TRIPLETS, 47 | 'Triplet of RGB objects to use in the environment.') 48 | _POLICY_OBJECT_TRIPLET = flags.DEFINE_enum( 49 | 'policy_object_triplet', None, _TEST_OBJECT_TRIPLETS, 50 | 'Optional test triplet name indicating to load a policy that was trained on' 51 | ' this triplet.') 52 | _LAUNCH_VIEWER = flags.DEFINE_bool( 53 | 'launch_viewer', True, 54 | 'Optional boolean. If True, will launch the dm_control viewer. If False' 55 | ' will load the policy, run it and save a recording of it as an .mp4.') 56 | 57 | 58 | def run_episode_and_render( 59 | env: subtask_env.SubTaskEnvironment, 60 | policy: policy_loading.Policy 61 | ) -> Sequence[np.ndarray]: 62 | """Saves a gif of the policy running against the environment.""" 63 | rendered_images = [] 64 | logging.info('Starting the rendering of the policy, this might take some' 65 | ' time...') 66 | state = policy.initial_state() 67 | timestep = env.reset() 68 | rendered_images.append(env.physics.render(camera_id='main_camera')) 69 | while not timestep.last(): 70 | (action, _), state = policy.step(timestep, state) 71 | timestep = env.step(action) 72 | rendered_images.append(env.physics.render(camera_id='main_camera')) 73 | logging.info('Done rendering!') 74 | return rendered_images 75 | 76 | 77 | def main(argv: Sequence[str]) -> None: 78 | 79 | del argv 80 | 81 | if not _LAUNCH_VIEWER.value and _POLICY_OBJECT_TRIPLET.value is None: 82 | raise ValueError('To record a video, a policy must be given.') 83 | 84 | # Load the rgb stacking environment. 85 | with environment.rgb_stacking(object_triplet=_OBJECT_TRIPLET.value) as env: 86 | 87 | # Optionally load a policy trained on one of these environments. 88 | if _POLICY_OBJECT_TRIPLET.value is not None: 89 | policy_path = _POLICY_PATHS[_POLICY_OBJECT_TRIPLET.value] 90 | policy = policy_loading.policy_from_path(policy_path) 91 | else: 92 | policy = None 93 | 94 | if _LAUNCH_VIEWER.value: 95 | # The viewer requires a callable as a policy. 96 | if policy is not None: 97 | policy = policy_loading.StatefulPolicyCallable(policy) 98 | viewer.launch(env, policy=policy) 99 | else: 100 | 101 | # Render the episode. 102 | rendered_episode = run_episode_and_render(env, policy) 103 | 104 | # Save as mp4 video in current directory. 105 | height, width, _ = rendered_episode[0].shape 106 | out = cv2.VideoWriter( 107 | './rendered_policy.mp4', 108 | cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), 109 | 1.0 / env.task.control_timestep, (width, height)) 110 | for image in rendered_episode: 111 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 112 | out.write(image) 113 | out.release() 114 | 115 | if __name__ == '__main__': 116 | app.run(main) 117 | -------------------------------------------------------------------------------- /rgb_stacking/physics_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """This file contains various helper functions for finding mujoco collisions. 17 | """ 18 | 19 | import itertools 20 | 21 | DEFAULT_OBJECT_COLLISION_MARGIN = 0.0002 22 | DEFAULT_COLLISION_MARGIN = 1e-8 23 | 24 | OBJECT_GEOM_PREFIXES = ['rgb'] 25 | GROUND_GEOM_PREFIXES = ['work_surface', 'ground'] 26 | ROBOT_GEOM_PREFIXES = ['robot'] 27 | 28 | 29 | def has_object_collision(physics, collision_geom_prefix, 30 | margin=DEFAULT_OBJECT_COLLISION_MARGIN): 31 | """Check for collisions between geoms and objects.""" 32 | return has_collision( 33 | physics=physics, 34 | collision_geom_prefix_1=[collision_geom_prefix], 35 | collision_geom_prefix_2=OBJECT_GEOM_PREFIXES, 36 | margin=margin) 37 | 38 | 39 | def has_ground_collision(physics, collision_geom_prefix, 40 | margin=DEFAULT_COLLISION_MARGIN): 41 | """Check for collisions between geoms and the ground.""" 42 | return has_collision( 43 | physics=physics, 44 | collision_geom_prefix_1=[collision_geom_prefix], 45 | collision_geom_prefix_2=GROUND_GEOM_PREFIXES, 46 | margin=margin) 47 | 48 | 49 | def has_robot_collision(physics, collision_geom_prefix, 50 | margin=DEFAULT_COLLISION_MARGIN): 51 | """Check for collisions between geoms and the robot.""" 52 | return has_collision( 53 | physics=physics, 54 | collision_geom_prefix_1=[collision_geom_prefix], 55 | collision_geom_prefix_2=ROBOT_GEOM_PREFIXES, 56 | margin=margin) 57 | 58 | 59 | def has_collision(physics, collision_geom_prefix_1, collision_geom_prefix_2, 60 | margin=DEFAULT_COLLISION_MARGIN): 61 | """Check for collisions between geoms.""" 62 | for contact in physics.data.contact: 63 | if contact.dist > margin: 64 | continue 65 | geom1_name = physics.model.id2name(contact.geom1, 'geom') 66 | geom2_name = physics.model.id2name(contact.geom2, 'geom') 67 | for pair in itertools.product( 68 | collision_geom_prefix_1, collision_geom_prefix_2): 69 | if ((geom1_name.startswith(pair[0]) and 70 | geom2_name.startswith(pair[1])) or 71 | (geom2_name.startswith(pair[0]) and 72 | geom1_name.startswith(pair[1]))): 73 | return True 74 | return False 75 | -------------------------------------------------------------------------------- /rgb_stacking/reward_functions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Generic reward functions.""" 17 | 18 | from typing import Callable, Optional, Iterable 19 | from absl import logging 20 | from dm_robotics import agentflow as af 21 | from dm_robotics.agentflow import spec_utils 22 | import numpy as np 23 | 24 | # Value returned by the gripper grasp observation when the gripper is in the 25 | # 'grasp' state (fingers closed and exerting force) 26 | _INWARD_GRASP = 2 27 | 28 | # Minimal value for the position tolerance of the shaped distance rewards. 29 | MINIMUM_POSITION_TOLERANCE = 1e-9 30 | 31 | RewardFunction = Callable[[spec_utils.ObservationValue], float] 32 | 33 | 34 | class RewardPreprocessor(af.TimestepPreprocessor): 35 | """Timestep preprocessor wrapper around a reward function.""" 36 | 37 | def __init__(self, reward_function: RewardFunction): 38 | super().__init__() 39 | self._reward_function = reward_function 40 | 41 | def _process_impl( 42 | self, timestep: af.PreprocessorTimestep) -> af.PreprocessorTimestep: 43 | reward = self._reward_function(timestep.observation) 44 | reward = self._out_spec.reward_spec.dtype.type(reward) 45 | return timestep.replace(reward=reward) 46 | 47 | def _output_spec( 48 | self, input_spec: spec_utils.TimeStepSpec) -> spec_utils.TimeStepSpec: 49 | return input_spec 50 | 51 | 52 | class GraspReward: 53 | """Sparse reward for the gripper grasp status.""" 54 | 55 | def __init__(self, obs_key: str): 56 | """Creates new GraspReward function. 57 | 58 | Args: 59 | obs_key: Key of the grasp observation in the observation spec. 60 | """ 61 | self._obs_key = obs_key 62 | 63 | def __call__(self, obs: spec_utils.ObservationValue): 64 | is_grasped = obs[self._obs_key][0] == _INWARD_GRASP 65 | return float(is_grasped) 66 | 67 | 68 | class StackPair: 69 | """Get two objects to be above each other. 70 | 71 | To determine the expected position the top object should be at, the size of 72 | the objects must be specified. Currently, only objects with 3-axial symmetry 73 | are supported; for other objects, the distance would impose a constraint on 74 | the possible orientations. 75 | 76 | Reward will be given if the top object's center is within a certain distance 77 | of the point above the bottom object, with the distance tolerance being 78 | separately configurable for the horizontal plane and the vertical axis. 79 | 80 | By default, reward is also only given when the robot is not currently grasping 81 | anything, though this can be deactivated. 82 | """ 83 | 84 | def __init__(self, 85 | obj_top: str, 86 | obj_bottom: str, 87 | obj_top_size: float, 88 | obj_bottom_size: float, 89 | horizontal_tolerance: float, 90 | vertical_tolerance: float, 91 | grasp_key: str = "gripper/grasp"): 92 | """Initialize the module. 93 | 94 | Args: 95 | obj_top: Key in the observation dict containing the top object's 96 | position. 97 | obj_bottom: Key in the observation dict containing the bottom 98 | object's position. 99 | obj_top_size: Height of the top object along the vertical axis, in meters. 100 | obj_bottom_size: Height of the bottom object along the vertical axis, in 101 | meters. 102 | horizontal_tolerance: Max distance from the exact stack position on the 103 | horizontal plane to still generate reward. 104 | vertical_tolerance: Max distance from the exact stack position along the 105 | vertical axis to still generate reward. 106 | grasp_key: Key in the observation dict containing the (buffered) grasp 107 | status. Can be set to `None` to not check the grasp status to return 108 | a reward. 109 | """ 110 | self._top_key = obj_top 111 | self._bottom_key = obj_bottom 112 | self._horizontal_tolerance = horizontal_tolerance 113 | self._vertical_tolerance = vertical_tolerance 114 | self._grasp_key = grasp_key 115 | 116 | if obj_top_size <= 0. or obj_bottom_size <= 0.: 117 | raise ValueError("Object sizes cannot be zero.") 118 | self._expected_dist = (obj_top_size + obj_bottom_size) / 2. 119 | 120 | def __call__(self, obs: spec_utils.ObservationValue): 121 | top = obs[self._top_key] 122 | bottom = obs[self._bottom_key] 123 | 124 | horizontal_dist = np.linalg.norm(top[:2] - bottom[:2]) 125 | if horizontal_dist > self._horizontal_tolerance: 126 | return 0. 127 | 128 | vertical_dist = top[2] - bottom[2] 129 | if np.abs(vertical_dist - self._expected_dist) > self._vertical_tolerance: 130 | return 0. 131 | 132 | if self._grasp_key is not None: 133 | grasp = obs[self._grasp_key] 134 | if grasp == _INWARD_GRASP: 135 | return 0. 136 | 137 | return 1. 138 | 139 | 140 | def tanh_squared(x: np.ndarray, margin: float, loss_at_margin: float = 0.95): 141 | """Returns a sigmoidal shaping loss based on Hafner & Reidmiller (2011). 142 | 143 | Args: 144 | x: A numpy array representing the error. 145 | margin: Margin parameter, a positive `float`. 146 | loss_at_margin: The loss when `l2_norm(x) == margin`. A `float` between 0 147 | and 1. 148 | 149 | Returns: 150 | Shaping loss, a `float` bounded in the half-open interval [0, 1). 151 | 152 | Raises: 153 | ValueError: If the value of `margin` or `loss_at_margin` is invalid. 154 | """ 155 | 156 | if not margin > 0: 157 | raise ValueError("`margin` must be positive.") 158 | if not 0.0 < loss_at_margin < 1.0: 159 | raise ValueError("`loss_at_margin` must be between 0 and 1.") 160 | 161 | error = np.linalg.norm(x) 162 | # Compute weight such that at the margin tanh(w * error) = loss_at_margin 163 | w = np.arctanh(np.sqrt(loss_at_margin)) / margin 164 | s = np.tanh(w * error) 165 | return s * s 166 | 167 | 168 | class DistanceReward: 169 | """Shaped reward based on the distance B-A between two entities A and B.""" 170 | 171 | def __init__( 172 | self, 173 | key_a: str, 174 | key_b: Optional[str], 175 | position_tolerance: Optional[np.ndarray] = None, 176 | shaping_tolerance: float = 0.1, 177 | loss_at_tolerance: float = 0.95, 178 | max_reward: float = 1., 179 | offset: Optional[np.ndarray] = None, 180 | z_min: Optional[float] = None, 181 | dim=3 182 | ): 183 | """Initialize the module. 184 | 185 | Args: 186 | key_a: Observation dict key to numpy array containing the position of 187 | object A. 188 | key_b: None or observation dict key to numpy array containing the position 189 | of object B. If None, distance simplifies to d = offset - A. 190 | position_tolerance: Vector of length `dim`. If 191 | `distance/position_tolerance < 1`, will return `maximum_reward` 192 | instead of shaped one. Setting this to `None`, or setting any entry 193 | to zero or close to zero, will effectively disable tolerance. 194 | shaping_tolerance: Scalar distance at which the loss is equal to 195 | `loss_at_tolerance`. Must be a positive float or `None`. If `None` 196 | reward is sparse and hence 0 is returned if 197 | `distance > position_tolerance`. 198 | loss_at_tolerance: The loss when `l2_norm(distance) == shaping_tolerance`. 199 | A `float` between 0 and 1. 200 | max_reward: Reward to return when `distance/position_tolerance < 1`. 201 | offset: Vector of length 3 that is added to the distance, i.e. 202 | `distance = B - A + offset`. 203 | z_min: Absolute object height that the object A center has to above be in 204 | order to generate reward. Used for example in hovering rewards. 205 | dim: The dimensionality of the space in which the distance is computed 206 | """ 207 | self._key_a = key_a 208 | self._key_b = key_b 209 | 210 | self._shaping_tolerance = shaping_tolerance 211 | self._loss_at_tolerance = loss_at_tolerance 212 | if max_reward < 1.: 213 | logging.warning("Maximum reward should not be below tanh maximum.") 214 | self._max_reward = max_reward 215 | self._z_min = z_min 216 | self._dim = dim 217 | 218 | if position_tolerance is None: 219 | self._position_tolerance = np.full( 220 | (dim,), fill_value=MINIMUM_POSITION_TOLERANCE) 221 | else: 222 | self._position_tolerance = position_tolerance 223 | self._position_tolerance[self._position_tolerance == 0] = ( 224 | MINIMUM_POSITION_TOLERANCE) 225 | 226 | if offset is None: 227 | self._offset = np.zeros((dim,)) 228 | else: 229 | self._offset = offset 230 | 231 | def __call__(self, obs: spec_utils.ObservationValue) -> float: 232 | 233 | # Check that object A is high enough before computing the reward. 234 | if self._z_min is not None and obs[self._key_a][2] < self._z_min: 235 | return 0. 236 | 237 | self._current_distance = (self._offset - obs[self._key_a][0:self._dim]) 238 | if self._key_b is not None: 239 | self._current_distance += obs[self._key_b][0:self._dim] 240 | 241 | weighted = self._current_distance / self._position_tolerance 242 | if np.linalg.norm(weighted) <= 1.: 243 | return self._max_reward 244 | 245 | if not self._shaping_tolerance: 246 | return 0. 247 | 248 | loss = tanh_squared( 249 | self._current_distance, margin=self._shaping_tolerance, 250 | loss_at_margin=self._loss_at_tolerance) 251 | return 1.0 - loss 252 | 253 | 254 | class LiftShaped: 255 | """Linear shaped reward for lifting, up to a specified height. 256 | 257 | Once the height is above a specified threshold, reward saturates. Shaping can 258 | also be deactivated for a sparse reward. 259 | 260 | Requires an observation `//abs_pose` containing the 261 | pose of the object in question. 262 | """ 263 | 264 | def __init__( 265 | self, 266 | obj_key, 267 | z_threshold, 268 | z_min, 269 | max_reward=1., 270 | shaping=True 271 | ): 272 | """Initialize the module. 273 | 274 | Args: 275 | obj_key: Key in the observation dict containing the object pose. 276 | z_threshold: Absolute object height at which the maximum reward will be 277 | given. 278 | z_min: Absolute object height that the object center has to above be in 279 | order to generate shaped reward. Ignored if `shaping` is False. 280 | max_reward: Reward given when the object is above the `z_threshold`. 281 | shaping: If true, will give a linear shaped reward when the object height 282 | is above `z_min`, but below `z_threshold`. 283 | Raises: 284 | ValueError: if `z_min` is larger than `z_threshold`. 285 | """ 286 | self._field = obj_key 287 | self._z_threshold = z_threshold 288 | self._shaping = shaping 289 | self._max_reward = max_reward 290 | self._z_min = z_min 291 | if z_min > z_threshold: 292 | raise ValueError("Lower shaping bound cannot be below upper bound.") 293 | 294 | def __call__(self, obs: spec_utils.ObservationValue) -> float: 295 | obj_z = obs[self._field][2] 296 | if obj_z <= self._z_min: 297 | return 0.0 298 | if obj_z >= self._z_threshold: 299 | return self._max_reward 300 | if self._shaping: 301 | r = (obj_z - self._z_min) / (self._z_threshold - self._z_min) 302 | return r 303 | return 0.0 304 | 305 | 306 | class Product: 307 | """Computes the product of a set of rewards.""" 308 | 309 | def __init__(self, terms: Iterable[RewardFunction]): 310 | self._terms = terms 311 | 312 | def __call__(self, obs: spec_utils.ObservationValue) -> float: 313 | r = 1. 314 | for term in self._terms: 315 | r *= term(obs) 316 | return r 317 | 318 | 319 | class _WeightedTermReward: 320 | """Base class for rewards using lists of weighted terms.""" 321 | 322 | def __init__(self, 323 | terms: Iterable[RewardFunction], 324 | weights: Optional[Iterable[float]] = None): 325 | """Initialize the reward instance. 326 | 327 | Args: 328 | terms: List of reward callables to be operated on. Each callable must 329 | take an observation as input, and return a float. 330 | weights: Weight that each reward returned by the callables in `terms` will 331 | be multiplied by. If `None`, will weight all terms with 1.0. 332 | Raises: 333 | ValueError: If `weights` has been specified, but its length differs from 334 | that of `terms`. 335 | """ 336 | self._terms = list(terms) 337 | self._weights = weights or [1.] * len(self._terms) 338 | if len(self._weights) != len(self._terms): 339 | raise ValueError("Number of terms and weights should be same.") 340 | 341 | def _weighted_terms( 342 | self, obs: spec_utils.ObservationValue) -> Iterable[float]: 343 | return [t(obs) * w for t, w in zip(self._terms, self._weights)] 344 | 345 | 346 | class Max(_WeightedTermReward): 347 | """Selects the maximum among a number of weighted rewards.""" 348 | 349 | def __call__(self, obs: spec_utils.ObservationValue) -> float: 350 | return max(self._weighted_terms(obs)) 351 | 352 | 353 | class ConditionalAnd: 354 | """Perform an and operation conditional on term1 exceeding a threshold.""" 355 | 356 | def __init__(self, 357 | term1: RewardFunction, 358 | term2: RewardFunction, 359 | threshold: float): 360 | self._term1 = term1 361 | self._term2 = term2 362 | self._thresh = threshold 363 | 364 | def __call__(self, obs: spec_utils.ObservationValue) -> float: 365 | r1 = self._term1(obs) 366 | r2 = self._term2(obs) 367 | if r1 > self._thresh: 368 | return (0.5 + r2 / 2.) * r1 369 | else: 370 | return r1 * 0.5 371 | 372 | 373 | class Staged: 374 | """Stages the rewards. 375 | 376 | This works by cycling through the terms backwards and using the last reward 377 | that gives a response above the provided threshold + the number of 378 | terms preceding it. 379 | 380 | Rewards must be in [0;1], otherwise they will be clipped. 381 | """ 382 | 383 | def __init__( 384 | self, terms: Iterable[RewardFunction], threshold: float): 385 | def make_clipped(term: RewardFunction): 386 | return lambda obs: np.clip(term(obs), 0., 1.) 387 | self._terms = [make_clipped(term) for term in terms] 388 | self._thresh = threshold 389 | 390 | def __call__(self, obs: spec_utils.ObservationValue) -> float: 391 | last_reward = 0. 392 | num_stages = float(len(self._terms)) 393 | for i, term in enumerate(reversed(self._terms)): 394 | last_reward = term(obs) 395 | if last_reward > self._thresh: 396 | # Found a reward above the threshold, add number of preceding terms 397 | # and normalize with the number of terms. 398 | return (len(self._terms) - (i + 1) + last_reward) / num_stages 399 | # Return the accumulated rewards. 400 | return last_reward / num_stages 401 | -------------------------------------------------------------------------------- /rgb_stacking/stack_rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Shaped and sparse reward functions for the RGB stacking task.""" 17 | 18 | from typing import Callable 19 | 20 | from dm_control import mjcf 21 | from dm_robotics.agentflow import spec_utils 22 | from dm_robotics.moma import prop 23 | import numpy as np 24 | 25 | from rgb_stacking import physics_utils 26 | from rgb_stacking import reward_functions 27 | from rgb_stacking import task 28 | 29 | # Heights above basket surface between which shaping for Lift will be computed. 30 | _LIFT_Z_MAX_ABOVE_BASKET = 0.1 31 | _LIFT_Z_MIN_ABOVE_BASKET = 0.055 32 | 33 | # Height for lifting. 34 | _LIFTING_HEIGHT = 0.04 35 | 36 | # Cylinder around (bottom object + object height) in which the top object must 37 | # be for it to be considered stacked. 38 | _STACK_HORIZONTAL_TOLERANCE = 0.03 39 | _STACK_VERTICAL_TOLERANCE = 0.01 40 | 41 | # Target height above bottom object for hover-above reward. 42 | _HOVER_OFFSET = 0.1 43 | 44 | # Distances at which maximum reach reward is given and shaping decays. 45 | _REACH_POSITION_TOLERANCE = np.array((0.02, 0.02, 0.02)) 46 | _REACH_SHAPING_TOLERANCE = 0.15 47 | 48 | # Distances at which maximum hovering reward is given and shaping decays. 49 | _HOVER_POSITION_TOLERANCE = np.array((0.01, 0.01, 0.01)) 50 | _HOVER_SHAPING_TOLERANCE = 0.2 51 | 52 | # Observation keys needed to compute the different rewards. 53 | _PINCH_POSE = 'sawyer/pinch/pose' 54 | _FINGER_ANGLE = 'gripper/joints/angle' 55 | _GRIPPER_GRASP = 'gripper/grasp' 56 | _RED_OBJECT_POSE = 'rgb30_red/abs_pose' 57 | _BLUE_OBJECT_POSE = 'rgb30_blue/abs_pose' 58 | 59 | 60 | # Object closeness threshold for x and y axis. 61 | _X_Y_CLOSE = 0.05 62 | _ONTOP = 0.02 63 | 64 | 65 | def get_shaped_stacking_reward() -> reward_functions.RewardFunction: 66 | """Returns a callable reward function for stacking the red block on blue.""" 67 | # # First stage: reach and grasp. 68 | reach_red = reward_functions.DistanceReward( 69 | key_a=_PINCH_POSE, 70 | key_b=_RED_OBJECT_POSE, 71 | shaping_tolerance=_REACH_SHAPING_TOLERANCE, 72 | position_tolerance=_REACH_POSITION_TOLERANCE) 73 | close_fingers = reward_functions.DistanceReward( 74 | key_a=_FINGER_ANGLE, 75 | key_b=None, 76 | position_tolerance=None, 77 | shaping_tolerance=255., 78 | loss_at_tolerance=0.95, 79 | max_reward=1., 80 | offset=np.array((255,)), 81 | dim=1) 82 | grasp = reward_functions.Max( 83 | terms=( 84 | close_fingers, reward_functions.GraspReward(obs_key=_GRIPPER_GRASP)), 85 | weights=(0.5, 1.)) 86 | reach_grasp = reward_functions.ConditionalAnd(reach_red, grasp, 0.9) 87 | 88 | # Second stage: grasp and lift. 89 | lift_reward = _get_reward_lift_red() 90 | lift_red = reward_functions.Product([grasp, lift_reward]) 91 | 92 | # Third stage: hover. 93 | top = _RED_OBJECT_POSE 94 | bottom = _BLUE_OBJECT_POSE 95 | place = reward_functions.DistanceReward( 96 | key_a=top, 97 | key_b=bottom, 98 | offset=np.array((0., 0., _LIFTING_HEIGHT)), 99 | position_tolerance=_HOVER_POSITION_TOLERANCE, 100 | shaping_tolerance=_HOVER_SHAPING_TOLERANCE, 101 | z_min=0.1) 102 | 103 | # Fourth stage: stack. 104 | stack = _get_reward_stack_red_on_blue() 105 | 106 | # Final stage: stack-and-leave 107 | stack_leave = reward_functions.Product( 108 | terms=(_get_reward_stack_red_on_blue(), _get_reward_above_red())) 109 | 110 | return reward_functions.Staged( 111 | [reach_grasp, lift_red, place, stack, stack_leave], 0.01) 112 | 113 | 114 | def get_sparse_stacking_reward() ->reward_functions.RewardFunction: 115 | """Sparse stacking reward for red-on-blue with the gripper moved away.""" 116 | return reward_functions.Product( 117 | terms=(_get_reward_stack_red_on_blue(), _get_reward_above_red())) 118 | 119 | 120 | def _get_reward_lift_red() -> reward_functions.RewardFunction: 121 | """Returns a callable reward function for lifting the red block.""" 122 | lift_reward = reward_functions.LiftShaped( 123 | obj_key=_RED_OBJECT_POSE, 124 | z_threshold=task.DEFAULT_BASKET_HEIGHT + _LIFT_Z_MAX_ABOVE_BASKET, 125 | z_min=task.DEFAULT_BASKET_HEIGHT + _LIFT_Z_MIN_ABOVE_BASKET) 126 | # Keep the object inside the area of the base plate. 127 | inside_basket = reward_functions.DistanceReward( 128 | key_a=_RED_OBJECT_POSE, 129 | key_b=None, 130 | position_tolerance=task.WORKSPACE_SIZE / 2., 131 | shaping_tolerance=1e-12, # Practically none. 132 | loss_at_tolerance=0.95, 133 | max_reward=1., 134 | offset=task.WORKSPACE_CENTER) 135 | return reward_functions.Product([lift_reward, inside_basket]) 136 | 137 | 138 | def _get_reward_stack_red_on_blue() -> reward_functions.RewardFunction: 139 | """Returns a callable reward function for stacking the red block on blue.""" 140 | return reward_functions.StackPair( 141 | obj_top=_RED_OBJECT_POSE, 142 | obj_bottom=_BLUE_OBJECT_POSE, 143 | obj_top_size=_LIFTING_HEIGHT, 144 | obj_bottom_size=_LIFTING_HEIGHT, 145 | horizontal_tolerance=_STACK_HORIZONTAL_TOLERANCE, 146 | vertical_tolerance=_STACK_VERTICAL_TOLERANCE) 147 | 148 | 149 | def _get_reward_above_red() -> reward_functions.RewardFunction: 150 | """Returns a callable reward function for being above the red block.""" 151 | return reward_functions.DistanceReward( 152 | key_a=_PINCH_POSE, 153 | key_b=_RED_OBJECT_POSE, 154 | shaping_tolerance=0.05, 155 | offset=np.array((0., 0., _HOVER_OFFSET)), 156 | position_tolerance=np.array((1., 1., 0.03))) # Anywhere horizontally. 157 | 158 | 159 | class SparseStack(object): 160 | """Sparse stack reward. 161 | 162 | Checks that the two objects being within _X_Y_CLOSE 163 | of each other in the x-y plane (no constraint on z-distance). Also checks 164 | that the object are not in contact with the robot, to ensure the robot is 165 | holding the objects in place. 166 | """ 167 | 168 | def __init__(self, 169 | top_object: prop.Prop, 170 | bottom_object: prop.Prop, 171 | get_physics_fn: Callable[[], mjcf.Physics]): 172 | """Initializes the reward. 173 | 174 | Args: 175 | top_object: Composer entity of the top object (red). 176 | bottom_object: Composer entity of the bottom object (blue). 177 | get_physics_fn: Callable that returns the current mjc physics from the 178 | environment. 179 | """ 180 | self._get_physics_fn = get_physics_fn 181 | self._top_object = top_object 182 | self._bottom_object = bottom_object 183 | 184 | def _align(self, physics): 185 | return np.linalg.norm( 186 | self._top_object.get_pose(physics)[0][:2] - 187 | self._bottom_object.get_pose(physics)[0][:2]) < _X_Y_CLOSE 188 | 189 | def _ontop(self, physics): 190 | 191 | return (self._top_object.get_pose(physics)[0][2] - 192 | self._bottom_object.get_pose(physics)[0][2]) > _ONTOP 193 | 194 | def _pile(self, physics): 195 | geom = '{}/'.format(self._top_object.name) 196 | if physics_utils.has_ground_collision(physics, collision_geom_prefix=geom): 197 | return float(0.0) 198 | if physics_utils.has_robot_collision(physics, collision_geom_prefix=geom): 199 | return float(0.0) 200 | return float(1.0) 201 | 202 | def _collide(self, physics): 203 | collision_geom_prefix_1 = '{}/'.format(self._top_object.name) 204 | collision_geom_prefix_2 = '{}/'.format(self._bottom_object.name) 205 | return physics_utils.has_collision(physics, [collision_geom_prefix_1], 206 | [collision_geom_prefix_2]) 207 | 208 | def __call__(self, obs: spec_utils.ObservationValue): 209 | del obs 210 | physics = self._get_physics_fn() 211 | if self._align(physics) and self._pile(physics) and self._collide( 212 | physics) and self._ontop(physics): 213 | return 1.0 214 | return 0.0 215 | 216 | 217 | def get_sparse_reward_fn( 218 | top_object: prop.Prop, 219 | bottom_object: prop.Prop, 220 | get_physics_fn: Callable[[], mjcf.Physics] 221 | ) -> reward_functions.RewardFunction: 222 | """Sparse stacking reward for stacking two props with no robot contact. 223 | 224 | Args: 225 | top_object: The bottom object (blue). 226 | bottom_object: The top object (red). 227 | get_physics_fn: Callable that returns the current mjcf physics from the 228 | environment. 229 | 230 | Returns: 231 | The sparse stack reward function. 232 | """ 233 | return SparseStack( 234 | top_object=top_object, 235 | bottom_object=bottom_object, 236 | get_physics_fn=get_physics_fn) 237 | -------------------------------------------------------------------------------- /rgb_stacking/task.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A module for constructing the RGB stacking task. 17 | 18 | This file builds the composer task containing a single sawyer robot facing 19 | 3 objects: a red, a green and a blue one. 20 | 21 | We define: 22 | - All the simulation objects, robot, basket, objects. 23 | - The sensors to measure the state of the environment. 24 | - The effector to control the robot. 25 | - The initialization logic. 26 | 27 | On top of this we can build a MoMa subtask environment. In this subtask 28 | environment we will decide what the reward will be and what observations are 29 | exposed. Thus allowing us to change the goal without changing this environment. 30 | """ 31 | from typing import Sequence 32 | 33 | from dm_control import composer 34 | from dm_control.composer.variation import distributions 35 | from dm_control.composer.variation import rotations 36 | from dm_robotics.geometry import pose_distribution 37 | from dm_robotics.manipulation.props.rgb_objects import rgb_object 38 | from dm_robotics.manipulation.standard_cell import rgb_basket 39 | from dm_robotics.moma import base_task 40 | from dm_robotics.moma import entity_initializer 41 | from dm_robotics.moma import prop 42 | from dm_robotics.moma import robot as moma_robot 43 | from dm_robotics.moma.effectors import arm_effector as arm_effector_module 44 | from dm_robotics.moma.effectors import cartesian_4d_velocity_effector 45 | from dm_robotics.moma.effectors import cartesian_6d_velocity_effector 46 | from dm_robotics.moma.effectors import default_gripper_effector 47 | from dm_robotics.moma.effectors import min_max_effector 48 | from dm_robotics.moma.models.arenas import empty 49 | from dm_robotics.moma.models.end_effectors.robot_hands import robotiq_2f85 50 | from dm_robotics.moma.models.end_effectors.wrist_sensors import robotiq_fts300 51 | from dm_robotics.moma.models.robots.robot_arms import sawyer 52 | from dm_robotics.moma.models.robots.robot_arms import sawyer_constants 53 | from dm_robotics.moma.sensors import action_sensor 54 | from dm_robotics.moma.sensors import camera_sensor 55 | from dm_robotics.moma.sensors import prop_pose_sensor 56 | from dm_robotics.moma.sensors import robot_arm_sensor 57 | from dm_robotics.moma.sensors import robot_tcp_sensor 58 | from dm_robotics.moma.sensors import robot_wrist_ft_sensor 59 | from dm_robotics.moma.sensors import robotiq_gripper_sensor 60 | from dm_robotics.moma.sensors import site_sensor 61 | import numpy as np 62 | 63 | 64 | # Margin from the joint limits at which to stop the Z rotation when using 4D 65 | # control. Values chosen to match the existing PyRobot-based environments. 66 | WRIST_RANGE = (-np.pi / 2, np.pi / 2) 67 | 68 | 69 | # Position of the basket relative to the attachment point of the robot. 70 | _DEFAULT_BASKET_CENTER = (0.6, 0.) 71 | DEFAULT_BASKET_HEIGHT = 0.0498 72 | _BASKET_ORIGIN = _DEFAULT_BASKET_CENTER + (DEFAULT_BASKET_HEIGHT,) 73 | WORKSPACE_CENTER = np.array(_DEFAULT_BASKET_CENTER + (0.1698,)) 74 | WORKSPACE_SIZE = np.array([0.25, 0.25, 0.2]) 75 | # Maximum linear and angular velocity of the robot's TCP. 76 | _MAX_LIN_VEL = 0.07 77 | _MAX_ANG_VEL = 1.0 78 | 79 | # Limits of the distributions used to sample initial positions (X, Y, Z in [m]) 80 | # for props in the basket. 81 | _PROP_MIN_POSITION_BOUNDS = [0.50, -0.10, 0.12] 82 | _PROP_MAX_POSITION_BOUNDS = [0.70, 0.10, 0.12] 83 | 84 | # Limits of the distributions used to sample initial position for TCP. 85 | _TCP_MIN_POSE_BOUNDS = [0.5, -0.14, 0.22, np.pi, 0, -np.pi / 4] 86 | _TCP_MAX_POSE_BOUNDS = [0.7, 0.14, 0.43, np.pi, 0, np.pi / 4] 87 | 88 | # Control timestep exposed to the agent. 89 | _CONTROL_TIMESTEP = 0.05 90 | 91 | # Joint state used for the nullspace. 92 | _NULLSPACE_JOINT_STATE = [ 93 | 0.0, -0.5186220703125, -0.529384765625, 1.220857421875, 0.40857421875, 94 | 1.07831640625, 0.0] 95 | 96 | # Joint velocity magnitude limits from the Sawyer URDF. 97 | _JOINT_VEL_LIMITS = sawyer_constants.VELOCITY_LIMITS['max'] 98 | 99 | # Identifier for the cameras. The key is the name used for the MoMa camera 100 | # sensor and the value corresponds to the identifier of that camera in the 101 | # mjcf model. 102 | _CAMERA_IDENTIFIERS = {'basket_back_left': 'base/basket_back_left', 103 | 'basket_front_left': 'base/basket_front_left', 104 | 'basket_front_right': 'base/basket_front_right'} 105 | 106 | # Configuration of the MuJoCo cameras. 107 | _CAMERA_CONFIG = camera_sensor.CameraConfig( 108 | width=128, 109 | height=128, 110 | fovy=30., 111 | has_rgb=True, 112 | has_depth=False, 113 | ) 114 | 115 | 116 | def rgb_task(red_obj_id: str, 117 | green_obj_id: str, 118 | blue_obj_id: str) -> base_task.BaseTask: 119 | """Builds a BaseTask and all dependencies. 120 | 121 | Args: 122 | red_obj_id: The RGB object ID that corresponds to the red object. More 123 | information on this can be found in the RGB Objects file. 124 | green_obj_id: See `red_obj_id` 125 | blue_obj_id: See `red_obj_id` 126 | 127 | Returns: 128 | The modular manipulation (MoMa) base task for the RGB stacking environment. 129 | A robot is placed in front of a basket containing 3 objects: a red, a green 130 | and blue one. 131 | """ 132 | 133 | # Build the composer scene. 134 | arena = _arena() 135 | _workspace(arena) 136 | 137 | robot = _sawyer_robot(robot_name='robot0') 138 | arena.attach(robot.arm) 139 | 140 | # We add a camera with a good point of view for capturing videos. 141 | pos = '1.4 0.0 0.45' 142 | quat = '0.541 0.455 0.456 0.541' 143 | name = 'main_camera' 144 | fovy = '45' 145 | arena.mjcf_model.worldbody.add( 146 | 'camera', name=name, pos=pos, quat=quat, fovy=fovy) 147 | 148 | props = _props(red_obj_id, green_obj_id, blue_obj_id) 149 | for p in props: 150 | frame = arena.add_free_entity(p) 151 | p.set_freejoint(frame.freejoint) 152 | 153 | # Add in the MoMa sensor to get observations from the environment. 154 | extra_sensors = prop_pose_sensor.build_prop_pose_sensors(props) 155 | camera_configurations = { 156 | name: _CAMERA_CONFIG for name in _CAMERA_IDENTIFIERS.keys()} 157 | extra_sensors.extend( 158 | camera_sensor.build_camera_sensors( 159 | camera_configurations, arena.mjcf_model, _CAMERA_IDENTIFIERS)) 160 | 161 | # Initializers to place the TCP and the props in the basket. 162 | dynamic_initializer = entity_initializer.TaskEntitiesInitializer( 163 | [_gripper_initializer(robot), _prop_initializers(props)]) 164 | 165 | moma_task = base_task.BaseTask( 166 | task_name='rgb_stacking', 167 | arena=arena, 168 | robots=[robot], 169 | props=props, 170 | extra_sensors=extra_sensors, 171 | extra_effectors=[], 172 | scene_initializer=lambda _: None, 173 | episode_initializer=dynamic_initializer, 174 | control_timestep=_CONTROL_TIMESTEP) 175 | return moma_task 176 | 177 | 178 | def _workspace(arena: composer.Arena) -> rgb_basket.RGBBasket: 179 | """Returns the basket used in the rgb stacking environment.""" 180 | workspace = rgb_basket.RGBBasket() 181 | attachment_site = arena.mjcf_model.worldbody.add( 182 | 'site', pos=_BASKET_ORIGIN, rgba='0 0 0 0', size='0.01') 183 | arena.attach(workspace, attachment_site) 184 | return workspace 185 | 186 | 187 | def _gripper_initializer( 188 | robot: moma_robot.Robot) -> entity_initializer.PoseInitializer: 189 | """Populates components with gripper initializers.""" 190 | 191 | gripper_pose_dist = pose_distribution.UniformPoseDistribution( 192 | min_pose_bounds=_TCP_MIN_POSE_BOUNDS, 193 | max_pose_bounds=_TCP_MAX_POSE_BOUNDS) 194 | return entity_initializer.PoseInitializer(robot.position_gripper, 195 | gripper_pose_dist.sample_pose) 196 | 197 | 198 | def _prop_initializers( 199 | props: Sequence[prop.Prop]) -> entity_initializer.PropPlacer: 200 | """Populates components with prop pose initializers.""" 201 | prop_position = distributions.Uniform(_PROP_MIN_POSITION_BOUNDS, 202 | _PROP_MAX_POSITION_BOUNDS) 203 | prop_quaternion = rotations.UniformQuaternion() 204 | 205 | return entity_initializer.PropPlacer( 206 | props=props, 207 | position=prop_position, 208 | quaternion=prop_quaternion, 209 | settle_physics=True) 210 | 211 | 212 | def _arena() -> composer.Arena: 213 | """Builds an arena Entity.""" 214 | arena = empty.Arena() 215 | arena.mjcf_model.size.nconmax = 5000 216 | arena.mjcf_model.size.njmax = 5000 217 | 218 | return arena 219 | 220 | 221 | def _sawyer_robot(robot_name: str) -> moma_robot.Robot: 222 | """Returns a Sawyer robot with all the sensors and effectors.""" 223 | 224 | arm = sawyer.Sawyer( 225 | name=robot_name, actuation=sawyer_constants.Actuation.INTEGRATED_VELOCITY) 226 | 227 | gripper = robotiq_2f85.Robotiq2F85() 228 | 229 | wrist_ft = robotiq_fts300.RobotiqFTS300() 230 | 231 | wrist_cameras = [] 232 | 233 | # Compose the robot after its model components are constructed. This should 234 | # usually be done early on as some Effectors (and possibly Sensors) can only 235 | # be constructed after the robot components have been composed. 236 | moma_robot.standard_compose( 237 | arm=arm, gripper=gripper, wrist_ft=wrist_ft, wrist_cameras=wrist_cameras) 238 | 239 | # We need to measure the last action sent to the robot and the gripper. 240 | arm_effector, arm_action_sensor = action_sensor.create_sensed_effector( 241 | arm_effector_module.ArmEffector( 242 | arm=arm, action_range_override=None, robot_name=robot_name)) 243 | 244 | # Effector used for the gripper. The gripper is controlled by applying the 245 | # min or max command, this allows the agent to quicky learn how to grasp 246 | # instead of learning how to close the gripper first. 247 | gripper_effector, gripper_action_sensor = action_sensor.create_sensed_effector( 248 | default_gripper_effector.DefaultGripperEffector(gripper, robot_name)) 249 | 250 | # Enable bang bang control for the gripper, this allows the agent to close and 251 | # open the gripper faster. 252 | gripper_effector = min_max_effector.MinMaxEffector( 253 | base_effector=gripper_effector) 254 | 255 | # Build the 4D cartesian controller, we use a 6D cartesian effector under the 256 | # hood. 257 | effector_model = cartesian_6d_velocity_effector.ModelParams( 258 | element=arm.wrist_site, joints=arm.joints) 259 | effector_control = cartesian_6d_velocity_effector.ControlParams( 260 | control_timestep_seconds=_CONTROL_TIMESTEP, 261 | max_lin_vel=_MAX_LIN_VEL, 262 | max_rot_vel=_MAX_ANG_VEL, 263 | joint_velocity_limits=np.array(_JOINT_VEL_LIMITS), 264 | nullspace_gain=0.025, 265 | nullspace_joint_position_reference=np.array(_NULLSPACE_JOINT_STATE), 266 | regularization_weight=1e-2, 267 | enable_joint_position_limits=True, 268 | minimum_distance_from_joint_position_limit=0.01, 269 | joint_position_limit_velocity_scale=0.95, 270 | max_cartesian_velocity_control_iterations=300, 271 | max_nullspace_control_iterations=300) 272 | 273 | # Don't activate collision avoidance because we are restricted to the virtual 274 | # workspace in the center of the basket. 275 | cart_effector_6d = cartesian_6d_velocity_effector.Cartesian6dVelocityEffector( 276 | robot_name=robot_name, 277 | joint_velocity_effector=arm_effector, 278 | model_params=effector_model, 279 | control_params=effector_control) 280 | cart_effector_4d = cartesian_4d_velocity_effector.Cartesian4dVelocityEffector( 281 | effector_6d=cart_effector_6d, 282 | element=arm.wrist_site, 283 | effector_prefix=f'{robot_name}_cart_4d_vel') 284 | 285 | # Constrain the workspace of the robot. 286 | cart_effector_4d = cartesian_4d_velocity_effector.limit_to_workspace( 287 | cartesian_effector=cart_effector_4d, 288 | element=gripper.tool_center_point, 289 | min_workspace_limits=WORKSPACE_CENTER - WORKSPACE_SIZE / 2, 290 | max_workspace_limits=WORKSPACE_CENTER + WORKSPACE_SIZE / 2, 291 | wrist_joint=arm.joints[-1], 292 | wrist_limits=WRIST_RANGE, 293 | reverse_wrist_range=True) 294 | 295 | robot_sensors = [] 296 | 297 | # Sensor for the joint states (torques, velocities and angles). 298 | robot_sensors.append(robot_arm_sensor.RobotArmSensor( 299 | arm=arm, name=f'{robot_name}_arm', have_torque_sensors=True)) 300 | 301 | # Sensor for the cartesian pose of the tcp site. 302 | robot_sensors.append(robot_tcp_sensor.RobotTCPSensor( 303 | gripper=gripper, name=robot_name)) 304 | 305 | # Sensor for cartesian pose of the wrist site. 306 | robot_sensors.append(site_sensor.SiteSensor( 307 | site=arm.wrist_site, name=f'{robot_name}_wrist_site')) 308 | 309 | # Sensor to measure the state of the gripper (position, velocity and grasp). 310 | robot_sensors.append(robotiq_gripper_sensor.RobotiqGripperSensor( 311 | gripper=gripper, name=f'{robot_name}_gripper')) 312 | 313 | # Sensor for the wrench measured at the wrist sensor. 314 | robot_sensors.append(robot_wrist_ft_sensor.RobotWristFTSensor( 315 | wrist_ft_sensor=wrist_ft, name=f'{robot_name}_wrist')) 316 | 317 | # Sensors to measure the last action sent to the arm joints and the gripper 318 | # actuator. 319 | robot_sensors.extend([arm_action_sensor, gripper_action_sensor]) 320 | 321 | return moma_robot.StandardRobot( 322 | arm=arm, 323 | arm_base_site_name='base_site', 324 | gripper=gripper, 325 | robot_sensors=robot_sensors, 326 | wrist_cameras=wrist_cameras, 327 | arm_effector=cart_effector_4d, 328 | gripper_effector=gripper_effector, 329 | wrist_ft=wrist_ft, 330 | name=robot_name) 331 | 332 | 333 | def _props(red: str, green: str, blue: str) -> Sequence[prop.Prop]: 334 | """Build task props.""" 335 | objects = ((red, 'red'), (green, 'green'), (blue, 'blue')) 336 | color_set = [ 337 | [1, 0, 0, 1], 338 | [0, 1, 0, 1], 339 | [0, 0, 1, 1], 340 | ] 341 | props = [] 342 | for i, (obj_id, color) in enumerate(objects): 343 | p = rgb_object.RgbObjectProp( 344 | obj_id=obj_id, color=color_set[i], name=f'rgb_object_{color}') 345 | p = prop.WrapperProp(wrapped_entity=p, name=f'rgb_object_{color}') 346 | props.append(p) 347 | 348 | return props 349 | -------------------------------------------------------------------------------- /rgb_stacking/utils/permissive_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """A permissive wrapper for a SavedModel.""" 17 | 18 | import copy 19 | import inspect 20 | 21 | from typing import Any, Callable, NamedTuple 22 | 23 | from absl import logging 24 | import tensorflow as tf 25 | import tree 26 | 27 | 28 | class _Function(NamedTuple): 29 | """Function exposing signature and expected canonical arguments.""" 30 | func: Callable[..., Any] 31 | signature: inspect.Signature 32 | structured_specs: Any 33 | 34 | def __call__(self, *args, **kwargs): 35 | return self.func(*args, **kwargs) 36 | 37 | @property 38 | def canonical_arguments(self) -> inspect.BoundArguments: 39 | args, kwargs = copy.deepcopy(self.structured_specs) 40 | return self.signature.bind(*args, **kwargs) 41 | 42 | 43 | class PermissiveModel: 44 | """A permissive wrapper for a SavedModel.""" 45 | 46 | # Disable pytype attribute error checks. 47 | _HAS_DYNAMIC_ATTRIBUTES = True 48 | 49 | def __init__(self, model): 50 | self.model = model 51 | 52 | self._tables = self.model.function_tables() 53 | self._initialized_tables = {} 54 | 55 | def build_parameters(params): 56 | params = [ 57 | inspect.Parameter(str(param[0], "utf-8"), param[1]) 58 | for param in params 59 | ] 60 | # Always include a VAR_KEYWORD to capture any extraneous arguments. 61 | if all([p.kind != inspect.Parameter.VAR_KEYWORD for p in params]): 62 | params.append( 63 | inspect.Parameter("__unused_kwargs", inspect.Parameter.VAR_KEYWORD)) 64 | return params 65 | 66 | signatures = self.model.function_signatures() 67 | if tf.executing_eagerly(): 68 | signatures = tree.map_structure(lambda x: x.numpy(), signatures) 69 | else: 70 | with tf.compat.v1.Session() as sess: 71 | signatures = sess.run(signatures) 72 | signatures = { 73 | func_name: inspect.Signature(build_parameters(func_params)) 74 | for func_name, func_params in signatures.items() 75 | } 76 | self.signatures = signatures 77 | 78 | # Attach deepfuncs. 79 | for name in self.signatures.keys(): 80 | setattr(self, name, self._make_permissive_function(name)) 81 | 82 | def _maybe_init_tables(self, concrete_func: Any, name: str): 83 | """Initialise all tables for a function if they are not initialised. 84 | 85 | Some functions rely on hash-tables that must be externally initialized. This 86 | method will perform a one-time initialisation of the tables. It does so by 87 | finding the corresponding op that creates the hash-table handles (these will 88 | be different from the ones observed in the initial deepfuncs), and import 89 | the corresponding keys and values. 90 | 91 | Args: 92 | concrete_func: A tf.ConcreteFunction corresponding to a deepfunc. 93 | name: The name of the deepfunc. 94 | """ 95 | if name not in self._tables: 96 | return 97 | 98 | all_nodes = dict( 99 | main={n.name: n for n in concrete_func.graph.as_graph_def().node}) 100 | for func_def in concrete_func.graph.as_graph_def().library.function: 101 | all_nodes[func_def.signature.name] = { 102 | n.name: n for n in func_def.node_def 103 | } 104 | 105 | for table_name, (table_keys, table_values) in self._tables[name].items(): 106 | table_op = None 107 | for nodes in all_nodes.values(): 108 | if table_name in nodes: 109 | if table_op is not None: 110 | raise ValueError(f"Duplicate table op found for {table_name}") 111 | table_op = nodes[table_name] 112 | 113 | logging.info("Initialising table for Op `%s`", table_name) 114 | table_handle_name = table_op.attr["shared_name"].s # pytype: disable=attribute-error 115 | table_handle = tf.raw_ops.HashTableV2( 116 | key_dtype=table_keys.dtype, 117 | value_dtype=table_values.dtype, 118 | shared_name=table_handle_name) 119 | tf.raw_ops.LookupTableImportV2( 120 | table_handle=table_handle, keys=table_keys, values=table_values) 121 | self._initialized_tables[name] = self._tables.pop(name) # Only init once. 122 | 123 | def _make_permissive_function(self, name: str) -> Callable[..., Any]: 124 | """Create a permissive version of a function in the SavedModel.""" 125 | if name not in self.signatures: 126 | raise ValueError(f"No function named {name} in SavedModel, " 127 | "options are {self.signatures}") 128 | 129 | tf_func = getattr(self.model, name) 130 | if hasattr(tf_func, "concrete_functions"): 131 | # tf.RestoredFunction 132 | concrete_func, = tf_func.concrete_functions # Expect only one. 133 | elif hasattr(tf_func, "_list_all_concrete_functions"): 134 | # tf.Function 135 | all_concrete_funcs = tf_func._list_all_concrete_functions() # pylint: disable=protected-access 136 | all_concrete_signatures = [ 137 | f.structured_input_signature for f in all_concrete_funcs 138 | ] 139 | # This is necessary as tf.saved_model.save can force a retrace on 140 | # tf.Function, resulting in another concrete function with identical 141 | # signature. 142 | unique_concrete_signatures = set([ 143 | tuple(tree.flatten_with_path(sig)) for sig in all_concrete_signatures 144 | ]) 145 | if len(unique_concrete_signatures) != 1: 146 | raise ValueError( 147 | "Expected exactly one unique concrete_function signature, found " 148 | f"the following: {all_concrete_signatures}") 149 | concrete_func = all_concrete_funcs[0] 150 | else: 151 | raise ValueError(f"No concrete functions found on {tf_func}") 152 | 153 | self._maybe_init_tables(concrete_func, name) 154 | 155 | def func(*args, **kwargs): 156 | bound_args = self.signatures[name].bind(*args, **kwargs) 157 | canonical_args = concrete_func.structured_input_signature 158 | 159 | flat_bound_args = tree.flatten_with_path( 160 | (bound_args.args, bound_args.kwargs)) 161 | flat_canonical_args = tree.flatten_with_path(canonical_args) 162 | 163 | # Specs for error reporting. 164 | bound_specs = tree.map_structure( 165 | lambda x: None if x is None else object(), 166 | (bound_args.args, bound_args.kwargs)) 167 | canonical_specs = tree.map_structure( 168 | lambda x: None if x is None else object(), canonical_args) 169 | 170 | # Check for missing arguments. 171 | flat_bound_args_dict = dict(flat_bound_args) 172 | for arg_path, arg_spec in flat_canonical_args: 173 | if arg_path not in flat_bound_args_dict and arg_spec is not None: 174 | raise ValueError( 175 | f"Missing argument with path {arg_path}, expected canonical args " 176 | f"with structure {canonical_specs}, received {bound_specs}. (All " 177 | "required values have been replaced by object() for brevity.") 178 | 179 | if arg_path in flat_bound_args_dict and arg_spec is None: 180 | arg_value = flat_bound_args_dict[arg_path] 181 | if arg_value is not None: 182 | logging.warning( 183 | "Received unexpected argument `%s` for path %s, replaced with " 184 | "None.", arg_value, arg_path) 185 | flat_bound_args_dict[arg_path] = None 186 | 187 | # Filter out extraneous arguments and dictionary keys. 188 | flat_canonical_args_dict = dict(flat_canonical_args) 189 | filtered_flat_bound_args = { 190 | arg_path: arg_value 191 | for arg_path, arg_value in flat_bound_args_dict.items() 192 | if arg_path in flat_canonical_args_dict 193 | } 194 | full_flat_bound_args = [ 195 | filtered_flat_bound_args.get(arg_path, None) 196 | for arg_path, _ in flat_canonical_args 197 | ] 198 | filtered_args, filtered_kwargs = tree.unflatten_as( 199 | canonical_args, full_flat_bound_args) 200 | 201 | return tf_func(*filtered_args, **filtered_kwargs) 202 | 203 | return _Function( 204 | func, 205 | copy.deepcopy(self.signatures[name]), 206 | copy.deepcopy(concrete_func.structured_input_signature), 207 | ) 208 | -------------------------------------------------------------------------------- /rgb_stacking/utils/policy_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility for loading RGB stacking policies saved as TF SavedModel.""" 17 | 18 | from typing import NamedTuple 19 | import dm_env 20 | import numpy as np 21 | # This reverb dependency is needed since otherwise loading a SavedModel throws 22 | # an error when the ReverbDataset op is not found. 23 | import reverb # pylint: disable=unused-import 24 | import tensorflow as tf 25 | import tree 26 | import typing_extensions 27 | 28 | from rgb_stacking.utils import permissive_model 29 | 30 | 31 | class _MPOState(NamedTuple): 32 | counter: tf.Tensor 33 | actor: tree.Structure[tf.Tensor] 34 | critic: tree.Structure[tf.Tensor] 35 | 36 | 37 | @typing_extensions.runtime 38 | class Policy(typing_extensions.Protocol): 39 | 40 | def step(self, timestep: dm_env.TimeStep, state: _MPOState): 41 | pass 42 | 43 | def initial_state(self) -> _MPOState: 44 | pass 45 | 46 | 47 | def policy_from_path(saved_model_path: str) -> Policy: 48 | """Loads policy from stored TF SavedModel.""" 49 | policy = tf.saved_model.load(saved_model_path) 50 | # Relax strict requirement with respect to its expected inputs, e.g. in 51 | # regards to unused arguments. 52 | policy = permissive_model.PermissiveModel(policy) 53 | 54 | # The loaded policy's step function expects batched data. Wrap it so that it 55 | # expects unbatched data. 56 | policy_step_batch_fn = policy.step 57 | 58 | def _expand_batch_dim(x): 59 | return np.expand_dims(x, axis=0) 60 | 61 | def _squeeze_batch_dim(x): 62 | return np.squeeze(x, axis=0) 63 | 64 | def policy_step_fn(timestep: dm_env.TimeStep, state: _MPOState): 65 | timestep_batch = dm_env.TimeStep( 66 | None, None, None, 67 | tree.map_structure(_expand_batch_dim, timestep.observation)) 68 | state_batch = tree.map_structure(_expand_batch_dim, state) 69 | output_batch = policy_step_batch_fn(timestep_batch, state_batch) 70 | output = tree.map_structure(_squeeze_batch_dim, output_batch) 71 | return output 72 | 73 | policy.step = policy_step_fn 74 | return policy 75 | 76 | 77 | class StatefulPolicyCallable: 78 | """Object-oriented policy for directly using in dm_control viewer.""" 79 | 80 | def __init__(self, policy: Policy): 81 | self._policy = policy 82 | self._state = self._policy.initial_state() 83 | 84 | def __call__(self, timestep: dm_env.TimeStep): 85 | if timestep.step_type == dm_env.StepType.FIRST: 86 | self._state = self._policy.initial_state() 87 | (action, _), self._state = self._policy.step(timestep, self._state) 88 | return action 89 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Fail on any error. 18 | set -e 19 | 20 | # Display commands being run. 21 | set -x 22 | 23 | TMP_DIR=`mktemp -d` 24 | 25 | python3 -m venv "${TMP_DIR}/rgb_stacking" 26 | source "${TMP_DIR}/rgb_stacking/bin/activate" 27 | 28 | # Install dependencies. 29 | pip install --upgrade -r requirements.txt 30 | 31 | # Run the visualization of the environment. 32 | python -m rgb_stacking.main 33 | 34 | # Clean up. 35 | rm -r ${TMP_DIR} 36 | --------------------------------------------------------------------------------