├── .gitignore ├── LICENSE.md ├── README.md ├── core ├── engine │ ├── cloth_simulator_gripper.py │ ├── mpm_simulator_md.py │ └── render.py └── envs │ └── hang_cloth_env.py ├── images ├── ant_agent.gif ├── ant_demo.gif ├── cloth.gif ├── cloth_expert.gif └── results.png └── policy ├── brax_task ├── expert │ ├── acrobot_traj_action.npy │ ├── acrobot_traj_obs.npy │ ├── acrobot_traj_reward.npy │ ├── acrobot_traj_state.npy │ ├── ant_traj_action.npy │ ├── ant_traj_obs.npy │ ├── ant_traj_reward.npy │ ├── ant_traj_state.npy │ ├── hopper_traj_action.npy │ ├── hopper_traj_obs.npy │ ├── hopper_traj_reward.npy │ ├── hopper_traj_state.npy │ ├── humanoid_params.pickle │ ├── humanoid_traj_action.npy │ ├── humanoid_traj_obs.npy │ ├── humanoid_traj_reward.npy │ ├── humanoid_traj_state.npy │ ├── inverted_pendulum_traj_action.npy │ ├── inverted_pendulum_traj_obs.npy │ ├── inverted_pendulum_traj_reward.npy │ ├── inverted_pendulum_traj_state.npy │ ├── reacher_traj_action.npy │ ├── reacher_traj_obs.npy │ ├── reacher_traj_reward.npy │ ├── reacher_traj_state.npy │ ├── swimmer_traj_action.npy │ ├── swimmer_traj_obs.npy │ ├── swimmer_traj_reward.npy │ ├── swimmer_traj_state.npy │ ├── walker2d_traj_action.npy │ ├── walker2d_traj_obs.npy │ ├── walker2d_traj_reward.npy │ └── walker2d_traj_state.npy ├── expert_multi_traj │ ├── acrobot_params.pickle │ ├── acrobot_traj_action.npy │ ├── acrobot_traj_observation.npy │ ├── acrobot_traj_reward.npy │ ├── acrobot_traj_state.npy │ ├── ant_params.pickle │ ├── ant_traj_action.npy │ ├── ant_traj_observation.npy │ ├── ant_traj_reward.npy │ ├── ant_traj_state.npy │ ├── hopper_params.pickle │ ├── hopper_traj_action.npy │ ├── hopper_traj_observation.npy │ ├── hopper_traj_reward.npy │ ├── hopper_traj_state.npy │ ├── humanoid_params.pickle │ ├── humanoid_traj_action.npy │ ├── humanoid_traj_observation.npy │ ├── humanoid_traj_reward.npy │ ├── humanoid_traj_state.npy │ ├── inverted_pendulum_params.pickle │ ├── inverted_pendulum_traj_action.npy │ ├── inverted_pendulum_traj_observation.npy │ ├── inverted_pendulum_traj_reward.npy │ ├── inverted_pendulum_traj_state.npy │ ├── reacher_params.pickle │ ├── reacher_traj_action.npy │ ├── reacher_traj_observation.npy │ ├── reacher_traj_reward.npy │ ├── reacher_traj_state.npy │ ├── swimmer_params.pickle │ ├── swimmer_traj_action.npy │ ├── swimmer_traj_observation.npy │ ├── swimmer_traj_reward.npy │ ├── swimmer_traj_state.npy │ ├── walker2d_params.pickle │ ├── walker2d_traj_action.npy │ ├── walker2d_traj_done.npy │ ├── walker2d_traj_observation.npy │ ├── walker2d_traj_reward.npy │ └── walker2d_traj_state.npy ├── expert_used_in_paper │ ├── acrobot.pickle │ ├── acrobot_params.pickle │ ├── acrobot_traj_action.npy │ ├── acrobot_traj_obs.npy │ ├── acrobot_traj_state.npy │ ├── ant.pickle │ ├── ant_params.pickle │ ├── ant_traj_action.npy │ ├── ant_traj_obs.npy │ ├── ant_traj_state.npy │ ├── halfcheetah.pickle │ ├── halfcheetah_params.pickle │ ├── halfcheetah_traj_action.npy │ ├── halfcheetah_traj_obs.npy │ ├── halfcheetah_traj_state.npy │ ├── hopper.pickle │ ├── hopper_params.pickle │ ├── hopper_traj_action.npy │ ├── hopper_traj_obs.npy │ ├── hopper_traj_state.npy │ ├── humanoid.pickle │ ├── humanoid_params.pickle │ ├── humanoid_traj_action.npy │ ├── humanoid_traj_obs.npy │ ├── humanoid_traj_state.npy │ ├── inverted_pendulum.pickle │ ├── inverted_pendulum_params.pickle │ ├── inverted_pendulum_traj_action.npy │ ├── inverted_pendulum_traj_obs.npy │ ├── inverted_pendulum_traj_state.npy │ ├── reacher.pickle │ ├── reacher_params.pickle │ ├── reacher_traj_action.npy │ ├── reacher_traj_obs.npy │ ├── reacher_traj_state.npy │ ├── swimmer.pickle │ ├── swimmer_params.pickle │ ├── swimmer_traj_action.npy │ ├── swimmer_traj_obs.npy │ ├── swimmer_traj_state.npy │ ├── walker2d.pickle │ ├── walker2d_params.pickle │ ├── walker2d_traj_action.npy │ ├── walker2d_traj_obs.npy │ └── walker2d_traj_state.npy ├── train_multi_traj.py └── train_on_policy.py ├── cloth_task ├── expert │ ├── hang_cloth.pickle │ ├── hang_cloth_traj_action.npy │ ├── hang_cloth_traj_obs.npy │ └── hang_cloth_traj_state.npy └── train_on_policy.py └── util ├── ILD_rollout.py └── expert_rollout.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | core/engine/__pycache__/ 3 | 4 | *.pyc 5 | 6 | policy/.DS_Store 7 | 8 | core/.DS_Store 9 | 10 | .idea/ 11 | 12 | .trunk/ 13 | 14 | .isort.cfg 15 | 16 | .markdownlint.yaml 17 | 18 | .flake8 19 | 20 | policy/brax_task/logs_bp/ 21 | policy/brax_task/logs 22 | policy/brax_task/start_train*.sh 23 | *.v2 24 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Imitation Learning via Differentiable Physics 2 | 3 | Existing imitation learning (IL) methods such as inverse reinforcement learning (IRL) usually have a double-loop training process, alternating between learning a reward function and a policy and tend to suffer long training time and high variance. In this work, we identify the benefits of differentiable physics simulators and propose a new IL method, i.e., Imitation Learning via Differentiable Physics (ILD), which gets rid of the double-loop design and achieves significant improvements in final performance, convergence speed, and stability. 4 | 5 | [[paper](https://arxiv.org/abs/2206.04873)] [[code](https://github.com/sail-sg/ILD)] 6 | 7 | #### Brax MuJoCo Tasks 8 | 9 | Our ILD agent learns using a single expert demonstration with much less variance and higher performance. 10 | 11 | ![results](images/results.png) 12 | 13 | 14 | 15 | Expert Demo | Learned Policy 16 | :-: | :-: 17 | ![](images/ant_demo.gif) | ![](images/ant_agent.gif) 18 | 19 | 20 | 21 | 22 | 23 | #### Cloth Manipulation Task 24 | 25 | We collect a single expert demonstration in a noise-free environment. Despite the presence of severe control noise in the test environment, our method completes the task and recovers the expert behavior. 26 | 27 | Expert Demo (Noise-free) | Learned Policy (Heavy Noise in Control) 28 | :-: | :-: 29 | ![](images/cloth_expert.gif) | ![](images/cloth.gif) 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | #### Installation 40 | 41 | ~~~ 42 | conda create -n ILD python==3.8 43 | conda activate ILD 44 | 45 | pip install --upgrade pip 46 | pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 47 | pip install brax 48 | pip install streamlit 49 | pip install tensorflow 50 | pip install open3d 51 | ~~~ 52 | 53 | #### Start training 54 | ~~~ 55 | # train with a single demonstration 56 | cd policy/brax_task 57 | CUDA_VISIBLE_DEVICES=0 python train_on_policy.py --env="ant" --seed=1 58 | 59 | # train with multiple demonstrations 60 | cd policy/brax_task 61 | CUDA_VISIBLE_DEVICES=0 python train_multi_traj.py --env="ant" --seed=1 62 | 63 | # train with cloth manipulation task 64 | cd policy/cloth_task 65 | CUDA_VISIBLE_DEVICES=0 python train_on_policy.py --seed=1 66 | ~~~ 67 | -------------------------------------------------------------------------------- /core/engine/cloth_simulator_gripper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from functools import partial 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | from jax import custom_vjp, random 21 | from jax._src.lax.control_flow import fori_loop, scan 22 | 23 | N = 190 24 | cell_size = 1.0 / N 25 | gravity = 0.5 26 | stiffness = 1600 27 | damping = 2 28 | dt = 2e-3 29 | max_v = 2. 30 | 31 | ball_radius = 0.2 32 | num_triangles = (N - 1) * (N - 1) * 2 33 | 34 | links = [[-1, 0], [1, 0], [0, -1], [0, 1], [-1, -1], [1, -1], [-1, 1], [1, 1]] 35 | links = jnp.array(links) 36 | 37 | a, b = jnp.indices((N, N)) 38 | grid_idx = jnp.concatenate((a[..., None], b[..., None]), axis=2) 39 | indices, vertices, cloth_mask = None, None, None 40 | key_global = jax.random.PRNGKey(1) 41 | 42 | 43 | @partial(custom_vjp) 44 | def norm_grad(x): 45 | return x 46 | 47 | 48 | def norm_grad_fwd(x): 49 | return x, () 50 | 51 | 52 | def norm_grad_bwd(x, g): 53 | g /= jnp.linalg.norm(g) 54 | g = jnp.nan_to_num(g) 55 | g /= cloth_mask.sum() 56 | 57 | return g, 58 | 59 | 60 | norm_grad.defvjp(norm_grad_fwd, norm_grad_bwd) 61 | 62 | 63 | def default_collision_func(x, v, idx_i, idx_j): 64 | return x 65 | 66 | 67 | def primitive_collision_func(x, v, action, ps): 68 | # collision with primitive ball 69 | pos, radius = ps[:3], ps[3] 70 | d_v = action[:3].reshape(1, 3) 71 | suction = action[-1] 72 | 73 | # find points on the surface 74 | x_ = x - jnp.array(pos).reshape(1, 3) 75 | dist = jnp.linalg.norm(x_, axis=-1) 76 | mask = dist <= radius 77 | mask = mask[..., None].repeat(3, -1) 78 | v_ = jnp.where(mask, 0, v) 79 | x_ = jnp.where(mask, x + d_v * (1 - suction), x) 80 | 81 | # weight = jnp.exp(-1 * (dist*20 - 1))[..., None] 82 | # v = v - weight * suction * v 83 | # x = x + d_v * weight 84 | 85 | v_mask = jnp.abs(v).max() > max_v 86 | v = jnp.where(v_mask, v, v_) 87 | x = jnp.where(v_mask, x, x_) 88 | 89 | x = norm_grad(x) 90 | v = norm_grad(v) 91 | 92 | return x, v 93 | 94 | 95 | def create_vars(N_, collision_func_, cloth_mask_, key_): 96 | global N, num_triangles, grid_idx, cell_size, collision_func, \ 97 | indices, vertices, cloth_mask, idx_i, idx_j, x_grid 98 | 99 | # set global vars 100 | N = N_ 101 | num_triangles = (N - 1) * (N - 1) * 2 102 | cell_size = 1.0 / N 103 | collision_func = collision_func_ 104 | 105 | # cloth mask 106 | indices = jnp.zeros((num_triangles * 3,)) 107 | vertices = jnp.zeros((N * N, 3)) 108 | cloth_mask = cloth_mask_ 109 | idx_i, idx_j = jnp.nonzero(cloth_mask) 110 | grid_idx = jnp.concatenate([idx_i[:, None], idx_j[:, None]], axis=-1) 111 | 112 | # create x, v 113 | x = np.zeros((N, N, 3)) 114 | for i, j in np.ndindex((N, N)): 115 | x[i, j] = np.array([ 116 | i * cell_size, j * cell_size / np.sqrt(2), 117 | (N - j) * cell_size / np.sqrt(2) + 0.1 118 | ]) 119 | x_grid = jnp.array(x) 120 | v = jnp.zeros((N, N, 3)) 121 | ps0 = jnp.array([0., 0.1, 0.58, 0.01]) 122 | ps1 = jnp.array([0., 0.1, 0.58, 0.01]) 123 | 124 | set_indices() 125 | 126 | # mask x and v 127 | x = x_grid[idx_i, idx_j] 128 | v = v[idx_i, idx_j] 129 | 130 | return x, v, ps0, ps1, key_ 131 | 132 | 133 | def set_indices(): 134 | global indices, cloth_mask 135 | indices = np.array(indices) 136 | cloth_mask = np.array(cloth_mask) 137 | for i, j in np.ndindex((N, N)): 138 | 139 | if i < N - 1 and j < N - 1: 140 | flag = 1 141 | flag *= cloth_mask[i - 1, j - 1] * cloth_mask[i - 1, j] * cloth_mask[i - 1, j + 1] 142 | flag *= cloth_mask[i, j - 1] * cloth_mask[i, j] * cloth_mask[i, j + 1] 143 | flag *= cloth_mask[i + 1, j - 1] * cloth_mask[i + 1, j] * cloth_mask[i + 1, j + 1] 144 | 145 | square_id = (i * (N - 1)) + j 146 | # 1st triangle of the square 147 | indices[square_id * 6 + 0] = i * N + j 148 | indices[square_id * 6 + 1] = (i + 1) * N + j 149 | indices[square_id * 6 + 2] = i * N + (j + 1) 150 | # 2nd triangle of the square 151 | indices[square_id * 6 + 3] = (i + 1) * N + j + 1 152 | indices[square_id * 6 + 4] = i * N + (j + 1) 153 | indices[square_id * 6 + 5] = (i + 1) * N + j 154 | 155 | indices[square_id * 6 + 0] *= flag 156 | indices[square_id * 6 + 1] *= flag 157 | indices[square_id * 6 + 2] *= flag 158 | # 2nd triangle of the square 159 | indices[square_id * 6 + 3] *= flag 160 | indices[square_id * 6 + 4] *= flag 161 | indices[square_id * 6 + 5] *= flag 162 | 163 | cloth_mask = jnp.array(cloth_mask) 164 | indices = jnp.array(indices) 165 | indices = indices.reshape((-1, 3)) 166 | indices = indices[indices.sum(1) != 0] 167 | 168 | 169 | def robot_step(action, x, v, ps0, ps1, key): 170 | def step_(i, state): 171 | action, _, _, _, _, _ = state 172 | state_ = step(*state) 173 | return (action,) + state_ 174 | 175 | # normalize speed, 50 sub steps, 20 is a scale factor 176 | action = action.at[:3].set(action[:3].clip(-1, 1) / 50. / 20.) 177 | action = action.at[4:7].set(action[4:7].clip(-1, 1) / 50. / 20.) 178 | 179 | # add uncertainty 180 | key, _ = random.split(key) 181 | action += random.uniform(key, action.shape) * 0.0001 - 0.00005 # randomness 182 | action += random.uniform(key_global, action.shape) * 0.0004 - 0.0002 # fixed bias, as key_global won't change 183 | 184 | state = (action, x, v, ps0, ps1, key) 185 | state = fori_loop(0, 50, step_, state) 186 | 187 | return state[1:] 188 | 189 | 190 | def step(action, x, v, ps0, ps1, key): 191 | v -= jnp.array([0, gravity * dt, 0]) 192 | action = action.clip(-1, 1) 193 | 194 | action = action.at[3].set(0) 195 | action = action.at[7].set(0) 196 | 197 | # mask out invalid area 198 | # v *= cloth_mask.reshape((N, N, 1)) 199 | j_ = grid_idx.reshape((-1, 1, 2)).repeat(len(links), -2) 200 | j_ = j_ + links[None, ...] 201 | j_ = jnp.clip(j_, 0, N - 1) 202 | 203 | i_ = grid_idx.reshape((-1, 1, 2)).repeat(len(links), -2) 204 | original_length = cell_size * jnp.linalg.norm(j_ - i_, axis=-1)[..., None] 205 | ori_len_is_not_0 = (original_length != 0).astype(jnp.int32) 206 | original_length = jnp.clip(original_length, 1e-12, jnp.inf) 207 | 208 | j_x, j_y = j_.reshape((-1, 2))[:, 0], j_.reshape((-1, 2))[:, 1] 209 | i_x, i_y = i_.reshape((-1, 2))[:, 0], i_.reshape((-1, 2))[:, 1] 210 | 211 | x_grid = jnp.zeros((N, N, 3)).at[idx_i, idx_j].set(x) 212 | relative_pos = x_grid[j_x, j_y] - x_grid[i_x, i_y] 213 | # current_length = jnp.linalg.norm(relative_pos, axis=-1) 214 | current_length = jnp.clip((relative_pos ** 2).sum(-1), 1e-12, jnp.inf) ** 0.5 215 | current_length = current_length.reshape((-1, len(links), 1)) 216 | 217 | force = stiffness * relative_pos.reshape((-1, 8, 3)) / current_length * ( 218 | current_length - original_length) / original_length 219 | 220 | force *= ori_len_is_not_0 221 | 222 | # mask out force from invalid area 223 | force *= cloth_mask[j_x, j_y].reshape((-1, 8, 1)) 224 | 225 | force = force.sum(1) 226 | v += force * dt 227 | v *= jnp.exp(-damping * dt) 228 | 229 | # collision 230 | v = collision_func(x, v, idx_i, idx_j) 231 | x, v = primitive_collision_func(x, v, action[:4], ps0) 232 | x, v = primitive_collision_func(x, v, action[4:], ps1) 233 | 234 | v_mask = jnp.abs(v).max() > max_v 235 | ps0_ = ps0.at[:3].add(action[:3]).clip(0, 1) 236 | ps1_ = ps1.at[:3].add(action[4:7]).clip(0, 1) 237 | ps0 = jnp.where(v_mask, ps0, ps0_) 238 | ps1 = jnp.where(v_mask, ps1, ps1_) 239 | 240 | # collision with the ground 241 | x = x.clip(0, 1) 242 | v = v.clip(-max_v, max_v) 243 | 244 | x += dt * v 245 | 246 | x = norm_grad(x) 247 | v = norm_grad(v) 248 | ps0 = norm_grad(ps0) 249 | ps1 = norm_grad(ps1) 250 | 251 | return x, v, ps0, ps1, key 252 | 253 | 254 | def get_indices(): 255 | return indices 256 | 257 | 258 | def get_x_grid(x): 259 | x_grid_ = x_grid.at[idx_i, idx_j].set(x) 260 | return x_grid_ 261 | 262 | 263 | collision_func = default_collision_func 264 | -------------------------------------------------------------------------------- /core/engine/mpm_simulator_md.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import jax.numpy as jnp 16 | 17 | 18 | def process_state(state, v_size=0, p_size=0): 19 | """ 20 | :param state: tuple of vars 21 | :return: convert into device arrays with additional batch dim 22 | """ 23 | 24 | state_new = () 25 | for i in range(len(state)): 26 | 27 | var_tmp = jnp.array(state[i]) 28 | 29 | if v_size > 0: 30 | var_tmp = var_tmp[None, ...].repeat(v_size, 0) 31 | 32 | if p_size > 0: 33 | var_tmp = var_tmp[None, ...].repeat(p_size, 0) 34 | 35 | state_new += (var_tmp,) 36 | 37 | return state_new 38 | -------------------------------------------------------------------------------- /core/engine/render.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import pathlib 17 | 18 | import numpy as np 19 | 20 | 21 | my_path = pathlib.Path(__file__).parent.resolve() 22 | 23 | def draw_xyz(x, y, z, l): 24 | import open3d as o3d 25 | points = [ 26 | [x, y, z], 27 | [x + l, y, z], 28 | [x, y + l, z], 29 | [x, y, z + l], ] 30 | lines = [ 31 | [0, 1], 32 | [0, 2], 33 | [0, 3], 34 | ] 35 | colors = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] 36 | line_set = o3d.geometry.LineSet( 37 | points=o3d.utility.Vector3dVector(points), 38 | lines=o3d.utility.Vector2iVector(lines), 39 | ) 40 | line_set.colors = o3d.utility.Vector3dVector(colors) 41 | return line_set 42 | 43 | 44 | def draw_container(size, pos, rot): 45 | import open3d as o3d 46 | # size = np.array(size) - 0.02 # only draw inner wall 47 | points = [ 48 | [size[0], -size[1], size[2]], 49 | [size[0], -size[1], -size[2]], 50 | [-size[0], -size[1], -size[2]], 51 | [-size[0], -size[1], size[2]], 52 | 53 | [size[0], size[1], size[2]], 54 | [size[0], size[1], -size[2]], 55 | [-size[0], size[1], -size[2]], 56 | [-size[0], size[1], size[2]], 57 | ] 58 | 59 | lines = [ 60 | [0, 1], 61 | [1, 2], 62 | [2, 3], 63 | [3, 0], 64 | 65 | [0, 4], 66 | [1, 5], 67 | [2, 6], 68 | [3, 7], 69 | ] 70 | colors = [[1, 0, 0]] * 8 71 | line_set = o3d.geometry.LineSet( 72 | points=o3d.utility.Vector3dVector(points), 73 | lines=o3d.utility.Vector2iVector(lines), 74 | ) 75 | line_set.colors = o3d.utility.Vector3dVector(colors) 76 | 77 | rot_mat = line_set.get_rotation_matrix_from_quaternion(rot) 78 | line_set = line_set.rotate(rot_mat) 79 | line_set = line_set.translate(pos) 80 | 81 | return line_set 82 | 83 | 84 | def visualize_pc(xyz): 85 | import open3d as o3d 86 | xyz = xyz[xyz[:, 2] > -5] 87 | xyz = xyz[xyz[:, 1] > -5] 88 | xyz = xyz[xyz[:, 0] > -5] 89 | xyz = xyz[xyz[:, 2] < 5] 90 | xyz = xyz[xyz[:, 1] < 5] 91 | xyz = xyz[xyz[:, 0] < 5] 92 | 93 | pcd_o3d = o3d.geometry.PointCloud() 94 | pcd_o3d.points = o3d.utility.Vector3dVector(xyz) 95 | origin_xyz = draw_xyz(0, 0, 0, 1) 96 | o3d.visualization.draw_geometries([pcd_o3d, origin_xyz]) 97 | 98 | 99 | class BasicRenderer: 100 | 101 | def __init__(self, box_sizes, colors): 102 | import open3d as o3d 103 | self.box_sizes = np.array(box_sizes) 104 | self.n_primitives = len(box_sizes) 105 | self.o3d = o3d 106 | self.vis = o3d.visualization.Visualizer() 107 | self.vis.create_window() 108 | 109 | self.pcd_o3d = o3d.geometry.PointCloud() 110 | self.pcd_o3d.colors = o3d.utility.Vector3dVector(colors) 111 | self.vis.add_geometry(self.pcd_o3d) 112 | 113 | origin_xyz = draw_xyz(0, 0, 0, 1) 114 | self.vis.add_geometry(origin_xyz) 115 | 116 | self.containers = [] 117 | self.containers_ = [] 118 | 119 | for i in range(self.n_primitives): 120 | box_size = box_sizes[i] 121 | # mesh_box = o3d.geometry.TriangleMesh.create_box(width=box_size[0] * 2, height=box_size[1] * 2, 122 | # depth=box_size[2] * 2) 123 | # mesh_box.compute_vertex_normals() 124 | # mesh_box.paint_uniform_color([0.1, 0.1, 0.1]) 125 | container = draw_container(box_size, [0, 0, 0], [1, 0, 0, 0]) 126 | self.containers.append(container) 127 | 128 | container_ = copy.deepcopy(container) 129 | self.containers_.append(container_) 130 | self.vis.add_geometry(container_) 131 | 132 | def render(self, i, state): 133 | for i in range(self.n_primitives): 134 | obj_pos = np.array(state[12][i].reshape((-1, 3))[0]) 135 | obj_rot = np.array(state[13][i].reshape((-1, 4))[0]) 136 | obj_rot = self.containers[i].get_rotation_matrix_from_quaternion(obj_rot) 137 | # effector_pos -= self.box_sizes[i] 138 | self.vis.remove_geometry(self.containers_[i]) 139 | self.containers_[i] = copy.deepcopy(self.containers[i]).rotate(obj_rot).translate(obj_pos) 140 | self.vis.add_geometry(self.containers_[i]) 141 | 142 | # obj_rot = self.containers[i].get_rotation_matrix_from_quaternion(obj_rot) 143 | # self.containers[i].rotate(default_rot).translate([0, 0, 0]) 144 | # self.containers[i].rotate(obj_rot).translate(obj_pos) 145 | # self.vis.update_geometry(self.containers[i]) 146 | 147 | obj_pos = np.array(state[12].reshape((-1, 3))[0]) 148 | particle_pos = np.array(state[0].reshape((-1, 3))[:state[0].shape[-2]]) 149 | print(i, "position", obj_pos, particle_pos[0]) 150 | self.pcd_o3d.points = self.o3d.utility.Vector3dVector(particle_pos) 151 | self.vis.update_geometry(self.pcd_o3d) 152 | self.vis.poll_events() 153 | self.vis.update_renderer() 154 | self.vis.poll_events() 155 | 156 | 157 | class MeshRenderer: 158 | 159 | def __init__(self): 160 | import open3d as o3d 161 | self.o3d = o3d 162 | self.vis = o3d.visualization.Visualizer() 163 | self.vis.create_window() 164 | render_op = self.vis.get_render_option() 165 | render_op.mesh_show_wireframe = True 166 | render_op.mesh_show_back_face = True 167 | 168 | self.mesh = o3d.geometry.TriangleMesh() 169 | self.vis.add_geometry(self.mesh) 170 | 171 | # origin_xyz = draw_xyz(0, 0, 0, 1) 172 | # self.vis.add_geometry(origin_xyz) 173 | 174 | 175 | 176 | def render(self, i, vertices, indices): 177 | import open3d as o3d 178 | np_vertices = np.array(vertices) 179 | np_triangles = np.array(indices).astype(np.int32).reshape((-1, 3)) 180 | self.mesh.vertices = o3d.utility.Vector3dVector(np_vertices) 181 | self.mesh.triangles = o3d.utility.Vector3iVector(np_triangles) 182 | 183 | self.mesh.compute_vertex_normals() 184 | self.vis.update_geometry(self.mesh) 185 | self.vis.poll_events() 186 | self.vis.update_renderer() 187 | -------------------------------------------------------------------------------- /core/envs/hang_cloth_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from copy import deepcopy 16 | 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | from gym.spaces import Box 21 | from jax import jit, vmap, random 22 | 23 | from core.engine.cloth_simulator_gripper import robot_step, create_vars, get_x_grid, get_indices 24 | from core.engine.mpm_simulator_md import process_state 25 | from core.engine.render import MeshRenderer 26 | 27 | N = 64 28 | size = int(N / 5) 29 | pole_pos = np.array([0., 0.3, 0.15]) 30 | pole_radius = 0.01 31 | 32 | 33 | class HangCloth: 34 | 35 | def __init__(self, robot_step_grad_fun, max_steps, init_state, batch_size, visualize=False): 36 | self.state = None 37 | self.seed_num = 0 38 | self.key = random.PRNGKey(self.seed_num) 39 | self.init_state = init_state 40 | self.batch_size = batch_size 41 | self.step_jax = robot_step_grad_fun 42 | self.max_steps = max_steps 43 | self.cur_step = 0 44 | self.action_size = 8 45 | self.observation_size = 288 * 6 + 8 46 | self.cloth_state_shape = (288, 6) 47 | self.observation_space = Box(low=-1.0, high=1.0, shape=(288 * 6 + 8,), dtype=np.float32) 48 | self.action_space = Box(low=-1.0, high=1.0, shape=(8,), dtype=np.float32) 49 | self.visualize = visualize 50 | self.spec = None 51 | 52 | cloth_mask = create_cloth_mask() 53 | self.idx_i, self.idx_j = jnp.nonzero(cloth_mask) 54 | 55 | if visualize: 56 | import open3d as o3d 57 | renderer = MeshRenderer() 58 | pole = o3d.geometry.TriangleMesh.create_cylinder(radius=pole_radius, height=0.35) 59 | R = pole.get_rotation_matrix_from_xyz((0, np.pi / 2, 0)) 60 | pole.rotate(R, center=(0, 0, 0)) 61 | pole.translate([0.5, pole_pos[1], pole_pos[2]]) 62 | pole.compute_vertex_normals() 63 | pole.paint_uniform_color([0.6, 0.6, 0.6]) 64 | renderer.vis.add_geometry(pole) 65 | 66 | # add grippers 67 | ps0_np = np.array(init_state[2]) 68 | gripper0 = o3d.geometry.TriangleMesh.create_sphere(radius=ps0_np[3]) 69 | gripper0.translate(ps0_np[:3], relative=False) 70 | renderer.vis.add_geometry(gripper0) 71 | 72 | ps1_np = np.array(init_state[3]) 73 | gripper1 = o3d.geometry.TriangleMesh.create_sphere(radius=ps1_np[3]) 74 | gripper1.translate(ps1_np[:3], relative=False) 75 | renderer.vis.add_geometry(gripper1) 76 | 77 | self.renderer = renderer 78 | self.gripper0 = gripper0 79 | self.gripper1 = gripper1 80 | 81 | def seed(self, seed): 82 | self.seed_num = seed 83 | self.key = random.PRNGKey(self.seed_num) 84 | self.init_state = self.init_state[:-1] + (self.key,) 85 | self.reset() 86 | 87 | def step(self, action): 88 | self.state = self.step_jax(action, *self.state) 89 | 90 | obs = jnp.concatenate([self.state[0], self.state[1]], axis=-1) 91 | obs = jnp.concatenate([obs.flatten(), self.state[2], self.state[3]], axis=-1) 92 | # TODO change observations 93 | reward, done, info = 0, self.cur_step > self.max_steps, {} 94 | 95 | if self.cur_step >= self.max_steps - 1: 96 | done = True 97 | x = self.state[0] 98 | if x[:, 1].max() >= pole_pos[1] and x[:, 2].min() <= pole_pos[2] and x[:, 2].max() >= pole_pos[2]: 99 | reward = 1. 100 | 101 | self.cur_step += 1 102 | 103 | return np.array(obs), reward, done, info 104 | 105 | def reset(self): 106 | cloth_mask = create_cloth_mask() 107 | self.key, _ = jax.random.split(self.key) 108 | state = create_vars(N, collision_func, cloth_mask, self.key) 109 | 110 | actions = jnp.zeros((60, 8)) 111 | for action in actions: 112 | state = self.step_jax(action, *state) 113 | 114 | x, v, ps0, ps1, key = state 115 | ps0 = [4.0625378e-01, -4.9900409e-04, 5.1644766e-01, 0.01] 116 | ps1 = [5.312532e-01, -4.990041e-04, 5.164433e-01, 0.01] 117 | ps0 = jnp.array(ps0) 118 | ps1 = jnp.array(ps1) 119 | self.state = (x, v, ps0, ps1, key) 120 | 121 | obs = jnp.concatenate([self.state[0], self.state[1]], axis=-1) 122 | obs = jnp.concatenate([obs.flatten(), self.state[2], self.state[3]], axis=-1) 123 | 124 | return np.array(obs) 125 | 126 | @staticmethod 127 | def reset_jax(key_envs, step_jax, batch_size): 128 | cloth_mask = create_cloth_mask() 129 | key_envs, _ = jax.random.split(key_envs) 130 | state = create_vars(N, collision_func, cloth_mask, key_envs) 131 | 132 | state = process_state(state, v_size=batch_size, p_size=0) 133 | actions = jnp.zeros((60, batch_size, 8)) if batch_size > 0 else jnp.zeros((60, 8)) 134 | for action in actions: 135 | state = step_jax(action, *state) 136 | 137 | x, v, ps0, ps1, key = state 138 | ps0 = [4.0625378e-01, -4.9900409e-04, 5.1644766e-01, 0.01] 139 | ps1 = [5.312532e-01, -4.990041e-04, 5.164433e-01, 0.01] 140 | ps0 = jnp.array(ps0) if batch_size == 0 else jnp.array([ps0] * batch_size) 141 | ps1 = jnp.array(ps1) if batch_size == 0 else jnp.array([ps1] * batch_size) 142 | state = (x, v, ps0, ps1, key) 143 | 144 | return state 145 | 146 | def render(self): 147 | x, v, ps0, ps1, _ = self.state 148 | indices = get_indices() 149 | x_grid_ = get_x_grid(x) 150 | x_grid_ = x_grid_.reshape((-1, 3)) 151 | 152 | self.gripper0.translate(np.array(ps0)[:3], relative=False) 153 | self.renderer.vis.update_geometry(self.gripper0) 154 | self.gripper1.translate(np.array(ps1)[:3], relative=False) 155 | self.renderer.vis.update_geometry(self.gripper1) 156 | 157 | self.renderer.render(0, vertices=x_grid_, indices=indices) 158 | 159 | 160 | def collision_func(x, v, idx_i, idx_j): 161 | # collision with pole 162 | x = jnp.zeros((N, N, 3)).at[idx_i, idx_j].set(x) 163 | v = jnp.zeros((N, N, 3)).at[idx_i, idx_j].set(v) 164 | 165 | # find points on the surface 166 | x -= jnp.array(pole_pos).reshape(1, 1, 3) 167 | mask = jnp.linalg.norm(x[..., 1:], axis=-1) <= pole_radius 168 | mask = mask[..., None].repeat(3, -1) 169 | 170 | surface_norm = x.at[..., 0].set(0) * -1 # point into the pole 171 | 172 | # calc surface norm at each point 173 | norm_ = jnp.sqrt((surface_norm ** 2).sum(-1))[..., None] 174 | dot_prod = jnp.einsum('ijk,ijk->ij', v, surface_norm).clip(0, jnp.inf)[..., None] 175 | proj_of_v_on_surface = (dot_prod / norm_ ** 2) * surface_norm 176 | v_ = v - proj_of_v_on_surface # prevent from going into the pole 177 | 178 | v_ *= 0.95 # simulate friction 179 | v = jnp.where(mask, v_, v) 180 | 181 | v = v[idx_i, idx_j] 182 | return v 183 | 184 | 185 | def create_cloth_mask(): 186 | cloth_mask = jnp.zeros((N, N)) 187 | cloth_mask = cloth_mask.at[size * 2:size * 3, size * 2:size * 4].set(1) 188 | 189 | return cloth_mask 190 | 191 | 192 | def make_env(batch_size=0, episode_length=80, visualize=False, seed=0): 193 | cloth_mask = create_cloth_mask() 194 | key = random.PRNGKey(seed) 195 | state = create_vars(N, collision_func, cloth_mask, key) 196 | 197 | actions = jnp.zeros((100, 8)) 198 | 199 | # compile sim according to conf 200 | state = process_state(state, v_size=batch_size, p_size=0) 201 | robot_step_grad_fun = robot_step 202 | if batch_size > 0: 203 | actions = jnp.array(actions)[:, None, ...].repeat(batch_size, 1) 204 | robot_step_grad_fun = vmap(robot_step_grad_fun) 205 | print("compiling simulation") 206 | robot_step_grad_fun = jit(robot_step_grad_fun) 207 | robot_step_grad_fun(actions[0], *state) # to warm up 208 | 209 | env = HangCloth(robot_step_grad_fun, max_steps=episode_length, 210 | init_state=state, batch_size=batch_size, visualize=visualize) 211 | 212 | return env 213 | 214 | 215 | if __name__ == "__main__": 216 | print(jax.devices()) 217 | env = make_env(batch_size=0, visualize=True) 218 | env.reset() 219 | -------------------------------------------------------------------------------- /images/ant_agent.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/images/ant_agent.gif -------------------------------------------------------------------------------- /images/ant_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/images/ant_demo.gif -------------------------------------------------------------------------------- /images/cloth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/images/cloth.gif -------------------------------------------------------------------------------- /images/cloth_expert.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/images/cloth_expert.gif -------------------------------------------------------------------------------- /images/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/images/results.png -------------------------------------------------------------------------------- /policy/brax_task/expert/acrobot_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/acrobot_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/acrobot_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/acrobot_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/acrobot_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/acrobot_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/acrobot_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/acrobot_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/ant_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/ant_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/ant_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/ant_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/ant_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/ant_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/ant_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/ant_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/hopper_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/hopper_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/hopper_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/hopper_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/hopper_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/hopper_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/hopper_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/hopper_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/humanoid_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/humanoid_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert/humanoid_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/humanoid_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/humanoid_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/humanoid_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/humanoid_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/humanoid_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/humanoid_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/humanoid_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/inverted_pendulum_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/inverted_pendulum_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/inverted_pendulum_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/inverted_pendulum_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/inverted_pendulum_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/inverted_pendulum_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/inverted_pendulum_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/inverted_pendulum_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/reacher_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/reacher_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/reacher_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/reacher_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/reacher_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/reacher_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/reacher_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/reacher_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/swimmer_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/swimmer_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/swimmer_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/swimmer_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/swimmer_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/swimmer_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/swimmer_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/swimmer_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/walker2d_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/walker2d_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/walker2d_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/walker2d_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/walker2d_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/walker2d_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert/walker2d_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert/walker2d_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/acrobot_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/acrobot_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/acrobot_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/acrobot_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/acrobot_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/acrobot_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/acrobot_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/acrobot_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/acrobot_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/acrobot_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/ant_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/ant_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/ant_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/ant_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/ant_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/ant_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/ant_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/ant_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/ant_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/ant_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/hopper_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/hopper_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/hopper_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/hopper_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/hopper_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/hopper_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/hopper_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/hopper_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/hopper_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/hopper_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/humanoid_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/humanoid_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/humanoid_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/humanoid_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/humanoid_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/humanoid_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/humanoid_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/humanoid_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/humanoid_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/humanoid_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/inverted_pendulum_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/inverted_pendulum_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/inverted_pendulum_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/inverted_pendulum_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/inverted_pendulum_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/inverted_pendulum_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/inverted_pendulum_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/inverted_pendulum_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/inverted_pendulum_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/inverted_pendulum_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/reacher_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/reacher_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/reacher_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/reacher_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/reacher_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/reacher_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/reacher_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/reacher_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/reacher_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/reacher_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/swimmer_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/swimmer_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/swimmer_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/swimmer_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/swimmer_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/swimmer_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/swimmer_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/swimmer_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/swimmer_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/swimmer_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/walker2d_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/walker2d_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/walker2d_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/walker2d_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/walker2d_traj_done.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/walker2d_traj_done.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/walker2d_traj_observation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/walker2d_traj_observation.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/walker2d_traj_reward.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/walker2d_traj_reward.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_multi_traj/walker2d_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_multi_traj/walker2d_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/acrobot.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/acrobot.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/acrobot_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/acrobot_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/acrobot_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/acrobot_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/acrobot_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/acrobot_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/acrobot_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/acrobot_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/ant.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/ant.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/ant_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/ant_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/ant_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/ant_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/ant_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/ant_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/ant_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/ant_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/halfcheetah.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/halfcheetah.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/halfcheetah_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/halfcheetah_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/halfcheetah_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/halfcheetah_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/halfcheetah_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/halfcheetah_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/halfcheetah_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/halfcheetah_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/hopper.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/hopper.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/hopper_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/hopper_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/hopper_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/hopper_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/hopper_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/hopper_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/hopper_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/hopper_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/humanoid.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/humanoid.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/humanoid_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/humanoid_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/humanoid_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/humanoid_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/humanoid_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/humanoid_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/humanoid_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/humanoid_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/inverted_pendulum.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/inverted_pendulum.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/inverted_pendulum_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/inverted_pendulum_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/inverted_pendulum_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/inverted_pendulum_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/inverted_pendulum_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/inverted_pendulum_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/inverted_pendulum_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/inverted_pendulum_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/reacher.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/reacher.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/reacher_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/reacher_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/reacher_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/reacher_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/reacher_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/reacher_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/reacher_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/reacher_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/swimmer.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/swimmer.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/swimmer_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/swimmer_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/swimmer_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/swimmer_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/swimmer_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/swimmer_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/swimmer_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/swimmer_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/walker2d.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/walker2d.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/walker2d_params.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/walker2d_params.pickle -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/walker2d_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/walker2d_traj_action.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/walker2d_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/walker2d_traj_obs.npy -------------------------------------------------------------------------------- /policy/brax_task/expert_used_in_paper/walker2d_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/brax_task/expert_used_in_paper/walker2d_traj_state.npy -------------------------------------------------------------------------------- /policy/brax_task/train_multi_traj.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import pickle 17 | import time 18 | from functools import partial 19 | from typing import Any, Callable, Dict, Optional 20 | 21 | import brax 22 | import flax 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | import optax 27 | import streamlit.components.v1 as components 28 | import tensorflow as tf 29 | from absl import logging 30 | from brax import envs 31 | from brax.io import html 32 | from brax.training import distribution 33 | from brax.training import networks 34 | from brax.training import normalization 35 | from brax.training import pmap 36 | from brax.training.types import PRNGKey 37 | from brax.training.types import Params 38 | from flax import linen 39 | from jax import custom_vjp 40 | 41 | logging.set_verbosity(logging.INFO) 42 | tf.config.experimental.set_visible_devices([], "GPU") 43 | 44 | 45 | @flax.struct.dataclass 46 | class TrainingState: 47 | """Contains training state for the learner.""" 48 | key: PRNGKey 49 | normalizer_params: Params 50 | state_normalizer_params: Params 51 | optimizer_state: optax.OptState 52 | il_optimizer_state: optax.OptState 53 | policy_params: Params 54 | 55 | 56 | def train( 57 | environment_fn: Callable[..., envs.Env], 58 | episode_length: int, 59 | action_repeat: int = 1, 60 | num_envs: int = 1, 61 | num_eval_envs: int = 128, 62 | max_gradient_norm: float = 1e9, 63 | max_devices_per_host: Optional[int] = None, 64 | learning_rate=1e-4, 65 | normalize_observations=False, 66 | seed=0, 67 | log_frequency=10, 68 | progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None, 69 | truncation_length: Optional[int] = None, 70 | ): 71 | xt = time.time() 72 | 73 | # prepare expert demos 74 | args.logdir = f"multi_traj_logs/{args.env}/{args.env}_ep_len{args.ep_len}_num_envs{args.num_envs}_lr{args.lr}_trunc_len{args.trunc_len}" \ 75 | f"_max_it{args.max_it}_max_grad_norm{args.max_grad_norm}_ef_{args.entropy_factor}" \ 76 | f"_df_{args.deviation_factor}_acf_{args.action_cf_factor}_l2loss_{args.l2}_il_{args.il}_ILD_{args.ILD}" \ 77 | f"/seed{args.seed}" 78 | demo_traj = np.load(f"expert_multi_traj/{args.env}_traj_state.npy") 79 | demo_traj = jnp.array(demo_traj)[:args.ep_len] 80 | demo_traj_action = np.load(f"expert_multi_traj/{args.env}_traj_action.npy") 81 | demo_traj_action = jnp.array(demo_traj_action)[:args.ep_len] 82 | demo_traj_obs = np.load(f"expert_multi_traj/{args.env}_traj_observation.npy") 83 | demo_traj_obs = jnp.array(demo_traj_obs)[:args.ep_len] 84 | 85 | demo_traj_reward = np.load(f"expert_multi_traj/{args.env}_traj_reward.npy") 86 | print("expert reward", demo_traj_reward, "avg", demo_traj_reward.mean()) 87 | 88 | # tensorboard 89 | file_writer = tf.summary.create_file_writer(args.logdir) 90 | file_writer.set_as_default() 91 | 92 | # distributed training setup 93 | process_count = jax.process_count() 94 | process_id = jax.process_index() 95 | local_device_count = jax.local_device_count() 96 | local_devices_to_use = local_device_count 97 | if max_devices_per_host: 98 | local_devices_to_use = min(local_devices_to_use, max_devices_per_host) 99 | logging.info('Device count: %d, process count: %d (id %d), local device count: %d, ' 100 | 'devices to be used count: %d', jax.device_count(), process_count, 101 | process_id, local_device_count, local_devices_to_use) 102 | logging.info('Available devices %s', jax.devices()) 103 | 104 | # seeds 105 | key = jax.random.PRNGKey(seed) 106 | key, key_models, key_env = jax.random.split(key, 3) 107 | key_env = jax.random.split(key_env, process_count)[process_id] 108 | key = jax.random.split(key, process_count)[process_id] 109 | key_debug = jax.random.PRNGKey(seed + 666) 110 | 111 | # envs 112 | core_env = environment_fn( 113 | action_repeat=action_repeat, 114 | batch_size=num_envs // local_devices_to_use // process_count, 115 | episode_length=episode_length) 116 | key_envs = jax.random.split(key_env, local_devices_to_use) 117 | step_fn = jax.jit(core_env.step) 118 | reset_fn = jax.jit(jax.vmap(core_env.reset)) 119 | first_state = reset_fn(key_envs) 120 | 121 | eval_env = environment_fn( 122 | action_repeat=action_repeat, 123 | batch_size=num_eval_envs, 124 | episode_length=episode_length, 125 | eval_metrics=True) 126 | eval_step_fn = jax.jit(eval_env.step) 127 | eval_first_state = jax.jit(eval_env.reset)(key_env) 128 | 129 | # initialize policy 130 | parametric_action_distribution = distribution.NormalTanhDistribution(event_size=core_env.action_size) 131 | policy_model = make_direct_optimization_model(parametric_action_distribution, core_env.observation_size) 132 | 133 | # init optimizer 134 | policy_params = policy_model.init(key_models) 135 | optimizer = optax.adam(learning_rate=learning_rate) 136 | optimizer_state = optimizer.init(policy_params) 137 | il_optimizer_state = optimizer.init(policy_params) 138 | optimizer_state, policy_params, il_optimizer_state = pmap.bcast_local_devices( 139 | (optimizer_state, policy_params, il_optimizer_state), local_devices_to_use) 140 | 141 | # observation normalizer 142 | normalizer_params, obs_normalizer_update_fn, obs_normalizer_apply_fn = ( 143 | normalization.create_observation_normalizer( 144 | core_env.observation_size, 145 | normalize_observations, 146 | num_leading_batch_dims=2, 147 | pmap_to_devices=local_devices_to_use)) 148 | 149 | # state normalizer 150 | state_normalizer_params, state_normalizer_update_fn, state_normalizer_apply_fn = ( 151 | normalization.create_observation_normalizer( 152 | demo_traj.shape[-1], 153 | normalize_observations=True, 154 | num_leading_batch_dims=2, 155 | pmap_to_devices=local_devices_to_use)) 156 | 157 | """ 158 | IL boostrap 159 | """ 160 | 161 | def il_loss(params, normalizer_params, key): 162 | 163 | normalizer_params = obs_normalizer_update_fn(normalizer_params, demo_traj_obs) 164 | normalized_obs = obs_normalizer_apply_fn(normalizer_params, demo_traj_obs) 165 | logits = policy_model.apply(params, normalized_obs) 166 | rollout_actions = parametric_action_distribution.sample(logits, key) 167 | 168 | loss_val = (rollout_actions - demo_traj_action) ** 2 169 | loss_val = loss_val.sum(-1).mean() 170 | return loss_val, normalizer_params 171 | 172 | def il_minimize(training_state: TrainingState): 173 | synchro = pmap.is_replicated((training_state.optimizer_state, 174 | training_state.policy_params, 175 | training_state.normalizer_params, 176 | training_state.state_normalizer_params, 177 | training_state.il_optimizer_state), axis_name='i') 178 | key, key_grad = jax.random.split(training_state.key) 179 | 180 | grad, normalizer_params = il_loss_grad(training_state.policy_params, 181 | training_state.normalizer_params, 182 | key_grad) 183 | grad = clip_by_global_norm(grad) 184 | grad = jax.lax.pmean(grad, axis_name='i') 185 | params_update, il_optimizer_state = optimizer.update(grad, training_state.il_optimizer_state) 186 | policy_params = optax.apply_updates(training_state.policy_params, params_update) 187 | 188 | metrics = { 189 | 'grad_norm': optax.global_norm(grad), 190 | 'params_norm': optax.global_norm(policy_params) 191 | } 192 | return TrainingState( 193 | key=key, 194 | optimizer_state=training_state.optimizer_state, 195 | il_optimizer_state=il_optimizer_state, 196 | normalizer_params=normalizer_params, 197 | state_normalizer_params=training_state.state_normalizer_params, 198 | policy_params=policy_params), metrics, synchro 199 | 200 | """ 201 | Evaluation functions 202 | """ 203 | 204 | def do_one_step_eval(carry, unused_target_t): 205 | state, params, normalizer_params, key = carry 206 | key, key_sample = jax.random.split(key) 207 | # TODO: Make this nicer ([0] comes from pmapping). 208 | obs = obs_normalizer_apply_fn( 209 | jax.tree_map(lambda x: x[0], normalizer_params), state.obs) 210 | print(obs.shape) 211 | print(jax.tree_map(lambda x: x.shape, params)) 212 | logits = policy_model.apply(params, obs) 213 | actions = parametric_action_distribution.sample(logits, key_sample) 214 | nstate = eval_step_fn(state, actions) 215 | return (nstate, params, normalizer_params, key), state 216 | 217 | @jax.jit 218 | def run_eval(params, state, normalizer_params, key): 219 | params = jax.tree_map(lambda x: x[0], params) 220 | (state, _, _, key), state_list = jax.lax.scan( 221 | do_one_step_eval, (state, params, normalizer_params, key), (), 222 | length=episode_length // action_repeat) 223 | return state, key, state_list 224 | 225 | def eval_policy(it, key_debug): 226 | if process_id == 0: 227 | eval_state, key_debug, state_list = run_eval(training_state.policy_params, 228 | eval_first_state, 229 | training_state.normalizer_params, 230 | key_debug) 231 | eval_metrics = eval_state.info['eval_metrics'] 232 | eval_metrics.completed_episodes.block_until_ready() 233 | eval_sps = ( 234 | episode_length * eval_first_state.reward.shape[0] / 235 | (time.time() - t)) 236 | avg_episode_length = ( 237 | eval_metrics.completed_episodes_steps / 238 | eval_metrics.completed_episodes) 239 | metrics = dict( 240 | dict({ 241 | f'eval/episode_{name}': value / eval_metrics.completed_episodes 242 | for name, value in eval_metrics.completed_episodes_metrics.items() 243 | }), 244 | **dict({ 245 | 'eval/completed_episodes': eval_metrics.completed_episodes, 246 | 'eval/avg_episode_length': avg_episode_length, 247 | 'speed/sps': sps, 248 | 'speed/eval_sps': eval_sps, 249 | 'speed/training_walltime': training_walltime, 250 | 'speed/timestamp': training_walltime, 251 | 'train/grad_norm': jnp.mean(summary.get('grad_norm', 0)), 252 | 'train/params_norm': jnp.mean(summary.get('params_norm', 0)), 253 | })) 254 | 255 | logging.info(metrics) 256 | if progress_fn: 257 | progress_fn(it, metrics) 258 | 259 | if it % 10 == 0: 260 | visualize(state_list) 261 | 262 | tf.summary.scalar('eval_episode_reward', data=np.array(metrics['eval/episode_reward']), 263 | step=it * args.num_envs * args.ep_len) 264 | 265 | """ 266 | Training functions 267 | """ 268 | 269 | @partial(custom_vjp) 270 | def norm_grad(x): 271 | return x 272 | 273 | def norm_grad_fwd(x): 274 | return x, () 275 | 276 | def norm_grad_bwd(x, g): 277 | # g /= jnp.linalg.norm(g) 278 | # g = jnp.nan_to_num(g) 279 | g_norm = optax.global_norm(g) 280 | trigger = g_norm < 1.0 281 | g = jax.tree_multimap( 282 | lambda t: jnp.where(trigger, 283 | jnp.nan_to_num(t), 284 | (jnp.nan_to_num(t) / g_norm) * 1.0), g) 285 | return g, 286 | 287 | norm_grad.defvjp(norm_grad_fwd, norm_grad_bwd) 288 | 289 | def do_one_step(carry, step_index): 290 | state, params, normalizer_params, key = carry 291 | key, key_sample = jax.random.split(key) 292 | normalized_obs = obs_normalizer_apply_fn(normalizer_params, state.obs) 293 | logits = policy_model.apply(params, normalized_obs) 294 | actions = parametric_action_distribution.sample(logits, key_sample) 295 | nstate = step_fn(state, actions) 296 | 297 | actions = norm_grad(actions) 298 | nstate = norm_grad(nstate) 299 | 300 | if truncation_length is not None and truncation_length > 0: 301 | nstate = jax.lax.cond( 302 | jnp.mod(step_index + 1, truncation_length) == 0., 303 | jax.lax.stop_gradient, lambda x: x, nstate) 304 | 305 | return (nstate, params, normalizer_params, key), (nstate.reward, state.obs, state.qp, logits, actions) 306 | 307 | def l2_loss(params, normalizer_params, state_normalizer_params, state, key): 308 | _, (rewards, obs, qp_list, logit_list, action_list) = jax.lax.scan( 309 | do_one_step, (state, params, normalizer_params, key), 310 | (jnp.array(range(episode_length // action_repeat))), 311 | length=episode_length // action_repeat) 312 | 313 | rollout_traj = jnp.concatenate([qp_list.pos.reshape((qp_list.pos.shape[0], qp_list.pos.shape[1], -1)), 314 | qp_list.rot.reshape((qp_list.rot.shape[0], qp_list.rot.shape[1], -1)), 315 | qp_list.vel.reshape((qp_list.vel.shape[0], qp_list.vel.shape[1], -1)), 316 | qp_list.ang.reshape((qp_list.ang.shape[0], qp_list.ang.shape[1], -1))], axis=-1) 317 | 318 | # normalize states 319 | normalizer_params = obs_normalizer_update_fn(normalizer_params, obs) 320 | state_normalizer_params = state_normalizer_update_fn(state_normalizer_params, rollout_traj) 321 | 322 | rollout_traj = state_normalizer_apply_fn(state_normalizer_params, rollout_traj) 323 | demo_traj_ = state_normalizer_apply_fn(state_normalizer_params, demo_traj) 324 | 325 | loss_val = (rollout_traj - demo_traj_) ** 2 326 | loss_val = jnp.sqrt(loss_val.sum(-1)).mean() 327 | 328 | return loss_val, (normalizer_params, state_normalizer_params, obs, 0, 0, 0) 329 | 330 | def loss(params, normalizer_params, state_normalizer_params, state, key): 331 | _, (rewards, obs, qp_list, logit_list, action_list) = jax.lax.scan( 332 | do_one_step, (state, params, normalizer_params, key), 333 | (jnp.array(range(episode_length // action_repeat))), 334 | length=episode_length // action_repeat) 335 | 336 | rollout_traj_raw = jnp.concatenate([qp_list.pos.reshape((qp_list.pos.shape[0], qp_list.pos.shape[1], -1)), 337 | qp_list.rot.reshape((qp_list.rot.shape[0], qp_list.rot.shape[1], -1)), 338 | qp_list.vel.reshape((qp_list.vel.shape[0], qp_list.vel.shape[1], -1)), 339 | qp_list.ang.reshape((qp_list.ang.shape[0], qp_list.ang.shape[1], -1))], 340 | axis=-1) 341 | 342 | # normalize states 343 | normalizer_params = obs_normalizer_update_fn(normalizer_params, obs) 344 | state_normalizer_params = state_normalizer_update_fn(state_normalizer_params, rollout_traj_raw) 345 | rollout_traj_raw = state_normalizer_apply_fn(state_normalizer_params, rollout_traj_raw) 346 | 347 | # (num_envs,num_demo,num_step,features) (360,16,128,130) 348 | rollout_traj = rollout_traj_raw.swapaxes(1, 0)[:, None, ...].repeat(demo_traj.shape[1], 1) 349 | demo_traj_ = state_normalizer_apply_fn(state_normalizer_params, demo_traj) 350 | demo_traj_ = demo_traj_.swapaxes(1, 0)[None, ...].repeat(args.num_envs, 0) 351 | 352 | # calc state chamfer loss 353 | # for every state in rollout_traj find closest state in demo 354 | pred = rollout_traj[..., None, :].repeat(args.ep_len, -2) 355 | pred_demo = demo_traj_[..., None, :, :].repeat(args.ep_len, -3) 356 | pred_dis = jnp.sqrt(((pred - pred_demo) ** 2).mean(-1)).min(-1) # (360,16, 128), distance 357 | cf_loss = pred_dis.mean(-1).min(-1).mean() * args.deviation_factor # select the best from k expert demos 358 | 359 | # for every state in demo_traj_ find closest state in rollout_traj 360 | demo = demo_traj_[..., None, :].repeat(args.ep_len, -2) 361 | demo_pred = rollout_traj[..., None, :, :].repeat(args.ep_len, -3) 362 | demo_dis = jnp.sqrt(((demo - demo_pred) ** 2).mean(-1)).min(-1) # (batch, 128, 128), distance 363 | cf_loss += demo_dis.mean(-1).min(-1).mean() 364 | 365 | cf_action_loss, entropy_loss = 0, 0 366 | final_loss = cf_loss + entropy_loss * args.entropy_factor + cf_action_loss * args.action_cf_factor 367 | final_loss = jnp.tanh(final_loss) 368 | 369 | return final_loss, (normalizer_params, state_normalizer_params, obs, cf_loss, entropy_loss, cf_action_loss) 370 | 371 | def _minimize(training_state: TrainingState, state: envs.State): 372 | synchro = pmap.is_replicated((training_state.optimizer_state, 373 | training_state.policy_params, 374 | training_state.normalizer_params, 375 | training_state.state_normalizer_params), axis_name='i') 376 | key, key_grad = jax.random.split(training_state.key) 377 | grad_raw, (normalizer_params, 378 | state_normalizer_params, 379 | obs, cf_loss, entropy_loss, cf_action_loss) = loss_grad(training_state.policy_params, 380 | training_state.normalizer_params, 381 | training_state.state_normalizer_params, 382 | state, key_grad) 383 | grad_raw = jax.tree_multimap(lambda t: jnp.nan_to_num(t), grad_raw) 384 | grad = clip_by_global_norm(grad_raw) 385 | grad = jax.lax.pmean(grad, axis_name='i') 386 | params_update, optimizer_state = optimizer.update(grad, training_state.optimizer_state) 387 | policy_params = optax.apply_updates(training_state.policy_params, params_update) 388 | 389 | metrics = { 390 | 'grad_norm': optax.global_norm(grad_raw), 391 | 'params_norm': optax.global_norm(policy_params), 392 | 'cf_loss': cf_loss, 393 | 'entropy_loss': entropy_loss, 394 | "cf_action_loss": cf_action_loss 395 | } 396 | return TrainingState( 397 | key=key, 398 | optimizer_state=optimizer_state, 399 | il_optimizer_state=training_state.il_optimizer_state, 400 | normalizer_params=normalizer_params, 401 | state_normalizer_params=state_normalizer_params, 402 | policy_params=policy_params), metrics, synchro 403 | 404 | def clip_by_global_norm(updates): 405 | g_norm = optax.global_norm(updates) 406 | trigger = g_norm < max_gradient_norm 407 | updates = jax.tree_multimap( 408 | lambda t: jnp.where(trigger, t, (t / g_norm) * max_gradient_norm), 409 | updates) 410 | 411 | return updates 412 | 413 | # compile training functions 414 | il_loss_grad = jax.grad(il_loss, has_aux=True) 415 | 416 | if args.l2 == 1: 417 | loss_grad = jax.grad(l2_loss, has_aux=True) 418 | print("using l2 loss") 419 | else: 420 | loss_grad = jax.grad(loss, has_aux=True) 421 | print("using chamfer loss") 422 | minimize = jax.pmap(_minimize, axis_name='i') 423 | il_minimize = jax.pmap(il_minimize, axis_name='i') 424 | 425 | # prepare training 426 | sps = 0 427 | training_walltime = 0 428 | summary = {'params_norm': optax.global_norm(jax.tree_map(lambda x: x[0], policy_params))} 429 | key = jnp.stack(jax.random.split(key, local_devices_to_use)) 430 | training_state = TrainingState(key=key, optimizer_state=optimizer_state, 431 | il_optimizer_state=il_optimizer_state, 432 | normalizer_params=normalizer_params, 433 | state_normalizer_params=state_normalizer_params, 434 | policy_params=policy_params) 435 | 436 | # IL bootstrap 437 | if args.il: 438 | for it in range(1000): 439 | logging.info('IL bootstrap starting iteration %s %s', it, time.time() - xt) 440 | t = time.time() 441 | 442 | if it % 100 == 0: 443 | eval_policy(it, key_debug) 444 | 445 | # il optimization 446 | training_state, summary, synchro = il_minimize(training_state) 447 | assert synchro[0], (it, training_state) 448 | jax.tree_map(lambda x: x.block_until_ready(), summary) 449 | eval_policy(0, key_debug) 450 | 451 | # main training loop 452 | if args.ILD: 453 | for it in range(log_frequency + 1): 454 | actor_lr = (1e-5 - args.lr) * float(it / log_frequency) + args.lr 455 | optimizer = optax.adam(learning_rate=actor_lr) 456 | print("actor_lr: ", actor_lr) 457 | 458 | logging.info('starting iteration %s %s', it, time.time() - xt) 459 | t = time.time() 460 | 461 | eval_policy(it, key_debug) 462 | if it == log_frequency: 463 | break 464 | 465 | # optimization 466 | t = time.time() 467 | num_steps = it * args.num_envs * args.ep_len 468 | training_state, metrics, synchro = minimize(training_state, first_state) 469 | tf.summary.scalar('cf_loss', data=np.array(metrics['cf_loss'])[0], step=num_steps) 470 | tf.summary.scalar('entropy_loss', data=np.array(metrics['entropy_loss'])[0], step=num_steps) 471 | tf.summary.scalar('cf_action_loss', data=np.array(metrics['cf_action_loss'])[0], step=num_steps) 472 | tf.summary.scalar('grad_norm', data=np.array(metrics['grad_norm'])[0], step=num_steps) 473 | tf.summary.scalar('params_norm', data=np.array(metrics['params_norm'])[0], step=num_steps) 474 | assert synchro[0], (it, training_state) 475 | jax.tree_map(lambda x: x.block_until_ready(), metrics) 476 | sps = (episode_length * num_envs) / (time.time() - t) 477 | training_walltime += time.time() - t 478 | 479 | params = jax.tree_map(lambda x: x[0], training_state.policy_params) 480 | normalizer_params = jax.tree_map(lambda x: x[0], 481 | training_state.normalizer_params) 482 | params = normalizer_params, params 483 | inference = make_inference_fn(core_env.observation_size, core_env.action_size, 484 | normalize_observations) 485 | 486 | # save params in pickle file 487 | with open(args.logdir + '/params.pkl', 'wb') as f: 488 | pickle.dump(params, f) 489 | 490 | pmap.synchronize_hosts() 491 | 492 | 493 | def make_direct_optimization_model(parametric_action_distribution, obs_size): 494 | return networks.make_model( 495 | [512, 256, parametric_action_distribution.param_size], 496 | obs_size, 497 | activation=linen.swish) 498 | 499 | 500 | def make_inference_fn(observation_size, action_size, normalize_observations): 501 | """Creates params and inference function for the direct optimization agent.""" 502 | parametric_action_distribution = distribution.NormalTanhDistribution( 503 | event_size=action_size) 504 | _, obs_normalizer_apply_fn = normalization.make_data_and_apply_fn( 505 | observation_size, normalize_observations) 506 | policy_model = make_direct_optimization_model(parametric_action_distribution, 507 | observation_size) 508 | 509 | def inference_fn(params, obs, key): 510 | normalizer_params, params = params 511 | obs = obs_normalizer_apply_fn(normalizer_params, obs) 512 | action = parametric_action_distribution.sample( 513 | policy_model.apply(params, obs), key) 514 | return action 515 | 516 | return inference_fn 517 | 518 | 519 | def visualize(state_list): 520 | environment = args.env 521 | env = envs.create(env_name=environment) 522 | 523 | visual_states = [] 524 | for i in range(state_list.qp.ang.shape[0]): 525 | qp_state = brax.QP(np.array(state_list.qp.pos[i, 0]), 526 | np.array(state_list.qp.rot[i, 0]), 527 | np.array(state_list.qp.vel[i, 0]), 528 | np.array(state_list.qp.ang[i, 0])) 529 | visual_states.append(qp_state) 530 | 531 | html_string = html.render(env.sys, visual_states) 532 | components.html(html_string, height=500) 533 | 534 | 535 | if __name__ == '__main__': 536 | parser = argparse.ArgumentParser() 537 | 538 | parser.add_argument('--env', default="reacher") 539 | parser.add_argument('--ep_len', default=128, type=int) 540 | parser.add_argument('--num_envs', default=64, type=int) 541 | parser.add_argument('--lr', default=1e-3, type=float) 542 | parser.add_argument('--trunc_len', default=10, type=int) 543 | parser.add_argument('--max_it', default=5000, type=int) 544 | parser.add_argument('--max_grad_norm', default=0.3, type=float) 545 | parser.add_argument('--entropy_factor', default=0, type=float) 546 | parser.add_argument('--deviation_factor', default=1.0, type=float) 547 | parser.add_argument('--action_cf_factor', default=0, type=float) 548 | parser.add_argument('--il', default=1, type=float) 549 | parser.add_argument('--l2', default=0, type=float) 550 | parser.add_argument('--seed', default=1, type=int) 551 | parser.add_argument('--ILD', default=1, type=int) 552 | 553 | args = parser.parse_args() 554 | 555 | train(environment_fn=envs.create_fn(args.env), 556 | episode_length=args.ep_len, 557 | num_envs=args.num_envs, 558 | learning_rate=args.lr, 559 | normalize_observations=True, 560 | log_frequency=args.max_it, 561 | truncation_length=args.trunc_len, 562 | max_gradient_norm=args.max_grad_norm, 563 | seed=args.seed) 564 | -------------------------------------------------------------------------------- /policy/brax_task/train_on_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | import pickle 17 | import time 18 | from functools import partial 19 | from typing import Any, Callable, Dict, Optional 20 | 21 | import brax 22 | import flax 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | import optax 27 | import streamlit.components.v1 as components 28 | import tensorflow as tf 29 | from absl import logging 30 | from brax import envs 31 | from brax.io import html 32 | from brax.training import distribution 33 | from brax.training import networks 34 | from brax.training import normalization 35 | from brax.training import pmap 36 | from brax.training.types import PRNGKey 37 | from brax.training.types import Params 38 | from flax import linen 39 | from jax import custom_vjp 40 | 41 | logging.set_verbosity(logging.INFO) 42 | tf.config.experimental.set_visible_devices([], "GPU") 43 | best_reward = 0 44 | 45 | 46 | @flax.struct.dataclass 47 | class TrainingState: 48 | """Contains training state for the learner.""" 49 | key: PRNGKey 50 | normalizer_params: Params 51 | state_normalizer_params: Params 52 | optimizer_state: optax.OptState 53 | il_optimizer_state: optax.OptState 54 | policy_params: Params 55 | 56 | 57 | def train( 58 | environment_fn: Callable[..., envs.Env], 59 | episode_length: int, 60 | action_repeat: int = 1, 61 | num_envs: int = 1, 62 | num_eval_envs: int = 128, 63 | max_gradient_norm: float = 1e9, 64 | max_devices_per_host: Optional[int] = None, 65 | learning_rate=1e-4, 66 | normalize_observations=False, 67 | seed=0, 68 | log_frequency=10, 69 | progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None, 70 | truncation_length: Optional[int] = None, 71 | ): 72 | xt = time.time() 73 | 74 | # prepare expert demos 75 | args.logdir = f"logs/{args.env}/{args.env}_ep_len{args.ep_len}_num_envs{args.num_envs}_lr{args.lr}_trunc_len{args.trunc_len}" \ 76 | f"_max_it{args.max_it}_max_grad_norm{args.max_grad_norm}_re_dis{args.reverse_discount}_ef_{args.entropy_factor}" \ 77 | f"_df_{args.deviation_factor}_acf_{args.action_cf_factor}_l2loss_{args.l2}_il_{args.il}_ILD_{args.ILD}" \ 78 | f"/seed{args.seed}" 79 | demo_traj = np.load(f"expert/{args.env}_traj_state.npy") 80 | demo_traj = jnp.array(demo_traj)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1) 81 | demo_traj_action = np.load(f"expert/{args.env}_traj_action.npy") 82 | demo_traj_action = jnp.array(demo_traj_action)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1) 83 | demo_traj_obs = np.load(f"expert/{args.env}_traj_obs.npy") 84 | demo_traj_obs = jnp.array(demo_traj_obs)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1) 85 | reverse_discounts = jnp.array([args.reverse_discount ** i for i in range(args.ep_len, 0, -1)])[None, ...] 86 | reverse_discounts = reverse_discounts.repeat(args.num_envs, 0) 87 | 88 | # tensorboard 89 | file_writer = tf.summary.create_file_writer(args.logdir) 90 | file_writer.set_as_default() 91 | 92 | # distributed training setup 93 | process_count = jax.process_count() 94 | process_id = jax.process_index() 95 | local_device_count = jax.local_device_count() 96 | local_devices_to_use = local_device_count 97 | if max_devices_per_host: 98 | local_devices_to_use = min(local_devices_to_use, max_devices_per_host) 99 | logging.info('Device count: %d, process count: %d (id %d), local device count: %d, ' 100 | 'devices to be used count: %d', jax.device_count(), process_count, 101 | process_id, local_device_count, local_devices_to_use) 102 | logging.info('Available devices %s', jax.devices()) 103 | 104 | # seeds 105 | key = jax.random.PRNGKey(seed) 106 | key, key_models, key_env = jax.random.split(key, 3) 107 | key_env = jax.random.split(key_env, process_count)[process_id] 108 | key = jax.random.split(key, process_count)[process_id] 109 | key_debug = jax.random.PRNGKey(seed + 666) 110 | 111 | # envs 112 | core_env = environment_fn( 113 | action_repeat=action_repeat, 114 | batch_size=num_envs // local_devices_to_use // process_count, 115 | episode_length=episode_length) 116 | key_envs = jax.random.split(key_env, local_devices_to_use) 117 | step_fn = jax.jit(core_env.step) 118 | reset_fn = jax.jit(jax.vmap(core_env.reset)) 119 | first_state = reset_fn(key_envs) 120 | 121 | eval_env = environment_fn( 122 | action_repeat=action_repeat, 123 | batch_size=num_eval_envs, 124 | episode_length=episode_length, 125 | eval_metrics=True) 126 | eval_step_fn = jax.jit(eval_env.step) 127 | eval_first_state = jax.jit(eval_env.reset)(key_env) 128 | 129 | # initialize policy 130 | parametric_action_distribution = distribution.NormalTanhDistribution(event_size=core_env.action_size) 131 | policy_model = make_direct_optimization_model(parametric_action_distribution, core_env.observation_size) 132 | 133 | # init optimizer 134 | policy_params = policy_model.init(key_models) 135 | optimizer = optax.adam(learning_rate=learning_rate) 136 | optimizer_state = optimizer.init(policy_params) 137 | il_optimizer_state = optimizer.init(policy_params) 138 | optimizer_state, policy_params, il_optimizer_state = pmap.bcast_local_devices( 139 | (optimizer_state, policy_params, il_optimizer_state), local_devices_to_use) 140 | 141 | # observation normalizer 142 | normalizer_params, obs_normalizer_update_fn, obs_normalizer_apply_fn = ( 143 | normalization.create_observation_normalizer( 144 | core_env.observation_size, 145 | normalize_observations, 146 | num_leading_batch_dims=2, 147 | pmap_to_devices=local_devices_to_use)) 148 | 149 | # state normalizer 150 | state_normalizer_params, state_normalizer_update_fn, state_normalizer_apply_fn = ( 151 | normalization.create_observation_normalizer( 152 | demo_traj.shape[-1], 153 | normalize_observations=True, 154 | num_leading_batch_dims=2, 155 | pmap_to_devices=local_devices_to_use)) 156 | 157 | """ 158 | IL boostrap 159 | """ 160 | 161 | def il_loss(params, normalizer_params, key): 162 | 163 | normalizer_params = obs_normalizer_update_fn(normalizer_params, demo_traj_obs) 164 | normalized_obs = obs_normalizer_apply_fn(normalizer_params, demo_traj_obs) 165 | logits = policy_model.apply(params, normalized_obs) 166 | rollout_actions = parametric_action_distribution.sample(logits, key) 167 | 168 | loss_val = (rollout_actions - demo_traj_action) ** 2 169 | loss_val = loss_val.sum(-1).mean() 170 | return loss_val, normalizer_params 171 | 172 | def il_minimize(training_state: TrainingState): 173 | synchro = pmap.is_replicated((training_state.optimizer_state, 174 | training_state.policy_params, 175 | training_state.normalizer_params, 176 | training_state.state_normalizer_params, 177 | training_state.il_optimizer_state), axis_name='i') 178 | key, key_grad = jax.random.split(training_state.key) 179 | 180 | grad, normalizer_params = il_loss_grad(training_state.policy_params, 181 | training_state.normalizer_params, 182 | key_grad) 183 | grad = clip_by_global_norm(grad) 184 | grad = jax.lax.pmean(grad, axis_name='i') 185 | params_update, il_optimizer_state = optimizer.update(grad, training_state.il_optimizer_state) 186 | policy_params = optax.apply_updates(training_state.policy_params, params_update) 187 | 188 | metrics = { 189 | 'grad_norm': optax.global_norm(grad), 190 | 'params_norm': optax.global_norm(policy_params) 191 | } 192 | return TrainingState( 193 | key=key, 194 | optimizer_state=training_state.optimizer_state, 195 | il_optimizer_state=il_optimizer_state, 196 | normalizer_params=normalizer_params, 197 | state_normalizer_params=training_state.state_normalizer_params, 198 | policy_params=policy_params), metrics, synchro 199 | 200 | """ 201 | Evaluation functions 202 | """ 203 | 204 | def do_one_step_eval(carry, unused_target_t): 205 | state, params, normalizer_params, key = carry 206 | key, key_sample = jax.random.split(key) 207 | # TODO: Make this nicer ([0] comes from pmapping). 208 | obs = obs_normalizer_apply_fn( 209 | jax.tree_map(lambda x: x[0], normalizer_params), state.obs) 210 | print(obs.shape) 211 | print(jax.tree_map(lambda x: x.shape, params)) 212 | logits = policy_model.apply(params, obs) 213 | actions = parametric_action_distribution.sample(logits, key_sample) 214 | nstate = eval_step_fn(state, actions) 215 | return (nstate, params, normalizer_params, key), state 216 | 217 | @jax.jit 218 | def run_eval(params, state, normalizer_params, key): 219 | params = jax.tree_map(lambda x: x[0], params) 220 | (state, _, _, key), state_list = jax.lax.scan( 221 | do_one_step_eval, (state, params, normalizer_params, key), (), 222 | length=episode_length // action_repeat) 223 | return state, key, state_list 224 | 225 | def eval_policy(it, key_debug): 226 | global best_reward 227 | if process_id == 0: 228 | eval_state, key_debug, state_list = run_eval(training_state.policy_params, 229 | eval_first_state, 230 | training_state.normalizer_params, 231 | key_debug) 232 | eval_metrics = eval_state.info['eval_metrics'] 233 | eval_metrics.completed_episodes.block_until_ready() 234 | eval_sps = ( 235 | episode_length * eval_first_state.reward.shape[0] / 236 | (time.time() - t)) 237 | avg_episode_length = ( 238 | eval_metrics.completed_episodes_steps / 239 | eval_metrics.completed_episodes) 240 | metrics = dict( 241 | dict({ 242 | f'eval/episode_{name}': value / eval_metrics.completed_episodes 243 | for name, value in eval_metrics.completed_episodes_metrics.items() 244 | }), 245 | **dict({ 246 | 'eval/completed_episodes': eval_metrics.completed_episodes, 247 | 'eval/avg_episode_length': avg_episode_length, 248 | 'speed/sps': sps, 249 | 'speed/eval_sps': eval_sps, 250 | 'speed/training_walltime': training_walltime, 251 | 'speed/timestamp': training_walltime, 252 | 'train/grad_norm': jnp.mean(summary.get('grad_norm', 0)), 253 | 'train/params_norm': jnp.mean(summary.get('params_norm', 0)), 254 | })) 255 | 256 | logging.info(metrics) 257 | if progress_fn: 258 | progress_fn(it, metrics) 259 | 260 | if it % 10 == 0: 261 | visualize(state_list) 262 | 263 | tf.summary.scalar('eval_episode_reward', data=np.array(metrics['eval/episode_reward']), 264 | step=it * args.num_envs * args.ep_len) 265 | 266 | if np.array(metrics['eval/episode_reward']) > best_reward: 267 | best_reward = np.array(metrics['eval/episode_reward']) 268 | # save params in pickle file 269 | print('Saving params with reward', best_reward) 270 | params = jax.tree_map(lambda x: x[0], training_state.policy_params) 271 | normalizer_params = jax.tree_map(lambda x: x[0], training_state.normalizer_params) 272 | params_ = normalizer_params, params 273 | with open(args.logdir + '/params.pkl', 'wb') as f: 274 | pickle.dump(params_, f) 275 | 276 | """ 277 | Training functions 278 | """ 279 | 280 | @partial(custom_vjp) 281 | def norm_grad(x): 282 | return x 283 | 284 | def norm_grad_fwd(x): 285 | return x, () 286 | 287 | def norm_grad_bwd(x, g): 288 | # g /= jnp.linalg.norm(g) 289 | # g = jnp.nan_to_num(g) 290 | g_norm = optax.global_norm(g) 291 | trigger = g_norm < 1.0 292 | g = jax.tree_multimap( 293 | lambda t: jnp.where(trigger, 294 | jnp.nan_to_num(t), 295 | (jnp.nan_to_num(t) / g_norm) * 1.0), g) 296 | return g, 297 | 298 | norm_grad.defvjp(norm_grad_fwd, norm_grad_bwd) 299 | 300 | def do_one_step(carry, step_index): 301 | state, params, normalizer_params, key = carry 302 | key, key_sample = jax.random.split(key) 303 | normalized_obs = obs_normalizer_apply_fn(normalizer_params, state.obs) 304 | logits = policy_model.apply(params, normalized_obs) 305 | actions = parametric_action_distribution.sample(logits, key_sample) 306 | nstate = step_fn(state, actions) 307 | 308 | actions = norm_grad(actions) 309 | nstate = norm_grad(nstate) 310 | 311 | if truncation_length is not None and truncation_length > 0: 312 | nstate = jax.lax.cond( 313 | jnp.mod(step_index + 1, truncation_length) == 0., 314 | jax.lax.stop_gradient, lambda x: x, nstate) 315 | 316 | return (nstate, params, normalizer_params, key), (nstate.reward, state.obs, state.qp, logits, actions) 317 | 318 | def l2_loss(params, normalizer_params, state_normalizer_params, state, key): 319 | _, (rewards, obs, qp_list, logit_list, action_list) = jax.lax.scan( 320 | do_one_step, (state, params, normalizer_params, key), 321 | (jnp.array(range(episode_length // action_repeat))), 322 | length=episode_length // action_repeat) 323 | 324 | rollout_traj = jnp.concatenate([qp_list.pos.reshape((qp_list.pos.shape[0], qp_list.pos.shape[1], -1)), 325 | qp_list.rot.reshape((qp_list.rot.shape[0], qp_list.rot.shape[1], -1)), 326 | qp_list.vel.reshape((qp_list.vel.shape[0], qp_list.vel.shape[1], -1)), 327 | qp_list.ang.reshape((qp_list.ang.shape[0], qp_list.ang.shape[1], -1))], axis=-1) 328 | 329 | # normalize states 330 | normalizer_params = obs_normalizer_update_fn(normalizer_params, obs) 331 | state_normalizer_params = state_normalizer_update_fn(state_normalizer_params, rollout_traj) 332 | 333 | rollout_traj = state_normalizer_apply_fn(state_normalizer_params, rollout_traj) 334 | demo_traj_ = state_normalizer_apply_fn(state_normalizer_params, demo_traj) 335 | 336 | loss_val = (rollout_traj - demo_traj_) ** 2 337 | loss_val = jnp.sqrt(loss_val.sum(-1)).mean() 338 | 339 | return loss_val, (normalizer_params, state_normalizer_params, obs, 0, 0, 0) 340 | 341 | def loss(params, normalizer_params, state_normalizer_params, state, key): 342 | _, (rewards, obs, qp_list, logit_list, action_list) = jax.lax.scan( 343 | do_one_step, (state, params, normalizer_params, key), 344 | (jnp.array(range(episode_length // action_repeat))), 345 | length=episode_length // action_repeat) 346 | 347 | rollout_traj = jnp.concatenate([qp_list.pos.reshape((qp_list.pos.shape[0], qp_list.pos.shape[1], -1)), 348 | qp_list.rot.reshape((qp_list.rot.shape[0], qp_list.rot.shape[1], -1)), 349 | qp_list.vel.reshape((qp_list.vel.shape[0], qp_list.vel.shape[1], -1)), 350 | qp_list.ang.reshape((qp_list.ang.shape[0], qp_list.ang.shape[1], -1))], axis=-1) 351 | 352 | # normalize states 353 | normalizer_params = obs_normalizer_update_fn(normalizer_params, obs) 354 | state_normalizer_params = state_normalizer_update_fn(state_normalizer_params, rollout_traj) 355 | 356 | rollout_traj = state_normalizer_apply_fn(state_normalizer_params, rollout_traj) 357 | demo_traj_ = state_normalizer_apply_fn(state_normalizer_params, demo_traj) 358 | 359 | # calc state chamfer loss 360 | # for every state in rollout_traj find closest state in demo 361 | pred = rollout_traj.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 362 | pred_demo = demo_traj_.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 363 | pred_dis = jnp.sqrt(((pred - pred_demo) ** 2).mean(-1)) # (batch, 128, 128), distance 364 | cf_loss = (pred_dis.min(-1) * reverse_discounts).mean() * args.deviation_factor 365 | 366 | # for every state in demo_traj_ find closest state in rollout_traj 367 | demo = demo_traj_.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 368 | demo_pred = rollout_traj.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 369 | demo_dis = jnp.sqrt(((demo - demo_pred) ** 2).mean(-1)) # (batch, 128, 128), distance 370 | cf_loss += (demo_dis.min(-1) * reverse_discounts).mean() 371 | 372 | # calc action cf loss 373 | pred_action = action_list.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 374 | pred_demo_action = demo_traj_action.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 375 | pred_dis_action = jnp.sqrt(((pred_action - pred_demo_action) ** 2).mean(-1)) # (batch, 128, 128), distance 376 | cf_action_loss = (pred_dis_action.min(-1) * reverse_discounts).mean() 377 | 378 | demo_action = demo_traj_action.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 379 | demo_pred_action = action_list.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 380 | demo_dis_action = jnp.sqrt(((demo_action - demo_pred_action) ** 2).mean(-1)) # (batch, 128, 128), distance 381 | cf_action_loss += (demo_dis_action.min(-1) * reverse_discounts).mean() 382 | 383 | # entropy cost 384 | loc, scale = jnp.split(logit_list, 2, axis=-1) 385 | sigma_list = jax.nn.softplus(scale) + parametric_action_distribution._min_std 386 | entropy_loss = -1 * 0.5 * jnp.log(2 * jnp.pi * sigma_list ** 2) 387 | entropy_loss = entropy_loss.mean(-1).mean() 388 | 389 | final_loss = cf_loss + entropy_loss * args.entropy_factor + cf_action_loss * args.action_cf_factor 390 | final_loss = jnp.tanh(final_loss) 391 | 392 | return final_loss, (normalizer_params, state_normalizer_params, obs, cf_loss, entropy_loss, cf_action_loss) 393 | 394 | def _minimize(training_state: TrainingState, state: envs.State): 395 | synchro = pmap.is_replicated((training_state.optimizer_state, 396 | training_state.policy_params, 397 | training_state.normalizer_params, 398 | training_state.state_normalizer_params), axis_name='i') 399 | key, key_grad = jax.random.split(training_state.key) 400 | grad_raw, (normalizer_params, 401 | state_normalizer_params, 402 | obs, cf_loss, entropy_loss, cf_action_loss) = loss_grad(training_state.policy_params, 403 | training_state.normalizer_params, 404 | training_state.state_normalizer_params, 405 | state, key_grad) 406 | grad_raw = jax.tree_multimap(lambda t: jnp.nan_to_num(t), grad_raw) 407 | grad = clip_by_global_norm(grad_raw) 408 | grad = jax.lax.pmean(grad, axis_name='i') 409 | params_update, optimizer_state = optimizer.update(grad, training_state.optimizer_state) 410 | policy_params = optax.apply_updates(training_state.policy_params, params_update) 411 | 412 | metrics = { 413 | 'grad_norm': optax.global_norm(grad_raw), 414 | 'params_norm': optax.global_norm(policy_params), 415 | 'cf_loss': cf_loss, 416 | 'entropy_loss': entropy_loss, 417 | "cf_action_loss": cf_action_loss 418 | } 419 | return TrainingState( 420 | key=key, 421 | optimizer_state=optimizer_state, 422 | il_optimizer_state=training_state.il_optimizer_state, 423 | normalizer_params=normalizer_params, 424 | state_normalizer_params=state_normalizer_params, 425 | policy_params=policy_params), metrics, synchro 426 | 427 | def clip_by_global_norm(updates): 428 | g_norm = optax.global_norm(updates) 429 | trigger = g_norm < max_gradient_norm 430 | updates = jax.tree_multimap( 431 | lambda t: jnp.where(trigger, t, (t / g_norm) * max_gradient_norm), 432 | updates) 433 | 434 | return updates 435 | 436 | # compile training functions 437 | il_loss_grad = jax.grad(il_loss, has_aux=True) 438 | 439 | if args.l2 == 1: 440 | loss_grad = jax.grad(l2_loss, has_aux=True) 441 | print("using l2 loss") 442 | else: 443 | loss_grad = jax.grad(loss, has_aux=True) 444 | print("using chamfer loss") 445 | minimize = jax.pmap(_minimize, axis_name='i') 446 | il_minimize = jax.pmap(il_minimize, axis_name='i') 447 | best_reward = 0 448 | 449 | # prepare training 450 | sps = 0 451 | training_walltime = 0 452 | summary = {'params_norm': optax.global_norm(jax.tree_map(lambda x: x[0], policy_params))} 453 | key = jnp.stack(jax.random.split(key, local_devices_to_use)) 454 | training_state = TrainingState(key=key, optimizer_state=optimizer_state, 455 | il_optimizer_state=il_optimizer_state, 456 | normalizer_params=normalizer_params, 457 | state_normalizer_params=state_normalizer_params, 458 | policy_params=policy_params) 459 | 460 | # IL bootstrap 461 | if args.il: 462 | for it in range(1000): 463 | logging.info('IL bootstrap starting iteration %s %s', it, time.time() - xt) 464 | t = time.time() 465 | 466 | if it % 100 == 0: 467 | eval_policy(it, key_debug) 468 | 469 | # il optimization 470 | training_state, summary, synchro = il_minimize(training_state) 471 | assert synchro[0], (it, training_state) 472 | jax.tree_map(lambda x: x.block_until_ready(), summary) 473 | eval_policy(0, key_debug) 474 | 475 | # main training loop 476 | if args.ILD: 477 | for it in range(log_frequency + 1): 478 | actor_lr = (1e-5 - args.lr) * float(it / log_frequency) + args.lr 479 | optimizer = optax.adam(learning_rate=actor_lr) 480 | print("actor_lr: ", actor_lr) 481 | 482 | logging.info('starting iteration %s %s', it, time.time() - xt) 483 | t = time.time() 484 | 485 | eval_policy(it, key_debug) 486 | if it == log_frequency: 487 | break 488 | 489 | # optimization 490 | t = time.time() 491 | num_steps = it * args.num_envs * args.ep_len 492 | training_state, metrics, synchro = minimize(training_state, first_state) 493 | tf.summary.scalar('cf_loss', data=np.array(metrics['cf_loss'])[0], step=num_steps) 494 | tf.summary.scalar('entropy_loss', data=np.array(metrics['entropy_loss'])[0], step=num_steps) 495 | tf.summary.scalar('cf_action_loss', data=np.array(metrics['cf_action_loss'])[0], step=num_steps) 496 | tf.summary.scalar('grad_norm', data=np.array(metrics['grad_norm'])[0], step=num_steps) 497 | tf.summary.scalar('params_norm', data=np.array(metrics['params_norm'])[0], step=num_steps) 498 | assert synchro[0], (it, training_state) 499 | jax.tree_map(lambda x: x.block_until_ready(), metrics) 500 | sps = (episode_length * num_envs) / (time.time() - t) 501 | training_walltime += time.time() - t 502 | 503 | params = jax.tree_map(lambda x: x[0], training_state.policy_params) 504 | normalizer_params = jax.tree_map(lambda x: x[0], 505 | training_state.normalizer_params) 506 | params = normalizer_params, params 507 | inference = make_inference_fn(core_env.observation_size, core_env.action_size, 508 | normalize_observations) 509 | 510 | pmap.synchronize_hosts() 511 | 512 | 513 | def make_direct_optimization_model(parametric_action_distribution, obs_size): 514 | return networks.make_model( 515 | [512, 256, parametric_action_distribution.param_size], 516 | obs_size, 517 | activation=linen.swish) 518 | 519 | 520 | def make_inference_fn(observation_size, action_size, normalize_observations): 521 | """Creates params and inference function for the direct optimization agent.""" 522 | parametric_action_distribution = distribution.NormalTanhDistribution( 523 | event_size=action_size) 524 | _, obs_normalizer_apply_fn = normalization.make_data_and_apply_fn( 525 | observation_size, normalize_observations) 526 | policy_model = make_direct_optimization_model(parametric_action_distribution, 527 | observation_size) 528 | 529 | def inference_fn(params, obs, key): 530 | normalizer_params, params = params 531 | obs = obs_normalizer_apply_fn(normalizer_params, obs) 532 | action = parametric_action_distribution.sample( 533 | policy_model.apply(params, obs), key) 534 | return action 535 | 536 | return inference_fn 537 | 538 | 539 | def visualize(state_list): 540 | environment = args.env 541 | env = envs.create(env_name=environment) 542 | 543 | visual_states = [] 544 | for i in range(state_list.qp.ang.shape[0]): 545 | qp_state = brax.QP(np.array(state_list.qp.pos[i, 0]), 546 | np.array(state_list.qp.rot[i, 0]), 547 | np.array(state_list.qp.vel[i, 0]), 548 | np.array(state_list.qp.ang[i, 0])) 549 | visual_states.append(qp_state) 550 | 551 | html_string = html.render(env.sys, visual_states) 552 | components.html(html_string, height=500) 553 | 554 | 555 | if __name__ == '__main__': 556 | parser = argparse.ArgumentParser() 557 | 558 | parser.add_argument('--env', default="ant") 559 | parser.add_argument('--ep_len', default=128, type=int) 560 | parser.add_argument('--num_envs', default=300, type=int) 561 | parser.add_argument('--lr', default=1e-3, type=float) 562 | parser.add_argument('--trunc_len', default=10, type=int) 563 | parser.add_argument('--max_it', default=5000, type=int) 564 | parser.add_argument('--max_grad_norm', default=0.3, type=float) 565 | parser.add_argument('--reverse_discount', default=1.0, type=float) 566 | parser.add_argument('--entropy_factor', default=0, type=float) 567 | parser.add_argument('--deviation_factor', default=1.0, type=float) 568 | parser.add_argument('--action_cf_factor', default=0, type=float) 569 | parser.add_argument('--il', default=1, type=float) 570 | parser.add_argument('--l2', default=0, type=float) 571 | parser.add_argument('--seed', default=1, type=int) 572 | parser.add_argument('--ILD', default=1, type=int) 573 | 574 | args = parser.parse_args() 575 | 576 | train(environment_fn=envs.create_fn(args.env), 577 | episode_length=args.ep_len, 578 | num_envs=args.num_envs, 579 | learning_rate=args.lr, 580 | normalize_observations=True, 581 | log_frequency=args.max_it, 582 | truncation_length=args.trunc_len, 583 | max_gradient_norm=args.max_grad_norm, 584 | seed=args.seed) 585 | -------------------------------------------------------------------------------- /policy/cloth_task/expert/hang_cloth.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/cloth_task/expert/hang_cloth.pickle -------------------------------------------------------------------------------- /policy/cloth_task/expert/hang_cloth_traj_action.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/cloth_task/expert/hang_cloth_traj_action.npy -------------------------------------------------------------------------------- /policy/cloth_task/expert/hang_cloth_traj_obs.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/cloth_task/expert/hang_cloth_traj_obs.npy -------------------------------------------------------------------------------- /policy/cloth_task/expert/hang_cloth_traj_state.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/ILD/df699447bfe025a3e7e4e2fda5342b49c3942a0d/policy/cloth_task/expert/hang_cloth_traj_state.npy -------------------------------------------------------------------------------- /policy/cloth_task/train_on_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Garena Online Private Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | # os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' 18 | 19 | import argparse 20 | import sys 21 | import time 22 | from functools import partial 23 | from typing import Any, Callable, Dict, Optional 24 | 25 | import flax 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | import optax 30 | import tensorflow as tf 31 | from absl import logging 32 | from brax import envs 33 | from brax.training import distribution, networks 34 | from brax.training import normalization 35 | from brax.training.types import PRNGKey 36 | from brax.training.types import Params 37 | from flax import linen 38 | from jax import custom_vjp 39 | sys.path.append('./../..') 40 | from core.envs.hang_cloth_env import make_env as make_env_hang_cloth, pole_pos 41 | 42 | logging.set_verbosity(logging.INFO) 43 | 44 | @flax.struct.dataclass 45 | class TrainingState: 46 | """Contains training state for the learner.""" 47 | key: PRNGKey 48 | optimizer_state: optax.OptState 49 | il_optimizer_state: optax.OptState 50 | policy_params: Params 51 | 52 | 53 | def group_state(state): 54 | cloth_state = jnp.concatenate([state[0], state[1]], axis=-1) 55 | cloth_state = cloth_state.reshape(cloth_state.shape[:-2] + (-1,)) 56 | 57 | gripper_state = jnp.concatenate([state[2], state[3]], axis=-1) 58 | combo_state = jnp.concatenate([cloth_state, gripper_state], axis=-1) 59 | 60 | return combo_state 61 | 62 | 63 | def train( 64 | environment_fn: Callable[..., envs.Env], 65 | episode_length: int, 66 | action_repeat: int = 1, 67 | num_envs: int = 1, 68 | num_eval_envs: int = 128, 69 | max_gradient_norm: float = 1e9, 70 | max_devices_per_host: Optional[int] = None, 71 | learning_rate=1e-4, 72 | normalize_observations=False, 73 | seed=0, 74 | log_frequency=10, 75 | progress_fn: Optional[Callable[[int, Dict[str, Any]], None]] = None, 76 | truncation_length: Optional[int] = None, 77 | ): 78 | xt = time.time() 79 | 80 | # prepare expert demos 81 | args.logdir = f"logs/{args.env}/{args.env}_ep_len{args.ep_len}_num_envs{args.num_envs}_lr{args.lr}_trunc_len{args.trunc_len}" \ 82 | f"_max_it{args.max_it}_max_grad_norm{args.max_grad_norm}_re_dis{args.reverse_discount}_ef_{args.entropy_factor}" \ 83 | f"_df_{args.deviation_factor}_acf_{args.action_cf_factor}" \ 84 | f"/seed{args.seed}" 85 | demo_traj_raw = np.load(f"expert/{args.env}_traj_state.npy", allow_pickle=True) 86 | demo_traj_raw = jnp.array(demo_traj_raw)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1) 87 | demo_gripper_traj = demo_traj_raw[..., -10:] 88 | demo_state_traj = demo_traj_raw[..., :-10] 89 | 90 | demo_traj_action = np.load(f"expert/{args.env}_traj_action.npy") 91 | demo_traj_action = jnp.array(demo_traj_action)[:args.ep_len][:, None, ...].repeat(args.num_envs, 1) 92 | 93 | reverse_discounts = jnp.array([args.reverse_discount ** i for i in range(args.ep_len, 0, -1)])[None, ...] 94 | reverse_discounts = reverse_discounts.repeat(args.num_envs, 0) 95 | 96 | # tensorboard 97 | file_writer = tf.summary.create_file_writer(args.logdir) 98 | file_writer.set_as_default() 99 | 100 | # distributed training setup 101 | process_count = jax.process_count() 102 | process_id = jax.process_index() 103 | local_device_count = jax.local_device_count() 104 | local_devices_to_use = local_device_count 105 | if max_devices_per_host: 106 | local_devices_to_use = min(local_devices_to_use, max_devices_per_host) 107 | logging.info('Device count: %d, process count: %d (id %d), local device count: %d, ' 108 | 'devices to be used count: %d', jax.device_count(), process_count, 109 | process_id, local_device_count, local_devices_to_use) 110 | logging.info('Available devices %s', jax.devices()) 111 | 112 | # seeds 113 | key = jax.random.PRNGKey(seed) 114 | key, key_models, key_env = jax.random.split(key, 3) 115 | key_debug = jax.random.PRNGKey(seed + 666) 116 | 117 | # envs 118 | core_env = environment_fn(batch_size=num_envs, episode_length=episode_length) 119 | step_fn = core_env.step_jax 120 | reset_fn = core_env.reset_jax 121 | first_state = reset_fn(key_env, step_fn, args.num_envs) 122 | 123 | eval_env = environment_fn(batch_size=num_eval_envs, episode_length=episode_length) 124 | eval_step_fn = eval_env.step_jax 125 | eval_first_state = eval_env.reset_jax(key_env, eval_step_fn, num_eval_envs) 126 | 127 | visualize_env = environment_fn(batch_size=0, episode_length=episode_length, visualize=False) 128 | visualize_env.step_jax = visualize_env.step_jax 129 | visualize_first_state = visualize_env.reset_jax(key_env, visualize_env.step_jax, batch_size=0) 130 | 131 | # initialize policy 132 | parametric_action_distribution = distribution.NormalTanhDistribution(event_size=core_env.action_size) 133 | policy_model = make_direct_optimization_model(parametric_action_distribution, core_env.observation_size) 134 | 135 | # init optimizer 136 | policy_params = policy_model.init(key_models) 137 | optimizer = optax.adam(learning_rate=learning_rate) 138 | optimizer_state = optimizer.init(policy_params) 139 | il_optimizer_state = optimizer.init(policy_params) 140 | 141 | """ 142 | IL boostrap 143 | """ 144 | 145 | def il_loss(params, key): 146 | 147 | logits = policy_model.apply(params, demo_traj_raw.reshape((-1, demo_traj_raw.shape[-1]))) 148 | logits = logits.reshape(demo_traj_raw.shape[:2] + (-1,)) 149 | rollout_actions = parametric_action_distribution.sample(logits, key) 150 | 151 | loss_val = (rollout_actions - demo_traj_action) ** 2 152 | loss_val = jnp.sqrt(loss_val.sum(-1)).mean() 153 | return loss_val, loss_val 154 | 155 | def il_minimize(training_state: TrainingState): 156 | 157 | grad, loss_val = il_loss_grad(training_state.policy_params, training_state.key) 158 | grad = clip_by_global_norm(grad) 159 | params_update, il_optimizer_state = optimizer.update(grad, training_state.il_optimizer_state) 160 | policy_params = optax.apply_updates(training_state.policy_params, params_update) 161 | 162 | metrics = { 163 | 'grad_norm': optax.global_norm(grad), 164 | 'params_norm': optax.global_norm(policy_params), 165 | 'loss_val': loss_val 166 | } 167 | return TrainingState( 168 | key=key, 169 | optimizer_state=training_state.optimizer_state, 170 | il_optimizer_state=il_optimizer_state, 171 | policy_params=policy_params), metrics, loss_val 172 | 173 | """ 174 | Evaluation functions 175 | """ 176 | 177 | def visualize(params, state, key): 178 | if not visualize_env.visualize: 179 | return 180 | 181 | visualize_env.reset() 182 | for i in range(args.ep_len): 183 | key, key_sample = jax.random.split(key) 184 | combo_state = group_state(state)[None, ...] 185 | logits = policy_model.apply(params, combo_state) 186 | logits = logits.reshape((combo_state.shape[0], -1)) 187 | actions = parametric_action_distribution.sample(logits, key_sample) 188 | actions = actions.squeeze() 189 | 190 | state = visualize_env.step_jax(actions, *state) 191 | visualize_env.state = state 192 | visualize_env.render() 193 | 194 | def do_one_step_eval(carry, unused_target_t): 195 | state, params, key = carry 196 | key, key_sample = jax.random.split(key) 197 | combo_state = group_state(state) 198 | 199 | print(jax.tree_map(lambda x: x.shape, params)) 200 | logits = policy_model.apply(params, combo_state) 201 | logits = logits.reshape((combo_state.shape[0], -1)) 202 | actions = parametric_action_distribution.sample(logits, key_sample) 203 | nstate = eval_step_fn(actions, *state) 204 | return (nstate, params, key), state 205 | 206 | @jax.jit 207 | def run_eval(params, state, key): 208 | (state, _, key), state_list = jax.lax.scan( 209 | do_one_step_eval, (state, params, key), (), 210 | length=episode_length // action_repeat) 211 | demo_traj_raw_ = demo_traj_raw[:, 0, :][:, None, :].repeat(num_eval_envs, 1) 212 | 213 | combo_state = group_state(state_list) 214 | combo_state = jnp.nan_to_num(combo_state) 215 | 216 | loss_val = (combo_state - demo_traj_raw_) ** 2 217 | loss_val = loss_val.sum(-1).mean() 218 | 219 | return state, key, state_list, loss_val 220 | 221 | def eval_policy(it, key_debug): 222 | state, key, state_list, loss_val = run_eval(training_state.policy_params, eval_first_state, key_debug) 223 | 224 | x = jnp.nan_to_num(state[0]) 225 | reward = (x[:, :, 1].max(1) >= pole_pos[1]).astype(jnp.float32) 226 | reward *= (x[:, :, 2].min(1) <= pole_pos[2]).astype(jnp.float32) 227 | reward *= (x[:, :, 2].max(1) >= pole_pos[2]).astype(jnp.float32) 228 | 229 | reward = reward.mean() 230 | 231 | metrics = { 232 | 'reward': reward, 233 | 'loss': loss_val, 234 | 'speed/sps': sps, 235 | 'speed/training_walltime': training_walltime, 236 | 'speed/timestamp': training_walltime, 237 | 'train/grad_norm': jnp.mean(summary.get('grad_norm', 0)), 238 | 'train/params_norm': jnp.mean(summary.get('params_norm', 0)), 239 | } 240 | 241 | logging.info(metrics) 242 | if progress_fn: 243 | progress_fn(it, metrics) 244 | 245 | tf.summary.scalar('eval_episode_loss', data=np.array(loss_val), 246 | step=it * args.num_envs * args.ep_len) 247 | tf.summary.scalar('eval_episode_reward', data=np.array(reward), 248 | step=it * args.num_envs * args.ep_len) 249 | 250 | """ 251 | Training functions 252 | """ 253 | 254 | @partial(custom_vjp, nondiff_argnums=(0,)) 255 | def memory_profiler(num_steps, x): 256 | return x 257 | 258 | def memory_profiler_fwd(num_steps, x): 259 | return x, () 260 | 261 | def memory_profiler_bwd(num_steps, _, g): 262 | 263 | jax.profiler.save_device_memory_profile(f'memory_{num_steps}.prof') 264 | return (g,) 265 | 266 | def do_one_step(carry, step_index): 267 | state, params, key = carry 268 | key, key_sample = jax.random.split(key) 269 | combo_state = group_state(state) 270 | logits = policy_model.apply(params, combo_state) 271 | actions = parametric_action_distribution.sample(logits, key_sample) 272 | nstate = step_fn(actions, *state) 273 | 274 | if truncation_length is not None and truncation_length > 0: 275 | nstate = jax.lax.cond( 276 | jnp.mod(step_index + 1, truncation_length) == 0., 277 | jax.lax.stop_gradient, lambda x: x, nstate) 278 | 279 | return (nstate, params, key), (nstate, logits, actions) 280 | 281 | def loss(params, state, key): 282 | _, (state_list, logit_list, action_list) = jax.lax.scan( 283 | do_one_step, (state, params, key), 284 | (jnp.array(range(episode_length // action_repeat))), 285 | length=episode_length // action_repeat) 286 | 287 | combo_state = group_state(state_list) 288 | 289 | cf_state_loss, cf_gripper_loss, entropy_loss, cf_action_loss = 0, 0, 0, 0 290 | 291 | pred = combo_state.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 292 | pred_demo = demo_traj_raw.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 293 | pred_dis = ((pred - pred_demo) ** 2).mean(-1) # (batch, 128, 128), distance 294 | cf_state_loss = (pred_dis.min(-1) * reverse_discounts).mean() * args.deviation_factor 295 | 296 | # for every state in demo_traj_ find closest state in rollout_traj 297 | demo = demo_traj_raw.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 298 | demo_pred = combo_state.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 299 | demo_dis = ((demo - demo_pred) ** 2).mean(-1) # (batch, 128, 128), distance 300 | cf_state_loss += (demo_dis.min(-1) * reverse_discounts).mean() 301 | 302 | # calc action cf loss 303 | pred_action = action_list.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 304 | pred_demo_action = demo_traj_action.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 305 | pred_dis_action = ((pred_action - pred_demo_action) ** 2).mean(-1) # (batch, 128, 128), distance 306 | cf_action_loss = (pred_dis_action.min(-1) * reverse_discounts).mean() 307 | 308 | demo_action = demo_traj_action.swapaxes(1, 0)[..., None, :].repeat(args.ep_len, -2) 309 | demo_pred_action = action_list.swapaxes(1, 0)[..., None, :, :].repeat(args.ep_len, -3) 310 | demo_dis_action = ((demo_action - demo_pred_action) ** 2).mean(-1) # (batch, 128, 128), distance 311 | cf_action_loss += (demo_dis_action.min(-1) * reverse_discounts).mean() 312 | 313 | # entropy cost 314 | loc, scale = jnp.split(logit_list, 2, axis=-1) 315 | sigma_list = jax.nn.softplus(scale) + parametric_action_distribution._min_std 316 | entropy_loss = -1 * 0.5 * jnp.log(2 * jnp.pi * sigma_list ** 2) 317 | entropy_loss = entropy_loss.mean(-1).mean() 318 | 319 | final_loss = cf_state_loss + cf_gripper_loss + entropy_loss * args.entropy_factor + cf_action_loss * args.action_cf_factor 320 | 321 | return final_loss, (cf_state_loss, cf_gripper_loss, entropy_loss, cf_action_loss) 322 | 323 | def _minimize(training_state, state): 324 | 325 | def minimize_step(carry, step_idx): 326 | policy_params, state, key = carry 327 | grad_raw, (cf_state_loss, cf_gripper_loss, entropy_loss, cf_action_loss) \ 328 | = loss_grad(policy_params, state, key + step_idx.astype(jnp.uint32)) 329 | 330 | return carry, (grad_raw, cf_state_loss, cf_gripper_loss, entropy_loss, cf_action_loss) 331 | 332 | _, (grad_raw, cf_state_loss, cf_gripper_loss, entropy_loss, cf_action_loss) = jax.lax.scan( 333 | minimize_step, (training_state.policy_params, state, training_state.key), 334 | jnp.array(range(args.vp)), length=args.vp) 335 | 336 | grad_raw = jax.tree_multimap(lambda t: jnp.nan_to_num(t), grad_raw) 337 | grad = clip_by_global_norm(grad_raw) 338 | grad = jax.tree_multimap(lambda t: t.mean(0), grad) 339 | 340 | params_update, optimizer_state = optimizer.update(grad, training_state.optimizer_state) 341 | policy_params = optax.apply_updates(training_state.policy_params, params_update) 342 | 343 | metrics = { 344 | 'grad_norm': optax.global_norm(grad_raw), 345 | 'params_norm': optax.global_norm(policy_params), 346 | 'cf_state_loss': cf_state_loss.mean(0), 347 | "cf_gripper_loss": cf_gripper_loss.mean(0), 348 | 'entropy_loss': entropy_loss.mean(0), 349 | "cf_action_loss": cf_action_loss.mean(0) 350 | } 351 | return TrainingState( 352 | key=key, 353 | optimizer_state=optimizer_state, 354 | il_optimizer_state=training_state.il_optimizer_state, 355 | policy_params=policy_params), metrics 356 | 357 | def clip_by_global_norm(updates): 358 | g_norm = optax.global_norm(updates) 359 | trigger = g_norm < max_gradient_norm 360 | updates = jax.tree_multimap( 361 | lambda t: jnp.where(trigger, t, (t / g_norm) * max_gradient_norm), 362 | updates) 363 | 364 | return updates 365 | 366 | # compile training functions 367 | il_loss_grad = jax.grad(il_loss, has_aux=True) 368 | loss_grad = jax.grad(loss, has_aux=True) 369 | _minimize = jax.jit(_minimize) 370 | il_minimize = jax.jit(il_minimize) 371 | memory_profiler.defvjp(memory_profiler_fwd, memory_profiler_bwd) 372 | 373 | # prepare training 374 | sps = 0 375 | training_walltime = 0 376 | summary = {'params_norm': optax.global_norm(jax.tree_map(lambda x: x[0], policy_params))} 377 | training_state = TrainingState(key=key, optimizer_state=optimizer_state, 378 | il_optimizer_state=il_optimizer_state, 379 | policy_params=policy_params) 380 | 381 | # IL bootstrap 382 | if args.il: 383 | for it in range(2000): 384 | # il optimization 385 | training_state, summary, loss_val = il_minimize(training_state) 386 | print('IL bootstrap starting iteration %s %s', it, np.array(loss_val), time.time() - xt) 387 | 388 | # main training loop 389 | for it in range(log_frequency + 1): 390 | logging.info('starting iteration %s %s', it, time.time() - xt) 391 | t = time.time() 392 | 393 | if it % 5 == 0: 394 | visualize(training_state.policy_params, visualize_first_state, key_debug) 395 | eval_policy(it, key_debug) 396 | 397 | if it == log_frequency: 398 | break 399 | 400 | # optimization 401 | t = time.time() 402 | num_steps = it * args.num_envs * args.ep_len 403 | training_state, metrics = _minimize(training_state, first_state) 404 | 405 | tf.summary.scalar('cf_state_loss', data=np.array(metrics['cf_state_loss']), step=num_steps) 406 | tf.summary.scalar('cf_gripper_loss', data=np.array(metrics['cf_gripper_loss']), step=num_steps) 407 | tf.summary.scalar('entropy_loss', data=np.array(metrics['entropy_loss']), step=num_steps) 408 | tf.summary.scalar('cf_action_loss', data=np.array(metrics['cf_action_loss']), step=num_steps) 409 | tf.summary.scalar('grad_norm', data=np.array(metrics['grad_norm']), step=num_steps) 410 | tf.summary.scalar('params_norm', data=np.array(metrics['params_norm']), step=num_steps) 411 | print("cf_state_loss", np.array(metrics['cf_state_loss']), 412 | "cf_action_loss", np.array(metrics['cf_action_loss']), 413 | "entropy_loss", np.array(metrics['entropy_loss']), 414 | "grad_norm", np.array(metrics['grad_norm'])) 415 | 416 | sps = (episode_length * num_envs) / (time.time() - t) 417 | training_walltime += time.time() - t 418 | 419 | params = jax.tree_map(lambda x: x[0], training_state.policy_params) 420 | normalizer_params = jax.tree_map(lambda x: x[0], 421 | training_state.normalizer_params) 422 | params = normalizer_params, params 423 | inference = make_inference_fn(core_env.observation_size, core_env.action_size, 424 | normalize_observations) 425 | 426 | 427 | def make_direct_optimization_model(parametric_action_distribution, obs_size): 428 | 429 | return networks.make_model( 430 | [512, 256, parametric_action_distribution.param_size], 431 | obs_size, 432 | activation=linen.swish) 433 | 434 | 435 | def make_inference_fn(observation_size, action_size, normalize_observations): 436 | """Creates params and inference function for the direct optimization agent.""" 437 | parametric_action_distribution = distribution.NormalTanhDistribution( 438 | event_size=action_size) 439 | _, obs_normalizer_apply_fn = normalization.make_data_and_apply_fn( 440 | observation_size, normalize_observations) 441 | policy_model = make_direct_optimization_model(parametric_action_distribution, 442 | observation_size) 443 | 444 | def inference_fn(params, obs, key): 445 | normalizer_params, params = params 446 | obs = obs_normalizer_apply_fn(normalizer_params, obs) 447 | action = parametric_action_distribution.sample( 448 | policy_model.apply(params, obs), key) 449 | return action 450 | 451 | return inference_fn 452 | 453 | 454 | if __name__ == '__main__': 455 | parser = argparse.ArgumentParser() 456 | 457 | parser.add_argument('--env', default="hang_cloth") 458 | parser.add_argument('--ep_len', default=80, type=int) 459 | parser.add_argument('--num_envs', default=10, type=int) 460 | parser.add_argument('--vp', default=5, type=int, help="virtual p size") 461 | parser.add_argument('--lr', default=1e-4, type=float) 462 | parser.add_argument('--trunc_len', default=10, type=int) 463 | parser.add_argument('--max_it', default=3000, type=int) 464 | parser.add_argument('--max_grad_norm', default=0.3, type=float) 465 | parser.add_argument('--reverse_discount', default=1.0, type=float) 466 | parser.add_argument('--entropy_factor', default=0, type=float) 467 | parser.add_argument('--deviation_factor', default=1.0, type=float) 468 | parser.add_argument('--action_cf_factor', default=0, type=float) 469 | parser.add_argument('--il', default=1, type=float) 470 | parser.add_argument('--seed', default=1, type=int) 471 | 472 | args = parser.parse_args() 473 | 474 | envs = { 475 | "hang_cloth": make_env_hang_cloth 476 | } 477 | 478 | train(environment_fn=envs[args.env], 479 | episode_length=args.ep_len, 480 | num_envs=args.num_envs, 481 | num_eval_envs=128, 482 | learning_rate=args.lr, 483 | normalize_observations=True, 484 | log_frequency=args.max_it, 485 | truncation_length=args.trunc_len, 486 | max_gradient_norm=args.max_grad_norm, 487 | seed=args.seed) 488 | -------------------------------------------------------------------------------- /policy/util/ILD_rollout.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import brax 5 | import jax 6 | import numpy as np 7 | from brax import envs 8 | from brax.io import html 9 | from brax.training import distribution, normalization, ppo, sac 10 | from policy.brax_task.train_on_policy import make_direct_optimization_model 11 | import streamlit.components.v1 as components 12 | 13 | my_path = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | 16 | def rollout(env_name, num_steps=128, use_expert=False, seed=1): 17 | env_fn = envs.create_fn(env_name) 18 | env = env_fn(batch_size=1, episode_length=num_steps * 2, auto_reset=False) 19 | env.step = jax.jit(env.step) 20 | 21 | # initialize policy 22 | if not use_expert: 23 | parametric_action_distribution = distribution.NormalTanhDistribution(event_size=env.action_size) 24 | policy_model = make_direct_optimization_model(parametric_action_distribution, env.observation_size) 25 | policy_model.apply = jax.jit(policy_model.apply) 26 | with open(f'{env_name}_params.pkl', 'rb') as f: 27 | normalizer_params, params = pickle.load(f) 28 | else: 29 | if env_name == "humanoid": 30 | inference = sac.make_inference_fn(env.observation_size, env.action_size, True) 31 | else: 32 | inference = ppo.make_inference_fn(env.observation_size, env.action_size, True) 33 | inference = jax.jit(inference) 34 | with open(f"{my_path}/../brax_task/expert_multi_traj/{env_name}_params.pickle", "rb") as f: 35 | decoded_params = pickle.load(f) 36 | 37 | _, _, obs_normalizer_apply_fn = ( 38 | normalization.create_observation_normalizer( 39 | env.observation_size, 40 | True, 41 | num_leading_batch_dims=2, 42 | pmap_to_devices=1)) 43 | 44 | key = jax.random.PRNGKey(seed) 45 | state = env.reset(jax.random.PRNGKey(seed)) 46 | 47 | def do_one_step_eval(carry, unused_target_t): 48 | state, key = carry 49 | key, key_sample = jax.random.split(key) 50 | 51 | if not use_expert: 52 | normalized_obs = obs_normalizer_apply_fn(normalizer_params, state.obs) 53 | logits = policy_model.apply(params, normalized_obs) 54 | action = parametric_action_distribution.sample(logits, key) 55 | else: 56 | action = inference(decoded_params, state.obs, key) 57 | 58 | nstate = env.step(state, action) 59 | return (nstate, key), state 60 | 61 | _, state_list = jax.lax.scan( 62 | do_one_step_eval, (state, key), (), 63 | length=num_steps) 64 | 65 | print(f'{env_name} reward: {state_list.reward.sum():.2f}') 66 | visualize(state_list, env_name, num_steps) 67 | 68 | 69 | def visualize(state_list, env_name, num_steps): 70 | env = envs.create(env_name=env_name, episode_length=num_steps) 71 | 72 | visual_states = [] 73 | for i in range(state_list.qp.ang.shape[0]): 74 | qp_state = brax.QP(np.array(state_list.qp.pos[i, 0]), 75 | np.array(state_list.qp.rot[i, 0]), 76 | np.array(state_list.qp.vel[i, 0]), 77 | np.array(state_list.qp.ang[i, 0])) 78 | visual_states.append(qp_state) 79 | 80 | html_string = html.render(env.sys, visual_states) 81 | components.html(html_string, height=500) 82 | 83 | 84 | if __name__ == '__main__': 85 | rollout("humanoid", num_steps=128, use_expert=False, seed=0) 86 | -------------------------------------------------------------------------------- /policy/util/expert_rollout.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pickle 4 | 5 | import jax 6 | import numpy as np 7 | 8 | from brax import envs 9 | from brax.training import ppo, sac 10 | 11 | my_path = os.path.dirname(os.path.abspath(__file__)) 12 | logging.getLogger().setLevel(logging.INFO) 13 | 14 | 15 | def train_ppo(env_name, algo="ppo"): 16 | if algo == "ppo": 17 | inference, params, metrics = ppo.train( 18 | envs.create_fn(env_name), 19 | num_timesteps=1e7, 20 | episode_length=128, 21 | num_envs=64, 22 | learning_rate=3e-4, 23 | entropy_cost=1e-2, 24 | discounting=0.95, 25 | unroll_length=5, 26 | batch_size=64, 27 | num_minibatches=8, 28 | num_update_epochs=4, 29 | log_frequency=5, 30 | normalize_observations=True, 31 | reward_scaling=10) 32 | else: 33 | inference, params, metrics = sac.train( 34 | envs.create_fn(env_name), 35 | num_timesteps=5242880 * 2, 36 | episode_length=1000, 37 | action_repeat=1, 38 | num_envs=64, 39 | learning_rate=0.0006, 40 | discounting=0.99, 41 | batch_size=256, 42 | log_frequency=131012, 43 | normalize_observations=True, 44 | reward_scaling=10, 45 | min_replay_size=8192, 46 | max_replay_size=1048576, 47 | grad_updates_per_step=0.125, 48 | seed=2) 49 | 50 | # save paras into pickle 51 | with open(f"{my_path}/../brax_task/expert_multi_traj/{env_name}_params.pickle", "wb") as f: 52 | pickle.dump(params, f) 53 | 54 | 55 | def rollout(env_name, num_steps, num_envs, algo="ppo"): 56 | env_fn = envs.create_fn(env_name) 57 | env = env_fn(batch_size=num_envs * 10, episode_length=num_steps * 2) 58 | env.step = jax.jit(env.step) 59 | 60 | algo = "sac" if env_name == "humanoid" else algo 61 | if algo == "ppo": 62 | inference = ppo.make_inference_fn(env.observation_size, env.action_size, True) 63 | else: 64 | inference = sac.make_inference_fn(env.observation_size, env.action_size, True) 65 | inference = jax.jit(inference) 66 | # load pickle file 67 | with open(f"{my_path}/../brax_task/expert_multi_traj/{env_name}_params.pickle", "rb") as f: 68 | decoded_params = pickle.load(f) 69 | 70 | traj_states = [] 71 | traj_actions = [] 72 | traj_obs = [] 73 | traj_rewards = 0 74 | traj_done = 0 75 | 76 | state = env.reset(jax.random.PRNGKey(0)) 77 | for j in range(num_steps): 78 | print(env_name, "step: ", j) 79 | action = inference(decoded_params, state.obs, jax.random.PRNGKey(0)) 80 | state = env.step(state, action) 81 | qp = np.concatenate([state.qp.pos.reshape((state.qp.pos.shape[0], -1)), 82 | state.qp.rot.reshape((state.qp.rot.shape[0], -1)), 83 | state.qp.vel.reshape((state.qp.vel.shape[0], -1)), 84 | state.qp.ang.reshape((state.qp.ang.shape[0], -1))], axis=-1) 85 | 86 | traj_states.append(qp) 87 | traj_actions.append(action) 88 | traj_obs.append(state.obs) 89 | traj_rewards += state.reward 90 | traj_done += state.done 91 | 92 | os.makedirs(f"{my_path}/../brax_task/expert_multi_traj", exist_ok=True) 93 | print(env_name, "traj reward: ", traj_rewards) 94 | print(env_name, "traj done: ", traj_done) 95 | 96 | traj_rewards = np.array(traj_rewards) 97 | traj_states = np.array(traj_states) 98 | traj_actions = np.array(traj_actions) 99 | traj_obs = np.array(traj_obs) 100 | traj_done = np.array(traj_done) 101 | 102 | # assert traj_done.sum() <= 2 103 | # filter by traj done 104 | traj_states = traj_states[:, traj_done == 0] 105 | traj_actions = traj_actions[:, traj_done == 0] 106 | traj_obs = traj_obs[:, traj_done == 0] 107 | traj_rewards = traj_rewards[traj_done == 0] 108 | 109 | # get_idx from top k rewards 110 | top_k_idx = np.argsort(traj_rewards)[-num_envs:] 111 | print(env_name, "top k rewards: ", traj_rewards[top_k_idx]) 112 | 113 | np.save(f"{my_path}/../brax_task/expert_multi_traj/%s_traj_state.npy" % env_name, traj_states[:, top_k_idx]) 114 | np.save(f"{my_path}/../brax_task/expert_multi_traj/%s_traj_action.npy" % env_name, traj_actions[:, top_k_idx]) 115 | np.save(f"{my_path}/../brax_task/expert_multi_traj/%s_traj_observation.npy" % env_name, traj_obs[:, top_k_idx]) 116 | np.save(f"{my_path}/../brax_task/expert_multi_traj/%s_traj_reward.npy" % env_name, traj_rewards[top_k_idx]) 117 | 118 | np.save(f"{my_path}/../brax_task/expert/%s_traj_state.npy" % env_name, traj_states[:, top_k_idx[-1]]) 119 | np.save(f"{my_path}/../brax_task/expert/%s_traj_action.npy" % env_name, traj_actions[:, top_k_idx[-1]]) 120 | np.save(f"{my_path}/../brax_task/expert/%s_traj_obs.npy" % env_name, traj_obs[:, top_k_idx[-1]]) 121 | np.save(f"{my_path}/../brax_task/expert/%s_traj_reward.npy" % env_name, traj_rewards[top_k_idx[-1]]) 122 | return traj_states, traj_actions, traj_obs 123 | 124 | 125 | def print_demonstration_reward(): 126 | # Ant & Hopper & Humanoid & Reacher & Walker2d & Swimmer & Inverted pendulum & Acrobot 127 | env_names = ["ant", "hopper", "humanoid", "reacher", "walker2d", "swimmer", "inverted_pendulum", "acrobot"] 128 | line = "" 129 | for env_name in env_names: 130 | print(env_name) 131 | traj_rewards = np.load(f"{my_path}/../brax_task/expert_multi_traj/{env_name}_traj_reward.npy") 132 | line += f" & {traj_rewards.mean():.2f} $\pm$ {traj_rewards.std():.2f}" 133 | print(line) 134 | 135 | 136 | if __name__ == '__main__': 137 | # print_demonstration_reward() 138 | # train_ppo("humanoid", algo="sac") 139 | rollout("humanoid", num_steps=128, num_envs=16) 140 | # env_names = ["ant", "walker2d", "humanoid", "acrobot", "reacher", "hopper", "swimmer", "inverted_pendulum"] 141 | # for env_name in env_names: 142 | # rollout(env_name, num_steps=128, num_envs=16) 143 | --------------------------------------------------------------------------------