├── .github
├── ISSUE_TEMPLATE
│ ├── bug.md
│ ├── proposal.md
│ └── question.md
├── PULL_REQUEST_TEMPLATE.md
├── stale.yml
└── workflows
│ ├── build.yml
│ └── pre-commit.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.rst
├── CONTRIBUTING.md
├── LICENSE.md
├── README.md
├── bin
└── docker_entrypoint
├── gym
├── __init__.py
├── core.py
├── envs
│ ├── __init__.py
│ ├── box2d
│ │ ├── __init__.py
│ │ ├── bipedal_walker.py
│ │ ├── car_dynamics.py
│ │ ├── car_racing.py
│ │ └── lunar_lander.py
│ ├── classic_control
│ │ ├── __init__.py
│ │ ├── acrobot.py
│ │ ├── assets
│ │ │ └── clockwise.png
│ │ ├── cartpole.py
│ │ ├── continuous_mountain_car.py
│ │ ├── mountain_car.py
│ │ ├── pendulum.py
│ │ └── utils.py
│ ├── mujoco
│ │ ├── __init__.py
│ │ ├── ant.py
│ │ ├── ant_v3.py
│ │ ├── ant_v4.py
│ │ ├── assets
│ │ │ ├── ant.xml
│ │ │ ├── half_cheetah.xml
│ │ │ ├── hopper.xml
│ │ │ ├── humanoid.xml
│ │ │ ├── humanoidstandup.xml
│ │ │ ├── inverted_double_pendulum.xml
│ │ │ ├── inverted_pendulum.xml
│ │ │ ├── point.xml
│ │ │ ├── pusher.xml
│ │ │ ├── reacher.xml
│ │ │ ├── swimmer.xml
│ │ │ └── walker2d.xml
│ │ ├── half_cheetah.py
│ │ ├── half_cheetah_v3.py
│ │ ├── half_cheetah_v4.py
│ │ ├── hopper.py
│ │ ├── hopper_v3.py
│ │ ├── hopper_v4.py
│ │ ├── humanoid.py
│ │ ├── humanoid_v3.py
│ │ ├── humanoid_v4.py
│ │ ├── humanoidstandup.py
│ │ ├── humanoidstandup_v4.py
│ │ ├── inverted_double_pendulum.py
│ │ ├── inverted_double_pendulum_v4.py
│ │ ├── inverted_pendulum.py
│ │ ├── inverted_pendulum_v4.py
│ │ ├── mujoco_env.py
│ │ ├── mujoco_rendering.py
│ │ ├── pusher.py
│ │ ├── pusher_v4.py
│ │ ├── reacher.py
│ │ ├── reacher_v4.py
│ │ ├── swimmer.py
│ │ ├── swimmer_v3.py
│ │ ├── swimmer_v4.py
│ │ ├── walker2d.py
│ │ ├── walker2d_v3.py
│ │ └── walker2d_v4.py
│ ├── registration.py
│ └── toy_text
│ │ ├── __init__.py
│ │ ├── blackjack.py
│ │ ├── cliffwalking.py
│ │ ├── font
│ │ └── Minecraft.ttf
│ │ ├── frozen_lake.py
│ │ ├── img
│ │ ├── C2.png
│ │ ├── C3.png
│ │ ├── C4.png
│ │ ├── C5.png
│ │ ├── C6.png
│ │ ├── C7.png
│ │ ├── C8.png
│ │ ├── C9.png
│ │ ├── CA.png
│ │ ├── CJ.png
│ │ ├── CK.png
│ │ ├── CQ.png
│ │ ├── CT.png
│ │ ├── Card.png
│ │ ├── D2.png
│ │ ├── D3.png
│ │ ├── D4.png
│ │ ├── D5.png
│ │ ├── D6.png
│ │ ├── D7.png
│ │ ├── D8.png
│ │ ├── D9.png
│ │ ├── DA.png
│ │ ├── DJ.png
│ │ ├── DK.png
│ │ ├── DQ.png
│ │ ├── DT.png
│ │ ├── H2.png
│ │ ├── H3.png
│ │ ├── H4.png
│ │ ├── H5.png
│ │ ├── H6.png
│ │ ├── H7.png
│ │ ├── H8.png
│ │ ├── H9.png
│ │ ├── HA.png
│ │ ├── HJ.png
│ │ ├── HK.png
│ │ ├── HQ.png
│ │ ├── HT.png
│ │ ├── S2.png
│ │ ├── S3.png
│ │ ├── S4.png
│ │ ├── S5.png
│ │ ├── S6.png
│ │ ├── S7.png
│ │ ├── S8.png
│ │ ├── S9.png
│ │ ├── SA.png
│ │ ├── SJ.png
│ │ ├── SK.png
│ │ ├── SQ.png
│ │ ├── ST.png
│ │ ├── cab_front.png
│ │ ├── cab_left.png
│ │ ├── cab_rear.png
│ │ ├── cab_right.png
│ │ ├── cookie.png
│ │ ├── cracked_hole.png
│ │ ├── elf_down.png
│ │ ├── elf_left.png
│ │ ├── elf_right.png
│ │ ├── elf_up.png
│ │ ├── goal.png
│ │ ├── gridworld_median_bottom.png
│ │ ├── gridworld_median_horiz.png
│ │ ├── gridworld_median_left.png
│ │ ├── gridworld_median_right.png
│ │ ├── gridworld_median_top.png
│ │ ├── gridworld_median_vert.png
│ │ ├── hole.png
│ │ ├── hotel.png
│ │ ├── ice.png
│ │ ├── mountain_bg1.png
│ │ ├── mountain_bg2.png
│ │ ├── mountain_cliff.png
│ │ ├── mountain_near-cliff1.png
│ │ ├── mountain_near-cliff2.png
│ │ ├── passenger.png
│ │ ├── stool.png
│ │ └── taxi_background.png
│ │ ├── taxi.py
│ │ └── utils.py
├── error.py
├── logger.py
├── py.typed
├── spaces
│ ├── __init__.py
│ ├── box.py
│ ├── dict.py
│ ├── discrete.py
│ ├── graph.py
│ ├── multi_binary.py
│ ├── multi_discrete.py
│ ├── sequence.py
│ ├── space.py
│ ├── text.py
│ ├── tuple.py
│ └── utils.py
├── utils
│ ├── __init__.py
│ ├── colorize.py
│ ├── env_checker.py
│ ├── ezpickle.py
│ ├── passive_env_checker.py
│ ├── play.py
│ ├── save_video.py
│ ├── seeding.py
│ └── step_api_compatibility.py
├── vector
│ ├── __init__.py
│ ├── async_vector_env.py
│ ├── sync_vector_env.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── misc.py
│ │ ├── numpy_utils.py
│ │ ├── shared_memory.py
│ │ └── spaces.py
│ └── vector_env.py
├── version.py
└── wrappers
│ ├── README.md
│ ├── __init__.py
│ ├── atari_preprocessing.py
│ ├── autoreset.py
│ ├── clip_action.py
│ ├── compatibility.py
│ ├── env_checker.py
│ ├── filter_observation.py
│ ├── flatten_observation.py
│ ├── frame_stack.py
│ ├── gray_scale_observation.py
│ ├── human_rendering.py
│ ├── monitoring
│ ├── __init__.py
│ └── video_recorder.py
│ ├── normalize.py
│ ├── order_enforcing.py
│ ├── pixel_observation.py
│ ├── record_episode_statistics.py
│ ├── record_video.py
│ ├── render_collection.py
│ ├── rescale_action.py
│ ├── resize_observation.py
│ ├── step_api_compatibility.py
│ ├── time_aware_observation.py
│ ├── time_limit.py
│ ├── transform_observation.py
│ ├── transform_reward.py
│ └── vector_list_info.py
├── py.Dockerfile
├── pyproject.toml
├── requirements.txt
├── setup.py
├── test_requirements.txt
└── tests
├── __init__.py
├── envs
├── __init__.py
├── test_action_dim_check.py
├── test_compatibility.py
├── test_env_implementation.py
├── test_envs.py
├── test_make.py
├── test_mujoco.py
├── test_register.py
├── test_spec.py
├── utils.py
└── utils_envs.py
├── spaces
├── __init__.py
├── test_box.py
├── test_dict.py
├── test_discrete.py
├── test_graph.py
├── test_multibinary.py
├── test_multidiscrete.py
├── test_sequence.py
├── test_space.py
├── test_spaces.py
├── test_text.py
├── test_tuple.py
├── test_utils.py
└── utils.py
├── test_core.py
├── testing_env.py
├── utils
├── __init__.py
├── test_env_checker.py
├── test_passive_env_checker.py
├── test_play.py
├── test_save_video.py
├── test_seeding.py
└── test_step_api_compatibility.py
├── vector
├── __init__.py
├── test_async_vector_env.py
├── test_numpy_utils.py
├── test_shared_memory.py
├── test_spaces.py
├── test_sync_vector_env.py
├── test_vector_env.py
├── test_vector_env_info.py
├── test_vector_env_wrapper.py
├── test_vector_make.py
└── utils.py
└── wrappers
├── __init__.py
├── test_atari_preprocessing.py
├── test_autoreset.py
├── test_clip_action.py
├── test_filter_observation.py
├── test_flatten.py
├── test_flatten_observation.py
├── test_frame_stack.py
├── test_gray_scale_observation.py
├── test_human_rendering.py
├── test_nested_dict.py
├── test_normalize.py
├── test_order_enforcing.py
├── test_passive_env_checker.py
├── test_pixel_observation.py
├── test_record_episode_statistics.py
├── test_record_video.py
├── test_rescale_action.py
├── test_resize_observation.py
├── test_step_compatibility.py
├── test_time_aware_observation.py
├── test_time_limit.py
├── test_transform_observation.py
├── test_transform_reward.py
├── test_vector_list_info.py
├── test_video_recorder.py
└── utils.py
/.github/ISSUE_TEMPLATE/bug.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug Report
3 | about: Submit a bug report
4 | title: "[Bug Report] Bug title"
5 |
6 | ---
7 |
8 | If you are submitting a bug report, please fill in the following details and use the tag [bug].
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **Code example**
14 | Please try to provide a minimal example to reproduce the bug. Error messages and stack traces are also helpful.
15 |
16 | **System Info**
17 | Describe the characteristic of your environment:
18 | * Describe how Gym was installed (pip, docker, source, ...)
19 | * What OS/version of Linux you're using. Note that while we will accept PRs to improve Window's support, we do not officially support it.
20 | * Python version
21 |
22 | **Additional context**
23 | Add any other context about the problem here.
24 |
25 | ### Checklist
26 |
27 | - [ ] I have checked that there is no similar [issue](https://github.com/openai/gym/issues) in the repo (**required**)
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/proposal.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Proposal
3 | about: Propose changes that are not fixes bugs
4 | title: "[Proposal] Proposal title"
5 | ---
6 |
7 |
8 |
9 | ### Proposal
10 |
11 | A clear and concise description of the proposal.
12 |
13 | ### Motivation
14 |
15 | Please outline the motivation for the proposal.
16 | Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]".
17 | If this is related to another GitHub issue, please link here too.
18 |
19 | ### Pitch
20 |
21 | A clear and concise description of what you want to happen.
22 |
23 | ### Alternatives
24 |
25 | A clear and concise description of any alternative solutions or features you've considered, if any.
26 |
27 | ### Additional context
28 |
29 | Add any other context or screenshots about the feature request here.
30 |
31 | ### Checklist
32 |
33 | - [ ] I have checked that there is no similar [issue](https://github.com/openai/gym/issues) in the repo (**required**)
34 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Question
3 | about: Ask a question
4 | title: "[Question] Question title"
5 | ---
6 |
7 |
8 | ### Question
9 |
10 | If you're a beginner and have basic questions, please ask on [r/reinforcementlearning](https://www.reddit.com/r/reinforcementlearning/) or in the [RL Discord](https://discord.com/invite/xhfNqQv) (if you're new please use the beginners channel). Basic questions that are not bugs or feature requests will be closed without reply, because GitHub issues are not an appropriate venue for these.
11 |
12 | Advanced/nontrivial questions, especially in areas where documentation is lacking, are very much welcome.
13 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # Description
2 |
3 | Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
4 |
5 | Fixes # (issue)
6 |
7 | ## Type of change
8 |
9 | Please delete options that are not relevant.
10 |
11 | - [ ] Bug fix (non-breaking change which fixes an issue)
12 | - [ ] New feature (non-breaking change which adds functionality)
13 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
14 | - [ ] This change requires a documentation update
15 |
16 | ### Screenshots
17 | Please attach before and after screenshots of the change if applicable.
18 |
19 |
29 |
30 | # Checklist:
31 |
32 | - [ ] I have run the [`pre-commit` checks](https://pre-commit.com/) with `pre-commit run --all-files` (see `CONTRIBUTING.md` instructions to set it up)
33 | - [ ] I have commented my code, particularly in hard-to-understand areas
34 | - [ ] I have made corresponding changes to the documentation
35 | - [ ] My changes generate no new warnings
36 | - [ ] I have added tests that prove my fix is effective or that my feature works
37 | - [ ] New and existing unit tests pass locally with my changes
38 |
39 |
46 |
--------------------------------------------------------------------------------
/.github/stale.yml:
--------------------------------------------------------------------------------
1 | # Configuration for probot-stale - https://github.com/probot/stale
2 |
3 | # Number of days of inactivity before an Issue or Pull Request becomes stale
4 | daysUntilStale: 60
5 |
6 | # Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
7 | # Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
8 | daysUntilClose: 14
9 |
10 | # Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
11 | onlyLabels:
12 | - more-information-needed
13 |
14 | # Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
15 | exemptLabels:
16 | - pinned
17 | - security
18 | - "[Status] Maybe Later"
19 |
20 | # Set to true to ignore issues in a project (defaults to false)
21 | exemptProjects: true
22 |
23 | # Set to true to ignore issues in a milestone (defaults to false)
24 | exemptMilestones: true
25 |
26 | # Set to true to ignore issues with an assignee (defaults to false)
27 | exemptAssignees: true
28 |
29 | # Label to use when marking as stale
30 | staleLabel: stale
31 |
32 | # Comment to post when marking as stale. Set to `false` to disable
33 | markComment: >
34 | This issue has been automatically marked as stale because it has not had
35 | recent activity. It will be closed if no further activity occurs. Thank you
36 | for your contributions.
37 |
38 | # Comment to post when removing the stale label.
39 | # unmarkComment: >
40 | # Your comment here.
41 |
42 | # Comment to post when closing a stale Issue or Pull Request.
43 | # closeComment: >
44 | # Your comment here.
45 |
46 | # Limit the number of actions per hour, from 1-30. Default is 30
47 | limitPerRun: 30
48 |
49 | # Limit to only `issues` or `pulls`
50 | only: issues
51 |
52 | # Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
53 | # pulls:
54 | # daysUntilStale: 30
55 | # markComment: >
56 | # This pull request has been automatically marked as stale because it has not had
57 | # recent activity. It will be closed if no further activity occurs. Thank you
58 | # for your contributions.
59 |
60 | # issues:
61 | # exemptLabels:
62 | # - confirmed
--------------------------------------------------------------------------------
/.github/workflows/build.yml:
--------------------------------------------------------------------------------
1 | name: build
2 | on: [pull_request, push]
3 |
4 | permissions:
5 | contents: read # to fetch code (actions/checkout)
6 |
7 | jobs:
8 | build:
9 | runs-on: ubuntu-latest
10 | strategy:
11 | matrix:
12 | python-version: ['3.6', '3.7', '3.8', '3.9', '3.10']
13 | steps:
14 | - uses: actions/checkout@v2
15 | - run: |
16 | docker build -f py.Dockerfile \
17 | --build-arg PYTHON_VERSION=${{ matrix.python-version }} \
18 | --tag gym-docker .
19 | - name: Run tests
20 | run: docker run gym-docker pytest
21 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | # https://pre-commit.com
2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file.
3 | name: pre-commit
4 | on:
5 | pull_request:
6 | push:
7 | branches: [master]
8 | permissions:
9 | contents: read # to fetch code (actions/checkout)
10 | jobs:
11 | pre-commit:
12 | runs-on: ubuntu-latest
13 | steps:
14 | - uses: actions/checkout@v2
15 | - uses: actions/setup-python@v2
16 | - run: pip install pre-commit
17 | - run: pre-commit --version
18 | - run: pre-commit install
19 | - run: pre-commit run --all-files
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | *.py~
4 | .DS_Store
5 | .cache
6 | .pytest_cache/
7 |
8 | # Setuptools distribution and build folders.
9 | /dist/
10 | /build
11 |
12 | # Virtualenv
13 | /env
14 |
15 | # Python egg metadata, regenerated from source files by setuptools.
16 | /*.egg-info
17 |
18 | *.sublime-project
19 | *.sublime-workspace
20 |
21 | logs/
22 |
23 | .ipynb_checkpoints
24 | ghostdriver.log
25 |
26 | junk
27 | MUJOCO_LOG.txt
28 |
29 | rllab_mujoco
30 |
31 | tutorial/*.html
32 |
33 | # IDE files
34 | .eggs
35 | .tox
36 |
37 | # PyCharm project files
38 | .idea
39 | vizdoom.ini
40 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | repos:
3 | - repo: https://github.com/python/black
4 | rev: 22.3.0
5 | hooks:
6 | - id: black
7 | - repo: https://github.com/codespell-project/codespell
8 | rev: v2.1.0
9 | hooks:
10 | - id: codespell
11 | args:
12 | - --ignore-words-list=nd,reacher,thist,ths, ure, referenc
13 | - repo: https://gitlab.com/PyCQA/flake8
14 | rev: 4.0.1
15 | hooks:
16 | - id: flake8
17 | args:
18 | - '--per-file-ignores=*/__init__.py:F401 gym/envs/registration.py:E704'
19 | - --ignore=E203,W503,E741
20 | - --max-complexity=30
21 | - --max-line-length=456
22 | - --show-source
23 | - --statistics
24 | - repo: https://github.com/PyCQA/isort
25 | rev: 5.10.1
26 | hooks:
27 | - id: isort
28 | args: ["--profile", "black"]
29 | - repo: https://github.com/pycqa/pydocstyle
30 | rev: 6.1.1 # pick a git hash / tag to point to
31 | hooks:
32 | - id: pydocstyle
33 | exclude: ^(gym/version.py)|(gym/envs/)|(tests/)
34 | args:
35 | - --source
36 | - --explain
37 | - --convention=google
38 | additional_dependencies: ["toml"]
39 | - repo: https://github.com/asottile/pyupgrade
40 | rev: v2.32.0
41 | hooks:
42 | - id: pyupgrade
43 | # TODO: remove `--keep-runtime-typing` option
44 | args: ["--py36-plus", "--keep-runtime-typing"]
45 | - repo: local
46 | hooks:
47 | - id: pyright
48 | name: pyright
49 | entry: pyright
50 | language: node
51 | pass_filenames: false
52 | types: [python]
53 | additional_dependencies: ["pyright"]
54 | args:
55 | - --project=pyproject.toml
56 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.rst:
--------------------------------------------------------------------------------
1 | OpenAI Gym is dedicated to providing a harassment-free experience for
2 | everyone, regardless of gender, gender identity and expression, sexual
3 | orientation, disability, physical appearance, body size, age, race, or
4 | religion. We do not tolerate harassment of participants in any form.
5 |
6 | This code of conduct applies to all OpenAI Gym spaces (including Gist
7 | comments) both online and off. Anyone who violates this code of
8 | conduct may be sanctioned or expelled from these spaces at the
9 | discretion of the OpenAI team.
10 |
11 | We may add additional rules over time, which will be made clearly
12 | available to participants. Participants are responsible for knowing
13 | and abiding by these rules.
14 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Gym Contribution Guidelines
2 |
3 | At this time we are currently accepting the current forms of contributions:
4 |
5 | - Bug reports (keep in mind that changing environment behavior should be minimized as that requires releasing a new version of the environment and makes results hard to compare across versions)
6 | - Pull requests for bug fixes
7 | - Documentation improvements
8 |
9 | Notably, we are not accepting these forms of contributions:
10 |
11 | - New environments
12 | - New features
13 |
14 | This may change in the future.
15 | If you wish to make a Gym environment, follow the instructions in [Creating Environments](https://github.com/openai/gym/blob/master/docs/creating_environments.md). When your environment works, you can make a PR to add it to the bottom of the [List of Environments](https://github.com/openai/gym/blob/master/docs/third_party_environments.md).
16 |
17 |
18 | Edit July 27, 2021: Please see https://github.com/openai/gym/issues/2259 for new contributing standards
19 |
20 | # Development
21 | This section contains technical instructions & hints for the contributors.
22 |
23 | ## Type checking
24 | The project uses `pyright` to check types.
25 | To type check locally, install `pyright` per official [instructions](https://github.com/microsoft/pyright#command-line).
26 | It's configuration lives within `pyproject.toml`. It includes list of included and excluded files currently supporting type checks.
27 | To run `pyright` for the project, run the pre-commit process (`pre-commit run --all-files`) or `pyright --project=pyproject.toml`
28 | Alternatively, pyright is a built-in feature of VSCode that will automatically provide type hinting.
29 |
30 | ### Adding typing to more modules and packages
31 | If you would like to add typing to a module in the project,
32 | the list of included, excluded and strict files can be found in pyproject.toml (pyproject.toml -> [tool.pyright]).
33 | To run `pyright` for the project, run the pre-commit process (`pre-commit run --all-files`) or `pyright`
34 |
35 | ## Git hooks
36 | The CI will run several checks on the new code pushed to the Gym repository. These checks can also be run locally without waiting for the CI by following the steps below:
37 | 1. [install `pre-commit`](https://pre-commit.com/#install),
38 | 2. Install the Git hooks by running `pre-commit install`.
39 |
40 | Once those two steps are done, the Git hooks will be run automatically at every new commit.
41 | The Git hooks can also be run manually with `pre-commit run --all-files`, and if needed they can be skipped (not recommended) with `git commit --no-verify`.
42 | **Note:** you may have to run `pre-commit run --all-files` manually a couple of times to make it pass when you commit, as each formatting tool will first format the code and fail the first time but should pass the second time.
43 |
44 | Additionally, for pull requests, the project runs a number of tests for the whole project using [pytest](https://docs.pytest.org/en/latest/getting-started.html#install-pytest).
45 | These tests can be run locally with `pytest` in the root folder.
46 |
47 | ## Docstrings
48 | Pydocstyle has been added to the pre-commit process such that all new functions follow the [google docstring style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html).
49 | All new functions require either a short docstring, a single line explaining the purpose of a function
50 | or a multiline docstring that documents each argument and the return type (if there is one) of the function.
51 | In addition, new file and class require top docstrings that should outline the purpose of the file/class.
52 | For classes, code block examples can be provided in the top docstring and not the constructor arguments.
53 |
54 | To check your docstrings are correct, run `pre-commit run --all-files` or `pydocstyle --source --explain --convention=google`.
55 | If all docstrings that fail, the source and reason for the failure is provided.
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2016 OpenAI (https://openai.com)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in
13 | all copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 | THE SOFTWARE.
22 |
23 | # Mujoco models
24 | This work is derived from [MuJuCo models](http://www.mujoco.org/forum/index.php?resources/) used under the following license:
25 | ```
26 | This file is part of MuJoCo.
27 | Copyright 2009-2015 Roboti LLC.
28 | Mujoco :: Advanced physics simulation engine
29 | Source : www.roboti.us
30 | Version : 1.31
31 | Released : 23Apr16
32 | Author :: Vikash Kumar
33 | Contacts : kumar@roboti.us
34 | ```
35 |
--------------------------------------------------------------------------------
/bin/docker_entrypoint:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # This script is the entrypoint for our Docker image.
3 |
4 | set -ex
5 |
6 | # Set up display; otherwise rendering will fail
7 | Xvfb -screen 0 1024x768x24 &
8 | export DISPLAY=:0
9 |
10 | # Wait for the file to come up
11 | display=0
12 | file="/tmp/.X11-unix/X$display"
13 | for i in $(seq 1 10); do
14 | if [ -e "$file" ]; then
15 | break
16 | fi
17 |
18 | echo "Waiting for $file to be created (try $i/10)"
19 | sleep "$i"
20 | done
21 | if ! [ -e "$file" ]; then
22 | echo "Timing out: $file was not created"
23 | exit 1
24 | fi
25 |
26 | exec "$@"
27 |
--------------------------------------------------------------------------------
/gym/__init__.py:
--------------------------------------------------------------------------------
1 | """Root __init__ of the gym module setting the __all__ of gym modules."""
2 | # isort: skip_file
3 |
4 | from gym import error
5 | from gym.version import VERSION as __version__
6 |
7 | from gym.core import (
8 | Env,
9 | Wrapper,
10 | ObservationWrapper,
11 | ActionWrapper,
12 | RewardWrapper,
13 | )
14 | from gym.spaces import Space
15 | from gym.envs import make, spec, register
16 | from gym import logger
17 | from gym import vector
18 | from gym import wrappers
19 | import os
20 | import sys
21 |
22 | __all__ = ["Env", "Space", "Wrapper", "make", "spec", "register"]
23 |
24 | # Initializing pygame initializes audio connections through SDL. SDL uses alsa by default on all Linux systems
25 | # SDL connecting to alsa frequently create these giant lists of warnings every time you import an environment using
26 | # pygame
27 | # DSP is far more benign (and should probably be the default in SDL anyways)
28 |
29 | if sys.platform.startswith("linux"):
30 | os.environ["SDL_AUDIODRIVER"] = "dsp"
31 |
32 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide"
33 |
34 | try:
35 | import gym_notices.notices as notices
36 |
37 | # print version warning if necessary
38 | notice = notices.notices.get(__version__)
39 | if notice:
40 | print(notice, file=sys.stderr)
41 |
42 | except Exception: # nosec
43 | pass
44 |
--------------------------------------------------------------------------------
/gym/envs/box2d/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.box2d.bipedal_walker import BipedalWalker, BipedalWalkerHardcore
2 | from gym.envs.box2d.car_racing import CarRacing
3 | from gym.envs.box2d.lunar_lander import LunarLander, LunarLanderContinuous
4 |
--------------------------------------------------------------------------------
/gym/envs/classic_control/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.classic_control.acrobot import AcrobotEnv
2 | from gym.envs.classic_control.cartpole import CartPoleEnv
3 | from gym.envs.classic_control.continuous_mountain_car import Continuous_MountainCarEnv
4 | from gym.envs.classic_control.mountain_car import MountainCarEnv
5 | from gym.envs.classic_control.pendulum import PendulumEnv
6 |
--------------------------------------------------------------------------------
/gym/envs/classic_control/assets/clockwise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/classic_control/assets/clockwise.png
--------------------------------------------------------------------------------
/gym/envs/classic_control/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions used for classic control environments.
3 | """
4 |
5 | from typing import Optional, SupportsFloat, Tuple
6 |
7 |
8 | def verify_number_and_cast(x: SupportsFloat) -> float:
9 | """Verify parameter is a single number and cast to a float."""
10 | try:
11 | x = float(x)
12 | except (ValueError, TypeError):
13 | raise ValueError(f"An option ({x}) could not be converted to a float.")
14 | return x
15 |
16 |
17 | def maybe_parse_reset_bounds(
18 | options: Optional[dict], default_low: float, default_high: float
19 | ) -> Tuple[float, float]:
20 | """
21 | This function can be called during a reset() to customize the sampling
22 | ranges for setting the initial state distributions.
23 |
24 | Args:
25 | options: Options passed in to reset().
26 | default_low: Default lower limit to use, if none specified in options.
27 | default_high: Default upper limit to use, if none specified in options.
28 |
29 | Returns:
30 | Tuple of the lower and upper limits.
31 | """
32 | if options is None:
33 | return default_low, default_high
34 |
35 | low = options.get("low") if "low" in options else default_low
36 | high = options.get("high") if "high" in options else default_high
37 |
38 | # We expect only numerical inputs.
39 | low = verify_number_and_cast(low)
40 | high = verify_number_and_cast(high)
41 | if low > high:
42 | raise ValueError(
43 | f"Lower bound ({low}) must be lower than higher bound ({high})."
44 | )
45 |
46 | return low, high
47 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.mujoco.mujoco_env import MujocoEnv, MuJocoPyEnv # isort:skip
2 |
3 | from gym.envs.mujoco.ant import AntEnv
4 | from gym.envs.mujoco.half_cheetah import HalfCheetahEnv
5 | from gym.envs.mujoco.hopper import HopperEnv
6 | from gym.envs.mujoco.humanoid import HumanoidEnv
7 | from gym.envs.mujoco.humanoidstandup import HumanoidStandupEnv
8 | from gym.envs.mujoco.inverted_double_pendulum import InvertedDoublePendulumEnv
9 | from gym.envs.mujoco.inverted_pendulum import InvertedPendulumEnv
10 | from gym.envs.mujoco.pusher import PusherEnv
11 | from gym.envs.mujoco.reacher import ReacherEnv
12 | from gym.envs.mujoco.swimmer import SwimmerEnv
13 | from gym.envs.mujoco.walker2d import Walker2dEnv
14 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/ant.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class AntEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 20,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | observation_space = Box(
20 | low=-np.inf, high=np.inf, shape=(111,), dtype=np.float64
21 | )
22 | MuJocoPyEnv.__init__(
23 | self, "ant.xml", 5, observation_space=observation_space, **kwargs
24 | )
25 | utils.EzPickle.__init__(self, **kwargs)
26 |
27 | def step(self, a):
28 | xposbefore = self.get_body_com("torso")[0]
29 | self.do_simulation(a, self.frame_skip)
30 | xposafter = self.get_body_com("torso")[0]
31 |
32 | forward_reward = (xposafter - xposbefore) / self.dt
33 | ctrl_cost = 0.5 * np.square(a).sum()
34 | contact_cost = (
35 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
36 | )
37 | survive_reward = 1.0
38 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
39 | state = self.state_vector()
40 | not_terminated = (
41 | np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
42 | )
43 | terminated = not not_terminated
44 | ob = self._get_obs()
45 |
46 | if self.render_mode == "human":
47 | self.render()
48 | return (
49 | ob,
50 | reward,
51 | terminated,
52 | False,
53 | dict(
54 | reward_forward=forward_reward,
55 | reward_ctrl=-ctrl_cost,
56 | reward_contact=-contact_cost,
57 | reward_survive=survive_reward,
58 | ),
59 | )
60 |
61 | def _get_obs(self):
62 | return np.concatenate(
63 | [
64 | self.sim.data.qpos.flat[2:],
65 | self.sim.data.qvel.flat,
66 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
67 | ]
68 | )
69 |
70 | def reset_model(self):
71 | qpos = self.init_qpos + self.np_random.uniform(
72 | size=self.model.nq, low=-0.1, high=0.1
73 | )
74 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
75 | self.set_state(qpos, qvel)
76 | return self._get_obs()
77 |
78 | def viewer_setup(self):
79 | assert self.viewer is not None
80 | self.viewer.cam.distance = self.model.stat.extent * 0.5
81 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/assets/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/assets/inverted_double_pendulum.xml:
--------------------------------------------------------------------------------
1 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/assets/inverted_pendulum.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/assets/point.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/assets/reacher.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/assets/swimmer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/half_cheetah.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class HalfCheetahEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 20,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64)
20 | MuJocoPyEnv.__init__(
21 | self, "half_cheetah.xml", 5, observation_space=observation_space, **kwargs
22 | )
23 | utils.EzPickle.__init__(self, **kwargs)
24 |
25 | def step(self, action):
26 | xposbefore = self.sim.data.qpos[0]
27 | self.do_simulation(action, self.frame_skip)
28 | xposafter = self.sim.data.qpos[0]
29 |
30 | ob = self._get_obs()
31 | reward_ctrl = -0.1 * np.square(action).sum()
32 | reward_run = (xposafter - xposbefore) / self.dt
33 | reward = reward_ctrl + reward_run
34 | terminated = False
35 |
36 | if self.render_mode == "human":
37 | self.render()
38 | return (
39 | ob,
40 | reward,
41 | terminated,
42 | False,
43 | dict(reward_run=reward_run, reward_ctrl=reward_ctrl),
44 | )
45 |
46 | def _get_obs(self):
47 | return np.concatenate(
48 | [
49 | self.sim.data.qpos.flat[1:],
50 | self.sim.data.qvel.flat,
51 | ]
52 | )
53 |
54 | def reset_model(self):
55 | qpos = self.init_qpos + self.np_random.uniform(
56 | low=-0.1, high=0.1, size=self.model.nq
57 | )
58 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
59 | self.set_state(qpos, qvel)
60 | return self._get_obs()
61 |
62 | def viewer_setup(self):
63 | assert self.viewer is not None
64 | self.viewer.cam.distance = self.model.stat.extent * 0.5
65 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/hopper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class HopperEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 125,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
20 | MuJocoPyEnv.__init__(
21 | self, "hopper.xml", 4, observation_space=observation_space, **kwargs
22 | )
23 | utils.EzPickle.__init__(self, **kwargs)
24 |
25 | def step(self, a):
26 | posbefore = self.sim.data.qpos[0]
27 | self.do_simulation(a, self.frame_skip)
28 | posafter, height, ang = self.sim.data.qpos[0:3]
29 |
30 | alive_bonus = 1.0
31 | reward = (posafter - posbefore) / self.dt
32 | reward += alive_bonus
33 | reward -= 1e-3 * np.square(a).sum()
34 | s = self.state_vector()
35 | terminated = not (
36 | np.isfinite(s).all()
37 | and (np.abs(s[2:]) < 100).all()
38 | and (height > 0.7)
39 | and (abs(ang) < 0.2)
40 | )
41 | ob = self._get_obs()
42 |
43 | if self.render_mode == "human":
44 | self.render()
45 | return ob, reward, terminated, False, {}
46 |
47 | def _get_obs(self):
48 | return np.concatenate(
49 | [self.sim.data.qpos.flat[1:], np.clip(self.sim.data.qvel.flat, -10, 10)]
50 | )
51 |
52 | def reset_model(self):
53 | qpos = self.init_qpos + self.np_random.uniform(
54 | low=-0.005, high=0.005, size=self.model.nq
55 | )
56 | qvel = self.init_qvel + self.np_random.uniform(
57 | low=-0.005, high=0.005, size=self.model.nv
58 | )
59 | self.set_state(qpos, qvel)
60 | return self._get_obs()
61 |
62 | def viewer_setup(self):
63 | assert self.viewer is not None
64 | self.viewer.cam.trackbodyid = 2
65 | self.viewer.cam.distance = self.model.stat.extent * 0.75
66 | self.viewer.cam.lookat[2] = 1.15
67 | self.viewer.cam.elevation = -20
68 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/humanoid.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | def mass_center(model, sim):
9 | mass = np.expand_dims(model.body_mass, 1)
10 | xpos = sim.data.xipos
11 | return (np.sum(mass * xpos, 0) / np.sum(mass))[0]
12 |
13 |
14 | class HumanoidEnv(MuJocoPyEnv, utils.EzPickle):
15 | metadata = {
16 | "render_modes": [
17 | "human",
18 | "rgb_array",
19 | "depth_array",
20 | ],
21 | "render_fps": 67,
22 | }
23 |
24 | def __init__(self, **kwargs):
25 | observation_space = Box(
26 | low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64
27 | )
28 | MuJocoPyEnv.__init__(
29 | self, "humanoid.xml", 5, observation_space=observation_space, **kwargs
30 | )
31 | utils.EzPickle.__init__(self, **kwargs)
32 |
33 | def _get_obs(self):
34 | data = self.sim.data
35 | return np.concatenate(
36 | [
37 | data.qpos.flat[2:],
38 | data.qvel.flat,
39 | data.cinert.flat,
40 | data.cvel.flat,
41 | data.qfrc_actuator.flat,
42 | data.cfrc_ext.flat,
43 | ]
44 | )
45 |
46 | def step(self, a):
47 | pos_before = mass_center(self.model, self.sim)
48 | self.do_simulation(a, self.frame_skip)
49 | pos_after = mass_center(self.model, self.sim)
50 |
51 | alive_bonus = 5.0
52 | data = self.sim.data
53 | lin_vel_cost = 1.25 * (pos_after - pos_before) / self.dt
54 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum()
55 | quad_impact_cost = 0.5e-6 * np.square(data.cfrc_ext).sum()
56 | quad_impact_cost = min(quad_impact_cost, 10)
57 | reward = lin_vel_cost - quad_ctrl_cost - quad_impact_cost + alive_bonus
58 | qpos = self.sim.data.qpos
59 | terminated = bool((qpos[2] < 1.0) or (qpos[2] > 2.0))
60 |
61 | if self.render_mode == "human":
62 | self.render()
63 | return (
64 | self._get_obs(),
65 | reward,
66 | terminated,
67 | False,
68 | dict(
69 | reward_linvel=lin_vel_cost,
70 | reward_quadctrl=-quad_ctrl_cost,
71 | reward_alive=alive_bonus,
72 | reward_impact=-quad_impact_cost,
73 | ),
74 | )
75 |
76 | def reset_model(self):
77 | c = 0.01
78 | self.set_state(
79 | self.init_qpos + self.np_random.uniform(low=-c, high=c, size=self.model.nq),
80 | self.init_qvel
81 | + self.np_random.uniform(
82 | low=-c,
83 | high=c,
84 | size=self.model.nv,
85 | ),
86 | )
87 | return self._get_obs()
88 |
89 | def viewer_setup(self):
90 | assert self.viewer is not None
91 | self.viewer.cam.trackbodyid = 1
92 | self.viewer.cam.distance = self.model.stat.extent * 1.0
93 | self.viewer.cam.lookat[2] = 2.0
94 | self.viewer.cam.elevation = -20
95 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/humanoidstandup.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class HumanoidStandupEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 67,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | observation_space = Box(
20 | low=-np.inf, high=np.inf, shape=(376,), dtype=np.float64
21 | )
22 | MuJocoPyEnv.__init__(
23 | self,
24 | "humanoidstandup.xml",
25 | 5,
26 | observation_space=observation_space,
27 | **kwargs
28 | )
29 | utils.EzPickle.__init__(self, **kwargs)
30 |
31 | def _get_obs(self):
32 | data = self.sim.data
33 | return np.concatenate(
34 | [
35 | data.qpos.flat[2:],
36 | data.qvel.flat,
37 | data.cinert.flat,
38 | data.cvel.flat,
39 | data.qfrc_actuator.flat,
40 | data.cfrc_ext.flat,
41 | ]
42 | )
43 |
44 | def step(self, a):
45 | self.do_simulation(a, self.frame_skip)
46 | pos_after = self.sim.data.qpos[2]
47 | data = self.sim.data
48 | uph_cost = (pos_after - 0) / self.model.opt.timestep
49 |
50 | quad_ctrl_cost = 0.1 * np.square(data.ctrl).sum()
51 | quad_impact_cost = 0.5e-6 * np.square(data.cfrc_ext).sum()
52 | quad_impact_cost = min(quad_impact_cost, 10)
53 | reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1
54 |
55 | if self.render_mode == "human":
56 | self.render()
57 | return (
58 | self._get_obs(),
59 | reward,
60 | False,
61 | False,
62 | dict(
63 | reward_linup=uph_cost,
64 | reward_quadctrl=-quad_ctrl_cost,
65 | reward_impact=-quad_impact_cost,
66 | ),
67 | )
68 |
69 | def reset_model(self):
70 | c = 0.01
71 | self.set_state(
72 | self.init_qpos + self.np_random.uniform(low=-c, high=c, size=self.model.nq),
73 | self.init_qvel
74 | + self.np_random.uniform(
75 | low=-c,
76 | high=c,
77 | size=self.model.nv,
78 | ),
79 | )
80 | return self._get_obs()
81 |
82 | def viewer_setup(self):
83 | assert self.viewer is not None
84 | self.viewer.cam.trackbodyid = 1
85 | self.viewer.cam.distance = self.model.stat.extent * 1.0
86 | self.viewer.cam.lookat[2] = 0.8925
87 | self.viewer.cam.elevation = -20
88 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/inverted_double_pendulum.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class InvertedDoublePendulumEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 20,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
20 | MuJocoPyEnv.__init__(
21 | self,
22 | "inverted_double_pendulum.xml",
23 | 5,
24 | observation_space=observation_space,
25 | **kwargs
26 | )
27 | utils.EzPickle.__init__(self, **kwargs)
28 |
29 | def step(self, action):
30 | self.do_simulation(action, self.frame_skip)
31 |
32 | ob = self._get_obs()
33 | x, _, y = self.sim.data.site_xpos[0]
34 | dist_penalty = 0.01 * x**2 + (y - 2) ** 2
35 | v1, v2 = self.sim.data.qvel[1:3]
36 | vel_penalty = 1e-3 * v1**2 + 5e-3 * v2**2
37 | alive_bonus = 10
38 | r = alive_bonus - dist_penalty - vel_penalty
39 | terminated = bool(y <= 1)
40 |
41 | if self.render_mode == "human":
42 | self.render()
43 | return ob, r, terminated, False, {}
44 |
45 | def _get_obs(self):
46 | return np.concatenate(
47 | [
48 | self.sim.data.qpos[:1], # cart x pos
49 | np.sin(self.sim.data.qpos[1:]), # link angles
50 | np.cos(self.sim.data.qpos[1:]),
51 | np.clip(self.sim.data.qvel, -10, 10),
52 | np.clip(self.sim.data.qfrc_constraint, -10, 10),
53 | ]
54 | ).ravel()
55 |
56 | def reset_model(self):
57 | self.set_state(
58 | self.init_qpos
59 | + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
60 | self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1,
61 | )
62 | return self._get_obs()
63 |
64 | def viewer_setup(self):
65 | assert self.viewer is not None
66 | v = self.viewer
67 | v.cam.trackbodyid = 0
68 | v.cam.distance = self.model.stat.extent * 0.5
69 | v.cam.lookat[2] = 0.12250000000000005 # v.model.stat.center[2]
70 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/inverted_pendulum.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class InvertedPendulumEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 25,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | utils.EzPickle.__init__(self, **kwargs)
20 | observation_space = Box(low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64)
21 | MuJocoPyEnv.__init__(
22 | self,
23 | "inverted_pendulum.xml",
24 | 2,
25 | observation_space=observation_space,
26 | **kwargs
27 | )
28 |
29 | def step(self, a):
30 | reward = 1.0
31 | self.do_simulation(a, self.frame_skip)
32 |
33 | ob = self._get_obs()
34 | terminated = bool(not np.isfinite(ob).all() or (np.abs(ob[1]) > 0.2))
35 |
36 | if self.render_mode == "human":
37 | self.render()
38 | return ob, reward, terminated, False, {}
39 |
40 | def reset_model(self):
41 | qpos = self.init_qpos + self.np_random.uniform(
42 | size=self.model.nq, low=-0.01, high=0.01
43 | )
44 | qvel = self.init_qvel + self.np_random.uniform(
45 | size=self.model.nv, low=-0.01, high=0.01
46 | )
47 | self.set_state(qpos, qvel)
48 | return self._get_obs()
49 |
50 | def _get_obs(self):
51 | return np.concatenate([self.sim.data.qpos, self.sim.data.qvel]).ravel()
52 |
53 | def viewer_setup(self):
54 | assert self.viewer is not None
55 | self.viewer.cam.trackbodyid = 0
56 | self.viewer.cam.distance = self.model.stat.extent
57 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/pusher.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class PusherEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 20,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | utils.EzPickle.__init__(self, **kwargs)
20 | observation_space = Box(low=-np.inf, high=np.inf, shape=(23,), dtype=np.float64)
21 | MuJocoPyEnv.__init__(
22 | self, "pusher.xml", 5, observation_space=observation_space, **kwargs
23 | )
24 |
25 | def step(self, a):
26 | vec_1 = self.get_body_com("object") - self.get_body_com("tips_arm")
27 | vec_2 = self.get_body_com("object") - self.get_body_com("goal")
28 |
29 | reward_near = -np.linalg.norm(vec_1)
30 | reward_dist = -np.linalg.norm(vec_2)
31 | reward_ctrl = -np.square(a).sum()
32 | reward = reward_dist + 0.1 * reward_ctrl + 0.5 * reward_near
33 |
34 | self.do_simulation(a, self.frame_skip)
35 | if self.render_mode == "human":
36 | self.render()
37 |
38 | ob = self._get_obs()
39 | return (
40 | ob,
41 | reward,
42 | False,
43 | False,
44 | dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
45 | )
46 |
47 | def viewer_setup(self):
48 | assert self.viewer is not None
49 | self.viewer.cam.trackbodyid = -1
50 | self.viewer.cam.distance = 4.0
51 |
52 | def reset_model(self):
53 | qpos = self.init_qpos
54 |
55 | self.goal_pos = np.asarray([0, 0])
56 | while True:
57 | self.cylinder_pos = np.concatenate(
58 | [
59 | self.np_random.uniform(low=-0.3, high=0, size=1),
60 | self.np_random.uniform(low=-0.2, high=0.2, size=1),
61 | ]
62 | )
63 | if np.linalg.norm(self.cylinder_pos - self.goal_pos) > 0.17:
64 | break
65 |
66 | qpos[-4:-2] = self.cylinder_pos
67 | qpos[-2:] = self.goal_pos
68 | qvel = self.init_qvel + self.np_random.uniform(
69 | low=-0.005, high=0.005, size=self.model.nv
70 | )
71 | qvel[-4:] = 0
72 | self.set_state(qpos, qvel)
73 | return self._get_obs()
74 |
75 | def _get_obs(self):
76 | return np.concatenate(
77 | [
78 | self.sim.data.qpos.flat[:7],
79 | self.sim.data.qvel.flat[:7],
80 | self.get_body_com("tips_arm"),
81 | self.get_body_com("object"),
82 | self.get_body_com("goal"),
83 | ]
84 | )
85 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/reacher.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class ReacherEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 50,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | utils.EzPickle.__init__(self, **kwargs)
20 | observation_space = Box(low=-np.inf, high=np.inf, shape=(11,), dtype=np.float64)
21 | MuJocoPyEnv.__init__(
22 | self, "reacher.xml", 2, observation_space=observation_space, **kwargs
23 | )
24 |
25 | def step(self, a):
26 | vec = self.get_body_com("fingertip") - self.get_body_com("target")
27 | reward_dist = -np.linalg.norm(vec)
28 | reward_ctrl = -np.square(a).sum()
29 | reward = reward_dist + reward_ctrl
30 |
31 | self.do_simulation(a, self.frame_skip)
32 | if self.render_mode == "human":
33 | self.render()
34 |
35 | ob = self._get_obs()
36 | return (
37 | ob,
38 | reward,
39 | False,
40 | False,
41 | dict(reward_dist=reward_dist, reward_ctrl=reward_ctrl),
42 | )
43 |
44 | def viewer_setup(self):
45 | assert self.viewer is not None
46 | self.viewer.cam.trackbodyid = 0
47 |
48 | def reset_model(self):
49 | qpos = (
50 | self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq)
51 | + self.init_qpos
52 | )
53 | while True:
54 | self.goal = self.np_random.uniform(low=-0.2, high=0.2, size=2)
55 | if np.linalg.norm(self.goal) < 0.2:
56 | break
57 | qpos[-2:] = self.goal
58 | qvel = self.init_qvel + self.np_random.uniform(
59 | low=-0.005, high=0.005, size=self.model.nv
60 | )
61 | qvel[-2:] = 0
62 | self.set_state(qpos, qvel)
63 | return self._get_obs()
64 |
65 | def _get_obs(self):
66 | theta = self.sim.data.qpos.flat[:2]
67 | return np.concatenate(
68 | [
69 | np.cos(theta),
70 | np.sin(theta),
71 | self.sim.data.qpos.flat[2:],
72 | self.sim.data.qvel.flat[:2],
73 | self.get_body_com("fingertip") - self.get_body_com("target"),
74 | ]
75 | )
76 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/swimmer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class SwimmerEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 25,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(8,), dtype=np.float64)
20 | MuJocoPyEnv.__init__(
21 | self, "swimmer.xml", 4, observation_space=observation_space, **kwargs
22 | )
23 | utils.EzPickle.__init__(self, **kwargs)
24 |
25 | def step(self, a):
26 | ctrl_cost_coeff = 0.0001
27 | xposbefore = self.sim.data.qpos[0]
28 | self.do_simulation(a, self.frame_skip)
29 | xposafter = self.sim.data.qpos[0]
30 |
31 | reward_fwd = (xposafter - xposbefore) / self.dt
32 | reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
33 | reward = reward_fwd + reward_ctrl
34 | ob = self._get_obs()
35 |
36 | if self.render_mode == "human":
37 | self.render()
38 |
39 | return (
40 | ob,
41 | reward,
42 | False,
43 | False,
44 | dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl),
45 | )
46 |
47 | def _get_obs(self):
48 | qpos = self.sim.data.qpos
49 | qvel = self.sim.data.qvel
50 | return np.concatenate([qpos.flat[2:], qvel.flat])
51 |
52 | def reset_model(self):
53 | self.set_state(
54 | self.init_qpos
55 | + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nq),
56 | self.init_qvel
57 | + self.np_random.uniform(low=-0.1, high=0.1, size=self.model.nv),
58 | )
59 | return self._get_obs()
60 |
--------------------------------------------------------------------------------
/gym/envs/mujoco/walker2d.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym import utils
4 | from gym.envs.mujoco import MuJocoPyEnv
5 | from gym.spaces import Box
6 |
7 |
8 | class Walker2dEnv(MuJocoPyEnv, utils.EzPickle):
9 | metadata = {
10 | "render_modes": [
11 | "human",
12 | "rgb_array",
13 | "depth_array",
14 | ],
15 | "render_fps": 125,
16 | }
17 |
18 | def __init__(self, **kwargs):
19 | observation_space = Box(low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64)
20 | MuJocoPyEnv.__init__(
21 | self, "walker2d.xml", 4, observation_space=observation_space, **kwargs
22 | )
23 | utils.EzPickle.__init__(self, **kwargs)
24 |
25 | def step(self, a):
26 | posbefore = self.sim.data.qpos[0]
27 | self.do_simulation(a, self.frame_skip)
28 | posafter, height, ang = self.sim.data.qpos[0:3]
29 |
30 | alive_bonus = 1.0
31 | reward = (posafter - posbefore) / self.dt
32 | reward += alive_bonus
33 | reward -= 1e-3 * np.square(a).sum()
34 | terminated = not (height > 0.8 and height < 2.0 and ang > -1.0 and ang < 1.0)
35 | ob = self._get_obs()
36 |
37 | if self.render_mode == "human":
38 | self.render()
39 |
40 | return ob, reward, terminated, False, {}
41 |
42 | def _get_obs(self):
43 | qpos = self.sim.data.qpos
44 | qvel = self.sim.data.qvel
45 | return np.concatenate([qpos[1:], np.clip(qvel, -10, 10)]).ravel()
46 |
47 | def reset_model(self):
48 | self.set_state(
49 | self.init_qpos
50 | + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nq),
51 | self.init_qvel
52 | + self.np_random.uniform(low=-0.005, high=0.005, size=self.model.nv),
53 | )
54 | return self._get_obs()
55 |
56 | def viewer_setup(self):
57 | assert self.viewer is not None
58 | self.viewer.cam.trackbodyid = 2
59 | self.viewer.cam.distance = self.model.stat.extent * 0.5
60 | self.viewer.cam.lookat[2] = 1.15
61 | self.viewer.cam.elevation = -20
62 |
--------------------------------------------------------------------------------
/gym/envs/toy_text/__init__.py:
--------------------------------------------------------------------------------
1 | from gym.envs.toy_text.blackjack import BlackjackEnv
2 | from gym.envs.toy_text.cliffwalking import CliffWalkingEnv
3 | from gym.envs.toy_text.frozen_lake import FrozenLakeEnv
4 | from gym.envs.toy_text.taxi import TaxiEnv
5 |
--------------------------------------------------------------------------------
/gym/envs/toy_text/font/Minecraft.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/font/Minecraft.ttf
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C2.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C3.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C4.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C5.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C6.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C7.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C8.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/C9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/C9.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/CA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CA.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/CJ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CJ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/CK.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CK.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/CQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CQ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/CT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/CT.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/Card.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/Card.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D2.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D3.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D4.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D5.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D6.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D7.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D8.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/D9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/D9.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/DA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DA.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/DJ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DJ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/DK.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DK.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/DQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DQ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/DT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/DT.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H2.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H3.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H4.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H5.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H6.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H7.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H8.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/H9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/H9.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/HA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HA.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/HJ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HJ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/HK.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HK.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/HQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HQ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/HT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/HT.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S2.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S3.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S4.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S5.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S6.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S7.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S8.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/S9.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/S9.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/SA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SA.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/SJ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SJ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/SK.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SK.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/SQ.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/SQ.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/ST.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/ST.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/cab_front.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_front.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/cab_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_left.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/cab_rear.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_rear.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/cab_right.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cab_right.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/cookie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cookie.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/cracked_hole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/cracked_hole.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/elf_down.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_down.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/elf_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_left.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/elf_right.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_right.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/elf_up.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/elf_up.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/goal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/goal.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/gridworld_median_bottom.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_bottom.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/gridworld_median_horiz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_horiz.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/gridworld_median_left.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_left.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/gridworld_median_right.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_right.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/gridworld_median_top.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_top.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/gridworld_median_vert.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/gridworld_median_vert.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/hole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/hole.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/hotel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/hotel.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/ice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/ice.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/mountain_bg1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_bg1.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/mountain_bg2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_bg2.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/mountain_cliff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_cliff.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/mountain_near-cliff1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_near-cliff1.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/mountain_near-cliff2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/mountain_near-cliff2.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/passenger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/passenger.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/stool.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/stool.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/img/taxi_background.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/envs/toy_text/img/taxi_background.png
--------------------------------------------------------------------------------
/gym/envs/toy_text/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def categorical_sample(prob_n, np_random: np.random.Generator):
5 | """Sample from categorical distribution where each row specifies class probabilities."""
6 | prob_n = np.asarray(prob_n)
7 | csprob_n = np.cumsum(prob_n)
8 | return np.argmax(csprob_n > np_random.random())
9 |
--------------------------------------------------------------------------------
/gym/logger.py:
--------------------------------------------------------------------------------
1 | """Set of functions for logging messages."""
2 | import sys
3 | import warnings
4 | from typing import Optional, Type
5 |
6 | from gym.utils import colorize
7 |
8 | DEBUG = 10
9 | INFO = 20
10 | WARN = 30
11 | ERROR = 40
12 | DISABLED = 50
13 |
14 | min_level = 30
15 |
16 |
17 | # Ensure DeprecationWarning to be displayed (#2685, #3059)
18 | warnings.filterwarnings("once", "", DeprecationWarning, module=r"^gym\.")
19 |
20 |
21 | def set_level(level: int):
22 | """Set logging threshold on current logger."""
23 | global min_level
24 | min_level = level
25 |
26 |
27 | def debug(msg: str, *args: object):
28 | """Logs a debug message to the user."""
29 | if min_level <= DEBUG:
30 | print(f"DEBUG: {msg % args}", file=sys.stderr)
31 |
32 |
33 | def info(msg: str, *args: object):
34 | """Logs an info message to the user."""
35 | if min_level <= INFO:
36 | print(f"INFO: {msg % args}", file=sys.stderr)
37 |
38 |
39 | def warn(
40 | msg: str,
41 | *args: object,
42 | category: Optional[Type[Warning]] = None,
43 | stacklevel: int = 1,
44 | ):
45 | """Raises a warning to the user if the min_level <= WARN.
46 |
47 | Args:
48 | msg: The message to warn the user
49 | *args: Additional information to warn the user
50 | category: The category of warning
51 | stacklevel: The stack level to raise to
52 | """
53 | if min_level <= WARN:
54 | warnings.warn(
55 | colorize(f"WARN: {msg % args}", "yellow"),
56 | category=category,
57 | stacklevel=stacklevel + 1,
58 | )
59 |
60 |
61 | def deprecation(msg: str, *args: object):
62 | """Logs a deprecation warning to users."""
63 | warn(msg, *args, category=DeprecationWarning, stacklevel=2)
64 |
65 |
66 | def error(msg: str, *args: object):
67 | """Logs an error message if min_level <= ERROR in red on the sys.stderr."""
68 | if min_level <= ERROR:
69 | print(colorize(f"ERROR: {msg % args}", "red"), file=sys.stderr)
70 |
71 |
72 | # DEPRECATED:
73 | setLevel = set_level
74 |
--------------------------------------------------------------------------------
/gym/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/gym/py.typed
--------------------------------------------------------------------------------
/gym/spaces/__init__.py:
--------------------------------------------------------------------------------
1 | """This module implements various spaces.
2 |
3 | Spaces describe mathematical sets and are used in Gym to specify valid actions and observations.
4 | Every Gym environment must have the attributes ``action_space`` and ``observation_space``.
5 | If, for instance, three possible actions (0,1,2) can be performed in your environment and observations
6 | are vectors in the two-dimensional unit cube, the environment code may contain the following two lines::
7 |
8 | self.action_space = spaces.Discrete(3)
9 | self.observation_space = spaces.Box(0, 1, shape=(2,))
10 | """
11 | from gym.spaces.box import Box
12 | from gym.spaces.dict import Dict
13 | from gym.spaces.discrete import Discrete
14 | from gym.spaces.graph import Graph, GraphInstance
15 | from gym.spaces.multi_binary import MultiBinary
16 | from gym.spaces.multi_discrete import MultiDiscrete
17 | from gym.spaces.sequence import Sequence
18 | from gym.spaces.space import Space
19 | from gym.spaces.text import Text
20 | from gym.spaces.tuple import Tuple
21 | from gym.spaces.utils import flatdim, flatten, flatten_space, unflatten
22 |
23 | __all__ = [
24 | "Space",
25 | "Box",
26 | "Discrete",
27 | "Text",
28 | "Graph",
29 | "GraphInstance",
30 | "MultiDiscrete",
31 | "MultiBinary",
32 | "Tuple",
33 | "Sequence",
34 | "Dict",
35 | "flatdim",
36 | "flatten_space",
37 | "flatten",
38 | "unflatten",
39 | ]
40 |
--------------------------------------------------------------------------------
/gym/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """A set of common utilities used within the environments.
2 |
3 | These are not intended as API functions, and will not remain stable over time.
4 | """
5 |
6 | # These submodules should not have any import-time dependencies.
7 | # We want this since we use `utils` during our import-time sanity checks
8 | # that verify that our dependencies are actually present.
9 | from gym.utils.colorize import colorize
10 | from gym.utils.ezpickle import EzPickle
11 |
--------------------------------------------------------------------------------
/gym/utils/colorize.py:
--------------------------------------------------------------------------------
1 | """A set of common utilities used within the environments.
2 |
3 | These are not intended as API functions, and will not remain stable over time.
4 | """
5 |
6 | color2num = dict(
7 | gray=30,
8 | red=31,
9 | green=32,
10 | yellow=33,
11 | blue=34,
12 | magenta=35,
13 | cyan=36,
14 | white=37,
15 | crimson=38,
16 | )
17 |
18 |
19 | def colorize(
20 | string: str, color: str, bold: bool = False, highlight: bool = False
21 | ) -> str:
22 | """Returns string surrounded by appropriate terminal colour codes to print colourised text.
23 |
24 | Args:
25 | string: The message to colourise
26 | color: Literal values are gray, red, green, yellow, blue, magenta, cyan, white, crimson
27 | bold: If to bold the string
28 | highlight: If to highlight the string
29 |
30 | Returns:
31 | Colourised string
32 | """
33 | attr = []
34 | num = color2num[color]
35 | if highlight:
36 | num += 10
37 | attr.append(str(num))
38 | if bold:
39 | attr.append("1")
40 | attrs = ";".join(attr)
41 | return f"\x1b[{attrs}m{string}\x1b[0m"
42 |
--------------------------------------------------------------------------------
/gym/utils/ezpickle.py:
--------------------------------------------------------------------------------
1 | """Class for pickling and unpickling objects via their constructor arguments."""
2 |
3 |
4 | class EzPickle:
5 | """Objects that are pickled and unpickled via their constructor arguments.
6 |
7 | Example::
8 |
9 | >>> class Dog(Animal, EzPickle):
10 | ... def __init__(self, furcolor, tailkind="bushy"):
11 | ... Animal.__init__()
12 | ... EzPickle.__init__(furcolor, tailkind)
13 |
14 | When this object is unpickled, a new ``Dog`` will be constructed by passing the provided furcolor and tailkind into the constructor.
15 | However, philosophers are still not sure whether it is still the same dog.
16 |
17 | This is generally needed only for environments which wrap C/C++ code, such as MuJoCo and Atari.
18 | """
19 |
20 | def __init__(self, *args, **kwargs):
21 | """Uses the ``args`` and ``kwargs`` from the object's constructor for pickling."""
22 | self._ezpickle_args = args
23 | self._ezpickle_kwargs = kwargs
24 |
25 | def __getstate__(self):
26 | """Returns the object pickle state with args and kwargs."""
27 | return {
28 | "_ezpickle_args": self._ezpickle_args,
29 | "_ezpickle_kwargs": self._ezpickle_kwargs,
30 | }
31 |
32 | def __setstate__(self, d):
33 | """Sets the object pickle state using d."""
34 | out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"])
35 | self.__dict__.update(out.__dict__)
36 |
--------------------------------------------------------------------------------
/gym/utils/seeding.py:
--------------------------------------------------------------------------------
1 | """Set of random number generator functions: seeding, generator, hashing seeds."""
2 | from typing import Any, Optional, Tuple
3 |
4 | import numpy as np
5 |
6 | from gym import error
7 |
8 |
9 | def np_random(seed: Optional[int] = None) -> Tuple[np.random.Generator, Any]:
10 | """Generates a random number generator from the seed and returns the Generator and seed.
11 |
12 | Args:
13 | seed: The seed used to create the generator
14 |
15 | Returns:
16 | The generator and resulting seed
17 |
18 | Raises:
19 | Error: Seed must be a non-negative integer or omitted
20 | """
21 | if seed is not None and not (isinstance(seed, int) and 0 <= seed):
22 | raise error.Error(f"Seed must be a non-negative integer or omitted, not {seed}")
23 |
24 | seed_seq = np.random.SeedSequence(seed)
25 | np_seed = seed_seq.entropy
26 | rng = RandomNumberGenerator(np.random.PCG64(seed_seq))
27 | return rng, np_seed
28 |
29 |
30 | RNG = RandomNumberGenerator = np.random.Generator
31 |
--------------------------------------------------------------------------------
/gym/vector/__init__.py:
--------------------------------------------------------------------------------
1 | """Module for vector environments."""
2 | from typing import Iterable, List, Optional, Union
3 |
4 | import gym
5 | from gym.vector.async_vector_env import AsyncVectorEnv
6 | from gym.vector.sync_vector_env import SyncVectorEnv
7 | from gym.vector.vector_env import VectorEnv, VectorEnvWrapper
8 |
9 | __all__ = ["AsyncVectorEnv", "SyncVectorEnv", "VectorEnv", "VectorEnvWrapper", "make"]
10 |
11 |
12 | def make(
13 | id: str,
14 | num_envs: int = 1,
15 | asynchronous: bool = True,
16 | wrappers: Optional[Union[callable, List[callable]]] = None,
17 | disable_env_checker: Optional[bool] = None,
18 | **kwargs,
19 | ) -> VectorEnv:
20 | """Create a vectorized environment from multiple copies of an environment, from its id.
21 |
22 | Example::
23 |
24 | >>> import gym
25 | >>> env = gym.vector.make('CartPole-v1', num_envs=3)
26 | >>> env.reset()
27 | array([[-0.04456399, 0.04653909, 0.01326909, -0.02099827],
28 | [ 0.03073904, 0.00145001, -0.03088818, -0.03131252],
29 | [ 0.03468829, 0.01500225, 0.01230312, 0.01825218]],
30 | dtype=float32)
31 |
32 | Args:
33 | id: The environment ID. This must be a valid ID from the registry.
34 | num_envs: Number of copies of the environment.
35 | asynchronous: If `True`, wraps the environments in an :class:`AsyncVectorEnv` (which uses `multiprocessing`_ to run the environments in parallel). If ``False``, wraps the environments in a :class:`SyncVectorEnv`.
36 | wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
37 | disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter
38 | (that is by default False), otherwise will run according to this argument (True = not run, False = run)
39 | **kwargs: Keywords arguments applied during `gym.make`
40 |
41 | Returns:
42 | The vectorized environment.
43 | """
44 |
45 | def create_env(env_num: int):
46 | """Creates an environment that can enable or disable the environment checker."""
47 | # If the env_num > 0 then disable the environment checker otherwise use the parameter
48 | _disable_env_checker = True if env_num > 0 else disable_env_checker
49 |
50 | def _make_env():
51 | env = gym.envs.registration.make(
52 | id,
53 | disable_env_checker=_disable_env_checker,
54 | **kwargs,
55 | )
56 | if wrappers is not None:
57 | if callable(wrappers):
58 | env = wrappers(env)
59 | elif isinstance(wrappers, Iterable) and all(
60 | [callable(w) for w in wrappers]
61 | ):
62 | for wrapper in wrappers:
63 | env = wrapper(env)
64 | else:
65 | raise NotImplementedError
66 | return env
67 |
68 | return _make_env
69 |
70 | env_fns = [
71 | create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
72 | ]
73 | return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
74 |
--------------------------------------------------------------------------------
/gym/vector/utils/__init__.py:
--------------------------------------------------------------------------------
1 | """Module for gym vector utils."""
2 | from gym.vector.utils.misc import CloudpickleWrapper, clear_mpi_env_vars
3 | from gym.vector.utils.numpy_utils import concatenate, create_empty_array
4 | from gym.vector.utils.shared_memory import (
5 | create_shared_memory,
6 | read_from_shared_memory,
7 | write_to_shared_memory,
8 | )
9 | from gym.vector.utils.spaces import _BaseGymSpaces # pyright: reportPrivateUsage=false
10 | from gym.vector.utils.spaces import BaseGymSpaces, batch_space, iterate
11 |
12 | __all__ = [
13 | "CloudpickleWrapper",
14 | "clear_mpi_env_vars",
15 | "concatenate",
16 | "create_empty_array",
17 | "create_shared_memory",
18 | "read_from_shared_memory",
19 | "write_to_shared_memory",
20 | "BaseGymSpaces",
21 | "batch_space",
22 | "iterate",
23 | ]
24 |
--------------------------------------------------------------------------------
/gym/vector/utils/misc.py:
--------------------------------------------------------------------------------
1 | """Miscellaneous utilities."""
2 | import contextlib
3 | import os
4 |
5 | __all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
6 |
7 |
8 | class CloudpickleWrapper:
9 | """Wrapper that uses cloudpickle to pickle and unpickle the result."""
10 |
11 | def __init__(self, fn: callable):
12 | """Cloudpickle wrapper for a function."""
13 | self.fn = fn
14 |
15 | def __getstate__(self):
16 | """Get the state using `cloudpickle.dumps(self.fn)`."""
17 | import cloudpickle
18 |
19 | return cloudpickle.dumps(self.fn)
20 |
21 | def __setstate__(self, ob):
22 | """Sets the state with obs."""
23 | import pickle
24 |
25 | self.fn = pickle.loads(ob)
26 |
27 | def __call__(self):
28 | """Calls the function `self.fn` with no arguments."""
29 | return self.fn()
30 |
31 |
32 | @contextlib.contextmanager
33 | def clear_mpi_env_vars():
34 | """Clears the MPI of environment variables.
35 |
36 | `from mpi4py import MPI` will call `MPI_Init` by default.
37 | If the child process has MPI environment variables, MPI will think that the child process
38 | is an MPI process just like the parent and do bad things such as hang.
39 |
40 | This context manager is a hacky way to clear those environment variables
41 | temporarily such as when we are starting multiprocessing Processes.
42 |
43 | Yields:
44 | Yields for the context manager
45 | """
46 | removed_environment = {}
47 | for k, v in list(os.environ.items()):
48 | for prefix in ["OMPI_", "PMI_"]:
49 | if k.startswith(prefix):
50 | removed_environment[k] = v
51 | del os.environ[k]
52 | try:
53 | yield
54 | finally:
55 | os.environ.update(removed_environment)
56 |
--------------------------------------------------------------------------------
/gym/version.py:
--------------------------------------------------------------------------------
1 | VERSION = "0.26.2"
2 |
--------------------------------------------------------------------------------
/gym/wrappers/README.md:
--------------------------------------------------------------------------------
1 | # Wrappers
2 |
3 | Wrappers are used to transform an environment in a modular way:
4 |
5 | ```python
6 | env = gym.make('Pong-v0')
7 | env = MyWrapper(env)
8 | ```
9 |
10 | Note that we may later restructure any of the files in this directory,
11 | but will keep the wrappers available at the wrappers' top-level
12 | folder. So for example, you should access `MyWrapper` as follows:
13 |
14 | ```python
15 | from gym.wrappers import MyWrapper
16 | ```
17 |
18 | ## Quick tips for writing your own wrapper
19 |
20 | - Don't forget to call `super(class_name, self).__init__(env)` if you override the wrapper's `__init__` function
21 | - You can access the inner environment with `self.unwrapped`
22 | - You can access the previous layer using `self.env`
23 | - The variables `metadata`, `action_space`, `observation_space`, `reward_range`, and `spec` are copied to `self` from the previous layer
24 | - Create a wrapped function for at least one of the following: `__init__(self, env)`, `step`, `reset`, `render`, `close`, or `seed`
25 | - Your layered function should take its input from the previous layer (`self.env`) and/or the inner layer (`self.unwrapped`)
26 |
--------------------------------------------------------------------------------
/gym/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | """Module of wrapper classes."""
2 | from gym import error
3 | from gym.wrappers.atari_preprocessing import AtariPreprocessing
4 | from gym.wrappers.autoreset import AutoResetWrapper
5 | from gym.wrappers.clip_action import ClipAction
6 | from gym.wrappers.filter_observation import FilterObservation
7 | from gym.wrappers.flatten_observation import FlattenObservation
8 | from gym.wrappers.frame_stack import FrameStack, LazyFrames
9 | from gym.wrappers.gray_scale_observation import GrayScaleObservation
10 | from gym.wrappers.human_rendering import HumanRendering
11 | from gym.wrappers.normalize import NormalizeObservation, NormalizeReward
12 | from gym.wrappers.order_enforcing import OrderEnforcing
13 | from gym.wrappers.record_episode_statistics import RecordEpisodeStatistics
14 | from gym.wrappers.record_video import RecordVideo, capped_cubic_video_schedule
15 | from gym.wrappers.render_collection import RenderCollection
16 | from gym.wrappers.rescale_action import RescaleAction
17 | from gym.wrappers.resize_observation import ResizeObservation
18 | from gym.wrappers.step_api_compatibility import StepAPICompatibility
19 | from gym.wrappers.time_aware_observation import TimeAwareObservation
20 | from gym.wrappers.time_limit import TimeLimit
21 | from gym.wrappers.transform_observation import TransformObservation
22 | from gym.wrappers.transform_reward import TransformReward
23 | from gym.wrappers.vector_list_info import VectorListInfo
24 |
--------------------------------------------------------------------------------
/gym/wrappers/autoreset.py:
--------------------------------------------------------------------------------
1 | """Wrapper that autoreset environments when `terminated=True` or `truncated=True`."""
2 | import gym
3 |
4 |
5 | class AutoResetWrapper(gym.Wrapper):
6 | """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
7 |
8 | When calling step causes :meth:`Env.step` to return `terminated=True` or `truncated=True`, :meth:`Env.reset` is called,
9 | and the return format of :meth:`self.step` is as follows: ``(new_obs, final_reward, final_terminated, final_truncated, info)``
10 | with new step API and ``(new_obs, final_reward, final_done, info)`` with the old step API.
11 | - ``new_obs`` is the first observation after calling :meth:`self.env.reset`
12 | - ``final_reward`` is the reward after calling :meth:`self.env.step`, prior to calling :meth:`self.env.reset`.
13 | - ``final_terminated`` is the terminated value before calling :meth:`self.env.reset`.
14 | - ``final_truncated`` is the truncated value before calling :meth:`self.env.reset`. Both `final_terminated` and `final_truncated` cannot be False.
15 | - ``info`` is a dict containing all the keys from the info dict returned by the call to :meth:`self.env.reset`,
16 | with an additional key "final_observation" containing the observation returned by the last call to :meth:`self.env.step`
17 | and "final_info" containing the info dict returned by the last call to :meth:`self.env.step`.
18 |
19 | Warning: When using this wrapper to collect rollouts, note that when :meth:`Env.step` returns `terminated` or `truncated`, a
20 | new observation from after calling :meth:`Env.reset` is returned by :meth:`Env.step` alongside the
21 | final reward, terminated and truncated state from the previous episode.
22 | If you need the final state from the previous episode, you need to retrieve it via the
23 | "final_observation" key in the info dict.
24 | Make sure you know what you're doing if you use this wrapper!
25 | """
26 |
27 | def __init__(self, env: gym.Env):
28 | """A class for providing an automatic reset functionality for gym environments when calling :meth:`self.step`.
29 |
30 | Args:
31 | env (gym.Env): The environment to apply the wrapper
32 | """
33 | super().__init__(env)
34 |
35 | def step(self, action):
36 | """Steps through the environment with action and resets the environment if a terminated or truncated signal is encountered.
37 |
38 | Args:
39 | action: The action to take
40 |
41 | Returns:
42 | The autoreset environment :meth:`step`
43 | """
44 | obs, reward, terminated, truncated, info = self.env.step(action)
45 | if terminated or truncated:
46 |
47 | new_obs, new_info = self.env.reset()
48 | assert (
49 | "final_observation" not in new_info
50 | ), 'info dict cannot contain key "final_observation" '
51 | assert (
52 | "final_info" not in new_info
53 | ), 'info dict cannot contain key "final_info" '
54 |
55 | new_info["final_observation"] = obs
56 | new_info["final_info"] = info
57 |
58 | obs = new_obs
59 | info = new_info
60 |
61 | return obs, reward, terminated, truncated, info
62 |
--------------------------------------------------------------------------------
/gym/wrappers/clip_action.py:
--------------------------------------------------------------------------------
1 | """Wrapper for clipping actions within a valid bound."""
2 | import numpy as np
3 |
4 | import gym
5 | from gym import ActionWrapper
6 | from gym.spaces import Box
7 |
8 |
9 | class ClipAction(ActionWrapper):
10 | """Clip the continuous action within the valid :class:`Box` observation space bound.
11 |
12 | Example:
13 | >>> import gym
14 | >>> env = gym.make('Bipedal-Walker-v3')
15 | >>> env = ClipAction(env)
16 | >>> env.action_space
17 | Box(-1.0, 1.0, (4,), float32)
18 | >>> env.step(np.array([5.0, 2.0, -10.0, 0.0]))
19 | # Executes the action np.array([1.0, 1.0, -1.0, 0]) in the base environment
20 | """
21 |
22 | def __init__(self, env: gym.Env):
23 | """A wrapper for clipping continuous actions within the valid bound.
24 |
25 | Args:
26 | env: The environment to apply the wrapper
27 | """
28 | assert isinstance(env.action_space, Box)
29 | super().__init__(env)
30 |
31 | def action(self, action):
32 | """Clips the action within the valid bounds.
33 |
34 | Args:
35 | action: The action to clip
36 |
37 | Returns:
38 | The clipped action
39 | """
40 | return np.clip(action, self.action_space.low, self.action_space.high)
41 |
--------------------------------------------------------------------------------
/gym/wrappers/env_checker.py:
--------------------------------------------------------------------------------
1 | """A passive environment checker wrapper for an environment's observation and action space along with the reset, step and render functions."""
2 | import gym
3 | from gym.core import ActType
4 | from gym.utils.passive_env_checker import (
5 | check_action_space,
6 | check_observation_space,
7 | env_render_passive_checker,
8 | env_reset_passive_checker,
9 | env_step_passive_checker,
10 | )
11 |
12 |
13 | class PassiveEnvChecker(gym.Wrapper):
14 | """A passive environment checker wrapper that surrounds the step, reset and render functions to check they follow the gym API."""
15 |
16 | def __init__(self, env):
17 | """Initialises the wrapper with the environments, run the observation and action space tests."""
18 | super().__init__(env)
19 |
20 | assert hasattr(
21 | env, "action_space"
22 | ), "The environment must specify an action space. https://www.gymlibrary.dev/content/environment_creation/"
23 | check_action_space(env.action_space)
24 | assert hasattr(
25 | env, "observation_space"
26 | ), "The environment must specify an observation space. https://www.gymlibrary.dev/content/environment_creation/"
27 | check_observation_space(env.observation_space)
28 |
29 | self.checked_reset = False
30 | self.checked_step = False
31 | self.checked_render = False
32 |
33 | def step(self, action: ActType):
34 | """Steps through the environment that on the first call will run the `passive_env_step_check`."""
35 | if self.checked_step is False:
36 | self.checked_step = True
37 | return env_step_passive_checker(self.env, action)
38 | else:
39 | return self.env.step(action)
40 |
41 | def reset(self, **kwargs):
42 | """Resets the environment that on the first call will run the `passive_env_reset_check`."""
43 | if self.checked_reset is False:
44 | self.checked_reset = True
45 | return env_reset_passive_checker(self.env, **kwargs)
46 | else:
47 | return self.env.reset(**kwargs)
48 |
49 | def render(self, *args, **kwargs):
50 | """Renders the environment that on the first call will run the `passive_env_render_check`."""
51 | if self.checked_render is False:
52 | self.checked_render = True
53 | return env_render_passive_checker(self.env, *args, **kwargs)
54 | else:
55 | return self.env.render(*args, **kwargs)
56 |
--------------------------------------------------------------------------------
/gym/wrappers/filter_observation.py:
--------------------------------------------------------------------------------
1 | """A wrapper for filtering dictionary observations by their keys."""
2 | import copy
3 | from typing import Sequence
4 |
5 | import gym
6 | from gym import spaces
7 |
8 |
9 | class FilterObservation(gym.ObservationWrapper):
10 | """Filter Dict observation space by the keys.
11 |
12 | Example:
13 | >>> import gym
14 | >>> env = gym.wrappers.TransformObservation(
15 | ... gym.make('CartPole-v1'), lambda obs: {'obs': obs, 'time': 0}
16 | ... )
17 | >>> env.observation_space = gym.spaces.Dict(obs=env.observation_space, time=gym.spaces.Discrete(1))
18 | >>> env.reset()
19 | {'obs': array([-0.00067088, -0.01860439, 0.04772898, -0.01911527], dtype=float32), 'time': 0}
20 | >>> env = FilterObservation(env, filter_keys=['time'])
21 | >>> env.reset()
22 | {'obs': array([ 0.04560107, 0.04466959, -0.0328232 , -0.02367178], dtype=float32)}
23 | >>> env.step(0)
24 | ({'obs': array([ 0.04649447, -0.14996664, -0.03329664, 0.25847703], dtype=float32)}, 1.0, False, {})
25 | """
26 |
27 | def __init__(self, env: gym.Env, filter_keys: Sequence[str] = None):
28 | """A wrapper that filters dictionary observations by their keys.
29 |
30 | Args:
31 | env: The environment to apply the wrapper
32 | filter_keys: List of keys to be included in the observations. If ``None``, observations will not be filtered and this wrapper has no effect
33 |
34 | Raises:
35 | ValueError: If the environment's observation space is not :class:`spaces.Dict`
36 | ValueError: If any of the `filter_keys` are not included in the original `env`'s observation space
37 | """
38 | super().__init__(env)
39 |
40 | wrapped_observation_space = env.observation_space
41 | if not isinstance(wrapped_observation_space, spaces.Dict):
42 | raise ValueError(
43 | f"FilterObservationWrapper is only usable with dict observations, "
44 | f"environment observation space is {type(wrapped_observation_space)}"
45 | )
46 |
47 | observation_keys = wrapped_observation_space.spaces.keys()
48 | if filter_keys is None:
49 | filter_keys = tuple(observation_keys)
50 |
51 | missing_keys = {key for key in filter_keys if key not in observation_keys}
52 | if missing_keys:
53 | raise ValueError(
54 | "All the filter_keys must be included in the original observation space.\n"
55 | f"Filter keys: {filter_keys}\n"
56 | f"Observation keys: {observation_keys}\n"
57 | f"Missing keys: {missing_keys}"
58 | )
59 |
60 | self.observation_space = type(wrapped_observation_space)(
61 | [
62 | (name, copy.deepcopy(space))
63 | for name, space in wrapped_observation_space.spaces.items()
64 | if name in filter_keys
65 | ]
66 | )
67 |
68 | self._env = env
69 | self._filter_keys = tuple(filter_keys)
70 |
71 | def observation(self, observation):
72 | """Filters the observations.
73 |
74 | Args:
75 | observation: The observation to filter
76 |
77 | Returns:
78 | The filtered observations
79 | """
80 | filter_observation = self._filter_observation(observation)
81 | return filter_observation
82 |
83 | def _filter_observation(self, observation):
84 | observation = type(observation)(
85 | [
86 | (name, value)
87 | for name, value in observation.items()
88 | if name in self._filter_keys
89 | ]
90 | )
91 | return observation
92 |
--------------------------------------------------------------------------------
/gym/wrappers/flatten_observation.py:
--------------------------------------------------------------------------------
1 | """Wrapper for flattening observations of an environment."""
2 | import gym
3 | import gym.spaces as spaces
4 |
5 |
6 | class FlattenObservation(gym.ObservationWrapper):
7 | """Observation wrapper that flattens the observation.
8 |
9 | Example:
10 | >>> import gym
11 | >>> env = gym.make('CarRacing-v1')
12 | >>> env.observation_space.shape
13 | (96, 96, 3)
14 | >>> env = FlattenObservation(env)
15 | >>> env.observation_space.shape
16 | (27648,)
17 | >>> obs = env.reset()
18 | >>> obs.shape
19 | (27648,)
20 | """
21 |
22 | def __init__(self, env: gym.Env):
23 | """Flattens the observations of an environment.
24 |
25 | Args:
26 | env: The environment to apply the wrapper
27 | """
28 | super().__init__(env)
29 | self.observation_space = spaces.flatten_space(env.observation_space)
30 |
31 | def observation(self, observation):
32 | """Flattens an observation.
33 |
34 | Args:
35 | observation: The observation to flatten
36 |
37 | Returns:
38 | The flattened observation
39 | """
40 | return spaces.flatten(self.env.observation_space, observation)
41 |
--------------------------------------------------------------------------------
/gym/wrappers/gray_scale_observation.py:
--------------------------------------------------------------------------------
1 | """Wrapper that converts a color observation to grayscale."""
2 | import numpy as np
3 |
4 | import gym
5 | from gym.spaces import Box
6 |
7 |
8 | class GrayScaleObservation(gym.ObservationWrapper):
9 | """Convert the image observation from RGB to gray scale.
10 |
11 | Example:
12 | >>> env = gym.make('CarRacing-v1')
13 | >>> env.observation_space
14 | Box(0, 255, (96, 96, 3), uint8)
15 | >>> env = GrayScaleObservation(gym.make('CarRacing-v1'))
16 | >>> env.observation_space
17 | Box(0, 255, (96, 96), uint8)
18 | >>> env = GrayScaleObservation(gym.make('CarRacing-v1'), keep_dim=True)
19 | >>> env.observation_space
20 | Box(0, 255, (96, 96, 1), uint8)
21 | """
22 |
23 | def __init__(self, env: gym.Env, keep_dim: bool = False):
24 | """Convert the image observation from RGB to gray scale.
25 |
26 | Args:
27 | env (Env): The environment to apply the wrapper
28 | keep_dim (bool): If `True`, a singleton dimension will be added, i.e. observations are of the shape AxBx1.
29 | Otherwise, they are of shape AxB.
30 | """
31 | super().__init__(env)
32 | self.keep_dim = keep_dim
33 |
34 | assert (
35 | isinstance(self.observation_space, Box)
36 | and len(self.observation_space.shape) == 3
37 | and self.observation_space.shape[-1] == 3
38 | )
39 |
40 | obs_shape = self.observation_space.shape[:2]
41 | if self.keep_dim:
42 | self.observation_space = Box(
43 | low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=np.uint8
44 | )
45 | else:
46 | self.observation_space = Box(
47 | low=0, high=255, shape=obs_shape, dtype=np.uint8
48 | )
49 |
50 | def observation(self, observation):
51 | """Converts the colour observation to greyscale.
52 |
53 | Args:
54 | observation: Color observations
55 |
56 | Returns:
57 | Grayscale observations
58 | """
59 | import cv2
60 |
61 | observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
62 | if self.keep_dim:
63 | observation = np.expand_dims(observation, -1)
64 | return observation
65 |
--------------------------------------------------------------------------------
/gym/wrappers/monitoring/__init__.py:
--------------------------------------------------------------------------------
1 | """Module for monitoring.video_recorder."""
2 |
--------------------------------------------------------------------------------
/gym/wrappers/order_enforcing.py:
--------------------------------------------------------------------------------
1 | """Wrapper to enforce the proper ordering of environment operations."""
2 | import gym
3 | from gym.error import ResetNeeded
4 |
5 |
6 | class OrderEnforcing(gym.Wrapper):
7 | """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
8 |
9 | Example:
10 | >>> from gym.envs.classic_control import CartPoleEnv
11 | >>> env = CartPoleEnv()
12 | >>> env = OrderEnforcing(env)
13 | >>> env.step(0)
14 | ResetNeeded: Cannot call env.step() before calling env.reset()
15 | >>> env.render()
16 | ResetNeeded: Cannot call env.render() before calling env.reset()
17 | >>> env.reset()
18 | >>> env.render()
19 | >>> env.step(0)
20 | """
21 |
22 | def __init__(self, env: gym.Env, disable_render_order_enforcing: bool = False):
23 | """A wrapper that will produce an error if :meth:`step` is called before an initial :meth:`reset`.
24 |
25 | Args:
26 | env: The environment to wrap
27 | disable_render_order_enforcing: If to disable render order enforcing
28 | """
29 | super().__init__(env)
30 | self._has_reset: bool = False
31 | self._disable_render_order_enforcing: bool = disable_render_order_enforcing
32 |
33 | def step(self, action):
34 | """Steps through the environment with `kwargs`."""
35 | if not self._has_reset:
36 | raise ResetNeeded("Cannot call env.step() before calling env.reset()")
37 | return self.env.step(action)
38 |
39 | def reset(self, **kwargs):
40 | """Resets the environment with `kwargs`."""
41 | self._has_reset = True
42 | return self.env.reset(**kwargs)
43 |
44 | def render(self, *args, **kwargs):
45 | """Renders the environment with `kwargs`."""
46 | if not self._disable_render_order_enforcing and not self._has_reset:
47 | raise ResetNeeded(
48 | "Cannot call `env.render()` before calling `env.reset()`, if this is a intended action, "
49 | "set `disable_render_order_enforcing=True` on the OrderEnforcer wrapper."
50 | )
51 | return self.env.render(*args, **kwargs)
52 |
53 | @property
54 | def has_reset(self):
55 | """Returns if the environment has been reset before."""
56 | return self._has_reset
57 |
--------------------------------------------------------------------------------
/gym/wrappers/render_collection.py:
--------------------------------------------------------------------------------
1 | """A wrapper that adds render collection mode to an environment."""
2 | import gym
3 |
4 |
5 | class RenderCollection(gym.Wrapper):
6 | """Save collection of render frames."""
7 |
8 | def __init__(self, env: gym.Env, pop_frames: bool = True, reset_clean: bool = True):
9 | """Initialize a :class:`RenderCollection` instance.
10 |
11 | Args:
12 | env: The environment that is being wrapped
13 | pop_frames (bool): If true, clear the collection frames after .render() is called.
14 | Default value is True.
15 | reset_clean (bool): If true, clear the collection frames when .reset() is called.
16 | Default value is True.
17 | """
18 | super().__init__(env)
19 | assert env.render_mode is not None
20 | assert not env.render_mode.endswith("_list")
21 | self.frame_list = []
22 | self.reset_clean = reset_clean
23 | self.pop_frames = pop_frames
24 |
25 | @property
26 | def render_mode(self):
27 | """Returns the collection render_mode name."""
28 | return f"{self.env.render_mode}_list"
29 |
30 | def step(self, *args, **kwargs):
31 | """Perform a step in the base environment and collect a frame."""
32 | output = self.env.step(*args, **kwargs)
33 | self.frame_list.append(self.env.render())
34 | return output
35 |
36 | def reset(self, *args, **kwargs):
37 | """Reset the base environment, eventually clear the frame_list, and collect a frame."""
38 | result = self.env.reset(*args, **kwargs)
39 |
40 | if self.reset_clean:
41 | self.frame_list = []
42 | self.frame_list.append(self.env.render())
43 |
44 | return result
45 |
46 | def render(self):
47 | """Returns the collection of frames and, if pop_frames = True, clears it."""
48 | frames = self.frame_list
49 | if self.pop_frames:
50 | self.frame_list = []
51 |
52 | return frames
53 |
--------------------------------------------------------------------------------
/gym/wrappers/rescale_action.py:
--------------------------------------------------------------------------------
1 | """Wrapper for rescaling actions to within a max and min action."""
2 | from typing import Union
3 |
4 | import numpy as np
5 |
6 | import gym
7 | from gym import spaces
8 |
9 |
10 | class RescaleAction(gym.ActionWrapper):
11 | """Affinely rescales the continuous action space of the environment to the range [min_action, max_action].
12 |
13 | The base environment :attr:`env` must have an action space of type :class:`spaces.Box`. If :attr:`min_action`
14 | or :attr:`max_action` are numpy arrays, the shape must match the shape of the environment's action space.
15 |
16 | Example:
17 | >>> import gym
18 | >>> env = gym.make('BipedalWalker-v3')
19 | >>> env.action_space
20 | Box(-1.0, 1.0, (4,), float32)
21 | >>> min_action = -0.5
22 | >>> max_action = np.array([0.0, 0.5, 1.0, 0.75])
23 | >>> env = RescaleAction(env, min_action=min_action, max_action=max_action)
24 | >>> env.action_space
25 | Box(-0.5, [0. 0.5 1. 0.75], (4,), float32)
26 | >>> RescaleAction(env, min_action, max_action).action_space == gym.spaces.Box(min_action, max_action)
27 | True
28 | """
29 |
30 | def __init__(
31 | self,
32 | env: gym.Env,
33 | min_action: Union[float, int, np.ndarray],
34 | max_action: Union[float, int, np.ndarray],
35 | ):
36 | """Initializes the :class:`RescaleAction` wrapper.
37 |
38 | Args:
39 | env (Env): The environment to apply the wrapper
40 | min_action (float, int or np.ndarray): The min values for each action. This may be a numpy array or a scalar.
41 | max_action (float, int or np.ndarray): The max values for each action. This may be a numpy array or a scalar.
42 | """
43 | assert isinstance(
44 | env.action_space, spaces.Box
45 | ), f"expected Box action space, got {type(env.action_space)}"
46 | assert np.less_equal(min_action, max_action).all(), (min_action, max_action)
47 |
48 | super().__init__(env)
49 | self.min_action = (
50 | np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + min_action
51 | )
52 | self.max_action = (
53 | np.zeros(env.action_space.shape, dtype=env.action_space.dtype) + max_action
54 | )
55 | self.action_space = spaces.Box(
56 | low=min_action,
57 | high=max_action,
58 | shape=env.action_space.shape,
59 | dtype=env.action_space.dtype,
60 | )
61 |
62 | def action(self, action):
63 | """Rescales the action affinely from [:attr:`min_action`, :attr:`max_action`] to the action space of the base environment, :attr:`env`.
64 |
65 | Args:
66 | action: The action to rescale
67 |
68 | Returns:
69 | The rescaled action
70 | """
71 | assert np.all(np.greater_equal(action, self.min_action)), (
72 | action,
73 | self.min_action,
74 | )
75 | assert np.all(np.less_equal(action, self.max_action)), (action, self.max_action)
76 | low = self.env.action_space.low
77 | high = self.env.action_space.high
78 | action = low + (high - low) * (
79 | (action - self.min_action) / (self.max_action - self.min_action)
80 | )
81 | action = np.clip(action, low, high)
82 | return action
83 |
--------------------------------------------------------------------------------
/gym/wrappers/resize_observation.py:
--------------------------------------------------------------------------------
1 | """Wrapper for resizing observations."""
2 | from typing import Union
3 |
4 | import numpy as np
5 |
6 | import gym
7 | from gym.error import DependencyNotInstalled
8 | from gym.spaces import Box
9 |
10 |
11 | class ResizeObservation(gym.ObservationWrapper):
12 | """Resize the image observation.
13 |
14 | This wrapper works on environments with image observations (or more generally observations of shape AxBxC) and resizes
15 | the observation to the shape given by the 2-tuple :attr:`shape`. The argument :attr:`shape` may also be an integer.
16 | In that case, the observation is scaled to a square of side-length :attr:`shape`.
17 |
18 | Example:
19 | >>> import gym
20 | >>> env = gym.make('CarRacing-v1')
21 | >>> env.observation_space.shape
22 | (96, 96, 3)
23 | >>> env = ResizeObservation(env, 64)
24 | >>> env.observation_space.shape
25 | (64, 64, 3)
26 | """
27 |
28 | def __init__(self, env: gym.Env, shape: Union[tuple, int]):
29 | """Resizes image observations to shape given by :attr:`shape`.
30 |
31 | Args:
32 | env: The environment to apply the wrapper
33 | shape: The shape of the resized observations
34 | """
35 | super().__init__(env)
36 | if isinstance(shape, int):
37 | shape = (shape, shape)
38 | assert all(x > 0 for x in shape), shape
39 |
40 | self.shape = tuple(shape)
41 |
42 | assert isinstance(
43 | env.observation_space, Box
44 | ), f"Expected the observation space to be Box, actual type: {type(env.observation_space)}"
45 | obs_shape = self.shape + env.observation_space.shape[2:]
46 | self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
47 |
48 | def observation(self, observation):
49 | """Updates the observations by resizing the observation to shape given by :attr:`shape`.
50 |
51 | Args:
52 | observation: The observation to reshape
53 |
54 | Returns:
55 | The reshaped observations
56 |
57 | Raises:
58 | DependencyNotInstalled: opencv-python is not installed
59 | """
60 | try:
61 | import cv2
62 | except ImportError:
63 | raise DependencyNotInstalled(
64 | "opencv is not install, run `pip install gym[other]`"
65 | )
66 |
67 | observation = cv2.resize(
68 | observation, self.shape[::-1], interpolation=cv2.INTER_AREA
69 | )
70 | if observation.ndim == 2:
71 | observation = np.expand_dims(observation, -1)
72 | return observation
73 |
--------------------------------------------------------------------------------
/gym/wrappers/step_api_compatibility.py:
--------------------------------------------------------------------------------
1 | """Implementation of StepAPICompatibility wrapper class for transforming envs between new and old step API."""
2 | import gym
3 | from gym.logger import deprecation
4 | from gym.utils.step_api_compatibility import (
5 | convert_to_done_step_api,
6 | convert_to_terminated_truncated_step_api,
7 | )
8 |
9 |
10 | class StepAPICompatibility(gym.Wrapper):
11 | r"""A wrapper which can transform an environment from new step API to old and vice-versa.
12 |
13 | Old step API refers to step() method returning (observation, reward, done, info)
14 | New step API refers to step() method returning (observation, reward, terminated, truncated, info)
15 | (Refer to docs for details on the API change)
16 |
17 | Args:
18 | env (gym.Env): the env to wrap. Can be in old or new API
19 | apply_step_compatibility (bool): Apply to convert environment to use new step API that returns two bools. (False by default)
20 |
21 | Examples:
22 | >>> env = gym.make("CartPole-v1")
23 | >>> env # wrapper not applied by default, set to new API
24 | >>>>
25 | >>> env = gym.make("CartPole-v1", apply_api_compatibility=True) # set to old API
26 | >>>>>
27 | >>> env = StepAPICompatibility(CustomEnv(), apply_step_compatibility=False) # manually using wrapper on unregistered envs
28 |
29 | """
30 |
31 | def __init__(self, env: gym.Env, output_truncation_bool: bool = True):
32 | """A wrapper which can transform an environment from new step API to old and vice-versa.
33 |
34 | Args:
35 | env (gym.Env): the env to wrap. Can be in old or new API
36 | output_truncation_bool (bool): Whether the wrapper's step method outputs two booleans (new API) or one boolean (old API)
37 | """
38 | super().__init__(env)
39 | self.output_truncation_bool = output_truncation_bool
40 | if not self.output_truncation_bool:
41 | deprecation(
42 | "Initializing environment in old step API which returns one bool instead of two."
43 | )
44 |
45 | def step(self, action):
46 | """Steps through the environment, returning 5 or 4 items depending on `apply_step_compatibility`.
47 |
48 | Args:
49 | action: action to step through the environment with
50 |
51 | Returns:
52 | (observation, reward, terminated, truncated, info) or (observation, reward, done, info)
53 | """
54 | step_returns = self.env.step(action)
55 | if self.output_truncation_bool:
56 | return convert_to_terminated_truncated_step_api(step_returns)
57 | else:
58 | return convert_to_done_step_api(step_returns)
59 |
--------------------------------------------------------------------------------
/gym/wrappers/time_aware_observation.py:
--------------------------------------------------------------------------------
1 | """Wrapper for adding time aware observations to environment observation."""
2 | import numpy as np
3 |
4 | import gym
5 | from gym.spaces import Box
6 |
7 |
8 | class TimeAwareObservation(gym.ObservationWrapper):
9 | """Augment the observation with the current time step in the episode.
10 |
11 | The observation space of the wrapped environment is assumed to be a flat :class:`Box`.
12 | In particular, pixel observations are not supported. This wrapper will append the current timestep within the current episode to the observation.
13 |
14 | Example:
15 | >>> import gym
16 | >>> env = gym.make('CartPole-v1')
17 | >>> env = TimeAwareObservation(env)
18 | >>> env.reset()
19 | array([ 0.03810719, 0.03522411, 0.02231044, -0.01088205, 0. ])
20 | >>> env.step(env.action_space.sample())[0]
21 | array([ 0.03881167, -0.16021058, 0.0220928 , 0.28875574, 1. ])
22 | """
23 |
24 | def __init__(self, env: gym.Env):
25 | """Initialize :class:`TimeAwareObservation` that requires an environment with a flat :class:`Box` observation space.
26 |
27 | Args:
28 | env: The environment to apply the wrapper
29 | """
30 | super().__init__(env)
31 | assert isinstance(env.observation_space, Box)
32 | assert env.observation_space.dtype == np.float32
33 | low = np.append(self.observation_space.low, 0.0)
34 | high = np.append(self.observation_space.high, np.inf)
35 | self.observation_space = Box(low, high, dtype=np.float32)
36 | self.is_vector_env = getattr(env, "is_vector_env", False)
37 |
38 | def observation(self, observation):
39 | """Adds to the observation with the current time step.
40 |
41 | Args:
42 | observation: The observation to add the time step to
43 |
44 | Returns:
45 | The observation with the time step appended to
46 | """
47 | return np.append(observation, self.t)
48 |
49 | def step(self, action):
50 | """Steps through the environment, incrementing the time step.
51 |
52 | Args:
53 | action: The action to take
54 |
55 | Returns:
56 | The environment's step using the action.
57 | """
58 | self.t += 1
59 | return super().step(action)
60 |
61 | def reset(self, **kwargs):
62 | """Reset the environment setting the time to zero.
63 |
64 | Args:
65 | **kwargs: Kwargs to apply to env.reset()
66 |
67 | Returns:
68 | The reset environment
69 | """
70 | self.t = 0
71 | return super().reset(**kwargs)
72 |
--------------------------------------------------------------------------------
/gym/wrappers/time_limit.py:
--------------------------------------------------------------------------------
1 | """Wrapper for limiting the time steps of an environment."""
2 | from typing import Optional
3 |
4 | import gym
5 |
6 |
7 | class TimeLimit(gym.Wrapper):
8 | """This wrapper will issue a `truncated` signal if a maximum number of timesteps is exceeded.
9 |
10 | If a truncation is not defined inside the environment itself, this is the only place that the truncation signal is issued.
11 | Critically, this is different from the `terminated` signal that originates from the underlying environment as part of the MDP.
12 |
13 | Example:
14 | >>> from gym.envs.classic_control import CartPoleEnv
15 | >>> from gym.wrappers import TimeLimit
16 | >>> env = CartPoleEnv()
17 | >>> env = TimeLimit(env, max_episode_steps=1000)
18 | """
19 |
20 | def __init__(
21 | self,
22 | env: gym.Env,
23 | max_episode_steps: Optional[int] = None,
24 | ):
25 | """Initializes the :class:`TimeLimit` wrapper with an environment and the number of steps after which truncation will occur.
26 |
27 | Args:
28 | env: The environment to apply the wrapper
29 | max_episode_steps: An optional max episode steps (if ``Ǹone``, ``env.spec.max_episode_steps`` is used)
30 | """
31 | super().__init__(env)
32 | if max_episode_steps is None and self.env.spec is not None:
33 | max_episode_steps = env.spec.max_episode_steps
34 | if self.env.spec is not None:
35 | self.env.spec.max_episode_steps = max_episode_steps
36 | self._max_episode_steps = max_episode_steps
37 | self._elapsed_steps = None
38 |
39 | def step(self, action):
40 | """Steps through the environment and if the number of steps elapsed exceeds ``max_episode_steps`` then truncate.
41 |
42 | Args:
43 | action: The environment step action
44 |
45 | Returns:
46 | The environment step ``(observation, reward, terminated, truncated, info)`` with `truncated=True`
47 | if the number of steps elapsed >= max episode steps
48 |
49 | """
50 | observation, reward, terminated, truncated, info = self.env.step(action)
51 | self._elapsed_steps += 1
52 |
53 | if self._elapsed_steps >= self._max_episode_steps:
54 | truncated = True
55 |
56 | return observation, reward, terminated, truncated, info
57 |
58 | def reset(self, **kwargs):
59 | """Resets the environment with :param:`**kwargs` and sets the number of steps elapsed to zero.
60 |
61 | Args:
62 | **kwargs: The kwargs to reset the environment with
63 |
64 | Returns:
65 | The reset environment
66 | """
67 | self._elapsed_steps = 0
68 | return self.env.reset(**kwargs)
69 |
--------------------------------------------------------------------------------
/gym/wrappers/transform_observation.py:
--------------------------------------------------------------------------------
1 | """Wrapper for transforming observations."""
2 | from typing import Any, Callable
3 |
4 | import gym
5 |
6 |
7 | class TransformObservation(gym.ObservationWrapper):
8 | """Transform the observation via an arbitrary function :attr:`f`.
9 |
10 | The function :attr:`f` should be defined on the observation space of the base environment, ``env``, and should, ideally, return values in the same space.
11 |
12 | If the transformation you wish to apply to observations returns values in a *different* space, you should subclass :class:`ObservationWrapper`, implement the transformation, and set the new observation space accordingly. If you were to use this wrapper instead, the observation space would be set incorrectly.
13 |
14 | Example:
15 | >>> import gym
16 | >>> import numpy as np
17 | >>> env = gym.make('CartPole-v1')
18 | >>> env = TransformObservation(env, lambda obs: obs + 0.1*np.random.randn(*obs.shape))
19 | >>> env.reset()
20 | array([-0.08319338, 0.04635121, -0.07394746, 0.20877492])
21 | """
22 |
23 | def __init__(self, env: gym.Env, f: Callable[[Any], Any]):
24 | """Initialize the :class:`TransformObservation` wrapper with an environment and a transform function :param:`f`.
25 |
26 | Args:
27 | env: The environment to apply the wrapper
28 | f: A function that transforms the observation
29 | """
30 | super().__init__(env)
31 | assert callable(f)
32 | self.f = f
33 |
34 | def observation(self, observation):
35 | """Transforms the observations with callable :attr:`f`.
36 |
37 | Args:
38 | observation: The observation to transform
39 |
40 | Returns:
41 | The transformed observation
42 | """
43 | return self.f(observation)
44 |
--------------------------------------------------------------------------------
/gym/wrappers/transform_reward.py:
--------------------------------------------------------------------------------
1 | """Wrapper for transforming the reward."""
2 | from typing import Callable
3 |
4 | import gym
5 | from gym import RewardWrapper
6 |
7 |
8 | class TransformReward(RewardWrapper):
9 | """Transform the reward via an arbitrary function.
10 |
11 | Warning:
12 | If the base environment specifies a reward range which is not invariant under :attr:`f`, the :attr:`reward_range` of the wrapped environment will be incorrect.
13 |
14 | Example:
15 | >>> import gym
16 | >>> env = gym.make('CartPole-v1')
17 | >>> env = TransformReward(env, lambda r: 0.01*r)
18 | >>> env.reset()
19 | >>> observation, reward, terminated, truncated, info = env.step(env.action_space.sample())
20 | >>> reward
21 | 0.01
22 | """
23 |
24 | def __init__(self, env: gym.Env, f: Callable[[float], float]):
25 | """Initialize the :class:`TransformReward` wrapper with an environment and reward transform function :param:`f`.
26 |
27 | Args:
28 | env: The environment to apply the wrapper
29 | f: A function that transforms the reward
30 | """
31 | super().__init__(env)
32 | assert callable(f)
33 | self.f = f
34 |
35 | def reward(self, reward):
36 | """Transforms the reward using callable :attr:`f`.
37 |
38 | Args:
39 | reward: The reward to transform
40 |
41 | Returns:
42 | The transformed reward
43 | """
44 | return self.f(reward)
45 |
--------------------------------------------------------------------------------
/py.Dockerfile:
--------------------------------------------------------------------------------
1 | # A Dockerfile that sets up a full Gym install with test dependencies
2 | ARG PYTHON_VERSION
3 | FROM python:$PYTHON_VERSION
4 |
5 | SHELL ["/bin/bash", "-o", "pipefail", "-c"]
6 |
7 | RUN apt-get -y update \
8 | && apt-get install --no-install-recommends -y \
9 | unzip \
10 | libglu1-mesa-dev \
11 | libgl1-mesa-dev \
12 | libosmesa6-dev \
13 | xvfb \
14 | patchelf \
15 | ffmpeg cmake \
16 | && apt-get autoremove -y \
17 | && apt-get clean \
18 | && rm -rf /var/lib/apt/lists/* \
19 | # Download mujoco
20 | && mkdir /root/.mujoco \
21 | && cd /root/.mujoco \
22 | && wget -qO- 'https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz' | tar -xzvf -
23 |
24 | ENV LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin"
25 |
26 | COPY . /usr/local/gym/
27 | WORKDIR /usr/local/gym/
28 |
29 | RUN if [ "python:${PYTHON_VERSION}" = "python:3.6.15" ] ; then pip install .[box2d,classic_control,toy_text,other] pytest=="7.0.1" --no-cache-dir; else pip install .[testing] --no-cache-dir; fi
30 |
31 | ENTRYPOINT ["/usr/local/gym/bin/docker_entrypoint"]
32 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.pyright]
2 |
3 | include = [
4 | "gym/**",
5 | "tests/**"
6 | ]
7 |
8 | exclude = [
9 | "**/node_modules",
10 | "**/__pycache__",
11 | ]
12 |
13 | strict = [
14 |
15 | ]
16 |
17 | typeCheckingMode = "basic"
18 | pythonVersion = "3.6"
19 | pythonPlatform = "All"
20 | typeshedPath = "typeshed"
21 | enableTypeIgnoreComments = true
22 |
23 | # This is required as the CI pre-commit does not download the module (i.e. numpy, pygame, box2d)
24 | # Therefore, we have to ignore missing imports
25 | reportMissingImports = "none"
26 | # Some modules are missing type stubs, which is an issue when running pyright locally
27 | reportMissingTypeStubs = false
28 | # For warning and error, will raise an error when
29 | reportInvalidTypeVarUse = "none"
30 |
31 | # reportUnknownMemberType = "warning" # -> raises 6035 warnings
32 | # reportUnknownParameterType = "warning" # -> raises 1327 warnings
33 | # reportUnknownVariableType = "warning" # -> raises 2585 warnings
34 | # reportUnknownArgumentType = "warning" # -> raises 2104 warnings
35 | reportGeneralTypeIssues = "none" # -> commented out raises 489 errors
36 | reportUntypedFunctionDecorator = "none" # -> pytest.mark.parameterize issues
37 |
38 | reportPrivateUsage = "warning"
39 | reportUnboundVariable = "warning"
40 |
41 | [tool.pytest.ini_options]
42 | filterwarnings = ['ignore:.*step API.*:DeprecationWarning'] # TODO: to be removed when old step API is removed
43 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.18.0
2 | cloudpickle>=1.2.0
3 | importlib_metadata>=4.8.0; python_version < '3.10'
4 | gym_notices>=0.0.4
5 | dataclasses==0.8; python_version == '3.6'
6 | typing_extensions==4.3.0; python_version == '3.7'
7 | opencv-python>=3.0
8 | lz4>=3.1.0
9 | matplotlib>=3.0
10 | box2d-py==2.3.5
11 | pygame==2.1.0
12 | ale-py~=0.8.0
13 | mujoco==2.2.0
14 | mujoco_py<2.2,>=2.1
15 | imageio>=2.14.1
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """Setups the project."""
2 | import itertools
3 | import re
4 |
5 | from setuptools import find_packages, setup
6 |
7 | with open("gym/version.py") as file:
8 | full_version = file.read()
9 | assert (
10 | re.match(r'VERSION = "\d\.\d+\.\d+"\n', full_version).group(0) == full_version
11 | ), f"Unexpected version: {full_version}"
12 | VERSION = re.search(r"\d\.\d+\.\d+", full_version).group(0)
13 |
14 | # Environment-specific dependencies.
15 | extras = {
16 | "atari": ["ale-py~=0.8.0"],
17 | "accept-rom-license": ["autorom[accept-rom-license]~=0.4.2"],
18 | "box2d": ["box2d-py==2.3.5", "pygame==2.1.0", "swig==4.*"],
19 | "classic_control": ["pygame==2.1.0"],
20 | "mujoco_py": ["mujoco_py<2.2,>=2.1"],
21 | "mujoco": ["mujoco==2.2", "imageio>=2.14.1"],
22 | "toy_text": ["pygame==2.1.0"],
23 | "other": ["lz4>=3.1.0", "opencv-python>=3.0", "matplotlib>=3.0", "moviepy>=1.0.0"],
24 | }
25 |
26 | # Testing dependency groups.
27 | testing_group = set(extras.keys()) - {"accept-rom-license", "atari"}
28 | extras["testing"] = list(
29 | set(itertools.chain.from_iterable(map(lambda group: extras[group], testing_group)))
30 | ) + ["pytest==7.0.1"]
31 |
32 | # All dependency groups - accept rom license as requires user to run
33 | all_groups = set(extras.keys()) - {"accept-rom-license"}
34 | extras["all"] = list(
35 | set(itertools.chain.from_iterable(map(lambda group: extras[group], all_groups)))
36 | )
37 |
38 | # Uses the readme as the description on PyPI
39 | with open("README.md") as fh:
40 | long_description = ""
41 | header_count = 0
42 | for line in fh:
43 | if line.startswith("##"):
44 | header_count += 1
45 | if header_count < 2:
46 | long_description += line
47 | else:
48 | break
49 |
50 | setup(
51 | author="Gym Community",
52 | author_email="jkterry@umd.edu",
53 | classifiers=[
54 | # Python 3.6 is minimally supported (only with basic gym environments and API)
55 | "Programming Language :: Python :: 3",
56 | "Programming Language :: Python :: 3.6",
57 | "Programming Language :: Python :: 3.7",
58 | "Programming Language :: Python :: 3.8",
59 | "Programming Language :: Python :: 3.9",
60 | "Programming Language :: Python :: 3.10",
61 | ],
62 | description="Gym: A universal API for reinforcement learning environments",
63 | extras_require=extras,
64 | install_requires=[
65 | "numpy >= 1.18.0",
66 | "cloudpickle >= 1.2.0",
67 | "importlib_metadata >= 4.8.0; python_version < '3.10'",
68 | "gym_notices >= 0.0.4",
69 | "dataclasses == 0.8; python_version == '3.6'",
70 | ],
71 | license="MIT",
72 | long_description=long_description,
73 | long_description_content_type="text/markdown",
74 | name="gym",
75 | packages=[package for package in find_packages() if package.startswith("gym")],
76 | package_data={
77 | "gym": [
78 | "envs/mujoco/assets/*.xml",
79 | "envs/classic_control/assets/*.png",
80 | "envs/toy_text/font/*.ttf",
81 | "envs/toy_text/img/*.png",
82 | "py.typed",
83 | ]
84 | },
85 | python_requires=">=3.6",
86 | tests_require=extras["testing"],
87 | url="https://www.gymlibrary.dev/",
88 | version=VERSION,
89 | zip_safe=False,
90 | )
91 |
--------------------------------------------------------------------------------
/test_requirements.txt:
--------------------------------------------------------------------------------
1 | box2d-py==2.3.5
2 | lz4>=3.1.0
3 | opencv-python>=3.0
4 | mujoco==2.2.0
5 | matplotlib>=3.0
6 | imageio>=2.14.1
7 | pygame==2.1.0
8 | mujoco_py<2.2,>=2.1
9 | pytest==7.0.1
10 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/__init__.py
--------------------------------------------------------------------------------
/tests/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/envs/__init__.py
--------------------------------------------------------------------------------
/tests/envs/test_spec.py:
--------------------------------------------------------------------------------
1 | """Tests that gym.spec works as expected."""
2 |
3 | import re
4 |
5 | import pytest
6 |
7 | import gym
8 |
9 |
10 | def test_spec():
11 | spec = gym.spec("CartPole-v1")
12 | assert spec.id == "CartPole-v1"
13 | assert spec is gym.envs.registry["CartPole-v1"]
14 |
15 |
16 | def test_spec_kwargs():
17 | map_name_value = "8x8"
18 | env = gym.make("FrozenLake-v1", map_name=map_name_value)
19 | assert env.spec.kwargs["map_name"] == map_name_value
20 |
21 |
22 | def test_spec_missing_lookup():
23 | gym.register(id="Test1-v0", entry_point="no-entry-point")
24 | gym.register(id="Test1-v15", entry_point="no-entry-point")
25 | gym.register(id="Test1-v9", entry_point="no-entry-point")
26 | gym.register(id="Other1-v100", entry_point="no-entry-point")
27 |
28 | with pytest.raises(
29 | gym.error.DeprecatedEnv,
30 | match=re.escape(
31 | "Environment version v1 for `Test1` is deprecated. Please use `Test1-v15` instead."
32 | ),
33 | ):
34 | gym.spec("Test1-v1")
35 |
36 | with pytest.raises(
37 | gym.error.UnregisteredEnv,
38 | match=re.escape(
39 | "Environment version `v1000` for environment `Test1` doesn't exist. It provides versioned environments: [ `v0`, `v9`, `v15` ]."
40 | ),
41 | ):
42 | gym.spec("Test1-v1000")
43 |
44 | with pytest.raises(
45 | gym.error.UnregisteredEnv,
46 | match=re.escape("Environment Unknown1 doesn't exist. "),
47 | ):
48 | gym.spec("Unknown1-v1")
49 |
50 |
51 | def test_spec_malformed_lookup():
52 | with pytest.raises(
53 | gym.error.Error,
54 | match=f'^{re.escape("Malformed environment ID: “Breakout-v0”.(Currently all IDs must be of the form [namespace/](env-name)-v(version). (namespace is optional))")}$',
55 | ):
56 | gym.spec("“Breakout-v0”")
57 |
58 |
59 | def test_spec_versioned_lookups():
60 | gym.register("test/Test2-v5", "no-entry-point")
61 |
62 | with pytest.raises(
63 | gym.error.VersionNotFound,
64 | match=re.escape(
65 | "Environment version `v9` for environment `test/Test2` doesn't exist. It provides versioned environments: [ `v5` ]."
66 | ),
67 | ):
68 | gym.spec("test/Test2-v9")
69 |
70 | with pytest.raises(
71 | gym.error.DeprecatedEnv,
72 | match=re.escape(
73 | "Environment version v4 for `test/Test2` is deprecated. Please use `test/Test2-v5` instead."
74 | ),
75 | ):
76 | gym.spec("test/Test2-v4")
77 |
78 | assert gym.spec("test/Test2-v5") is not None
79 |
80 |
81 | def test_spec_default_lookups():
82 | gym.register("test/Test3", "no-entry-point")
83 |
84 | with pytest.raises(
85 | gym.error.DeprecatedEnv,
86 | match=re.escape(
87 | "Environment version `v0` for environment `test/Test3` doesn't exist. It provides the default version test/Test3`."
88 | ),
89 | ):
90 | gym.spec("test/Test3-v0")
91 |
92 | assert gym.spec("test/Test3") is not None
93 |
--------------------------------------------------------------------------------
/tests/envs/utils.py:
--------------------------------------------------------------------------------
1 | """Finds all the specs that we can test with"""
2 | from typing import List, Optional
3 |
4 | import numpy as np
5 |
6 | import gym
7 | from gym import error, logger
8 | from gym.envs.registration import EnvSpec
9 |
10 |
11 | def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
12 | """Tries to make the environment showing if it is possible.
13 |
14 | Warning the environments have no wrappers, including time limit and order enforcing.
15 | """
16 | # To avoid issues with registered environments during testing, we check that the spec entry points are from gym.envs.
17 | if "gym.envs." in env_spec.entry_point:
18 | try:
19 | return env_spec.make(disable_env_checker=True).unwrapped
20 | except (ImportError, error.DependencyNotInstalled) as e:
21 | logger.warn(f"Not testing {env_spec.id} due to error: {e}")
22 | return None
23 |
24 |
25 | # Tries to make all environment to test with
26 | all_testing_initialised_envs: List[Optional[gym.Env]] = [
27 | try_make_env(env_spec) for env_spec in gym.envs.registry.values()
28 | ]
29 | all_testing_initialised_envs: List[gym.Env] = [
30 | env for env in all_testing_initialised_envs if env is not None
31 | ]
32 |
33 | # All testing, mujoco and gym environment specs
34 | all_testing_env_specs: List[EnvSpec] = [
35 | env.spec for env in all_testing_initialised_envs
36 | ]
37 | mujoco_testing_env_specs: List[EnvSpec] = [
38 | env_spec
39 | for env_spec in all_testing_env_specs
40 | if "gym.envs.mujoco" in env_spec.entry_point
41 | ]
42 | gym_testing_env_specs: List[EnvSpec] = [
43 | env_spec
44 | for env_spec in all_testing_env_specs
45 | if any(
46 | f"gym.envs.{ep}" in env_spec.entry_point
47 | for ep in ["box2d", "classic_control", "toy_text"]
48 | )
49 | ]
50 | # TODO, add minimum testing env spec in testing
51 | minimum_testing_env_specs = [
52 | env_spec
53 | for env_spec in [
54 | "CartPole-v1",
55 | "MountainCarContinuous-v0",
56 | "LunarLander-v2",
57 | "LunarLanderContinuous-v2",
58 | "CarRacing-v2",
59 | "Blackjack-v1",
60 | "Reacher-v4",
61 | ]
62 | if env_spec in all_testing_env_specs
63 | ]
64 |
65 |
66 | def assert_equals(a, b, prefix=None):
67 | """Assert equality of data structures `a` and `b`.
68 |
69 | Args:
70 | a: first data structure
71 | b: second data structure
72 | prefix: prefix for failed assertion message for types and dicts
73 | """
74 | assert type(a) == type(b), f"{prefix}Differing types: {a} and {b}"
75 | if isinstance(a, dict):
76 | assert list(a.keys()) == list(b.keys()), f"{prefix}Key sets differ: {a} and {b}"
77 |
78 | for k in a.keys():
79 | v_a = a[k]
80 | v_b = b[k]
81 | assert_equals(v_a, v_b)
82 | elif isinstance(a, np.ndarray):
83 | np.testing.assert_array_equal(a, b)
84 | elif isinstance(a, tuple):
85 | for elem_from_a, elem_from_b in zip(a, b):
86 | assert_equals(elem_from_a, elem_from_b)
87 | else:
88 | assert a == b
89 |
--------------------------------------------------------------------------------
/tests/envs/utils_envs.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 |
4 | class RegisterDuringMakeEnv(gym.Env):
5 | """Used in `test_registration.py` to check if `env.make` can import and register an env"""
6 |
7 | def __init__(self):
8 | self.action_space = gym.spaces.Discrete(1)
9 | self.observation_space = gym.spaces.Discrete(1)
10 |
11 |
12 | class ArgumentEnv(gym.Env):
13 | observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
14 | action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
15 |
16 | def __init__(self, arg1, arg2, arg3):
17 | self.arg1 = arg1
18 | self.arg2 = arg2
19 | self.arg3 = arg3
20 |
21 |
22 | # Environments to test render_mode
23 | class NoHuman(gym.Env):
24 | """Environment that does not have human-rendering."""
25 |
26 | metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
27 |
28 | def __init__(self, render_mode=None):
29 | assert render_mode in self.metadata["render_modes"]
30 | self.render_mode = render_mode
31 |
32 |
33 | class NoHumanOldAPI(gym.Env):
34 | """Environment that does not have human-rendering."""
35 |
36 | metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
37 |
38 | def __init__(self):
39 | pass
40 |
41 |
42 | class NoHumanNoRGB(gym.Env):
43 | """Environment that has neither human- nor rgb-rendering"""
44 |
45 | metadata = {"render_modes": ["ascii"], "render_fps": 4}
46 |
47 | def __init__(self, render_mode=None):
48 | assert render_mode in self.metadata["render_modes"]
49 | self.render_mode = render_mode
50 |
--------------------------------------------------------------------------------
/tests/spaces/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/spaces/__init__.py
--------------------------------------------------------------------------------
/tests/spaces/test_discrete.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym.spaces import Discrete
4 |
5 |
6 | def test_space_legacy_pickling():
7 | """Test the legacy pickle of Discrete that is missing the `start` parameter."""
8 | legacy_state = {
9 | "shape": (
10 | 1,
11 | 2,
12 | 3,
13 | ),
14 | "dtype": np.int64,
15 | "np_random": np.random.default_rng(),
16 | "n": 3,
17 | }
18 | space = Discrete(1)
19 | space.__setstate__(legacy_state)
20 |
21 | assert space.shape == legacy_state["shape"]
22 | assert space.np_random == legacy_state["np_random"]
23 | assert space.n == 3
24 | assert space.dtype == legacy_state["dtype"]
25 |
26 | # Test that start is missing
27 | assert "start" in space.__dict__
28 | del space.__dict__["start"] # legacy did not include start param
29 | assert "start" not in space.__dict__
30 |
31 | space.__setstate__(legacy_state)
32 | assert space.start == 0
33 |
34 |
35 | def test_sample_mask():
36 | space = Discrete(4, start=2)
37 | assert 2 <= space.sample() < 6
38 | assert space.sample(mask=np.array([0, 1, 0, 0], dtype=np.int8)) == 3
39 | assert space.sample(mask=np.array([0, 0, 0, 0], dtype=np.int8)) == 2
40 | assert space.sample(mask=np.array([0, 1, 0, 1], dtype=np.int8)) in [3, 5]
41 |
--------------------------------------------------------------------------------
/tests/spaces/test_multibinary.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym.spaces import MultiBinary
4 |
5 |
6 | def test_sample():
7 | space = MultiBinary(4)
8 |
9 | sample = space.sample(mask=np.array([0, 0, 1, 1], dtype=np.int8))
10 | assert np.all(sample == [0, 0, 1, 1])
11 |
12 | sample = space.sample(mask=np.array([0, 1, 2, 2], dtype=np.int8))
13 | assert sample[0] == 0 and sample[1] == 1
14 | assert sample[2] == 0 or sample[2] == 1
15 | assert sample[3] == 0 or sample[3] == 1
16 |
17 | space = MultiBinary(np.array([2, 3]))
18 | sample = space.sample(mask=np.array([[0, 0, 0], [1, 1, 1]], dtype=np.int8))
19 | assert np.all(sample == [[0, 0, 0], [1, 1, 1]]), sample
20 |
--------------------------------------------------------------------------------
/tests/spaces/test_multidiscrete.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from gym.spaces import Discrete, MultiDiscrete
4 | from gym.utils.env_checker import data_equivalence
5 |
6 |
7 | def test_multidiscrete_as_tuple():
8 | # 1D multi-discrete
9 | space = MultiDiscrete([3, 4, 5])
10 |
11 | assert space.shape == (3,)
12 | assert space[0] == Discrete(3)
13 | assert space[0:1] == MultiDiscrete([3])
14 | assert space[0:2] == MultiDiscrete([3, 4])
15 | assert space[:] == space and space[:] is not space
16 |
17 | # 2D multi-discrete
18 | space = MultiDiscrete([[3, 4, 5], [6, 7, 8]])
19 |
20 | assert space.shape == (2, 3)
21 | assert space[0, 1] == Discrete(4)
22 | assert space[0] == MultiDiscrete([3, 4, 5])
23 | assert space[0:1] == MultiDiscrete([[3, 4, 5]])
24 | assert space[0:2, :] == MultiDiscrete([[3, 4, 5], [6, 7, 8]])
25 | assert space[:, 0:1] == MultiDiscrete([[3], [6]])
26 | assert space[0:2, 0:2] == MultiDiscrete([[3, 4], [6, 7]])
27 | assert space[:] == space and space[:] is not space
28 | assert space[:, :] == space and space[:, :] is not space
29 |
30 |
31 | def test_multidiscrete_subspace_reproducibility():
32 | # 1D multi-discrete
33 | space = MultiDiscrete([100, 200, 300])
34 | space.seed()
35 |
36 | assert data_equivalence(space[0].sample(), space[0].sample())
37 | assert data_equivalence(space[0:1].sample(), space[0:1].sample())
38 | assert data_equivalence(space[0:2].sample(), space[0:2].sample())
39 | assert data_equivalence(space[:].sample(), space[:].sample())
40 | assert data_equivalence(space[:].sample(), space.sample())
41 |
42 | # 2D multi-discrete
43 | space = MultiDiscrete([[300, 400, 500], [600, 700, 800]])
44 | space.seed()
45 |
46 | assert data_equivalence(space[0, 1].sample(), space[0, 1].sample())
47 | assert data_equivalence(space[0].sample(), space[0].sample())
48 | assert data_equivalence(space[0:1].sample(), space[0:1].sample())
49 | assert data_equivalence(space[0:2, :].sample(), space[0:2, :].sample())
50 | assert data_equivalence(space[:, 0:1].sample(), space[:, 0:1].sample())
51 | assert data_equivalence(space[0:2, 0:2].sample(), space[0:2, 0:2].sample())
52 | assert data_equivalence(space[:].sample(), space[:].sample())
53 | assert data_equivalence(space[:, :].sample(), space[:, :].sample())
54 | assert data_equivalence(space[:, :].sample(), space.sample())
55 |
56 |
57 | def test_multidiscrete_length():
58 | space = MultiDiscrete(nvec=[3, 2, 4])
59 | assert len(space) == 3
60 |
61 | space = MultiDiscrete(nvec=[[2, 3], [3, 2]])
62 | with pytest.warns(
63 | UserWarning,
64 | match="Getting the length of a multi-dimensional MultiDiscrete space.",
65 | ):
66 | assert len(space) == 2
67 |
--------------------------------------------------------------------------------
/tests/spaces/test_sequence.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | import gym.spaces
7 |
8 |
9 | def test_sample():
10 | """Tests the sequence sampling works as expects and the errors are correctly raised."""
11 | space = gym.spaces.Sequence(gym.spaces.Box(0, 1))
12 |
13 | # Test integer mask length
14 | for length in range(4):
15 | sample = space.sample(mask=(length, None))
16 | assert sample in space
17 | assert len(sample) == length
18 |
19 | with pytest.raises(
20 | AssertionError,
21 | match=re.escape(
22 | "Expects the length mask to be greater than or equal to zero, actual value: -1"
23 | ),
24 | ):
25 | space.sample(mask=(-1, None))
26 |
27 | # Test np.array mask length
28 | sample = space.sample(mask=(np.array([5]), None))
29 | assert sample in space
30 | assert len(sample) == 5
31 |
32 | sample = space.sample(mask=(np.array([3, 4, 5]), None))
33 | assert sample in space
34 | assert len(sample) in [3, 4, 5]
35 |
36 | with pytest.raises(
37 | AssertionError,
38 | match=re.escape(
39 | "Expects the shape of the length mask to be 1-dimensional, actual shape: (2, 2)"
40 | ),
41 | ):
42 | space.sample(mask=(np.array([[2, 2], [2, 2]]), None))
43 |
44 | with pytest.raises(
45 | AssertionError,
46 | match=re.escape(
47 | "Expects all values in the length_mask to be greater than or equal to zero, actual values: [ 1 2 -1]"
48 | ),
49 | ):
50 | space.sample(mask=(np.array([1, 2, -1]), None))
51 |
52 | # Test with an invalid length
53 | with pytest.raises(
54 | TypeError,
55 | match=re.escape(
56 | "Expects the type of length_mask to an integer or a np.ndarray, actual type: "
57 | ),
58 | ):
59 | space.sample(mask=("abc", None))
60 |
--------------------------------------------------------------------------------
/tests/spaces/test_space.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import pytest
4 |
5 | from gym import Space
6 | from gym.spaces import utils
7 |
8 | TESTING_SPACE = Space()
9 |
10 |
11 | @pytest.mark.parametrize(
12 | "func",
13 | [
14 | TESTING_SPACE.sample,
15 | partial(TESTING_SPACE.contains, None),
16 | partial(utils.flatdim, TESTING_SPACE),
17 | partial(utils.flatten, TESTING_SPACE, None),
18 | partial(utils.flatten_space, TESTING_SPACE),
19 | partial(utils.unflatten, TESTING_SPACE, None),
20 | ],
21 | )
22 | def test_not_implemented_errors(func):
23 | with pytest.raises(NotImplementedError):
24 | func()
25 |
--------------------------------------------------------------------------------
/tests/spaces/test_text.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from gym.spaces import Text
7 |
8 |
9 | def test_sample_mask():
10 | space = Text(min_length=1, max_length=5)
11 |
12 | # Test the sample length
13 | sample = space.sample(mask=(3, None))
14 | assert sample in space
15 | assert len(sample) == 3
16 |
17 | sample = space.sample(mask=None)
18 | assert sample in space
19 | assert 1 <= len(sample) <= 5
20 |
21 | with pytest.raises(
22 | ValueError,
23 | match=re.escape(
24 | "Trying to sample with a minimum length > 0 (1) but the character mask is all zero meaning that no character could be sampled."
25 | ),
26 | ):
27 | space.sample(mask=(3, np.zeros(len(space.character_set), dtype=np.int8)))
28 |
29 | space = Text(min_length=0, max_length=5)
30 | sample = space.sample(
31 | mask=(None, np.zeros(len(space.character_set), dtype=np.int8))
32 | )
33 | assert sample in space
34 | assert sample == ""
35 |
36 | # Test the sample characters
37 | space = Text(max_length=5, charset="abcd")
38 |
39 | sample = space.sample(mask=(3, np.array([0, 1, 0, 0], dtype=np.int8)))
40 | assert sample in space
41 | assert sample == "bbb"
42 |
--------------------------------------------------------------------------------
/tests/spaces/test_tuple.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import gym.spaces
5 | from gym.spaces import Box, Dict, Discrete, MultiBinary, Tuple
6 | from gym.utils.env_checker import data_equivalence
7 |
8 |
9 | def test_sequence_inheritance():
10 | """The gym Tuple space inherits from abc.Sequences, this test checks all functions work"""
11 | spaces = [Discrete(5), Discrete(10), Discrete(5)]
12 | tuple_space = Tuple(spaces)
13 |
14 | assert len(tuple_space) == len(spaces)
15 | # Test indexing
16 | for i in range(len(tuple_space)):
17 | assert tuple_space[i] == spaces[i]
18 |
19 | # Test iterable
20 | for space in tuple_space:
21 | assert space in spaces
22 |
23 | # Test count
24 | assert tuple_space.count(Discrete(5)) == 2
25 | assert tuple_space.count(Discrete(6)) == 0
26 | assert tuple_space.count(MultiBinary(2)) == 0
27 |
28 | # Test index
29 | assert tuple_space.index(Discrete(5)) == 0
30 | assert tuple_space.index(Discrete(5), 1) == 2
31 |
32 | # Test errors
33 | with pytest.raises(ValueError):
34 | tuple_space.index(Discrete(10), 0, 1)
35 | with pytest.raises(IndexError):
36 | assert tuple_space[4]
37 |
38 |
39 | @pytest.mark.parametrize(
40 | "space, seed, expected_len",
41 | [
42 | (Tuple([Discrete(5), Discrete(4)]), None, 2),
43 | (Tuple([Discrete(5), Discrete(4)]), 123, 3),
44 | (Tuple([Discrete(5), Discrete(4)]), (123, 456), 2),
45 | (
46 | Tuple(
47 | (Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))
48 | ),
49 | (123, (456, 789)),
50 | 3,
51 | ),
52 | (
53 | Tuple(
54 | (
55 | Discrete(3),
56 | Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)),
57 | )
58 | ),
59 | (123, {"position": 456, "velocity": 789}),
60 | 3,
61 | ),
62 | ],
63 | )
64 | def test_seeds(space, seed, expected_len):
65 | seeds = space.seed(seed)
66 | assert isinstance(seeds, list) and all(isinstance(elem, int) for elem in seeds)
67 | assert len(seeds) == expected_len
68 |
69 | sample1 = space.sample()
70 |
71 | seeds2 = space.seed(seed)
72 | sample2 = space.sample()
73 |
74 | data_equivalence(seeds, seeds2)
75 | data_equivalence(sample1, sample2)
76 |
77 |
78 | @pytest.mark.parametrize(
79 | "space_fn",
80 | [
81 | lambda: Tuple(["abc"]),
82 | lambda: Tuple([gym.spaces.Box(0, 1), "abc"]),
83 | lambda: Tuple("abc"),
84 | ],
85 | )
86 | def test_bad_space_calls(space_fn):
87 | with pytest.raises(AssertionError):
88 | space_fn()
89 |
90 |
91 | def test_contains_promotion():
92 | space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (2,))))
93 |
94 | assert (
95 | np.array([0.0], dtype=np.float32),
96 | np.array([0.0, 0.0], dtype=np.float32),
97 | ) in space
98 |
99 | space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(-1, 0, (1,))))
100 | assert np.array([[0.0], [0.0]], dtype=np.float32) in space
101 |
102 |
103 | def test_bad_seed():
104 | space = gym.spaces.Tuple((gym.spaces.Box(0, 1), gym.spaces.Box(0, 1)))
105 | with pytest.raises(
106 | TypeError,
107 | match="Expected seed type: list, tuple, int or None, actual type: ",
108 | ):
109 | space.seed(0.0)
110 |
--------------------------------------------------------------------------------
/tests/spaces/utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 |
5 | from gym.spaces import (
6 | Box,
7 | Dict,
8 | Discrete,
9 | Graph,
10 | MultiBinary,
11 | MultiDiscrete,
12 | Sequence,
13 | Space,
14 | Text,
15 | Tuple,
16 | )
17 |
18 | TESTING_FUNDAMENTAL_SPACES = [
19 | Discrete(3),
20 | Discrete(3, start=-1),
21 | Box(low=0.0, high=1.0),
22 | Box(low=0.0, high=np.inf, shape=(2, 2)),
23 | Box(low=np.array([-10.0, 0.0]), high=np.array([10.0, 10.0]), dtype=np.float64),
24 | Box(low=-np.inf, high=0.0, shape=(2, 1)),
25 | Box(low=0.0, high=np.inf, shape=(2, 1)),
26 | MultiDiscrete([2, 2]),
27 | MultiDiscrete([[2, 3], [3, 2]]),
28 | MultiBinary(8),
29 | MultiBinary([2, 3]),
30 | Text(6),
31 | Text(min_length=3, max_length=6),
32 | Text(6, charset="abcdef"),
33 | ]
34 | TESTING_FUNDAMENTAL_SPACES_IDS = [f"{space}" for space in TESTING_FUNDAMENTAL_SPACES]
35 |
36 |
37 | TESTING_COMPOSITE_SPACES = [
38 | # Tuple spaces
39 | Tuple([Discrete(5), Discrete(4)]),
40 | Tuple(
41 | (
42 | Discrete(5),
43 | Box(
44 | low=np.array([0.0, 0.0]),
45 | high=np.array([1.0, 5.0]),
46 | dtype=np.float64,
47 | ),
48 | )
49 | ),
50 | Tuple((Discrete(5), Tuple((Box(low=0.0, high=1.0, shape=(3,)), Discrete(2))))),
51 | Tuple((Discrete(3), Dict(position=Box(low=0.0, high=1.0), velocity=Discrete(2)))),
52 | Tuple((Graph(node_space=Box(-1, 1, shape=(2, 1)), edge_space=None), Discrete(2))),
53 | # Dict spaces
54 | Dict(
55 | {
56 | "position": Discrete(5),
57 | "velocity": Box(
58 | low=np.array([0.0, 0.0]),
59 | high=np.array([1.0, 5.0]),
60 | dtype=np.float64,
61 | ),
62 | }
63 | ),
64 | Dict(
65 | position=Discrete(6),
66 | velocity=Box(
67 | low=np.array([0.0, 0.0]),
68 | high=np.array([1.0, 5.0]),
69 | dtype=np.float64,
70 | ),
71 | ),
72 | Dict(
73 | {
74 | "a": Box(low=0, high=1, shape=(3, 3)),
75 | "b": Dict(
76 | {
77 | "b_1": Box(low=-100, high=100, shape=(2,)),
78 | "b_2": Box(low=-1, high=1, shape=(2,)),
79 | }
80 | ),
81 | "c": Discrete(4),
82 | }
83 | ),
84 | Dict(
85 | a=Dict(
86 | a=Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=None),
87 | b=Box(-100, 100, shape=(2, 2)),
88 | ),
89 | b=Tuple((Box(-100, 100, shape=(2,)), Box(-100, 100, shape=(2,)))),
90 | ),
91 | # Graph spaces
92 | Graph(node_space=Box(low=-100, high=100, shape=(3, 4)), edge_space=Discrete(5)),
93 | Graph(node_space=Discrete(5), edge_space=Box(low=-100, high=100, shape=(3, 4))),
94 | Graph(node_space=Discrete(3), edge_space=Discrete(4)),
95 | # Sequence spaces
96 | Sequence(Discrete(4)),
97 | Sequence(Dict({"feature": Box(0, 1, (3,))})),
98 | Sequence(Graph(node_space=Box(-100, 100, shape=(2, 2)), edge_space=Discrete(4))),
99 | ]
100 | TESTING_COMPOSITE_SPACES_IDS = [f"{space}" for space in TESTING_COMPOSITE_SPACES]
101 |
102 | TESTING_SPACES: List[Space] = TESTING_FUNDAMENTAL_SPACES + TESTING_COMPOSITE_SPACES
103 | TESTING_SPACES_IDS = TESTING_FUNDAMENTAL_SPACES_IDS + TESTING_COMPOSITE_SPACES_IDS
104 |
--------------------------------------------------------------------------------
/tests/testing_env.py:
--------------------------------------------------------------------------------
1 | """Provides a generic testing environment for use in tests with custom reset, step and render functions."""
2 | import types
3 | from typing import Any, Dict, Optional, Tuple, Union
4 |
5 | import gym
6 | from gym import spaces
7 | from gym.core import ActType, ObsType
8 | from gym.envs.registration import EnvSpec
9 |
10 |
11 | def basic_reset_fn(
12 | self,
13 | *,
14 | seed: Optional[int] = None,
15 | options: Optional[dict] = None,
16 | ) -> Union[ObsType, Tuple[ObsType, dict]]:
17 | """A basic reset function that will pass the environment check using random actions from the observation space."""
18 | super(GenericTestEnv, self).reset(seed=seed)
19 | self.observation_space.seed(seed)
20 | return self.observation_space.sample(), {"options": options}
21 |
22 |
23 | def new_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
24 | """A step function that follows the new step api that will pass the environment check using random actions from the observation space."""
25 | return self.observation_space.sample(), 0, False, False, {}
26 |
27 |
28 | def old_step_fn(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
29 | """A step function that follows the old step api that will pass the environment check using random actions from the observation space."""
30 | return self.observation_space.sample(), 0, False, {}
31 |
32 |
33 | def basic_render_fn(self):
34 | """Basic render fn that does nothing."""
35 | pass
36 |
37 |
38 | # todo: change all testing environment to this generic class
39 | class GenericTestEnv(gym.Env):
40 | """A generic testing environment for use in testing with modified environments are required."""
41 |
42 | def __init__(
43 | self,
44 | action_space: spaces.Space = spaces.Box(0, 1, (1,)),
45 | observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
46 | reset_fn: callable = basic_reset_fn,
47 | step_fn: callable = new_step_fn,
48 | render_fn: callable = basic_render_fn,
49 | metadata: Optional[Dict[str, Any]] = None,
50 | render_mode: Optional[str] = None,
51 | spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"),
52 | ):
53 | self.metadata = {} if metadata is None else metadata
54 | self.render_mode = render_mode
55 | self.spec = spec
56 |
57 | if observation_space is not None:
58 | self.observation_space = observation_space
59 | if action_space is not None:
60 | self.action_space = action_space
61 |
62 | if reset_fn is not None:
63 | self.reset = types.MethodType(reset_fn, self)
64 | if step_fn is not None:
65 | self.step = types.MethodType(step_fn, self)
66 | if render_fn is not None:
67 | self.render = types.MethodType(render_fn, self)
68 |
69 | def reset(
70 | self,
71 | *,
72 | seed: Optional[int] = None,
73 | options: Optional[dict] = None,
74 | ) -> Union[ObsType, Tuple[ObsType, dict]]:
75 | # If you need a default working reset function, use `basic_reset_fn` above
76 | raise NotImplementedError("TestingEnv reset_fn is not set.")
77 |
78 | def step(self, action: ActType) -> Tuple[ObsType, float, bool, dict]:
79 | raise NotImplementedError("TestingEnv step_fn is not set.")
80 |
81 | def render(self):
82 | raise NotImplementedError("testingEnv render_fn is not set.")
83 |
--------------------------------------------------------------------------------
/tests/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/utils/__init__.py
--------------------------------------------------------------------------------
/tests/utils/test_seeding.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | from gym import error
4 | from gym.utils import seeding
5 |
6 |
7 | def test_invalid_seeds():
8 | for seed in [-1, "test"]:
9 | try:
10 | seeding.np_random(seed)
11 | except error.Error:
12 | pass
13 | else:
14 | assert False, f"Invalid seed {seed} passed validation"
15 |
16 |
17 | def test_valid_seeds():
18 | for seed in [0, 1]:
19 | random, seed1 = seeding.np_random(seed)
20 | assert seed == seed1
21 |
22 |
23 | def test_rng_pickle():
24 | rng, _ = seeding.np_random(seed=0)
25 | pickled = pickle.dumps(rng)
26 | rng2 = pickle.loads(pickled)
27 | assert isinstance(
28 | rng2, seeding.RandomNumberGenerator
29 | ), "Unpickled object is not a RandomNumberGenerator"
30 | assert rng.random() == rng2.random()
31 |
--------------------------------------------------------------------------------
/tests/vector/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/vector/__init__.py
--------------------------------------------------------------------------------
/tests/vector/test_vector_env_info.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import gym
5 | from gym.vector.sync_vector_env import SyncVectorEnv
6 | from tests.vector.utils import make_env
7 |
8 | ENV_ID = "CartPole-v1"
9 | NUM_ENVS = 3
10 | ENV_STEPS = 50
11 | SEED = 42
12 |
13 |
14 | @pytest.mark.parametrize("asynchronous", [True, False])
15 | def test_vector_env_info(asynchronous):
16 | env = gym.vector.make(
17 | ENV_ID, num_envs=NUM_ENVS, asynchronous=asynchronous, disable_env_checker=True
18 | )
19 | env.reset(seed=SEED)
20 | for _ in range(ENV_STEPS):
21 | env.action_space.seed(SEED)
22 | action = env.action_space.sample()
23 | _, _, terminateds, truncateds, infos = env.step(action)
24 | if any(terminateds) or any(truncateds):
25 | assert len(infos["final_observation"]) == NUM_ENVS
26 | assert len(infos["_final_observation"]) == NUM_ENVS
27 |
28 | assert isinstance(infos["final_observation"], np.ndarray)
29 | assert isinstance(infos["_final_observation"], np.ndarray)
30 |
31 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
32 | if terminated or truncated:
33 | assert infos["_final_observation"][i]
34 | else:
35 | assert not infos["_final_observation"][i]
36 | assert infos["final_observation"][i] is None
37 |
38 |
39 | @pytest.mark.parametrize("concurrent_ends", [1, 2, 3])
40 | def test_vector_env_info_concurrent_termination(concurrent_ends):
41 | # envs that need to terminate together will have the same action
42 | actions = [0] * concurrent_ends + [1] * (NUM_ENVS - concurrent_ends)
43 | envs = [make_env(ENV_ID, SEED) for _ in range(NUM_ENVS)]
44 | envs = SyncVectorEnv(envs)
45 |
46 | for _ in range(ENV_STEPS):
47 | _, _, terminateds, truncateds, infos = envs.step(actions)
48 | if any(terminateds) or any(truncateds):
49 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
50 | if i < concurrent_ends:
51 | assert terminated or truncated
52 | assert infos["_final_observation"][i]
53 | else:
54 | assert not infos["_final_observation"][i]
55 | assert infos["final_observation"][i] is None
56 | return
57 |
--------------------------------------------------------------------------------
/tests/vector/test_vector_env_wrapper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gym.vector import VectorEnvWrapper, make
4 |
5 |
6 | class DummyWrapper(VectorEnvWrapper):
7 | def __init__(self, env):
8 | self.env = env
9 | self.counter = 0
10 |
11 | def reset_async(self, **kwargs):
12 | super().reset_async()
13 | self.counter += 1
14 |
15 |
16 | def test_vector_env_wrapper_inheritance():
17 | env = make("FrozenLake-v1", asynchronous=False)
18 | wrapped = DummyWrapper(env)
19 | wrapped.reset()
20 | assert wrapped.counter == 1
21 |
22 |
23 | def test_vector_env_wrapper_attributes():
24 | """Test if `set_attr`, `call` methods for VecEnvWrapper get correctly forwarded to the vector env it is wrapping."""
25 | env = make("CartPole-v1", num_envs=3)
26 | wrapped = DummyWrapper(make("CartPole-v1", num_envs=3))
27 |
28 | assert np.allclose(wrapped.call("gravity"), env.call("gravity"))
29 | env.set_attr("gravity", [20.0, 20.0, 20.0])
30 | wrapped.set_attr("gravity", [20.0, 20.0, 20.0])
31 | assert np.allclose(wrapped.get_attr("gravity"), env.get_attr("gravity"))
32 |
--------------------------------------------------------------------------------
/tests/vector/test_vector_make.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym.vector import AsyncVectorEnv, SyncVectorEnv
5 | from gym.wrappers import OrderEnforcing, TimeLimit, TransformObservation
6 | from gym.wrappers.env_checker import PassiveEnvChecker
7 | from tests.wrappers.utils import has_wrapper
8 |
9 |
10 | def test_vector_make_id():
11 | env = gym.vector.make("CartPole-v1")
12 | assert isinstance(env, AsyncVectorEnv)
13 | assert env.num_envs == 1
14 | env.close()
15 |
16 |
17 | @pytest.mark.parametrize("num_envs", [1, 3, 10])
18 | def test_vector_make_num_envs(num_envs):
19 | env = gym.vector.make("CartPole-v1", num_envs=num_envs)
20 | assert env.num_envs == num_envs
21 | env.close()
22 |
23 |
24 | def test_vector_make_asynchronous():
25 | env = gym.vector.make("CartPole-v1", asynchronous=True)
26 | assert isinstance(env, AsyncVectorEnv)
27 | env.close()
28 |
29 | env = gym.vector.make("CartPole-v1", asynchronous=False)
30 | assert isinstance(env, SyncVectorEnv)
31 | env.close()
32 |
33 |
34 | def test_vector_make_wrappers():
35 | env = gym.vector.make("CartPole-v1", num_envs=2, asynchronous=False)
36 | assert isinstance(env, SyncVectorEnv)
37 | assert len(env.envs) == 2
38 |
39 | sub_env = env.envs[0]
40 | assert isinstance(sub_env, gym.Env)
41 | if sub_env.spec.order_enforce:
42 | assert has_wrapper(sub_env, OrderEnforcing)
43 | if sub_env.spec.max_episode_steps is not None:
44 | assert has_wrapper(sub_env, TimeLimit)
45 |
46 | assert all(
47 | has_wrapper(sub_env, TransformObservation) is False for sub_env in env.envs
48 | )
49 | env.close()
50 |
51 | env = gym.vector.make(
52 | "CartPole-v1",
53 | num_envs=2,
54 | asynchronous=False,
55 | wrappers=lambda _env: TransformObservation(_env, lambda obs: obs * 2),
56 | )
57 | # As asynchronous environment are inaccessible, synchronous vector must be used
58 | assert isinstance(env, SyncVectorEnv)
59 | assert all(has_wrapper(sub_env, TransformObservation) for sub_env in env.envs)
60 |
61 | env.close()
62 |
63 |
64 | def test_vector_make_disable_env_checker():
65 | # As asynchronous environment are inaccessible, synchronous vector must be used
66 | env = gym.vector.make("CartPole-v1", num_envs=1, asynchronous=False)
67 | assert isinstance(env, SyncVectorEnv)
68 | assert has_wrapper(env.envs[0], PassiveEnvChecker)
69 | env.close()
70 |
71 | env = gym.vector.make("CartPole-v1", num_envs=5, asynchronous=False)
72 | assert isinstance(env, SyncVectorEnv)
73 | assert has_wrapper(env.envs[0], PassiveEnvChecker)
74 | assert all(
75 | has_wrapper(env.envs[i], PassiveEnvChecker) is False for i in [1, 2, 3, 4]
76 | )
77 | env.close()
78 |
79 | env = gym.vector.make(
80 | "CartPole-v1", num_envs=3, asynchronous=False, disable_env_checker=True
81 | )
82 | assert isinstance(env, SyncVectorEnv)
83 | assert all(has_wrapper(sub_env, PassiveEnvChecker) is False for sub_env in env.envs)
84 | env.close()
85 |
--------------------------------------------------------------------------------
/tests/wrappers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/openai/gym/dcd185843a62953e27c2d54dc8c2d647d604b635/tests/wrappers/__init__.py
--------------------------------------------------------------------------------
/tests/wrappers/test_clip_action.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import gym
4 | from gym.wrappers import ClipAction
5 |
6 |
7 | def test_clip_action():
8 | # mountaincar: action-based rewards
9 | env = gym.make("MountainCarContinuous-v0", disable_env_checker=True)
10 | wrapped_env = ClipAction(
11 | gym.make("MountainCarContinuous-v0", disable_env_checker=True)
12 | )
13 |
14 | seed = 0
15 |
16 | env.reset(seed=seed)
17 | wrapped_env.reset(seed=seed)
18 |
19 | actions = [[0.4], [1.2], [-0.3], [0.0], [-2.5]]
20 | for action in actions:
21 | obs1, r1, ter1, trunc1, _ = env.step(
22 | np.clip(action, env.action_space.low, env.action_space.high)
23 | )
24 | obs2, r2, ter2, trunc2, _ = wrapped_env.step(action)
25 | assert np.allclose(r1, r2)
26 | assert np.allclose(obs1, obs2)
27 | assert ter1 == ter2
28 | assert trunc1 == trunc2
29 |
--------------------------------------------------------------------------------
/tests/wrappers/test_filter_observation.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | import gym
7 | from gym import spaces
8 | from gym.wrappers.filter_observation import FilterObservation
9 |
10 |
11 | class FakeEnvironment(gym.Env):
12 | def __init__(
13 | self, render_mode=None, observation_keys: Tuple[str, ...] = ("state",)
14 | ):
15 | self.observation_space = spaces.Dict(
16 | {
17 | name: spaces.Box(shape=(2,), low=-1, high=1, dtype=np.float32)
18 | for name in observation_keys
19 | }
20 | )
21 | self.action_space = spaces.Box(shape=(1,), low=-1, high=1, dtype=np.float32)
22 | self.render_mode = render_mode
23 |
24 | def render(self, mode="human"):
25 | image_shape = (32, 32, 3)
26 | return np.zeros(image_shape, dtype=np.uint8)
27 |
28 | def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
29 | super().reset(seed=seed)
30 | observation = self.observation_space.sample()
31 | return observation, {}
32 |
33 | def step(self, action):
34 | del action
35 | observation = self.observation_space.sample()
36 | reward, terminal, info = 0.0, False, {}
37 | return observation, reward, terminal, info
38 |
39 |
40 | FILTER_OBSERVATION_TEST_CASES = (
41 | (("key1", "key2"), ("key1",)),
42 | (("key1", "key2"), ("key1", "key2")),
43 | (("key1",), None),
44 | (("key1",), ("key1",)),
45 | )
46 |
47 | ERROR_TEST_CASES = (
48 | ("key", ValueError, "All the filter_keys must be included..*"),
49 | (False, TypeError, "'bool' object is not iterable"),
50 | (1, TypeError, "'int' object is not iterable"),
51 | )
52 |
53 |
54 | class TestFilterObservation:
55 | @pytest.mark.parametrize(
56 | "observation_keys,filter_keys", FILTER_OBSERVATION_TEST_CASES
57 | )
58 | def test_filter_observation(self, observation_keys, filter_keys):
59 | env = FakeEnvironment(observation_keys=observation_keys)
60 |
61 | # Make sure we are testing the right environment for the test.
62 | observation_space = env.observation_space
63 | assert isinstance(observation_space, spaces.Dict)
64 |
65 | wrapped_env = FilterObservation(env, filter_keys=filter_keys)
66 |
67 | assert isinstance(wrapped_env.observation_space, spaces.Dict)
68 |
69 | if filter_keys is None:
70 | filter_keys = tuple(observation_keys)
71 |
72 | assert len(wrapped_env.observation_space.spaces) == len(filter_keys)
73 | assert tuple(wrapped_env.observation_space.spaces.keys()) == tuple(filter_keys)
74 |
75 | # Check that the added space item is consistent with the added observation.
76 | observation, info = wrapped_env.reset()
77 | assert len(observation) == len(filter_keys)
78 | assert isinstance(info, dict)
79 |
80 | @pytest.mark.parametrize("filter_keys,error_type,error_match", ERROR_TEST_CASES)
81 | def test_raises_with_incorrect_arguments(
82 | self, filter_keys, error_type, error_match
83 | ):
84 | env = FakeEnvironment(observation_keys=("key1", "key2"))
85 |
86 | with pytest.raises(error_type, match=error_match):
87 | FilterObservation(env, filter_keys=filter_keys)
88 |
--------------------------------------------------------------------------------
/tests/wrappers/test_flatten.py:
--------------------------------------------------------------------------------
1 | """Tests for the flatten observation wrapper."""
2 |
3 | from collections import OrderedDict
4 | from typing import Optional
5 |
6 | import numpy as np
7 | import pytest
8 |
9 | import gym
10 | from gym.spaces import Box, Dict, flatten, unflatten
11 | from gym.wrappers import FlattenObservation
12 |
13 |
14 | class FakeEnvironment(gym.Env):
15 | def __init__(self, observation_space):
16 | self.observation_space = observation_space
17 |
18 | def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
19 | super().reset(seed=seed)
20 | self.observation = self.observation_space.sample()
21 | return self.observation, {}
22 |
23 |
24 | OBSERVATION_SPACES = (
25 | (
26 | Dict(
27 | OrderedDict(
28 | [
29 | ("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)),
30 | ("key2", Box(shape=(), low=1, high=1, dtype=np.float32)),
31 | ("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)),
32 | ]
33 | )
34 | ),
35 | True,
36 | ),
37 | (
38 | Dict(
39 | OrderedDict(
40 | [
41 | ("key2", Box(shape=(), low=0, high=0, dtype=np.float32)),
42 | ("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)),
43 | ("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)),
44 | ]
45 | )
46 | ),
47 | True,
48 | ),
49 | (
50 | Dict(
51 | {
52 | "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32),
53 | "key2": Box(shape=(), low=-1, high=1, dtype=np.float32),
54 | "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32),
55 | }
56 | ),
57 | False,
58 | ),
59 | )
60 |
61 |
62 | class TestFlattenEnvironment:
63 | @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
64 | def test_flattened_environment(self, observation_space, ordered_values):
65 | """
66 | make sure that flattened observations occur in the order expected
67 | """
68 | env = FakeEnvironment(observation_space=observation_space)
69 | wrapped_env = FlattenObservation(env)
70 | flattened, info = wrapped_env.reset()
71 |
72 | unflattened = unflatten(env.observation_space, flattened)
73 | original = env.observation
74 |
75 | self._check_observations(original, flattened, unflattened, ordered_values)
76 |
77 | @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES)
78 | def test_flatten_unflatten(self, observation_space, ordered_values):
79 | """
80 | test flatten and unflatten functions directly
81 | """
82 | original = observation_space.sample()
83 |
84 | flattened = flatten(observation_space, original)
85 | unflattened = unflatten(observation_space, flattened)
86 |
87 | self._check_observations(original, flattened, unflattened, ordered_values)
88 |
89 | def _check_observations(self, original, flattened, unflattened, ordered_values):
90 | # make sure that unflatten(flatten(original)) == original
91 | assert set(unflattened.keys()) == set(original.keys())
92 | for k, v in original.items():
93 | np.testing.assert_allclose(unflattened[k], v)
94 |
95 | if ordered_values:
96 | # make sure that the values were flattened in the order they appeared in the
97 | # OrderedDict
98 | np.testing.assert_allclose(sorted(flattened), flattened)
99 |
--------------------------------------------------------------------------------
/tests/wrappers/test_flatten_observation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import gym
5 | from gym import spaces
6 | from gym.wrappers import FlattenObservation
7 |
8 |
9 | @pytest.mark.parametrize("env_id", ["Blackjack-v1"])
10 | def test_flatten_observation(env_id):
11 | env = gym.make(env_id, disable_env_checker=True)
12 | wrapped_env = FlattenObservation(env)
13 |
14 | obs, info = env.reset()
15 | wrapped_obs, wrapped_obs_info = wrapped_env.reset()
16 |
17 | space = spaces.Tuple((spaces.Discrete(32), spaces.Discrete(11), spaces.Discrete(2)))
18 | wrapped_space = spaces.Box(0, 1, [32 + 11 + 2], dtype=np.int64)
19 |
20 | assert space.contains(obs)
21 | assert wrapped_space.contains(wrapped_obs)
22 | assert isinstance(info, dict)
23 | assert isinstance(wrapped_obs_info, dict)
24 |
--------------------------------------------------------------------------------
/tests/wrappers/test_frame_stack.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import gym
5 | from gym.wrappers import FrameStack
6 |
7 | try:
8 | import lz4
9 | except ImportError:
10 | lz4 = None
11 |
12 |
13 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1", "CarRacing-v2"])
14 | @pytest.mark.parametrize("num_stack", [2, 3, 4])
15 | @pytest.mark.parametrize(
16 | "lz4_compress",
17 | [
18 | pytest.param(
19 | True,
20 | marks=pytest.mark.skipif(
21 | lz4 is None, reason="Need lz4 to run tests with compression"
22 | ),
23 | ),
24 | False,
25 | ],
26 | )
27 | def test_frame_stack(env_id, num_stack, lz4_compress):
28 | env = gym.make(env_id, disable_env_checker=True)
29 | shape = env.observation_space.shape
30 | env = FrameStack(env, num_stack, lz4_compress)
31 | assert env.observation_space.shape == (num_stack,) + shape
32 | assert env.observation_space.dtype == env.env.observation_space.dtype
33 |
34 | dup = gym.make(env_id, disable_env_checker=True)
35 |
36 | obs, _ = env.reset(seed=0)
37 | dup_obs, _ = dup.reset(seed=0)
38 | assert np.allclose(obs[-1], dup_obs)
39 |
40 | for _ in range(num_stack**2):
41 | action = env.action_space.sample()
42 | dup_obs, _, dup_terminated, dup_truncated, _ = dup.step(action)
43 | obs, _, terminated, truncated, _ = env.step(action)
44 |
45 | assert dup_terminated == terminated
46 | assert dup_truncated == truncated
47 | assert np.allclose(obs[-1], dup_obs)
48 |
49 | if terminated or truncated:
50 | break
51 |
52 | assert len(obs) == num_stack
53 |
--------------------------------------------------------------------------------
/tests/wrappers/test_gray_scale_observation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym import spaces
5 | from gym.wrappers import GrayScaleObservation
6 |
7 |
8 | @pytest.mark.parametrize("env_id", ["CarRacing-v2"])
9 | @pytest.mark.parametrize("keep_dim", [True, False])
10 | def test_gray_scale_observation(env_id, keep_dim):
11 | rgb_env = gym.make(env_id, disable_env_checker=True)
12 |
13 | assert isinstance(rgb_env.observation_space, spaces.Box)
14 | assert len(rgb_env.observation_space.shape) == 3
15 | assert rgb_env.observation_space.shape[-1] == 3
16 |
17 | wrapped_env = GrayScaleObservation(rgb_env, keep_dim=keep_dim)
18 | assert isinstance(wrapped_env.observation_space, spaces.Box)
19 | if keep_dim:
20 | assert len(wrapped_env.observation_space.shape) == 3
21 | assert wrapped_env.observation_space.shape[-1] == 1
22 | else:
23 | assert len(wrapped_env.observation_space.shape) == 2
24 |
25 | wrapped_obs, info = wrapped_env.reset()
26 | assert wrapped_obs in wrapped_env.observation_space
27 |
--------------------------------------------------------------------------------
/tests/wrappers/test_human_rendering.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import pytest
4 |
5 | import gym
6 | from gym.wrappers import HumanRendering
7 |
8 |
9 | def test_human_rendering():
10 | for mode in ["rgb_array", "rgb_array_list"]:
11 | env = HumanRendering(
12 | gym.make("CartPole-v1", render_mode=mode, disable_env_checker=True)
13 | )
14 | assert env.render_mode == "human"
15 | env.reset()
16 |
17 | for _ in range(75):
18 | _, _, terminated, truncated, _ = env.step(env.action_space.sample())
19 | if terminated or truncated:
20 | env.reset()
21 |
22 | env.close()
23 |
24 | env = gym.make("CartPole-v1", render_mode="human")
25 | with pytest.raises(
26 | AssertionError,
27 | match=re.escape(
28 | "Expected env.render_mode to be one of 'rgb_array' or 'rgb_array_list' but got 'human'"
29 | ),
30 | ):
31 | HumanRendering(env)
32 | env.close()
33 |
--------------------------------------------------------------------------------
/tests/wrappers/test_order_enforcing.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym.envs.classic_control import CartPoleEnv
5 | from gym.error import ResetNeeded
6 | from gym.wrappers import OrderEnforcing
7 | from tests.envs.utils import all_testing_env_specs
8 | from tests.wrappers.utils import has_wrapper
9 |
10 |
11 | @pytest.mark.parametrize(
12 | "spec", all_testing_env_specs, ids=[spec.id for spec in all_testing_env_specs]
13 | )
14 | def test_gym_make_order_enforcing(spec):
15 | """Checks that gym.make wrappers the environment with the OrderEnforcing wrapper."""
16 | env = gym.make(spec.id, disable_env_checker=True)
17 |
18 | assert has_wrapper(env, OrderEnforcing)
19 |
20 |
21 | def test_order_enforcing():
22 | """Checks that the order enforcing works as expected, raising an error before reset is called and not after."""
23 | # The reason for not using gym.make is that all environments are by default wrapped in the order enforcing wrapper
24 | env = CartPoleEnv(render_mode="rgb_array_list")
25 | assert not has_wrapper(env, OrderEnforcing)
26 |
27 | # Assert that the order enforcing works for step and render before reset
28 | order_enforced_env = OrderEnforcing(env)
29 | assert order_enforced_env.has_reset is False
30 | with pytest.raises(ResetNeeded):
31 | order_enforced_env.step(0)
32 | with pytest.raises(ResetNeeded):
33 | order_enforced_env.render()
34 | assert order_enforced_env.has_reset is False
35 |
36 | # Assert that the Assertion errors are not raised after reset
37 | order_enforced_env.reset()
38 | assert order_enforced_env.has_reset is True
39 | order_enforced_env.step(0)
40 | order_enforced_env.render()
41 |
42 | # Assert that with disable_render_order_enforcing works, the environment has already been reset
43 | env = CartPoleEnv(render_mode="rgb_array_list")
44 | env = OrderEnforcing(env, disable_render_order_enforcing=True)
45 | env.render() # no assertion error
46 |
--------------------------------------------------------------------------------
/tests/wrappers/test_passive_env_checker.py:
--------------------------------------------------------------------------------
1 | import re
2 | import warnings
3 |
4 | import numpy as np
5 | import pytest
6 |
7 | import gym
8 | from gym.wrappers.env_checker import PassiveEnvChecker
9 | from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
10 | from tests.envs.utils import all_testing_initialised_envs
11 | from tests.testing_env import GenericTestEnv
12 |
13 |
14 | @pytest.mark.parametrize(
15 | "env",
16 | all_testing_initialised_envs,
17 | ids=[env.spec.id for env in all_testing_initialised_envs],
18 | )
19 | def test_passive_checker_wrapper_warnings(env):
20 | with warnings.catch_warnings(record=True) as caught_warnings:
21 | checker_env = PassiveEnvChecker(env)
22 | checker_env.reset()
23 | checker_env.step(checker_env.action_space.sample())
24 | # todo, add check for render, bugged due to mujoco v2/3 and v4 envs
25 |
26 | checker_env.close()
27 |
28 | for warning in caught_warnings:
29 | if warning.message.args[0] not in PASSIVE_CHECK_IGNORE_WARNING:
30 | raise gym.error.Error(f"Unexpected warning: {warning.message}")
31 |
32 |
33 | @pytest.mark.parametrize(
34 | "env, message",
35 | [
36 | (
37 | GenericTestEnv(action_space=None),
38 | "The environment must specify an action space. https://www.gymlibrary.dev/content/environment_creation/",
39 | ),
40 | (
41 | GenericTestEnv(action_space="error"),
42 | "action space does not inherit from `gym.spaces.Space`, actual type: ",
43 | ),
44 | (
45 | GenericTestEnv(observation_space=None),
46 | "The environment must specify an observation space. https://www.gymlibrary.dev/content/environment_creation/",
47 | ),
48 | (
49 | GenericTestEnv(observation_space="error"),
50 | "observation space does not inherit from `gym.spaces.Space`, actual type: ",
51 | ),
52 | ],
53 | )
54 | def test_initialise_failures(env, message):
55 | with pytest.raises(AssertionError, match=f"^{re.escape(message)}$"):
56 | PassiveEnvChecker(env)
57 |
58 | env.close()
59 |
60 |
61 | def _reset_failure(self, seed=None, options=None):
62 | return np.array([-1.0], dtype=np.float32), {}
63 |
64 |
65 | def _step_failure(self, action):
66 | return "error"
67 |
68 |
69 | def test_api_failures():
70 | env = GenericTestEnv(
71 | reset_fn=_reset_failure,
72 | step_fn=_step_failure,
73 | metadata={"render_modes": "error"},
74 | )
75 | env = PassiveEnvChecker(env)
76 | assert env.checked_reset is False
77 | assert env.checked_step is False
78 | assert env.checked_render is False
79 |
80 | with pytest.warns(
81 | UserWarning,
82 | match=re.escape(
83 | "The obs returned by the `reset()` method is not within the observation space"
84 | ),
85 | ):
86 | env.reset()
87 | assert env.checked_reset
88 |
89 | with pytest.raises(
90 | AssertionError,
91 | match="Expects step result to be a tuple, actual type: ",
92 | ):
93 | env.step(env.action_space.sample())
94 | assert env.checked_step
95 |
96 | with pytest.warns(
97 | UserWarning,
98 | match=r"Expects the render_modes to be a sequence \(i\.e\. list, tuple\), actual type: ",
99 | ):
100 | env.render()
101 | assert env.checked_render
102 |
103 | env.close()
104 |
--------------------------------------------------------------------------------
/tests/wrappers/test_record_video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import gym
5 | from gym.wrappers import capped_cubic_video_schedule
6 |
7 |
8 | def test_record_video_using_default_trigger():
9 | env = gym.make(
10 | "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
11 | )
12 | env = gym.wrappers.RecordVideo(env, "videos")
13 | env.reset()
14 | for _ in range(199):
15 | action = env.action_space.sample()
16 | _, _, terminated, truncated, _ = env.step(action)
17 | if terminated or truncated:
18 | env.reset()
19 | env.close()
20 | assert os.path.isdir("videos")
21 | mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
22 | assert len(mp4_files) == sum(
23 | capped_cubic_video_schedule(i) for i in range(env.episode_id + 1)
24 | )
25 | shutil.rmtree("videos")
26 |
27 |
28 | def test_record_video_reset():
29 | env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
30 | env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
31 | ob_space = env.observation_space
32 | obs, info = env.reset()
33 | env.close()
34 | assert os.path.isdir("videos")
35 | shutil.rmtree("videos")
36 | assert ob_space.contains(obs)
37 | assert isinstance(info, dict)
38 |
39 |
40 | def test_record_video_step_trigger():
41 | env = gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True)
42 | env._max_episode_steps = 20
43 | env = gym.wrappers.RecordVideo(env, "videos", step_trigger=lambda x: x % 100 == 0)
44 | env.reset()
45 | for _ in range(199):
46 | action = env.action_space.sample()
47 | _, _, terminated, truncated, _ = env.step(action)
48 | if terminated or truncated:
49 | env.reset()
50 | env.close()
51 | assert os.path.isdir("videos")
52 | mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
53 | assert len(mp4_files) == 2
54 | shutil.rmtree("videos")
55 |
56 |
57 | def make_env(gym_id, seed, **kwargs):
58 | def thunk():
59 | env = gym.make(gym_id, disable_env_checker=True, **kwargs)
60 | env._max_episode_steps = 20
61 | if seed == 1:
62 | env = gym.wrappers.RecordVideo(
63 | env, "videos", step_trigger=lambda x: x % 100 == 0
64 | )
65 | return env
66 |
67 | return thunk
68 |
69 |
70 | def test_record_video_within_vector():
71 | envs = gym.vector.SyncVectorEnv(
72 | [make_env("CartPole-v1", 1 + i, render_mode="rgb_array") for i in range(2)]
73 | )
74 | envs = gym.wrappers.RecordEpisodeStatistics(envs)
75 | envs.reset()
76 | for i in range(199):
77 | _, _, _, _, infos = envs.step(envs.action_space.sample())
78 |
79 | # break when every env is done
80 | if "episode" in infos and all(infos["_episode"]):
81 | print(f"episode_reward={infos['episode']['r']}")
82 |
83 | assert os.path.isdir("videos")
84 | mp4_files = [file for file in os.listdir("videos") if file.endswith(".mp4")]
85 | assert len(mp4_files) == 2
86 | shutil.rmtree("videos")
87 |
--------------------------------------------------------------------------------
/tests/wrappers/test_rescale_action.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import gym
5 | from gym.wrappers import RescaleAction
6 |
7 |
8 | def test_rescale_action():
9 | env = gym.make("CartPole-v1", disable_env_checker=True)
10 | with pytest.raises(AssertionError):
11 | env = RescaleAction(env, -1, 1)
12 | del env
13 |
14 | env = gym.make("Pendulum-v1", disable_env_checker=True)
15 | wrapped_env = RescaleAction(
16 | gym.make("Pendulum-v1", disable_env_checker=True), -1, 1
17 | )
18 |
19 | seed = 0
20 |
21 | obs, info = env.reset(seed=seed)
22 | wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=seed)
23 | assert np.allclose(obs, wrapped_obs)
24 |
25 | obs, reward, _, _, _ = env.step([1.5])
26 | with pytest.raises(AssertionError):
27 | wrapped_env.step([1.5])
28 | wrapped_obs, wrapped_reward, _, _, _ = wrapped_env.step([0.75])
29 |
30 | assert np.allclose(obs, wrapped_obs)
31 | assert np.allclose(reward, wrapped_reward)
32 |
--------------------------------------------------------------------------------
/tests/wrappers/test_resize_observation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym import spaces
5 | from gym.wrappers import ResizeObservation
6 |
7 |
8 | @pytest.mark.parametrize("env_id", ["CarRacing-v2"])
9 | @pytest.mark.parametrize("shape", [16, 32, (8, 5), [10, 7]])
10 | def test_resize_observation(env_id, shape):
11 | env = gym.make(env_id, disable_env_checker=True)
12 | env = ResizeObservation(env, shape)
13 |
14 | assert isinstance(env.observation_space, spaces.Box)
15 | assert env.observation_space.shape[-1] == 3
16 | obs, _ = env.reset()
17 | if isinstance(shape, int):
18 | assert env.observation_space.shape[:2] == (shape, shape)
19 | assert obs.shape == (shape, shape, 3)
20 | else:
21 | assert env.observation_space.shape[:2] == tuple(shape)
22 | assert obs.shape == tuple(shape) + (3,)
23 |
--------------------------------------------------------------------------------
/tests/wrappers/test_step_compatibility.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym.spaces import Discrete
5 | from gym.wrappers import StepAPICompatibility
6 |
7 |
8 | class OldStepEnv(gym.Env):
9 | def __init__(self):
10 | self.action_space = Discrete(2)
11 | self.observation_space = Discrete(2)
12 |
13 | def step(self, action):
14 | obs = self.observation_space.sample()
15 | rew = 0
16 | done = False
17 | info = {}
18 | return obs, rew, done, info
19 |
20 |
21 | class NewStepEnv(gym.Env):
22 | def __init__(self):
23 | self.action_space = Discrete(2)
24 | self.observation_space = Discrete(2)
25 |
26 | def step(self, action):
27 | obs = self.observation_space.sample()
28 | rew = 0
29 | terminated = False
30 | truncated = False
31 | info = {}
32 | return obs, rew, terminated, truncated, info
33 |
34 |
35 | @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
36 | @pytest.mark.parametrize("output_truncation_bool", [None, True])
37 | def test_step_compatibility_to_new_api(env, output_truncation_bool):
38 | if output_truncation_bool is None:
39 | env = StepAPICompatibility(env())
40 | else:
41 | env = StepAPICompatibility(env(), output_truncation_bool)
42 | step_returns = env.step(0)
43 | _, _, terminated, truncated, _ = step_returns
44 | assert isinstance(terminated, bool)
45 | assert isinstance(truncated, bool)
46 |
47 |
48 | @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv])
49 | def test_step_compatibility_to_old_api(env):
50 | env = StepAPICompatibility(env(), False)
51 | step_returns = env.step(0)
52 | assert len(step_returns) == 4
53 | _, _, done, _ = step_returns
54 | assert isinstance(done, bool)
55 |
56 |
57 | @pytest.mark.parametrize("apply_api_compatibility", [None, True, False])
58 | def test_step_compatibility_in_make(apply_api_compatibility):
59 | gym.register("OldStepEnv-v0", entry_point=OldStepEnv)
60 |
61 | if apply_api_compatibility is not None:
62 | env = gym.make(
63 | "OldStepEnv-v0",
64 | apply_api_compatibility=apply_api_compatibility,
65 | disable_env_checker=True,
66 | )
67 | else:
68 | env = gym.make("OldStepEnv-v0", disable_env_checker=True)
69 |
70 | env.reset()
71 | step_returns = env.step(0)
72 | if apply_api_compatibility:
73 | assert len(step_returns) == 5
74 | _, _, terminated, truncated, _ = step_returns
75 | assert isinstance(terminated, bool)
76 | assert isinstance(truncated, bool)
77 | else:
78 | assert len(step_returns) == 4
79 | _, _, done, _ = step_returns
80 | assert isinstance(done, bool)
81 |
82 | gym.envs.registry.pop("OldStepEnv-v0")
83 |
--------------------------------------------------------------------------------
/tests/wrappers/test_time_aware_observation.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym import spaces
5 | from gym.wrappers import TimeAwareObservation
6 |
7 |
8 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
9 | def test_time_aware_observation(env_id):
10 | env = gym.make(env_id, disable_env_checker=True)
11 | wrapped_env = TimeAwareObservation(env)
12 |
13 | assert isinstance(env.observation_space, spaces.Box)
14 | assert isinstance(wrapped_env.observation_space, spaces.Box)
15 | assert wrapped_env.observation_space.shape[0] == env.observation_space.shape[0] + 1
16 |
17 | obs, info = env.reset()
18 | wrapped_obs, wrapped_obs_info = wrapped_env.reset()
19 | assert wrapped_env.t == 0.0
20 | assert wrapped_obs[-1] == 0.0
21 | assert wrapped_obs.shape[0] == obs.shape[0] + 1
22 |
23 | wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample())
24 | assert wrapped_env.t == 1.0
25 | assert wrapped_obs[-1] == 1.0
26 | assert wrapped_obs.shape[0] == obs.shape[0] + 1
27 |
28 | wrapped_obs, _, _, _, _ = wrapped_env.step(env.action_space.sample())
29 | assert wrapped_env.t == 2.0
30 | assert wrapped_obs[-1] == 2.0
31 | assert wrapped_obs.shape[0] == obs.shape[0] + 1
32 |
33 | wrapped_obs, wrapped_obs_info = wrapped_env.reset()
34 | assert wrapped_env.t == 0.0
35 | assert wrapped_obs[-1] == 0.0
36 | assert wrapped_obs.shape[0] == obs.shape[0] + 1
37 |
--------------------------------------------------------------------------------
/tests/wrappers/test_time_limit.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym.envs.classic_control.pendulum import PendulumEnv
5 | from gym.wrappers import TimeLimit
6 |
7 |
8 | def test_time_limit_reset_info():
9 | env = gym.make("CartPole-v1", disable_env_checker=True)
10 | env = TimeLimit(env)
11 | ob_space = env.observation_space
12 | obs, info = env.reset()
13 | assert ob_space.contains(obs)
14 | assert isinstance(info, dict)
15 |
16 |
17 | @pytest.mark.parametrize("double_wrap", [False, True])
18 | def test_time_limit_wrapper(double_wrap):
19 | # The pendulum env does not terminate by default
20 | # so we are sure termination is only due to timeout
21 | env = PendulumEnv()
22 | max_episode_length = 20
23 | env = TimeLimit(env, max_episode_length)
24 | if double_wrap:
25 | env = TimeLimit(env, max_episode_length)
26 | env.reset()
27 | terminated, truncated = False, False
28 | n_steps = 0
29 | info = {}
30 | while not (terminated or truncated):
31 | n_steps += 1
32 | _, _, terminated, truncated, info = env.step(env.action_space.sample())
33 |
34 | assert n_steps == max_episode_length
35 | assert truncated
36 |
37 |
38 | @pytest.mark.parametrize("double_wrap", [False, True])
39 | def test_termination_on_last_step(double_wrap):
40 | # Special case: termination at the last timestep
41 | # Truncation due to timeout also happens at the same step
42 |
43 | env = PendulumEnv()
44 |
45 | def patched_step(_action):
46 | return env.observation_space.sample(), 0.0, True, False, {}
47 |
48 | env.step = patched_step
49 |
50 | max_episode_length = 1
51 | env = TimeLimit(env, max_episode_length)
52 | if double_wrap:
53 | env = TimeLimit(env, max_episode_length)
54 | env.reset()
55 | _, _, terminated, truncated, _ = env.step(env.action_space.sample())
56 | assert terminated is True
57 | assert truncated is True
58 |
--------------------------------------------------------------------------------
/tests/wrappers/test_transform_observation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import gym
5 | from gym.wrappers import TransformObservation
6 |
7 |
8 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
9 | def test_transform_observation(env_id):
10 | def affine_transform(x):
11 | return 3 * x + 2
12 |
13 | env = gym.make(env_id, disable_env_checker=True)
14 | wrapped_env = TransformObservation(
15 | gym.make(env_id, disable_env_checker=True), lambda obs: affine_transform(obs)
16 | )
17 |
18 | obs, info = env.reset(seed=0)
19 | wrapped_obs, wrapped_obs_info = wrapped_env.reset(seed=0)
20 | assert np.allclose(wrapped_obs, affine_transform(obs))
21 | assert isinstance(wrapped_obs_info, dict)
22 |
23 | action = env.action_space.sample()
24 | obs, reward, terminated, truncated, _ = env.step(action)
25 | (
26 | wrapped_obs,
27 | wrapped_reward,
28 | wrapped_terminated,
29 | wrapped_truncated,
30 | _,
31 | ) = wrapped_env.step(action)
32 | assert np.allclose(wrapped_obs, affine_transform(obs))
33 | assert np.allclose(wrapped_reward, reward)
34 | assert wrapped_terminated == terminated
35 | assert wrapped_truncated == truncated
36 |
--------------------------------------------------------------------------------
/tests/wrappers/test_transform_reward.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import gym
5 | from gym.wrappers import TransformReward
6 |
7 |
8 | @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"])
9 | def test_transform_reward(env_id):
10 | # use case #1: scale
11 | scales = [0.1, 200]
12 | for scale in scales:
13 | env = gym.make(env_id, disable_env_checker=True)
14 | wrapped_env = TransformReward(
15 | gym.make(env_id, disable_env_checker=True), lambda r: scale * r
16 | )
17 | action = env.action_space.sample()
18 |
19 | env.reset(seed=0)
20 | wrapped_env.reset(seed=0)
21 |
22 | _, reward, _, _, _ = env.step(action)
23 | _, wrapped_reward, _, _, _ = wrapped_env.step(action)
24 |
25 | assert wrapped_reward == scale * reward
26 | del env, wrapped_env
27 |
28 | # use case #2: clip
29 | min_r = -0.0005
30 | max_r = 0.0002
31 | env = gym.make(env_id, disable_env_checker=True)
32 | wrapped_env = TransformReward(
33 | gym.make(env_id, disable_env_checker=True), lambda r: np.clip(r, min_r, max_r)
34 | )
35 | action = env.action_space.sample()
36 |
37 | env.reset(seed=0)
38 | wrapped_env.reset(seed=0)
39 |
40 | _, reward, _, _, _ = env.step(action)
41 | _, wrapped_reward, _, _, _ = wrapped_env.step(action)
42 |
43 | assert abs(wrapped_reward) < abs(reward)
44 | assert wrapped_reward == -0.0005 or wrapped_reward == 0.0002
45 | del env, wrapped_env
46 |
47 | # use case #3: sign
48 | env = gym.make(env_id, disable_env_checker=True)
49 | wrapped_env = TransformReward(
50 | gym.make(env_id, disable_env_checker=True), lambda r: np.sign(r)
51 | )
52 |
53 | env.reset(seed=0)
54 | wrapped_env.reset(seed=0)
55 |
56 | for _ in range(1000):
57 | action = env.action_space.sample()
58 | _, wrapped_reward, terminated, truncated, _ = wrapped_env.step(action)
59 | assert wrapped_reward in [-1.0, 0.0, 1.0]
60 | if terminated or truncated:
61 | break
62 | del env, wrapped_env
63 |
--------------------------------------------------------------------------------
/tests/wrappers/test_vector_list_info.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import gym
4 | from gym.wrappers import RecordEpisodeStatistics, VectorListInfo
5 |
6 | ENV_ID = "CartPole-v1"
7 | NUM_ENVS = 3
8 | ENV_STEPS = 50
9 | SEED = 42
10 |
11 |
12 | def test_usage_in_vector_env():
13 | env = gym.make(ENV_ID, disable_env_checker=True)
14 | vector_env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
15 |
16 | VectorListInfo(vector_env)
17 |
18 | with pytest.raises(AssertionError):
19 | VectorListInfo(env)
20 |
21 |
22 | def test_info_to_list():
23 | env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
24 | wrapped_env = VectorListInfo(env_to_wrap)
25 | wrapped_env.action_space.seed(SEED)
26 | _, info = wrapped_env.reset(seed=SEED)
27 | assert isinstance(info, list)
28 | assert len(info) == NUM_ENVS
29 |
30 | for _ in range(ENV_STEPS):
31 | action = wrapped_env.action_space.sample()
32 | _, _, terminateds, truncateds, list_info = wrapped_env.step(action)
33 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
34 | if terminated or truncated:
35 | assert "final_observation" in list_info[i]
36 | else:
37 | assert "final_observation" not in list_info[i]
38 |
39 |
40 | def test_info_to_list_statistics():
41 | env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS, disable_env_checker=True)
42 | wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
43 | _, info = wrapped_env.reset(seed=SEED)
44 | wrapped_env.action_space.seed(SEED)
45 | assert isinstance(info, list)
46 | assert len(info) == NUM_ENVS
47 |
48 | for _ in range(ENV_STEPS):
49 | action = wrapped_env.action_space.sample()
50 | _, _, terminateds, truncateds, list_info = wrapped_env.step(action)
51 | for i, (terminated, truncated) in enumerate(zip(terminateds, truncateds)):
52 | if terminated or truncated:
53 | assert "episode" in list_info[i]
54 | for stats in ["r", "l", "t"]:
55 | assert stats in list_info[i]["episode"]
56 | assert isinstance(list_info[i]["episode"][stats], float)
57 | else:
58 | assert "episode" not in list_info[i]
59 |
--------------------------------------------------------------------------------
/tests/wrappers/test_video_recorder.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import re
4 | import time
5 |
6 | import pytest
7 |
8 | import gym
9 | from gym.wrappers.monitoring.video_recorder import VideoRecorder
10 |
11 |
12 | class BrokenRecordableEnv(gym.Env):
13 | metadata = {"render_modes": ["rgb_array_list"]}
14 |
15 | def __init__(self, render_mode="rgb_array_list"):
16 | self.render_mode = render_mode
17 |
18 | def render(self):
19 | pass
20 |
21 |
22 | class UnrecordableEnv(gym.Env):
23 | metadata = {"render_modes": [None]}
24 |
25 | def __init__(self, render_mode=None):
26 | self.render_mode = render_mode
27 |
28 | def render(self):
29 | pass
30 |
31 |
32 | def test_record_simple():
33 | env = gym.make(
34 | "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
35 | )
36 | rec = VideoRecorder(env)
37 | env.reset()
38 | rec.capture_frame()
39 |
40 | rec.close()
41 |
42 | assert not rec.broken
43 | assert os.path.exists(rec.path)
44 | f = open(rec.path)
45 | assert os.fstat(f.fileno()).st_size > 100
46 |
47 |
48 | def test_autoclose():
49 | def record():
50 | env = gym.make(
51 | "CartPole-v1", render_mode="rgb_array_list", disable_env_checker=True
52 | )
53 | rec = VideoRecorder(env)
54 | env.reset()
55 | rec.capture_frame()
56 |
57 | rec_path = rec.path
58 |
59 | # The function ends without an explicit `rec.close()` call
60 | # The Python interpreter will implicitly do `del rec` on garbage cleaning
61 | return rec_path
62 |
63 | rec_path = record()
64 |
65 | gc.collect() # do explicit garbage collection for test
66 | time.sleep(5) # wait for subprocess exiting
67 |
68 | assert os.path.exists(rec_path)
69 | f = open(rec_path)
70 | assert os.fstat(f.fileno()).st_size > 100
71 |
72 |
73 | def test_no_frames():
74 | env = BrokenRecordableEnv()
75 | rec = VideoRecorder(env)
76 | rec.close()
77 | assert rec.functional
78 | assert not os.path.exists(rec.path)
79 |
80 |
81 | def test_record_unrecordable_method():
82 | with pytest.warns(
83 | UserWarning,
84 | match=re.escape(
85 | "\x1b[33mWARN: Disabling video recorder because environment was not initialized with any compatible video mode between `rgb_array` and `rgb_array_list`\x1b[0m"
86 | ),
87 | ):
88 | env = UnrecordableEnv()
89 | rec = VideoRecorder(env)
90 | assert not rec.enabled
91 | rec.close()
92 |
93 |
94 | def test_record_breaking_render_method():
95 | with pytest.warns(
96 | UserWarning,
97 | match=re.escape(
98 | "Env returned None on `render()`. Disabling further rendering for video recorder by marking as disabled:"
99 | ),
100 | ):
101 | env = BrokenRecordableEnv()
102 | rec = VideoRecorder(env)
103 | rec.capture_frame()
104 | rec.close()
105 | assert rec.broken
106 | assert not os.path.exists(rec.path)
107 |
108 |
109 | def test_text_envs():
110 | env = gym.make(
111 | "FrozenLake-v1", render_mode="rgb_array_list", disable_env_checker=True
112 | )
113 | video = VideoRecorder(env)
114 | try:
115 | env.reset()
116 | video.capture_frame()
117 | video.close()
118 | finally:
119 | os.remove(video.path)
120 |
--------------------------------------------------------------------------------
/tests/wrappers/utils.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 |
4 | def has_wrapper(wrapped_env: gym.Env, wrapper_type: type) -> bool:
5 | while isinstance(wrapped_env, gym.Wrapper):
6 | if isinstance(wrapped_env, wrapper_type):
7 | return True
8 | wrapped_env = wrapped_env.env
9 | return False
10 |
--------------------------------------------------------------------------------