├── .circleci └── config.yml ├── .codespell.skip ├── .dockerignore ├── .github └── workflows │ └── publish-to-pypi.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── ci ├── Xdummy-entrypoint.py ├── build_venv.sh ├── code_checks.sh └── xorg.conf ├── codecov.yml ├── docs ├── Makefile ├── _static │ └── img │ │ └── logo.svg ├── common │ ├── base_envs.rst │ ├── testing.rst │ └── util.rst ├── conf.py ├── environments │ ├── diagnostic.rst │ └── renovated.rst ├── guide │ └── install.rst ├── index.rst └── make.bat ├── mypy.ini ├── pyproject.toml ├── readthedocs.yml ├── setup.cfg ├── setup.py ├── src └── seals │ ├── __init__.py │ ├── atari.py │ ├── base_envs.py │ ├── classic_control.py │ ├── diagnostics │ ├── __init__.py │ ├── branching.py │ ├── cliff_world.py │ ├── early_term.py │ ├── init_shift.py │ ├── largest_sum.py │ ├── noisy_obs.py │ ├── parabola.py │ ├── proc_goal.py │ ├── random_trans.py │ ├── risky_path.py │ └── sort.py │ ├── mujoco.py │ ├── py.typed │ ├── testing │ ├── __init__.py │ └── envs.py │ └── util.py └── tests ├── conftest.py ├── test_base_env.py ├── test_diagnostics.py ├── test_envs.py ├── test_mujoco_rl.py ├── test_util.py └── test_wrappers.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 # Version of CircleCI config format 2 | 3 | # "Orbs" are reusable packages of CircleCI config. 4 | # They can simplify common tasks, such as interacting with external services. 5 | # This section lists orbs this config uses. 6 | orbs: 7 | codecov: codecov/codecov@1.1.0 # support for uploading code coverage to codecov 8 | 9 | defaults: &defaults 10 | docker: 11 | - image: humancompatibleai/seals:base-alpha 12 | auth: 13 | username: $DOCKERHUB_USERNAME 14 | password: $DOCKERHUB_PASSWORD 15 | working_directory: /seals 16 | 17 | executors: 18 | unit-test: 19 | <<: *defaults 20 | resource_class: large 21 | environment: 22 | # Don't use auto-detect since it sees all CPUs available, but container is throttled. 23 | NUM_CPUS: 4 24 | lintandtype: 25 | <<: *defaults 26 | resource_class: medium 27 | environment: 28 | # If you change these, also change ci/code_checks.sh 29 | LINT_FILES: src/ tests/ docs/conf.py setup.py # files we lint 30 | # Files we statically type check. Source files like src/ should almost always be present. 31 | # In this repo we also typecheck tests/ -- but sometimes you may want to exclude these 32 | # if they do strange things with types (e.g. mocking). 33 | TYPECHECK_FILES: src/ tests/ setup.py 34 | # Don't use auto-detect since it sees all CPUs available, but container is throttled. 35 | NUM_CPUS: 2 36 | 37 | 38 | commands: 39 | # Define common function to install dependencies and seals, used in the jobs defined in the next section 40 | dependencies: 41 | description: "Check out and update Python dependencies." 42 | steps: 43 | - checkout # Check out the code from Git 44 | 45 | # Download and cache dependencies 46 | # Note the Docker image must still be manually updated if any binary (non-Python) dependencies change. 47 | 48 | # Restore cache if it exists. setup.py defines all the requirements, so we checksum that. 49 | # If you want to force an update despite setup.py not changing, you can bump the version 50 | # number `vn-dependencies`. This can be useful if newer versions of a package have been 51 | # released that you want to upgrade to, without mandating the newer version in setup.py. 52 | - restore_cache: 53 | keys: 54 | - v2-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_venv.sh" }} 55 | 56 | # Create virtual environment and install dependencies using `ci/build_venv.sh`. 57 | # `mujoco_py` needs a MuJoCo key, so download that first. 58 | # We do some sanity checks to ensure the key works. 59 | - run: 60 | name: install dependencies 61 | command: "[[ -d /venv ]] || /seals/ci/build_venv.sh /venv" 62 | 63 | # Save the cache of dependencies. 64 | - save_cache: 65 | paths: 66 | - /venv 67 | key: v2-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_venv.sh" }} 68 | 69 | # Install seals. 70 | # Note we install the source distribution, not in developer mode (`pip install -e`). 71 | # This ensures we're testing the package our users would experience, and in particular 72 | # will catch e.g. modules or data files missing from `setup.py`. 73 | - run: 74 | name: install evaluating_rewards 75 | # Build a wheel then install to avoid copying whole directory (pip issue #2195) 76 | command: | 77 | python setup.py sdist bdist_wheel 78 | pip install --upgrade --force-reinstall --no-deps dist/seals-*.whl 79 | 80 | # The `jobs` section defines jobs that can be executed on CircleCI as part of workflows. 81 | jobs: 82 | # `lintandtype` installs dependencies + `seals`, lints the code, builds the docs, and runs type checks. 83 | lintandtype: 84 | executor: lintandtype 85 | 86 | steps: 87 | - dependencies 88 | - run: 89 | name: flake8 90 | command: flake8 ${LINT_FILES} 91 | 92 | - run: 93 | name: black 94 | command: black --check ${LINT_FILES} 95 | 96 | - run: 97 | name: codespell 98 | command: codespell -I .codespell.skip --skip='*.pyc' ${LINT_FILES} 99 | 100 | - run: 101 | name: sphinx 102 | command: pushd docs/ && make clean && make html && popd 103 | 104 | - run: 105 | name: pytype 106 | command: pytype ${TYPECHECK_FILES} 107 | 108 | - run: 109 | name: mypy 110 | command: mypy ${TYPECHECK_FILES} 111 | 112 | # `unit-test` runs the unit tests in `tests/`. 113 | unit-test: 114 | executor: unit-test 115 | steps: 116 | - dependencies 117 | 118 | # Running out of memory is a common cause of spurious test failures. 119 | # In particular, the CI machines have less memory than most workstations. 120 | # So tests can pass locally but fail on CI. Record memory and other resource 121 | # usage over time to aid with diagnosing these failures. 122 | - run: 123 | name: Memory Monitor 124 | # | is needed for multi-line values in YAML 125 | command: | 126 | mkdir /tmp/resource-usage 127 | export FILE=/tmp/resource-usage/memory.txt 128 | while true; do 129 | ps -u root eo pid,%cpu,%mem,args,uname --sort=-%mem >> $FILE 130 | echo "----------" >> $FILE 131 | sleep 1 132 | done 133 | background: true 134 | 135 | # Run the unit tests themselves 136 | - run: 137 | name: run tests 138 | command: | 139 | # Xdummy-entrypoint.py: starts an X server and sets DISPLAY, then runs wrapped command. 140 | # pytest arguments: 141 | # --cov specifies which directories to report code coverage for 142 | # Since we test the installed `seals`, our source files live in `venv`, not in `src/seals`. 143 | # --junitxml records test results in JUnit format. We upload this file using `store_test_results` 144 | # later, and CircleCI then parses this to pretty-print results. 145 | # --shard-id and --num-shards are used to split tests across parallel executors using `pytest-shard`. 146 | # -n uses `pytest-xdist` to parallelize tests within a single instance. 147 | Xdummy-entrypoint.py pytest --cov=/venv/lib/python3.8/site-packages/seals --cov=tests \ 148 | --junitxml=/tmp/test-reports/junit.xml \ 149 | -n ${NUM_CPUS} -vv tests/ 150 | # Following two lines rewrite paths from venv/ to src/, based on `coverage:paths` in `setup.cfg` 151 | # This is needed to avoid confusing Codecov 152 | mv .coverage .coverage.bench 153 | coverage combine 154 | - codecov/upload 155 | 156 | # Upload the test results and resource usage to CircleCI 157 | - store_artifacts: 158 | path: /tmp/test-reports 159 | destination: test-reports 160 | - store_artifacts: 161 | path: /tmp/resource-usage 162 | destination: resource-usage 163 | # store_test_results uploads the files and tells CircleCI that it should parse them as test results 164 | - store_test_results: 165 | path: /tmp/test-reports 166 | 167 | # Workflows specify what jobs to actually run on CircleCI. If we didn't specify this, 168 | # nothing would run! Here we have just a single workflow, `test`, containing both the 169 | # jobs defined above. By default, the jobs all run in parallel. We can make them run 170 | # sequentially, or have more complex dependency structures, using the `require` command; 171 | # see https://circleci.com/docs/2.0/workflows/ 172 | # 173 | # We attach two contexts to both jobs, which define a set of environment variable: 174 | # - `MuJoCo` which contains the URL for our MuJoCo license key. 175 | # - `docker-hub-creds` which contain the credentials for our Dockerhub machine user. 176 | # It's important these are kept confidential -- so don't echo the environment variables 177 | # anywhere in the config! 178 | workflows: 179 | version: 2 180 | test: 181 | jobs: 182 | - lintandtype: 183 | context: 184 | - docker-hub-creds 185 | - unit-test: 186 | context: 187 | - docker-hub-creds 188 | -------------------------------------------------------------------------------- /.codespell.skip: -------------------------------------------------------------------------------- 1 | ith 2 | reacher 3 | iff 4 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | .gitignore -------------------------------------------------------------------------------- /.github/workflows/publish-to-pypi.yml: -------------------------------------------------------------------------------- 1 | # Adapted from https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ 2 | 3 | name: Publish seals distributions 📦 to PyPI and TestPyPI 4 | 5 | on: push 6 | 7 | jobs: 8 | build-n-publish: 9 | name: Build and publish seals distributions 📦 to PyPI and TestPyPI 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v3 14 | with: 15 | # Fetch tags needed by setuptools_scm to infer version number 16 | # See https://github.com/pypa/setuptools_scm/issues/414 17 | fetch-depth: 0 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Install pypa/build 24 | run: >- 25 | python -m 26 | pip install 27 | build 28 | --user 29 | - name: Build a binary wheel and a source tarball 30 | run: >- 31 | python -m 32 | build 33 | --sdist 34 | --wheel 35 | --outdir dist/ 36 | . 37 | 38 | # Publish new distribution to Test PyPi on every push. 39 | # This ensures the workflow stays healthy, and will also serve 40 | # as a source of alpha builds. 41 | - name: Publish distribution 📦 to Test PyPI 42 | uses: pypa/gh-action-pypi-publish@release/v1 43 | with: 44 | password: ${{ secrets.TEST_PYPI_API_TOKEN }} 45 | repository_url: https://test.pypi.org/legacy/ 46 | 47 | # Publish new distribution to production PyPi on releases. 48 | - name: Publish distribution 📦 to PyPI 49 | if: startsWith(github.ref, 'refs/tags/v') 50 | uses: pypa/gh-action-pypi-publish@release/v1 51 | with: 52 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # PyCharm project settings 121 | .idea/ 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype type checker 135 | .pytype/ 136 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # base stage contains just binary dependencies. 2 | # This is used in the CI build. 3 | FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04 AS base 4 | ARG DEBIAN_FRONTEND=noninteractive 5 | 6 | RUN apt-get update -q \ 7 | && apt-get install -y --no-install-recommends \ 8 | build-essential \ 9 | curl \ 10 | ffmpeg \ 11 | git \ 12 | ssh \ 13 | libgl1-mesa-dev \ 14 | libgl1-mesa-glx \ 15 | libglew-dev \ 16 | libosmesa6-dev \ 17 | net-tools \ 18 | parallel \ 19 | patchelf \ 20 | python3.8 \ 21 | python3.8-dev \ 22 | python3-pip \ 23 | rsync \ 24 | software-properties-common \ 25 | unzip \ 26 | vim \ 27 | virtualenv \ 28 | xpra \ 29 | xserver-xorg-dev \ 30 | && apt-get clean \ 31 | && rm -rf /var/lib/apt/lists/* 32 | 33 | ENV LANG C.UTF-8 34 | 35 | RUN mkdir -p /root/.mujoco \ 36 | && curl -o mjpro150.zip https://www.roboti.us/download/mjpro150_linux.zip \ 37 | && unzip mjpro150.zip -d /root/.mujoco \ 38 | && rm mjpro150.zip \ 39 | && curl -o /root/.mujoco/mjkey.txt https://www.roboti.us/file/mjkey.txt 40 | 41 | # Set the PATH to the venv before we create the venv, so it's visible in base. 42 | # This is since we may create the venv outside of Docker, e.g. in CI 43 | # or by binding it in for local development. 44 | ENV PATH="/venv/bin:$PATH" 45 | ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH} 46 | 47 | # Run Xdummy mock X server by default so that rendering will work. 48 | COPY ci/xorg.conf /etc/dummy_xorg.conf 49 | COPY ci/Xdummy-entrypoint.py /usr/bin/Xdummy-entrypoint.py 50 | ENTRYPOINT ["/usr/bin/Xdummy-entrypoint.py"] 51 | 52 | # python-req stage contains Python venv, but not code. 53 | # It is useful for development purposes: you can mount 54 | # code from outside the Docker container. 55 | FROM base as python-req 56 | 57 | WORKDIR /seals 58 | # Copy only necessary dependencies to build virtual environment. 59 | # This minimizes how often this layer needs to be rebuilt. 60 | COPY ./setup.py ./setup.py 61 | COPY ./README.md ./README.md 62 | COPY ./src/seals/version.py ./src/seals/version.py 63 | COPY ./ci/build_venv.sh ./ci/build_venv.sh 64 | RUN /seals/ci/build_venv.sh /venv \ 65 | && rm -rf $HOME/.cache/pip 66 | 67 | # full stage contains everything. 68 | # Can be used for deployment and local testing. 69 | FROM python-req as full 70 | 71 | # Delay copying (and installing) the code until the very end 72 | COPY . /seals 73 | # Build a wheel then install to avoid copying whole directory (pip issue #2195) 74 | RUN python3 setup.py sdist bdist_wheel 75 | RUN pip install --upgrade dist/seals-*.whl 76 | 77 | # Default entrypoints 78 | CMD ["pytest", "-n", "auto", "-vv", "tests/"] 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Center for Human-Compatible AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![CircleCI](https://circleci.com/gh/HumanCompatibleAI/seals.svg?style=svg)](https://circleci.com/gh/HumanCompatibleAI/seals) 2 | [![Documentation Status](https://readthedocs.org/projects/seals/badge/?version=latest)](https://seals.readthedocs.io/en/latest/?badge=latest) 3 | [![codecov](https://codecov.io/gh/HumanCompatibleAI/seals/branch/master/graph/badge.svg)](https://codecov.io/gh/HumanCompatibleAI/seals) 4 | [![PyPI version](https://badge.fury.io/py/seals.svg)](https://badge.fury.io/py/seals) 5 | 6 |

7 | 8 | **Status**: early beta. 9 | 10 | *seals*, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for 11 | evaluating specification learning algorithms, such as reward or imitation learning. The 12 | environments are compatible with [Gym](https://github.com/openai/gym), but are designed 13 | to test algorithms that learn from user data, without requiring a procedurally specified 14 | reward function. 15 | 16 | There are two types of environments in *seals*: 17 | 18 | - **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. 19 | - **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous 20 | control tasks and Atari games to be suitable for specification learning benchmarks. In particular, 21 | we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). 22 | 23 | *seals* is under active development and we intend to add more categories of tasks soon. 24 | 25 | You may also be interested in our sister project [imitation](https://github.com/humancompatibleai/imitation/), 26 | providing implementations of a variety of imitation and reward learning algorithms. 27 | 28 | Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about *seals*. 29 | 30 | # Quickstart 31 | 32 | To install the latest release from PyPI, run: 33 | 34 | ```bash 35 | pip install seals 36 | ``` 37 | 38 | All *seals* environments are available in the Gym registry. Simply import it and then use as you 39 | would with your usual RL or specification learning algroithm: 40 | 41 | ```python 42 | import gymnasium as gym 43 | import seals 44 | 45 | env = gym.make('seals/CartPole-v0') 46 | ``` 47 | 48 | We make releases periodically, but if you wish to use the latest version of the code, you can 49 | install directly from Git master: 50 | 51 | ```bash 52 | pip install git+https://github.com/HumanCompatibleAI/seals.git 53 | ``` 54 | 55 | # Contributing 56 | 57 | For development, clone the source code and create a virtual environment for this project: 58 | 59 | ```bash 60 | git clone git@github.com:HumanCompatibleAI/seals.git 61 | cd seals 62 | ./ci/build_venv.sh 63 | pip install -e .[dev] # install extra tools useful for development 64 | ``` 65 | 66 | ## Code style 67 | 68 | We follow a PEP8 code style with line length 88, and typically follow the [Google Code Style Guide](http://google.github.io/styleguide/pyguide.html), 69 | but defer to PEP8 where they conflict. We use the `black` autoformatter to avoid arguing over formatting. 70 | Docstrings follow the Google docstring convention defined [here](http://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings), 71 | with an extensive example in the [Sphinx docs](https://www.sphinx-doc.org/en/master/usage/extensions/example_google.html). 72 | 73 | All PRs must pass linting via the `ci/code_checks.sh` script. It is convenient to install this as a commit hook: 74 | 75 | ```bash 76 | ln -s ../../ci/code_checks.sh .git/hooks/pre-commit 77 | ``` 78 | 79 | ## Tests 80 | 81 | We use [pytest](https://docs.pytest.org/en/latest/) for unit tests 82 | and [codecov](http://codecov.io/) for code coverage. 83 | We also use [pytype](https://github.com/google/pytype) and [mypy](http://mypy-lang.org/) 84 | for type checking. 85 | 86 | ## Workflow 87 | 88 | Trivial changes (e.g. typo fixes) may be made directly by maintainers. Any non-trivial changes 89 | must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous 90 | integration tests (CircleCI linting, type checking, unit tests and CodeCov) to be merged. 91 | 92 | It is often helpful to open an issue before proposing a PR, to allow for discussion of the design 93 | before coding commences. 94 | 95 | # Citing seals 96 | 97 | To cite this project in publications: 98 | 99 | ```bibtex 100 | @misc{seals, 101 | author = {Adam Gleave and Pedro Freire and Steven Wang and Sam Toyer}, 102 | title = {{seals}: Suite of Environments for Algorithms that Learn Specifications}, 103 | year = {2020}, 104 | publisher = {GitHub}, 105 | journal = {GitHub repository}, 106 | howpublished = {\url{https://github.com/HumanCompatibleAI/seals}}, 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /ci/Xdummy-entrypoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | # Adapted from https://github.com/openai/mujoco-py/blob/master/vendor/Xdummy-entrypoint 4 | # Copyright OpenAI; MIT License 5 | 6 | import argparse 7 | import os 8 | import sys 9 | import subprocess 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | args, extra_args = parser.parse_known_args() 14 | 15 | subprocess.Popen( 16 | [ 17 | "nohup", 18 | "Xorg", 19 | "-noreset", 20 | "+extension", 21 | "GLX", 22 | "+extension", 23 | "RANDR", 24 | "+extension", 25 | "RENDER", 26 | "-logfile", 27 | "/tmp/xdummy.log", 28 | "-config", 29 | "/etc/dummy_xorg.conf", 30 | ":0", 31 | ] 32 | ) 33 | subprocess.Popen( 34 | ["nohup", "Xdummy"], 35 | stdout=open("/dev/null", "w"), 36 | stderr=open("/dev/null", "w"), 37 | ) 38 | os.environ["DISPLAY"] = ":0" 39 | 40 | if not extra_args: 41 | argv = ["/bin/bash"] 42 | else: 43 | argv = extra_args 44 | 45 | # Explicitly flush right before the exec since otherwise things might get 46 | # lost in Python's buffers around stdout/stderr (!). 47 | sys.stdout.flush() 48 | sys.stderr.flush() 49 | 50 | os.execvpe(argv[0], argv, os.environ) 51 | -------------------------------------------------------------------------------- /ci/build_venv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -e # exit immediately on any error 4 | 5 | venv=$1 6 | if [[ ${venv} == "" ]]; then 7 | venv="venv" 8 | fi 9 | 10 | virtualenv -p python3.8 ${venv} 11 | source ${venv}/bin/activate 12 | pip install --upgrade pip # Ensure we have the newest pip 13 | pip install .[cpu,docs,mujoco,test] 14 | -------------------------------------------------------------------------------- /ci/code_checks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # If you change these, also change .circleci/config.yml. 4 | SRC_FILES=(src/ tests/ docs/conf.py setup.py) 5 | 6 | set -x # echo commands 7 | set -e # quit immediately on error 8 | 9 | echo "Source format checking" 10 | flake8 ${SRC_FILES[@]} 11 | black --check ${SRC_FILES[@]} 12 | codespell -I .codespell.skip --skip='*.pyc' ${SRC_FILES[@]} 13 | 14 | if [ -x "`which circleci`" ]; then 15 | circleci config validate 16 | fi 17 | 18 | if [ "$skipexpensive" != "true" ]; then 19 | echo "Building docs (validates docstrings)" 20 | pushd docs/ 21 | make clean 22 | make html 23 | popd 24 | 25 | echo "Type checking" 26 | pytype ${SRC_FILES[@]} 27 | mypy ${SRC_FILES[@]} 28 | fi 29 | -------------------------------------------------------------------------------- /ci/xorg.conf: -------------------------------------------------------------------------------- 1 | # This xorg configuration file is meant to be used by xpra 2 | # to start a dummy X11 server. 3 | # For details, please see: 4 | # https://xpra.org/Xdummy.html 5 | 6 | Section "ServerFlags" 7 | Option "DontVTSwitch" "true" 8 | Option "AllowMouseOpenFail" "true" 9 | Option "PciForceNone" "true" 10 | Option "AutoEnableDevices" "false" 11 | Option "AutoAddDevices" "false" 12 | EndSection 13 | 14 | Section "InputDevice" 15 | Identifier "dummy_mouse" 16 | Option "CorePointer" "true" 17 | Driver "void" 18 | EndSection 19 | 20 | Section "InputDevice" 21 | Identifier "dummy_keyboard" 22 | Option "CoreKeyboard" "true" 23 | Driver "void" 24 | EndSection 25 | 26 | Section "Device" 27 | Identifier "dummy_videocard" 28 | Driver "dummy" 29 | Option "ConstantDPI" "true" 30 | #VideoRam 4096000 31 | #VideoRam 256000 32 | VideoRam 192000 33 | EndSection 34 | 35 | Section "Monitor" 36 | Identifier "dummy_monitor" 37 | HorizSync 5.0 - 1000.0 38 | VertRefresh 5.0 - 200.0 39 | #This can be used to get a specific DPI, but only for the default resolution: 40 | #DisplaySize 508 317 41 | #NOTE: the highest modes will not work without increasing the VideoRam 42 | # for the dummy video card. 43 | Modeline "32768x32768" 15226.50 32768 35800 39488 46208 32768 32771 32781 32953 44 | Modeline "32768x16384" 7516.25 32768 35544 39192 45616 16384 16387 16397 16478 45 | Modeline "16384x8192" 2101.93 16384 16416 24400 24432 8192 8390 8403 8602 46 | Modeline "8192x4096" 424.46 8192 8224 9832 9864 4096 4195 4202 4301 47 | Modeline "5496x1200" 199.13 5496 5528 6280 6312 1200 1228 1233 1261 48 | Modeline "5280x1080" 169.96 5280 5312 5952 5984 1080 1105 1110 1135 49 | Modeline "5280x1200" 191.40 5280 5312 6032 6064 1200 1228 1233 1261 50 | Modeline "5120x3200" 199.75 5120 5152 5904 5936 3200 3277 3283 3361 51 | Modeline "4800x1200" 64.42 4800 4832 5072 5104 1200 1229 1231 1261 52 | Modeline "3840x2880" 133.43 3840 3872 4376 4408 2880 2950 2955 3025 53 | Modeline "3840x2560" 116.93 3840 3872 4312 4344 2560 2622 2627 2689 54 | Modeline "3840x2048" 91.45 3840 3872 4216 4248 2048 2097 2101 2151 55 | Modeline "3840x1080" 100.38 3840 3848 4216 4592 1080 1081 1084 1093 56 | Modeline "3600x1200" 106.06 3600 3632 3984 4368 1200 1201 1204 1214 57 | Modeline "3288x1080" 39.76 3288 3320 3464 3496 1080 1106 1108 1135 58 | Modeline "2048x2048" 49.47 2048 2080 2264 2296 2048 2097 2101 2151 59 | Modeline "2048x1536" 80.06 2048 2104 2312 2576 1536 1537 1540 1554 60 | Modeline "2560x1600" 47.12 2560 2592 2768 2800 1600 1639 1642 1681 61 | Modeline "2560x1440" 42.12 2560 2592 2752 2784 1440 1475 1478 1513 62 | Modeline "1920x1440" 69.47 1920 1960 2152 2384 1440 1441 1444 1457 63 | Modeline "1920x1200" 26.28 1920 1952 2048 2080 1200 1229 1231 1261 64 | Modeline "1920x1080" 23.53 1920 1952 2040 2072 1080 1106 1108 1135 65 | Modeline "1680x1050" 20.08 1680 1712 1784 1816 1050 1075 1077 1103 66 | Modeline "1600x1200" 22.04 1600 1632 1712 1744 1200 1229 1231 1261 67 | Modeline "1600x900" 33.92 1600 1632 1760 1792 900 921 924 946 68 | Modeline "1440x900" 30.66 1440 1472 1584 1616 900 921 924 946 69 | ModeLine "1366x768" 72.00 1366 1414 1446 1494 768 771 777 803 70 | Modeline "1280x1024" 31.50 1280 1312 1424 1456 1024 1048 1052 1076 71 | Modeline "1280x800" 24.15 1280 1312 1400 1432 800 819 822 841 72 | Modeline "1280x768" 23.11 1280 1312 1392 1424 768 786 789 807 73 | Modeline "1360x768" 24.49 1360 1392 1480 1512 768 786 789 807 74 | Modeline "1024x768" 18.71 1024 1056 1120 1152 768 786 789 807 75 | Modeline "768x1024" 19.50 768 800 872 904 1024 1048 1052 1076 76 | 77 | 78 | #common resolutions for android devices (both orientations): 79 | Modeline "800x1280" 25.89 800 832 928 960 1280 1310 1315 1345 80 | Modeline "1280x800" 24.15 1280 1312 1400 1432 800 819 822 841 81 | Modeline "720x1280" 30.22 720 752 864 896 1280 1309 1315 1345 82 | Modeline "1280x720" 27.41 1280 1312 1416 1448 720 737 740 757 83 | Modeline "768x1024" 24.93 768 800 888 920 1024 1047 1052 1076 84 | Modeline "1024x768" 23.77 1024 1056 1144 1176 768 785 789 807 85 | Modeline "600x1024" 19.90 600 632 704 736 1024 1047 1052 1076 86 | Modeline "1024x600" 18.26 1024 1056 1120 1152 600 614 617 631 87 | Modeline "536x960" 16.74 536 568 624 656 960 982 986 1009 88 | Modeline "960x536" 15.23 960 992 1048 1080 536 548 551 563 89 | Modeline "600x800" 15.17 600 632 688 720 800 818 822 841 90 | Modeline "800x600" 14.50 800 832 880 912 600 614 617 631 91 | Modeline "480x854" 13.34 480 512 560 592 854 873 877 897 92 | Modeline "848x480" 12.09 848 880 920 952 480 491 493 505 93 | Modeline "480x800" 12.43 480 512 552 584 800 818 822 841 94 | Modeline "800x480" 11.46 800 832 872 904 480 491 493 505 95 | #resolutions for android devices (both orientations) 96 | #minus the status bar 97 | #38px status bar (and width rounded up) 98 | Modeline "800x1242" 25.03 800 832 920 952 1242 1271 1275 1305 99 | Modeline "1280x762" 22.93 1280 1312 1392 1424 762 780 783 801 100 | Modeline "720x1242" 29.20 720 752 856 888 1242 1271 1276 1305 101 | Modeline "1280x682" 25.85 1280 1312 1408 1440 682 698 701 717 102 | Modeline "768x986" 23.90 768 800 888 920 986 1009 1013 1036 103 | Modeline "1024x730" 22.50 1024 1056 1136 1168 730 747 750 767 104 | Modeline "600x986" 19.07 600 632 704 736 986 1009 1013 1036 105 | Modeline "1024x562" 17.03 1024 1056 1120 1152 562 575 578 591 106 | Modeline "536x922" 16.01 536 568 624 656 922 943 947 969 107 | Modeline "960x498" 14.09 960 992 1040 1072 498 509 511 523 108 | Modeline "600x762" 14.39 600 632 680 712 762 779 783 801 109 | Modeline "800x562" 13.52 800 832 880 912 562 575 578 591 110 | Modeline "480x810" 12.59 480 512 552 584 810 828 832 851 111 | Modeline "848x442" 11.09 848 880 920 952 442 452 454 465 112 | Modeline "480x762" 11.79 480 512 552 584 762 779 783 801 113 | EndSection 114 | 115 | Section "Screen" 116 | Identifier "dummy_screen" 117 | Device "dummy_videocard" 118 | Monitor "dummy_monitor" 119 | DefaultDepth 24 120 | SubSection "Display" 121 | Viewport 0 0 122 | Depth 24 123 | #Modes "32768x32768" "32768x16384" "16384x8192" "8192x4096" "5120x3200" "3840x2880" "3840x2560" "3840x2048" "2048x2048" "2560x1600" "1920x1440" "1920x1200" "1920x1080" "1600x1200" "1680x1050" "1600x900" "1400x1050" "1440x900" "1280x1024" "1366x768" "1280x800" "1024x768" "1024x600" "800x600" "320x200" 124 | Modes "5120x3200" "3840x2880" "3840x2560" "3840x2048" "2048x2048" "2560x1600" "1920x1440" "1920x1200" "1920x1080" "1600x1200" "1680x1050" "1600x900" "1400x1050" "1440x900" "1280x1024" "1366x768" "1280x800" "1024x768" "1024x600" "800x600" "320x200" 125 | #Virtual 32000 32000 126 | #Virtual 16384 8192 127 | # 1024x768 is big enough for testing, but small enough it won't eat up lots of RAM 128 | Virtual 1024 768 129 | #Virtual 5120 3200 130 | EndSubSection 131 | EndSection 132 | 133 | Section "ServerLayout" 134 | Identifier "dummy_layout" 135 | Screen "dummy_screen" 136 | InputDevice "dummy_mouse" 137 | InputDevice "dummy_keyboard" 138 | EndSection 139 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: false 5 | main: 6 | paths: 7 | - "src/seals/" 8 | - "!src/seals/testing/" 9 | tests: 10 | # Should not have dead code in our tests 11 | target: 100% 12 | paths: 13 | - "tests/" 14 | - "src/seals/testing/" 15 | patch: 16 | default: false 17 | main: 18 | paths: 19 | - "src/seals/" 20 | - "!src/seals/testing/" 21 | tests: 22 | target: 100% 23 | paths: 24 | - "tests/" 25 | - "src/seals/testing/" 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/img/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | image/svg+xml 46 | 55 | 57 | 59 | 65 | 71 | 73 | 79 | 85 | 91 | 92 | 93 | 95 | 102 | 106 | 110 | 114 | 115 | 121 | 127 | 133 | 140 | 144 | 148 | 152 | 153 | 159 | 166 | 170 | 174 | 178 | 179 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /docs/common/base_envs.rst: -------------------------------------------------------------------------------- 1 | .. _base_envs: 2 | 3 | Base Environments 4 | ================= 5 | 6 | .. automodule:: seals.base_envs 7 | -------------------------------------------------------------------------------- /docs/common/testing.rst: -------------------------------------------------------------------------------- 1 | .. _testing: 2 | 3 | Helpers for unit-testing environments 4 | ===================================== 5 | 6 | .. automodule:: seals.testing.envs 7 | -------------------------------------------------------------------------------- /docs/common/util.rst: -------------------------------------------------------------------------------- 1 | .. _util: 2 | 3 | Utilities 4 | ========= 5 | 6 | .. automodule:: seals.util 7 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | """Configuration file for the Sphinx documentation builder.""" 2 | 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | from importlib import metadata 8 | 9 | # -- Project information ----------------------------------------------------- 10 | 11 | project = "seals" 12 | copyright = "2020, Center for Human-Compatible AI" # noqa: A001 13 | author = "Center for Human-Compatible AI" 14 | 15 | # The full version, including alpha/beta/rc tags 16 | version = metadata.version("seals") 17 | 18 | 19 | # -- General configuration --------------------------------------------------- 20 | 21 | # Add any Sphinx extension module names here, as strings. They can be 22 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 23 | # ones. 24 | extensions = [ 25 | "sphinx.ext.autodoc", 26 | "sphinx_autodoc_typehints", 27 | "sphinx.ext.autosummary", 28 | "sphinx.ext.mathjax", 29 | "sphinx.ext.napoleon", 30 | "sphinx.ext.viewcode", 31 | "sphinx_rtd_theme", 32 | ] 33 | autodoc_mock_imports = ["mujoco_py"] 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ["_templates"] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 42 | 43 | autodoc_default_options = { 44 | "members": True, 45 | "undoc-members": True, 46 | "special-members": "__init__", 47 | "show-inheritance": True, 48 | } 49 | 50 | # -- Options for HTML output ------------------------------------------------- 51 | 52 | # The theme to use for HTML and HTML Help pages. See the documentation for 53 | # a list of builtin themes. 54 | # 55 | html_theme = "sphinx_rtd_theme" 56 | html_logo = "_static/img/logo.svg" 57 | html_theme_options = { 58 | "logo_only": True, 59 | "style_nav_header_background": "#06203A", 60 | } 61 | 62 | # Add any paths that contain custom static files (such as style sheets) here, 63 | # relative to this directory. They are copied after the builtin static files, 64 | # so a file named "default.css" will overwrite the builtin "default.css". 65 | html_static_path = ["_static"] 66 | -------------------------------------------------------------------------------- /docs/environments/diagnostic.rst: -------------------------------------------------------------------------------- 1 | .. _diagnostic: 2 | 3 | Diagnostic Tasks 4 | ================ 5 | 6 | Diagnostic tasks test individual facets of algorithm performance in isolation. 7 | 8 | Branching 9 | --------- 10 | 11 | **Gym ID**: ``seals/Branching-v0`` 12 | 13 | .. automodule:: seals.diagnostics.branching 14 | 15 | EarlyTerm 16 | --------- 17 | 18 | **Gym ID**: ``seals/EarlyTermPos-v0`` and ``seals/EarlyTermNeg-v0`` 19 | 20 | .. automodule:: seals.diagnostics.early_term 21 | 22 | InitShift 23 | --------- 24 | 25 | **Gym ID**: ``seals/InitShiftTrain-v0`` and ``seals/InitShiftTest-v0seals/EarlyTermPos-v0`` 26 | 27 | .. automodule:: seals.diagnostics.init_shift 28 | 29 | LargestSum 30 | ------------------- 31 | 32 | **Gym ID**: ``seals/LargestSum-v0`` 33 | 34 | .. automodule:: seals.diagnostics.largest_sum 35 | 36 | NoisyObs 37 | -------- 38 | 39 | **Gym ID**: ``seals/NoisyObs-v0`` 40 | 41 | .. automodule:: seals.diagnostics.noisy_obs 42 | 43 | Parabola 44 | -------- 45 | 46 | **Gym ID**: ``seals/Parabola-v0`` 47 | 48 | .. automodule:: seals.diagnostics.parabola 49 | 50 | ProcGoal 51 | -------- 52 | 53 | **Gym ID**: ``seals/ProcGoal-v0`` 54 | 55 | .. automodule:: seals.diagnostics.proc_goal 56 | 57 | RiskyPath 58 | --------- 59 | 60 | **Gym ID**: ``seals/RiskyPath-v0`` 61 | 62 | .. automodule:: seals.diagnostics.risky_path 63 | 64 | Sort 65 | ---- 66 | 67 | **Gym ID**: ``seals/Sort-v0`` 68 | 69 | .. automodule:: seals.diagnostics.sort 70 | -------------------------------------------------------------------------------- /docs/environments/renovated.rst: -------------------------------------------------------------------------------- 1 | .. _renovated: 2 | 3 | Renovated Environments 4 | ====================== 5 | 6 | These environments are adaptations of widely-used reinforcement learning benchmarks from 7 | `Gym `_, modified to be suitable for benchmarking specification 8 | learning algorithms. In particular, we: 9 | 10 | * Make episodes fixed length. Since episode termination conditions are often correlated with 11 | reward, variable-length episodes provide a side-channel of reward information that algorithms 12 | can exploit. Critically, episode boundaries do not exist outside of simulation: in the 13 | real-world, a human must often `"reset" the RL algorithm `_. 14 | 15 | Moreover, many algorithms do not properly handle episode termination, and so are 16 | `biased `_ towards shorter or longer episode boundaries. 17 | This confounds evaluation, making some algorithms appear spuriously good or bad depending 18 | on if their bias aligns with the task objective. 19 | 20 | For most tasks, we make the episode fixed length simply by removing the early termination 21 | condition. In some environments, such as *MountainCar*, it does not make sense to continue 22 | after the terminal state: in this case, we make the terminal state an absorbing state that 23 | is repeated until the end of the episode. 24 | * Ensure observations include all information necessary to compute the ground-truth reward 25 | function. For some environments, this has required augmenting the observation space. 26 | We make this modification to make RL and specification learning of comparable difficulty 27 | in these environments. While in general both RL and specification learning may need to 28 | operate in partially observable environments, the observations in these relatively simple 29 | environments were typically engineered to *make RL easy*: for a fair comparison, we must 30 | therefore also provide reward learning algorithms with sufficient features to recover the 31 | reward. 32 | 33 | In the future, we intend to add Atari tasks with the score masked, another reward side-channel. 34 | 35 | Classic Control 36 | --------------- 37 | 38 | CartPole 39 | ******** 40 | 41 | **Gym ID**: ``seals/CartPole-v0`` 42 | 43 | .. autoclass:: seals.classic_control.FixedHorizonCartPole 44 | 45 | MountainCar 46 | *********** 47 | 48 | **Gym ID**: ``seals/MountainCar-v0`` 49 | 50 | .. autofunction:: seals.classic_control.mountain_car 51 | 52 | MuJoCo 53 | ------ 54 | 55 | Ant 56 | *** 57 | 58 | **Gym ID**: ``seals/Ant-v0`` 59 | 60 | .. autoclass:: seals.mujoco.AntEnv 61 | 62 | HalfCheetah 63 | *********** 64 | 65 | **Gym ID**: ``seals/HalfCheetah-v0`` 66 | 67 | .. autoclass:: seals.mujoco.HalfCheetahEnv 68 | 69 | Hopper 70 | ****** 71 | 72 | **Gym ID**: ``seals/Hopper-v0`` 73 | 74 | .. autoclass:: seals.mujoco.HopperEnv 75 | 76 | Humanoid 77 | ******** 78 | 79 | **Gym ID**: ``seals/Humanoid-v0`` 80 | 81 | .. autoclass:: seals.mujoco.HumanoidEnv 82 | 83 | Swimmer 84 | ******* 85 | 86 | **Gym ID**: ``seals/Swimmer-v0`` 87 | 88 | .. autoclass:: seals.mujoco.SwimmerEnv 89 | 90 | Walker2d 91 | ******** 92 | 93 | **Gym ID**: ``seals/Walker2d-v0`` 94 | 95 | .. autoclass:: seals.mujoco.Walker2dEnv 96 | -------------------------------------------------------------------------------- /docs/guide/install.rst: -------------------------------------------------------------------------------- 1 | .. _install: 2 | 3 | Installation Instructions 4 | ========================= 5 | 6 | To install the latest release from PyPi, run:: 7 | 8 | pip install seals 9 | 10 | We make releases periodically, but if you wish to use the latest version of the code, you can 11 | always install directly from Git master:: 12 | 13 | pip install git+https://github.com/HumanCompatibleAI/seals.git 14 | 15 | *seals* has optional dependencies needed by some subset of environments. In particular, 16 | to use MuJoCo environments, you will need to install `MuJoCo `_ 1.5 17 | and then run:: 18 | 19 | pip install seals[mujoco] 20 | 21 | You may need to install some other binary dependencies: see the instructions in 22 | `Gym `_ and `mujoco-py `_ 23 | for further information. 24 | 25 | You can also use our Docker image which includes all necessary binary dependencies. You can either 26 | build it from the ``Dockerfile``, or by downloading a pre-built image:: 27 | 28 | docker pull humancompatibleai/seals:base 29 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | seals User Guide 2 | ================ 3 | 4 | The Suite of Environments for Algorithms that Learn Specifications, or *seals*, is a toolkit for 5 | evaluating specification learning algorithms, such as reward or imitation learning. The environments 6 | are compatible with `Gym `_, but are designed to test algorithms 7 | that learn from user data, without requiring a procedurally specified reward function. 8 | 9 | There are two types of environments in *seals*: 10 | 11 | * **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. 12 | * **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous 13 | control tasks to be suitable for specification learning benchmarks. In particular, this 14 | involves removing any side-channel sources of reward information (such as episode boundaries, 15 | the score appearing in the observation, etc) and including all the information needed to 16 | compute the reward in the observation space. 17 | 18 | *seals* is under active development and we intend to add more categories of tasks soon. 19 | 20 | .. toctree:: 21 | :maxdepth: 1 22 | :caption: User Guide 23 | 24 | guide/install 25 | 26 | 27 | .. toctree:: 28 | :maxdepth: 3 29 | :caption: Environments 30 | 31 | environments/diagnostic 32 | environments/renovated 33 | 34 | .. toctree:: 35 | :maxdepth: 2 36 | :caption: Common 37 | 38 | common/base_envs 39 | common/util 40 | common/testing 41 | 42 | Citing seals 43 | ------------ 44 | To cite this project in publications: 45 | 46 | .. code-block:: bibtex 47 | 48 | @misc{seals, 49 | author = {Adam Gleave and Pedro Freire and Steven Wang and Sam Toyer}, 50 | title = {{seals}: Suite of Environments for Algorithms that Learn Specifications}, 51 | year = {2020}, 52 | publisher = {GitHub}, 53 | journal = {GitHub repository}, 54 | howpublished = {\url{https://github.com/HumanCompatibleAI/seals}}, 55 | } 56 | 57 | Indices and tables 58 | ================== 59 | 60 | * :ref:`genindex` 61 | * :ref:`modindex` 62 | * :ref:`search` 63 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = true -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | target-version = ["py38"] 7 | 8 | [[tool.mypy.overrides]] 9 | module = ["gym.*", "setuptools_scm.*"] 10 | ignore_missing_imports = true 11 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | formats: all 7 | 8 | python: 9 | version: 3.8 10 | install: 11 | - method: pip 12 | path: . 13 | extra_requirements: 14 | - docs 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [coverage:run] 2 | source = seals 3 | include= 4 | src/* 5 | tests/* 6 | 7 | [coverage:report] 8 | exclude_lines = 9 | pragma: no cover 10 | if __name__ == .__main__.: 11 | omit = 12 | setup.py 13 | 14 | [coverage:paths] 15 | source = 16 | src/seals 17 | *venv/lib/python*/site-packages/seals 18 | 19 | [darglint] 20 | strictness=long 21 | 22 | [flake8] 23 | docstring-convention=google 24 | ignore = E203, W503 25 | max-line-length = 88 26 | 27 | [isort] 28 | line_length=88 29 | known_first_party=seals,tests 30 | default_section=THIRDPARTY 31 | multi_line_output=3 32 | include_trailing_comma=True 33 | force_sort_within_sections=True 34 | skip=.pytype 35 | 36 | [pytype] 37 | inputs = 38 | src/ 39 | tests/ 40 | setup.py 41 | python_version >= 3.8 42 | 43 | [tool:pytest] 44 | markers = 45 | expensive: mark a test as expensive (deselect with '-m "not expensive"') 46 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """setup.py for seals project.""" 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from setuptools import find_packages, setup # type:ignore 6 | 7 | if TYPE_CHECKING: 8 | from setuptools_scm.version import ScmVersion 9 | 10 | 11 | def get_version(version: "ScmVersion") -> str: 12 | """Generates the version string for the package. 13 | 14 | This function replaces the default version format used by setuptools_scm 15 | to allow development builds to be versioned using the git commit hash 16 | instead of the number of commits since the last release, which leads to 17 | duplicate version identifiers when using multiple branches 18 | (see https://github.com/HumanCompatibleAI/imitation/issues/500). 19 | The version has the following format: 20 | {version}[.dev{build}] 21 | where build is the shortened commit hash converted to base 10. 22 | 23 | Args: 24 | version: The version object given by setuptools_scm, calculated 25 | from the git repository. 26 | 27 | Returns: 28 | The formatted version string to use for the package. 29 | """ 30 | # We import setuptools_scm here because it is only installed after the module 31 | # is loaded and the setup function is called. 32 | from setuptools_scm import version as scm_version 33 | 34 | if version.node: 35 | # By default node corresponds to the short commit hash when using git, 36 | # plus a "g" prefix. We remove the "g" prefix from the commit hash which 37 | # is added by setuptools_scm by default ("g" for git vs. mercurial etc.) 38 | # because letters are not valid for version identifiers in PEP 440. 39 | # We also convert from hexadecimal to base 10 for the same reason. 40 | version.node = str(int(version.node.lstrip("g"), 16)) 41 | if version.exact: 42 | # an exact version is when the current commit is tagged with a version. 43 | return version.format_with("{tag}") 44 | else: 45 | # the current commit is not tagged with a version, so we guess 46 | # what the "next" version will be (this can be disabled but is the 47 | # default behavior of setuptools_scm so it has been left in). 48 | return version.format_next_version( 49 | scm_version.guess_next_version, 50 | fmt="{guessed}.dev{node}", 51 | ) 52 | 53 | 54 | def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: 55 | """Generates the local version string for the package. 56 | 57 | By default, when commits are made on top of a release version, setuptools_scm 58 | sets the version to be {version}.dev{distance}+{node} where {distance} is the number 59 | of commits since the last release and {node} is the short commit hash. 60 | This function replaces the default version format used by setuptools_scm 61 | so that committed changes away from a release version are not considered 62 | local versions but dev versions instead (by using the format 63 | {version}.dev{node} instead. This is so that we can push test releases 64 | to TestPyPI (it does not accept local versions). 65 | Local versions are still present if there are uncommitted changes (if the tree 66 | is dirty), in which case the current date is added to the version. 67 | 68 | Args: 69 | version: The version object given by setuptools_scm, calculated 70 | from the git repository. 71 | time_format: The format to use for the date. 72 | 73 | Returns: 74 | The formatted local version string to use for the package. 75 | """ 76 | return version.format_choice( 77 | "", 78 | "+d{time:{time_format}}", 79 | time_format=time_format, 80 | ) 81 | 82 | 83 | def get_readme() -> str: 84 | """Retrieve content from README.""" 85 | with open("README.md", "r") as f: 86 | return f.read() 87 | 88 | 89 | ATARI_REQUIRE = [ 90 | "opencv-python", 91 | "ale-py~=0.8.1", 92 | "pillow", 93 | "autorom[accept-rom-license]~=0.4.2", 94 | "shimmy[atari] >=0.1.0,<1.0", 95 | ] 96 | TESTS_REQUIRE = [ 97 | "black", 98 | "coverage~=4.5.4", 99 | "codecov", 100 | "codespell", 101 | "darglint>=1.5.6", 102 | "flake8", 103 | "flake8-blind-except", 104 | "flake8-builtins", 105 | "flake8-commas", 106 | "flake8-debugger", 107 | "flake8-docstrings", 108 | "flake8-isort", 109 | "isort", 110 | "matplotlib", 111 | "mypy", 112 | "pydocstyle", 113 | "pytest", 114 | "pytest-cov", 115 | "pytest-xdist", 116 | "pytype", 117 | "stable-baselines3>=0.9.0", 118 | "setuptools_scm~=7.0.5", 119 | "gymnasium[classic-control,mujoco]", 120 | *ATARI_REQUIRE, 121 | ] 122 | DOCS_REQUIRE = [ 123 | "sphinx", 124 | "sphinx-autodoc-typehints>=1.21.5", 125 | "sphinx-rtd-theme", 126 | ] 127 | 128 | 129 | setup( 130 | name="seals", 131 | use_scm_version={"local_scheme": get_local_version, "version_scheme": get_version}, 132 | description="Suite of Environments for Algorithms that Learn Specifications", 133 | long_description=get_readme(), 134 | long_description_content_type="text/markdown", 135 | author="Center for Human-Compatible AI", 136 | python_requires=">=3.8.0", 137 | packages=find_packages("src"), 138 | package_dir={"": "src"}, 139 | package_data={"seals": ["py.typed"]}, 140 | install_requires=["gymnasium", "numpy"], 141 | tests_require=TESTS_REQUIRE, 142 | extras_require={ 143 | # recommended packages for development 144 | "dev": ["ipdb", "jupyter", *TESTS_REQUIRE, *DOCS_REQUIRE], 145 | "docs": DOCS_REQUIRE, 146 | "test": TESTS_REQUIRE, 147 | "mujoco": ["gymnasium[mujoco]"], 148 | "atari": ATARI_REQUIRE, 149 | }, 150 | url="https://github.com/HumanCompatibleAI/benchmark-environments", 151 | license="MIT", 152 | classifiers=[ 153 | # Trove classifiers 154 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 155 | "License :: OSI Approved :: MIT License", 156 | "Programming Language :: Python", 157 | "Programming Language :: Python :: 3", 158 | "Programming Language :: Python :: 3.8", 159 | "Programming Language :: Python :: 3.9", 160 | "Programming Language :: Python :: 3.10", 161 | "Programming Language :: Python :: Implementation :: CPython", 162 | "Programming Language :: Python :: Implementation :: PyPy", 163 | "License :: OSI Approved :: MIT License", 164 | "Operating System :: OS Independent", 165 | ], 166 | ) 167 | -------------------------------------------------------------------------------- /src/seals/__init__.py: -------------------------------------------------------------------------------- 1 | """Benchmark environments for reward modeling and imitation.""" 2 | 3 | from importlib import metadata 4 | 5 | import gymnasium as gym 6 | 7 | from seals import atari, util 8 | import seals.diagnostics # noqa: F401 9 | 10 | try: 11 | __version__ = metadata.version("seals") 12 | except metadata.PackageNotFoundError: # pragma: no cover 13 | # package is not installed 14 | pass 15 | 16 | # Classic control 17 | 18 | gym.register( 19 | id="seals/CartPole-v0", 20 | entry_point="seals.classic_control:FixedHorizonCartPole", 21 | max_episode_steps=500, 22 | ) 23 | 24 | gym.register( 25 | id="seals/MountainCar-v0", 26 | entry_point="seals.classic_control:mountain_car", 27 | max_episode_steps=200, 28 | ) 29 | 30 | # MuJoCo 31 | 32 | for env_base in ["Ant", "HalfCheetah", "Hopper", "Humanoid", "Swimmer", "Walker2d"]: 33 | gym.register( 34 | id=f"seals/{env_base}-v1", 35 | entry_point=f"seals.mujoco:{env_base}Env", 36 | max_episode_steps=util.get_gym_max_episode_steps(f"{env_base}-v4"), 37 | ) 38 | 39 | # Atari 40 | 41 | GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.registry.values())) 42 | atari.register_atari_envs(GYM_ATARI_ENV_SPECS) 43 | -------------------------------------------------------------------------------- /src/seals/atari.py: -------------------------------------------------------------------------------- 1 | """Adaptation of Atari environments for specification learning algorithms.""" 2 | 3 | from typing import Dict, Iterable, Optional 4 | 5 | import gymnasium as gym 6 | from gymnasium.envs.registration import EnvSpec 7 | 8 | from seals.util import ( 9 | AutoResetWrapper, 10 | BoxRegion, 11 | MaskedRegionSpecifier, 12 | MaskScoreWrapper, 13 | get_gym_max_episode_steps, 14 | ) 15 | 16 | SCORE_REGIONS: Dict[str, MaskedRegionSpecifier] = { 17 | "BeamRider": [ 18 | BoxRegion(x=(5, 20), y=(45, 120)), 19 | BoxRegion(x=(28, 40), y=(15, 40)), 20 | ], 21 | "Breakout": [BoxRegion(x=(0, 16), y=(35, 80))], 22 | "Enduro": [ 23 | BoxRegion(x=(163, 173), y=(55, 110)), 24 | BoxRegion(x=(177, 188), y=(68, 107)), 25 | ], 26 | "Pong": [BoxRegion(x=(0, 24), y=(0, 160))], 27 | "Qbert": [BoxRegion(x=(6, 15), y=(33, 71))], 28 | "Seaquest": [BoxRegion(x=(7, 19), y=(80, 110))], 29 | "SpaceInvaders": [BoxRegion(x=(10, 20), y=(0, 160))], 30 | } 31 | 32 | 33 | def _get_score_region(atari_env_id: str) -> Optional[MaskedRegionSpecifier]: 34 | basename = atari_env_id.split("/")[-1].split("-")[0] 35 | basename = basename.replace("NoFrameskip", "") 36 | return SCORE_REGIONS.get(basename) 37 | 38 | 39 | def make_atari_env(atari_env_id: str, masked: bool, *args, **kwargs) -> gym.Env: 40 | """Fixed-length, optionally masked-score variant of a given Atari environment.""" 41 | env: gym.Env = AutoResetWrapper(gym.make(atari_env_id, *args, **kwargs)) 42 | 43 | if masked: 44 | score_region = _get_score_region(atari_env_id) 45 | if score_region is None: 46 | raise ValueError( 47 | "Requested environment does not yet support masking. " 48 | "See https://github.com/HumanCompatibleAI/seals/issues/61.", 49 | ) 50 | env = MaskScoreWrapper(env, score_region) 51 | 52 | return env 53 | 54 | 55 | def _not_ram_or_det(env_id: str) -> bool: 56 | """Checks a gym Atari environment isn't deterministic or using RAM observations.""" 57 | slash_separated = env_id.split("/") 58 | # environment name should look like "ALE/Amidar-v5" or "Amidar-ramNoFrameskip-v4" 59 | assert len(slash_separated) in (1, 2) 60 | after_slash = slash_separated[-1] 61 | hyphen_separated = after_slash.split("-") 62 | assert len(hyphen_separated) > 1 63 | not_ram = "ram" not in hyphen_separated[1] 64 | not_deterministic = "Deterministic" not in env_id 65 | return not_ram and not_deterministic 66 | 67 | 68 | def _supported_atari_env(gym_spec: EnvSpec) -> bool: 69 | """Checks if a gym Atari environment is one of the ones we will support.""" 70 | is_atari = gym_spec.entry_point == "shimmy.atari_env:AtariEnv" 71 | v5_and_plain = gym_spec.id.endswith("-v5") and "NoFrameskip" not in gym_spec.id 72 | v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id 73 | return ( 74 | is_atari 75 | and _not_ram_or_det(gym_spec.id) 76 | and (v5_and_plain or v4_and_no_frameskip) 77 | ) 78 | 79 | 80 | def _seals_name(gym_spec: EnvSpec, masked: bool) -> str: 81 | """Makes a Gym ID for an Atari environment in the seals namespace.""" 82 | slash_separated = gym_spec.id.split("/") 83 | name = "seals/" + slash_separated[-1] 84 | 85 | if not masked: 86 | last_hyphen_idx = name.rfind("-v") 87 | name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:] 88 | return name 89 | 90 | 91 | def register_atari_envs( 92 | gym_atari_env_specs: Iterable[EnvSpec], 93 | ) -> None: 94 | """Register masked and unmasked wrapped gym Atari environments.""" 95 | 96 | def register_gym(masked): 97 | gym.register( 98 | id=_seals_name(gym_spec, masked=masked), 99 | entry_point="seals.atari:make_atari_env", 100 | max_episode_steps=get_gym_max_episode_steps(gym_spec.id), 101 | kwargs=dict(atari_env_id=gym_spec.id, masked=masked), 102 | ) 103 | 104 | for gym_spec in gym_atari_env_specs: 105 | register_gym(masked=False) 106 | if _get_score_region(gym_spec.id) is not None: 107 | register_gym(masked=True) 108 | -------------------------------------------------------------------------------- /src/seals/base_envs.py: -------------------------------------------------------------------------------- 1 | """Base environment classes.""" 2 | 3 | import abc 4 | from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union 5 | 6 | import gymnasium as gym 7 | from gymnasium import spaces 8 | import numpy as np 9 | import numpy.typing as npt 10 | 11 | from seals import util 12 | 13 | # Note: we redefine the type vars from gymnasium.core here, because pytype does not 14 | # recognize them as valid type vars if we import them from gymnasium.core. 15 | StateType = TypeVar("StateType") 16 | ActType = TypeVar("ActType") 17 | ObsType = TypeVar("ObsType") 18 | 19 | 20 | class ResettablePOMDP( 21 | gym.Env[ObsType, ActType], 22 | abc.ABC, 23 | Generic[StateType, ObsType, ActType], 24 | ): 25 | """ABC for POMDPs that are resettable. 26 | 27 | Specifically, these environments provide oracle access to sample from 28 | the initial state distribution and transition dynamics, and compute the 29 | reward and termination condition. Almost all simulated environments can 30 | meet these criteria. 31 | """ 32 | 33 | state_space: spaces.Space[StateType] 34 | 35 | _cur_state: Optional[StateType] 36 | _n_actions_taken: Optional[int] 37 | 38 | def __init__(self): 39 | """Build resettable (PO)MDP.""" 40 | self._cur_state = None 41 | self._n_actions_taken = None 42 | 43 | @abc.abstractmethod 44 | def initial_state(self) -> StateType: 45 | """Samples from the initial state distribution.""" 46 | 47 | @abc.abstractmethod 48 | def transition(self, state: StateType, action: ActType) -> StateType: 49 | """Samples from transition distribution.""" 50 | 51 | @abc.abstractmethod 52 | def reward(self, state: StateType, action: ActType, new_state: StateType) -> float: 53 | """Computes reward for a given transition.""" 54 | 55 | @abc.abstractmethod 56 | def terminal(self, state: StateType, step: int) -> bool: 57 | """Is the state terminal?""" 58 | 59 | @abc.abstractmethod 60 | def obs_from_state(self, state: StateType) -> ObsType: 61 | """Sample observation for given state.""" 62 | 63 | @property 64 | def n_actions_taken(self) -> int: 65 | """Number of steps taken so far.""" 66 | assert self._n_actions_taken is not None 67 | return self._n_actions_taken 68 | 69 | @property 70 | def state(self) -> StateType: 71 | """Current state.""" 72 | assert self._cur_state is not None 73 | return self._cur_state 74 | 75 | @state.setter 76 | def state(self, state: StateType): 77 | """Set the current state.""" 78 | if state not in self.state_space: 79 | raise ValueError(f"{state} not in {self.state_space}") 80 | self._cur_state = state 81 | 82 | def reset( 83 | self, 84 | *, 85 | seed: Optional[int] = None, 86 | options: Optional[Dict[str, Any]] = None, 87 | ) -> Tuple[ObsType, Dict[str, Any]]: 88 | """Reset the episode and return initial observation.""" 89 | if options is not None: 90 | raise NotImplementedError("Options not supported.") 91 | 92 | super().reset(seed=seed) 93 | self.state = self.initial_state() 94 | self._n_actions_taken = 0 95 | obs = self.obs_from_state(self.state) 96 | info: Dict[str, Any] = dict() 97 | return obs, info 98 | 99 | def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: 100 | """Transition state using given action.""" 101 | if self._cur_state is None or self._n_actions_taken is None: 102 | raise RuntimeError("Need to call reset() before first step()") 103 | if action not in self.action_space: 104 | raise ValueError(f"{action} not in {self.action_space}") 105 | 106 | old_state = self.state 107 | self.state = self.transition(self.state, action) 108 | obs = self.obs_from_state(self.state) 109 | assert obs in self.observation_space 110 | reward = self.reward(old_state, action, self.state) 111 | self._n_actions_taken += 1 112 | terminated = self.terminal(self.state, self.n_actions_taken) 113 | truncated = False 114 | 115 | infos = {"old_state": old_state, "new_state": self._cur_state} 116 | return obs, reward, terminated, truncated, infos 117 | 118 | 119 | class ExposePOMDPStateWrapper( 120 | gym.Wrapper[StateType, ActType, ObsType, ActType], 121 | Generic[StateType, ObsType, ActType], 122 | ): 123 | """A wrapper that exposes the current state of the POMDP as the observation.""" 124 | 125 | env: ResettablePOMDP[StateType, ObsType, ActType] 126 | 127 | def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None: 128 | """Build wrapper. 129 | 130 | Args: 131 | env: POMDP to wrap. 132 | """ 133 | super().__init__(env) 134 | self._observation_space = env.state_space 135 | 136 | def reset( 137 | self, 138 | seed: Optional[int] = None, 139 | options: Optional[Dict[str, Any]] = None, 140 | ) -> Tuple[StateType, Dict[str, Any]]: 141 | """Reset environment and return initial state.""" 142 | _, info = self.env.reset(seed=seed, options=options) 143 | return self.env.state, info 144 | 145 | def step(self, action) -> Tuple[StateType, float, bool, bool, dict]: 146 | """Transition state using given action.""" 147 | _, reward, terminated, truncated, info = self.env.step(action) 148 | return self.env.state, reward, terminated, truncated, info 149 | 150 | 151 | class ResettableMDP( 152 | ResettablePOMDP[StateType, StateType, ActType], 153 | abc.ABC, 154 | Generic[StateType, ActType], 155 | ): 156 | """ABC for MDPs that are resettable.""" 157 | 158 | @property 159 | def observation_space(self): 160 | """Observation space.""" 161 | return self.state_space 162 | 163 | def obs_from_state(self, state: StateType) -> StateType: 164 | """Identity since observation == state in an MDP.""" 165 | return state 166 | 167 | 168 | DiscreteSpaceInt = np.int64 169 | 170 | 171 | # TODO(juan) this does not implement the .render() method, 172 | # so in theory it should not be instantiated directly. 173 | # Not sure why this is not raising an error? 174 | class BaseTabularModelPOMDP( 175 | ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt], 176 | Generic[ObsType], 177 | ): 178 | """Base class for tabular environments with known dynamics. 179 | 180 | This is the general class that also allows subclassing for creating 181 | MDP (where observation == state) or POMDP (where observation != state). 182 | """ 183 | 184 | transition_matrix: np.ndarray 185 | reward_matrix: np.ndarray 186 | 187 | state_space: spaces.Discrete 188 | 189 | def __init__( 190 | self, 191 | *, 192 | transition_matrix: np.ndarray, 193 | reward_matrix: np.ndarray, 194 | horizon: Optional[int] = None, 195 | initial_state_dist: Optional[np.ndarray] = None, 196 | ): 197 | """Build tabular environment. 198 | 199 | Args: 200 | transition_matrix: 3-D array with transition probabilities for a 201 | given state-action pair, of shape `(n_states,n_actions,n_states)`. 202 | reward_matrix: 1-D, 2-D or 3-D array corresponding to rewards to a 203 | given `(state, action, next_state)` triple. A 2-D array assumes 204 | the `next_state` is not used in the reward, and a 1-D array 205 | assumes neither the `action` nor `next_state` are used. 206 | Of shape `(n_states,n_actions,n_states)[:n]` where `n` 207 | is the dimensionality of the array. 208 | horizon: Maximum number of timesteps. The default is `None`, 209 | which represents an infinite horizon. 210 | initial_state_dist: Distribution from which state is sampled at the 211 | start of the episode. If `None`, it is assumed initial state 212 | is always 0. Shape `(n_states,)`. 213 | 214 | Raises: 215 | ValueError: `transition_matrix`, `reward_matrix` or 216 | `initial_state_dist` have shapes different to specified above. 217 | """ 218 | super().__init__() 219 | 220 | # The following matrices should conform to the shapes below: 221 | 222 | # transition matrix: n_states x n_actions x n_states 223 | n_states = transition_matrix.shape[0] 224 | if n_states != transition_matrix.shape[2]: 225 | raise ValueError( 226 | "Malformed transition_matrix:\n" 227 | f"transition_matrix.shape: {transition_matrix.shape}\n" 228 | f"{n_states} != {transition_matrix.shape[2]}", 229 | ) 230 | 231 | # reward matrix: n_states x n_actions x n_states 232 | # OR n_states x n_actions 233 | # OR n_states 234 | if reward_matrix.shape != transition_matrix.shape[: len(reward_matrix.shape)]: 235 | raise ValueError( 236 | "transition_matrix and reward_matrix are not compatible:\n" 237 | f"transition_matrix.shape: {transition_matrix.shape}\n" 238 | f"reward_matrix.shape: {reward_matrix.shape}", 239 | ) 240 | 241 | # initial state dist: n_states 242 | if initial_state_dist is None: 243 | initial_state_dist = util.one_hot_encoding(0, n_states) 244 | if initial_state_dist.ndim != 1: 245 | raise ValueError( 246 | "initial_state_dist has multiple dimensions:\n" 247 | f"{initial_state_dist.ndim} != 1", 248 | ) 249 | if initial_state_dist.shape[0] != n_states: 250 | raise ValueError( 251 | "transition_matrix and initial_state_dist are not compatible:\n" 252 | f"number of states = {n_states}\n" 253 | f"len(initial_state_dist) = {len(initial_state_dist)}", 254 | ) 255 | 256 | self.transition_matrix = transition_matrix 257 | self.reward_matrix = reward_matrix 258 | self._feature_matrix = None 259 | self.horizon = horizon 260 | self.initial_state_dist = initial_state_dist 261 | 262 | self.state_space = spaces.Discrete(self.state_dim) 263 | self.action_space = spaces.Discrete(self.action_dim) 264 | 265 | def initial_state(self) -> DiscreteSpaceInt: 266 | """Samples from the initial state distribution.""" 267 | return DiscreteSpaceInt( 268 | util.sample_distribution( 269 | self.initial_state_dist, 270 | random=self.np_random, 271 | ), 272 | ) 273 | 274 | def transition( 275 | self, 276 | state: DiscreteSpaceInt, 277 | action: DiscreteSpaceInt, 278 | ) -> DiscreteSpaceInt: 279 | """Samples from transition distribution.""" 280 | return DiscreteSpaceInt( 281 | util.sample_distribution( 282 | self.transition_matrix[state, action], 283 | random=self.np_random, 284 | ), 285 | ) 286 | 287 | def reward( 288 | self, 289 | state: DiscreteSpaceInt, 290 | action: DiscreteSpaceInt, 291 | new_state: DiscreteSpaceInt, 292 | ) -> float: 293 | """Computes reward for a given transition.""" 294 | inputs = (state, action, new_state)[: len(self.reward_matrix.shape)] 295 | return self.reward_matrix[inputs] 296 | 297 | def terminal(self, state: DiscreteSpaceInt, n_actions_taken: int) -> bool: 298 | """Checks if state is terminal.""" 299 | del state 300 | return self.horizon is not None and n_actions_taken >= self.horizon 301 | 302 | @property 303 | def feature_matrix(self): 304 | """Matrix mapping states to feature vectors.""" 305 | # Construct lazily to save memory in algorithms that don't need features. 306 | if self._feature_matrix is None: 307 | n_states = self.state_space.n 308 | self._feature_matrix = np.eye(n_states) 309 | return self._feature_matrix 310 | 311 | @property 312 | def state_dim(self): 313 | """Number of states in this MDP (int).""" 314 | return self.transition_matrix.shape[0] 315 | 316 | @property 317 | def action_dim(self) -> int: 318 | """Number of action vectors (int).""" 319 | return self.transition_matrix.shape[1] 320 | 321 | 322 | ObsEntryType = TypeVar( 323 | "ObsEntryType", 324 | bound=Union[np.floating, np.integer], 325 | ) 326 | 327 | 328 | class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray], Generic[ObsEntryType]): 329 | """Tabular model POMDP. 330 | 331 | This class is specifically for environments where observation != state, 332 | from both a typing perspective but also by defining the method that 333 | draws observations from the state. 334 | 335 | The tabular model is deterministic in drawing observations from the state, 336 | in that given a certain state, the observation is always the same; 337 | a vector with self.obs_dim entries. 338 | """ 339 | 340 | observation_matrix: npt.NDArray[ObsEntryType] 341 | observation_space: spaces.Box 342 | 343 | def __init__( 344 | self, 345 | *, 346 | transition_matrix: np.ndarray, 347 | observation_matrix: npt.NDArray[ObsEntryType], 348 | reward_matrix: np.ndarray, 349 | horizon: Optional[int] = None, 350 | initial_state_dist: Optional[np.ndarray] = None, 351 | ): 352 | """Initializes a tabular model POMDP.""" 353 | self.observation_matrix = observation_matrix 354 | super().__init__( 355 | transition_matrix=transition_matrix, 356 | reward_matrix=reward_matrix, 357 | horizon=horizon, 358 | initial_state_dist=initial_state_dist, 359 | ) 360 | 361 | # observation matrix: n_states x n_observations 362 | if observation_matrix.shape[0] != self.state_dim: 363 | raise ValueError( 364 | "transition_matrix and observation_matrix are not compatible:\n" 365 | f"transition_matrix.shape[0]: {self.state_dim}\n" 366 | f"observation_matrix.shape[0]: {observation_matrix.shape[0]}", 367 | ) 368 | 369 | min_val: float 370 | max_val: float 371 | try: 372 | dtype_iinfo = np.iinfo(self.obs_dtype) 373 | min_val, max_val = dtype_iinfo.min, dtype_iinfo.max 374 | except ValueError: 375 | min_val = -np.inf 376 | max_val = np.inf 377 | self.observation_space = spaces.Box( 378 | low=min_val, 379 | high=max_val, 380 | shape=(self.obs_dim,), 381 | dtype=self.obs_dtype, # type: ignore 382 | ) 383 | 384 | def obs_from_state(self, state: DiscreteSpaceInt) -> npt.NDArray[ObsEntryType]: 385 | """Computes observation from state.""" 386 | # Copy so it can't be mutated in-place (updates will be reflected in 387 | # self.observation_matrix!) 388 | obs = self.observation_matrix[state].copy() 389 | assert obs.ndim == 1, obs.shape 390 | return obs 391 | 392 | @property 393 | def obs_dim(self) -> int: 394 | """Size of observation vectors for this MDP.""" 395 | return self.observation_matrix.shape[1] 396 | 397 | @property 398 | def obs_dtype(self) -> np.dtype: 399 | """Data type of observation vectors (e.g. np.float32).""" 400 | return self.observation_matrix.dtype 401 | 402 | 403 | class TabularModelMDP(BaseTabularModelPOMDP[DiscreteSpaceInt]): 404 | """Tabular model MDP. 405 | 406 | A tabular model MDP is a tabular MDP where the transition and reward 407 | matrices are constant. 408 | """ 409 | 410 | def __init__( 411 | self, 412 | *, 413 | transition_matrix: np.ndarray, 414 | reward_matrix: np.ndarray, 415 | horizon: Optional[int] = None, 416 | initial_state_dist: Optional[np.ndarray] = None, 417 | ): 418 | """Initializes a tabular model MDP. 419 | 420 | Args: 421 | transition_matrix: Matrix of shape `(n_states, n_actions, n_states)` 422 | containing transition probabilities. 423 | reward_matrix: Matrix of shape `(n_states, n_actions, n_states)` 424 | containing reward values. 425 | initial_state_dist: Distribution over initial states. Shape `(n_states,)`. 426 | horizon: Maximum number of steps to take in an episode. 427 | """ 428 | super().__init__( 429 | transition_matrix=transition_matrix, 430 | reward_matrix=reward_matrix, 431 | horizon=horizon, 432 | initial_state_dist=initial_state_dist, 433 | ) 434 | self.observation_space = self.state_space 435 | 436 | def obs_from_state(self, state: DiscreteSpaceInt) -> DiscreteSpaceInt: 437 | """Identity since observation == state in an MDP.""" 438 | return state 439 | -------------------------------------------------------------------------------- /src/seals/classic_control.py: -------------------------------------------------------------------------------- 1 | """Adaptation of classic Gym environments for specification learning algorithms.""" 2 | 3 | from typing import Any, Dict, Optional 4 | import warnings 5 | 6 | import gymnasium as gym 7 | from gymnasium import spaces 8 | from gymnasium.envs import classic_control 9 | import numpy as np 10 | 11 | from seals import util 12 | 13 | 14 | class FixedHorizonCartPole(classic_control.CartPoleEnv): 15 | """Fixed-length variant of CartPole-v1. 16 | 17 | Reward is 1.0 whenever the CartPole is an "ok" state (i.e., the pole is upright 18 | and the cart is on the screen). Otherwise, reward is 0.0. 19 | 20 | Terminated is always False. 21 | By default, this environment is wrapped in 'TimeLimit' with max steps 500, 22 | which sets `truncated` to true after that many steps. 23 | """ 24 | 25 | def __init__(self, *args, **kwargs): 26 | """Builds FixedHorizonCartPole, modifying observation_space from gym parent.""" 27 | super().__init__(*args, **kwargs) 28 | 29 | high = [ 30 | np.finfo(np.float32).max, # x axis 31 | np.finfo(np.float32).max, # x velocity 32 | np.pi, # theta in radians 33 | np.finfo(np.float32).max, # theta velocity 34 | ] 35 | high = np.array(high) 36 | self.observation_space = spaces.Box(-high, high, dtype=np.float32) 37 | 38 | def reset( 39 | self, 40 | seed: Optional[int] = None, 41 | options: Optional[Dict[str, Any]] = None, 42 | ): 43 | """Reset for FixedHorizonCartPole.""" 44 | observation, info = super().reset(seed=seed, options=options) 45 | return observation.astype(np.float32), info 46 | 47 | def step(self, action): 48 | """Step function for FixedHorizonCartPole.""" 49 | with warnings.catch_warnings(): 50 | # Filter out CartPoleEnv warning for calling step() beyond 51 | # terminated=True or truncated=True 52 | warnings.filterwarnings("ignore", ".*You are calling.*") 53 | super().step(action) 54 | 55 | self.state = list(self.state) 56 | x, _, theta, _ = self.state 57 | 58 | # Normalize theta to [-pi, pi] range. 59 | theta = (theta + np.pi) % (2 * np.pi) - np.pi 60 | self.state[2] = theta 61 | 62 | state_ok = bool( 63 | abs(x) < self.x_threshold and abs(theta) < self.theta_threshold_radians, 64 | ) 65 | 66 | rew = 1.0 if state_ok else 0.0 67 | return np.array(self.state, dtype=np.float32), rew, False, False, {} 68 | 69 | 70 | def mountain_car(*args, **kwargs): 71 | """Fixed-length variant of MountainCar-v0. 72 | 73 | In the event of early episode completion (i.e., the car reaches the 74 | goal), we enter an absorbing state that repeats the final observation 75 | and returns reward 0. 76 | 77 | Done is always returned on timestep 200 only. 78 | """ 79 | env = gym.make("MountainCar-v0", *args, **kwargs) 80 | env = util.ObsCastWrapper(env, dtype=np.float32) 81 | env = util.AbsorbAfterDoneWrapper(env) 82 | return env 83 | -------------------------------------------------------------------------------- /src/seals/diagnostics/__init__.py: -------------------------------------------------------------------------------- 1 | """Simple diagnostic environments.""" 2 | 3 | import gymnasium as gym 4 | 5 | gym.register( 6 | id="seals/Branching-v0", 7 | entry_point="seals.diagnostics.branching:BranchingEnv", 8 | max_episode_steps=11, 9 | ) 10 | 11 | gym.register( 12 | id="seals/EarlyTermNeg-v0", 13 | entry_point="seals.diagnostics.early_term:EarlyTermNegEnv", 14 | max_episode_steps=10, 15 | ) 16 | 17 | gym.register( 18 | id="seals/EarlyTermPos-v0", 19 | entry_point="seals.diagnostics.early_term:EarlyTermPosEnv", 20 | max_episode_steps=10, 21 | ) 22 | 23 | gym.register( 24 | id="seals/InitShiftTrain-v0", 25 | entry_point="seals.diagnostics.init_shift:InitShiftTrainEnv", 26 | max_episode_steps=3, 27 | ) 28 | 29 | gym.register( 30 | id="seals/InitShiftTest-v0", 31 | entry_point="seals.diagnostics.init_shift:InitShiftTestEnv", 32 | max_episode_steps=3, 33 | ) 34 | 35 | gym.register( 36 | id="seals/LargestSum-v0", 37 | entry_point="seals.diagnostics.largest_sum:LargestSumEnv", 38 | max_episode_steps=1, 39 | ) 40 | 41 | gym.register( 42 | id="seals/NoisyObs-v0", 43 | entry_point="seals.diagnostics.noisy_obs:NoisyObsEnv", 44 | max_episode_steps=15, 45 | ) 46 | 47 | gym.register( 48 | id="seals/Parabola-v0", 49 | entry_point="seals.diagnostics.parabola:ParabolaEnv", 50 | max_episode_steps=20, 51 | ) 52 | 53 | gym.register( 54 | id="seals/ProcGoal-v0", 55 | entry_point="seals.diagnostics.proc_goal:ProcGoalEnv", 56 | max_episode_steps=20, 57 | ) 58 | 59 | gym.register( 60 | id="seals/RiskyPath-v0", 61 | entry_point="seals.diagnostics.risky_path:RiskyPathEnv", 62 | max_episode_steps=5, 63 | ) 64 | 65 | gym.register( 66 | id="seals/Sort-v0", 67 | entry_point="seals.diagnostics.sort:SortEnv", 68 | max_episode_steps=6, 69 | ) 70 | 71 | 72 | def register_cliff_world(suffix, kwargs): 73 | """Register a CliffWorld with the given suffix and keyword arguments.""" 74 | gym.register( 75 | f"seals/CliffWorld{suffix}-v0", 76 | entry_point="seals.diagnostics.cliff_world:CliffWorldEnv", 77 | kwargs=kwargs, 78 | ) 79 | 80 | 81 | def register_all_cliff_worlds(): 82 | """Register all CliffWorld environments.""" 83 | for width, height, horizon in [(7, 4, 9), (15, 6, 18), (100, 20, 110)]: 84 | for use_xy in [False, True]: 85 | use_xy_str = "XY" if use_xy else "" 86 | register_cliff_world( 87 | f"{width}x{height}{use_xy_str}", 88 | kwargs={ 89 | "width": width, 90 | "height": height, 91 | "use_xy_obs": use_xy, 92 | "horizon": horizon, 93 | }, 94 | ) 95 | 96 | 97 | register_all_cliff_worlds() 98 | 99 | # These parameter choices are somewhat arbitrary. 100 | # We anticipate most users will want to construct RandomTransitionEnv directly. 101 | gym.register( 102 | "seals/Random-v0", 103 | entry_point="seals.diagnostics.random_trans:RandomTransitionEnv", 104 | kwargs={ 105 | "n_states": 16, 106 | "n_actions": 3, 107 | "branch_factor": 2, 108 | "horizon": 20, 109 | "random_obs": True, 110 | "obs_dim": 5, 111 | "generator_seed": 42, 112 | }, 113 | ) 114 | -------------------------------------------------------------------------------- /src/seals/diagnostics/branching.py: -------------------------------------------------------------------------------- 1 | """Hard-exploration environment.""" 2 | 3 | import itertools 4 | 5 | import numpy as np 6 | 7 | from seals import base_envs 8 | 9 | 10 | class BranchingEnv(base_envs.TabularModelMDP): 11 | """Long branching environment requiring exploration. 12 | 13 | The agent must traverse a specific path of length L to reach a 14 | final goal, with B choices at each step. Wrong actions lead to 15 | dead-ends with zero reward. 16 | """ 17 | 18 | def __init__(self, branch_factor: int = 2, length: int = 10): 19 | """Construct environment. 20 | 21 | Args: 22 | branch_factor: number of actions at each state. 23 | length: path length from initial state to goal. 24 | """ 25 | nS = 1 + branch_factor * length 26 | nA = branch_factor 27 | 28 | def get_next(state: int, action: int) -> int: 29 | can_move = state % branch_factor == 0 and state != nS - 1 30 | return state + (action + 1) * can_move 31 | 32 | transition_matrix = np.zeros((nS, nA, nS)) 33 | for state, action in itertools.product(range(nS), range(nA)): 34 | transition_matrix[state, action, get_next(state, action)] = 1.0 35 | 36 | reward_matrix = np.zeros((nS,)) 37 | reward_matrix[-1] = 1.0 38 | 39 | super().__init__( 40 | transition_matrix=transition_matrix, 41 | reward_matrix=reward_matrix, 42 | ) 43 | -------------------------------------------------------------------------------- /src/seals/diagnostics/cliff_world.py: -------------------------------------------------------------------------------- 1 | """A cliff world that uses the TabularModelPOMDP.""" 2 | 3 | import numpy as np 4 | 5 | from seals.base_envs import TabularModelPOMDP 6 | 7 | 8 | class CliffWorldEnv(TabularModelPOMDP): 9 | """A grid world with a goal next to a cliff the agent may fall into. 10 | 11 | Illustration:: 12 | 13 | 0 1 2 3 4 5 6 7 8 9 14 | +-+-+-+-+-+-+-+-+-+-+ Wind: 15 | 0 |S|C|C|C|C|C|C|C|C|G| 16 | +-+-+-+-+-+-+-+-+-+-+ ^ ^ ^ 17 | 1 | | | | | | | | | | | | | | 18 | +-+-+-+-+-+-+-+-+-+-+ 19 | 2 | | | | | | | | | | | ^ ^ ^ 20 | +-+-+-+-+-+-+-+-+-+-+ | | | 21 | 22 | Aim is to get from S to G. The G square has reward +10, the C squares 23 | ("cliff") have reward -10, and all other squares have reward -1. Agent can 24 | move in all directions (except through walls), but there is 30% chance that 25 | they will be blown upwards by one more unit than intended due to wind. 26 | Optimal policy is to go out a bit and avoid the cliff, but still hit goal 27 | eventually. 28 | """ 29 | 30 | width: int 31 | height: int 32 | 33 | def __init__( 34 | self, 35 | *, 36 | width: int, 37 | height: int, 38 | horizon: int, 39 | use_xy_obs: bool, 40 | rew_default: int = -1, 41 | rew_goal: int = 10, 42 | rew_cliff: int = -10, 43 | fail_p: float = 0.3, 44 | ): 45 | """Builds CliffWorld with specified dimensions and reward.""" 46 | assert ( 47 | width >= 3 and height >= 2 48 | ), "degenerate grid world requested; is this a bug?" 49 | self.width = width 50 | self.height = height 51 | succ_p = 1 - fail_p 52 | n_states = width * height 53 | O_mat = np.zeros( 54 | (n_states, 2 if use_xy_obs else n_states), 55 | dtype=np.float32, 56 | ) 57 | R_vec = np.zeros((n_states,)) 58 | T_mat = np.zeros((n_states, 4, n_states)) 59 | 60 | def to_id_clamp(row, col): 61 | """Convert (x,y) state to state ID, after clamp x & y to lie in grid.""" 62 | row = min(max(row, 0), height - 1) 63 | col = min(max(col, 0), width - 1) 64 | state_id = row * width + col 65 | assert 0 <= state_id < T_mat.shape[0] 66 | return state_id 67 | 68 | for row in range(height): 69 | for col in range(width): 70 | state_id = to_id_clamp(row, col) 71 | 72 | # start by computing reward 73 | if row > 0: 74 | r = rew_default # blank 75 | elif col == 0: 76 | r = rew_default # start 77 | elif col == width - 1: 78 | r = rew_goal # goal 79 | else: 80 | r = rew_cliff # cliff 81 | R_vec[state_id] = r 82 | 83 | # now compute observation 84 | if use_xy_obs: 85 | # (x, y) coordinate scaled to (0,1) 86 | O_mat[state_id, :] = [ 87 | float(col) / (width - 1), 88 | float(row) / (height - 1), 89 | ] 90 | else: 91 | # our observation matrix is just the identity; observation 92 | # is an indicator vector telling us exactly what state 93 | # we're in 94 | O_mat[state_id, state_id] = 1 95 | 96 | # finally, compute transition matrix entries for each of the 97 | # four actions 98 | for drow in [-1, 1]: 99 | for dcol in [-1, 1]: 100 | action_id = (drow + 1) + (dcol + 1) // 2 101 | target_state = to_id_clamp(row + drow, col + dcol) 102 | fail_state = to_id_clamp(row + drow - 1, col + dcol) 103 | T_mat[state_id, action_id, fail_state] += fail_p 104 | T_mat[state_id, action_id, target_state] += succ_p 105 | 106 | assert np.allclose(np.sum(T_mat, axis=-1), 1, rtol=1e-5), ( 107 | "un-normalised matrix %s" % O_mat 108 | ) 109 | super().__init__( 110 | transition_matrix=T_mat, 111 | observation_matrix=O_mat, 112 | reward_matrix=R_vec, 113 | horizon=horizon, 114 | initial_state_dist=None, 115 | ) 116 | 117 | def draw_value_vec(self, D: np.ndarray) -> None: 118 | """Use matplotlib to plot a vector of values for each state. 119 | 120 | The vector could represent things like reward, occupancy measure, etc. 121 | 122 | Args: 123 | D: the vector to plot. 124 | 125 | Raises: 126 | ImportError: if matplotlib is not installed. 127 | """ 128 | try: # pragma: no cover 129 | import matplotlib.pyplot as plt 130 | except ImportError as exc: # pragma: no cover 131 | raise ImportError( 132 | "matplotlib is not installed in your system, " 133 | "and is required for this function.", 134 | ) from exc 135 | 136 | grid = D.reshape(self.height, self.width) 137 | plt.imshow(grid) 138 | plt.gca().grid(False) 139 | -------------------------------------------------------------------------------- /src/seals/diagnostics/early_term.py: -------------------------------------------------------------------------------- 1 | """Environment checking for correctness under early termination.""" 2 | 3 | import functools 4 | 5 | import numpy as np 6 | 7 | from seals import base_envs 8 | 9 | 10 | class EarlyTerminationEnv(base_envs.TabularModelMDP): 11 | """Three-state MDP with early termination state. 12 | 13 | Many implementations of imitation learning algorithms incorrectly assign a 14 | value of zero to terminal states [1]. Depending on the sign of the learned 15 | reward function in non-terminal states, this can either bias the agent to 16 | end episodes early or prolong them as long as possible. This confounds 17 | evaluation as performance is spuriously high in tasks where the termination 18 | bias aligns with the task objective. These tasks attempt to detect this 19 | type of bias, and they are adapted from [1]. 20 | 21 | The environment is a 3-state MDP, in which the agent can either alternate 22 | between two initial states until reaching the time horizon, or they can 23 | move to a terminal state causing the episode to terminate early. 24 | 25 | [1] Kostrikov, Ilya, et al. "Discriminator-actor-critic: Addressing sample 26 | inefficiency and reward bias in adversarial imitation learning." arXiv 27 | preprint arXiv:1809.02925 (2018). 28 | """ 29 | 30 | def __init__(self, is_reward_positive: bool = True): 31 | """Construct environment. 32 | 33 | Args: 34 | is_reward_positive: whether rewards are positive or negative. 35 | """ 36 | nS = 3 37 | nA = 2 38 | 39 | transition_matrix = np.zeros((nS, nA, nS)) 40 | 41 | transition_matrix[0, :, 1] = 1.0 42 | 43 | transition_matrix[1, 0, 0] = 1.0 44 | transition_matrix[1, 1, 2] = 1.0 45 | 46 | transition_matrix[2, :, 2] = 1.0 47 | 48 | reward_sign = 2 * is_reward_positive - 1 49 | reward_matrix = reward_sign * np.ones((nS,), dtype=float) 50 | 51 | super().__init__( 52 | transition_matrix=transition_matrix, 53 | reward_matrix=reward_matrix, 54 | ) 55 | 56 | def terminal(self, state: base_envs.DiscreteSpaceInt, n_actions_taken: int) -> bool: 57 | """Returns True if (and only if) in state 2.""" 58 | return bool(state == 2) 59 | 60 | 61 | EarlyTermPosEnv = functools.partial(EarlyTerminationEnv, is_reward_positive=True) 62 | EarlyTermNegEnv = functools.partial(EarlyTerminationEnv, is_reward_positive=False) 63 | -------------------------------------------------------------------------------- /src/seals/diagnostics/init_shift.py: -------------------------------------------------------------------------------- 1 | """Environment with shift in initial state distribution.""" 2 | 3 | import functools 4 | import itertools 5 | 6 | import numpy as np 7 | 8 | from seals import base_envs 9 | 10 | 11 | class InitShiftEnv(base_envs.TabularModelMDP): 12 | """Tests for robustness to initial state shift. 13 | 14 | Many LfH algorithms learn from expert demonstrations. This can be 15 | problematic when the environment the demonstrations were gathered in 16 | differs even slightly from the learner's environment. 17 | 18 | This task illustrates this problem. We have a depth-2 full binary tree 19 | where the agent moves left or right until reaching a leaf. The expert 20 | starts at the root s_0, whereas the learner starts at the left branch s_1 21 | and so can only reach leaves s_3 and s_4. Reward is only given at the 22 | leaves. 23 | 24 | The expert always move to the highest reward leaf s_6, so any algorithm 25 | that relies on demonstrations will not know whether it is better to go to 26 | s_3 or s_4. By contrast, feedback such as preference comparison can 27 | disambiguate this case. 28 | """ 29 | 30 | def __init__(self, initial_state: base_envs.DiscreteSpaceInt): 31 | """Constructs environment. 32 | 33 | Args: 34 | initial_state: fixed initial state. 35 | 36 | Raises: 37 | ValueError: `initial_state` not in [0,6]. 38 | """ 39 | nS = 7 40 | nA = 2 41 | 42 | if not 0 <= initial_state < nS: 43 | raise ValueError(f"Initial state {initial_state} must lie in [0,{nS})") 44 | 45 | self._initial_state = initial_state 46 | 47 | non_leaves = np.arange(3) 48 | leaves = np.arange(3, 7) 49 | 50 | transition_matrix = np.zeros((nS, nA, nS)) 51 | 52 | for state, action in itertools.product(non_leaves, range(nA)): 53 | next_state = 2 * state + 1 + action 54 | transition_matrix[state, action, next_state] = 1.0 55 | 56 | transition_matrix[leaves, :, leaves] = 1.0 57 | 58 | reward_matrix = np.zeros((nS,)) 59 | reward_matrix[leaves] = [1, -1, -1, 2] 60 | 61 | super().__init__( 62 | transition_matrix=transition_matrix, 63 | reward_matrix=reward_matrix, 64 | ) 65 | 66 | def initial_state(self) -> base_envs.DiscreteSpaceInt: 67 | """Returns initial state defined in constructor.""" 68 | return self._initial_state 69 | 70 | 71 | InitShiftTrainEnv = functools.partial(InitShiftEnv, initial_state=0) 72 | InitShiftTestEnv = functools.partial(InitShiftEnv, initial_state=1) 73 | -------------------------------------------------------------------------------- /src/seals/diagnostics/largest_sum.py: -------------------------------------------------------------------------------- 1 | """Environment testing scalability to high-dimensional tasks.""" 2 | 3 | from gymnasium import spaces 4 | import numpy as np 5 | 6 | from seals import base_envs 7 | 8 | 9 | class LargestSumEnv(base_envs.ResettableMDP): 10 | """High-dimensional linear classification problem. 11 | 12 | This environment evaluates how algorithms scale with increasing 13 | dimensionality. It is a classification task with binary actions 14 | and uniformly sampled states s in [0, 1]**L. The agent is 15 | rewarded for taking action 1 if the sum of the first half x[:L//2] 16 | is greater than the sum of the second half x[L//2:], and otherwise 17 | is rewarded for taking action 0. 18 | """ 19 | 20 | def __init__(self, length: int = 50): 21 | """Build environment. 22 | 23 | Args: 24 | length: dimensionality of state space vector. 25 | """ 26 | super().__init__() 27 | self._length = length 28 | self.state_space = spaces.Box(low=0.0, high=1.0, shape=(length,)) 29 | self.action_space = spaces.Discrete(2) 30 | 31 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: 32 | """Always returns True, since this task should have a 1-timestep horizon.""" 33 | return True 34 | 35 | def initial_state(self) -> np.ndarray: 36 | """Returns vector sampled uniformly in [0, 1]**L.""" 37 | init_state = self.np_random.random((self._length,)) 38 | return init_state.astype(self.observation_space.dtype) 39 | 40 | def reward(self, state: np.ndarray, act: int, next_state: np.ndarray) -> float: 41 | """Returns +1.0 reward when action is the right label and 0.0 otherwise.""" 42 | n = self._length 43 | label = np.sum(state[: n // 2]) > np.sum(state[n // 2 :]) 44 | return float(act == label) 45 | 46 | def transition(self, state: np.ndarray, action: int) -> np.ndarray: 47 | """Returns same state.""" 48 | return state 49 | -------------------------------------------------------------------------------- /src/seals/diagnostics/noisy_obs.py: -------------------------------------------------------------------------------- 1 | """Environment testing for robustness to noise.""" 2 | 3 | from gymnasium import spaces 4 | import numpy as np 5 | 6 | from seals import base_envs, util 7 | 8 | 9 | class NoisyObsEnv(base_envs.ResettablePOMDP): 10 | """Simple gridworld with noisy observations. 11 | 12 | The agent randomly starts at the one of the corners of an MxM grid and 13 | tries to reach and stay at the center. The observation consists of the 14 | agent's (x,y) coordinates and L "distractor" samples of Gaussian noise . 15 | The challenge is to select the relevant features in the observations, and 16 | not overfit to noise. 17 | """ 18 | 19 | def __init__(self, *, size: int = 5, noise_length: int = 20): 20 | """Build environment. 21 | 22 | Args: 23 | size: width and height of gridworld. 24 | noise_length: dimension of noise vector in observation. 25 | """ 26 | super().__init__() 27 | 28 | self._size = size 29 | self._noise_length = noise_length 30 | self._goal = np.array([self._size // 2, self._size // 2]) 31 | 32 | obs_box_low = np.concatenate( 33 | ([0, 0], np.full(self._noise_length, -np.inf)), # type: ignore 34 | ) 35 | obs_box_high = np.concatenate( 36 | ([size - 1, size - 1], np.full(self._noise_length, np.inf)), # type: ignore 37 | ) 38 | 39 | self.state_space = spaces.MultiDiscrete([size, size]) 40 | self.action_space = spaces.Discrete(5) 41 | self.observation_space = spaces.Box( 42 | low=obs_box_low, 43 | high=obs_box_high, 44 | dtype=np.float32, 45 | ) 46 | 47 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: 48 | """Always returns False.""" 49 | return False 50 | 51 | def initial_state(self) -> np.ndarray: 52 | """Returns one of the grid's corners.""" 53 | n = self._size 54 | corners = np.array([[0, 0], [n - 1, 0], [0, n - 1], [n - 1, n - 1]]) 55 | return corners[self.np_random.integers(4)] 56 | 57 | def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float: 58 | """Returns +1.0 reward if state is the goal and 0.0 otherwise.""" 59 | return float(np.all(state == self._goal)) 60 | 61 | def transition(self, state: np.ndarray, action: int) -> np.ndarray: 62 | """Returns next state according to grid.""" 63 | return util.grid_transition_fn( 64 | state, 65 | action, 66 | x_bounds=(0, self._size - 1), 67 | y_bounds=(0, self._size - 1), 68 | ) 69 | 70 | def obs_from_state(self, state: np.ndarray) -> np.ndarray: 71 | """Returns (x, y) concatenated with Gaussian noise.""" 72 | noise_vector = self.np_random.normal(size=self._noise_length) 73 | return np.concatenate([state, noise_vector]).astype(np.float32) 74 | -------------------------------------------------------------------------------- /src/seals/diagnostics/parabola.py: -------------------------------------------------------------------------------- 1 | """Environment testing for generalization in continuous spaces.""" 2 | 3 | from gymnasium import spaces 4 | import numpy as np 5 | 6 | from seals import base_envs 7 | 8 | 9 | class ParabolaEnv(base_envs.ResettableMDP): 10 | """Environment to mimic parabola curves. 11 | 12 | This environment tests algorithms' ability to learn in continuous 13 | action spaces, a challenge for Q-learning methods in particular. 14 | The goal is to mimic the path of a parabola p(x) = A*x**2 + B*x + 15 | C, where A, B and C are constants sampled uniformly from [-1, 1] 16 | at the start of the episode. 17 | """ 18 | 19 | def __init__(self, x_step: float = 0.05, bounds: float = 5): 20 | """Construct environment. 21 | 22 | Args: 23 | x_step: x position difference between timesteps. 24 | bounds: limits coordinates, useful for keeping rewards in 25 | a small bounded range. 26 | """ 27 | super().__init__() 28 | self._x_step = x_step 29 | self._bounds = bounds 30 | 31 | state_high = np.array([bounds, bounds, 1.0, 1.0, 1.0]) 32 | state_low = (-1) * state_high 33 | 34 | self.state_space = spaces.Box(low=state_low, high=state_high) 35 | self.action_space = spaces.Box(low=(-2) * bounds, high=2 * bounds, shape=()) 36 | 37 | def terminal(self, state: int, n_actions_taken: int) -> bool: 38 | """Always returns False.""" 39 | return False 40 | 41 | def initial_state(self) -> np.ndarray: 42 | """Get state by sampling a random parabola.""" 43 | a, b, c = -1 + 2 * self.np_random.random((3,)) 44 | x, y = 0, c 45 | return np.array([x, y, a, b, c], dtype=self.state_space.dtype) 46 | 47 | def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float: 48 | """Negative squared vertical distance from parabola.""" 49 | x, y, a, b, c = state 50 | target_y = a * x**2 + b * x + c 51 | return (-1) * (y - target_y) ** 2 52 | 53 | def transition(self, state: np.ndarray, action: int) -> np.ndarray: 54 | """Update x according to x_step and y according to action.""" 55 | x, y, a, b, c = state 56 | next_x = np.clip(x + self._x_step, -self._bounds, self._bounds) 57 | next_y = np.clip(y + action, -self._bounds, self._bounds) 58 | return np.array([next_x, next_y, a, b, c], dtype=self.state_space.dtype) 59 | -------------------------------------------------------------------------------- /src/seals/diagnostics/proc_goal.py: -------------------------------------------------------------------------------- 1 | """Large gridworld with random agent and goal position.""" 2 | 3 | from gymnasium import spaces 4 | import numpy as np 5 | 6 | from seals import base_envs, util 7 | 8 | 9 | class ProcGoalEnv(base_envs.ResettableMDP): 10 | """Large gridworld with random agent and goal position. 11 | 12 | In this task, the agent starts at a random position in a large 13 | grid, and must navigate to a goal randomly placed in a 14 | neighborhood around the agent. The observation is a 4-dimensional 15 | vector containing the (x,y) coordinates of the agent and the goal. 16 | The reward at each timestep is the negative Manhattan distance 17 | between the two positions. With a large enough grid, generalizing 18 | is necessary to achieve good performance, since most initial 19 | states will be unseen. 20 | """ 21 | 22 | def __init__(self, bounds: int = 100, distance: int = 10): 23 | """Constructs environment. 24 | 25 | Args: 26 | bounds: the absolute values of the coordinates of the initial agent 27 | position are bounded by `bounds`. Increasing the value might make 28 | generalization harder. 29 | distance: initial distance between agent and goal. 30 | """ 31 | super().__init__() 32 | self._bounds = bounds 33 | self._distance = distance 34 | self.state_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,)) 35 | self.action_space = spaces.Discrete(5) 36 | 37 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: 38 | """Always returns False.""" 39 | return False 40 | 41 | def initial_state(self) -> np.ndarray: 42 | """Samples random agent position and random goal.""" 43 | pos = self.np_random.integers(low=-self._bounds, high=self._bounds, size=(2,)) 44 | 45 | x_dist = self.np_random.integers(self._distance) 46 | y_dist = self._distance - x_dist 47 | random_signs = 2 * self.np_random.integers(2, size=2) - 1 48 | goal = pos + random_signs * (x_dist, y_dist) 49 | 50 | return np.concatenate([pos, goal]).astype(self.observation_space.dtype) 51 | 52 | def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float: 53 | """Negative L1 distance to goal.""" 54 | return (-1) * np.sum(np.abs(state[2:] - state[:2])) 55 | 56 | def transition(self, state: np.ndarray, action: int) -> np.ndarray: 57 | """Returns next state according to grid.""" 58 | pos, goal = state[:2], state[2:] 59 | next_pos = util.grid_transition_fn(pos, action) 60 | return np.concatenate([next_pos, goal]) 61 | -------------------------------------------------------------------------------- /src/seals/diagnostics/random_trans.py: -------------------------------------------------------------------------------- 1 | """A tabular model MDP with a random transition matrix.""" 2 | 3 | from typing import Optional 4 | 5 | import numpy as np 6 | 7 | from seals.base_envs import TabularModelPOMDP 8 | 9 | 10 | class RandomTransitionEnv(TabularModelPOMDP): 11 | """AN MDP with a random transition matrix. 12 | 13 | Random matrix is created by `make_random_trans_mat`. 14 | """ 15 | 16 | reward_weights: np.ndarray 17 | 18 | def __init__( 19 | self, 20 | *, 21 | n_states: int, 22 | n_actions: int, 23 | branch_factor: int, 24 | horizon: int, 25 | random_obs: bool, 26 | obs_dim: Optional[int] = None, 27 | generator_seed: Optional[int] = None, 28 | ): 29 | """Builds RandomTransitionEnv. 30 | 31 | Args: 32 | n_states: Number of states. 33 | n_actions: Number of actions. 34 | branch_factor: Maximum number of states that can be reached from 35 | each state-action pair. 36 | horizon: The horizon of the MDP, i.e. the episode length. 37 | random_obs: Whether to use random observations (True) 38 | or one-hot coded (False). 39 | obs_dim: The size of the observation vectors; must be `None` 40 | if `random_obs == False`. 41 | generator_seed: Seed for NumPy RNG. 42 | 43 | Raises: 44 | ValueError: If ``obs_dim`` is not ``None`` when ``random_obs == False``. 45 | ValueError: If ``obs_dim`` is ``None`` when ``random_obs == True``. 46 | """ 47 | # this generator is ONLY for constructing the MDP, not for controlling 48 | # random outcomes during rollouts 49 | rand_gen = np.random.default_rng(generator_seed) 50 | 51 | if random_obs: 52 | if obs_dim is None: 53 | obs_dim = n_states 54 | else: 55 | if obs_dim is not None: 56 | raise ValueError("obs_dim must be None if random_obs is False") 57 | 58 | observation_matrix = self.make_obs_mat( 59 | n_states=n_states, 60 | is_random=random_obs, 61 | obs_dim=obs_dim, 62 | rand_state=rand_gen, 63 | ) 64 | transition_matrix = self.make_random_trans_mat( 65 | n_states=n_states, 66 | n_actions=n_actions, 67 | max_branch_factor=branch_factor, 68 | rand_state=rand_gen, 69 | ) 70 | initial_state_dist = self.make_random_state_dist( 71 | n_avail=branch_factor, 72 | n_states=n_states, 73 | rand_state=rand_gen, 74 | ) 75 | 76 | self.reward_weights = rand_gen.normal(size=(observation_matrix.shape[-1],)) 77 | reward_matrix = observation_matrix @ self.reward_weights 78 | super().__init__( 79 | transition_matrix=transition_matrix, 80 | observation_matrix=observation_matrix, 81 | reward_matrix=reward_matrix, 82 | horizon=horizon, 83 | initial_state_dist=initial_state_dist, 84 | ) 85 | 86 | @staticmethod 87 | def make_random_trans_mat( 88 | n_states, 89 | n_actions, 90 | max_branch_factor, 91 | rand_state: Optional[np.random.Generator] = None, 92 | ) -> np.ndarray: 93 | """Make a 'random' transition matrix. 94 | 95 | Each action goes to at least `max_branch_factor` other states from the 96 | current state, with transition distribution sampled from Dirichlet(1,1,…,1). 97 | 98 | This roughly apes the strategy from some old Lisp code that Rich Sutton 99 | left on the internet (http://incompleteideas.net/RandomMDPs.html), and is 100 | therefore a legitimate way to generate MDPs. 101 | 102 | Args: 103 | n_states: Number of states. 104 | n_actions: Number of actions. 105 | max_branch_factor: Maximum number of states that can be reached from 106 | each state-action pair. 107 | rand_state: NumPy random state. 108 | 109 | Returns: 110 | The transition matrix `mat`, where `mat[s,a,next_s]` gives the probability 111 | of transitioning to `next_s` after taking action `a` in state `s`. 112 | """ 113 | if rand_state is None: 114 | rand_state = np.random.default_rng() 115 | assert rand_state is not None 116 | out_mat = np.zeros((n_states, n_actions, n_states), dtype="float32") 117 | for start_state in range(n_states): 118 | for action in range(n_actions): 119 | # uniformly sample a number of successors in [1,max_branch_factor] 120 | # for this action 121 | successors = rand_state.integers(1, max_branch_factor + 1) 122 | next_states = rand_state.choice( 123 | n_states, 124 | size=(successors,), 125 | replace=False, 126 | ) 127 | # generate random vec in probability simplex 128 | next_vec = rand_state.dirichlet(np.ones((successors,))) 129 | next_vec = next_vec / np.sum(next_vec) 130 | out_mat[start_state, action, next_states] = next_vec 131 | return out_mat 132 | 133 | @staticmethod 134 | def make_random_state_dist( 135 | n_avail: int, 136 | n_states: int, 137 | rand_state: Optional[np.random.Generator] = None, 138 | ) -> np.ndarray: 139 | """Make a random initial state distribution over n_states. 140 | 141 | Args: 142 | n_avail: Number of states available to transition into. 143 | n_states: Total number of states. 144 | rand_state: NumPy random state. 145 | 146 | Returns: 147 | An initial state distribution that is zero at all but a uniformly random 148 | chosen subset of `n_avail` states. This subset of chosen states are set to a 149 | sample from the uniform distribution over the (n_avail-1) simplex, aka the 150 | flat Dirichlet distribution. 151 | 152 | Raises: 153 | ValueError: If `n_avail` is not in the range `(0, n_states]`. 154 | """ # noqa: DAR402 155 | if rand_state is None: 156 | rand_state = np.random.default_rng() 157 | assert rand_state is not None 158 | assert 0 < n_avail <= n_states 159 | init_dist = np.zeros((n_states,)) 160 | next_states = rand_state.choice(n_states, size=(n_avail,), replace=False) 161 | avail_state_dist = rand_state.dirichlet(np.ones((n_avail,))) 162 | init_dist[next_states] = avail_state_dist 163 | assert np.sum(init_dist > 0) == n_avail 164 | init_dist = init_dist / np.sum(init_dist) 165 | return init_dist 166 | 167 | @staticmethod 168 | def make_obs_mat( 169 | n_states: int, 170 | is_random: bool, 171 | obs_dim: Optional[int] = None, 172 | rand_state: Optional[np.random.Generator] = None, 173 | ) -> np.ndarray: 174 | """Makes an observation matrix with a single observation for each state. 175 | 176 | Args: 177 | n_states (int): Number of states. 178 | is_random (bool): Are observations drawn at random? 179 | If `True`, draw from random normal distribution. 180 | If `False`, are unique one-hot vectors for each state. 181 | obs_dim (int or NoneType): Must be `None` if `is_random == False`. 182 | Otherwise, this must be set to the size of the random vectors. 183 | rand_state (np.random.Generator): Random number generator. 184 | 185 | Returns: 186 | A matrix of shape `(n_states, obs_dim if is_random else n_states)`. 187 | 188 | Raises: 189 | ValueError: If ``is_random == False`` and ``obs_dim is not None``. 190 | """ 191 | if rand_state is None: 192 | rand_state = np.random.default_rng() 193 | assert rand_state is not None 194 | if is_random: 195 | if obs_dim is None: 196 | raise ValueError("obs_dim must be set if random_obs is True") 197 | obs_mat = rand_state.normal(0, 2, (n_states, obs_dim)) 198 | else: 199 | if obs_dim is not None: 200 | raise ValueError("obs_dim must be None if random_obs is False") 201 | obs_mat = np.identity(n_states) 202 | assert ( 203 | obs_mat.ndim == 2 204 | and obs_mat.shape[:1] == (n_states,) 205 | and obs_mat.shape[1] > 0 206 | ) 207 | return obs_mat.astype(np.float32) 208 | -------------------------------------------------------------------------------- /src/seals/diagnostics/risky_path.py: -------------------------------------------------------------------------------- 1 | """Environment testing for correct behavior under stochasticity.""" 2 | 3 | import numpy as np 4 | 5 | from seals import base_envs 6 | 7 | 8 | class RiskyPathEnv(base_envs.TabularModelMDP): 9 | """Environment with two paths to a goal: one safe and one risky. 10 | 11 | Many LfH algorithms are derived from Maximum Entropy Inverse Reinforcement 12 | Learning [1], which models the demonstrator as producing trajectories with 13 | probability p(tau) proportional to exp(R(tau)). This model implies that a 14 | demonstrator can "control" the environment well enough to follow any 15 | high-reward trajectory with high probability [2]. However, in stochastic 16 | environments, the agent cannot control the probability of each trajectory 17 | independently. This misspecification may lead to poor behavior. 18 | 19 | This task tests for this behavior. The agent starts at s_0 and can reach 20 | the goal s_2 (reward 1.0) by either taking the safe path s_0 to s_1 to s_2, 21 | or taking a risky action, which has equal chances of going to either s_3 22 | (reward -100.0) or s_2. The safe path has the highest expected return, but 23 | the risky action sometimes reaches the goal s_2 in fewer timesteps, leading 24 | to higher best-case return. Algorithms that fail to correctly handle 25 | stochastic dynamics may therefore wrongly believe the reward favors taking 26 | the risky path. 27 | 28 | [1] Ziebart, Brian D., et al. "Maximum entropy inverse reinforcement 29 | learning." AAAI. Vol. 8. 2008. 30 | [2] Ziebart, Brian D. "Modeling purposeful adaptive behavior with the 31 | principle of maximum causal entropy." (2010); PhD thesis, 32 | CMU-ML-10-110; page 105. 33 | """ 34 | 35 | def __init__(self): 36 | """Initialize environment.""" 37 | nS = 4 38 | nA = 2 39 | 40 | transition_matrix = np.zeros((nS, nA, nS)) 41 | transition_matrix[0, 0, 1] = 1.0 42 | transition_matrix[0, 1, [2, 3]] = 0.5 43 | 44 | transition_matrix[1, 0, 2] = 1.0 45 | transition_matrix[1, 1, 1] = 1.0 46 | 47 | transition_matrix[[2, 3], :, [2, 3]] = 1.0 48 | 49 | reward_matrix = np.array([0.0, 0.0, 1.0, -100.0]) 50 | 51 | super().__init__( 52 | transition_matrix=transition_matrix, 53 | reward_matrix=reward_matrix, 54 | ) 55 | -------------------------------------------------------------------------------- /src/seals/diagnostics/sort.py: -------------------------------------------------------------------------------- 1 | """Environment to sort a list using swap actions.""" 2 | 3 | from gymnasium import spaces 4 | import numpy as np 5 | 6 | from seals import base_envs 7 | 8 | 9 | class SortEnv(base_envs.ResettableMDP): 10 | """Environment to sort a list using swap actions.""" 11 | 12 | def __init__(self, length: int = 4): 13 | """Constructs environment. 14 | 15 | The initial state is a vector x sampled uniformly from 16 | [0,1]**L, with actions a = (i,j) swapping x_i and x_j. The 17 | reward is given according to the number of elements in the 18 | correct position. To perform well, the learned policy must 19 | compare elements, otherwise it will not generalize to all 20 | possible randomly selected initial states. 21 | """ 22 | self._length = length 23 | 24 | super().__init__() 25 | self.state_space = spaces.Box(low=0, high=1.0, shape=(length,)) 26 | self.action_space = spaces.MultiDiscrete([length, length]) 27 | 28 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: 29 | """Always returns False.""" 30 | return False 31 | 32 | def initial_state(self): 33 | """Sample random vector uniformly in [0, 1]**L.""" 34 | sample = self.np_random.random(size=self._length) 35 | return sample.astype(self.state_space.dtype) 36 | 37 | def reward( 38 | self, 39 | state: np.ndarray, 40 | action: np.ndarray, 41 | new_state: np.ndarray, 42 | ) -> float: 43 | """Rewards fully sorted lists, and new correct positions.""" 44 | del action 45 | # This is not meant to be a potential shaping in the formal sense, 46 | # as it changes the trajectory returns (since we do not return 47 | # a fixed-potential state at termination). 48 | num_correct = self._num_correct_positions(state) 49 | new_num_correct = self._num_correct_positions(new_state) 50 | potential_diff = new_num_correct - num_correct 51 | 52 | return float(self._is_sorted(new_state)) + potential_diff 53 | 54 | def transition(self, state: np.ndarray, action: np.ndarray) -> np.ndarray: 55 | """Action a = (i, j) swaps elements in positions i and j.""" 56 | new_state = state.copy() 57 | i, j = action 58 | new_state[[i, j]] = new_state[[j, i]] 59 | return new_state 60 | 61 | def _is_sorted(self, arr: np.ndarray) -> bool: 62 | return list(arr) == sorted(arr) 63 | 64 | def _num_correct_positions(self, arr: np.ndarray) -> int: 65 | return np.sum(arr == sorted(arr)) 66 | -------------------------------------------------------------------------------- /src/seals/mujoco.py: -------------------------------------------------------------------------------- 1 | """Adaptation of MuJoCo environments for specification learning algorithms.""" 2 | 3 | import functools 4 | 5 | from gymnasium.envs.mujoco import ( 6 | ant_v4, 7 | half_cheetah_v4, 8 | hopper_v4, 9 | humanoid_v4, 10 | swimmer_v4, 11 | walker2d_v4, 12 | ) 13 | 14 | 15 | def _include_position_in_observation(cls): 16 | cls.__init__ = functools.partialmethod( 17 | cls.__init__, 18 | exclude_current_positions_from_observation=False, 19 | ) 20 | return cls 21 | 22 | 23 | def _no_early_termination(cls): 24 | cls.__init__ = functools.partialmethod(cls.__init__, terminate_when_unhealthy=False) 25 | return cls 26 | 27 | 28 | @_include_position_in_observation 29 | @_no_early_termination 30 | class AntEnv(ant_v4.AntEnv): 31 | """Ant with position observation and no early termination.""" 32 | 33 | 34 | @_include_position_in_observation 35 | class HalfCheetahEnv(half_cheetah_v4.HalfCheetahEnv): 36 | """HalfCheetah with position observation. Naturally does not terminate early.""" 37 | 38 | 39 | @_include_position_in_observation 40 | @_no_early_termination 41 | class HopperEnv(hopper_v4.HopperEnv): 42 | """Hopper with position observation and no early termination.""" 43 | 44 | 45 | @_include_position_in_observation 46 | @_no_early_termination 47 | class HumanoidEnv(humanoid_v4.HumanoidEnv): 48 | """Humanoid with position observation and no early termination.""" 49 | 50 | 51 | @_include_position_in_observation 52 | class SwimmerEnv(swimmer_v4.SwimmerEnv): 53 | """Swimmer with position observation. Naturally does not terminate early.""" 54 | 55 | 56 | @_include_position_in_observation 57 | @_no_early_termination 58 | class Walker2dEnv(walker2d_v4.Walker2dEnv): 59 | """Walker2d with position observation and no early termination.""" 60 | -------------------------------------------------------------------------------- /src/seals/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HumanCompatibleAI/seals/90f68cfbb694d687f8f2cb05bdf3aa85714bbf6c/src/seals/py.typed -------------------------------------------------------------------------------- /src/seals/testing/__init__.py: -------------------------------------------------------------------------------- 1 | """Helper methods for unit tests. May also be useful for users of imitation.""" 2 | -------------------------------------------------------------------------------- /src/seals/testing/envs.py: -------------------------------------------------------------------------------- 1 | """Helper methods for tests of custom Gym environments. 2 | 3 | This is used in our test suite in `tests/test_envs.py`. It is also used in sister 4 | projects such as `imitation`, and may be useful in other codebases. 5 | """ 6 | 7 | import re 8 | from typing import ( 9 | Any, 10 | Callable, 11 | Iterable, 12 | Iterator, 13 | List, 14 | Mapping, 15 | Sequence, 16 | SupportsFloat, 17 | Tuple, 18 | ) 19 | 20 | import gymnasium as gym 21 | import numpy as np 22 | 23 | Step = Tuple[Any, SupportsFloat, bool, bool, Mapping[str, Any]] 24 | Rollout = Sequence[Step] 25 | """A sequence of 5-tuples (obs, rew, terminated, truncated, info) as returned by 26 | `get_rollout`.""" 27 | 28 | 29 | def make_env_fixture( 30 | skip_fn: Callable[[str], None], 31 | ) -> Callable[[str], Iterator[gym.Env]]: 32 | """Creates a fixture function, calling `skip_fn` when dependencies are missing. 33 | 34 | For example, in `pytest`, one would use:: 35 | 36 | env = pytest.fixture(make_env_fixture(skip_fn=pytest.skip)) 37 | 38 | Then any method with an `env` parameter will receive the created environment, with 39 | the `env_name` parameter automatically passed to the fixture. 40 | 41 | In `unittest`, one would use:: 42 | 43 | def skip_fn(msg): 44 | raise unittest.SkipTest(msg) 45 | 46 | make_env = contextlib.contextmanager(make_env_fixture(skip_fn=skip_fn)) 47 | 48 | And then call `with make_env(env_name) as env:` to create environments. 49 | 50 | Args: 51 | skip_fn: the function called when a dependency is missing to skip the test. 52 | 53 | Returns: 54 | A method to create Gym environments given their name. 55 | """ 56 | 57 | def f(env_name: str) -> Iterator[gym.Env]: 58 | """Create environment `env_name`. 59 | 60 | Args: 61 | env_name: The name of the environment in the Gym registry. 62 | 63 | Yields: 64 | The created environment. 65 | 66 | Raises: 67 | gym.error.DependencyNotInstalled: if a dependency is missing 68 | other than MuJoCo (for MuJoCo, the test is instead skipped). 69 | """ 70 | env = None 71 | try: 72 | env = gym.make(env_name) 73 | yield env 74 | except gym.error.DependencyNotInstalled as e: # pragma: no cover 75 | if e.args[0].find("mujoco_py") != -1: 76 | skip_fn("Requires `mujoco_py`, which isn't installed.") 77 | else: 78 | raise 79 | finally: 80 | if env is not None: 81 | env.close() 82 | 83 | return f 84 | 85 | 86 | def matches_list(env_name: str, patterns: Iterable[str]) -> bool: 87 | """Returns True if `env_name` matches any of the patterns in `patterns`.""" 88 | return any(re.match(env_pattern, env_name) for env_pattern in patterns) 89 | 90 | 91 | def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout: 92 | """Performs a sequence of actions `actions` in `env`. 93 | 94 | Args: 95 | env: the environment to roll out in. 96 | actions: the actions to perform. 97 | 98 | Returns: 99 | A sequence of 5-tuples (obs, rew, terminated, truncated, info). 100 | """ 101 | obs, info = env.reset() 102 | ret: List[Step] = [(obs, 0, False, False, info)] 103 | for act in actions: 104 | ret.append(env.step(act)) 105 | return ret 106 | 107 | 108 | def assert_equal_rollout(rollout_a: Rollout, rollout_b: Rollout) -> None: 109 | """Checks rollouts for equality. 110 | 111 | Raises: 112 | AssertionError if they are not equal. 113 | """ 114 | for step_a, step_b in zip(rollout_a, rollout_b): 115 | ob_a, rew_a, terminated_a, truncated_a, info_a = step_a 116 | ob_b, rew_b, terminated_b, truncated_b, info_b = step_b 117 | np.testing.assert_equal(ob_a, ob_b) 118 | assert rew_a == rew_b 119 | assert terminated_a == terminated_b 120 | assert truncated_a == truncated_b 121 | np.testing.assert_equal(info_a, info_b) 122 | 123 | 124 | def has_same_observations(rollout_a: Rollout, rollout_b: Rollout) -> bool: 125 | """True if `rollout_a` and `rollout_b` have the same observations.""" 126 | obs_list_a = [step[0] for step in rollout_a] 127 | obs_list_b = [step[0] for step in rollout_b] 128 | if len(obs_list_a) != len(obs_list_b): # pragma: no cover 129 | return False 130 | for obs_a, obs_b in zip(obs_list_a, obs_list_b): 131 | if isinstance(obs_a, Mapping): # pragma: no cover 132 | if obs_a.keys() != obs_b.keys(): 133 | return False 134 | obs_a = list(obs_a.values()) 135 | obs_b = list(obs_b.values()) 136 | else: 137 | obs_a, obs_b = [obs_a], [obs_b] 138 | if any([np.any(x != y) for x, y in zip(obs_a, obs_b)]): 139 | return False 140 | return True 141 | 142 | 143 | def test_seed( 144 | env: gym.Env, 145 | env_name: str, 146 | deterministic_envs: Iterable[str], 147 | rollout_len: int = 10, 148 | num_seeds: int = 100, 149 | ) -> None: 150 | """Tests environment seeding. 151 | 152 | If non-deterministic, different seeds should produce different transitions. 153 | If deterministic, should be invariant to seed. 154 | 155 | Raises: 156 | AssertionError if test fails. 157 | """ 158 | env.action_space.seed(0) 159 | actions = [env.action_space.sample() for _ in range(rollout_len)] 160 | # With the same seed, should always get the same result 161 | env.reset(seed=42) 162 | rollout_a = get_rollout(env, actions) 163 | 164 | env.reset(seed=42) 165 | rollout_b = get_rollout(env, actions) 166 | 167 | assert_equal_rollout(rollout_a, rollout_b) 168 | 169 | # For most non-deterministic environments, if we try enough seeds we should 170 | # eventually get a different result. For deterministic environments, all 171 | # seeds should produce the same starting state. 172 | def different_seeds_same_rollout(seed1, seed2): 173 | new_actions = [env.action_space.sample() for _ in range(rollout_len)] 174 | env.reset(seed=seed1) 175 | new_rollout_1 = get_rollout(env, new_actions) 176 | env.reset(seed=seed2) 177 | new_rollout_2 = get_rollout(env, new_actions) 178 | return has_same_observations(new_rollout_1, new_rollout_2) 179 | 180 | is_deterministic = matches_list(env_name, deterministic_envs) 181 | same_obs = all( 182 | different_seeds_same_rollout(seed, seed + 1) for seed in range(num_seeds) 183 | ) 184 | assert same_obs == is_deterministic 185 | 186 | 187 | def _check_obs(obs: np.ndarray, obs_space: gym.Space) -> None: 188 | """Check obs is consistent with obs_space.""" 189 | if obs_space.shape: 190 | assert obs.shape == obs_space.shape 191 | assert obs.dtype == obs_space.dtype 192 | assert obs in obs_space 193 | 194 | 195 | def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> bool: 196 | """Sample from env and check return value is of valid type.""" 197 | act = env.action_space.sample() 198 | obs, rew, terminated, truncated, info = env.step(act) 199 | _check_obs(obs, obs_space) 200 | assert isinstance(rew, float) 201 | assert isinstance(terminated, bool) 202 | assert isinstance(truncated, bool) 203 | assert isinstance(info, dict) 204 | return terminated or truncated 205 | 206 | 207 | def is_mujoco_env(env: gym.Env) -> bool: 208 | """True if `env` is a MuJoCo environment.""" 209 | return hasattr(env, "sim") and hasattr(env, "model") 210 | 211 | 212 | def test_rollout_schema( 213 | env: gym.Env, 214 | steps_after_terminated: int = 10, 215 | max_steps: int = 10_000, 216 | check_episode_ends: bool = True, 217 | ) -> None: 218 | """Check custom environments have correct types on `step` and `reset`. 219 | 220 | Args: 221 | env: The environment to test. 222 | steps_after_terminated: The number of steps to take after `terminated` is True, 223 | the nominal episode termination. This is an abuse of the Gym API, 224 | but we would like the environments to handle this case gracefully. 225 | max_steps: Test fails if we do not get `terminated` after this many timesteps. 226 | check_episode_ends: Check that episode ends after `max_steps` steps, and that 227 | steps taken after `terminated` is true are of the correct type. 228 | 229 | Raises: 230 | AssertionError if test fails. 231 | """ 232 | obs_space = env.observation_space 233 | obs, _ = env.reset(seed=0) 234 | _check_obs(obs, obs_space) 235 | 236 | done = False 237 | for _ in range(max_steps): 238 | done = _sample_and_check(env, obs_space) 239 | if done: 240 | break 241 | 242 | if check_episode_ends: 243 | assert done, "did not get to end of episode" 244 | 245 | for _ in range(steps_after_terminated): 246 | _sample_and_check(env, obs_space) 247 | 248 | 249 | def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None: 250 | """Test that you must call reset() before calling step(). 251 | 252 | Example usage in pytest: 253 | test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises) 254 | 255 | Args: 256 | env: The environment to test. 257 | skip_fn: called when the environment is incompatible with the test. 258 | raises_fn: Context manager to check exception is thrown. 259 | 260 | Raises: 261 | AssertionError if test fails. 262 | """ 263 | if is_mujoco_env(env): # pragma: no cover 264 | # We can't use isinstance since importing mujoco_py will fail on 265 | # machines without MuJoCo installed 266 | skip_fn("MuJoCo environments cannot perform this check.") 267 | 268 | act = env.action_space.sample() 269 | with raises_fn(Exception): # need to call env.reset() first 270 | env.step(act) 271 | 272 | 273 | class CountingEnv(gym.Env): 274 | """At timestep `t` of each episode, has `t == obs == reward / 10`. 275 | 276 | Episodes finish after `episode_length` calls to `step()`, or equivalently 277 | `episode_length` actions. For example, if we have `episode_length=5`, 278 | then an episode has the following observations and rewards: 279 | 280 | ``` 281 | obs = [0, 1, 2, 3, 4, 5] 282 | rews = [10, 20, 30, 40, 50] 283 | ``` 284 | """ 285 | 286 | def __init__(self, episode_length: int = 5): 287 | """Initialize a CountingEnv. 288 | 289 | Params: 290 | episode_length: The number of actions before each episode ends. 291 | """ 292 | assert episode_length >= 1 293 | self.observation_space = gym.spaces.Box(low=0, high=np.inf, shape=()) 294 | self.action_space = gym.spaces.Box(low=0, high=np.inf, shape=()) 295 | self.episode_length = episode_length 296 | self.timestep = None 297 | 298 | def reset(self, seed=None, options={}): 299 | """Reset method for CountingEnv.""" 300 | t, self.timestep = 0, 1 301 | return np.array(t, dtype=self.observation_space.dtype), {} 302 | 303 | def step(self, action): 304 | """Step method for CountingEnv.""" 305 | if self.timestep is None: # pragma: no cover 306 | raise RuntimeError("Need to reset before first step().") 307 | if np.array(action) not in self.action_space: # pragma: no cover 308 | raise ValueError(f"Invalid action {action}") 309 | if self.timestep > self.episode_length: # pragma: no cover 310 | raise ValueError("Should reset env. Episode is over.") 311 | 312 | t, self.timestep = self.timestep, self.timestep + 1 313 | obs = np.array(t, dtype=self.observation_space.dtype) 314 | rew = t * 10.0 315 | terminated = t == self.episode_length 316 | return obs, rew, terminated, False, {} 317 | -------------------------------------------------------------------------------- /src/seals/util.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous utilities.""" 2 | 3 | from dataclasses import dataclass 4 | from typing import ( 5 | Any, 6 | Dict, 7 | Generic, 8 | List, 9 | Optional, 10 | Sequence, 11 | SupportsFloat, 12 | Tuple, 13 | TypeVar, 14 | Union, 15 | ) 16 | 17 | import gymnasium as gym 18 | import numpy as np 19 | 20 | # Note: we redefine the type vars from gymnasium.core here, because pytype does not 21 | # recognize them as valid type vars if we import them from gymnasium.core. 22 | WrapperObsType = TypeVar("WrapperObsType") 23 | WrapperActType = TypeVar("WrapperActType") 24 | ObsType = TypeVar("ObsType") 25 | ActType = TypeVar("ActType") 26 | 27 | 28 | class AutoResetWrapper( 29 | gym.Wrapper, 30 | Generic[WrapperObsType, WrapperActType, ObsType, ActType], 31 | ): 32 | """Hides terminated truncated and auto-resets at the end of each episode. 33 | 34 | Depending on the flag 'discard_terminal_observation', either discards the terminal 35 | observation or pads with an additional 'reset transition'. The former is the default 36 | behavior. 37 | In the latter case, the action taken during the 'reset transition' will not have an 38 | effect, the reward will be constant (set by the wrapper argument `reset_reward`, 39 | which has default value 0.0), and info an empty dictionary. 40 | """ 41 | 42 | def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): 43 | """Builds the wrapper. 44 | 45 | Args: 46 | env: The environment to wrap. 47 | discard_terminal_observation: Defaults to True. If True, the terminal 48 | observation is discarded and the environment is reset immediately. The 49 | returned observation will then be the start of the next episode. The 50 | overridden observation is stored in `info["terminal_observation"]`. 51 | If False, the terminal observation is returned and the environment is 52 | reset in the next step. 53 | reset_reward: The reward to return for the reset transition. Defaults to 54 | 0.0. 55 | """ 56 | super().__init__(env) 57 | self.discard_terminal_observation = discard_terminal_observation 58 | self.reset_reward = reset_reward 59 | self.previous_done = False # Whether the previous step returned done=True. 60 | 61 | def step( 62 | self, 63 | action: WrapperActType, 64 | ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: 65 | """When terminated or truncated, resets the environment. 66 | 67 | Always returns False for terminated and truncated. 68 | 69 | Depending on whether we are discarding the terminal observation, 70 | either resets the environment and discards, 71 | or returns the terminal observation, and then uses the next step to reset the 72 | environment, after which steps will be performed as normal. 73 | """ 74 | if self.discard_terminal_observation: 75 | return self._step_discard(action) 76 | else: 77 | return self._step_pad(action) 78 | 79 | def _step_pad( 80 | self, 81 | action: WrapperActType, 82 | ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: 83 | """When terminated or truncated, resets the environment. 84 | 85 | Always returns False for terminated and truncated. 86 | 87 | The agent will then usually be asked to perform an action based on 88 | the terminal observation. In the next step, this final action will be ignored 89 | to instead reset the environment and return the initial observation of the new 90 | episode. 91 | 92 | Some potential caveats: 93 | - The underlying environment will perform fewer steps than the wrapped 94 | environment. 95 | - The number of steps the agent performs and the number of steps recorded in the 96 | underlying environment will not match, which could cause issues if these are 97 | assumed to be the same. 98 | """ 99 | if self.previous_done: 100 | self.previous_done = False 101 | reset_obs, reset_info_dict = self.env.reset() 102 | info = {"reset_info_dict": reset_info_dict} 103 | # This transition will only reset the environment, the action is ignored. 104 | return reset_obs, self.reset_reward, False, False, info 105 | 106 | obs, rew, terminated, truncated, info = self.env.step(action) 107 | self.previous_done = terminated or truncated 108 | return obs, rew, False, False, info 109 | 110 | def _step_discard( 111 | self, 112 | action: WrapperActType, 113 | ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: 114 | """When terminated or truncated, return False for both and automatically reset. 115 | 116 | When an automatic reset happens, the observation from reset is returned, 117 | and the overridden observation is stored in 118 | `info["terminal_observation"]`. 119 | """ 120 | obs, rew, terminated, truncated, info = self.env.step(action) 121 | if terminated or truncated: 122 | info["terminal_observation"] = obs 123 | obs, reset_info_dict = self.env.reset() 124 | info["reset_info_dict"] = reset_info_dict 125 | return obs, rew, False, False, info 126 | 127 | 128 | @dataclass 129 | class BoxRegion: 130 | """A rectangular region dataclass used by MaskScoreWrapper.""" 131 | 132 | x: Tuple 133 | y: Tuple 134 | 135 | 136 | MaskedRegionSpecifier = List[BoxRegion] 137 | 138 | 139 | class MaskScoreWrapper(gym.ObservationWrapper): 140 | """Mask a list of box-shaped regions in the observation to hide reward info. 141 | 142 | Intended for environments whose observations are raw pixels (like Atari 143 | environments). Used to mask regions of the observation that include information 144 | that could be used to infer the reward, like the game score or enemy ship count. 145 | """ 146 | 147 | def __init__( 148 | self, 149 | env: gym.Env, 150 | score_regions: MaskedRegionSpecifier, 151 | fill_value: Union[float, Sequence[float]] = 0, 152 | ): 153 | """Builds MaskScoreWrapper. 154 | 155 | Args: 156 | env: The environment to wrap. 157 | score_regions: A list of box-shaped regions to mask, each denoted by 158 | a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1` 159 | and `y0 < y1`. 160 | fill_value: The fill_value for the masked region. By default is black. 161 | Can support RGB colors by being a sequence of values [r, g, b]. 162 | 163 | Raises: 164 | ValueError: If a score region does not conform to the spec. 165 | """ 166 | super().__init__(env) 167 | self.fill_value = np.array(fill_value, env.observation_space.dtype) 168 | 169 | if env.observation_space.shape is None: 170 | raise ValueError("Observation space must have a shape.") # pragma: no cover 171 | self.mask = np.ones(env.observation_space.shape, dtype=bool) 172 | for r in score_regions: 173 | if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]: 174 | raise ValueError('Invalid region: "x" and "y" must be increasing.') 175 | self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0 176 | 177 | def observation(self, obs): 178 | """Returns observation with masked regions filled with `fill_value`.""" 179 | return np.where(self.mask, obs, self.fill_value) 180 | 181 | 182 | class ObsCastWrapper(gym.ObservationWrapper): 183 | """Cast observations to specified dtype. 184 | 185 | Some external environments return observations of a different type than the 186 | declared observation space. Where possible, this should be fixed upstream, 187 | but casting can be a viable workaround -- especially when the returned 188 | observations are higher resolution than the observation space. 189 | """ 190 | 191 | def __init__(self, env: gym.Env, dtype: np.dtype): 192 | """Builds ObsCastWrapper. 193 | 194 | Args: 195 | env: the environment to wrap. 196 | dtype: the dtype to cast observations to. 197 | """ 198 | super().__init__(env) 199 | self.dtype = dtype 200 | 201 | def observation(self, obs): 202 | """Returns observation cast to self.dtype.""" 203 | return obs.astype(self.dtype) 204 | 205 | 206 | class AbsorbAfterDoneWrapper(gym.Wrapper): 207 | """Transition into absorbing state instead of episode termination. 208 | 209 | When the environment being wrapped returns `terminated=True` or `truncated=True`, 210 | we return an absorbing observation. 211 | This wrapper always returns `terminated=False` and `truncated=False`. 212 | 213 | A convenient way to add absorbing states to environments like MountainCar. 214 | """ 215 | 216 | def __init__( 217 | self, 218 | env: gym.Env, 219 | absorb_reward: float = 0.0, 220 | absorb_obs: Optional[np.ndarray] = None, 221 | ): 222 | """Initialize AbsorbAfterDoneWrapper. 223 | 224 | Args: 225 | env: The wrapped Env. 226 | absorb_reward: The reward returned at the absorb state. 227 | absorb_obs: The observation returned at the absorb state. If None, then 228 | repeat the final observation before absorb. 229 | """ 230 | super().__init__(env) 231 | self.absorb_reward = absorb_reward 232 | self.absorb_obs_default = absorb_obs 233 | self.absorb_obs_this_episode = None 234 | self.at_absorb_state = None 235 | 236 | def reset(self, *args, **kwargs): 237 | """Reset the environment.""" 238 | self.at_absorb_state = False 239 | self.absorb_obs_this_episode = None 240 | return self.env.reset(*args, **kwargs) 241 | 242 | def step(self, action): 243 | """Advance the environment by one step. 244 | 245 | This wrapped `step()` always returns terminated=False and truncated=False. 246 | 247 | After the first time either terminated or truncated is returned by the 248 | underlying Env, we enter an artificial absorb state. 249 | 250 | In this artificial absorb state, we stop calling 251 | `self.env.step(action)` (i.e. the `action` argument is entirely ignored) and 252 | we return fixed values for obs, rew, terminated, truncated, and info. 253 | The values of `obs` and `rew` depend on initialization arguments. 254 | `info` is always an empty dictionary. 255 | """ 256 | if not self.at_absorb_state: 257 | obs, rew, terminated, truncated, info = self.env.step(action) 258 | if terminated or truncated: 259 | # Initialize the artificial absorb state, which we will repeatedly use 260 | # starting on the next call to `step()`. 261 | self.at_absorb_state = True 262 | 263 | if self.absorb_obs_default is None: 264 | self.absorb_obs_this_episode = obs 265 | else: 266 | self.absorb_obs_this_episode = self.absorb_obs_default 267 | else: 268 | assert self.absorb_obs_this_episode is not None 269 | assert self.absorb_reward is not None 270 | obs = self.absorb_obs_this_episode 271 | rew = self.absorb_reward 272 | info = {} 273 | 274 | return obs, rew, False, False, info 275 | 276 | 277 | def get_gym_max_episode_steps(env_name: str) -> Optional[int]: 278 | """Get the `max_episode_steps` attribute associated with a gym Spec.""" 279 | return gym.spec(env_name).max_episode_steps 280 | 281 | 282 | def sample_distribution( 283 | p: np.ndarray, 284 | random: np.random.Generator, 285 | ) -> int: 286 | """Samples an integer with probabilities given by p.""" 287 | return random.choice(np.arange(len(p)), p=p) 288 | 289 | 290 | def one_hot_encoding(pos: int, size: int) -> np.ndarray: 291 | """Returns a 1-D hot encoding of a given position and size.""" 292 | return np.eye(size)[pos] 293 | 294 | 295 | def grid_transition_fn( 296 | state: np.ndarray, 297 | action: int, 298 | x_bounds: Tuple[float, float] = (-np.inf, np.inf), 299 | y_bounds: Tuple[float, float] = (-np.inf, np.inf), 300 | ): 301 | """Returns transition of a deterministic gridworld. 302 | 303 | Agent is bounded in the region limited by x_bounds and y_bounds, 304 | ends inclusive. 305 | 306 | (0, 0) is interpreted to be top-left corner. 307 | 308 | Actions: 309 | 0: Right 310 | 1: Down 311 | 2: Left 312 | 3: Up 313 | 4: Stay put 314 | """ 315 | dirs = [ 316 | (1, 0), 317 | (0, 1), 318 | (-1, 0), 319 | (0, -1), 320 | (0, 0), 321 | ] 322 | 323 | x, y = state 324 | dx, dy = dirs[action] 325 | 326 | next_x = np.clip(x + dx, *x_bounds) 327 | next_y = np.clip(y + dy, *y_bounds) 328 | next_state = np.array([next_x, next_y], dtype=state.dtype) 329 | 330 | return next_state 331 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Configuration for pytest.""" 2 | 3 | import pytest 4 | 5 | pytest.register_assert_rewrite("seals.testing") 6 | 7 | 8 | def pytest_addoption(parser): 9 | """Add --expensive option.""" 10 | parser.addoption( 11 | "--expensive", 12 | action="store_true", 13 | dest="expensive", 14 | default=False, 15 | help="enable expensive tests", 16 | ) 17 | 18 | 19 | def pytest_collection_modifyitems(config, items): 20 | """Make expensive tests be skipped without an --expensive flag.""" 21 | if config.getoption("--expensive"): # pragma: no cover 22 | return 23 | skip_expensive = pytest.mark.skip(reason="needs --expensive option to run") 24 | for item in items: 25 | if "expensive" in item.keywords: 26 | item.add_marker(skip_expensive) 27 | -------------------------------------------------------------------------------- /tests/test_base_env.py: -------------------------------------------------------------------------------- 1 | """Test the base_envs module. 2 | 3 | Note base_envs is also tested indirectly via smoke tests in `test_envs`, 4 | so the tests in this file focus on features unique to classes in `base_envs`. 5 | """ 6 | 7 | import gymnasium as gym 8 | import numpy as np 9 | import pytest 10 | 11 | from seals import base_envs 12 | from seals.testing import envs 13 | 14 | 15 | class NewEnv(base_envs.TabularModelMDP): 16 | """Test the TabularModelMDP class.""" 17 | 18 | def __init__(self): 19 | """Build environment.""" 20 | np.random.seed(0) 21 | nS = 3 22 | nA = 2 23 | transition_matrix = np.random.random((nS, nA, nS)) 24 | transition_matrix /= transition_matrix.sum(axis=2)[:, :, None] 25 | reward_matrix = np.random.random((nS,)) 26 | super().__init__( 27 | transition_matrix=transition_matrix, 28 | reward_matrix=reward_matrix, 29 | ) 30 | 31 | 32 | def test_base_envs(): 33 | """Test parts of base_envs not covered elsewhere.""" 34 | env = NewEnv() 35 | 36 | assert np.all(np.eye(3) == env.feature_matrix) 37 | 38 | envs.test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises) 39 | 40 | env.reset(seed=0) 41 | assert env.n_actions_taken == 0 42 | env.step(env.action_space.sample()) 43 | assert env.n_actions_taken == 1 44 | env.step(env.action_space.sample()) 45 | assert env.n_actions_taken == 2 46 | 47 | new_state = env.state_space.sample() 48 | env.state = new_state 49 | assert env.state == new_state 50 | 51 | bad_state = "not a state" 52 | with pytest.raises(ValueError, match=r".*not in.*"): 53 | env.state = bad_state # type: ignore 54 | 55 | with pytest.raises(NotImplementedError, match=r"Options not supported.*"): 56 | env.reset(options={"option": "value"}) 57 | 58 | 59 | def test_tabular_env_validation(): 60 | """Test input validation for base_envs.TabularModelEnv.""" 61 | with pytest.raises(ValueError, match=r"Malformed transition_matrix.*"): 62 | base_envs.TabularModelMDP( 63 | transition_matrix=np.zeros((3, 1, 4)), 64 | reward_matrix=np.zeros((3,)), 65 | ) 66 | with pytest.raises(ValueError, match=r"initial_state_dist has multiple.*"): 67 | base_envs.TabularModelMDP( 68 | transition_matrix=np.zeros((3, 1, 3)), 69 | reward_matrix=np.zeros((3,)), 70 | initial_state_dist=np.zeros((3, 4)), 71 | ) 72 | with pytest.raises(ValueError, match=r"transition_matrix and initial_state_dist.*"): 73 | base_envs.TabularModelMDP( 74 | transition_matrix=np.zeros((3, 1, 3)), 75 | reward_matrix=np.zeros((3,)), 76 | initial_state_dist=np.zeros((2)), 77 | ) 78 | with pytest.raises(ValueError, match=r"transition_matrix and reward_matrix.*"): 79 | base_envs.TabularModelMDP( 80 | transition_matrix=np.zeros((4, 1, 4)), 81 | reward_matrix=np.zeros((3,)), 82 | ) 83 | with pytest.raises(ValueError, match=r"transition_matrix and observation_matrix.*"): 84 | base_envs.TabularModelPOMDP( 85 | transition_matrix=np.zeros((3, 1, 3)), 86 | reward_matrix=np.zeros((3,)), 87 | observation_matrix=np.zeros((4, 3)), 88 | ) 89 | 90 | env = base_envs.TabularModelMDP( 91 | transition_matrix=np.zeros((3, 1, 3)), 92 | reward_matrix=np.zeros((3,)), 93 | ) 94 | env.reset(seed=0) 95 | with pytest.raises(ValueError, match=r".*not in.*"): 96 | env.step(4) 97 | 98 | 99 | def test_expose_pomdp_state_wrapper(): 100 | """Test the ExposePOMDPStateWrapper class.""" 101 | env = NewEnv() 102 | wrapped_env = base_envs.ExposePOMDPStateWrapper(env) 103 | 104 | assert wrapped_env.observation_space == env.state_space 105 | state, _ = wrapped_env.reset(seed=0) 106 | assert state == env.state 107 | assert state in env.state_space 108 | 109 | action = env.action_space.sample() 110 | next_state, reward, terminated, truncated, info = wrapped_env.step(action) 111 | assert next_state == env.state 112 | assert next_state in env.state_space 113 | 114 | 115 | def test_tabular_pompd_obs_space_int(): 116 | """Test the TabularModelPOMDP class with an integer observation space.""" 117 | env = base_envs.TabularModelPOMDP( 118 | transition_matrix=np.zeros( 119 | (3, 1, 3), 120 | ), 121 | reward_matrix=np.zeros((3,)), 122 | observation_matrix=np.zeros((3, 3), dtype=np.int64), 123 | ) 124 | assert isinstance(env.observation_space, gym.spaces.Box) 125 | assert env.observation_space.dtype == np.int64 126 | -------------------------------------------------------------------------------- /tests/test_diagnostics.py: -------------------------------------------------------------------------------- 1 | """Test the `diagnostics.*` environments.""" 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from seals.diagnostics import cliff_world, init_shift, random_trans 7 | 8 | 9 | def test_init_shift_validation(): 10 | """Test input validation for init_shift.InitShiftEnv.""" 11 | for invalid_state in [-1, 7, 8, 100]: 12 | with pytest.raises(ValueError, match=r"Initial state.*"): 13 | init_shift.InitShiftEnv(initial_state=invalid_state) 14 | 15 | 16 | def test_cliff_world_draw_value_vec(): 17 | """Smoke test for cliff_world.CliffWorldEnv.draw_value_vec().""" 18 | env = cliff_world.CliffWorldEnv( 19 | width=7, 20 | height=4, 21 | horizon=9, 22 | use_xy_obs=False, 23 | ) 24 | D = np.zeros(env.state_dim) 25 | env.draw_value_vec(D) 26 | 27 | 28 | def test_random_transition_env_init(): 29 | """Test that RandomTransitionEnv initializes correctly.""" 30 | random_trans.RandomTransitionEnv( 31 | n_states=3, 32 | n_actions=2, 33 | branch_factor=3, 34 | horizon=10, 35 | random_obs=False, 36 | ) 37 | random_trans.RandomTransitionEnv( 38 | n_states=3, 39 | n_actions=2, 40 | branch_factor=3, 41 | horizon=10, 42 | random_obs=True, 43 | ) 44 | random_trans.RandomTransitionEnv( 45 | n_states=3, 46 | n_actions=2, 47 | branch_factor=3, 48 | horizon=10, 49 | random_obs=True, 50 | obs_dim=10, 51 | ) 52 | with pytest.raises(ValueError, match="obs_dim must be None if random_obs is False"): 53 | random_trans.RandomTransitionEnv( 54 | n_states=3, 55 | n_actions=2, 56 | branch_factor=3, 57 | horizon=10, 58 | random_obs=False, 59 | obs_dim=3, 60 | ) 61 | 62 | 63 | def test_make_random_matrices_no_explicit_rng(): 64 | """Test that random matrix maker static methods work without an explicit RNG.""" 65 | random_trans.RandomTransitionEnv.make_random_trans_mat(3, 2, 3) 66 | random_trans.RandomTransitionEnv.make_random_state_dist(3, 3) 67 | random_trans.RandomTransitionEnv.make_obs_mat(3, True, 3) 68 | with pytest.raises(ValueError, match="obs_dim must be set if random_obs is True"): 69 | random_trans.RandomTransitionEnv.make_obs_mat(3, True) 70 | random_trans.RandomTransitionEnv.make_obs_mat(3, False) 71 | with pytest.raises(ValueError, match="obs_dim must be None if random_obs is False"): 72 | random_trans.RandomTransitionEnv.make_obs_mat(3, False, 3) 73 | -------------------------------------------------------------------------------- /tests/test_envs.py: -------------------------------------------------------------------------------- 1 | """Smoke tests for all environments.""" 2 | 3 | from typing import List, Union 4 | 5 | import gymnasium as gym 6 | from gymnasium.envs import registration 7 | import numpy as np 8 | import pytest 9 | 10 | import seals # noqa: F401 required for env registration 11 | from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env 12 | from seals.testing import envs 13 | from seals.testing.envs import is_mujoco_env 14 | 15 | ENV_NAMES: List[str] = [ 16 | env_id for env_id in registration.registry.keys() if env_id.startswith("seals/") 17 | ] 18 | 19 | 20 | DETERMINISTIC_ENVS: List[str] = [ 21 | "seals/EarlyTermPos-v0", 22 | "seals/EarlyTermNeg-v0", 23 | "seals/Branching-v0", 24 | "seals/InitShiftTrain-v0", 25 | "seals/InitShiftTest-v0", 26 | ] 27 | 28 | UNMASKED_ATARI_ENVS: List[str] = [ 29 | _seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS 30 | ] 31 | MASKED_ATARI_ENVS: List[str] = [ 32 | _seals_name(gym_spec, masked=True) 33 | for gym_spec in seals.GYM_ATARI_ENV_SPECS 34 | if _get_score_region(gym_spec.id) is not None 35 | ] 36 | ATARI_ENVS = UNMASKED_ATARI_ENVS + MASKED_ATARI_ENVS 37 | 38 | ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS)) 39 | ATARI_NO_FRAMESKIP_ENVS: List[str] = list( 40 | filter(lambda name: name.endswith("-v4"), ATARI_ENVS), 41 | ) 42 | 43 | DETERMINISTIC_ENVS += ATARI_NO_FRAMESKIP_ENVS 44 | 45 | 46 | env = pytest.fixture(envs.make_env_fixture(skip_fn=pytest.skip)) 47 | 48 | 49 | def test_some_atari_envs(): 50 | """Tests if we succeeded in finding any Atari envs.""" 51 | assert len(seals.GYM_ATARI_ENV_SPECS) > 0 52 | 53 | 54 | def test_atari_space_invaders(): 55 | """Tests for masked and unmasked Atari space invaders environments.""" 56 | masked_space_invader_environments = list( 57 | filter( 58 | lambda name: "SpaceInvaders" in name and "Unmasked" not in name, 59 | ATARI_ENVS, 60 | ), 61 | ) 62 | assert len(masked_space_invader_environments) > 0 63 | 64 | unmasked_space_invader_environments = list( 65 | filter( 66 | lambda name: "SpaceInvaders" in name and "Unmasked" in name, 67 | ATARI_ENVS, 68 | ), 69 | ) 70 | assert len(unmasked_space_invader_environments) > 0 71 | 72 | 73 | def test_atari_unmasked_env_naming(): 74 | """Tests that all unmasked Atari envs have the appropriate name qualifier.""" 75 | noncompliant_envs = list( 76 | filter( 77 | lambda name: _get_score_region(name) is None and "Unmasked" not in name, 78 | ATARI_ENVS, 79 | ), 80 | ) 81 | assert len(noncompliant_envs) == 0 82 | 83 | 84 | def test_make_unsupported_masked_atari_env_throws_error(): 85 | """Tests that making an unsupported masked Atari env throws an error.""" 86 | match_str = ( 87 | "Requested environment does not yet support masking. " 88 | "See https://github.com/HumanCompatibleAI/seals/issues/61." 89 | ) 90 | with pytest.raises(ValueError, match=match_str): 91 | make_atari_env("ALE/Bowling-v5", masked=True) 92 | 93 | 94 | def test_atari_masks_satisfy_spec(): 95 | """Tests that all Atari masks satisfy the spec.""" 96 | masks_satisfy_spec = [ 97 | mask.x[0] < mask.x[1] and mask.y[0] < mask.y[1] 98 | for env_regions in SCORE_REGIONS.values() 99 | for mask in env_regions 100 | ] 101 | assert all(masks_satisfy_spec) 102 | 103 | 104 | @pytest.mark.parametrize("env_name", ENV_NAMES) 105 | class TestEnvs: 106 | """Battery of simple tests for environments.""" 107 | 108 | def test_seed(self, env: gym.Env, env_name: str): 109 | """Tests environment seeding. 110 | 111 | Deterministic Atari environments are run with fewer seeds to minimize the number 112 | of resets done in this test suite, since Atari resets take a long time and there 113 | are many Atari environments. 114 | """ 115 | if env_name in ATARI_ENVS: 116 | # these environments take a while for their non-determinism to show. 117 | slow_random_envs = [ 118 | "seals/Bowling-Unmasked-v5", 119 | "seals/Frogger-Unmasked-v5", 120 | "seals/KingKong-Unmasked-v5", 121 | "seals/Koolaid-Unmasked-v5", 122 | "seals/NameThisGame-Unmasked-v5", 123 | "seals/Casino-Unmasked-v5", 124 | ] 125 | rollout_len = 100 if env_name not in slow_random_envs else 400 126 | num_seeds = 2 if env_name in ATARI_NO_FRAMESKIP_ENVS else 10 127 | envs.test_seed( 128 | env, 129 | env_name, 130 | DETERMINISTIC_ENVS, 131 | rollout_len=rollout_len, 132 | num_seeds=num_seeds, 133 | ) 134 | else: 135 | envs.test_seed(env, env_name, DETERMINISTIC_ENVS) 136 | 137 | def test_premature_step(self, env: gym.Env): 138 | """Tests if step() before reset() raises error.""" 139 | envs.test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises) 140 | 141 | def test_rollout_schema(self, env: gym.Env, env_name: str): 142 | """Tests if environments have correct types on `step()` and `reset()`. 143 | 144 | Atari environments have a very long episode length (~100k observations), so in 145 | the interest of time we do not run them to the end of their episodes or check 146 | the return time of `env.step` after the end of the episode. 147 | """ 148 | if env_name in ATARI_ENVS: 149 | envs.test_rollout_schema(env, max_steps=1_000, check_episode_ends=False) 150 | else: 151 | envs.test_rollout_schema(env) 152 | 153 | def test_render_modes(self, env_name: str): 154 | """Tests that all render modes specifeid in the metadata work. 155 | 156 | Note: we only check that no exception is thrown. 157 | There is no test to see if something reasonable is rendered. 158 | """ 159 | for mode in gym.make(env_name).metadata["render_modes"]: 160 | # GIVEN 161 | env = gym.make(env_name, render_mode=mode) 162 | env.reset(seed=0) 163 | 164 | # WHEN 165 | if mode == "rgb_array" and not is_mujoco_env(env): 166 | # The render should not change without calling `step()`. 167 | # MuJoCo rendering fails this check, ignore -- not much we can do. 168 | r1: Union[np.ndarray, List[np.ndarray], None] = env.render() 169 | r2: Union[np.ndarray, List[np.ndarray], None] = env.render() 170 | assert r1 is not None 171 | assert r2 is not None 172 | assert np.allclose(r1, r2) 173 | else: 174 | env.render() 175 | 176 | # THEN 177 | # no error raised 178 | 179 | # CLEANUP 180 | env.close() 181 | -------------------------------------------------------------------------------- /tests/test_mujoco_rl.py: -------------------------------------------------------------------------------- 1 | """Test RL on MuJoCo adapted environments.""" 2 | 3 | from typing import Tuple 4 | 5 | import gymnasium as gym 6 | import pytest 7 | import stable_baselines3 8 | from stable_baselines3.common import evaluation 9 | 10 | import seals # noqa: F401 Import required for env registration 11 | 12 | 13 | def _eval_env( 14 | env_name: str, 15 | total_timesteps: int, 16 | ) -> Tuple[float, float]: # pragma: no cover 17 | """Train PPO2 for `total_timesteps` on `env_name` and evaluate returns.""" 18 | env = gym.make(env_name) 19 | model = stable_baselines3.PPO("MlpPolicy", env) 20 | model.learn(total_timesteps=total_timesteps) 21 | res = evaluation.evaluate_policy(model, env) 22 | assert isinstance(res[0], float) 23 | return res 24 | 25 | 26 | # SOMEDAY(adam): tests are flaky and consistently fail in some environments 27 | # Unclear if they even should pass in some cases. 28 | # See discussion in GH#6 and GH#40. 29 | @pytest.mark.expensive 30 | @pytest.mark.parametrize( 31 | "env_base", 32 | ["HalfCheetah", "Ant", "Hopper", "Humanoid", "Swimmer", "Walker2d"], 33 | ) 34 | def test_fixed_env_model_as_good_as_gym_env_model(env_base: str): # pragma: no cover 35 | """Compare original and modified MuJoCo v3 envs.""" 36 | train_timesteps = 200000 37 | 38 | gym_reward, _ = _eval_env(f"{env_base}-v4", total_timesteps=train_timesteps) 39 | fixed_reward, _ = _eval_env( 40 | f"seals/{env_base}-v1", 41 | total_timesteps=train_timesteps, 42 | ) 43 | 44 | epsilon = 0.1 45 | sign = 1 if gym_reward > 0 else -1 46 | assert (1 - sign * epsilon) * gym_reward <= fixed_reward 47 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | """Test `seals.util`.""" 2 | 3 | import collections 4 | 5 | import gymnasium as gym 6 | import numpy as np 7 | import pytest 8 | 9 | from seals import GYM_ATARI_ENV_SPECS, util 10 | 11 | 12 | def test_mask_score_wrapper_enforces_spec(): 13 | """Test that MaskScoreWrapper enforces the spec.""" 14 | atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) 15 | desired_error_message = 'Invalid region: "x" and "y" must be increasing.' 16 | with pytest.raises(ValueError, match=desired_error_message): 17 | util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(0, 1), y=(1, 0))]) 18 | with pytest.raises(ValueError, match=desired_error_message): 19 | util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(1, 0), y=(0, 1))]) 20 | 21 | 22 | def test_sample_distribution(): 23 | """Test util.sample_distribution.""" 24 | distr_size = 5 25 | distr = np.random.random((distr_size,)) 26 | distr /= distr.sum() 27 | 28 | n_samples = 1000 29 | rng = np.random.default_rng(0) 30 | sample_count = collections.Counter( 31 | util.sample_distribution(distr, rng) for _ in range(n_samples) 32 | ) 33 | 34 | empirical_distr = np.array([sample_count[i] for i in range(distr_size)]) / n_samples 35 | 36 | # Empirical distribution matches real distribution 37 | l1_err = np.sum(np.abs(empirical_distr - distr)) 38 | assert l1_err < 0.1 39 | 40 | # Same seed gives same samples 41 | assert all( 42 | util.sample_distribution(distr, random=np.random.default_rng(seed)) 43 | == util.sample_distribution(distr, random=np.random.default_rng(seed)) 44 | for seed in range(20) 45 | ) 46 | 47 | 48 | def test_one_hot_encoding(): 49 | """Test util.one_hot_encoding.""" 50 | Case = collections.namedtuple("Case", ["pos", "size", "encoding"]) 51 | 52 | cases = [ 53 | Case(pos=0, size=1, encoding=np.array([1.0])), 54 | Case(pos=1, size=5, encoding=np.array([0.0, 1.0, 0.0, 0.0, 0.0])), 55 | Case(pos=3, size=4, encoding=np.array([0.0, 0.0, 0.0, 1.0])), 56 | Case(pos=2, size=3, encoding=np.array([0.0, 0.0, 1.0])), 57 | Case(pos=2, size=6, encoding=np.array([0.0, 0.0, 1.0, 0.0, 0.0, 0.0])), 58 | ] 59 | 60 | assert all( 61 | np.all(util.one_hot_encoding(pos, size) == encoding) 62 | for pos, size, encoding in cases 63 | ) 64 | -------------------------------------------------------------------------------- /tests/test_wrappers.py: -------------------------------------------------------------------------------- 1 | """Tests for wrapper classes.""" 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from seals import util 7 | from seals.testing import envs 8 | 9 | 10 | def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2): 11 | """Check that AutoResetWrapper returns correct values from step and reset. 12 | 13 | AutoResetWrapper that pads trajectory with an extra transition containing the 14 | terminal observations. 15 | Also check that calls to .reset() do not interfere with automatic resets. 16 | Due to the padding, the number of steps counted inside the environment and the 17 | number of steps performed outside the environment, i.e., the number of actions 18 | performed, will differ. This test checks that this difference is consistent. 19 | """ 20 | env = util.AutoResetWrapper( 21 | envs.CountingEnv(episode_length=episode_length), 22 | discard_terminal_observation=False, 23 | ) 24 | 25 | for _ in range(n_manual_reset): 26 | obs, info = env.reset() 27 | assert obs == 0 28 | 29 | # We count the number of episodes, so we can sanity check the padding. 30 | num_episodes = 0 31 | next_episode_end = episode_length 32 | for t in range(1, n_steps + 1): 33 | act = env.action_space.sample() 34 | obs, rew, terminated, truncated, info = env.step(act) 35 | 36 | # AutoResetWrapper overrides all terminated and truncated signals. 37 | assert terminated is False 38 | assert truncated is False 39 | 40 | if t == next_episode_end: 41 | # Unlike the AutoResetWrapper that discards terminal observations, 42 | # here the final observation is returned directly, and is not stored 43 | # in the info dict. 44 | # Due to padding, for every episode the final observation is offset from 45 | # the outer step by one. 46 | assert obs == (t - num_episodes) / (num_episodes + 1) 47 | assert rew == episode_length * 10 48 | if t == next_episode_end + 1: 49 | num_episodes += 1 50 | # Because the final step returned the final observation, the initial 51 | # obs of the next episode is returned in this additional step. 52 | assert obs == 0 53 | # Consequently, the next episode end is one step later, so it is 54 | # episode_length steps from now. 55 | next_episode_end = t + episode_length 56 | 57 | # Reward of the 'reset transition' is fixed to be 0. 58 | assert rew == 0 59 | 60 | # Sanity check padding. Padding should be 1 for each past episode. 61 | assert ( 62 | next_episode_end 63 | == (num_episodes + 1) * episode_length + num_episodes 64 | ) 65 | 66 | 67 | def test_auto_reset_wrapper_discard(episode_length=3, n_steps=100, n_manual_reset=2): 68 | """Check that AutoResetWrapper returns correct values from step and reset. 69 | 70 | Tests for AutoResetWrapper that discards terminal observations. 71 | Also check that calls to .reset() do not interfere with automatic resets. 72 | """ 73 | env = util.AutoResetWrapper( 74 | envs.CountingEnv(episode_length=episode_length), 75 | discard_terminal_observation=True, 76 | ) 77 | 78 | for _ in range(n_manual_reset): 79 | obs, info = env.reset() 80 | assert obs == 0 81 | 82 | for t in range(1, n_steps + 1): 83 | act = env.action_space.sample() 84 | obs, rew, terminated, truncated, info = env.step(act) 85 | expected_obs = t % episode_length 86 | 87 | assert obs == expected_obs 88 | assert terminated is False 89 | assert truncated is False 90 | 91 | if expected_obs == 0: # End of episode 92 | assert info.get("terminal_observation", None) == episode_length 93 | assert rew == episode_length * 10 94 | else: 95 | assert "terminal_observation" not in info 96 | assert rew == expected_obs * 10 97 | 98 | 99 | def test_absorb_repeat_custom_state( 100 | absorb_reward=-4, 101 | absorb_obs=-3.0, 102 | episode_length=6, 103 | n_steps=100, 104 | n_manual_reset=3, 105 | ): 106 | """Check that AbsorbAfterDoneWrapper returns custom state and reward.""" 107 | env = envs.CountingEnv(episode_length=episode_length) 108 | env = util.AbsorbAfterDoneWrapper( 109 | env, 110 | absorb_reward=absorb_reward, 111 | absorb_obs=absorb_obs, 112 | ) 113 | 114 | for r in range(n_manual_reset): 115 | env.reset() 116 | for t in range(1, n_steps + 1): 117 | act = env.action_space.sample() 118 | obs, rew, terminated, truncated, _ = env.step(act) 119 | assert terminated is False 120 | assert truncated is False 121 | if t > episode_length: 122 | expected_obs = absorb_obs 123 | expected_rew = absorb_reward 124 | else: 125 | expected_obs = t 126 | expected_rew = t * 10.0 127 | assert obs == expected_obs 128 | assert rew == expected_rew 129 | 130 | 131 | def test_absorb_repeat_final_state(episode_length=6, n_steps=100, n_manual_reset=3): 132 | """Check that AbsorbAfterDoneWrapper can repeat final state.""" 133 | env = envs.CountingEnv(episode_length=episode_length) 134 | env = util.AbsorbAfterDoneWrapper(env, absorb_reward=-1, absorb_obs=None) 135 | 136 | for _ in range(n_manual_reset): 137 | env.reset() 138 | for t in range(1, n_steps + 1): 139 | act = env.action_space.sample() 140 | obs, rew, terminated, truncated, _ = env.step(act) 141 | assert terminated is False 142 | assert truncated is False 143 | if t > episode_length: 144 | expected_obs = episode_length 145 | expected_rew = -1 146 | else: 147 | expected_obs = t 148 | expected_rew = t * 10.0 149 | assert obs == expected_obs 150 | assert rew == expected_rew 151 | 152 | 153 | @pytest.mark.parametrize("dtype", [np.int64, np.float32, np.float64]) 154 | def test_obs_cast(dtype: np.dtype, episode_length: int = 5): 155 | """Check obs_cast observations are of specified dtype and not mangled. 156 | 157 | Test uses CountingEnv with small integers, which can be represented in 158 | all the specified dtypes without any loss of precision. 159 | """ 160 | env = util.ObsCastWrapper( 161 | envs.CountingEnv(episode_length=episode_length), 162 | dtype, 163 | ) 164 | 165 | obs, _ = env.reset() 166 | assert obs.dtype == dtype 167 | assert obs == 0 168 | for t in range(1, episode_length + 1): 169 | act = env.action_space.sample() 170 | obs, rew, terminated, truncated, _ = env.step(act) 171 | assert terminated == (t == episode_length) 172 | assert truncated is False 173 | assert obs.dtype == dtype 174 | assert obs == t 175 | assert rew == t * 10.0 176 | --------------------------------------------------------------------------------