├── .coveragerc ├── .gitignore ├── .pylintrc ├── CMakeLists.txt ├── LICENSE ├── README.md ├── docs ├── internal_deps.png ├── internal_deps.puml └── sgb.md ├── examples ├── duplo │ ├── README.md │ ├── __init__.py │ ├── bt_collection.py │ ├── environment.py │ ├── execution_nodes.py │ ├── fitness_function.py │ ├── paths.py │ ├── run_execute_bt.py │ ├── run_learn_bt.py │ └── world.py ├── simple_execution_demo │ ├── __init__.py │ ├── bt_collection.py │ ├── dummy_world.py │ ├── execution_nodes.py │ ├── paths.py │ └── run_execute_bt.py └── tiago_pnp │ ├── README.md │ ├── __init__.py │ ├── bt_collection.py │ ├── environment.py │ ├── execution_nodes.py │ ├── fitness_function.py │ ├── paths.py │ ├── run_execute_bt.py │ ├── run_learn_bt.py │ └── world.py ├── package.xml ├── requirements.txt ├── setup.py └── src └── behavior_tree_learning ├── __init__.py ├── core ├── __init__.py ├── gp │ ├── __init__.py │ ├── algorithm.py │ ├── environment.py │ ├── hash_table.py │ ├── operators.py │ ├── parameters.py │ ├── selection.py │ └── steps.py ├── gp_sbt │ ├── __init__.py │ ├── environment.py │ ├── fitness_function.py │ ├── gp_operators.py │ ├── learning.py │ └── world_factory.py ├── logger │ ├── __init__.py │ └── logplot.py ├── planner │ ├── __init__.py │ ├── node_factory.py │ └── planner.py ├── plotter │ ├── __init__.py │ └── print_functions.py └── sbt │ ├── __init__.py │ ├── behavior_tree.py │ ├── behaviors.py │ ├── executor.py │ ├── graphics.py │ ├── node_factory.py │ ├── parse_operation.py │ ├── py_tree.py │ └── world.py ├── gp.py ├── learning.py ├── sbt.py └── tests ├── fwk ├── BT_SETTINGS.yaml ├── __init__.py ├── behavior_nodes.py └── world.py ├── paths.py ├── test_gp_algorithm.py ├── test_gp_hash_table.py ├── test_gp_sbt_operators.py ├── test_gp_selection.py ├── test_logplot.py ├── test_parse_operation.py ├── test_sbt_btsr.py ├── test_sbt_node_factory.py └── test_sbt_pytree.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | omit = 5 | *\__init__.py 6 | *\site-packages\* 7 | *eloquent* 8 | *lib* 9 | *\tests\* 10 | *\py_trees-release* 11 | [report] 12 | # Regexes for lines to exclude from consideration 13 | exclude_lines = 14 | # Have to re-enable the standard pragma 15 | pragma: no cover 16 | 17 | ignore_errors = True 18 | show_missing = True 19 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | logs -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 2.8.3) 2 | project(behavior_tree_learning) 3 | 4 | find_package(catkin REQUIRED COMPONENTS 5 | rospy 6 | ) 7 | 8 | catkin_python_setup() 9 | catkin_package( 10 | LIBRARIES ${PROJECT_NAME} 11 | CATKIN_DEPENDS rospy 12 | ) 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2021 Jonathan Styrud, Matteo Iovino, Diego Escudero (dgerod@xyz-lab.org.es) 190 | Copyright 2022 Diego Escudero (dgerod@xyz-lab.org.es) 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Behavior Trees using Genetic Programming 2 | 3 | This repository contains an implementation of a Genetic Programming (GP) algorithm that evolves Behavior Trees (BTs) to solve different tasks. 4 | 5 | This repository is based on: 6 | * https://github.com/jstyrud/planning-and-learning 7 | * https://github.com/matiov/learn-BTs-with-GP 8 | 9 | References: 10 | * __Towards Blended Reactive Planning and Acting using Behavior Trees__. 11 | Colledanchise, Michele & Almeida, Diogo & Ogren, Petter. 12 | ICRA-2019, May 2019. DOI:10.1109/ICRA.2019.8794128. [PDF](https://arxiv.org/pdf/1611.00230.pdf) 13 | * __A Survey of Behavior Trees in Robotics and AI__. 14 | Iovino, Matteo & Scukins, Edvards & Styrud, Jonathan & Ogren, Petter & Smith, Christian. 15 | May 2020. [PDF](https://arxiv.org/pdf/2005.05842.pdf) 16 | * __Behavior Trees in Robotics and AI: An Introduction__. 17 | Colledanchise, Michele & Ogren, Petter. 18 | June 2020. [PDF](https://arxiv.org/pdf/1709.00084.pdf) 19 | * __Genetic Programming__. 20 | Accessed on December 22, 2021. [Web](https://geneticprogramming.com) 21 | * __A Field Guide to Genetic Programming__. 22 | Poli, Riccardo & Langdon, William & Mcphee, Nicholas. 23 | 2008. ISBN 978-1-4092-0073-4. 24 | * __Learning Behavior Trees with Genetic Programming in Unpredictable Environments__. 25 | Iovino, Matteo & Styrud, Jonathan & Falco, Pietro & Smith, Christian. 26 | ICRA 2021, May 2021. DOI:10.1109/ICRA48506.2021.9562088. [PDF](https://arxiv.org/pdf/2011.03252v1.pdf) 27 | * __Combining Planning and Learning of Behavior Trees for Robotic Assembly__. 28 | Styrud, Jonathan & Iovino, Matteo & Norrlöf, Mikael & Björkman, Mårten & Smith, Christian. 29 | March 2021. [PDF](https://arxiv.org/pdf/2103.09036v1.pdf) 30 | 31 | Other references: 32 | * __Combining Context Awareness and Planning to Learn Behavior Trees from Demonstration__. 33 | Gustavsson, Oscar & Iovino, Matteo & Styrud, Jonathan & Smith, Christian. 34 | September 2021. [PDF](https://arxiv.org/pdf/2109.07133.pdf) 35 | * __Integrating Reinforcement Learning into Behavior Trees by Hierarchical Composition__. 36 | Kartašev, Mart. 37 | In Degree Project Computer Sciene and Engineering, KTH, Stockholm (Sweden) 38 | 2019. [PDF](https://www.diva-portal.org/smash/get/diva2:1368535/FULLTEXT01.pdf) 39 | * __Learning Behavior Trees From Demonstration__. 40 | French, Kevin & Wu, Shiyu & Pan, Tianyang & Zhou, Zheming & Jenkins, Odest Chadwicke. 41 | ICRA-2019, May 2019. [IEEE](https://ieeexplore.ieee.org/document/8794104) 42 | 43 | ### Installation 44 | 45 | After cloning the repository, run the following command to install the correct dependencies: 46 | ```bash 47 | pip3 install -r requirements.txt 48 | ``` 49 | 50 | To check the package is working well you should execute all the tests. So, move to the 51 | test directory of the package and execute them: 52 | ```bash 53 | cd %PACKAGE_DIRECTORY%/src/behavior_tree_learning/tests 54 | python -m unittest discover -s . -p 'test_*.py' 55 | ``` 56 | ### Examples 57 | 58 | Execute an existing behavior tree stored in "bt_collection.py": 59 | ```bash 60 | cd %PACKAGE_DIRECTORY%/examples/duplo 61 | python ./run_execute_bt.py 62 | ``` 63 | 64 | Learn a behavior tree using genetic programming: 65 | ```bash 66 | cd %PACKAGE_DIRECTORY%/examples/duplo 67 | python ./run_learn_bt.py 68 | ``` 69 | -------------------------------------------------------------------------------- /docs/internal_deps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgerod/behavior_tree_learning/71da80c91ecd48fd5da377f83604b62112ba9629/docs/internal_deps.png -------------------------------------------------------------------------------- /docs/internal_deps.puml: -------------------------------------------------------------------------------- 1 | @startuml 2 | 3 | !theme toy 4 | [GP-SBT] -> [GP] 5 | [GP] -> [Logger] 6 | [GP-SBT] -> [SBT] 7 | [Plotter] -> [SBT] 8 | 9 | @enduml -------------------------------------------------------------------------------- /docs/sgb.md: -------------------------------------------------------------------------------- 1 | Software Guidebook 2 | ================== 3 | 4 | ## Introduction 5 | 6 | ## API 7 | 8 | ### BehaviorTreeExecutor 9 | 10 | ### BehaviorTreeLearner 11 | 12 | ## Internal Structure 13 | ### Dependencies 14 | 15 | ![Packages dependencies](internal_deps.png) 16 | 17 | ## How to extend 18 | -------------------------------------------------------------------------------- /examples/duplo/README.md: -------------------------------------------------------------------------------- 1 | Picking a Lego Duplo 2 | ==================== 3 | 4 | Example adapted from https://github.com/jstyrud/planning-and-learning, the world is 5 | implemented using a state machine. -------------------------------------------------------------------------------- /examples/duplo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgerod/behavior_tree_learning/71da80c91ecd48fd5da377f83604b62112ba9629/examples/duplo/__init__.py -------------------------------------------------------------------------------- /examples/duplo/bt_collection.py: -------------------------------------------------------------------------------- 1 | _bt_1 = ['s(', 2 | 'apply force 0!', 3 | 'apply force 0!', 4 | 'apply force 0!', 5 | ')'] 6 | 7 | _bt_2 = ['f(', 8 | '0 at pos (0.0, 0.05, 0.0)?', 9 | 's(', 10 | 'pick 1!', 11 | 'place on 0!', 12 | 'pick 2!', 13 | 'place on 1!', 14 | ')', 15 | ')'] 16 | 17 | 18 | def select_bt(idx: int): 19 | collection = [_bt_1, _bt_2] 20 | return collection[idx] 21 | -------------------------------------------------------------------------------- /examples/duplo/environment.py: -------------------------------------------------------------------------------- 1 | from interface import implements 2 | from behavior_tree_learning.sbt import BehaviorTreeExecutor, ExecutionParameters 3 | from behavior_tree_learning.sbt import StringBehaviorTree, BehaviorNodeFactory 4 | from behavior_tree_learning.learning import Environment 5 | from duplo.world import ApplicationWorldFactory 6 | from duplo.fitness_function import FitnessFunction 7 | 8 | 9 | class ApplicationEnvironment(implements(Environment)): 10 | 11 | def __init__(self, node_factory: BehaviorNodeFactory, world_factory: ApplicationWorldFactory, 12 | target_positions, 13 | static_tree=None, fitness_coefficients=None, verbose=False): 14 | 15 | self._node_factory = node_factory 16 | self._world_factory = world_factory 17 | self._verbose = verbose 18 | 19 | self._targets = target_positions 20 | self._static_tree = static_tree 21 | self._fitness_coefficients = fitness_coefficients 22 | self._random_events = False 23 | 24 | def run_and_compute(self, individual, verbose): 25 | 26 | verbose_enabled = self._verbose or verbose 27 | 28 | sbt = list(individual) 29 | if verbose_enabled: 30 | print("SBT: ", sbt) 31 | 32 | world = self._world_factory.make() 33 | 34 | tree = StringBehaviorTree(sbt, behaviors=self._node_factory, world=world, verbose=verbose) 35 | success, ticks = tree.run_bt(parameters=ExecutionParameters(successes_required=1)) 36 | 37 | fitness = FitnessFunction().compute_cost(world, tree, ticks, self._targets, 38 | self._fitness_coefficients, verbose=verbose) 39 | 40 | if verbose_enabled: 41 | print("fitness: ", fitness) 42 | 43 | return fitness 44 | 45 | def plot_individual(self, path, plot_name, individual): 46 | """ Saves a graphical representation of the individual """ 47 | 48 | sbt = list(individual) 49 | 50 | if self._static_tree is not None: 51 | tree = StringBehaviorTree(self._add_to_static_tree(sbt), behaviors=self._node_factory) 52 | else: 53 | tree = StringBehaviorTree(sbt[:], behaviors=self._node_factory) 54 | 55 | tree.save_figure(path, name=plot_name) 56 | -------------------------------------------------------------------------------- /examples/duplo/execution_nodes.py: -------------------------------------------------------------------------------- 1 | import re 2 | import py_trees as pt 3 | from behavior_tree_learning.sbt import BehaviorRegister, BehaviorNode 4 | from duplo import world as sm 5 | 6 | 7 | class HandEmpty(BehaviorNode): 8 | """ 9 | Check if hand is empty 10 | """ 11 | 12 | @staticmethod 13 | def make(text, world, verbose=False): 14 | return HandEmpty(text, world, verbose) 15 | 16 | def __init__(self, name, world, verbose): 17 | self._world = world 18 | super(HandEmpty, self).__init__(name) 19 | 20 | def update(self): 21 | if self._world.hand_empty(): 22 | return pt.common.Status.SUCCESS 23 | return pt.common.Status.FAILURE 24 | 25 | 26 | class Picked(BehaviorNode): 27 | """ 28 | Check if brick is picked 29 | """ 30 | 31 | @staticmethod 32 | def make(text, world, verbose=False): 33 | return Picked(text, world, verbose, re.findall(r'\d+', text)) 34 | 35 | def __init__(self, name, world, verbose, brick): 36 | self._world = world 37 | self._brick = int(brick[0]) 38 | super(Picked, self).__init__(name) 39 | 40 | def update(self): 41 | if self._world.get_picked() == self._brick: 42 | return pt.common.Status.SUCCESS 43 | return pt.common.Status.FAILURE 44 | 45 | 46 | class AtPos(BehaviorNode): 47 | """ 48 | Check if brick is at position 49 | """ 50 | 51 | @staticmethod 52 | def make(text, world, verbose=False): 53 | return AtPos(text, world, verbose, re.findall(r'-?\d+\.\d+|-?\d+', text)) 54 | 55 | def __init__(self, name, world, verbose, brick_and_pos): 56 | self._world = world 57 | self._brick = int(brick_and_pos[0]) 58 | self._pos = sm.Pos(float(brick_and_pos[1]), float(brick_and_pos[2]), float(brick_and_pos[3])) 59 | self._verbose = verbose 60 | super(AtPos, self).__init__(name) 61 | 62 | def update(self): 63 | if self._world.distance(self._brick, self._pos) < self._world.sm_par.pos_margin: 64 | if self._verbose: 65 | print(self.name, ": SUCCESS") 66 | return pt.common.Status.SUCCESS 67 | if self._verbose: 68 | print(self.name, ": FAILURE") 69 | return pt.common.Status.FAILURE 70 | 71 | 72 | class On(BehaviorNode): 73 | """ 74 | Check if one brick is on other brick 75 | """ 76 | 77 | @staticmethod 78 | def make(text, world, verbose=False): 79 | return On(text, world, verbose, re.findall(r'\d+', text)) 80 | 81 | def __init__(self, name, world, verbose, bricks): 82 | self._world = world 83 | self._upper = int(bricks[0]) 84 | self._lower = int(bricks[1]) 85 | self._verbose = verbose 86 | super(On, self).__init__(name) 87 | 88 | def update(self): 89 | if self._world.on_top(self._world.state.bricks[self._upper], 90 | self._world.state.bricks[self._lower]): 91 | if self._verbose: 92 | print(self.name, ": SUCCESS") 93 | return pt.common.Status.SUCCESS 94 | if self._verbose: 95 | print(self.name, ": FAILURE") 96 | return pt.common.Status.FAILURE 97 | 98 | 99 | class StateMachineBehavior(BehaviorNode): 100 | """ 101 | Class template for state machine behaviors 102 | """ 103 | 104 | def __init__(self, name, world, verbose=False): 105 | self._world = world 106 | self._state = None 107 | self._verbose = verbose 108 | super(StateMachineBehavior, self).__init__(name) 109 | 110 | def update(self): 111 | if self._verbose and self._state == pt.common.Status.RUNNING: 112 | print(self.name, ":", self._state) 113 | 114 | def success(self): 115 | self._state = pt.common.Status.SUCCESS 116 | if self._verbose: 117 | print(self.name, ": SUCCESS") 118 | 119 | def failure(self): 120 | self._state = pt.common.Status.FAILURE 121 | if self._verbose: 122 | print(self.name, ": FAILURE") 123 | 124 | 125 | class Pick(StateMachineBehavior): 126 | """ 127 | Pick up a brick 128 | """ 129 | 130 | @staticmethod 131 | def make(text, world, verbose=False): 132 | return Pick(text, world, re.findall(r'\d+', text), verbose) 133 | 134 | def __init__(self, name, world, brick, verbose): 135 | self._brick = int(brick[0]) 136 | super(Pick, self).__init__(name, world, verbose) 137 | 138 | def initialise(self): 139 | if self._world.get_picked() == self._brick: 140 | self.success() 141 | elif self._world.get_picked() is not None: 142 | self.failure() 143 | else: 144 | self._state = None 145 | 146 | def update(self): 147 | super(Pick, self).update() 148 | 149 | if self._state is None: 150 | self._state = pt.common.Status.RUNNING 151 | elif self._state is pt.common.Status.RUNNING: 152 | if self._world.pick(self._brick): 153 | self.success() 154 | else: 155 | self.failure() 156 | self._world.random_event() 157 | return self._state 158 | 159 | 160 | class Place(StateMachineBehavior): 161 | """ 162 | Place current brick at given position 163 | """ 164 | 165 | @staticmethod 166 | def make(text, world, verbose=False): 167 | if 'place at' in text: 168 | return Place(text, world, position=re.findall(r'-?\d+\.\d+|-?\d+', text), verbose=verbose) 169 | elif 'place on' in text: 170 | return Place(text, world, brick=re.findall(r'\d+', text), verbose=verbose) 171 | else: 172 | raise ValueError('Unknown [%s] node' % text) 173 | 174 | def __init__(self, name, world, brick=None, position=None, verbose=False): 175 | # pylint: disable=too-many-arguments 176 | if brick is not None: 177 | self._brick = int(brick[0]) 178 | self._position = None 179 | elif position is not None: 180 | self._position = sm.Pos(float(position[0]), float(position[1]), float(position[2])) 181 | self._brick = None 182 | super(Place, self).__init__(name, world, verbose) 183 | 184 | def initialise(self): 185 | if self._world.get_picked() is None: 186 | self.failure() 187 | else: 188 | self._state = None 189 | 190 | def update(self): 191 | super(Place, self).update() 192 | 193 | if self._state is None: 194 | self._state = pt.common.Status.RUNNING 195 | elif self._state is pt.common.Status.RUNNING: 196 | if self._brick is not None: 197 | success = self._world.place(brick=self._brick) 198 | else: 199 | success = self._world.place(position=self._position) 200 | if success: 201 | self.success() 202 | else: 203 | self.failure() 204 | self._world.random_event() 205 | return self._state 206 | 207 | 208 | class Put(StateMachineBehavior): 209 | """ 210 | Picks brick and places it on other brick 211 | """ 212 | 213 | @staticmethod 214 | def make(text, world, verbose=False): 215 | return Put(text, world, re.findall(r'-?\d+\.\d+|-?\d+', text), verbose) 216 | 217 | def __init__(self, name, world, brick_and_pos, verbose): 218 | self._brick = int(brick_and_pos[0]) 219 | if len(brick_and_pos) > 2: 220 | self._position = sm.Pos(float(brick_and_pos[1]), float(brick_and_pos[2]), float(brick_and_pos[3])) 221 | self._lower = None 222 | else: 223 | self._lower = int(brick_and_pos[1]) 224 | self._position = None 225 | 226 | super(Put, self).__init__(name, world, verbose) 227 | 228 | def initialise(self): 229 | if self._lower is not None: 230 | if self.world_interface.on_top(self.world_interface.state.bricks[self._brick], 231 | self.world_interface.state.bricks[self._lower]): 232 | self.success() 233 | else: 234 | self._state = None 235 | elif self._world.distance(self._brick, self._position) < self._world.sm_par.pos_margin: 236 | self.success() 237 | elif self._world.get_picked() is not None and self._world.get_picked() != self._brick: 238 | self.failure() 239 | else: 240 | self._state = None 241 | 242 | def update(self): 243 | super(Put, self).update() 244 | 245 | if self._state is None: 246 | self._state = pt.common.Status.RUNNING 247 | elif self._state is pt.common.Status.RUNNING: 248 | success = self._world.pick(self.brick) 249 | if success: 250 | if self._lower is not None: 251 | success = self._world.place(brick=self._lower) 252 | else: 253 | success = self._world.place(position=self._position) 254 | if success: 255 | self.success() 256 | else: 257 | self.failure() 258 | self._world.random_event() 259 | return self._state 260 | 261 | 262 | class ApplyForce(StateMachineBehavior): 263 | """ 264 | Apply force on given brick 265 | """ 266 | 267 | @staticmethod 268 | def make(text, world, verbose=False): 269 | return ApplyForce(text, world, re.findall(r'\d+', text), verbose) 270 | 271 | def __init__(self, name, world, brick, verbose): 272 | self._brick = int(brick[0]) 273 | super(ApplyForce, self).__init__(name, world, verbose) 274 | 275 | def initialise(self): 276 | if self._world.get_picked() is not None: 277 | self.failure() 278 | else: 279 | self._state = None 280 | 281 | def update(self): 282 | super(ApplyForce, self).update() 283 | if self._state is None: 284 | self._state = pt.common.Status.RUNNING 285 | elif self._state is pt.common.Status.RUNNING: 286 | if self._world.apply_force(self._brick): 287 | self.success() 288 | else: 289 | self.failure() 290 | self._world.random_event() 291 | return self._state 292 | 293 | 294 | def _make_tower_nodes(): 295 | 296 | behavior_register = BehaviorRegister() 297 | behavior_register.add_condition('picked 0?', Picked) 298 | behavior_register.add_condition('picked 1?', Picked) 299 | behavior_register.add_condition('picked 2?', Picked) 300 | behavior_register.add_condition('0 at pos (0.0, 0.05, 0.0)?', AtPos) 301 | behavior_register.add_condition('1 at pos (0.0, 0.05, 0.0192)?', AtPos) 302 | behavior_register.add_condition('2 at pos (0.0, 0.05, 0.0384)?', AtPos) 303 | behavior_register.add_condition('0 on 1?', On) 304 | behavior_register.add_condition('0 on 2?', On) 305 | behavior_register.add_condition('1 on 0?', On) 306 | behavior_register.add_condition('1 on 2?', On) 307 | behavior_register.add_condition('2 on 0?', On) 308 | behavior_register.add_condition('2 on 1?', On) 309 | behavior_register.add_action('pick 0!', Pick) 310 | behavior_register.add_action('pick 1!', Pick) 311 | behavior_register.add_action('pick 2!', Pick) 312 | behavior_register.add_action('place at (0.0, 0.05, 0.0)!', Place) 313 | behavior_register.add_action('place on 0!', Place) 314 | behavior_register.add_action('place on 1!', Place) 315 | behavior_register.add_action('place on 2!', Place) 316 | behavior_register.add_action('apply force 0!', ApplyForce) 317 | behavior_register.add_action('apply force 1!', ApplyForce) 318 | behavior_register.add_action('apply force 2!', ApplyForce) 319 | 320 | return behavior_register 321 | 322 | 323 | def _make_croissant_nodes(): 324 | 325 | behavior_register = BehaviorRegister() 326 | behavior_register.add_condition('picked 0?', Picked) 327 | behavior_register.add_condition('picked 1?', Picked) 328 | behavior_register.add_condition('picked 2?', Picked) 329 | behavior_register.add_condition('picked 3?', Picked) 330 | behavior_register.add_condition('0 at pos (0.0, 0.0, 0.0)?', AtPos) 331 | behavior_register.add_condition('1 at pos (0.0, 0.0, 0.0192)?', AtPos) 332 | behavior_register.add_condition('2 at pos (0.016, -0.032, 0.0)?', AtPos) 333 | behavior_register.add_condition('3 at pos (0.016, 0.032, 0.0)?', AtPos) 334 | behavior_register.add_condition('0 on 1?', On) 335 | behavior_register.add_condition('0 on 2?', On) 336 | behavior_register.add_condition('0 on 3?', On) 337 | behavior_register.add_condition('1 on 0?', On) 338 | behavior_register.add_condition('1 on 2?', On) 339 | behavior_register.add_condition('1 on 3?', On) 340 | behavior_register.add_condition('2 on 0?', On) 341 | behavior_register.add_condition('2 on 1?', On) 342 | behavior_register.add_condition('2 on 3?', On) 343 | behavior_register.add_condition('3 on 0?', On) 344 | behavior_register.add_condition('3 on 1?', On) 345 | behavior_register.add_condition('3 on 2?', On) 346 | behavior_register.add_action('pick 0!', Pick) 347 | behavior_register.add_action('pick 1!', Pick) 348 | behavior_register.add_action('pick 2!', Pick) 349 | behavior_register.add_action('pick 3!', Pick) 350 | behavior_register.add_action('place at (0.0, 0.05, 0.0)!', Place) 351 | behavior_register.add_action('place at (0.016, -0.032, 0.0)!', Place) 352 | behavior_register.add_action('place at (0.016, 0.032, 0.0)!', Place) 353 | behavior_register.add_action('place on 0!', Place) 354 | behavior_register.add_action('place on 1!', Place) 355 | behavior_register.add_action('place on 2!', Place) 356 | behavior_register.add_action('place on 3!', Place) 357 | behavior_register.add_action('apply force 0!', ApplyForce) 358 | behavior_register.add_action('apply force 1!', ApplyForce) 359 | behavior_register.add_action('apply force 2!', ApplyForce) 360 | behavior_register.add_action('apply force 3!', ApplyForce) 361 | 362 | return behavior_register 363 | 364 | 365 | def get_behaviors(name): 366 | 367 | if name == 'tower': 368 | return _make_tower_nodes() 369 | elif name == 'croissant': 370 | return _make_croissant_nodes() 371 | else: 372 | raise ValueError('Unknown %s name', name) 373 | -------------------------------------------------------------------------------- /examples/duplo/fitness_function.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from duplo.world import ApplicationWorld 3 | 4 | 5 | @dataclass 6 | class Coefficients: 7 | """ 8 | Coefficients for tuning the cost function 9 | """ 10 | task_completion: float = 1000.0 11 | pos_acc: float = 0.0004 12 | depth: float = 0.0 13 | length: float = 0.1 14 | ticks: float = 0.0 15 | failed: float = 50.0 16 | timeout: float = 10.0 17 | hand_not_empty: float = 0.0 18 | 19 | 20 | class FitnessFunction: 21 | 22 | def compute_cost(self, world: ApplicationWorld, behavior_tree, ticks, targets, coefficients=None, verbose=False): 23 | 24 | if coefficients is None: 25 | coefficients = Coefficients() 26 | 27 | depth = behavior_tree.depth 28 | length = behavior_tree.length 29 | 30 | cost = (coefficients.length * length + coefficients.depth * depth + 31 | coefficients.ticks * ticks) 32 | 33 | if verbose: 34 | print("Cost from length:", cost) 35 | 36 | for i in range(len(targets)): 37 | cost += coefficients.task_completion * max(0, world.distance(i, targets[i]) - coefficients.pos_acc) 38 | if verbose: 39 | print("Cost:", cost) 40 | 41 | if behavior_tree.failed: 42 | cost += coefficients.failed 43 | if verbose: 44 | print("Failed: ", cost) 45 | if behavior_tree.timeout: 46 | cost += coefficients.timeout 47 | if verbose: 48 | print("Timed out: ", cost) 49 | if world.get_picked() is not None: 50 | cost += coefficients.hand_not_empty 51 | if verbose: 52 | print("Hand not empty: ", cost) 53 | 54 | fitness = -cost 55 | return fitness 56 | -------------------------------------------------------------------------------- /examples/duplo/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | _this_file_path = os.path.abspath(__file__) 6 | _PACKAGE_DIRECTORY = os.path.dirname(os.path.dirname(os.path.dirname(_this_file_path))) 7 | _EXAMPLES_DIRECTORY = os.path.dirname(os.path.dirname(_this_file_path)) 8 | _CURRENT_EXAMPLE_DIRECTORY = os.path.dirname(_this_file_path) 9 | 10 | 11 | def add_modules_to_path(): 12 | sys.path.append(os.path.normpath(_PACKAGE_DIRECTORY)) 13 | sys.path.append(os.path.normpath(_EXAMPLES_DIRECTORY)) 14 | 15 | 16 | def get_example_directory(): 17 | return _CURRENT_EXAMPLE_DIRECTORY 18 | 19 | 20 | def get_outputs_directory(): 21 | return os.path.join(_CURRENT_EXAMPLE_DIRECTORY, 'results') 22 | 23 | 24 | def get_log_directory(): 25 | return os.path.join(_CURRENT_EXAMPLE_DIRECTORY, 'logs') 26 | -------------------------------------------------------------------------------- /examples/duplo/run_execute_bt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import os 7 | from behavior_tree_learning.sbt import BehaviorTreeExecutor, ExecutionParameters 8 | from behavior_tree_learning.sbt import BehaviorNodeFactory 9 | 10 | from duplo.paths import get_log_directory 11 | from duplo import bt_collection 12 | from duplo.execution_nodes import get_behaviors 13 | from duplo.world import ApplicationWorld 14 | from duplo.world import Pos as WorldPos 15 | 16 | 17 | def run(): 18 | 19 | scenario_name = 'tower' 20 | node_factory = BehaviorNodeFactory(get_behaviors(scenario_name)) 21 | 22 | sbt_1 = bt_collection.select_bt(0) 23 | sbt_2 = bt_collection.select_bt(1) 24 | trials = [sbt_1, sbt_2] 25 | 26 | start_position = [WorldPos(-0.05, -0.1, 0), WorldPos(0.0, -0.1, 0), WorldPos(0.05, -0.1, 0)] 27 | 28 | for tdx, trial in zip(range(1, len(trials)+1), trials): 29 | 30 | print("Trial: %d" % tdx) 31 | 32 | sbt = list(trial) 33 | print("SBT: ", sbt) 34 | 35 | simulated_world = ApplicationWorld(start_position, scenario=scenario_name) 36 | bt_executor = BehaviorTreeExecutor(node_factory, simulated_world) 37 | 38 | success, ticks, tree = bt_executor.run(sbt, ExecutionParameters(successes_required=1), 39 | verbose=True) 40 | 41 | try: 42 | os.mkdir(get_log_directory()) 43 | except OSError: 44 | pass 45 | 46 | file_name = 'trial_%d' % tdx 47 | tree.save_figure(get_log_directory(), name=file_name) 48 | print("Succeed: ", success) 49 | 50 | 51 | if __name__ == "__main__": 52 | run() 53 | -------------------------------------------------------------------------------- /examples/duplo/run_learn_bt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import os 7 | import logging 8 | 9 | from behavior_tree_learning.sbt import BehaviorNodeFactory 10 | from behavior_tree_learning.learning import BehaviorTreeLearner, GeneticParameters, GeneticSelectionMethods 11 | from behavior_tree_learning.learning import TraceConfiguration 12 | 13 | from duplo.execution_nodes import get_behaviors 14 | from duplo.world import Pos as WorldPos 15 | from duplo.world import ApplicationWorldFactory 16 | from duplo.environment import ApplicationEnvironment 17 | 18 | 19 | def _configure_logger(level, directory_path, name): 20 | 21 | for handler in logging.root.handlers[:]: 22 | logging.root.removeHandler(handler) 23 | 24 | try: 25 | file_path = os.path.join(directory_path, name + '.log') 26 | os.mkdir(directory_path) 27 | except: 28 | pass 29 | 30 | logging.basicConfig(filename=file_path, 31 | format='%(filename)s: %(message)s') 32 | logging.getLogger("gp").setLevel(level) 33 | 34 | 35 | def _plot_summary(outputs_dir_path, scenario_name, trials): 36 | 37 | from behavior_tree_learning.core.logger import logplot 38 | 39 | parameters = logplot.PlotParameters() 40 | parameters.plot_std = True 41 | parameters.xlabel = 'Episodes' 42 | parameters.ylabel = 'Fitness' 43 | parameters.mean_color = 'r' 44 | parameters.std_color = 'r' 45 | parameters.horizontal = -3.0 46 | parameters.save_fig = True 47 | parameters.save_fig = True 48 | parameters.path = os.path.join(outputs_dir_path, scenario_name + '.pdf') 49 | 50 | logplot.plot_learning_curves(trials, parameters) 51 | 52 | 53 | def _prepare_scenarios(): 54 | 55 | scenarios = [] 56 | 57 | scenario_name = 'tower' 58 | start_position = [WorldPos(-0.05, -0.1, 0), WorldPos(0.0, -0.1, 0), WorldPos(0.05, -0.1, 0)] 59 | target_position = [WorldPos(0.0, 0.05, 0), WorldPos(0.0, 0.05, 0.0192), WorldPos(0.0, 0.05, 2 * 0.0192)] 60 | scenarios.append((scenario_name, start_position, target_position)) 61 | 62 | scenario_name = 'croissant' 63 | start_position = [WorldPos(-0.05, -0.1, 0), WorldPos(0.05, -0.1, 0), WorldPos(0.05, 0.1, 0), 64 | WorldPos(-0.05, 0.1, 0)] 65 | target_position = [WorldPos(0.0, 0.0, 0.0), WorldPos(0.0, 0.0, 0.0192), WorldPos(0.016, -0.032, 0.0), 66 | WorldPos(0.016, 0.032, 0.0)] 67 | scenarios.append((scenario_name, start_position, target_position)) 68 | 69 | return scenarios 70 | 71 | 72 | def run(): 73 | 74 | parameters = GeneticParameters() 75 | 76 | parameters.n_generations = 200 77 | parameters.fitness_threshold = 0. 78 | 79 | parameters.n_population = 16 80 | parameters.ind_start_length = 8 81 | parameters.f_crossover = 0.5 82 | parameters.n_offspring_crossover = 2 83 | parameters.replace_crossover = False 84 | parameters.f_mutation = 0.5 85 | parameters.n_offspring_mutation = 2 86 | parameters.parent_selection = GeneticSelectionMethods.RANK 87 | parameters.survivor_selection = GeneticSelectionMethods.RANK 88 | parameters.f_elites = 0.1 89 | parameters.f_parents = parameters.f_elites 90 | parameters.mutate_co_offspring = False 91 | parameters.mutate_co_parents = True 92 | parameters.mutation_p_add = 0.4 93 | parameters.mutation_p_delete = 0.3 94 | parameters.allow_identical = False 95 | 96 | tracer = TraceConfiguration() 97 | tracer.plot_fitness = True 98 | tracer.plot_best_individual = True 99 | tracer.plot_last_generation = True 100 | 101 | scenarios = _prepare_scenarios() 102 | for scenario_name, start_position, target_position in scenarios: 103 | 104 | num_trials = 10 105 | trials = [] 106 | for tdx in range(1, num_trials+1): 107 | 108 | trial_name = scenario_name + '_' + str(tdx) 109 | trials.append(trial_name) 110 | print("Trial: %s" % trial_name) 111 | 112 | log_name = trial_name 113 | _configure_logger(logging.DEBUG, paths.get_log_directory(), log_name) 114 | 115 | parameters.log_name = log_name 116 | seed = tdx 117 | 118 | node_factory = BehaviorNodeFactory(get_behaviors(scenario_name)) 119 | world_factory = ApplicationWorldFactory(start_position, scenario=scenario_name) 120 | environment = ApplicationEnvironment(node_factory, world_factory, target_position, verbose=False) 121 | 122 | bt_learner = BehaviorTreeLearner.from_environment(environment) 123 | success = bt_learner.run(parameters, seed, 124 | outputs_dir_path=paths.get_outputs_directory(), 125 | trace_conf=tracer, 126 | verbose=False) 127 | 128 | print("Trial: %d, Succeed: %s" % (tdx, success)) 129 | 130 | _plot_summary(paths.get_outputs_directory(), scenario_name, trials) 131 | 132 | 133 | if __name__ == "__main__": 134 | run() 135 | 136 | -------------------------------------------------------------------------------- /examples/duplo/world.py: -------------------------------------------------------------------------------- 1 | import random 2 | from math import sqrt 3 | from enum import IntEnum 4 | from dataclasses import dataclass 5 | from copy import copy 6 | from interface import implements 7 | from behavior_tree_learning.sbt import World 8 | 9 | 10 | @dataclass 11 | class Pos: 12 | """ 13 | Cartesian position 14 | """ 15 | 16 | x: float = 0 17 | y: float = 0 18 | z: float = 0 19 | 20 | def __add__(self, other): 21 | return Pos(self.x + other.x, self.y + other.y, self.z + other.z) 22 | 23 | def __str__(self): 24 | return '(' + str(self.x) + ', ' + str(self.y) + ', ' + str(self.z) + ')' 25 | 26 | 27 | @dataclass 28 | class State: 29 | """ 30 | Definition of substates 31 | """ 32 | 33 | bricks = None 34 | picked = None 35 | 36 | 37 | @dataclass 38 | class SMParameters: 39 | """Data class for parameters for the state machine simulator """ 40 | 41 | pick_height: float = 0.04 42 | brick_height: float = 0.0192 43 | not_pressed_dist: float = 0.002 44 | pos_margin: float = 0.001 45 | ontop_margin: float = 0.003 46 | random_events: bool = False # Random events 47 | verbose: bool = False # Extra prints 48 | 49 | 50 | class SMMode(IntEnum): 51 | """Special state machine modes for testing different situations """ 52 | 53 | DEFAULT = 0 54 | CROISSANT = 1 55 | BALANCE = 2 56 | BLOCKING = 3 57 | 58 | 59 | def move_brick_to(brick, position): 60 | """ Move brick to given position """ 61 | 62 | brick.x = position.x 63 | brick.y = position.y 64 | brick.z = position.z 65 | 66 | 67 | class ApplicationWorld(implements(World)): 68 | """ 69 | Class for handling the State Machine Simulator 70 | """ 71 | 72 | def __init__(self, start_positions, random_events=False, parameters=None, scenario: str = ""): 73 | 74 | if parameters is None: 75 | self.sm_par = SMParameters() 76 | else: 77 | self.sm_par = parameters 78 | 79 | self.sm_par.random_events = random_events 80 | 81 | if scenario == "tower": 82 | self.mode = SMMode.DEFAULT 83 | elif scenario == "croissant": 84 | self.mode = SMMode.CROISSANT 85 | else: 86 | raise ValueError("Unknown [%s] scenario" % scenario) 87 | 88 | self.state = State() 89 | self.state.bricks = [] 90 | for pos in start_positions: 91 | self.state.bricks.append(copy(pos)) 92 | 93 | def startup(self, verbose): 94 | return True 95 | 96 | def is_alive(self): 97 | return True 98 | 99 | def shutdown(self): 100 | pass 101 | 102 | def random_event(self): 103 | """ 104 | Has a probability of creating a random event, 105 | dropping the current picked brick at a random position 106 | """ 107 | 108 | if self.sm_par.random_events: 109 | number = random.random() 110 | if number < 0.5: 111 | if self.state.picked is not None: 112 | self.state.bricks[self.state.picked].x += random.gauss(0, 0.05) 113 | self.state.bricks[self.state.picked].y += random.gauss(0, 0.05) 114 | self.state.bricks[self.state.picked].z = 0 115 | self.state.picked = None 116 | 117 | def hand_empty(self): 118 | """ Checks if any object is picked """ 119 | 120 | if self.state.picked is None: 121 | return True 122 | return False 123 | 124 | def pick(self, brick): 125 | """ Picks given brick """ 126 | 127 | if self.state.picked is None: 128 | self.state.picked = brick 129 | self.state.bricks[brick].z += self.sm_par.pick_height 130 | return True 131 | return False 132 | 133 | def place(self, position=None, brick=None): 134 | # pylint: disable=too-many-branches 135 | """ Place current picked object on given position or given brick """ 136 | 137 | if (self.state.picked is not None 138 | and (position is not None or brick is not None)): 139 | if self.mode == SMMode.CROISSANT: 140 | if self.state.picked == 2 or self.state.picked == 3: 141 | if abs(self.state.bricks[1].y) < 0.01: 142 | return False 143 | if position is not None: 144 | new_brick_position = position 145 | elif brick is not None: 146 | new_brick_position = copy(self.state.bricks[brick]) 147 | new_brick_position.z += self.sm_par.brick_height + self.sm_par.not_pressed_dist 148 | 149 | if self.mode == SMMode.BALANCE: 150 | if self.state.picked == 0 and brick == 1: 151 | new_brick_position.y += 0.01 152 | elif self.state.picked == 2 and brick is None and self.state.bricks[1].y == 0.0: 153 | new_brick_position.z += 0.0192 154 | elif self.mode == SMMode.BLOCKING: 155 | if position is not None: 156 | if self.state.picked == 2: 157 | for i in range(len(self.state.bricks)): 158 | if self.state.picked != i and \ 159 | new_brick_position.x == self.state.bricks[i].x and \ 160 | abs(new_brick_position.y - self.state.bricks[i].y) < 0.1: 161 | return False 162 | for i in range(len(self.state.bricks)): 163 | if self.state.picked != i and self.distance(i, new_brick_position) < 0.001: 164 | return False 165 | elif brick is not None: 166 | if brick == 2: 167 | new_brick_position.z += 0.05 168 | 169 | move_brick_to(self.state.bricks[self.state.picked], new_brick_position) 170 | self.state.picked = None 171 | return True 172 | return False 173 | 174 | def apply_force(self, brick): 175 | """ Applies force on brick """ 176 | 177 | if self.state.picked is None: 178 | upper_brick = self.state.bricks[brick] 179 | for lower_brick in self.state.bricks: 180 | if lower_brick is not upper_brick: 181 | if self.on_top(upper_brick, lower_brick): 182 | upper_brick.z = lower_brick.z + self.sm_par.brick_height 183 | break 184 | return True 185 | return False 186 | 187 | def get_picked(self): 188 | """ return picked state """ 189 | 190 | return self.state.picked 191 | 192 | def distance(self, brick, position): 193 | """ Returns distance between given brick and given position """ 194 | 195 | return sqrt((self.state.bricks[brick].x - position.x)**2 + 196 | (self.state.bricks[brick].y - position.y)**2 + 197 | (self.state.bricks[brick].z - position.z)**2) 198 | 199 | def on_top(self, upper_brick, lower_brick): 200 | """ Checks if upper brick is on top of lower brick with margins """ 201 | 202 | if (abs(upper_brick.x - lower_brick.x) < self.sm_par.ontop_margin 203 | and abs(upper_brick.y - lower_brick.y) < self.sm_par.ontop_margin 204 | and (0 < upper_brick.z - lower_brick.z <= self.sm_par.brick_height + self.sm_par.ontop_margin)): 205 | return True 206 | return False 207 | 208 | 209 | class ApplicationWorldFactory: 210 | 211 | def __init__(self, start_position, random_events=False, parameters=None, scenario: str = ""): 212 | self._start_position = start_position 213 | self._random_events = random_events 214 | self._parameters = parameters 215 | self._scenario = scenario 216 | 217 | def make(self): 218 | return ApplicationWorld(self._start_position, self._random_events, self._parameters, self._scenario) 219 | -------------------------------------------------------------------------------- /examples/simple_execution_demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgerod/behavior_tree_learning/71da80c91ecd48fd5da377f83604b62112ba9629/examples/simple_execution_demo/__init__.py -------------------------------------------------------------------------------- /examples/simple_execution_demo/bt_collection.py: -------------------------------------------------------------------------------- 1 | _bt_1 = ['s(', 2 | 'DO_PickGearPart[]', 3 | 'DO_MoveGearPart[P: place]', 4 | 'DO_PlaceGearPart[]', 5 | ')'] 6 | 7 | _bt_2 = ['f(', 8 | 'CHECK_GearPartPlaced[]', 9 | 's(', 10 | 'DO_PickGearPart[]', 11 | 'DO_MoveGearPart[P: place]', 12 | 'DO_PlaceGearPart[]', 13 | ')', 14 | ')'] 15 | 16 | 17 | def select_bt(idx: int): 18 | collection = [_bt_1, _bt_2] 19 | return collection[idx] 20 | -------------------------------------------------------------------------------- /examples/simple_execution_demo/dummy_world.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from interface import implements 3 | from behavior_tree_learning.sbt import World 4 | 5 | 6 | WorldOperationResults = \ 7 | collections.namedtuple('WorldOperationResults', 8 | 'is_picked_succeed is_placed_succeed do_pick_succeed do_place_succeed do_move_succeed') 9 | 10 | 11 | class DummyWorld(implements(World)): 12 | 13 | def __init__(self, operation_results: WorldOperationResults, is_alive: bool = True): 14 | self._is_alive = is_alive 15 | self._operation_results = operation_results 16 | 17 | def startup(self, verbose): 18 | return True 19 | 20 | def is_alive(self): 21 | return self._is_alive 22 | 23 | def shutdown(self): 24 | pass 25 | 26 | def is_gear_part_picked(self): 27 | return self._operation_results.is_picked_succeed 28 | 29 | def is_gear_part_placed(self): 30 | return self._operation_results.is_placed_succeed 31 | 32 | def pick_gear_part(self): 33 | return self._operation_results.do_pick_succeed 34 | 35 | def place_gear_part(self): 36 | return self._operation_results.do_place_succeed 37 | 38 | def move_gear_part(self): 39 | return self._operation_results.do_move_succeed 40 | -------------------------------------------------------------------------------- /examples/simple_execution_demo/execution_nodes.py: -------------------------------------------------------------------------------- 1 | import py_trees as pt 2 | from behavior_tree_learning.sbt import BehaviorRegister, BehaviorNodeWithOperation 3 | 4 | 5 | class CheckGearPartPicked(BehaviorNodeWithOperation): 6 | 7 | @staticmethod 8 | def make(text, world, verbose=False): 9 | return CheckGearPartPicked(text, world, verbose) 10 | 11 | def __init__(self, name, world, verbose): 12 | super().__init__(name) 13 | self._world = world 14 | 15 | def initialise(self): 16 | pass 17 | 18 | def update(self): 19 | print("CheckGearPartPicked::update() [%s]" % self.name) 20 | success = self._world.is_gear_part_picked() 21 | return pt.common.Status.SUCCESS if success else pt.common.Status.FAILURE 22 | 23 | 24 | class CheckGearPartPlaced(BehaviorNodeWithOperation): 25 | 26 | @staticmethod 27 | def make(text, world, verbose=False): 28 | return CheckGearPartPlaced(text, world, verbose) 29 | 30 | def __init__(self, name, world, verbose): 31 | super().__init__(name) 32 | self._world = world 33 | 34 | def initialise(self): 35 | pass 36 | 37 | def update(self): 38 | print("CheckGearPartPlaced::update() [%s]" % self.name) 39 | success = self._world.is_gear_part_placed() 40 | return pt.common.Status.SUCCESS if success else pt.common.Status.FAILURE 41 | 42 | 43 | class DoPickGearPart(BehaviorNodeWithOperation): 44 | 45 | @staticmethod 46 | def make(text, world, verbose=False): 47 | return DoPickGearPart(text, world, verbose) 48 | 49 | def __init__(self, name, world, verbose): 50 | super().__init__(name) 51 | self._world = world 52 | 53 | def initialise(self): 54 | pass 55 | 56 | def update(self): 57 | print("DoPickGearPart::update() [%s]" % self.name) 58 | success = self._world.pick_gear_part() 59 | return pt.common.Status.SUCCESS if success else pt.common.Status.FAILURE 60 | 61 | 62 | class DoPlaceGearPart(BehaviorNodeWithOperation): 63 | 64 | @staticmethod 65 | def make(text, world, verbose=False): 66 | return DoPlaceGearPart(text, world, verbose) 67 | 68 | def __init__(self, name, world, verbose): 69 | super().__init__(name) 70 | self._world = world 71 | 72 | def initialise(self): 73 | pass 74 | 75 | def update(self): 76 | print("DoPlaceGearPart::update() [%s]" % self.name) 77 | success = self._world.place_gear_part() 78 | return pt.common.Status.SUCCESS if success else pt.common.Status.FAILURE 79 | 80 | 81 | class DoMoveGearPart(BehaviorNodeWithOperation): 82 | 83 | @staticmethod 84 | def make(text, world, verbose=False): 85 | return DoMoveGearPart(text, world, verbose) 86 | 87 | def __init__(self, name, world, verbose): 88 | super().__init__(name) 89 | self._world = world 90 | 91 | def initialise(self): 92 | pass 93 | 94 | def update(self): 95 | print("DoMoveGearPart::update() [%s]" % self.name) 96 | success = self._world.move_gear_part() 97 | return pt.common.Status.SUCCESS if success else pt.common.Status.FAILURE 98 | 99 | 100 | def get_behaviors(): 101 | 102 | behavior_register = BehaviorRegister() 103 | behavior_register.add_condition('CHECK_GearPartPlaced[]', CheckGearPartPlaced) 104 | behavior_register.add_condition('CHECK_GearPartPicked[]', CheckGearPartPicked) 105 | behavior_register.add_action('DO_PickGearPart[]', DoPickGearPart) 106 | behavior_register.add_action('DO_PlaceGearPart[]', DoPlaceGearPart) 107 | behavior_register.add_action('DO_MoveGearPart[P: place]', DoMoveGearPart) 108 | return behavior_register 109 | -------------------------------------------------------------------------------- /examples/simple_execution_demo/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | _this_file_path = os.path.abspath(__file__) 6 | _PACKAGE_DIRECTORY = os.path.dirname(os.path.dirname(os.path.dirname(_this_file_path))) 7 | _EXAMPLES_DIRECTORY = os.path.dirname(os.path.dirname(_this_file_path)) 8 | _CURRENT_EXAMPLE_DIRECTORY = os.path.dirname(_this_file_path) 9 | 10 | 11 | def add_modules_to_path(): 12 | sys.path.append(os.path.normpath(_PACKAGE_DIRECTORY)) 13 | sys.path.append(os.path.normpath(_EXAMPLES_DIRECTORY)) 14 | 15 | 16 | def get_example_directory(): 17 | return _CURRENT_EXAMPLE_DIRECTORY 18 | 19 | 20 | def get_log_directory(): 21 | return os.path.join(_CURRENT_EXAMPLE_DIRECTORY, 'logs') 22 | -------------------------------------------------------------------------------- /examples/simple_execution_demo/run_execute_bt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import os 7 | from behavior_tree_learning.sbt import BehaviorTreeExecutor, ExecutionParameters 8 | from behavior_tree_learning.sbt import BehaviorNodeFactory 9 | 10 | from simple_execution_demo.paths import get_log_directory 11 | from simple_execution_demo import bt_collection 12 | from simple_execution_demo.execution_nodes import get_behaviors 13 | from simple_execution_demo.dummy_world import DummyWorld, WorldOperationResults 14 | 15 | 16 | def run(): 17 | 18 | node_factory = BehaviorNodeFactory(get_behaviors()) 19 | 20 | sbt_1 = bt_collection.select_bt(0) 21 | sbt_2 = bt_collection.select_bt(1) 22 | trials = [(sbt_1, WorldOperationResults(is_picked_succeed=True, is_placed_succeed=True, 23 | do_pick_succeed=True, do_place_succeed=True, do_move_succeed=True)), 24 | (sbt_1, WorldOperationResults(is_picked_succeed=False, is_placed_succeed=False, 25 | do_pick_succeed=False, do_place_succeed=False, do_move_succeed=False)), 26 | (sbt_2, WorldOperationResults(is_picked_succeed=True, is_placed_succeed=True, 27 | do_pick_succeed=True, do_place_succeed=True, do_move_succeed=True)), 28 | (sbt_2, WorldOperationResults(is_picked_succeed=False, is_placed_succeed=False, 29 | do_pick_succeed=False, do_place_succeed=False, do_move_succeed=False))] 30 | world_feedback_succeed = True 31 | 32 | for tdx, trial in zip(range(0, len(trials)), trials): 33 | 34 | print("Trial: %d" % tdx) 35 | 36 | sbt = list(trial[0]) 37 | world_operations = trial[1] 38 | 39 | print("SBT: ", sbt) 40 | print("World operations: ", world_operations) 41 | 42 | world = DummyWorld(world_operations, world_feedback_succeed) 43 | bt_executor = BehaviorTreeExecutor(node_factory, world) 44 | success, ticks, tree = bt_executor.run(sbt, ExecutionParameters(successes_required=1)) 45 | 46 | try: 47 | os.mkdir(get_log_directory()) 48 | except OSError: 49 | pass 50 | 51 | file_name = 'trial_%d' % (tdx + 1) 52 | 53 | tree.save_figure(get_log_directory(), name=file_name) 54 | print("Succeed: ", success) 55 | 56 | 57 | if __name__ == "__main__": 58 | run() 59 | -------------------------------------------------------------------------------- /examples/tiago_pnp/README.md: -------------------------------------------------------------------------------- 1 | Tiago Pick and Place 2 | ==================== 3 | 4 | Example adapted from https://github.com/matiov/learn-BTs-with-GP, the world is 5 | implemented using a state machine. -------------------------------------------------------------------------------- /examples/tiago_pnp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgerod/behavior_tree_learning/71da80c91ecd48fd5da377f83604b62112ba9629/examples/tiago_pnp/__init__.py -------------------------------------------------------------------------------- /examples/tiago_pnp/bt_collection.py: -------------------------------------------------------------------------------- 1 | _scenario_1 = ['f(', 2 | 'task_done?', 3 | 's(', 4 | 'localise!', 5 | 'head [Up]!', 6 | 'f(', 7 | 'have_block?', 8 | 's(', 9 | 'arm [Tucked]!', 10 | 'move_to_pick [0]!', 11 | ')', 12 | ')', 13 | 'head [Down]!', 14 | 'pick!', 15 | 'move_to_place!', 16 | 'place!', 17 | ')', 18 | ')'] 19 | 20 | 21 | def select_bt(name: str): 22 | collection = {'scenario_1': _scenario_1} 23 | return collection[name] 24 | -------------------------------------------------------------------------------- /examples/tiago_pnp/environment.py: -------------------------------------------------------------------------------- 1 | from interface import implements 2 | from behavior_tree_learning.sbt import StringBehaviorTree, BehaviorNodeFactory 3 | from behavior_tree_learning.learning import Environment 4 | from tiago_pnp.world import ApplicationWorldFactory 5 | from tiago_pnp.fitness_function import FitnessFunction 6 | 7 | 8 | class ApplicationEnvironment(implements(Environment)): 9 | """ 10 | Class defining the environment in which the individual operates 11 | """ 12 | 13 | def __init__(self, node_factory: BehaviorNodeFactory, world_factory: ApplicationWorldFactory, 14 | scenario: str, verbose=False): 15 | 16 | if scenario != 'scenario_1' and scenario != 'scenario_3': 17 | raise ValueError('Unknown selected scenario') 18 | 19 | self._scenario = scenario 20 | self._world_factory = world_factory 21 | self._node_factory = node_factory 22 | self._verbose = verbose 23 | 24 | def run_and_compute(self, individual, verbose): 25 | """ 26 | Run the simulation and return the fitness 27 | """ 28 | 29 | sbt = list(individual) 30 | verbose_enabled = self._verbose or verbose 31 | 32 | if verbose_enabled: 33 | print("SBT: ", sbt) 34 | 35 | if self._scenario == 'scenario_2': 36 | # In this case we run the same BT against the state machine in 3 different setups 37 | # every setup features a different spawn pose for the cube 38 | 39 | fitness = 0 40 | performance = 0 41 | 42 | for i in range(3): 43 | world = self._world_factory.make() 44 | behavior_tree = StringBehaviorTree(sbt[:], behaviors=self._node_factory, world=world, 45 | verbose=verbose) 46 | _, ticks = behavior_tree.run_bt() 47 | 48 | cost, output = FitnessFunction().compute_cost(world, behavior_tree, ticks, verbose) 49 | 50 | fitness += -cost / 3.0 51 | performance += int(output) 52 | 53 | if performance == 3: 54 | completed = True 55 | else: 56 | completed = False 57 | 58 | elif self._scenario == 'scenario_1' or self._scenario == 'scenario_3': 59 | 60 | world = self._world_factory.make() 61 | behavior_tree = StringBehaviorTree(sbt[:], behaviors=self._node_factory, world=world, 62 | verbose=verbose) 63 | _, ticks = behavior_tree.run_bt() 64 | 65 | cost, completed = FitnessFunction().compute_cost(world, behavior_tree, ticks, verbose) 66 | fitness = -cost 67 | 68 | if verbose_enabled: 69 | print("fitness: ", fitness) 70 | 71 | return fitness 72 | 73 | def plot_individual(self, path, plot_name, individual): 74 | """ Saves a graphical representation of the individual """ 75 | 76 | sbt = list(individual) 77 | tree = StringBehaviorTree(sbt[:], behaviors=self._node_factory) 78 | tree.save_figure(path, name=plot_name) 79 | -------------------------------------------------------------------------------- /examples/tiago_pnp/execution_nodes.py: -------------------------------------------------------------------------------- 1 | import re 2 | import py_trees as pt 3 | from behavior_tree_learning.sbt import BehaviorRegister, BehaviorNode 4 | from tiago_pnp import world as sm 5 | 6 | 7 | class BlockOnTable(BehaviorNode): 8 | """ 9 | Condition checking if the cube is on table 10 | """ 11 | 12 | @staticmethod 13 | def make(text, world, verbose=False): 14 | return BlockOnTable(text, world) 15 | 16 | def __init__(self, name, world): 17 | self._world = world 18 | super(BlockOnTable, self).__init__(name) 19 | 20 | def update(self): 21 | if self._world.feedback[sm.Feedback.CUBE] == self._world.poses.cube_goal_pose: 22 | return pt.common.Status.SUCCESS 23 | return pt.common.Status.FAILURE 24 | 25 | 26 | class IsLocalised(BehaviorNode): 27 | """ 28 | Condition checking if the robot is localised 29 | """ 30 | 31 | @staticmethod 32 | def make(text, world, verbose=False): 33 | return IsLocalised(text, world, verbose) 34 | 35 | def __init__(self, name, world, verbose): 36 | self._world = world 37 | self._verbose = verbose 38 | super(IsLocalised, self).__init__(name) 39 | 40 | def update(self): 41 | if self._verbose: 42 | print("Checking LOC") 43 | 44 | if self._world.current[sm.State.LOCALISED]: 45 | return pt.common.Status.SUCCESS 46 | return pt.common.Status.FAILURE 47 | 48 | 49 | class Localise(BehaviorNode): 50 | """ 51 | Localise behavior 52 | """ 53 | 54 | @staticmethod 55 | def make(text, world, verbose=False): 56 | return Localise(text, world, verbose) 57 | 58 | def __init__(self, name, world, verbose): 59 | self._world = world 60 | self._verbose = verbose 61 | self._state = None 62 | super(Localise, self).__init__(name) 63 | 64 | def initialise(self): 65 | if not self._world.current[sm.State.LOCALISED]: 66 | self._state = None 67 | 68 | def update(self): 69 | if self._verbose: 70 | print("LOC") 71 | 72 | if self._state is None: 73 | self._state = pt.common.Status.RUNNING 74 | elif self._state is pt.common.Status.RUNNING: 75 | if self._world.localise_robot(): 76 | self._state = pt.common.Status.SUCCESS 77 | else: 78 | self._state = pt.common.Status.FAILURE 79 | return self._state 80 | 81 | 82 | class MoveArm(BehaviorNode): 83 | """ 84 | Moving arm behavior 85 | """ 86 | 87 | @staticmethod 88 | def make(text, world, verbose=False): 89 | 90 | configuration = text[text.find("[") + 1:text.find("]")] 91 | return MoveArm(text, world, verbose, configuration) 92 | 93 | def __init__(self, name, world, verbose, configuration): 94 | 95 | self._world = world 96 | self._verbose = verbose 97 | self._configuration = configuration 98 | self._state = None 99 | 100 | super(MoveArm, self).__init__(name) 101 | 102 | def initialise(self): 103 | 104 | if self._world.current[sm.State.ARM] != self._configuration: 105 | self._state = None 106 | 107 | def update(self): 108 | 109 | if self._state is None: 110 | self._state = pt.common.Status.RUNNING 111 | self._world.manipulating = True 112 | elif self._state is pt.common.Status.RUNNING: 113 | if self._world.move_arm(self._configuration): 114 | self._state = pt.common.Status.SUCCESS 115 | else: 116 | self._state = pt.common.Status.FAILURE 117 | self._world.manipulating = False 118 | return self._state 119 | 120 | 121 | class IsTucked(BehaviorNode): 122 | """ 123 | Condition checking if the robot arm is tucked 124 | """ 125 | 126 | @staticmethod 127 | def make(text, world, verbose=False): 128 | return IsTucked(text, world, verbose) 129 | 130 | def __init__(self, name, world, verbose): 131 | self._world = world 132 | self._verbose = verbose 133 | super(IsTucked, self).__init__(name) 134 | 135 | def update(self): 136 | if self._verbose: 137 | print("Checking TUCK") 138 | 139 | # you don't want to tuck again if the robot has the cube 140 | if self._world.current[sm.State.ARM] == "Tucked": 141 | return pt.common.Status.SUCCESS 142 | return pt.common.Status.FAILURE 143 | 144 | 145 | class NotHaveBlock(BehaviorNode): 146 | """ 147 | Condition checking if the robot does not have the cube 148 | """ 149 | 150 | @staticmethod 151 | def make(text, world, verbose=False): 152 | return NotHaveBlock(text, world, verbose) 153 | 154 | def __init__(self, name, world, verbose): 155 | self._world = world 156 | self._verbose = verbose 157 | super(NotHaveBlock, self).__init__(name) 158 | 159 | def update(self): 160 | if self._verbose: 161 | print("Checking NOT BLOCK") 162 | 163 | if not self._world.current[sm.State.HAS_CUBE]: 164 | return pt.common.Status.SUCCESS 165 | return pt.common.Status.FAILURE 166 | 167 | 168 | class HaveBlock(BehaviorNode): 169 | """ 170 | Condition checking if the robot has the cube 171 | """ 172 | 173 | @staticmethod 174 | def make(text, world, verbose=False): 175 | return HaveBlock(text, world, verbose) 176 | 177 | def __init__(self, name, world, verbose): 178 | self._world = world 179 | self._verbose = verbose 180 | super(HaveBlock, self).__init__(name) 181 | 182 | def update(self): 183 | if self._verbose: 184 | print("Checking PICK") 185 | if self._world.current[sm.State.HAS_CUBE]: 186 | return pt.common.Status.SUCCESS 187 | return pt.common.Status.FAILURE 188 | 189 | 190 | class PickUp(BehaviorNode): 191 | """ 192 | Picking behavior 193 | """ 194 | 195 | @staticmethod 196 | def make(text, world, verbose=False): 197 | return PickUp(text, world, verbose) 198 | 199 | def __init__(self, name, world, verbose): 200 | self._world = world 201 | self._verbose = verbose 202 | self._state = None 203 | super(PickUp, self).__init__(name) 204 | 205 | def initialise(self): 206 | if self._world.feedback[sm.State.ARM] != "Pick" and not self._world.current[sm.State.HAS_CUBE]: 207 | self._state = None 208 | 209 | def update(self): 210 | if self._verbose: 211 | print("PICK") 212 | 213 | if self._state is None: 214 | self._state = pt.common.Status.RUNNING 215 | self._world.manipulating = True 216 | elif self._state is pt.common.Status.RUNNING: 217 | self._world.manipulating = False 218 | if self._world.pick(): 219 | self._state = pt.common.Status.SUCCESS 220 | else: 221 | self._state = pt.common.Status.FAILURE 222 | 223 | if self._world.current[sm.State.POSE] == self._world.poses.pick_table0: 224 | self._world.current[sm.State.VISITED][0] = True 225 | elif self._world.current[sm.State.POSE] == self._world.poses.pick_table1: 226 | self._world.current[sm.State.VISITED][1] = True 227 | elif self._world.current[sm.State.POSE] == self._world.poses.pick_table2: 228 | self._world.current[sm.State.VISITED][2] = True 229 | 230 | return self._state 231 | 232 | 233 | class Placed(BehaviorNode): 234 | """ 235 | Condition checking if the robot has placed the cube 236 | """ 237 | 238 | @staticmethod 239 | def make(text, world, verbose=False): 240 | 241 | cube_id = int(re.findall(r'[0-9]+', text)[0]) 242 | if verbose: 243 | print('Cube id: %d' % cube_id) 244 | return Placed(text, world, verbose, cube_id) 245 | 246 | def __init__(self, name, world, verbose, cube_id): 247 | self._world = world 248 | self._verbose = verbose 249 | self._cube_id = cube_id 250 | super(Placed, self).__init__(name) 251 | 252 | def update(self): 253 | if self._verbose: 254 | print("Checking PLACED") 255 | 256 | if (self._world.feedback[sm.Feedback.CUBE][self._cube_id] == self._world.poses.cube_goal_pose 257 | and not self._world.current[sm.State.HAS_CUBE]): 258 | return pt.common.Status.SUCCESS 259 | return pt.common.Status.FAILURE 260 | 261 | 262 | class Place(BehaviorNode): 263 | """ 264 | Placing behavior 265 | """ 266 | 267 | @staticmethod 268 | def make(text, world, verbose=False): 269 | return Place(text, world, verbose) 270 | 271 | def __init__(self, name, world, verbose): 272 | self._world = world 273 | self._verbose = verbose 274 | self._state = None 275 | super(Place, self).__init__(name) 276 | 277 | def initialise(self): 278 | if self._world.current[sm.State.HAS_CUBE]: 279 | self._state = None 280 | 281 | def update(self): 282 | if self._verbose: 283 | print("PLACE") 284 | 285 | if self._state is None: 286 | self._state = pt.common.Status.RUNNING 287 | self._world.manipulating = True 288 | elif self._state is pt.common.Status.RUNNING: 289 | if self._world.place(): 290 | self._state = pt.common.Status.SUCCESS 291 | else: 292 | self._state = pt.common.Status.FAILURE 293 | self._world.manipulating = False 294 | return self._state 295 | 296 | 297 | class Visited(BehaviorNode): 298 | """ 299 | Condition checking if robot has visited a pick table and attempted the picking 300 | """ 301 | 302 | @staticmethod 303 | def make(text, world, verbose=False): 304 | pose = 'xxx' 305 | return Visited(text, world, verbose, pose) 306 | 307 | def __init__(self, name, world, verbose, pose): 308 | self._world = world 309 | self._verbose = verbose 310 | self._pose = pose 311 | self._pose_idx = None 312 | super(Visited, self).__init__("{} visited?".format(self._pose)) 313 | 314 | def update(self): 315 | if self._pose == "pick_table0": 316 | self._pose_idx = 0 317 | elif self._pose == "pick_table1": 318 | self._pose_idx = 1 319 | elif self._pose == "pick_table2": 320 | self._pose_idx = 2 321 | 322 | if self._world.current[sm.State.VISITED][self.pose_idx]: 323 | return pt.common.Status.SUCCESS 324 | return pt.common.Status.FAILURE 325 | 326 | 327 | class MoveToPose(BehaviorNode): 328 | """ 329 | Move to pose behavior 330 | """ 331 | 332 | @staticmethod 333 | def make(text, world, verbose=False): 334 | 335 | if text == "move_to_pick [0]!": 336 | return MoveToPose(text, world, verbose, "pick_table_0") 337 | elif text == "move_to_pick [1]!": 338 | return MoveToPose(text, world, verbose, "pick_table_1") 339 | elif text == "move_to_pick [2]!": 340 | return MoveToPose(text, world, verbose, "pick_table_2") 341 | elif text == "move_to_place!": 342 | return MoveToPose(text, world, verbose, "place_table") 343 | else: 344 | raise ValueError("Unknown [%s] behavior node" % text) 345 | 346 | def __init__(self, name, world, verbose, pose): 347 | self._world = world 348 | self._verbose = verbose 349 | 350 | self._state = None 351 | self._pose = pose 352 | self._sm_pose = [] 353 | 354 | super(MoveToPose, self).__init__(name) 355 | 356 | def initialise(self): 357 | 358 | if self._pose == "pick_table_0": 359 | self._sm_pose = self._world.poses.pick_table0 360 | elif self._pose == "pick_table_1": 361 | self._sm_pose = self._world.poses.pick_table1 362 | elif self._pose == "pick_table_2": 363 | self._sm_pose = self._world.poses.pick_table2 364 | elif self._pose == "place_table": 365 | self._sm_pose = self._world.poses.place_table 366 | elif self._pose == "random_1": 367 | self._sm_pose = self._world.poses.random_pose1 368 | elif self._pose == "random_2": 369 | self._sm_pose = self._world.poses.random_pose2 370 | elif self._pose == "random_3": 371 | self._sm_pose = self._world.poses.random_pose3 372 | elif self._pose == "random_4": 373 | self._sm_pose = self._world.poses.random_pose4 374 | elif self._pose == "random_5": 375 | self._sm_pose = self._world.poses.random_pose5 376 | elif self._pose == "random_6": 377 | self._sm_pose = self._world.poses.random_pose6 378 | elif self._pose == "random_7": 379 | self._sm_pose = self._world.poses.random_pose7 380 | elif self._pose == "random_8": 381 | self._sm_pose = self._world.poses.random_pose8 382 | elif self._pose == "random_9": 383 | self._sm_pose = self._world.poses.random_pose9 384 | elif self._pose == "origin": 385 | self._sm_pose = self._world.poses.origin 386 | elif self._pose == "spawn": 387 | self._sm_pose = self._world.poses.spawn_pose 388 | 389 | if self._world.current[sm.State.POSE] != self._sm_pose: 390 | self._state = None 391 | 392 | def update(self): 393 | if self._verbose: 394 | print("MPiT") 395 | 396 | if self._state is None: 397 | self._state = pt.common.Status.RUNNING 398 | self._world.moving = True 399 | elif self._state is pt.common.Status.RUNNING: 400 | if self._world.move_to(self._sm_pose): 401 | self._state = pt.common.Status.SUCCESS 402 | else: 403 | self._state = pt.common.Status.FAILURE 404 | self._world.moving = False 405 | return self._state 406 | 407 | 408 | class MoveToPoseSafely(BehaviorNode): 409 | """ 410 | Move to palce pose behavior taking a slower but safer path 411 | """ 412 | 413 | @staticmethod 414 | def make(text, world, verbose=False): 415 | pose = 'xxx' 416 | return MoveToPoseSafely(text, world, verbose) 417 | 418 | def __init__(self, name, world, verbose, pose): 419 | 420 | self._world = world 421 | self._verbose = verbose 422 | 423 | self._state = None 424 | self._pose = pose 425 | self._sm_pose = [] 426 | 427 | super(MoveToPoseSafely, self).__init__("Safely to {}!".format(self._pose)) 428 | 429 | def initialise(self): 430 | if self._pose == "pick_table0": 431 | self._sm_pose = self._world.poses.pick_table0 432 | elif self._pose == "place_table": 433 | self._sm_pose = self._world.poses.place_table 434 | if self._world.current[sm.State.POSE] != self._sm_pose: 435 | self._state = None 436 | 437 | def update(self): 438 | 439 | if self._verbose: 440 | print("MPlT") 441 | 442 | if self._state is None: 443 | self._state = pt.common.Status.RUNNING 444 | self._world.moving = True 445 | elif self._state is pt.common.Status.RUNNING: 446 | if self._world.move_to(self.sm_pose, safe=True): 447 | self._state = pt.common.Status.SUCCESS 448 | else: 449 | self._state = pt.common.Status.FAILURE 450 | self._world.moving = False 451 | return self._state 452 | 453 | 454 | class MoveHeadUp(BehaviorNode): 455 | """ 456 | Move the head up behavior 457 | """ 458 | 459 | @staticmethod 460 | def make(text, world, verbose=False): 461 | return MoveHeadUp(text, world, verbose) 462 | 463 | def __init__(self, name, world, verbose): 464 | self._world = world 465 | self._verbose = verbose 466 | self._state = None 467 | super(MoveHeadUp, self).__init__(name) 468 | 469 | def initialise(self): 470 | if not self._world.manipulating and self._world.current[sm.State.HEAD] != 'Up': 471 | self._state = None 472 | 473 | def update(self): 474 | if self._verbose: 475 | print("UP") 476 | 477 | if self._state is None: 478 | self._state = pt.common.Status.RUNNING 479 | elif self._state is pt.common.Status.RUNNING: 480 | if self._world.move_head_up(): 481 | self._state = pt.common.Status.SUCCESS 482 | else: 483 | self._state = pt.common.Status.FAILURE 484 | return self._state 485 | 486 | 487 | class MoveHeadDown(BehaviorNode): 488 | """ 489 | Move the head down behavior 490 | """ 491 | 492 | @staticmethod 493 | def make(text, world, verbose=False): 494 | return MoveHeadDown(text, world, verbose) 495 | 496 | def __init__(self, name, world, verbose): 497 | self._world = world 498 | self._verbose = verbose 499 | self._state = None 500 | super(MoveHeadDown, self).__init__(name) 501 | 502 | def initialise(self): 503 | if not self._world.moving and self._world.current[sm.State.HEAD] != 'Down': 504 | self._state = None 505 | 506 | def update(self): 507 | if self._verbose: 508 | print("DOWN") 509 | 510 | if self._state is None: 511 | self._state = pt.common.Status.RUNNING 512 | elif self._state is pt.common.Status.RUNNING: 513 | if self._world.move_head_down(): 514 | self._state = pt.common.Status.SUCCESS 515 | else: 516 | self._state = pt.common.Status.FAILURE 517 | return self._state 518 | 519 | 520 | class Finished(BehaviorNode): 521 | """ 522 | Condition checking if the task is finished 523 | """ 524 | 525 | @staticmethod 526 | def make(text, world, verbose=False): 527 | return Finished(text, world, verbose) 528 | 529 | def __init__(self, name, world, verbose): 530 | self._world = world 531 | self._verbose = verbose 532 | super(Finished, self).__init__(name) 533 | 534 | def update(self): 535 | if self._verbose: 536 | print("Checking PLACED") 537 | 538 | cube_dist = sum(self._world.feedback[sm.Feedback.CUBE_DISTANCE]) 539 | if cube_dist == 0.0: 540 | return pt.common.Status.SUCCESS 541 | return pt.common.Status.FAILURE 542 | 543 | 544 | def _make_scenario1_nodes(): 545 | 546 | behavior_register = BehaviorRegister() 547 | behavior_register.add_condition('have_block?', HaveBlock) 548 | behavior_register.add_condition('cube_placed [0]?', Placed) 549 | behavior_register.add_condition('task_done?', Finished) 550 | behavior_register.add_action('head [Up]!', MoveHeadUp) 551 | behavior_register.add_action('head [Down]!', MoveHeadDown) 552 | behavior_register.add_action('localise!', Localise) 553 | behavior_register.add_action('move_to_pick [0]!', MoveToPose) 554 | behavior_register.add_action('move_to_place!', MoveToPose) 555 | behavior_register.add_action('place!', Place) 556 | behavior_register.add_action('pick!', PickUp) 557 | behavior_register.add_action('arm [Tucked]!', MoveArm) 558 | 559 | return behavior_register 560 | 561 | 562 | def _make_scenario3_nodes(): 563 | 564 | behavior_register = BehaviorRegister() 565 | behavior_register.add_condition('have_block?', HaveBlock) 566 | behavior_register.add_condition('cube_placed [0]?', Placed) 567 | behavior_register.add_condition('cube_placed [1]?', Placed) 568 | behavior_register.add_condition('cube_placed [2]?', Placed) 569 | behavior_register.add_condition('task_done?', Finished) 570 | behavior_register.add_action('head [Up]!', MoveHeadUp) 571 | behavior_register.add_action('head [Down]!', MoveHeadDown) 572 | behavior_register.add_action('localise!', Localise) 573 | behavior_register.add_action('move_to_pick [0]!', MoveToPose) 574 | behavior_register.add_action('move_to_pick [1]!', MoveToPose) 575 | behavior_register.add_action('move_to_pick [2]!', MoveToPose) 576 | behavior_register.add_action('move_to_place!', MoveToPose) 577 | behavior_register.add_action('place!', Place) 578 | behavior_register.add_action('pick!', PickUp) 579 | behavior_register.add_action('arm [Tucked]!', MoveArm) 580 | 581 | return behavior_register 582 | 583 | 584 | def get_behaviors(name): 585 | 586 | if name == 'scenario_1': 587 | return _make_scenario1_nodes() 588 | elif name == 'scenario_3': 589 | return _make_scenario3_nodes() 590 | else: 591 | raise ValueError('Unknown %s name', name) 592 | 593 | -------------------------------------------------------------------------------- /examples/tiago_pnp/fitness_function.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from tiago_pnp import world as sm 3 | 4 | 5 | @dataclass 6 | class Coefficients: 7 | 8 | # BT structure: 9 | depth: float = 0.0 10 | length: float = 0.5 11 | time: float = 0.1 12 | failure: float = 0.0 13 | # Task steps: 14 | task_completion: int = 300 15 | subtask: int = 100 16 | pick: int = 50 17 | # Gradient: 18 | cube_dist: int = 10 19 | localization: int = 1 20 | distance_robot_cube: int = 2 21 | robot_dist: int = 0 22 | min_distance_robot_cube: int = 0 23 | min_distance_cube_goal: int = 0 24 | 25 | 26 | class FitnessFunction: 27 | 28 | def compute_cost(self, world, behavior_tree, ticks, verbose=False) -> (float, bool): 29 | 30 | completed = False 31 | coefficients = Coefficients() 32 | 33 | depth = behavior_tree.depth 34 | length = behavior_tree.length 35 | 36 | cube_distance = sum(world.feedback[sm.Feedback.CUBE_DISTANCE]) 37 | min_cube_distance = sum(world.feedback[sm.Feedback.MIN_CUBE_DISTANCE]) 38 | robot_cube_distance = sum(world.feedback[sm.Feedback.ROBOT_CUBE_DISTANCE]) 39 | min_rc_distance = sum(world.feedback[sm.Feedback.MIN_RC_DISTANCE]) 40 | 41 | robot_distance = world.feedback[sm.Feedback.ROBOT_DISTANCE] 42 | loc_error = world.feedback[sm.Feedback.LOCALIZATION_ERROR] 43 | time = world.feedback[sm.Feedback.ELAPSED_TIME] 44 | failure_probability = world.feedback[sm.Feedback.FAILURE_PB] 45 | 46 | cost = float(coefficients.length*length + 47 | coefficients.depth*depth + 48 | coefficients.cube_dist*cube_distance**2 + 49 | coefficients.localization*loc_error**2 + 50 | coefficients.distance_robot_cube*robot_cube_distance**2 + 51 | coefficients.min_distance_cube_goal*min_cube_distance**2 + 52 | coefficients.min_distance_robot_cube*min_rc_distance**2 + 53 | coefficients.robot_dist*robot_distance**2 + 54 | coefficients.time*time) + coefficients.failure*failure_probability 55 | 56 | if cube_distance == 0.0: 57 | completed = True 58 | else: 59 | cost += coefficients.task_completion 60 | for i in range(world.cubes): 61 | if world.feedback[sm.Feedback.CUBE_DISTANCE][i] == 0.0: 62 | cost -= coefficients.subtask 63 | if world.current[sm.State.HAS_CUBE] and world.current[sm.State.CUBE_ID] == i: 64 | cost -= coefficients.pick 65 | 66 | if verbose: 67 | print("\n") 68 | print("Ticks: " + str(ticks)) 69 | print("Cube pose: " + str(world.feedback[sm.Feedback.CUBE])) 70 | print("Robot pose: " + str(world.feedback[sm.Feedback.AMCL])) 71 | print("State pose: " + str(world.current[sm.State.POSE])) 72 | print("\n") 73 | print("Cube distance from goal: " + str(cube_distance)) 74 | print("Contribution: " + str(coefficients.cube_dist*cube_distance**2)) 75 | print("Min cube distance: " + str(min_cube_distance)) 76 | print("Contribution: " + str(coefficients.min_distance_cube_goal*min_cube_distance**2)) 77 | print("Robot distance from cube: " + str(robot_cube_distance)) 78 | print("Contribution: " + str(coefficients.distance_robot_cube*robot_cube_distance**2)) 79 | print("Min robot distance: " + str(min_rc_distance)) 80 | print("Contribution: " + str(coefficients.min_distance_robot_cube*min_rc_distance**2)) 81 | print("Robot distance from goal: " + str(robot_distance)) 82 | print("Contribution: " + str(coefficients.robot_dist*robot_distance**2)) 83 | print("Localisation Error: " + str(loc_error)) 84 | print("Contribution: " + str(coefficients.localization*loc_error**2)) 85 | print("Behavior Tree: L " + str(length) + ", D " + str(depth)) 86 | print("Contribution: " + str(coefficients.length*length + coefficients.depth*depth)) 87 | print("Elapsed Time: " + str(time)) 88 | print("Contribution: " + str(coefficients.time*time)) 89 | print("Failure Probability: " + str(failure_probability)) 90 | print("Contribution: " + str(coefficients.failure*failure_probability)) 91 | print("Total Cost: " + str(cost)) 92 | print("\n") 93 | 94 | return cost, completed 95 | -------------------------------------------------------------------------------- /examples/tiago_pnp/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | _this_file_path = os.path.abspath(__file__) 6 | _PACKAGE_DIRECTORY = os.path.dirname(os.path.dirname(os.path.dirname(_this_file_path))) 7 | _EXAMPLES_DIRECTORY = os.path.dirname(os.path.dirname(_this_file_path)) 8 | _CURRENT_EXAMPLE_DIRECTORY = os.path.dirname(_this_file_path) 9 | 10 | 11 | def add_modules_to_path(): 12 | sys.path.append(os.path.normpath(_PACKAGE_DIRECTORY)) 13 | sys.path.append(os.path.normpath(_EXAMPLES_DIRECTORY)) 14 | 15 | 16 | def get_example_directory(): 17 | return _CURRENT_EXAMPLE_DIRECTORY 18 | 19 | 20 | def get_log_directory(): 21 | return os.path.join(_CURRENT_EXAMPLE_DIRECTORY, 'logs') 22 | -------------------------------------------------------------------------------- /examples/tiago_pnp/run_execute_bt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import os 7 | from behavior_tree_learning.sbt import BehaviorNodeFactory, BehaviorTreeExecutor, ExecutionParameters 8 | from tiago_pnp.paths import get_log_directory 9 | from tiago_pnp import bt_collection 10 | from tiago_pnp.execution_nodes import get_behaviors 11 | from tiago_pnp.world import ApplicationWorld 12 | 13 | 14 | def run(): 15 | 16 | scenario = 'scenario_1' 17 | deterministic = False 18 | 19 | node_factory_1 = BehaviorNodeFactory(get_behaviors(scenario)) 20 | sbt_1 = bt_collection.select_bt(scenario) 21 | trials = [(sbt_1, node_factory_1)] 22 | 23 | for tdx, trial in zip(range(0, len(trials)), trials): 24 | 25 | print("Trial: %d" % tdx) 26 | 27 | sbt = list(trial[0]) 28 | node_factory = trial[1] 29 | print("SBT: ", sbt) 30 | 31 | simulated_world = ApplicationWorld(scenario, deterministic) 32 | bt_executor = BehaviorTreeExecutor(node_factory, simulated_world) 33 | 34 | success, ticks, tree = bt_executor.run(sbt, ExecutionParameters(successes_required=1), 35 | verbose=True) 36 | 37 | try: 38 | os.mkdir(get_log_directory()) 39 | except OSError: 40 | pass 41 | 42 | file_name = 'trial_%d' % (tdx + 1) 43 | tree.save_figure(get_log_directory(), name=file_name) 44 | print("Succeed: ", success) 45 | 46 | 47 | if __name__ == "__main__": 48 | run() 49 | -------------------------------------------------------------------------------- /examples/tiago_pnp/run_learn_bt.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import os 7 | import logging 8 | 9 | from behavior_tree_learning.sbt import BehaviorNodeFactory 10 | from behavior_tree_learning.learning import BehaviorTreeLearner, GeneticParameters, GeneticSelectionMethods 11 | from behavior_tree_learning.learning import TraceConfiguration 12 | 13 | from tiago_pnp.execution_nodes import get_behaviors 14 | from tiago_pnp.world import ApplicationWorld, ApplicationWorldFactory 15 | from tiago_pnp.environment import ApplicationEnvironment 16 | 17 | 18 | def _configure_logger(level, directory_path, name): 19 | 20 | for handler in logging.root.handlers[:]: 21 | logging.root.removeHandler(handler) 22 | 23 | try: 24 | file_path = os.path.join(directory_path, name + '.log') 25 | os.mkdir(directory_path) 26 | except: 27 | pass 28 | 29 | logging.basicConfig(filename=file_path, 30 | format='%(filename)s: %(message)s') 31 | logging.getLogger("gp").setLevel(level) 32 | 33 | 34 | def run(): 35 | 36 | scenario = 'scenario_1' 37 | parameters = GeneticParameters() 38 | 39 | parameters.n_generations = 8000 40 | parameters.fitness_threshold = -16. 41 | 42 | parameters.n_population = 30 43 | parameters.ind_start_length = 4 44 | parameters.f_crossover = 0.4 45 | parameters.f_mutation = 0.6 46 | parameters.n_offspring_crossover = 2 47 | parameters.n_offspring_mutation = 4 48 | parameters.parent_selection = GeneticSelectionMethods.TOURNAMENT 49 | parameters.survivor_selection = GeneticSelectionMethods.TOURNAMENT 50 | parameters.f_elites = 0.1 51 | parameters.f_parents = 1 52 | parameters.mutation_p_add = 0.5 53 | parameters.mutation_p_delete = 0.2 54 | parameters.rerun_fitness = 0 55 | parameters.allow_identical = False 56 | 57 | tracer = TraceConfiguration() 58 | tracer.plot_fitness = True 59 | tracer.plot_best_individual = True 60 | tracer.plot_last_generation = True 61 | 62 | num_trials = 10 63 | for tdx in range(1, num_trials+1): 64 | 65 | log_name = scenario + '_' + str(tdx) 66 | _configure_logger(logging.DEBUG, paths.get_log_directory(), log_name) 67 | 68 | parameters.log_name = log_name 69 | seed = tdx*100 70 | 71 | node_factory = BehaviorNodeFactory(get_behaviors(scenario)) 72 | world_factory = ApplicationWorldFactory(scenario, deterministic=True) 73 | environment = ApplicationEnvironment(node_factory, world_factory, scenario, verbose=False) 74 | 75 | bt_learner = BehaviorTreeLearner.from_environment(environment) 76 | success = bt_learner.run(parameters, seed, 77 | trace_conf=tracer, 78 | verbose=False) 79 | 80 | print("Trial: %d, Succeed: %s" % (tdx, success)) 81 | 82 | 83 | if __name__ == "__main__": 84 | run() 85 | -------------------------------------------------------------------------------- /package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | behavior_tree_learning 5 | 6 | 0.0.0 7 | TODO: Package description 8 | Diego Escudero 9 | TODO: License declaration 10 | 11 | catkin 12 | rospy 13 | rospy 14 | 15 | rosunit 16 | 17 | 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | py_trees 3 | numpy 4 | dataclasses 5 | scipy 6 | PyYAML 7 | python-interface 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ## ! DO NOT MANUALLY INVOKE THIS setup.py, USE CATKIN INSTEAD 2 | 3 | from distutils.core import setup 4 | from catkin_pkg.python_setup import generate_distutils_setup 5 | 6 | setup_args = generate_distutils_setup( 7 | packages=["behavior_tree_learning"], 8 | package_dir={"": "src"}) 9 | 10 | setup(**setup_args) 11 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgerod/behavior_tree_learning/71da80c91ecd48fd5da377f83604b62112ba9629/src/behavior_tree_learning/__init__.py -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp/__init__.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.gp.algorithm import GeneticProgramming 2 | from behavior_tree_learning.core.gp.environment import GeneticEnvironment, make_steps 3 | from behavior_tree_learning.core.gp.operators import GeneticOperators 4 | from behavior_tree_learning.core.gp.parameters import GeneticParameters, TraceConfiguration 5 | from behavior_tree_learning.core.gp.selection import SelectionMethods as GeneticSelectionMethods 6 | from behavior_tree_learning.core.gp.steps import AlgorithmSteps 7 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp/environment.py: -------------------------------------------------------------------------------- 1 | from interface import Interface, implements 2 | from behavior_tree_learning.core.gp.steps import AlgorithmSteps 3 | 4 | 5 | class GeneticEnvironment(Interface): 6 | 7 | def run_and_compute(self, individual, verbose): 8 | """ 9 | Run the simulation and return the fitness 10 | 11 | Parameters: 12 | individual 13 | verbose (bool) 14 | Returns: 15 | fitness (float) 16 | """ 17 | pass 18 | 19 | def plot_individual(self, path, plot_name, individual): 20 | """ 21 | Saves a graphical representation of the individual 22 | 23 | Parameters: 24 | path (str) : where to store the figure 25 | plot_name (str) : name of the figure 26 | individual 27 | Returns: 28 | None 29 | """ 30 | pass 31 | 32 | 33 | def make_steps(environment: GeneticEnvironment): 34 | 35 | class StepsForEnvironment(implements(AlgorithmSteps)): 36 | def __init__(self, environment_): 37 | self._environment = environment_ 38 | 39 | def calculate_fitness(self, individual, verbose): 40 | return self._environment.run_and_compute(individual, verbose) 41 | 42 | def plot_individual(self, path, plot_name, individual): 43 | self._environment.plot_individual(path, plot_name, individual) 44 | 45 | return StepsForEnvironment(environment) 46 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp/hash_table.py: -------------------------------------------------------------------------------- 1 | """ 2 | Hash table with linked list for entries with same hash 3 | """ 4 | 5 | import os 6 | import pathlib 7 | import hashlib 8 | import ast 9 | 10 | 11 | class _Node: 12 | """ 13 | Node data structure - essentially a LinkedList node 14 | """ 15 | 16 | def __init__(self, key, value): 17 | self.key = key 18 | self.value = [value] 19 | self.next = None 20 | 21 | def __eq__(self, other): 22 | if not isinstance(other, _Node): 23 | return False 24 | equal = self.key == other.key and self.value == other.value 25 | if equal: 26 | if self.next is not None or other.next is not None: 27 | if self.next is None or other.next is None: 28 | equal = False 29 | else: 30 | equal = self.next == other.next 31 | return equal 32 | 33 | 34 | class HashTable: 35 | 36 | _FILE_NAME = 'hash_log.txt' 37 | _DEFAULT_DIRECTORY_NAME = 'logs' 38 | 39 | def __init__(self, size=100000, path: str = ''): 40 | """_ 41 | Initialize hash table to fixed size 42 | """ 43 | 44 | self._size = size 45 | self._directory_name = self._DEFAULT_DIRECTORY_NAME if path == '' else path 46 | self._buckets = [None]*self._size 47 | self._num_values = 0 48 | 49 | def __eq__(self, other): 50 | 51 | if not isinstance(other, HashTable): 52 | return False 53 | 54 | equal = True 55 | for i in range(self._size): 56 | if self._buckets[i] != other._buckets[i]: 57 | equal = False 58 | break 59 | return equal 60 | 61 | def num_values(self): 62 | return self._num_values 63 | 64 | def hash(self, key: str): 65 | """ 66 | Generate a hash for a given key 67 | Input: string key 68 | Output: hash 69 | """ 70 | 71 | string = ''.join(key) 72 | new_hash = hashlib.md5() 73 | new_hash.update(string.encode('utf-8')) 74 | hashcode = new_hash.hexdigest() 75 | hashcode = int(hashcode, 16) 76 | return hashcode % self._size 77 | 78 | def insert(self, key: str, value): 79 | """ 80 | Insert a key - value pair to the hashtable 81 | Input: key - string 82 | value - anything 83 | """ 84 | 85 | index = self.hash(key) 86 | node = self._buckets[index] 87 | if node is None: 88 | self._buckets[index] = _Node(key, value) 89 | else: 90 | done = False 91 | while not done: 92 | if node.key == key: 93 | node.value.append(value) 94 | done = True 95 | elif node.next is None: 96 | node.next = _Node(key, value) 97 | done = True 98 | else: 99 | node = node.next 100 | 101 | self._num_values += 1 102 | 103 | def find(self, key: str): 104 | """ 105 | Find a data value based on key 106 | Input: key - string 107 | Output: value stored under "key" or None if not found 108 | """ 109 | 110 | index = self.hash(key) 111 | node = self._buckets[index] 112 | while node is not None and node.key != key: 113 | node = node.next 114 | 115 | if node is None: 116 | return None 117 | return node.value 118 | 119 | def load(self): 120 | """ 121 | Loads hash table information. 122 | """ 123 | 124 | self._create_directory(self._directory_name) 125 | 126 | with open(os.path.join(self._directory_name, self._FILE_NAME), 'r') as f: 127 | lines = f.read().splitlines() 128 | 129 | for i in range(0, len(lines)): 130 | individual = lines[i] 131 | individual = individual[5:].split(', value: ') 132 | key = ast.literal_eval(individual[0]) 133 | individual = individual[1].split(', count: ') 134 | values = individual[0][1:-1].split(', ') # Remove brackets and split multiples 135 | for value in values: 136 | self.insert(key, float(value)) 137 | 138 | def write(self): 139 | """ 140 | Writes table contents to a file 141 | """ 142 | 143 | self._create_directory(self._directory_name) 144 | 145 | with open(os.path.join(self._directory_name, self._FILE_NAME), 'w') as f: 146 | for node in filter(lambda x: x is not None, self._buckets): 147 | while node is not None: 148 | f.writelines('key: ' + str(node.key) + 149 | ', value: ' + str(node.value) + 150 | ', count: ' + str(len(node.value)) + '\n') 151 | node = node.next 152 | f.close() 153 | 154 | @staticmethod 155 | def _create_directory(directory_path): 156 | 157 | pathlib.Path(directory_path).mkdir(parents=True, exist_ok=True) 158 | 159 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp/operators.py: -------------------------------------------------------------------------------- 1 | from interface import Interface 2 | 3 | 4 | class GeneticOperators(Interface): 5 | 6 | def random_genome(self, length): 7 | """ 8 | Generate one genome 9 | 10 | Parameters: 11 | length (int) : length of the genome 12 | Returns: 13 | genome 14 | """ 15 | pass 16 | 17 | def mutate_gene(self, genome, p_add, p_delete): 18 | """ 19 | Mutate only a single gene. 20 | 21 | Parameters: 22 | genome 23 | p_add (int) : mutate parameter 24 | p_delete (int) : mutate parameter 25 | Returns: 26 | genome 27 | """ 28 | pass 29 | 30 | def crossover_genome(self, genome1, genome2, replace): 31 | """ 32 | Do crossover between genomes at random points 33 | 34 | Parameters: 35 | genome1 36 | genome2 37 | replace (bool) 38 | Returns: 39 | genome1 40 | genome2 41 | """ 42 | pass 43 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp/parameters.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from behavior_tree_learning.core.gp.selection import SelectionMethods 3 | 4 | 5 | @dataclass 6 | class GeneticParameters: 7 | 8 | ind_start_length: int = 5 # Start length of initial genomes 9 | min_length: int = 2 # Minimum length of individual 10 | n_population: int = 8 # Number of individuals in population 11 | f_crossover: float = 0.5 # Fraction of parent pool selected for crossover 12 | n_offspring_crossover: int = 1 # Number of offspring from crossover per parent 13 | replace_crossover: bool = False # Crossover replaces subtree at receiving genome or inserts 14 | f_mutation: float = 0.5 # Fraction of parent pool selected for mutation 15 | n_offspring_mutation: int = 1 # Number of offspring from mutation per parent 16 | parent_selection: int = SelectionMethods.TOURNAMENT # Selection method for parents 17 | survivor_selection: int = SelectionMethods.TOURNAMENT # Selection method for survival 18 | f_elites: float = 0.1 # Fraction of population that survive as elites 19 | f_parents: float = 1 # Fraction of parents that may survive to next generation 20 | mutate_co_offspring: bool = False # Offspring from crossover may also be mutated 21 | mutate_co_parents: bool = False # Parents for crossover may also be mutated 22 | mutation_p_add: float = 0.4 # Probability of mutation adding a gene 23 | mutation_p_delete: float = 0.3 # Probability of mutation deleting a gene 24 | allow_identical: bool = False # Offspring may be identical to any parent in prev generation 25 | keep_baseline: bool = True # Baseline, if any, is always kept in population for breeding 26 | boost_baseline: bool = False # Baseline is boosted to have higher probability of breeding 27 | boost_baseline_only_co: bool = True # Baseline is boosted for crossover selection, not mutation 28 | n_generations: int = 100 # Maximum number of generations 29 | fitness_threshold: float = 0.0 # Finish when best fitness is over this threshold 30 | hash_table_size: int = 100000 # Size of hash table 31 | rerun_fitness: int = 0 # 0-run only once, 1-according to prob, 2-always 32 | log_name: str = '1' # Name of log for folder and file handling 33 | 34 | 35 | @dataclass 36 | class TraceConfiguration: 37 | 38 | plot_fitness: bool = False # Save a plot with all fitness as figure 39 | plot_best_individual: bool = False # Save final best individual as figure 40 | plot_last_generation: bool = False # Save figures of entire last generation 41 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp/selection.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | import random 3 | import numpy as np 4 | 5 | 6 | def _elite_selection(population, fitness, n_elites, verbose): 7 | """ 8 | Elite selection from population 9 | """ 10 | 11 | sorted_population = sorted(zip(fitness, population), reverse=True) 12 | selected = [x for _, x in sorted_population[:n_elites]] 13 | 14 | if verbose: 15 | print('Elite selection - population: %d, num selected: %d' % (len(population), n_elites)) 16 | print('Sorted: population:') 17 | for fitness, genome in sorted_population: 18 | print('fitness: %f, genome: %s' % (fitness, genome)) 19 | 20 | return selected 21 | 22 | 23 | def _tournament_selection(population, fitness, n_winners, verbose): 24 | """ 25 | Tournament selection. 26 | """ 27 | 28 | tournament_size = n_winners 29 | while tournament_size < len(population): 30 | tournament_size *= 2 31 | 32 | tournament_population = list(zip(fitness, population)) 33 | random.shuffle(tournament_population) 34 | 35 | for i in range(tournament_size - len(population)): 36 | # Add dummies to make sure we have a full tournament 37 | tournament_population.insert(i * 2, (-float("inf"), [])) 38 | 39 | winner_fitness, winners = [list(x) for x in zip(*tournament_population)] 40 | while len(winners) > n_winners: 41 | for i in range(0, int(len(winners) / 2)): 42 | 43 | if winner_fitness[i] < winner_fitness[i+1]: 44 | winner_fitness.pop(i) 45 | winners.pop(i) 46 | else: 47 | winner_fitness.pop(i + 1) 48 | winners.pop(i + 1) 49 | 50 | return winners 51 | 52 | 53 | def _rank_selection(population, fitness, n_selected, verbose): 54 | """ 55 | Rank proportional selection 56 | Probabilities for each individual are scaled linearly according to rank 57 | such that the highest ranked individual get n_ranks as weight 58 | and the lowest ranked individual gets 1. The weights are then scaled so 59 | that they sum to 1. 60 | """ 61 | 62 | sorted_population = sorted(zip(fitness, population), reverse=True) 63 | _, sorted_indices = [list(x) for x in zip(*sorted_population)] 64 | n_ranks = len(sorted_indices) 65 | p = np.linspace(2 / (n_ranks + 1), 2 / (n_ranks * (n_ranks + 1)), n_ranks) 66 | return list(np.random.choice(sorted_indices, size=n_selected, replace=False, p=p)) 67 | 68 | 69 | def _random_selection(population, n_selected, verbose): 70 | return random.sample(population, n_selected) 71 | 72 | 73 | class SelectionMethods(Enum): 74 | """ 75 | Enum class for selection methods 76 | """ 77 | 78 | ELITISM = auto() 79 | TOURNAMENT = auto() 80 | RANK = auto() 81 | RANDOM = auto() 82 | ALL = auto() 83 | 84 | 85 | def selection(selection_method, population, fitness, n_selected, verbose=False): 86 | """ 87 | Select individuals from population 88 | """ 89 | 90 | if selection_method == SelectionMethods.ELITISM: 91 | selected = _elite_selection(population, fitness, n_selected, verbose) 92 | elif selection_method == SelectionMethods.TOURNAMENT: 93 | selected = _tournament_selection(population, fitness, n_selected, verbose) 94 | elif selection_method == SelectionMethods.RANK: 95 | selected = _rank_selection(population, fitness, n_selected, verbose) 96 | elif selection_method == SelectionMethods.RANDOM: 97 | selected = _random_selection(population, n_selected, verbose) 98 | elif selection_method == SelectionMethods.ALL: 99 | selected = population 100 | else: 101 | raise Exception('Invalid selection method') 102 | 103 | return selected 104 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp/steps.py: -------------------------------------------------------------------------------- 1 | import interface 2 | from interface import Interface, implements 3 | 4 | 5 | class AlgorithmSteps(Interface): 6 | 7 | @interface.default 8 | def execution_started(self): 9 | pass 10 | 11 | @interface.default 12 | def execute_generation(self, generation): 13 | pass 14 | 15 | @interface.default 16 | def current_population(self, population): 17 | pass 18 | 19 | @interface.default 20 | def crossover_population(self, population): 21 | pass 22 | 23 | @interface.default 24 | def mutated_population(self, population): 25 | pass 26 | 27 | @interface.default 28 | def survided_population(self, population): 29 | pass 30 | 31 | def calculate_fitness(self, individual, verbose): 32 | pass 33 | 34 | @interface.default 35 | def more_generations(self, generation, last_generation, fitness_achieved): 36 | pass 37 | 38 | @interface.default 39 | def plot_individual(self, path, plot_name, individual): 40 | pass 41 | 42 | @interface.default 43 | def execution_completed(self): 44 | pass 45 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp_sbt/__init__.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.gp_sbt.environment import Environment, EnvironmentWithFitnessFunction 2 | from behavior_tree_learning.core.gp_sbt.world_factory import WorldFactory 3 | from behavior_tree_learning.core.gp_sbt.gp_operators import Operators 4 | from behavior_tree_learning.core.gp_sbt.learning import BehaviorTreeLearner 5 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp_sbt/environment.py: -------------------------------------------------------------------------------- 1 | from interface import Interface, implements 2 | from behavior_tree_learning.core.sbt import BehaviorNodeFactory, StringBehaviorTree, ExecutionParameters 3 | from behavior_tree_learning.core.gp_sbt.world_factory import WorldFactory 4 | from behavior_tree_learning.core.gp_sbt.fitness_function import FitnessFunction 5 | 6 | 7 | class Environment(Interface): 8 | 9 | def run_and_compute(self, individual, verbose): 10 | """ 11 | Run the simulation and return the fitness 12 | 13 | Parameters: 14 | individual 15 | verbose (bool) 16 | Returns: 17 | fitness (float) 18 | """ 19 | pass 20 | 21 | def plot_individual(self, path, plot_name, individual): 22 | """ 23 | Saves a graphical representation of the individual 24 | 25 | Parameters: 26 | path (str) : where to store the figure 27 | plot_name (str) : name of the figure 28 | individual 29 | Returns: 30 | None 31 | """ 32 | pass 33 | 34 | 35 | class EnvironmentWithFitnessFunction(implements(Environment)): 36 | 37 | def __init__(self, 38 | node_factory: BehaviorNodeFactory, 39 | world_factory: WorldFactory, 40 | fitness_function: FitnessFunction, 41 | verbose=False): 42 | 43 | self._node_factory = node_factory 44 | self._world_factory = world_factory 45 | self._fitness_function = fitness_function 46 | self._verbose = verbose 47 | 48 | def run_and_compute(self, individual, verbose): 49 | 50 | verbose_enabled = self._verbose or verbose 51 | 52 | sbt = list(individual) 53 | if verbose_enabled: 54 | print("SBT: ", sbt) 55 | 56 | world = self._world_factory.make() 57 | 58 | tree = StringBehaviorTree(sbt, behaviors=self._node_factory, world=world, verbose=verbose) 59 | success, ticks = tree.run_bt(parameters=ExecutionParameters(successes_required=1)) 60 | 61 | fitness = FitnessFunction().compute_cost(world, tree, ticks, self._targets, 62 | self._fitness_coefficients, verbose=verbose) 63 | 64 | if verbose_enabled: 65 | print("fitness: ", fitness) 66 | 67 | return fitness 68 | 69 | def plot_individual(self, path, plot_name, individual): 70 | 71 | sbt = list(individual) 72 | tree = StringBehaviorTree(sbt[:], behaviors=self._node_factory) 73 | tree.save_figure(path, name=plot_name) 74 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp_sbt/fitness_function.py: -------------------------------------------------------------------------------- 1 | from interface import Interface 2 | 3 | 4 | class FitnessFunction(Interface): 5 | 6 | def compute_cost(self, world, behavior_tree, ticks, verbose): 7 | """ 8 | Retrieve values and compute cost 9 | 10 | Parameters: 11 | world (World) 12 | behavior_tree (StringBehaviorTree) 13 | ticks (int) 14 | verbose (bool) 15 | Returns: 16 | cost (float) 17 | """ 18 | pass 19 | 20 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp_sbt/gp_operators.py: -------------------------------------------------------------------------------- 1 | import random 2 | from interface import implements 3 | from behavior_tree_learning.core.gp import GeneticOperators 4 | from behavior_tree_learning.core.sbt import BehaviorTreeStringRepresentation 5 | 6 | 7 | class Operators(implements(GeneticOperators)): 8 | 9 | def random_genome(self, length): 10 | """ 11 | Returns a random genome 12 | """ 13 | 14 | bt = BehaviorTreeStringRepresentation([]) 15 | return bt.random(length) 16 | 17 | def mutate_gene(self, genome, p_add, p_delete): 18 | """ 19 | Mutate only a single gene. 20 | """ 21 | 22 | if p_add < 0 or p_delete < 0: 23 | raise Exception("Mutation parameters must not be negative.") 24 | 25 | if p_add + p_delete > 1: 26 | raise Exception("Sum of the mutation probabilities must be less than 1.") 27 | 28 | mutated_individual = BehaviorTreeStringRepresentation([]) 29 | max_attempts = 100 30 | attempts = 0 31 | while (not mutated_individual.is_valid() or mutated_individual.bt == genome) and attempts < max_attempts: 32 | mutated_individual.set(genome) 33 | index = random.randint(0, len(genome) - 1) 34 | mutation = random.random() 35 | 36 | if mutation < p_delete: 37 | mutated_individual.delete_node(index) 38 | elif mutation < p_delete + p_add: 39 | mutated_individual.add_node(index) 40 | else: 41 | mutated_individual.change_node(index) 42 | 43 | mutated_individual.close() 44 | mutated_individual.trim() 45 | attempts += 1 46 | 47 | if attempts >= max_attempts and (not mutated_individual.is_valid() or mutated_individual.bt == genome): 48 | mutated_individual = BehaviorTreeStringRepresentation([]) 49 | 50 | return mutated_individual.bt 51 | 52 | def crossover_genome(self, genome1, genome2, replace): 53 | """ 54 | Do crossover between genomes at random points 55 | """ 56 | 57 | bt1 = BehaviorTreeStringRepresentation(genome1) 58 | bt2 = BehaviorTreeStringRepresentation(genome2) 59 | offspring1 = BehaviorTreeStringRepresentation([]) 60 | offspring2 = BehaviorTreeStringRepresentation([]) 61 | 62 | if bt1.is_valid() and bt2.is_valid(): 63 | max_attempts = 100 64 | attempts = 0 65 | found = False 66 | while not found and attempts < max_attempts: 67 | offspring1.set(bt1.bt) 68 | offspring2.set(bt2.bt) 69 | cop1 = -1 70 | cop2 = -1 71 | if len(genome1) == 1: 72 | cop1 = 0 # Change whole tree 73 | else: 74 | while not offspring1.is_subtree(cop1): 75 | cop1 = random.randint(1, len(genome1) - 1) 76 | if len(genome2) == 1: 77 | cop2 = 0 # Change whole tree 78 | else: 79 | while not offspring2.is_subtree(cop2): 80 | cop2 = random.randint(1, len(genome2) - 1) 81 | 82 | if replace: 83 | offspring1.swap_subtrees(offspring2, cop1, cop2) 84 | else: 85 | subtree1 = offspring1.get_subtree(cop1) 86 | subtree2 = offspring2.get_subtree(cop2) 87 | if len(genome1) == 1: 88 | index1 = random.randint(0, 1) 89 | else: 90 | index1 = random.randint(1, len(genome1) - 1) 91 | if len(genome2) == 1: 92 | index2 = random.randint(0, 1) 93 | else: 94 | index2 = random.randint(1, len(genome2) - 1) 95 | offspring1.insert_subtree(subtree2, index1) 96 | offspring2.insert_subtree(subtree1, index2) 97 | 98 | attempts += 1 99 | if offspring1.is_valid() and offspring2.is_valid(): 100 | found = True 101 | if not found: 102 | offspring1.set([]) 103 | offspring2.set([]) 104 | 105 | return offspring1.bt, offspring2.bt 106 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp_sbt/learning.py: -------------------------------------------------------------------------------- 1 | from interface import implements 2 | from behavior_tree_learning.core.gp import GeneticEnvironment, make_steps 3 | from behavior_tree_learning.core.gp import GeneticParameters, TraceConfiguration 4 | from behavior_tree_learning.core.gp import AlgorithmSteps 5 | from behavior_tree_learning.core.gp import GeneticProgramming 6 | from behavior_tree_learning.core.gp_sbt.environment \ 7 | import Environment, EnvironmentWithFitnessFunction 8 | from behavior_tree_learning.core.gp_sbt.gp_operators \ 9 | import Operators as GeneticOperatorsForSBT 10 | 11 | 12 | class BehaviorTreeLearner: 13 | 14 | @staticmethod 15 | def from_environment(environment: Environment): 16 | 17 | bt = BehaviorTreeLearner() 18 | bt._gp_operators = GeneticOperatorsForSBT() 19 | bt._steps = make_steps(_EnvironmentAdapter(environment)) 20 | return bt 21 | 22 | @staticmethod 23 | def from_steps(steps: AlgorithmSteps): 24 | 25 | bt = BehaviorTreeLearner() 26 | bt._gp_operators = GeneticOperatorsForSBT() 27 | bt._steps = steps 28 | return bt 29 | 30 | def __init__(self): 31 | self._gp_operators = None 32 | self._steps = None 33 | 34 | def run(self, parameters: GeneticParameters, seed=None, hot_start=False, base_line=None, verbose=False, 35 | outputs_dir_path="", trace_conf=TraceConfiguration()): 36 | 37 | if not self._gp_operators or not self._steps: 38 | raise RuntimeError("Object not created correctly, a factory method should be used") 39 | 40 | gp = GeneticProgramming(self._gp_operators, outputs_dir_path) 41 | gp.run(self._steps, parameters, seed, hot_start, base_line, trace_conf=trace_conf, verbose=verbose) 42 | 43 | return True 44 | 45 | 46 | class _EnvironmentAdapter(implements(GeneticEnvironment)): 47 | 48 | def __init__(self, environment_): 49 | self._environment = environment_ 50 | 51 | def run_and_compute(self, individual, verbose): 52 | return self._environment.run_and_compute(individual, verbose) 53 | 54 | def plot_individual(self, path, plot_name, individual): 55 | self._environment.plot_individual(path, plot_name, individual) 56 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/gp_sbt/world_factory.py: -------------------------------------------------------------------------------- 1 | from interface import Interface 2 | 3 | 4 | class WorldFactory(Interface): 5 | 6 | def make(self): 7 | pass 8 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/logger/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dgerod/behavior_tree_learning/71da80c91ecd48fd5da377f83604b62112ba9629/src/behavior_tree_learning/core/logger/__init__.py -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/logger/logplot.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=too-many-instance-attributes 2 | """ 3 | Handling of logs and plots for learning 4 | """ 5 | import os 6 | import shutil 7 | import pickle 8 | from dataclasses import dataclass 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | import matplotlib 12 | from scipy import interpolate 13 | 14 | 15 | matplotlib.rcParams['pdf.fonttype'] = 42 16 | _DIRECTORY_PATH = "logs" 17 | 18 | 19 | def open_file(path, mode): 20 | """ 21 | Attempts to open file at path. 22 | Tried up to max_attempts times because of intermittent permission errors on Windows 23 | """ 24 | 25 | max_attempts = 100 26 | f = None 27 | for _ in range(max_attempts): 28 | try: 29 | f = open(path, mode) 30 | except PermissionError: 31 | continue 32 | break 33 | return f 34 | 35 | 36 | def make_directory(path): 37 | """ 38 | Attempts to create directory at path. 39 | Tried up to max_attempts times because of intermittent permission errors on Windows 40 | """ 41 | 42 | max_attempts = 100 43 | for _ in range(max_attempts): 44 | try: 45 | os.mkdir(path) 46 | except PermissionError: 47 | continue 48 | break 49 | 50 | 51 | def configure_log(directory_path=""): 52 | 53 | if directory_path != "": 54 | global _DIRECTORY_PATH 55 | _DIRECTORY_PATH = directory_path 56 | 57 | 58 | def get_log_folder(log_name): 59 | 60 | return _get_log_folder(log_name) 61 | 62 | 63 | def trim_logs(logs): 64 | """ Trims a list of logs so that all logs have the same number of entries/generations """ 65 | 66 | min_rowlength = 9999999 67 | for row in logs: 68 | rowlen = len(row) 69 | if rowlen < min_rowlength: 70 | min_rowlength = rowlen 71 | 72 | for row in logs: 73 | del row[min_rowlength:] 74 | 75 | 76 | def clear_logs(log_name): 77 | """ Clears previous log folders of same same """ 78 | 79 | log_folder = _get_log_folder(log_name) 80 | try: 81 | shutil.rmtree(log_folder) 82 | except FileNotFoundError: # pragma: no cover 83 | pass 84 | 85 | make_directory(log_folder) 86 | fitness_log_path = log_folder + '/fitness_log.txt' 87 | population_log_path = log_folder + '/population_log.txt' 88 | open(fitness_log_path, "x") 89 | open(population_log_path, "x") 90 | 91 | 92 | def clear_after_generation(log_name, generation): 93 | """ Clears fitness and population logs after given generation """ 94 | 95 | with open_file(_get_log_folder(log_name) + '/fitness_log.txt', 'r') as f: 96 | lines = f.readlines() 97 | with open_file(_get_log_folder(log_name) + '/fitness_log.txt', 'w') as f: 98 | for i in range(generation + 1): 99 | f.write(lines[i]) 100 | with open_file(_get_log_folder(log_name) + '/population_log.txt', 'r') as f: 101 | lines = f.readlines() 102 | with open_file(_get_log_folder(log_name) + '/population_log.txt', 'w') as f: 103 | for i in range(generation + 1): 104 | f.write(lines[i]) 105 | 106 | 107 | def log_best_individual(log_name, best_individual): 108 | """ Saves the best individual """ 109 | 110 | with open_file(_get_log_folder(log_name) + '/best_individual.pickle', 'wb') as f: 111 | pickle.dump(best_individual, f) 112 | 113 | 114 | def log_fitness(log_name, fitness): 115 | """ Logs fitness of all individuals """ 116 | 117 | with open_file(_get_log_folder(log_name) + '/fitness_log.txt', 'a') as f: 118 | f.write("%s\n" % fitness) 119 | 120 | 121 | def log_best_fitness(log_name, best_fitness): 122 | """ Logs best fitness of each generation """ 123 | 124 | with open_file(_get_log_folder(log_name) + '/best_fitness_log.pickle', 'wb') as f: 125 | pickle.dump(best_fitness, f) 126 | 127 | 128 | def log_n_episodes(log_name, n_episodes): 129 | """ Logs number of episodes """ 130 | with open_file(_get_log_folder(log_name) + '/n_episodes_log.pickle', 'wb') as f: 131 | pickle.dump(n_episodes, f) 132 | 133 | 134 | def log_population(log_name, population): 135 | """ Logs full population of the generation""" 136 | 137 | with open_file(_get_log_folder(log_name) + '/population_log.txt', 'a') as f: 138 | f.write("%s\n" % population) 139 | 140 | 141 | def log_last_population(log_name, population): 142 | """ Logs current population as pickle object """ 143 | 144 | with open_file(_get_log_folder(log_name) + '/population.pickle', 'wb') as f: 145 | pickle.dump(population, f) 146 | 147 | 148 | def log_settings(log_name, settings, base_line): 149 | """ Logs settings used for the run """ 150 | 151 | with open_file(_get_log_folder(log_name) + '/settings.txt', 'w') as f: 152 | for key, value in vars(settings).items(): 153 | f.write(key + ' ' + str(value) + '\n') 154 | f.write('Baseline: ' + str(base_line) + '\n') 155 | 156 | 157 | def log_state(log_name, randomstate, np_randomstate, generation): 158 | """ Logs the current random state and generation number """ 159 | 160 | with open_file(_get_log_folder(log_name) + '/states.pickle', 'wb') as f: 161 | pickle.dump(randomstate, f) 162 | pickle.dump(np_randomstate, f) 163 | pickle.dump(generation, f) 164 | 165 | 166 | def get_best_fitness(log_name): 167 | """ Gets the best fitness list from the given log """ 168 | 169 | with open_file(_get_log_folder(log_name) + '/best_fitness_log.pickle', 'rb') as f: 170 | best_fitness = pickle.load(f) 171 | return best_fitness 172 | 173 | 174 | def get_n_episodes(log_name): 175 | """ Gets the list of n_episodes from the given log """ 176 | 177 | with open_file(_get_log_folder(log_name) + '/n_episodes_log.pickle', 'rb') as f: 178 | n_episodes = pickle.load(f) 179 | return n_episodes 180 | 181 | 182 | def get_state(log_name): 183 | """ Gets the random state and generation number """ 184 | 185 | with open_file(get_log_folder(log_name) + '/states.pickle', 'rb') as f: 186 | randomstate = pickle.load(f) 187 | np_randomstate = pickle.load(f) 188 | generation = pickle.load(f) 189 | return randomstate, np_randomstate, generation 190 | 191 | 192 | def get_last_population(log_name): 193 | """ Gets the last population list from the given log """ 194 | 195 | with open_file(_get_log_folder(log_name) + '/population.pickle', 'rb') as f: 196 | population = pickle.load(f) 197 | return population 198 | 199 | 200 | def get_best_individual(log_name): 201 | """ Return the best individual from the given log """ 202 | 203 | with open_file(_get_log_folder(log_name) + '/best_individual.pickle', 'rb') as f: 204 | best_individual = pickle.load(f) 205 | return best_individual 206 | 207 | 208 | def plot_fitness(log_name, fitness, n_episodes=None): 209 | """ 210 | Plots fitness over iterations or individuals 211 | """ 212 | 213 | if n_episodes is not None: 214 | plt.plot(n_episodes, fitness) 215 | plt.xlabel("Episodes") 216 | else: 217 | plt.plot(fitness) 218 | plt.xlabel("Generation") 219 | plt.ylabel("Fitness") 220 | plt.savefig(_get_log_folder(log_name) + '/Fitness.png') 221 | plt.close() 222 | 223 | 224 | @dataclass 225 | class PlotParameters: 226 | """ 227 | Data class for parameters for plotting 228 | """ 229 | 230 | plot_mean: bool = True # Plot the mean of the logs 231 | mean_color: str = 'b' # Color for mean curve 232 | plot_std: bool = True # Plot the standard deviation 233 | std_color: str = 'b' # Color of the std fill 234 | plot_minmax: bool = False # Plots minmax instead of std, should not be combined 235 | plot_ind: bool = False # Plot each individual log 236 | ind_color: str = 'aquamarine' # Ind color 237 | label: str = '' # Label name 238 | title: str = '' # Plot title 239 | xlabel: str = '' # Label of x axis 240 | x_max: int = 0 # Upper limit of x axis 241 | extend_gens: int = 0 # Extend until this minimum number of gens 242 | ylabel: str = '' # Label of y axis 243 | extrapolate_y: bool = False # Extrapolate y as constant to x_max 244 | logarithmic_y: bool = False # Logarithmic y scale 245 | plot_horizontal: bool = True # Plot thin horizontal line 246 | horizontal: float = 0 # Horizontal value to plot 247 | horizontal_label: str = '' # Label of horizontal line 248 | horizontal_linestyle: str = 'dashed' # Style of horizontal line 249 | legend_position: str = 'lower right' # Position of legend 250 | save_fig: bool = True # Save figure. If false, more plots is possible. 251 | path: str = 'logs/plot.svg' # Path to save log 252 | 253 | 254 | def plot_learning_curves(logs, parameters): 255 | # pylint: disable=too-many-branches, too-many-statements, too-many-locals 256 | """ 257 | Plots mean and standard deviation of a number of logs in the same figure 258 | """ 259 | 260 | fitness = [] 261 | n_episodes = [] 262 | for log_name in logs: 263 | fitness.append(get_best_fitness(log_name)) 264 | n_episodes.append(get_n_episodes(log_name)) 265 | 266 | n_logs = len(logs) 267 | 268 | if parameters.extend_gens > 0: 269 | # Extend until this minimum number of gens, assuming shorter logs are stopped because 270 | # they have converged there is no difference to end result 271 | for i in range(n_logs): 272 | if len(fitness[i]) < parameters.extend_gens: 273 | last_fitness = fitness[i][-1] 274 | 275 | while len(fitness[i]) < parameters.extend_gens: 276 | fitness[i].append(last_fitness) 277 | n_episodes[i].append(parameters.x_max) 278 | 279 | trim_logs(fitness) 280 | trim_logs(n_episodes) 281 | 282 | fitness = np.array(fitness) 283 | n_episodes = np.array(n_episodes) 284 | 285 | startx = np.max(n_episodes[:, 0]) 286 | endx = np.min(n_episodes[:, -1]) 287 | if parameters.extrapolate_y: 288 | x = np.arange(startx, parameters.x_max + 1) 289 | else: 290 | x = np.arange(startx, endx + 1) 291 | 292 | if parameters.plot_horizontal: 293 | plt.plot([0, parameters.x_max], 294 | [parameters.horizontal, parameters.horizontal], 295 | color='k', linestyle=parameters.horizontal_linestyle, linewidth=1, label=parameters.horizontal_label) 296 | 297 | y = np.zeros((len(x), n_logs)) 298 | for i in range(0, n_logs): 299 | f = interpolate.interp1d(n_episodes[i, :], fitness[i, :], bounds_error=False) 300 | y[:, i] = f(x) 301 | if parameters.extrapolate_y: 302 | n_extrapolated = int(parameters.x_max - n_episodes[i, -1]) 303 | if n_extrapolated > 0: 304 | left = y[:n_episodes[i, -1] - n_episodes[i, 0] + 1, i] 305 | y[:, i] = np.concatenate((left, np.full(n_extrapolated, left[-1]))) 306 | if parameters.plot_ind: 307 | plt.plot(x, y[:, i], color=parameters.ind_color, linestyle='dashed', linewidth=1) 308 | 309 | y_mean = np.mean(y, axis=1) 310 | if parameters.plot_mean: 311 | plt.plot(x, y_mean, color=parameters.mean_color, label=parameters.label) 312 | 313 | if parameters.plot_std: 314 | y_std = np.std(y, axis=1) 315 | plt.fill_between(x, y_mean - y_std, y_mean + y_std, alpha=.1, color=parameters.std_color) 316 | if parameters.plot_minmax: 317 | max_curve = np.max(y, axis=1) 318 | min_curve = np.min(y, axis=1) 319 | plt.fill_between(x, min_curve, max_curve, alpha=.1, color=parameters.std_color) 320 | 321 | plt.legend(loc=parameters.legend_position) 322 | plt.xlabel(parameters.xlabel) 323 | if parameters.x_max > 0: 324 | plt.xlim(0, parameters.x_max) 325 | 326 | if parameters.logarithmic_y: 327 | plt.yscale('symlog') 328 | plt.yticks([0, -1, -10, -100], ('0', '-1', '-10', '-100')) 329 | plt.ylabel(parameters.ylabel) 330 | plt.title(parameters.title) 331 | if parameters.save_fig: 332 | 333 | plt.savefig(parameters.path, format='pdf', dpi=300) 334 | plt.close() 335 | 336 | 337 | def _get_log_folder(log_name): 338 | 339 | directory_path = _DIRECTORY_PATH 340 | if not os.path.exists(directory_path): 341 | os.mkdir(directory_path) 342 | 343 | return os.path.join(directory_path, log_name) 344 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/planner/__init__.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.planner import planner 2 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/planner/node_factory.py: -------------------------------------------------------------------------------- 1 | import py_trees as pt 2 | from behavior_tree_learning.core.sbt.behavior_factory import BehaviorNodeFactory 3 | 4 | 5 | class PlannerBehaviorNodeFactory: 6 | 7 | def __init__(self, make_execution_node, get_condition_parameters): 8 | 9 | self._behavior_factory = BehaviorNodeFactory(make_execution_node) 10 | self._make_execution_node = make_execution_node 11 | self._get_condition_parameters = get_condition_parameters 12 | 13 | def get_condition_parameters(self, condition): 14 | 15 | return self._get_condition_parameters(condition) 16 | 17 | def get_node(self, name: str, world, condition_parameters): 18 | 19 | node = self._behavior_factory.make_node(name) 20 | 21 | if node is None: 22 | return self._make_execution_node(name, world, condition_parameters) 23 | else: 24 | return node, True 25 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/planner/planner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements a very simple task planner inspired from 'Towards Blended Reactive 3 | Planning and Acting using Behavior Trees'. Generates a behaviors tree to solve 4 | task given a set of goals and behaviors with preconditions and postconditions. 5 | Since the conditions are not always static, it actually runs the tree while 6 | evaluating the conditions. 7 | """ 8 | import py_trees as pt 9 | 10 | from behavior_tree_learning.core.sbt import StringBehaviorTree, BehaviorNodeFactory 11 | from behavior_tree_learning.core.sbt.behaviors import RSequence 12 | from behavior_tree_learning.core.sbt.behavior_tree import get_action_list 13 | from behavior_tree_learning.core.planner import PlannerBehaviorNodeFactory 14 | 15 | 16 | def _handle_precondition(precondition, behavior_factory, world): 17 | """ 18 | Handles precondition by creating a subtree whose post-conditions (aka effects) 19 | fulfill the pre-condition (aka condition) 20 | """ 21 | 22 | print("Pre-condition in: ", precondition) 23 | condition_parameters = behavior_factory.get_condition_parameters(precondition) 24 | 25 | for action in get_action_list(): 26 | 27 | action_node, _ = behavior_factory.get_node(action, world, condition_parameters) 28 | if precondition in action_node.get_postconditions(): 29 | 30 | action_preconditions = action_node.get_preconditions() 31 | 32 | if action_preconditions: 33 | 34 | bt = RSequence('Sequence') 35 | for action_precondition in action_preconditions: 36 | 37 | condition_parameters = behavior_factory.get_condition_parameters(action_precondition) 38 | child, _ = behavior_factory.get_node(action_precondition, 39 | world, 40 | condition_parameters) 41 | bt.add_child(child) 42 | 43 | bt.add_child(action_node) 44 | 45 | else: 46 | bt = action_node 47 | 48 | return bt 49 | 50 | print("ERROR, no matching action found to ensure precondition") 51 | return None 52 | 53 | 54 | def _extend_leaf_node(leaf_node, behavior_factory, world): 55 | """ 56 | If leaf node fails, it should be replaced with a selector that checks leaf node 57 | and a subtree that fixes the pre-condition (condition) whenever it's not met. 58 | """ 59 | 60 | bt = pt.composites.Selector(name='Fallback') 61 | leaf_node.parent.replace_child(leaf_node, bt) 62 | bt.add_child(leaf_node) 63 | print("What is failing? ", leaf_node.name) 64 | 65 | extended = _handle_precondition(leaf_node.name, behavior_factory, world) 66 | if extended is not None: 67 | bt.add_child(extended) 68 | 69 | 70 | def _expand_tree(node, behavior_factory, world): 71 | """ 72 | Expands the part of the tree that fails 73 | """ 74 | 75 | print("TREE COMING IN :", node) 76 | 77 | if node.name == 'Fallback': 78 | print("Fallback node fails\n") 79 | for index, child in enumerate(node.children): 80 | if index >= 1: # Normally there will only be two children 81 | _expand_tree(child, behavior_factory, world) 82 | 83 | elif node.name == 'Sequence': 84 | print("Sequence node fails\n") 85 | for i in range(len(node.children)): 86 | if node.children[i].status == pt.common.Status.FAILURE: 87 | print("Child that fails: ", node.children[i].name) 88 | _expand_tree(node.children[i], behavior_factory, world) 89 | elif isinstance(node, pt.behaviour.Behaviour) and node.status == pt.common.Status.FAILURE: 90 | 91 | _extend_leaf_node(node, behavior_factory, world) 92 | 93 | else: 94 | print("Tree", node.name) 95 | 96 | 97 | def plan(world, get_execution_node, get_condition_parameters, goals): 98 | """ 99 | Generates a behaviors tree to solve task given a set of goals 100 | and behaviors with pre-conditions and post-conditions. Since the 101 | conditions are not always static, it actually runs the tree while evaluating 102 | the conditions. 103 | """ 104 | 105 | tree = RSequence() 106 | planner_behavior_factory = PlannerBehaviorNodeFactory(get_execution_node, get_condition_parameters) 107 | 108 | for goal in goals: 109 | 110 | goal_condition, _ = planner_behavior_factory.get_node(goal, world, []) 111 | tree.add_child(goal_condition) 112 | 113 | print(pt.display.unicode_tree(root=tree, show_status=True)) 114 | 115 | for i in range(60): 116 | 117 | tree.tick_once() 118 | print("Tick: ", i) 119 | print(pt.display.unicode_tree(root=tree, show_status=True)) 120 | 121 | if tree.status is pt.common.Status.FAILURE: 122 | _expand_tree(tree, planner_behavior_factory, world) 123 | print(pt.display.unicode_tree(root=tree, show_status=True)) 124 | 125 | elif tree.status is pt.common.Status.SUCCESS: 126 | break 127 | 128 | sbt_behavior_factory = BehaviorNodeFactory(get_execution_node) 129 | pt.display.render_dot_tree(tree, name='Planned bt', target_directory='') 130 | print(StringBehaviorTree('', sbt_behavior_factory, world, tree).to_string()) 131 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/plotter/__init__.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.plotter.print_functions import print_ascii_tree -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/plotter/print_functions.py: -------------------------------------------------------------------------------- 1 | import py_trees as pt 2 | 3 | 4 | def print_ascii_tree(py_tree): 5 | print(pt.display.ascii_tree(py_tree.root)) 6 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/__init__.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.sbt.behavior_tree import BehaviorTreeStringRepresentation 2 | from behavior_tree_learning.core.sbt.node_factory import BehaviorNodeFactory, BehaviorRegister, \ 3 | BehaviorNode, BehaviorNodeWithOperation 4 | from behavior_tree_learning.core.sbt.executor import BehaviorTreeExecutor 5 | from behavior_tree_learning.core.sbt.py_tree import StringBehaviorTree, ExecutionParameters 6 | from behavior_tree_learning.core.sbt.world import World 7 | from behavior_tree_learning.core.sbt.graphics import plot_behavior_tree 8 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/behaviors.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common behaviors for all implementations 3 | """ 4 | 5 | import py_trees as pt 6 | 7 | 8 | class RSequence(pt.composites.Selector): 9 | """ 10 | Rsequence for py_trees 11 | Reactive sequence overidding sequence with memory, py_trees' only available sequence. 12 | 13 | Author: Chrisotpher Iliffe Sprague, sprague@kth.se 14 | """ 15 | 16 | def __init__(self, name="Sequence", children=None): 17 | super(RSequence, self).__init__(name=name, children=children) 18 | 19 | def tick(self): 20 | """ 21 | Run the tick behaviour for this selector. Note that the status 22 | of the tick is always determined by its children, not 23 | by the user customized update function. 24 | Yields: 25 | :class:`~py_trees.behaviour.Behaviour`: a reference to itself or one of its children 26 | """ 27 | self.logger.debug("%s.tick()" % self.__class__.__name__) 28 | # Required behaviour for *all* behaviours and composites is 29 | # for tick() to check if it isn't running and initialise 30 | if self.status != pt.common.Status.RUNNING: 31 | # selectors dont do anything specific on initialisation 32 | # - the current child is managed by the update, never needs to be 'initialised' 33 | # run subclass (user) handles 34 | self.initialise() 35 | # run any work designated by a customized instance of this class 36 | self.update() 37 | previous = self.current_child 38 | for child in self.children: 39 | for node in child.tick(): 40 | yield node 41 | if node is child and \ 42 | (node.status == pt.common.Status.RUNNING or node.status == pt.common.Status.FAILURE): 43 | self.current_child = child 44 | self.status = node.status 45 | if previous is None or previous != self.current_child: 46 | # we interrupted, invalidate everything at a lower priority 47 | passed = False 48 | for sibling in self.children: 49 | if passed and sibling.status != pt.common.Status.INVALID: 50 | sibling.stop(pt.common.Status.INVALID) 51 | if sibling == self.current_child: 52 | passed = True 53 | yield self 54 | return 55 | # all children succeded, set succeed ourselves and current child to the last bugger who failed us 56 | self.status = pt.common.Status.SUCCESS 57 | try: 58 | self.current_child = self.children[-1] 59 | except IndexError: 60 | self.current_child = None 61 | yield self 62 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/executor.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.sbt.node_factory import BehaviorNodeFactory 2 | from behavior_tree_learning.core.sbt.py_tree import StringBehaviorTree, ExecutionParameters 3 | from behavior_tree_learning.core.sbt.world import World 4 | 5 | 6 | class BehaviorTreeExecutor: 7 | 8 | def __init__(self, node_factory: BehaviorNodeFactory, world: World): 9 | 10 | self._node_factory = node_factory 11 | self._world = world 12 | 13 | def run(self, sbt: str, parameters: ExecutionParameters, verbose=False): 14 | 15 | tree = StringBehaviorTree(sbt, behaviors=self._node_factory, world=self._world, verbose=verbose) 16 | success, ticks = tree.run_bt(parameters=parameters) 17 | return success, ticks, tree 18 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/graphics.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.sbt.node_factory import BehaviorNodeFactory 2 | from behavior_tree_learning.core.sbt.py_tree import StringBehaviorTree 3 | 4 | 5 | def plot_behavior_tree(bt_name: str, sbt: str, node_factory: BehaviorNodeFactory, directory_path): 6 | 7 | tree = StringBehaviorTree(sbt, behaviors=node_factory, world=None) 8 | tree.save_figure(directory_path, name=bt_name) 9 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/node_factory.py: -------------------------------------------------------------------------------- 1 | import py_trees as pt 2 | from behavior_tree_learning.core.sbt import behavior_tree as bt 3 | from behavior_tree_learning.core.sbt.behaviors import RSequence 4 | from behavior_tree_learning.core.sbt import parse_operation as operation 5 | 6 | 7 | class BehaviorNode(pt.behaviour.Behaviour): 8 | """ 9 | From 'pt.behaviour.Behaviour': 10 | 11 | def initialise(self): 12 | pass 13 | def update(self): 14 | return pt.common.Status.SUCCESS 15 | def terminate(self, new_status): 16 | pass 17 | 18 | The initialize() method should call World::is_alive(), same for 19 | 'BehaviorNodeWithOperation' 20 | """ 21 | 22 | def __init__(self, name): 23 | super().__init__(name) 24 | 25 | 26 | class BehaviorNodeWithOperation(pt.behaviour.Behaviour): 27 | 28 | @staticmethod 29 | def make(text, world, verbose=False): 30 | raise NotImplementedError 31 | 32 | def __init__(self, text): 33 | operation.print_parsed_function(text) 34 | self._operation = operation.parse_function(text) 35 | super().__init__(self._operation[0]) 36 | 37 | 38 | class BehaviorRegister: 39 | 40 | class BehaviorType: 41 | 42 | CONDITION = 1 43 | ACTION = 2 44 | 45 | def __init__(self): 46 | self._behaviors = {} 47 | 48 | def add_condition(self, name, behavior_class): 49 | self._behaviors[name] = [self.BehaviorType.CONDITION, behavior_class] 50 | 51 | def add_action(self, name, behavior_class): 52 | self._behaviors[name] = [self.BehaviorType.ACTION, behavior_class] 53 | 54 | def behaviors(self): 55 | return self._behaviors 56 | 57 | 58 | class BehaviorNodeFactory: 59 | 60 | def __init__(self, execution_behavior_register: BehaviorRegister = None): 61 | 62 | if execution_behavior_register is not None: 63 | self._execution_behavior_register = execution_behavior_register 64 | self._load_sbt_settings() 65 | else: 66 | self._execution_behavior_register = None 67 | 68 | def make_node(self, name, world=None, verbose=False): 69 | 70 | if name == 'nonpytreesbehavior': 71 | return None, False 72 | 73 | has_children = True 74 | node = self._make_control_node(name) 75 | 76 | if node is None and self._execution_behavior_register is not None: 77 | 78 | has_children = False 79 | node = self._make_execution_node(name, world, verbose) 80 | 81 | if node is None: 82 | raise Exception("Unexpected character", name) 83 | 84 | return node, has_children 85 | 86 | def _load_sbt_settings(self): 87 | 88 | bt.initialize_settings() 89 | bt.add_node('fallback', 'f(') 90 | bt.add_node('sequence', 's(') 91 | bt.add_node('parallel', 'p(') 92 | bt.add_node('up_node', ')') 93 | 94 | behaviors = self._execution_behavior_register.behaviors() 95 | for key in behaviors.keys(): 96 | 97 | if behaviors[key][0] == BehaviorRegister.BehaviorType.CONDITION: 98 | type_ = 'condition' 99 | elif behaviors[key][0] == BehaviorRegister.BehaviorType.ACTION: 100 | type_ = 'action' 101 | 102 | name = key 103 | bt.add_node(type_, name) 104 | 105 | def _make_control_node(self, name): 106 | 107 | node = None 108 | 109 | if name == 'f(': 110 | node = pt.composites.Selector('Fallback') 111 | elif name == 's(': 112 | node = RSequence() 113 | elif name == 'p(': 114 | node = pt.composites.Parallel(name="Parallel", 115 | policy=pt.common.ParallelPolicy.SuccessOnAll(synchronise=False)) 116 | elif name == ')': 117 | node = None 118 | 119 | return node 120 | 121 | def _make_execution_node(self, name, world, verbose): 122 | 123 | behaviors = self._execution_behavior_register.behaviors() 124 | for key in behaviors.keys(): 125 | if name == key: 126 | return behaviors[key][1].make(name, world, verbose) 127 | 128 | return None 129 | 130 | 131 | def make_factory_from_file(file_path): 132 | return NotImplementedError 133 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/parse_operation.py: -------------------------------------------------------------------------------- 1 | def _check_main_syntax(plain_text): 2 | 3 | text = plain_text.replace(" ", "") 4 | 5 | idx = text.find("=>") 6 | if idx != -1: 7 | function_and_arguments = text[:idx] 8 | return_values = text[idx + 2:] 9 | else: 10 | function_and_arguments = text 11 | return_values = None 12 | 13 | #print("left:", function_and_arguments) 14 | #print("right:", return_values) 15 | 16 | if (function_and_arguments.find("[") == -1 or 17 | function_and_arguments.find("]") == -1): 18 | raise Exception("Wrong FUNCTION_DEFINITION syntax, missed parenthesis after name") 19 | 20 | if return_values is not None: 21 | if (return_values.find("[") == -1 or 22 | return_values.find("]") == -1): 23 | raise Exception("Wrong RETURN_VALUE syntax, missed parenthesis") 24 | elif return_values.find("[") == (return_values.find("]") - 1): 25 | raise Exception("Wrong RETURN_VALUE syntax, any value between parenthesis") 26 | 27 | return text 28 | 29 | 30 | def _extract_function_name(text): 31 | return text[:text.find("[")] 32 | 33 | 34 | def _extract_variable(text): 35 | idx = text.find(":") 36 | name = text[:idx] 37 | type_ = text[idx + 1:] 38 | return name, type_ 39 | 40 | 41 | def _extract_variables(text): 42 | args_as_list = text.split(',') 43 | args = {} 44 | for arg in args_as_list: 45 | name, type_ = _extract_variable(arg) 46 | args[name] = type_ 47 | return args 48 | 49 | 50 | def _extract_function_arguments(text): 51 | 52 | if text.find("[") == -1 or text.find("]") == -1: 53 | raise Exception("Wrong ARGUMENTS syntax, missed parenthesis") 54 | 55 | args_as_text = text[text.find("[") + 1: text.find("]")] 56 | if args_as_text == "": 57 | return None 58 | else: 59 | return _extract_variables(args_as_text) 60 | 61 | 62 | def _extract_function_return_value(text): 63 | 64 | idx = text.find("=>") 65 | if idx == -1: 66 | return None 67 | 68 | args_as_text = text[idx + 2:] 69 | args_as_text = args_as_text[args_as_text.find("[") + 1: args_as_text.find("]")] 70 | if args_as_text == "": 71 | raise Exception("Wrong RETURN_VALUE syntax, missed parenthesis") 72 | 73 | return _extract_variables(args_as_text) 74 | 75 | 76 | def parse_function(text): 77 | 78 | text = _check_main_syntax(text) 79 | name = _extract_function_name(text) 80 | arguments = _extract_function_arguments(text) 81 | return_value = _extract_function_return_value(text) 82 | return name, arguments, return_value 83 | 84 | 85 | def print_parsed_function(text): 86 | 87 | name, arguments, return_value = parse_function(text) 88 | print("Text: ", text) 89 | print("Function name: %s, arguments: %s, return: %s" % (name, arguments, return_value)) 90 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/py_tree.py: -------------------------------------------------------------------------------- 1 | import time 2 | import py_trees as pt 3 | from behavior_tree_learning.core.sbt.world import World 4 | from behavior_tree_learning.core.sbt.behavior_tree import BehaviorTreeStringRepresentation 5 | from behavior_tree_learning.core.sbt.node_factory import BehaviorNodeFactory 6 | 7 | 8 | class ExecutionParameters: 9 | 10 | def __init__(self, max_ticks=30, max_time=30.0, max_straight_fails=1, successes_required=2): 11 | 12 | self.max_ticks = max_ticks 13 | self.max_time = max_time 14 | self.max_straight_fails = max_straight_fails 15 | self.successes_required = successes_required 16 | 17 | 18 | class StringBehaviorTree(pt.trees.BehaviourTree): 19 | 20 | class TraceInfo: 21 | 22 | def __init__(self, verbose): 23 | self.verbose = verbose 24 | 25 | def __init__(self, string: str, behaviors: BehaviorNodeFactory, world: World = None, root=None, verbose=False): 26 | 27 | if root is not None: 28 | self.root = root 29 | string = self.to_string() 30 | 31 | self.bt = BehaviorTreeStringRepresentation(string) 32 | self.depth = self.bt.depth() 33 | self.length = self.bt.length() 34 | self.failed = False 35 | self.timeout = False 36 | 37 | self._world = world 38 | self._behavior_factory = behaviors 39 | self._trace_info = self.TraceInfo(verbose) 40 | 41 | if root is not None: 42 | has_children = False 43 | else: 44 | self.root, has_children = self._behavior_factory.make_node(string[0], self._world, self._trace_info.verbose) 45 | string.pop(0) 46 | 47 | super().__init__(root=self.root) 48 | 49 | if has_children: 50 | self.create_from_string(string, self.root) 51 | 52 | def to_string(self): 53 | """ 54 | Returns bt string (actually a list) from py tree root 55 | by cleaning the ascii tree from py trees 56 | Not complete or beautiful by any means but works for many trees 57 | """ 58 | 59 | string = pt.display.ascii_tree(self.root) 60 | string = string.replace("[o] ", "") 61 | string = string.replace("\t", "") 62 | string = string.replace("-->", "") 63 | string = string.replace("Fallback", "f(") 64 | string = string.replace("Sequence", "s(") 65 | bt = string.split("\n") 66 | bt = bt[:-1] 67 | 68 | prev_leading_spaces = 999999 69 | for i in range(len(bt) - 1, -1, -1): 70 | leading_spaces = len(bt[i]) - len(bt[i].lstrip(' ')) 71 | bt[i] = bt[i].lstrip(' ') 72 | if leading_spaces > prev_leading_spaces: 73 | for _ in range(round((leading_spaces - prev_leading_spaces) / 4)): 74 | bt.insert(i + 1, ')') 75 | prev_leading_spaces = leading_spaces 76 | 77 | bt_obj = BehaviorTreeStringRepresentation(bt) 78 | bt_obj.close() 79 | 80 | return bt_obj.bt 81 | 82 | def create_from_string(self, string: str, node): 83 | """ 84 | Recursive function to generate the tree from a string 85 | """ 86 | 87 | while len(string) > 0: 88 | if string[0] == ")": 89 | string.pop(0) 90 | return node 91 | 92 | new_node, has_children = self._behavior_factory.make_node(string[0], self._world, self._trace_info.verbose) 93 | string.pop(0) 94 | if has_children: 95 | # Node is a control node or decorator with children - add subtree via string and then add to parent 96 | new_node = self.create_from_string(string, new_node) 97 | node.add_child(new_node) 98 | else: 99 | # Node is a leaf/action node - add to parent, then keep looking for siblings 100 | node.add_child(new_node) 101 | 102 | # This return is only reached if there are too few up nodes 103 | return node 104 | 105 | def run_bt(self, parameters: ExecutionParameters = ExecutionParameters()): 106 | """ 107 | Function executing the behavior tree 108 | """ 109 | 110 | if not self._world.startup(self._trace_info.verbose): 111 | return False, 0 112 | 113 | max_ticks = parameters.max_ticks 114 | max_time = parameters.max_time 115 | max_straight_fails = parameters.max_straight_fails 116 | successes_required = parameters.successes_required 117 | 118 | ticks = 0 119 | straight_fails = 0 120 | successes = 0 121 | status_ok = True 122 | start = time.time() 123 | 124 | while (self.root.status is not pt.common.Status.FAILURE or straight_fails < max_straight_fails) \ 125 | and (self.root.status is not pt.common.Status.SUCCESS or successes < successes_required) \ 126 | and ticks < max_ticks and status_ok: 127 | 128 | status_ok = self._world.is_alive() 129 | 130 | if status_ok: 131 | self.root.tick_once() 132 | 133 | ticks += 1 134 | if self.root.status is pt.common.Status.SUCCESS: 135 | successes += 1 136 | else: 137 | successes = 0 138 | 139 | if self.root.status is pt.common.Status.FAILURE: 140 | straight_fails += 1 141 | else: 142 | straight_fails = 0 143 | 144 | if time.time() - start > max_time: 145 | status_ok = False 146 | if self._trace_info.verbose: 147 | print("Max time expired") 148 | 149 | if self._trace_info.verbose: 150 | print("Status: %s Ticks: %d, Time: %s" % (status_ok, ticks, time.time() - start)) 151 | 152 | if ticks >= max_ticks: 153 | self.timeout = True 154 | if straight_fails >= max_straight_fails: 155 | self.failed = True 156 | 157 | self._world.shutdown() 158 | 159 | return status_ok, ticks 160 | 161 | def save_figure(self, path: str, name: str = "bt"): 162 | 163 | pt.display.render_dot_tree(self.root, name=name, target_directory=path) 164 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/core/sbt/world.py: -------------------------------------------------------------------------------- 1 | from interface import Interface 2 | 3 | 4 | class World(Interface): 5 | 6 | def startup(self, verbose): 7 | pass 8 | 9 | def is_alive(self): 10 | pass 11 | 12 | def shutdown(self): 13 | pass 14 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/gp.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.gp import GeneticEnvironment, GeneticOperators 2 | from behavior_tree_learning.core.gp import GeneticParameters, GeneticSelectionMethods, TraceConfiguration 3 | from behavior_tree_learning.core.gp import GeneticProgramming 4 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/learning.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.gp import GeneticParameters, GeneticSelectionMethods, TraceConfiguration 2 | from behavior_tree_learning.core.sbt import World, StringBehaviorTree, BehaviorTreeStringRepresentation 3 | from behavior_tree_learning.core.sbt import BehaviorNodeFactory, BehaviorRegister 4 | from behavior_tree_learning.core.sbt import ExecutionParameters 5 | from behavior_tree_learning.core.gp_sbt import Environment, BehaviorTreeLearner 6 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/sbt.py: -------------------------------------------------------------------------------- 1 | from behavior_tree_learning.core.sbt import BehaviorTreeExecutor, ExecutionParameters 2 | from behavior_tree_learning.core.sbt import BehaviorNodeFactory, BehaviorRegister, \ 3 | BehaviorNode, BehaviorNodeWithOperation 4 | from behavior_tree_learning.core.sbt import World, StringBehaviorTree 5 | from behavior_tree_learning.core.sbt import BehaviorTreeStringRepresentation 6 | from behavior_tree_learning.core.sbt import plot_behavior_tree 7 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/fwk/BT_SETTINGS.yaml: -------------------------------------------------------------------------------- 1 | fallback_nodes: 2 | - 'f(' 3 | sequence_nodes: 4 | - 's(' 5 | condition_nodes: 6 | - 'c0' 7 | - 'c1' 8 | action_nodes: 9 | - 'a0' 10 | - 'a1' 11 | - 'a2' 12 | - 'a3' 13 | - 'a4' 14 | - 'a5' 15 | up_node: 16 | - ')' -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/fwk/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/fwk/behavior_nodes.py: -------------------------------------------------------------------------------- 1 | import py_trees as pt 2 | from behavior_tree_learning.sbt import BehaviorRegister, BehaviorNode 3 | 4 | 5 | class C(BehaviorNode): 6 | 7 | @staticmethod 8 | def make(text, world, _): 9 | return C(text, world) 10 | 11 | def __init__(self, name, _): 12 | self._name = name 13 | super(C, self).__init__(str(self._name)) 14 | 15 | def update(self): 16 | #print(self._name) 17 | return pt.common.Status.SUCCESS 18 | 19 | 20 | class A(BehaviorNode): 21 | 22 | @staticmethod 23 | def make(text, world, verbose): 24 | return A(text, world) 25 | 26 | def __init__(self, name, wold): 27 | super(A, self).__init__(str(name)) 28 | 29 | def update(self): 30 | return pt.common.Status.SUCCESS 31 | 32 | 33 | class Toggle1(BehaviorNode): 34 | 35 | @staticmethod 36 | def make(text, world, verbose=False): 37 | return Toggle1(text, world) 38 | 39 | def __init__(self, name, world): 40 | self._world = world 41 | super(Toggle1, self).__init__(str(name)) 42 | 43 | def update(self): 44 | self._world.toggle_1() 45 | return pt.common.Status.SUCCESS 46 | 47 | 48 | class Toggle2(BehaviorNode): 49 | 50 | @staticmethod 51 | def make(text, world, verbose=False): 52 | return Toggle2(text, world) 53 | 54 | def __init__(self, name, world): 55 | self._world = world 56 | super(Toggle2, self).__init__(str(name)) 57 | 58 | def update(self): 59 | self._world.toggle_2() 60 | return pt.common.Status.SUCCESS 61 | 62 | 63 | class Toggle3(BehaviorNode): 64 | 65 | @staticmethod 66 | def make(text, world, verbose=False): 67 | return Toggle3(text, world) 68 | 69 | def __init__(self, name, world): 70 | self._world = world 71 | super(Toggle3, self).__init__(str(name)) 72 | 73 | def update(self): 74 | self._world.toggle_3() 75 | return pt.common.Status.SUCCESS 76 | 77 | 78 | class Toggle4(BehaviorNode): 79 | 80 | @staticmethod 81 | def make(text, world, verbose=False): 82 | return Toggle4(text, world) 83 | 84 | def __init__(self, name, world): 85 | self._world = world 86 | super(Toggle4, self).__init__(str(name)) 87 | 88 | def update(self): 89 | self._world.toggle_4() 90 | return pt.common.Status.SUCCESS 91 | 92 | 93 | class Read1(BehaviorNode): 94 | 95 | @staticmethod 96 | def make(text, world, verbose=False): 97 | return Read1(text, world) 98 | 99 | def __init__(self, name, world): 100 | self._world = world 101 | super(Read1, self).__init__(str(name)) 102 | 103 | def update(self): 104 | if self._world.read_1(): 105 | return pt.common.Status.SUCCESS 106 | return pt.common.Status.FAILURE 107 | 108 | 109 | class Read2(BehaviorNode): 110 | 111 | @staticmethod 112 | def make(text, world, verbose=False): 113 | return Read2(text, world) 114 | 115 | def __init__(self, name, world): 116 | self._world = world 117 | super(Read2, self).__init__(str(name)) 118 | 119 | def update(self): 120 | if self._world.read_2(): 121 | return pt.common.Status.SUCCESS 122 | return pt.common.Status.FAILURE 123 | 124 | 125 | class Read3(BehaviorNode): 126 | 127 | @staticmethod 128 | def make(text, world, verbose=False): 129 | return Read3(text, world) 130 | 131 | def __init__(self, name, world): 132 | self._world = world 133 | super(Read3, self).__init__(str(name)) 134 | 135 | def update(self): 136 | if self._world.read_3(): 137 | return pt.common.Status.SUCCESS 138 | return pt.common.Status.FAILURE 139 | 140 | 141 | class Read4(BehaviorNode): 142 | 143 | @staticmethod 144 | def make(text, world, verbose=False): 145 | return Read4(text, world) 146 | 147 | def __init__(self, name, world): 148 | self._world = world 149 | super(Read4, self).__init__(str(name)) 150 | 151 | def update(self): 152 | if self._world.read_4(): 153 | return pt.common.Status.SUCCESS 154 | return pt.common.Status.FAILURE 155 | 156 | 157 | def get_behaviors(): 158 | 159 | behavior_register = BehaviorRegister() 160 | 161 | behavior_register.add_condition('c0', C) 162 | behavior_register.add_condition('c1', C) 163 | behavior_register.add_condition('c2', C) 164 | behavior_register.add_action('a0', A) 165 | behavior_register.add_action('a1', A) 166 | behavior_register.add_action('a2', A) 167 | behavior_register.add_action('a3', A) 168 | behavior_register.add_action('a4', A) 169 | behavior_register.add_action('a5', A) 170 | 171 | return behavior_register 172 | 173 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/fwk/world.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import IntEnum 3 | from dataclasses import dataclass 4 | from interface import implements 5 | from behavior_tree_learning.core.sbt import World 6 | 7 | 8 | class State(IntEnum): 9 | 10 | state1 = 0 11 | state2 = 1 12 | state3 = 2 13 | state4 = 3 14 | 15 | 16 | @dataclass 17 | class SMParameters: 18 | """Data class for parameters for the state machine simulator """ 19 | verbose: bool = False #Extra prints 20 | 21 | 22 | class DummyWorld(implements(World)): 23 | 24 | def __init__(self): 25 | self.sm_par = SMParameters() 26 | self.state = [False]*(len(State)) 27 | 28 | def startup(self): 29 | return True 30 | 31 | def is_alive(self): 32 | return True 33 | 34 | def shutdown(self): 35 | return 36 | 37 | def toggle_1(self): 38 | self.state[State.state1] = True 39 | 40 | def toggle_22(self): 41 | self.state[State.state2] = True 42 | 43 | def toggle_3(self): 44 | self.state[State.state3] = True 45 | 46 | def toggle_4(self): 47 | self.state[State.state4] = True 48 | 49 | def read_1(self): 50 | return self.state[State.state1] 51 | 52 | def read_2(self): 53 | return self.state[State.state2] 54 | 55 | def read_3(self): 56 | return self.state[State.state3] 57 | 58 | def read4(self): 59 | return self.state[State.state4] 60 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | _this_file_path = os.path.abspath(__file__) 6 | _PACKAGE_DIRECTORY = os.path.dirname(os.path.dirname(_this_file_path)) 7 | _TEST_DIRECTORY = os.path.dirname(_this_file_path) 8 | 9 | 10 | def add_modules_to_path(): 11 | sys.path.append(os.path.normpath(_PACKAGE_DIRECTORY)) 12 | 13 | 14 | def get_test_directory(): 15 | return _TEST_DIRECTORY 16 | 17 | 18 | def get_log_directory(): 19 | return os.path.join(_TEST_DIRECTORY, 'logs') 20 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_gp_algorithm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import unittest 7 | from interface import implements 8 | from behavior_tree_learning.core.gp.hash_table import HashTable 9 | from behavior_tree_learning.core.gp.parameters import GeneticParameters 10 | from behavior_tree_learning.core.gp.operators import GeneticOperators 11 | from behavior_tree_learning.core.gp.environment import GeneticEnvironment 12 | from behavior_tree_learning.core.gp.algorithm import GeneticProgramming 13 | 14 | 15 | class FakeOperators(implements(GeneticOperators)): 16 | 17 | def random_genome(self, length): 18 | pass 19 | 20 | def mutate_gene(self, genome, p_add, p_delete): 21 | pass 22 | 23 | def crossover_genome(self, genome1, genome2, replace): 24 | pass 25 | 26 | 27 | class FakEnvironment(implements(GeneticEnvironment)): 28 | 29 | def run_and_compute(self, individual, verbose): 30 | pass 31 | 32 | def plot_individual(self, path, plot_name, individual): 33 | pass 34 | 35 | 36 | class TestGpAlgorithm(unittest.TestCase): 37 | 38 | def test_create_population(self): 39 | 40 | operators = FakeOperators() 41 | parameters = GeneticParameters() 42 | environment = FakEnvironment() 43 | 44 | #gp_algorithm = GeneticProgramming(operators) 45 | #gp_algorithm.run(environment, parameters) 46 | 47 | 48 | if __name__ == '__main__': 49 | unittest.main() 50 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_gp_hash_table.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import os 7 | import unittest 8 | from behavior_tree_learning.core.gp.hash_table import HashTable, _Node 9 | 10 | 11 | class TestHastTable(unittest.TestCase): 12 | 13 | def test_save_table_and_load(self): 14 | 15 | directory_path = os.path.join('logs', 'test_1') 16 | hash_table1 = HashTable(size=10, path=directory_path) 17 | hash_table1.insert(['1'], 1) 18 | hash_table1.insert(['2'], 2) 19 | hash_table1.insert(['3'], 3) 20 | hash_table1.insert(['4'], 4) 21 | hash_table1.insert(['4'], 5) 22 | hash_table1.write() 23 | 24 | hash_table2 = HashTable(size=10, path=directory_path) 25 | hash_table2.load() 26 | 27 | self.assertEqual(hash_table1, hash_table2) 28 | 29 | def test_tables_are_equal(self): 30 | 31 | hash_table1 = HashTable(size=10) 32 | hash_table2 = HashTable(size=10) 33 | hash_table1.insert(['1'], 1) 34 | hash_table1.insert(['2'], 2) 35 | hash_table2.insert(['3'], 3) 36 | hash_table2.insert(['4'], 4) 37 | 38 | self.assertEqual(hash_table1, hash_table1) 39 | self.assertNotEqual(hash_table1, hash_table2) 40 | self.assertNotEqual(hash_table1, 1) 41 | 42 | def test_nodes_are_equal(self): 43 | 44 | node1 = _Node(['a'], 1) 45 | node2 = _Node(['a'], 1) 46 | node3 = _Node(['b'], 2) 47 | self.assertEqual(node1, node2) 48 | self.assertNotEqual(node1, node3) 49 | self.assertNotEqual(node1, ['a']) 50 | self.assertNotEqual(node1, 1) 51 | 52 | node1.next = node3 53 | self.assertNotEqual(node1, node2) 54 | 55 | node2.next = node3 56 | self.assertEqual(node1, node2) 57 | 58 | def test_multiple_entries_in_one_table(self): 59 | 60 | hash_table1 = HashTable(size=10) 61 | hash_table1.insert(['a'], 1) 62 | hash_table1.insert(['a'], 2) 63 | hash_table1.insert(['a'], 3) 64 | hash_table1.insert(['b'], 4) 65 | hash_table1.insert(['b'], 5) 66 | hash_table1.insert(['b'], 6) 67 | 68 | self.assertEqual(hash_table1.find(['a']), [1, 2, 3]) 69 | self.assertEqual(hash_table1.find(['b']), [4, 5, 6]) 70 | 71 | 72 | if __name__ == '__main__': 73 | unittest.main() 74 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_gp_sbt_operators.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import unittest 7 | 8 | import random 9 | from behavior_tree_learning.core.sbt import BehaviorTreeStringRepresentation 10 | from behavior_tree_learning.core.sbt import BehaviorNodeFactory 11 | from behavior_tree_learning.core.gp_sbt import Operators 12 | from tests.fwk.behavior_nodes import get_behaviors 13 | 14 | 15 | class TestGpForSbtOperations(unittest.TestCase): 16 | 17 | def setUp(self) -> None: 18 | self._node_factory = BehaviorNodeFactory(get_behaviors()) 19 | 20 | def test_random_genome(self): 21 | 22 | gp_operators = Operators() 23 | length = 5 24 | 25 | for _ in range(10): 26 | genome = gp_operators.random_genome(length) 27 | btsr = BehaviorTreeStringRepresentation(genome) 28 | 29 | self.assertEqual(length, btsr.length()) 30 | self.assertTrue(btsr.is_valid()) 31 | 32 | def test_mutate_gene(self): 33 | 34 | gp_operators = Operators() 35 | genome = ['s(', 'c0', ')'] 36 | 37 | with self.assertRaises(Exception): 38 | gp_operators.mutate_gene(genome, p_add=-1, p_delete=1) 39 | 40 | with self.assertRaises(Exception): 41 | gp_operators.mutate_gene(genome, p_add=1, p_delete=1) 42 | 43 | for _ in range(10): 44 | mutated_genome = gp_operators.mutate_gene(genome, p_add=1, p_delete=0) 45 | self.assertGreaterEqual(len(mutated_genome), len(genome)) 46 | 47 | mutated_genome = gp_operators.mutate_gene(genome, p_add=0, p_delete=1) 48 | self.assertLessEqual(len(mutated_genome), len(genome)) 49 | 50 | mutated_genome = gp_operators.mutate_gene(genome, p_add=0, p_delete=0) 51 | btsr = BehaviorTreeStringRepresentation(mutated_genome) 52 | self.assertNotEqual(mutated_genome, genome) 53 | self.assertTrue(btsr.is_valid()) 54 | 55 | mutated_genome = gp_operators.mutate_gene(genome, p_add=0.3, p_delete=0.3) 56 | btsr.set(mutated_genome) 57 | self.assertNotEqual(mutated_genome, genome) 58 | self.assertTrue(btsr.is_valid()) 59 | 60 | def test_crossover_genome(self): 61 | 62 | gp_operators = Operators() 63 | genome1 = ['s(', 'c0', 'f(', 'c0', 'a0', ')', 'a0', ')'] 64 | genome2 = ['f(', 'c1', 's(', 'c1', 'a1', ')', 'a1', ')'] 65 | offspring1, offspring2 = gp_operators.crossover_genome(genome1, genome2, replace=True) 66 | 67 | self.assertNotEqual(offspring1, []) 68 | self.assertNotEqual(offspring2, []) 69 | self.assertNotEqual(offspring1, genome1) 70 | self.assertNotEqual(offspring1, genome2) 71 | self.assertNotEqual(offspring2, genome1) 72 | self.assertNotEqual(offspring2, genome2) 73 | 74 | btsr_1 = BehaviorTreeStringRepresentation(offspring1) 75 | self.assertTrue(btsr_1.is_valid()) 76 | btsr_1 = btsr_1.set(offspring2) 77 | self.assertTrue(btsr_1.is_valid()) 78 | 79 | genome1 = ['a0'] 80 | genome2 = ['a1'] 81 | offspring1, offspring2 = gp_operators.crossover_genome(genome1, genome2, replace=True) 82 | self.assertEqual(offspring1, genome2) 83 | self.assertEqual(offspring2, genome1) 84 | 85 | genome1 = [] 86 | offspring1, offspring2 = gp_operators.crossover_genome(genome1, genome2, replace=True) 87 | self.assertEqual(offspring1, []) 88 | self.assertEqual(offspring2, []) 89 | 90 | for i in range(10): 91 | random.seed(i) 92 | offspring1, offspring2 = gp_operators.crossover_genome(gp_operators.random_genome(10), 93 | gp_operators.random_genome(10), 94 | replace=True) 95 | btsr_1 = btsr_1.set(offspring1) 96 | self.assertTrue(btsr_1.is_valid()) 97 | btsr_1 = btsr_1.set(offspring2) 98 | self.assertTrue(btsr_1.is_valid()) 99 | 100 | genome1 = ['s(', 'f(', 'c0', 'a0', ')', 'a0', ')'] 101 | genome2 = ['f(', 's(', 'c1', 'a1', ')', 'a1', ')'] 102 | offspring1, offspring2 = gp_operators.crossover_genome(genome1, genome2, replace=False) 103 | self.assertNotEqual(offspring1, genome1) 104 | self.assertNotEqual(offspring2, genome2) 105 | 106 | for gene in genome1: 107 | self.assertTrue(gene in offspring1) 108 | for gene in genome2: 109 | self.assertTrue(gene in offspring2) 110 | 111 | 112 | if __name__ == '__main__': 113 | unittest.main() 114 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_gp_selection.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import unittest 7 | import random 8 | from behavior_tree_learning.core.sbt import BehaviorNodeFactory 9 | from behavior_tree_learning.core.gp import selection as gps 10 | from tests.fwk.behavior_nodes import get_behaviors 11 | 12 | 13 | class TestSelection(unittest.TestCase): 14 | 15 | def setUp(self) -> None: 16 | 17 | self._node_factory = BehaviorNodeFactory(get_behaviors()) 18 | 19 | def test_elite_selection(self): 20 | 21 | population = list(range(6)) 22 | fitness = [0, 1, 2, 1, 3, 1] 23 | 24 | selected = gps.selection(gps.SelectionMethods.ELITISM, population, fitness, 2) 25 | self.assertEqual(selected, [4, 2]) 26 | 27 | population = list(range(8)) 28 | fitness = [0, 6, 2, 8, 3, 2, 1, 1] 29 | 30 | selected = gps.selection(gps.SelectionMethods.ELITISM, population, fitness, 2) 31 | self.assertEqual(selected, [3, 1]) 32 | 33 | def test_tournament_selection(self): 34 | 35 | population = list(range(6)) 36 | fitness = [0, 1, 2, 1, 3, 1] 37 | 38 | selected = gps.selection(gps.SelectionMethods.TOURNAMENT, population, fitness, 2) 39 | self.assertEqual(selected, [4, 2]) 40 | 41 | population = list(range(10)) 42 | fitness = [2, 1, 2, 1, 3, 1, 4, 5, 0, 0] 43 | 44 | selected = gps.selection(gps.SelectionMethods.TOURNAMENT, population, fitness, 5) 45 | self.assertEqual(selected, [6, 3, 7, 4, 2]) 46 | 47 | population = list(range(10)) 48 | fitness = [2, 1, 2, 1, 3, 1, 4, 5, 0, 0] 49 | 50 | selected = gps.selection(gps.SelectionMethods.TOURNAMENT, population, fitness, 3) 51 | self.assertEqual(selected, [7, 2, 6]) 52 | 53 | def test_rank_selection(self): 54 | 55 | population = list(range(6)) 56 | fitness = [0, 1, 2, 1, 3, 1] 57 | 58 | selected = [] 59 | for _ in range(10): 60 | selected += gps.selection(gps.SelectionMethods.RANK, population, fitness, 2) 61 | assert 4 in selected 62 | 63 | population = list(range(10)) 64 | fitness = [2, 1, 1, 1, 3, 1, 4, 5, 3, 0] 65 | 66 | num_times_selected = 0 67 | num_runs = 100 68 | for seed in range(0, num_runs): 69 | selected = gps.selection(gps.SelectionMethods.RANK, population, fitness, 2) 70 | if 0 in selected: 71 | num_times_selected += 1 72 | 73 | # 0 is the 5th in rank so the probability of getting selected 74 | # should be (10 - 4) / sum(1 to 10) = 6 / 55 75 | # Probability of getting selected when picking 2 out of 10 is then 76 | # 6 / 55 + (55 - 6) / 55 * 6 / 55 = 6 / 55 * (2 - 6 / 55) 77 | # Check is with some margin 78 | assert 6/55 * (2 - 6 / 55) - 0.05 < num_times_selected / num_runs < 6/55 * (2 - 6 / 55) + 0.05 79 | 80 | def test_random_selection(self): 81 | 82 | population = list(range(6)) 83 | fitness = [0, 1, 2, 1, 3, 1] 84 | 85 | random.seed(0) 86 | selected_1 = gps.selection(gps.SelectionMethods.RANDOM, population, fitness, 3) 87 | selected2 = gps.selection(gps.SelectionMethods.RANDOM, population, fitness, 3) 88 | self.assertNotEqual(selected_1, selected2) 89 | 90 | def test_all_selection(self): 91 | 92 | population = list(range(6)) 93 | fitness = [0, 1, 2, 1, 3, 1] 94 | 95 | selected = gps.selection(gps.SelectionMethods.ALL, population, fitness, 2) 96 | self.assertEqual(selected, population) 97 | 98 | 99 | if __name__ == '__main__': 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_logplot.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import unittest 7 | import os 8 | import shutil 9 | from behavior_tree_learning.core.logger import logplot 10 | 11 | 12 | class TestStringBehaviorTreeForPyTree(unittest.TestCase): 13 | 14 | OUTPUT_DIR = 'logs' 15 | 16 | def test_trim_logs(self): 17 | 18 | logs = [] 19 | logs.append([1, 2, 3]) 20 | logs.append([1, 2, 3, 4]) 21 | logs.append([1, 2, 3, 4, 5]) 22 | 23 | logplot.trim_logs(logs) 24 | 25 | assert logs == [[1, 2, 3], [1, 2, 3], [1, 2, 3]] 26 | 27 | def test_plot_fitness(self): 28 | 29 | LOG_NAME = 'test' 30 | 31 | logplot.clear_logs(LOG_NAME) 32 | logplot.plot_fitness(LOG_NAME, [0, 1, 2]) 33 | assert os.path.isfile(logplot.get_log_folder(LOG_NAME) + '/Fitness.png') 34 | try: 35 | shutil.rmtree(logplot.get_log_folder('test')) 36 | except FileNotFoundError: 37 | pass 38 | 39 | def test_plot_learning_curves_with_extend_gens(self): 40 | 41 | LOG_NAME_1 = 'test1' 42 | LOG_NAME_2 = 'test2' 43 | PDF_FILE_NAME = 'test.pdf' 44 | 45 | logplot.clear_logs(LOG_NAME_1) 46 | logplot.log_best_fitness(LOG_NAME_1, [1, 2, 3, 4, 5]) 47 | logplot.log_n_episodes(LOG_NAME_1, [5, 10, 15, 20, 25]) 48 | logplot.clear_logs(LOG_NAME_2) 49 | logplot.log_best_fitness(LOG_NAME_2, [1, 2, 5]) 50 | logplot.log_n_episodes(LOG_NAME_2, [5, 10, 15]) 51 | 52 | parameters = logplot.PlotParameters() 53 | parameters.path = PDF_FILE_NAME 54 | parameters.extend_gens = 5 55 | parameters.save_fig = True 56 | parameters.x_max = 30 57 | 58 | logplot.plot_learning_curves([LOG_NAME_1, LOG_NAME_2], parameters) 59 | 60 | def test_plot_learning_curves(self): 61 | 62 | LOG_NAME = 'test' 63 | PDF_FILE_NAME = 'test.pdf' 64 | 65 | try: 66 | os.remove(PDF_FILE_NAME) 67 | except FileNotFoundError: 68 | pass 69 | 70 | logplot.clear_logs(LOG_NAME) 71 | logplot.log_best_fitness(LOG_NAME, [1, 2, 3, 4, 5]) 72 | logplot.log_n_episodes(LOG_NAME, [5, 10, 15, 20, 25]) 73 | 74 | parameters = logplot.PlotParameters() 75 | parameters.path = PDF_FILE_NAME 76 | parameters.extrapolate_y = False 77 | parameters.plot_mean = False 78 | parameters.plot_std = False 79 | parameters.plot_ind = False 80 | parameters.save_fig = False 81 | parameters.x_max = 0 82 | parameters.plot_horizontal = True 83 | logplot.plot_learning_curves([LOG_NAME], parameters) 84 | self.assertFalse(os.path.isfile(PDF_FILE_NAME)) 85 | 86 | parameters.extrapolate_y = True 87 | parameters.plot_mean = True 88 | parameters.plot_std = True 89 | parameters.plot_ind = True 90 | parameters.save_fig = True 91 | parameters.x_max = 100 92 | parameters.plot_horizontal = True 93 | parameters.save_fig = True 94 | logplot.plot_learning_curves([LOG_NAME], parameters) 95 | self.assertTrue(os.path.isfile(PDF_FILE_NAME)) 96 | os.remove(PDF_FILE_NAME) 97 | 98 | parameters.x_max = 10 99 | parameters.plot_horizontal = False 100 | logplot.plot_learning_curves([LOG_NAME], parameters) 101 | assert os.path.isfile(PDF_FILE_NAME) 102 | 103 | os.remove(PDF_FILE_NAME) 104 | try: 105 | shutil.rmtree(logplot.get_log_folder(LOG_NAME)) 106 | except FileNotFoundError: 107 | pass 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_parse_operation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import unittest 7 | from behavior_tree_learning.core.sbt.parse_operation import parse_function 8 | 9 | 10 | class TestParseFunctionAsText(unittest.TestCase): 11 | 12 | @staticmethod 13 | def _execute_test(test_name, text): 14 | 15 | print("\n") 16 | print(test_name) 17 | print(" + original: ", text) 18 | name, arguments, return_value = parse_function(text) 19 | print(" + Result") 20 | print(" + name: ", name) 21 | print(" + arguments: ", arguments) 22 | print(" + return value: ", return_value) 23 | 24 | def test_correct_syntax_using_without_arguments_1(self): 25 | 26 | text = "a_function[] => [E :bool]" 27 | name, arguments, return_value = parse_function(text) 28 | 29 | self.assertEqual(name, "a_function") 30 | self.assertEqual(arguments, None) 31 | self.assertEqual(return_value, {'E': 'bool'}) 32 | 33 | def test_correct_syntax_using_without_arguments_2(self): 34 | 35 | text = "a_function []=>[E :bool]" 36 | name, arguments, return_value = parse_function(text) 37 | 38 | self.assertEqual(name, "a_function") 39 | self.assertEqual(arguments, None) 40 | self.assertEqual(return_value, {'E': 'bool'}) 41 | 42 | def test_correct_syntax_using_one_argument(self): 43 | 44 | text = "another_function[A :place] => [E :bool]" 45 | name, arguments, return_value = parse_function(text) 46 | 47 | self.assertEqual(name, "another_function") 48 | self.assertEqual(arguments, {'A': 'place'}) 49 | self.assertEqual(return_value, {'E': 'bool'}) 50 | 51 | def test_correct_syntax_using_two_argument(self): 52 | 53 | text = "another_function[A:place, B :pose]=>[E :bool]" 54 | name, arguments, return_value = parse_function(text) 55 | 56 | self.assertEqual(name, "another_function") 57 | self.assertEqual(arguments, {'A': 'place', 'B': 'pose'}) 58 | self.assertEqual(return_value, {'E': 'bool'}) 59 | 60 | def test_correct_syntax_without_return_value_1(self): 61 | text = "another_function[]" 62 | name, arguments, return_value = parse_function(text) 63 | 64 | self.assertEqual(name, "another_function") 65 | self.assertEqual(arguments, None) 66 | self.assertEqual(return_value, None) 67 | 68 | def test_correct_syntax_without_return_value_2(self): 69 | 70 | text = "another_function[A:place]" 71 | name, arguments, return_value = parse_function(text) 72 | 73 | self.assertEqual(name, "another_function") 74 | self.assertEqual(arguments, {'A': 'place'}) 75 | self.assertEqual(return_value, None) 76 | 77 | def test_correct_syntax_without_return_value_3(self): 78 | 79 | text = "another_function[A:place, B :pose]" 80 | name, arguments, return_value = parse_function(text) 81 | 82 | self.assertEqual(name, "another_function") 83 | self.assertEqual(arguments, {'A': 'place', 'B': 'pose'}) 84 | self.assertEqual(return_value, None) 85 | 86 | def test_correct_syntax_with_two_return_values(self): 87 | 88 | text = "another_function[] => [E :bool, F:place]" 89 | name, arguments, return_value = parse_function(text) 90 | 91 | self.assertEqual(name, "another_function") 92 | self.assertEqual(arguments, None) 93 | self.assertEqual(return_value, {'E': 'bool', 'F': 'place'}) 94 | 95 | def test_wrong_syntax_return_value_without_parenthesis(self): 96 | 97 | with self.assertRaises(Exception): 98 | text = "another_function[] => E :bool" 99 | parse_function(text) 100 | 101 | def test_wrong_syntax_function_declaration_without_parenthesis(self): 102 | 103 | with self.assertRaises(Exception): 104 | text = "another_function => (E :bool)" 105 | parse_function(text) 106 | 107 | def test_wrong_syntax_return_values_only_parenthesis(self): 108 | 109 | with self.assertRaises(Exception): 110 | text = "another_function[] => []" 111 | parse_function(text) 112 | 113 | 114 | if __name__ == '__main__': 115 | unittest.main() 116 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_sbt_btsr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import unittest 7 | import os 8 | import random 9 | from behavior_tree_learning.core.sbt import behavior_tree 10 | from behavior_tree_learning.core.sbt import BehaviorTreeStringRepresentation, StringBehaviorTree 11 | 12 | 13 | class TestBehaviorTreeStringRepresentation(unittest.TestCase): 14 | 15 | def setUp(self): 16 | 17 | behavior_tree.load_settings_from_file( 18 | os.path.join(paths.get_test_directory(), 'fwk', 'BT_SETTINGS.yaml')) 19 | 20 | def test_init(self): 21 | 22 | _ = BehaviorTreeStringRepresentation([]) 23 | 24 | def test_random(self): 25 | 26 | btsr = BehaviorTreeStringRepresentation([]) 27 | random.seed(1337) 28 | 29 | for length in range(1, 11): 30 | btsr.random(length) 31 | assert btsr.length() == length 32 | assert btsr.is_valid() 33 | 34 | def test_is_valid(self): 35 | 36 | btsr = BehaviorTreeStringRepresentation([]) 37 | self.assertFalse(btsr.is_valid()) 38 | 39 | # Valid tree 40 | btsr.set(['s(', 'c0', 'f(', 'c0', 'a0', ')', 'a0', ')']) 41 | self.assertTrue(btsr.is_valid()) 42 | 43 | # Minimal valid tree - just an action node 44 | btsr.set(['a0']) 45 | self.assertTrue(btsr.is_valid()) 46 | 47 | # Two control nodes at root level - not valid 48 | btsr.set(['s(', 'c0', 'f(', 'c0', 'a0', ')', 'a0', ')', 's(', 'a0', ')']) 49 | self.assertFalse(btsr.is_valid()) 50 | 51 | # Action node at root level - not valid 52 | btsr.set(['s(', 'c0', 'f(', 'c0', 'a0', ')', ')', 'a0', ')']) 53 | self.assertFalse(btsr.is_valid()) 54 | 55 | # Too few up nodes - not valid 56 | btsr.set(['s(', 'c0', 'f(', 'c0', 'a0', ')', 'a0']) 57 | self.assertFalse(btsr.is_valid()) 58 | 59 | # Too few up nodes - not valid 60 | btsr.set(['s(', 'c0', 'f(', 'c0', 'a0', ')']) 61 | self.assertFalse(btsr.is_valid()) 62 | 63 | # No control nodes, but more than one action - not valid 64 | btsr.set(['a0', 'a0']) 65 | self.assertFalse(btsr.is_valid()) 66 | 67 | # Starts with an up node - not valid 68 | btsr.set([')', 'f(', 'c0', 'a0', ')']) 69 | self.assertFalse(btsr.is_valid()) 70 | 71 | # Just a control node - not valid 72 | btsr.set(['s(', ')']) 73 | self.assertFalse(btsr.is_valid()) 74 | 75 | # Just a control node - not valid 76 | btsr.set(['s(', 's(']) 77 | self.assertFalse(btsr.is_valid()) 78 | 79 | # Up just after control node 80 | btsr.set(['s(', 'f(', ')', 'a0', ')']) 81 | self.assertFalse(btsr.is_valid()) 82 | 83 | # Unknown characters 84 | btsr.set(['s(', 'c0', 'x', 'y', 'z', ')']) 85 | self.assertFalse(btsr.is_valid()) 86 | 87 | def test_subtree_is_valid(self): 88 | 89 | btsr = BehaviorTreeStringRepresentation([]) 90 | 91 | self.assertTrue(btsr.is_subtree_valid(['s(', 'f(', 'a0', ')', ')', ')'], True, True)) 92 | 93 | self.assertFalse(btsr.is_subtree_valid(['s(', 'f(', 'a0', ')', ')', ')'], True, False)) 94 | 95 | self.assertFalse(btsr.is_subtree_valid(['f(', 's(', 'a0', ')', ')', ')'], False, True)) 96 | 97 | self.assertFalse(btsr.is_subtree_valid(['f(', 'f(', 'a0', ')', ')', ')'], True, True)) 98 | 99 | self.assertFalse(btsr.is_subtree_valid(['s(', 's(', 'a0', ')', ')', ')'], True, True)) 100 | 101 | self.assertFalse(btsr.is_subtree_valid(['s(', 'f(', 'a0', ')', ')'], True, True)) 102 | 103 | self.assertTrue(btsr.is_subtree_valid(['s(', 'f(', 'c0', ')', ')', ')'], True, True)) 104 | 105 | def test_close(self): 106 | """ Tests close function """ 107 | 108 | btsr = BehaviorTreeStringRepresentation([]) 109 | 110 | btsr.close() 111 | self.assertEqual(btsr.bt, []) 112 | 113 | # Correct tree with just one action 114 | btsr.set(['a0']).close() 115 | self.assertEqual(btsr.bt, ['a0']) 116 | 117 | # Correct tree 118 | btsr.set(['s(', 's(', 'a0', ')', ')']).close() 119 | self.assertEqual(btsr.bt, ['s(', 's(', 'a0', ')', ')']) 120 | 121 | # Missing up at end 122 | btsr.set(['s(', 's(', 'a0', ')', 's(', 'a0', 's(', 'a0']).close() 123 | self.assertEqual(btsr.bt, ['s(', 's(', 'a0', ')', 's(', 'a0', 's(', 'a0', ')', ')', ')']) 124 | 125 | # Too many up at end 126 | btsr.set(['s(', 'a0', ')', ')', ')']).close() 127 | self.assertEqual(btsr.bt, ['s(', 'a0', ')']) 128 | 129 | # Too many up but not at the end 130 | btsr.set(['s(', 's(', 'a0', ')', ')', ')', 'a1', ')']).close() 131 | self.assertEqual(btsr.bt, ['s(', 's(', 'a0', ')', 'a1', ')']) 132 | 133 | def test_trim(self): 134 | """ Tests trim function """ 135 | 136 | btsr = BehaviorTreeStringRepresentation([]) 137 | 138 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', 's(', 'a0', ')', ')']) 139 | btsr.trim() 140 | self.assertEqual(btsr.bt, ['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', 'a0', ')']) 141 | 142 | btsr.set(['s(', 'a0', 'f(', ')', 'a0', 's(', 'a0', ')', ')']) 143 | btsr.trim() 144 | self.assertEqual(btsr.bt, ['s(', 'a0', 'a0', 'a0', ')']) 145 | 146 | btsr.set(['s(', 'a0', 'f(', 'a1', 's(', 'a2', ')', 'a3', ')', 'a4', ')']) 147 | btsr.trim() 148 | self.assertEqual(btsr.bt, ['s(', 'a0', 'f(', 'a1', 'a2', 'a3', ')', 'a4', ')']) 149 | 150 | btsr.set(['s(', 'a0', 'f(', 's(', 'a2', 'a3', ')', ')', 'a4', ')']) 151 | btsr.trim() 152 | self.assertEqual(btsr.bt, ['s(', 'a0', 'a2', 'a3', 'a4', ')']) 153 | 154 | btsr.set(['s(', 'a0', ')']) 155 | btsr.trim() 156 | self.assertEqual(btsr.bt, ['s(', 'a0', ')']) 157 | 158 | def test_depth(self): 159 | """ Tests bt_depth function """ 160 | 161 | btsr = BehaviorTreeStringRepresentation([]) 162 | 163 | # Normal correct tree 164 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 165 | self.assertEqual(btsr.depth(), 2) 166 | 167 | # Goes to 0 before last node - invalid 168 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')', 's(', 'a0', ')']) 169 | self.assertEqual(btsr.depth(), -1) 170 | 171 | # Goes to 0 before last node - invalid 172 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', ')', 'a0', ')']) 173 | self.assertEqual(btsr.depth(), -1) 174 | 175 | # Goes to 0 before last node - invalid 176 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0']) 177 | self.assertEqual(btsr.depth(), -1) 178 | 179 | # Just an action node - no depth 180 | btsr.set(['a0']) 181 | self.assertEqual(btsr.depth(), 0) 182 | 183 | def test_length(self): 184 | """ Tests bt_length function """ 185 | 186 | btsr = BehaviorTreeStringRepresentation([]) 187 | 188 | btsr.set(['s(', 'a0', 'a1', ')']) 189 | self.assertEqual(btsr.length(), 3) 190 | 191 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 192 | self.assertEqual(btsr.length(), 6) 193 | 194 | btsr.set(['s(', ')']) 195 | self.assertEqual(btsr.length(), 1) 196 | 197 | btsr.set(['a0']) 198 | self.assertEqual(btsr.length(), 1) 199 | 200 | def test_change_node(self): 201 | """ Tests change_node function """ 202 | 203 | btsr = BehaviorTreeStringRepresentation([]) 204 | random.seed(1337) 205 | 206 | # No new node given, change to random node 207 | btsr.set(['s(', 'a0', 'a0', ')']).change_node(2) 208 | self.assertNotEqual(btsr.bt[2], 'a0') 209 | 210 | # Change control node to action node 211 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).change_node(2, 'a0') 212 | self.assertEqual(btsr.bt, ['s(', 'a0', 'a0', 'a0', ')']) 213 | 214 | # Change control node to action node - correct up must be removed too 215 | btsr.set(['s(', 'a0', 'f(', 's(', 'a0', ')', 'a0', ')', 'a0', ')']).change_node(2, 'a0') 216 | self.assertEqual(btsr.bt, ['s(', 'a0', 'a0', 'a0', ')']) 217 | 218 | btsr.set(['s(', 'a0', 'f(', 's(', 'a0', ')', 'a1', ')', 'a0', ')']).change_node(3, 'a0') 219 | self.assertEqual(btsr.bt, ['s(', 'a0', 'f(', 'a0', 'a1', ')', 'a0', ')']) 220 | 221 | # Change action node to control node 222 | btsr.set(['s(', 'a0', 'a0', ')']).change_node(1, 'f(') 223 | self.assertEqual(btsr.bt, ['s(', 'f(', 'a0', 'a0', ')', 'a0', ')']) 224 | 225 | # Change action node to action node 226 | btsr.set(['s(', 'a0', 'a0', ')']).change_node(1, 'a1') 227 | self.assertEqual(btsr.bt, ['s(', 'a1', 'a0', ')']) 228 | 229 | # Change control node to control node 230 | btsr.set(['s(', 'a0', 'a0', ')']).change_node(0, 'f(') 231 | self.assertEqual(btsr.bt, ['f(', 'a0', 'a0', ')']) 232 | 233 | # Change up node, not possible 234 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).change_node(5, 'a0') 235 | self.assertEqual(btsr.bt, ['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 236 | 237 | def test_add_node(self): 238 | """ Tests add_node function """ 239 | 240 | btsr = BehaviorTreeStringRepresentation([]) 241 | random.seed(1337) 242 | 243 | btsr.set(['a0']).add_node(0, 's(') 244 | self.assertEqual(btsr.bt, ['s(', 'a0', ')']) 245 | 246 | btsr.set(['s(', 'a0', 'a0', ')']).add_node(2) 247 | self.assertEqual(btsr.bt, ['s(', 'a0', 'a3', 'a0', ')']) 248 | 249 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).add_node(2, 'a0') 250 | self.assertEqual(btsr.bt, ['s(', 'a0', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 251 | 252 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).add_node(3, 'a0') 253 | self.assertEqual(btsr.bt, ['s(', 'a0', 'f(', 'a0', 'a0', 'a0', ')', 'a0', ')']) 254 | 255 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).add_node(0, 'f(') 256 | self.assertEqual(btsr.bt, ['f(', 's(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')', ')']) 257 | 258 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).add_node(4, 's(') 259 | self.assertTrue(btsr.is_valid()) 260 | 261 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).add_node(2, 'f(') 262 | self.assertTrue(btsr.is_valid()) 263 | 264 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).add_node(1, 'f(') 265 | self.assertTrue(btsr.is_valid()) 266 | 267 | btsr.set(['s(', 'a0', 'f(', 'c1', 'a0', ')', ')']).add_node(2, 'f(') 268 | self.assertTrue(btsr.is_valid()) 269 | 270 | def test_delete_node(self): 271 | """ Tests delete_node function """ 272 | 273 | btsr = BehaviorTreeStringRepresentation([]) 274 | 275 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).delete_node(0) 276 | self.assertEqual(btsr.bt, []) 277 | 278 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 's(', 'a0', ')', ')']).delete_node(0) 279 | self.assertEqual(btsr.bt, []) 280 | 281 | btsr.set(['s(', 'a0', 'f(', 'a0', 's(', 'a0', ')', ')', 's(', 'a0', ')', ')']).delete_node(0) 282 | self.assertEqual(btsr.bt, []) 283 | 284 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']).delete_node(1) 285 | self.assertEqual(btsr.bt, ['s(', 'f(', 'a0', 'a0', ')', 'a0', ')']) 286 | 287 | btsr.set(['s(', 'a0', 'f(', 'a1', 'a2', ')', 'a3', ')']).delete_node(2) 288 | self.assertEqual(btsr.bt, ['s(', 'a0', 'a3', ')']) 289 | 290 | btsr.set(['s(', 'a0', 'f(', 'a0', ')', 'a0', ')']).delete_node(3) 291 | self.assertEqual(btsr.bt, ['s(', 'a0', 'f(', ')', 'a0', ')']) 292 | 293 | btsr.set(['s(', 'a0', ')']).delete_node(2) 294 | self.assertEqual(btsr.bt, ['s(', 'a0', ')']) 295 | 296 | def test_find_parent(self): 297 | """ Tests find_parent function """ 298 | 299 | btsr = BehaviorTreeStringRepresentation([]) 300 | btsr.set(['s(', 'a0', 'f(', 'a0', ')', 'a0', ')']) 301 | 302 | self.assertEqual(btsr.find_parent(0), None) 303 | self.assertEqual(btsr.find_parent(1), 0) 304 | self.assertEqual(btsr.find_parent(2), 0) 305 | self.assertEqual(btsr.find_parent(3), 2) 306 | self.assertEqual(btsr.find_parent(4), 2) 307 | self.assertEqual(btsr.find_parent(5), 0) 308 | 309 | def test_find_children(self): 310 | """ Tests find_children function """ 311 | 312 | btsr = BehaviorTreeStringRepresentation([]) 313 | btsr.set(['s(', 'a0', 'f(', 'a0', ')', 'a0', ')']) 314 | 315 | self.assertEqual(btsr.find_children(0), [1, 2, 5]) 316 | self.assertEqual(btsr.find_children(1), []) 317 | self.assertEqual(btsr.find_children(2), [3]) 318 | self.assertEqual(btsr.find_children(3), []) 319 | self.assertEqual(btsr.find_children(4), []) 320 | self.assertEqual(btsr.find_children(5), []) 321 | 322 | def test_find_up_node(self): 323 | """ Tests find_up_node function """ 324 | 325 | btsr = BehaviorTreeStringRepresentation([]) 326 | 327 | btsr.set(['s(', 'a0', 'f(', 'a0', ')', 'a0', ')']) 328 | self.assertEqual(btsr.find_up_node(0), 6) 329 | 330 | btsr.set(['s(', 'a0', 'f(', 'a0', ')', 'a0', ')']) 331 | self.assertEqual(btsr.find_up_node(2), 4) 332 | 333 | btsr.set(['s(', 'a0', 'f(', 's(', 'a0', ')', 'a0', ')']) 334 | self.assertEqual(btsr.find_up_node(2), 7) 335 | 336 | btsr.set(['s(', 'a0', 'f(', 'a0', ')', 'a0', ')']) 337 | with self.assertRaises(Exception): 338 | _ = btsr.find_up_node(1) 339 | 340 | btsr.set(['s(', 'a0', 'f(', 'a0', ')', 'a0']) 341 | with self.assertRaises(Exception): 342 | _ = btsr.find_up_node(0) 343 | 344 | btsr.set(['s(', 's(', 'a0', 'f(', 'a0', ')', 'a0']) 345 | with self.assertRaises(Exception): 346 | _ = btsr.find_up_node(1) 347 | 348 | def test_get_subtree(self): 349 | """ Tests get_subtree function """ 350 | 351 | btsr = BehaviorTreeStringRepresentation(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 352 | 353 | subtree = btsr.get_subtree(1) 354 | self.assertEqual(subtree, ['a0']) 355 | 356 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 's(', 'a0', 'a0', ')', ')']) 357 | subtree = btsr.get_subtree(6) 358 | self.assertEqual(subtree, ['s(', 'a0', 'a0', ')']) 359 | 360 | btsr.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', ')']) 361 | subtree = btsr.get_subtree(2) 362 | self.assertEqual(subtree, ['f(', 'a0', 'a0', ')']) 363 | 364 | subtree = btsr.get_subtree(5) 365 | self.assertEqual(subtree, []) 366 | 367 | def test_insert_subtree(self): 368 | 369 | btsr = BehaviorTreeStringRepresentation(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 370 | 371 | btsr.insert_subtree(['f(', 'a1', ')'], 1) 372 | self.assertEqual(btsr.bt, ['s(', 'f(', 'a1', ')', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 373 | 374 | btsr.insert_subtree(['f(', 'a1', ')'], 6) 375 | self.assertEqual(btsr.bt, ['s(', 'f(', 'a1', ')', 'a0', 'f(', 'f(', 'a1', ')', 'a0', 'a0', ')', 'a0', ')']) 376 | 377 | def test_swap_subtrees(self): 378 | 379 | btsr_1 = BehaviorTreeStringRepresentation(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 380 | btsr_2 = BehaviorTreeStringRepresentation(['s(', 'a0', 'f(', 'a0', 'a0', ')', 's(', 'a0', 'a0', ')', ')']) 381 | btsr_1.swap_subtrees(btsr_2, 6, 6) 382 | 383 | self.assertEqual(btsr_1.bt, ['s(', 'a0', 'f(', 'a0', 'a0', ')', 's(', 'a0', 'a0', ')', ')']) 384 | self.assertEqual(btsr_2.bt, ['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 385 | 386 | # Invalid subtree because it's an up node, no swap 387 | btsr_1.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 388 | btsr_2.set(['s(', 'a0', 'f(', 'a0', 'a0', ')', 's(', 'a0', 'a0', ')', ')']) 389 | btsr_1.swap_subtrees(btsr_2, 5, 6) 390 | 391 | self.assertEqual(btsr_1.bt, ['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 392 | self.assertEqual(btsr_2.bt, ['s(', 'a0', 'f(', 'a0', 'a0', ')', 's(', 'a0', 'a0', ')', ')']) 393 | 394 | def test_is_subtree(self): 395 | 396 | btsr = BehaviorTreeStringRepresentation(['s(', 'a0', 'f(', 'a0', 'a0', ')', 'a0', ')']) 397 | 398 | self.assertTrue(btsr.is_subtree(0)) 399 | self.assertTrue(btsr.is_subtree(1)) 400 | self.assertFalse(btsr.is_subtree(5)) 401 | 402 | 403 | if __name__ == '__main__': 404 | unittest.main() 405 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_sbt_node_factory.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | 5 | paths.add_modules_to_path() 6 | 7 | import unittest 8 | 9 | from behavior_tree_learning.core.sbt import BehaviorTreeStringRepresentation 10 | from behavior_tree_learning.core.sbt import BehaviorNodeFactory 11 | from tests.fwk.behavior_nodes import get_behaviors 12 | 13 | 14 | class TestBehaviorNodeFactory(unittest.TestCase): 15 | 16 | def setUp(self) -> None: 17 | 18 | from behavior_tree_learning.core.sbt.behavior_tree import _clean_settings 19 | _clean_settings() 20 | 21 | def test_btsr_does_not_work_without_factory(self): 22 | 23 | btsr = BehaviorTreeStringRepresentation(['s(', 'c0', 'f(', 'c0', 'a0', ')', 'a0', ')']) 24 | self.assertFalse(btsr.is_valid()) 25 | 26 | def test_btsr_work_after_initialize_factory(self): 27 | 28 | node_factory = BehaviorNodeFactory(get_behaviors()) 29 | btsr = BehaviorTreeStringRepresentation(['s(', 'c0', 'f(', 'c0', 'a0', ')', 'a0', ')']) 30 | self.assertTrue(btsr.is_valid()) 31 | 32 | 33 | if __name__ == '__main__': 34 | unittest.main() 35 | -------------------------------------------------------------------------------- /src/behavior_tree_learning/tests/test_sbt_pytree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import paths 4 | paths.add_modules_to_path() 5 | 6 | import unittest 7 | 8 | from behavior_tree_learning.core.sbt import StringBehaviorTree 9 | from behavior_tree_learning.core.sbt import BehaviorNodeFactory 10 | from behavior_tree_learning.core.plotter import print_ascii_tree 11 | from tests.fwk.behavior_nodes import get_behaviors 12 | 13 | 14 | class TestStringBehaviorTree(unittest.TestCase): 15 | 16 | def setUp(self) -> None: 17 | 18 | self._node_factory = BehaviorNodeFactory(get_behaviors()) 19 | 20 | def test_bt_from_str(self): 21 | 22 | sbt = ['f(', 'c0', 'c0', ')'] 23 | bt = StringBehaviorTree(sbt, behaviors=self._node_factory) 24 | self.assertEqual(sbt, []) 25 | self.assertEqual(len(bt.root.children), 2) 26 | print_ascii_tree(bt) 27 | 28 | sbt = ['f(', 'f(', 'c0', 'c0', ')', ')'] 29 | bt = StringBehaviorTree(sbt, behaviors=self._node_factory) 30 | self.assertEqual(sbt, []) 31 | self.assertEqual(len(bt.root.children), 1) 32 | print_ascii_tree(bt) 33 | 34 | sbt = ['f(', 'f(', 'c0', 'c0', ')', 's(', 'c0', 'c0', ')', ')'] 35 | bt = StringBehaviorTree(sbt, behaviors=self._node_factory) 36 | self.assertEqual(sbt, []) 37 | self.assertEqual(len(bt.root.children), 2) 38 | print_ascii_tree(bt) 39 | 40 | sbt = ['f(', 'f(', 'c0', 'c0', ')', 'f(', 's(', 'c0', ')', ')', ')'] 41 | bt = StringBehaviorTree(sbt, behaviors=self._node_factory) 42 | self.assertEqual(sbt, []) 43 | self.assertEqual(len(bt.root.children), 2) 44 | print_ascii_tree(bt) 45 | 46 | sbt = ['f(', 'f(', 'c0', 'c0', ')'] 47 | bt = StringBehaviorTree(sbt, behaviors=self._node_factory) 48 | self.assertEqual(sbt, []) 49 | self.assertEqual(len(bt.root.children), 1) 50 | print_ascii_tree(bt) 51 | 52 | sbt = ['f(', 'f(', 'c0', ')', ')', 'c0', ')'] 53 | bt = StringBehaviorTree(sbt, behaviors=self._node_factory) 54 | self.assertEqual(sbt, ['c0', ')']) 55 | self.assertEqual(len(bt.root.children), 1) 56 | print_ascii_tree(bt) 57 | 58 | with self.assertRaises(Exception): 59 | StringBehaviorTree(['nonbehavior'], behaviors=self._node_factory) 60 | 61 | with self.assertRaises(Exception): 62 | StringBehaviorTree(['f(', 'nonpytreesbehavior', ')'], behaviors=self._node_factory) 63 | 64 | def test_str_from_bt(self): 65 | 66 | sbt = ['f(', 'f(', 'c0', 'c0', ')', 'f(', 's(', 'c0', ')', ')', ')'] 67 | bt = StringBehaviorTree(sbt[:], behaviors=self._node_factory) 68 | self.assertEqual(bt.to_string(), sbt) 69 | 70 | sbt = ['f(', 'f(', 'c0', ')', 'c0', ')'] 71 | bt = StringBehaviorTree(sbt[:], behaviors=self._node_factory) 72 | self.assertEqual(bt.to_string(), sbt) 73 | 74 | 75 | if __name__ == '__main__': 76 | unittest.main() 77 | --------------------------------------------------------------------------------