├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs ├── environment │ ├── environment.html │ └── index.html ├── index.html ├── naf_components │ ├── index.html │ ├── naf_algorithm.html │ └── naf_neural_network.html ├── rl_framework.html └── utils │ ├── collision_detector.html │ ├── exceptions.html │ ├── index.html │ ├── logger.html │ └── replay_buffer.html ├── pyproject.toml ├── robotic_manipulator_rloa ├── __init__.py ├── environment │ ├── __init__.py │ └── environment.py ├── naf_components │ ├── __init__.py │ ├── demo_weights │ │ ├── __init__.py │ │ ├── weights_kuka.p │ │ └── weights_xarm6.p │ ├── naf_algorithm.py │ └── naf_neural_network.py ├── rl_framework.py └── utils │ ├── __init__.py │ ├── collision_detector.py │ ├── exceptions.py │ ├── logger.py │ └── replay_buffer.py └── tests ├── __init__.py └── robotic_manipulator_rloa ├── __init__.py ├── environment ├── __init__.py └── test_environment.py ├── naf_components ├── __init__.py ├── test_naf_algorithm.py └── test_naf_neural_network.py ├── test_rl_framework.py └── utils ├── __init__.py ├── test_collision_detector.py ├── test_exceptions.py ├── test_logger.py └── test_replay_buffer.py /.gitignore: -------------------------------------------------------------------------------- 1 | robotic_manipulator_rloa.egg-info 2 | .idea 3 | dist 4 | .pypirc 5 | __pycache__/ 6 | checkpoints/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Javier Martinez 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include robotic_manipulator_rloa/naf_components/demo_weights/weights_kuka.p 2 | include robotic_manipulator_rloa/naf_components/demo_weights/weights_xarm6.p 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # robotic_manipulator_rloa 2 | 3 | **robotic_manipulator_rloa** is a framework for training Robotic Manipulators on the Obstacle Avoidance task through Reinforcement Learning. 4 | 5 | ## Installation 6 | 7 | Install the package with [pip](https://pip.pypa.io/en/stable/). 8 | 9 | ```bash 10 | $ pip install robotic-manipulator-rloa 11 | ``` 12 | 13 | > **_For Windows:_** If the installation is not successful because Microsoft Visual C++ 14.0 or greater is required, 14 | > download and install the Microsoft C++ Build Tools from here: https://visualstudio.microsoft.com/es/visual-cpp-build-tools/ 15 | 16 | ## Usage 17 | 18 | ### Execution of a demo training and testing process for the KUKA IIWA Robotic Manipulator 19 | 20 | ```python 21 | from robotic_manipulator_rloa import ManipulatorFramework 22 | 23 | # Initialize the framework 24 | mf = ManipulatorFramework() 25 | 26 | # Run a demo of the training process for the KUKA IIWA Robotic Manipulator 27 | mf.run_demo_training('kuka_training', verbose=False) 28 | 29 | # Run a demo of the testing process for the KUKA IIWA Robotic Manipulator 30 | mf.run_demo_testing('kuka_testing') 31 | ``` 32 | 33 | ### Execution of a training for the KUKA IIWA Robotic Manipulator 34 | 35 | ```python 36 | from robotic_manipulator_rloa import ManipulatorFramework 37 | 38 | # Initialize the framework 39 | mf = ManipulatorFramework() 40 | 41 | # Initialize KUKA IIWA Robotic Manipulator environment 42 | mf.initialize_environment(manipulator_file='kuka_iiwa/kuka_with_gripper2.sdf', 43 | endeffector_index=13, 44 | fixed_joints=[6, 7, 8, 9, 10, 11, 12, 13], 45 | involved_joints=[0, 1, 2, 3, 4, 5], 46 | target_position=[0.4, 0.85, 0.71], 47 | obstacle_position=[0.45, 0.55, 0.55], 48 | initial_joint_positions=[0.9, 0.45, 0, 0, 0, 0], 49 | initial_positions_variation_range=[0, 0, 0.5, 0.5, 0.5, 0.5], 50 | visualize=False) 51 | 52 | # Initialize NAF Agent (checkpoint files will be generated every 100 episodes) 53 | mf.initialize_naf_agent(checkpoint_frequency=100) 54 | 55 | # Run training for 3000 episodes, 400 timesteps per episode 56 | mf.run_training(3000, 400, verbose=False) 57 | ``` 58 | 59 | ### Execution of a testing process for the KUKA IIWA Robotic Manipulator (must execute a training for 3000 episodes before) 60 | 61 | ```python 62 | import os 63 | import pybullet_data 64 | from robotic_manipulator_rloa import ManipulatorFramework 65 | 66 | # Initialize the framework 67 | mf = ManipulatorFramework() 68 | 69 | # Initialize KUKA IIWA Robotic Manipulator environment 70 | kuka_path = os.path.join(pybullet_data.getDataPath(), 'kuka_iiwa/kuka_with_gripper2.sdf') 71 | mf.initialize_environment(manipulator_file=kuka_path, 72 | endeffector_index=13, 73 | fixed_joints=[6, 7, 8, 9, 10, 11, 12, 13], 74 | involved_joints=[0, 1, 2, 3, 4, 5], 75 | target_position=[0.4, 0.85, 0.71], 76 | obstacle_position=[0.45, 0.55, 0.55], 77 | initial_joint_positions=[0.9, 0.45, 0, 0, 0, 0], 78 | initial_positions_variation_range=[0, 0, .5, .5, .5, .5], 79 | visualize=False) 80 | 81 | # Initialize NAF Agent 82 | mf.initialize_naf_agent() 83 | 84 | # Load pretrained weights from .p file 85 | mf.load_pretrained_parameters_from_episode(3000) 86 | 87 | # Test the pretrained model for 50 test episodes, 750 timesteps each 88 | mf.test_trained_model(50, 750) 89 | 90 | ``` 91 | 92 | ## Contributing 93 | 94 | Pull requests are welcome! For major changes, please open an issue first 95 | to discuss what you would like to change. Please make sure to update and execute the tests! 96 | 97 | ```bash 98 | robotic_manipulator_rloa$ pytest --cov-report term-missing --cov=robotic_manipulator_rloa/ tests/robotic_manipulator_rloa/ 99 | ``` 100 | 101 | ## License 102 | 103 | [MIT License](https://choosealicense.com/licenses/mit/) -------------------------------------------------------------------------------- /docs/environment/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | robotic_manipulator_rloa.environment API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module robotic_manipulator_rloa.environment

23 |
24 |
25 |
26 |
27 |

Sub-modules

28 |
29 |
robotic_manipulator_rloa.environment.environment
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | 60 |
61 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | robotic_manipulator_rloa API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Package robotic_manipulator_rloa

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
from .rl_framework import ManipulatorFramework
30 |
31 |
32 |
33 |

Sub-modules

34 |
35 |
robotic_manipulator_rloa.environment
36 |
37 |
38 |
39 |
robotic_manipulator_rloa.naf_components
40 |
41 |
42 |
43 |
robotic_manipulator_rloa.rl_framework
44 |
45 |
46 |
47 |
robotic_manipulator_rloa.utils
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 | 76 |
77 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /docs/naf_components/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | robotic_manipulator_rloa.naf_components API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module robotic_manipulator_rloa.naf_components

23 |
24 |
25 |
26 |
27 |

Sub-modules

28 |
29 |
robotic_manipulator_rloa.naf_components.naf_algorithm
30 |
31 |
32 |
33 |
robotic_manipulator_rloa.naf_components.naf_neural_network
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 | 65 |
66 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /docs/utils/collision_detector.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | robotic_manipulator_rloa.utils.collision_detector API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module robotic_manipulator_rloa.utils.collision_detector

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
from dataclasses import dataclass
 30 | from typing import List
 31 | 
 32 | import numpy as np
 33 | import pybullet as p
 34 | from numpy.typing import NDArray
 35 | 
 36 | 
 37 | @dataclass
 38 | class CollisionObject:
 39 |     """
 40 |     Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to
 41 |     calculate distances to other bodies.
 42 |     """
 43 |     body: str
 44 |     link: int
 45 | 
 46 | 
 47 | class CollisionDetector:
 48 | 
 49 |     def __init__(self, collision_object: CollisionObject, obstacle_ids: List[str]):
 50 |         """
 51 |         Calculates distances between bodies' joints.
 52 |         Args:
 53 |             collision_object: CollisionObject instance, which indicates the body/joint from
 54 |                 which to calculate distances/collisions.
 55 |             obstacle_ids: Obstacle body UID. Distances are calculated from the joint/body given in the
 56 |                 "collision_object" parameter to the "obstacle_ids" bodies.
 57 |         """
 58 |         self.obstacles = obstacle_ids
 59 |         self.collision_object = collision_object
 60 | 
 61 |     def compute_distances(self, max_distance: float = 10.0) -> NDArray:
 62 |         """
 63 |         Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
 64 |         to the bodies defined in self.obstacles.
 65 |         Args:
 66 |             max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value
 67 |                 for the distance between such bodies will be max_distance.
 68 |         Returns:
 69 |             A numpy array of distances, one per pair of collision objects.
 70 |         """
 71 |         distances = list()
 72 |         for obstacle in self.obstacles:
 73 | 
 74 |             # Compute the shortest distances between the collision-object and the given obstacle
 75 |             closest_points = p.getClosestPoints(
 76 |                 self.collision_object.body,
 77 |                 obstacle,
 78 |                 distance=max_distance,
 79 |                 linkIndexA=self.collision_object.link
 80 |             )
 81 | 
 82 |             # If bodies are above max_distance apart, nothing is returned, so
 83 |             # we just saturate at max_distance. Otherwise, take the minimum
 84 |             if len(closest_points) == 0:
 85 |                 distances.append(max_distance)
 86 |             else:
 87 |                 distances.append(np.min([point[8] for point in closest_points]))
 88 | 
 89 |         return np.array(distances)
 90 | 
 91 |     def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray:
 92 |         """
 93 |         Compute collisions between manipulator's parts.
 94 |         Args:
 95 |             affected_joints: Joints to consider when calculating distances.
 96 |             max_distance: Maximum distance to be considered. Distances further than this will be ignored, and
 97 |                 the "max_distance" value will be returned.
 98 |         Returns:
 99 |             Array where each element corresponds to the distances from a given joint to the other joints.
100 |         """
101 |         distances = list()
102 |         for joint_ind in affected_joints:
103 | 
104 |             # Collisions with the previous and next joints are omitted, as they will be always in contact
105 |             if (self.collision_object.link == joint_ind) or \
106 |                     (joint_ind == self.collision_object.link - 1) or \
107 |                     (joint_ind == self.collision_object.link + 1):
108 |                 continue    # pragma: no cover
109 | 
110 |             # Compute the shortest distances between all object pairs
111 |             closest_points = p.getClosestPoints(
112 |                 self.collision_object.body,
113 |                 self.collision_object.body,
114 |                 distance=max_distance,
115 |                 linkIndexA=self.collision_object.link,
116 |                 linkIndexB=joint_ind
117 |             )
118 | 
119 |             # If bodies are above max_distance apart, nothing is returned, so
120 |             # we just saturate at max_distance. Otherwise, take the minimum
121 |             if len(closest_points) == 0:
122 |                 distances.append(max_distance)
123 |             else:
124 |                 distances.append(np.min([point[8] for point in closest_points]))
125 | 
126 |         return np.array(distances)
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |

Classes

137 |
138 |
139 | class CollisionDetector 140 | (collision_object: CollisionObject, obstacle_ids: List[str]) 141 |
142 |
143 |

Calculates distances between bodies' joints.

144 |

Args

145 |
146 |
collision_object
147 |
CollisionObject instance, which indicates the body/joint from 148 | which to calculate distances/collisions.
149 |
obstacle_ids
150 |
Obstacle body UID. Distances are calculated from the joint/body given in the 151 | "collision_object" parameter to the "obstacle_ids" bodies.
152 |
153 |
154 | 155 | Expand source code 156 | 157 |
class CollisionDetector:
158 | 
159 |     def __init__(self, collision_object: CollisionObject, obstacle_ids: List[str]):
160 |         """
161 |         Calculates distances between bodies' joints.
162 |         Args:
163 |             collision_object: CollisionObject instance, which indicates the body/joint from
164 |                 which to calculate distances/collisions.
165 |             obstacle_ids: Obstacle body UID. Distances are calculated from the joint/body given in the
166 |                 "collision_object" parameter to the "obstacle_ids" bodies.
167 |         """
168 |         self.obstacles = obstacle_ids
169 |         self.collision_object = collision_object
170 | 
171 |     def compute_distances(self, max_distance: float = 10.0) -> NDArray:
172 |         """
173 |         Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
174 |         to the bodies defined in self.obstacles.
175 |         Args:
176 |             max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value
177 |                 for the distance between such bodies will be max_distance.
178 |         Returns:
179 |             A numpy array of distances, one per pair of collision objects.
180 |         """
181 |         distances = list()
182 |         for obstacle in self.obstacles:
183 | 
184 |             # Compute the shortest distances between the collision-object and the given obstacle
185 |             closest_points = p.getClosestPoints(
186 |                 self.collision_object.body,
187 |                 obstacle,
188 |                 distance=max_distance,
189 |                 linkIndexA=self.collision_object.link
190 |             )
191 | 
192 |             # If bodies are above max_distance apart, nothing is returned, so
193 |             # we just saturate at max_distance. Otherwise, take the minimum
194 |             if len(closest_points) == 0:
195 |                 distances.append(max_distance)
196 |             else:
197 |                 distances.append(np.min([point[8] for point in closest_points]))
198 | 
199 |         return np.array(distances)
200 | 
201 |     def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray:
202 |         """
203 |         Compute collisions between manipulator's parts.
204 |         Args:
205 |             affected_joints: Joints to consider when calculating distances.
206 |             max_distance: Maximum distance to be considered. Distances further than this will be ignored, and
207 |                 the "max_distance" value will be returned.
208 |         Returns:
209 |             Array where each element corresponds to the distances from a given joint to the other joints.
210 |         """
211 |         distances = list()
212 |         for joint_ind in affected_joints:
213 | 
214 |             # Collisions with the previous and next joints are omitted, as they will be always in contact
215 |             if (self.collision_object.link == joint_ind) or \
216 |                     (joint_ind == self.collision_object.link - 1) or \
217 |                     (joint_ind == self.collision_object.link + 1):
218 |                 continue    # pragma: no cover
219 | 
220 |             # Compute the shortest distances between all object pairs
221 |             closest_points = p.getClosestPoints(
222 |                 self.collision_object.body,
223 |                 self.collision_object.body,
224 |                 distance=max_distance,
225 |                 linkIndexA=self.collision_object.link,
226 |                 linkIndexB=joint_ind
227 |             )
228 | 
229 |             # If bodies are above max_distance apart, nothing is returned, so
230 |             # we just saturate at max_distance. Otherwise, take the minimum
231 |             if len(closest_points) == 0:
232 |                 distances.append(max_distance)
233 |             else:
234 |                 distances.append(np.min([point[8] for point in closest_points]))
235 | 
236 |         return np.array(distances)
237 |
238 |

Methods

239 |
240 |
241 | def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.0) ‑> numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]] 242 |
243 |
244 |

Compute collisions between manipulator's parts.

245 |

Args

246 |
247 |
affected_joints
248 |
Joints to consider when calculating distances.
249 |
max_distance
250 |
Maximum distance to be considered. Distances further than this will be ignored, and 251 | the "max_distance" value will be returned.
252 |
253 |

Returns

254 |

Array where each element corresponds to the distances from a given joint to the other joints.

255 |
256 | 257 | Expand source code 258 | 259 |
def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray:
260 |     """
261 |     Compute collisions between manipulator's parts.
262 |     Args:
263 |         affected_joints: Joints to consider when calculating distances.
264 |         max_distance: Maximum distance to be considered. Distances further than this will be ignored, and
265 |             the "max_distance" value will be returned.
266 |     Returns:
267 |         Array where each element corresponds to the distances from a given joint to the other joints.
268 |     """
269 |     distances = list()
270 |     for joint_ind in affected_joints:
271 | 
272 |         # Collisions with the previous and next joints are omitted, as they will be always in contact
273 |         if (self.collision_object.link == joint_ind) or \
274 |                 (joint_ind == self.collision_object.link - 1) or \
275 |                 (joint_ind == self.collision_object.link + 1):
276 |             continue    # pragma: no cover
277 | 
278 |         # Compute the shortest distances between all object pairs
279 |         closest_points = p.getClosestPoints(
280 |             self.collision_object.body,
281 |             self.collision_object.body,
282 |             distance=max_distance,
283 |             linkIndexA=self.collision_object.link,
284 |             linkIndexB=joint_ind
285 |         )
286 | 
287 |         # If bodies are above max_distance apart, nothing is returned, so
288 |         # we just saturate at max_distance. Otherwise, take the minimum
289 |         if len(closest_points) == 0:
290 |             distances.append(max_distance)
291 |         else:
292 |             distances.append(np.min([point[8] for point in closest_points]))
293 | 
294 |     return np.array(distances)
295 |
296 |
297 |
298 | def compute_distances(self, max_distance: float = 10.0) ‑> numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]] 299 |
300 |
301 |

Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object 302 | to the bodies defined in self.obstacles.

303 |

Args

304 |
305 |
max_distance
306 |
Bodies farther apart than this distance are not queried by PyBullet, the return value 307 | for the distance between such bodies will be max_distance.
308 |
309 |

Returns

310 |

A numpy array of distances, one per pair of collision objects.

311 |
312 | 313 | Expand source code 314 | 315 |
def compute_distances(self, max_distance: float = 10.0) -> NDArray:
316 |     """
317 |     Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
318 |     to the bodies defined in self.obstacles.
319 |     Args:
320 |         max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value
321 |             for the distance between such bodies will be max_distance.
322 |     Returns:
323 |         A numpy array of distances, one per pair of collision objects.
324 |     """
325 |     distances = list()
326 |     for obstacle in self.obstacles:
327 | 
328 |         # Compute the shortest distances between the collision-object and the given obstacle
329 |         closest_points = p.getClosestPoints(
330 |             self.collision_object.body,
331 |             obstacle,
332 |             distance=max_distance,
333 |             linkIndexA=self.collision_object.link
334 |         )
335 | 
336 |         # If bodies are above max_distance apart, nothing is returned, so
337 |         # we just saturate at max_distance. Otherwise, take the minimum
338 |         if len(closest_points) == 0:
339 |             distances.append(max_distance)
340 |         else:
341 |             distances.append(np.min([point[8] for point in closest_points]))
342 | 
343 |     return np.array(distances)
344 |
345 |
346 |
347 |
348 |
349 | class CollisionObject 350 | (body: str, link: int) 351 |
352 |
353 |

Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to 354 | calculate distances to other bodies.

355 |
356 | 357 | Expand source code 358 | 359 |
class CollisionObject:
360 |     """
361 |     Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to
362 |     calculate distances to other bodies.
363 |     """
364 |     body: str
365 |     link: int
366 |
367 |

Class variables

368 |
369 |
var body : str
370 |
371 |
372 |
373 | 374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 | 413 |
414 | 417 | 418 | 419 | -------------------------------------------------------------------------------- /docs/utils/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | robotic_manipulator_rloa.utils API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module robotic_manipulator_rloa.utils

23 |
24 |
25 |
26 |
27 |

Sub-modules

28 |
29 |
robotic_manipulator_rloa.utils.collision_detector
30 |
31 |
32 |
33 |
robotic_manipulator_rloa.utils.exceptions
34 |
35 |
36 |
37 |
robotic_manipulator_rloa.utils.logger
38 |
39 |
40 |
41 |
robotic_manipulator_rloa.utils.replay_buffer
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | 75 |
76 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /docs/utils/logger.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | robotic_manipulator_rloa.utils.logger API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module robotic_manipulator_rloa.utils.logger

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
import logging
 30 | from datetime import datetime
 31 | from logging import LogRecord
 32 | from logging.config import dictConfig
 33 | from logging.handlers import RotatingFileHandler
 34 | 
 35 | 
 36 | class Logger:  # pylint: disable=too-many-instance-attributes
 37 |     """ Stores and processes the logs """
 38 | 
 39 |     @staticmethod
 40 |     def set_logger_setup() -> None:
 41 |         """
 42 |         Sets the logger setup with a predefined configuration.
 43 |         """
 44 | 
 45 |         log_config_dict = Logger.generate_logging_config_dict()
 46 |         dictConfig(log_config_dict)
 47 |         rotating_file_handler = RotatingFileHandler(filename='training_logs.log', mode='a', maxBytes=50000000,
 48 |                                                     backupCount=10, encoding='utf-8')
 49 |         rotating_file_handler.setFormatter(logging.Formatter(
 50 |             '"%(levelname)s"|"{datetime}"|%(message)s'.format(datetime=datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
 51 |         ))
 52 |         rotating_file_handler.setLevel(logging.INFO)
 53 |         logger = get_global_logger()
 54 |         logger.addHandler(rotating_file_handler)
 55 |         logger.setLevel(20)
 56 | 
 57 |     @staticmethod
 58 |     def generate_logging_config_dict() -> dict:
 59 |         """
 60 |         Generates the configuration dictionary that is used to configure the logger.
 61 |         Returns:
 62 |             Configuration dictionary.
 63 |         """
 64 |         return {
 65 |             'version': 1,
 66 |             'disable_existing_loggers': False,
 67 |             'formatters': {
 68 |                 'custom_formatter': {
 69 |                     '()': CustomFormatter,
 70 |                     'dateformat': '%Y-%m-%dT%H:%M:%S.%06d%z'
 71 |                 },
 72 |             },
 73 |             'handlers': {
 74 |                 'debug_console_handler': {
 75 |                     'level': 'NOTSET',
 76 |                     'formatter': 'custom_formatter',
 77 |                     'class': 'logging.StreamHandler',
 78 |                     'stream': 'ext://sys.stdout',
 79 |                 }
 80 |             },
 81 |             'loggers': {
 82 |                 '': {
 83 |                     'handlers': ['debug_console_handler'],
 84 |                     'level': 'NOTSET',
 85 |                 },
 86 |             }
 87 |         }
 88 | 
 89 | 
 90 | def get_global_logger() -> logging.Logger:
 91 |     """
 92 |     Getter for the logger.
 93 |     Returns:
 94 |         Logger instance to be used on the framework.
 95 |     """
 96 |     return logging.getLogger(__name__)
 97 | 
 98 | 
 99 | class CustomFormatter(logging.Formatter):
100 | 
101 |     def __init__(self, dateformat: str = None):
102 |         """
103 |         CustomFormatter for the logger.
104 |         """
105 |         super().__init__()
106 |         self.dateformat = dateformat
107 | 
108 |     def format(self, record: LogRecord) -> str:
109 |         """
110 |         Formats the provided LogRecord instance.
111 |         Returns:
112 |             Formatted LogRecord as string.
113 |         """
114 |         # Set format and colors
115 |         grey = "\033[38;20m"
116 |         green = "\033[32;20m"
117 |         yellow = "\033[33;20m"
118 |         red = "\033[31;20m"
119 |         bold_red = "\033[31;1m"
120 |         reset = "\033[0m"
121 |         format_ = '[%(levelname)-8s] - {datetime} - %(message)s'.format(
122 |             datetime=datetime.now().astimezone().strftime('%Y-%m-%dT%H:%M:%S.%f%z')
123 |         )
124 | 
125 |         self.FORMATS = {
126 |             logging.DEBUG: green + format_ + reset,
127 |             logging.INFO: grey + format_ + reset,
128 |             logging.WARNING: yellow + format_ + reset,
129 |             logging.ERROR: red + format_ + reset,
130 |             logging.CRITICAL: bold_red + format_ + reset
131 |         }
132 | 
133 |         log_format = self.FORMATS.get(record.levelno)
134 | 
135 |         formatter = logging.Formatter(log_format, datefmt=self.dateformat)
136 |         return formatter.format(record)
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |

Functions

145 |
146 |
147 | def get_global_logger() ‑> logging.Logger 148 |
149 |
150 |

Getter for the logger.

151 |

Returns

152 |

Logger instance to be used on the framework.

153 |
154 | 155 | Expand source code 156 | 157 |
def get_global_logger() -> logging.Logger:
158 |     """
159 |     Getter for the logger.
160 |     Returns:
161 |         Logger instance to be used on the framework.
162 |     """
163 |     return logging.getLogger(__name__)
164 |
165 |
166 |
167 |
168 |
169 |

Classes

170 |
171 |
172 | class CustomFormatter 173 | (dateformat: str = None) 174 |
175 |
176 |

Formatter instances are used to convert a LogRecord to text.

177 |

Formatters need to know how a LogRecord is constructed. They are 178 | responsible for converting a LogRecord to (usually) a string which can 179 | be interpreted by either a human or an external system. The base Formatter 180 | allows a formatting string to be specified. If none is supplied, the 181 | style-dependent default value, "%(message)s", "{message}", or 182 | "${message}", is used.

183 |

The Formatter can be initialized with a format string which makes use of 184 | knowledge of the LogRecord attributes - e.g. the default value mentioned 185 | above makes use of the fact that the user's message and arguments are pre- 186 | formatted into a LogRecord's message attribute. Currently, the useful 187 | attributes in a LogRecord are described by:

188 |

%(name)s 189 | Name of the logger (logging channel) 190 | %(levelno)s 191 | Numeric logging level for the message (DEBUG, INFO, 192 | WARNING, ERROR, CRITICAL) 193 | %(levelname)s 194 | Text logging level for the message ("DEBUG", "INFO", 195 | "WARNING", "ERROR", "CRITICAL") 196 | %(pathname)s 197 | Full pathname of the source file where the logging 198 | call was issued (if available) 199 | %(filename)s 200 | Filename portion of pathname 201 | %(module)s 202 | Module (name portion of filename) 203 | %(lineno)d 204 | Source line number where the logging call was issued 205 | (if available) 206 | %(funcName)s 207 | Function name 208 | %(created)f 209 | Time when the LogRecord was created (time.time() 210 | return value) 211 | %(asctime)s 212 | Textual time when the LogRecord was created 213 | %(msecs)d 214 | Millisecond portion of the creation time 215 | %(relativeCreated)d Time in milliseconds when the LogRecord was created, 216 | relative to the time the logging module was loaded 217 | (typically at application startup time) 218 | %(thread)d 219 | Thread ID (if available) 220 | %(threadName)s 221 | Thread name (if available) 222 | %(process)d 223 | Process ID (if available) 224 | %(message)s 225 | The result of record.getMessage(), computed just as 226 | the record is emitted

227 |

CustomFormatter for the logger.

228 |
229 | 230 | Expand source code 231 | 232 |
class CustomFormatter(logging.Formatter):
233 | 
234 |     def __init__(self, dateformat: str = None):
235 |         """
236 |         CustomFormatter for the logger.
237 |         """
238 |         super().__init__()
239 |         self.dateformat = dateformat
240 | 
241 |     def format(self, record: LogRecord) -> str:
242 |         """
243 |         Formats the provided LogRecord instance.
244 |         Returns:
245 |             Formatted LogRecord as string.
246 |         """
247 |         # Set format and colors
248 |         grey = "\033[38;20m"
249 |         green = "\033[32;20m"
250 |         yellow = "\033[33;20m"
251 |         red = "\033[31;20m"
252 |         bold_red = "\033[31;1m"
253 |         reset = "\033[0m"
254 |         format_ = '[%(levelname)-8s] - {datetime} - %(message)s'.format(
255 |             datetime=datetime.now().astimezone().strftime('%Y-%m-%dT%H:%M:%S.%f%z')
256 |         )
257 | 
258 |         self.FORMATS = {
259 |             logging.DEBUG: green + format_ + reset,
260 |             logging.INFO: grey + format_ + reset,
261 |             logging.WARNING: yellow + format_ + reset,
262 |             logging.ERROR: red + format_ + reset,
263 |             logging.CRITICAL: bold_red + format_ + reset
264 |         }
265 | 
266 |         log_format = self.FORMATS.get(record.levelno)
267 | 
268 |         formatter = logging.Formatter(log_format, datefmt=self.dateformat)
269 |         return formatter.format(record)
270 |
271 |

Ancestors

272 |
    273 |
  • logging.Formatter
  • 274 |
275 |

Methods

276 |
277 |
278 | def format(self, record: logging.LogRecord) ‑> str 279 |
280 |
281 |

Formats the provided LogRecord instance.

282 |

Returns

283 |

Formatted LogRecord as string.

284 |
285 | 286 | Expand source code 287 | 288 |
def format(self, record: LogRecord) -> str:
289 |     """
290 |     Formats the provided LogRecord instance.
291 |     Returns:
292 |         Formatted LogRecord as string.
293 |     """
294 |     # Set format and colors
295 |     grey = "\033[38;20m"
296 |     green = "\033[32;20m"
297 |     yellow = "\033[33;20m"
298 |     red = "\033[31;20m"
299 |     bold_red = "\033[31;1m"
300 |     reset = "\033[0m"
301 |     format_ = '[%(levelname)-8s] - {datetime} - %(message)s'.format(
302 |         datetime=datetime.now().astimezone().strftime('%Y-%m-%dT%H:%M:%S.%f%z')
303 |     )
304 | 
305 |     self.FORMATS = {
306 |         logging.DEBUG: green + format_ + reset,
307 |         logging.INFO: grey + format_ + reset,
308 |         logging.WARNING: yellow + format_ + reset,
309 |         logging.ERROR: red + format_ + reset,
310 |         logging.CRITICAL: bold_red + format_ + reset
311 |     }
312 | 
313 |     log_format = self.FORMATS.get(record.levelno)
314 | 
315 |     formatter = logging.Formatter(log_format, datefmt=self.dateformat)
316 |     return formatter.format(record)
317 |
318 |
319 |
320 |
321 |
322 | class Logger 323 |
324 |
325 |

Stores and processes the logs

326 |
327 | 328 | Expand source code 329 | 330 |
class Logger:  # pylint: disable=too-many-instance-attributes
331 |     """ Stores and processes the logs """
332 | 
333 |     @staticmethod
334 |     def set_logger_setup() -> None:
335 |         """
336 |         Sets the logger setup with a predefined configuration.
337 |         """
338 | 
339 |         log_config_dict = Logger.generate_logging_config_dict()
340 |         dictConfig(log_config_dict)
341 |         rotating_file_handler = RotatingFileHandler(filename='training_logs.log', mode='a', maxBytes=50000000,
342 |                                                     backupCount=10, encoding='utf-8')
343 |         rotating_file_handler.setFormatter(logging.Formatter(
344 |             '"%(levelname)s"|"{datetime}"|%(message)s'.format(datetime=datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
345 |         ))
346 |         rotating_file_handler.setLevel(logging.INFO)
347 |         logger = get_global_logger()
348 |         logger.addHandler(rotating_file_handler)
349 |         logger.setLevel(20)
350 | 
351 |     @staticmethod
352 |     def generate_logging_config_dict() -> dict:
353 |         """
354 |         Generates the configuration dictionary that is used to configure the logger.
355 |         Returns:
356 |             Configuration dictionary.
357 |         """
358 |         return {
359 |             'version': 1,
360 |             'disable_existing_loggers': False,
361 |             'formatters': {
362 |                 'custom_formatter': {
363 |                     '()': CustomFormatter,
364 |                     'dateformat': '%Y-%m-%dT%H:%M:%S.%06d%z'
365 |                 },
366 |             },
367 |             'handlers': {
368 |                 'debug_console_handler': {
369 |                     'level': 'NOTSET',
370 |                     'formatter': 'custom_formatter',
371 |                     'class': 'logging.StreamHandler',
372 |                     'stream': 'ext://sys.stdout',
373 |                 }
374 |             },
375 |             'loggers': {
376 |                 '': {
377 |                     'handlers': ['debug_console_handler'],
378 |                     'level': 'NOTSET',
379 |                 },
380 |             }
381 |         }
382 |
383 |

Static methods

384 |
385 |
386 | def generate_logging_config_dict() ‑> dict 387 |
388 |
389 |

Generates the configuration dictionary that is used to configure the logger.

390 |

Returns

391 |

Configuration dictionary.

392 |
393 | 394 | Expand source code 395 | 396 |
@staticmethod
397 | def generate_logging_config_dict() -> dict:
398 |     """
399 |     Generates the configuration dictionary that is used to configure the logger.
400 |     Returns:
401 |         Configuration dictionary.
402 |     """
403 |     return {
404 |         'version': 1,
405 |         'disable_existing_loggers': False,
406 |         'formatters': {
407 |             'custom_formatter': {
408 |                 '()': CustomFormatter,
409 |                 'dateformat': '%Y-%m-%dT%H:%M:%S.%06d%z'
410 |             },
411 |         },
412 |         'handlers': {
413 |             'debug_console_handler': {
414 |                 'level': 'NOTSET',
415 |                 'formatter': 'custom_formatter',
416 |                 'class': 'logging.StreamHandler',
417 |                 'stream': 'ext://sys.stdout',
418 |             }
419 |         },
420 |         'loggers': {
421 |             '': {
422 |                 'handlers': ['debug_console_handler'],
423 |                 'level': 'NOTSET',
424 |             },
425 |         }
426 |     }
427 |
428 |
429 |
430 | def set_logger_setup() ‑> None 431 |
432 |
433 |

Sets the logger setup with a predefined configuration.

434 |
435 | 436 | Expand source code 437 | 438 |
@staticmethod
439 | def set_logger_setup() -> None:
440 |     """
441 |     Sets the logger setup with a predefined configuration.
442 |     """
443 | 
444 |     log_config_dict = Logger.generate_logging_config_dict()
445 |     dictConfig(log_config_dict)
446 |     rotating_file_handler = RotatingFileHandler(filename='training_logs.log', mode='a', maxBytes=50000000,
447 |                                                 backupCount=10, encoding='utf-8')
448 |     rotating_file_handler.setFormatter(logging.Formatter(
449 |         '"%(levelname)s"|"{datetime}"|%(message)s'.format(datetime=datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
450 |     ))
451 |     rotating_file_handler.setLevel(logging.INFO)
452 |     logger = get_global_logger()
453 |     logger.addHandler(rotating_file_handler)
454 |     logger.setLevel(20)
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 | 497 |
498 | 501 | 502 | 503 | -------------------------------------------------------------------------------- /docs/utils/replay_buffer.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | robotic_manipulator_rloa.utils.replay_buffer API documentation 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |
22 |

Module robotic_manipulator_rloa.utils.replay_buffer

23 |
24 |
25 |
26 | 27 | Expand source code 28 | 29 |
import random
 30 | from collections import deque, namedtuple
 31 | from typing import Tuple
 32 | 
 33 | import numpy as np
 34 | import torch
 35 | from numpy.typing import NDArray
 36 | 
 37 | 
 38 | # deque is a Doubly Ended Queue, provides O(1) complexity for pop and append actions
 39 | # namedtuple is a tuple that can be accessed by both its index and attributes
 40 | 
 41 | 
 42 | class ReplayBuffer:
 43 | 
 44 |     def __init__(self, buffer_size: int, batch_size: int, device: torch.device, seed: int):
 45 |         """
 46 |         Buffer to store experience tuples. Each experience has the following structure:
 47 |         (state, action, reward, next_state, done)
 48 |         Args:
 49 |             buffer_size: Maximum size for the buffer. Higher buffer size imply higher RAM consumption.
 50 |             batch_size: Number of experiences to be retrieved from the ReplayBuffer per batch.
 51 |             device: CUDA device.
 52 |             seed: Random seed.
 53 |         """
 54 |         self.device = device
 55 |         self.memory = deque(maxlen=buffer_size)
 56 |         self.batch_size = batch_size
 57 |         self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
 58 |         random.seed(seed)
 59 | 
 60 |     def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
 61 |         """
 62 |         Add a new experience to the Replay Buffer.
 63 |         Args:
 64 |             state: NDArray of the current state.
 65 |             action: NDArray of the action taken from state {state}.
 66 |             reward: Reward obtained after performing action {action} from state {state}.
 67 |             next_state: NDArray of the state reached after performing action {action} from state {state}.
 68 |             done: Integer (0 or 1) indicating whether the next_state is a terminal state.
 69 |         """
 70 |         # Create namedtuple object from the experience
 71 |         exp = self.experience(state, action, reward, next_state, done)
 72 |         # Add the experience object to memory
 73 |         self.memory.append(exp)
 74 | 
 75 |     def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
 76 |         """
 77 |         Randomly sample a batch of experiences from memory.
 78 |         Returns:
 79 |             Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
 80 |             in the tuple is a torch Tensor composed of {batch_size} items.
 81 |         """
 82 |         # Randomly sample a batch of experiences
 83 |         experiences = random.sample(self.memory, k=self.batch_size)
 84 | 
 85 |         states = torch.from_numpy(
 86 |             np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to(
 87 |             self.device)
 88 |         actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
 89 |         rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
 90 |         next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
 91 |             self.device)
 92 |         dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
 93 |             self.device)
 94 | 
 95 |         return states, actions, rewards, next_states, dones
 96 | 
 97 |     def __len__(self) -> int:
 98 |         """
 99 |         Return the current size of the Replay Buffer
100 |         Returns:
101 |             Size of Replay Buffer
102 |         """
103 |         return len(self.memory)
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |

Classes

114 |
115 |
116 | class ReplayBuffer 117 | (buffer_size: int, batch_size: int, device: torch.device, seed: int) 118 |
119 |
120 |

Buffer to store experience tuples. Each experience has the following structure: 121 | (state, action, reward, next_state, done)

122 |

Args

123 |
124 |
buffer_size
125 |
Maximum size for the buffer. Higher buffer size imply higher RAM consumption.
126 |
batch_size
127 |
Number of experiences to be retrieved from the ReplayBuffer per batch.
128 |
device
129 |
CUDA device.
130 |
seed
131 |
Random seed.
132 |
133 |
134 | 135 | Expand source code 136 | 137 |
class ReplayBuffer:
138 | 
139 |     def __init__(self, buffer_size: int, batch_size: int, device: torch.device, seed: int):
140 |         """
141 |         Buffer to store experience tuples. Each experience has the following structure:
142 |         (state, action, reward, next_state, done)
143 |         Args:
144 |             buffer_size: Maximum size for the buffer. Higher buffer size imply higher RAM consumption.
145 |             batch_size: Number of experiences to be retrieved from the ReplayBuffer per batch.
146 |             device: CUDA device.
147 |             seed: Random seed.
148 |         """
149 |         self.device = device
150 |         self.memory = deque(maxlen=buffer_size)
151 |         self.batch_size = batch_size
152 |         self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
153 |         random.seed(seed)
154 | 
155 |     def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
156 |         """
157 |         Add a new experience to the Replay Buffer.
158 |         Args:
159 |             state: NDArray of the current state.
160 |             action: NDArray of the action taken from state {state}.
161 |             reward: Reward obtained after performing action {action} from state {state}.
162 |             next_state: NDArray of the state reached after performing action {action} from state {state}.
163 |             done: Integer (0 or 1) indicating whether the next_state is a terminal state.
164 |         """
165 |         # Create namedtuple object from the experience
166 |         exp = self.experience(state, action, reward, next_state, done)
167 |         # Add the experience object to memory
168 |         self.memory.append(exp)
169 | 
170 |     def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171 |         """
172 |         Randomly sample a batch of experiences from memory.
173 |         Returns:
174 |             Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
175 |             in the tuple is a torch Tensor composed of {batch_size} items.
176 |         """
177 |         # Randomly sample a batch of experiences
178 |         experiences = random.sample(self.memory, k=self.batch_size)
179 | 
180 |         states = torch.from_numpy(
181 |             np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to(
182 |             self.device)
183 |         actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
184 |         rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
185 |         next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
186 |             self.device)
187 |         dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
188 |             self.device)
189 | 
190 |         return states, actions, rewards, next_states, dones
191 | 
192 |     def __len__(self) -> int:
193 |         """
194 |         Return the current size of the Replay Buffer
195 |         Returns:
196 |             Size of Replay Buffer
197 |         """
198 |         return len(self.memory)
199 |
200 |

Methods

201 |
202 |
203 | def add(self, state: numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]], action: numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]], reward: float, next_state: numpy.ndarray[typing.Any, numpy.dtype[+ScalarType]], done: int) ‑> None 204 |
205 |
206 |

Add a new experience to the Replay Buffer.

207 |

Args

208 |
209 |
state
210 |
NDArray of the current state.
211 |
action
212 |
NDArray of the action taken from state {state}.
213 |
reward
214 |
Reward obtained after performing action {action} from state {state}.
215 |
next_state
216 |
NDArray of the state reached after performing action {action} from state {state}.
217 |
done
218 |
Integer (0 or 1) indicating whether the next_state is a terminal state.
219 |
220 |
221 | 222 | Expand source code 223 | 224 |
def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
225 |     """
226 |     Add a new experience to the Replay Buffer.
227 |     Args:
228 |         state: NDArray of the current state.
229 |         action: NDArray of the action taken from state {state}.
230 |         reward: Reward obtained after performing action {action} from state {state}.
231 |         next_state: NDArray of the state reached after performing action {action} from state {state}.
232 |         done: Integer (0 or 1) indicating whether the next_state is a terminal state.
233 |     """
234 |     # Create namedtuple object from the experience
235 |     exp = self.experience(state, action, reward, next_state, done)
236 |     # Add the experience object to memory
237 |     self.memory.append(exp)
238 |
239 |
240 |
241 | def sample(self) ‑> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] 242 |
243 |
244 |

Randomly sample a batch of experiences from memory.

245 |

Returns

246 |

Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element 247 | in the tuple is a torch Tensor composed of {batch_size} items.

248 |
249 | 250 | Expand source code 251 | 252 |
def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
253 |     """
254 |     Randomly sample a batch of experiences from memory.
255 |     Returns:
256 |         Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
257 |         in the tuple is a torch Tensor composed of {batch_size} items.
258 |     """
259 |     # Randomly sample a batch of experiences
260 |     experiences = random.sample(self.memory, k=self.batch_size)
261 | 
262 |     states = torch.from_numpy(
263 |         np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to(
264 |         self.device)
265 |     actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
266 |     rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
267 |     next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
268 |         self.device)
269 |     dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
270 |         self.device)
271 | 
272 |     return states, actions, rewards, next_states, dones
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 | 304 |
305 | 308 | 309 | 310 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "robotic-manipulator-rloa" 7 | version = "1.0.0" 8 | description = "Framework for training Robotic Manipulators on the Obstacle Avoidance task through Reinforcement Learning." 9 | readme = "README.md" 10 | authors = [{ name = "Javier Martinez", email = "jmartinezojeda5upv@gmail.com" }] 11 | license = { file = "LICENSE" } 12 | classifiers = [ 13 | "License :: OSI Approved :: MIT License", 14 | "Programming Language :: Python", 15 | "Programming Language :: Python :: 3", 16 | "Operating System :: Microsoft :: Windows", 17 | "Operating System :: POSIX :: Linux", 18 | "Framework :: Robot Framework" 19 | ] 20 | keywords = ["Robotics", "Manipulator", "Framework", "Reinforcement Learning", "Obstacle Avoidance"] 21 | dependencies = [ 22 | "torch", 23 | "numpy", 24 | "pybullet", 25 | "matplotlib" 26 | ] 27 | requires-python = ">=3.8" 28 | 29 | [tool.setuptools.package-data] 30 | myModule = ["*.p"] 31 | 32 | [project.urls] 33 | Homepage = "https://github.com/JavierMtz5/robotic_manipulator_rloa" 34 | Repository = "https://github.com/JavierMtz5/robotic_manipulator_rloa" 35 | Documentation = "https://javiermtz5.github.io/robotic_manipulator_rloa/" -------------------------------------------------------------------------------- /robotic_manipulator_rloa/__init__.py: -------------------------------------------------------------------------------- 1 | from .rl_framework import ManipulatorFramework 2 | -------------------------------------------------------------------------------- /robotic_manipulator_rloa/environment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/environment/__init__.py -------------------------------------------------------------------------------- /robotic_manipulator_rloa/environment/environment.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Tuple 3 | 4 | import numpy as np 5 | import pybullet as p 6 | import pybullet_data 7 | from numpy.typing import NDArray 8 | 9 | from robotic_manipulator_rloa.utils.logger import get_global_logger 10 | from robotic_manipulator_rloa.utils.exceptions import ( 11 | InvalidManipulatorFile, 12 | InvalidEnvironmentParameter 13 | ) 14 | from robotic_manipulator_rloa.utils.collision_detector import CollisionObject, CollisionDetector 15 | 16 | logger = get_global_logger() 17 | 18 | 19 | class EnvironmentConfiguration: 20 | 21 | def __init__(self, 22 | endeffector_index: int, 23 | fixed_joints: List[int], 24 | involved_joints: List[int], 25 | target_position: List[float], 26 | obstacle_position: List[float], 27 | initial_joint_positions: List[float] = None, 28 | initial_positions_variation_range: List[float] = None, 29 | max_force: float = 200., 30 | visualize: bool = True): 31 | """ 32 | Validates each of the parameters required for the Environment class initialization. 33 | Args: 34 | endeffector_index: Index of the manipulator's end-effector. 35 | fixed_joints: List containing the indices of every joint not involved in the training. 36 | involved_joints: List containing the indices of every joint involved in the training. 37 | target_position: List containing the position of the target object, as 3D Cartesian coordinates. 38 | obstacle_position: List containing the position of the obstacle, as 3D Cartesian coordinates. 39 | initial_joint_positions: List containing as many items as the number of joints of the manipulator. 40 | Each item in the list corresponds to the initial position wanted for the joint with that same index. 41 | initial_positions_variation_range: List containing as many items as the number of joints of the manipulator. 42 | Each item in the list corresponds to the variation range wanted for the joint with that same index. 43 | max_force: Maximum force to be applied on the joints. 44 | visualize: Visualization mode. 45 | """ 46 | self._validate_endeffector_index(endeffector_index) 47 | self._validate_fixed_joints(fixed_joints) 48 | self._validate_involved_joints(involved_joints) 49 | self._validate_target_position(target_position) 50 | self._validate_obstacle_position(obstacle_position) 51 | self._validate_initial_joint_positions(initial_joint_positions) 52 | self._validate_initial_positions_variation_range(initial_positions_variation_range) 53 | self._validate_max_force(max_force) 54 | self._validate_visualize(visualize) 55 | 56 | def _validate_endeffector_index(self, endeffector_index: int) -> None: 57 | """ 58 | Validates the "endeffector_index" parameter. 59 | Args: 60 | endeffector_index: int 61 | Raises: 62 | InvalidEnvironmentParameter 63 | """ 64 | if not isinstance(endeffector_index, int): 65 | raise InvalidEnvironmentParameter('End Effector index received is not an integer') 66 | self.endeffector_index = endeffector_index 67 | 68 | def _validate_fixed_joints(self, fixed_joints: List[int]) -> None: 69 | """ 70 | Validates the "fixed_joints" parameter 71 | Args: 72 | fixed_joints: list of integers 73 | Raises: 74 | InvalidEnvironmentParameter 75 | """ 76 | if not isinstance(fixed_joints, list): 77 | raise InvalidEnvironmentParameter('Fixed Joints received is not a list') 78 | for val in fixed_joints: 79 | if not isinstance(val, int): 80 | raise InvalidEnvironmentParameter('An item inside the Fixed Joints list is not an integer') 81 | self.fixed_joints = fixed_joints 82 | 83 | def _validate_involved_joints(self, involved_joints: List[int]) -> None: 84 | """ 85 | Validates the "involved_joints" parameter 86 | Args: 87 | involved_joints: list of integers 88 | Raises: 89 | InvalidEnvironmentParameter 90 | """ 91 | if not isinstance(involved_joints, list): 92 | raise InvalidEnvironmentParameter('Involved Joints received is not a list') 93 | for val in involved_joints: 94 | if not isinstance(val, int): 95 | raise InvalidEnvironmentParameter('An item inside the Involved Joints list is not an integer') 96 | self.involved_joints = involved_joints 97 | 98 | def _validate_target_position(self, target_position: List[float]) -> None: 99 | """ 100 | Validates the "target_position" parameter 101 | Args: 102 | target_position: list of floats 103 | Raises: 104 | InvalidEnvironmentParameter 105 | """ 106 | if not isinstance(target_position, list): 107 | raise InvalidEnvironmentParameter('Target Position received is not a list') 108 | for val in target_position: 109 | if not isinstance(val, (int, float)): 110 | raise InvalidEnvironmentParameter('An item inside the Target Position list is not a float') 111 | self.target_position = target_position 112 | 113 | def _validate_obstacle_position(self, obstacle_position: List[float]) -> None: 114 | """ 115 | Validates the "obstacle_position" parameter 116 | Args: 117 | obstacle_position: list of floats 118 | Raises: 119 | InvalidEnvironmentParameter 120 | """ 121 | if not isinstance(obstacle_position, list): 122 | raise InvalidEnvironmentParameter('Obstacle Position received is not a list') 123 | for val in obstacle_position: 124 | if not isinstance(val, (int, float)): 125 | raise InvalidEnvironmentParameter('An item inside the Obstacle Position list is not a float') 126 | self.obstacle_position = obstacle_position 127 | 128 | def _validate_initial_joint_positions(self, initial_joint_positions: List[float]) -> None: 129 | """ 130 | Validates the "initial_joint_positions" parameter 131 | Args: 132 | initial_joint_positions: list of floats 133 | Raises: 134 | InvalidEnvironmentParameter 135 | """ 136 | if initial_joint_positions is None: 137 | self.initial_joint_positions = None 138 | return 139 | if not isinstance(initial_joint_positions, list): 140 | raise InvalidEnvironmentParameter('Initial Joint Positions received is not a list') 141 | for val in initial_joint_positions: 142 | if not isinstance(val, (int, float)): 143 | raise InvalidEnvironmentParameter('An item inside the Initial Joint Positions list is not a float') 144 | self.initial_joint_positions = initial_joint_positions 145 | 146 | def _validate_initial_positions_variation_range(self, initial_positions_variation_range: List[float]) -> None: 147 | """ 148 | Validates the "initial_positions_variation_range" parameter 149 | Args: 150 | initial_positions_variation_range: list of floats 151 | Raises: 152 | InvalidEnvironmentParameter 153 | """ 154 | if initial_positions_variation_range is None: 155 | self.initial_positions_variation_range = None 156 | return 157 | if not isinstance(initial_positions_variation_range, list): 158 | raise InvalidEnvironmentParameter('Initial Positions Variation Range received is not a list') 159 | for val in initial_positions_variation_range: 160 | if not isinstance(val, (float, int)): 161 | raise InvalidEnvironmentParameter('An item inside the Initial Positions Variation Range ' 162 | 'list is not a float') 163 | self.initial_positions_variation_range = initial_positions_variation_range 164 | 165 | def _validate_max_force(self, max_force: float) -> None: 166 | """ 167 | Validates the "max_force" parameter 168 | Args: 169 | max_force: float 170 | Raises: 171 | InvalidEnvironmentParameter 172 | """ 173 | if not isinstance(max_force, (int, float)): 174 | raise InvalidEnvironmentParameter('Maximum Force value received is not a float') 175 | self.max_force = max_force 176 | 177 | def _validate_visualize(self, visualize: bool) -> None: 178 | """ 179 | Validates the "visualize" parameter 180 | Args: 181 | visualize: bool 182 | Raises: 183 | InvalidEnvironmentParameter 184 | """ 185 | if not isinstance(visualize, bool): 186 | raise InvalidEnvironmentParameter('Visualize value received is not a boolean') 187 | self.visualize = visualize 188 | 189 | 190 | class Environment: 191 | 192 | def __init__(self, 193 | manipulator_file: str, 194 | environment_config: EnvironmentConfiguration): 195 | """ 196 | Creates the Pybullet environment used along the training. 197 | Args: 198 | manipulator_file: Path to the URDF or SDF file from which to load the Robotic Manipulator. 199 | environment_config: Instance of the EnvironmentConfiguration class with all its attributes set. 200 | Raises: 201 | InvalidManipulatorFile: The URDF/SDF file doesn't exist, is invalid or has an invalid extension. 202 | """ 203 | self.manipulator_file = manipulator_file 204 | self.visualize = environment_config.visualize 205 | 206 | # Initialize pybullet 207 | self.physics_client = p.connect(p.GUI if environment_config.visualize else p.DIRECT) 208 | p.setGravity(0, 0, -9.81) 209 | p.setRealTimeSimulation(0) 210 | p.setAdditionalSearchPath(pybullet_data.getDataPath()) 211 | 212 | self.target_pos = environment_config.target_position 213 | self.obstacle_pos = environment_config.obstacle_position 214 | self.max_force = environment_config.max_force # Maximum force to be applied (DEFAULT=200) 215 | self.initial_joint_positions = environment_config.initial_joint_positions 216 | self.initial_positions_variation_range = environment_config.initial_positions_variation_range 217 | 218 | self.endeffector_index = environment_config.endeffector_index # Index of the Manipulator's End-Effector 219 | self.fixed_joints = environment_config.fixed_joints # List of indexes for the joints to be fixed 220 | self.involved_joints = environment_config.involved_joints # List of indexes of joints involved in the training 221 | 222 | # Load Manipulator from URDF/SDF file 223 | logger.debug(f'Loading URDF/SDF file {manipulator_file} for Robot Manipulator...') 224 | if not isinstance(manipulator_file, str): 225 | raise InvalidManipulatorFile('The filename provided is not a string') 226 | 227 | try: 228 | if manipulator_file.endswith('.urdf'): 229 | self.manipulator_uid = p.loadURDF(manipulator_file) 230 | elif manipulator_file.endswith('.sdf'): 231 | self.manipulator_uid = p.loadSDF(manipulator_file)[0] 232 | else: 233 | raise InvalidManipulatorFile('The file extension is neither .sdf nor .urdf') 234 | except p.error as err: 235 | logger.critical(err) 236 | raise InvalidManipulatorFile 237 | 238 | self.num_joints = p.getNumJoints(self.manipulator_uid) 239 | 240 | logger.debug(f'Robot Manipulator URDF/SDF file {manipulator_file} has been successfully loaded. ' 241 | f'The Robot Manipulator has {self.num_joints} joints, and its joints, ' 242 | f'together with the information of each, are:') 243 | data = list() 244 | for joint_ind in range(self.num_joints): 245 | joint_info = p.getJointInfo(self.manipulator_uid, joint_ind) 246 | data.append((joint_ind, joint_info[1].decode("utf-8"), joint_info[9], joint_info[8], joint_info[13])) 247 | 248 | # Print Joints info 249 | self.print_table(data) 250 | 251 | # Create obstacle with the shape of a sphere, and the target object with square shape 252 | self.obstacle = p.loadURDF('sphere_small.urdf', basePosition=self.obstacle_pos, 253 | useFixedBase=1, globalScaling=2.5) 254 | self.target = p.loadURDF('cube_small.urdf', basePosition=self.target_pos, 255 | useFixedBase=1, globalScaling=1) 256 | logger.debug(f'Both the obstacle and the target object have been generated in positions {self.obstacle_pos} ' 257 | f'and {self.target_pos} respectively') 258 | 259 | # 9 elements correspond to the 3 vector indicating the position of the target, end effector and obstacle 260 | # The other elements are the two arrays of the involved joint's position and velocities 261 | self._observation_space = np.zeros((9 + 2 * len(self.involved_joints),)) 262 | self._action_space = np.zeros((len(self.involved_joints),)) 263 | 264 | def reset(self, verbose: bool = True) -> NDArray: 265 | """ 266 | Resets the environment to a initial state.\n 267 | - If "initial_joint_positions" and "initial_positions_variation_range" are not set, all joints will be reset to 268 | the 0 position.\n 269 | - If only "initial_joint_positions" is set, the joints will be reset to those positions.\n 270 | - If only "initial_positions_variation_range" is set, the joints will be reset to 0 plus the variation noise.\n 271 | - If both "initial_joint_positions" and "initial_positions_variation_range" are set, the joints will be reset 272 | to the positions specified plus the variation noise. 273 | Args: 274 | verbose: Boolean indicating whether to print context information or not. 275 | Returns: 276 | New state reached after reset. 277 | """ 278 | if verbose: logger.info('Resetting Environment...') 279 | 280 | # Reset the robot's base position and orientation 281 | p.resetBasePositionAndOrientation(self.manipulator_uid, [0.000000, 0.000000, 0.000000], 282 | [0.000000, 0.000000, 0.000000, 1.000000]) 283 | 284 | if not self.initial_joint_positions and not self.initial_positions_variation_range: 285 | initial_state = [0 for _ in range(self.num_joints)] 286 | elif self.initial_joint_positions: 287 | if self.initial_positions_variation_range: 288 | initial_state = [random.uniform(pos - var, pos + var) for pos, var 289 | in zip(self.initial_joint_positions, self.initial_positions_variation_range)] 290 | else: 291 | initial_state = self.initial_joint_positions 292 | else: 293 | initial_state = [random.uniform(0 - var, 0 + var) for var in self.initial_positions_variation_range] 294 | 295 | for joint_index, pos in enumerate(initial_state): 296 | p.setJointMotorControl2(self.manipulator_uid, joint_index, 297 | controlMode=p.POSITION_CONTROL, 298 | targetPosition=pos) 299 | 300 | for _ in range(50): 301 | p.stepSimulation(self.physics_client) 302 | 303 | # Generate first state, and return it 304 | # The states are defined as {joint_pos, joint_vel, end-effector_pos, target_pos, obstacle_pos}, where 305 | # both joint_pos and joint_vel are arrays with the pos and vel of each joint 306 | new_state = self.get_state() 307 | if verbose: logger.info('Environment Reset') 308 | 309 | return new_state 310 | 311 | def is_terminal_state(self, target_threshold: float = 0.05, obstacle_threshold: float = 0., 312 | consider_autocollision: bool = False) -> int: 313 | """ 314 | Calculates if a terminal state is reached. 315 | Args: 316 | target_threshold: Threshold which delimits the terminal state. If the end-effector is closer 317 | to the target position than the threshold value, then a terminal state is reached. 318 | obstacle_threshold: Threshold which delimits the terminal state. If the end-effector is closer 319 | to the obstacle position than the threshold value, then a terminal state is reached. 320 | consider_autocollision: If set to True, the collision of any of the joints and parts of the manipulator 321 | with any other joint or part will be considered a terminal state. 322 | Returns: 323 | Integer (0 or 1) indicating whether the new state reached is a terminal state or not. 324 | """ 325 | # If the manipulator has a collision with the obstacle, the episode terminates 326 | if self.get_manipulator_obstacle_collisions(threshold=obstacle_threshold): 327 | logger.info('Collision detected, terminating episode...') 328 | return 1 329 | 330 | # If the position of the end-effector is the same as the one of the target position, episode terminates 331 | if self.get_endeffector_target_collision(threshold=target_threshold)[0]: 332 | logger.info('The goal state has been reached, terminating episode...') 333 | return 1 334 | 335 | # If the manipulator collides with itself, a terminal state is reached 336 | if consider_autocollision: 337 | self_distances = self.get_manipulator_collisions_with_itself() 338 | for distances in self_distances.values(): 339 | if (distances < 0).any(): 340 | logger.info('Auto-Collision detected, terminating episode...') 341 | return 1 342 | 343 | return 0 344 | 345 | def get_reward(self, consider_autocollision: bool = False) -> float: 346 | """ 347 | Computes the reward from the given state. 348 | Returns: 349 | Rewards:\n 350 | - If the end effector reaches the target position, a reward of +250 is returned.\n 351 | - If the end effector collides with the obstacle or with itself*, a reward of -1000 is returned.\n 352 | - Otherwise, the negative value of the distance from end effector to the target is returned.\n 353 | * The manipulator's collisions with itself are only considered if "consider_autocollision" parameter is set 354 | to True. 355 | """ 356 | # Auto-Collision is only calculated if requested 357 | self_collision = False 358 | if consider_autocollision: 359 | self_distances = self.get_manipulator_collisions_with_itself() 360 | for distances in self_distances.values(): 361 | if (distances < 0).any(): 362 | self_collision = True 363 | 364 | endeffector_target_collision, endeffector_target_dist = self.get_endeffector_target_collision(threshold=0.05) 365 | 366 | if endeffector_target_collision: 367 | return 250 368 | elif self.get_manipulator_obstacle_collisions(threshold=0) or self_collision: 369 | return -1000 370 | else: 371 | return -1 * float(endeffector_target_dist) 372 | 373 | def get_manipulator_obstacle_collisions(self, threshold: float) -> bool: 374 | """ 375 | Calculates if there is a collision between the manipulator and the obstacle. 376 | Args: 377 | threshold: If the distance between the end effector and the obstacle is below the "threshold", then 378 | it is considered a collision. 379 | Returns: 380 | Boolean indicating whether a collision occurred. 381 | """ 382 | joint_distances = list() 383 | for joint_ind in range(self.num_joints): 384 | end_effector_collision_obj = CollisionObject(body=self.manipulator_uid, link=joint_ind) 385 | collision_detector = CollisionDetector(collision_object=end_effector_collision_obj, 386 | obstacle_ids=[self.obstacle]) 387 | 388 | dist = collision_detector.compute_distances() 389 | joint_distances.append(dist[0]) 390 | 391 | joint_distances = np.array(joint_distances) 392 | return (joint_distances < threshold).any() 393 | 394 | def get_manipulator_collisions_with_itself(self) -> dict: 395 | """ 396 | Calculates the distances between each of the manipulator's joints and the other joints. 397 | Returns: 398 | Dictionary where each key is the index of a joint, and where each value is an array with the 399 | distances from that joint to any other joint in the manipulator. 400 | """ 401 | joint_distances = dict() 402 | for joint_ind in range(self.num_joints): 403 | joint_collision_obj = CollisionObject(body=self.manipulator_uid, link=joint_ind) 404 | collision_detector = CollisionDetector(collision_object=joint_collision_obj, 405 | obstacle_ids=[]) 406 | distances = collision_detector.compute_collisions_in_manipulator( 407 | affected_joints=[_ for _ in range(self.num_joints)], # all joints are taken into account 408 | max_distance=10 409 | ) 410 | joint_distances[f'joint_{joint_ind}'] = distances 411 | 412 | return joint_distances 413 | 414 | def get_endeffector_target_collision(self, threshold: float) -> Tuple[bool, float]: 415 | """ 416 | Calculates if there are any collisions between the end effector and the target. 417 | Args: 418 | threshold: If the distance between the end effector and the target is below {threshold}, then 419 | it is considered a collision. 420 | Returns: 421 | Tuple where the first element is a boolean indicating whether a collision occurred, adn where 422 | the second is the distance from end effector to target minus the threshold. 423 | """ 424 | kuka_end_effector = CollisionObject(body=self.manipulator_uid, link=self.endeffector_index) 425 | collision_detector = CollisionDetector(collision_object=kuka_end_effector, obstacle_ids=[self.target]) 426 | 427 | dist = collision_detector.compute_distances() 428 | 429 | return (dist < threshold).any(), dist - threshold 430 | 431 | def get_state(self) -> NDArray: 432 | """ 433 | Retrieves information from the environment's current state. 434 | Returns: 435 | State as (joint_pos, joint_vel, end-effector_pos, target_pos, obstacle_pos):\n 436 | - The positions of the target, obstacle and end effector are given as 3D cartesian coordinates.\n 437 | - The joint positions and joint velocities are given as arrays of length equal to the number of 438 | joint involved in the training. 439 | """ 440 | joint_pos, joint_vel = list(), list() 441 | 442 | for joint_index in range(len(self.involved_joints)): 443 | joint_pos.append(p.getJointState(self.manipulator_uid, joint_index)[0]) 444 | joint_vel.append(p.getJointState(self.manipulator_uid, joint_index)[1]) 445 | 446 | end_effector_pos = p.getLinkState(self.manipulator_uid, self.endeffector_index)[0] 447 | end_effector_pos = list(end_effector_pos) 448 | 449 | state = np.hstack([np.array(joint_pos), np.array(joint_vel), np.array(end_effector_pos), 450 | np.array(self.target_pos), np.array(self.obstacle_pos)]) 451 | return state.astype(float) 452 | 453 | def step(self, action: NDArray) -> Tuple[NDArray, float, int]: 454 | """ 455 | Applies the action on the Robot's joints, so that each joint reaches the desired velocity for 456 | each involved joint. 457 | Args: 458 | action: Array where each element corresponds to the velocity to be applied on the joint 459 | with that same index. 460 | Returns: 461 | (new_state, reward, done) 462 | """ 463 | # Apply velocities on the involved joints according to action 464 | for joint_index, vel in zip(self.involved_joints, action): 465 | p.setJointMotorControl2(self.manipulator_uid, 466 | joint_index, 467 | p.VELOCITY_CONTROL, 468 | targetVelocity=vel, 469 | force=self.max_force) 470 | 471 | # Create constraint for fixed joints (maintain joint on fixed position) 472 | for joint_ind in self.fixed_joints: 473 | p.setJointMotorControl2(self.manipulator_uid, 474 | joint_ind, 475 | p.POSITION_CONTROL, 476 | targetPosition=0) 477 | 478 | # Perform actions on simulation 479 | p.stepSimulation(physicsClientId=self.physics_client) 480 | 481 | reward = self.get_reward() 482 | new_state = self.get_state() 483 | done = self.is_terminal_state() 484 | 485 | return new_state, reward, done 486 | 487 | @staticmethod 488 | def print_table(data: List[Tuple[int, str, float, float, tuple]]) -> None: 489 | """ 490 | Prints a table such that the elements received in the "data" parameter are displayed under 491 | "Index", "Name", "Upper Limit", "Lower Limit" and "Axis" columns. It is used to print the Manipulator's 492 | joint's information in an ordered manner. 493 | Args: 494 | data: List where each element contains all the information about a given joint. 495 | Each element on the list will be a tuple containing (index, name, upper_limit, lower_limit, axis). 496 | """ 497 | logger.debug('{:<6} {:<35} {:<15} {:<15} {:<15}'.format('Index', 'Name', 'Upper Limit', 'Lower Limit', 'Axis')) 498 | for index, name, up_limit, lo_limit, axis in data: 499 | logger.debug('{:<6} {:<35} {:<15} {:<15} {:<15}'.format(index, name, up_limit, lo_limit, str(axis))) 500 | 501 | @property 502 | def observation_space(self) -> np.ndarray: 503 | """ 504 | Getter for the observation space of the environment. 505 | Returns: 506 | Numpy array of zeros with same shape as the environment's states. 507 | """ 508 | return self._observation_space 509 | 510 | @property 511 | def action_space(self) -> np.ndarray: 512 | """ 513 | Getter for the action space of the environment. 514 | Returns: 515 | Numpy array of zeros with same shape as the environment's actions. 516 | """ 517 | return self._action_space 518 | -------------------------------------------------------------------------------- /robotic_manipulator_rloa/naf_components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/__init__.py -------------------------------------------------------------------------------- /robotic_manipulator_rloa/naf_components/demo_weights/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/demo_weights/__init__.py -------------------------------------------------------------------------------- /robotic_manipulator_rloa/naf_components/demo_weights/weights_kuka.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/demo_weights/weights_kuka.p -------------------------------------------------------------------------------- /robotic_manipulator_rloa/naf_components/demo_weights/weights_xarm6.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/demo_weights/weights_xarm6.p -------------------------------------------------------------------------------- /robotic_manipulator_rloa/naf_components/naf_algorithm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import time 5 | from typing import Tuple, Dict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | from numpy.typing import NDArray 11 | from torch.nn.utils import clip_grad_norm_ 12 | 13 | from robotic_manipulator_rloa.utils.logger import get_global_logger 14 | from robotic_manipulator_rloa.environment.environment import Environment 15 | from robotic_manipulator_rloa.utils.exceptions import MissingWeightsFile 16 | from robotic_manipulator_rloa.naf_components.naf_neural_network import NAF 17 | from robotic_manipulator_rloa.utils.replay_buffer import ReplayBuffer 18 | 19 | 20 | logger = get_global_logger() 21 | 22 | 23 | class NAFAgent: 24 | 25 | MODEL_PATH = 'model.p' # Filename where the parameters of the trained torch neural network are stored 26 | 27 | def __init__(self, 28 | environment: Environment, 29 | state_size: int, 30 | action_size: int, 31 | layer_size: int, 32 | batch_size: int, 33 | buffer_size: int, 34 | learning_rate: float, 35 | tau: float, 36 | gamma: float, 37 | update_freq: int, 38 | num_updates: int, 39 | checkpoint_frequency: int, 40 | device: torch.device, 41 | seed: int) -> None: 42 | """ 43 | Interacts with and learns from the environment via the NAF algorithm. 44 | Args: 45 | environment: Instance of Environment class. 46 | state_size: Dimension of the states. 47 | action_size: Dimension of the actions. 48 | layer_size: Size for the hidden layers of the neural network. 49 | batch_size: Number of experiences to train with per training batch. 50 | buffer_size: Maximum number of experiences to be stored in Replay Buffer. 51 | learning_rate: Learning rate for neural network's optimizer. 52 | tau: Hyperparameter for soft updating the target network. 53 | gamma: Discount factor. 54 | update_freq: Number of timesteps after which the main neural network is updated. 55 | num_updates: Number of updates performed when learning. 56 | checkpoint_frequency: Number of episodes after which a checkpoint is generated. 57 | device: Device used (CPU or CUDA). 58 | seed: Random seed. 59 | """ 60 | # Create required parent directory 61 | os.makedirs('checkpoints/', exist_ok=True) 62 | 63 | self.environment = environment 64 | self.state_size = state_size 65 | self.action_size = action_size 66 | self.layer_size = layer_size 67 | self.buffer_size = buffer_size 68 | self.learning_rate = learning_rate 69 | random.seed(seed) 70 | self.device = device 71 | self.tau = tau 72 | self.gamma = gamma 73 | self.update_freq = update_freq 74 | self.num_updates = num_updates 75 | self.batch_size = batch_size 76 | self.checkpoint_frequency = checkpoint_frequency 77 | 78 | # Initalize Q-Networks 79 | self.qnetwork_main = NAF(state_size, action_size, layer_size, seed, device).to(device) 80 | self.qnetwork_target = NAF(state_size, action_size, layer_size, seed, device).to(device) 81 | 82 | # Define Adam as optimizer 83 | self.optimizer = optim.Adam(self.qnetwork_main.parameters(), lr=learning_rate) 84 | 85 | # Initialize Replay memory 86 | self.memory = ReplayBuffer(buffer_size, batch_size, self.device, seed) 87 | 88 | # Initialize update time step counter (for updating every {update_freq} steps) 89 | self.update_t_step = 0 90 | 91 | def initialize_pretrained_agent_from_episode(self, episode: int) -> None: 92 | """ 93 | Loads the previously trained weights into the main and target neural networks. 94 | The pretrained weights are retrieved from the checkpoints generated on a training execution, so 95 | the episode provided must be present in the checkpoints/ folder. 96 | Args: 97 | episode: Episode from which to retrieve the pretrained weights. 98 | Raises: 99 | MissingWeightsFile: The weights.p file is not present in the checkpoints/{episode}/ folder provided. 100 | """ 101 | # Check if file is present in checkpoints/{episode}/ directory 102 | if not os.path.isfile(f'checkpoints/{episode}/weights.p'): 103 | raise MissingWeightsFile 104 | 105 | logger.debug(f'Loading naf_components weights from trained naf_components on episode {episode}...') 106 | self.qnetwork_main.load_state_dict(torch.load(f'checkpoints/{episode}/weights.p')) 107 | self.qnetwork_target.load_state_dict(torch.load(f'checkpoints/{episode}/weights.p')) 108 | logger.info(f'Loaded weights from trained naf_components on episode {episode}') 109 | 110 | def initialize_pretrained_agent_from_weights_file(self, weights_path: str) -> None: 111 | """ 112 | Loads the previously trained weights into the main and target neural networks. 113 | The pretrained weights are retrieved from a .p file containing the weights, located in 114 | the {weights_path} path. 115 | Args: 116 | weights_path: Path to the .p file containing the network's weights. 117 | Raises: 118 | MissingWeightsFile: The file path provided does not exist. 119 | """ 120 | # Check if file is present 121 | if not os.path.isfile(weights_path): 122 | raise MissingWeightsFile 123 | 124 | logger.debug('Loading naf_components weights from trained naf_components...') 125 | self.qnetwork_main.load_state_dict(torch.load(weights_path)) 126 | self.qnetwork_target.load_state_dict(torch.load(weights_path)) 127 | logger.info('Loaded pre-trained weights for the NN') 128 | 129 | def step(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None: 130 | """ 131 | Stores in the ReplayBuffer the new experience composed by the parameters received, 132 | and learns only if the Buffer contains enough experiences to fill a batch. The 133 | learning will occur if the update frequency {update_freq} is reached, in which case it 134 | will learn {num_updates} times. 135 | Args: 136 | state: Current state. 137 | action: Action performed from state {state}. 138 | reward: Reward obtained after performing action {action} from state {state}. 139 | next_state: New state reached after performing action {action} from state {state}. 140 | done: Integer (0 or 1) indicating whether a terminal state have been reached. 141 | """ 142 | 143 | # Save experience in replay memory 144 | self.memory.add(state, action, reward, next_state, done) 145 | 146 | # Learning will be performed every {update_freq}} time-steps. 147 | self.update_t_step = (self.update_t_step + 1) % self.update_freq # Update time step counter 148 | if self.update_t_step == 0: 149 | # If enough samples are available in memory, get random subset and learn 150 | if len(self.memory) > self.batch_size: 151 | for _ in range(self.num_updates): 152 | # Pick random batch of experiences from memory 153 | experiences = self.memory.sample() 154 | 155 | # Learn from experiences and get loss 156 | self.learn(experiences) 157 | 158 | def act(self, state: NDArray) -> NDArray: 159 | """ 160 | Extracts the action which maximizes the Q-Function, by getting the output of the mu layer 161 | of the main neural network. 162 | Args: 163 | state: Current state from which to pick the best action. 164 | Returns: 165 | Action which maximizes Q-Function. 166 | """ 167 | state = torch.from_numpy(state).float().to(self.device) 168 | 169 | # Set evaluation mode on naf_components for obtaining a prediction 170 | self.qnetwork_main.eval() 171 | with torch.no_grad(): 172 | # Get the action with maximum Q-Value from the local network 173 | action, _, _ = self.qnetwork_main(state.unsqueeze(0)) 174 | 175 | # Set training mode on naf_components for future use 176 | self.qnetwork_main.train() 177 | 178 | return action.cpu().squeeze().numpy() 179 | 180 | def learn(self, experiences: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> None: 181 | """ 182 | Calculate the Q-Function estimate from the main neural network, the Target value from 183 | the target neural network, and calculate the loss with both values, all by feeding the received 184 | batch of experience tuples to both networks. After loss is calculated, backpropagation is performed on the 185 | main network from the given loss, so that the weights of the main network are updated. 186 | Args: 187 | experiences: Tuple of five elements, where each element is a torch.Tensor of length {batch_size}. 188 | """ 189 | # Set gradients of all optimized torch Tensors to zero 190 | self.optimizer.zero_grad() 191 | states, actions, rewards, next_states, dones = experiences 192 | 193 | # Get the Value Function for the next state from target naf_components (no_grad() disables gradient calculation) 194 | with torch.no_grad(): 195 | _, _, V_ = self.qnetwork_target(next_states) 196 | 197 | # Compute the target Value Functions for the given experiences. 198 | # The target value is calculated as target_val = r + gamma * V(s') 199 | target_values = rewards + (self.gamma * V_) 200 | 201 | # Compute the expected Value Function from main network 202 | _, q_estimate, _ = self.qnetwork_main(states, actions) 203 | 204 | # Compute loss between target value and expected Q value 205 | loss = F.mse_loss(q_estimate, target_values) 206 | 207 | # Perform backpropagation for minimizing loss 208 | loss.backward() 209 | clip_grad_norm_(self.qnetwork_main.parameters(), 1) 210 | self.optimizer.step() 211 | 212 | # Update the target network softly with the local one 213 | self.soft_update(self.qnetwork_main, self.qnetwork_target) 214 | 215 | # return loss.detach().cpu().numpy() 216 | 217 | def soft_update(self, main_nn: NAF, target_nn: NAF) -> None: 218 | """ 219 | Soft update naf_components parameters following this formula:\n 220 | θ_target = τ*θ_local + (1 - τ)*θ_target 221 | Args: 222 | main_nn: Main torch neural network. 223 | target_nn: Target torch neural network. 224 | """ 225 | for target_param, main_param in zip(target_nn.parameters(), main_nn.parameters()): 226 | target_param.data.copy_(self.tau * main_param.data + (1. - self.tau) * target_param.data) 227 | 228 | def run(self, frames: int = 1000, episodes: int = 1000, verbose: bool = True) -> Dict[int, Tuple[float, int]]: 229 | """ 230 | Execute training flow of the NAF algorithm on the given environment. 231 | Args: 232 | frames: Number of maximum frames or timesteps per episode. 233 | episodes: Number of episodes required to terminate the training. 234 | verbose: Boolean indicating whether many or few logs are shown. 235 | Returns: 236 | Returns the score history generated along the training. 237 | """ 238 | logger.info('Training started') 239 | # Initialize 'scores' dictionary to store rewards and timesteps executed for each episode 240 | scores = {episode: (0, 0) for episode in range(1, episodes + 1)} 241 | 242 | # Iterate through every episode 243 | for episode in range(episodes): 244 | logger.info(f'Running Episode {episode + 1}') 245 | start = time.time() # Timer to measure execution time per episode 246 | state = self.environment.reset(verbose) 247 | score, mean = 0, list() 248 | 249 | for frame in range(1, frames + 1): 250 | if verbose: logger.info(f'Running frame {frame} in episode {episode + 1}') 251 | 252 | # Pick action according to current state 253 | if verbose: logger.info(f'Current State: {state}') 254 | action = self.act(state) 255 | if verbose: logger.info(f'Action chosen for the given state is: {action}') 256 | 257 | # Perform action on environment and get new state and reward 258 | next_state, reward, done = self.environment.step(action) 259 | 260 | # Save the experience in the ReplayBuffer, and learn from previous experiences if applicable 261 | self.step(state, action, reward, next_state, done) 262 | 263 | state = next_state # Update state to next state 264 | score += reward 265 | mean.append(reward) 266 | 267 | if verbose: logger.info(f'Reward: {reward} - Cumulative reward: {score}\n') 268 | 269 | if done: 270 | break 271 | 272 | # Updates scores history 273 | scores[episode + 1] = (score, frame) # save most recent score and last frame 274 | logger.info(f'Reward: {score}') 275 | logger.info(f'Number of frames: {frame}') 276 | logger.info(f'Mean of rewards on this episode: {sum(mean) / frames}') 277 | logger.info(f'Time taken for this episode: {round(time.time() - start, 3)} secs\n') 278 | 279 | # Save the episode's performance if it is a checkpoint episode 280 | if (episode + 1) % self.checkpoint_frequency == 0: 281 | # Create parent directory for current episode 282 | os.makedirs(f'checkpoints/{episode + 1}/', exist_ok=True) 283 | # Save naf_components weights 284 | torch.save(self.qnetwork_main.state_dict(), f'checkpoints/{episode + 1}/weights.p') 285 | # Save naf_components's performance metrics 286 | with open(f'checkpoints/{episode + 1}/scores.txt', 'w') as f: 287 | f.write(json.dumps(scores)) 288 | 289 | torch.save(self.qnetwork_main.state_dict(), self.MODEL_PATH) 290 | logger.info(f'Model has been successfully saved in {self.MODEL_PATH}') 291 | 292 | return scores 293 | -------------------------------------------------------------------------------- /robotic_manipulator_rloa/naf_components/naf_neural_network.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Any 2 | 3 | import torch 4 | from torch import nn 5 | from torch.distributions import MultivariateNormal 6 | 7 | 8 | class NAF(nn.Module): 9 | 10 | def __init__(self, state_size: int, action_size: int, layer_size: int, seed: int, device: torch.device) -> None: 11 | """ 12 | Model to be used in the NAF algorithm. Network Architecture:\n 13 | - Common network\n 14 | - Linear + BatchNormalization (input_shape, layer_size)\n 15 | - Linear + BatchNormalization (layer_size, layer_size)\n 16 | 17 | - Output for mu network (used for calculating A)\n 18 | - Linear (layer_size, action_size)\n 19 | 20 | - Output for V network (used for calculating Q = A + V)\n 21 | - Linear (layer_size, 1)\n 22 | 23 | - Output for L network (used for calculating P = L . Lt)\n 24 | - Linear (layer_size, (action_size*action_size+1)/2)\n 25 | Args: 26 | state_size: Dimension of a state. 27 | action_size: Dimension of an action. 28 | layer_size: Size of the hidden layers of the neural network. 29 | seed: Random seed. 30 | device: CUDA device. 31 | """ 32 | super(NAF, self).__init__() 33 | self.seed = torch.manual_seed(seed) 34 | self.state_size = state_size 35 | self.action_size = action_size 36 | self.device = device 37 | 38 | # DEFINE THE MODEL 39 | 40 | # Define the first NN hidden layer + BatchNormalization 41 | self.input_layer = nn.Linear(in_features=self.state_size, out_features=layer_size) 42 | self.bn1 = nn.BatchNorm1d(layer_size) 43 | 44 | # Define the second NN hidden layer + BatchNormalization 45 | self.hidden_layer = nn.Linear(in_features=layer_size, out_features=layer_size) 46 | self.bn2 = nn.BatchNorm1d(layer_size) 47 | 48 | # Define the output layer for the mu Network 49 | self.action_values = nn.Linear(in_features=layer_size, out_features=action_size) 50 | # Define the output layer for the V Network 51 | self.value = nn.Linear(in_features=layer_size, out_features=1) 52 | # Define the output layer for the L Network 53 | self.matrix_entries = nn.Linear(in_features=layer_size, 54 | out_features=int(self.action_size * (self.action_size + 1) / 2)) 55 | 56 | def forward(self, 57 | input_: torch.Tensor, 58 | action: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[Any], Any]: 59 | """ 60 | Forward propagation. 61 | It feeds the NN with the input, and gets the output for the mu, V and L networks.\n 62 | - Output from the L network is used to create the P matrix.\n 63 | - Output from the V network is used to calculate the Q value: Q = A + V\n 64 | - Output from the mu network is used to calculate A. The action output of mu nn is considered 65 | the action that maximizes Q-function. 66 | Args: 67 | input_: Input for the neural network's input layer. 68 | action: Current action, used for calculating the Q-Function estimate. 69 | Returns: 70 | Returns a tuple containing the action which maximizes the Q-Function, the 71 | Q-Function estimate and the Value Function. 72 | """ 73 | # ============ FEED INPUT DATA TO THE NEURAL NETWORK ================================= 74 | 75 | # Feed the input to the INPUT_LAYER and apply ReLu activation function (+ BatchNorm) 76 | x = torch.relu(self.bn1(self.input_layer(input_))) 77 | # Feed the output of INPUT_LAYER to the HIDDEN_LAYER layer and apply ReLu activation function (+ BatchNorm) 78 | x = torch.relu(self.bn2(self.hidden_layer(x))) 79 | 80 | # Feed the output of HIDDEN_LAYER to the mu layer and apply tanh activation function 81 | action_value = torch.tanh(self.action_values(x)) 82 | 83 | # Feed the output of HIDDEN_LAYER to the L layer and apply tanh activation function 84 | matrix_entries = torch.tanh(self.matrix_entries(x)) 85 | 86 | # Feed the output of HIDDEN_LAYER to the V layer 87 | V = self.value(x) 88 | 89 | # Modifies the output of the mu layer by unsqueezing it (all tensor as a 1D vector) 90 | action_value = action_value.unsqueeze(-1) 91 | 92 | # ============ CREATE L MATRIX from the outputs of the L layer ======================= 93 | 94 | # Create lower-triangular matrix, size: (n_samples, action_size, action_size) 95 | L = torch.zeros((input_.shape[0], self.action_size, self.action_size)).to(self.device) 96 | # Get lower triagular indices (returns list of 2 elems, where the first row contains row coordinates 97 | # of all indices and the second row contains column coordinates) 98 | lower_tri_indices = torch.tril_indices(row=self.action_size, col=self.action_size, offset=0) 99 | # Fill matrix with the outputs of the L layer 100 | L[:, lower_tri_indices[0], lower_tri_indices[1]] = matrix_entries 101 | # Raise the diagonal elements of the matrix to the square 102 | L.diagonal(dim1=1, dim2=2).exp_() 103 | # Calculate state-dependent, positive-definite square matrix P 104 | P = L * L.transpose(2, 1) 105 | 106 | # ============================ CALCULATE Q-VALUE ===================================== # 107 | 108 | Q = None 109 | if action is not None: 110 | # Calculate Advantage Function estimate 111 | A = (-0.5 * torch.matmul(torch.matmul((action.unsqueeze(-1) - action_value).transpose(2, 1), P), 112 | (action.unsqueeze(-1) - action_value))).squeeze(-1) 113 | 114 | # Calculate Q-values 115 | Q = A + V 116 | 117 | # =========================== ADD NOISE TO ACTION ==================================== # 118 | 119 | dist = MultivariateNormal(action_value.squeeze(-1), torch.inverse(P)) 120 | action = dist.sample() 121 | action = torch.clamp(action, min=-1, max=1) 122 | 123 | return action, Q, V 124 | -------------------------------------------------------------------------------- /robotic_manipulator_rloa/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/utils/__init__.py -------------------------------------------------------------------------------- /robotic_manipulator_rloa/utils/collision_detector.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import numpy as np 5 | import pybullet as p 6 | from numpy.typing import NDArray 7 | 8 | 9 | @dataclass 10 | class CollisionObject: 11 | """ 12 | Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to 13 | calculate distances to other bodies. 14 | """ 15 | body: str 16 | link: int 17 | 18 | 19 | class CollisionDetector: 20 | 21 | def __init__(self, collision_object: CollisionObject, obstacle_ids: List[str]): 22 | """ 23 | Calculates distances between bodies' joints. 24 | Args: 25 | collision_object: CollisionObject instance, which indicates the body/joint from 26 | which to calculate distances/collisions. 27 | obstacle_ids: Obstacle body UID. Distances are calculated from the joint/body given in the 28 | "collision_object" parameter to the "obstacle_ids" bodies. 29 | """ 30 | self.obstacles = obstacle_ids 31 | self.collision_object = collision_object 32 | 33 | def compute_distances(self, max_distance: float = 10.0) -> NDArray: 34 | """ 35 | Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object 36 | to the bodies defined in self.obstacles. 37 | Args: 38 | max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value 39 | for the distance between such bodies will be max_distance. 40 | Returns: 41 | A numpy array of distances, one per pair of collision objects. 42 | """ 43 | distances = list() 44 | for obstacle in self.obstacles: 45 | 46 | # Compute the shortest distances between the collision-object and the given obstacle 47 | closest_points = p.getClosestPoints( 48 | self.collision_object.body, 49 | obstacle, 50 | distance=max_distance, 51 | linkIndexA=self.collision_object.link 52 | ) 53 | 54 | # If bodies are above max_distance apart, nothing is returned, so 55 | # we just saturate at max_distance. Otherwise, take the minimum 56 | if len(closest_points) == 0: 57 | distances.append(max_distance) 58 | else: 59 | distances.append(np.min([point[8] for point in closest_points])) 60 | 61 | return np.array(distances) 62 | 63 | def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray: 64 | """ 65 | Compute collisions between manipulator's parts. 66 | Args: 67 | affected_joints: Joints to consider when calculating distances. 68 | max_distance: Maximum distance to be considered. Distances further than this will be ignored, and 69 | the "max_distance" value will be returned. 70 | Returns: 71 | Array where each element corresponds to the distances from a given joint to the other joints. 72 | """ 73 | distances = list() 74 | for joint_ind in affected_joints: 75 | 76 | # Collisions with the previous and next joints are omitted, as they will be always in contact 77 | if (self.collision_object.link == joint_ind) or \ 78 | (joint_ind == self.collision_object.link - 1) or \ 79 | (joint_ind == self.collision_object.link + 1): 80 | continue # pragma: no cover 81 | 82 | # Compute the shortest distances between all object pairs 83 | closest_points = p.getClosestPoints( 84 | self.collision_object.body, 85 | self.collision_object.body, 86 | distance=max_distance, 87 | linkIndexA=self.collision_object.link, 88 | linkIndexB=joint_ind 89 | ) 90 | 91 | # If bodies are above max_distance apart, nothing is returned, so 92 | # we just saturate at max_distance. Otherwise, take the minimum 93 | if len(closest_points) == 0: 94 | distances.append(max_distance) 95 | else: 96 | distances.append(np.min([point[8] for point in closest_points])) 97 | 98 | return np.array(distances) 99 | -------------------------------------------------------------------------------- /robotic_manipulator_rloa/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Optional 4 | 5 | 6 | class FrameworkException(Exception): 7 | 8 | def __init__(self, message: str) -> None: 9 | """ 10 | Creates new FrameworkException. Base class that must be extended by any custom exception. 11 | Args: 12 | message: Info about the exception. 13 | """ 14 | Exception.__init__(self, message) 15 | self.message = message 16 | 17 | def __str__(self) -> str: 18 | """ Returns the string representation of the object """ 19 | return self.__class__.__name__ + ': ' + self.message 20 | 21 | def set_message(self, value: str) -> FrameworkException: 22 | """ 23 | Set the message to be printed on terminal. 24 | Args: 25 | value: Message with info about exception. 26 | Returns: 27 | FrameworkException 28 | """ 29 | self.message = value 30 | return self 31 | 32 | 33 | class InvalidManipulatorFile(FrameworkException): 34 | """ 35 | Exception raised when the URDF/SDF file received cannot be loaded with 36 | Pybullet's loadURDF/loadSDF methods. 37 | """ 38 | message = 'The URDF/SDF file received is not valid' 39 | 40 | def __init__(self, message: Optional[str] = None) -> None: 41 | if message: 42 | self.message = message 43 | FrameworkException.__init__(self, self.message) 44 | 45 | 46 | class InvalidHyperParameter(FrameworkException): 47 | """ 48 | Exception raised when the user tries to set an invalid value on a hyper-parameter 49 | """ 50 | message = 'The hyperparameter received is not valid' 51 | 52 | def __init__(self, message: Optional[str] = None) -> None: 53 | if message: 54 | self.message = message 55 | FrameworkException.__init__(self, self.message) 56 | 57 | 58 | class InvalidEnvironmentParameter(FrameworkException): 59 | """ 60 | Exception raised when the Environment is initialized with invalid parameter/parameters. 61 | """ 62 | message = 'The Environment parameter received is not valid' 63 | 64 | def __init__(self, message: Optional[str] = None) -> None: 65 | if message: 66 | self.message = message 67 | FrameworkException.__init__(self, self.message) 68 | 69 | 70 | class InvalidNAFAgentParameter(FrameworkException): 71 | """ 72 | Exception raised when the NAFAgent is initialized with invalid parameter/parameters. 73 | """ 74 | message = 'The NAF Agent parameter received is not valid' 75 | 76 | def __init__(self, message: Optional[str] = None) -> None: 77 | if message: 78 | self.message = message 79 | FrameworkException.__init__(self, self.message) 80 | 81 | 82 | class EnvironmentNotInitialized(FrameworkException): 83 | """ 84 | Exception raised when the Environment has not yet been initialized and the user tries to 85 | call a method which requires the Environment to be initialized. 86 | """ 87 | message = 'The Environment is not yet initialized. The environment can be initialized via the ' \ 88 | 'initialize_environment() method' 89 | 90 | def __init__(self, message: Optional[str] = None) -> None: 91 | if message: 92 | self.message = message 93 | FrameworkException.__init__(self, self.message) 94 | 95 | 96 | class NAFAgentNotInitialized(FrameworkException): 97 | """ 98 | Exception raised when the NAFAgent has not yet been initialized and the user tries to 99 | call a method which requires the NAFAgent to be initialized. 100 | """ 101 | message = 'The NAF Agent is not yet initialized. The agent can be initialized via the ' \ 102 | 'initialize_naf_agent() method' 103 | 104 | def __init__(self, message: Optional[str] = None) -> None: 105 | if message: 106 | self.message = message 107 | FrameworkException.__init__(self, self.message) 108 | 109 | 110 | class MissingWeightsFile(FrameworkException): 111 | """ 112 | Exception raised when the user loads pretrained weights from an invalid location. 113 | """ 114 | message = 'The weight file provided does not exist' 115 | 116 | def __init__(self, message: Optional[str] = None) -> None: 117 | if message: 118 | self.message = message 119 | FrameworkException.__init__(self, self.message) 120 | 121 | 122 | class ConfigurationIncomplete(FrameworkException): 123 | """ 124 | Exception raised when either the Environment, the NAFAgent or both have not been initialized 125 | yet, and the user tries to execute a training by calling the run_training() method. 126 | """ 127 | message = 'The configuration for the training is incomplete. Either the Environment, the ' \ 128 | 'NAF Agent or both are not yet initialized. The environment can be initialized via the ' \ 129 | 'initialize_environment() method, and the agent can be initialized via the ' \ 130 | 'initialize_naf_agent() method' 131 | 132 | def __init__(self, message: Optional[str] = None) -> None: 133 | if message: 134 | self.message = message 135 | FrameworkException.__init__(self, self.message) 136 | -------------------------------------------------------------------------------- /robotic_manipulator_rloa/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | from logging import LogRecord 4 | from logging.config import dictConfig 5 | from logging.handlers import RotatingFileHandler 6 | 7 | 8 | class Logger: # pylint: disable=too-many-instance-attributes 9 | """ Stores and processes the logs """ 10 | 11 | @staticmethod 12 | def set_logger_setup() -> None: 13 | """ 14 | Sets the logger setup with a predefined configuration. 15 | """ 16 | 17 | log_config_dict = Logger.generate_logging_config_dict() 18 | dictConfig(log_config_dict) 19 | rotating_file_handler = RotatingFileHandler(filename='training_logs.log', mode='a', maxBytes=50000000, 20 | backupCount=10, encoding='utf-8') 21 | rotating_file_handler.setFormatter(logging.Formatter( 22 | '"%(levelname)s"|"{datetime}"|%(message)s'.format(datetime=datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ')) 23 | )) 24 | rotating_file_handler.setLevel(logging.INFO) 25 | logger = get_global_logger() 26 | logger.addHandler(rotating_file_handler) 27 | logger.setLevel(20) 28 | 29 | @staticmethod 30 | def generate_logging_config_dict() -> dict: 31 | """ 32 | Generates the configuration dictionary that is used to configure the logger. 33 | Returns: 34 | Configuration dictionary. 35 | """ 36 | return { 37 | 'version': 1, 38 | 'disable_existing_loggers': False, 39 | 'formatters': { 40 | 'custom_formatter': { 41 | '()': CustomFormatter, 42 | 'dateformat': '%Y-%m-%dT%H:%M:%S.%06d%z' 43 | }, 44 | }, 45 | 'handlers': { 46 | 'debug_console_handler': { 47 | 'level': 'NOTSET', 48 | 'formatter': 'custom_formatter', 49 | 'class': 'logging.StreamHandler', 50 | 'stream': 'ext://sys.stdout', 51 | } 52 | }, 53 | 'loggers': { 54 | '': { 55 | 'handlers': ['debug_console_handler'], 56 | 'level': 'NOTSET', 57 | }, 58 | } 59 | } 60 | 61 | 62 | def get_global_logger() -> logging.Logger: 63 | """ 64 | Getter for the logger. 65 | Returns: 66 | Logger instance to be used on the framework. 67 | """ 68 | return logging.getLogger(__name__) 69 | 70 | 71 | class CustomFormatter(logging.Formatter): 72 | 73 | def __init__(self, dateformat: str = None): 74 | """ 75 | CustomFormatter for the logger. 76 | """ 77 | super().__init__() 78 | self.dateformat = dateformat 79 | 80 | def format(self, record: LogRecord) -> str: 81 | """ 82 | Formats the provided LogRecord instance. 83 | Returns: 84 | Formatted LogRecord as string. 85 | """ 86 | # Set format and colors 87 | grey = "\033[38;20m" 88 | green = "\033[32;20m" 89 | yellow = "\033[33;20m" 90 | red = "\033[31;20m" 91 | bold_red = "\033[31;1m" 92 | reset = "\033[0m" 93 | format_ = '[%(levelname)-8s] - {datetime} - %(message)s'.format( 94 | datetime=datetime.now().astimezone().strftime('%Y-%m-%dT%H:%M:%S.%f%z') 95 | ) 96 | 97 | self.FORMATS = { 98 | logging.DEBUG: green + format_ + reset, 99 | logging.INFO: grey + format_ + reset, 100 | logging.WARNING: yellow + format_ + reset, 101 | logging.ERROR: red + format_ + reset, 102 | logging.CRITICAL: bold_red + format_ + reset 103 | } 104 | 105 | log_format = self.FORMATS.get(record.levelno) 106 | 107 | formatter = logging.Formatter(log_format, datefmt=self.dateformat) 108 | return formatter.format(record) 109 | -------------------------------------------------------------------------------- /robotic_manipulator_rloa/utils/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import deque, namedtuple 3 | from typing import Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from numpy.typing import NDArray 8 | 9 | 10 | # deque is a Doubly Ended Queue, provides O(1) complexity for pop and append actions 11 | # namedtuple is a tuple that can be accessed by both its index and attributes 12 | 13 | 14 | class ReplayBuffer: 15 | 16 | def __init__(self, buffer_size: int, batch_size: int, device: torch.device, seed: int): 17 | """ 18 | Buffer to store experience tuples. Each experience has the following structure: 19 | (state, action, reward, next_state, done) 20 | Args: 21 | buffer_size: Maximum size for the buffer. Higher buffer size imply higher RAM consumption. 22 | batch_size: Number of experiences to be retrieved from the ReplayBuffer per batch. 23 | device: CUDA device. 24 | seed: Random seed. 25 | """ 26 | self.device = device 27 | self.memory = deque(maxlen=buffer_size) 28 | self.batch_size = batch_size 29 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 30 | random.seed(seed) 31 | 32 | def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None: 33 | """ 34 | Add a new experience to the Replay Buffer. 35 | Args: 36 | state: NDArray of the current state. 37 | action: NDArray of the action taken from state {state}. 38 | reward: Reward obtained after performing action {action} from state {state}. 39 | next_state: NDArray of the state reached after performing action {action} from state {state}. 40 | done: Integer (0 or 1) indicating whether the next_state is a terminal state. 41 | """ 42 | # Create namedtuple object from the experience 43 | exp = self.experience(state, action, reward, next_state, done) 44 | # Add the experience object to memory 45 | self.memory.append(exp) 46 | 47 | def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 48 | """ 49 | Randomly sample a batch of experiences from memory. 50 | Returns: 51 | Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element 52 | in the tuple is a torch Tensor composed of {batch_size} items. 53 | """ 54 | # Randomly sample a batch of experiences 55 | experiences = random.sample(self.memory, k=self.batch_size) 56 | 57 | states = torch.from_numpy( 58 | np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to( 59 | self.device) 60 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device) 61 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device) 62 | next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to( 63 | self.device) 64 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to( 65 | self.device) 66 | 67 | return states, actions, rewards, next_states, dones 68 | 69 | def __len__(self) -> int: 70 | """ 71 | Return the current size of the Replay Buffer 72 | Returns: 73 | Size of Replay Buffer 74 | """ 75 | return len(self.memory) 76 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/__init__.py -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/__init__.py -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/environment/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/environment/__init__.py -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/naf_components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/naf_components/__init__.py -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/naf_components/test_naf_algorithm.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from mock import MagicMock, patch, mock_open 4 | import pytest 5 | 6 | from robotic_manipulator_rloa.naf_components.naf_algorithm import NAFAgent 7 | from robotic_manipulator_rloa.utils.exceptions import MissingWeightsFile 8 | 9 | 10 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.time') 11 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.clip_grad_norm_') 12 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.F') 13 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch') 14 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.logger') 15 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment') 16 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.os') 17 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.random') 18 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF') 19 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.ReplayBuffer') 20 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam') 21 | def test_naf_agent(mock_optim: MagicMock, 22 | mock_replay_buffer: MagicMock, 23 | mock_naf: MagicMock, 24 | mock_random: MagicMock, 25 | mock_os: MagicMock, 26 | mock_environment: MagicMock, 27 | mock_logger: MagicMock, 28 | mock_torch: MagicMock, 29 | mock_F: MagicMock, 30 | mock_clip_grad_norm_: MagicMock, 31 | mock_time: MagicMock) -> None: 32 | """ 33 | Test for the NAFAgent class constructor 34 | """ 35 | # ================== Constructor ========================================== 36 | naf_agent = NAFAgent(environment=mock_environment, 37 | state_size=10, 38 | action_size=5, 39 | layer_size=128, 40 | batch_size=64, 41 | buffer_size=100000, 42 | learning_rate=0.001, 43 | tau=0.001, 44 | gamma=0.001, 45 | update_freq=1, 46 | num_updates=1, 47 | checkpoint_frequency=1, 48 | device='cpu', 49 | seed=0) 50 | 51 | mock_os.makedirs.return_value = None 52 | fake_naf_instance = mock_naf.return_value 53 | fake_torch_naf = fake_naf_instance.to.return_value 54 | fake_network_params = fake_torch_naf.parameters.return_value 55 | fake_optimizer = mock_optim.return_value 56 | fake_replay_buffer = mock_replay_buffer.return_value 57 | 58 | assert naf_agent.environment == mock_environment 59 | assert naf_agent.state_size == 10 60 | assert naf_agent.action_size == 5 61 | assert naf_agent.layer_size == 128 62 | assert naf_agent.buffer_size == 100000 63 | assert naf_agent.learning_rate == 0.001 64 | mock_random.seed.assert_any_call(0) 65 | assert naf_agent.device == 'cpu' 66 | assert naf_agent.tau == 0.001 67 | assert naf_agent.gamma == 0.001 68 | assert naf_agent.update_freq == 1 69 | assert naf_agent.num_updates == 1 70 | assert naf_agent.batch_size == 64 71 | assert naf_agent.checkpoint_frequency == 1 72 | assert naf_agent.qnetwork_main == fake_torch_naf 73 | assert naf_agent.qnetwork_target == fake_torch_naf 74 | mock_naf.assert_any_call(10, 5, 128, 0, 'cpu') 75 | assert naf_agent.optimizer == fake_optimizer 76 | mock_optim.assert_any_call(fake_network_params, lr=0.001) 77 | assert naf_agent.memory == fake_replay_buffer 78 | mock_replay_buffer.assert_any_call(100000, 64, 'cpu', 0) 79 | assert naf_agent.update_t_step == 0 80 | 81 | # ================== initialize_pretrained_agent_from_episode() =========== 82 | fake_torch_load = mock_torch.load.return_value 83 | mock_os.path.isfile.return_value = True 84 | naf_agent.initialize_pretrained_agent_from_episode(0) 85 | fake_torch_naf.load_state_dict.assert_any_call(fake_torch_load) 86 | mock_torch.load.assert_any_call('checkpoints/0/weights.p') 87 | 88 | # ================== initialize_pretrained_agent_from_episode() when file is not present 89 | mock_os.path.isfile.return_value = False 90 | with pytest.raises(MissingWeightsFile): 91 | naf_agent.initialize_pretrained_agent_from_episode(0) 92 | 93 | # ================== initialize_pretrained_agent_from_weights_file() ====== 94 | fake_torch_load = mock_torch.load.return_value 95 | mock_os.path.isfile.return_value = True 96 | naf_agent.initialize_pretrained_agent_from_weights_file('weights.p') 97 | fake_torch_naf.load_state_dict.assert_any_call(fake_torch_load) 98 | mock_torch.load.assert_any_call('weights.p') 99 | 100 | # ================== initialize_pretrained_agent_from_weights_file() when file is not present 101 | mock_os.path.isfile.return_value = False 102 | with pytest.raises(MissingWeightsFile): 103 | naf_agent.initialize_pretrained_agent_from_weights_file('weights.p') 104 | 105 | 106 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.ReplayBuffer') 107 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment') 108 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.learn') 109 | def test_naf_agent__step(mock_learn: MagicMock, 110 | mock_environment: MagicMock, 111 | mock_replay_buffer: MagicMock) -> None: 112 | """Test for the step() method of the NAFAgent class""" 113 | # ================== step() =============================================== 114 | naf_agent = NAFAgent(environment=mock_environment, 115 | state_size=10, 116 | action_size=5, 117 | layer_size=128, 118 | batch_size=64, 119 | buffer_size=100000, 120 | learning_rate=0.001, 121 | tau=0.001, 122 | gamma=0.001, 123 | update_freq=1, 124 | num_updates=1, 125 | checkpoint_frequency=1, 126 | device='cpu', 127 | seed=0) 128 | fake_replay_buffer = mock_replay_buffer.return_value 129 | fake_replay_buffer.__len__.return_value = 100 130 | fake_experiences = MagicMock() 131 | fake_replay_buffer.sample.return_value = fake_experiences 132 | 133 | naf_agent.step('state', 'action', 'reward', 'next_state', 'done') 134 | 135 | fake_replay_buffer.add.assert_any_call('state', 'action', 'reward', 'next_state', 'done') 136 | assert naf_agent.update_t_step == 0 137 | fake_replay_buffer.sample.assert_any_call() 138 | mock_learn.assert_any_call(fake_experiences) 139 | 140 | 141 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment') 142 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch') 143 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF') 144 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam') 145 | def test_naf_agent__act(mock_optimizer: MagicMock, 146 | mock_naf: MagicMock, 147 | mock_torch: MagicMock, 148 | mock_environment: MagicMock) -> None: 149 | """Test for the act() method of the NAFAgent class""" 150 | # ================== act() ================================================ 151 | naf_agent = NAFAgent(environment=mock_environment, 152 | state_size=10, 153 | action_size=5, 154 | layer_size=128, 155 | batch_size=64, 156 | buffer_size=100000, 157 | learning_rate=0.001, 158 | tau=0.001, 159 | gamma=0.001, 160 | update_freq=1, 161 | num_updates=1, 162 | checkpoint_frequency=1, 163 | device='cpu', 164 | seed=0) 165 | fake_state_from_numpy = mock_torch.from_numpy.return_value 166 | fake_state_float = fake_state_from_numpy.float.return_value 167 | fake_state_to = fake_state_float.to.return_value 168 | fake_state_unsqueezed = fake_state_to.unsqueeze.return_value 169 | fake_action = MagicMock() 170 | fake_naf_instance = mock_naf.return_value 171 | fake_torch_naf = fake_naf_instance.to.return_value 172 | fake_torch_naf.return_value = (fake_action, None, None) 173 | fake_action_cpu = fake_action.cpu.return_value 174 | fake_action_squeeze = fake_action_cpu.squeeze.return_value 175 | fake_action_numpy = fake_action_squeeze.numpy.return_value 176 | fake_no_grad = MagicMock(__enter__=MagicMock()) 177 | mock_torch.no_grad.return_value = fake_no_grad 178 | 179 | assert naf_agent.act('state') == fake_action_numpy 180 | 181 | fake_torch_naf.eval.assert_any_call() 182 | mock_torch.no_grad.assert_any_call() 183 | fake_torch_naf.assert_any_call(fake_state_unsqueezed) 184 | fake_torch_naf.train.assert_any_call() 185 | 186 | 187 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment') 188 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch') 189 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF') 190 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam') 191 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.F') 192 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.clip_grad_norm_') 193 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.soft_update') 194 | def test_naf_agent__learn(mock_soft_update: MagicMock, 195 | mock_clip_grad_norm_: MagicMock, 196 | mock_F: MagicMock, 197 | mock_optimizer: MagicMock, 198 | mock_naf: MagicMock, 199 | mock_torch: MagicMock, 200 | mock_environment: MagicMock) -> None: 201 | """Test for the learn() method in the NAFAgent class""" 202 | # ================== learn() ============================================== 203 | naf_agent = NAFAgent(environment=mock_environment, 204 | state_size=10, 205 | action_size=5, 206 | layer_size=128, 207 | batch_size=64, 208 | buffer_size=100000, 209 | learning_rate=0.001, 210 | tau=0.001, 211 | gamma=0.001, 212 | update_freq=1, 213 | num_updates=1, 214 | checkpoint_frequency=1, 215 | device='cpu', 216 | seed=0) 217 | fake_states, fake_actions, fake_rewards, fake_next_states, fake_dones = \ 218 | MagicMock(), MagicMock(), 1, MagicMock(), MagicMock() 219 | fake_experiences = (fake_states, fake_actions, fake_rewards, fake_next_states, fake_dones) 220 | fake_v, fake_q_estimate = 1, 1 221 | fake_naf_instance = mock_naf.return_value 222 | fake_torch_naf = fake_naf_instance.to.return_value 223 | fake_network_params = fake_torch_naf.parameters.return_value 224 | fake_torch_naf.return_value = (None, fake_q_estimate, fake_v) 225 | fake_optimizer = mock_optimizer.return_value 226 | fake_loss = mock_F.mse_loss.return_value 227 | 228 | naf_agent.learn(fake_experiences) 229 | 230 | fake_optimizer.zero_grad.assert_any_call() 231 | fake_torch_naf.assert_any_call(fake_next_states) 232 | fake_torch_naf.assert_any_call(fake_states, fake_actions) 233 | mock_F.mse_loss.assert_any_call(fake_q_estimate, 1.001) 234 | fake_loss.backward.assert_any_call() 235 | mock_clip_grad_norm_.assert_any_call(fake_network_params, 1) 236 | fake_optimizer.step.assert_any_call() 237 | mock_soft_update.assert_any_call(fake_torch_naf, fake_torch_naf) 238 | 239 | 240 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment') 241 | def test_naf_agent__soft_update(mock_environment: MagicMock) -> None: 242 | """Test for the soft_update() method for the NAFAgent class""" 243 | # ================== soft_update() ======================================== 244 | naf_agent = NAFAgent(environment=mock_environment, 245 | state_size=10, 246 | action_size=5, 247 | layer_size=128, 248 | batch_size=64, 249 | buffer_size=100000, 250 | learning_rate=0.001, 251 | tau=0.001, 252 | gamma=0.001, 253 | update_freq=1, 254 | num_updates=1, 255 | checkpoint_frequency=1, 256 | device='cpu', 257 | seed=0) 258 | fake_main_nn, fake_target_nn = MagicMock(), MagicMock() 259 | fake_param_value = MagicMock(value=1) 260 | fake_main_params, fake_target_params = MagicMock(data=fake_param_value), MagicMock(data=fake_param_value) 261 | fake_main_nn.parameters.return_value = [fake_main_params] 262 | fake_target_nn.parameters.return_value = [fake_target_params] 263 | 264 | naf_agent.soft_update(fake_main_nn, fake_target_nn) 265 | 266 | 267 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.logger') 268 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment') 269 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam') 270 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.time') 271 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.os') 272 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch') 273 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF') 274 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.act') 275 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.step') 276 | @patch('builtins.open', new_callable=mock_open()) 277 | def test_naf_agent__run(mock_open_file: MagicMock, 278 | mock_step: MagicMock, 279 | mock_act: MagicMock, 280 | mock_naf: MagicMock, 281 | mock_torch: MagicMock, 282 | mock_os: MagicMock, 283 | mock_time: MagicMock, 284 | mock_optimizer: MagicMock, 285 | mock_environment: MagicMock, 286 | mock_logger: MagicMock) -> None: 287 | """Test for the run() method for the NAFAgent class""" 288 | # ================== run() ================================================ 289 | naf_agent = NAFAgent(environment=mock_environment, 290 | state_size=10, 291 | action_size=5, 292 | layer_size=128, 293 | batch_size=64, 294 | buffer_size=100000, 295 | learning_rate=0.001, 296 | tau=0.001, 297 | gamma=0.001, 298 | update_freq=1, 299 | num_updates=1, 300 | checkpoint_frequency=1, 301 | device='cpu', 302 | seed=0) 303 | mock_time.time.return_value = 50 304 | fake_next_state, fake_reward, fake_done = MagicMock(), 1, MagicMock() 305 | fake_state = mock_environment.reset.return_value 306 | fake_action = MagicMock() 307 | mock_act.return_value = fake_action 308 | mock_environment.step.return_value = (fake_next_state, fake_reward, fake_done) 309 | mock_step.return_value = None 310 | mock_os.makedirs.return_value = None 311 | fake_naf_instance = mock_naf.return_value 312 | fake_torch_naf = fake_naf_instance.to.return_value 313 | fake_torch_naf.return_value = (fake_action, None, None) 314 | fake_torch_naf.state_dict.return_value = 'params_dict' 315 | 316 | assert naf_agent.run(1, 1, True) == {1: (1, 1)} 317 | 318 | mock_time.time.assert_any_call() 319 | mock_environment.reset.assert_any_call(True) 320 | mock_act.assert_any_call(fake_state) 321 | mock_environment.step.assert_any_call(fake_action) 322 | mock_step.assert_any_call(fake_state, fake_action, fake_reward, fake_next_state, fake_done) 323 | mock_os.makedirs.assert_any_call('checkpoints/1/', exist_ok=True) 324 | fake_torch_naf.state_dict.assert_any_call() 325 | mock_torch.save.assert_any_call('params_dict', 'checkpoints/1/weights.p') 326 | mock_open_file.assert_called_once_with('checkpoints/1/scores.txt', 'w') 327 | mock_open_file.return_value.__enter__().write.assert_called_once_with(json.dumps({1: (1, 1)})) 328 | mock_torch.save.assert_any_call('params_dict', 'model.p') 329 | -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/naf_components/test_naf_neural_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from mock import MagicMock, patch 4 | 5 | from robotic_manipulator_rloa.naf_components.naf_neural_network import NAF 6 | 7 | 8 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.manual_seed') 9 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.nn.Module') 10 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.nn.Linear') 11 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.nn.BatchNorm1d') 12 | def test_naf(mock_batchnorm: MagicMock, 13 | mock_linear: MagicMock, 14 | mock_nn: MagicMock, 15 | mock_manual_seed: MagicMock) -> None: 16 | """Test for the NAF class constructor""" 17 | mock_manual_seed.return_value = 'manual_seed' 18 | mock_linear.side_effect = [ 19 | 'linear_for_input_layer', 20 | 'linear_for_hidden_layer', 21 | 'linear_for_action_values', 22 | 'linear_for_value', 23 | 'linear_for_matrix_entries' 24 | ] 25 | mock_batchnorm.side_effect = [ 26 | 'batchnorm_for_bn1', 27 | 'batchnorm_for_bn2' 28 | ] 29 | 30 | naf = NAF(state_size=10, action_size=5, layer_size=128, seed=0, device='cpu') 31 | 32 | assert naf.seed == 'manual_seed' 33 | assert naf.state_size == 10 34 | assert naf.action_size == 5 35 | assert naf.device == 'cpu' 36 | assert naf.input_layer == 'linear_for_input_layer' 37 | assert naf.bn1 == 'batchnorm_for_bn1' 38 | assert naf.hidden_layer == 'linear_for_hidden_layer' 39 | assert naf.bn2 == 'batchnorm_for_bn2' 40 | assert naf.action_values == 'linear_for_action_values' 41 | assert naf.value == 'linear_for_value' 42 | assert naf.matrix_entries == 'linear_for_matrix_entries' 43 | mock_manual_seed.assert_any_call(0) 44 | mock_linear.assert_any_call(in_features=10, out_features=128) 45 | mock_linear.assert_any_call(in_features=128, out_features=128) 46 | mock_linear.assert_any_call(in_features=128, out_features=5) 47 | mock_linear.assert_any_call(in_features=128, out_features=1) 48 | mock_linear.assert_any_call(in_features=128, out_features=15) 49 | mock_batchnorm.assert_any_call(128) 50 | mock_batchnorm.assert_any_call(128) 51 | 52 | 53 | def test_naf__forward() -> None: 54 | """Test for the forward() method of the NAF class""" 55 | device = torch.device('cpu') 56 | naf = NAF(state_size=10, action_size=5, layer_size=256, seed=0, device=device) 57 | states = torch.from_numpy(np.stack( 58 | [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])] 59 | )).float().to(device) 60 | actions = torch.from_numpy(np.vstack( 61 | [np.array([0, 1, 2, 3, 4]), np.array([10, 11, 12, 13, 14])] 62 | )).long().to(device) 63 | 64 | action, q, v = naf(states, actions) 65 | 66 | assert q.tolist() == [[-35.50931930541992], [-638.494873046875]] 67 | assert v.tolist() == [[0.5665180683135986], [-0.08311141282320023]] 68 | -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/utils/__init__.py -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/utils/test_collision_detector.py: -------------------------------------------------------------------------------- 1 | from mock import MagicMock, patch 2 | import pytest 3 | import numpy as np 4 | 5 | from robotic_manipulator_rloa.utils.collision_detector import CollisionDetector, CollisionObject 6 | 7 | 8 | def test_collisionobject() -> None: 9 | """Test for the CollisionObject dataclass""" 10 | collision_object = CollisionObject(body='manipulator_body', link=0) 11 | assert collision_object.body == 'manipulator_body' 12 | assert collision_object.link == 0 13 | 14 | 15 | @pytest.mark.parametrize('closest_points', [ 16 | [(None, None, None, None, None, None, None, None, 1), 17 | (None, None, None, None, None, None, None, None, 3), 18 | (None, None, None, None, None, None, None, None, 5)], 19 | [] 20 | ]) 21 | @patch('robotic_manipulator_rloa.utils.collision_detector.p') 22 | def test_collision_detector(mock_pybullet: MagicMock, 23 | closest_points: list) -> None: 24 | """Test for the CollisionDetector class""" 25 | collision_object = CollisionObject(body='manipulator_body', link=0) 26 | collision_detector = CollisionDetector(collision_object=collision_object, 27 | obstacle_ids=['obstacle']) 28 | 29 | assert collision_detector.obstacles == ['obstacle'] 30 | assert collision_detector.collision_object == collision_object 31 | 32 | # ================== TEST FOR compute_distances() method ================== 33 | 34 | mock_pybullet.getClosestPoints.return_value = closest_points 35 | output = np.array([10.0]) if len(closest_points) == 0 else np.array([1]) 36 | 37 | assert collision_detector.compute_distances() == output 38 | 39 | mock_pybullet.getClosestPoints.assert_any_call(collision_object.body, 40 | 'obstacle', 41 | distance=10.0, 42 | linkIndexA=collision_object.link) 43 | 44 | # ================== TEST FOR compute_collisions_in_manipulator() method ================== 45 | 46 | mock_pybullet.getClosestPoints.return_value = closest_points 47 | output = np.array([10.0]) if len(closest_points) == 0 else np.array([1]) 48 | 49 | assert collision_detector.compute_collisions_in_manipulator([0, 3]) == output 50 | 51 | mock_pybullet.getClosestPoints.assert_any_call(collision_object.body, 52 | collision_object.body, 53 | distance=10.0, 54 | linkIndexA=collision_object.link, 55 | linkIndexB=3) 56 | -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/utils/test_exceptions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mock import MagicMock, patch 3 | from typing import Union 4 | 5 | from robotic_manipulator_rloa.utils.exceptions import ( 6 | FrameworkException, 7 | InvalidManipulatorFile, 8 | InvalidHyperParameter, 9 | InvalidEnvironmentParameter, 10 | InvalidNAFAgentParameter, 11 | EnvironmentNotInitialized, 12 | NAFAgentNotInitialized, 13 | MissingWeightsFile, 14 | ConfigurationIncomplete 15 | ) 16 | 17 | 18 | def test_framework_exception() -> None: 19 | """Test for the base exception class FrameworkException""" 20 | exception = FrameworkException(message='test_exception') 21 | assert exception.message == 'test_exception' 22 | assert str(exception) == 'FrameworkException: test_exception' 23 | exception.set_message('test msg') 24 | assert exception.message == 'test msg' 25 | 26 | 27 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 28 | def test_invalid_manipulator_file(msg: Union[str, None]) -> None: 29 | """Test for the InvalidManipulatorFile exception class""" 30 | if msg: 31 | exception = InvalidManipulatorFile(msg) 32 | assert exception.message == msg 33 | else: 34 | exception = InvalidManipulatorFile() 35 | assert exception.message == 'The URDF/SDF file received is not valid' 36 | 37 | 38 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 39 | def test_invalid_hyperparameter(msg: Union[str, None]) -> None: 40 | """Test for the InvalidHyperParameter exception class""" 41 | if msg: 42 | exception = InvalidHyperParameter(msg) 43 | assert exception.message == msg 44 | else: 45 | exception = InvalidHyperParameter() 46 | assert exception.message == 'The hyperparameter received is not valid' 47 | 48 | 49 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 50 | def test_invalid_environment_parameter(msg: Union[str, None]) -> None: 51 | """Test for the InvalidEnvironmentParameter exception class""" 52 | if msg: 53 | exception = InvalidEnvironmentParameter(msg) 54 | assert exception.message == msg 55 | else: 56 | exception = InvalidEnvironmentParameter() 57 | assert exception.message == 'The Environment parameter received is not valid' 58 | 59 | 60 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 61 | def test_invalid_nafagent_parameter(msg: Union[str, None]) -> None: 62 | """Test for the InvalidNAFAgentParameter exception class""" 63 | if msg: 64 | exception = InvalidNAFAgentParameter(msg) 65 | assert exception.message == msg 66 | else: 67 | exception = InvalidNAFAgentParameter() 68 | assert exception.message == 'The NAF Agent parameter received is not valid' 69 | 70 | 71 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 72 | def test_environment_not_initialized(msg: Union[str, None]) -> None: 73 | """Test for the EnvironmentNotInitialized exception class""" 74 | if msg: 75 | exception = EnvironmentNotInitialized(msg) 76 | assert exception.message == msg 77 | else: 78 | exception = EnvironmentNotInitialized() 79 | assert exception.message == 'The Environment is not yet initialized. The environment can be initialized ' \ 80 | 'via the initialize_environment() method' 81 | 82 | 83 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 84 | def test_nafagent_not_initialized(msg: Union[str, None]) -> None: 85 | """Test for the NAFAgentNotInitialized exception class""" 86 | if msg: 87 | exception = NAFAgentNotInitialized(msg) 88 | assert exception.message == msg 89 | else: 90 | exception = NAFAgentNotInitialized() 91 | assert exception.message == 'The NAF Agent is not yet initialized. The agent can be initialized via the ' \ 92 | 'initialize_naf_agent() method' 93 | 94 | 95 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 96 | def test_missing_weights_file(msg: Union[str, None]) -> None: 97 | """Test for the MissingWeightsFile exception class""" 98 | if msg: 99 | exception = MissingWeightsFile(msg) 100 | assert exception.message == msg 101 | else: 102 | exception = MissingWeightsFile() 103 | assert exception.message == 'The weight file provided does not exist' 104 | 105 | 106 | @pytest.mark.parametrize('msg', [None, 'error_msg']) 107 | def test_configuration_incomplete(msg: Union[str, None]) -> None: 108 | """Test for the ConfigurationIncomplete exception class""" 109 | if msg: 110 | exception = ConfigurationIncomplete(msg) 111 | assert exception.message == msg 112 | else: 113 | exception = ConfigurationIncomplete() 114 | assert exception.message == 'The configuration for the training is incomplete. Either the Environment, ' \ 115 | 'the NAF Agent or both are not yet initialized. The environment can be initialized ' \ 116 | 'via the initialize_environment() method, and the agent can be initialized via ' \ 117 | 'the initialize_naf_agent() method' 118 | -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/utils/test_logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | from mock import MagicMock, patch 5 | 6 | from robotic_manipulator_rloa.utils.logger import Logger, get_global_logger, CustomFormatter 7 | 8 | 9 | @patch('robotic_manipulator_rloa.utils.logger.Logger.generate_logging_config_dict') 10 | @patch('robotic_manipulator_rloa.utils.logger.dictConfig') 11 | @patch('robotic_manipulator_rloa.utils.logger.RotatingFileHandler') 12 | @patch('robotic_manipulator_rloa.utils.logger.logging.Formatter') 13 | @patch('robotic_manipulator_rloa.utils.logger.get_global_logger') 14 | @patch('robotic_manipulator_rloa.utils.logger.datetime') 15 | def test_logger(mock_datetime: MagicMock, 16 | mock_get_global_logger: MagicMock, 17 | mock_formatter: MagicMock, 18 | mock_rotating_file_handler: MagicMock, 19 | mock_dictConfig: MagicMock, 20 | mock_generate_config_dict: MagicMock) -> None: 21 | """Test for the Logger class""" 22 | fake_utcnow, fake_strftime = MagicMock(), MagicMock() 23 | mock_datetime.utcnow.return_value = fake_utcnow 24 | fake_utcnow.strftime.return_value = 'datetime' 25 | fake_config_dict = mock_generate_config_dict.return_value 26 | mock_dictConfig.return_value = None 27 | fake_rotating_file_handler, fake_formatter = MagicMock(), MagicMock() 28 | mock_rotating_file_handler.return_value = fake_rotating_file_handler 29 | mock_formatter.return_value = fake_formatter 30 | fake_rotating_file_handler.setFormatter.return_value = None 31 | fake_rotating_file_handler.setLevel.return_value = None 32 | fake_logger = MagicMock() 33 | mock_get_global_logger.return_value = fake_logger 34 | fake_logger.addHandler.return_value = None 35 | fake_logger.setLevel.return_value = None 36 | 37 | Logger.set_logger_setup() 38 | 39 | mock_generate_config_dict.assert_any_call() 40 | mock_dictConfig.assert_any_call(fake_config_dict) 41 | mock_rotating_file_handler.assert_any_call(filename='training_logs.log', mode='a', maxBytes=50000000, 42 | backupCount=10, encoding='utf-8') 43 | mock_formatter.assert_any_call('"%(levelname)s"|"datetime"|%(message)s') 44 | fake_rotating_file_handler.setFormatter.assert_any_call(fake_formatter) 45 | fake_rotating_file_handler.setLevel.assert_any_call(logging.INFO) 46 | mock_get_global_logger.assert_any_call() 47 | fake_logger.addHandler.assert_any_call(fake_rotating_file_handler) 48 | fake_logger.setLevel.assert_any_call(20) 49 | 50 | 51 | @patch('robotic_manipulator_rloa.utils.logger.CustomFormatter') 52 | def test_logger__generate_logging_config_dict(mock_custom_formatter: MagicMock) -> None: 53 | """Test for the generate_logging_config_dict() method of the Logger class""" 54 | output = { 55 | 'version': 1, 56 | 'disable_existing_loggers': False, 57 | 'formatters': { 58 | 'custom_formatter': { 59 | '()': mock_custom_formatter, 60 | 'dateformat': '%Y-%m-%dT%H:%M:%S.%06d%z' 61 | }, 62 | }, 63 | 'handlers': { 64 | 'debug_console_handler': { 65 | 'level': 'NOTSET', 66 | 'formatter': 'custom_formatter', 67 | 'class': 'logging.StreamHandler', 68 | 'stream': 'ext://sys.stdout', 69 | } 70 | }, 71 | 'loggers': { 72 | '': { 73 | 'handlers': ['debug_console_handler'], 74 | 'level': 'NOTSET', 75 | }, 76 | } 77 | } 78 | assert Logger.generate_logging_config_dict() == output 79 | 80 | 81 | @patch('robotic_manipulator_rloa.utils.logger.logging.getLogger') 82 | def test_get_global_logger(mock_get_logger: MagicMock) -> None: 83 | """Test for the get_global_logger() method""" 84 | mock_get_logger.return_value = 'logger' 85 | assert get_global_logger() == 'logger' 86 | mock_get_logger.assert_any_call('robotic_manipulator_rloa.utils.logger') 87 | 88 | 89 | @patch('robotic_manipulator_rloa.utils.logger.logging.Formatter') 90 | @patch('robotic_manipulator_rloa.utils.logger.datetime') 91 | def test_customformatter(mock_datetime: MagicMock, 92 | mock_formatter: MagicMock) -> None: 93 | """Test for the CustomFormatter class""" 94 | custom_formatter = CustomFormatter(dateformat='dateformat') 95 | assert custom_formatter.dateformat == 'dateformat' 96 | 97 | # ================== TEST FOR format() method ============================= 98 | 99 | fake_formatter, fake_now, fake_astimezone = MagicMock(), MagicMock(), MagicMock() 100 | fake_input_record = MagicMock(levelno=logging.DEBUG) 101 | mock_formatter.return_value = fake_formatter 102 | fake_formatter.format.return_value = 'formatted_record' 103 | mock_datetime.now.return_value = fake_now 104 | fake_now.astimezone.return_value = fake_astimezone 105 | fake_astimezone.strftime.return_value = 'datetime' 106 | 107 | assert custom_formatter.format(fake_input_record) == 'formatted_record' 108 | 109 | mock_formatter.assert_any_call("\033[32;20m" + "[%(levelname)-8s] - datetime - %(message)s" + "\033[0m", 110 | datefmt='dateformat') 111 | 112 | 113 | -------------------------------------------------------------------------------- /tests/robotic_manipulator_rloa/utils/test_replay_buffer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from mock import MagicMock, patch 3 | from collections import deque, namedtuple 4 | import numpy as np 5 | 6 | from robotic_manipulator_rloa.utils.replay_buffer import ReplayBuffer 7 | 8 | 9 | @patch('robotic_manipulator_rloa.utils.replay_buffer.namedtuple') 10 | @patch('robotic_manipulator_rloa.utils.replay_buffer.deque') 11 | @patch('robotic_manipulator_rloa.utils.replay_buffer.random') 12 | def test_replaybuffer(mock_random: MagicMock, 13 | mock_deque: MagicMock, 14 | mock_namedtuple: MagicMock) -> None: 15 | """Test for the ReplayBuffer class""" 16 | replay_buffer = ReplayBuffer(buffer_size=10000, 17 | batch_size=128, 18 | device='cpu', 19 | seed=0) 20 | assert replay_buffer.device == 'cpu' 21 | assert replay_buffer.memory == mock_deque.return_value 22 | assert replay_buffer.batch_size == 128 23 | assert replay_buffer.experience == mock_namedtuple.return_value 24 | 25 | mock_deque.assert_any_call(maxlen=10000) 26 | mock_namedtuple.assert_any_call("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 27 | mock_random.seed.assert_any_call(0) 28 | 29 | 30 | def test_replaybuffer__add() -> None: 31 | """Test for the add() method of the ReplayBuffer class""" 32 | replay_buffer = ReplayBuffer(buffer_size=10000, 33 | batch_size=128, 34 | device='cpu', 35 | seed=0) 36 | named_tuple = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) 37 | deque_ = deque(maxlen=10) 38 | experience = named_tuple('state', 'action', 'reward', 'next_state', 'done') 39 | deque_.append(experience) 40 | 41 | replay_buffer.add('state', 'action', 'reward', 'next_state', 'done') 42 | 43 | assert replay_buffer.memory == deque_ 44 | 45 | 46 | @patch('robotic_manipulator_rloa.utils.replay_buffer.torch') 47 | @patch('robotic_manipulator_rloa.utils.replay_buffer.random.sample') 48 | @patch('robotic_manipulator_rloa.utils.replay_buffer.np.stack') 49 | @patch('robotic_manipulator_rloa.utils.replay_buffer.np.vstack') 50 | def test_replaybuffer__sample(mock_vstack: MagicMock, 51 | mock_stack: MagicMock, 52 | mock_random_sample: MagicMock, 53 | mock_torch: MagicMock) -> None: 54 | """Test for the sample() method for the ReplayBuffer class""" 55 | replay_buffer = ReplayBuffer(buffer_size=10000, 56 | batch_size=1, 57 | device='cpu', 58 | seed=0) 59 | fake_state = np.array([0, 1]) 60 | fake_action = np.array([2, 3]) 61 | fake_reward = 1.5 62 | fake_next_state = np.array([4, 5]) 63 | fake_done = 0 64 | mock_random_sample.return_value = [MagicMock(state=fake_state, 65 | action=fake_action, 66 | reward=fake_reward, 67 | next_state=fake_next_state, 68 | done=fake_done)] 69 | fake_stack_state, fake_stack_next_state = MagicMock(), MagicMock() 70 | fake_vstack_action, fake_vstack_reward, fake_vstack_done = MagicMock(), MagicMock(), MagicMock() 71 | 72 | fake_vstack_done.astype.return_value = 'done_astype' 73 | mock_stack.side_effect = [fake_stack_state, fake_stack_next_state] 74 | mock_vstack.side_effect = [fake_vstack_action, fake_vstack_reward, fake_vstack_done] 75 | 76 | fake_from_numpy_state, fake_from_numpy_action, fake_from_numpy_reward, \ 77 | fake_from_numpy_next_state, fake_from_numpy_done = MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock() 78 | mock_torch.from_numpy.side_effect = [ 79 | fake_from_numpy_state, fake_from_numpy_action, fake_from_numpy_reward, 80 | fake_from_numpy_next_state, fake_from_numpy_done] 81 | 82 | output_state_float = fake_from_numpy_state.float.return_value 83 | output_state = output_state_float.to.return_value 84 | 85 | output_action_long = fake_from_numpy_action.long.return_value 86 | output_action = output_action_long.to.return_value 87 | 88 | output_reward_float = fake_from_numpy_reward.float.return_value 89 | output_reward = output_reward_float.to.return_value 90 | 91 | output_next_state_float = fake_from_numpy_next_state.float.return_value 92 | output_next_state = output_next_state_float.to.return_value 93 | 94 | output_done_float = fake_from_numpy_done.float.return_value 95 | output_done = output_done_float.to.return_value 96 | 97 | assert replay_buffer.sample() == (output_state, output_action, output_reward, output_next_state, output_done) 98 | 99 | 100 | def test_replaybuffer__len__() -> None: 101 | """Test for the __len__() method of the ReplayBuffer class""" 102 | replay_buffer = ReplayBuffer(buffer_size=10000, 103 | batch_size=1, 104 | device='cpu', 105 | seed=0) 106 | assert len(replay_buffer) == 0 107 | replay_buffer.add('state', 'action', 'reward', 'next_state', 'done') 108 | assert len(replay_buffer) == 1 109 | --------------------------------------------------------------------------------