├── .coveragerc ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── doc ├── Makefile ├── api.rst ├── conf.py ├── index.rst ├── requirements.txt └── template.rst ├── examples ├── mars │ ├── README.md │ ├── generate_plots.py │ ├── mars.py │ ├── mars_utilities.py │ ├── plot_utilities.py │ └── requirements.txt └── sample.py ├── requirements.txt ├── safemdp ├── SafeMDP_class.py ├── __init__.py ├── grid_world.py ├── test.py └── utilities.py ├── setup.py └── test_code.sh /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | # Regexes for lines to exclude from consideration 3 | exclude_lines = 4 | # Have to re-enable the standard pragma 5 | pragma: no cover 6 | 7 | # Don't complain about missing debug-only code: 8 | def __repr__ 9 | if self\.debug 10 | 11 | # Don't complain if tests don't hit defensive assertion code: 12 | raise AssertionError 13 | raise NotImplementedError 14 | 15 | # Don't complain if non-runnable code isn't run: 16 | if False: 17 | if __name__ == .__main__.: 18 | 19 | show_missing = True 20 | 21 | [html] 22 | directory = htmlcov 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .ipynb_checkpoints 3 | .coverage 4 | htmlcov 5 | 6 | # Python files 7 | *.pyc 8 | 9 | # Data files 10 | *.IMG 11 | *.tif 12 | 13 | # Documentation 14 | doc/safemdp.*.rst 15 | doc/_build 16 | 17 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - 2.7 4 | notifications: 5 | email: false 6 | 7 | # Setup anaconda 8 | before_install: 9 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh 10 | - chmod +x miniconda.sh 11 | - ./miniconda.sh -b -p ~/miniconda 12 | - export PATH=~/miniconda/bin:$PATH 13 | - conda update --yes conda 14 | # The next couple lines fix a crash with multiprocessing on Travis and are not specific to using Miniconda 15 | - sudo rm -rf /dev/shm 16 | - sudo ln -s /run/shm /dev/shm 17 | # Install packages 18 | install: 19 | - conda install --yes python=$TRAVIS_PYTHON_VERSION flake8 numpy scipy matplotlib nose networkx 20 | - pip install GPy 21 | # Coverage packages are on my binstar channel 22 | # - conda install --yes -c dan_blanchard python-coveralls nose-cov 23 | # - python setup.py install 24 | 25 | # Run tests 26 | script: 27 | - flake8 safemdp --exclude test*.py,__init__.py --ignore=E402,W503 --show-source 28 | - flake8 safemdp --filename=__init__.py,test*.py --ignore=F,E402,W503 --show-source 29 | - nosetests --with-doctest safemdp 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2015 Felix Berkenkamp 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SafeMDP 2 | 3 | [![Build Status](https://travis-ci.org/befelix/SafeMDP.svg?branch=master)](https://travis-ci.org/befelix/SafeMDP) 4 | [![Documentation Status](https://readthedocs.org/projects/safemdp/badge/?version=latest)](http://safemdp.readthedocs.io/en/latest/?badge=latest) 5 | 6 | Code for safe exploration in Markov Decision Processes (MDPs). This code accompanies the paper 7 | 8 | M. Turchetta, F. Berkenkamp, A. Krause, "Safe Exploration in Finite Markov Decision Processes with Gaussian Processes", Proc. of the Conference on Neural Information Processing Systems (NIPS), 2016, [PDF] 9 | 10 | # Installation 11 | 12 | The easiest way to install use the library is to install the Anaconda Python distribution. Then, run the following commands in the root directory of this repository: 13 | ``` 14 | pip install GPy 15 | python setup.py install 16 | ``` 17 | 18 | # Usage 19 | 20 | The documentation of the library is available on Read the Docs 21 | 22 | The file `examples/sample.py` implements a simple examples that samples a random world from a Gaussian process and shows exploration results. 23 | 24 | The code for the experiments in the paper can be found in the `examples/mars/` directory. 25 | -------------------------------------------------------------------------------- /doc/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # Internal variables. 11 | PAPEROPT_a4 = -D latex_paper_size=a4 12 | PAPEROPT_letter = -D latex_paper_size=letter 13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 14 | # the i18n builder cannot share the environment and doctrees with the others 15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 16 | 17 | .PHONY: help 18 | help: 19 | @echo "Please use \`make ' where is one of" 20 | @echo " html to make standalone HTML files" 21 | @echo " dirhtml to make HTML files named index.html in directories" 22 | @echo " singlehtml to make a single large HTML file" 23 | @echo " pickle to make pickle files" 24 | @echo " json to make JSON files" 25 | @echo " htmlhelp to make HTML files and a HTML help project" 26 | @echo " qthelp to make HTML files and a qthelp project" 27 | @echo " applehelp to make an Apple Help Book" 28 | @echo " devhelp to make HTML files and a Devhelp project" 29 | @echo " epub to make an epub" 30 | @echo " epub3 to make an epub3" 31 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 32 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 33 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 34 | @echo " text to make text files" 35 | @echo " man to make manual pages" 36 | @echo " texinfo to make Texinfo files" 37 | @echo " info to make Texinfo files and run them through makeinfo" 38 | @echo " gettext to make PO message catalogs" 39 | @echo " changes to make an overview of all changed/added/deprecated items" 40 | @echo " xml to make Docutils-native XML files" 41 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 42 | @echo " linkcheck to check all external links for integrity" 43 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 44 | @echo " coverage to run coverage check of the documentation (if enabled)" 45 | @echo " dummy to check syntax errors of document sources" 46 | 47 | .PHONY: clean 48 | clean: 49 | rm -rf $(BUILDDIR)/* 50 | 51 | .PHONY: html 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | .PHONY: dirhtml 58 | dirhtml: 59 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 60 | @echo 61 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 62 | 63 | .PHONY: singlehtml 64 | singlehtml: 65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 66 | @echo 67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 68 | 69 | .PHONY: pickle 70 | pickle: 71 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 72 | @echo 73 | @echo "Build finished; now you can process the pickle files." 74 | 75 | .PHONY: json 76 | json: 77 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 78 | @echo 79 | @echo "Build finished; now you can process the JSON files." 80 | 81 | .PHONY: htmlhelp 82 | htmlhelp: 83 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 84 | @echo 85 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 86 | ".hhp project file in $(BUILDDIR)/htmlhelp." 87 | 88 | .PHONY: qthelp 89 | qthelp: 90 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 91 | @echo 92 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 93 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 94 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/SafeMDP.qhcp" 95 | @echo "To view the help file:" 96 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/SafeMDP.qhc" 97 | 98 | .PHONY: applehelp 99 | applehelp: 100 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 101 | @echo 102 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 103 | @echo "N.B. You won't be able to view it unless you put it in" \ 104 | "~/Library/Documentation/Help or install it in your application" \ 105 | "bundle." 106 | 107 | .PHONY: devhelp 108 | devhelp: 109 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 110 | @echo 111 | @echo "Build finished." 112 | @echo "To view the help file:" 113 | @echo "# mkdir -p $$HOME/.local/share/devhelp/SafeMDP" 114 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/SafeMDP" 115 | @echo "# devhelp" 116 | 117 | .PHONY: epub 118 | epub: 119 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 120 | @echo 121 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 122 | 123 | .PHONY: epub3 124 | epub3: 125 | $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 126 | @echo 127 | @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." 128 | 129 | .PHONY: latex 130 | latex: 131 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 132 | @echo 133 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 134 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 135 | "(use \`make latexpdf' here to do that automatically)." 136 | 137 | .PHONY: latexpdf 138 | latexpdf: 139 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 140 | @echo "Running LaTeX files through pdflatex..." 141 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 142 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 143 | 144 | .PHONY: latexpdfja 145 | latexpdfja: 146 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 147 | @echo "Running LaTeX files through platex and dvipdfmx..." 148 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 149 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 150 | 151 | .PHONY: text 152 | text: 153 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 154 | @echo 155 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 156 | 157 | .PHONY: man 158 | man: 159 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 160 | @echo 161 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 162 | 163 | .PHONY: texinfo 164 | texinfo: 165 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 166 | @echo 167 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 168 | @echo "Run \`make' in that directory to run these through makeinfo" \ 169 | "(use \`make info' here to do that automatically)." 170 | 171 | .PHONY: info 172 | info: 173 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 174 | @echo "Running Texinfo files through makeinfo..." 175 | make -C $(BUILDDIR)/texinfo info 176 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 177 | 178 | .PHONY: gettext 179 | gettext: 180 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 181 | @echo 182 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 183 | 184 | .PHONY: changes 185 | changes: 186 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 187 | @echo 188 | @echo "The overview file is in $(BUILDDIR)/changes." 189 | 190 | .PHONY: linkcheck 191 | linkcheck: 192 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 193 | @echo 194 | @echo "Link check complete; look for any errors in the above output " \ 195 | "or in $(BUILDDIR)/linkcheck/output.txt." 196 | 197 | .PHONY: doctest 198 | doctest: 199 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 200 | @echo "Testing of doctests in the sources finished, look at the " \ 201 | "results in $(BUILDDIR)/doctest/output.txt." 202 | 203 | .PHONY: coverage 204 | coverage: 205 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 206 | @echo "Testing of coverage in the sources finished, look at the " \ 207 | "results in $(BUILDDIR)/coverage/python.txt." 208 | 209 | .PHONY: xml 210 | xml: 211 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 212 | @echo 213 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 214 | 215 | .PHONY: pseudoxml 216 | pseudoxml: 217 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 218 | @echo 219 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 220 | 221 | .PHONY: dummy 222 | dummy: 223 | $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy 224 | @echo 225 | @echo "Build finished. Dummy builder generates no files." 226 | -------------------------------------------------------------------------------- /doc/api.rst: -------------------------------------------------------------------------------- 1 | API Documentation 2 | ***************** 3 | 4 | .. automodule:: safemdp 5 | -------------------------------------------------------------------------------- /doc/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # SafeMDP documentation build configuration file, created by 5 | # sphinx-quickstart on Thu Jun 16 08:40:23 2016. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | 20 | import os 21 | import sys 22 | import shlex 23 | import mock 24 | 25 | 26 | MOCK_MODULES = ['GPy', 27 | 'GPy.util', 28 | 'GPy.util.linalg', 29 | 'GPy.inference', 30 | 'GPy.inference.latent_function_inference', 31 | 'GPy.inference.latent_function_inference.posterior', 32 | 'mpl_toolkits', 33 | 'mpl_toolkits.mplot3d', 34 | 'matplotlib', 35 | 'matplotlib.pyplot', 36 | 'networkx', 37 | 'nose', 38 | 'nose.tools', 39 | 'numpy', 40 | 'numpy.testing', 41 | 'scipy', 42 | 'scipy.interpolate', 43 | 'scipy.spatial', 44 | 'scipy.spatial.distance', 45 | ] 46 | 47 | for mod_name in MOCK_MODULES: 48 | sys.modules[mod_name] = mock.Mock() 49 | 50 | sys.path.insert(0, os.path.abspath('../')) 51 | 52 | # -- General configuration ------------------------------------------------ 53 | 54 | # If your documentation needs a minimal Sphinx version, state it here. 55 | # 56 | # needs_sphinx = '1.0' 57 | 58 | # Add any Sphinx extension module names here, as strings. They can be 59 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 60 | # ones. 61 | extensions = [ 62 | 'sphinx.ext.autodoc', 63 | 'numpydoc', 64 | 'sphinx.ext.autosummary', 65 | ] 66 | 67 | # Add any paths that contain templates here, relative to this directory. 68 | templates_path = [''] 69 | 70 | # Generate an autosummary with one file per function. 71 | autosummary_generate = True 72 | 73 | autodoc_default_flags = [] 74 | 75 | # The suffix(es) of source filenames. 76 | # You can specify multiple suffix as a list of string: 77 | # 78 | # source_suffix = ['.rst', '.md'] 79 | source_suffix = '.rst' 80 | 81 | # The encoding of source files. 82 | # 83 | # source_encoding = 'utf-8-sig' 84 | 85 | # The master toctree document. 86 | master_doc = 'index' 87 | 88 | # General information about the project. 89 | project = 'SafeMDP' 90 | copyright = '2016, Matteo Turchetta, Felix Berkenkamp, Andreas Krause' 91 | author = 'Matteo Turchetta, Felix Berkenkamp, Andreas Krause' 92 | 93 | # The version info for the project you're documenting, acts as replacement for 94 | # |version| and |release|, also used in various other places throughout the 95 | # built documents. 96 | # 97 | # The short X.Y version. 98 | version = '1.0' 99 | # The full version, including alpha/beta/rc tags. 100 | release = '1.0' 101 | 102 | # The language for content autogenerated by Sphinx. Refer to documentation 103 | # for a list of supported languages. 104 | # 105 | # This is also used if you do content translation via gettext catalogs. 106 | # Usually you set "language" from the command line for these cases. 107 | language = None 108 | 109 | # There are two options for replacing |today|: either, you set today to some 110 | # non-false value, then it is used: 111 | # 112 | # today = '' 113 | # 114 | # Else, today_fmt is used as the format for a strftime call. 115 | # 116 | # today_fmt = '%B %d, %Y' 117 | 118 | # List of patterns, relative to source directory, that match files and 119 | # directories to ignore when looking for source files. 120 | # This patterns also effect to html_static_path and html_extra_path 121 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 122 | 123 | # The reST default role (used for this markup: `text`) to use for all 124 | # documents. 125 | # 126 | # default_role = None 127 | 128 | # If true, '()' will be appended to :func: etc. cross-reference text. 129 | # 130 | # add_function_parentheses = True 131 | 132 | # If true, the current module name will be prepended to all description 133 | # unit titles (such as .. function::). 134 | # 135 | # add_module_names = True 136 | 137 | # If true, sectionauthor and moduleauthor directives will be shown in the 138 | # output. They are ignored by default. 139 | # 140 | # show_authors = False 141 | 142 | # The name of the Pygments (syntax highlighting) style to use. 143 | pygments_style = 'sphinx' 144 | 145 | # A list of ignored prefixes for module index sorting. 146 | # modindex_common_prefix = [] 147 | 148 | # If true, keep warnings as "system message" paragraphs in the built documents. 149 | # keep_warnings = False 150 | 151 | # If true, `todo` and `todoList` produce output, else they produce nothing. 152 | todo_include_todos = False 153 | 154 | 155 | # -- Options for HTML output ---------------------------------------------- 156 | 157 | # The theme to use for HTML and HTML Help pages. See the documentation for 158 | # a list of builtin themes. 159 | # 160 | html_theme = 'sphinx_rtd_theme' 161 | 162 | # Theme options are theme-specific and customize the look and feel of a theme 163 | # further. For a list of options available for each theme, see the 164 | # documentation. 165 | # 166 | # html_theme_options = {} 167 | 168 | # Add any paths that contain custom themes here, relative to this directory. 169 | # html_theme_path = [] 170 | 171 | # The name for this set of Sphinx documents. 172 | # " v documentation" by default. 173 | # 174 | # html_title = 'SafeMDP v1.0' 175 | 176 | # A shorter title for the navigation bar. Default is the same as html_title. 177 | # 178 | # html_short_title = None 179 | 180 | # The name of an image file (relative to this directory) to place at the top 181 | # of the sidebar. 182 | # 183 | # html_logo = None 184 | 185 | # The name of an image file (relative to this directory) to use as a favicon of 186 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 187 | # pixels large. 188 | # 189 | # html_favicon = None 190 | 191 | # Add any paths that contain custom static files (such as style sheets) here, 192 | # relative to this directory. They are copied after the builtin static files, 193 | # so a file named "default.css" will overwrite the builtin "default.css". 194 | #html_static_path = ['_static'] 195 | 196 | # Add any extra paths that contain custom files (such as robots.txt or 197 | # .htaccess) here, relative to this directory. These files are copied 198 | # directly to the root of the documentation. 199 | # 200 | # html_extra_path = [] 201 | 202 | # If not None, a 'Last updated on:' timestamp is inserted at every page 203 | # bottom, using the given strftime format. 204 | # The empty string is equivalent to '%b %d, %Y'. 205 | # 206 | # html_last_updated_fmt = None 207 | 208 | # If true, SmartyPants will be used to convert quotes and dashes to 209 | # typographically correct entities. 210 | # 211 | # html_use_smartypants = True 212 | 213 | # Custom sidebar templates, maps document names to template names. 214 | # 215 | # html_sidebars = {} 216 | 217 | # Additional templates that should be rendered to pages, maps page names to 218 | # template names. 219 | # 220 | # html_additional_pages = {} 221 | 222 | # If false, no module index is generated. 223 | # 224 | # html_domain_indices = True 225 | 226 | # If false, no index is generated. 227 | # 228 | # html_use_index = True 229 | 230 | # If true, the index is split into individual pages for each letter. 231 | # 232 | # html_split_index = False 233 | 234 | # If true, links to the reST sources are added to the pages. 235 | # 236 | # html_show_sourcelink = True 237 | 238 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 239 | # 240 | # html_show_sphinx = True 241 | 242 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 243 | # 244 | # html_show_copyright = True 245 | 246 | # If true, an OpenSearch description file will be output, and all pages will 247 | # contain a tag referring to it. The value of this option must be the 248 | # base URL from which the finished HTML is served. 249 | # 250 | # html_use_opensearch = '' 251 | 252 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 253 | # html_file_suffix = None 254 | 255 | # Language to be used for generating the HTML full-text search index. 256 | # Sphinx supports the following languages: 257 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja' 258 | # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr', 'zh' 259 | # 260 | html_search_language = 'en' 261 | 262 | # A dictionary with options for the search language support, empty by default. 263 | # 'ja' uses this config value. 264 | # 'zh' user can custom change `jieba` dictionary path. 265 | # 266 | # html_search_options = {'type': 'default'} 267 | 268 | # The name of a javascript file (relative to the configuration directory) that 269 | # implements a search results scorer. If empty, the default will be used. 270 | # 271 | # html_search_scorer = 'scorer.js' 272 | 273 | # Output file base name for HTML help builder. 274 | htmlhelp_basename = 'SafeMDPdoc' 275 | 276 | # -- Options for LaTeX output --------------------------------------------- 277 | 278 | latex_elements = { 279 | # The paper size ('letterpaper' or 'a4paper'). 280 | # 281 | # 'papersize': 'letterpaper', 282 | 283 | # The font size ('10pt', '11pt' or '12pt'). 284 | # 285 | # 'pointsize': '10pt', 286 | 287 | # Additional stuff for the LaTeX preamble. 288 | # 289 | # 'preamble': '', 290 | 291 | # Latex figure (float) alignment 292 | # 293 | # 'figure_align': 'htbp', 294 | } 295 | 296 | # Grouping the document tree into LaTeX files. List of tuples 297 | # (source start file, target name, title, 298 | # author, documentclass [howto, manual, or own class]). 299 | latex_documents = [ 300 | (master_doc, 'SafeMDP.tex', 'SafeMDP Documentation', 301 | 'Matteo Turchetta, Felix Berkenkamp, Andreas Krause', 'manual'), 302 | ] 303 | 304 | # The name of an image file (relative to this directory) to place at the top of 305 | # the title page. 306 | # 307 | # latex_logo = None 308 | 309 | # For "manual" documents, if this is true, then toplevel headings are parts, 310 | # not chapters. 311 | # 312 | # latex_use_parts = False 313 | 314 | # If true, show page references after internal links. 315 | # 316 | # latex_show_pagerefs = False 317 | 318 | # If true, show URL addresses after external links. 319 | # 320 | # latex_show_urls = False 321 | 322 | # Documents to append as an appendix to all manuals. 323 | # 324 | # latex_appendices = [] 325 | 326 | # If false, no module index is generated. 327 | # 328 | # latex_domain_indices = True 329 | 330 | 331 | # -- Options for manual page output --------------------------------------- 332 | 333 | # One entry per manual page. List of tuples 334 | # (source start file, name, description, authors, manual section). 335 | man_pages = [ 336 | (master_doc, 'safemdp', 'SafeMDP Documentation', 337 | [author], 1) 338 | ] 339 | 340 | # If true, show URL addresses after external links. 341 | # 342 | # man_show_urls = False 343 | 344 | 345 | # -- Options for Texinfo output ------------------------------------------- 346 | 347 | # Grouping the document tree into Texinfo files. List of tuples 348 | # (source start file, target name, title, author, 349 | # dir menu entry, description, category) 350 | texinfo_documents = [ 351 | (master_doc, 'SafeMDP', 'SafeMDP Documentation', 352 | author, 'SafeMDP', 'One line description of project.', 353 | 'Miscellaneous'), 354 | ] 355 | 356 | # Documents to append as an appendix to all manuals. 357 | # 358 | # texinfo_appendices = [] 359 | 360 | # If false, no module index is generated. 361 | # 362 | # texinfo_domain_indices = True 363 | 364 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 365 | # 366 | # texinfo_show_urls = 'footnote' 367 | 368 | # If true, do not generate a @detailmenu in the "Top" node's menu. 369 | # 370 | # texinfo_no_detailmenu = False 371 | 372 | 373 | # Example configuration for intersphinx: refer to the Python standard library. 374 | # intersphinx_mapping = {'https://docs.python.org/': None} 375 | -------------------------------------------------------------------------------- /doc/index.rst: -------------------------------------------------------------------------------- 1 | .. SafeMDP documentation master file, created by 2 | sphinx-quickstart on Thu Jun 16 08:40:23 2016. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to SafeMDP's documentation! 7 | =================================== 8 | 9 | .. toctree:: 10 | :caption: Contents 11 | :maxdepth: 3 12 | 13 | api 14 | 15 | 16 | Indices and tables 17 | ================== 18 | 19 | * :ref:`genindex` 20 | * :ref:`modindex` 21 | * :ref:`search` 22 | 23 | -------------------------------------------------------------------------------- /doc/requirements.txt: -------------------------------------------------------------------------------- 1 | numpydoc >= 0.5 2 | sphinx_rtd_theme >= 0.1.8 3 | 4 | -------------------------------------------------------------------------------- /doc/template.rst: -------------------------------------------------------------------------------- 1 | {{ name }} 2 | {{ underline }} 3 | 4 | .. currentmodule:: {{ module }} 5 | .. auto{{ objtype }}:: {{ objname }} {% if objtype == "class" %} 6 | :members: 7 | :inherited-members: 8 | {% endif %} 9 | -------------------------------------------------------------------------------- /examples/mars/README.md: -------------------------------------------------------------------------------- 1 | The experiments were run using Python 3 and the packages specified in the `requirements.txt` file. 2 | -------------------------------------------------------------------------------- /examples/mars/generate_plots.py: -------------------------------------------------------------------------------- 1 | from plot_utilities import * 2 | import numpy as np 3 | import cPickle 4 | import time 5 | 6 | # Safe plots 7 | data = np.load("mars safe experiment.npz") 8 | mu = data["mu_alt"] 9 | var = data["var_alt"] 10 | beta = data["beta"] 11 | world_shape = data["world_shape"] 12 | altitudes = data["altitudes"] 13 | X = data["X"] 14 | Y = data["Y"] 15 | coord = data["coord"] 16 | coverage_over_t = data["coverage_over_t"] 17 | S_hat = data["S_hat"] 18 | 19 | plot_dist_from_C(mu, var, beta, altitudes, world_shape) 20 | plot_coverage(coverage_over_t) 21 | plot_paper(altitudes, S_hat, world_shape, './safe_exploration.pdf') 22 | 23 | # Unsafe plot 24 | data = np.load("mars unsafe experiment.npz") 25 | coverage = data["coverage"] 26 | altitudes = data["altitudes"] 27 | visited = data["visited"] 28 | world_shape = data["world_shape"] 29 | 30 | plot_paper(altitudes, visited, world_shape, './no_safe_exploration1.pdf') 31 | 32 | # Random plot 33 | data = np.load("mars random experiment.npz") 34 | coverage = data["coverage"] 35 | altitudes = data["altitudes"] 36 | visited = data["visited"] 37 | world_shape = data["world_shape"] 38 | 39 | plot_paper(altitudes, visited, world_shape, './random_exploration1.pdf') 40 | 41 | # Non ergodic plot 42 | data = np.load("mars non ergodic experiment.npz") 43 | coverage = data["coverage"] 44 | altitudes = data["altitudes"] 45 | S_hat = data["S_hat"] 46 | world_shape = data["world_shape"] 47 | 48 | plot_paper(altitudes, S_hat, world_shape, './no_ergodic_exploration1.pdf') 49 | 50 | # No expanders plot 51 | data = np.load("mars no G experiment.npz") 52 | coverage = data["coverage"] 53 | altitudes = data["altitudes"] 54 | S_hat = data["S_hat"] 55 | world_shape = data["world_shape"] 56 | coverage_over_t = data["coverage_over_t"] 57 | plot_coverage(coverage_over_t) 58 | 59 | plot_paper(altitudes, S_hat, world_shape, './no_G_exploration1.pdf') -------------------------------------------------------------------------------- /examples/mars/mars.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | 3 | import sys 4 | import time 5 | 6 | import numpy as np 7 | 8 | from safemdp.grid_world import * 9 | from mars_utilities import (mars_map, initialize_SafeMDP_object, 10 | performance_metrics) 11 | from plot_utilities import * 12 | 13 | print(sys.version) 14 | 15 | 16 | # Control experiments and saving 17 | save_performance = True 18 | random_experiment = True 19 | non_safe_experiment = True 20 | non_ergodic_experiment = True 21 | no_expanders_exploration = True 22 | 23 | ############################## SAFE ########################################### 24 | 25 | # Get mars data 26 | altitudes, coord, world_shape, step_size, num_of_points = mars_map() 27 | 28 | # Initialize object for simulation 29 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object( 30 | altitudes, coord, world_shape, step_size) 31 | 32 | # Initialize for performance storage 33 | time_steps = 525 34 | coverage_over_t = np.empty(time_steps, dtype=float) 35 | 36 | # Simulation loop 37 | t = time.time() 38 | unsafe_count = 0 39 | source = start 40 | 41 | for i in range(time_steps): 42 | 43 | # Simulation 44 | x.update_sets() 45 | next_sample = x.target_sample() 46 | x.add_observation(*next_sample) 47 | path = shortest_path(source, next_sample, x.graph) 48 | source = path[-1] 49 | 50 | # Performances 51 | unsafe_transitions, coverage, false_safe = performance_metrics(path, x, 52 | true_S_hat_epsilon, 53 | true_S_hat, h_hard) 54 | unsafe_count += unsafe_transitions 55 | coverage_over_t[i] = coverage 56 | print(coverage, false_safe, unsafe_count, i) 57 | 58 | print(str(time.time() - t) + "seconds elapsed") 59 | 60 | # Posterior over heights for plotting 61 | mu_alt, var_alt = x.gp.predict(x.coord, include_likelihood=False) 62 | 63 | 64 | print("----UNSAFE EXPERIMENT---") 65 | print("Number of points for interpolation: " + str(num_of_points)) 66 | print("False safe: " +str(false_safe)) 67 | print("Unsafe evaluations: " + str(unsafe_count)) 68 | print("Coverage: " + str(coverage)) 69 | 70 | if save_performance: 71 | file_name = "mars safe experiment" 72 | 73 | np.savez(file_name, false_safe=false_safe, coverage=coverage, 74 | coverage_over_t=coverage_over_t, mu_alt=mu_alt, 75 | var_alt=var_alt, altitudes=x.altitudes, S_hat=x.S_hat, 76 | time_steps=time_steps, world_shape=world_shape, X=x.gp.X, 77 | Y=x.gp.Y, coord=x.coord, beta=x.beta) 78 | 79 | 80 | ########################## NON SAFE ########################################### 81 | if non_safe_experiment: 82 | 83 | # Get mars data 84 | altitudes, coord, world_shape, step_size, num_of_points = mars_map() 85 | 86 | # Initialize object for simulation 87 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object( 88 | altitudes, coord, world_shape, step_size) 89 | 90 | # Assume all transitions are safe 91 | x.S[:] = True 92 | 93 | # Simulation loop 94 | unsafe_count = 0 95 | source = start 96 | trajectory = [] 97 | 98 | for i in range(time_steps): 99 | x.update_sets() 100 | next_sample = x.target_sample() 101 | x.add_observation(*next_sample) 102 | path = shortest_path(source, next_sample, x.graph) 103 | source = path[-1] 104 | 105 | # Check safety 106 | path_altitudes = x.altitudes[path] 107 | unsafe_count = np.sum(-np.diff(path_altitudes) < h_hard) 108 | 109 | if unsafe_count == 0: 110 | trajectory = trajectory + path[:-1] 111 | else: 112 | trajectory = trajectory + safe_subpath(path, altitudes, h_hard) 113 | 114 | # Convert trajectory to S_hat-like matrix 115 | visited = path_to_boolean_matrix(trajectory, x.graph, x.S) 116 | 117 | # Normalization factor 118 | max_size = float(np.count_nonzero(true_S_hat_epsilon)) 119 | 120 | # Performance 121 | coverage = 100 * np.count_nonzero( 122 | np.logical_and(visited, true_S_hat_epsilon)) / max_size 123 | 124 | # Print 125 | print("----UNSAFE EXPERIMENT---") 126 | print("Unsafe evaluations: " + str(unsafe_count)) 127 | print("Coverage: " + str(coverage)) 128 | 129 | if unsafe_count > 0: 130 | break 131 | 132 | if save_performance: 133 | file_name = "mars unsafe experiment" 134 | 135 | np.savez(file_name, coverage=coverage, altitudes=x.altitudes, 136 | visited=visited, world_shape=world_shape) 137 | 138 | ############################### RANDOM ######################################## 139 | if random_experiment: 140 | 141 | # Get mars data 142 | altitudes, coord, world_shape, step_size, num_of_points = mars_map() 143 | 144 | # Initialize object for simulation 145 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object( 146 | altitudes, coord, world_shape, step_size) 147 | 148 | source = start 149 | trajectory = [source] 150 | 151 | for i in range(time_steps): 152 | 153 | # Choose action at random 154 | a = np.random.choice([1, 2, 3, 4]) 155 | 156 | # Add resulting state to trajectory 157 | for _, next_node, data in x.graph.out_edges(nbunch=source, data=True): 158 | if data["action"] == a: 159 | trajectory = trajectory + [next_node] 160 | source = next_node 161 | 162 | # Check safety 163 | path_altitudes = x.altitudes[trajectory] 164 | unsafe_count = np.sum(-np.diff(path_altitudes) < h_hard) 165 | 166 | # Get trajectory up to first unsafe transition 167 | trajectory = safe_subpath(trajectory, x.altitudes, h_hard) 168 | 169 | # Convert trajectory to S_hat-like matrix 170 | visited = path_to_boolean_matrix(trajectory, x.graph, x.S) 171 | 172 | # Normalization factor 173 | max_size = float(np.count_nonzero(true_S_hat_epsilon)) 174 | 175 | # Performance 176 | coverage = 100 * np.count_nonzero(np.logical_and(visited, 177 | true_S_hat_epsilon))/max_size 178 | # Print 179 | print("----RANDOM EXPERIMENT---") 180 | print("Unsafe evaluations: " + str(unsafe_count)) 181 | print("Coverage: " + str(coverage)) 182 | 183 | if save_performance: 184 | file_name = "mars random experiment" 185 | 186 | np.savez(file_name, coverage=coverage, altitudes=x.altitudes, 187 | visited=visited, world_shape=world_shape) 188 | 189 | ############################# NON ERGODIC ##################################### 190 | if non_ergodic_experiment: 191 | 192 | # Get mars data 193 | altitudes, coord, world_shape, step_size, num_of_points = mars_map() 194 | 195 | # Need to remove expanders otherwise next sample will be in G and 196 | # therefore in S_hat before I can set S_hat = S 197 | L_non_ergodic = 1000. 198 | 199 | # Initialize object for simulation 200 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object( 201 | altitudes, coord, world_shape, step_size, L=L_non_ergodic) 202 | 203 | # Simulation loop 204 | unsafe_count = 0 205 | source = start 206 | 207 | for i in range(time_steps): 208 | x.update_sets() 209 | 210 | # Remove ergodicity properties 211 | x.S_hat = x.S.copy() 212 | 213 | next_sample = x.target_sample() 214 | x.add_observation(*next_sample) 215 | try: 216 | path = shortest_path(source, next_sample, x.graph) 217 | source = path[-1] 218 | 219 | # Check safety 220 | path_altitudes = x.altitudes[path] 221 | unsafe_transitions = np.sum(-np.diff(path_altitudes) < h_hard) 222 | unsafe_count += unsafe_transitions 223 | except Exception: 224 | print ("No safe path available") 225 | break 226 | 227 | # For coverage we consider every state that has at least one action 228 | # classified as safe 229 | x.S_hat[:, 0] = np.any(x.S_hat[:, 1:], axis=1) 230 | 231 | # Normalization factor 232 | max_size = float(np.count_nonzero(true_S_hat_epsilon)) 233 | 234 | # Performance 235 | coverage = 100 * np.count_nonzero(np.logical_and(x.S_hat, 236 | true_S_hat_epsilon))/max_size 237 | 238 | # Print 239 | print("----NON ERGODIC EXPERIMENT---") 240 | print("Unsafe evaluations: " + str(unsafe_count)) 241 | print("Coverage: " + str(coverage)) 242 | 243 | if save_performance: 244 | file_name = "mars non ergodic experiment" 245 | 246 | np.savez(file_name, coverage=coverage, altitudes=x.altitudes, 247 | S_hat=x.S_hat, world_shape=world_shape) 248 | 249 | ################################## NO EXPANDERS ############################### 250 | if no_expanders_exploration: 251 | 252 | # Get mars data 253 | altitudes, coord, world_shape, step_size, num_of_points = mars_map() 254 | 255 | L_no_expanders = 1000. 256 | 257 | # Initialize object for simulation 258 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object( 259 | altitudes, coord, world_shape, step_size, L=L_no_expanders) 260 | 261 | # Initialize for performance storage 262 | coverage_over_t = np.empty(time_steps, dtype=float) 263 | 264 | # Simulation loop 265 | source = start 266 | unsafe_count = 0 267 | 268 | for i in range(int(time_steps)): 269 | 270 | #Simulation 271 | x.update_sets() 272 | next_sample = x.target_sample() 273 | x.add_observation(*next_sample) 274 | path = shortest_path(source, next_sample, x.graph) 275 | source = path[-1] 276 | 277 | # Performances 278 | unsafe_transitions, coverage, false_safe = performance_metrics(path, x, 279 | true_S_hat_epsilon, 280 | true_S_hat, h_hard) 281 | unsafe_count += unsafe_transitions 282 | coverage_over_t[i] = coverage 283 | print(coverage, false_safe, unsafe_count, i) 284 | 285 | # Print 286 | print("----NO EXPANDER EXPERIMENT---") 287 | print("False safe: " + str(false_safe)) 288 | print("Unsafe evaluations: " + str(unsafe_count)) 289 | print("Coverage: " + str(coverage)) 290 | 291 | if save_performance: 292 | file_name = "mars no G experiment" 293 | 294 | np.savez(file_name, false_safe=false_safe, coverage=coverage, 295 | coverage_over_t=coverage_over_t, altitudes=x.altitudes, S_hat=x.S_hat, 296 | world_shape=world_shape) 297 | -------------------------------------------------------------------------------- /examples/mars/mars_utilities.py: -------------------------------------------------------------------------------- 1 | from safemdp.grid_world import * 2 | from osgeo import gdal 3 | from scipy import interpolate 4 | import numpy as np 5 | import os 6 | import matplotlib.pyplot as plt 7 | import GPy 8 | 9 | 10 | __all__ = ['mars_map', 'initialize_SafeMDP_object', 'performance_metrics'] 11 | 12 | 13 | def mars_map(plot_map=False, interpolation=False): 14 | """ 15 | Extract the map for the simulation from the HiRISE data. If the HiRISE 16 | data is not in the current folder it will be downloaded and converted to 17 | GeoTiff extension with gdal. 18 | 19 | Parameters 20 | ---------- 21 | plot_map: bool 22 | If true plots the map that will be used for exploration 23 | interpolation: bool 24 | If true the data of the map will be interpolated with splines to 25 | obtain a finer grid 26 | 27 | Returns 28 | ------- 29 | altitudes: np.array 30 | 1-d vector with altitudes for each node 31 | coord: np.array 32 | Coordinate of the map we use for exploration 33 | world_shape: tuple 34 | Size of the grid world (rows, columns) 35 | step_size: tuple 36 | Step size for the grid (row, column) 37 | num_of_points: int 38 | Interpolation parameter. Indicates the scaling factor for the 39 | original step size 40 | """ 41 | 42 | # Define the dimension of the map we want to investigate and its resolution 43 | world_shape = (120, 70) 44 | step_size = (1., 1.) 45 | 46 | # Download and convert to GEOtiff Mars data 47 | if not os.path.exists('./mars.tif'): 48 | if not os.path.exists("./mars.IMG"): 49 | import urllib 50 | 51 | print('Downloading MARS map, this make take a while...') 52 | # Download the IMG file 53 | urllib.urlretrieve( 54 | "http://www.uahirise.org/PDS/DTM/PSP/ORB_010200_010299" 55 | "/PSP_010228_1490_ESP_016320_1490" 56 | "/DTEEC_010228_1490_016320_1490_A01.IMG", "mars.IMG") 57 | 58 | # Convert to tif 59 | print('Converting map to geotif...') 60 | os.system("gdal_translate -of GTiff ./mars.IMG ./mars.tif") 61 | print('Done') 62 | 63 | # Read the data with gdal module 64 | gdal.UseExceptions() 65 | ds = gdal.Open("./mars.tif") 66 | band = ds.GetRasterBand(1) 67 | elevation = band.ReadAsArray() 68 | 69 | # Extract the area of interest 70 | startX = 2890 71 | startY = 1955 72 | altitudes = np.copy(elevation[startX:startX + world_shape[0], 73 | startY:startY + world_shape[1]]) 74 | 75 | # Center the data 76 | mean_val = (np.max(altitudes) + np.min(altitudes)) / 2. 77 | altitudes[:] = altitudes - mean_val 78 | 79 | # Define coordinates 80 | n, m = world_shape 81 | step1, step2 = step_size 82 | xx, yy = np.meshgrid(np.linspace(0, (n - 1) * step1, n), 83 | np.linspace(0, (m - 1) * step2, m), indexing="ij") 84 | coord = np.vstack((xx.flatten(), yy.flatten())).T 85 | 86 | # Interpolate data 87 | if interpolation: 88 | 89 | # Interpolating function 90 | spline_interpolator = interpolate.RectBivariateSpline( 91 | np.linspace(0, (n - 1) * step1, n), 92 | np.linspace(0, (m - 1) * step1, m), altitudes) 93 | 94 | # New size and resolution 95 | num_of_points = 1 96 | world_shape = tuple([(x - 1) * num_of_points + 1 for x in world_shape]) 97 | step_size = tuple([x / num_of_points for x in step_size]) 98 | 99 | # New coordinates and altitudes 100 | n, m = world_shape 101 | step1, step2 = step_size 102 | xx, yy = np.meshgrid(np.linspace(0, (n - 1) * step1, n), 103 | np.linspace(0, (m - 1) * step2, m), indexing="ij") 104 | coord = np.vstack((xx.flatten(), yy.flatten())).T 105 | 106 | altitudes = spline_interpolator(np.linspace(0, (n - 1) * step1, n), 107 | np.linspace(0, (m - 1) * step2, m)) 108 | else: 109 | num_of_points = 1 110 | 111 | # Plot area 112 | if plot_map: 113 | plt.imshow(altitudes.T, origin="lower", interpolation="nearest") 114 | plt.colorbar() 115 | plt.show() 116 | altitudes = altitudes.flatten() 117 | 118 | return altitudes, coord, world_shape, step_size, num_of_points 119 | 120 | 121 | def initialize_SafeMDP_object(altitudes, coord, world_shape, step_size, L=0.2, 122 | beta=2, length=14.5, sigma_n=0.075, start_x=60, 123 | start_y=61): 124 | """ 125 | 126 | Parameters 127 | ---------- 128 | altitudes: np.array 129 | 1-d vector with altitudes for each node 130 | coord: np.array 131 | Coordinate of the map we use for exploration 132 | world_shape: tuple 133 | Size of the grid world (rows, columns) 134 | step_size: tuple 135 | Step size for the grid (row, column) 136 | L: float 137 | Lipschitz constant to compute expanders 138 | beta: float 139 | Scaling factor for confidence intervals 140 | length: float 141 | Lengthscale for Matern kernel 142 | sigma_n: 143 | Standard deviation for gaussian noise 144 | start_x: int 145 | x coordinate of the starting point 146 | start_y: 147 | y coordinate of the starting point 148 | 149 | Returns 150 | ------- 151 | start: int 152 | Node number of initial state 153 | x: SafeMDP 154 | Instance of the SafeMDP class for the mars exploration problem 155 | true_S_hat: np.array 156 | True S_hat if safety feature is known with no error and h_hard is used 157 | true_S_hat_epsilon: np.array 158 | True S_hat if safety feature is known up to epsilon and h is used 159 | h_hard: float 160 | True safety thrshold. It can be different from the safety threshold 161 | used for classification in case the agent needs to use extra caution 162 | (in our experiments h=25 deg, h_har=30 deg) 163 | """ 164 | 165 | # Safety threshold 166 | h = -np.tan(np.pi / 9. + np.pi / 36.) * step_size[0] 167 | 168 | #Initial node 169 | start = start_x * world_shape[1] + start_y 170 | 171 | # Initial safe sets 172 | S_hat0 = compute_S_hat0(start, world_shape, 4, altitudes, 173 | step_size, h) 174 | S0 = np.copy(S_hat0) 175 | S0[:, 0] = True 176 | 177 | # Initialize GP 178 | X = coord[start, :].reshape(1, 2) 179 | Y = altitudes[start].reshape(1, 1) 180 | kernel = GPy.kern.Matern52(input_dim=2, lengthscale=length, variance=100.) 181 | lik = GPy.likelihoods.Gaussian(variance=sigma_n ** 2) 182 | gp = GPy.core.GP(X, Y, kernel, lik) 183 | 184 | # Define SafeMDP object 185 | x = GridWorld(gp, world_shape, step_size, beta, altitudes, h, S0, 186 | S_hat0, L, update_dist=25) 187 | 188 | # Add samples about actions from starting node 189 | for i in range(5): 190 | x.add_observation(start, 1) 191 | x.add_observation(start, 2) 192 | x.add_observation(start, 3) 193 | x.add_observation(start, 4) 194 | 195 | x.gp.set_XY(X=x.gp.X[1:, :], Y=x.gp.Y[1:, :]) # Necessary for results as in 196 | # paper 197 | 198 | # True safe set for false safe 199 | h_hard = -np.tan(np.pi / 6.) * step_size[0] 200 | true_S = compute_true_safe_set(x.world_shape, x.altitudes, h_hard) 201 | true_S_hat = compute_true_S_hat(x.graph, true_S, x.initial_nodes) 202 | 203 | # True safe set for completeness 204 | epsilon = sigma_n * beta 205 | true_S_epsilon = compute_true_safe_set(x.world_shape, x.altitudes, 206 | x.h + epsilon) 207 | true_S_hat_epsilon = compute_true_S_hat(x.graph, true_S_epsilon, 208 | x.initial_nodes) 209 | 210 | return start, x, true_S_hat, true_S_hat_epsilon, h_hard 211 | 212 | 213 | def performance_metrics(path, x, true_S_hat_epsilon, true_S_hat, h_hard): 214 | """ 215 | 216 | Parameters 217 | ---------- 218 | path: np.array 219 | Nodes of the shortest safe path 220 | x: SafeMDP 221 | Instance of the SafeMDP class for the mars exploration problem 222 | true_S_hat_epsilon: np.array 223 | True S_hat if safety feature is known up to epsilon and h is used 224 | true_S_hat: np.array 225 | True S_hat if safety feature is known with no error and h_hard is used 226 | h_hard: float 227 | True safety thrshold. It can be different from the safety threshold 228 | used for classification in case the agent needs to use extra caution 229 | (in our experiments h=25 deg, h_har=30 deg) 230 | 231 | Returns 232 | ------- 233 | unsafe_transitions: int 234 | Number of unsafe transitions along the path 235 | coverage: float 236 | Percentage of coverage of true_S_hat_epsilon 237 | false_safe: int 238 | Number of misclassifications (classifing something as safe when it 239 | acutally is unsafe according to h_hard ) 240 | """ 241 | 242 | # Count unsafe transitions along the path 243 | path_altitudes = x.altitudes[path] 244 | unsafe_transitions = np.sum(-np.diff(path_altitudes) < h_hard) 245 | 246 | # Coverage 247 | max_size = float(np.count_nonzero(true_S_hat_epsilon)) 248 | coverage = 100 * np.count_nonzero(np.logical_and(x.S_hat, 249 | true_S_hat_epsilon))/max_size 250 | # False safe 251 | false_safe = np.count_nonzero(np.logical_and(x.S_hat, ~true_S_hat)) 252 | 253 | return unsafe_transitions, coverage, false_safe 254 | -------------------------------------------------------------------------------- /examples/mars/plot_utilities.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | from matplotlib import rcParams 4 | from matplotlib.colors import ColorConverter 5 | 6 | 7 | def paper_figure(figsize, subplots=None, **kwargs): 8 | """Define default values for font, fontsize and use latex""" 9 | 10 | def cm2inch(cm_tupl): 11 | """Convert cm to inches""" 12 | inch = 2.54 13 | return (cm / inch for cm in cm_tupl) 14 | 15 | if subplots is None: 16 | fig = plt.figure(figsize=cm2inch(figsize)) 17 | else: 18 | fig, ax = plt.subplots(subplots[0], subplots[1], 19 | figsize=cm2inch(figsize), **kwargs) 20 | 21 | # Parameters for IJRR 22 | params = { 23 | 'font.family': 'serif', 24 | 'font.serif': ['Times', 25 | 'Palatino', 26 | 'New Century Schoolbook', 27 | 'Bookman', 28 | 'Computer Modern Roman'], 29 | 'font.sans-serif': ['Times', 30 | 'Helvetica', 31 | 'Avant Garde', 32 | 'Computer Modern Sans serif'], 33 | 'text.usetex': True, 34 | # Make sure mathcal doesn't use the Times style 35 | 'text.latex.preamble': 36 | r'\DeclareMathAlphabet{\mathcal}{OMS}{cmsy}{m}{n}', 37 | 38 | 'axes.labelsize': 9, 39 | 'axes.linewidth': .75, 40 | 41 | 'font.size': 9, 42 | 'legend.fontsize': 9, 43 | 'xtick.labelsize': 8, 44 | 'ytick.labelsize': 8, 45 | 46 | # 'figure.dpi': 150, 47 | # 'savefig.dpi': 600, 48 | 'legend.numpoints': 1, 49 | } 50 | 51 | rcParams.update(params) 52 | 53 | if subplots is None: 54 | return fig 55 | else: 56 | return fig, ax 57 | 58 | 59 | def format_figure(axis, cbar=None): 60 | axis.spines['top'].set_linewidth(0.1) 61 | axis.spines['top'].set_alpha(0.5) 62 | axis.spines['right'].set_linewidth(0.1) 63 | axis.spines['right'].set_alpha(0.5) 64 | axis.xaxis.set_ticks_position('bottom') 65 | axis.yaxis.set_ticks_position('left') 66 | 67 | axis.set_xticks(np.arange(0, 121, 30)) 68 | yticks = np.arange(0, 71, 35) 69 | axis.set_yticks(yticks) 70 | axis.set_yticklabels(['{0}'.format(tick) for tick in yticks[::-1]]) 71 | 72 | axis.set_xlabel(r'distance [m]') 73 | axis.set_ylabel(r'distance [m]', labelpad=2) 74 | if cbar is not None: 75 | cbar.set_label(r'altitude [m]') 76 | 77 | cbar.set_ticks(np.arange(0, 36, 10)) 78 | 79 | for spine in cbar.ax.spines.itervalues(): 80 | spine.set_linewidth(0.1) 81 | cbar.ax.yaxis.set_tick_params(color=emulate_color('k', 0.7)) 82 | 83 | plt.tight_layout(pad=0.1) 84 | 85 | 86 | def emulate_color(color, alpha=1, background_color=(1, 1, 1)): 87 | """Take an RGBA color and an RGB background, return the emulated RGB color. 88 | 89 | The RGBA color with transparency alpha is converted to an RGB color via 90 | emulation in front of the background_color. 91 | """ 92 | to_rgb = ColorConverter().to_rgb 93 | color = to_rgb(color) 94 | background_color = to_rgb(background_color) 95 | return [(1 - alpha) * bg_col + alpha * col 96 | for col, bg_col in zip(color, background_color)] 97 | 98 | 99 | def plot_paper(altitudes, S_hat, world_shape, fileName=""): 100 | """ 101 | Plots for NIPS paper 102 | Parameters 103 | ---------- 104 | altitudes: np.array 105 | True value of the altitudes of the map 106 | S_hat: np.array 107 | Safe and ergodic set 108 | world_shape: tuple 109 | Size of the grid world (rows, columns) 110 | fileName: string 111 | Name of the file to save the plot. If empty string the plot is not 112 | saved 113 | Returns 114 | ------- 115 | 116 | """ 117 | # Size of figures and colormap 118 | tw = cw = 13.968 119 | cmap = 'jet' 120 | alpha = 1. 121 | alpha_world = 0.25 122 | size_wb = np.array([cw / 2.2, tw / 4.]) 123 | #size_wb = np.array([cw / 4.2, cw / 4.2]) 124 | 125 | # Shift altitudes 126 | altitudes -= np.nanmin(altitudes) 127 | vmin, vmax = (np.nanmin(altitudes), np.nanmax(altitudes)) 128 | origin = 'lower' 129 | 130 | fig = paper_figure(size_wb) 131 | 132 | # Copy altitudes for different alpha values 133 | altitudes2 = altitudes.copy() 134 | altitudes2[~S_hat[:, 0]] = np.nan 135 | 136 | axis = fig.gca() 137 | 138 | # Plot world 139 | c = axis.imshow(np.reshape(altitudes, world_shape).T, origin=origin, vmin=vmin, 140 | vmax=vmax, cmap=cmap, alpha=alpha_world) 141 | 142 | cbar = plt.colorbar(c) 143 | #cbar = None 144 | 145 | # Plot explored area 146 | plt.imshow(np.reshape(altitudes2, world_shape).T, origin=origin, vmin=vmin, 147 | vmax=vmax, interpolation='nearest', cmap=cmap, alpha=alpha) 148 | format_figure(axis, cbar) 149 | 150 | # Save figure 151 | if fileName: 152 | plt.savefig(fileName, transparent=False, format="pdf") 153 | plt.show() 154 | 155 | 156 | 157 | def plot_dist_from_C(mu, var, beta, altitudes, world_shape): 158 | """ 159 | Image plot of the distance of the true safety feature from the 160 | confidence interval. Distance is equal to 0 if the true r(s) lies within 161 | C(s), it is > 0 if r(s)>u(s) and < 0 if r(s) 0] = diff_u[diff_u > 0] 189 | 190 | # Below l 191 | diff_l = altitudes - l 192 | dist_from_confidence_interval[diff_l < 0] = diff_l[diff_l < 0] 193 | 194 | # Define limits 195 | max_value = np.max(dist_from_confidence_interval) 196 | min_value = np.min(dist_from_confidence_interval) 197 | limit = np.max([max_value, np.abs(min_value)]) 198 | 199 | # Plot 200 | plt.figure() 201 | plt.imshow( 202 | np.reshape(dist_from_confidence_interval, world_shape).T, origin='lower', 203 | interpolation='nearest', vmin=-limit, vmax=limit) 204 | title = "Distance from confidence interval" 205 | plt.title(title) 206 | plt.colorbar() 207 | plt.show() 208 | 209 | 210 | def plot_coverage(coverage_over_t): 211 | """ 212 | Plots coverage of true_S_hat_epsilon as a function of time 213 | 214 | """ 215 | plt.figure() 216 | plt.plot(coverage_over_t) 217 | title = "Coverage over time" 218 | plt.title(title) 219 | plt.show() 220 | 221 | -------------------------------------------------------------------------------- /examples/mars/requirements.txt: -------------------------------------------------------------------------------- 1 | cython == 0.23.4 2 | matplotlib == 1.5.1 3 | mkl == 11.3.1 4 | mkl-service == 1.1.2 5 | networkx == 1.11 6 | numpy == 1.10.0 7 | scipy == 0.17.0 8 | osgeo >= 2.0.0 9 | -------------------------------------------------------------------------------- /examples/sample.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import time 4 | 5 | import GPy 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | from safemdp.grid_world import (compute_true_safe_set, compute_S_hat0, 10 | compute_true_S_hat, draw_gp_sample, GridWorld) 11 | 12 | # Define world 13 | world_shape = (20, 20) 14 | step_size = (0.5, 0.5) 15 | 16 | # Define GP 17 | noise = 0.001 18 | kernel = GPy.kern.RBF(input_dim=2, lengthscale=(2., 2.), variance=1., 19 | ARD=True) 20 | lik = GPy.likelihoods.Gaussian(variance=noise ** 2) 21 | lik.constrain_bounded(1e-6, 10000.) 22 | 23 | # Sample and plot world 24 | altitudes, coord = draw_gp_sample(kernel, world_shape, step_size) 25 | fig = plt.figure() 26 | ax = fig.add_subplot(111, projection='3d') 27 | ax.plot_trisurf(coord[:, 0], coord[:, 1], altitudes) 28 | plt.show() 29 | 30 | # Define coordinates 31 | n, m = world_shape 32 | step1, step2 = step_size 33 | xx, yy = np.meshgrid(np.linspace(0, (n - 1) * step1, n), 34 | np.linspace(0, (m - 1) * step2, m), 35 | indexing="ij") 36 | coord = np.vstack((xx.flatten(), yy.flatten())).T 37 | 38 | # Safety threhsold 39 | h = -0.25 40 | 41 | # Lipschitz 42 | L = 0 43 | 44 | # Scaling factor for confidence interval 45 | beta = 2 46 | 47 | # Data to initialize GP 48 | n_samples = 1 49 | ind = np.random.choice(range(altitudes.size), n_samples) 50 | X = coord[ind, :] 51 | Y = altitudes[ind].reshape(n_samples, 1) + np.random.randn(n_samples, 52 | 1) 53 | gp = GPy.core.GP(X, Y, kernel, lik) 54 | 55 | # Initialize safe sets 56 | S0 = np.zeros((np.prod(world_shape), 5), dtype=bool) 57 | S0[:, 0] = True 58 | S_hat0 = compute_S_hat0(np.nan, world_shape, 4, altitudes, 59 | step_size, h) 60 | 61 | # Define SafeMDP object 62 | x = GridWorld(gp, world_shape, step_size, beta, altitudes, h, S0, S_hat0, 63 | L) 64 | 65 | # Insert samples from (s, a) in S_hat0 66 | tmp = np.arange(x.coord.shape[0]) 67 | s_vec_ind = np.random.choice(tmp[np.any(x.S_hat[:, 1:], axis=1)]) 68 | tmp = np.arange(1, x.S.shape[1]) 69 | actions = tmp[x.S_hat[s_vec_ind, 1:].squeeze()] 70 | for i in range(3): 71 | x.add_observation(s_vec_ind, np.random.choice(actions)) 72 | 73 | # Remove samples used for GP initialization 74 | x.gp.set_XY(x.gp.X[n_samples:, :], x.gp.Y[n_samples:]) 75 | 76 | t = time.time() 77 | for i in range(100): 78 | x.update_sets() 79 | next_sample = x.target_sample() 80 | x.add_observation(*next_sample) 81 | # x.compute_graph_lazy() 82 | # plt.figure(1) 83 | # plt.clf() 84 | # nx.draw_networkx(x.graph) 85 | # plt.show() 86 | print("Iteration: " + str(i)) 87 | 88 | print(str(time.time() - t) + "seconds elapsed") 89 | 90 | true_S = compute_true_safe_set(x.world_shape, x.altitudes, x.h) 91 | true_S_hat = compute_true_S_hat(x.graph, true_S, x.initial_nodes) 92 | 93 | # Plot safe sets 94 | x.plot_S(x.S_hat) 95 | x.plot_S(true_S_hat) 96 | 97 | # Classification performance 98 | print(np.sum(np.logical_and(true_S_hat, np.logical_not( 99 | x.S_hat)))) # in true S_hat and not S_hat 100 | print(np.sum(np.logical_and(x.S_hat, np.logical_not(true_S_hat)))) 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | GPy >= 0.8.0 2 | numpy >= 1.7.2 3 | scipy >= 0.16 4 | matplotlib >= 1.5.0 5 | networkx >= 1.1 6 | 7 | -------------------------------------------------------------------------------- /safemdp/SafeMDP_class.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import numpy as np 4 | 5 | from .utilities import max_out_degree 6 | 7 | __all__ = ['SafeMDP', 'link_graph_and_safe_set', 'reachable_set', 8 | 'returnable_set'] 9 | 10 | 11 | class SafeMDP(object): 12 | """Base class for safe exploration in MDPs. 13 | 14 | This class only provides basic options to compute the safely reachable 15 | and returnable sets. The actual update of the safety feature must be done 16 | in a class that inherits from `SafeMDP`. See `safempd.GridWorld` for an 17 | example. 18 | 19 | Parameters 20 | ---------- 21 | graph: networkx.DiGraph 22 | The graph that models the MDP. Each edge has an attribute `safe` in its 23 | metadata, which determines the safety of the transition. 24 | gp: GPy.core.GPRegression 25 | A Gaussian process model that can be used to determine the safety of 26 | transitions. Exact structure depends heavily on the usecase. 27 | S_hat0: boolean array 28 | An array that has True on the ith position if the ith node in the graph 29 | is part of the safe set. 30 | h: float 31 | The safety threshold. 32 | L: float 33 | The lipschitz constant 34 | beta: float, optional 35 | The confidence interval used by the GP model. 36 | """ 37 | def __init__(self, graph, gp, S_hat0, h, L, beta=2): 38 | super(SafeMDP, self).__init__() 39 | # Scalar for gp confidence intervals 40 | self.beta = beta 41 | 42 | # Threshold 43 | self.h = h 44 | 45 | # Lipschitz constant 46 | self.L = L 47 | 48 | # GP model 49 | self.gp = gp 50 | 51 | self.graph = graph 52 | self.graph_reverse = self.graph.reverse() 53 | 54 | num_nodes = self.graph.number_of_nodes() 55 | num_edges = max_out_degree(graph) 56 | safe_set_size = (num_nodes, num_edges + 1) 57 | 58 | self.reach = np.empty(safe_set_size, dtype=np.bool) 59 | self.G = np.empty(safe_set_size, dtype=np.bool) 60 | 61 | self.S_hat = S_hat0.copy() 62 | self.S_hat0 = self.S_hat.copy() 63 | self.initial_nodes = self.S_hat0[:, 0].nonzero()[0].tolist() 64 | 65 | def compute_S_hat(self): 66 | """Compute the safely reachable set given the current safe_set.""" 67 | self.reach[:] = False 68 | reachable_set(self.graph, self.initial_nodes, out=self.reach) 69 | 70 | self.S_hat[:] = False 71 | returnable_set(self.graph, self.graph_reverse, self.initial_nodes, 72 | out=self.S_hat) 73 | 74 | self.S_hat &= self.reach 75 | 76 | def add_gp_observations(self, x_new, y_new): 77 | """Add observations to the gp mode.""" 78 | # Update GP with observations 79 | self.gp.set_XY(np.vstack((self.gp.X, 80 | x_new)), 81 | np.vstack((self.gp.Y, 82 | y_new))) 83 | 84 | 85 | def link_graph_and_safe_set(graph, safe_set): 86 | """Link the safe set to the graph model. 87 | 88 | Parameters 89 | ---------- 90 | graph: nx.DiGraph() 91 | safe_set: np.array 92 | Safe set. For each node the edge (i, j) under action (a) is linked to 93 | safe_set[i, a] 94 | """ 95 | for node, next_node in graph.edges_iter(): 96 | edge = graph[node][next_node] 97 | edge['safe'] = safe_set[node:node + 1, edge['action']] 98 | 99 | 100 | def reachable_set(graph, initial_nodes, out=None): 101 | """ 102 | Compute the safe, reachable set of a graph 103 | 104 | Parameters 105 | ---------- 106 | graph: nx.DiGraph 107 | Directed graph. Each edge must have associated action metadata, 108 | which specifies the action that this edge corresponds to. 109 | Each edge has an attribute ['safe'], which is a boolean that 110 | indicates safety 111 | initial_nodes: list 112 | List of the initial, safe nodes that are used as a starting point to 113 | compute the reachable set. 114 | out: np.array 115 | The array to write the results to. Is assumed to be False everywhere 116 | except at the initial nodes 117 | 118 | Returns 119 | ------- 120 | reachable_set: np.array 121 | Boolean array that indicates whether a node belongs to the reachable 122 | set. 123 | """ 124 | 125 | if not initial_nodes: 126 | raise AttributeError('Set of initial nodes needs to be non-empty.') 127 | 128 | if out is None: 129 | visited = np.zeros((graph.number_of_nodes(), 130 | max_out_degree(graph) + 1), 131 | dtype=np.bool) 132 | else: 133 | visited = out 134 | 135 | # All nodes in the initial set are visited 136 | visited[initial_nodes, 0] = True 137 | 138 | stack = list(initial_nodes) 139 | 140 | # TODO: rather than checking if things are safe, specify a safe subgraph? 141 | while stack: 142 | node = stack.pop(0) 143 | # iterate over edges going away from node 144 | for _, next_node, data in graph.edges_iter(node, data=True): 145 | action = data['action'] 146 | if not visited[node, action] and data['safe']: 147 | visited[node, action] = True 148 | if not visited[next_node, 0]: 149 | stack.append(next_node) 150 | visited[next_node, 0] = True 151 | 152 | if out is None: 153 | return visited 154 | 155 | 156 | def returnable_set(graph, reverse_graph, initial_nodes, out=None): 157 | """ 158 | Compute the safe, returnable set of a graph 159 | 160 | Parameters 161 | ---------- 162 | graph: nx.DiGraph 163 | Directed graph. Each edge must have associated action metadata, 164 | which specifies the action that this edge corresponds to. 165 | Each edge has an attribute ['safe'], which is a boolean that 166 | indicates safety 167 | reverse_graph: nx.DiGraph 168 | The reversed directed graph, `graph.reverse()` 169 | initial_nodes: list 170 | List of the initial, safe nodes that are used as a starting point to 171 | compute the returnable set. 172 | out: np.array 173 | The array to write the results to. Is assumed to be False everywhere 174 | except at the initial nodes 175 | 176 | Returns 177 | ------- 178 | returnable_set: np.array 179 | Boolean array that indicates whether a node belongs to the returnable 180 | set. 181 | """ 182 | 183 | if not initial_nodes: 184 | raise AttributeError('Set of initial nodes needs to be non-empty.') 185 | 186 | if out is None: 187 | visited = np.zeros((graph.number_of_nodes(), 188 | max_out_degree(graph) + 1), 189 | dtype=np.bool) 190 | else: 191 | visited = out 192 | 193 | # All nodes in the initial set are visited 194 | visited[initial_nodes, 0] = True 195 | 196 | stack = list(initial_nodes) 197 | 198 | # TODO: rather than checking if things are safe, specify a safe subgraph? 199 | while stack: 200 | node = stack.pop(0) 201 | # iterate over edges going into node 202 | for _, prev_node in reverse_graph.edges_iter(node): 203 | data = graph.get_edge_data(prev_node, node) 204 | if not visited[prev_node, data['action']] and data['safe']: 205 | visited[prev_node, data['action']] = True 206 | if not visited[prev_node, 0]: 207 | stack.append(prev_node) 208 | visited[prev_node, 0] = True 209 | 210 | if out is None: 211 | return visited 212 | -------------------------------------------------------------------------------- /safemdp/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | The `safemdp` package implements tools for safe exploration in finite MDPs. 3 | 4 | Main classes 5 | ------------ 6 | 7 | These classes provide the main functionality for the safe exploration 8 | 9 | .. autosummary:: 10 | :template: template.rst 11 | :toctree: 12 | 13 | SafeMDP 14 | link_graph_and_safe_set 15 | reachable_set 16 | returnable_set 17 | 18 | Grid world 19 | ---------- 20 | 21 | Some additional functionality specific to gridworlds. 22 | 23 | .. autosummary:: 24 | :template: template.rst 25 | :toctree: 26 | 27 | GridWorld 28 | states_to_nodes 29 | nodes_to_states 30 | draw_gp_sample 31 | grid_world_graph 32 | grid 33 | compute_true_safe_set 34 | compute_true_S_hat 35 | compute_S_hat0 36 | shortest_path 37 | path_to_boolean_matrix 38 | safe_subpath 39 | 40 | 41 | Utilities 42 | --------- 43 | 44 | The following are utilities to make testing and working with the library more 45 | pleasant. 46 | 47 | .. autosummary:: 48 | :template: template.rst 49 | :toctree: 50 | 51 | DifferenceKernel 52 | max_out_degree 53 | """ 54 | 55 | from __future__ import absolute_import 56 | 57 | from .SafeMDP_class import * 58 | from .utilities import * 59 | from .grid_world import * 60 | 61 | # Add everything to __all__ 62 | __all__ = [s for s in dir() if not s.startswith('_')] 63 | 64 | # Import test after __all__ (no documentation) 65 | from numpy.testing import Tester 66 | test = Tester().test 67 | -------------------------------------------------------------------------------- /safemdp/grid_world.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import networkx as nx 4 | import numpy as np 5 | from matplotlib import pyplot as plt 6 | from scipy.spatial.distance import cdist 7 | 8 | from .utilities import DifferenceKernel 9 | from .SafeMDP_class import (reachable_set, returnable_set, SafeMDP, 10 | link_graph_and_safe_set) 11 | 12 | 13 | __all__ = ['compute_true_safe_set', 'compute_true_S_hat', 'compute_S_hat0', 14 | 'grid_world_graph', 'grid', 'GridWorld', 'draw_gp_sample', 15 | 'states_to_nodes', 'nodes_to_states', 'shortest_path', 16 | 'path_to_boolean_matrix', 'safe_subpath'] 17 | 18 | 19 | def compute_true_safe_set(world_shape, altitude, h): 20 | """ 21 | Computes the safe set given a perfect knowledge of the map 22 | 23 | Parameters 24 | ---------- 25 | world_shape: tuple 26 | altitude: np.array 27 | 1-d vector with altitudes for each node 28 | h: float 29 | Safety threshold for height differences 30 | 31 | Returns 32 | ------- 33 | true_safe: np.array 34 | Boolean array n_states x (n_actions + 1). 35 | """ 36 | 37 | true_safe = np.zeros((world_shape[0] * world_shape[1], 5), dtype=np.bool) 38 | 39 | altitude_grid = altitude.reshape(world_shape) 40 | 41 | # Reshape so that first dimensions are actions, the rest is the grid world. 42 | safe_grid = true_safe.T.reshape((5,) + world_shape) 43 | 44 | # Height difference (next height - current height) --> positive if downhill 45 | up_diff = altitude_grid[:, :-1] - altitude_grid[:, 1:] 46 | right_diff = altitude_grid[:-1, :] - altitude_grid[1:, :] 47 | 48 | # State are always safe 49 | true_safe[:, 0] = True 50 | 51 | # Going in the opposite direction 52 | safe_grid[1, :, :-1] = up_diff >= h 53 | safe_grid[2, :-1, :] = right_diff >= h 54 | safe_grid[3, :, 1:] = -up_diff >= h 55 | safe_grid[4, 1:, :] = -right_diff >= h 56 | 57 | return true_safe 58 | 59 | 60 | def dynamics_vec_ind(states_vec_ind, action, world_shape): 61 | """ 62 | Dynamic evolution of the system defined in vector representation of 63 | the states 64 | 65 | Parameters 66 | ---------- 67 | states_vec_ind: np.array 68 | Contains all the vector indexes of the states we want to compute 69 | the dynamic evolution for 70 | action: int 71 | action performed by the agent 72 | 73 | Returns 74 | ------- 75 | next_states_vec_ind: np.array 76 | vector index of states resulting from applying the action given 77 | as input to the array of starting points given as input 78 | """ 79 | n, m = world_shape 80 | next_states_vec_ind = np.copy(states_vec_ind) 81 | if action == 1: 82 | next_states_vec_ind[:] = states_vec_ind + 1 83 | condition = np.mod(next_states_vec_ind, m) == 0 84 | next_states_vec_ind[condition] = states_vec_ind[condition] 85 | elif action == 2: 86 | next_states_vec_ind[:] = states_vec_ind + m 87 | condition = next_states_vec_ind >= m * n 88 | next_states_vec_ind[condition] = states_vec_ind[condition] 89 | elif action == 3: 90 | next_states_vec_ind[:] = states_vec_ind - 1 91 | condition = np.mod(states_vec_ind, m) == 0 92 | next_states_vec_ind[condition] = states_vec_ind[condition] 93 | elif action == 4: 94 | next_states_vec_ind[:] = states_vec_ind - m 95 | condition = next_states_vec_ind <= -1 96 | next_states_vec_ind[condition] = states_vec_ind[condition] 97 | else: 98 | raise ValueError("Unknown action") 99 | return next_states_vec_ind 100 | 101 | 102 | def compute_S_hat0(s, world_shape, n_actions, altitudes, step_size, h): 103 | """ 104 | Compute a valid initial safe seed. 105 | 106 | Parameters 107 | --------- 108 | s: int or nan 109 | Vector index of the state where we start computing the safe seed 110 | from. If it is equal to nan, a state is chosen at random 111 | world_shape: tuple 112 | Size of the grid world (rows, columns) 113 | n_actions: int 114 | Number of actions available to the agent 115 | altitudes: np.array 116 | It contains the flattened n x m matrix where the altitudes of all 117 | the points in the map are stored 118 | step_size: tuple 119 | step sizes along each direction to create a linearly spaced grid 120 | h: float 121 | Safety threshold 122 | 123 | Returns 124 | ------ 125 | S_hat: np.array 126 | Boolean array n_states x (n_actions + 1). 127 | """ 128 | # Initialize 129 | n, m = world_shape 130 | n_states = n * m 131 | S_hat = np.zeros((n_states, n_actions + 1), dtype=bool) 132 | 133 | # In case an initial state is given 134 | if not np.isnan(s): 135 | S_hat[s, 0] = True 136 | valid_initial_seed = False 137 | vertical = False 138 | horizontal = False 139 | altitude_prev = altitudes[s] 140 | if not isinstance(s, np.ndarray): 141 | s = np.array([s]) 142 | 143 | # Loop through actions 144 | for action in range(1, n_actions + 1): 145 | 146 | # Compute next state to check steepness 147 | next_vec_ind = dynamics_vec_ind(s, action, world_shape) 148 | altitude_next = altitudes[next_vec_ind] 149 | 150 | if s != next_vec_ind and -np.abs(altitude_prev - altitude_next) / \ 151 | step_size[0] >= h: 152 | S_hat[s, action] = True 153 | S_hat[next_vec_ind, 0] = True 154 | S_hat[next_vec_ind, reverse_action(action)] = True 155 | if action == 1 or action == 3: 156 | vertical = True 157 | if action == 2 or action == 4: 158 | horizontal = True 159 | 160 | if vertical and horizontal: 161 | valid_initial_seed = True 162 | 163 | if valid_initial_seed: 164 | return S_hat 165 | else: 166 | print ("No valid initial seed starting from this state") 167 | S_hat[:] = False 168 | return S_hat 169 | 170 | # If an explicit initial state is not given 171 | else: 172 | while np.all(np.logical_not(S_hat)): 173 | initial_state = np.random.choice(n_states) 174 | S_hat = compute_S_hat0(initial_state, world_shape, n_actions, 175 | altitudes, step_size, h) 176 | return S_hat 177 | 178 | 179 | def reverse_action(action): 180 | # Computes the action that is the opposite of the one given as input 181 | 182 | rev_a = np.mod(action + 2, 4) 183 | if rev_a == 0: 184 | rev_a = 4 185 | return rev_a 186 | 187 | 188 | def grid_world_graph(world_size): 189 | """Create a graph that represents a grid world. 190 | 191 | In the grid world there are four actions, (1, 2, 3, 4), which correspond 192 | to going (up, right, down, left) in the x-y plane. The states are 193 | ordered so that `np.arange(np.prod(world_size)).reshape(world_size)` 194 | corresponds to a matrix where increasing the row index corresponds to the 195 | x direction in the graph, and increasing y index corresponds to the y 196 | direction. 197 | 198 | Parameters 199 | ---------- 200 | world_size: tuple 201 | The size of the grid world (rows, columns) 202 | 203 | Returns 204 | ------- 205 | graph: nx.DiGraph() 206 | The directed graph representing the grid world. 207 | """ 208 | nodes = np.arange(np.prod(world_size)) 209 | grid_nodes = nodes.reshape(world_size) 210 | 211 | graph = nx.DiGraph() 212 | 213 | # action 1: go right 214 | graph.add_edges_from(zip(grid_nodes[:, :-1].reshape(-1), 215 | grid_nodes[:, 1:].reshape(-1)), 216 | action=1) 217 | 218 | # action 2: go down 219 | graph.add_edges_from(zip(grid_nodes[:-1, :].reshape(-1), 220 | grid_nodes[1:, :].reshape(-1)), 221 | action=2) 222 | 223 | # action 3: go left 224 | graph.add_edges_from(zip(grid_nodes[:, 1:].reshape(-1), 225 | grid_nodes[:, :-1].reshape(-1)), 226 | action=3) 227 | 228 | # action 4: go up 229 | graph.add_edges_from(zip(grid_nodes[1:, :].reshape(-1), 230 | grid_nodes[:-1, :].reshape(-1)), 231 | action=4) 232 | 233 | return graph 234 | 235 | 236 | def compute_true_S_hat(graph, safe_set, initial_nodes, reverse_graph=None): 237 | """ 238 | Compute the true safe set with reachability and returnability. 239 | 240 | Parameters 241 | ---------- 242 | graph: nx.DiGraph 243 | safe_set: np.array 244 | initial_nodes: list of int 245 | reverse_graph: nx.DiGraph 246 | graph.reverse() 247 | 248 | Returns 249 | ------- 250 | true_safe: np.array 251 | Boolean array n_states x (n_actions + 1). 252 | """ 253 | graph = graph.copy() 254 | link_graph_and_safe_set(graph, safe_set) 255 | if reverse_graph is None: 256 | reverse_graph = graph.reverse() 257 | reach = reachable_set(graph, initial_nodes) 258 | ret = returnable_set(graph, reverse_graph, initial_nodes) 259 | ret &= reach 260 | return ret 261 | 262 | 263 | class GridWorld(SafeMDP): 264 | """ 265 | Grid world with Safe exploration 266 | 267 | Parameters 268 | ---------- 269 | gp: GPy.core.GP 270 | Gaussian process that expresses our current belief over the safety 271 | feature 272 | world_shape: shape 273 | Tuple that contains the shape of the grid world n x m 274 | step_size: tuple of floats 275 | Tuple that contains the step sizes along each direction to 276 | create a linearly spaced grid 277 | beta: float 278 | Scaling factor to determine the amplitude of the confidence 279 | intervals 280 | altitudes: np.array 281 | It contains the flattened n x m matrix where the altitudes 282 | of all the points in the map are stored 283 | h: float 284 | Safety threshold 285 | S0: np.array 286 | n_states x (n_actions + 1) array of booleans that indicates which 287 | states (first column) and which state-action pairs belong to the 288 | initial safe seed. Notice that, by convention we initialize all 289 | the states to be safe 290 | S_hat0: np.array or nan 291 | n_states x (n_actions + 1) array of booleans that indicates which 292 | states (first column) and which state-action pairs belong to the 293 | initial safe seed and satisfy recovery and reachability properties. 294 | If it is nan, such a boolean matrix is computed during 295 | initialization 296 | noise: float 297 | Standard deviation of the measurement noise 298 | L: float 299 | Lipschitz constant to compute expanders 300 | update_dist: int 301 | Distance in unweighted graph used for confidence interval update. 302 | A sample will only influence other nodes within this distance. 303 | """ 304 | def __init__(self, gp, world_shape, step_size, beta, altitudes, h, S0, 305 | S_hat0, L, update_dist=0): 306 | 307 | # Safe set 308 | self.S = S0.copy() 309 | graph = grid_world_graph(world_shape) 310 | link_graph_and_safe_set(graph, self.S) 311 | super(GridWorld, self).__init__(graph, gp, S_hat0, h, L, beta=2) 312 | 313 | self.altitudes = altitudes 314 | self.world_shape = world_shape 315 | self.step_size = step_size 316 | self.update_dist = update_dist 317 | 318 | # Grids for the map 319 | self.coord = grid(self.world_shape, self.step_size) 320 | 321 | # Distances 322 | self.distance_matrix = cdist(self.coord, self.coord) 323 | 324 | # Confidence intervals 325 | self.l = np.empty(self.S.shape, dtype=float) 326 | self.u = np.empty(self.S.shape, dtype=float) 327 | self.l[:] = -np.inf 328 | self.u[:] = np.inf 329 | self.l[self.S] = h 330 | 331 | # Prediction with difference of altitudes 332 | states_ind = np.arange(np.prod(self.world_shape)) 333 | states_grid = states_ind.reshape(world_shape) 334 | 335 | self._prev_up = states_grid[:, :-1].flatten() 336 | self._next_up = states_grid[:, 1:].flatten() 337 | self._prev_right = states_grid[:-1, :].flatten() 338 | self._next_right = states_grid[1:, :].flatten() 339 | 340 | self._mat_up = np.hstack((self.coord[self._prev_up, :], 341 | self.coord[self._next_up, :])) 342 | self._mat_right = np.hstack((self.coord[self._prev_right, :], 343 | self.coord[self._next_right, :])) 344 | 345 | def update_confidence_interval(self, jacobian=False): 346 | """ 347 | Updates the lower and the upper bound of the confidence intervals 348 | using then posterior distribution over the gradients of the altitudes 349 | 350 | Returns 351 | ------- 352 | l: np.array 353 | lower bound of the safety feature (mean - beta*std) 354 | u: np.array 355 | upper bound of the safety feature (mean - beta*std) 356 | """ 357 | if jacobian: 358 | # Predict safety feature 359 | mu, s = self.gp.predict_jacobian(self.coord, full_cov=False) 360 | mu = np.squeeze(mu) 361 | 362 | # Confidence interval 363 | s = self.beta * np.sqrt(s) 364 | 365 | # State are always safe 366 | self.l[:, 0] = self.u[:, 0] = self.h 367 | 368 | # Update safety feature 369 | self.l[:, [1, 2]] = -mu[:, ::-1] - s[:, ::-1] 370 | self.l[:, [3, 4]] = mu[:, ::-1] - s[:, ::-1] 371 | 372 | self.u[:, [1, 2]] = -mu[:, ::-1] + s[:, ::-1] 373 | self.u[:, [3, 4]] = mu[:, ::-1] + s[:, ::-1] 374 | 375 | elif self.update_dist > 0: 376 | # States are always safe 377 | self.l[:, 0] = self.u[:, 0] = self.h 378 | 379 | # Extract last two sampled states in the grid 380 | last_states = self.gp.X[-2:] 381 | last_nodes = states_to_nodes(last_states, self.world_shape, 382 | self.step_size) 383 | 384 | # Extract nodes to be updated 385 | nodes1 = nx.single_source_shortest_path(self.graph, last_nodes[0], 386 | self.update_dist).keys() 387 | nodes2 = nx.single_source_shortest_path(self.graph, last_nodes[1], 388 | self.update_dist).keys() 389 | update_nodes = np.union1d(nodes1, nodes2) 390 | subgraph = self.graph.subgraph(update_nodes) 391 | 392 | # Sort states to be updated according to actions 393 | prev_up = [] 394 | next_up = [] 395 | prev_right = [] 396 | next_right = [] 397 | 398 | for node1, node2, act in subgraph.edges_iter(data='action'): 399 | if act == 2: 400 | prev_right.append(node1) 401 | next_right.append(node2) 402 | elif act == 1: 403 | prev_up.append(node1) 404 | next_up.append(node2) 405 | 406 | mat_up = np.hstack((self.coord[prev_up, :], 407 | self.coord[next_up, :])) 408 | mat_right = np.hstack((self.coord[prev_right, :], 409 | self.coord[next_right, :])) 410 | 411 | # Update confidence for nodes around last sample 412 | mu_up, s_up = self.gp.predict(mat_up, 413 | kern=DifferenceKernel(self.gp.kern), 414 | full_cov=False) 415 | s_up = self.beta * np.sqrt(s_up) 416 | 417 | self.l[prev_up, 1, None] = mu_up - s_up 418 | self.u[prev_up, 1, None] = mu_up + s_up 419 | 420 | self.l[next_up, 3, None] = -mu_up - s_up 421 | self.u[next_up, 3, None] = -mu_up + s_up 422 | 423 | mu_right, s_right = self.gp.predict(mat_right, 424 | kern=DifferenceKernel( 425 | self.gp.kern), 426 | full_cov=False) 427 | s_right = self.beta * np.sqrt(s_right) 428 | 429 | self.l[prev_right, 2, None] = mu_right - s_right 430 | self.u[prev_right, 2, None] = mu_right + s_right 431 | 432 | self.l[next_right, 4, None] = -mu_right - s_right 433 | self.u[next_right, 4, None] = -mu_right + s_right 434 | 435 | else: 436 | # Initialize to unsafe 437 | self.l[:] = self.u[:] = self.h - 1 438 | 439 | # States are always safe 440 | self.l[:, 0] = self.u[:, 0] = self.h 441 | 442 | # Actions up and down 443 | mu_up, s_up = self.gp.predict(self._mat_up, 444 | kern=DifferenceKernel(self.gp.kern), 445 | full_cov=False) 446 | s_up = self.beta * np.sqrt(s_up) 447 | 448 | self.l[self._prev_up, 1, None] = mu_up - s_up 449 | self.u[self._prev_up, 1, None] = mu_up + s_up 450 | 451 | self.l[self._next_up, 3, None] = -mu_up - s_up 452 | self.u[self._next_up, 3, None] = -mu_up + s_up 453 | 454 | # Actions left and right 455 | mu_right, s_right = self.gp.predict(self._mat_right, 456 | kern=DifferenceKernel( 457 | self.gp.kern), 458 | full_cov=False) 459 | s_right = self.beta * np.sqrt(s_right) 460 | self.l[self._prev_right, 2, None] = mu_right - s_right 461 | self.u[self._prev_right, 2, None] = mu_right + s_right 462 | 463 | self.l[self._next_right, 4, None] = -mu_right - s_right 464 | self.u[self._next_right, 4, None] = -mu_right + s_right 465 | 466 | def compute_expanders(self): 467 | """Compute the expanders based on the current estimate of S_hat.""" 468 | self.G[:] = False 469 | 470 | for action in range(1, self.S_hat.shape[1]): 471 | 472 | # action-specific safe set 473 | s_hat = self.S_hat[:, action] 474 | 475 | # Extract distance from safe points to non safe ones 476 | distance = self.distance_matrix[np.ix_(s_hat, ~self.S[:, action])] 477 | 478 | # Update expanders for this particular action 479 | self.G[s_hat, action] = np.any( 480 | self.u[s_hat, action, None] - self.L * distance >= self.h, 481 | axis=1) 482 | 483 | def update_sets(self): 484 | """ 485 | Update the sets S, S_hat and G taking with the available observation 486 | """ 487 | self.update_confidence_interval() 488 | # self.S[:] = self.l >= self.h 489 | self.S |= self.l >= self.h 490 | 491 | self.compute_S_hat() 492 | self.compute_expanders() 493 | 494 | def plot_S(self, safe_set, action=0): 495 | """ 496 | Plot the set of safe states 497 | 498 | Parameters 499 | ---------- 500 | safe_set: np.array(dtype=bool) 501 | n_states x (n_actions + 1) array of boolean values that indicates 502 | the safe set 503 | action: int 504 | The action for which we want to plot the safe set. 505 | """ 506 | plt.figure(action) 507 | plt.imshow(np.reshape(safe_set[:, action], self.world_shape).T, 508 | origin='lower', interpolation='nearest', vmin=0, vmax=1) 509 | plt.title('action {0}'.format(action)) 510 | plt.show() 511 | 512 | def add_observation(self, node, action): 513 | """ 514 | Add an observation of the given state-action pair. 515 | 516 | Observing the pair (s, a) means adding an observation of the altitude 517 | at s and an observation of the altitude at f(s, a) 518 | 519 | Parameters 520 | ---------- 521 | node: int 522 | Node index 523 | action: int 524 | Action index 525 | """ 526 | # Observation of next state 527 | for _, next_node, data in self.graph.edges_iter(node, data=True): 528 | if data['action'] == action: 529 | break 530 | 531 | self.add_gp_observations(self.coord[[node, next_node], :], 532 | self.altitudes[[node, next_node], None]) 533 | 534 | def target_sample(self): 535 | """ 536 | Compute the next target (s, a) to sample (highest uncertainty within 537 | G or S_hat) 538 | 539 | Returns 540 | ------- 541 | node: int 542 | The next node to sample 543 | action: int 544 | The next action to sample 545 | """ 546 | if np.any(self.G): 547 | # Extract elements in G 548 | expander_id = np.nonzero(self.G) 549 | 550 | # Compute uncertainty 551 | w = self.u[self.G] - self.l[self.G] 552 | 553 | # Find max uncertainty 554 | max_id = np.argmax(w) 555 | 556 | else: 557 | print('No expanders, using most uncertain element in S_hat' 558 | 'instead.') 559 | 560 | # Extract elements in S_hat 561 | expander_id = np.nonzero(self.S_hat) 562 | 563 | # Compute uncertainty 564 | w = self.u[self.S_hat] - self.l[self.S_hat] 565 | 566 | # Find max uncertainty 567 | max_id = np.argmax(w) 568 | 569 | return expander_id[0][max_id], expander_id[1][max_id] 570 | 571 | 572 | def states_to_nodes(states, world_shape, step_size): 573 | """Convert physical states to node numbers. 574 | 575 | Parameters 576 | ---------- 577 | states: np.array 578 | States with physical coordinates 579 | world_shape: tuple 580 | The size of the grid_world 581 | step_size: tuple 582 | The step size of the grid world 583 | 584 | Returns 585 | ------- 586 | nodes: np.array 587 | The node indices corresponding to the states 588 | """ 589 | states = np.asanyarray(states) 590 | node_indices = np.rint(states / step_size).astype(np.int) 591 | return node_indices[:, 1] + world_shape[1] * node_indices[:, 0] 592 | 593 | 594 | def nodes_to_states(nodes, world_shape, step_size): 595 | """Convert node numbers to physical states. 596 | 597 | Parameters 598 | ---------- 599 | nodes: np.array 600 | Node indices of the grid world 601 | world_shape: tuple 602 | The size of the grid_world 603 | step_size: np.array 604 | The step size of the grid world 605 | 606 | Returns 607 | ------- 608 | states: np.array 609 | The states in physical coordinates 610 | """ 611 | nodes = np.asanyarray(nodes) 612 | step_size = np.asanyarray(step_size) 613 | return np.vstack((nodes // world_shape[1], 614 | nodes % world_shape[1])).T * step_size 615 | 616 | 617 | def grid(world_shape, step_size): 618 | """ 619 | Creates grids of coordinates and indices of state space 620 | 621 | Parameters 622 | ---------- 623 | world_shape: tuple 624 | Size of the grid world (rows, columns) 625 | step_size: tuple 626 | Phyiscal step size in the grid world 627 | 628 | Returns 629 | ------- 630 | states_ind: np.array 631 | (n*m) x 2 array containing the indices of the states 632 | states_coord: np.array 633 | (n*m) x 2 array containing the coordinates of the states 634 | """ 635 | nodes = np.arange(0, world_shape[0] * world_shape[1]) 636 | return nodes_to_states(nodes, world_shape, step_size) 637 | 638 | 639 | def draw_gp_sample(kernel, world_shape, step_size): 640 | """ 641 | Draws a sample from a Gaussian process distribution over a user 642 | specified grid 643 | 644 | Parameters 645 | ---------- 646 | kernel: GPy kernel 647 | Defines the GP we draw a sample from 648 | world_shape: tuple 649 | Shape of the grid we use for sampling 650 | step_size: tuple 651 | Step size along any axis to find linearly spaced points 652 | """ 653 | # Compute linearly spaced grid 654 | coord = grid(world_shape, step_size) 655 | 656 | # Draw a sample from GP 657 | cov = kernel.K(coord) + np.eye(coord.shape[0]) * 1e-10 658 | sample = np.random.multivariate_normal(np.zeros(coord.shape[0]), cov) 659 | return sample, coord 660 | 661 | 662 | def shortest_path(source, next_sample, G): 663 | """ 664 | Computes shortest safe path from a source to the next state-action pair 665 | the agent needs to sample 666 | 667 | Parameters 668 | ---------- 669 | source: int 670 | Staring node for the path 671 | next_sample: (int, int) 672 | Next state-action pair the agent needs to sample. First entry is the 673 | number that indicates the state. Second entry indicates the action 674 | G: networkx DiGraph 675 | Graph that indicates the dynamics. It is linked to S matrix 676 | 677 | Returns 678 | ------- 679 | path: list 680 | shortest safe path 681 | """ 682 | 683 | # Extract safe graph 684 | safe_edges = [edge for edge in G.edges_iter(data=True) if edge[2]['safe']] 685 | graph_safe = nx.DiGraph(safe_edges) 686 | 687 | # Compute shortest path 688 | target = next_sample[0] 689 | action = next_sample[1] 690 | path = nx.astar_path(graph_safe, source, target) 691 | 692 | for _, next_node, data in graph_safe.out_edges(nbunch=target, data=True): 693 | if data["action"] == action: 694 | path = path + [next_node] 695 | 696 | return path 697 | 698 | 699 | def path_to_boolean_matrix(path, graph, S): 700 | """ 701 | Computes a S-like matrix for approaches where performances is based 702 | on the trajectory of the agent (e.g. unsafe or random exploration) 703 | Parameters 704 | ---------- 705 | path: np.array 706 | Contains the nodes that are visited along the path 707 | graph: networkx.DiGraph 708 | Graph that indicates the dynamics 709 | S: np.array 710 | Array describing the safe set (needed for initialization) 711 | 712 | Returns 713 | ------- 714 | bool_mat: np.array 715 | S-like array that is true for all the states and state-action pairs 716 | along the path 717 | """ 718 | 719 | # Initialize matrix 720 | bool_mat = np.zeros_like(S, dtype=bool) 721 | 722 | # Go through path to find actions 723 | for i in range(len(path) - 1): 724 | 725 | prev = path[i] 726 | succ = path[i + 1] 727 | 728 | for _, next_node, data in graph.out_edges(nbunch=prev, data=True): 729 | if next_node == succ: 730 | bool_mat[prev, 0] = True 731 | a = data["action"] 732 | bool_mat[prev, a] = True 733 | break 734 | bool_mat[succ, 0] = True 735 | return bool_mat 736 | 737 | 738 | def safe_subpath(path, altitudes, h): 739 | """ 740 | Computes the maximum subpath of path along which the safety constraint is 741 | not violated 742 | Parameters 743 | ---------- 744 | path: np.array 745 | Contains the nodes that are visited along the path 746 | altitudes: np.array 747 | 1-d vector with altitudes for each node 748 | h: float 749 | Safety threshold 750 | 751 | Returns 752 | ------- 753 | subpath: np.array 754 | Maximum subpath of path that fulfills the safety constraint 755 | 756 | """ 757 | # Initialize subpath 758 | subpath = [path[0]] 759 | 760 | # Loop through path 761 | for j in range(len(path) - 1): 762 | prev = path[j] 763 | succ = path[j + 1] 764 | 765 | # Check safety constraint 766 | if altitudes[prev] - altitudes[succ] >= h: 767 | subpath = subpath + [succ] 768 | else: 769 | break 770 | return subpath 771 | -------------------------------------------------------------------------------- /safemdp/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | 3 | import unittest 4 | import GPy 5 | import numpy as np 6 | import networkx as nx 7 | from numpy.testing import * 8 | 9 | from .utilities import * 10 | 11 | from safemdp.SafeMDP_class import reachable_set, returnable_set 12 | from safemdp.grid_world import compute_true_safe_set, grid_world_graph 13 | from .SafeMDP_class import link_graph_and_safe_set 14 | 15 | 16 | class DifferenceKernelTest(unittest.TestCase): 17 | 18 | @staticmethod 19 | def _check(gp, x1, x2): 20 | """Compare the gp difference predictions on X1 and X2. 21 | 22 | Parameters 23 | ---------- 24 | gp: GPy.core.GP 25 | x1: np.array 26 | x2: np.array 27 | """ 28 | n = x1.shape[0] 29 | 30 | # Difference prediction with library 31 | a = np.hstack((np.eye(n), -np.eye(n))) 32 | m1, v1 = gp.predict_noiseless(np.vstack((x1, x2)), full_cov=True) 33 | m1 = a.dot(m1) 34 | v1 = np.linalg.multi_dot((a, v1, a.T)) 35 | 36 | # Predict diagonal 37 | m2, v2 = gp.predict_noiseless(np.hstack((x1, x2)), 38 | kern=DifferenceKernel(gp.kern), 39 | full_cov=False) 40 | 41 | assert_allclose(m1, m2) 42 | assert_allclose(np.diag(v1), v2.squeeze()) 43 | 44 | # Predict full covariance 45 | m2, v2 = gp.predict_noiseless(np.hstack((x1, x2)), 46 | kern=DifferenceKernel(gp.kern), 47 | full_cov=True) 48 | 49 | assert_allclose(m1, m2) 50 | assert_allclose(v1, v2, atol=1e-12) 51 | 52 | def test_1d(self): 53 | """Test the difference kernel for a 1D input.""" 54 | # Create some GP model 55 | kernel = GPy.kern.RBF(input_dim=1, lengthscale=0.05) 56 | likelihood = GPy.likelihoods.Gaussian(variance=0.005 ** 2) 57 | x = np.linspace(0, 1, 5)[:, None] 58 | y = x ** 2 59 | gp = GPy.core.GP(x, y, kernel, likelihood) 60 | 61 | # Create test points 62 | n = 10 63 | x1 = np.linspace(0, 1, n)[:, None] 64 | x2 = x1 + np.linspace(0, 0.1, n)[::-1, None] 65 | 66 | self._check(gp, x1, x2) 67 | 68 | def test_2d(self): 69 | """Test the difference kernel for a 2D input.""" 70 | 71 | # Create some GP model 72 | kernel = GPy.kern.RBF(input_dim=2, lengthscale=0.05) 73 | likelihood = GPy.likelihoods.Gaussian(variance=0.005 ** 2) 74 | x = np.hstack((np.linspace(0, 1, 5)[:, None], 75 | np.linspace(0.5, 1.5, 5)[:, None])) 76 | y = x[:, [0]] ** 2 + x[:, [1]] ** 2 77 | gp = GPy.core.GP(x, y, kernel, likelihood) 78 | 79 | # Create test points 80 | n = 10 81 | 82 | x1 = np.hstack((np.linspace(0, 1, n)[:, None], 83 | np.linspace(0.5, 1.5, n)[:, None])) 84 | x2 = x1 + np.hstack((np.linspace(0, 0.1, n)[::-1, None], 85 | np.linspace(0., 0.1, n)[::-1, None])) 86 | 87 | self._check(gp, x1, x2) 88 | 89 | 90 | class MaxOutDegreeTest(unittest.TestCase): 91 | def test_all(self): 92 | """Test the max_out_degree function.""" 93 | graph = nx.DiGraph() 94 | graph.add_edges_from(((0, 1), 95 | (1, 2), 96 | (2, 3), 97 | (3, 1))) 98 | assert_(max_out_degree(graph), 1) 99 | 100 | graph.add_edge(0, 2) 101 | assert_(max_out_degree(graph), 2) 102 | 103 | graph.add_edge(2, 3) 104 | assert_(max_out_degree(graph), 2) 105 | 106 | graph.add_edge(3, 2) 107 | assert_(max_out_degree(graph), 2) 108 | 109 | graph.add_edge(3, 1) 110 | assert_(max_out_degree(graph), 3) 111 | 112 | 113 | class ReachableSetTest(unittest.TestCase): 114 | 115 | def __init__(self, *args, **kwargs): 116 | super(ReachableSetTest, self).__init__(*args, **kwargs) 117 | # 3 118 | # ^ 119 | # | 120 | # 0 --> 1 --> 2 --> 0 121 | # ^ 122 | # | 123 | # 4 124 | self.graph = nx.DiGraph() 125 | self.graph.add_edges_from([(0, 1), 126 | (1, 2), 127 | (2, 0), 128 | (4, 1)], action=1) 129 | self.graph.add_edge(2, 3, action=2) 130 | 131 | self.safe_set = np.ones((self.graph.number_of_nodes(), 132 | max_out_degree(self.graph) + 1), 133 | dtype=np.bool) 134 | link_graph_and_safe_set(self.graph, self.safe_set) 135 | self.true = np.zeros(self.safe_set.shape[0], dtype=np.bool) 136 | 137 | def setUp(self): 138 | self.safe_set[:] = True 139 | 140 | def _check(self): 141 | reach = reachable_set(self.graph, [0]) 142 | assert_equal(reach[:, 0], self.true) 143 | 144 | def test_all_safe(self): 145 | """Test reachable set if everything is safe""" 146 | self.true[:] = [1, 1, 1, 1, 0] 147 | self._check() 148 | 149 | def test_unsafe1(self): 150 | """Test safety aspect""" 151 | self.safe_set[1, 1] = False 152 | self.true[:] = [1, 1, 0, 0, 0] 153 | self._check() 154 | 155 | def test_unsafe2(self): 156 | """Test safety aspect""" 157 | self.safe_set[2, 2] = False 158 | self.true[:] = [1, 1, 1, 0, 0] 159 | self._check() 160 | 161 | def test_unsafe3(self): 162 | """Test safety aspect""" 163 | self.safe_set[2, 1] = False 164 | self.true[:] = [1, 1, 1, 1, 0] 165 | self._check() 166 | 167 | def test_unsafe4(self): 168 | """Test safety aspect""" 169 | self.safe_set[4, 1] = False 170 | self.true[:] = [1, 1, 1, 1, 0] 171 | self._check() 172 | 173 | def test_out(self): 174 | """Test writing the output""" 175 | self.safe_set[2, 2] = False 176 | self.true[:] = [1, 1, 1, 0, 0] 177 | out = np.zeros_like(self.safe_set) 178 | reachable_set(self.graph, [0], out=out) 179 | assert_equal(out[:, 0], self.true) 180 | 181 | def test_error(self): 182 | """Check error condition""" 183 | with assert_raises(AttributeError): 184 | reachable_set(self.graph, []) 185 | 186 | 187 | class ReturnableSetTest(unittest.TestCase): 188 | 189 | def __init__(self, *args, **kwargs): 190 | super(ReturnableSetTest, self).__init__(*args, **kwargs) 191 | # 3 192 | # ^ 193 | # | 194 | # 0 --> 1 --> 2 --> 0 195 | # ^ 196 | # | 197 | # 4 198 | self.graph = nx.DiGraph() 199 | self.graph.add_edges_from([(0, 1), 200 | (1, 2), 201 | (2, 0), 202 | (4, 1)], action=1) 203 | self.graph.add_edge(2, 3, action=2) 204 | self.graph_rev = self.graph.reverse() 205 | 206 | self.safe_set = np.ones((self.graph.number_of_nodes(), 207 | max_out_degree(self.graph) + 1), 208 | dtype=np.bool) 209 | link_graph_and_safe_set(self.graph, self.safe_set) 210 | self.true = np.zeros(self.safe_set.shape[0], dtype=np.bool) 211 | 212 | def setUp(self): 213 | self.safe_set[:] = True 214 | 215 | def _check(self): 216 | ret = returnable_set(self.graph, self.graph_rev, [0]) 217 | assert_equal(ret[:, 0], self.true) 218 | 219 | def test_all_safe(self): 220 | """Test reachable set if everything is safe""" 221 | self.true[:] = [1, 1, 1, 0, 1] 222 | self._check() 223 | 224 | def test_unsafe1(self): 225 | """Test safety aspect""" 226 | self.safe_set[1, 1] = False 227 | self.true[:] = [1, 0, 1, 0, 0] 228 | self._check() 229 | 230 | def test_unsafe2(self): 231 | """Test safety aspect""" 232 | self.safe_set[2, 1] = False 233 | self.true[:] = [1, 0, 0, 0, 0] 234 | self._check() 235 | 236 | def test_unsafe3(self): 237 | """Test safety aspect""" 238 | self.safe_set[2, 2] = False 239 | self.true[:] = [1, 1, 1, 0, 1] 240 | self._check() 241 | 242 | def test_unsafe4(self): 243 | """Test safety aspect""" 244 | self.safe_set[4, 1] = False 245 | self.true[:] = [1, 1, 1, 0, 0] 246 | self._check() 247 | 248 | def test_out(self): 249 | """Test writing the output""" 250 | self.safe_set[1, 1] = False 251 | self.true[:] = [1, 0, 1, 0, 0] 252 | out = np.zeros_like(self.safe_set) 253 | returnable_set(self.graph, self.graph_rev, [0], out=out) 254 | assert_equal(out[:, 0], self.true) 255 | 256 | def test_error(self): 257 | """Check error condition""" 258 | with assert_raises(AttributeError): 259 | reachable_set(self.graph, []) 260 | 261 | 262 | class GridWorldGraphTest(unittest.TestCase): 263 | """Test the grid_world_graph function.""" 264 | 265 | def test(self): 266 | """Simple test""" 267 | # 1 2 3 268 | # 4 5 6 269 | graph = grid_world_graph((2, 3)) 270 | graph_true = nx.DiGraph() 271 | graph_true.add_edges_from(((1, 2), 272 | (2, 3), 273 | (4, 5), 274 | (5, 6)), 275 | action=1) 276 | graph_true.add_edges_from(((1, 4), 277 | (2, 5), 278 | (3, 6)), 279 | action=2) 280 | graph_true.add_edges_from(((2, 1), 281 | (3, 2), 282 | (5, 4), 283 | (6, 5)), 284 | action=3) 285 | graph_true.add_edges_from(((4, 1), 286 | (5, 2), 287 | (6, 3)), 288 | action=4) 289 | 290 | assert_(nx.is_isomorphic(graph, graph_true)) 291 | 292 | 293 | class TestTrueSafeSet(unittest.TestCase): 294 | 295 | def test_differences_safe(self): 296 | altitudes = np.array([[1, 2, 3], 297 | [2, 3, 4]]) 298 | safe = compute_true_safe_set((2, 3), altitudes.reshape(-1), -1) 299 | true_safe = np.array([[1, 1, 1, 1, 1, 1], 300 | [1, 1, 0, 1, 1, 0], 301 | [1, 1, 1, 0, 0, 0], 302 | [0, 1, 1, 0, 1, 1], 303 | [0, 0, 0, 1, 1, 1]], 304 | dtype=np.bool).T 305 | 306 | assert_equal(safe, true_safe) 307 | 308 | def test_differences_unsafe(self): 309 | altitudes = np.array([[1, 0, 3], 310 | [2, 3, 0]]) 311 | safe = compute_true_safe_set((2, 3), altitudes.reshape(-1), -1) 312 | true_safe = np.array([[1, 1, 1, 1, 1, 1], 313 | [1, 0, 0, 1, 1, 0], 314 | [1, 0, 1, 0, 0, 0], 315 | [0, 1, 1, 0, 1, 0], 316 | [0, 0, 0, 1, 1, 0]], 317 | dtype=np.bool).T 318 | assert_equal(safe, true_safe) 319 | 320 | 321 | if __name__ == '__main__': 322 | unittest.main() 323 | -------------------------------------------------------------------------------- /safemdp/utilities.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import numpy as np 4 | 5 | __all__ = ['DifferenceKernel', 'max_out_degree'] 6 | 7 | 8 | class DifferenceKernel(object): 9 | """ 10 | A fake kernel that can be used to predict differences two function values. 11 | 12 | Given a gp based on measurements, we aim to predict the difference between 13 | the function values at two different test points, X1 and X2; that is, we 14 | want to obtain mean and variance of f(X1) - f(X2). Using this fake 15 | kernel, this can be achieved with 16 | `mean, var = gp.predict(np.hstack((X1, X2)), kern=DiffKernel(gp.kern))` 17 | 18 | Parameters 19 | ---------- 20 | kernel: GPy.kern.* 21 | The kernel used by the GP 22 | """ 23 | 24 | def __init__(self, kernel): 25 | self.kern = kernel 26 | 27 | def K(self, x1, x2=None): 28 | """Equivalent of kern.K 29 | 30 | If only x1 is passed then it is assumed to contain the data for both 31 | whose differences we are computing. Otherwise, x2 will contain these 32 | extended states (see PosteriorExact._raw_predict in 33 | GPy/inference/latent_function_inference0/posterior.py) 34 | 35 | Parameters 36 | ---------- 37 | x1: np.array 38 | x2: np.array 39 | """ 40 | dim = self.kern.input_dim 41 | if x2 is None: 42 | x10 = x1[:, :dim] 43 | x11 = x1[:, dim:] 44 | return (self.kern.K(x10) + self.kern.K(x11) - 45 | self.kern.K(x10, x11) - self.kern.K(x11, x10)) 46 | else: 47 | x20 = x2[:, :dim] 48 | x21 = x2[:, dim:] 49 | return self.kern.K(x1, x20) - self.kern.K(x1, x21) 50 | 51 | def Kdiag(self, x): 52 | """Equivalent of kern.Kdiag for the difference prediction. 53 | 54 | Parameters 55 | ---------- 56 | x: np.array 57 | """ 58 | dim = self.kern.input_dim 59 | x0 = x[:, :dim] 60 | x1 = x[:, dim:] 61 | return (self.kern.Kdiag(x0) + self.kern.Kdiag(x1) - 62 | 2 * np.diag(self.kern.K(x0, x1))) 63 | 64 | 65 | def max_out_degree(graph): 66 | """Compute the maximum out_degree of a graph 67 | 68 | Parameters 69 | ---------- 70 | graph: nx.DiGraph 71 | 72 | Returns 73 | ------- 74 | max_out_degree: int 75 | The maximum out_degree of the graph 76 | """ 77 | def degree_generator(graph): 78 | for _, degree in graph.out_degree_iter(): 79 | yield degree 80 | return max(degree_generator(graph)) 81 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | 5 | def read(fname): 6 | """Read the name relative to current directory""" 7 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 8 | 9 | setup( 10 | name="safemdp", 11 | version="1.0", 12 | author="Matteo Turchetta, Felix Berkenkamp", 13 | author_email="matteotu@ethz.ch, befelix@ethz.ch", 14 | description=("Safe exploration in MDPs"), 15 | license="MIT", 16 | url="http://packages.python.org/an_example_pypi_project", 17 | packages=['safemdp'], 18 | long_description=read('README.md'), 19 | install_requires=[ 20 | 'GPy >= 0.8.0', 21 | 'numpy >= 1.7.2', 22 | 'scipy >= 0.16', 23 | 'matplotlib >= 1.5.0', 24 | 'networkx >= 1.1', 25 | ], 26 | ) 27 | -------------------------------------------------------------------------------- /test_code.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | module="safemdp" 4 | 5 | get_script_dir () { 6 | SOURCE="${BASH_SOURCE[0]}" 7 | # While $SOURCE is a symlink, resolve it 8 | while [ -h "$SOURCE" ]; do 9 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" 10 | SOURCE="$( readlink "$SOURCE" )" 11 | # If $SOURCE was a relative symlink (so no "/" as prefix, need to resolve it relative to the symlink base directory 12 | [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" 13 | done 14 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )" 15 | echo "$DIR" 16 | } 17 | 18 | # Change to script root 19 | cd $(get_script_dir) 20 | GREEN='\033[0;32m' 21 | NC='\033[0m' 22 | 23 | # Run style tests 24 | echo -e "${GREEN}Running style tests.${NC}" 25 | flake8 $module --exclude test*.py,__init__.py --ignore=E402,W503 --show-source 26 | 27 | # Ignore import errors for __init__ and tests 28 | flake8 $module --filename=__init__.py,test*.py --ignore=F,E402,W503 --show-source 29 | 30 | echo -e "${GREEN}Testing docstring conventions.${NC}" 31 | # Test docstring conventions 32 | pydocstyle safemdp --match='(?!__init__).*\.py' 2>&1 | grep -v "WARNING: __all__" 33 | 34 | # Run unit tests 35 | echo -e "${GREEN}Running unit tests.${NC}" 36 | nosetests --with-doctest --with-coverage --cover-erase --cover-package=safemdp $module 37 | 38 | # Export html 39 | coverage html 40 | 41 | --------------------------------------------------------------------------------