├── .github └── workflows │ ├── ci.yml │ └── pypi.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── bin └── learn ├── brax ├── __init__.py ├── actuator.py ├── actuator_test.py ├── base.py ├── base_test.py ├── com.py ├── com_test.py ├── contact.py ├── contact_test.py ├── envs │ ├── __init__.py │ ├── ant.py │ ├── assets │ │ ├── ant.xml │ │ ├── half_cheetah.xml │ │ ├── hopper.xml │ │ ├── humanoid.xml │ │ ├── humanoidstandup.xml │ │ ├── inverted_double_pendulum.xml │ │ ├── inverted_pendulum.xml │ │ ├── pusher.xml │ │ ├── reacher.xml │ │ ├── swimmer.xml │ │ └── walker2d.xml │ ├── base.py │ ├── env_test.py │ ├── fast.py │ ├── half_cheetah.py │ ├── hopper.py │ ├── humanoid.py │ ├── humanoidstandup.py │ ├── inverted_double_pendulum.py │ ├── inverted_pendulum.py │ ├── pusher.py │ ├── reacher.py │ ├── swimmer.py │ ├── walker2d.py │ └── wrappers │ │ ├── __init__.py │ │ ├── dm_env.py │ │ ├── dm_env_test.py │ │ ├── gym.py │ │ ├── gym_test.py │ │ ├── torch.py │ │ ├── training.py │ │ └── training_test.py ├── experimental │ ├── __init__.py │ └── barkour │ │ ├── README.md │ │ ├── __init__.py │ │ ├── assets │ │ ├── joystick.gif │ │ ├── joystick_real.gif │ │ └── joystick_v0.gif │ │ ├── data │ │ ├── barkour_run_0.csv │ │ ├── barkour_run_1.csv │ │ └── barkour_run_2.csv │ │ ├── score_barkour.py │ │ └── tutorial.ipynb ├── fluid.py ├── fluid_test.py ├── generalized │ ├── __init__.py │ ├── base.py │ ├── constraint.py │ ├── constraint_test.py │ ├── dynamics.py │ ├── dynamics_test.py │ ├── integrator.py │ ├── mass.py │ ├── mass_test.py │ ├── perf_test.py │ ├── pipeline.py │ └── pipeline_test.py ├── io │ ├── __init__.py │ ├── html.py │ ├── image.py │ ├── json.py │ ├── json_test.py │ ├── metrics.py │ ├── mjcf.py │ ├── mjcf_test.py │ ├── model.py │ └── torch.py ├── kinematics.py ├── kinematics_test.py ├── math.py ├── math_test.py ├── mjx │ ├── __init__.py │ ├── base.py │ ├── perf_test.py │ ├── pipeline.py │ └── pipeline_test.py ├── positional │ ├── __init__.py │ ├── base.py │ ├── collisions.py │ ├── integrator.py │ ├── joints.py │ ├── joints_test.py │ ├── perf_test.py │ ├── pipeline.py │ └── pipeline_test.py ├── scan.py ├── scan_test.py ├── spring │ ├── __init__.py │ ├── base.py │ ├── collisions.py │ ├── integrator.py │ ├── joints.py │ ├── joints_test.py │ ├── perf_test.py │ ├── pipeline.py │ └── pipeline_test.py ├── test_data │ ├── capsule.xml │ ├── colour_objects.xml │ ├── convex_convex.xml │ ├── double_pendulum.xml │ ├── double_prismatic.xml │ ├── fluid_box.xml │ ├── fluid_box_offset_com.xml │ ├── fluid_ellipsoid.xml │ ├── fluid_sphere.xml │ ├── fluid_two_spheres.xml │ ├── fluid_wind.xml │ ├── meshes │ │ ├── cylinder.stl │ │ ├── dodecahedron.stl │ │ ├── pyramid.stl │ │ └── tetrahedron.stl │ ├── nonzero_joint_ref.xml │ ├── prismaversal_2dof_joint.xml │ ├── prismaversal_3dof_joint.xml │ ├── single_pendulum.xml │ ├── single_pendulum_motor.xml │ ├── single_pendulum_position.xml │ ├── single_pendulum_position_frclimit.xml │ ├── single_pendulum_velocity.xml │ ├── single_prismatic.xml │ ├── single_spherical_pendulum.xml │ ├── single_spherical_pendulum_motor.xml │ ├── single_spherical_pendulum_position.xml │ ├── single_universal_pendulum.xml │ ├── solver_params_v2.xml │ ├── triple_pendulum.xml │ ├── triple_pendulum_motor.xml │ ├── triple_prismatic.xml │ ├── world_body_transform.xml │ ├── world_fromto.xml │ └── world_self_collision.xml ├── test_utils.py ├── training │ ├── __init__.py │ ├── acme │ │ ├── __init__.py │ │ ├── running_statistics.py │ │ ├── specs.py │ │ └── types.py │ ├── acting.py │ ├── agents │ │ ├── __init__.py │ │ ├── apg │ │ │ ├── __init__.py │ │ │ ├── networks.py │ │ │ ├── train.py │ │ │ └── train_test.py │ │ ├── ars │ │ │ ├── __init__.py │ │ │ ├── networks.py │ │ │ ├── train.py │ │ │ └── train_test.py │ │ ├── bc │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── losses.py │ │ │ ├── networks.py │ │ │ ├── train.py │ │ │ └── train_test.py │ │ ├── es │ │ │ ├── __init__.py │ │ │ ├── networks.py │ │ │ ├── train.py │ │ │ └── train_test.py │ │ ├── ppo │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── checkpoint_test.py │ │ │ ├── losses.py │ │ │ ├── networks.py │ │ │ ├── networks_vision.py │ │ │ ├── train.py │ │ │ └── train_test.py │ │ └── sac │ │ │ ├── __init__.py │ │ │ ├── checkpoint.py │ │ │ ├── checkpoint_test.py │ │ │ ├── losses.py │ │ │ ├── networks.py │ │ │ ├── train.py │ │ │ └── train_test.py │ ├── checkpoint.py │ ├── distribution.py │ ├── gradients.py │ ├── learner.py │ ├── logger.py │ ├── networks.py │ ├── pmap.py │ ├── replay_buffers.py │ ├── replay_buffers_test.py │ ├── spectral_norm.py │ ├── types.py │ └── types_test.py └── visualizer │ ├── favicon.ico │ ├── index.html │ ├── js │ ├── animator.js │ ├── selector.js │ ├── system.js │ └── viewer.js │ └── visualizer.py ├── datasets ├── README.md ├── ppo_10_million_steps.tar.gz ├── ppo_500_million_steps.tar.gz └── sac_5_million_steps.tar.gz ├── docs ├── code_of_conduct.md ├── contributing.md ├── img │ ├── a1.gif │ ├── ant.gif │ ├── ant_v2.gif │ ├── brax_logo.gif │ ├── braxlines │ │ ├── ant_diayn.png │ │ ├── ant_diayn_skill1.gif │ │ ├── ant_diayn_skill2.gif │ │ ├── ant_diayn_skill3.gif │ │ ├── ant_diayn_skill4.gif │ │ ├── ant_smm.gif │ │ ├── ant_smm.png │ │ ├── humanoid_diayn.png │ │ ├── humanoid_diayn_skill1.gif │ │ ├── humanoid_diayn_skill2.gif │ │ ├── humanoid_diayn_skill3.gif │ │ ├── humanoid_smm.gif │ │ ├── humanoid_smm.png │ │ └── sketches.png │ ├── composer │ │ ├── ant_chase.gif │ │ ├── ant_push.gif │ │ ├── pro_ant1.gif │ │ └── pro_ant2.gif │ ├── fetch.gif │ ├── grasp.gif │ ├── halfcheetah.gif │ ├── humanoid.gif │ ├── humanoid_v2.gif │ └── ur5e.gif └── release-notes │ ├── next-release.md │ ├── v0.0.11.md │ ├── v0.0.12.md │ ├── v0.0.13.md │ ├── v0.0.14.md │ ├── v0.0.15.md │ ├── v0.0.16.md │ ├── v0.1.0.md │ ├── v0.1.1.md │ ├── v0.1.2.md │ ├── v0.10.0.md │ ├── v0.10.1.md │ ├── v0.10.2.md │ ├── v0.10.3.md │ ├── v0.10.4.md │ ├── v0.10.5.md │ ├── v0.11.0.md │ ├── v0.12.0.md │ ├── v0.12.1.md │ ├── v0.12.2.md │ ├── v0.12.3.md │ ├── v0.9.0.md │ ├── v0.9.1.md │ ├── v0.9.2.md │ ├── v0.9.3.md │ └── v0.9.4.md ├── notebooks ├── basics.ipynb ├── training.ipynb └── training_torch.ipynb └── pyproject.toml /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | # Triggers the workflow on push or pull request events but only for the main branch 5 | push: 6 | branches: [ main ] 7 | pull_request: 8 | branches: [ main ] 9 | 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python-version: ['3.10', '3.11'] 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -e .[develop] 27 | - name: Run tests 28 | run: pytest --ignore=brax/v1 29 | -------------------------------------------------------------------------------- /.github/workflows/pypi.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: pypi 5 | 6 | on: 7 | push: 8 | tags: "v*" 9 | workflow_dispatch: 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | permissions: 15 | contents: write # Needed to create releases 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 20 | uses: actions/setup-python@v1 21 | with: 22 | python-version: "3.x" 23 | - name: Create Release 24 | id: create_release 25 | uses: actions/create-release@latest 26 | env: 27 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token 28 | with: 29 | tag_name: ${{ github.ref }} 30 | release_name: ${{ github.ref }} 31 | body_path: docs/release-notes/${{ github.ref_name }}.md 32 | draft: false 33 | prerelease: false 34 | 35 | - name: Install dependencies 36 | run: | 37 | pip install uv 38 | uv pip install --system -e ".[dev]" 39 | uv pip install --system build twine 40 | - name: Build and publish 41 | env: 42 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 43 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 44 | run: | 45 | python -m build 46 | twine upload dist/* 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .ipynb_checkpoints 3 | node_modules 4 | /pip_test 5 | /_python_build 6 | *.pyc 7 | __pycache__ 8 | *.swp 9 | .vscode/ 10 | .idea/ 11 | *.egg-info 12 | 13 | ####### UNITY STUFF 14 | 15 | */[Ll]ibrary/ 16 | */[Tt]emp/ 17 | */[Oo]bj/ 18 | [Bb]uild/ 19 | [Bb]uilds/ 20 | */[Bb]uild/ 21 | */[Bb]uilds/ 22 | */[Ll]ogs/ 23 | */[Aa]ssets/AssetStoreTools* 24 | */[Ii]mport/ 25 | */[Aa]ssets/AssetBundles* 26 | */[Aa]ssets/StreamingAssets/AssetBundles* 27 | 28 | # Visual Studio cache directory 29 | .vs/ 30 | 31 | # Gradle cache directory 32 | .gradle/ 33 | 34 | # Autogenerated VS/MD/Consulo solution and project files 35 | ExportedObj/ 36 | .consulo/ 37 | *.csproj 38 | *.unityproj 39 | *.sln 40 | *.suo 41 | *.tmp 42 | *.user 43 | *.userprefs 44 | *.pidb 45 | *.booproj 46 | *.svd 47 | *.pdb 48 | *.opendb 49 | *.VC.db 50 | 51 | # Unity3D generated meta files 52 | *.pidb.meta 53 | *.pdb.meta 54 | 55 | # Unity3D generated file on crash reports 56 | sysinfo.txt 57 | 58 | # Builds 59 | *.apk 60 | *.unitypackage 61 | 62 | UnitySDK.log -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune docs 2 | prune notebooks 3 | prune datasets 4 | 5 | include brax/envs/assets/*.xml 6 | recursive-include brax/experimental/barkour *.csv *.stl *.xml 7 | recursive-include brax/test_data *.xml *.stl *.obj *.urdf 8 | recursive-include brax/visualizer * 9 | -------------------------------------------------------------------------------- /bin/learn: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from absl import app 4 | from brax.training import learner 5 | 6 | if __name__ == '__main__': 7 | app.run(learner.main) 8 | -------------------------------------------------------------------------------- /brax/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Import top-level classes and functions here for encapsulation/clarity.""" 16 | 17 | __version__ = '0.12.3' 18 | 19 | from brax.base import Motion 20 | from brax.base import State 21 | from brax.base import System 22 | from brax.base import Transform 23 | -------------------------------------------------------------------------------- /brax/actuator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Functions for applying actuators to a physics pipeline.""" 17 | 18 | from brax.base import System 19 | import jax 20 | from jax import numpy as jp 21 | 22 | 23 | def to_tau( 24 | sys: System, act: jax.Array, q: jax.Array, qd: jax.Array 25 | ) -> jax.Array: 26 | """Convert actuator to a joint force tau. 27 | 28 | Args: 29 | sys: system defining the kinematic tree and other properties 30 | act: (act_size,) actuator force input vector 31 | q: joint position vector 32 | qd: joint velocity vector 33 | 34 | Returns: 35 | tau: (qd_size,) vector of joint forces 36 | """ 37 | if sys.act_size() == 0: 38 | return jp.zeros(sys.qd_size()) 39 | 40 | ctrl_range = sys.actuator.ctrl_range 41 | force_range = sys.actuator.force_range 42 | 43 | q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id] 44 | act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1]) 45 | # See https://github.com/deepmind/mujoco/discussions/754 for why gear is 46 | # used for the bias term. 47 | bias = sys.actuator.gear * ( 48 | q * sys.actuator.bias_q + qd * sys.actuator.bias_qd 49 | ) 50 | 51 | force = sys.actuator.gain * act + bias 52 | force = jp.clip(force, force_range[:, 0], force_range[:, 1]) 53 | 54 | force *= sys.actuator.gear 55 | tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(force) 56 | 57 | return tau 58 | -------------------------------------------------------------------------------- /brax/base_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for brax.base.""" 16 | 17 | from absl.testing import absltest 18 | from brax import test_utils 19 | import jax 20 | import numpy as np 21 | 22 | 23 | class BaseTest(absltest.TestCase): 24 | 25 | def test_write_mass_array(self): 26 | sys = test_utils.load_fixture('ant.xml') 27 | rng = jax.random.PRNGKey(0) 28 | noise = jax.random.uniform(rng, (sys.link.inertia.mass.shape[0],)) 29 | new_mass = sys.link.inertia.mass + noise 30 | 31 | sys_w = sys.tree_replace({'link.inertia.mass': new_mass}) 32 | np.testing.assert_array_equal(sys_w.link.inertia.mass, new_mass) 33 | 34 | def test_write_mass_value(self): 35 | sys = test_utils.load_fixture('ant.xml') 36 | sys_w = sys.tree_replace({'link.inertia.mass': 1.0}) 37 | self.assertEqual(sys_w.link.inertia.mass, 1.0) 38 | 39 | def test_write_array(self): 40 | sys = test_utils.load_fixture('ant.xml') 41 | np.random.seed(0) 42 | 43 | expected = np.random.uniform(sys.elasticity.shape) 44 | sys_w = sys.tree_replace({'elasticity': expected}) 45 | np.testing.assert_array_equal(sys_w.elasticity, expected) 46 | 47 | 48 | if __name__ == '__main__': 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /brax/com.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Helper functions for physics calculations in maximal coordinates.""" 16 | 17 | # pylint:disable=g-multiple-import 18 | from typing import Tuple 19 | 20 | from brax import math 21 | from brax.base import Motion, System, Transform 22 | import jax 23 | from jax import numpy as jp 24 | 25 | 26 | def from_world( 27 | sys: System, x: Transform, xd: Motion 28 | ) -> Tuple[Transform, Motion]: 29 | """Converts link transform and motion from world frame to com frame.""" 30 | x_i = x.vmap().do(Transform.create(pos=sys.link.inertia.transform.pos)) 31 | xd_i = Transform.create(pos=x_i.pos - x.pos).vmap().do(xd) 32 | return x_i, xd_i 33 | 34 | 35 | def to_world( 36 | sys: System, x_i: Transform, xd_i: Motion 37 | ) -> Tuple[Transform, Motion]: 38 | """Converts link transform and motion from com frame to world frame.""" 39 | x = x_i.vmap().do(Transform.create(pos=-sys.link.inertia.transform.pos)) 40 | xd = Transform.create(pos=x.pos - x_i.pos).vmap().do(xd_i) 41 | return x, xd 42 | 43 | 44 | def inv_inertia(sys, x) -> jax.Array: 45 | """Gets the inverse inertia at the center of mass in world frame.""" 46 | 47 | @jax.vmap 48 | def inv_i(link_inertia, x_rot): 49 | ri = math.quat_mul(x_rot, link_inertia.transform.rot) 50 | i_diag = jp.diagonal(link_inertia.i) ** (1 - sys.spring_inertia_scale) 51 | i_inv_mx = jp.diag(1 / i_diag) 52 | i_rot_row = jax.vmap(math.rotate, in_axes=[0, None])(i_inv_mx, ri) 53 | i_rot_col = jax.vmap(math.rotate, in_axes=[0, None])(i_rot_row.T, ri) 54 | return i_rot_col 55 | 56 | return inv_i(sys.link.inertia, x.rot) 57 | -------------------------------------------------------------------------------- /brax/com_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for com.""" 16 | 17 | # pylint:disable=g-multiple-import 18 | from absl.testing import absltest 19 | from brax import com 20 | from brax import kinematics 21 | from brax import math 22 | from brax import test_utils 23 | import jax 24 | from jax import numpy as jp 25 | import numpy as np 26 | 27 | 28 | class ComTest(absltest.TestCase): 29 | 30 | def test_transform(self): 31 | sys = test_utils.load_fixture('capsule.xml') 32 | sys = sys.replace( 33 | link=sys.link.replace( 34 | inertia=sys.link.inertia.replace( 35 | transform=sys.link.inertia.transform.replace( 36 | pos=jp.array([[0.0, 0.0, -0.1]]) 37 | ) 38 | ) 39 | ) 40 | ) 41 | 42 | x, xd = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size())) 43 | x = x.replace( 44 | pos=jp.array([[1.0, -1.0, 0.3]]), 45 | rot=jp.array([[0.976442, 0.16639178, -0.13593051, 0.01994515]]), 46 | ) 47 | xd = xd.replace( 48 | vel=jp.array([[5.0, 1.0, -1.0]]), ang=jp.array([[1.0, 2.0, -3.0]]) 49 | ) 50 | 51 | x_i, xd_i = com.from_world(sys, x, xd) 52 | self.assertNotAlmostEqual(jp.abs(x_i.pos - x.pos).sum(), 0) 53 | np.testing.assert_array_almost_equal(x_i.rot, x.rot) 54 | self.assertNotAlmostEqual(jp.abs(xd_i.vel - xd.vel).sum(), 0) 55 | np.testing.assert_array_almost_equal(xd_i.ang, xd.ang) 56 | xp, xdp = com.to_world(sys, x_i, xd_i) 57 | np.testing.assert_array_almost_equal(x.pos, xp.pos) 58 | np.testing.assert_array_almost_equal(x.rot, xp.rot) 59 | np.testing.assert_array_almost_equal(xd.vel, xdp.vel) 60 | np.testing.assert_array_almost_equal(xd.ang, xdp.ang) 61 | 62 | def test_inv_inertia(self): 63 | sys = test_utils.load_fixture('capsule.xml') 64 | sys = sys.replace( 65 | link=sys.link.replace( 66 | transform=sys.link.transform.replace( 67 | rot=math.euler_to_quat(jp.array([0.0, 0.0, 45.0])).reshape( 68 | 1, -1 69 | ) 70 | ) 71 | ) 72 | ) 73 | x, _ = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size())) 74 | x = x.replace( 75 | rot=math.euler_to_quat(jp.array([45.0, 0.0, 0.0])).reshape(1, -1) 76 | ) 77 | 78 | # get the expected inv inertia 79 | x_i = x.vmap().do(sys.link.inertia.transform) 80 | cinr = x_i.replace(pos=jp.zeros_like(x_i.pos)).vmap().do(sys.link.inertia) 81 | expected = jax.vmap(math.inv_3x3)(cinr.i) 82 | 83 | inv_i = com.inv_inertia(sys, x) 84 | np.testing.assert_array_almost_equal(expected, inv_i, 1e-6) 85 | 86 | 87 | if __name__ == '__main__': 88 | absltest.main() 89 | -------------------------------------------------------------------------------- /brax/contact.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Calculations for generating contacts.""" 17 | 18 | from typing import Optional 19 | from brax import math 20 | from brax.base import Contact 21 | from brax.base import System 22 | from brax.base import Transform 23 | import jax 24 | from jax import numpy as jp 25 | from mujoco import mjx 26 | 27 | 28 | def get(sys: System, x: Transform) -> Optional[Contact]: 29 | """Calculates contacts. 30 | 31 | Args: 32 | sys: system defining the kinematic tree and other properties 33 | x: link transforms in world frame 34 | 35 | Returns: 36 | Contact pytree 37 | """ 38 | d = mjx.make_data(sys) 39 | if d.ncon == 0: 40 | return None 41 | 42 | @jax.vmap 43 | def local_to_global(pos1, quat1, pos2, quat2): 44 | pos = pos1 + math.rotate(pos2, quat1) 45 | mat = math.quat_to_3x3(math.quat_mul(quat1, quat2)) 46 | return pos, mat 47 | 48 | x = x.concatenate(Transform.zero((1,))) 49 | xpos = x.pos[sys.geom_bodyid - 1] 50 | xquat = x.rot[sys.geom_bodyid - 1] 51 | geom_xpos, geom_xmat = local_to_global( 52 | xpos, xquat, sys.geom_pos, sys.geom_quat 53 | ) 54 | 55 | # pytype: disable=wrong-arg-types 56 | d = d.replace(geom_xpos=geom_xpos, geom_xmat=geom_xmat) 57 | d = mjx.collision(sys, d) 58 | # pytype: enable=wrong-arg-types 59 | 60 | c = d.contact 61 | elasticity = (sys.elasticity[c.geom1] + sys.elasticity[c.geom2]) * 0.5 62 | 63 | body1 = jp.array(sys.geom_bodyid)[c.geom1] - 1 64 | body2 = jp.array(sys.geom_bodyid)[c.geom2] - 1 65 | link_idx = (body1, body2) 66 | 67 | return Contact(elasticity=elasticity, link_idx=link_idx, **c.__dict__) 68 | -------------------------------------------------------------------------------- /brax/envs/assets/inverted_double_pendulum.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /brax/envs/assets/inverted_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /brax/envs/assets/reacher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 55 | -------------------------------------------------------------------------------- /brax/envs/assets/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 40 | -------------------------------------------------------------------------------- /brax/envs/env_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for brax envs.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from brax import envs 22 | from brax import test_utils 23 | import jax 24 | from jax import numpy as jp 25 | 26 | _EXPECTED_SPS = {'spring': {'ant': 1000, 'humanoid': 1000}} 27 | 28 | 29 | 30 | class EnvTest(parameterized.TestCase): 31 | params = [ 32 | (b, e, _EXPECTED_SPS[b][e]) 33 | for b in _EXPECTED_SPS 34 | for e in _EXPECTED_SPS[b] 35 | ] 36 | 37 | @parameterized.parameters(params) 38 | def testSpeed(self, backend, env_name, expected_sps): 39 | batch_size = 128 40 | episode_length = 100 if expected_sps < 10_000 else 1000 41 | 42 | env = envs.create( 43 | env_name, 44 | backend=backend, 45 | episode_length=episode_length, 46 | auto_reset=True, 47 | ) 48 | zero_action = jp.zeros(env.action_size) 49 | step_fn = functools.partial(env.step, action=zero_action) 50 | 51 | mean_sps = test_utils.benchmark( 52 | f'{backend}_{env_name}', 53 | env.reset, 54 | step_fn, 55 | batch_size=batch_size, 56 | length=episode_length, 57 | ) 58 | self.assertGreater(mean_sps, expected_sps * 0.99) 59 | 60 | 61 | 62 | if __name__ == '__main__': 63 | absltest.main() 64 | -------------------------------------------------------------------------------- /brax/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/envs/wrappers/dm_env_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests the dm env wrapper.""" 16 | 17 | from absl.testing import absltest 18 | from brax import envs 19 | from brax.envs.wrappers import dm_env 20 | import numpy as np 21 | 22 | 23 | class DmEnvTest(absltest.TestCase): 24 | 25 | def test_action_space(self): 26 | """Tests the action space of the DmEnvWrapper.""" 27 | base_env = envs.create('pusher') 28 | env = dm_env.DmEnvWrapper(base_env) 29 | np.testing.assert_array_equal( 30 | env.action_spec().minimum, base_env.sys.actuator.ctrl_range[:, 0] 31 | ) 32 | np.testing.assert_array_equal( 33 | env.action_spec().maximum, base_env.sys.actuator.ctrl_range[:, 1] 34 | ) 35 | 36 | 37 | if __name__ == '__main__': 38 | absltest.main() 39 | -------------------------------------------------------------------------------- /brax/envs/wrappers/gym_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests the gym wrapper.""" 16 | 17 | from absl.testing import absltest 18 | from brax import envs 19 | from brax.envs.wrappers import gym 20 | from brax.envs.wrappers import training 21 | import numpy as np 22 | 23 | 24 | class GymTest(absltest.TestCase): 25 | 26 | def test_action_space(self): 27 | """Tests the action space of the GymWrapper.""" 28 | base_env = envs.create('pusher') 29 | env = gym.GymWrapper(base_env) 30 | np.testing.assert_array_equal( 31 | env.action_space.low, base_env.sys.actuator.ctrl_range[:, 0] 32 | ) 33 | np.testing.assert_array_equal( 34 | env.action_space.high, base_env.sys.actuator.ctrl_range[:, 1] 35 | ) 36 | 37 | 38 | def test_vector_action_space(self): 39 | """Tests the action space of the VectorGymWrapper.""" 40 | base_env = envs.create('pusher') 41 | env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=256)) 42 | np.testing.assert_array_equal( 43 | env.action_space.low, 44 | np.tile(base_env.sys.actuator.ctrl_range[:, 0], [256, 1]), 45 | ) 46 | np.testing.assert_array_equal( 47 | env.action_space.high, 48 | np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1]), 49 | ) 50 | 51 | 52 | if __name__ == '__main__': 53 | absltest.main() 54 | -------------------------------------------------------------------------------- /brax/envs/wrappers/torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Wrapper around a Brax GymWrapper, that converts outputs to PyTorch tensors. 16 | 17 | This conversion happens directly on-device, without moving values to the CPU. 18 | """ 19 | 20 | from typing import Optional 21 | 22 | # NOTE: The following line will emit a warning and raise ImportError if `torch` 23 | # isn't available. 24 | from brax.io import torch 25 | import gym 26 | 27 | 28 | class TorchWrapper(gym.Wrapper): 29 | """Wrapper that converts Jax tensors to PyTorch tensors.""" 30 | 31 | def __init__(self, env: gym.Env, device: Optional[torch.Device] = None): 32 | """Creates a gym Env to one that outputs PyTorch tensors.""" 33 | super().__init__(env) 34 | self.device = device 35 | 36 | def reset(self): 37 | obs = super().reset() 38 | return torch.jax_to_torch(obs, device=self.device) 39 | 40 | def step(self, action): 41 | action = torch.torch_to_jax(action) 42 | obs, reward, done, info = super().step(action) 43 | obs = torch.jax_to_torch(obs, device=self.device) 44 | reward = torch.jax_to_torch(reward, device=self.device) 45 | done = torch.jax_to_torch(done, device=self.device) 46 | info = torch.jax_to_torch(info, device=self.device) 47 | return obs, reward, done, info 48 | -------------------------------------------------------------------------------- /brax/envs/wrappers/training_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for training wrappers.""" 16 | 17 | import functools 18 | 19 | from absl.testing import absltest 20 | from brax import envs 21 | from brax.envs.wrappers import training 22 | import jax 23 | import jax.numpy as jp 24 | import numpy as np 25 | 26 | 27 | class TrainingTest(absltest.TestCase): 28 | 29 | def test_domain_randomization_wrapper(self): 30 | def rand(sys, rng): 31 | @jax.vmap 32 | def get_offset(rng): 33 | offset = jax.random.uniform(rng, shape=(3,), minval=-0.1, maxval=0.1) 34 | pos = sys.link.transform.pos.at[0].set(offset) 35 | return pos 36 | 37 | sys_v = sys.tree_replace({'link.inertia.transform.pos': get_offset(rng)}) 38 | in_axes = jax.tree.map(lambda x: None, sys) 39 | in_axes = in_axes.tree_replace({'link.inertia.transform.pos': 0}) 40 | return sys_v, in_axes 41 | 42 | env = envs.create('ant') 43 | rng = jax.random.PRNGKey(0) 44 | rng = jax.random.split(rng, 256) 45 | env = training.wrap( 46 | env, 47 | episode_length=200, 48 | randomization_fn=functools.partial(rand, rng=rng), 49 | ) 50 | 51 | # set the same key across the batch for env.reset so that only the 52 | # randomization wrapper creates variability in the env.step 53 | key = jp.zeros((256, 2), dtype=jp.uint32) 54 | state = jax.jit(env.reset)(key) 55 | self.assertEqual(state.pipeline_state.q[:, 0].shape[0], 256) 56 | self.assertEqual(np.unique(state.pipeline_state.q[:, 0]).shape[0], 1) 57 | 58 | # test that the DomainRandomizationWrapper creates variability in env.step 59 | state = jax.jit(env.step)(state, jp.zeros((256, env.sys.act_size()))) 60 | self.assertEqual(state.pipeline_state.q[:, 0].shape[0], 256) 61 | self.assertEqual(np.unique(state.pipeline_state.q[:, 0]).shape[0], 256) 62 | 63 | 64 | if __name__ == '__main__': 65 | absltest.main() 66 | -------------------------------------------------------------------------------- /brax/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/experimental/barkour/README.md: -------------------------------------------------------------------------------- 1 | # Google Barkour 2 | 3 | ## Overview 4 | 5 | This repository contains a script to evaluate policies on the [Barkour](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_v0#google-barkour-v0-description-mjcf) environment benchmark. See `score_barkour.py` for additional details on how policies get evaluated, or see the publication on [arxiv](https://arxiv.org/abs/2305.14654). 6 | 7 | ## Training Joystick Policies 8 | 9 | For instructions on how to train a joystick policy on the [Barkour v0](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_v0) quadruped on GPU/TPU with Brax, see the [tutorial notebook](https://colab.research.google.com/github/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb). The environment is defined in `BarkourEnv`, and joystick policies train in 6 minutes on an A100 GPU. 10 | 11 |

12 | 13 |

14 | 15 | 16 | For [Barkour vB](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_vb) quadruped, see the [MJX tutorial notebook](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb). These joystick policies successfully transfer onto the quadruped robot. 17 | 18 |

19 | 20 | 21 |

22 | 23 | 24 | ## MJCF Instructions 25 | 26 | For robot assets, please see [Barkour vB](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_vb) and [Barkour v0](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_v0) in the MuJoCo Menagerie repository. The Barkour environment can be found [here](https://github.com/google-deepmind/mujoco_menagerie/blob/main/google_barkour_v0/scene_barkour.xml). 27 | 28 | 29 | ## Publications 30 | 31 | If you use this work in an academic context, please cite the following publication: 32 | 33 | @misc{caluwaerts2023barkour, 34 | title={Barkour: Benchmarking Animal-level Agility with Quadruped Robots}, 35 | author={Ken Caluwaerts and Atil Iscen and J. Chase Kew and Wenhao Yu and Tingnan Zhang and Daniel Freeman and Kuang-Huei Lee and Lisa Lee and Stefano Saliceti and Vincent Zhuang and Nathan Batchelor and Steven Bohez and Federico Casarini and Jose Enrique Chen and Omar Cortes and Erwin Coumans and Adil Dostmohamed and Gabriel Dulac-Arnold and Alejandro Escontrela and Erik Frey and Roland Hafner and Deepali Jain and Bauyrjan Jyenis and Yuheng Kuang and Edward Lee and Linda Luu and Ofir Nachum and Ken Oslund and Jason Powell and Diego Reyes and Francesco Romano and Feresteh Sadeghi and Ron Sloat and Baruch Tabanpour and Daniel Zheng and Michael Neunert and Raia Hadsell and Nicolas Heess and Francesco Nori and Jeff Seto and Carolina Parada and Vikas Sindhwani and Vincent Vanhoucke and Jie Tan}, 36 | year={2023}, 37 | eprint={2305.14654}, 38 | archivePrefix={arXiv}, 39 | primaryClass={cs.RO} 40 | } -------------------------------------------------------------------------------- /brax/experimental/barkour/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/experimental/barkour/assets/joystick.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/experimental/barkour/assets/joystick.gif -------------------------------------------------------------------------------- /brax/experimental/barkour/assets/joystick_real.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/experimental/barkour/assets/joystick_real.gif -------------------------------------------------------------------------------- /brax/experimental/barkour/assets/joystick_v0.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/experimental/barkour/assets/joystick_v0.gif -------------------------------------------------------------------------------- /brax/fluid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import, g-importing-member 16 | """Functions for forces/torques through fluids.""" 17 | 18 | from typing import Union 19 | from brax.base import Force, Motion, System, Transform 20 | import jax 21 | import jax.numpy as jp 22 | 23 | 24 | def _box_viscosity(box: jax.Array, xd_i: Motion, viscosity: jax.Array) -> Force: 25 | """Gets force due to motion through a viscous fluid.""" 26 | diam = jp.mean(box, axis=-1) 27 | ang_scale = -jp.pi * diam**3 * viscosity 28 | vel_scale = -3.0 * jp.pi * diam * viscosity 29 | frc = Force( 30 | ang=ang_scale[:, None] * xd_i.ang, vel=vel_scale[:, None] * xd_i.vel 31 | ) 32 | return frc 33 | 34 | 35 | def _box_density(box: jax.Array, xd_i: Motion, density: jax.Array) -> Force: 36 | """Gets force due to motion through dense fluid.""" 37 | 38 | @jax.vmap 39 | def apply(b: jax.Array, xd: Motion) -> Force: 40 | box_mult_vel = jp.array([b[1] * b[2], b[0] * b[2], b[0] * b[1]]) 41 | vel = -0.5 * density * box_mult_vel * jp.abs(xd.vel) * xd.vel 42 | box_mult_ang = jp.array([ 43 | b[0] * (b[1] ** 4 + b[2] ** 4), 44 | b[1] * (b[0] ** 4 + b[2] ** 4), 45 | b[2] * (b[0] ** 4 + b[1] ** 4), 46 | ]) 47 | ang = -1.0 * density * box_mult_ang * jp.abs(xd.ang) * xd.ang / 64.0 48 | return Force(vel=vel, ang=ang) 49 | 50 | return apply(box, xd_i) 51 | 52 | 53 | def force( 54 | sys: System, 55 | x: Transform, 56 | xd: Motion, 57 | mass: jax.Array, 58 | inertia: jax.Array, 59 | root_com: Union[jax.Array, None] = None, 60 | ) -> Force: 61 | """Returns force due to motion through a fluid.""" 62 | # get the velocity at the com position/orientation 63 | x_i = x.vmap().do(sys.link.inertia.transform) 64 | # TODO: remove root_com when xd is fixed for stacked joints 65 | offset = x_i.pos - x.pos if root_com is None else x_i.pos - root_com 66 | xd_i = x_i.replace(pos=offset).vmap().do(xd) 67 | 68 | # TODO: add ellipsoid fluid model from mujoco 69 | # TODO: consider adding wind from mj.opt.wind 70 | diag_inertia = jax.vmap(jp.diag)(inertia) 71 | diag_inertia_v = jp.repeat(diag_inertia, 3, axis=-2).reshape((-1, 3, 3)) 72 | diag_inertia_v *= jp.ones((3, 3)) - 2 * jp.eye(3) 73 | box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12) 74 | box = jp.sqrt(box / mass[:, None]) 75 | 76 | frc = _box_viscosity(box, xd_i, sys.viscosity) 77 | frc += _box_density(box, xd_i, sys.density) 78 | 79 | # rotate back to the world orientation 80 | frc = Transform.create(rot=x_i.rot).vmap().do(frc) 81 | 82 | return frc 83 | -------------------------------------------------------------------------------- /brax/generalized/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/generalized/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import, g-importing-member 16 | """Base types for generalized pipeline.""" 17 | 18 | from brax import base 19 | from brax.base import Inertia, Motion, Transform 20 | from flax import struct 21 | import jax 22 | from jax import numpy as jp 23 | 24 | 25 | @struct.dataclass 26 | class State(base.State): 27 | """Dynamic state that changes after every step. 28 | 29 | Attributes: 30 | root_com: (num_links,) center of mass position of link root kinematic tree 31 | cinr: (num_links,) inertia in com frame 32 | cd: (num_links,) link velocities in com frame 33 | cdof: (qd_size,) dofs in com frame 34 | cdofd: (qd_size,) cdof velocity 35 | mass_mx: (qd_size, qd_size) mass matrix 36 | mass_mx_inv: (qd_size, qd_size) inverse mass matrix 37 | contact: calculated contacts 38 | con_jac: constraint jacobian 39 | con_diag: constraint A diagonal 40 | con_aref: constraint reference acceleration 41 | qf_smooth: (qd_size,) smooth dynamics force 42 | qf_constraint: (qd_size,) force from constraints (collision etc) 43 | qdd: (qd_size,) joint acceleration vector 44 | """ 45 | 46 | # position/velocity based terms are updated at the end of each step: 47 | root_com: jax.Array 48 | cinr: Inertia 49 | cd: Motion 50 | cdof: Motion 51 | cdofd: Motion 52 | mass_mx: jax.Array 53 | mass_mx_inv: jax.Array 54 | con_jac: jax.Array 55 | con_diag: jax.Array 56 | con_aref: jax.Array 57 | # acceleration based terms are calculated using terms from the previous step: 58 | qf_smooth: jax.Array 59 | qf_constraint: jax.Array 60 | qdd: jax.Array 61 | 62 | @classmethod 63 | def init( 64 | cls, q: jax.Array, qd: jax.Array, x: Transform, xd: Motion 65 | ) -> 'State': 66 | """Returns an initial State given a brax system.""" 67 | num_links = x.pos.shape[0] 68 | qd_size = qd.shape[0] 69 | return State( 70 | q=q, 71 | qd=qd, 72 | x=x, 73 | xd=xd, 74 | contact=None, 75 | root_com=jp.zeros(3), 76 | cinr=Inertia( 77 | Transform.zero((num_links,)), 78 | jp.zeros((num_links, 3, 3)), 79 | jp.zeros((num_links,)), 80 | ), 81 | cd=Motion.zero((num_links,)), 82 | cdof=Motion.zero((num_links,)), 83 | cdofd=Motion.zero((num_links,)), 84 | mass_mx=jp.zeros((qd_size, qd_size)), 85 | mass_mx_inv=jp.zeros((qd_size, qd_size)), 86 | con_jac=jp.zeros(()), 87 | con_diag=jp.zeros(()), 88 | con_aref=jp.zeros(()), 89 | qf_smooth=jp.zeros_like(qd), 90 | qf_constraint=jp.zeros_like(qd), 91 | qdd=jp.zeros_like(qd), 92 | ) 93 | -------------------------------------------------------------------------------- /brax/generalized/dynamics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Tests for dynamics.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from brax import test_utils 21 | from brax.generalized import pipeline 22 | import jax 23 | from jax import numpy as jp 24 | import numpy as np 25 | 26 | 27 | class DynamicsTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters( 30 | 'ant.xml', 31 | 'triple_pendulum.xml', 32 | ('humanoid.xml',), 33 | ('half_cheetah.xml',), 34 | ('swimmer.xml',), 35 | ) 36 | def test_transform_com(self, xml_file): 37 | """Test dynamics transform com.""" 38 | sys = test_utils.load_fixture(xml_file) 39 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): 40 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) 41 | 42 | np.testing.assert_almost_equal( 43 | state.root_com[0], mj_next.subtree_com[0], 5 44 | ) 45 | mj_cinr_i = np.zeros((state.cinr.i.shape[0], 3, 3)) 46 | mj_cinr_i[:, [0, 1, 2], [0, 1, 2]] = mj_next.cinert[1:, 0:3] # diagonal 47 | mj_cinr_i[:, [0, 0, 1], [1, 2, 2]] = mj_next.cinert[1:, 3:6] # upper tri 48 | mj_cinr_i[:, [1, 2, 2], [0, 0, 1]] = mj_next.cinert[1:, 3:6] # lower tri 49 | mj_cinr_pos = mj_next.cinert[1:, 6:9] 50 | 51 | np.testing.assert_almost_equal(state.cinr.i, mj_cinr_i, 5) 52 | np.testing.assert_almost_equal(state.cinr.transform.pos, mj_cinr_pos, 5) 53 | np.testing.assert_almost_equal(state.cinr.mass, mj_next.cinert[1:, 9], 6) 54 | np.testing.assert_almost_equal(state.cd.matrix(), mj_next.cvel[1:], 4) 55 | np.testing.assert_almost_equal(state.cdof.matrix(), mj_next.cdof, 6) 56 | np.testing.assert_almost_equal(state.cdofd.matrix(), mj_next.cdof_dot, 5) 57 | 58 | @parameterized.parameters( 59 | 'ant.xml', 60 | 'triple_pendulum.xml', 61 | ('humanoid.xml',), 62 | ('half_cheetah.xml',), 63 | ('swimmer.xml',), 64 | ) 65 | def test_forward(self, xml_file): 66 | """Test dynamics forward.""" 67 | sys = test_utils.load_fixture(xml_file) 68 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): 69 | act = jp.zeros(sys.act_size()) 70 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) 71 | state = jax.jit(pipeline.step)(sys, state, act) 72 | 73 | np.testing.assert_allclose( 74 | state.qf_smooth, mj_next.qfrc_smooth, rtol=1e-4, atol=1e-4 75 | ) 76 | 77 | 78 | if __name__ == '__main__': 79 | absltest.main() 80 | -------------------------------------------------------------------------------- /brax/generalized/integrator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Integrator functions.""" 17 | 18 | from brax import math 19 | from brax import scan 20 | from brax.base import System 21 | from brax.generalized.base import State 22 | import jax 23 | from jax import numpy as jp 24 | 25 | 26 | def _integrate_q_axis(sys: System, q: jax.Array, qd: jax.Array) -> jax.Array: 27 | """Integrates next q for revolute/prismatic joints.""" 28 | return q + qd * sys.opt.timestep 29 | 30 | 31 | def _integrate_q_free(sys: System, q: jax.Array, qd: jax.Array) -> jax.Array: 32 | """Integrates next q for free joints.""" 33 | rot, ang = q[3:7], qd[3:6] 34 | ang_norm = jp.linalg.norm(ang) + 1e-8 35 | axis = ang / ang_norm 36 | angle = sys.opt.timestep * ang_norm 37 | qrot = math.quat_rot_axis(axis, angle) 38 | rot = math.quat_mul(rot, qrot) 39 | rot = rot / jp.linalg.norm(rot) 40 | pos, vel = q[0:3], qd[0:3] 41 | pos += vel * sys.opt.timestep 42 | 43 | return jp.concatenate([pos, rot]) 44 | 45 | 46 | def integrate(sys: System, state: State) -> State: 47 | """Semi-implicit Euler integration. 48 | 49 | Args: 50 | sys: system defining the kinematic tree and other properties 51 | state: generalized state 52 | 53 | Returns: 54 | state: state with q, qd, and qdd updated 55 | """ 56 | # integrate joint damping implicitly to increase stability when we are not 57 | # using approximate inverse 58 | if sys.matrix_inv_iterations == 0: 59 | mx = state.mass_mx + jp.diag(sys.dof.damping) * sys.opt.timestep 60 | mx_inv = jax.scipy.linalg.solve(mx, jp.eye(sys.qd_size()), assume_a='pos') 61 | else: 62 | mx_inv = state.mass_mx_inv 63 | qdd = mx_inv @ (state.qf_smooth + state.qf_constraint) 64 | qd = state.qd + qdd * sys.opt.timestep 65 | 66 | def q_fn(typ, link, q, qd): 67 | q = q.reshape(link.transform.pos.shape[0], -1) 68 | qd = qd.reshape(link.transform.pos.shape[0], -1) 69 | fun = jax.vmap( 70 | { 71 | 'f': _integrate_q_free, 72 | '1': _integrate_q_axis, 73 | '2': _integrate_q_axis, 74 | '3': _integrate_q_axis, 75 | }[typ], 76 | in_axes=(None, 0, 0), 77 | ) 78 | q_s = fun(sys, q, qd).reshape(-1) 79 | 80 | return q_s 81 | 82 | q = scan.link_types(sys, q_fn, 'lqd', 'q', sys.link, state.q, qd) 83 | 84 | return state.replace(q=q, qd=qd, qdd=qdd) 85 | -------------------------------------------------------------------------------- /brax/generalized/mass_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Tests for mass matrices.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from brax import test_utils 21 | from brax.generalized import pipeline 22 | import jax 23 | import mujoco 24 | import numpy as np 25 | 26 | 27 | class MassTest(parameterized.TestCase): 28 | 29 | @parameterized.parameters( 30 | ('ant.xml',), 31 | ('triple_pendulum.xml',), 32 | ('humanoid.xml',), 33 | ('half_cheetah.xml',), 34 | ) 35 | def test_matrix(self, xml_file): 36 | """Test mass matrix calculation.""" 37 | sys = test_utils.load_fixture(xml_file) 38 | model = test_utils.load_fixture_mujoco(xml_file) 39 | mj_mass_mx = np.zeros((sys.qd_size(), sys.qd_size())) 40 | 41 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): 42 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) 43 | mujoco.mj_fullM(model, mj_mass_mx, mj_next.qM) 44 | np.testing.assert_almost_equal(state.mass_mx, mj_mass_mx, 5) 45 | 46 | 47 | if __name__ == '__main__': 48 | absltest.main() 49 | -------------------------------------------------------------------------------- /brax/generalized/perf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Generalized perf tests.""" 17 | 18 | from absl.testing import absltest 19 | from brax import test_utils 20 | from brax.generalized import pipeline 21 | import jax 22 | from jax import numpy as jp 23 | 24 | 25 | class PerfTest(absltest.TestCase): 26 | 27 | def test_pipeline_ant(self): 28 | sys = test_utils.load_fixture('ant.xml') 29 | 30 | def init_fn(rng): 31 | rng1, rng2 = jax.random.split(rng, 2) 32 | q = jp.array([0, 0, 0.75, 1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) 33 | q += jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1) 34 | qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),)) 35 | return pipeline.init(sys, q, qd) 36 | 37 | def step_fn(state): 38 | return pipeline.step(sys, state, jp.zeros(sys.act_size())) 39 | 40 | test_utils.benchmark('generalized pipeline ant', init_fn, step_fn) 41 | 42 | 43 | if __name__ == '__main__': 44 | absltest.main() 45 | -------------------------------------------------------------------------------- /brax/generalized/pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Physics pipeline for generalized coordinates engine.""" 17 | 18 | from typing import Optional 19 | from brax import actuator 20 | from brax import contact 21 | from brax import kinematics 22 | from brax.base import System 23 | from brax.generalized import constraint 24 | from brax.generalized import dynamics 25 | from brax.generalized import integrator 26 | from brax.generalized import mass 27 | from brax.generalized.base import State 28 | from brax.io import mjcf 29 | import jax 30 | 31 | 32 | def init( 33 | sys: System, 34 | q: jax.Array, 35 | qd: jax.Array, 36 | unused_act: Optional[jax.Array] = None, 37 | unused_ctrl: Optional[jax.Array] = None, 38 | debug: bool = False, 39 | ) -> State: 40 | """Initializes physics state. 41 | 42 | Args: 43 | sys: a brax system 44 | q: (q_size,) joint angle vector 45 | qd: (qd_size,) joint velocity vector 46 | debug: if True, adds contact to the state for debugging 47 | 48 | Returns: 49 | state: initial physics state 50 | """ 51 | if sys.mj_model is not None: 52 | mjcf.validate_model(sys.mj_model) 53 | x, xd = kinematics.forward(sys, q, qd) 54 | state = State.init(q, qd, x, xd) # pytype: disable=wrong-arg-types # jax-ndarray 55 | state = dynamics.transform_com(sys, state) 56 | state = mass.matrix_inv(sys, state, 0) 57 | state = constraint.jacobian(sys, state) 58 | if debug: 59 | state = state.replace(contact=contact.get(sys, state.x)) 60 | 61 | return state 62 | 63 | 64 | def step( 65 | sys: System, state: State, act: jax.Array, debug: bool = False 66 | ) -> State: 67 | """Performs a physics step. 68 | 69 | Args: 70 | sys: a brax system 71 | state: physics state prior to step 72 | act: (act_size,) actuator input vector 73 | debug: if True, adds contact to the state for debugging 74 | 75 | Returns: 76 | state: physics state after step 77 | """ 78 | # calculate acceleration terms 79 | tau = actuator.to_tau(sys, act, state.q, state.qd) 80 | state = state.replace(qf_smooth=dynamics.forward(sys, state, tau)) 81 | state = state.replace(qf_constraint=constraint.force(sys, state)) 82 | 83 | # update position/velocity level terms 84 | state = integrator.integrate(sys, state) 85 | x, xd = kinematics.forward(sys, state.q, state.qd) 86 | state = state.replace(x=x, xd=xd) 87 | state = dynamics.transform_com(sys, state) 88 | state = mass.matrix_inv(sys, state, sys.matrix_inv_iterations) 89 | state = constraint.jacobian(sys, state) 90 | 91 | if debug: 92 | state = state.replace(contact=contact.get(sys, state.x)) 93 | 94 | return state 95 | -------------------------------------------------------------------------------- /brax/generalized/pipeline_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Tests for generalized pipeline.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from brax import test_utils 21 | from brax.generalized import pipeline 22 | from brax.io import mjcf 23 | import jax 24 | from jax import numpy as jp 25 | import numpy as np 26 | 27 | 28 | class PipelineTest(parameterized.TestCase): 29 | 30 | @parameterized.parameters( 31 | ('ant.xml',), 32 | ('triple_pendulum.xml',), 33 | ('humanoid.xml',), 34 | ('half_cheetah.xml',), 35 | ('swimmer.xml',), 36 | ) 37 | def test_forward(self, xml_file): 38 | """Test pipeline step.""" 39 | sys = test_utils.load_fixture(xml_file) 40 | # crank up solver iterations just to demonstrate close match to mujoco 41 | sys = sys.replace(solver_iterations=500) 42 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file): 43 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel) 44 | state = jax.jit(pipeline.step)(sys, state, jp.zeros(sys.act_size())) 45 | 46 | np.testing.assert_allclose(state.q, mj_next.qpos, atol=0.002) 47 | np.testing.assert_allclose(state.qd, mj_next.qvel, atol=0.5) 48 | 49 | 50 | class GradientTest(absltest.TestCase): 51 | """Tests that gradients are not NaN.""" 52 | 53 | def test_grad(self): 54 | """Tests that gradients are not NaN.""" 55 | xml = """ 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | """ 66 | sys = mjcf.loads(xml) 67 | init_state = jax.jit(pipeline.init)( 68 | sys, sys.init_q, jp.zeros(sys.qd_size()) 69 | ) 70 | 71 | def fn(xd): 72 | qd = jp.zeros(sys.qd_size()).at[0].set(xd) 73 | state = init_state.replace(qd=qd) 74 | for _ in range(10): 75 | state = jax.jit(pipeline.step)(sys, state, None) 76 | return state.qd[0] 77 | 78 | grad = jax.grad(fn)(-1.0) 79 | self.assertFalse(np.isnan(grad)) 80 | 81 | 82 | if __name__ == '__main__': 83 | absltest.main() 84 | -------------------------------------------------------------------------------- /brax/io/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/io/html.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Exports a system config and trajectory as an html view.""" 17 | 18 | import base64 19 | from typing import List, Optional, Union 20 | import zlib 21 | 22 | import brax 23 | from brax.base import State, System 24 | from brax.io import json 25 | from etils import epath 26 | import jinja2 27 | 28 | 29 | def save(path: str, sys: System, states: List[State]): 30 | """Saves trajectory as an HTML text file.""" 31 | path = epath.Path(path) 32 | if not path.parent.exists(): 33 | path.parent.mkdir(parents=True) 34 | path.write_text(render(sys, states)) 35 | 36 | 37 | def render_from_json( 38 | sys: str, height: Union[int, str], colab: bool, base_url: Optional[str] 39 | ) -> str: 40 | """Returns an HTML string that visualizes the brax system json string.""" 41 | html_path = epath.resource_path('brax') / 'visualizer/index.html' 42 | template = jinja2.Template(html_path.read_text()) 43 | 44 | js_url = base_url 45 | if base_url is None: 46 | base_url = 'https://cdn.jsdelivr.net/gh/google/brax' 47 | js_url = f'{base_url}@v{brax.__version__}/brax/visualizer/js/viewer.js' 48 | 49 | sys = base64.b64encode(zlib.compress(bytes(sys, 'utf-8'))).decode('ascii') 50 | html = template.render( 51 | system_json_b64=sys, height=height, js_url=js_url, colab=colab 52 | ) 53 | return html 54 | 55 | 56 | def render( 57 | sys: System, 58 | states: List[State], 59 | height: Union[int, str] = 480, 60 | colab: bool = True, 61 | base_url: Optional[str] = None, 62 | ) -> str: 63 | """Returns an HTML string for the brax system and trajectory. 64 | 65 | Args: 66 | sys: brax System object 67 | states: list of system states to render 68 | height: the height of the render window 69 | colab: whether to use css styles for colab 70 | base_url: the base url for serving the visualizer files. By default, a CDN 71 | url is used 72 | 73 | Returns: 74 | string containing HTML for the brax visualizer 75 | """ 76 | return render_from_json(json.dumps(sys, states), height, colab, base_url) 77 | -------------------------------------------------------------------------------- /brax/io/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Exports a system config and state as an image.""" 16 | 17 | import io 18 | from typing import List, Optional, Sequence, Union 19 | 20 | import brax 21 | from brax import base 22 | import mujoco 23 | import numpy as np 24 | from PIL import Image 25 | 26 | 27 | def render_array( 28 | sys: brax.System, 29 | trajectory: Union[List[base.State], base.State], 30 | height: int = 240, 31 | width: int = 320, 32 | camera: Optional[str] = None, 33 | ) -> Union[Sequence[np.ndarray], np.ndarray]: 34 | """Returns a sequence of np.ndarray images using the MuJoCo renderer.""" 35 | renderer = mujoco.Renderer(sys.mj_model, height=height, width=width) 36 | camera = camera or -1 37 | 38 | def get_image(state: base.State): 39 | d = mujoco.MjData(sys.mj_model) 40 | d.qpos, d.qvel = state.q, state.qd 41 | if hasattr(state, 'mocap_pos') and hasattr(state, 'mocap_quat'): 42 | d.mocap_pos, d.mocap_quat = state.mocap_pos, state.mocap_quat 43 | mujoco.mj_forward(sys.mj_model, d) 44 | renderer.update_scene(d, camera=camera) 45 | return renderer.render() 46 | 47 | if isinstance(trajectory, list): 48 | return [get_image(s) for s in trajectory] 49 | 50 | return get_image(trajectory) 51 | 52 | 53 | def render( 54 | sys: brax.System, 55 | trajectory: List[base.State], 56 | height: int = 240, 57 | width: int = 320, 58 | camera: Optional[str] = None, 59 | fmt: str = 'png', 60 | ) -> bytes: 61 | """Returns an image of a brax System.""" 62 | if not trajectory: 63 | raise RuntimeError('must have at least one state') 64 | 65 | frames = render_array(sys, trajectory, height, width, camera) 66 | frames = [Image.fromarray(image) for image in frames] 67 | 68 | f = io.BytesIO() 69 | if len(frames) == 1: 70 | frames[0].save(f, format=fmt) 71 | else: 72 | frames[0].save( 73 | f, 74 | format=fmt, 75 | append_images=frames[1:], 76 | save_all=True, 77 | duration=sys.opt.timestep * 1000, 78 | loop=0, 79 | ) 80 | 81 | return f.getvalue() 82 | -------------------------------------------------------------------------------- /brax/io/json_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for json.""" 16 | 17 | import json 18 | 19 | from absl.testing import absltest 20 | from brax import test_utils 21 | from brax.generalized import pipeline 22 | from brax.io import json as bjson 23 | import jax 24 | import jax.numpy as jp 25 | 26 | 27 | class JsonTest(absltest.TestCase): 28 | 29 | def test_dumps(self): 30 | sys = test_utils.load_fixture('convex_convex.xml') 31 | state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size())) 32 | res = bjson.dumps(sys, [state]) 33 | res = json.loads(res) 34 | 35 | self.assertIsInstance(res['geoms'], dict) 36 | self.assertSequenceEqual( 37 | sorted(res['geoms'].keys()), 38 | ['box', 'dodecahedron', 'pyramid', 'tetrahedron', 'world'], 39 | ) 40 | self.assertLen(res['geoms']['world'], 1) 41 | 42 | for f in ['size', 'rgba', 'name', 'link_idx', 'pos', 'rot']: 43 | self.assertIn(f, res['geoms']['box'][0]) 44 | 45 | def test_dumps_invalidstate_raises(self): 46 | sys = test_utils.load_fixture('convex_convex.xml') 47 | state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size())) 48 | state = jax.tree.map(lambda x: jp.stack([x, x]), state) 49 | with self.assertRaises(RuntimeError): 50 | bjson.dumps(sys, [state]) 51 | 52 | 53 | if __name__ == '__main__': 54 | absltest.main() 55 | -------------------------------------------------------------------------------- /brax/io/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | """General purpose metrics writer interface.""" 18 | 19 | from absl import logging 20 | 21 | try: 22 | from tensorboardX import SummaryWriter # type: ignore 23 | finally: 24 | pass 25 | 26 | 27 | 28 | class Writer: 29 | """General purpose metrics writer.""" 30 | 31 | def __init__(self, logdir=''): 32 | self._writer = SummaryWriter(logdir=logdir) 33 | 34 | def __enter__(self): 35 | return self 36 | 37 | def __exit__(self, exc_type, exc_val, exc_tb): 38 | self._writer.close() 39 | 40 | def write_hparams(self, hparams): 41 | """Writes global hparams.""" 42 | logging.info('Hyperparameters: %s', hparams) 43 | self._writer.add_hparams(hparams, {}) 44 | 45 | def write_scalars(self, step, scalars): 46 | """Writers scalar metrics.""" 47 | values = [ 48 | f'{k}={v:.6f}' if isinstance(v, float) else f'{k}={v}' 49 | for k, v in sorted(scalars.items()) 50 | ] 51 | logging.info('[%d] %s', step, ', '.join(values)) 52 | for k, v in scalars.items(): 53 | self._writer.add_scalar(k, v, step) 54 | -------------------------------------------------------------------------------- /brax/io/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Loading/saving of inference functions.""" 16 | 17 | import pickle 18 | from typing import Any 19 | from etils import epath 20 | 21 | 22 | def load_params(path: str) -> Any: 23 | with epath.Path(path).open('rb') as fin: 24 | buf = fin.read() 25 | return pickle.loads(buf) 26 | 27 | 28 | def save_params(path: str, params: Any): 29 | """Saves parameters in flax format.""" 30 | with epath.Path(path).open('wb') as fout: 31 | fout.write(pickle.dumps(params)) 32 | -------------------------------------------------------------------------------- /brax/io/torch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions to convert Jax Arrays into PyTorch Tensors and vice-versa.""" 16 | 17 | from collections import abc 18 | import functools 19 | from typing import Any, Dict, Union 20 | import warnings 21 | 22 | import jax 23 | from jax import dlpack as jax_dlpack 24 | 25 | try: 26 | # pylint:disable=g-import-not-at-top 27 | import torch 28 | from torch.utils import dlpack as torch_dlpack 29 | except ImportError: 30 | warnings.warn( 31 | "brax.io.torch requires PyTorch. Please run `pip install torch` to use " 32 | "functions from this module." 33 | ) 34 | raise 35 | 36 | Device = Union[str, torch.device] 37 | 38 | 39 | @functools.singledispatch 40 | def torch_to_jax(value: Any) -> Any: 41 | """Converts PyTorch tensors to JAX arrays. 42 | 43 | Args: 44 | value: torch tensor 45 | 46 | Returns: 47 | a JAX array 48 | """ 49 | del value 50 | 51 | 52 | @torch_to_jax.register(torch.Tensor) 53 | def _tensor_to_jax(value: torch.Tensor) -> jax.Array: 54 | """Converts a PyTorch Tensor into a jax.Array.""" 55 | tensor = torch_dlpack.to_dlpack(value) 56 | tensor = jax_dlpack.from_dlpack(tensor) 57 | return tensor 58 | 59 | 60 | @torch_to_jax.register(abc.Mapping) 61 | def _torch_dict_to_jax( 62 | value: Dict[str, Union[torch.Tensor, Any]], 63 | ) -> Dict[str, Union[jax.Array, Any]]: 64 | """Converts a dict of PyTorch tensors into a dict of jax.Arrays.""" 65 | return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) # type: ignore 66 | 67 | 68 | @functools.singledispatch 69 | def jax_to_torch(value: Any, device: Union[Device, None] = None) -> Any: 70 | """Convert JAX values to PyTorch Tensors. 71 | 72 | Args: 73 | value: jax array or pytree 74 | device: device to copy value to (or None to leave on same device) 75 | 76 | Returns: 77 | Torch tensor on device 78 | 79 | By default, the returned tensors are on the same device as the Jax inputs, 80 | but if `device` is passed, the tensors will be moved to that device. 81 | """ 82 | del value, device 83 | 84 | 85 | @jax_to_torch.register(jax.Array) 86 | def _jaxarray_to_tensor( 87 | value: jax.Array, device: Union[Device, None] = None 88 | ) -> torch.Tensor: 89 | """Converts a jax.Array into PyTorch Tensor.""" 90 | tensor = torch_dlpack.from_dlpack(value.astype("float32")) 91 | if device: 92 | return tensor.to(device=device) 93 | return tensor 94 | 95 | 96 | @jax_to_torch.register(abc.Mapping) 97 | def _jax_dict_to_torch( 98 | value: Dict[str, Union[jax.Array, Any]], device: Union[Device, None] = None 99 | ) -> Dict[str, Union[torch.Tensor, Any]]: 100 | """Converts a dict of jax.Arrays into a dict of PyTorch tensors.""" 101 | return type(value)( 102 | **{k: jax_to_torch(v, device=device) for k, v in value.items()} 103 | ) # type: ignore 104 | -------------------------------------------------------------------------------- /brax/math_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for geometry.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from brax import math 20 | import jax 21 | from jax import numpy as jp 22 | import numpy as np 23 | 24 | 25 | def _get_rand_norm(seed: int): 26 | np.random.seed(seed) 27 | theta = np.random.random(1) * 2 * np.pi 28 | a = (np.random.random(1) - 0.5) * 2.0 29 | phi = np.arccos(a) 30 | x = np.sin(phi) * np.cos(theta) 31 | y = np.sin(phi) * np.sin(theta) 32 | z = np.cos(phi) 33 | return jp.array([x, y, z]).squeeze() 34 | 35 | 36 | class MathTest(absltest.TestCase): 37 | 38 | def test_inv_approximate(self): 39 | # create a 4x4 matrix we know is invertible 40 | x = jax.random.normal(jax.random.PRNGKey(0), (4, 4)) 41 | x = jp.eye(4) * 0.001 + x @ x.T 42 | 43 | x_inv = jp.linalg.inv(x) 44 | x_inv_approximate = math.inv_approximate(x, jp.zeros((4, 4)), num_iter=100) 45 | 46 | np.testing.assert_array_almost_equal(x_inv_approximate, x_inv) 47 | 48 | def test_from_to(self): 49 | v1 = jp.array([1.0, 0.0, 0.0]) 50 | rot = math.from_to(v1, v1) 51 | np.testing.assert_array_almost_equal(v1, math.rotate(v1, rot)) 52 | 53 | rot = math.from_to(v1, -v1) 54 | np.testing.assert_array_almost_equal(-v1, math.rotate(v1, rot)) 55 | 56 | rot = math.from_to(-v1, v1) 57 | np.testing.assert_array_almost_equal(v1, math.rotate(-v1, rot)) 58 | 59 | v1 = jp.array([0.0, 1.0, 0.0]) 60 | rot = math.from_to(v1, -v1) 61 | np.testing.assert_array_almost_equal(-v1, math.rotate(v1, rot)) 62 | 63 | v2 = jp.array([-0.5, 0.5, 0.0]) 64 | v2 /= jp.linalg.norm(v2) 65 | rot = math.from_to(v1, v2) 66 | np.testing.assert_array_almost_equal(v2, math.rotate(v1, rot)) 67 | 68 | 69 | class OrthoganalsTest(parameterized.TestCase): 70 | """Tests the orthogonals function.""" 71 | 72 | @parameterized.parameters(range(100)) 73 | def test_orthogonals(self, i): 74 | a = _get_rand_norm(i) 75 | b, c = math.orthogonals(a) 76 | np.testing.assert_almost_equal(jp.linalg.norm(a), 1) 77 | np.testing.assert_almost_equal(jp.linalg.norm(b), 1) 78 | np.testing.assert_almost_equal(jp.linalg.norm(c), 1) 79 | self.assertAlmostEqual(np.abs(a.dot(b)), 0, 6) 80 | self.assertAlmostEqual(np.abs(b.dot(c)), 0, 6) 81 | self.assertAlmostEqual(np.abs(a.dot(c)), 0, 6) 82 | 83 | 84 | if __name__ == '__main__': 85 | absltest.main() 86 | -------------------------------------------------------------------------------- /brax/mjx/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/mjx/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Brax adapter for MJX physics engine.""" 17 | 18 | from brax import base 19 | from mujoco import mjx 20 | 21 | 22 | class State(base.State, mjx.Data): 23 | """Dynamic state that changes after every pipeline step.""" 24 | -------------------------------------------------------------------------------- /brax/mjx/perf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """PBD perf tests.""" 17 | 18 | from absl.testing import absltest 19 | from brax import test_utils 20 | from brax.mjx import pipeline 21 | import jax 22 | from jax import numpy as jp 23 | 24 | 25 | class PerfTest(absltest.TestCase): 26 | 27 | def test_pipeline_ant(self): 28 | model = test_utils.load_fixture('ant.xml') 29 | 30 | def init_fn(rng): 31 | rng1, rng2 = jax.random.split(rng, 2) 32 | q = jax.random.uniform(rng1, (model.nq,), minval=-0.1, maxval=0.1) 33 | qd = 0.1 * jax.random.normal(rng2, (model.nv,)) 34 | return pipeline.init(model, q, qd) 35 | 36 | def step_fn(data): 37 | return pipeline.step(model, data, jp.zeros(model.nu)) 38 | 39 | test_utils.benchmark('mjx pipeline ant', init_fn, step_fn) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /brax/mjx/pipeline_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | # pylint:disable=g-importing-member 17 | """Tests for spring physics pipeline.""" 18 | 19 | from absl.testing import absltest 20 | from brax import test_utils 21 | from brax.base import Contact 22 | from brax.mjx import pipeline 23 | import jax 24 | from jax import numpy as jp 25 | import mujoco 26 | import numpy as np 27 | 28 | 29 | class PipelineTest(absltest.TestCase): 30 | 31 | def test_pendulum(self): 32 | model = test_utils.load_fixture('double_pendulum.xml') 33 | 34 | state = pipeline.init(model, model.init_q, jp.zeros(model.qd_size())) 35 | 36 | self.assertIsInstance(state.contact, Contact) 37 | 38 | step_fn = jax.jit(pipeline.step) 39 | for _ in range(20): 40 | state = step_fn(model, state, jp.zeros(model.act_size())) 41 | 42 | # compare against mujoco 43 | model = test_utils.load_fixture_mujoco('double_pendulum.xml') 44 | data = mujoco.MjData(model) 45 | mujoco.mj_step(model, data, 20) 46 | 47 | np.testing.assert_almost_equal(data.qpos, state.q, decimal=4) 48 | np.testing.assert_almost_equal(data.qvel, state.qd, decimal=3) 49 | np.testing.assert_almost_equal(data.xpos[1:], state.x.pos, decimal=4) 50 | 51 | def test_pipeline_init_with_ctrl(self): 52 | model = test_utils.load_fixture('single_spherical_pendulum_position.xml') 53 | ctrl = jp.array([0.3, 0.5, 0.4]) 54 | state = pipeline.init( 55 | model, 56 | model.init_q, 57 | jp.zeros(model.qd_size()), 58 | ctrl=ctrl, 59 | ) 60 | np.testing.assert_array_almost_equal(state.ctrl, ctrl) 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /brax/positional/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/positional/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Base types for positional pipeline.""" 17 | 18 | from brax import base 19 | from brax.base import Motion, Transform 20 | from flax import struct 21 | import jax 22 | import jax.numpy as jp 23 | 24 | 25 | @struct.dataclass 26 | class State(base.State): 27 | """Dynamic state that changes after every step. 28 | 29 | Attributes: 30 | x_i: link center of mass in world frame 31 | xd_i: link center of mass motion in world frame 32 | j: link position in joint frame 33 | jd: link motion in joint frame 34 | a_p: joint parent anchor in world frame 35 | a_c: joint child anchor in world frame 36 | mass: link mass 37 | """ 38 | 39 | x_i: Transform 40 | xd_i: Motion 41 | j: Transform 42 | jd: Motion 43 | a_p: Transform 44 | a_c: Transform 45 | mass: jax.Array 46 | -------------------------------------------------------------------------------- /brax/positional/integrator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for integrating maximal coordinate dynamics.""" 16 | 17 | # pylint:disable=g-multiple-import 18 | from typing import Tuple 19 | 20 | from brax import math 21 | from brax.base import Motion, System, Transform 22 | import jax 23 | from jax import numpy as jp 24 | 25 | 26 | def integrate_xdv(sys: System, xd: Motion, xdv: Motion) -> Motion: 27 | """Updates velocity by applying delta-velocity. 28 | 29 | Args: 30 | sys: System to forward propagate 31 | xd: velocity 32 | xdv: delta-velocity 33 | 34 | Returns: 35 | xd: updated velocity 36 | """ 37 | damp = Motion(vel=sys.vel_damping, ang=sys.ang_damping) 38 | xd = ( 39 | jax.tree.map(lambda d, x: jp.exp(d * sys.opt.timestep) * x, damp, xd) 40 | + xdv 41 | ) 42 | 43 | return xd 44 | 45 | 46 | def integrate_xdd( 47 | sys: System, 48 | x: Transform, 49 | xd: Motion, 50 | xdd: Motion, 51 | ) -> Tuple[Transform, Motion]: 52 | """Updates position and velocity for a system time step given acceleration. 53 | 54 | Args: 55 | sys: System to forward propagate 56 | x: position 57 | xd: velocity 58 | xdd: acceleration 59 | 60 | Returns: 61 | x: updated position 62 | xd: updated velocity 63 | """ 64 | 65 | xd = xd + xdd * sys.opt.timestep 66 | damp = Motion(vel=sys.vel_damping, ang=sys.ang_damping) 67 | xd = jax.tree.map(lambda d, x: jp.exp(d * sys.opt.timestep) * x, damp, xd) 68 | 69 | @jax.vmap 70 | def op(x, xd): 71 | pos = x.pos + xd.vel * sys.opt.timestep 72 | rot_at_ang_quat = math.ang_to_quat(xd.ang) * 0.5 * sys.opt.timestep 73 | rot, _ = math.normalize(x.rot + math.quat_mul(rot_at_ang_quat, x.rot)) 74 | return Transform(pos=pos, rot=rot) 75 | 76 | x = op(x, xd) 77 | 78 | return x, xd 79 | 80 | 81 | def project_xd(sys: System, x: Transform, x_prev: Transform) -> Motion: 82 | """Performs the position based dynamics velocity projection step. 83 | 84 | The velocity and angular velocity must respect the spatial and quaternion 85 | distance (respectively) between x and x_prev. 86 | 87 | Args: 88 | sys: The system definition 89 | x: The current transform 90 | x_prev: The transform at the previous step 91 | 92 | Returns: 93 | New state with velocity pinned to respect distance traveled since x_prev 94 | """ 95 | 96 | @jax.vmap 97 | def op(x, x_prev): 98 | vel = (x.pos - x_prev.pos) / sys.opt.timestep 99 | dq = math.relative_quat(x_prev.rot, x.rot) 100 | ang = 2.0 * dq[1:] / sys.opt.timestep 101 | scale = jp.where(dq[0] >= 0.0, 1.0, -1.0) 102 | ang = scale * ang 103 | return Motion(vel=vel, ang=ang) 104 | 105 | return op(x, x_prev) 106 | -------------------------------------------------------------------------------- /brax/positional/perf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """PBD perf tests.""" 17 | 18 | from absl.testing import absltest 19 | from brax import test_utils 20 | from brax.positional import pipeline 21 | import jax 22 | from jax import numpy as jp 23 | 24 | 25 | class PerfTest(absltest.TestCase): 26 | 27 | def test_pipeline_ant(self): 28 | sys = test_utils.load_fixture('ant.xml') 29 | 30 | def init_fn(rng): 31 | rng1, rng2 = jax.random.split(rng, 2) 32 | q = jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1) 33 | qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),)) 34 | return pipeline.init(sys, q, qd) 35 | 36 | def step_fn(state): 37 | return pipeline.step(sys, state, jp.zeros(sys.act_size())) 38 | 39 | test_utils.benchmark('pbd pipeline ant', init_fn, step_fn) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /brax/spring/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/spring/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Base types for spring pipeline.""" 17 | 18 | from brax import base 19 | from brax.base import Motion, Transform 20 | from flax import struct 21 | import jax 22 | from jax import numpy as jp 23 | 24 | 25 | @struct.dataclass 26 | class State(base.State): 27 | """Dynamic state that changes after every step. 28 | 29 | Attributes: 30 | x_i: (num_links,) link center of mass position in world frame 31 | xd_i: (num_links,) link center of mass velocity in world frame 32 | j: link position in joint frame 33 | jd: link motion in joint frame 34 | a_p: joint parent anchor in world frame 35 | a_c: joint child anchor in world frame 36 | i_inv: link inverse inertia 37 | mass: link mass 38 | """ 39 | 40 | x_i: Transform 41 | xd_i: Motion 42 | j: Transform 43 | jd: Motion 44 | a_p: Transform 45 | a_c: Transform 46 | i_inv: jax.Array 47 | mass: jax.Array 48 | -------------------------------------------------------------------------------- /brax/spring/integrator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Functions for integrating maximal coordinate dynamics.""" 16 | 17 | # pylint:disable=g-multiple-import 18 | from typing import Tuple 19 | 20 | from brax import math 21 | from brax.base import Motion, System, Transform 22 | import jax 23 | from jax import numpy as jp 24 | 25 | 26 | def integrate( 27 | sys: System, 28 | x_i: Transform, 29 | xd_i: Motion, 30 | xdv_i: Motion, 31 | ) -> Tuple[Transform, Motion]: 32 | """Updates state with velocity update in the center of mass frame. 33 | 34 | Args: 35 | sys: System to forward propagate 36 | x_i: link center of mass transform in world frame 37 | xd_i: link center of mass motion in world frame 38 | xdv_i: link center of mass delta-velocity in world frame 39 | 40 | Returns: 41 | x_i: updated link center of mass transform in world frame 42 | xd_i: updated link center of mass motion in world frame 43 | """ 44 | 45 | @jax.vmap 46 | def op(x_i, xd_i, xdv_i): 47 | # damp velocity and add acceleration 48 | xd_i = Motion( 49 | vel=jp.exp(sys.vel_damping * sys.opt.timestep) * xd_i.vel, 50 | ang=jp.exp(sys.ang_damping * sys.opt.timestep) * xd_i.ang, 51 | ) 52 | xd_i += xdv_i 53 | 54 | rot_at_ang_quat = math.ang_to_quat(xd_i.ang) * 0.5 * sys.opt.timestep 55 | rot = x_i.rot + math.quat_mul(rot_at_ang_quat, x_i.rot) 56 | x_i = Transform( 57 | pos=x_i.pos + xd_i.vel * sys.opt.timestep, rot=rot / jp.linalg.norm(rot) 58 | ) 59 | 60 | return x_i, xd_i 61 | 62 | return op(x_i, xd_i, xdv_i) 63 | -------------------------------------------------------------------------------- /brax/spring/joints_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for spring physics joints.""" 16 | 17 | from absl.testing import absltest 18 | from absl.testing import parameterized 19 | from brax import test_utils 20 | from brax.spring import pipeline 21 | import jax 22 | from jax import numpy as jp 23 | 24 | 25 | class JointTest(parameterized.TestCase): 26 | 27 | @parameterized.parameters( 28 | (2.0, 0.125, 0.0625), (5.0, 0.125, 0.03125), (1.0, 0.0625, 0.1) 29 | ) 30 | def test_pendulum_period(self, mass, radius, vel): 31 | """A small spherical mass swings for approximately one period.""" 32 | sys = test_utils.load_fixture('single_pendulum.xml') 33 | 34 | dist_to_anchor = 0.5 35 | inertia_cm = 2.0 / 5.0 * mass * radius**2.0 36 | inertia_about_anchor = mass * dist_to_anchor**2.0 + inertia_cm 37 | g = 9.81 38 | # formula for period of pendulum 39 | period = ( 40 | 2 * jp.pi * jp.sqrt(inertia_about_anchor / (mass * g * dist_to_anchor)) 41 | ) 42 | num_timesteps = 1_000 43 | sys = sys.tree_replace({'opt.timestep': period / num_timesteps}) 44 | link = sys.link.replace(constraint_limit_stiffness=jp.array([0.0] * 1)) 45 | link = link.replace(constraint_stiffness=jp.array([10_000.0] * 1)) 46 | link = link.replace(constraint_ang_damping=jp.array([0.0] * 1)) 47 | link = link.replace(constraint_vel_damping=jp.array([0.0] * 1)) 48 | sys = sys.replace(link=link) 49 | sys = sys.replace(ang_damping=0.0) 50 | sys = sys.replace( 51 | link=sys.link.replace( 52 | inertia=sys.link.inertia.replace( 53 | i=jp.array([0.4 * mass * radius**2 * jp.eye(3)] * 1), 54 | mass=jp.array([mass]), 55 | ) 56 | ) 57 | ) 58 | 59 | # init with small initial velocity for small angle approx. validity 60 | state = pipeline.init(sys, jp.array([-jp.pi / 2.0]), jp.array([vel])) 61 | 62 | j_spring_step = jax.jit(pipeline.step) 63 | for _ in range(num_timesteps): 64 | state = j_spring_step(sys, state, jp.zeros(sys.act_size())) 65 | 66 | self.assertAlmostEqual(state.xd.ang[0, 0], vel, 2) # returned to the origin 67 | 68 | 69 | if __name__ == '__main__': 70 | absltest.main() 71 | -------------------------------------------------------------------------------- /brax/spring/perf_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint:disable=g-multiple-import 16 | """Spring perf tests.""" 17 | 18 | from absl.testing import absltest 19 | from brax import test_utils 20 | from brax.spring import pipeline 21 | import jax 22 | from jax import numpy as jp 23 | 24 | 25 | class PerfTest(absltest.TestCase): 26 | 27 | def test_pipeline_ant(self): 28 | sys = test_utils.load_fixture('ant.xml') 29 | 30 | def init_fn(rng): 31 | rng1, rng2 = jax.random.split(rng, 2) 32 | q = jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1) 33 | qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),)) 34 | return pipeline.init(sys, q, qd) 35 | 36 | def step_fn(state): 37 | return pipeline.step(sys, state, jp.zeros(sys.act_size())) 38 | 39 | test_utils.benchmark('spring pipeline ant', init_fn, step_fn) 40 | 41 | 42 | if __name__ == '__main__': 43 | absltest.main() 44 | -------------------------------------------------------------------------------- /brax/test_data/capsule.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/colour_objects.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 31 | -------------------------------------------------------------------------------- /brax/test_data/convex_convex.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /brax/test_data/double_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 29 | -------------------------------------------------------------------------------- /brax/test_data/double_prismatic.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/fluid_box.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /brax/test_data/fluid_box_offset_com.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /brax/test_data/fluid_ellipsoid.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /brax/test_data/fluid_sphere.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /brax/test_data/fluid_two_spheres.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /brax/test_data/fluid_wind.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /brax/test_data/meshes/cylinder.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/cylinder.stl -------------------------------------------------------------------------------- /brax/test_data/meshes/dodecahedron.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/dodecahedron.stl -------------------------------------------------------------------------------- /brax/test_data/meshes/pyramid.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/pyramid.stl -------------------------------------------------------------------------------- /brax/test_data/meshes/tetrahedron.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/tetrahedron.stl -------------------------------------------------------------------------------- /brax/test_data/nonzero_joint_ref.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /brax/test_data/prismaversal_2dof_joint.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/prismaversal_3dof_joint.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/single_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/single_pendulum_motor.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 21 | -------------------------------------------------------------------------------- /brax/test_data/single_pendulum_position.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 21 | -------------------------------------------------------------------------------- /brax/test_data/single_pendulum_position_frclimit.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 21 | -------------------------------------------------------------------------------- /brax/test_data/single_pendulum_velocity.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 21 | -------------------------------------------------------------------------------- /brax/test_data/single_prismatic.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/single_spherical_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 14 | -------------------------------------------------------------------------------- /brax/test_data/single_spherical_pendulum_motor.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 26 | -------------------------------------------------------------------------------- /brax/test_data/single_spherical_pendulum_position.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 26 | -------------------------------------------------------------------------------- /brax/test_data/single_universal_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 13 | -------------------------------------------------------------------------------- /brax/test_data/solver_params_v2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/triple_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 30 | -------------------------------------------------------------------------------- /brax/test_data/triple_pendulum_motor.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 34 | -------------------------------------------------------------------------------- /brax/test_data/triple_prismatic.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/world_body_transform.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /brax/test_data/world_fromto.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 21 | -------------------------------------------------------------------------------- /brax/test_data/world_self_collision.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /brax/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/acme/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/acme/specs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Objects which specify the input/output spaces of an environment. 16 | 17 | 18 | This file was taken from acme and modified to simplify dependencies: 19 | 20 | https://github.com/deepmind/acme/blob/master/acme/specs.py 21 | """ 22 | import dataclasses 23 | from typing import Tuple 24 | 25 | import jax.numpy as jnp 26 | 27 | 28 | @dataclasses.dataclass(frozen=True) 29 | class Array: 30 | """Describes a numpy array or scalar shape and dtype. 31 | 32 | Similar to dm_env.specs.Array. 33 | """ 34 | shape: Tuple[int, ...] 35 | dtype: jnp.dtype 36 | -------------------------------------------------------------------------------- /brax/training/acme/types.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Common types used throughout Acme. 16 | 17 | This file was taken from acme and modified to simplify dependencies: 18 | 19 | https://github.com/deepmind/acme/blob/master/acme/types.py 20 | """ 21 | from typing import Any, Iterable, Mapping, Union 22 | 23 | from brax.training.acme import specs 24 | import jax.numpy as jnp 25 | 26 | # Define types for nested arrays and tensors. 27 | NestedArray = jnp.ndarray 28 | NestedTensor = Any 29 | 30 | # pytype: disable=not-supported-yet 31 | NestedSpec = Union[ 32 | specs.Array, 33 | Iterable['NestedSpec'], 34 | Mapping[Any, 'NestedSpec'], 35 | ] 36 | # pytype: enable=not-supported-yet 37 | 38 | Nest = Union[NestedArray, NestedTensor, NestedSpec] 39 | -------------------------------------------------------------------------------- /brax/training/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/agents/apg/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/agents/apg/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """APG networks.""" 16 | 17 | from typing import Sequence, Tuple 18 | 19 | from brax.training import distribution 20 | from brax.training import networks 21 | from brax.training import types 22 | from brax.training.types import PRNGKey 23 | import flax 24 | from flax import linen 25 | 26 | 27 | @flax.struct.dataclass 28 | class APGNetworks: 29 | policy_network: networks.FeedForwardNetwork 30 | parametric_action_distribution: distribution.ParametricDistribution 31 | 32 | 33 | def make_inference_fn(apg_networks: APGNetworks): 34 | """Creates params and inference function for the APG agent.""" 35 | 36 | def make_policy( 37 | params: types.PolicyParams, deterministic: bool = False 38 | ) -> types.Policy: 39 | 40 | def policy( 41 | observations: types.Observation, key_sample: PRNGKey 42 | ) -> Tuple[types.Action, types.Extra]: 43 | logits = apg_networks.policy_network.apply(*params, observations) 44 | if deterministic: 45 | return apg_networks.parametric_action_distribution.mode(logits), {} 46 | return ( 47 | apg_networks.parametric_action_distribution.sample( 48 | logits, key_sample 49 | ), 50 | {}, 51 | ) 52 | 53 | return policy 54 | 55 | return make_policy 56 | 57 | 58 | def make_apg_networks( 59 | observation_size: int, 60 | action_size: int, 61 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, 62 | hidden_layer_sizes: Sequence[int] = (32,) * 4, 63 | activation: networks.ActivationFn = linen.elu, 64 | layer_norm: bool = True, 65 | ) -> APGNetworks: 66 | """Make APG networks.""" 67 | parametric_action_distribution = distribution.NormalTanhDistribution( 68 | event_size=action_size, var_scale=0.1 69 | ) 70 | policy_network = networks.make_policy_network( 71 | parametric_action_distribution.param_size, 72 | observation_size, 73 | preprocess_observations_fn=preprocess_observations_fn, 74 | hidden_layer_sizes=hidden_layer_sizes, 75 | activation=activation, 76 | kernel_init=linen.initializers.orthogonal(0.01), 77 | layer_norm=layer_norm, 78 | ) 79 | return APGNetworks( 80 | policy_network=policy_network, 81 | parametric_action_distribution=parametric_action_distribution, 82 | ) 83 | -------------------------------------------------------------------------------- /brax/training/agents/ars/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/agents/ars/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ARS networks.""" 16 | 17 | from typing import Tuple 18 | 19 | from brax.training import networks 20 | from brax.training import types 21 | from brax.training.types import PRNGKey 22 | import jax.numpy as jnp 23 | 24 | ARSNetwork = networks.FeedForwardNetwork 25 | 26 | 27 | def make_policy_network( 28 | observation_size: int, 29 | action_size: int, 30 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, 31 | ) -> ARSNetwork: 32 | """Creates a policy network.""" 33 | 34 | def apply(processor_params, policy_params, obs): 35 | obs = preprocess_observations_fn(obs, processor_params) 36 | return jnp.matmul(obs, policy_params) 37 | 38 | return ARSNetwork( 39 | init=lambda _: jnp.zeros((observation_size, action_size)), apply=apply 40 | ) 41 | 42 | 43 | def make_inference_fn(policy_network: ARSNetwork): 44 | """Creates params and inference function for the ARS agent.""" 45 | 46 | def make_policy(params: types.PolicyParams) -> types.Policy: 47 | 48 | def policy( 49 | observations: types.Observation, unused_key_sample: PRNGKey 50 | ) -> Tuple[types.Action, types.Extra]: 51 | return policy_network.apply(*params, observations), {} 52 | 53 | return policy 54 | 55 | return make_policy 56 | -------------------------------------------------------------------------------- /brax/training/agents/ars/train_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Augmented Random Search training tests.""" 16 | 17 | import pickle 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from brax import envs 22 | from brax.training.acme import running_statistics 23 | from brax.training.agents.ars import networks as ars_networks 24 | from brax.training.agents.ars import train as ars 25 | import jax 26 | 27 | 28 | class ARSTest(parameterized.TestCase): 29 | """Tests for ARS module.""" 30 | 31 | @parameterized.parameters(True, False) 32 | def testModelEncoding(self, normalize_observations): 33 | env = envs.get_environment('fast') 34 | _, params, _ = ars.train( 35 | env, 36 | num_timesteps=128, 37 | episode_length=128, 38 | normalize_observations=normalize_observations, 39 | ) 40 | normalize_fn = lambda x, y: x 41 | if normalize_observations: 42 | normalize_fn = running_statistics.normalize 43 | ars_network = ars_networks.make_policy_network( 44 | env.observation_size, env.action_size, normalize_fn 45 | ) 46 | inference = ars_networks.make_inference_fn(ars_network) 47 | byte_encoding = pickle.dumps(params) 48 | decoded_params = pickle.loads(byte_encoding) 49 | 50 | # Compute one action. 51 | state = env.reset(jax.random.PRNGKey(0)) 52 | action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0] 53 | env.step(state, action) 54 | 55 | def testTrainDomainRandomize(self): 56 | """Test with domain randomization.""" 57 | 58 | def rand_fn(sys, rng): 59 | @jax.vmap 60 | def get_offset(rng): 61 | offset = jax.random.uniform(rng, shape=(3,), minval=-0.1, maxval=0.1) 62 | pos = sys.link.transform.pos.at[0].set(offset) 63 | return pos 64 | 65 | sys_v = sys.tree_replace({'link.inertia.transform.pos': get_offset(rng)}) 66 | in_axes = jax.tree.map(lambda x: None, sys) 67 | in_axes = in_axes.tree_replace({'link.inertia.transform.pos': 0}) 68 | return sys_v, in_axes 69 | 70 | _, _, _ = ars.train( 71 | envs.get_environment('inverted_pendulum', backend='spring'), 72 | num_timesteps=128, 73 | episode_length=128, 74 | normalize_observations=True, 75 | randomization_fn=rand_fn, 76 | ) 77 | 78 | 79 | if __name__ == '__main__': 80 | absltest.main() 81 | -------------------------------------------------------------------------------- /brax/training/agents/bc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/agents/bc/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Checkpointing for BC.""" 16 | 17 | from typing import Any, Union 18 | 19 | from brax.training import checkpoint 20 | from brax.training import types 21 | from brax.training.agents.bc import networks as bc_networks 22 | from etils import epath 23 | from ml_collections import config_dict 24 | 25 | _CONFIG_FNAME = 'bc_network_config.json' 26 | 27 | 28 | def save( 29 | path: Union[str, epath.Path], 30 | step: int, 31 | params: Any, 32 | config: config_dict.ConfigDict, 33 | ): 34 | """Saves a checkpoint.""" 35 | return checkpoint.save(path, step, params, config, _CONFIG_FNAME) 36 | 37 | 38 | def load( 39 | path: Union[str, epath.Path], 40 | ): 41 | """Loads checkpoint.""" 42 | return checkpoint.load(path) 43 | 44 | 45 | def network_config( 46 | observation_size: types.ObservationSize, 47 | action_size: int, 48 | normalize_observations: bool, 49 | network_factory: types.NetworkFactory[bc_networks.BCNetworks], 50 | ) -> config_dict.ConfigDict: 51 | """Returns a config dict for re-creating a network from a checkpoint.""" 52 | return checkpoint.network_config( 53 | observation_size, action_size, normalize_observations, network_factory 54 | ) 55 | 56 | 57 | def _get_bc_network( 58 | config: config_dict.ConfigDict, 59 | network_factory: types.NetworkFactory[bc_networks.BCNetworks], 60 | ) -> bc_networks.BCNetworks: 61 | """Generates a BC network given config.""" 62 | return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type 63 | 64 | 65 | def load_config( 66 | path: Union[str, epath.Path], 67 | ) -> config_dict.ConfigDict: 68 | """Loads BC config from checkpoint.""" 69 | path = epath.Path(path) 70 | config_path = path / _CONFIG_FNAME 71 | return checkpoint.load_config(config_path) 72 | 73 | 74 | def load_policy( 75 | path: Union[str, epath.Path], 76 | network_factory: types.NetworkFactory[bc_networks.BCNetworks], 77 | deterministic: bool = True, 78 | ): 79 | """Loads policy inference function from BC checkpoint. 80 | 81 | The policy is always deterministic. 82 | """ 83 | path = epath.Path(path) 84 | config = load_config(path.parent) 85 | params = load(path) 86 | bc_network = _get_bc_network(config, network_factory) 87 | make_inference_fn = bc_networks.make_inference_fn(bc_network) 88 | 89 | return make_inference_fn(params, deterministic=deterministic) 90 | -------------------------------------------------------------------------------- /brax/training/agents/bc/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Losses for BC.""" 16 | 17 | from typing import Any, Callable, Dict, Tuple 18 | 19 | from brax.training.agents.bc import networks 20 | from brax.training.types import Params 21 | import jax.numpy as jp 22 | 23 | 24 | # Vanilla L2 with postprocessing 25 | def bc_loss( 26 | params: Params, 27 | normalizer_params: Any, 28 | data: Dict, 29 | make_policy: Callable[[Tuple[Any, Params]], networks.BCInferenceFn], 30 | ): 31 | policy = make_policy((normalizer_params, params)) 32 | _, action_extras = policy(data['observations'], key_sample=None) # pytype: disable=wrong-keyword-args 33 | actor_loss = ( 34 | ( 35 | ( 36 | jp.tanh(action_extras['loc']) 37 | - jp.tanh(data['teacher_action_extras']['loc']) 38 | ) 39 | ** 2 40 | ) 41 | .sum(-1) 42 | .mean() 43 | ) 44 | actor_loss = actor_loss.mean() 45 | return actor_loss, {'actor_loss': actor_loss, 'mse_loss': actor_loss} 46 | -------------------------------------------------------------------------------- /brax/training/agents/es/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/agents/es/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evolution strategy networks.""" 16 | 17 | from typing import Sequence, Tuple 18 | 19 | from brax.training import distribution 20 | from brax.training import networks 21 | from brax.training import types 22 | from brax.training.types import PRNGKey 23 | import flax 24 | from flax import linen 25 | 26 | 27 | @flax.struct.dataclass 28 | class ESNetworks: 29 | policy_network: networks.FeedForwardNetwork 30 | parametric_action_distribution: distribution.ParametricDistribution 31 | 32 | 33 | def make_inference_fn(es_networks: ESNetworks): 34 | """Creates params and inference function for the ES agent.""" 35 | 36 | def make_policy( 37 | params: types.PolicyParams, deterministic: bool = False 38 | ) -> types.Policy: 39 | 40 | def policy( 41 | observations: types.Observation, key_sample: PRNGKey 42 | ) -> Tuple[types.Action, types.Extra]: 43 | logits = es_networks.policy_network.apply(*params, observations) 44 | if deterministic: 45 | return es_networks.parametric_action_distribution.mode(logits), {} 46 | return ( 47 | es_networks.parametric_action_distribution.sample(logits, key_sample), 48 | {}, 49 | ) 50 | 51 | return policy 52 | 53 | return make_policy 54 | 55 | 56 | def make_es_networks( 57 | observation_size: int, 58 | action_size: int, 59 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, 60 | hidden_layer_sizes: Sequence[int] = (32,) * 4, 61 | activation: networks.ActivationFn = linen.relu, 62 | ) -> ESNetworks: 63 | """Make ES networks.""" 64 | parametric_action_distribution = distribution.NormalTanhDistribution( 65 | event_size=action_size 66 | ) 67 | policy_network = networks.make_policy_network( 68 | parametric_action_distribution.param_size, 69 | observation_size, 70 | preprocess_observations_fn=preprocess_observations_fn, 71 | hidden_layer_sizes=hidden_layer_sizes, 72 | activation=activation, 73 | ) 74 | return ESNetworks( 75 | policy_network=policy_network, 76 | parametric_action_distribution=parametric_action_distribution, 77 | ) 78 | -------------------------------------------------------------------------------- /brax/training/agents/es/train_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Evolution Strategy training tests.""" 16 | 17 | import pickle 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | from brax import envs 22 | from brax.training.acme import running_statistics 23 | from brax.training.agents.es import networks as es_networks 24 | from brax.training.agents.es import train as es 25 | import jax 26 | 27 | 28 | class ESTest(parameterized.TestCase): 29 | """Tests for ES module.""" 30 | 31 | def testTrain(self): 32 | """Test ES with a simple env.""" 33 | _, _, metrics = es.train( 34 | environment=envs.get_environment('fast'), 35 | num_timesteps=65536, 36 | episode_length=128, 37 | learning_rate=0.1, 38 | ) 39 | self.assertGreater(metrics['eval/episode_reward'], 140) 40 | 41 | @parameterized.parameters(True, False) 42 | def testModelEncoding(self, normalize_observations): 43 | env = envs.get_environment('fast') 44 | _, params, _ = es.train( 45 | env, 46 | num_timesteps=128, 47 | episode_length=128, 48 | normalize_observations=normalize_observations, 49 | ) 50 | normalize_fn = lambda x, y: x 51 | if normalize_observations: 52 | normalize_fn = running_statistics.normalize 53 | es_network = es_networks.make_es_networks( 54 | env.observation_size, env.action_size, normalize_fn 55 | ) 56 | inference = es_networks.make_inference_fn(es_network) 57 | byte_encoding = pickle.dumps(params) 58 | decoded_params = pickle.loads(byte_encoding) 59 | 60 | # Compute one action. 61 | state = env.reset(jax.random.PRNGKey(0)) 62 | action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0] 63 | env.step(state, action) 64 | 65 | def testTrainDomainRandomize(self): 66 | """Test with domain randomization.""" 67 | 68 | def rand_fn(sys, rng): 69 | @jax.vmap 70 | def get_offset(rng): 71 | offset = jax.random.uniform(rng, shape=(3,), minval=-0.1, maxval=0.1) 72 | pos = sys.link.transform.pos.at[0].set(offset) 73 | return pos 74 | 75 | sys_v = sys.tree_replace({'link.inertia.transform.pos': get_offset(rng)}) 76 | in_axes = jax.tree.map(lambda x: None, sys) 77 | in_axes = in_axes.tree_replace({'link.inertia.transform.pos': 0}) 78 | return sys_v, in_axes 79 | 80 | _, _, _ = es.train( 81 | envs.get_environment('inverted_pendulum', backend='spring'), 82 | num_timesteps=1280, 83 | episode_length=128, 84 | randomization_fn=rand_fn, 85 | ) 86 | 87 | 88 | if __name__ == '__main__': 89 | jax.config.update('jax_threefry_partitionable', False) 90 | absltest.main() 91 | -------------------------------------------------------------------------------- /brax/training/agents/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/agents/ppo/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Checkpointing for PPO.""" 16 | 17 | from typing import Any, Union 18 | 19 | from brax.training import checkpoint 20 | from brax.training import types 21 | from brax.training.agents.ppo import networks as ppo_networks 22 | from etils import epath 23 | from ml_collections import config_dict 24 | 25 | _CONFIG_FNAME = 'ppo_network_config.json' 26 | 27 | 28 | def save( 29 | path: Union[str, epath.Path], 30 | step: int, 31 | params: Any, 32 | config: config_dict.ConfigDict, 33 | ): 34 | """Saves a checkpoint.""" 35 | return checkpoint.save(path, step, params, config, _CONFIG_FNAME) 36 | 37 | 38 | def load( 39 | path: Union[str, epath.Path], 40 | ): 41 | """Loads checkpoint.""" 42 | return checkpoint.load(path) 43 | 44 | 45 | def network_config( 46 | observation_size: types.ObservationSize, 47 | action_size: int, 48 | normalize_observations: bool, 49 | network_factory: types.NetworkFactory[ppo_networks.PPONetworks], 50 | ) -> config_dict.ConfigDict: 51 | """Returns a config dict for re-creating a network from a checkpoint.""" 52 | return checkpoint.network_config( 53 | observation_size, action_size, normalize_observations, network_factory 54 | ) 55 | 56 | 57 | def _get_ppo_network( 58 | config: config_dict.ConfigDict, 59 | network_factory: types.NetworkFactory[ppo_networks.PPONetworks], 60 | ) -> ppo_networks.PPONetworks: 61 | """Generates a PPO network given config.""" 62 | return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type 63 | 64 | 65 | def load_config( 66 | path: Union[str, epath.Path], 67 | ) -> config_dict.ConfigDict: 68 | """Loads PPO config from checkpoint.""" 69 | path = epath.Path(path) 70 | config_path = path / _CONFIG_FNAME 71 | return checkpoint.load_config(config_path) 72 | 73 | 74 | def load_policy( 75 | path: Union[str, epath.Path], 76 | network_factory: types.NetworkFactory[ 77 | ppo_networks.PPONetworks 78 | ] = ppo_networks.make_ppo_networks, 79 | deterministic: bool = True, 80 | ): 81 | """Loads policy inference function from PPO checkpoint.""" 82 | path = epath.Path(path) 83 | config = load_config(path) 84 | params = load(path) 85 | ppo_network = _get_ppo_network(config, network_factory) 86 | make_inference_fn = ppo_networks.make_inference_fn(ppo_network) 87 | 88 | return make_inference_fn(params, deterministic=deterministic) 89 | -------------------------------------------------------------------------------- /brax/training/agents/ppo/networks_vision.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """PPO vision networks.""" 16 | 17 | from typing import Any, Callable, Mapping, Sequence, Tuple 18 | 19 | from brax.training import distribution 20 | from brax.training import networks 21 | from brax.training import types 22 | import flax 23 | from flax import linen 24 | import jax.numpy as jp 25 | 26 | 27 | ModuleDef = Any 28 | ActivationFn = Callable[[jp.ndarray], jp.ndarray] 29 | Initializer = Callable[..., Any] 30 | 31 | 32 | @flax.struct.dataclass 33 | class PPONetworks: 34 | policy_network: networks.FeedForwardNetwork 35 | value_network: networks.FeedForwardNetwork 36 | parametric_action_distribution: distribution.ParametricDistribution 37 | 38 | 39 | def make_ppo_networks_vision( 40 | observation_size: Mapping[str, Tuple[int, ...]], 41 | action_size: int, 42 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, 43 | policy_hidden_layer_sizes: Sequence[int] = (256, 256), 44 | value_hidden_layer_sizes: Sequence[int] = (256, 256), 45 | activation: ActivationFn = linen.swish, 46 | normalise_channels: bool = False, 47 | policy_obs_key: str = "", 48 | value_obs_key: str = "", 49 | ) -> PPONetworks: 50 | """Make Vision PPO networks with preprocessor.""" 51 | 52 | parametric_action_distribution = distribution.NormalTanhDistribution( 53 | event_size=action_size 54 | ) 55 | 56 | policy_network = networks.make_policy_network_vision( 57 | observation_size=observation_size, 58 | output_size=parametric_action_distribution.param_size, 59 | preprocess_observations_fn=preprocess_observations_fn, 60 | activation=activation, 61 | hidden_layer_sizes=policy_hidden_layer_sizes, 62 | state_obs_key=policy_obs_key, 63 | normalise_channels=normalise_channels, 64 | ) 65 | 66 | value_network = networks.make_value_network_vision( 67 | observation_size=observation_size, 68 | preprocess_observations_fn=preprocess_observations_fn, 69 | activation=activation, 70 | hidden_layer_sizes=value_hidden_layer_sizes, 71 | state_obs_key=value_obs_key, 72 | normalise_channels=normalise_channels, 73 | ) 74 | 75 | return PPONetworks( 76 | policy_network=policy_network, 77 | value_network=value_network, 78 | parametric_action_distribution=parametric_action_distribution, 79 | ) 80 | -------------------------------------------------------------------------------- /brax/training/agents/sac/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /brax/training/agents/sac/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Checkpointing for SAC.""" 16 | 17 | from typing import Any, Union 18 | 19 | from brax.training import checkpoint 20 | from brax.training import types 21 | from brax.training.agents.sac import networks as sac_networks 22 | from etils import epath 23 | from ml_collections import config_dict 24 | 25 | _CONFIG_FNAME = 'sac_network_config.json' 26 | 27 | 28 | def save( 29 | path: Union[str, epath.Path], 30 | step: int, 31 | params: Any, 32 | config: config_dict.ConfigDict, 33 | ): 34 | """Saves a checkpoint.""" 35 | return checkpoint.save(path, step, params, config, _CONFIG_FNAME) 36 | 37 | 38 | def load( 39 | path: Union[str, epath.Path], 40 | ): 41 | """Loads SAC checkpoint.""" 42 | return checkpoint.load(path) 43 | 44 | 45 | def network_config( 46 | observation_size: types.ObservationSize, 47 | action_size: int, 48 | normalize_observations: bool, 49 | network_factory: types.NetworkFactory[sac_networks.SACNetworks], 50 | ) -> config_dict.ConfigDict: 51 | """Returns a config dict for re-creating a network from a checkpoint.""" 52 | return checkpoint.network_config( 53 | observation_size, action_size, normalize_observations, network_factory 54 | ) 55 | 56 | 57 | def _get_network( 58 | config: config_dict.ConfigDict, 59 | network_factory: types.NetworkFactory[sac_networks.SACNetworks], 60 | ) -> sac_networks.SACNetworks: 61 | """Generates a SAC network given config.""" 62 | return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type 63 | 64 | 65 | def load_config( 66 | path: Union[str, epath.Path], 67 | ) -> config_dict.ConfigDict: 68 | """Loads SAC config from checkpoint.""" 69 | path = epath.Path(path) 70 | config_path = path / _CONFIG_FNAME 71 | return checkpoint.load_config(config_path) 72 | 73 | 74 | def load_policy( 75 | path: Union[str, epath.Path], 76 | network_factory: types.NetworkFactory[ 77 | sac_networks.SACNetworks 78 | ] = sac_networks.make_sac_networks, 79 | deterministic: bool = True, 80 | ): 81 | """Loads policy inference function from SAC checkpoint.""" 82 | path = epath.Path(path) 83 | config = load_config(path) 84 | params = load(path) 85 | sac_network = _get_network(config, network_factory) 86 | make_inference_fn = sac_networks.make_inference_fn(sac_network) 87 | 88 | return make_inference_fn(params, deterministic=deterministic) 89 | -------------------------------------------------------------------------------- /brax/training/agents/sac/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """SAC networks.""" 16 | 17 | from typing import Sequence, Tuple 18 | 19 | from brax.training import distribution 20 | from brax.training import networks 21 | from brax.training import types 22 | from brax.training.types import PRNGKey 23 | import flax 24 | from flax import linen 25 | 26 | 27 | @flax.struct.dataclass 28 | class SACNetworks: 29 | policy_network: networks.FeedForwardNetwork 30 | q_network: networks.FeedForwardNetwork 31 | parametric_action_distribution: distribution.ParametricDistribution 32 | 33 | 34 | def make_inference_fn(sac_networks: SACNetworks): 35 | """Creates params and inference function for the SAC agent.""" 36 | 37 | def make_policy( 38 | params: types.PolicyParams, deterministic: bool = False 39 | ) -> types.Policy: 40 | 41 | def policy( 42 | observations: types.Observation, key_sample: PRNGKey 43 | ) -> Tuple[types.Action, types.Extra]: 44 | logits = sac_networks.policy_network.apply(*params, observations) 45 | if deterministic: 46 | return sac_networks.parametric_action_distribution.mode(logits), {} 47 | return ( 48 | sac_networks.parametric_action_distribution.sample( 49 | logits, key_sample 50 | ), 51 | {}, 52 | ) 53 | 54 | return policy 55 | 56 | return make_policy 57 | 58 | 59 | def make_sac_networks( 60 | observation_size: int, 61 | action_size: int, 62 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor, 63 | hidden_layer_sizes: Sequence[int] = (256, 256), 64 | activation: networks.ActivationFn = linen.relu, 65 | policy_network_layer_norm: bool = False, 66 | q_network_layer_norm: bool = False, 67 | ) -> SACNetworks: 68 | """Make SAC networks.""" 69 | parametric_action_distribution = distribution.NormalTanhDistribution( 70 | event_size=action_size 71 | ) 72 | policy_network = networks.make_policy_network( 73 | parametric_action_distribution.param_size, 74 | observation_size, 75 | preprocess_observations_fn=preprocess_observations_fn, 76 | hidden_layer_sizes=hidden_layer_sizes, 77 | activation=activation, 78 | layer_norm=policy_network_layer_norm, 79 | ) 80 | q_network = networks.make_q_network( 81 | observation_size, 82 | action_size, 83 | preprocess_observations_fn=preprocess_observations_fn, 84 | hidden_layer_sizes=hidden_layer_sizes, 85 | activation=activation, 86 | layer_norm=q_network_layer_norm, 87 | ) 88 | return SACNetworks( 89 | policy_network=policy_network, 90 | q_network=q_network, 91 | parametric_action_distribution=parametric_action_distribution, 92 | ) 93 | -------------------------------------------------------------------------------- /brax/training/gradients.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Brax training gradient utility functions.""" 16 | 17 | from typing import Callable, Optional 18 | 19 | import jax 20 | import optax 21 | 22 | 23 | def loss_and_pgrad( 24 | loss_fn: Callable[..., float], 25 | pmap_axis_name: Optional[str], 26 | has_aux: bool = False, 27 | ): 28 | g = jax.value_and_grad(loss_fn, has_aux=has_aux) 29 | 30 | def h(*args, **kwargs): 31 | value, grad = g(*args, **kwargs) 32 | return value, jax.lax.pmean(grad, axis_name=pmap_axis_name) 33 | 34 | return g if pmap_axis_name is None else h 35 | 36 | 37 | def gradient_update_fn( 38 | loss_fn: Callable[..., float], 39 | optimizer: optax.GradientTransformation, 40 | pmap_axis_name: Optional[str], 41 | has_aux: bool = False, 42 | ): 43 | """Wrapper of the loss function that apply gradient updates. 44 | 45 | Args: 46 | loss_fn: The loss function. 47 | optimizer: The optimizer to apply gradients. 48 | pmap_axis_name: If relevant, the name of the pmap axis to synchronize 49 | gradients. 50 | has_aux: Whether the loss_fn has auxiliary data. 51 | 52 | Returns: 53 | A function that takes the same argument as the loss function plus the 54 | optimizer state. The output of this function is the loss, the new parameter, 55 | and the new optimizer state. 56 | """ 57 | loss_and_pgrad_fn = loss_and_pgrad( 58 | loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux 59 | ) 60 | 61 | def f(*args, optimizer_state): 62 | value, grads = loss_and_pgrad_fn(*args) 63 | params_update, optimizer_state = optimizer.update(grads, optimizer_state) 64 | params = optax.apply_updates(args[0], params_update) 65 | return value, params, optimizer_state 66 | 67 | return f 68 | -------------------------------------------------------------------------------- /brax/training/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Logger for training metrics.""" 16 | 17 | import collections 18 | import logging 19 | from jax import numpy as jnp 20 | import numpy as np 21 | 22 | 23 | class EpisodeMetricsLogger: 24 | """Logs training metrics for each episode.""" 25 | 26 | def __init__( 27 | self, buffer_size=100, steps_between_logging=1e5, progress_fn=None 28 | ): 29 | self._metrics_buffer = collections.defaultdict( 30 | lambda: collections.deque(maxlen=buffer_size) 31 | ) 32 | self._buffer_size = buffer_size 33 | self._steps_between_logging = steps_between_logging 34 | self._num_steps = 0 35 | self._last_log_steps = 0 36 | self._log_count = 0 37 | self._progress_fn = progress_fn 38 | 39 | def update_episode_metrics(self, metrics, dones): 40 | self._num_steps += np.prod(dones.shape) 41 | if jnp.sum(dones) > 0: 42 | for name, metric in metrics.items(): 43 | done_metrics = metric[dones.astype(bool)].flatten().tolist() 44 | self._metrics_buffer[name].extend(done_metrics) 45 | if self._num_steps - self._last_log_steps >= self._steps_between_logging: 46 | self.log_metrics() 47 | self._last_log_steps = self._num_steps 48 | 49 | def log_metrics(self, pad=35): 50 | """Log metrics to console.""" 51 | self._log_count += 1 52 | log_string = ( 53 | f"\n{'Steps':>{pad}} Env: {self._num_steps} Log: {self._log_count}\n" 54 | ) 55 | mean_metrics = {} 56 | for metric_name in self._metrics_buffer: 57 | mean_metrics[metric_name] = np.mean(self._metrics_buffer[metric_name]) 58 | log_string += ( 59 | f"{f'Episode {metric_name}:':>{pad}}" 60 | f" {mean_metrics[metric_name]:.4f}\n" 61 | ) 62 | logging.info(log_string) 63 | if self._progress_fn is not None: 64 | self._progress_fn( 65 | int(self._num_steps), 66 | {f"episode/{name}": value for name, value in mean_metrics.items()}, 67 | ) 68 | -------------------------------------------------------------------------------- /brax/training/pmap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Brax Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Input normalization utils.""" 16 | 17 | import functools 18 | from typing import Any 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | def bcast_local_devices(value, local_devices_to_use=1): 25 | """Broadcasts an object to all local devices.""" 26 | devices = jax.local_devices()[:local_devices_to_use] 27 | return jax.device_put_replicated(value, devices) 28 | 29 | 30 | def synchronize_hosts(): 31 | if jax.process_count() == 1: 32 | return 33 | # Make sure all processes stay up until the end of main. 34 | x = jnp.ones([jax.local_device_count()]) 35 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x)) 36 | assert x[0] == jax.device_count() 37 | 38 | 39 | def _fingerprint(x: Any) -> float: 40 | sums = jax.tree_util.tree_map(jnp.sum, x) 41 | return jax.tree_util.tree_reduce(lambda x, y: x + y, sums) 42 | 43 | 44 | def is_replicated(x: Any, axis_name: str) -> jnp.ndarray: 45 | """Returns whether x is replicated. 46 | 47 | Should be called inside a function pmapped along 'axis_name' 48 | Args: 49 | x: Object to check replication. 50 | axis_name: pmap axis_name. 51 | 52 | Returns: 53 | boolean whether x is replicated. 54 | """ 55 | fp = _fingerprint(x) 56 | return jax.lax.pmin(fp, axis_name=axis_name) == jax.lax.pmax( 57 | fp, axis_name=axis_name 58 | ) 59 | 60 | 61 | def assert_is_replicated(x: Any, debug: Any = None): 62 | """Returns whether x is replicated. 63 | 64 | Should be called from a non-jitted code. 65 | Args: 66 | x: Object to check replication. 67 | debug: Debug message in case of failure. 68 | """ 69 | f = functools.partial(is_replicated, axis_name='i') 70 | assert jax.pmap(f, axis_name='i')(x)[0], debug 71 | -------------------------------------------------------------------------------- /brax/visualizer/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/visualizer/favicon.ico -------------------------------------------------------------------------------- /brax/visualizer/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Brax visualizer 6 | 7 | 8 | 9 | 27 | 28 | 29 | 39 | 40 | 41 | 42 | 55 | 56 | 57 | 58 |
59 | 60 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /brax/visualizer/js/selector.js: -------------------------------------------------------------------------------- 1 | import * as THREE from 'three'; 2 | 3 | class Selector extends THREE.EventDispatcher { 4 | constructor(viewer) { 5 | super(); 6 | 7 | this.viewer = viewer; 8 | this.raycaster = new THREE.Raycaster(); 9 | this.raycaster.layers.set(1); 10 | this.mousePos = new THREE.Vector2(); 11 | this.selected = null; 12 | this.hovered = null; 13 | this.dragging = false; 14 | this.selectable = viewer.scene.children.filter( 15 | o => o instanceof THREE.Group && o.name != 'world'); 16 | 17 | const domElement = this.viewer.domElement; 18 | domElement.addEventListener('pointermove', this.onPointerMove.bind(this)); 19 | domElement.addEventListener('pointerdown', this.onPointerDown.bind(this)); 20 | domElement.addEventListener('pointerup', this.onPointerUp.bind(this)); 21 | } 22 | 23 | onPointerMove(event) { 24 | event.preventDefault(); 25 | this.dragging = true; 26 | 27 | const rect = this.viewer.domElement.getBoundingClientRect(); 28 | this.mousePos.x = ((event.clientX - rect.left) / rect.width) * 2 - 1; 29 | this.mousePos.y = -((event.clientY - rect.top) / rect.height) * 2 + 1; 30 | this.raycaster.setFromCamera(this.mousePos, this.viewer.camera); 31 | const intersections = 32 | this.raycaster.intersectObjects(this.selectable, true); 33 | 34 | if (intersections.length > 0) { 35 | let object = intersections[0].object; 36 | while (object.parent && !object.name) { 37 | object = object.parent; 38 | } 39 | if (this.hovered !== object) { 40 | if (this.hovered) { 41 | this.dispatchEvent({type: 'hoveroff', object: this.hovered}); 42 | } 43 | this.hovered = object; 44 | this.dispatchEvent({type: 'hoveron', object: this.hovered}); 45 | this.viewer.domElement.style.cursor = 'pointer'; 46 | } 47 | } else if (this.hovered !== null) { 48 | this.dispatchEvent({type: 'hoveroff', object: this.hovered}); 49 | 50 | this.viewer.domElement.style.cursor = 'auto'; 51 | this.hovered = null; 52 | } 53 | } 54 | 55 | onPointerDown(event) { 56 | event.preventDefault(); 57 | this.dragging = false; 58 | } 59 | 60 | onPointerUp(event) { 61 | event.preventDefault(); 62 | if (this.dragging) return; // ignore drag events, only handle clicks 63 | this.raycaster.setFromCamera(this.mousePos, this.viewer.camera); 64 | const intersections = 65 | this.raycaster.intersectObjects(this.selectable, true); 66 | 67 | if (intersections.length > 0) { 68 | let object = intersections[0].object; 69 | while (object.parent && !object.name) { 70 | object = object.parent; 71 | } 72 | if (this.selected !== object) { 73 | if (this.selected) { 74 | this.dispatchEvent({type: 'deselect', object: this.selected}); 75 | } 76 | this.selected = object; 77 | this.dispatchEvent({type: 'select', object: this.selected}); 78 | } 79 | } else if (this.selected !== null) { 80 | this.dispatchEvent({type: 'deselect', object: this.selected}); 81 | this.selected = null; 82 | } 83 | } 84 | } 85 | 86 | export {Selector}; 87 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | In this directory, we provide results from hyperparameter sweeps for SAC and PPO as zipped json files. Within, hyperparameters are sorted by environment, and then by performance.\ 2 | \ 3 | Hyperparameter ranges for PPO:\ 4 | \ 5 | total_env_steps: 10_000_000 , 500_000_000 (split between files with these names)\ 6 | eval_frequency: 20\ 7 | reward_scaling: 1, 5\ 8 | episode_length: 1000\ 9 | normalize_observations: True\ 10 | action_repeat: 1\ 11 | entropy_cost: 1e-3\ 12 | learning_rate: 3e-4\ 13 | discounting: 0.99, 0.997\ 14 | num_envs: 2048\ 15 | unroll_length: 1, 5, 20\ 16 | batch_size: 512,1024\ 17 | num_minibatches: 4, 8, 16, 32\ 18 | num_update_epochs: 2, 4, 8\ 19 | \ 20 | \ 21 | Hyperparameter ranges for SAC:\ 22 | \ 23 | env: 'halfcheetah'\ 24 | total_env_steps: 1048576 * 5\ 25 | eval_frequency: 131012\ 26 | reward_scaling: 5, 10, 30\ 27 | episode_length: 1000\ 28 | normalize_observations: True\ 29 | action_repeat: 1\ 30 | learning_rate: 3e-4, 6e-4\ 31 | discounting: 0.95, 0.99, 0.997\ 32 | num_envs: 64, 128, 256\ 33 | min_replay_size: 8192\ 34 | max_replay_size: 1048576\ 35 | batch_size: 128, 256, 512\ 36 | grad_updates_per_step: 0.125 / 2, 0.125, 0.125 * 2 37 | -------------------------------------------------------------------------------- /datasets/ppo_10_million_steps.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/datasets/ppo_10_million_steps.tar.gz -------------------------------------------------------------------------------- /datasets/ppo_500_million_steps.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/datasets/ppo_500_million_steps.tar.gz -------------------------------------------------------------------------------- /datasets/sac_5_million_steps.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/datasets/sac_5_million_steps.tar.gz -------------------------------------------------------------------------------- /docs/contributing.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code Reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /docs/img/a1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/a1.gif -------------------------------------------------------------------------------- /docs/img/ant.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/ant.gif -------------------------------------------------------------------------------- /docs/img/ant_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/ant_v2.gif -------------------------------------------------------------------------------- /docs/img/brax_logo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/brax_logo.gif -------------------------------------------------------------------------------- /docs/img/braxlines/ant_diayn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn.png -------------------------------------------------------------------------------- /docs/img/braxlines/ant_diayn_skill1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill1.gif -------------------------------------------------------------------------------- /docs/img/braxlines/ant_diayn_skill2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill2.gif -------------------------------------------------------------------------------- /docs/img/braxlines/ant_diayn_skill3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill3.gif -------------------------------------------------------------------------------- /docs/img/braxlines/ant_diayn_skill4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill4.gif -------------------------------------------------------------------------------- /docs/img/braxlines/ant_smm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_smm.gif -------------------------------------------------------------------------------- /docs/img/braxlines/ant_smm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_smm.png -------------------------------------------------------------------------------- /docs/img/braxlines/humanoid_diayn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn.png -------------------------------------------------------------------------------- /docs/img/braxlines/humanoid_diayn_skill1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn_skill1.gif -------------------------------------------------------------------------------- /docs/img/braxlines/humanoid_diayn_skill2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn_skill2.gif -------------------------------------------------------------------------------- /docs/img/braxlines/humanoid_diayn_skill3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn_skill3.gif -------------------------------------------------------------------------------- /docs/img/braxlines/humanoid_smm.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_smm.gif -------------------------------------------------------------------------------- /docs/img/braxlines/humanoid_smm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_smm.png -------------------------------------------------------------------------------- /docs/img/braxlines/sketches.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/sketches.png -------------------------------------------------------------------------------- /docs/img/composer/ant_chase.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/ant_chase.gif -------------------------------------------------------------------------------- /docs/img/composer/ant_push.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/ant_push.gif -------------------------------------------------------------------------------- /docs/img/composer/pro_ant1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/pro_ant1.gif -------------------------------------------------------------------------------- /docs/img/composer/pro_ant2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/pro_ant2.gif -------------------------------------------------------------------------------- /docs/img/fetch.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/fetch.gif -------------------------------------------------------------------------------- /docs/img/grasp.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/grasp.gif -------------------------------------------------------------------------------- /docs/img/halfcheetah.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/halfcheetah.gif -------------------------------------------------------------------------------- /docs/img/humanoid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/humanoid.gif -------------------------------------------------------------------------------- /docs/img/humanoid_v2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/humanoid_v2.gif -------------------------------------------------------------------------------- /docs/img/ur5e.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/ur5e.gif -------------------------------------------------------------------------------- /docs/release-notes/next-release.md: -------------------------------------------------------------------------------- 1 | # Brax Release Notes 2 | 3 | * Make sure brax is compatible with MJX API change https://github.com/google-deepmind/mujoco/commit/6cfea719850f7dbc91b9d6fb36deafb3c5d04eaf. 4 | * Fix #595, patch for checkpointing with activations other than relu when no checkpoint path is specified. 5 | * Merge https://github.com/google/brax/pull/582. 6 | * Merge https://github.com/google/brax/pull/549. 7 | * Remove brax/v1. 8 | * Issue a warning in brax.io.mjcf. We point users to 9 | [MJX](https://github.com/google-deepmind/mujoco/tree/main/mjx) and 10 | [MuJoCo Playground](https://github.com/google-deepmind/mujoco_playground) instead. Brax training is still 11 | actively maintained however. 12 | -------------------------------------------------------------------------------- /docs/release-notes/v0.0.11.md: -------------------------------------------------------------------------------- 1 | # Brax Version 0.0.11 Release Notes 2 | 3 | This version introduces a significant overhaul to the physics algorithms. We now support position based dynamics for resolving joint and collision constraints. See [this paper](https://matthias-research.github.io/pages/publications/PBDBodies.pdf) for details about PBD. 4 | 5 | The most noticeable difference to prior versions of Brax is that joints are now modeled as infinitely stiff, whereas before they were stiff damped spring systems. This new physics is now default, and all environments use PBD-based joints and collisions by default. 6 | 7 | If you would like to preserve the behavior used in previous versions of brax, you can either: 8 | 9 | 1. Version pin to `0.0.10` – the version right before this upgrade. While you will not get the latest and greatest improvements to Brax, you will have unambiguously consistent behavior. 10 | 11 | 2. Add `dynamics_mode: "legacy_spring"` to your brax configuration file. This causes brax to navigate the old codepath. 12 | 13 | 3. Supply `legacy_spring=True,` as a kwarg to env creation (without `s). This causes Brax to load the older config for all the environments currently defined in Brax (see the logic in the init functions of each env for details). 14 | 15 | Thank you for using Brax, and feel free to open an Issue if you have any questions! -------------------------------------------------------------------------------- /docs/release-notes/v0.0.12.md: -------------------------------------------------------------------------------- 1 | # Brax v0.0.12 Release Notes 2 | 3 | This release fixes a javascript bug that is preventing the viewer from rendering. 4 | -------------------------------------------------------------------------------- /docs/release-notes/v0.0.13.md: -------------------------------------------------------------------------------- 1 | # Brax v0.0.13 Release Notes 2 | 3 | This release fixes a few bugs in the collision handling in PBD, and adds 4 | support for specifying collider visibility, color, and contact participation. 5 | -------------------------------------------------------------------------------- /docs/release-notes/v0.0.14.md: -------------------------------------------------------------------------------- 1 | # Brax v0.0.14 Release Notes 2 | 3 | This release includes a refactor of the training code to make it more modular and hackable, with each algorithm now as a separate submodule under `brax.training.agents`. 4 | 5 | This release also updates references to the deprecated jax.tree* functions to their new home in jax.tree_util, fixes a few bugs in physics/collision code, and adds an initial implementation of box-box collisions. 6 | -------------------------------------------------------------------------------- /docs/release-notes/v0.0.15.md: -------------------------------------------------------------------------------- 1 | # Brax v0.0.14 Release Notes 2 | 3 | This release includes a refactor of the training code to make it more modular and hackable, with each algorithm now as a separate submodule under `brax.training.agents`. 4 | 5 | This release also updates references to the deprecated jax.tree* functions to their new home in jax.tree_util, fixes a few bugs in physics/collision code, and adds an initial implementation of box-box collisions. 6 | -------------------------------------------------------------------------------- /docs/release-notes/v0.0.16.md: -------------------------------------------------------------------------------- 1 | # Brax v0.0.16 Release Notes 2 | 3 | This release adds a new module: `brax.experimental.tracing` that allows for domain randomization during training. This release also adds support for placing replay buffers on device using `pjit` which allows for more configurable parallelism across many devices. Finally this release includes a number of small bug fixes. 4 | 5 | This will be the final release before we release a preview of a significant API change, so users may want to pin to this version if API stability is important. -------------------------------------------------------------------------------- /docs/release-notes/v0.1.0.md: -------------------------------------------------------------------------------- 1 | # Brax v0.1.0 Release Notes 2 | 3 | This minor release adds a preview of a major overhaul to Brax's API and 4 | functionality. This overhaul (found in the `v2/` folder) will eventually become 5 | Brax's first stable (1.0) release. 6 | 7 | The new features of Brax v2 include: 8 | 9 | * Generalized physics backend. 10 | * Continued support for the Spring physics backends. PBD will soon follow. 11 | * Direct support for Mujoco XML format, and URDF by association. 12 | * Fully traceable System object. 13 | * Env API that better supports custom physics backends. 14 | * Open sourced visualizer server. -------------------------------------------------------------------------------- /docs/release-notes/v0.1.1.md: -------------------------------------------------------------------------------- 1 | # Brax v0.1.1 Release Notes 2 | 3 | This patch release includes: 4 | * Contact debugger added to the visualizer. 5 | * Refactor of how state is passed to the v2 visualizer. 6 | * Small clean up to v2/generalized. 7 | * Convex-convex collisions added to v2. 8 | * URDF loader gets mass and revolute joint limits. -------------------------------------------------------------------------------- /docs/release-notes/v0.1.2.md: -------------------------------------------------------------------------------- 1 | # Brax v0.1.2 Release Notes 2 | 3 | This patch release includes: 4 | * Add positional physics pipeline. 5 | * Refactor spring physics pipeline. 6 | * Add more brax envs. 7 | * Add better documentation for brax v2 and update v2 notebooks. 8 | * Fixes to kinematics.py. 9 | * Add missing dependencies and files to MANIFEST.in 10 | * Add convex-capsule, convex-sphere, and convex-plane collisions. 11 | * Add rgba color to brax visualizer. -------------------------------------------------------------------------------- /docs/release-notes/v0.10.0.md: -------------------------------------------------------------------------------- 1 | # Brax v0.10.0 Release Notes 2 | 3 | This minor release makes several changes to the brax API, such that [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) data structures are the core data structures used in brax. This allows for more seamless model loading from `MuJoCo` XMLs, and allows for running `MJX` physics more seamlessly in brax. 4 | 5 | * Rebase brax `System` and `State` onto `mjx.Model` and `mjx.Data`. 6 | * Separate validation logic from the model loading logic in `brax.io.mjcf`. This allows users to load an [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) model in brax, without hitting validation errors for other physics backends like `positional` and `spring`. 7 | * Remove `System.geoms`, since `brax.System` inherits from `mjx.Model` and all geom information is available in `mjx.Model`. We also update the brax viewer to work with this new schema. 8 | * Delete the brax contact library and use the contact library from `MJX`. 9 | * Use the MuJoCo renderer instead of pytinyrenderer for `brax.io.image`. 10 | -------------------------------------------------------------------------------- /docs/release-notes/v0.10.1.md: -------------------------------------------------------------------------------- 1 | # Brax v0.10.1 Release Notes 2 | 3 | * Fixes #460, #461, and an issue related to #353. 4 | * Removes barkour v0 joystick policy in favor of the [MJX tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb) training barkour vb. 5 | * Fixes #466. 6 | -------------------------------------------------------------------------------- /docs/release-notes/v0.10.2.md: -------------------------------------------------------------------------------- 1 | # Brax v0.10.2 Release Notes 2 | 3 | - Fix bug in rendering cylinders and planes. 4 | - Fix issue with link offsets in `io.mjcf.load_model`. -------------------------------------------------------------------------------- /docs/release-notes/v0.10.3.md: -------------------------------------------------------------------------------- 1 | # Brax v0.10.3 Release Notes 2 | 3 | - Fix a bug in rendering capsules and cylinders with the wrong size. 4 | -------------------------------------------------------------------------------- /docs/release-notes/v0.10.4.md: -------------------------------------------------------------------------------- 1 | # Brax v0.10.4 Release Notes 2 | 3 | - Add support for compressing json embedded in HTML output for large models. 4 | - Remove legacy `dt` field in the brax System. We rely on MuJoCo's `opt.timestep` field instead. 5 | - Add `act` to `pipeline.init` functions. See https://mujoco.readthedocs.io/en/stable/computation/index.html#physics-state. 6 | - Updated basic APG algorithm #476, h/t @Andrew-Luo1. 7 | -------------------------------------------------------------------------------- /docs/release-notes/v0.10.5.md: -------------------------------------------------------------------------------- 1 | # Brax v0.10.5 Release Notes 2 | 3 | * Modify `policy_params_fn` in the PPO implementation to take in the full model params. This can be used for checkpointing models. 4 | * Add `restore_checkpoint_path` in PPO implementation. 5 | -------------------------------------------------------------------------------- /docs/release-notes/v0.11.0.md: -------------------------------------------------------------------------------- 1 | # Brax v0.11.0 Release Notes 2 | 3 | * Remove contact debugging from the viewer. This is a breaking change compared to older versions of brax. Old saved files will not render using the new viewer. 4 | * Added `ctrl` as input to the `MJX pipeline.init`. 5 | * Fixes to #513, #512, and #504. 6 | -------------------------------------------------------------------------------- /docs/release-notes/v0.12.0.md: -------------------------------------------------------------------------------- 1 | # Brax v0.12.0 Release Notes 2 | 3 | * Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is. 4 | * Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0. 5 | * Add `layer_norm` to `make_q_network` and set `layer_norm` to `True` in `make_sace_networks` Q Network. 6 | * Change PPO train function to return both value and policy network params, rather than just policy params. 7 | * Merge https://github.com/google/brax/pull/561, adds grad norm clipping to PPO. 8 | * Merge https://github.com/google/brax/issues/477, changes pusher vel damping. 9 | * Merge https://github.com/google/brax/pull/558, adds `mocap_pos` and `mocap_quat` to render function. 10 | * Merge https://github.com/google/brax/pull/559, allows for dictionary observations environment `State`. 11 | * Merge https://github.com/google/brax/pull/562, which supports asymmetric actor-critic for PPO. 12 | * Merge https://github.com/google/brax/pull/560, allows PPO from vision. 13 | -------------------------------------------------------------------------------- /docs/release-notes/v0.12.1.md: -------------------------------------------------------------------------------- 1 | # Brax v0.12.1 Release Notes 2 | 3 | * Add `wrap_env_fn` to training API. This allows users to specify custom wrapping functions for their environments. 4 | * Remove `FrozenDict` in brax env/training API. 5 | -------------------------------------------------------------------------------- /docs/release-notes/v0.12.2.md: -------------------------------------------------------------------------------- 1 | # Brax v0.12.2 Release Notes 2 | 3 | * See v0.12.3. -------------------------------------------------------------------------------- /docs/release-notes/v0.12.3.md: -------------------------------------------------------------------------------- 1 | # Brax v0.12.3 Release Notes 2 | 3 | * Add training metrics to brax PPO, which allows users to avoid running evals during training while getting more frequent metric updates (à là RSL-RL). Set `num_evals=0` and `log_training_metrics=True`. 4 | * Add checkpointing directly to brax PPO, rather than relying on the `policy_params_fn` callback. 5 | * Fix bug in inverted pendulum (#574) where the position of the tip was being mis-calculated. 6 | * Add UInt64 to prevent overflow of training steps in brax training. Fixes #578. 7 | -------------------------------------------------------------------------------- /docs/release-notes/v0.9.0.md: -------------------------------------------------------------------------------- 1 | # Brax v0.9.0 Release Notes 2 | 3 | This patch release moves: 4 | * brax's older API to the `brax.v1` module 5 | * and the `brax.v2` module to `brax` -------------------------------------------------------------------------------- /docs/release-notes/v0.9.1.md: -------------------------------------------------------------------------------- 1 | # Brax v0.9.1 Release Notes 2 | 3 | This patch release includes: 4 | * Add support for positional actuators. 5 | * Add fluid viscosity + density via box model. 6 | * Adds cylinder collider (but only for wafer-thin cylinders) 7 | * Bring back dm_env and torch env wrappers. 8 | * Bring back image rendering via pytinyrenderer. 9 | -------------------------------------------------------------------------------- /docs/release-notes/v0.9.2.md: -------------------------------------------------------------------------------- 1 | # Brax v0.9.2 Release Notes 2 | 3 | This patch release: 4 | * Adds domain randomization module 5 | * Adds swimmer env 6 | * Adds a quadruped training example in `brax/experimental` that demonstrates sim2real transfer. -------------------------------------------------------------------------------- /docs/release-notes/v0.9.3.md: -------------------------------------------------------------------------------- 1 | # Brax v0.9.3 Release Notes 2 | 3 | This patch release: 4 | 5 | * Fixes compatibility issues with MuJoCo 3.0.0 6 | * Fixes bugs with gym and dm envs. 7 | -------------------------------------------------------------------------------- /docs/release-notes/v0.9.4.md: -------------------------------------------------------------------------------- 1 | # Brax v0.9.4 Release Notes 2 | 3 | * Fixes gradients for generalized by changing a jp.linalg.norm to safe_norm. 4 | * Adds the [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) pipeline to Brax as well as an MjxEnv for RL training. 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "brax" 7 | version = "0.12.3" 8 | description = "A differentiable physics engine written in JAX." 9 | authors = [ 10 | { name = "Brax Authors", email = "no-reply@google.com" }, 11 | ] 12 | readme = { file = "README.md", content-type = "text/markdown" } 13 | requires-python = ">=3.10" 14 | license = { file = "LICENSE" } 15 | keywords = [ 16 | "JAX", 17 | "reinforcement learning", 18 | "rigidbody", 19 | "physics", 20 | ] 21 | classifiers = [ 22 | "Development Status :: 4 - Beta", 23 | "Intended Audience :: Developers", 24 | "Intended Audience :: Science/Research", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Programming Language :: Python", 27 | "Programming Language :: Python :: 3.10", 28 | "Programming Language :: Python :: 3.11", 29 | "Programming Language :: Python :: 3.12", 30 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 31 | ] 32 | dependencies = [ 33 | "absl-py", 34 | "dataclasses; python_version < '3.7'", 35 | "etils", 36 | "flask", 37 | "flask_cors", 38 | "flax", 39 | "jax>=0.4.6", 40 | "jaxlib>=0.4.6", 41 | "jaxopt", 42 | "jinja2", 43 | "ml_collections", 44 | "mujoco", 45 | "mujoco-mjx", 46 | "numpy", 47 | "optax", 48 | "orbax-checkpoint", 49 | "Pillow", 50 | "scipy", 51 | "tensorboardX", 52 | "trimesh", 53 | "typing-extensions", 54 | ] 55 | 56 | [project.optional-dependencies] 57 | develop = [ 58 | "pytest", 59 | "transforms3d", 60 | "gym", 61 | "dm_env", 62 | ] 63 | 64 | [project.urls] 65 | Homepage = "http://github.com/google/brax" 66 | 67 | [tool.hatch.build.targets.wheel] 68 | packages = ["brax"] 69 | exclude = [ 70 | "/datasets", 71 | "/docs", 72 | "/notebooks", 73 | "/tests", 74 | "brax/experimental/barkour/assets", 75 | "brax/experimental/barkour/data", 76 | ] 77 | 78 | [tool.hatch.build.targets.sdist] 79 | exclude = [ 80 | "/datasets", 81 | "/docs", 82 | "/notebooks", 83 | "brax/experimental/barkour/assets", 84 | "brax/experimental/barkour/data", 85 | ] 86 | 87 | [tool.isort] 88 | force_single_line = true 89 | force_sort_within_sections = true 90 | lexicographical = true 91 | single_line_exclusions = ["typing"] 92 | order_by_type = false 93 | group_by_package = true 94 | line_length = 120 95 | use_parentheses = true 96 | multi_line_output = 3 97 | skip_glob = ["**/*.ipynb"] 98 | 99 | [tool.pyink] 100 | line-length = 80 101 | unstable = true 102 | pyink-indentation = 2 103 | pyink-use-majority-quotes = true 104 | extend-exclude = '''( 105 | .ipynb$ 106 | )''' 107 | --------------------------------------------------------------------------------