├── .Dockerignore ├── .codecov.yml ├── .github └── workflows │ └── push.yml ├── .gitignore ├── .multicore.env ├── .onecore.env ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DCO.md ├── Dockerfile ├── LICENSE ├── Makefile ├── Makefile.docker ├── README.md ├── _old_Dockerfile ├── conftest.py ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── index.md │ ├── markdown │ └── readme.md │ ├── notebooks │ ├── plangym_introduction.ipynb │ └── tutorial.md │ └── notes │ ├── __init__.py │ └── interface.rst ├── install-lua-macos.sh ├── pyproject.toml ├── requirements-dev.lock ├── requirements.lock ├── src ├── __init__.py └── plangym │ ├── __init__.py │ ├── api_tests.py │ ├── control │ ├── __init__.py │ ├── balloon.py │ ├── box_2d.py │ ├── classic_control.py │ ├── dm_control.py │ └── lunar_lander.py │ ├── core.py │ ├── environment_names.py │ ├── registry.py │ ├── scripts │ ├── __init__.py │ └── import_retro_roms.py │ ├── utils.py │ ├── vectorization │ ├── __init__.py │ ├── env.py │ ├── parallel.py │ └── ray.py │ ├── version.py │ └── videogames │ ├── __init__.py │ ├── atari.py │ ├── env.py │ ├── montezuma.py │ ├── nes.py │ └── retro.py ├── tests ├── __init__.py ├── control │ ├── __init__.py │ ├── test_balloon.py │ ├── test_box_2d.py │ ├── test_classic_control.py │ ├── test_dm_control.py │ └── test_lunar_lander.py ├── test_core.py ├── test_registry.py ├── test_utils.py ├── vectorization │ ├── __init__.py │ ├── test_parallel.py │ └── test_ray.py └── videogames │ ├── __init__.py │ ├── test_atari.py │ ├── test_montezuma.py │ ├── test_nes.py │ └── test_retro.py └── uncompressed_ROMs.zip /.Dockerignore: -------------------------------------------------------------------------------- 1 | venv 2 | .pytest_cache 3 | .hypothesis 4 | *.egg-info -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | ignore: 3 | - tests 4 | coverage: 5 | status: 6 | project: 7 | default: 8 | target: auto 9 | threshold: 0% 10 | informational: true 11 | paths: 12 | - plangym 13 | -------------------------------------------------------------------------------- /.github/workflows/push.yml: -------------------------------------------------------------------------------- 1 | name: Push 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | env: 12 | PROJECT_NAME: plangym 13 | PROJECT_DIR: src/plangym 14 | VERSION_FILE: src/plangym/version.py 15 | DEFAULT_BRANCH: master 16 | BOT_NAME: fragile-bot 17 | BOT_EMAIL: bot@fragile.tech 18 | DOCKER_ORG: fragiletech 19 | LOCAL_CACHE: | 20 | ~/.local/bin 21 | ~/.local/lib/python3.*/site-packages 22 | /opt/homebrew 23 | 24 | jobs: 25 | style-check: 26 | name: Style check 27 | if: "!contains(github.event.head_commit.message, 'Bump version')" 28 | runs-on: ubuntu-latest 29 | steps: 30 | - name: actions/checkout 31 | uses: actions/checkout@v3 32 | - name: Set up Python 3.10 33 | uses: actions/setup-python@v2 34 | with: 35 | python-version: "3.10" 36 | - name: Setup ruff 37 | id: setup-ruff 38 | uses: astral-sh/ruff-action@v3 39 | with: 40 | version-file: "pyproject.toml" 41 | args: "check --diff" 42 | 43 | pytest: 44 | name: Run Pytest 45 | if: "!contains(github.event.head_commit.message, 'Bump version')" 46 | strategy: 47 | matrix: 48 | python-version: [ '3.10' ] 49 | os: [ 'ubuntu-latest', 'macos-latest' ] 50 | runs-on: ${{ matrix.os }} 51 | steps: 52 | - name: actions/checkout 53 | uses: actions/checkout@v3 54 | - name: Set up Python ${{ matrix.python-version }} 55 | uses: actions/setup-python@v2 56 | with: 57 | python-version: ${{ matrix.python-version }} 58 | - name: Setup Rye 59 | id: setup-rye 60 | uses: eifinger/setup-rye@v4 61 | with: 62 | enable-cache: true 63 | cache-prefix: ${{ matrix.os }}-latest-rye-test-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} 64 | - name: actions/cache 65 | uses: actions/cache@v3 66 | with: 67 | path: ${{ env.LOCAL_CACHE }} 68 | key: ${{ matrix.os }}-latest-system-test-${{ matrix.python-version }}-${{ hashFiles('pyproject.toml') }} 69 | restore-keys: ${{ matrix.os }}-latest-system-test-${{ matrix.python-version }} 70 | 71 | - name: Install Ubuntu test and package dependencies 72 | if: ${{ matrix.os == 'ubuntu-latest' }} 73 | run: | 74 | set -x 75 | sudo apt-get install -y xvfb libglu1-mesa x11-utils 76 | rye pin --relaxed cpython@${{ matrix.python-version }} 77 | rye sync --all-features 78 | ROM_PASSWORD=${{ secrets.ROM_PASSWORD }} rye run import-roms 79 | 80 | - name: Install MacOS test and package dependencies 81 | if: ${{ matrix.os == 'macos-latest' }} 82 | run: | 83 | set -x 84 | brew install --cask xquartz 85 | brew install swig libzip 86 | # When building retro from source we may need the deprecated version of lua 5.1. 87 | # brew install qt5 capnp 88 | # Retro does not build in MacOS due to ancient requirements, so we will be installing retro==0.9.1 from pypi 89 | # because it contains pre-build wheels for MacOS. 90 | # chmod +x install-lua-macos.sh 91 | # sudo ./install-lua-macos.sh 92 | # echo 'export PATH="/opt/homebrew/opt/qt@5/bin:$PATH"' >> ~/.zshrc 93 | # export SDKROOT=$(xcrun --sdk macosx --show-sdk-path) 94 | # https://docs.github.com/en/actions/learn-github-actions/workflow-commands-for-github-actions#adding-a-system-path 95 | echo "/opt/X11/bin" >> $GITHUB_PATH 96 | # https://github.com/ponty/PyVirtualDisplay/issues/42 97 | if [ ! -d /tmp/.X11-unix ]; then 98 | mkdir /tmp/.X11-unix 99 | fi 100 | sudo chmod 1777 /tmp/.X11-unix 101 | sudo chown root /tmp/.X11-unix 102 | rye pin --relaxed cpython@${{ matrix.python-version }} 103 | rye sync --all-features 104 | # Fix a bug in retro.data where it tries to load an inexistent version file 105 | # sed -i '' 's/VERSION\.txt/VERSION/g' /Users/runner/work/plangym/plangym/.venv/lib/python3.10/site-packages/retro/__init__.py 106 | if [ ! -d /Users/runner/work/plangym/plangym/.venv/lib/python3.10/site-packages/retro/VERSION ]; then 107 | echo "0.9.1" > /Users/runner/work/plangym/plangym/.venv/lib/python3.10/site-packages/retro/VERSION.txt 108 | fi 109 | ROM_PASSWORD=${{ secrets.ROM_PASSWORD }} rye run import-roms 110 | 111 | - name: Run Pytest on MacOS 112 | if: ${{ matrix.os == 'macos-latest' }} 113 | run: | 114 | set -x 115 | # TODO: Figure out how to emulate a display in headless machines, and figure out why the commented files fail 116 | # SKIP_RENDER=True rye run pytest tests/test_registry.py tests/videogames/test_retro.py 117 | SKIP_RENDER=True rye run pytest tests/control tests/videogames/test_atari.py tests/videogames/test_nes.py tests/test_core.py tests/test_utils.py 118 | 119 | - name: Run code coverage on Ubuntu 120 | if: ${{ matrix.os == 'ubuntu-latest' }} 121 | run: | 122 | set -x 123 | xvfb-run -s "-screen 0 1400x900x24" rye run codecov 124 | 125 | - name: Upload coverage report 126 | # if: ${{ matrix.python-version == '3.10' && matrix.os == 'ubuntu-latest' }} 127 | if: ${{ matrix.os == 'ubuntu-latest' }} 128 | uses: codecov/codecov-action@v4 129 | with: 130 | fail_ci_if_error: false # optional (default = false) 131 | files: ./coverage.xml,./coverage_parallel_1.xml,./coverage_parallel_2.xml,./coverage_parallel_3.xml,./coverage_vectorization.xml 132 | flags: unittests # optional 133 | name: codecov-umbrella # optional 134 | token: ${{ secrets.CODECOV_TOKEN }} # required 135 | verbose: true # optional (default = false) 136 | 137 | # test-docker: 138 | # name: Test Docker container 139 | # runs-on: ubuntu-20.04 140 | # if: "!contains(github.event.head_commit.message, 'Bump version')" 141 | # steps: 142 | # - uses: actions/checkout@v2 143 | # - name: Build container 144 | # run: | 145 | # set -x 146 | # ROM_PASSWORD=${{ secrets.ROM_PASSWORD }} make docker-build 147 | # - name: Run tests 148 | # run: | 149 | # set -x 150 | # make docker-test 151 | 152 | build-test-package: 153 | name: Build and test the package 154 | needs: style-check 155 | runs-on: ubuntu-latest 156 | if: "!contains(github.event.head_commit.message, 'Bump version')" 157 | steps: 158 | - name: actions/checkout 159 | uses: actions/checkout@v3 160 | - name: Set up Python 3.10 161 | uses: actions/setup-python@v2 162 | with: 163 | python-version: '3.10' 164 | - name: Setup Rye 165 | id: setup-rye 166 | uses: eifinger/setup-rye@v4 167 | with: 168 | enable-cache: true 169 | cache-prefix: ubuntu-latest-rye-build-3.10-${{ hashFiles('pyproject.toml') }} 170 | - name: actions/cache 171 | uses: actions/cache@v3 172 | with: 173 | path: ${{ env.LOCAL_CACHE }} 174 | key: ubuntu-latest-system-build-3.10-${{ hashFiles('pyproject.toml') }} 175 | restore-keys: ubuntu-latest-system-test 176 | - name: Install build dependencies 177 | run: | 178 | set -x 179 | pip install uv 180 | rye install bump2version 181 | rye install twine 182 | 183 | - name: Create unique version for test.pypi 184 | run: | 185 | set -x 186 | current_version=$(grep __version__ $VERSION_FILE | cut -d\" -f2) 187 | ts=$(date +%s) 188 | new_version="$current_version$ts" 189 | bumpversion --current-version $current_version --new-version $new_version patch $VERSION_FILE 190 | 191 | - name: Build package 192 | run: | 193 | set -x 194 | rye build --clean 195 | twine check dist/* 196 | 197 | - name: Publish package to TestPyPI 198 | env: 199 | TEST_PYPI_PASS: ${{ secrets.TEST_PYPI_PASS }} 200 | if: "'$TEST_PYPI_PASS' != ''" 201 | uses: pypa/gh-action-pypi-publish@release/v1 202 | with: 203 | password: ${{ secrets.TEST_PYPI_PASS }} 204 | repository-url: https://test.pypi.org/legacy/ 205 | skip-existing: true 206 | 207 | - name: Install dependencies 208 | env: 209 | UV_SYSTEM_PYTHON: 1 210 | run: | 211 | set -x 212 | sudo apt-get install -y xvfb libglu1-mesa x11-utils 213 | rye lock --all-features 214 | uv pip install -r requirements.lock 215 | uv pip install dist/*.whl 216 | # ROM_PASSWORD=${{ secrets.ROM_PASSWORD }} python -m plangym.scripts.import_retro_roms 217 | 218 | # - name: Test package 219 | # env: 220 | # UV_SYSTEM_PYTHON: 1 221 | # run: | 222 | # set -x 223 | # rm -rf $PROJECT_DIR 224 | # find . -name "*.pyc" -delete 225 | # make test 226 | 227 | bump-version: 228 | name: Bump package version 229 | env: 230 | BOT_AUTH_TOKEN: ${{ secrets.BOT_AUTH_TOKEN }} 231 | if: "!contains(github.event.head_commit.message, 'Bump version') && github.ref == 'refs/heads/master' && '$BOT_AUTH_TOKEN' != ''" 232 | runs-on: ubuntu-latest 233 | needs: 234 | - pytest 235 | - build-test-package 236 | # - test-docker 237 | steps: 238 | - name: actions/checkout 239 | uses: actions/checkout@v3 240 | with: 241 | persist-credentials: false 242 | fetch-depth: 100 243 | - name: current_version 244 | run: | 245 | set -x 246 | echo "current_version=$(grep __version__ $VERSION_FILE | cut -d\" -f2)" >> $GITHUB_ENV 247 | echo "version_file=$VERSION_FILE" >> $GITHUB_ENV 248 | echo 'bot_name="${BOT_NAME}"' >> $GITHUB_ENV 249 | echo 'bot_email="${BOT_EMAIL}"' >> $GITHUB_ENV 250 | - name: FragileTech/bump-version 251 | uses: FragileTech/bump-version@main 252 | with: 253 | current_version: "${{ env.current_version }}" 254 | files: "${{ env.version_file }}" 255 | commit_name: "${{ env.bot_name }}" 256 | commit_email: "${{ env.bot_email }}" 257 | login: "${{ env.bot_name }}" 258 | token: "${{ secrets.BOT_AUTH_TOKEN }}" 259 | 260 | release-package: 261 | name: Release PyPI package 262 | env: 263 | PYPI_PASS: ${{ secrets.PYPI_PASS }} 264 | if: "contains(github.event.head_commit.message, 'Bump version') && github.ref == 'refs/heads/master' && '$PYPI_PASS' != ''" 265 | runs-on: ubuntu-latest 266 | steps: 267 | - name: actions/checkout 268 | uses: actions/checkout@v3 269 | - name: Set up Python 3.10 270 | uses: actions/setup-python@v2 271 | with: 272 | python-version: '3.10' 273 | - name: Setup Rye 274 | id: setup-rye 275 | uses: eifinger/setup-rye@v4 276 | with: 277 | enable-cache: true 278 | cache-prefix: ubuntu-latest-rye-release-3.10-${{ hashFiles('pyproject.toml') }} 279 | - name: Install dependencies 280 | run: | 281 | set -x 282 | rye install twine 283 | 284 | - name: Build package 285 | run: | 286 | set -x 287 | rye build --clean 288 | twine check dist/* 289 | 290 | - name: Publish package to PyPI 291 | env: 292 | PYPI_PASS: ${{ secrets.PYPI_PASS }} 293 | if: "'$PYPI_PASS' != ''" 294 | uses: pypa/gh-action-pypi-publish@release/v1 295 | with: 296 | password: ${{ secrets.PYPI_PASS }} 297 | skip-existing: true 298 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | gym_api.py 2 | roms 3 | uncompressed ROMs.zip 4 | uncompressed ROMs 5 | # This file is here just for the reference 6 | WHAT_MLOQ_GENERATED.md 7 | 8 | #Mac OS 9 | *.DS_Store 10 | 11 | #PyCharm IDE 12 | .idea/ 13 | 14 | # Documentation build templates 15 | docs/_build/ 16 | docs/build/ 17 | 18 | # Byte-compiled / optimized / DLL templates 19 | __pycache__/ 20 | *.py[cod] 21 | *$py.class 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | env/ 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | 46 | # PyInstaller 47 | # Usually these templates are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | coverage_*.xml 65 | *.cover 66 | .hypothesis/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docsrc/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed templates 99 | *.sage.py 100 | 101 | # dotenv 102 | .env 103 | 104 | # virtualenv 105 | .venv 106 | venv/ 107 | ENV/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | 122 | # CI 123 | .ci 124 | 125 | # Notebooks by default are ignored 126 | *.ipynb 127 | 128 | # lock files 129 | *.lock 130 | 131 | # hydra command history 132 | outputs/ 133 | 134 | *.pck 135 | *.npy 136 | -------------------------------------------------------------------------------- /.multicore.env: -------------------------------------------------------------------------------- 1 | PYVIRTUALDISPLAY_DISPLAYFD=0 2 | SKIP_CLASSIC_CONTROL=1 -------------------------------------------------------------------------------- /.onecore.env: -------------------------------------------------------------------------------- 1 | PYTEST_XDIST_AUTO_NUM_WORKERS=1 2 | PYVIRTUALDISPLAY_DISPLAYFD=0 -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 20.8b1 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/PyCQA/isort 7 | rev: 5.8.0 8 | hooks: 9 | - id: isort 10 | args: ["."] 11 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file for Sphinx projects 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | # Required 5 | version: 2 6 | 7 | # Set the OS, Python version and other tools you might need 8 | build: 9 | os: ubuntu-22.04 10 | tools: 11 | python: "3.10" 12 | # You can also specify other tool versions: 13 | # nodejs: "20" 14 | # rust: "1.70" 15 | # golang: "1.20" 16 | apt_packages: 17 | - build-essential 18 | - libglu1-mesa 19 | - xvfb 20 | - clang 21 | - swig 22 | 23 | # Build documentation in the "docs/" directory with Sphinx 24 | sphinx: 25 | configuration: docs/source/conf.py 26 | # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs 27 | # builder: "dirhtml" 28 | # Fail on all warnings to avoid broken references 29 | # fail_on_warning: true 30 | 31 | # Optionally build your docs in additional formats such as PDF and ePub 32 | # formats: 33 | # - pdf 34 | # - epub 35 | 36 | # Optional but recommended, declare the Python requirements required 37 | # to build your documentation 38 | # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html 39 | python: 40 | install: 41 | - requirements: requirements.lock -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | 2 | # Contributor Covenant Code of Conduct 3 | 4 | ## Our Pledge 5 | 6 | We as members, contributors, and leaders pledge to make participation in our 7 | community a harassment-free experience for everyone, regardless of age, body 8 | size, visible or invisible disability, ethnicity, sex characteristics, gender 9 | identity and expression, level of experience, education, socio-economic status, 10 | nationality, personal appearance, race, religion, or sexual identity 11 | and orientation. 12 | 13 | We pledge to act and interact in ways that contribute to an open, welcoming, 14 | diverse, inclusive, and healthy community. 15 | 16 | ## Our Standards 17 | 18 | Examples of behavior that contributes to a positive environment for our 19 | community include: 20 | 21 | * Demonstrating empathy and kindness toward other people 22 | * Being respectful of differing opinions, viewpoints, and experiences 23 | * Giving and gracefully accepting constructive feedback 24 | * Accepting responsibility and apologizing to those affected by our mistakes, 25 | and learning from the experience 26 | * Focusing on what is best not just for us as individuals, but for the 27 | overall community 28 | 29 | Examples of unacceptable behavior include: 30 | 31 | * The use of sexualized language or imagery, and sexual attention or 32 | advances of any kind 33 | * Trolling, insulting or derogatory comments, and personal or political attacks 34 | * Public or private harassment 35 | * Publishing others' private information, such as a physical or email 36 | address, without their explicit permission 37 | * Other conduct which could reasonably be considered inappropriate in a 38 | professional setting 39 | 40 | ## Enforcement Responsibilities 41 | 42 | Community leaders are responsible for clarifying and enforcing our standards of 43 | acceptable behavior and will take appropriate and fair corrective action in 44 | response to any behavior that they deem inappropriate, threatening, offensive, 45 | or harmful. 46 | 47 | Community leaders have the right and responsibility to remove, edit, or reject 48 | comments, commits, code, wiki edits, issues, and other contributions that are 49 | not aligned to this Code of Conduct, and will communicate reasons for moderation 50 | decisions when appropriate. 51 | 52 | ## Scope 53 | 54 | This Code of Conduct applies within all community spaces, and also applies when 55 | an individual is officially representing the community in public spaces. 56 | Examples of representing our community include using an official e-mail address, 57 | posting via an official social media account, or acting as an appointed 58 | representative at an online or offline event. 59 | 60 | ## Enforcement 61 | 62 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 63 | reported to the community leaders responsible for enforcement at info@fragile.tech. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | [https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available 126 | at [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | plangym is [MIT licensed](LICENSE) and accepts 4 | contributions via GitHub pull requests. This document outlines some of the 5 | conventions on development workflow, commit message formatting, contact points, 6 | and other resources to make it easier to get your contribution accepted. 7 | 8 | ## Certificate of Origin 9 | 10 | By contributing to this project you agree to the [Developer Certificate of 11 | Origin (DCO)](DCO.md). This document was created by the Linux Kernel community and is a 12 | simple statement that you, as a contributor, have the legal right to make the 13 | contribution. 14 | 15 | In order to show your agreement with the DCO you should include at the end of commit message, 16 | the following line: `Signed-off-by: John Doe `, using your real name. 17 | 18 | This can be done easily using the [`-s`](https://github.com/git/git/blob/b2c150d3aa82f6583b9aadfecc5f8fa1c74aca09/Documentation/git-commit.txt#L154-L161) flag on the `git commit`. 19 | 20 | 21 | ## Support Channels 22 | 23 | The official support channels, for both users and contributors, are: 24 | 25 | - GitHub [issues](https://github.com/FragileTech/plangym/issues)* 26 | 27 | *Before opening a new issue or submitting a new pull request, it's helpful to 28 | search the project - it's likely that another user has already reported the 29 | issue you're facing, or it's a known issue that we're already aware of. 30 | 31 | 32 | ## How to Contribute 33 | 34 | Pull Requests (PRs) are the main and exclusive way to contribute to the official project. 35 | In order for a PR to be accepted it needs to pass a list of requirements: 36 | 37 | - The CI style check passes (run locally with `make check`). 38 | - Code Coverage does not decrease. 39 | - All the tests pass. 40 | - Python code is formatted according to [![PEP8](https://img.shields.io/badge/code%20style-pep8-orange.svg)](https://www.python.org/dev/peps/pep-0008/). 41 | - If the PR is a bug fix, it has to include a new unit test that fails before the patch is merged. 42 | - If the PR is a new feature, it has to come with a suite of unit tests, that tests the new functionality. 43 | - In any case, all the PRs have to pass the personal evaluation of at least one of the [maintainers](MAINTAINERS.md). 44 | 45 | 46 | ### Format of the commit message 47 | 48 | The commit summary must start with a capital letter and with a verb in present tense. No dot in the end. 49 | 50 | ``` 51 | Add a feature 52 | Remove unused code 53 | Fix a bug 54 | ``` 55 | 56 | Every commit details should describe what was changed, under which context and, if applicable, the GitHub issue it relates to. 57 | -------------------------------------------------------------------------------- /DCO.md: -------------------------------------------------------------------------------- 1 | Developer Certificate of Origin 2 | Version 1.1 3 | 4 | Copyright (C) 2004, 2006 The Linux Foundation and its contributors. 5 | 660 York Street, Suite 102, 6 | San Francisco, CA 94110 USA 7 | 8 | Everyone is permitted to copy and distribute verbatim copies of this 9 | license document, but changing it is not allowed. 10 | 11 | 12 | Developer's Certificate of Origin 1.1 13 | 14 | By making a contribution to this project, I certify that: 15 | 16 | (a) The contribution was created in whole or in part by me and I 17 | have the right to submit it under the open source license 18 | indicated in the file; or 19 | 20 | (b) The contribution is based upon previous work that, to the best 21 | of my knowledge, is covered under an appropriate open source 22 | license and I have the right under that license to submit that 23 | work with modifications, whether created in whole or in part 24 | by me, under the same open source license (unless I am 25 | permitted to submit under a different license), as indicated 26 | in the file; or 27 | 28 | (c) The contribution was provided directly to me by some other 29 | person who certified (a), (b) or (c) and I have not modified 30 | it. 31 | 32 | (d) I understand and agree that this project and the contribution 33 | are public and that a record of the contribution (including all 34 | personal information I submit with it, including my sign-off) is 35 | maintained indefinitely and may be redistributed consistent with 36 | this project or the open source license(s) involved. 37 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:22.04 2 | 3 | 4 | RUN apt-get update && \ 5 | apt-get install -y --no-install-suggests --no-install-recommends \ 6 | make cmake git xvfb build-essential curl ssh wget ca-certificates swig \ 7 | python3 python3-dev && \ 8 | wget -O - https://bootstrap.pypa.io/get-pip.py | python3 9 | 10 | RUN python3 -m pip install uv 11 | COPY src ./src 12 | COPY requirements.lock . 13 | COPY pyproject.toml . 14 | COPY src/plangym/scripts/import_retro_roms.py . 15 | COPY LICENSE . 16 | COPY README.md . 17 | COPY ROMS.zip . 18 | 19 | RUN uv pip install --no-cache --system -r requirements.lock 20 | RUN python3 src/plangym/scripts/import_retro_roms.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 - 2021 FragileTech 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | current_dir = $(shell pwd) 2 | 3 | PROJECT = plangym 4 | n ?= auto 5 | DOCKER_ORG = fragiletech 6 | DOCKER_TAG ?= ${PROJECT} 7 | ROM_FILE ?= "uncompressed_ROMs.zip" 8 | ROM_PASSWORD ?= "NO_PASSWORD" 9 | VERSION ?= latest 10 | MUJOCO_PATH?=~/.mujoco 11 | 12 | .PHONY: install-mujoco 13 | install-mujoco: 14 | mkdir ${MUJOCO_PATH} 15 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz 16 | tar -xvzf mujoco210-linux-x86_64.tar.gz -C ${MUJOCO_PATH} 17 | rm mujoco210-linux-x86_64.tar.gz 18 | 19 | .PHONY: import-roms 20 | import-roms: 21 | ifeq (${ROM_PASSWORD}, "NO_PASSWORD") 22 | unzip -o ${ROM_FILE} 23 | else 24 | unzip -o -P ${ROM_PASSWORD} ${ROM_FILE} 25 | endif 26 | python3 src/plangym/scripts/import_retro_roms.py 27 | 28 | .PHONY: install-envs 29 | install-envs: 30 | make -f Makefile.docker install-env-deps 31 | make install-mujoco 32 | 33 | .PHONY: docker-shell 34 | docker-shell: 35 | docker run --rm --gpus all -v ${current_dir}:/${PROJECT} --network host -w /${PROJECT} -it ${DOCKER_ORG}/${PROJECT}:${VERSION} bash 36 | 37 | .PHONY: docker-notebook 38 | docker-notebook: 39 | docker run --rm --gpus all -v ${current_dir}:/${PROJECT} --network host -w /${PROJECT} -it ${DOCKER_ORG}/${PROJECT}:${VERSION} 40 | 41 | .PHONY: docker-build 42 | docker-build: 43 | docker build --pull -t ${DOCKER_ORG}/${PROJECT}:${VERSION} . --build-arg ROM_PASSWORD=${ROM_PASSWORD} 44 | 45 | .PHONY: docker-test 46 | docker-test: 47 | find -name "*.pyc" -delete 48 | docker run --rm --network host -w /${PROJECT} -e MUJOCO_GL=egl -e SKIP_RENDER=True -e DISABLE_RAY=True --entrypoint python3 ${DOCKER_ORG}/${PROJECT}:${VERSION} -m pytest -n $n -s -o log_cli=true -o log_cli_level=info 49 | docker run --rm --network host -w /${PROJECT} -e MUJOCO_GL=egl -e SKIP_RENDER=True -e DISABLE_RAY=False --entrypoint python3 ${DOCKER_ORG}/${PROJECT}:${VERSION} -m pytest tests/vectorization/test_ray.py -s -o log_cli=true -o log_cli_level=info 50 | 51 | .PHONY: docker-push 52 | docker-push: 53 | docker push ${DOCKER_ORG}/${DOCKER_TAG}:${VERSION} 54 | docker tag ${DOCKER_ORG}/${DOCKER_TAG}:${VERSION} ${DOCKER_ORG}/${DOCKER_TAG}:latest 55 | docker push ${DOCKER_ORG}/${DOCKER_TAG}:latest 56 | -------------------------------------------------------------------------------- /Makefile.docker: -------------------------------------------------------------------------------- 1 | current_dir = $(shell pwd) 2 | 3 | PROJECT = dockerfiles 4 | VERSION ?= latest 5 | DOCKER_TAG = None 6 | PYTHON_VERSION = "3.8" 7 | UBUNTU_NAME = $(lsb_release -s -c) 8 | 9 | # Install system packages 10 | .PHONY: install-common-dependencies 11 | install-common-dependencies: 12 | apt-get update && \ 13 | apt-get install -y --no-install-suggests --no-install-recommends \ 14 | ca-certificates locales pkg-config apt-utils gcc g++ wget make cmake git curl flex ssh gpgv \ 15 | libffi-dev libjpeg-turbo-progs libjpeg8-dev libjpeg-turbo8 libjpeg-turbo8-dev gnupg2 \ 16 | libpng-dev libpng16-16 libglib2.0-0 bison gfortran lsb-release \ 17 | libsm6 libxext6 libxrender1 libfontconfig1 libhdf5-dev libopenblas-base libopenblas-dev \ 18 | libfreetype6 libfreetype6-dev zlib1g-dev zlib1g xvfb python-opengl ffmpeg libhdf5-dev unzip && \ 19 | ln -s /usr/lib/x86_64-linux-gnu/libz.so /lib/ && \ 20 | ln -s /usr/lib/x86_64-linux-gnu/libjpeg.so /lib/ && \ 21 | echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && \ 22 | locale-gen && \ 23 | wget -O - https://bootstrap.pypa.io/get-pip.py | python3 && \ 24 | rm -rf /var/lib/apt/lists/* && \ 25 | echo '#!/bin/bash\n\\n\echo\n\echo " $@"\n\echo\n\' > /browser && \ 26 | chmod +x /browser 27 | 28 | 29 | .PHONY: remove-dev-packages 30 | remove-dev-packages: 31 | pip3 uninstall -y cython && \ 32 | apt-get remove -y cmake pkg-config flex bison curl libpng-dev \ 33 | libjpeg-turbo8-dev zlib1g-dev libhdf5-dev libopenblas-dev gfortran \ 34 | libfreetype6-dev libjpeg8-dev libffi-dev && \ 35 | apt-get autoremove -y && \ 36 | apt-get clean && \ 37 | rm -rf /var/lib/apt/lists/* 38 | 39 | # Install Python 3.9 40 | .PHONY: install-python3.9 41 | install-python3.9: 42 | apt-get install -y --no-install-suggests --no-install-recommends \ 43 | python3.9 python3.9-dev python3-distutils python3-setuptools 44 | 45 | # Install Python 3.8 46 | .PHONY: install-python3.8 47 | install-python3.8: 48 | apt-get install -y --no-install-suggests --no-install-recommends \ 49 | python3.8 python3.8-dev python3-distutils python3-setuptools 50 | 51 | # Install Python 3.7 52 | .PHONY: install-python3.7 53 | install-python3.7: 54 | apt-get install -y --no-install-suggests --no-install-recommends \ 55 | python3.7 python3.7-dev python3-distutils python3-setuptools 56 | 57 | # Install Python 3.6 58 | .PHONY: install-python3.6 59 | install-python3.6: 60 | apt-get install -y --no-install-suggests --no-install-recommends \ 61 | python3.6 python3.6-dev python3-distutils python3-setuptools \ 62 | 63 | # Install phantomjs for holoviews image save 64 | .PHONY: install-phantomjs 65 | install-phantomjs: 66 | curl -sSL https://deb.nodesource.com/gpgkey/nodesource.gpg.key | apt-key add - && \ 67 | echo "deb https://deb.nodesource.com/node_10.x ${UBUNTU_NAME} main" | tee /etc/apt/sources.list.d/nodesource.list && \ 68 | echo "deb-src https://deb.nodesource.com/node_10.x ${UBUNTU_NAME} main" | tee -a /etc/apt/sources.list.d/nodesource.list && \ 69 | apt-get update && apt-get install -y nodejs && \ 70 | npm install phantomjs --unsafe-perm && \ 71 | npm install -g phantomjs-prebuilt --unsafe-perm 72 | 73 | # Install common python dependencies 74 | .PHONY: install-python-libs 75 | install-python-libs: 76 | python3 -m pip install -U pip && \ 77 | pip3 install --no-cache-dir setuptools wheel cython pipenv && \ 78 | pip3 install --no-cache-dir matplotlib && \ 79 | python3 -c "import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot" 80 | 81 | .PHONY: install-env-deps 82 | install-env-deps: 83 | apt-get update 84 | apt-get install -y --no-install-suggests --no-install-recommends \ 85 | libglfw3 libglew-dev libgl1-mesa-glx libosmesa6 xvfb swig 86 | 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Welcome to Plangym 2 | 3 | [![Documentation Status](https://readthedocs.org/projects/plangym/badge/?version=latest)](https://plangym.readthedocs.io/en/latest/?badge=latest) 4 | [![Code coverage](https://codecov.io/github/FragileTech/plangym/coverage.svg)](https://codecov.io/github/FragileTech/plangym) 5 | [![PyPI package](https://badgen.net/pypi/v/plangym)](https://pypi.org/project/plangym/) 6 | [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/ambv/black) 7 | [![license: MIT](https://img.shields.io/badge/license-MIT-green.svg)](https://opensource.org/licenses/MIT) 8 | 9 | Plangym is an open source Python library for developing and comparing planning algorithms by providing a 10 | standard API to communicate between algorithms and environments, as well as a standard set of environments 11 | compliant with that API. 12 | 13 | Given that OpenAI's `gym` has become the de-facto standard in the research community, `plangym`'s API 14 | is designed to be as similar as possible to `gym`'s API while allowing to modify the environment state. 15 | 16 | Furthermore, it provides additional functionality for stepping the environments in parallel, delayed environment 17 | initialization for dealing with environments that are difficult to serialize, compatibility with `gym.Wrappers`, 18 | and more. 19 | 20 | ## Supported environments 21 | Plangym currently supports all the following environments: 22 | 23 | * OpenAI gym classic control environments 24 | * OpenAI gym Box2D environments 25 | * OpenAI gym Atari 2600 environments 26 | * Deepmind's dm_control environments 27 | * Stable-retro environments 28 | 29 | ## Getting started 30 | 31 | ### Stepping an environment 32 | ```python 33 | import plangym 34 | env = plangym.make(name="CartPole-v0") 35 | state, obs, info = env.reset() 36 | 37 | state = state.copy() 38 | action = env.action_space.sample() 39 | 40 | data = env.step(state=state, action=action) 41 | new_state, observ, reward, end, truncated, info = data 42 | ``` 43 | 44 | 45 | ### Stepping a batch of states and actions 46 | ```python 47 | import plangym 48 | env = plangym.make(name="CartPole-v0") 49 | state, obs, info = env.reset() 50 | 51 | states = [state.copy() for _ in range(10)] 52 | actions = [env.action_space.sample() for _ in range(10)] 53 | 54 | data = env.step_batch(states=states, actions=actions) 55 | new_states, observs, rewards, ends, truncateds, infos = data 56 | ``` 57 | 58 | 59 | ### Using parallel steps 60 | 61 | ```python 62 | import plangym 63 | env = plangym.make(name="MsPacman-v0", n_workers=2) 64 | 65 | state, obs, info = env.reset() 66 | 67 | states = [state.copy() for _ in range(10)] 68 | actions = [env.action_space.sample() for _ in range(10)] 69 | 70 | data = env.step_batch(states=states, actions=actions) 71 | new_states, observs, rewards, ends, truncateds, infos = data 72 | ``` 73 | 74 | ## Installation 75 | TODO: Meanwhile take a look at how we set up the repository in `.github/workflows/push.yaml`. 76 | 77 | ## License 78 | Plangym is released under the [MIT](LICENSE) license. 79 | 80 | ## Contributing 81 | 82 | Contributions are very welcome! Please check the [contributing guidelines](CONTRIBUTING.md) before opening a pull request. 83 | 84 | If you have any suggestions for improvement, or you want to report a bug please open 85 | an [issue](https://github.com/FragileTech/plangym/issues). 86 | 87 | 88 | # Installing nes-py 89 | 90 | #### Step 1: Install necessary development tools and libraries 91 | sudo apt-get update 92 | sudo apt-get install build-essential clang 93 | sudo apt-get install libstdc++-10-dev 94 | 95 | #### Step 2: Verify the compiler and include paths 96 | #### Ensure you are using g++ instead of clang++ if clang++ is not properly configured 97 | export CXX=g++ 98 | export CC=gcc 99 | 100 | # Rebuild the project 101 | rye install nes-py --git=https://github.com/FragileTech/nes-py -------------------------------------------------------------------------------- /_old_Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:20.04 2 | ARG JUPYTER_PASSWORD="plangym" 3 | ARG ROM_PASSWORD="" 4 | 5 | ENV BROWSER=/browser \ 6 | LC_ALL=en_US.UTF-8 \ 7 | LANG=en_US.UTF-8 8 | COPY Makefile.docker Makefile 9 | 10 | RUN apt-get update && \ 11 | apt-get install -y --no-install-suggests --no-install-recommends make cmake curl ssh && \ 12 | make install-python3.8 && \ 13 | make install-common-dependencies && \ 14 | make install-python-libs && \ 15 | make install-env-deps 16 | 17 | COPY . plangym/ 18 | 19 | RUN cd plangym \ 20 | && make install-mujoco \ 21 | && python3 -m pip install -r requirements-lint.txt \ 22 | && python3 -m pip install -r requirements-test.txt \ 23 | && python3 -m pip install -r requirements.txt \ 24 | && python3 -m pip install ipython jupyter \ 25 | && python3 -m pip install -e . \ 26 | && ROM_PASSWORD=${ROM_PASSWORD} make import-roms \ 27 | && git config --global init.defaultBranch master \ 28 | && git config --global user.name "Whoever" \ 29 | && git config --global user.email "whoever@fragile.tech" 30 | 31 | RUN make remove-dev-packages 32 | 33 | RUN mkdir /root/.jupyter && \ 34 | echo 'c.NotebookApp.token = "'${JUPYTER_PASSWORD}'"' > /root/.jupyter/jupyter_notebook_config.py 35 | CMD jupyter notebook --allow-root --port 8080 --ip 0.0.0.0 36 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pytest 3 | 4 | import plangym 5 | 6 | 7 | @pytest.fixture(autouse=True) 8 | def add_imports(doctest_namespace): 9 | """Define names and aliases for the code docstrings.""" 10 | doctest_namespace["np"] = numpy 11 | doctest_namespace["plangym"] = plangym 12 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | 22 | .PHONY: server 23 | server: 24 | python3 -m http.server --directory build/html/ 25 | 26 | .PHONY: test 27 | test: 28 | make html 29 | make server -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | from unittest.mock import MagicMock 16 | 17 | 18 | class Mock(MagicMock): 19 | @classmethod 20 | def __getattr__(cls, name): 21 | return MagicMock() 22 | 23 | 24 | sys.path.insert(0, os.path.abspath("../../")) 25 | sys.setrecursionlimit(1500) 26 | MOCK_MODULES = [] 27 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 28 | 29 | # -- Project information ----------------------------------------------------- 30 | project = "plangym" 31 | copyright = "2018, FragileTech" 32 | author = "Guillem Duran Ballester" 33 | 34 | # The short X.Y version 35 | from plangym.version import __version__ 36 | 37 | 38 | version = __version__ 39 | # The full version, including alpha/beta/rc tags 40 | release = __version__ 41 | # -- General configuration --------------------------------------------------- 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | exclude_patterns = ["_build", "**.ipynb_checkpoints"] 45 | # The master toctree document. 46 | master_doc = "index" 47 | # Add any Sphinx extension module names here, as strings. They can be 48 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 49 | # ones. 50 | extensions = [ 51 | "sphinx.ext.autodoc", 52 | "autoapi.extension", 53 | "sphinx.ext.doctest", 54 | "sphinx.ext.intersphinx", 55 | "sphinx.ext.todo", 56 | "sphinx.ext.coverage", 57 | "sphinx.ext.imgmath", 58 | "sphinx.ext.viewcode", 59 | "sphinx.ext.napoleon", 60 | "sphinx.ext.autosectionlabel", 61 | "sphinx.ext.autodoc.typehints", 62 | "sphinx_book_theme", 63 | "myst_nb", 64 | "sphinxcontrib.mermaid", 65 | "sphinx.ext.githubpages", 66 | "sphinx_copybutton", 67 | "sphinx_togglebutton", 68 | ] 69 | suppress_warnings = ["image.nonlocal_uri"] 70 | autodoc_typehints = "description" 71 | # Autoapi settings 72 | autoapi_type = "python" 73 | autoapi_dirs = ["../../src/plangym"] 74 | autoapi_add_toctree_entry = True 75 | # Make use of custom templates 76 | autoapi_template_dir = "_autoapi_templates" 77 | exclude_patterns.append("_autoapi_templates/index.rst") 78 | 79 | # Ignore sphinx-autoapi warnings on multiple target description 80 | suppress_warnings.append("ref.python") 81 | 82 | # Napoleon settings 83 | napoleon_google_docstring = True 84 | napoleon_numpy_docstring = True 85 | napoleon_include_init_with_doc = True 86 | napoleon_include_private_with_doc = False 87 | napoleon_include_special_with_doc = True 88 | napoleon_use_admonition_for_examples = False 89 | napoleon_use_admonition_for_notes = False 90 | napoleon_use_admonition_for_references = False 91 | napoleon_use_ivar = False 92 | napoleon_use_param = True 93 | napoleon_use_rtype = True 94 | 95 | # Add any paths that contain templates here, relative to this directory. 96 | templates_path = ["_templates"] 97 | 98 | # List of patterns, relative to source directory, that match files and 99 | # directories to ignore when looking for source files. 100 | # This pattern also affects html_static_path and html_extra_path. 101 | exclude_patterns = [] 102 | 103 | 104 | # -- Options for HTML output ------------------------------------------------- 105 | 106 | # The theme to use for HTML and HTML Help pages. See the documentation for 107 | # a list of builtin themes. 108 | # 109 | html_title = "" 110 | html_theme = "sphinx_book_theme" 111 | # html_logo = "_static/logo-wide.svg" 112 | # html_favicon = "_static/logo-square.svg" 113 | html_theme_options = { 114 | "github_url": "https://github.com/fragiletech/plangym", 115 | "repository_url": "https://github.com/fragiletech/plangym", 116 | "repository_branch": "gh-pages", 117 | "home_page_in_toc": True, 118 | "path_to_docs": "docs", 119 | "show_navbar_depth": 1, 120 | "use_edit_page_button": True, 121 | "use_repository_button": True, 122 | "use_download_button": True, 123 | "launch_buttons": { 124 | "binderhub_url": "https://mybinder.org", 125 | "notebook_interface": "classic", 126 | }, 127 | } 128 | 129 | # Add any paths that contain custom static files (such as style sheets) here, 130 | # relative to this directory. They are copied after the builtin static files, 131 | # so a file named "default.css" will overwrite the builtin "default.css". 132 | html_static_path = ["_static"] 133 | 134 | # myst_parser options 135 | nb_execution_mode = "off" 136 | myst_heading_anchors = 2 137 | myst_enable_extensions = [ 138 | "amsmath", 139 | "colon_fence", 140 | "deflist", 141 | "dollarmath", 142 | "html_admonition", 143 | "html_image", 144 | "linkify", 145 | "replacements", 146 | "smartquotes", 147 | "substitution", 148 | ] 149 | 150 | 151 | # If true, `todo` and `todoList` produce output, else they produce nothing. 152 | todo_include_todos = True 153 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../README.md 2 | ``` 3 | 4 | ```{toctree} 5 | --- 6 | maxdepth: 1 7 | caption: Introduction 8 | --- 9 | markdown/readme.md 10 | ``` 11 | 12 | ```{toctree} 13 | --- 14 | maxdepth: 5 15 | caption: User guide 16 | --- 17 | notebooks/plangym_introduction.ipynb 18 | notebooks/tutorial.md 19 | ``` 20 | 21 | ```{toctree} 22 | --- 23 | maxdepth: 2 24 | caption: API Reference 25 | --- 26 | autoapi/index.rst 27 | ``` -------------------------------------------------------------------------------- /docs/source/markdown/readme.md: -------------------------------------------------------------------------------- 1 | ```{include} ../../../README.md 2 | ``` -------------------------------------------------------------------------------- /docs/source/notebooks/plangym_introduction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "This is an introductory tutorial to the main features of plangym." 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Working with states" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### `reset` and `step` return the environment state" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "The main difference with the `gym` API is that environment state is considered as important as observations, rewards and terminal flags. This is why plangym incorporates them to the tuples that the environment returns after calling `step` and `reset`:\n", 29 | "\n", 30 | "- The `reset` method will return a tuple of (state, observation) unless you pass `return_state=False` as an argument.\n", 31 | "\n", 32 | "- When `step` is called passing the environment state as an argument it will return a tuple containing `(state, obs, reward, end, info)`" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 1, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import plangym\n", 42 | "\n", 43 | "env = plangym.make(\"CartPole-v0\")\n", 44 | "action = env.action_space.sample()\n", 45 | "\n", 46 | "state, obs = env.reset()\n", 47 | "state, obs, reward, end, info = env.step(action, state)" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "However, if you don't provide the environment state when calling `step`, the returned tuple will match the standard `gym` interface:" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 2, 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "env = plangym.make(\"CartPole-v0\")\n", 64 | "action = env.action_space.sample()\n", 65 | "\n", 66 | "obs = env.reset(return_state=False)\n", 67 | "obs, reward, end, info = env.step(action)" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "### Accessing and modifying the environment state" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "You can get a copy of the environment's state calling `env.get_state()`:" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 3, 87 | "metadata": {}, 88 | "outputs": [ 89 | { 90 | "data": { 91 | "text/plain": [ 92 | "array([ 0.03145539, 0.17749025, 0.01348916, -0.25611924])" 93 | ] 94 | }, 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "output_type": "execute_result" 98 | } 99 | ], 100 | "source": [ 101 | "state = env.get_state()\n", 102 | "state" 103 | ] 104 | }, 105 | { 106 | "cell_type": "markdown", 107 | "metadata": {}, 108 | "source": [ 109 | "And set the environment state using `env.set_state(state)`" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 4, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "env.set_state(state)\n", 119 | "assert (state == env.get_state()).all()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## Step vectorization" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "All plangym environments offer a `step_batch` method that allows vectorized steps of batches of states and actions. \n", 134 | "\n", 135 | "Calling `step_batch` with a list of states and actions will return a tuple of lists containing the step data for each of the states and actions provided." 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 7, 141 | "metadata": {}, 142 | "outputs": [ 143 | { 144 | "data": { 145 | "text/plain": [ 146 | "(list, list)" 147 | ] 148 | }, 149 | "execution_count": 7, 150 | "metadata": {}, 151 | "output_type": "execute_result" 152 | } 153 | ], 154 | "source": [ 155 | "states = [state.copy() for _ in range(10)]\n", 156 | "actions = [env.action_space.sample() for _ in range(10)]\n", 157 | "\n", 158 | "data = env.step_batch(states=states, actions=actions)\n", 159 | "new_states, observs, rewards, ends, infos = data\n", 160 | "type(new_states), type(observs)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "markdown", 165 | "metadata": {}, 166 | "source": [ 167 | "### Parallel step vectorization using multiprocessing" 168 | ] 169 | }, 170 | { 171 | "cell_type": "markdown", 172 | "metadata": {}, 173 | "source": [ 174 | "Passing the argument `n_workers` to `plangym.make` will return an environment that steps a batch of actions and states in parallel using multiprocessing." 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 9, 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "data": { 184 | "text/plain": [ 185 | "(plangym.parallel.ParallelEnv, list, list)" 186 | ] 187 | }, 188 | "execution_count": 9, 189 | "metadata": {}, 190 | "output_type": "execute_result" 191 | } 192 | ], 193 | "source": [ 194 | "env = plangym.make(\"CartPole-v0\", n_workers=2)\n", 195 | "states = [state.copy() for _ in range(10)]\n", 196 | "actions = [env.action_space.sample() for _ in range(10)]\n", 197 | "\n", 198 | "data = env.step_batch(states=states, actions=actions)\n", 199 | "new_states, observs, rewards, ends, infos = data\n", 200 | "type(env), type(new_states), type(observs)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "### Step vectorization using ray" 208 | ] 209 | }, 210 | { 211 | "cell_type": "markdown", 212 | "metadata": {}, 213 | "source": [ 214 | "It is possible to use ray actors to step the environment in parallel when calling `step_batch` by passing `ray=True` to `plangym.make`" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 10, 220 | "metadata": {}, 221 | "outputs": [ 222 | { 223 | "name": "stderr", 224 | "output_type": "stream", 225 | "text": [ 226 | "2021-12-13 10:01:47,772\tINFO services.py:1247 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265\u001b[39m\u001b[22m\n" 227 | ] 228 | }, 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "(plangym.ray.RayEnv, list, list)" 233 | ] 234 | }, 235 | "execution_count": 10, 236 | "metadata": {}, 237 | "output_type": "execute_result" 238 | } 239 | ], 240 | "source": [ 241 | "import ray\n", 242 | "ray.init()\n", 243 | "\n", 244 | "env = plangym.make(\"CartPole-v0\", n_workers=2, ray=True)\n", 245 | "states = [state.copy() for _ in range(10)]\n", 246 | "actions = [env.action_space.sample() for _ in range(10)]\n", 247 | "\n", 248 | "data = env.step_batch(states=states, actions=actions)\n", 249 | "new_states, observs, rewards, ends, infos = data\n", 250 | "type(env), type(new_states), type(observs)" 251 | ] 252 | } 253 | ], 254 | "metadata": { 255 | "kernelspec": { 256 | "display_name": "Python 3", 257 | "language": "python", 258 | "name": "python3" 259 | }, 260 | "language_info": { 261 | "codemirror_mode": { 262 | "name": "ipython", 263 | "version": 3 264 | }, 265 | "file_extension": ".py", 266 | "mimetype": "text/x-python", 267 | "name": "python", 268 | "nbconvert_exporter": "python", 269 | "pygments_lexer": "ipython3", 270 | "version": "3.8.10" 271 | } 272 | }, 273 | "nbformat": 4, 274 | "nbformat_minor": 4 275 | } 276 | -------------------------------------------------------------------------------- /docs/source/notebooks/tutorial.md: -------------------------------------------------------------------------------- 1 | ## Welcome to Plangym 2 | 3 | Plangym is an open source Python library for developing and comparing planning algorithms by providing a 4 | standard API to communicate between algorithms and environments, as well as a standard set of environments 5 | compliant with that API. 6 | 7 | Furthermore, it provides additional functionality for stepping the environments in parallel, delayed environment 8 | initialization for dealing with environments that are difficult to serialize, compatibility with `gym.Wrappers`, 9 | and more. 10 | 11 | ## API for reinforcement learning 12 | 13 | OpenAI's `gym` has become the de-facto standard in the research community, `plangym`'s API 14 | is designed to be as similar as possible to `gym`'s API while allowing to modify the environment state. 15 | `plangym` offers a standard API for reinforcement learning problems with a simple, intuitive interface.\ 16 | Users with general knowledge of `gym` syntax will feel comfortable using `plangym`; it uses the 17 | same schema and philosophy of the former, yet `plangym` provides new advanced functionalities beyond `gym` 18 | capabilities. 19 | 20 | ## Plangym states 21 | 22 | The principal attribute that characterizes `plangym` and distinguishes it from other libraries is the capacity to 23 | save the current state of the environment. By simply calling `get_state()`, the user is able to store the positions, 24 | attributes, actions, and all necessary information that has led the agent and the environment to their actual state. 25 | In this way, the user can load a specific configuration of the simulation and continue the process in that 26 | precise state. 27 | 28 | ## Getting started 29 | 30 | ### Stepping an environment 31 | 32 | We initialize the environment using the command `plangym.make`, similarly to `gym` syntax. By resetting 33 | the environment, we get our initial _state_ and _observation_. As mentioned, the fact that the environment 34 | is returning its current state is one of the main `plangym` features; we are able to __get__ and __set__ 35 | the precise configuration of the environment in each step as if we were loading and saving the 36 | data of a game. This option allows the user to apply a specific action to an explicit state: 37 | 38 | ```python 39 | import plangym 40 | env = plangym.make(name="CartPole-v0") 41 | state, obs = env.reset() 42 | 43 | state = state.copy() 44 | action = env.action_space.sample() 45 | 46 | data = env.step(state=state, action=action) 47 | new_state, observ, reward, end, info = data 48 | ``` 49 | 50 | We interact with the environment by applying an action to a specific environment state via `plangym.PlanEnv.step`. 51 | We can define the exact environment state over which we apply our action. 52 | 53 | As expected, this function returns the evolution of the environment, 54 | the observed results, the reward of the performed action, if the agent enters 55 | a terminal state, and additional information about the process. 56 | 57 | If we are not interested in getting the current state of the environment, we simply define the argument 58 | `return_state = False` inside the methods `plangym.PlanEnv.reset` and `plangym.PlanEnv.step`: 59 | 60 | ```python 61 | import plangym 62 | env = plangym.make(name="CartPole-v0") 63 | obs = env.reset(return_state=False) 64 | 65 | action = env.action_space.sample() 66 | 67 | data = env.step(action=action, return_state=False) 68 | observ, reward, end, info = data 69 | ``` 70 | 71 | By setting `return_state=False`, neither `reset()` nor `step()` will return the state of the simulation. In this way, 72 | we are obtaining the exact same answers as if we were working in a plain `gym` interface. Thus, `plangym` 73 | provides a complete tool for developing planning projects __as well as__ a general, standard API for reinforcement learning problems. 74 | 75 | ### Stepping a batch of states and actions 76 | ```python 77 | import plangym 78 | env = plangym.make(name="CartPole-v0") 79 | state, obs = env.reset() 80 | 81 | states = [state.copy() for _ in range(10)] 82 | actions = [env.action_space.sample() for _ in range(10)] 83 | 84 | data = env.step_batch(states=states, actions=actions) 85 | new_states, observs, rewards, ends, infos = data 86 | ``` 87 | 88 | `plangym` allows applying multiple actions in a single call via the command `plangym.PlanEnv.step_batch`. 89 | The syntax used for this case is reminiscent to that employed when calling a `step` function; we should define 90 | a __list__ of states and actions and use them as arguments of the function `step_batch()`. `plangym` will 91 | take care of distributing the states and actions correspondingly, returning a tuple with the results 92 | of such actions. 93 | 94 | ### Making environments 95 | 96 | To initialize an environment, `plangym` uses the same syntax as `gym` via the `plangym.make` command. However, 97 | this command offers more advanced options than the `gym` standard; it controls the general behavior of the API and 98 | its different environments, and it serves as a command center between the user and the library. 99 | 100 | Instead of using a specific syntax for each environment (with distinct arguments and parameters), 101 | `plangym` unifies all options within a single, common framework under the control of 102 | `make()` command. 103 | 104 | All instance attributes are defined through the `make()` command, which classifies and distributes them accordingly whether 105 | they belong to `plangym` or standard parameters. In addition, `make()` also allows the user to configure the 106 | parameters needed for stepping the environment in parallel. One only should select the desired mode, and 107 | `plangym` will do the rest. 108 | 109 | ```python 110 | import plangym 111 | env = plangym.make( 112 | name="PlanMontezuma-v0", # name of the environment 113 | n_workers=4, # Number of parallel processes 114 | state='', # Define a specific state for the environment 115 | ) 116 | ``` 117 | Once the parameters have been introduced, the command instantiates the appropriate environment class 118 | with the given attributes. 119 | 120 | 121 | #### Make arguments 122 | 123 | `make()` accepts multiple arguments when creating an environment. We should distinguish between the arguments 124 | passed to configurate the environment making process and those used to instantiate the environment itself. 125 | * Make signature: 126 | Attributes used to configure the process that creates the environment. 127 | * `name`: Name of the environment. 128 | * `n_workers`: Number of workers that will be used to step the environment. 129 | * `ray`: Use ray for taking steps in parallel when calling `step_batch()`. 130 | * `domain_name`: Return the name of the agent in the current 131 | simulation. It is a keyword argument that is only valid for 132 | `dm_control` environments. 133 | * `state`: Define a specific state for the environment. The state parameter 134 | only works for `RetroEnv`, and it is used to select the starting level of the 135 | selected game. All the other environments do no accept state as a keyword argument, 136 | and specific states can be set using `get_state()` and `set_state()`. 137 | * Environment instance attributes: 138 | Parameters passed when the class is created. They define and configure the attributes of the class. `make()` accepts 139 | these arguments as _kwargs_. 140 | 141 | All keyword arguments that do not belong to the _Make arguments_ list are passed as _kwargs_ inside `make()` 142 | to instantiate the corresponding environment class (we must emphasize that `plangym` will also use 143 | some attributes included inside the _Make arguments_ classification as instance attributes of the class, such 144 | as `state` or `domain_name`). 145 | 146 | #### Instance attributes 147 | 148 | As mentioned, users dispose of several parameters to configure the environment creation process 149 | and the attributes of the class itself. Instance parameters are passed as _kwargs_ to the environment class. 150 | 151 | Inside these instance attributes, we should differentiate between the attributes managed by `plangym`, and 152 | those that are specific to the `gym` library. `plangym` attributes characterize the envelope that wraps 153 | the original `gym` environment, offering a standard interface among all the processes. `gym` attributes are 154 | those not managed by `plangym` and are passed __directly__ to the `gym.make` method. 155 | 156 | The instance attributes (managed by `plangym`) common to all environment classes are: 157 | * `name`: Name of the environment. Follows standard gym syntax conventions. 158 | * `frameskip`: Number of times an action will be applied for each ``dt``. When we __step__ the environment, 159 | we take `dt` simulation steps, i.e., we evolve _dt_-times the environment (by applying an action) __in each__ 160 | step. Within __each__ simulation step `dt`, we apply the same action `frameskip` times. At the end 161 | of the day, the environment will have evolved `dt * frameskip` times. 162 | * `autoreset`: Automatically reset the `plangym.environment` when the OpenAI environment returns ``end = True``. 163 | * `wrappers`: Wrappers that will be applied to the underlying OpenAI environment. Every element 164 | of the iterable can be either a class `gym.Wrapper` or a tuple containing ``(gym.Wrapper, kwargs)``. 165 | * `delay_setup`: If ``True``, `plangym` does not initialize the class `gym.environment`and 166 | waits for ``setup`` to be called later. Deferring the environment instantiation gives the users 167 | the option to create it in external processes or when demanded. This fact allows sending `plangym.environment` 168 | as serializable objects, leaving all the settings already defined and prepared for the user to 169 | instantiate the environment when needed. 170 | * `remove_time_limit`: If `True`, remove the time limit from the environment. 171 | * `render_mode`: Select how the environment and the observations are represented. Options to be 172 | selected are `[None, "human", "rgb_aray"]`. 173 | * `episodic_life`: If `True`, `plangym` sends a terminal signal after loosing a life. 174 | * `obs_type`: Define how `plangym` calculates the observations. Options to be selected 175 | are `["coords", "rgb", "grayscale", None]`. 176 | * `return_image`: If ``True``, 'plangym' adds an "rgb" key in the `info` dictionary returned by 177 | `plangym.env.step` method. This key contains an RGB representation of the environment state. 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | -------------------------------------------------------------------------------- /docs/source/notes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FragileTech/plangym/417325eac9be9668f4065351ba22f3f6461f5a4f/docs/source/notes/__init__.py -------------------------------------------------------------------------------- /docs/source/notes/interface.rst: -------------------------------------------------------------------------------- 1 | Plangym API 2 | ======================================= 3 | 4 | .. autoclass:: plangym.core::PlanEnv 5 | :members: 6 | 7 | .. autoclass:: plangym.core::PlangymEnv 8 | :members: 9 | 10 | Videogames 11 | =========== 12 | 13 | Atari 2600 14 | ----------------- 15 | .. autoclass:: plangym.videogames.atari.AtariEnv 16 | :members: 17 | 18 | .. autoclass:: plangym.videogames.montezuma.MontezumaEnv 19 | :members: 20 | 21 | Gym retro 22 | ----------------- 23 | 24 | .. autoclass:: plangym.videogames.retro.RetroEnv 25 | :members: 26 | 27 | Super Mario (NES) 28 | ----------------- 29 | 30 | .. autoclass:: plangym.videogames.nes.MarioEnv 31 | :members: 32 | 33 | .. autoclass:: plangym.videogames.nes.NesEnv 34 | :members: 35 | 36 | Video games API 37 | -------------- 38 | .. autoclass:: plangym.videogames.env.VideogameEnv 39 | :members: 40 | 41 | Control Tasks 42 | ============= 43 | 44 | DM Control 45 | ----------------- 46 | .. autoclass:: plangym.control.dm_control.DMControlEnv 47 | :members: 48 | 49 | Classic control 50 | ----------------- 51 | 52 | .. autoclass:: plangym.control.classic_control.ClassicControl 53 | :members: 54 | 55 | Box2D 56 | ----------------- 57 | 58 | .. autoclass:: plangym.control.box_2d.Box2DEnv 59 | :members: 60 | 61 | .. autoclass:: plangym.control.lunar_lander.LunarLander 62 | :members: 63 | 64 | Vectorization 65 | ============== 66 | 67 | Multiprocessing 68 | ---------------- 69 | .. autoclass:: plangym.vectorization.parallel.ParallelEnv 70 | :members: 71 | 72 | Ray 73 | ---------------- 74 | .. autoclass:: plangym.vectorization.ray.RayEnv 75 | :members: 76 | 77 | Vectorization API 78 | ----------------- 79 | .. autoclass:: plangym.vectorization.env.VectorizedEnv 80 | :members: 81 | 82 | -------------------------------------------------------------------------------- /install-lua-macos.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # This scripts installs Lua, LuaRocks, and some Lua libraries on macOS. 4 | # The main purpose is to install Busted for testing Neovim plugins. 5 | # After the installation, you will be able to run test using busted: 6 | # busted --lua nlua spec/mytest_spec.lua 7 | 8 | ################################################################################ 9 | # Dependencies 10 | ################################################################################ 11 | 12 | xcode-select --install 13 | 14 | # Lua Directory: where Lua and Luarocks will be installed 15 | # You can change installation location by changing this variable 16 | LUA_DIR="$HOME/Developer/lua" 17 | 18 | mkdir -p $LUA_DIR 19 | 20 | ################################################################################ 21 | # Lua 22 | ################################################################################ 23 | 24 | # Download and Extract Lua Sources 25 | cd /tmp 26 | rm -rf lua-5.1.5.* 27 | wget https://www.lua.org/ftp/lua-5.1.5.tar.gz 28 | LUA_SHA='2640fc56a795f29d28ef15e13c34a47e223960b0240e8cb0a82d9b0738695333' 29 | shasum -a 256 lua-5.1.5.tar.gz | grep -q $LUA_SHA && echo "Hash matches" || echo "Hash don't match" 30 | tar xvf lua-5.1.5.tar.gz 31 | cd lua-5.1.5/ 32 | 33 | # Modify Makefile to set destination dir 34 | sed -i '' "s#/usr/local#${LUA_DIR}/#g" Makefile 35 | 36 | # Compile and install Lua 37 | make macosx 38 | make test && make install 39 | 40 | # Export PATHs 41 | export PATH="$PATH:$LUA_DIR/bin" 42 | export LUA_CPATH="$LUA_DIR/lib/lua/5.1/?.so" 43 | export LUA_PATH="$LUA_DIR/share/lua/5.1/?.lua;;" 44 | export MANPATH="$LUA_DIR/share/man:$MANPATH" 45 | 46 | # Verify Lua Installation 47 | which lua 48 | echo "Expected Output:" 49 | echo " ${LUA_DIR}/bin/lua" 50 | lua -v 51 | echo 'Expected Output:' 52 | echo ' Lua 5.1.5 Copyright (C) 1994-2012 Lua.org, PUC-Rio' 53 | file ${LUA_DIR}/bin/lua 54 | echo "Expected Output (on Apple Silicon):" 55 | echo " ${LUA_DIR}/bin/lua: Mach-O 64-bit executable arm64" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "plangym" 3 | dynamic = ["version"] 4 | description = "Plangym is an interface to use gymnasium for planning problems. It extends the standard interface to allow setting and recovering the environment states." 5 | authors = [{ name = "Guillem Duran Ballester", email = "guillem@fragile.tech" }] 6 | maintainers = [{ name = "Guillem Duran Ballester", email = "guillem@fragile.tech" }] 7 | license = {file = "LICENSE"} 8 | readme = "README.md" 9 | requires-python = ">=3.10" 10 | packages = [{ include = "plangym", from = "src" }] 11 | include = ["tests/**/*", "tests/**/.*"] 12 | homepage = "https://github.com/FragileTech/plangym" 13 | repository = "https://github.com/FragileTech/plangym" 14 | documentation = "https://github.com/FragileTech/plangym" 15 | keywords = ["RL", "gymnasium", "planning", "plangym"] 16 | classifiers = [ 17 | "Development Status :: 3 - Alpha", 18 | "Intended Audience :: Developers", 19 | "License :: OSI Approved :: MIT License", 20 | "Programming Language :: Python :: 3.10", 21 | "Topic :: Software Development :: Libraries", 22 | ] 23 | dependencies = [ 24 | "numpy", 25 | "pillow; sys_platform != 'darwin'", 26 | "fragile-gym", 27 | "opencv-python>=4.10.0.84", 28 | "pyglet==1.5.11", 29 | "pyvirtualdisplay>=3.0", 30 | "imageio>=2.35.1", 31 | "flogging>=0.0.22", 32 | ] 33 | [project.optional-dependencies] 34 | atari = ["ale-py", "gymnasium[accept-rom-license,atari]>=0.29.1, == 0.*"] 35 | nes = [ 36 | "fragile-gym[accept-rom-license]", 37 | "fragile-nes-py>=10.0.1", # Requires clang, build-essential 38 | "fragile-gym-super-mario-bros>=7.4.1", 39 | ] 40 | classic-control = ["gymnasium[classic_control]>=0.29.1, == 0.*", "pygame>=2.6.0"] 41 | ray = ["ray>=2.35.0"] 42 | dm_control = ["mujoco>=3.2.2", "dm-control>=1.0.22"] 43 | retro = [ 44 | "stable-retro==0.9.2; sys_platform != 'darwin'", 45 | "stable-retro==0.9.1; sys_platform == 'darwin'" 46 | ] 47 | jupyter = ["jupyterlab>=3.2.0"] 48 | box_2d = ["box2d-py==2.3.5"] 49 | test = [ 50 | "psutil>=5.8.0", 51 | "pytest>=6.2.5", 52 | "pytest-cov>=3.0.0", 53 | "pytest-xdist>=2.4.0", 54 | "pytest-rerunfailures>=10.2", 55 | "pyvirtualdisplay>=1.3.2", 56 | "tomli>=1.2.3", 57 | "hypothesis>=6.24.6" 58 | ] 59 | docs = [ 60 | "sphinx", 61 | "linkify-it-py", 62 | "myst-parser", 63 | "myst-nb", 64 | "ruyaml", 65 | "sphinx-autoapi", 66 | "pydata-sphinx-theme", 67 | "sphinx-autodoc2", 68 | "sphinxcontrib-mermaid", 69 | "sphinx_book_theme", 70 | "sphinx_rtd_theme", 71 | "jupyter-cache", 72 | "sphinx-copybutton", 73 | "sphinx-togglebutton", 74 | "sphinxext-opengraph", 75 | "sphinxcontrib-bibtex", 76 | ] 77 | 78 | [build-system] 79 | requires = ["hatchling"] 80 | build-backend = "hatchling.build" 81 | [tool.hatch.metadata] 82 | allow-direct-references = true 83 | [tool.hatch.version] 84 | path = "src/plangym/version.py" 85 | 86 | [tool.rye] 87 | dev-dependencies = ["ruff"] 88 | #excluded-dependencies = ["gym"] 89 | universal = true 90 | 91 | [tool.rye.scripts] 92 | style = { chain = ["ruff check --fix-only --unsafe-fixes tests src", "ruff format tests src"] } 93 | check = { chain = ["ruff check --diff tests src", "ruff format --diff tests src"]} #,"mypy src tests" ] } 94 | test = { chain = ["test:doctest", "test:parallel", "test:singlecore"] } 95 | codecov = { chain = ["codecov:singlecore", "codecov:parallel"] } 96 | import-roms = { cmd = "python3 src/plangym/scripts/import_retro_roms.py" } 97 | "test:parallel" = { cmd = "pytest -n auto -s -o log_cli=true -o log_cli_level=info tests", env-file = ".multicore.env" } 98 | "test:singlecore" = { cmd = "pytest -s -o log_cli=true -o log_cli_level=info tests/control/test_classic_control.py", env-file = ".onecore.env" } 99 | "test:doctest" = { cmd = "pytest --doctest-modules -n 0 -s -o log_cli=true -o log_cli_level=info src", env-file = ".multicore.env" } 100 | "codecov:parallel" = { chain = ["codecov:parallel_1", "codecov:parallel_2", "codecov:parallel_3", "codecov:vectorization"] } 101 | "codecov:parallel_1" = { cmd = "pytest -n auto -s -o log_cli=true -o log_cli_level=info --cov=./ --cov-report=xml:coverage_parallel_1.xml --cov-config=pyproject.toml tests/test_core.py tests/test_registry.py tests/test_utils.py", env-file = ".multicore.env" } 102 | "codecov:parallel_2" = { cmd = "pytest -n auto -s -o log_cli=true -o log_cli_level=info --cov=./ --cov-report=xml:coverage_parallel_2.xml --cov-config=pyproject.toml tests/videogames", env-file = ".multicore.env" } 103 | "codecov:parallel_3" = { cmd = "pytest -n auto -s -o log_cli=true -o log_cli_level=info --cov=./ --cov-report=xml:coverage_parallel_3.xml --cov-config=pyproject.toml tests/control", env-file = ".multicore.env" } 104 | "codecov:vectorization" = { cmd = "pytest -n 0 -s -o log_cli=true -o log_cli_level=info --cov=./ --cov-report=xml:coverage_vectorization.xml --cov-config=pyproject.toml tests/vectorization", env-file = ".multicore.env" } 105 | "codecov:singlecore" = { cmd = "pytest --doctest-modules -s -o log_cli=true -o log_cli_level=info --cov=./ --cov-report=xml --cov-config=pyproject.toml tests/control/test_classic_control.py", env-file = ".onecore.env" } 106 | docs = {chain = ["build-docs", "serve-docs"]} 107 | build-docs = { cmd = "sphinx-build -b html docs/source docs/build"} 108 | serve-docs = { cmd = "python3 -m http.server --directory docs/build" } 109 | 110 | [tool.ruff] 111 | # Assume Python 3.10 112 | target-version = "py310" 113 | preview = true 114 | include = ["*.py", "*.pyi", "**/pyproject.toml"]#, "*.ipynb"] 115 | # Exclude a variety of commonly ignored directories. 116 | exclude = [ 117 | ".bzr", 118 | ".direnv", 119 | ".eggs", 120 | ".git", 121 | ".git-rewrite", 122 | ".hg", 123 | ".mypy_cache", 124 | ".nox", 125 | ".pants.d", 126 | ".pytype", 127 | ".ruff_cache", 128 | ".svn", 129 | ".tox", 130 | ".venv", 131 | ".idea", 132 | "__pypackages__", 133 | "_build", 134 | "buck-out", 135 | "build", 136 | "dist", 137 | "node_modules", 138 | "output", 139 | "venv", 140 | "experimental", 141 | ".pytest_cache", 142 | "**/.ipynb_checkpoints/**", 143 | "**/proto/**", 144 | "data", 145 | "config", 146 | ] 147 | # Same as Black. 148 | line-length = 99 149 | [tool.ruff.lint] 150 | # Allow unused variables when underscore-prefixed. 151 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 152 | select = [ 153 | "ARG", "C4", "D", "E", "EM", "F", "FBT", 154 | "FLY", "FIX", "FURB", "N", "NPY", 155 | "INP", "ISC", "PERF", "PIE", "PL", 156 | "PTH", "RET", "RUF", "S", "T10", 157 | "TD", "T20", "UP", "YTT", "W", 158 | ] 159 | ignore = [ 160 | "D100", "D211", "D213", "D104", "D203", "D301", "D407", "S101", 161 | "FBT001", "FBT002", "FIX002", "ISC001", "PLR0913", "RUF012", "TD003", 162 | "PTH123", "PLR6301", "PLR0917", "S311", "S403", "PLR0914", "PLR0915", "S608", 163 | "EM102", "PTH111", "FIX004", "UP035", "PLW2901", "S318", "S408", 'S405', 164 | 'E902', "TD001", "TD002", "FIX001", 165 | ] 166 | # Allow autofix for all enabled rules (when `--fix`) is provided. 167 | fixable = ["ALL"] 168 | unfixable = ["I"] 169 | 170 | [tool.ruff.lint.flake8-quotes] 171 | docstring-quotes = "double" 172 | 173 | [tool.ruff.lint.per-file-ignores] 174 | "__init__.py" = ["E402", "F401"] 175 | "cli.py" = ["PLC0415", "D205", "D400", "D415"] 176 | "core.py" = ["ARG002", "PLR0904"] 177 | "_old_core.py" = ["ALL"] 178 | "lunar_lander.py" = ["PLR2004", "FBT003", "N806"] 179 | "api_tests.py" = ["D", "ARG002", "PLW1508", "FBT003", "PLR2004"] 180 | "montezuma.py" = ["PLR2004", "S101", "ARG002", "TD002"] 181 | "registry.py" = ["PLC0415", "PLR0911"] 182 | "**/docs/**" = ["INP001", "PTH100"] 183 | "**/super_mario_gym/**" = ["ALL"] 184 | "**/{tests,docs,tools}/*" = [ 185 | "E402", "F401", "F811", "D", "S101", "PLR2004", "S105", 186 | "PLW1514", "PTH123", "PTH107", "N811", "PLC0415", "ARG002", 187 | ] 188 | # Enable reformatting of code snippets in docstrings. 189 | [tool.ruff.format] 190 | docstring-code-line-length = 80 191 | docstring-code-format = true 192 | indent-style = "space" 193 | line-ending = "auto" 194 | preview = true 195 | quote-style = "double" 196 | 197 | [tool.mypy] 198 | exclude = ["experimental.*", "deprecated.*"] 199 | ignore_missing_imports = true 200 | 201 | # isort orders and lints imports 202 | [tool.isort] 203 | profile = "black" 204 | line_length = 99 205 | multi_line_output = 3 206 | order_by_type = false 207 | force_alphabetical_sort_within_sections = true 208 | force_sort_within_sections = true 209 | combine_as_imports = true 210 | include_trailing_comma = true 211 | color_output = true 212 | lines_after_imports = 2 213 | honor_noqa = true 214 | skip = ["venv", ".venv"] 215 | skip_glob = ["*.pyx"] 216 | 217 | [tool.pylint.master] 218 | ignore = 'tests' 219 | load-plugins =' pylint.extensions.docparams' 220 | 221 | [tool.pylint.messages_control] 222 | disable = 'all,' 223 | enable = """, 224 | missing-param-doc, 225 | differing-param-doc, 226 | differing-type-doc, 227 | missing-return-doc, 228 | """ 229 | 230 | [tool.pytest.ini_options] 231 | # To disable a specific warning --> action:message:category:module:line 232 | filterwarnings = ["ignore::UserWarning", 'ignore::DeprecationWarning'] 233 | addopts = "--ignore=scripts --doctest-continue-on-failure" 234 | 235 | # Code coverage config 236 | [tool.coverage.run] 237 | branch = true 238 | 239 | [tool.coverage.report] 240 | exclude_lines =["no cover", 241 | 'raise NotImplementedError', 242 | 'if __name__ == "__main__":'] 243 | ignore_errors = true 244 | omit = ["tests/*", "src/plangym/scripts/*"] 245 | -------------------------------------------------------------------------------- /requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: true 8 | # with-sources: false 9 | # generate-hashes: false 10 | # universal: true 11 | 12 | -e file:. 13 | absl-py==2.1.0 14 | # via dm-control 15 | # via dm-env 16 | # via labmaze 17 | # via mujoco 18 | accessible-pygments==0.0.5 19 | # via pydata-sphinx-theme 20 | aiosignal==1.3.1 21 | # via ray 22 | alabaster==1.0.0 23 | # via sphinx 24 | ale-py==0.8.1 25 | # via plangym 26 | # via shimmy 27 | anyio==4.4.0 28 | # via httpx 29 | # via jupyter-server 30 | appnope==0.1.4 ; platform_system == 'Darwin' 31 | # via ipykernel 32 | argon2-cffi==23.1.0 33 | # via jupyter-server 34 | argon2-cffi-bindings==21.2.0 35 | # via argon2-cffi 36 | arrow==1.3.0 37 | # via isoduration 38 | astroid==3.3.2 39 | # via sphinx-autoapi 40 | # via sphinx-autodoc2 41 | asttokens==2.4.1 42 | # via stack-data 43 | async-lru==2.0.4 44 | # via jupyterlab 45 | attrs==24.2.0 46 | # via hypothesis 47 | # via jsonschema 48 | # via jupyter-cache 49 | # via referencing 50 | autorom==0.4.2 51 | # via fragile-gym 52 | # via gymnasium 53 | autorom-accept-rom-license==0.6.1 54 | # via autorom 55 | babel==2.16.0 56 | # via jupyterlab-server 57 | # via pydata-sphinx-theme 58 | # via sphinx 59 | beautifulsoup4==4.12.3 60 | # via nbconvert 61 | # via pydata-sphinx-theme 62 | bleach==6.1.0 63 | # via nbconvert 64 | box2d-py==2.3.5 65 | # via plangym 66 | certifi==2024.8.30 67 | # via httpcore 68 | # via httpx 69 | # via requests 70 | cffi==1.17.0 71 | # via argon2-cffi-bindings 72 | # via pyzmq 73 | charset-normalizer==3.3.2 74 | # via requests 75 | click==8.1.7 76 | # via autorom 77 | # via autorom-accept-rom-license 78 | # via jupyter-cache 79 | # via ray 80 | cloudpickle==3.0.0 81 | # via fragile-gym 82 | # via gymnasium 83 | colorama==0.4.6 ; sys_platform == 'win32' or platform_system == 'Windows' 84 | # via click 85 | # via ipython 86 | # via pytest 87 | # via sphinx 88 | # via tqdm 89 | comm==0.2.2 90 | # via ipykernel 91 | coverage==7.6.1 92 | # via pytest-cov 93 | debugpy==1.8.5 94 | # via ipykernel 95 | decorator==5.1.1 96 | # via ipython 97 | defusedxml==0.7.1 98 | # via nbconvert 99 | distro==1.9.0 100 | # via ruyaml 101 | dm-control==1.0.22 102 | # via plangym 103 | dm-env==1.6 104 | # via dm-control 105 | dm-tree==0.1.8 106 | # via dm-control 107 | # via dm-env 108 | docutils==0.21.2 109 | # via myst-parser 110 | # via pybtex-docutils 111 | # via pydata-sphinx-theme 112 | # via sphinx 113 | # via sphinx-togglebutton 114 | # via sphinxcontrib-bibtex 115 | etils==1.9.4 116 | # via mujoco 117 | exceptiongroup==1.2.2 ; python_full_version < '3.11' 118 | # via anyio 119 | # via hypothesis 120 | # via ipython 121 | # via pytest 122 | execnet==2.1.1 123 | # via pytest-xdist 124 | executing==2.1.0 125 | # via stack-data 126 | farama-notifications==0.0.4 127 | # via gymnasium 128 | fastjsonschema==2.20.0 129 | # via nbformat 130 | filelock==3.15.4 131 | # via ray 132 | flogging==0.0.22 133 | # via plangym 134 | fqdn==1.5.1 135 | # via jsonschema 136 | fragile-gym==1.21.1 137 | # via fragile-nes-py 138 | # via plangym 139 | fragile-gym-super-mario-bros==7.4.1 140 | # via plangym 141 | fragile-nes-py==10.0.1 142 | # via plangym 143 | frozenlist==1.4.1 144 | # via aiosignal 145 | # via ray 146 | fsspec==2024.6.1 147 | # via etils 148 | glfw==2.7.0 149 | # via dm-control 150 | # via mujoco 151 | greenlet==3.0.3 ; (python_full_version < '3.13' and platform_machine == 'AMD64') or (python_full_version < '3.13' and platform_machine == 'WIN32') or (python_full_version < '3.13' and platform_machine == 'aarch64') or (python_full_version < '3.13' and platform_machine == 'amd64') or (python_full_version < '3.13' and platform_machine == 'ppc64le') or (python_full_version < '3.13' and platform_machine == 'win32') or (python_full_version < '3.13' and platform_machine == 'x86_64') 152 | # via sqlalchemy 153 | gymnasium==0.29.1 154 | # via plangym 155 | # via shimmy 156 | # via stable-retro 157 | h11==0.14.0 158 | # via httpcore 159 | httpcore==1.0.5 160 | # via httpx 161 | httpx==0.27.2 162 | # via jupyterlab 163 | hypothesis==6.111.2 164 | # via plangym 165 | idna==3.8 166 | # via anyio 167 | # via httpx 168 | # via jsonschema 169 | # via requests 170 | imageio==2.35.1 171 | # via plangym 172 | imagesize==1.4.1 173 | # via sphinx 174 | importlib-metadata==8.4.0 175 | # via jupyter-cache 176 | # via myst-nb 177 | importlib-resources==6.4.4 178 | # via ale-py 179 | # via etils 180 | iniconfig==2.0.0 181 | # via pytest 182 | ipykernel==6.29.5 183 | # via jupyterlab 184 | # via myst-nb 185 | ipython==8.27.0 186 | # via fragile-gym 187 | # via ipykernel 188 | # via myst-nb 189 | isoduration==20.11.0 190 | # via jsonschema 191 | jedi==0.19.1 192 | # via ipython 193 | jinja2==3.1.4 194 | # via jupyter-server 195 | # via jupyterlab 196 | # via jupyterlab-server 197 | # via myst-parser 198 | # via nbconvert 199 | # via sphinx 200 | # via sphinx-autoapi 201 | json5==0.9.25 202 | # via jupyterlab-server 203 | jsonpointer==3.0.0 204 | # via jsonschema 205 | jsonschema==4.23.0 206 | # via jupyter-events 207 | # via jupyterlab-server 208 | # via nbformat 209 | # via ray 210 | jsonschema-specifications==2023.12.1 211 | # via jsonschema 212 | jupyter-cache==1.0.0 213 | # via myst-nb 214 | # via plangym 215 | jupyter-client==8.6.2 216 | # via ipykernel 217 | # via jupyter-server 218 | # via nbclient 219 | jupyter-core==5.7.2 220 | # via ipykernel 221 | # via jupyter-client 222 | # via jupyter-server 223 | # via jupyterlab 224 | # via nbclient 225 | # via nbconvert 226 | # via nbformat 227 | jupyter-events==0.10.0 228 | # via jupyter-server 229 | jupyter-lsp==2.2.5 230 | # via jupyterlab 231 | jupyter-server==2.14.2 232 | # via jupyter-lsp 233 | # via jupyterlab 234 | # via jupyterlab-server 235 | # via notebook-shim 236 | jupyter-server-terminals==0.5.3 237 | # via jupyter-server 238 | jupyterlab==4.2.5 239 | # via plangym 240 | jupyterlab-pygments==0.3.0 241 | # via nbconvert 242 | jupyterlab-server==2.27.3 243 | # via jupyterlab 244 | labmaze==1.0.6 245 | # via dm-control 246 | latexcodec==3.0.0 247 | # via pybtex 248 | linkify-it-py==2.0.3 249 | # via plangym 250 | lxml==5.3.0 251 | # via dm-control 252 | markdown-it-py==3.0.0 253 | # via mdit-py-plugins 254 | # via myst-parser 255 | markupsafe==2.1.5 256 | # via jinja2 257 | # via nbconvert 258 | matplotlib-inline==0.1.7 259 | # via ipykernel 260 | # via ipython 261 | mdit-py-plugins==0.4.1 262 | # via myst-parser 263 | mdurl==0.1.2 264 | # via markdown-it-py 265 | mistune==3.0.2 266 | # via nbconvert 267 | msgpack==1.0.8 268 | # via ray 269 | mujoco==3.2.2 270 | # via dm-control 271 | # via plangym 272 | myst-nb==1.1.1 273 | # via plangym 274 | myst-parser==4.0.0 275 | # via myst-nb 276 | # via plangym 277 | nbclient==0.10.0 278 | # via jupyter-cache 279 | # via myst-nb 280 | # via nbconvert 281 | nbconvert==7.16.4 282 | # via jupyter-server 283 | nbformat==5.10.4 284 | # via jupyter-cache 285 | # via jupyter-server 286 | # via myst-nb 287 | # via nbclient 288 | # via nbconvert 289 | nest-asyncio==1.6.0 290 | # via ipykernel 291 | notebook-shim==0.2.4 292 | # via jupyterlab 293 | numpy==2.1.1 294 | # via ale-py 295 | # via dm-control 296 | # via dm-env 297 | # via fragile-gym 298 | # via fragile-nes-py 299 | # via gymnasium 300 | # via imageio 301 | # via labmaze 302 | # via mujoco 303 | # via opencv-python 304 | # via plangym 305 | # via scipy 306 | # via shimmy 307 | opencv-python==4.10.0.84 308 | # via plangym 309 | overrides==7.7.0 310 | # via jupyter-server 311 | packaging==24.1 312 | # via ipykernel 313 | # via jupyter-server 314 | # via jupyterlab 315 | # via jupyterlab-server 316 | # via nbconvert 317 | # via pydata-sphinx-theme 318 | # via pytest 319 | # via pytest-rerunfailures 320 | # via ray 321 | # via sphinx 322 | pandocfilters==1.5.1 323 | # via nbconvert 324 | parso==0.8.4 325 | # via jedi 326 | pexpect==4.9.0 ; sys_platform != 'emscripten' and sys_platform != 'win32' 327 | # via ipython 328 | pillow==10.4.0 329 | # via imageio 330 | # via plangym 331 | platformdirs==4.2.2 332 | # via jupyter-core 333 | pluggy==1.5.0 334 | # via pytest 335 | prometheus-client==0.20.0 336 | # via jupyter-server 337 | prompt-toolkit==3.0.47 338 | # via ipython 339 | protobuf==5.28.0 340 | # via dm-control 341 | # via ray 342 | psutil==6.0.0 343 | # via ipykernel 344 | # via plangym 345 | ptyprocess==0.7.0 ; os_name != 'nt' or (sys_platform != 'emscripten' and sys_platform != 'win32') 346 | # via pexpect 347 | # via terminado 348 | pure-eval==0.2.3 349 | # via stack-data 350 | pybtex==0.24.0 351 | # via pybtex-docutils 352 | # via sphinxcontrib-bibtex 353 | pybtex-docutils==1.0.3 354 | # via sphinxcontrib-bibtex 355 | pycparser==2.22 356 | # via cffi 357 | pydata-sphinx-theme==0.15.4 358 | # via plangym 359 | # via sphinx-book-theme 360 | pygame==2.6.0 361 | # via gymnasium 362 | # via plangym 363 | pyglet==1.5.11 364 | # via fragile-nes-py 365 | # via plangym 366 | # via stable-retro 367 | pygments==2.18.0 368 | # via accessible-pygments 369 | # via ipython 370 | # via nbconvert 371 | # via pydata-sphinx-theme 372 | # via sphinx 373 | pyopengl==3.1.7 374 | # via dm-control 375 | # via mujoco 376 | pyparsing==3.1.4 377 | # via dm-control 378 | pytest==8.3.2 379 | # via plangym 380 | # via pytest-cov 381 | # via pytest-rerunfailures 382 | # via pytest-xdist 383 | pytest-cov==5.0.0 384 | # via plangym 385 | pytest-rerunfailures==14.0 386 | # via plangym 387 | pytest-xdist==3.6.1 388 | # via plangym 389 | python-dateutil==2.9.0.post0 390 | # via arrow 391 | # via jupyter-client 392 | python-json-logger==2.0.7 393 | # via jupyter-events 394 | pyvirtualdisplay==3.0 395 | # via plangym 396 | pywin32==306 ; platform_python_implementation != 'PyPy' and sys_platform == 'win32' 397 | # via jupyter-core 398 | pywinpty==2.0.13 ; os_name == 'nt' 399 | # via jupyter-server 400 | # via jupyter-server-terminals 401 | # via terminado 402 | pyyaml==6.0.2 403 | # via jupyter-cache 404 | # via jupyter-events 405 | # via myst-nb 406 | # via myst-parser 407 | # via pybtex 408 | # via ray 409 | # via sphinx-autoapi 410 | pyzmq==26.2.0 411 | # via ipykernel 412 | # via jupyter-client 413 | # via jupyter-server 414 | ray==2.35.0 415 | # via plangym 416 | referencing==0.35.1 417 | # via jsonschema 418 | # via jsonschema-specifications 419 | # via jupyter-events 420 | requests==2.32.3 421 | # via autorom 422 | # via autorom-accept-rom-license 423 | # via dm-control 424 | # via jupyterlab-server 425 | # via ray 426 | # via sphinx 427 | rfc3339-validator==0.1.4 428 | # via jsonschema 429 | # via jupyter-events 430 | rfc3986-validator==0.1.1 431 | # via jsonschema 432 | # via jupyter-events 433 | rpds-py==0.20.0 434 | # via jsonschema 435 | # via referencing 436 | ruyaml==0.91.0 437 | # via plangym 438 | scipy==1.14.1 439 | # via dm-control 440 | send2trash==1.8.3 441 | # via jupyter-server 442 | setuptools==74.1.1 443 | # via dm-control 444 | # via jupyterlab 445 | # via labmaze 446 | # via ruyaml 447 | # via sphinx-togglebutton 448 | shimmy==0.2.1 449 | # via gymnasium 450 | six==1.16.0 451 | # via asttokens 452 | # via bleach 453 | # via pybtex 454 | # via python-dateutil 455 | # via rfc3339-validator 456 | sniffio==1.3.1 457 | # via anyio 458 | # via httpx 459 | snowballstemmer==2.2.0 460 | # via sphinx 461 | sortedcontainers==2.4.0 462 | # via hypothesis 463 | soupsieve==2.6 464 | # via beautifulsoup4 465 | sphinx==8.0.2 466 | # via myst-nb 467 | # via myst-parser 468 | # via plangym 469 | # via pydata-sphinx-theme 470 | # via sphinx-autoapi 471 | # via sphinx-book-theme 472 | # via sphinx-copybutton 473 | # via sphinx-rtd-theme 474 | # via sphinx-togglebutton 475 | # via sphinxcontrib-bibtex 476 | # via sphinxext-opengraph 477 | sphinx-autoapi==3.3.1 478 | # via plangym 479 | sphinx-autodoc2==0.5.0 480 | # via plangym 481 | sphinx-book-theme==1.1.3 482 | # via plangym 483 | sphinx-copybutton==0.5.2 484 | # via plangym 485 | sphinx-rtd-theme==0.5.1 486 | # via plangym 487 | sphinx-togglebutton==0.3.2 488 | # via plangym 489 | sphinxcontrib-applehelp==2.0.0 490 | # via sphinx 491 | sphinxcontrib-bibtex==2.6.2 492 | # via plangym 493 | sphinxcontrib-devhelp==2.0.0 494 | # via sphinx 495 | sphinxcontrib-htmlhelp==2.1.0 496 | # via sphinx 497 | sphinxcontrib-jsmath==1.0.1 498 | # via sphinx 499 | sphinxcontrib-mermaid==0.9.2 500 | # via plangym 501 | sphinxcontrib-qthelp==2.0.0 502 | # via sphinx 503 | sphinxcontrib-serializinghtml==2.0.0 504 | # via sphinx 505 | sphinxext-opengraph==0.9.1 506 | # via plangym 507 | sqlalchemy==2.0.33 508 | # via jupyter-cache 509 | stable-retro==0.9.1 ; sys_platform == 'darwin' 510 | # via plangym 511 | stable-retro==0.9.2 ; sys_platform != 'darwin' 512 | # via plangym 513 | stack-data==0.6.3 514 | # via ipython 515 | tabulate==0.9.0 516 | # via jupyter-cache 517 | terminado==0.18.1 518 | # via jupyter-server 519 | # via jupyter-server-terminals 520 | tinycss2==1.3.0 521 | # via nbconvert 522 | tomli==2.0.1 523 | # via coverage 524 | # via jupyterlab 525 | # via plangym 526 | # via pytest 527 | # via sphinx 528 | # via sphinx-autodoc2 529 | tornado==6.4.1 530 | # via ipykernel 531 | # via jupyter-client 532 | # via jupyter-server 533 | # via jupyterlab 534 | # via terminado 535 | tqdm==4.66.5 536 | # via autorom 537 | # via dm-control 538 | # via fragile-nes-py 539 | traitlets==5.14.3 540 | # via comm 541 | # via ipykernel 542 | # via ipython 543 | # via jupyter-client 544 | # via jupyter-core 545 | # via jupyter-events 546 | # via jupyter-server 547 | # via jupyterlab 548 | # via matplotlib-inline 549 | # via nbclient 550 | # via nbconvert 551 | # via nbformat 552 | types-python-dateutil==2.9.0.20240821 553 | # via arrow 554 | typing-extensions==4.12.2 555 | # via ale-py 556 | # via anyio 557 | # via astroid 558 | # via async-lru 559 | # via etils 560 | # via gymnasium 561 | # via ipython 562 | # via myst-nb 563 | # via pydata-sphinx-theme 564 | # via sphinx-autodoc2 565 | # via sqlalchemy 566 | uc-micro-py==1.0.3 567 | # via linkify-it-py 568 | uri-template==1.3.0 569 | # via jsonschema 570 | urllib3==2.2.2 571 | # via requests 572 | wcwidth==0.2.13 573 | # via prompt-toolkit 574 | webcolors==24.8.0 575 | # via jsonschema 576 | webencodings==0.5.1 577 | # via bleach 578 | # via tinycss2 579 | websocket-client==1.8.0 580 | # via jupyter-server 581 | wheel==0.44.0 582 | # via sphinx-togglebutton 583 | xxhash==3.5.0 584 | # via flogging 585 | zipp==3.20.1 586 | # via etils 587 | # via importlib-metadata 588 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FragileTech/plangym/417325eac9be9668f4065351ba22f3f6461f5a4f/src/__init__.py -------------------------------------------------------------------------------- /src/plangym/__init__.py: -------------------------------------------------------------------------------- 1 | """Various environments for plangym.""" 2 | 3 | import warnings 4 | 5 | 6 | warnings.filterwarnings( 7 | "ignore", 8 | message=( 9 | "Using or importing the ABCs from 'collections' instead of from 'collections.abc' " 10 | "is deprecated since Python 3.3,and in 3.9 it will stop working" 11 | ), 12 | ) 13 | warnings.filterwarnings( 14 | "ignore", 15 | message=( 16 | "the imp module is deprecated in favour of importlib; see the module's " 17 | "documentation for alternative uses" 18 | ), 19 | ) 20 | warnings.filterwarnings( 21 | "ignore", 22 | message=( 23 | "Using or importing the ABCs from 'collections' instead of from " 24 | "'collections.abc' is deprecated, and in 3.8 it will stop working" 25 | ), 26 | ) 27 | warnings.filterwarnings( 28 | "ignore", 29 | message=( 30 | "The set_clim function was deprecated in Matplotlib 3.1 " 31 | "and will be removed in 3.3. Use ScalarMappable.set_clim " 32 | "instead." 33 | ), 34 | ) 35 | warnings.filterwarnings( 36 | "ignore", 37 | category=UserWarning, 38 | ) 39 | warnings.filterwarnings( 40 | "ignore", 41 | message="invalid escape sequence", 42 | ) 43 | 44 | warnings.filterwarnings("ignore", message="Gdk.Cursor.new is deprecated") 45 | warnings.filterwarnings( 46 | "ignore", 47 | message=( 48 | " `numpy.bool` is a deprecated alias for the builtin `bool`. " 49 | "To silence this warning, use `bool` by itself. Doing this will not modify any " 50 | "behavior and is safe. If you specifically wanted the numpy scalar type, " 51 | "use `numpy.bool_` here." 52 | ), 53 | ) 54 | warnings.filterwarnings( 55 | "ignore", 56 | message=" WARN: Box bound precision lowered by casting to float32", 57 | ) 58 | warnings.filterwarnings( 59 | "ignore", 60 | message=( 61 | " DeprecationWarning: The binary mode of fromstring is deprecated, " 62 | "as it behaves surprisingly on unicode inputs. Use frombuffer instead" 63 | ), 64 | ) 65 | warnings.filterwarnings( 66 | "ignore", 67 | message=( 68 | " DeprecationWarning: distutils Version classes are deprecated. " 69 | "Use packaging.version instead." 70 | ), 71 | ) 72 | warnings.filterwarnings( 73 | "ignore", 74 | message=" WARNING:root:The use of `check_types` is deprecated and does not have any effect.", 75 | ) 76 | 77 | from plangym.core import PlanEnv 78 | from plangym.registry import make 79 | from plangym.version import __version__ 80 | -------------------------------------------------------------------------------- /src/plangym/control/__init__.py: -------------------------------------------------------------------------------- 1 | """Module that contains environments representing control tasks.""" 2 | 3 | from plangym.control.balloon import BalloonEnv 4 | from plangym.control.box_2d import Box2DEnv 5 | from plangym.control.classic_control import ClassicControl 6 | from plangym.control.dm_control import DMControlEnv 7 | from plangym.control.lunar_lander import LunarLander 8 | -------------------------------------------------------------------------------- /src/plangym/control/balloon.py: -------------------------------------------------------------------------------- 1 | """Implement the ``plangym`` API for the Balloon Learning Environment.""" 2 | 3 | from typing import Any 4 | 5 | import numpy 6 | 7 | 8 | try: 9 | import balloon_learning_environment.env.balloon_env # noqa: F401 10 | from balloon_learning_environment.env.rendering.matplotlib_renderer import MatplotlibRenderer 11 | except ImportError: 12 | 13 | def MatplotlibRenderer(): # noqa: D103, N802 14 | return None 15 | 16 | 17 | from plangym.core import PlangymEnv 18 | 19 | 20 | class BalloonEnv(PlangymEnv): 21 | """Balloon Learning Environment. 22 | 23 | Implements the 'BalloonLearningEnvironment-v0' released by Google in the \ 24 | balloon_learning_environment. 25 | 26 | For more information about the environment, please refer to \ 27 | https://github.com/google/balloon-learning-environment. 28 | """ 29 | 30 | AVAILABLE_RENDER_MODES = {"human", "rgb_array", "tensorboard", None} 31 | AVAILABLE_OBS_TYPES = {"coords", "rgb", "grayscale"} 32 | STATE_IS_ARRAY = False 33 | 34 | def __init__( 35 | self, 36 | name: str = "BalloonLearningEnvironment-v0", 37 | renderer=None, 38 | array_state: bool = True, 39 | **kwargs, 40 | ): 41 | """Initialize a :class:`BalloonEnv`. 42 | 43 | Args: 44 | name: Name of the environment. Follows standard gym syntax conventions. 45 | renderer: MatplotlibRenderer object (or any renderer object) to plot 46 | the ``balloons`` environment. For more information, see the 47 | official documentation. 48 | array_state: boolean value. If True, transform the state object to 49 | a ``numpy.array``. 50 | kwargs: Additional arguments to be passed to the ``gym.make`` function. 51 | 52 | """ 53 | renderer = renderer or MatplotlibRenderer() 54 | self.STATE_IS_ARRAY = array_state 55 | super().__init__(name=name, renderer=renderer, **kwargs) 56 | 57 | def get_state(self) -> Any: 58 | """Get the state of the environment.""" 59 | state = self.gym_env.unwrapped.get_simulator_state() 60 | if self.STATE_IS_ARRAY: 61 | state = numpy.array((state, None), dtype=object) 62 | return state 63 | 64 | def set_state(self, state: Any) -> None: 65 | """Set the state of the environment.""" 66 | if self.STATE_IS_ARRAY: 67 | state = state[0] 68 | return self.gym_env.unwrapped.arena.set_simulator_state(state) 69 | 70 | def seed(self, seed: int | None = None): # noqa: ARG002 71 | """Ignore seeding until next release.""" 72 | return 73 | -------------------------------------------------------------------------------- /src/plangym/control/box_2d.py: -------------------------------------------------------------------------------- 1 | """Implement the ``plangym`` API for Box2D environments.""" 2 | 3 | import copy 4 | 5 | import numpy 6 | 7 | from plangym.core import PlangymEnv 8 | 9 | 10 | class Box2DState: 11 | """Extract state information from Box2D environments. 12 | 13 | This class implements basic functionalities to get the necessary 14 | elements to construct a Box2D state. 15 | """ 16 | 17 | @staticmethod 18 | def get_body_attributes(body) -> dict: 19 | """Return a dictionary containing the attributes of a given body. 20 | 21 | Given a ``Env.world.body`` element, this method constructs a dictionary 22 | whose entries describe all body attributes. 23 | """ 24 | base = { 25 | "mass": body.mass, 26 | "inertia": body.inertia, 27 | "localCenter": body.localCenter, 28 | } 29 | state_info = { 30 | "type": body.type, 31 | "bullet": body.bullet, 32 | "awake": body.awake, 33 | "sleepingAllowed": body.sleepingAllowed, 34 | "active": body.active, 35 | "fixedRotation": body.fixedRotation, 36 | } 37 | kinematics = {"transform": body.transform, "position": body.position, "angle": body.angle} 38 | other = { 39 | "worldCenter": body.worldCenter, 40 | "localCenter": body.localCenter, 41 | "linearVelocity": body.linearVelocity, 42 | "angularVelocity": body.angularVelocity, 43 | } 44 | base.update(kinematics) 45 | base.update(state_info) 46 | base.update(other) 47 | return base 48 | 49 | @staticmethod 50 | def serialize_body_attribute(value): 51 | """Copy one body attribute.""" 52 | from Box2D.Box2D import b2Transform, b2Vec2 # noqa: PLC0415 53 | 54 | if isinstance(value, b2Vec2): 55 | return (*value.copy(),) 56 | if isinstance(value, b2Transform): 57 | return { 58 | "angle": float(value.angle), 59 | "position": (*value.position.copy(),), 60 | } 61 | return copy.copy(value) 62 | 63 | @classmethod 64 | def serialize_body_state(cls, state_dict): 65 | """Serialize the state of the target body data. 66 | 67 | This method takes as argument the result given by the method 68 | ``self.get_body_attributes``, the latter consisting in a dictionary 69 | containing all attribute elements of a body. The method returns 70 | a dictionary whose values are the serialized attributes of the body. 71 | """ 72 | return {k: cls.serialize_body_attribute(v) for k, v in state_dict.items()} 73 | 74 | @staticmethod 75 | def set_value_to_body(body, name, value): 76 | """Set the target value to a body attribute.""" 77 | from Box2D.Box2D import b2Transform, b2Vec2 # noqa: PLC0415 78 | 79 | body_object = getattr(body, name) 80 | if isinstance(body_object, b2Vec2): 81 | body_object.Set(*value) 82 | elif isinstance(body_object, b2Transform): 83 | body_object.angle = value["angle"] 84 | body_object.position.Set(*value["position"]) 85 | else: 86 | setattr(body, name, value) 87 | 88 | @classmethod 89 | def set_body_state(cls, body, state): 90 | """Set the state to the target body. 91 | 92 | The method defines the corresponding body attribute to the value 93 | selected by the user. 94 | """ 95 | state = state[0] if isinstance(state, numpy.ndarray) else state 96 | for k, v in state.items(): 97 | cls.set_value_to_body(body, k, v) 98 | return body 99 | 100 | @classmethod 101 | def serialize_body(cls, body): 102 | """Serialize the data of the target ``Env.world.body`` instance.""" 103 | data: dict = cls.get_body_attributes(body) 104 | return cls.serialize_body_state(data) 105 | 106 | @classmethod 107 | def serialize_world_state(cls, world): 108 | """Serialize the state of all the bodies in world. 109 | 110 | The method serializes all body elements contained within the 111 | given ``Env.world`` object. 112 | """ 113 | return [cls.serialize_body(b) for b in world.bodies] 114 | 115 | @classmethod 116 | def set_world_state(cls, world, state): 117 | """Set the state of the world environment to the provided state. 118 | 119 | The method states the current environment by defining its world 120 | bodies' attributes. 121 | """ 122 | for body, state_ in zip(world.bodies, state): 123 | cls.set_body_state(body, state_) 124 | 125 | @classmethod 126 | def get_env_state(cls, env): 127 | """Get the serialized state of the target environment.""" 128 | return cls.serialize_world_state(env.unwrapped.world) 129 | 130 | @classmethod 131 | def set_env_state(cls, env, state): 132 | """Set the serialized state to the target environment.""" 133 | return cls.set_world_state(env.unwrapped.world, state) 134 | 135 | 136 | class Box2DEnv(PlangymEnv): 137 | """Common interface for working with Box2D environments released by `gym`.""" 138 | 139 | def get_state(self) -> numpy.array: 140 | """Recover the internal state of the simulation. 141 | 142 | A state must completely describe the Environment at a given moment. 143 | """ 144 | state = Box2DState.get_env_state(self.gym_env) 145 | return numpy.array((state, None), dtype=object) 146 | 147 | def set_state(self, state: numpy.ndarray) -> None: 148 | """Set the internal state of the simulation. 149 | 150 | Args: 151 | state: Target state to be set in the environment. 152 | 153 | Returns: 154 | None 155 | 156 | """ 157 | Box2DState.set_env_state(self.gym_env, state[0]) 158 | -------------------------------------------------------------------------------- /src/plangym/control/classic_control.py: -------------------------------------------------------------------------------- 1 | """Implement the ``plangym`` API for ``gym`` classic control environments.""" 2 | 3 | import copy 4 | 5 | import numpy 6 | 7 | from plangym.core import PlangymEnv 8 | 9 | 10 | class ClassicControl(PlangymEnv): 11 | """Environment for OpenAI gym classic control environments.""" 12 | 13 | def get_state(self) -> numpy.ndarray: 14 | """Recover the internal state of the environment.""" 15 | return numpy.array(copy.copy(self.gym_env.unwrapped.state)) 16 | 17 | def set_state(self, state: numpy.ndarray): 18 | """Set the internal state of the environemnt. 19 | 20 | Args: 21 | state: Target state to be set in the environment. 22 | 23 | Returns: 24 | None 25 | 26 | """ 27 | self.gym_env.unwrapped.state = copy.copy(tuple(state.tolist())) 28 | return state 29 | -------------------------------------------------------------------------------- /src/plangym/control/dm_control.py: -------------------------------------------------------------------------------- 1 | """Implement the ``plangym`` API for ``dm_control`` environments.""" 2 | 3 | from typing import Iterable 4 | import time 5 | import warnings 6 | 7 | from gymnasium.spaces import Box 8 | import numpy 9 | 10 | from plangym.core import PlangymEnv, wrap_callable 11 | 12 | 13 | try: 14 | from gym.envs.classic_control import rendering 15 | 16 | novideo_mode = False 17 | except Exception: # pragma: no cover 18 | novideo_mode = True 19 | 20 | 21 | class DMControlEnv(PlangymEnv): 22 | """Wrap the `dm_control library, allowing its implementation in planning problems. 23 | 24 | The dm_control library is a DeepMind's software stack for physics-based 25 | simulation and Reinforcement Learning environments, using MuJoCo physics. 26 | 27 | For more information about the environment, please refer to 28 | https://github.com/deepmind/dm_control 29 | 30 | This class allows the implementation of dm_control in planning problems. 31 | It allows parallel and vectorized execution of the environments. 32 | """ 33 | 34 | DEFAULT_OBS_TYPE = "coords" 35 | 36 | def __init__( 37 | self, 38 | name: str = "cartpole-balance", 39 | frameskip: int = 1, 40 | episodic_life: bool = False, 41 | autoreset: bool = True, 42 | wrappers: Iterable[wrap_callable] | None = None, 43 | delay_setup: bool = False, 44 | visualize_reward: bool = True, 45 | domain_name=None, 46 | task_name=None, 47 | render_mode="rgb_array", 48 | obs_type: str | None = None, 49 | remove_time_limit=None, # noqa: ARG002 50 | return_image: bool = False, 51 | ): 52 | """Initialize a :class:`DMControlEnv`. 53 | 54 | Args: 55 | name: Name of the task. Provide the task to be solved as 56 | `domain_name-task_name`. For example 'cartpole-balance'. 57 | frameskip: Set a deterministic frameskip to apply the same 58 | action N times. 59 | episodic_life: Send terminal signal after loosing a life. 60 | autoreset: Restart environment when reaching a terminal state. 61 | wrappers: Wrappers that will be applied to the underlying OpenAI env. \ 62 | Every element of the iterable can be either a :class:`gym.Wrapper` \ 63 | or a tuple containing ``(gym.Wrapper, kwargs)``. 64 | delay_setup: If ``True``, do not initialize the ``gym.Environment`` \ 65 | and wait for ``setup`` to be called later. 66 | visualize_reward: Define the color of the agent, which depends 67 | on the reward on its last timestep. 68 | domain_name: Same as in dm_control.suite.load. 69 | task_name: Same as in dm_control.suite.load. 70 | render_mode: None|human|rgb_array. 71 | remove_time_limit: Ignored. 72 | obs_type: One of {"coords", "rgb", "grayscale"}. 73 | return_image: If ``True``, add a "rgb" key to the observation dict. 74 | 75 | """ 76 | self._visualize_reward = visualize_reward 77 | self.viewer = [] 78 | self._viewer = None 79 | name, self._domain_name, self._task_name = self._parse_names(name, domain_name, task_name) 80 | super().__init__( 81 | name=name, 82 | frameskip=frameskip, 83 | episodic_life=episodic_life, 84 | wrappers=wrappers, 85 | delay_setup=delay_setup, 86 | autoreset=autoreset, 87 | render_mode=render_mode, 88 | obs_type=obs_type, 89 | return_image=return_image, 90 | ) 91 | 92 | @property 93 | def physics(self): 94 | """Alias for gym_env.physics.""" 95 | return self.gym_env.physics 96 | 97 | @property 98 | def domain_name(self) -> str: 99 | """Return the name of the agent in the current simulation.""" 100 | return self._domain_name 101 | 102 | @property 103 | def task_name(self) -> str: 104 | """Return the name of the task in the current simulation.""" 105 | return self._task_name 106 | 107 | @staticmethod 108 | def _parse_names(name, domain_name, task_name): 109 | """Return the name, domain name, and task name of the project.""" 110 | if isinstance(name, str) and domain_name is None: 111 | domain_name = name if "-" not in name else name.split("-")[0] 112 | 113 | if isinstance(name, str) and "-" in name and task_name is None: 114 | task_name = task_name if "-" not in name else name.split("-")[1] 115 | if (not isinstance(name, str) or "-" not in name) and task_name is None: 116 | raise ValueError( 117 | f"Invalid combination: name {name}," 118 | f" domain_name {domain_name}, task_name {task_name}", 119 | ) 120 | name = f"{domain_name}-{task_name}" 121 | return name, domain_name, task_name 122 | 123 | def init_gym_env(self): 124 | """Initialize the environment instance (dm_control) that the current class is wrapping.""" 125 | from dm_control import suite # noqa: PLC0415 126 | 127 | env = suite.load( 128 | domain_name=self.domain_name, 129 | task_name=self.task_name, 130 | visualize_reward=self._visualize_reward, 131 | ) 132 | self.viewer = [] 133 | self._viewer = None if novideo_mode else rendering.SimpleImageViewer() 134 | return env 135 | 136 | def setup(self): 137 | """Initialize the target :class:`gym.Env` instance.""" 138 | with warnings.catch_warnings(): 139 | warnings.simplefilter("ignore") 140 | super().setup() 141 | 142 | def _init_action_space(self): 143 | """Define the action space of the environment. 144 | 145 | This method determines the spectrum of possible actions that the 146 | agent can perform. The action space consists in a grid representing 147 | the Cartesian product of the closed intervals defined by the user. 148 | """ 149 | self._action_space = Box( 150 | low=self.action_spec().minimum, 151 | high=self.action_spec().maximum, 152 | dtype=numpy.float32, 153 | ) 154 | 155 | def _init_obs_space_coords(self): 156 | """Define the observation space of the environment.""" 157 | obs, _info = self.reset(return_state=False) 158 | shape = obs.shape 159 | self._obs_space = Box(low=-numpy.inf, high=numpy.inf, shape=shape, dtype=numpy.float32) 160 | 161 | def action_spec(self): 162 | """Alias for the environment's ``action_spec``.""" 163 | return self.gym_env.action_spec() 164 | 165 | def get_image(self) -> numpy.ndarray: 166 | """Return a numpy array containing the rendered view of the environment. 167 | 168 | Square matrices are interpreted as a greyscale image. Three-dimensional arrays 169 | are interpreted as RGB images with channels (Height, Width, RGB). 170 | """ 171 | return self.gym_env.physics.render(camera_id=0) 172 | 173 | def render(self, mode=None): 174 | """Render the environment. 175 | 176 | Store all the RGB images rendered to be shown when the `show_game`\ 177 | function is called. 178 | 179 | Args: 180 | mode: `rgb_array` return an RGB image stored in a numpy array. `human` 181 | stores the rendered image in a viewer to be shown when `show_game` 182 | is called. 183 | 184 | Returns: 185 | numpy.ndarray when mode == `rgb_array`. True when mode == `human` 186 | 187 | """ 188 | curr_mode = self.render_mode 189 | mode_ = mode or curr_mode 190 | self._render_mode = mode_ 191 | img = self.get_image() 192 | self._render_mode = curr_mode 193 | if mode == "rgb_array": 194 | return img 195 | if mode == "human": 196 | self.viewer.append(img) 197 | return True 198 | 199 | def show_game(self, sleep: float = 0.05): 200 | """Render the collected RGB images. 201 | 202 | When 'human' option is selected as argument for the `render` method, 203 | it stores a collection of RGB images inside the ``self.viewer`` 204 | attribute. This method calls the latter to visualize the collected 205 | images. 206 | """ 207 | if self._viewer is None: 208 | self._viewer = rendering.SimpleImageViewer() 209 | for img in self.viewer: 210 | self._viewer.imshow(img) 211 | time.sleep(sleep) 212 | 213 | def get_coords_obs(self, obs, **kwargs) -> numpy.ndarray: # noqa: ARG002 214 | """Get the environment observation from a time_step object. 215 | 216 | Args: 217 | obs: Time step object returned after stepping the environment. 218 | **kwargs: Ignored 219 | 220 | Returns: 221 | Numpy array containing the environment observation. 222 | 223 | """ 224 | return self._time_step_to_obs(time_step=obs) 225 | 226 | def set_state(self, state: numpy.ndarray) -> None: 227 | """Set the state of the simulator to the target State. 228 | 229 | Args: 230 | state: numpy.ndarray containing the information about the state to be set. 231 | 232 | Returns: 233 | None 234 | 235 | """ 236 | with self.gym_env.physics.reset_context(): 237 | self.gym_env.physics.set_state(state) 238 | 239 | def get_state(self) -> numpy.ndarray: 240 | """Return the state of the environment. 241 | 242 | Return a tuple containing the three arrays that characterize the state\ 243 | of the system. 244 | 245 | Each tuple contains the position of the robot, its velocity 246 | and the control variables currently being applied. 247 | 248 | Returns 249 | Tuple of numpy arrays containing all the information needed to describe 250 | the current state of the simulation. 251 | 252 | """ 253 | return self.gym_env.physics.get_state() 254 | 255 | def apply_action(self, action): 256 | """Transform the returned time_step object to a compatible gym tuple.""" 257 | info = {} 258 | time_step = self.gym_env.step(action) 259 | obs = time_step 260 | terminal = time_step.last() 261 | _reward = time_step.reward if time_step.reward is not None else 0.0 262 | reward = _reward + self._reward_step 263 | truncated = False 264 | return obs, reward, terminal, truncated, info 265 | 266 | @staticmethod 267 | def _time_step_to_obs(time_step) -> numpy.ndarray: 268 | """Stack observation values as a horizontal sequence. 269 | 270 | Concat observations in a single array, making easier calculating 271 | distances. 272 | """ 273 | return numpy.hstack( 274 | [numpy.array([time_step.observation[x]]).flatten() for x in time_step.observation], 275 | ) 276 | 277 | def close(self): 278 | """Tear down the environment and close rendering.""" 279 | try: 280 | super().close() 281 | if self._viewer is not None: 282 | self._viewer.close() 283 | except Exception: # noqa: S110 284 | pass 285 | -------------------------------------------------------------------------------- /src/plangym/registry.py: -------------------------------------------------------------------------------- 1 | """Functionality for instantiating the environment by passing the environment id.""" 2 | 3 | from plangym.environment_names import ATARI, BOX_2D, CLASSIC_CONTROL, DM_CONTROL, RETRO 4 | 5 | 6 | def get_planenv_class(name, domain_name, state): 7 | """Return the class corresponding to the environment name.""" 8 | # if name == "MinimalPacman-v0": 9 | # return MinimalPacman 10 | # elif name == "MinimalPong-v0": 11 | # return MinimalPong 12 | if name == "PlanMontezuma-v0": 13 | from plangym.videogames import MontezumaEnv 14 | 15 | return MontezumaEnv 16 | if state is not None or name in set(RETRO): 17 | from plangym.videogames import RetroEnv 18 | 19 | return RetroEnv 20 | if name in set(CLASSIC_CONTROL): 21 | from plangym.control import ClassicControl 22 | 23 | return ClassicControl 24 | if name in set(BOX_2D): 25 | if name == "FastLunarLander-v0": 26 | from plangym.control import LunarLander 27 | 28 | return LunarLander 29 | from plangym.control import Box2DEnv 30 | 31 | return Box2DEnv 32 | if name in ATARI: 33 | from plangym.videogames import AtariEnv 34 | 35 | return AtariEnv 36 | if domain_name is not None or any(x[0] in name for x in DM_CONTROL): 37 | from plangym.control import DMControlEnv 38 | 39 | return DMControlEnv 40 | if "SuperMarioBros" in name: 41 | from plangym.videogames import MarioEnv 42 | 43 | return MarioEnv 44 | if "BalloonLearningEnvironment-v0": 45 | from plangym.control import BalloonEnv 46 | 47 | return BalloonEnv 48 | raise ValueError(f"Environment {name} is not supported.") 49 | 50 | 51 | def get_environment_class( 52 | name: str | None = None, 53 | n_workers: int | None = None, 54 | ray: bool = False, 55 | domain_name: str | None = None, 56 | state: str | None = None, 57 | ): 58 | """Get the class and vectorized environment and PlangymEnv class from the make params.""" 59 | env_class = get_planenv_class(name, domain_name, state) 60 | if ray: 61 | from plangym.vectorization import RayEnv 62 | 63 | return RayEnv, env_class 64 | if n_workers is not None: 65 | from plangym.vectorization import ParallelEnv 66 | 67 | return ParallelEnv, env_class 68 | return None, env_class 69 | 70 | 71 | def make( 72 | name: str | None = None, 73 | n_workers: int | None = None, 74 | ray: bool = False, 75 | domain_name: str | None = None, 76 | state: str | None = None, 77 | **kwargs, 78 | ): 79 | """Create the appropriate PlangymEnv from the environment name and other parameters.""" 80 | parallel_class, env_class = get_environment_class( 81 | name=name, 82 | n_workers=n_workers, 83 | ray=ray, 84 | domain_name=domain_name, 85 | state=state, 86 | ) 87 | kwargs["name"] = name 88 | if state is not None: 89 | kwargs["state"] = state 90 | if domain_name is not None: 91 | kwargs["domain_name"] = domain_name 92 | if parallel_class is not None: 93 | return parallel_class(env_class=env_class, n_workers=n_workers, **kwargs) 94 | return env_class(**kwargs) 95 | -------------------------------------------------------------------------------- /src/plangym/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FragileTech/plangym/417325eac9be9668f4065351ba22f3f6461f5a4f/src/plangym/scripts/__init__.py -------------------------------------------------------------------------------- /src/plangym/scripts/import_retro_roms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import zipfile 4 | import logging 5 | import flogging 6 | from pathlib import Path 7 | 8 | import retro.data 9 | 10 | flogging.setup() 11 | logger = logging.getLogger("import-roms") 12 | 13 | 14 | def _check_zipfile(f, process_f): 15 | with zipfile.ZipFile(f) as zf: 16 | for entry in zf.infolist(): 17 | _root, ext = os.path.splitext(entry.filename) # noqa: PTH122 18 | with zf.open(entry) as innerf: 19 | if ext == ".zip": 20 | _check_zipfile(innerf, process_f) 21 | else: 22 | process_f(entry.filename, innerf) 23 | 24 | 25 | def main(): 26 | """Import ROMs from a directory into the retro data directory.""" 27 | from retro.data import EMU_EXTENSIONS # noqa: PLC0415 28 | 29 | # This avoids a bug when loading the emu_extensions. 30 | 31 | emu_extensions = { 32 | ".sfc": "Snes", 33 | ".md": "Genesis", 34 | ".sms": "Sms", 35 | ".gg": "GameGear", 36 | ".nes": "Nes", 37 | ".gba": "GbAdvance", 38 | ".gb": "GameBoy", 39 | ".gbc": "GbColor", 40 | ".a26": "Atari2600", 41 | ".pce": "PCEngine", 42 | } 43 | EMU_EXTENSIONS.update(emu_extensions) 44 | paths = sys.argv[1:] or [Path.cwd()] 45 | logger.info(f"Importing ROMs from: {paths}") 46 | logger.info("Fetching known hashes") 47 | known_hashes = retro.data.get_known_hashes() 48 | logger.info(f"Found {len(known_hashes)} known hashes") 49 | imported_games = 0 50 | 51 | def save_if_matches(filename, f): 52 | nonlocal imported_games 53 | try: 54 | data, hash = retro.data.groom_rom(filename, f) 55 | except (OSError, ValueError): 56 | logging.warning(f"Failed to process file: {filename}") 57 | return 58 | if hash in known_hashes: 59 | game, ext, curpath = known_hashes[hash] 60 | # print('Importing', game) 61 | rompath = os.path.join(curpath, game, f"rom{ext}") # noqa: PTH118 62 | # print("ROM PATH", rompath) 63 | with open(rompath, "wb") as file: # noqa: FURB103 64 | file.write(data) 65 | imported_games += 1 66 | logger.info(f"Imported game: {game}") 67 | logger.debug(f"to {rompath}") 68 | 69 | for path in paths: # noqa: PLR1702 70 | logger.info(f"Processing path: {path}") 71 | for root, dirs, files in os.walk(path): 72 | for filename in files: 73 | logger.debug(f"Processing file: {root}/{filename}") 74 | filepath = os.path.join(root, filename) # noqa: PTH118 75 | with open(filepath, "rb") as f: 76 | _root, ext = os.path.splitext(filename) # noqa: PTH122 77 | if ext == ".zip": 78 | try: 79 | _check_zipfile(f, save_if_matches) 80 | except (zipfile.BadZipFile, RuntimeError, OSError): 81 | logger.warning(f"Failed to process zip file: {filepath}") 82 | else: 83 | save_if_matches(filename, f) 84 | logger.info(f"Total imported games: {imported_games}") 85 | 86 | 87 | if __name__ == "__main__": 88 | sys.exit(main()) 89 | -------------------------------------------------------------------------------- /src/plangym/utils.py: -------------------------------------------------------------------------------- 1 | """Generic utilities for working with environments.""" 2 | 3 | import os 4 | 5 | import gymnasium as gym 6 | from gymnasium.spaces import Box 7 | from gymnasium.wrappers.time_limit import TimeLimit 8 | import numpy 9 | from pyvirtualdisplay import Display 10 | import cv2 11 | 12 | try: 13 | from PIL import Image 14 | 15 | USE_PIL = True 16 | except ImportError: # pragma: no cover 17 | USE_PIL = False 18 | 19 | 20 | def get_display(visible=False, size=(400, 400), **kwargs): 21 | """Start a virtual display.""" 22 | os.environ["PYVIRTUALDISPLAY_DISPLAYFD"] = "0" 23 | display = Display(visible=visible, size=size, **kwargs) 24 | display.start() 25 | return display 26 | 27 | 28 | def remove_time_limit_from_spec(spec): 29 | """Remove the maximum time limit of an environment spec.""" 30 | if hasattr(spec, "max_episode_steps"): 31 | spec._max_episode_steps = spec.max_episode_steps 32 | spec.max_episode_steps = 1e100 33 | if hasattr(spec, "max_episode_time"): 34 | spec._max_episode_time = spec.max_episode_time 35 | spec.max_episode_time = 1e100 36 | 37 | 38 | def remove_time_limit(gym_env: gym.Env) -> gym.Env: 39 | """Remove the maximum time limit of the provided environment.""" 40 | if hasattr(gym_env, "spec") and gym_env.spec is not None: 41 | remove_time_limit_from_spec(gym_env.spec) 42 | if not isinstance(gym_env, gym.Wrapper): 43 | return gym_env 44 | for _ in range(5): 45 | try: 46 | if isinstance(gym_env, TimeLimit): 47 | return gym_env.env 48 | if isinstance(gym_env.env, gym.Wrapper) and isinstance(gym_env.env, TimeLimit): 49 | gym_env.env = gym_env.env.env 50 | # This is an ugly hack to make sure that we can remove the TimeLimit even 51 | # if somebody is crazy enough to apply three other wrappers on top of the TimeLimit 52 | elif isinstance(gym_env.env.env, gym.Wrapper) and isinstance( 53 | gym_env.env.env, 54 | TimeLimit, 55 | ): # pragma: no cover 56 | gym_env.env.env = gym_env.env.env.env 57 | elif isinstance(gym_env.env.env.env, gym.Wrapper) and isinstance( 58 | gym_env.env.env.env, 59 | TimeLimit, 60 | ): # pragma: no cover 61 | gym_env.env.env.env = gym_env.env.env.env.env 62 | else: # pragma: no cover 63 | break 64 | except AttributeError: 65 | break 66 | return gym_env 67 | 68 | 69 | def process_frame_pil( 70 | frame: numpy.ndarray, 71 | width: int | None = None, 72 | height: int | None = None, 73 | mode: str = "RGB", 74 | ) -> numpy.ndarray: 75 | """Resize an RGB frame to a specified shape and mode. 76 | 77 | Use PIL to resize an RGB frame to a specified height and width \ 78 | or changing it to a different mode. 79 | 80 | Args: 81 | frame: Target numpy array representing the image that will be resized. 82 | width: Width of the resized image. 83 | height: Height of the resized image. 84 | mode: Passed to Image.convert. 85 | 86 | Returns: 87 | The resized frame that matches the provided width and height. 88 | 89 | """ 90 | mode = "L" if mode == "GRAY" else mode 91 | height = height or frame.shape[0] 92 | width = width or frame.shape[1] 93 | frame = Image.fromarray(frame) 94 | frame = frame.convert(mode).resize(size=(width, height)) 95 | return numpy.array(frame) 96 | 97 | 98 | def process_frame_opencv( 99 | frame: numpy.ndarray, 100 | width: int | None = None, 101 | height: int | None = None, 102 | mode: str = "RGB", 103 | ) -> numpy.ndarray: 104 | """Resize an RGB frame to a specified shape and mode. 105 | 106 | Use OpenCV to resize an RGB frame to a specified height and width \ 107 | or changing it to a different mode. 108 | 109 | Args: 110 | frame: Target numpy array representing the image that will be resized. 111 | width: Width of the resized image. 112 | height: Height of the resized image. 113 | mode: Passed to cv2.cvtColor. 114 | 115 | Returns: 116 | The resized frame that matches the provided width and height. 117 | 118 | """ 119 | height = height or frame.shape[0] 120 | width = width or frame.shape[1] 121 | frame = cv2.resize(frame, (width, height)) 122 | if mode in {"GRAY", "L"}: 123 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 124 | elif mode == "BGR": 125 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 126 | return frame 127 | 128 | 129 | def process_frame( 130 | frame: numpy.ndarray, 131 | width: int | None = None, 132 | height: int | None = None, 133 | mode: str = "RGB", 134 | ) -> numpy.ndarray: 135 | """Resize an RGB frame to a specified shape and mode. 136 | 137 | Use either PIL or OpenCV to resize an RGB frame to a specified height and width \ 138 | or changing it to a different mode. 139 | 140 | Args: 141 | frame: Target numpy array representing the image that will be resized. 142 | width: Width of the resized image. 143 | height: Height of the resized image. 144 | mode: Passed to either Image.convert or cv2.cvtColor. 145 | 146 | Returns: 147 | The resized frame that matches the provided width and height. 148 | 149 | """ 150 | func = process_frame_pil if USE_PIL else process_frame_opencv # pragma: no cover 151 | return func(frame, width, height, mode) 152 | 153 | 154 | class GrayScaleObservation(gym.ObservationWrapper, gym.utils.RecordConstructorArgs): 155 | """Convert the image observation from RGB to gray scale. 156 | 157 | Example: 158 | >>> import gymnasium as gym 159 | >>> from gymnasium.wrappers import GrayScaleObservation 160 | >>> env = gym.make("CarRacing-v2") 161 | >>> env.observation_space 162 | Box(0, 255, (96, 96, 3), uint8) 163 | >>> env = GrayScaleObservation(gym.make("CarRacing-v2")) 164 | >>> env.observation_space 165 | Box(0, 255, (96, 96), uint8) 166 | >>> env = GrayScaleObservation(gym.make("CarRacing-v2"), keep_dim=True) 167 | >>> env.observation_space 168 | Box(0, 255, (96, 96, 1), uint8) 169 | 170 | """ 171 | 172 | def __init__(self, env: gym.Env, keep_dim: bool = False): 173 | """Convert the image observation from RGB to gray scale. 174 | 175 | Args: 176 | env (Env): The environment to apply the wrapper 177 | keep_dim (bool): If `True`, a singleton dimension will be added, i.e. \ 178 | observations are of the shape AxBx1. Otherwise, they are of shape AxB. 179 | 180 | """ 181 | gym.utils.RecordConstructorArgs.__init__(self, keep_dim=keep_dim) 182 | gym.ObservationWrapper.__init__(self, env) 183 | 184 | self.keep_dim = keep_dim 185 | 186 | assert ( 187 | "Box" in self.observation_space.__class__.__name__ # works for both gym and gymnasium 188 | and len(self.observation_space.shape) == 3 # noqa: PLR2004 189 | and self.observation_space.shape[-1] == 3 # noqa: PLR2004 190 | ), f"Expected input to be of shape (..., 3), got {self.observation_space.shape}" 191 | 192 | obs_shape = self.observation_space.shape[:2] 193 | if self.keep_dim: 194 | self.observation_space = Box( 195 | low=0, high=255, shape=(obs_shape[0], obs_shape[1], 1), dtype=numpy.uint8 196 | ) 197 | else: 198 | self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=numpy.uint8) 199 | 200 | def observation(self, observation): 201 | """Convert the colour observation to greyscale. 202 | 203 | Args: 204 | observation: Color observations 205 | 206 | Returns: 207 | Grayscale observations 208 | 209 | """ 210 | import cv2 # noqa: PLC0415 211 | 212 | observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) 213 | if self.keep_dim: 214 | observation = numpy.expand_dims(observation, -1) 215 | return observation 216 | -------------------------------------------------------------------------------- /src/plangym/vectorization/__init__.py: -------------------------------------------------------------------------------- 1 | """Module that contains the code implementing vectorization for `PlangymEnv.step_batch`.""" 2 | 3 | from plangym.vectorization.parallel import ParallelEnv 4 | from plangym.vectorization.ray import RayEnv 5 | -------------------------------------------------------------------------------- /src/plangym/vectorization/env.py: -------------------------------------------------------------------------------- 1 | """Plangym API implementation.""" 2 | 3 | from abc import ABC 4 | from typing import Callable, Generator 5 | 6 | from gymnasium.spaces import Space 7 | import numpy 8 | 9 | from plangym.core import PlanEnv, PlangymEnv 10 | 11 | 12 | class VectorizedEnv(PlangymEnv, ABC): # noqa: PLR0904 13 | """Base class that defines the API for working with vectorized environments. 14 | 15 | A vectorized environment allows to step several copies of the environment in parallel 16 | when calling ``step_batch``. 17 | 18 | It creates a local copy of the environment that is the target of all the other 19 | methods of :class:`PlanEnv`. In practise, a :class:`VectorizedEnv` 20 | acts as a wrapper of an environment initialized with the provided parameters when calling 21 | __init__. 22 | 23 | """ 24 | 25 | def __init__( 26 | self, 27 | env_class, 28 | name: str, 29 | frameskip: int = 1, 30 | autoreset: bool = True, 31 | delay_setup: bool = False, 32 | n_workers: int = 8, 33 | **kwargs, 34 | ): 35 | """Initialize a :class:`VectorizedEnv`. 36 | 37 | Args: 38 | env_class: Class of the environment to be wrapped. 39 | name: Name of the environment. 40 | frameskip: Number of times ``step`` will be called with the same action. 41 | autoreset: Ignored. Always set to True. Automatically reset the environment 42 | when the OpenAI environment returns ``end = True``. 43 | delay_setup: If ``True`` do not initialize the :class:`gym.Environment` 44 | and wait for ``setup`` to be called later. 45 | n_workers: Number of workers that will be used to step the env. 46 | **kwargs: Additional keyword arguments passed to env_class.__init__. 47 | 48 | """ 49 | self._n_workers = n_workers 50 | self._env_class = env_class 51 | self._env_kwargs = kwargs 52 | self._plangym_env: PlangymEnv | PlanEnv | None = None 53 | self.SINGLETON = env_class.SINGLETON if hasattr(env_class, "SINGLETON") else False 54 | self.STATE_IS_ARRAY = ( 55 | env_class.STATE_IS_ARRAY if hasattr(env_class, "STATE_IS_ARRAY") else True 56 | ) 57 | super().__init__( 58 | name=name, 59 | frameskip=frameskip, 60 | autoreset=autoreset, 61 | delay_setup=delay_setup, 62 | ) 63 | 64 | @property 65 | def n_workers(self) -> int: 66 | """Return the number of parallel processes that run ``step_batch`` in parallel.""" 67 | return self._n_workers 68 | 69 | @property 70 | def plan_env(self) -> PlanEnv: 71 | """Environment that is wrapped by the current instance.""" 72 | return self._plangym_env 73 | 74 | @property 75 | def obs_shape(self) -> tuple[int]: 76 | """Tuple containing the shape of the observations returned by the Environment.""" 77 | return self.plan_env.obs_shape 78 | 79 | @property 80 | def action_shape(self) -> tuple[int]: 81 | """Tuple containing the shape of the actions applied to the Environment.""" 82 | return self.plan_env.action_shape 83 | 84 | @property 85 | def action_space(self) -> Space: 86 | """Return the action_space of the environment.""" 87 | return self.plan_env.action_space 88 | 89 | @property 90 | def observation_space(self) -> Space: 91 | """Return the observation_space of the environment.""" 92 | return self.plan_env.observation_space 93 | 94 | @property 95 | def gym_env(self): 96 | """Return the instance of the environment that is being wrapped by plangym.""" 97 | try: 98 | return self.plan_env.gym_env 99 | except AttributeError: 100 | return None 101 | 102 | def __getattr__(self, item): 103 | """Forward attributes to the wrapped environment.""" 104 | return getattr(self.plan_env, item) 105 | 106 | @staticmethod 107 | def split_similar_chunks( 108 | vector: list | numpy.ndarray, 109 | n_chunks: int, 110 | ) -> Generator[list | numpy.ndarray, None, None]: 111 | """Split an indexable object into similar chunks. 112 | 113 | Args: 114 | vector: Target indexable object to be split. 115 | n_chunks: Number of similar chunks. 116 | 117 | Returns: 118 | Generator that returns the chunks created after splitting the target object. 119 | 120 | """ 121 | chunk_size = int(numpy.ceil(len(vector) / n_chunks)) 122 | for i in range(0, len(vector), chunk_size): 123 | yield vector[i : i + chunk_size] 124 | 125 | @classmethod 126 | def batch_step_data(cls, actions, states, dt, batch_size): 127 | """Make batches of step data to distribute across workers.""" 128 | no_states = states is None or states[0] is None 129 | states = [None] * len(actions) if no_states else states 130 | dt = dt if isinstance(dt, numpy.ndarray) else numpy.ones(len(states), dtype=int) * dt 131 | states_chunks = cls.split_similar_chunks(states, n_chunks=batch_size) 132 | actions_chunks = cls.split_similar_chunks(actions, n_chunks=batch_size) 133 | dt_chunks = cls.split_similar_chunks(dt, n_chunks=batch_size) 134 | return states_chunks, actions_chunks, dt_chunks 135 | 136 | @staticmethod 137 | def unpack_transitions(results: list, return_states: bool): 138 | """Aggregate the results of stepping across diferent workers.""" 139 | _states, observs, rewards, terminals, truncateds, infos = [], [], [], [], [], [] 140 | for result in results: 141 | if not return_states: 142 | obs, rew, ends, trunc, info = result 143 | else: 144 | _sts, obs, rew, ends, trunc, info = result 145 | _states += _sts 146 | 147 | observs += obs 148 | rewards += rew 149 | terminals += ends 150 | infos += info 151 | truncateds += trunc 152 | if not return_states: 153 | transitions = observs, rewards, terminals, truncateds, infos 154 | else: 155 | transitions = _states, observs, rewards, terminals, truncateds, infos 156 | return transitions 157 | 158 | def create_env_callable(self, **kwargs) -> Callable[..., PlanEnv]: 159 | """Return a callable that initializes the environment that is being vectorized.""" 160 | 161 | def create_env_callable(env_class, **env_kwargs): 162 | def _inner(**inner_kwargs): 163 | env_kwargs.update(inner_kwargs) 164 | return env_class(**env_kwargs) 165 | 166 | return _inner 167 | 168 | sub_env_kwargs = dict(self._env_kwargs) 169 | sub_env_kwargs["render_mode"] = self.render_mode if self.render_mode != "human" else None 170 | callable_kwargs = dict( 171 | env_class=self._env_class, 172 | name=self.name, 173 | frameskip=self.frameskip, 174 | delay_setup=self._env_class.SINGLETON, 175 | **sub_env_kwargs, 176 | ) 177 | callable_kwargs.update(kwargs) 178 | return create_env_callable(**callable_kwargs) 179 | 180 | def setup(self) -> None: 181 | """Initialize the target environment with the parameters provided at __init__.""" 182 | self._plangym_env: PlangymEnv = self.create_env_callable()() 183 | self._plangym_env.setup() 184 | 185 | def step( 186 | self, 187 | action: numpy.ndarray, 188 | state: numpy.ndarray = None, 189 | dt: int = 1, 190 | return_state: bool | None = None, 191 | ): 192 | """Step the environment applying a given action from an arbitrary state. 193 | 194 | If is not provided the signature matches the `step` method from OpenAI gym. 195 | 196 | Args: 197 | action: Array containing the action to be applied. 198 | state: State to be set before stepping the environment. 199 | dt: Consecutive number of times to apply the given action. 200 | return_state: Whether to return the state in the returned tuple. \ 201 | If None, `step` will return the state if `state` was passed as a parameter. 202 | 203 | Returns: 204 | if states is `None` returns `(observs, rewards, ends, infos)` else 205 | `(new_states, observs, rewards, ends, infos)`. 206 | 207 | """ 208 | return self.plan_env.step(action=action, state=state, dt=dt, return_state=return_state) 209 | 210 | def reset(self, return_state: bool = True): 211 | """Reset the environment. 212 | 213 | Reset the environment and returns the first observation, or the first \ 214 | (state, obs, info) tuple. 215 | 216 | Args: 217 | return_state: If true return a also the initial state of the env. 218 | 219 | Returns: 220 | Observation of the environment if `return_state` is False. Otherwise, 221 | return (state, obs) after reset. 222 | 223 | """ 224 | if self.plan_env is None and self.delay_setup: 225 | self.setup() 226 | state, obs, info = self.plan_env.reset(return_state=True) 227 | self.sync_states(state) 228 | return (state, obs, info) if return_state else (obs, info) 229 | 230 | def get_state(self): 231 | """Recover the internal state of the simulation. 232 | 233 | A state completely describes the Environment at a given moment. 234 | 235 | Returns 236 | State of the simulation. 237 | 238 | """ 239 | return self.plan_env.get_state() 240 | 241 | def set_state(self, state): 242 | """Set the internal state of the simulation. 243 | 244 | Args: 245 | state: Target state to be set in the environment. 246 | 247 | """ 248 | self.plan_env.set_state(state) 249 | self.sync_states(state) 250 | 251 | def render(self, mode="human"): # noqa: ARG002 252 | """Render the environment using OpenGL. This wraps the OpenAI render method.""" 253 | return self.plan_env.render() 254 | 255 | def get_image(self) -> numpy.ndarray: 256 | """Return a numpy array containing the rendered view of the environment. 257 | 258 | Square matrices are interpreted as a greyscale image. Three-dimensional arrays 259 | are interpreted as RGB images with channels (Height, Width, RGB) 260 | """ 261 | return self.plan_env.get_image() 262 | 263 | def step_with_dt(self, action: numpy.ndarray | int | float, dt: int = 1) -> tuple: 264 | """Step the environment ``dt`` times with the same action. 265 | 266 | Take ``dt`` simulation steps and make the environment evolve in multiples \ 267 | of ``self.frameskip`` for a total of ``dt`` * ``self.frameskip`` steps. 268 | 269 | Args: 270 | action: Chosen action applied to the environment. 271 | dt: Consecutive number of times that the action will be applied. 272 | 273 | Returns: 274 | If state is `None` returns `(observs, reward, terminal, info)` 275 | else returns `(new_state, observs, reward, terminal, info)`. 276 | 277 | """ 278 | return self.plan_env.step_with_dt(action=action, dt=dt) 279 | 280 | def sample_action(self): 281 | """Return a valid action that can be used to step the Environment. 282 | 283 | Implementing this method is optional, and it's only intended to make the 284 | testing process of the Environment easier. 285 | """ 286 | return self.plan_env.sample_action() 287 | 288 | def step_batch( 289 | self, 290 | actions: numpy.ndarray, 291 | states: numpy.ndarray = None, 292 | dt: numpy.ndarray | int = 1, 293 | return_state: bool | None = None, 294 | ): 295 | """Vectorized version of the ``step`` method. 296 | 297 | It allows to step a vector of states and actions. The signature and 298 | behaviour is the same as ``step``, but taking a list of states, actions 299 | and dts as input. 300 | 301 | Args: 302 | actions: Iterable containing the different actions to be applied. 303 | states: Iterable containing the different states to be set. 304 | dt: int or array containing the frameskips that will be applied. 305 | return_state: Whether to return the state in the returned tuple. \ 306 | If None, `step` will return the state if `state` was passed as a parameter. 307 | 308 | Returns: 309 | if states is None returns `(observs, rewards, ends, infos)` else 310 | `(new_states, observs, rewards, ends, infos)`. 311 | 312 | """ 313 | dt_is_array = dt.shape if isinstance(dt, numpy.ndarray) else isinstance(dt, list | tuple) 314 | dt = dt if dt_is_array else numpy.ones(len(actions), dtype=int) * dt 315 | return self.make_transitions(actions, states, dt, return_state=return_state) 316 | 317 | def clone(self, **kwargs) -> "PlanEnv": 318 | """Return a copy of the environment.""" 319 | self_kwargs = dict( 320 | name=self.name, 321 | frameskip=self.frameskip, 322 | delay_setup=self.delay_setup, 323 | env_class=self._env_class, 324 | n_workers=self.n_workers, 325 | **self._env_kwargs, 326 | ) 327 | self_kwargs.update(kwargs) 328 | return self.__class__(**self_kwargs) 329 | 330 | def sync_states(self, state: None): 331 | """Synchronize the workers' states with the state of `self.gym_env`. 332 | 333 | Set all the states of the different workers of the internal :class:`BatchEnv` 334 | to the same state as the internal :class:`Environment` used to apply the 335 | non-vectorized steps. 336 | """ 337 | raise NotImplementedError() 338 | 339 | def make_transitions(self, actions, states, dt, return_state: bool | None = None): 340 | """Implement the logic for stepping the environment in parallel.""" 341 | raise NotImplementedError() 342 | -------------------------------------------------------------------------------- /src/plangym/vectorization/ray.py: -------------------------------------------------------------------------------- 1 | """Implement a :class:`plangym.VectorizedEnv` that uses ray when calling `step_batch`.""" 2 | 3 | import numpy 4 | 5 | 6 | try: 7 | import ray 8 | except ImportError: 9 | pass 10 | 11 | from plangym.core import PlanEnv 12 | from plangym.vectorization.env import VectorizedEnv 13 | 14 | 15 | @ray.remote 16 | class RemoteEnv(PlanEnv): 17 | """Remote ray Actor interface for a plangym.PlanEnv.""" 18 | 19 | def __init__(self, env_callable): 20 | """Initialize a :class:`RemoteEnv`.""" 21 | self._env_callable = env_callable 22 | self.env = None 23 | 24 | @property 25 | def unwrapped(self): 26 | """Completely unwrap this Environment. 27 | 28 | Returns 29 | plangym.Environment: The base non-wrapped plangym.Environment instance 30 | 31 | """ 32 | return self.env 33 | 34 | @property 35 | def name(self) -> str: 36 | """Return the name of the environment.""" 37 | return self.env.name 38 | 39 | def setup(self): 40 | """Init the wrapped environment.""" 41 | self.env = self._env_callable() 42 | 43 | def step(self, action, state=None, dt: int = 1, return_state: bool | None = None) -> tuple: 44 | """Take a simulation step and make the environment evolve. 45 | 46 | Args: 47 | action: Chosen action applied to the environment. 48 | state: Set the environment to the given state before stepping it. 49 | If state is None the behaviour of this function will be the 50 | same as in OpenAI gym. 51 | dt: Consecutive number of times to apply an action. 52 | return_state: Whether to return the state in the returned tuple. \ 53 | If None, `step` will return the state if `state` was passed as a parameter. 54 | 55 | Returns: 56 | if states is None returns (observs, rewards, ends, infos) 57 | else returns(new_states, observs, rewards, ends, infos) 58 | 59 | """ 60 | return self.env.step(action=action, state=state, dt=dt, return_state=return_state) 61 | 62 | def step_batch( 63 | self, 64 | actions: [numpy.ndarray, list], 65 | states=None, 66 | dt: int = 1, 67 | return_state: bool | None = None, 68 | ) -> tuple: 69 | """Take a step on a batch of states and actions. 70 | 71 | Args: 72 | actions: Chosen actions applied to the environment. 73 | states: Set the environment to the given states before stepping it. 74 | If state is None the behaviour of this function will be the same 75 | as in OpenAI gym. 76 | dt: Consecutive number of times that the action will be 77 | applied. 78 | return_state: Whether to return the state in the returned tuple. \ 79 | If None, `step` will return the state if `state` was passed as a parameter. 80 | 81 | Returns: 82 | if states is None returns (observs, rewards, ends, infos) 83 | else returns(new_states, observs, rewards, ends, infos) 84 | 85 | """ 86 | return self.env.step_batch( 87 | actions=actions, 88 | states=states, 89 | dt=dt, 90 | return_state=return_state, 91 | ) 92 | 93 | def reset(self, return_state: bool = True) -> [numpy.ndarray, tuple]: 94 | """Restart the environment.""" 95 | return self.env.reset(return_state=return_state) 96 | 97 | def get_state(self): 98 | """Recover the internal state of the simulation. 99 | 100 | A state must completely describe the Environment at a given moment. 101 | """ 102 | return self.env.get_state() 103 | 104 | def set_state(self, state): 105 | """Set the internal state of the simulation. 106 | 107 | Args: 108 | state: Target state to be set in the environment. 109 | 110 | Returns: 111 | None 112 | 113 | """ 114 | return self.env.set_state(state=state) 115 | 116 | 117 | class RayEnv(VectorizedEnv): 118 | """Use ray for taking steps in parallel when calling `step_batch`.""" 119 | 120 | def __init__( 121 | self, 122 | env_class, 123 | name: str, 124 | frameskip: int = 1, 125 | autoreset: bool = True, 126 | delay_setup: bool = False, 127 | n_workers: int = 8, 128 | **kwargs, 129 | ): 130 | """Initialize a :class:`ParallelEnv`. 131 | 132 | Args: 133 | env_class: Class of the environment to be wrapped. 134 | name: Name of the environment. 135 | frameskip: Number of times ``step`` will me called with the same action. 136 | autoreset: Ignored. Always set to True. Automatically reset the environment 137 | when the OpenAI environment returns ``end = True``. 138 | delay_setup: If ``True`` do not initialize the ``gym.Environment`` \ 139 | and wait for ``setup`` to be called later. 140 | env_callable: Callable that returns an instance of the environment \ 141 | that will be parallelized. 142 | n_workers: Number of workers that will be used to step the env. 143 | *args: Additional args for the environment. 144 | **kwargs: Additional kwargs for the environment. 145 | 146 | """ 147 | self._workers = None 148 | super().__init__( 149 | env_class=env_class, 150 | name=name, 151 | frameskip=frameskip, 152 | autoreset=autoreset, 153 | delay_setup=delay_setup, 154 | n_workers=n_workers, 155 | **kwargs, 156 | ) 157 | 158 | @property 159 | def workers(self) -> list[RemoteEnv]: 160 | """Remote actors exposing copies of the environment.""" 161 | return self._workers 162 | 163 | def setup(self): 164 | """Run environment initialization and create the subprocesses for stepping in parallel.""" 165 | env_callable = self.create_env_callable(autoreset=True, delay_setup=False) 166 | workers = [RemoteEnv.remote(env_callable=env_callable) for _ in range(self.n_workers)] 167 | ray.get([w.setup.remote() for w in workers]) 168 | self._workers = workers 169 | # Initialize local copy last to tolerate singletons better 170 | super().setup() 171 | 172 | def make_transitions( 173 | self, 174 | actions, 175 | states=None, 176 | dt: [numpy.ndarray, int] = 1, 177 | return_state: bool | None = None, 178 | ): 179 | """Implement the logic for stepping the environment in parallel.""" 180 | ret_states = not ( 181 | states is None or (isinstance(states, list | numpy.ndarray) and states[0] is None) 182 | ) 183 | _return_state = ret_states if return_state is None else return_state 184 | chunks = self.batch_step_data( 185 | actions=actions, 186 | states=states, 187 | dt=dt, 188 | batch_size=len(self.workers), 189 | ) 190 | results_ids = [] 191 | for env, states_batch, actions_batch, _dt in zip(self.workers, *chunks): 192 | result = env.step_batch.remote( 193 | actions=actions_batch, 194 | states=states_batch, 195 | dt=_dt, 196 | return_state=return_state, 197 | ) 198 | results_ids.append(result) 199 | results = ray.get(results_ids) 200 | return self.unpack_transitions(results=results, return_states=_return_state) 201 | 202 | def reset(self, return_state: bool = True) -> [numpy.ndarray, tuple]: 203 | """Restart the environment.""" 204 | if self.plan_env is None and self.delay_setup: 205 | self.setup() 206 | ray.get([w.reset.remote(return_state=return_state) for w in self.workers]) 207 | return super().reset(return_state=return_state) 208 | 209 | def sync_states(self, state: None) -> None: 210 | """Synchronize all the copies of the wrapped environment. 211 | 212 | Set all the states of the different workers of the internal :class:`BatchEnv` 213 | to the same state as the internal :class:`Environment` used to apply the 214 | non-vectorized steps. 215 | """ 216 | state = super().get_state() if state is None else state 217 | obj_ids = [w.set_state.remote(state) for w in self.workers] 218 | self.plan_env.set_state(state) 219 | ray.get(obj_ids) 220 | -------------------------------------------------------------------------------- /src/plangym/version.py: -------------------------------------------------------------------------------- 1 | """Current version of the project. Do not modify manually.""" 2 | 3 | __version__ = "0.1.30" 4 | -------------------------------------------------------------------------------- /src/plangym/videogames/__init__.py: -------------------------------------------------------------------------------- 1 | """Module that contains environments representing video games.""" 2 | 3 | from plangym.videogames.atari import AtariEnv 4 | from plangym.videogames.montezuma import MontezumaEnv 5 | from plangym.videogames.nes import MarioEnv 6 | from plangym.videogames.retro import RetroEnv 7 | -------------------------------------------------------------------------------- /src/plangym/videogames/env.py: -------------------------------------------------------------------------------- 1 | """Plangym API implementation.""" 2 | 3 | from abc import ABC 4 | from typing import Any, Iterable 5 | 6 | import gymnasium as gym 7 | import numpy 8 | 9 | from plangym.core import PlangymEnv, wrap_callable 10 | 11 | 12 | LIFE_KEY = "lifes" 13 | 14 | 15 | class VideogameEnv(PlangymEnv, ABC): 16 | """Common interface for working with video games that run using an emulator.""" 17 | 18 | AVAILABLE_OBS_TYPES = {"rgb", "grayscale", "ram"} 19 | DEFAULT_OBS_TYPE = "rgb" 20 | 21 | def __init__( 22 | self, 23 | name: str, 24 | frameskip: int = 5, 25 | episodic_life: bool = False, 26 | autoreset: bool = True, 27 | delay_setup: bool = False, 28 | remove_time_limit: bool = True, 29 | obs_type: str = "rgb", # ram | rgb | grayscale 30 | render_mode: str | None = None, # None | human | rgb_array 31 | wrappers: Iterable[wrap_callable] | None = None, 32 | **kwargs, 33 | ): 34 | """Initialize a :class:`VideogameEnv`. 35 | 36 | Args: 37 | name: Name of the environment. Follows standard gym syntax conventions. 38 | frameskip: Number of times an action will be applied for each step 39 | in dt. 40 | episodic_life: Return ``end = True`` when losing a life. 41 | autoreset: Restart environment when reaching a terminal state. 42 | delay_setup: If ``True`` do not initialize the ``gym.Environment`` 43 | and wait for ``setup`` to be called later. 44 | remove_time_limit: If True, remove the time limit from the environment. 45 | obs_type: One of {"rgb", "ram", "grayscale"}. 46 | mode: Integer or string indicating the game mode, when available. 47 | difficulty: Difficulty level of the game, when available. 48 | repeat_action_probability: Repeat the last action with this probability. 49 | full_action_space: Whether to use the full range of possible actions 50 | or only those available in the game. 51 | render_mode: One of {None, "human", "rgb_aray"}. 52 | wrappers: Wrappers that will be applied to the underlying OpenAI env. 53 | Every element of the iterable can be either a :class:`gym.Wrapper` 54 | or a tuple containing ``(gym.Wrapper, kwargs)``. 55 | kwargs: Additional arguments to be passed to the ``gym.make`` function. 56 | 57 | """ 58 | self.episodic_life = episodic_life 59 | self._info_step = {LIFE_KEY: -1, "lost_life": False} 60 | super().__init__( 61 | name=name, 62 | frameskip=frameskip, 63 | autoreset=autoreset, 64 | wrappers=wrappers, 65 | delay_setup=delay_setup, 66 | render_mode=render_mode, 67 | remove_time_limit=remove_time_limit, 68 | obs_type=obs_type, 69 | **kwargs, 70 | ) 71 | 72 | @property 73 | def n_actions(self) -> int: 74 | """Return the number of actions available.""" 75 | return self.action_space.n 76 | 77 | @staticmethod 78 | def get_lifes_from_info(info: dict[str, Any]) -> int: 79 | """Return the number of lifes remaining in the current game.""" 80 | return info.get("life", -1) 81 | 82 | def apply_action(self, action): 83 | """Evolve the environment for one time step applying the provided action.""" 84 | obs, reward, terminal, truncated, info = super().apply_action(action=action) 85 | info[LIFE_KEY] = self.get_lifes_from_info(info) 86 | past_lifes = self._info_step.get(LIFE_KEY, -1) 87 | lost_life = past_lifes > info[LIFE_KEY] or self._info_step.get("lost_life") 88 | info["lost_life"] = lost_life 89 | terminal = (terminal or lost_life) if self.episodic_life else terminal 90 | return obs, reward, terminal, truncated, info 91 | 92 | def clone(self, **kwargs) -> "VideogameEnv": 93 | """Return a copy of the environment.""" 94 | params = { 95 | "episodic_life": self.episodic_life, 96 | "obs_type": self.obs_type, 97 | "render_mode": self.render_mode, 98 | } 99 | params.update(**kwargs) 100 | return super().clone(**params) 101 | 102 | def begin_step( 103 | self, action=None, dt=None, state=None, return_state: bool | None = None 104 | ) -> None: 105 | """Perform setup of step variables before starting `step_with_dt`.""" 106 | self._info_step = {LIFE_KEY: -1, "lost_life": False} 107 | super().begin_step( 108 | action=action, 109 | dt=dt, 110 | state=state, 111 | return_state=return_state, 112 | ) 113 | 114 | def init_spaces(self) -> None: 115 | """Initialize the action_space and the observation_space of the environment.""" 116 | super().init_spaces() 117 | if self.obs_type == "ram": 118 | if self.DEFAULT_OBS_TYPE == "ram": 119 | space = self.gym_env.observation_space 120 | else: 121 | ram_size = self.get_ram().shape 122 | space = gym.spaces.Box(low=0, high=255, dtype=numpy.uint8, shape=ram_size) 123 | self._obs_space = space 124 | 125 | def process_obs(self, obs, **kwargs): 126 | """Return the ram vector if obs_type == "ram" or and image otherwise.""" 127 | obs = super().process_obs(obs, **kwargs) 128 | if self.obs_type == "ram" and self.DEFAULT_OBS_TYPE != "ram": 129 | obs = self.get_ram() 130 | return obs 131 | 132 | def get_ram(self) -> numpy.ndarray: 133 | """Return the ram of the emulator as a numpy array.""" 134 | raise NotImplementedError() 135 | -------------------------------------------------------------------------------- /src/plangym/videogames/nes.py: -------------------------------------------------------------------------------- 1 | """Environment for playing Mario bros using gym-super-mario-bros.""" 2 | 3 | from typing import Any, TypeVar 4 | 5 | import gymnasium as gym 6 | import numpy 7 | 8 | from plangym.videogames.env import VideogameEnv 9 | 10 | # actions for the simple run right environment 11 | RIGHT_ONLY = [ 12 | ["NOOP"], 13 | ["right"], 14 | ["right", "A"], 15 | ["right", "B"], 16 | ["right", "A", "B"], 17 | ] 18 | 19 | 20 | # actions for very simple movement 21 | SIMPLE_MOVEMENT = [ 22 | ["NOOP"], 23 | ["right"], 24 | ["right", "A"], 25 | ["right", "B"], 26 | ["right", "A", "B"], 27 | ["A"], 28 | ["left"], 29 | ] 30 | 31 | 32 | # actions for more complex movement 33 | COMPLEX_MOVEMENT = [ 34 | ["NOOP"], 35 | ["right"], 36 | ["right", "A"], 37 | ["right", "B"], 38 | ["right", "A", "B"], 39 | ["A"], 40 | ["left"], 41 | ["left", "A"], 42 | ["left", "B"], 43 | ["left", "A", "B"], 44 | ["down"], 45 | ["up"], 46 | ] 47 | 48 | ObsType = TypeVar("ObsType") 49 | ActType = TypeVar("ActType") 50 | RenderFrame = TypeVar("RenderFrame") 51 | 52 | 53 | class NESWrapper: 54 | """A wrapper for the NES environment.""" 55 | 56 | def __init__(self, wrapped): 57 | """Initialize the NESWrapper.""" 58 | self._wrapped = wrapped 59 | 60 | def __getattr__(self, name): 61 | """Get an attribute from the wrapped object.""" 62 | return getattr(self._wrapped, name) 63 | 64 | def __setattr__(self, name, value): 65 | """Set an attribute on the wrapped object.""" 66 | if name == "_wrapped": 67 | super().__setattr__(name, value) 68 | else: 69 | setattr(self._wrapped, name, value) # pragma: no cover 70 | 71 | def __delattr__(self, name): 72 | """Delete an attribute from the wrapped object.""" 73 | delattr(self._wrapped, name) # pragma: no cover 74 | 75 | def step( 76 | self, action: ActType 77 | ) -> tuple[gym.core.WrapperObsType, gym.core.SupportsFloat, bool, bool, dict[str, Any]]: 78 | """Modify the :attr:`env` after calling :meth:`step` using :meth:`self.observation`.""" 79 | observation, reward, terminated, info = self._wrapped.step(action) 80 | truncated = False 81 | return self.observation(observation), reward, terminated, truncated, info 82 | 83 | def reset( 84 | self, 85 | *, 86 | seed: int | None = None, # noqa: ARG002 87 | options: dict[str, Any] | None = None, # noqa: ARG002 88 | ) -> tuple[gym.core.WrapperObsType, dict[str, Any]]: 89 | """Modify the :attr:`env` after calling :meth:`reset`, returning a modified observation.""" 90 | obs = self.env.reset() 91 | info = {} 92 | return self.observation(obs), info 93 | 94 | def observation(self, observation: ObsType) -> gym.core.WrapperObsType: 95 | """Return a modified observation. 96 | 97 | Args: 98 | observation: The :attr:`env` observation 99 | 100 | Returns: 101 | The modified observation 102 | 103 | """ 104 | return observation 105 | 106 | 107 | class JoypadSpace(gym.Wrapper): 108 | """An environment wrapper to convert binary to discrete action space.""" 109 | 110 | # a mapping of buttons to binary values 111 | _button_map = { 112 | "right": 0b10000000, 113 | "left": 0b01000000, 114 | "down": 0b00100000, 115 | "up": 0b00010000, 116 | "start": 0b00001000, 117 | "select": 0b00000100, 118 | "B": 0b00000010, 119 | "A": 0b00000001, 120 | "NOOP": 0b00000000, 121 | } 122 | 123 | @classmethod 124 | def buttons(cls) -> list: 125 | """Return the buttons that can be used as actions.""" 126 | return list(cls._button_map.keys()) 127 | 128 | def __init__(self, env: gym.Env, actions: list): 129 | """Initialize a new binary to discrete action space wrapper. 130 | 131 | Args: 132 | env: the environment to wrap 133 | actions: an ordered list of actions (as lists of buttons). 134 | The index of each button list is its discrete coded value 135 | 136 | Returns: 137 | None 138 | 139 | """ 140 | super().__init__(env) 141 | # create the new action space 142 | self.action_space = gym.spaces.Discrete(len(actions)) 143 | # create the action map from the list of discrete actions 144 | self._action_map = {} 145 | self._action_meanings = {} 146 | # iterate over all the actions (as button lists) 147 | for action, button_list in enumerate(actions): 148 | # the value of this action's bitmap 149 | byte_action = 0 150 | # iterate over the buttons in this button list 151 | for button in button_list: 152 | byte_action |= self._button_map[button] 153 | # set this action maps value to the byte action value 154 | self._action_map[action] = byte_action 155 | self._action_meanings[action] = " ".join(button_list) 156 | 157 | def step(self, action): 158 | """Take a step using the given action. 159 | 160 | Args: 161 | action (int): the discrete action to perform 162 | 163 | Returns: 164 | a tuple of: 165 | - (numpy.ndarray) the state as a result of the action 166 | - (float) the reward achieved by taking the action 167 | - (bool) a flag denoting whether the episode has ended 168 | - (dict) a dictionary of extra information 169 | 170 | """ 171 | # take the step and record the output 172 | return self.env.step(self._action_map[action]) 173 | 174 | # def reset(self, *, seed=None, options=None): 175 | # """Reset the environment and return the initial observation.""" 176 | # return self.env.reset(), {} 177 | 178 | def get_keys_to_action(self): 179 | """Return the dictionary of keyboard keys to actions.""" 180 | # get the old mapping of keys to actions 181 | old_keys_to_action = self.env.unwrapped.get_keys_to_action() 182 | # invert the keys to action mapping to lookup key combos by action 183 | action_to_keys = {v: k for k, v in old_keys_to_action.items()} 184 | # create a new mapping of keys to actions 185 | keys_to_action = {} 186 | # iterate over the actions and their byte values in this mapper 187 | for action, byte in self._action_map.items(): 188 | # get the keys to press for the action 189 | keys = action_to_keys[byte] 190 | # set the keys value in the dictionary to the current discrete act 191 | keys_to_action[keys] = action 192 | 193 | return keys_to_action 194 | 195 | def get_action_meanings(self): 196 | """Return a list of actions meanings.""" 197 | actions = sorted(self._action_meanings.keys()) 198 | return [self._action_meanings[action] for action in actions] 199 | 200 | 201 | class NesEnv(VideogameEnv): 202 | """Environment for working with the NES-py emulator.""" 203 | 204 | @property 205 | def nes_env(self) -> "NESEnv": # noqa: F821 206 | """Access the underlying NESEnv.""" 207 | return self.gym_env.unwrapped 208 | 209 | def get_image(self) -> numpy.ndarray: 210 | """Return a numpy array containing the rendered view of the environment. 211 | 212 | Square matrices are interpreted as a greyscale image. Three-dimensional arrays 213 | are interpreted as RGB images with channels (Height, Width, RGB) 214 | """ 215 | return self.gym_env.screen.copy() 216 | 217 | def get_ram(self) -> numpy.ndarray: 218 | """Return a copy of the emulator environment.""" 219 | return self.nes_env.ram.copy() 220 | 221 | def get_state(self, state: numpy.ndarray | None = None) -> numpy.ndarray: 222 | """Recover the internal state of the simulation. 223 | 224 | A state must completely describe the Environment at a given moment. 225 | """ 226 | return self.gym_env.get_state(state) 227 | 228 | def set_state(self, state: numpy.ndarray) -> None: 229 | """Set the internal state of the simulation. 230 | 231 | Args: 232 | state: Target state to be set in the environment. 233 | 234 | Returns: 235 | None 236 | 237 | """ 238 | self.gym_env.set_state(state) 239 | 240 | def close(self) -> None: 241 | """Close the underlying :class:`gym.Env`.""" 242 | if self.nes_env._env is None: 243 | return 244 | try: 245 | super().close() 246 | except ValueError: # pragma: no cover 247 | pass 248 | 249 | def __del__(self): 250 | """Tear down the environment.""" 251 | try: 252 | self.close() 253 | except ValueError: # pragma: no cover 254 | pass 255 | 256 | def render(self, mode="rgb_array"): # noqa: ARG002 257 | """Render the environment.""" 258 | return self.gym_env.screen.copy() 259 | 260 | 261 | class MarioEnv(NesEnv): 262 | """Interface for using gym-super-mario-bros in plangym.""" 263 | 264 | AVAILABLE_OBS_TYPES = {"coords", "rgb", "grayscale", "ram"} 265 | MOVEMENTS = { 266 | "complex": COMPLEX_MOVEMENT, 267 | "simple": SIMPLE_MOVEMENT, 268 | "right": RIGHT_ONLY, 269 | } 270 | 271 | def __init__( 272 | self, 273 | name: str, 274 | movement_type: str = "simple", 275 | original_reward: bool = False, 276 | **kwargs, 277 | ): 278 | """Initialize a MarioEnv. 279 | 280 | Args: 281 | name: Name of the environment. 282 | movement_type: One of {complex|simple|right} 283 | original_reward: If False return a custom reward based on mario position and level. 284 | **kwargs: passed to super().__init__. 285 | 286 | """ 287 | self._movement_type = movement_type 288 | self._original_reward = original_reward 289 | super().__init__(name=name, **kwargs) 290 | 291 | def get_state(self, state: numpy.ndarray | None = None) -> numpy.ndarray: 292 | """Recover the internal state of the simulation. 293 | 294 | A state must completely describe the Environment at a given moment. 295 | """ 296 | state = numpy.empty(250288, dtype=numpy.byte) if state is None else state 297 | state[-2:] = 0 # Some states use the last two bytes. Set to zero by default. 298 | return super().get_state(state) 299 | 300 | def init_gym_env(self) -> gym.Env: 301 | """Initialize the :class:`NESEnv`` instance that the current class is wrapping.""" 302 | from gym_super_mario_bros import make # noqa: PLC0415 303 | from gym_super_mario_bros.actions import COMPLEX_MOVEMENT # noqa: PLC0415 304 | 305 | env = make(self.name) 306 | gym_env = NESWrapper(JoypadSpace(env.unwrapped, COMPLEX_MOVEMENT)) 307 | gym_env.reset() 308 | return gym_env 309 | 310 | def _update_info(self, info: dict[str, Any]) -> dict[str, Any]: 311 | info["player_state"] = self.nes_env._player_state 312 | info["area"] = self.nes_env._area 313 | info["left_x_position"] = self.nes_env._left_x_position 314 | info["is_stage_over"] = self.nes_env._is_stage_over 315 | info["is_dying"] = self.nes_env._is_dying 316 | info["is_dead"] = self.nes_env._is_dead 317 | info["y_pixel"] = self.nes_env._y_pixel 318 | info["y_viewport"] = self.nes_env._y_viewport 319 | info["x_position_last"] = self.nes_env._x_position_last 320 | info["in_pipe"] = (info["player_state"] == 0x02) or (info["player_state"] == 0x03) # noqa: PLR2004 321 | return info 322 | 323 | def _get_info( 324 | self, 325 | ): 326 | info = { 327 | "x_pos": 0, 328 | "y_pos": 0, 329 | "world": 0, 330 | "stage": 0, 331 | "life": 0, 332 | "coins": 0, 333 | "flag_get": False, 334 | "in_pipe": False, 335 | } 336 | return self._update_info(info) 337 | 338 | def get_coords_obs( 339 | self, 340 | obs: numpy.ndarray, 341 | info: dict[str, Any] | None = None, 342 | **kwargs, # noqa: ARG002 343 | ) -> numpy.ndarray: 344 | """Return the information contained in info as an observation if obs_type == "info".""" 345 | if self.obs_type == "coords": 346 | info = info or self._get_info() 347 | obs = numpy.array( 348 | [ 349 | info.get("x_pos", 0), 350 | info.get("y_pos", 0), 351 | info.get("world" * 10, 0), 352 | info.get("stage", 0), 353 | info.get("life", 0), 354 | int(info.get("flag_get", 0)), 355 | info.get("coins", 0), 356 | ], 357 | ) 358 | return obs 359 | 360 | def process_reward(self, reward, info, **kwargs) -> float: # noqa: ARG002 361 | """Return a custom reward based on the x, y coordinates and level mario is in.""" 362 | if not self._original_reward: 363 | world = int(info.get("world", 0)) 364 | stage = int(info.get("stage", 0)) 365 | x_pos = int(info.get("x_pos", 0)) 366 | reward = ( 367 | (world * 25000) 368 | + (stage * 5000) 369 | + x_pos 370 | + 10 * int(bool(info.get("in_pipe", 0))) 371 | + 100 * int(bool(info.get("flag_get", 0))) 372 | # + (abs(info["x_pos"] - info["x_position_last"])) 373 | ) 374 | return reward 375 | 376 | def process_terminal(self, terminal, info, **kwargs) -> bool: # noqa: ARG002 377 | """Return True if terminal or mario is dying.""" 378 | return terminal or info.get("is_dying", False) or info.get("is_dead", False) 379 | 380 | def process_info(self, info, **kwargs) -> dict[str, Any]: # noqa: ARG002 381 | """Add additional data to the info dictionary.""" 382 | return self._update_info(info) 383 | -------------------------------------------------------------------------------- /src/plangym/videogames/retro.py: -------------------------------------------------------------------------------- 1 | """Implement the ``plangym`` API for retro environments.""" 2 | 3 | from typing import Any, Iterable 4 | 5 | import gymnasium as gym 6 | from gymnasium import spaces 7 | import numpy 8 | 9 | from plangym.core import wrap_callable 10 | from plangym.videogames.env import VideogameEnv 11 | 12 | 13 | class ActionDiscretizer(gym.ActionWrapper): 14 | """Wrap a gym-retro environment and make it use discrete actions for the Sonic game.""" 15 | 16 | def __init__(self, env, actions=None): 17 | """Initialize a :class`ActionDiscretizer`.""" 18 | super().__init__(env) 19 | buttons = ["B", "A", "MODE", "START", "UP", "DOWN", "LEFT", "RIGHT", "C", "Y", "X", "Z"] 20 | actions = ( 21 | [ 22 | ["LEFT"], 23 | ["RIGHT"], 24 | ["LEFT", "DOWN"], 25 | ["RIGHT", "DOWN"], 26 | ["DOWN"], 27 | ["DOWN", "B"], 28 | ["B"], 29 | ] 30 | if actions is None 31 | else actions 32 | ) 33 | self._actions = [] 34 | for action in actions: 35 | arr = numpy.array([False] * 12) 36 | for button in action: 37 | arr[buttons.index(button)] = True 38 | self._actions.append(arr) 39 | self.action_space = spaces.Discrete(len(self._actions)) 40 | 41 | def action(self, a) -> int: # pylint: disable=W0221 42 | """Return the corresponding action in the emulator's format.""" 43 | return self._actions[a].copy() 44 | 45 | 46 | class RetroEnv(VideogameEnv): 47 | """Environment for playing ``gym-retro`` games.""" 48 | 49 | AVAILABLE_OBS_TYPES = {"coords", "rgb", "grayscale", "ram"} 50 | SINGLETON = True 51 | 52 | def __init__( 53 | self, 54 | name: str, 55 | frameskip: int = 5, 56 | episodic_life: bool = False, 57 | autoreset: bool = True, 58 | delay_setup: bool = False, 59 | remove_time_limit: bool = True, 60 | obs_type: str = "rgb", # ram | rgb | grayscale 61 | render_mode: str | None = None, # None | human | rgb_array 62 | wrappers: Iterable[wrap_callable] | None = None, 63 | **kwargs, 64 | ): 65 | """Initialize a :class:`RetroEnv`. 66 | 67 | Args: 68 | name: Name of the environment. Follows standard gym syntax conventions. 69 | frameskip: Number of times an action will be applied for each step \ 70 | in dt. 71 | episodic_life: Return ``end = True`` when losing a life. 72 | autoreset: Restart environment when reaching a terminal state. 73 | delay_setup: If ``True`` do not initialize the ``gym.Environment`` \ 74 | and wait for ``setup`` to be called later. 75 | remove_time_limit: If True, remove the time limit from the environment. 76 | obs_type: One of {"rgb", "ram", "grayscale"}. 77 | render_mode: One of {None, "human", "rgb_aray"}. 78 | wrappers: Wrappers that will be applied to the underlying OpenAI env. \ 79 | Every element of the iterable can be either a :class:`gym.Wrapper` \ 80 | or a tuple containing ``(gym.Wrapper, kwargs)``. 81 | kwargs: Additional arguments to be passed to the ``gym.make`` function. 82 | 83 | """ 84 | super().__init__( 85 | name=name, 86 | frameskip=frameskip, 87 | episodic_life=episodic_life, 88 | autoreset=autoreset, 89 | delay_setup=delay_setup, 90 | remove_time_limit=remove_time_limit, 91 | obs_type=obs_type, # ram | rgb | grayscale 92 | render_mode=render_mode, # None | human | rgb_array 93 | wrappers=wrappers, 94 | **kwargs, 95 | ) 96 | 97 | def __getattr__(self, item): 98 | """Forward getattr to self.gym_env.""" 99 | return getattr(self.gym_env, item) 100 | 101 | @staticmethod 102 | def get_win_condition(info: dict[str, Any]) -> bool: # pragma: no cover 103 | """Get win condition for games that have the end of the screen available.""" 104 | end_screen = info.get("screen_x", 0) >= info.get("screen_x_end", 1e6) 105 | return info.get("x", 0) >= info.get("screen_x_end", 1e6) or end_screen 106 | 107 | def get_ram(self) -> numpy.ndarray: 108 | """Return the ram of the emulator as a numpy array.""" 109 | return self.get_state() # .copy() 110 | 111 | def clone(self, **kwargs) -> "RetroEnv": 112 | """Return a copy of the environment with its initialization delayed.""" 113 | default_kwargs = { 114 | "name": self.name, 115 | "frameskip": self.frameskip, 116 | "wrappers": self._wrappers, 117 | "episodic_life": self.episodic_life, 118 | "autoreset": self.autoreset, 119 | "delay_setup": self.delay_setup, 120 | "obs_type": self.obs_type, 121 | } 122 | default_kwargs.update(kwargs) 123 | return super().clone(**default_kwargs) 124 | 125 | def init_gym_env(self) -> gym.Env: 126 | """Initialize the retro environment.""" 127 | import retro # noqa: PLC0415 128 | 129 | if self._gym_env is not None: 130 | self._gym_env.close() 131 | return retro.make(self.name, **self._gym_env_kwargs) 132 | 133 | def get_state(self) -> numpy.ndarray: 134 | """Get the state of the retro environment.""" 135 | state = self.gym_env.em.get_state() 136 | return numpy.frombuffer(state, dtype=numpy.uint8) 137 | 138 | def set_state(self, state: numpy.ndarray): 139 | """Set the state of the retro environment.""" 140 | raw_state = state.tobytes() 141 | self.gym_env.em.set_state(raw_state) 142 | return state 143 | 144 | def close(self): 145 | """Close the underlying :class:`gym.Env`.""" 146 | if hasattr(self, "_gym_env") and hasattr(self._gym_env, "close"): 147 | import gc # noqa: PLC0415 148 | 149 | self._gym_env.close() 150 | gc.collect() 151 | 152 | def reset( 153 | self, 154 | return_state: bool = True, 155 | ) -> numpy.ndarray | tuple[numpy.ndarray, numpy.ndarray]: 156 | """Restart the environment. 157 | 158 | Args: 159 | return_state: If ``True``, it will return the state of the environment. 160 | 161 | Returns: 162 | ``(state, obs)`` if ```return_state`` is ``True`` else return ``obs``. 163 | 164 | """ 165 | obs, _info = self.apply_reset() 166 | obs = self.process_obs(obs) 167 | info = _info or {} 168 | info = self.process_info(obs=obs, reward=0, terminal=False, info=info) 169 | return (self.get_state(), obs, info) if return_state else (obs, info) 170 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytest 4 | 5 | 6 | warnings.filterwarnings( 7 | action="ignore", category=pytest.PytestUnraisableExceptionWarning, module="pytest" 8 | ) 9 | try: 10 | import retro 11 | 12 | SKIP_RETRO_TESTS = False 13 | except ImportError: 14 | SKIP_RETRO_TESTS = True 15 | 16 | try: 17 | import ray 18 | 19 | SKIP_RAY_TESTS = False 20 | except ImportError: 21 | SKIP_RAY_TESTS = True 22 | 23 | try: 24 | from plangym.videogames.atari import AtariEnv 25 | 26 | SKIP_ATARI_TESTS = False 27 | except ImportError: 28 | SKIP_ATARI_TESTS = True 29 | 30 | try: 31 | from plangym.control.dm_control import DMControlEnv 32 | 33 | DMControlEnv(name="walker-run", frameskip=3) 34 | SKIP_DM_CONTROL_TESTS = False 35 | except (ImportError, AttributeError, ValueError): 36 | SKIP_DM_CONTROL_TESTS = True 37 | 38 | 39 | try: 40 | import Box2D 41 | 42 | SKIP_BOX2D_TESTS = False 43 | except ImportError: 44 | SKIP_BOX2D_TESTS = True 45 | -------------------------------------------------------------------------------- /tests/control/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FragileTech/plangym/417325eac9be9668f4065351ba22f3f6461f5a4f/tests/control/__init__.py -------------------------------------------------------------------------------- /tests/control/test_balloon.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | 6 | pytest.importorskip("balloon_learning_environment") 7 | from src.plangym.api_tests import ( 8 | batch_size, 9 | display, 10 | generate_test_cases, 11 | TestPlanEnv, 12 | TestPlangymEnv, 13 | ) 14 | from plangym.control.balloon import BalloonEnv 15 | 16 | 17 | disable_balloon_tests = not bool(os.getenv("DISABLE_BALLOON_ENV")) 18 | if disable_balloon_tests and str(disable_balloon_tests).lower() != "false": 19 | pytest.skip("balloon_learning_environment tests are disabled", allow_module_level=True) 20 | 21 | 22 | @pytest.fixture( 23 | params=generate_test_cases(["BalloonLearningEnvironment-v0"], BalloonEnv), 24 | scope="module", 25 | ) 26 | def env(request) -> BalloonEnv: 27 | return request.param() 28 | -------------------------------------------------------------------------------- /tests/control/test_box_2d.py: -------------------------------------------------------------------------------- 1 | from gymnasium.wrappers import TimeLimit 2 | import pytest 3 | 4 | 5 | pytest.importorskip("Box2D") 6 | from src.plangym.api_tests import ( 7 | batch_size, 8 | display, 9 | generate_test_cases, 10 | TestPlanEnv, 11 | TestPlangymEnv, 12 | ) 13 | from plangym.control.box_2d import Box2DEnv 14 | from plangym.environment_names import BOX_2D 15 | 16 | 17 | def bipedal_walker(): 18 | timelimit = [(TimeLimit, {"max_episode_steps": 1000})] 19 | return Box2DEnv(name="BipedalWalker-v3", autoreset=True, wrappers=timelimit) 20 | 21 | 22 | @pytest.fixture( 23 | params=generate_test_cases(BOX_2D, Box2DEnv, custom_tests=[bipedal_walker]), scope="module" 24 | ) 25 | def env(request) -> Box2DEnv: 26 | return request.param() 27 | -------------------------------------------------------------------------------- /tests/control/test_classic_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from plangym.control.classic_control import ClassicControl 6 | from plangym.environment_names import CLASSIC_CONTROL 7 | 8 | 9 | if ( 10 | os.getenv("SKIP_CLASSIC_CONTROL", None) 11 | and str(os.getenv("SKIP_CLASSIC_CONTROL", "false")).lower() != "false" 12 | ): 13 | pytest.skip("Skipping classic control", allow_module_level=True) 14 | 15 | from plangym.api_tests import ( 16 | batch_size, 17 | display, 18 | generate_test_cases, 19 | TestPlanEnv, 20 | TestPlangymEnv, 21 | ) 22 | import operator 23 | 24 | 25 | @pytest.fixture( 26 | params=zip(generate_test_cases(CLASSIC_CONTROL, ClassicControl), iter(CLASSIC_CONTROL)), 27 | ids=operator.itemgetter(1), 28 | scope="module", 29 | ) 30 | def env(request) -> ClassicControl: 31 | env = request.param[0]() 32 | yield env 33 | env.close() 34 | 35 | 36 | class TestClassic(TestPlangymEnv): 37 | def test_wrap_environment(self, env): 38 | if env.name == "Acrobot-v1": 39 | return None 40 | return super().test_wrap_environment(env) 41 | -------------------------------------------------------------------------------- /tests/control/test_dm_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy 4 | import pytest 5 | 6 | 7 | pytest.importorskip("dm_control") 8 | from src.plangym.api_tests import ( 9 | batch_size, 10 | display, 11 | generate_test_cases, 12 | TestPlanEnv, 13 | TestPlangymEnv, 14 | ) 15 | from plangym.control.dm_control import DMControlEnv 16 | from plangym.environment_names import DM_CONTROL 17 | 18 | 19 | class DummyTimeLimit: 20 | def __init__(self, env, max_episode_steps=None): 21 | self._max_episode_steps = max_episode_steps 22 | self._elapsed_steps = None 23 | self.env = env 24 | 25 | def __getattr__(self, item): 26 | return getattr(self.env, item) 27 | 28 | 29 | def walker_run(): 30 | timelimit = [(DummyTimeLimit, {"max_episode_steps": 1000})] 31 | return DMControlEnv(name="walker-run", frameskip=3, wrappers=timelimit) 32 | 33 | 34 | def parallel_dm(): 35 | return DMControlEnv(name="cartpole-balance", frameskip=3) 36 | 37 | 38 | environments = [walker_run, parallel_dm] 39 | 40 | 41 | @pytest.fixture( 42 | params=generate_test_cases(DM_CONTROL, DMControlEnv, n_workers_values=[None]), 43 | scope="module", 44 | ) 45 | def env(request) -> DMControlEnv: 46 | env = request.param() 47 | yield env 48 | try: 49 | env.close() 50 | except Exception: # noqa S110 51 | pass 52 | 53 | 54 | class TestDMControl: 55 | def test_attributes(self, env): 56 | env.reset() 57 | assert hasattr(env, "physics") 58 | assert hasattr(env, "action_spec") 59 | assert hasattr(env, "action_space") 60 | assert hasattr(env, "render_mode") 61 | assert env.render_mode in {"human", "rgb_array", "coords", None} 62 | 63 | @pytest.mark.skipif(os.getenv("SKIP_RENDER", None), reason="No display in CI.") 64 | def test_render(self, env): 65 | env.reset() 66 | obs_rgb = env.render(mode="rgb_array") 67 | assert isinstance(obs_rgb, numpy.ndarray) 68 | old_len = len(env.viewer) 69 | action = env.sample_action() 70 | env.step(action) 71 | env.render(mode="human") 72 | assert len(env.viewer) > old_len 73 | env.show_game(sleep=0.01) 74 | 75 | def test_parse_name_fails(self): 76 | with pytest.raises(ValueError): 77 | DMControlEnv(name="cartpole") 78 | -------------------------------------------------------------------------------- /tests/control/test_lunar_lander.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | pytest.importorskip("Box2D") 5 | from plangym import api_tests 6 | from plangym.api_tests import ( 7 | batch_size, 8 | display, 9 | generate_test_cases, 10 | TestPlangymEnv, 11 | ) 12 | from plangym.control.lunar_lander import FastGymLunarLander, LunarLander 13 | 14 | 15 | def lunar_lander_det_discrete(): 16 | return LunarLander(autoreset=False, deterministic=True, continuous=False) 17 | 18 | 19 | def lunar_lander_random_discrete(): 20 | return LunarLander(autoreset=False, deterministic=False, continuous=False) 21 | 22 | 23 | def lunar_lander_random_continuous(): 24 | return LunarLander( 25 | autoreset=False, 26 | deterministic=False, 27 | continuous=True, 28 | ) 29 | 30 | 31 | environments = [ 32 | lunar_lander_det_discrete, 33 | lunar_lander_random_discrete, 34 | lunar_lander_random_continuous, 35 | ] 36 | 37 | 38 | @pytest.fixture( 39 | params=generate_test_cases(["FastLunarLander-v0"], LunarLander, custom_tests=environments), 40 | scope="module", 41 | ) 42 | def env(request) -> LunarLander: 43 | env = request.param() 44 | yield env 45 | env.close() 46 | 47 | 48 | class TestFastGymLunarLander: 49 | def test_death(self): 50 | gym_env = FastGymLunarLander() 51 | gym_env.reset() 52 | for _ in range(1000): 53 | *_, end, _info = gym_env.step(gym_env.action_space.sample()) 54 | if end: 55 | break 56 | 57 | 58 | class TestLunarLander(api_tests.TestPlangymEnv): 59 | pass 60 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import numpy 4 | import pytest 5 | 6 | from plangym.api_tests import batch_size, display, TestPlanEnv 7 | from plangym.core import PlanEnv 8 | 9 | 10 | class DummyPlanEnv(PlanEnv): 11 | _step_count = 0 12 | _state = None 13 | 14 | @property 15 | def obs_shape(self) -> tuple[int]: 16 | """Tuple containing the shape of the observations returned by the Environment.""" 17 | return (10,) 18 | 19 | @property 20 | def action_shape(self) -> tuple[int]: 21 | """Tuple containing the shape of the actions applied to the Environment.""" 22 | return () 23 | 24 | def get_image(self): 25 | return numpy.zeros((10, 10, 3)) 26 | 27 | def get_state(self): 28 | if self._state is None: 29 | state = numpy.ones(10) 30 | state[-1] = self._step_count 31 | self._state = state 32 | return state 33 | return self._state 34 | 35 | def set_state(self, state: numpy.ndarray) -> None: 36 | self._state = state 37 | 38 | def sample_action(self): 39 | return 0 40 | 41 | def apply_reset(self, **kwargs): 42 | self._step_count = 0 43 | return numpy.zeros(10), {} 44 | 45 | def apply_action(self, action) -> tuple: 46 | self._step_count += 1 47 | obs, reward, end, truncated, info = numpy.ones(10), 1, False, False, {} 48 | return obs, reward, end, truncated, info 49 | 50 | def clone(self): 51 | return self 52 | 53 | 54 | environments = [lambda: DummyPlanEnv(name="dummy")] 55 | 56 | 57 | @pytest.fixture(params=environments, scope="class") 58 | def env(request) -> PlanEnv: 59 | return request.param() 60 | 61 | 62 | @pytest.fixture(params=environments, scope="class") 63 | def plangym_env(request) -> PlanEnv: 64 | return request.param() 65 | 66 | 67 | class TestPrivateAPI: 68 | @pytest.mark.parametrize("dt", [1, 3]) 69 | def test_step_with_dt(self, env, dt): 70 | _ = env.reset(return_state=False) 71 | action = env.sample_action() 72 | assert env.action_shape == numpy.array(action).shape 73 | data = env.step_with_dt(action, dt=dt) 74 | assert isinstance(data, tuple) 75 | -------------------------------------------------------------------------------- /tests/test_registry.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import gymnasium as gym 5 | import pytest 6 | 7 | from plangym.control.classic_control import ClassicControl 8 | from plangym.environment_names import ATARI, BOX_2D, CLASSIC_CONTROL, DM_CONTROL, RETRO 9 | from plangym.registry import make 10 | from plangym.vectorization.parallel import ParallelEnv 11 | from tests import ( 12 | SKIP_ATARI_TESTS, 13 | SKIP_BOX2D_TESTS, 14 | SKIP_DM_CONTROL_TESTS, 15 | SKIP_RAY_TESTS, 16 | SKIP_RETRO_TESTS, 17 | ) 18 | 19 | 20 | def _test_env_class(name, cls, **kwargs): 21 | n_workers = 2 22 | assert isinstance(make(name, delay_setup=False, **kwargs), cls) 23 | env = make(name=name, n_workers=n_workers, delay_setup=True, **kwargs) 24 | assert isinstance(env, ParallelEnv) 25 | assert env._env_class == cls 26 | assert env.n_workers == n_workers 27 | if not SKIP_RAY_TESTS: 28 | from plangym.vectorization.ray import RayEnv 29 | 30 | env = make(name=name, n_workers=n_workers, ray=True, delay_setup=True, **kwargs) 31 | assert isinstance(env, RayEnv) 32 | assert env._env_class == cls 33 | assert env.n_workers == n_workers 34 | 35 | 36 | class TestMake: 37 | @pytest.mark.parametrize("name", CLASSIC_CONTROL) 38 | def test_classic_control_make(self, name): 39 | _test_env_class(name, ClassicControl) 40 | 41 | @pytest.mark.skipif(SKIP_ATARI_TESTS, reason="Atari not installed") 42 | @pytest.mark.parametrize("name", ATARI[::10]) 43 | def test_atari_make(self, name): 44 | from plangym.videogames.atari import AtariEnv 45 | 46 | _test_env_class(name, AtariEnv) 47 | 48 | @pytest.mark.skipif(SKIP_BOX2D_TESTS, reason="BOX_2D not installed") 49 | @pytest.mark.parametrize("name", BOX_2D) 50 | def test_box2d_make(self, name): 51 | from plangym.control.box_2d import Box2DEnv 52 | from plangym.control.lunar_lander import LunarLander 53 | 54 | if name == "FastLunarLander-v0": 55 | _test_env_class(name, LunarLander) 56 | return 57 | if name == "CarRacing-v0" and os.getenv("SKIP_RENDER", None): 58 | return 59 | with warnings.catch_warnings(): 60 | warnings.simplefilter("ignore") 61 | _test_env_class(name, Box2DEnv) 62 | 63 | @pytest.mark.skipif(SKIP_RETRO_TESTS, reason="Retro not installed") 64 | @pytest.mark.parametrize("name", RETRO[::10]) 65 | def test_retro_make(self, name): 66 | from plangym.videogames.retro import RetroEnv 67 | 68 | try: 69 | _test_env_class(name, RetroEnv) 70 | except FileNotFoundError: 71 | pass 72 | 73 | @pytest.mark.skipif(SKIP_RETRO_TESTS, reason="Retro not installed") 74 | def test_retro_make_with_state(self): 75 | from plangym.videogames.retro import ActionDiscretizer, RetroEnv 76 | 77 | try: 78 | _test_env_class( 79 | "SonicTheHedgehog-Genesis", 80 | RetroEnv, 81 | state="GreenHillZone.Act3", 82 | wrappers=[ActionDiscretizer], 83 | ) 84 | except FileNotFoundError: 85 | pass 86 | 87 | @pytest.mark.skipif(SKIP_ATARI_TESTS, reason="Atari not installed") 88 | def test_custom_atari_make(self): 89 | # from plangym.minimal import MinimalPacman, MinimalPong 90 | from plangym.videogames import MontezumaEnv 91 | 92 | # _test_env_class("MinimalPacman-v0", MinimalPacman) 93 | # _test_env_class("MinimalPong-v0", MinimalPong) 94 | _test_env_class("PlanMontezuma-v0", MontezumaEnv) 95 | 96 | @pytest.mark.skipif(SKIP_DM_CONTROL_TESTS, reason="dm_control not installed") 97 | @pytest.mark.parametrize("name", DM_CONTROL) 98 | def test_dmcontrol_make(self, name): 99 | from plangym.control.dm_control import DMControlEnv 100 | 101 | domain_name, task_name = name 102 | if task_name is not None: 103 | _test_env_class(domain_name, DMControlEnv, task_name=task_name) 104 | else: 105 | _test_env_class(domain_name, DMControlEnv) 106 | 107 | @pytest.mark.skipif(SKIP_DM_CONTROL_TESTS, reason="dm_control not installed") 108 | @pytest.mark.parametrize("name", DM_CONTROL) 109 | def test_dmcontrol_domain_name_make(self, name): 110 | from plangym.control.dm_control import DMControlEnv 111 | 112 | domain_name, task_name = name 113 | if task_name is not None: 114 | _test_env_class( 115 | name=None, domain_name=domain_name, cls=DMControlEnv, task_name=task_name 116 | ) 117 | else: 118 | _test_env_class(name=None, domain_name=domain_name, cls=DMControlEnv) 119 | 120 | def test_invalid_name(self): 121 | with pytest.raises(gym.error.Error): 122 | make(name="Miaudb") 123 | with pytest.raises(gym.error.UnregisteredEnv): 124 | make(name="Miaudb-v0") 125 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing 3 | from gymnasium.wrappers.time_limit import TimeLimit 4 | from gymnasium.wrappers.transform_reward import TransformReward 5 | import numpy 6 | from numpy.random import default_rng 7 | 8 | from plangym.utils import process_frame, remove_time_limit, process_frame_opencv 9 | 10 | rng = default_rng() 11 | 12 | 13 | def test_remove_time_limit(): 14 | env = gym.make("MsPacmanNoFrameskip-v4") 15 | env = TransformReward(TimeLimit(AtariPreprocessing(env), max_episode_steps=100), lambda x: x) 16 | rem_env = remove_time_limit(env) 17 | assert rem_env.spec.max_episode_steps == int(1e100) 18 | assert not isinstance(rem_env.env, TimeLimit) 19 | assert "TimeLimit" not in str(rem_env) 20 | 21 | 22 | def test_process_frame(): 23 | example = (rng.random((100, 100, 3)) * 255).astype(numpy.uint8) 24 | frame = process_frame(example, mode="L") 25 | assert frame.shape == (100, 100) 26 | frame = process_frame(example, width=30, height=50) 27 | assert frame.shape == (50, 30, 3) 28 | frame = process_frame(example, width=80, height=70, mode="L") 29 | assert frame.shape == (70, 80) 30 | 31 | 32 | def test_process_frame_opencv(): 33 | example = (rng.random((100, 100, 3)) * 255).astype(numpy.uint8) 34 | frame = process_frame_opencv(example, mode="L") 35 | assert frame.shape == (100, 100) 36 | frame = process_frame_opencv(example, width=30, height=50) 37 | assert frame.shape == (50, 30, 3) 38 | frame = process_frame_opencv(example, width=30, height=50, mode="BGR") 39 | assert frame.shape == (50, 30, 3) 40 | frame = process_frame_opencv(example, width=80, height=70, mode="GRAY") 41 | assert frame.shape == (70, 80) 42 | -------------------------------------------------------------------------------- /tests/vectorization/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FragileTech/plangym/417325eac9be9668f4065351ba22f3f6461f5a4f/tests/vectorization/__init__.py -------------------------------------------------------------------------------- /tests/vectorization/test_parallel.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pytest 3 | 4 | from plangym.api_tests import batch_size, display, TestPlanEnv, TestPlangymEnv 5 | from plangym.control.classic_control import ClassicControl 6 | from plangym.vectorization.parallel import BatchEnv, ExternalProcess, ParallelEnv 7 | from plangym.videogames.atari import AtariEnv 8 | 9 | 10 | def parallel_cartpole(): 11 | return ParallelEnv(env_class=ClassicControl, name="CartPole-v0", blocking=True, n_workers=2) 12 | 13 | 14 | def parallel_pacman(): 15 | return ParallelEnv(env_class=AtariEnv, name="MsPacman-ram-v0", n_workers=2) 16 | 17 | 18 | environments = [parallel_cartpole, parallel_pacman] 19 | 20 | 21 | @pytest.fixture(params=environments, scope="module") 22 | def env(request) -> ClassicControl: 23 | return request.param() 24 | 25 | 26 | class TestBatchEnv: 27 | def test_len(self, env): 28 | assert len(env._batch_env) == 2 29 | 30 | def test_getattr(self, env): 31 | assert isinstance(env._batch_env, BatchEnv) 32 | assert env._batch_env.observation_space is not None 33 | 34 | def test_getitem(self, env): 35 | assert isinstance(env._batch_env[0], ExternalProcess) 36 | 37 | def test_reset(self, env): 38 | obs, _ = env._batch_env.reset(return_states=False) 39 | assert isinstance(obs, numpy.ndarray) 40 | indices = numpy.arange(len(env._batch_env._envs)) 41 | state, obs, _ = env._batch_env.reset(return_states=True, indices=indices) 42 | if env.STATE_IS_ARRAY: 43 | assert isinstance(state, numpy.ndarray) 44 | 45 | 46 | class TestExternalProcess: 47 | def test_reset(self, env): 48 | ep = env._batch_env[0] 49 | obs, *_ = ep.reset(return_states=False, blocking=True) 50 | assert isinstance(obs, numpy.ndarray) 51 | state, obs, _ = ep.reset(return_states=True, blocking=True) 52 | if env.STATE_IS_ARRAY: 53 | assert isinstance(state, numpy.ndarray) 54 | 55 | obs, *_ = ep.reset(return_states=False, blocking=False)() 56 | assert isinstance(obs, numpy.ndarray) 57 | state, obs, _ = ep.reset(return_states=True, blocking=False)() 58 | if env.STATE_IS_ARRAY: 59 | assert isinstance(state, numpy.ndarray) 60 | 61 | def test_step(self, env): 62 | ep = env._batch_env[0] 63 | state, *_ = ep.reset(return_states=True, blocking=True) 64 | ep.set_state(state, blocking=False)() 65 | action = env.sample_action() 66 | data = ep.step(action, dt=2, blocking=True) 67 | assert isinstance(data, tuple) 68 | state, *data = ep.step(action, state, blocking=True) 69 | assert len(data) > 0 70 | if env.STATE_IS_ARRAY: 71 | assert isinstance(state, numpy.ndarray) 72 | 73 | state, *_ = ep.reset(return_states=True, blocking=False)() 74 | action = env.sample_action() 75 | data = ep.step(action, dt=2, blocking=False)() 76 | assert isinstance(data, tuple) 77 | state, *data = ep.step(action, state, blocking=False)() 78 | assert len(data) > 0 79 | 80 | def test_attributes(self, env): 81 | ep = env._batch_env[0] 82 | ep.observation_space 83 | ep.action_space.sample() 84 | ep.unwrapped 85 | -------------------------------------------------------------------------------- /tests/vectorization/test_ray.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | import numpy 5 | import pytest 6 | import ray 7 | 8 | from plangym.control.classic_control import ClassicControl 9 | from plangym.vectorization.ray import RayEnv, RemoteEnv 10 | from plangym.videogames.atari import AtariEnv 11 | 12 | 13 | pytest.importorskip("ray") 14 | if os.getenv("DISABLE_RAY") and str(os.getenv("DISABLE_RAY", "False")).lower() != "false": 15 | pytest.skip("Ray not installed or disabled", allow_module_level=True) 16 | from src.plangym.api_tests import batch_size, display, TestPlanEnv, TestPlangymEnv 17 | 18 | 19 | def ray_cartpole(): 20 | return RayEnv(env_class=ClassicControl, name="CartPole-v0", n_workers=2) 21 | 22 | 23 | def ray_retro(): 24 | from plangym.videogames.retro import RetroEnv 25 | 26 | return RayEnv(env_class=RetroEnv, name="Airstriker-Genesis", n_workers=2) 27 | 28 | 29 | def ray_dm_control(): 30 | from plangym.control.dm_control import DMControlEnv 31 | 32 | return RayEnv(env_class=DMControlEnv, name="walker-walk", n_workers=2) 33 | 34 | 35 | environments = [(ray_cartpole, True), (ray_dm_control, True), (ray_retro, False)] 36 | 37 | 38 | @pytest.fixture(params=environments, scope="module") 39 | def env(request) -> AtariEnv: 40 | env_call, local = request.param 41 | with warnings.catch_warnings(): 42 | warnings.simplefilter("ignore") 43 | ray.init(ignore_reinit_error=True, local_mode=local) 44 | yield env_call() 45 | ray.shutdown() 46 | 47 | 48 | def test_remote_actor(): 49 | with warnings.catch_warnings(): 50 | warnings.simplefilter("ignore") 51 | ray.init(ignore_reinit_error=True, local_mode=True) 52 | 53 | def create_cartpole(): 54 | return ClassicControl(name="CartPole-v0") 55 | 56 | env = RemoteEnv.remote(create_cartpole) 57 | ray.get(env.setup.remote()) 58 | ray.get(env.reset.remote()) 59 | ray.get(env.step.remote(0)) 60 | state = ray.get(env.get_state.remote()) 61 | assert isinstance(state, numpy.ndarray) 62 | -------------------------------------------------------------------------------- /tests/videogames/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FragileTech/plangym/417325eac9be9668f4065351ba22f3f6461f5a4f/tests/videogames/__init__.py -------------------------------------------------------------------------------- /tests/videogames/test_atari.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.wrappers import TimeLimit 3 | import numpy 4 | import pytest 5 | 6 | from plangym.environment_names import ATARI 7 | from plangym.videogames.atari import ale_to_ram, AtariEnv 8 | from tests import SKIP_ATARI_TESTS 9 | 10 | 11 | if SKIP_ATARI_TESTS: 12 | pytest.skip("Atari not installed, skipping", allow_module_level=True) 13 | from plangym.api_tests import ( 14 | batch_size, 15 | display, 16 | generate_test_cases, 17 | TestPlanEnv, 18 | TestPlangymEnv, 19 | ) 20 | 21 | 22 | def qbert_ram(): 23 | return AtariEnv(name="Qbert-ram-v4", clone_seeds=False, autoreset=False) 24 | 25 | 26 | @pytest.fixture( 27 | params=generate_test_cases(ATARI, AtariEnv, custom_tests=[qbert_ram]), 28 | scope="module", 29 | ) 30 | def env(request) -> AtariEnv: 31 | env = request.param() 32 | yield env 33 | env.close() 34 | 35 | 36 | class TestAtariEnv: 37 | def test_ale_to_ram(self, env): 38 | _ = env.reset() 39 | ram = ale_to_ram(env.ale) 40 | env_ram = env.get_ram() 41 | assert isinstance(ram, numpy.ndarray) 42 | assert ram.shape == env_ram.shape 43 | assert (ram == env_ram).all() 44 | 45 | def test_get_image(self): 46 | env = qbert_ram() 47 | obs = env.get_image() 48 | assert isinstance(obs, numpy.ndarray) 49 | 50 | def test_n_actions(self, env): 51 | n_actions = env.n_actions 52 | assert isinstance(n_actions, int | np.int64) 53 | -------------------------------------------------------------------------------- /tests/videogames/test_montezuma.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pytest 3 | 4 | from plangym.vectorization.parallel import ParallelEnv 5 | from plangym.videogames.montezuma import CustomMontezuma, MontezumaEnv, MontezumaPosLevel 6 | from tests import SKIP_ATARI_TESTS 7 | 8 | 9 | if SKIP_ATARI_TESTS: 10 | pytest.skip("Atari not installed, skipping", allow_module_level=True) 11 | from plangym import api_tests 12 | from plangym.api_tests import batch_size, display, TestPlangymEnv 13 | 14 | 15 | def montezuma(): 16 | return MontezumaEnv(obs_type="coords", autoreset=True, score_objects=True) 17 | 18 | 19 | def montezuma_unproc(): 20 | return MontezumaEnv(obs_type="rgb", autoreset=True, only_keys=True) 21 | 22 | 23 | def parallel_montezuma(): 24 | return ParallelEnv( 25 | env_class=MontezumaEnv, 26 | frameskip=5, 27 | name="", 28 | score_objects=True, 29 | objects_from_pixels=True, 30 | ) 31 | 32 | 33 | def montezuma_coords(): 34 | return MontezumaEnv(autoreset=True, obs_type="coords") 35 | 36 | 37 | environments = [montezuma, montezuma_unproc, parallel_montezuma, montezuma_coords] 38 | 39 | 40 | @pytest.fixture(params=environments, scope="module") 41 | def env(request) -> MontezumaEnv: 42 | env = request.param() 43 | yield env 44 | env.close() 45 | 46 | 47 | @pytest.fixture(scope="module") 48 | def pos_level(): 49 | return MontezumaPosLevel(1, 100, 2, 30, 16) 50 | 51 | 52 | class TestMontezumaPosLevel: 53 | def test_hash(self, pos_level): 54 | assert isinstance(hash(pos_level), int) 55 | 56 | def test_compate(self, pos_level): 57 | assert pos_level == MontezumaPosLevel(*pos_level.tuple) 58 | assert not pos_level == 6 59 | 60 | def test_get_state(self, pos_level): 61 | assert pos_level.__getstate__() == pos_level.tuple 62 | 63 | def test_set_state(self, pos_level): 64 | level, score, room, x, y = (10, 9, 8, 7, 6) 65 | pos_level.__setstate__((level, score, room, x, y)) 66 | assert pos_level.tuple == (10, 9, 8, 7, 6) 67 | 68 | def test_repr(self, pos_level): 69 | assert isinstance(repr(pos_level), str) 70 | 71 | 72 | class TestCustomMontezuma: 73 | def test_pos_from_unproc_state(self): 74 | env = CustomMontezuma(obs_type="rgb") 75 | obs = env.reset() 76 | for i in range(20): 77 | obs, *_ = env.step(0) 78 | facepix = env.get_face_pixels(obs) 79 | pos = env.pos_from_obs(face_pixels=facepix, obs=obs) 80 | assert isinstance(pos, MontezumaPosLevel) 81 | 82 | def test_get_objects_from_pixel(self): 83 | env = CustomMontezuma(obs_type="rgb") 84 | obs = env.reset() 85 | for i in range(20): 86 | obs, *_ = env.step(0) 87 | ob = env.get_objects_from_pixels(room=0, obs=obs, old_objects=[]) 88 | assert isinstance(ob, int) 89 | 90 | env = CustomMontezuma(obs_type="rgb", objects_remember_rooms=True) 91 | obs = env.reset() 92 | for i in range(20): 93 | obs, *_ = env.step(0) 94 | tup = env.get_objects_from_pixels(room=0, obs=obs, old_objects=[]) 95 | assert isinstance(tup, tuple) 96 | 97 | def test_get_room_xy(self): 98 | # Test cases for known rooms 99 | assert CustomMontezuma.get_room_xy(0) == (3, 0) 100 | assert CustomMontezuma.get_room_xy(23) == (8, 3) 101 | assert CustomMontezuma.get_room_xy(10) == (3, 2) 102 | 103 | # Test case for a room not in the pyramid 104 | assert CustomMontezuma.get_room_xy(24) is None 105 | assert CustomMontezuma.get_room_xy(-2) is None 106 | 107 | 108 | class TestMontezuma(api_tests.TestPlanEnv): 109 | @pytest.mark.parametrize("state", [None, True]) 110 | @pytest.mark.parametrize("return_state", [None, True, False]) 111 | def test_step(self, env, state, return_state, dt=1): 112 | _state, *_ = env.reset(return_state=True) 113 | if state is not None: 114 | state = _state 115 | action = env.sample_action() 116 | data = env.step(action, dt=dt, state=state, return_state=return_state) 117 | *new_state, obs, reward, terminal, _truncated, info = data 118 | assert isinstance(data, tuple) 119 | # Test return state works correctly 120 | should_return_state = state is not None if return_state is None else return_state 121 | if should_return_state: 122 | assert len(new_state) == 1 123 | new_state = new_state[0] 124 | state_is_array = isinstance(new_state, numpy.ndarray) 125 | assert state_is_array if env.STATE_IS_ARRAY else not state_is_array 126 | if state_is_array: 127 | assert _state.shape == new_state.shape 128 | if not env.SINGLETON and env.STATE_IS_ARRAY: 129 | curr_state = env.get_state() 130 | curr_state, new_state = curr_state[1:], new_state[1:] 131 | assert new_state.shape == curr_state.shape 132 | assert (new_state == curr_state).all(), ( 133 | f"original: {new_state[new_state != curr_state]} " 134 | f"env: {curr_state[new_state != curr_state]}" 135 | ) 136 | else: 137 | assert len(new_state) == 0 138 | api_tests.step_tuple_test(env, obs, reward, terminal, info, dt=dt) 139 | -------------------------------------------------------------------------------- /tests/videogames/test_nes.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from plangym.api_tests import ( 4 | batch_size, 5 | display, 6 | generate_test_cases, 7 | TestPlanEnv, 8 | TestPlangymEnv, 9 | TestVideogameEnv, 10 | ) 11 | from plangym.videogames.nes import MarioEnv 12 | 13 | 14 | env_names = ["SuperMarioBros-v0", "SuperMarioBros-v1", "SuperMarioBros2-v0"] 15 | 16 | 17 | @pytest.fixture( 18 | params=generate_test_cases(env_names, MarioEnv, n_workers_values=None), scope="module" 19 | ) 20 | def env(request): 21 | return request.param() 22 | 23 | 24 | class TestMarioEnv: 25 | def test_get_keys_to_action(self, env): 26 | vals = env.gym_env.get_keys_to_action() 27 | assert isinstance(vals, dict) 28 | 29 | def test_get_action_meanings(self, env): 30 | vals = env.gym_env.get_action_meanings() 31 | assert isinstance(vals, list) 32 | 33 | def test_buttons(self, env): 34 | buttons = env.gym_env.buttons() 35 | assert isinstance(buttons, list), buttons 36 | assert all(isinstance(b, str) for b in buttons) 37 | -------------------------------------------------------------------------------- /tests/videogames/test_retro.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import gymnasium as gym 4 | import pytest 5 | 6 | from plangym.vectorization.parallel import ParallelEnv 7 | from plangym.videogames.retro import ActionDiscretizer, RetroEnv 8 | 9 | 10 | pytest.importorskip("retro") 11 | 12 | from plangym import api_tests 13 | from plangym.api_tests import batch_size, display, TestPlanEnv 14 | 15 | 16 | def retro_airstrike(): 17 | res_obs = gym.wrappers.resize_observation.ResizeObservation 18 | return RetroEnv(name="Airstriker-Genesis", wrappers=[(res_obs, {"shape": (90, 90)})]) 19 | 20 | 21 | def retro_sonic(): 22 | return RetroEnv( 23 | name="SonicTheHedgehog-Genesis", 24 | state="GreenHillZone.Act3", 25 | wrappers=[ActionDiscretizer], 26 | obs_type="grayscale", 27 | ) 28 | 29 | 30 | def parallel_retro(): 31 | return ParallelEnv( 32 | name="Airstriker-Genesis", 33 | env_class=RetroEnv, 34 | n_workers=2, 35 | obs_type="ram", 36 | wrappers=[ActionDiscretizer], 37 | ) 38 | 39 | 40 | environments = [retro_airstrike, retro_sonic, parallel_retro] 41 | 42 | 43 | @pytest.fixture(params=environments, scope="class") 44 | def env(request) -> RetroEnv | ParallelEnv: 45 | env_ = request.param() 46 | if env_.delay_setup and env_.gym_env is None: 47 | env_.setup() 48 | yield env_ 49 | env_.close() 50 | 51 | 52 | class TestRetro: 53 | def test_init_env(self): 54 | env = retro_airstrike() 55 | env.reset() 56 | env.setup() 57 | 58 | def test_getattribute(self): 59 | env = retro_airstrike() 60 | env.em.get_state() 61 | 62 | def test_clone(self): 63 | env = RetroEnv(name="Airstriker-Genesis", obs_type="ram", delay_setup=True) 64 | new_env = env.clone() 65 | del env 66 | new_env.reset() 67 | 68 | 69 | class TestPlangymRetro(api_tests.TestPlangymEnv): 70 | pass 71 | -------------------------------------------------------------------------------- /uncompressed_ROMs.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FragileTech/plangym/417325eac9be9668f4065351ba22f3f6461f5a4f/uncompressed_ROMs.zip --------------------------------------------------------------------------------