├── .coveragerc ├── .gitignore ├── .pypirc ├── .travis.yml ├── DESCRIPTION.rst ├── LICENSE.txt ├── Makefile ├── README.md ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ └── index.rst ├── lspi ├── __init__.py ├── basis_functions.py ├── domains.py ├── lspi.py ├── policy.py ├── sample.py └── solvers.py ├── lspi_testsuite ├── __init__.py ├── test_basis_functions.py ├── test_domains.py ├── test_learn.py ├── test_learning_chain_domain.py ├── test_policy.py ├── test_sample.py └── test_solvers.py ├── requirements.txt ├── setup.cfg └── setup.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | 5 | [report] 6 | # Regexes for lines to exclude from consideration 7 | exclude_lines = 8 | # Have to re-enable the standard pragma 9 | pragma: no cover 10 | 11 | # Don't complain about missing debug-only code: 12 | def __repr__ 13 | if self\.debug 14 | 15 | # Don't complain if tests don't hit defensive assertion code: 16 | raise AssertionError 17 | raise NotImplementedError 18 | 19 | # Don't complain if non-runnable code isn't run: 20 | if 0: 21 | if __name__ == .__main__.: 22 | 23 | ignore_errors = True -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | lspi-python-docs.zip 2 | 3 | # Livereload 4 | Gemfile 5 | Gemfile.lock 6 | Guardfile 7 | 8 | # Sphinx 9 | docs/build/ 10 | docs/source/autodoc 11 | 12 | # autoenv script 13 | .env 14 | 15 | # virtualenv 16 | lspienv/ 17 | 18 | ########## 19 | # Python # 20 | ########## 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | env/ 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .tox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *,cover 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | ############# 81 | # JetBrains # 82 | ############# 83 | 84 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm 85 | 86 | *.iml 87 | 88 | ## Directory-based project format: 89 | .idea/ 90 | # if you remove the above rule, at least ignore the following: 91 | 92 | # User-specific stuff: 93 | # .idea/workspace.xml 94 | # .idea/tasks.xml 95 | # .idea/dictionaries 96 | 97 | # Sensitive or high-churn files: 98 | # .idea/dataSources.ids 99 | # .idea/dataSources.xml 100 | # .idea/sqlDataSources.xml 101 | # .idea/dynamic.xml 102 | # .idea/uiDesigner.xml 103 | 104 | # Gradle: 105 | # .idea/gradle.xml 106 | # .idea/libraries 107 | 108 | # Mongo Explorer plugin: 109 | # .idea/mongoSettings.xml 110 | 111 | ## File-based project format: 112 | *.ipr 113 | *.iws 114 | 115 | ## Plugin-specific files: 116 | 117 | # IntelliJ 118 | /out/ 119 | 120 | # mpeltonen/sbt-idea plugin 121 | .idea_modules/ 122 | 123 | # JIRA plugin 124 | atlassian-ide-plugin.xml 125 | 126 | # Crashlytics plugin (for Android Studio and IntelliJ) 127 | com_crashlytics_export_strings.xml 128 | crashlytics.properties 129 | crashlytics-build.properties 130 | -------------------------------------------------------------------------------- /.pypirc: -------------------------------------------------------------------------------- 1 | [distutils] 2 | index-servers=pypi 3 | 4 | [pypi] 5 | repository = https://pypi.python.org/pypi 6 | username = rhololkeolke 7 | 8 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "2.7" 4 | before_install: 5 | - sudo apt-get install -qq python-numpy python-scipy 6 | virtualenv: 7 | system_site_packages: true 8 | install: "pip install -r requirements.txt" 9 | script: make travis-test 10 | -------------------------------------------------------------------------------- /DESCRIPTION.rst: -------------------------------------------------------------------------------- 1 | LSPI Python 2 | =========== 3 | 4 | This is a Python implementation of the Least Squares Policy Iteration (LSPI) reinforcement learning algorithm. 5 | For more information on the algorithm please refer to the paper 6 | 7 | | “Least-Squares Policy Iteration.” 8 | | Lagoudakis, Michail G., and Ronald Parr. 9 | | Journal of Machine Learning Research 4, 2003. 10 | | ``_ 11 | 12 | You can also visit their website where more information and a Matlab version is provided. 13 | 14 | ``_ -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015, Devin Schwab 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | * Neither the name of lspi-python nor the names of its 15 | contributors may be used to endorse or promote products derived from 16 | this software without specific prior written permission. 17 | 18 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: html-docs clean clean-pyc clean-tests clean-releases test all flake8 sphinx-apidoc travis-test release upload-release 2 | 3 | all: lspienv flake8 test html-docs 4 | 5 | lspienv: lspienv/bin/activate 6 | 7 | lspienv/bin/activate: requirements.txt 8 | test -d lspienv || virtualenv lspienv 9 | . lspienv/bin/activate; pip install -Ur requirements.txt 10 | touch lspienv/bin/activate 11 | 12 | flake8: 13 | . lspienv/bin/activate; flake8 lspi 14 | 15 | sphinx-apidoc: 16 | . lspienv/bin/activate; sphinx-apidoc -f -e -o docs/source/autodoc lspi 17 | 18 | html-docs: sphinx-apidoc 19 | . lspienv/bin/activate; PYTHONPATH=.. $(MAKE) -C docs html 20 | 21 | clean: clean-pyc clean-docs clean-tests clean-releases 22 | 23 | 24 | clean-pyc: 25 | find . -name '*.pyc' -exec rm -f {} + 26 | find . -name '*.pyo' -exec rm -f {} + 27 | 28 | clean-docs: 29 | $(MAKE) -C docs clean 30 | 31 | clean-releases: 32 | rm -rf dist/ 33 | rm -rf lspi_python.egg-info/ 34 | rm -rf build/ 35 | rm -f lspi-python-docs.zip 36 | 37 | clean-tests: 38 | rm -rf htmlcov/ 39 | 40 | test: 41 | . lspienv/bin/activate; nosetests --config=setup.cfg lspi_testsuite 42 | 43 | travis-test: 44 | nosetests --config=setup.cfg lspi_testsuite 45 | 46 | release: lspienv flake8 test html-docs 47 | . lspienv/bin/activate; python setup.py sdist bdist_wheel 48 | zip -r lspi-python-docs.zip docs/build/html/* 49 | 50 | upload-release: release 51 | . lspienv/bin/activate; twine upload -p $@ dist/* 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LSPI Python 2 | 3 | [![Build Status](https://travis-ci.org/rhololkeolke/lspi-python.svg?branch=master)](https://travis-ci.org/rhololkeolke/lspi-python) 4 | 5 | This is a Python implementation of the Least Squares Policy Iteration (LSPI) reinforcement learning algorithm. 6 | For more information on the algorithm please refer to the paper 7 | 8 | “Least-Squares Policy Iteration.” 9 | Lagoudakis, Michail G., and Ronald Parr. 10 | Journal of Machine Learning Research 4, 2003. 11 | [https://www.cs.duke.edu/research/AI/LSPI/jmlr03.pdf](https://www.cs.duke.edu/research/AI/LSPI/jmlr03.pdf) 12 | 13 | You can also visit their website where more information and a Matlab version is provided. 14 | 15 | [http://www.cs.duke.edu/research/AI/LSPI/](http://www.cs.duke.edu/research/AI/LSPI/) 16 | 17 | ## Requirements 18 | 19 | The requirements.txt file contains the python module requirements to use this 20 | library, run the tests, and generate the docs. To install all of the listed 21 | requirements automatically you can use the command 22 | 23 | ``` 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ## Testing 28 | 29 | If you have nosetests you can run the tests with `nosetests --config=setup.cfg lspi_testsuite`. 30 | If you have virtual environment installed you can run `make test` which will automatically create a virtual environment 31 | with all of the dependencies and then run the tests. 32 | 33 | ## Docs 34 | 35 | To generate the docs you will need sphinx. If you have virtual environment installed you can run 36 | `make html-docs`. This will automatically create a virtual environment with all of the dependencies 37 | and then run sphinx. The output will exist in `docs/build/html/`. 38 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) source 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make ' where is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/LSPIPython.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/LSPIPython.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $$HOME/.local/share/devhelp/LSPIPython" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/LSPIPython" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% source 10 | set I18NSPHINXOPTS=%SPHINXOPTS% source 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 31 | echo. text to make text files 32 | echo. man to make manual pages 33 | echo. texinfo to make Texinfo files 34 | echo. gettext to make PO message catalogs 35 | echo. changes to make an overview over all changed/added/deprecated items 36 | echo. xml to make Docutils-native XML files 37 | echo. pseudoxml to make pseudoxml-XML files for display purposes 38 | echo. linkcheck to check all external links for integrity 39 | echo. doctest to run all doctests embedded in the documentation if enabled 40 | goto end 41 | ) 42 | 43 | if "%1" == "clean" ( 44 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 45 | del /q /s %BUILDDIR%\* 46 | goto end 47 | ) 48 | 49 | 50 | %SPHINXBUILD% 2> nul 51 | if errorlevel 9009 ( 52 | echo. 53 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 54 | echo.installed, then set the SPHINXBUILD environment variable to point 55 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 56 | echo.may add the Sphinx directory to PATH. 57 | echo. 58 | echo.If you don't have Sphinx installed, grab it from 59 | echo.http://sphinx-doc.org/ 60 | exit /b 1 61 | ) 62 | 63 | if "%1" == "html" ( 64 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 65 | if errorlevel 1 exit /b 1 66 | echo. 67 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 68 | goto end 69 | ) 70 | 71 | if "%1" == "dirhtml" ( 72 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 73 | if errorlevel 1 exit /b 1 74 | echo. 75 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 76 | goto end 77 | ) 78 | 79 | if "%1" == "singlehtml" ( 80 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 81 | if errorlevel 1 exit /b 1 82 | echo. 83 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 84 | goto end 85 | ) 86 | 87 | if "%1" == "pickle" ( 88 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 89 | if errorlevel 1 exit /b 1 90 | echo. 91 | echo.Build finished; now you can process the pickle files. 92 | goto end 93 | ) 94 | 95 | if "%1" == "json" ( 96 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 97 | if errorlevel 1 exit /b 1 98 | echo. 99 | echo.Build finished; now you can process the JSON files. 100 | goto end 101 | ) 102 | 103 | if "%1" == "htmlhelp" ( 104 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 105 | if errorlevel 1 exit /b 1 106 | echo. 107 | echo.Build finished; now you can run HTML Help Workshop with the ^ 108 | .hhp project file in %BUILDDIR%/htmlhelp. 109 | goto end 110 | ) 111 | 112 | if "%1" == "qthelp" ( 113 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 114 | if errorlevel 1 exit /b 1 115 | echo. 116 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 117 | .qhcp project file in %BUILDDIR%/qthelp, like this: 118 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\LSPIPython.qhcp 119 | echo.To view the help file: 120 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\LSPIPython.ghc 121 | goto end 122 | ) 123 | 124 | if "%1" == "devhelp" ( 125 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 126 | if errorlevel 1 exit /b 1 127 | echo. 128 | echo.Build finished. 129 | goto end 130 | ) 131 | 132 | if "%1" == "epub" ( 133 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 134 | if errorlevel 1 exit /b 1 135 | echo. 136 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 137 | goto end 138 | ) 139 | 140 | if "%1" == "latex" ( 141 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 142 | if errorlevel 1 exit /b 1 143 | echo. 144 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 145 | goto end 146 | ) 147 | 148 | if "%1" == "latexpdf" ( 149 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 150 | cd %BUILDDIR%/latex 151 | make all-pdf 152 | cd %BUILDDIR%/.. 153 | echo. 154 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 155 | goto end 156 | ) 157 | 158 | if "%1" == "latexpdfja" ( 159 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 160 | cd %BUILDDIR%/latex 161 | make all-pdf-ja 162 | cd %BUILDDIR%/.. 163 | echo. 164 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 165 | goto end 166 | ) 167 | 168 | if "%1" == "text" ( 169 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 170 | if errorlevel 1 exit /b 1 171 | echo. 172 | echo.Build finished. The text files are in %BUILDDIR%/text. 173 | goto end 174 | ) 175 | 176 | if "%1" == "man" ( 177 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 178 | if errorlevel 1 exit /b 1 179 | echo. 180 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 181 | goto end 182 | ) 183 | 184 | if "%1" == "texinfo" ( 185 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 186 | if errorlevel 1 exit /b 1 187 | echo. 188 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 189 | goto end 190 | ) 191 | 192 | if "%1" == "gettext" ( 193 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 194 | if errorlevel 1 exit /b 1 195 | echo. 196 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 197 | goto end 198 | ) 199 | 200 | if "%1" == "changes" ( 201 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 202 | if errorlevel 1 exit /b 1 203 | echo. 204 | echo.The overview file is in %BUILDDIR%/changes. 205 | goto end 206 | ) 207 | 208 | if "%1" == "linkcheck" ( 209 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 210 | if errorlevel 1 exit /b 1 211 | echo. 212 | echo.Link check complete; look for any errors in the above output ^ 213 | or in %BUILDDIR%/linkcheck/output.txt. 214 | goto end 215 | ) 216 | 217 | if "%1" == "doctest" ( 218 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 219 | if errorlevel 1 exit /b 1 220 | echo. 221 | echo.Testing of doctests in the sources finished, look at the ^ 222 | results in %BUILDDIR%/doctest/output.txt. 223 | goto end 224 | ) 225 | 226 | if "%1" == "xml" ( 227 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 228 | if errorlevel 1 exit /b 1 229 | echo. 230 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 231 | goto end 232 | ) 233 | 234 | if "%1" == "pseudoxml" ( 235 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 236 | if errorlevel 1 exit /b 1 237 | echo. 238 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 239 | goto end 240 | ) 241 | 242 | :end 243 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # LSPI Python documentation build configuration file, created by 4 | # sphinx-quickstart on Thu Mar 19 04:02:44 2015. 5 | # 6 | # This file is execfile()d with the current directory set to its 7 | # containing dir. 8 | # 9 | # Note that not all possible configuration values are present in this 10 | # autogenerated file. 11 | # 12 | # All configuration values have a default; values that are commented out 13 | # serve to show the default. 14 | 15 | import sys 16 | import os 17 | 18 | # If extensions (or modules to document with autodoc) are in another directory, 19 | # add these directories to sys.path here. If the directory is relative to the 20 | # documentation root, use os.path.abspath to make it absolute, like shown here. 21 | #sys.path.insert(0, os.path.abspath('.')) 22 | 23 | # -- General configuration ------------------------------------------------ 24 | 25 | # If your documentation needs a minimal Sphinx version, state it here. 26 | #needs_sphinx = '1.0' 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | 'sphinx.ext.autodoc', 33 | 'sphinx.ext.doctest', 34 | 'sphinx.ext.todo', 35 | 'sphinx.ext.coverage', 36 | 'sphinx.ext.mathjax', 37 | 'sphinx.ext.viewcode', 38 | 'sphinxcontrib.napoleon' 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # The suffix of source filenames. 45 | source_suffix = '.rst' 46 | 47 | # The encoding of source files. 48 | #source_encoding = 'utf-8-sig' 49 | 50 | # The master toctree document. 51 | master_doc = 'index' 52 | 53 | # General information about the project. 54 | project = u'LSPI Python' 55 | copyright = u'2015, Devin Schwab' 56 | 57 | # The version info for the project you're documenting, acts as replacement for 58 | # |version| and |release|, also used in various other places throughout the 59 | # built documents. 60 | # 61 | # The short X.Y version. 62 | version = '0.0.0' 63 | # The full version, including alpha/beta/rc tags. 64 | release = '0.0.0' 65 | 66 | # The language for content autogenerated by Sphinx. Refer to documentation 67 | # for a list of supported languages. 68 | #language = None 69 | 70 | # There are two options for replacing |today|: either, you set today to some 71 | # non-false value, then it is used: 72 | #today = '' 73 | # Else, today_fmt is used as the format for a strftime call. 74 | #today_fmt = '%B %d, %Y' 75 | 76 | # List of patterns, relative to source directory, that match files and 77 | # directories to ignore when looking for source files. 78 | exclude_patterns = [] 79 | 80 | # The reST default role (used for this markup: `text`) to use for all 81 | # documents. 82 | #default_role = None 83 | 84 | # If true, '()' will be appended to :func: etc. cross-reference text. 85 | #add_function_parentheses = True 86 | 87 | # If true, the current module name will be prepended to all description 88 | # unit titles (such as .. function::). 89 | #add_module_names = True 90 | 91 | # If true, sectionauthor and moduleauthor directives will be shown in the 92 | # output. They are ignored by default. 93 | #show_authors = False 94 | 95 | # The name of the Pygments (syntax highlighting) style to use. 96 | pygments_style = 'sphinx' 97 | 98 | # A list of ignored prefixes for module index sorting. 99 | #modindex_common_prefix = [] 100 | 101 | # If true, keep warnings as "system message" paragraphs in the built documents. 102 | #keep_warnings = False 103 | 104 | 105 | # -- Options for HTML output ---------------------------------------------- 106 | 107 | 108 | # on_rtd is whether we are on readthedocs.org 109 | import os 110 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 111 | 112 | if not on_rtd: # only import and set the theme if we're building docs locally 113 | import sphinx_rtd_theme 114 | html_theme = 'sphinx_rtd_theme' 115 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 116 | else: 117 | html_theme = 'default' 118 | 119 | # Theme options are theme-specific and customize the look and feel of a theme 120 | # further. For a list of options available for each theme, see the 121 | # documentation. 122 | #html_theme_options = {} 123 | 124 | # Add any paths that contain custom themes here, relative to this directory. 125 | #html_theme_path = [] 126 | 127 | # The name for this set of Sphinx documents. If None, it defaults to 128 | # " v documentation". 129 | #html_title = None 130 | 131 | # A shorter title for the navigation bar. Default is the same as html_title. 132 | #html_short_title = None 133 | 134 | # The name of an image file (relative to this directory) to place at the top 135 | # of the sidebar. 136 | #html_logo = None 137 | 138 | # The name of an image file (within the static path) to use as favicon of the 139 | # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 140 | # pixels large. 141 | #html_favicon = None 142 | 143 | # Add any paths that contain custom static files (such as style sheets) here, 144 | # relative to this directory. They are copied after the builtin static files, 145 | # so a file named "default.css" will overwrite the builtin "default.css". 146 | html_static_path = ['_static'] 147 | 148 | # Add any extra paths that contain custom files (such as robots.txt or 149 | # .htaccess) here, relative to this directory. These files are copied 150 | # directly to the root of the documentation. 151 | #html_extra_path = [] 152 | 153 | # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, 154 | # using the given strftime format. 155 | #html_last_updated_fmt = '%b %d, %Y' 156 | 157 | # If true, SmartyPants will be used to convert quotes and dashes to 158 | # typographically correct entities. 159 | #html_use_smartypants = True 160 | 161 | # Custom sidebar templates, maps document names to template names. 162 | #html_sidebars = {} 163 | 164 | # Additional templates that should be rendered to pages, maps page names to 165 | # template names. 166 | #html_additional_pages = {} 167 | 168 | # If false, no module index is generated. 169 | #html_domain_indices = True 170 | 171 | # If false, no index is generated. 172 | #html_use_index = True 173 | 174 | # If true, the index is split into individual pages for each letter. 175 | #html_split_index = False 176 | 177 | # If true, links to the reST sources are added to the pages. 178 | #html_show_sourcelink = True 179 | 180 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 181 | #html_show_sphinx = True 182 | 183 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 184 | #html_show_copyright = True 185 | 186 | # If true, an OpenSearch description file will be output, and all pages will 187 | # contain a tag referring to it. The value of this option must be the 188 | # base URL from which the finished HTML is served. 189 | #html_use_opensearch = '' 190 | 191 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 192 | #html_file_suffix = None 193 | 194 | # Output file base name for HTML help builder. 195 | htmlhelp_basename = 'LSPIPythondoc' 196 | 197 | 198 | # -- Options for LaTeX output --------------------------------------------- 199 | 200 | latex_elements = { 201 | # The paper size ('letterpaper' or 'a4paper'). 202 | #'papersize': 'letterpaper', 203 | 204 | # The font size ('10pt', '11pt' or '12pt'). 205 | #'pointsize': '10pt', 206 | 207 | # Additional stuff for the LaTeX preamble. 208 | #'preamble': '', 209 | } 210 | 211 | # Grouping the document tree into LaTeX files. List of tuples 212 | # (source start file, target name, title, 213 | # author, documentclass [howto, manual, or own class]). 214 | latex_documents = [ 215 | ('index', 'LSPIPython.tex', u'LSPI Python Documentation', 216 | u'Devin Schwab', 'manual'), 217 | ] 218 | 219 | # The name of an image file (relative to this directory) to place at the top of 220 | # the title page. 221 | #latex_logo = None 222 | 223 | # For "manual" documents, if this is true, then toplevel headings are parts, 224 | # not chapters. 225 | #latex_use_parts = False 226 | 227 | # If true, show page references after internal links. 228 | #latex_show_pagerefs = False 229 | 230 | # If true, show URL addresses after external links. 231 | #latex_show_urls = False 232 | 233 | # Documents to append as an appendix to all manuals. 234 | #latex_appendices = [] 235 | 236 | # If false, no module index is generated. 237 | #latex_domain_indices = True 238 | 239 | 240 | # -- Options for manual page output --------------------------------------- 241 | 242 | # One entry per manual page. List of tuples 243 | # (source start file, name, description, authors, manual section). 244 | man_pages = [ 245 | ('index', 'lspipython', u'LSPI Python Documentation', 246 | [u'Devin Schwab'], 1) 247 | ] 248 | 249 | # If true, show URL addresses after external links. 250 | #man_show_urls = False 251 | 252 | 253 | # -- Options for Texinfo output ------------------------------------------- 254 | 255 | # Grouping the document tree into Texinfo files. List of tuples 256 | # (source start file, target name, title, author, 257 | # dir menu entry, description, category) 258 | texinfo_documents = [ 259 | ('index', 'LSPIPython', u'LSPI Python Documentation', 260 | u'Devin Schwab', 'LSPIPython', 'One line description of project.', 261 | 'Miscellaneous'), 262 | ] 263 | 264 | # Documents to append as an appendix to all manuals. 265 | #texinfo_appendices = [] 266 | 267 | # If false, no module index is generated. 268 | #texinfo_domain_indices = True 269 | 270 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 271 | #texinfo_show_urls = 'footnote' 272 | 273 | # If true, do not generate a @detailmenu in the "Top" node's menu. 274 | #texinfo_no_detailmenu = False 275 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. LSPI Python documentation master file, created by 2 | sphinx-quickstart on Thu Mar 19 04:02:44 2015. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to LSPI Python's documentation! 7 | ======================================= 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | autodoc/modules 15 | 16 | 17 | This is a Python implementation of the Least Squares Policy Iteration (LSPI) reinforcement learning algorithm. 18 | For more information on the algorithm please refer to the paper 19 | 20 | | “Least-Squares Policy Iteration.” 21 | | Lagoudakis, Michail G., and Ronald Parr. 22 | | Journal of Machine Learning Research 4, 2003. 23 | | ``_ 24 | 25 | You can also visit their website where more information and a Matlab version is provided. 26 | 27 | ``_ 28 | 29 | Overview 30 | -------- 31 | 32 | When using this library the first thing you must do is collect a set of samples 33 | for LSPI to learn from. Each sample should be an instance of the :class:`Sample`. 34 | These samples are then passed into the :func:`lspi.learn` method. This method 35 | takes in the list of samples, a policy, and a solver. The :class:`Policy` class 36 | provided should not need to be modified. The learn method then continuously 37 | calls the solver on the data samples and policy until the policy converges. Once 38 | the policy has converged the agent can use the policy to find the best action 39 | in every state and execute it. 40 | 41 | The Policy class contains the basis function approximation and its associated weights. 42 | Weights can be specified or if left unspecified, randomly generated. The policy 43 | also contains the probability of doing an exploration action, and the discount factor. 44 | The Policy class should not need to be modified when using this library. 45 | 46 | The basis functions all inherit from the abstract base class :class:`lspi.basis_functions.BasisFunction`. This 47 | class provides the minimum interface for a basis function. Instances of this class 48 | may contain specialized fields and methods. There are a handful of basis function 49 | classes provided in this package including: :class:`lspi.basis_functions.FakeBasis`, :class:`lspi.basis_functions.ExactBasis`, 50 | :class:`lspi.basis_functions.OneDimensionalPolynomialBasis`, and :class:`lspi.basis_functions.RadialBasisFunction`. See 51 | each class for its respective construction parameters and how the basis is calculated. 52 | You can also implement your own BasisFunctions by inheriting from the BasisFunction class and implementing 53 | all of the abstract methods. 54 | 55 | As mentioned the learn method takes in a Solver instance. This instance is responsible 56 | for performing a single policy update step given the current policy and the samples being 57 | learned from. Currently the only implemented Solver is the :class:`lspi.solvers.LSTDQSolver` which implements 58 | the algorithm from Figure 5 of the LSPI paper. There are other variants in the LSPI paper that could 59 | also be implemented. Additionally if a different matrix solving style is needed (e.g. sparse matrix solver) 60 | then a new solver can be implemented. To implement a new Solver simply create a 61 | class that inherits from the :class:`lspi.solvers.Solver` class. You must implement all of the abstract methods. 62 | 63 | For testing and demonstration purposes the simple ChainDomain from the LSPI paper is included 64 | in the :mod:`lspi.domains` module. If you wish to implement other domains it is recommended 65 | that you inherit from the :class:`lspi.domains.Domain` class and implement the abstract methods. 66 | 67 | Indices and tables 68 | ================== 69 | 70 | * :ref:`genindex` 71 | * :ref:`modindex` 72 | * :ref:`search` 73 | -------------------------------------------------------------------------------- /lspi/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Least Squares Policy Iteration (LSPI) implementation. 3 | 4 | Implements the algorithms described in the paper 5 | 6 | "Least-Squares Policy Iteration." 7 | Lagoudakis, Michail G., and Ronald Parr. 8 | Journal of Machine Learning Research 4, 2003. 9 | https://www.cs.duke.edu/research/AI/LSPI/jmlr03.pdf 10 | 11 | The implementation is based on the Matlab implementation provided by 12 | the authors. The implementation is available for download at 13 | http://www.cs.duke.edu/research/AI/LSPI/ 14 | 15 | """ 16 | 17 | import basis_functions # noqa 18 | import domains # noqa 19 | from lspi import learn # noqa 20 | from policy import Policy # noqa 21 | from sample import Sample # noqa 22 | import solvers # noqa 23 | -------------------------------------------------------------------------------- /lspi/basis_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Abstract Base Class for Basis Function and some common implementations.""" 3 | 4 | import abc 5 | 6 | import numpy as np 7 | 8 | 9 | class BasisFunction(object): 10 | 11 | r"""ABC for basis functions used by LSPI Policies. 12 | 13 | A basis function is a function that takes in a state vector and an action 14 | index and returns a vector of features. The resulting feature vector is 15 | referred to as :math:`\phi` in the LSPI paper (pg 9 of the PDF referenced 16 | in this package's documentation). The :math:`\phi` vector is dotted with 17 | the weight vector of the Policy to calculate the Q-value. 18 | 19 | The dimensions of the state vector are usually smaller than the dimensions 20 | of the :math:`\phi` vector. However, the dimensions of the :math:`\phi` 21 | vector are usually much smaller than the dimensions of an exact 22 | representation of the state which leads to significant savings when 23 | computing and storing a policy. 24 | 25 | """ 26 | 27 | __metaclass__ = abc.ABCMeta 28 | 29 | @abc.abstractmethod 30 | def size(self): 31 | r"""Return the vector size of the basis function. 32 | 33 | Returns 34 | ------- 35 | int 36 | The size of the :math:`\phi` vector. 37 | (Referred to as k in the paper). 38 | 39 | """ 40 | pass # pragma: no cover 41 | 42 | @abc.abstractmethod 43 | def evaluate(self, state, action): 44 | r"""Calculate the :math:`\phi` matrix for the given state-action pair. 45 | 46 | The way this value is calculated depends entirely on the concrete 47 | implementation of BasisFunction. 48 | 49 | Parameters 50 | ---------- 51 | state : numpy.array 52 | The state to get the features for. 53 | When calculating Q(s, a) this is the s. 54 | action : int 55 | The action index to get the features for. 56 | When calculating Q(s, a) this is the a. 57 | 58 | 59 | Returns 60 | ------- 61 | numpy.array 62 | The :math:`\phi` vector. Used by Policy to compute Q-value. 63 | 64 | """ 65 | pass # pragma: no cover 66 | 67 | @abc.abstractproperty 68 | def num_actions(self): 69 | """Return number of possible actions. 70 | 71 | Returns 72 | ------- 73 | int 74 | Number of possible actions. 75 | """ 76 | pass # pragma: no cover 77 | 78 | @staticmethod 79 | def _validate_num_actions(num_actions): 80 | """Return num_actions if valid. Otherwise raise ValueError. 81 | 82 | Return 83 | ------ 84 | int 85 | Number of possible actions. 86 | 87 | Raises 88 | ------ 89 | ValueError 90 | If num_actions < 1 91 | 92 | """ 93 | if num_actions < 1: 94 | raise ValueError('num_actions must be >= 1') 95 | return num_actions 96 | 97 | 98 | class FakeBasis(BasisFunction): 99 | 100 | r"""Basis that ignores all input. Useful for random sampling. 101 | 102 | When creating a purely random Policy a basis function is still required. 103 | This basis function just returns a :math:`\phi` equal to [1.] for all 104 | inputs. It will however, still throw exceptions for impossible values like 105 | negative action indexes. 106 | 107 | """ 108 | 109 | def __init__(self, num_actions): 110 | """Initialize FakeBasis.""" 111 | self.__num_actions = BasisFunction._validate_num_actions(num_actions) 112 | 113 | def size(self): 114 | r"""Return size of 1. 115 | 116 | Returns 117 | ------- 118 | int 119 | Size of :math:`phi` which is always 1 for FakeBasis 120 | 121 | Example 122 | ------- 123 | 124 | >>> FakeBasis().size() 125 | 1 126 | 127 | """ 128 | return 1 129 | 130 | def evaluate(self, state, action): 131 | r"""Return :math:`\phi` equal to [1.]. 132 | 133 | Parameters 134 | ---------- 135 | state : numpy.array 136 | The state to get the features for. 137 | When calculating Q(s, a) this is the s. FakeBasis ignores these 138 | values. 139 | action : int 140 | The action index to get the features for. 141 | When calculating Q(s, a) this is the a. FakeBasis ignores these 142 | values. 143 | 144 | Returns 145 | ------- 146 | numpy.array 147 | :math:`\phi` vector equal to [1.]. 148 | 149 | Raises 150 | ------ 151 | IndexError 152 | If action index is < 0 153 | 154 | Example 155 | ------- 156 | 157 | >>> FakeBasis().evaluate(np.arange(10), 0) 158 | array([ 1.]) 159 | 160 | """ 161 | if action < 0: 162 | raise IndexError('action index must be >= 0') 163 | if action >= self.num_actions: 164 | raise IndexError('action must be < num_actions') 165 | return np.array([1.]) 166 | 167 | @property 168 | def num_actions(self): 169 | """Return number of possible actions.""" 170 | return self.__num_actions 171 | 172 | @num_actions.setter 173 | def num_actions(self, value): 174 | """Set the number of possible actions. 175 | 176 | Parameters 177 | ---------- 178 | value: int 179 | Number of possible actions. Must be >= 1. 180 | 181 | Raises 182 | ------ 183 | ValueError 184 | If value < 1. 185 | 186 | """ 187 | if value < 1: 188 | raise ValueError('num_actions must be at least 1.') 189 | self.__num_actions = value 190 | 191 | 192 | class OneDimensionalPolynomialBasis(BasisFunction): 193 | 194 | """Polynomial features for a state with one dimension. 195 | 196 | Takes the value of the state and constructs a vector proportional 197 | to the specified degree and number of actions. The polynomial is first 198 | constructed as [..., 1, value, value^2, ..., value^k, ...] 199 | where k is the degree. The rest of the vector is 0. 200 | 201 | Parameters 202 | ---------- 203 | degree : int 204 | The polynomial degree. 205 | num_actions: int 206 | The total number of possible actions 207 | 208 | Raises 209 | ------ 210 | ValueError 211 | If degree is less than 0 212 | ValueError 213 | If num_actions is less than 1 214 | 215 | """ 216 | 217 | def __init__(self, degree, num_actions): 218 | """Initialize polynomial basis function.""" 219 | self.__num_actions = BasisFunction._validate_num_actions(num_actions) 220 | 221 | if degree < 0: 222 | raise ValueError('Degree must be >= 0') 223 | self.degree = degree 224 | 225 | def size(self): 226 | """Calculate the size of the basis function. 227 | 228 | The base size will be degree + 1. This basic matrix is then 229 | duplicated once for every action. Therefore the size is equal to 230 | (degree + 1) * number of actions 231 | 232 | 233 | Returns 234 | ------- 235 | int 236 | The size of the phi matrix that will be returned from evaluate. 237 | 238 | 239 | Example 240 | ------- 241 | 242 | >>> basis = OneDimensionalPolynomialBasis(2, 2) 243 | >>> basis.size() 244 | 6 245 | 246 | """ 247 | return (self.degree + 1) * self.num_actions 248 | 249 | def evaluate(self, state, action): 250 | r"""Calculate :math:`\phi` matrix for given state action pair. 251 | 252 | The :math:`\phi` matrix is used to calculate the Q function for the 253 | given policy. 254 | 255 | Parameters 256 | ---------- 257 | state : numpy.array 258 | The state to get the features for. 259 | When calculating Q(s, a) this is the s. 260 | action : int 261 | The action index to get the features for. 262 | When calculating Q(s, a) this is the a. 263 | 264 | Returns 265 | ------- 266 | numpy.array 267 | The :math:`\phi` vector. Used by Policy to compute Q-value. 268 | 269 | Raises 270 | ------ 271 | IndexError 272 | If :math:`0 \le action < num\_actions` then IndexError is raised. 273 | ValueError 274 | If the state vector has any number of dimensions other than 1 a 275 | ValueError is raised. 276 | 277 | Example 278 | ------- 279 | 280 | >>> basis = OneDimensionalPolynomialBasis(2, 2) 281 | >>> basis.evaluate(np.array([2]), 0) 282 | array([ 1., 2., 4., 0., 0., 0.]) 283 | 284 | """ 285 | if action < 0 or action >= self.num_actions: 286 | raise IndexError('Action index out of bounds') 287 | 288 | if state.shape != (1, ): 289 | raise ValueError('This class only supports one dimensional states') 290 | 291 | phi = np.zeros((self.size(), )) 292 | 293 | offset = (self.size()/self.num_actions)*action 294 | 295 | value = state[0] 296 | 297 | phi[offset:offset + self.degree + 1] = \ 298 | np.array([pow(value, i) for i in range(self.degree+1)]) 299 | 300 | return phi 301 | 302 | @property 303 | def num_actions(self): 304 | """Return number of possible actions.""" 305 | return self.__num_actions 306 | 307 | @num_actions.setter 308 | def num_actions(self, value): 309 | """Set the number of possible actions. 310 | 311 | Parameters 312 | ---------- 313 | value: int 314 | Number of possible actions. Must be >= 1. 315 | 316 | Raises 317 | ------ 318 | ValueError 319 | If value < 1. 320 | 321 | """ 322 | if value < 1: 323 | raise ValueError('num_actions must be at least 1.') 324 | self.__num_actions = value 325 | 326 | 327 | class RadialBasisFunction(BasisFunction): 328 | 329 | r"""Gaussian Multidimensional Radial Basis Function (RBF). 330 | 331 | Given a set of k means :math:`(\mu_1 , \ldots, \mu_k)` produce a feature 332 | vector :math:`(1, e^{-\gamma || s - \mu_1 ||^2}, \cdots, 333 | e^{-\gamma || s - \mu_k ||^2})` where `s` is the state vector and 334 | :math:`\gamma` is a free parameter. This vector will be padded with 335 | 0's on both sides proportional to the number of possible actions 336 | specified. 337 | 338 | Parameters 339 | ---------- 340 | means: list(numpy.array) 341 | List of numpy arrays representing :math:`(\mu_1, \ldots, \mu_k)`. 342 | Each :math:`\mu` is a numpy array with dimensions matching the state 343 | vector this basis function will be used with. If the dimensions of each 344 | vector are not equal than an exception will be raised. If no means are 345 | specified then a ValueError will be raised 346 | gamma: float 347 | Free parameter which controls the size/spread of the Gaussian "bumps". 348 | This parameter is best selected via tuning through cross validation. 349 | gamma must be > 0. 350 | num_actions: int 351 | Number of actions. Must be in range [1, :math:`\infty`] otherwise 352 | an exception will be raised. 353 | 354 | Raises 355 | ------ 356 | ValueError 357 | If means list is empty 358 | ValueError 359 | If dimensions of each mean vector do not match. 360 | ValueError 361 | If gamma is <= 0. 362 | ValueError 363 | If num_actions is less than 1. 364 | 365 | Note 366 | ---- 367 | 368 | The numpy arrays specifying the means are not copied. 369 | 370 | """ 371 | 372 | def __init__(self, means, gamma, num_actions): 373 | """Initialize RBF instance.""" 374 | self.__num_actions = BasisFunction._validate_num_actions(num_actions) 375 | 376 | if len(means) == 0: 377 | raise ValueError('You must specify at least one mean') 378 | 379 | if reduce(RadialBasisFunction.__check_mean_size, means) is None: 380 | raise ValueError('All mean vectors must have the same dimensions') 381 | 382 | self.means = means 383 | 384 | if gamma <= 0: 385 | raise ValueError('gamma must be > 0') 386 | 387 | self.gamma = gamma 388 | 389 | @staticmethod 390 | def __check_mean_size(left, right): 391 | """Apply f if the value is not None. 392 | 393 | This method is meant to be used with reduce. It will return either the 394 | right most numpy array or None if any of the array's had 395 | differing sizes. I wanted to use a Maybe monad here, 396 | but Python doesn't support that out of the box. 397 | 398 | Return 399 | ------ 400 | None or numpy.array 401 | None values will propogate through the reduce automatically. 402 | 403 | """ 404 | if left is None or right is None: 405 | return None 406 | else: 407 | if left.shape != right.shape: 408 | return None 409 | return right 410 | 411 | def size(self): 412 | r"""Calculate size of the :math:`\phi` matrix. 413 | 414 | The size is equal to the number of means + 1 times the number of 415 | number actions. 416 | 417 | Returns 418 | ------- 419 | int 420 | The size of the phi matrix that will be returned from evaluate. 421 | 422 | """ 423 | return (len(self.means) + 1) * self.num_actions 424 | 425 | def evaluate(self, state, action): 426 | r"""Calculate the :math:`\phi` matrix. 427 | 428 | Matrix will have the following form: 429 | 430 | :math:`[\cdots, 1, e^{-\gamma || s - \mu_1 ||^2}, \cdots, 431 | e^{-\gamma || s - \mu_k ||^2}, \cdots]` 432 | 433 | where the matrix will be padded with 0's on either side depending 434 | on the specified action index and the number of possible actions. 435 | 436 | Returns 437 | ------- 438 | numpy.array 439 | The :math:`\phi` vector. Used by Policy to compute Q-value. 440 | 441 | Raises 442 | ------ 443 | IndexError 444 | If :math:`0 \le action < num\_actions` then IndexError is raised. 445 | ValueError 446 | If the state vector has any number of dimensions other than 1 a 447 | ValueError is raised. 448 | 449 | """ 450 | if action < 0 or action >= self.num_actions: 451 | raise IndexError('Action index out of bounds') 452 | 453 | if state.shape != self.means[0].shape: 454 | raise ValueError('Dimensions of state must match ' 455 | 'dimensions of means') 456 | 457 | phi = np.zeros((self.size(), )) 458 | offset = (len(self.means[0])+1)*action 459 | 460 | rbf = [RadialBasisFunction.__calc_basis_component(state, 461 | mean, 462 | self.gamma) 463 | for mean in self.means] 464 | phi[offset] = 1. 465 | phi[offset+1:offset+1+len(rbf)] = rbf 466 | 467 | return phi 468 | 469 | @staticmethod 470 | def __calc_basis_component(state, mean, gamma): 471 | mean_diff = state - mean 472 | return np.exp(-gamma*np.sum(mean_diff*mean_diff)) 473 | 474 | @property 475 | def num_actions(self): 476 | """Return number of possible actions.""" 477 | return self.__num_actions 478 | 479 | @num_actions.setter 480 | def num_actions(self, value): 481 | """Set the number of possible actions. 482 | 483 | Parameters 484 | ---------- 485 | value: int 486 | Number of possible actions. Must be >= 1. 487 | 488 | Raises 489 | ------ 490 | ValueError 491 | If value < 1. 492 | 493 | """ 494 | if value < 1: 495 | raise ValueError('num_actions must be at least 1.') 496 | self.__num_actions = value 497 | 498 | 499 | class ExactBasis(BasisFunction): 500 | 501 | """Basis function with no functional approximation. 502 | 503 | This can only be used in domains with finite, discrete state-spaces. For 504 | example the Chain domain from the LSPI paper would work with this basis, 505 | but the inverted pendulum domain would not. 506 | 507 | Parameters 508 | ---------- 509 | num_states: list 510 | A list containing integers representing the number of possible values 511 | for each state variable. 512 | num_actions: int 513 | Number of possible actions. 514 | """ 515 | 516 | def __init__(self, num_states, num_actions): 517 | """Initialize ExactBasis.""" 518 | if len(np.where(num_states <= 0)[0]) != 0: 519 | raise ValueError('num_states value\'s must be > 0') 520 | 521 | self.__num_actions = BasisFunction._validate_num_actions(num_actions) 522 | self._num_states = num_states 523 | 524 | self._offsets = [1] 525 | for i in range(1, len(num_states)): 526 | self._offsets.append(self._offsets[-1]*num_states[i-1]) 527 | 528 | def size(self): 529 | r"""Return the vector size of the basis function. 530 | 531 | Returns 532 | ------- 533 | int 534 | The size of the :math:`\phi` vector. 535 | (Referred to as k in the paper). 536 | """ 537 | return reduce(lambda x, y: x*y, self._num_states, 1)*self.__num_actions 538 | 539 | def get_state_action_index(self, state, action): 540 | """Return the non-zero index of the basis. 541 | 542 | Parameters 543 | ---------- 544 | state: numpy.array 545 | The state to get the index for. 546 | action: int 547 | The state to get the index for. 548 | 549 | Returns 550 | ------- 551 | int 552 | The non-zero index of the basis 553 | 554 | Raises 555 | ------ 556 | IndexError 557 | If action index < 0 or action index > num_actions 558 | """ 559 | if action < 0: 560 | raise IndexError('action index must be >= 0') 561 | if action >= self.num_actions: 562 | raise IndexError('action must be < num_actions') 563 | 564 | base = action * int(self.size() / self.__num_actions) 565 | 566 | offset = 0 567 | for i, value in enumerate(state): 568 | offset += self._offsets[i] * state[i] 569 | 570 | return base + offset 571 | 572 | def evaluate(self, state, action): 573 | r"""Return a :math:`\phi` vector that has a single non-zero value. 574 | 575 | Parameters 576 | ---------- 577 | state: numpy.array 578 | The state to get the features for. When calculating Q(s, a) this is 579 | the s. 580 | action: int 581 | The action index to get the features for. 582 | When calculating Q(s, a) this is the a. 583 | 584 | Returns 585 | ------- 586 | numpy.array 587 | :math:`\phi` vector 588 | 589 | Raises 590 | ------ 591 | IndexError 592 | If action index < 0 or action index > num_actions 593 | ValueError 594 | If the size of the state does not match the the size of the 595 | num_states list used during construction. 596 | ValueError 597 | If any of the state variables are < 0 or >= the corresponding 598 | value in the num_states list used during construction. 599 | """ 600 | if len(state) != len(self._num_states): 601 | raise ValueError('Number of state variables must match ' 602 | + 'size of num_states.') 603 | if len(np.where(state < 0)[0]) != 0: 604 | raise ValueError('state cannot contain negative values.') 605 | for state_var, num_state_values in zip(state, self._num_states): 606 | if state_var >= num_state_values: 607 | raise ValueError('state values must be <= corresponding ' 608 | + 'num_states value.') 609 | 610 | phi = np.zeros(self.size()) 611 | phi[self.get_state_action_index(state, action)] = 1 612 | 613 | return phi 614 | 615 | @property 616 | def num_actions(self): 617 | """Return number of possible actions.""" 618 | return self.__num_actions 619 | 620 | @num_actions.setter 621 | def num_actions(self, value): 622 | """Set the number of possible actions. 623 | 624 | Parameters 625 | ---------- 626 | value: int 627 | Number of possible actions. Must be >= 1. 628 | 629 | Raises 630 | ------ 631 | ValueError 632 | if value < 1. 633 | """ 634 | if value < 1: 635 | raise ValueError('num_actions must be at least 1.') 636 | self.__num_actions = value 637 | -------------------------------------------------------------------------------- /lspi/domains.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains example domains that LSPI works on.""" 3 | 4 | 5 | import abc 6 | 7 | 8 | from random import randint, random 9 | 10 | import numpy as np 11 | 12 | from sample import Sample 13 | 14 | 15 | class Domain(object): 16 | 17 | r"""ABC for domains. 18 | 19 | Minimum interface for a reinforcement learning domain. 20 | """ 21 | 22 | __metaclass__ = abc.ABCMeta 23 | 24 | @abc.abstractmethod 25 | def num_actions(self): 26 | """Return number of possible actions for the given domain. 27 | 28 | Actions are indexed from 0 to num_actions - 1. 29 | 30 | Returns 31 | ------- 32 | int 33 | Number of possible actions. 34 | """ 35 | pass # pragma: no cover 36 | 37 | @abc.abstractmethod 38 | def current_state(self): 39 | """Return the current state of the domain. 40 | 41 | Returns 42 | ------- 43 | numpy.array 44 | The current state of the environment expressed as a numpy array 45 | of the individual state variables. 46 | """ 47 | pass # pragma: no cover 48 | 49 | @abc.abstractmethod 50 | def apply_action(self, action): 51 | """Apply action and return a sample. 52 | 53 | Parameters 54 | ---------- 55 | action: int 56 | The action index to apply. This should be a number in the range 57 | [0, num_actions()) 58 | 59 | Returns 60 | ------- 61 | sample.Sample 62 | Sample containing the previous state, the action applied, the 63 | received reward and the resulting state. 64 | """ 65 | pass # pragma: no cover 66 | 67 | @abc.abstractmethod 68 | def reset(self, initial_state=None): 69 | """Reset the simulator to initial conditions. 70 | 71 | Parameters 72 | ---------- 73 | initial_state: numpy.array 74 | Optionally specify the state to reset to. If None then the domain 75 | should use its default initial set of states. The type will 76 | generally be a numpy.array, but a subclass may accept other types. 77 | 78 | """ 79 | pass # pragma: no cover 80 | 81 | @abc.abstractmethod 82 | def action_name(self, action): 83 | """Return a string representation of the action. 84 | 85 | Parameters 86 | ---------- 87 | action: int 88 | The action index to apply. This number should be in the range 89 | [0, num_actions()) 90 | 91 | Returns 92 | ------- 93 | str 94 | String representation of the action index. 95 | """ 96 | pass # pragma: no cover 97 | 98 | 99 | class ChainDomain(Domain): 100 | 101 | """Chain domain from LSPI paper. 102 | 103 | Very simple MDP. Used to test LSPI methods and demonstrate the interface. 104 | The state space is a series of discrete nodes in a chain. There are two 105 | actions: Left and Right. These actions fail with a configurable 106 | probability. When the action fails to performs the opposite action. In 107 | otherwords if left is the action applied, but it fails, then the agent will 108 | actually move right (assuming it is not in the right most state). 109 | 110 | The default reward for any action in a state is 0. There are 2 special 111 | states that will give a +1 reward for entering. The two special states can 112 | be configured to appear at the end of the chain, in the middle, or 113 | in the middle of each half of the state space. 114 | 115 | Parameters 116 | ---------- 117 | num_states: int 118 | Number of states in the chain. Must be at least 4. 119 | Defaults to 10 states. 120 | reward_location: ChainDomain.RewardLoction 121 | Location of the states with +1 rewards 122 | failure_probability: float 123 | The probability that the applied action will fail. Must be in range 124 | [0, 1] 125 | 126 | """ 127 | 128 | class RewardLocation(object): 129 | 130 | """Location of states giving +1 reward in the chain. 131 | 132 | Ends: 133 | Rewards will be given at the ends of the chain. 134 | Middle: 135 | Rewards will be given at the middle two states of the chain. 136 | HalfMiddles: 137 | Rewards will be given at the middle two states of each half 138 | of the chain. 139 | 140 | """ 141 | 142 | Ends, Middle, HalfMiddles = range(3) 143 | 144 | __action_names = ['left', 'right'] 145 | 146 | def __init__(self, num_states=10, 147 | reward_location=RewardLocation.Ends, 148 | failure_probability=.1): 149 | """Initialize ChainDomain.""" 150 | if num_states < 4: 151 | raise ValueError('num_states must be >= 4') 152 | if failure_probability < 0 or failure_probability > 1: 153 | raise ValueError('failure_probability must be in range [0, 1]') 154 | 155 | self.num_states = int(num_states) 156 | self.reward_location = reward_location 157 | self.failure_probability = failure_probability 158 | 159 | self._state = ChainDomain.__init_random_state(num_states) 160 | 161 | def num_actions(self): 162 | """Return number of actions. 163 | 164 | Chain domain has 2 actions. 165 | 166 | Returns 167 | ------- 168 | int 169 | Number of actions 170 | 171 | """ 172 | return 2 173 | 174 | def current_state(self): 175 | """Return the current state of the domain. 176 | 177 | Returns 178 | ------- 179 | numpy.array 180 | The current state as a 1D numpy vector of type int. 181 | 182 | """ 183 | return self._state 184 | 185 | def apply_action(self, action): 186 | """Apply the action to the chain. 187 | 188 | If left is applied then the occupied state index will decrease by 1. 189 | Unless the agent is already at 0, in which case the state will not 190 | change. 191 | 192 | If right is applied then the occupied state index will increase by 1. 193 | Unless the agent is already at num_states-1, in which case the state 194 | will not change. 195 | 196 | The reward function is determined by the reward location specified when 197 | constructing the domain. 198 | 199 | If failure_probability is > 0 then there is the chance for the left 200 | and right actions to fail. If the left action fails then the agent 201 | will move right. Similarly if the right action fails then the agent 202 | will move left. 203 | 204 | Parameters 205 | ---------- 206 | action: int 207 | Action index. Must be in range [0, num_actions()) 208 | 209 | Returns 210 | ------- 211 | sample.Sample 212 | The sample for the applied action. 213 | 214 | Raises 215 | ------ 216 | ValueError 217 | If the action index is outside of the range [0, num_actions()) 218 | 219 | """ 220 | if action < 0 or action >= 2: 221 | raise ValueError('Action index outside of bounds [0, %d)' % 222 | self.num_actions()) 223 | 224 | action_failed = False 225 | if random() < self.failure_probability: 226 | action_failed = True 227 | 228 | # this assumes that the state has one and only one occupied location 229 | if (action == 0 and not action_failed) \ 230 | or (action == 1 and action_failed): 231 | new_location = max(0, self._state[0]-1) 232 | else: 233 | new_location = min(self.num_states-1, self._state[0]+1) 234 | 235 | next_state = np.array([new_location]) 236 | 237 | reward = 0 238 | if self.reward_location == ChainDomain.RewardLocation.Ends: 239 | if new_location == 0 or new_location == self.num_states-1: 240 | reward = 1 241 | elif self.reward_location == ChainDomain.RewardLocation.Middle: 242 | if new_location == int(self.num_states/2) \ 243 | or new_location == int(self.num_states/2 + 1): 244 | reward = 1 245 | else: # HalfMiddles case 246 | if new_location == int(self.num_states/4) \ 247 | or new_location == int(3*self.num_states/4): 248 | reward = 1 249 | 250 | sample = Sample(self._state.copy(), action, reward, next_state.copy()) 251 | 252 | self._state = next_state 253 | 254 | return sample 255 | 256 | def reset(self, initial_state=None): 257 | """Reset the domain to initial state or specified state. 258 | 259 | If the state is unspecified then it will generate a random state, just 260 | like when constructing from scratch. 261 | 262 | State must be the same size as the original state. State values can be 263 | either 0 or 1. There must be one and only one location that contains 264 | a value of 1. Whatever the numpy array type used, it will be converted 265 | to an integer numpy array. 266 | 267 | Parameters 268 | ---------- 269 | initial_state: numpy.array 270 | The state to set the simulator to. If None then set to a random 271 | state. 272 | 273 | Raises 274 | ------ 275 | ValueError 276 | If initial state's shape does not match (num_states, ). In 277 | otherwords the initial state must be a 1D numpy array with the 278 | same length as the existing state. 279 | ValueError 280 | If part of the state has a value or 1, or there are multiple 281 | parts of the state with value of 1. 282 | ValueError 283 | If there are values in the state other than 0 or 1. 284 | 285 | """ 286 | if initial_state is None: 287 | self._state = ChainDomain.__init_random_state(self.num_states) 288 | else: 289 | if initial_state.shape != (1, ): 290 | raise ValueError('The specified state did not match the ' 291 | + 'current state size') 292 | state = initial_state.astype(np.int) 293 | if state[0] < 0 or state[0] >= self.num_states: 294 | raise ValueError('State value must be in range ' 295 | + '[0, num_states)') 296 | self._state = state 297 | 298 | def action_name(self, action): 299 | """Return string representation of actions. 300 | 301 | 0: 302 | left 303 | 1: 304 | right 305 | 306 | Returns 307 | ------- 308 | str 309 | String representation of action. 310 | """ 311 | return ChainDomain.__action_names[action] 312 | 313 | @staticmethod 314 | def __init_random_state(num_states): 315 | """Return randomly initialized state of the specified size.""" 316 | return np.array([randint(0, num_states-1)]) 317 | -------------------------------------------------------------------------------- /lspi/lspi.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains main interface to LSPI algorithm.""" 3 | 4 | from copy import copy 5 | 6 | import numpy as np 7 | 8 | 9 | def learn(data, initial_policy, solver, epsilon=10**-5, max_iterations=10): 10 | r"""Find the optimal policy for the specified data. 11 | 12 | Parameters 13 | ---------- 14 | data: 15 | Generally a list of samples, however, the type of data does not matter 16 | so long as the specified solver can handle it in its solve routine. For 17 | example when doing model based learning one might pass in a model 18 | instead of sample data 19 | initial_policy: Policy 20 | Starting policy. A copy of this policy will be made at the start of the 21 | method. This means that the provided initial policy will be preserved. 22 | solver: Solver 23 | A subclass of the Solver abstract base class. This class must implement 24 | the solve method. Examples of solvers might be steepest descent or 25 | any other linear system of equation matrix solver. This is basically 26 | going to be implementations of the LSTDQ algorithm. 27 | epsilon: float 28 | The threshold of the change in policy weights. Determines if the policy 29 | has converged. When the L2-norm of the change in weights is less than 30 | this value the policy is considered converged 31 | max_iterations: int 32 | The maximum number of iterations to run before giving up on 33 | convergence. The change in policy weights are not guaranteed to ever 34 | go below epsilon. To prevent an infinite loop this parameter must be 35 | specified. 36 | 37 | Return 38 | ------ 39 | Policy 40 | The converged policy. If the policy does not converge by max_iterations 41 | then this will be the last iteration's policy. 42 | 43 | Raises 44 | ------ 45 | ValueError 46 | If epsilon is <= 0 47 | ValueError 48 | If max_iteration <= 0 49 | 50 | """ 51 | if epsilon <= 0: 52 | raise ValueError('epsilon must be > 0: %g' % epsilon) 53 | if max_iterations <= 0: 54 | raise ValueError('max_iterations must be > 0: %d' % max_iterations) 55 | 56 | # this is just to make sure that changing the weight vector doesn't 57 | # affect the original policy weights 58 | curr_policy = copy(initial_policy) 59 | 60 | distance = float('inf') 61 | iteration = 0 62 | while distance > epsilon and iteration < max_iterations: 63 | iteration += 1 64 | new_weights = solver.solve(data, curr_policy) 65 | 66 | distance = np.linalg.norm(new_weights - curr_policy.weights) 67 | curr_policy.weights = new_weights 68 | 69 | return curr_policy 70 | -------------------------------------------------------------------------------- /lspi/policy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """LSPI Policy class used for learning and executing policy.""" 3 | 4 | import random 5 | 6 | import numpy as np 7 | 8 | 9 | class Policy(object): 10 | 11 | r"""Represents LSPI policy. Used for sampling, learning, and executing. 12 | 13 | The policy class includes an exploration value which controls the 14 | probability of performing a random action instead of the best action 15 | according to the policy. This can be useful during sample. 16 | 17 | It also includes the discount factor :math:`\gamma`, number of possible 18 | actions and the basis function used for this policy. 19 | 20 | Parameters 21 | ---------- 22 | basis: BasisFunction 23 | The basis function used to compute :math:`phi` which is used to select 24 | the best action according to the policy 25 | discount: float, optional 26 | The discount factor :math:`\gamma`. Defaults to 1.0 which is valid 27 | for finite horizon problems. 28 | explore: float, optional 29 | Probability of executing a random action instead of the best action 30 | according to the policy. Defaults to 0 which is no exploration. 31 | weights: numpy.array or None 32 | The weight vector which is dotted with the :math:`\phi` vector from 33 | basis to produce the approximate Q value. When None is passed in 34 | the weight vector is initialized with random weights. 35 | tie_breaking_strategy: Policy.TieBreakingStrategy value 36 | The strategy to use if a tie occurs when selecting the best action. 37 | See the :py:class:`lspi.policy.Policy.TieBreakingStrategy` 38 | class description for what the different options are. 39 | 40 | Raises 41 | ------ 42 | ValueError 43 | If discount is < 0 or > 1 44 | ValueError 45 | If explore is < 0 or > 1 46 | ValueError 47 | If weights are not None and the number of dimensions does not match 48 | the size of the basis function. 49 | """ 50 | 51 | class TieBreakingStrategy(object): 52 | 53 | """Strategy for breaking a tie between actions in the policy. 54 | 55 | FirstWins: 56 | In the event of a tie the first action encountered with that 57 | value is returned. 58 | LastWins: 59 | In the event of a tie the last action encountered with that 60 | value is returned. 61 | RandomWins 62 | In the event of a tie a random action encountered with that 63 | value is returned. 64 | 65 | """ 66 | 67 | FirstWins, LastWins, RandomWins = range(3) 68 | 69 | def __init__(self, basis, discount=1.0, 70 | explore=0.0, weights=None, 71 | tie_breaking_strategy=TieBreakingStrategy.RandomWins): 72 | """Initialize a Policy.""" 73 | self.basis = basis 74 | 75 | if discount < 0.0 or discount > 1.0: 76 | raise ValueError('discount must be in range [0, 1]') 77 | 78 | self.discount = discount 79 | 80 | if explore < 0.0 or explore > 1.0: 81 | raise ValueError('explore must be in range [0, 1]') 82 | 83 | self.explore = explore 84 | 85 | if weights is None: 86 | self.weights = np.random.uniform(-1.0, 1.0, size=(basis.size(),)) 87 | else: 88 | if weights.shape != (basis.size(), ): 89 | raise ValueError('weights shape must equal (basis.size(), 1)') 90 | self.weights = weights 91 | 92 | self.tie_breaking_strategy = tie_breaking_strategy 93 | 94 | def __copy__(self): 95 | """Return a copy of this class with a deep copy of the weights.""" 96 | return Policy(self.basis, 97 | self.discount, 98 | self.explore, 99 | self.weights.copy(), 100 | self.tie_breaking_strategy) 101 | 102 | def calc_q_value(self, state, action): 103 | """Calculate the Q function for the given state action pair. 104 | 105 | Parameters 106 | ---------- 107 | state: numpy.array 108 | State vector that Q value is being calculated for. This is 109 | the s in Q(s, a) 110 | action: int 111 | Action index that Q value is being calculated for. This is 112 | the a in Q(s, a) 113 | 114 | Return 115 | ------ 116 | float 117 | The Q value for the state action pair 118 | 119 | Raises 120 | ------ 121 | ValueError 122 | If state's dimensions do not conform to basis function expectations 123 | ValueError 124 | If action is outside of the range of valid action indexes 125 | 126 | """ 127 | if action < 0 or action >= self.basis.num_actions: 128 | raise IndexError('action must be in range [0, num_actions)') 129 | 130 | return self.weights.dot(self.basis.evaluate(state, action)) 131 | 132 | def best_action(self, state): 133 | """Select the best action according to the policy. 134 | 135 | This calculates argmax_a Q(state, a). In otherwords it returns 136 | the action that maximizes the Q value for this state. 137 | 138 | Parameters 139 | ---------- 140 | state: numpy.array 141 | State vector. 142 | tie_breaking_strategy: TieBreakingStrategy value 143 | In the event of a tie specifies which action the policy should 144 | return. (Defaults to random) 145 | 146 | Returns 147 | ------- 148 | int 149 | Action index 150 | 151 | Raises 152 | ------ 153 | ValueError 154 | If state's dimensions do not match basis functions expectations. 155 | 156 | """ 157 | q_values = [self.calc_q_value(state, action) 158 | for action in range(self.basis.num_actions)] 159 | 160 | best_q = float('-inf') 161 | best_actions = [] 162 | for action, q_value in enumerate(q_values): 163 | if q_value > best_q: 164 | best_actions = [action] 165 | best_q = q_value 166 | elif q_value == best_q: 167 | best_actions.append(action) 168 | 169 | if self.tie_breaking_strategy == Policy.TieBreakingStrategy.FirstWins: 170 | return best_actions[0] 171 | elif self.tie_breaking_strategy == Policy.TieBreakingStrategy.LastWins: 172 | return best_actions[-1] 173 | else: 174 | return random.choice(best_actions) 175 | 176 | def select_action(self, state): 177 | """With random probability select best action or random action. 178 | 179 | If the random number is below the explore value then pick a random 180 | value otherwise pick the best action according to the basis and 181 | policy weights. 182 | 183 | Parameters 184 | ---------- 185 | state: numpy.array 186 | State vector 187 | 188 | Returns 189 | ------- 190 | int 191 | Action index 192 | 193 | Raises 194 | ------ 195 | ValueError 196 | If state's dimensions do not match basis functions expectations. 197 | 198 | """ 199 | if random.random() < self.explore: 200 | return random.choice(range(self.basis.num_actions)) 201 | else: 202 | return self.best_action(state) 203 | 204 | @property 205 | def num_actions(self): 206 | r"""Return number of possible actions. 207 | 208 | This number should always match the value stored in basis.num_actions. 209 | 210 | Return 211 | ------ 212 | int 213 | Number of possible actions. In range [1, :math:`\infty`) 214 | 215 | """ 216 | return self.basis.num_actions 217 | 218 | @num_actions.setter 219 | def num_actions(self, value): 220 | """Set the number of possible actions. 221 | 222 | This number should always match the value stored in basis.num_actions. 223 | 224 | Parameters 225 | ---------- 226 | value: int 227 | Value to set num_actions to. Must be >= 1 228 | 229 | Raises 230 | ------ 231 | ValueError 232 | If value is < 1 233 | 234 | """ 235 | self.basis.num_actions = value 236 | -------------------------------------------------------------------------------- /lspi/sample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains class representing an LSPI sample.""" 3 | 4 | 5 | class Sample(object): 6 | 7 | """Represents an LSPI sample tuple ``(s, a, r, s', absorb)``. 8 | 9 | Parameters 10 | ---------- 11 | 12 | state : numpy.array 13 | State of the environment at the start of the sample. 14 | ``s`` in the sample tuple. 15 | (The usual type is a numpy array.) 16 | action : int 17 | Index of action that was executed. 18 | ``a`` in the sample tuple 19 | reward : float 20 | Reward received from the environment. 21 | ``r`` in the sample tuple 22 | next_state : numpy.array 23 | State of the environment after executing the sample's action. 24 | ``s'`` in the sample tuple 25 | (The type should match that of state.) 26 | absorb : bool, optional 27 | True if this sample ended the episode. False otherwise. 28 | ``absorb`` in the sample tuple 29 | (The default is False, which implies that this is a 30 | non-episode-ending sample) 31 | 32 | 33 | Assumes that this is a non-absorbing sample (as the vast majority 34 | of samples will be non-absorbing). 35 | 36 | This class is just a dumb data holder so the types of the different 37 | fields can be anything convenient for the problem domain. 38 | 39 | For states represented by vectors a numpy array works well. 40 | 41 | """ 42 | 43 | def __init__(self, state, action, reward, next_state, absorb=False): 44 | """Initialize Sample instance.""" 45 | self.state = state 46 | self.action = action 47 | self.reward = reward 48 | self.next_state = next_state 49 | self.absorb = absorb 50 | 51 | def __repr__(self): 52 | """Create string representation of tuple.""" 53 | return 'Sample(%s, %s, %s, %s, %s)' % (self.state, 54 | self.action, 55 | self.reward, 56 | self.next_state, 57 | self.absorb) 58 | -------------------------------------------------------------------------------- /lspi/solvers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains main LSPI method and various LSTDQ solvers.""" 3 | 4 | import abc 5 | import logging 6 | 7 | import numpy as np 8 | 9 | import scipy.linalg 10 | 11 | 12 | class Solver(object): 13 | 14 | r"""ABC for LSPI solvers. 15 | 16 | Implementations of this class will implement the various LSTDQ algorithms 17 | with various linear algebra solving techniques. This solver will be used 18 | by the lspi.learn method. The instance will be called iteratively until 19 | the convergence parameters are satisified. 20 | 21 | """ 22 | 23 | __metaclass__ = abc.ABCMeta 24 | 25 | @abc.abstractmethod 26 | def solve(self, data, policy): 27 | r"""Return one-step update of the policy weights for the given data. 28 | 29 | Parameters 30 | ---------- 31 | data: 32 | This is the data used by the solver. In most cases this will be 33 | a list of samples. But it can be anything supported by the specific 34 | Solver implementation's solve method. 35 | policy: Policy 36 | The current policy to find an improvement to. 37 | 38 | Returns 39 | ------- 40 | numpy.array 41 | Return the new weights as determined by this method. 42 | 43 | """ 44 | pass # pragma: no cover 45 | 46 | 47 | class LSTDQSolver(Solver): 48 | 49 | """LSTDQ Implementation with standard matrix solvers. 50 | 51 | Uses the algorithm from Figure 5 of the LSPI paper. If the A matrix 52 | turns out to be full rank then scipy's standard linalg solver is used. If 53 | the matrix turns out to be less than full rank then least squares method 54 | will be used. 55 | 56 | By default the A matrix will have its diagonal preconditioned with a small 57 | positive value. This will help to ensure that even with few samples the 58 | A matrix will be full rank. If you do not want the A matrix to be 59 | preconditioned then you can set this value to 0. 60 | 61 | Parameters 62 | ---------- 63 | precondition_value: float 64 | Value to set A matrix diagonals to. Should be a small positive number. 65 | If you do not want preconditioning enabled then set it 0. 66 | """ 67 | 68 | def __init__(self, precondition_value=.1): 69 | """Initialize LSTDQSolver.""" 70 | self.precondition_value = precondition_value 71 | 72 | def solve(self, data, policy): 73 | """Run LSTDQ iteration. 74 | 75 | See Figure 5 of the LSPI paper for more information. 76 | """ 77 | k = policy.basis.size() 78 | a_mat = np.zeros((k, k)) 79 | np.fill_diagonal(a_mat, self.precondition_value) 80 | 81 | b_vec = np.zeros((k, 1)) 82 | 83 | for sample in data: 84 | phi_sa = (policy.basis.evaluate(sample.state, sample.action) 85 | .reshape((-1, 1))) 86 | 87 | if not sample.absorb: 88 | best_action = policy.best_action(sample.next_state) 89 | phi_sprime = (policy.basis 90 | .evaluate(sample.next_state, best_action) 91 | .reshape((-1, 1))) 92 | else: 93 | phi_sprime = np.zeros((k, 1)) 94 | 95 | a_mat += phi_sa.dot((phi_sa - policy.discount*phi_sprime).T) 96 | b_vec += phi_sa*sample.reward 97 | 98 | a_rank = np.linalg.matrix_rank(a_mat) 99 | if a_rank == k: 100 | w = scipy.linalg.solve(a_mat, b_vec) 101 | else: 102 | logging.warning('A matrix is not full rank. %d < %d', a_rank, k) 103 | w = scipy.linalg.lstsq(a_mat, b_vec)[0] 104 | return w.reshape((-1, )) 105 | -------------------------------------------------------------------------------- /lspi_testsuite/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Testsuite for the lspi package.""" -------------------------------------------------------------------------------- /lspi_testsuite/test_basis_functions.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains unit tests for the basis function module.""" 3 | from unittest import TestCase 4 | 5 | from lspi.basis_functions import (BasisFunction, 6 | FakeBasis, 7 | OneDimensionalPolynomialBasis, 8 | RadialBasisFunction, 9 | ExactBasis) 10 | import numpy as np 11 | 12 | class TestBasisFunction(TestCase): 13 | def test_require_size_method(self): 14 | """Test BasisFunction implementation requires size method.""" 15 | 16 | class MissingSizeBasis(BasisFunction): 17 | def evaluate(self, state, action): 18 | pass 19 | 20 | @property 21 | def num_actions(self): 22 | pass 23 | 24 | with self.assertRaises(TypeError): 25 | MissingSizeBasis() 26 | 27 | def test_require_evaluate_method(self): 28 | """Test BasisFunction implementation requires evaluate method.""" 29 | 30 | class MissingEvaluateBasis(BasisFunction): 31 | def size(self): 32 | pass 33 | 34 | @property 35 | def num_actions(self): 36 | pass 37 | 38 | with self.assertRaises(TypeError): 39 | MissingEvaluateBasis() 40 | 41 | def test_require_num_actions_property(self): 42 | 43 | class MissingNumActionsProperty(BasisFunction): 44 | def size(self): 45 | pass 46 | 47 | def evaluate(self, state, action): 48 | pass 49 | 50 | with self.assertRaises(TypeError): 51 | MissingNumActionsProperty() 52 | 53 | def test_works_with_both_methods_implemented(self): 54 | """Test BasisFunction implemention works when all methods defined.""" 55 | 56 | class ShouldWorkBasis(BasisFunction): 57 | 58 | def size(self): 59 | pass 60 | 61 | def evaluate(self, state, action): 62 | pass 63 | 64 | @property 65 | def num_actions(self): 66 | pass 67 | 68 | ShouldWorkBasis() 69 | 70 | def test_validate_num_actions(self): 71 | self.assertEqual(BasisFunction._validate_num_actions(6), 6) 72 | 73 | def test_validate_num_actions_out_of_bounds(self): 74 | with self.assertRaises(ValueError): 75 | BasisFunction._validate_num_actions(0) 76 | 77 | 78 | class TestFakeBasis(TestCase): 79 | def setUp(self): 80 | self.basis = FakeBasis(6) 81 | 82 | def test_num_actions_property(self): 83 | self.assertEqual(self.basis.num_actions, 6) 84 | 85 | def test_num_actions_setter(self): 86 | self.basis.num_actions = 10 87 | 88 | self.assertEqual(self.basis.num_actions, 10) 89 | 90 | def test_num_actions_setter_invalid_value(self): 91 | with self.assertRaises(ValueError): 92 | self.basis.num_actions = 0 93 | 94 | def test_size(self): 95 | self.assertEqual(self.basis.size(), 1) 96 | 97 | def test_evaluate(self): 98 | np.testing.assert_array_almost_equal(self.basis.evaluate(None, 0), 99 | np.array([1.])) 100 | 101 | def test_evaluate_negative_action_index(self): 102 | with self.assertRaises(IndexError): 103 | self.basis.evaluate(None, -1) 104 | 105 | def test_evaluate_out_of_bounds_action_index(self): 106 | with self.assertRaises(IndexError): 107 | self.basis.evaluate(None, 6) 108 | 109 | class TestOneDimensionalPolynomialBasis(TestCase): 110 | def setUp(self): 111 | 112 | self.basis = OneDimensionalPolynomialBasis(2, 2) 113 | 114 | def test_specify_degree(self): 115 | 116 | self.assertEqual(self.basis.degree, 2) 117 | 118 | def test_specify_actions(self): 119 | 120 | self.assertEqual(self.basis.num_actions, 2) 121 | 122 | def test_num_actions_setter(self): 123 | self.basis.num_actions = 10 124 | 125 | self.assertEqual(self.basis.num_actions, 10) 126 | 127 | def test_num_actions_setter_invalid_value(self): 128 | with self.assertRaises(ValueError): 129 | self.basis.num_actions = 0 130 | 131 | def test_out_of_bounds_degree(self): 132 | 133 | with self.assertRaises(ValueError): 134 | OneDimensionalPolynomialBasis(-1, 2) 135 | 136 | def test_out_of_bounds_num_action(self): 137 | 138 | with self.assertRaises(ValueError): 139 | OneDimensionalPolynomialBasis(2, 0) 140 | 141 | with self.assertRaises(ValueError): 142 | OneDimensionalPolynomialBasis(2, -1) 143 | 144 | def test_size(self): 145 | 146 | self.assertEqual(self.basis.size(), 6) 147 | 148 | def test_evaluate(self): 149 | 150 | phi = self.basis.evaluate(np.array([2]), 1) 151 | self.assertEqual(phi.shape, (6, )) 152 | np.testing.assert_array_almost_equal(phi, 153 | np.array([0., 0., 0., 1., 2., 4.])) 154 | 155 | def test_evaluate_out_of_bounds_action(self): 156 | 157 | with self.assertRaises(IndexError): 158 | self.basis.evaluate(np.array([2]), 2) 159 | 160 | with self.assertRaises(IndexError): 161 | self.basis.evaluate(np.array([2]), -1) 162 | 163 | def test_evaluate_incorrect_state_dimensions(self): 164 | 165 | with self.assertRaises(ValueError): 166 | self.basis.evaluate(np.array([2, 3]), 0) 167 | 168 | class TestRadialBasisFunction(TestCase): 169 | def setUp(self): 170 | 171 | self.means = [-np.ones((3, )), np.zeros((3, )), np.ones((3, ))] 172 | self.gamma = 1 173 | self.num_actions = 2 174 | self.basis = RadialBasisFunction(self.means, 175 | self.gamma, 176 | self.num_actions) 177 | self.state = np.zeros((3, )) 178 | 179 | def test_specify_means(self): 180 | 181 | for mean, expected_mean in zip(self.basis.means, self.means): 182 | np.testing.assert_array_almost_equal(mean, expected_mean) 183 | 184 | def test_empty_means_list(self): 185 | with self.assertRaises(ValueError): 186 | RadialBasisFunction([], self.gamma, self.num_actions) 187 | 188 | def test_mismatched_mean_shapes(self): 189 | with self.assertRaises(ValueError): 190 | RadialBasisFunction([np.zeros((3, )), 191 | -np.ones((2, )), 192 | np.ones((3, ))], 193 | self.gamma, 194 | self.num_actions) 195 | 196 | def test_specify_gamma(self): 197 | self.assertAlmostEqual(self.gamma, self.basis.gamma) 198 | 199 | def test_out_of_bounds_gamma(self): 200 | with self.assertRaises(ValueError): 201 | RadialBasisFunction(self.means, 0, self.num_actions) 202 | 203 | def test_specify_actions(self): 204 | 205 | self.assertEqual(self.basis.num_actions, self.num_actions) 206 | 207 | def test_num_actions_setter(self): 208 | self.basis.num_actions = 10 209 | 210 | self.assertEqual(self.basis.num_actions, 10) 211 | 212 | def test_num_actions_setter_invalid_value(self): 213 | with self.assertRaises(ValueError): 214 | self.basis.num_actions = 0 215 | 216 | def test_out_of_bounds_num_action(self): 217 | 218 | with self.assertRaises(ValueError): 219 | RadialBasisFunction(self.means, self.gamma, 0) 220 | 221 | with self.assertRaises(ValueError): 222 | RadialBasisFunction(self.means, self.gamma, -1) 223 | 224 | def test_size(self): 225 | 226 | self.assertEqual(self.basis.size(), 8) 227 | 228 | def test_evaluate(self): 229 | 230 | phi = self.basis.evaluate(self.state, 0) 231 | self.assertEqual(phi.shape, (8, )) 232 | np.testing.assert_array_almost_equal(phi, 233 | np.array([1., 234 | 0.0498, 235 | 1., 236 | 0.0498, 237 | 0., 238 | 0., 239 | 0., 240 | 0.]), 241 | 4) 242 | 243 | def test_evaluate_out_of_bounds_action(self): 244 | 245 | with self.assertRaises(IndexError): 246 | self.basis.evaluate(self.state, 2) 247 | 248 | with self.assertRaises(IndexError): 249 | self.basis.evaluate(self.state, -1) 250 | 251 | def test_evaluate_incorrect_state_dimensions(self): 252 | 253 | with self.assertRaises(ValueError): 254 | self.basis.evaluate(np.zeros((2, )), 0) 255 | 256 | class TestExactBasis(TestCase): 257 | def setUp(self): 258 | self.basis = ExactBasis([2, 3, 4], 2) 259 | 260 | def test_invalid_num_states(self): 261 | num_states = np.ones(3) 262 | num_states[0] = 0 263 | 264 | with self.assertRaises(ValueError): 265 | ExactBasis(num_states, 2) 266 | 267 | def test_num_actions_property(self): 268 | self.assertEqual(self.basis.num_actions, 2) 269 | 270 | def test_num_actions_setter(self): 271 | self.basis.num_actions = 3 272 | 273 | self.assertEqual(self.basis.num_actions, 3) 274 | 275 | def test_num_actions_setter_invalid_value(self): 276 | with self.assertRaises(ValueError): 277 | self.basis.num_actions = 0 278 | 279 | def test_size(self): 280 | self.assertEqual(self.basis.size(), 48) 281 | 282 | def test_evaluate(self): 283 | phi = self.basis.evaluate(np.array([0, 0, 0]), 0) 284 | self.assertEqual(phi.shape, (48, )) 285 | 286 | expected_phi = np.zeros((48, )) 287 | expected_phi[0] = 1 288 | 289 | np.testing.assert_array_almost_equal(phi, expected_phi) 290 | 291 | phi = self.basis.evaluate(np.array([1, 0, 0]), 0) 292 | self.assertEqual(phi.shape, (48, )) 293 | 294 | expected_phi = np.zeros((48, )) 295 | expected_phi[1] = 1 296 | 297 | np.testing.assert_array_almost_equal(phi, expected_phi) 298 | 299 | phi = self.basis.evaluate(np.array([0, 1, 0]), 0) 300 | self.assertEqual(phi.shape, (48, )) 301 | 302 | expected_phi = np.zeros((48, )) 303 | expected_phi[2] = 1 304 | 305 | np.testing.assert_array_almost_equal(phi, expected_phi) 306 | 307 | phi = self.basis.evaluate(np.array([0, 0, 1]), 0) 308 | self.assertEqual(phi.shape, (48, )) 309 | 310 | expected_phi = np.zeros((48, )) 311 | expected_phi[6] = 1 312 | 313 | np.testing.assert_array_almost_equal(phi, expected_phi) 314 | 315 | phi = self.basis.evaluate(np.array([0, 0, 0]), 1) 316 | self.assertEqual(phi.shape, (48, )) 317 | 318 | expected_phi = np.zeros((48, )) 319 | expected_phi[24] = 1 320 | 321 | np.testing.assert_array_almost_equal(phi, expected_phi) 322 | 323 | phi = self.basis.evaluate(np.array([1, 2, 3]), 1) 324 | self.assertEqual(phi.shape, (48, )) 325 | 326 | expected_phi = np.zeros((48, )) 327 | expected_phi[47] = 1 328 | 329 | np.testing.assert_array_almost_equal(phi, expected_phi) 330 | 331 | def test_evaluate_out_of_bounds_action(self): 332 | with self.assertRaises(IndexError): 333 | self.basis.evaluate(np.array([0, 0, 0]), -1) 334 | 335 | 336 | with self.assertRaises(IndexError): 337 | self.basis.evaluate(np.array([0, 0, 0]), 3) 338 | 339 | def test_evaluate_out_of_bounds_state(self): 340 | with self.assertRaises(ValueError): 341 | self.basis.evaluate(np.array([-1, 0, 0]), 0) 342 | 343 | 344 | with self.assertRaises(ValueError): 345 | self.basis.evaluate(np.array([0, -1, 0]), 0) 346 | 347 | 348 | with self.assertRaises(ValueError): 349 | self.basis.evaluate(np.array([0, 0, -1]), 0) 350 | 351 | 352 | with self.assertRaises(ValueError): 353 | self.basis.evaluate(np.array([2, 0, 0]), 0) 354 | 355 | 356 | with self.assertRaises(ValueError): 357 | self.basis.evaluate(np.array([0, 3, 0]), 0) 358 | 359 | 360 | with self.assertRaises(ValueError): 361 | self.basis.evaluate(np.array([0, 0, 4]), 0) 362 | 363 | def test_evaluate_wrong_size_state(self): 364 | with self.assertRaises(ValueError): 365 | self.basis.evaluate(np.array([0]), 0) 366 | 367 | with self.assertRaises(ValueError): 368 | self.basis.evaluate(np.array([0, 0, 0, 0]), 0) -------------------------------------------------------------------------------- /lspi_testsuite/test_domains.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains unit tests for the included domains.""" 3 | from unittest import TestCase 4 | 5 | from lspi.domains import ChainDomain 6 | import numpy as np 7 | 8 | class TestChainDomain(TestCase): 9 | def setUp(self): 10 | self.num_states = 20 11 | self.reward_location = ChainDomain.RewardLocation.HalfMiddles 12 | self.failure_probability = .3 13 | self.domain = ChainDomain(self.num_states, 14 | self.reward_location, 15 | self.failure_probability) 16 | 17 | def test_minimum_number_of_states(self): 18 | """Test that domain throws error is num_states < 4.""" 19 | 20 | with self.assertRaises(ValueError): 21 | ChainDomain(3) 22 | 23 | def test_invalid_failure_probability(self): 24 | """Test that error is raised if failure probability is < 0 or > 1.""" 25 | 26 | with self.assertRaises(ValueError): 27 | ChainDomain(failure_probability=-.1) 28 | 29 | with self.assertRaises(ValueError): 30 | ChainDomain(failure_probability=1.1) 31 | 32 | def test_init_parameters_are_used(self): 33 | """Test that init parameters are used.""" 34 | 35 | self.assertEquals(self.domain.reward_location, 36 | self.reward_location) 37 | self.assertEquals(self.domain.failure_probability, 38 | self.failure_probability) 39 | 40 | def test_num_actions(self): 41 | """Test ChainDomain num_actions implementation.""" 42 | 43 | self.assertEquals(self.domain.num_actions(), 2) 44 | 45 | def test_reset_with_no_specified_state(self): 46 | """Test reset with no specified state.""" 47 | 48 | self.domain.reset() # basically test that no exception is thrown 49 | 50 | def test_reset_with_specified_state(self): 51 | """Test reset with a valid state specified.""" 52 | 53 | new_state = np.array([0]) 54 | 55 | self.domain.reset(new_state) 56 | 57 | curr_state = self.domain.current_state() 58 | self.assertEquals(curr_state[0], 0) 59 | 60 | def test_reset_with_diff_sized_state(self): 61 | """Test state vector with different sized state.""" 62 | 63 | new_state = np.zeros(self.num_states+1) 64 | new_state[0] = 1 65 | 66 | with self.assertRaises(ValueError): 67 | self.domain.reset(new_state) 68 | 69 | def test_reset_with_invalid_values(self): 70 | """Test reset with values in state not equal to 0 or 1.""" 71 | 72 | new_state = np.array([-1]) 73 | 74 | with self.assertRaises(ValueError): 75 | self.domain.reset(new_state) 76 | 77 | new_state = np.array([self.num_states]) 78 | 79 | with self.assertRaises(ValueError): 80 | self.domain.reset(new_state) 81 | 82 | def test_action_name(self): 83 | """Test action_name method.""" 84 | 85 | self.assertEquals(self.domain.action_name(0), "left") 86 | self.assertEquals(self.domain.action_name(1), "right") 87 | 88 | def test_deterministic_left(self): 89 | """Test deterministic left action.""" 90 | 91 | num_states = 10 92 | starting_state = np.array([2]) 93 | 94 | chain_domain = ChainDomain(num_states, 95 | ChainDomain.RewardLocation.Ends, 96 | 0) 97 | chain_domain.reset(starting_state) 98 | 99 | np.testing.assert_array_equal(chain_domain.current_state(), 100 | starting_state) 101 | 102 | expected_state = np.array([1]) 103 | 104 | sample = chain_domain.apply_action(0) 105 | np.testing.assert_array_equal(sample.state, starting_state) 106 | self.assertEquals(sample.action, 0) 107 | self.assertEquals(sample.reward, 0) 108 | np.testing.assert_array_equal(sample.next_state, expected_state) 109 | self.assertFalse(sample.absorb) 110 | np.testing.assert_array_equal(chain_domain.current_state(), 111 | expected_state) 112 | 113 | def test_deterministic_left_chain_end(self): 114 | """Test deterministic left action at the end of the chain.""" 115 | 116 | num_states = 10 117 | starting_state = np.array([0]) 118 | 119 | chain_domain = ChainDomain(num_states, 120 | ChainDomain.RewardLocation.Ends, 121 | 0) 122 | chain_domain.reset(starting_state) 123 | 124 | np.testing.assert_array_equal(chain_domain.current_state(), 125 | starting_state) 126 | 127 | expected_state = starting_state.copy() 128 | 129 | sample = chain_domain.apply_action(0) 130 | np.testing.assert_array_equal(sample.state, starting_state) 131 | self.assertEquals(sample.action, 0) 132 | self.assertEquals(sample.reward, 1) 133 | np.testing.assert_array_equal(sample.next_state, expected_state) 134 | self.assertFalse(sample.absorb) 135 | np.testing.assert_array_equal(chain_domain.current_state(), 136 | expected_state) 137 | 138 | def test_deterministic_right(self): 139 | """Test deterministic right action.""" 140 | 141 | num_states = 10 142 | starting_state = np.array([num_states-3]) 143 | 144 | chain_domain = ChainDomain(num_states, 145 | ChainDomain.RewardLocation.Ends, 146 | 0) 147 | chain_domain.reset(starting_state) 148 | 149 | np.testing.assert_array_equal(chain_domain.current_state(), 150 | starting_state) 151 | 152 | expected_state = np.array([num_states-2]) 153 | 154 | sample = chain_domain.apply_action(1) 155 | np.testing.assert_array_equal(sample.state, starting_state) 156 | self.assertEquals(sample.action, 1) 157 | self.assertEquals(sample.reward, 0) 158 | np.testing.assert_array_equal(sample.next_state, expected_state) 159 | self.assertFalse(sample.absorb) 160 | np.testing.assert_array_equal(chain_domain.current_state(), 161 | expected_state) 162 | 163 | def test_deterministic_right_chain_end(self): 164 | """Test deterministic right action at the end of the chain.""" 165 | 166 | num_states = 10 167 | starting_state = np.array([num_states-1]) 168 | 169 | chain_domain = ChainDomain(num_states, 170 | ChainDomain.RewardLocation.Ends, 171 | 0) 172 | chain_domain.reset(starting_state) 173 | 174 | np.testing.assert_array_equal(chain_domain.current_state(), 175 | starting_state) 176 | 177 | expected_state = starting_state.copy() 178 | 179 | sample = chain_domain.apply_action(1) 180 | np.testing.assert_array_equal(sample.state, starting_state) 181 | self.assertEquals(sample.action, 1) 182 | self.assertEquals(sample.reward, 1) 183 | np.testing.assert_array_equal(sample.next_state, expected_state) 184 | self.assertFalse(sample.absorb) 185 | np.testing.assert_array_equal(chain_domain.current_state(), 186 | expected_state) 187 | 188 | def test_failed_left(self): 189 | """Test failing left action.""" 190 | 191 | num_states = 10 192 | starting_state = np.array([1]) 193 | 194 | chain_domain = ChainDomain(num_states, 195 | ChainDomain.RewardLocation.Ends, 196 | 1) 197 | chain_domain.reset(starting_state) 198 | 199 | np.testing.assert_array_equal(chain_domain.current_state(), 200 | starting_state) 201 | 202 | expected_state = np.array([2]) 203 | 204 | sample = chain_domain.apply_action(0) 205 | np.testing.assert_array_equal(sample.state, starting_state) 206 | self.assertEquals(sample.action, 0) 207 | self.assertEquals(sample.reward, 0) 208 | np.testing.assert_array_equal(sample.next_state, expected_state) 209 | self.assertFalse(sample.absorb) 210 | np.testing.assert_array_equal(chain_domain.current_state(), 211 | expected_state) 212 | 213 | def test_failed_right(self): 214 | """Test failing right action.""" 215 | 216 | num_states = 10 217 | starting_state = np.array([2]) 218 | 219 | chain_domain = ChainDomain(num_states, 220 | ChainDomain.RewardLocation.Ends, 221 | 1) 222 | chain_domain.reset(starting_state) 223 | 224 | np.testing.assert_array_equal(chain_domain.current_state(), 225 | starting_state) 226 | 227 | expected_state = np.array([1]) 228 | 229 | sample = chain_domain.apply_action(1) 230 | np.testing.assert_array_equal(sample.state, starting_state) 231 | self.assertEquals(sample.action, 1) 232 | self.assertEquals(sample.reward, 0) 233 | np.testing.assert_array_equal(sample.next_state, expected_state) 234 | self.assertFalse(sample.absorb) 235 | np.testing.assert_array_equal(chain_domain.current_state(), 236 | expected_state) 237 | 238 | def test_rewards_at_ends(self): 239 | """Test rewards at end chain.""" 240 | 241 | num_states = 10 242 | starting_state = np.array([0]) 243 | 244 | chain_domain = ChainDomain(num_states, 245 | ChainDomain.RewardLocation.Ends, 246 | 0) 247 | chain_domain.reset(starting_state) 248 | 249 | np.testing.assert_array_equal(chain_domain.current_state(), 250 | starting_state) 251 | 252 | expected_state = starting_state.copy() 253 | 254 | sample = chain_domain.apply_action(0) 255 | np.testing.assert_array_equal(sample.state, starting_state) 256 | self.assertEquals(sample.action, 0) 257 | self.assertEquals(sample.reward, 1) 258 | np.testing.assert_array_equal(sample.next_state, expected_state) 259 | self.assertFalse(sample.absorb) 260 | np.testing.assert_array_equal(chain_domain.current_state(), 261 | expected_state) 262 | 263 | starting_state = np.array([num_states-1]) 264 | 265 | chain_domain.reset(starting_state) 266 | 267 | np.testing.assert_array_equal(chain_domain.current_state(), 268 | starting_state) 269 | 270 | expected_state = starting_state.copy() 271 | 272 | sample = chain_domain.apply_action(1) 273 | np.testing.assert_array_equal(sample.state, starting_state) 274 | self.assertEquals(sample.action, 1) 275 | self.assertEquals(sample.reward, 1) 276 | np.testing.assert_array_equal(sample.next_state, expected_state) 277 | self.assertFalse(sample.absorb) 278 | np.testing.assert_array_equal(chain_domain.current_state(), 279 | expected_state) 280 | 281 | def test_rewards_in_middle(self): 282 | """Test chain with rewards in the middle.""" 283 | 284 | num_states = 10 285 | starting_state = np.array([num_states/2-1]) 286 | 287 | chain_domain = ChainDomain(num_states, 288 | ChainDomain.RewardLocation.Middle, 289 | 0) 290 | chain_domain.reset(starting_state) 291 | 292 | np.testing.assert_array_equal(chain_domain.current_state(), 293 | starting_state) 294 | 295 | expected_state = np.array([num_states/2]) 296 | 297 | sample = chain_domain.apply_action(1) 298 | np.testing.assert_array_equal(sample.state, starting_state) 299 | self.assertEquals(sample.action, 1) 300 | self.assertEquals(sample.reward, 1) 301 | np.testing.assert_array_equal(sample.next_state, expected_state) 302 | self.assertFalse(sample.absorb) 303 | np.testing.assert_array_equal(chain_domain.current_state(), 304 | expected_state) 305 | 306 | starting_state = expected_state.copy() 307 | 308 | expected_state = np.array([num_states/2+1]) 309 | 310 | sample = chain_domain.apply_action(1) 311 | np.testing.assert_array_equal(sample.state, starting_state) 312 | self.assertEquals(sample.action, 1) 313 | self.assertEquals(sample.reward, 1) 314 | np.testing.assert_array_equal(sample.next_state, expected_state) 315 | self.assertFalse(sample.absorb) 316 | np.testing.assert_array_equal(chain_domain.current_state(), 317 | expected_state) 318 | 319 | starting_state = expected_state.copy() 320 | 321 | expected_state = np.array([num_states/2+2]) 322 | 323 | sample = chain_domain.apply_action(1) 324 | np.testing.assert_array_equal(sample.state, starting_state) 325 | self.assertEquals(sample.action, 1) 326 | self.assertEquals(sample.reward, 0) 327 | np.testing.assert_array_equal(sample.next_state, expected_state) 328 | self.assertFalse(sample.absorb) 329 | np.testing.assert_array_equal(chain_domain.current_state(), 330 | expected_state) 331 | 332 | def test_rewards_in_half_middles(self): 333 | """Test chain with rewards in the middle.""" 334 | 335 | num_states = 10 336 | starting_state = np.array([num_states/4-1]) 337 | 338 | chain_domain = ChainDomain(num_states, 339 | ChainDomain.RewardLocation.HalfMiddles, 340 | 0) 341 | chain_domain.reset(starting_state) 342 | 343 | np.testing.assert_array_equal(chain_domain.current_state(), 344 | starting_state) 345 | 346 | expected_state = np.array([num_states/4]) 347 | 348 | sample = chain_domain.apply_action(1) 349 | np.testing.assert_array_equal(sample.state, starting_state) 350 | self.assertEquals(sample.action, 1) 351 | self.assertEquals(sample.reward, 1) 352 | np.testing.assert_array_equal(sample.next_state, expected_state) 353 | self.assertFalse(sample.absorb) 354 | np.testing.assert_array_equal(chain_domain.current_state(), 355 | expected_state) 356 | 357 | starting_state = np.array([3*num_states/4-1]) 358 | chain_domain.reset(starting_state) 359 | 360 | np.testing.assert_array_equal(chain_domain.current_state(), 361 | starting_state) 362 | 363 | expected_state = np.array([3*num_states/4]) 364 | 365 | sample = chain_domain.apply_action(1) 366 | np.testing.assert_array_equal(sample.state, starting_state) 367 | self.assertEquals(sample.action, 1) 368 | self.assertEquals(sample.reward, 1) 369 | np.testing.assert_array_equal(sample.next_state, expected_state) 370 | self.assertFalse(sample.absorb) 371 | np.testing.assert_array_equal(chain_domain.current_state(), 372 | expected_state) 373 | 374 | chain_domain.reset(starting_state) 375 | 376 | expected_state = np.array([3*num_states/4-2]) 377 | 378 | sample = chain_domain.apply_action(0) 379 | np.testing.assert_array_equal(sample.state, starting_state) 380 | self.assertEquals(sample.action, 0) 381 | self.assertEquals(sample.reward, 0) 382 | np.testing.assert_array_equal(sample.next_state, expected_state) 383 | self.assertFalse(sample.absorb) 384 | np.testing.assert_array_equal(chain_domain.current_state(), 385 | expected_state) 386 | 387 | 388 | def test_out_of_bounds_action_application(self): 389 | """Test that error is raised when action is out of range.""" 390 | 391 | with self.assertRaises(ValueError): 392 | self.domain.apply_action(-1) 393 | 394 | with self.assertRaises(ValueError): 395 | self.domain.apply_action(self.domain.num_actions()) -------------------------------------------------------------------------------- /lspi_testsuite/test_learn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains test for the lspi learn method.""" 3 | from unittest import TestCase 4 | 5 | import lspi 6 | from lspi.solvers import Solver 7 | from lspi.policy import Policy 8 | from lspi.basis_functions import FakeBasis 9 | import numpy as np 10 | 11 | class SolverStub(Solver): 12 | def __init__(self, max_iterations): 13 | self.max_iterations = max_iterations 14 | self.num_calls = 0 15 | 16 | def count_calls(self): 17 | self.num_calls += 1 18 | if self.num_calls > self.max_iterations: 19 | raise RuntimeError(("%s was called more than the specified " + 20 | "max_iterations: %d") % 21 | (self.__class__.__name__, self.max_iterations)) 22 | 23 | class MaxIterationsSolverStub(SolverStub): 24 | def __init__(self, max_iterations=10): 25 | super(MaxIterationsSolverStub, self).__init__(max_iterations) 26 | 27 | def solve(self, data, policy): 28 | super(MaxIterationsSolverStub, self).count_calls() 29 | return policy.weights + 100 30 | 31 | class EpsilonSolverStub(SolverStub): 32 | def __init__(self, epsilon, max_iterations=10): 33 | super(EpsilonSolverStub, self).__init__(max_iterations) 34 | self.epsilon = epsilon 35 | 36 | def solve(self, data, policy): 37 | super(EpsilonSolverStub, self).count_calls() 38 | return policy.weights.copy() 39 | 40 | class WeightSolverStub(SolverStub): 41 | 42 | def __init__(self, weights, max_iterations=10): 43 | super(WeightSolverStub, self).__init__(max_iterations) 44 | self.weights = weights 45 | 46 | def solve(self, data, policy): 47 | super(WeightSolverStub, self).count_calls() 48 | return self.weights 49 | 50 | class SolverParamStub(SolverStub): 51 | def __init__(self, data, policy, max_iterations=10): 52 | super(SolverParamStub, self).__init__(max_iterations) 53 | self.data = data 54 | self.policy = policy 55 | 56 | def solve(self, data, policy): 57 | super(SolverParamStub, self).count_calls() 58 | 59 | assert id(self.data) == id(data) 60 | np.testing.assert_array_almost_equal_nulp(self.policy.weights, 61 | policy.weights) 62 | assert policy.discount == self.policy.discount 63 | assert policy.explore == self.policy.explore 64 | assert policy.basis == self.policy.basis 65 | assert policy.tie_breaking_strategy == \ 66 | self.policy.tie_breaking_strategy 67 | 68 | return self.policy.weights 69 | 70 | 71 | class TestLearnFunction(TestCase): 72 | def test_max_iterations_stopping_condition(self): 73 | """Test if learning stops when max_iterations is reached.""" 74 | 75 | with self.assertRaises(ValueError): 76 | lspi.learn(None, None, None, max_iterations=0) 77 | 78 | max_iterations_solver = MaxIterationsSolverStub() 79 | 80 | lspi.learn(None, 81 | Policy(FakeBasis(1)), 82 | max_iterations_solver, 83 | epsilon=10**-200, 84 | max_iterations=10) 85 | 86 | self.assertEqual(max_iterations_solver.num_calls, 10) 87 | 88 | def test_epsilon_stopping_condition(self): 89 | """Test if learning stops when distance is less than epsilon.""" 90 | 91 | with self.assertRaises(ValueError): 92 | lspi.learn(None, None, None, epsilon=0) 93 | 94 | epsilon_solver = EpsilonSolverStub(10**-21) 95 | 96 | lspi.learn(None, 97 | Policy(FakeBasis(1)), 98 | epsilon_solver, 99 | epsilon=10**-20, 100 | max_iterations=1000) 101 | 102 | self.assertEqual(epsilon_solver.num_calls, 1) 103 | 104 | def test_returns_policy_with_new_weights(self): 105 | """Test if the weights in the new policy differ and are not the same underlying numpy vector.""" 106 | 107 | initial_policy = Policy(FakeBasis(1)) 108 | 109 | weight_solver = WeightSolverStub(initial_policy.weights) 110 | 111 | new_policy = lspi.learn(None, 112 | initial_policy, 113 | weight_solver, 114 | max_iterations=1) 115 | 116 | self.assertEqual(weight_solver.num_calls, 1) 117 | self.assertFalse(np.may_share_memory(initial_policy.weights, 118 | new_policy)) 119 | self.assertNotEquals(id(initial_policy), id(new_policy)) 120 | np.testing.assert_array_almost_equal(new_policy.weights, 121 | weight_solver.weights) 122 | 123 | def test_solver_uses_policy_and_data(self): 124 | """Test that the solver is passed the data and policy.""" 125 | 126 | data = [10] 127 | initial_policy = Policy(FakeBasis(1)) 128 | 129 | solver_stub = SolverParamStub(data, initial_policy) 130 | 131 | lspi.learn(solver_stub.data, 132 | solver_stub.policy, 133 | solver_stub, 134 | max_iterations=1) -------------------------------------------------------------------------------- /lspi_testsuite/test_learning_chain_domain.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains integration test of different learning methods on chain domain.""" 3 | from unittest import TestCase 4 | 5 | import lspi 6 | 7 | import numpy as np 8 | 9 | class TestChainDomainLearning(TestCase): 10 | def setUp(self): 11 | self.domain = lspi.domains.ChainDomain() 12 | 13 | sampling_policy = lspi.Policy(lspi.basis_functions.FakeBasis(2), .9, 1) 14 | 15 | self.samples = [] 16 | for i in range(1000): 17 | action = sampling_policy.select_action(self.domain.current_state()) 18 | self.samples.append(self.domain.apply_action(action)) 19 | 20 | self.random_policy_cum_rewards = np.sum([sample.reward 21 | for sample in self.samples]) 22 | 23 | self.solver = lspi.solvers.LSTDQSolver() 24 | 25 | def test_chain_polynomial_basis(self): 26 | 27 | initial_policy = lspi.Policy( 28 | lspi.basis_functions.OneDimensionalPolynomialBasis(3, 2), 29 | .9, 30 | 0) 31 | 32 | learned_policy = lspi.learn(self.samples, initial_policy, self.solver) 33 | 34 | self.domain.reset() 35 | cumulative_reward = 0 36 | for i in range(1000): 37 | action = learned_policy.select_action(self.domain.current_state()) 38 | sample = self.domain.apply_action(action) 39 | cumulative_reward += sample.reward 40 | 41 | self.assertGreater(cumulative_reward, self.random_policy_cum_rewards) 42 | 43 | def test_chain_rbf_basis(self): 44 | 45 | initial_policy = lspi.Policy( 46 | lspi.basis_functions.RadialBasisFunction( 47 | np.array([[0], [2], [4], [6], [8]]), .5, 2), 48 | .9, 49 | 0) 50 | 51 | learned_policy = lspi.learn(self.samples, initial_policy, self.solver) 52 | 53 | self.domain.reset() 54 | cumulative_reward = 0 55 | for i in range(1000): 56 | action = learned_policy.select_action(self.domain.current_state()) 57 | sample = self.domain.apply_action(action) 58 | cumulative_reward += sample.reward 59 | 60 | self.assertGreater(cumulative_reward, self.random_policy_cum_rewards) -------------------------------------------------------------------------------- /lspi_testsuite/test_policy.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from unittest import TestCase 3 | 4 | from lspi.policy import Policy 5 | from lspi.basis_functions import FakeBasis, OneDimensionalPolynomialBasis 6 | import numpy as np 7 | from copy import copy 8 | 9 | class TestPolicy(TestCase): 10 | 11 | def create_policy(self, *args, **kwargs): 12 | return Policy(FakeBasis(5), *args, **kwargs) 13 | 14 | @staticmethod 15 | def list_has_duplicates(list, num_places=4): 16 | # verify that there are no duplicate q values. 17 | # round the q_values so that there are not small floating point 18 | # inconsistencies that lead to no duplicates being detected 19 | # Then make a set of the list. If there are no duplicates then the 20 | # cardinality of the set will match the length of the list 21 | rounded_list = map(lambda x: round(x, 4), list) 22 | return len(set(rounded_list)) < len(list) 23 | 24 | def setUp(self): 25 | self.poly_policy = Policy(OneDimensionalPolynomialBasis(1, 2), 26 | weights=np.array([1., 1, 2, 2])) 27 | self.state = np.array([-3.]) 28 | self.tie_weights = np.ones((4,)) 29 | 30 | def test_default_constructor(self): 31 | policy = self.create_policy() 32 | 33 | self.assertTrue(isinstance(policy.basis, FakeBasis)) 34 | self.assertAlmostEqual(policy.discount, 1.0) 35 | self.assertAlmostEqual(policy.explore, 0.0) 36 | self.assertEqual(policy.weights.shape, (1,)) 37 | self.assertEqual(policy.tie_breaking_strategy, 38 | Policy.TieBreakingStrategy.RandomWins) 39 | 40 | def test_full_constructor(self): 41 | policy = self.create_policy(.5, .1, np.array([1.]), 42 | Policy.TieBreakingStrategy.FirstWins) 43 | 44 | self.assertTrue(isinstance(policy.basis, FakeBasis)) 45 | self.assertAlmostEqual(policy.discount, .5) 46 | self.assertAlmostEqual(policy.explore, 0.1) 47 | np.testing.assert_array_almost_equal(policy.weights, np.array([1.])) 48 | self.assertEqual(policy.tie_breaking_strategy, 49 | Policy.TieBreakingStrategy.FirstWins) 50 | 51 | def test_discount_out_of_bounds(self): 52 | with self.assertRaises(ValueError): 53 | self.create_policy(discount=-1.0) 54 | 55 | with self.assertRaises(ValueError): 56 | self.create_policy(discount=1.1) 57 | 58 | def test_explore_out_of_bounds(self): 59 | with self.assertRaises(ValueError): 60 | self.create_policy(explore=-.01) 61 | 62 | with self.assertRaises(ValueError): 63 | self.create_policy(explore=1.1) 64 | 65 | def test_weight_basis_dimensions_mismatch(self): 66 | with self.assertRaises(ValueError): 67 | self.create_policy(weights=np.arange(2)) 68 | 69 | def test_copy(self): 70 | orig_policy = self.create_policy() 71 | policy_copy = copy(orig_policy) 72 | 73 | self.assertNotEqual(id(orig_policy), id(policy_copy)) 74 | self.assertEqual(orig_policy.num_actions, 75 | policy_copy.num_actions) 76 | self.assertEqual(orig_policy.discount, policy_copy.discount) 77 | self.assertEqual(orig_policy.explore, policy_copy.explore) 78 | np.testing.assert_array_almost_equal(orig_policy.weights, 79 | policy_copy.weights) 80 | 81 | self.assertNotEqual(id(orig_policy.weights), id(policy_copy.weights)) 82 | 83 | # verify that changing a weight in the original doesn't affect the copy 84 | orig_policy.weights[0] *= -1 85 | 86 | # numpy doesn't have an assert if not equal method 87 | # so to do the inverse I'm asserting the two arrays are equal 88 | # and expecting the assertion to fail 89 | with self.assertRaises(AssertionError): 90 | np.testing.assert_array_almost_equal(orig_policy.weights, 91 | policy_copy.weights) 92 | 93 | def test_calc_q_value_unit_weights(self): 94 | q_value = self.poly_policy.calc_q_value(self.state, 0) 95 | self.assertAlmostEqual(q_value, -2.) 96 | 97 | def test_calc_q_value_non_unit_weights(self): 98 | q_value = self.poly_policy.calc_q_value(self.state, 1) 99 | self.assertAlmostEqual(q_value, -4.) 100 | 101 | def test_calc_q_value_negative_action(self): 102 | with self.assertRaises(IndexError): 103 | self.poly_policy.calc_q_value(self.state, -1) 104 | 105 | def test_calc_q_value_out_of_bounds_action(self): 106 | with self.assertRaises(IndexError): 107 | self.poly_policy.calc_q_value(self.state, 2) 108 | 109 | def test_calc_q_value_mismatched_state_dimensions(self): 110 | with self.assertRaises(ValueError): 111 | self.poly_policy.calc_q_value(np.ones((2,)), 0) 112 | 113 | def test_best_action_no_ties(self): 114 | 115 | q_values = [self.poly_policy.calc_q_value(self.state, action) 116 | for action in range(self.poly_policy.num_actions)] 117 | 118 | self.assertFalse(TestPolicy.list_has_duplicates(q_values)) 119 | 120 | best_action = self.poly_policy.best_action(self.state) 121 | self.assertEqual(best_action, 0) 122 | 123 | def test_best_action_with_ties_first_wins(self): 124 | self.poly_policy.weights = self.tie_weights 125 | self.poly_policy.tie_breaking_strategy = \ 126 | Policy.TieBreakingStrategy.FirstWins 127 | 128 | q_values = [self.poly_policy.calc_q_value(self.state, action) 129 | for action in range(self.poly_policy.num_actions)] 130 | 131 | self.assertTrue(TestPolicy.list_has_duplicates(q_values)) 132 | 133 | best_action = self.poly_policy.best_action(self.state) 134 | self.assertEqual(best_action, 0) 135 | 136 | def test_best_action_with_ties_last_wins(self): 137 | self.poly_policy.weights = self.tie_weights 138 | self.poly_policy.tie_breaking_strategy = \ 139 | Policy.TieBreakingStrategy.LastWins 140 | 141 | q_values = [self.poly_policy.calc_q_value(self.state, action) 142 | for action in range(self.poly_policy.num_actions)] 143 | 144 | self.assertTrue(TestPolicy.list_has_duplicates(q_values)) 145 | 146 | best_action = self.poly_policy.best_action(self.state) 147 | self.assertEqual(best_action, 1) 148 | 149 | def test_best_action_with_ties_random_wins(self): 150 | self.poly_policy.weights = self.tie_weights 151 | self.poly_policy.tie_breaking_strategy = \ 152 | Policy.TieBreakingStrategy.RandomWins 153 | 154 | q_values = [self.poly_policy.calc_q_value(self.state, action) 155 | for action in range(self.poly_policy.num_actions)] 156 | 157 | self.assertTrue(TestPolicy.list_has_duplicates(q_values)) 158 | 159 | # select the best action num_times times 160 | num_times = 10 161 | best_actions = [self.poly_policy.best_action(self.state) 162 | for i in range(num_times)] 163 | 164 | # This test will fail if all of the actions selected either action 0 165 | # or action 1. When all action 0 is selected the sum will be 166 | # equal to 0. When all action 1 is taken the sum will be equal to 167 | # num_times 168 | self.assertLess(int(sum(best_actions)), num_times) 169 | self.assertNotEqual(int(sum(best_actions)), 0) 170 | 171 | def test_best_action_mismatched_state_dimensions(self): 172 | with self.assertRaises(ValueError): 173 | self.poly_policy.best_action(np.ones((2,))) 174 | 175 | def test_select_action_random(self): 176 | # first verify there are no ties 177 | # this way we know the tie breaking strategy isn't introducing 178 | # the randomness 179 | q_values = [self.poly_policy.calc_q_value(self.state, action) 180 | for action in range(self.poly_policy.num_actions)] 181 | 182 | self.assertFalse(TestPolicy.list_has_duplicates(q_values)) 183 | 184 | self.poly_policy.explore = 1.0 185 | self.poly_policy.tie_breaking_strategy = \ 186 | Policy.TieBreakingStrategy.FirstWins 187 | 188 | # this is set up to evaluate to no tie 189 | num_times = 10 190 | best_actions = [self.poly_policy.select_action(self.state) 191 | for i in range(num_times)] 192 | 193 | self.assertNotEqual(sum(best_actions), 0) 194 | self.assertNotEqual(sum(best_actions), num_times) 195 | 196 | def test_select_action_deterministic(self): 197 | # first verify there are no ties 198 | # this way we know the tie breaking strategy isn't introducing 199 | # the randomness 200 | q_values = [self.poly_policy.calc_q_value(self.state, action) 201 | for action in range(self.poly_policy.num_actions)] 202 | 203 | self.assertFalse(TestPolicy.list_has_duplicates(q_values)) 204 | 205 | self.poly_policy.explore = 0.0 206 | self.poly_policy.tie_breaking_strategy = \ 207 | Policy.TieBreakingStrategy.FirstWins 208 | 209 | # this is set up to evaluate to no tie 210 | num_times = 10 211 | best_actions = [self.poly_policy.select_action(self.state) 212 | for i in range(num_times)] 213 | self.assertEqual(sum(best_actions), 0) 214 | 215 | def test_select_action_mismatched_state_dimensions(self): 216 | with self.assertRaises(ValueError): 217 | self.poly_policy.select_action(np.ones((2,))) 218 | 219 | def test_num_actions_getter(self): 220 | self.assertEqual(self.poly_policy.num_actions, 221 | self.poly_policy.basis.num_actions) 222 | 223 | self.poly_policy.basis.num_actions = 10 224 | 225 | self.assertEqual(self.poly_policy.num_actions, 226 | self.poly_policy.basis.num_actions) 227 | 228 | def test_num_actions_setter(self): 229 | self.assertEqual(self.poly_policy.num_actions, 230 | self.poly_policy.basis.num_actions) 231 | 232 | self.poly_policy.num_actions = 10 233 | 234 | self.assertEqual(self.poly_policy.num_actions, 235 | self.poly_policy.basis.num_actions) -------------------------------------------------------------------------------- /lspi_testsuite/test_sample.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Tests for emodel.lspi.sample class.""" 3 | from unittest import TestCase 4 | 5 | from lspi import Sample 6 | 7 | 8 | class TestSample(TestCase): 9 | 10 | def setUp(self): # flake8: noqa 11 | """Set the constructor parameters to test with.""" 12 | self.state = [0, 1] 13 | self.action = 2 14 | self.reward = -1.5 15 | self.next_state = [1, 0] 16 | self.absorb = True 17 | 18 | def test_full_constructor(self): 19 | """Construct a Sample.""" 20 | 21 | sample = Sample(self.state, 22 | self.action, 23 | self.reward, 24 | self.next_state, 25 | self.absorb) 26 | 27 | self.assertEqual(sample.state, self.state) 28 | self.assertEqual(sample.action, self.action) 29 | self.assertAlmostEqual(sample.reward, self.reward, 3) 30 | self.assertEqual(sample.next_state, self.next_state) 31 | self.assertEqual(sample.absorb, self.absorb) 32 | 33 | def test_default_constructor(self): 34 | """Construct a Sample with default arguments.""" 35 | 36 | sample = Sample(self.state, 37 | self.action, 38 | self.reward, 39 | self.next_state) 40 | 41 | self.assertEqual(sample.state, self.state) 42 | self.assertEqual(sample.action, self.action) 43 | self.assertAlmostEqual(sample.reward, self.reward, 3) 44 | self.assertEqual(sample.next_state, self.next_state) 45 | self.assertEqual(sample.absorb, False) -------------------------------------------------------------------------------- /lspi_testsuite/test_solvers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Contains tests for the various solvers.""" 3 | from unittest import TestCase 4 | 5 | from lspi.basis_functions import ExactBasis 6 | from lspi.policy import Policy 7 | from lspi.sample import Sample 8 | from lspi.solvers import LSTDQSolver 9 | 10 | import numpy as np 11 | 12 | class TestLSTDQSolver(TestCase): 13 | def setUp(self): 14 | self.data = [Sample(np.array([0]), 0, 1, np.array([0])), 15 | Sample(np.array([1]), 0, -1, np.array([1]))] 16 | 17 | self.basis = ExactBasis([2], 1) 18 | self.policy = Policy(self.basis, 19 | .9, 20 | 0, 21 | np.zeros((2, )), 22 | Policy.TieBreakingStrategy.FirstWins) 23 | 24 | def test_precondition_value_set(self): 25 | """Test that precondition value is saved to solver.""" 26 | precondition_value = .3 27 | solver = LSTDQSolver(precondition_value) 28 | self.assertEqual(solver.precondition_value, precondition_value) 29 | 30 | def test_solve_method_full_rank_matrix(self): 31 | """Test that the solver works.""" 32 | 33 | solver = LSTDQSolver(precondition_value=0) 34 | 35 | weights = solver.solve(self.data, self.policy) 36 | 37 | expected_weights = np.array([10, -10]) 38 | 39 | np.testing.assert_array_almost_equal(weights, expected_weights) 40 | 41 | def test_solve_method_singular_matrix(self): 42 | """Test with singular matrix and no precondition.""" 43 | 44 | solver = LSTDQSolver(precondition_value=0) 45 | 46 | weights = solver.solve(self.data[:-1], self.policy) 47 | 48 | expected_weights = np.array([10, 0]) 49 | 50 | np.testing.assert_array_almost_equal(weights, expected_weights) 51 | 52 | def test_solve_method_singular_matrix_with_preconditiong(self): 53 | """Test with singluar matrix and preconditioning.""" 54 | 55 | solver = LSTDQSolver(precondition_value=.1) 56 | 57 | weights = solver.solve(self.data[:-1], self.policy) 58 | 59 | expected_weights = np.array([5, 0]) 60 | 61 | np.testing.assert_array_almost_equal(weights, expected_weights) 62 | 63 | def test_solve_method_with_absorbing_sample(self): 64 | """Test with absorbing sample.""" 65 | solver = LSTDQSolver(precondition_value=0) 66 | 67 | self.data[0].absorb = True 68 | weights = solver.solve(self.data, self.policy) 69 | 70 | expected_weights = np.array([1, -10]) 71 | 72 | np.testing.assert_array_almost_equal(weights, expected_weights) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coverage 2 | numpy 3 | scipy 4 | flake8==2.4.0 5 | flake8-import-order 6 | flake8-docstrings 7 | #flake8-print 8 | flake8-quotes 9 | sphinx 10 | sphinxcontrib-napoleon 11 | sphinx_rtd_theme 12 | mccabe==0.2.1 13 | wheel 14 | twine 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [nosetests] 2 | with-doctest=1 3 | with-coverage=1 4 | cover-erase=1 5 | cover-html=1 6 | cover-html-dir=htmlcov 7 | cover-package=lspi 8 | cover-branches=1 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """A setuptools based setup module.""" 2 | 3 | from setuptools import setup, find_packages 4 | from os import path 5 | from io import open 6 | 7 | here = path.abspath(path.dirname(__file__)) 8 | 9 | with open(path.join(here, 'DESCRIPTION.rst'), encoding='utf-8') as f: 10 | long_description = f.read() 11 | 12 | setup( 13 | name='lspi-python', 14 | version='1.0.1', 15 | description='LSPI algorithm in Python', 16 | long_description=long_description, 17 | url='https://github.com/rhololkeolke/lspi-python', 18 | author='Devin Schwab', 19 | author_email='digidevin@gmail.com', 20 | license='BSD-3-Clause', 21 | classifiers=[ 22 | 'Development Status :: 3 - Alpha', 23 | 'Intended Audience :: Science/Research', 24 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 25 | 'License :: OSI Approved :: BSD License', 26 | 'Programming Language :: Python :: 2.7' 27 | ], 28 | keywords='machinelearning ai', 29 | packages=find_packages(exclude=['docs', '*testsuite', 'test*']), 30 | install_requires=['numpy', 'scipy'], 31 | extras_require={ 32 | 'test': ['nosetests', 'coverage'] 33 | } 34 | ) 35 | --------------------------------------------------------------------------------