├── .github
└── workflows
│ ├── ci.yml
│ └── pypi.yml
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── bin
└── learn
├── brax
├── __init__.py
├── actuator.py
├── actuator_test.py
├── base.py
├── base_test.py
├── com.py
├── com_test.py
├── contact.py
├── contact_test.py
├── envs
│ ├── __init__.py
│ ├── ant.py
│ ├── assets
│ │ ├── ant.xml
│ │ ├── half_cheetah.xml
│ │ ├── hopper.xml
│ │ ├── humanoid.xml
│ │ ├── humanoidstandup.xml
│ │ ├── inverted_double_pendulum.xml
│ │ ├── inverted_pendulum.xml
│ │ ├── pusher.xml
│ │ ├── reacher.xml
│ │ ├── swimmer.xml
│ │ └── walker2d.xml
│ ├── base.py
│ ├── env_test.py
│ ├── fast.py
│ ├── half_cheetah.py
│ ├── hopper.py
│ ├── humanoid.py
│ ├── humanoidstandup.py
│ ├── inverted_double_pendulum.py
│ ├── inverted_pendulum.py
│ ├── pusher.py
│ ├── reacher.py
│ ├── swimmer.py
│ ├── walker2d.py
│ └── wrappers
│ │ ├── __init__.py
│ │ ├── dm_env.py
│ │ ├── dm_env_test.py
│ │ ├── gym.py
│ │ ├── gym_test.py
│ │ ├── torch.py
│ │ ├── training.py
│ │ └── training_test.py
├── experimental
│ ├── __init__.py
│ └── barkour
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── assets
│ │ ├── joystick.gif
│ │ ├── joystick_real.gif
│ │ └── joystick_v0.gif
│ │ ├── data
│ │ ├── barkour_run_0.csv
│ │ ├── barkour_run_1.csv
│ │ └── barkour_run_2.csv
│ │ ├── score_barkour.py
│ │ └── tutorial.ipynb
├── fluid.py
├── fluid_test.py
├── generalized
│ ├── __init__.py
│ ├── base.py
│ ├── constraint.py
│ ├── constraint_test.py
│ ├── dynamics.py
│ ├── dynamics_test.py
│ ├── integrator.py
│ ├── mass.py
│ ├── mass_test.py
│ ├── perf_test.py
│ ├── pipeline.py
│ └── pipeline_test.py
├── io
│ ├── __init__.py
│ ├── html.py
│ ├── image.py
│ ├── json.py
│ ├── json_test.py
│ ├── metrics.py
│ ├── mjcf.py
│ ├── mjcf_test.py
│ ├── model.py
│ └── torch.py
├── kinematics.py
├── kinematics_test.py
├── math.py
├── math_test.py
├── mjx
│ ├── __init__.py
│ ├── base.py
│ ├── perf_test.py
│ ├── pipeline.py
│ └── pipeline_test.py
├── positional
│ ├── __init__.py
│ ├── base.py
│ ├── collisions.py
│ ├── integrator.py
│ ├── joints.py
│ ├── joints_test.py
│ ├── perf_test.py
│ ├── pipeline.py
│ └── pipeline_test.py
├── scan.py
├── scan_test.py
├── spring
│ ├── __init__.py
│ ├── base.py
│ ├── collisions.py
│ ├── integrator.py
│ ├── joints.py
│ ├── joints_test.py
│ ├── perf_test.py
│ ├── pipeline.py
│ └── pipeline_test.py
├── test_data
│ ├── capsule.xml
│ ├── colour_objects.xml
│ ├── convex_convex.xml
│ ├── double_pendulum.xml
│ ├── double_prismatic.xml
│ ├── fluid_box.xml
│ ├── fluid_box_offset_com.xml
│ ├── fluid_ellipsoid.xml
│ ├── fluid_sphere.xml
│ ├── fluid_two_spheres.xml
│ ├── fluid_wind.xml
│ ├── meshes
│ │ ├── cylinder.stl
│ │ ├── dodecahedron.stl
│ │ ├── pyramid.stl
│ │ └── tetrahedron.stl
│ ├── nonzero_joint_ref.xml
│ ├── prismaversal_2dof_joint.xml
│ ├── prismaversal_3dof_joint.xml
│ ├── single_pendulum.xml
│ ├── single_pendulum_motor.xml
│ ├── single_pendulum_position.xml
│ ├── single_pendulum_position_frclimit.xml
│ ├── single_pendulum_velocity.xml
│ ├── single_prismatic.xml
│ ├── single_spherical_pendulum.xml
│ ├── single_spherical_pendulum_motor.xml
│ ├── single_spherical_pendulum_position.xml
│ ├── single_universal_pendulum.xml
│ ├── solver_params_v2.xml
│ ├── triple_pendulum.xml
│ ├── triple_pendulum_motor.xml
│ ├── triple_prismatic.xml
│ ├── world_body_transform.xml
│ ├── world_fromto.xml
│ └── world_self_collision.xml
├── test_utils.py
├── training
│ ├── __init__.py
│ ├── acme
│ │ ├── __init__.py
│ │ ├── running_statistics.py
│ │ ├── specs.py
│ │ └── types.py
│ ├── acting.py
│ ├── agents
│ │ ├── __init__.py
│ │ ├── apg
│ │ │ ├── __init__.py
│ │ │ ├── networks.py
│ │ │ ├── train.py
│ │ │ └── train_test.py
│ │ ├── ars
│ │ │ ├── __init__.py
│ │ │ ├── networks.py
│ │ │ ├── train.py
│ │ │ └── train_test.py
│ │ ├── bc
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint.py
│ │ │ ├── losses.py
│ │ │ ├── networks.py
│ │ │ ├── train.py
│ │ │ └── train_test.py
│ │ ├── es
│ │ │ ├── __init__.py
│ │ │ ├── networks.py
│ │ │ ├── train.py
│ │ │ └── train_test.py
│ │ ├── ppo
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint.py
│ │ │ ├── checkpoint_test.py
│ │ │ ├── losses.py
│ │ │ ├── networks.py
│ │ │ ├── networks_vision.py
│ │ │ ├── train.py
│ │ │ └── train_test.py
│ │ └── sac
│ │ │ ├── __init__.py
│ │ │ ├── checkpoint.py
│ │ │ ├── checkpoint_test.py
│ │ │ ├── losses.py
│ │ │ ├── networks.py
│ │ │ ├── train.py
│ │ │ └── train_test.py
│ ├── checkpoint.py
│ ├── distribution.py
│ ├── gradients.py
│ ├── learner.py
│ ├── logger.py
│ ├── networks.py
│ ├── pmap.py
│ ├── replay_buffers.py
│ ├── replay_buffers_test.py
│ ├── spectral_norm.py
│ ├── types.py
│ └── types_test.py
└── visualizer
│ ├── favicon.ico
│ ├── index.html
│ ├── js
│ ├── animator.js
│ ├── selector.js
│ ├── system.js
│ └── viewer.js
│ └── visualizer.py
├── datasets
├── README.md
├── ppo_10_million_steps.tar.gz
├── ppo_500_million_steps.tar.gz
└── sac_5_million_steps.tar.gz
├── docs
├── code_of_conduct.md
├── contributing.md
├── img
│ ├── a1.gif
│ ├── ant.gif
│ ├── ant_v2.gif
│ ├── brax_logo.gif
│ ├── braxlines
│ │ ├── ant_diayn.png
│ │ ├── ant_diayn_skill1.gif
│ │ ├── ant_diayn_skill2.gif
│ │ ├── ant_diayn_skill3.gif
│ │ ├── ant_diayn_skill4.gif
│ │ ├── ant_smm.gif
│ │ ├── ant_smm.png
│ │ ├── humanoid_diayn.png
│ │ ├── humanoid_diayn_skill1.gif
│ │ ├── humanoid_diayn_skill2.gif
│ │ ├── humanoid_diayn_skill3.gif
│ │ ├── humanoid_smm.gif
│ │ ├── humanoid_smm.png
│ │ └── sketches.png
│ ├── composer
│ │ ├── ant_chase.gif
│ │ ├── ant_push.gif
│ │ ├── pro_ant1.gif
│ │ └── pro_ant2.gif
│ ├── fetch.gif
│ ├── grasp.gif
│ ├── halfcheetah.gif
│ ├── humanoid.gif
│ ├── humanoid_v2.gif
│ └── ur5e.gif
└── release-notes
│ ├── next-release.md
│ ├── v0.0.11.md
│ ├── v0.0.12.md
│ ├── v0.0.13.md
│ ├── v0.0.14.md
│ ├── v0.0.15.md
│ ├── v0.0.16.md
│ ├── v0.1.0.md
│ ├── v0.1.1.md
│ ├── v0.1.2.md
│ ├── v0.10.0.md
│ ├── v0.10.1.md
│ ├── v0.10.2.md
│ ├── v0.10.3.md
│ ├── v0.10.4.md
│ ├── v0.10.5.md
│ ├── v0.11.0.md
│ ├── v0.12.0.md
│ ├── v0.12.1.md
│ ├── v0.12.2.md
│ ├── v0.12.3.md
│ ├── v0.9.0.md
│ ├── v0.9.1.md
│ ├── v0.9.2.md
│ ├── v0.9.3.md
│ └── v0.9.4.md
├── notebooks
├── basics.ipynb
├── training.ipynb
└── training_torch.ipynb
└── pyproject.toml
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | # Triggers the workflow on push or pull request events but only for the main branch
5 | push:
6 | branches: [ main ]
7 | pull_request:
8 | branches: [ main ]
9 |
10 | jobs:
11 | build:
12 | runs-on: ubuntu-latest
13 | strategy:
14 | fail-fast: false
15 | matrix:
16 | python-version: ['3.10', '3.11']
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python ${{ matrix.python-version }}
20 | uses: actions/setup-python@v1
21 | with:
22 | python-version: ${{ matrix.python-version }}
23 | - name: Install dependencies
24 | run: |
25 | python -m pip install --upgrade pip
26 | pip install -e .[develop]
27 | - name: Run tests
28 | run: pytest --ignore=brax/v1
29 |
--------------------------------------------------------------------------------
/.github/workflows/pypi.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: pypi
5 |
6 | on:
7 | push:
8 | tags: "v*"
9 | workflow_dispatch:
10 |
11 | jobs:
12 | deploy:
13 | runs-on: ubuntu-latest
14 | permissions:
15 | contents: write # Needed to create releases
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python
20 | uses: actions/setup-python@v1
21 | with:
22 | python-version: "3.x"
23 | - name: Create Release
24 | id: create_release
25 | uses: actions/create-release@latest
26 | env:
27 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, you do not need to create your own token
28 | with:
29 | tag_name: ${{ github.ref }}
30 | release_name: ${{ github.ref }}
31 | body_path: docs/release-notes/${{ github.ref_name }}.md
32 | draft: false
33 | prerelease: false
34 |
35 | - name: Install dependencies
36 | run: |
37 | pip install uv
38 | uv pip install --system -e ".[dev]"
39 | uv pip install --system build twine
40 | - name: Build and publish
41 | env:
42 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
43 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
44 | run: |
45 | python -m build
46 | twine upload dist/*
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .ipynb_checkpoints
3 | node_modules
4 | /pip_test
5 | /_python_build
6 | *.pyc
7 | __pycache__
8 | *.swp
9 | .vscode/
10 | .idea/
11 | *.egg-info
12 |
13 | ####### UNITY STUFF
14 |
15 | */[Ll]ibrary/
16 | */[Tt]emp/
17 | */[Oo]bj/
18 | [Bb]uild/
19 | [Bb]uilds/
20 | */[Bb]uild/
21 | */[Bb]uilds/
22 | */[Ll]ogs/
23 | */[Aa]ssets/AssetStoreTools*
24 | */[Ii]mport/
25 | */[Aa]ssets/AssetBundles*
26 | */[Aa]ssets/StreamingAssets/AssetBundles*
27 |
28 | # Visual Studio cache directory
29 | .vs/
30 |
31 | # Gradle cache directory
32 | .gradle/
33 |
34 | # Autogenerated VS/MD/Consulo solution and project files
35 | ExportedObj/
36 | .consulo/
37 | *.csproj
38 | *.unityproj
39 | *.sln
40 | *.suo
41 | *.tmp
42 | *.user
43 | *.userprefs
44 | *.pidb
45 | *.booproj
46 | *.svd
47 | *.pdb
48 | *.opendb
49 | *.VC.db
50 |
51 | # Unity3D generated meta files
52 | *.pidb.meta
53 | *.pdb.meta
54 |
55 | # Unity3D generated file on crash reports
56 | sysinfo.txt
57 |
58 | # Builds
59 | *.apk
60 | *.unitypackage
61 |
62 | UnitySDK.log
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | prune docs
2 | prune notebooks
3 | prune datasets
4 |
5 | include brax/envs/assets/*.xml
6 | recursive-include brax/experimental/barkour *.csv *.stl *.xml
7 | recursive-include brax/test_data *.xml *.stl *.obj *.urdf
8 | recursive-include brax/visualizer *
9 |
--------------------------------------------------------------------------------
/bin/learn:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from absl import app
4 | from brax.training import learner
5 |
6 | if __name__ == '__main__':
7 | app.run(learner.main)
8 |
--------------------------------------------------------------------------------
/brax/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Import top-level classes and functions here for encapsulation/clarity."""
16 |
17 | __version__ = '0.12.3'
18 |
19 | from brax.base import Motion
20 | from brax.base import State
21 | from brax.base import System
22 | from brax.base import Transform
23 |
--------------------------------------------------------------------------------
/brax/actuator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Functions for applying actuators to a physics pipeline."""
17 |
18 | from brax.base import System
19 | import jax
20 | from jax import numpy as jp
21 |
22 |
23 | def to_tau(
24 | sys: System, act: jax.Array, q: jax.Array, qd: jax.Array
25 | ) -> jax.Array:
26 | """Convert actuator to a joint force tau.
27 |
28 | Args:
29 | sys: system defining the kinematic tree and other properties
30 | act: (act_size,) actuator force input vector
31 | q: joint position vector
32 | qd: joint velocity vector
33 |
34 | Returns:
35 | tau: (qd_size,) vector of joint forces
36 | """
37 | if sys.act_size() == 0:
38 | return jp.zeros(sys.qd_size())
39 |
40 | ctrl_range = sys.actuator.ctrl_range
41 | force_range = sys.actuator.force_range
42 |
43 | q, qd = q[sys.actuator.q_id], qd[sys.actuator.qd_id]
44 | act = jp.clip(act, ctrl_range[:, 0], ctrl_range[:, 1])
45 | # See https://github.com/deepmind/mujoco/discussions/754 for why gear is
46 | # used for the bias term.
47 | bias = sys.actuator.gear * (
48 | q * sys.actuator.bias_q + qd * sys.actuator.bias_qd
49 | )
50 |
51 | force = sys.actuator.gain * act + bias
52 | force = jp.clip(force, force_range[:, 0], force_range[:, 1])
53 |
54 | force *= sys.actuator.gear
55 | tau = jp.zeros(sys.qd_size()).at[sys.actuator.qd_id].add(force)
56 |
57 | return tau
58 |
--------------------------------------------------------------------------------
/brax/base_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for brax.base."""
16 |
17 | from absl.testing import absltest
18 | from brax import test_utils
19 | import jax
20 | import numpy as np
21 |
22 |
23 | class BaseTest(absltest.TestCase):
24 |
25 | def test_write_mass_array(self):
26 | sys = test_utils.load_fixture('ant.xml')
27 | rng = jax.random.PRNGKey(0)
28 | noise = jax.random.uniform(rng, (sys.link.inertia.mass.shape[0],))
29 | new_mass = sys.link.inertia.mass + noise
30 |
31 | sys_w = sys.tree_replace({'link.inertia.mass': new_mass})
32 | np.testing.assert_array_equal(sys_w.link.inertia.mass, new_mass)
33 |
34 | def test_write_mass_value(self):
35 | sys = test_utils.load_fixture('ant.xml')
36 | sys_w = sys.tree_replace({'link.inertia.mass': 1.0})
37 | self.assertEqual(sys_w.link.inertia.mass, 1.0)
38 |
39 | def test_write_array(self):
40 | sys = test_utils.load_fixture('ant.xml')
41 | np.random.seed(0)
42 |
43 | expected = np.random.uniform(sys.elasticity.shape)
44 | sys_w = sys.tree_replace({'elasticity': expected})
45 | np.testing.assert_array_equal(sys_w.elasticity, expected)
46 |
47 |
48 | if __name__ == '__main__':
49 | absltest.main()
50 |
--------------------------------------------------------------------------------
/brax/com.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Helper functions for physics calculations in maximal coordinates."""
16 |
17 | # pylint:disable=g-multiple-import
18 | from typing import Tuple
19 |
20 | from brax import math
21 | from brax.base import Motion, System, Transform
22 | import jax
23 | from jax import numpy as jp
24 |
25 |
26 | def from_world(
27 | sys: System, x: Transform, xd: Motion
28 | ) -> Tuple[Transform, Motion]:
29 | """Converts link transform and motion from world frame to com frame."""
30 | x_i = x.vmap().do(Transform.create(pos=sys.link.inertia.transform.pos))
31 | xd_i = Transform.create(pos=x_i.pos - x.pos).vmap().do(xd)
32 | return x_i, xd_i
33 |
34 |
35 | def to_world(
36 | sys: System, x_i: Transform, xd_i: Motion
37 | ) -> Tuple[Transform, Motion]:
38 | """Converts link transform and motion from com frame to world frame."""
39 | x = x_i.vmap().do(Transform.create(pos=-sys.link.inertia.transform.pos))
40 | xd = Transform.create(pos=x.pos - x_i.pos).vmap().do(xd_i)
41 | return x, xd
42 |
43 |
44 | def inv_inertia(sys, x) -> jax.Array:
45 | """Gets the inverse inertia at the center of mass in world frame."""
46 |
47 | @jax.vmap
48 | def inv_i(link_inertia, x_rot):
49 | ri = math.quat_mul(x_rot, link_inertia.transform.rot)
50 | i_diag = jp.diagonal(link_inertia.i) ** (1 - sys.spring_inertia_scale)
51 | i_inv_mx = jp.diag(1 / i_diag)
52 | i_rot_row = jax.vmap(math.rotate, in_axes=[0, None])(i_inv_mx, ri)
53 | i_rot_col = jax.vmap(math.rotate, in_axes=[0, None])(i_rot_row.T, ri)
54 | return i_rot_col
55 |
56 | return inv_i(sys.link.inertia, x.rot)
57 |
--------------------------------------------------------------------------------
/brax/com_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for com."""
16 |
17 | # pylint:disable=g-multiple-import
18 | from absl.testing import absltest
19 | from brax import com
20 | from brax import kinematics
21 | from brax import math
22 | from brax import test_utils
23 | import jax
24 | from jax import numpy as jp
25 | import numpy as np
26 |
27 |
28 | class ComTest(absltest.TestCase):
29 |
30 | def test_transform(self):
31 | sys = test_utils.load_fixture('capsule.xml')
32 | sys = sys.replace(
33 | link=sys.link.replace(
34 | inertia=sys.link.inertia.replace(
35 | transform=sys.link.inertia.transform.replace(
36 | pos=jp.array([[0.0, 0.0, -0.1]])
37 | )
38 | )
39 | )
40 | )
41 |
42 | x, xd = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size()))
43 | x = x.replace(
44 | pos=jp.array([[1.0, -1.0, 0.3]]),
45 | rot=jp.array([[0.976442, 0.16639178, -0.13593051, 0.01994515]]),
46 | )
47 | xd = xd.replace(
48 | vel=jp.array([[5.0, 1.0, -1.0]]), ang=jp.array([[1.0, 2.0, -3.0]])
49 | )
50 |
51 | x_i, xd_i = com.from_world(sys, x, xd)
52 | self.assertNotAlmostEqual(jp.abs(x_i.pos - x.pos).sum(), 0)
53 | np.testing.assert_array_almost_equal(x_i.rot, x.rot)
54 | self.assertNotAlmostEqual(jp.abs(xd_i.vel - xd.vel).sum(), 0)
55 | np.testing.assert_array_almost_equal(xd_i.ang, xd.ang)
56 | xp, xdp = com.to_world(sys, x_i, xd_i)
57 | np.testing.assert_array_almost_equal(x.pos, xp.pos)
58 | np.testing.assert_array_almost_equal(x.rot, xp.rot)
59 | np.testing.assert_array_almost_equal(xd.vel, xdp.vel)
60 | np.testing.assert_array_almost_equal(xd.ang, xdp.ang)
61 |
62 | def test_inv_inertia(self):
63 | sys = test_utils.load_fixture('capsule.xml')
64 | sys = sys.replace(
65 | link=sys.link.replace(
66 | transform=sys.link.transform.replace(
67 | rot=math.euler_to_quat(jp.array([0.0, 0.0, 45.0])).reshape(
68 | 1, -1
69 | )
70 | )
71 | )
72 | )
73 | x, _ = kinematics.forward(sys, sys.init_q, jp.zeros(sys.qd_size()))
74 | x = x.replace(
75 | rot=math.euler_to_quat(jp.array([45.0, 0.0, 0.0])).reshape(1, -1)
76 | )
77 |
78 | # get the expected inv inertia
79 | x_i = x.vmap().do(sys.link.inertia.transform)
80 | cinr = x_i.replace(pos=jp.zeros_like(x_i.pos)).vmap().do(sys.link.inertia)
81 | expected = jax.vmap(math.inv_3x3)(cinr.i)
82 |
83 | inv_i = com.inv_inertia(sys, x)
84 | np.testing.assert_array_almost_equal(expected, inv_i, 1e-6)
85 |
86 |
87 | if __name__ == '__main__':
88 | absltest.main()
89 |
--------------------------------------------------------------------------------
/brax/contact.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Calculations for generating contacts."""
17 |
18 | from typing import Optional
19 | from brax import math
20 | from brax.base import Contact
21 | from brax.base import System
22 | from brax.base import Transform
23 | import jax
24 | from jax import numpy as jp
25 | from mujoco import mjx
26 |
27 |
28 | def get(sys: System, x: Transform) -> Optional[Contact]:
29 | """Calculates contacts.
30 |
31 | Args:
32 | sys: system defining the kinematic tree and other properties
33 | x: link transforms in world frame
34 |
35 | Returns:
36 | Contact pytree
37 | """
38 | d = mjx.make_data(sys)
39 | if d.ncon == 0:
40 | return None
41 |
42 | @jax.vmap
43 | def local_to_global(pos1, quat1, pos2, quat2):
44 | pos = pos1 + math.rotate(pos2, quat1)
45 | mat = math.quat_to_3x3(math.quat_mul(quat1, quat2))
46 | return pos, mat
47 |
48 | x = x.concatenate(Transform.zero((1,)))
49 | xpos = x.pos[sys.geom_bodyid - 1]
50 | xquat = x.rot[sys.geom_bodyid - 1]
51 | geom_xpos, geom_xmat = local_to_global(
52 | xpos, xquat, sys.geom_pos, sys.geom_quat
53 | )
54 |
55 | # pytype: disable=wrong-arg-types
56 | d = d.replace(geom_xpos=geom_xpos, geom_xmat=geom_xmat)
57 | d = mjx.collision(sys, d)
58 | # pytype: enable=wrong-arg-types
59 |
60 | c = d.contact
61 | elasticity = (sys.elasticity[c.geom1] + sys.elasticity[c.geom2]) * 0.5
62 |
63 | body1 = jp.array(sys.geom_bodyid)[c.geom1] - 1
64 | body2 = jp.array(sys.geom_bodyid)[c.geom2] - 1
65 | link_idx = (body1, body2)
66 |
67 | return Contact(elasticity=elasticity, link_idx=link_idx, **c.__dict__)
68 |
--------------------------------------------------------------------------------
/brax/envs/assets/inverted_double_pendulum.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/brax/envs/assets/inverted_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/brax/envs/assets/reacher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/brax/envs/assets/swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/brax/envs/env_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for brax envs."""
16 |
17 | import functools
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from brax import envs
22 | from brax import test_utils
23 | import jax
24 | from jax import numpy as jp
25 |
26 | _EXPECTED_SPS = {'spring': {'ant': 1000, 'humanoid': 1000}}
27 |
28 |
29 |
30 | class EnvTest(parameterized.TestCase):
31 | params = [
32 | (b, e, _EXPECTED_SPS[b][e])
33 | for b in _EXPECTED_SPS
34 | for e in _EXPECTED_SPS[b]
35 | ]
36 |
37 | @parameterized.parameters(params)
38 | def testSpeed(self, backend, env_name, expected_sps):
39 | batch_size = 128
40 | episode_length = 100 if expected_sps < 10_000 else 1000
41 |
42 | env = envs.create(
43 | env_name,
44 | backend=backend,
45 | episode_length=episode_length,
46 | auto_reset=True,
47 | )
48 | zero_action = jp.zeros(env.action_size)
49 | step_fn = functools.partial(env.step, action=zero_action)
50 |
51 | mean_sps = test_utils.benchmark(
52 | f'{backend}_{env_name}',
53 | env.reset,
54 | step_fn,
55 | batch_size=batch_size,
56 | length=episode_length,
57 | )
58 | self.assertGreater(mean_sps, expected_sps * 0.99)
59 |
60 |
61 |
62 | if __name__ == '__main__':
63 | absltest.main()
64 |
--------------------------------------------------------------------------------
/brax/envs/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/envs/wrappers/dm_env_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests the dm env wrapper."""
16 |
17 | from absl.testing import absltest
18 | from brax import envs
19 | from brax.envs.wrappers import dm_env
20 | import numpy as np
21 |
22 |
23 | class DmEnvTest(absltest.TestCase):
24 |
25 | def test_action_space(self):
26 | """Tests the action space of the DmEnvWrapper."""
27 | base_env = envs.create('pusher')
28 | env = dm_env.DmEnvWrapper(base_env)
29 | np.testing.assert_array_equal(
30 | env.action_spec().minimum, base_env.sys.actuator.ctrl_range[:, 0]
31 | )
32 | np.testing.assert_array_equal(
33 | env.action_spec().maximum, base_env.sys.actuator.ctrl_range[:, 1]
34 | )
35 |
36 |
37 | if __name__ == '__main__':
38 | absltest.main()
39 |
--------------------------------------------------------------------------------
/brax/envs/wrappers/gym_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests the gym wrapper."""
16 |
17 | from absl.testing import absltest
18 | from brax import envs
19 | from brax.envs.wrappers import gym
20 | from brax.envs.wrappers import training
21 | import numpy as np
22 |
23 |
24 | class GymTest(absltest.TestCase):
25 |
26 | def test_action_space(self):
27 | """Tests the action space of the GymWrapper."""
28 | base_env = envs.create('pusher')
29 | env = gym.GymWrapper(base_env)
30 | np.testing.assert_array_equal(
31 | env.action_space.low, base_env.sys.actuator.ctrl_range[:, 0]
32 | )
33 | np.testing.assert_array_equal(
34 | env.action_space.high, base_env.sys.actuator.ctrl_range[:, 1]
35 | )
36 |
37 |
38 | def test_vector_action_space(self):
39 | """Tests the action space of the VectorGymWrapper."""
40 | base_env = envs.create('pusher')
41 | env = gym.VectorGymWrapper(training.VmapWrapper(base_env, batch_size=256))
42 | np.testing.assert_array_equal(
43 | env.action_space.low,
44 | np.tile(base_env.sys.actuator.ctrl_range[:, 0], [256, 1]),
45 | )
46 | np.testing.assert_array_equal(
47 | env.action_space.high,
48 | np.tile(base_env.sys.actuator.ctrl_range[:, 1], [256, 1]),
49 | )
50 |
51 |
52 | if __name__ == '__main__':
53 | absltest.main()
54 |
--------------------------------------------------------------------------------
/brax/envs/wrappers/torch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Wrapper around a Brax GymWrapper, that converts outputs to PyTorch tensors.
16 |
17 | This conversion happens directly on-device, without moving values to the CPU.
18 | """
19 |
20 | from typing import Optional
21 |
22 | # NOTE: The following line will emit a warning and raise ImportError if `torch`
23 | # isn't available.
24 | from brax.io import torch
25 | import gym
26 |
27 |
28 | class TorchWrapper(gym.Wrapper):
29 | """Wrapper that converts Jax tensors to PyTorch tensors."""
30 |
31 | def __init__(self, env: gym.Env, device: Optional[torch.Device] = None):
32 | """Creates a gym Env to one that outputs PyTorch tensors."""
33 | super().__init__(env)
34 | self.device = device
35 |
36 | def reset(self):
37 | obs = super().reset()
38 | return torch.jax_to_torch(obs, device=self.device)
39 |
40 | def step(self, action):
41 | action = torch.torch_to_jax(action)
42 | obs, reward, done, info = super().step(action)
43 | obs = torch.jax_to_torch(obs, device=self.device)
44 | reward = torch.jax_to_torch(reward, device=self.device)
45 | done = torch.jax_to_torch(done, device=self.device)
46 | info = torch.jax_to_torch(info, device=self.device)
47 | return obs, reward, done, info
48 |
--------------------------------------------------------------------------------
/brax/envs/wrappers/training_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for training wrappers."""
16 |
17 | import functools
18 |
19 | from absl.testing import absltest
20 | from brax import envs
21 | from brax.envs.wrappers import training
22 | import jax
23 | import jax.numpy as jp
24 | import numpy as np
25 |
26 |
27 | class TrainingTest(absltest.TestCase):
28 |
29 | def test_domain_randomization_wrapper(self):
30 | def rand(sys, rng):
31 | @jax.vmap
32 | def get_offset(rng):
33 | offset = jax.random.uniform(rng, shape=(3,), minval=-0.1, maxval=0.1)
34 | pos = sys.link.transform.pos.at[0].set(offset)
35 | return pos
36 |
37 | sys_v = sys.tree_replace({'link.inertia.transform.pos': get_offset(rng)})
38 | in_axes = jax.tree.map(lambda x: None, sys)
39 | in_axes = in_axes.tree_replace({'link.inertia.transform.pos': 0})
40 | return sys_v, in_axes
41 |
42 | env = envs.create('ant')
43 | rng = jax.random.PRNGKey(0)
44 | rng = jax.random.split(rng, 256)
45 | env = training.wrap(
46 | env,
47 | episode_length=200,
48 | randomization_fn=functools.partial(rand, rng=rng),
49 | )
50 |
51 | # set the same key across the batch for env.reset so that only the
52 | # randomization wrapper creates variability in the env.step
53 | key = jp.zeros((256, 2), dtype=jp.uint32)
54 | state = jax.jit(env.reset)(key)
55 | self.assertEqual(state.pipeline_state.q[:, 0].shape[0], 256)
56 | self.assertEqual(np.unique(state.pipeline_state.q[:, 0]).shape[0], 1)
57 |
58 | # test that the DomainRandomizationWrapper creates variability in env.step
59 | state = jax.jit(env.step)(state, jp.zeros((256, env.sys.act_size())))
60 | self.assertEqual(state.pipeline_state.q[:, 0].shape[0], 256)
61 | self.assertEqual(np.unique(state.pipeline_state.q[:, 0]).shape[0], 256)
62 |
63 |
64 | if __name__ == '__main__':
65 | absltest.main()
66 |
--------------------------------------------------------------------------------
/brax/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/experimental/barkour/README.md:
--------------------------------------------------------------------------------
1 | # Google Barkour
2 |
3 | ## Overview
4 |
5 | This repository contains a script to evaluate policies on the [Barkour](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_v0#google-barkour-v0-description-mjcf) environment benchmark. See `score_barkour.py` for additional details on how policies get evaluated, or see the publication on [arxiv](https://arxiv.org/abs/2305.14654).
6 |
7 | ## Training Joystick Policies
8 |
9 | For instructions on how to train a joystick policy on the [Barkour v0](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_v0) quadruped on GPU/TPU with Brax, see the [tutorial notebook](https://colab.research.google.com/github/google/brax/blob/main/brax/experimental/barkour/tutorial.ipynb). The environment is defined in `BarkourEnv`, and joystick policies train in 6 minutes on an A100 GPU.
10 |
11 |
12 |
13 |
14 |
15 |
16 | For [Barkour vB](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_vb) quadruped, see the [MJX tutorial notebook](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb). These joystick policies successfully transfer onto the quadruped robot.
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | ## MJCF Instructions
25 |
26 | For robot assets, please see [Barkour vB](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_vb) and [Barkour v0](https://github.com/google-deepmind/mujoco_menagerie/tree/main/google_barkour_v0) in the MuJoCo Menagerie repository. The Barkour environment can be found [here](https://github.com/google-deepmind/mujoco_menagerie/blob/main/google_barkour_v0/scene_barkour.xml).
27 |
28 |
29 | ## Publications
30 |
31 | If you use this work in an academic context, please cite the following publication:
32 |
33 | @misc{caluwaerts2023barkour,
34 | title={Barkour: Benchmarking Animal-level Agility with Quadruped Robots},
35 | author={Ken Caluwaerts and Atil Iscen and J. Chase Kew and Wenhao Yu and Tingnan Zhang and Daniel Freeman and Kuang-Huei Lee and Lisa Lee and Stefano Saliceti and Vincent Zhuang and Nathan Batchelor and Steven Bohez and Federico Casarini and Jose Enrique Chen and Omar Cortes and Erwin Coumans and Adil Dostmohamed and Gabriel Dulac-Arnold and Alejandro Escontrela and Erik Frey and Roland Hafner and Deepali Jain and Bauyrjan Jyenis and Yuheng Kuang and Edward Lee and Linda Luu and Ofir Nachum and Ken Oslund and Jason Powell and Diego Reyes and Francesco Romano and Feresteh Sadeghi and Ron Sloat and Baruch Tabanpour and Daniel Zheng and Michael Neunert and Raia Hadsell and Nicolas Heess and Francesco Nori and Jeff Seto and Carolina Parada and Vikas Sindhwani and Vincent Vanhoucke and Jie Tan},
36 | year={2023},
37 | eprint={2305.14654},
38 | archivePrefix={arXiv},
39 | primaryClass={cs.RO}
40 | }
--------------------------------------------------------------------------------
/brax/experimental/barkour/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/experimental/barkour/assets/joystick.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/experimental/barkour/assets/joystick.gif
--------------------------------------------------------------------------------
/brax/experimental/barkour/assets/joystick_real.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/experimental/barkour/assets/joystick_real.gif
--------------------------------------------------------------------------------
/brax/experimental/barkour/assets/joystick_v0.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/experimental/barkour/assets/joystick_v0.gif
--------------------------------------------------------------------------------
/brax/fluid.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import, g-importing-member
16 | """Functions for forces/torques through fluids."""
17 |
18 | from typing import Union
19 | from brax.base import Force, Motion, System, Transform
20 | import jax
21 | import jax.numpy as jp
22 |
23 |
24 | def _box_viscosity(box: jax.Array, xd_i: Motion, viscosity: jax.Array) -> Force:
25 | """Gets force due to motion through a viscous fluid."""
26 | diam = jp.mean(box, axis=-1)
27 | ang_scale = -jp.pi * diam**3 * viscosity
28 | vel_scale = -3.0 * jp.pi * diam * viscosity
29 | frc = Force(
30 | ang=ang_scale[:, None] * xd_i.ang, vel=vel_scale[:, None] * xd_i.vel
31 | )
32 | return frc
33 |
34 |
35 | def _box_density(box: jax.Array, xd_i: Motion, density: jax.Array) -> Force:
36 | """Gets force due to motion through dense fluid."""
37 |
38 | @jax.vmap
39 | def apply(b: jax.Array, xd: Motion) -> Force:
40 | box_mult_vel = jp.array([b[1] * b[2], b[0] * b[2], b[0] * b[1]])
41 | vel = -0.5 * density * box_mult_vel * jp.abs(xd.vel) * xd.vel
42 | box_mult_ang = jp.array([
43 | b[0] * (b[1] ** 4 + b[2] ** 4),
44 | b[1] * (b[0] ** 4 + b[2] ** 4),
45 | b[2] * (b[0] ** 4 + b[1] ** 4),
46 | ])
47 | ang = -1.0 * density * box_mult_ang * jp.abs(xd.ang) * xd.ang / 64.0
48 | return Force(vel=vel, ang=ang)
49 |
50 | return apply(box, xd_i)
51 |
52 |
53 | def force(
54 | sys: System,
55 | x: Transform,
56 | xd: Motion,
57 | mass: jax.Array,
58 | inertia: jax.Array,
59 | root_com: Union[jax.Array, None] = None,
60 | ) -> Force:
61 | """Returns force due to motion through a fluid."""
62 | # get the velocity at the com position/orientation
63 | x_i = x.vmap().do(sys.link.inertia.transform)
64 | # TODO: remove root_com when xd is fixed for stacked joints
65 | offset = x_i.pos - x.pos if root_com is None else x_i.pos - root_com
66 | xd_i = x_i.replace(pos=offset).vmap().do(xd)
67 |
68 | # TODO: add ellipsoid fluid model from mujoco
69 | # TODO: consider adding wind from mj.opt.wind
70 | diag_inertia = jax.vmap(jp.diag)(inertia)
71 | diag_inertia_v = jp.repeat(diag_inertia, 3, axis=-2).reshape((-1, 3, 3))
72 | diag_inertia_v *= jp.ones((3, 3)) - 2 * jp.eye(3)
73 | box = 6.0 * jp.clip(jp.sum(diag_inertia_v, axis=-1), a_min=1e-12)
74 | box = jp.sqrt(box / mass[:, None])
75 |
76 | frc = _box_viscosity(box, xd_i, sys.viscosity)
77 | frc += _box_density(box, xd_i, sys.density)
78 |
79 | # rotate back to the world orientation
80 | frc = Transform.create(rot=x_i.rot).vmap().do(frc)
81 |
82 | return frc
83 |
--------------------------------------------------------------------------------
/brax/generalized/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/generalized/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import, g-importing-member
16 | """Base types for generalized pipeline."""
17 |
18 | from brax import base
19 | from brax.base import Inertia, Motion, Transform
20 | from flax import struct
21 | import jax
22 | from jax import numpy as jp
23 |
24 |
25 | @struct.dataclass
26 | class State(base.State):
27 | """Dynamic state that changes after every step.
28 |
29 | Attributes:
30 | root_com: (num_links,) center of mass position of link root kinematic tree
31 | cinr: (num_links,) inertia in com frame
32 | cd: (num_links,) link velocities in com frame
33 | cdof: (qd_size,) dofs in com frame
34 | cdofd: (qd_size,) cdof velocity
35 | mass_mx: (qd_size, qd_size) mass matrix
36 | mass_mx_inv: (qd_size, qd_size) inverse mass matrix
37 | contact: calculated contacts
38 | con_jac: constraint jacobian
39 | con_diag: constraint A diagonal
40 | con_aref: constraint reference acceleration
41 | qf_smooth: (qd_size,) smooth dynamics force
42 | qf_constraint: (qd_size,) force from constraints (collision etc)
43 | qdd: (qd_size,) joint acceleration vector
44 | """
45 |
46 | # position/velocity based terms are updated at the end of each step:
47 | root_com: jax.Array
48 | cinr: Inertia
49 | cd: Motion
50 | cdof: Motion
51 | cdofd: Motion
52 | mass_mx: jax.Array
53 | mass_mx_inv: jax.Array
54 | con_jac: jax.Array
55 | con_diag: jax.Array
56 | con_aref: jax.Array
57 | # acceleration based terms are calculated using terms from the previous step:
58 | qf_smooth: jax.Array
59 | qf_constraint: jax.Array
60 | qdd: jax.Array
61 |
62 | @classmethod
63 | def init(
64 | cls, q: jax.Array, qd: jax.Array, x: Transform, xd: Motion
65 | ) -> 'State':
66 | """Returns an initial State given a brax system."""
67 | num_links = x.pos.shape[0]
68 | qd_size = qd.shape[0]
69 | return State(
70 | q=q,
71 | qd=qd,
72 | x=x,
73 | xd=xd,
74 | contact=None,
75 | root_com=jp.zeros(3),
76 | cinr=Inertia(
77 | Transform.zero((num_links,)),
78 | jp.zeros((num_links, 3, 3)),
79 | jp.zeros((num_links,)),
80 | ),
81 | cd=Motion.zero((num_links,)),
82 | cdof=Motion.zero((num_links,)),
83 | cdofd=Motion.zero((num_links,)),
84 | mass_mx=jp.zeros((qd_size, qd_size)),
85 | mass_mx_inv=jp.zeros((qd_size, qd_size)),
86 | con_jac=jp.zeros(()),
87 | con_diag=jp.zeros(()),
88 | con_aref=jp.zeros(()),
89 | qf_smooth=jp.zeros_like(qd),
90 | qf_constraint=jp.zeros_like(qd),
91 | qdd=jp.zeros_like(qd),
92 | )
93 |
--------------------------------------------------------------------------------
/brax/generalized/dynamics_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Tests for dynamics."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from brax import test_utils
21 | from brax.generalized import pipeline
22 | import jax
23 | from jax import numpy as jp
24 | import numpy as np
25 |
26 |
27 | class DynamicsTest(parameterized.TestCase):
28 |
29 | @parameterized.parameters(
30 | 'ant.xml',
31 | 'triple_pendulum.xml',
32 | ('humanoid.xml',),
33 | ('half_cheetah.xml',),
34 | ('swimmer.xml',),
35 | )
36 | def test_transform_com(self, xml_file):
37 | """Test dynamics transform com."""
38 | sys = test_utils.load_fixture(xml_file)
39 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
40 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
41 |
42 | np.testing.assert_almost_equal(
43 | state.root_com[0], mj_next.subtree_com[0], 5
44 | )
45 | mj_cinr_i = np.zeros((state.cinr.i.shape[0], 3, 3))
46 | mj_cinr_i[:, [0, 1, 2], [0, 1, 2]] = mj_next.cinert[1:, 0:3] # diagonal
47 | mj_cinr_i[:, [0, 0, 1], [1, 2, 2]] = mj_next.cinert[1:, 3:6] # upper tri
48 | mj_cinr_i[:, [1, 2, 2], [0, 0, 1]] = mj_next.cinert[1:, 3:6] # lower tri
49 | mj_cinr_pos = mj_next.cinert[1:, 6:9]
50 |
51 | np.testing.assert_almost_equal(state.cinr.i, mj_cinr_i, 5)
52 | np.testing.assert_almost_equal(state.cinr.transform.pos, mj_cinr_pos, 5)
53 | np.testing.assert_almost_equal(state.cinr.mass, mj_next.cinert[1:, 9], 6)
54 | np.testing.assert_almost_equal(state.cd.matrix(), mj_next.cvel[1:], 4)
55 | np.testing.assert_almost_equal(state.cdof.matrix(), mj_next.cdof, 6)
56 | np.testing.assert_almost_equal(state.cdofd.matrix(), mj_next.cdof_dot, 5)
57 |
58 | @parameterized.parameters(
59 | 'ant.xml',
60 | 'triple_pendulum.xml',
61 | ('humanoid.xml',),
62 | ('half_cheetah.xml',),
63 | ('swimmer.xml',),
64 | )
65 | def test_forward(self, xml_file):
66 | """Test dynamics forward."""
67 | sys = test_utils.load_fixture(xml_file)
68 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
69 | act = jp.zeros(sys.act_size())
70 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
71 | state = jax.jit(pipeline.step)(sys, state, act)
72 |
73 | np.testing.assert_allclose(
74 | state.qf_smooth, mj_next.qfrc_smooth, rtol=1e-4, atol=1e-4
75 | )
76 |
77 |
78 | if __name__ == '__main__':
79 | absltest.main()
80 |
--------------------------------------------------------------------------------
/brax/generalized/integrator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Integrator functions."""
17 |
18 | from brax import math
19 | from brax import scan
20 | from brax.base import System
21 | from brax.generalized.base import State
22 | import jax
23 | from jax import numpy as jp
24 |
25 |
26 | def _integrate_q_axis(sys: System, q: jax.Array, qd: jax.Array) -> jax.Array:
27 | """Integrates next q for revolute/prismatic joints."""
28 | return q + qd * sys.opt.timestep
29 |
30 |
31 | def _integrate_q_free(sys: System, q: jax.Array, qd: jax.Array) -> jax.Array:
32 | """Integrates next q for free joints."""
33 | rot, ang = q[3:7], qd[3:6]
34 | ang_norm = jp.linalg.norm(ang) + 1e-8
35 | axis = ang / ang_norm
36 | angle = sys.opt.timestep * ang_norm
37 | qrot = math.quat_rot_axis(axis, angle)
38 | rot = math.quat_mul(rot, qrot)
39 | rot = rot / jp.linalg.norm(rot)
40 | pos, vel = q[0:3], qd[0:3]
41 | pos += vel * sys.opt.timestep
42 |
43 | return jp.concatenate([pos, rot])
44 |
45 |
46 | def integrate(sys: System, state: State) -> State:
47 | """Semi-implicit Euler integration.
48 |
49 | Args:
50 | sys: system defining the kinematic tree and other properties
51 | state: generalized state
52 |
53 | Returns:
54 | state: state with q, qd, and qdd updated
55 | """
56 | # integrate joint damping implicitly to increase stability when we are not
57 | # using approximate inverse
58 | if sys.matrix_inv_iterations == 0:
59 | mx = state.mass_mx + jp.diag(sys.dof.damping) * sys.opt.timestep
60 | mx_inv = jax.scipy.linalg.solve(mx, jp.eye(sys.qd_size()), assume_a='pos')
61 | else:
62 | mx_inv = state.mass_mx_inv
63 | qdd = mx_inv @ (state.qf_smooth + state.qf_constraint)
64 | qd = state.qd + qdd * sys.opt.timestep
65 |
66 | def q_fn(typ, link, q, qd):
67 | q = q.reshape(link.transform.pos.shape[0], -1)
68 | qd = qd.reshape(link.transform.pos.shape[0], -1)
69 | fun = jax.vmap(
70 | {
71 | 'f': _integrate_q_free,
72 | '1': _integrate_q_axis,
73 | '2': _integrate_q_axis,
74 | '3': _integrate_q_axis,
75 | }[typ],
76 | in_axes=(None, 0, 0),
77 | )
78 | q_s = fun(sys, q, qd).reshape(-1)
79 |
80 | return q_s
81 |
82 | q = scan.link_types(sys, q_fn, 'lqd', 'q', sys.link, state.q, qd)
83 |
84 | return state.replace(q=q, qd=qd, qdd=qdd)
85 |
--------------------------------------------------------------------------------
/brax/generalized/mass_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Tests for mass matrices."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from brax import test_utils
21 | from brax.generalized import pipeline
22 | import jax
23 | import mujoco
24 | import numpy as np
25 |
26 |
27 | class MassTest(parameterized.TestCase):
28 |
29 | @parameterized.parameters(
30 | ('ant.xml',),
31 | ('triple_pendulum.xml',),
32 | ('humanoid.xml',),
33 | ('half_cheetah.xml',),
34 | )
35 | def test_matrix(self, xml_file):
36 | """Test mass matrix calculation."""
37 | sys = test_utils.load_fixture(xml_file)
38 | model = test_utils.load_fixture_mujoco(xml_file)
39 | mj_mass_mx = np.zeros((sys.qd_size(), sys.qd_size()))
40 |
41 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
42 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
43 | mujoco.mj_fullM(model, mj_mass_mx, mj_next.qM)
44 | np.testing.assert_almost_equal(state.mass_mx, mj_mass_mx, 5)
45 |
46 |
47 | if __name__ == '__main__':
48 | absltest.main()
49 |
--------------------------------------------------------------------------------
/brax/generalized/perf_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Generalized perf tests."""
17 |
18 | from absl.testing import absltest
19 | from brax import test_utils
20 | from brax.generalized import pipeline
21 | import jax
22 | from jax import numpy as jp
23 |
24 |
25 | class PerfTest(absltest.TestCase):
26 |
27 | def test_pipeline_ant(self):
28 | sys = test_utils.load_fixture('ant.xml')
29 |
30 | def init_fn(rng):
31 | rng1, rng2 = jax.random.split(rng, 2)
32 | q = jp.array([0, 0, 0.75, 1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
33 | q += jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1)
34 | qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),))
35 | return pipeline.init(sys, q, qd)
36 |
37 | def step_fn(state):
38 | return pipeline.step(sys, state, jp.zeros(sys.act_size()))
39 |
40 | test_utils.benchmark('generalized pipeline ant', init_fn, step_fn)
41 |
42 |
43 | if __name__ == '__main__':
44 | absltest.main()
45 |
--------------------------------------------------------------------------------
/brax/generalized/pipeline.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Physics pipeline for generalized coordinates engine."""
17 |
18 | from typing import Optional
19 | from brax import actuator
20 | from brax import contact
21 | from brax import kinematics
22 | from brax.base import System
23 | from brax.generalized import constraint
24 | from brax.generalized import dynamics
25 | from brax.generalized import integrator
26 | from brax.generalized import mass
27 | from brax.generalized.base import State
28 | from brax.io import mjcf
29 | import jax
30 |
31 |
32 | def init(
33 | sys: System,
34 | q: jax.Array,
35 | qd: jax.Array,
36 | unused_act: Optional[jax.Array] = None,
37 | unused_ctrl: Optional[jax.Array] = None,
38 | debug: bool = False,
39 | ) -> State:
40 | """Initializes physics state.
41 |
42 | Args:
43 | sys: a brax system
44 | q: (q_size,) joint angle vector
45 | qd: (qd_size,) joint velocity vector
46 | debug: if True, adds contact to the state for debugging
47 |
48 | Returns:
49 | state: initial physics state
50 | """
51 | if sys.mj_model is not None:
52 | mjcf.validate_model(sys.mj_model)
53 | x, xd = kinematics.forward(sys, q, qd)
54 | state = State.init(q, qd, x, xd) # pytype: disable=wrong-arg-types # jax-ndarray
55 | state = dynamics.transform_com(sys, state)
56 | state = mass.matrix_inv(sys, state, 0)
57 | state = constraint.jacobian(sys, state)
58 | if debug:
59 | state = state.replace(contact=contact.get(sys, state.x))
60 |
61 | return state
62 |
63 |
64 | def step(
65 | sys: System, state: State, act: jax.Array, debug: bool = False
66 | ) -> State:
67 | """Performs a physics step.
68 |
69 | Args:
70 | sys: a brax system
71 | state: physics state prior to step
72 | act: (act_size,) actuator input vector
73 | debug: if True, adds contact to the state for debugging
74 |
75 | Returns:
76 | state: physics state after step
77 | """
78 | # calculate acceleration terms
79 | tau = actuator.to_tau(sys, act, state.q, state.qd)
80 | state = state.replace(qf_smooth=dynamics.forward(sys, state, tau))
81 | state = state.replace(qf_constraint=constraint.force(sys, state))
82 |
83 | # update position/velocity level terms
84 | state = integrator.integrate(sys, state)
85 | x, xd = kinematics.forward(sys, state.q, state.qd)
86 | state = state.replace(x=x, xd=xd)
87 | state = dynamics.transform_com(sys, state)
88 | state = mass.matrix_inv(sys, state, sys.matrix_inv_iterations)
89 | state = constraint.jacobian(sys, state)
90 |
91 | if debug:
92 | state = state.replace(contact=contact.get(sys, state.x))
93 |
94 | return state
95 |
--------------------------------------------------------------------------------
/brax/generalized/pipeline_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Tests for generalized pipeline."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from brax import test_utils
21 | from brax.generalized import pipeline
22 | from brax.io import mjcf
23 | import jax
24 | from jax import numpy as jp
25 | import numpy as np
26 |
27 |
28 | class PipelineTest(parameterized.TestCase):
29 |
30 | @parameterized.parameters(
31 | ('ant.xml',),
32 | ('triple_pendulum.xml',),
33 | ('humanoid.xml',),
34 | ('half_cheetah.xml',),
35 | ('swimmer.xml',),
36 | )
37 | def test_forward(self, xml_file):
38 | """Test pipeline step."""
39 | sys = test_utils.load_fixture(xml_file)
40 | # crank up solver iterations just to demonstrate close match to mujoco
41 | sys = sys.replace(solver_iterations=500)
42 | for mj_prev, mj_next in test_utils.sample_mujoco_states(xml_file):
43 | state = jax.jit(pipeline.init)(sys, mj_prev.qpos, mj_prev.qvel)
44 | state = jax.jit(pipeline.step)(sys, state, jp.zeros(sys.act_size()))
45 |
46 | np.testing.assert_allclose(state.q, mj_next.qpos, atol=0.002)
47 | np.testing.assert_allclose(state.qd, mj_next.qvel, atol=0.5)
48 |
49 |
50 | class GradientTest(absltest.TestCase):
51 | """Tests that gradients are not NaN."""
52 |
53 | def test_grad(self):
54 | """Tests that gradients are not NaN."""
55 | xml = """
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 | """
66 | sys = mjcf.loads(xml)
67 | init_state = jax.jit(pipeline.init)(
68 | sys, sys.init_q, jp.zeros(sys.qd_size())
69 | )
70 |
71 | def fn(xd):
72 | qd = jp.zeros(sys.qd_size()).at[0].set(xd)
73 | state = init_state.replace(qd=qd)
74 | for _ in range(10):
75 | state = jax.jit(pipeline.step)(sys, state, None)
76 | return state.qd[0]
77 |
78 | grad = jax.grad(fn)(-1.0)
79 | self.assertFalse(np.isnan(grad))
80 |
81 |
82 | if __name__ == '__main__':
83 | absltest.main()
84 |
--------------------------------------------------------------------------------
/brax/io/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/io/html.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Exports a system config and trajectory as an html view."""
17 |
18 | import base64
19 | from typing import List, Optional, Union
20 | import zlib
21 |
22 | import brax
23 | from brax.base import State, System
24 | from brax.io import json
25 | from etils import epath
26 | import jinja2
27 |
28 |
29 | def save(path: str, sys: System, states: List[State]):
30 | """Saves trajectory as an HTML text file."""
31 | path = epath.Path(path)
32 | if not path.parent.exists():
33 | path.parent.mkdir(parents=True)
34 | path.write_text(render(sys, states))
35 |
36 |
37 | def render_from_json(
38 | sys: str, height: Union[int, str], colab: bool, base_url: Optional[str]
39 | ) -> str:
40 | """Returns an HTML string that visualizes the brax system json string."""
41 | html_path = epath.resource_path('brax') / 'visualizer/index.html'
42 | template = jinja2.Template(html_path.read_text())
43 |
44 | js_url = base_url
45 | if base_url is None:
46 | base_url = 'https://cdn.jsdelivr.net/gh/google/brax'
47 | js_url = f'{base_url}@v{brax.__version__}/brax/visualizer/js/viewer.js'
48 |
49 | sys = base64.b64encode(zlib.compress(bytes(sys, 'utf-8'))).decode('ascii')
50 | html = template.render(
51 | system_json_b64=sys, height=height, js_url=js_url, colab=colab
52 | )
53 | return html
54 |
55 |
56 | def render(
57 | sys: System,
58 | states: List[State],
59 | height: Union[int, str] = 480,
60 | colab: bool = True,
61 | base_url: Optional[str] = None,
62 | ) -> str:
63 | """Returns an HTML string for the brax system and trajectory.
64 |
65 | Args:
66 | sys: brax System object
67 | states: list of system states to render
68 | height: the height of the render window
69 | colab: whether to use css styles for colab
70 | base_url: the base url for serving the visualizer files. By default, a CDN
71 | url is used
72 |
73 | Returns:
74 | string containing HTML for the brax visualizer
75 | """
76 | return render_from_json(json.dumps(sys, states), height, colab, base_url)
77 |
--------------------------------------------------------------------------------
/brax/io/image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Exports a system config and state as an image."""
16 |
17 | import io
18 | from typing import List, Optional, Sequence, Union
19 |
20 | import brax
21 | from brax import base
22 | import mujoco
23 | import numpy as np
24 | from PIL import Image
25 |
26 |
27 | def render_array(
28 | sys: brax.System,
29 | trajectory: Union[List[base.State], base.State],
30 | height: int = 240,
31 | width: int = 320,
32 | camera: Optional[str] = None,
33 | ) -> Union[Sequence[np.ndarray], np.ndarray]:
34 | """Returns a sequence of np.ndarray images using the MuJoCo renderer."""
35 | renderer = mujoco.Renderer(sys.mj_model, height=height, width=width)
36 | camera = camera or -1
37 |
38 | def get_image(state: base.State):
39 | d = mujoco.MjData(sys.mj_model)
40 | d.qpos, d.qvel = state.q, state.qd
41 | if hasattr(state, 'mocap_pos') and hasattr(state, 'mocap_quat'):
42 | d.mocap_pos, d.mocap_quat = state.mocap_pos, state.mocap_quat
43 | mujoco.mj_forward(sys.mj_model, d)
44 | renderer.update_scene(d, camera=camera)
45 | return renderer.render()
46 |
47 | if isinstance(trajectory, list):
48 | return [get_image(s) for s in trajectory]
49 |
50 | return get_image(trajectory)
51 |
52 |
53 | def render(
54 | sys: brax.System,
55 | trajectory: List[base.State],
56 | height: int = 240,
57 | width: int = 320,
58 | camera: Optional[str] = None,
59 | fmt: str = 'png',
60 | ) -> bytes:
61 | """Returns an image of a brax System."""
62 | if not trajectory:
63 | raise RuntimeError('must have at least one state')
64 |
65 | frames = render_array(sys, trajectory, height, width, camera)
66 | frames = [Image.fromarray(image) for image in frames]
67 |
68 | f = io.BytesIO()
69 | if len(frames) == 1:
70 | frames[0].save(f, format=fmt)
71 | else:
72 | frames[0].save(
73 | f,
74 | format=fmt,
75 | append_images=frames[1:],
76 | save_all=True,
77 | duration=sys.opt.timestep * 1000,
78 | loop=0,
79 | )
80 |
81 | return f.getvalue()
82 |
--------------------------------------------------------------------------------
/brax/io/json_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for json."""
16 |
17 | import json
18 |
19 | from absl.testing import absltest
20 | from brax import test_utils
21 | from brax.generalized import pipeline
22 | from brax.io import json as bjson
23 | import jax
24 | import jax.numpy as jp
25 |
26 |
27 | class JsonTest(absltest.TestCase):
28 |
29 | def test_dumps(self):
30 | sys = test_utils.load_fixture('convex_convex.xml')
31 | state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size()))
32 | res = bjson.dumps(sys, [state])
33 | res = json.loads(res)
34 |
35 | self.assertIsInstance(res['geoms'], dict)
36 | self.assertSequenceEqual(
37 | sorted(res['geoms'].keys()),
38 | ['box', 'dodecahedron', 'pyramid', 'tetrahedron', 'world'],
39 | )
40 | self.assertLen(res['geoms']['world'], 1)
41 |
42 | for f in ['size', 'rgba', 'name', 'link_idx', 'pos', 'rot']:
43 | self.assertIn(f, res['geoms']['box'][0])
44 |
45 | def test_dumps_invalidstate_raises(self):
46 | sys = test_utils.load_fixture('convex_convex.xml')
47 | state = pipeline.init(sys, sys.init_q, jp.zeros(sys.qd_size()))
48 | state = jax.tree.map(lambda x: jp.stack([x, x]), state)
49 | with self.assertRaises(RuntimeError):
50 | bjson.dumps(sys, [state])
51 |
52 |
53 | if __name__ == '__main__':
54 | absltest.main()
55 |
--------------------------------------------------------------------------------
/brax/io/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 |
17 | """General purpose metrics writer interface."""
18 |
19 | from absl import logging
20 |
21 | try:
22 | from tensorboardX import SummaryWriter # type: ignore
23 | finally:
24 | pass
25 |
26 |
27 |
28 | class Writer:
29 | """General purpose metrics writer."""
30 |
31 | def __init__(self, logdir=''):
32 | self._writer = SummaryWriter(logdir=logdir)
33 |
34 | def __enter__(self):
35 | return self
36 |
37 | def __exit__(self, exc_type, exc_val, exc_tb):
38 | self._writer.close()
39 |
40 | def write_hparams(self, hparams):
41 | """Writes global hparams."""
42 | logging.info('Hyperparameters: %s', hparams)
43 | self._writer.add_hparams(hparams, {})
44 |
45 | def write_scalars(self, step, scalars):
46 | """Writers scalar metrics."""
47 | values = [
48 | f'{k}={v:.6f}' if isinstance(v, float) else f'{k}={v}'
49 | for k, v in sorted(scalars.items())
50 | ]
51 | logging.info('[%d] %s', step, ', '.join(values))
52 | for k, v in scalars.items():
53 | self._writer.add_scalar(k, v, step)
54 |
--------------------------------------------------------------------------------
/brax/io/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Loading/saving of inference functions."""
16 |
17 | import pickle
18 | from typing import Any
19 | from etils import epath
20 |
21 |
22 | def load_params(path: str) -> Any:
23 | with epath.Path(path).open('rb') as fin:
24 | buf = fin.read()
25 | return pickle.loads(buf)
26 |
27 |
28 | def save_params(path: str, params: Any):
29 | """Saves parameters in flax format."""
30 | with epath.Path(path).open('wb') as fout:
31 | fout.write(pickle.dumps(params))
32 |
--------------------------------------------------------------------------------
/brax/io/torch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions to convert Jax Arrays into PyTorch Tensors and vice-versa."""
16 |
17 | from collections import abc
18 | import functools
19 | from typing import Any, Dict, Union
20 | import warnings
21 |
22 | import jax
23 | from jax import dlpack as jax_dlpack
24 |
25 | try:
26 | # pylint:disable=g-import-not-at-top
27 | import torch
28 | from torch.utils import dlpack as torch_dlpack
29 | except ImportError:
30 | warnings.warn(
31 | "brax.io.torch requires PyTorch. Please run `pip install torch` to use "
32 | "functions from this module."
33 | )
34 | raise
35 |
36 | Device = Union[str, torch.device]
37 |
38 |
39 | @functools.singledispatch
40 | def torch_to_jax(value: Any) -> Any:
41 | """Converts PyTorch tensors to JAX arrays.
42 |
43 | Args:
44 | value: torch tensor
45 |
46 | Returns:
47 | a JAX array
48 | """
49 | del value
50 |
51 |
52 | @torch_to_jax.register(torch.Tensor)
53 | def _tensor_to_jax(value: torch.Tensor) -> jax.Array:
54 | """Converts a PyTorch Tensor into a jax.Array."""
55 | tensor = torch_dlpack.to_dlpack(value)
56 | tensor = jax_dlpack.from_dlpack(tensor)
57 | return tensor
58 |
59 |
60 | @torch_to_jax.register(abc.Mapping)
61 | def _torch_dict_to_jax(
62 | value: Dict[str, Union[torch.Tensor, Any]],
63 | ) -> Dict[str, Union[jax.Array, Any]]:
64 | """Converts a dict of PyTorch tensors into a dict of jax.Arrays."""
65 | return type(value)(**{k: torch_to_jax(v) for k, v in value.items()}) # type: ignore
66 |
67 |
68 | @functools.singledispatch
69 | def jax_to_torch(value: Any, device: Union[Device, None] = None) -> Any:
70 | """Convert JAX values to PyTorch Tensors.
71 |
72 | Args:
73 | value: jax array or pytree
74 | device: device to copy value to (or None to leave on same device)
75 |
76 | Returns:
77 | Torch tensor on device
78 |
79 | By default, the returned tensors are on the same device as the Jax inputs,
80 | but if `device` is passed, the tensors will be moved to that device.
81 | """
82 | del value, device
83 |
84 |
85 | @jax_to_torch.register(jax.Array)
86 | def _jaxarray_to_tensor(
87 | value: jax.Array, device: Union[Device, None] = None
88 | ) -> torch.Tensor:
89 | """Converts a jax.Array into PyTorch Tensor."""
90 | tensor = torch_dlpack.from_dlpack(value.astype("float32"))
91 | if device:
92 | return tensor.to(device=device)
93 | return tensor
94 |
95 |
96 | @jax_to_torch.register(abc.Mapping)
97 | def _jax_dict_to_torch(
98 | value: Dict[str, Union[jax.Array, Any]], device: Union[Device, None] = None
99 | ) -> Dict[str, Union[torch.Tensor, Any]]:
100 | """Converts a dict of jax.Arrays into a dict of PyTorch tensors."""
101 | return type(value)(
102 | **{k: jax_to_torch(v, device=device) for k, v in value.items()}
103 | ) # type: ignore
104 |
--------------------------------------------------------------------------------
/brax/math_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for geometry."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from brax import math
20 | import jax
21 | from jax import numpy as jp
22 | import numpy as np
23 |
24 |
25 | def _get_rand_norm(seed: int):
26 | np.random.seed(seed)
27 | theta = np.random.random(1) * 2 * np.pi
28 | a = (np.random.random(1) - 0.5) * 2.0
29 | phi = np.arccos(a)
30 | x = np.sin(phi) * np.cos(theta)
31 | y = np.sin(phi) * np.sin(theta)
32 | z = np.cos(phi)
33 | return jp.array([x, y, z]).squeeze()
34 |
35 |
36 | class MathTest(absltest.TestCase):
37 |
38 | def test_inv_approximate(self):
39 | # create a 4x4 matrix we know is invertible
40 | x = jax.random.normal(jax.random.PRNGKey(0), (4, 4))
41 | x = jp.eye(4) * 0.001 + x @ x.T
42 |
43 | x_inv = jp.linalg.inv(x)
44 | x_inv_approximate = math.inv_approximate(x, jp.zeros((4, 4)), num_iter=100)
45 |
46 | np.testing.assert_array_almost_equal(x_inv_approximate, x_inv)
47 |
48 | def test_from_to(self):
49 | v1 = jp.array([1.0, 0.0, 0.0])
50 | rot = math.from_to(v1, v1)
51 | np.testing.assert_array_almost_equal(v1, math.rotate(v1, rot))
52 |
53 | rot = math.from_to(v1, -v1)
54 | np.testing.assert_array_almost_equal(-v1, math.rotate(v1, rot))
55 |
56 | rot = math.from_to(-v1, v1)
57 | np.testing.assert_array_almost_equal(v1, math.rotate(-v1, rot))
58 |
59 | v1 = jp.array([0.0, 1.0, 0.0])
60 | rot = math.from_to(v1, -v1)
61 | np.testing.assert_array_almost_equal(-v1, math.rotate(v1, rot))
62 |
63 | v2 = jp.array([-0.5, 0.5, 0.0])
64 | v2 /= jp.linalg.norm(v2)
65 | rot = math.from_to(v1, v2)
66 | np.testing.assert_array_almost_equal(v2, math.rotate(v1, rot))
67 |
68 |
69 | class OrthoganalsTest(parameterized.TestCase):
70 | """Tests the orthogonals function."""
71 |
72 | @parameterized.parameters(range(100))
73 | def test_orthogonals(self, i):
74 | a = _get_rand_norm(i)
75 | b, c = math.orthogonals(a)
76 | np.testing.assert_almost_equal(jp.linalg.norm(a), 1)
77 | np.testing.assert_almost_equal(jp.linalg.norm(b), 1)
78 | np.testing.assert_almost_equal(jp.linalg.norm(c), 1)
79 | self.assertAlmostEqual(np.abs(a.dot(b)), 0, 6)
80 | self.assertAlmostEqual(np.abs(b.dot(c)), 0, 6)
81 | self.assertAlmostEqual(np.abs(a.dot(c)), 0, 6)
82 |
83 |
84 | if __name__ == '__main__':
85 | absltest.main()
86 |
--------------------------------------------------------------------------------
/brax/mjx/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/mjx/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Brax adapter for MJX physics engine."""
17 |
18 | from brax import base
19 | from mujoco import mjx
20 |
21 |
22 | class State(base.State, mjx.Data):
23 | """Dynamic state that changes after every pipeline step."""
24 |
--------------------------------------------------------------------------------
/brax/mjx/perf_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """PBD perf tests."""
17 |
18 | from absl.testing import absltest
19 | from brax import test_utils
20 | from brax.mjx import pipeline
21 | import jax
22 | from jax import numpy as jp
23 |
24 |
25 | class PerfTest(absltest.TestCase):
26 |
27 | def test_pipeline_ant(self):
28 | model = test_utils.load_fixture('ant.xml')
29 |
30 | def init_fn(rng):
31 | rng1, rng2 = jax.random.split(rng, 2)
32 | q = jax.random.uniform(rng1, (model.nq,), minval=-0.1, maxval=0.1)
33 | qd = 0.1 * jax.random.normal(rng2, (model.nv,))
34 | return pipeline.init(model, q, qd)
35 |
36 | def step_fn(data):
37 | return pipeline.step(model, data, jp.zeros(model.nu))
38 |
39 | test_utils.benchmark('mjx pipeline ant', init_fn, step_fn)
40 |
41 |
42 | if __name__ == '__main__':
43 | absltest.main()
44 |
--------------------------------------------------------------------------------
/brax/mjx/pipeline_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | # pylint:disable=g-importing-member
17 | """Tests for spring physics pipeline."""
18 |
19 | from absl.testing import absltest
20 | from brax import test_utils
21 | from brax.base import Contact
22 | from brax.mjx import pipeline
23 | import jax
24 | from jax import numpy as jp
25 | import mujoco
26 | import numpy as np
27 |
28 |
29 | class PipelineTest(absltest.TestCase):
30 |
31 | def test_pendulum(self):
32 | model = test_utils.load_fixture('double_pendulum.xml')
33 |
34 | state = pipeline.init(model, model.init_q, jp.zeros(model.qd_size()))
35 |
36 | self.assertIsInstance(state.contact, Contact)
37 |
38 | step_fn = jax.jit(pipeline.step)
39 | for _ in range(20):
40 | state = step_fn(model, state, jp.zeros(model.act_size()))
41 |
42 | # compare against mujoco
43 | model = test_utils.load_fixture_mujoco('double_pendulum.xml')
44 | data = mujoco.MjData(model)
45 | mujoco.mj_step(model, data, 20)
46 |
47 | np.testing.assert_almost_equal(data.qpos, state.q, decimal=4)
48 | np.testing.assert_almost_equal(data.qvel, state.qd, decimal=3)
49 | np.testing.assert_almost_equal(data.xpos[1:], state.x.pos, decimal=4)
50 |
51 | def test_pipeline_init_with_ctrl(self):
52 | model = test_utils.load_fixture('single_spherical_pendulum_position.xml')
53 | ctrl = jp.array([0.3, 0.5, 0.4])
54 | state = pipeline.init(
55 | model,
56 | model.init_q,
57 | jp.zeros(model.qd_size()),
58 | ctrl=ctrl,
59 | )
60 | np.testing.assert_array_almost_equal(state.ctrl, ctrl)
61 |
62 |
63 | if __name__ == '__main__':
64 | absltest.main()
65 |
--------------------------------------------------------------------------------
/brax/positional/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/positional/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Base types for positional pipeline."""
17 |
18 | from brax import base
19 | from brax.base import Motion, Transform
20 | from flax import struct
21 | import jax
22 | import jax.numpy as jp
23 |
24 |
25 | @struct.dataclass
26 | class State(base.State):
27 | """Dynamic state that changes after every step.
28 |
29 | Attributes:
30 | x_i: link center of mass in world frame
31 | xd_i: link center of mass motion in world frame
32 | j: link position in joint frame
33 | jd: link motion in joint frame
34 | a_p: joint parent anchor in world frame
35 | a_c: joint child anchor in world frame
36 | mass: link mass
37 | """
38 |
39 | x_i: Transform
40 | xd_i: Motion
41 | j: Transform
42 | jd: Motion
43 | a_p: Transform
44 | a_c: Transform
45 | mass: jax.Array
46 |
--------------------------------------------------------------------------------
/brax/positional/integrator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for integrating maximal coordinate dynamics."""
16 |
17 | # pylint:disable=g-multiple-import
18 | from typing import Tuple
19 |
20 | from brax import math
21 | from brax.base import Motion, System, Transform
22 | import jax
23 | from jax import numpy as jp
24 |
25 |
26 | def integrate_xdv(sys: System, xd: Motion, xdv: Motion) -> Motion:
27 | """Updates velocity by applying delta-velocity.
28 |
29 | Args:
30 | sys: System to forward propagate
31 | xd: velocity
32 | xdv: delta-velocity
33 |
34 | Returns:
35 | xd: updated velocity
36 | """
37 | damp = Motion(vel=sys.vel_damping, ang=sys.ang_damping)
38 | xd = (
39 | jax.tree.map(lambda d, x: jp.exp(d * sys.opt.timestep) * x, damp, xd)
40 | + xdv
41 | )
42 |
43 | return xd
44 |
45 |
46 | def integrate_xdd(
47 | sys: System,
48 | x: Transform,
49 | xd: Motion,
50 | xdd: Motion,
51 | ) -> Tuple[Transform, Motion]:
52 | """Updates position and velocity for a system time step given acceleration.
53 |
54 | Args:
55 | sys: System to forward propagate
56 | x: position
57 | xd: velocity
58 | xdd: acceleration
59 |
60 | Returns:
61 | x: updated position
62 | xd: updated velocity
63 | """
64 |
65 | xd = xd + xdd * sys.opt.timestep
66 | damp = Motion(vel=sys.vel_damping, ang=sys.ang_damping)
67 | xd = jax.tree.map(lambda d, x: jp.exp(d * sys.opt.timestep) * x, damp, xd)
68 |
69 | @jax.vmap
70 | def op(x, xd):
71 | pos = x.pos + xd.vel * sys.opt.timestep
72 | rot_at_ang_quat = math.ang_to_quat(xd.ang) * 0.5 * sys.opt.timestep
73 | rot, _ = math.normalize(x.rot + math.quat_mul(rot_at_ang_quat, x.rot))
74 | return Transform(pos=pos, rot=rot)
75 |
76 | x = op(x, xd)
77 |
78 | return x, xd
79 |
80 |
81 | def project_xd(sys: System, x: Transform, x_prev: Transform) -> Motion:
82 | """Performs the position based dynamics velocity projection step.
83 |
84 | The velocity and angular velocity must respect the spatial and quaternion
85 | distance (respectively) between x and x_prev.
86 |
87 | Args:
88 | sys: The system definition
89 | x: The current transform
90 | x_prev: The transform at the previous step
91 |
92 | Returns:
93 | New state with velocity pinned to respect distance traveled since x_prev
94 | """
95 |
96 | @jax.vmap
97 | def op(x, x_prev):
98 | vel = (x.pos - x_prev.pos) / sys.opt.timestep
99 | dq = math.relative_quat(x_prev.rot, x.rot)
100 | ang = 2.0 * dq[1:] / sys.opt.timestep
101 | scale = jp.where(dq[0] >= 0.0, 1.0, -1.0)
102 | ang = scale * ang
103 | return Motion(vel=vel, ang=ang)
104 |
105 | return op(x, x_prev)
106 |
--------------------------------------------------------------------------------
/brax/positional/perf_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """PBD perf tests."""
17 |
18 | from absl.testing import absltest
19 | from brax import test_utils
20 | from brax.positional import pipeline
21 | import jax
22 | from jax import numpy as jp
23 |
24 |
25 | class PerfTest(absltest.TestCase):
26 |
27 | def test_pipeline_ant(self):
28 | sys = test_utils.load_fixture('ant.xml')
29 |
30 | def init_fn(rng):
31 | rng1, rng2 = jax.random.split(rng, 2)
32 | q = jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1)
33 | qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),))
34 | return pipeline.init(sys, q, qd)
35 |
36 | def step_fn(state):
37 | return pipeline.step(sys, state, jp.zeros(sys.act_size()))
38 |
39 | test_utils.benchmark('pbd pipeline ant', init_fn, step_fn)
40 |
41 |
42 | if __name__ == '__main__':
43 | absltest.main()
44 |
--------------------------------------------------------------------------------
/brax/spring/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/spring/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Base types for spring pipeline."""
17 |
18 | from brax import base
19 | from brax.base import Motion, Transform
20 | from flax import struct
21 | import jax
22 | from jax import numpy as jp
23 |
24 |
25 | @struct.dataclass
26 | class State(base.State):
27 | """Dynamic state that changes after every step.
28 |
29 | Attributes:
30 | x_i: (num_links,) link center of mass position in world frame
31 | xd_i: (num_links,) link center of mass velocity in world frame
32 | j: link position in joint frame
33 | jd: link motion in joint frame
34 | a_p: joint parent anchor in world frame
35 | a_c: joint child anchor in world frame
36 | i_inv: link inverse inertia
37 | mass: link mass
38 | """
39 |
40 | x_i: Transform
41 | xd_i: Motion
42 | j: Transform
43 | jd: Motion
44 | a_p: Transform
45 | a_c: Transform
46 | i_inv: jax.Array
47 | mass: jax.Array
48 |
--------------------------------------------------------------------------------
/brax/spring/integrator.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Functions for integrating maximal coordinate dynamics."""
16 |
17 | # pylint:disable=g-multiple-import
18 | from typing import Tuple
19 |
20 | from brax import math
21 | from brax.base import Motion, System, Transform
22 | import jax
23 | from jax import numpy as jp
24 |
25 |
26 | def integrate(
27 | sys: System,
28 | x_i: Transform,
29 | xd_i: Motion,
30 | xdv_i: Motion,
31 | ) -> Tuple[Transform, Motion]:
32 | """Updates state with velocity update in the center of mass frame.
33 |
34 | Args:
35 | sys: System to forward propagate
36 | x_i: link center of mass transform in world frame
37 | xd_i: link center of mass motion in world frame
38 | xdv_i: link center of mass delta-velocity in world frame
39 |
40 | Returns:
41 | x_i: updated link center of mass transform in world frame
42 | xd_i: updated link center of mass motion in world frame
43 | """
44 |
45 | @jax.vmap
46 | def op(x_i, xd_i, xdv_i):
47 | # damp velocity and add acceleration
48 | xd_i = Motion(
49 | vel=jp.exp(sys.vel_damping * sys.opt.timestep) * xd_i.vel,
50 | ang=jp.exp(sys.ang_damping * sys.opt.timestep) * xd_i.ang,
51 | )
52 | xd_i += xdv_i
53 |
54 | rot_at_ang_quat = math.ang_to_quat(xd_i.ang) * 0.5 * sys.opt.timestep
55 | rot = x_i.rot + math.quat_mul(rot_at_ang_quat, x_i.rot)
56 | x_i = Transform(
57 | pos=x_i.pos + xd_i.vel * sys.opt.timestep, rot=rot / jp.linalg.norm(rot)
58 | )
59 |
60 | return x_i, xd_i
61 |
62 | return op(x_i, xd_i, xdv_i)
63 |
--------------------------------------------------------------------------------
/brax/spring/joints_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for spring physics joints."""
16 |
17 | from absl.testing import absltest
18 | from absl.testing import parameterized
19 | from brax import test_utils
20 | from brax.spring import pipeline
21 | import jax
22 | from jax import numpy as jp
23 |
24 |
25 | class JointTest(parameterized.TestCase):
26 |
27 | @parameterized.parameters(
28 | (2.0, 0.125, 0.0625), (5.0, 0.125, 0.03125), (1.0, 0.0625, 0.1)
29 | )
30 | def test_pendulum_period(self, mass, radius, vel):
31 | """A small spherical mass swings for approximately one period."""
32 | sys = test_utils.load_fixture('single_pendulum.xml')
33 |
34 | dist_to_anchor = 0.5
35 | inertia_cm = 2.0 / 5.0 * mass * radius**2.0
36 | inertia_about_anchor = mass * dist_to_anchor**2.0 + inertia_cm
37 | g = 9.81
38 | # formula for period of pendulum
39 | period = (
40 | 2 * jp.pi * jp.sqrt(inertia_about_anchor / (mass * g * dist_to_anchor))
41 | )
42 | num_timesteps = 1_000
43 | sys = sys.tree_replace({'opt.timestep': period / num_timesteps})
44 | link = sys.link.replace(constraint_limit_stiffness=jp.array([0.0] * 1))
45 | link = link.replace(constraint_stiffness=jp.array([10_000.0] * 1))
46 | link = link.replace(constraint_ang_damping=jp.array([0.0] * 1))
47 | link = link.replace(constraint_vel_damping=jp.array([0.0] * 1))
48 | sys = sys.replace(link=link)
49 | sys = sys.replace(ang_damping=0.0)
50 | sys = sys.replace(
51 | link=sys.link.replace(
52 | inertia=sys.link.inertia.replace(
53 | i=jp.array([0.4 * mass * radius**2 * jp.eye(3)] * 1),
54 | mass=jp.array([mass]),
55 | )
56 | )
57 | )
58 |
59 | # init with small initial velocity for small angle approx. validity
60 | state = pipeline.init(sys, jp.array([-jp.pi / 2.0]), jp.array([vel]))
61 |
62 | j_spring_step = jax.jit(pipeline.step)
63 | for _ in range(num_timesteps):
64 | state = j_spring_step(sys, state, jp.zeros(sys.act_size()))
65 |
66 | self.assertAlmostEqual(state.xd.ang[0, 0], vel, 2) # returned to the origin
67 |
68 |
69 | if __name__ == '__main__':
70 | absltest.main()
71 |
--------------------------------------------------------------------------------
/brax/spring/perf_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint:disable=g-multiple-import
16 | """Spring perf tests."""
17 |
18 | from absl.testing import absltest
19 | from brax import test_utils
20 | from brax.spring import pipeline
21 | import jax
22 | from jax import numpy as jp
23 |
24 |
25 | class PerfTest(absltest.TestCase):
26 |
27 | def test_pipeline_ant(self):
28 | sys = test_utils.load_fixture('ant.xml')
29 |
30 | def init_fn(rng):
31 | rng1, rng2 = jax.random.split(rng, 2)
32 | q = jax.random.uniform(rng1, (sys.q_size(),), minval=-0.1, maxval=0.1)
33 | qd = 0.1 * jax.random.normal(rng2, (sys.qd_size(),))
34 | return pipeline.init(sys, q, qd)
35 |
36 | def step_fn(state):
37 | return pipeline.step(sys, state, jp.zeros(sys.act_size()))
38 |
39 | test_utils.benchmark('spring pipeline ant', init_fn, step_fn)
40 |
41 |
42 | if __name__ == '__main__':
43 | absltest.main()
44 |
--------------------------------------------------------------------------------
/brax/test_data/capsule.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/brax/test_data/colour_objects.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/brax/test_data/convex_convex.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
--------------------------------------------------------------------------------
/brax/test_data/double_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
--------------------------------------------------------------------------------
/brax/test_data/double_prismatic.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/brax/test_data/fluid_box.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/brax/test_data/fluid_box_offset_com.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
--------------------------------------------------------------------------------
/brax/test_data/fluid_ellipsoid.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/brax/test_data/fluid_sphere.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/brax/test_data/fluid_two_spheres.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
--------------------------------------------------------------------------------
/brax/test_data/fluid_wind.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/brax/test_data/meshes/cylinder.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/cylinder.stl
--------------------------------------------------------------------------------
/brax/test_data/meshes/dodecahedron.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/dodecahedron.stl
--------------------------------------------------------------------------------
/brax/test_data/meshes/pyramid.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/pyramid.stl
--------------------------------------------------------------------------------
/brax/test_data/meshes/tetrahedron.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/test_data/meshes/tetrahedron.stl
--------------------------------------------------------------------------------
/brax/test_data/nonzero_joint_ref.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/brax/test_data/prismaversal_2dof_joint.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/brax/test_data/prismaversal_3dof_joint.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/brax/test_data/single_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/brax/test_data/single_pendulum_motor.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/brax/test_data/single_pendulum_position.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/brax/test_data/single_pendulum_position_frclimit.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/brax/test_data/single_pendulum_velocity.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/brax/test_data/single_prismatic.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/brax/test_data/single_spherical_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/brax/test_data/single_spherical_pendulum_motor.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/brax/test_data/single_spherical_pendulum_position.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/brax/test_data/single_universal_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/brax/test_data/solver_params_v2.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/brax/test_data/triple_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/brax/test_data/triple_pendulum_motor.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------
/brax/test_data/triple_prismatic.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/brax/test_data/world_body_transform.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/brax/test_data/world_fromto.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/brax/test_data/world_self_collision.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
--------------------------------------------------------------------------------
/brax/training/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/acme/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/acme/specs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Objects which specify the input/output spaces of an environment.
16 |
17 |
18 | This file was taken from acme and modified to simplify dependencies:
19 |
20 | https://github.com/deepmind/acme/blob/master/acme/specs.py
21 | """
22 | import dataclasses
23 | from typing import Tuple
24 |
25 | import jax.numpy as jnp
26 |
27 |
28 | @dataclasses.dataclass(frozen=True)
29 | class Array:
30 | """Describes a numpy array or scalar shape and dtype.
31 |
32 | Similar to dm_env.specs.Array.
33 | """
34 | shape: Tuple[int, ...]
35 | dtype: jnp.dtype
36 |
--------------------------------------------------------------------------------
/brax/training/acme/types.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Common types used throughout Acme.
16 |
17 | This file was taken from acme and modified to simplify dependencies:
18 |
19 | https://github.com/deepmind/acme/blob/master/acme/types.py
20 | """
21 | from typing import Any, Iterable, Mapping, Union
22 |
23 | from brax.training.acme import specs
24 | import jax.numpy as jnp
25 |
26 | # Define types for nested arrays and tensors.
27 | NestedArray = jnp.ndarray
28 | NestedTensor = Any
29 |
30 | # pytype: disable=not-supported-yet
31 | NestedSpec = Union[
32 | specs.Array,
33 | Iterable['NestedSpec'],
34 | Mapping[Any, 'NestedSpec'],
35 | ]
36 | # pytype: enable=not-supported-yet
37 |
38 | Nest = Union[NestedArray, NestedTensor, NestedSpec]
39 |
--------------------------------------------------------------------------------
/brax/training/agents/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/agents/apg/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/agents/apg/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """APG networks."""
16 |
17 | from typing import Sequence, Tuple
18 |
19 | from brax.training import distribution
20 | from brax.training import networks
21 | from brax.training import types
22 | from brax.training.types import PRNGKey
23 | import flax
24 | from flax import linen
25 |
26 |
27 | @flax.struct.dataclass
28 | class APGNetworks:
29 | policy_network: networks.FeedForwardNetwork
30 | parametric_action_distribution: distribution.ParametricDistribution
31 |
32 |
33 | def make_inference_fn(apg_networks: APGNetworks):
34 | """Creates params and inference function for the APG agent."""
35 |
36 | def make_policy(
37 | params: types.PolicyParams, deterministic: bool = False
38 | ) -> types.Policy:
39 |
40 | def policy(
41 | observations: types.Observation, key_sample: PRNGKey
42 | ) -> Tuple[types.Action, types.Extra]:
43 | logits = apg_networks.policy_network.apply(*params, observations)
44 | if deterministic:
45 | return apg_networks.parametric_action_distribution.mode(logits), {}
46 | return (
47 | apg_networks.parametric_action_distribution.sample(
48 | logits, key_sample
49 | ),
50 | {},
51 | )
52 |
53 | return policy
54 |
55 | return make_policy
56 |
57 |
58 | def make_apg_networks(
59 | observation_size: int,
60 | action_size: int,
61 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
62 | hidden_layer_sizes: Sequence[int] = (32,) * 4,
63 | activation: networks.ActivationFn = linen.elu,
64 | layer_norm: bool = True,
65 | ) -> APGNetworks:
66 | """Make APG networks."""
67 | parametric_action_distribution = distribution.NormalTanhDistribution(
68 | event_size=action_size, var_scale=0.1
69 | )
70 | policy_network = networks.make_policy_network(
71 | parametric_action_distribution.param_size,
72 | observation_size,
73 | preprocess_observations_fn=preprocess_observations_fn,
74 | hidden_layer_sizes=hidden_layer_sizes,
75 | activation=activation,
76 | kernel_init=linen.initializers.orthogonal(0.01),
77 | layer_norm=layer_norm,
78 | )
79 | return APGNetworks(
80 | policy_network=policy_network,
81 | parametric_action_distribution=parametric_action_distribution,
82 | )
83 |
--------------------------------------------------------------------------------
/brax/training/agents/ars/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/agents/ars/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ARS networks."""
16 |
17 | from typing import Tuple
18 |
19 | from brax.training import networks
20 | from brax.training import types
21 | from brax.training.types import PRNGKey
22 | import jax.numpy as jnp
23 |
24 | ARSNetwork = networks.FeedForwardNetwork
25 |
26 |
27 | def make_policy_network(
28 | observation_size: int,
29 | action_size: int,
30 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
31 | ) -> ARSNetwork:
32 | """Creates a policy network."""
33 |
34 | def apply(processor_params, policy_params, obs):
35 | obs = preprocess_observations_fn(obs, processor_params)
36 | return jnp.matmul(obs, policy_params)
37 |
38 | return ARSNetwork(
39 | init=lambda _: jnp.zeros((observation_size, action_size)), apply=apply
40 | )
41 |
42 |
43 | def make_inference_fn(policy_network: ARSNetwork):
44 | """Creates params and inference function for the ARS agent."""
45 |
46 | def make_policy(params: types.PolicyParams) -> types.Policy:
47 |
48 | def policy(
49 | observations: types.Observation, unused_key_sample: PRNGKey
50 | ) -> Tuple[types.Action, types.Extra]:
51 | return policy_network.apply(*params, observations), {}
52 |
53 | return policy
54 |
55 | return make_policy
56 |
--------------------------------------------------------------------------------
/brax/training/agents/ars/train_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Augmented Random Search training tests."""
16 |
17 | import pickle
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from brax import envs
22 | from brax.training.acme import running_statistics
23 | from brax.training.agents.ars import networks as ars_networks
24 | from brax.training.agents.ars import train as ars
25 | import jax
26 |
27 |
28 | class ARSTest(parameterized.TestCase):
29 | """Tests for ARS module."""
30 |
31 | @parameterized.parameters(True, False)
32 | def testModelEncoding(self, normalize_observations):
33 | env = envs.get_environment('fast')
34 | _, params, _ = ars.train(
35 | env,
36 | num_timesteps=128,
37 | episode_length=128,
38 | normalize_observations=normalize_observations,
39 | )
40 | normalize_fn = lambda x, y: x
41 | if normalize_observations:
42 | normalize_fn = running_statistics.normalize
43 | ars_network = ars_networks.make_policy_network(
44 | env.observation_size, env.action_size, normalize_fn
45 | )
46 | inference = ars_networks.make_inference_fn(ars_network)
47 | byte_encoding = pickle.dumps(params)
48 | decoded_params = pickle.loads(byte_encoding)
49 |
50 | # Compute one action.
51 | state = env.reset(jax.random.PRNGKey(0))
52 | action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0]
53 | env.step(state, action)
54 |
55 | def testTrainDomainRandomize(self):
56 | """Test with domain randomization."""
57 |
58 | def rand_fn(sys, rng):
59 | @jax.vmap
60 | def get_offset(rng):
61 | offset = jax.random.uniform(rng, shape=(3,), minval=-0.1, maxval=0.1)
62 | pos = sys.link.transform.pos.at[0].set(offset)
63 | return pos
64 |
65 | sys_v = sys.tree_replace({'link.inertia.transform.pos': get_offset(rng)})
66 | in_axes = jax.tree.map(lambda x: None, sys)
67 | in_axes = in_axes.tree_replace({'link.inertia.transform.pos': 0})
68 | return sys_v, in_axes
69 |
70 | _, _, _ = ars.train(
71 | envs.get_environment('inverted_pendulum', backend='spring'),
72 | num_timesteps=128,
73 | episode_length=128,
74 | normalize_observations=True,
75 | randomization_fn=rand_fn,
76 | )
77 |
78 |
79 | if __name__ == '__main__':
80 | absltest.main()
81 |
--------------------------------------------------------------------------------
/brax/training/agents/bc/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/agents/bc/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Checkpointing for BC."""
16 |
17 | from typing import Any, Union
18 |
19 | from brax.training import checkpoint
20 | from brax.training import types
21 | from brax.training.agents.bc import networks as bc_networks
22 | from etils import epath
23 | from ml_collections import config_dict
24 |
25 | _CONFIG_FNAME = 'bc_network_config.json'
26 |
27 |
28 | def save(
29 | path: Union[str, epath.Path],
30 | step: int,
31 | params: Any,
32 | config: config_dict.ConfigDict,
33 | ):
34 | """Saves a checkpoint."""
35 | return checkpoint.save(path, step, params, config, _CONFIG_FNAME)
36 |
37 |
38 | def load(
39 | path: Union[str, epath.Path],
40 | ):
41 | """Loads checkpoint."""
42 | return checkpoint.load(path)
43 |
44 |
45 | def network_config(
46 | observation_size: types.ObservationSize,
47 | action_size: int,
48 | normalize_observations: bool,
49 | network_factory: types.NetworkFactory[bc_networks.BCNetworks],
50 | ) -> config_dict.ConfigDict:
51 | """Returns a config dict for re-creating a network from a checkpoint."""
52 | return checkpoint.network_config(
53 | observation_size, action_size, normalize_observations, network_factory
54 | )
55 |
56 |
57 | def _get_bc_network(
58 | config: config_dict.ConfigDict,
59 | network_factory: types.NetworkFactory[bc_networks.BCNetworks],
60 | ) -> bc_networks.BCNetworks:
61 | """Generates a BC network given config."""
62 | return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type
63 |
64 |
65 | def load_config(
66 | path: Union[str, epath.Path],
67 | ) -> config_dict.ConfigDict:
68 | """Loads BC config from checkpoint."""
69 | path = epath.Path(path)
70 | config_path = path / _CONFIG_FNAME
71 | return checkpoint.load_config(config_path)
72 |
73 |
74 | def load_policy(
75 | path: Union[str, epath.Path],
76 | network_factory: types.NetworkFactory[bc_networks.BCNetworks],
77 | deterministic: bool = True,
78 | ):
79 | """Loads policy inference function from BC checkpoint.
80 |
81 | The policy is always deterministic.
82 | """
83 | path = epath.Path(path)
84 | config = load_config(path.parent)
85 | params = load(path)
86 | bc_network = _get_bc_network(config, network_factory)
87 | make_inference_fn = bc_networks.make_inference_fn(bc_network)
88 |
89 | return make_inference_fn(params, deterministic=deterministic)
90 |
--------------------------------------------------------------------------------
/brax/training/agents/bc/losses.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Losses for BC."""
16 |
17 | from typing import Any, Callable, Dict, Tuple
18 |
19 | from brax.training.agents.bc import networks
20 | from brax.training.types import Params
21 | import jax.numpy as jp
22 |
23 |
24 | # Vanilla L2 with postprocessing
25 | def bc_loss(
26 | params: Params,
27 | normalizer_params: Any,
28 | data: Dict,
29 | make_policy: Callable[[Tuple[Any, Params]], networks.BCInferenceFn],
30 | ):
31 | policy = make_policy((normalizer_params, params))
32 | _, action_extras = policy(data['observations'], key_sample=None) # pytype: disable=wrong-keyword-args
33 | actor_loss = (
34 | (
35 | (
36 | jp.tanh(action_extras['loc'])
37 | - jp.tanh(data['teacher_action_extras']['loc'])
38 | )
39 | ** 2
40 | )
41 | .sum(-1)
42 | .mean()
43 | )
44 | actor_loss = actor_loss.mean()
45 | return actor_loss, {'actor_loss': actor_loss, 'mse_loss': actor_loss}
46 |
--------------------------------------------------------------------------------
/brax/training/agents/es/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/agents/es/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evolution strategy networks."""
16 |
17 | from typing import Sequence, Tuple
18 |
19 | from brax.training import distribution
20 | from brax.training import networks
21 | from brax.training import types
22 | from brax.training.types import PRNGKey
23 | import flax
24 | from flax import linen
25 |
26 |
27 | @flax.struct.dataclass
28 | class ESNetworks:
29 | policy_network: networks.FeedForwardNetwork
30 | parametric_action_distribution: distribution.ParametricDistribution
31 |
32 |
33 | def make_inference_fn(es_networks: ESNetworks):
34 | """Creates params and inference function for the ES agent."""
35 |
36 | def make_policy(
37 | params: types.PolicyParams, deterministic: bool = False
38 | ) -> types.Policy:
39 |
40 | def policy(
41 | observations: types.Observation, key_sample: PRNGKey
42 | ) -> Tuple[types.Action, types.Extra]:
43 | logits = es_networks.policy_network.apply(*params, observations)
44 | if deterministic:
45 | return es_networks.parametric_action_distribution.mode(logits), {}
46 | return (
47 | es_networks.parametric_action_distribution.sample(logits, key_sample),
48 | {},
49 | )
50 |
51 | return policy
52 |
53 | return make_policy
54 |
55 |
56 | def make_es_networks(
57 | observation_size: int,
58 | action_size: int,
59 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
60 | hidden_layer_sizes: Sequence[int] = (32,) * 4,
61 | activation: networks.ActivationFn = linen.relu,
62 | ) -> ESNetworks:
63 | """Make ES networks."""
64 | parametric_action_distribution = distribution.NormalTanhDistribution(
65 | event_size=action_size
66 | )
67 | policy_network = networks.make_policy_network(
68 | parametric_action_distribution.param_size,
69 | observation_size,
70 | preprocess_observations_fn=preprocess_observations_fn,
71 | hidden_layer_sizes=hidden_layer_sizes,
72 | activation=activation,
73 | )
74 | return ESNetworks(
75 | policy_network=policy_network,
76 | parametric_action_distribution=parametric_action_distribution,
77 | )
78 |
--------------------------------------------------------------------------------
/brax/training/agents/es/train_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Evolution Strategy training tests."""
16 |
17 | import pickle
18 |
19 | from absl.testing import absltest
20 | from absl.testing import parameterized
21 | from brax import envs
22 | from brax.training.acme import running_statistics
23 | from brax.training.agents.es import networks as es_networks
24 | from brax.training.agents.es import train as es
25 | import jax
26 |
27 |
28 | class ESTest(parameterized.TestCase):
29 | """Tests for ES module."""
30 |
31 | def testTrain(self):
32 | """Test ES with a simple env."""
33 | _, _, metrics = es.train(
34 | environment=envs.get_environment('fast'),
35 | num_timesteps=65536,
36 | episode_length=128,
37 | learning_rate=0.1,
38 | )
39 | self.assertGreater(metrics['eval/episode_reward'], 140)
40 |
41 | @parameterized.parameters(True, False)
42 | def testModelEncoding(self, normalize_observations):
43 | env = envs.get_environment('fast')
44 | _, params, _ = es.train(
45 | env,
46 | num_timesteps=128,
47 | episode_length=128,
48 | normalize_observations=normalize_observations,
49 | )
50 | normalize_fn = lambda x, y: x
51 | if normalize_observations:
52 | normalize_fn = running_statistics.normalize
53 | es_network = es_networks.make_es_networks(
54 | env.observation_size, env.action_size, normalize_fn
55 | )
56 | inference = es_networks.make_inference_fn(es_network)
57 | byte_encoding = pickle.dumps(params)
58 | decoded_params = pickle.loads(byte_encoding)
59 |
60 | # Compute one action.
61 | state = env.reset(jax.random.PRNGKey(0))
62 | action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0]
63 | env.step(state, action)
64 |
65 | def testTrainDomainRandomize(self):
66 | """Test with domain randomization."""
67 |
68 | def rand_fn(sys, rng):
69 | @jax.vmap
70 | def get_offset(rng):
71 | offset = jax.random.uniform(rng, shape=(3,), minval=-0.1, maxval=0.1)
72 | pos = sys.link.transform.pos.at[0].set(offset)
73 | return pos
74 |
75 | sys_v = sys.tree_replace({'link.inertia.transform.pos': get_offset(rng)})
76 | in_axes = jax.tree.map(lambda x: None, sys)
77 | in_axes = in_axes.tree_replace({'link.inertia.transform.pos': 0})
78 | return sys_v, in_axes
79 |
80 | _, _, _ = es.train(
81 | envs.get_environment('inverted_pendulum', backend='spring'),
82 | num_timesteps=1280,
83 | episode_length=128,
84 | randomization_fn=rand_fn,
85 | )
86 |
87 |
88 | if __name__ == '__main__':
89 | jax.config.update('jax_threefry_partitionable', False)
90 | absltest.main()
91 |
--------------------------------------------------------------------------------
/brax/training/agents/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/agents/ppo/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Checkpointing for PPO."""
16 |
17 | from typing import Any, Union
18 |
19 | from brax.training import checkpoint
20 | from brax.training import types
21 | from brax.training.agents.ppo import networks as ppo_networks
22 | from etils import epath
23 | from ml_collections import config_dict
24 |
25 | _CONFIG_FNAME = 'ppo_network_config.json'
26 |
27 |
28 | def save(
29 | path: Union[str, epath.Path],
30 | step: int,
31 | params: Any,
32 | config: config_dict.ConfigDict,
33 | ):
34 | """Saves a checkpoint."""
35 | return checkpoint.save(path, step, params, config, _CONFIG_FNAME)
36 |
37 |
38 | def load(
39 | path: Union[str, epath.Path],
40 | ):
41 | """Loads checkpoint."""
42 | return checkpoint.load(path)
43 |
44 |
45 | def network_config(
46 | observation_size: types.ObservationSize,
47 | action_size: int,
48 | normalize_observations: bool,
49 | network_factory: types.NetworkFactory[ppo_networks.PPONetworks],
50 | ) -> config_dict.ConfigDict:
51 | """Returns a config dict for re-creating a network from a checkpoint."""
52 | return checkpoint.network_config(
53 | observation_size, action_size, normalize_observations, network_factory
54 | )
55 |
56 |
57 | def _get_ppo_network(
58 | config: config_dict.ConfigDict,
59 | network_factory: types.NetworkFactory[ppo_networks.PPONetworks],
60 | ) -> ppo_networks.PPONetworks:
61 | """Generates a PPO network given config."""
62 | return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type
63 |
64 |
65 | def load_config(
66 | path: Union[str, epath.Path],
67 | ) -> config_dict.ConfigDict:
68 | """Loads PPO config from checkpoint."""
69 | path = epath.Path(path)
70 | config_path = path / _CONFIG_FNAME
71 | return checkpoint.load_config(config_path)
72 |
73 |
74 | def load_policy(
75 | path: Union[str, epath.Path],
76 | network_factory: types.NetworkFactory[
77 | ppo_networks.PPONetworks
78 | ] = ppo_networks.make_ppo_networks,
79 | deterministic: bool = True,
80 | ):
81 | """Loads policy inference function from PPO checkpoint."""
82 | path = epath.Path(path)
83 | config = load_config(path)
84 | params = load(path)
85 | ppo_network = _get_ppo_network(config, network_factory)
86 | make_inference_fn = ppo_networks.make_inference_fn(ppo_network)
87 |
88 | return make_inference_fn(params, deterministic=deterministic)
89 |
--------------------------------------------------------------------------------
/brax/training/agents/ppo/networks_vision.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """PPO vision networks."""
16 |
17 | from typing import Any, Callable, Mapping, Sequence, Tuple
18 |
19 | from brax.training import distribution
20 | from brax.training import networks
21 | from brax.training import types
22 | import flax
23 | from flax import linen
24 | import jax.numpy as jp
25 |
26 |
27 | ModuleDef = Any
28 | ActivationFn = Callable[[jp.ndarray], jp.ndarray]
29 | Initializer = Callable[..., Any]
30 |
31 |
32 | @flax.struct.dataclass
33 | class PPONetworks:
34 | policy_network: networks.FeedForwardNetwork
35 | value_network: networks.FeedForwardNetwork
36 | parametric_action_distribution: distribution.ParametricDistribution
37 |
38 |
39 | def make_ppo_networks_vision(
40 | observation_size: Mapping[str, Tuple[int, ...]],
41 | action_size: int,
42 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
43 | policy_hidden_layer_sizes: Sequence[int] = (256, 256),
44 | value_hidden_layer_sizes: Sequence[int] = (256, 256),
45 | activation: ActivationFn = linen.swish,
46 | normalise_channels: bool = False,
47 | policy_obs_key: str = "",
48 | value_obs_key: str = "",
49 | ) -> PPONetworks:
50 | """Make Vision PPO networks with preprocessor."""
51 |
52 | parametric_action_distribution = distribution.NormalTanhDistribution(
53 | event_size=action_size
54 | )
55 |
56 | policy_network = networks.make_policy_network_vision(
57 | observation_size=observation_size,
58 | output_size=parametric_action_distribution.param_size,
59 | preprocess_observations_fn=preprocess_observations_fn,
60 | activation=activation,
61 | hidden_layer_sizes=policy_hidden_layer_sizes,
62 | state_obs_key=policy_obs_key,
63 | normalise_channels=normalise_channels,
64 | )
65 |
66 | value_network = networks.make_value_network_vision(
67 | observation_size=observation_size,
68 | preprocess_observations_fn=preprocess_observations_fn,
69 | activation=activation,
70 | hidden_layer_sizes=value_hidden_layer_sizes,
71 | state_obs_key=value_obs_key,
72 | normalise_channels=normalise_channels,
73 | )
74 |
75 | return PPONetworks(
76 | policy_network=policy_network,
77 | value_network=value_network,
78 | parametric_action_distribution=parametric_action_distribution,
79 | )
80 |
--------------------------------------------------------------------------------
/brax/training/agents/sac/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/brax/training/agents/sac/checkpoint.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Checkpointing for SAC."""
16 |
17 | from typing import Any, Union
18 |
19 | from brax.training import checkpoint
20 | from brax.training import types
21 | from brax.training.agents.sac import networks as sac_networks
22 | from etils import epath
23 | from ml_collections import config_dict
24 |
25 | _CONFIG_FNAME = 'sac_network_config.json'
26 |
27 |
28 | def save(
29 | path: Union[str, epath.Path],
30 | step: int,
31 | params: Any,
32 | config: config_dict.ConfigDict,
33 | ):
34 | """Saves a checkpoint."""
35 | return checkpoint.save(path, step, params, config, _CONFIG_FNAME)
36 |
37 |
38 | def load(
39 | path: Union[str, epath.Path],
40 | ):
41 | """Loads SAC checkpoint."""
42 | return checkpoint.load(path)
43 |
44 |
45 | def network_config(
46 | observation_size: types.ObservationSize,
47 | action_size: int,
48 | normalize_observations: bool,
49 | network_factory: types.NetworkFactory[sac_networks.SACNetworks],
50 | ) -> config_dict.ConfigDict:
51 | """Returns a config dict for re-creating a network from a checkpoint."""
52 | return checkpoint.network_config(
53 | observation_size, action_size, normalize_observations, network_factory
54 | )
55 |
56 |
57 | def _get_network(
58 | config: config_dict.ConfigDict,
59 | network_factory: types.NetworkFactory[sac_networks.SACNetworks],
60 | ) -> sac_networks.SACNetworks:
61 | """Generates a SAC network given config."""
62 | return checkpoint.get_network(config, network_factory) # pytype: disable=bad-return-type
63 |
64 |
65 | def load_config(
66 | path: Union[str, epath.Path],
67 | ) -> config_dict.ConfigDict:
68 | """Loads SAC config from checkpoint."""
69 | path = epath.Path(path)
70 | config_path = path / _CONFIG_FNAME
71 | return checkpoint.load_config(config_path)
72 |
73 |
74 | def load_policy(
75 | path: Union[str, epath.Path],
76 | network_factory: types.NetworkFactory[
77 | sac_networks.SACNetworks
78 | ] = sac_networks.make_sac_networks,
79 | deterministic: bool = True,
80 | ):
81 | """Loads policy inference function from SAC checkpoint."""
82 | path = epath.Path(path)
83 | config = load_config(path)
84 | params = load(path)
85 | sac_network = _get_network(config, network_factory)
86 | make_inference_fn = sac_networks.make_inference_fn(sac_network)
87 |
88 | return make_inference_fn(params, deterministic=deterministic)
89 |
--------------------------------------------------------------------------------
/brax/training/agents/sac/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """SAC networks."""
16 |
17 | from typing import Sequence, Tuple
18 |
19 | from brax.training import distribution
20 | from brax.training import networks
21 | from brax.training import types
22 | from brax.training.types import PRNGKey
23 | import flax
24 | from flax import linen
25 |
26 |
27 | @flax.struct.dataclass
28 | class SACNetworks:
29 | policy_network: networks.FeedForwardNetwork
30 | q_network: networks.FeedForwardNetwork
31 | parametric_action_distribution: distribution.ParametricDistribution
32 |
33 |
34 | def make_inference_fn(sac_networks: SACNetworks):
35 | """Creates params and inference function for the SAC agent."""
36 |
37 | def make_policy(
38 | params: types.PolicyParams, deterministic: bool = False
39 | ) -> types.Policy:
40 |
41 | def policy(
42 | observations: types.Observation, key_sample: PRNGKey
43 | ) -> Tuple[types.Action, types.Extra]:
44 | logits = sac_networks.policy_network.apply(*params, observations)
45 | if deterministic:
46 | return sac_networks.parametric_action_distribution.mode(logits), {}
47 | return (
48 | sac_networks.parametric_action_distribution.sample(
49 | logits, key_sample
50 | ),
51 | {},
52 | )
53 |
54 | return policy
55 |
56 | return make_policy
57 |
58 |
59 | def make_sac_networks(
60 | observation_size: int,
61 | action_size: int,
62 | preprocess_observations_fn: types.PreprocessObservationFn = types.identity_observation_preprocessor,
63 | hidden_layer_sizes: Sequence[int] = (256, 256),
64 | activation: networks.ActivationFn = linen.relu,
65 | policy_network_layer_norm: bool = False,
66 | q_network_layer_norm: bool = False,
67 | ) -> SACNetworks:
68 | """Make SAC networks."""
69 | parametric_action_distribution = distribution.NormalTanhDistribution(
70 | event_size=action_size
71 | )
72 | policy_network = networks.make_policy_network(
73 | parametric_action_distribution.param_size,
74 | observation_size,
75 | preprocess_observations_fn=preprocess_observations_fn,
76 | hidden_layer_sizes=hidden_layer_sizes,
77 | activation=activation,
78 | layer_norm=policy_network_layer_norm,
79 | )
80 | q_network = networks.make_q_network(
81 | observation_size,
82 | action_size,
83 | preprocess_observations_fn=preprocess_observations_fn,
84 | hidden_layer_sizes=hidden_layer_sizes,
85 | activation=activation,
86 | layer_norm=q_network_layer_norm,
87 | )
88 | return SACNetworks(
89 | policy_network=policy_network,
90 | q_network=q_network,
91 | parametric_action_distribution=parametric_action_distribution,
92 | )
93 |
--------------------------------------------------------------------------------
/brax/training/gradients.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Brax training gradient utility functions."""
16 |
17 | from typing import Callable, Optional
18 |
19 | import jax
20 | import optax
21 |
22 |
23 | def loss_and_pgrad(
24 | loss_fn: Callable[..., float],
25 | pmap_axis_name: Optional[str],
26 | has_aux: bool = False,
27 | ):
28 | g = jax.value_and_grad(loss_fn, has_aux=has_aux)
29 |
30 | def h(*args, **kwargs):
31 | value, grad = g(*args, **kwargs)
32 | return value, jax.lax.pmean(grad, axis_name=pmap_axis_name)
33 |
34 | return g if pmap_axis_name is None else h
35 |
36 |
37 | def gradient_update_fn(
38 | loss_fn: Callable[..., float],
39 | optimizer: optax.GradientTransformation,
40 | pmap_axis_name: Optional[str],
41 | has_aux: bool = False,
42 | ):
43 | """Wrapper of the loss function that apply gradient updates.
44 |
45 | Args:
46 | loss_fn: The loss function.
47 | optimizer: The optimizer to apply gradients.
48 | pmap_axis_name: If relevant, the name of the pmap axis to synchronize
49 | gradients.
50 | has_aux: Whether the loss_fn has auxiliary data.
51 |
52 | Returns:
53 | A function that takes the same argument as the loss function plus the
54 | optimizer state. The output of this function is the loss, the new parameter,
55 | and the new optimizer state.
56 | """
57 | loss_and_pgrad_fn = loss_and_pgrad(
58 | loss_fn, pmap_axis_name=pmap_axis_name, has_aux=has_aux
59 | )
60 |
61 | def f(*args, optimizer_state):
62 | value, grads = loss_and_pgrad_fn(*args)
63 | params_update, optimizer_state = optimizer.update(grads, optimizer_state)
64 | params = optax.apply_updates(args[0], params_update)
65 | return value, params, optimizer_state
66 |
67 | return f
68 |
--------------------------------------------------------------------------------
/brax/training/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Logger for training metrics."""
16 |
17 | import collections
18 | import logging
19 | from jax import numpy as jnp
20 | import numpy as np
21 |
22 |
23 | class EpisodeMetricsLogger:
24 | """Logs training metrics for each episode."""
25 |
26 | def __init__(
27 | self, buffer_size=100, steps_between_logging=1e5, progress_fn=None
28 | ):
29 | self._metrics_buffer = collections.defaultdict(
30 | lambda: collections.deque(maxlen=buffer_size)
31 | )
32 | self._buffer_size = buffer_size
33 | self._steps_between_logging = steps_between_logging
34 | self._num_steps = 0
35 | self._last_log_steps = 0
36 | self._log_count = 0
37 | self._progress_fn = progress_fn
38 |
39 | def update_episode_metrics(self, metrics, dones):
40 | self._num_steps += np.prod(dones.shape)
41 | if jnp.sum(dones) > 0:
42 | for name, metric in metrics.items():
43 | done_metrics = metric[dones.astype(bool)].flatten().tolist()
44 | self._metrics_buffer[name].extend(done_metrics)
45 | if self._num_steps - self._last_log_steps >= self._steps_between_logging:
46 | self.log_metrics()
47 | self._last_log_steps = self._num_steps
48 |
49 | def log_metrics(self, pad=35):
50 | """Log metrics to console."""
51 | self._log_count += 1
52 | log_string = (
53 | f"\n{'Steps':>{pad}} Env: {self._num_steps} Log: {self._log_count}\n"
54 | )
55 | mean_metrics = {}
56 | for metric_name in self._metrics_buffer:
57 | mean_metrics[metric_name] = np.mean(self._metrics_buffer[metric_name])
58 | log_string += (
59 | f"{f'Episode {metric_name}:':>{pad}}"
60 | f" {mean_metrics[metric_name]:.4f}\n"
61 | )
62 | logging.info(log_string)
63 | if self._progress_fn is not None:
64 | self._progress_fn(
65 | int(self._num_steps),
66 | {f"episode/{name}": value for name, value in mean_metrics.items()},
67 | )
68 |
--------------------------------------------------------------------------------
/brax/training/pmap.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The Brax Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Input normalization utils."""
16 |
17 | import functools
18 | from typing import Any
19 |
20 | import jax
21 | import jax.numpy as jnp
22 |
23 |
24 | def bcast_local_devices(value, local_devices_to_use=1):
25 | """Broadcasts an object to all local devices."""
26 | devices = jax.local_devices()[:local_devices_to_use]
27 | return jax.device_put_replicated(value, devices)
28 |
29 |
30 | def synchronize_hosts():
31 | if jax.process_count() == 1:
32 | return
33 | # Make sure all processes stay up until the end of main.
34 | x = jnp.ones([jax.local_device_count()])
35 | x = jax.device_get(jax.pmap(lambda x: jax.lax.psum(x, 'i'), 'i')(x))
36 | assert x[0] == jax.device_count()
37 |
38 |
39 | def _fingerprint(x: Any) -> float:
40 | sums = jax.tree_util.tree_map(jnp.sum, x)
41 | return jax.tree_util.tree_reduce(lambda x, y: x + y, sums)
42 |
43 |
44 | def is_replicated(x: Any, axis_name: str) -> jnp.ndarray:
45 | """Returns whether x is replicated.
46 |
47 | Should be called inside a function pmapped along 'axis_name'
48 | Args:
49 | x: Object to check replication.
50 | axis_name: pmap axis_name.
51 |
52 | Returns:
53 | boolean whether x is replicated.
54 | """
55 | fp = _fingerprint(x)
56 | return jax.lax.pmin(fp, axis_name=axis_name) == jax.lax.pmax(
57 | fp, axis_name=axis_name
58 | )
59 |
60 |
61 | def assert_is_replicated(x: Any, debug: Any = None):
62 | """Returns whether x is replicated.
63 |
64 | Should be called from a non-jitted code.
65 | Args:
66 | x: Object to check replication.
67 | debug: Debug message in case of failure.
68 | """
69 | f = functools.partial(is_replicated, axis_name='i')
70 | assert jax.pmap(f, axis_name='i')(x)[0], debug
71 |
--------------------------------------------------------------------------------
/brax/visualizer/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/brax/visualizer/favicon.ico
--------------------------------------------------------------------------------
/brax/visualizer/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Brax visualizer
6 |
7 |
8 |
9 |
27 |
28 |
29 |
39 |
40 |
41 |
42 |
55 |
56 |
57 |
58 |
59 |
60 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/brax/visualizer/js/selector.js:
--------------------------------------------------------------------------------
1 | import * as THREE from 'three';
2 |
3 | class Selector extends THREE.EventDispatcher {
4 | constructor(viewer) {
5 | super();
6 |
7 | this.viewer = viewer;
8 | this.raycaster = new THREE.Raycaster();
9 | this.raycaster.layers.set(1);
10 | this.mousePos = new THREE.Vector2();
11 | this.selected = null;
12 | this.hovered = null;
13 | this.dragging = false;
14 | this.selectable = viewer.scene.children.filter(
15 | o => o instanceof THREE.Group && o.name != 'world');
16 |
17 | const domElement = this.viewer.domElement;
18 | domElement.addEventListener('pointermove', this.onPointerMove.bind(this));
19 | domElement.addEventListener('pointerdown', this.onPointerDown.bind(this));
20 | domElement.addEventListener('pointerup', this.onPointerUp.bind(this));
21 | }
22 |
23 | onPointerMove(event) {
24 | event.preventDefault();
25 | this.dragging = true;
26 |
27 | const rect = this.viewer.domElement.getBoundingClientRect();
28 | this.mousePos.x = ((event.clientX - rect.left) / rect.width) * 2 - 1;
29 | this.mousePos.y = -((event.clientY - rect.top) / rect.height) * 2 + 1;
30 | this.raycaster.setFromCamera(this.mousePos, this.viewer.camera);
31 | const intersections =
32 | this.raycaster.intersectObjects(this.selectable, true);
33 |
34 | if (intersections.length > 0) {
35 | let object = intersections[0].object;
36 | while (object.parent && !object.name) {
37 | object = object.parent;
38 | }
39 | if (this.hovered !== object) {
40 | if (this.hovered) {
41 | this.dispatchEvent({type: 'hoveroff', object: this.hovered});
42 | }
43 | this.hovered = object;
44 | this.dispatchEvent({type: 'hoveron', object: this.hovered});
45 | this.viewer.domElement.style.cursor = 'pointer';
46 | }
47 | } else if (this.hovered !== null) {
48 | this.dispatchEvent({type: 'hoveroff', object: this.hovered});
49 |
50 | this.viewer.domElement.style.cursor = 'auto';
51 | this.hovered = null;
52 | }
53 | }
54 |
55 | onPointerDown(event) {
56 | event.preventDefault();
57 | this.dragging = false;
58 | }
59 |
60 | onPointerUp(event) {
61 | event.preventDefault();
62 | if (this.dragging) return; // ignore drag events, only handle clicks
63 | this.raycaster.setFromCamera(this.mousePos, this.viewer.camera);
64 | const intersections =
65 | this.raycaster.intersectObjects(this.selectable, true);
66 |
67 | if (intersections.length > 0) {
68 | let object = intersections[0].object;
69 | while (object.parent && !object.name) {
70 | object = object.parent;
71 | }
72 | if (this.selected !== object) {
73 | if (this.selected) {
74 | this.dispatchEvent({type: 'deselect', object: this.selected});
75 | }
76 | this.selected = object;
77 | this.dispatchEvent({type: 'select', object: this.selected});
78 | }
79 | } else if (this.selected !== null) {
80 | this.dispatchEvent({type: 'deselect', object: this.selected});
81 | this.selected = null;
82 | }
83 | }
84 | }
85 |
86 | export {Selector};
87 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | In this directory, we provide results from hyperparameter sweeps for SAC and PPO as zipped json files. Within, hyperparameters are sorted by environment, and then by performance.\
2 | \
3 | Hyperparameter ranges for PPO:\
4 | \
5 | total_env_steps: 10_000_000 , 500_000_000 (split between files with these names)\
6 | eval_frequency: 20\
7 | reward_scaling: 1, 5\
8 | episode_length: 1000\
9 | normalize_observations: True\
10 | action_repeat: 1\
11 | entropy_cost: 1e-3\
12 | learning_rate: 3e-4\
13 | discounting: 0.99, 0.997\
14 | num_envs: 2048\
15 | unroll_length: 1, 5, 20\
16 | batch_size: 512,1024\
17 | num_minibatches: 4, 8, 16, 32\
18 | num_update_epochs: 2, 4, 8\
19 | \
20 | \
21 | Hyperparameter ranges for SAC:\
22 | \
23 | env: 'halfcheetah'\
24 | total_env_steps: 1048576 * 5\
25 | eval_frequency: 131012\
26 | reward_scaling: 5, 10, 30\
27 | episode_length: 1000\
28 | normalize_observations: True\
29 | action_repeat: 1\
30 | learning_rate: 3e-4, 6e-4\
31 | discounting: 0.95, 0.99, 0.997\
32 | num_envs: 64, 128, 256\
33 | min_replay_size: 8192\
34 | max_replay_size: 1048576\
35 | batch_size: 128, 256, 512\
36 | grad_updates_per_step: 0.125 / 2, 0.125, 0.125 * 2
37 |
--------------------------------------------------------------------------------
/datasets/ppo_10_million_steps.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/datasets/ppo_10_million_steps.tar.gz
--------------------------------------------------------------------------------
/datasets/ppo_500_million_steps.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/datasets/ppo_500_million_steps.tar.gz
--------------------------------------------------------------------------------
/datasets/sac_5_million_steps.tar.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/datasets/sac_5_million_steps.tar.gz
--------------------------------------------------------------------------------
/docs/contributing.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code Reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google/conduct/).
29 |
--------------------------------------------------------------------------------
/docs/img/a1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/a1.gif
--------------------------------------------------------------------------------
/docs/img/ant.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/ant.gif
--------------------------------------------------------------------------------
/docs/img/ant_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/ant_v2.gif
--------------------------------------------------------------------------------
/docs/img/brax_logo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/brax_logo.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/ant_diayn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn.png
--------------------------------------------------------------------------------
/docs/img/braxlines/ant_diayn_skill1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill1.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/ant_diayn_skill2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill2.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/ant_diayn_skill3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill3.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/ant_diayn_skill4.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_diayn_skill4.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/ant_smm.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_smm.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/ant_smm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/ant_smm.png
--------------------------------------------------------------------------------
/docs/img/braxlines/humanoid_diayn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn.png
--------------------------------------------------------------------------------
/docs/img/braxlines/humanoid_diayn_skill1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn_skill1.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/humanoid_diayn_skill2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn_skill2.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/humanoid_diayn_skill3.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_diayn_skill3.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/humanoid_smm.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_smm.gif
--------------------------------------------------------------------------------
/docs/img/braxlines/humanoid_smm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/humanoid_smm.png
--------------------------------------------------------------------------------
/docs/img/braxlines/sketches.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/braxlines/sketches.png
--------------------------------------------------------------------------------
/docs/img/composer/ant_chase.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/ant_chase.gif
--------------------------------------------------------------------------------
/docs/img/composer/ant_push.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/ant_push.gif
--------------------------------------------------------------------------------
/docs/img/composer/pro_ant1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/pro_ant1.gif
--------------------------------------------------------------------------------
/docs/img/composer/pro_ant2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/composer/pro_ant2.gif
--------------------------------------------------------------------------------
/docs/img/fetch.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/fetch.gif
--------------------------------------------------------------------------------
/docs/img/grasp.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/grasp.gif
--------------------------------------------------------------------------------
/docs/img/halfcheetah.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/halfcheetah.gif
--------------------------------------------------------------------------------
/docs/img/humanoid.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/humanoid.gif
--------------------------------------------------------------------------------
/docs/img/humanoid_v2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/humanoid_v2.gif
--------------------------------------------------------------------------------
/docs/img/ur5e.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/brax/d59e4db582e98da1734c098aed7219271c940bda/docs/img/ur5e.gif
--------------------------------------------------------------------------------
/docs/release-notes/next-release.md:
--------------------------------------------------------------------------------
1 | # Brax Release Notes
2 |
3 | * Make sure brax is compatible with MJX API change https://github.com/google-deepmind/mujoco/commit/6cfea719850f7dbc91b9d6fb36deafb3c5d04eaf.
4 | * Fix #595, patch for checkpointing with activations other than relu when no checkpoint path is specified.
5 | * Merge https://github.com/google/brax/pull/582.
6 | * Merge https://github.com/google/brax/pull/549.
7 | * Remove brax/v1.
8 | * Issue a warning in brax.io.mjcf. We point users to
9 | [MJX](https://github.com/google-deepmind/mujoco/tree/main/mjx) and
10 | [MuJoCo Playground](https://github.com/google-deepmind/mujoco_playground) instead. Brax training is still
11 | actively maintained however.
12 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.0.11.md:
--------------------------------------------------------------------------------
1 | # Brax Version 0.0.11 Release Notes
2 |
3 | This version introduces a significant overhaul to the physics algorithms. We now support position based dynamics for resolving joint and collision constraints. See [this paper](https://matthias-research.github.io/pages/publications/PBDBodies.pdf) for details about PBD.
4 |
5 | The most noticeable difference to prior versions of Brax is that joints are now modeled as infinitely stiff, whereas before they were stiff damped spring systems. This new physics is now default, and all environments use PBD-based joints and collisions by default.
6 |
7 | If you would like to preserve the behavior used in previous versions of brax, you can either:
8 |
9 | 1. Version pin to `0.0.10` – the version right before this upgrade. While you will not get the latest and greatest improvements to Brax, you will have unambiguously consistent behavior.
10 |
11 | 2. Add `dynamics_mode: "legacy_spring"` to your brax configuration file. This causes brax to navigate the old codepath.
12 |
13 | 3. Supply `legacy_spring=True,` as a kwarg to env creation (without `s). This causes Brax to load the older config for all the environments currently defined in Brax (see the logic in the init functions of each env for details).
14 |
15 | Thank you for using Brax, and feel free to open an Issue if you have any questions!
--------------------------------------------------------------------------------
/docs/release-notes/v0.0.12.md:
--------------------------------------------------------------------------------
1 | # Brax v0.0.12 Release Notes
2 |
3 | This release fixes a javascript bug that is preventing the viewer from rendering.
4 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.0.13.md:
--------------------------------------------------------------------------------
1 | # Brax v0.0.13 Release Notes
2 |
3 | This release fixes a few bugs in the collision handling in PBD, and adds
4 | support for specifying collider visibility, color, and contact participation.
5 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.0.14.md:
--------------------------------------------------------------------------------
1 | # Brax v0.0.14 Release Notes
2 |
3 | This release includes a refactor of the training code to make it more modular and hackable, with each algorithm now as a separate submodule under `brax.training.agents`.
4 |
5 | This release also updates references to the deprecated jax.tree* functions to their new home in jax.tree_util, fixes a few bugs in physics/collision code, and adds an initial implementation of box-box collisions.
6 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.0.15.md:
--------------------------------------------------------------------------------
1 | # Brax v0.0.14 Release Notes
2 |
3 | This release includes a refactor of the training code to make it more modular and hackable, with each algorithm now as a separate submodule under `brax.training.agents`.
4 |
5 | This release also updates references to the deprecated jax.tree* functions to their new home in jax.tree_util, fixes a few bugs in physics/collision code, and adds an initial implementation of box-box collisions.
6 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.0.16.md:
--------------------------------------------------------------------------------
1 | # Brax v0.0.16 Release Notes
2 |
3 | This release adds a new module: `brax.experimental.tracing` that allows for domain randomization during training. This release also adds support for placing replay buffers on device using `pjit` which allows for more configurable parallelism across many devices. Finally this release includes a number of small bug fixes.
4 |
5 | This will be the final release before we release a preview of a significant API change, so users may want to pin to this version if API stability is important.
--------------------------------------------------------------------------------
/docs/release-notes/v0.1.0.md:
--------------------------------------------------------------------------------
1 | # Brax v0.1.0 Release Notes
2 |
3 | This minor release adds a preview of a major overhaul to Brax's API and
4 | functionality. This overhaul (found in the `v2/` folder) will eventually become
5 | Brax's first stable (1.0) release.
6 |
7 | The new features of Brax v2 include:
8 |
9 | * Generalized physics backend.
10 | * Continued support for the Spring physics backends. PBD will soon follow.
11 | * Direct support for Mujoco XML format, and URDF by association.
12 | * Fully traceable System object.
13 | * Env API that better supports custom physics backends.
14 | * Open sourced visualizer server.
--------------------------------------------------------------------------------
/docs/release-notes/v0.1.1.md:
--------------------------------------------------------------------------------
1 | # Brax v0.1.1 Release Notes
2 |
3 | This patch release includes:
4 | * Contact debugger added to the visualizer.
5 | * Refactor of how state is passed to the v2 visualizer.
6 | * Small clean up to v2/generalized.
7 | * Convex-convex collisions added to v2.
8 | * URDF loader gets mass and revolute joint limits.
--------------------------------------------------------------------------------
/docs/release-notes/v0.1.2.md:
--------------------------------------------------------------------------------
1 | # Brax v0.1.2 Release Notes
2 |
3 | This patch release includes:
4 | * Add positional physics pipeline.
5 | * Refactor spring physics pipeline.
6 | * Add more brax envs.
7 | * Add better documentation for brax v2 and update v2 notebooks.
8 | * Fixes to kinematics.py.
9 | * Add missing dependencies and files to MANIFEST.in
10 | * Add convex-capsule, convex-sphere, and convex-plane collisions.
11 | * Add rgba color to brax visualizer.
--------------------------------------------------------------------------------
/docs/release-notes/v0.10.0.md:
--------------------------------------------------------------------------------
1 | # Brax v0.10.0 Release Notes
2 |
3 | This minor release makes several changes to the brax API, such that [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) data structures are the core data structures used in brax. This allows for more seamless model loading from `MuJoCo` XMLs, and allows for running `MJX` physics more seamlessly in brax.
4 |
5 | * Rebase brax `System` and `State` onto `mjx.Model` and `mjx.Data`.
6 | * Separate validation logic from the model loading logic in `brax.io.mjcf`. This allows users to load an [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) model in brax, without hitting validation errors for other physics backends like `positional` and `spring`.
7 | * Remove `System.geoms`, since `brax.System` inherits from `mjx.Model` and all geom information is available in `mjx.Model`. We also update the brax viewer to work with this new schema.
8 | * Delete the brax contact library and use the contact library from `MJX`.
9 | * Use the MuJoCo renderer instead of pytinyrenderer for `brax.io.image`.
10 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.10.1.md:
--------------------------------------------------------------------------------
1 | # Brax v0.10.1 Release Notes
2 |
3 | * Fixes #460, #461, and an issue related to #353.
4 | * Removes barkour v0 joystick policy in favor of the [MJX tutorial](https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/mjx/tutorial.ipynb) training barkour vb.
5 | * Fixes #466.
6 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.10.2.md:
--------------------------------------------------------------------------------
1 | # Brax v0.10.2 Release Notes
2 |
3 | - Fix bug in rendering cylinders and planes.
4 | - Fix issue with link offsets in `io.mjcf.load_model`.
--------------------------------------------------------------------------------
/docs/release-notes/v0.10.3.md:
--------------------------------------------------------------------------------
1 | # Brax v0.10.3 Release Notes
2 |
3 | - Fix a bug in rendering capsules and cylinders with the wrong size.
4 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.10.4.md:
--------------------------------------------------------------------------------
1 | # Brax v0.10.4 Release Notes
2 |
3 | - Add support for compressing json embedded in HTML output for large models.
4 | - Remove legacy `dt` field in the brax System. We rely on MuJoCo's `opt.timestep` field instead.
5 | - Add `act` to `pipeline.init` functions. See https://mujoco.readthedocs.io/en/stable/computation/index.html#physics-state.
6 | - Updated basic APG algorithm #476, h/t @Andrew-Luo1.
7 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.10.5.md:
--------------------------------------------------------------------------------
1 | # Brax v0.10.5 Release Notes
2 |
3 | * Modify `policy_params_fn` in the PPO implementation to take in the full model params. This can be used for checkpointing models.
4 | * Add `restore_checkpoint_path` in PPO implementation.
5 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.11.0.md:
--------------------------------------------------------------------------------
1 | # Brax v0.11.0 Release Notes
2 |
3 | * Remove contact debugging from the viewer. This is a breaking change compared to older versions of brax. Old saved files will not render using the new viewer.
4 | * Added `ctrl` as input to the `MJX pipeline.init`.
5 | * Fixes to #513, #512, and #504.
6 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.12.0.md:
--------------------------------------------------------------------------------
1 | # Brax v0.12.0 Release Notes
2 |
3 | * Add boolean `wrap_env` to all brax `train` functions, which optionally wraps the env for training, or uses the env as is.
4 | * Fix bug in PPO train to return loaded checkpoint when `num_timesteps` is 0.
5 | * Add `layer_norm` to `make_q_network` and set `layer_norm` to `True` in `make_sace_networks` Q Network.
6 | * Change PPO train function to return both value and policy network params, rather than just policy params.
7 | * Merge https://github.com/google/brax/pull/561, adds grad norm clipping to PPO.
8 | * Merge https://github.com/google/brax/issues/477, changes pusher vel damping.
9 | * Merge https://github.com/google/brax/pull/558, adds `mocap_pos` and `mocap_quat` to render function.
10 | * Merge https://github.com/google/brax/pull/559, allows for dictionary observations environment `State`.
11 | * Merge https://github.com/google/brax/pull/562, which supports asymmetric actor-critic for PPO.
12 | * Merge https://github.com/google/brax/pull/560, allows PPO from vision.
13 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.12.1.md:
--------------------------------------------------------------------------------
1 | # Brax v0.12.1 Release Notes
2 |
3 | * Add `wrap_env_fn` to training API. This allows users to specify custom wrapping functions for their environments.
4 | * Remove `FrozenDict` in brax env/training API.
5 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.12.2.md:
--------------------------------------------------------------------------------
1 | # Brax v0.12.2 Release Notes
2 |
3 | * See v0.12.3.
--------------------------------------------------------------------------------
/docs/release-notes/v0.12.3.md:
--------------------------------------------------------------------------------
1 | # Brax v0.12.3 Release Notes
2 |
3 | * Add training metrics to brax PPO, which allows users to avoid running evals during training while getting more frequent metric updates (à là RSL-RL). Set `num_evals=0` and `log_training_metrics=True`.
4 | * Add checkpointing directly to brax PPO, rather than relying on the `policy_params_fn` callback.
5 | * Fix bug in inverted pendulum (#574) where the position of the tip was being mis-calculated.
6 | * Add UInt64 to prevent overflow of training steps in brax training. Fixes #578.
7 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.9.0.md:
--------------------------------------------------------------------------------
1 | # Brax v0.9.0 Release Notes
2 |
3 | This patch release moves:
4 | * brax's older API to the `brax.v1` module
5 | * and the `brax.v2` module to `brax`
--------------------------------------------------------------------------------
/docs/release-notes/v0.9.1.md:
--------------------------------------------------------------------------------
1 | # Brax v0.9.1 Release Notes
2 |
3 | This patch release includes:
4 | * Add support for positional actuators.
5 | * Add fluid viscosity + density via box model.
6 | * Adds cylinder collider (but only for wafer-thin cylinders)
7 | * Bring back dm_env and torch env wrappers.
8 | * Bring back image rendering via pytinyrenderer.
9 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.9.2.md:
--------------------------------------------------------------------------------
1 | # Brax v0.9.2 Release Notes
2 |
3 | This patch release:
4 | * Adds domain randomization module
5 | * Adds swimmer env
6 | * Adds a quadruped training example in `brax/experimental` that demonstrates sim2real transfer.
--------------------------------------------------------------------------------
/docs/release-notes/v0.9.3.md:
--------------------------------------------------------------------------------
1 | # Brax v0.9.3 Release Notes
2 |
3 | This patch release:
4 |
5 | * Fixes compatibility issues with MuJoCo 3.0.0
6 | * Fixes bugs with gym and dm envs.
7 |
--------------------------------------------------------------------------------
/docs/release-notes/v0.9.4.md:
--------------------------------------------------------------------------------
1 | # Brax v0.9.4 Release Notes
2 |
3 | * Fixes gradients for generalized by changing a jp.linalg.norm to safe_norm.
4 | * Adds the [MJX](https://mujoco.readthedocs.io/en/stable/mjx.html) pipeline to Brax as well as an MjxEnv for RL training.
5 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "brax"
7 | version = "0.12.3"
8 | description = "A differentiable physics engine written in JAX."
9 | authors = [
10 | { name = "Brax Authors", email = "no-reply@google.com" },
11 | ]
12 | readme = { file = "README.md", content-type = "text/markdown" }
13 | requires-python = ">=3.10"
14 | license = { file = "LICENSE" }
15 | keywords = [
16 | "JAX",
17 | "reinforcement learning",
18 | "rigidbody",
19 | "physics",
20 | ]
21 | classifiers = [
22 | "Development Status :: 4 - Beta",
23 | "Intended Audience :: Developers",
24 | "Intended Audience :: Science/Research",
25 | "License :: OSI Approved :: Apache Software License",
26 | "Programming Language :: Python",
27 | "Programming Language :: Python :: 3.10",
28 | "Programming Language :: Python :: 3.11",
29 | "Programming Language :: Python :: 3.12",
30 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
31 | ]
32 | dependencies = [
33 | "absl-py",
34 | "dataclasses; python_version < '3.7'",
35 | "etils",
36 | "flask",
37 | "flask_cors",
38 | "flax",
39 | "jax>=0.4.6",
40 | "jaxlib>=0.4.6",
41 | "jaxopt",
42 | "jinja2",
43 | "ml_collections",
44 | "mujoco",
45 | "mujoco-mjx",
46 | "numpy",
47 | "optax",
48 | "orbax-checkpoint",
49 | "Pillow",
50 | "scipy",
51 | "tensorboardX",
52 | "trimesh",
53 | "typing-extensions",
54 | ]
55 |
56 | [project.optional-dependencies]
57 | develop = [
58 | "pytest",
59 | "transforms3d",
60 | "gym",
61 | "dm_env",
62 | ]
63 |
64 | [project.urls]
65 | Homepage = "http://github.com/google/brax"
66 |
67 | [tool.hatch.build.targets.wheel]
68 | packages = ["brax"]
69 | exclude = [
70 | "/datasets",
71 | "/docs",
72 | "/notebooks",
73 | "/tests",
74 | "brax/experimental/barkour/assets",
75 | "brax/experimental/barkour/data",
76 | ]
77 |
78 | [tool.hatch.build.targets.sdist]
79 | exclude = [
80 | "/datasets",
81 | "/docs",
82 | "/notebooks",
83 | "brax/experimental/barkour/assets",
84 | "brax/experimental/barkour/data",
85 | ]
86 |
87 | [tool.isort]
88 | force_single_line = true
89 | force_sort_within_sections = true
90 | lexicographical = true
91 | single_line_exclusions = ["typing"]
92 | order_by_type = false
93 | group_by_package = true
94 | line_length = 120
95 | use_parentheses = true
96 | multi_line_output = 3
97 | skip_glob = ["**/*.ipynb"]
98 |
99 | [tool.pyink]
100 | line-length = 80
101 | unstable = true
102 | pyink-indentation = 2
103 | pyink-use-majority-quotes = true
104 | extend-exclude = '''(
105 | .ipynb$
106 | )'''
107 |
--------------------------------------------------------------------------------