├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── assets
└── saved_models
│ ├── mpo_state_rgb_test_triplet1
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
│ ├── mpo_state_rgb_test_triplet2
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
│ ├── mpo_state_rgb_test_triplet3
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
│ ├── mpo_state_rgb_test_triplet4
│ ├── saved_model.pb
│ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
│ └── mpo_state_rgb_test_triplet5
│ ├── saved_model.pb
│ └── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
├── doc
└── images
│ └── rgb_environment.png
├── real_cell_documentation
├── basket_and_camera_assembly.pdf
├── bill_of_materials.pdf
├── cell_assembly.pdf
├── robot_assembly.pdf
└── standard_cell_generic_setup.pdf
├── requirements.txt
├── rgb_stacking
├── environment.py
├── environment_test.py
├── main.py
├── physics_utils.py
├── reward_functions.py
├── stack_rewards.py
├── task.py
└── utils
│ ├── permissive_model.py
│ └── policy_loading.py
└── run.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | # Pull Requests
4 |
5 | Please send in fixes or feature additions through Pull Requests.
6 |
7 | ## Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a Contributor License
10 | Agreement. You (or your employer) retain the copyright to your contribution,
11 | this simply gives us permission to use and redistribute your contributions as
12 | part of the project. Head over to to see
13 | your current agreements on file or to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RGB-stacking 🛑🟩🔷 for robotic manipulation
2 | ### [BLOG](https://deepmind.com/blog/article/stacking-our-way-to-more-general-robots) | [PAPER][pick_and_place_paper] | [VIDEO](https://youtu.be/BxOKPEtMuZw)
3 |
4 | **Beyond Pick-and-Place: Tackling Robotic Stacking of Diverse Shapes,**\
5 | Alex X. Lee*, Coline Devin*, Yuxiang Zhou*, Thomas Lampe*, Konstantinos Bousmalis*, Jost Tobias Springenberg*, Arunkumar Byravan, Abbas Abdolmaleki, Nimrod Gileadi, David Khosid, Claudio Fantacci, Jose Enrique Chen, Akhil Raju, Rae Jeong, Michael Neunert, Antoine Laurens, Stefano Saliceti, Federico Casarini, Martin Riedmiller, Raia Hadsell, Francesco Nori.\
6 | In *Conference on Robot Learning (CoRL)*, 2021.
7 |
8 |
9 |
10 | This repository contains an implementation of the simulation environment
11 | described in the paper
12 | ["Beyond Pick-and-Place: Tackling robotic stacking of diverse shapes"][pick_and_place_paper].
13 | Note that this is a re-implementation of the environment (to remove dependencies
14 | on internal libraries). As a result, not all the features described in the paper
15 | are available at this point. Noticeably, domain randomization is not included
16 | in this release. We also aim to provide reference performance metrics of
17 | trained policies on this environment in the near future.
18 |
19 | In this environment, the agent controls a robot arm with a parallel gripper
20 | above a basket, which contains three objects — one red, one green, and one blue,
21 | hence the name RGB. The agent's task is to stack the red object on top of the
22 | blue object, within 20 seconds, while the green object serves as an obstacle and
23 | distraction. The agent controls the robot using a 4D Cartesian controller. The
24 | controlled DOFs are x, y, z and rotation around
25 | the z axis. The simulation is a MuJoCo environment built using the
26 | [Modular Manipulation (MoMa) framework](https://github.com/deepmind/robotics/tree/main/py/moma/README.md).
27 |
28 |
29 | ## Corresponding method
30 | The RGB-stacking paper
31 | ["Beyond Pick-and-Place: Tackling robotic stacking of diverse shapes"][pick_and_place_paper]
32 | also contains a description and thorough evaluation of our initial solution to
33 | both the 'Skill Mastery' (training on the 5 designated test triplets and
34 | evaluating on them) and the 'Skill Generalization' (training on triplets of
35 | training objects and evaluating on the 5 test triplets). Our
36 | approach was to first train a state-based policy in simulation via a standard RL
37 | algorithm (we used [MPO](https://arxiv.org/abs/1806.06920)) followed by
38 | interactive distillation of the state-based policy into a vision-based policy (using a
39 | domain randomized version of the environment) that we then deployed to the robot
40 | via zero-shot sim-to-real transfer. We finally improved the policy further via
41 | offline RL based on data collected from the sim-to-real policy (we used [CRR](https://arxiv.org/abs/2006.15134)). For details on our method and the results
42 | please consult the paper.
43 |
44 | ## Released specialist policies
45 |
46 | This repository includes state-based policies that were trained on this
47 | environment, which differs slightly from the internal one we used for the paper.
48 | These are 5 specialist policies, each one trained on one test triplet. They
49 | correspond to the Skill Mastery-State teacher in Table 1 of the manuscript and
50 | they achieve 75% stacking success on average. In detail, the stacking success of
51 | each agent over a run of 1000 episodes is (average of 2 seeds):
52 |
53 | * Triplet 1: 77.7%
54 | * Triplet 2: 47.4%
55 | * Triplet 3: 83.5%
56 | * Triplet 4: 79.9%
57 | * Triplet 5: 89.5%
58 | * Average: 75.6%
59 |
60 | The policy weights in the directory `assets/saved_model` are made available
61 | under the terms of the Creative Commons Attribution 4.0 (CC BY 4.0) license.
62 | You may obtain a copy of the License at
63 | https://creativecommons.org/licenses/by/4.0/legalcode.
64 |
65 | ## Installing and visualising the environment
66 |
67 | Please ensure that you have a working
68 | [MuJoCo200 installation](https://www.roboti.us/download.html) and a valid
69 | [MuJoCo licence](https://www.roboti.us/license.html).
70 |
71 | 1. Clone this repository:
72 |
73 | ```bash
74 | git clone https://github.com/deepmind/rgb_stacking.git
75 | cd rgb_stacking
76 | ```
77 |
78 | 2. Prepare a Python 3 environment - venv is recommended.
79 |
80 | ```bash
81 | python3 -m venv rgb_stacking_venv
82 | source rgb_stacking_venv/bin/activate
83 | ```
84 |
85 | 3. Install dependencies:
86 |
87 | ```bash
88 | pip install -r requirements.txt
89 | ```
90 |
91 | 4. Run the environment viewer:
92 |
93 | ```bash
94 | python -m rgb_stacking.main
95 | ```
96 |
97 | Step 2-4 can also be done by running the run.sh script:
98 |
99 | ```bash
100 | ./run.sh
101 | ```
102 |
103 | By default, this loads the environment with a random test triplet and starts the
104 | viewer for visualisation. Alternatively, the object set can be specified with
105 | `--object_triplet` (see the relevant [section](#specifying-the-object-triplet)
106 | for options).
107 |
108 | ## Specifying one of the released specialist policies
109 |
110 | You can also load the environment along with a specialist policy using the flag
111 | `--policy_object_triplet`. E.g. to execute the respective specialist in the
112 | environment with triplet 4 use the following command:
113 |
114 | ```bash
115 | python -m rgb_stacking.main --object_triplet=rgb_test_triplet4 --policy_object_triplet=rgb_test_triplet4
116 | ```
117 |
118 | Executing and visualising a policy in the viewer can be very slow.
119 | Alternatively, using `launch_viewer=False` will render the policy and save it
120 | as `rendered_policy.mp4` in the current directory.
121 |
122 | ```bash
123 | MUJOCO_GL=egl python -m rgb_stacking.main --launch_viewer=False --object_triplet=rgb_test_triplet4 --policy_object_triplet=rgb_test_triplet4
124 | ```
125 |
126 | ## Specifying the object triplet
127 |
128 | The default environment will load with a random test triplet (see Sect. 3.2.1 in
129 | the paper). If you wish to use a different triplet you can use the following
130 | commands:
131 |
132 | ```python
133 | from rgb_stacking import environment
134 |
135 | env = environment.rgb_stacking(object_triplet=NAME_OF_TRIPLET)
136 | ```
137 |
138 | The possible `NAME_OF_TRIPLET` are:
139 |
140 | * `rgb_test_triplet{i}` where `i` is one of 1, 2, 3, 4, 5: Loads test triplet `i`.
141 | * `rgb_test_random`: Randomly loads one of the 5 test triplets.
142 | * `rgb_train_random`: Triplet comprised of blocks from the training set.
143 | * `rgb_heldout_random`: Triplet comprised of blocks from the held-out set.
144 |
145 | For more information on the blocks and the possible options, please refer to
146 | the [rgb_objects repository](https://github.com/deepmind/dm_robotics/tree/main/py/manipulation/props/rgb_objects/README.md).
147 |
148 | ## Specifying the observation space
149 |
150 | By default, the observations exposed by the environment are only the ones we
151 | used for training our state-based agents. To use another set of observations
152 | please use the following code snippet:
153 |
154 | ```python
155 | from rgb_stacking import environment
156 |
157 | env = environment.rgb_stacking(
158 | observations=environment.ObservationSet.CHOSEN_SET)
159 | ```
160 |
161 | The possible `CHOSEN_SET` are:
162 |
163 | * `STATE_ONLY`: Only the state observations, used for training expert policies
164 | from state in simulation (stage 1).
165 | * `VISION_ONLY`: Only image observations.
166 | * `ALL`: All observations.
167 | * `INTERACTIVE_IMITATION_LEARNING`: Pair of image observations and a subset of
168 | proprioception observations, used for interactive imitation learning
169 | (stage 2).
170 | * `OFFLINE_POLICY_IMPROVEMENT`: Pair of image observations and a subset of
171 | proprioception observations, used for the one-step offline policy
172 | improvement (stage 3).
173 |
174 | ## Real RGB-Stacking Environment: CAD models and assembly instructions
175 |
176 | The [CAD model](https://deepmind.onshape.com/documents/d0b99322019b124525012b2a/w/6702310a5b51c79efed7c65b/e/4b6eb89dc085468d0fee5e97)
177 | of the setup is available in onshape.
178 |
179 | We also provide the following [documents](https://github.com/deepmind/rgb_stacking/tree/main/real_cell_documentation)
180 | for the assembly of the real cell:
181 |
182 | * Assembly instructions for the basket.
183 | * Assembly instructions for the robot.
184 | * Assembly instructions for the cell.
185 | * The bill of materials of all the necessary parts.
186 | * A diagram with the wiring of cell.
187 |
188 | The RGB-objects themselves can be 3D-printed using the STLs available in the [rgb_objects repository](https://github.com/deepmind/dm_robotics/tree/main/py/manipulation/props/rgb_objects/README.md).
189 |
190 | ## Citing
191 |
192 | If you use `rgb_stacking` in your work, please cite the accompanying [paper][pick_and_place_paper]:
193 |
194 | ```bibtex
195 | @inproceedings{lee2021rgbstacking,
196 | title={Beyond Pick-and-Place: Tackling Robotic Stacking of Diverse Shapes},
197 | author={Alex X. Lee and
198 | Coline Devin and
199 | Yuxiang Zhou and
200 | Thomas Lampe and
201 | Konstantinos Bousmalis and
202 | Jost Tobias Springenberg and
203 | Arunkumar Byravan and
204 | Abbas Abdolmaleki and
205 | Nimrod Gileadi and
206 | David Khosid and
207 | Claudio Fantacci and
208 | Jose Enrique Chen and
209 | Akhil Raju and
210 | Rae Jeong and
211 | Michael Neunert and
212 | Antoine Laurens and
213 | Stefano Saliceti and
214 | Federico Casarini and
215 | Martin Riedmiller and
216 | Raia Hadsell and
217 | Francesco Nori},
218 | booktitle={Conference on Robot Learning (CoRL)},
219 | year={2021},
220 | url={https://openreview.net/forum?id=U0Q8CrtBJxJ}
221 | }
222 | ```
223 |
224 |
225 | [pick_and_place_paper]: http://arxiv.org/abs/2110.06192
226 |
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet1/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet1/saved_model.pb
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet1/variables/variables.index
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet2/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet2/saved_model.pb
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet2/variables/variables.index
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet3/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet3/saved_model.pb
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet3/variables/variables.index
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet4/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet4/saved_model.pb
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet4/variables/variables.index
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet5/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet5/saved_model.pb
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/assets/saved_models/mpo_state_rgb_test_triplet5/variables/variables.index
--------------------------------------------------------------------------------
/doc/images/rgb_environment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/doc/images/rgb_environment.png
--------------------------------------------------------------------------------
/real_cell_documentation/basket_and_camera_assembly.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/basket_and_camera_assembly.pdf
--------------------------------------------------------------------------------
/real_cell_documentation/bill_of_materials.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/bill_of_materials.pdf
--------------------------------------------------------------------------------
/real_cell_documentation/cell_assembly.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/cell_assembly.pdf
--------------------------------------------------------------------------------
/real_cell_documentation/robot_assembly.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/robot_assembly.pdf
--------------------------------------------------------------------------------
/real_cell_documentation/standard_cell_generic_setup.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/rgb_stacking/794c3434060acd58057c09f9a255b003cb1e3311/real_cell_documentation/standard_cell_generic_setup.pdf
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dm-control == 0.0.364896371
2 | dm-robotics-moma == 0.0.4
3 | dm-robotics-manipulation == 0.0.4
4 | tensorflow == 2.7.0.rc0
5 | dm-reverb-nightly == 0.5.0.dev20211104
6 | opencv-python >= 3.4.0
7 | typing-extensions >= 3.7.4
8 |
--------------------------------------------------------------------------------
/rgb_stacking/environment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Builds RGB stacking environment.
17 |
18 | This file builds the RGB stacking environment used in the paper
19 | "Beyond Pick-and-Place: Tackling robotic stacking of diverse shapes".
20 | The environment is composed of a robot with a parallel gripper. In front
21 | of the robot there is a basket containing 3 objects, one red, one green and one
22 | blue. The goal is for the robot to stack the red object on top of the blue one.
23 |
24 | In this specific file, we build the interface that is exposed to the agent,
25 | namely the action spec, the observation spec along with the reward.
26 | """
27 |
28 | import enum
29 | from typing import Sequence
30 |
31 | from dm_robotics import agentflow as af
32 | from dm_robotics.agentflow.preprocessors import observation_transforms
33 | from dm_robotics.agentflow.preprocessors import timestep_preprocessor as tsp
34 | from dm_robotics.agentflow.subtasks import subtask_termination
35 | from dm_robotics.manipulation.props.rgb_objects import rgb_object
36 | from dm_robotics.moma import action_spaces
37 | from dm_robotics.moma import subtask_env
38 | from dm_robotics.moma import subtask_env_builder
39 | from dm_robotics.moma.utils import mujoco_collisions
40 | import numpy as np
41 |
42 | from rgb_stacking import reward_functions
43 | from rgb_stacking import stack_rewards
44 | from rgb_stacking import task
45 |
46 |
47 | # The environments provides stacked observations to the agent. The data of the
48 | # previous n steps is stacked and provided as a single observation to the agent.
49 | # We stack different observations a different number of times.
50 | _OBSERVATION_STACK_DEPTH = 3
51 | _ACTION_OBS_STACK_DEPTH = 2
52 |
53 | # Number of steps in an episode
54 | _MAX_STEPS = 400
55 |
56 | # Timestep of the physics simulation.
57 | _PHYSICS_TIMESTEP = 0.0005
58 |
59 |
60 | _STATE_OBSERVATIONS = (
61 | 'action/environment',
62 | 'gripper/grasp',
63 | 'gripper/joints/angle',
64 | 'gripper/joints/velocity',
65 | 'rgb30_blue/abs_pose',
66 | 'rgb30_blue/to_pinch',
67 | 'rgb30_green/abs_pose',
68 | 'rgb30_green/to_pinch',
69 | 'rgb30_red/abs_pose',
70 | 'rgb30_red/to_pinch',
71 | 'sawyer/joints/angle',
72 | 'sawyer/joints/torque',
73 | 'sawyer/joints/velocity',
74 | 'sawyer/pinch/pose',
75 | 'sawyer/tcp/pose',
76 | 'sawyer/tcp/velocity',
77 | 'wrist/force',
78 | 'wrist/torque'
79 | )
80 |
81 | _VISION_OBSERVATIONS = (
82 | 'basket_back_left/pixels',
83 | 'basket_front_left/pixels',
84 | 'basket_front_right/pixels'
85 | )
86 |
87 | # For interactive imitation learning, the vision-based policy used the
88 | # following proprioception and images observations from the pair of front
89 | # cameras given by the simulated environment.
90 | _INTERACTIVE_IMITATION_LEARNING_OBSERVATIONS = (
91 | 'sawyer/joints/angle',
92 | 'gripper/joints/angle',
93 | 'sawyer/pinch/pose',
94 | 'sawyer/tcp/pose',
95 | 'basket_front_left/pixels',
96 | 'basket_front_right/pixels'
97 | )
98 |
99 |
100 | # For the one-step offline policy improvement from real data, the vision-based
101 | # policy used the following proprioception and images observations from the pair
102 | # of front cameras given by the real environment.
103 | _OFFLINE_POLICY_IMPROVEMENT_OBSERVATIONS = [
104 | 'sawyer/joints/angle',
105 | 'sawyer/joints/velocity',
106 | 'gripper/grasp',
107 | 'gripper/joints/angle',
108 | 'gripper/joints/velocity',
109 | 'sawyer/pinch/pose',
110 | 'basket_front_left/pixels',
111 | 'basket_front_right/pixels',]
112 |
113 |
114 | class ObservationSet(int, enum.Enum):
115 | """Different possible set of observations that can be exposed."""
116 |
117 | _observations: Sequence[str]
118 |
119 | def __new__(cls, value: int, observations: Sequence[str]):
120 | obj = int.__new__(cls, value)
121 | obj._value_ = value
122 | obj._observations = observations
123 | return obj
124 |
125 | @property
126 | def observations(self):
127 | return self._observations
128 |
129 | STATE_ONLY = (0, _STATE_OBSERVATIONS)
130 | VISION_ONLY = (1, _VISION_OBSERVATIONS)
131 | ALL = (2, _STATE_OBSERVATIONS + _VISION_OBSERVATIONS)
132 | INTERACTIVE_IMITATION_LEARNING = (
133 | 3, _INTERACTIVE_IMITATION_LEARNING_OBSERVATIONS)
134 | OFFLINE_POLICY_IMPROVEMENT = (4, _OFFLINE_POLICY_IMPROVEMENT_OBSERVATIONS)
135 |
136 |
137 | def rgb_stacking(
138 | object_triplet: str = 'rgb_test_random',
139 | observation_set: ObservationSet = ObservationSet.STATE_ONLY,
140 | use_sparse_reward: bool = False
141 | ) -> subtask_env.SubTaskEnvironment:
142 | """Returns the environment.
143 |
144 | The relevant groups can be found here:
145 | https://github.com/deepmind/robotics/blob/main/py/manipulation/props/rgb_objects/rgb_object.py
146 |
147 | The valid object triplets can be found under PROP_TRIPLETS in the file.
148 |
149 | Args:
150 | object_triplet: Triplet of RGB objects to use in the environment.
151 | observation_set: Set of observations that en environment should expose.
152 | use_sparse_reward: If true will use sparse reward, which is 1 if the objects
153 | stacked and not touching the robot, and 0 otherwise.
154 | """
155 |
156 | red_id, green_id, blue_id = rgb_object.PROP_TRIPLETS[object_triplet][1]
157 | rgb_task = task.rgb_task(red_id, green_id, blue_id)
158 | rgb_task.physics_timestep = _PHYSICS_TIMESTEP
159 |
160 | # To speed up simulation we ensure that mujoco will no check contact between
161 | # geoms that cannot collide.
162 | mujoco_collisions.exclude_bodies_based_on_contype_conaffinity(
163 | rgb_task.root_entity.mjcf_model)
164 |
165 | # Build the agent flow subtask. This is where the task logic is defined,
166 | # observations, and rewards.
167 | env_builder = subtask_env_builder.SubtaskEnvBuilder()
168 | env_builder.set_task(rgb_task)
169 | task_env = env_builder.build_base_env()
170 |
171 | # Define the action space, this is used to expose the actuators used in the
172 | # base task.
173 | effectors_action_spec = rgb_task.effectors_action_spec(
174 | physics=task_env.physics)
175 | robot_action_spaces = []
176 | for rbt in rgb_task.robots:
177 | arm_action_space = action_spaces.ArmJointActionSpace(
178 | af.prefix_slicer(effectors_action_spec, rbt.arm_effector.prefix))
179 | gripper_action_space = action_spaces.GripperActionSpace(
180 | af.prefix_slicer(effectors_action_spec, rbt.gripper_effector.prefix))
181 | robot_action_spaces.extend([arm_action_space, gripper_action_space])
182 |
183 | composite_action_space = af.CompositeActionSpace(
184 | robot_action_spaces)
185 | env_builder.set_action_space(composite_action_space)
186 |
187 | # Cast all the floating point observations to float32.
188 | env_builder.add_preprocessor(
189 | observation_transforms.DowncastFloatPreprocessor(np.float32))
190 |
191 | # Concatenate the TCP and wrist site observations.
192 | env_builder.add_preprocessor(observation_transforms.MergeObservations(
193 | obs_to_merge=['robot0_tcp_pos', 'robot0_tcp_quat'],
194 | new_obs='robot0_tcp_pose'))
195 | env_builder.add_preprocessor(observation_transforms.MergeObservations(
196 | obs_to_merge=['robot0_wrist_site_pos', 'robot0_wrist_site_quat'],
197 | new_obs='robot0_wrist_site_pose'))
198 |
199 | # Add in observations to measure the distance from the TCP to the objects.
200 | for color in ('red', 'green', 'blue'):
201 | env_builder.add_preprocessor(observation_transforms.AddObservation(
202 | obs_name=f'{color}_to_pinch',
203 | obs_callable=_distance_delta_obs(
204 | f'rgb_object_{color}_pose', 'robot0_tcp_pose')))
205 |
206 | # Concatenate the action sent to the robot joints and the gripper actuator.
207 | env_builder.add_preprocessor(observation_transforms.MergeObservations(
208 | obs_to_merge=['robot0_arm_joint_previous_action',
209 | 'robot0_gripper_previous_action'],
210 | new_obs='robot0_previous_action'))
211 |
212 | # Mapping of observation names to match the observation names in the stored
213 | # data.
214 | obs_mapping = {
215 | 'robot0_arm_joint_pos': 'sawyer/joints/angle',
216 | 'robot0_arm_joint_vel': 'sawyer/joints/velocity',
217 | 'robot0_arm_joint_torques': 'sawyer/joints/torque',
218 | 'robot0_tcp_pose': 'sawyer/pinch/pose',
219 | 'robot0_wrist_site_pose': 'sawyer/tcp/pose',
220 | 'robot0_wrist_site_vel_world': 'sawyer/tcp/velocity',
221 | 'robot0_gripper_pos': 'gripper/joints/angle',
222 | 'robot0_gripper_vel': 'gripper/joints/velocity',
223 | 'robot0_gripper_grasp': 'gripper/grasp',
224 | 'robot0_wrist_force': 'wrist/force',
225 | 'robot0_wrist_torque': 'wrist/torque',
226 | 'rgb_object_red_pose': 'rgb30_red/abs_pose',
227 | 'rgb_object_green_pose': 'rgb30_green/abs_pose',
228 | 'rgb_object_blue_pose': 'rgb30_blue/abs_pose',
229 | 'basket_back_left_rgb_img': 'basket_back_left/pixels',
230 | 'basket_front_left_rgb_img': 'basket_front_left/pixels',
231 | 'basket_front_right_rgb_img': 'basket_front_right/pixels',
232 | 'red_to_pinch': 'rgb30_red/to_pinch',
233 | 'blue_to_pinch': 'rgb30_blue/to_pinch',
234 | 'green_to_pinch': 'rgb30_green/to_pinch',
235 | 'robot0_previous_action': 'action/environment',
236 | }
237 |
238 | # Create different subsets of observations.
239 | action_obs = {'action/environment'}
240 |
241 | # These observations only have a single floating point value instead of an
242 | # array.
243 | single_value_obs = {'gripper/joints/angle',
244 | 'gripper/joints/velocity',
245 | 'gripper/grasp'}
246 |
247 | # Rename observations.
248 | env_builder.add_preprocessor(observation_transforms.RenameObservations(
249 | obs_mapping, raise_on_missing=False))
250 |
251 | if use_sparse_reward:
252 | reward_fn = stack_rewards.get_sparse_reward_fn(
253 | top_object=rgb_task.props[0],
254 | bottom_object=rgb_task.props[2],
255 | get_physics_fn=lambda: task_env.physics)
256 | else:
257 | reward_fn = stack_rewards.get_shaped_stacking_reward()
258 | env_builder.add_preprocessor(reward_functions.RewardPreprocessor(reward_fn))
259 |
260 | # We concatenate several observations from consecutive timesteps. Depending
261 | # on the observations, we will concatenate a different number of observations.
262 | # - Most observations are stacked 3 times
263 | # - Camera observations are not stacked.
264 | # - The action observation is stacked twice.
265 | # - When stacking three scalar (i.e. numpy array of shape (1,)) observations,
266 | # we do not add a leading dimension, so the final shape is (3,).
267 | env_builder.add_preprocessor(
268 | observation_transforms.StackObservations(
269 | obs_to_stack=list(
270 | set(_STATE_OBSERVATIONS) - action_obs - single_value_obs),
271 | stack_depth=_OBSERVATION_STACK_DEPTH,
272 | add_leading_dim=True))
273 | env_builder.add_preprocessor(
274 | observation_transforms.StackObservations(
275 | obs_to_stack=list(single_value_obs),
276 | stack_depth=_OBSERVATION_STACK_DEPTH,
277 | add_leading_dim=False))
278 | env_builder.add_preprocessor(
279 | observation_transforms.StackObservations(
280 | obs_to_stack=list(action_obs),
281 | stack_depth=_ACTION_OBS_STACK_DEPTH,
282 | add_leading_dim=True))
283 |
284 | # Only keep the obseravtions that we want to expose to the agent.
285 | env_builder.add_preprocessor(observation_transforms.RetainObservations(
286 | observation_set.observations, raise_on_missing=False))
287 |
288 | # End episodes after 400 steps.
289 | env_builder.add_preprocessor(
290 | subtask_termination.MaxStepsTermination(_MAX_STEPS))
291 |
292 | return env_builder.build()
293 |
294 |
295 | def _distance_delta_obs(key1: str, key2: str):
296 | """Returns a callable that returns the difference between two observations."""
297 | def util(timestep: tsp.PreprocessorTimestep) -> np.ndarray:
298 | return timestep.observation[key1] - timestep.observation[key2]
299 | return util
300 |
301 |
--------------------------------------------------------------------------------
/rgb_stacking/environment_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Tests environment.py."""
17 |
18 | from absl.testing import absltest
19 | from dm_env import test_utils
20 |
21 | from rgb_stacking import environment
22 |
23 |
24 | class EnvironmentTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
25 |
26 | def make_object_under_test(self):
27 | return environment.rgb_stacking(
28 | object_triplet='rgb_test_triplet1',
29 | observation_set=environment.ObservationSet.ALL)
30 |
31 |
32 | if __name__ == '__main__':
33 | absltest.main()
34 |
--------------------------------------------------------------------------------
/rgb_stacking/main.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Script to run a viewer to visualize the rgb stacking environment."""
17 |
18 | from typing import Sequence
19 |
20 | from absl import app
21 | from absl import flags
22 | from absl import logging
23 | import cv2
24 | from dm_control import viewer
25 | from dm_robotics.manipulation.props.rgb_objects import rgb_object
26 | from dm_robotics.moma import subtask_env
27 | import numpy as np
28 |
29 | from rgb_stacking import environment
30 | from rgb_stacking.utils import policy_loading
31 |
32 | _TEST_OBJECT_TRIPLETS = tuple(rgb_object.PROP_TRIPLETS_TEST.keys())
33 |
34 | _ALL_OBJECT_TRIPLETS = _TEST_OBJECT_TRIPLETS + (
35 | 'rgb_train_random',
36 | 'rgb_test_random',
37 | )
38 |
39 | _POLICY_DIR = ('assets/saved_models')
40 |
41 | _POLICY_PATHS = {
42 | k: f'{_POLICY_DIR}/mpo_state_{k}' for k in _TEST_OBJECT_TRIPLETS
43 | }
44 |
45 | _OBJECT_TRIPLET = flags.DEFINE_enum(
46 | 'object_triplet', 'rgb_test_random', _ALL_OBJECT_TRIPLETS,
47 | 'Triplet of RGB objects to use in the environment.')
48 | _POLICY_OBJECT_TRIPLET = flags.DEFINE_enum(
49 | 'policy_object_triplet', None, _TEST_OBJECT_TRIPLETS,
50 | 'Optional test triplet name indicating to load a policy that was trained on'
51 | ' this triplet.')
52 | _LAUNCH_VIEWER = flags.DEFINE_bool(
53 | 'launch_viewer', True,
54 | 'Optional boolean. If True, will launch the dm_control viewer. If False'
55 | ' will load the policy, run it and save a recording of it as an .mp4.')
56 |
57 |
58 | def run_episode_and_render(
59 | env: subtask_env.SubTaskEnvironment,
60 | policy: policy_loading.Policy
61 | ) -> Sequence[np.ndarray]:
62 | """Saves a gif of the policy running against the environment."""
63 | rendered_images = []
64 | logging.info('Starting the rendering of the policy, this might take some'
65 | ' time...')
66 | state = policy.initial_state()
67 | timestep = env.reset()
68 | rendered_images.append(env.physics.render(camera_id='main_camera'))
69 | while not timestep.last():
70 | (action, _), state = policy.step(timestep, state)
71 | timestep = env.step(action)
72 | rendered_images.append(env.physics.render(camera_id='main_camera'))
73 | logging.info('Done rendering!')
74 | return rendered_images
75 |
76 |
77 | def main(argv: Sequence[str]) -> None:
78 |
79 | del argv
80 |
81 | if not _LAUNCH_VIEWER.value and _POLICY_OBJECT_TRIPLET.value is None:
82 | raise ValueError('To record a video, a policy must be given.')
83 |
84 | # Load the rgb stacking environment.
85 | with environment.rgb_stacking(object_triplet=_OBJECT_TRIPLET.value) as env:
86 |
87 | # Optionally load a policy trained on one of these environments.
88 | if _POLICY_OBJECT_TRIPLET.value is not None:
89 | policy_path = _POLICY_PATHS[_POLICY_OBJECT_TRIPLET.value]
90 | policy = policy_loading.policy_from_path(policy_path)
91 | else:
92 | policy = None
93 |
94 | if _LAUNCH_VIEWER.value:
95 | # The viewer requires a callable as a policy.
96 | if policy is not None:
97 | policy = policy_loading.StatefulPolicyCallable(policy)
98 | viewer.launch(env, policy=policy)
99 | else:
100 |
101 | # Render the episode.
102 | rendered_episode = run_episode_and_render(env, policy)
103 |
104 | # Save as mp4 video in current directory.
105 | height, width, _ = rendered_episode[0].shape
106 | out = cv2.VideoWriter(
107 | './rendered_policy.mp4',
108 | cv2.VideoWriter_fourcc('m', 'p', '4', 'v'),
109 | 1.0 / env.task.control_timestep, (width, height))
110 | for image in rendered_episode:
111 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
112 | out.write(image)
113 | out.release()
114 |
115 | if __name__ == '__main__':
116 | app.run(main)
117 |
--------------------------------------------------------------------------------
/rgb_stacking/physics_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """This file contains various helper functions for finding mujoco collisions.
17 | """
18 |
19 | import itertools
20 |
21 | DEFAULT_OBJECT_COLLISION_MARGIN = 0.0002
22 | DEFAULT_COLLISION_MARGIN = 1e-8
23 |
24 | OBJECT_GEOM_PREFIXES = ['rgb']
25 | GROUND_GEOM_PREFIXES = ['work_surface', 'ground']
26 | ROBOT_GEOM_PREFIXES = ['robot']
27 |
28 |
29 | def has_object_collision(physics, collision_geom_prefix,
30 | margin=DEFAULT_OBJECT_COLLISION_MARGIN):
31 | """Check for collisions between geoms and objects."""
32 | return has_collision(
33 | physics=physics,
34 | collision_geom_prefix_1=[collision_geom_prefix],
35 | collision_geom_prefix_2=OBJECT_GEOM_PREFIXES,
36 | margin=margin)
37 |
38 |
39 | def has_ground_collision(physics, collision_geom_prefix,
40 | margin=DEFAULT_COLLISION_MARGIN):
41 | """Check for collisions between geoms and the ground."""
42 | return has_collision(
43 | physics=physics,
44 | collision_geom_prefix_1=[collision_geom_prefix],
45 | collision_geom_prefix_2=GROUND_GEOM_PREFIXES,
46 | margin=margin)
47 |
48 |
49 | def has_robot_collision(physics, collision_geom_prefix,
50 | margin=DEFAULT_COLLISION_MARGIN):
51 | """Check for collisions between geoms and the robot."""
52 | return has_collision(
53 | physics=physics,
54 | collision_geom_prefix_1=[collision_geom_prefix],
55 | collision_geom_prefix_2=ROBOT_GEOM_PREFIXES,
56 | margin=margin)
57 |
58 |
59 | def has_collision(physics, collision_geom_prefix_1, collision_geom_prefix_2,
60 | margin=DEFAULT_COLLISION_MARGIN):
61 | """Check for collisions between geoms."""
62 | for contact in physics.data.contact:
63 | if contact.dist > margin:
64 | continue
65 | geom1_name = physics.model.id2name(contact.geom1, 'geom')
66 | geom2_name = physics.model.id2name(contact.geom2, 'geom')
67 | for pair in itertools.product(
68 | collision_geom_prefix_1, collision_geom_prefix_2):
69 | if ((geom1_name.startswith(pair[0]) and
70 | geom2_name.startswith(pair[1])) or
71 | (geom2_name.startswith(pair[0]) and
72 | geom1_name.startswith(pair[1]))):
73 | return True
74 | return False
75 |
--------------------------------------------------------------------------------
/rgb_stacking/reward_functions.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Generic reward functions."""
17 |
18 | from typing import Callable, Optional, Iterable
19 | from absl import logging
20 | from dm_robotics import agentflow as af
21 | from dm_robotics.agentflow import spec_utils
22 | import numpy as np
23 |
24 | # Value returned by the gripper grasp observation when the gripper is in the
25 | # 'grasp' state (fingers closed and exerting force)
26 | _INWARD_GRASP = 2
27 |
28 | # Minimal value for the position tolerance of the shaped distance rewards.
29 | MINIMUM_POSITION_TOLERANCE = 1e-9
30 |
31 | RewardFunction = Callable[[spec_utils.ObservationValue], float]
32 |
33 |
34 | class RewardPreprocessor(af.TimestepPreprocessor):
35 | """Timestep preprocessor wrapper around a reward function."""
36 |
37 | def __init__(self, reward_function: RewardFunction):
38 | super().__init__()
39 | self._reward_function = reward_function
40 |
41 | def _process_impl(
42 | self, timestep: af.PreprocessorTimestep) -> af.PreprocessorTimestep:
43 | reward = self._reward_function(timestep.observation)
44 | reward = self._out_spec.reward_spec.dtype.type(reward)
45 | return timestep.replace(reward=reward)
46 |
47 | def _output_spec(
48 | self, input_spec: spec_utils.TimeStepSpec) -> spec_utils.TimeStepSpec:
49 | return input_spec
50 |
51 |
52 | class GraspReward:
53 | """Sparse reward for the gripper grasp status."""
54 |
55 | def __init__(self, obs_key: str):
56 | """Creates new GraspReward function.
57 |
58 | Args:
59 | obs_key: Key of the grasp observation in the observation spec.
60 | """
61 | self._obs_key = obs_key
62 |
63 | def __call__(self, obs: spec_utils.ObservationValue):
64 | is_grasped = obs[self._obs_key][0] == _INWARD_GRASP
65 | return float(is_grasped)
66 |
67 |
68 | class StackPair:
69 | """Get two objects to be above each other.
70 |
71 | To determine the expected position the top object should be at, the size of
72 | the objects must be specified. Currently, only objects with 3-axial symmetry
73 | are supported; for other objects, the distance would impose a constraint on
74 | the possible orientations.
75 |
76 | Reward will be given if the top object's center is within a certain distance
77 | of the point above the bottom object, with the distance tolerance being
78 | separately configurable for the horizontal plane and the vertical axis.
79 |
80 | By default, reward is also only given when the robot is not currently grasping
81 | anything, though this can be deactivated.
82 | """
83 |
84 | def __init__(self,
85 | obj_top: str,
86 | obj_bottom: str,
87 | obj_top_size: float,
88 | obj_bottom_size: float,
89 | horizontal_tolerance: float,
90 | vertical_tolerance: float,
91 | grasp_key: str = "gripper/grasp"):
92 | """Initialize the module.
93 |
94 | Args:
95 | obj_top: Key in the observation dict containing the top object's
96 | position.
97 | obj_bottom: Key in the observation dict containing the bottom
98 | object's position.
99 | obj_top_size: Height of the top object along the vertical axis, in meters.
100 | obj_bottom_size: Height of the bottom object along the vertical axis, in
101 | meters.
102 | horizontal_tolerance: Max distance from the exact stack position on the
103 | horizontal plane to still generate reward.
104 | vertical_tolerance: Max distance from the exact stack position along the
105 | vertical axis to still generate reward.
106 | grasp_key: Key in the observation dict containing the (buffered) grasp
107 | status. Can be set to `None` to not check the grasp status to return
108 | a reward.
109 | """
110 | self._top_key = obj_top
111 | self._bottom_key = obj_bottom
112 | self._horizontal_tolerance = horizontal_tolerance
113 | self._vertical_tolerance = vertical_tolerance
114 | self._grasp_key = grasp_key
115 |
116 | if obj_top_size <= 0. or obj_bottom_size <= 0.:
117 | raise ValueError("Object sizes cannot be zero.")
118 | self._expected_dist = (obj_top_size + obj_bottom_size) / 2.
119 |
120 | def __call__(self, obs: spec_utils.ObservationValue):
121 | top = obs[self._top_key]
122 | bottom = obs[self._bottom_key]
123 |
124 | horizontal_dist = np.linalg.norm(top[:2] - bottom[:2])
125 | if horizontal_dist > self._horizontal_tolerance:
126 | return 0.
127 |
128 | vertical_dist = top[2] - bottom[2]
129 | if np.abs(vertical_dist - self._expected_dist) > self._vertical_tolerance:
130 | return 0.
131 |
132 | if self._grasp_key is not None:
133 | grasp = obs[self._grasp_key]
134 | if grasp == _INWARD_GRASP:
135 | return 0.
136 |
137 | return 1.
138 |
139 |
140 | def tanh_squared(x: np.ndarray, margin: float, loss_at_margin: float = 0.95):
141 | """Returns a sigmoidal shaping loss based on Hafner & Reidmiller (2011).
142 |
143 | Args:
144 | x: A numpy array representing the error.
145 | margin: Margin parameter, a positive `float`.
146 | loss_at_margin: The loss when `l2_norm(x) == margin`. A `float` between 0
147 | and 1.
148 |
149 | Returns:
150 | Shaping loss, a `float` bounded in the half-open interval [0, 1).
151 |
152 | Raises:
153 | ValueError: If the value of `margin` or `loss_at_margin` is invalid.
154 | """
155 |
156 | if not margin > 0:
157 | raise ValueError("`margin` must be positive.")
158 | if not 0.0 < loss_at_margin < 1.0:
159 | raise ValueError("`loss_at_margin` must be between 0 and 1.")
160 |
161 | error = np.linalg.norm(x)
162 | # Compute weight such that at the margin tanh(w * error) = loss_at_margin
163 | w = np.arctanh(np.sqrt(loss_at_margin)) / margin
164 | s = np.tanh(w * error)
165 | return s * s
166 |
167 |
168 | class DistanceReward:
169 | """Shaped reward based on the distance B-A between two entities A and B."""
170 |
171 | def __init__(
172 | self,
173 | key_a: str,
174 | key_b: Optional[str],
175 | position_tolerance: Optional[np.ndarray] = None,
176 | shaping_tolerance: float = 0.1,
177 | loss_at_tolerance: float = 0.95,
178 | max_reward: float = 1.,
179 | offset: Optional[np.ndarray] = None,
180 | z_min: Optional[float] = None,
181 | dim=3
182 | ):
183 | """Initialize the module.
184 |
185 | Args:
186 | key_a: Observation dict key to numpy array containing the position of
187 | object A.
188 | key_b: None or observation dict key to numpy array containing the position
189 | of object B. If None, distance simplifies to d = offset - A.
190 | position_tolerance: Vector of length `dim`. If
191 | `distance/position_tolerance < 1`, will return `maximum_reward`
192 | instead of shaped one. Setting this to `None`, or setting any entry
193 | to zero or close to zero, will effectively disable tolerance.
194 | shaping_tolerance: Scalar distance at which the loss is equal to
195 | `loss_at_tolerance`. Must be a positive float or `None`. If `None`
196 | reward is sparse and hence 0 is returned if
197 | `distance > position_tolerance`.
198 | loss_at_tolerance: The loss when `l2_norm(distance) == shaping_tolerance`.
199 | A `float` between 0 and 1.
200 | max_reward: Reward to return when `distance/position_tolerance < 1`.
201 | offset: Vector of length 3 that is added to the distance, i.e.
202 | `distance = B - A + offset`.
203 | z_min: Absolute object height that the object A center has to above be in
204 | order to generate reward. Used for example in hovering rewards.
205 | dim: The dimensionality of the space in which the distance is computed
206 | """
207 | self._key_a = key_a
208 | self._key_b = key_b
209 |
210 | self._shaping_tolerance = shaping_tolerance
211 | self._loss_at_tolerance = loss_at_tolerance
212 | if max_reward < 1.:
213 | logging.warning("Maximum reward should not be below tanh maximum.")
214 | self._max_reward = max_reward
215 | self._z_min = z_min
216 | self._dim = dim
217 |
218 | if position_tolerance is None:
219 | self._position_tolerance = np.full(
220 | (dim,), fill_value=MINIMUM_POSITION_TOLERANCE)
221 | else:
222 | self._position_tolerance = position_tolerance
223 | self._position_tolerance[self._position_tolerance == 0] = (
224 | MINIMUM_POSITION_TOLERANCE)
225 |
226 | if offset is None:
227 | self._offset = np.zeros((dim,))
228 | else:
229 | self._offset = offset
230 |
231 | def __call__(self, obs: spec_utils.ObservationValue) -> float:
232 |
233 | # Check that object A is high enough before computing the reward.
234 | if self._z_min is not None and obs[self._key_a][2] < self._z_min:
235 | return 0.
236 |
237 | self._current_distance = (self._offset - obs[self._key_a][0:self._dim])
238 | if self._key_b is not None:
239 | self._current_distance += obs[self._key_b][0:self._dim]
240 |
241 | weighted = self._current_distance / self._position_tolerance
242 | if np.linalg.norm(weighted) <= 1.:
243 | return self._max_reward
244 |
245 | if not self._shaping_tolerance:
246 | return 0.
247 |
248 | loss = tanh_squared(
249 | self._current_distance, margin=self._shaping_tolerance,
250 | loss_at_margin=self._loss_at_tolerance)
251 | return 1.0 - loss
252 |
253 |
254 | class LiftShaped:
255 | """Linear shaped reward for lifting, up to a specified height.
256 |
257 | Once the height is above a specified threshold, reward saturates. Shaping can
258 | also be deactivated for a sparse reward.
259 |
260 | Requires an observation `//abs_pose` containing the
261 | pose of the object in question.
262 | """
263 |
264 | def __init__(
265 | self,
266 | obj_key,
267 | z_threshold,
268 | z_min,
269 | max_reward=1.,
270 | shaping=True
271 | ):
272 | """Initialize the module.
273 |
274 | Args:
275 | obj_key: Key in the observation dict containing the object pose.
276 | z_threshold: Absolute object height at which the maximum reward will be
277 | given.
278 | z_min: Absolute object height that the object center has to above be in
279 | order to generate shaped reward. Ignored if `shaping` is False.
280 | max_reward: Reward given when the object is above the `z_threshold`.
281 | shaping: If true, will give a linear shaped reward when the object height
282 | is above `z_min`, but below `z_threshold`.
283 | Raises:
284 | ValueError: if `z_min` is larger than `z_threshold`.
285 | """
286 | self._field = obj_key
287 | self._z_threshold = z_threshold
288 | self._shaping = shaping
289 | self._max_reward = max_reward
290 | self._z_min = z_min
291 | if z_min > z_threshold:
292 | raise ValueError("Lower shaping bound cannot be below upper bound.")
293 |
294 | def __call__(self, obs: spec_utils.ObservationValue) -> float:
295 | obj_z = obs[self._field][2]
296 | if obj_z <= self._z_min:
297 | return 0.0
298 | if obj_z >= self._z_threshold:
299 | return self._max_reward
300 | if self._shaping:
301 | r = (obj_z - self._z_min) / (self._z_threshold - self._z_min)
302 | return r
303 | return 0.0
304 |
305 |
306 | class Product:
307 | """Computes the product of a set of rewards."""
308 |
309 | def __init__(self, terms: Iterable[RewardFunction]):
310 | self._terms = terms
311 |
312 | def __call__(self, obs: spec_utils.ObservationValue) -> float:
313 | r = 1.
314 | for term in self._terms:
315 | r *= term(obs)
316 | return r
317 |
318 |
319 | class _WeightedTermReward:
320 | """Base class for rewards using lists of weighted terms."""
321 |
322 | def __init__(self,
323 | terms: Iterable[RewardFunction],
324 | weights: Optional[Iterable[float]] = None):
325 | """Initialize the reward instance.
326 |
327 | Args:
328 | terms: List of reward callables to be operated on. Each callable must
329 | take an observation as input, and return a float.
330 | weights: Weight that each reward returned by the callables in `terms` will
331 | be multiplied by. If `None`, will weight all terms with 1.0.
332 | Raises:
333 | ValueError: If `weights` has been specified, but its length differs from
334 | that of `terms`.
335 | """
336 | self._terms = list(terms)
337 | self._weights = weights or [1.] * len(self._terms)
338 | if len(self._weights) != len(self._terms):
339 | raise ValueError("Number of terms and weights should be same.")
340 |
341 | def _weighted_terms(
342 | self, obs: spec_utils.ObservationValue) -> Iterable[float]:
343 | return [t(obs) * w for t, w in zip(self._terms, self._weights)]
344 |
345 |
346 | class Max(_WeightedTermReward):
347 | """Selects the maximum among a number of weighted rewards."""
348 |
349 | def __call__(self, obs: spec_utils.ObservationValue) -> float:
350 | return max(self._weighted_terms(obs))
351 |
352 |
353 | class ConditionalAnd:
354 | """Perform an and operation conditional on term1 exceeding a threshold."""
355 |
356 | def __init__(self,
357 | term1: RewardFunction,
358 | term2: RewardFunction,
359 | threshold: float):
360 | self._term1 = term1
361 | self._term2 = term2
362 | self._thresh = threshold
363 |
364 | def __call__(self, obs: spec_utils.ObservationValue) -> float:
365 | r1 = self._term1(obs)
366 | r2 = self._term2(obs)
367 | if r1 > self._thresh:
368 | return (0.5 + r2 / 2.) * r1
369 | else:
370 | return r1 * 0.5
371 |
372 |
373 | class Staged:
374 | """Stages the rewards.
375 |
376 | This works by cycling through the terms backwards and using the last reward
377 | that gives a response above the provided threshold + the number of
378 | terms preceding it.
379 |
380 | Rewards must be in [0;1], otherwise they will be clipped.
381 | """
382 |
383 | def __init__(
384 | self, terms: Iterable[RewardFunction], threshold: float):
385 | def make_clipped(term: RewardFunction):
386 | return lambda obs: np.clip(term(obs), 0., 1.)
387 | self._terms = [make_clipped(term) for term in terms]
388 | self._thresh = threshold
389 |
390 | def __call__(self, obs: spec_utils.ObservationValue) -> float:
391 | last_reward = 0.
392 | num_stages = float(len(self._terms))
393 | for i, term in enumerate(reversed(self._terms)):
394 | last_reward = term(obs)
395 | if last_reward > self._thresh:
396 | # Found a reward above the threshold, add number of preceding terms
397 | # and normalize with the number of terms.
398 | return (len(self._terms) - (i + 1) + last_reward) / num_stages
399 | # Return the accumulated rewards.
400 | return last_reward / num_stages
401 |
--------------------------------------------------------------------------------
/rgb_stacking/stack_rewards.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Shaped and sparse reward functions for the RGB stacking task."""
17 |
18 | from typing import Callable
19 |
20 | from dm_control import mjcf
21 | from dm_robotics.agentflow import spec_utils
22 | from dm_robotics.moma import prop
23 | import numpy as np
24 |
25 | from rgb_stacking import physics_utils
26 | from rgb_stacking import reward_functions
27 | from rgb_stacking import task
28 |
29 | # Heights above basket surface between which shaping for Lift will be computed.
30 | _LIFT_Z_MAX_ABOVE_BASKET = 0.1
31 | _LIFT_Z_MIN_ABOVE_BASKET = 0.055
32 |
33 | # Height for lifting.
34 | _LIFTING_HEIGHT = 0.04
35 |
36 | # Cylinder around (bottom object + object height) in which the top object must
37 | # be for it to be considered stacked.
38 | _STACK_HORIZONTAL_TOLERANCE = 0.03
39 | _STACK_VERTICAL_TOLERANCE = 0.01
40 |
41 | # Target height above bottom object for hover-above reward.
42 | _HOVER_OFFSET = 0.1
43 |
44 | # Distances at which maximum reach reward is given and shaping decays.
45 | _REACH_POSITION_TOLERANCE = np.array((0.02, 0.02, 0.02))
46 | _REACH_SHAPING_TOLERANCE = 0.15
47 |
48 | # Distances at which maximum hovering reward is given and shaping decays.
49 | _HOVER_POSITION_TOLERANCE = np.array((0.01, 0.01, 0.01))
50 | _HOVER_SHAPING_TOLERANCE = 0.2
51 |
52 | # Observation keys needed to compute the different rewards.
53 | _PINCH_POSE = 'sawyer/pinch/pose'
54 | _FINGER_ANGLE = 'gripper/joints/angle'
55 | _GRIPPER_GRASP = 'gripper/grasp'
56 | _RED_OBJECT_POSE = 'rgb30_red/abs_pose'
57 | _BLUE_OBJECT_POSE = 'rgb30_blue/abs_pose'
58 |
59 |
60 | # Object closeness threshold for x and y axis.
61 | _X_Y_CLOSE = 0.05
62 | _ONTOP = 0.02
63 |
64 |
65 | def get_shaped_stacking_reward() -> reward_functions.RewardFunction:
66 | """Returns a callable reward function for stacking the red block on blue."""
67 | # # First stage: reach and grasp.
68 | reach_red = reward_functions.DistanceReward(
69 | key_a=_PINCH_POSE,
70 | key_b=_RED_OBJECT_POSE,
71 | shaping_tolerance=_REACH_SHAPING_TOLERANCE,
72 | position_tolerance=_REACH_POSITION_TOLERANCE)
73 | close_fingers = reward_functions.DistanceReward(
74 | key_a=_FINGER_ANGLE,
75 | key_b=None,
76 | position_tolerance=None,
77 | shaping_tolerance=255.,
78 | loss_at_tolerance=0.95,
79 | max_reward=1.,
80 | offset=np.array((255,)),
81 | dim=1)
82 | grasp = reward_functions.Max(
83 | terms=(
84 | close_fingers, reward_functions.GraspReward(obs_key=_GRIPPER_GRASP)),
85 | weights=(0.5, 1.))
86 | reach_grasp = reward_functions.ConditionalAnd(reach_red, grasp, 0.9)
87 |
88 | # Second stage: grasp and lift.
89 | lift_reward = _get_reward_lift_red()
90 | lift_red = reward_functions.Product([grasp, lift_reward])
91 |
92 | # Third stage: hover.
93 | top = _RED_OBJECT_POSE
94 | bottom = _BLUE_OBJECT_POSE
95 | place = reward_functions.DistanceReward(
96 | key_a=top,
97 | key_b=bottom,
98 | offset=np.array((0., 0., _LIFTING_HEIGHT)),
99 | position_tolerance=_HOVER_POSITION_TOLERANCE,
100 | shaping_tolerance=_HOVER_SHAPING_TOLERANCE,
101 | z_min=0.1)
102 |
103 | # Fourth stage: stack.
104 | stack = _get_reward_stack_red_on_blue()
105 |
106 | # Final stage: stack-and-leave
107 | stack_leave = reward_functions.Product(
108 | terms=(_get_reward_stack_red_on_blue(), _get_reward_above_red()))
109 |
110 | return reward_functions.Staged(
111 | [reach_grasp, lift_red, place, stack, stack_leave], 0.01)
112 |
113 |
114 | def get_sparse_stacking_reward() ->reward_functions.RewardFunction:
115 | """Sparse stacking reward for red-on-blue with the gripper moved away."""
116 | return reward_functions.Product(
117 | terms=(_get_reward_stack_red_on_blue(), _get_reward_above_red()))
118 |
119 |
120 | def _get_reward_lift_red() -> reward_functions.RewardFunction:
121 | """Returns a callable reward function for lifting the red block."""
122 | lift_reward = reward_functions.LiftShaped(
123 | obj_key=_RED_OBJECT_POSE,
124 | z_threshold=task.DEFAULT_BASKET_HEIGHT + _LIFT_Z_MAX_ABOVE_BASKET,
125 | z_min=task.DEFAULT_BASKET_HEIGHT + _LIFT_Z_MIN_ABOVE_BASKET)
126 | # Keep the object inside the area of the base plate.
127 | inside_basket = reward_functions.DistanceReward(
128 | key_a=_RED_OBJECT_POSE,
129 | key_b=None,
130 | position_tolerance=task.WORKSPACE_SIZE / 2.,
131 | shaping_tolerance=1e-12, # Practically none.
132 | loss_at_tolerance=0.95,
133 | max_reward=1.,
134 | offset=task.WORKSPACE_CENTER)
135 | return reward_functions.Product([lift_reward, inside_basket])
136 |
137 |
138 | def _get_reward_stack_red_on_blue() -> reward_functions.RewardFunction:
139 | """Returns a callable reward function for stacking the red block on blue."""
140 | return reward_functions.StackPair(
141 | obj_top=_RED_OBJECT_POSE,
142 | obj_bottom=_BLUE_OBJECT_POSE,
143 | obj_top_size=_LIFTING_HEIGHT,
144 | obj_bottom_size=_LIFTING_HEIGHT,
145 | horizontal_tolerance=_STACK_HORIZONTAL_TOLERANCE,
146 | vertical_tolerance=_STACK_VERTICAL_TOLERANCE)
147 |
148 |
149 | def _get_reward_above_red() -> reward_functions.RewardFunction:
150 | """Returns a callable reward function for being above the red block."""
151 | return reward_functions.DistanceReward(
152 | key_a=_PINCH_POSE,
153 | key_b=_RED_OBJECT_POSE,
154 | shaping_tolerance=0.05,
155 | offset=np.array((0., 0., _HOVER_OFFSET)),
156 | position_tolerance=np.array((1., 1., 0.03))) # Anywhere horizontally.
157 |
158 |
159 | class SparseStack(object):
160 | """Sparse stack reward.
161 |
162 | Checks that the two objects being within _X_Y_CLOSE
163 | of each other in the x-y plane (no constraint on z-distance). Also checks
164 | that the object are not in contact with the robot, to ensure the robot is
165 | holding the objects in place.
166 | """
167 |
168 | def __init__(self,
169 | top_object: prop.Prop,
170 | bottom_object: prop.Prop,
171 | get_physics_fn: Callable[[], mjcf.Physics]):
172 | """Initializes the reward.
173 |
174 | Args:
175 | top_object: Composer entity of the top object (red).
176 | bottom_object: Composer entity of the bottom object (blue).
177 | get_physics_fn: Callable that returns the current mjc physics from the
178 | environment.
179 | """
180 | self._get_physics_fn = get_physics_fn
181 | self._top_object = top_object
182 | self._bottom_object = bottom_object
183 |
184 | def _align(self, physics):
185 | return np.linalg.norm(
186 | self._top_object.get_pose(physics)[0][:2] -
187 | self._bottom_object.get_pose(physics)[0][:2]) < _X_Y_CLOSE
188 |
189 | def _ontop(self, physics):
190 |
191 | return (self._top_object.get_pose(physics)[0][2] -
192 | self._bottom_object.get_pose(physics)[0][2]) > _ONTOP
193 |
194 | def _pile(self, physics):
195 | geom = '{}/'.format(self._top_object.name)
196 | if physics_utils.has_ground_collision(physics, collision_geom_prefix=geom):
197 | return float(0.0)
198 | if physics_utils.has_robot_collision(physics, collision_geom_prefix=geom):
199 | return float(0.0)
200 | return float(1.0)
201 |
202 | def _collide(self, physics):
203 | collision_geom_prefix_1 = '{}/'.format(self._top_object.name)
204 | collision_geom_prefix_2 = '{}/'.format(self._bottom_object.name)
205 | return physics_utils.has_collision(physics, [collision_geom_prefix_1],
206 | [collision_geom_prefix_2])
207 |
208 | def __call__(self, obs: spec_utils.ObservationValue):
209 | del obs
210 | physics = self._get_physics_fn()
211 | if self._align(physics) and self._pile(physics) and self._collide(
212 | physics) and self._ontop(physics):
213 | return 1.0
214 | return 0.0
215 |
216 |
217 | def get_sparse_reward_fn(
218 | top_object: prop.Prop,
219 | bottom_object: prop.Prop,
220 | get_physics_fn: Callable[[], mjcf.Physics]
221 | ) -> reward_functions.RewardFunction:
222 | """Sparse stacking reward for stacking two props with no robot contact.
223 |
224 | Args:
225 | top_object: The bottom object (blue).
226 | bottom_object: The top object (red).
227 | get_physics_fn: Callable that returns the current mjcf physics from the
228 | environment.
229 |
230 | Returns:
231 | The sparse stack reward function.
232 | """
233 | return SparseStack(
234 | top_object=top_object,
235 | bottom_object=bottom_object,
236 | get_physics_fn=get_physics_fn)
237 |
--------------------------------------------------------------------------------
/rgb_stacking/task.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """A module for constructing the RGB stacking task.
17 |
18 | This file builds the composer task containing a single sawyer robot facing
19 | 3 objects: a red, a green and a blue one.
20 |
21 | We define:
22 | - All the simulation objects, robot, basket, objects.
23 | - The sensors to measure the state of the environment.
24 | - The effector to control the robot.
25 | - The initialization logic.
26 |
27 | On top of this we can build a MoMa subtask environment. In this subtask
28 | environment we will decide what the reward will be and what observations are
29 | exposed. Thus allowing us to change the goal without changing this environment.
30 | """
31 | from typing import Sequence
32 |
33 | from dm_control import composer
34 | from dm_control.composer.variation import distributions
35 | from dm_control.composer.variation import rotations
36 | from dm_robotics.geometry import pose_distribution
37 | from dm_robotics.manipulation.props.rgb_objects import rgb_object
38 | from dm_robotics.manipulation.standard_cell import rgb_basket
39 | from dm_robotics.moma import base_task
40 | from dm_robotics.moma import entity_initializer
41 | from dm_robotics.moma import prop
42 | from dm_robotics.moma import robot as moma_robot
43 | from dm_robotics.moma.effectors import arm_effector as arm_effector_module
44 | from dm_robotics.moma.effectors import cartesian_4d_velocity_effector
45 | from dm_robotics.moma.effectors import cartesian_6d_velocity_effector
46 | from dm_robotics.moma.effectors import default_gripper_effector
47 | from dm_robotics.moma.effectors import min_max_effector
48 | from dm_robotics.moma.models.arenas import empty
49 | from dm_robotics.moma.models.end_effectors.robot_hands import robotiq_2f85
50 | from dm_robotics.moma.models.end_effectors.wrist_sensors import robotiq_fts300
51 | from dm_robotics.moma.models.robots.robot_arms import sawyer
52 | from dm_robotics.moma.models.robots.robot_arms import sawyer_constants
53 | from dm_robotics.moma.sensors import action_sensor
54 | from dm_robotics.moma.sensors import camera_sensor
55 | from dm_robotics.moma.sensors import prop_pose_sensor
56 | from dm_robotics.moma.sensors import robot_arm_sensor
57 | from dm_robotics.moma.sensors import robot_tcp_sensor
58 | from dm_robotics.moma.sensors import robot_wrist_ft_sensor
59 | from dm_robotics.moma.sensors import robotiq_gripper_sensor
60 | from dm_robotics.moma.sensors import site_sensor
61 | import numpy as np
62 |
63 |
64 | # Margin from the joint limits at which to stop the Z rotation when using 4D
65 | # control. Values chosen to match the existing PyRobot-based environments.
66 | WRIST_RANGE = (-np.pi / 2, np.pi / 2)
67 |
68 |
69 | # Position of the basket relative to the attachment point of the robot.
70 | _DEFAULT_BASKET_CENTER = (0.6, 0.)
71 | DEFAULT_BASKET_HEIGHT = 0.0498
72 | _BASKET_ORIGIN = _DEFAULT_BASKET_CENTER + (DEFAULT_BASKET_HEIGHT,)
73 | WORKSPACE_CENTER = np.array(_DEFAULT_BASKET_CENTER + (0.1698,))
74 | WORKSPACE_SIZE = np.array([0.25, 0.25, 0.2])
75 | # Maximum linear and angular velocity of the robot's TCP.
76 | _MAX_LIN_VEL = 0.07
77 | _MAX_ANG_VEL = 1.0
78 |
79 | # Limits of the distributions used to sample initial positions (X, Y, Z in [m])
80 | # for props in the basket.
81 | _PROP_MIN_POSITION_BOUNDS = [0.50, -0.10, 0.12]
82 | _PROP_MAX_POSITION_BOUNDS = [0.70, 0.10, 0.12]
83 |
84 | # Limits of the distributions used to sample initial position for TCP.
85 | _TCP_MIN_POSE_BOUNDS = [0.5, -0.14, 0.22, np.pi, 0, -np.pi / 4]
86 | _TCP_MAX_POSE_BOUNDS = [0.7, 0.14, 0.43, np.pi, 0, np.pi / 4]
87 |
88 | # Control timestep exposed to the agent.
89 | _CONTROL_TIMESTEP = 0.05
90 |
91 | # Joint state used for the nullspace.
92 | _NULLSPACE_JOINT_STATE = [
93 | 0.0, -0.5186220703125, -0.529384765625, 1.220857421875, 0.40857421875,
94 | 1.07831640625, 0.0]
95 |
96 | # Joint velocity magnitude limits from the Sawyer URDF.
97 | _JOINT_VEL_LIMITS = sawyer_constants.VELOCITY_LIMITS['max']
98 |
99 | # Identifier for the cameras. The key is the name used for the MoMa camera
100 | # sensor and the value corresponds to the identifier of that camera in the
101 | # mjcf model.
102 | _CAMERA_IDENTIFIERS = {'basket_back_left': 'base/basket_back_left',
103 | 'basket_front_left': 'base/basket_front_left',
104 | 'basket_front_right': 'base/basket_front_right'}
105 |
106 | # Configuration of the MuJoCo cameras.
107 | _CAMERA_CONFIG = camera_sensor.CameraConfig(
108 | width=128,
109 | height=128,
110 | fovy=30.,
111 | has_rgb=True,
112 | has_depth=False,
113 | )
114 |
115 |
116 | def rgb_task(red_obj_id: str,
117 | green_obj_id: str,
118 | blue_obj_id: str) -> base_task.BaseTask:
119 | """Builds a BaseTask and all dependencies.
120 |
121 | Args:
122 | red_obj_id: The RGB object ID that corresponds to the red object. More
123 | information on this can be found in the RGB Objects file.
124 | green_obj_id: See `red_obj_id`
125 | blue_obj_id: See `red_obj_id`
126 |
127 | Returns:
128 | The modular manipulation (MoMa) base task for the RGB stacking environment.
129 | A robot is placed in front of a basket containing 3 objects: a red, a green
130 | and blue one.
131 | """
132 |
133 | # Build the composer scene.
134 | arena = _arena()
135 | _workspace(arena)
136 |
137 | robot = _sawyer_robot(robot_name='robot0')
138 | arena.attach(robot.arm)
139 |
140 | # We add a camera with a good point of view for capturing videos.
141 | pos = '1.4 0.0 0.45'
142 | quat = '0.541 0.455 0.456 0.541'
143 | name = 'main_camera'
144 | fovy = '45'
145 | arena.mjcf_model.worldbody.add(
146 | 'camera', name=name, pos=pos, quat=quat, fovy=fovy)
147 |
148 | props = _props(red_obj_id, green_obj_id, blue_obj_id)
149 | for p in props:
150 | frame = arena.add_free_entity(p)
151 | p.set_freejoint(frame.freejoint)
152 |
153 | # Add in the MoMa sensor to get observations from the environment.
154 | extra_sensors = prop_pose_sensor.build_prop_pose_sensors(props)
155 | camera_configurations = {
156 | name: _CAMERA_CONFIG for name in _CAMERA_IDENTIFIERS.keys()}
157 | extra_sensors.extend(
158 | camera_sensor.build_camera_sensors(
159 | camera_configurations, arena.mjcf_model, _CAMERA_IDENTIFIERS))
160 |
161 | # Initializers to place the TCP and the props in the basket.
162 | dynamic_initializer = entity_initializer.TaskEntitiesInitializer(
163 | [_gripper_initializer(robot), _prop_initializers(props)])
164 |
165 | moma_task = base_task.BaseTask(
166 | task_name='rgb_stacking',
167 | arena=arena,
168 | robots=[robot],
169 | props=props,
170 | extra_sensors=extra_sensors,
171 | extra_effectors=[],
172 | scene_initializer=lambda _: None,
173 | episode_initializer=dynamic_initializer,
174 | control_timestep=_CONTROL_TIMESTEP)
175 | return moma_task
176 |
177 |
178 | def _workspace(arena: composer.Arena) -> rgb_basket.RGBBasket:
179 | """Returns the basket used in the rgb stacking environment."""
180 | workspace = rgb_basket.RGBBasket()
181 | attachment_site = arena.mjcf_model.worldbody.add(
182 | 'site', pos=_BASKET_ORIGIN, rgba='0 0 0 0', size='0.01')
183 | arena.attach(workspace, attachment_site)
184 | return workspace
185 |
186 |
187 | def _gripper_initializer(
188 | robot: moma_robot.Robot) -> entity_initializer.PoseInitializer:
189 | """Populates components with gripper initializers."""
190 |
191 | gripper_pose_dist = pose_distribution.UniformPoseDistribution(
192 | min_pose_bounds=_TCP_MIN_POSE_BOUNDS,
193 | max_pose_bounds=_TCP_MAX_POSE_BOUNDS)
194 | return entity_initializer.PoseInitializer(robot.position_gripper,
195 | gripper_pose_dist.sample_pose)
196 |
197 |
198 | def _prop_initializers(
199 | props: Sequence[prop.Prop]) -> entity_initializer.PropPlacer:
200 | """Populates components with prop pose initializers."""
201 | prop_position = distributions.Uniform(_PROP_MIN_POSITION_BOUNDS,
202 | _PROP_MAX_POSITION_BOUNDS)
203 | prop_quaternion = rotations.UniformQuaternion()
204 |
205 | return entity_initializer.PropPlacer(
206 | props=props,
207 | position=prop_position,
208 | quaternion=prop_quaternion,
209 | settle_physics=True)
210 |
211 |
212 | def _arena() -> composer.Arena:
213 | """Builds an arena Entity."""
214 | arena = empty.Arena()
215 | arena.mjcf_model.size.nconmax = 5000
216 | arena.mjcf_model.size.njmax = 5000
217 |
218 | return arena
219 |
220 |
221 | def _sawyer_robot(robot_name: str) -> moma_robot.Robot:
222 | """Returns a Sawyer robot with all the sensors and effectors."""
223 |
224 | arm = sawyer.Sawyer(
225 | name=robot_name, actuation=sawyer_constants.Actuation.INTEGRATED_VELOCITY)
226 |
227 | gripper = robotiq_2f85.Robotiq2F85()
228 |
229 | wrist_ft = robotiq_fts300.RobotiqFTS300()
230 |
231 | wrist_cameras = []
232 |
233 | # Compose the robot after its model components are constructed. This should
234 | # usually be done early on as some Effectors (and possibly Sensors) can only
235 | # be constructed after the robot components have been composed.
236 | moma_robot.standard_compose(
237 | arm=arm, gripper=gripper, wrist_ft=wrist_ft, wrist_cameras=wrist_cameras)
238 |
239 | # We need to measure the last action sent to the robot and the gripper.
240 | arm_effector, arm_action_sensor = action_sensor.create_sensed_effector(
241 | arm_effector_module.ArmEffector(
242 | arm=arm, action_range_override=None, robot_name=robot_name))
243 |
244 | # Effector used for the gripper. The gripper is controlled by applying the
245 | # min or max command, this allows the agent to quicky learn how to grasp
246 | # instead of learning how to close the gripper first.
247 | gripper_effector, gripper_action_sensor = action_sensor.create_sensed_effector(
248 | default_gripper_effector.DefaultGripperEffector(gripper, robot_name))
249 |
250 | # Enable bang bang control for the gripper, this allows the agent to close and
251 | # open the gripper faster.
252 | gripper_effector = min_max_effector.MinMaxEffector(
253 | base_effector=gripper_effector)
254 |
255 | # Build the 4D cartesian controller, we use a 6D cartesian effector under the
256 | # hood.
257 | effector_model = cartesian_6d_velocity_effector.ModelParams(
258 | element=arm.wrist_site, joints=arm.joints)
259 | effector_control = cartesian_6d_velocity_effector.ControlParams(
260 | control_timestep_seconds=_CONTROL_TIMESTEP,
261 | max_lin_vel=_MAX_LIN_VEL,
262 | max_rot_vel=_MAX_ANG_VEL,
263 | joint_velocity_limits=np.array(_JOINT_VEL_LIMITS),
264 | nullspace_gain=0.025,
265 | nullspace_joint_position_reference=np.array(_NULLSPACE_JOINT_STATE),
266 | regularization_weight=1e-2,
267 | enable_joint_position_limits=True,
268 | minimum_distance_from_joint_position_limit=0.01,
269 | joint_position_limit_velocity_scale=0.95,
270 | max_cartesian_velocity_control_iterations=300,
271 | max_nullspace_control_iterations=300)
272 |
273 | # Don't activate collision avoidance because we are restricted to the virtual
274 | # workspace in the center of the basket.
275 | cart_effector_6d = cartesian_6d_velocity_effector.Cartesian6dVelocityEffector(
276 | robot_name=robot_name,
277 | joint_velocity_effector=arm_effector,
278 | model_params=effector_model,
279 | control_params=effector_control)
280 | cart_effector_4d = cartesian_4d_velocity_effector.Cartesian4dVelocityEffector(
281 | effector_6d=cart_effector_6d,
282 | element=arm.wrist_site,
283 | effector_prefix=f'{robot_name}_cart_4d_vel')
284 |
285 | # Constrain the workspace of the robot.
286 | cart_effector_4d = cartesian_4d_velocity_effector.limit_to_workspace(
287 | cartesian_effector=cart_effector_4d,
288 | element=gripper.tool_center_point,
289 | min_workspace_limits=WORKSPACE_CENTER - WORKSPACE_SIZE / 2,
290 | max_workspace_limits=WORKSPACE_CENTER + WORKSPACE_SIZE / 2,
291 | wrist_joint=arm.joints[-1],
292 | wrist_limits=WRIST_RANGE,
293 | reverse_wrist_range=True)
294 |
295 | robot_sensors = []
296 |
297 | # Sensor for the joint states (torques, velocities and angles).
298 | robot_sensors.append(robot_arm_sensor.RobotArmSensor(
299 | arm=arm, name=f'{robot_name}_arm', have_torque_sensors=True))
300 |
301 | # Sensor for the cartesian pose of the tcp site.
302 | robot_sensors.append(robot_tcp_sensor.RobotTCPSensor(
303 | gripper=gripper, name=robot_name))
304 |
305 | # Sensor for cartesian pose of the wrist site.
306 | robot_sensors.append(site_sensor.SiteSensor(
307 | site=arm.wrist_site, name=f'{robot_name}_wrist_site'))
308 |
309 | # Sensor to measure the state of the gripper (position, velocity and grasp).
310 | robot_sensors.append(robotiq_gripper_sensor.RobotiqGripperSensor(
311 | gripper=gripper, name=f'{robot_name}_gripper'))
312 |
313 | # Sensor for the wrench measured at the wrist sensor.
314 | robot_sensors.append(robot_wrist_ft_sensor.RobotWristFTSensor(
315 | wrist_ft_sensor=wrist_ft, name=f'{robot_name}_wrist'))
316 |
317 | # Sensors to measure the last action sent to the arm joints and the gripper
318 | # actuator.
319 | robot_sensors.extend([arm_action_sensor, gripper_action_sensor])
320 |
321 | return moma_robot.StandardRobot(
322 | arm=arm,
323 | arm_base_site_name='base_site',
324 | gripper=gripper,
325 | robot_sensors=robot_sensors,
326 | wrist_cameras=wrist_cameras,
327 | arm_effector=cart_effector_4d,
328 | gripper_effector=gripper_effector,
329 | wrist_ft=wrist_ft,
330 | name=robot_name)
331 |
332 |
333 | def _props(red: str, green: str, blue: str) -> Sequence[prop.Prop]:
334 | """Build task props."""
335 | objects = ((red, 'red'), (green, 'green'), (blue, 'blue'))
336 | color_set = [
337 | [1, 0, 0, 1],
338 | [0, 1, 0, 1],
339 | [0, 0, 1, 1],
340 | ]
341 | props = []
342 | for i, (obj_id, color) in enumerate(objects):
343 | p = rgb_object.RgbObjectProp(
344 | obj_id=obj_id, color=color_set[i], name=f'rgb_object_{color}')
345 | p = prop.WrapperProp(wrapped_entity=p, name=f'rgb_object_{color}')
346 | props.append(p)
347 |
348 | return props
349 |
--------------------------------------------------------------------------------
/rgb_stacking/utils/permissive_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """A permissive wrapper for a SavedModel."""
17 |
18 | import copy
19 | import inspect
20 |
21 | from typing import Any, Callable, NamedTuple
22 |
23 | from absl import logging
24 | import tensorflow as tf
25 | import tree
26 |
27 |
28 | class _Function(NamedTuple):
29 | """Function exposing signature and expected canonical arguments."""
30 | func: Callable[..., Any]
31 | signature: inspect.Signature
32 | structured_specs: Any
33 |
34 | def __call__(self, *args, **kwargs):
35 | return self.func(*args, **kwargs)
36 |
37 | @property
38 | def canonical_arguments(self) -> inspect.BoundArguments:
39 | args, kwargs = copy.deepcopy(self.structured_specs)
40 | return self.signature.bind(*args, **kwargs)
41 |
42 |
43 | class PermissiveModel:
44 | """A permissive wrapper for a SavedModel."""
45 |
46 | # Disable pytype attribute error checks.
47 | _HAS_DYNAMIC_ATTRIBUTES = True
48 |
49 | def __init__(self, model):
50 | self.model = model
51 |
52 | self._tables = self.model.function_tables()
53 | self._initialized_tables = {}
54 |
55 | def build_parameters(params):
56 | params = [
57 | inspect.Parameter(str(param[0], "utf-8"), param[1])
58 | for param in params
59 | ]
60 | # Always include a VAR_KEYWORD to capture any extraneous arguments.
61 | if all([p.kind != inspect.Parameter.VAR_KEYWORD for p in params]):
62 | params.append(
63 | inspect.Parameter("__unused_kwargs", inspect.Parameter.VAR_KEYWORD))
64 | return params
65 |
66 | signatures = self.model.function_signatures()
67 | if tf.executing_eagerly():
68 | signatures = tree.map_structure(lambda x: x.numpy(), signatures)
69 | else:
70 | with tf.compat.v1.Session() as sess:
71 | signatures = sess.run(signatures)
72 | signatures = {
73 | func_name: inspect.Signature(build_parameters(func_params))
74 | for func_name, func_params in signatures.items()
75 | }
76 | self.signatures = signatures
77 |
78 | # Attach deepfuncs.
79 | for name in self.signatures.keys():
80 | setattr(self, name, self._make_permissive_function(name))
81 |
82 | def _maybe_init_tables(self, concrete_func: Any, name: str):
83 | """Initialise all tables for a function if they are not initialised.
84 |
85 | Some functions rely on hash-tables that must be externally initialized. This
86 | method will perform a one-time initialisation of the tables. It does so by
87 | finding the corresponding op that creates the hash-table handles (these will
88 | be different from the ones observed in the initial deepfuncs), and import
89 | the corresponding keys and values.
90 |
91 | Args:
92 | concrete_func: A tf.ConcreteFunction corresponding to a deepfunc.
93 | name: The name of the deepfunc.
94 | """
95 | if name not in self._tables:
96 | return
97 |
98 | all_nodes = dict(
99 | main={n.name: n for n in concrete_func.graph.as_graph_def().node})
100 | for func_def in concrete_func.graph.as_graph_def().library.function:
101 | all_nodes[func_def.signature.name] = {
102 | n.name: n for n in func_def.node_def
103 | }
104 |
105 | for table_name, (table_keys, table_values) in self._tables[name].items():
106 | table_op = None
107 | for nodes in all_nodes.values():
108 | if table_name in nodes:
109 | if table_op is not None:
110 | raise ValueError(f"Duplicate table op found for {table_name}")
111 | table_op = nodes[table_name]
112 |
113 | logging.info("Initialising table for Op `%s`", table_name)
114 | table_handle_name = table_op.attr["shared_name"].s # pytype: disable=attribute-error
115 | table_handle = tf.raw_ops.HashTableV2(
116 | key_dtype=table_keys.dtype,
117 | value_dtype=table_values.dtype,
118 | shared_name=table_handle_name)
119 | tf.raw_ops.LookupTableImportV2(
120 | table_handle=table_handle, keys=table_keys, values=table_values)
121 | self._initialized_tables[name] = self._tables.pop(name) # Only init once.
122 |
123 | def _make_permissive_function(self, name: str) -> Callable[..., Any]:
124 | """Create a permissive version of a function in the SavedModel."""
125 | if name not in self.signatures:
126 | raise ValueError(f"No function named {name} in SavedModel, "
127 | "options are {self.signatures}")
128 |
129 | tf_func = getattr(self.model, name)
130 | if hasattr(tf_func, "concrete_functions"):
131 | # tf.RestoredFunction
132 | concrete_func, = tf_func.concrete_functions # Expect only one.
133 | elif hasattr(tf_func, "_list_all_concrete_functions"):
134 | # tf.Function
135 | all_concrete_funcs = tf_func._list_all_concrete_functions() # pylint: disable=protected-access
136 | all_concrete_signatures = [
137 | f.structured_input_signature for f in all_concrete_funcs
138 | ]
139 | # This is necessary as tf.saved_model.save can force a retrace on
140 | # tf.Function, resulting in another concrete function with identical
141 | # signature.
142 | unique_concrete_signatures = set([
143 | tuple(tree.flatten_with_path(sig)) for sig in all_concrete_signatures
144 | ])
145 | if len(unique_concrete_signatures) != 1:
146 | raise ValueError(
147 | "Expected exactly one unique concrete_function signature, found "
148 | f"the following: {all_concrete_signatures}")
149 | concrete_func = all_concrete_funcs[0]
150 | else:
151 | raise ValueError(f"No concrete functions found on {tf_func}")
152 |
153 | self._maybe_init_tables(concrete_func, name)
154 |
155 | def func(*args, **kwargs):
156 | bound_args = self.signatures[name].bind(*args, **kwargs)
157 | canonical_args = concrete_func.structured_input_signature
158 |
159 | flat_bound_args = tree.flatten_with_path(
160 | (bound_args.args, bound_args.kwargs))
161 | flat_canonical_args = tree.flatten_with_path(canonical_args)
162 |
163 | # Specs for error reporting.
164 | bound_specs = tree.map_structure(
165 | lambda x: None if x is None else object(),
166 | (bound_args.args, bound_args.kwargs))
167 | canonical_specs = tree.map_structure(
168 | lambda x: None if x is None else object(), canonical_args)
169 |
170 | # Check for missing arguments.
171 | flat_bound_args_dict = dict(flat_bound_args)
172 | for arg_path, arg_spec in flat_canonical_args:
173 | if arg_path not in flat_bound_args_dict and arg_spec is not None:
174 | raise ValueError(
175 | f"Missing argument with path {arg_path}, expected canonical args "
176 | f"with structure {canonical_specs}, received {bound_specs}. (All "
177 | "required values have been replaced by object() for brevity.")
178 |
179 | if arg_path in flat_bound_args_dict and arg_spec is None:
180 | arg_value = flat_bound_args_dict[arg_path]
181 | if arg_value is not None:
182 | logging.warning(
183 | "Received unexpected argument `%s` for path %s, replaced with "
184 | "None.", arg_value, arg_path)
185 | flat_bound_args_dict[arg_path] = None
186 |
187 | # Filter out extraneous arguments and dictionary keys.
188 | flat_canonical_args_dict = dict(flat_canonical_args)
189 | filtered_flat_bound_args = {
190 | arg_path: arg_value
191 | for arg_path, arg_value in flat_bound_args_dict.items()
192 | if arg_path in flat_canonical_args_dict
193 | }
194 | full_flat_bound_args = [
195 | filtered_flat_bound_args.get(arg_path, None)
196 | for arg_path, _ in flat_canonical_args
197 | ]
198 | filtered_args, filtered_kwargs = tree.unflatten_as(
199 | canonical_args, full_flat_bound_args)
200 |
201 | return tf_func(*filtered_args, **filtered_kwargs)
202 |
203 | return _Function(
204 | func,
205 | copy.deepcopy(self.signatures[name]),
206 | copy.deepcopy(concrete_func.structured_input_signature),
207 | )
208 |
--------------------------------------------------------------------------------
/rgb_stacking/utils/policy_loading.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility for loading RGB stacking policies saved as TF SavedModel."""
17 |
18 | from typing import NamedTuple
19 | import dm_env
20 | import numpy as np
21 | # This reverb dependency is needed since otherwise loading a SavedModel throws
22 | # an error when the ReverbDataset op is not found.
23 | import reverb # pylint: disable=unused-import
24 | import tensorflow as tf
25 | import tree
26 | import typing_extensions
27 |
28 | from rgb_stacking.utils import permissive_model
29 |
30 |
31 | class _MPOState(NamedTuple):
32 | counter: tf.Tensor
33 | actor: tree.Structure[tf.Tensor]
34 | critic: tree.Structure[tf.Tensor]
35 |
36 |
37 | @typing_extensions.runtime
38 | class Policy(typing_extensions.Protocol):
39 |
40 | def step(self, timestep: dm_env.TimeStep, state: _MPOState):
41 | pass
42 |
43 | def initial_state(self) -> _MPOState:
44 | pass
45 |
46 |
47 | def policy_from_path(saved_model_path: str) -> Policy:
48 | """Loads policy from stored TF SavedModel."""
49 | policy = tf.saved_model.load(saved_model_path)
50 | # Relax strict requirement with respect to its expected inputs, e.g. in
51 | # regards to unused arguments.
52 | policy = permissive_model.PermissiveModel(policy)
53 |
54 | # The loaded policy's step function expects batched data. Wrap it so that it
55 | # expects unbatched data.
56 | policy_step_batch_fn = policy.step
57 |
58 | def _expand_batch_dim(x):
59 | return np.expand_dims(x, axis=0)
60 |
61 | def _squeeze_batch_dim(x):
62 | return np.squeeze(x, axis=0)
63 |
64 | def policy_step_fn(timestep: dm_env.TimeStep, state: _MPOState):
65 | timestep_batch = dm_env.TimeStep(
66 | None, None, None,
67 | tree.map_structure(_expand_batch_dim, timestep.observation))
68 | state_batch = tree.map_structure(_expand_batch_dim, state)
69 | output_batch = policy_step_batch_fn(timestep_batch, state_batch)
70 | output = tree.map_structure(_squeeze_batch_dim, output_batch)
71 | return output
72 |
73 | policy.step = policy_step_fn
74 | return policy
75 |
76 |
77 | class StatefulPolicyCallable:
78 | """Object-oriented policy for directly using in dm_control viewer."""
79 |
80 | def __init__(self, policy: Policy):
81 | self._policy = policy
82 | self._state = self._policy.initial_state()
83 |
84 | def __call__(self, timestep: dm_env.TimeStep):
85 | if timestep.step_type == dm_env.StepType.FIRST:
86 | self._state = self._policy.initial_state()
87 | (action, _), self._state = self._policy.step(timestep, self._state)
88 | return action
89 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2021 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 | # Fail on any error.
18 | set -e
19 |
20 | # Display commands being run.
21 | set -x
22 |
23 | TMP_DIR=`mktemp -d`
24 |
25 | python3 -m venv "${TMP_DIR}/rgb_stacking"
26 | source "${TMP_DIR}/rgb_stacking/bin/activate"
27 |
28 | # Install dependencies.
29 | pip install --upgrade -r requirements.txt
30 |
31 | # Run the visualization of the environment.
32 | python -m rgb_stacking.main
33 |
34 | # Clean up.
35 | rm -r ${TMP_DIR}
36 |
--------------------------------------------------------------------------------