├── .pre-commit-config.yaml ├── requirements.txt ├── .gitignore ├── Makefile ├── LICENSE ├── .github └── workflows │ └── test.yml ├── pyproject.toml ├── convert.py ├── README.md ├── train.py └── train.ipynb /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/kynan/nbstripout 3 | rev: 0.8.1 4 | hooks: 5 | - id: nbstripout 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # requirements.txt 2 | 3 | # Training dependencies 4 | ksim==0.1.3 5 | xax==0.3.0 6 | mujoco-scenes 7 | 8 | # Inference dependencies 9 | kinfer[jax] 10 | 11 | # Jupyter Notebook dependencies 12 | jupyter 13 | ipykernel 14 | pre-commit 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # .gitignore 2 | 3 | # Python 4 | *.py[oc] 5 | __pycache__/ 6 | *.egg-info 7 | .eggs/ 8 | .mypy_cache/* 9 | .pyre/ 10 | .pytest_cache/ 11 | .ruff_cache/ 12 | .dmypy.json 13 | 14 | # Jupyter 15 | .ipynb_checkpoints/ 16 | 17 | # Databases 18 | *.db 19 | 20 | # Logs 21 | rollouts/ 22 | 23 | # Build artifacts 24 | build/ 25 | dist/ 26 | *.so 27 | out*/ 28 | **/run_*/ 29 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Makefile 2 | 3 | py-files := $(shell find . -name '*.py' -not -path "*/run_*/*" -not -path "*/build/*") 4 | 5 | install: 6 | @pip install --upgrade --upgrade-strategy eager -r requirements.txt 7 | .PHONY: install 8 | 9 | install-dev: 10 | @pip install ruff mypy 11 | .PHONY: install-dev 12 | 13 | format: 14 | @ruff format $(py-files) 15 | @ruff check --fix $(py-files) 16 | .PHONY: format 17 | 18 | static-checks: 19 | @mkdir -p .mypy_cache 20 | @ruff check $(py-files) 21 | @mypy --install-types --non-interactive $(py-files) 22 | .PHONY: lint 23 | 24 | notebook: 25 | jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser 26 | .PHONY: notebook 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 K-Scale Labs 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Python Checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | types: 11 | - opened 12 | - reopened 13 | - synchronize 14 | - ready_for_review 15 | 16 | concurrency: 17 | group: tests-${{ github.head_ref || github.run_id }} 18 | cancel-in-progress: true 19 | 20 | jobs: 21 | run-base-tests: 22 | timeout-minutes: 10 23 | runs-on: ubuntu-latest 24 | steps: 25 | - name: Check out repository 26 | uses: actions/checkout@v4 27 | 28 | - name: Set up Python 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: "3.12" 32 | 33 | - name: Restore cache 34 | id: restore-cache 35 | uses: actions/cache/restore@v3 36 | with: 37 | path: | 38 | ${{ env.pythonLocation }} 39 | .mypy_cache/ 40 | key: python-requirements-${{ env.pythonLocation }}-${{ github.event.pull_request.base.sha || github.sha }} 41 | restore-keys: | 42 | python-requirements-${{ env.pythonLocation }} 43 | python-requirements- 44 | 45 | - name: Install package 46 | run: | 47 | make install 48 | make install-dev 49 | 50 | - name: Run static checks 51 | run: | 52 | make static-checks 53 | 54 | - name: Save cache 55 | uses: actions/cache/save@v3 56 | if: github.ref == 'refs/heads/master' 57 | with: 58 | path: | 59 | ${{ env.pythonLocation }} 60 | .mypy_cache/ 61 | key: ${{ steps.restore-cache.outputs.cache-primary-key }} 62 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | 3 | addopts = "-rx -rf -x -q --full-trace" 4 | testpaths = ["tests"] 5 | 6 | markers = [ 7 | "slow: Marks test as being slow", 8 | ] 9 | 10 | [tool.mypy] 11 | 12 | pretty = true 13 | show_column_numbers = true 14 | show_error_context = true 15 | show_error_codes = true 16 | show_traceback = true 17 | disallow_untyped_defs = true 18 | strict_equality = true 19 | allow_redefinition = true 20 | 21 | warn_unused_ignores = true 22 | warn_redundant_casts = true 23 | 24 | incremental = true 25 | namespace_packages = false 26 | 27 | [[tool.mypy.overrides]] 28 | 29 | module = [ 30 | "_pytest.*", 31 | "distrax.*", 32 | "equinox.*", 33 | "glfw.*", 34 | "mujoco.*", 35 | "optax.*", 36 | "scipy.*", 37 | "bvhio.*", 38 | "glm.*", 39 | "imageio.*", 40 | "tensorflow.*" 41 | ] 42 | ignore_missing_imports = true 43 | 44 | [tool.isort] 45 | 46 | profile = "black" 47 | 48 | [tool.ruff] 49 | 50 | line-length = 120 51 | target-version = "py310" 52 | 53 | [tool.ruff.format] 54 | 55 | quote-style = "double" 56 | docstring-code-format = true 57 | 58 | [tool.ruff.lint] 59 | 60 | select = ["ANN", "D", "E", "F", "G", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"] 61 | 62 | ignore = [ 63 | "D101", "D102", "D103", "D104", "D105", "D106", "D107", 64 | "N812", "N817", 65 | "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004", 66 | "PLW0603", "PLW2901", 67 | ] 68 | 69 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 70 | 71 | [tool.ruff.lint.per-file-ignores] 72 | 73 | "__init__.py" = ["E402", "F401", "F403", "F811"] 74 | 75 | [tool.ruff.lint.isort] 76 | 77 | known-first-party = ["benchmark", "tests"] 78 | combine-as-imports = true 79 | 80 | [tool.ruff.lint.pydocstyle] 81 | 82 | convention = "google" 83 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | """Converts a checkpoint to a deployable model.""" 2 | 3 | import argparse 4 | from pathlib import Path 5 | 6 | import jax 7 | import jax.numpy as jnp 8 | import ksim 9 | from jaxtyping import Array 10 | from kinfer.export.jax import export_fn 11 | from kinfer.export.serialize import pack 12 | from kinfer.rust_bindings import PyModelMetadata 13 | 14 | from train import HumanoidWalkingTask, Model 15 | 16 | 17 | def main() -> None: 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("checkpoint_path", type=str) 20 | parser.add_argument("output_path", type=str) 21 | args = parser.parse_args() 22 | 23 | if not (ckpt_path := Path(args.checkpoint_path)).exists(): 24 | raise FileNotFoundError(f"Checkpoint path {ckpt_path} does not exist") 25 | 26 | task: HumanoidWalkingTask = HumanoidWalkingTask.load_task(ckpt_path) 27 | model: Model = task.load_ckpt(ckpt_path, part="model")[0] 28 | 29 | # Loads the Mujoco model and gets the joint names. 30 | mujoco_model = task.get_mujoco_model() 31 | joint_names = ksim.get_joint_names_in_order(mujoco_model)[1:] # Removes the root joint. 32 | 33 | # Constant values. 34 | carry_shape = (task.config.depth, task.config.hidden_size) 35 | 36 | metadata = PyModelMetadata( 37 | joint_names=joint_names, 38 | num_commands=None, 39 | carry_size=carry_shape, 40 | ) 41 | 42 | @jax.jit 43 | def init_fn() -> Array: 44 | return jnp.zeros(carry_shape) 45 | 46 | @jax.jit 47 | def step_fn( 48 | joint_angles: Array, 49 | joint_angular_velocities: Array, 50 | projected_gravity: Array, 51 | accelerometer: Array, 52 | gyroscope: Array, 53 | time: Array, 54 | carry: Array, 55 | ) -> tuple[Array, Array]: 56 | obs = jnp.concatenate( 57 | [ 58 | jnp.sin(time), 59 | jnp.cos(time), 60 | joint_angles, 61 | joint_angular_velocities, 62 | projected_gravity, 63 | accelerometer, 64 | gyroscope, 65 | ], 66 | axis=-1, 67 | ) 68 | dist, carry = model.actor.forward(obs, carry) 69 | return dist.mode(), carry 70 | 71 | init_onnx = export_fn( 72 | model=init_fn, 73 | metadata=metadata, 74 | ) 75 | 76 | step_onnx = export_fn( 77 | model=step_fn, 78 | metadata=metadata, 79 | ) 80 | 81 | kinfer_model = pack( 82 | init_fn=init_onnx, 83 | step_fn=step_onnx, 84 | metadata=metadata, 85 | ) 86 | 87 | # Saves the resulting model. 88 | (output_path := Path(args.output_path)).parent.mkdir(parents=True, exist_ok=True) 89 | with open(output_path, "wb") as f: 90 | f.write(kinfer_model) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

K-Sim Gym

3 |

Train and deploy your own humanoid robot controller in 700 lines of Python

4 |

5 | Tutorial · 6 | Leaderboard · 7 | Documentation 8 |
9 | K-Sim Examples · 10 | Joystick Example 11 |

12 | 13 | https://github.com/user-attachments/assets/82e5e998-1d62-43e2-ae52-864af6e72629 14 | 15 |
16 | 17 | ## Getting Started 18 | 19 | You can use this repository as a GitHub template or as a Google Colab. 20 | 21 | ### Google Colab 22 | 23 | You can quickly try out the humanoid benchmark by running the [training notebook](https://colab.research.google.com/github/kscalelabs/ksim-gym/blob/master/train.ipynb) in Google Colab. 24 | 25 | ### On your own GPU 26 | 27 | 1. Read through the [current leaderboard](https://www.kscale.dev/benchmarks) submissions and through the [ksim examples](https://github.com/kscalelabs/ksim/tree/master/examples) 28 | 2. Create a new repository from this template by clicking [here](https://github.com/new?template_name=ksim-gym&template_owner=kscalelabs) 29 | 3. Clone the new repository you create from this template: 30 | 31 | ```bash 32 | git clone git@github.com:/ksim-gym.git 33 | cd ksim-gym 34 | ``` 35 | 36 | 4. Create a new Python environment (we require Python 3.11 or later and recommend using [conda](https://docs.conda.io/projects/conda/en/stable/user-guide/getting-started.html)) 37 | 5. Install the package with its dependencies: 38 | 39 | ```bash 40 | pip install -r requirements.txt 41 | pip install 'jax[cuda12]' # If using GPU machine, install JAX CUDA libraries 42 | python -c "import jax; print(jax.default_backend())" # Should print "gpu" 43 | ``` 44 | 45 | 6. Train a policy: 46 | - Your robot should be walking within ~80 training steps, which takes 30 minutes on an RTX 4090 GPU. 47 | - Training runs indefinitely, unless you set the `max_steps` argument. You can also use `Ctrl+C` to stop it. 48 | - Click on the TensorBoard link in the terminal to visualize the current run's training logs and videos. 49 | - List all the available arguments with `python -m train --help`. 50 | ```bash 51 | python -m train 52 | ``` 53 | ```bash 54 | # You can override default arguments like this 55 | python -m train max_steps=100 56 | ``` 57 | 7. To see the TensorBoard logs for all your runs: 58 | ```bash 59 | tensorboard --logdir humanoid_walking_task 60 | ``` 61 | 8. To view your trained checkpoint in the interactive viewer: 62 | - Use the mouse to move the camera around 63 | - Hold `Ctrl` and double click to select a body on the robot, and then left or right click to apply forces to it. 64 | ```bash 65 | python -m train run_mode=view load_from_ckpt_path=humanoid_walking_task/run_/checkpoints/ckpt.bin 66 | ``` 67 | 68 | 9. Convert your trained checkpoint to a `kinfer` model, which can be deployed on a real robot: 69 | 70 | ```bash 71 | python -m convert /path/to/ckpt.bin /path/to/model.kinfer 72 | ``` 73 | 74 | 10. Visualize the converted model in [`kinfer-sim`](https://docs.kscale.dev/docs/k-infer): 75 | 76 | ```bash 77 | kinfer-sim assets/model.kinfer kbot --start-height 1.2 --save-video video.mp4 78 | ``` 79 | 80 | 11. Commit the K-Infer model and the recorded video to this repository 81 | 12. Push your code and model to your repository, and make sure the repository is public (you may need to use [Git LFS](https://git-lfs.com)) 82 | 13. Write a message with a link to your repository on our [Discord](https://url.kscale.dev/discord) in the "【🧠】submissions" channel 83 | 14. Wait for one of us to run it on the real robot - this should take about a day, but if we are dragging our feet, please message us on Discord 84 | 15. Voila! Your name will now appear on our [leaderboard](https://url.kscale.dev/leaderboard) 85 | 86 | ## Troubleshooting 87 | 88 | If you encounter issues, please consult the [ksim documentation](https://docs.kscale.dev/docs/ksim#/) or reach out to us on [Discord](https://url.kscale.dev/discord). 89 | 90 | ## Tips and Tricks 91 | 92 | To see all the available command line arguments, use the command: 93 | 94 | ```bash 95 | python -m train --help 96 | ``` 97 | 98 | To visualize running your model without using `kinfer-sim`, use the command: 99 | 100 | ```bash 101 | python -m train run_mode=view 102 | ``` 103 | 104 | To see an example of a locomotion task with more complex reward tuning, see our [kbot-joystick](https://github.com/kscalelabs/kbot-joystick) task which was generated from this template. It also contains a pretrained checkpoint that you can initialize training from by running 105 | 106 | ```bash 107 | python -m train load_from_ckpt_path=assets/ckpt.bin 108 | ``` 109 | 110 | You can also visualize the pre-trained model by combining these two commands: 111 | 112 | ```bash 113 | python -m train load_from_ckpt_path=assets/ckpt.bin run_mode=view 114 | ``` 115 | 116 | If you want to use the Jupyter notebook and don't want to commit your training logs, we suggest using [pre-commit](https://pre-commit.com/) to clean the notebook before committing: 117 | 118 | ```bash 119 | pip install pre-commit 120 | pre-commit install 121 | ``` 122 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """Defines simple task for training a walking policy for the default humanoid.""" 2 | 3 | import asyncio 4 | import functools 5 | import math 6 | from dataclasses import dataclass 7 | from typing import Self 8 | 9 | import attrs 10 | import distrax 11 | import equinox as eqx 12 | import jax 13 | import jax.numpy as jnp 14 | import ksim 15 | import mujoco 16 | import mujoco_scenes 17 | import mujoco_scenes.mjcf 18 | import optax 19 | import xax 20 | from jaxtyping import Array, PRNGKeyArray 21 | 22 | # These are in the order of the neural network outputs. 23 | ZEROS: list[tuple[str, float]] = [ 24 | ("dof_right_shoulder_pitch_03", 0.0), 25 | ("dof_right_shoulder_roll_03", math.radians(-10.0)), 26 | ("dof_right_shoulder_yaw_02", 0.0), 27 | ("dof_right_elbow_02", math.radians(90.0)), 28 | ("dof_right_wrist_00", 0.0), 29 | ("dof_left_shoulder_pitch_03", 0.0), 30 | ("dof_left_shoulder_roll_03", math.radians(10.0)), 31 | ("dof_left_shoulder_yaw_02", 0.0), 32 | ("dof_left_elbow_02", math.radians(-90.0)), 33 | ("dof_left_wrist_00", 0.0), 34 | ("dof_right_hip_pitch_04", math.radians(-20.0)), 35 | ("dof_right_hip_roll_03", math.radians(-0.0)), 36 | ("dof_right_hip_yaw_03", 0.0), 37 | ("dof_right_knee_04", math.radians(-50.0)), 38 | ("dof_right_ankle_02", math.radians(30.0)), 39 | ("dof_left_hip_pitch_04", math.radians(20.0)), 40 | ("dof_left_hip_roll_03", math.radians(0.0)), 41 | ("dof_left_hip_yaw_03", 0.0), 42 | ("dof_left_knee_04", math.radians(50.0)), 43 | ("dof_left_ankle_02", math.radians(-30.0)), 44 | ] 45 | 46 | 47 | @dataclass 48 | class HumanoidWalkingTaskConfig(ksim.PPOConfig): 49 | """Config for the humanoid walking task.""" 50 | 51 | # Model parameters. 52 | hidden_size: int = xax.field( 53 | value=128, 54 | help="The hidden size for the MLPs.", 55 | ) 56 | depth: int = xax.field( 57 | value=5, 58 | help="The depth for the MLPs.", 59 | ) 60 | num_mixtures: int = xax.field( 61 | value=5, 62 | help="The number of mixtures for the actor.", 63 | ) 64 | var_scale: float = xax.field( 65 | value=0.5, 66 | help="The scale for the standard deviations of the actor.", 67 | ) 68 | use_acc_gyro: bool = xax.field( 69 | value=True, 70 | help="Whether to use the IMU acceleration and gyroscope observations.", 71 | ) 72 | 73 | # Optimizer parameters. 74 | learning_rate: float = xax.field( 75 | value=3e-4, 76 | help="Learning rate for PPO.", 77 | ) 78 | adam_weight_decay: float = xax.field( 79 | value=1e-5, 80 | help="Weight decay for the Adam optimizer.", 81 | ) 82 | 83 | 84 | @attrs.define(frozen=True, kw_only=True) 85 | class JointPositionPenalty(ksim.JointDeviationPenalty): 86 | @classmethod 87 | def create_from_names( 88 | cls, 89 | names: list[str], 90 | physics_model: ksim.PhysicsModel, 91 | scale: float = -1.0, 92 | scale_by_curriculum: bool = False, 93 | ) -> Self: 94 | zeros = {k: v for k, v in ZEROS} 95 | joint_targets = [zeros[name] for name in names] 96 | 97 | return cls.create( 98 | physics_model=physics_model, 99 | joint_names=tuple(names), 100 | joint_targets=tuple(joint_targets), 101 | scale=scale, 102 | scale_by_curriculum=scale_by_curriculum, 103 | ) 104 | 105 | 106 | @attrs.define(frozen=True, kw_only=True) 107 | class BentArmPenalty(JointPositionPenalty): 108 | @classmethod 109 | def create_penalty( 110 | cls, 111 | physics_model: ksim.PhysicsModel, 112 | scale: float = -1.0, 113 | scale_by_curriculum: bool = False, 114 | ) -> Self: 115 | return cls.create_from_names( 116 | names=[ 117 | "dof_right_shoulder_pitch_03", 118 | "dof_right_shoulder_roll_03", 119 | "dof_right_shoulder_yaw_02", 120 | "dof_right_elbow_02", 121 | "dof_right_wrist_00", 122 | "dof_left_shoulder_pitch_03", 123 | "dof_left_shoulder_roll_03", 124 | "dof_left_shoulder_yaw_02", 125 | "dof_left_elbow_02", 126 | "dof_left_wrist_00", 127 | ], 128 | physics_model=physics_model, 129 | scale=scale, 130 | scale_by_curriculum=scale_by_curriculum, 131 | ) 132 | 133 | 134 | @attrs.define(frozen=True, kw_only=True) 135 | class StraightLegPenalty(JointPositionPenalty): 136 | @classmethod 137 | def create_penalty( 138 | cls, 139 | physics_model: ksim.PhysicsModel, 140 | scale: float = -1.0, 141 | scale_by_curriculum: bool = False, 142 | ) -> Self: 143 | return cls.create_from_names( 144 | names=[ 145 | "dof_left_hip_roll_03", 146 | "dof_left_hip_yaw_03", 147 | "dof_right_hip_roll_03", 148 | "dof_right_hip_yaw_03", 149 | ], 150 | physics_model=physics_model, 151 | scale=scale, 152 | scale_by_curriculum=scale_by_curriculum, 153 | ) 154 | 155 | 156 | class Actor(eqx.Module): 157 | """Actor for the walking task.""" 158 | 159 | input_proj: eqx.nn.Linear 160 | rnns: tuple[eqx.nn.GRUCell, ...] 161 | output_proj: eqx.nn.Linear 162 | num_inputs: int = eqx.static_field() 163 | num_outputs: int = eqx.static_field() 164 | num_mixtures: int = eqx.static_field() 165 | min_std: float = eqx.static_field() 166 | max_std: float = eqx.static_field() 167 | var_scale: float = eqx.static_field() 168 | 169 | def __init__( 170 | self, 171 | key: PRNGKeyArray, 172 | *, 173 | num_inputs: int, 174 | num_outputs: int, 175 | min_std: float, 176 | max_std: float, 177 | var_scale: float, 178 | hidden_size: int, 179 | num_mixtures: int, 180 | depth: int, 181 | ) -> None: 182 | # Project input to hidden size 183 | key, input_proj_key = jax.random.split(key) 184 | self.input_proj = eqx.nn.Linear( 185 | in_features=num_inputs, 186 | out_features=hidden_size, 187 | key=input_proj_key, 188 | ) 189 | 190 | # Create RNN layer 191 | key, rnn_key = jax.random.split(key) 192 | rnn_keys = jax.random.split(rnn_key, depth) 193 | self.rnns = tuple( 194 | [ 195 | eqx.nn.GRUCell( 196 | input_size=hidden_size, 197 | hidden_size=hidden_size, 198 | key=rnn_key, 199 | ) 200 | for rnn_key in rnn_keys 201 | ] 202 | ) 203 | 204 | # Project to output 205 | self.output_proj = eqx.nn.Linear( 206 | in_features=hidden_size, 207 | out_features=num_outputs * 3 * num_mixtures, 208 | key=key, 209 | ) 210 | 211 | self.num_inputs = num_inputs 212 | self.num_outputs = num_outputs 213 | self.num_mixtures = num_mixtures 214 | self.min_std = min_std 215 | self.max_std = max_std 216 | self.var_scale = var_scale 217 | 218 | def forward(self, obs_n: Array, carry: Array) -> tuple[distrax.Distribution, Array]: 219 | x_n = self.input_proj(obs_n) 220 | out_carries = [] 221 | for i, rnn in enumerate(self.rnns): 222 | x_n = rnn(x_n, carry[i]) 223 | out_carries.append(x_n) 224 | out_n = self.output_proj(x_n) 225 | 226 | # Reshape the output to be a mixture of gaussians. 227 | slice_len = self.num_outputs * self.num_mixtures 228 | mean_nm = out_n[..., :slice_len].reshape(self.num_outputs, self.num_mixtures) 229 | std_nm = out_n[..., slice_len : slice_len * 2].reshape(self.num_outputs, self.num_mixtures) 230 | logits_nm = out_n[..., slice_len * 2 :].reshape(self.num_outputs, self.num_mixtures) 231 | 232 | # Softplus and clip to ensure positive standard deviations. 233 | std_nm = jnp.clip((jax.nn.softplus(std_nm) + self.min_std) * self.var_scale, max=self.max_std) 234 | 235 | # Apply bias to the means. 236 | mean_nm = mean_nm + jnp.array([v for _, v in ZEROS])[:, None] 237 | 238 | dist_n = ksim.MixtureOfGaussians(means_nm=mean_nm, stds_nm=std_nm, logits_nm=logits_nm) 239 | 240 | return dist_n, jnp.stack(out_carries, axis=0) 241 | 242 | 243 | class Critic(eqx.Module): 244 | """Critic for the walking task.""" 245 | 246 | input_proj: eqx.nn.Linear 247 | rnns: tuple[eqx.nn.GRUCell, ...] 248 | output_proj: eqx.nn.Linear 249 | num_inputs: int = eqx.static_field() 250 | 251 | def __init__( 252 | self, 253 | key: PRNGKeyArray, 254 | *, 255 | num_inputs: int, 256 | hidden_size: int, 257 | depth: int, 258 | ) -> None: 259 | num_outputs = 1 260 | 261 | # Project input to hidden size 262 | key, input_proj_key = jax.random.split(key) 263 | self.input_proj = eqx.nn.Linear( 264 | in_features=num_inputs, 265 | out_features=hidden_size, 266 | key=input_proj_key, 267 | ) 268 | 269 | # Create RNN layer 270 | key, rnn_key = jax.random.split(key) 271 | rnn_keys = jax.random.split(rnn_key, depth) 272 | self.rnns = tuple( 273 | [ 274 | eqx.nn.GRUCell( 275 | input_size=hidden_size, 276 | hidden_size=hidden_size, 277 | key=rnn_key, 278 | ) 279 | for rnn_key in rnn_keys 280 | ] 281 | ) 282 | 283 | # Project to output 284 | self.output_proj = eqx.nn.Linear( 285 | in_features=hidden_size, 286 | out_features=num_outputs, 287 | key=key, 288 | ) 289 | 290 | self.num_inputs = num_inputs 291 | 292 | def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]: 293 | x_n = self.input_proj(obs_n) 294 | out_carries = [] 295 | for i, rnn in enumerate(self.rnns): 296 | x_n = rnn(x_n, carry[i]) 297 | out_carries.append(x_n) 298 | out_n = self.output_proj(x_n) 299 | 300 | return out_n, jnp.stack(out_carries, axis=0) 301 | 302 | 303 | class Model(eqx.Module): 304 | actor: Actor 305 | critic: Critic 306 | 307 | def __init__( 308 | self, 309 | key: PRNGKeyArray, 310 | *, 311 | num_actor_inputs: int, 312 | num_actor_outputs: int, 313 | num_critic_inputs: int, 314 | min_std: float, 315 | max_std: float, 316 | var_scale: float, 317 | hidden_size: int, 318 | num_mixtures: int, 319 | depth: int, 320 | ) -> None: 321 | actor_key, critic_key = jax.random.split(key) 322 | self.actor = Actor( 323 | actor_key, 324 | num_inputs=num_actor_inputs, 325 | num_outputs=num_actor_outputs, 326 | min_std=min_std, 327 | max_std=max_std, 328 | var_scale=var_scale, 329 | hidden_size=hidden_size, 330 | num_mixtures=num_mixtures, 331 | depth=depth, 332 | ) 333 | self.critic = Critic( 334 | critic_key, 335 | hidden_size=hidden_size, 336 | depth=depth, 337 | num_inputs=num_critic_inputs, 338 | ) 339 | 340 | 341 | class HumanoidWalkingTask(ksim.PPOTask[HumanoidWalkingTaskConfig]): 342 | def get_optimizer(self) -> optax.GradientTransformation: 343 | return ( 344 | optax.adam(self.config.learning_rate) 345 | if self.config.adam_weight_decay == 0.0 346 | else optax.adamw(self.config.learning_rate, weight_decay=self.config.adam_weight_decay) 347 | ) 348 | 349 | def get_mujoco_model(self) -> mujoco.MjModel: 350 | mjcf_path = asyncio.run(ksim.get_mujoco_model_path("kbot", name="robot")) 351 | return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene="smooth") 352 | 353 | def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> ksim.Metadata: 354 | metadata = asyncio.run(ksim.get_mujoco_model_metadata("kbot")) 355 | if metadata.joint_name_to_metadata is None: 356 | raise ValueError("Joint metadata is not available") 357 | if metadata.actuator_type_to_metadata is None: 358 | raise ValueError("Actuator metadata is not available") 359 | return metadata 360 | 361 | def get_actuators( 362 | self, 363 | physics_model: ksim.PhysicsModel, 364 | metadata: ksim.Metadata | None = None, 365 | ) -> ksim.Actuators: 366 | assert metadata is not None, "Metadata is required" 367 | return ksim.PositionActuators( 368 | physics_model=physics_model, 369 | metadata=metadata, 370 | ) 371 | 372 | def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]: 373 | return [ 374 | ksim.StaticFrictionRandomizer(), 375 | ksim.ArmatureRandomizer(), 376 | ksim.AllBodiesMassMultiplicationRandomizer(scale_lower=0.95, scale_upper=1.05), 377 | ksim.JointDampingRandomizer(), 378 | ksim.JointZeroPositionRandomizer(scale_lower=math.radians(-2), scale_upper=math.radians(2)), 379 | ] 380 | 381 | def get_events(self, physics_model: ksim.PhysicsModel) -> list[ksim.Event]: 382 | return [ 383 | ksim.PushEvent( 384 | x_force=1.0, 385 | y_force=1.0, 386 | z_force=0.3, 387 | force_range=(0.5, 1.0), 388 | x_angular_force=0.0, 389 | y_angular_force=0.0, 390 | z_angular_force=0.0, 391 | interval_range=(0.5, 4.0), 392 | ), 393 | ] 394 | 395 | def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]: 396 | return [ 397 | ksim.RandomJointPositionReset.create(physics_model, {k: v for k, v in ZEROS}, scale=0.1), 398 | ksim.RandomJointVelocityReset(), 399 | ] 400 | 401 | def get_observations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Observation]: 402 | return [ 403 | ksim.TimestepObservation(), 404 | ksim.JointPositionObservation(noise=math.radians(2)), 405 | ksim.JointVelocityObservation(noise=math.radians(10)), 406 | ksim.ActuatorForceObservation(), 407 | ksim.CenterOfMassInertiaObservation(), 408 | ksim.CenterOfMassVelocityObservation(), 409 | ksim.BasePositionObservation(), 410 | ksim.BaseOrientationObservation(), 411 | ksim.BaseLinearVelocityObservation(), 412 | ksim.BaseAngularVelocityObservation(), 413 | ksim.BaseLinearAccelerationObservation(), 414 | ksim.BaseAngularAccelerationObservation(), 415 | ksim.ActuatorAccelerationObservation(), 416 | ksim.ProjectedGravityObservation.create( 417 | physics_model=physics_model, 418 | framequat_name="imu_site_quat", 419 | lag_range=(0.0, 0.1), 420 | noise=math.radians(1), 421 | ), 422 | ksim.SensorObservation.create( 423 | physics_model=physics_model, 424 | sensor_name="imu_acc", 425 | noise=1.0, 426 | ), 427 | ksim.SensorObservation.create( 428 | physics_model=physics_model, 429 | sensor_name="imu_gyro", 430 | noise=math.radians(10), 431 | ), 432 | ] 433 | 434 | def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]: 435 | return [] 436 | 437 | def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]: 438 | return [ 439 | # Standard rewards. 440 | ksim.NaiveForwardReward(clip_max=1.25, in_robot_frame=False, scale=3.0), 441 | ksim.NaiveForwardOrientationReward(scale=1.0), 442 | ksim.StayAliveReward(scale=1.0), 443 | ksim.UprightReward(scale=0.5), 444 | # Avoid movement penalties. 445 | ksim.AngularVelocityPenalty(index=("x", "y"), scale=-0.1), 446 | ksim.LinearVelocityPenalty(index=("z"), scale=-0.1), 447 | # Normalization penalties. 448 | ksim.AvoidLimitsPenalty.create(physics_model, scale=-0.01), 449 | ksim.JointAccelerationPenalty(scale=-0.01, scale_by_curriculum=True), 450 | ksim.JointJerkPenalty(scale=-0.01, scale_by_curriculum=True), 451 | ksim.LinkAccelerationPenalty(scale=-0.01, scale_by_curriculum=True), 452 | ksim.LinkJerkPenalty(scale=-0.01, scale_by_curriculum=True), 453 | ksim.ActionAccelerationPenalty(scale=-0.01, scale_by_curriculum=True), 454 | # Bespoke rewards. 455 | BentArmPenalty.create_penalty(physics_model, scale=-0.1), 456 | StraightLegPenalty.create_penalty(physics_model, scale=-0.1), 457 | ] 458 | 459 | def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]: 460 | return [ 461 | ksim.BadZTermination(unhealthy_z_lower=0.6, unhealthy_z_upper=1.2), 462 | ksim.FarFromOriginTermination(max_dist=10.0), 463 | ] 464 | 465 | def get_curriculum(self, physics_model: ksim.PhysicsModel) -> ksim.Curriculum: 466 | return ksim.DistanceFromOriginCurriculum( 467 | min_level_steps=5, 468 | ) 469 | 470 | def get_model(self, key: PRNGKeyArray) -> Model: 471 | return Model( 472 | key, 473 | num_actor_inputs=51 if self.config.use_acc_gyro else 45, 474 | num_actor_outputs=len(ZEROS), 475 | num_critic_inputs=446, 476 | min_std=0.001, 477 | max_std=1.0, 478 | var_scale=self.config.var_scale, 479 | hidden_size=self.config.hidden_size, 480 | num_mixtures=self.config.num_mixtures, 481 | depth=self.config.depth, 482 | ) 483 | 484 | def run_actor( 485 | self, 486 | model: Actor, 487 | observations: xax.FrozenDict[str, Array], 488 | commands: xax.FrozenDict[str, Array], 489 | carry: Array, 490 | ) -> tuple[distrax.Distribution, Array]: 491 | time_1 = observations["timestep_observation"] 492 | joint_pos_n = observations["joint_position_observation"] 493 | joint_vel_n = observations["joint_velocity_observation"] 494 | proj_grav_3 = observations["projected_gravity_observation"] 495 | imu_acc_3 = observations["sensor_observation_imu_acc"] 496 | imu_gyro_3 = observations["sensor_observation_imu_gyro"] 497 | 498 | obs = [ 499 | jnp.sin(time_1), 500 | jnp.cos(time_1), 501 | joint_pos_n, # NUM_JOINTS 502 | joint_vel_n, # NUM_JOINTS 503 | proj_grav_3, # 3 504 | ] 505 | if self.config.use_acc_gyro: 506 | obs += [ 507 | imu_acc_3, # 3 508 | imu_gyro_3, # 3 509 | ] 510 | 511 | obs_n = jnp.concatenate(obs, axis=-1) 512 | action, carry = model.forward(obs_n, carry) 513 | 514 | return action, carry 515 | 516 | def run_critic( 517 | self, 518 | model: Critic, 519 | observations: xax.FrozenDict[str, Array], 520 | commands: xax.FrozenDict[str, Array], 521 | carry: Array, 522 | ) -> tuple[Array, Array]: 523 | time_1 = observations["timestep_observation"] 524 | dh_joint_pos_j = observations["joint_position_observation"] 525 | dh_joint_vel_j = observations["joint_velocity_observation"] 526 | com_inertia_n = observations["center_of_mass_inertia_observation"] 527 | com_vel_n = observations["center_of_mass_velocity_observation"] 528 | imu_acc_3 = observations["sensor_observation_imu_acc"] 529 | imu_gyro_3 = observations["sensor_observation_imu_gyro"] 530 | proj_grav_3 = observations["projected_gravity_observation"] 531 | act_frc_obs_n = observations["actuator_force_observation"] 532 | base_pos_3 = observations["base_position_observation"] 533 | base_quat_4 = observations["base_orientation_observation"] 534 | 535 | obs_n = jnp.concatenate( 536 | [ 537 | jnp.sin(time_1), 538 | jnp.cos(time_1), 539 | dh_joint_pos_j, # NUM_JOINTS 540 | dh_joint_vel_j / 10.0, # NUM_JOINTS 541 | com_inertia_n, # 160 542 | com_vel_n, # 96 543 | imu_acc_3, # 3 544 | imu_gyro_3, # 3 545 | proj_grav_3, # 3 546 | act_frc_obs_n / 100.0, # NUM_JOINTS 547 | base_pos_3, # 3 548 | base_quat_4, # 4 549 | ], 550 | axis=-1, 551 | ) 552 | 553 | return model.forward(obs_n, carry) 554 | 555 | def _model_scan_fn( 556 | self, 557 | actor_critic_carry: tuple[Array, Array], 558 | xs: tuple[ksim.Trajectory, PRNGKeyArray], 559 | model: Model, 560 | ) -> tuple[tuple[Array, Array], ksim.PPOVariables]: 561 | transition, rng = xs 562 | 563 | actor_carry, critic_carry = actor_critic_carry 564 | actor_dist, next_actor_carry = self.run_actor( 565 | model=model.actor, 566 | observations=transition.obs, 567 | commands=transition.command, 568 | carry=actor_carry, 569 | ) 570 | 571 | # Gets the log probabilities of the action. 572 | log_probs = actor_dist.log_prob(transition.action) 573 | assert isinstance(log_probs, Array) 574 | 575 | value, next_critic_carry = self.run_critic( 576 | model=model.critic, 577 | observations=transition.obs, 578 | commands=transition.command, 579 | carry=critic_carry, 580 | ) 581 | 582 | transition_ppo_variables = ksim.PPOVariables( 583 | log_probs=log_probs, 584 | values=value.squeeze(-1), 585 | ) 586 | 587 | next_carry = jax.tree.map( 588 | lambda x, y: jnp.where(transition.done, x, y), 589 | self.get_initial_model_carry(rng), 590 | (next_actor_carry, next_critic_carry), 591 | ) 592 | 593 | return next_carry, transition_ppo_variables 594 | 595 | def get_ppo_variables( 596 | self, 597 | model: Model, 598 | trajectory: ksim.Trajectory, 599 | model_carry: tuple[Array, Array], 600 | rng: PRNGKeyArray, 601 | ) -> tuple[ksim.PPOVariables, tuple[Array, Array]]: 602 | scan_fn = functools.partial(self._model_scan_fn, model=model) 603 | next_model_carry, ppo_variables = xax.scan( 604 | scan_fn, 605 | model_carry, 606 | (trajectory, jax.random.split(rng, len(trajectory.done))), 607 | jit_level=4, 608 | ) 609 | return ppo_variables, next_model_carry 610 | 611 | def get_initial_model_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]: 612 | return ( 613 | jnp.zeros(shape=(self.config.depth, self.config.hidden_size)), 614 | jnp.zeros(shape=(self.config.depth, self.config.hidden_size)), 615 | ) 616 | 617 | def sample_action( 618 | self, 619 | model: Model, 620 | model_carry: tuple[Array, Array], 621 | physics_model: ksim.PhysicsModel, 622 | physics_state: ksim.PhysicsState, 623 | observations: xax.FrozenDict[str, Array], 624 | commands: xax.FrozenDict[str, Array], 625 | rng: PRNGKeyArray, 626 | argmax: bool, 627 | ) -> ksim.Action: 628 | actor_carry_in, critic_carry_in = model_carry 629 | action_dist_j, actor_carry = self.run_actor( 630 | model=model.actor, 631 | observations=observations, 632 | commands=commands, 633 | carry=actor_carry_in, 634 | ) 635 | action_j = action_dist_j.mode() if argmax else action_dist_j.sample(seed=rng) 636 | return ksim.Action(action=action_j, carry=(actor_carry, critic_carry_in)) 637 | 638 | 639 | if __name__ == "__main__": 640 | HumanoidWalkingTask.launch( 641 | HumanoidWalkingTaskConfig( 642 | # Training parameters. 643 | num_envs=2048, 644 | batch_size=256, 645 | num_passes=4, 646 | epochs_per_log_step=1, 647 | rollout_length_seconds=8.0, 648 | global_grad_clip=2.0, 649 | # Simulation parameters. 650 | dt=0.002, 651 | ctrl_dt=0.02, 652 | iterations=8, 653 | ls_iterations=8, 654 | action_latency_range=(0.003, 0.01), # Simulate 3-10ms of latency. 655 | drop_action_prob=0.05, # Drop 5% of commands. 656 | # Visualization parameters. 657 | render_track_body_id=0, 658 | # Checkpointing parameters. 659 | save_every_n_seconds=60, 660 | ), 661 | ) 662 | -------------------------------------------------------------------------------- /train.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "0", 6 | "metadata": { 7 | "id": "0" 8 | }, 9 | "source": [ 10 | "# K-Scale Humanoid Benchmark\n", 11 | "\n", 12 | "Welcome to the K-Scale Humanoid Benchmark! This notebook will walk you through training your own reinforcement learning policy, which you can then use to control a K-Scale robot.\n", 13 | "\n", 14 | "*Note:* The Just-In-Time compilation may take a while and cause your Colab instance to appear to disconnect. However, your training cell may actually still be running. Make sure to check before restarting!\n", 15 | "\n", 16 | "*Last updated: 2025/05/15*" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "id": "1", 22 | "metadata": { 23 | "id": "GcUyV2BhHCxS" 24 | }, 25 | "source": [ 26 | "## Dependencies and Config\n", 27 | "\n", 28 | "The K-Scale Humanoid Benchmark uses K-Scale's open-source RL framework [K-Sim](https://github.com/kscalelabs/ksim) for training and the [K-Scale API](https://github.com/kscalelabs/kscale) for asset management." 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "2", 35 | "metadata": { 36 | "colab": { 37 | "base_uri": "https://localhost:8080/" 38 | }, 39 | "id": "X9GR-PWjgynB", 40 | "outputId": "8b5227f7-e06e-465e-ce65-ed19f75d1191" 41 | }, 42 | "outputs": [], 43 | "source": [ 44 | "# Install packages\n", 45 | "\n", 46 | "!pip install ksim==0.1.2 xax==0.3.0 mujoco-scenes" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": null, 52 | "id": "3", 53 | "metadata": { 54 | "colab": { 55 | "base_uri": "https://localhost:8080/" 56 | }, 57 | "id": "19e07786", 58 | "outputId": "78e173ed-afd3-45fc-a518-2a4a2ba31b8f" 59 | }, 60 | "outputs": [], 61 | "source": [ 62 | "# Set up environment variables\n", 63 | "%env MUJOCO_GL=egl" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "4", 70 | "metadata": { 71 | "id": "1" 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "import asyncio\n", 76 | "import functools\n", 77 | "import math\n", 78 | "from dataclasses import dataclass\n", 79 | "from typing import Self\n", 80 | "\n", 81 | "import attrs\n", 82 | "import distrax\n", 83 | "import equinox as eqx\n", 84 | "import jax\n", 85 | "import jax.numpy as jnp\n", 86 | "import ksim\n", 87 | "import mujoco\n", 88 | "import mujoco_scenes\n", 89 | "import mujoco_scenes.mjcf\n", 90 | "import nest_asyncio\n", 91 | "import optax\n", 92 | "import xax\n", 93 | "from jaxtyping import Array, PRNGKeyArray\n", 94 | "\n", 95 | "nest_asyncio.apply()" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "5", 102 | "metadata": { 103 | "id": "2" 104 | }, 105 | "outputs": [], 106 | "source": [ 107 | "# These are in the order of the neural network outputs.\n", 108 | "ZEROS: list[tuple[str, float]] = [\n", 109 | " (\"dof_right_shoulder_pitch_03\", 0.0),\n", 110 | " (\"dof_right_shoulder_roll_03\", math.radians(-10.0)),\n", 111 | " (\"dof_right_shoulder_yaw_02\", 0.0),\n", 112 | " (\"dof_right_elbow_02\", math.radians(90.0)),\n", 113 | " (\"dof_right_wrist_00\", 0.0),\n", 114 | " (\"dof_left_shoulder_pitch_03\", 0.0),\n", 115 | " (\"dof_left_shoulder_roll_03\", math.radians(10.0)),\n", 116 | " (\"dof_left_shoulder_yaw_02\", 0.0),\n", 117 | " (\"dof_left_elbow_02\", math.radians(-90.0)),\n", 118 | " (\"dof_left_wrist_00\", 0.0),\n", 119 | " (\"dof_right_hip_pitch_04\", math.radians(-20.0)),\n", 120 | " (\"dof_right_hip_roll_03\", math.radians(-0.0)),\n", 121 | " (\"dof_right_hip_yaw_03\", 0.0),\n", 122 | " (\"dof_right_knee_04\", math.radians(-50.0)),\n", 123 | " (\"dof_right_ankle_02\", math.radians(30.0)),\n", 124 | " (\"dof_left_hip_pitch_04\", math.radians(20.0)),\n", 125 | " (\"dof_left_hip_roll_03\", math.radians(0.0)),\n", 126 | " (\"dof_left_hip_yaw_03\", 0.0),\n", 127 | " (\"dof_left_knee_04\", math.radians(50.0)),\n", 128 | " (\"dof_left_ankle_02\", math.radians(-30.0)),\n", 129 | "]" 130 | ] 131 | }, 132 | { 133 | "cell_type": "markdown", 134 | "id": "6", 135 | "metadata": { 136 | "id": "3" 137 | }, 138 | "source": [ 139 | "## Rewards\n", 140 | "\n", 141 | "When training a reinforcement learning agent, the most important thing to define is what reward you want the agent to maximimze. `ksim` includes a number of useful default rewards for training walking agents, but it is often a good idea to define new rewards to encourage specific types of behavior. The cell below shows an example of how to define a custom reward. A similar pattern can be used to define custom objectives, events, observations, and more." 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": null, 147 | "id": "7", 148 | "metadata": { 149 | "id": "4" 150 | }, 151 | "outputs": [], 152 | "source": [ 153 | "@attrs.define(frozen=True, kw_only=True)\n", 154 | "class JointPositionPenalty(ksim.JointDeviationPenalty):\n", 155 | " @classmethod\n", 156 | " def create_from_names(\n", 157 | " cls,\n", 158 | " names: list[str],\n", 159 | " physics_model: ksim.PhysicsModel,\n", 160 | " scale: float = -1.0,\n", 161 | " scale_by_curriculum: bool = False,\n", 162 | " ) -> Self:\n", 163 | " zeros = {k: v for k, v in ZEROS}\n", 164 | " joint_targets = [zeros[name] for name in names]\n", 165 | "\n", 166 | " return cls.create(\n", 167 | " physics_model=physics_model,\n", 168 | " joint_names=tuple(names),\n", 169 | " joint_targets=tuple(joint_targets),\n", 170 | " scale=scale,\n", 171 | " scale_by_curriculum=scale_by_curriculum,\n", 172 | " )\n", 173 | "\n", 174 | "\n", 175 | "@attrs.define(frozen=True, kw_only=True)\n", 176 | "class BentArmPenalty(JointPositionPenalty):\n", 177 | " @classmethod\n", 178 | " def create_penalty(\n", 179 | " cls,\n", 180 | " physics_model: ksim.PhysicsModel,\n", 181 | " scale: float = -1.0,\n", 182 | " scale_by_curriculum: bool = False,\n", 183 | " ) -> Self:\n", 184 | " return cls.create_from_names(\n", 185 | " names=[\n", 186 | " \"dof_right_shoulder_pitch_03\",\n", 187 | " \"dof_right_shoulder_roll_03\",\n", 188 | " \"dof_right_shoulder_yaw_02\",\n", 189 | " \"dof_right_elbow_02\",\n", 190 | " \"dof_right_wrist_00\",\n", 191 | " \"dof_left_shoulder_pitch_03\",\n", 192 | " \"dof_left_shoulder_roll_03\",\n", 193 | " \"dof_left_shoulder_yaw_02\",\n", 194 | " \"dof_left_elbow_02\",\n", 195 | " \"dof_left_wrist_00\",\n", 196 | " ],\n", 197 | " physics_model=physics_model,\n", 198 | " scale=scale,\n", 199 | " scale_by_curriculum=scale_by_curriculum,\n", 200 | " )\n", 201 | "\n", 202 | "\n", 203 | "@attrs.define(frozen=True, kw_only=True)\n", 204 | "class StraightLegPenalty(JointPositionPenalty):\n", 205 | " @classmethod\n", 206 | " def create_penalty(\n", 207 | " cls,\n", 208 | " physics_model: ksim.PhysicsModel,\n", 209 | " scale: float = -1.0,\n", 210 | " scale_by_curriculum: bool = False,\n", 211 | " ) -> Self:\n", 212 | " return cls.create_from_names(\n", 213 | " names=[\n", 214 | " \"dof_left_hip_roll_03\",\n", 215 | " \"dof_left_hip_yaw_03\",\n", 216 | " \"dof_right_hip_roll_03\",\n", 217 | " \"dof_right_hip_yaw_03\",\n", 218 | " ],\n", 219 | " physics_model=physics_model,\n", 220 | " scale=scale,\n", 221 | " scale_by_curriculum=scale_by_curriculum,\n", 222 | " )" 223 | ] 224 | }, 225 | { 226 | "cell_type": "markdown", 227 | "id": "8", 228 | "metadata": { 229 | "id": "5" 230 | }, 231 | "source": [ 232 | "## Actor-Critic Model\n", 233 | "\n", 234 | "We train our reinforcement learning agent using an RNN-based actor and critic, which we define below." 235 | ] 236 | }, 237 | { 238 | "cell_type": "code", 239 | "execution_count": null, 240 | "id": "9", 241 | "metadata": { 242 | "id": "6" 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "class Actor(eqx.Module):\n", 247 | " \"\"\"Actor for the walking task.\"\"\"\n", 248 | "\n", 249 | " input_proj: eqx.nn.Linear\n", 250 | " rnns: tuple[eqx.nn.GRUCell, ...]\n", 251 | " output_proj: eqx.nn.Linear\n", 252 | " num_inputs: int = eqx.static_field()\n", 253 | " num_outputs: int = eqx.static_field()\n", 254 | " num_mixtures: int = eqx.static_field()\n", 255 | " min_std: float = eqx.static_field()\n", 256 | " max_std: float = eqx.static_field()\n", 257 | " var_scale: float = eqx.static_field()\n", 258 | "\n", 259 | " def __init__(\n", 260 | " self,\n", 261 | " key: PRNGKeyArray,\n", 262 | " *,\n", 263 | " num_inputs: int,\n", 264 | " num_outputs: int,\n", 265 | " min_std: float,\n", 266 | " max_std: float,\n", 267 | " var_scale: float,\n", 268 | " hidden_size: int,\n", 269 | " num_mixtures: int,\n", 270 | " depth: int,\n", 271 | " ) -> None:\n", 272 | " # Project input to hidden size\n", 273 | " key, input_proj_key = jax.random.split(key)\n", 274 | " self.input_proj = eqx.nn.Linear(\n", 275 | " in_features=num_inputs,\n", 276 | " out_features=hidden_size,\n", 277 | " key=input_proj_key,\n", 278 | " )\n", 279 | "\n", 280 | " # Create RNN layer\n", 281 | " key, rnn_key = jax.random.split(key)\n", 282 | " rnn_keys = jax.random.split(rnn_key, depth)\n", 283 | " self.rnns = tuple(\n", 284 | " [\n", 285 | " eqx.nn.GRUCell(\n", 286 | " input_size=hidden_size,\n", 287 | " hidden_size=hidden_size,\n", 288 | " key=rnn_key,\n", 289 | " )\n", 290 | " for rnn_key in rnn_keys\n", 291 | " ]\n", 292 | " )\n", 293 | "\n", 294 | " # Project to output\n", 295 | " self.output_proj = eqx.nn.Linear(\n", 296 | " in_features=hidden_size,\n", 297 | " out_features=num_outputs * 3 * num_mixtures,\n", 298 | " key=key,\n", 299 | " )\n", 300 | "\n", 301 | " self.num_inputs = num_inputs\n", 302 | " self.num_outputs = num_outputs\n", 303 | " self.num_mixtures = num_mixtures\n", 304 | " self.min_std = min_std\n", 305 | " self.max_std = max_std\n", 306 | " self.var_scale = var_scale\n", 307 | "\n", 308 | " def forward(self, obs_n: Array, carry: Array) -> tuple[distrax.Distribution, Array]:\n", 309 | " x_n = self.input_proj(obs_n)\n", 310 | " out_carries = []\n", 311 | " for i, rnn in enumerate(self.rnns):\n", 312 | " x_n = rnn(x_n, carry[i])\n", 313 | " out_carries.append(x_n)\n", 314 | " out_n = self.output_proj(x_n)\n", 315 | "\n", 316 | " # Reshape the output to be a mixture of gaussians.\n", 317 | " slice_len = self.num_outputs * self.num_mixtures\n", 318 | " mean_nm = out_n[..., :slice_len].reshape(self.num_outputs, self.num_mixtures)\n", 319 | " std_nm = out_n[..., slice_len : slice_len * 2].reshape(self.num_outputs, self.num_mixtures)\n", 320 | " logits_nm = out_n[..., slice_len * 2 :].reshape(self.num_outputs, self.num_mixtures)\n", 321 | "\n", 322 | " # Softplus and clip to ensure positive standard deviations.\n", 323 | " std_nm = jnp.clip((jax.nn.softplus(std_nm) + self.min_std) * self.var_scale, max=self.max_std)\n", 324 | "\n", 325 | " # Apply bias to the means.\n", 326 | " mean_nm = mean_nm + jnp.array([v for _, v in ZEROS])[:, None]\n", 327 | "\n", 328 | " dist_n = ksim.MixtureOfGaussians(means_nm=mean_nm, stds_nm=std_nm, logits_nm=logits_nm)\n", 329 | "\n", 330 | " return dist_n, jnp.stack(out_carries, axis=0)\n", 331 | "\n", 332 | "\n", 333 | "class Critic(eqx.Module):\n", 334 | " \"\"\"Critic for the walking task.\"\"\"\n", 335 | "\n", 336 | " input_proj: eqx.nn.Linear\n", 337 | " rnns: tuple[eqx.nn.GRUCell, ...]\n", 338 | " output_proj: eqx.nn.Linear\n", 339 | " num_inputs: int = eqx.static_field()\n", 340 | "\n", 341 | " def __init__(\n", 342 | " self,\n", 343 | " key: PRNGKeyArray,\n", 344 | " *,\n", 345 | " num_inputs: int,\n", 346 | " hidden_size: int,\n", 347 | " depth: int,\n", 348 | " ) -> None:\n", 349 | " num_outputs = 1\n", 350 | "\n", 351 | " # Project input to hidden size\n", 352 | " key, input_proj_key = jax.random.split(key)\n", 353 | " self.input_proj = eqx.nn.Linear(\n", 354 | " in_features=num_inputs,\n", 355 | " out_features=hidden_size,\n", 356 | " key=input_proj_key,\n", 357 | " )\n", 358 | "\n", 359 | " # Create RNN layer\n", 360 | " key, rnn_key = jax.random.split(key)\n", 361 | " rnn_keys = jax.random.split(rnn_key, depth)\n", 362 | " self.rnns = tuple(\n", 363 | " [\n", 364 | " eqx.nn.GRUCell(\n", 365 | " input_size=hidden_size,\n", 366 | " hidden_size=hidden_size,\n", 367 | " key=rnn_key,\n", 368 | " )\n", 369 | " for rnn_key in rnn_keys\n", 370 | " ]\n", 371 | " )\n", 372 | "\n", 373 | " # Project to output\n", 374 | " self.output_proj = eqx.nn.Linear(\n", 375 | " in_features=hidden_size,\n", 376 | " out_features=num_outputs,\n", 377 | " key=key,\n", 378 | " )\n", 379 | "\n", 380 | " self.num_inputs = num_inputs\n", 381 | "\n", 382 | " def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]:\n", 383 | " x_n = self.input_proj(obs_n)\n", 384 | " out_carries = []\n", 385 | " for i, rnn in enumerate(self.rnns):\n", 386 | " x_n = rnn(x_n, carry[i])\n", 387 | " out_carries.append(x_n)\n", 388 | " out_n = self.output_proj(x_n)\n", 389 | "\n", 390 | " return out_n, jnp.stack(out_carries, axis=0)\n", 391 | "\n", 392 | "\n", 393 | "class Model(eqx.Module):\n", 394 | " actor: Actor\n", 395 | " critic: Critic\n", 396 | "\n", 397 | " def __init__(\n", 398 | " self,\n", 399 | " key: PRNGKeyArray,\n", 400 | " *,\n", 401 | " num_actor_inputs: int,\n", 402 | " num_actor_outputs: int,\n", 403 | " num_critic_inputs: int,\n", 404 | " min_std: float,\n", 405 | " max_std: float,\n", 406 | " var_scale: float,\n", 407 | " hidden_size: int,\n", 408 | " num_mixtures: int,\n", 409 | " depth: int,\n", 410 | " ) -> None:\n", 411 | " actor_key, critic_key = jax.random.split(key)\n", 412 | " self.actor = Actor(\n", 413 | " actor_key,\n", 414 | " num_inputs=num_actor_inputs,\n", 415 | " num_outputs=num_actor_outputs,\n", 416 | " min_std=min_std,\n", 417 | " max_std=max_std,\n", 418 | " var_scale=var_scale,\n", 419 | " hidden_size=hidden_size,\n", 420 | " num_mixtures=num_mixtures,\n", 421 | " depth=depth,\n", 422 | " )\n", 423 | " self.critic = Critic(\n", 424 | " critic_key,\n", 425 | " hidden_size=hidden_size,\n", 426 | " depth=depth,\n", 427 | " num_inputs=num_critic_inputs,\n", 428 | " )" 429 | ] 430 | }, 431 | { 432 | "cell_type": "markdown", 433 | "id": "10", 434 | "metadata": { 435 | "id": "7" 436 | }, 437 | "source": [ 438 | "## Config\n", 439 | "\n", 440 | "The [ksim framework](https://github.com/kscalelabs/ksim) is based on [xax](https://github.com/kscalelabs/xax), a JAX training library built by K-Scale. To provide configuration options, xax uses a Config dataclass to parse command-line options. We define the config here." 441 | ] 442 | }, 443 | { 444 | "cell_type": "code", 445 | "execution_count": null, 446 | "id": "11", 447 | "metadata": { 448 | "id": "8" 449 | }, 450 | "outputs": [], 451 | "source": [ 452 | "@dataclass\n", 453 | "class HumanoidWalkingTaskConfig(ksim.PPOConfig):\n", 454 | " \"\"\"Config for the humanoid walking task.\"\"\"\n", 455 | "\n", 456 | " # Model parameters.\n", 457 | " hidden_size: int = xax.field(\n", 458 | " value=128,\n", 459 | " help=\"The hidden size for the MLPs.\",\n", 460 | " )\n", 461 | " depth: int = xax.field(\n", 462 | " value=5,\n", 463 | " help=\"The depth for the MLPs.\",\n", 464 | " )\n", 465 | " num_mixtures: int = xax.field(\n", 466 | " value=5,\n", 467 | " help=\"The number of mixtures for the actor.\",\n", 468 | " )\n", 469 | " var_scale: float = xax.field(\n", 470 | " value=0.5,\n", 471 | " help=\"The scale for the standard deviations of the actor.\",\n", 472 | " )\n", 473 | " use_acc_gyro: bool = xax.field(\n", 474 | " value=True,\n", 475 | " help=\"Whether to use the IMU acceleration and gyroscope observations.\",\n", 476 | " )\n", 477 | "\n", 478 | " # Curriculum parameters.\n", 479 | " num_curriculum_levels: int = xax.field(\n", 480 | " value=100,\n", 481 | " help=\"The number of curriculum levels to use.\",\n", 482 | " )\n", 483 | " increase_threshold: float = xax.field(\n", 484 | " value=5.0,\n", 485 | " help=\"Increase the curriculum level when the mean trajectory length is above this threshold.\",\n", 486 | " )\n", 487 | " decrease_threshold: float = xax.field(\n", 488 | " value=1.0,\n", 489 | " help=\"Decrease the curriculum level when the mean trajectory length is below this threshold.\",\n", 490 | " )\n", 491 | " min_level_steps: int = xax.field(\n", 492 | " value=1,\n", 493 | " help=\"The minimum number of steps to wait before changing the curriculum level.\",\n", 494 | " )\n", 495 | "\n", 496 | " # Optimizer parameters.\n", 497 | " learning_rate: float = xax.field(\n", 498 | " value=3e-4,\n", 499 | " help=\"Learning rate for PPO.\",\n", 500 | " )\n", 501 | " adam_weight_decay: float = xax.field(\n", 502 | " value=1e-5,\n", 503 | " help=\"Weight decay for the Adam optimizer.\",\n", 504 | " )" 505 | ] 506 | }, 507 | { 508 | "cell_type": "markdown", 509 | "id": "12", 510 | "metadata": { 511 | "id": "9" 512 | }, 513 | "source": [ 514 | "## Task\n", 515 | "\n", 516 | "The meat-and-potatoes of our training code is the task. This defines the observations, rewards, model calling logic, and everything else needed by `ksim` to train our reinforcement learning agent." 517 | ] 518 | }, 519 | { 520 | "cell_type": "code", 521 | "execution_count": null, 522 | "id": "13", 523 | "metadata": { 524 | "id": "10" 525 | }, 526 | "outputs": [], 527 | "source": [ 528 | "class HumanoidWalkingTask(ksim.PPOTask[HumanoidWalkingTaskConfig]):\n", 529 | " def get_optimizer(self) -> optax.GradientTransformation:\n", 530 | " return (\n", 531 | " optax.adam(self.config.learning_rate)\n", 532 | " if self.config.adam_weight_decay == 0.0\n", 533 | " else optax.adamw(self.config.learning_rate, weight_decay=self.config.adam_weight_decay)\n", 534 | " )\n", 535 | "\n", 536 | " def get_mujoco_model(self) -> mujoco.MjModel:\n", 537 | " mjcf_path = asyncio.run(ksim.get_mujoco_model_path(\"kbot\", name=\"robot\"))\n", 538 | " return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene=\"smooth\")\n", 539 | "\n", 540 | " def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> ksim.Metadata:\n", 541 | " metadata = asyncio.run(ksim.get_mujoco_model_metadata(\"kbot\"))\n", 542 | " if metadata.joint_name_to_metadata is None:\n", 543 | " raise ValueError(\"Joint metadata is not available\")\n", 544 | " if metadata.actuator_type_to_metadata is None:\n", 545 | " raise ValueError(\"Actuator metadata is not available\")\n", 546 | " return metadata\n", 547 | "\n", 548 | " def get_actuators(\n", 549 | " self,\n", 550 | " physics_model: ksim.PhysicsModel,\n", 551 | " metadata: ksim.Metadata | None = None,\n", 552 | " ) -> ksim.Actuators:\n", 553 | " assert metadata is not None, \"Metadata is required\"\n", 554 | " return ksim.PositionActuators(\n", 555 | " physics_model=physics_model,\n", 556 | " metadata=metadata,\n", 557 | " )\n", 558 | "\n", 559 | " def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]:\n", 560 | " return [\n", 561 | " ksim.StaticFrictionRandomizer(),\n", 562 | " ksim.ArmatureRandomizer(),\n", 563 | " ksim.AllBodiesMassMultiplicationRandomizer(scale_lower=0.95, scale_upper=1.05),\n", 564 | " ksim.JointDampingRandomizer(),\n", 565 | " ksim.JointZeroPositionRandomizer(scale_lower=math.radians(-2), scale_upper=math.radians(2)),\n", 566 | " ]\n", 567 | "\n", 568 | " def get_events(self, physics_model: ksim.PhysicsModel) -> list[ksim.Event]:\n", 569 | " return [\n", 570 | " ksim.PushEvent(\n", 571 | " x_force=1.0,\n", 572 | " y_force=1.0,\n", 573 | " z_force=0.3,\n", 574 | " force_range=(0.5, 1.0),\n", 575 | " x_angular_force=0.0,\n", 576 | " y_angular_force=0.0,\n", 577 | " z_angular_force=0.0,\n", 578 | " interval_range=(0.5, 4.0),\n", 579 | " ),\n", 580 | " ]\n", 581 | "\n", 582 | " def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]:\n", 583 | " return [\n", 584 | " ksim.RandomJointPositionReset.create(physics_model, {k: v for k, v in ZEROS}, scale=0.1),\n", 585 | " ksim.RandomJointVelocityReset(),\n", 586 | " ]\n", 587 | "\n", 588 | " def get_observations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Observation]:\n", 589 | " return [\n", 590 | " ksim.TimestepObservation(),\n", 591 | " ksim.JointPositionObservation(noise=math.radians(2)),\n", 592 | " ksim.JointVelocityObservation(noise=math.radians(10)),\n", 593 | " ksim.ActuatorForceObservation(),\n", 594 | " ksim.CenterOfMassInertiaObservation(),\n", 595 | " ksim.CenterOfMassVelocityObservation(),\n", 596 | " ksim.BasePositionObservation(),\n", 597 | " ksim.BaseOrientationObservation(),\n", 598 | " ksim.BaseLinearVelocityObservation(),\n", 599 | " ksim.BaseAngularVelocityObservation(),\n", 600 | " ksim.BaseLinearAccelerationObservation(),\n", 601 | " ksim.BaseAngularAccelerationObservation(),\n", 602 | " ksim.ActuatorAccelerationObservation(),\n", 603 | " ksim.ProjectedGravityObservation.create(\n", 604 | " physics_model=physics_model,\n", 605 | " framequat_name=\"imu_site_quat\",\n", 606 | " lag_range=(0.0, 0.1),\n", 607 | " noise=math.radians(1),\n", 608 | " ),\n", 609 | " ksim.SensorObservation.create(\n", 610 | " physics_model=physics_model,\n", 611 | " sensor_name=\"imu_acc\",\n", 612 | " noise=1.0,\n", 613 | " ),\n", 614 | " ksim.SensorObservation.create(\n", 615 | " physics_model=physics_model,\n", 616 | " sensor_name=\"imu_gyro\",\n", 617 | " noise=math.radians(10),\n", 618 | " ),\n", 619 | " ]\n", 620 | "\n", 621 | " def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:\n", 622 | " return []\n", 623 | "\n", 624 | " def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:\n", 625 | " return [\n", 626 | " # Standard rewards.\n", 627 | " ksim.NaiveForwardReward(clip_max=1.25, in_robot_frame=False, scale=3.0),\n", 628 | " ksim.NaiveForwardOrientationReward(scale=1.0),\n", 629 | " ksim.StayAliveReward(scale=1.0),\n", 630 | " ksim.UprightReward(scale=0.5),\n", 631 | " # Avoid movement penalties.\n", 632 | " ksim.AngularVelocityPenalty(index=(\"x\", \"y\"), scale=-0.1),\n", 633 | " ksim.LinearVelocityPenalty(index=(\"z\"), scale=-0.1),\n", 634 | " # Normalization penalties.\n", 635 | " ksim.AvoidLimitsPenalty.create(physics_model, scale=-0.01),\n", 636 | " ksim.JointAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n", 637 | " ksim.JointJerkPenalty(scale=-0.01, scale_by_curriculum=True),\n", 638 | " ksim.LinkAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n", 639 | " ksim.LinkJerkPenalty(scale=-0.01, scale_by_curriculum=True),\n", 640 | " ksim.ActionAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n", 641 | " # Bespoke rewards.\n", 642 | " BentArmPenalty.create_penalty(physics_model, scale=-0.1),\n", 643 | " StraightLegPenalty.create_penalty(physics_model, scale=-0.1),\n", 644 | " ]\n", 645 | "\n", 646 | " def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:\n", 647 | " return [\n", 648 | " ksim.BadZTermination(unhealthy_z_lower=0.6, unhealthy_z_upper=1.2),\n", 649 | " ksim.FarFromOriginTermination(max_dist=10.0),\n", 650 | " ]\n", 651 | "\n", 652 | " def get_curriculum(self, physics_model: ksim.PhysicsModel) -> ksim.Curriculum:\n", 653 | " return ksim.DistanceFromOriginCurriculum(\n", 654 | " min_level_steps=5,\n", 655 | " )\n", 656 | "\n", 657 | " def get_model(self, key: PRNGKeyArray) -> Model:\n", 658 | " return Model(\n", 659 | " key,\n", 660 | " num_actor_inputs=51 if self.config.use_acc_gyro else 45,\n", 661 | " num_actor_outputs=len(ZEROS),\n", 662 | " num_critic_inputs=446,\n", 663 | " min_std=0.001,\n", 664 | " max_std=1.0,\n", 665 | " var_scale=self.config.var_scale,\n", 666 | " hidden_size=self.config.hidden_size,\n", 667 | " num_mixtures=self.config.num_mixtures,\n", 668 | " depth=self.config.depth,\n", 669 | " )\n", 670 | "\n", 671 | " def run_actor(\n", 672 | " self,\n", 673 | " model: Actor,\n", 674 | " observations: xax.FrozenDict[str, Array],\n", 675 | " commands: xax.FrozenDict[str, Array],\n", 676 | " carry: Array,\n", 677 | " ) -> tuple[distrax.Distribution, Array]:\n", 678 | " time_1 = observations[\"timestep_observation\"]\n", 679 | " joint_pos_n = observations[\"joint_position_observation\"]\n", 680 | " joint_vel_n = observations[\"joint_velocity_observation\"]\n", 681 | " proj_grav_3 = observations[\"projected_gravity_observation\"]\n", 682 | " imu_acc_3 = observations[\"sensor_observation_imu_acc\"]\n", 683 | " imu_gyro_3 = observations[\"sensor_observation_imu_gyro\"]\n", 684 | "\n", 685 | " obs = [\n", 686 | " jnp.sin(time_1),\n", 687 | " jnp.cos(time_1),\n", 688 | " joint_pos_n, # NUM_JOINTS\n", 689 | " joint_vel_n, # NUM_JOINTS\n", 690 | " proj_grav_3, # 3\n", 691 | " ]\n", 692 | " if self.config.use_acc_gyro:\n", 693 | " obs += [\n", 694 | " imu_acc_3, # 3\n", 695 | " imu_gyro_3, # 3\n", 696 | " ]\n", 697 | "\n", 698 | " obs_n = jnp.concatenate(obs, axis=-1)\n", 699 | " action, carry = model.forward(obs_n, carry)\n", 700 | "\n", 701 | " return action, carry\n", 702 | "\n", 703 | " def run_critic(\n", 704 | " self,\n", 705 | " model: Critic,\n", 706 | " observations: xax.FrozenDict[str, Array],\n", 707 | " commands: xax.FrozenDict[str, Array],\n", 708 | " carry: Array,\n", 709 | " ) -> tuple[Array, Array]:\n", 710 | " time_1 = observations[\"timestep_observation\"]\n", 711 | " dh_joint_pos_j = observations[\"joint_position_observation\"]\n", 712 | " dh_joint_vel_j = observations[\"joint_velocity_observation\"]\n", 713 | " com_inertia_n = observations[\"center_of_mass_inertia_observation\"]\n", 714 | " com_vel_n = observations[\"center_of_mass_velocity_observation\"]\n", 715 | " imu_acc_3 = observations[\"sensor_observation_imu_acc\"]\n", 716 | " imu_gyro_3 = observations[\"sensor_observation_imu_gyro\"]\n", 717 | " proj_grav_3 = observations[\"projected_gravity_observation\"]\n", 718 | " act_frc_obs_n = observations[\"actuator_force_observation\"]\n", 719 | " base_pos_3 = observations[\"base_position_observation\"]\n", 720 | " base_quat_4 = observations[\"base_orientation_observation\"]\n", 721 | "\n", 722 | " obs_n = jnp.concatenate(\n", 723 | " [\n", 724 | " jnp.sin(time_1),\n", 725 | " jnp.cos(time_1),\n", 726 | " dh_joint_pos_j, # NUM_JOINTS\n", 727 | " dh_joint_vel_j / 10.0, # NUM_JOINTS\n", 728 | " com_inertia_n, # 160\n", 729 | " com_vel_n, # 96\n", 730 | " imu_acc_3, # 3\n", 731 | " imu_gyro_3, # 3\n", 732 | " proj_grav_3, # 3\n", 733 | " act_frc_obs_n / 100.0, # NUM_JOINTS\n", 734 | " base_pos_3, # 3\n", 735 | " base_quat_4, # 4\n", 736 | " ],\n", 737 | " axis=-1,\n", 738 | " )\n", 739 | "\n", 740 | " return model.forward(obs_n, carry)\n", 741 | "\n", 742 | " def _model_scan_fn(\n", 743 | " self,\n", 744 | " actor_critic_carry: tuple[Array, Array],\n", 745 | " xs: tuple[ksim.Trajectory, PRNGKeyArray],\n", 746 | " model: Model,\n", 747 | " ) -> tuple[tuple[Array, Array], ksim.PPOVariables]:\n", 748 | " transition, rng = xs\n", 749 | "\n", 750 | " actor_carry, critic_carry = actor_critic_carry\n", 751 | " actor_dist, next_actor_carry = self.run_actor(\n", 752 | " model=model.actor,\n", 753 | " observations=transition.obs,\n", 754 | " commands=transition.command,\n", 755 | " carry=actor_carry,\n", 756 | " )\n", 757 | "\n", 758 | " # Gets the log probabilities of the action.\n", 759 | " log_probs = actor_dist.log_prob(transition.action)\n", 760 | " assert isinstance(log_probs, Array)\n", 761 | "\n", 762 | " value, next_critic_carry = self.run_critic(\n", 763 | " model=model.critic,\n", 764 | " observations=transition.obs,\n", 765 | " commands=transition.command,\n", 766 | " carry=critic_carry,\n", 767 | " )\n", 768 | "\n", 769 | " transition_ppo_variables = ksim.PPOVariables(\n", 770 | " log_probs=log_probs,\n", 771 | " values=value.squeeze(-1),\n", 772 | " )\n", 773 | "\n", 774 | " next_carry = jax.tree.map(\n", 775 | " lambda x, y: jnp.where(transition.done, x, y),\n", 776 | " self.get_initial_model_carry(rng),\n", 777 | " (next_actor_carry, next_critic_carry),\n", 778 | " )\n", 779 | "\n", 780 | " return next_carry, transition_ppo_variables\n", 781 | "\n", 782 | " def get_ppo_variables(\n", 783 | " self,\n", 784 | " model: Model,\n", 785 | " trajectory: ksim.Trajectory,\n", 786 | " model_carry: tuple[Array, Array],\n", 787 | " rng: PRNGKeyArray,\n", 788 | " ) -> tuple[ksim.PPOVariables, tuple[Array, Array]]:\n", 789 | " scan_fn = functools.partial(self._model_scan_fn, model=model)\n", 790 | " next_model_carry, ppo_variables = xax.scan(\n", 791 | " scan_fn,\n", 792 | " model_carry,\n", 793 | " (trajectory, jax.random.split(rng, len(trajectory.done))),\n", 794 | " jit_level=4,\n", 795 | " )\n", 796 | " return ppo_variables, next_model_carry\n", 797 | "\n", 798 | " def get_initial_model_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:\n", 799 | " return (\n", 800 | " jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),\n", 801 | " jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),\n", 802 | " )\n", 803 | "\n", 804 | " def sample_action(\n", 805 | " self,\n", 806 | " model: Model,\n", 807 | " model_carry: tuple[Array, Array],\n", 808 | " physics_model: ksim.PhysicsModel,\n", 809 | " physics_state: ksim.PhysicsState,\n", 810 | " observations: xax.FrozenDict[str, Array],\n", 811 | " commands: xax.FrozenDict[str, Array],\n", 812 | " rng: PRNGKeyArray,\n", 813 | " argmax: bool,\n", 814 | " ) -> ksim.Action:\n", 815 | " actor_carry_in, critic_carry_in = model_carry\n", 816 | " action_dist_j, actor_carry = self.run_actor(\n", 817 | " model=model.actor,\n", 818 | " observations=observations,\n", 819 | " commands=commands,\n", 820 | " carry=actor_carry_in,\n", 821 | " )\n", 822 | " action_j = action_dist_j.mode() if argmax else action_dist_j.sample(seed=rng)\n", 823 | " return ksim.Action(action=action_j, carry=(actor_carry, critic_carry_in))" 824 | ] 825 | }, 826 | { 827 | "cell_type": "markdown", 828 | "id": "4dcf85a7", 829 | "metadata": {}, 830 | "source": [ 831 | "# Launch TensorBoard\n", 832 | "\n", 833 | "The below cell launches TensorBoard to visualize the training progress.\n", 834 | "\n", 835 | "After launching an experiment, please wait for ~5 minutes for the task to start running and then click the reload button in the top right corner of the TensorBoard page. You can also open the settings and check the \"Reload data\" option to automatically reload the TensorBoard. " 836 | ] 837 | }, 838 | { 839 | "cell_type": "code", 840 | "execution_count": null, 841 | "id": "14", 842 | "metadata": {}, 843 | "outputs": [], 844 | "source": [ 845 | "# Launch TensorBoard\n", 846 | "%load_ext tensorboard\n", 847 | "%tensorboard --logdir humanoid_walking_task" 848 | ] 849 | }, 850 | { 851 | "cell_type": "markdown", 852 | "id": "15", 853 | "metadata": { 854 | "id": "11" 855 | }, 856 | "source": [ 857 | "## Launching an Experiment\n", 858 | "\n", 859 | "To launch an experiment with `xax`, you can use `Task.launch(config)`. Note that this is usually intended to be called from the command-line, so it will by default attempt to parse additional command-line arguments unless `use_cli=False` is set.\n", 860 | "\n", 861 | "By default, runs will be logged to a directory called `run_[x]` in the task directory (/content/humanoid_walking_task/ in Colab). From there, you can download the ckpt `.bin` files and the TensorBoard logs.\n", 862 | "\n", 863 | "Also note that since this is a Jupyter notebook, the task will be unable to find the task training code and emit a warning like \"Could not resolve task path for , returning current working directory\". You can safely ignore this warning." 864 | ] 865 | }, 866 | { 867 | "cell_type": "code", 868 | "execution_count": null, 869 | "id": "16", 870 | "metadata": { 871 | "colab": { 872 | "base_uri": "https://localhost:8080/" 873 | }, 874 | "id": "12", 875 | "outputId": "b37b0dca-fc55-445b-f175-4f4d533bd22b" 876 | }, 877 | "outputs": [], 878 | "source": [ 879 | "if __name__ == \"__main__\":\n", 880 | " HumanoidWalkingTask.launch(\n", 881 | " HumanoidWalkingTaskConfig(\n", 882 | " # Training parameters.\n", 883 | " num_envs=2048,\n", 884 | " batch_size=256,\n", 885 | " num_passes=4,\n", 886 | " epochs_per_log_step=1,\n", 887 | " rollout_length_seconds=8.0,\n", 888 | " global_grad_clip=2.0,\n", 889 | " # Simulation parameters.\n", 890 | " dt=0.002,\n", 891 | " ctrl_dt=0.02,\n", 892 | " iterations=8,\n", 893 | " ls_iterations=8,\n", 894 | " action_latency_range=(0.003, 0.01), # Simulate 3-10ms of latency.\n", 895 | " drop_action_prob=0.05, # Drop 5% of commands.\n", 896 | " # Visualization parameters\n", 897 | " render_track_body_id=0,\n", 898 | " # Checkpointing parameters.\n", 899 | " save_every_n_seconds=60,\n", 900 | " ),\n", 901 | " use_cli=False,\n", 902 | " )" 903 | ] 904 | } 905 | ], 906 | "metadata": { 907 | "accelerator": "GPU", 908 | "colab": { 909 | "gpuType": "T4", 910 | "provenance": [] 911 | }, 912 | "kernelspec": { 913 | "display_name": "Python 3", 914 | "name": "python3" 915 | }, 916 | "language_info": { 917 | "codemirror_mode": { 918 | "name": "ipython", 919 | "version": 3 920 | }, 921 | "file_extension": ".py", 922 | "mimetype": "text/x-python", 923 | "name": "python", 924 | "nbconvert_exporter": "python", 925 | "pygments_lexer": "ipython3", 926 | "version": "3.11.11" 927 | } 928 | }, 929 | "nbformat": 4, 930 | "nbformat_minor": 5 931 | } 932 | --------------------------------------------------------------------------------