├── .github
└── workflows
│ └── ci.yaml
├── .gitignore
├── Cargo.toml
├── LICENSE
├── README.md
├── assets
├── g1-forward.onnx
├── g1-forward.zip
└── unitree_a1
│ ├── LICENSE
│ ├── README.md
│ ├── a1.png
│ ├── a1.xml
│ ├── assets
│ ├── calf.obj
│ ├── hip.obj
│ ├── thigh.obj
│ ├── thigh_mirror.obj
│ ├── trunk.obj
│ └── trunk_A1.png
│ └── scene.xml
├── build.rs
├── examples
├── a1_walk.rs
└── bevy_rl_rest.rs
└── python
├── compare_onnx_pytorch.ipynb
├── policy.py
└── run_onnx_policy.ipynb
/.github/workflows/ci.yaml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | pull_request:
7 | branches: [main]
8 |
9 | env:
10 | CARGO_TERM_COLOR: always
11 |
12 | jobs:
13 | # Run cargo test
14 | test:
15 | name: Build
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Checkout sources
19 | uses: actions/checkout@v2
20 | - name: Initalize submobules
21 | run: git submodule init && git submodule update
22 | - name: Cache
23 | uses: actions/cache@v2
24 | with:
25 | path: |
26 | ~/.cargo/bin/
27 | ~/.cargo/registry/index/
28 | ~/.cargo/registry/cache/
29 | ~/.cargo/git/db/
30 | target/
31 | key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.toml') }}
32 | - name: Install stable toolchain
33 | uses: actions-rs/toolchain@v1
34 | with:
35 | profile: minimal
36 | toolchain: nightly
37 | override: true
38 | - name: Install Dependencies
39 | run: sudo apt-get update; sudo apt-get install pkg-config libx11-dev libasound2-dev libudev-dev libxcb-render0-dev libxcb-shape0-dev libxcb-xfixes0-dev
40 | - name: Build all examples
41 | uses: actions-rs/cargo@v1
42 | with:
43 | command: build
44 |
45 | # Run cargo clippy -- -D warnings
46 | clippy_check:
47 | name: Clippy
48 | runs-on: ubuntu-latest
49 | steps:
50 | - name: Checkout sources
51 | uses: actions/checkout@v2
52 | - name: Cache
53 | uses: actions/cache@v2
54 | with:
55 | path: |
56 | ~/.cargo/bin/
57 | ~/.cargo/registry/index/
58 | ~/.cargo/registry/cache/
59 | ~/.cargo/git/db/
60 | target/
61 | key: ${{ runner.os }}-cargo-clippy-${{ hashFiles('**/Cargo.toml') }}
62 | - name: Install stable toolchain
63 | uses: actions-rs/toolchain@v1
64 | with:
65 | toolchain: nightly
66 | profile: minimal
67 | components: clippy
68 | override: true
69 | - name: Install Dependencies
70 | run: sudo apt-get update; sudo apt-get install pkg-config libx11-dev libasound2-dev libudev-dev
71 | - name: Run clippy
72 | uses: actions-rs/clippy-check@v1
73 | with:
74 | token: ${{ secrets.GITHUB_TOKEN }}
75 | args: -- -D warnings
76 |
77 | # Run cargo fmt --all -- --check
78 | format:
79 | name: Format
80 | runs-on: ubuntu-latest
81 | steps:
82 | - name: Checkout sources
83 | uses: actions/checkout@v2
84 | - name: Install stable toolchain
85 | uses: actions-rs/toolchain@v1
86 | with:
87 | toolchain: nightly
88 | profile: minimal
89 | components: rustfmt
90 | override: true
91 | - name: Run cargo fmt
92 | uses: actions-rs/cargo@v1
93 | with:
94 | command: fmt
95 | args: --all -- --check
96 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
--------------------------------------------------------------------------------
/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "bevy_quadruped_neural_control"
3 | version = "0.0.3"
4 | edition = "2021"
5 | license = "MIT OR Apache-2.0"
6 | repository = "https://github.com/stillonearth/bevy_quadruped_neural_control"
7 |
8 | [dependencies]
9 | bevy = "0.12"
10 | bevy_rl = { version = "0.12" } # { version = "0.9.6" }
11 | bevy_mujoco = { version = "0.12" }
12 | # bevy-inspector-egui = { version = "0.18.1" }
13 | bevy_flycam = { git = "https://github.com/sburris0/bevy_flycam" }
14 | rand = "0.8.5"
15 | serde = "1.0.158"
16 | serde_derive = "1.0.158"
17 | serde_json = "1.0.94"
18 | tract-onnx = { version = "0.19.7" }
19 | lazy_static = "1.4.0"
20 |
21 | [profile.dev]
22 | opt-level = 3
23 |
24 | [[example]]
25 | name = "a1_walk"
26 |
27 | [[example]]
28 | name = "bevy_rl_rest"
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Bevy is dual-licensed under either
2 |
3 | * MIT License (docs/LICENSE-MIT or http://opensource.org/licenses/MIT)
4 | * Apache License, Version 2.0 (docs/LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
5 |
6 | at your option.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Bevy Quadruped Neural Control
2 |
3 | [](https://github.com/bevyengine/bevy#license)
4 | [](https://github.com/stillonearth/bevy_quadruped_neural_control/actions)
5 |
6 | Control quadruped robot in simulation using neural network. These demos use [bevy_mujoco](https://github.com/stillonearth/bevy_mujoco) (mujoco physics for bevy) and [sonos/tract](https://github.com/sonos/tract) (invoke neural networks in Rust) to make an environment with neurally-controlled quadruped robot Unitree A1.
7 |
8 | https://user-images.githubusercontent.com/97428129/210613348-82a5e59d-96af-42a9-a94a-c47093eb8297.mp4
9 |
10 | | Example | Description |
11 | | ------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
12 | | [examples/bevy_rl_rest.rs](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/examples/bevy_rl_rest.rs) | Tuns Unitree A1 simulation wrapped as Reinforcement Learning Gym envronment with [bevy_rl](https://github.com/stillonearth/bevy_rl). It also rust REST API so you can control a robot from another environment such as python. |
13 | | [python/policy.py](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/python/policy.py) | An example how to control a robot with trained [stable_baselines3/SAC](https://stable-baselines3.readthedocs.io/en/master/modules/sac.html) policy from python env |
14 | | [python/run_onnx_policy.ipynb](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/python/run_onnx_policy.ipynb) | Exports PyTorch stable_baselines3 SAC policy to [Open Neural Network Exchange](https://onnx.ai/) format and runs the policy from python env |
15 | | [examples/a1_walk.rs](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/examples/a1_walk.rs) | Runs Unitree A1 simulation and ONNX neural network to control it all in Rust env |
16 |
17 | Details on control system synthesis [here](https://github.com/stillonearth/continuous_control-unitree-a1).
18 |
19 | ## Running Neural Network in Rust Environment with Sonos Tract
20 |
21 | ```bash
22 | cargo run --example a1_walk
23 | ```
24 |
25 | ## Running Rust Simulator and Neural Networks in Python
26 |
27 | ```bash
28 | cargo run --example bevy_rl_rest &
29 | cd python && python policy.py
30 | ```
31 |
--------------------------------------------------------------------------------
/assets/g1-forward.onnx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/g1-forward.onnx
--------------------------------------------------------------------------------
/assets/g1-forward.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/g1-forward.zip
--------------------------------------------------------------------------------
/assets/unitree_a1/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2016-2022 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | * Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | * Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | * Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/assets/unitree_a1/README.md:
--------------------------------------------------------------------------------
1 | # Unitree A1 Description (MJCF)
2 |
3 | ## Overview
4 |
5 | This package contains a simplified robot description (MJCF) of the [A1 Quadruped
6 | Robot](https://www.unitree.com/products/a1/) developed by [Unitree
7 | Robotics](https://www.unitree.com/). It is derived from the [publicly available
8 | URDF
9 | description](https://github.com/unitreerobotics/unitree_mujoco/tree/main/data/a1/urdf).
10 |
11 |
12 |
13 |
14 |
15 | ## URDF → MJCF derivation steps
16 |
17 | 1. Converted the DAE [mesh
18 | files](https://github.com/unitreerobotics/unitree_mujoco/tree/main/data/a1/meshes)
19 | to OBJ format using [Blender](https://www.blender.org/).
20 | 2. Processed `.obj` files with [`obj2mjcf`](https://github.com/kevinzakka/obj2mjcf).
21 | 3. Added ` ` to the URDF's
22 | `` clause in order to preserve visual geometries.
23 | 4. Loaded the URDF into MuJoCo and saved a corresponding MJCF.
24 | 5. Added a `` to the base, and a tracking light.
25 | 6. Manually edited the MJCF to extract common properties into the `` section.
26 | 7. Manually designed collision geometries.
27 | 8. Shifted joint reference values and ranges so that 0 corresponds to standing pose.
28 | 9. Softened the contacts of the feet to approximate the effect of rubber and
29 | increased `impratio` to reduce slippage.
30 | 10. Added `scene.xml` which includes the robot, with a textured groundplane, skybox, and haze.
31 |
32 | ## License
33 |
34 | This model is released under a [BSD-3-Clause License](LICENSE).
35 |
--------------------------------------------------------------------------------
/assets/unitree_a1/a1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/unitree_a1/a1.png
--------------------------------------------------------------------------------
/assets/unitree_a1/a1.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
--------------------------------------------------------------------------------
/assets/unitree_a1/assets/trunk_A1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/unitree_a1/assets/trunk_A1.png
--------------------------------------------------------------------------------
/assets/unitree_a1/scene.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
16 |
18 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
--------------------------------------------------------------------------------
/build.rs:
--------------------------------------------------------------------------------
1 | use std::env;
2 | use std::fs;
3 | use std::path::Path;
4 | use std::path::PathBuf;
5 | use std::str::FromStr;
6 |
7 | fn get_output_path() -> PathBuf {
8 | let manifest_dir_string = env::var("CARGO_MANIFEST_DIR").unwrap();
9 | let build_type = env::var("PROFILE").unwrap();
10 | let path = Path::new(&manifest_dir_string)
11 | .join("target")
12 | .join(build_type);
13 | path
14 | }
15 |
16 | fn main() {
17 | let (_, _, default_install) = match env::var("CARGO_CFG_UNIX") {
18 | Ok(_) => ("", "", ""),
19 | _ => match env::var("CARGO_CFG_WINDOWS") {
20 | Ok(_) => ("", "dll", "C:\\Program Files\\MuJoCo"),
21 | _ => ("", "", ""),
22 | },
23 | };
24 |
25 | if option_env!("DOCS_RS").is_none() {
26 | let mj_root = match (env::var("MUJOCO_DIR"), env::var("MUJOCO_PREFIX")) {
27 | (Ok(dir), _) | (Err(..), Ok(dir)) => dir,
28 | (Err(..), Err(..)) => default_install.to_string(),
29 | };
30 | let mj_root = PathBuf::from_str(&mj_root).expect("Unable to get path");
31 | let mj_lib_windows = mj_root.join("bin");
32 |
33 | // Copy mujoco.dll to target directory on Windows targets
34 | if env::var("CARGO_CFG_WINDOWS").is_ok() {
35 | let target_dir = get_output_path();
36 | let src = Path::join(
37 | &env::current_dir().unwrap(),
38 | mj_lib_windows.join("mujoco.dll"),
39 | );
40 |
41 | fs::create_dir_all(&target_dir).unwrap();
42 | let dest = Path::join(Path::new(&target_dir), Path::new("mujoco.dll"));
43 | eprintln!("Copying {:?} to {:?}", src, dest);
44 | std::fs::copy(src, dest).unwrap();
45 | }
46 | }
47 | }
48 |
--------------------------------------------------------------------------------
/examples/a1_walk.rs:
--------------------------------------------------------------------------------
1 | #![allow(clippy::type_complexity)]
2 | #![allow(clippy::approx_constant)]
3 | #![feature(allocator_api)]
4 | // This example runs a simple policy on the A1 robot. The policy is a feedforward neural network
5 | // Is uses tract_onnx to load the model and run it
6 |
7 | use std::alloc::Global;
8 |
9 | use bevy::prelude::*;
10 | use bevy_flycam::*;
11 | use bevy_mujoco::*;
12 |
13 | use tract_ndarray::Array2;
14 | use tract_onnx::prelude::*;
15 |
16 | use lazy_static::lazy_static;
17 |
18 | lazy_static! {
19 | static ref MODEL: SimplePlan, Graph>> =
20 | tract_onnx::onnx()
21 | .model_for_path("assets/g1-forward.onnx")
22 | .unwrap()
23 | .with_input_fact(0, f32::fact([1, 119]).into())
24 | .unwrap()
25 | .into_optimized()
26 | .unwrap()
27 | .into_runnable()
28 | .unwrap();
29 | }
30 |
31 | fn setup(mut commands: Commands) {
32 | commands.spawn(PointLightBundle {
33 | point_light: PointLight {
34 | intensity: 9000.0,
35 | range: 100.,
36 | shadows_enabled: false,
37 | ..default()
38 | },
39 | transform: Transform::from_xyz(8.0, 16.0, 8.0),
40 | ..default()
41 | });
42 |
43 | commands
44 | .spawn(Camera3dBundle {
45 | transform: Transform::from_xyz(0.0, 2.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y),
46 | ..default()
47 | })
48 | .insert(FlyCam);
49 | }
50 |
51 | // These numbers aren't exported to onnx, it's action scaling coefficients from original python code
52 | // https://github.com/DLR-RM/stable-baselines3/blob/4fa17dcf0f72455aa3d36308291d4b052b2544f7/stable_baselines3/common/policies.py#L263
53 | // can be obtained from `python/compare_onnx_pytorch.ipynb`
54 | const LOW: [f32; 12] = [
55 | -0.802851, -1.0472, -2.69653, -0.802851, -1.0472, -2.69653, -0.802851, -1.0472, -2.69653,
56 | -0.802851, -1.0472, -2.69653,
57 | ];
58 |
59 | const HIGH: [f32; 12] = [
60 | 0.802851, 4.18879, -0.916298, 0.802851, 4.18879, -0.916298, 0.802851, 4.18879, -0.916298,
61 | 0.802851, 4.18879, -0.916298,
62 | ];
63 |
64 | fn robot_control_loop(mut mujoco_resources: ResMut) {
65 | // prepare simulation data for the NN
66 | let qpos = mujoco_resources.state.qpos.clone();
67 | let qvel = mujoco_resources.state.qvel.clone();
68 | let cfrc_ext = mujoco_resources.state.cfrc_ext.clone();
69 |
70 | // make an input vector for a neural network
71 | let mut input_vec: Vec = Vec::new();
72 | for value in qpos.iter().skip(2) {
73 | input_vec.push(*value as f32);
74 | }
75 | for value in qvel.iter() {
76 | input_vec.push(*value as f32);
77 | }
78 | for value in cfrc_ext.iter() {
79 | input_vec.push(value[0] as f32);
80 | input_vec.push(value[1] as f32);
81 | input_vec.push(value[2] as f32);
82 | input_vec.push(value[3] as f32);
83 | input_vec.push(value[4] as f32);
84 | input_vec.push(value[5] as f32);
85 | }
86 |
87 | // convert this to a tensor
88 | let input: Tensor = Array2::from_shape_vec((1, 119), input_vec).unwrap().into();
89 | // get model execution result
90 | let result = MODEL.run(tvec!(input.into())).unwrap();
91 | // extract model output
92 | let actions = result[0].to_array_view::().unwrap();
93 | // prepare control vector fo passing to mujoco
94 | let mut control: Vec = vec![0.0; mujoco_resources.control.number_of_controls];
95 | // fill control vector with actions (copy and adjust model output)
96 | for i in 0..mujoco_resources.control.number_of_controls {
97 | control[i] = actions[[0, i]] as f64;
98 | // scaling actions
99 | control[i] = LOW[i] as f64 + 0.5 * (control[i] + 1.0) * (HIGH[i] as f64 - LOW[i] as f64);
100 | }
101 |
102 | mujoco_resources.control.data = control;
103 | }
104 |
105 | fn main() {
106 | App::new()
107 | .add_plugins(DefaultPlugins)
108 | .insert_resource(MuJoCoPluginSettings {
109 | model_xml_path: "assets/unitree_a1/scene.xml".to_string(),
110 | pause_simulation: false,
111 | target_fps: 600.0, // this is not actual fps (bug in bevy_mujoco),
112 | // the bigger the value, the slower the simulation
113 | })
114 | .add_plugins(NoCameraPlayerPlugin)
115 | .insert_resource(MovementSettings {
116 | speed: 1.0,
117 | ..default()
118 | })
119 | .add_plugins(MuJoCoPlugin)
120 | .add_systems(Startup, setup)
121 | .add_systems(Update, robot_control_loop)
122 | .run();
123 | }
124 |
--------------------------------------------------------------------------------
/examples/bevy_rl_rest.rs:
--------------------------------------------------------------------------------
1 | #![allow(clippy::needless_range_loop)]
2 | #![allow(clippy::approx_constant)]
3 | #![allow(clippy::type_complexity)]
4 |
5 | ///! This example shows how to use the REST API to control the simulation
6 | ///! It uses the bevy_rl crate to provide a REST API
7 | use bevy::prelude::*;
8 | use bevy_flycam::*;
9 | use bevy_mujoco::*;
10 |
11 | use bevy_rl::{AIGymPlugin, AIGymSettings, AIGymState, EventControl, EventPause, SimulationState};
12 | use serde::{Deserialize, Serialize};
13 |
14 | #[derive(Default, Deref, DerefMut, Clone, Deserialize)]
15 | pub struct Actions(Vec);
16 |
17 | // Observation space
18 |
19 | #[derive(Default, Deref, DerefMut, Serialize, Clone)]
20 | pub struct EnvironmentState(MuJoCoState);
21 |
22 | fn setup(mut commands: Commands) {
23 | commands.spawn(PointLightBundle {
24 | point_light: PointLight {
25 | intensity: 9000.0,
26 | range: 100.,
27 | shadows_enabled: false,
28 | ..default()
29 | },
30 | transform: Transform::from_xyz(8.0, 16.0, 8.0),
31 | ..default()
32 | });
33 |
34 | commands
35 | .spawn(Camera3dBundle {
36 | transform: Transform::from_xyz(0.0, 2.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y),
37 | ..default()
38 | })
39 | .insert(FlyCam);
40 | }
41 |
42 | fn bevy_rl_pause_request(
43 | mut pause_event_reader: EventReader,
44 | mut mujoco_settings: ResMut,
45 | mujoco_resources: Res,
46 | ai_gym_state: Res>,
47 | ) {
48 | for _ in pause_event_reader.iter() {
49 | // Pause physics engine
50 | mujoco_settings.pause_simulation = true;
51 | // Collect state into serializable struct
52 | let env_state = EnvironmentState(mujoco_resources.state.clone());
53 | // Set bevy_rl gym state
54 | let mut ai_gym_state = ai_gym_state.lock().unwrap();
55 | ai_gym_state.set_env_state(env_state);
56 | }
57 | }
58 |
59 | #[allow(unused_must_use)]
60 | fn bevy_rl_control_request(
61 | mut pause_event_reader: EventReader,
62 | mut mujoco_settings: ResMut,
63 | mut mujoco_resources: ResMut,
64 | mut simulation_state: ResMut>,
65 | ) {
66 | for control in pause_event_reader.iter() {
67 | println!("Control request received");
68 | let unparsed_actions = &control.0;
69 | for i in 0..unparsed_actions.len() {
70 | if let Some(unparsed_action) = unparsed_actions[i].clone() {
71 | let action: Vec = serde_json::from_str(&unparsed_action).unwrap();
72 | mujoco_resources.control.data = action;
73 | }
74 | }
75 | // Resume simulation
76 | mujoco_settings.pause_simulation = false;
77 | simulation_state.set(SimulationState::Running);
78 | }
79 | }
80 |
81 | fn main() {
82 | let mut app = App::new();
83 |
84 | // Basic bevy setup
85 | app.add_plugins(DefaultPlugins)
86 | .add_plugins(NoCameraPlayerPlugin)
87 | .insert_resource(MovementSettings {
88 | speed: 3.0,
89 | ..default()
90 | })
91 | .add_systems(Startup, setup);
92 |
93 | // Setup bevy_mujoco
94 | app.insert_resource(MuJoCoPluginSettings {
95 | model_xml_path: "assets/unitree_a1/scene.xml".to_string(),
96 | pause_simulation: false,
97 | target_fps: 600.0,
98 | })
99 | .add_plugins(MuJoCoPlugin);
100 |
101 | // Setup bevy_rl
102 | let ai_gym_state = AIGymState::::new(AIGymSettings {
103 | num_agents: 1,
104 | render_to_buffer: false,
105 | pause_interval: 0.01,
106 | ..default()
107 | });
108 | app.insert_resource(ai_gym_state)
109 | .add_plugins(AIGymPlugin::::default());
110 |
111 | // bevy_rl events
112 | app.add_systems(Update, bevy_rl_pause_request);
113 | app.add_systems(Update, bevy_rl_control_request);
114 |
115 | // Start
116 | app.run();
117 | }
118 |
--------------------------------------------------------------------------------
/python/compare_onnx_pytorch.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 36,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import numpy as np\n",
11 | "import onnxruntime as ort\n",
12 | "from stable_baselines3 import SAC"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 37,
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "torch_model = SAC.load('./../assets/g1-forward')"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 38,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "input = np.array([0.0] * 119).reshape(1, 119).astype(np.float32)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 51,
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "data": {
40 | "text/plain": [
41 | "array([-5232.02768734, -9256.49316448, -9769.03249478, 4742.5439958 ,\n",
42 | " 4766.3663854 , -9727.51090693, -7289.71801066, -4803.99115235,\n",
43 | " -9639.31192291, 7165.20577157, -4364.83300203, -9530.80311477])"
44 | ]
45 | },
46 | "execution_count": 51,
47 | "metadata": {},
48 | "output_type": "execute_result"
49 | }
50 | ],
51 | "source": [
52 | "torch_mean = np.zeros(12)\n",
53 | "\n",
54 | "n_iter = 10000\n",
55 | "for i in range(n_iter):\n",
56 | " torch_mean += torch_model.policy.predict(torch.from_numpy(input).float())[0][0]\n",
57 | "\n",
58 | "torch_mean/n_iter\n",
59 | "torch_mean"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 52,
65 | "metadata": {},
66 | "outputs": [
67 | {
68 | "data": {
69 | "text/plain": [
70 | "Box([-0.802851 -1.0472 -2.69653 -0.802851 -1.0472 -2.69653 -0.802851\n",
71 | " -1.0472 -2.69653 -0.802851 -1.0472 -2.69653 ], [ 0.802851 4.18879 -0.916298 0.802851 4.18879 -0.916298 0.802851\n",
72 | " 4.18879 -0.916298 0.802851 4.18879 -0.916298], (12,), float32)"
73 | ]
74 | },
75 | "execution_count": 52,
76 | "metadata": {},
77 | "output_type": "execute_result"
78 | }
79 | ],
80 | "source": [
81 | "torch_model.policy.action_space"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": null,
87 | "metadata": {},
88 | "outputs": [],
89 | "source": [
90 | "def scale_action(scaled_action ,action_space):\n",
91 | " low, high = action_space.low, action_space.high\n",
92 | " return low + (0.5 * (scaled_action + 1.0) * (high - low))"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 46,
98 | "metadata": {},
99 | "outputs": [
100 | {
101 | "data": {
102 | "text/plain": [
103 | "array([-6514.36762797, -9523.70335982, 9328.56398707, 5927.71452494,\n",
104 | " -4158.10327213, 9337.60048291, -9094.07935127, -7836.99624672,\n",
105 | " 9462.4126415 , 8929.80290843, -7671.65321006, 9574.82129016])"
106 | ]
107 | },
108 | "execution_count": 46,
109 | "metadata": {},
110 | "output_type": "execute_result"
111 | }
112 | ],
113 | "source": [
114 | "onnx_mean = np.zeros(12)\n",
115 | "ort_sess = ort.InferenceSession('./../assets/g1-forward.onnx')\n",
116 | "\n",
117 | "n_iter = 10000\n",
118 | "for i in range(n_iter):\n",
119 | " onnx_mean += ort_sess.run(None, {'obs': input})[0][0]\n",
120 | "\n",
121 | "onnx_mean/n_iter\n",
122 | "onnx_mean"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": []
131 | }
132 | ],
133 | "metadata": {
134 | "kernelspec": {
135 | "display_name": "torch",
136 | "language": "python",
137 | "name": "python3"
138 | },
139 | "language_info": {
140 | "codemirror_mode": {
141 | "name": "ipython",
142 | "version": 3
143 | },
144 | "file_extension": ".py",
145 | "mimetype": "text/x-python",
146 | "name": "python",
147 | "nbconvert_exporter": "python",
148 | "pygments_lexer": "ipython3",
149 | "version": "3.10.8"
150 | },
151 | "orig_nbformat": 4,
152 | "vscode": {
153 | "interpreter": {
154 | "hash": "82ea6adb180b12ed72836e614d5d57295654ca2a9780d621124b81b6a9baa809"
155 | }
156 | }
157 | },
158 | "nbformat": 4,
159 | "nbformat_minor": 2
160 | }
161 |
--------------------------------------------------------------------------------
/python/policy.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import json
3 | import numpy as np
4 |
5 | from stable_baselines3 import SAC
6 |
7 |
8 | API_STEP = 'http://127.0.0.1:7878/step'
9 | API_STATE = 'http://127.0.0.1:7878/state'
10 |
11 | def send_action(action):
12 | payload = json.dumps([{"action": json.dumps(action)}], indent=4)
13 | requests.get(API_STEP, params={'payload': payload})
14 |
15 | def get_obs():
16 | state = requests.get(API_STATE).json()
17 | qpos = np.array(state['qpos']).flat.copy()
18 | qvel = np.array(state['qvel']).flat.copy()
19 | cfrc_ext = np.array(state['cfrc_ext']).flat.copy()
20 |
21 | return np.concatenate([qpos[2:], qvel, cfrc_ext])
22 |
23 | model = SAC.load('./../assets/g1-forward')
24 |
25 | while True:
26 | obs = get_obs()
27 | action = model.predict(obs)[0].tolist()
28 |
29 | send_action(action)
30 |
--------------------------------------------------------------------------------
/python/run_onnx_policy.ipynb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/python/run_onnx_policy.ipynb
--------------------------------------------------------------------------------