├── .github ├── ISSUE_TEMPLATE │ ├── bug.md │ ├── proposal.md │ └── question.md ├── PULL_REQUEST_TEMPLATE.md ├── stale.yml └── workflows │ ├── build.yml │ └── pre-commit.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.rst ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── bin └── docker_entrypoint ├── gym ├── __init__.py ├── core.py ├── envs │ ├── __init__.py │ ├── box2d │ │ ├── __init__.py │ │ ├── bipedal_walker.py │ │ ├── car_dynamics.py │ │ ├── car_racing.py │ │ └── lunar_lander.py │ ├── classic_control │ │ ├── __init__.py │ │ ├── acrobot.py │ │ ├── assets │ │ │ └── clockwise.png │ │ ├── cartpole.py │ │ ├── continuous_mountain_car.py │ │ ├── mountain_car.py │ │ ├── pendulum.py │ │ └── utils.py │ ├── mujoco │ │ ├── __init__.py │ │ ├── ant.py │ │ ├── ant_v3.py │ │ ├── ant_v4.py │ │ ├── assets │ │ │ ├── ant.xml │ │ │ ├── half_cheetah.xml │ │ │ ├── hopper.xml │ │ │ ├── humanoid.xml │ │ │ ├── humanoidstandup.xml │ │ │ ├── inverted_double_pendulum.xml │ │ │ ├── inverted_pendulum.xml │ │ │ ├── point.xml │ │ │ ├── pusher.xml │ │ │ ├── reacher.xml │ │ │ ├── swimmer.xml │ │ │ └── walker2d.xml │ │ ├── half_cheetah.py │ │ ├── half_cheetah_v3.py │ │ ├── half_cheetah_v4.py │ │ ├── hopper.py │ │ ├── hopper_v3.py │ │ ├── hopper_v4.py │ │ ├── humanoid.py │ │ ├── humanoid_v3.py │ │ ├── humanoid_v4.py │ │ ├── humanoidstandup.py │ │ ├── humanoidstandup_v4.py │ │ ├── inverted_double_pendulum.py │ │ ├── inverted_double_pendulum_v4.py │ │ ├── inverted_pendulum.py │ │ ├── inverted_pendulum_v4.py │ │ ├── mujoco_env.py │ │ ├── mujoco_rendering.py │ │ ├── pusher.py │ │ ├── pusher_v4.py │ │ ├── reacher.py │ │ ├── reacher_v4.py │ │ ├── swimmer.py │ │ ├── swimmer_v3.py │ │ ├── swimmer_v4.py │ │ ├── walker2d.py │ │ ├── walker2d_v3.py │ │ └── walker2d_v4.py │ ├── registration.py │ └── toy_text │ │ ├── __init__.py │ │ ├── blackjack.py │ │ ├── cliffwalking.py │ │ ├── font │ │ └── Minecraft.ttf │ │ ├── frozen_lake.py │ │ ├── img │ │ ├── C2.png │ │ ├── C3.png │ │ ├── C4.png │ │ ├── C5.png │ │ ├── C6.png │ │ ├── C7.png │ │ ├── C8.png │ │ ├── C9.png │ │ ├── CA.png │ │ ├── CJ.png │ │ ├── CK.png │ │ ├── CQ.png │ │ ├── CT.png │ │ ├── Card.png │ │ ├── D2.png │ │ ├── D3.png │ │ ├── D4.png │ │ ├── D5.png │ │ ├── D6.png │ │ ├── D7.png │ │ ├── D8.png │ │ ├── D9.png │ │ ├── DA.png │ │ ├── DJ.png │ │ ├── DK.png │ │ ├── DQ.png │ │ ├── DT.png │ │ ├── H2.png │ │ ├── H3.png │ │ ├── H4.png │ │ ├── H5.png │ │ ├── H6.png │ │ ├── H7.png │ │ ├── H8.png │ │ ├── H9.png │ │ ├── HA.png │ │ ├── HJ.png │ │ ├── HK.png │ │ ├── HQ.png │ │ ├── HT.png │ │ ├── S2.png │ │ ├── S3.png │ │ ├── S4.png │ │ ├── S5.png │ │ ├── S6.png │ │ ├── S7.png │ │ ├── S8.png │ │ ├── S9.png │ │ ├── SA.png │ │ ├── SJ.png │ │ ├── SK.png │ │ ├── SQ.png │ │ ├── ST.png │ │ ├── cab_front.png │ │ ├── cab_left.png │ │ ├── cab_rear.png │ │ ├── cab_right.png │ │ ├── cookie.png │ │ ├── cracked_hole.png │ │ ├── elf_down.png │ │ ├── elf_left.png │ │ ├── elf_right.png │ │ ├── elf_up.png │ │ ├── goal.png │ │ ├── gridworld_median_bottom.png │ │ ├── gridworld_median_horiz.png │ │ ├── gridworld_median_left.png │ │ ├── gridworld_median_right.png │ │ ├── gridworld_median_top.png │ │ ├── gridworld_median_vert.png │ │ ├── hole.png │ │ ├── hotel.png │ │ ├── ice.png │ │ ├── mountain_bg1.png │ │ ├── mountain_bg2.png │ │ ├── mountain_cliff.png │ │ ├── mountain_near-cliff1.png │ │ ├── mountain_near-cliff2.png │ │ ├── passenger.png │ │ ├── stool.png │ │ └── taxi_background.png │ │ ├── taxi.py │ │ └── utils.py ├── error.py ├── logger.py ├── py.typed ├── spaces │ ├── __init__.py │ ├── box.py │ ├── dict.py │ ├── discrete.py │ ├── graph.py │ ├── multi_binary.py │ ├── multi_discrete.py │ ├── sequence.py │ ├── space.py │ ├── text.py │ ├── tuple.py │ └── utils.py ├── utils │ ├── __init__.py │ ├── colorize.py │ ├── env_checker.py │ ├── ezpickle.py │ ├── passive_env_checker.py │ ├── play.py │ ├── save_video.py │ ├── seeding.py │ └── step_api_compatibility.py ├── vector │ ├── __init__.py │ ├── async_vector_env.py │ ├── sync_vector_env.py │ ├── utils │ │ ├── __init__.py │ │ ├── misc.py │ │ ├── numpy_utils.py │ │ ├── shared_memory.py │ │ └── spaces.py │ └── vector_env.py ├── version.py └── wrappers │ ├── README.md │ ├── __init__.py │ ├── atari_preprocessing.py │ ├── autoreset.py │ ├── clip_action.py │ ├── compatibility.py │ ├── env_checker.py │ ├── filter_observation.py │ ├── flatten_observation.py │ ├── frame_stack.py │ ├── gray_scale_observation.py │ ├── human_rendering.py │ ├── monitoring │ ├── __init__.py │ └── video_recorder.py │ ├── normalize.py │ ├── order_enforcing.py │ ├── pixel_observation.py │ ├── record_episode_statistics.py │ ├── record_video.py │ ├── render_collection.py │ ├── rescale_action.py │ ├── resize_observation.py │ ├── step_api_compatibility.py │ ├── time_aware_observation.py │ ├── time_limit.py │ ├── transform_observation.py │ ├── transform_reward.py │ └── vector_list_info.py ├── py.Dockerfile ├── pyproject.toml ├── requirements.txt ├── setup.py ├── test_requirements.txt └── tests ├── __init__.py ├── envs ├── __init__.py ├── test_action_dim_check.py ├── test_compatibility.py ├── test_env_implementation.py ├── test_envs.py ├── test_make.py ├── test_mujoco.py ├── test_register.py ├── test_spec.py ├── utils.py └── utils_envs.py ├── spaces ├── __init__.py ├── test_box.py ├── test_dict.py ├── test_discrete.py ├── test_graph.py ├── test_multibinary.py ├── test_multidiscrete.py ├── test_sequence.py ├── test_space.py ├── test_spaces.py ├── test_text.py ├── test_tuple.py ├── test_utils.py └── utils.py ├── test_core.py ├── testing_env.py ├── utils ├── __init__.py ├── test_env_checker.py ├── test_passive_env_checker.py ├── test_play.py ├── test_save_video.py ├── test_seeding.py └── test_step_api_compatibility.py ├── vector ├── __init__.py ├── test_async_vector_env.py ├── test_numpy_utils.py ├── test_shared_memory.py ├── test_spaces.py ├── test_sync_vector_env.py ├── test_vector_env.py ├── test_vector_env_info.py ├── test_vector_env_wrapper.py ├── test_vector_make.py └── utils.py └── wrappers ├── __init__.py ├── test_atari_preprocessing.py ├── test_autoreset.py ├── test_clip_action.py ├── test_filter_observation.py ├── test_flatten.py ├── test_flatten_observation.py ├── test_frame_stack.py ├── test_gray_scale_observation.py ├── test_human_rendering.py ├── test_nested_dict.py ├── test_normalize.py ├── test_order_enforcing.py ├── test_passive_env_checker.py ├── test_pixel_observation.py ├── test_record_episode_statistics.py ├── test_record_video.py ├── test_rescale_action.py ├── test_resize_observation.py ├── test_step_compatibility.py ├── test_time_aware_observation.py ├── test_time_limit.py ├── test_transform_observation.py ├── test_transform_reward.py ├── test_vector_list_info.py ├── test_video_recorder.py └── utils.py /.github/ISSUE_TEMPLATE/bug.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug Report 3 | about: Submit a bug report 4 | title: "[Bug Report] Bug title" 5 | 6 | --- 7 | 8 | If you are submitting a bug report, please fill in the following details and use the tag [bug]. 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **Code example** 14 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful. 15 | 16 | **System Info** 17 | Describe the characteristic of your environment: 18 | * Describe how Gym was installed (pip, docker, source, ...) 19 | * What OS/version of Linux you're using. Note that while we will accept PRs to improve Window's support, we do not officially support it. 20 | * Python version 21 | 22 | **Additional context** 23 | Add any other context about the problem here. 24 | 25 | ### Checklist 26 | 27 | - [ ] I have checked that there is no similar [issue](https://github.com/openai/gym/issues) in the repo (**required**) 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/proposal.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Proposal 3 | about: Propose changes that are not fixes bugs 4 | title: "[Proposal] Proposal title" 5 | --- 6 | 7 | 8 | 9 | ### Proposal 10 | 11 | A clear and concise description of the proposal. 12 | 13 | ### Motivation 14 | 15 | Please outline the motivation for the proposal. 16 | Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". 17 | If this is related to another GitHub issue, please link here too. 18 | 19 | ### Pitch 20 | 21 | A clear and concise description of what you want to happen. 22 | 23 | ### Alternatives 24 | 25 | A clear and concise description of any alternative solutions or features you've considered, if any. 26 | 27 | ### Additional context 28 | 29 | Add any other context or screenshots about the feature request here. 30 | 31 | ### Checklist 32 | 33 | - [ ] I have checked that there is no similar [issue](https://github.com/openai/gym/issues) in the repo (**required**) 34 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask a question 4 | title: "[Question] Question title" 5 | --- 6 | 7 | 8 | ### Question 9 | 10 | If you're a beginner and have basic questions, please ask on [r/reinforcementlearning](https://www.reddit.com/r/reinforcementlearning/) or in the [RL Discord](https://discord.com/invite/xhfNqQv) (if you're new please use the beginners channel). Basic questions that are not bugs or feature requests will be closed without reply, because GitHub issues are not an appropriate venue for these. 11 | 12 | Advanced/nontrivial questions, especially in areas where documentation is lacking, are very much welcome. 13 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. 4 | 5 | Fixes # (issue) 6 | 7 | ## Type of change 8 | 9 | Please delete options that are not relevant. 10 | 11 | - [ ] Bug fix (non-breaking change which fixes an issue) 12 | - [ ] New feature (non-breaking change which adds functionality) 13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 14 | - [ ] This change requires a documentation update 15 | 16 | ### Screenshots 17 | Please attach before and after screenshots of the change if applicable. 18 | 19 | 29 | 30 | # Checklist: 31 | 32 | - [ ] I have run the [`pre-commit` checks](https://pre-commit.com/) with `pre-commit run --all-files` (see `CONTRIBUTING.md` instructions to set it up) 33 | - [ ] I have commented my code, particularly in hard-to-understand areas 34 | - [ ] I have made corresponding changes to the documentation 35 | - [ ] My changes generate no new warnings 36 | - [ ] I have added tests that prove my fix is effective or that my feature works 37 | - [ ] New and existing unit tests pass locally with my changes 38 | 39 | 46 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Configuration for probot-stale - https://github.com/probot/stale 2 | 3 | # Number of days of inactivity before an Issue or Pull Request becomes stale 4 | daysUntilStale: 60 5 | 6 | # Number of days of inactivity before an Issue or Pull Request with the stale label is closed. 7 | # Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. 8 | daysUntilClose: 14 9 | 10 | # Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled) 11 | onlyLabels: 12 | - more-information-needed 13 | 14 | # Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable 15 | exemptLabels: 16 | - pinned 17 | - security 18 | - "[Status] Maybe Later" 19 | 20 | # Set to true to ignore issues in a project (defaults to false) 21 | exemptProjects: true 22 | 23 | # Set to true to ignore issues in a milestone (defaults to false) 24 | exemptMilestones: true 25 | 26 | # Set to true to ignore issues with an assignee (defaults to false) 27 | exemptAssignees: true 28 | 29 | # Label to use when marking as stale 30 | staleLabel: stale 31 | 32 | # Comment to post when marking as stale. Set to `false` to disable 33 | markComment: > 34 | This issue has been automatically marked as stale because it has not had 35 | recent activity. It will be closed if no further activity occurs. Thank you 36 | for your contributions. 37 | 38 | # Comment to post when removing the stale label. 39 | # unmarkComment: > 40 | # Your comment here. 41 | 42 | # Comment to post when closing a stale Issue or Pull Request. 43 | # closeComment: > 44 | # Your comment here. 45 | 46 | # Limit the number of actions per hour, from 1-30. Default is 30 47 | limitPerRun: 30 48 | 49 | # Limit to only `issues` or `pulls` 50 | only: issues 51 | 52 | # Optionally, specify configuration settings that are specific to just 'issues' or 'pulls': 53 | # pulls: 54 | # daysUntilStale: 30 55 | # markComment: > 56 | # This pull request has been automatically marked as stale because it has not had 57 | # recent activity. It will be closed if no further activity occurs. Thank you 58 | # for your contributions. 59 | 60 | # issues: 61 | # exemptLabels: 62 | # - confirmed -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | on: [pull_request, push] 3 | 4 | permissions: 5 | contents: read # to fetch code (actions/checkout) 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | strategy: 11 | matrix: 12 | python-version: ['3.6', '3.7', '3.8', '3.9', '3.10'] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - run: | 16 | docker build -f py.Dockerfile \ 17 | --build-arg PYTHON_VERSION=${{ matrix.python-version }} \ 18 | --tag gym-docker . 19 | - name: Run tests 20 | run: docker run gym-docker pytest 21 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. 3 | name: pre-commit 4 | on: 5 | pull_request: 6 | push: 7 | branches: [master] 8 | permissions: 9 | contents: read # to fetch code (actions/checkout) 10 | jobs: 11 | pre-commit: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: actions/setup-python@v2 16 | - run: pip install pre-commit 17 | - run: pre-commit --version 18 | - run: pre-commit install 19 | - run: pre-commit run --all-files 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | *.py~ 4 | .DS_Store 5 | .cache 6 | .pytest_cache/ 7 | 8 | # Setuptools distribution and build folders. 9 | /dist/ 10 | /build 11 | 12 | # Virtualenv 13 | /env 14 | 15 | # Python egg metadata, regenerated from source files by setuptools. 16 | /*.egg-info 17 | 18 | *.sublime-project 19 | *.sublime-workspace 20 | 21 | logs/ 22 | 23 | .ipynb_checkpoints 24 | ghostdriver.log 25 | 26 | junk 27 | MUJOCO_LOG.txt 28 | 29 | rllab_mujoco 30 | 31 | tutorial/*.html 32 | 33 | # IDE files 34 | .eggs 35 | .tox 36 | 37 | # PyCharm project files 38 | .idea 39 | vizdoom.ini 40 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/python/black 4 | rev: 22.3.0 5 | hooks: 6 | - id: black 7 | - repo: https://github.com/codespell-project/codespell 8 | rev: v2.1.0 9 | hooks: 10 | - id: codespell 11 | args: 12 | - --ignore-words-list=nd,reacher,thist,ths, ure, referenc 13 | - repo: https://gitlab.com/PyCQA/flake8 14 | rev: 4.0.1 15 | hooks: 16 | - id: flake8 17 | args: 18 | - '--per-file-ignores=*/__init__.py:F401 gym/envs/registration.py:E704' 19 | - --ignore=E203,W503,E741 20 | - --max-complexity=30 21 | - --max-line-length=456 22 | - --show-source 23 | - --statistics 24 | - repo: https://github.com/PyCQA/isort 25 | rev: 5.10.1 26 | hooks: 27 | - id: isort 28 | args: ["--profile", "black"] 29 | - repo: https://github.com/pycqa/pydocstyle 30 | rev: 6.1.1 # pick a git hash / tag to point to 31 | hooks: 32 | - id: pydocstyle 33 | exclude: ^(gym/version.py)|(gym/envs/)|(tests/) 34 | args: 35 | - --source 36 | - --explain 37 | - --convention=google 38 | additional_dependencies: ["toml"] 39 | - repo: https://github.com/asottile/pyupgrade 40 | rev: v2.32.0 41 | hooks: 42 | - id: pyupgrade 43 | # TODO: remove `--keep-runtime-typing` option 44 | args: ["--py36-plus", "--keep-runtime-typing"] 45 | - repo: local 46 | hooks: 47 | - id: pyright 48 | name: pyright 49 | entry: pyright 50 | language: node 51 | pass_filenames: false 52 | types: [python] 53 | additional_dependencies: ["pyright"] 54 | args: 55 | - --project=pyproject.toml 56 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.rst: -------------------------------------------------------------------------------- 1 | OpenAI Gym is dedicated to providing a harassment-free experience for 2 | everyone, regardless of gender, gender identity and expression, sexual 3 | orientation, disability, physical appearance, body size, age, race, or 4 | religion. We do not tolerate harassment of participants in any form. 5 | 6 | This code of conduct applies to all OpenAI Gym spaces (including Gist 7 | comments) both online and off. Anyone who violates this code of 8 | conduct may be sanctioned or expelled from these spaces at the 9 | discretion of the OpenAI team. 10 | 11 | We may add additional rules over time, which will be made clearly 12 | available to participants. Participants are responsible for knowing 13 | and abiding by these rules. 14 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Gym Contribution Guidelines 2 | 3 | At this time we are currently accepting the current forms of contributions: 4 | 5 | - Bug reports (keep in mind that changing environment behavior should be minimized as that requires releasing a new version of the environment and makes results hard to compare across versions) 6 | - Pull requests for bug fixes 7 | - Documentation improvements 8 | 9 | Notably, we are not accepting these forms of contributions: 10 | 11 | - New environments 12 | - New features 13 | 14 | This may change in the future. 15 | If you wish to make a Gym environment, follow the instructions in [Creating Environments](https://github.com/openai/gym/blob/master/docs/creating_environments.md). When your environment works, you can make a PR to add it to the bottom of the [List of Environments](https://github.com/openai/gym/blob/master/docs/third_party_environments.md). 16 | 17 | 18 | Edit July 27, 2021: Please see https://github.com/openai/gym/issues/2259 for new contributing standards 19 | 20 | # Development 21 | This section contains technical instructions & hints for the contributors. 22 | 23 | ## Type checking 24 | The project uses `pyright` to check types. 25 | To type check locally, install `pyright` per official [instructions](https://github.com/microsoft/pyright#command-line). 26 | It's configuration lives within `pyproject.toml`. It includes list of included and excluded files currently supporting type checks. 27 | To run `pyright` for the project, run the pre-commit process (`pre-commit run --all-files`) or `pyright --project=pyproject.toml` 28 | Alternatively, pyright is a built-in feature of VSCode that will automatically provide type hinting. 29 | 30 | ### Adding typing to more modules and packages 31 | If you would like to add typing to a module in the project, 32 | the list of included, excluded and strict files can be found in pyproject.toml (pyproject.toml -> [tool.pyright]). 33 | To run `pyright` for the project, run the pre-commit process (`pre-commit run --all-files`) or `pyright` 34 | 35 | ## Git hooks 36 | The CI will run several checks on the new code pushed to the Gym repository. These checks can also be run locally without waiting for the CI by following the steps below: 37 | 1. [install `pre-commit`](https://pre-commit.com/#install), 38 | 2. Install the Git hooks by running `pre-commit install`. 39 | 40 | Once those two steps are done, the Git hooks will be run automatically at every new commit. 41 | The Git hooks can also be run manually with `pre-commit run --all-files`, and if needed they can be skipped (not recommended) with `git commit --no-verify`. 42 | **Note:** you may have to run `pre-commit run --all-files` manually a couple of times to make it pass when you commit, as each formatting tool will first format the code and fail the first time but should pass the second time. 43 | 44 | Additionally, for pull requests, the project runs a number of tests for the whole project using [pytest](https://docs.pytest.org/en/latest/getting-started.html#install-pytest). 45 | These tests can be run locally with `pytest` in the root folder. 46 | 47 | ## Docstrings 48 | Pydocstyle has been added to the pre-commit process such that all new functions follow the [google docstring style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html). 49 | All new functions require either a short docstring, a single line explaining the purpose of a function 50 | or a multiline docstring that documents each argument and the return type (if there is one) of the function. 51 | In addition, new file and class require top docstrings that should outline the purpose of the file/class. 52 | For classes, code block examples can be provided in the top docstring and not the constructor arguments. 53 | 54 | To check your docstrings are correct, run `pre-commit run --all-files` or `pydocstyle --source --explain --convention=google`. 55 | If all docstrings that fail, the source and reason for the failure is provided. -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2016 OpenAI (https://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | 23 | # Mujoco models 24 | This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license: 25 | ``` 26 | This file is part of MuJoCo. 27 | Copyright 2009-2015 Roboti LLC. 28 | Mujoco :: Advanced physics simulation engine 29 | Source : www.roboti.us 30 | Version : 1.31 31 | Released : 23Apr16 32 | Author :: Vikash Kumar 33 | Contacts : kumar@roboti.us 34 | ``` 35 | -------------------------------------------------------------------------------- /bin/docker_entrypoint: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is the entrypoint for our Docker image. 3 | 4 | set -ex 5 | 6 | # Set up display; otherwise rendering will fail 7 | Xvfb -screen 0 1024x768x24 & 8 | export DISPLAY=:0 9 | 10 | # Wait for the file to come up 11 | display=0 12 | file="/tmp/.X11-unix/X$display" 13 | for i in $(seq 1 10); do 14 | if [ -e "$file" ]; then 15 | break 16 | fi 17 | 18 | echo "Waiting for $file to be created (try $i/10)" 19 | sleep "$i" 20 | done 21 | if ! [ -e "$file" ]; then 22 | echo "Timing out: $file was not created" 23 | exit 1 24 | fi 25 | 26 | exec "$@" 27 | -------------------------------------------------------------------------------- /gym/__init__.py: -------------------------------------------------------------------------------- 1 | """Root __init__ of the gym module setting the __all__ of gym modules.""" 2 | # isort: skip_file 3 | 4 | from gym import error 5 | from gym.version import VERSION as __version__ 6 | 7 | from gym.core import ( 8 | Env, 9 | Wrapper, 10 | ObservationWrapper, 11 | ActionWrapper, 12 | RewardWrapper, 13 | ) 14 | from gym.spaces import Space 15 | from gym.envs import make, spec, register 16 | from gym import logger 17 | from gym import vector 18 | from gym import wrappers 19 | import os 20 | import sys 21 | 22 | __all__ = ["Env", "Space", "Wrapper", "make", "spec", "register"] 23 | 24 | # Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems 25 | # SDL connecting to alsa frequently create these giant lists of warnings every time you import an environment using 26 | # pygame 27 | # DSP is far more benign (and should probably be the default in SDL anyways) 28 | 29 | if sys.platform.startswith("linux"): 30 | os.environ["SDL_AUDIODRIVER"] = "dsp" 31 | 32 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide" 33 | 34 | try: 35 | import gym_notices.notices as notices 36 | 37 | # print version warning if necessary 38 | notice = notices.notices.get(__version__) 39 | if notice: 40 | print(notice, file=sys.stderr) 41 | 42 | except Exception: # nosec 43 | pass 44 | -------------------------------------------------------------------------------- /gym/envs/box2d/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.box2d.bipedal_walker import BipedalWalker, BipedalWalkerHardcore 2 | from gym.envs.box2d.car_racing import CarRacing 3 | from gym.envs.box2d.lunar_lander import LunarLander, LunarLanderContinuous 4 | -------------------------------------------------------------------------------- /gym/envs/classic_control/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.classic_control.acrobot import AcrobotEnv 2 | from gym.envs.classic_control.cartpole import CartPoleEnv 3 | from gym.envs.classic_control.continuous_mountain_car import Continuous_MountainCarEnv 4 | from gym.envs.classic_control.mountain_car import MountainCarEnv 5 | from gym.envs.classic_control.pendulum import PendulumEnv 6 | -------------------------------------------------------------------------------- /gym/envs/classic_control/assets/clockwise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/classic_control/assets/clockwise.png -------------------------------------------------------------------------------- /gym/envs/classic_control/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions used for classic control environments. 3 | """ 4 | 5 | from typing import Optional, SupportsFloat, Tuple 6 | 7 | 8 | def verify_number_and_cast(x: SupportsFloat) -> float: 9 | """Verify parameter is a single number and cast to a float.""" 10 | try: 11 | x = float(x) 12 | except (ValueError, TypeError): 13 | raise ValueError(f"An option ({x}) could not be converted to a float.") 14 | return x 15 | 16 | 17 | def maybe_parse_reset_bounds( 18 | options: Optional[dict], default_low: float, default_high: float 19 | ) -> Tuple[float, float]: 20 | """ 21 | This function can be called during a reset() to customize the sampling 22 | ranges for setting the initial state distributions. 23 | 24 | Args: 25 | options: Options passed in to reset(). 26 | default_low: Default lower limit to use, if none specified in options. 27 | default_high: Default upper limit to use, if none specified in options. 28 | 29 | Returns: 30 | Tuple of the lower and upper limits. 31 | """ 32 | if options is None: 33 | return default_low, default_high 34 | 35 | low = options.get("low") if "low" in options else default_low 36 | high = options.get("high") if "high" in options else default_high 37 | 38 | # We expect only numerical inputs. 39 | low = verify_number_and_cast(low) 40 | high = verify_number_and_cast(high) 41 | if low > high: 42 | raise ValueError( 43 | f"Lower bound ({low}) must be lower than higher bound ({high})." 44 | ) 45 | 46 | return low, high 47 | -------------------------------------------------------------------------------- /gym/envs/mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.mujoco.mujoco_env import MujocoEnv, MuJocoPyEnv # isort:skip 2 | 3 | from gym.envs.mujoco.ant import AntEnv 4 | from gym.envs.mujoco.half_cheetah import HalfCheetahEnv 5 | from gym.envs.mujoco.hopper import HopperEnv 6 | from gym.envs.mujoco.humanoid import HumanoidEnv 7 | from gym.envs.mujoco.humanoidstandup import HumanoidStandupEnv 8 | from gym.envs.mujoco.inverted_double_pendulum import InvertedDoublePendulumEnv 9 | from gym.envs.mujoco.inverted_pendulum import InvertedPendulumEnv 10 | from gym.envs.mujoco.pusher import PusherEnv 11 | from gym.envs.mujoco.reacher import ReacherEnv 12 | from gym.envs.mujoco.swimmer import SwimmerEnv 13 | from gym.envs.mujoco.walker2d import Walker2dEnv 14 | -------------------------------------------------------------------------------- /gym/envs/mujoco/ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class AntEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 20, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | observation_space = Box( 20 | low=-np.inf, high=np.inf, shape=(111,), dtype=np.float64 21 | ) 22 | MuJocoPyEnv.__init__( 23 | self, "ant.xml", 5, observation_space=observation_space, **kwargs 24 | ) 25 | utils.EzPickle.__init__(self, **kwargs) 26 | 27 | def step(self, a): 28 | xposbefore = self.get_body_com("torso")[0] 29 | self.do_simulation(a, self.frame_skip) 30 | xposafter = self.get_body_com("torso")[0] 31 | 32 | forward_reward = (xposafter - xposbefore) / self.dt 33 | ctrl_cost = 0.5 * np.square(a).sum() 34 | contact_cost = ( 35 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 36 | ) 37 | survive_reward = 1.0 38 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 39 | state = self.state_vector() 40 | not_terminated = ( 41 | np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 42 | ) 43 | terminated = not not_terminated 44 | ob = self._get_obs() 45 | 46 | if self.render_mode == "human": 47 | self.render() 48 | return ( 49 | ob, 50 | reward, 51 | terminated, 52 | False, 53 | dict( 54 | reward_forward=forward_reward, 55 | reward_ctrl=-ctrl_cost, 56 | reward_contact=-contact_cost, 57 | reward_survive=survive_reward, 58 | ), 59 | ) 60 | 61 | def _get_obs(self): 62 | return np.concatenate( 63 | [ 64 | self.sim.data.qpos.flat[2:], 65 | self.sim.data.qvel.flat, 66 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 67 | ] 68 | ) 69 | 70 | def reset_model(self): 71 | qpos = self.init_qpos + self.np_random.uniform( 72 | size=self.model.nq, low=-0.1, high=0.1 73 | ) 74 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1 75 | self.set_state(qpos, qvel) 76 | return self._get_obs() 77 | 78 | def viewer_setup(self): 79 | assert self.viewer is not None 80 | self.viewer.cam.distance = self.model.stat.extent * 0.5 81 | -------------------------------------------------------------------------------- /gym/envs/mujoco/assets/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 49 | -------------------------------------------------------------------------------- /gym/envs/mujoco/assets/inverted_double_pendulum.xml: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /gym/envs/mujoco/assets/inverted_pendulum.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /gym/envs/mujoco/assets/point.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 32 | -------------------------------------------------------------------------------- /gym/envs/mujoco/assets/reacher.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 40 | -------------------------------------------------------------------------------- /gym/envs/mujoco/assets/swimmer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 40 | -------------------------------------------------------------------------------- /gym/envs/mujoco/half_cheetah.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 20, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64) 20 | MuJocoPyEnv.__init__( 21 | self, "half_cheetah.xml", 5, observation_space=observation_space, **kwargs 22 | ) 23 | utils.EzPickle.__init__(self, **kwargs) 24 | 25 | def step(self, action): 26 | xposbefore = self.sim.data.qpos[0] 27 | self.do_simulation(action, self.frame_skip) 28 | xposafter = self.sim.data.qpos[0] 29 | 30 | ob = self._get_obs() 31 | reward_ctrl = -0.1 * np.square(action).sum() 32 | reward_run = (xposafter - xposbefore) / self.dt 33 | reward = reward_ctrl + reward_run 34 | terminated = False 35 | 36 | if self.render_mode == "human": 37 | self.render() 38 | return ( 39 | ob, 40 | reward, 41 | terminated, 42 | False, 43 | dict(reward_run=reward_run, reward_ctrl=reward_ctrl), 44 | ) 45 | 46 | def _get_obs(self): 47 | return np.concatenate( 48 | [ 49 | self.sim.data.qpos.flat[1:], 50 | self.sim.data.qvel.flat, 51 | ] 52 | ) 53 | 54 | def reset_model(self): 55 | qpos = self.init_qpos + self.np_random.uniform( 56 | low=-0.1, high=0.1, size=self.model.nq 57 | ) 58 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1 59 | self.set_state(qpos, qvel) 60 | return self._get_obs() 61 | 62 | def viewer_setup(self): 63 | assert self.viewer is not None 64 | self.viewer.cam.distance = self.model.stat.extent * 0.5 65 | -------------------------------------------------------------------------------- /gym/envs/mujoco/hopper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class HopperEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 125, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64) 20 | MuJocoPyEnv.__init__( 21 | self, "hopper.xml", 4, observation_space=observation_space, **kwargs 22 | ) 23 | utils.EzPickle.__init__(self, **kwargs) 24 | 25 | def step(self, a): 26 | posbefore = self.sim.data.qpos[0] 27 | self.do_simulation(a, self.frame_skip) 28 | posafter, height, ang = self.sim.data.qpos[0:3] 29 | 30 | alive_bonus = 1.0 31 | reward = (posafter - posbefore) / self.dt 32 | reward += alive_bonus 33 | reward -= 1e-3 * np.square(a).sum() 34 | s = self.state_vector() 35 | terminated = not ( 36 | np.isfinite(s).all() 37 | and (np.abs(s[2:]) < 100).all() 38 | and (height > 0.7) 39 | and (abs(ang) < 0.2) 40 | ) 41 | ob = self._get_obs() 42 | 43 | if self.render_mode == "human": 44 | self.render() 45 | return ob, reward, terminated, False, {} 46 | 47 | def _get_obs(self): 48 | return np.concatenate( 49 | [self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)] 50 | ) 51 | 52 | def reset_model(self): 53 | qpos = self.init_qpos + self.np_random.uniform( 54 | low=-0.005, high=0.005, size=self.model.nq 55 | ) 56 | qvel = self.init_qvel + self.np_random.uniform( 57 | low=-0.005, high=0.005, size=self.model.nv 58 | ) 59 | self.set_state(qpos, qvel) 60 | return self._get_obs() 61 | 62 | def viewer_setup(self): 63 | assert self.viewer is not None 64 | self.viewer.cam.trackbodyid = 2 65 | self.viewer.cam.distance = self.model.stat.extent * 0.75 66 | self.viewer.cam.lookat[2] = 1.15 67 | self.viewer.cam.elevation = -20 68 | -------------------------------------------------------------------------------- /gym/envs/mujoco/humanoid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | def mass_center(model, sim): 9 | mass = np.expand_dims(model.body_mass, 1) 10 | xpos = sim.data.xipos 11 | return (np.sum(mass * xpos, 0) / np.sum(mass))[0] 12 | 13 | 14 | class HumanoidEnv(MuJocoPyEnv, utils.EzPickle): 15 | metadata = { 16 | "render_modes": [ 17 | "human", 18 | "rgb_array", 19 | "depth_array", 20 | ], 21 | "render_fps": 67, 22 | } 23 | 24 | def __init__(self, **kwargs): 25 | observation_space = Box( 26 | low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64 27 | ) 28 | MuJocoPyEnv.__init__( 29 | self, "humanoid.xml", 5, observation_space=observation_space, **kwargs 30 | ) 31 | utils.EzPickle.__init__(self, **kwargs) 32 | 33 | def _get_obs(self): 34 | data = self.sim.data 35 | return np.concatenate( 36 | [ 37 | data.qpos.flat[2:], 38 | data.qvel.flat, 39 | data.cinert.flat, 40 | data.cvel.flat, 41 | data.qfrc_actuator.flat, 42 | data.cfrc_ext.flat, 43 | ] 44 | ) 45 | 46 | def step(self, a): 47 | pos_before = mass_center(self.model, self.sim) 48 | self.do_simulation(a, self.frame_skip) 49 | pos_after = mass_center(self.model, self.sim) 50 | 51 | alive_bonus = 5.0 52 | data = self.sim.data 53 | lin_vel_cost = 1.25 * (pos_after - pos_before) / self.dt 54 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() 55 | quad_impact_cost = 0.5e-6 * np.square(data.cfrc_ext).sum() 56 | quad_impact_cost = min(quad_impact_cost, 10) 57 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus 58 | qpos = self.sim.data.qpos 59 | terminated = bool((qpos[2] < 1.0) or (qpos[2] > 2.0)) 60 | 61 | if self.render_mode == "human": 62 | self.render() 63 | return ( 64 | self._get_obs(), 65 | reward, 66 | terminated, 67 | False, 68 | dict( 69 | reward_linvel=lin_vel_cost, 70 | reward_quadctrl=-quad_ctrl_cost, 71 | reward_alive=alive_bonus, 72 | reward_impact=-quad_impact_cost, 73 | ), 74 | ) 75 | 76 | def reset_model(self): 77 | c = 0.01 78 | self.set_state( 79 | self.init_qpos + self.np_random.uniform(low=-c, high=c, size=self.model.nq), 80 | self.init_qvel 81 | + self.np_random.uniform( 82 | low=-c, 83 | high=c, 84 | size=self.model.nv, 85 | ), 86 | ) 87 | return self._get_obs() 88 | 89 | def viewer_setup(self): 90 | assert self.viewer is not None 91 | self.viewer.cam.trackbodyid = 1 92 | self.viewer.cam.distance = self.model.stat.extent * 1.0 93 | self.viewer.cam.lookat[2] = 2.0 94 | self.viewer.cam.elevation = -20 95 | -------------------------------------------------------------------------------- /gym/envs/mujoco/humanoidstandup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 67, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | observation_space = Box( 20 | low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64 21 | ) 22 | MuJocoPyEnv.__init__( 23 | self, 24 | "humanoidstandup.xml", 25 | 5, 26 | observation_space=observation_space, 27 | **kwargs 28 | ) 29 | utils.EzPickle.__init__(self, **kwargs) 30 | 31 | def _get_obs(self): 32 | data = self.sim.data 33 | return np.concatenate( 34 | [ 35 | data.qpos.flat[2:], 36 | data.qvel.flat, 37 | data.cinert.flat, 38 | data.cvel.flat, 39 | data.qfrc_actuator.flat, 40 | data.cfrc_ext.flat, 41 | ] 42 | ) 43 | 44 | def step(self, a): 45 | self.do_simulation(a, self.frame_skip) 46 | pos_after = self.sim.data.qpos[2] 47 | data = self.sim.data 48 | uph_cost = (pos_after - 0) / self.model.opt.timestep 49 | 50 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum() 51 | quad_impact_cost = 0.5e-6 * np.square(data.cfrc_ext).sum() 52 | quad_impact_cost = min(quad_impact_cost, 10) 53 | reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 54 | 55 | if self.render_mode == "human": 56 | self.render() 57 | return ( 58 | self._get_obs(), 59 | reward, 60 | False, 61 | False, 62 | dict( 63 | reward_linup=uph_cost, 64 | reward_quadctrl=-quad_ctrl_cost, 65 | reward_impact=-quad_impact_cost, 66 | ), 67 | ) 68 | 69 | def reset_model(self): 70 | c = 0.01 71 | self.set_state( 72 | self.init_qpos + self.np_random.uniform(low=-c, high=c, size=self.model.nq), 73 | self.init_qvel 74 | + self.np_random.uniform( 75 | low=-c, 76 | high=c, 77 | size=self.model.nv, 78 | ), 79 | ) 80 | return self._get_obs() 81 | 82 | def viewer_setup(self): 83 | assert self.viewer is not None 84 | self.viewer.cam.trackbodyid = 1 85 | self.viewer.cam.distance = self.model.stat.extent * 1.0 86 | self.viewer.cam.lookat[2] = 0.8925 87 | self.viewer.cam.elevation = -20 88 | -------------------------------------------------------------------------------- /gym/envs/mujoco/inverted_double_pendulum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 20, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64) 20 | MuJocoPyEnv.__init__( 21 | self, 22 | "inverted_double_pendulum.xml", 23 | 5, 24 | observation_space=observation_space, 25 | **kwargs 26 | ) 27 | utils.EzPickle.__init__(self, **kwargs) 28 | 29 | def step(self, action): 30 | self.do_simulation(action, self.frame_skip) 31 | 32 | ob = self._get_obs() 33 | x, _, y = self.sim.data.site_xpos[0] 34 | dist_penalty = 0.01 * x**2 + (y - 2) ** 2 35 | v1, v2 = self.sim.data.qvel[1:3] 36 | vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2 37 | alive_bonus = 10 38 | r = alive_bonus - dist_penalty - vel_penalty 39 | terminated = bool(y <= 1) 40 | 41 | if self.render_mode == "human": 42 | self.render() 43 | return ob, r, terminated, False, {} 44 | 45 | def _get_obs(self): 46 | return np.concatenate( 47 | [ 48 | self.sim.data.qpos[:1], # cart x pos 49 | np.sin(self.sim.data.qpos[1:]), # link angles 50 | np.cos(self.sim.data.qpos[1:]), 51 | np.clip(self.sim.data.qvel, -10, 10), 52 | np.clip(self.sim.data.qfrc_constraint, -10, 10), 53 | ] 54 | ).ravel() 55 | 56 | def reset_model(self): 57 | self.set_state( 58 | self.init_qpos 59 | + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), 60 | self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1, 61 | ) 62 | return self._get_obs() 63 | 64 | def viewer_setup(self): 65 | assert self.viewer is not None 66 | v = self.viewer 67 | v.cam.trackbodyid = 0 68 | v.cam.distance = self.model.stat.extent * 0.5 69 | v.cam.lookat[2] = 0.12250000000000005 # v.model.stat.center[2] 70 | -------------------------------------------------------------------------------- /gym/envs/mujoco/inverted_pendulum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 25, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | utils.EzPickle.__init__(self, **kwargs) 20 | observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64) 21 | MuJocoPyEnv.__init__( 22 | self, 23 | "inverted_pendulum.xml", 24 | 2, 25 | observation_space=observation_space, 26 | **kwargs 27 | ) 28 | 29 | def step(self, a): 30 | reward = 1.0 31 | self.do_simulation(a, self.frame_skip) 32 | 33 | ob = self._get_obs() 34 | terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2)) 35 | 36 | if self.render_mode == "human": 37 | self.render() 38 | return ob, reward, terminated, False, {} 39 | 40 | def reset_model(self): 41 | qpos = self.init_qpos + self.np_random.uniform( 42 | size=self.model.nq, low=-0.01, high=0.01 43 | ) 44 | qvel = self.init_qvel + self.np_random.uniform( 45 | size=self.model.nv, low=-0.01, high=0.01 46 | ) 47 | self.set_state(qpos, qvel) 48 | return self._get_obs() 49 | 50 | def _get_obs(self): 51 | return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel() 52 | 53 | def viewer_setup(self): 54 | assert self.viewer is not None 55 | self.viewer.cam.trackbodyid = 0 56 | self.viewer.cam.distance = self.model.stat.extent 57 | -------------------------------------------------------------------------------- /gym/envs/mujoco/pusher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class PusherEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 20, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | utils.EzPickle.__init__(self, **kwargs) 20 | observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64) 21 | MuJocoPyEnv.__init__( 22 | self, "pusher.xml", 5, observation_space=observation_space, **kwargs 23 | ) 24 | 25 | def step(self, a): 26 | vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm") 27 | vec_2 = self.get_body_com("object") - self.get_body_com("goal") 28 | 29 | reward_near = -np.linalg.norm(vec_1) 30 | reward_dist = -np.linalg.norm(vec_2) 31 | reward_ctrl = -np.square(a).sum() 32 | reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near 33 | 34 | self.do_simulation(a, self.frame_skip) 35 | if self.render_mode == "human": 36 | self.render() 37 | 38 | ob = self._get_obs() 39 | return ( 40 | ob, 41 | reward, 42 | False, 43 | False, 44 | dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), 45 | ) 46 | 47 | def viewer_setup(self): 48 | assert self.viewer is not None 49 | self.viewer.cam.trackbodyid = -1 50 | self.viewer.cam.distance = 4.0 51 | 52 | def reset_model(self): 53 | qpos = self.init_qpos 54 | 55 | self.goal_pos = np.asarray([0, 0]) 56 | while True: 57 | self.cylinder_pos = np.concatenate( 58 | [ 59 | self.np_random.uniform(low=-0.3, high=0, size=1), 60 | self.np_random.uniform(low=-0.2, high=0.2, size=1), 61 | ] 62 | ) 63 | if np.linalg.norm(self.cylinder_pos - self.goal_pos) > 0.17: 64 | break 65 | 66 | qpos[-4:-2] = self.cylinder_pos 67 | qpos[-2:] = self.goal_pos 68 | qvel = self.init_qvel + self.np_random.uniform( 69 | low=-0.005, high=0.005, size=self.model.nv 70 | ) 71 | qvel[-4:] = 0 72 | self.set_state(qpos, qvel) 73 | return self._get_obs() 74 | 75 | def _get_obs(self): 76 | return np.concatenate( 77 | [ 78 | self.sim.data.qpos.flat[:7], 79 | self.sim.data.qvel.flat[:7], 80 | self.get_body_com("tips_arm"), 81 | self.get_body_com("object"), 82 | self.get_body_com("goal"), 83 | ] 84 | ) 85 | -------------------------------------------------------------------------------- /gym/envs/mujoco/reacher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class ReacherEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 50, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | utils.EzPickle.__init__(self, **kwargs) 20 | observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64) 21 | MuJocoPyEnv.__init__( 22 | self, "reacher.xml", 2, observation_space=observation_space, **kwargs 23 | ) 24 | 25 | def step(self, a): 26 | vec = self.get_body_com("fingertip") - self.get_body_com("target") 27 | reward_dist = -np.linalg.norm(vec) 28 | reward_ctrl = -np.square(a).sum() 29 | reward = reward_dist + reward_ctrl 30 | 31 | self.do_simulation(a, self.frame_skip) 32 | if self.render_mode == "human": 33 | self.render() 34 | 35 | ob = self._get_obs() 36 | return ( 37 | ob, 38 | reward, 39 | False, 40 | False, 41 | dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl), 42 | ) 43 | 44 | def viewer_setup(self): 45 | assert self.viewer is not None 46 | self.viewer.cam.trackbodyid = 0 47 | 48 | def reset_model(self): 49 | qpos = ( 50 | self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq) 51 | + self.init_qpos 52 | ) 53 | while True: 54 | self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2) 55 | if np.linalg.norm(self.goal) < 0.2: 56 | break 57 | qpos[-2:] = self.goal 58 | qvel = self.init_qvel + self.np_random.uniform( 59 | low=-0.005, high=0.005, size=self.model.nv 60 | ) 61 | qvel[-2:] = 0 62 | self.set_state(qpos, qvel) 63 | return self._get_obs() 64 | 65 | def _get_obs(self): 66 | theta = self.sim.data.qpos.flat[:2] 67 | return np.concatenate( 68 | [ 69 | np.cos(theta), 70 | np.sin(theta), 71 | self.sim.data.qpos.flat[2:], 72 | self.sim.data.qvel.flat[:2], 73 | self.get_body_com("fingertip") - self.get_body_com("target"), 74 | ] 75 | ) 76 | -------------------------------------------------------------------------------- /gym/envs/mujoco/swimmer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class SwimmerEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 25, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float64) 20 | MuJocoPyEnv.__init__( 21 | self, "swimmer.xml", 4, observation_space=observation_space, **kwargs 22 | ) 23 | utils.EzPickle.__init__(self, **kwargs) 24 | 25 | def step(self, a): 26 | ctrl_cost_coeff = 0.0001 27 | xposbefore = self.sim.data.qpos[0] 28 | self.do_simulation(a, self.frame_skip) 29 | xposafter = self.sim.data.qpos[0] 30 | 31 | reward_fwd = (xposafter - xposbefore) / self.dt 32 | reward_ctrl = -ctrl_cost_coeff * np.square(a).sum() 33 | reward = reward_fwd + reward_ctrl 34 | ob = self._get_obs() 35 | 36 | if self.render_mode == "human": 37 | self.render() 38 | 39 | return ( 40 | ob, 41 | reward, 42 | False, 43 | False, 44 | dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl), 45 | ) 46 | 47 | def _get_obs(self): 48 | qpos = self.sim.data.qpos 49 | qvel = self.sim.data.qvel 50 | return np.concatenate([qpos.flat[2:], qvel.flat]) 51 | 52 | def reset_model(self): 53 | self.set_state( 54 | self.init_qpos 55 | + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq), 56 | self.init_qvel 57 | + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv), 58 | ) 59 | return self._get_obs() 60 | -------------------------------------------------------------------------------- /gym/envs/mujoco/walker2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import utils 4 | from gym.envs.mujoco import MuJocoPyEnv 5 | from gym.spaces import Box 6 | 7 | 8 | class Walker2dEnv(MuJocoPyEnv, utils.EzPickle): 9 | metadata = { 10 | "render_modes": [ 11 | "human", 12 | "rgb_array", 13 | "depth_array", 14 | ], 15 | "render_fps": 125, 16 | } 17 | 18 | def __init__(self, **kwargs): 19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64) 20 | MuJocoPyEnv.__init__( 21 | self, "walker2d.xml", 4, observation_space=observation_space, **kwargs 22 | ) 23 | utils.EzPickle.__init__(self, **kwargs) 24 | 25 | def step(self, a): 26 | posbefore = self.sim.data.qpos[0] 27 | self.do_simulation(a, self.frame_skip) 28 | posafter, height, ang = self.sim.data.qpos[0:3] 29 | 30 | alive_bonus = 1.0 31 | reward = (posafter - posbefore) / self.dt 32 | reward += alive_bonus 33 | reward -= 1e-3 * np.square(a).sum() 34 | terminated = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0) 35 | ob = self._get_obs() 36 | 37 | if self.render_mode == "human": 38 | self.render() 39 | 40 | return ob, reward, terminated, False, {} 41 | 42 | def _get_obs(self): 43 | qpos = self.sim.data.qpos 44 | qvel = self.sim.data.qvel 45 | return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel() 46 | 47 | def reset_model(self): 48 | self.set_state( 49 | self.init_qpos 50 | + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq), 51 | self.init_qvel 52 | + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv), 53 | ) 54 | return self._get_obs() 55 | 56 | def viewer_setup(self): 57 | assert self.viewer is not None 58 | self.viewer.cam.trackbodyid = 2 59 | self.viewer.cam.distance = self.model.stat.extent * 0.5 60 | self.viewer.cam.lookat[2] = 1.15 61 | self.viewer.cam.elevation = -20 62 | -------------------------------------------------------------------------------- /gym/envs/toy_text/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.toy_text.blackjack import BlackjackEnv 2 | from gym.envs.toy_text.cliffwalking import CliffWalkingEnv 3 | from gym.envs.toy_text.frozen_lake import FrozenLakeEnv 4 | from gym.envs.toy_text.taxi import TaxiEnv 5 | -------------------------------------------------------------------------------- /gym/envs/toy_text/font/Minecraft.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/font/Minecraft.ttf -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C2.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C3.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C4.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C5.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C6.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C7.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C8.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/C9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C9.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/CA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CA.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/CJ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CJ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/CK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CK.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/CQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CQ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/CT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CT.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/Card.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/Card.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D2.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D3.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D4.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D5.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D6.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D7.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D8.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/D9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D9.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/DA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DA.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/DJ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DJ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/DK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DK.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/DQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DQ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/DT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DT.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H2.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H3.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H4.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H5.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H6.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H7.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H8.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/H9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H9.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/HA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HA.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/HJ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HJ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/HK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HK.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/HQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HQ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/HT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HT.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S2.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S3.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S4.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S5.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S6.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S7.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S8.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/S9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S9.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/SA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SA.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/SJ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SJ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/SK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SK.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/SQ.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SQ.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/ST.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/ST.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/cab_front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_front.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/cab_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_left.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/cab_rear.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_rear.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/cab_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_right.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/cookie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cookie.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/cracked_hole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cracked_hole.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/elf_down.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_down.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/elf_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_left.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/elf_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_right.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/elf_up.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_up.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/goal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/goal.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/gridworld_median_bottom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_bottom.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/gridworld_median_horiz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_horiz.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/gridworld_median_left.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_left.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/gridworld_median_right.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_right.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/gridworld_median_top.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_top.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/gridworld_median_vert.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_vert.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/hole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/hole.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/hotel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/hotel.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/ice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/ice.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/mountain_bg1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_bg1.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/mountain_bg2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_bg2.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/mountain_cliff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_cliff.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/mountain_near-cliff1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_near-cliff1.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/mountain_near-cliff2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_near-cliff2.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/passenger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/passenger.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/stool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/stool.png -------------------------------------------------------------------------------- /gym/envs/toy_text/img/taxi_background.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/taxi_background.png -------------------------------------------------------------------------------- /gym/envs/toy_text/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def categorical_sample(prob_n, np_random: np.random.Generator): 5 | """Sample from categorical distribution where each row specifies class probabilities.""" 6 | prob_n = np.asarray(prob_n) 7 | csprob_n = np.cumsum(prob_n) 8 | return np.argmax(csprob_n > np_random.random()) 9 | -------------------------------------------------------------------------------- /gym/logger.py: -------------------------------------------------------------------------------- 1 | """Set of functions for logging messages.""" 2 | import sys 3 | import warnings 4 | from typing import Optional, Type 5 | 6 | from gym.utils import colorize 7 | 8 | DEBUG = 10 9 | INFO = 20 10 | WARN = 30 11 | ERROR = 40 12 | DISABLED = 50 13 | 14 | min_level = 30 15 | 16 | 17 | # Ensure DeprecationWarning to be displayed (#2685, #3059) 18 | warnings.filterwarnings("once", "", DeprecationWarning, module=r"^gym\.") 19 | 20 | 21 | def set_level(level: int): 22 | """Set logging threshold on current logger.""" 23 | global min_level 24 | min_level = level 25 | 26 | 27 | def debug(msg: str, *args: object): 28 | """Logs a debug message to the user.""" 29 | if min_level <= DEBUG: 30 | print(f"DEBUG: {msg % args}", file=sys.stderr) 31 | 32 | 33 | def info(msg: str, *args: object): 34 | """Logs an info message to the user.""" 35 | if min_level <= INFO: 36 | print(f"INFO: {msg % args}", file=sys.stderr) 37 | 38 | 39 | def warn( 40 | msg: str, 41 | *args: object, 42 | category: Optional[Type[Warning]] = None, 43 | stacklevel: int = 1, 44 | ): 45 | """Raises a warning to the user if the min_level <= WARN. 46 | 47 | Args: 48 | msg: The message to warn the user 49 | *args: Additional information to warn the user 50 | category: The category of warning 51 | stacklevel: The stack level to raise to 52 | """ 53 | if min_level <= WARN: 54 | warnings.warn( 55 | colorize(f"WARN: {msg % args}", "yellow"), 56 | category=category, 57 | stacklevel=stacklevel + 1, 58 | ) 59 | 60 | 61 | def deprecation(msg: str, *args: object): 62 | """Logs a deprecation warning to users.""" 63 | warn(msg, *args, category=DeprecationWarning, stacklevel=2) 64 | 65 | 66 | def error(msg: str, *args: object): 67 | """Logs an error message if min_level <= ERROR in red on the sys.stderr.""" 68 | if min_level <= ERROR: 69 | print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr) 70 | 71 | 72 | # DEPRECATED: 73 | setLevel = set_level 74 | -------------------------------------------------------------------------------- /gym/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/py.typed -------------------------------------------------------------------------------- /gym/spaces/__init__.py: -------------------------------------------------------------------------------- 1 | """This module implements various spaces. 2 | 3 | Spaces describe mathematical sets and are used in Gym to specify valid actions and observations. 4 | Every Gym environment must have the attributes ``action_space`` and ``observation_space``. 5 | If, for instance, three possible actions (0,1,2) can be performed in your environment and observations 6 | are vectors in the two-dimensional unit cube, the environment code may contain the following two lines:: 7 | 8 | self.action_space = spaces.Discrete(3) 9 | self.observation_space = spaces.Box(0, 1, shape=(2,)) 10 | """ 11 | from gym.spaces.box import Box 12 | from gym.spaces.dict import Dict 13 | from gym.spaces.discrete import Discrete 14 | from gym.spaces.graph import Graph, GraphInstance 15 | from gym.spaces.multi_binary import MultiBinary 16 | from gym.spaces.multi_discrete import MultiDiscrete 17 | from gym.spaces.sequence import Sequence 18 | from gym.spaces.space import Space 19 | from gym.spaces.text import Text 20 | from gym.spaces.tuple import Tuple 21 | from gym.spaces.utils import flatdim, flatten, flatten_space, unflatten 22 | 23 | __all__ = [ 24 | "Space", 25 | "Box", 26 | "Discrete", 27 | "Text", 28 | "Graph", 29 | "GraphInstance", 30 | "MultiDiscrete", 31 | "MultiBinary", 32 | "Tuple", 33 | "Sequence", 34 | "Dict", 35 | "flatdim", 36 | "flatten_space", 37 | "flatten", 38 | "unflatten", 39 | ] 40 | -------------------------------------------------------------------------------- /gym/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """A set of common utilities used within the environments. 2 | 3 | These are not intended as API functions, and will not remain stable over time. 4 | """ 5 | 6 | # These submodules should not have any import-time dependencies. 7 | # We want this since we use `utils` during our import-time sanity checks 8 | # that verify that our dependencies are actually present. 9 | from gym.utils.colorize import colorize 10 | from gym.utils.ezpickle import EzPickle 11 | -------------------------------------------------------------------------------- /gym/utils/colorize.py: -------------------------------------------------------------------------------- 1 | """A set of common utilities used within the environments. 2 | 3 | These are not intended as API functions, and will not remain stable over time. 4 | """ 5 | 6 | color2num = dict( 7 | gray=30, 8 | red=31, 9 | green=32, 10 | yellow=33, 11 | blue=34, 12 | magenta=35, 13 | cyan=36, 14 | white=37, 15 | crimson=38, 16 | ) 17 | 18 | 19 | def colorize( 20 | string: str, color: str, bold: bool = False, highlight: bool = False 21 | ) -> str: 22 | """Returns string surrounded by appropriate terminal colour codes to print colourised text. 23 | 24 | Args: 25 | string: The message to colourise 26 | color: Literal values are gray, red, green, yellow, blue, magenta, cyan, white, crimson 27 | bold: If to bold the string 28 | highlight: If to highlight the string 29 | 30 | Returns: 31 | Colourised string 32 | """ 33 | attr = [] 34 | num = color2num[color] 35 | if highlight: 36 | num += 10 37 | attr.append(str(num)) 38 | if bold: 39 | attr.append("1") 40 | attrs = ";".join(attr) 41 | return f"\x1b[{attrs}m{string}\x1b[0m" 42 | -------------------------------------------------------------------------------- /gym/utils/ezpickle.py: -------------------------------------------------------------------------------- 1 | """Class for pickling and unpickling objects via their constructor arguments.""" 2 | 3 | 4 | class EzPickle: 5 | """Objects that are pickled and unpickled via their constructor arguments. 6 | 7 | Example:: 8 | 9 | >>> class Dog(Animal, EzPickle): 10 | ... def __init__(self, furcolor, tailkind="bushy"): 11 | ... Animal.__init__() 12 | ... EzPickle.__init__(furcolor, tailkind) 13 | 14 | When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor. 15 | However, philosophers are still not sure whether it is still the same dog. 16 | 17 | This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari. 18 | """ 19 | 20 | def __init__(self, *args, **kwargs): 21 | """Uses the ``args`` and ``kwargs`` from the object's constructor for pickling.""" 22 | self._ezpickle_args = args 23 | self._ezpickle_kwargs = kwargs 24 | 25 | def __getstate__(self): 26 | """Returns the object pickle state with args and kwargs.""" 27 | return { 28 | "_ezpickle_args": self._ezpickle_args, 29 | "_ezpickle_kwargs": self._ezpickle_kwargs, 30 | } 31 | 32 | def __setstate__(self, d): 33 | """Sets the object pickle state using d.""" 34 | out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"]) 35 | self.__dict__.update(out.__dict__) 36 | -------------------------------------------------------------------------------- /gym/utils/seeding.py: -------------------------------------------------------------------------------- 1 | """Set of random number generator functions: seeding, generator, hashing seeds.""" 2 | from typing import Any, Optional, Tuple 3 | 4 | import numpy as np 5 | 6 | from gym import error 7 | 8 | 9 | def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]: 10 | """Generates a random number generator from the seed and returns the Generator and seed. 11 | 12 | Args: 13 | seed: The seed used to create the generator 14 | 15 | Returns: 16 | The generator and resulting seed 17 | 18 | Raises: 19 | Error: Seed must be a non-negative integer or omitted 20 | """ 21 | if seed is not None and not (isinstance(seed, int) and 0 <= seed): 22 | raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}") 23 | 24 | seed_seq = np.random.SeedSequence(seed) 25 | np_seed = seed_seq.entropy 26 | rng = RandomNumberGenerator(np.random.PCG64(seed_seq)) 27 | return rng, np_seed 28 | 29 | 30 | RNG = RandomNumberGenerator = np.random.Generator 31 | -------------------------------------------------------------------------------- /gym/vector/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for vector environments.""" 2 | from typing import Iterable, List, Optional, Union 3 | 4 | import gym 5 | from gym.vector.async_vector_env import AsyncVectorEnv 6 | from gym.vector.sync_vector_env import SyncVectorEnv 7 | from gym.vector.vector_env import VectorEnv, VectorEnvWrapper 8 | 9 | __all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"] 10 | 11 | 12 | def make( 13 | id: str, 14 | num_envs: int = 1, 15 | asynchronous: bool = True, 16 | wrappers: Optional[Union[callable, List[callable]]] = None, 17 | disable_env_checker: Optional[bool] = None, 18 | **kwargs, 19 | ) -> VectorEnv: 20 | """Create a vectorized environment from multiple copies of an environment, from its id. 21 | 22 | Example:: 23 | 24 | >>> import gym 25 | >>> env = gym.vector.make('CartPole-v1', num_envs=3) 26 | >>> env.reset() 27 | array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827], 28 | [ 0.03073904, 0.00145001, -0.03088818, -0.03131252], 29 | [ 0.03468829, 0.01500225, 0.01230312, 0.01825218]], 30 | dtype=float32) 31 | 32 | Args: 33 | id: The environment ID. This must be a valid ID from the registry. 34 | num_envs: Number of copies of the environment. 35 | asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`. 36 | wrappers: If not ``None``, then apply the wrappers to each internal environment during creation. 37 | disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter 38 | (that is by default False), otherwise will run according to this argument (True = not run, False = run) 39 | **kwargs: Keywords arguments applied during `gym.make` 40 | 41 | Returns: 42 | The vectorized environment. 43 | """ 44 | 45 | def create_env(env_num: int): 46 | """Creates an environment that can enable or disable the environment checker.""" 47 | # If the env_num > 0 then disable the environment checker otherwise use the parameter 48 | _disable_env_checker = True if env_num > 0 else disable_env_checker 49 | 50 | def _make_env(): 51 | env = gym.envs.registration.make( 52 | id, 53 | disable_env_checker=_disable_env_checker, 54 | **kwargs, 55 | ) 56 | if wrappers is not None: 57 | if callable(wrappers): 58 | env = wrappers(env) 59 | elif isinstance(wrappers, Iterable) and all( 60 | [callable(w) for w in wrappers] 61 | ): 62 | for wrapper in wrappers: 63 | env = wrapper(env) 64 | else: 65 | raise NotImplementedError 66 | return env 67 | 68 | return _make_env 69 | 70 | env_fns = [ 71 | create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs) 72 | ] 73 | return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns) 74 | -------------------------------------------------------------------------------- /gym/vector/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for gym vector utils.""" 2 | from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars 3 | from gym.vector.utils.numpy_utils import concatenate, create_empty_array 4 | from gym.vector.utils.shared_memory import ( 5 | create_shared_memory, 6 | read_from_shared_memory, 7 | write_to_shared_memory, 8 | ) 9 | from gym.vector.utils.spaces import _BaseGymSpaces # pyright: reportPrivateUsage=false 10 | from gym.vector.utils.spaces import BaseGymSpaces, batch_space, iterate 11 | 12 | __all__ = [ 13 | "CloudpickleWrapper", 14 | "clear_mpi_env_vars", 15 | "concatenate", 16 | "create_empty_array", 17 | "create_shared_memory", 18 | "read_from_shared_memory", 19 | "write_to_shared_memory", 20 | "BaseGymSpaces", 21 | "batch_space", 22 | "iterate", 23 | ] 24 | -------------------------------------------------------------------------------- /gym/vector/utils/misc.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous utilities.""" 2 | import contextlib 3 | import os 4 | 5 | __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"] 6 | 7 | 8 | class CloudpickleWrapper: 9 | """Wrapper that uses cloudpickle to pickle and unpickle the result.""" 10 | 11 | def __init__(self, fn: callable): 12 | """Cloudpickle wrapper for a function.""" 13 | self.fn = fn 14 | 15 | def __getstate__(self): 16 | """Get the state using `cloudpickle.dumps(self.fn)`.""" 17 | import cloudpickle 18 | 19 | return cloudpickle.dumps(self.fn) 20 | 21 | def __setstate__(self, ob): 22 | """Sets the state with obs.""" 23 | import pickle 24 | 25 | self.fn = pickle.loads(ob) 26 | 27 | def __call__(self): 28 | """Calls the function `self.fn` with no arguments.""" 29 | return self.fn() 30 | 31 | 32 | @contextlib.contextmanager 33 | def clear_mpi_env_vars(): 34 | """Clears the MPI of environment variables. 35 | 36 | `from mpi4py import MPI` will call `MPI_Init` by default. 37 | If the child process has MPI environment variables, MPI will think that the child process 38 | is an MPI process just like the parent and do bad things such as hang. 39 | 40 | This context manager is a hacky way to clear those environment variables 41 | temporarily such as when we are starting multiprocessing Processes. 42 | 43 | Yields: 44 | Yields for the context manager 45 | """ 46 | removed_environment = {} 47 | for k, v in list(os.environ.items()): 48 | for prefix in ["OMPI_", "PMI_"]: 49 | if k.startswith(prefix): 50 | removed_environment[k] = v 51 | del os.environ[k] 52 | try: 53 | yield 54 | finally: 55 | os.environ.update(removed_environment) 56 | -------------------------------------------------------------------------------- /gym/version.py: -------------------------------------------------------------------------------- 1 | VERSION = "0.26.2" 2 | -------------------------------------------------------------------------------- /gym/wrappers/README.md: -------------------------------------------------------------------------------- 1 | # Wrappers 2 | 3 | Wrappers are used to transform an environment in a modular way: 4 | 5 | ```python 6 | env = gym.make('Pong-v0') 7 | env = MyWrapper(env) 8 | ``` 9 | 10 | Note that we may later restructure any of the files in this directory, 11 | but will keep the wrappers available at the wrappers' top-level 12 | folder. So for example, you should access `MyWrapper` as follows: 13 | 14 | ```python 15 | from gym.wrappers import MyWrapper 16 | ``` 17 | 18 | ## Quick tips for writing your own wrapper 19 | 20 | - Don't forget to call `super(class_name, self).__init__(env)` if you override the wrapper's `__init__` function 21 | - You can access the inner environment with `self.unwrapped` 22 | - You can access the previous layer using `self.env` 23 | - The variables `metadata`, `action_space`, `observation_space`, `reward_range`, and `spec` are copied to `self` from the previous layer 24 | - Create a wrapped function for at least one of the following: `__init__(self, env)`, `step`, `reset`, `render`, `close`, or `seed` 25 | - Your layered function should take its input from the previous layer (`self.env`) and/or the inner layer (`self.unwrapped`) 26 | -------------------------------------------------------------------------------- /gym/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | """Module of wrapper classes.""" 2 | from gym import error 3 | from gym.wrappers.atari_preprocessing import AtariPreprocessing 4 | from gym.wrappers.autoreset import AutoResetWrapper 5 | from gym.wrappers.clip_action import ClipAction 6 | from gym.wrappers.filter_observation import FilterObservation 7 | from gym.wrappers.flatten_observation import FlattenObservation 8 | from gym.wrappers.frame_stack import FrameStack, LazyFrames 9 | from gym.wrappers.gray_scale_observation import GrayScaleObservation 10 | from gym.wrappers.human_rendering import HumanRendering 11 | from gym.wrappers.normalize import NormalizeObservation, NormalizeReward 12 | from gym.wrappers.order_enforcing import OrderEnforcing 13 | from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics 14 | from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule 15 | from gym.wrappers.render_collection import RenderCollection 16 | from gym.wrappers.rescale_action import RescaleAction 17 | from gym.wrappers.resize_observation import ResizeObservation 18 | from gym.wrappers.step_api_compatibility import StepAPICompatibility 19 | from gym.wrappers.time_aware_observation import TimeAwareObservation 20 | from gym.wrappers.time_limit import TimeLimit 21 | from gym.wrappers.transform_observation import TransformObservation 22 | from gym.wrappers.transform_reward import TransformReward 23 | from gym.wrappers.vector_list_info import VectorListInfo 24 | -------------------------------------------------------------------------------- /gym/wrappers/autoreset.py: -------------------------------------------------------------------------------- 1 | """Wrapper that autoreset environments when `terminated=True` or `truncated=True`.""" 2 | import gym 3 | 4 | 5 | class AutoResetWrapper(gym.Wrapper): 6 | """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. 7 | 8 | When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called, 9 | and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)`` 10 | with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API. 11 | - ``new_obs`` is the first observation after calling :meth:`self.env.reset` 12 | - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`. 13 | - ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`. 14 | - ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False. 15 | - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`, 16 | with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step` 17 | and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`. 18 | 19 | Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a 20 | new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the 21 | final reward, terminated and truncated state from the previous episode. 22 | If you need the final state from the previous episode, you need to retrieve it via the 23 | "final_observation" key in the info dict. 24 | Make sure you know what you're doing if you use this wrapper! 25 | """ 26 | 27 | def __init__(self, env: gym.Env): 28 | """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`. 29 | 30 | Args: 31 | env (gym.Env): The environment to apply the wrapper 32 | """ 33 | super().__init__(env) 34 | 35 | def step(self, action): 36 | """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered. 37 | 38 | Args: 39 | action: The action to take 40 | 41 | Returns: 42 | The autoreset environment :meth:`step` 43 | """ 44 | obs, reward, terminated, truncated, info = self.env.step(action) 45 | if terminated or truncated: 46 | 47 | new_obs, new_info = self.env.reset() 48 | assert ( 49 | "final_observation" not in new_info 50 | ), 'info dict cannot contain key "final_observation" ' 51 | assert ( 52 | "final_info" not in new_info 53 | ), 'info dict cannot contain key "final_info" ' 54 | 55 | new_info["final_observation"] = obs 56 | new_info["final_info"] = info 57 | 58 | obs = new_obs 59 | info = new_info 60 | 61 | return obs, reward, terminated, truncated, info 62 | -------------------------------------------------------------------------------- /gym/wrappers/clip_action.py: -------------------------------------------------------------------------------- 1 | """Wrapper for clipping actions within a valid bound.""" 2 | import numpy as np 3 | 4 | import gym 5 | from gym import ActionWrapper 6 | from gym.spaces import Box 7 | 8 | 9 | class ClipAction(ActionWrapper): 10 | """Clip the continuous action within the valid :class:`Box` observation space bound. 11 | 12 | Example: 13 | >>> import gym 14 | >>> env = gym.make('Bipedal-Walker-v3') 15 | >>> env = ClipAction(env) 16 | >>> env.action_space 17 | Box(-1.0, 1.0, (4,), float32) 18 | >>> env.step(np.array([5.0, 2.0, -10.0, 0.0])) 19 | # Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment 20 | """ 21 | 22 | def __init__(self, env: gym.Env): 23 | """A wrapper for clipping continuous actions within the valid bound. 24 | 25 | Args: 26 | env: The environment to apply the wrapper 27 | """ 28 | assert isinstance(env.action_space, Box) 29 | super().__init__(env) 30 | 31 | def action(self, action): 32 | """Clips the action within the valid bounds. 33 | 34 | Args: 35 | action: The action to clip 36 | 37 | Returns: 38 | The clipped action 39 | """ 40 | return np.clip(action, self.action_space.low, self.action_space.high) 41 | -------------------------------------------------------------------------------- /gym/wrappers/env_checker.py: -------------------------------------------------------------------------------- 1 | """A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions.""" 2 | import gym 3 | from gym.core import ActType 4 | from gym.utils.passive_env_checker import ( 5 | check_action_space, 6 | check_observation_space, 7 | env_render_passive_checker, 8 | env_reset_passive_checker, 9 | env_step_passive_checker, 10 | ) 11 | 12 | 13 | class PassiveEnvChecker(gym.Wrapper): 14 | """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gym API.""" 15 | 16 | def __init__(self, env): 17 | """Initialises the wrapper with the environments, run the observation and action space tests.""" 18 | super().__init__(env) 19 | 20 | assert hasattr( 21 | env, "action_space" 22 | ), "The environment must specify an action space. https://www.gymlibrary.dev/content/environment_creation/" 23 | check_action_space(env.action_space) 24 | assert hasattr( 25 | env, "observation_space" 26 | ), "The environment must specify an observation space. https://www.gymlibrary.dev/content/environment_creation/" 27 | check_observation_space(env.observation_space) 28 | 29 | self.checked_reset = False 30 | self.checked_step = False 31 | self.checked_render = False 32 | 33 | def step(self, action: ActType): 34 | """Steps through the environment that on the first call will run the `passive_env_step_check`.""" 35 | if self.checked_step is False: 36 | self.checked_step = True 37 | return env_step_passive_checker(self.env, action) 38 | else: 39 | return self.env.step(action) 40 | 41 | def reset(self, **kwargs): 42 | """Resets the environment that on the first call will run the `passive_env_reset_check`.""" 43 | if self.checked_reset is False: 44 | self.checked_reset = True 45 | return env_reset_passive_checker(self.env, **kwargs) 46 | else: 47 | return self.env.reset(**kwargs) 48 | 49 | def render(self, *args, **kwargs): 50 | """Renders the environment that on the first call will run the `passive_env_render_check`.""" 51 | if self.checked_render is False: 52 | self.checked_render = True 53 | return env_render_passive_checker(self.env, *args, **kwargs) 54 | else: 55 | return self.env.render(*args, **kwargs) 56 | -------------------------------------------------------------------------------- /gym/wrappers/filter_observation.py: -------------------------------------------------------------------------------- 1 | """A wrapper for filtering dictionary observations by their keys.""" 2 | import copy 3 | from typing import Sequence 4 | 5 | import gym 6 | from gym import spaces 7 | 8 | 9 | class FilterObservation(gym.ObservationWrapper): 10 | """Filter Dict observation space by the keys. 11 | 12 | Example: 13 | >>> import gym 14 | >>> env = gym.wrappers.TransformObservation( 15 | ... gym.make('CartPole-v1'), lambda obs: {'obs': obs, 'time': 0} 16 | ... ) 17 | >>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1)) 18 | >>> env.reset() 19 | {'obs': array([-0.00067088, -0.01860439, 0.04772898, -0.01911527], dtype=float32), 'time': 0} 20 | >>> env = FilterObservation(env, filter_keys=['time']) 21 | >>> env.reset() 22 | {'obs': array([ 0.04560107, 0.04466959, -0.0328232 , -0.02367178], dtype=float32)} 23 | >>> env.step(0) 24 | ({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {}) 25 | """ 26 | 27 | def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None): 28 | """A wrapper that filters dictionary observations by their keys. 29 | 30 | Args: 31 | env: The environment to apply the wrapper 32 | filter_keys: List of keys to be included in the observations. If ``None``, observations will not be filtered and this wrapper has no effect 33 | 34 | Raises: 35 | ValueError: If the environment's observation space is not :class:`spaces.Dict` 36 | ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space 37 | """ 38 | super().__init__(env) 39 | 40 | wrapped_observation_space = env.observation_space 41 | if not isinstance(wrapped_observation_space, spaces.Dict): 42 | raise ValueError( 43 | f"FilterObservationWrapper is only usable with dict observations, " 44 | f"environment observation space is {type(wrapped_observation_space)}" 45 | ) 46 | 47 | observation_keys = wrapped_observation_space.spaces.keys() 48 | if filter_keys is None: 49 | filter_keys = tuple(observation_keys) 50 | 51 | missing_keys = {key for key in filter_keys if key not in observation_keys} 52 | if missing_keys: 53 | raise ValueError( 54 | "All the filter_keys must be included in the original observation space.\n" 55 | f"Filter keys: {filter_keys}\n" 56 | f"Observation keys: {observation_keys}\n" 57 | f"Missing keys: {missing_keys}" 58 | ) 59 | 60 | self.observation_space = type(wrapped_observation_space)( 61 | [ 62 | (name, copy.deepcopy(space)) 63 | for name, space in wrapped_observation_space.spaces.items() 64 | if name in filter_keys 65 | ] 66 | ) 67 | 68 | self._env = env 69 | self._filter_keys = tuple(filter_keys) 70 | 71 | def observation(self, observation): 72 | """Filters the observations. 73 | 74 | Args: 75 | observation: The observation to filter 76 | 77 | Returns: 78 | The filtered observations 79 | """ 80 | filter_observation = self._filter_observation(observation) 81 | return filter_observation 82 | 83 | def _filter_observation(self, observation): 84 | observation = type(observation)( 85 | [ 86 | (name, value) 87 | for name, value in observation.items() 88 | if name in self._filter_keys 89 | ] 90 | ) 91 | return observation 92 | -------------------------------------------------------------------------------- /gym/wrappers/flatten_observation.py: -------------------------------------------------------------------------------- 1 | """Wrapper for flattening observations of an environment.""" 2 | import gym 3 | import gym.spaces as spaces 4 | 5 | 6 | class FlattenObservation(gym.ObservationWrapper): 7 | """Observation wrapper that flattens the observation. 8 | 9 | Example: 10 | >>> import gym 11 | >>> env = gym.make('CarRacing-v1') 12 | >>> env.observation_space.shape 13 | (96, 96, 3) 14 | >>> env = FlattenObservation(env) 15 | >>> env.observation_space.shape 16 | (27648,) 17 | >>> obs = env.reset() 18 | >>> obs.shape 19 | (27648,) 20 | """ 21 | 22 | def __init__(self, env: gym.Env): 23 | """Flattens the observations of an environment. 24 | 25 | Args: 26 | env: The environment to apply the wrapper 27 | """ 28 | super().__init__(env) 29 | self.observation_space = spaces.flatten_space(env.observation_space) 30 | 31 | def observation(self, observation): 32 | """Flattens an observation. 33 | 34 | Args: 35 | observation: The observation to flatten 36 | 37 | Returns: 38 | The flattened observation 39 | """ 40 | return spaces.flatten(self.env.observation_space, observation) 41 | -------------------------------------------------------------------------------- /gym/wrappers/gray_scale_observation.py: -------------------------------------------------------------------------------- 1 | """Wrapper that converts a color observation to grayscale.""" 2 | import numpy as np 3 | 4 | import gym 5 | from gym.spaces import Box 6 | 7 | 8 | class GrayScaleObservation(gym.ObservationWrapper): 9 | """Convert the image observation from RGB to gray scale. 10 | 11 | Example: 12 | >>> env = gym.make('CarRacing-v1') 13 | >>> env.observation_space 14 | Box(0, 255, (96, 96, 3), uint8) 15 | >>> env = GrayScaleObservation(gym.make('CarRacing-v1')) 16 | >>> env.observation_space 17 | Box(0, 255, (96, 96), uint8) 18 | >>> env = GrayScaleObservation(gym.make('CarRacing-v1'), keep_dim=True) 19 | >>> env.observation_space 20 | Box(0, 255, (96, 96, 1), uint8) 21 | """ 22 | 23 | def __init__(self, env: gym.Env, keep_dim: bool = False): 24 | """Convert the image observation from RGB to gray scale. 25 | 26 | Args: 27 | env (Env): The environment to apply the wrapper 28 | keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1. 29 | Otherwise, they are of shape AxB. 30 | """ 31 | super().__init__(env) 32 | self.keep_dim = keep_dim 33 | 34 | assert ( 35 | isinstance(self.observation_space, Box) 36 | and len(self.observation_space.shape) == 3 37 | and self.observation_space.shape[-1] == 3 38 | ) 39 | 40 | obs_shape = self.observation_space.shape[:2] 41 | if self.keep_dim: 42 | self.observation_space = Box( 43 | low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8 44 | ) 45 | else: 46 | self.observation_space = Box( 47 | low=0, high=255, shape=obs_shape, dtype=np.uint8 48 | ) 49 | 50 | def observation(self, observation): 51 | """Converts the colour observation to greyscale. 52 | 53 | Args: 54 | observation: Color observations 55 | 56 | Returns: 57 | Grayscale observations 58 | """ 59 | import cv2 60 | 61 | observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) 62 | if self.keep_dim: 63 | observation = np.expand_dims(observation, -1) 64 | return observation 65 | -------------------------------------------------------------------------------- /gym/wrappers/monitoring/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for monitoring.video_recorder.""" 2 | -------------------------------------------------------------------------------- /gym/wrappers/order_enforcing.py: -------------------------------------------------------------------------------- 1 | """Wrapper to enforce the proper ordering of environment operations.""" 2 | import gym 3 | from gym.error import ResetNeeded 4 | 5 | 6 | class OrderEnforcing(gym.Wrapper): 7 | """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. 8 | 9 | Example: 10 | >>> from gym.envs.classic_control import CartPoleEnv 11 | >>> env = CartPoleEnv() 12 | >>> env = OrderEnforcing(env) 13 | >>> env.step(0) 14 | ResetNeeded: Cannot call env.step() before calling env.reset() 15 | >>> env.render() 16 | ResetNeeded: Cannot call env.render() before calling env.reset() 17 | >>> env.reset() 18 | >>> env.render() 19 | >>> env.step(0) 20 | """ 21 | 22 | def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False): 23 | """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`. 24 | 25 | Args: 26 | env: The environment to wrap 27 | disable_render_order_enforcing: If to disable render order enforcing 28 | """ 29 | super().__init__(env) 30 | self._has_reset: bool = False 31 | self._disable_render_order_enforcing: bool = disable_render_order_enforcing 32 | 33 | def step(self, action): 34 | """Steps through the environment with `kwargs`.""" 35 | if not self._has_reset: 36 | raise ResetNeeded("Cannot call env.step() before calling env.reset()") 37 | return self.env.step(action) 38 | 39 | def reset(self, **kwargs): 40 | """Resets the environment with `kwargs`.""" 41 | self._has_reset = True 42 | return self.env.reset(**kwargs) 43 | 44 | def render(self, *args, **kwargs): 45 | """Renders the environment with `kwargs`.""" 46 | if not self._disable_render_order_enforcing and not self._has_reset: 47 | raise ResetNeeded( 48 | "Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, " 49 | "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper." 50 | ) 51 | return self.env.render(*args, **kwargs) 52 | 53 | @property 54 | def has_reset(self): 55 | """Returns if the environment has been reset before.""" 56 | return self._has_reset 57 | -------------------------------------------------------------------------------- /gym/wrappers/render_collection.py: -------------------------------------------------------------------------------- 1 | """A wrapper that adds render collection mode to an environment.""" 2 | import gym 3 | 4 | 5 | class RenderCollection(gym.Wrapper): 6 | """Save collection of render frames.""" 7 | 8 | def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True): 9 | """Initialize a :class:`RenderCollection` instance. 10 | 11 | Args: 12 | env: The environment that is being wrapped 13 | pop_frames (bool): If true, clear the collection frames after .render() is called. 14 | Default value is True. 15 | reset_clean (bool): If true, clear the collection frames when .reset() is called. 16 | Default value is True. 17 | """ 18 | super().__init__(env) 19 | assert env.render_mode is not None 20 | assert not env.render_mode.endswith("_list") 21 | self.frame_list = [] 22 | self.reset_clean = reset_clean 23 | self.pop_frames = pop_frames 24 | 25 | @property 26 | def render_mode(self): 27 | """Returns the collection render_mode name.""" 28 | return f"{self.env.render_mode}_list" 29 | 30 | def step(self, *args, **kwargs): 31 | """Perform a step in the base environment and collect a frame.""" 32 | output = self.env.step(*args, **kwargs) 33 | self.frame_list.append(self.env.render()) 34 | return output 35 | 36 | def reset(self, *args, **kwargs): 37 | """Reset the base environment, eventually clear the frame_list, and collect a frame.""" 38 | result = self.env.reset(*args, **kwargs) 39 | 40 | if self.reset_clean: 41 | self.frame_list = [] 42 | self.frame_list.append(self.env.render()) 43 | 44 | return result 45 | 46 | def render(self): 47 | """Returns the collection of frames and, if pop_frames = True, clears it.""" 48 | frames = self.frame_list 49 | if self.pop_frames: 50 | self.frame_list = [] 51 | 52 | return frames 53 | -------------------------------------------------------------------------------- /gym/wrappers/rescale_action.py: -------------------------------------------------------------------------------- 1 | """Wrapper for rescaling actions to within a max and min action.""" 2 | from typing import Union 3 | 4 | import numpy as np 5 | 6 | import gym 7 | from gym import spaces 8 | 9 | 10 | class RescaleAction(gym.ActionWrapper): 11 | """Affinely rescales the continuous action space of the environment to the range [min_action, max_action]. 12 | 13 | The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action` 14 | or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space. 15 | 16 | Example: 17 | >>> import gym 18 | >>> env = gym.make('BipedalWalker-v3') 19 | >>> env.action_space 20 | Box(-1.0, 1.0, (4,), float32) 21 | >>> min_action = -0.5 22 | >>> max_action = np.array([0.0, 0.5, 1.0, 0.75]) 23 | >>> env = RescaleAction(env, min_action=min_action, max_action=max_action) 24 | >>> env.action_space 25 | Box(-0.5, [0. 0.5 1. 0.75], (4,), float32) 26 | >>> RescaleAction(env, min_action, max_action).action_space == gym.spaces.Box(min_action, max_action) 27 | True 28 | """ 29 | 30 | def __init__( 31 | self, 32 | env: gym.Env, 33 | min_action: Union[float, int, np.ndarray], 34 | max_action: Union[float, int, np.ndarray], 35 | ): 36 | """Initializes the :class:`RescaleAction` wrapper. 37 | 38 | Args: 39 | env (Env): The environment to apply the wrapper 40 | min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar. 41 | max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar. 42 | """ 43 | assert isinstance( 44 | env.action_space, spaces.Box 45 | ), f"expected Box action space, got {type(env.action_space)}" 46 | assert np.less_equal(min_action, max_action).all(), (min_action, max_action) 47 | 48 | super().__init__(env) 49 | self.min_action = ( 50 | np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action 51 | ) 52 | self.max_action = ( 53 | np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action 54 | ) 55 | self.action_space = spaces.Box( 56 | low=min_action, 57 | high=max_action, 58 | shape=env.action_space.shape, 59 | dtype=env.action_space.dtype, 60 | ) 61 | 62 | def action(self, action): 63 | """Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`. 64 | 65 | Args: 66 | action: The action to rescale 67 | 68 | Returns: 69 | The rescaled action 70 | """ 71 | assert np.all(np.greater_equal(action, self.min_action)), ( 72 | action, 73 | self.min_action, 74 | ) 75 | assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action) 76 | low = self.env.action_space.low 77 | high = self.env.action_space.high 78 | action = low + (high - low) * ( 79 | (action - self.min_action) / (self.max_action - self.min_action) 80 | ) 81 | action = np.clip(action, low, high) 82 | return action 83 | -------------------------------------------------------------------------------- /gym/wrappers/resize_observation.py: -------------------------------------------------------------------------------- 1 | """Wrapper for resizing observations.""" 2 | from typing import Union 3 | 4 | import numpy as np 5 | 6 | import gym 7 | from gym.error import DependencyNotInstalled 8 | from gym.spaces import Box 9 | 10 | 11 | class ResizeObservation(gym.ObservationWrapper): 12 | """Resize the image observation. 13 | 14 | This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes 15 | the observation to the shape given by the 2-tuple :attr:`shape`. The argument :attr:`shape` may also be an integer. 16 | In that case, the observation is scaled to a square of side-length :attr:`shape`. 17 | 18 | Example: 19 | >>> import gym 20 | >>> env = gym.make('CarRacing-v1') 21 | >>> env.observation_space.shape 22 | (96, 96, 3) 23 | >>> env = ResizeObservation(env, 64) 24 | >>> env.observation_space.shape 25 | (64, 64, 3) 26 | """ 27 | 28 | def __init__(self, env: gym.Env, shape: Union[tuple, int]): 29 | """Resizes image observations to shape given by :attr:`shape`. 30 | 31 | Args: 32 | env: The environment to apply the wrapper 33 | shape: The shape of the resized observations 34 | """ 35 | super().__init__(env) 36 | if isinstance(shape, int): 37 | shape = (shape, shape) 38 | assert all(x > 0 for x in shape), shape 39 | 40 | self.shape = tuple(shape) 41 | 42 | assert isinstance( 43 | env.observation_space, Box 44 | ), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}" 45 | obs_shape = self.shape + env.observation_space.shape[2:] 46 | self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8) 47 | 48 | def observation(self, observation): 49 | """Updates the observations by resizing the observation to shape given by :attr:`shape`. 50 | 51 | Args: 52 | observation: The observation to reshape 53 | 54 | Returns: 55 | The reshaped observations 56 | 57 | Raises: 58 | DependencyNotInstalled: opencv-python is not installed 59 | """ 60 | try: 61 | import cv2 62 | except ImportError: 63 | raise DependencyNotInstalled( 64 | "opencv is not install, run `pip install gym[other]`" 65 | ) 66 | 67 | observation = cv2.resize( 68 | observation, self.shape[::-1], interpolation=cv2.INTER_AREA 69 | ) 70 | if observation.ndim == 2: 71 | observation = np.expand_dims(observation, -1) 72 | return observation 73 | -------------------------------------------------------------------------------- /gym/wrappers/step_api_compatibility.py: -------------------------------------------------------------------------------- 1 | """Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API.""" 2 | import gym 3 | from gym.logger import deprecation 4 | from gym.utils.step_api_compatibility import ( 5 | convert_to_done_step_api, 6 | convert_to_terminated_truncated_step_api, 7 | ) 8 | 9 | 10 | class StepAPICompatibility(gym.Wrapper): 11 | r"""A wrapper which can transform an environment from new step API to old and vice-versa. 12 | 13 | Old step API refers to step() method returning (observation, reward, done, info) 14 | New step API refers to step() method returning (observation, reward, terminated, truncated, info) 15 | (Refer to docs for details on the API change) 16 | 17 | Args: 18 | env (gym.Env): the env to wrap. Can be in old or new API 19 | apply_step_compatibility (bool): Apply to convert environment to use new step API that returns two bools. (False by default) 20 | 21 | Examples: 22 | >>> env = gym.make("CartPole-v1") 23 | >>> env # wrapper not applied by default, set to new API 24 | >>>> 25 | >>> env = gym.make("CartPole-v1", apply_api_compatibility=True) # set to old API 26 | >>>>> 27 | >>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs 28 | 29 | """ 30 | 31 | def __init__(self, env: gym.Env, output_truncation_bool: bool = True): 32 | """A wrapper which can transform an environment from new step API to old and vice-versa. 33 | 34 | Args: 35 | env (gym.Env): the env to wrap. Can be in old or new API 36 | output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API) 37 | """ 38 | super().__init__(env) 39 | self.output_truncation_bool = output_truncation_bool 40 | if not self.output_truncation_bool: 41 | deprecation( 42 | "Initializing environment in old step API which returns one bool instead of two." 43 | ) 44 | 45 | def step(self, action): 46 | """Steps through the environment, returning 5 or 4 items depending on `apply_step_compatibility`. 47 | 48 | Args: 49 | action: action to step through the environment with 50 | 51 | Returns: 52 | (observation, reward, terminated, truncated, info) or (observation, reward, done, info) 53 | """ 54 | step_returns = self.env.step(action) 55 | if self.output_truncation_bool: 56 | return convert_to_terminated_truncated_step_api(step_returns) 57 | else: 58 | return convert_to_done_step_api(step_returns) 59 | -------------------------------------------------------------------------------- /gym/wrappers/time_aware_observation.py: -------------------------------------------------------------------------------- 1 | """Wrapper for adding time aware observations to environment observation.""" 2 | import numpy as np 3 | 4 | import gym 5 | from gym.spaces import Box 6 | 7 | 8 | class TimeAwareObservation(gym.ObservationWrapper): 9 | """Augment the observation with the current time step in the episode. 10 | 11 | The observation space of the wrapped environment is assumed to be a flat :class:`Box`. 12 | In particular, pixel observations are not supported. This wrapper will append the current timestep within the current episode to the observation. 13 | 14 | Example: 15 | >>> import gym 16 | >>> env = gym.make('CartPole-v1') 17 | >>> env = TimeAwareObservation(env) 18 | >>> env.reset() 19 | array([ 0.03810719, 0.03522411, 0.02231044, -0.01088205, 0. ]) 20 | >>> env.step(env.action_space.sample())[0] 21 | array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ]) 22 | """ 23 | 24 | def __init__(self, env: gym.Env): 25 | """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space. 26 | 27 | Args: 28 | env: The environment to apply the wrapper 29 | """ 30 | super().__init__(env) 31 | assert isinstance(env.observation_space, Box) 32 | assert env.observation_space.dtype == np.float32 33 | low = np.append(self.observation_space.low, 0.0) 34 | high = np.append(self.observation_space.high, np.inf) 35 | self.observation_space = Box(low, high, dtype=np.float32) 36 | self.is_vector_env = getattr(env, "is_vector_env", False) 37 | 38 | def observation(self, observation): 39 | """Adds to the observation with the current time step. 40 | 41 | Args: 42 | observation: The observation to add the time step to 43 | 44 | Returns: 45 | The observation with the time step appended to 46 | """ 47 | return np.append(observation, self.t) 48 | 49 | def step(self, action): 50 | """Steps through the environment, incrementing the time step. 51 | 52 | Args: 53 | action: The action to take 54 | 55 | Returns: 56 | The environment's step using the action. 57 | """ 58 | self.t += 1 59 | return super().step(action) 60 | 61 | def reset(self, **kwargs): 62 | """Reset the environment setting the time to zero. 63 | 64 | Args: 65 | **kwargs: Kwargs to apply to env.reset() 66 | 67 | Returns: 68 | The reset environment 69 | """ 70 | self.t = 0 71 | return super().reset(**kwargs) 72 | -------------------------------------------------------------------------------- /gym/wrappers/time_limit.py: -------------------------------------------------------------------------------- 1 | """Wrapper for limiting the time steps of an environment.""" 2 | from typing import Optional 3 | 4 | import gym 5 | 6 | 7 | class TimeLimit(gym.Wrapper): 8 | """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded. 9 | 10 | If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued. 11 | Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP. 12 | 13 | Example: 14 | >>> from gym.envs.classic_control import CartPoleEnv 15 | >>> from gym.wrappers import TimeLimit 16 | >>> env = CartPoleEnv() 17 | >>> env = TimeLimit(env, max_episode_steps=1000) 18 | """ 19 | 20 | def __init__( 21 | self, 22 | env: gym.Env, 23 | max_episode_steps: Optional[int] = None, 24 | ): 25 | """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur. 26 | 27 | Args: 28 | env: The environment to apply the wrapper 29 | max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used) 30 | """ 31 | super().__init__(env) 32 | if max_episode_steps is None and self.env.spec is not None: 33 | max_episode_steps = env.spec.max_episode_steps 34 | if self.env.spec is not None: 35 | self.env.spec.max_episode_steps = max_episode_steps 36 | self._max_episode_steps = max_episode_steps 37 | self._elapsed_steps = None 38 | 39 | def step(self, action): 40 | """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate. 41 | 42 | Args: 43 | action: The environment step action 44 | 45 | Returns: 46 | The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True` 47 | if the number of steps elapsed >= max episode steps 48 | 49 | """ 50 | observation, reward, terminated, truncated, info = self.env.step(action) 51 | self._elapsed_steps += 1 52 | 53 | if self._elapsed_steps >= self._max_episode_steps: 54 | truncated = True 55 | 56 | return observation, reward, terminated, truncated, info 57 | 58 | def reset(self, **kwargs): 59 | """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero. 60 | 61 | Args: 62 | **kwargs: The kwargs to reset the environment with 63 | 64 | Returns: 65 | The reset environment 66 | """ 67 | self._elapsed_steps = 0 68 | return self.env.reset(**kwargs) 69 | -------------------------------------------------------------------------------- /gym/wrappers/transform_observation.py: -------------------------------------------------------------------------------- 1 | """Wrapper for transforming observations.""" 2 | from typing import Any, Callable 3 | 4 | import gym 5 | 6 | 7 | class TransformObservation(gym.ObservationWrapper): 8 | """Transform the observation via an arbitrary function :attr:`f`. 9 | 10 | The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space. 11 | 12 | If the transformation you wish to apply to observations returns values in a *different* space, you should subclass :class:`ObservationWrapper`, implement the transformation, and set the new observation space accordingly. If you were to use this wrapper instead, the observation space would be set incorrectly. 13 | 14 | Example: 15 | >>> import gym 16 | >>> import numpy as np 17 | >>> env = gym.make('CartPole-v1') 18 | >>> env = TransformObservation(env, lambda obs: obs + 0.1*np.random.randn(*obs.shape)) 19 | >>> env.reset() 20 | array([-0.08319338, 0.04635121, -0.07394746, 0.20877492]) 21 | """ 22 | 23 | def __init__(self, env: gym.Env, f: Callable[[Any], Any]): 24 | """Initialize the :class:`TransformObservation` wrapper with an environment and a transform function :param:`f`. 25 | 26 | Args: 27 | env: The environment to apply the wrapper 28 | f: A function that transforms the observation 29 | """ 30 | super().__init__(env) 31 | assert callable(f) 32 | self.f = f 33 | 34 | def observation(self, observation): 35 | """Transforms the observations with callable :attr:`f`. 36 | 37 | Args: 38 | observation: The observation to transform 39 | 40 | Returns: 41 | The transformed observation 42 | """ 43 | return self.f(observation) 44 | -------------------------------------------------------------------------------- /gym/wrappers/transform_reward.py: -------------------------------------------------------------------------------- 1 | """Wrapper for transforming the reward.""" 2 | from typing import Callable 3 | 4 | import gym 5 | from gym import RewardWrapper 6 | 7 | 8 | class TransformReward(RewardWrapper): 9 | """Transform the reward via an arbitrary function. 10 | 11 | Warning: 12 | If the base environment specifies a reward range which is not invariant under :attr:`f`, the :attr:`reward_range` of the wrapped environment will be incorrect. 13 | 14 | Example: 15 | >>> import gym 16 | >>> env = gym.make('CartPole-v1') 17 | >>> env = TransformReward(env, lambda r: 0.01*r) 18 | >>> env.reset() 19 | >>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample()) 20 | >>> reward 21 | 0.01 22 | """ 23 | 24 | def __init__(self, env: gym.Env, f: Callable[[float], float]): 25 | """Initialize the :class:`TransformReward` wrapper with an environment and reward transform function :param:`f`. 26 | 27 | Args: 28 | env: The environment to apply the wrapper 29 | f: A function that transforms the reward 30 | """ 31 | super().__init__(env) 32 | assert callable(f) 33 | self.f = f 34 | 35 | def reward(self, reward): 36 | """Transforms the reward using callable :attr:`f`. 37 | 38 | Args: 39 | reward: The reward to transform 40 | 41 | Returns: 42 | The transformed reward 43 | """ 44 | return self.f(reward) 45 | -------------------------------------------------------------------------------- /py.Dockerfile: -------------------------------------------------------------------------------- 1 | # A Dockerfile that sets up a full Gym install with test dependencies 2 | ARG PYTHON_VERSION 3 | FROM python:$PYTHON_VERSION 4 | 5 | SHELL ["/bin/bash", "-o", "pipefail", "-c"] 6 | 7 | RUN apt-get -y update \ 8 | && apt-get install --no-install-recommends -y \ 9 | unzip \ 10 | libglu1-mesa-dev \ 11 | libgl1-mesa-dev \ 12 | libosmesa6-dev \ 13 | xvfb \ 14 | patchelf \ 15 | ffmpeg cmake \ 16 | && apt-get autoremove -y \ 17 | && apt-get clean \ 18 | && rm -rf /var/lib/apt/lists/* \ 19 | # Download mujoco 20 | && mkdir /root/.mujoco \ 21 | && cd /root/.mujoco \ 22 | && wget -qO- 'https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz' | tar -xzvf - 23 | 24 | ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin" 25 | 26 | COPY . /usr/local/gym/ 27 | WORKDIR /usr/local/gym/ 28 | 29 | RUN if [ "python:${PYTHON_VERSION}" = "python:3.6.15" ] ; then pip install .[box2d,classic_control,toy_text,other] pytest=="7.0.1" --no-cache-dir; else pip install .[testing] --no-cache-dir; fi 30 | 31 | ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"] 32 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pyright] 2 | 3 | include = [ 4 | "gym/**", 5 | "tests/**" 6 | ] 7 | 8 | exclude = [ 9 | "**/node_modules", 10 | "**/__pycache__", 11 | ] 12 | 13 | strict = [ 14 | 15 | ] 16 | 17 | typeCheckingMode = "basic" 18 | pythonVersion = "3.6" 19 | pythonPlatform = "All" 20 | typeshedPath = "typeshed" 21 | enableTypeIgnoreComments = true 22 | 23 | # This is required as the CI pre-commit does not download the module (i.e. numpy, pygame, box2d) 24 | # Therefore, we have to ignore missing imports 25 | reportMissingImports = "none" 26 | # Some modules are missing type stubs, which is an issue when running pyright locally 27 | reportMissingTypeStubs = false 28 | # For warning and error, will raise an error when 29 | reportInvalidTypeVarUse = "none" 30 | 31 | # reportUnknownMemberType = "warning" # -> raises 6035 warnings 32 | # reportUnknownParameterType = "warning" # -> raises 1327 warnings 33 | # reportUnknownVariableType = "warning" # -> raises 2585 warnings 34 | # reportUnknownArgumentType = "warning" # -> raises 2104 warnings 35 | reportGeneralTypeIssues = "none" # -> commented out raises 489 errors 36 | reportUntypedFunctionDecorator = "none" # -> pytest.mark.parameterize issues 37 | 38 | reportPrivateUsage = "warning" 39 | reportUnboundVariable = "warning" 40 | 41 | [tool.pytest.ini_options] 42 | filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed 43 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.18.0 2 | cloudpickle>=1.2.0 3 | importlib_metadata>=4.8.0; python_version < '3.10' 4 | gym_notices>=0.0.4 5 | dataclasses==0.8; python_version == '3.6' 6 | typing_extensions==4.3.0; python_version == '3.7' 7 | opencv-python>=3.0 8 | lz4>=3.1.0 9 | matplotlib>=3.0 10 | box2d-py==2.3.5 11 | pygame==2.1.0 12 | ale-py~=0.8.0 13 | mujoco==2.2.0 14 | mujoco_py<2.2,>=2.1 15 | imageio>=2.14.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setups the project.""" 2 | import itertools 3 | import re 4 | 5 | from setuptools import find_packages, setup 6 | 7 | with open("gym/version.py") as file: 8 | full_version = file.read() 9 | assert ( 10 | re.match(r'VERSION = "\d\.\d+\.\d+"\n', full_version).group(0) == full_version 11 | ), f"Unexpected version: {full_version}" 12 | VERSION = re.search(r"\d\.\d+\.\d+", full_version).group(0) 13 | 14 | # Environment-specific dependencies. 15 | extras = { 16 | "atari": ["ale-py~=0.8.0"], 17 | "accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"], 18 | "box2d": ["box2d-py==2.3.5", "pygame==2.1.0", "swig==4.*"], 19 | "classic_control": ["pygame==2.1.0"], 20 | "mujoco_py": ["mujoco_py<2.2,>=2.1"], 21 | "mujoco": ["mujoco==2.2", "imageio>=2.14.1"], 22 | "toy_text": ["pygame==2.1.0"], 23 | "other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"], 24 | } 25 | 26 | # Testing dependency groups. 27 | testing_group = set(extras.keys()) - {"accept-rom-license", "atari"} 28 | extras["testing"] = list( 29 | set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group))) 30 | ) + ["pytest==7.0.1"] 31 | 32 | # All dependency groups - accept rom license as requires user to run 33 | all_groups = set(extras.keys()) - {"accept-rom-license"} 34 | extras["all"] = list( 35 | set(itertools.chain.from_iterable(map(lambda group: extras[group], all_groups))) 36 | ) 37 | 38 | # Uses the readme as the description on PyPI 39 | with open("README.md") as fh: 40 | long_description = "" 41 | header_count = 0 42 | for line in fh: 43 | if line.startswith("##"): 44 | header_count += 1 45 | if header_count < 2: 46 | long_description += line 47 | else: 48 | break 49 | 50 | setup( 51 | author="Gym Community", 52 | author_email="jkterry@umd.edu", 53 | classifiers=[ 54 | # Python 3.6 is minimally supported (only with basic gym environments and API) 55 | "Programming Language :: Python :: 3", 56 | "Programming Language :: Python :: 3.6", 57 | "Programming Language :: Python :: 3.7", 58 | "Programming Language :: Python :: 3.8", 59 | "Programming Language :: Python :: 3.9", 60 | "Programming Language :: Python :: 3.10", 61 | ], 62 | description="Gym: A universal API for reinforcement learning environments", 63 | extras_require=extras, 64 | install_requires=[ 65 | "numpy >= 1.18.0", 66 | "cloudpickle >= 1.2.0", 67 | "importlib_metadata >= 4.8.0; python_version < '3.10'", 68 | "gym_notices >= 0.0.4", 69 | "dataclasses == 0.8; python_version == '3.6'", 70 | ], 71 | license="MIT", 72 | long_description=long_description, 73 | long_description_content_type="text/markdown", 74 | name="gym", 75 | packages=[package for package in find_packages() if package.startswith("gym")], 76 | package_data={ 77 | "gym": [ 78 | "envs/mujoco/assets/*.xml", 79 | "envs/classic_control/assets/*.png", 80 | "envs/toy_text/font/*.ttf", 81 | "envs/toy_text/img/*.png", 82 | "py.typed", 83 | ] 84 | }, 85 | python_requires=">=3.6", 86 | tests_require=extras["testing"], 87 | url="https://www.gymlibrary.dev/", 88 | version=VERSION, 89 | zip_safe=False, 90 | ) 91 | -------------------------------------------------------------------------------- /test_requirements.txt: -------------------------------------------------------------------------------- 1 | box2d-py==2.3.5 2 | lz4>=3.1.0 3 | opencv-python>=3.0 4 | mujoco==2.2.0 5 | matplotlib>=3.0 6 | imageio>=2.14.1 7 | pygame==2.1.0 8 | mujoco_py<2.2,>=2.1 9 | pytest==7.0.1 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/__init__.py -------------------------------------------------------------------------------- /tests/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/envs/__init__.py -------------------------------------------------------------------------------- /tests/envs/test_spec.py: -------------------------------------------------------------------------------- 1 | """Tests that gym.spec works as expected.""" 2 | 3 | import re 4 | 5 | import pytest 6 | 7 | import gym 8 | 9 | 10 | def test_spec(): 11 | spec = gym.spec("CartPole-v1") 12 | assert spec.id == "CartPole-v1" 13 | assert spec is gym.envs.registry["CartPole-v1"] 14 | 15 | 16 | def test_spec_kwargs(): 17 | map_name_value = "8x8" 18 | env = gym.make("FrozenLake-v1", map_name=map_name_value) 19 | assert env.spec.kwargs["map_name"] == map_name_value 20 | 21 | 22 | def test_spec_missing_lookup(): 23 | gym.register(id="Test1-v0", entry_point="no-entry-point") 24 | gym.register(id="Test1-v15", entry_point="no-entry-point") 25 | gym.register(id="Test1-v9", entry_point="no-entry-point") 26 | gym.register(id="Other1-v100", entry_point="no-entry-point") 27 | 28 | with pytest.raises( 29 | gym.error.DeprecatedEnv, 30 | match=re.escape( 31 | "Environment version v1 for `Test1` is deprecated. Please use `Test1-v15` instead." 32 | ), 33 | ): 34 | gym.spec("Test1-v1") 35 | 36 | with pytest.raises( 37 | gym.error.UnregisteredEnv, 38 | match=re.escape( 39 | "Environment version `v1000` for environment `Test1` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]." 40 | ), 41 | ): 42 | gym.spec("Test1-v1000") 43 | 44 | with pytest.raises( 45 | gym.error.UnregisteredEnv, 46 | match=re.escape("Environment Unknown1 doesn't exist. "), 47 | ): 48 | gym.spec("Unknown1-v1") 49 | 50 | 51 | def test_spec_malformed_lookup(): 52 | with pytest.raises( 53 | gym.error.Error, 54 | match=f'^{re.escape("Malformed environment ID: “Breakout-v0”.(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))")}$', 55 | ): 56 | gym.spec("“Breakout-v0”") 57 | 58 | 59 | def test_spec_versioned_lookups(): 60 | gym.register("test/Test2-v5", "no-entry-point") 61 | 62 | with pytest.raises( 63 | gym.error.VersionNotFound, 64 | match=re.escape( 65 | "Environment version `v9` for environment `test/Test2` doesn't exist. It provides versioned environments: [ `v5` ]." 66 | ), 67 | ): 68 | gym.spec("test/Test2-v9") 69 | 70 | with pytest.raises( 71 | gym.error.DeprecatedEnv, 72 | match=re.escape( 73 | "Environment version v4 for `test/Test2` is deprecated. Please use `test/Test2-v5` instead." 74 | ), 75 | ): 76 | gym.spec("test/Test2-v4") 77 | 78 | assert gym.spec("test/Test2-v5") is not None 79 | 80 | 81 | def test_spec_default_lookups(): 82 | gym.register("test/Test3", "no-entry-point") 83 | 84 | with pytest.raises( 85 | gym.error.DeprecatedEnv, 86 | match=re.escape( 87 | "Environment version `v0` for environment `test/Test3` doesn't exist. It provides the default version test/Test3`." 88 | ), 89 | ): 90 | gym.spec("test/Test3-v0") 91 | 92 | assert gym.spec("test/Test3") is not None 93 | -------------------------------------------------------------------------------- /tests/envs/utils.py: -------------------------------------------------------------------------------- 1 | """Finds all the specs that we can test with""" 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | 6 | import gym 7 | from gym import error, logger 8 | from gym.envs.registration import EnvSpec 9 | 10 | 11 | def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]: 12 | """Tries to make the environment showing if it is possible. 13 | 14 | Warning the environments have no wrappers, including time limit and order enforcing. 15 | """ 16 | # To avoid issues with registered environments during testing, we check that the spec entry points are from gym.envs. 17 | if "gym.envs." in env_spec.entry_point: 18 | try: 19 | return env_spec.make(disable_env_checker=True).unwrapped 20 | except (ImportError, error.DependencyNotInstalled) as e: 21 | logger.warn(f"Not testing {env_spec.id} due to error: {e}") 22 | return None 23 | 24 | 25 | # Tries to make all environment to test with 26 | all_testing_initialised_envs: List[Optional[gym.Env]] = [ 27 | try_make_env(env_spec) for env_spec in gym.envs.registry.values() 28 | ] 29 | all_testing_initialised_envs: List[gym.Env] = [ 30 | env for env in all_testing_initialised_envs if env is not None 31 | ] 32 | 33 | # All testing, mujoco and gym environment specs 34 | all_testing_env_specs: List[EnvSpec] = [ 35 | env.spec for env in all_testing_initialised_envs 36 | ] 37 | mujoco_testing_env_specs: List[EnvSpec] = [ 38 | env_spec 39 | for env_spec in all_testing_env_specs 40 | if "gym.envs.mujoco" in env_spec.entry_point 41 | ] 42 | gym_testing_env_specs: List[EnvSpec] = [ 43 | env_spec 44 | for env_spec in all_testing_env_specs 45 | if any( 46 | f"gym.envs.{ep}" in env_spec.entry_point 47 | for ep in ["box2d", "classic_control", "toy_text"] 48 | ) 49 | ] 50 | # TODO, add minimum testing env spec in testing 51 | minimum_testing_env_specs = [ 52 | env_spec 53 | for env_spec in [ 54 | "CartPole-v1", 55 | "MountainCarContinuous-v0", 56 | "LunarLander-v2", 57 | "LunarLanderContinuous-v2", 58 | "CarRacing-v2", 59 | "Blackjack-v1", 60 | "Reacher-v4", 61 | ] 62 | if env_spec in all_testing_env_specs 63 | ] 64 | 65 | 66 | def assert_equals(a, b, prefix=None): 67 | """Assert equality of data structures `a` and `b`. 68 | 69 | Args: 70 | a: first data structure 71 | b: second data structure 72 | prefix: prefix for failed assertion message for types and dicts 73 | """ 74 | assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}" 75 | if isinstance(a, dict): 76 | assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}" 77 | 78 | for k in a.keys(): 79 | v_a = a[k] 80 | v_b = b[k] 81 | assert_equals(v_a, v_b) 82 | elif isinstance(a, np.ndarray): 83 | np.testing.assert_array_equal(a, b) 84 | elif isinstance(a, tuple): 85 | for elem_from_a, elem_from_b in zip(a, b): 86 | assert_equals(elem_from_a, elem_from_b) 87 | else: 88 | assert a == b 89 | -------------------------------------------------------------------------------- /tests/envs/utils_envs.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | class RegisterDuringMakeEnv(gym.Env): 5 | """Used in `test_registration.py` to check if `env.make` can import and register an env""" 6 | 7 | def __init__(self): 8 | self.action_space = gym.spaces.Discrete(1) 9 | self.observation_space = gym.spaces.Discrete(1) 10 | 11 | 12 | class ArgumentEnv(gym.Env): 13 | observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) 14 | action_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) 15 | 16 | def __init__(self, arg1, arg2, arg3): 17 | self.arg1 = arg1 18 | self.arg2 = arg2 19 | self.arg3 = arg3 20 | 21 | 22 | # Environments to test render_mode 23 | class NoHuman(gym.Env): 24 | """Environment that does not have human-rendering.""" 25 | 26 | metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} 27 | 28 | def __init__(self, render_mode=None): 29 | assert render_mode in self.metadata["render_modes"] 30 | self.render_mode = render_mode 31 | 32 | 33 | class NoHumanOldAPI(gym.Env): 34 | """Environment that does not have human-rendering.""" 35 | 36 | metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} 37 | 38 | def __init__(self): 39 | pass 40 | 41 | 42 | class NoHumanNoRGB(gym.Env): 43 | """Environment that has neither human- nor rgb-rendering""" 44 | 45 | metadata = {"render_modes": ["ascii"], "render_fps": 4} 46 | 47 | def __init__(self, render_mode=None): 48 | assert render_mode in self.metadata["render_modes"] 49 | self.render_mode = render_mode 50 | -------------------------------------------------------------------------------- /tests/spaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/spaces/__init__.py -------------------------------------------------------------------------------- /tests/spaces/test_discrete.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym.spaces import Discrete 4 | 5 | 6 | def test_space_legacy_pickling(): 7 | """Test the legacy pickle of Discrete that is missing the `start` parameter.""" 8 | legacy_state = { 9 | "shape": ( 10 | 1, 11 | 2, 12 | 3, 13 | ), 14 | "dtype": np.int64, 15 | "np_random": np.random.default_rng(), 16 | "n": 3, 17 | } 18 | space = Discrete(1) 19 | space.__setstate__(legacy_state) 20 | 21 | assert space.shape == legacy_state["shape"] 22 | assert space.np_random == legacy_state["np_random"] 23 | assert space.n == 3 24 | assert space.dtype == legacy_state["dtype"] 25 | 26 | # Test that start is missing 27 | assert "start" in space.__dict__ 28 | del space.__dict__["start"] # legacy did not include start param 29 | assert "start" not in space.__dict__ 30 | 31 | space.__setstate__(legacy_state) 32 | assert space.start == 0 33 | 34 | 35 | def test_sample_mask(): 36 | space = Discrete(4, start=2) 37 | assert 2 <= space.sample() < 6 38 | assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3 39 | assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2 40 | assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5] 41 | -------------------------------------------------------------------------------- /tests/spaces/test_multibinary.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym.spaces import MultiBinary 4 | 5 | 6 | def test_sample(): 7 | space = MultiBinary(4) 8 | 9 | sample = space.sample(mask=np.array([0, 0, 1, 1], dtype=np.int8)) 10 | assert np.all(sample == [0, 0, 1, 1]) 11 | 12 | sample = space.sample(mask=np.array([0, 1, 2, 2], dtype=np.int8)) 13 | assert sample[0] == 0 and sample[1] == 1 14 | assert sample[2] == 0 or sample[2] == 1 15 | assert sample[3] == 0 or sample[3] == 1 16 | 17 | space = MultiBinary(np.array([2, 3])) 18 | sample = space.sample(mask=np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int8)) 19 | assert np.all(sample == [[0, 0, 0], [1, 1, 1]]), sample 20 | -------------------------------------------------------------------------------- /tests/spaces/test_multidiscrete.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from gym.spaces import Discrete, MultiDiscrete 4 | from gym.utils.env_checker import data_equivalence 5 | 6 | 7 | def test_multidiscrete_as_tuple(): 8 | # 1D multi-discrete 9 | space = MultiDiscrete([3, 4, 5]) 10 | 11 | assert space.shape == (3,) 12 | assert space[0] == Discrete(3) 13 | assert space[0:1] == MultiDiscrete([3]) 14 | assert space[0:2] == MultiDiscrete([3, 4]) 15 | assert space[:] == space and space[:] is not space 16 | 17 | # 2D multi-discrete 18 | space = MultiDiscrete([[3, 4, 5], [6, 7, 8]]) 19 | 20 | assert space.shape == (2, 3) 21 | assert space[0, 1] == Discrete(4) 22 | assert space[0] == MultiDiscrete([3, 4, 5]) 23 | assert space[0:1] == MultiDiscrete([[3, 4, 5]]) 24 | assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]]) 25 | assert space[:, 0:1] == MultiDiscrete([[3], [6]]) 26 | assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]]) 27 | assert space[:] == space and space[:] is not space 28 | assert space[:, :] == space and space[:, :] is not space 29 | 30 | 31 | def test_multidiscrete_subspace_reproducibility(): 32 | # 1D multi-discrete 33 | space = MultiDiscrete([100, 200, 300]) 34 | space.seed() 35 | 36 | assert data_equivalence(space[0].sample(), space[0].sample()) 37 | assert data_equivalence(space[0:1].sample(), space[0:1].sample()) 38 | assert data_equivalence(space[0:2].sample(), space[0:2].sample()) 39 | assert data_equivalence(space[:].sample(), space[:].sample()) 40 | assert data_equivalence(space[:].sample(), space.sample()) 41 | 42 | # 2D multi-discrete 43 | space = MultiDiscrete([[300, 400, 500], [600, 700, 800]]) 44 | space.seed() 45 | 46 | assert data_equivalence(space[0, 1].sample(), space[0, 1].sample()) 47 | assert data_equivalence(space[0].sample(), space[0].sample()) 48 | assert data_equivalence(space[0:1].sample(), space[0:1].sample()) 49 | assert data_equivalence(space[0:2, :].sample(), space[0:2, :].sample()) 50 | assert data_equivalence(space[:, 0:1].sample(), space[:, 0:1].sample()) 51 | assert data_equivalence(space[0:2, 0:2].sample(), space[0:2, 0:2].sample()) 52 | assert data_equivalence(space[:].sample(), space[:].sample()) 53 | assert data_equivalence(space[:, :].sample(), space[:, :].sample()) 54 | assert data_equivalence(space[:, :].sample(), space.sample()) 55 | 56 | 57 | def test_multidiscrete_length(): 58 | space = MultiDiscrete(nvec=[3, 2, 4]) 59 | assert len(space) == 3 60 | 61 | space = MultiDiscrete(nvec=[[2, 3], [3, 2]]) 62 | with pytest.warns( 63 | UserWarning, 64 | match="Getting the length of a multi-dimensional MultiDiscrete space.", 65 | ): 66 | assert len(space) == 2 67 | -------------------------------------------------------------------------------- /tests/spaces/test_sequence.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import gym.spaces 7 | 8 | 9 | def test_sample(): 10 | """Tests the sequence sampling works as expects and the errors are correctly raised.""" 11 | space = gym.spaces.Sequence(gym.spaces.Box(0, 1)) 12 | 13 | # Test integer mask length 14 | for length in range(4): 15 | sample = space.sample(mask=(length, None)) 16 | assert sample in space 17 | assert len(sample) == length 18 | 19 | with pytest.raises( 20 | AssertionError, 21 | match=re.escape( 22 | "Expects the length mask to be greater than or equal to zero, actual value: -1" 23 | ), 24 | ): 25 | space.sample(mask=(-1, None)) 26 | 27 | # Test np.array mask length 28 | sample = space.sample(mask=(np.array([5]), None)) 29 | assert sample in space 30 | assert len(sample) == 5 31 | 32 | sample = space.sample(mask=(np.array([3, 4, 5]), None)) 33 | assert sample in space 34 | assert len(sample) in [3, 4, 5] 35 | 36 | with pytest.raises( 37 | AssertionError, 38 | match=re.escape( 39 | "Expects the shape of the length mask to be 1-dimensional, actual shape: (2, 2)" 40 | ), 41 | ): 42 | space.sample(mask=(np.array([[2, 2], [2, 2]]), None)) 43 | 44 | with pytest.raises( 45 | AssertionError, 46 | match=re.escape( 47 | "Expects all values in the length_mask to be greater than or equal to zero, actual values: [ 1 2 -1]" 48 | ), 49 | ): 50 | space.sample(mask=(np.array([1, 2, -1]), None)) 51 | 52 | # Test with an invalid length 53 | with pytest.raises( 54 | TypeError, 55 | match=re.escape( 56 | "Expects the type of length_mask to an integer or a np.ndarray, actual type: " 57 | ), 58 | ): 59 | space.sample(mask=("abc", None)) 60 | -------------------------------------------------------------------------------- /tests/spaces/test_space.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import pytest 4 | 5 | from gym import Space 6 | from gym.spaces import utils 7 | 8 | TESTING_SPACE = Space() 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "func", 13 | [ 14 | TESTING_SPACE.sample, 15 | partial(TESTING_SPACE.contains, None), 16 | partial(utils.flatdim, TESTING_SPACE), 17 | partial(utils.flatten, TESTING_SPACE, None), 18 | partial(utils.flatten_space, TESTING_SPACE), 19 | partial(utils.unflatten, TESTING_SPACE, None), 20 | ], 21 | ) 22 | def test_not_implemented_errors(func): 23 | with pytest.raises(NotImplementedError): 24 | func() 25 | -------------------------------------------------------------------------------- /tests/spaces/test_text.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from gym.spaces import Text 7 | 8 | 9 | def test_sample_mask(): 10 | space = Text(min_length=1, max_length=5) 11 | 12 | # Test the sample length 13 | sample = space.sample(mask=(3, None)) 14 | assert sample in space 15 | assert len(sample) == 3 16 | 17 | sample = space.sample(mask=None) 18 | assert sample in space 19 | assert 1 <= len(sample) <= 5 20 | 21 | with pytest.raises( 22 | ValueError, 23 | match=re.escape( 24 | "Trying to sample with a minimum length > 0 (1) but the character mask is all zero meaning that no character could be sampled." 25 | ), 26 | ): 27 | space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8))) 28 | 29 | space = Text(min_length=0, max_length=5) 30 | sample = space.sample( 31 | mask=(None, np.zeros(len(space.character_set), dtype=np.int8)) 32 | ) 33 | assert sample in space 34 | assert sample == "" 35 | 36 | # Test the sample characters 37 | space = Text(max_length=5, charset="abcd") 38 | 39 | sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8))) 40 | assert sample in space 41 | assert sample == "bbb" 42 | -------------------------------------------------------------------------------- /tests/spaces/test_tuple.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import gym.spaces 5 | from gym.spaces import Box, Dict, Discrete, MultiBinary, Tuple 6 | from gym.utils.env_checker import data_equivalence 7 | 8 | 9 | def test_sequence_inheritance(): 10 | """The gym Tuple space inherits from abc.Sequences, this test checks all functions work""" 11 | spaces = [Discrete(5), Discrete(10), Discrete(5)] 12 | tuple_space = Tuple(spaces) 13 | 14 | assert len(tuple_space) == len(spaces) 15 | # Test indexing 16 | for i in range(len(tuple_space)): 17 | assert tuple_space[i] == spaces[i] 18 | 19 | # Test iterable 20 | for space in tuple_space: 21 | assert space in spaces 22 | 23 | # Test count 24 | assert tuple_space.count(Discrete(5)) == 2 25 | assert tuple_space.count(Discrete(6)) == 0 26 | assert tuple_space.count(MultiBinary(2)) == 0 27 | 28 | # Test index 29 | assert tuple_space.index(Discrete(5)) == 0 30 | assert tuple_space.index(Discrete(5), 1) == 2 31 | 32 | # Test errors 33 | with pytest.raises(ValueError): 34 | tuple_space.index(Discrete(10), 0, 1) 35 | with pytest.raises(IndexError): 36 | assert tuple_space[4] 37 | 38 | 39 | @pytest.mark.parametrize( 40 | "space, seed, expected_len", 41 | [ 42 | (Tuple([Discrete(5), Discrete(4)]), None, 2), 43 | (Tuple([Discrete(5), Discrete(4)]), 123, 3), 44 | (Tuple([Discrete(5), Discrete(4)]), (123, 456), 2), 45 | ( 46 | Tuple( 47 | (Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2)))) 48 | ), 49 | (123, (456, 789)), 50 | 3, 51 | ), 52 | ( 53 | Tuple( 54 | ( 55 | Discrete(3), 56 | Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)), 57 | ) 58 | ), 59 | (123, {"position": 456, "velocity": 789}), 60 | 3, 61 | ), 62 | ], 63 | ) 64 | def test_seeds(space, seed, expected_len): 65 | seeds = space.seed(seed) 66 | assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds) 67 | assert len(seeds) == expected_len 68 | 69 | sample1 = space.sample() 70 | 71 | seeds2 = space.seed(seed) 72 | sample2 = space.sample() 73 | 74 | data_equivalence(seeds, seeds2) 75 | data_equivalence(sample1, sample2) 76 | 77 | 78 | @pytest.mark.parametrize( 79 | "space_fn", 80 | [ 81 | lambda: Tuple(["abc"]), 82 | lambda: Tuple([gym.spaces.Box(0, 1), "abc"]), 83 | lambda: Tuple("abc"), 84 | ], 85 | ) 86 | def test_bad_space_calls(space_fn): 87 | with pytest.raises(AssertionError): 88 | space_fn() 89 | 90 | 91 | def test_contains_promotion(): 92 | space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (2,)))) 93 | 94 | assert ( 95 | np.array([0.0], dtype=np.float32), 96 | np.array([0.0, 0.0], dtype=np.float32), 97 | ) in space 98 | 99 | space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (1,)))) 100 | assert np.array([[0.0], [0.0]], dtype=np.float32) in space 101 | 102 | 103 | def test_bad_seed(): 104 | space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(0, 1))) 105 | with pytest.raises( 106 | TypeError, 107 | match="Expected seed type: list, tuple, int or None, actual type: ", 108 | ): 109 | space.seed(0.0) 110 | -------------------------------------------------------------------------------- /tests/spaces/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | 5 | from gym.spaces import ( 6 | Box, 7 | Dict, 8 | Discrete, 9 | Graph, 10 | MultiBinary, 11 | MultiDiscrete, 12 | Sequence, 13 | Space, 14 | Text, 15 | Tuple, 16 | ) 17 | 18 | TESTING_FUNDAMENTAL_SPACES = [ 19 | Discrete(3), 20 | Discrete(3, start=-1), 21 | Box(low=0.0, high=1.0), 22 | Box(low=0.0, high=np.inf, shape=(2, 2)), 23 | Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64), 24 | Box(low=-np.inf, high=0.0, shape=(2, 1)), 25 | Box(low=0.0, high=np.inf, shape=(2, 1)), 26 | MultiDiscrete([2, 2]), 27 | MultiDiscrete([[2, 3], [3, 2]]), 28 | MultiBinary(8), 29 | MultiBinary([2, 3]), 30 | Text(6), 31 | Text(min_length=3, max_length=6), 32 | Text(6, charset="abcdef"), 33 | ] 34 | TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES] 35 | 36 | 37 | TESTING_COMPOSITE_SPACES = [ 38 | # Tuple spaces 39 | Tuple([Discrete(5), Discrete(4)]), 40 | Tuple( 41 | ( 42 | Discrete(5), 43 | Box( 44 | low=np.array([0.0, 0.0]), 45 | high=np.array([1.0, 5.0]), 46 | dtype=np.float64, 47 | ), 48 | ) 49 | ), 50 | Tuple((Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))), 51 | Tuple((Discrete(3), Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)))), 52 | Tuple((Graph(node_space=Box(-1, 1, shape=(2, 1)), edge_space=None), Discrete(2))), 53 | # Dict spaces 54 | Dict( 55 | { 56 | "position": Discrete(5), 57 | "velocity": Box( 58 | low=np.array([0.0, 0.0]), 59 | high=np.array([1.0, 5.0]), 60 | dtype=np.float64, 61 | ), 62 | } 63 | ), 64 | Dict( 65 | position=Discrete(6), 66 | velocity=Box( 67 | low=np.array([0.0, 0.0]), 68 | high=np.array([1.0, 5.0]), 69 | dtype=np.float64, 70 | ), 71 | ), 72 | Dict( 73 | { 74 | "a": Box(low=0, high=1, shape=(3, 3)), 75 | "b": Dict( 76 | { 77 | "b_1": Box(low=-100, high=100, shape=(2,)), 78 | "b_2": Box(low=-1, high=1, shape=(2,)), 79 | } 80 | ), 81 | "c": Discrete(4), 82 | } 83 | ), 84 | Dict( 85 | a=Dict( 86 | a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None), 87 | b=Box(-100, 100, shape=(2, 2)), 88 | ), 89 | b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))), 90 | ), 91 | # Graph spaces 92 | Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)), 93 | Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))), 94 | Graph(node_space=Discrete(3), edge_space=Discrete(4)), 95 | # Sequence spaces 96 | Sequence(Discrete(4)), 97 | Sequence(Dict({"feature": Box(0, 1, (3,))})), 98 | Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))), 99 | ] 100 | TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES] 101 | 102 | TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES + TESTING_COMPOSITE_SPACES 103 | TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS + TESTING_COMPOSITE_SPACES_IDS 104 | -------------------------------------------------------------------------------- /tests/testing_env.py: -------------------------------------------------------------------------------- 1 | """Provides a generic testing environment for use in tests with custom reset, step and render functions.""" 2 | import types 3 | from typing import Any, Dict, Optional, Tuple, Union 4 | 5 | import gym 6 | from gym import spaces 7 | from gym.core import ActType, ObsType 8 | from gym.envs.registration import EnvSpec 9 | 10 | 11 | def basic_reset_fn( 12 | self, 13 | *, 14 | seed: Optional[int] = None, 15 | options: Optional[dict] = None, 16 | ) -> Union[ObsType, Tuple[ObsType, dict]]: 17 | """A basic reset function that will pass the environment check using random actions from the observation space.""" 18 | super(GenericTestEnv, self).reset(seed=seed) 19 | self.observation_space.seed(seed) 20 | return self.observation_space.sample(), {"options": options} 21 | 22 | 23 | def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: 24 | """A step function that follows the new step api that will pass the environment check using random actions from the observation space.""" 25 | return self.observation_space.sample(), 0, False, False, {} 26 | 27 | 28 | def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: 29 | """A step function that follows the old step api that will pass the environment check using random actions from the observation space.""" 30 | return self.observation_space.sample(), 0, False, {} 31 | 32 | 33 | def basic_render_fn(self): 34 | """Basic render fn that does nothing.""" 35 | pass 36 | 37 | 38 | # todo: change all testing environment to this generic class 39 | class GenericTestEnv(gym.Env): 40 | """A generic testing environment for use in testing with modified environments are required.""" 41 | 42 | def __init__( 43 | self, 44 | action_space: spaces.Space = spaces.Box(0, 1, (1,)), 45 | observation_space: spaces.Space = spaces.Box(0, 1, (1,)), 46 | reset_fn: callable = basic_reset_fn, 47 | step_fn: callable = new_step_fn, 48 | render_fn: callable = basic_render_fn, 49 | metadata: Optional[Dict[str, Any]] = None, 50 | render_mode: Optional[str] = None, 51 | spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"), 52 | ): 53 | self.metadata = {} if metadata is None else metadata 54 | self.render_mode = render_mode 55 | self.spec = spec 56 | 57 | if observation_space is not None: 58 | self.observation_space = observation_space 59 | if action_space is not None: 60 | self.action_space = action_space 61 | 62 | if reset_fn is not None: 63 | self.reset = types.MethodType(reset_fn, self) 64 | if step_fn is not None: 65 | self.step = types.MethodType(step_fn, self) 66 | if render_fn is not None: 67 | self.render = types.MethodType(render_fn, self) 68 | 69 | def reset( 70 | self, 71 | *, 72 | seed: Optional[int] = None, 73 | options: Optional[dict] = None, 74 | ) -> Union[ObsType, Tuple[ObsType, dict]]: 75 | # If you need a default working reset function, use `basic_reset_fn` above 76 | raise NotImplementedError("TestingEnv reset_fn is not set.") 77 | 78 | def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]: 79 | raise NotImplementedError("TestingEnv step_fn is not set.") 80 | 81 | def render(self): 82 | raise NotImplementedError("testingEnv render_fn is not set.") 83 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/test_seeding.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from gym import error 4 | from gym.utils import seeding 5 | 6 | 7 | def test_invalid_seeds(): 8 | for seed in [-1, "test"]: 9 | try: 10 | seeding.np_random(seed) 11 | except error.Error: 12 | pass 13 | else: 14 | assert False, f"Invalid seed {seed} passed validation" 15 | 16 | 17 | def test_valid_seeds(): 18 | for seed in [0, 1]: 19 | random, seed1 = seeding.np_random(seed) 20 | assert seed == seed1 21 | 22 | 23 | def test_rng_pickle(): 24 | rng, _ = seeding.np_random(seed=0) 25 | pickled = pickle.dumps(rng) 26 | rng2 = pickle.loads(pickled) 27 | assert isinstance( 28 | rng2, seeding.RandomNumberGenerator 29 | ), "Unpickled object is not a RandomNumberGenerator" 30 | assert rng.random() == rng2.random() 31 | -------------------------------------------------------------------------------- /tests/vector/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/vector/__init__.py -------------------------------------------------------------------------------- /tests/vector/test_vector_env_info.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import gym 5 | from gym.vector.sync_vector_env import SyncVectorEnv 6 | from tests.vector.utils import make_env 7 | 8 | ENV_ID = "CartPole-v1" 9 | NUM_ENVS = 3 10 | ENV_STEPS = 50 11 | SEED = 42 12 | 13 | 14 | @pytest.mark.parametrize("asynchronous", [True, False]) 15 | def test_vector_env_info(asynchronous): 16 | env = gym.vector.make( 17 | ENV_ID, num_envs=NUM_ENVS, asynchronous=asynchronous, disable_env_checker=True 18 | ) 19 | env.reset(seed=SEED) 20 | for _ in range(ENV_STEPS): 21 | env.action_space.seed(SEED) 22 | action = env.action_space.sample() 23 | _, _, terminateds, truncateds, infos = env.step(action) 24 | if any(terminateds) or any(truncateds): 25 | assert len(infos["final_observation"]) == NUM_ENVS 26 | assert len(infos["_final_observation"]) == NUM_ENVS 27 | 28 | assert isinstance(infos["final_observation"], np.ndarray) 29 | assert isinstance(infos["_final_observation"], np.ndarray) 30 | 31 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): 32 | if terminated or truncated: 33 | assert infos["_final_observation"][i] 34 | else: 35 | assert not infos["_final_observation"][i] 36 | assert infos["final_observation"][i] is None 37 | 38 | 39 | @pytest.mark.parametrize("concurrent_ends", [1, 2, 3]) 40 | def test_vector_env_info_concurrent_termination(concurrent_ends): 41 | # envs that need to terminate together will have the same action 42 | actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends) 43 | envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)] 44 | envs = SyncVectorEnv(envs) 45 | 46 | for _ in range(ENV_STEPS): 47 | _, _, terminateds, truncateds, infos = envs.step(actions) 48 | if any(terminateds) or any(truncateds): 49 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): 50 | if i < concurrent_ends: 51 | assert terminated or truncated 52 | assert infos["_final_observation"][i] 53 | else: 54 | assert not infos["_final_observation"][i] 55 | assert infos["final_observation"][i] is None 56 | return 57 | -------------------------------------------------------------------------------- /tests/vector/test_vector_env_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym.vector import VectorEnvWrapper, make 4 | 5 | 6 | class DummyWrapper(VectorEnvWrapper): 7 | def __init__(self, env): 8 | self.env = env 9 | self.counter = 0 10 | 11 | def reset_async(self, **kwargs): 12 | super().reset_async() 13 | self.counter += 1 14 | 15 | 16 | def test_vector_env_wrapper_inheritance(): 17 | env = make("FrozenLake-v1", asynchronous=False) 18 | wrapped = DummyWrapper(env) 19 | wrapped.reset() 20 | assert wrapped.counter == 1 21 | 22 | 23 | def test_vector_env_wrapper_attributes(): 24 | """Test if `set_attr`, `call` methods for VecEnvWrapper get correctly forwarded to the vector env it is wrapping.""" 25 | env = make("CartPole-v1", num_envs=3) 26 | wrapped = DummyWrapper(make("CartPole-v1", num_envs=3)) 27 | 28 | assert np.allclose(wrapped.call("gravity"), env.call("gravity")) 29 | env.set_attr("gravity", [20.0, 20.0, 20.0]) 30 | wrapped.set_attr("gravity", [20.0, 20.0, 20.0]) 31 | assert np.allclose(wrapped.get_attr("gravity"), env.get_attr("gravity")) 32 | -------------------------------------------------------------------------------- /tests/vector/test_vector_make.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym.vector import AsyncVectorEnv, SyncVectorEnv 5 | from gym.wrappers import OrderEnforcing, TimeLimit, TransformObservation 6 | from gym.wrappers.env_checker import PassiveEnvChecker 7 | from tests.wrappers.utils import has_wrapper 8 | 9 | 10 | def test_vector_make_id(): 11 | env = gym.vector.make("CartPole-v1") 12 | assert isinstance(env, AsyncVectorEnv) 13 | assert env.num_envs == 1 14 | env.close() 15 | 16 | 17 | @pytest.mark.parametrize("num_envs", [1, 3, 10]) 18 | def test_vector_make_num_envs(num_envs): 19 | env = gym.vector.make("CartPole-v1", num_envs=num_envs) 20 | assert env.num_envs == num_envs 21 | env.close() 22 | 23 | 24 | def test_vector_make_asynchronous(): 25 | env = gym.vector.make("CartPole-v1", asynchronous=True) 26 | assert isinstance(env, AsyncVectorEnv) 27 | env.close() 28 | 29 | env = gym.vector.make("CartPole-v1", asynchronous=False) 30 | assert isinstance(env, SyncVectorEnv) 31 | env.close() 32 | 33 | 34 | def test_vector_make_wrappers(): 35 | env = gym.vector.make("CartPole-v1", num_envs=2, asynchronous=False) 36 | assert isinstance(env, SyncVectorEnv) 37 | assert len(env.envs) == 2 38 | 39 | sub_env = env.envs[0] 40 | assert isinstance(sub_env, gym.Env) 41 | if sub_env.spec.order_enforce: 42 | assert has_wrapper(sub_env, OrderEnforcing) 43 | if sub_env.spec.max_episode_steps is not None: 44 | assert has_wrapper(sub_env, TimeLimit) 45 | 46 | assert all( 47 | has_wrapper(sub_env, TransformObservation) is False for sub_env in env.envs 48 | ) 49 | env.close() 50 | 51 | env = gym.vector.make( 52 | "CartPole-v1", 53 | num_envs=2, 54 | asynchronous=False, 55 | wrappers=lambda _env: TransformObservation(_env, lambda obs: obs * 2), 56 | ) 57 | # As asynchronous environment are inaccessible, synchronous vector must be used 58 | assert isinstance(env, SyncVectorEnv) 59 | assert all(has_wrapper(sub_env, TransformObservation) for sub_env in env.envs) 60 | 61 | env.close() 62 | 63 | 64 | def test_vector_make_disable_env_checker(): 65 | # As asynchronous environment are inaccessible, synchronous vector must be used 66 | env = gym.vector.make("CartPole-v1", num_envs=1, asynchronous=False) 67 | assert isinstance(env, SyncVectorEnv) 68 | assert has_wrapper(env.envs[0], PassiveEnvChecker) 69 | env.close() 70 | 71 | env = gym.vector.make("CartPole-v1", num_envs=5, asynchronous=False) 72 | assert isinstance(env, SyncVectorEnv) 73 | assert has_wrapper(env.envs[0], PassiveEnvChecker) 74 | assert all( 75 | has_wrapper(env.envs[i], PassiveEnvChecker) is False for i in [1, 2, 3, 4] 76 | ) 77 | env.close() 78 | 79 | env = gym.vector.make( 80 | "CartPole-v1", num_envs=3, asynchronous=False, disable_env_checker=True 81 | ) 82 | assert isinstance(env, SyncVectorEnv) 83 | assert all(has_wrapper(sub_env, PassiveEnvChecker) is False for sub_env in env.envs) 84 | env.close() 85 | -------------------------------------------------------------------------------- /tests/wrappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/wrappers/__init__.py -------------------------------------------------------------------------------- /tests/wrappers/test_clip_action.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import gym 4 | from gym.wrappers import ClipAction 5 | 6 | 7 | def test_clip_action(): 8 | # mountaincar: action-based rewards 9 | env = gym.make("MountainCarContinuous-v0", disable_env_checker=True) 10 | wrapped_env = ClipAction( 11 | gym.make("MountainCarContinuous-v0", disable_env_checker=True) 12 | ) 13 | 14 | seed = 0 15 | 16 | env.reset(seed=seed) 17 | wrapped_env.reset(seed=seed) 18 | 19 | actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]] 20 | for action in actions: 21 | obs1, r1, ter1, trunc1, _ = env.step( 22 | np.clip(action, env.action_space.low, env.action_space.high) 23 | ) 24 | obs2, r2, ter2, trunc2, _ = wrapped_env.step(action) 25 | assert np.allclose(r1, r2) 26 | assert np.allclose(obs1, obs2) 27 | assert ter1 == ter2 28 | assert trunc1 == trunc2 29 | -------------------------------------------------------------------------------- /tests/wrappers/test_filter_observation.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | import gym 7 | from gym import spaces 8 | from gym.wrappers.filter_observation import FilterObservation 9 | 10 | 11 | class FakeEnvironment(gym.Env): 12 | def __init__( 13 | self, render_mode=None, observation_keys: Tuple[str, ...] = ("state",) 14 | ): 15 | self.observation_space = spaces.Dict( 16 | { 17 | name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32) 18 | for name in observation_keys 19 | } 20 | ) 21 | self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32) 22 | self.render_mode = render_mode 23 | 24 | def render(self, mode="human"): 25 | image_shape = (32, 32, 3) 26 | return np.zeros(image_shape, dtype=np.uint8) 27 | 28 | def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): 29 | super().reset(seed=seed) 30 | observation = self.observation_space.sample() 31 | return observation, {} 32 | 33 | def step(self, action): 34 | del action 35 | observation = self.observation_space.sample() 36 | reward, terminal, info = 0.0, False, {} 37 | return observation, reward, terminal, info 38 | 39 | 40 | FILTER_OBSERVATION_TEST_CASES = ( 41 | (("key1", "key2"), ("key1",)), 42 | (("key1", "key2"), ("key1", "key2")), 43 | (("key1",), None), 44 | (("key1",), ("key1",)), 45 | ) 46 | 47 | ERROR_TEST_CASES = ( 48 | ("key", ValueError, "All the filter_keys must be included..*"), 49 | (False, TypeError, "'bool' object is not iterable"), 50 | (1, TypeError, "'int' object is not iterable"), 51 | ) 52 | 53 | 54 | class TestFilterObservation: 55 | @pytest.mark.parametrize( 56 | "observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES 57 | ) 58 | def test_filter_observation(self, observation_keys, filter_keys): 59 | env = FakeEnvironment(observation_keys=observation_keys) 60 | 61 | # Make sure we are testing the right environment for the test. 62 | observation_space = env.observation_space 63 | assert isinstance(observation_space, spaces.Dict) 64 | 65 | wrapped_env = FilterObservation(env, filter_keys=filter_keys) 66 | 67 | assert isinstance(wrapped_env.observation_space, spaces.Dict) 68 | 69 | if filter_keys is None: 70 | filter_keys = tuple(observation_keys) 71 | 72 | assert len(wrapped_env.observation_space.spaces) == len(filter_keys) 73 | assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys) 74 | 75 | # Check that the added space item is consistent with the added observation. 76 | observation, info = wrapped_env.reset() 77 | assert len(observation) == len(filter_keys) 78 | assert isinstance(info, dict) 79 | 80 | @pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES) 81 | def test_raises_with_incorrect_arguments( 82 | self, filter_keys, error_type, error_match 83 | ): 84 | env = FakeEnvironment(observation_keys=("key1", "key2")) 85 | 86 | with pytest.raises(error_type, match=error_match): 87 | FilterObservation(env, filter_keys=filter_keys) 88 | -------------------------------------------------------------------------------- /tests/wrappers/test_flatten.py: -------------------------------------------------------------------------------- 1 | """Tests for the flatten observation wrapper.""" 2 | 3 | from collections import OrderedDict 4 | from typing import Optional 5 | 6 | import numpy as np 7 | import pytest 8 | 9 | import gym 10 | from gym.spaces import Box, Dict, flatten, unflatten 11 | from gym.wrappers import FlattenObservation 12 | 13 | 14 | class FakeEnvironment(gym.Env): 15 | def __init__(self, observation_space): 16 | self.observation_space = observation_space 17 | 18 | def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): 19 | super().reset(seed=seed) 20 | self.observation = self.observation_space.sample() 21 | return self.observation, {} 22 | 23 | 24 | OBSERVATION_SPACES = ( 25 | ( 26 | Dict( 27 | OrderedDict( 28 | [ 29 | ("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)), 30 | ("key2", Box(shape=(), low=1, high=1, dtype=np.float32)), 31 | ("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)), 32 | ] 33 | ) 34 | ), 35 | True, 36 | ), 37 | ( 38 | Dict( 39 | OrderedDict( 40 | [ 41 | ("key2", Box(shape=(), low=0, high=0, dtype=np.float32)), 42 | ("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)), 43 | ("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)), 44 | ] 45 | ) 46 | ), 47 | True, 48 | ), 49 | ( 50 | Dict( 51 | { 52 | "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32), 53 | "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), 54 | "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), 55 | } 56 | ), 57 | False, 58 | ), 59 | ) 60 | 61 | 62 | class TestFlattenEnvironment: 63 | @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) 64 | def test_flattened_environment(self, observation_space, ordered_values): 65 | """ 66 | make sure that flattened observations occur in the order expected 67 | """ 68 | env = FakeEnvironment(observation_space=observation_space) 69 | wrapped_env = FlattenObservation(env) 70 | flattened, info = wrapped_env.reset() 71 | 72 | unflattened = unflatten(env.observation_space, flattened) 73 | original = env.observation 74 | 75 | self._check_observations(original, flattened, unflattened, ordered_values) 76 | 77 | @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) 78 | def test_flatten_unflatten(self, observation_space, ordered_values): 79 | """ 80 | test flatten and unflatten functions directly 81 | """ 82 | original = observation_space.sample() 83 | 84 | flattened = flatten(observation_space, original) 85 | unflattened = unflatten(observation_space, flattened) 86 | 87 | self._check_observations(original, flattened, unflattened, ordered_values) 88 | 89 | def _check_observations(self, original, flattened, unflattened, ordered_values): 90 | # make sure that unflatten(flatten(original)) == original 91 | assert set(unflattened.keys()) == set(original.keys()) 92 | for k, v in original.items(): 93 | np.testing.assert_allclose(unflattened[k], v) 94 | 95 | if ordered_values: 96 | # make sure that the values were flattened in the order they appeared in the 97 | # OrderedDict 98 | np.testing.assert_allclose(sorted(flattened), flattened) 99 | -------------------------------------------------------------------------------- /tests/wrappers/test_flatten_observation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import gym 5 | from gym import spaces 6 | from gym.wrappers import FlattenObservation 7 | 8 | 9 | @pytest.mark.parametrize("env_id", ["Blackjack-v1"]) 10 | def test_flatten_observation(env_id): 11 | env = gym.make(env_id, disable_env_checker=True) 12 | wrapped_env = FlattenObservation(env) 13 | 14 | obs, info = env.reset() 15 | wrapped_obs, wrapped_obs_info = wrapped_env.reset() 16 | 17 | space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2))) 18 | wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64) 19 | 20 | assert space.contains(obs) 21 | assert wrapped_space.contains(wrapped_obs) 22 | assert isinstance(info, dict) 23 | assert isinstance(wrapped_obs_info, dict) 24 | -------------------------------------------------------------------------------- /tests/wrappers/test_frame_stack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import gym 5 | from gym.wrappers import FrameStack 6 | 7 | try: 8 | import lz4 9 | except ImportError: 10 | lz4 = None 11 | 12 | 13 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "CarRacing-v2"]) 14 | @pytest.mark.parametrize("num_stack", [2, 3, 4]) 15 | @pytest.mark.parametrize( 16 | "lz4_compress", 17 | [ 18 | pytest.param( 19 | True, 20 | marks=pytest.mark.skipif( 21 | lz4 is None, reason="Need lz4 to run tests with compression" 22 | ), 23 | ), 24 | False, 25 | ], 26 | ) 27 | def test_frame_stack(env_id, num_stack, lz4_compress): 28 | env = gym.make(env_id, disable_env_checker=True) 29 | shape = env.observation_space.shape 30 | env = FrameStack(env, num_stack, lz4_compress) 31 | assert env.observation_space.shape == (num_stack,) + shape 32 | assert env.observation_space.dtype == env.env.observation_space.dtype 33 | 34 | dup = gym.make(env_id, disable_env_checker=True) 35 | 36 | obs, _ = env.reset(seed=0) 37 | dup_obs, _ = dup.reset(seed=0) 38 | assert np.allclose(obs[-1], dup_obs) 39 | 40 | for _ in range(num_stack**2): 41 | action = env.action_space.sample() 42 | dup_obs, _, dup_terminated, dup_truncated, _ = dup.step(action) 43 | obs, _, terminated, truncated, _ = env.step(action) 44 | 45 | assert dup_terminated == terminated 46 | assert dup_truncated == truncated 47 | assert np.allclose(obs[-1], dup_obs) 48 | 49 | if terminated or truncated: 50 | break 51 | 52 | assert len(obs) == num_stack 53 | -------------------------------------------------------------------------------- /tests/wrappers/test_gray_scale_observation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym import spaces 5 | from gym.wrappers import GrayScaleObservation 6 | 7 | 8 | @pytest.mark.parametrize("env_id", ["CarRacing-v2"]) 9 | @pytest.mark.parametrize("keep_dim", [True, False]) 10 | def test_gray_scale_observation(env_id, keep_dim): 11 | rgb_env = gym.make(env_id, disable_env_checker=True) 12 | 13 | assert isinstance(rgb_env.observation_space, spaces.Box) 14 | assert len(rgb_env.observation_space.shape) == 3 15 | assert rgb_env.observation_space.shape[-1] == 3 16 | 17 | wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim) 18 | assert isinstance(wrapped_env.observation_space, spaces.Box) 19 | if keep_dim: 20 | assert len(wrapped_env.observation_space.shape) == 3 21 | assert wrapped_env.observation_space.shape[-1] == 1 22 | else: 23 | assert len(wrapped_env.observation_space.shape) == 2 24 | 25 | wrapped_obs, info = wrapped_env.reset() 26 | assert wrapped_obs in wrapped_env.observation_space 27 | -------------------------------------------------------------------------------- /tests/wrappers/test_human_rendering.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import pytest 4 | 5 | import gym 6 | from gym.wrappers import HumanRendering 7 | 8 | 9 | def test_human_rendering(): 10 | for mode in ["rgb_array", "rgb_array_list"]: 11 | env = HumanRendering( 12 | gym.make("CartPole-v1", render_mode=mode, disable_env_checker=True) 13 | ) 14 | assert env.render_mode == "human" 15 | env.reset() 16 | 17 | for _ in range(75): 18 | _, _, terminated, truncated, _ = env.step(env.action_space.sample()) 19 | if terminated or truncated: 20 | env.reset() 21 | 22 | env.close() 23 | 24 | env = gym.make("CartPole-v1", render_mode="human") 25 | with pytest.raises( 26 | AssertionError, 27 | match=re.escape( 28 | "Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got 'human'" 29 | ), 30 | ): 31 | HumanRendering(env) 32 | env.close() 33 | -------------------------------------------------------------------------------- /tests/wrappers/test_order_enforcing.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym.envs.classic_control import CartPoleEnv 5 | from gym.error import ResetNeeded 6 | from gym.wrappers import OrderEnforcing 7 | from tests.envs.utils import all_testing_env_specs 8 | from tests.wrappers.utils import has_wrapper 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs] 13 | ) 14 | def test_gym_make_order_enforcing(spec): 15 | """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper.""" 16 | env = gym.make(spec.id, disable_env_checker=True) 17 | 18 | assert has_wrapper(env, OrderEnforcing) 19 | 20 | 21 | def test_order_enforcing(): 22 | """Checks that the order enforcing works as expected, raising an error before reset is called and not after.""" 23 | # The reason for not using gym.make is that all environments are by default wrapped in the order enforcing wrapper 24 | env = CartPoleEnv(render_mode="rgb_array_list") 25 | assert not has_wrapper(env, OrderEnforcing) 26 | 27 | # Assert that the order enforcing works for step and render before reset 28 | order_enforced_env = OrderEnforcing(env) 29 | assert order_enforced_env.has_reset is False 30 | with pytest.raises(ResetNeeded): 31 | order_enforced_env.step(0) 32 | with pytest.raises(ResetNeeded): 33 | order_enforced_env.render() 34 | assert order_enforced_env.has_reset is False 35 | 36 | # Assert that the Assertion errors are not raised after reset 37 | order_enforced_env.reset() 38 | assert order_enforced_env.has_reset is True 39 | order_enforced_env.step(0) 40 | order_enforced_env.render() 41 | 42 | # Assert that with disable_render_order_enforcing works, the environment has already been reset 43 | env = CartPoleEnv(render_mode="rgb_array_list") 44 | env = OrderEnforcing(env, disable_render_order_enforcing=True) 45 | env.render() # no assertion error 46 | -------------------------------------------------------------------------------- /tests/wrappers/test_passive_env_checker.py: -------------------------------------------------------------------------------- 1 | import re 2 | import warnings 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | import gym 8 | from gym.wrappers.env_checker import PassiveEnvChecker 9 | from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING 10 | from tests.envs.utils import all_testing_initialised_envs 11 | from tests.testing_env import GenericTestEnv 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "env", 16 | all_testing_initialised_envs, 17 | ids=[env.spec.id for env in all_testing_initialised_envs], 18 | ) 19 | def test_passive_checker_wrapper_warnings(env): 20 | with warnings.catch_warnings(record=True) as caught_warnings: 21 | checker_env = PassiveEnvChecker(env) 22 | checker_env.reset() 23 | checker_env.step(checker_env.action_space.sample()) 24 | # todo, add check for render, bugged due to mujoco v2/3 and v4 envs 25 | 26 | checker_env.close() 27 | 28 | for warning in caught_warnings: 29 | if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING: 30 | raise gym.error.Error(f"Unexpected warning: {warning.message}") 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "env, message", 35 | [ 36 | ( 37 | GenericTestEnv(action_space=None), 38 | "The environment must specify an action space. https://www.gymlibrary.dev/content/environment_creation/", 39 | ), 40 | ( 41 | GenericTestEnv(action_space="error"), 42 | "action space does not inherit from `gym.spaces.Space`, actual type: ", 43 | ), 44 | ( 45 | GenericTestEnv(observation_space=None), 46 | "The environment must specify an observation space. https://www.gymlibrary.dev/content/environment_creation/", 47 | ), 48 | ( 49 | GenericTestEnv(observation_space="error"), 50 | "observation space does not inherit from `gym.spaces.Space`, actual type: ", 51 | ), 52 | ], 53 | ) 54 | def test_initialise_failures(env, message): 55 | with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"): 56 | PassiveEnvChecker(env) 57 | 58 | env.close() 59 | 60 | 61 | def _reset_failure(self, seed=None, options=None): 62 | return np.array([-1.0], dtype=np.float32), {} 63 | 64 | 65 | def _step_failure(self, action): 66 | return "error" 67 | 68 | 69 | def test_api_failures(): 70 | env = GenericTestEnv( 71 | reset_fn=_reset_failure, 72 | step_fn=_step_failure, 73 | metadata={"render_modes": "error"}, 74 | ) 75 | env = PassiveEnvChecker(env) 76 | assert env.checked_reset is False 77 | assert env.checked_step is False 78 | assert env.checked_render is False 79 | 80 | with pytest.warns( 81 | UserWarning, 82 | match=re.escape( 83 | "The obs returned by the `reset()` method is not within the observation space" 84 | ), 85 | ): 86 | env.reset() 87 | assert env.checked_reset 88 | 89 | with pytest.raises( 90 | AssertionError, 91 | match="Expects step result to be a tuple, actual type: ", 92 | ): 93 | env.step(env.action_space.sample()) 94 | assert env.checked_step 95 | 96 | with pytest.warns( 97 | UserWarning, 98 | match=r"Expects the render_modes to be a sequence \(i\.e\. list, tuple\), actual type: ", 99 | ): 100 | env.render() 101 | assert env.checked_render 102 | 103 | env.close() 104 | -------------------------------------------------------------------------------- /tests/wrappers/test_record_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import gym 5 | from gym.wrappers import capped_cubic_video_schedule 6 | 7 | 8 | def test_record_video_using_default_trigger(): 9 | env = gym.make( 10 | "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True 11 | ) 12 | env = gym.wrappers.RecordVideo(env, "videos") 13 | env.reset() 14 | for _ in range(199): 15 | action = env.action_space.sample() 16 | _, _, terminated, truncated, _ = env.step(action) 17 | if terminated or truncated: 18 | env.reset() 19 | env.close() 20 | assert os.path.isdir("videos") 21 | mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] 22 | assert len(mp4_files) == sum( 23 | capped_cubic_video_schedule(i) for i in range(env.episode_id + 1) 24 | ) 25 | shutil.rmtree("videos") 26 | 27 | 28 | def test_record_video_reset(): 29 | env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) 30 | env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) 31 | ob_space = env.observation_space 32 | obs, info = env.reset() 33 | env.close() 34 | assert os.path.isdir("videos") 35 | shutil.rmtree("videos") 36 | assert ob_space.contains(obs) 37 | assert isinstance(info, dict) 38 | 39 | 40 | def test_record_video_step_trigger(): 41 | env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True) 42 | env._max_episode_steps = 20 43 | env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0) 44 | env.reset() 45 | for _ in range(199): 46 | action = env.action_space.sample() 47 | _, _, terminated, truncated, _ = env.step(action) 48 | if terminated or truncated: 49 | env.reset() 50 | env.close() 51 | assert os.path.isdir("videos") 52 | mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] 53 | assert len(mp4_files) == 2 54 | shutil.rmtree("videos") 55 | 56 | 57 | def make_env(gym_id, seed, **kwargs): 58 | def thunk(): 59 | env = gym.make(gym_id, disable_env_checker=True, **kwargs) 60 | env._max_episode_steps = 20 61 | if seed == 1: 62 | env = gym.wrappers.RecordVideo( 63 | env, "videos", step_trigger=lambda x: x % 100 == 0 64 | ) 65 | return env 66 | 67 | return thunk 68 | 69 | 70 | def test_record_video_within_vector(): 71 | envs = gym.vector.SyncVectorEnv( 72 | [make_env("CartPole-v1", 1 + i, render_mode="rgb_array") for i in range(2)] 73 | ) 74 | envs = gym.wrappers.RecordEpisodeStatistics(envs) 75 | envs.reset() 76 | for i in range(199): 77 | _, _, _, _, infos = envs.step(envs.action_space.sample()) 78 | 79 | # break when every env is done 80 | if "episode" in infos and all(infos["_episode"]): 81 | print(f"episode_reward={infos['episode']['r']}") 82 | 83 | assert os.path.isdir("videos") 84 | mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")] 85 | assert len(mp4_files) == 2 86 | shutil.rmtree("videos") 87 | -------------------------------------------------------------------------------- /tests/wrappers/test_rescale_action.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import gym 5 | from gym.wrappers import RescaleAction 6 | 7 | 8 | def test_rescale_action(): 9 | env = gym.make("CartPole-v1", disable_env_checker=True) 10 | with pytest.raises(AssertionError): 11 | env = RescaleAction(env, -1, 1) 12 | del env 13 | 14 | env = gym.make("Pendulum-v1", disable_env_checker=True) 15 | wrapped_env = RescaleAction( 16 | gym.make("Pendulum-v1", disable_env_checker=True), -1, 1 17 | ) 18 | 19 | seed = 0 20 | 21 | obs, info = env.reset(seed=seed) 22 | wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed) 23 | assert np.allclose(obs, wrapped_obs) 24 | 25 | obs, reward, _, _, _ = env.step([1.5]) 26 | with pytest.raises(AssertionError): 27 | wrapped_env.step([1.5]) 28 | wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75]) 29 | 30 | assert np.allclose(obs, wrapped_obs) 31 | assert np.allclose(reward, wrapped_reward) 32 | -------------------------------------------------------------------------------- /tests/wrappers/test_resize_observation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym import spaces 5 | from gym.wrappers import ResizeObservation 6 | 7 | 8 | @pytest.mark.parametrize("env_id", ["CarRacing-v2"]) 9 | @pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]]) 10 | def test_resize_observation(env_id, shape): 11 | env = gym.make(env_id, disable_env_checker=True) 12 | env = ResizeObservation(env, shape) 13 | 14 | assert isinstance(env.observation_space, spaces.Box) 15 | assert env.observation_space.shape[-1] == 3 16 | obs, _ = env.reset() 17 | if isinstance(shape, int): 18 | assert env.observation_space.shape[:2] == (shape, shape) 19 | assert obs.shape == (shape, shape, 3) 20 | else: 21 | assert env.observation_space.shape[:2] == tuple(shape) 22 | assert obs.shape == tuple(shape) + (3,) 23 | -------------------------------------------------------------------------------- /tests/wrappers/test_step_compatibility.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym.spaces import Discrete 5 | from gym.wrappers import StepAPICompatibility 6 | 7 | 8 | class OldStepEnv(gym.Env): 9 | def __init__(self): 10 | self.action_space = Discrete(2) 11 | self.observation_space = Discrete(2) 12 | 13 | def step(self, action): 14 | obs = self.observation_space.sample() 15 | rew = 0 16 | done = False 17 | info = {} 18 | return obs, rew, done, info 19 | 20 | 21 | class NewStepEnv(gym.Env): 22 | def __init__(self): 23 | self.action_space = Discrete(2) 24 | self.observation_space = Discrete(2) 25 | 26 | def step(self, action): 27 | obs = self.observation_space.sample() 28 | rew = 0 29 | terminated = False 30 | truncated = False 31 | info = {} 32 | return obs, rew, terminated, truncated, info 33 | 34 | 35 | @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) 36 | @pytest.mark.parametrize("output_truncation_bool", [None, True]) 37 | def test_step_compatibility_to_new_api(env, output_truncation_bool): 38 | if output_truncation_bool is None: 39 | env = StepAPICompatibility(env()) 40 | else: 41 | env = StepAPICompatibility(env(), output_truncation_bool) 42 | step_returns = env.step(0) 43 | _, _, terminated, truncated, _ = step_returns 44 | assert isinstance(terminated, bool) 45 | assert isinstance(truncated, bool) 46 | 47 | 48 | @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) 49 | def test_step_compatibility_to_old_api(env): 50 | env = StepAPICompatibility(env(), False) 51 | step_returns = env.step(0) 52 | assert len(step_returns) == 4 53 | _, _, done, _ = step_returns 54 | assert isinstance(done, bool) 55 | 56 | 57 | @pytest.mark.parametrize("apply_api_compatibility", [None, True, False]) 58 | def test_step_compatibility_in_make(apply_api_compatibility): 59 | gym.register("OldStepEnv-v0", entry_point=OldStepEnv) 60 | 61 | if apply_api_compatibility is not None: 62 | env = gym.make( 63 | "OldStepEnv-v0", 64 | apply_api_compatibility=apply_api_compatibility, 65 | disable_env_checker=True, 66 | ) 67 | else: 68 | env = gym.make("OldStepEnv-v0", disable_env_checker=True) 69 | 70 | env.reset() 71 | step_returns = env.step(0) 72 | if apply_api_compatibility: 73 | assert len(step_returns) == 5 74 | _, _, terminated, truncated, _ = step_returns 75 | assert isinstance(terminated, bool) 76 | assert isinstance(truncated, bool) 77 | else: 78 | assert len(step_returns) == 4 79 | _, _, done, _ = step_returns 80 | assert isinstance(done, bool) 81 | 82 | gym.envs.registry.pop("OldStepEnv-v0") 83 | -------------------------------------------------------------------------------- /tests/wrappers/test_time_aware_observation.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym import spaces 5 | from gym.wrappers import TimeAwareObservation 6 | 7 | 8 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) 9 | def test_time_aware_observation(env_id): 10 | env = gym.make(env_id, disable_env_checker=True) 11 | wrapped_env = TimeAwareObservation(env) 12 | 13 | assert isinstance(env.observation_space, spaces.Box) 14 | assert isinstance(wrapped_env.observation_space, spaces.Box) 15 | assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1 16 | 17 | obs, info = env.reset() 18 | wrapped_obs, wrapped_obs_info = wrapped_env.reset() 19 | assert wrapped_env.t == 0.0 20 | assert wrapped_obs[-1] == 0.0 21 | assert wrapped_obs.shape[0] == obs.shape[0] + 1 22 | 23 | wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) 24 | assert wrapped_env.t == 1.0 25 | assert wrapped_obs[-1] == 1.0 26 | assert wrapped_obs.shape[0] == obs.shape[0] + 1 27 | 28 | wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample()) 29 | assert wrapped_env.t == 2.0 30 | assert wrapped_obs[-1] == 2.0 31 | assert wrapped_obs.shape[0] == obs.shape[0] + 1 32 | 33 | wrapped_obs, wrapped_obs_info = wrapped_env.reset() 34 | assert wrapped_env.t == 0.0 35 | assert wrapped_obs[-1] == 0.0 36 | assert wrapped_obs.shape[0] == obs.shape[0] + 1 37 | -------------------------------------------------------------------------------- /tests/wrappers/test_time_limit.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym.envs.classic_control.pendulum import PendulumEnv 5 | from gym.wrappers import TimeLimit 6 | 7 | 8 | def test_time_limit_reset_info(): 9 | env = gym.make("CartPole-v1", disable_env_checker=True) 10 | env = TimeLimit(env) 11 | ob_space = env.observation_space 12 | obs, info = env.reset() 13 | assert ob_space.contains(obs) 14 | assert isinstance(info, dict) 15 | 16 | 17 | @pytest.mark.parametrize("double_wrap", [False, True]) 18 | def test_time_limit_wrapper(double_wrap): 19 | # The pendulum env does not terminate by default 20 | # so we are sure termination is only due to timeout 21 | env = PendulumEnv() 22 | max_episode_length = 20 23 | env = TimeLimit(env, max_episode_length) 24 | if double_wrap: 25 | env = TimeLimit(env, max_episode_length) 26 | env.reset() 27 | terminated, truncated = False, False 28 | n_steps = 0 29 | info = {} 30 | while not (terminated or truncated): 31 | n_steps += 1 32 | _, _, terminated, truncated, info = env.step(env.action_space.sample()) 33 | 34 | assert n_steps == max_episode_length 35 | assert truncated 36 | 37 | 38 | @pytest.mark.parametrize("double_wrap", [False, True]) 39 | def test_termination_on_last_step(double_wrap): 40 | # Special case: termination at the last timestep 41 | # Truncation due to timeout also happens at the same step 42 | 43 | env = PendulumEnv() 44 | 45 | def patched_step(_action): 46 | return env.observation_space.sample(), 0.0, True, False, {} 47 | 48 | env.step = patched_step 49 | 50 | max_episode_length = 1 51 | env = TimeLimit(env, max_episode_length) 52 | if double_wrap: 53 | env = TimeLimit(env, max_episode_length) 54 | env.reset() 55 | _, _, terminated, truncated, _ = env.step(env.action_space.sample()) 56 | assert terminated is True 57 | assert truncated is True 58 | -------------------------------------------------------------------------------- /tests/wrappers/test_transform_observation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import gym 5 | from gym.wrappers import TransformObservation 6 | 7 | 8 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) 9 | def test_transform_observation(env_id): 10 | def affine_transform(x): 11 | return 3 * x + 2 12 | 13 | env = gym.make(env_id, disable_env_checker=True) 14 | wrapped_env = TransformObservation( 15 | gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs) 16 | ) 17 | 18 | obs, info = env.reset(seed=0) 19 | wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=0) 20 | assert np.allclose(wrapped_obs, affine_transform(obs)) 21 | assert isinstance(wrapped_obs_info, dict) 22 | 23 | action = env.action_space.sample() 24 | obs, reward, terminated, truncated, _ = env.step(action) 25 | ( 26 | wrapped_obs, 27 | wrapped_reward, 28 | wrapped_terminated, 29 | wrapped_truncated, 30 | _, 31 | ) = wrapped_env.step(action) 32 | assert np.allclose(wrapped_obs, affine_transform(obs)) 33 | assert np.allclose(wrapped_reward, reward) 34 | assert wrapped_terminated == terminated 35 | assert wrapped_truncated == truncated 36 | -------------------------------------------------------------------------------- /tests/wrappers/test_transform_reward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import gym 5 | from gym.wrappers import TransformReward 6 | 7 | 8 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) 9 | def test_transform_reward(env_id): 10 | # use case #1: scale 11 | scales = [0.1, 200] 12 | for scale in scales: 13 | env = gym.make(env_id, disable_env_checker=True) 14 | wrapped_env = TransformReward( 15 | gym.make(env_id, disable_env_checker=True), lambda r: scale * r 16 | ) 17 | action = env.action_space.sample() 18 | 19 | env.reset(seed=0) 20 | wrapped_env.reset(seed=0) 21 | 22 | _, reward, _, _, _ = env.step(action) 23 | _, wrapped_reward, _, _, _ = wrapped_env.step(action) 24 | 25 | assert wrapped_reward == scale * reward 26 | del env, wrapped_env 27 | 28 | # use case #2: clip 29 | min_r = -0.0005 30 | max_r = 0.0002 31 | env = gym.make(env_id, disable_env_checker=True) 32 | wrapped_env = TransformReward( 33 | gym.make(env_id, disable_env_checker=True), lambda r: np.clip(r, min_r, max_r) 34 | ) 35 | action = env.action_space.sample() 36 | 37 | env.reset(seed=0) 38 | wrapped_env.reset(seed=0) 39 | 40 | _, reward, _, _, _ = env.step(action) 41 | _, wrapped_reward, _, _, _ = wrapped_env.step(action) 42 | 43 | assert abs(wrapped_reward) < abs(reward) 44 | assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002 45 | del env, wrapped_env 46 | 47 | # use case #3: sign 48 | env = gym.make(env_id, disable_env_checker=True) 49 | wrapped_env = TransformReward( 50 | gym.make(env_id, disable_env_checker=True), lambda r: np.sign(r) 51 | ) 52 | 53 | env.reset(seed=0) 54 | wrapped_env.reset(seed=0) 55 | 56 | for _ in range(1000): 57 | action = env.action_space.sample() 58 | _, wrapped_reward, terminated, truncated, _ = wrapped_env.step(action) 59 | assert wrapped_reward in [-1.0, 0.0, 1.0] 60 | if terminated or truncated: 61 | break 62 | del env, wrapped_env 63 | -------------------------------------------------------------------------------- /tests/wrappers/test_vector_list_info.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import gym 4 | from gym.wrappers import RecordEpisodeStatistics, VectorListInfo 5 | 6 | ENV_ID = "CartPole-v1" 7 | NUM_ENVS = 3 8 | ENV_STEPS = 50 9 | SEED = 42 10 | 11 | 12 | def test_usage_in_vector_env(): 13 | env = gym.make(ENV_ID, disable_env_checker=True) 14 | vector_env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) 15 | 16 | VectorListInfo(vector_env) 17 | 18 | with pytest.raises(AssertionError): 19 | VectorListInfo(env) 20 | 21 | 22 | def test_info_to_list(): 23 | env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) 24 | wrapped_env = VectorListInfo(env_to_wrap) 25 | wrapped_env.action_space.seed(SEED) 26 | _, info = wrapped_env.reset(seed=SEED) 27 | assert isinstance(info, list) 28 | assert len(info) == NUM_ENVS 29 | 30 | for _ in range(ENV_STEPS): 31 | action = wrapped_env.action_space.sample() 32 | _, _, terminateds, truncateds, list_info = wrapped_env.step(action) 33 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): 34 | if terminated or truncated: 35 | assert "final_observation" in list_info[i] 36 | else: 37 | assert "final_observation" not in list_info[i] 38 | 39 | 40 | def test_info_to_list_statistics(): 41 | env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True) 42 | wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap)) 43 | _, info = wrapped_env.reset(seed=SEED) 44 | wrapped_env.action_space.seed(SEED) 45 | assert isinstance(info, list) 46 | assert len(info) == NUM_ENVS 47 | 48 | for _ in range(ENV_STEPS): 49 | action = wrapped_env.action_space.sample() 50 | _, _, terminateds, truncateds, list_info = wrapped_env.step(action) 51 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)): 52 | if terminated or truncated: 53 | assert "episode" in list_info[i] 54 | for stats in ["r", "l", "t"]: 55 | assert stats in list_info[i]["episode"] 56 | assert isinstance(list_info[i]["episode"][stats], float) 57 | else: 58 | assert "episode" not in list_info[i] 59 | -------------------------------------------------------------------------------- /tests/wrappers/test_video_recorder.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import re 4 | import time 5 | 6 | import pytest 7 | 8 | import gym 9 | from gym.wrappers.monitoring.video_recorder import VideoRecorder 10 | 11 | 12 | class BrokenRecordableEnv(gym.Env): 13 | metadata = {"render_modes": ["rgb_array_list"]} 14 | 15 | def __init__(self, render_mode="rgb_array_list"): 16 | self.render_mode = render_mode 17 | 18 | def render(self): 19 | pass 20 | 21 | 22 | class UnrecordableEnv(gym.Env): 23 | metadata = {"render_modes": [None]} 24 | 25 | def __init__(self, render_mode=None): 26 | self.render_mode = render_mode 27 | 28 | def render(self): 29 | pass 30 | 31 | 32 | def test_record_simple(): 33 | env = gym.make( 34 | "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True 35 | ) 36 | rec = VideoRecorder(env) 37 | env.reset() 38 | rec.capture_frame() 39 | 40 | rec.close() 41 | 42 | assert not rec.broken 43 | assert os.path.exists(rec.path) 44 | f = open(rec.path) 45 | assert os.fstat(f.fileno()).st_size > 100 46 | 47 | 48 | def test_autoclose(): 49 | def record(): 50 | env = gym.make( 51 | "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True 52 | ) 53 | rec = VideoRecorder(env) 54 | env.reset() 55 | rec.capture_frame() 56 | 57 | rec_path = rec.path 58 | 59 | # The function ends without an explicit `rec.close()` call 60 | # The Python interpreter will implicitly do `del rec` on garbage cleaning 61 | return rec_path 62 | 63 | rec_path = record() 64 | 65 | gc.collect() # do explicit garbage collection for test 66 | time.sleep(5) # wait for subprocess exiting 67 | 68 | assert os.path.exists(rec_path) 69 | f = open(rec_path) 70 | assert os.fstat(f.fileno()).st_size > 100 71 | 72 | 73 | def test_no_frames(): 74 | env = BrokenRecordableEnv() 75 | rec = VideoRecorder(env) 76 | rec.close() 77 | assert rec.functional 78 | assert not os.path.exists(rec.path) 79 | 80 | 81 | def test_record_unrecordable_method(): 82 | with pytest.warns( 83 | UserWarning, 84 | match=re.escape( 85 | "\x1b[33mWARN: Disabling video recorder because environment was not initialized with any compatible video mode between `rgb_array` and `rgb_array_list`\x1b[0m" 86 | ), 87 | ): 88 | env = UnrecordableEnv() 89 | rec = VideoRecorder(env) 90 | assert not rec.enabled 91 | rec.close() 92 | 93 | 94 | def test_record_breaking_render_method(): 95 | with pytest.warns( 96 | UserWarning, 97 | match=re.escape( 98 | "Env returned None on `render()`. Disabling further rendering for video recorder by marking as disabled:" 99 | ), 100 | ): 101 | env = BrokenRecordableEnv() 102 | rec = VideoRecorder(env) 103 | rec.capture_frame() 104 | rec.close() 105 | assert rec.broken 106 | assert not os.path.exists(rec.path) 107 | 108 | 109 | def test_text_envs(): 110 | env = gym.make( 111 | "FrozenLake-v1", render_mode="rgb_array_list", disable_env_checker=True 112 | ) 113 | video = VideoRecorder(env) 114 | try: 115 | env.reset() 116 | video.capture_frame() 117 | video.close() 118 | finally: 119 | os.remove(video.path) 120 | -------------------------------------------------------------------------------- /tests/wrappers/utils.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | 4 | def has_wrapper(wrapped_env: gym.Env, wrapper_type: type) -> bool: 5 | while isinstance(wrapped_env, gym.Wrapper): 6 | if isinstance(wrapped_env, wrapper_type): 7 | return True 8 | wrapped_env = wrapped_env.env 9 | return False 10 | --------------------------------------------------------------------------------