├── .github
└── workflows
│ ├── ros-ci.yml
│ └── ros-lint.yml
├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── turtlebot3_dqn
├── CHANGELOG.rst
├── package.xml
├── resource
│ └── turtlebot3_dqn
├── saved_model
│ ├── stage1_episode600.h5
│ └── stage1_episode600.json
├── setup.cfg
├── setup.py
└── turtlebot3_dqn
│ ├── __init__.py
│ ├── action_graph.py
│ ├── dqn_agent.py
│ ├── dqn_environment.py
│ ├── dqn_gazebo.py
│ ├── dqn_test.py
│ └── result_graph.py
├── turtlebot3_machine_learning
├── CHANGELOG.rst
├── CMakeLists.txt
└── package.xml
└── turtlebot3_machine_learning_ci.repos
/.github/workflows/ros-ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [ humble, jazzy, main ]
6 | pull_request:
7 | branches: [ humble, jazzy, main ]
8 |
9 | jobs:
10 | ROS_CI:
11 | runs-on: ubuntu-22.04
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | ros_distribution:
16 | - humble
17 | - jazzy
18 | - rolling
19 | include:
20 | # ROS 2 Humble Hawksbill
21 | - docker_image: ubuntu:jammy
22 | ros_distribution: humble
23 | ros_version: 2
24 | # ROS 2 Jazzy Jalisco
25 | - docker_image: ubuntu:noble
26 | ros_distribution: jazzy
27 | ros_version: 2
28 | # ROS 2 Rolling Ridley
29 | - docker_image: ubuntu:noble
30 | ros_distribution: rolling
31 | ros_version: 2
32 | container:
33 | image: ${{ matrix.docker_image }}
34 | steps:
35 | - name: Setup workspace
36 | run: mkdir -p ros_ws/src
37 |
38 | - name: Checkout code
39 | uses: actions/checkout@v4
40 | with:
41 | path: ros_ws/src
42 |
43 | - name: Setup ROS environment
44 | uses: ros-tooling/setup-ros@v0.7
45 | with:
46 | required-ros-distributions: ${{ matrix.ros_distribution }}
47 |
48 | - name: Add pip break-system-packages for rosdep
49 | run: |
50 | printf "[install]\nbreak-system-packages = true\n" | sudo tee /etc/pip.conf
51 |
52 | - name: Build and Test
53 | uses: ros-tooling/action-ros-ci@v0.3
54 | env:
55 | PIP_BREAK_SYSTEM_PACKAGES: "1"
56 | with:
57 | target-ros2-distro: ${{ matrix.ros_distribution }}
58 | vcs-repo-file-url: "https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/main/turtlebot3_machine_learning_ci.repos"
59 | package-name: |
60 | turtlebot3_dqn
61 | turtlebot3_machine_learning
62 |
--------------------------------------------------------------------------------
/.github/workflows/ros-lint.yml:
--------------------------------------------------------------------------------
1 | # The name of the workflow
2 | name: Lint
3 |
4 | # Specifies the events that trigger the workflow
5 | on:
6 | pull_request:
7 |
8 | # Defines a set of jobs to be run as part of the workflow
9 | jobs:
10 | ament_lint:
11 | runs-on: ubuntu-latest
12 | container:
13 | image: rostooling/setup-ros-docker:ubuntu-noble-ros-rolling-ros-base-latest
14 | strategy:
15 | fail-fast: false
16 | matrix:
17 | linter: [flake8, pep257, xmllint, copyright]
18 | steps:
19 | - name: Checkout code
20 | uses: actions/checkout@v4
21 |
22 | - name: Setup ROS environment
23 | uses: ros-tooling/setup-ros@v0.7
24 |
25 | - name: Run Linter
26 | env:
27 | AMENT_CPPCHECK_ALLOW_SLOW_VERSIONS: 1
28 | uses: ros-tooling/action-ros-lint@master
29 | with:
30 | linter: ${{ matrix.linter }}
31 | distribution: rolling
32 | package-name: "*"
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 |
107 | # own
108 | /docker_setup/backup
109 | *.idea
110 | *.log
111 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | Any contribution that you make to this repository will
2 | be under the Apache 2 License, as dictated by that
3 | [license](http://www.apache.org/licenses/LICENSE-2.0.html):
4 |
5 | ~~~
6 | 5. Submission of Contributions. Unless You explicitly state otherwise,
7 | any Contribution intentionally submitted for inclusion in the Work
8 | by You to the Licensor shall be under the terms and conditions of
9 | this License, without any additional terms or conditions.
10 | Notwithstanding the above, nothing herein shall supersede or modify
11 | the terms of any separate license agreement you may have executed
12 | with Licensor regarding such Contributions.
13 | ~~~
14 |
15 | Contributors must sign-off each commit by adding a `Signed-off-by: ...`
16 | line to commit messages to certify that they have the right to submit
17 | the code they are contributing to the project according to the
18 | [Developer Certificate of Origin (DCO)](https://developercertificate.org/).
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TurtleBot3
2 |
3 |
4 | - Active Branches: noetic, humble, jazzy, main(rolling)
5 | - Legacy Branches: *-devel
6 |
7 | ## Open Source Projects Related to TurtleBot3
8 | - [turtlebot3](https://github.com/ROBOTIS-GIT/turtlebot3)
9 | - [turtlebot3_msgs](https://github.com/ROBOTIS-GIT/turtlebot3_msgs)
10 | - [turtlebot3_simulations](https://github.com/ROBOTIS-GIT/turtlebot3_simulations)
11 | - [turtlebot3_manipulation](https://github.com/ROBOTIS-GIT/turtlebot3_manipulation)
12 | - [turtlebot3_manipulation_simulations](https://github.com/ROBOTIS-GIT/turtlebot3_manipulation_simulations)
13 | - [turtlebot3_applications](https://github.com/ROBOTIS-GIT/turtlebot3_applications)
14 | - [turtlebot3_applications_msgs](https://github.com/ROBOTIS-GIT/turtlebot3_applications_msgs)
15 | - [turtlebot3_machine_learning](https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning)
16 | - [turtlebot3_autorace](https://github.com/ROBOTIS-GIT/turtlebot3_autorace)
17 | - [turtlebot3_home_service_challenge](https://github.com/ROBOTIS-GIT/turtlebot3_home_service_challenge)
18 | - [hls_lfcd_lds_driver](https://github.com/ROBOTIS-GIT/hls_lfcd_lds_driver)
19 | - [ld08_driver](https://github.com/ROBOTIS-GIT/ld08_driver)
20 | - [open_manipulator](https://github.com/ROBOTIS-GIT/open_manipulator)
21 | - [dynamixel_sdk](https://github.com/ROBOTIS-GIT/DynamixelSDK)
22 | - [OpenCR-Hardware](https://github.com/ROBOTIS-GIT/OpenCR-Hardware)
23 | - [OpenCR](https://github.com/ROBOTIS-GIT/OpenCR)
24 |
25 | ## Documentation, Videos, and Community
26 |
27 | ### Official Documentation
28 | - ⚙️ **[ROBOTIS DYNAMIXEL](https://dynamixel.com/)**
29 | - 📚 **[ROBOTIS e-Manual for Dynamixel SDK](http://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/overview/)**
30 | - 📚 **[ROBOTIS e-Manual for TurtleBot3](http://turtlebot3.robotis.com/)**
31 | - 📚 **[ROBOTIS e-Manual for OpenMANIPULATOR-X](https://emanual.robotis.com/docs/en/platform/openmanipulator_x/overview/)**
32 |
33 | ### Learning Resources
34 | - 🎥 **[ROBOTIS YouTube Channel](https://www.youtube.com/@ROBOTISCHANNEL)**
35 | - 🎥 **[ROBOTIS Open Source YouTube Channel](https://www.youtube.com/@ROBOTISOpenSourceTeam)**
36 | - 🎥 **[ROBOTIS TurtleBot3 YouTube Playlist](https://www.youtube.com/playlist?list=PLRG6WP3c31_XI3wlvHlx2Mp8BYqgqDURU)**
37 | - 🎥 **[ROBOTIS OpenMANIPULATOR YouTube Playlist](https://www.youtube.com/playlist?list=PLRG6WP3c31_WpEsB6_Rdt3KhiopXQlUkb)**
38 |
39 | ### Community & Support
40 | - 💬 **[ROBOTIS Community Forum](https://forum.robotis.com/)**
41 | - 💬 **[TurtleBot category from ROS Community](https://discourse.ros.org/c/turtlebot/)**
42 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/CHANGELOG.rst:
--------------------------------------------------------------------------------
1 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2 | Changelog for package turtlebot3_dqn
3 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4 |
5 | 1.0.1 (2025-05-02)
6 | ------------------
7 | * Support for ROS 2 Jazzy version
8 | * Gazebo simulation support for the package
9 | * Contributors: ChanHyeong Lee
10 |
11 | 1.0.0 (2025-04-17)
12 | ------------------
13 | * Support for ROS 2 Humble version
14 | * Renewal of package structure
15 | * Improved behavioral rewards for agents
16 | * Contributors: ChanHyeong Lee
17 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | turtlebot3_dqn
4 | 1.0.1
5 |
6 | The turtlebot3_dqn package using reinforcement learning with DQN (Deep Q-Learning).
7 |
8 | Pyo
9 | Apache 2.0
10 | http://turtlebot3.robotis.com
11 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning
12 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning/issues
13 | Gilbert
14 | ChanHyeong Lee
15 | python3-pip
16 | ament_index_python
17 | geometry_msgs
18 | python-tensorflow-pip
19 | python3-numpy
20 | python3-pyqt5
21 | python3-pyqtgraph
22 | rclpy
23 | sensor_msgs
24 | std_msgs
25 | std_srvs
26 | turtlebot3_msgs
27 |
28 | ament_python
29 |
30 |
31 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/resource/turtlebot3_dqn:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/95847712f2d06507dffb6c94b2f356ff38dfa79f/turtlebot3_dqn/resource/turtlebot3_dqn
--------------------------------------------------------------------------------
/turtlebot3_dqn/saved_model/stage1_episode600.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/95847712f2d06507dffb6c94b2f356ff38dfa79f/turtlebot3_dqn/saved_model/stage1_episode600.h5
--------------------------------------------------------------------------------
/turtlebot3_dqn/saved_model/stage1_episode600.json:
--------------------------------------------------------------------------------
1 | {"epsilon": 0.1416302983127139}
2 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/setup.cfg:
--------------------------------------------------------------------------------
1 | [develop]
2 | script_dir=$base/lib/turtlebot3_dqn
3 | [install]
4 | install_scripts=$base/lib/turtlebot3_dqn
5 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/setup.py:
--------------------------------------------------------------------------------
1 | import glob
2 |
3 | from setuptools import find_packages
4 | from setuptools import setup
5 |
6 | package_name = 'turtlebot3_dqn'
7 | authors_info = [
8 | ('Gilbert', 'kkjong@robotis.com'),
9 | ('Ryan Shim', 'N/A'),
10 | ('ChanHyeong Lee', 'dddoggi1207@gmail.com'),
11 | ]
12 | authors = ', '.join(author for author, _ in authors_info)
13 | author_emails = ', '.join(email for _, email in authors_info)
14 |
15 | setup(
16 | name=package_name,
17 | version='1.0.1',
18 | packages=find_packages(),
19 | data_files=[
20 | ('share/ament_index/resource_index/packages', ['resource/' + package_name]),
21 | ('share/' + package_name, ['package.xml']),
22 | ('share/' + package_name + '/launch', glob.glob('launch/*.py')),
23 | ],
24 | install_requires=['setuptools', 'launch'],
25 | zip_safe=True,
26 | author=authors,
27 | author_email=author_emails,
28 | maintainer='Pyo',
29 | maintainer_email='pyo@robotis.com',
30 | description='ROS 2 packages for TurtleBot3 machine learning',
31 | license='Apache 2.0',
32 | tests_require=['pytest'],
33 | entry_points={
34 | 'console_scripts': [
35 | 'action_graph = turtlebot3_dqn.action_graph:main',
36 | 'dqn_agent = turtlebot3_dqn.dqn_agent:main',
37 | 'dqn_environment = turtlebot3_dqn.dqn_environment:main',
38 | 'dqn_gazebo = turtlebot3_dqn.dqn_gazebo:main',
39 | 'dqn_test = turtlebot3_dqn.dqn_test:main',
40 | 'result_graph = turtlebot3_dqn.result_graph:main',
41 | ],
42 | },
43 | )
44 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/turtlebot3_dqn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ROBOTIS-GIT/turtlebot3_machine_learning/95847712f2d06507dffb6c94b2f356ff38dfa79f/turtlebot3_dqn/turtlebot3_dqn/__init__.py
--------------------------------------------------------------------------------
/turtlebot3_dqn/turtlebot3_dqn/action_graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #################################################################################
3 | # Copyright 2018 ROBOTIS CO., LTD.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #################################################################################
17 | #
18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee
19 |
20 | import signal
21 | import sys
22 | import threading
23 | import time
24 |
25 | from PyQt5.QtCore import pyqtSignal
26 | from PyQt5.QtCore import Qt
27 | from PyQt5.QtCore import QThread
28 | from PyQt5.QtWidgets import QApplication
29 | from PyQt5.QtWidgets import QGridLayout
30 | from PyQt5.QtWidgets import QLabel
31 | from PyQt5.QtWidgets import QLineEdit
32 | from PyQt5.QtWidgets import QProgressBar
33 | from PyQt5.QtWidgets import QWidget
34 | import rclpy
35 | from rclpy.node import Node
36 | from std_msgs.msg import Float32MultiArray
37 |
38 |
39 | class Ros2Subscriber(Node):
40 |
41 | def __init__(self, qt_thread):
42 | super().__init__('progress_subscriber')
43 | self.qt_thread = qt_thread
44 |
45 | self.subscription = self.create_subscription(
46 | Float32MultiArray,
47 | '/get_action',
48 | self.get_array_callback,
49 | 10
50 | )
51 |
52 | def get_array_callback(self, msg):
53 | data = list(msg.data)
54 |
55 | self.qt_thread.signal_action0.emit(0)
56 | self.qt_thread.signal_action1.emit(0)
57 | self.qt_thread.signal_action2.emit(0)
58 | self.qt_thread.signal_action3.emit(0)
59 | self.qt_thread.signal_action4.emit(0)
60 |
61 | if data[0] == 0:
62 | self.qt_thread.signal_action0.emit(100)
63 | elif data[0] == 1:
64 | self.qt_thread.signal_action1.emit(100)
65 | elif data[0] == 2:
66 | self.qt_thread.signal_action2.emit(100)
67 | elif data[0] == 3:
68 | self.qt_thread.signal_action3.emit(100)
69 | elif data[0] == 4:
70 | self.qt_thread.signal_action4.emit(100)
71 |
72 | if len(data) >= 2:
73 | self.qt_thread.signal_total_reward.emit(str(round(data[-2], 2)))
74 | self.qt_thread.signal_reward.emit(str(round(data[-1], 2)))
75 |
76 |
77 | class Thread(QThread):
78 |
79 | signal_action0 = pyqtSignal(int)
80 | signal_action1 = pyqtSignal(int)
81 | signal_action2 = pyqtSignal(int)
82 | signal_action3 = pyqtSignal(int)
83 | signal_action4 = pyqtSignal(int)
84 | signal_total_reward = pyqtSignal(str)
85 | signal_reward = pyqtSignal(str)
86 |
87 | def __init__(self):
88 | super().__init__()
89 | self.node = None
90 |
91 | def run(self):
92 | self.node = Ros2Subscriber(self)
93 | rclpy.spin(self.node)
94 | self.node.destroy_node()
95 |
96 |
97 | class Form(QWidget):
98 |
99 | def __init__(self, qt_thread):
100 | super().__init__(flags=Qt.Widget)
101 | self.qt_thread = qt_thread
102 | self.setWindowTitle('Action State')
103 |
104 | layout = QGridLayout()
105 |
106 | self.pgsb1 = QProgressBar()
107 | self.pgsb1.setOrientation(Qt.Vertical)
108 | self.pgsb1.setValue(0)
109 | self.pgsb1.setRange(0, 100)
110 |
111 | self.pgsb2 = QProgressBar()
112 | self.pgsb2.setOrientation(Qt.Vertical)
113 | self.pgsb2.setValue(0)
114 | self.pgsb2.setRange(0, 100)
115 |
116 | self.pgsb3 = QProgressBar()
117 | self.pgsb3.setOrientation(Qt.Vertical)
118 | self.pgsb3.setValue(0)
119 | self.pgsb3.setRange(0, 100)
120 |
121 | self.pgsb4 = QProgressBar()
122 | self.pgsb4.setOrientation(Qt.Vertical)
123 | self.pgsb4.setValue(0)
124 | self.pgsb4.setRange(0, 100)
125 |
126 | self.pgsb5 = QProgressBar()
127 | self.pgsb5.setOrientation(Qt.Vertical)
128 | self.pgsb5.setValue(0)
129 | self.pgsb5.setRange(0, 100)
130 |
131 | self.label_total_reward = QLabel('Total reward')
132 | self.edit_total_reward = QLineEdit('')
133 | self.edit_total_reward.setDisabled(True)
134 | self.edit_total_reward.setFixedWidth(100)
135 |
136 | self.label_reward = QLabel('Reward')
137 | self.edit_reward = QLineEdit('')
138 | self.edit_reward.setDisabled(True)
139 | self.edit_reward.setFixedWidth(100)
140 |
141 | self.label_left = QLabel('Left')
142 | self.label_front = QLabel('Front')
143 | self.label_right = QLabel('Right')
144 |
145 | layout.addWidget(self.label_total_reward, 0, 0)
146 | layout.addWidget(self.edit_total_reward, 1, 0)
147 | layout.addWidget(self.label_reward, 2, 0)
148 | layout.addWidget(self.edit_reward, 3, 0)
149 |
150 | layout.addWidget(self.pgsb1, 0, 4, 4, 1)
151 | layout.addWidget(self.pgsb2, 0, 5, 4, 1)
152 | layout.addWidget(self.pgsb3, 0, 6, 4, 1)
153 | layout.addWidget(self.pgsb4, 0, 7, 4, 1)
154 | layout.addWidget(self.pgsb5, 0, 8, 4, 1)
155 |
156 | layout.addWidget(self.label_left, 4, 4)
157 | layout.addWidget(self.label_front, 4, 6)
158 | layout.addWidget(self.label_right, 4, 8)
159 |
160 | self.setLayout(layout)
161 |
162 | qt_thread.signal_action0.connect(self.pgsb1.setValue)
163 | qt_thread.signal_action1.connect(self.pgsb2.setValue)
164 | qt_thread.signal_action2.connect(self.pgsb3.setValue)
165 | qt_thread.signal_action3.connect(self.pgsb4.setValue)
166 | qt_thread.signal_action4.connect(self.pgsb5.setValue)
167 | qt_thread.signal_total_reward.connect(self.edit_total_reward.setText)
168 | qt_thread.signal_reward.connect(self.edit_reward.setText)
169 |
170 | def closeEvent(self, event):
171 | if hasattr(self.qt_thread, 'node') and self.qt_thread.node is not None:
172 | self.qt_thread.node.destroy_node()
173 | rclpy.shutdown()
174 | event.accept()
175 |
176 |
177 | def run_qt_app(qt_thread):
178 | app = QApplication(sys.argv)
179 | form = Form(qt_thread)
180 | form.show()
181 | app.exec_()
182 |
183 |
184 | def main():
185 | rclpy.init()
186 | qt_thread = Thread()
187 | qt_thread.start()
188 | qt_gui_thread = threading.Thread(target=run_qt_app, args=(qt_thread,), daemon=True)
189 | qt_gui_thread.start()
190 |
191 | def shutdown_handler(sig, frame):
192 | print('shutdown')
193 | qt_thread.node.destroy_node()
194 | rclpy.shutdown()
195 | sys.exit(0)
196 |
197 | signal.signal(signal.SIGINT, shutdown_handler)
198 | signal.signal(signal.SIGTERM, shutdown_handler)
199 | try:
200 | while rclpy.ok():
201 | time.sleep(0.1)
202 | except KeyboardInterrupt:
203 | shutdown_handler(None, None)
204 |
205 |
206 | if __name__ == '__main__':
207 | main()
208 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/turtlebot3_dqn/dqn_agent.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #################################################################################
3 | # Copyright 2019 ROBOTIS CO., LTD.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #################################################################################
17 | #
18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee
19 |
20 | import collections
21 | import datetime
22 | import json
23 | import math
24 | import os
25 | import random
26 | import sys
27 | import time
28 |
29 | import numpy
30 | import rclpy
31 | from rclpy.node import Node
32 | from std_msgs.msg import Float32MultiArray
33 | from std_srvs.srv import Empty
34 | import tensorflow
35 | from tensorflow.keras.layers import Dense
36 | from tensorflow.keras.layers import Input
37 | from tensorflow.keras.losses import MeanSquaredError
38 | from tensorflow.keras.models import load_model
39 | from tensorflow.keras.models import Sequential
40 | from tensorflow.keras.optimizers import Adam
41 |
42 | from turtlebot3_msgs.srv import Dqn
43 |
44 |
45 | tensorflow.config.set_visible_devices([], 'GPU')
46 |
47 | LOGGING = True
48 | current_time = datetime.datetime.now().strftime('[%mm%dd-%H:%M]')
49 |
50 |
51 | class DQNMetric(tensorflow.keras.metrics.Metric):
52 |
53 | def __init__(self, name='dqn_metric'):
54 | super(DQNMetric, self).__init__(name=name)
55 | self.loss = self.add_weight(name='loss', initializer='zeros')
56 | self.episode_step = self.add_weight(name='step', initializer='zeros')
57 |
58 | def update_state(self, y_true, y_pred=0, sample_weight=None):
59 | self.loss.assign_add(y_true)
60 | self.episode_step.assign_add(1)
61 |
62 | def result(self):
63 | return self.loss / self.episode_step
64 |
65 | def reset_states(self):
66 | self.loss.assign(0)
67 | self.episode_step.assign(0)
68 |
69 |
70 | class DQNAgent(Node):
71 |
72 | def __init__(self, stage_num, max_training_episodes):
73 | super().__init__('dqn_agent')
74 |
75 | self.stage = int(stage_num)
76 | self.train_mode = True
77 | self.state_size = 26
78 | self.action_size = 5
79 | self.max_training_episodes = int(max_training_episodes)
80 |
81 | self.done = False
82 | self.succeed = False
83 | self.fail = False
84 |
85 | self.discount_factor = 0.99
86 | self.learning_rate = 0.0007
87 | self.epsilon = 1.0
88 | self.step_counter = 0
89 | self.epsilon_decay = 6000 * self.stage
90 | self.epsilon_min = 0.05
91 | self.batch_size = 128
92 |
93 | self.replay_memory = collections.deque(maxlen=500000)
94 | self.min_replay_memory_size = 5000
95 |
96 | self.model = self.create_qnetwork()
97 | self.target_model = self.create_qnetwork()
98 | self.update_target_model()
99 | self.update_target_after = 5000
100 | self.target_update_after_counter = 0
101 |
102 | self.load_model = False
103 | self.load_episode = 0
104 | self.model_dir_path = os.path.join(
105 | os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
106 | 'saved_model'
107 | )
108 | self.model_path = os.path.join(
109 | self.model_dir_path,
110 | 'stage' + str(self.stage) + '_episode' + str(self.load_episode) + '.h5'
111 | )
112 |
113 | if self.load_model:
114 | self.model.set_weights(load_model(self.model_path).get_weights())
115 | with open(os.path.join(
116 | self.model_dir_path,
117 | 'stage' + str(self.stage) + '_episode' + str(self.load_episode) + '.json'
118 | )) as outfile:
119 | param = json.load(outfile)
120 | self.epsilon = param.get('epsilon')
121 | self.step_counter = param.get('step_counter')
122 |
123 | if LOGGING:
124 | tensorboard_file_name = current_time + ' dqn_stage' + str(self.stage) + '_reward'
125 | home_dir = os.path.expanduser('~')
126 | dqn_reward_log_dir = os.path.join(
127 | home_dir, 'turtlebot3_dqn_logs', 'gradient_tape', tensorboard_file_name
128 | )
129 | self.dqn_reward_writer = tensorflow.summary.create_file_writer(dqn_reward_log_dir)
130 | self.dqn_reward_metric = DQNMetric()
131 |
132 | self.rl_agent_interface_client = self.create_client(Dqn, 'rl_agent_interface')
133 | self.make_environment_client = self.create_client(Empty, 'make_environment')
134 | self.reset_environment_client = self.create_client(Dqn, 'reset_environment')
135 |
136 | self.action_pub = self.create_publisher(Float32MultiArray, '/get_action', 10)
137 | self.result_pub = self.create_publisher(Float32MultiArray, 'result', 10)
138 |
139 | self.process()
140 |
141 | def process(self):
142 | self.env_make()
143 | time.sleep(1.0)
144 |
145 | episode_num = self.load_episode
146 |
147 | for episode in range(self.load_episode + 1, self.max_training_episodes + 1):
148 | state = self.reset_environment()
149 | episode_num += 1
150 | local_step = 0
151 | score = 0
152 | sum_max_q = 0.0
153 |
154 | time.sleep(1.0)
155 |
156 | while True:
157 | local_step += 1
158 |
159 | q_values = self.model.predict(state)
160 | sum_max_q += float(numpy.max(q_values))
161 |
162 | action = int(self.get_action(state))
163 | next_state, reward, done = self.step(action)
164 | score += reward
165 |
166 | msg = Float32MultiArray()
167 | msg.data = [float(action), float(score), float(reward)]
168 | self.action_pub.publish(msg)
169 |
170 | if self.train_mode:
171 | self.append_sample((state, action, reward, next_state, done))
172 | self.train_model(done)
173 |
174 | state = next_state
175 |
176 | if done:
177 | avg_max_q = sum_max_q / local_step if local_step > 0 else 0.0
178 |
179 | msg = Float32MultiArray()
180 | msg.data = [float(score), float(avg_max_q)]
181 | self.result_pub.publish(msg)
182 |
183 | if LOGGING:
184 | self.dqn_reward_metric.update_state(score)
185 | with self.dqn_reward_writer.as_default():
186 | tensorflow.summary.scalar(
187 | 'dqn_reward', self.dqn_reward_metric.result(), step=episode_num
188 | )
189 | self.dqn_reward_metric.reset_states()
190 |
191 | print(
192 | 'Episode:', episode,
193 | 'score:', score,
194 | 'memory length:', len(self.replay_memory),
195 | 'epsilon:', self.epsilon)
196 |
197 | param_keys = ['epsilon', 'step']
198 | param_values = [self.epsilon, self.step_counter]
199 | param_dictionary = dict(zip(param_keys, param_values))
200 | break
201 |
202 | time.sleep(0.01)
203 |
204 | if self.train_mode:
205 | if episode % 100 == 0:
206 | self.model_path = os.path.join(
207 | self.model_dir_path,
208 | 'stage' + str(self.stage) + '_episode' + str(episode) + '.h5')
209 | self.model.save(self.model_path)
210 | with open(
211 | os.path.join(
212 | self.model_dir_path,
213 | 'stage' + str(self.stage) + '_episode' + str(episode) + '.json'
214 | ),
215 | 'w'
216 | ) as outfile:
217 | json.dump(param_dictionary, outfile)
218 |
219 | def env_make(self):
220 | while not self.make_environment_client.wait_for_service(timeout_sec=1.0):
221 | self.get_logger().warn(
222 | 'Environment make client failed to connect to the server, try again ...'
223 | )
224 |
225 | self.make_environment_client.call_async(Empty.Request())
226 |
227 | def reset_environment(self):
228 | while not self.reset_environment_client.wait_for_service(timeout_sec=1.0):
229 | self.get_logger().warn(
230 | 'Reset environment client failed to connect to the server, try again ...'
231 | )
232 |
233 | future = self.reset_environment_client.call_async(Dqn.Request())
234 |
235 | rclpy.spin_until_future_complete(self, future)
236 | if future.result() is not None:
237 | state = future.result().state
238 | state = numpy.reshape(numpy.asarray(state), [1, self.state_size])
239 | else:
240 | self.get_logger().error(
241 | 'Exception while calling service: {0}'.format(future.exception()))
242 |
243 | return state
244 |
245 | def get_action(self, state):
246 | if self.train_mode:
247 | self.step_counter += 1
248 | self.epsilon = self.epsilon_min + (1.0 - self.epsilon_min) * math.exp(
249 | -1.0 * self.step_counter / self.epsilon_decay)
250 | lucky = random.random()
251 | if lucky > (1 - self.epsilon):
252 | result = random.randint(0, self.action_size - 1)
253 | else:
254 | result = numpy.argmax(self.model.predict(state))
255 | else:
256 | result = numpy.argmax(self.model.predict(state))
257 |
258 | return result
259 |
260 | def step(self, action):
261 | req = Dqn.Request()
262 | req.action = action
263 |
264 | while not self.rl_agent_interface_client.wait_for_service(timeout_sec=1.0):
265 | self.get_logger().info('rl_agent interface service not available, waiting again...')
266 |
267 | future = self.rl_agent_interface_client.call_async(req)
268 |
269 | rclpy.spin_until_future_complete(self, future)
270 |
271 | if future.result() is not None:
272 | next_state = future.result().state
273 | next_state = numpy.reshape(numpy.asarray(next_state), [1, self.state_size])
274 | reward = future.result().reward
275 | done = future.result().done
276 | else:
277 | self.get_logger().error(
278 | 'Exception while calling service: {0}'.format(future.exception()))
279 |
280 | return next_state, reward, done
281 |
282 | def create_qnetwork(self):
283 | model = Sequential()
284 | model.add(Input(shape=(self.state_size,)))
285 | model.add(Dense(512, activation='relu'))
286 | model.add(Dense(256, activation='relu'))
287 | model.add(Dense(128, activation='relu'))
288 | model.add(Dense(self.action_size, activation='linear'))
289 | model.compile(loss=MeanSquaredError(), optimizer=Adam(learning_rate=self.learning_rate))
290 | model.summary()
291 |
292 | return model
293 |
294 | def update_target_model(self):
295 | self.target_model.set_weights(self.model.get_weights())
296 | self.target_update_after_counter = 0
297 | print('*Target model updated*')
298 |
299 | def append_sample(self, transition):
300 | self.replay_memory.append(transition)
301 |
302 | def train_model(self, terminal):
303 | if len(self.replay_memory) < self.min_replay_memory_size:
304 | return
305 | data_in_mini_batch = random.sample(self.replay_memory, self.batch_size)
306 |
307 | current_states = numpy.array([transition[0] for transition in data_in_mini_batch])
308 | current_states = current_states.squeeze()
309 | current_qvalues_list = self.model.predict(current_states)
310 |
311 | next_states = numpy.array([transition[3] for transition in data_in_mini_batch])
312 | next_states = next_states.squeeze()
313 | next_qvalues_list = self.target_model.predict(next_states)
314 |
315 | x_train = []
316 | y_train = []
317 |
318 | for index, (current_state, action, reward, _, done) in enumerate(data_in_mini_batch):
319 | current_q_values = current_qvalues_list[index]
320 |
321 | if not done:
322 | future_reward = numpy.max(next_qvalues_list[index])
323 | desired_q = reward + self.discount_factor * future_reward
324 | else:
325 | desired_q = reward
326 |
327 | current_q_values[action] = desired_q
328 | x_train.append(current_state)
329 | y_train.append(current_q_values)
330 |
331 | x_train = numpy.array(x_train)
332 | y_train = numpy.array(y_train)
333 | x_train = numpy.reshape(x_train, [len(data_in_mini_batch), self.state_size])
334 | y_train = numpy.reshape(y_train, [len(data_in_mini_batch), self.action_size])
335 |
336 | self.model.fit(
337 | tensorflow.convert_to_tensor(x_train, tensorflow.float32),
338 | tensorflow.convert_to_tensor(y_train, tensorflow.float32),
339 | batch_size=self.batch_size, verbose=0
340 | )
341 | self.target_update_after_counter += 1
342 |
343 | if self.target_update_after_counter > self.update_target_after and terminal:
344 | self.update_target_model()
345 |
346 |
347 | def main(args=None):
348 | if args is None:
349 | args = sys.argv
350 | stage_num = args[1] if len(args) > 1 else '1'
351 | max_training_episodes = args[2] if len(args) > 2 else '1000'
352 | rclpy.init(args=args)
353 |
354 | dqn_agent = DQNAgent(stage_num, max_training_episodes)
355 | rclpy.spin(dqn_agent)
356 |
357 | dqn_agent.destroy_node()
358 | rclpy.shutdown()
359 |
360 |
361 | if __name__ == '__main__':
362 | main()
363 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/turtlebot3_dqn/dqn_environment.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #################################################################################
3 | # Copyright 2019 ROBOTIS CO., LTD.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #################################################################################
17 | #
18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee
19 |
20 | import math
21 | import os
22 |
23 | from geometry_msgs.msg import Twist
24 | from geometry_msgs.msg import TwistStamped
25 | from nav_msgs.msg import Odometry
26 | import numpy
27 | import rclpy
28 | from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
29 | from rclpy.node import Node
30 | from rclpy.qos import qos_profile_sensor_data
31 | from rclpy.qos import QoSProfile
32 | from sensor_msgs.msg import LaserScan
33 | from std_srvs.srv import Empty
34 |
35 | from turtlebot3_msgs.srv import Dqn
36 | from turtlebot3_msgs.srv import Goal
37 |
38 |
39 | ROS_DISTRO = os.environ.get('ROS_DISTRO')
40 |
41 |
42 | class RLEnvironment(Node):
43 |
44 | def __init__(self):
45 | super().__init__('rl_environment')
46 | self.goal_pose_x = 0.0
47 | self.goal_pose_y = 0.0
48 | self.robot_pose_x = 0.0
49 | self.robot_pose_y = 0.0
50 |
51 | self.action_size = 5
52 | self.max_step = 800
53 |
54 | self.done = False
55 | self.fail = False
56 | self.succeed = False
57 |
58 | self.goal_angle = 0.0
59 | self.goal_distance = 1.0
60 | self.init_goal_distance = 0.5
61 | self.scan_ranges = []
62 | self.front_ranges = []
63 | self.min_obstacle_distance = 10.0
64 | self.is_front_min_actual_front = False
65 |
66 | self.local_step = 0
67 | self.stop_cmd_vel_timer = None
68 | self.angular_vel = [1.5, 0.75, 0.0, -0.75, -1.5]
69 |
70 | qos = QoSProfile(depth=10)
71 |
72 | if ROS_DISTRO == 'humble':
73 | self.cmd_vel_pub = self.create_publisher(Twist, 'cmd_vel', qos)
74 | else:
75 | self.cmd_vel_pub = self.create_publisher(TwistStamped, 'cmd_vel', qos)
76 |
77 | self.odom_sub = self.create_subscription(
78 | Odometry,
79 | 'odom',
80 | self.odom_sub_callback,
81 | qos
82 | )
83 | self.scan_sub = self.create_subscription(
84 | LaserScan,
85 | 'scan',
86 | self.scan_sub_callback,
87 | qos_profile_sensor_data
88 | )
89 |
90 | self.clients_callback_group = MutuallyExclusiveCallbackGroup()
91 | self.task_succeed_client = self.create_client(
92 | Goal,
93 | 'task_succeed',
94 | callback_group=self.clients_callback_group
95 | )
96 | self.task_failed_client = self.create_client(
97 | Goal,
98 | 'task_failed',
99 | callback_group=self.clients_callback_group
100 | )
101 | self.initialize_environment_client = self.create_client(
102 | Goal,
103 | 'initialize_env',
104 | callback_group=self.clients_callback_group
105 | )
106 |
107 | self.rl_agent_interface_service = self.create_service(
108 | Dqn,
109 | 'rl_agent_interface',
110 | self.rl_agent_interface_callback
111 | )
112 | self.make_environment_service = self.create_service(
113 | Empty,
114 | 'make_environment',
115 | self.make_environment_callback
116 | )
117 | self.reset_environment_service = self.create_service(
118 | Dqn,
119 | 'reset_environment',
120 | self.reset_environment_callback
121 | )
122 |
123 | def make_environment_callback(self, request, response):
124 | self.get_logger().info('Make environment called')
125 | while not self.initialize_environment_client.wait_for_service(timeout_sec=1.0):
126 | self.get_logger().warn(
127 | 'service for initialize the environment is not available, waiting ...'
128 | )
129 | future = self.initialize_environment_client.call_async(Goal.Request())
130 | rclpy.spin_until_future_complete(self, future)
131 | response_goal = future.result()
132 | if not response_goal.success:
133 | self.get_logger().error('initialize environment request failed')
134 | else:
135 | self.goal_pose_x = response_goal.pose_x
136 | self.goal_pose_y = response_goal.pose_y
137 | self.get_logger().info(
138 | 'goal initialized at [%f, %f]' % (self.goal_pose_x, self.goal_pose_y)
139 | )
140 |
141 | return response
142 |
143 | def reset_environment_callback(self, request, response):
144 | state = self.calculate_state()
145 | self.init_goal_distance = state[0]
146 | self.prev_goal_distance = self.init_goal_distance
147 | response.state = state
148 |
149 | return response
150 |
151 | def call_task_succeed(self):
152 | while not self.task_succeed_client.wait_for_service(timeout_sec=1.0):
153 | self.get_logger().warn('service for task succeed is not available, waiting ...')
154 | future = self.task_succeed_client.call_async(Goal.Request())
155 | rclpy.spin_until_future_complete(self, future)
156 | if future.result() is not None:
157 | response = future.result()
158 | self.goal_pose_x = response.pose_x
159 | self.goal_pose_y = response.pose_y
160 | self.get_logger().info('service for task succeed finished')
161 | else:
162 | self.get_logger().error('task succeed service call failed')
163 |
164 | def call_task_failed(self):
165 | while not self.task_failed_client.wait_for_service(timeout_sec=1.0):
166 | self.get_logger().warn('service for task failed is not available, waiting ...')
167 | future = self.task_failed_client.call_async(Goal.Request())
168 | rclpy.spin_until_future_complete(self, future)
169 | if future.result() is not None:
170 | response = future.result()
171 | self.goal_pose_x = response.pose_x
172 | self.goal_pose_y = response.pose_y
173 | self.get_logger().info('service for task failed finished')
174 | else:
175 | self.get_logger().error('task failed service call failed')
176 |
177 | def scan_sub_callback(self, scan):
178 | self.scan_ranges = []
179 | self.front_ranges = []
180 | self.front_angles = []
181 |
182 | num_of_lidar_rays = len(scan.ranges)
183 | angle_min = scan.angle_min
184 | angle_increment = scan.angle_increment
185 |
186 | self.front_distance = scan.ranges[0]
187 |
188 | for i in range(num_of_lidar_rays):
189 | angle = angle_min + i * angle_increment
190 | distance = scan.ranges[i]
191 |
192 | if distance == float('Inf'):
193 | distance = 3.5
194 | elif numpy.isnan(distance):
195 | distance = 0.0
196 |
197 | self.scan_ranges.append(distance)
198 |
199 | if (0 <= angle <= math.pi/2) or (3*math.pi/2 <= angle <= 2*math.pi):
200 | self.front_ranges.append(distance)
201 | self.front_angles.append(angle)
202 |
203 | self.min_obstacle_distance = min(self.scan_ranges)
204 | self.front_min_obstacle_distance = min(self.front_ranges) if self.front_ranges else 10.0
205 |
206 | def odom_sub_callback(self, msg):
207 | self.robot_pose_x = msg.pose.pose.position.x
208 | self.robot_pose_y = msg.pose.pose.position.y
209 | _, _, self.robot_pose_theta = self.euler_from_quaternion(msg.pose.pose.orientation)
210 |
211 | goal_distance = math.sqrt(
212 | (self.goal_pose_x - self.robot_pose_x) ** 2
213 | + (self.goal_pose_y - self.robot_pose_y) ** 2)
214 | path_theta = math.atan2(
215 | self.goal_pose_y - self.robot_pose_y,
216 | self.goal_pose_x - self.robot_pose_x)
217 |
218 | goal_angle = path_theta - self.robot_pose_theta
219 | if goal_angle > math.pi:
220 | goal_angle -= 2 * math.pi
221 |
222 | elif goal_angle < -math.pi:
223 | goal_angle += 2 * math.pi
224 |
225 | self.goal_distance = goal_distance
226 | self.goal_angle = goal_angle
227 |
228 | def calculate_state(self):
229 | state = []
230 | state.append(float(self.goal_distance))
231 | state.append(float(self.goal_angle))
232 | for var in self.front_ranges:
233 | state.append(float(var))
234 | self.local_step += 1
235 |
236 | if self.goal_distance < 0.20:
237 | self.get_logger().info('Goal Reached')
238 | self.succeed = True
239 | self.done = True
240 | if ROS_DISTRO == 'humble':
241 | self.cmd_vel_pub.publish(Twist())
242 | else:
243 | self.cmd_vel_pub.publish(TwistStamped())
244 | self.local_step = 0
245 | self.call_task_succeed()
246 |
247 | if self.min_obstacle_distance < 0.15:
248 | self.get_logger().info('Collision happened')
249 | self.fail = True
250 | self.done = True
251 | if ROS_DISTRO == 'humble':
252 | self.cmd_vel_pub.publish(Twist())
253 | else:
254 | self.cmd_vel_pub.publish(TwistStamped())
255 | self.local_step = 0
256 | self.call_task_failed()
257 |
258 | if self.local_step == self.max_step:
259 | self.get_logger().info('Time out!')
260 | self.fail = True
261 | self.done = True
262 | if ROS_DISTRO == 'humble':
263 | self.cmd_vel_pub.publish(Twist())
264 | else:
265 | self.cmd_vel_pub.publish(TwistStamped())
266 | self.local_step = 0
267 | self.call_task_failed()
268 |
269 | return state
270 |
271 | def compute_directional_weights(self, relative_angles, max_weight=10.0):
272 | power = 6
273 | raw_weights = (numpy.cos(relative_angles))**power + 0.1
274 | scaled_weights = raw_weights * (max_weight / numpy.max(raw_weights))
275 | normalized_weights = scaled_weights / numpy.sum(scaled_weights)
276 | return normalized_weights
277 |
278 | def compute_weighted_obstacle_reward(self):
279 | if not self.front_ranges or not self.front_angles:
280 | return 0.0
281 |
282 | front_ranges = numpy.array(self.front_ranges)
283 | front_angles = numpy.array(self.front_angles)
284 |
285 | valid_mask = front_ranges <= 0.5
286 | if not numpy.any(valid_mask):
287 | return 0.0
288 |
289 | front_ranges = front_ranges[valid_mask]
290 | front_angles = front_angles[valid_mask]
291 |
292 | relative_angles = numpy.unwrap(front_angles)
293 | relative_angles[relative_angles > numpy.pi] -= 2 * numpy.pi
294 |
295 | weights = self.compute_directional_weights(relative_angles, max_weight=10.0)
296 |
297 | safe_dists = numpy.clip(front_ranges - 0.25, 1e-2, 3.5)
298 | decay = numpy.exp(-3.0 * safe_dists)
299 |
300 | weighted_decay = numpy.dot(weights, decay)
301 |
302 | reward = - (1.0 + 4.0 * weighted_decay)
303 |
304 | return reward
305 |
306 | def calculate_reward(self):
307 | yaw_reward = 1 - (2 * abs(self.goal_angle) / math.pi)
308 | obstacle_reward = self.compute_weighted_obstacle_reward()
309 |
310 | print('directional_reward: %f, obstacle_reward: %f' % (yaw_reward, obstacle_reward))
311 | reward = yaw_reward + obstacle_reward
312 |
313 | if self.succeed:
314 | reward = 100.0
315 | elif self.fail:
316 | reward = -50.0
317 |
318 | return reward
319 |
320 | def rl_agent_interface_callback(self, request, response):
321 | action = request.action
322 | if ROS_DISTRO == 'humble':
323 | msg = Twist()
324 | msg.linear.x = 0.2
325 | msg.angular.z = self.angular_vel[action]
326 | else:
327 | msg = TwistStamped()
328 | msg.twist.linear.x = 0.2
329 | msg.twist.angular.z = self.angular_vel[action]
330 |
331 | self.cmd_vel_pub.publish(msg)
332 | if self.stop_cmd_vel_timer is None:
333 | self.prev_goal_distance = self.init_goal_distance
334 | self.stop_cmd_vel_timer = self.create_timer(0.8, self.timer_callback)
335 | else:
336 | self.destroy_timer(self.stop_cmd_vel_timer)
337 | self.stop_cmd_vel_timer = self.create_timer(0.8, self.timer_callback)
338 |
339 | response.state = self.calculate_state()
340 | response.reward = self.calculate_reward()
341 | response.done = self.done
342 |
343 | if self.done is True:
344 | self.done = False
345 | self.succeed = False
346 | self.fail = False
347 |
348 | return response
349 |
350 | def timer_callback(self):
351 | self.get_logger().info('Stop called')
352 | if ROS_DISTRO == 'humble':
353 | self.cmd_vel_pub.publish(Twist())
354 | else:
355 | self.cmd_vel_pub.publish(TwistStamped())
356 | self.destroy_timer(self.stop_cmd_vel_timer)
357 |
358 | def euler_from_quaternion(self, quat):
359 | x = quat.x
360 | y = quat.y
361 | z = quat.z
362 | w = quat.w
363 |
364 | sinr_cosp = 2 * (w * x + y * z)
365 | cosr_cosp = 1 - 2 * (x * x + y * y)
366 | roll = numpy.arctan2(sinr_cosp, cosr_cosp)
367 |
368 | sinp = 2 * (w * y - z * x)
369 | pitch = numpy.arcsin(sinp)
370 |
371 | siny_cosp = 2 * (w * z + x * y)
372 | cosy_cosp = 1 - 2 * (y * y + z * z)
373 | yaw = numpy.arctan2(siny_cosp, cosy_cosp)
374 |
375 | return roll, pitch, yaw
376 |
377 |
378 | def main(args=None):
379 | rclpy.init(args=args)
380 | rl_environment = RLEnvironment()
381 | try:
382 | while rclpy.ok():
383 | rclpy.spin_once(rl_environment, timeout_sec=0.1)
384 | except KeyboardInterrupt:
385 | pass
386 | finally:
387 | rl_environment.destroy_node()
388 | rclpy.shutdown()
389 |
390 |
391 | if __name__ == '__main__':
392 | main()
393 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/turtlebot3_dqn/dqn_gazebo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #################################################################################
3 | # Copyright 2019 ROBOTIS CO., LTD.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #################################################################################
17 | #
18 | # # Authors: Ryan Shim, Gilbert, ChanHyeong Lee
19 |
20 | import os
21 | import random
22 | import subprocess
23 | import sys
24 | import time
25 |
26 | from ament_index_python.packages import get_package_share_directory
27 | import rclpy
28 | from rclpy.callback_groups import MutuallyExclusiveCallbackGroup
29 | from rclpy.node import Node
30 | from std_srvs.srv import Empty
31 |
32 | from turtlebot3_msgs.srv import Goal
33 |
34 |
35 | ROS_DISTRO = os.environ.get('ROS_DISTRO')
36 | if ROS_DISTRO == 'humble':
37 | from gazebo_msgs.srv import DeleteEntity
38 | from gazebo_msgs.srv import SpawnEntity
39 | from geometry_msgs.msg import Pose
40 |
41 |
42 | class GazeboInterface(Node):
43 |
44 | def __init__(self, stage_num):
45 | super().__init__('gazebo_interface')
46 | self.stage = int(stage_num)
47 |
48 | self.entity_name = 'goal_box'
49 | self.entity_pose_x = 0.5
50 | self.entity_pose_y = 0.0
51 |
52 | if ROS_DISTRO == 'humble':
53 | self.entity = None
54 | self.open_entity()
55 | self.delete_entity_client = self.create_client(DeleteEntity, 'delete_entity')
56 | self.spawn_entity_client = self.create_client(SpawnEntity, 'spawn_entity')
57 | self.reset_simulation_client = self.create_client(Empty, 'reset_simulation')
58 |
59 | self.callback_group = MutuallyExclusiveCallbackGroup()
60 | self.initialize_env_service = self.create_service(
61 | Goal,
62 | 'initialize_env',
63 | self.initialize_env_callback,
64 | callback_group=self.callback_group
65 | )
66 | self.task_succeed_service = self.create_service(
67 | Goal,
68 | 'task_succeed',
69 | self.task_succeed_callback,
70 | callback_group=self.callback_group
71 | )
72 | self.task_failed_service = self.create_service(
73 | Goal,
74 | 'task_failed',
75 | self.task_failed_callback,
76 | callback_group=self.callback_group
77 | )
78 |
79 | def open_entity(self):
80 | try:
81 | package_share = get_package_share_directory('turtlebot3_gazebo')
82 | model_path = os.path.join(
83 | package_share, 'models', 'turtlebot3_dqn_world', 'goal_box', 'model.sdf'
84 | )
85 | with open(model_path, 'r') as f:
86 | self.entity = f.read()
87 | self.get_logger().info('Loaded entity from: ' + model_path)
88 | except Exception as e:
89 | self.get_logger().error('Failed to load entity file: {}'.format(e))
90 | raise e
91 |
92 | def spawn_entity(self):
93 | if ROS_DISTRO == 'humble':
94 | entity_pose = Pose()
95 | entity_pose.position.x = self.entity_pose_x
96 | entity_pose.position.y = self.entity_pose_y
97 |
98 | spawn_req = SpawnEntity.Request()
99 | spawn_req.name = self.entity_name
100 | spawn_req.xml = self.entity
101 | spawn_req.initial_pose = entity_pose
102 |
103 | while not self.spawn_entity_client.wait_for_service(timeout_sec=1.0):
104 | self.get_logger().warn('service for spawn_entity is not available, waiting ...')
105 | future = self.spawn_entity_client.call_async(spawn_req)
106 | rclpy.spin_until_future_complete(self, future)
107 | print(f'Spawn Goal at ({self.entity_pose_x}, {self.entity_pose_y}, {0.0})')
108 | else:
109 | service_name = '/world/dqn/create'
110 | package_share = get_package_share_directory('turtlebot3_gazebo')
111 | model_path = os.path.join(
112 | package_share, 'models', 'turtlebot3_dqn_world', 'goal_box', 'model.sdf'
113 | )
114 | req = (
115 | f'sdf_filename: "{model_path}", '
116 | f'name: "{self.entity_name}", '
117 | f'pose: {{ position: {{ '
118 | f'x: {self.entity_pose_x}, '
119 | f'y: {self.entity_pose_y}, '
120 | f'z: 0.0 }} }}'
121 | )
122 | cmd = [
123 | 'gz', 'service',
124 | '-s', service_name,
125 | '--reqtype', 'gz.msgs.EntityFactory',
126 | '--reptype', 'gz.msgs.Boolean',
127 | '--timeout', '1000',
128 | '--req', req
129 | ]
130 | try:
131 | subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL)
132 | print(f'Spawn Goal at ({self.entity_pose_x}, {self.entity_pose_y}, {0.0})')
133 | except subprocess.CalledProcessError:
134 | pass
135 |
136 | def delete_entity(self):
137 | if ROS_DISTRO == 'humble':
138 | delete_req = DeleteEntity.Request()
139 | delete_req.name = self.entity_name
140 |
141 | while not self.delete_entity_client.wait_for_service(timeout_sec=1.0):
142 | self.get_logger().warn('service for delete_entity is not available, waiting ...')
143 | future = self.delete_entity_client.call_async(delete_req)
144 | rclpy.spin_until_future_complete(self, future)
145 | print('Delete Goal')
146 | else:
147 | service_name = '/world/dqn/remove'
148 | req = f'name: "{self.entity_name}", type: 2'
149 | cmd = [
150 | 'gz', 'service',
151 | '-s', service_name,
152 | '--reqtype', 'gz.msgs.Entity',
153 | '--reptype', 'gz.msgs.Boolean',
154 | '--timeout', '1000',
155 | '--req', req
156 | ]
157 | try:
158 | subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL)
159 | print('Delete Goal')
160 | except subprocess.CalledProcessError:
161 | pass
162 |
163 | def reset_simulation(self):
164 | reset_req = Empty.Request()
165 |
166 | while not self.reset_simulation_client.wait_for_service(timeout_sec=1.0):
167 | self.get_logger().warn('service for reset_simulation is not available, waiting ...')
168 |
169 | self.reset_simulation_client.call_async(reset_req)
170 |
171 | def reset_burger(self):
172 | service_name_delete = '/world/dqn/remove'
173 | req_delete = 'name: "burger", type: 2'
174 | cmd_delete = [
175 | 'gz', 'service',
176 | '-s', service_name_delete,
177 | '--reqtype', 'gz.msgs.Entity',
178 | '--reptype', 'gz.msgs.Boolean',
179 | '--timeout', '1000',
180 | '--req', req_delete
181 | ]
182 | try:
183 | subprocess.run(cmd_delete, check=True, stdout=subprocess.DEVNULL)
184 | print('Delete Burger')
185 | except subprocess.CalledProcessError:
186 | pass
187 | time.sleep(0.2)
188 | service_name_spawn = '/world/dqn/create'
189 | package_share = get_package_share_directory('turtlebot3_gazebo')
190 | model_path = os.path.join(package_share, 'models', 'turtlebot3_burger', 'model.sdf')
191 | req_spawn = (
192 | f'sdf_filename: "{model_path}", '
193 | f'name: "burger", '
194 | f'pose: {{ position: {{ x: 0.0, y: 0.0, z: 0.0 }} }}'
195 | )
196 | cmd_spawn = [
197 | 'gz', 'service',
198 | '-s', service_name_spawn,
199 | '--reqtype', 'gz.msgs.EntityFactory',
200 | '--reptype', 'gz.msgs.Boolean',
201 | '--timeout', '1000',
202 | '--req', req_spawn
203 | ]
204 | try:
205 | subprocess.run(cmd_spawn, check=True, stdout=subprocess.DEVNULL)
206 | print('Spawn Burger')
207 | except subprocess.CalledProcessError:
208 | pass
209 |
210 | def task_succeed_callback(self, request, response):
211 | self.delete_entity()
212 | time.sleep(0.2)
213 | self.generate_goal_pose()
214 | time.sleep(0.2)
215 | self.spawn_entity()
216 | response.pose_x = self.entity_pose_x
217 | response.pose_y = self.entity_pose_y
218 | response.success = True
219 | return response
220 |
221 | def task_failed_callback(self, request, response):
222 | self.delete_entity()
223 | time.sleep(0.2)
224 | if ROS_DISTRO == 'humble':
225 | self.reset_simulation()
226 | else:
227 | self.reset_burger()
228 | time.sleep(0.2)
229 | self.generate_goal_pose()
230 | time.sleep(0.2)
231 | self.spawn_entity()
232 | response.pose_x = self.entity_pose_x
233 | response.pose_y = self.entity_pose_y
234 | response.success = True
235 | return response
236 |
237 | def initialize_env_callback(self, request, response):
238 | self.delete_entity()
239 | time.sleep(0.2)
240 | if ROS_DISTRO == 'humble':
241 | self.reset_simulation()
242 | else:
243 | self.reset_burger()
244 | time.sleep(0.2)
245 | self.spawn_entity()
246 | response.pose_x = self.entity_pose_x
247 | response.pose_y = self.entity_pose_y
248 | response.success = True
249 | return response
250 |
251 | def generate_goal_pose(self):
252 | if self.stage != 4:
253 | self.entity_pose_x = random.randrange(-21, 21) / 10
254 | self.entity_pose_y = random.randrange(-21, 21) / 10
255 | else:
256 | goal_pose_list = [
257 | [1.0, 0.0], [2.0, -1.5], [0.0, -2.0], [2.0, 1.5], [0.5, 2.0], [-1.5, 2.1],
258 | [-2.0, 0.5], [-2.0, -0.5], [-1.5, -2.0], [-0.5, -1.0], [2.0, -0.5], [-1.0, -1.0]
259 | ]
260 | rand_index = random.randint(0, len(goal_pose_list) - 1)
261 | self.entity_pose_x = goal_pose_list[rand_index][0]
262 | self.entity_pose_y = goal_pose_list[rand_index][1]
263 |
264 |
265 | def main(args=None):
266 | rclpy.init(args=sys.argv)
267 | stage_num = sys.argv[1] if len(sys.argv) > 1 else '1'
268 | gazebo_interface = GazeboInterface(stage_num)
269 | try:
270 | while rclpy.ok():
271 | rclpy.spin_once(gazebo_interface, timeout_sec=0.1)
272 | except KeyboardInterrupt:
273 | pass
274 | finally:
275 | gazebo_interface.destroy_node()
276 | rclpy.shutdown()
277 |
278 |
279 | if __name__ == '__main__':
280 | main()
281 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/turtlebot3_dqn/dqn_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #################################################################################
3 | # Copyright 2019 ROBOTIS CO., LTD.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #################################################################################
17 | #
18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee
19 |
20 | import collections
21 | import os
22 | import sys
23 | import time
24 |
25 | import numpy
26 | import rclpy
27 | from rclpy.node import Node
28 | from tensorflow.keras.layers import Dense
29 | from tensorflow.keras.losses import MeanSquaredError
30 | from tensorflow.keras.models import load_model
31 | from tensorflow.keras.models import Sequential
32 | from tensorflow.keras.optimizers import RMSprop
33 |
34 | from turtlebot3_msgs.srv import Dqn
35 |
36 |
37 | class DQNTest(Node):
38 |
39 | def __init__(self, stage, load_episode):
40 | super().__init__('dqn_test')
41 |
42 | self.stage = int(stage)
43 | self.load_episode = int(load_episode)
44 |
45 | self.state_size = 26
46 | self.action_size = 5
47 |
48 | self.memory = collections.deque(maxlen=1000000)
49 |
50 | self.model = self.build_model()
51 | model_path = os.path.join(
52 | os.path.dirname(os.path.dirname(os.path.realpath(__file__))),
53 | 'saved_model',
54 | f'stage{self.stage}_episode{self.load_episode}.h5'
55 | )
56 |
57 | loaded_model = load_model(
58 | model_path, compile=False, custom_objects={'mse': MeanSquaredError()}
59 | )
60 | self.model.set_weights(loaded_model.get_weights())
61 |
62 | self.rl_agent_interface_client = self.create_client(Dqn, 'rl_agent_interface')
63 |
64 | self.run_test()
65 |
66 | def build_model(self):
67 | model = Sequential()
68 | model.add(Dense(
69 | 512, input_shape=(self.state_size,),
70 | activation='relu',
71 | kernel_initializer='lecun_uniform'
72 | ))
73 | model.add(Dense(256, activation='relu', kernel_initializer='lecun_uniform'))
74 | model.add(Dense(128, activation='relu', kernel_initializer='lecun_uniform'))
75 | model.add(Dense(self.action_size, activation='linear', kernel_initializer='lecun_uniform'))
76 | model.compile(loss=MeanSquaredError(), optimizer=RMSprop(learning_rate=0.00025))
77 | return model
78 |
79 | def get_action(self, state):
80 | state = numpy.asarray(state)
81 | q_values = self.model.predict(state.reshape(1, -1), verbose=0)
82 | return int(numpy.argmax(q_values[0]))
83 |
84 | def run_test(self):
85 | while True:
86 | done = False
87 | init = True
88 | score = 0
89 | local_step = 0
90 | next_state = []
91 |
92 | time.sleep(1.0)
93 |
94 | while not done:
95 | local_step += 1
96 | action = 2 if local_step == 1 else self.get_action(next_state)
97 |
98 | req = Dqn.Request()
99 | req.action = action
100 | req.init = init
101 |
102 | while not self.rl_agent_interface_client.wait_for_service(timeout_sec=1.0):
103 | self.get_logger().warn(
104 | 'rl_agent interface service not available, waiting again...')
105 |
106 | future = self.rl_agent_interface_client.call_async(req)
107 | rclpy.spin_until_future_complete(self, future)
108 |
109 | if future.done() and future.result() is not None:
110 | next_state = future.result().state
111 | reward = future.result().reward
112 | done = future.result().done
113 | score += reward
114 | init = False
115 | else:
116 | self.get_logger().error(f'Service call failure: {future.exception()}')
117 |
118 | time.sleep(0.01)
119 |
120 |
121 | def main(args=None):
122 | rclpy.init(args=args if args else sys.argv)
123 | stage = sys.argv[1] if len(sys.argv) > 1 else '1'
124 | load_episode = sys.argv[2] if len(sys.argv) > 2 else '600'
125 | node = DQNTest(stage, load_episode)
126 |
127 | try:
128 | rclpy.spin(node)
129 | except KeyboardInterrupt:
130 | pass
131 | finally:
132 | node.destroy_node()
133 | rclpy.shutdown()
134 |
135 |
136 | if __name__ == '__main__':
137 | main()
138 |
--------------------------------------------------------------------------------
/turtlebot3_dqn/turtlebot3_dqn/result_graph.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #################################################################################
3 | # Copyright 2018 ROBOTIS CO., LTD.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #################################################################################
17 | #
18 | # Authors: Ryan Shim, Gilbert, ChanHyeong Lee
19 |
20 | import signal
21 | import sys
22 | import threading
23 |
24 | from PyQt5.QtCore import QTimer
25 | from PyQt5.QtWidgets import QApplication
26 | from PyQt5.QtWidgets import QMainWindow
27 | import pyqtgraph
28 | import rclpy
29 | from rclpy.node import Node
30 | from std_msgs.msg import Float32MultiArray
31 |
32 |
33 | class GraphSubscriber(Node):
34 |
35 | def __init__(self, window):
36 | super().__init__('graph')
37 |
38 | self.window = window
39 |
40 | self.subscription = self.create_subscription(
41 | Float32MultiArray,
42 | '/result',
43 | self.data_callback,
44 | 10
45 | )
46 | self.subscription
47 |
48 | def data_callback(self, msg):
49 | self.window.receive_data(msg)
50 |
51 |
52 | class Window(QMainWindow):
53 |
54 | def __init__(self):
55 | super(Window, self).__init__()
56 |
57 | self.setWindowTitle('Result')
58 | self.setGeometry(50, 50, 600, 650)
59 |
60 | self.ep = []
61 | self.data_list = []
62 | self.rewards = []
63 | self.count = 1
64 |
65 | self.plot()
66 |
67 | self.ros_subscriber = GraphSubscriber(self)
68 | self.ros_thread = threading.Thread(
69 | target=rclpy.spin, args=(self.ros_subscriber,), daemon=True
70 | )
71 | self.ros_thread.start()
72 |
73 | def receive_data(self, msg):
74 | self.data_list.append(msg.data[0])
75 | self.ep.append(self.count)
76 | self.count += 1
77 | self.rewards.append(msg.data[1])
78 |
79 | def plot(self):
80 | self.qValuePlt = pyqtgraph.PlotWidget(self, title='Average max Q-value')
81 | self.qValuePlt.setGeometry(0, 320, 600, 300)
82 |
83 | self.rewardsPlt = pyqtgraph.PlotWidget(self, title='Total reward')
84 | self.rewardsPlt.setGeometry(0, 10, 600, 300)
85 |
86 | self.timer = QTimer()
87 | self.timer.timeout.connect(self.update)
88 | self.timer.start(200)
89 |
90 | self.show()
91 |
92 | def update(self):
93 | self.rewardsPlt.showGrid(x=True, y=True)
94 | self.qValuePlt.showGrid(x=True, y=True)
95 |
96 | self.rewardsPlt.plot(self.ep, self.data_list, pen=(255, 0, 0), clear=True)
97 | self.qValuePlt.plot(self.ep, self.rewards, pen=(0, 255, 0), clear=True)
98 |
99 | def closeEvent(self, event):
100 | if self.ros_subscriber is not None:
101 | self.ros_subscriber.destroy_node()
102 | rclpy.shutdown()
103 | event.accept()
104 |
105 |
106 | def main():
107 | rclpy.init()
108 | app = QApplication(sys.argv)
109 | win = Window()
110 |
111 | def shutdown_handler(sig, frame):
112 | print('shutdown')
113 | if win.ros_subscriber is not None:
114 | win.ros_subscriber.destroy_node()
115 | rclpy.shutdown()
116 | app.quit()
117 |
118 | signal.signal(signal.SIGINT, shutdown_handler)
119 | signal.signal(signal.SIGTERM, shutdown_handler)
120 | sys.exit(app.exec())
121 |
122 |
123 | if __name__ == '__main__':
124 | main()
125 |
--------------------------------------------------------------------------------
/turtlebot3_machine_learning/CHANGELOG.rst:
--------------------------------------------------------------------------------
1 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
2 | Changelog for package turtlebot3_machine_learning
3 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4 |
5 | 1.0.1 (2025-05-02)
6 | ------------------
7 | * Support for ROS 2 Jazzy version
8 | * Gazebo simulation support for the package
9 | * Contributors: ChanHyeong Lee
10 |
11 | 1.0.0 (2025-04-17)
12 | ------------------
13 | * Support for ROS 2 Humble version
14 | * Renewal of package structure
15 | * Improved behavioral rewards for agents
16 | * Contributors: ChanHyeong Lee
17 |
--------------------------------------------------------------------------------
/turtlebot3_machine_learning/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | ################################################################################
2 | # Set minimum required version of cmake, project name and compile options
3 | ################################################################################
4 | cmake_minimum_required(VERSION 3.5)
5 | project(turtlebot3_machine_learning)
6 |
7 | if(NOT CMAKE_CXX_STANDARD)
8 | set(CMAKE_CXX_STANDARD 17)
9 | endif()
10 |
11 | ################################################################################
12 | # Find ament packages and libraries for ament and system dependencies
13 | ################################################################################
14 | find_package(ament_cmake REQUIRED)
15 |
16 | ################################################################################
17 | # Macro for ament package
18 | ################################################################################
19 | ament_package()
20 |
--------------------------------------------------------------------------------
/turtlebot3_machine_learning/package.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | turtlebot3_machine_learning
4 | 1.0.1
5 |
6 | This metapackage for ROS 2 TurtleBot3 machine learning.
7 |
8 | Pyo
9 | Apache 2.0
10 | http://turtlebot3.robotis.com
11 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning
12 | https://github.com/ROBOTIS-GIT/turtlebot3_machine_learning/issues
13 | Gilbert
14 | ChanHyeong Lee
15 | turtlebot3_dqn
16 |
17 | ament_cmake
18 |
19 |
20 |
--------------------------------------------------------------------------------
/turtlebot3_machine_learning_ci.repos:
--------------------------------------------------------------------------------
1 | repositories:
2 | turtlebot3_msgs:
3 | type: git
4 | url: https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git
5 | version: main
6 |
--------------------------------------------------------------------------------