├── .github └── workflows │ ├── ros-ci.yml │ └── ros-lint.yml ├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── turtlebot3_dqn ├── CHANGELOG.rst ├── package.xml ├── resource │ └── turtlebot3_dqn ├── saved_model │ ├── stage1_episode600.h5 │ └── stage1_episode600.json ├── setup.cfg ├── setup.py └── turtlebot3_dqn │ ├── __init__.py │ ├── action_graph.py │ ├── dqn_agent.py │ ├── dqn_environment.py │ ├── dqn_gazebo.py │ ├── dqn_test.py │ └── result_graph.py ├── turtlebot3_machine_learning ├── CHANGELOG.rst ├── CMakeLists.txt └── package.xml └── turtlebot3_machine_learning_ci.repos /.github/workflows/ros-ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [ humble, jazzy, main ] 6 | pull_request: 7 | branches: [ humble, jazzy, main ] 8 | 9 | jobs: 10 | ROS_CI: 11 | runs-on: ubuntu-22.04 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | ros_distribution: 16 | - humble 17 | - jazzy 18 | - rolling 19 | include: 20 | # ROS 2 Humble Hawksbill 21 | - docker_image: ubuntu:jammy 22 | ros_distribution: humble 23 | ros_version: 2 24 | # ROS 2 Jazzy Jalisco 25 | - docker_image: ubuntu:noble 26 | ros_distribution: jazzy 27 | ros_version: 2 28 | # ROS 2 Rolling Ridley 29 | - docker_image: ubuntu:noble 30 | ros_distribution: rolling 31 | ros_version: 2 32 | container: 33 | image: ${{ matrix.docker_image }} 34 | steps: 35 | - name: Setup workspace 36 | run: mkdir -p ros_ws/src 37 | 38 | - name: Checkout code 39 | uses: actions/checkout@v4 40 | with: 41 | path: ros_ws/src 42 | 43 | - name: Setup ROS environment 44 | uses: ros-tooling/setup-ros@v0.7 45 | with: 46 | required-ros-distributions: ${{ matrix.ros_distribution }} 47 | 48 | - name: Add pip break-system-packages for rosdep 49 | run: | 50 | printf "[install]\nbreak-system-packages = true\n" | sudo tee /etc/pip.conf 51 | 52 | - name: Build and Test 53 | uses: ros-tooling/action-ros-ci@v0.3 54 | env: 55 | PIP_BREAK_SYSTEM_PACKAGES: "1" 56 | with: 57 | target-ros2-distro: ${{ matrix.ros_distribution }} 58 | vcs-repo-file-url: "https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/main/turtlebot3_machine_learning_ci.repos" 59 | package-name: | 60 | turtlebot3_dqn 61 | turtlebot3_machine_learning 62 | -------------------------------------------------------------------------------- /.github/workflows/ros-lint.yml: -------------------------------------------------------------------------------- 1 | # The name of the workflow 2 | name: Lint 3 | 4 | # Specifies the events that trigger the workflow 5 | on: 6 | pull_request: 7 | 8 | # Defines a set of jobs to be run as part of the workflow 9 | jobs: 10 | ament_lint: 11 | runs-on: ubuntu-latest 12 | container: 13 | image: rostooling/setup-ros-docker:ubuntu-noble-ros-rolling-ros-base-latest 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | linter: [flake8, pep257, xmllint, copyright] 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v4 21 | 22 | - name: Setup ROS environment 23 | uses: ros-tooling/setup-ros@v0.7 24 | 25 | - name: Run Linter 26 | env: 27 | AMENT_CPPCHECK_ALLOW_SLOW_VERSIONS: 1 28 | uses: ros-tooling/action-ros-lint@master 29 | with: 30 | linter: ${{ matrix.linter }} 31 | distribution: rolling 32 | package-name: "*" 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | # own 108 | /docker_setup/backup 109 | *.idea 110 | *.log 111 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Any contribution that you make to this repository will 2 | be under the Apache 2 License, as dictated by that 3 | [license](http://www.apache.org/licenses/LICENSE-2.0.html): 4 | 5 | ~~~ 6 | 5. Submission of Contributions. Unless You explicitly state otherwise, 7 | any Contribution intentionally submitted for inclusion in the Work 8 | by You to the Licensor shall be under the terms and conditions of 9 | this License, without any additional terms or conditions. 10 | Notwithstanding the above, nothing herein shall supersede or modify 11 | the terms of any separate license agreement you may have executed 12 | with Licensor regarding such Contributions. 13 | ~~~ 14 | 15 | Contributors must sign-off each commit by adding a `Signed-off-by: ...` 16 | line to commit messages to certify that they have the right to submit 17 | the code they are contributing to the project according to the 18 | [Developer Certificate of Origin (DCO)](https://developercertificate.org/). 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TurtleBot3 2 | 3 | 4 | - Active Branches: noetic, humble, jazzy, main(rolling) 5 | - Legacy Branches: *-devel 6 | 7 | ## Open Source Projects Related to TurtleBot3 8 | - [turtlebot3](https://github.com/ROBOTIS-GIT/turtlebot3) 9 | - [turtlebot3_msgs](https://github.com/ROBOTIS-GIT/turtlebot3_msgs) 10 | - [turtlebot3_simulations](https://github.com/ROBOTIS-GIT/turtlebot3_simulations) 11 | - [turtlebot3_manipulation](https://github.com/ROBOTIS-GIT/turtlebot3_manipulation) 12 | - [turtlebot3_manipulation_simulations](https://github.com/ROBOTIS-GIT/turtlebot3_manipulation_simulations) 13 | - [turtlebot3_applications](https://github.com/ROBOTIS-GIT/turtlebot3_applications) 14 | - [turtlebot3_applications_msgs](https://github.com/ROBOTIS-GIT/turtlebot3_applications_msgs) 15 | - [turtlebot3_machine_learning](https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning) 16 | - [turtlebot3_autorace](https://github.com/ROBOTIS-GIT/turtlebot3_autorace) 17 | - [turtlebot3_home_service_challenge](https://github.com/ROBOTIS-GIT/turtlebot3_home_service_challenge) 18 | - [hls_lfcd_lds_driver](https://github.com/ROBOTIS-GIT/hls_lfcd_lds_driver) 19 | - [ld08_driver](https://github.com/ROBOTIS-GIT/ld08_driver) 20 | - [open_manipulator](https://github.com/ROBOTIS-GIT/open_manipulator) 21 | - [dynamixel_sdk](https://github.com/ROBOTIS-GIT/DynamixelSDK) 22 | - [OpenCR-Hardware](https://github.com/ROBOTIS-GIT/OpenCR-Hardware) 23 | - [OpenCR](https://github.com/ROBOTIS-GIT/OpenCR) 24 | 25 | ## Documentation, Videos, and Community 26 | 27 | ### Official Documentation 28 | - ⚙️ **[ROBOTIS DYNAMIXEL](https://dynamixel.com/)** 29 | - 📚 **[ROBOTIS e-Manual for Dynamixel SDK](http://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/overview/)** 30 | - 📚 **[ROBOTIS e-Manual for TurtleBot3](http://turtlebot3.robotis.com/)** 31 | - 📚 **[ROBOTIS e-Manual for OpenMANIPULATOR-X](https://emanual.robotis.com/docs/en/platform/openmanipulator_x/overview/)** 32 | 33 | ### Learning Resources 34 | - 🎥 **[ROBOTIS YouTube Channel](https://www.youtube.com/@ROBOTISCHANNEL)** 35 | - 🎥 **[ROBOTIS Open Source YouTube Channel](https://www.youtube.com/@ROBOTISOpenSourceTeam)** 36 | - 🎥 **[ROBOTIS TurtleBot3 YouTube Playlist](https://www.youtube.com/playlist?list=PLRG6WP3c31_XI3wlvHlx2Mp8BYqgqDURU)** 37 | - 🎥 **[ROBOTIS OpenMANIPULATOR YouTube Playlist](https://www.youtube.com/playlist?list=PLRG6WP3c31_WpEsB6_Rdt3KhiopXQlUkb)** 38 | 39 | ### Community & Support 40 | - 💬 **[ROBOTIS Community Forum](https://forum.robotis.com/)** 41 | - 💬 **[TurtleBot category from ROS Community](https://discourse.ros.org/c/turtlebot/)** 42 | -------------------------------------------------------------------------------- /turtlebot3_dqn/CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2 | Changelog for package turtlebot3_dqn 3 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 4 | 5 | 1.0.1 (2025-05-02) 6 | ------------------ 7 | * Support for ROS 2 Jazzy version 8 | * Gazebo simulation support for the package 9 | * Contributors: ChanHyeong Lee 10 | 11 | 1.0.0 (2025-04-17) 12 | ------------------ 13 | * Support for ROS 2 Humble version 14 | * Renewal of package structure 15 | * Improved behavioral rewards for agents 16 | * Contributors: ChanHyeong Lee 17 | -------------------------------------------------------------------------------- /turtlebot3_dqn/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | turtlebot3_dqn 4 | 1.0.1 5 | 6 | The turtlebot3_dqn package using reinforcement learning with DQN (Deep Q-Learning). 7 | 8 | Pyo 9 | Apache 2.0 10 | http://turtlebot3.robotis.com 11 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning 12 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning/issues 13 | Gilbert 14 | ChanHyeong Lee 15 | python3-pip 16 | ament_index_python 17 | geometry_msgs 18 | python-tensorflow-pip 19 | python3-numpy 20 | python3-pyqt5 21 | python3-pyqtgraph 22 | rclpy 23 | sensor_msgs 24 | std_msgs 25 | std_srvs 26 | turtlebot3_msgs 27 | 28 | ament_python 29 | 30 | 31 | -------------------------------------------------------------------------------- /turtlebot3_dqn/resource/turtlebot3_dqn: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/95847712f2d06507dffb6c94b2f356ff38dfa79f/turtlebot3_dqn/resource/turtlebot3_dqn -------------------------------------------------------------------------------- /turtlebot3_dqn/saved_model/stage1_episode600.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/95847712f2d06507dffb6c94b2f356ff38dfa79f/turtlebot3_dqn/saved_model/stage1_episode600.h5 -------------------------------------------------------------------------------- /turtlebot3_dqn/saved_model/stage1_episode600.json: -------------------------------------------------------------------------------- 1 | {"epsilon": 0.1416302983127139} 2 | -------------------------------------------------------------------------------- /turtlebot3_dqn/setup.cfg: -------------------------------------------------------------------------------- 1 | [develop] 2 | script_dir=$base/lib/turtlebot3_dqn 3 | [install] 4 | install_scripts=$base/lib/turtlebot3_dqn 5 | -------------------------------------------------------------------------------- /turtlebot3_dqn/setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | from setuptools import find_packages 4 | from setuptools import setup 5 | 6 | package_name = 'turtlebot3_dqn' 7 | authors_info = [ 8 | ('Gilbert', 'kkjong@robotis.com'), 9 | ('Ryan Shim', 'N/A'), 10 | ('ChanHyeong Lee', 'dddoggi1207@gmail.com'), 11 | ] 12 | authors = ', '.join(author for author, _ in authors_info) 13 | author_emails = ', '.join(email for _, email in authors_info) 14 | 15 | setup( 16 | name=package_name, 17 | version='1.0.1', 18 | packages=find_packages(), 19 | data_files=[ 20 | ('share/ament_index/resource_index/packages', ['resource/' + package_name]), 21 | ('share/' + package_name, ['package.xml']), 22 | ('share/' + package_name + '/launch', glob.glob('launch/*.py')), 23 | ], 24 | install_requires=['setuptools', 'launch'], 25 | zip_safe=True, 26 | author=authors, 27 | author_email=author_emails, 28 | maintainer='Pyo', 29 | maintainer_email='pyo@robotis.com', 30 | description='ROS 2 packages for TurtleBot3 machine learning', 31 | license='Apache 2.0', 32 | tests_require=['pytest'], 33 | entry_points={ 34 | 'console_scripts': [ 35 | 'action_graph = turtlebot3_dqn.action_graph:main', 36 | 'dqn_agent = turtlebot3_dqn.dqn_agent:main', 37 | 'dqn_environment = turtlebot3_dqn.dqn_environment:main', 38 | 'dqn_gazebo = turtlebot3_dqn.dqn_gazebo:main', 39 | 'dqn_test = turtlebot3_dqn.dqn_test:main', 40 | 'result_graph = turtlebot3_dqn.result_graph:main', 41 | ], 42 | }, 43 | ) 44 | -------------------------------------------------------------------------------- /turtlebot3_dqn/turtlebot3_dqn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/95847712f2d06507dffb6c94b2f356ff38dfa79f/turtlebot3_dqn/turtlebot3_dqn/__init__.py -------------------------------------------------------------------------------- /turtlebot3_dqn/turtlebot3_dqn/action_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | ################################################################################# 3 | # Copyright 2018 ROBOTIS CO., LTD. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | ################################################################################# 17 | # 18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee 19 | 20 | import signal 21 | import sys 22 | import threading 23 | import time 24 | 25 | from PyQt5.QtCore import pyqtSignal 26 | from PyQt5.QtCore import Qt 27 | from PyQt5.QtCore import QThread 28 | from PyQt5.QtWidgets import QApplication 29 | from PyQt5.QtWidgets import QGridLayout 30 | from PyQt5.QtWidgets import QLabel 31 | from PyQt5.QtWidgets import QLineEdit 32 | from PyQt5.QtWidgets import QProgressBar 33 | from PyQt5.QtWidgets import QWidget 34 | import rclpy 35 | from rclpy.node import Node 36 | from std_msgs.msg import Float32MultiArray 37 | 38 | 39 | class Ros2Subscriber(Node): 40 | 41 | def __init__(self, qt_thread): 42 | super().__init__('progress_subscriber') 43 | self.qt_thread = qt_thread 44 | 45 | self.subscription = self.create_subscription( 46 | Float32MultiArray, 47 | '/get_action', 48 | self.get_array_callback, 49 | 10 50 | ) 51 | 52 | def get_array_callback(self, msg): 53 | data = list(msg.data) 54 | 55 | self.qt_thread.signal_action0.emit(0) 56 | self.qt_thread.signal_action1.emit(0) 57 | self.qt_thread.signal_action2.emit(0) 58 | self.qt_thread.signal_action3.emit(0) 59 | self.qt_thread.signal_action4.emit(0) 60 | 61 | if data[0] == 0: 62 | self.qt_thread.signal_action0.emit(100) 63 | elif data[0] == 1: 64 | self.qt_thread.signal_action1.emit(100) 65 | elif data[0] == 2: 66 | self.qt_thread.signal_action2.emit(100) 67 | elif data[0] == 3: 68 | self.qt_thread.signal_action3.emit(100) 69 | elif data[0] == 4: 70 | self.qt_thread.signal_action4.emit(100) 71 | 72 | if len(data) >= 2: 73 | self.qt_thread.signal_total_reward.emit(str(round(data[-2], 2))) 74 | self.qt_thread.signal_reward.emit(str(round(data[-1], 2))) 75 | 76 | 77 | class Thread(QThread): 78 | 79 | signal_action0 = pyqtSignal(int) 80 | signal_action1 = pyqtSignal(int) 81 | signal_action2 = pyqtSignal(int) 82 | signal_action3 = pyqtSignal(int) 83 | signal_action4 = pyqtSignal(int) 84 | signal_total_reward = pyqtSignal(str) 85 | signal_reward = pyqtSignal(str) 86 | 87 | def __init__(self): 88 | super().__init__() 89 | self.node = None 90 | 91 | def run(self): 92 | self.node = Ros2Subscriber(self) 93 | rclpy.spin(self.node) 94 | self.node.destroy_node() 95 | 96 | 97 | class Form(QWidget): 98 | 99 | def __init__(self, qt_thread): 100 | super().__init__(flags=Qt.Widget) 101 | self.qt_thread = qt_thread 102 | self.setWindowTitle('Action State') 103 | 104 | layout = QGridLayout() 105 | 106 | self.pgsb1 = QProgressBar() 107 | self.pgsb1.setOrientation(Qt.Vertical) 108 | self.pgsb1.setValue(0) 109 | self.pgsb1.setRange(0, 100) 110 | 111 | self.pgsb2 = QProgressBar() 112 | self.pgsb2.setOrientation(Qt.Vertical) 113 | self.pgsb2.setValue(0) 114 | self.pgsb2.setRange(0, 100) 115 | 116 | self.pgsb3 = QProgressBar() 117 | self.pgsb3.setOrientation(Qt.Vertical) 118 | self.pgsb3.setValue(0) 119 | self.pgsb3.setRange(0, 100) 120 | 121 | self.pgsb4 = QProgressBar() 122 | self.pgsb4.setOrientation(Qt.Vertical) 123 | self.pgsb4.setValue(0) 124 | self.pgsb4.setRange(0, 100) 125 | 126 | self.pgsb5 = QProgressBar() 127 | self.pgsb5.setOrientation(Qt.Vertical) 128 | self.pgsb5.setValue(0) 129 | self.pgsb5.setRange(0, 100) 130 | 131 | self.label_total_reward = QLabel('Total reward') 132 | self.edit_total_reward = QLineEdit('') 133 | self.edit_total_reward.setDisabled(True) 134 | self.edit_total_reward.setFixedWidth(100) 135 | 136 | self.label_reward = QLabel('Reward') 137 | self.edit_reward = QLineEdit('') 138 | self.edit_reward.setDisabled(True) 139 | self.edit_reward.setFixedWidth(100) 140 | 141 | self.label_left = QLabel('Left') 142 | self.label_front = QLabel('Front') 143 | self.label_right = QLabel('Right') 144 | 145 | layout.addWidget(self.label_total_reward, 0, 0) 146 | layout.addWidget(self.edit_total_reward, 1, 0) 147 | layout.addWidget(self.label_reward, 2, 0) 148 | layout.addWidget(self.edit_reward, 3, 0) 149 | 150 | layout.addWidget(self.pgsb1, 0, 4, 4, 1) 151 | layout.addWidget(self.pgsb2, 0, 5, 4, 1) 152 | layout.addWidget(self.pgsb3, 0, 6, 4, 1) 153 | layout.addWidget(self.pgsb4, 0, 7, 4, 1) 154 | layout.addWidget(self.pgsb5, 0, 8, 4, 1) 155 | 156 | layout.addWidget(self.label_left, 4, 4) 157 | layout.addWidget(self.label_front, 4, 6) 158 | layout.addWidget(self.label_right, 4, 8) 159 | 160 | self.setLayout(layout) 161 | 162 | qt_thread.signal_action0.connect(self.pgsb1.setValue) 163 | qt_thread.signal_action1.connect(self.pgsb2.setValue) 164 | qt_thread.signal_action2.connect(self.pgsb3.setValue) 165 | qt_thread.signal_action3.connect(self.pgsb4.setValue) 166 | qt_thread.signal_action4.connect(self.pgsb5.setValue) 167 | qt_thread.signal_total_reward.connect(self.edit_total_reward.setText) 168 | qt_thread.signal_reward.connect(self.edit_reward.setText) 169 | 170 | def closeEvent(self, event): 171 | if hasattr(self.qt_thread, 'node') and self.qt_thread.node is not None: 172 | self.qt_thread.node.destroy_node() 173 | rclpy.shutdown() 174 | event.accept() 175 | 176 | 177 | def run_qt_app(qt_thread): 178 | app = QApplication(sys.argv) 179 | form = Form(qt_thread) 180 | form.show() 181 | app.exec_() 182 | 183 | 184 | def main(): 185 | rclpy.init() 186 | qt_thread = Thread() 187 | qt_thread.start() 188 | qt_gui_thread = threading.Thread(target=run_qt_app, args=(qt_thread,), daemon=True) 189 | qt_gui_thread.start() 190 | 191 | def shutdown_handler(sig, frame): 192 | print('shutdown') 193 | qt_thread.node.destroy_node() 194 | rclpy.shutdown() 195 | sys.exit(0) 196 | 197 | signal.signal(signal.SIGINT, shutdown_handler) 198 | signal.signal(signal.SIGTERM, shutdown_handler) 199 | try: 200 | while rclpy.ok(): 201 | time.sleep(0.1) 202 | except KeyboardInterrupt: 203 | shutdown_handler(None, None) 204 | 205 | 206 | if __name__ == '__main__': 207 | main() 208 | -------------------------------------------------------------------------------- /turtlebot3_dqn/turtlebot3_dqn/dqn_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ################################################################################# 3 | # Copyright 2019 ROBOTIS CO., LTD. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | ################################################################################# 17 | # 18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee 19 | 20 | import collections 21 | import datetime 22 | import json 23 | import math 24 | import os 25 | import random 26 | import sys 27 | import time 28 | 29 | import numpy 30 | import rclpy 31 | from rclpy.node import Node 32 | from std_msgs.msg import Float32MultiArray 33 | from std_srvs.srv import Empty 34 | import tensorflow 35 | from tensorflow.keras.layers import Dense 36 | from tensorflow.keras.layers import Input 37 | from tensorflow.keras.losses import MeanSquaredError 38 | from tensorflow.keras.models import load_model 39 | from tensorflow.keras.models import Sequential 40 | from tensorflow.keras.optimizers import Adam 41 | 42 | from turtlebot3_msgs.srv import Dqn 43 | 44 | 45 | tensorflow.config.set_visible_devices([], 'GPU') 46 | 47 | LOGGING = True 48 | current_time = datetime.datetime.now().strftime('[%mm%dd-%H:%M]') 49 | 50 | 51 | class DQNMetric(tensorflow.keras.metrics.Metric): 52 | 53 | def __init__(self, name='dqn_metric'): 54 | super(DQNMetric, self).__init__(name=name) 55 | self.loss = self.add_weight(name='loss', initializer='zeros') 56 | self.episode_step = self.add_weight(name='step', initializer='zeros') 57 | 58 | def update_state(self, y_true, y_pred=0, sample_weight=None): 59 | self.loss.assign_add(y_true) 60 | self.episode_step.assign_add(1) 61 | 62 | def result(self): 63 | return self.loss / self.episode_step 64 | 65 | def reset_states(self): 66 | self.loss.assign(0) 67 | self.episode_step.assign(0) 68 | 69 | 70 | class DQNAgent(Node): 71 | 72 | def __init__(self, stage_num, max_training_episodes): 73 | super().__init__('dqn_agent') 74 | 75 | self.stage = int(stage_num) 76 | self.train_mode = True 77 | self.state_size = 26 78 | self.action_size = 5 79 | self.max_training_episodes = int(max_training_episodes) 80 | 81 | self.done = False 82 | self.succeed = False 83 | self.fail = False 84 | 85 | self.discount_factor = 0.99 86 | self.learning_rate = 0.0007 87 | self.epsilon = 1.0 88 | self.step_counter = 0 89 | self.epsilon_decay = 6000 * self.stage 90 | self.epsilon_min = 0.05 91 | self.batch_size = 128 92 | 93 | self.replay_memory = collections.deque(maxlen=500000) 94 | self.min_replay_memory_size = 5000 95 | 96 | self.model = self.create_qnetwork() 97 | self.target_model = self.create_qnetwork() 98 | self.update_target_model() 99 | self.update_target_after = 5000 100 | self.target_update_after_counter = 0 101 | 102 | self.load_model = False 103 | self.load_episode = 0 104 | self.model_dir_path = os.path.join( 105 | os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 106 | 'saved_model' 107 | ) 108 | self.model_path = os.path.join( 109 | self.model_dir_path, 110 | 'stage' + str(self.stage) + '_episode' + str(self.load_episode) + '.h5' 111 | ) 112 | 113 | if self.load_model: 114 | self.model.set_weights(load_model(self.model_path).get_weights()) 115 | with open(os.path.join( 116 | self.model_dir_path, 117 | 'stage' + str(self.stage) + '_episode' + str(self.load_episode) + '.json' 118 | )) as outfile: 119 | param = json.load(outfile) 120 | self.epsilon = param.get('epsilon') 121 | self.step_counter = param.get('step_counter') 122 | 123 | if LOGGING: 124 | tensorboard_file_name = current_time + ' dqn_stage' + str(self.stage) + '_reward' 125 | home_dir = os.path.expanduser('~') 126 | dqn_reward_log_dir = os.path.join( 127 | home_dir, 'turtlebot3_dqn_logs', 'gradient_tape', tensorboard_file_name 128 | ) 129 | self.dqn_reward_writer = tensorflow.summary.create_file_writer(dqn_reward_log_dir) 130 | self.dqn_reward_metric = DQNMetric() 131 | 132 | self.rl_agent_interface_client = self.create_client(Dqn, 'rl_agent_interface') 133 | self.make_environment_client = self.create_client(Empty, 'make_environment') 134 | self.reset_environment_client = self.create_client(Dqn, 'reset_environment') 135 | 136 | self.action_pub = self.create_publisher(Float32MultiArray, '/get_action', 10) 137 | self.result_pub = self.create_publisher(Float32MultiArray, 'result', 10) 138 | 139 | self.process() 140 | 141 | def process(self): 142 | self.env_make() 143 | time.sleep(1.0) 144 | 145 | episode_num = self.load_episode 146 | 147 | for episode in range(self.load_episode + 1, self.max_training_episodes + 1): 148 | state = self.reset_environment() 149 | episode_num += 1 150 | local_step = 0 151 | score = 0 152 | sum_max_q = 0.0 153 | 154 | time.sleep(1.0) 155 | 156 | while True: 157 | local_step += 1 158 | 159 | q_values = self.model.predict(state) 160 | sum_max_q += float(numpy.max(q_values)) 161 | 162 | action = int(self.get_action(state)) 163 | next_state, reward, done = self.step(action) 164 | score += reward 165 | 166 | msg = Float32MultiArray() 167 | msg.data = [float(action), float(score), float(reward)] 168 | self.action_pub.publish(msg) 169 | 170 | if self.train_mode: 171 | self.append_sample((state, action, reward, next_state, done)) 172 | self.train_model(done) 173 | 174 | state = next_state 175 | 176 | if done: 177 | avg_max_q = sum_max_q / local_step if local_step > 0 else 0.0 178 | 179 | msg = Float32MultiArray() 180 | msg.data = [float(score), float(avg_max_q)] 181 | self.result_pub.publish(msg) 182 | 183 | if LOGGING: 184 | self.dqn_reward_metric.update_state(score) 185 | with self.dqn_reward_writer.as_default(): 186 | tensorflow.summary.scalar( 187 | 'dqn_reward', self.dqn_reward_metric.result(), step=episode_num 188 | ) 189 | self.dqn_reward_metric.reset_states() 190 | 191 | print( 192 | 'Episode:', episode, 193 | 'score:', score, 194 | 'memory length:', len(self.replay_memory), 195 | 'epsilon:', self.epsilon) 196 | 197 | param_keys = ['epsilon', 'step'] 198 | param_values = [self.epsilon, self.step_counter] 199 | param_dictionary = dict(zip(param_keys, param_values)) 200 | break 201 | 202 | time.sleep(0.01) 203 | 204 | if self.train_mode: 205 | if episode % 100 == 0: 206 | self.model_path = os.path.join( 207 | self.model_dir_path, 208 | 'stage' + str(self.stage) + '_episode' + str(episode) + '.h5') 209 | self.model.save(self.model_path) 210 | with open( 211 | os.path.join( 212 | self.model_dir_path, 213 | 'stage' + str(self.stage) + '_episode' + str(episode) + '.json' 214 | ), 215 | 'w' 216 | ) as outfile: 217 | json.dump(param_dictionary, outfile) 218 | 219 | def env_make(self): 220 | while not self.make_environment_client.wait_for_service(timeout_sec=1.0): 221 | self.get_logger().warn( 222 | 'Environment make client failed to connect to the server, try again ...' 223 | ) 224 | 225 | self.make_environment_client.call_async(Empty.Request()) 226 | 227 | def reset_environment(self): 228 | while not self.reset_environment_client.wait_for_service(timeout_sec=1.0): 229 | self.get_logger().warn( 230 | 'Reset environment client failed to connect to the server, try again ...' 231 | ) 232 | 233 | future = self.reset_environment_client.call_async(Dqn.Request()) 234 | 235 | rclpy.spin_until_future_complete(self, future) 236 | if future.result() is not None: 237 | state = future.result().state 238 | state = numpy.reshape(numpy.asarray(state), [1, self.state_size]) 239 | else: 240 | self.get_logger().error( 241 | 'Exception while calling service: {0}'.format(future.exception())) 242 | 243 | return state 244 | 245 | def get_action(self, state): 246 | if self.train_mode: 247 | self.step_counter += 1 248 | self.epsilon = self.epsilon_min + (1.0 - self.epsilon_min) * math.exp( 249 | -1.0 * self.step_counter / self.epsilon_decay) 250 | lucky = random.random() 251 | if lucky > (1 - self.epsilon): 252 | result = random.randint(0, self.action_size - 1) 253 | else: 254 | result = numpy.argmax(self.model.predict(state)) 255 | else: 256 | result = numpy.argmax(self.model.predict(state)) 257 | 258 | return result 259 | 260 | def step(self, action): 261 | req = Dqn.Request() 262 | req.action = action 263 | 264 | while not self.rl_agent_interface_client.wait_for_service(timeout_sec=1.0): 265 | self.get_logger().info('rl_agent interface service not available, waiting again...') 266 | 267 | future = self.rl_agent_interface_client.call_async(req) 268 | 269 | rclpy.spin_until_future_complete(self, future) 270 | 271 | if future.result() is not None: 272 | next_state = future.result().state 273 | next_state = numpy.reshape(numpy.asarray(next_state), [1, self.state_size]) 274 | reward = future.result().reward 275 | done = future.result().done 276 | else: 277 | self.get_logger().error( 278 | 'Exception while calling service: {0}'.format(future.exception())) 279 | 280 | return next_state, reward, done 281 | 282 | def create_qnetwork(self): 283 | model = Sequential() 284 | model.add(Input(shape=(self.state_size,))) 285 | model.add(Dense(512, activation='relu')) 286 | model.add(Dense(256, activation='relu')) 287 | model.add(Dense(128, activation='relu')) 288 | model.add(Dense(self.action_size, activation='linear')) 289 | model.compile(loss=MeanSquaredError(), optimizer=Adam(learning_rate=self.learning_rate)) 290 | model.summary() 291 | 292 | return model 293 | 294 | def update_target_model(self): 295 | self.target_model.set_weights(self.model.get_weights()) 296 | self.target_update_after_counter = 0 297 | print('*Target model updated*') 298 | 299 | def append_sample(self, transition): 300 | self.replay_memory.append(transition) 301 | 302 | def train_model(self, terminal): 303 | if len(self.replay_memory) < self.min_replay_memory_size: 304 | return 305 | data_in_mini_batch = random.sample(self.replay_memory, self.batch_size) 306 | 307 | current_states = numpy.array([transition[0] for transition in data_in_mini_batch]) 308 | current_states = current_states.squeeze() 309 | current_qvalues_list = self.model.predict(current_states) 310 | 311 | next_states = numpy.array([transition[3] for transition in data_in_mini_batch]) 312 | next_states = next_states.squeeze() 313 | next_qvalues_list = self.target_model.predict(next_states) 314 | 315 | x_train = [] 316 | y_train = [] 317 | 318 | for index, (current_state, action, reward, _, done) in enumerate(data_in_mini_batch): 319 | current_q_values = current_qvalues_list[index] 320 | 321 | if not done: 322 | future_reward = numpy.max(next_qvalues_list[index]) 323 | desired_q = reward + self.discount_factor * future_reward 324 | else: 325 | desired_q = reward 326 | 327 | current_q_values[action] = desired_q 328 | x_train.append(current_state) 329 | y_train.append(current_q_values) 330 | 331 | x_train = numpy.array(x_train) 332 | y_train = numpy.array(y_train) 333 | x_train = numpy.reshape(x_train, [len(data_in_mini_batch), self.state_size]) 334 | y_train = numpy.reshape(y_train, [len(data_in_mini_batch), self.action_size]) 335 | 336 | self.model.fit( 337 | tensorflow.convert_to_tensor(x_train, tensorflow.float32), 338 | tensorflow.convert_to_tensor(y_train, tensorflow.float32), 339 | batch_size=self.batch_size, verbose=0 340 | ) 341 | self.target_update_after_counter += 1 342 | 343 | if self.target_update_after_counter > self.update_target_after and terminal: 344 | self.update_target_model() 345 | 346 | 347 | def main(args=None): 348 | if args is None: 349 | args = sys.argv 350 | stage_num = args[1] if len(args) > 1 else '1' 351 | max_training_episodes = args[2] if len(args) > 2 else '1000' 352 | rclpy.init(args=args) 353 | 354 | dqn_agent = DQNAgent(stage_num, max_training_episodes) 355 | rclpy.spin(dqn_agent) 356 | 357 | dqn_agent.destroy_node() 358 | rclpy.shutdown() 359 | 360 | 361 | if __name__ == '__main__': 362 | main() 363 | -------------------------------------------------------------------------------- /turtlebot3_dqn/turtlebot3_dqn/dqn_environment.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ################################################################################# 3 | # Copyright 2019 ROBOTIS CO., LTD. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | ################################################################################# 17 | # 18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee 19 | 20 | import math 21 | import os 22 | 23 | from geometry_msgs.msg import Twist 24 | from geometry_msgs.msg import TwistStamped 25 | from nav_msgs.msg import Odometry 26 | import numpy 27 | import rclpy 28 | from rclpy.callback_groups import MutuallyExclusiveCallbackGroup 29 | from rclpy.node import Node 30 | from rclpy.qos import qos_profile_sensor_data 31 | from rclpy.qos import QoSProfile 32 | from sensor_msgs.msg import LaserScan 33 | from std_srvs.srv import Empty 34 | 35 | from turtlebot3_msgs.srv import Dqn 36 | from turtlebot3_msgs.srv import Goal 37 | 38 | 39 | ROS_DISTRO = os.environ.get('ROS_DISTRO') 40 | 41 | 42 | class RLEnvironment(Node): 43 | 44 | def __init__(self): 45 | super().__init__('rl_environment') 46 | self.goal_pose_x = 0.0 47 | self.goal_pose_y = 0.0 48 | self.robot_pose_x = 0.0 49 | self.robot_pose_y = 0.0 50 | 51 | self.action_size = 5 52 | self.max_step = 800 53 | 54 | self.done = False 55 | self.fail = False 56 | self.succeed = False 57 | 58 | self.goal_angle = 0.0 59 | self.goal_distance = 1.0 60 | self.init_goal_distance = 0.5 61 | self.scan_ranges = [] 62 | self.front_ranges = [] 63 | self.min_obstacle_distance = 10.0 64 | self.is_front_min_actual_front = False 65 | 66 | self.local_step = 0 67 | self.stop_cmd_vel_timer = None 68 | self.angular_vel = [1.5, 0.75, 0.0, -0.75, -1.5] 69 | 70 | qos = QoSProfile(depth=10) 71 | 72 | if ROS_DISTRO == 'humble': 73 | self.cmd_vel_pub = self.create_publisher(Twist, 'cmd_vel', qos) 74 | else: 75 | self.cmd_vel_pub = self.create_publisher(TwistStamped, 'cmd_vel', qos) 76 | 77 | self.odom_sub = self.create_subscription( 78 | Odometry, 79 | 'odom', 80 | self.odom_sub_callback, 81 | qos 82 | ) 83 | self.scan_sub = self.create_subscription( 84 | LaserScan, 85 | 'scan', 86 | self.scan_sub_callback, 87 | qos_profile_sensor_data 88 | ) 89 | 90 | self.clients_callback_group = MutuallyExclusiveCallbackGroup() 91 | self.task_succeed_client = self.create_client( 92 | Goal, 93 | 'task_succeed', 94 | callback_group=self.clients_callback_group 95 | ) 96 | self.task_failed_client = self.create_client( 97 | Goal, 98 | 'task_failed', 99 | callback_group=self.clients_callback_group 100 | ) 101 | self.initialize_environment_client = self.create_client( 102 | Goal, 103 | 'initialize_env', 104 | callback_group=self.clients_callback_group 105 | ) 106 | 107 | self.rl_agent_interface_service = self.create_service( 108 | Dqn, 109 | 'rl_agent_interface', 110 | self.rl_agent_interface_callback 111 | ) 112 | self.make_environment_service = self.create_service( 113 | Empty, 114 | 'make_environment', 115 | self.make_environment_callback 116 | ) 117 | self.reset_environment_service = self.create_service( 118 | Dqn, 119 | 'reset_environment', 120 | self.reset_environment_callback 121 | ) 122 | 123 | def make_environment_callback(self, request, response): 124 | self.get_logger().info('Make environment called') 125 | while not self.initialize_environment_client.wait_for_service(timeout_sec=1.0): 126 | self.get_logger().warn( 127 | 'service for initialize the environment is not available, waiting ...' 128 | ) 129 | future = self.initialize_environment_client.call_async(Goal.Request()) 130 | rclpy.spin_until_future_complete(self, future) 131 | response_goal = future.result() 132 | if not response_goal.success: 133 | self.get_logger().error('initialize environment request failed') 134 | else: 135 | self.goal_pose_x = response_goal.pose_x 136 | self.goal_pose_y = response_goal.pose_y 137 | self.get_logger().info( 138 | 'goal initialized at [%f, %f]' % (self.goal_pose_x, self.goal_pose_y) 139 | ) 140 | 141 | return response 142 | 143 | def reset_environment_callback(self, request, response): 144 | state = self.calculate_state() 145 | self.init_goal_distance = state[0] 146 | self.prev_goal_distance = self.init_goal_distance 147 | response.state = state 148 | 149 | return response 150 | 151 | def call_task_succeed(self): 152 | while not self.task_succeed_client.wait_for_service(timeout_sec=1.0): 153 | self.get_logger().warn('service for task succeed is not available, waiting ...') 154 | future = self.task_succeed_client.call_async(Goal.Request()) 155 | rclpy.spin_until_future_complete(self, future) 156 | if future.result() is not None: 157 | response = future.result() 158 | self.goal_pose_x = response.pose_x 159 | self.goal_pose_y = response.pose_y 160 | self.get_logger().info('service for task succeed finished') 161 | else: 162 | self.get_logger().error('task succeed service call failed') 163 | 164 | def call_task_failed(self): 165 | while not self.task_failed_client.wait_for_service(timeout_sec=1.0): 166 | self.get_logger().warn('service for task failed is not available, waiting ...') 167 | future = self.task_failed_client.call_async(Goal.Request()) 168 | rclpy.spin_until_future_complete(self, future) 169 | if future.result() is not None: 170 | response = future.result() 171 | self.goal_pose_x = response.pose_x 172 | self.goal_pose_y = response.pose_y 173 | self.get_logger().info('service for task failed finished') 174 | else: 175 | self.get_logger().error('task failed service call failed') 176 | 177 | def scan_sub_callback(self, scan): 178 | self.scan_ranges = [] 179 | self.front_ranges = [] 180 | self.front_angles = [] 181 | 182 | num_of_lidar_rays = len(scan.ranges) 183 | angle_min = scan.angle_min 184 | angle_increment = scan.angle_increment 185 | 186 | self.front_distance = scan.ranges[0] 187 | 188 | for i in range(num_of_lidar_rays): 189 | angle = angle_min + i * angle_increment 190 | distance = scan.ranges[i] 191 | 192 | if distance == float('Inf'): 193 | distance = 3.5 194 | elif numpy.isnan(distance): 195 | distance = 0.0 196 | 197 | self.scan_ranges.append(distance) 198 | 199 | if (0 <= angle <= math.pi/2) or (3*math.pi/2 <= angle <= 2*math.pi): 200 | self.front_ranges.append(distance) 201 | self.front_angles.append(angle) 202 | 203 | self.min_obstacle_distance = min(self.scan_ranges) 204 | self.front_min_obstacle_distance = min(self.front_ranges) if self.front_ranges else 10.0 205 | 206 | def odom_sub_callback(self, msg): 207 | self.robot_pose_x = msg.pose.pose.position.x 208 | self.robot_pose_y = msg.pose.pose.position.y 209 | _, _, self.robot_pose_theta = self.euler_from_quaternion(msg.pose.pose.orientation) 210 | 211 | goal_distance = math.sqrt( 212 | (self.goal_pose_x - self.robot_pose_x) ** 2 213 | + (self.goal_pose_y - self.robot_pose_y) ** 2) 214 | path_theta = math.atan2( 215 | self.goal_pose_y - self.robot_pose_y, 216 | self.goal_pose_x - self.robot_pose_x) 217 | 218 | goal_angle = path_theta - self.robot_pose_theta 219 | if goal_angle > math.pi: 220 | goal_angle -= 2 * math.pi 221 | 222 | elif goal_angle < -math.pi: 223 | goal_angle += 2 * math.pi 224 | 225 | self.goal_distance = goal_distance 226 | self.goal_angle = goal_angle 227 | 228 | def calculate_state(self): 229 | state = [] 230 | state.append(float(self.goal_distance)) 231 | state.append(float(self.goal_angle)) 232 | for var in self.front_ranges: 233 | state.append(float(var)) 234 | self.local_step += 1 235 | 236 | if self.goal_distance < 0.20: 237 | self.get_logger().info('Goal Reached') 238 | self.succeed = True 239 | self.done = True 240 | if ROS_DISTRO == 'humble': 241 | self.cmd_vel_pub.publish(Twist()) 242 | else: 243 | self.cmd_vel_pub.publish(TwistStamped()) 244 | self.local_step = 0 245 | self.call_task_succeed() 246 | 247 | if self.min_obstacle_distance < 0.15: 248 | self.get_logger().info('Collision happened') 249 | self.fail = True 250 | self.done = True 251 | if ROS_DISTRO == 'humble': 252 | self.cmd_vel_pub.publish(Twist()) 253 | else: 254 | self.cmd_vel_pub.publish(TwistStamped()) 255 | self.local_step = 0 256 | self.call_task_failed() 257 | 258 | if self.local_step == self.max_step: 259 | self.get_logger().info('Time out!') 260 | self.fail = True 261 | self.done = True 262 | if ROS_DISTRO == 'humble': 263 | self.cmd_vel_pub.publish(Twist()) 264 | else: 265 | self.cmd_vel_pub.publish(TwistStamped()) 266 | self.local_step = 0 267 | self.call_task_failed() 268 | 269 | return state 270 | 271 | def compute_directional_weights(self, relative_angles, max_weight=10.0): 272 | power = 6 273 | raw_weights = (numpy.cos(relative_angles))**power + 0.1 274 | scaled_weights = raw_weights * (max_weight / numpy.max(raw_weights)) 275 | normalized_weights = scaled_weights / numpy.sum(scaled_weights) 276 | return normalized_weights 277 | 278 | def compute_weighted_obstacle_reward(self): 279 | if not self.front_ranges or not self.front_angles: 280 | return 0.0 281 | 282 | front_ranges = numpy.array(self.front_ranges) 283 | front_angles = numpy.array(self.front_angles) 284 | 285 | valid_mask = front_ranges <= 0.5 286 | if not numpy.any(valid_mask): 287 | return 0.0 288 | 289 | front_ranges = front_ranges[valid_mask] 290 | front_angles = front_angles[valid_mask] 291 | 292 | relative_angles = numpy.unwrap(front_angles) 293 | relative_angles[relative_angles > numpy.pi] -= 2 * numpy.pi 294 | 295 | weights = self.compute_directional_weights(relative_angles, max_weight=10.0) 296 | 297 | safe_dists = numpy.clip(front_ranges - 0.25, 1e-2, 3.5) 298 | decay = numpy.exp(-3.0 * safe_dists) 299 | 300 | weighted_decay = numpy.dot(weights, decay) 301 | 302 | reward = - (1.0 + 4.0 * weighted_decay) 303 | 304 | return reward 305 | 306 | def calculate_reward(self): 307 | yaw_reward = 1 - (2 * abs(self.goal_angle) / math.pi) 308 | obstacle_reward = self.compute_weighted_obstacle_reward() 309 | 310 | print('directional_reward: %f, obstacle_reward: %f' % (yaw_reward, obstacle_reward)) 311 | reward = yaw_reward + obstacle_reward 312 | 313 | if self.succeed: 314 | reward = 100.0 315 | elif self.fail: 316 | reward = -50.0 317 | 318 | return reward 319 | 320 | def rl_agent_interface_callback(self, request, response): 321 | action = request.action 322 | if ROS_DISTRO == 'humble': 323 | msg = Twist() 324 | msg.linear.x = 0.2 325 | msg.angular.z = self.angular_vel[action] 326 | else: 327 | msg = TwistStamped() 328 | msg.twist.linear.x = 0.2 329 | msg.twist.angular.z = self.angular_vel[action] 330 | 331 | self.cmd_vel_pub.publish(msg) 332 | if self.stop_cmd_vel_timer is None: 333 | self.prev_goal_distance = self.init_goal_distance 334 | self.stop_cmd_vel_timer = self.create_timer(0.8, self.timer_callback) 335 | else: 336 | self.destroy_timer(self.stop_cmd_vel_timer) 337 | self.stop_cmd_vel_timer = self.create_timer(0.8, self.timer_callback) 338 | 339 | response.state = self.calculate_state() 340 | response.reward = self.calculate_reward() 341 | response.done = self.done 342 | 343 | if self.done is True: 344 | self.done = False 345 | self.succeed = False 346 | self.fail = False 347 | 348 | return response 349 | 350 | def timer_callback(self): 351 | self.get_logger().info('Stop called') 352 | if ROS_DISTRO == 'humble': 353 | self.cmd_vel_pub.publish(Twist()) 354 | else: 355 | self.cmd_vel_pub.publish(TwistStamped()) 356 | self.destroy_timer(self.stop_cmd_vel_timer) 357 | 358 | def euler_from_quaternion(self, quat): 359 | x = quat.x 360 | y = quat.y 361 | z = quat.z 362 | w = quat.w 363 | 364 | sinr_cosp = 2 * (w * x + y * z) 365 | cosr_cosp = 1 - 2 * (x * x + y * y) 366 | roll = numpy.arctan2(sinr_cosp, cosr_cosp) 367 | 368 | sinp = 2 * (w * y - z * x) 369 | pitch = numpy.arcsin(sinp) 370 | 371 | siny_cosp = 2 * (w * z + x * y) 372 | cosy_cosp = 1 - 2 * (y * y + z * z) 373 | yaw = numpy.arctan2(siny_cosp, cosy_cosp) 374 | 375 | return roll, pitch, yaw 376 | 377 | 378 | def main(args=None): 379 | rclpy.init(args=args) 380 | rl_environment = RLEnvironment() 381 | try: 382 | while rclpy.ok(): 383 | rclpy.spin_once(rl_environment, timeout_sec=0.1) 384 | except KeyboardInterrupt: 385 | pass 386 | finally: 387 | rl_environment.destroy_node() 388 | rclpy.shutdown() 389 | 390 | 391 | if __name__ == '__main__': 392 | main() 393 | -------------------------------------------------------------------------------- /turtlebot3_dqn/turtlebot3_dqn/dqn_gazebo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ################################################################################# 3 | # Copyright 2019 ROBOTIS CO., LTD. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | ################################################################################# 17 | # 18 | # # Authors: Ryan Shim, Gilbert, ChanHyeong Lee 19 | 20 | import os 21 | import random 22 | import subprocess 23 | import sys 24 | import time 25 | 26 | from ament_index_python.packages import get_package_share_directory 27 | import rclpy 28 | from rclpy.callback_groups import MutuallyExclusiveCallbackGroup 29 | from rclpy.node import Node 30 | from std_srvs.srv import Empty 31 | 32 | from turtlebot3_msgs.srv import Goal 33 | 34 | 35 | ROS_DISTRO = os.environ.get('ROS_DISTRO') 36 | if ROS_DISTRO == 'humble': 37 | from gazebo_msgs.srv import DeleteEntity 38 | from gazebo_msgs.srv import SpawnEntity 39 | from geometry_msgs.msg import Pose 40 | 41 | 42 | class GazeboInterface(Node): 43 | 44 | def __init__(self, stage_num): 45 | super().__init__('gazebo_interface') 46 | self.stage = int(stage_num) 47 | 48 | self.entity_name = 'goal_box' 49 | self.entity_pose_x = 0.5 50 | self.entity_pose_y = 0.0 51 | 52 | if ROS_DISTRO == 'humble': 53 | self.entity = None 54 | self.open_entity() 55 | self.delete_entity_client = self.create_client(DeleteEntity, 'delete_entity') 56 | self.spawn_entity_client = self.create_client(SpawnEntity, 'spawn_entity') 57 | self.reset_simulation_client = self.create_client(Empty, 'reset_simulation') 58 | 59 | self.callback_group = MutuallyExclusiveCallbackGroup() 60 | self.initialize_env_service = self.create_service( 61 | Goal, 62 | 'initialize_env', 63 | self.initialize_env_callback, 64 | callback_group=self.callback_group 65 | ) 66 | self.task_succeed_service = self.create_service( 67 | Goal, 68 | 'task_succeed', 69 | self.task_succeed_callback, 70 | callback_group=self.callback_group 71 | ) 72 | self.task_failed_service = self.create_service( 73 | Goal, 74 | 'task_failed', 75 | self.task_failed_callback, 76 | callback_group=self.callback_group 77 | ) 78 | 79 | def open_entity(self): 80 | try: 81 | package_share = get_package_share_directory('turtlebot3_gazebo') 82 | model_path = os.path.join( 83 | package_share, 'models', 'turtlebot3_dqn_world', 'goal_box', 'model.sdf' 84 | ) 85 | with open(model_path, 'r') as f: 86 | self.entity = f.read() 87 | self.get_logger().info('Loaded entity from: ' + model_path) 88 | except Exception as e: 89 | self.get_logger().error('Failed to load entity file: {}'.format(e)) 90 | raise e 91 | 92 | def spawn_entity(self): 93 | if ROS_DISTRO == 'humble': 94 | entity_pose = Pose() 95 | entity_pose.position.x = self.entity_pose_x 96 | entity_pose.position.y = self.entity_pose_y 97 | 98 | spawn_req = SpawnEntity.Request() 99 | spawn_req.name = self.entity_name 100 | spawn_req.xml = self.entity 101 | spawn_req.initial_pose = entity_pose 102 | 103 | while not self.spawn_entity_client.wait_for_service(timeout_sec=1.0): 104 | self.get_logger().warn('service for spawn_entity is not available, waiting ...') 105 | future = self.spawn_entity_client.call_async(spawn_req) 106 | rclpy.spin_until_future_complete(self, future) 107 | print(f'Spawn Goal at ({self.entity_pose_x}, {self.entity_pose_y}, {0.0})') 108 | else: 109 | service_name = '/world/dqn/create' 110 | package_share = get_package_share_directory('turtlebot3_gazebo') 111 | model_path = os.path.join( 112 | package_share, 'models', 'turtlebot3_dqn_world', 'goal_box', 'model.sdf' 113 | ) 114 | req = ( 115 | f'sdf_filename: "{model_path}", ' 116 | f'name: "{self.entity_name}", ' 117 | f'pose: {{ position: {{ ' 118 | f'x: {self.entity_pose_x}, ' 119 | f'y: {self.entity_pose_y}, ' 120 | f'z: 0.0 }} }}' 121 | ) 122 | cmd = [ 123 | 'gz', 'service', 124 | '-s', service_name, 125 | '--reqtype', 'gz.msgs.EntityFactory', 126 | '--reptype', 'gz.msgs.Boolean', 127 | '--timeout', '1000', 128 | '--req', req 129 | ] 130 | try: 131 | subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL) 132 | print(f'Spawn Goal at ({self.entity_pose_x}, {self.entity_pose_y}, {0.0})') 133 | except subprocess.CalledProcessError: 134 | pass 135 | 136 | def delete_entity(self): 137 | if ROS_DISTRO == 'humble': 138 | delete_req = DeleteEntity.Request() 139 | delete_req.name = self.entity_name 140 | 141 | while not self.delete_entity_client.wait_for_service(timeout_sec=1.0): 142 | self.get_logger().warn('service for delete_entity is not available, waiting ...') 143 | future = self.delete_entity_client.call_async(delete_req) 144 | rclpy.spin_until_future_complete(self, future) 145 | print('Delete Goal') 146 | else: 147 | service_name = '/world/dqn/remove' 148 | req = f'name: "{self.entity_name}", type: 2' 149 | cmd = [ 150 | 'gz', 'service', 151 | '-s', service_name, 152 | '--reqtype', 'gz.msgs.Entity', 153 | '--reptype', 'gz.msgs.Boolean', 154 | '--timeout', '1000', 155 | '--req', req 156 | ] 157 | try: 158 | subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL) 159 | print('Delete Goal') 160 | except subprocess.CalledProcessError: 161 | pass 162 | 163 | def reset_simulation(self): 164 | reset_req = Empty.Request() 165 | 166 | while not self.reset_simulation_client.wait_for_service(timeout_sec=1.0): 167 | self.get_logger().warn('service for reset_simulation is not available, waiting ...') 168 | 169 | self.reset_simulation_client.call_async(reset_req) 170 | 171 | def reset_burger(self): 172 | service_name_delete = '/world/dqn/remove' 173 | req_delete = 'name: "burger", type: 2' 174 | cmd_delete = [ 175 | 'gz', 'service', 176 | '-s', service_name_delete, 177 | '--reqtype', 'gz.msgs.Entity', 178 | '--reptype', 'gz.msgs.Boolean', 179 | '--timeout', '1000', 180 | '--req', req_delete 181 | ] 182 | try: 183 | subprocess.run(cmd_delete, check=True, stdout=subprocess.DEVNULL) 184 | print('Delete Burger') 185 | except subprocess.CalledProcessError: 186 | pass 187 | time.sleep(0.2) 188 | service_name_spawn = '/world/dqn/create' 189 | package_share = get_package_share_directory('turtlebot3_gazebo') 190 | model_path = os.path.join(package_share, 'models', 'turtlebot3_burger', 'model.sdf') 191 | req_spawn = ( 192 | f'sdf_filename: "{model_path}", ' 193 | f'name: "burger", ' 194 | f'pose: {{ position: {{ x: 0.0, y: 0.0, z: 0.0 }} }}' 195 | ) 196 | cmd_spawn = [ 197 | 'gz', 'service', 198 | '-s', service_name_spawn, 199 | '--reqtype', 'gz.msgs.EntityFactory', 200 | '--reptype', 'gz.msgs.Boolean', 201 | '--timeout', '1000', 202 | '--req', req_spawn 203 | ] 204 | try: 205 | subprocess.run(cmd_spawn, check=True, stdout=subprocess.DEVNULL) 206 | print('Spawn Burger') 207 | except subprocess.CalledProcessError: 208 | pass 209 | 210 | def task_succeed_callback(self, request, response): 211 | self.delete_entity() 212 | time.sleep(0.2) 213 | self.generate_goal_pose() 214 | time.sleep(0.2) 215 | self.spawn_entity() 216 | response.pose_x = self.entity_pose_x 217 | response.pose_y = self.entity_pose_y 218 | response.success = True 219 | return response 220 | 221 | def task_failed_callback(self, request, response): 222 | self.delete_entity() 223 | time.sleep(0.2) 224 | if ROS_DISTRO == 'humble': 225 | self.reset_simulation() 226 | else: 227 | self.reset_burger() 228 | time.sleep(0.2) 229 | self.generate_goal_pose() 230 | time.sleep(0.2) 231 | self.spawn_entity() 232 | response.pose_x = self.entity_pose_x 233 | response.pose_y = self.entity_pose_y 234 | response.success = True 235 | return response 236 | 237 | def initialize_env_callback(self, request, response): 238 | self.delete_entity() 239 | time.sleep(0.2) 240 | if ROS_DISTRO == 'humble': 241 | self.reset_simulation() 242 | else: 243 | self.reset_burger() 244 | time.sleep(0.2) 245 | self.spawn_entity() 246 | response.pose_x = self.entity_pose_x 247 | response.pose_y = self.entity_pose_y 248 | response.success = True 249 | return response 250 | 251 | def generate_goal_pose(self): 252 | if self.stage != 4: 253 | self.entity_pose_x = random.randrange(-21, 21) / 10 254 | self.entity_pose_y = random.randrange(-21, 21) / 10 255 | else: 256 | goal_pose_list = [ 257 | [1.0, 0.0], [2.0, -1.5], [0.0, -2.0], [2.0, 1.5], [0.5, 2.0], [-1.5, 2.1], 258 | [-2.0, 0.5], [-2.0, -0.5], [-1.5, -2.0], [-0.5, -1.0], [2.0, -0.5], [-1.0, -1.0] 259 | ] 260 | rand_index = random.randint(0, len(goal_pose_list) - 1) 261 | self.entity_pose_x = goal_pose_list[rand_index][0] 262 | self.entity_pose_y = goal_pose_list[rand_index][1] 263 | 264 | 265 | def main(args=None): 266 | rclpy.init(args=sys.argv) 267 | stage_num = sys.argv[1] if len(sys.argv) > 1 else '1' 268 | gazebo_interface = GazeboInterface(stage_num) 269 | try: 270 | while rclpy.ok(): 271 | rclpy.spin_once(gazebo_interface, timeout_sec=0.1) 272 | except KeyboardInterrupt: 273 | pass 274 | finally: 275 | gazebo_interface.destroy_node() 276 | rclpy.shutdown() 277 | 278 | 279 | if __name__ == '__main__': 280 | main() 281 | -------------------------------------------------------------------------------- /turtlebot3_dqn/turtlebot3_dqn/dqn_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ################################################################################# 3 | # Copyright 2019 ROBOTIS CO., LTD. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | ################################################################################# 17 | # 18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee 19 | 20 | import collections 21 | import os 22 | import sys 23 | import time 24 | 25 | import numpy 26 | import rclpy 27 | from rclpy.node import Node 28 | from tensorflow.keras.layers import Dense 29 | from tensorflow.keras.losses import MeanSquaredError 30 | from tensorflow.keras.models import load_model 31 | from tensorflow.keras.models import Sequential 32 | from tensorflow.keras.optimizers import RMSprop 33 | 34 | from turtlebot3_msgs.srv import Dqn 35 | 36 | 37 | class DQNTest(Node): 38 | 39 | def __init__(self, stage, load_episode): 40 | super().__init__('dqn_test') 41 | 42 | self.stage = int(stage) 43 | self.load_episode = int(load_episode) 44 | 45 | self.state_size = 26 46 | self.action_size = 5 47 | 48 | self.memory = collections.deque(maxlen=1000000) 49 | 50 | self.model = self.build_model() 51 | model_path = os.path.join( 52 | os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 53 | 'saved_model', 54 | f'stage{self.stage}_episode{self.load_episode}.h5' 55 | ) 56 | 57 | loaded_model = load_model( 58 | model_path, compile=False, custom_objects={'mse': MeanSquaredError()} 59 | ) 60 | self.model.set_weights(loaded_model.get_weights()) 61 | 62 | self.rl_agent_interface_client = self.create_client(Dqn, 'rl_agent_interface') 63 | 64 | self.run_test() 65 | 66 | def build_model(self): 67 | model = Sequential() 68 | model.add(Dense( 69 | 512, input_shape=(self.state_size,), 70 | activation='relu', 71 | kernel_initializer='lecun_uniform' 72 | )) 73 | model.add(Dense(256, activation='relu', kernel_initializer='lecun_uniform')) 74 | model.add(Dense(128, activation='relu', kernel_initializer='lecun_uniform')) 75 | model.add(Dense(self.action_size, activation='linear', kernel_initializer='lecun_uniform')) 76 | model.compile(loss=MeanSquaredError(), optimizer=RMSprop(learning_rate=0.00025)) 77 | return model 78 | 79 | def get_action(self, state): 80 | state = numpy.asarray(state) 81 | q_values = self.model.predict(state.reshape(1, -1), verbose=0) 82 | return int(numpy.argmax(q_values[0])) 83 | 84 | def run_test(self): 85 | while True: 86 | done = False 87 | init = True 88 | score = 0 89 | local_step = 0 90 | next_state = [] 91 | 92 | time.sleep(1.0) 93 | 94 | while not done: 95 | local_step += 1 96 | action = 2 if local_step == 1 else self.get_action(next_state) 97 | 98 | req = Dqn.Request() 99 | req.action = action 100 | req.init = init 101 | 102 | while not self.rl_agent_interface_client.wait_for_service(timeout_sec=1.0): 103 | self.get_logger().warn( 104 | 'rl_agent interface service not available, waiting again...') 105 | 106 | future = self.rl_agent_interface_client.call_async(req) 107 | rclpy.spin_until_future_complete(self, future) 108 | 109 | if future.done() and future.result() is not None: 110 | next_state = future.result().state 111 | reward = future.result().reward 112 | done = future.result().done 113 | score += reward 114 | init = False 115 | else: 116 | self.get_logger().error(f'Service call failure: {future.exception()}') 117 | 118 | time.sleep(0.01) 119 | 120 | 121 | def main(args=None): 122 | rclpy.init(args=args if args else sys.argv) 123 | stage = sys.argv[1] if len(sys.argv) > 1 else '1' 124 | load_episode = sys.argv[2] if len(sys.argv) > 2 else '600' 125 | node = DQNTest(stage, load_episode) 126 | 127 | try: 128 | rclpy.spin(node) 129 | except KeyboardInterrupt: 130 | pass 131 | finally: 132 | node.destroy_node() 133 | rclpy.shutdown() 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /turtlebot3_dqn/turtlebot3_dqn/result_graph.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | ################################################################################# 3 | # Copyright 2018 ROBOTIS CO., LTD. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | ################################################################################# 17 | # 18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee 19 | 20 | import signal 21 | import sys 22 | import threading 23 | 24 | from PyQt5.QtCore import QTimer 25 | from PyQt5.QtWidgets import QApplication 26 | from PyQt5.QtWidgets import QMainWindow 27 | import pyqtgraph 28 | import rclpy 29 | from rclpy.node import Node 30 | from std_msgs.msg import Float32MultiArray 31 | 32 | 33 | class GraphSubscriber(Node): 34 | 35 | def __init__(self, window): 36 | super().__init__('graph') 37 | 38 | self.window = window 39 | 40 | self.subscription = self.create_subscription( 41 | Float32MultiArray, 42 | '/result', 43 | self.data_callback, 44 | 10 45 | ) 46 | self.subscription 47 | 48 | def data_callback(self, msg): 49 | self.window.receive_data(msg) 50 | 51 | 52 | class Window(QMainWindow): 53 | 54 | def __init__(self): 55 | super(Window, self).__init__() 56 | 57 | self.setWindowTitle('Result') 58 | self.setGeometry(50, 50, 600, 650) 59 | 60 | self.ep = [] 61 | self.data_list = [] 62 | self.rewards = [] 63 | self.count = 1 64 | 65 | self.plot() 66 | 67 | self.ros_subscriber = GraphSubscriber(self) 68 | self.ros_thread = threading.Thread( 69 | target=rclpy.spin, args=(self.ros_subscriber,), daemon=True 70 | ) 71 | self.ros_thread.start() 72 | 73 | def receive_data(self, msg): 74 | self.data_list.append(msg.data[0]) 75 | self.ep.append(self.count) 76 | self.count += 1 77 | self.rewards.append(msg.data[1]) 78 | 79 | def plot(self): 80 | self.qValuePlt = pyqtgraph.PlotWidget(self, title='Average max Q-value') 81 | self.qValuePlt.setGeometry(0, 320, 600, 300) 82 | 83 | self.rewardsPlt = pyqtgraph.PlotWidget(self, title='Total reward') 84 | self.rewardsPlt.setGeometry(0, 10, 600, 300) 85 | 86 | self.timer = QTimer() 87 | self.timer.timeout.connect(self.update) 88 | self.timer.start(200) 89 | 90 | self.show() 91 | 92 | def update(self): 93 | self.rewardsPlt.showGrid(x=True, y=True) 94 | self.qValuePlt.showGrid(x=True, y=True) 95 | 96 | self.rewardsPlt.plot(self.ep, self.data_list, pen=(255, 0, 0), clear=True) 97 | self.qValuePlt.plot(self.ep, self.rewards, pen=(0, 255, 0), clear=True) 98 | 99 | def closeEvent(self, event): 100 | if self.ros_subscriber is not None: 101 | self.ros_subscriber.destroy_node() 102 | rclpy.shutdown() 103 | event.accept() 104 | 105 | 106 | def main(): 107 | rclpy.init() 108 | app = QApplication(sys.argv) 109 | win = Window() 110 | 111 | def shutdown_handler(sig, frame): 112 | print('shutdown') 113 | if win.ros_subscriber is not None: 114 | win.ros_subscriber.destroy_node() 115 | rclpy.shutdown() 116 | app.quit() 117 | 118 | signal.signal(signal.SIGINT, shutdown_handler) 119 | signal.signal(signal.SIGTERM, shutdown_handler) 120 | sys.exit(app.exec()) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /turtlebot3_machine_learning/CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 2 | Changelog for package turtlebot3_machine_learning 3 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 4 | 5 | 1.0.1 (2025-05-02) 6 | ------------------ 7 | * Support for ROS 2 Jazzy version 8 | * Gazebo simulation support for the package 9 | * Contributors: ChanHyeong Lee 10 | 11 | 1.0.0 (2025-04-17) 12 | ------------------ 13 | * Support for ROS 2 Humble version 14 | * Renewal of package structure 15 | * Improved behavioral rewards for agents 16 | * Contributors: ChanHyeong Lee 17 | -------------------------------------------------------------------------------- /turtlebot3_machine_learning/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | ################################################################################ 2 | # Set minimum required version of cmake, project name and compile options 3 | ################################################################################ 4 | cmake_minimum_required(VERSION 3.5) 5 | project(turtlebot3_machine_learning) 6 | 7 | if(NOT CMAKE_CXX_STANDARD) 8 | set(CMAKE_CXX_STANDARD 17) 9 | endif() 10 | 11 | ################################################################################ 12 | # Find ament packages and libraries for ament and system dependencies 13 | ################################################################################ 14 | find_package(ament_cmake REQUIRED) 15 | 16 | ################################################################################ 17 | # Macro for ament package 18 | ################################################################################ 19 | ament_package() 20 | -------------------------------------------------------------------------------- /turtlebot3_machine_learning/package.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | turtlebot3_machine_learning 4 | 1.0.1 5 | 6 | This metapackage for ROS 2 TurtleBot3 machine learning. 7 | 8 | Pyo 9 | Apache 2.0 10 | http://turtlebot3.robotis.com 11 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning 12 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning/issues 13 | Gilbert 14 | ChanHyeong Lee 15 | turtlebot3_dqn 16 | 17 | ament_cmake 18 | 19 | 20 | -------------------------------------------------------------------------------- /turtlebot3_machine_learning_ci.repos: -------------------------------------------------------------------------------- 1 | repositories: 2 | turtlebot3_msgs: 3 | type: git 4 | url: https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git 5 | version: main 6 | --------------------------------------------------------------------------------