├── .gitignore ├── gifs ├── first.gif ├── op3_17.gif ├── box_handling.gif ├── great_goal.gif ├── great_save.gif ├── humanoid_foot_84.gif ├── humanoid_head_7.gif ├── op3_foot_ballance.gif ├── humanoid_box_head_30.gif └── humanoid_head_balance.gif ├── README.md └── tutorials ├── foot_tricks.ipynb ├── head_tricks.ipynb └── robot_tricks.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /gifs/first.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/first.gif -------------------------------------------------------------------------------- /gifs/op3_17.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/op3_17.gif -------------------------------------------------------------------------------- /gifs/box_handling.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/box_handling.gif -------------------------------------------------------------------------------- /gifs/great_goal.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/great_goal.gif -------------------------------------------------------------------------------- /gifs/great_save.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/great_save.gif -------------------------------------------------------------------------------- /gifs/humanoid_foot_84.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/humanoid_foot_84.gif -------------------------------------------------------------------------------- /gifs/humanoid_head_7.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/humanoid_head_7.gif -------------------------------------------------------------------------------- /gifs/op3_foot_ballance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/op3_foot_ballance.gif -------------------------------------------------------------------------------- /gifs/humanoid_box_head_30.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/humanoid_box_head_30.gif -------------------------------------------------------------------------------- /gifs/humanoid_head_balance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/goncalog/ai-robotics/HEAD/gifs/humanoid_head_balance.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ai-robotics 2 | AI Robotics tutorials for hobbyists 3 | 4 | ## Summary 5 | Learn how to train a virtual humanoid/robot to do football tricks with Reinforcement Learning 6 | 7 | From -> To 8 | 9 | ![humanoid_falling movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/first.gif) 10 | ![op3_foot_bounces movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/op3_17.gif) 11 | 12 | Learn how to: 13 | - tune training hyperparameters for better model performance 14 | - optimise reward functions for faster learning 15 | - get results in a couple of hours of training, and for free (i also don't own any GPUs) 16 | 17 | ## Tutorials 18 | 1 - [Foot tricks](https://github.com/goncalog/ai-robotics/blob/main/tutorials/foot_tricks.ipynb) 19 | 20 | ![humanoid_foot_bounces movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/humanoid_foot_84.gif) 21 | 22 | 2 - [Head tricks](https://github.com/goncalog/ai-robotics/blob/main/tutorials/head_tricks.ipynb) 23 | 24 | ![humanoid_head_bounces movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/humanoid_head_7.gif) 25 | ![humanoid_box_head_bounces movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/humanoid_box_head_30.gif) 26 | 27 | ![humanoid_head_ballance movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/humanoid_head_balance.gif) 28 | 29 | 3 - [Penalty taking and stopping](https://github.com/goncalog/ai-robotics/blob/main/tutorials/penalties.ipynb) 30 | 31 | ![great_goal movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/great_goal.gif) 32 | ![great_save movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/great_save.gif) 33 | 34 | 4 - [Robot tricks](https://github.com/goncalog/ai-robotics/blob/main/tutorials/robot_tricks.ipynb) 35 | 36 | ![op3_foot_bounces movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/op3_17.gif) 37 | ![op3_foot_ballance movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/op3_foot_ballance.gif) 38 | 39 | 5 - Box handling [Coming soon] 40 | 41 | ![h1_box_handling movie](https://github.com/goncalog/ai-robotics/raw/main/gifs/box_handling.gif) 42 | 43 | ## Notes 44 | * Most of these results can be achieved in a few hours of GPU training. Sometimes less. And for free on [Kaggle](https://www.kaggle.com/) (30h per week usage limit) 45 | * Notebooks originally adapted from [Mujoco's MJX tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb) 46 | -------------------------------------------------------------------------------- /tutorials/foot_tricks.ipynb: -------------------------------------------------------------------------------- 1 | {"metadata":{"colab":{"gpuClass":"premium","private_outputs":true,"provenance":[],"collapsed_sections":["YvyGCsgSCxHQ","P1K6IznI2y83"],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"accelerator":"GPU","language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30683,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Tutorial: Foot Tricks\n\nIn this tutorial you will learn how to use the physics simulator MuJoCo, and Reinforcement Learning to teach a humanoid to bounce a ball with its foot - an essential skill in the art of football.","metadata":{"id":"Jz-I63-7YgRO"}},{"cell_type":"markdown","source":"## 1 - Kaggle\nIt's recommended to run this notebook on www.kaggle.com where you can use two T4 GPUs for 30h/week for free.\n\nTo import this notebook into Kaggle you need to:\n- Login to your Kaggle account\n- Create a new notebook\n- Click on `File` and then `Import Notebook`\n- Select the tab `GitHub`\n- Search `goncalog/ai-robotics`\n- Select the file `tutorials/foot_tricks.ipynb`\n- Click the `Import` button\n\nTo run this notebook you can either click the `Run All` button or run each cell individually by clicking the `Run current cell` button.","metadata":{}},{"cell_type":"markdown","source":"## 2 - Config\n\nThis is the configuration to run the tutorial, it includes:\n- Training hyperparameters\n- Mujoco environment variables\n- File paths\n- Rendering variables","metadata":{}},{"cell_type":"code","source":"num_timesteps = 180_000_000\nnum_evals = 9\n# num_envs: the number of parallel environments to use for rollouts\nnum_envs = 2048\n\n# learning_rate: learning rate for ppo loss\nlearning_rate = 3e-4\n# discounting: discounting rate\ndiscounting = 0.97\n# episode_length: the length of an environment episode\nepisode_length = 1000\n# normalize_observations: whether to normalize observations\nnormalize_observations = True\n# action_repeat: the number of timesteps to repeat an action\naction_repeat = 1\n# unroll_length: the number of timesteps to unroll in each environment.\n# The PPO loss is computed over `unroll_length` timesteps\nunroll_length = 10\n# entropy_cost: entropy reward for ppo loss, higher values increase entropy of the policy\nentropy_cost = 1e-3\n# batch_size: the batch size for each minibatch SGD step\nbatch_size = 1024\n# num_minibatches: the number of times to run the SGD step,\n# each with a different minibatch with leading dimension of `batch_size`\nnum_minibatches = 32\n# num_updates_per_batch: the number of times to run the gradient update over\n# all minibatches before doing a new environment rollout\nnum_updates_per_batch = 8\n# reward_scaling: float scaling for reward\nreward_scaling = 1\n# clipping_epsilon: clipping epsilon for PPO loss\nclipping_epsilon = 0.3\n# gae_lambda: General advantage estimation lambda\ngae_lambda = 0.95\n# normalize_advantage: whether to normalize advantage estimate\nnormalize_advantage = True\n\npolicy_hidden_layer_sizes = (32,) * 4\nvalue_hidden_layer_sizes = (256,) * 5\n\nball_size = 0.15\ntorso_index = 2 # index of torso body in mjx data (it contains the head geom)\nball_height = 0.8 # z coordinate of centre of mass\nball_x = 0.35 # x coordinate of centre of mass\nfoot_left_index = 10 # index of foot_left body in mjx data\nbounce_threshold = ball_height - 0.05 # z coordinate\n\n# Simulation time step in seconds. \n# This is the single most important parameter affecting the speed-accuracy trade-off \n# which is inherent in every physics simulation. \n# Smaller values result in better accuracy and stability\nmj_model_timestep = 0.005\n\nsave_path = \"/kaggle/working/mjx_brax_nn\"\n\nnum_rollouts = 1\nnum_bounces_threshold = 0","metadata":{"id":"-Xt8DyfJYrbd","execution":{"iopub.status.busy":"2024-04-18T15:28:41.918292Z","iopub.execute_input":"2024-04-18T15:28:41.919062Z","iopub.status.idle":"2024-04-18T15:28:41.928845Z","shell.execute_reply.started":"2024-04-18T15:28:41.919030Z","shell.execute_reply":"2024-04-18T15:28:41.927841Z"},"trusted":true},"execution_count":35,"outputs":[]},{"cell_type":"markdown","source":"## 3 - Install MuJoCo, MJX, and Brax","metadata":{"id":"YvyGCsgSCxHQ"}},{"cell_type":"code","source":"!pip install mujoco\n!pip install mujoco_mjx\n!pip install brax","metadata":{"id":"Xqo7pyX-n72M","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Check if MuJoCo installation was successful\n\nimport distutils.util\nimport os\nimport subprocess\nif subprocess.run('nvidia-smi').returncode:\n raise RuntimeError(\n 'Cannot communicate with GPU. '\n 'Make sure you are using a GPU runtime. '\n 'Go to the Runtime menu and select Choose runtime type.')\n\n# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n# This is usually installed as part of an Nvidia driver package, but this\n# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\nNVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/'\nNVIDIA_ICD_CONFIG_FILE = '10_nvidia.json'\nif not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n os.makedirs(NVIDIA_ICD_CONFIG_PATH)\n file_path = os.path.join(NVIDIA_ICD_CONFIG_PATH, NVIDIA_ICD_CONFIG_FILE)\n with open(file_path, 'w') as f:\n f.write(\"\"\"{\n \"file_format_version\" : \"1.0.0\",\n \"ICD\" : {\n \"library_path\" : \"libEGL_nvidia.so.0\"\n }\n}\n\"\"\")\n\n# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\nxla_flags = os.environ.get('XLA_FLAGS', '')\nxla_flags += ' --xla_gpu_triton_gemm_any=True'\nos.environ['XLA_FLAGS'] = xla_flags\n\n# Configure MuJoCo to use the EGL rendering backend (requires GPU)\nprint('Setting environment variable to use GPU rendering:')\n%env MUJOCO_GL=egl\n\ntry:\n print('Checking that the installation succeeded:')\n import mujoco\n mujoco.MjModel.from_xml_string('')\nexcept Exception as e:\n raise e from RuntimeError(\n 'Something went wrong during installation. Check the shell output above '\n 'for more information.\\n'\n 'If using a hosted runtime, make sure you enable GPU acceleration '\n 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n\nprint('Installation successful.')","metadata":{"id":"IbZxYDxzoz5R","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Import packages for plotting and creating graphics\nimport time\nimport itertools\nimport numpy as np\nfrom typing import Callable, NamedTuple, Optional, Union, List\n\n# Graphics and plotting.\nprint('Installing mediapy:')\n!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n!pip install -q mediapy\nimport mediapy as media\nimport matplotlib.pyplot as plt\n\n# More legible printing from numpy.\nnp.set_printoptions(precision=3, suppress=True, linewidth=100)","metadata":{"id":"T5f4w3Kq2X14","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Import MuJoCo, MJX, and Brax\n\nfrom datetime import datetime\nimport functools\nimport jax\nfrom jax import numpy as jp\nimport numpy as np\nfrom typing import Any, Dict, Sequence, Tuple, Union\n\nfrom brax import base\nfrom brax import envs\nfrom brax import math\nfrom brax.base import Base, Motion, Transform\nfrom brax.envs.base import Env, PipelineEnv, State\nfrom brax.mjx.base import State as MjxState\nfrom brax.io import html, mjcf, model\nfrom brax.training import distribution, networks\n\nfrom etils import epath\nfrom flax import linen, struct\nfrom matplotlib import pyplot as plt\nimport mediapy as media\nfrom ml_collections import config_dict\nimport mujoco\nfrom mujoco import mjx\n","metadata":{"id":"ObF1UXrkb0Nd","execution":{"iopub.status.busy":"2024-04-18T14:11:17.892024Z","iopub.execute_input":"2024-04-18T14:11:17.895648Z","iopub.status.idle":"2024-04-18T14:11:22.029329Z","shell.execute_reply.started":"2024-04-18T14:11:17.895581Z","shell.execute_reply":"2024-04-18T14:11:22.028499Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"markdown","source":"## 4 - Setting up the Humanoid environment with MJX\nMJX is an implementation of MuJoCo written in [JAX](https://jax.readthedocs.io/en/latest/index.html), enabling large batch training on GPU/TPU. In this notebook, we train RL policies with MJX.\n\nHere we implement our environment by adapting the original [Humanoid](https://github.com/google-deepmind/mujoco/blob/546a27ca72397b888e314ee4549bcf12d9fd5957/model/humanoid/humanoid.xml) environment to also include a ball. Notice that `reset` initializes a `State`, and `step` steps through the physics step and reward logic. The reward and stepping logic train the Humanoid to bounce a ball with its left foot.","metadata":{"id":"RAv6WUVUm78k"}},{"cell_type":"code","source":"# Humanoid XML\n\nball_material = \"\"\"\n \n \n \"\"\"\nball_default = f\"\"\"\n \n \n \n \"\"\"\nball_body = f\"\"\"\n \n \n \n \n \"\"\"\nball_contact = \"\"\"\n \n \n \n \"\"\"\n\nxml = f\"\"\"\n\n \n\n \n \n \n \n \n \n\n \n\n \n \n \n \n\n \n \n \n {ball_material}\n \n\n \n {ball_default}\n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n {ball_body}\n \n\n \n \n \n \n \n \n \n {ball_contact}\n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n\"\"\"","metadata":{"id":"G5RKmMLkUuC8","cellView":"form","execution":{"iopub.status.busy":"2024-04-18T14:11:22.030726Z","iopub.execute_input":"2024-04-18T14:11:22.031229Z","iopub.status.idle":"2024-04-18T14:11:22.047138Z","shell.execute_reply.started":"2024-04-18T14:11:22.031200Z","shell.execute_reply":"2024-04-18T14:11:22.046195Z"},"trusted":true},"execution_count":6,"outputs":[]},{"cell_type":"code","source":"# Humanoid Env\n\nclass Humanoid(PipelineEnv):\n\n def __init__(\n self,\n terminate_when_unhealthy=True,\n reset_noise_scale=1e-2,\n exclude_current_positions_from_observation=True,\n **kwargs,\n ):\n mj_model = mujoco.MjModel.from_xml_string(xml)\n mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG\n mj_model.opt.iterations = 6\n mj_model.opt.ls_iterations = 6\n\n sys = mjcf.load_model(mj_model)\n\n physics_steps_per_control_step = 5\n kwargs['n_frames'] = kwargs.get(\n 'n_frames', physics_steps_per_control_step)\n kwargs['backend'] = 'mjx'\n\n super().__init__(sys, **kwargs)\n\n self._terminate_when_unhealthy = terminate_when_unhealthy\n self._reset_noise_scale = reset_noise_scale\n self._exclude_current_positions_from_observation = (\n exclude_current_positions_from_observation\n )\n\n\n def reset(self, rng: jp.ndarray) -> State:\n \"\"\"Resets the environment to an initial state.\"\"\"\n rng, rng1, rng2 = jax.random.split(rng, 3)\n\n low, hi = -self._reset_noise_scale, self._reset_noise_scale\n qpos = self.sys.qpos0 + jax.random.uniform(\n rng1, (self.sys.nq,), minval=low, maxval=hi\n )\n qvel = jax.random.uniform(\n rng2, (self.sys.nv,), minval=low, maxval=hi\n )\n \n data = self.pipeline_init(qpos, qvel)\n\n obs = self._get_obs(data, jp.zeros(self.sys.nu))\n reward, done, zero = jp.zeros(3)\n metrics = {\n 'ball_reward': zero,\n 'reward_quadctrl': zero,\n 'reward_alive': zero,\n 'reward': zero,\n 'bounces': zero,\n }\n return State(data, obs, reward, done, metrics)\n\n\n def step(self, state: State, action: jp.ndarray) -> State:\n \"\"\"Runs one timestep of the environment's dynamics.\"\"\"\n data0 = state.pipeline_state\n data = self.pipeline_step(data0, action)\n \n reward, done = self._get_reward(state, action, data0, data)\n obs = self._get_obs(data, action)\n return state.replace(\n pipeline_state=data, obs=obs, reward=reward, done=done\n )\n\n\n def _get_obs(\n self, data: mjx.Data, action: jp.ndarray\n ) -> jp.ndarray:\n \"\"\"Observes humanoid body and ball position, velocities, and angles.\"\"\"\n position = data.qpos\n if self._exclude_current_positions_from_observation:\n position = position[2:]\n\n # external_contact_forces are excluded\n return jp.concatenate([\n # qpos: position / nq: number of generalized coordinates = dim(qpos)\n position,\n # qvel: velocity / nv: number of degrees of freedom = dim(qvel)\n data.qvel,\n # cinert: com-based body inertia and mass / (nbody, 10)\n data.cinert[1:].ravel(),\n # cvel: com-based velocity [3D rot; 3D tran] / (nbody, 6)\n data.cvel[1:].ravel(),\n # qfrc_actuator: actuator force / nv: number of degrees of freedom\n data.qfrc_actuator,\n ])\n\n\n def _get_reward(\n self, state: State, action: jp.ndarray, data0: mjx.Data, data: mjx.Data\n ) -> Tuple[jp.ndarray, jp.ndarray]:\n \"\"\"Apply reward func based on ball distance to normal of the left foot and target height.\"\"\"\n ctrl_cost_weight = 0.1\n healthy_reward = 5.0\n healthy_z_range = (1.0, 3.0)\n ball_reward = 5.0\n ball_healthy_z_range = (0.3, 3.0)\n ball_reward_min_z = 0.3\n ball_reward_target_z = 1.0\n distance_feet_reward = 5.0\n distance_feet_max = 2.0\n \n com_before_ball = data0.subtree_com[-1]\n com_after_ball = data.subtree_com[-1]\n com_after_foot = data.subtree_com[foot_left_index]\n distance_foot = jp.sqrt(jp.square(com_after_ball[0] - com_after_foot[0]) + jp.square(com_after_ball[1] - com_after_foot[1]))\n\n min_z, max_z = healthy_z_range\n is_healthy = jp.where(data.q[torso_index] < min_z, 0.0, 1.0)\n is_healthy = jp.where(data.q[torso_index] > max_z, 0.0, is_healthy)\n\n ball_min_z, ball_max_z = ball_healthy_z_range\n is_healthy = jp.where(com_after_ball[2] < ball_min_z, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[2] > ball_max_z, 0.0, is_healthy)\n \n is_healthy = jp.where(distance_foot > distance_feet_max, 0.0, is_healthy)\n\n ctrl_cost = ctrl_cost_weight * jp.sum(jp.square(action))\n \n distance_target_height = jp.sqrt(jp.square(com_after_ball[2] - ball_reward_target_z))\n ball_reward = ball_reward * (1 - (distance_target_height / (ball_max_z - ball_reward_target_z)))\n is_ball_reward = jp.where(com_after_ball[2] >= ball_reward_min_z, 1.0, 0.0)\n \n distance_feet_reward = distance_feet_reward * (1 - (distance_foot / distance_feet_max))\n \n reward = ball_reward * is_ball_reward + healthy_reward - ctrl_cost + distance_feet_reward\n\n state.metrics.update(\n ball_reward=ball_reward * is_ball_reward,\n reward_quadctrl=-ctrl_cost,\n reward_alive=healthy_reward,\n reward=reward,\n bounces=self._is_bounce(com_before_ball, com_after_ball),\n )\n \n done = 1.0 - is_healthy\n return reward, done\n \n \n # There is a lot of room to improve this function as it should check for contacts\n # between the ball and the lower left limb of the Humanoid\n # (at the time of implementation the contacts data wasn't easily accessible in MJX)\n def _is_bounce(\n self, com_before_ball: jp.ndarray, com_after_ball: jp.ndarray\n ) -> jp.ndarray:\n \"\"\"Check if ball bounced.\"\"\"\n is_bounce = jp.where(com_before_ball[2] < bounce_threshold, 1.0, 0.0)\n is_bounce = jp.where(com_after_ball[2] >= bounce_threshold, is_bounce, 0.0)\n return is_bounce\n \n\nenvs.register_environment(\"humanoid\", Humanoid)","metadata":{"id":"mtGMYNLE3QJN","cellView":"form","execution":{"iopub.status.busy":"2024-04-18T14:11:22.049162Z","iopub.execute_input":"2024-04-18T14:11:22.049526Z","iopub.status.idle":"2024-04-18T14:11:22.076037Z","shell.execute_reply.started":"2024-04-18T14:11:22.049499Z","shell.execute_reply":"2024-04-18T14:11:22.075036Z"},"trusted":true},"execution_count":7,"outputs":[]},{"cell_type":"markdown","source":"## 5 - Visualize a rollout\n\nLet's instantiate the environment and visualize a short rollout.\n\nNOTE: Since episodes terminate early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane, and the lower left limb and the ball. The other contacts weren't included. This also speeds up the training later on.","metadata":{"id":"P1K6IznI2y83"}},{"cell_type":"code","source":"# Instantiate the environment\nenv_name = \"humanoid\"\nenv = envs.get_environment(env_name)\n\n# Define the jit reset/step functions\njit_reset = jax.jit(env.reset)\njit_step = jax.jit(env.step)","metadata":{"id":"EhKLFK54C1CH","execution":{"iopub.status.busy":"2024-04-18T14:11:22.077149Z","iopub.execute_input":"2024-04-18T14:11:22.077463Z","iopub.status.idle":"2024-04-18T14:11:26.730947Z","shell.execute_reply.started":"2024-04-18T14:11:22.077438Z","shell.execute_reply":"2024-04-18T14:11:26.730059Z"},"trusted":true},"execution_count":8,"outputs":[]},{"cell_type":"code","source":"# Initialize the state\nstate = jit_reset(jax.random.PRNGKey(0))\nprint(f\"Observations size: {len(state.obs)}\")\nprint(f\"Actions size: {env.sys.nu}\")\n\nrollout = [state.pipeline_state]\n\n# Grab a trajectory\nfor i in range(50):\n # ctrl: control / nu: number of actuators/controls = dim(ctrl)\n ctrl = -0.1 * jp.ones(env.sys.nu)\n state = jit_step(state, ctrl)\n rollout.append(state.pipeline_state)\n\nmedia.show_video(env.render(rollout, camera='side', height=480, width=640), fps=1.0 / env.dt)","metadata":{"id":"Ph8u-v2Q2xLS","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 6 - Define the training functions\n\nLet's define the training functions using [PPO](https://openai.com/research/openai-baselines-ppo) to make the Humanoid bounce the ball with its left foot.","metadata":{"id":"BQDG6NQ1CbZD"}},{"cell_type":"code","source":"# Define the Acting/Evaluator (adapted from https://github.com/google/brax)\n\n\"\"\"Brax training acting functions.\"\"\"\n\nimport time\nfrom typing import Callable, Sequence, Tuple, Union\n\nfrom brax import envs\nfrom brax.training.types import Metrics\nfrom brax.training.types import Policy\nfrom brax.training.types import PolicyParams\nfrom brax.training.types import PRNGKey\nfrom brax.training.types import Transition\nfrom brax.v1 import envs as envs_v1\nimport jax\nimport numpy as np\n\nActingState = Union[envs.State, envs_v1.State]\nActingEnv = Union[envs.Env, envs_v1.Env, envs_v1.Wrapper]\n\n\ndef actor_step(\n env: ActingEnv,\n env_state: ActingState,\n policy: Policy,\n key: PRNGKey,\n extra_fields: Sequence[str] = ()\n) -> Tuple[ActingState, Transition]:\n \"\"\"Collect data.\"\"\"\n actions, policy_extras = policy(env_state.obs, key)\n nstate = env.step(env_state, actions)\n state_extras = {x: nstate.info[x] for x in extra_fields}\n return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray\n observation=env_state.obs,\n action=actions,\n reward=nstate.reward,\n discount=1 - nstate.done,\n next_observation=nstate.obs,\n extras={\n 'policy_extras': policy_extras,\n 'state_extras': state_extras\n })\n\n\ndef generate_unroll(\n env: ActingEnv,\n env_state: ActingState,\n policy: Policy,\n key: PRNGKey,\n unroll_length: int,\n extra_fields: Sequence[str] = ()\n) -> Tuple[ActingState, Transition]:\n \"\"\"Collect trajectories of given unroll_length.\"\"\"\n\n @jax.jit\n def f(carry, unused_t):\n state, current_key = carry\n current_key, next_key = jax.random.split(current_key)\n nstate, transition = actor_step(\n env, state, policy, current_key, extra_fields=extra_fields)\n return (nstate, next_key), transition\n\n (final_state, _), data = jax.lax.scan(\n f, (env_state, key), (), length=unroll_length)\n return final_state, data\n\n\n# TODO: Consider moving this to its own file.\nclass Evaluator:\n \"\"\"Class to run evaluations.\"\"\"\n\n def __init__(self, eval_env: envs.Env,\n eval_policy_fn: Callable[[PolicyParams],\n Policy], num_eval_envs: int,\n episode_length: int, action_repeat: int, key: PRNGKey):\n \"\"\"Init.\n\n Args:\n eval_env: Batched environment to run evals on.\n eval_policy_fn: Function returning the policy from the policy parameters.\n num_eval_envs: Each env will run 1 episode in parallel for each eval.\n episode_length: Maximum length of an episode.\n action_repeat: Number of physics steps per env step.\n key: RNG key.\n \"\"\"\n self._key = key\n self._eval_walltime = 0.\n\n eval_env = envs.training.EvalWrapper(eval_env)\n\n def generate_eval_unroll(policy_params: PolicyParams,\n key: PRNGKey) -> ActingState:\n reset_keys = jax.random.split(key, num_eval_envs)\n eval_first_state = eval_env.reset(reset_keys)\n return generate_unroll(\n eval_env,\n eval_first_state,\n eval_policy_fn(policy_params),\n key,\n unroll_length=episode_length // action_repeat)[0]\n\n self._generate_eval_unroll = jax.jit(generate_eval_unroll)\n self._steps_per_unroll = episode_length * num_eval_envs\n\n def run_evaluation(self,\n policy_params: PolicyParams,\n training_metrics: Metrics,\n aggregate_episodes: bool = True) -> Metrics:\n \"\"\"Run one epoch of evaluation.\"\"\"\n self._key, unroll_key = jax.random.split(self._key)\n\n t = time.time()\n eval_state = self._generate_eval_unroll(policy_params, unroll_key)\n eval_metrics = eval_state.info['eval_metrics']\n eval_metrics.active_episodes.block_until_ready()\n epoch_eval_time = time.time() - t\n metrics = {}\n for fn in [np.mean, np.std, np.max]:\n suffix = '_std' if fn == np.std else '_max' if fn == np.max else ''\n metrics.update(\n {\n f'eval/episode_{name}{suffix}': (\n fn(value) if aggregate_episodes else value\n )\n for name, value in eval_metrics.episode_metrics.items()\n }\n )\n metrics['eval/avg_episode_length'] = np.mean(eval_metrics.episode_steps)\n metrics['eval/epoch_eval_time'] = epoch_eval_time\n metrics['eval/sps'] = self._steps_per_unroll / epoch_eval_time\n self._eval_walltime = self._eval_walltime + epoch_eval_time\n metrics = {\n 'eval/walltime': self._eval_walltime,\n **training_metrics,\n **metrics\n }\n\n return metrics # pytype: disable=bad-return-type # jax-ndarray","metadata":{"execution":{"iopub.status.busy":"2024-04-18T14:12:22.415865Z","iopub.execute_input":"2024-04-18T14:12:22.416165Z","iopub.status.idle":"2024-04-18T14:12:22.873283Z","shell.execute_reply.started":"2024-04-18T14:12:22.416135Z","shell.execute_reply":"2024-04-18T14:12:22.872307Z"},"trusted":true},"execution_count":10,"outputs":[]},{"cell_type":"code","source":"# Define the Training Function (adapted from https://github.com/google/brax)\n\n\"\"\"Proximal policy optimization training.\n\nSee: https://arxiv.org/pdf/1707.06347.pdf\n\"\"\"\n\nimport functools\nimport time\nfrom typing import Callable, Optional, Tuple, Union\n\nfrom absl import logging\nfrom brax import base\nfrom brax import envs\nfrom brax.training import gradients\nfrom brax.training import pmap\nfrom brax.training import types\nfrom brax.training.acme import running_statistics\nfrom brax.training.acme import specs\nfrom brax.training.agents.ppo import losses as ppo_losses\nfrom brax.training.agents.ppo import networks as ppo_networks\nfrom brax.training.types import Params, PolicyParams, PreprocessorParams\nfrom brax.training.types import PRNGKey\nfrom brax.v1 import envs as envs_v1\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\n\nInferenceParams = Tuple[running_statistics.NestedMeanStd, Params]\nMetrics = types.Metrics\nValueParams = Any\n\n_PMAP_AXIS_NAME = 'i'\n\n\n@flax.struct.dataclass\nclass TrainingState:\n \"\"\"Contains training state for the learner.\"\"\"\n optimizer_state: optax.OptState\n params: ppo_losses.PPONetworkParams\n normalizer_params: running_statistics.RunningStatisticsState\n env_steps: jnp.ndarray\n\n\ndef _unpmap(v):\n return jax.tree_util.tree_map(lambda x: x[0], v)\n\n\ndef _strip_weak_type(tree):\n # brax user code is sometimes ambiguous about weak_type. in order to\n # avoid extra jit recompilations we strip all weak types from user input\n def f(leaf):\n leaf = jnp.asarray(leaf)\n return leaf.astype(leaf.dtype)\n return jax.tree_util.tree_map(f, tree)\n\n\ndef train(\n environment: Union[envs_v1.Env, envs.Env],\n num_timesteps: int,\n episode_length: int,\n action_repeat: int = 1,\n num_envs: int = 1,\n max_devices_per_host: Optional[int] = None,\n num_eval_envs: int = 128,\n learning_rate: float = 1e-4,\n entropy_cost: float = 1e-4,\n discounting: float = 0.9,\n seed: int = 0,\n unroll_length: int = 10,\n batch_size: int = 32,\n num_minibatches: int = 16,\n num_updates_per_batch: int = 2,\n num_evals: int = 1,\n num_resets_per_eval: int = 0,\n normalize_observations: bool = False,\n reward_scaling: float = 1.0,\n clipping_epsilon: float = 0.3,\n gae_lambda: float = 0.95,\n deterministic_eval: bool = False,\n network_factory: types.NetworkFactory[\n ppo_networks.PPONetworks\n ] = ppo_networks.make_ppo_networks,\n progress_fn: Callable[[int, Metrics], None] = lambda *args: None,\n normalize_advantage: bool = True,\n eval_env: Optional[envs.Env] = None,\n policy_params_fn: Callable[..., None] = lambda *args: None,\n randomization_fn: Optional[\n Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]\n ] = None,\n saved_params: Optional[\n Tuple[PreprocessorParams, PolicyParams, ValueParams]\n ] = None,\n):\n \"\"\"PPO training.\n\n Args:\n environment: the environment to train\n num_timesteps: the total number of environment steps to use during training\n episode_length: the length of an environment episode\n action_repeat: the number of timesteps to repeat an action\n num_envs: the number of parallel environments to use for rollouts\n NOTE: `num_envs` must be divisible by the total number of chips since each\n chip gets `num_envs // total_number_of_chips` environments to roll out\n NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since\n data generated by `num_envs` parallel envs gets used for gradient\n updates over `num_minibatches` of data, where each minibatch has a\n leading dimension of `batch_size`\n max_devices_per_host: maximum number of chips to use per host process\n num_eval_envs: the number of envs to use for evaluation. Each env will run 1\n episode, and all envs run in parallel during eval.\n learning_rate: learning rate for ppo loss\n entropy_cost: entropy reward for ppo loss, higher values increase entropy\n of the policy\n discounting: discounting rate\n seed: random seed\n unroll_length: the number of timesteps to unroll in each environment. The\n PPO loss is computed over `unroll_length` timesteps\n batch_size: the batch size for each minibatch SGD step\n num_minibatches: the number of times to run the SGD step, each with a\n different minibatch with leading dimension of `batch_size`\n num_updates_per_batch: the number of times to run the gradient update over\n all minibatches before doing a new environment rollout\n num_evals: the number of evals to run during the entire training run.\n Increasing the number of evals increases total training time\n num_resets_per_eval: the number of environment resets to run between each\n eval. The environment resets occur on the host\n normalize_observations: whether to normalize observations\n reward_scaling: float scaling for reward\n clipping_epsilon: clipping epsilon for PPO loss\n gae_lambda: General advantage estimation lambda\n deterministic_eval: whether to run the eval with a deterministic policy\n network_factory: function that generates networks for policy and value\n functions\n progress_fn: a user-defined callback function for reporting/plotting metrics\n normalize_advantage: whether to normalize advantage estimate\n eval_env: an optional environment for eval only, defaults to `environment`\n policy_params_fn: a user-defined callback function that can be used for\n saving policy checkpoints\n randomization_fn: a user-defined callback function that generates randomized\n environments\n saved_params: params to init the training with; includes normalizer_params\n and policy and value network params\n\n Returns:\n Tuple of (make_policy function, network params, metrics)\n \"\"\"\n assert batch_size * num_minibatches % num_envs == 0\n xt = time.time()\n\n process_count = jax.process_count()\n process_id = jax.process_index()\n local_device_count = jax.local_device_count()\n local_devices_to_use = local_device_count\n if max_devices_per_host:\n local_devices_to_use = min(local_devices_to_use, max_devices_per_host)\n logging.info(\n 'Device count: %d, process count: %d (id %d), local device count: %d, '\n 'devices to be used count: %d', jax.device_count(), process_count,\n process_id, local_device_count, local_devices_to_use)\n device_count = local_devices_to_use * process_count\n\n # The number of environment steps executed for every training step.\n env_step_per_training_step = (\n batch_size * unroll_length * num_minibatches * action_repeat)\n num_evals_after_init = max(num_evals - 1, 1)\n # The number of training_step calls per training_epoch call.\n # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step *\n # num_resets_per_eval))\n num_training_steps_per_epoch = np.ceil(\n num_timesteps\n / (\n num_evals_after_init\n * env_step_per_training_step\n * max(num_resets_per_eval, 1)\n )\n ).astype(int)\n\n key = jax.random.PRNGKey(seed)\n global_key, local_key = jax.random.split(key)\n del key\n local_key = jax.random.fold_in(local_key, process_id)\n local_key, key_env, eval_key = jax.random.split(local_key, 3)\n # key_networks should be global, so that networks are initialized the same\n # way for different processes.\n key_policy, key_value = jax.random.split(global_key)\n del global_key\n\n assert num_envs % device_count == 0\n\n v_randomization_fn = None\n if randomization_fn is not None:\n randomization_batch_size = num_envs // local_device_count\n # all devices gets the same randomization rng\n randomization_rng = jax.random.split(key_env, randomization_batch_size)\n v_randomization_fn = functools.partial(\n randomization_fn, rng=randomization_rng\n )\n\n if isinstance(environment, envs.Env):\n wrap_for_training = envs.training.wrap\n else:\n wrap_for_training = envs_v1.wrappers.wrap_for_training\n\n env = wrap_for_training(\n environment,\n episode_length=episode_length,\n action_repeat=action_repeat,\n randomization_fn=v_randomization_fn,\n )\n\n reset_fn = jax.jit(jax.vmap(env.reset))\n key_envs = jax.random.split(key_env, num_envs // process_count)\n key_envs = jnp.reshape(key_envs,\n (local_devices_to_use, -1) + key_envs.shape[1:])\n env_state = reset_fn(key_envs)\n\n normalize = lambda x, y: x\n if normalize_observations:\n normalize = running_statistics.normalize\n ppo_network = network_factory(\n env_state.obs.shape[-1],\n env.action_size,\n preprocess_observations_fn=normalize)\n make_policy = ppo_networks.make_inference_fn(ppo_network)\n\n optimizer = optax.adam(learning_rate=learning_rate)\n\n loss_fn = functools.partial(\n ppo_losses.compute_ppo_loss,\n ppo_network=ppo_network,\n entropy_cost=entropy_cost,\n discounting=discounting,\n reward_scaling=reward_scaling,\n gae_lambda=gae_lambda,\n clipping_epsilon=clipping_epsilon,\n normalize_advantage=normalize_advantage)\n\n gradient_update_fn = gradients.gradient_update_fn(\n loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)\n\n def minibatch_step(\n carry, data: types.Transition,\n normalizer_params: running_statistics.RunningStatisticsState):\n optimizer_state, params, key = carry\n key, key_loss = jax.random.split(key)\n (_, metrics), params, optimizer_state = gradient_update_fn(\n params,\n normalizer_params,\n data,\n key_loss,\n optimizer_state=optimizer_state)\n\n return (optimizer_state, params, key), metrics\n\n def sgd_step(carry, unused_t, data: types.Transition,\n normalizer_params: running_statistics.RunningStatisticsState):\n optimizer_state, params, key = carry\n key, key_perm, key_grad = jax.random.split(key, 3)\n\n def convert_data(x: jnp.ndarray):\n x = jax.random.permutation(key_perm, x)\n x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])\n return x\n\n shuffled_data = jax.tree_util.tree_map(convert_data, data)\n (optimizer_state, params, _), metrics = jax.lax.scan(\n functools.partial(minibatch_step, normalizer_params=normalizer_params),\n (optimizer_state, params, key_grad),\n shuffled_data,\n length=num_minibatches)\n return (optimizer_state, params, key), metrics\n\n def training_step(\n carry: Tuple[TrainingState, envs.State, PRNGKey],\n unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:\n training_state, state, key = carry\n key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)\n\n policy = make_policy(\n (training_state.normalizer_params, training_state.params.policy))\n\n def f(carry, unused_t):\n current_state, current_key = carry\n current_key, next_key = jax.random.split(current_key)\n next_state, data = generate_unroll(\n env,\n current_state,\n policy,\n current_key,\n unroll_length,\n extra_fields=('truncation',))\n return (next_state, next_key), data\n\n (state, _), data = jax.lax.scan(\n f, (state, key_generate_unroll), (),\n length=batch_size * num_minibatches // num_envs)\n # Have leading dimensions (batch_size * num_minibatches, unroll_length)\n data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data)\n data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),\n data)\n assert data.discount.shape[1:] == (unroll_length,)\n\n # Update normalization params and normalize observations.\n normalizer_params = running_statistics.update(\n training_state.normalizer_params,\n data.observation,\n pmap_axis_name=_PMAP_AXIS_NAME)\n\n (optimizer_state, params, _), metrics = jax.lax.scan(\n functools.partial(\n sgd_step, data=data, normalizer_params=normalizer_params),\n (training_state.optimizer_state, training_state.params, key_sgd), (),\n length=num_updates_per_batch)\n\n new_training_state = TrainingState(\n optimizer_state=optimizer_state,\n params=params,\n normalizer_params=normalizer_params,\n env_steps=training_state.env_steps + env_step_per_training_step)\n return (new_training_state, state, new_key), metrics\n\n def training_epoch(training_state: TrainingState, state: envs.State,\n key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:\n (training_state, state, _), loss_metrics = jax.lax.scan(\n training_step, (training_state, state, key), (),\n length=num_training_steps_per_epoch)\n loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)\n return training_state, state, loss_metrics\n\n training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)\n\n # Note that this is NOT a pure jittable method.\n def training_epoch_with_timing(\n training_state: TrainingState, env_state: envs.State,\n key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:\n nonlocal training_walltime\n t = time.time()\n training_state, env_state = _strip_weak_type((training_state, env_state))\n result = training_epoch(training_state, env_state, key)\n training_state, env_state, metrics = _strip_weak_type(result)\n\n metrics = jax.tree_util.tree_map(jnp.mean, metrics)\n jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)\n\n epoch_training_time = time.time() - t\n training_walltime += epoch_training_time\n sps = (num_training_steps_per_epoch *\n env_step_per_training_step *\n max(num_resets_per_eval, 1)) / epoch_training_time\n metrics = {\n 'training/sps': sps,\n 'training/walltime': training_walltime,\n **{f'training/{name}': value for name, value in metrics.items()}\n }\n return training_state, env_state, metrics # pytype: disable=bad-return-type # py311-upgrade\n\n\n if saved_params is None:\n init_params = ppo_losses.PPONetworkParams(\n policy=ppo_network.policy_network.init(key_policy),\n value=ppo_network.value_network.init(key_value))\n normalizer_params = running_statistics.init_state(\n specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32')))\n else:\n init_params = ppo_losses.PPONetworkParams(\n policy=saved_params[1],\n value=saved_params[2])\n normalizer_params = saved_params[0]\n\n training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray\n optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars\n params=init_params,\n normalizer_params=normalizer_params,\n env_steps=0)\n training_state = jax.device_put_replicated(\n training_state,\n jax.local_devices()[:local_devices_to_use])\n\n if not eval_env:\n eval_env = environment\n if randomization_fn is not None:\n v_randomization_fn = functools.partial(\n randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)\n )\n eval_env = wrap_for_training(\n eval_env,\n episode_length=episode_length,\n action_repeat=action_repeat,\n randomization_fn=v_randomization_fn,\n )\n\n evaluator = Evaluator(\n eval_env,\n functools.partial(make_policy, deterministic=deterministic_eval),\n num_eval_envs=num_eval_envs,\n episode_length=episode_length,\n action_repeat=action_repeat,\n key=eval_key)\n\n # Run initial eval\n metrics = {}\n if process_id == 0 and num_evals > 1:\n metrics = evaluator.run_evaluation(\n _unpmap(\n (training_state.normalizer_params, training_state.params.policy)),\n training_metrics={})\n logging.info(metrics)\n progress_fn(0, metrics)\n\n training_metrics = {}\n training_walltime = 0\n current_step = 0\n # Initialize variables to allow saving params of run with max score\n max_score = 0\n max_score_params = {}\n for it in range(num_evals_after_init):\n logging.info('starting iteration %s %s', it, time.time() - xt)\n\n for _ in range(max(num_resets_per_eval, 1)):\n # optimization\n epoch_key, local_key = jax.random.split(local_key)\n epoch_keys = jax.random.split(epoch_key, local_devices_to_use)\n (training_state, env_state, training_metrics) = (\n training_epoch_with_timing(training_state, env_state, epoch_keys)\n )\n current_step = int(_unpmap(training_state.env_steps))\n\n key_envs = jax.vmap(\n lambda x, s: jax.random.split(x[0], s),\n in_axes=(0, None))(key_envs, key_envs.shape[1])\n # TODO: move extra reset logic to the AutoResetWrapper.\n env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state\n\n if process_id == 0:\n # Run evals.\n metrics = evaluator.run_evaluation(\n _unpmap(\n (training_state.normalizer_params, training_state.params.policy)),\n training_metrics)\n logging.info(metrics)\n progress_fn(current_step, metrics)\n params = _unpmap(\n (training_state.normalizer_params, training_state.params.policy,\n training_state.params.value))\n\n # Save params if this is the max score\n eval_score = metrics['eval/episode_reward']\n if eval_score > max_score:\n max_score = eval_score\n max_score_params = {\n \"score\": max_score,\n \"params\": params,\n }\n policy_params_fn(current_step, make_policy, params)\n\n total_steps = current_step\n assert total_steps >= num_timesteps\n\n # If there was no mistakes the training_state should still be identical on all\n # devices.\n pmap.assert_is_replicated(training_state)\n params = _unpmap(\n (training_state.normalizer_params, training_state.params.policy,\n training_state.params.value))\n logging.info('total steps: %s', total_steps)\n pmap.synchronize_hosts()\n return (make_policy, params, metrics, max_score_params)\n","metadata":{"cellView":"form","id":"fxmLdUcPUMSD","execution":{"iopub.status.busy":"2024-04-18T14:12:22.876924Z","iopub.execute_input":"2024-04-18T14:12:22.877476Z","iopub.status.idle":"2024-04-18T14:12:23.617090Z","shell.execute_reply.started":"2024-04-18T14:12:22.877448Z","shell.execute_reply":"2024-04-18T14:12:23.616214Z"},"trusted":true},"execution_count":11,"outputs":[]},{"cell_type":"code","source":"# Define the PPO networks (adapted from https://github.com/google/brax)\n\n@flax.struct.dataclass\nclass PPONetworks:\n policy_network: networks.FeedForwardNetwork\n value_network: networks.FeedForwardNetwork\n parametric_action_distribution: distribution.ParametricDistribution\n\ndef make_ppo_networks(\n observation_size: int,\n action_size: int,\n preprocess_observations_fn: types.PreprocessObservationFn = types\n .identity_observation_preprocessor,\n policy_hidden_layer_sizes: Sequence[int] = policy_hidden_layer_sizes,\n value_hidden_layer_sizes: Sequence[int] = value_hidden_layer_sizes,\n activation: networks.ActivationFn = linen.swish) -> PPONetworks:\n \"\"\"Make PPO networks with preprocessor.\"\"\"\n parametric_action_distribution = distribution.NormalTanhDistribution(\n event_size=action_size)\n policy_network = networks.make_policy_network(\n parametric_action_distribution.param_size,\n observation_size,\n preprocess_observations_fn=preprocess_observations_fn,\n hidden_layer_sizes=policy_hidden_layer_sizes,\n activation=activation)\n value_network = networks.make_value_network(\n observation_size,\n preprocess_observations_fn=preprocess_observations_fn,\n hidden_layer_sizes=value_hidden_layer_sizes,\n activation=activation)\n\n return PPONetworks(\n policy_network=policy_network,\n value_network=value_network,\n parametric_action_distribution=parametric_action_distribution)","metadata":{"execution":{"iopub.status.busy":"2024-04-18T14:12:23.618307Z","iopub.execute_input":"2024-04-18T14:12:23.618814Z","iopub.status.idle":"2024-04-18T14:12:23.627668Z","shell.execute_reply.started":"2024-04-18T14:12:23.618787Z","shell.execute_reply":"2024-04-18T14:12:23.626695Z"},"trusted":true},"execution_count":12,"outputs":[]},{"cell_type":"markdown","source":"## 7 - Training the Humanoid\n\nTraining for 120m timesteps with 9 evals takes about 45min with two T4 GPUs. That can be enough for it to learn to do 7 bounces on average (although it can take longer in some training runs as the optimization is non-deterministic).\n\nLearning to do better (~18 bounces on average) is possible in less than 3 hours.","metadata":{}},{"cell_type":"code","source":"# Load params to restart training from a saved checkpoint\n# (i.e. from the saved policy and value neural networks' weights)\nupload_model = False\nif upload_model:\n saved_params = model.load_params(save_path)\nelse:\n saved_params = None","metadata":{"id":"vgeiw_vNjwcq","execution":{"iopub.status.busy":"2024-04-18T15:36:50.405373Z","iopub.execute_input":"2024-04-18T15:36:50.405778Z","iopub.status.idle":"2024-04-18T15:36:50.419138Z","shell.execute_reply.started":"2024-04-18T15:36:50.405746Z","shell.execute_reply":"2024-04-18T15:36:50.418001Z"},"trusted":true},"execution_count":37,"outputs":[]},{"cell_type":"code","source":"# Train\ntrain_fn = functools.partial(\n train, num_timesteps=num_timesteps, num_evals=num_evals,\n episode_length=episode_length, normalize_observations=normalize_observations,\n action_repeat=action_repeat, unroll_length=unroll_length, num_minibatches=num_minibatches,\n num_updates_per_batch=num_updates_per_batch, discounting=discounting,\n learning_rate=learning_rate, entropy_cost=entropy_cost, num_envs=num_envs,\n reward_scaling=reward_scaling, clipping_epsilon=clipping_epsilon, gae_lambda=gae_lambda,\n normalize_advantage=normalize_advantage, batch_size=batch_size, seed=0,\n network_factory=make_ppo_networks, saved_params=saved_params)\n\nx_data = []\ny_data = []\nydataerr = []\ntimes = [datetime.now()]\n\nmax_y, min_y = 15000, 0\ndef progress(num_steps, metrics):\n times.append(datetime.now())\n x_data.append(num_steps)\n y_data.append(metrics['eval/episode_reward'])\n ydataerr.append(metrics['eval/episode_reward_std'])\n\n plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.1])\n plt.ylim([min_y, max_y])\n\n plt.xlabel('# environment steps')\n plt.ylabel('reward per episode')\n plt.title(f'y={y_data[-1]:.1f}')\n\n plt.errorbar(\n x_data, y_data, yerr=ydataerr)\n plt.show()\n\n if 'training/policy_loss' in metrics:\n print(\"Other metrics\") \n print(f\"entropy loss: {metrics['training/entropy_loss']:.2f}\")\n print(f\"value loss: {metrics['training/v_loss']:.2f}\")\n print(f\"max episode reward: {int(metrics['eval/episode_reward_max'])}\")\n print(f\"avg bounces: {metrics['eval/episode_bounces']:.2f}\")\n print(f\"max bounces: {metrics['eval/episode_bounces_max']}\\n\")\n\nmake_inference_fn, train_params, _, max_score_params = train_fn(\n environment=env, progress_fn=progress)\n\nprint(f'time to jit: {times[1] - times[0]}')\nprint(f'time to train: {times[-1] - times[1]}')\nprint(f'total time: {times[-1] - times[0]}\\n')\nprint(f\"max score: {int(max_score_params['score'])}\")","metadata":{"id":"xLiddQYPApBw","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 8 - Save and load the policy\n\nWe can save and load the policy using the brax model API.","metadata":{"id":"YYIch0HEApBx"}},{"cell_type":"code","source":"# Save the model\nmodel.save_params(save_path, train_params)\n# model.save_params(save_path, max_score_params['params'])","metadata":{"id":"Z8gI6qH6ApBx","execution":{"iopub.status.busy":"2024-04-18T16:01:31.935352Z","iopub.execute_input":"2024-04-18T16:01:31.935655Z","iopub.status.idle":"2024-04-18T16:01:31.950113Z","shell.execute_reply.started":"2024-04-18T16:01:31.935630Z","shell.execute_reply":"2024-04-18T16:01:31.949221Z"},"trusted":true},"execution_count":39,"outputs":[]},{"cell_type":"code","source":"# Load the model and define the inference function\ninference_fn = make_inference_fn(model.load_params(save_path)[:2])\njit_inference_fn = jax.jit(inference_fn)","metadata":{"id":"h4reaWgxApBx","cellView":"form","execution":{"iopub.status.busy":"2024-04-18T16:01:31.951495Z","iopub.execute_input":"2024-04-18T16:01:31.951804Z","iopub.status.idle":"2024-04-18T16:01:31.983341Z","shell.execute_reply.started":"2024-04-18T16:01:31.951777Z","shell.execute_reply":"2024-04-18T16:01:31.982292Z"},"trusted":true},"execution_count":40,"outputs":[]},{"cell_type":"markdown","source":"## 9 - Visualize the policy\n\nFinally we can visualize the Humanoid in action and watch while it bounces the ball with its foot!\n\nThis can also be saved to an mp4 file which you can then download from the `Output` section (can be found on the right if running in a laptop).","metadata":{"id":"0G357XIfApBy"}},{"cell_type":"code","source":"eval_env = envs.get_environment(env_name)\n\njit_reset = jax.jit(eval_env.reset)\njit_step = jax.jit(eval_env.step)","metadata":{"id":"osYasMw4ApBy","execution":{"iopub.status.busy":"2024-04-18T16:01:31.988305Z","iopub.execute_input":"2024-04-18T16:01:31.988616Z","iopub.status.idle":"2024-04-18T16:01:32.421997Z","shell.execute_reply.started":"2024-04-18T16:01:31.988589Z","shell.execute_reply":"2024-04-18T16:01:32.421164Z"},"trusted":true},"execution_count":41,"outputs":[]},{"cell_type":"code","source":"# Visualize the Humanoid and optionally save it to a mp4 file\n\ninit_time = datetime.now()\nrollouts = []\nmax_rollout_reward = 0\nmax_bounces = 0\nfor i in range(num_rollouts):\n # Initialize the state\n rng = jax.random.PRNGKey(i)\n state = jit_reset(rng)\n rollout = [state.pipeline_state]\n total_reward = 0\n total_bounces = 0\n\n # Grab a trajectory\n n_steps = 100000\n render_every = 2\n\n for i in range(n_steps):\n act_rng, rng = jax.random.split(rng)\n ctrl, _ = jit_inference_fn(state.obs, act_rng)\n state = jit_step(state, ctrl)\n total_reward += state.metrics[\"reward\"]\n total_bounces += state.metrics[\"bounces\"]\n rollout.append(state.pipeline_state)\n\n if state.done:\n break\n\n max_rollout_reward = max(max_rollout_reward, total_reward)\n max_bounces = max(max_bounces, total_bounces)\n \n if total_bounces > num_bounces_threshold:\n print(f\"Iteration with reward {int(total_reward)} and {int(total_bounces)} bounces\")\n video = env.render(rollout[::render_every], camera='side', height=480, width=640)\n media.show_video(video, fps=1.0 / env.dt / render_every)\n media.write_video(f\"/kaggle/working/ball_bounce_{int(total_reward)}_{int(total_bounces)}.mp4\", video, fps=1.0 / env.dt / render_every)\n \nprint(f\"Max rollout reward was - {int(max_rollout_reward)}\")\nprint(f\"Max bounces was - {int(max_bounces)}\")\nprint(f'total time: {datetime.now() - init_time}')","metadata":{"id":"d-UhypudApBy","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 10 - Tune training hyperparameters for better performance\n\nIt's possible that the current training hyperparams won't enable the Humanoid to learn to bounce the ball more than 6-8 times on average. \n\nPerformance can actually decline if you keep training beyond that point. This is a sign you may want to try a lower `learning_rate`.\n\nChange `upload_model` to `True` (to restart training from a saved checkpoint) and try updating the learning rates in the `2 - Config` section to see how far you can go. Good luck! \n\nAnd if you manage to go beyond 84 bounces, well done on the amazing work and please share!","metadata":{}}]} -------------------------------------------------------------------------------- /tutorials/head_tricks.ipynb: -------------------------------------------------------------------------------- 1 | {"metadata":{"colab":{"gpuClass":"premium","private_outputs":true,"provenance":[],"collapsed_sections":["YvyGCsgSCxHQ","P1K6IznI2y83"],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"accelerator":"GPU","language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"nvidiaTeslaT4","dataSources":[],"dockerImageVersionId":30698,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Tutorial: Head Tricks\n\nIn this tutorial you will learn how to use the physics simulator MuJoCo, and Reinforcement Learning to teach a humanoid to bounce a ball with its head - an essential skill in the art of football.","metadata":{"id":"Jz-I63-7YgRO"}},{"cell_type":"markdown","source":"## 1 - Kaggle\nIt's recommended to run this notebook on www.kaggle.com where you can use two T4 GPUs for 30h/week for free.\n\nTo import this notebook into Kaggle you need to:\n- Login to your Kaggle account\n- Create a new notebook\n- Click on `File` and then `Import Notebook`\n- Select the tab `GitHub`\n- Search `goncalog/ai-robotics`\n- Select the file `tutorials/head_tricks.ipynb`\n- Click the `Import` button\n\nTo run this notebook you can either click the `Run All` button or run each cell individually by clicking the `Run current cell` button.","metadata":{}},{"cell_type":"markdown","source":"## 2 - Config\n\nThis is the configuration to run the tutorial, it includes:\n- Training hyperparameters\n- Mujoco environment variables\n- File paths\n- Rendering variables","metadata":{}},{"cell_type":"code","source":"num_timesteps = 30_000_000\nnum_evals = 5\n# num_envs: the number of parallel environments to use for rollouts\nnum_envs = 2048\n\n# learning_rate: learning rate for ppo loss\nlearning_rate = 3e-4\n# discounting: discounting rate\ndiscounting = 0.97\n# episode_length: the length of an environment episode\nepisode_length = 1000\n# normalize_observations: whether to normalize observations\nnormalize_observations = True\n# action_repeat: the number of timesteps to repeat an action\naction_repeat = 1\n# unroll_length: the number of timesteps to unroll in each environment.\n# The PPO loss is computed over `unroll_length` timesteps\nunroll_length = 10\n# entropy_cost: entropy reward for ppo loss, higher values increase entropy of the policy\nentropy_cost = 1e-3\n# batch_size: the batch size for each minibatch SGD step\nbatch_size = 1024\n# num_minibatches: the number of times to run the SGD step,\n# each with a different minibatch with leading dimension of `batch_size`\nnum_minibatches = 32\n# num_updates_per_batch: the number of times to run the gradient update over\n# all minibatches before doing a new environment rollout\nnum_updates_per_batch = 8\n# reward_scaling: float scaling for reward\nreward_scaling = 1\n# clipping_epsilon: clipping epsilon for PPO loss\nclipping_epsilon = 0.3\n# gae_lambda: General advantage estimation lambda\ngae_lambda = 0.95\n# normalize_advantage: whether to normalize advantage estimate\nnormalize_advantage = True\n\npolicy_hidden_layer_sizes = (32,) * 4\nvalue_hidden_layer_sizes = (256,) * 5\n\nhead_type = \"sphere\" # can be sphere or box\nball_size = 0.15\ntorso_index = 2 # index of torso body in mjx data (it contains the head geom)\nball_height = 1.85 # z coordinate of centre of mass\nball_x = 0 # x coordinate of centre of mass\nbounce_threshold = ball_height - 0.05 # z coordinate\n\n# Simulation time step in seconds. \n# This is the single most important parameter affecting the speed-accuracy trade-off \n# which is inherent in every physics simulation. \n# Smaller values result in better accuracy and stability\nmj_model_timestep = 0.005\n\nsave_path = \"/kaggle/working/mjx_brax_nn\"\n\nnum_rollouts = 1\nnum_bounces_threshold = 0","metadata":{"id":"-Xt8DyfJYrbd","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 3 - Install MuJoCo, MJX, and Brax","metadata":{"id":"YvyGCsgSCxHQ"}},{"cell_type":"code","source":"!pip install mujoco==3.1.2\n!pip install mujoco_mjx==3.1.2\n!pip install brax==0.10.0","metadata":{"id":"Xqo7pyX-n72M","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Check if MuJoCo installation was successful\n\nimport distutils.util\nimport os\nimport subprocess\nif subprocess.run('nvidia-smi').returncode:\n raise RuntimeError(\n 'Cannot communicate with GPU. '\n 'Make sure you are using a GPU runtime. '\n 'Go to the Runtime menu and select Choose runtime type.')\n\n# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n# This is usually installed as part of an Nvidia driver package, but this\n# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\nNVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/'\nNVIDIA_ICD_CONFIG_FILE = '10_nvidia.json'\nif not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n os.makedirs(NVIDIA_ICD_CONFIG_PATH)\n file_path = os.path.join(NVIDIA_ICD_CONFIG_PATH, NVIDIA_ICD_CONFIG_FILE)\n with open(file_path, 'w') as f:\n f.write(\"\"\"{\n \"file_format_version\" : \"1.0.0\",\n \"ICD\" : {\n \"library_path\" : \"libEGL_nvidia.so.0\"\n }\n}\n\"\"\")\n\n# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\nxla_flags = os.environ.get('XLA_FLAGS', '')\nxla_flags += ' --xla_gpu_triton_gemm_any=True'\nos.environ['XLA_FLAGS'] = xla_flags\n\n# Configure MuJoCo to use the EGL rendering backend (requires GPU)\nprint('Setting environment variable to use GPU rendering:')\n%env MUJOCO_GL=egl\n\ntry:\n print('Checking that the installation succeeded:')\n import mujoco\n mujoco.MjModel.from_xml_string('')\nexcept Exception as e:\n raise e from RuntimeError(\n 'Something went wrong during installation. Check the shell output above '\n 'for more information.\\n'\n 'If using a hosted runtime, make sure you enable GPU acceleration '\n 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n\nprint('Installation successful.')","metadata":{"id":"IbZxYDxzoz5R","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Import packages for plotting and creating graphics\nimport time\nimport itertools\nimport numpy as np\nfrom typing import Callable, NamedTuple, Optional, Union, List\n\n# Graphics and plotting.\nprint('Installing mediapy:')\n!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n!pip install -q mediapy\nimport mediapy as media\nimport matplotlib.pyplot as plt\n\n# More legible printing from numpy.\nnp.set_printoptions(precision=3, suppress=True, linewidth=100)","metadata":{"id":"T5f4w3Kq2X14","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Import MuJoCo, MJX, and Brax\n\nfrom datetime import datetime\nimport functools\nimport jax\nfrom jax import numpy as jp\nimport numpy as np\nfrom typing import Any, Dict, Sequence, Tuple, Union\n\nfrom brax import base\nfrom brax import envs\nfrom brax import math\nfrom brax.base import Base, Motion, Transform\nfrom brax.envs.base import Env, PipelineEnv, State\nfrom brax.mjx.base import State as MjxState\nfrom brax.io import html, mjcf, model\nfrom brax.training import distribution, networks\n\nfrom etils import epath\nfrom flax import linen, struct\nfrom matplotlib import pyplot as plt\nimport mediapy as media\nfrom ml_collections import config_dict\nimport mujoco\nfrom mujoco import mjx\n","metadata":{"id":"ObF1UXrkb0Nd","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 4 - Setting up the Humanoid environment with MJX\nMJX is an implementation of MuJoCo written in [JAX](https://jax.readthedocs.io/en/latest/index.html), enabling large batch training on GPU/TPU. In this notebook, we train RL policies with MJX.\n\nHere we implement our environment by adapting the original [Humanoid](https://github.com/google-deepmind/mujoco/blob/546a27ca72397b888e314ee4549bcf12d9fd5957/model/humanoid/humanoid.xml) environment to also include a ball. Notice that `reset` initializes a `State`, and `step` steps through the physics step and reward logic. The reward and stepping logic train the Humanoid to bounce a ball with its head.","metadata":{"id":"RAv6WUVUm78k"}},{"cell_type":"code","source":"# Humanoid XML\n\nball_material = \"\"\"\n \n \n \"\"\"\nball_default = f\"\"\"\n \n \n \n \"\"\"\nball_body = f\"\"\"\n \n \n \n \n \"\"\"\nball_contact = ''\nif head_type == \"sphere\":\n head_xml = ''\nelif head_type == \"box\":\n head_xml = ''\nelse:\n raise Exception(f\"{head_type=} isn't supported.\")\n\nxml = f\"\"\"\n\n \n\n \n \n \n \n \n \n\n \n\n \n \n \n \n\n \n \n \n {ball_material}\n \n\n \n {ball_default}\n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n \n \n \n\n \n \n \n \n \n \n \n \n {head_xml}\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n {ball_body}\n \n\n \n \n \n \n \n \n \n {ball_contact}\n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n\"\"\"","metadata":{"id":"G5RKmMLkUuC8","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Humanoid Env\n\nclass Humanoid(PipelineEnv):\n\n def __init__(\n self,\n terminate_when_unhealthy=True,\n reset_noise_scale=1e-2,\n exclude_current_positions_from_observation=True,\n **kwargs,\n ):\n mj_model = mujoco.MjModel.from_xml_string(xml)\n mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG\n mj_model.opt.iterations = 6\n mj_model.opt.ls_iterations = 6\n\n sys = mjcf.load_model(mj_model)\n\n physics_steps_per_control_step = 5\n kwargs['n_frames'] = kwargs.get(\n 'n_frames', physics_steps_per_control_step)\n kwargs['backend'] = 'mjx'\n\n super().__init__(sys, **kwargs)\n\n self._terminate_when_unhealthy = terminate_when_unhealthy\n self._reset_noise_scale = reset_noise_scale\n self._exclude_current_positions_from_observation = (\n exclude_current_positions_from_observation\n )\n\n\n def reset(self, rng: jp.ndarray) -> State:\n \"\"\"Resets the environment to an initial state.\"\"\"\n rng, rng1, rng2 = jax.random.split(rng, 3)\n\n low, hi = -self._reset_noise_scale, self._reset_noise_scale\n qpos = self.sys.qpos0 + jax.random.uniform(\n rng1, (self.sys.nq,), minval=low, maxval=hi\n )\n qvel = jax.random.uniform(\n rng2, (self.sys.nv,), minval=low, maxval=hi\n )\n \n data = self.pipeline_init(qpos, qvel)\n\n obs = self._get_obs(data, jp.zeros(self.sys.nu))\n reward, done, zero = jp.zeros(3)\n metrics = {\n 'ball_reward': zero,\n 'reward_quadctrl': zero,\n 'reward_alive': zero,\n 'reward': zero,\n 'bounces': zero,\n }\n return State(data, obs, reward, done, metrics)\n\n\n def step(self, state: State, action: jp.ndarray) -> State:\n \"\"\"Runs one timestep of the environment's dynamics.\"\"\"\n data0 = state.pipeline_state\n data = self.pipeline_step(data0, action)\n \n reward, done = self._get_reward(state, action, data0, data)\n obs = self._get_obs(data, action)\n return state.replace(\n pipeline_state=data, obs=obs, reward=reward, done=done\n )\n\n\n def _get_obs(\n self, data: mjx.Data, action: jp.ndarray\n ) -> jp.ndarray:\n \"\"\"Observes humanoid body and ball position, velocities, and angles.\"\"\"\n position = data.qpos\n if self._exclude_current_positions_from_observation:\n position = position[2:]\n\n # external_contact_forces are excluded\n return jp.concatenate([\n # qpos: position / nq: number of generalized coordinates = dim(qpos)\n position,\n # qvel: velocity / nv: number of degrees of freedom = dim(qvel)\n data.qvel,\n # cinert: com-based body inertia and mass / (nbody, 10)\n data.cinert[1:].ravel(),\n # cvel: com-based velocity [3D rot; 3D tran] / (nbody, 6)\n data.cvel[1:].ravel(),\n # qfrc_actuator: actuator force / nv: number of degrees of freedom\n data.qfrc_actuator,\n ])\n\n\n def _get_reward(\n self, state: State, action: jp.ndarray, data0: mjx.Data, data: mjx.Data\n ) -> Tuple[jp.ndarray, jp.ndarray]:\n \"\"\"Apply reward function and return outputs.\"\"\"\n ball_reward = 100.0\n ctrl_cost_weight = 0.1\n healthy_reward = 1.0\n healthy_z_range = (0.5, 3.0)\n ball_healthy_x_y_range = (-0.5, 0.5)\n ball_healthy_z_range = (1.0, 2.4)\n ball_reward_min_z = 2.2\n \n com_before_ball = data0.subtree_com[-1]\n com_after_ball = data.subtree_com[-1]\n\n min_z, max_z = healthy_z_range\n is_healthy = jp.where(data.q[torso_index] < min_z, 0.0, 1.0)\n is_healthy = jp.where(data.q[torso_index] > max_z, 0.0, is_healthy)\n\n ball_min_x, ball_max_x = ball_healthy_x_y_range\n ball_min_y, ball_max_y = ball_healthy_x_y_range\n ball_min_z, ball_max_z = ball_healthy_z_range\n is_healthy = jp.where(com_after_ball[0] < ball_min_x, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[0] > ball_max_x, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[1] < ball_min_y, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[1] > ball_max_y, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[2] < ball_min_z, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[2] > ball_max_z, 0.0, is_healthy)\n\n if not self._terminate_when_unhealthy:\n healthy_reward = healthy_reward * is_healthy\n ball_reward = ball_reward * is_healthy\n\n ctrl_cost = ctrl_cost_weight * jp.sum(jp.square(action))\n\n is_ball_reward = jp.where(com_after_ball[2] >= ball_reward_min_z, 1.0, 0.0)\n reward = ball_reward * is_ball_reward + healthy_reward - ctrl_cost\n\n state.metrics.update(\n ball_reward=ball_reward * is_ball_reward,\n reward_quadctrl=-ctrl_cost,\n reward_alive=healthy_reward,\n reward=reward,\n bounces=self._is_bounce(com_before_ball, com_after_ball),\n )\n \n done = 1.0 - is_healthy if self._terminate_when_unhealthy else 0.0\n return reward, done\n\n\n def _get_reward_new(\n self, state: State, action: jp.ndarray, data0: mjx.Data, data: mjx.Data\n ) -> Tuple[jp.ndarray, jp.ndarray]:\n \"\"\"Apply reward func based on ball distance to normal of the head and target height.\"\"\"\n ctrl_cost_weight = 0.1\n healthy_reward = 5.0\n healthy_z_range = (1.0, 3.0)\n ball_reward = 5.0\n ball_healthy_z_range = (1.0, 4.0)\n ball_reward_min_z = 2.0\n ball_reward_target_z = 2.2\n distance_head_reward = 5.0\n distance_head_max = 2.0\n \n com_before_ball = data0.subtree_com[-1]\n com_after_ball = data.subtree_com[-1]\n com_after_head = data.subtree_com[torso_index]\n distance_head = jp.sqrt(jp.square(com_after_ball[0] - com_after_head[0]) + jp.square(com_after_ball[1] - com_after_head[1]))\n\n min_z, max_z = healthy_z_range\n is_healthy = jp.where(data.q[torso_index] < min_z, 0.0, 1.0)\n is_healthy = jp.where(data.q[torso_index] > max_z, 0.0, is_healthy)\n\n ball_min_z, ball_max_z = ball_healthy_z_range\n is_healthy = jp.where(com_after_ball[2] < ball_min_z, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[2] > ball_max_z, 0.0, is_healthy)\n \n is_healthy = jp.where(distance_head > distance_head_max, 0.0, is_healthy)\n\n ctrl_cost = ctrl_cost_weight * jp.sum(jp.square(action))\n \n distance_target_height = jp.sqrt(jp.square(com_after_ball[2] - ball_reward_target_z))\n ball_reward = ball_reward * (1 - (distance_target_height / (ball_max_z - ball_reward_target_z)))\n is_ball_reward = jp.where(com_after_ball[2] >= ball_reward_min_z, 1.0, 0.0)\n \n distance_head_reward = distance_head_reward * (1 - (distance_head / distance_head_max))\n \n reward = ball_reward * is_ball_reward + healthy_reward - ctrl_cost + distance_head_reward\n\n state.metrics.update(\n ball_reward=ball_reward * is_ball_reward,\n reward_quadctrl=-ctrl_cost,\n reward_alive=healthy_reward,\n reward=reward,\n bounces=self._is_bounce(com_before_ball, com_after_ball),\n )\n \n done = 1.0 - is_healthy\n return reward, done\n \n \n # There is a lot of room to improve this function as it should check for contacts\n # between the ball and the head of the Humanoid\n # (at the time of implementation the contacts data wasn't easily accessible in MJX)\n def _is_bounce(\n self, com_before_ball: jp.ndarray, com_after_ball: jp.ndarray\n ) -> jp.ndarray:\n \"\"\"Check if ball bounced.\"\"\"\n is_bounce = jp.where(com_before_ball[2] < bounce_threshold, 1.0, 0.0)\n is_bounce = jp.where(com_after_ball[2] >= bounce_threshold, is_bounce, 0.0)\n return is_bounce\n \n\nenvs.register_environment(\"humanoid\", Humanoid)","metadata":{"id":"mtGMYNLE3QJN","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 5 - Visualize a rollout\n\nLet's instantiate the environment and visualize a short rollout.\n\nNOTE: Since episodes terminate early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane, and the head and the ball. The other contacts weren't included. This also speeds up the training later on.","metadata":{"id":"P1K6IznI2y83"}},{"cell_type":"code","source":"# Instantiate the environment\nenv_name = \"humanoid\"\nenv = envs.get_environment(env_name)\n\n# Define the jit reset/step functions\njit_reset = jax.jit(env.reset)\njit_step = jax.jit(env.step)","metadata":{"id":"EhKLFK54C1CH","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Initialize the state\nstate = jit_reset(jax.random.PRNGKey(0))\nprint(f\"Observations size: {len(state.obs)}\")\nprint(f\"Actions size: {env.sys.nu}\")\n\nrollout = [state.pipeline_state]\n\n# Grab a trajectory\nfor i in range(50):\n # ctrl: control / nu: number of actuators/controls = dim(ctrl)\n ctrl = -0.1 * jp.ones(env.sys.nu)\n state = jit_step(state, ctrl)\n rollout.append(state.pipeline_state)\n\nmedia.show_video(env.render(rollout, camera='side', height=480, width=640), fps=1.0 / env.dt)","metadata":{"id":"Ph8u-v2Q2xLS","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 6 - Define the training functions\n\nLet's define the training functions using [PPO](https://openai.com/research/openai-baselines-ppo) to make the Humanoid bounce the ball with its head.","metadata":{"id":"BQDG6NQ1CbZD"}},{"cell_type":"code","source":"# Define the Acting/Evaluator (adapted from https://github.com/google/brax)\n\n\"\"\"Brax training acting functions.\"\"\"\n\nimport time\nfrom typing import Callable, Sequence, Tuple, Union\n\nfrom brax import envs\nfrom brax.training.types import Metrics\nfrom brax.training.types import Policy\nfrom brax.training.types import PolicyParams\nfrom brax.training.types import PRNGKey\nfrom brax.training.types import Transition\nfrom brax.v1 import envs as envs_v1\nimport jax\nimport numpy as np\n\nActingState = Union[envs.State, envs_v1.State]\nActingEnv = Union[envs.Env, envs_v1.Env, envs_v1.Wrapper]\n\n\ndef actor_step(\n env: ActingEnv,\n env_state: ActingState,\n policy: Policy,\n key: PRNGKey,\n extra_fields: Sequence[str] = ()\n) -> Tuple[ActingState, Transition]:\n \"\"\"Collect data.\"\"\"\n actions, policy_extras = policy(env_state.obs, key)\n nstate = env.step(env_state, actions)\n state_extras = {x: nstate.info[x] for x in extra_fields}\n return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray\n observation=env_state.obs,\n action=actions,\n reward=nstate.reward,\n discount=1 - nstate.done,\n next_observation=nstate.obs,\n extras={\n 'policy_extras': policy_extras,\n 'state_extras': state_extras\n })\n\n\ndef generate_unroll(\n env: ActingEnv,\n env_state: ActingState,\n policy: Policy,\n key: PRNGKey,\n unroll_length: int,\n extra_fields: Sequence[str] = ()\n) -> Tuple[ActingState, Transition]:\n \"\"\"Collect trajectories of given unroll_length.\"\"\"\n\n @jax.jit\n def f(carry, unused_t):\n state, current_key = carry\n current_key, next_key = jax.random.split(current_key)\n nstate, transition = actor_step(\n env, state, policy, current_key, extra_fields=extra_fields)\n return (nstate, next_key), transition\n\n (final_state, _), data = jax.lax.scan(\n f, (env_state, key), (), length=unroll_length)\n return final_state, data\n\n\n# TODO: Consider moving this to its own file.\nclass Evaluator:\n \"\"\"Class to run evaluations.\"\"\"\n\n def __init__(self, eval_env: envs.Env,\n eval_policy_fn: Callable[[PolicyParams],\n Policy], num_eval_envs: int,\n episode_length: int, action_repeat: int, key: PRNGKey):\n \"\"\"Init.\n\n Args:\n eval_env: Batched environment to run evals on.\n eval_policy_fn: Function returning the policy from the policy parameters.\n num_eval_envs: Each env will run 1 episode in parallel for each eval.\n episode_length: Maximum length of an episode.\n action_repeat: Number of physics steps per env step.\n key: RNG key.\n \"\"\"\n self._key = key\n self._eval_walltime = 0.\n\n eval_env = envs.training.EvalWrapper(eval_env)\n\n def generate_eval_unroll(policy_params: PolicyParams,\n key: PRNGKey) -> ActingState:\n reset_keys = jax.random.split(key, num_eval_envs)\n eval_first_state = eval_env.reset(reset_keys)\n return generate_unroll(\n eval_env,\n eval_first_state,\n eval_policy_fn(policy_params),\n key,\n unroll_length=episode_length // action_repeat)[0]\n\n self._generate_eval_unroll = jax.jit(generate_eval_unroll)\n self._steps_per_unroll = episode_length * num_eval_envs\n\n def run_evaluation(self,\n policy_params: PolicyParams,\n training_metrics: Metrics,\n aggregate_episodes: bool = True) -> Metrics:\n \"\"\"Run one epoch of evaluation.\"\"\"\n self._key, unroll_key = jax.random.split(self._key)\n\n t = time.time()\n eval_state = self._generate_eval_unroll(policy_params, unroll_key)\n eval_metrics = eval_state.info['eval_metrics']\n eval_metrics.active_episodes.block_until_ready()\n epoch_eval_time = time.time() - t\n metrics = {}\n for fn in [np.mean, np.std, np.max]:\n suffix = '_std' if fn == np.std else '_max' if fn == np.max else ''\n metrics.update(\n {\n f'eval/episode_{name}{suffix}': (\n fn(value) if aggregate_episodes else value\n )\n for name, value in eval_metrics.episode_metrics.items()\n }\n )\n metrics['eval/avg_episode_length'] = np.mean(eval_metrics.episode_steps)\n metrics['eval/epoch_eval_time'] = epoch_eval_time\n metrics['eval/sps'] = self._steps_per_unroll / epoch_eval_time\n self._eval_walltime = self._eval_walltime + epoch_eval_time\n metrics = {\n 'eval/walltime': self._eval_walltime,\n **training_metrics,\n **metrics\n }\n\n return metrics # pytype: disable=bad-return-type # jax-ndarray","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Define the Training Function (adapted from https://github.com/google/brax)\n\n\"\"\"Proximal policy optimization training.\n\nSee: https://arxiv.org/pdf/1707.06347.pdf\n\"\"\"\n\nimport functools\nimport time\nfrom typing import Callable, Optional, Tuple, Union\n\nfrom absl import logging\nfrom brax import base\nfrom brax import envs\nfrom brax.training import gradients\nfrom brax.training import pmap\nfrom brax.training import types\nfrom brax.training.acme import running_statistics\nfrom brax.training.acme import specs\nfrom brax.training.agents.ppo import losses as ppo_losses\nfrom brax.training.agents.ppo import networks as ppo_networks\nfrom brax.training.types import Params, PolicyParams, PreprocessorParams\nfrom brax.training.types import PRNGKey\nfrom brax.v1 import envs as envs_v1\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\n\nInferenceParams = Tuple[running_statistics.NestedMeanStd, Params]\nMetrics = types.Metrics\nValueParams = Any\n\n_PMAP_AXIS_NAME = 'i'\n\n\n@flax.struct.dataclass\nclass TrainingState:\n \"\"\"Contains training state for the learner.\"\"\"\n optimizer_state: optax.OptState\n params: ppo_losses.PPONetworkParams\n normalizer_params: running_statistics.RunningStatisticsState\n env_steps: jnp.ndarray\n\n\ndef _unpmap(v):\n return jax.tree_util.tree_map(lambda x: x[0], v)\n\n\ndef _strip_weak_type(tree):\n # brax user code is sometimes ambiguous about weak_type. in order to\n # avoid extra jit recompilations we strip all weak types from user input\n def f(leaf):\n leaf = jnp.asarray(leaf)\n return leaf.astype(leaf.dtype)\n return jax.tree_util.tree_map(f, tree)\n\n\ndef train(\n environment: Union[envs_v1.Env, envs.Env],\n num_timesteps: int,\n episode_length: int,\n action_repeat: int = 1,\n num_envs: int = 1,\n max_devices_per_host: Optional[int] = None,\n num_eval_envs: int = 128,\n learning_rate: float = 1e-4,\n entropy_cost: float = 1e-4,\n discounting: float = 0.9,\n seed: int = 0,\n unroll_length: int = 10,\n batch_size: int = 32,\n num_minibatches: int = 16,\n num_updates_per_batch: int = 2,\n num_evals: int = 1,\n num_resets_per_eval: int = 0,\n normalize_observations: bool = False,\n reward_scaling: float = 1.0,\n clipping_epsilon: float = 0.3,\n gae_lambda: float = 0.95,\n deterministic_eval: bool = False,\n network_factory: types.NetworkFactory[\n ppo_networks.PPONetworks\n ] = ppo_networks.make_ppo_networks,\n progress_fn: Callable[[int, Metrics], None] = lambda *args: None,\n normalize_advantage: bool = True,\n eval_env: Optional[envs.Env] = None,\n policy_params_fn: Callable[..., None] = lambda *args: None,\n randomization_fn: Optional[\n Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]\n ] = None,\n saved_params: Optional[\n Tuple[PreprocessorParams, PolicyParams, ValueParams]\n ] = None,\n):\n \"\"\"PPO training.\n\n Args:\n environment: the environment to train\n num_timesteps: the total number of environment steps to use during training\n episode_length: the length of an environment episode\n action_repeat: the number of timesteps to repeat an action\n num_envs: the number of parallel environments to use for rollouts\n NOTE: `num_envs` must be divisible by the total number of chips since each\n chip gets `num_envs // total_number_of_chips` environments to roll out\n NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since\n data generated by `num_envs` parallel envs gets used for gradient\n updates over `num_minibatches` of data, where each minibatch has a\n leading dimension of `batch_size`\n max_devices_per_host: maximum number of chips to use per host process\n num_eval_envs: the number of envs to use for evaluation. Each env will run 1\n episode, and all envs run in parallel during eval.\n learning_rate: learning rate for ppo loss\n entropy_cost: entropy reward for ppo loss, higher values increase entropy\n of the policy\n discounting: discounting rate\n seed: random seed\n unroll_length: the number of timesteps to unroll in each environment. The\n PPO loss is computed over `unroll_length` timesteps\n batch_size: the batch size for each minibatch SGD step\n num_minibatches: the number of times to run the SGD step, each with a\n different minibatch with leading dimension of `batch_size`\n num_updates_per_batch: the number of times to run the gradient update over\n all minibatches before doing a new environment rollout\n num_evals: the number of evals to run during the entire training run.\n Increasing the number of evals increases total training time\n num_resets_per_eval: the number of environment resets to run between each\n eval. The environment resets occur on the host\n normalize_observations: whether to normalize observations\n reward_scaling: float scaling for reward\n clipping_epsilon: clipping epsilon for PPO loss\n gae_lambda: General advantage estimation lambda\n deterministic_eval: whether to run the eval with a deterministic policy\n network_factory: function that generates networks for policy and value\n functions\n progress_fn: a user-defined callback function for reporting/plotting metrics\n normalize_advantage: whether to normalize advantage estimate\n eval_env: an optional environment for eval only, defaults to `environment`\n policy_params_fn: a user-defined callback function that can be used for\n saving policy checkpoints\n randomization_fn: a user-defined callback function that generates randomized\n environments\n saved_params: params to init the training with; includes normalizer_params\n and policy and value network params\n\n Returns:\n Tuple of (make_policy function, network params, metrics)\n \"\"\"\n assert batch_size * num_minibatches % num_envs == 0\n xt = time.time()\n\n process_count = jax.process_count()\n process_id = jax.process_index()\n local_device_count = jax.local_device_count()\n local_devices_to_use = local_device_count\n if max_devices_per_host:\n local_devices_to_use = min(local_devices_to_use, max_devices_per_host)\n logging.info(\n 'Device count: %d, process count: %d (id %d), local device count: %d, '\n 'devices to be used count: %d', jax.device_count(), process_count,\n process_id, local_device_count, local_devices_to_use)\n device_count = local_devices_to_use * process_count\n\n # The number of environment steps executed for every training step.\n env_step_per_training_step = (\n batch_size * unroll_length * num_minibatches * action_repeat)\n num_evals_after_init = max(num_evals - 1, 1)\n # The number of training_step calls per training_epoch call.\n # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step *\n # num_resets_per_eval))\n num_training_steps_per_epoch = np.ceil(\n num_timesteps\n / (\n num_evals_after_init\n * env_step_per_training_step\n * max(num_resets_per_eval, 1)\n )\n ).astype(int)\n\n key = jax.random.PRNGKey(seed)\n global_key, local_key = jax.random.split(key)\n del key\n local_key = jax.random.fold_in(local_key, process_id)\n local_key, key_env, eval_key = jax.random.split(local_key, 3)\n # key_networks should be global, so that networks are initialized the same\n # way for different processes.\n key_policy, key_value = jax.random.split(global_key)\n del global_key\n\n assert num_envs % device_count == 0\n\n v_randomization_fn = None\n if randomization_fn is not None:\n randomization_batch_size = num_envs // local_device_count\n # all devices gets the same randomization rng\n randomization_rng = jax.random.split(key_env, randomization_batch_size)\n v_randomization_fn = functools.partial(\n randomization_fn, rng=randomization_rng\n )\n\n if isinstance(environment, envs.Env):\n wrap_for_training = envs.training.wrap\n else:\n wrap_for_training = envs_v1.wrappers.wrap_for_training\n\n env = wrap_for_training(\n environment,\n episode_length=episode_length,\n action_repeat=action_repeat,\n randomization_fn=v_randomization_fn,\n )\n\n reset_fn = jax.jit(jax.vmap(env.reset))\n key_envs = jax.random.split(key_env, num_envs // process_count)\n key_envs = jnp.reshape(key_envs,\n (local_devices_to_use, -1) + key_envs.shape[1:])\n env_state = reset_fn(key_envs)\n\n normalize = lambda x, y: x\n if normalize_observations:\n normalize = running_statistics.normalize\n ppo_network = network_factory(\n env_state.obs.shape[-1],\n env.action_size,\n preprocess_observations_fn=normalize)\n make_policy = ppo_networks.make_inference_fn(ppo_network)\n\n optimizer = optax.adam(learning_rate=learning_rate)\n\n loss_fn = functools.partial(\n ppo_losses.compute_ppo_loss,\n ppo_network=ppo_network,\n entropy_cost=entropy_cost,\n discounting=discounting,\n reward_scaling=reward_scaling,\n gae_lambda=gae_lambda,\n clipping_epsilon=clipping_epsilon,\n normalize_advantage=normalize_advantage)\n\n gradient_update_fn = gradients.gradient_update_fn(\n loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)\n\n def minibatch_step(\n carry, data: types.Transition,\n normalizer_params: running_statistics.RunningStatisticsState):\n optimizer_state, params, key = carry\n key, key_loss = jax.random.split(key)\n (_, metrics), params, optimizer_state = gradient_update_fn(\n params,\n normalizer_params,\n data,\n key_loss,\n optimizer_state=optimizer_state)\n\n return (optimizer_state, params, key), metrics\n\n def sgd_step(carry, unused_t, data: types.Transition,\n normalizer_params: running_statistics.RunningStatisticsState):\n optimizer_state, params, key = carry\n key, key_perm, key_grad = jax.random.split(key, 3)\n\n def convert_data(x: jnp.ndarray):\n x = jax.random.permutation(key_perm, x)\n x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])\n return x\n\n shuffled_data = jax.tree_util.tree_map(convert_data, data)\n (optimizer_state, params, _), metrics = jax.lax.scan(\n functools.partial(minibatch_step, normalizer_params=normalizer_params),\n (optimizer_state, params, key_grad),\n shuffled_data,\n length=num_minibatches)\n return (optimizer_state, params, key), metrics\n\n def training_step(\n carry: Tuple[TrainingState, envs.State, PRNGKey],\n unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:\n training_state, state, key = carry\n key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)\n\n policy = make_policy(\n (training_state.normalizer_params, training_state.params.policy))\n\n def f(carry, unused_t):\n current_state, current_key = carry\n current_key, next_key = jax.random.split(current_key)\n next_state, data = generate_unroll(\n env,\n current_state,\n policy,\n current_key,\n unroll_length,\n extra_fields=('truncation',))\n return (next_state, next_key), data\n\n (state, _), data = jax.lax.scan(\n f, (state, key_generate_unroll), (),\n length=batch_size * num_minibatches // num_envs)\n # Have leading dimensions (batch_size * num_minibatches, unroll_length)\n data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data)\n data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),\n data)\n assert data.discount.shape[1:] == (unroll_length,)\n\n # Update normalization params and normalize observations.\n normalizer_params = running_statistics.update(\n training_state.normalizer_params,\n data.observation,\n pmap_axis_name=_PMAP_AXIS_NAME)\n\n (optimizer_state, params, _), metrics = jax.lax.scan(\n functools.partial(\n sgd_step, data=data, normalizer_params=normalizer_params),\n (training_state.optimizer_state, training_state.params, key_sgd), (),\n length=num_updates_per_batch)\n\n new_training_state = TrainingState(\n optimizer_state=optimizer_state,\n params=params,\n normalizer_params=normalizer_params,\n env_steps=training_state.env_steps + env_step_per_training_step)\n return (new_training_state, state, new_key), metrics\n\n def training_epoch(training_state: TrainingState, state: envs.State,\n key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:\n (training_state, state, _), loss_metrics = jax.lax.scan(\n training_step, (training_state, state, key), (),\n length=num_training_steps_per_epoch)\n loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)\n return training_state, state, loss_metrics\n\n training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)\n\n # Note that this is NOT a pure jittable method.\n def training_epoch_with_timing(\n training_state: TrainingState, env_state: envs.State,\n key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:\n nonlocal training_walltime\n t = time.time()\n training_state, env_state = _strip_weak_type((training_state, env_state))\n result = training_epoch(training_state, env_state, key)\n training_state, env_state, metrics = _strip_weak_type(result)\n\n metrics = jax.tree_util.tree_map(jnp.mean, metrics)\n jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)\n\n epoch_training_time = time.time() - t\n training_walltime += epoch_training_time\n sps = (num_training_steps_per_epoch *\n env_step_per_training_step *\n max(num_resets_per_eval, 1)) / epoch_training_time\n metrics = {\n 'training/sps': sps,\n 'training/walltime': training_walltime,\n **{f'training/{name}': value for name, value in metrics.items()}\n }\n return training_state, env_state, metrics # pytype: disable=bad-return-type # py311-upgrade\n\n\n if saved_params is None:\n init_params = ppo_losses.PPONetworkParams(\n policy=ppo_network.policy_network.init(key_policy),\n value=ppo_network.value_network.init(key_value))\n normalizer_params = running_statistics.init_state(\n specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32')))\n else:\n init_params = ppo_losses.PPONetworkParams(\n policy=saved_params[1],\n value=saved_params[2])\n normalizer_params = saved_params[0]\n\n training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray\n optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars\n params=init_params,\n normalizer_params=normalizer_params,\n env_steps=0)\n training_state = jax.device_put_replicated(\n training_state,\n jax.local_devices()[:local_devices_to_use])\n\n if not eval_env:\n eval_env = environment\n if randomization_fn is not None:\n v_randomization_fn = functools.partial(\n randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)\n )\n eval_env = wrap_for_training(\n eval_env,\n episode_length=episode_length,\n action_repeat=action_repeat,\n randomization_fn=v_randomization_fn,\n )\n\n evaluator = Evaluator(\n eval_env,\n functools.partial(make_policy, deterministic=deterministic_eval),\n num_eval_envs=num_eval_envs,\n episode_length=episode_length,\n action_repeat=action_repeat,\n key=eval_key)\n\n # Run initial eval\n metrics = {}\n if process_id == 0 and num_evals > 1:\n metrics = evaluator.run_evaluation(\n _unpmap(\n (training_state.normalizer_params, training_state.params.policy)),\n training_metrics={})\n logging.info(metrics)\n progress_fn(0, metrics)\n\n training_metrics = {}\n training_walltime = 0\n current_step = 0\n # Initialize variables to allow saving params of run with max score\n max_score = 0\n max_score_params = {}\n for it in range(num_evals_after_init):\n logging.info('starting iteration %s %s', it, time.time() - xt)\n\n for _ in range(max(num_resets_per_eval, 1)):\n # optimization\n epoch_key, local_key = jax.random.split(local_key)\n epoch_keys = jax.random.split(epoch_key, local_devices_to_use)\n (training_state, env_state, training_metrics) = (\n training_epoch_with_timing(training_state, env_state, epoch_keys)\n )\n current_step = int(_unpmap(training_state.env_steps))\n\n key_envs = jax.vmap(\n lambda x, s: jax.random.split(x[0], s),\n in_axes=(0, None))(key_envs, key_envs.shape[1])\n # TODO: move extra reset logic to the AutoResetWrapper.\n env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state\n\n if process_id == 0:\n # Run evals.\n metrics = evaluator.run_evaluation(\n _unpmap(\n (training_state.normalizer_params, training_state.params.policy)),\n training_metrics)\n logging.info(metrics)\n progress_fn(current_step, metrics)\n params = _unpmap(\n (training_state.normalizer_params, training_state.params.policy,\n training_state.params.value))\n\n # Save params if this is the max score\n eval_score = metrics['eval/episode_reward']\n if eval_score > max_score:\n max_score = eval_score\n max_score_params = {\n \"score\": max_score,\n \"params\": params,\n }\n policy_params_fn(current_step, make_policy, params)\n\n total_steps = current_step\n assert total_steps >= num_timesteps\n\n # If there was no mistakes the training_state should still be identical on all\n # devices.\n pmap.assert_is_replicated(training_state)\n params = _unpmap(\n (training_state.normalizer_params, training_state.params.policy,\n training_state.params.value))\n logging.info('total steps: %s', total_steps)\n pmap.synchronize_hosts()\n return (make_policy, params, metrics, max_score_params)\n","metadata":{"cellView":"form","id":"fxmLdUcPUMSD","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Define the PPO networks (adapted from https://github.com/google/brax)\n\n@flax.struct.dataclass\nclass PPONetworks:\n policy_network: networks.FeedForwardNetwork\n value_network: networks.FeedForwardNetwork\n parametric_action_distribution: distribution.ParametricDistribution\n\ndef make_ppo_networks(\n observation_size: int,\n action_size: int,\n preprocess_observations_fn: types.PreprocessObservationFn = types\n .identity_observation_preprocessor,\n policy_hidden_layer_sizes: Sequence[int] = policy_hidden_layer_sizes,\n value_hidden_layer_sizes: Sequence[int] = value_hidden_layer_sizes,\n activation: networks.ActivationFn = linen.swish) -> PPONetworks:\n \"\"\"Make PPO networks with preprocessor.\"\"\"\n parametric_action_distribution = distribution.NormalTanhDistribution(\n event_size=action_size)\n policy_network = networks.make_policy_network(\n parametric_action_distribution.param_size,\n observation_size,\n preprocess_observations_fn=preprocess_observations_fn,\n hidden_layer_sizes=policy_hidden_layer_sizes,\n activation=activation)\n value_network = networks.make_value_network(\n observation_size,\n preprocess_observations_fn=preprocess_observations_fn,\n hidden_layer_sizes=value_hidden_layer_sizes,\n activation=activation)\n\n return PPONetworks(\n policy_network=policy_network,\n value_network=value_network,\n parametric_action_distribution=parametric_action_distribution)","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 7 - Training the Humanoid\n\nTraining for 30m timesteps with 5 evals takes about 12min with two T4 GPUs. But the humanoid won't learn to do more than 1-2 bounces.\n\nLearning to do better (~3 bounces on average and max of 5) is possible though - read section 10 on optimising the reward function.","metadata":{}},{"cell_type":"code","source":"# Load params to restart training from a saved checkpoint\n# (i.e. from the saved policy and value neural networks' weights)\nupload_model = False\nif upload_model:\n saved_params = model.load_params(save_path)\nelse:\n saved_params = None","metadata":{"id":"vgeiw_vNjwcq","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Train\ntrain_fn = functools.partial(\n train, num_timesteps=num_timesteps, num_evals=num_evals,\n episode_length=episode_length, normalize_observations=normalize_observations,\n action_repeat=action_repeat, unroll_length=unroll_length, num_minibatches=num_minibatches,\n num_updates_per_batch=num_updates_per_batch, discounting=discounting,\n learning_rate=learning_rate, entropy_cost=entropy_cost, num_envs=num_envs,\n reward_scaling=reward_scaling, clipping_epsilon=clipping_epsilon, gae_lambda=gae_lambda,\n normalize_advantage=normalize_advantage, batch_size=batch_size, seed=0,\n network_factory=make_ppo_networks, saved_params=saved_params)\n\nx_data = []\ny_data = []\nydataerr = []\ntimes = [datetime.now()]\n\nmax_y, min_y = 15000, 0\ndef progress(num_steps, metrics):\n times.append(datetime.now())\n x_data.append(num_steps)\n y_data.append(metrics['eval/episode_reward'])\n ydataerr.append(metrics['eval/episode_reward_std'])\n\n plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.1])\n plt.ylim([min_y, max_y])\n\n plt.xlabel('# environment steps')\n plt.ylabel('reward per episode')\n plt.title(f'y={y_data[-1]:.1f}')\n\n plt.errorbar(\n x_data, y_data, yerr=ydataerr)\n plt.show()\n\n if 'training/policy_loss' in metrics:\n print(\"Other metrics\") \n print(f\"entropy loss: {metrics['training/entropy_loss']:.2f}\")\n print(f\"value loss: {metrics['training/v_loss']:.2f}\")\n print(f\"max episode reward: {int(metrics['eval/episode_reward_max'])}\")\n print(f\"avg bounces: {metrics['eval/episode_bounces']:.2f}\")\n print(f\"max bounces: {metrics['eval/episode_bounces_max']}\\n\")\n\nmake_inference_fn, train_params, _, max_score_params = train_fn(\n environment=env, progress_fn=progress)\n\nprint(f'time to jit: {times[1] - times[0]}')\nprint(f'time to train: {times[-1] - times[1]}')\nprint(f'total time: {times[-1] - times[0]}\\n')\nprint(f\"max score: {int(max_score_params['score'])}\")","metadata":{"id":"xLiddQYPApBw","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 8 - Save and load the policy\n\nWe can save and load the policy using the brax model API.","metadata":{"id":"YYIch0HEApBx"}},{"cell_type":"code","source":"# Save the model\nmodel.save_params(save_path, train_params)\n# model.save_params(save_path, max_score_params['params'])","metadata":{"id":"Z8gI6qH6ApBx","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Load the model and define the inference function\ninference_fn = make_inference_fn(model.load_params(save_path)[:2])\njit_inference_fn = jax.jit(inference_fn)","metadata":{"id":"h4reaWgxApBx","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 9 - Visualize the policy\n\nFinally we can visualize the Humanoid in action and watch while it bounces the ball with its head!\n\nThis can also be saved to an mp4 file which you can then download from the `Output` section (can be found on the right if running in a laptop).","metadata":{"id":"0G357XIfApBy"}},{"cell_type":"code","source":"eval_env = envs.get_environment(env_name)\n\njit_reset = jax.jit(eval_env.reset)\njit_step = jax.jit(eval_env.step)","metadata":{"id":"osYasMw4ApBy","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Visualize the Humanoid and optionally save it to a mp4 file\n\ninit_time = datetime.now()\nrollouts = []\nmax_rollout_reward = 0\nmax_bounces = 0\nfor i in range(num_rollouts):\n # Initialize the state\n rng = jax.random.PRNGKey(i)\n state = jit_reset(rng)\n rollout = [state.pipeline_state]\n total_reward = 0\n total_bounces = 0\n\n # Grab a trajectory\n n_steps = 100000\n render_every = 2\n\n for i in range(n_steps):\n act_rng, rng = jax.random.split(rng)\n ctrl, _ = jit_inference_fn(state.obs, act_rng)\n state = jit_step(state, ctrl)\n total_reward += state.metrics[\"reward\"]\n total_bounces += state.metrics[\"bounces\"]\n rollout.append(state.pipeline_state)\n\n if state.done:\n break\n\n max_rollout_reward = max(max_rollout_reward, total_reward)\n max_bounces = max(max_bounces, total_bounces)\n \n if total_bounces > num_bounces_threshold:\n print(f\"Iteration with reward {int(total_reward)} and {int(total_bounces)} bounces\")\n video = env.render(rollout[::render_every], camera='side', height=480, width=640)\n media.show_video(video, fps=1.0 / env.dt / render_every)\n media.write_video(f\"/kaggle/working/ball_bounce_{int(total_reward)}_{int(total_bounces)}.mp4\", video, fps=1.0 / env.dt / render_every)\n \nprint(f\"Max rollout reward was - {int(max_rollout_reward)}\")\nprint(f\"Max bounces was - {int(max_bounces)}\")\nprint(f'total time: {datetime.now() - init_time}')","metadata":{"id":"d-UhypudApBy","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 10 - Optimise the reward function for faster learning\n\nThe current reward function is unlikely to enable the Humanoid to learn to bounce the ball more than 1-2 times. This is because it is a sparse reward function (i.e. the rewards are infrequent - no rewards for a long time and then, suddenly, a large reward). This makes learning hard.\n\nThere's an improved reward function very creatively named `get_reward_new` which enables faster learning. Try using it instead to see how far you can go (change the function called from `step` in section 4). Good luck!\n\nTraining for 120m timesteps with 9 evals takes about 43min with two T4 GPUs. That can be enough for it to learn to do 2 bounces on average and a max of 5 (although it can take longer in some training runs as the optimization is non-deterministic). Change `num_timesteps` and `num_evals` in section `2 - Config` accordingly.\n\nLearning to do better (~3 bounces on average) is possible in about 3 hours.\n\nAnd if you manage to go beyond 7 bounces, well done on the amazing work and please share!","metadata":{}},{"cell_type":"markdown","source":"## 11 - Optimise the Humanoid for faster learning\n\nIn addition to optimising the reward function (see section 10), another option is to adapt the Humanoid to the task at hand.\n\nChange `head_type` to `box` in the `2 - Config` section to for the Humanoid to have a cubic head instead and see how far you can go. Good luck!\n\nTraining for 120m timesteps with 9 evals takes about 43min with two T4 GPUs. That can be enough for it to learn to do 6 bounces on average and a max of 11 (although it can take longer in some training runs as the optimization is non-deterministic).\n\nLearning to do better (~15 bounces on average) is possible in about 2 hours.\n\nAnd if you manage to go beyond 30 bounces, well done on the amazing work and please share!","metadata":{}},{"cell_type":"markdown","source":"## 12 - Adjust reward function params for a different skill\n\nBy adjusting the reward function params, namely the `ball_reward_min_z` and `ball_reward_target_z`, it's possible for the Humanoid to learn how to balance the ball on its head instead of trying to bounce it.\n\nLook for the reward function on section 5 and try updating these params to see if you can do it. Good luck!","metadata":{}}]} -------------------------------------------------------------------------------- /tutorials/robot_tricks.ipynb: -------------------------------------------------------------------------------- 1 | {"metadata":{"colab":{"gpuClass":"premium","private_outputs":true,"provenance":[],"collapsed_sections":["YvyGCsgSCxHQ","P1K6IznI2y83"],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3","language":"python"},"accelerator":"GPU","language_info":{"name":"python","version":"3.10.13","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"none","dataSources":[{"sourceId":8081927,"sourceType":"datasetVersion","datasetId":4770280}],"dockerImageVersionId":30698,"isInternetEnabled":true,"language":"python","sourceType":"notebook","isGpuEnabled":false}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Tutorial: Robot Tricks\n\nIn this tutorial you will learn how to use the physics simulator MuJoCo, and Reinforcement Learning to teach a robot to bounce a ball with its foot - an essential skill in the art of football.","metadata":{"id":"Jz-I63-7YgRO"}},{"cell_type":"markdown","source":"## 1 - Kaggle\nIt's recommended to run this notebook on www.kaggle.com where you can use two T4 GPUs for 30h/week for free.\n\nTo import this notebook into Kaggle you need to:\n- Login to your Kaggle account\n- Create a new notebook\n- Click on `File` and then `Import Notebook`\n- Select the tab `GitHub`\n- Search `goncalog/ai-robotics`\n- Select the file `tutorials/robot_tricks.ipynb`\n- Click the `Import` button\n\nTo run this notebook you can either click the `Run All` button or run each cell individually by clicking the `Run current cell` button.","metadata":{}},{"cell_type":"markdown","source":"## 2 - Config\n\nThis is the configuration to run the tutorial, it includes:\n- Training hyperparameters\n- Mujoco environment variables\n- File paths\n- Rendering variables","metadata":{}},{"cell_type":"code","source":"num_timesteps = 60_000_000\nnum_evals = 9\n# num_envs: the number of parallel environments to use for rollouts\nnum_envs = 2048\n\n# learning_rate: learning rate for ppo loss\nlearning_rate = 3e-4\n# discounting: discounting rate\ndiscounting = 0.97\n# episode_length: the length of an environment episode\nepisode_length = 1000\n# normalize_observations: whether to normalize observations\nnormalize_observations = True\n# action_repeat: the number of timesteps to repeat an action\naction_repeat = 1\n# unroll_length: the number of timesteps to unroll in each environment.\n# The PPO loss is computed over `unroll_length` timesteps\nunroll_length = 10\n# entropy_cost: entropy reward for ppo loss, higher values increase entropy of the policy\nentropy_cost = 1e-3\n# batch_size: the batch size for each minibatch SGD step\nbatch_size = 1024\n# num_minibatches: the number of times to run the SGD step,\n# each with a different minibatch with leading dimension of `batch_size`\nnum_minibatches = 32\n# num_updates_per_batch: the number of times to run the gradient update over\n# all minibatches before doing a new environment rollout\nnum_updates_per_batch = 8\n# reward_scaling: float scaling for reward\nreward_scaling = 1\n# clipping_epsilon: clipping epsilon for PPO loss\nclipping_epsilon = 0.3\n# gae_lambda: General advantage estimation lambda\ngae_lambda = 0.95\n# normalize_advantage: whether to normalize advantage estimate\nnormalize_advantage = True\n\npolicy_hidden_layer_sizes = (32,) * 4\nvalue_hidden_layer_sizes = (256,) * 5\n\nball_size = 0.04\ntorso_index = 3 # index of torso body in mjx data (it contains the head geom)\nball_height = 0.5 # z coordinate of centre of mass\nball_x = 0.12 # x coordinate of centre of mass\nball_y = 0.05 # y coordinate of centre of mass\nfoot_left_index = 15 # index of foot_left body in mjx data\nop3_contacts = False\n\n# Simulation time step in seconds. \n# This is the single most important parameter affecting the speed-accuracy trade-off \n# which is inherent in every physics simulation. \n# Smaller values result in better accuracy and stability\nmj_model_timestep = 0.005\n\nsave_path = \"/kaggle/working/mjx_brax_nn\"\nop3_assets_path = \"/kaggle/input/assets-op3\"\n\nnum_rollouts = 1\nnum_bounces_threshold = 0","metadata":{"id":"-Xt8DyfJYrbd","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 3 - Install MuJoCo, MJX, and Brax","metadata":{"id":"YvyGCsgSCxHQ"}},{"cell_type":"code","source":"!pip install mujoco==3.1.2\n!pip install mujoco_mjx==3.1.2\n!pip install brax==0.10.0\n\n# Check if MuJoCo installation was successful\nimport distutils.util\nimport os\nimport subprocess\nif subprocess.run('nvidia-smi').returncode:\n raise RuntimeError(\n 'Cannot communicate with GPU. '\n 'Make sure you are using a GPU runtime. '\n 'Go to the Runtime menu and select Choose runtime type.')\n\n# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n# This is usually installed as part of an Nvidia driver package, but this\n# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\nNVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/'\nNVIDIA_ICD_CONFIG_FILE = '10_nvidia.json'\nif not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n os.makedirs(NVIDIA_ICD_CONFIG_PATH)\n file_path = os.path.join(NVIDIA_ICD_CONFIG_PATH, NVIDIA_ICD_CONFIG_FILE)\n with open(file_path, 'w') as f:\n f.write(\"\"\"{\n \"file_format_version\" : \"1.0.0\",\n \"ICD\" : {\n \"library_path\" : \"libEGL_nvidia.so.0\"\n }\n}\n\"\"\")\n\n# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\nxla_flags = os.environ.get('XLA_FLAGS', '')\nxla_flags += ' --xla_gpu_triton_gemm_any=True'\nos.environ['XLA_FLAGS'] = xla_flags\n\n# Configure MuJoCo to use the EGL rendering backend (requires GPU)\nprint('Setting environment variable to use GPU rendering:')\n%env MUJOCO_GL=egl\n\ntry:\n print('Checking that the installation succeeded:')\n import mujoco\n mujoco.MjModel.from_xml_string('')\nexcept Exception as e:\n raise e from RuntimeError(\n 'Something went wrong during installation. Check the shell output above '\n 'for more information.\\n'\n 'If using a hosted runtime, make sure you enable GPU acceleration '\n 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n\nprint('Installation successful.')\n\n# Import packages for plotting and creating graphics\nimport time\nimport itertools\nimport numpy as np\nfrom typing import Callable, NamedTuple, Optional, Union, List\n\n# Graphics and plotting.\nprint('Installing mediapy:')\n!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n!pip install -q mediapy\nimport mediapy as media\nimport matplotlib.pyplot as plt\n\n# More legible printing from numpy.\nnp.set_printoptions(precision=3, suppress=True, linewidth=100)\n\n# Import MuJoCo, MJX, and Brax\n\nfrom datetime import datetime\nimport functools\nimport jax\nfrom jax import numpy as jp\nimport numpy as np\nfrom typing import Any, Dict, Sequence, Tuple, Union\n\nfrom brax import base\nfrom brax import envs\nfrom brax import math\nfrom brax.base import Base, Motion, Transform\nfrom brax.envs.base import Env, PipelineEnv, State\nfrom brax.mjx.base import State as MjxState\nfrom brax.io import html, mjcf, model\nfrom brax.training import distribution, networks\n\nfrom etils import epath\nfrom flax import linen, struct\nfrom matplotlib import pyplot as plt\nimport mediapy as media\nfrom ml_collections import config_dict\nimport mujoco\nfrom mujoco import mjx","metadata":{"id":"Xqo7pyX-n72M","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 4 - Setting up the OP3 environment with MJX\nMJX is an implementation of MuJoCo written in [JAX](https://jax.readthedocs.io/en/latest/index.html), enabling large batch training on GPU/TPU. In this notebook, we train RL policies with MJX.\n\nHere we implement our environment by adapting the original [OP3](https://github.com/google-deepmind/mujoco_menagerie/blob/main/robotis_op3/op3.xml) environment to also include a ball. Notice that `reset` initializes a `State`, and `step` steps through the physics step and reward logic. The reward and stepping logic train the robot to bounce a ball with its left foot.\n\nNote: if this is the first time you're running this environment on Kaggle, you'll have to upload the OP3 [assets](https://github.com/google-deepmind/mujoco_menagerie/tree/main/robotis_op3/assets) into the `op3_assets_path` set in section `2 - Config` (you can do this by clicking the `Upload` button on the right-hand side and then `New Dataset`)","metadata":{"id":"RAv6WUVUm78k"}},{"cell_type":"code","source":"# OP3 XML\n\nball_material = \"\"\"\n \n \n \"\"\"\nball_default = f\"\"\"\n \n \n \n \"\"\"\nball_body = f\"\"\"\n \n \n \n \n \"\"\"\nif op3_contacts:\n body_collision = \"\"\"\n \n \n \n \n \n \"\"\"\n h1c_collision = ''\n h2_collision = \"\"\"\n \n \n \n \"\"\"\n la1c_collision = ''\n la2c_collision = ''\n la3c_collision = ''\n ra1c_collision = ''\n ra2c_collision = ''\n ra3c_collision = ''\n ll1c_collision = ''\n ll2c_collision = ''\n ll3c_collision = ''\n ll4c_collision = ''\n ll5c_collision = ''\n rl1c_collision = ''\n rl2c_collision = ''\n rl3c_collision = ''\n rl4c_collision = ''\n rl5c_collision = ''\n foot_collision = 'class=\"collision\"'\n \nelse:\n body_collision = \"\"\n h1c_collision = \"\"\n h2_collision = \"\"\n la1c_collision = \"\"\n la2c_collision = \"\"\n la3c_collision = \"\"\n ra1c_collision = \"\"\n ra2c_collision = \"\"\n ra3c_collision = \"\"\n ll1c_collision = \"\"\n ll2c_collision = \"\"\n ll3c_collision = \"\"\n ll4c_collision = \"\"\n ll5c_collision = \"\"\n rl1c_collision = \"\"\n rl2c_collision = \"\"\n rl3c_collision = \"\"\n rl4c_collision = \"\"\n rl5c_collision = \"\"\n foot_collision = 'class=\"no_collision\"'\n\nxml = f\"\"\"\n\n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n {ball_default}\n \n\n \n \n \n \n {ball_material}\n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n \n \n \n\n \n \n \n \n \n \n \n {body_collision}\n \n \n \n \n \n {h1c_collision}\n \n \n \n \n {h2_collision}\n \n \n \n \n \n \n \n {la1c_collision}\n \n \n \n \n {la2c_collision}\n \n \n \n \n {la3c_collision}\n \n \n \n \n \n \n \n {ra1c_collision}\n \n \n \n \n {ra2c_collision}\n \n \n \n \n {ra3c_collision}\n \n \n \n \n \n \n \n {ll1c_collision}\n \n \n \n \n {ll2c_collision}\n \n \n \n \n {ll3c_collision}\n \n \n \n \n {ll4c_collision}\n \n \n \n \n {ll5c_collision}\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n {rl1c_collision}\n \n \n \n \n {rl2c_collision}\n \n \n \n \n {rl3c_collision}\n \n \n \n \n {rl4c_collision}\n \n \n \n \n {rl5c_collision}\n \n \n \n \n \n \n \n \n \n \n \n \n \n {ball_body}\n \n\n \n \n \n \n \n \n \n \n \n \n \n \n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n\n\"\"\"","metadata":{"id":"G5RKmMLkUuC8","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# OP3 Env\n\nclass OP3(PipelineEnv):\n\n def __init__(\n self,\n terminate_when_unhealthy=True,\n reset_noise_scale=1e-2,\n exclude_current_positions_from_observation=True,\n **kwargs,\n ):\n mj_model = mujoco.MjModel.from_xml_string(xml)\n mj_model.opt.solver = mujoco.mjtSolver.mjSOL_CG\n mj_model.opt.iterations = 6\n mj_model.opt.ls_iterations = 6\n\n sys = mjcf.load_model(mj_model)\n\n physics_steps_per_control_step = 5\n kwargs['n_frames'] = kwargs.get(\n 'n_frames', physics_steps_per_control_step)\n kwargs['backend'] = 'mjx'\n\n super().__init__(sys, **kwargs)\n\n self._terminate_when_unhealthy = terminate_when_unhealthy\n self._reset_noise_scale = reset_noise_scale\n self._exclude_current_positions_from_observation = (\n exclude_current_positions_from_observation\n )\n\n\n def reset(self, rng: jp.ndarray) -> State:\n \"\"\"Resets the environment to an initial state.\"\"\"\n rng, rng1, rng2 = jax.random.split(rng, 3)\n\n low, hi = -self._reset_noise_scale, self._reset_noise_scale\n qpos = self.sys.qpos0 + jax.random.uniform(\n rng1, (self.sys.nq,), minval=low, maxval=hi\n )\n qvel = jax.random.uniform(\n rng2, (self.sys.nv,), minval=low, maxval=hi\n )\n \n data = self.pipeline_init(qpos, qvel)\n\n obs = self._get_obs(data, jp.zeros(self.sys.nu))\n reward, done, zero = jp.zeros(3)\n metrics = {\n 'ball_reward': zero,\n 'reward_quadctrl': zero,\n 'reward_alive': zero,\n 'reward': zero,\n 'bounces': zero,\n }\n return State(data, obs, reward, done, metrics)\n\n\n def step(self, state: State, action: jp.ndarray) -> State:\n \"\"\"Runs one timestep of the environment's dynamics.\"\"\"\n data0 = state.pipeline_state\n data = self.pipeline_step(data0, action)\n \n reward, done = self._get_reward(state, action, data0, data)\n obs = self._get_obs(data, action)\n return state.replace(\n pipeline_state=data, obs=obs, reward=reward, done=done\n )\n\n\n def _get_obs(\n self, data: mjx.Data, action: jp.ndarray\n ) -> jp.ndarray:\n \"\"\"Observes robot body and ball position, velocities, and angles.\"\"\"\n position = data.qpos\n if self._exclude_current_positions_from_observation:\n position = position[2:]\n\n # external_contact_forces are excluded\n return jp.concatenate([\n # qpos: position / nq: number of generalized coordinates = dim(qpos)\n position,\n # qvel: velocity / nv: number of degrees of freedom = dim(qvel)\n data.qvel,\n # cinert: com-based body inertia and mass / (nbody, 10)\n data.cinert[1:].ravel(),\n # cvel: com-based velocity [3D rot; 3D tran] / (nbody, 6)\n data.cvel[1:].ravel(),\n # qfrc_actuator: actuator force / nv: number of degrees of freedom\n data.qfrc_actuator,\n ])\n\n\n def _get_reward(\n self, state: State, action: jp.ndarray, data0: mjx.Data, data: mjx.Data\n ) -> Tuple[jp.ndarray, jp.ndarray]:\n \"\"\"Apply reward func based on ball distance to normal of the left foot and target height.\"\"\"\n ctrl_cost_weight = 0.1\n healthy_reward = 5.0\n healthy_z_range = (0.4, 1.5)\n ball_reward = 5.0\n ball_healthy_z_range = (ball_size*2.1, 1.0)\n ball_reward_min_z = ball_size*2.1\n ball_reward_target_z = 0.5\n distance_feet_reward = 5.0\n distance_feet_max = 1.0\n bounce_threshold = ball_reward_target_z - 0.05 # z coordinate\n \n com_before_ball = data0.subtree_com[-1]\n com_after_ball = data.subtree_com[-1]\n com_after_foot = data.subtree_com[foot_left_index]\n distance_foot = jp.sqrt(jp.square(com_after_ball[0] - com_after_foot[0]) + jp.square(com_after_ball[1] - com_after_foot[1]))\n\n min_z, max_z = healthy_z_range\n is_healthy = jp.where(data.q[torso_index] < min_z, 0.0, 1.0)\n is_healthy = jp.where(data.q[torso_index] > max_z, 0.0, is_healthy)\n\n ball_min_z, ball_max_z = ball_healthy_z_range\n is_healthy = jp.where(com_after_ball[2] < ball_min_z, 0.0, is_healthy)\n is_healthy = jp.where(com_after_ball[2] > ball_max_z, 0.0, is_healthy)\n \n is_healthy = jp.where(distance_foot > distance_feet_max, 0.0, is_healthy)\n\n ctrl_cost = ctrl_cost_weight * jp.sum(jp.square(action))\n \n distance_target_height = jp.sqrt(jp.square(com_after_ball[2] - ball_reward_target_z))\n ball_reward = ball_reward * (1 - (distance_target_height / (ball_max_z - ball_reward_target_z)))\n is_ball_reward = jp.where(com_after_ball[2] >= ball_reward_min_z, 1.0, 0.0)\n \n distance_feet_reward = distance_feet_reward * (1 - (distance_foot / distance_feet_max))\n \n reward = ball_reward * is_ball_reward + healthy_reward - ctrl_cost + distance_feet_reward\n\n state.metrics.update(\n ball_reward=ball_reward * is_ball_reward,\n reward_quadctrl=-ctrl_cost,\n reward_alive=healthy_reward,\n reward=reward,\n bounces=self._is_bounce(com_before_ball, com_after_ball, bounce_threshold),\n )\n \n done = 1.0 - is_healthy\n return reward, done\n \n \n # There's a lot of room to improve this function as it should check for contacts\n # between the ball and the lower left limb of the robot\n # (at the time of implementation the contacts data wasn't easily accessible in MJX)\n def _is_bounce(\n self, com_before_ball: jp.ndarray, com_after_ball: jp.ndarray, bounce_threshold: jp.ndarray\n ) -> jp.ndarray:\n \"\"\"Check if ball bounced.\"\"\"\n is_bounce = jp.where(com_before_ball[2] < bounce_threshold, 1.0, 0.0)\n is_bounce = jp.where(com_after_ball[2] >= bounce_threshold, is_bounce, 0.0)\n return is_bounce\n \n\nenvs.register_environment(\"op3\", OP3)","metadata":{"id":"mtGMYNLE3QJN","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 5 - Visualize a rollout\n\nLet's instantiate the environment and visualize a short rollout.\n\nNOTE: Since episodes terminate early if the torso is below the healthy z-range, the only relevant contacts for this task are between the feet and the plane, and the lower left limb and the ball. The other contacts weren't included. This also speeds up the training later on.","metadata":{"id":"P1K6IznI2y83"}},{"cell_type":"code","source":"# Instantiate the environment\nenv_name = \"op3\"\nenv = envs.get_environment(env_name)\n\n# Define the jit reset/step functions\njit_reset = jax.jit(env.reset)\njit_step = jax.jit(env.step)","metadata":{"id":"EhKLFK54C1CH","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Initialize the state\nstate = jit_reset(jax.random.PRNGKey(0))\nprint(f\"Observations size: {len(state.obs)}\")\nprint(f\"Actions size: {env.sys.nu}\")\n\nrollout = [state.pipeline_state]\n\n# Grab a trajectory\nfor i in range(50):\n # ctrl: control / nu: number of actuators/controls = dim(ctrl)\n ctrl = -0.1 * jp.ones(env.sys.nu)\n state = jit_step(state, ctrl)\n rollout.append(state.pipeline_state)\n\nmedia.show_video(env.render(rollout, camera='side', height=480, width=640), fps=1.0 / env.dt)","metadata":{"id":"Ph8u-v2Q2xLS","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 6 - Define the training functions\n\nLet's define the training functions using [PPO](https://openai.com/research/openai-baselines-ppo) to make the robot bounce the ball with its foot.","metadata":{"id":"BQDG6NQ1CbZD"}},{"cell_type":"code","source":"# Define the Acting/Evaluator (adapted from https://github.com/google/brax)\n\n\"\"\"Brax training acting functions.\"\"\"\n\nimport time\nfrom typing import Callable, Sequence, Tuple, Union\n\nfrom brax import envs\nfrom brax.training.types import Metrics\nfrom brax.training.types import Policy\nfrom brax.training.types import PolicyParams\nfrom brax.training.types import PRNGKey\nfrom brax.training.types import Transition\nfrom brax.v1 import envs as envs_v1\nimport jax\nimport numpy as np\n\nActingState = Union[envs.State, envs_v1.State]\nActingEnv = Union[envs.Env, envs_v1.Env, envs_v1.Wrapper]\n\n\ndef actor_step(\n env: ActingEnv,\n env_state: ActingState,\n policy: Policy,\n key: PRNGKey,\n extra_fields: Sequence[str] = ()\n) -> Tuple[ActingState, Transition]:\n \"\"\"Collect data.\"\"\"\n actions, policy_extras = policy(env_state.obs, key)\n nstate = env.step(env_state, actions)\n state_extras = {x: nstate.info[x] for x in extra_fields}\n return nstate, Transition( # pytype: disable=wrong-arg-types # jax-ndarray\n observation=env_state.obs,\n action=actions,\n reward=nstate.reward,\n discount=1 - nstate.done,\n next_observation=nstate.obs,\n extras={\n 'policy_extras': policy_extras,\n 'state_extras': state_extras\n })\n\n\ndef generate_unroll(\n env: ActingEnv,\n env_state: ActingState,\n policy: Policy,\n key: PRNGKey,\n unroll_length: int,\n extra_fields: Sequence[str] = ()\n) -> Tuple[ActingState, Transition]:\n \"\"\"Collect trajectories of given unroll_length.\"\"\"\n\n @jax.jit\n def f(carry, unused_t):\n state, current_key = carry\n current_key, next_key = jax.random.split(current_key)\n nstate, transition = actor_step(\n env, state, policy, current_key, extra_fields=extra_fields)\n return (nstate, next_key), transition\n\n (final_state, _), data = jax.lax.scan(\n f, (env_state, key), (), length=unroll_length)\n return final_state, data\n\n\n# TODO: Consider moving this to its own file.\nclass Evaluator:\n \"\"\"Class to run evaluations.\"\"\"\n\n def __init__(self, eval_env: envs.Env,\n eval_policy_fn: Callable[[PolicyParams],\n Policy], num_eval_envs: int,\n episode_length: int, action_repeat: int, key: PRNGKey):\n \"\"\"Init.\n\n Args:\n eval_env: Batched environment to run evals on.\n eval_policy_fn: Function returning the policy from the policy parameters.\n num_eval_envs: Each env will run 1 episode in parallel for each eval.\n episode_length: Maximum length of an episode.\n action_repeat: Number of physics steps per env step.\n key: RNG key.\n \"\"\"\n self._key = key\n self._eval_walltime = 0.\n\n eval_env = envs.training.EvalWrapper(eval_env)\n\n def generate_eval_unroll(policy_params: PolicyParams,\n key: PRNGKey) -> ActingState:\n reset_keys = jax.random.split(key, num_eval_envs)\n eval_first_state = eval_env.reset(reset_keys)\n return generate_unroll(\n eval_env,\n eval_first_state,\n eval_policy_fn(policy_params),\n key,\n unroll_length=episode_length // action_repeat)[0]\n\n self._generate_eval_unroll = jax.jit(generate_eval_unroll)\n self._steps_per_unroll = episode_length * num_eval_envs\n\n def run_evaluation(self,\n policy_params: PolicyParams,\n training_metrics: Metrics,\n aggregate_episodes: bool = True) -> Metrics:\n \"\"\"Run one epoch of evaluation.\"\"\"\n self._key, unroll_key = jax.random.split(self._key)\n\n t = time.time()\n eval_state = self._generate_eval_unroll(policy_params, unroll_key)\n eval_metrics = eval_state.info['eval_metrics']\n eval_metrics.active_episodes.block_until_ready()\n epoch_eval_time = time.time() - t\n metrics = {}\n for fn in [np.mean, np.std, np.max]:\n suffix = '_std' if fn == np.std else '_max' if fn == np.max else ''\n metrics.update(\n {\n f'eval/episode_{name}{suffix}': (\n fn(value) if aggregate_episodes else value\n )\n for name, value in eval_metrics.episode_metrics.items()\n }\n )\n metrics['eval/avg_episode_length'] = np.mean(eval_metrics.episode_steps)\n metrics['eval/epoch_eval_time'] = epoch_eval_time\n metrics['eval/sps'] = self._steps_per_unroll / epoch_eval_time\n self._eval_walltime = self._eval_walltime + epoch_eval_time\n metrics = {\n 'eval/walltime': self._eval_walltime,\n **training_metrics,\n **metrics\n }\n\n return metrics # pytype: disable=bad-return-type # jax-ndarray","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Define the Training Function (adapted from https://github.com/google/brax)\n\n\"\"\"Proximal policy optimization training.\n\nSee: https://arxiv.org/pdf/1707.06347.pdf\n\"\"\"\n\nimport functools\nimport time\nfrom typing import Callable, Optional, Tuple, Union\n\nfrom absl import logging\nfrom brax import base\nfrom brax import envs\nfrom brax.training import gradients\nfrom brax.training import pmap\nfrom brax.training import types\nfrom brax.training.acme import running_statistics\nfrom brax.training.acme import specs\nfrom brax.training.agents.ppo import losses as ppo_losses\nfrom brax.training.agents.ppo import networks as ppo_networks\nfrom brax.training.types import Params, PolicyParams, PreprocessorParams\nfrom brax.training.types import PRNGKey\nfrom brax.v1 import envs as envs_v1\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport optax\n\n\nInferenceParams = Tuple[running_statistics.NestedMeanStd, Params]\nMetrics = types.Metrics\nValueParams = Any\n\n_PMAP_AXIS_NAME = 'i'\n\n\n@flax.struct.dataclass\nclass TrainingState:\n \"\"\"Contains training state for the learner.\"\"\"\n optimizer_state: optax.OptState\n params: ppo_losses.PPONetworkParams\n normalizer_params: running_statistics.RunningStatisticsState\n env_steps: jnp.ndarray\n\n\ndef _unpmap(v):\n return jax.tree_util.tree_map(lambda x: x[0], v)\n\n\ndef _strip_weak_type(tree):\n # brax user code is sometimes ambiguous about weak_type. in order to\n # avoid extra jit recompilations we strip all weak types from user input\n def f(leaf):\n leaf = jnp.asarray(leaf)\n return leaf.astype(leaf.dtype)\n return jax.tree_util.tree_map(f, tree)\n\n\ndef train(\n environment: Union[envs_v1.Env, envs.Env],\n num_timesteps: int,\n episode_length: int,\n action_repeat: int = 1,\n num_envs: int = 1,\n max_devices_per_host: Optional[int] = None,\n num_eval_envs: int = 128,\n learning_rate: float = 1e-4,\n entropy_cost: float = 1e-4,\n discounting: float = 0.9,\n seed: int = 0,\n unroll_length: int = 10,\n batch_size: int = 32,\n num_minibatches: int = 16,\n num_updates_per_batch: int = 2,\n num_evals: int = 1,\n num_resets_per_eval: int = 0,\n normalize_observations: bool = False,\n reward_scaling: float = 1.0,\n clipping_epsilon: float = 0.3,\n gae_lambda: float = 0.95,\n deterministic_eval: bool = False,\n network_factory: types.NetworkFactory[\n ppo_networks.PPONetworks\n ] = ppo_networks.make_ppo_networks,\n progress_fn: Callable[[int, Metrics], None] = lambda *args: None,\n normalize_advantage: bool = True,\n eval_env: Optional[envs.Env] = None,\n policy_params_fn: Callable[..., None] = lambda *args: None,\n randomization_fn: Optional[\n Callable[[base.System, jnp.ndarray], Tuple[base.System, base.System]]\n ] = None,\n saved_params: Optional[\n Tuple[PreprocessorParams, PolicyParams, ValueParams]\n ] = None,\n):\n \"\"\"PPO training.\n\n Args:\n environment: the environment to train\n num_timesteps: the total number of environment steps to use during training\n episode_length: the length of an environment episode\n action_repeat: the number of timesteps to repeat an action\n num_envs: the number of parallel environments to use for rollouts\n NOTE: `num_envs` must be divisible by the total number of chips since each\n chip gets `num_envs // total_number_of_chips` environments to roll out\n NOTE: `batch_size * num_minibatches` must be divisible by `num_envs` since\n data generated by `num_envs` parallel envs gets used for gradient\n updates over `num_minibatches` of data, where each minibatch has a\n leading dimension of `batch_size`\n max_devices_per_host: maximum number of chips to use per host process\n num_eval_envs: the number of envs to use for evaluation. Each env will run 1\n episode, and all envs run in parallel during eval.\n learning_rate: learning rate for ppo loss\n entropy_cost: entropy reward for ppo loss, higher values increase entropy\n of the policy\n discounting: discounting rate\n seed: random seed\n unroll_length: the number of timesteps to unroll in each environment. The\n PPO loss is computed over `unroll_length` timesteps\n batch_size: the batch size for each minibatch SGD step\n num_minibatches: the number of times to run the SGD step, each with a\n different minibatch with leading dimension of `batch_size`\n num_updates_per_batch: the number of times to run the gradient update over\n all minibatches before doing a new environment rollout\n num_evals: the number of evals to run during the entire training run.\n Increasing the number of evals increases total training time\n num_resets_per_eval: the number of environment resets to run between each\n eval. The environment resets occur on the host\n normalize_observations: whether to normalize observations\n reward_scaling: float scaling for reward\n clipping_epsilon: clipping epsilon for PPO loss\n gae_lambda: General advantage estimation lambda\n deterministic_eval: whether to run the eval with a deterministic policy\n network_factory: function that generates networks for policy and value\n functions\n progress_fn: a user-defined callback function for reporting/plotting metrics\n normalize_advantage: whether to normalize advantage estimate\n eval_env: an optional environment for eval only, defaults to `environment`\n policy_params_fn: a user-defined callback function that can be used for\n saving policy checkpoints\n randomization_fn: a user-defined callback function that generates randomized\n environments\n saved_params: params to init the training with; includes normalizer_params\n and policy and value network params\n\n Returns:\n Tuple of (make_policy function, network params, metrics)\n \"\"\"\n assert batch_size * num_minibatches % num_envs == 0\n xt = time.time()\n\n process_count = jax.process_count()\n process_id = jax.process_index()\n local_device_count = jax.local_device_count()\n local_devices_to_use = local_device_count\n if max_devices_per_host:\n local_devices_to_use = min(local_devices_to_use, max_devices_per_host)\n logging.info(\n 'Device count: %d, process count: %d (id %d), local device count: %d, '\n 'devices to be used count: %d', jax.device_count(), process_count,\n process_id, local_device_count, local_devices_to_use)\n device_count = local_devices_to_use * process_count\n\n # The number of environment steps executed for every training step.\n env_step_per_training_step = (\n batch_size * unroll_length * num_minibatches * action_repeat)\n num_evals_after_init = max(num_evals - 1, 1)\n # The number of training_step calls per training_epoch call.\n # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step *\n # num_resets_per_eval))\n num_training_steps_per_epoch = np.ceil(\n num_timesteps\n / (\n num_evals_after_init\n * env_step_per_training_step\n * max(num_resets_per_eval, 1)\n )\n ).astype(int)\n\n key = jax.random.PRNGKey(seed)\n global_key, local_key = jax.random.split(key)\n del key\n local_key = jax.random.fold_in(local_key, process_id)\n local_key, key_env, eval_key = jax.random.split(local_key, 3)\n # key_networks should be global, so that networks are initialized the same\n # way for different processes.\n key_policy, key_value = jax.random.split(global_key)\n del global_key\n\n assert num_envs % device_count == 0\n\n v_randomization_fn = None\n if randomization_fn is not None:\n randomization_batch_size = num_envs // local_device_count\n # all devices gets the same randomization rng\n randomization_rng = jax.random.split(key_env, randomization_batch_size)\n v_randomization_fn = functools.partial(\n randomization_fn, rng=randomization_rng\n )\n\n if isinstance(environment, envs.Env):\n wrap_for_training = envs.training.wrap\n else:\n wrap_for_training = envs_v1.wrappers.wrap_for_training\n\n env = wrap_for_training(\n environment,\n episode_length=episode_length,\n action_repeat=action_repeat,\n randomization_fn=v_randomization_fn,\n )\n\n reset_fn = jax.jit(jax.vmap(env.reset))\n key_envs = jax.random.split(key_env, num_envs // process_count)\n key_envs = jnp.reshape(key_envs,\n (local_devices_to_use, -1) + key_envs.shape[1:])\n env_state = reset_fn(key_envs)\n\n normalize = lambda x, y: x\n if normalize_observations:\n normalize = running_statistics.normalize\n ppo_network = network_factory(\n env_state.obs.shape[-1],\n env.action_size,\n preprocess_observations_fn=normalize)\n make_policy = ppo_networks.make_inference_fn(ppo_network)\n\n optimizer = optax.adam(learning_rate=learning_rate)\n\n loss_fn = functools.partial(\n ppo_losses.compute_ppo_loss,\n ppo_network=ppo_network,\n entropy_cost=entropy_cost,\n discounting=discounting,\n reward_scaling=reward_scaling,\n gae_lambda=gae_lambda,\n clipping_epsilon=clipping_epsilon,\n normalize_advantage=normalize_advantage)\n\n gradient_update_fn = gradients.gradient_update_fn(\n loss_fn, optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True)\n\n def minibatch_step(\n carry, data: types.Transition,\n normalizer_params: running_statistics.RunningStatisticsState):\n optimizer_state, params, key = carry\n key, key_loss = jax.random.split(key)\n (_, metrics), params, optimizer_state = gradient_update_fn(\n params,\n normalizer_params,\n data,\n key_loss,\n optimizer_state=optimizer_state)\n\n return (optimizer_state, params, key), metrics\n\n def sgd_step(carry, unused_t, data: types.Transition,\n normalizer_params: running_statistics.RunningStatisticsState):\n optimizer_state, params, key = carry\n key, key_perm, key_grad = jax.random.split(key, 3)\n\n def convert_data(x: jnp.ndarray):\n x = jax.random.permutation(key_perm, x)\n x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:])\n return x\n\n shuffled_data = jax.tree_util.tree_map(convert_data, data)\n (optimizer_state, params, _), metrics = jax.lax.scan(\n functools.partial(minibatch_step, normalizer_params=normalizer_params),\n (optimizer_state, params, key_grad),\n shuffled_data,\n length=num_minibatches)\n return (optimizer_state, params, key), metrics\n\n def training_step(\n carry: Tuple[TrainingState, envs.State, PRNGKey],\n unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]:\n training_state, state, key = carry\n key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3)\n\n policy = make_policy(\n (training_state.normalizer_params, training_state.params.policy))\n\n def f(carry, unused_t):\n current_state, current_key = carry\n current_key, next_key = jax.random.split(current_key)\n next_state, data = generate_unroll(\n env,\n current_state,\n policy,\n current_key,\n unroll_length,\n extra_fields=('truncation',))\n return (next_state, next_key), data\n\n (state, _), data = jax.lax.scan(\n f, (state, key_generate_unroll), (),\n length=batch_size * num_minibatches // num_envs)\n # Have leading dimensions (batch_size * num_minibatches, unroll_length)\n data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data)\n data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]),\n data)\n assert data.discount.shape[1:] == (unroll_length,)\n\n # Update normalization params and normalize observations.\n normalizer_params = running_statistics.update(\n training_state.normalizer_params,\n data.observation,\n pmap_axis_name=_PMAP_AXIS_NAME)\n\n (optimizer_state, params, _), metrics = jax.lax.scan(\n functools.partial(\n sgd_step, data=data, normalizer_params=normalizer_params),\n (training_state.optimizer_state, training_state.params, key_sgd), (),\n length=num_updates_per_batch)\n\n new_training_state = TrainingState(\n optimizer_state=optimizer_state,\n params=params,\n normalizer_params=normalizer_params,\n env_steps=training_state.env_steps + env_step_per_training_step)\n return (new_training_state, state, new_key), metrics\n\n def training_epoch(training_state: TrainingState, state: envs.State,\n key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:\n (training_state, state, _), loss_metrics = jax.lax.scan(\n training_step, (training_state, state, key), (),\n length=num_training_steps_per_epoch)\n loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics)\n return training_state, state, loss_metrics\n\n training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME)\n\n # Note that this is NOT a pure jittable method.\n def training_epoch_with_timing(\n training_state: TrainingState, env_state: envs.State,\n key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]:\n nonlocal training_walltime\n t = time.time()\n training_state, env_state = _strip_weak_type((training_state, env_state))\n result = training_epoch(training_state, env_state, key)\n training_state, env_state, metrics = _strip_weak_type(result)\n\n metrics = jax.tree_util.tree_map(jnp.mean, metrics)\n jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics)\n\n epoch_training_time = time.time() - t\n training_walltime += epoch_training_time\n sps = (num_training_steps_per_epoch *\n env_step_per_training_step *\n max(num_resets_per_eval, 1)) / epoch_training_time\n metrics = {\n 'training/sps': sps,\n 'training/walltime': training_walltime,\n **{f'training/{name}': value for name, value in metrics.items()}\n }\n return training_state, env_state, metrics # pytype: disable=bad-return-type # py311-upgrade\n\n\n if saved_params is None:\n init_params = ppo_losses.PPONetworkParams(\n policy=ppo_network.policy_network.init(key_policy),\n value=ppo_network.value_network.init(key_value))\n normalizer_params = running_statistics.init_state(\n specs.Array(env_state.obs.shape[-1:], jnp.dtype('float32')))\n else:\n init_params = ppo_losses.PPONetworkParams(\n policy=saved_params[1],\n value=saved_params[2])\n normalizer_params = saved_params[0]\n\n training_state = TrainingState( # pytype: disable=wrong-arg-types # jax-ndarray\n optimizer_state=optimizer.init(init_params), # pytype: disable=wrong-arg-types # numpy-scalars\n params=init_params,\n normalizer_params=normalizer_params,\n env_steps=0)\n training_state = jax.device_put_replicated(\n training_state,\n jax.local_devices()[:local_devices_to_use])\n\n if not eval_env:\n eval_env = environment\n if randomization_fn is not None:\n v_randomization_fn = functools.partial(\n randomization_fn, rng=jax.random.split(eval_key, num_eval_envs)\n )\n eval_env = wrap_for_training(\n eval_env,\n episode_length=episode_length,\n action_repeat=action_repeat,\n randomization_fn=v_randomization_fn,\n )\n\n evaluator = Evaluator(\n eval_env,\n functools.partial(make_policy, deterministic=deterministic_eval),\n num_eval_envs=num_eval_envs,\n episode_length=episode_length,\n action_repeat=action_repeat,\n key=eval_key)\n\n # Run initial eval\n metrics = {}\n if process_id == 0 and num_evals > 1:\n metrics = evaluator.run_evaluation(\n _unpmap(\n (training_state.normalizer_params, training_state.params.policy)),\n training_metrics={})\n logging.info(metrics)\n progress_fn(0, metrics)\n\n training_metrics = {}\n training_walltime = 0\n current_step = 0\n # Initialize variables to allow saving params of run with max score\n max_score = 0\n max_score_params = {}\n for it in range(num_evals_after_init):\n logging.info('starting iteration %s %s', it, time.time() - xt)\n\n for _ in range(max(num_resets_per_eval, 1)):\n # optimization\n epoch_key, local_key = jax.random.split(local_key)\n epoch_keys = jax.random.split(epoch_key, local_devices_to_use)\n (training_state, env_state, training_metrics) = (\n training_epoch_with_timing(training_state, env_state, epoch_keys)\n )\n current_step = int(_unpmap(training_state.env_steps))\n\n key_envs = jax.vmap(\n lambda x, s: jax.random.split(x[0], s),\n in_axes=(0, None))(key_envs, key_envs.shape[1])\n # TODO: move extra reset logic to the AutoResetWrapper.\n env_state = reset_fn(key_envs) if num_resets_per_eval > 0 else env_state\n\n if process_id == 0:\n # Run evals.\n metrics = evaluator.run_evaluation(\n _unpmap(\n (training_state.normalizer_params, training_state.params.policy)),\n training_metrics)\n logging.info(metrics)\n progress_fn(current_step, metrics)\n params = _unpmap(\n (training_state.normalizer_params, training_state.params.policy,\n training_state.params.value))\n\n # Save params if this is the max score\n eval_score = metrics['eval/episode_reward']\n if eval_score > max_score:\n max_score = eval_score\n max_score_params = {\n \"score\": max_score,\n \"params\": params,\n }\n policy_params_fn(current_step, make_policy, params)\n\n total_steps = current_step\n assert total_steps >= num_timesteps\n\n # If there was no mistakes the training_state should still be identical on all\n # devices.\n pmap.assert_is_replicated(training_state)\n params = _unpmap(\n (training_state.normalizer_params, training_state.params.policy,\n training_state.params.value))\n logging.info('total steps: %s', total_steps)\n pmap.synchronize_hosts()\n return (make_policy, params, metrics, max_score_params)\n","metadata":{"cellView":"form","id":"fxmLdUcPUMSD","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Define the PPO networks (adapted from https://github.com/google/brax)\n\n@flax.struct.dataclass\nclass PPONetworks:\n policy_network: networks.FeedForwardNetwork\n value_network: networks.FeedForwardNetwork\n parametric_action_distribution: distribution.ParametricDistribution\n\ndef make_ppo_networks(\n observation_size: int,\n action_size: int,\n preprocess_observations_fn: types.PreprocessObservationFn = types\n .identity_observation_preprocessor,\n policy_hidden_layer_sizes: Sequence[int] = policy_hidden_layer_sizes,\n value_hidden_layer_sizes: Sequence[int] = value_hidden_layer_sizes,\n activation: networks.ActivationFn = linen.swish) -> PPONetworks:\n \"\"\"Make PPO networks with preprocessor.\"\"\"\n parametric_action_distribution = distribution.NormalTanhDistribution(\n event_size=action_size)\n policy_network = networks.make_policy_network(\n parametric_action_distribution.param_size,\n observation_size,\n preprocess_observations_fn=preprocess_observations_fn,\n hidden_layer_sizes=policy_hidden_layer_sizes,\n activation=activation)\n value_network = networks.make_value_network(\n observation_size,\n preprocess_observations_fn=preprocess_observations_fn,\n hidden_layer_sizes=value_hidden_layer_sizes,\n activation=activation)\n\n return PPONetworks(\n policy_network=policy_network,\n value_network=value_network,\n parametric_action_distribution=parametric_action_distribution)","metadata":{"trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 7 - Training the robot\n\nTraining for 60m timesteps with 9 evals takes about 50min with two T4 GPUs. That can be enough for it to learn to do 2 bounces on average and a max of 3 (although it can take longer in some training runs as the optimization is non-deterministic).\n\nLearning to do better (~10 bounces on average) is possible in about 4 hours.","metadata":{}},{"cell_type":"code","source":"# Load params to restart training from a saved checkpoint\n# (i.e. from the saved policy and value neural networks' weights)\nupload_model = False\nif upload_model:\n saved_params = model.load_params(save_path)\nelse:\n saved_params = None","metadata":{"id":"vgeiw_vNjwcq","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Train\ntrain_fn = functools.partial(\n train, num_timesteps=num_timesteps, num_evals=num_evals,\n episode_length=episode_length, normalize_observations=normalize_observations,\n action_repeat=action_repeat, unroll_length=unroll_length, num_minibatches=num_minibatches,\n num_updates_per_batch=num_updates_per_batch, discounting=discounting,\n learning_rate=learning_rate, entropy_cost=entropy_cost, num_envs=num_envs,\n reward_scaling=reward_scaling, clipping_epsilon=clipping_epsilon, gae_lambda=gae_lambda,\n normalize_advantage=normalize_advantage, batch_size=batch_size, seed=0,\n network_factory=make_ppo_networks, saved_params=saved_params)\n\nx_data = []\ny_data = []\nydataerr = []\ntimes = [datetime.now()]\n\nmax_y, min_y = 15000, 0\ndef progress(num_steps, metrics):\n times.append(datetime.now())\n x_data.append(num_steps)\n y_data.append(metrics['eval/episode_reward'])\n ydataerr.append(metrics['eval/episode_reward_std'])\n\n plt.xlim([0, train_fn.keywords['num_timesteps'] * 1.1])\n plt.ylim([min_y, max_y])\n\n plt.xlabel('# environment steps')\n plt.ylabel('reward per episode')\n plt.title(f'y={y_data[-1]:.1f}')\n\n plt.errorbar(\n x_data, y_data, yerr=ydataerr)\n plt.show()\n\n if 'training/policy_loss' in metrics:\n print(\"Other metrics\") \n print(f\"entropy loss: {metrics['training/entropy_loss']:.2f}\")\n print(f\"value loss: {metrics['training/v_loss']:.2f}\")\n print(f\"max episode reward: {int(metrics['eval/episode_reward_max'])}\")\n print(f\"avg bounces: {metrics['eval/episode_bounces']:.2f}\")\n print(f\"max bounces: {metrics['eval/episode_bounces_max']}\\n\")\n\nmake_inference_fn, train_params, _, max_score_params = train_fn(\n environment=env, progress_fn=progress)\n\nprint(f'time to jit: {times[1] - times[0]}')\nprint(f'time to train: {times[-1] - times[1]}')\nprint(f'total time: {times[-1] - times[0]}\\n')\nprint(f\"max score: {int(max_score_params['score'])}\")","metadata":{"id":"xLiddQYPApBw","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 8 - Save and load the policy\n\nWe can save and load the policy using the brax model API.","metadata":{"id":"YYIch0HEApBx"}},{"cell_type":"code","source":"# Save the model\nmodel.save_params(save_path, train_params)\n# model.save_params(save_path, max_score_params['params'])","metadata":{"id":"Z8gI6qH6ApBx","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Load the model and define the inference function\ninference_fn = make_inference_fn(model.load_params(save_path)[:2])\njit_inference_fn = jax.jit(inference_fn)","metadata":{"id":"h4reaWgxApBx","cellView":"form","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":"## 9 - Visualize the policy\n\nFinally we can visualize the robot in action and watch while it bounces the ball with its foot!\n\nThis can also be saved to an mp4 file which you can then download from the `Output` section (can be found on the right if running in a laptop).","metadata":{"id":"0G357XIfApBy"}},{"cell_type":"code","source":"eval_env = envs.get_environment(env_name)\n\njit_reset = jax.jit(eval_env.reset)\njit_step = jax.jit(eval_env.step)","metadata":{"id":"osYasMw4ApBy","trusted":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":"# Visualize the robot and optionally save it to a mp4 file\n\ninit_time = datetime.now()\nrollouts = []\nmax_rollout_reward = 0\nmax_bounces = 0\nfor i in range(num_rollouts):\n # Initialize the state\n rng = jax.random.PRNGKey(i)\n state = jit_reset(rng)\n rollout = [state.pipeline_state]\n total_reward = 0\n total_bounces = 0\n\n # Grab a trajectory\n n_steps = 100000\n render_every = 2\n\n for i in range(n_steps):\n act_rng, rng = jax.random.split(rng)\n ctrl, _ = jit_inference_fn(state.obs, act_rng)\n state = jit_step(state, ctrl)\n total_reward += state.metrics[\"reward\"]\n total_bounces += state.metrics[\"bounces\"]\n rollout.append(state.pipeline_state)\n\n if state.done:\n break\n\n max_rollout_reward = max(max_rollout_reward, total_reward)\n max_bounces = max(max_bounces, total_bounces)\n \n if total_bounces > num_bounces_threshold:\n print(f\"Iteration with reward {int(total_reward)} and {int(total_bounces)} bounces\")\n video = env.render(rollout[::render_every], camera='side', height=480, width=640)\n media.show_video(video, fps=1.0 / env.dt / render_every)\n media.write_video(f\"/kaggle/working/ball_bounce_{int(total_reward)}_{int(total_bounces)}.mp4\", video, fps=1.0 / env.dt / render_every)\n \nprint(f\"Max rollout reward was - {int(max_rollout_reward)}\")\nprint(f\"Max bounces was - {int(max_bounces)}\")\nprint(f'total time: {datetime.now() - init_time}')","metadata":{"id":"d-UhypudApBy","trusted":true},"execution_count":null,"outputs":[]}]} --------------------------------------------------------------------------------