├── .pre-commit-config.yaml
├── requirements.txt
├── .gitignore
├── Makefile
├── LICENSE
├── .github
└── workflows
│ └── test.yml
├── pyproject.toml
├── convert.py
├── README.md
├── train.py
└── train.ipynb
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/kynan/nbstripout
3 | rev: 0.8.1
4 | hooks:
5 | - id: nbstripout
6 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # requirements.txt
2 |
3 | # Training dependencies
4 | ksim==0.1.3
5 | xax==0.3.0
6 | mujoco-scenes
7 |
8 | # Inference dependencies
9 | kinfer[jax]
10 |
11 | # Jupyter Notebook dependencies
12 | jupyter
13 | ipykernel
14 | pre-commit
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # .gitignore
2 |
3 | # Python
4 | *.py[oc]
5 | __pycache__/
6 | *.egg-info
7 | .eggs/
8 | .mypy_cache/*
9 | .pyre/
10 | .pytest_cache/
11 | .ruff_cache/
12 | .dmypy.json
13 |
14 | # Jupyter
15 | .ipynb_checkpoints/
16 |
17 | # Databases
18 | *.db
19 |
20 | # Logs
21 | rollouts/
22 |
23 | # Build artifacts
24 | build/
25 | dist/
26 | *.so
27 | out*/
28 | **/run_*/
29 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile
2 |
3 | py-files := $(shell find . -name '*.py' -not -path "*/run_*/*" -not -path "*/build/*")
4 |
5 | install:
6 | @pip install --upgrade --upgrade-strategy eager -r requirements.txt
7 | .PHONY: install
8 |
9 | install-dev:
10 | @pip install ruff mypy
11 | .PHONY: install-dev
12 |
13 | format:
14 | @ruff format $(py-files)
15 | @ruff check --fix $(py-files)
16 | .PHONY: format
17 |
18 | static-checks:
19 | @mkdir -p .mypy_cache
20 | @ruff check $(py-files)
21 | @mypy --install-types --non-interactive $(py-files)
22 | .PHONY: lint
23 |
24 | notebook:
25 | jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser
26 | .PHONY: notebook
27 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 K-Scale Labs
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Python Checks
2 |
3 | on:
4 | push:
5 | branches:
6 | - master
7 | pull_request:
8 | branches:
9 | - master
10 | types:
11 | - opened
12 | - reopened
13 | - synchronize
14 | - ready_for_review
15 |
16 | concurrency:
17 | group: tests-${{ github.head_ref || github.run_id }}
18 | cancel-in-progress: true
19 |
20 | jobs:
21 | run-base-tests:
22 | timeout-minutes: 10
23 | runs-on: ubuntu-latest
24 | steps:
25 | - name: Check out repository
26 | uses: actions/checkout@v4
27 |
28 | - name: Set up Python
29 | uses: actions/setup-python@v4
30 | with:
31 | python-version: "3.12"
32 |
33 | - name: Restore cache
34 | id: restore-cache
35 | uses: actions/cache/restore@v3
36 | with:
37 | path: |
38 | ${{ env.pythonLocation }}
39 | .mypy_cache/
40 | key: python-requirements-${{ env.pythonLocation }}-${{ github.event.pull_request.base.sha || github.sha }}
41 | restore-keys: |
42 | python-requirements-${{ env.pythonLocation }}
43 | python-requirements-
44 |
45 | - name: Install package
46 | run: |
47 | make install
48 | make install-dev
49 |
50 | - name: Run static checks
51 | run: |
52 | make static-checks
53 |
54 | - name: Save cache
55 | uses: actions/cache/save@v3
56 | if: github.ref == 'refs/heads/master'
57 | with:
58 | path: |
59 | ${{ env.pythonLocation }}
60 | .mypy_cache/
61 | key: ${{ steps.restore-cache.outputs.cache-primary-key }}
62 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pytest.ini_options]
2 |
3 | addopts = "-rx -rf -x -q --full-trace"
4 | testpaths = ["tests"]
5 |
6 | markers = [
7 | "slow: Marks test as being slow",
8 | ]
9 |
10 | [tool.mypy]
11 |
12 | pretty = true
13 | show_column_numbers = true
14 | show_error_context = true
15 | show_error_codes = true
16 | show_traceback = true
17 | disallow_untyped_defs = true
18 | strict_equality = true
19 | allow_redefinition = true
20 |
21 | warn_unused_ignores = true
22 | warn_redundant_casts = true
23 |
24 | incremental = true
25 | namespace_packages = false
26 |
27 | [[tool.mypy.overrides]]
28 |
29 | module = [
30 | "_pytest.*",
31 | "distrax.*",
32 | "equinox.*",
33 | "glfw.*",
34 | "mujoco.*",
35 | "optax.*",
36 | "scipy.*",
37 | "bvhio.*",
38 | "glm.*",
39 | "imageio.*",
40 | "tensorflow.*"
41 | ]
42 | ignore_missing_imports = true
43 |
44 | [tool.isort]
45 |
46 | profile = "black"
47 |
48 | [tool.ruff]
49 |
50 | line-length = 120
51 | target-version = "py310"
52 |
53 | [tool.ruff.format]
54 |
55 | quote-style = "double"
56 | docstring-code-format = true
57 |
58 | [tool.ruff.lint]
59 |
60 | select = ["ANN", "D", "E", "F", "G", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
61 |
62 | ignore = [
63 | "D101", "D102", "D103", "D104", "D105", "D106", "D107",
64 | "N812", "N817",
65 | "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004",
66 | "PLW0603", "PLW2901",
67 | ]
68 |
69 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
70 |
71 | [tool.ruff.lint.per-file-ignores]
72 |
73 | "__init__.py" = ["E402", "F401", "F403", "F811"]
74 |
75 | [tool.ruff.lint.isort]
76 |
77 | known-first-party = ["benchmark", "tests"]
78 | combine-as-imports = true
79 |
80 | [tool.ruff.lint.pydocstyle]
81 |
82 | convention = "google"
83 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | """Converts a checkpoint to a deployable model."""
2 |
3 | import argparse
4 | from pathlib import Path
5 |
6 | import jax
7 | import jax.numpy as jnp
8 | import ksim
9 | from jaxtyping import Array
10 | from kinfer.export.jax import export_fn
11 | from kinfer.export.serialize import pack
12 | from kinfer.rust_bindings import PyModelMetadata
13 |
14 | from train import HumanoidWalkingTask, Model
15 |
16 |
17 | def main() -> None:
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument("checkpoint_path", type=str)
20 | parser.add_argument("output_path", type=str)
21 | args = parser.parse_args()
22 |
23 | if not (ckpt_path := Path(args.checkpoint_path)).exists():
24 | raise FileNotFoundError(f"Checkpoint path {ckpt_path} does not exist")
25 |
26 | task: HumanoidWalkingTask = HumanoidWalkingTask.load_task(ckpt_path)
27 | model: Model = task.load_ckpt(ckpt_path, part="model")[0]
28 |
29 | # Loads the Mujoco model and gets the joint names.
30 | mujoco_model = task.get_mujoco_model()
31 | joint_names = ksim.get_joint_names_in_order(mujoco_model)[1:] # Removes the root joint.
32 |
33 | # Constant values.
34 | carry_shape = (task.config.depth, task.config.hidden_size)
35 |
36 | metadata = PyModelMetadata(
37 | joint_names=joint_names,
38 | num_commands=None,
39 | carry_size=carry_shape,
40 | )
41 |
42 | @jax.jit
43 | def init_fn() -> Array:
44 | return jnp.zeros(carry_shape)
45 |
46 | @jax.jit
47 | def step_fn(
48 | joint_angles: Array,
49 | joint_angular_velocities: Array,
50 | projected_gravity: Array,
51 | accelerometer: Array,
52 | gyroscope: Array,
53 | time: Array,
54 | carry: Array,
55 | ) -> tuple[Array, Array]:
56 | obs = jnp.concatenate(
57 | [
58 | jnp.sin(time),
59 | jnp.cos(time),
60 | joint_angles,
61 | joint_angular_velocities,
62 | projected_gravity,
63 | accelerometer,
64 | gyroscope,
65 | ],
66 | axis=-1,
67 | )
68 | dist, carry = model.actor.forward(obs, carry)
69 | return dist.mode(), carry
70 |
71 | init_onnx = export_fn(
72 | model=init_fn,
73 | metadata=metadata,
74 | )
75 |
76 | step_onnx = export_fn(
77 | model=step_fn,
78 | metadata=metadata,
79 | )
80 |
81 | kinfer_model = pack(
82 | init_fn=init_onnx,
83 | step_fn=step_onnx,
84 | metadata=metadata,
85 | )
86 |
87 | # Saves the resulting model.
88 | (output_path := Path(args.output_path)).parent.mkdir(parents=True, exist_ok=True)
89 | with open(output_path, "wb") as f:
90 | f.write(kinfer_model)
91 |
92 |
93 | if __name__ == "__main__":
94 | main()
95 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
K-Sim Gym
3 |
Train and deploy your own humanoid robot controller in 700 lines of Python
4 |
12 |
13 | https://github.com/user-attachments/assets/82e5e998-1d62-43e2-ae52-864af6e72629
14 |
15 |
16 |
17 | ## Getting Started
18 |
19 | You can use this repository as a GitHub template or as a Google Colab.
20 |
21 | ### Google Colab
22 |
23 | You can quickly try out the humanoid benchmark by running the [training notebook](https://colab.research.google.com/github/kscalelabs/ksim-gym/blob/master/train.ipynb) in Google Colab.
24 |
25 | ### On your own GPU
26 |
27 | 1. Read through the [current leaderboard](https://www.kscale.dev/benchmarks) submissions and through the [ksim examples](https://github.com/kscalelabs/ksim/tree/master/examples)
28 | 2. Create a new repository from this template by clicking [here](https://github.com/new?template_name=ksim-gym&template_owner=kscalelabs)
29 | 3. Clone the new repository you create from this template:
30 |
31 | ```bash
32 | git clone git@github.com:/ksim-gym.git
33 | cd ksim-gym
34 | ```
35 |
36 | 4. Create a new Python environment (we require Python 3.11 or later and recommend using [conda](https://docs.conda.io/projects/conda/en/stable/user-guide/getting-started.html))
37 | 5. Install the package with its dependencies:
38 |
39 | ```bash
40 | pip install -r requirements.txt
41 | pip install 'jax[cuda12]' # If using GPU machine, install JAX CUDA libraries
42 | python -c "import jax; print(jax.default_backend())" # Should print "gpu"
43 | ```
44 |
45 | 6. Train a policy:
46 | - Your robot should be walking within ~80 training steps, which takes 30 minutes on an RTX 4090 GPU.
47 | - Training runs indefinitely, unless you set the `max_steps` argument. You can also use `Ctrl+C` to stop it.
48 | - Click on the TensorBoard link in the terminal to visualize the current run's training logs and videos.
49 | - List all the available arguments with `python -m train --help`.
50 | ```bash
51 | python -m train
52 | ```
53 | ```bash
54 | # You can override default arguments like this
55 | python -m train max_steps=100
56 | ```
57 | 7. To see the TensorBoard logs for all your runs:
58 | ```bash
59 | tensorboard --logdir humanoid_walking_task
60 | ```
61 | 8. To view your trained checkpoint in the interactive viewer:
62 | - Use the mouse to move the camera around
63 | - Hold `Ctrl` and double click to select a body on the robot, and then left or right click to apply forces to it.
64 | ```bash
65 | python -m train run_mode=view load_from_ckpt_path=humanoid_walking_task/run_/checkpoints/ckpt.bin
66 | ```
67 |
68 | 9. Convert your trained checkpoint to a `kinfer` model, which can be deployed on a real robot:
69 |
70 | ```bash
71 | python -m convert /path/to/ckpt.bin /path/to/model.kinfer
72 | ```
73 |
74 | 10. Visualize the converted model in [`kinfer-sim`](https://docs.kscale.dev/docs/k-infer):
75 |
76 | ```bash
77 | kinfer-sim assets/model.kinfer kbot --start-height 1.2 --save-video video.mp4
78 | ```
79 |
80 | 11. Commit the K-Infer model and the recorded video to this repository
81 | 12. Push your code and model to your repository, and make sure the repository is public (you may need to use [Git LFS](https://git-lfs.com))
82 | 13. Write a message with a link to your repository on our [Discord](https://url.kscale.dev/discord) in the "【🧠】submissions" channel
83 | 14. Wait for one of us to run it on the real robot - this should take about a day, but if we are dragging our feet, please message us on Discord
84 | 15. Voila! Your name will now appear on our [leaderboard](https://url.kscale.dev/leaderboard)
85 |
86 | ## Troubleshooting
87 |
88 | If you encounter issues, please consult the [ksim documentation](https://docs.kscale.dev/docs/ksim#/) or reach out to us on [Discord](https://url.kscale.dev/discord).
89 |
90 | ## Tips and Tricks
91 |
92 | To see all the available command line arguments, use the command:
93 |
94 | ```bash
95 | python -m train --help
96 | ```
97 |
98 | To visualize running your model without using `kinfer-sim`, use the command:
99 |
100 | ```bash
101 | python -m train run_mode=view
102 | ```
103 |
104 | To see an example of a locomotion task with more complex reward tuning, see our [kbot-joystick](https://github.com/kscalelabs/kbot-joystick) task which was generated from this template. It also contains a pretrained checkpoint that you can initialize training from by running
105 |
106 | ```bash
107 | python -m train load_from_ckpt_path=assets/ckpt.bin
108 | ```
109 |
110 | You can also visualize the pre-trained model by combining these two commands:
111 |
112 | ```bash
113 | python -m train load_from_ckpt_path=assets/ckpt.bin run_mode=view
114 | ```
115 |
116 | If you want to use the Jupyter notebook and don't want to commit your training logs, we suggest using [pre-commit](https://pre-commit.com/) to clean the notebook before committing:
117 |
118 | ```bash
119 | pip install pre-commit
120 | pre-commit install
121 | ```
122 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """Defines simple task for training a walking policy for the default humanoid."""
2 |
3 | import asyncio
4 | import functools
5 | import math
6 | from dataclasses import dataclass
7 | from typing import Self
8 |
9 | import attrs
10 | import distrax
11 | import equinox as eqx
12 | import jax
13 | import jax.numpy as jnp
14 | import ksim
15 | import mujoco
16 | import mujoco_scenes
17 | import mujoco_scenes.mjcf
18 | import optax
19 | import xax
20 | from jaxtyping import Array, PRNGKeyArray
21 |
22 | # These are in the order of the neural network outputs.
23 | ZEROS: list[tuple[str, float]] = [
24 | ("dof_right_shoulder_pitch_03", 0.0),
25 | ("dof_right_shoulder_roll_03", math.radians(-10.0)),
26 | ("dof_right_shoulder_yaw_02", 0.0),
27 | ("dof_right_elbow_02", math.radians(90.0)),
28 | ("dof_right_wrist_00", 0.0),
29 | ("dof_left_shoulder_pitch_03", 0.0),
30 | ("dof_left_shoulder_roll_03", math.radians(10.0)),
31 | ("dof_left_shoulder_yaw_02", 0.0),
32 | ("dof_left_elbow_02", math.radians(-90.0)),
33 | ("dof_left_wrist_00", 0.0),
34 | ("dof_right_hip_pitch_04", math.radians(-20.0)),
35 | ("dof_right_hip_roll_03", math.radians(-0.0)),
36 | ("dof_right_hip_yaw_03", 0.0),
37 | ("dof_right_knee_04", math.radians(-50.0)),
38 | ("dof_right_ankle_02", math.radians(30.0)),
39 | ("dof_left_hip_pitch_04", math.radians(20.0)),
40 | ("dof_left_hip_roll_03", math.radians(0.0)),
41 | ("dof_left_hip_yaw_03", 0.0),
42 | ("dof_left_knee_04", math.radians(50.0)),
43 | ("dof_left_ankle_02", math.radians(-30.0)),
44 | ]
45 |
46 |
47 | @dataclass
48 | class HumanoidWalkingTaskConfig(ksim.PPOConfig):
49 | """Config for the humanoid walking task."""
50 |
51 | # Model parameters.
52 | hidden_size: int = xax.field(
53 | value=128,
54 | help="The hidden size for the MLPs.",
55 | )
56 | depth: int = xax.field(
57 | value=5,
58 | help="The depth for the MLPs.",
59 | )
60 | num_mixtures: int = xax.field(
61 | value=5,
62 | help="The number of mixtures for the actor.",
63 | )
64 | var_scale: float = xax.field(
65 | value=0.5,
66 | help="The scale for the standard deviations of the actor.",
67 | )
68 | use_acc_gyro: bool = xax.field(
69 | value=True,
70 | help="Whether to use the IMU acceleration and gyroscope observations.",
71 | )
72 |
73 | # Optimizer parameters.
74 | learning_rate: float = xax.field(
75 | value=3e-4,
76 | help="Learning rate for PPO.",
77 | )
78 | adam_weight_decay: float = xax.field(
79 | value=1e-5,
80 | help="Weight decay for the Adam optimizer.",
81 | )
82 |
83 |
84 | @attrs.define(frozen=True, kw_only=True)
85 | class JointPositionPenalty(ksim.JointDeviationPenalty):
86 | @classmethod
87 | def create_from_names(
88 | cls,
89 | names: list[str],
90 | physics_model: ksim.PhysicsModel,
91 | scale: float = -1.0,
92 | scale_by_curriculum: bool = False,
93 | ) -> Self:
94 | zeros = {k: v for k, v in ZEROS}
95 | joint_targets = [zeros[name] for name in names]
96 |
97 | return cls.create(
98 | physics_model=physics_model,
99 | joint_names=tuple(names),
100 | joint_targets=tuple(joint_targets),
101 | scale=scale,
102 | scale_by_curriculum=scale_by_curriculum,
103 | )
104 |
105 |
106 | @attrs.define(frozen=True, kw_only=True)
107 | class BentArmPenalty(JointPositionPenalty):
108 | @classmethod
109 | def create_penalty(
110 | cls,
111 | physics_model: ksim.PhysicsModel,
112 | scale: float = -1.0,
113 | scale_by_curriculum: bool = False,
114 | ) -> Self:
115 | return cls.create_from_names(
116 | names=[
117 | "dof_right_shoulder_pitch_03",
118 | "dof_right_shoulder_roll_03",
119 | "dof_right_shoulder_yaw_02",
120 | "dof_right_elbow_02",
121 | "dof_right_wrist_00",
122 | "dof_left_shoulder_pitch_03",
123 | "dof_left_shoulder_roll_03",
124 | "dof_left_shoulder_yaw_02",
125 | "dof_left_elbow_02",
126 | "dof_left_wrist_00",
127 | ],
128 | physics_model=physics_model,
129 | scale=scale,
130 | scale_by_curriculum=scale_by_curriculum,
131 | )
132 |
133 |
134 | @attrs.define(frozen=True, kw_only=True)
135 | class StraightLegPenalty(JointPositionPenalty):
136 | @classmethod
137 | def create_penalty(
138 | cls,
139 | physics_model: ksim.PhysicsModel,
140 | scale: float = -1.0,
141 | scale_by_curriculum: bool = False,
142 | ) -> Self:
143 | return cls.create_from_names(
144 | names=[
145 | "dof_left_hip_roll_03",
146 | "dof_left_hip_yaw_03",
147 | "dof_right_hip_roll_03",
148 | "dof_right_hip_yaw_03",
149 | ],
150 | physics_model=physics_model,
151 | scale=scale,
152 | scale_by_curriculum=scale_by_curriculum,
153 | )
154 |
155 |
156 | class Actor(eqx.Module):
157 | """Actor for the walking task."""
158 |
159 | input_proj: eqx.nn.Linear
160 | rnns: tuple[eqx.nn.GRUCell, ...]
161 | output_proj: eqx.nn.Linear
162 | num_inputs: int = eqx.static_field()
163 | num_outputs: int = eqx.static_field()
164 | num_mixtures: int = eqx.static_field()
165 | min_std: float = eqx.static_field()
166 | max_std: float = eqx.static_field()
167 | var_scale: float = eqx.static_field()
168 |
169 | def __init__(
170 | self,
171 | key: PRNGKeyArray,
172 | *,
173 | num_inputs: int,
174 | num_outputs: int,
175 | min_std: float,
176 | max_std: float,
177 | var_scale: float,
178 | hidden_size: int,
179 | num_mixtures: int,
180 | depth: int,
181 | ) -> None:
182 | # Project input to hidden size
183 | key, input_proj_key = jax.random.split(key)
184 | self.input_proj = eqx.nn.Linear(
185 | in_features=num_inputs,
186 | out_features=hidden_size,
187 | key=input_proj_key,
188 | )
189 |
190 | # Create RNN layer
191 | key, rnn_key = jax.random.split(key)
192 | rnn_keys = jax.random.split(rnn_key, depth)
193 | self.rnns = tuple(
194 | [
195 | eqx.nn.GRUCell(
196 | input_size=hidden_size,
197 | hidden_size=hidden_size,
198 | key=rnn_key,
199 | )
200 | for rnn_key in rnn_keys
201 | ]
202 | )
203 |
204 | # Project to output
205 | self.output_proj = eqx.nn.Linear(
206 | in_features=hidden_size,
207 | out_features=num_outputs * 3 * num_mixtures,
208 | key=key,
209 | )
210 |
211 | self.num_inputs = num_inputs
212 | self.num_outputs = num_outputs
213 | self.num_mixtures = num_mixtures
214 | self.min_std = min_std
215 | self.max_std = max_std
216 | self.var_scale = var_scale
217 |
218 | def forward(self, obs_n: Array, carry: Array) -> tuple[distrax.Distribution, Array]:
219 | x_n = self.input_proj(obs_n)
220 | out_carries = []
221 | for i, rnn in enumerate(self.rnns):
222 | x_n = rnn(x_n, carry[i])
223 | out_carries.append(x_n)
224 | out_n = self.output_proj(x_n)
225 |
226 | # Reshape the output to be a mixture of gaussians.
227 | slice_len = self.num_outputs * self.num_mixtures
228 | mean_nm = out_n[..., :slice_len].reshape(self.num_outputs, self.num_mixtures)
229 | std_nm = out_n[..., slice_len : slice_len * 2].reshape(self.num_outputs, self.num_mixtures)
230 | logits_nm = out_n[..., slice_len * 2 :].reshape(self.num_outputs, self.num_mixtures)
231 |
232 | # Softplus and clip to ensure positive standard deviations.
233 | std_nm = jnp.clip((jax.nn.softplus(std_nm) + self.min_std) * self.var_scale, max=self.max_std)
234 |
235 | # Apply bias to the means.
236 | mean_nm = mean_nm + jnp.array([v for _, v in ZEROS])[:, None]
237 |
238 | dist_n = ksim.MixtureOfGaussians(means_nm=mean_nm, stds_nm=std_nm, logits_nm=logits_nm)
239 |
240 | return dist_n, jnp.stack(out_carries, axis=0)
241 |
242 |
243 | class Critic(eqx.Module):
244 | """Critic for the walking task."""
245 |
246 | input_proj: eqx.nn.Linear
247 | rnns: tuple[eqx.nn.GRUCell, ...]
248 | output_proj: eqx.nn.Linear
249 | num_inputs: int = eqx.static_field()
250 |
251 | def __init__(
252 | self,
253 | key: PRNGKeyArray,
254 | *,
255 | num_inputs: int,
256 | hidden_size: int,
257 | depth: int,
258 | ) -> None:
259 | num_outputs = 1
260 |
261 | # Project input to hidden size
262 | key, input_proj_key = jax.random.split(key)
263 | self.input_proj = eqx.nn.Linear(
264 | in_features=num_inputs,
265 | out_features=hidden_size,
266 | key=input_proj_key,
267 | )
268 |
269 | # Create RNN layer
270 | key, rnn_key = jax.random.split(key)
271 | rnn_keys = jax.random.split(rnn_key, depth)
272 | self.rnns = tuple(
273 | [
274 | eqx.nn.GRUCell(
275 | input_size=hidden_size,
276 | hidden_size=hidden_size,
277 | key=rnn_key,
278 | )
279 | for rnn_key in rnn_keys
280 | ]
281 | )
282 |
283 | # Project to output
284 | self.output_proj = eqx.nn.Linear(
285 | in_features=hidden_size,
286 | out_features=num_outputs,
287 | key=key,
288 | )
289 |
290 | self.num_inputs = num_inputs
291 |
292 | def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]:
293 | x_n = self.input_proj(obs_n)
294 | out_carries = []
295 | for i, rnn in enumerate(self.rnns):
296 | x_n = rnn(x_n, carry[i])
297 | out_carries.append(x_n)
298 | out_n = self.output_proj(x_n)
299 |
300 | return out_n, jnp.stack(out_carries, axis=0)
301 |
302 |
303 | class Model(eqx.Module):
304 | actor: Actor
305 | critic: Critic
306 |
307 | def __init__(
308 | self,
309 | key: PRNGKeyArray,
310 | *,
311 | num_actor_inputs: int,
312 | num_actor_outputs: int,
313 | num_critic_inputs: int,
314 | min_std: float,
315 | max_std: float,
316 | var_scale: float,
317 | hidden_size: int,
318 | num_mixtures: int,
319 | depth: int,
320 | ) -> None:
321 | actor_key, critic_key = jax.random.split(key)
322 | self.actor = Actor(
323 | actor_key,
324 | num_inputs=num_actor_inputs,
325 | num_outputs=num_actor_outputs,
326 | min_std=min_std,
327 | max_std=max_std,
328 | var_scale=var_scale,
329 | hidden_size=hidden_size,
330 | num_mixtures=num_mixtures,
331 | depth=depth,
332 | )
333 | self.critic = Critic(
334 | critic_key,
335 | hidden_size=hidden_size,
336 | depth=depth,
337 | num_inputs=num_critic_inputs,
338 | )
339 |
340 |
341 | class HumanoidWalkingTask(ksim.PPOTask[HumanoidWalkingTaskConfig]):
342 | def get_optimizer(self) -> optax.GradientTransformation:
343 | return (
344 | optax.adam(self.config.learning_rate)
345 | if self.config.adam_weight_decay == 0.0
346 | else optax.adamw(self.config.learning_rate, weight_decay=self.config.adam_weight_decay)
347 | )
348 |
349 | def get_mujoco_model(self) -> mujoco.MjModel:
350 | mjcf_path = asyncio.run(ksim.get_mujoco_model_path("kbot", name="robot"))
351 | return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene="smooth")
352 |
353 | def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> ksim.Metadata:
354 | metadata = asyncio.run(ksim.get_mujoco_model_metadata("kbot"))
355 | if metadata.joint_name_to_metadata is None:
356 | raise ValueError("Joint metadata is not available")
357 | if metadata.actuator_type_to_metadata is None:
358 | raise ValueError("Actuator metadata is not available")
359 | return metadata
360 |
361 | def get_actuators(
362 | self,
363 | physics_model: ksim.PhysicsModel,
364 | metadata: ksim.Metadata | None = None,
365 | ) -> ksim.Actuators:
366 | assert metadata is not None, "Metadata is required"
367 | return ksim.PositionActuators(
368 | physics_model=physics_model,
369 | metadata=metadata,
370 | )
371 |
372 | def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]:
373 | return [
374 | ksim.StaticFrictionRandomizer(),
375 | ksim.ArmatureRandomizer(),
376 | ksim.AllBodiesMassMultiplicationRandomizer(scale_lower=0.95, scale_upper=1.05),
377 | ksim.JointDampingRandomizer(),
378 | ksim.JointZeroPositionRandomizer(scale_lower=math.radians(-2), scale_upper=math.radians(2)),
379 | ]
380 |
381 | def get_events(self, physics_model: ksim.PhysicsModel) -> list[ksim.Event]:
382 | return [
383 | ksim.PushEvent(
384 | x_force=1.0,
385 | y_force=1.0,
386 | z_force=0.3,
387 | force_range=(0.5, 1.0),
388 | x_angular_force=0.0,
389 | y_angular_force=0.0,
390 | z_angular_force=0.0,
391 | interval_range=(0.5, 4.0),
392 | ),
393 | ]
394 |
395 | def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]:
396 | return [
397 | ksim.RandomJointPositionReset.create(physics_model, {k: v for k, v in ZEROS}, scale=0.1),
398 | ksim.RandomJointVelocityReset(),
399 | ]
400 |
401 | def get_observations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Observation]:
402 | return [
403 | ksim.TimestepObservation(),
404 | ksim.JointPositionObservation(noise=math.radians(2)),
405 | ksim.JointVelocityObservation(noise=math.radians(10)),
406 | ksim.ActuatorForceObservation(),
407 | ksim.CenterOfMassInertiaObservation(),
408 | ksim.CenterOfMassVelocityObservation(),
409 | ksim.BasePositionObservation(),
410 | ksim.BaseOrientationObservation(),
411 | ksim.BaseLinearVelocityObservation(),
412 | ksim.BaseAngularVelocityObservation(),
413 | ksim.BaseLinearAccelerationObservation(),
414 | ksim.BaseAngularAccelerationObservation(),
415 | ksim.ActuatorAccelerationObservation(),
416 | ksim.ProjectedGravityObservation.create(
417 | physics_model=physics_model,
418 | framequat_name="imu_site_quat",
419 | lag_range=(0.0, 0.1),
420 | noise=math.radians(1),
421 | ),
422 | ksim.SensorObservation.create(
423 | physics_model=physics_model,
424 | sensor_name="imu_acc",
425 | noise=1.0,
426 | ),
427 | ksim.SensorObservation.create(
428 | physics_model=physics_model,
429 | sensor_name="imu_gyro",
430 | noise=math.radians(10),
431 | ),
432 | ]
433 |
434 | def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:
435 | return []
436 |
437 | def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:
438 | return [
439 | # Standard rewards.
440 | ksim.NaiveForwardReward(clip_max=1.25, in_robot_frame=False, scale=3.0),
441 | ksim.NaiveForwardOrientationReward(scale=1.0),
442 | ksim.StayAliveReward(scale=1.0),
443 | ksim.UprightReward(scale=0.5),
444 | # Avoid movement penalties.
445 | ksim.AngularVelocityPenalty(index=("x", "y"), scale=-0.1),
446 | ksim.LinearVelocityPenalty(index=("z"), scale=-0.1),
447 | # Normalization penalties.
448 | ksim.AvoidLimitsPenalty.create(physics_model, scale=-0.01),
449 | ksim.JointAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),
450 | ksim.JointJerkPenalty(scale=-0.01, scale_by_curriculum=True),
451 | ksim.LinkAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),
452 | ksim.LinkJerkPenalty(scale=-0.01, scale_by_curriculum=True),
453 | ksim.ActionAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),
454 | # Bespoke rewards.
455 | BentArmPenalty.create_penalty(physics_model, scale=-0.1),
456 | StraightLegPenalty.create_penalty(physics_model, scale=-0.1),
457 | ]
458 |
459 | def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:
460 | return [
461 | ksim.BadZTermination(unhealthy_z_lower=0.6, unhealthy_z_upper=1.2),
462 | ksim.FarFromOriginTermination(max_dist=10.0),
463 | ]
464 |
465 | def get_curriculum(self, physics_model: ksim.PhysicsModel) -> ksim.Curriculum:
466 | return ksim.DistanceFromOriginCurriculum(
467 | min_level_steps=5,
468 | )
469 |
470 | def get_model(self, key: PRNGKeyArray) -> Model:
471 | return Model(
472 | key,
473 | num_actor_inputs=51 if self.config.use_acc_gyro else 45,
474 | num_actor_outputs=len(ZEROS),
475 | num_critic_inputs=446,
476 | min_std=0.001,
477 | max_std=1.0,
478 | var_scale=self.config.var_scale,
479 | hidden_size=self.config.hidden_size,
480 | num_mixtures=self.config.num_mixtures,
481 | depth=self.config.depth,
482 | )
483 |
484 | def run_actor(
485 | self,
486 | model: Actor,
487 | observations: xax.FrozenDict[str, Array],
488 | commands: xax.FrozenDict[str, Array],
489 | carry: Array,
490 | ) -> tuple[distrax.Distribution, Array]:
491 | time_1 = observations["timestep_observation"]
492 | joint_pos_n = observations["joint_position_observation"]
493 | joint_vel_n = observations["joint_velocity_observation"]
494 | proj_grav_3 = observations["projected_gravity_observation"]
495 | imu_acc_3 = observations["sensor_observation_imu_acc"]
496 | imu_gyro_3 = observations["sensor_observation_imu_gyro"]
497 |
498 | obs = [
499 | jnp.sin(time_1),
500 | jnp.cos(time_1),
501 | joint_pos_n, # NUM_JOINTS
502 | joint_vel_n, # NUM_JOINTS
503 | proj_grav_3, # 3
504 | ]
505 | if self.config.use_acc_gyro:
506 | obs += [
507 | imu_acc_3, # 3
508 | imu_gyro_3, # 3
509 | ]
510 |
511 | obs_n = jnp.concatenate(obs, axis=-1)
512 | action, carry = model.forward(obs_n, carry)
513 |
514 | return action, carry
515 |
516 | def run_critic(
517 | self,
518 | model: Critic,
519 | observations: xax.FrozenDict[str, Array],
520 | commands: xax.FrozenDict[str, Array],
521 | carry: Array,
522 | ) -> tuple[Array, Array]:
523 | time_1 = observations["timestep_observation"]
524 | dh_joint_pos_j = observations["joint_position_observation"]
525 | dh_joint_vel_j = observations["joint_velocity_observation"]
526 | com_inertia_n = observations["center_of_mass_inertia_observation"]
527 | com_vel_n = observations["center_of_mass_velocity_observation"]
528 | imu_acc_3 = observations["sensor_observation_imu_acc"]
529 | imu_gyro_3 = observations["sensor_observation_imu_gyro"]
530 | proj_grav_3 = observations["projected_gravity_observation"]
531 | act_frc_obs_n = observations["actuator_force_observation"]
532 | base_pos_3 = observations["base_position_observation"]
533 | base_quat_4 = observations["base_orientation_observation"]
534 |
535 | obs_n = jnp.concatenate(
536 | [
537 | jnp.sin(time_1),
538 | jnp.cos(time_1),
539 | dh_joint_pos_j, # NUM_JOINTS
540 | dh_joint_vel_j / 10.0, # NUM_JOINTS
541 | com_inertia_n, # 160
542 | com_vel_n, # 96
543 | imu_acc_3, # 3
544 | imu_gyro_3, # 3
545 | proj_grav_3, # 3
546 | act_frc_obs_n / 100.0, # NUM_JOINTS
547 | base_pos_3, # 3
548 | base_quat_4, # 4
549 | ],
550 | axis=-1,
551 | )
552 |
553 | return model.forward(obs_n, carry)
554 |
555 | def _model_scan_fn(
556 | self,
557 | actor_critic_carry: tuple[Array, Array],
558 | xs: tuple[ksim.Trajectory, PRNGKeyArray],
559 | model: Model,
560 | ) -> tuple[tuple[Array, Array], ksim.PPOVariables]:
561 | transition, rng = xs
562 |
563 | actor_carry, critic_carry = actor_critic_carry
564 | actor_dist, next_actor_carry = self.run_actor(
565 | model=model.actor,
566 | observations=transition.obs,
567 | commands=transition.command,
568 | carry=actor_carry,
569 | )
570 |
571 | # Gets the log probabilities of the action.
572 | log_probs = actor_dist.log_prob(transition.action)
573 | assert isinstance(log_probs, Array)
574 |
575 | value, next_critic_carry = self.run_critic(
576 | model=model.critic,
577 | observations=transition.obs,
578 | commands=transition.command,
579 | carry=critic_carry,
580 | )
581 |
582 | transition_ppo_variables = ksim.PPOVariables(
583 | log_probs=log_probs,
584 | values=value.squeeze(-1),
585 | )
586 |
587 | next_carry = jax.tree.map(
588 | lambda x, y: jnp.where(transition.done, x, y),
589 | self.get_initial_model_carry(rng),
590 | (next_actor_carry, next_critic_carry),
591 | )
592 |
593 | return next_carry, transition_ppo_variables
594 |
595 | def get_ppo_variables(
596 | self,
597 | model: Model,
598 | trajectory: ksim.Trajectory,
599 | model_carry: tuple[Array, Array],
600 | rng: PRNGKeyArray,
601 | ) -> tuple[ksim.PPOVariables, tuple[Array, Array]]:
602 | scan_fn = functools.partial(self._model_scan_fn, model=model)
603 | next_model_carry, ppo_variables = xax.scan(
604 | scan_fn,
605 | model_carry,
606 | (trajectory, jax.random.split(rng, len(trajectory.done))),
607 | jit_level=4,
608 | )
609 | return ppo_variables, next_model_carry
610 |
611 | def get_initial_model_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:
612 | return (
613 | jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),
614 | jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),
615 | )
616 |
617 | def sample_action(
618 | self,
619 | model: Model,
620 | model_carry: tuple[Array, Array],
621 | physics_model: ksim.PhysicsModel,
622 | physics_state: ksim.PhysicsState,
623 | observations: xax.FrozenDict[str, Array],
624 | commands: xax.FrozenDict[str, Array],
625 | rng: PRNGKeyArray,
626 | argmax: bool,
627 | ) -> ksim.Action:
628 | actor_carry_in, critic_carry_in = model_carry
629 | action_dist_j, actor_carry = self.run_actor(
630 | model=model.actor,
631 | observations=observations,
632 | commands=commands,
633 | carry=actor_carry_in,
634 | )
635 | action_j = action_dist_j.mode() if argmax else action_dist_j.sample(seed=rng)
636 | return ksim.Action(action=action_j, carry=(actor_carry, critic_carry_in))
637 |
638 |
639 | if __name__ == "__main__":
640 | HumanoidWalkingTask.launch(
641 | HumanoidWalkingTaskConfig(
642 | # Training parameters.
643 | num_envs=2048,
644 | batch_size=256,
645 | num_passes=4,
646 | epochs_per_log_step=1,
647 | rollout_length_seconds=8.0,
648 | global_grad_clip=2.0,
649 | # Simulation parameters.
650 | dt=0.002,
651 | ctrl_dt=0.02,
652 | iterations=8,
653 | ls_iterations=8,
654 | action_latency_range=(0.003, 0.01), # Simulate 3-10ms of latency.
655 | drop_action_prob=0.05, # Drop 5% of commands.
656 | # Visualization parameters.
657 | render_track_body_id=0,
658 | # Checkpointing parameters.
659 | save_every_n_seconds=60,
660 | ),
661 | )
662 |
--------------------------------------------------------------------------------
/train.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "0",
6 | "metadata": {
7 | "id": "0"
8 | },
9 | "source": [
10 | "# K-Scale Humanoid Benchmark\n",
11 | "\n",
12 | "Welcome to the K-Scale Humanoid Benchmark! This notebook will walk you through training your own reinforcement learning policy, which you can then use to control a K-Scale robot.\n",
13 | "\n",
14 | "*Note:* The Just-In-Time compilation may take a while and cause your Colab instance to appear to disconnect. However, your training cell may actually still be running. Make sure to check before restarting!\n",
15 | "\n",
16 | "*Last updated: 2025/05/15*"
17 | ]
18 | },
19 | {
20 | "cell_type": "markdown",
21 | "id": "1",
22 | "metadata": {
23 | "id": "GcUyV2BhHCxS"
24 | },
25 | "source": [
26 | "## Dependencies and Config\n",
27 | "\n",
28 | "The K-Scale Humanoid Benchmark uses K-Scale's open-source RL framework [K-Sim](https://github.com/kscalelabs/ksim) for training and the [K-Scale API](https://github.com/kscalelabs/kscale) for asset management."
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": null,
34 | "id": "2",
35 | "metadata": {
36 | "colab": {
37 | "base_uri": "https://localhost:8080/"
38 | },
39 | "id": "X9GR-PWjgynB",
40 | "outputId": "8b5227f7-e06e-465e-ce65-ed19f75d1191"
41 | },
42 | "outputs": [],
43 | "source": [
44 | "# Install packages\n",
45 | "\n",
46 | "!pip install ksim==0.1.2 xax==0.3.0 mujoco-scenes"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": null,
52 | "id": "3",
53 | "metadata": {
54 | "colab": {
55 | "base_uri": "https://localhost:8080/"
56 | },
57 | "id": "19e07786",
58 | "outputId": "78e173ed-afd3-45fc-a518-2a4a2ba31b8f"
59 | },
60 | "outputs": [],
61 | "source": [
62 | "# Set up environment variables\n",
63 | "%env MUJOCO_GL=egl"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": null,
69 | "id": "4",
70 | "metadata": {
71 | "id": "1"
72 | },
73 | "outputs": [],
74 | "source": [
75 | "import asyncio\n",
76 | "import functools\n",
77 | "import math\n",
78 | "from dataclasses import dataclass\n",
79 | "from typing import Self\n",
80 | "\n",
81 | "import attrs\n",
82 | "import distrax\n",
83 | "import equinox as eqx\n",
84 | "import jax\n",
85 | "import jax.numpy as jnp\n",
86 | "import ksim\n",
87 | "import mujoco\n",
88 | "import mujoco_scenes\n",
89 | "import mujoco_scenes.mjcf\n",
90 | "import nest_asyncio\n",
91 | "import optax\n",
92 | "import xax\n",
93 | "from jaxtyping import Array, PRNGKeyArray\n",
94 | "\n",
95 | "nest_asyncio.apply()"
96 | ]
97 | },
98 | {
99 | "cell_type": "code",
100 | "execution_count": null,
101 | "id": "5",
102 | "metadata": {
103 | "id": "2"
104 | },
105 | "outputs": [],
106 | "source": [
107 | "# These are in the order of the neural network outputs.\n",
108 | "ZEROS: list[tuple[str, float]] = [\n",
109 | " (\"dof_right_shoulder_pitch_03\", 0.0),\n",
110 | " (\"dof_right_shoulder_roll_03\", math.radians(-10.0)),\n",
111 | " (\"dof_right_shoulder_yaw_02\", 0.0),\n",
112 | " (\"dof_right_elbow_02\", math.radians(90.0)),\n",
113 | " (\"dof_right_wrist_00\", 0.0),\n",
114 | " (\"dof_left_shoulder_pitch_03\", 0.0),\n",
115 | " (\"dof_left_shoulder_roll_03\", math.radians(10.0)),\n",
116 | " (\"dof_left_shoulder_yaw_02\", 0.0),\n",
117 | " (\"dof_left_elbow_02\", math.radians(-90.0)),\n",
118 | " (\"dof_left_wrist_00\", 0.0),\n",
119 | " (\"dof_right_hip_pitch_04\", math.radians(-20.0)),\n",
120 | " (\"dof_right_hip_roll_03\", math.radians(-0.0)),\n",
121 | " (\"dof_right_hip_yaw_03\", 0.0),\n",
122 | " (\"dof_right_knee_04\", math.radians(-50.0)),\n",
123 | " (\"dof_right_ankle_02\", math.radians(30.0)),\n",
124 | " (\"dof_left_hip_pitch_04\", math.radians(20.0)),\n",
125 | " (\"dof_left_hip_roll_03\", math.radians(0.0)),\n",
126 | " (\"dof_left_hip_yaw_03\", 0.0),\n",
127 | " (\"dof_left_knee_04\", math.radians(50.0)),\n",
128 | " (\"dof_left_ankle_02\", math.radians(-30.0)),\n",
129 | "]"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "id": "6",
135 | "metadata": {
136 | "id": "3"
137 | },
138 | "source": [
139 | "## Rewards\n",
140 | "\n",
141 | "When training a reinforcement learning agent, the most important thing to define is what reward you want the agent to maximimze. `ksim` includes a number of useful default rewards for training walking agents, but it is often a good idea to define new rewards to encourage specific types of behavior. The cell below shows an example of how to define a custom reward. A similar pattern can be used to define custom objectives, events, observations, and more."
142 | ]
143 | },
144 | {
145 | "cell_type": "code",
146 | "execution_count": null,
147 | "id": "7",
148 | "metadata": {
149 | "id": "4"
150 | },
151 | "outputs": [],
152 | "source": [
153 | "@attrs.define(frozen=True, kw_only=True)\n",
154 | "class JointPositionPenalty(ksim.JointDeviationPenalty):\n",
155 | " @classmethod\n",
156 | " def create_from_names(\n",
157 | " cls,\n",
158 | " names: list[str],\n",
159 | " physics_model: ksim.PhysicsModel,\n",
160 | " scale: float = -1.0,\n",
161 | " scale_by_curriculum: bool = False,\n",
162 | " ) -> Self:\n",
163 | " zeros = {k: v for k, v in ZEROS}\n",
164 | " joint_targets = [zeros[name] for name in names]\n",
165 | "\n",
166 | " return cls.create(\n",
167 | " physics_model=physics_model,\n",
168 | " joint_names=tuple(names),\n",
169 | " joint_targets=tuple(joint_targets),\n",
170 | " scale=scale,\n",
171 | " scale_by_curriculum=scale_by_curriculum,\n",
172 | " )\n",
173 | "\n",
174 | "\n",
175 | "@attrs.define(frozen=True, kw_only=True)\n",
176 | "class BentArmPenalty(JointPositionPenalty):\n",
177 | " @classmethod\n",
178 | " def create_penalty(\n",
179 | " cls,\n",
180 | " physics_model: ksim.PhysicsModel,\n",
181 | " scale: float = -1.0,\n",
182 | " scale_by_curriculum: bool = False,\n",
183 | " ) -> Self:\n",
184 | " return cls.create_from_names(\n",
185 | " names=[\n",
186 | " \"dof_right_shoulder_pitch_03\",\n",
187 | " \"dof_right_shoulder_roll_03\",\n",
188 | " \"dof_right_shoulder_yaw_02\",\n",
189 | " \"dof_right_elbow_02\",\n",
190 | " \"dof_right_wrist_00\",\n",
191 | " \"dof_left_shoulder_pitch_03\",\n",
192 | " \"dof_left_shoulder_roll_03\",\n",
193 | " \"dof_left_shoulder_yaw_02\",\n",
194 | " \"dof_left_elbow_02\",\n",
195 | " \"dof_left_wrist_00\",\n",
196 | " ],\n",
197 | " physics_model=physics_model,\n",
198 | " scale=scale,\n",
199 | " scale_by_curriculum=scale_by_curriculum,\n",
200 | " )\n",
201 | "\n",
202 | "\n",
203 | "@attrs.define(frozen=True, kw_only=True)\n",
204 | "class StraightLegPenalty(JointPositionPenalty):\n",
205 | " @classmethod\n",
206 | " def create_penalty(\n",
207 | " cls,\n",
208 | " physics_model: ksim.PhysicsModel,\n",
209 | " scale: float = -1.0,\n",
210 | " scale_by_curriculum: bool = False,\n",
211 | " ) -> Self:\n",
212 | " return cls.create_from_names(\n",
213 | " names=[\n",
214 | " \"dof_left_hip_roll_03\",\n",
215 | " \"dof_left_hip_yaw_03\",\n",
216 | " \"dof_right_hip_roll_03\",\n",
217 | " \"dof_right_hip_yaw_03\",\n",
218 | " ],\n",
219 | " physics_model=physics_model,\n",
220 | " scale=scale,\n",
221 | " scale_by_curriculum=scale_by_curriculum,\n",
222 | " )"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "id": "8",
228 | "metadata": {
229 | "id": "5"
230 | },
231 | "source": [
232 | "## Actor-Critic Model\n",
233 | "\n",
234 | "We train our reinforcement learning agent using an RNN-based actor and critic, which we define below."
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "id": "9",
241 | "metadata": {
242 | "id": "6"
243 | },
244 | "outputs": [],
245 | "source": [
246 | "class Actor(eqx.Module):\n",
247 | " \"\"\"Actor for the walking task.\"\"\"\n",
248 | "\n",
249 | " input_proj: eqx.nn.Linear\n",
250 | " rnns: tuple[eqx.nn.GRUCell, ...]\n",
251 | " output_proj: eqx.nn.Linear\n",
252 | " num_inputs: int = eqx.static_field()\n",
253 | " num_outputs: int = eqx.static_field()\n",
254 | " num_mixtures: int = eqx.static_field()\n",
255 | " min_std: float = eqx.static_field()\n",
256 | " max_std: float = eqx.static_field()\n",
257 | " var_scale: float = eqx.static_field()\n",
258 | "\n",
259 | " def __init__(\n",
260 | " self,\n",
261 | " key: PRNGKeyArray,\n",
262 | " *,\n",
263 | " num_inputs: int,\n",
264 | " num_outputs: int,\n",
265 | " min_std: float,\n",
266 | " max_std: float,\n",
267 | " var_scale: float,\n",
268 | " hidden_size: int,\n",
269 | " num_mixtures: int,\n",
270 | " depth: int,\n",
271 | " ) -> None:\n",
272 | " # Project input to hidden size\n",
273 | " key, input_proj_key = jax.random.split(key)\n",
274 | " self.input_proj = eqx.nn.Linear(\n",
275 | " in_features=num_inputs,\n",
276 | " out_features=hidden_size,\n",
277 | " key=input_proj_key,\n",
278 | " )\n",
279 | "\n",
280 | " # Create RNN layer\n",
281 | " key, rnn_key = jax.random.split(key)\n",
282 | " rnn_keys = jax.random.split(rnn_key, depth)\n",
283 | " self.rnns = tuple(\n",
284 | " [\n",
285 | " eqx.nn.GRUCell(\n",
286 | " input_size=hidden_size,\n",
287 | " hidden_size=hidden_size,\n",
288 | " key=rnn_key,\n",
289 | " )\n",
290 | " for rnn_key in rnn_keys\n",
291 | " ]\n",
292 | " )\n",
293 | "\n",
294 | " # Project to output\n",
295 | " self.output_proj = eqx.nn.Linear(\n",
296 | " in_features=hidden_size,\n",
297 | " out_features=num_outputs * 3 * num_mixtures,\n",
298 | " key=key,\n",
299 | " )\n",
300 | "\n",
301 | " self.num_inputs = num_inputs\n",
302 | " self.num_outputs = num_outputs\n",
303 | " self.num_mixtures = num_mixtures\n",
304 | " self.min_std = min_std\n",
305 | " self.max_std = max_std\n",
306 | " self.var_scale = var_scale\n",
307 | "\n",
308 | " def forward(self, obs_n: Array, carry: Array) -> tuple[distrax.Distribution, Array]:\n",
309 | " x_n = self.input_proj(obs_n)\n",
310 | " out_carries = []\n",
311 | " for i, rnn in enumerate(self.rnns):\n",
312 | " x_n = rnn(x_n, carry[i])\n",
313 | " out_carries.append(x_n)\n",
314 | " out_n = self.output_proj(x_n)\n",
315 | "\n",
316 | " # Reshape the output to be a mixture of gaussians.\n",
317 | " slice_len = self.num_outputs * self.num_mixtures\n",
318 | " mean_nm = out_n[..., :slice_len].reshape(self.num_outputs, self.num_mixtures)\n",
319 | " std_nm = out_n[..., slice_len : slice_len * 2].reshape(self.num_outputs, self.num_mixtures)\n",
320 | " logits_nm = out_n[..., slice_len * 2 :].reshape(self.num_outputs, self.num_mixtures)\n",
321 | "\n",
322 | " # Softplus and clip to ensure positive standard deviations.\n",
323 | " std_nm = jnp.clip((jax.nn.softplus(std_nm) + self.min_std) * self.var_scale, max=self.max_std)\n",
324 | "\n",
325 | " # Apply bias to the means.\n",
326 | " mean_nm = mean_nm + jnp.array([v for _, v in ZEROS])[:, None]\n",
327 | "\n",
328 | " dist_n = ksim.MixtureOfGaussians(means_nm=mean_nm, stds_nm=std_nm, logits_nm=logits_nm)\n",
329 | "\n",
330 | " return dist_n, jnp.stack(out_carries, axis=0)\n",
331 | "\n",
332 | "\n",
333 | "class Critic(eqx.Module):\n",
334 | " \"\"\"Critic for the walking task.\"\"\"\n",
335 | "\n",
336 | " input_proj: eqx.nn.Linear\n",
337 | " rnns: tuple[eqx.nn.GRUCell, ...]\n",
338 | " output_proj: eqx.nn.Linear\n",
339 | " num_inputs: int = eqx.static_field()\n",
340 | "\n",
341 | " def __init__(\n",
342 | " self,\n",
343 | " key: PRNGKeyArray,\n",
344 | " *,\n",
345 | " num_inputs: int,\n",
346 | " hidden_size: int,\n",
347 | " depth: int,\n",
348 | " ) -> None:\n",
349 | " num_outputs = 1\n",
350 | "\n",
351 | " # Project input to hidden size\n",
352 | " key, input_proj_key = jax.random.split(key)\n",
353 | " self.input_proj = eqx.nn.Linear(\n",
354 | " in_features=num_inputs,\n",
355 | " out_features=hidden_size,\n",
356 | " key=input_proj_key,\n",
357 | " )\n",
358 | "\n",
359 | " # Create RNN layer\n",
360 | " key, rnn_key = jax.random.split(key)\n",
361 | " rnn_keys = jax.random.split(rnn_key, depth)\n",
362 | " self.rnns = tuple(\n",
363 | " [\n",
364 | " eqx.nn.GRUCell(\n",
365 | " input_size=hidden_size,\n",
366 | " hidden_size=hidden_size,\n",
367 | " key=rnn_key,\n",
368 | " )\n",
369 | " for rnn_key in rnn_keys\n",
370 | " ]\n",
371 | " )\n",
372 | "\n",
373 | " # Project to output\n",
374 | " self.output_proj = eqx.nn.Linear(\n",
375 | " in_features=hidden_size,\n",
376 | " out_features=num_outputs,\n",
377 | " key=key,\n",
378 | " )\n",
379 | "\n",
380 | " self.num_inputs = num_inputs\n",
381 | "\n",
382 | " def forward(self, obs_n: Array, carry: Array) -> tuple[Array, Array]:\n",
383 | " x_n = self.input_proj(obs_n)\n",
384 | " out_carries = []\n",
385 | " for i, rnn in enumerate(self.rnns):\n",
386 | " x_n = rnn(x_n, carry[i])\n",
387 | " out_carries.append(x_n)\n",
388 | " out_n = self.output_proj(x_n)\n",
389 | "\n",
390 | " return out_n, jnp.stack(out_carries, axis=0)\n",
391 | "\n",
392 | "\n",
393 | "class Model(eqx.Module):\n",
394 | " actor: Actor\n",
395 | " critic: Critic\n",
396 | "\n",
397 | " def __init__(\n",
398 | " self,\n",
399 | " key: PRNGKeyArray,\n",
400 | " *,\n",
401 | " num_actor_inputs: int,\n",
402 | " num_actor_outputs: int,\n",
403 | " num_critic_inputs: int,\n",
404 | " min_std: float,\n",
405 | " max_std: float,\n",
406 | " var_scale: float,\n",
407 | " hidden_size: int,\n",
408 | " num_mixtures: int,\n",
409 | " depth: int,\n",
410 | " ) -> None:\n",
411 | " actor_key, critic_key = jax.random.split(key)\n",
412 | " self.actor = Actor(\n",
413 | " actor_key,\n",
414 | " num_inputs=num_actor_inputs,\n",
415 | " num_outputs=num_actor_outputs,\n",
416 | " min_std=min_std,\n",
417 | " max_std=max_std,\n",
418 | " var_scale=var_scale,\n",
419 | " hidden_size=hidden_size,\n",
420 | " num_mixtures=num_mixtures,\n",
421 | " depth=depth,\n",
422 | " )\n",
423 | " self.critic = Critic(\n",
424 | " critic_key,\n",
425 | " hidden_size=hidden_size,\n",
426 | " depth=depth,\n",
427 | " num_inputs=num_critic_inputs,\n",
428 | " )"
429 | ]
430 | },
431 | {
432 | "cell_type": "markdown",
433 | "id": "10",
434 | "metadata": {
435 | "id": "7"
436 | },
437 | "source": [
438 | "## Config\n",
439 | "\n",
440 | "The [ksim framework](https://github.com/kscalelabs/ksim) is based on [xax](https://github.com/kscalelabs/xax), a JAX training library built by K-Scale. To provide configuration options, xax uses a Config dataclass to parse command-line options. We define the config here."
441 | ]
442 | },
443 | {
444 | "cell_type": "code",
445 | "execution_count": null,
446 | "id": "11",
447 | "metadata": {
448 | "id": "8"
449 | },
450 | "outputs": [],
451 | "source": [
452 | "@dataclass\n",
453 | "class HumanoidWalkingTaskConfig(ksim.PPOConfig):\n",
454 | " \"\"\"Config for the humanoid walking task.\"\"\"\n",
455 | "\n",
456 | " # Model parameters.\n",
457 | " hidden_size: int = xax.field(\n",
458 | " value=128,\n",
459 | " help=\"The hidden size for the MLPs.\",\n",
460 | " )\n",
461 | " depth: int = xax.field(\n",
462 | " value=5,\n",
463 | " help=\"The depth for the MLPs.\",\n",
464 | " )\n",
465 | " num_mixtures: int = xax.field(\n",
466 | " value=5,\n",
467 | " help=\"The number of mixtures for the actor.\",\n",
468 | " )\n",
469 | " var_scale: float = xax.field(\n",
470 | " value=0.5,\n",
471 | " help=\"The scale for the standard deviations of the actor.\",\n",
472 | " )\n",
473 | " use_acc_gyro: bool = xax.field(\n",
474 | " value=True,\n",
475 | " help=\"Whether to use the IMU acceleration and gyroscope observations.\",\n",
476 | " )\n",
477 | "\n",
478 | " # Curriculum parameters.\n",
479 | " num_curriculum_levels: int = xax.field(\n",
480 | " value=100,\n",
481 | " help=\"The number of curriculum levels to use.\",\n",
482 | " )\n",
483 | " increase_threshold: float = xax.field(\n",
484 | " value=5.0,\n",
485 | " help=\"Increase the curriculum level when the mean trajectory length is above this threshold.\",\n",
486 | " )\n",
487 | " decrease_threshold: float = xax.field(\n",
488 | " value=1.0,\n",
489 | " help=\"Decrease the curriculum level when the mean trajectory length is below this threshold.\",\n",
490 | " )\n",
491 | " min_level_steps: int = xax.field(\n",
492 | " value=1,\n",
493 | " help=\"The minimum number of steps to wait before changing the curriculum level.\",\n",
494 | " )\n",
495 | "\n",
496 | " # Optimizer parameters.\n",
497 | " learning_rate: float = xax.field(\n",
498 | " value=3e-4,\n",
499 | " help=\"Learning rate for PPO.\",\n",
500 | " )\n",
501 | " adam_weight_decay: float = xax.field(\n",
502 | " value=1e-5,\n",
503 | " help=\"Weight decay for the Adam optimizer.\",\n",
504 | " )"
505 | ]
506 | },
507 | {
508 | "cell_type": "markdown",
509 | "id": "12",
510 | "metadata": {
511 | "id": "9"
512 | },
513 | "source": [
514 | "## Task\n",
515 | "\n",
516 | "The meat-and-potatoes of our training code is the task. This defines the observations, rewards, model calling logic, and everything else needed by `ksim` to train our reinforcement learning agent."
517 | ]
518 | },
519 | {
520 | "cell_type": "code",
521 | "execution_count": null,
522 | "id": "13",
523 | "metadata": {
524 | "id": "10"
525 | },
526 | "outputs": [],
527 | "source": [
528 | "class HumanoidWalkingTask(ksim.PPOTask[HumanoidWalkingTaskConfig]):\n",
529 | " def get_optimizer(self) -> optax.GradientTransformation:\n",
530 | " return (\n",
531 | " optax.adam(self.config.learning_rate)\n",
532 | " if self.config.adam_weight_decay == 0.0\n",
533 | " else optax.adamw(self.config.learning_rate, weight_decay=self.config.adam_weight_decay)\n",
534 | " )\n",
535 | "\n",
536 | " def get_mujoco_model(self) -> mujoco.MjModel:\n",
537 | " mjcf_path = asyncio.run(ksim.get_mujoco_model_path(\"kbot\", name=\"robot\"))\n",
538 | " return mujoco_scenes.mjcf.load_mjmodel(mjcf_path, scene=\"smooth\")\n",
539 | "\n",
540 | " def get_mujoco_model_metadata(self, mj_model: mujoco.MjModel) -> ksim.Metadata:\n",
541 | " metadata = asyncio.run(ksim.get_mujoco_model_metadata(\"kbot\"))\n",
542 | " if metadata.joint_name_to_metadata is None:\n",
543 | " raise ValueError(\"Joint metadata is not available\")\n",
544 | " if metadata.actuator_type_to_metadata is None:\n",
545 | " raise ValueError(\"Actuator metadata is not available\")\n",
546 | " return metadata\n",
547 | "\n",
548 | " def get_actuators(\n",
549 | " self,\n",
550 | " physics_model: ksim.PhysicsModel,\n",
551 | " metadata: ksim.Metadata | None = None,\n",
552 | " ) -> ksim.Actuators:\n",
553 | " assert metadata is not None, \"Metadata is required\"\n",
554 | " return ksim.PositionActuators(\n",
555 | " physics_model=physics_model,\n",
556 | " metadata=metadata,\n",
557 | " )\n",
558 | "\n",
559 | " def get_physics_randomizers(self, physics_model: ksim.PhysicsModel) -> list[ksim.PhysicsRandomizer]:\n",
560 | " return [\n",
561 | " ksim.StaticFrictionRandomizer(),\n",
562 | " ksim.ArmatureRandomizer(),\n",
563 | " ksim.AllBodiesMassMultiplicationRandomizer(scale_lower=0.95, scale_upper=1.05),\n",
564 | " ksim.JointDampingRandomizer(),\n",
565 | " ksim.JointZeroPositionRandomizer(scale_lower=math.radians(-2), scale_upper=math.radians(2)),\n",
566 | " ]\n",
567 | "\n",
568 | " def get_events(self, physics_model: ksim.PhysicsModel) -> list[ksim.Event]:\n",
569 | " return [\n",
570 | " ksim.PushEvent(\n",
571 | " x_force=1.0,\n",
572 | " y_force=1.0,\n",
573 | " z_force=0.3,\n",
574 | " force_range=(0.5, 1.0),\n",
575 | " x_angular_force=0.0,\n",
576 | " y_angular_force=0.0,\n",
577 | " z_angular_force=0.0,\n",
578 | " interval_range=(0.5, 4.0),\n",
579 | " ),\n",
580 | " ]\n",
581 | "\n",
582 | " def get_resets(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reset]:\n",
583 | " return [\n",
584 | " ksim.RandomJointPositionReset.create(physics_model, {k: v for k, v in ZEROS}, scale=0.1),\n",
585 | " ksim.RandomJointVelocityReset(),\n",
586 | " ]\n",
587 | "\n",
588 | " def get_observations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Observation]:\n",
589 | " return [\n",
590 | " ksim.TimestepObservation(),\n",
591 | " ksim.JointPositionObservation(noise=math.radians(2)),\n",
592 | " ksim.JointVelocityObservation(noise=math.radians(10)),\n",
593 | " ksim.ActuatorForceObservation(),\n",
594 | " ksim.CenterOfMassInertiaObservation(),\n",
595 | " ksim.CenterOfMassVelocityObservation(),\n",
596 | " ksim.BasePositionObservation(),\n",
597 | " ksim.BaseOrientationObservation(),\n",
598 | " ksim.BaseLinearVelocityObservation(),\n",
599 | " ksim.BaseAngularVelocityObservation(),\n",
600 | " ksim.BaseLinearAccelerationObservation(),\n",
601 | " ksim.BaseAngularAccelerationObservation(),\n",
602 | " ksim.ActuatorAccelerationObservation(),\n",
603 | " ksim.ProjectedGravityObservation.create(\n",
604 | " physics_model=physics_model,\n",
605 | " framequat_name=\"imu_site_quat\",\n",
606 | " lag_range=(0.0, 0.1),\n",
607 | " noise=math.radians(1),\n",
608 | " ),\n",
609 | " ksim.SensorObservation.create(\n",
610 | " physics_model=physics_model,\n",
611 | " sensor_name=\"imu_acc\",\n",
612 | " noise=1.0,\n",
613 | " ),\n",
614 | " ksim.SensorObservation.create(\n",
615 | " physics_model=physics_model,\n",
616 | " sensor_name=\"imu_gyro\",\n",
617 | " noise=math.radians(10),\n",
618 | " ),\n",
619 | " ]\n",
620 | "\n",
621 | " def get_commands(self, physics_model: ksim.PhysicsModel) -> list[ksim.Command]:\n",
622 | " return []\n",
623 | "\n",
624 | " def get_rewards(self, physics_model: ksim.PhysicsModel) -> list[ksim.Reward]:\n",
625 | " return [\n",
626 | " # Standard rewards.\n",
627 | " ksim.NaiveForwardReward(clip_max=1.25, in_robot_frame=False, scale=3.0),\n",
628 | " ksim.NaiveForwardOrientationReward(scale=1.0),\n",
629 | " ksim.StayAliveReward(scale=1.0),\n",
630 | " ksim.UprightReward(scale=0.5),\n",
631 | " # Avoid movement penalties.\n",
632 | " ksim.AngularVelocityPenalty(index=(\"x\", \"y\"), scale=-0.1),\n",
633 | " ksim.LinearVelocityPenalty(index=(\"z\"), scale=-0.1),\n",
634 | " # Normalization penalties.\n",
635 | " ksim.AvoidLimitsPenalty.create(physics_model, scale=-0.01),\n",
636 | " ksim.JointAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n",
637 | " ksim.JointJerkPenalty(scale=-0.01, scale_by_curriculum=True),\n",
638 | " ksim.LinkAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n",
639 | " ksim.LinkJerkPenalty(scale=-0.01, scale_by_curriculum=True),\n",
640 | " ksim.ActionAccelerationPenalty(scale=-0.01, scale_by_curriculum=True),\n",
641 | " # Bespoke rewards.\n",
642 | " BentArmPenalty.create_penalty(physics_model, scale=-0.1),\n",
643 | " StraightLegPenalty.create_penalty(physics_model, scale=-0.1),\n",
644 | " ]\n",
645 | "\n",
646 | " def get_terminations(self, physics_model: ksim.PhysicsModel) -> list[ksim.Termination]:\n",
647 | " return [\n",
648 | " ksim.BadZTermination(unhealthy_z_lower=0.6, unhealthy_z_upper=1.2),\n",
649 | " ksim.FarFromOriginTermination(max_dist=10.0),\n",
650 | " ]\n",
651 | "\n",
652 | " def get_curriculum(self, physics_model: ksim.PhysicsModel) -> ksim.Curriculum:\n",
653 | " return ksim.DistanceFromOriginCurriculum(\n",
654 | " min_level_steps=5,\n",
655 | " )\n",
656 | "\n",
657 | " def get_model(self, key: PRNGKeyArray) -> Model:\n",
658 | " return Model(\n",
659 | " key,\n",
660 | " num_actor_inputs=51 if self.config.use_acc_gyro else 45,\n",
661 | " num_actor_outputs=len(ZEROS),\n",
662 | " num_critic_inputs=446,\n",
663 | " min_std=0.001,\n",
664 | " max_std=1.0,\n",
665 | " var_scale=self.config.var_scale,\n",
666 | " hidden_size=self.config.hidden_size,\n",
667 | " num_mixtures=self.config.num_mixtures,\n",
668 | " depth=self.config.depth,\n",
669 | " )\n",
670 | "\n",
671 | " def run_actor(\n",
672 | " self,\n",
673 | " model: Actor,\n",
674 | " observations: xax.FrozenDict[str, Array],\n",
675 | " commands: xax.FrozenDict[str, Array],\n",
676 | " carry: Array,\n",
677 | " ) -> tuple[distrax.Distribution, Array]:\n",
678 | " time_1 = observations[\"timestep_observation\"]\n",
679 | " joint_pos_n = observations[\"joint_position_observation\"]\n",
680 | " joint_vel_n = observations[\"joint_velocity_observation\"]\n",
681 | " proj_grav_3 = observations[\"projected_gravity_observation\"]\n",
682 | " imu_acc_3 = observations[\"sensor_observation_imu_acc\"]\n",
683 | " imu_gyro_3 = observations[\"sensor_observation_imu_gyro\"]\n",
684 | "\n",
685 | " obs = [\n",
686 | " jnp.sin(time_1),\n",
687 | " jnp.cos(time_1),\n",
688 | " joint_pos_n, # NUM_JOINTS\n",
689 | " joint_vel_n, # NUM_JOINTS\n",
690 | " proj_grav_3, # 3\n",
691 | " ]\n",
692 | " if self.config.use_acc_gyro:\n",
693 | " obs += [\n",
694 | " imu_acc_3, # 3\n",
695 | " imu_gyro_3, # 3\n",
696 | " ]\n",
697 | "\n",
698 | " obs_n = jnp.concatenate(obs, axis=-1)\n",
699 | " action, carry = model.forward(obs_n, carry)\n",
700 | "\n",
701 | " return action, carry\n",
702 | "\n",
703 | " def run_critic(\n",
704 | " self,\n",
705 | " model: Critic,\n",
706 | " observations: xax.FrozenDict[str, Array],\n",
707 | " commands: xax.FrozenDict[str, Array],\n",
708 | " carry: Array,\n",
709 | " ) -> tuple[Array, Array]:\n",
710 | " time_1 = observations[\"timestep_observation\"]\n",
711 | " dh_joint_pos_j = observations[\"joint_position_observation\"]\n",
712 | " dh_joint_vel_j = observations[\"joint_velocity_observation\"]\n",
713 | " com_inertia_n = observations[\"center_of_mass_inertia_observation\"]\n",
714 | " com_vel_n = observations[\"center_of_mass_velocity_observation\"]\n",
715 | " imu_acc_3 = observations[\"sensor_observation_imu_acc\"]\n",
716 | " imu_gyro_3 = observations[\"sensor_observation_imu_gyro\"]\n",
717 | " proj_grav_3 = observations[\"projected_gravity_observation\"]\n",
718 | " act_frc_obs_n = observations[\"actuator_force_observation\"]\n",
719 | " base_pos_3 = observations[\"base_position_observation\"]\n",
720 | " base_quat_4 = observations[\"base_orientation_observation\"]\n",
721 | "\n",
722 | " obs_n = jnp.concatenate(\n",
723 | " [\n",
724 | " jnp.sin(time_1),\n",
725 | " jnp.cos(time_1),\n",
726 | " dh_joint_pos_j, # NUM_JOINTS\n",
727 | " dh_joint_vel_j / 10.0, # NUM_JOINTS\n",
728 | " com_inertia_n, # 160\n",
729 | " com_vel_n, # 96\n",
730 | " imu_acc_3, # 3\n",
731 | " imu_gyro_3, # 3\n",
732 | " proj_grav_3, # 3\n",
733 | " act_frc_obs_n / 100.0, # NUM_JOINTS\n",
734 | " base_pos_3, # 3\n",
735 | " base_quat_4, # 4\n",
736 | " ],\n",
737 | " axis=-1,\n",
738 | " )\n",
739 | "\n",
740 | " return model.forward(obs_n, carry)\n",
741 | "\n",
742 | " def _model_scan_fn(\n",
743 | " self,\n",
744 | " actor_critic_carry: tuple[Array, Array],\n",
745 | " xs: tuple[ksim.Trajectory, PRNGKeyArray],\n",
746 | " model: Model,\n",
747 | " ) -> tuple[tuple[Array, Array], ksim.PPOVariables]:\n",
748 | " transition, rng = xs\n",
749 | "\n",
750 | " actor_carry, critic_carry = actor_critic_carry\n",
751 | " actor_dist, next_actor_carry = self.run_actor(\n",
752 | " model=model.actor,\n",
753 | " observations=transition.obs,\n",
754 | " commands=transition.command,\n",
755 | " carry=actor_carry,\n",
756 | " )\n",
757 | "\n",
758 | " # Gets the log probabilities of the action.\n",
759 | " log_probs = actor_dist.log_prob(transition.action)\n",
760 | " assert isinstance(log_probs, Array)\n",
761 | "\n",
762 | " value, next_critic_carry = self.run_critic(\n",
763 | " model=model.critic,\n",
764 | " observations=transition.obs,\n",
765 | " commands=transition.command,\n",
766 | " carry=critic_carry,\n",
767 | " )\n",
768 | "\n",
769 | " transition_ppo_variables = ksim.PPOVariables(\n",
770 | " log_probs=log_probs,\n",
771 | " values=value.squeeze(-1),\n",
772 | " )\n",
773 | "\n",
774 | " next_carry = jax.tree.map(\n",
775 | " lambda x, y: jnp.where(transition.done, x, y),\n",
776 | " self.get_initial_model_carry(rng),\n",
777 | " (next_actor_carry, next_critic_carry),\n",
778 | " )\n",
779 | "\n",
780 | " return next_carry, transition_ppo_variables\n",
781 | "\n",
782 | " def get_ppo_variables(\n",
783 | " self,\n",
784 | " model: Model,\n",
785 | " trajectory: ksim.Trajectory,\n",
786 | " model_carry: tuple[Array, Array],\n",
787 | " rng: PRNGKeyArray,\n",
788 | " ) -> tuple[ksim.PPOVariables, tuple[Array, Array]]:\n",
789 | " scan_fn = functools.partial(self._model_scan_fn, model=model)\n",
790 | " next_model_carry, ppo_variables = xax.scan(\n",
791 | " scan_fn,\n",
792 | " model_carry,\n",
793 | " (trajectory, jax.random.split(rng, len(trajectory.done))),\n",
794 | " jit_level=4,\n",
795 | " )\n",
796 | " return ppo_variables, next_model_carry\n",
797 | "\n",
798 | " def get_initial_model_carry(self, rng: PRNGKeyArray) -> tuple[Array, Array]:\n",
799 | " return (\n",
800 | " jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),\n",
801 | " jnp.zeros(shape=(self.config.depth, self.config.hidden_size)),\n",
802 | " )\n",
803 | "\n",
804 | " def sample_action(\n",
805 | " self,\n",
806 | " model: Model,\n",
807 | " model_carry: tuple[Array, Array],\n",
808 | " physics_model: ksim.PhysicsModel,\n",
809 | " physics_state: ksim.PhysicsState,\n",
810 | " observations: xax.FrozenDict[str, Array],\n",
811 | " commands: xax.FrozenDict[str, Array],\n",
812 | " rng: PRNGKeyArray,\n",
813 | " argmax: bool,\n",
814 | " ) -> ksim.Action:\n",
815 | " actor_carry_in, critic_carry_in = model_carry\n",
816 | " action_dist_j, actor_carry = self.run_actor(\n",
817 | " model=model.actor,\n",
818 | " observations=observations,\n",
819 | " commands=commands,\n",
820 | " carry=actor_carry_in,\n",
821 | " )\n",
822 | " action_j = action_dist_j.mode() if argmax else action_dist_j.sample(seed=rng)\n",
823 | " return ksim.Action(action=action_j, carry=(actor_carry, critic_carry_in))"
824 | ]
825 | },
826 | {
827 | "cell_type": "markdown",
828 | "id": "4dcf85a7",
829 | "metadata": {},
830 | "source": [
831 | "# Launch TensorBoard\n",
832 | "\n",
833 | "The below cell launches TensorBoard to visualize the training progress.\n",
834 | "\n",
835 | "After launching an experiment, please wait for ~5 minutes for the task to start running and then click the reload button in the top right corner of the TensorBoard page. You can also open the settings and check the \"Reload data\" option to automatically reload the TensorBoard. "
836 | ]
837 | },
838 | {
839 | "cell_type": "code",
840 | "execution_count": null,
841 | "id": "14",
842 | "metadata": {},
843 | "outputs": [],
844 | "source": [
845 | "# Launch TensorBoard\n",
846 | "%load_ext tensorboard\n",
847 | "%tensorboard --logdir humanoid_walking_task"
848 | ]
849 | },
850 | {
851 | "cell_type": "markdown",
852 | "id": "15",
853 | "metadata": {
854 | "id": "11"
855 | },
856 | "source": [
857 | "## Launching an Experiment\n",
858 | "\n",
859 | "To launch an experiment with `xax`, you can use `Task.launch(config)`. Note that this is usually intended to be called from the command-line, so it will by default attempt to parse additional command-line arguments unless `use_cli=False` is set.\n",
860 | "\n",
861 | "By default, runs will be logged to a directory called `run_[x]` in the task directory (/content/humanoid_walking_task/ in Colab). From there, you can download the ckpt `.bin` files and the TensorBoard logs.\n",
862 | "\n",
863 | "Also note that since this is a Jupyter notebook, the task will be unable to find the task training code and emit a warning like \"Could not resolve task path for , returning current working directory\". You can safely ignore this warning."
864 | ]
865 | },
866 | {
867 | "cell_type": "code",
868 | "execution_count": null,
869 | "id": "16",
870 | "metadata": {
871 | "colab": {
872 | "base_uri": "https://localhost:8080/"
873 | },
874 | "id": "12",
875 | "outputId": "b37b0dca-fc55-445b-f175-4f4d533bd22b"
876 | },
877 | "outputs": [],
878 | "source": [
879 | "if __name__ == \"__main__\":\n",
880 | " HumanoidWalkingTask.launch(\n",
881 | " HumanoidWalkingTaskConfig(\n",
882 | " # Training parameters.\n",
883 | " num_envs=2048,\n",
884 | " batch_size=256,\n",
885 | " num_passes=4,\n",
886 | " epochs_per_log_step=1,\n",
887 | " rollout_length_seconds=8.0,\n",
888 | " global_grad_clip=2.0,\n",
889 | " # Simulation parameters.\n",
890 | " dt=0.002,\n",
891 | " ctrl_dt=0.02,\n",
892 | " iterations=8,\n",
893 | " ls_iterations=8,\n",
894 | " action_latency_range=(0.003, 0.01), # Simulate 3-10ms of latency.\n",
895 | " drop_action_prob=0.05, # Drop 5% of commands.\n",
896 | " # Visualization parameters\n",
897 | " render_track_body_id=0,\n",
898 | " # Checkpointing parameters.\n",
899 | " save_every_n_seconds=60,\n",
900 | " ),\n",
901 | " use_cli=False,\n",
902 | " )"
903 | ]
904 | }
905 | ],
906 | "metadata": {
907 | "accelerator": "GPU",
908 | "colab": {
909 | "gpuType": "T4",
910 | "provenance": []
911 | },
912 | "kernelspec": {
913 | "display_name": "Python 3",
914 | "name": "python3"
915 | },
916 | "language_info": {
917 | "codemirror_mode": {
918 | "name": "ipython",
919 | "version": 3
920 | },
921 | "file_extension": ".py",
922 | "mimetype": "text/x-python",
923 | "name": "python",
924 | "nbconvert_exporter": "python",
925 | "pygments_lexer": "ipython3",
926 | "version": "3.11.11"
927 | }
928 | },
929 | "nbformat": 4,
930 | "nbformat_minor": 5
931 | }
932 |
--------------------------------------------------------------------------------