├── .coveragerc
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── doc
├── Makefile
├── api.rst
├── conf.py
├── index.rst
├── requirements.txt
└── template.rst
├── examples
├── mars
│ ├── README.md
│ ├── generate_plots.py
│ ├── mars.py
│ ├── mars_utilities.py
│ ├── plot_utilities.py
│ └── requirements.txt
└── sample.py
├── requirements.txt
├── safemdp
├── SafeMDP_class.py
├── __init__.py
├── grid_world.py
├── test.py
└── utilities.py
├── setup.py
└── test_code.sh
/.coveragerc:
--------------------------------------------------------------------------------
1 | [report]
2 | # Regexes for lines to exclude from consideration
3 | exclude_lines =
4 | # Have to re-enable the standard pragma
5 | pragma: no cover
6 |
7 | # Don't complain about missing debug-only code:
8 | def __repr__
9 | if self\.debug
10 |
11 | # Don't complain if tests don't hit defensive assertion code:
12 | raise AssertionError
13 | raise NotImplementedError
14 |
15 | # Don't complain if non-runnable code isn't run:
16 | if False:
17 | if __name__ == .__main__.:
18 |
19 | show_missing = True
20 |
21 | [html]
22 | directory = htmlcov
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | .ipynb_checkpoints
3 | .coverage
4 | htmlcov
5 |
6 | # Python files
7 | *.pyc
8 |
9 | # Data files
10 | *.IMG
11 | *.tif
12 |
13 | # Documentation
14 | doc/safemdp.*.rst
15 | doc/_build
16 |
17 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 | python:
3 | - 2.7
4 | notifications:
5 | email: false
6 |
7 | # Setup anaconda
8 | before_install:
9 | - wget http://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh
10 | - chmod +x miniconda.sh
11 | - ./miniconda.sh -b -p ~/miniconda
12 | - export PATH=~/miniconda/bin:$PATH
13 | - conda update --yes conda
14 | # The next couple lines fix a crash with multiprocessing on Travis and are not specific to using Miniconda
15 | - sudo rm -rf /dev/shm
16 | - sudo ln -s /run/shm /dev/shm
17 | # Install packages
18 | install:
19 | - conda install --yes python=$TRAVIS_PYTHON_VERSION flake8 numpy scipy matplotlib nose networkx
20 | - pip install GPy
21 | # Coverage packages are on my binstar channel
22 | # - conda install --yes -c dan_blanchard python-coveralls nose-cov
23 | # - python setup.py install
24 |
25 | # Run tests
26 | script:
27 | - flake8 safemdp --exclude test*.py,__init__.py --ignore=E402,W503 --show-source
28 | - flake8 safemdp --filename=__init__.py,test*.py --ignore=F,E402,W503 --show-source
29 | - nosetests --with-doctest safemdp
30 |
31 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | The MIT License (MIT)
2 |
3 | Copyright (c) 2015 Felix Berkenkamp
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SafeMDP
2 |
3 | [](https://travis-ci.org/befelix/SafeMDP)
4 | [](http://safemdp.readthedocs.io/en/latest/?badge=latest)
5 |
6 | Code for safe exploration in Markov Decision Processes (MDPs). This code accompanies the paper
7 |
8 | M. Turchetta, F. Berkenkamp, A. Krause, "Safe Exploration in Finite Markov Decision Processes with Gaussian Processes", Proc. of the Conference on Neural Information Processing Systems (NIPS), 2016, [PDF]
9 |
10 | # Installation
11 |
12 | The easiest way to install use the library is to install the Anaconda Python distribution. Then, run the following commands in the root directory of this repository:
13 | ```
14 | pip install GPy
15 | python setup.py install
16 | ```
17 |
18 | # Usage
19 |
20 | The documentation of the library is available on Read the Docs
21 |
22 | The file `examples/sample.py` implements a simple examples that samples a random world from a Gaussian process and shows exploration results.
23 |
24 | The code for the experiments in the paper can be found in the `examples/mars/` directory.
25 |
--------------------------------------------------------------------------------
/doc/Makefile:
--------------------------------------------------------------------------------
1 | # Makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | PAPER =
8 | BUILDDIR = _build
9 |
10 | # Internal variables.
11 | PAPEROPT_a4 = -D latex_paper_size=a4
12 | PAPEROPT_letter = -D latex_paper_size=letter
13 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
14 | # the i18n builder cannot share the environment and doctrees with the others
15 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
16 |
17 | .PHONY: help
18 | help:
19 | @echo "Please use \`make ' where is one of"
20 | @echo " html to make standalone HTML files"
21 | @echo " dirhtml to make HTML files named index.html in directories"
22 | @echo " singlehtml to make a single large HTML file"
23 | @echo " pickle to make pickle files"
24 | @echo " json to make JSON files"
25 | @echo " htmlhelp to make HTML files and a HTML help project"
26 | @echo " qthelp to make HTML files and a qthelp project"
27 | @echo " applehelp to make an Apple Help Book"
28 | @echo " devhelp to make HTML files and a Devhelp project"
29 | @echo " epub to make an epub"
30 | @echo " epub3 to make an epub3"
31 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
32 | @echo " latexpdf to make LaTeX files and run them through pdflatex"
33 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
34 | @echo " text to make text files"
35 | @echo " man to make manual pages"
36 | @echo " texinfo to make Texinfo files"
37 | @echo " info to make Texinfo files and run them through makeinfo"
38 | @echo " gettext to make PO message catalogs"
39 | @echo " changes to make an overview of all changed/added/deprecated items"
40 | @echo " xml to make Docutils-native XML files"
41 | @echo " pseudoxml to make pseudoxml-XML files for display purposes"
42 | @echo " linkcheck to check all external links for integrity"
43 | @echo " doctest to run all doctests embedded in the documentation (if enabled)"
44 | @echo " coverage to run coverage check of the documentation (if enabled)"
45 | @echo " dummy to check syntax errors of document sources"
46 |
47 | .PHONY: clean
48 | clean:
49 | rm -rf $(BUILDDIR)/*
50 |
51 | .PHONY: html
52 | html:
53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
54 | @echo
55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
56 |
57 | .PHONY: dirhtml
58 | dirhtml:
59 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
60 | @echo
61 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
62 |
63 | .PHONY: singlehtml
64 | singlehtml:
65 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
66 | @echo
67 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
68 |
69 | .PHONY: pickle
70 | pickle:
71 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
72 | @echo
73 | @echo "Build finished; now you can process the pickle files."
74 |
75 | .PHONY: json
76 | json:
77 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
78 | @echo
79 | @echo "Build finished; now you can process the JSON files."
80 |
81 | .PHONY: htmlhelp
82 | htmlhelp:
83 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
84 | @echo
85 | @echo "Build finished; now you can run HTML Help Workshop with the" \
86 | ".hhp project file in $(BUILDDIR)/htmlhelp."
87 |
88 | .PHONY: qthelp
89 | qthelp:
90 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
91 | @echo
92 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \
93 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:"
94 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/SafeMDP.qhcp"
95 | @echo "To view the help file:"
96 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/SafeMDP.qhc"
97 |
98 | .PHONY: applehelp
99 | applehelp:
100 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp
101 | @echo
102 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp."
103 | @echo "N.B. You won't be able to view it unless you put it in" \
104 | "~/Library/Documentation/Help or install it in your application" \
105 | "bundle."
106 |
107 | .PHONY: devhelp
108 | devhelp:
109 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
110 | @echo
111 | @echo "Build finished."
112 | @echo "To view the help file:"
113 | @echo "# mkdir -p $$HOME/.local/share/devhelp/SafeMDP"
114 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/SafeMDP"
115 | @echo "# devhelp"
116 |
117 | .PHONY: epub
118 | epub:
119 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
120 | @echo
121 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub."
122 |
123 | .PHONY: epub3
124 | epub3:
125 | $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3
126 | @echo
127 | @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3."
128 |
129 | .PHONY: latex
130 | latex:
131 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
132 | @echo
133 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
134 | @echo "Run \`make' in that directory to run these through (pdf)latex" \
135 | "(use \`make latexpdf' here to do that automatically)."
136 |
137 | .PHONY: latexpdf
138 | latexpdf:
139 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
140 | @echo "Running LaTeX files through pdflatex..."
141 | $(MAKE) -C $(BUILDDIR)/latex all-pdf
142 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
143 |
144 | .PHONY: latexpdfja
145 | latexpdfja:
146 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
147 | @echo "Running LaTeX files through platex and dvipdfmx..."
148 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
149 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
150 |
151 | .PHONY: text
152 | text:
153 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
154 | @echo
155 | @echo "Build finished. The text files are in $(BUILDDIR)/text."
156 |
157 | .PHONY: man
158 | man:
159 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
160 | @echo
161 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man."
162 |
163 | .PHONY: texinfo
164 | texinfo:
165 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
166 | @echo
167 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
168 | @echo "Run \`make' in that directory to run these through makeinfo" \
169 | "(use \`make info' here to do that automatically)."
170 |
171 | .PHONY: info
172 | info:
173 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
174 | @echo "Running Texinfo files through makeinfo..."
175 | make -C $(BUILDDIR)/texinfo info
176 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
177 |
178 | .PHONY: gettext
179 | gettext:
180 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
181 | @echo
182 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
183 |
184 | .PHONY: changes
185 | changes:
186 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
187 | @echo
188 | @echo "The overview file is in $(BUILDDIR)/changes."
189 |
190 | .PHONY: linkcheck
191 | linkcheck:
192 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
193 | @echo
194 | @echo "Link check complete; look for any errors in the above output " \
195 | "or in $(BUILDDIR)/linkcheck/output.txt."
196 |
197 | .PHONY: doctest
198 | doctest:
199 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
200 | @echo "Testing of doctests in the sources finished, look at the " \
201 | "results in $(BUILDDIR)/doctest/output.txt."
202 |
203 | .PHONY: coverage
204 | coverage:
205 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage
206 | @echo "Testing of coverage in the sources finished, look at the " \
207 | "results in $(BUILDDIR)/coverage/python.txt."
208 |
209 | .PHONY: xml
210 | xml:
211 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
212 | @echo
213 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml."
214 |
215 | .PHONY: pseudoxml
216 | pseudoxml:
217 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
218 | @echo
219 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
220 |
221 | .PHONY: dummy
222 | dummy:
223 | $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy
224 | @echo
225 | @echo "Build finished. Dummy builder generates no files."
226 |
--------------------------------------------------------------------------------
/doc/api.rst:
--------------------------------------------------------------------------------
1 | API Documentation
2 | *****************
3 |
4 | .. automodule:: safemdp
5 |
--------------------------------------------------------------------------------
/doc/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | #
4 | # SafeMDP documentation build configuration file, created by
5 | # sphinx-quickstart on Thu Jun 16 08:40:23 2016.
6 | #
7 | # This file is execfile()d with the current directory set to its
8 | # containing dir.
9 | #
10 | # Note that not all possible configuration values are present in this
11 | # autogenerated file.
12 | #
13 | # All configuration values have a default; values that are commented out
14 | # serve to show the default.
15 |
16 | # If extensions (or modules to document with autodoc) are in another directory,
17 | # add these directories to sys.path here. If the directory is relative to the
18 | # documentation root, use os.path.abspath to make it absolute, like shown here.
19 |
20 | import os
21 | import sys
22 | import shlex
23 | import mock
24 |
25 |
26 | MOCK_MODULES = ['GPy',
27 | 'GPy.util',
28 | 'GPy.util.linalg',
29 | 'GPy.inference',
30 | 'GPy.inference.latent_function_inference',
31 | 'GPy.inference.latent_function_inference.posterior',
32 | 'mpl_toolkits',
33 | 'mpl_toolkits.mplot3d',
34 | 'matplotlib',
35 | 'matplotlib.pyplot',
36 | 'networkx',
37 | 'nose',
38 | 'nose.tools',
39 | 'numpy',
40 | 'numpy.testing',
41 | 'scipy',
42 | 'scipy.interpolate',
43 | 'scipy.spatial',
44 | 'scipy.spatial.distance',
45 | ]
46 |
47 | for mod_name in MOCK_MODULES:
48 | sys.modules[mod_name] = mock.Mock()
49 |
50 | sys.path.insert(0, os.path.abspath('../'))
51 |
52 | # -- General configuration ------------------------------------------------
53 |
54 | # If your documentation needs a minimal Sphinx version, state it here.
55 | #
56 | # needs_sphinx = '1.0'
57 |
58 | # Add any Sphinx extension module names here, as strings. They can be
59 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
60 | # ones.
61 | extensions = [
62 | 'sphinx.ext.autodoc',
63 | 'numpydoc',
64 | 'sphinx.ext.autosummary',
65 | ]
66 |
67 | # Add any paths that contain templates here, relative to this directory.
68 | templates_path = ['']
69 |
70 | # Generate an autosummary with one file per function.
71 | autosummary_generate = True
72 |
73 | autodoc_default_flags = []
74 |
75 | # The suffix(es) of source filenames.
76 | # You can specify multiple suffix as a list of string:
77 | #
78 | # source_suffix = ['.rst', '.md']
79 | source_suffix = '.rst'
80 |
81 | # The encoding of source files.
82 | #
83 | # source_encoding = 'utf-8-sig'
84 |
85 | # The master toctree document.
86 | master_doc = 'index'
87 |
88 | # General information about the project.
89 | project = 'SafeMDP'
90 | copyright = '2016, Matteo Turchetta, Felix Berkenkamp, Andreas Krause'
91 | author = 'Matteo Turchetta, Felix Berkenkamp, Andreas Krause'
92 |
93 | # The version info for the project you're documenting, acts as replacement for
94 | # |version| and |release|, also used in various other places throughout the
95 | # built documents.
96 | #
97 | # The short X.Y version.
98 | version = '1.0'
99 | # The full version, including alpha/beta/rc tags.
100 | release = '1.0'
101 |
102 | # The language for content autogenerated by Sphinx. Refer to documentation
103 | # for a list of supported languages.
104 | #
105 | # This is also used if you do content translation via gettext catalogs.
106 | # Usually you set "language" from the command line for these cases.
107 | language = None
108 |
109 | # There are two options for replacing |today|: either, you set today to some
110 | # non-false value, then it is used:
111 | #
112 | # today = ''
113 | #
114 | # Else, today_fmt is used as the format for a strftime call.
115 | #
116 | # today_fmt = '%B %d, %Y'
117 |
118 | # List of patterns, relative to source directory, that match files and
119 | # directories to ignore when looking for source files.
120 | # This patterns also effect to html_static_path and html_extra_path
121 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
122 |
123 | # The reST default role (used for this markup: `text`) to use for all
124 | # documents.
125 | #
126 | # default_role = None
127 |
128 | # If true, '()' will be appended to :func: etc. cross-reference text.
129 | #
130 | # add_function_parentheses = True
131 |
132 | # If true, the current module name will be prepended to all description
133 | # unit titles (such as .. function::).
134 | #
135 | # add_module_names = True
136 |
137 | # If true, sectionauthor and moduleauthor directives will be shown in the
138 | # output. They are ignored by default.
139 | #
140 | # show_authors = False
141 |
142 | # The name of the Pygments (syntax highlighting) style to use.
143 | pygments_style = 'sphinx'
144 |
145 | # A list of ignored prefixes for module index sorting.
146 | # modindex_common_prefix = []
147 |
148 | # If true, keep warnings as "system message" paragraphs in the built documents.
149 | # keep_warnings = False
150 |
151 | # If true, `todo` and `todoList` produce output, else they produce nothing.
152 | todo_include_todos = False
153 |
154 |
155 | # -- Options for HTML output ----------------------------------------------
156 |
157 | # The theme to use for HTML and HTML Help pages. See the documentation for
158 | # a list of builtin themes.
159 | #
160 | html_theme = 'sphinx_rtd_theme'
161 |
162 | # Theme options are theme-specific and customize the look and feel of a theme
163 | # further. For a list of options available for each theme, see the
164 | # documentation.
165 | #
166 | # html_theme_options = {}
167 |
168 | # Add any paths that contain custom themes here, relative to this directory.
169 | # html_theme_path = []
170 |
171 | # The name for this set of Sphinx documents.
172 | # " v documentation" by default.
173 | #
174 | # html_title = 'SafeMDP v1.0'
175 |
176 | # A shorter title for the navigation bar. Default is the same as html_title.
177 | #
178 | # html_short_title = None
179 |
180 | # The name of an image file (relative to this directory) to place at the top
181 | # of the sidebar.
182 | #
183 | # html_logo = None
184 |
185 | # The name of an image file (relative to this directory) to use as a favicon of
186 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
187 | # pixels large.
188 | #
189 | # html_favicon = None
190 |
191 | # Add any paths that contain custom static files (such as style sheets) here,
192 | # relative to this directory. They are copied after the builtin static files,
193 | # so a file named "default.css" will overwrite the builtin "default.css".
194 | #html_static_path = ['_static']
195 |
196 | # Add any extra paths that contain custom files (such as robots.txt or
197 | # .htaccess) here, relative to this directory. These files are copied
198 | # directly to the root of the documentation.
199 | #
200 | # html_extra_path = []
201 |
202 | # If not None, a 'Last updated on:' timestamp is inserted at every page
203 | # bottom, using the given strftime format.
204 | # The empty string is equivalent to '%b %d, %Y'.
205 | #
206 | # html_last_updated_fmt = None
207 |
208 | # If true, SmartyPants will be used to convert quotes and dashes to
209 | # typographically correct entities.
210 | #
211 | # html_use_smartypants = True
212 |
213 | # Custom sidebar templates, maps document names to template names.
214 | #
215 | # html_sidebars = {}
216 |
217 | # Additional templates that should be rendered to pages, maps page names to
218 | # template names.
219 | #
220 | # html_additional_pages = {}
221 |
222 | # If false, no module index is generated.
223 | #
224 | # html_domain_indices = True
225 |
226 | # If false, no index is generated.
227 | #
228 | # html_use_index = True
229 |
230 | # If true, the index is split into individual pages for each letter.
231 | #
232 | # html_split_index = False
233 |
234 | # If true, links to the reST sources are added to the pages.
235 | #
236 | # html_show_sourcelink = True
237 |
238 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
239 | #
240 | # html_show_sphinx = True
241 |
242 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
243 | #
244 | # html_show_copyright = True
245 |
246 | # If true, an OpenSearch description file will be output, and all pages will
247 | # contain a tag referring to it. The value of this option must be the
248 | # base URL from which the finished HTML is served.
249 | #
250 | # html_use_opensearch = ''
251 |
252 | # This is the file name suffix for HTML files (e.g. ".xhtml").
253 | # html_file_suffix = None
254 |
255 | # Language to be used for generating the HTML full-text search index.
256 | # Sphinx supports the following languages:
257 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja'
258 | # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr', 'zh'
259 | #
260 | html_search_language = 'en'
261 |
262 | # A dictionary with options for the search language support, empty by default.
263 | # 'ja' uses this config value.
264 | # 'zh' user can custom change `jieba` dictionary path.
265 | #
266 | # html_search_options = {'type': 'default'}
267 |
268 | # The name of a javascript file (relative to the configuration directory) that
269 | # implements a search results scorer. If empty, the default will be used.
270 | #
271 | # html_search_scorer = 'scorer.js'
272 |
273 | # Output file base name for HTML help builder.
274 | htmlhelp_basename = 'SafeMDPdoc'
275 |
276 | # -- Options for LaTeX output ---------------------------------------------
277 |
278 | latex_elements = {
279 | # The paper size ('letterpaper' or 'a4paper').
280 | #
281 | # 'papersize': 'letterpaper',
282 |
283 | # The font size ('10pt', '11pt' or '12pt').
284 | #
285 | # 'pointsize': '10pt',
286 |
287 | # Additional stuff for the LaTeX preamble.
288 | #
289 | # 'preamble': '',
290 |
291 | # Latex figure (float) alignment
292 | #
293 | # 'figure_align': 'htbp',
294 | }
295 |
296 | # Grouping the document tree into LaTeX files. List of tuples
297 | # (source start file, target name, title,
298 | # author, documentclass [howto, manual, or own class]).
299 | latex_documents = [
300 | (master_doc, 'SafeMDP.tex', 'SafeMDP Documentation',
301 | 'Matteo Turchetta, Felix Berkenkamp, Andreas Krause', 'manual'),
302 | ]
303 |
304 | # The name of an image file (relative to this directory) to place at the top of
305 | # the title page.
306 | #
307 | # latex_logo = None
308 |
309 | # For "manual" documents, if this is true, then toplevel headings are parts,
310 | # not chapters.
311 | #
312 | # latex_use_parts = False
313 |
314 | # If true, show page references after internal links.
315 | #
316 | # latex_show_pagerefs = False
317 |
318 | # If true, show URL addresses after external links.
319 | #
320 | # latex_show_urls = False
321 |
322 | # Documents to append as an appendix to all manuals.
323 | #
324 | # latex_appendices = []
325 |
326 | # If false, no module index is generated.
327 | #
328 | # latex_domain_indices = True
329 |
330 |
331 | # -- Options for manual page output ---------------------------------------
332 |
333 | # One entry per manual page. List of tuples
334 | # (source start file, name, description, authors, manual section).
335 | man_pages = [
336 | (master_doc, 'safemdp', 'SafeMDP Documentation',
337 | [author], 1)
338 | ]
339 |
340 | # If true, show URL addresses after external links.
341 | #
342 | # man_show_urls = False
343 |
344 |
345 | # -- Options for Texinfo output -------------------------------------------
346 |
347 | # Grouping the document tree into Texinfo files. List of tuples
348 | # (source start file, target name, title, author,
349 | # dir menu entry, description, category)
350 | texinfo_documents = [
351 | (master_doc, 'SafeMDP', 'SafeMDP Documentation',
352 | author, 'SafeMDP', 'One line description of project.',
353 | 'Miscellaneous'),
354 | ]
355 |
356 | # Documents to append as an appendix to all manuals.
357 | #
358 | # texinfo_appendices = []
359 |
360 | # If false, no module index is generated.
361 | #
362 | # texinfo_domain_indices = True
363 |
364 | # How to display URL addresses: 'footnote', 'no', or 'inline'.
365 | #
366 | # texinfo_show_urls = 'footnote'
367 |
368 | # If true, do not generate a @detailmenu in the "Top" node's menu.
369 | #
370 | # texinfo_no_detailmenu = False
371 |
372 |
373 | # Example configuration for intersphinx: refer to the Python standard library.
374 | # intersphinx_mapping = {'https://docs.python.org/': None}
375 |
--------------------------------------------------------------------------------
/doc/index.rst:
--------------------------------------------------------------------------------
1 | .. SafeMDP documentation master file, created by
2 | sphinx-quickstart on Thu Jun 16 08:40:23 2016.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Welcome to SafeMDP's documentation!
7 | ===================================
8 |
9 | .. toctree::
10 | :caption: Contents
11 | :maxdepth: 3
12 |
13 | api
14 |
15 |
16 | Indices and tables
17 | ==================
18 |
19 | * :ref:`genindex`
20 | * :ref:`modindex`
21 | * :ref:`search`
22 |
23 |
--------------------------------------------------------------------------------
/doc/requirements.txt:
--------------------------------------------------------------------------------
1 | numpydoc >= 0.5
2 | sphinx_rtd_theme >= 0.1.8
3 |
4 |
--------------------------------------------------------------------------------
/doc/template.rst:
--------------------------------------------------------------------------------
1 | {{ name }}
2 | {{ underline }}
3 |
4 | .. currentmodule:: {{ module }}
5 | .. auto{{ objtype }}:: {{ objname }} {% if objtype == "class" %}
6 | :members:
7 | :inherited-members:
8 | {% endif %}
9 |
--------------------------------------------------------------------------------
/examples/mars/README.md:
--------------------------------------------------------------------------------
1 | The experiments were run using Python 3 and the packages specified in the `requirements.txt` file.
2 |
--------------------------------------------------------------------------------
/examples/mars/generate_plots.py:
--------------------------------------------------------------------------------
1 | from plot_utilities import *
2 | import numpy as np
3 | import cPickle
4 | import time
5 |
6 | # Safe plots
7 | data = np.load("mars safe experiment.npz")
8 | mu = data["mu_alt"]
9 | var = data["var_alt"]
10 | beta = data["beta"]
11 | world_shape = data["world_shape"]
12 | altitudes = data["altitudes"]
13 | X = data["X"]
14 | Y = data["Y"]
15 | coord = data["coord"]
16 | coverage_over_t = data["coverage_over_t"]
17 | S_hat = data["S_hat"]
18 |
19 | plot_dist_from_C(mu, var, beta, altitudes, world_shape)
20 | plot_coverage(coverage_over_t)
21 | plot_paper(altitudes, S_hat, world_shape, './safe_exploration.pdf')
22 |
23 | # Unsafe plot
24 | data = np.load("mars unsafe experiment.npz")
25 | coverage = data["coverage"]
26 | altitudes = data["altitudes"]
27 | visited = data["visited"]
28 | world_shape = data["world_shape"]
29 |
30 | plot_paper(altitudes, visited, world_shape, './no_safe_exploration1.pdf')
31 |
32 | # Random plot
33 | data = np.load("mars random experiment.npz")
34 | coverage = data["coverage"]
35 | altitudes = data["altitudes"]
36 | visited = data["visited"]
37 | world_shape = data["world_shape"]
38 |
39 | plot_paper(altitudes, visited, world_shape, './random_exploration1.pdf')
40 |
41 | # Non ergodic plot
42 | data = np.load("mars non ergodic experiment.npz")
43 | coverage = data["coverage"]
44 | altitudes = data["altitudes"]
45 | S_hat = data["S_hat"]
46 | world_shape = data["world_shape"]
47 |
48 | plot_paper(altitudes, S_hat, world_shape, './no_ergodic_exploration1.pdf')
49 |
50 | # No expanders plot
51 | data = np.load("mars no G experiment.npz")
52 | coverage = data["coverage"]
53 | altitudes = data["altitudes"]
54 | S_hat = data["S_hat"]
55 | world_shape = data["world_shape"]
56 | coverage_over_t = data["coverage_over_t"]
57 | plot_coverage(coverage_over_t)
58 |
59 | plot_paper(altitudes, S_hat, world_shape, './no_G_exploration1.pdf')
--------------------------------------------------------------------------------
/examples/mars/mars.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function
2 |
3 | import sys
4 | import time
5 |
6 | import numpy as np
7 |
8 | from safemdp.grid_world import *
9 | from mars_utilities import (mars_map, initialize_SafeMDP_object,
10 | performance_metrics)
11 | from plot_utilities import *
12 |
13 | print(sys.version)
14 |
15 |
16 | # Control experiments and saving
17 | save_performance = True
18 | random_experiment = True
19 | non_safe_experiment = True
20 | non_ergodic_experiment = True
21 | no_expanders_exploration = True
22 |
23 | ############################## SAFE ###########################################
24 |
25 | # Get mars data
26 | altitudes, coord, world_shape, step_size, num_of_points = mars_map()
27 |
28 | # Initialize object for simulation
29 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object(
30 | altitudes, coord, world_shape, step_size)
31 |
32 | # Initialize for performance storage
33 | time_steps = 525
34 | coverage_over_t = np.empty(time_steps, dtype=float)
35 |
36 | # Simulation loop
37 | t = time.time()
38 | unsafe_count = 0
39 | source = start
40 |
41 | for i in range(time_steps):
42 |
43 | # Simulation
44 | x.update_sets()
45 | next_sample = x.target_sample()
46 | x.add_observation(*next_sample)
47 | path = shortest_path(source, next_sample, x.graph)
48 | source = path[-1]
49 |
50 | # Performances
51 | unsafe_transitions, coverage, false_safe = performance_metrics(path, x,
52 | true_S_hat_epsilon,
53 | true_S_hat, h_hard)
54 | unsafe_count += unsafe_transitions
55 | coverage_over_t[i] = coverage
56 | print(coverage, false_safe, unsafe_count, i)
57 |
58 | print(str(time.time() - t) + "seconds elapsed")
59 |
60 | # Posterior over heights for plotting
61 | mu_alt, var_alt = x.gp.predict(x.coord, include_likelihood=False)
62 |
63 |
64 | print("----UNSAFE EXPERIMENT---")
65 | print("Number of points for interpolation: " + str(num_of_points))
66 | print("False safe: " +str(false_safe))
67 | print("Unsafe evaluations: " + str(unsafe_count))
68 | print("Coverage: " + str(coverage))
69 |
70 | if save_performance:
71 | file_name = "mars safe experiment"
72 |
73 | np.savez(file_name, false_safe=false_safe, coverage=coverage,
74 | coverage_over_t=coverage_over_t, mu_alt=mu_alt,
75 | var_alt=var_alt, altitudes=x.altitudes, S_hat=x.S_hat,
76 | time_steps=time_steps, world_shape=world_shape, X=x.gp.X,
77 | Y=x.gp.Y, coord=x.coord, beta=x.beta)
78 |
79 |
80 | ########################## NON SAFE ###########################################
81 | if non_safe_experiment:
82 |
83 | # Get mars data
84 | altitudes, coord, world_shape, step_size, num_of_points = mars_map()
85 |
86 | # Initialize object for simulation
87 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object(
88 | altitudes, coord, world_shape, step_size)
89 |
90 | # Assume all transitions are safe
91 | x.S[:] = True
92 |
93 | # Simulation loop
94 | unsafe_count = 0
95 | source = start
96 | trajectory = []
97 |
98 | for i in range(time_steps):
99 | x.update_sets()
100 | next_sample = x.target_sample()
101 | x.add_observation(*next_sample)
102 | path = shortest_path(source, next_sample, x.graph)
103 | source = path[-1]
104 |
105 | # Check safety
106 | path_altitudes = x.altitudes[path]
107 | unsafe_count = np.sum(-np.diff(path_altitudes) < h_hard)
108 |
109 | if unsafe_count == 0:
110 | trajectory = trajectory + path[:-1]
111 | else:
112 | trajectory = trajectory + safe_subpath(path, altitudes, h_hard)
113 |
114 | # Convert trajectory to S_hat-like matrix
115 | visited = path_to_boolean_matrix(trajectory, x.graph, x.S)
116 |
117 | # Normalization factor
118 | max_size = float(np.count_nonzero(true_S_hat_epsilon))
119 |
120 | # Performance
121 | coverage = 100 * np.count_nonzero(
122 | np.logical_and(visited, true_S_hat_epsilon)) / max_size
123 |
124 | # Print
125 | print("----UNSAFE EXPERIMENT---")
126 | print("Unsafe evaluations: " + str(unsafe_count))
127 | print("Coverage: " + str(coverage))
128 |
129 | if unsafe_count > 0:
130 | break
131 |
132 | if save_performance:
133 | file_name = "mars unsafe experiment"
134 |
135 | np.savez(file_name, coverage=coverage, altitudes=x.altitudes,
136 | visited=visited, world_shape=world_shape)
137 |
138 | ############################### RANDOM ########################################
139 | if random_experiment:
140 |
141 | # Get mars data
142 | altitudes, coord, world_shape, step_size, num_of_points = mars_map()
143 |
144 | # Initialize object for simulation
145 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object(
146 | altitudes, coord, world_shape, step_size)
147 |
148 | source = start
149 | trajectory = [source]
150 |
151 | for i in range(time_steps):
152 |
153 | # Choose action at random
154 | a = np.random.choice([1, 2, 3, 4])
155 |
156 | # Add resulting state to trajectory
157 | for _, next_node, data in x.graph.out_edges(nbunch=source, data=True):
158 | if data["action"] == a:
159 | trajectory = trajectory + [next_node]
160 | source = next_node
161 |
162 | # Check safety
163 | path_altitudes = x.altitudes[trajectory]
164 | unsafe_count = np.sum(-np.diff(path_altitudes) < h_hard)
165 |
166 | # Get trajectory up to first unsafe transition
167 | trajectory = safe_subpath(trajectory, x.altitudes, h_hard)
168 |
169 | # Convert trajectory to S_hat-like matrix
170 | visited = path_to_boolean_matrix(trajectory, x.graph, x.S)
171 |
172 | # Normalization factor
173 | max_size = float(np.count_nonzero(true_S_hat_epsilon))
174 |
175 | # Performance
176 | coverage = 100 * np.count_nonzero(np.logical_and(visited,
177 | true_S_hat_epsilon))/max_size
178 | # Print
179 | print("----RANDOM EXPERIMENT---")
180 | print("Unsafe evaluations: " + str(unsafe_count))
181 | print("Coverage: " + str(coverage))
182 |
183 | if save_performance:
184 | file_name = "mars random experiment"
185 |
186 | np.savez(file_name, coverage=coverage, altitudes=x.altitudes,
187 | visited=visited, world_shape=world_shape)
188 |
189 | ############################# NON ERGODIC #####################################
190 | if non_ergodic_experiment:
191 |
192 | # Get mars data
193 | altitudes, coord, world_shape, step_size, num_of_points = mars_map()
194 |
195 | # Need to remove expanders otherwise next sample will be in G and
196 | # therefore in S_hat before I can set S_hat = S
197 | L_non_ergodic = 1000.
198 |
199 | # Initialize object for simulation
200 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object(
201 | altitudes, coord, world_shape, step_size, L=L_non_ergodic)
202 |
203 | # Simulation loop
204 | unsafe_count = 0
205 | source = start
206 |
207 | for i in range(time_steps):
208 | x.update_sets()
209 |
210 | # Remove ergodicity properties
211 | x.S_hat = x.S.copy()
212 |
213 | next_sample = x.target_sample()
214 | x.add_observation(*next_sample)
215 | try:
216 | path = shortest_path(source, next_sample, x.graph)
217 | source = path[-1]
218 |
219 | # Check safety
220 | path_altitudes = x.altitudes[path]
221 | unsafe_transitions = np.sum(-np.diff(path_altitudes) < h_hard)
222 | unsafe_count += unsafe_transitions
223 | except Exception:
224 | print ("No safe path available")
225 | break
226 |
227 | # For coverage we consider every state that has at least one action
228 | # classified as safe
229 | x.S_hat[:, 0] = np.any(x.S_hat[:, 1:], axis=1)
230 |
231 | # Normalization factor
232 | max_size = float(np.count_nonzero(true_S_hat_epsilon))
233 |
234 | # Performance
235 | coverage = 100 * np.count_nonzero(np.logical_and(x.S_hat,
236 | true_S_hat_epsilon))/max_size
237 |
238 | # Print
239 | print("----NON ERGODIC EXPERIMENT---")
240 | print("Unsafe evaluations: " + str(unsafe_count))
241 | print("Coverage: " + str(coverage))
242 |
243 | if save_performance:
244 | file_name = "mars non ergodic experiment"
245 |
246 | np.savez(file_name, coverage=coverage, altitudes=x.altitudes,
247 | S_hat=x.S_hat, world_shape=world_shape)
248 |
249 | ################################## NO EXPANDERS ###############################
250 | if no_expanders_exploration:
251 |
252 | # Get mars data
253 | altitudes, coord, world_shape, step_size, num_of_points = mars_map()
254 |
255 | L_no_expanders = 1000.
256 |
257 | # Initialize object for simulation
258 | start, x, true_S_hat, true_S_hat_epsilon, h_hard = initialize_SafeMDP_object(
259 | altitudes, coord, world_shape, step_size, L=L_no_expanders)
260 |
261 | # Initialize for performance storage
262 | coverage_over_t = np.empty(time_steps, dtype=float)
263 |
264 | # Simulation loop
265 | source = start
266 | unsafe_count = 0
267 |
268 | for i in range(int(time_steps)):
269 |
270 | #Simulation
271 | x.update_sets()
272 | next_sample = x.target_sample()
273 | x.add_observation(*next_sample)
274 | path = shortest_path(source, next_sample, x.graph)
275 | source = path[-1]
276 |
277 | # Performances
278 | unsafe_transitions, coverage, false_safe = performance_metrics(path, x,
279 | true_S_hat_epsilon,
280 | true_S_hat, h_hard)
281 | unsafe_count += unsafe_transitions
282 | coverage_over_t[i] = coverage
283 | print(coverage, false_safe, unsafe_count, i)
284 |
285 | # Print
286 | print("----NO EXPANDER EXPERIMENT---")
287 | print("False safe: " + str(false_safe))
288 | print("Unsafe evaluations: " + str(unsafe_count))
289 | print("Coverage: " + str(coverage))
290 |
291 | if save_performance:
292 | file_name = "mars no G experiment"
293 |
294 | np.savez(file_name, false_safe=false_safe, coverage=coverage,
295 | coverage_over_t=coverage_over_t, altitudes=x.altitudes, S_hat=x.S_hat,
296 | world_shape=world_shape)
297 |
--------------------------------------------------------------------------------
/examples/mars/mars_utilities.py:
--------------------------------------------------------------------------------
1 | from safemdp.grid_world import *
2 | from osgeo import gdal
3 | from scipy import interpolate
4 | import numpy as np
5 | import os
6 | import matplotlib.pyplot as plt
7 | import GPy
8 |
9 |
10 | __all__ = ['mars_map', 'initialize_SafeMDP_object', 'performance_metrics']
11 |
12 |
13 | def mars_map(plot_map=False, interpolation=False):
14 | """
15 | Extract the map for the simulation from the HiRISE data. If the HiRISE
16 | data is not in the current folder it will be downloaded and converted to
17 | GeoTiff extension with gdal.
18 |
19 | Parameters
20 | ----------
21 | plot_map: bool
22 | If true plots the map that will be used for exploration
23 | interpolation: bool
24 | If true the data of the map will be interpolated with splines to
25 | obtain a finer grid
26 |
27 | Returns
28 | -------
29 | altitudes: np.array
30 | 1-d vector with altitudes for each node
31 | coord: np.array
32 | Coordinate of the map we use for exploration
33 | world_shape: tuple
34 | Size of the grid world (rows, columns)
35 | step_size: tuple
36 | Step size for the grid (row, column)
37 | num_of_points: int
38 | Interpolation parameter. Indicates the scaling factor for the
39 | original step size
40 | """
41 |
42 | # Define the dimension of the map we want to investigate and its resolution
43 | world_shape = (120, 70)
44 | step_size = (1., 1.)
45 |
46 | # Download and convert to GEOtiff Mars data
47 | if not os.path.exists('./mars.tif'):
48 | if not os.path.exists("./mars.IMG"):
49 | import urllib
50 |
51 | print('Downloading MARS map, this make take a while...')
52 | # Download the IMG file
53 | urllib.urlretrieve(
54 | "http://www.uahirise.org/PDS/DTM/PSP/ORB_010200_010299"
55 | "/PSP_010228_1490_ESP_016320_1490"
56 | "/DTEEC_010228_1490_016320_1490_A01.IMG", "mars.IMG")
57 |
58 | # Convert to tif
59 | print('Converting map to geotif...')
60 | os.system("gdal_translate -of GTiff ./mars.IMG ./mars.tif")
61 | print('Done')
62 |
63 | # Read the data with gdal module
64 | gdal.UseExceptions()
65 | ds = gdal.Open("./mars.tif")
66 | band = ds.GetRasterBand(1)
67 | elevation = band.ReadAsArray()
68 |
69 | # Extract the area of interest
70 | startX = 2890
71 | startY = 1955
72 | altitudes = np.copy(elevation[startX:startX + world_shape[0],
73 | startY:startY + world_shape[1]])
74 |
75 | # Center the data
76 | mean_val = (np.max(altitudes) + np.min(altitudes)) / 2.
77 | altitudes[:] = altitudes - mean_val
78 |
79 | # Define coordinates
80 | n, m = world_shape
81 | step1, step2 = step_size
82 | xx, yy = np.meshgrid(np.linspace(0, (n - 1) * step1, n),
83 | np.linspace(0, (m - 1) * step2, m), indexing="ij")
84 | coord = np.vstack((xx.flatten(), yy.flatten())).T
85 |
86 | # Interpolate data
87 | if interpolation:
88 |
89 | # Interpolating function
90 | spline_interpolator = interpolate.RectBivariateSpline(
91 | np.linspace(0, (n - 1) * step1, n),
92 | np.linspace(0, (m - 1) * step1, m), altitudes)
93 |
94 | # New size and resolution
95 | num_of_points = 1
96 | world_shape = tuple([(x - 1) * num_of_points + 1 for x in world_shape])
97 | step_size = tuple([x / num_of_points for x in step_size])
98 |
99 | # New coordinates and altitudes
100 | n, m = world_shape
101 | step1, step2 = step_size
102 | xx, yy = np.meshgrid(np.linspace(0, (n - 1) * step1, n),
103 | np.linspace(0, (m - 1) * step2, m), indexing="ij")
104 | coord = np.vstack((xx.flatten(), yy.flatten())).T
105 |
106 | altitudes = spline_interpolator(np.linspace(0, (n - 1) * step1, n),
107 | np.linspace(0, (m - 1) * step2, m))
108 | else:
109 | num_of_points = 1
110 |
111 | # Plot area
112 | if plot_map:
113 | plt.imshow(altitudes.T, origin="lower", interpolation="nearest")
114 | plt.colorbar()
115 | plt.show()
116 | altitudes = altitudes.flatten()
117 |
118 | return altitudes, coord, world_shape, step_size, num_of_points
119 |
120 |
121 | def initialize_SafeMDP_object(altitudes, coord, world_shape, step_size, L=0.2,
122 | beta=2, length=14.5, sigma_n=0.075, start_x=60,
123 | start_y=61):
124 | """
125 |
126 | Parameters
127 | ----------
128 | altitudes: np.array
129 | 1-d vector with altitudes for each node
130 | coord: np.array
131 | Coordinate of the map we use for exploration
132 | world_shape: tuple
133 | Size of the grid world (rows, columns)
134 | step_size: tuple
135 | Step size for the grid (row, column)
136 | L: float
137 | Lipschitz constant to compute expanders
138 | beta: float
139 | Scaling factor for confidence intervals
140 | length: float
141 | Lengthscale for Matern kernel
142 | sigma_n:
143 | Standard deviation for gaussian noise
144 | start_x: int
145 | x coordinate of the starting point
146 | start_y:
147 | y coordinate of the starting point
148 |
149 | Returns
150 | -------
151 | start: int
152 | Node number of initial state
153 | x: SafeMDP
154 | Instance of the SafeMDP class for the mars exploration problem
155 | true_S_hat: np.array
156 | True S_hat if safety feature is known with no error and h_hard is used
157 | true_S_hat_epsilon: np.array
158 | True S_hat if safety feature is known up to epsilon and h is used
159 | h_hard: float
160 | True safety thrshold. It can be different from the safety threshold
161 | used for classification in case the agent needs to use extra caution
162 | (in our experiments h=25 deg, h_har=30 deg)
163 | """
164 |
165 | # Safety threshold
166 | h = -np.tan(np.pi / 9. + np.pi / 36.) * step_size[0]
167 |
168 | #Initial node
169 | start = start_x * world_shape[1] + start_y
170 |
171 | # Initial safe sets
172 | S_hat0 = compute_S_hat0(start, world_shape, 4, altitudes,
173 | step_size, h)
174 | S0 = np.copy(S_hat0)
175 | S0[:, 0] = True
176 |
177 | # Initialize GP
178 | X = coord[start, :].reshape(1, 2)
179 | Y = altitudes[start].reshape(1, 1)
180 | kernel = GPy.kern.Matern52(input_dim=2, lengthscale=length, variance=100.)
181 | lik = GPy.likelihoods.Gaussian(variance=sigma_n ** 2)
182 | gp = GPy.core.GP(X, Y, kernel, lik)
183 |
184 | # Define SafeMDP object
185 | x = GridWorld(gp, world_shape, step_size, beta, altitudes, h, S0,
186 | S_hat0, L, update_dist=25)
187 |
188 | # Add samples about actions from starting node
189 | for i in range(5):
190 | x.add_observation(start, 1)
191 | x.add_observation(start, 2)
192 | x.add_observation(start, 3)
193 | x.add_observation(start, 4)
194 |
195 | x.gp.set_XY(X=x.gp.X[1:, :], Y=x.gp.Y[1:, :]) # Necessary for results as in
196 | # paper
197 |
198 | # True safe set for false safe
199 | h_hard = -np.tan(np.pi / 6.) * step_size[0]
200 | true_S = compute_true_safe_set(x.world_shape, x.altitudes, h_hard)
201 | true_S_hat = compute_true_S_hat(x.graph, true_S, x.initial_nodes)
202 |
203 | # True safe set for completeness
204 | epsilon = sigma_n * beta
205 | true_S_epsilon = compute_true_safe_set(x.world_shape, x.altitudes,
206 | x.h + epsilon)
207 | true_S_hat_epsilon = compute_true_S_hat(x.graph, true_S_epsilon,
208 | x.initial_nodes)
209 |
210 | return start, x, true_S_hat, true_S_hat_epsilon, h_hard
211 |
212 |
213 | def performance_metrics(path, x, true_S_hat_epsilon, true_S_hat, h_hard):
214 | """
215 |
216 | Parameters
217 | ----------
218 | path: np.array
219 | Nodes of the shortest safe path
220 | x: SafeMDP
221 | Instance of the SafeMDP class for the mars exploration problem
222 | true_S_hat_epsilon: np.array
223 | True S_hat if safety feature is known up to epsilon and h is used
224 | true_S_hat: np.array
225 | True S_hat if safety feature is known with no error and h_hard is used
226 | h_hard: float
227 | True safety thrshold. It can be different from the safety threshold
228 | used for classification in case the agent needs to use extra caution
229 | (in our experiments h=25 deg, h_har=30 deg)
230 |
231 | Returns
232 | -------
233 | unsafe_transitions: int
234 | Number of unsafe transitions along the path
235 | coverage: float
236 | Percentage of coverage of true_S_hat_epsilon
237 | false_safe: int
238 | Number of misclassifications (classifing something as safe when it
239 | acutally is unsafe according to h_hard )
240 | """
241 |
242 | # Count unsafe transitions along the path
243 | path_altitudes = x.altitudes[path]
244 | unsafe_transitions = np.sum(-np.diff(path_altitudes) < h_hard)
245 |
246 | # Coverage
247 | max_size = float(np.count_nonzero(true_S_hat_epsilon))
248 | coverage = 100 * np.count_nonzero(np.logical_and(x.S_hat,
249 | true_S_hat_epsilon))/max_size
250 | # False safe
251 | false_safe = np.count_nonzero(np.logical_and(x.S_hat, ~true_S_hat))
252 |
253 | return unsafe_transitions, coverage, false_safe
254 |
--------------------------------------------------------------------------------
/examples/mars/plot_utilities.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | from matplotlib import rcParams
4 | from matplotlib.colors import ColorConverter
5 |
6 |
7 | def paper_figure(figsize, subplots=None, **kwargs):
8 | """Define default values for font, fontsize and use latex"""
9 |
10 | def cm2inch(cm_tupl):
11 | """Convert cm to inches"""
12 | inch = 2.54
13 | return (cm / inch for cm in cm_tupl)
14 |
15 | if subplots is None:
16 | fig = plt.figure(figsize=cm2inch(figsize))
17 | else:
18 | fig, ax = plt.subplots(subplots[0], subplots[1],
19 | figsize=cm2inch(figsize), **kwargs)
20 |
21 | # Parameters for IJRR
22 | params = {
23 | 'font.family': 'serif',
24 | 'font.serif': ['Times',
25 | 'Palatino',
26 | 'New Century Schoolbook',
27 | 'Bookman',
28 | 'Computer Modern Roman'],
29 | 'font.sans-serif': ['Times',
30 | 'Helvetica',
31 | 'Avant Garde',
32 | 'Computer Modern Sans serif'],
33 | 'text.usetex': True,
34 | # Make sure mathcal doesn't use the Times style
35 | 'text.latex.preamble':
36 | r'\DeclareMathAlphabet{\mathcal}{OMS}{cmsy}{m}{n}',
37 |
38 | 'axes.labelsize': 9,
39 | 'axes.linewidth': .75,
40 |
41 | 'font.size': 9,
42 | 'legend.fontsize': 9,
43 | 'xtick.labelsize': 8,
44 | 'ytick.labelsize': 8,
45 |
46 | # 'figure.dpi': 150,
47 | # 'savefig.dpi': 600,
48 | 'legend.numpoints': 1,
49 | }
50 |
51 | rcParams.update(params)
52 |
53 | if subplots is None:
54 | return fig
55 | else:
56 | return fig, ax
57 |
58 |
59 | def format_figure(axis, cbar=None):
60 | axis.spines['top'].set_linewidth(0.1)
61 | axis.spines['top'].set_alpha(0.5)
62 | axis.spines['right'].set_linewidth(0.1)
63 | axis.spines['right'].set_alpha(0.5)
64 | axis.xaxis.set_ticks_position('bottom')
65 | axis.yaxis.set_ticks_position('left')
66 |
67 | axis.set_xticks(np.arange(0, 121, 30))
68 | yticks = np.arange(0, 71, 35)
69 | axis.set_yticks(yticks)
70 | axis.set_yticklabels(['{0}'.format(tick) for tick in yticks[::-1]])
71 |
72 | axis.set_xlabel(r'distance [m]')
73 | axis.set_ylabel(r'distance [m]', labelpad=2)
74 | if cbar is not None:
75 | cbar.set_label(r'altitude [m]')
76 |
77 | cbar.set_ticks(np.arange(0, 36, 10))
78 |
79 | for spine in cbar.ax.spines.itervalues():
80 | spine.set_linewidth(0.1)
81 | cbar.ax.yaxis.set_tick_params(color=emulate_color('k', 0.7))
82 |
83 | plt.tight_layout(pad=0.1)
84 |
85 |
86 | def emulate_color(color, alpha=1, background_color=(1, 1, 1)):
87 | """Take an RGBA color and an RGB background, return the emulated RGB color.
88 |
89 | The RGBA color with transparency alpha is converted to an RGB color via
90 | emulation in front of the background_color.
91 | """
92 | to_rgb = ColorConverter().to_rgb
93 | color = to_rgb(color)
94 | background_color = to_rgb(background_color)
95 | return [(1 - alpha) * bg_col + alpha * col
96 | for col, bg_col in zip(color, background_color)]
97 |
98 |
99 | def plot_paper(altitudes, S_hat, world_shape, fileName=""):
100 | """
101 | Plots for NIPS paper
102 | Parameters
103 | ----------
104 | altitudes: np.array
105 | True value of the altitudes of the map
106 | S_hat: np.array
107 | Safe and ergodic set
108 | world_shape: tuple
109 | Size of the grid world (rows, columns)
110 | fileName: string
111 | Name of the file to save the plot. If empty string the plot is not
112 | saved
113 | Returns
114 | -------
115 |
116 | """
117 | # Size of figures and colormap
118 | tw = cw = 13.968
119 | cmap = 'jet'
120 | alpha = 1.
121 | alpha_world = 0.25
122 | size_wb = np.array([cw / 2.2, tw / 4.])
123 | #size_wb = np.array([cw / 4.2, cw / 4.2])
124 |
125 | # Shift altitudes
126 | altitudes -= np.nanmin(altitudes)
127 | vmin, vmax = (np.nanmin(altitudes), np.nanmax(altitudes))
128 | origin = 'lower'
129 |
130 | fig = paper_figure(size_wb)
131 |
132 | # Copy altitudes for different alpha values
133 | altitudes2 = altitudes.copy()
134 | altitudes2[~S_hat[:, 0]] = np.nan
135 |
136 | axis = fig.gca()
137 |
138 | # Plot world
139 | c = axis.imshow(np.reshape(altitudes, world_shape).T, origin=origin, vmin=vmin,
140 | vmax=vmax, cmap=cmap, alpha=alpha_world)
141 |
142 | cbar = plt.colorbar(c)
143 | #cbar = None
144 |
145 | # Plot explored area
146 | plt.imshow(np.reshape(altitudes2, world_shape).T, origin=origin, vmin=vmin,
147 | vmax=vmax, interpolation='nearest', cmap=cmap, alpha=alpha)
148 | format_figure(axis, cbar)
149 |
150 | # Save figure
151 | if fileName:
152 | plt.savefig(fileName, transparent=False, format="pdf")
153 | plt.show()
154 |
155 |
156 |
157 | def plot_dist_from_C(mu, var, beta, altitudes, world_shape):
158 | """
159 | Image plot of the distance of the true safety feature from the
160 | confidence interval. Distance is equal to 0 if the true r(s) lies within
161 | C(s), it is > 0 if r(s)>u(s) and < 0 if r(s) 0] = diff_u[diff_u > 0]
189 |
190 | # Below l
191 | diff_l = altitudes - l
192 | dist_from_confidence_interval[diff_l < 0] = diff_l[diff_l < 0]
193 |
194 | # Define limits
195 | max_value = np.max(dist_from_confidence_interval)
196 | min_value = np.min(dist_from_confidence_interval)
197 | limit = np.max([max_value, np.abs(min_value)])
198 |
199 | # Plot
200 | plt.figure()
201 | plt.imshow(
202 | np.reshape(dist_from_confidence_interval, world_shape).T, origin='lower',
203 | interpolation='nearest', vmin=-limit, vmax=limit)
204 | title = "Distance from confidence interval"
205 | plt.title(title)
206 | plt.colorbar()
207 | plt.show()
208 |
209 |
210 | def plot_coverage(coverage_over_t):
211 | """
212 | Plots coverage of true_S_hat_epsilon as a function of time
213 |
214 | """
215 | plt.figure()
216 | plt.plot(coverage_over_t)
217 | title = "Coverage over time"
218 | plt.title(title)
219 | plt.show()
220 |
221 |
--------------------------------------------------------------------------------
/examples/mars/requirements.txt:
--------------------------------------------------------------------------------
1 | cython == 0.23.4
2 | matplotlib == 1.5.1
3 | mkl == 11.3.1
4 | mkl-service == 1.1.2
5 | networkx == 1.11
6 | numpy == 1.10.0
7 | scipy == 0.17.0
8 | osgeo >= 2.0.0
9 |
--------------------------------------------------------------------------------
/examples/sample.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, absolute_import
2 |
3 | import time
4 |
5 | import GPy
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 |
9 | from safemdp.grid_world import (compute_true_safe_set, compute_S_hat0,
10 | compute_true_S_hat, draw_gp_sample, GridWorld)
11 |
12 | # Define world
13 | world_shape = (20, 20)
14 | step_size = (0.5, 0.5)
15 |
16 | # Define GP
17 | noise = 0.001
18 | kernel = GPy.kern.RBF(input_dim=2, lengthscale=(2., 2.), variance=1.,
19 | ARD=True)
20 | lik = GPy.likelihoods.Gaussian(variance=noise ** 2)
21 | lik.constrain_bounded(1e-6, 10000.)
22 |
23 | # Sample and plot world
24 | altitudes, coord = draw_gp_sample(kernel, world_shape, step_size)
25 | fig = plt.figure()
26 | ax = fig.add_subplot(111, projection='3d')
27 | ax.plot_trisurf(coord[:, 0], coord[:, 1], altitudes)
28 | plt.show()
29 |
30 | # Define coordinates
31 | n, m = world_shape
32 | step1, step2 = step_size
33 | xx, yy = np.meshgrid(np.linspace(0, (n - 1) * step1, n),
34 | np.linspace(0, (m - 1) * step2, m),
35 | indexing="ij")
36 | coord = np.vstack((xx.flatten(), yy.flatten())).T
37 |
38 | # Safety threhsold
39 | h = -0.25
40 |
41 | # Lipschitz
42 | L = 0
43 |
44 | # Scaling factor for confidence interval
45 | beta = 2
46 |
47 | # Data to initialize GP
48 | n_samples = 1
49 | ind = np.random.choice(range(altitudes.size), n_samples)
50 | X = coord[ind, :]
51 | Y = altitudes[ind].reshape(n_samples, 1) + np.random.randn(n_samples,
52 | 1)
53 | gp = GPy.core.GP(X, Y, kernel, lik)
54 |
55 | # Initialize safe sets
56 | S0 = np.zeros((np.prod(world_shape), 5), dtype=bool)
57 | S0[:, 0] = True
58 | S_hat0 = compute_S_hat0(np.nan, world_shape, 4, altitudes,
59 | step_size, h)
60 |
61 | # Define SafeMDP object
62 | x = GridWorld(gp, world_shape, step_size, beta, altitudes, h, S0, S_hat0,
63 | L)
64 |
65 | # Insert samples from (s, a) in S_hat0
66 | tmp = np.arange(x.coord.shape[0])
67 | s_vec_ind = np.random.choice(tmp[np.any(x.S_hat[:, 1:], axis=1)])
68 | tmp = np.arange(1, x.S.shape[1])
69 | actions = tmp[x.S_hat[s_vec_ind, 1:].squeeze()]
70 | for i in range(3):
71 | x.add_observation(s_vec_ind, np.random.choice(actions))
72 |
73 | # Remove samples used for GP initialization
74 | x.gp.set_XY(x.gp.X[n_samples:, :], x.gp.Y[n_samples:])
75 |
76 | t = time.time()
77 | for i in range(100):
78 | x.update_sets()
79 | next_sample = x.target_sample()
80 | x.add_observation(*next_sample)
81 | # x.compute_graph_lazy()
82 | # plt.figure(1)
83 | # plt.clf()
84 | # nx.draw_networkx(x.graph)
85 | # plt.show()
86 | print("Iteration: " + str(i))
87 |
88 | print(str(time.time() - t) + "seconds elapsed")
89 |
90 | true_S = compute_true_safe_set(x.world_shape, x.altitudes, x.h)
91 | true_S_hat = compute_true_S_hat(x.graph, true_S, x.initial_nodes)
92 |
93 | # Plot safe sets
94 | x.plot_S(x.S_hat)
95 | x.plot_S(true_S_hat)
96 |
97 | # Classification performance
98 | print(np.sum(np.logical_and(true_S_hat, np.logical_not(
99 | x.S_hat)))) # in true S_hat and not S_hat
100 | print(np.sum(np.logical_and(x.S_hat, np.logical_not(true_S_hat))))
101 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | GPy >= 0.8.0
2 | numpy >= 1.7.2
3 | scipy >= 0.16
4 | matplotlib >= 1.5.0
5 | networkx >= 1.1
6 |
7 |
--------------------------------------------------------------------------------
/safemdp/SafeMDP_class.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, absolute_import
2 |
3 | import numpy as np
4 |
5 | from .utilities import max_out_degree
6 |
7 | __all__ = ['SafeMDP', 'link_graph_and_safe_set', 'reachable_set',
8 | 'returnable_set']
9 |
10 |
11 | class SafeMDP(object):
12 | """Base class for safe exploration in MDPs.
13 |
14 | This class only provides basic options to compute the safely reachable
15 | and returnable sets. The actual update of the safety feature must be done
16 | in a class that inherits from `SafeMDP`. See `safempd.GridWorld` for an
17 | example.
18 |
19 | Parameters
20 | ----------
21 | graph: networkx.DiGraph
22 | The graph that models the MDP. Each edge has an attribute `safe` in its
23 | metadata, which determines the safety of the transition.
24 | gp: GPy.core.GPRegression
25 | A Gaussian process model that can be used to determine the safety of
26 | transitions. Exact structure depends heavily on the usecase.
27 | S_hat0: boolean array
28 | An array that has True on the ith position if the ith node in the graph
29 | is part of the safe set.
30 | h: float
31 | The safety threshold.
32 | L: float
33 | The lipschitz constant
34 | beta: float, optional
35 | The confidence interval used by the GP model.
36 | """
37 | def __init__(self, graph, gp, S_hat0, h, L, beta=2):
38 | super(SafeMDP, self).__init__()
39 | # Scalar for gp confidence intervals
40 | self.beta = beta
41 |
42 | # Threshold
43 | self.h = h
44 |
45 | # Lipschitz constant
46 | self.L = L
47 |
48 | # GP model
49 | self.gp = gp
50 |
51 | self.graph = graph
52 | self.graph_reverse = self.graph.reverse()
53 |
54 | num_nodes = self.graph.number_of_nodes()
55 | num_edges = max_out_degree(graph)
56 | safe_set_size = (num_nodes, num_edges + 1)
57 |
58 | self.reach = np.empty(safe_set_size, dtype=np.bool)
59 | self.G = np.empty(safe_set_size, dtype=np.bool)
60 |
61 | self.S_hat = S_hat0.copy()
62 | self.S_hat0 = self.S_hat.copy()
63 | self.initial_nodes = self.S_hat0[:, 0].nonzero()[0].tolist()
64 |
65 | def compute_S_hat(self):
66 | """Compute the safely reachable set given the current safe_set."""
67 | self.reach[:] = False
68 | reachable_set(self.graph, self.initial_nodes, out=self.reach)
69 |
70 | self.S_hat[:] = False
71 | returnable_set(self.graph, self.graph_reverse, self.initial_nodes,
72 | out=self.S_hat)
73 |
74 | self.S_hat &= self.reach
75 |
76 | def add_gp_observations(self, x_new, y_new):
77 | """Add observations to the gp mode."""
78 | # Update GP with observations
79 | self.gp.set_XY(np.vstack((self.gp.X,
80 | x_new)),
81 | np.vstack((self.gp.Y,
82 | y_new)))
83 |
84 |
85 | def link_graph_and_safe_set(graph, safe_set):
86 | """Link the safe set to the graph model.
87 |
88 | Parameters
89 | ----------
90 | graph: nx.DiGraph()
91 | safe_set: np.array
92 | Safe set. For each node the edge (i, j) under action (a) is linked to
93 | safe_set[i, a]
94 | """
95 | for node, next_node in graph.edges_iter():
96 | edge = graph[node][next_node]
97 | edge['safe'] = safe_set[node:node + 1, edge['action']]
98 |
99 |
100 | def reachable_set(graph, initial_nodes, out=None):
101 | """
102 | Compute the safe, reachable set of a graph
103 |
104 | Parameters
105 | ----------
106 | graph: nx.DiGraph
107 | Directed graph. Each edge must have associated action metadata,
108 | which specifies the action that this edge corresponds to.
109 | Each edge has an attribute ['safe'], which is a boolean that
110 | indicates safety
111 | initial_nodes: list
112 | List of the initial, safe nodes that are used as a starting point to
113 | compute the reachable set.
114 | out: np.array
115 | The array to write the results to. Is assumed to be False everywhere
116 | except at the initial nodes
117 |
118 | Returns
119 | -------
120 | reachable_set: np.array
121 | Boolean array that indicates whether a node belongs to the reachable
122 | set.
123 | """
124 |
125 | if not initial_nodes:
126 | raise AttributeError('Set of initial nodes needs to be non-empty.')
127 |
128 | if out is None:
129 | visited = np.zeros((graph.number_of_nodes(),
130 | max_out_degree(graph) + 1),
131 | dtype=np.bool)
132 | else:
133 | visited = out
134 |
135 | # All nodes in the initial set are visited
136 | visited[initial_nodes, 0] = True
137 |
138 | stack = list(initial_nodes)
139 |
140 | # TODO: rather than checking if things are safe, specify a safe subgraph?
141 | while stack:
142 | node = stack.pop(0)
143 | # iterate over edges going away from node
144 | for _, next_node, data in graph.edges_iter(node, data=True):
145 | action = data['action']
146 | if not visited[node, action] and data['safe']:
147 | visited[node, action] = True
148 | if not visited[next_node, 0]:
149 | stack.append(next_node)
150 | visited[next_node, 0] = True
151 |
152 | if out is None:
153 | return visited
154 |
155 |
156 | def returnable_set(graph, reverse_graph, initial_nodes, out=None):
157 | """
158 | Compute the safe, returnable set of a graph
159 |
160 | Parameters
161 | ----------
162 | graph: nx.DiGraph
163 | Directed graph. Each edge must have associated action metadata,
164 | which specifies the action that this edge corresponds to.
165 | Each edge has an attribute ['safe'], which is a boolean that
166 | indicates safety
167 | reverse_graph: nx.DiGraph
168 | The reversed directed graph, `graph.reverse()`
169 | initial_nodes: list
170 | List of the initial, safe nodes that are used as a starting point to
171 | compute the returnable set.
172 | out: np.array
173 | The array to write the results to. Is assumed to be False everywhere
174 | except at the initial nodes
175 |
176 | Returns
177 | -------
178 | returnable_set: np.array
179 | Boolean array that indicates whether a node belongs to the returnable
180 | set.
181 | """
182 |
183 | if not initial_nodes:
184 | raise AttributeError('Set of initial nodes needs to be non-empty.')
185 |
186 | if out is None:
187 | visited = np.zeros((graph.number_of_nodes(),
188 | max_out_degree(graph) + 1),
189 | dtype=np.bool)
190 | else:
191 | visited = out
192 |
193 | # All nodes in the initial set are visited
194 | visited[initial_nodes, 0] = True
195 |
196 | stack = list(initial_nodes)
197 |
198 | # TODO: rather than checking if things are safe, specify a safe subgraph?
199 | while stack:
200 | node = stack.pop(0)
201 | # iterate over edges going into node
202 | for _, prev_node in reverse_graph.edges_iter(node):
203 | data = graph.get_edge_data(prev_node, node)
204 | if not visited[prev_node, data['action']] and data['safe']:
205 | visited[prev_node, data['action']] = True
206 | if not visited[prev_node, 0]:
207 | stack.append(prev_node)
208 | visited[prev_node, 0] = True
209 |
210 | if out is None:
211 | return visited
212 |
--------------------------------------------------------------------------------
/safemdp/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | The `safemdp` package implements tools for safe exploration in finite MDPs.
3 |
4 | Main classes
5 | ------------
6 |
7 | These classes provide the main functionality for the safe exploration
8 |
9 | .. autosummary::
10 | :template: template.rst
11 | :toctree:
12 |
13 | SafeMDP
14 | link_graph_and_safe_set
15 | reachable_set
16 | returnable_set
17 |
18 | Grid world
19 | ----------
20 |
21 | Some additional functionality specific to gridworlds.
22 |
23 | .. autosummary::
24 | :template: template.rst
25 | :toctree:
26 |
27 | GridWorld
28 | states_to_nodes
29 | nodes_to_states
30 | draw_gp_sample
31 | grid_world_graph
32 | grid
33 | compute_true_safe_set
34 | compute_true_S_hat
35 | compute_S_hat0
36 | shortest_path
37 | path_to_boolean_matrix
38 | safe_subpath
39 |
40 |
41 | Utilities
42 | ---------
43 |
44 | The following are utilities to make testing and working with the library more
45 | pleasant.
46 |
47 | .. autosummary::
48 | :template: template.rst
49 | :toctree:
50 |
51 | DifferenceKernel
52 | max_out_degree
53 | """
54 |
55 | from __future__ import absolute_import
56 |
57 | from .SafeMDP_class import *
58 | from .utilities import *
59 | from .grid_world import *
60 |
61 | # Add everything to __all__
62 | __all__ = [s for s in dir() if not s.startswith('_')]
63 |
64 | # Import test after __all__ (no documentation)
65 | from numpy.testing import Tester
66 | test = Tester().test
67 |
--------------------------------------------------------------------------------
/safemdp/grid_world.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, absolute_import
2 |
3 | import networkx as nx
4 | import numpy as np
5 | from matplotlib import pyplot as plt
6 | from scipy.spatial.distance import cdist
7 |
8 | from .utilities import DifferenceKernel
9 | from .SafeMDP_class import (reachable_set, returnable_set, SafeMDP,
10 | link_graph_and_safe_set)
11 |
12 |
13 | __all__ = ['compute_true_safe_set', 'compute_true_S_hat', 'compute_S_hat0',
14 | 'grid_world_graph', 'grid', 'GridWorld', 'draw_gp_sample',
15 | 'states_to_nodes', 'nodes_to_states', 'shortest_path',
16 | 'path_to_boolean_matrix', 'safe_subpath']
17 |
18 |
19 | def compute_true_safe_set(world_shape, altitude, h):
20 | """
21 | Computes the safe set given a perfect knowledge of the map
22 |
23 | Parameters
24 | ----------
25 | world_shape: tuple
26 | altitude: np.array
27 | 1-d vector with altitudes for each node
28 | h: float
29 | Safety threshold for height differences
30 |
31 | Returns
32 | -------
33 | true_safe: np.array
34 | Boolean array n_states x (n_actions + 1).
35 | """
36 |
37 | true_safe = np.zeros((world_shape[0] * world_shape[1], 5), dtype=np.bool)
38 |
39 | altitude_grid = altitude.reshape(world_shape)
40 |
41 | # Reshape so that first dimensions are actions, the rest is the grid world.
42 | safe_grid = true_safe.T.reshape((5,) + world_shape)
43 |
44 | # Height difference (next height - current height) --> positive if downhill
45 | up_diff = altitude_grid[:, :-1] - altitude_grid[:, 1:]
46 | right_diff = altitude_grid[:-1, :] - altitude_grid[1:, :]
47 |
48 | # State are always safe
49 | true_safe[:, 0] = True
50 |
51 | # Going in the opposite direction
52 | safe_grid[1, :, :-1] = up_diff >= h
53 | safe_grid[2, :-1, :] = right_diff >= h
54 | safe_grid[3, :, 1:] = -up_diff >= h
55 | safe_grid[4, 1:, :] = -right_diff >= h
56 |
57 | return true_safe
58 |
59 |
60 | def dynamics_vec_ind(states_vec_ind, action, world_shape):
61 | """
62 | Dynamic evolution of the system defined in vector representation of
63 | the states
64 |
65 | Parameters
66 | ----------
67 | states_vec_ind: np.array
68 | Contains all the vector indexes of the states we want to compute
69 | the dynamic evolution for
70 | action: int
71 | action performed by the agent
72 |
73 | Returns
74 | -------
75 | next_states_vec_ind: np.array
76 | vector index of states resulting from applying the action given
77 | as input to the array of starting points given as input
78 | """
79 | n, m = world_shape
80 | next_states_vec_ind = np.copy(states_vec_ind)
81 | if action == 1:
82 | next_states_vec_ind[:] = states_vec_ind + 1
83 | condition = np.mod(next_states_vec_ind, m) == 0
84 | next_states_vec_ind[condition] = states_vec_ind[condition]
85 | elif action == 2:
86 | next_states_vec_ind[:] = states_vec_ind + m
87 | condition = next_states_vec_ind >= m * n
88 | next_states_vec_ind[condition] = states_vec_ind[condition]
89 | elif action == 3:
90 | next_states_vec_ind[:] = states_vec_ind - 1
91 | condition = np.mod(states_vec_ind, m) == 0
92 | next_states_vec_ind[condition] = states_vec_ind[condition]
93 | elif action == 4:
94 | next_states_vec_ind[:] = states_vec_ind - m
95 | condition = next_states_vec_ind <= -1
96 | next_states_vec_ind[condition] = states_vec_ind[condition]
97 | else:
98 | raise ValueError("Unknown action")
99 | return next_states_vec_ind
100 |
101 |
102 | def compute_S_hat0(s, world_shape, n_actions, altitudes, step_size, h):
103 | """
104 | Compute a valid initial safe seed.
105 |
106 | Parameters
107 | ---------
108 | s: int or nan
109 | Vector index of the state where we start computing the safe seed
110 | from. If it is equal to nan, a state is chosen at random
111 | world_shape: tuple
112 | Size of the grid world (rows, columns)
113 | n_actions: int
114 | Number of actions available to the agent
115 | altitudes: np.array
116 | It contains the flattened n x m matrix where the altitudes of all
117 | the points in the map are stored
118 | step_size: tuple
119 | step sizes along each direction to create a linearly spaced grid
120 | h: float
121 | Safety threshold
122 |
123 | Returns
124 | ------
125 | S_hat: np.array
126 | Boolean array n_states x (n_actions + 1).
127 | """
128 | # Initialize
129 | n, m = world_shape
130 | n_states = n * m
131 | S_hat = np.zeros((n_states, n_actions + 1), dtype=bool)
132 |
133 | # In case an initial state is given
134 | if not np.isnan(s):
135 | S_hat[s, 0] = True
136 | valid_initial_seed = False
137 | vertical = False
138 | horizontal = False
139 | altitude_prev = altitudes[s]
140 | if not isinstance(s, np.ndarray):
141 | s = np.array([s])
142 |
143 | # Loop through actions
144 | for action in range(1, n_actions + 1):
145 |
146 | # Compute next state to check steepness
147 | next_vec_ind = dynamics_vec_ind(s, action, world_shape)
148 | altitude_next = altitudes[next_vec_ind]
149 |
150 | if s != next_vec_ind and -np.abs(altitude_prev - altitude_next) / \
151 | step_size[0] >= h:
152 | S_hat[s, action] = True
153 | S_hat[next_vec_ind, 0] = True
154 | S_hat[next_vec_ind, reverse_action(action)] = True
155 | if action == 1 or action == 3:
156 | vertical = True
157 | if action == 2 or action == 4:
158 | horizontal = True
159 |
160 | if vertical and horizontal:
161 | valid_initial_seed = True
162 |
163 | if valid_initial_seed:
164 | return S_hat
165 | else:
166 | print ("No valid initial seed starting from this state")
167 | S_hat[:] = False
168 | return S_hat
169 |
170 | # If an explicit initial state is not given
171 | else:
172 | while np.all(np.logical_not(S_hat)):
173 | initial_state = np.random.choice(n_states)
174 | S_hat = compute_S_hat0(initial_state, world_shape, n_actions,
175 | altitudes, step_size, h)
176 | return S_hat
177 |
178 |
179 | def reverse_action(action):
180 | # Computes the action that is the opposite of the one given as input
181 |
182 | rev_a = np.mod(action + 2, 4)
183 | if rev_a == 0:
184 | rev_a = 4
185 | return rev_a
186 |
187 |
188 | def grid_world_graph(world_size):
189 | """Create a graph that represents a grid world.
190 |
191 | In the grid world there are four actions, (1, 2, 3, 4), which correspond
192 | to going (up, right, down, left) in the x-y plane. The states are
193 | ordered so that `np.arange(np.prod(world_size)).reshape(world_size)`
194 | corresponds to a matrix where increasing the row index corresponds to the
195 | x direction in the graph, and increasing y index corresponds to the y
196 | direction.
197 |
198 | Parameters
199 | ----------
200 | world_size: tuple
201 | The size of the grid world (rows, columns)
202 |
203 | Returns
204 | -------
205 | graph: nx.DiGraph()
206 | The directed graph representing the grid world.
207 | """
208 | nodes = np.arange(np.prod(world_size))
209 | grid_nodes = nodes.reshape(world_size)
210 |
211 | graph = nx.DiGraph()
212 |
213 | # action 1: go right
214 | graph.add_edges_from(zip(grid_nodes[:, :-1].reshape(-1),
215 | grid_nodes[:, 1:].reshape(-1)),
216 | action=1)
217 |
218 | # action 2: go down
219 | graph.add_edges_from(zip(grid_nodes[:-1, :].reshape(-1),
220 | grid_nodes[1:, :].reshape(-1)),
221 | action=2)
222 |
223 | # action 3: go left
224 | graph.add_edges_from(zip(grid_nodes[:, 1:].reshape(-1),
225 | grid_nodes[:, :-1].reshape(-1)),
226 | action=3)
227 |
228 | # action 4: go up
229 | graph.add_edges_from(zip(grid_nodes[1:, :].reshape(-1),
230 | grid_nodes[:-1, :].reshape(-1)),
231 | action=4)
232 |
233 | return graph
234 |
235 |
236 | def compute_true_S_hat(graph, safe_set, initial_nodes, reverse_graph=None):
237 | """
238 | Compute the true safe set with reachability and returnability.
239 |
240 | Parameters
241 | ----------
242 | graph: nx.DiGraph
243 | safe_set: np.array
244 | initial_nodes: list of int
245 | reverse_graph: nx.DiGraph
246 | graph.reverse()
247 |
248 | Returns
249 | -------
250 | true_safe: np.array
251 | Boolean array n_states x (n_actions + 1).
252 | """
253 | graph = graph.copy()
254 | link_graph_and_safe_set(graph, safe_set)
255 | if reverse_graph is None:
256 | reverse_graph = graph.reverse()
257 | reach = reachable_set(graph, initial_nodes)
258 | ret = returnable_set(graph, reverse_graph, initial_nodes)
259 | ret &= reach
260 | return ret
261 |
262 |
263 | class GridWorld(SafeMDP):
264 | """
265 | Grid world with Safe exploration
266 |
267 | Parameters
268 | ----------
269 | gp: GPy.core.GP
270 | Gaussian process that expresses our current belief over the safety
271 | feature
272 | world_shape: shape
273 | Tuple that contains the shape of the grid world n x m
274 | step_size: tuple of floats
275 | Tuple that contains the step sizes along each direction to
276 | create a linearly spaced grid
277 | beta: float
278 | Scaling factor to determine the amplitude of the confidence
279 | intervals
280 | altitudes: np.array
281 | It contains the flattened n x m matrix where the altitudes
282 | of all the points in the map are stored
283 | h: float
284 | Safety threshold
285 | S0: np.array
286 | n_states x (n_actions + 1) array of booleans that indicates which
287 | states (first column) and which state-action pairs belong to the
288 | initial safe seed. Notice that, by convention we initialize all
289 | the states to be safe
290 | S_hat0: np.array or nan
291 | n_states x (n_actions + 1) array of booleans that indicates which
292 | states (first column) and which state-action pairs belong to the
293 | initial safe seed and satisfy recovery and reachability properties.
294 | If it is nan, such a boolean matrix is computed during
295 | initialization
296 | noise: float
297 | Standard deviation of the measurement noise
298 | L: float
299 | Lipschitz constant to compute expanders
300 | update_dist: int
301 | Distance in unweighted graph used for confidence interval update.
302 | A sample will only influence other nodes within this distance.
303 | """
304 | def __init__(self, gp, world_shape, step_size, beta, altitudes, h, S0,
305 | S_hat0, L, update_dist=0):
306 |
307 | # Safe set
308 | self.S = S0.copy()
309 | graph = grid_world_graph(world_shape)
310 | link_graph_and_safe_set(graph, self.S)
311 | super(GridWorld, self).__init__(graph, gp, S_hat0, h, L, beta=2)
312 |
313 | self.altitudes = altitudes
314 | self.world_shape = world_shape
315 | self.step_size = step_size
316 | self.update_dist = update_dist
317 |
318 | # Grids for the map
319 | self.coord = grid(self.world_shape, self.step_size)
320 |
321 | # Distances
322 | self.distance_matrix = cdist(self.coord, self.coord)
323 |
324 | # Confidence intervals
325 | self.l = np.empty(self.S.shape, dtype=float)
326 | self.u = np.empty(self.S.shape, dtype=float)
327 | self.l[:] = -np.inf
328 | self.u[:] = np.inf
329 | self.l[self.S] = h
330 |
331 | # Prediction with difference of altitudes
332 | states_ind = np.arange(np.prod(self.world_shape))
333 | states_grid = states_ind.reshape(world_shape)
334 |
335 | self._prev_up = states_grid[:, :-1].flatten()
336 | self._next_up = states_grid[:, 1:].flatten()
337 | self._prev_right = states_grid[:-1, :].flatten()
338 | self._next_right = states_grid[1:, :].flatten()
339 |
340 | self._mat_up = np.hstack((self.coord[self._prev_up, :],
341 | self.coord[self._next_up, :]))
342 | self._mat_right = np.hstack((self.coord[self._prev_right, :],
343 | self.coord[self._next_right, :]))
344 |
345 | def update_confidence_interval(self, jacobian=False):
346 | """
347 | Updates the lower and the upper bound of the confidence intervals
348 | using then posterior distribution over the gradients of the altitudes
349 |
350 | Returns
351 | -------
352 | l: np.array
353 | lower bound of the safety feature (mean - beta*std)
354 | u: np.array
355 | upper bound of the safety feature (mean - beta*std)
356 | """
357 | if jacobian:
358 | # Predict safety feature
359 | mu, s = self.gp.predict_jacobian(self.coord, full_cov=False)
360 | mu = np.squeeze(mu)
361 |
362 | # Confidence interval
363 | s = self.beta * np.sqrt(s)
364 |
365 | # State are always safe
366 | self.l[:, 0] = self.u[:, 0] = self.h
367 |
368 | # Update safety feature
369 | self.l[:, [1, 2]] = -mu[:, ::-1] - s[:, ::-1]
370 | self.l[:, [3, 4]] = mu[:, ::-1] - s[:, ::-1]
371 |
372 | self.u[:, [1, 2]] = -mu[:, ::-1] + s[:, ::-1]
373 | self.u[:, [3, 4]] = mu[:, ::-1] + s[:, ::-1]
374 |
375 | elif self.update_dist > 0:
376 | # States are always safe
377 | self.l[:, 0] = self.u[:, 0] = self.h
378 |
379 | # Extract last two sampled states in the grid
380 | last_states = self.gp.X[-2:]
381 | last_nodes = states_to_nodes(last_states, self.world_shape,
382 | self.step_size)
383 |
384 | # Extract nodes to be updated
385 | nodes1 = nx.single_source_shortest_path(self.graph, last_nodes[0],
386 | self.update_dist).keys()
387 | nodes2 = nx.single_source_shortest_path(self.graph, last_nodes[1],
388 | self.update_dist).keys()
389 | update_nodes = np.union1d(nodes1, nodes2)
390 | subgraph = self.graph.subgraph(update_nodes)
391 |
392 | # Sort states to be updated according to actions
393 | prev_up = []
394 | next_up = []
395 | prev_right = []
396 | next_right = []
397 |
398 | for node1, node2, act in subgraph.edges_iter(data='action'):
399 | if act == 2:
400 | prev_right.append(node1)
401 | next_right.append(node2)
402 | elif act == 1:
403 | prev_up.append(node1)
404 | next_up.append(node2)
405 |
406 | mat_up = np.hstack((self.coord[prev_up, :],
407 | self.coord[next_up, :]))
408 | mat_right = np.hstack((self.coord[prev_right, :],
409 | self.coord[next_right, :]))
410 |
411 | # Update confidence for nodes around last sample
412 | mu_up, s_up = self.gp.predict(mat_up,
413 | kern=DifferenceKernel(self.gp.kern),
414 | full_cov=False)
415 | s_up = self.beta * np.sqrt(s_up)
416 |
417 | self.l[prev_up, 1, None] = mu_up - s_up
418 | self.u[prev_up, 1, None] = mu_up + s_up
419 |
420 | self.l[next_up, 3, None] = -mu_up - s_up
421 | self.u[next_up, 3, None] = -mu_up + s_up
422 |
423 | mu_right, s_right = self.gp.predict(mat_right,
424 | kern=DifferenceKernel(
425 | self.gp.kern),
426 | full_cov=False)
427 | s_right = self.beta * np.sqrt(s_right)
428 |
429 | self.l[prev_right, 2, None] = mu_right - s_right
430 | self.u[prev_right, 2, None] = mu_right + s_right
431 |
432 | self.l[next_right, 4, None] = -mu_right - s_right
433 | self.u[next_right, 4, None] = -mu_right + s_right
434 |
435 | else:
436 | # Initialize to unsafe
437 | self.l[:] = self.u[:] = self.h - 1
438 |
439 | # States are always safe
440 | self.l[:, 0] = self.u[:, 0] = self.h
441 |
442 | # Actions up and down
443 | mu_up, s_up = self.gp.predict(self._mat_up,
444 | kern=DifferenceKernel(self.gp.kern),
445 | full_cov=False)
446 | s_up = self.beta * np.sqrt(s_up)
447 |
448 | self.l[self._prev_up, 1, None] = mu_up - s_up
449 | self.u[self._prev_up, 1, None] = mu_up + s_up
450 |
451 | self.l[self._next_up, 3, None] = -mu_up - s_up
452 | self.u[self._next_up, 3, None] = -mu_up + s_up
453 |
454 | # Actions left and right
455 | mu_right, s_right = self.gp.predict(self._mat_right,
456 | kern=DifferenceKernel(
457 | self.gp.kern),
458 | full_cov=False)
459 | s_right = self.beta * np.sqrt(s_right)
460 | self.l[self._prev_right, 2, None] = mu_right - s_right
461 | self.u[self._prev_right, 2, None] = mu_right + s_right
462 |
463 | self.l[self._next_right, 4, None] = -mu_right - s_right
464 | self.u[self._next_right, 4, None] = -mu_right + s_right
465 |
466 | def compute_expanders(self):
467 | """Compute the expanders based on the current estimate of S_hat."""
468 | self.G[:] = False
469 |
470 | for action in range(1, self.S_hat.shape[1]):
471 |
472 | # action-specific safe set
473 | s_hat = self.S_hat[:, action]
474 |
475 | # Extract distance from safe points to non safe ones
476 | distance = self.distance_matrix[np.ix_(s_hat, ~self.S[:, action])]
477 |
478 | # Update expanders for this particular action
479 | self.G[s_hat, action] = np.any(
480 | self.u[s_hat, action, None] - self.L * distance >= self.h,
481 | axis=1)
482 |
483 | def update_sets(self):
484 | """
485 | Update the sets S, S_hat and G taking with the available observation
486 | """
487 | self.update_confidence_interval()
488 | # self.S[:] = self.l >= self.h
489 | self.S |= self.l >= self.h
490 |
491 | self.compute_S_hat()
492 | self.compute_expanders()
493 |
494 | def plot_S(self, safe_set, action=0):
495 | """
496 | Plot the set of safe states
497 |
498 | Parameters
499 | ----------
500 | safe_set: np.array(dtype=bool)
501 | n_states x (n_actions + 1) array of boolean values that indicates
502 | the safe set
503 | action: int
504 | The action for which we want to plot the safe set.
505 | """
506 | plt.figure(action)
507 | plt.imshow(np.reshape(safe_set[:, action], self.world_shape).T,
508 | origin='lower', interpolation='nearest', vmin=0, vmax=1)
509 | plt.title('action {0}'.format(action))
510 | plt.show()
511 |
512 | def add_observation(self, node, action):
513 | """
514 | Add an observation of the given state-action pair.
515 |
516 | Observing the pair (s, a) means adding an observation of the altitude
517 | at s and an observation of the altitude at f(s, a)
518 |
519 | Parameters
520 | ----------
521 | node: int
522 | Node index
523 | action: int
524 | Action index
525 | """
526 | # Observation of next state
527 | for _, next_node, data in self.graph.edges_iter(node, data=True):
528 | if data['action'] == action:
529 | break
530 |
531 | self.add_gp_observations(self.coord[[node, next_node], :],
532 | self.altitudes[[node, next_node], None])
533 |
534 | def target_sample(self):
535 | """
536 | Compute the next target (s, a) to sample (highest uncertainty within
537 | G or S_hat)
538 |
539 | Returns
540 | -------
541 | node: int
542 | The next node to sample
543 | action: int
544 | The next action to sample
545 | """
546 | if np.any(self.G):
547 | # Extract elements in G
548 | expander_id = np.nonzero(self.G)
549 |
550 | # Compute uncertainty
551 | w = self.u[self.G] - self.l[self.G]
552 |
553 | # Find max uncertainty
554 | max_id = np.argmax(w)
555 |
556 | else:
557 | print('No expanders, using most uncertain element in S_hat'
558 | 'instead.')
559 |
560 | # Extract elements in S_hat
561 | expander_id = np.nonzero(self.S_hat)
562 |
563 | # Compute uncertainty
564 | w = self.u[self.S_hat] - self.l[self.S_hat]
565 |
566 | # Find max uncertainty
567 | max_id = np.argmax(w)
568 |
569 | return expander_id[0][max_id], expander_id[1][max_id]
570 |
571 |
572 | def states_to_nodes(states, world_shape, step_size):
573 | """Convert physical states to node numbers.
574 |
575 | Parameters
576 | ----------
577 | states: np.array
578 | States with physical coordinates
579 | world_shape: tuple
580 | The size of the grid_world
581 | step_size: tuple
582 | The step size of the grid world
583 |
584 | Returns
585 | -------
586 | nodes: np.array
587 | The node indices corresponding to the states
588 | """
589 | states = np.asanyarray(states)
590 | node_indices = np.rint(states / step_size).astype(np.int)
591 | return node_indices[:, 1] + world_shape[1] * node_indices[:, 0]
592 |
593 |
594 | def nodes_to_states(nodes, world_shape, step_size):
595 | """Convert node numbers to physical states.
596 |
597 | Parameters
598 | ----------
599 | nodes: np.array
600 | Node indices of the grid world
601 | world_shape: tuple
602 | The size of the grid_world
603 | step_size: np.array
604 | The step size of the grid world
605 |
606 | Returns
607 | -------
608 | states: np.array
609 | The states in physical coordinates
610 | """
611 | nodes = np.asanyarray(nodes)
612 | step_size = np.asanyarray(step_size)
613 | return np.vstack((nodes // world_shape[1],
614 | nodes % world_shape[1])).T * step_size
615 |
616 |
617 | def grid(world_shape, step_size):
618 | """
619 | Creates grids of coordinates and indices of state space
620 |
621 | Parameters
622 | ----------
623 | world_shape: tuple
624 | Size of the grid world (rows, columns)
625 | step_size: tuple
626 | Phyiscal step size in the grid world
627 |
628 | Returns
629 | -------
630 | states_ind: np.array
631 | (n*m) x 2 array containing the indices of the states
632 | states_coord: np.array
633 | (n*m) x 2 array containing the coordinates of the states
634 | """
635 | nodes = np.arange(0, world_shape[0] * world_shape[1])
636 | return nodes_to_states(nodes, world_shape, step_size)
637 |
638 |
639 | def draw_gp_sample(kernel, world_shape, step_size):
640 | """
641 | Draws a sample from a Gaussian process distribution over a user
642 | specified grid
643 |
644 | Parameters
645 | ----------
646 | kernel: GPy kernel
647 | Defines the GP we draw a sample from
648 | world_shape: tuple
649 | Shape of the grid we use for sampling
650 | step_size: tuple
651 | Step size along any axis to find linearly spaced points
652 | """
653 | # Compute linearly spaced grid
654 | coord = grid(world_shape, step_size)
655 |
656 | # Draw a sample from GP
657 | cov = kernel.K(coord) + np.eye(coord.shape[0]) * 1e-10
658 | sample = np.random.multivariate_normal(np.zeros(coord.shape[0]), cov)
659 | return sample, coord
660 |
661 |
662 | def shortest_path(source, next_sample, G):
663 | """
664 | Computes shortest safe path from a source to the next state-action pair
665 | the agent needs to sample
666 |
667 | Parameters
668 | ----------
669 | source: int
670 | Staring node for the path
671 | next_sample: (int, int)
672 | Next state-action pair the agent needs to sample. First entry is the
673 | number that indicates the state. Second entry indicates the action
674 | G: networkx DiGraph
675 | Graph that indicates the dynamics. It is linked to S matrix
676 |
677 | Returns
678 | -------
679 | path: list
680 | shortest safe path
681 | """
682 |
683 | # Extract safe graph
684 | safe_edges = [edge for edge in G.edges_iter(data=True) if edge[2]['safe']]
685 | graph_safe = nx.DiGraph(safe_edges)
686 |
687 | # Compute shortest path
688 | target = next_sample[0]
689 | action = next_sample[1]
690 | path = nx.astar_path(graph_safe, source, target)
691 |
692 | for _, next_node, data in graph_safe.out_edges(nbunch=target, data=True):
693 | if data["action"] == action:
694 | path = path + [next_node]
695 |
696 | return path
697 |
698 |
699 | def path_to_boolean_matrix(path, graph, S):
700 | """
701 | Computes a S-like matrix for approaches where performances is based
702 | on the trajectory of the agent (e.g. unsafe or random exploration)
703 | Parameters
704 | ----------
705 | path: np.array
706 | Contains the nodes that are visited along the path
707 | graph: networkx.DiGraph
708 | Graph that indicates the dynamics
709 | S: np.array
710 | Array describing the safe set (needed for initialization)
711 |
712 | Returns
713 | -------
714 | bool_mat: np.array
715 | S-like array that is true for all the states and state-action pairs
716 | along the path
717 | """
718 |
719 | # Initialize matrix
720 | bool_mat = np.zeros_like(S, dtype=bool)
721 |
722 | # Go through path to find actions
723 | for i in range(len(path) - 1):
724 |
725 | prev = path[i]
726 | succ = path[i + 1]
727 |
728 | for _, next_node, data in graph.out_edges(nbunch=prev, data=True):
729 | if next_node == succ:
730 | bool_mat[prev, 0] = True
731 | a = data["action"]
732 | bool_mat[prev, a] = True
733 | break
734 | bool_mat[succ, 0] = True
735 | return bool_mat
736 |
737 |
738 | def safe_subpath(path, altitudes, h):
739 | """
740 | Computes the maximum subpath of path along which the safety constraint is
741 | not violated
742 | Parameters
743 | ----------
744 | path: np.array
745 | Contains the nodes that are visited along the path
746 | altitudes: np.array
747 | 1-d vector with altitudes for each node
748 | h: float
749 | Safety threshold
750 |
751 | Returns
752 | -------
753 | subpath: np.array
754 | Maximum subpath of path that fulfills the safety constraint
755 |
756 | """
757 | # Initialize subpath
758 | subpath = [path[0]]
759 |
760 | # Loop through path
761 | for j in range(len(path) - 1):
762 | prev = path[j]
763 | succ = path[j + 1]
764 |
765 | # Check safety constraint
766 | if altitudes[prev] - altitudes[succ] >= h:
767 | subpath = subpath + [succ]
768 | else:
769 | break
770 | return subpath
771 |
--------------------------------------------------------------------------------
/safemdp/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, absolute_import
2 |
3 | import unittest
4 | import GPy
5 | import numpy as np
6 | import networkx as nx
7 | from numpy.testing import *
8 |
9 | from .utilities import *
10 |
11 | from safemdp.SafeMDP_class import reachable_set, returnable_set
12 | from safemdp.grid_world import compute_true_safe_set, grid_world_graph
13 | from .SafeMDP_class import link_graph_and_safe_set
14 |
15 |
16 | class DifferenceKernelTest(unittest.TestCase):
17 |
18 | @staticmethod
19 | def _check(gp, x1, x2):
20 | """Compare the gp difference predictions on X1 and X2.
21 |
22 | Parameters
23 | ----------
24 | gp: GPy.core.GP
25 | x1: np.array
26 | x2: np.array
27 | """
28 | n = x1.shape[0]
29 |
30 | # Difference prediction with library
31 | a = np.hstack((np.eye(n), -np.eye(n)))
32 | m1, v1 = gp.predict_noiseless(np.vstack((x1, x2)), full_cov=True)
33 | m1 = a.dot(m1)
34 | v1 = np.linalg.multi_dot((a, v1, a.T))
35 |
36 | # Predict diagonal
37 | m2, v2 = gp.predict_noiseless(np.hstack((x1, x2)),
38 | kern=DifferenceKernel(gp.kern),
39 | full_cov=False)
40 |
41 | assert_allclose(m1, m2)
42 | assert_allclose(np.diag(v1), v2.squeeze())
43 |
44 | # Predict full covariance
45 | m2, v2 = gp.predict_noiseless(np.hstack((x1, x2)),
46 | kern=DifferenceKernel(gp.kern),
47 | full_cov=True)
48 |
49 | assert_allclose(m1, m2)
50 | assert_allclose(v1, v2, atol=1e-12)
51 |
52 | def test_1d(self):
53 | """Test the difference kernel for a 1D input."""
54 | # Create some GP model
55 | kernel = GPy.kern.RBF(input_dim=1, lengthscale=0.05)
56 | likelihood = GPy.likelihoods.Gaussian(variance=0.005 ** 2)
57 | x = np.linspace(0, 1, 5)[:, None]
58 | y = x ** 2
59 | gp = GPy.core.GP(x, y, kernel, likelihood)
60 |
61 | # Create test points
62 | n = 10
63 | x1 = np.linspace(0, 1, n)[:, None]
64 | x2 = x1 + np.linspace(0, 0.1, n)[::-1, None]
65 |
66 | self._check(gp, x1, x2)
67 |
68 | def test_2d(self):
69 | """Test the difference kernel for a 2D input."""
70 |
71 | # Create some GP model
72 | kernel = GPy.kern.RBF(input_dim=2, lengthscale=0.05)
73 | likelihood = GPy.likelihoods.Gaussian(variance=0.005 ** 2)
74 | x = np.hstack((np.linspace(0, 1, 5)[:, None],
75 | np.linspace(0.5, 1.5, 5)[:, None]))
76 | y = x[:, [0]] ** 2 + x[:, [1]] ** 2
77 | gp = GPy.core.GP(x, y, kernel, likelihood)
78 |
79 | # Create test points
80 | n = 10
81 |
82 | x1 = np.hstack((np.linspace(0, 1, n)[:, None],
83 | np.linspace(0.5, 1.5, n)[:, None]))
84 | x2 = x1 + np.hstack((np.linspace(0, 0.1, n)[::-1, None],
85 | np.linspace(0., 0.1, n)[::-1, None]))
86 |
87 | self._check(gp, x1, x2)
88 |
89 |
90 | class MaxOutDegreeTest(unittest.TestCase):
91 | def test_all(self):
92 | """Test the max_out_degree function."""
93 | graph = nx.DiGraph()
94 | graph.add_edges_from(((0, 1),
95 | (1, 2),
96 | (2, 3),
97 | (3, 1)))
98 | assert_(max_out_degree(graph), 1)
99 |
100 | graph.add_edge(0, 2)
101 | assert_(max_out_degree(graph), 2)
102 |
103 | graph.add_edge(2, 3)
104 | assert_(max_out_degree(graph), 2)
105 |
106 | graph.add_edge(3, 2)
107 | assert_(max_out_degree(graph), 2)
108 |
109 | graph.add_edge(3, 1)
110 | assert_(max_out_degree(graph), 3)
111 |
112 |
113 | class ReachableSetTest(unittest.TestCase):
114 |
115 | def __init__(self, *args, **kwargs):
116 | super(ReachableSetTest, self).__init__(*args, **kwargs)
117 | # 3
118 | # ^
119 | # |
120 | # 0 --> 1 --> 2 --> 0
121 | # ^
122 | # |
123 | # 4
124 | self.graph = nx.DiGraph()
125 | self.graph.add_edges_from([(0, 1),
126 | (1, 2),
127 | (2, 0),
128 | (4, 1)], action=1)
129 | self.graph.add_edge(2, 3, action=2)
130 |
131 | self.safe_set = np.ones((self.graph.number_of_nodes(),
132 | max_out_degree(self.graph) + 1),
133 | dtype=np.bool)
134 | link_graph_and_safe_set(self.graph, self.safe_set)
135 | self.true = np.zeros(self.safe_set.shape[0], dtype=np.bool)
136 |
137 | def setUp(self):
138 | self.safe_set[:] = True
139 |
140 | def _check(self):
141 | reach = reachable_set(self.graph, [0])
142 | assert_equal(reach[:, 0], self.true)
143 |
144 | def test_all_safe(self):
145 | """Test reachable set if everything is safe"""
146 | self.true[:] = [1, 1, 1, 1, 0]
147 | self._check()
148 |
149 | def test_unsafe1(self):
150 | """Test safety aspect"""
151 | self.safe_set[1, 1] = False
152 | self.true[:] = [1, 1, 0, 0, 0]
153 | self._check()
154 |
155 | def test_unsafe2(self):
156 | """Test safety aspect"""
157 | self.safe_set[2, 2] = False
158 | self.true[:] = [1, 1, 1, 0, 0]
159 | self._check()
160 |
161 | def test_unsafe3(self):
162 | """Test safety aspect"""
163 | self.safe_set[2, 1] = False
164 | self.true[:] = [1, 1, 1, 1, 0]
165 | self._check()
166 |
167 | def test_unsafe4(self):
168 | """Test safety aspect"""
169 | self.safe_set[4, 1] = False
170 | self.true[:] = [1, 1, 1, 1, 0]
171 | self._check()
172 |
173 | def test_out(self):
174 | """Test writing the output"""
175 | self.safe_set[2, 2] = False
176 | self.true[:] = [1, 1, 1, 0, 0]
177 | out = np.zeros_like(self.safe_set)
178 | reachable_set(self.graph, [0], out=out)
179 | assert_equal(out[:, 0], self.true)
180 |
181 | def test_error(self):
182 | """Check error condition"""
183 | with assert_raises(AttributeError):
184 | reachable_set(self.graph, [])
185 |
186 |
187 | class ReturnableSetTest(unittest.TestCase):
188 |
189 | def __init__(self, *args, **kwargs):
190 | super(ReturnableSetTest, self).__init__(*args, **kwargs)
191 | # 3
192 | # ^
193 | # |
194 | # 0 --> 1 --> 2 --> 0
195 | # ^
196 | # |
197 | # 4
198 | self.graph = nx.DiGraph()
199 | self.graph.add_edges_from([(0, 1),
200 | (1, 2),
201 | (2, 0),
202 | (4, 1)], action=1)
203 | self.graph.add_edge(2, 3, action=2)
204 | self.graph_rev = self.graph.reverse()
205 |
206 | self.safe_set = np.ones((self.graph.number_of_nodes(),
207 | max_out_degree(self.graph) + 1),
208 | dtype=np.bool)
209 | link_graph_and_safe_set(self.graph, self.safe_set)
210 | self.true = np.zeros(self.safe_set.shape[0], dtype=np.bool)
211 |
212 | def setUp(self):
213 | self.safe_set[:] = True
214 |
215 | def _check(self):
216 | ret = returnable_set(self.graph, self.graph_rev, [0])
217 | assert_equal(ret[:, 0], self.true)
218 |
219 | def test_all_safe(self):
220 | """Test reachable set if everything is safe"""
221 | self.true[:] = [1, 1, 1, 0, 1]
222 | self._check()
223 |
224 | def test_unsafe1(self):
225 | """Test safety aspect"""
226 | self.safe_set[1, 1] = False
227 | self.true[:] = [1, 0, 1, 0, 0]
228 | self._check()
229 |
230 | def test_unsafe2(self):
231 | """Test safety aspect"""
232 | self.safe_set[2, 1] = False
233 | self.true[:] = [1, 0, 0, 0, 0]
234 | self._check()
235 |
236 | def test_unsafe3(self):
237 | """Test safety aspect"""
238 | self.safe_set[2, 2] = False
239 | self.true[:] = [1, 1, 1, 0, 1]
240 | self._check()
241 |
242 | def test_unsafe4(self):
243 | """Test safety aspect"""
244 | self.safe_set[4, 1] = False
245 | self.true[:] = [1, 1, 1, 0, 0]
246 | self._check()
247 |
248 | def test_out(self):
249 | """Test writing the output"""
250 | self.safe_set[1, 1] = False
251 | self.true[:] = [1, 0, 1, 0, 0]
252 | out = np.zeros_like(self.safe_set)
253 | returnable_set(self.graph, self.graph_rev, [0], out=out)
254 | assert_equal(out[:, 0], self.true)
255 |
256 | def test_error(self):
257 | """Check error condition"""
258 | with assert_raises(AttributeError):
259 | reachable_set(self.graph, [])
260 |
261 |
262 | class GridWorldGraphTest(unittest.TestCase):
263 | """Test the grid_world_graph function."""
264 |
265 | def test(self):
266 | """Simple test"""
267 | # 1 2 3
268 | # 4 5 6
269 | graph = grid_world_graph((2, 3))
270 | graph_true = nx.DiGraph()
271 | graph_true.add_edges_from(((1, 2),
272 | (2, 3),
273 | (4, 5),
274 | (5, 6)),
275 | action=1)
276 | graph_true.add_edges_from(((1, 4),
277 | (2, 5),
278 | (3, 6)),
279 | action=2)
280 | graph_true.add_edges_from(((2, 1),
281 | (3, 2),
282 | (5, 4),
283 | (6, 5)),
284 | action=3)
285 | graph_true.add_edges_from(((4, 1),
286 | (5, 2),
287 | (6, 3)),
288 | action=4)
289 |
290 | assert_(nx.is_isomorphic(graph, graph_true))
291 |
292 |
293 | class TestTrueSafeSet(unittest.TestCase):
294 |
295 | def test_differences_safe(self):
296 | altitudes = np.array([[1, 2, 3],
297 | [2, 3, 4]])
298 | safe = compute_true_safe_set((2, 3), altitudes.reshape(-1), -1)
299 | true_safe = np.array([[1, 1, 1, 1, 1, 1],
300 | [1, 1, 0, 1, 1, 0],
301 | [1, 1, 1, 0, 0, 0],
302 | [0, 1, 1, 0, 1, 1],
303 | [0, 0, 0, 1, 1, 1]],
304 | dtype=np.bool).T
305 |
306 | assert_equal(safe, true_safe)
307 |
308 | def test_differences_unsafe(self):
309 | altitudes = np.array([[1, 0, 3],
310 | [2, 3, 0]])
311 | safe = compute_true_safe_set((2, 3), altitudes.reshape(-1), -1)
312 | true_safe = np.array([[1, 1, 1, 1, 1, 1],
313 | [1, 0, 0, 1, 1, 0],
314 | [1, 0, 1, 0, 0, 0],
315 | [0, 1, 1, 0, 1, 0],
316 | [0, 0, 0, 1, 1, 0]],
317 | dtype=np.bool).T
318 | assert_equal(safe, true_safe)
319 |
320 |
321 | if __name__ == '__main__':
322 | unittest.main()
323 |
--------------------------------------------------------------------------------
/safemdp/utilities.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import numpy as np
4 |
5 | __all__ = ['DifferenceKernel', 'max_out_degree']
6 |
7 |
8 | class DifferenceKernel(object):
9 | """
10 | A fake kernel that can be used to predict differences two function values.
11 |
12 | Given a gp based on measurements, we aim to predict the difference between
13 | the function values at two different test points, X1 and X2; that is, we
14 | want to obtain mean and variance of f(X1) - f(X2). Using this fake
15 | kernel, this can be achieved with
16 | `mean, var = gp.predict(np.hstack((X1, X2)), kern=DiffKernel(gp.kern))`
17 |
18 | Parameters
19 | ----------
20 | kernel: GPy.kern.*
21 | The kernel used by the GP
22 | """
23 |
24 | def __init__(self, kernel):
25 | self.kern = kernel
26 |
27 | def K(self, x1, x2=None):
28 | """Equivalent of kern.K
29 |
30 | If only x1 is passed then it is assumed to contain the data for both
31 | whose differences we are computing. Otherwise, x2 will contain these
32 | extended states (see PosteriorExact._raw_predict in
33 | GPy/inference/latent_function_inference0/posterior.py)
34 |
35 | Parameters
36 | ----------
37 | x1: np.array
38 | x2: np.array
39 | """
40 | dim = self.kern.input_dim
41 | if x2 is None:
42 | x10 = x1[:, :dim]
43 | x11 = x1[:, dim:]
44 | return (self.kern.K(x10) + self.kern.K(x11) -
45 | self.kern.K(x10, x11) - self.kern.K(x11, x10))
46 | else:
47 | x20 = x2[:, :dim]
48 | x21 = x2[:, dim:]
49 | return self.kern.K(x1, x20) - self.kern.K(x1, x21)
50 |
51 | def Kdiag(self, x):
52 | """Equivalent of kern.Kdiag for the difference prediction.
53 |
54 | Parameters
55 | ----------
56 | x: np.array
57 | """
58 | dim = self.kern.input_dim
59 | x0 = x[:, :dim]
60 | x1 = x[:, dim:]
61 | return (self.kern.Kdiag(x0) + self.kern.Kdiag(x1) -
62 | 2 * np.diag(self.kern.K(x0, x1)))
63 |
64 |
65 | def max_out_degree(graph):
66 | """Compute the maximum out_degree of a graph
67 |
68 | Parameters
69 | ----------
70 | graph: nx.DiGraph
71 |
72 | Returns
73 | -------
74 | max_out_degree: int
75 | The maximum out_degree of the graph
76 | """
77 | def degree_generator(graph):
78 | for _, degree in graph.out_degree_iter():
79 | yield degree
80 | return max(degree_generator(graph))
81 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 |
4 |
5 | def read(fname):
6 | """Read the name relative to current directory"""
7 | return open(os.path.join(os.path.dirname(__file__), fname)).read()
8 |
9 | setup(
10 | name="safemdp",
11 | version="1.0",
12 | author="Matteo Turchetta, Felix Berkenkamp",
13 | author_email="matteotu@ethz.ch, befelix@ethz.ch",
14 | description=("Safe exploration in MDPs"),
15 | license="MIT",
16 | url="http://packages.python.org/an_example_pypi_project",
17 | packages=['safemdp'],
18 | long_description=read('README.md'),
19 | install_requires=[
20 | 'GPy >= 0.8.0',
21 | 'numpy >= 1.7.2',
22 | 'scipy >= 0.16',
23 | 'matplotlib >= 1.5.0',
24 | 'networkx >= 1.1',
25 | ],
26 | )
27 |
--------------------------------------------------------------------------------
/test_code.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | module="safemdp"
4 |
5 | get_script_dir () {
6 | SOURCE="${BASH_SOURCE[0]}"
7 | # While $SOURCE is a symlink, resolve it
8 | while [ -h "$SOURCE" ]; do
9 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
10 | SOURCE="$( readlink "$SOURCE" )"
11 | # If $SOURCE was a relative symlink (so no "/" as prefix, need to resolve it relative to the symlink base directory
12 | [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE"
13 | done
14 | DIR="$( cd -P "$( dirname "$SOURCE" )" && pwd )"
15 | echo "$DIR"
16 | }
17 |
18 | # Change to script root
19 | cd $(get_script_dir)
20 | GREEN='\033[0;32m'
21 | NC='\033[0m'
22 |
23 | # Run style tests
24 | echo -e "${GREEN}Running style tests.${NC}"
25 | flake8 $module --exclude test*.py,__init__.py --ignore=E402,W503 --show-source
26 |
27 | # Ignore import errors for __init__ and tests
28 | flake8 $module --filename=__init__.py,test*.py --ignore=F,E402,W503 --show-source
29 |
30 | echo -e "${GREEN}Testing docstring conventions.${NC}"
31 | # Test docstring conventions
32 | pydocstyle safemdp --match='(?!__init__).*\.py' 2>&1 | grep -v "WARNING: __all__"
33 |
34 | # Run unit tests
35 | echo -e "${GREEN}Running unit tests.${NC}"
36 | nosetests --with-doctest --with-coverage --cover-erase --cover-package=safemdp $module
37 |
38 | # Export html
39 | coverage html
40 |
41 |
--------------------------------------------------------------------------------