├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── Cargo.toml ├── LICENSE ├── README.md ├── assets ├── g1-forward.onnx ├── g1-forward.zip └── unitree_a1 │ ├── LICENSE │ ├── README.md │ ├── a1.png │ ├── a1.xml │ ├── assets │ ├── calf.obj │ ├── hip.obj │ ├── thigh.obj │ ├── thigh_mirror.obj │ ├── trunk.obj │ └── trunk_A1.png │ └── scene.xml ├── build.rs ├── examples ├── a1_walk.rs └── bevy_rl_rest.rs └── python ├── compare_onnx_pytorch.ipynb ├── policy.py └── run_onnx_policy.ipynb /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | # Run cargo test 14 | test: 15 | name: Build 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout sources 19 | uses: actions/checkout@v2 20 | - name: Initalize submobules 21 | run: git submodule init && git submodule update 22 | - name: Cache 23 | uses: actions/cache@v2 24 | with: 25 | path: | 26 | ~/.cargo/bin/ 27 | ~/.cargo/registry/index/ 28 | ~/.cargo/registry/cache/ 29 | ~/.cargo/git/db/ 30 | target/ 31 | key: ${{ runner.os }}-cargo-test-${{ hashFiles('**/Cargo.toml') }} 32 | - name: Install stable toolchain 33 | uses: actions-rs/toolchain@v1 34 | with: 35 | profile: minimal 36 | toolchain: nightly 37 | override: true 38 | - name: Install Dependencies 39 | run: sudo apt-get update; sudo apt-get install pkg-config libx11-dev libasound2-dev libudev-dev libxcb-render0-dev libxcb-shape0-dev libxcb-xfixes0-dev 40 | - name: Build all examples 41 | uses: actions-rs/cargo@v1 42 | with: 43 | command: build 44 | 45 | # Run cargo clippy -- -D warnings 46 | clippy_check: 47 | name: Clippy 48 | runs-on: ubuntu-latest 49 | steps: 50 | - name: Checkout sources 51 | uses: actions/checkout@v2 52 | - name: Cache 53 | uses: actions/cache@v2 54 | with: 55 | path: | 56 | ~/.cargo/bin/ 57 | ~/.cargo/registry/index/ 58 | ~/.cargo/registry/cache/ 59 | ~/.cargo/git/db/ 60 | target/ 61 | key: ${{ runner.os }}-cargo-clippy-${{ hashFiles('**/Cargo.toml') }} 62 | - name: Install stable toolchain 63 | uses: actions-rs/toolchain@v1 64 | with: 65 | toolchain: nightly 66 | profile: minimal 67 | components: clippy 68 | override: true 69 | - name: Install Dependencies 70 | run: sudo apt-get update; sudo apt-get install pkg-config libx11-dev libasound2-dev libudev-dev 71 | - name: Run clippy 72 | uses: actions-rs/clippy-check@v1 73 | with: 74 | token: ${{ secrets.GITHUB_TOKEN }} 75 | args: -- -D warnings 76 | 77 | # Run cargo fmt --all -- --check 78 | format: 79 | name: Format 80 | runs-on: ubuntu-latest 81 | steps: 82 | - name: Checkout sources 83 | uses: actions/checkout@v2 84 | - name: Install stable toolchain 85 | uses: actions-rs/toolchain@v1 86 | with: 87 | toolchain: nightly 88 | profile: minimal 89 | components: rustfmt 90 | override: true 91 | - name: Run cargo fmt 92 | uses: actions-rs/cargo@v1 93 | with: 94 | command: fmt 95 | args: --all -- --check 96 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bevy_quadruped_neural_control" 3 | version = "0.0.3" 4 | edition = "2021" 5 | license = "MIT OR Apache-2.0" 6 | repository = "https://github.com/stillonearth/bevy_quadruped_neural_control" 7 | 8 | [dependencies] 9 | bevy = "0.12" 10 | bevy_rl = { version = "0.12" } # { version = "0.9.6" } 11 | bevy_mujoco = { version = "0.12" } 12 | # bevy-inspector-egui = { version = "0.18.1" } 13 | bevy_flycam = { git = "https://github.com/sburris0/bevy_flycam" } 14 | rand = "0.8.5" 15 | serde = "1.0.158" 16 | serde_derive = "1.0.158" 17 | serde_json = "1.0.94" 18 | tract-onnx = { version = "0.19.7" } 19 | lazy_static = "1.4.0" 20 | 21 | [profile.dev] 22 | opt-level = 3 23 | 24 | [[example]] 25 | name = "a1_walk" 26 | 27 | [[example]] 28 | name = "bevy_rl_rest" 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Bevy is dual-licensed under either 2 | 3 | * MIT License (docs/LICENSE-MIT or http://opensource.org/licenses/MIT) 4 | * Apache License, Version 2.0 (docs/LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0) 5 | 6 | at your option. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bevy Quadruped Neural Control 2 | 3 | [![MIT/Apache 2.0](https://img.shields.io/badge/license-MIT%2FApache-blue.svg)](https://github.com/bevyengine/bevy#license) 4 | [![Rust](https://github.com/stillonearth/bevy_quadruped_neural_control/workflows/CI/badge.svg)](https://github.com/stillonearth/bevy_quadruped_neural_control/actions) 5 | 6 | Control quadruped robot in simulation using neural network. These demos use [bevy_mujoco](https://github.com/stillonearth/bevy_mujoco) (mujoco physics for bevy) and [sonos/tract](https://github.com/sonos/tract) (invoke neural networks in Rust) to make an environment with neurally-controlled quadruped robot Unitree A1. 7 | 8 | https://user-images.githubusercontent.com/97428129/210613348-82a5e59d-96af-42a9-a94a-c47093eb8297.mp4 9 | 10 | | Example | Description | 11 | | ------------------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | 12 | | [examples/bevy_rl_rest.rs](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/examples/bevy_rl_rest.rs) | Tuns Unitree A1 simulation wrapped as Reinforcement Learning Gym envronment with [bevy_rl](https://github.com/stillonearth/bevy_rl). It also rust REST API so you can control a robot from another environment such as python. | 13 | | [python/policy.py](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/python/policy.py) | An example how to control a robot with trained [stable_baselines3/SAC](https://stable-baselines3.readthedocs.io/en/master/modules/sac.html) policy from python env | 14 | | [python/run_onnx_policy.ipynb](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/python/run_onnx_policy.ipynb) | Exports PyTorch stable_baselines3 SAC policy to [Open Neural Network Exchange](https://onnx.ai/) format and runs the policy from python env | 15 | | [examples/a1_walk.rs](https://github.com/stillonearth/bevy_quadruped_neural_control/blob/main/examples/a1_walk.rs) | Runs Unitree A1 simulation and ONNX neural network to control it all in Rust env | 16 | 17 | Details on control system synthesis [here](https://github.com/stillonearth/continuous_control-unitree-a1). 18 | 19 | ## Running Neural Network in Rust Environment with Sonos Tract 20 | 21 | ```bash 22 | cargo run --example a1_walk 23 | ``` 24 | 25 | ## Running Rust Simulator and Neural Networks in Python 26 | 27 | ```bash 28 | cargo run --example bevy_rl_rest & 29 | cd python && python policy.py 30 | ``` 31 | -------------------------------------------------------------------------------- /assets/g1-forward.onnx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/g1-forward.onnx -------------------------------------------------------------------------------- /assets/g1-forward.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/g1-forward.zip -------------------------------------------------------------------------------- /assets/unitree_a1/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2016-2022 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics") 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /assets/unitree_a1/README.md: -------------------------------------------------------------------------------- 1 | # Unitree A1 Description (MJCF) 2 | 3 | ## Overview 4 | 5 | This package contains a simplified robot description (MJCF) of the [A1 Quadruped 6 | Robot](https://www.unitree.com/products/a1/) developed by [Unitree 7 | Robotics](https://www.unitree.com/). It is derived from the [publicly available 8 | URDF 9 | description](https://github.com/unitreerobotics/unitree_mujoco/tree/main/data/a1/urdf). 10 | 11 |

12 | 13 |

14 | 15 | ## URDF → MJCF derivation steps 16 | 17 | 1. Converted the DAE [mesh 18 | files](https://github.com/unitreerobotics/unitree_mujoco/tree/main/data/a1/meshes) 19 | to OBJ format using [Blender](https://www.blender.org/). 20 | 2. Processed `.obj` files with [`obj2mjcf`](https://github.com/kevinzakka/obj2mjcf). 21 | 3. Added ` ` to the URDF's 22 | `` clause in order to preserve visual geometries. 23 | 4. Loaded the URDF into MuJoCo and saved a corresponding MJCF. 24 | 5. Added a `` to the base, and a tracking light. 25 | 6. Manually edited the MJCF to extract common properties into the `` section. 26 | 7. Manually designed collision geometries. 27 | 8. Shifted joint reference values and ranges so that 0 corresponds to standing pose. 28 | 9. Softened the contacts of the feet to approximate the effect of rubber and 29 | increased `impratio` to reduce slippage. 30 | 10. Added `scene.xml` which includes the robot, with a textured groundplane, skybox, and haze. 31 | 32 | ## License 33 | 34 | This model is released under a [BSD-3-Clause License](LICENSE). 35 | -------------------------------------------------------------------------------- /assets/unitree_a1/a1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/unitree_a1/a1.png -------------------------------------------------------------------------------- /assets/unitree_a1/a1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 193 | -------------------------------------------------------------------------------- /assets/unitree_a1/assets/trunk_A1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/assets/unitree_a1/assets/trunk_A1.png -------------------------------------------------------------------------------- /assets/unitree_a1/scene.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 16 | 18 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | use std::env; 2 | use std::fs; 3 | use std::path::Path; 4 | use std::path::PathBuf; 5 | use std::str::FromStr; 6 | 7 | fn get_output_path() -> PathBuf { 8 | let manifest_dir_string = env::var("CARGO_MANIFEST_DIR").unwrap(); 9 | let build_type = env::var("PROFILE").unwrap(); 10 | let path = Path::new(&manifest_dir_string) 11 | .join("target") 12 | .join(build_type); 13 | path 14 | } 15 | 16 | fn main() { 17 | let (_, _, default_install) = match env::var("CARGO_CFG_UNIX") { 18 | Ok(_) => ("", "", ""), 19 | _ => match env::var("CARGO_CFG_WINDOWS") { 20 | Ok(_) => ("", "dll", "C:\\Program Files\\MuJoCo"), 21 | _ => ("", "", ""), 22 | }, 23 | }; 24 | 25 | if option_env!("DOCS_RS").is_none() { 26 | let mj_root = match (env::var("MUJOCO_DIR"), env::var("MUJOCO_PREFIX")) { 27 | (Ok(dir), _) | (Err(..), Ok(dir)) => dir, 28 | (Err(..), Err(..)) => default_install.to_string(), 29 | }; 30 | let mj_root = PathBuf::from_str(&mj_root).expect("Unable to get path"); 31 | let mj_lib_windows = mj_root.join("bin"); 32 | 33 | // Copy mujoco.dll to target directory on Windows targets 34 | if env::var("CARGO_CFG_WINDOWS").is_ok() { 35 | let target_dir = get_output_path(); 36 | let src = Path::join( 37 | &env::current_dir().unwrap(), 38 | mj_lib_windows.join("mujoco.dll"), 39 | ); 40 | 41 | fs::create_dir_all(&target_dir).unwrap(); 42 | let dest = Path::join(Path::new(&target_dir), Path::new("mujoco.dll")); 43 | eprintln!("Copying {:?} to {:?}", src, dest); 44 | std::fs::copy(src, dest).unwrap(); 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /examples/a1_walk.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::type_complexity)] 2 | #![allow(clippy::approx_constant)] 3 | #![feature(allocator_api)] 4 | // This example runs a simple policy on the A1 robot. The policy is a feedforward neural network 5 | // Is uses tract_onnx to load the model and run it 6 | 7 | use std::alloc::Global; 8 | 9 | use bevy::prelude::*; 10 | use bevy_flycam::*; 11 | use bevy_mujoco::*; 12 | 13 | use tract_ndarray::Array2; 14 | use tract_onnx::prelude::*; 15 | 16 | use lazy_static::lazy_static; 17 | 18 | lazy_static! { 19 | static ref MODEL: SimplePlan, Graph>> = 20 | tract_onnx::onnx() 21 | .model_for_path("assets/g1-forward.onnx") 22 | .unwrap() 23 | .with_input_fact(0, f32::fact([1, 119]).into()) 24 | .unwrap() 25 | .into_optimized() 26 | .unwrap() 27 | .into_runnable() 28 | .unwrap(); 29 | } 30 | 31 | fn setup(mut commands: Commands) { 32 | commands.spawn(PointLightBundle { 33 | point_light: PointLight { 34 | intensity: 9000.0, 35 | range: 100., 36 | shadows_enabled: false, 37 | ..default() 38 | }, 39 | transform: Transform::from_xyz(8.0, 16.0, 8.0), 40 | ..default() 41 | }); 42 | 43 | commands 44 | .spawn(Camera3dBundle { 45 | transform: Transform::from_xyz(0.0, 2.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y), 46 | ..default() 47 | }) 48 | .insert(FlyCam); 49 | } 50 | 51 | // These numbers aren't exported to onnx, it's action scaling coefficients from original python code 52 | // https://github.com/DLR-RM/stable-baselines3/blob/4fa17dcf0f72455aa3d36308291d4b052b2544f7/stable_baselines3/common/policies.py#L263 53 | // can be obtained from `python/compare_onnx_pytorch.ipynb` 54 | const LOW: [f32; 12] = [ 55 | -0.802851, -1.0472, -2.69653, -0.802851, -1.0472, -2.69653, -0.802851, -1.0472, -2.69653, 56 | -0.802851, -1.0472, -2.69653, 57 | ]; 58 | 59 | const HIGH: [f32; 12] = [ 60 | 0.802851, 4.18879, -0.916298, 0.802851, 4.18879, -0.916298, 0.802851, 4.18879, -0.916298, 61 | 0.802851, 4.18879, -0.916298, 62 | ]; 63 | 64 | fn robot_control_loop(mut mujoco_resources: ResMut) { 65 | // prepare simulation data for the NN 66 | let qpos = mujoco_resources.state.qpos.clone(); 67 | let qvel = mujoco_resources.state.qvel.clone(); 68 | let cfrc_ext = mujoco_resources.state.cfrc_ext.clone(); 69 | 70 | // make an input vector for a neural network 71 | let mut input_vec: Vec = Vec::new(); 72 | for value in qpos.iter().skip(2) { 73 | input_vec.push(*value as f32); 74 | } 75 | for value in qvel.iter() { 76 | input_vec.push(*value as f32); 77 | } 78 | for value in cfrc_ext.iter() { 79 | input_vec.push(value[0] as f32); 80 | input_vec.push(value[1] as f32); 81 | input_vec.push(value[2] as f32); 82 | input_vec.push(value[3] as f32); 83 | input_vec.push(value[4] as f32); 84 | input_vec.push(value[5] as f32); 85 | } 86 | 87 | // convert this to a tensor 88 | let input: Tensor = Array2::from_shape_vec((1, 119), input_vec).unwrap().into(); 89 | // get model execution result 90 | let result = MODEL.run(tvec!(input.into())).unwrap(); 91 | // extract model output 92 | let actions = result[0].to_array_view::().unwrap(); 93 | // prepare control vector fo passing to mujoco 94 | let mut control: Vec = vec![0.0; mujoco_resources.control.number_of_controls]; 95 | // fill control vector with actions (copy and adjust model output) 96 | for i in 0..mujoco_resources.control.number_of_controls { 97 | control[i] = actions[[0, i]] as f64; 98 | // scaling actions 99 | control[i] = LOW[i] as f64 + 0.5 * (control[i] + 1.0) * (HIGH[i] as f64 - LOW[i] as f64); 100 | } 101 | 102 | mujoco_resources.control.data = control; 103 | } 104 | 105 | fn main() { 106 | App::new() 107 | .add_plugins(DefaultPlugins) 108 | .insert_resource(MuJoCoPluginSettings { 109 | model_xml_path: "assets/unitree_a1/scene.xml".to_string(), 110 | pause_simulation: false, 111 | target_fps: 600.0, // this is not actual fps (bug in bevy_mujoco), 112 | // the bigger the value, the slower the simulation 113 | }) 114 | .add_plugins(NoCameraPlayerPlugin) 115 | .insert_resource(MovementSettings { 116 | speed: 1.0, 117 | ..default() 118 | }) 119 | .add_plugins(MuJoCoPlugin) 120 | .add_systems(Startup, setup) 121 | .add_systems(Update, robot_control_loop) 122 | .run(); 123 | } 124 | -------------------------------------------------------------------------------- /examples/bevy_rl_rest.rs: -------------------------------------------------------------------------------- 1 | #![allow(clippy::needless_range_loop)] 2 | #![allow(clippy::approx_constant)] 3 | #![allow(clippy::type_complexity)] 4 | 5 | ///! This example shows how to use the REST API to control the simulation 6 | ///! It uses the bevy_rl crate to provide a REST API 7 | use bevy::prelude::*; 8 | use bevy_flycam::*; 9 | use bevy_mujoco::*; 10 | 11 | use bevy_rl::{AIGymPlugin, AIGymSettings, AIGymState, EventControl, EventPause, SimulationState}; 12 | use serde::{Deserialize, Serialize}; 13 | 14 | #[derive(Default, Deref, DerefMut, Clone, Deserialize)] 15 | pub struct Actions(Vec); 16 | 17 | // Observation space 18 | 19 | #[derive(Default, Deref, DerefMut, Serialize, Clone)] 20 | pub struct EnvironmentState(MuJoCoState); 21 | 22 | fn setup(mut commands: Commands) { 23 | commands.spawn(PointLightBundle { 24 | point_light: PointLight { 25 | intensity: 9000.0, 26 | range: 100., 27 | shadows_enabled: false, 28 | ..default() 29 | }, 30 | transform: Transform::from_xyz(8.0, 16.0, 8.0), 31 | ..default() 32 | }); 33 | 34 | commands 35 | .spawn(Camera3dBundle { 36 | transform: Transform::from_xyz(0.0, 2.0, 2.0).looking_at(Vec3::ZERO, Vec3::Y), 37 | ..default() 38 | }) 39 | .insert(FlyCam); 40 | } 41 | 42 | fn bevy_rl_pause_request( 43 | mut pause_event_reader: EventReader, 44 | mut mujoco_settings: ResMut, 45 | mujoco_resources: Res, 46 | ai_gym_state: Res>, 47 | ) { 48 | for _ in pause_event_reader.iter() { 49 | // Pause physics engine 50 | mujoco_settings.pause_simulation = true; 51 | // Collect state into serializable struct 52 | let env_state = EnvironmentState(mujoco_resources.state.clone()); 53 | // Set bevy_rl gym state 54 | let mut ai_gym_state = ai_gym_state.lock().unwrap(); 55 | ai_gym_state.set_env_state(env_state); 56 | } 57 | } 58 | 59 | #[allow(unused_must_use)] 60 | fn bevy_rl_control_request( 61 | mut pause_event_reader: EventReader, 62 | mut mujoco_settings: ResMut, 63 | mut mujoco_resources: ResMut, 64 | mut simulation_state: ResMut>, 65 | ) { 66 | for control in pause_event_reader.iter() { 67 | println!("Control request received"); 68 | let unparsed_actions = &control.0; 69 | for i in 0..unparsed_actions.len() { 70 | if let Some(unparsed_action) = unparsed_actions[i].clone() { 71 | let action: Vec = serde_json::from_str(&unparsed_action).unwrap(); 72 | mujoco_resources.control.data = action; 73 | } 74 | } 75 | // Resume simulation 76 | mujoco_settings.pause_simulation = false; 77 | simulation_state.set(SimulationState::Running); 78 | } 79 | } 80 | 81 | fn main() { 82 | let mut app = App::new(); 83 | 84 | // Basic bevy setup 85 | app.add_plugins(DefaultPlugins) 86 | .add_plugins(NoCameraPlayerPlugin) 87 | .insert_resource(MovementSettings { 88 | speed: 3.0, 89 | ..default() 90 | }) 91 | .add_systems(Startup, setup); 92 | 93 | // Setup bevy_mujoco 94 | app.insert_resource(MuJoCoPluginSettings { 95 | model_xml_path: "assets/unitree_a1/scene.xml".to_string(), 96 | pause_simulation: false, 97 | target_fps: 600.0, 98 | }) 99 | .add_plugins(MuJoCoPlugin); 100 | 101 | // Setup bevy_rl 102 | let ai_gym_state = AIGymState::::new(AIGymSettings { 103 | num_agents: 1, 104 | render_to_buffer: false, 105 | pause_interval: 0.01, 106 | ..default() 107 | }); 108 | app.insert_resource(ai_gym_state) 109 | .add_plugins(AIGymPlugin::::default()); 110 | 111 | // bevy_rl events 112 | app.add_systems(Update, bevy_rl_pause_request); 113 | app.add_systems(Update, bevy_rl_control_request); 114 | 115 | // Start 116 | app.run(); 117 | } 118 | -------------------------------------------------------------------------------- /python/compare_onnx_pytorch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 36, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import numpy as np\n", 11 | "import onnxruntime as ort\n", 12 | "from stable_baselines3 import SAC" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 37, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "torch_model = SAC.load('./../assets/g1-forward')" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 38, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "input = np.array([0.0] * 119).reshape(1, 119).astype(np.float32)" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": 51, 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "data": { 40 | "text/plain": [ 41 | "array([-5232.02768734, -9256.49316448, -9769.03249478, 4742.5439958 ,\n", 42 | " 4766.3663854 , -9727.51090693, -7289.71801066, -4803.99115235,\n", 43 | " -9639.31192291, 7165.20577157, -4364.83300203, -9530.80311477])" 44 | ] 45 | }, 46 | "execution_count": 51, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "torch_mean = np.zeros(12)\n", 53 | "\n", 54 | "n_iter = 10000\n", 55 | "for i in range(n_iter):\n", 56 | " torch_mean += torch_model.policy.predict(torch.from_numpy(input).float())[0][0]\n", 57 | "\n", 58 | "torch_mean/n_iter\n", 59 | "torch_mean" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 52, 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "data": { 69 | "text/plain": [ 70 | "Box([-0.802851 -1.0472 -2.69653 -0.802851 -1.0472 -2.69653 -0.802851\n", 71 | " -1.0472 -2.69653 -0.802851 -1.0472 -2.69653 ], [ 0.802851 4.18879 -0.916298 0.802851 4.18879 -0.916298 0.802851\n", 72 | " 4.18879 -0.916298 0.802851 4.18879 -0.916298], (12,), float32)" 73 | ] 74 | }, 75 | "execution_count": 52, 76 | "metadata": {}, 77 | "output_type": "execute_result" 78 | } 79 | ], 80 | "source": [ 81 | "torch_model.policy.action_space" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def scale_action(scaled_action ,action_space):\n", 91 | " low, high = action_space.low, action_space.high\n", 92 | " return low + (0.5 * (scaled_action + 1.0) * (high - low))" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": 46, 98 | "metadata": {}, 99 | "outputs": [ 100 | { 101 | "data": { 102 | "text/plain": [ 103 | "array([-6514.36762797, -9523.70335982, 9328.56398707, 5927.71452494,\n", 104 | " -4158.10327213, 9337.60048291, -9094.07935127, -7836.99624672,\n", 105 | " 9462.4126415 , 8929.80290843, -7671.65321006, 9574.82129016])" 106 | ] 107 | }, 108 | "execution_count": 46, 109 | "metadata": {}, 110 | "output_type": "execute_result" 111 | } 112 | ], 113 | "source": [ 114 | "onnx_mean = np.zeros(12)\n", 115 | "ort_sess = ort.InferenceSession('./../assets/g1-forward.onnx')\n", 116 | "\n", 117 | "n_iter = 10000\n", 118 | "for i in range(n_iter):\n", 119 | " onnx_mean += ort_sess.run(None, {'obs': input})[0][0]\n", 120 | "\n", 121 | "onnx_mean/n_iter\n", 122 | "onnx_mean" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [] 131 | } 132 | ], 133 | "metadata": { 134 | "kernelspec": { 135 | "display_name": "torch", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "codemirror_mode": { 141 | "name": "ipython", 142 | "version": 3 143 | }, 144 | "file_extension": ".py", 145 | "mimetype": "text/x-python", 146 | "name": "python", 147 | "nbconvert_exporter": "python", 148 | "pygments_lexer": "ipython3", 149 | "version": "3.10.8" 150 | }, 151 | "orig_nbformat": 4, 152 | "vscode": { 153 | "interpreter": { 154 | "hash": "82ea6adb180b12ed72836e614d5d57295654ca2a9780d621124b81b6a9baa809" 155 | } 156 | } 157 | }, 158 | "nbformat": 4, 159 | "nbformat_minor": 2 160 | } 161 | -------------------------------------------------------------------------------- /python/policy.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import numpy as np 4 | 5 | from stable_baselines3 import SAC 6 | 7 | 8 | API_STEP = 'http://127.0.0.1:7878/step' 9 | API_STATE = 'http://127.0.0.1:7878/state' 10 | 11 | def send_action(action): 12 | payload = json.dumps([{"action": json.dumps(action)}], indent=4) 13 | requests.get(API_STEP, params={'payload': payload}) 14 | 15 | def get_obs(): 16 | state = requests.get(API_STATE).json() 17 | qpos = np.array(state['qpos']).flat.copy() 18 | qvel = np.array(state['qvel']).flat.copy() 19 | cfrc_ext = np.array(state['cfrc_ext']).flat.copy() 20 | 21 | return np.concatenate([qpos[2:], qvel, cfrc_ext]) 22 | 23 | model = SAC.load('./../assets/g1-forward') 24 | 25 | while True: 26 | obs = get_obs() 27 | action = model.predict(obs)[0].tolist() 28 | 29 | send_action(action) 30 | -------------------------------------------------------------------------------- /python/run_onnx_policy.ipynb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stillonearth/bevy_quadruped_neural_control/e1535d0029c35b6cbd364d0171482d07bfa1cfb8/python/run_onnx_policy.ipynb --------------------------------------------------------------------------------