├── requirements.txt ├── justfile ├── .gitignore ├── src ├── error.rs ├── lib.rs ├── space_data.rs ├── utility.rs ├── environment.rs ├── client.rs └── space_template.rs ├── examples ├── basic.rs ├── agent.py ├── readme.md └── agent.rs ├── Cargo.toml ├── .travis.yml ├── LICENSE ├── .github └── workflows │ └── test.yml ├── .rustfmt.toml ├── README.md └── tests └── integration_tests.rs /requirements.txt: -------------------------------------------------------------------------------- 1 | gymnasium[all]==0.26.3 -------------------------------------------------------------------------------- /justfile: -------------------------------------------------------------------------------- 1 | 2 | check: 3 | cargo +nightly fmt -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target 2 | Cargo.lock 3 | .idea 4 | *.iml 5 | *.pyc 6 | -------------------------------------------------------------------------------- /src/error.rs: -------------------------------------------------------------------------------- 1 | use cpython::{PyErr, PyObject}; 2 | use thiserror::Error; 3 | 4 | #[derive(Debug, Error)] 5 | pub enum GymError { 6 | #[error("Invalid action")] 7 | InvalidAction, 8 | #[error("Invalid conversion")] 9 | InvalidConversion, 10 | #[error("Wrong type")] 11 | WrongType, 12 | #[error("Unable to parse step result")] 13 | WrongStepResult, 14 | #[error("Unable to parse reset result")] 15 | WrongResetResult, 16 | #[error("Invalid seed")] 17 | InvalidSeed, 18 | #[error("Invalid render mode")] 19 | InvalidRenderMode, 20 | #[error("Unable to make environment '{0}' with dict '{1:?}' (Error: {2:?})")] 21 | InvalidMake(String, Vec<(PyObject, PyObject)>, PyErr), 22 | } 23 | -------------------------------------------------------------------------------- /examples/basic.rs: -------------------------------------------------------------------------------- 1 | use gym::client::MakeOptions; 2 | 3 | extern crate gym; 4 | 5 | fn main() { 6 | let gym = gym::client::GymClient::default(); 7 | let env = gym 8 | .make( 9 | "CartPole-v1", 10 | Some(MakeOptions { 11 | render_mode: Some(gym::client::RenderMode::Human), 12 | ..Default::default() 13 | }), 14 | ) 15 | .expect("Unable to create environment"); 16 | 17 | for _ in 0..10 { 18 | env.reset(None).expect("Unable to reset"); 19 | 20 | for _ in 0..100 { 21 | let action = env.action_space().sample(); 22 | let state = env.step(&action).unwrap(); 23 | env.render(); 24 | if state.is_done { 25 | break; 26 | } 27 | } 28 | } 29 | 30 | env.close(); 31 | } 32 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "gym" 3 | version = "2.2.1" 4 | authors = ["Roberto "] 5 | edition = "2021" 6 | 7 | description = "Open AI environments bindings for Rust" 8 | repository = "https://github.com/MrRobb/gym-rs" 9 | documentation = "https://docs.rs/gym" 10 | readme = "README.md" 11 | keywords = ["openai", "gym", "environments", "reinforcement", "learning"] 12 | categories = ["api-bindings"] 13 | license = "MIT/Apache-2.0" 14 | 15 | [badges] 16 | maintenance = { status = "actively-developed" } 17 | 18 | [dependencies] 19 | rand = "0.8" 20 | cpython = "0.7" 21 | ndarray = { version = "0.15", features = ["serde"] } 22 | thiserror = "1.0" 23 | serde = { version = "1.0.160", features = ["derive"] } 24 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow( 2 | clippy::missing_errors_doc, 3 | clippy::missing_const_for_fn, 4 | clippy::missing_panics_doc, 5 | clippy::must_use_candidate, 6 | clippy::module_name_repetitions, 7 | clippy::use_self 8 | )] 9 | 10 | extern crate cpython; 11 | extern crate ndarray; 12 | extern crate rand; 13 | use space_data::SpaceData; 14 | 15 | pub mod client; 16 | mod environment; 17 | pub mod error; 18 | pub mod space_data; 19 | pub mod space_template; 20 | pub mod utility; 21 | 22 | type DiscreteType = usize; 23 | type VectorType = ndarray::Array1; 24 | pub type Action = SpaceData; 25 | pub type Observation = SpaceData; 26 | pub type Reward = f64; 27 | 28 | pub struct State { 29 | pub observation: SpaceData, 30 | pub reward: f64, 31 | pub is_done: bool, 32 | pub is_truncated: bool, 33 | } 34 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: rust 2 | 3 | addons: 4 | apt: 5 | update: true 6 | packages: 7 | - "python3" 8 | - "python3-pip" 9 | - "python3-setuptools" 10 | - "python3-wheel" 11 | - "libssl-dev" 12 | 13 | python: 14 | - 3.7 15 | rust: 16 | - stable 17 | 18 | cache: cargo 19 | 20 | before_cache: | 21 | if [[ "$TRAVIS_RUST_VERSION" == stable ]]; then 22 | cargo install cargo-tarpaulin 23 | fi 24 | 25 | install: 26 | - sudo python3 -m pip install --upgrade pip 27 | - sudo python3 -m pip install -r requirements.txt 28 | 29 | script: 30 | - cargo clean 31 | - cargo build 32 | - cargo test 33 | 34 | after_success: | 35 | if [[ "$TRAVIS_RUST_VERSION" == stable ]]; then 36 | # Create and upload a report for codecov.io 37 | cargo tarpaulin --out Xml 38 | bash <(curl -s https://codecov.io/bash) 39 | fi 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mr.Robb 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/agent.py: -------------------------------------------------------------------------------- 1 | """Training the agent""" 2 | 3 | import random 4 | import numpy as np 5 | import gymnasium as gym 6 | 7 | print(gym.__version__) 8 | 9 | env = gym.make("Taxi-v3") 10 | 11 | # Hyperparameters 12 | alpha = 0.1 13 | gamma = 0.6 14 | epsilon = 0.1 15 | 16 | q_table = np.zeros([env.observation_space.n, env.action_space.n]) 17 | 18 | for i in range(0, 100_000): 19 | state, info = env.reset() 20 | 21 | epochs, penalties, reward, = (0, 0, 0) 22 | done = False 23 | 24 | while not done: 25 | if random.uniform(0, 1) < epsilon: 26 | action = env.action_space.sample() # Explore action space 27 | else: 28 | action = np.argmax(q_table[state]) 29 | 30 | next_state, reward, terminated, truncated, info = env.step(action) 31 | 32 | done = terminated or truncated 33 | 34 | old_value = q_table[state, action] 35 | next_max = np.max(q_table[next_state]) 36 | 37 | new_value = (1 - alpha) * old_value + alpha * (reward + gamma * next_max) 38 | q_table[state, action] = new_value 39 | 40 | if reward == -10: 41 | penalties += 1 42 | 43 | state = next_state 44 | epochs += 1 45 | 46 | if i % 100 == 0: 47 | print(f"Episode: {i} in {epochs}") 48 | 49 | print("Training finished.\n") 50 | -------------------------------------------------------------------------------- /src/space_data.rs: -------------------------------------------------------------------------------- 1 | use cpython::{PyObject, Python, PythonObject, ToPyObject}; 2 | 3 | use crate::error::GymError; 4 | use crate::{DiscreteType, VectorType}; 5 | 6 | #[derive(Debug, Clone)] 7 | pub enum SpaceData { 8 | Discrete(DiscreteType), 9 | Box(VectorType), 10 | Tuple(VectorType), 11 | } 12 | 13 | impl SpaceData { 14 | pub fn get_discrete(self) -> Result { 15 | match self { 16 | SpaceData::Discrete(n) => Ok(n), 17 | _ => Err(GymError::WrongType), 18 | } 19 | } 20 | 21 | pub fn get_box(self) -> Result, GymError> { 22 | match self { 23 | SpaceData::Box(v) => Ok(v), 24 | _ => Err(GymError::WrongType), 25 | } 26 | } 27 | 28 | pub fn get_tuple(self) -> Result, GymError> { 29 | match self { 30 | SpaceData::Tuple(s) => Ok(s), 31 | _ => Err(GymError::WrongType), 32 | } 33 | } 34 | 35 | pub fn into_pyo(self) -> PyObject { 36 | let gil = Python::acquire_gil(); 37 | let py = gil.python(); 38 | match self { 39 | SpaceData::Discrete(n) => n.into_py_object(py).into_object(), 40 | SpaceData::Box(v) => v.to_vec().into_py_object(py).into_object(), 41 | SpaceData::Tuple(spaces) => { 42 | let vpyo = spaces.to_vec().into_iter().map(Self::into_pyo).collect::>(); 43 | vpyo.into_py_object(py).into_object() 44 | }, 45 | } 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /examples/readme.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | | # | Topic | Example | Code | Finished | 4 | |--------|---------------------|---------------------------------------|-----------------------------------------------------|----------| 5 | | **1.** | Test | [Basic](#basic) | [basic.rs](basic.rs) | ≈ | 6 | | **2.** | Dynamic Programming | [Value iteration](#value-iteration) | [value_iteration.rs](value_iteration.rs) | ✗ | 7 | | **3.** | | [Policy iteration](#policy-iteration) | [policy_iteration.rs](policy_iteration.rs) | ✗ | 8 | | **4.** | Tabular RL | [Q-Learning](#q-learning) | [qlearning.rs](qlearning.rs) | ✗ | 9 | | **5.** | | [Monte Carlo with ε-soft policy](#monte-carlo-with-ε-soft-policy) | [softegreedy.rs](softegreedy.rs) | ✗ | 10 | | **6.** | | [Wolf-PHC](#wolf-phc) | [wolfphc.rs](wolfphc.rs) | ✗ | 11 | 12 | ## Test 13 | 14 | ### Basic 15 | 16 | ## Dynamic Programming 17 | 18 | ### Value Iteration 19 | 20 | ### Policy Iteration 21 | 22 | ## Tabular RL 23 | 24 | ### Q-Learning 25 | 26 | ### Monte Carlo with ε-soft policy 27 | 28 | ### Wolf-PHC 29 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Rust 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | 9 | env: 10 | CARGO_TERM_COLOR: always 11 | 12 | jobs: 13 | format: 14 | name: Format 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | 19 | # Setup toolchain 20 | - name: Setup toolchain 21 | uses: actions-rs/toolchain@v1 22 | with: 23 | toolchain: nightly 24 | 25 | # Check formatting 26 | - uses: actions-rs/cargo@v1 27 | with: 28 | command: fmt 29 | toolchain: nightly 30 | args: --all -- --check 31 | 32 | test: 33 | runs-on: ubuntu-latest 34 | 35 | steps: 36 | - uses: actions/checkout@v3 37 | 38 | # Setup toolchain 39 | - name: Setup toolchain 40 | uses: actions-rs/toolchain@v1 41 | with: 42 | toolchain: stable 43 | 44 | # Install module 45 | - name: Install module 46 | uses: actions/setup-python@v4 47 | with: 48 | python-version: "3.10" 49 | cache: "pip" # caching pip dependencies 50 | - run: | 51 | pip install swig 52 | pip install gym[accept-rom-license] 53 | pip install --user -r requirements.txt 54 | 55 | # Test 56 | - name: Run tests 57 | run: cargo test -- --test-threads=1 58 | 59 | # Code coverage 60 | - name: Codecov 61 | uses: codecov/codecov-action@v3.1.1 62 | with: 63 | token: ${{ secrets.CODECOV_TOKEN }} 64 | fail_ci_if_error: true 65 | verbose: true 66 | -------------------------------------------------------------------------------- /examples/agent.rs: -------------------------------------------------------------------------------- 1 | extern crate gym; 2 | extern crate rand; 3 | 4 | use gym::client::GymClient; 5 | use gym::{Action, State}; 6 | use rand::Rng; 7 | 8 | // Hyperparameters 9 | const ALPHA: f64 = 0.1; 10 | const GAMMA: f64 = 0.6; 11 | const EPSILON: f64 = 0.1; 12 | const INFINITY: f64 = -1.0 / 0.0; 13 | 14 | fn argmax(v: &[f64]) -> usize { 15 | let mut i_max = 0; 16 | let mut f_max = -1.0 / 0.0; 17 | for (i, &f) in v.iter().enumerate() { 18 | if f > f_max { 19 | i_max = i; 20 | f_max = f; 21 | } 22 | } 23 | i_max 24 | } 25 | 26 | fn main() { 27 | let mut rng = rand::thread_rng(); 28 | let client = GymClient::default(); 29 | let env = client.make("Taxi-v3", None).expect("Unable to create environment"); 30 | let mut qtable = [[0.0; 6]; 500]; 31 | 32 | // Exploration 33 | for ep in 0..100_000 { 34 | let mut epochs = 0; 35 | let mut done = false; 36 | let (obs, _info) = env.reset(None).expect("Unable to reset"); 37 | let mut state: usize = obs.get_discrete().unwrap(); 38 | 39 | while !done { 40 | let action = if rng.gen_bool(EPSILON) { 41 | env.action_space().sample().get_discrete().unwrap() 42 | } else { 43 | argmax(&qtable[state]) 44 | }; 45 | 46 | let State { 47 | observation, 48 | reward, 49 | is_done, 50 | is_truncated, 51 | } = env.step(&Action::Discrete(action)).unwrap(); 52 | let next_state: usize = observation.get_discrete().unwrap(); 53 | 54 | let old_value = qtable[state][action]; 55 | let next_max = qtable[next_state].iter().copied().fold(INFINITY, f64::max); 56 | 57 | let next_value = (1.0 - ALPHA).mul_add(old_value, ALPHA * GAMMA.mul_add(next_max, reward)); 58 | qtable[state][action] = next_value; 59 | 60 | state = next_state; 61 | epochs += 1; 62 | done = is_done || is_truncated; 63 | } 64 | 65 | if ep % 100 == 0 { 66 | println!("Finished episode {} in {}", ep, epochs); 67 | } 68 | } 69 | } 70 | -------------------------------------------------------------------------------- /src/utility.rs: -------------------------------------------------------------------------------- 1 | use crate::client::GymClient; 2 | use crate::State; 3 | use ndarray::concatenate; 4 | use ndarray::{Array1, Array2, ArrayView1, ArrayView2, ArrayViewMut1, Axis}; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | #[derive(Debug, Clone, Deserialize, Serialize)] 8 | pub struct StandardScaler { 9 | means: Array1, 10 | standard_deviations: Array1, 11 | } 12 | 13 | impl StandardScaler { 14 | pub fn for_environment(environment: &str) -> Self { 15 | Self::new(Self::generate_standard_scaler_samples(environment, 1000).view()) 16 | } 17 | 18 | pub fn new(samples: ArrayView2) -> Self { 19 | let standard_deviations = samples.var_axis(Axis(0), 0.0).mapv_into(|x| (x + f64::EPSILON).sqrt()); 20 | 21 | let means = samples.mean_axis(Axis(0)).unwrap(); 22 | 23 | StandardScaler { 24 | means, 25 | standard_deviations, 26 | } 27 | } 28 | 29 | pub fn scale_inplace(&self, mut sample: ArrayViewMut1) { 30 | sample -= &self.means; 31 | sample /= &self.standard_deviations; 32 | } 33 | 34 | pub fn scale(&self, sample: ArrayView1) -> Array1 { 35 | (&sample - &self.means) / &self.standard_deviations 36 | } 37 | 38 | fn generate_standard_scaler_samples(environment: &str, num_samples: usize) -> Array2 { 39 | let gym = GymClient::default(); 40 | let env = gym.make(environment, None).unwrap(); 41 | 42 | // collect samples for standard scaler 43 | let samples = env 44 | .reset(None) 45 | .unwrap() 46 | .0 47 | .get_box() 48 | .expect("expected gym environment with box type observations"); 49 | 50 | let mut samples = samples.insert_axis(Axis(0)); 51 | 52 | println!("sampling for scaler"); 53 | for _ in 0..num_samples { 54 | let State { observation, .. } = env.step(&env.action_space().sample()).unwrap(); 55 | 56 | samples = concatenate![Axis(0), samples, observation.get_box().unwrap().insert_axis(Axis(0))]; 57 | } 58 | println!("done sampling"); 59 | 60 | samples 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /.rustfmt.toml: -------------------------------------------------------------------------------- 1 | max_width = 120 2 | hard_tabs = true 3 | tab_spaces = 4 4 | newline_style = "Auto" 5 | use_small_heuristics = "Default" 6 | indent_style = "Block" 7 | wrap_comments = false 8 | format_code_in_doc_comments = true 9 | comment_width = 80 10 | normalize_comments = true 11 | normalize_doc_attributes = true 12 | format_strings = true 13 | format_macro_matchers = true 14 | format_macro_bodies = true 15 | hex_literal_case = "Preserve" 16 | empty_item_single_line = true 17 | struct_lit_single_line = true 18 | fn_single_line = false 19 | where_single_line = false 20 | imports_indent = "Block" 21 | imports_layout = "Mixed" 22 | imports_granularity = "Module" 23 | group_imports = "StdExternalCrate" 24 | reorder_imports = true 25 | reorder_modules = true 26 | reorder_impl_items = true 27 | type_punctuation_density = "Wide" 28 | space_before_colon = false 29 | space_after_colon = true 30 | spaces_around_ranges = false 31 | binop_separator = "Front" 32 | remove_nested_parens = true 33 | combine_control_expr = false 34 | overflow_delimited_expr = false 35 | struct_field_align_threshold = 0 36 | enum_discrim_align_threshold = 0 37 | match_arm_blocks = true 38 | match_arm_leading_pipes = "Never" 39 | force_multiline_blocks = false 40 | fn_args_layout = "Tall" 41 | brace_style = "SameLineWhere" 42 | control_brace_style = "ClosingNextLine" 43 | trailing_semicolon = true 44 | trailing_comma = "Vertical" 45 | match_block_trailing_comma = true 46 | blank_lines_upper_bound = 1 47 | blank_lines_lower_bound = 0 48 | edition = "2021" 49 | version = "Two" 50 | inline_attribute_width = 0 51 | format_generated_files = true 52 | merge_derives = true 53 | use_try_shorthand = true 54 | use_field_init_shorthand = true 55 | force_explicit_abi = true 56 | condense_wildcard_suffixes = false 57 | color = "Auto" 58 | required_version = "1.5.1" 59 | unstable_features = true 60 | disable_all_formatting = false 61 | skip_children = false 62 | hide_parse_errors = false 63 | error_on_line_overflow = false 64 | error_on_unformatted = false 65 | ignore = [] 66 | emit_mode = "Files" 67 | make_backup = false -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # gym-rs 2 | 3 | 4 | [![Crates.io](https://img.shields.io/crates/v/gym)](https://crates.io/crates/gym) 5 | [![Docs.rs](https://docs.rs/gym/badge.svg)](https://docs.rs/gym/latest/gym) 6 | [![codecov](https://codecov.io/gh/MrRobb/gym-rs/branch/master/graph/badge.svg)](https://codecov.io/gh/MrRobb/gym-rs) 7 | [![license](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/MrRobb/gym-rs/blob/master/LICENSE) 8 | 9 | OpenAI gym binding for Rust. 10 | 11 | > Actively maintained! If you have any problem just [create an issue](https://github.com/MrRobb/gym-rs/issues/new). 12 | 13 | ### Install 14 | 15 | Just install the requierements layed out in the [requirements.txt](https://github.com/MrRobb/gym-rs/blob/master/requirements.txt). 16 | 17 | > If you don't have python installed, go [here](https://realpython.com/installing-python/) 18 | 19 | ```sh 20 | curl "https://raw.githubusercontent.com/MrRobb/gym-rs/master/requirements.txt" > requirements.txt 21 | pip3 install -r requirements.txt 22 | ``` 23 | 24 | ### Usage 25 | 26 | Once everything is installed, just add this crate to your your Rust project. 27 | 28 | ```toml 29 | # Cargo.toml 30 | 31 | [dependencies] 32 | gym = "*" # Update * with the latest version 33 | ``` 34 | 35 | ### Example 36 | 37 | Once you have installed correctly the library, the only thing left is to test if its working ok. To do so, you just have to execute the following commands: 38 | 39 | > If you don't have Rust installed go [here](https://www.rust-lang.org/tools/install) 40 | 41 | ```sh script 42 | git clone https://github.com/MrRobb/gym-rs.git 43 | cd gym-rs 44 | pip3 install -r requirements.txt 45 | cargo run --example basic 46 | ``` 47 | 48 | ### Troubleshooting 49 | 50 | In Ubuntu 20.04, it is possible that you need to install `swig`. To do that, execute: 51 | 52 | ```sh 53 | sudo apt-get install swig 54 | ``` 55 | 56 | The example can fail with virtualenv. It's more of a general problem of the cpython crate rather than this one, you can resolve it by setting the PYTHONHOME env var to the module path of the venv, e.g.: 57 | 58 | ```sh 59 | PYTHONPATH=~/venv-py37/lib/python3.7/site-packages cargo run --example basic 60 | ``` 61 | 62 | ## Donation (BTC) 63 | 64 |

65 | 66 | 67 | 68 |

69 |

BTC address: 3KRM66geiaXWzqs5hRb35dGiQEQAa6JTYU

70 | -------------------------------------------------------------------------------- /src/environment.rs: -------------------------------------------------------------------------------- 1 | use cpython::{GILGuard, NoArgs, ObjectProtocol, PyDict, PyObject, PyTuple}; 2 | 3 | use crate::error::GymError; 4 | use crate::space_data::SpaceData; 5 | use crate::space_template::SpaceTemplate; 6 | use crate::{Action, State}; 7 | 8 | pub struct Environment<'a> { 9 | pub gil: &'a GILGuard, 10 | pub env: PyObject, 11 | pub observation_space: SpaceTemplate, 12 | pub action_space: SpaceTemplate, 13 | } 14 | 15 | impl<'a> Drop for Environment<'a> { 16 | fn drop(&mut self) { 17 | self.close(); 18 | } 19 | } 20 | 21 | impl<'a> Environment<'a> { 22 | pub fn reset(&self, seed: Option) -> Result<(SpaceData, PyObject), GymError> { 23 | let py = self.gil.python(); 24 | let dict = PyDict::new(py); 25 | if let Some(seed) = seed { 26 | dict.set_item(py, "seed", seed).map_err(|_| GymError::InvalidSeed)?; 27 | } 28 | let result = self 29 | .env 30 | .call_method(py, "reset", NoArgs, Some(&dict)) 31 | .expect("Unable to call 'reset'"); 32 | let observation = self 33 | .observation_space 34 | .extract_data(&result.get_item(py, 0).map_err(|_| GymError::WrongResetResult)?)?; 35 | let info = result.get_item(py, 1).map_err(|_| GymError::WrongResetResult)?; 36 | Ok((observation, info)) 37 | } 38 | 39 | pub fn render(&self) { 40 | let py = self.gil.python(); 41 | self.env 42 | .call_method(py, "render", NoArgs, None) 43 | .expect("Unable to call 'render'"); 44 | } 45 | 46 | pub fn step(&self, action: &Action) -> Result { 47 | let py = self.gil.python(); 48 | let result = match action { 49 | Action::Discrete(n) => self 50 | .env 51 | .call_method(py, "step", (n,), None) 52 | .map_err(|_| GymError::InvalidAction)?, 53 | Action::Box(v) => { 54 | let vv = v.to_vec(); 55 | self.env 56 | .call_method(py, "step", (vv,), None) 57 | .map_err(|_| GymError::InvalidAction)? 58 | }, 59 | Action::Tuple(spaces) => { 60 | let vpyo = spaces.to_vec().into_iter().map(SpaceData::into_pyo).collect::>(); 61 | let tuple_pyo = PyTuple::new(py, &vpyo); 62 | self.env 63 | .call_method(py, "step", (tuple_pyo,), None) 64 | .map_err(|_| GymError::InvalidAction)? 65 | }, 66 | }; 67 | 68 | let s = State { 69 | observation: self 70 | .observation_space 71 | .extract_data(&result.get_item(py, 0).map_err(|_| GymError::WrongStepResult)?)?, 72 | reward: result 73 | .get_item(py, 1) 74 | .map_err(|_| GymError::WrongStepResult)? 75 | .extract(py) 76 | .map_err(|_| GymError::WrongStepResult)?, 77 | is_done: result 78 | .get_item(py, 2) 79 | .map_err(|_| GymError::WrongStepResult)? 80 | .extract(py) 81 | .map_err(|_| GymError::WrongStepResult)?, 82 | is_truncated: result 83 | .get_item(py, 3) 84 | .map_err(|_| GymError::WrongStepResult)? 85 | .extract(py) 86 | .map_err(|_| GymError::WrongStepResult)?, 87 | }; 88 | 89 | Ok(s) 90 | } 91 | 92 | pub fn close(&self) { 93 | let py = self.gil.python(); 94 | let _res = self 95 | .env 96 | .call_method(py, "close", NoArgs, None) 97 | .expect("Unable to call 'close'"); 98 | } 99 | 100 | /// Returns the number of allowed actions for this environment. 101 | pub fn action_space(&self) -> &SpaceTemplate { 102 | &self.action_space 103 | } 104 | 105 | /// Returns the shape of the observation tensors. 106 | pub fn observation_space(&self) -> &SpaceTemplate { 107 | &self.observation_space 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /src/client.rs: -------------------------------------------------------------------------------- 1 | use cpython::{GILGuard, ObjectProtocol, PyDict, PyModule, Python}; 2 | 3 | use crate::space_template::SpaceTemplate; 4 | use crate::{environment::Environment, error::GymError}; 5 | 6 | pub struct GymClient { 7 | pub gil: GILGuard, 8 | pub gym: PyModule, 9 | pub version: String, 10 | } 11 | 12 | pub enum RenderMode { 13 | Human, 14 | RgbArray, 15 | Custom(String), 16 | } 17 | 18 | impl ToString for RenderMode { 19 | fn to_string(&self) -> String { 20 | match self { 21 | RenderMode::Human => "human".to_string(), 22 | RenderMode::RgbArray => "rgb_array".to_string(), 23 | RenderMode::Custom(s) => s.to_string(), 24 | } 25 | } 26 | } 27 | 28 | #[derive(Default)] 29 | pub struct MakeOptions { 30 | pub render_mode: Option, 31 | pub apply_api_compatibility: bool, 32 | pub use_old_gym_enviroment: bool, 33 | } 34 | 35 | impl Default for GymClient { 36 | fn default() -> Self { 37 | // Get python 38 | let gil = Python::acquire_gil(); 39 | let py = gil.python(); 40 | 41 | // Set argv[0] -> otherwise render() fails 42 | let sys = py.import("sys").expect("Error: import sys"); 43 | 44 | match sys.get(py, "argv") { 45 | Result::Ok(argv) => { 46 | argv.call_method(py, "append", ("",), None) 47 | .expect("Error: sys.argv.append('')"); 48 | }, 49 | Result::Err(_) => {}, 50 | }; 51 | 52 | // Import gym 53 | let gym = py.import("gymnasium").expect("Error: import gym"); 54 | let version = gym 55 | .get(py, "__version__") 56 | .expect("Unable to call gym.__version__") 57 | .extract(py) 58 | .expect("Unable to call gym.__version__"); 59 | 60 | Self { gil, gym, version } 61 | } 62 | } 63 | 64 | impl GymClient { 65 | pub fn make(&self, mut env_id: &str, options: Option) -> Result { 66 | let py = self.gil.python(); 67 | let dict = PyDict::new(py); 68 | if let Some(options) = options { 69 | dict.set_item(py, "apply_api_compatibility", options.apply_api_compatibility) 70 | .expect("Unable to set apply_api_compatibility"); 71 | if let Some(render_mode) = options.render_mode { 72 | dict.set_item(py, "render_mode", render_mode.to_string()) 73 | .map_err(|_| GymError::InvalidRenderMode)?; 74 | } 75 | if options.use_old_gym_enviroment { 76 | dict.set_item(py, "env_id", env_id).expect("Unable to set env_id"); 77 | env_id = "GymV26Environment-v0"; 78 | } 79 | } 80 | let env = self 81 | .gym 82 | .call(py, "make", (env_id,), Some(&dict)) 83 | .map_err(|e| GymError::InvalidMake(env_id.to_owned(), dict.items(py), e))?; 84 | 85 | Ok(Environment { 86 | gil: &self.gil, 87 | observation_space: SpaceTemplate::extract_template( 88 | &env.getattr(py, "observation_space") 89 | .expect("Unable to get attribute 'observation_space'"), 90 | ), 91 | action_space: SpaceTemplate::extract_template( 92 | &env.getattr(py, "action_space") 93 | .expect("Unable to get attribute 'action_space'"), 94 | ), 95 | env, 96 | }) 97 | } 98 | 99 | pub fn list_all(&self) -> Vec { 100 | let py = self.gil.python(); 101 | // gymnasiun.envs.registry.keys() 102 | self.gym 103 | .get(py, "envs") 104 | .expect("Unable to call gym.envs") 105 | .getattr(py, "registry") 106 | .expect("Unable to get attribute 'all'") 107 | .cast_as::(py) 108 | .unwrap() 109 | .items(py) 110 | .iter() 111 | .map(|(k, _)| k.extract::(py).unwrap()) 112 | .collect() 113 | } 114 | 115 | pub fn version(&self) -> &str { 116 | self.version.as_str() 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /tests/integration_tests.rs: -------------------------------------------------------------------------------- 1 | #[cfg(test)] 2 | mod tests { 3 | use gym::client::{GymClient, MakeOptions}; 4 | use gym::space_data::SpaceData; 5 | use gym::Action; 6 | 7 | #[test] 8 | fn test_gym_client() { 9 | let _client = GymClient::default(); 10 | } 11 | 12 | #[test] 13 | fn test_make() { 14 | let client = GymClient::default(); 15 | client.make("CartPole-v1", None).unwrap(); 16 | } 17 | 18 | #[test] 19 | fn test_seed() { 20 | let client = GymClient::default(); 21 | let env = client.make("FrozenLake-v1", None).unwrap(); 22 | let (obs, _) = env.reset(Some(1002)).unwrap(); 23 | assert_eq!(0, obs.get_discrete().unwrap()); 24 | let action = SpaceData::Discrete(1); 25 | let state = env.step(&action).unwrap(); 26 | assert_eq!(4, state.observation.get_discrete().unwrap()); 27 | } 28 | 29 | #[test] 30 | fn test_reset() { 31 | let client = GymClient::default(); 32 | let env = client.make("CartPole-v1", None).unwrap(); 33 | env.reset(None).unwrap(); 34 | } 35 | 36 | #[test] 37 | fn test_box_observation_3d() { 38 | let client = GymClient::default(); 39 | let env = client 40 | .make( 41 | "ALE/Asteroids-v5", 42 | Some(MakeOptions { 43 | use_old_gym_enviroment: true, 44 | ..Default::default() 45 | }), 46 | ) 47 | .unwrap(); 48 | env.reset(None).unwrap(); 49 | env.step(&env.action_space().sample()).unwrap(); 50 | } 51 | 52 | #[test] 53 | fn test_step() { 54 | let client = GymClient::default(); 55 | let env = client.make("CartPole-v1", None).unwrap(); 56 | env.reset(None).unwrap(); 57 | let action = env.action_space().sample(); 58 | env.step(&action).unwrap(); 59 | } 60 | 61 | #[test] 62 | #[should_panic] 63 | fn test_invalid_action() { 64 | let client = GymClient::default(); 65 | let env = client.make("CartPole-v1", None).unwrap(); 66 | env.reset(None).unwrap(); 67 | let action = Action::Discrete(500); // invalid action 68 | env.step(&action).unwrap(); 69 | } 70 | 71 | #[test] 72 | #[should_panic] 73 | fn test_wrong_type() { 74 | let client = GymClient::default(); 75 | let env = client.make("CartPole-v1", None).unwrap(); 76 | env.reset(None).unwrap(); 77 | let _res = env.action_space().sample().get_box().unwrap(); 78 | } 79 | 80 | #[test] 81 | fn test_box_action() { 82 | let client = GymClient::default(); 83 | let env = client.make("BipedalWalker-v3", None).unwrap(); 84 | env.reset(None).unwrap(); 85 | let action = env.action_space().sample(); 86 | env.step(&action).unwrap(); 87 | } 88 | 89 | #[test] 90 | fn test_tuple_template() { 91 | let client = GymClient::default(); 92 | let _res = client.make("Blackjack-v1", None).unwrap(); 93 | } 94 | 95 | #[test] 96 | fn test_tuple_obs() { 97 | let client = GymClient::default(); 98 | let env = client.make("Blackjack-v1", None).unwrap(); 99 | env.reset(None).unwrap(); 100 | let action = env.action_space().sample(); 101 | env.step(&action).unwrap(); 102 | } 103 | 104 | #[test] 105 | fn test_tuple_action() { 106 | let client = GymClient::default(); 107 | let env = client 108 | .make( 109 | "ReversedAddition3-v0", 110 | Some(MakeOptions { 111 | use_old_gym_enviroment: true, 112 | ..Default::default() 113 | }), 114 | ) 115 | .unwrap(); 116 | env.reset(None).unwrap(); 117 | let action = env.action_space().sample(); 118 | env.step(&action).unwrap(); 119 | } 120 | 121 | #[test] 122 | fn test_gym_version() { 123 | let client = GymClient::default(); 124 | assert!(!client.version().is_empty()); 125 | } 126 | 127 | #[test] 128 | fn test_render() { 129 | let client = GymClient::default(); 130 | let env = client.make("FrozenLake-v1", None).unwrap(); 131 | env.reset(None).unwrap(); 132 | let action = env.action_space().sample(); 133 | env.step(&action).unwrap(); 134 | env.render(); 135 | } 136 | 137 | #[test] 138 | fn test_close() { 139 | let client = GymClient::default(); 140 | let env = client.make("FrozenLake-v1", None).unwrap(); 141 | env.close(); 142 | } 143 | } 144 | -------------------------------------------------------------------------------- /src/space_template.rs: -------------------------------------------------------------------------------- 1 | use cpython::{NoArgs, ObjectProtocol, PyObject, Python}; 2 | use rand::Rng; 3 | 4 | use crate::error::GymError; 5 | use crate::space_data::SpaceData; 6 | use crate::DiscreteType; 7 | 8 | #[derive(Debug)] 9 | pub enum SpaceTemplate { 10 | Discrete { 11 | n: DiscreteType, 12 | }, 13 | Box { 14 | high: Vec, 15 | low: Vec, 16 | shape: Vec, 17 | }, 18 | Tuple { 19 | spaces: Vec, 20 | }, 21 | } 22 | 23 | impl SpaceTemplate { 24 | pub fn extract_data(&self, pyo: &PyObject) -> Result { 25 | let gil = Python::acquire_gil(); 26 | let py = gil.python(); 27 | 28 | match self { 29 | SpaceTemplate::Discrete { .. } => { 30 | let n = pyo 31 | .extract::(py) 32 | .map_err(|_| GymError::InvalidConversion)?; 33 | Ok(SpaceData::Discrete(n)) 34 | }, 35 | SpaceTemplate::Box { .. } => { 36 | let v = pyo 37 | .call_method(py, "flatten", NoArgs, None) 38 | .map_err(|_| GymError::InvalidConversion)? 39 | .extract::>(py) 40 | .map_err(|_| GymError::InvalidConversion)?; 41 | Ok(SpaceData::Box(v.into())) 42 | }, 43 | SpaceTemplate::Tuple { .. } => { 44 | let mut tuple = vec![]; 45 | let mut i = 0; 46 | let mut item = pyo.get_item(py, i); 47 | while item.is_ok() { 48 | let pyo_item = self.extract_data(&item.unwrap())?; 49 | tuple.push(pyo_item); 50 | i += 1; 51 | item = pyo.get_item(py, i); 52 | } 53 | Ok(SpaceData::Tuple(tuple.into())) 54 | }, 55 | } 56 | } 57 | 58 | pub fn extract_template(pyo: &PyObject) -> Self { 59 | let gil = Python::acquire_gil(); 60 | let py = gil.python(); 61 | 62 | let class = pyo 63 | .getattr(py, "__class__") 64 | .expect("Unable to extract __class__ (this should never happen)"); 65 | 66 | let name = class 67 | .getattr(py, "__name__") 68 | .expect("Unable to extract __name__ (this should never happen)") 69 | .extract::(py) 70 | .expect("Unable to extract __name__ (this should never happen)"); 71 | 72 | match name.as_ref() { 73 | "Discrete" => { 74 | let n = pyo 75 | .getattr(py, "n") 76 | .expect("Unable to get attribute 'n'") 77 | .extract::(py) 78 | .expect("Unable to convert 'n' to usize"); 79 | Self::Discrete { n } 80 | }, 81 | "Box" => { 82 | let high = pyo 83 | .getattr(py, "high") 84 | .expect("Unable to get attribute 'high'") 85 | .call_method(py, "flatten", NoArgs, None) 86 | .expect("Unable to call 'flatten'") 87 | .extract::>(py) 88 | .expect("Unable to convert 'high' to Vec"); 89 | 90 | let low = pyo 91 | .getattr(py, "low") 92 | .expect("Unable to get attribute 'low'") 93 | .call_method(py, "flatten", NoArgs, None) 94 | .expect("Unable to call 'flatten'") 95 | .extract::>(py) 96 | .expect("Unable to convert 'low' to Vec"); 97 | 98 | let shape = pyo 99 | .getattr(py, "shape") 100 | .expect("Unable to get attribute 'shape'") 101 | .extract::>(py) 102 | .expect("Unable to convert 'shape' to Vec"); 103 | 104 | debug_assert_eq!(high.len(), low.len()); 105 | debug_assert_eq!(low.len(), shape.iter().product()); 106 | high.iter().zip(low.iter()).for_each(|(h, l)| debug_assert!(h > l)); 107 | 108 | Self::Box { high, low, shape } 109 | }, 110 | "Tuple" => { 111 | let mut i = 0; 112 | let mut tuple = vec![]; 113 | let mut item = pyo.get_item(py, i); 114 | 115 | while item.is_ok() { 116 | let pyo_item = item.unwrap(); 117 | let space = Self::extract_template(&pyo_item); 118 | tuple.push(space); 119 | i += 1; 120 | item = pyo.get_item(py, i); 121 | } 122 | 123 | Self::Tuple { spaces: tuple } 124 | }, 125 | _ => unreachable!(), 126 | } 127 | } 128 | 129 | pub fn sample(&self) -> SpaceData { 130 | let mut rng = rand::thread_rng(); 131 | match self { 132 | SpaceTemplate::Discrete { n } => SpaceData::Discrete(rng.gen_range(0..*n)), 133 | SpaceTemplate::Box { high, low, shape } => { 134 | let dimensions = shape.len(); 135 | let mut v = vec![]; 136 | for d in 0..dimensions { 137 | for _ in 0..shape[d] { 138 | v.push(rng.gen_range(low[d]..high[d])); 139 | } 140 | } 141 | SpaceData::Box(v.into()) 142 | }, 143 | SpaceTemplate::Tuple { spaces } => { 144 | let mut tuple = vec![]; 145 | for space in spaces { 146 | let sample = space.sample(); 147 | tuple.push(sample); 148 | } 149 | SpaceData::Tuple(tuple.into()) 150 | }, 151 | } 152 | } 153 | } 154 | --------------------------------------------------------------------------------