├── .circleci
└── config.yml
├── .codespell.skip
├── .dockerignore
├── .github
└── workflows
│ └── publish-to-pypi.yml
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── ci
├── Xdummy-entrypoint.py
├── build_venv.sh
├── code_checks.sh
└── xorg.conf
├── codecov.yml
├── docs
├── Makefile
├── _static
│ └── img
│ │ └── logo.svg
├── common
│ ├── base_envs.rst
│ ├── testing.rst
│ └── util.rst
├── conf.py
├── environments
│ ├── diagnostic.rst
│ └── renovated.rst
├── guide
│ └── install.rst
├── index.rst
└── make.bat
├── mypy.ini
├── pyproject.toml
├── readthedocs.yml
├── setup.cfg
├── setup.py
├── src
└── seals
│ ├── __init__.py
│ ├── atari.py
│ ├── base_envs.py
│ ├── classic_control.py
│ ├── diagnostics
│ ├── __init__.py
│ ├── branching.py
│ ├── cliff_world.py
│ ├── early_term.py
│ ├── init_shift.py
│ ├── largest_sum.py
│ ├── noisy_obs.py
│ ├── parabola.py
│ ├── proc_goal.py
│ ├── random_trans.py
│ ├── risky_path.py
│ └── sort.py
│ ├── mujoco.py
│ ├── py.typed
│ ├── testing
│ ├── __init__.py
│ └── envs.py
│ └── util.py
└── tests
├── conftest.py
├── test_base_env.py
├── test_diagnostics.py
├── test_envs.py
├── test_mujoco_rl.py
├── test_util.py
└── test_wrappers.py
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1 # Version of CircleCI config format
2 |
3 | # "Orbs" are reusable packages of CircleCI config.
4 | # They can simplify common tasks, such as interacting with external services.
5 | # This section lists orbs this config uses.
6 | orbs:
7 | codecov: codecov/codecov@1.1.0 # support for uploading code coverage to codecov
8 |
9 | defaults: &defaults
10 | docker:
11 | - image: humancompatibleai/seals:base-alpha
12 | auth:
13 | username: $DOCKERHUB_USERNAME
14 | password: $DOCKERHUB_PASSWORD
15 | working_directory: /seals
16 |
17 | executors:
18 | unit-test:
19 | <<: *defaults
20 | resource_class: large
21 | environment:
22 | # Don't use auto-detect since it sees all CPUs available, but container is throttled.
23 | NUM_CPUS: 4
24 | lintandtype:
25 | <<: *defaults
26 | resource_class: medium
27 | environment:
28 | # If you change these, also change ci/code_checks.sh
29 | LINT_FILES: src/ tests/ docs/conf.py setup.py # files we lint
30 | # Files we statically type check. Source files like src/ should almost always be present.
31 | # In this repo we also typecheck tests/ -- but sometimes you may want to exclude these
32 | # if they do strange things with types (e.g. mocking).
33 | TYPECHECK_FILES: src/ tests/ setup.py
34 | # Don't use auto-detect since it sees all CPUs available, but container is throttled.
35 | NUM_CPUS: 2
36 |
37 |
38 | commands:
39 | # Define common function to install dependencies and seals, used in the jobs defined in the next section
40 | dependencies:
41 | description: "Check out and update Python dependencies."
42 | steps:
43 | - checkout # Check out the code from Git
44 |
45 | # Download and cache dependencies
46 | # Note the Docker image must still be manually updated if any binary (non-Python) dependencies change.
47 |
48 | # Restore cache if it exists. setup.py defines all the requirements, so we checksum that.
49 | # If you want to force an update despite setup.py not changing, you can bump the version
50 | # number `vn-dependencies`. This can be useful if newer versions of a package have been
51 | # released that you want to upgrade to, without mandating the newer version in setup.py.
52 | - restore_cache:
53 | keys:
54 | - v2-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_venv.sh" }}
55 |
56 | # Create virtual environment and install dependencies using `ci/build_venv.sh`.
57 | # `mujoco_py` needs a MuJoCo key, so download that first.
58 | # We do some sanity checks to ensure the key works.
59 | - run:
60 | name: install dependencies
61 | command: "[[ -d /venv ]] || /seals/ci/build_venv.sh /venv"
62 |
63 | # Save the cache of dependencies.
64 | - save_cache:
65 | paths:
66 | - /venv
67 | key: v2-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_venv.sh" }}
68 |
69 | # Install seals.
70 | # Note we install the source distribution, not in developer mode (`pip install -e`).
71 | # This ensures we're testing the package our users would experience, and in particular
72 | # will catch e.g. modules or data files missing from `setup.py`.
73 | - run:
74 | name: install evaluating_rewards
75 | # Build a wheel then install to avoid copying whole directory (pip issue #2195)
76 | command: |
77 | python setup.py sdist bdist_wheel
78 | pip install --upgrade --force-reinstall --no-deps dist/seals-*.whl
79 |
80 | # The `jobs` section defines jobs that can be executed on CircleCI as part of workflows.
81 | jobs:
82 | # `lintandtype` installs dependencies + `seals`, lints the code, builds the docs, and runs type checks.
83 | lintandtype:
84 | executor: lintandtype
85 |
86 | steps:
87 | - dependencies
88 | - run:
89 | name: flake8
90 | command: flake8 ${LINT_FILES}
91 |
92 | - run:
93 | name: black
94 | command: black --check ${LINT_FILES}
95 |
96 | - run:
97 | name: codespell
98 | command: codespell -I .codespell.skip --skip='*.pyc' ${LINT_FILES}
99 |
100 | - run:
101 | name: sphinx
102 | command: pushd docs/ && make clean && make html && popd
103 |
104 | - run:
105 | name: pytype
106 | command: pytype ${TYPECHECK_FILES}
107 |
108 | - run:
109 | name: mypy
110 | command: mypy ${TYPECHECK_FILES}
111 |
112 | # `unit-test` runs the unit tests in `tests/`.
113 | unit-test:
114 | executor: unit-test
115 | steps:
116 | - dependencies
117 |
118 | # Running out of memory is a common cause of spurious test failures.
119 | # In particular, the CI machines have less memory than most workstations.
120 | # So tests can pass locally but fail on CI. Record memory and other resource
121 | # usage over time to aid with diagnosing these failures.
122 | - run:
123 | name: Memory Monitor
124 | # | is needed for multi-line values in YAML
125 | command: |
126 | mkdir /tmp/resource-usage
127 | export FILE=/tmp/resource-usage/memory.txt
128 | while true; do
129 | ps -u root eo pid,%cpu,%mem,args,uname --sort=-%mem >> $FILE
130 | echo "----------" >> $FILE
131 | sleep 1
132 | done
133 | background: true
134 |
135 | # Run the unit tests themselves
136 | - run:
137 | name: run tests
138 | command: |
139 | # Xdummy-entrypoint.py: starts an X server and sets DISPLAY, then runs wrapped command.
140 | # pytest arguments:
141 | # --cov specifies which directories to report code coverage for
142 | # Since we test the installed `seals`, our source files live in `venv`, not in `src/seals`.
143 | # --junitxml records test results in JUnit format. We upload this file using `store_test_results`
144 | # later, and CircleCI then parses this to pretty-print results.
145 | # --shard-id and --num-shards are used to split tests across parallel executors using `pytest-shard`.
146 | # -n uses `pytest-xdist` to parallelize tests within a single instance.
147 | Xdummy-entrypoint.py pytest --cov=/venv/lib/python3.8/site-packages/seals --cov=tests \
148 | --junitxml=/tmp/test-reports/junit.xml \
149 | -n ${NUM_CPUS} -vv tests/
150 | # Following two lines rewrite paths from venv/ to src/, based on `coverage:paths` in `setup.cfg`
151 | # This is needed to avoid confusing Codecov
152 | mv .coverage .coverage.bench
153 | coverage combine
154 | - codecov/upload
155 |
156 | # Upload the test results and resource usage to CircleCI
157 | - store_artifacts:
158 | path: /tmp/test-reports
159 | destination: test-reports
160 | - store_artifacts:
161 | path: /tmp/resource-usage
162 | destination: resource-usage
163 | # store_test_results uploads the files and tells CircleCI that it should parse them as test results
164 | - store_test_results:
165 | path: /tmp/test-reports
166 |
167 | # Workflows specify what jobs to actually run on CircleCI. If we didn't specify this,
168 | # nothing would run! Here we have just a single workflow, `test`, containing both the
169 | # jobs defined above. By default, the jobs all run in parallel. We can make them run
170 | # sequentially, or have more complex dependency structures, using the `require` command;
171 | # see https://circleci.com/docs/2.0/workflows/
172 | #
173 | # We attach two contexts to both jobs, which define a set of environment variable:
174 | # - `MuJoCo` which contains the URL for our MuJoCo license key.
175 | # - `docker-hub-creds` which contain the credentials for our Dockerhub machine user.
176 | # It's important these are kept confidential -- so don't echo the environment variables
177 | # anywhere in the config!
178 | workflows:
179 | version: 2
180 | test:
181 | jobs:
182 | - lintandtype:
183 | context:
184 | - docker-hub-creds
185 | - unit-test:
186 | context:
187 | - docker-hub-creds
188 |
--------------------------------------------------------------------------------
/.codespell.skip:
--------------------------------------------------------------------------------
1 | ith
2 | reacher
3 | iff
4 |
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | .gitignore
--------------------------------------------------------------------------------
/.github/workflows/publish-to-pypi.yml:
--------------------------------------------------------------------------------
1 | # Adapted from https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/
2 |
3 | name: Publish seals distributions 📦 to PyPI and TestPyPI
4 |
5 | on: push
6 |
7 | jobs:
8 | build-n-publish:
9 | name: Build and publish seals distributions 📦 to PyPI and TestPyPI
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v3
14 | with:
15 | # Fetch tags needed by setuptools_scm to infer version number
16 | # See https://github.com/pypa/setuptools_scm/issues/414
17 | fetch-depth: 0
18 | - name: Set up Python 3.10
19 | uses: actions/setup-python@v3
20 | with:
21 | python-version: "3.10"
22 |
23 | - name: Install pypa/build
24 | run: >-
25 | python -m
26 | pip install
27 | build
28 | --user
29 | - name: Build a binary wheel and a source tarball
30 | run: >-
31 | python -m
32 | build
33 | --sdist
34 | --wheel
35 | --outdir dist/
36 | .
37 |
38 | # Publish new distribution to Test PyPi on every push.
39 | # This ensures the workflow stays healthy, and will also serve
40 | # as a source of alpha builds.
41 | - name: Publish distribution 📦 to Test PyPI
42 | uses: pypa/gh-action-pypi-publish@release/v1
43 | with:
44 | password: ${{ secrets.TEST_PYPI_API_TOKEN }}
45 | repository_url: https://test.pypi.org/legacy/
46 |
47 | # Publish new distribution to production PyPi on releases.
48 | - name: Publish distribution 📦 to PyPI
49 | if: startsWith(github.ref, 'refs/tags/v')
50 | uses: pypa/gh-action-pypi-publish@release/v1
51 | with:
52 | password: ${{ secrets.PYPI_API_TOKEN }}
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # PyCharm project settings
121 | .idea/
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype type checker
135 | .pytype/
136 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # base stage contains just binary dependencies.
2 | # This is used in the CI build.
3 | FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04 AS base
4 | ARG DEBIAN_FRONTEND=noninteractive
5 |
6 | RUN apt-get update -q \
7 | && apt-get install -y --no-install-recommends \
8 | build-essential \
9 | curl \
10 | ffmpeg \
11 | git \
12 | ssh \
13 | libgl1-mesa-dev \
14 | libgl1-mesa-glx \
15 | libglew-dev \
16 | libosmesa6-dev \
17 | net-tools \
18 | parallel \
19 | patchelf \
20 | python3.8 \
21 | python3.8-dev \
22 | python3-pip \
23 | rsync \
24 | software-properties-common \
25 | unzip \
26 | vim \
27 | virtualenv \
28 | xpra \
29 | xserver-xorg-dev \
30 | && apt-get clean \
31 | && rm -rf /var/lib/apt/lists/*
32 |
33 | ENV LANG C.UTF-8
34 |
35 | RUN mkdir -p /root/.mujoco \
36 | && curl -o mjpro150.zip https://www.roboti.us/download/mjpro150_linux.zip \
37 | && unzip mjpro150.zip -d /root/.mujoco \
38 | && rm mjpro150.zip \
39 | && curl -o /root/.mujoco/mjkey.txt https://www.roboti.us/file/mjkey.txt
40 |
41 | # Set the PATH to the venv before we create the venv, so it's visible in base.
42 | # This is since we may create the venv outside of Docker, e.g. in CI
43 | # or by binding it in for local development.
44 | ENV PATH="/venv/bin:$PATH"
45 | ENV LD_LIBRARY_PATH /root/.mujoco/mjpro150/bin:${LD_LIBRARY_PATH}
46 |
47 | # Run Xdummy mock X server by default so that rendering will work.
48 | COPY ci/xorg.conf /etc/dummy_xorg.conf
49 | COPY ci/Xdummy-entrypoint.py /usr/bin/Xdummy-entrypoint.py
50 | ENTRYPOINT ["/usr/bin/Xdummy-entrypoint.py"]
51 |
52 | # python-req stage contains Python venv, but not code.
53 | # It is useful for development purposes: you can mount
54 | # code from outside the Docker container.
55 | FROM base as python-req
56 |
57 | WORKDIR /seals
58 | # Copy only necessary dependencies to build virtual environment.
59 | # This minimizes how often this layer needs to be rebuilt.
60 | COPY ./setup.py ./setup.py
61 | COPY ./README.md ./README.md
62 | COPY ./src/seals/version.py ./src/seals/version.py
63 | COPY ./ci/build_venv.sh ./ci/build_venv.sh
64 | RUN /seals/ci/build_venv.sh /venv \
65 | && rm -rf $HOME/.cache/pip
66 |
67 | # full stage contains everything.
68 | # Can be used for deployment and local testing.
69 | FROM python-req as full
70 |
71 | # Delay copying (and installing) the code until the very end
72 | COPY . /seals
73 | # Build a wheel then install to avoid copying whole directory (pip issue #2195)
74 | RUN python3 setup.py sdist bdist_wheel
75 | RUN pip install --upgrade dist/seals-*.whl
76 |
77 | # Default entrypoints
78 | CMD ["pytest", "-n", "auto", "-vv", "tests/"]
79 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Center for Human-Compatible AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://circleci.com/gh/HumanCompatibleAI/seals)
2 | [](https://seals.readthedocs.io/en/latest/?badge=latest)
3 | [](https://codecov.io/gh/HumanCompatibleAI/seals)
4 | [](https://badge.fury.io/py/seals)
5 |
6 |

7 |
8 | **Status**: early beta.
9 |
10 | *seals*, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for
11 | evaluating specification learning algorithms, such as reward or imitation learning. The
12 | environments are compatible with [Gym](https://github.com/openai/gym), but are designed
13 | to test algorithms that learn from user data, without requiring a procedurally specified
14 | reward function.
15 |
16 | There are two types of environments in *seals*:
17 |
18 | - **Diagnostic Tasks** which test individual facets of algorithm performance in isolation.
19 | - **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous
20 | control tasks and Atari games to be suitable for specification learning benchmarks. In particular,
21 | we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score).
22 |
23 | *seals* is under active development and we intend to add more categories of tasks soon.
24 |
25 | You may also be interested in our sister project [imitation](https://github.com/humancompatibleai/imitation/),
26 | providing implementations of a variety of imitation and reward learning algorithms.
27 |
28 | Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about *seals*.
29 |
30 | # Quickstart
31 |
32 | To install the latest release from PyPI, run:
33 |
34 | ```bash
35 | pip install seals
36 | ```
37 |
38 | All *seals* environments are available in the Gym registry. Simply import it and then use as you
39 | would with your usual RL or specification learning algroithm:
40 |
41 | ```python
42 | import gymnasium as gym
43 | import seals
44 |
45 | env = gym.make('seals/CartPole-v0')
46 | ```
47 |
48 | We make releases periodically, but if you wish to use the latest version of the code, you can
49 | install directly from Git master:
50 |
51 | ```bash
52 | pip install git+https://github.com/HumanCompatibleAI/seals.git
53 | ```
54 |
55 | # Contributing
56 |
57 | For development, clone the source code and create a virtual environment for this project:
58 |
59 | ```bash
60 | git clone git@github.com:HumanCompatibleAI/seals.git
61 | cd seals
62 | ./ci/build_venv.sh
63 | pip install -e .[dev] # install extra tools useful for development
64 | ```
65 |
66 | ## Code style
67 |
68 | We follow a PEP8 code style with line length 88, and typically follow the [Google Code Style Guide](http://google.github.io/styleguide/pyguide.html),
69 | but defer to PEP8 where they conflict. We use the `black` autoformatter to avoid arguing over formatting.
70 | Docstrings follow the Google docstring convention defined [here](http://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings),
71 | with an extensive example in the [Sphinx docs](https://www.sphinx-doc.org/en/master/usage/extensions/example_google.html).
72 |
73 | All PRs must pass linting via the `ci/code_checks.sh` script. It is convenient to install this as a commit hook:
74 |
75 | ```bash
76 | ln -s ../../ci/code_checks.sh .git/hooks/pre-commit
77 | ```
78 |
79 | ## Tests
80 |
81 | We use [pytest](https://docs.pytest.org/en/latest/) for unit tests
82 | and [codecov](http://codecov.io/) for code coverage.
83 | We also use [pytype](https://github.com/google/pytype) and [mypy](http://mypy-lang.org/)
84 | for type checking.
85 |
86 | ## Workflow
87 |
88 | Trivial changes (e.g. typo fixes) may be made directly by maintainers. Any non-trivial changes
89 | must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous
90 | integration tests (CircleCI linting, type checking, unit tests and CodeCov) to be merged.
91 |
92 | It is often helpful to open an issue before proposing a PR, to allow for discussion of the design
93 | before coding commences.
94 |
95 | # Citing seals
96 |
97 | To cite this project in publications:
98 |
99 | ```bibtex
100 | @misc{seals,
101 | author = {Adam Gleave and Pedro Freire and Steven Wang and Sam Toyer},
102 | title = {{seals}: Suite of Environments for Algorithms that Learn Specifications},
103 | year = {2020},
104 | publisher = {GitHub},
105 | journal = {GitHub repository},
106 | howpublished = {\url{https://github.com/HumanCompatibleAI/seals}},
107 | }
108 | ```
109 |
--------------------------------------------------------------------------------
/ci/Xdummy-entrypoint.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python3
2 |
3 | # Adapted from https://github.com/openai/mujoco-py/blob/master/vendor/Xdummy-entrypoint
4 | # Copyright OpenAI; MIT License
5 |
6 | import argparse
7 | import os
8 | import sys
9 | import subprocess
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser()
13 | args, extra_args = parser.parse_known_args()
14 |
15 | subprocess.Popen(
16 | [
17 | "nohup",
18 | "Xorg",
19 | "-noreset",
20 | "+extension",
21 | "GLX",
22 | "+extension",
23 | "RANDR",
24 | "+extension",
25 | "RENDER",
26 | "-logfile",
27 | "/tmp/xdummy.log",
28 | "-config",
29 | "/etc/dummy_xorg.conf",
30 | ":0",
31 | ]
32 | )
33 | subprocess.Popen(
34 | ["nohup", "Xdummy"],
35 | stdout=open("/dev/null", "w"),
36 | stderr=open("/dev/null", "w"),
37 | )
38 | os.environ["DISPLAY"] = ":0"
39 |
40 | if not extra_args:
41 | argv = ["/bin/bash"]
42 | else:
43 | argv = extra_args
44 |
45 | # Explicitly flush right before the exec since otherwise things might get
46 | # lost in Python's buffers around stdout/stderr (!).
47 | sys.stdout.flush()
48 | sys.stderr.flush()
49 |
50 | os.execvpe(argv[0], argv, os.environ)
51 |
--------------------------------------------------------------------------------
/ci/build_venv.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -e # exit immediately on any error
4 |
5 | venv=$1
6 | if [[ ${venv} == "" ]]; then
7 | venv="venv"
8 | fi
9 |
10 | virtualenv -p python3.8 ${venv}
11 | source ${venv}/bin/activate
12 | pip install --upgrade pip # Ensure we have the newest pip
13 | pip install .[cpu,docs,mujoco,test]
14 |
--------------------------------------------------------------------------------
/ci/code_checks.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # If you change these, also change .circleci/config.yml.
4 | SRC_FILES=(src/ tests/ docs/conf.py setup.py)
5 |
6 | set -x # echo commands
7 | set -e # quit immediately on error
8 |
9 | echo "Source format checking"
10 | flake8 ${SRC_FILES[@]}
11 | black --check ${SRC_FILES[@]}
12 | codespell -I .codespell.skip --skip='*.pyc' ${SRC_FILES[@]}
13 |
14 | if [ -x "`which circleci`" ]; then
15 | circleci config validate
16 | fi
17 |
18 | if [ "$skipexpensive" != "true" ]; then
19 | echo "Building docs (validates docstrings)"
20 | pushd docs/
21 | make clean
22 | make html
23 | popd
24 |
25 | echo "Type checking"
26 | pytype ${SRC_FILES[@]}
27 | mypy ${SRC_FILES[@]}
28 | fi
29 |
--------------------------------------------------------------------------------
/ci/xorg.conf:
--------------------------------------------------------------------------------
1 | # This xorg configuration file is meant to be used by xpra
2 | # to start a dummy X11 server.
3 | # For details, please see:
4 | # https://xpra.org/Xdummy.html
5 |
6 | Section "ServerFlags"
7 | Option "DontVTSwitch" "true"
8 | Option "AllowMouseOpenFail" "true"
9 | Option "PciForceNone" "true"
10 | Option "AutoEnableDevices" "false"
11 | Option "AutoAddDevices" "false"
12 | EndSection
13 |
14 | Section "InputDevice"
15 | Identifier "dummy_mouse"
16 | Option "CorePointer" "true"
17 | Driver "void"
18 | EndSection
19 |
20 | Section "InputDevice"
21 | Identifier "dummy_keyboard"
22 | Option "CoreKeyboard" "true"
23 | Driver "void"
24 | EndSection
25 |
26 | Section "Device"
27 | Identifier "dummy_videocard"
28 | Driver "dummy"
29 | Option "ConstantDPI" "true"
30 | #VideoRam 4096000
31 | #VideoRam 256000
32 | VideoRam 192000
33 | EndSection
34 |
35 | Section "Monitor"
36 | Identifier "dummy_monitor"
37 | HorizSync 5.0 - 1000.0
38 | VertRefresh 5.0 - 200.0
39 | #This can be used to get a specific DPI, but only for the default resolution:
40 | #DisplaySize 508 317
41 | #NOTE: the highest modes will not work without increasing the VideoRam
42 | # for the dummy video card.
43 | Modeline "32768x32768" 15226.50 32768 35800 39488 46208 32768 32771 32781 32953
44 | Modeline "32768x16384" 7516.25 32768 35544 39192 45616 16384 16387 16397 16478
45 | Modeline "16384x8192" 2101.93 16384 16416 24400 24432 8192 8390 8403 8602
46 | Modeline "8192x4096" 424.46 8192 8224 9832 9864 4096 4195 4202 4301
47 | Modeline "5496x1200" 199.13 5496 5528 6280 6312 1200 1228 1233 1261
48 | Modeline "5280x1080" 169.96 5280 5312 5952 5984 1080 1105 1110 1135
49 | Modeline "5280x1200" 191.40 5280 5312 6032 6064 1200 1228 1233 1261
50 | Modeline "5120x3200" 199.75 5120 5152 5904 5936 3200 3277 3283 3361
51 | Modeline "4800x1200" 64.42 4800 4832 5072 5104 1200 1229 1231 1261
52 | Modeline "3840x2880" 133.43 3840 3872 4376 4408 2880 2950 2955 3025
53 | Modeline "3840x2560" 116.93 3840 3872 4312 4344 2560 2622 2627 2689
54 | Modeline "3840x2048" 91.45 3840 3872 4216 4248 2048 2097 2101 2151
55 | Modeline "3840x1080" 100.38 3840 3848 4216 4592 1080 1081 1084 1093
56 | Modeline "3600x1200" 106.06 3600 3632 3984 4368 1200 1201 1204 1214
57 | Modeline "3288x1080" 39.76 3288 3320 3464 3496 1080 1106 1108 1135
58 | Modeline "2048x2048" 49.47 2048 2080 2264 2296 2048 2097 2101 2151
59 | Modeline "2048x1536" 80.06 2048 2104 2312 2576 1536 1537 1540 1554
60 | Modeline "2560x1600" 47.12 2560 2592 2768 2800 1600 1639 1642 1681
61 | Modeline "2560x1440" 42.12 2560 2592 2752 2784 1440 1475 1478 1513
62 | Modeline "1920x1440" 69.47 1920 1960 2152 2384 1440 1441 1444 1457
63 | Modeline "1920x1200" 26.28 1920 1952 2048 2080 1200 1229 1231 1261
64 | Modeline "1920x1080" 23.53 1920 1952 2040 2072 1080 1106 1108 1135
65 | Modeline "1680x1050" 20.08 1680 1712 1784 1816 1050 1075 1077 1103
66 | Modeline "1600x1200" 22.04 1600 1632 1712 1744 1200 1229 1231 1261
67 | Modeline "1600x900" 33.92 1600 1632 1760 1792 900 921 924 946
68 | Modeline "1440x900" 30.66 1440 1472 1584 1616 900 921 924 946
69 | ModeLine "1366x768" 72.00 1366 1414 1446 1494 768 771 777 803
70 | Modeline "1280x1024" 31.50 1280 1312 1424 1456 1024 1048 1052 1076
71 | Modeline "1280x800" 24.15 1280 1312 1400 1432 800 819 822 841
72 | Modeline "1280x768" 23.11 1280 1312 1392 1424 768 786 789 807
73 | Modeline "1360x768" 24.49 1360 1392 1480 1512 768 786 789 807
74 | Modeline "1024x768" 18.71 1024 1056 1120 1152 768 786 789 807
75 | Modeline "768x1024" 19.50 768 800 872 904 1024 1048 1052 1076
76 |
77 |
78 | #common resolutions for android devices (both orientations):
79 | Modeline "800x1280" 25.89 800 832 928 960 1280 1310 1315 1345
80 | Modeline "1280x800" 24.15 1280 1312 1400 1432 800 819 822 841
81 | Modeline "720x1280" 30.22 720 752 864 896 1280 1309 1315 1345
82 | Modeline "1280x720" 27.41 1280 1312 1416 1448 720 737 740 757
83 | Modeline "768x1024" 24.93 768 800 888 920 1024 1047 1052 1076
84 | Modeline "1024x768" 23.77 1024 1056 1144 1176 768 785 789 807
85 | Modeline "600x1024" 19.90 600 632 704 736 1024 1047 1052 1076
86 | Modeline "1024x600" 18.26 1024 1056 1120 1152 600 614 617 631
87 | Modeline "536x960" 16.74 536 568 624 656 960 982 986 1009
88 | Modeline "960x536" 15.23 960 992 1048 1080 536 548 551 563
89 | Modeline "600x800" 15.17 600 632 688 720 800 818 822 841
90 | Modeline "800x600" 14.50 800 832 880 912 600 614 617 631
91 | Modeline "480x854" 13.34 480 512 560 592 854 873 877 897
92 | Modeline "848x480" 12.09 848 880 920 952 480 491 493 505
93 | Modeline "480x800" 12.43 480 512 552 584 800 818 822 841
94 | Modeline "800x480" 11.46 800 832 872 904 480 491 493 505
95 | #resolutions for android devices (both orientations)
96 | #minus the status bar
97 | #38px status bar (and width rounded up)
98 | Modeline "800x1242" 25.03 800 832 920 952 1242 1271 1275 1305
99 | Modeline "1280x762" 22.93 1280 1312 1392 1424 762 780 783 801
100 | Modeline "720x1242" 29.20 720 752 856 888 1242 1271 1276 1305
101 | Modeline "1280x682" 25.85 1280 1312 1408 1440 682 698 701 717
102 | Modeline "768x986" 23.90 768 800 888 920 986 1009 1013 1036
103 | Modeline "1024x730" 22.50 1024 1056 1136 1168 730 747 750 767
104 | Modeline "600x986" 19.07 600 632 704 736 986 1009 1013 1036
105 | Modeline "1024x562" 17.03 1024 1056 1120 1152 562 575 578 591
106 | Modeline "536x922" 16.01 536 568 624 656 922 943 947 969
107 | Modeline "960x498" 14.09 960 992 1040 1072 498 509 511 523
108 | Modeline "600x762" 14.39 600 632 680 712 762 779 783 801
109 | Modeline "800x562" 13.52 800 832 880 912 562 575 578 591
110 | Modeline "480x810" 12.59 480 512 552 584 810 828 832 851
111 | Modeline "848x442" 11.09 848 880 920 952 442 452 454 465
112 | Modeline "480x762" 11.79 480 512 552 584 762 779 783 801
113 | EndSection
114 |
115 | Section "Screen"
116 | Identifier "dummy_screen"
117 | Device "dummy_videocard"
118 | Monitor "dummy_monitor"
119 | DefaultDepth 24
120 | SubSection "Display"
121 | Viewport 0 0
122 | Depth 24
123 | #Modes "32768x32768" "32768x16384" "16384x8192" "8192x4096" "5120x3200" "3840x2880" "3840x2560" "3840x2048" "2048x2048" "2560x1600" "1920x1440" "1920x1200" "1920x1080" "1600x1200" "1680x1050" "1600x900" "1400x1050" "1440x900" "1280x1024" "1366x768" "1280x800" "1024x768" "1024x600" "800x600" "320x200"
124 | Modes "5120x3200" "3840x2880" "3840x2560" "3840x2048" "2048x2048" "2560x1600" "1920x1440" "1920x1200" "1920x1080" "1600x1200" "1680x1050" "1600x900" "1400x1050" "1440x900" "1280x1024" "1366x768" "1280x800" "1024x768" "1024x600" "800x600" "320x200"
125 | #Virtual 32000 32000
126 | #Virtual 16384 8192
127 | # 1024x768 is big enough for testing, but small enough it won't eat up lots of RAM
128 | Virtual 1024 768
129 | #Virtual 5120 3200
130 | EndSubSection
131 | EndSection
132 |
133 | Section "ServerLayout"
134 | Identifier "dummy_layout"
135 | Screen "dummy_screen"
136 | InputDevice "dummy_mouse"
137 | InputDevice "dummy_keyboard"
138 | EndSection
139 |
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default: false
5 | main:
6 | paths:
7 | - "src/seals/"
8 | - "!src/seals/testing/"
9 | tests:
10 | # Should not have dead code in our tests
11 | target: 100%
12 | paths:
13 | - "tests/"
14 | - "src/seals/testing/"
15 | patch:
16 | default: false
17 | main:
18 | paths:
19 | - "src/seals/"
20 | - "!src/seals/testing/"
21 | tests:
22 | target: 100%
23 | paths:
24 | - "tests/"
25 | - "src/seals/testing/"
26 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/img/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/docs/common/base_envs.rst:
--------------------------------------------------------------------------------
1 | .. _base_envs:
2 |
3 | Base Environments
4 | =================
5 |
6 | .. automodule:: seals.base_envs
7 |
--------------------------------------------------------------------------------
/docs/common/testing.rst:
--------------------------------------------------------------------------------
1 | .. _testing:
2 |
3 | Helpers for unit-testing environments
4 | =====================================
5 |
6 | .. automodule:: seals.testing.envs
7 |
--------------------------------------------------------------------------------
/docs/common/util.rst:
--------------------------------------------------------------------------------
1 | .. _util:
2 |
3 | Utilities
4 | =========
5 |
6 | .. automodule:: seals.util
7 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | """Configuration file for the Sphinx documentation builder."""
2 |
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # http://www.sphinx-doc.org/en/master/config
6 |
7 | from importlib import metadata
8 |
9 | # -- Project information -----------------------------------------------------
10 |
11 | project = "seals"
12 | copyright = "2020, Center for Human-Compatible AI" # noqa: A001
13 | author = "Center for Human-Compatible AI"
14 |
15 | # The full version, including alpha/beta/rc tags
16 | version = metadata.version("seals")
17 |
18 |
19 | # -- General configuration ---------------------------------------------------
20 |
21 | # Add any Sphinx extension module names here, as strings. They can be
22 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
23 | # ones.
24 | extensions = [
25 | "sphinx.ext.autodoc",
26 | "sphinx_autodoc_typehints",
27 | "sphinx.ext.autosummary",
28 | "sphinx.ext.mathjax",
29 | "sphinx.ext.napoleon",
30 | "sphinx.ext.viewcode",
31 | "sphinx_rtd_theme",
32 | ]
33 | autodoc_mock_imports = ["mujoco_py"]
34 |
35 | # Add any paths that contain templates here, relative to this directory.
36 | templates_path = ["_templates"]
37 |
38 | # List of patterns, relative to source directory, that match files and
39 | # directories to ignore when looking for source files.
40 | # This pattern also affects html_static_path and html_extra_path.
41 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
42 |
43 | autodoc_default_options = {
44 | "members": True,
45 | "undoc-members": True,
46 | "special-members": "__init__",
47 | "show-inheritance": True,
48 | }
49 |
50 | # -- Options for HTML output -------------------------------------------------
51 |
52 | # The theme to use for HTML and HTML Help pages. See the documentation for
53 | # a list of builtin themes.
54 | #
55 | html_theme = "sphinx_rtd_theme"
56 | html_logo = "_static/img/logo.svg"
57 | html_theme_options = {
58 | "logo_only": True,
59 | "style_nav_header_background": "#06203A",
60 | }
61 |
62 | # Add any paths that contain custom static files (such as style sheets) here,
63 | # relative to this directory. They are copied after the builtin static files,
64 | # so a file named "default.css" will overwrite the builtin "default.css".
65 | html_static_path = ["_static"]
66 |
--------------------------------------------------------------------------------
/docs/environments/diagnostic.rst:
--------------------------------------------------------------------------------
1 | .. _diagnostic:
2 |
3 | Diagnostic Tasks
4 | ================
5 |
6 | Diagnostic tasks test individual facets of algorithm performance in isolation.
7 |
8 | Branching
9 | ---------
10 |
11 | **Gym ID**: ``seals/Branching-v0``
12 |
13 | .. automodule:: seals.diagnostics.branching
14 |
15 | EarlyTerm
16 | ---------
17 |
18 | **Gym ID**: ``seals/EarlyTermPos-v0`` and ``seals/EarlyTermNeg-v0``
19 |
20 | .. automodule:: seals.diagnostics.early_term
21 |
22 | InitShift
23 | ---------
24 |
25 | **Gym ID**: ``seals/InitShiftTrain-v0`` and ``seals/InitShiftTest-v0seals/EarlyTermPos-v0``
26 |
27 | .. automodule:: seals.diagnostics.init_shift
28 |
29 | LargestSum
30 | -------------------
31 |
32 | **Gym ID**: ``seals/LargestSum-v0``
33 |
34 | .. automodule:: seals.diagnostics.largest_sum
35 |
36 | NoisyObs
37 | --------
38 |
39 | **Gym ID**: ``seals/NoisyObs-v0``
40 |
41 | .. automodule:: seals.diagnostics.noisy_obs
42 |
43 | Parabola
44 | --------
45 |
46 | **Gym ID**: ``seals/Parabola-v0``
47 |
48 | .. automodule:: seals.diagnostics.parabola
49 |
50 | ProcGoal
51 | --------
52 |
53 | **Gym ID**: ``seals/ProcGoal-v0``
54 |
55 | .. automodule:: seals.diagnostics.proc_goal
56 |
57 | RiskyPath
58 | ---------
59 |
60 | **Gym ID**: ``seals/RiskyPath-v0``
61 |
62 | .. automodule:: seals.diagnostics.risky_path
63 |
64 | Sort
65 | ----
66 |
67 | **Gym ID**: ``seals/Sort-v0``
68 |
69 | .. automodule:: seals.diagnostics.sort
70 |
--------------------------------------------------------------------------------
/docs/environments/renovated.rst:
--------------------------------------------------------------------------------
1 | .. _renovated:
2 |
3 | Renovated Environments
4 | ======================
5 |
6 | These environments are adaptations of widely-used reinforcement learning benchmarks from
7 | `Gym `_, modified to be suitable for benchmarking specification
8 | learning algorithms. In particular, we:
9 |
10 | * Make episodes fixed length. Since episode termination conditions are often correlated with
11 | reward, variable-length episodes provide a side-channel of reward information that algorithms
12 | can exploit. Critically, episode boundaries do not exist outside of simulation: in the
13 | real-world, a human must often `"reset" the RL algorithm `_.
14 |
15 | Moreover, many algorithms do not properly handle episode termination, and so are
16 | `biased `_ towards shorter or longer episode boundaries.
17 | This confounds evaluation, making some algorithms appear spuriously good or bad depending
18 | on if their bias aligns with the task objective.
19 |
20 | For most tasks, we make the episode fixed length simply by removing the early termination
21 | condition. In some environments, such as *MountainCar*, it does not make sense to continue
22 | after the terminal state: in this case, we make the terminal state an absorbing state that
23 | is repeated until the end of the episode.
24 | * Ensure observations include all information necessary to compute the ground-truth reward
25 | function. For some environments, this has required augmenting the observation space.
26 | We make this modification to make RL and specification learning of comparable difficulty
27 | in these environments. While in general both RL and specification learning may need to
28 | operate in partially observable environments, the observations in these relatively simple
29 | environments were typically engineered to *make RL easy*: for a fair comparison, we must
30 | therefore also provide reward learning algorithms with sufficient features to recover the
31 | reward.
32 |
33 | In the future, we intend to add Atari tasks with the score masked, another reward side-channel.
34 |
35 | Classic Control
36 | ---------------
37 |
38 | CartPole
39 | ********
40 |
41 | **Gym ID**: ``seals/CartPole-v0``
42 |
43 | .. autoclass:: seals.classic_control.FixedHorizonCartPole
44 |
45 | MountainCar
46 | ***********
47 |
48 | **Gym ID**: ``seals/MountainCar-v0``
49 |
50 | .. autofunction:: seals.classic_control.mountain_car
51 |
52 | MuJoCo
53 | ------
54 |
55 | Ant
56 | ***
57 |
58 | **Gym ID**: ``seals/Ant-v0``
59 |
60 | .. autoclass:: seals.mujoco.AntEnv
61 |
62 | HalfCheetah
63 | ***********
64 |
65 | **Gym ID**: ``seals/HalfCheetah-v0``
66 |
67 | .. autoclass:: seals.mujoco.HalfCheetahEnv
68 |
69 | Hopper
70 | ******
71 |
72 | **Gym ID**: ``seals/Hopper-v0``
73 |
74 | .. autoclass:: seals.mujoco.HopperEnv
75 |
76 | Humanoid
77 | ********
78 |
79 | **Gym ID**: ``seals/Humanoid-v0``
80 |
81 | .. autoclass:: seals.mujoco.HumanoidEnv
82 |
83 | Swimmer
84 | *******
85 |
86 | **Gym ID**: ``seals/Swimmer-v0``
87 |
88 | .. autoclass:: seals.mujoco.SwimmerEnv
89 |
90 | Walker2d
91 | ********
92 |
93 | **Gym ID**: ``seals/Walker2d-v0``
94 |
95 | .. autoclass:: seals.mujoco.Walker2dEnv
96 |
--------------------------------------------------------------------------------
/docs/guide/install.rst:
--------------------------------------------------------------------------------
1 | .. _install:
2 |
3 | Installation Instructions
4 | =========================
5 |
6 | To install the latest release from PyPi, run::
7 |
8 | pip install seals
9 |
10 | We make releases periodically, but if you wish to use the latest version of the code, you can
11 | always install directly from Git master::
12 |
13 | pip install git+https://github.com/HumanCompatibleAI/seals.git
14 |
15 | *seals* has optional dependencies needed by some subset of environments. In particular,
16 | to use MuJoCo environments, you will need to install `MuJoCo `_ 1.5
17 | and then run::
18 |
19 | pip install seals[mujoco]
20 |
21 | You may need to install some other binary dependencies: see the instructions in
22 | `Gym `_ and `mujoco-py `_
23 | for further information.
24 |
25 | You can also use our Docker image which includes all necessary binary dependencies. You can either
26 | build it from the ``Dockerfile``, or by downloading a pre-built image::
27 |
28 | docker pull humancompatibleai/seals:base
29 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | seals User Guide
2 | ================
3 |
4 | The Suite of Environments for Algorithms that Learn Specifications, or *seals*, is a toolkit for
5 | evaluating specification learning algorithms, such as reward or imitation learning. The environments
6 | are compatible with `Gym `_, but are designed to test algorithms
7 | that learn from user data, without requiring a procedurally specified reward function.
8 |
9 | There are two types of environments in *seals*:
10 |
11 | * **Diagnostic Tasks** which test individual facets of algorithm performance in isolation.
12 | * **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous
13 | control tasks to be suitable for specification learning benchmarks. In particular, this
14 | involves removing any side-channel sources of reward information (such as episode boundaries,
15 | the score appearing in the observation, etc) and including all the information needed to
16 | compute the reward in the observation space.
17 |
18 | *seals* is under active development and we intend to add more categories of tasks soon.
19 |
20 | .. toctree::
21 | :maxdepth: 1
22 | :caption: User Guide
23 |
24 | guide/install
25 |
26 |
27 | .. toctree::
28 | :maxdepth: 3
29 | :caption: Environments
30 |
31 | environments/diagnostic
32 | environments/renovated
33 |
34 | .. toctree::
35 | :maxdepth: 2
36 | :caption: Common
37 |
38 | common/base_envs
39 | common/util
40 | common/testing
41 |
42 | Citing seals
43 | ------------
44 | To cite this project in publications:
45 |
46 | .. code-block:: bibtex
47 |
48 | @misc{seals,
49 | author = {Adam Gleave and Pedro Freire and Steven Wang and Sam Toyer},
50 | title = {{seals}: Suite of Environments for Algorithms that Learn Specifications},
51 | year = {2020},
52 | publisher = {GitHub},
53 | journal = {GitHub repository},
54 | howpublished = {\url{https://github.com/HumanCompatibleAI/seals}},
55 | }
56 |
57 | Indices and tables
58 | ==================
59 |
60 | * :ref:`genindex`
61 | * :ref:`modindex`
62 | * :ref:`search`
63 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | ignore_missing_imports = true
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [tool.black]
6 | target-version = ["py38"]
7 |
8 | [[tool.mypy.overrides]]
9 | module = ["gym.*", "setuptools_scm.*"]
10 | ignore_missing_imports = true
11 |
--------------------------------------------------------------------------------
/readthedocs.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | sphinx:
4 | configuration: docs/conf.py
5 |
6 | formats: all
7 |
8 | python:
9 | version: 3.8
10 | install:
11 | - method: pip
12 | path: .
13 | extra_requirements:
14 | - docs
15 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [coverage:run]
2 | source = seals
3 | include=
4 | src/*
5 | tests/*
6 |
7 | [coverage:report]
8 | exclude_lines =
9 | pragma: no cover
10 | if __name__ == .__main__.:
11 | omit =
12 | setup.py
13 |
14 | [coverage:paths]
15 | source =
16 | src/seals
17 | *venv/lib/python*/site-packages/seals
18 |
19 | [darglint]
20 | strictness=long
21 |
22 | [flake8]
23 | docstring-convention=google
24 | ignore = E203, W503
25 | max-line-length = 88
26 |
27 | [isort]
28 | line_length=88
29 | known_first_party=seals,tests
30 | default_section=THIRDPARTY
31 | multi_line_output=3
32 | include_trailing_comma=True
33 | force_sort_within_sections=True
34 | skip=.pytype
35 |
36 | [pytype]
37 | inputs =
38 | src/
39 | tests/
40 | setup.py
41 | python_version >= 3.8
42 |
43 | [tool:pytest]
44 | markers =
45 | expensive: mark a test as expensive (deselect with '-m "not expensive"')
46 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """setup.py for seals project."""
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from setuptools import find_packages, setup # type:ignore
6 |
7 | if TYPE_CHECKING:
8 | from setuptools_scm.version import ScmVersion
9 |
10 |
11 | def get_version(version: "ScmVersion") -> str:
12 | """Generates the version string for the package.
13 |
14 | This function replaces the default version format used by setuptools_scm
15 | to allow development builds to be versioned using the git commit hash
16 | instead of the number of commits since the last release, which leads to
17 | duplicate version identifiers when using multiple branches
18 | (see https://github.com/HumanCompatibleAI/imitation/issues/500).
19 | The version has the following format:
20 | {version}[.dev{build}]
21 | where build is the shortened commit hash converted to base 10.
22 |
23 | Args:
24 | version: The version object given by setuptools_scm, calculated
25 | from the git repository.
26 |
27 | Returns:
28 | The formatted version string to use for the package.
29 | """
30 | # We import setuptools_scm here because it is only installed after the module
31 | # is loaded and the setup function is called.
32 | from setuptools_scm import version as scm_version
33 |
34 | if version.node:
35 | # By default node corresponds to the short commit hash when using git,
36 | # plus a "g" prefix. We remove the "g" prefix from the commit hash which
37 | # is added by setuptools_scm by default ("g" for git vs. mercurial etc.)
38 | # because letters are not valid for version identifiers in PEP 440.
39 | # We also convert from hexadecimal to base 10 for the same reason.
40 | version.node = str(int(version.node.lstrip("g"), 16))
41 | if version.exact:
42 | # an exact version is when the current commit is tagged with a version.
43 | return version.format_with("{tag}")
44 | else:
45 | # the current commit is not tagged with a version, so we guess
46 | # what the "next" version will be (this can be disabled but is the
47 | # default behavior of setuptools_scm so it has been left in).
48 | return version.format_next_version(
49 | scm_version.guess_next_version,
50 | fmt="{guessed}.dev{node}",
51 | )
52 |
53 |
54 | def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str:
55 | """Generates the local version string for the package.
56 |
57 | By default, when commits are made on top of a release version, setuptools_scm
58 | sets the version to be {version}.dev{distance}+{node} where {distance} is the number
59 | of commits since the last release and {node} is the short commit hash.
60 | This function replaces the default version format used by setuptools_scm
61 | so that committed changes away from a release version are not considered
62 | local versions but dev versions instead (by using the format
63 | {version}.dev{node} instead. This is so that we can push test releases
64 | to TestPyPI (it does not accept local versions).
65 | Local versions are still present if there are uncommitted changes (if the tree
66 | is dirty), in which case the current date is added to the version.
67 |
68 | Args:
69 | version: The version object given by setuptools_scm, calculated
70 | from the git repository.
71 | time_format: The format to use for the date.
72 |
73 | Returns:
74 | The formatted local version string to use for the package.
75 | """
76 | return version.format_choice(
77 | "",
78 | "+d{time:{time_format}}",
79 | time_format=time_format,
80 | )
81 |
82 |
83 | def get_readme() -> str:
84 | """Retrieve content from README."""
85 | with open("README.md", "r") as f:
86 | return f.read()
87 |
88 |
89 | ATARI_REQUIRE = [
90 | "opencv-python",
91 | "ale-py~=0.8.1",
92 | "pillow",
93 | "autorom[accept-rom-license]~=0.4.2",
94 | "shimmy[atari] >=0.1.0,<1.0",
95 | ]
96 | TESTS_REQUIRE = [
97 | "black",
98 | "coverage~=4.5.4",
99 | "codecov",
100 | "codespell",
101 | "darglint>=1.5.6",
102 | "flake8",
103 | "flake8-blind-except",
104 | "flake8-builtins",
105 | "flake8-commas",
106 | "flake8-debugger",
107 | "flake8-docstrings",
108 | "flake8-isort",
109 | "isort",
110 | "matplotlib",
111 | "mypy",
112 | "pydocstyle",
113 | "pytest",
114 | "pytest-cov",
115 | "pytest-xdist",
116 | "pytype",
117 | "stable-baselines3>=0.9.0",
118 | "setuptools_scm~=7.0.5",
119 | "gymnasium[classic-control,mujoco]",
120 | *ATARI_REQUIRE,
121 | ]
122 | DOCS_REQUIRE = [
123 | "sphinx",
124 | "sphinx-autodoc-typehints>=1.21.5",
125 | "sphinx-rtd-theme",
126 | ]
127 |
128 |
129 | setup(
130 | name="seals",
131 | use_scm_version={"local_scheme": get_local_version, "version_scheme": get_version},
132 | description="Suite of Environments for Algorithms that Learn Specifications",
133 | long_description=get_readme(),
134 | long_description_content_type="text/markdown",
135 | author="Center for Human-Compatible AI",
136 | python_requires=">=3.8.0",
137 | packages=find_packages("src"),
138 | package_dir={"": "src"},
139 | package_data={"seals": ["py.typed"]},
140 | install_requires=["gymnasium", "numpy"],
141 | tests_require=TESTS_REQUIRE,
142 | extras_require={
143 | # recommended packages for development
144 | "dev": ["ipdb", "jupyter", *TESTS_REQUIRE, *DOCS_REQUIRE],
145 | "docs": DOCS_REQUIRE,
146 | "test": TESTS_REQUIRE,
147 | "mujoco": ["gymnasium[mujoco]"],
148 | "atari": ATARI_REQUIRE,
149 | },
150 | url="https://github.com/HumanCompatibleAI/benchmark-environments",
151 | license="MIT",
152 | classifiers=[
153 | # Trove classifiers
154 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
155 | "License :: OSI Approved :: MIT License",
156 | "Programming Language :: Python",
157 | "Programming Language :: Python :: 3",
158 | "Programming Language :: Python :: 3.8",
159 | "Programming Language :: Python :: 3.9",
160 | "Programming Language :: Python :: 3.10",
161 | "Programming Language :: Python :: Implementation :: CPython",
162 | "Programming Language :: Python :: Implementation :: PyPy",
163 | "License :: OSI Approved :: MIT License",
164 | "Operating System :: OS Independent",
165 | ],
166 | )
167 |
--------------------------------------------------------------------------------
/src/seals/__init__.py:
--------------------------------------------------------------------------------
1 | """Benchmark environments for reward modeling and imitation."""
2 |
3 | from importlib import metadata
4 |
5 | import gymnasium as gym
6 |
7 | from seals import atari, util
8 | import seals.diagnostics # noqa: F401
9 |
10 | try:
11 | __version__ = metadata.version("seals")
12 | except metadata.PackageNotFoundError: # pragma: no cover
13 | # package is not installed
14 | pass
15 |
16 | # Classic control
17 |
18 | gym.register(
19 | id="seals/CartPole-v0",
20 | entry_point="seals.classic_control:FixedHorizonCartPole",
21 | max_episode_steps=500,
22 | )
23 |
24 | gym.register(
25 | id="seals/MountainCar-v0",
26 | entry_point="seals.classic_control:mountain_car",
27 | max_episode_steps=200,
28 | )
29 |
30 | # MuJoCo
31 |
32 | for env_base in ["Ant", "HalfCheetah", "Hopper", "Humanoid", "Swimmer", "Walker2d"]:
33 | gym.register(
34 | id=f"seals/{env_base}-v1",
35 | entry_point=f"seals.mujoco:{env_base}Env",
36 | max_episode_steps=util.get_gym_max_episode_steps(f"{env_base}-v4"),
37 | )
38 |
39 | # Atari
40 |
41 | GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.registry.values()))
42 | atari.register_atari_envs(GYM_ATARI_ENV_SPECS)
43 |
--------------------------------------------------------------------------------
/src/seals/atari.py:
--------------------------------------------------------------------------------
1 | """Adaptation of Atari environments for specification learning algorithms."""
2 |
3 | from typing import Dict, Iterable, Optional
4 |
5 | import gymnasium as gym
6 | from gymnasium.envs.registration import EnvSpec
7 |
8 | from seals.util import (
9 | AutoResetWrapper,
10 | BoxRegion,
11 | MaskedRegionSpecifier,
12 | MaskScoreWrapper,
13 | get_gym_max_episode_steps,
14 | )
15 |
16 | SCORE_REGIONS: Dict[str, MaskedRegionSpecifier] = {
17 | "BeamRider": [
18 | BoxRegion(x=(5, 20), y=(45, 120)),
19 | BoxRegion(x=(28, 40), y=(15, 40)),
20 | ],
21 | "Breakout": [BoxRegion(x=(0, 16), y=(35, 80))],
22 | "Enduro": [
23 | BoxRegion(x=(163, 173), y=(55, 110)),
24 | BoxRegion(x=(177, 188), y=(68, 107)),
25 | ],
26 | "Pong": [BoxRegion(x=(0, 24), y=(0, 160))],
27 | "Qbert": [BoxRegion(x=(6, 15), y=(33, 71))],
28 | "Seaquest": [BoxRegion(x=(7, 19), y=(80, 110))],
29 | "SpaceInvaders": [BoxRegion(x=(10, 20), y=(0, 160))],
30 | }
31 |
32 |
33 | def _get_score_region(atari_env_id: str) -> Optional[MaskedRegionSpecifier]:
34 | basename = atari_env_id.split("/")[-1].split("-")[0]
35 | basename = basename.replace("NoFrameskip", "")
36 | return SCORE_REGIONS.get(basename)
37 |
38 |
39 | def make_atari_env(atari_env_id: str, masked: bool, *args, **kwargs) -> gym.Env:
40 | """Fixed-length, optionally masked-score variant of a given Atari environment."""
41 | env: gym.Env = AutoResetWrapper(gym.make(atari_env_id, *args, **kwargs))
42 |
43 | if masked:
44 | score_region = _get_score_region(atari_env_id)
45 | if score_region is None:
46 | raise ValueError(
47 | "Requested environment does not yet support masking. "
48 | "See https://github.com/HumanCompatibleAI/seals/issues/61.",
49 | )
50 | env = MaskScoreWrapper(env, score_region)
51 |
52 | return env
53 |
54 |
55 | def _not_ram_or_det(env_id: str) -> bool:
56 | """Checks a gym Atari environment isn't deterministic or using RAM observations."""
57 | slash_separated = env_id.split("/")
58 | # environment name should look like "ALE/Amidar-v5" or "Amidar-ramNoFrameskip-v4"
59 | assert len(slash_separated) in (1, 2)
60 | after_slash = slash_separated[-1]
61 | hyphen_separated = after_slash.split("-")
62 | assert len(hyphen_separated) > 1
63 | not_ram = "ram" not in hyphen_separated[1]
64 | not_deterministic = "Deterministic" not in env_id
65 | return not_ram and not_deterministic
66 |
67 |
68 | def _supported_atari_env(gym_spec: EnvSpec) -> bool:
69 | """Checks if a gym Atari environment is one of the ones we will support."""
70 | is_atari = gym_spec.entry_point == "shimmy.atari_env:AtariEnv"
71 | v5_and_plain = gym_spec.id.endswith("-v5") and "NoFrameskip" not in gym_spec.id
72 | v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id
73 | return (
74 | is_atari
75 | and _not_ram_or_det(gym_spec.id)
76 | and (v5_and_plain or v4_and_no_frameskip)
77 | )
78 |
79 |
80 | def _seals_name(gym_spec: EnvSpec, masked: bool) -> str:
81 | """Makes a Gym ID for an Atari environment in the seals namespace."""
82 | slash_separated = gym_spec.id.split("/")
83 | name = "seals/" + slash_separated[-1]
84 |
85 | if not masked:
86 | last_hyphen_idx = name.rfind("-v")
87 | name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:]
88 | return name
89 |
90 |
91 | def register_atari_envs(
92 | gym_atari_env_specs: Iterable[EnvSpec],
93 | ) -> None:
94 | """Register masked and unmasked wrapped gym Atari environments."""
95 |
96 | def register_gym(masked):
97 | gym.register(
98 | id=_seals_name(gym_spec, masked=masked),
99 | entry_point="seals.atari:make_atari_env",
100 | max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
101 | kwargs=dict(atari_env_id=gym_spec.id, masked=masked),
102 | )
103 |
104 | for gym_spec in gym_atari_env_specs:
105 | register_gym(masked=False)
106 | if _get_score_region(gym_spec.id) is not None:
107 | register_gym(masked=True)
108 |
--------------------------------------------------------------------------------
/src/seals/base_envs.py:
--------------------------------------------------------------------------------
1 | """Base environment classes."""
2 |
3 | import abc
4 | from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union
5 |
6 | import gymnasium as gym
7 | from gymnasium import spaces
8 | import numpy as np
9 | import numpy.typing as npt
10 |
11 | from seals import util
12 |
13 | # Note: we redefine the type vars from gymnasium.core here, because pytype does not
14 | # recognize them as valid type vars if we import them from gymnasium.core.
15 | StateType = TypeVar("StateType")
16 | ActType = TypeVar("ActType")
17 | ObsType = TypeVar("ObsType")
18 |
19 |
20 | class ResettablePOMDP(
21 | gym.Env[ObsType, ActType],
22 | abc.ABC,
23 | Generic[StateType, ObsType, ActType],
24 | ):
25 | """ABC for POMDPs that are resettable.
26 |
27 | Specifically, these environments provide oracle access to sample from
28 | the initial state distribution and transition dynamics, and compute the
29 | reward and termination condition. Almost all simulated environments can
30 | meet these criteria.
31 | """
32 |
33 | state_space: spaces.Space[StateType]
34 |
35 | _cur_state: Optional[StateType]
36 | _n_actions_taken: Optional[int]
37 |
38 | def __init__(self):
39 | """Build resettable (PO)MDP."""
40 | self._cur_state = None
41 | self._n_actions_taken = None
42 |
43 | @abc.abstractmethod
44 | def initial_state(self) -> StateType:
45 | """Samples from the initial state distribution."""
46 |
47 | @abc.abstractmethod
48 | def transition(self, state: StateType, action: ActType) -> StateType:
49 | """Samples from transition distribution."""
50 |
51 | @abc.abstractmethod
52 | def reward(self, state: StateType, action: ActType, new_state: StateType) -> float:
53 | """Computes reward for a given transition."""
54 |
55 | @abc.abstractmethod
56 | def terminal(self, state: StateType, step: int) -> bool:
57 | """Is the state terminal?"""
58 |
59 | @abc.abstractmethod
60 | def obs_from_state(self, state: StateType) -> ObsType:
61 | """Sample observation for given state."""
62 |
63 | @property
64 | def n_actions_taken(self) -> int:
65 | """Number of steps taken so far."""
66 | assert self._n_actions_taken is not None
67 | return self._n_actions_taken
68 |
69 | @property
70 | def state(self) -> StateType:
71 | """Current state."""
72 | assert self._cur_state is not None
73 | return self._cur_state
74 |
75 | @state.setter
76 | def state(self, state: StateType):
77 | """Set the current state."""
78 | if state not in self.state_space:
79 | raise ValueError(f"{state} not in {self.state_space}")
80 | self._cur_state = state
81 |
82 | def reset(
83 | self,
84 | *,
85 | seed: Optional[int] = None,
86 | options: Optional[Dict[str, Any]] = None,
87 | ) -> Tuple[ObsType, Dict[str, Any]]:
88 | """Reset the episode and return initial observation."""
89 | if options is not None:
90 | raise NotImplementedError("Options not supported.")
91 |
92 | super().reset(seed=seed)
93 | self.state = self.initial_state()
94 | self._n_actions_taken = 0
95 | obs = self.obs_from_state(self.state)
96 | info: Dict[str, Any] = dict()
97 | return obs, info
98 |
99 | def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
100 | """Transition state using given action."""
101 | if self._cur_state is None or self._n_actions_taken is None:
102 | raise RuntimeError("Need to call reset() before first step()")
103 | if action not in self.action_space:
104 | raise ValueError(f"{action} not in {self.action_space}")
105 |
106 | old_state = self.state
107 | self.state = self.transition(self.state, action)
108 | obs = self.obs_from_state(self.state)
109 | assert obs in self.observation_space
110 | reward = self.reward(old_state, action, self.state)
111 | self._n_actions_taken += 1
112 | terminated = self.terminal(self.state, self.n_actions_taken)
113 | truncated = False
114 |
115 | infos = {"old_state": old_state, "new_state": self._cur_state}
116 | return obs, reward, terminated, truncated, infos
117 |
118 |
119 | class ExposePOMDPStateWrapper(
120 | gym.Wrapper[StateType, ActType, ObsType, ActType],
121 | Generic[StateType, ObsType, ActType],
122 | ):
123 | """A wrapper that exposes the current state of the POMDP as the observation."""
124 |
125 | env: ResettablePOMDP[StateType, ObsType, ActType]
126 |
127 | def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None:
128 | """Build wrapper.
129 |
130 | Args:
131 | env: POMDP to wrap.
132 | """
133 | super().__init__(env)
134 | self._observation_space = env.state_space
135 |
136 | def reset(
137 | self,
138 | seed: Optional[int] = None,
139 | options: Optional[Dict[str, Any]] = None,
140 | ) -> Tuple[StateType, Dict[str, Any]]:
141 | """Reset environment and return initial state."""
142 | _, info = self.env.reset(seed=seed, options=options)
143 | return self.env.state, info
144 |
145 | def step(self, action) -> Tuple[StateType, float, bool, bool, dict]:
146 | """Transition state using given action."""
147 | _, reward, terminated, truncated, info = self.env.step(action)
148 | return self.env.state, reward, terminated, truncated, info
149 |
150 |
151 | class ResettableMDP(
152 | ResettablePOMDP[StateType, StateType, ActType],
153 | abc.ABC,
154 | Generic[StateType, ActType],
155 | ):
156 | """ABC for MDPs that are resettable."""
157 |
158 | @property
159 | def observation_space(self):
160 | """Observation space."""
161 | return self.state_space
162 |
163 | def obs_from_state(self, state: StateType) -> StateType:
164 | """Identity since observation == state in an MDP."""
165 | return state
166 |
167 |
168 | DiscreteSpaceInt = np.int64
169 |
170 |
171 | # TODO(juan) this does not implement the .render() method,
172 | # so in theory it should not be instantiated directly.
173 | # Not sure why this is not raising an error?
174 | class BaseTabularModelPOMDP(
175 | ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt],
176 | Generic[ObsType],
177 | ):
178 | """Base class for tabular environments with known dynamics.
179 |
180 | This is the general class that also allows subclassing for creating
181 | MDP (where observation == state) or POMDP (where observation != state).
182 | """
183 |
184 | transition_matrix: np.ndarray
185 | reward_matrix: np.ndarray
186 |
187 | state_space: spaces.Discrete
188 |
189 | def __init__(
190 | self,
191 | *,
192 | transition_matrix: np.ndarray,
193 | reward_matrix: np.ndarray,
194 | horizon: Optional[int] = None,
195 | initial_state_dist: Optional[np.ndarray] = None,
196 | ):
197 | """Build tabular environment.
198 |
199 | Args:
200 | transition_matrix: 3-D array with transition probabilities for a
201 | given state-action pair, of shape `(n_states,n_actions,n_states)`.
202 | reward_matrix: 1-D, 2-D or 3-D array corresponding to rewards to a
203 | given `(state, action, next_state)` triple. A 2-D array assumes
204 | the `next_state` is not used in the reward, and a 1-D array
205 | assumes neither the `action` nor `next_state` are used.
206 | Of shape `(n_states,n_actions,n_states)[:n]` where `n`
207 | is the dimensionality of the array.
208 | horizon: Maximum number of timesteps. The default is `None`,
209 | which represents an infinite horizon.
210 | initial_state_dist: Distribution from which state is sampled at the
211 | start of the episode. If `None`, it is assumed initial state
212 | is always 0. Shape `(n_states,)`.
213 |
214 | Raises:
215 | ValueError: `transition_matrix`, `reward_matrix` or
216 | `initial_state_dist` have shapes different to specified above.
217 | """
218 | super().__init__()
219 |
220 | # The following matrices should conform to the shapes below:
221 |
222 | # transition matrix: n_states x n_actions x n_states
223 | n_states = transition_matrix.shape[0]
224 | if n_states != transition_matrix.shape[2]:
225 | raise ValueError(
226 | "Malformed transition_matrix:\n"
227 | f"transition_matrix.shape: {transition_matrix.shape}\n"
228 | f"{n_states} != {transition_matrix.shape[2]}",
229 | )
230 |
231 | # reward matrix: n_states x n_actions x n_states
232 | # OR n_states x n_actions
233 | # OR n_states
234 | if reward_matrix.shape != transition_matrix.shape[: len(reward_matrix.shape)]:
235 | raise ValueError(
236 | "transition_matrix and reward_matrix are not compatible:\n"
237 | f"transition_matrix.shape: {transition_matrix.shape}\n"
238 | f"reward_matrix.shape: {reward_matrix.shape}",
239 | )
240 |
241 | # initial state dist: n_states
242 | if initial_state_dist is None:
243 | initial_state_dist = util.one_hot_encoding(0, n_states)
244 | if initial_state_dist.ndim != 1:
245 | raise ValueError(
246 | "initial_state_dist has multiple dimensions:\n"
247 | f"{initial_state_dist.ndim} != 1",
248 | )
249 | if initial_state_dist.shape[0] != n_states:
250 | raise ValueError(
251 | "transition_matrix and initial_state_dist are not compatible:\n"
252 | f"number of states = {n_states}\n"
253 | f"len(initial_state_dist) = {len(initial_state_dist)}",
254 | )
255 |
256 | self.transition_matrix = transition_matrix
257 | self.reward_matrix = reward_matrix
258 | self._feature_matrix = None
259 | self.horizon = horizon
260 | self.initial_state_dist = initial_state_dist
261 |
262 | self.state_space = spaces.Discrete(self.state_dim)
263 | self.action_space = spaces.Discrete(self.action_dim)
264 |
265 | def initial_state(self) -> DiscreteSpaceInt:
266 | """Samples from the initial state distribution."""
267 | return DiscreteSpaceInt(
268 | util.sample_distribution(
269 | self.initial_state_dist,
270 | random=self.np_random,
271 | ),
272 | )
273 |
274 | def transition(
275 | self,
276 | state: DiscreteSpaceInt,
277 | action: DiscreteSpaceInt,
278 | ) -> DiscreteSpaceInt:
279 | """Samples from transition distribution."""
280 | return DiscreteSpaceInt(
281 | util.sample_distribution(
282 | self.transition_matrix[state, action],
283 | random=self.np_random,
284 | ),
285 | )
286 |
287 | def reward(
288 | self,
289 | state: DiscreteSpaceInt,
290 | action: DiscreteSpaceInt,
291 | new_state: DiscreteSpaceInt,
292 | ) -> float:
293 | """Computes reward for a given transition."""
294 | inputs = (state, action, new_state)[: len(self.reward_matrix.shape)]
295 | return self.reward_matrix[inputs]
296 |
297 | def terminal(self, state: DiscreteSpaceInt, n_actions_taken: int) -> bool:
298 | """Checks if state is terminal."""
299 | del state
300 | return self.horizon is not None and n_actions_taken >= self.horizon
301 |
302 | @property
303 | def feature_matrix(self):
304 | """Matrix mapping states to feature vectors."""
305 | # Construct lazily to save memory in algorithms that don't need features.
306 | if self._feature_matrix is None:
307 | n_states = self.state_space.n
308 | self._feature_matrix = np.eye(n_states)
309 | return self._feature_matrix
310 |
311 | @property
312 | def state_dim(self):
313 | """Number of states in this MDP (int)."""
314 | return self.transition_matrix.shape[0]
315 |
316 | @property
317 | def action_dim(self) -> int:
318 | """Number of action vectors (int)."""
319 | return self.transition_matrix.shape[1]
320 |
321 |
322 | ObsEntryType = TypeVar(
323 | "ObsEntryType",
324 | bound=Union[np.floating, np.integer],
325 | )
326 |
327 |
328 | class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray], Generic[ObsEntryType]):
329 | """Tabular model POMDP.
330 |
331 | This class is specifically for environments where observation != state,
332 | from both a typing perspective but also by defining the method that
333 | draws observations from the state.
334 |
335 | The tabular model is deterministic in drawing observations from the state,
336 | in that given a certain state, the observation is always the same;
337 | a vector with self.obs_dim entries.
338 | """
339 |
340 | observation_matrix: npt.NDArray[ObsEntryType]
341 | observation_space: spaces.Box
342 |
343 | def __init__(
344 | self,
345 | *,
346 | transition_matrix: np.ndarray,
347 | observation_matrix: npt.NDArray[ObsEntryType],
348 | reward_matrix: np.ndarray,
349 | horizon: Optional[int] = None,
350 | initial_state_dist: Optional[np.ndarray] = None,
351 | ):
352 | """Initializes a tabular model POMDP."""
353 | self.observation_matrix = observation_matrix
354 | super().__init__(
355 | transition_matrix=transition_matrix,
356 | reward_matrix=reward_matrix,
357 | horizon=horizon,
358 | initial_state_dist=initial_state_dist,
359 | )
360 |
361 | # observation matrix: n_states x n_observations
362 | if observation_matrix.shape[0] != self.state_dim:
363 | raise ValueError(
364 | "transition_matrix and observation_matrix are not compatible:\n"
365 | f"transition_matrix.shape[0]: {self.state_dim}\n"
366 | f"observation_matrix.shape[0]: {observation_matrix.shape[0]}",
367 | )
368 |
369 | min_val: float
370 | max_val: float
371 | try:
372 | dtype_iinfo = np.iinfo(self.obs_dtype)
373 | min_val, max_val = dtype_iinfo.min, dtype_iinfo.max
374 | except ValueError:
375 | min_val = -np.inf
376 | max_val = np.inf
377 | self.observation_space = spaces.Box(
378 | low=min_val,
379 | high=max_val,
380 | shape=(self.obs_dim,),
381 | dtype=self.obs_dtype, # type: ignore
382 | )
383 |
384 | def obs_from_state(self, state: DiscreteSpaceInt) -> npt.NDArray[ObsEntryType]:
385 | """Computes observation from state."""
386 | # Copy so it can't be mutated in-place (updates will be reflected in
387 | # self.observation_matrix!)
388 | obs = self.observation_matrix[state].copy()
389 | assert obs.ndim == 1, obs.shape
390 | return obs
391 |
392 | @property
393 | def obs_dim(self) -> int:
394 | """Size of observation vectors for this MDP."""
395 | return self.observation_matrix.shape[1]
396 |
397 | @property
398 | def obs_dtype(self) -> np.dtype:
399 | """Data type of observation vectors (e.g. np.float32)."""
400 | return self.observation_matrix.dtype
401 |
402 |
403 | class TabularModelMDP(BaseTabularModelPOMDP[DiscreteSpaceInt]):
404 | """Tabular model MDP.
405 |
406 | A tabular model MDP is a tabular MDP where the transition and reward
407 | matrices are constant.
408 | """
409 |
410 | def __init__(
411 | self,
412 | *,
413 | transition_matrix: np.ndarray,
414 | reward_matrix: np.ndarray,
415 | horizon: Optional[int] = None,
416 | initial_state_dist: Optional[np.ndarray] = None,
417 | ):
418 | """Initializes a tabular model MDP.
419 |
420 | Args:
421 | transition_matrix: Matrix of shape `(n_states, n_actions, n_states)`
422 | containing transition probabilities.
423 | reward_matrix: Matrix of shape `(n_states, n_actions, n_states)`
424 | containing reward values.
425 | initial_state_dist: Distribution over initial states. Shape `(n_states,)`.
426 | horizon: Maximum number of steps to take in an episode.
427 | """
428 | super().__init__(
429 | transition_matrix=transition_matrix,
430 | reward_matrix=reward_matrix,
431 | horizon=horizon,
432 | initial_state_dist=initial_state_dist,
433 | )
434 | self.observation_space = self.state_space
435 |
436 | def obs_from_state(self, state: DiscreteSpaceInt) -> DiscreteSpaceInt:
437 | """Identity since observation == state in an MDP."""
438 | return state
439 |
--------------------------------------------------------------------------------
/src/seals/classic_control.py:
--------------------------------------------------------------------------------
1 | """Adaptation of classic Gym environments for specification learning algorithms."""
2 |
3 | from typing import Any, Dict, Optional
4 | import warnings
5 |
6 | import gymnasium as gym
7 | from gymnasium import spaces
8 | from gymnasium.envs import classic_control
9 | import numpy as np
10 |
11 | from seals import util
12 |
13 |
14 | class FixedHorizonCartPole(classic_control.CartPoleEnv):
15 | """Fixed-length variant of CartPole-v1.
16 |
17 | Reward is 1.0 whenever the CartPole is an "ok" state (i.e., the pole is upright
18 | and the cart is on the screen). Otherwise, reward is 0.0.
19 |
20 | Terminated is always False.
21 | By default, this environment is wrapped in 'TimeLimit' with max steps 500,
22 | which sets `truncated` to true after that many steps.
23 | """
24 |
25 | def __init__(self, *args, **kwargs):
26 | """Builds FixedHorizonCartPole, modifying observation_space from gym parent."""
27 | super().__init__(*args, **kwargs)
28 |
29 | high = [
30 | np.finfo(np.float32).max, # x axis
31 | np.finfo(np.float32).max, # x velocity
32 | np.pi, # theta in radians
33 | np.finfo(np.float32).max, # theta velocity
34 | ]
35 | high = np.array(high)
36 | self.observation_space = spaces.Box(-high, high, dtype=np.float32)
37 |
38 | def reset(
39 | self,
40 | seed: Optional[int] = None,
41 | options: Optional[Dict[str, Any]] = None,
42 | ):
43 | """Reset for FixedHorizonCartPole."""
44 | observation, info = super().reset(seed=seed, options=options)
45 | return observation.astype(np.float32), info
46 |
47 | def step(self, action):
48 | """Step function for FixedHorizonCartPole."""
49 | with warnings.catch_warnings():
50 | # Filter out CartPoleEnv warning for calling step() beyond
51 | # terminated=True or truncated=True
52 | warnings.filterwarnings("ignore", ".*You are calling.*")
53 | super().step(action)
54 |
55 | self.state = list(self.state)
56 | x, _, theta, _ = self.state
57 |
58 | # Normalize theta to [-pi, pi] range.
59 | theta = (theta + np.pi) % (2 * np.pi) - np.pi
60 | self.state[2] = theta
61 |
62 | state_ok = bool(
63 | abs(x) < self.x_threshold and abs(theta) < self.theta_threshold_radians,
64 | )
65 |
66 | rew = 1.0 if state_ok else 0.0
67 | return np.array(self.state, dtype=np.float32), rew, False, False, {}
68 |
69 |
70 | def mountain_car(*args, **kwargs):
71 | """Fixed-length variant of MountainCar-v0.
72 |
73 | In the event of early episode completion (i.e., the car reaches the
74 | goal), we enter an absorbing state that repeats the final observation
75 | and returns reward 0.
76 |
77 | Done is always returned on timestep 200 only.
78 | """
79 | env = gym.make("MountainCar-v0", *args, **kwargs)
80 | env = util.ObsCastWrapper(env, dtype=np.float32)
81 | env = util.AbsorbAfterDoneWrapper(env)
82 | return env
83 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/__init__.py:
--------------------------------------------------------------------------------
1 | """Simple diagnostic environments."""
2 |
3 | import gymnasium as gym
4 |
5 | gym.register(
6 | id="seals/Branching-v0",
7 | entry_point="seals.diagnostics.branching:BranchingEnv",
8 | max_episode_steps=11,
9 | )
10 |
11 | gym.register(
12 | id="seals/EarlyTermNeg-v0",
13 | entry_point="seals.diagnostics.early_term:EarlyTermNegEnv",
14 | max_episode_steps=10,
15 | )
16 |
17 | gym.register(
18 | id="seals/EarlyTermPos-v0",
19 | entry_point="seals.diagnostics.early_term:EarlyTermPosEnv",
20 | max_episode_steps=10,
21 | )
22 |
23 | gym.register(
24 | id="seals/InitShiftTrain-v0",
25 | entry_point="seals.diagnostics.init_shift:InitShiftTrainEnv",
26 | max_episode_steps=3,
27 | )
28 |
29 | gym.register(
30 | id="seals/InitShiftTest-v0",
31 | entry_point="seals.diagnostics.init_shift:InitShiftTestEnv",
32 | max_episode_steps=3,
33 | )
34 |
35 | gym.register(
36 | id="seals/LargestSum-v0",
37 | entry_point="seals.diagnostics.largest_sum:LargestSumEnv",
38 | max_episode_steps=1,
39 | )
40 |
41 | gym.register(
42 | id="seals/NoisyObs-v0",
43 | entry_point="seals.diagnostics.noisy_obs:NoisyObsEnv",
44 | max_episode_steps=15,
45 | )
46 |
47 | gym.register(
48 | id="seals/Parabola-v0",
49 | entry_point="seals.diagnostics.parabola:ParabolaEnv",
50 | max_episode_steps=20,
51 | )
52 |
53 | gym.register(
54 | id="seals/ProcGoal-v0",
55 | entry_point="seals.diagnostics.proc_goal:ProcGoalEnv",
56 | max_episode_steps=20,
57 | )
58 |
59 | gym.register(
60 | id="seals/RiskyPath-v0",
61 | entry_point="seals.diagnostics.risky_path:RiskyPathEnv",
62 | max_episode_steps=5,
63 | )
64 |
65 | gym.register(
66 | id="seals/Sort-v0",
67 | entry_point="seals.diagnostics.sort:SortEnv",
68 | max_episode_steps=6,
69 | )
70 |
71 |
72 | def register_cliff_world(suffix, kwargs):
73 | """Register a CliffWorld with the given suffix and keyword arguments."""
74 | gym.register(
75 | f"seals/CliffWorld{suffix}-v0",
76 | entry_point="seals.diagnostics.cliff_world:CliffWorldEnv",
77 | kwargs=kwargs,
78 | )
79 |
80 |
81 | def register_all_cliff_worlds():
82 | """Register all CliffWorld environments."""
83 | for width, height, horizon in [(7, 4, 9), (15, 6, 18), (100, 20, 110)]:
84 | for use_xy in [False, True]:
85 | use_xy_str = "XY" if use_xy else ""
86 | register_cliff_world(
87 | f"{width}x{height}{use_xy_str}",
88 | kwargs={
89 | "width": width,
90 | "height": height,
91 | "use_xy_obs": use_xy,
92 | "horizon": horizon,
93 | },
94 | )
95 |
96 |
97 | register_all_cliff_worlds()
98 |
99 | # These parameter choices are somewhat arbitrary.
100 | # We anticipate most users will want to construct RandomTransitionEnv directly.
101 | gym.register(
102 | "seals/Random-v0",
103 | entry_point="seals.diagnostics.random_trans:RandomTransitionEnv",
104 | kwargs={
105 | "n_states": 16,
106 | "n_actions": 3,
107 | "branch_factor": 2,
108 | "horizon": 20,
109 | "random_obs": True,
110 | "obs_dim": 5,
111 | "generator_seed": 42,
112 | },
113 | )
114 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/branching.py:
--------------------------------------------------------------------------------
1 | """Hard-exploration environment."""
2 |
3 | import itertools
4 |
5 | import numpy as np
6 |
7 | from seals import base_envs
8 |
9 |
10 | class BranchingEnv(base_envs.TabularModelMDP):
11 | """Long branching environment requiring exploration.
12 |
13 | The agent must traverse a specific path of length L to reach a
14 | final goal, with B choices at each step. Wrong actions lead to
15 | dead-ends with zero reward.
16 | """
17 |
18 | def __init__(self, branch_factor: int = 2, length: int = 10):
19 | """Construct environment.
20 |
21 | Args:
22 | branch_factor: number of actions at each state.
23 | length: path length from initial state to goal.
24 | """
25 | nS = 1 + branch_factor * length
26 | nA = branch_factor
27 |
28 | def get_next(state: int, action: int) -> int:
29 | can_move = state % branch_factor == 0 and state != nS - 1
30 | return state + (action + 1) * can_move
31 |
32 | transition_matrix = np.zeros((nS, nA, nS))
33 | for state, action in itertools.product(range(nS), range(nA)):
34 | transition_matrix[state, action, get_next(state, action)] = 1.0
35 |
36 | reward_matrix = np.zeros((nS,))
37 | reward_matrix[-1] = 1.0
38 |
39 | super().__init__(
40 | transition_matrix=transition_matrix,
41 | reward_matrix=reward_matrix,
42 | )
43 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/cliff_world.py:
--------------------------------------------------------------------------------
1 | """A cliff world that uses the TabularModelPOMDP."""
2 |
3 | import numpy as np
4 |
5 | from seals.base_envs import TabularModelPOMDP
6 |
7 |
8 | class CliffWorldEnv(TabularModelPOMDP):
9 | """A grid world with a goal next to a cliff the agent may fall into.
10 |
11 | Illustration::
12 |
13 | 0 1 2 3 4 5 6 7 8 9
14 | +-+-+-+-+-+-+-+-+-+-+ Wind:
15 | 0 |S|C|C|C|C|C|C|C|C|G|
16 | +-+-+-+-+-+-+-+-+-+-+ ^ ^ ^
17 | 1 | | | | | | | | | | | | | |
18 | +-+-+-+-+-+-+-+-+-+-+
19 | 2 | | | | | | | | | | | ^ ^ ^
20 | +-+-+-+-+-+-+-+-+-+-+ | | |
21 |
22 | Aim is to get from S to G. The G square has reward +10, the C squares
23 | ("cliff") have reward -10, and all other squares have reward -1. Agent can
24 | move in all directions (except through walls), but there is 30% chance that
25 | they will be blown upwards by one more unit than intended due to wind.
26 | Optimal policy is to go out a bit and avoid the cliff, but still hit goal
27 | eventually.
28 | """
29 |
30 | width: int
31 | height: int
32 |
33 | def __init__(
34 | self,
35 | *,
36 | width: int,
37 | height: int,
38 | horizon: int,
39 | use_xy_obs: bool,
40 | rew_default: int = -1,
41 | rew_goal: int = 10,
42 | rew_cliff: int = -10,
43 | fail_p: float = 0.3,
44 | ):
45 | """Builds CliffWorld with specified dimensions and reward."""
46 | assert (
47 | width >= 3 and height >= 2
48 | ), "degenerate grid world requested; is this a bug?"
49 | self.width = width
50 | self.height = height
51 | succ_p = 1 - fail_p
52 | n_states = width * height
53 | O_mat = np.zeros(
54 | (n_states, 2 if use_xy_obs else n_states),
55 | dtype=np.float32,
56 | )
57 | R_vec = np.zeros((n_states,))
58 | T_mat = np.zeros((n_states, 4, n_states))
59 |
60 | def to_id_clamp(row, col):
61 | """Convert (x,y) state to state ID, after clamp x & y to lie in grid."""
62 | row = min(max(row, 0), height - 1)
63 | col = min(max(col, 0), width - 1)
64 | state_id = row * width + col
65 | assert 0 <= state_id < T_mat.shape[0]
66 | return state_id
67 |
68 | for row in range(height):
69 | for col in range(width):
70 | state_id = to_id_clamp(row, col)
71 |
72 | # start by computing reward
73 | if row > 0:
74 | r = rew_default # blank
75 | elif col == 0:
76 | r = rew_default # start
77 | elif col == width - 1:
78 | r = rew_goal # goal
79 | else:
80 | r = rew_cliff # cliff
81 | R_vec[state_id] = r
82 |
83 | # now compute observation
84 | if use_xy_obs:
85 | # (x, y) coordinate scaled to (0,1)
86 | O_mat[state_id, :] = [
87 | float(col) / (width - 1),
88 | float(row) / (height - 1),
89 | ]
90 | else:
91 | # our observation matrix is just the identity; observation
92 | # is an indicator vector telling us exactly what state
93 | # we're in
94 | O_mat[state_id, state_id] = 1
95 |
96 | # finally, compute transition matrix entries for each of the
97 | # four actions
98 | for drow in [-1, 1]:
99 | for dcol in [-1, 1]:
100 | action_id = (drow + 1) + (dcol + 1) // 2
101 | target_state = to_id_clamp(row + drow, col + dcol)
102 | fail_state = to_id_clamp(row + drow - 1, col + dcol)
103 | T_mat[state_id, action_id, fail_state] += fail_p
104 | T_mat[state_id, action_id, target_state] += succ_p
105 |
106 | assert np.allclose(np.sum(T_mat, axis=-1), 1, rtol=1e-5), (
107 | "un-normalised matrix %s" % O_mat
108 | )
109 | super().__init__(
110 | transition_matrix=T_mat,
111 | observation_matrix=O_mat,
112 | reward_matrix=R_vec,
113 | horizon=horizon,
114 | initial_state_dist=None,
115 | )
116 |
117 | def draw_value_vec(self, D: np.ndarray) -> None:
118 | """Use matplotlib to plot a vector of values for each state.
119 |
120 | The vector could represent things like reward, occupancy measure, etc.
121 |
122 | Args:
123 | D: the vector to plot.
124 |
125 | Raises:
126 | ImportError: if matplotlib is not installed.
127 | """
128 | try: # pragma: no cover
129 | import matplotlib.pyplot as plt
130 | except ImportError as exc: # pragma: no cover
131 | raise ImportError(
132 | "matplotlib is not installed in your system, "
133 | "and is required for this function.",
134 | ) from exc
135 |
136 | grid = D.reshape(self.height, self.width)
137 | plt.imshow(grid)
138 | plt.gca().grid(False)
139 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/early_term.py:
--------------------------------------------------------------------------------
1 | """Environment checking for correctness under early termination."""
2 |
3 | import functools
4 |
5 | import numpy as np
6 |
7 | from seals import base_envs
8 |
9 |
10 | class EarlyTerminationEnv(base_envs.TabularModelMDP):
11 | """Three-state MDP with early termination state.
12 |
13 | Many implementations of imitation learning algorithms incorrectly assign a
14 | value of zero to terminal states [1]. Depending on the sign of the learned
15 | reward function in non-terminal states, this can either bias the agent to
16 | end episodes early or prolong them as long as possible. This confounds
17 | evaluation as performance is spuriously high in tasks where the termination
18 | bias aligns with the task objective. These tasks attempt to detect this
19 | type of bias, and they are adapted from [1].
20 |
21 | The environment is a 3-state MDP, in which the agent can either alternate
22 | between two initial states until reaching the time horizon, or they can
23 | move to a terminal state causing the episode to terminate early.
24 |
25 | [1] Kostrikov, Ilya, et al. "Discriminator-actor-critic: Addressing sample
26 | inefficiency and reward bias in adversarial imitation learning." arXiv
27 | preprint arXiv:1809.02925 (2018).
28 | """
29 |
30 | def __init__(self, is_reward_positive: bool = True):
31 | """Construct environment.
32 |
33 | Args:
34 | is_reward_positive: whether rewards are positive or negative.
35 | """
36 | nS = 3
37 | nA = 2
38 |
39 | transition_matrix = np.zeros((nS, nA, nS))
40 |
41 | transition_matrix[0, :, 1] = 1.0
42 |
43 | transition_matrix[1, 0, 0] = 1.0
44 | transition_matrix[1, 1, 2] = 1.0
45 |
46 | transition_matrix[2, :, 2] = 1.0
47 |
48 | reward_sign = 2 * is_reward_positive - 1
49 | reward_matrix = reward_sign * np.ones((nS,), dtype=float)
50 |
51 | super().__init__(
52 | transition_matrix=transition_matrix,
53 | reward_matrix=reward_matrix,
54 | )
55 |
56 | def terminal(self, state: base_envs.DiscreteSpaceInt, n_actions_taken: int) -> bool:
57 | """Returns True if (and only if) in state 2."""
58 | return bool(state == 2)
59 |
60 |
61 | EarlyTermPosEnv = functools.partial(EarlyTerminationEnv, is_reward_positive=True)
62 | EarlyTermNegEnv = functools.partial(EarlyTerminationEnv, is_reward_positive=False)
63 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/init_shift.py:
--------------------------------------------------------------------------------
1 | """Environment with shift in initial state distribution."""
2 |
3 | import functools
4 | import itertools
5 |
6 | import numpy as np
7 |
8 | from seals import base_envs
9 |
10 |
11 | class InitShiftEnv(base_envs.TabularModelMDP):
12 | """Tests for robustness to initial state shift.
13 |
14 | Many LfH algorithms learn from expert demonstrations. This can be
15 | problematic when the environment the demonstrations were gathered in
16 | differs even slightly from the learner's environment.
17 |
18 | This task illustrates this problem. We have a depth-2 full binary tree
19 | where the agent moves left or right until reaching a leaf. The expert
20 | starts at the root s_0, whereas the learner starts at the left branch s_1
21 | and so can only reach leaves s_3 and s_4. Reward is only given at the
22 | leaves.
23 |
24 | The expert always move to the highest reward leaf s_6, so any algorithm
25 | that relies on demonstrations will not know whether it is better to go to
26 | s_3 or s_4. By contrast, feedback such as preference comparison can
27 | disambiguate this case.
28 | """
29 |
30 | def __init__(self, initial_state: base_envs.DiscreteSpaceInt):
31 | """Constructs environment.
32 |
33 | Args:
34 | initial_state: fixed initial state.
35 |
36 | Raises:
37 | ValueError: `initial_state` not in [0,6].
38 | """
39 | nS = 7
40 | nA = 2
41 |
42 | if not 0 <= initial_state < nS:
43 | raise ValueError(f"Initial state {initial_state} must lie in [0,{nS})")
44 |
45 | self._initial_state = initial_state
46 |
47 | non_leaves = np.arange(3)
48 | leaves = np.arange(3, 7)
49 |
50 | transition_matrix = np.zeros((nS, nA, nS))
51 |
52 | for state, action in itertools.product(non_leaves, range(nA)):
53 | next_state = 2 * state + 1 + action
54 | transition_matrix[state, action, next_state] = 1.0
55 |
56 | transition_matrix[leaves, :, leaves] = 1.0
57 |
58 | reward_matrix = np.zeros((nS,))
59 | reward_matrix[leaves] = [1, -1, -1, 2]
60 |
61 | super().__init__(
62 | transition_matrix=transition_matrix,
63 | reward_matrix=reward_matrix,
64 | )
65 |
66 | def initial_state(self) -> base_envs.DiscreteSpaceInt:
67 | """Returns initial state defined in constructor."""
68 | return self._initial_state
69 |
70 |
71 | InitShiftTrainEnv = functools.partial(InitShiftEnv, initial_state=0)
72 | InitShiftTestEnv = functools.partial(InitShiftEnv, initial_state=1)
73 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/largest_sum.py:
--------------------------------------------------------------------------------
1 | """Environment testing scalability to high-dimensional tasks."""
2 |
3 | from gymnasium import spaces
4 | import numpy as np
5 |
6 | from seals import base_envs
7 |
8 |
9 | class LargestSumEnv(base_envs.ResettableMDP):
10 | """High-dimensional linear classification problem.
11 |
12 | This environment evaluates how algorithms scale with increasing
13 | dimensionality. It is a classification task with binary actions
14 | and uniformly sampled states s in [0, 1]**L. The agent is
15 | rewarded for taking action 1 if the sum of the first half x[:L//2]
16 | is greater than the sum of the second half x[L//2:], and otherwise
17 | is rewarded for taking action 0.
18 | """
19 |
20 | def __init__(self, length: int = 50):
21 | """Build environment.
22 |
23 | Args:
24 | length: dimensionality of state space vector.
25 | """
26 | super().__init__()
27 | self._length = length
28 | self.state_space = spaces.Box(low=0.0, high=1.0, shape=(length,))
29 | self.action_space = spaces.Discrete(2)
30 |
31 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool:
32 | """Always returns True, since this task should have a 1-timestep horizon."""
33 | return True
34 |
35 | def initial_state(self) -> np.ndarray:
36 | """Returns vector sampled uniformly in [0, 1]**L."""
37 | init_state = self.np_random.random((self._length,))
38 | return init_state.astype(self.observation_space.dtype)
39 |
40 | def reward(self, state: np.ndarray, act: int, next_state: np.ndarray) -> float:
41 | """Returns +1.0 reward when action is the right label and 0.0 otherwise."""
42 | n = self._length
43 | label = np.sum(state[: n // 2]) > np.sum(state[n // 2 :])
44 | return float(act == label)
45 |
46 | def transition(self, state: np.ndarray, action: int) -> np.ndarray:
47 | """Returns same state."""
48 | return state
49 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/noisy_obs.py:
--------------------------------------------------------------------------------
1 | """Environment testing for robustness to noise."""
2 |
3 | from gymnasium import spaces
4 | import numpy as np
5 |
6 | from seals import base_envs, util
7 |
8 |
9 | class NoisyObsEnv(base_envs.ResettablePOMDP):
10 | """Simple gridworld with noisy observations.
11 |
12 | The agent randomly starts at the one of the corners of an MxM grid and
13 | tries to reach and stay at the center. The observation consists of the
14 | agent's (x,y) coordinates and L "distractor" samples of Gaussian noise .
15 | The challenge is to select the relevant features in the observations, and
16 | not overfit to noise.
17 | """
18 |
19 | def __init__(self, *, size: int = 5, noise_length: int = 20):
20 | """Build environment.
21 |
22 | Args:
23 | size: width and height of gridworld.
24 | noise_length: dimension of noise vector in observation.
25 | """
26 | super().__init__()
27 |
28 | self._size = size
29 | self._noise_length = noise_length
30 | self._goal = np.array([self._size // 2, self._size // 2])
31 |
32 | obs_box_low = np.concatenate(
33 | ([0, 0], np.full(self._noise_length, -np.inf)), # type: ignore
34 | )
35 | obs_box_high = np.concatenate(
36 | ([size - 1, size - 1], np.full(self._noise_length, np.inf)), # type: ignore
37 | )
38 |
39 | self.state_space = spaces.MultiDiscrete([size, size])
40 | self.action_space = spaces.Discrete(5)
41 | self.observation_space = spaces.Box(
42 | low=obs_box_low,
43 | high=obs_box_high,
44 | dtype=np.float32,
45 | )
46 |
47 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool:
48 | """Always returns False."""
49 | return False
50 |
51 | def initial_state(self) -> np.ndarray:
52 | """Returns one of the grid's corners."""
53 | n = self._size
54 | corners = np.array([[0, 0], [n - 1, 0], [0, n - 1], [n - 1, n - 1]])
55 | return corners[self.np_random.integers(4)]
56 |
57 | def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float:
58 | """Returns +1.0 reward if state is the goal and 0.0 otherwise."""
59 | return float(np.all(state == self._goal))
60 |
61 | def transition(self, state: np.ndarray, action: int) -> np.ndarray:
62 | """Returns next state according to grid."""
63 | return util.grid_transition_fn(
64 | state,
65 | action,
66 | x_bounds=(0, self._size - 1),
67 | y_bounds=(0, self._size - 1),
68 | )
69 |
70 | def obs_from_state(self, state: np.ndarray) -> np.ndarray:
71 | """Returns (x, y) concatenated with Gaussian noise."""
72 | noise_vector = self.np_random.normal(size=self._noise_length)
73 | return np.concatenate([state, noise_vector]).astype(np.float32)
74 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/parabola.py:
--------------------------------------------------------------------------------
1 | """Environment testing for generalization in continuous spaces."""
2 |
3 | from gymnasium import spaces
4 | import numpy as np
5 |
6 | from seals import base_envs
7 |
8 |
9 | class ParabolaEnv(base_envs.ResettableMDP):
10 | """Environment to mimic parabola curves.
11 |
12 | This environment tests algorithms' ability to learn in continuous
13 | action spaces, a challenge for Q-learning methods in particular.
14 | The goal is to mimic the path of a parabola p(x) = A*x**2 + B*x +
15 | C, where A, B and C are constants sampled uniformly from [-1, 1]
16 | at the start of the episode.
17 | """
18 |
19 | def __init__(self, x_step: float = 0.05, bounds: float = 5):
20 | """Construct environment.
21 |
22 | Args:
23 | x_step: x position difference between timesteps.
24 | bounds: limits coordinates, useful for keeping rewards in
25 | a small bounded range.
26 | """
27 | super().__init__()
28 | self._x_step = x_step
29 | self._bounds = bounds
30 |
31 | state_high = np.array([bounds, bounds, 1.0, 1.0, 1.0])
32 | state_low = (-1) * state_high
33 |
34 | self.state_space = spaces.Box(low=state_low, high=state_high)
35 | self.action_space = spaces.Box(low=(-2) * bounds, high=2 * bounds, shape=())
36 |
37 | def terminal(self, state: int, n_actions_taken: int) -> bool:
38 | """Always returns False."""
39 | return False
40 |
41 | def initial_state(self) -> np.ndarray:
42 | """Get state by sampling a random parabola."""
43 | a, b, c = -1 + 2 * self.np_random.random((3,))
44 | x, y = 0, c
45 | return np.array([x, y, a, b, c], dtype=self.state_space.dtype)
46 |
47 | def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float:
48 | """Negative squared vertical distance from parabola."""
49 | x, y, a, b, c = state
50 | target_y = a * x**2 + b * x + c
51 | return (-1) * (y - target_y) ** 2
52 |
53 | def transition(self, state: np.ndarray, action: int) -> np.ndarray:
54 | """Update x according to x_step and y according to action."""
55 | x, y, a, b, c = state
56 | next_x = np.clip(x + self._x_step, -self._bounds, self._bounds)
57 | next_y = np.clip(y + action, -self._bounds, self._bounds)
58 | return np.array([next_x, next_y, a, b, c], dtype=self.state_space.dtype)
59 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/proc_goal.py:
--------------------------------------------------------------------------------
1 | """Large gridworld with random agent and goal position."""
2 |
3 | from gymnasium import spaces
4 | import numpy as np
5 |
6 | from seals import base_envs, util
7 |
8 |
9 | class ProcGoalEnv(base_envs.ResettableMDP):
10 | """Large gridworld with random agent and goal position.
11 |
12 | In this task, the agent starts at a random position in a large
13 | grid, and must navigate to a goal randomly placed in a
14 | neighborhood around the agent. The observation is a 4-dimensional
15 | vector containing the (x,y) coordinates of the agent and the goal.
16 | The reward at each timestep is the negative Manhattan distance
17 | between the two positions. With a large enough grid, generalizing
18 | is necessary to achieve good performance, since most initial
19 | states will be unseen.
20 | """
21 |
22 | def __init__(self, bounds: int = 100, distance: int = 10):
23 | """Constructs environment.
24 |
25 | Args:
26 | bounds: the absolute values of the coordinates of the initial agent
27 | position are bounded by `bounds`. Increasing the value might make
28 | generalization harder.
29 | distance: initial distance between agent and goal.
30 | """
31 | super().__init__()
32 | self._bounds = bounds
33 | self._distance = distance
34 | self.state_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,))
35 | self.action_space = spaces.Discrete(5)
36 |
37 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool:
38 | """Always returns False."""
39 | return False
40 |
41 | def initial_state(self) -> np.ndarray:
42 | """Samples random agent position and random goal."""
43 | pos = self.np_random.integers(low=-self._bounds, high=self._bounds, size=(2,))
44 |
45 | x_dist = self.np_random.integers(self._distance)
46 | y_dist = self._distance - x_dist
47 | random_signs = 2 * self.np_random.integers(2, size=2) - 1
48 | goal = pos + random_signs * (x_dist, y_dist)
49 |
50 | return np.concatenate([pos, goal]).astype(self.observation_space.dtype)
51 |
52 | def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float:
53 | """Negative L1 distance to goal."""
54 | return (-1) * np.sum(np.abs(state[2:] - state[:2]))
55 |
56 | def transition(self, state: np.ndarray, action: int) -> np.ndarray:
57 | """Returns next state according to grid."""
58 | pos, goal = state[:2], state[2:]
59 | next_pos = util.grid_transition_fn(pos, action)
60 | return np.concatenate([next_pos, goal])
61 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/random_trans.py:
--------------------------------------------------------------------------------
1 | """A tabular model MDP with a random transition matrix."""
2 |
3 | from typing import Optional
4 |
5 | import numpy as np
6 |
7 | from seals.base_envs import TabularModelPOMDP
8 |
9 |
10 | class RandomTransitionEnv(TabularModelPOMDP):
11 | """AN MDP with a random transition matrix.
12 |
13 | Random matrix is created by `make_random_trans_mat`.
14 | """
15 |
16 | reward_weights: np.ndarray
17 |
18 | def __init__(
19 | self,
20 | *,
21 | n_states: int,
22 | n_actions: int,
23 | branch_factor: int,
24 | horizon: int,
25 | random_obs: bool,
26 | obs_dim: Optional[int] = None,
27 | generator_seed: Optional[int] = None,
28 | ):
29 | """Builds RandomTransitionEnv.
30 |
31 | Args:
32 | n_states: Number of states.
33 | n_actions: Number of actions.
34 | branch_factor: Maximum number of states that can be reached from
35 | each state-action pair.
36 | horizon: The horizon of the MDP, i.e. the episode length.
37 | random_obs: Whether to use random observations (True)
38 | or one-hot coded (False).
39 | obs_dim: The size of the observation vectors; must be `None`
40 | if `random_obs == False`.
41 | generator_seed: Seed for NumPy RNG.
42 |
43 | Raises:
44 | ValueError: If ``obs_dim`` is not ``None`` when ``random_obs == False``.
45 | ValueError: If ``obs_dim`` is ``None`` when ``random_obs == True``.
46 | """
47 | # this generator is ONLY for constructing the MDP, not for controlling
48 | # random outcomes during rollouts
49 | rand_gen = np.random.default_rng(generator_seed)
50 |
51 | if random_obs:
52 | if obs_dim is None:
53 | obs_dim = n_states
54 | else:
55 | if obs_dim is not None:
56 | raise ValueError("obs_dim must be None if random_obs is False")
57 |
58 | observation_matrix = self.make_obs_mat(
59 | n_states=n_states,
60 | is_random=random_obs,
61 | obs_dim=obs_dim,
62 | rand_state=rand_gen,
63 | )
64 | transition_matrix = self.make_random_trans_mat(
65 | n_states=n_states,
66 | n_actions=n_actions,
67 | max_branch_factor=branch_factor,
68 | rand_state=rand_gen,
69 | )
70 | initial_state_dist = self.make_random_state_dist(
71 | n_avail=branch_factor,
72 | n_states=n_states,
73 | rand_state=rand_gen,
74 | )
75 |
76 | self.reward_weights = rand_gen.normal(size=(observation_matrix.shape[-1],))
77 | reward_matrix = observation_matrix @ self.reward_weights
78 | super().__init__(
79 | transition_matrix=transition_matrix,
80 | observation_matrix=observation_matrix,
81 | reward_matrix=reward_matrix,
82 | horizon=horizon,
83 | initial_state_dist=initial_state_dist,
84 | )
85 |
86 | @staticmethod
87 | def make_random_trans_mat(
88 | n_states,
89 | n_actions,
90 | max_branch_factor,
91 | rand_state: Optional[np.random.Generator] = None,
92 | ) -> np.ndarray:
93 | """Make a 'random' transition matrix.
94 |
95 | Each action goes to at least `max_branch_factor` other states from the
96 | current state, with transition distribution sampled from Dirichlet(1,1,…,1).
97 |
98 | This roughly apes the strategy from some old Lisp code that Rich Sutton
99 | left on the internet (http://incompleteideas.net/RandomMDPs.html), and is
100 | therefore a legitimate way to generate MDPs.
101 |
102 | Args:
103 | n_states: Number of states.
104 | n_actions: Number of actions.
105 | max_branch_factor: Maximum number of states that can be reached from
106 | each state-action pair.
107 | rand_state: NumPy random state.
108 |
109 | Returns:
110 | The transition matrix `mat`, where `mat[s,a,next_s]` gives the probability
111 | of transitioning to `next_s` after taking action `a` in state `s`.
112 | """
113 | if rand_state is None:
114 | rand_state = np.random.default_rng()
115 | assert rand_state is not None
116 | out_mat = np.zeros((n_states, n_actions, n_states), dtype="float32")
117 | for start_state in range(n_states):
118 | for action in range(n_actions):
119 | # uniformly sample a number of successors in [1,max_branch_factor]
120 | # for this action
121 | successors = rand_state.integers(1, max_branch_factor + 1)
122 | next_states = rand_state.choice(
123 | n_states,
124 | size=(successors,),
125 | replace=False,
126 | )
127 | # generate random vec in probability simplex
128 | next_vec = rand_state.dirichlet(np.ones((successors,)))
129 | next_vec = next_vec / np.sum(next_vec)
130 | out_mat[start_state, action, next_states] = next_vec
131 | return out_mat
132 |
133 | @staticmethod
134 | def make_random_state_dist(
135 | n_avail: int,
136 | n_states: int,
137 | rand_state: Optional[np.random.Generator] = None,
138 | ) -> np.ndarray:
139 | """Make a random initial state distribution over n_states.
140 |
141 | Args:
142 | n_avail: Number of states available to transition into.
143 | n_states: Total number of states.
144 | rand_state: NumPy random state.
145 |
146 | Returns:
147 | An initial state distribution that is zero at all but a uniformly random
148 | chosen subset of `n_avail` states. This subset of chosen states are set to a
149 | sample from the uniform distribution over the (n_avail-1) simplex, aka the
150 | flat Dirichlet distribution.
151 |
152 | Raises:
153 | ValueError: If `n_avail` is not in the range `(0, n_states]`.
154 | """ # noqa: DAR402
155 | if rand_state is None:
156 | rand_state = np.random.default_rng()
157 | assert rand_state is not None
158 | assert 0 < n_avail <= n_states
159 | init_dist = np.zeros((n_states,))
160 | next_states = rand_state.choice(n_states, size=(n_avail,), replace=False)
161 | avail_state_dist = rand_state.dirichlet(np.ones((n_avail,)))
162 | init_dist[next_states] = avail_state_dist
163 | assert np.sum(init_dist > 0) == n_avail
164 | init_dist = init_dist / np.sum(init_dist)
165 | return init_dist
166 |
167 | @staticmethod
168 | def make_obs_mat(
169 | n_states: int,
170 | is_random: bool,
171 | obs_dim: Optional[int] = None,
172 | rand_state: Optional[np.random.Generator] = None,
173 | ) -> np.ndarray:
174 | """Makes an observation matrix with a single observation for each state.
175 |
176 | Args:
177 | n_states (int): Number of states.
178 | is_random (bool): Are observations drawn at random?
179 | If `True`, draw from random normal distribution.
180 | If `False`, are unique one-hot vectors for each state.
181 | obs_dim (int or NoneType): Must be `None` if `is_random == False`.
182 | Otherwise, this must be set to the size of the random vectors.
183 | rand_state (np.random.Generator): Random number generator.
184 |
185 | Returns:
186 | A matrix of shape `(n_states, obs_dim if is_random else n_states)`.
187 |
188 | Raises:
189 | ValueError: If ``is_random == False`` and ``obs_dim is not None``.
190 | """
191 | if rand_state is None:
192 | rand_state = np.random.default_rng()
193 | assert rand_state is not None
194 | if is_random:
195 | if obs_dim is None:
196 | raise ValueError("obs_dim must be set if random_obs is True")
197 | obs_mat = rand_state.normal(0, 2, (n_states, obs_dim))
198 | else:
199 | if obs_dim is not None:
200 | raise ValueError("obs_dim must be None if random_obs is False")
201 | obs_mat = np.identity(n_states)
202 | assert (
203 | obs_mat.ndim == 2
204 | and obs_mat.shape[:1] == (n_states,)
205 | and obs_mat.shape[1] > 0
206 | )
207 | return obs_mat.astype(np.float32)
208 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/risky_path.py:
--------------------------------------------------------------------------------
1 | """Environment testing for correct behavior under stochasticity."""
2 |
3 | import numpy as np
4 |
5 | from seals import base_envs
6 |
7 |
8 | class RiskyPathEnv(base_envs.TabularModelMDP):
9 | """Environment with two paths to a goal: one safe and one risky.
10 |
11 | Many LfH algorithms are derived from Maximum Entropy Inverse Reinforcement
12 | Learning [1], which models the demonstrator as producing trajectories with
13 | probability p(tau) proportional to exp(R(tau)). This model implies that a
14 | demonstrator can "control" the environment well enough to follow any
15 | high-reward trajectory with high probability [2]. However, in stochastic
16 | environments, the agent cannot control the probability of each trajectory
17 | independently. This misspecification may lead to poor behavior.
18 |
19 | This task tests for this behavior. The agent starts at s_0 and can reach
20 | the goal s_2 (reward 1.0) by either taking the safe path s_0 to s_1 to s_2,
21 | or taking a risky action, which has equal chances of going to either s_3
22 | (reward -100.0) or s_2. The safe path has the highest expected return, but
23 | the risky action sometimes reaches the goal s_2 in fewer timesteps, leading
24 | to higher best-case return. Algorithms that fail to correctly handle
25 | stochastic dynamics may therefore wrongly believe the reward favors taking
26 | the risky path.
27 |
28 | [1] Ziebart, Brian D., et al. "Maximum entropy inverse reinforcement
29 | learning." AAAI. Vol. 8. 2008.
30 | [2] Ziebart, Brian D. "Modeling purposeful adaptive behavior with the
31 | principle of maximum causal entropy." (2010); PhD thesis,
32 | CMU-ML-10-110; page 105.
33 | """
34 |
35 | def __init__(self):
36 | """Initialize environment."""
37 | nS = 4
38 | nA = 2
39 |
40 | transition_matrix = np.zeros((nS, nA, nS))
41 | transition_matrix[0, 0, 1] = 1.0
42 | transition_matrix[0, 1, [2, 3]] = 0.5
43 |
44 | transition_matrix[1, 0, 2] = 1.0
45 | transition_matrix[1, 1, 1] = 1.0
46 |
47 | transition_matrix[[2, 3], :, [2, 3]] = 1.0
48 |
49 | reward_matrix = np.array([0.0, 0.0, 1.0, -100.0])
50 |
51 | super().__init__(
52 | transition_matrix=transition_matrix,
53 | reward_matrix=reward_matrix,
54 | )
55 |
--------------------------------------------------------------------------------
/src/seals/diagnostics/sort.py:
--------------------------------------------------------------------------------
1 | """Environment to sort a list using swap actions."""
2 |
3 | from gymnasium import spaces
4 | import numpy as np
5 |
6 | from seals import base_envs
7 |
8 |
9 | class SortEnv(base_envs.ResettableMDP):
10 | """Environment to sort a list using swap actions."""
11 |
12 | def __init__(self, length: int = 4):
13 | """Constructs environment.
14 |
15 | The initial state is a vector x sampled uniformly from
16 | [0,1]**L, with actions a = (i,j) swapping x_i and x_j. The
17 | reward is given according to the number of elements in the
18 | correct position. To perform well, the learned policy must
19 | compare elements, otherwise it will not generalize to all
20 | possible randomly selected initial states.
21 | """
22 | self._length = length
23 |
24 | super().__init__()
25 | self.state_space = spaces.Box(low=0, high=1.0, shape=(length,))
26 | self.action_space = spaces.MultiDiscrete([length, length])
27 |
28 | def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool:
29 | """Always returns False."""
30 | return False
31 |
32 | def initial_state(self):
33 | """Sample random vector uniformly in [0, 1]**L."""
34 | sample = self.np_random.random(size=self._length)
35 | return sample.astype(self.state_space.dtype)
36 |
37 | def reward(
38 | self,
39 | state: np.ndarray,
40 | action: np.ndarray,
41 | new_state: np.ndarray,
42 | ) -> float:
43 | """Rewards fully sorted lists, and new correct positions."""
44 | del action
45 | # This is not meant to be a potential shaping in the formal sense,
46 | # as it changes the trajectory returns (since we do not return
47 | # a fixed-potential state at termination).
48 | num_correct = self._num_correct_positions(state)
49 | new_num_correct = self._num_correct_positions(new_state)
50 | potential_diff = new_num_correct - num_correct
51 |
52 | return float(self._is_sorted(new_state)) + potential_diff
53 |
54 | def transition(self, state: np.ndarray, action: np.ndarray) -> np.ndarray:
55 | """Action a = (i, j) swaps elements in positions i and j."""
56 | new_state = state.copy()
57 | i, j = action
58 | new_state[[i, j]] = new_state[[j, i]]
59 | return new_state
60 |
61 | def _is_sorted(self, arr: np.ndarray) -> bool:
62 | return list(arr) == sorted(arr)
63 |
64 | def _num_correct_positions(self, arr: np.ndarray) -> int:
65 | return np.sum(arr == sorted(arr))
66 |
--------------------------------------------------------------------------------
/src/seals/mujoco.py:
--------------------------------------------------------------------------------
1 | """Adaptation of MuJoCo environments for specification learning algorithms."""
2 |
3 | import functools
4 |
5 | from gymnasium.envs.mujoco import (
6 | ant_v4,
7 | half_cheetah_v4,
8 | hopper_v4,
9 | humanoid_v4,
10 | swimmer_v4,
11 | walker2d_v4,
12 | )
13 |
14 |
15 | def _include_position_in_observation(cls):
16 | cls.__init__ = functools.partialmethod(
17 | cls.__init__,
18 | exclude_current_positions_from_observation=False,
19 | )
20 | return cls
21 |
22 |
23 | def _no_early_termination(cls):
24 | cls.__init__ = functools.partialmethod(cls.__init__, terminate_when_unhealthy=False)
25 | return cls
26 |
27 |
28 | @_include_position_in_observation
29 | @_no_early_termination
30 | class AntEnv(ant_v4.AntEnv):
31 | """Ant with position observation and no early termination."""
32 |
33 |
34 | @_include_position_in_observation
35 | class HalfCheetahEnv(half_cheetah_v4.HalfCheetahEnv):
36 | """HalfCheetah with position observation. Naturally does not terminate early."""
37 |
38 |
39 | @_include_position_in_observation
40 | @_no_early_termination
41 | class HopperEnv(hopper_v4.HopperEnv):
42 | """Hopper with position observation and no early termination."""
43 |
44 |
45 | @_include_position_in_observation
46 | @_no_early_termination
47 | class HumanoidEnv(humanoid_v4.HumanoidEnv):
48 | """Humanoid with position observation and no early termination."""
49 |
50 |
51 | @_include_position_in_observation
52 | class SwimmerEnv(swimmer_v4.SwimmerEnv):
53 | """Swimmer with position observation. Naturally does not terminate early."""
54 |
55 |
56 | @_include_position_in_observation
57 | @_no_early_termination
58 | class Walker2dEnv(walker2d_v4.Walker2dEnv):
59 | """Walker2d with position observation and no early termination."""
60 |
--------------------------------------------------------------------------------
/src/seals/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HumanCompatibleAI/seals/90f68cfbb694d687f8f2cb05bdf3aa85714bbf6c/src/seals/py.typed
--------------------------------------------------------------------------------
/src/seals/testing/__init__.py:
--------------------------------------------------------------------------------
1 | """Helper methods for unit tests. May also be useful for users of imitation."""
2 |
--------------------------------------------------------------------------------
/src/seals/testing/envs.py:
--------------------------------------------------------------------------------
1 | """Helper methods for tests of custom Gym environments.
2 |
3 | This is used in our test suite in `tests/test_envs.py`. It is also used in sister
4 | projects such as `imitation`, and may be useful in other codebases.
5 | """
6 |
7 | import re
8 | from typing import (
9 | Any,
10 | Callable,
11 | Iterable,
12 | Iterator,
13 | List,
14 | Mapping,
15 | Sequence,
16 | SupportsFloat,
17 | Tuple,
18 | )
19 |
20 | import gymnasium as gym
21 | import numpy as np
22 |
23 | Step = Tuple[Any, SupportsFloat, bool, bool, Mapping[str, Any]]
24 | Rollout = Sequence[Step]
25 | """A sequence of 5-tuples (obs, rew, terminated, truncated, info) as returned by
26 | `get_rollout`."""
27 |
28 |
29 | def make_env_fixture(
30 | skip_fn: Callable[[str], None],
31 | ) -> Callable[[str], Iterator[gym.Env]]:
32 | """Creates a fixture function, calling `skip_fn` when dependencies are missing.
33 |
34 | For example, in `pytest`, one would use::
35 |
36 | env = pytest.fixture(make_env_fixture(skip_fn=pytest.skip))
37 |
38 | Then any method with an `env` parameter will receive the created environment, with
39 | the `env_name` parameter automatically passed to the fixture.
40 |
41 | In `unittest`, one would use::
42 |
43 | def skip_fn(msg):
44 | raise unittest.SkipTest(msg)
45 |
46 | make_env = contextlib.contextmanager(make_env_fixture(skip_fn=skip_fn))
47 |
48 | And then call `with make_env(env_name) as env:` to create environments.
49 |
50 | Args:
51 | skip_fn: the function called when a dependency is missing to skip the test.
52 |
53 | Returns:
54 | A method to create Gym environments given their name.
55 | """
56 |
57 | def f(env_name: str) -> Iterator[gym.Env]:
58 | """Create environment `env_name`.
59 |
60 | Args:
61 | env_name: The name of the environment in the Gym registry.
62 |
63 | Yields:
64 | The created environment.
65 |
66 | Raises:
67 | gym.error.DependencyNotInstalled: if a dependency is missing
68 | other than MuJoCo (for MuJoCo, the test is instead skipped).
69 | """
70 | env = None
71 | try:
72 | env = gym.make(env_name)
73 | yield env
74 | except gym.error.DependencyNotInstalled as e: # pragma: no cover
75 | if e.args[0].find("mujoco_py") != -1:
76 | skip_fn("Requires `mujoco_py`, which isn't installed.")
77 | else:
78 | raise
79 | finally:
80 | if env is not None:
81 | env.close()
82 |
83 | return f
84 |
85 |
86 | def matches_list(env_name: str, patterns: Iterable[str]) -> bool:
87 | """Returns True if `env_name` matches any of the patterns in `patterns`."""
88 | return any(re.match(env_pattern, env_name) for env_pattern in patterns)
89 |
90 |
91 | def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout:
92 | """Performs a sequence of actions `actions` in `env`.
93 |
94 | Args:
95 | env: the environment to roll out in.
96 | actions: the actions to perform.
97 |
98 | Returns:
99 | A sequence of 5-tuples (obs, rew, terminated, truncated, info).
100 | """
101 | obs, info = env.reset()
102 | ret: List[Step] = [(obs, 0, False, False, info)]
103 | for act in actions:
104 | ret.append(env.step(act))
105 | return ret
106 |
107 |
108 | def assert_equal_rollout(rollout_a: Rollout, rollout_b: Rollout) -> None:
109 | """Checks rollouts for equality.
110 |
111 | Raises:
112 | AssertionError if they are not equal.
113 | """
114 | for step_a, step_b in zip(rollout_a, rollout_b):
115 | ob_a, rew_a, terminated_a, truncated_a, info_a = step_a
116 | ob_b, rew_b, terminated_b, truncated_b, info_b = step_b
117 | np.testing.assert_equal(ob_a, ob_b)
118 | assert rew_a == rew_b
119 | assert terminated_a == terminated_b
120 | assert truncated_a == truncated_b
121 | np.testing.assert_equal(info_a, info_b)
122 |
123 |
124 | def has_same_observations(rollout_a: Rollout, rollout_b: Rollout) -> bool:
125 | """True if `rollout_a` and `rollout_b` have the same observations."""
126 | obs_list_a = [step[0] for step in rollout_a]
127 | obs_list_b = [step[0] for step in rollout_b]
128 | if len(obs_list_a) != len(obs_list_b): # pragma: no cover
129 | return False
130 | for obs_a, obs_b in zip(obs_list_a, obs_list_b):
131 | if isinstance(obs_a, Mapping): # pragma: no cover
132 | if obs_a.keys() != obs_b.keys():
133 | return False
134 | obs_a = list(obs_a.values())
135 | obs_b = list(obs_b.values())
136 | else:
137 | obs_a, obs_b = [obs_a], [obs_b]
138 | if any([np.any(x != y) for x, y in zip(obs_a, obs_b)]):
139 | return False
140 | return True
141 |
142 |
143 | def test_seed(
144 | env: gym.Env,
145 | env_name: str,
146 | deterministic_envs: Iterable[str],
147 | rollout_len: int = 10,
148 | num_seeds: int = 100,
149 | ) -> None:
150 | """Tests environment seeding.
151 |
152 | If non-deterministic, different seeds should produce different transitions.
153 | If deterministic, should be invariant to seed.
154 |
155 | Raises:
156 | AssertionError if test fails.
157 | """
158 | env.action_space.seed(0)
159 | actions = [env.action_space.sample() for _ in range(rollout_len)]
160 | # With the same seed, should always get the same result
161 | env.reset(seed=42)
162 | rollout_a = get_rollout(env, actions)
163 |
164 | env.reset(seed=42)
165 | rollout_b = get_rollout(env, actions)
166 |
167 | assert_equal_rollout(rollout_a, rollout_b)
168 |
169 | # For most non-deterministic environments, if we try enough seeds we should
170 | # eventually get a different result. For deterministic environments, all
171 | # seeds should produce the same starting state.
172 | def different_seeds_same_rollout(seed1, seed2):
173 | new_actions = [env.action_space.sample() for _ in range(rollout_len)]
174 | env.reset(seed=seed1)
175 | new_rollout_1 = get_rollout(env, new_actions)
176 | env.reset(seed=seed2)
177 | new_rollout_2 = get_rollout(env, new_actions)
178 | return has_same_observations(new_rollout_1, new_rollout_2)
179 |
180 | is_deterministic = matches_list(env_name, deterministic_envs)
181 | same_obs = all(
182 | different_seeds_same_rollout(seed, seed + 1) for seed in range(num_seeds)
183 | )
184 | assert same_obs == is_deterministic
185 |
186 |
187 | def _check_obs(obs: np.ndarray, obs_space: gym.Space) -> None:
188 | """Check obs is consistent with obs_space."""
189 | if obs_space.shape:
190 | assert obs.shape == obs_space.shape
191 | assert obs.dtype == obs_space.dtype
192 | assert obs in obs_space
193 |
194 |
195 | def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> bool:
196 | """Sample from env and check return value is of valid type."""
197 | act = env.action_space.sample()
198 | obs, rew, terminated, truncated, info = env.step(act)
199 | _check_obs(obs, obs_space)
200 | assert isinstance(rew, float)
201 | assert isinstance(terminated, bool)
202 | assert isinstance(truncated, bool)
203 | assert isinstance(info, dict)
204 | return terminated or truncated
205 |
206 |
207 | def is_mujoco_env(env: gym.Env) -> bool:
208 | """True if `env` is a MuJoCo environment."""
209 | return hasattr(env, "sim") and hasattr(env, "model")
210 |
211 |
212 | def test_rollout_schema(
213 | env: gym.Env,
214 | steps_after_terminated: int = 10,
215 | max_steps: int = 10_000,
216 | check_episode_ends: bool = True,
217 | ) -> None:
218 | """Check custom environments have correct types on `step` and `reset`.
219 |
220 | Args:
221 | env: The environment to test.
222 | steps_after_terminated: The number of steps to take after `terminated` is True,
223 | the nominal episode termination. This is an abuse of the Gym API,
224 | but we would like the environments to handle this case gracefully.
225 | max_steps: Test fails if we do not get `terminated` after this many timesteps.
226 | check_episode_ends: Check that episode ends after `max_steps` steps, and that
227 | steps taken after `terminated` is true are of the correct type.
228 |
229 | Raises:
230 | AssertionError if test fails.
231 | """
232 | obs_space = env.observation_space
233 | obs, _ = env.reset(seed=0)
234 | _check_obs(obs, obs_space)
235 |
236 | done = False
237 | for _ in range(max_steps):
238 | done = _sample_and_check(env, obs_space)
239 | if done:
240 | break
241 |
242 | if check_episode_ends:
243 | assert done, "did not get to end of episode"
244 |
245 | for _ in range(steps_after_terminated):
246 | _sample_and_check(env, obs_space)
247 |
248 |
249 | def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None:
250 | """Test that you must call reset() before calling step().
251 |
252 | Example usage in pytest:
253 | test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises)
254 |
255 | Args:
256 | env: The environment to test.
257 | skip_fn: called when the environment is incompatible with the test.
258 | raises_fn: Context manager to check exception is thrown.
259 |
260 | Raises:
261 | AssertionError if test fails.
262 | """
263 | if is_mujoco_env(env): # pragma: no cover
264 | # We can't use isinstance since importing mujoco_py will fail on
265 | # machines without MuJoCo installed
266 | skip_fn("MuJoCo environments cannot perform this check.")
267 |
268 | act = env.action_space.sample()
269 | with raises_fn(Exception): # need to call env.reset() first
270 | env.step(act)
271 |
272 |
273 | class CountingEnv(gym.Env):
274 | """At timestep `t` of each episode, has `t == obs == reward / 10`.
275 |
276 | Episodes finish after `episode_length` calls to `step()`, or equivalently
277 | `episode_length` actions. For example, if we have `episode_length=5`,
278 | then an episode has the following observations and rewards:
279 |
280 | ```
281 | obs = [0, 1, 2, 3, 4, 5]
282 | rews = [10, 20, 30, 40, 50]
283 | ```
284 | """
285 |
286 | def __init__(self, episode_length: int = 5):
287 | """Initialize a CountingEnv.
288 |
289 | Params:
290 | episode_length: The number of actions before each episode ends.
291 | """
292 | assert episode_length >= 1
293 | self.observation_space = gym.spaces.Box(low=0, high=np.inf, shape=())
294 | self.action_space = gym.spaces.Box(low=0, high=np.inf, shape=())
295 | self.episode_length = episode_length
296 | self.timestep = None
297 |
298 | def reset(self, seed=None, options={}):
299 | """Reset method for CountingEnv."""
300 | t, self.timestep = 0, 1
301 | return np.array(t, dtype=self.observation_space.dtype), {}
302 |
303 | def step(self, action):
304 | """Step method for CountingEnv."""
305 | if self.timestep is None: # pragma: no cover
306 | raise RuntimeError("Need to reset before first step().")
307 | if np.array(action) not in self.action_space: # pragma: no cover
308 | raise ValueError(f"Invalid action {action}")
309 | if self.timestep > self.episode_length: # pragma: no cover
310 | raise ValueError("Should reset env. Episode is over.")
311 |
312 | t, self.timestep = self.timestep, self.timestep + 1
313 | obs = np.array(t, dtype=self.observation_space.dtype)
314 | rew = t * 10.0
315 | terminated = t == self.episode_length
316 | return obs, rew, terminated, False, {}
317 |
--------------------------------------------------------------------------------
/src/seals/util.py:
--------------------------------------------------------------------------------
1 | """Miscellaneous utilities."""
2 |
3 | from dataclasses import dataclass
4 | from typing import (
5 | Any,
6 | Dict,
7 | Generic,
8 | List,
9 | Optional,
10 | Sequence,
11 | SupportsFloat,
12 | Tuple,
13 | TypeVar,
14 | Union,
15 | )
16 |
17 | import gymnasium as gym
18 | import numpy as np
19 |
20 | # Note: we redefine the type vars from gymnasium.core here, because pytype does not
21 | # recognize them as valid type vars if we import them from gymnasium.core.
22 | WrapperObsType = TypeVar("WrapperObsType")
23 | WrapperActType = TypeVar("WrapperActType")
24 | ObsType = TypeVar("ObsType")
25 | ActType = TypeVar("ActType")
26 |
27 |
28 | class AutoResetWrapper(
29 | gym.Wrapper,
30 | Generic[WrapperObsType, WrapperActType, ObsType, ActType],
31 | ):
32 | """Hides terminated truncated and auto-resets at the end of each episode.
33 |
34 | Depending on the flag 'discard_terminal_observation', either discards the terminal
35 | observation or pads with an additional 'reset transition'. The former is the default
36 | behavior.
37 | In the latter case, the action taken during the 'reset transition' will not have an
38 | effect, the reward will be constant (set by the wrapper argument `reset_reward`,
39 | which has default value 0.0), and info an empty dictionary.
40 | """
41 |
42 | def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0):
43 | """Builds the wrapper.
44 |
45 | Args:
46 | env: The environment to wrap.
47 | discard_terminal_observation: Defaults to True. If True, the terminal
48 | observation is discarded and the environment is reset immediately. The
49 | returned observation will then be the start of the next episode. The
50 | overridden observation is stored in `info["terminal_observation"]`.
51 | If False, the terminal observation is returned and the environment is
52 | reset in the next step.
53 | reset_reward: The reward to return for the reset transition. Defaults to
54 | 0.0.
55 | """
56 | super().__init__(env)
57 | self.discard_terminal_observation = discard_terminal_observation
58 | self.reset_reward = reset_reward
59 | self.previous_done = False # Whether the previous step returned done=True.
60 |
61 | def step(
62 | self,
63 | action: WrapperActType,
64 | ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
65 | """When terminated or truncated, resets the environment.
66 |
67 | Always returns False for terminated and truncated.
68 |
69 | Depending on whether we are discarding the terminal observation,
70 | either resets the environment and discards,
71 | or returns the terminal observation, and then uses the next step to reset the
72 | environment, after which steps will be performed as normal.
73 | """
74 | if self.discard_terminal_observation:
75 | return self._step_discard(action)
76 | else:
77 | return self._step_pad(action)
78 |
79 | def _step_pad(
80 | self,
81 | action: WrapperActType,
82 | ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
83 | """When terminated or truncated, resets the environment.
84 |
85 | Always returns False for terminated and truncated.
86 |
87 | The agent will then usually be asked to perform an action based on
88 | the terminal observation. In the next step, this final action will be ignored
89 | to instead reset the environment and return the initial observation of the new
90 | episode.
91 |
92 | Some potential caveats:
93 | - The underlying environment will perform fewer steps than the wrapped
94 | environment.
95 | - The number of steps the agent performs and the number of steps recorded in the
96 | underlying environment will not match, which could cause issues if these are
97 | assumed to be the same.
98 | """
99 | if self.previous_done:
100 | self.previous_done = False
101 | reset_obs, reset_info_dict = self.env.reset()
102 | info = {"reset_info_dict": reset_info_dict}
103 | # This transition will only reset the environment, the action is ignored.
104 | return reset_obs, self.reset_reward, False, False, info
105 |
106 | obs, rew, terminated, truncated, info = self.env.step(action)
107 | self.previous_done = terminated or truncated
108 | return obs, rew, False, False, info
109 |
110 | def _step_discard(
111 | self,
112 | action: WrapperActType,
113 | ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
114 | """When terminated or truncated, return False for both and automatically reset.
115 |
116 | When an automatic reset happens, the observation from reset is returned,
117 | and the overridden observation is stored in
118 | `info["terminal_observation"]`.
119 | """
120 | obs, rew, terminated, truncated, info = self.env.step(action)
121 | if terminated or truncated:
122 | info["terminal_observation"] = obs
123 | obs, reset_info_dict = self.env.reset()
124 | info["reset_info_dict"] = reset_info_dict
125 | return obs, rew, False, False, info
126 |
127 |
128 | @dataclass
129 | class BoxRegion:
130 | """A rectangular region dataclass used by MaskScoreWrapper."""
131 |
132 | x: Tuple
133 | y: Tuple
134 |
135 |
136 | MaskedRegionSpecifier = List[BoxRegion]
137 |
138 |
139 | class MaskScoreWrapper(gym.ObservationWrapper):
140 | """Mask a list of box-shaped regions in the observation to hide reward info.
141 |
142 | Intended for environments whose observations are raw pixels (like Atari
143 | environments). Used to mask regions of the observation that include information
144 | that could be used to infer the reward, like the game score or enemy ship count.
145 | """
146 |
147 | def __init__(
148 | self,
149 | env: gym.Env,
150 | score_regions: MaskedRegionSpecifier,
151 | fill_value: Union[float, Sequence[float]] = 0,
152 | ):
153 | """Builds MaskScoreWrapper.
154 |
155 | Args:
156 | env: The environment to wrap.
157 | score_regions: A list of box-shaped regions to mask, each denoted by
158 | a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1`
159 | and `y0 < y1`.
160 | fill_value: The fill_value for the masked region. By default is black.
161 | Can support RGB colors by being a sequence of values [r, g, b].
162 |
163 | Raises:
164 | ValueError: If a score region does not conform to the spec.
165 | """
166 | super().__init__(env)
167 | self.fill_value = np.array(fill_value, env.observation_space.dtype)
168 |
169 | if env.observation_space.shape is None:
170 | raise ValueError("Observation space must have a shape.") # pragma: no cover
171 | self.mask = np.ones(env.observation_space.shape, dtype=bool)
172 | for r in score_regions:
173 | if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]:
174 | raise ValueError('Invalid region: "x" and "y" must be increasing.')
175 | self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0
176 |
177 | def observation(self, obs):
178 | """Returns observation with masked regions filled with `fill_value`."""
179 | return np.where(self.mask, obs, self.fill_value)
180 |
181 |
182 | class ObsCastWrapper(gym.ObservationWrapper):
183 | """Cast observations to specified dtype.
184 |
185 | Some external environments return observations of a different type than the
186 | declared observation space. Where possible, this should be fixed upstream,
187 | but casting can be a viable workaround -- especially when the returned
188 | observations are higher resolution than the observation space.
189 | """
190 |
191 | def __init__(self, env: gym.Env, dtype: np.dtype):
192 | """Builds ObsCastWrapper.
193 |
194 | Args:
195 | env: the environment to wrap.
196 | dtype: the dtype to cast observations to.
197 | """
198 | super().__init__(env)
199 | self.dtype = dtype
200 |
201 | def observation(self, obs):
202 | """Returns observation cast to self.dtype."""
203 | return obs.astype(self.dtype)
204 |
205 |
206 | class AbsorbAfterDoneWrapper(gym.Wrapper):
207 | """Transition into absorbing state instead of episode termination.
208 |
209 | When the environment being wrapped returns `terminated=True` or `truncated=True`,
210 | we return an absorbing observation.
211 | This wrapper always returns `terminated=False` and `truncated=False`.
212 |
213 | A convenient way to add absorbing states to environments like MountainCar.
214 | """
215 |
216 | def __init__(
217 | self,
218 | env: gym.Env,
219 | absorb_reward: float = 0.0,
220 | absorb_obs: Optional[np.ndarray] = None,
221 | ):
222 | """Initialize AbsorbAfterDoneWrapper.
223 |
224 | Args:
225 | env: The wrapped Env.
226 | absorb_reward: The reward returned at the absorb state.
227 | absorb_obs: The observation returned at the absorb state. If None, then
228 | repeat the final observation before absorb.
229 | """
230 | super().__init__(env)
231 | self.absorb_reward = absorb_reward
232 | self.absorb_obs_default = absorb_obs
233 | self.absorb_obs_this_episode = None
234 | self.at_absorb_state = None
235 |
236 | def reset(self, *args, **kwargs):
237 | """Reset the environment."""
238 | self.at_absorb_state = False
239 | self.absorb_obs_this_episode = None
240 | return self.env.reset(*args, **kwargs)
241 |
242 | def step(self, action):
243 | """Advance the environment by one step.
244 |
245 | This wrapped `step()` always returns terminated=False and truncated=False.
246 |
247 | After the first time either terminated or truncated is returned by the
248 | underlying Env, we enter an artificial absorb state.
249 |
250 | In this artificial absorb state, we stop calling
251 | `self.env.step(action)` (i.e. the `action` argument is entirely ignored) and
252 | we return fixed values for obs, rew, terminated, truncated, and info.
253 | The values of `obs` and `rew` depend on initialization arguments.
254 | `info` is always an empty dictionary.
255 | """
256 | if not self.at_absorb_state:
257 | obs, rew, terminated, truncated, info = self.env.step(action)
258 | if terminated or truncated:
259 | # Initialize the artificial absorb state, which we will repeatedly use
260 | # starting on the next call to `step()`.
261 | self.at_absorb_state = True
262 |
263 | if self.absorb_obs_default is None:
264 | self.absorb_obs_this_episode = obs
265 | else:
266 | self.absorb_obs_this_episode = self.absorb_obs_default
267 | else:
268 | assert self.absorb_obs_this_episode is not None
269 | assert self.absorb_reward is not None
270 | obs = self.absorb_obs_this_episode
271 | rew = self.absorb_reward
272 | info = {}
273 |
274 | return obs, rew, False, False, info
275 |
276 |
277 | def get_gym_max_episode_steps(env_name: str) -> Optional[int]:
278 | """Get the `max_episode_steps` attribute associated with a gym Spec."""
279 | return gym.spec(env_name).max_episode_steps
280 |
281 |
282 | def sample_distribution(
283 | p: np.ndarray,
284 | random: np.random.Generator,
285 | ) -> int:
286 | """Samples an integer with probabilities given by p."""
287 | return random.choice(np.arange(len(p)), p=p)
288 |
289 |
290 | def one_hot_encoding(pos: int, size: int) -> np.ndarray:
291 | """Returns a 1-D hot encoding of a given position and size."""
292 | return np.eye(size)[pos]
293 |
294 |
295 | def grid_transition_fn(
296 | state: np.ndarray,
297 | action: int,
298 | x_bounds: Tuple[float, float] = (-np.inf, np.inf),
299 | y_bounds: Tuple[float, float] = (-np.inf, np.inf),
300 | ):
301 | """Returns transition of a deterministic gridworld.
302 |
303 | Agent is bounded in the region limited by x_bounds and y_bounds,
304 | ends inclusive.
305 |
306 | (0, 0) is interpreted to be top-left corner.
307 |
308 | Actions:
309 | 0: Right
310 | 1: Down
311 | 2: Left
312 | 3: Up
313 | 4: Stay put
314 | """
315 | dirs = [
316 | (1, 0),
317 | (0, 1),
318 | (-1, 0),
319 | (0, -1),
320 | (0, 0),
321 | ]
322 |
323 | x, y = state
324 | dx, dy = dirs[action]
325 |
326 | next_x = np.clip(x + dx, *x_bounds)
327 | next_y = np.clip(y + dy, *y_bounds)
328 | next_state = np.array([next_x, next_y], dtype=state.dtype)
329 |
330 | return next_state
331 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Configuration for pytest."""
2 |
3 | import pytest
4 |
5 | pytest.register_assert_rewrite("seals.testing")
6 |
7 |
8 | def pytest_addoption(parser):
9 | """Add --expensive option."""
10 | parser.addoption(
11 | "--expensive",
12 | action="store_true",
13 | dest="expensive",
14 | default=False,
15 | help="enable expensive tests",
16 | )
17 |
18 |
19 | def pytest_collection_modifyitems(config, items):
20 | """Make expensive tests be skipped without an --expensive flag."""
21 | if config.getoption("--expensive"): # pragma: no cover
22 | return
23 | skip_expensive = pytest.mark.skip(reason="needs --expensive option to run")
24 | for item in items:
25 | if "expensive" in item.keywords:
26 | item.add_marker(skip_expensive)
27 |
--------------------------------------------------------------------------------
/tests/test_base_env.py:
--------------------------------------------------------------------------------
1 | """Test the base_envs module.
2 |
3 | Note base_envs is also tested indirectly via smoke tests in `test_envs`,
4 | so the tests in this file focus on features unique to classes in `base_envs`.
5 | """
6 |
7 | import gymnasium as gym
8 | import numpy as np
9 | import pytest
10 |
11 | from seals import base_envs
12 | from seals.testing import envs
13 |
14 |
15 | class NewEnv(base_envs.TabularModelMDP):
16 | """Test the TabularModelMDP class."""
17 |
18 | def __init__(self):
19 | """Build environment."""
20 | np.random.seed(0)
21 | nS = 3
22 | nA = 2
23 | transition_matrix = np.random.random((nS, nA, nS))
24 | transition_matrix /= transition_matrix.sum(axis=2)[:, :, None]
25 | reward_matrix = np.random.random((nS,))
26 | super().__init__(
27 | transition_matrix=transition_matrix,
28 | reward_matrix=reward_matrix,
29 | )
30 |
31 |
32 | def test_base_envs():
33 | """Test parts of base_envs not covered elsewhere."""
34 | env = NewEnv()
35 |
36 | assert np.all(np.eye(3) == env.feature_matrix)
37 |
38 | envs.test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises)
39 |
40 | env.reset(seed=0)
41 | assert env.n_actions_taken == 0
42 | env.step(env.action_space.sample())
43 | assert env.n_actions_taken == 1
44 | env.step(env.action_space.sample())
45 | assert env.n_actions_taken == 2
46 |
47 | new_state = env.state_space.sample()
48 | env.state = new_state
49 | assert env.state == new_state
50 |
51 | bad_state = "not a state"
52 | with pytest.raises(ValueError, match=r".*not in.*"):
53 | env.state = bad_state # type: ignore
54 |
55 | with pytest.raises(NotImplementedError, match=r"Options not supported.*"):
56 | env.reset(options={"option": "value"})
57 |
58 |
59 | def test_tabular_env_validation():
60 | """Test input validation for base_envs.TabularModelEnv."""
61 | with pytest.raises(ValueError, match=r"Malformed transition_matrix.*"):
62 | base_envs.TabularModelMDP(
63 | transition_matrix=np.zeros((3, 1, 4)),
64 | reward_matrix=np.zeros((3,)),
65 | )
66 | with pytest.raises(ValueError, match=r"initial_state_dist has multiple.*"):
67 | base_envs.TabularModelMDP(
68 | transition_matrix=np.zeros((3, 1, 3)),
69 | reward_matrix=np.zeros((3,)),
70 | initial_state_dist=np.zeros((3, 4)),
71 | )
72 | with pytest.raises(ValueError, match=r"transition_matrix and initial_state_dist.*"):
73 | base_envs.TabularModelMDP(
74 | transition_matrix=np.zeros((3, 1, 3)),
75 | reward_matrix=np.zeros((3,)),
76 | initial_state_dist=np.zeros((2)),
77 | )
78 | with pytest.raises(ValueError, match=r"transition_matrix and reward_matrix.*"):
79 | base_envs.TabularModelMDP(
80 | transition_matrix=np.zeros((4, 1, 4)),
81 | reward_matrix=np.zeros((3,)),
82 | )
83 | with pytest.raises(ValueError, match=r"transition_matrix and observation_matrix.*"):
84 | base_envs.TabularModelPOMDP(
85 | transition_matrix=np.zeros((3, 1, 3)),
86 | reward_matrix=np.zeros((3,)),
87 | observation_matrix=np.zeros((4, 3)),
88 | )
89 |
90 | env = base_envs.TabularModelMDP(
91 | transition_matrix=np.zeros((3, 1, 3)),
92 | reward_matrix=np.zeros((3,)),
93 | )
94 | env.reset(seed=0)
95 | with pytest.raises(ValueError, match=r".*not in.*"):
96 | env.step(4)
97 |
98 |
99 | def test_expose_pomdp_state_wrapper():
100 | """Test the ExposePOMDPStateWrapper class."""
101 | env = NewEnv()
102 | wrapped_env = base_envs.ExposePOMDPStateWrapper(env)
103 |
104 | assert wrapped_env.observation_space == env.state_space
105 | state, _ = wrapped_env.reset(seed=0)
106 | assert state == env.state
107 | assert state in env.state_space
108 |
109 | action = env.action_space.sample()
110 | next_state, reward, terminated, truncated, info = wrapped_env.step(action)
111 | assert next_state == env.state
112 | assert next_state in env.state_space
113 |
114 |
115 | def test_tabular_pompd_obs_space_int():
116 | """Test the TabularModelPOMDP class with an integer observation space."""
117 | env = base_envs.TabularModelPOMDP(
118 | transition_matrix=np.zeros(
119 | (3, 1, 3),
120 | ),
121 | reward_matrix=np.zeros((3,)),
122 | observation_matrix=np.zeros((3, 3), dtype=np.int64),
123 | )
124 | assert isinstance(env.observation_space, gym.spaces.Box)
125 | assert env.observation_space.dtype == np.int64
126 |
--------------------------------------------------------------------------------
/tests/test_diagnostics.py:
--------------------------------------------------------------------------------
1 | """Test the `diagnostics.*` environments."""
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from seals.diagnostics import cliff_world, init_shift, random_trans
7 |
8 |
9 | def test_init_shift_validation():
10 | """Test input validation for init_shift.InitShiftEnv."""
11 | for invalid_state in [-1, 7, 8, 100]:
12 | with pytest.raises(ValueError, match=r"Initial state.*"):
13 | init_shift.InitShiftEnv(initial_state=invalid_state)
14 |
15 |
16 | def test_cliff_world_draw_value_vec():
17 | """Smoke test for cliff_world.CliffWorldEnv.draw_value_vec()."""
18 | env = cliff_world.CliffWorldEnv(
19 | width=7,
20 | height=4,
21 | horizon=9,
22 | use_xy_obs=False,
23 | )
24 | D = np.zeros(env.state_dim)
25 | env.draw_value_vec(D)
26 |
27 |
28 | def test_random_transition_env_init():
29 | """Test that RandomTransitionEnv initializes correctly."""
30 | random_trans.RandomTransitionEnv(
31 | n_states=3,
32 | n_actions=2,
33 | branch_factor=3,
34 | horizon=10,
35 | random_obs=False,
36 | )
37 | random_trans.RandomTransitionEnv(
38 | n_states=3,
39 | n_actions=2,
40 | branch_factor=3,
41 | horizon=10,
42 | random_obs=True,
43 | )
44 | random_trans.RandomTransitionEnv(
45 | n_states=3,
46 | n_actions=2,
47 | branch_factor=3,
48 | horizon=10,
49 | random_obs=True,
50 | obs_dim=10,
51 | )
52 | with pytest.raises(ValueError, match="obs_dim must be None if random_obs is False"):
53 | random_trans.RandomTransitionEnv(
54 | n_states=3,
55 | n_actions=2,
56 | branch_factor=3,
57 | horizon=10,
58 | random_obs=False,
59 | obs_dim=3,
60 | )
61 |
62 |
63 | def test_make_random_matrices_no_explicit_rng():
64 | """Test that random matrix maker static methods work without an explicit RNG."""
65 | random_trans.RandomTransitionEnv.make_random_trans_mat(3, 2, 3)
66 | random_trans.RandomTransitionEnv.make_random_state_dist(3, 3)
67 | random_trans.RandomTransitionEnv.make_obs_mat(3, True, 3)
68 | with pytest.raises(ValueError, match="obs_dim must be set if random_obs is True"):
69 | random_trans.RandomTransitionEnv.make_obs_mat(3, True)
70 | random_trans.RandomTransitionEnv.make_obs_mat(3, False)
71 | with pytest.raises(ValueError, match="obs_dim must be None if random_obs is False"):
72 | random_trans.RandomTransitionEnv.make_obs_mat(3, False, 3)
73 |
--------------------------------------------------------------------------------
/tests/test_envs.py:
--------------------------------------------------------------------------------
1 | """Smoke tests for all environments."""
2 |
3 | from typing import List, Union
4 |
5 | import gymnasium as gym
6 | from gymnasium.envs import registration
7 | import numpy as np
8 | import pytest
9 |
10 | import seals # noqa: F401 required for env registration
11 | from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env
12 | from seals.testing import envs
13 | from seals.testing.envs import is_mujoco_env
14 |
15 | ENV_NAMES: List[str] = [
16 | env_id for env_id in registration.registry.keys() if env_id.startswith("seals/")
17 | ]
18 |
19 |
20 | DETERMINISTIC_ENVS: List[str] = [
21 | "seals/EarlyTermPos-v0",
22 | "seals/EarlyTermNeg-v0",
23 | "seals/Branching-v0",
24 | "seals/InitShiftTrain-v0",
25 | "seals/InitShiftTest-v0",
26 | ]
27 |
28 | UNMASKED_ATARI_ENVS: List[str] = [
29 | _seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS
30 | ]
31 | MASKED_ATARI_ENVS: List[str] = [
32 | _seals_name(gym_spec, masked=True)
33 | for gym_spec in seals.GYM_ATARI_ENV_SPECS
34 | if _get_score_region(gym_spec.id) is not None
35 | ]
36 | ATARI_ENVS = UNMASKED_ATARI_ENVS + MASKED_ATARI_ENVS
37 |
38 | ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS))
39 | ATARI_NO_FRAMESKIP_ENVS: List[str] = list(
40 | filter(lambda name: name.endswith("-v4"), ATARI_ENVS),
41 | )
42 |
43 | DETERMINISTIC_ENVS += ATARI_NO_FRAMESKIP_ENVS
44 |
45 |
46 | env = pytest.fixture(envs.make_env_fixture(skip_fn=pytest.skip))
47 |
48 |
49 | def test_some_atari_envs():
50 | """Tests if we succeeded in finding any Atari envs."""
51 | assert len(seals.GYM_ATARI_ENV_SPECS) > 0
52 |
53 |
54 | def test_atari_space_invaders():
55 | """Tests for masked and unmasked Atari space invaders environments."""
56 | masked_space_invader_environments = list(
57 | filter(
58 | lambda name: "SpaceInvaders" in name and "Unmasked" not in name,
59 | ATARI_ENVS,
60 | ),
61 | )
62 | assert len(masked_space_invader_environments) > 0
63 |
64 | unmasked_space_invader_environments = list(
65 | filter(
66 | lambda name: "SpaceInvaders" in name and "Unmasked" in name,
67 | ATARI_ENVS,
68 | ),
69 | )
70 | assert len(unmasked_space_invader_environments) > 0
71 |
72 |
73 | def test_atari_unmasked_env_naming():
74 | """Tests that all unmasked Atari envs have the appropriate name qualifier."""
75 | noncompliant_envs = list(
76 | filter(
77 | lambda name: _get_score_region(name) is None and "Unmasked" not in name,
78 | ATARI_ENVS,
79 | ),
80 | )
81 | assert len(noncompliant_envs) == 0
82 |
83 |
84 | def test_make_unsupported_masked_atari_env_throws_error():
85 | """Tests that making an unsupported masked Atari env throws an error."""
86 | match_str = (
87 | "Requested environment does not yet support masking. "
88 | "See https://github.com/HumanCompatibleAI/seals/issues/61."
89 | )
90 | with pytest.raises(ValueError, match=match_str):
91 | make_atari_env("ALE/Bowling-v5", masked=True)
92 |
93 |
94 | def test_atari_masks_satisfy_spec():
95 | """Tests that all Atari masks satisfy the spec."""
96 | masks_satisfy_spec = [
97 | mask.x[0] < mask.x[1] and mask.y[0] < mask.y[1]
98 | for env_regions in SCORE_REGIONS.values()
99 | for mask in env_regions
100 | ]
101 | assert all(masks_satisfy_spec)
102 |
103 |
104 | @pytest.mark.parametrize("env_name", ENV_NAMES)
105 | class TestEnvs:
106 | """Battery of simple tests for environments."""
107 |
108 | def test_seed(self, env: gym.Env, env_name: str):
109 | """Tests environment seeding.
110 |
111 | Deterministic Atari environments are run with fewer seeds to minimize the number
112 | of resets done in this test suite, since Atari resets take a long time and there
113 | are many Atari environments.
114 | """
115 | if env_name in ATARI_ENVS:
116 | # these environments take a while for their non-determinism to show.
117 | slow_random_envs = [
118 | "seals/Bowling-Unmasked-v5",
119 | "seals/Frogger-Unmasked-v5",
120 | "seals/KingKong-Unmasked-v5",
121 | "seals/Koolaid-Unmasked-v5",
122 | "seals/NameThisGame-Unmasked-v5",
123 | "seals/Casino-Unmasked-v5",
124 | ]
125 | rollout_len = 100 if env_name not in slow_random_envs else 400
126 | num_seeds = 2 if env_name in ATARI_NO_FRAMESKIP_ENVS else 10
127 | envs.test_seed(
128 | env,
129 | env_name,
130 | DETERMINISTIC_ENVS,
131 | rollout_len=rollout_len,
132 | num_seeds=num_seeds,
133 | )
134 | else:
135 | envs.test_seed(env, env_name, DETERMINISTIC_ENVS)
136 |
137 | def test_premature_step(self, env: gym.Env):
138 | """Tests if step() before reset() raises error."""
139 | envs.test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises)
140 |
141 | def test_rollout_schema(self, env: gym.Env, env_name: str):
142 | """Tests if environments have correct types on `step()` and `reset()`.
143 |
144 | Atari environments have a very long episode length (~100k observations), so in
145 | the interest of time we do not run them to the end of their episodes or check
146 | the return time of `env.step` after the end of the episode.
147 | """
148 | if env_name in ATARI_ENVS:
149 | envs.test_rollout_schema(env, max_steps=1_000, check_episode_ends=False)
150 | else:
151 | envs.test_rollout_schema(env)
152 |
153 | def test_render_modes(self, env_name: str):
154 | """Tests that all render modes specifeid in the metadata work.
155 |
156 | Note: we only check that no exception is thrown.
157 | There is no test to see if something reasonable is rendered.
158 | """
159 | for mode in gym.make(env_name).metadata["render_modes"]:
160 | # GIVEN
161 | env = gym.make(env_name, render_mode=mode)
162 | env.reset(seed=0)
163 |
164 | # WHEN
165 | if mode == "rgb_array" and not is_mujoco_env(env):
166 | # The render should not change without calling `step()`.
167 | # MuJoCo rendering fails this check, ignore -- not much we can do.
168 | r1: Union[np.ndarray, List[np.ndarray], None] = env.render()
169 | r2: Union[np.ndarray, List[np.ndarray], None] = env.render()
170 | assert r1 is not None
171 | assert r2 is not None
172 | assert np.allclose(r1, r2)
173 | else:
174 | env.render()
175 |
176 | # THEN
177 | # no error raised
178 |
179 | # CLEANUP
180 | env.close()
181 |
--------------------------------------------------------------------------------
/tests/test_mujoco_rl.py:
--------------------------------------------------------------------------------
1 | """Test RL on MuJoCo adapted environments."""
2 |
3 | from typing import Tuple
4 |
5 | import gymnasium as gym
6 | import pytest
7 | import stable_baselines3
8 | from stable_baselines3.common import evaluation
9 |
10 | import seals # noqa: F401 Import required for env registration
11 |
12 |
13 | def _eval_env(
14 | env_name: str,
15 | total_timesteps: int,
16 | ) -> Tuple[float, float]: # pragma: no cover
17 | """Train PPO2 for `total_timesteps` on `env_name` and evaluate returns."""
18 | env = gym.make(env_name)
19 | model = stable_baselines3.PPO("MlpPolicy", env)
20 | model.learn(total_timesteps=total_timesteps)
21 | res = evaluation.evaluate_policy(model, env)
22 | assert isinstance(res[0], float)
23 | return res
24 |
25 |
26 | # SOMEDAY(adam): tests are flaky and consistently fail in some environments
27 | # Unclear if they even should pass in some cases.
28 | # See discussion in GH#6 and GH#40.
29 | @pytest.mark.expensive
30 | @pytest.mark.parametrize(
31 | "env_base",
32 | ["HalfCheetah", "Ant", "Hopper", "Humanoid", "Swimmer", "Walker2d"],
33 | )
34 | def test_fixed_env_model_as_good_as_gym_env_model(env_base: str): # pragma: no cover
35 | """Compare original and modified MuJoCo v3 envs."""
36 | train_timesteps = 200000
37 |
38 | gym_reward, _ = _eval_env(f"{env_base}-v4", total_timesteps=train_timesteps)
39 | fixed_reward, _ = _eval_env(
40 | f"seals/{env_base}-v1",
41 | total_timesteps=train_timesteps,
42 | )
43 |
44 | epsilon = 0.1
45 | sign = 1 if gym_reward > 0 else -1
46 | assert (1 - sign * epsilon) * gym_reward <= fixed_reward
47 |
--------------------------------------------------------------------------------
/tests/test_util.py:
--------------------------------------------------------------------------------
1 | """Test `seals.util`."""
2 |
3 | import collections
4 |
5 | import gymnasium as gym
6 | import numpy as np
7 | import pytest
8 |
9 | from seals import GYM_ATARI_ENV_SPECS, util
10 |
11 |
12 | def test_mask_score_wrapper_enforces_spec():
13 | """Test that MaskScoreWrapper enforces the spec."""
14 | atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id)
15 | desired_error_message = 'Invalid region: "x" and "y" must be increasing.'
16 | with pytest.raises(ValueError, match=desired_error_message):
17 | util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(0, 1), y=(1, 0))])
18 | with pytest.raises(ValueError, match=desired_error_message):
19 | util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(1, 0), y=(0, 1))])
20 |
21 |
22 | def test_sample_distribution():
23 | """Test util.sample_distribution."""
24 | distr_size = 5
25 | distr = np.random.random((distr_size,))
26 | distr /= distr.sum()
27 |
28 | n_samples = 1000
29 | rng = np.random.default_rng(0)
30 | sample_count = collections.Counter(
31 | util.sample_distribution(distr, rng) for _ in range(n_samples)
32 | )
33 |
34 | empirical_distr = np.array([sample_count[i] for i in range(distr_size)]) / n_samples
35 |
36 | # Empirical distribution matches real distribution
37 | l1_err = np.sum(np.abs(empirical_distr - distr))
38 | assert l1_err < 0.1
39 |
40 | # Same seed gives same samples
41 | assert all(
42 | util.sample_distribution(distr, random=np.random.default_rng(seed))
43 | == util.sample_distribution(distr, random=np.random.default_rng(seed))
44 | for seed in range(20)
45 | )
46 |
47 |
48 | def test_one_hot_encoding():
49 | """Test util.one_hot_encoding."""
50 | Case = collections.namedtuple("Case", ["pos", "size", "encoding"])
51 |
52 | cases = [
53 | Case(pos=0, size=1, encoding=np.array([1.0])),
54 | Case(pos=1, size=5, encoding=np.array([0.0, 1.0, 0.0, 0.0, 0.0])),
55 | Case(pos=3, size=4, encoding=np.array([0.0, 0.0, 0.0, 1.0])),
56 | Case(pos=2, size=3, encoding=np.array([0.0, 0.0, 1.0])),
57 | Case(pos=2, size=6, encoding=np.array([0.0, 0.0, 1.0, 0.0, 0.0, 0.0])),
58 | ]
59 |
60 | assert all(
61 | np.all(util.one_hot_encoding(pos, size) == encoding)
62 | for pos, size, encoding in cases
63 | )
64 |
--------------------------------------------------------------------------------
/tests/test_wrappers.py:
--------------------------------------------------------------------------------
1 | """Tests for wrapper classes."""
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from seals import util
7 | from seals.testing import envs
8 |
9 |
10 | def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2):
11 | """Check that AutoResetWrapper returns correct values from step and reset.
12 |
13 | AutoResetWrapper that pads trajectory with an extra transition containing the
14 | terminal observations.
15 | Also check that calls to .reset() do not interfere with automatic resets.
16 | Due to the padding, the number of steps counted inside the environment and the
17 | number of steps performed outside the environment, i.e., the number of actions
18 | performed, will differ. This test checks that this difference is consistent.
19 | """
20 | env = util.AutoResetWrapper(
21 | envs.CountingEnv(episode_length=episode_length),
22 | discard_terminal_observation=False,
23 | )
24 |
25 | for _ in range(n_manual_reset):
26 | obs, info = env.reset()
27 | assert obs == 0
28 |
29 | # We count the number of episodes, so we can sanity check the padding.
30 | num_episodes = 0
31 | next_episode_end = episode_length
32 | for t in range(1, n_steps + 1):
33 | act = env.action_space.sample()
34 | obs, rew, terminated, truncated, info = env.step(act)
35 |
36 | # AutoResetWrapper overrides all terminated and truncated signals.
37 | assert terminated is False
38 | assert truncated is False
39 |
40 | if t == next_episode_end:
41 | # Unlike the AutoResetWrapper that discards terminal observations,
42 | # here the final observation is returned directly, and is not stored
43 | # in the info dict.
44 | # Due to padding, for every episode the final observation is offset from
45 | # the outer step by one.
46 | assert obs == (t - num_episodes) / (num_episodes + 1)
47 | assert rew == episode_length * 10
48 | if t == next_episode_end + 1:
49 | num_episodes += 1
50 | # Because the final step returned the final observation, the initial
51 | # obs of the next episode is returned in this additional step.
52 | assert obs == 0
53 | # Consequently, the next episode end is one step later, so it is
54 | # episode_length steps from now.
55 | next_episode_end = t + episode_length
56 |
57 | # Reward of the 'reset transition' is fixed to be 0.
58 | assert rew == 0
59 |
60 | # Sanity check padding. Padding should be 1 for each past episode.
61 | assert (
62 | next_episode_end
63 | == (num_episodes + 1) * episode_length + num_episodes
64 | )
65 |
66 |
67 | def test_auto_reset_wrapper_discard(episode_length=3, n_steps=100, n_manual_reset=2):
68 | """Check that AutoResetWrapper returns correct values from step and reset.
69 |
70 | Tests for AutoResetWrapper that discards terminal observations.
71 | Also check that calls to .reset() do not interfere with automatic resets.
72 | """
73 | env = util.AutoResetWrapper(
74 | envs.CountingEnv(episode_length=episode_length),
75 | discard_terminal_observation=True,
76 | )
77 |
78 | for _ in range(n_manual_reset):
79 | obs, info = env.reset()
80 | assert obs == 0
81 |
82 | for t in range(1, n_steps + 1):
83 | act = env.action_space.sample()
84 | obs, rew, terminated, truncated, info = env.step(act)
85 | expected_obs = t % episode_length
86 |
87 | assert obs == expected_obs
88 | assert terminated is False
89 | assert truncated is False
90 |
91 | if expected_obs == 0: # End of episode
92 | assert info.get("terminal_observation", None) == episode_length
93 | assert rew == episode_length * 10
94 | else:
95 | assert "terminal_observation" not in info
96 | assert rew == expected_obs * 10
97 |
98 |
99 | def test_absorb_repeat_custom_state(
100 | absorb_reward=-4,
101 | absorb_obs=-3.0,
102 | episode_length=6,
103 | n_steps=100,
104 | n_manual_reset=3,
105 | ):
106 | """Check that AbsorbAfterDoneWrapper returns custom state and reward."""
107 | env = envs.CountingEnv(episode_length=episode_length)
108 | env = util.AbsorbAfterDoneWrapper(
109 | env,
110 | absorb_reward=absorb_reward,
111 | absorb_obs=absorb_obs,
112 | )
113 |
114 | for r in range(n_manual_reset):
115 | env.reset()
116 | for t in range(1, n_steps + 1):
117 | act = env.action_space.sample()
118 | obs, rew, terminated, truncated, _ = env.step(act)
119 | assert terminated is False
120 | assert truncated is False
121 | if t > episode_length:
122 | expected_obs = absorb_obs
123 | expected_rew = absorb_reward
124 | else:
125 | expected_obs = t
126 | expected_rew = t * 10.0
127 | assert obs == expected_obs
128 | assert rew == expected_rew
129 |
130 |
131 | def test_absorb_repeat_final_state(episode_length=6, n_steps=100, n_manual_reset=3):
132 | """Check that AbsorbAfterDoneWrapper can repeat final state."""
133 | env = envs.CountingEnv(episode_length=episode_length)
134 | env = util.AbsorbAfterDoneWrapper(env, absorb_reward=-1, absorb_obs=None)
135 |
136 | for _ in range(n_manual_reset):
137 | env.reset()
138 | for t in range(1, n_steps + 1):
139 | act = env.action_space.sample()
140 | obs, rew, terminated, truncated, _ = env.step(act)
141 | assert terminated is False
142 | assert truncated is False
143 | if t > episode_length:
144 | expected_obs = episode_length
145 | expected_rew = -1
146 | else:
147 | expected_obs = t
148 | expected_rew = t * 10.0
149 | assert obs == expected_obs
150 | assert rew == expected_rew
151 |
152 |
153 | @pytest.mark.parametrize("dtype", [np.int64, np.float32, np.float64])
154 | def test_obs_cast(dtype: np.dtype, episode_length: int = 5):
155 | """Check obs_cast observations are of specified dtype and not mangled.
156 |
157 | Test uses CountingEnv with small integers, which can be represented in
158 | all the specified dtypes without any loss of precision.
159 | """
160 | env = util.ObsCastWrapper(
161 | envs.CountingEnv(episode_length=episode_length),
162 | dtype,
163 | )
164 |
165 | obs, _ = env.reset()
166 | assert obs.dtype == dtype
167 | assert obs == 0
168 | for t in range(1, episode_length + 1):
169 | act = env.action_space.sample()
170 | obs, rew, terminated, truncated, _ = env.step(act)
171 | assert terminated == (t == episode_length)
172 | assert truncated is False
173 | assert obs.dtype == dtype
174 | assert obs == t
175 | assert rew == t * 10.0
176 |
--------------------------------------------------------------------------------