from dataclasses import dataclass
30 | from typing import List
31 |
32 | import numpy as np
33 | import pybullet as p
34 | from numpy.typing import NDArray
35 |
36 |
37 | @dataclass
38 | class CollisionObject:
39 | """
40 | Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to
41 | calculate distances to other bodies.
42 | """
43 | body: str
44 | link: int
45 |
46 |
47 | class CollisionDetector:
48 |
49 | def __init__(self, collision_object: CollisionObject, obstacle_ids: List[str]):
50 | """
51 | Calculates distances between bodies' joints.
52 | Args:
53 | collision_object: CollisionObject instance, which indicates the body/joint from
54 | which to calculate distances/collisions.
55 | obstacle_ids: Obstacle body UID. Distances are calculated from the joint/body given in the
56 | "collision_object" parameter to the "obstacle_ids" bodies.
57 | """
58 | self.obstacles = obstacle_ids
59 | self.collision_object = collision_object
60 |
61 | def compute_distances(self, max_distance: float = 10.0) -> NDArray:
62 | """
63 | Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
64 | to the bodies defined in self.obstacles.
65 | Args:
66 | max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value
67 | for the distance between such bodies will be max_distance.
68 | Returns:
69 | A numpy array of distances, one per pair of collision objects.
70 | """
71 | distances = list()
72 | for obstacle in self.obstacles:
73 |
74 | # Compute the shortest distances between the collision-object and the given obstacle
75 | closest_points = p.getClosestPoints(
76 | self.collision_object.body,
77 | obstacle,
78 | distance=max_distance,
79 | linkIndexA=self.collision_object.link
80 | )
81 |
82 | # If bodies are above max_distance apart, nothing is returned, so
83 | # we just saturate at max_distance. Otherwise, take the minimum
84 | if len(closest_points) == 0:
85 | distances.append(max_distance)
86 | else:
87 | distances.append(np.min([point[8] for point in closest_points]))
88 |
89 | return np.array(distances)
90 |
91 | def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray:
92 | """
93 | Compute collisions between manipulator's parts.
94 | Args:
95 | affected_joints: Joints to consider when calculating distances.
96 | max_distance: Maximum distance to be considered. Distances further than this will be ignored, and
97 | the "max_distance" value will be returned.
98 | Returns:
99 | Array where each element corresponds to the distances from a given joint to the other joints.
100 | """
101 | distances = list()
102 | for joint_ind in affected_joints:
103 |
104 | # Collisions with the previous and next joints are omitted, as they will be always in contact
105 | if (self.collision_object.link == joint_ind) or \
106 | (joint_ind == self.collision_object.link - 1) or \
107 | (joint_ind == self.collision_object.link + 1):
108 | continue # pragma: no cover
109 |
110 | # Compute the shortest distances between all object pairs
111 | closest_points = p.getClosestPoints(
112 | self.collision_object.body,
113 | self.collision_object.body,
114 | distance=max_distance,
115 | linkIndexA=self.collision_object.link,
116 | linkIndexB=joint_ind
117 | )
118 |
119 | # If bodies are above max_distance apart, nothing is returned, so
120 | # we just saturate at max_distance. Otherwise, take the minimum
121 | if len(closest_points) == 0:
122 | distances.append(max_distance)
123 | else:
124 | distances.append(np.min([point[8] for point in closest_points]))
125 |
126 | return np.array(distances)
CollisionObject instance, which indicates the body/joint from
148 | which to calculate distances/collisions.
149 |
obstacle_ids
150 |
Obstacle body UID. Distances are calculated from the joint/body given in the
151 | "collision_object" parameter to the "obstacle_ids" bodies.
152 |
153 |
154 |
155 | Expand source code
156 |
157 |
class CollisionDetector:
158 |
159 | def __init__(self, collision_object: CollisionObject, obstacle_ids: List[str]):
160 | """
161 | Calculates distances between bodies' joints.
162 | Args:
163 | collision_object: CollisionObject instance, which indicates the body/joint from
164 | which to calculate distances/collisions.
165 | obstacle_ids: Obstacle body UID. Distances are calculated from the joint/body given in the
166 | "collision_object" parameter to the "obstacle_ids" bodies.
167 | """
168 | self.obstacles = obstacle_ids
169 | self.collision_object = collision_object
170 |
171 | def compute_distances(self, max_distance: float = 10.0) -> NDArray:
172 | """
173 | Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
174 | to the bodies defined in self.obstacles.
175 | Args:
176 | max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value
177 | for the distance between such bodies will be max_distance.
178 | Returns:
179 | A numpy array of distances, one per pair of collision objects.
180 | """
181 | distances = list()
182 | for obstacle in self.obstacles:
183 |
184 | # Compute the shortest distances between the collision-object and the given obstacle
185 | closest_points = p.getClosestPoints(
186 | self.collision_object.body,
187 | obstacle,
188 | distance=max_distance,
189 | linkIndexA=self.collision_object.link
190 | )
191 |
192 | # If bodies are above max_distance apart, nothing is returned, so
193 | # we just saturate at max_distance. Otherwise, take the minimum
194 | if len(closest_points) == 0:
195 | distances.append(max_distance)
196 | else:
197 | distances.append(np.min([point[8] for point in closest_points]))
198 |
199 | return np.array(distances)
200 |
201 | def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray:
202 | """
203 | Compute collisions between manipulator's parts.
204 | Args:
205 | affected_joints: Joints to consider when calculating distances.
206 | max_distance: Maximum distance to be considered. Distances further than this will be ignored, and
207 | the "max_distance" value will be returned.
208 | Returns:
209 | Array where each element corresponds to the distances from a given joint to the other joints.
210 | """
211 | distances = list()
212 | for joint_ind in affected_joints:
213 |
214 | # Collisions with the previous and next joints are omitted, as they will be always in contact
215 | if (self.collision_object.link == joint_ind) or \
216 | (joint_ind == self.collision_object.link - 1) or \
217 | (joint_ind == self.collision_object.link + 1):
218 | continue # pragma: no cover
219 |
220 | # Compute the shortest distances between all object pairs
221 | closest_points = p.getClosestPoints(
222 | self.collision_object.body,
223 | self.collision_object.body,
224 | distance=max_distance,
225 | linkIndexA=self.collision_object.link,
226 | linkIndexB=joint_ind
227 | )
228 |
229 | # If bodies are above max_distance apart, nothing is returned, so
230 | # we just saturate at max_distance. Otherwise, take the minimum
231 | if len(closest_points) == 0:
232 | distances.append(max_distance)
233 | else:
234 | distances.append(np.min([point[8] for point in closest_points]))
235 |
236 | return np.array(distances)
Maximum distance to be considered. Distances further than this will be ignored, and
251 | the "max_distance" value will be returned.
252 |
253 |
Returns
254 |
Array where each element corresponds to the distances from a given joint to the other joints.
255 |
256 |
257 | Expand source code
258 |
259 |
def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray:
260 | """
261 | Compute collisions between manipulator's parts.
262 | Args:
263 | affected_joints: Joints to consider when calculating distances.
264 | max_distance: Maximum distance to be considered. Distances further than this will be ignored, and
265 | the "max_distance" value will be returned.
266 | Returns:
267 | Array where each element corresponds to the distances from a given joint to the other joints.
268 | """
269 | distances = list()
270 | for joint_ind in affected_joints:
271 |
272 | # Collisions with the previous and next joints are omitted, as they will be always in contact
273 | if (self.collision_object.link == joint_ind) or \
274 | (joint_ind == self.collision_object.link - 1) or \
275 | (joint_ind == self.collision_object.link + 1):
276 | continue # pragma: no cover
277 |
278 | # Compute the shortest distances between all object pairs
279 | closest_points = p.getClosestPoints(
280 | self.collision_object.body,
281 | self.collision_object.body,
282 | distance=max_distance,
283 | linkIndexA=self.collision_object.link,
284 | linkIndexB=joint_ind
285 | )
286 |
287 | # If bodies are above max_distance apart, nothing is returned, so
288 | # we just saturate at max_distance. Otherwise, take the minimum
289 | if len(closest_points) == 0:
290 | distances.append(max_distance)
291 | else:
292 | distances.append(np.min([point[8] for point in closest_points]))
293 |
294 | return np.array(distances)
Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
302 | to the bodies defined in self.obstacles.
303 |
Args
304 |
305 |
max_distance
306 |
Bodies farther apart than this distance are not queried by PyBullet, the return value
307 | for the distance between such bodies will be max_distance.
308 |
309 |
Returns
310 |
A numpy array of distances, one per pair of collision objects.
311 |
312 |
313 | Expand source code
314 |
315 |
def compute_distances(self, max_distance: float = 10.0) -> NDArray:
316 | """
317 | Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
318 | to the bodies defined in self.obstacles.
319 | Args:
320 | max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value
321 | for the distance between such bodies will be max_distance.
322 | Returns:
323 | A numpy array of distances, one per pair of collision objects.
324 | """
325 | distances = list()
326 | for obstacle in self.obstacles:
327 |
328 | # Compute the shortest distances between the collision-object and the given obstacle
329 | closest_points = p.getClosestPoints(
330 | self.collision_object.body,
331 | obstacle,
332 | distance=max_distance,
333 | linkIndexA=self.collision_object.link
334 | )
335 |
336 | # If bodies are above max_distance apart, nothing is returned, so
337 | # we just saturate at max_distance. Otherwise, take the minimum
338 | if len(closest_points) == 0:
339 | distances.append(max_distance)
340 | else:
341 | distances.append(np.min([point[8] for point in closest_points]))
342 |
343 | return np.array(distances)
Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to
354 | calculate distances to other bodies.
355 |
356 |
357 | Expand source code
358 |
359 |
class CollisionObject:
360 | """
361 | Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to
362 | calculate distances to other bodies.
363 | """
364 | body: str
365 | link: int
def get_global_logger() -> logging.Logger:
158 | """
159 | Getter for the logger.
160 | Returns:
161 | Logger instance to be used on the framework.
162 | """
163 | return logging.getLogger(__name__)
Formatter instances are used to convert a LogRecord to text.
177 |
Formatters need to know how a LogRecord is constructed. They are
178 | responsible for converting a LogRecord to (usually) a string which can
179 | be interpreted by either a human or an external system. The base Formatter
180 | allows a formatting string to be specified. If none is supplied, the
181 | style-dependent default value, "%(message)s", "{message}", or
182 | "${message}", is used.
183 |
The Formatter can be initialized with a format string which makes use of
184 | knowledge of the LogRecord attributes - e.g. the default value mentioned
185 | above makes use of the fact that the user's message and arguments are pre-
186 | formatted into a LogRecord's message attribute. Currently, the useful
187 | attributes in a LogRecord are described by:
188 |
%(name)s
189 | Name of the logger (logging channel)
190 | %(levelno)s
191 | Numeric logging level for the message (DEBUG, INFO,
192 | WARNING, ERROR, CRITICAL)
193 | %(levelname)s
194 | Text logging level for the message ("DEBUG", "INFO",
195 | "WARNING", "ERROR", "CRITICAL")
196 | %(pathname)s
197 | Full pathname of the source file where the logging
198 | call was issued (if available)
199 | %(filename)s
200 | Filename portion of pathname
201 | %(module)s
202 | Module (name portion of filename)
203 | %(lineno)d
204 | Source line number where the logging call was issued
205 | (if available)
206 | %(funcName)s
207 | Function name
208 | %(created)f
209 | Time when the LogRecord was created (time.time()
210 | return value)
211 | %(asctime)s
212 | Textual time when the LogRecord was created
213 | %(msecs)d
214 | Millisecond portion of the creation time
215 | %(relativeCreated)d Time in milliseconds when the LogRecord was created,
216 | relative to the time the logging module was loaded
217 | (typically at application startup time)
218 | %(thread)d
219 | Thread ID (if available)
220 | %(threadName)s
221 | Thread name (if available)
222 | %(process)d
223 | Process ID (if available)
224 | %(message)s
225 | The result of record.getMessage(), computed just as
226 | the record is emitted
import random
30 | from collections import deque, namedtuple
31 | from typing import Tuple
32 |
33 | import numpy as np
34 | import torch
35 | from numpy.typing import NDArray
36 |
37 |
38 | # deque is a Doubly Ended Queue, provides O(1) complexity for pop and append actions
39 | # namedtuple is a tuple that can be accessed by both its index and attributes
40 |
41 |
42 | class ReplayBuffer:
43 |
44 | def __init__(self, buffer_size: int, batch_size: int, device: torch.device, seed: int):
45 | """
46 | Buffer to store experience tuples. Each experience has the following structure:
47 | (state, action, reward, next_state, done)
48 | Args:
49 | buffer_size: Maximum size for the buffer. Higher buffer size imply higher RAM consumption.
50 | batch_size: Number of experiences to be retrieved from the ReplayBuffer per batch.
51 | device: CUDA device.
52 | seed: Random seed.
53 | """
54 | self.device = device
55 | self.memory = deque(maxlen=buffer_size)
56 | self.batch_size = batch_size
57 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
58 | random.seed(seed)
59 |
60 | def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
61 | """
62 | Add a new experience to the Replay Buffer.
63 | Args:
64 | state: NDArray of the current state.
65 | action: NDArray of the action taken from state {state}.
66 | reward: Reward obtained after performing action {action} from state {state}.
67 | next_state: NDArray of the state reached after performing action {action} from state {state}.
68 | done: Integer (0 or 1) indicating whether the next_state is a terminal state.
69 | """
70 | # Create namedtuple object from the experience
71 | exp = self.experience(state, action, reward, next_state, done)
72 | # Add the experience object to memory
73 | self.memory.append(exp)
74 |
75 | def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
76 | """
77 | Randomly sample a batch of experiences from memory.
78 | Returns:
79 | Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
80 | in the tuple is a torch Tensor composed of {batch_size} items.
81 | """
82 | # Randomly sample a batch of experiences
83 | experiences = random.sample(self.memory, k=self.batch_size)
84 |
85 | states = torch.from_numpy(
86 | np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to(
87 | self.device)
88 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
89 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
90 | next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
91 | self.device)
92 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
93 | self.device)
94 |
95 | return states, actions, rewards, next_states, dones
96 |
97 | def __len__(self) -> int:
98 | """
99 | Return the current size of the Replay Buffer
100 | Returns:
101 | Size of Replay Buffer
102 | """
103 | return len(self.memory)
Buffer to store experience tuples. Each experience has the following structure:
121 | (state, action, reward, next_state, done)
122 |
Args
123 |
124 |
buffer_size
125 |
Maximum size for the buffer. Higher buffer size imply higher RAM consumption.
126 |
batch_size
127 |
Number of experiences to be retrieved from the ReplayBuffer per batch.
128 |
device
129 |
CUDA device.
130 |
seed
131 |
Random seed.
132 |
133 |
134 |
135 | Expand source code
136 |
137 |
class ReplayBuffer:
138 |
139 | def __init__(self, buffer_size: int, batch_size: int, device: torch.device, seed: int):
140 | """
141 | Buffer to store experience tuples. Each experience has the following structure:
142 | (state, action, reward, next_state, done)
143 | Args:
144 | buffer_size: Maximum size for the buffer. Higher buffer size imply higher RAM consumption.
145 | batch_size: Number of experiences to be retrieved from the ReplayBuffer per batch.
146 | device: CUDA device.
147 | seed: Random seed.
148 | """
149 | self.device = device
150 | self.memory = deque(maxlen=buffer_size)
151 | self.batch_size = batch_size
152 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
153 | random.seed(seed)
154 |
155 | def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
156 | """
157 | Add a new experience to the Replay Buffer.
158 | Args:
159 | state: NDArray of the current state.
160 | action: NDArray of the action taken from state {state}.
161 | reward: Reward obtained after performing action {action} from state {state}.
162 | next_state: NDArray of the state reached after performing action {action} from state {state}.
163 | done: Integer (0 or 1) indicating whether the next_state is a terminal state.
164 | """
165 | # Create namedtuple object from the experience
166 | exp = self.experience(state, action, reward, next_state, done)
167 | # Add the experience object to memory
168 | self.memory.append(exp)
169 |
170 | def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
171 | """
172 | Randomly sample a batch of experiences from memory.
173 | Returns:
174 | Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
175 | in the tuple is a torch Tensor composed of {batch_size} items.
176 | """
177 | # Randomly sample a batch of experiences
178 | experiences = random.sample(self.memory, k=self.batch_size)
179 |
180 | states = torch.from_numpy(
181 | np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to(
182 | self.device)
183 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
184 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
185 | next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
186 | self.device)
187 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
188 | self.device)
189 |
190 | return states, actions, rewards, next_states, dones
191 |
192 | def __len__(self) -> int:
193 | """
194 | Return the current size of the Replay Buffer
195 | Returns:
196 | Size of Replay Buffer
197 | """
198 | return len(self.memory)
Reward obtained after performing action {action} from state {state}.
215 |
next_state
216 |
NDArray of the state reached after performing action {action} from state {state}.
217 |
done
218 |
Integer (0 or 1) indicating whether the next_state is a terminal state.
219 |
220 |
221 |
222 | Expand source code
223 |
224 |
def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
225 | """
226 | Add a new experience to the Replay Buffer.
227 | Args:
228 | state: NDArray of the current state.
229 | action: NDArray of the action taken from state {state}.
230 | reward: Reward obtained after performing action {action} from state {state}.
231 | next_state: NDArray of the state reached after performing action {action} from state {state}.
232 | done: Integer (0 or 1) indicating whether the next_state is a terminal state.
233 | """
234 | # Create namedtuple object from the experience
235 | exp = self.experience(state, action, reward, next_state, done)
236 | # Add the experience object to memory
237 | self.memory.append(exp)
Randomly sample a batch of experiences from memory.
245 |
Returns
246 |
Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
247 | in the tuple is a torch Tensor composed of {batch_size} items.
248 |
249 |
250 | Expand source code
251 |
252 |
def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
253 | """
254 | Randomly sample a batch of experiences from memory.
255 | Returns:
256 | Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
257 | in the tuple is a torch Tensor composed of {batch_size} items.
258 | """
259 | # Randomly sample a batch of experiences
260 | experiences = random.sample(self.memory, k=self.batch_size)
261 |
262 | states = torch.from_numpy(
263 | np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to(
264 | self.device)
265 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
266 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
267 | next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
268 | self.device)
269 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
270 | self.device)
271 |
272 | return states, actions, rewards, next_states, dones
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
304 |
305 |
308 |
309 |
310 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0.0", "wheel"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "robotic-manipulator-rloa"
7 | version = "1.0.0"
8 | description = "Framework for training Robotic Manipulators on the Obstacle Avoidance task through Reinforcement Learning."
9 | readme = "README.md"
10 | authors = [{ name = "Javier Martinez", email = "jmartinezojeda5upv@gmail.com" }]
11 | license = { file = "LICENSE" }
12 | classifiers = [
13 | "License :: OSI Approved :: MIT License",
14 | "Programming Language :: Python",
15 | "Programming Language :: Python :: 3",
16 | "Operating System :: Microsoft :: Windows",
17 | "Operating System :: POSIX :: Linux",
18 | "Framework :: Robot Framework"
19 | ]
20 | keywords = ["Robotics", "Manipulator", "Framework", "Reinforcement Learning", "Obstacle Avoidance"]
21 | dependencies = [
22 | "torch",
23 | "numpy",
24 | "pybullet",
25 | "matplotlib"
26 | ]
27 | requires-python = ">=3.8"
28 |
29 | [tool.setuptools.package-data]
30 | myModule = ["*.p"]
31 |
32 | [project.urls]
33 | Homepage = "https://github.com/JavierMtz5/robotic_manipulator_rloa"
34 | Repository = "https://github.com/JavierMtz5/robotic_manipulator_rloa"
35 | Documentation = "https://javiermtz5.github.io/robotic_manipulator_rloa/"
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/__init__.py:
--------------------------------------------------------------------------------
1 | from .rl_framework import ManipulatorFramework
2 |
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/environment/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/environment/__init__.py
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/environment/environment.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import List, Tuple
3 |
4 | import numpy as np
5 | import pybullet as p
6 | import pybullet_data
7 | from numpy.typing import NDArray
8 |
9 | from robotic_manipulator_rloa.utils.logger import get_global_logger
10 | from robotic_manipulator_rloa.utils.exceptions import (
11 | InvalidManipulatorFile,
12 | InvalidEnvironmentParameter
13 | )
14 | from robotic_manipulator_rloa.utils.collision_detector import CollisionObject, CollisionDetector
15 |
16 | logger = get_global_logger()
17 |
18 |
19 | class EnvironmentConfiguration:
20 |
21 | def __init__(self,
22 | endeffector_index: int,
23 | fixed_joints: List[int],
24 | involved_joints: List[int],
25 | target_position: List[float],
26 | obstacle_position: List[float],
27 | initial_joint_positions: List[float] = None,
28 | initial_positions_variation_range: List[float] = None,
29 | max_force: float = 200.,
30 | visualize: bool = True):
31 | """
32 | Validates each of the parameters required for the Environment class initialization.
33 | Args:
34 | endeffector_index: Index of the manipulator's end-effector.
35 | fixed_joints: List containing the indices of every joint not involved in the training.
36 | involved_joints: List containing the indices of every joint involved in the training.
37 | target_position: List containing the position of the target object, as 3D Cartesian coordinates.
38 | obstacle_position: List containing the position of the obstacle, as 3D Cartesian coordinates.
39 | initial_joint_positions: List containing as many items as the number of joints of the manipulator.
40 | Each item in the list corresponds to the initial position wanted for the joint with that same index.
41 | initial_positions_variation_range: List containing as many items as the number of joints of the manipulator.
42 | Each item in the list corresponds to the variation range wanted for the joint with that same index.
43 | max_force: Maximum force to be applied on the joints.
44 | visualize: Visualization mode.
45 | """
46 | self._validate_endeffector_index(endeffector_index)
47 | self._validate_fixed_joints(fixed_joints)
48 | self._validate_involved_joints(involved_joints)
49 | self._validate_target_position(target_position)
50 | self._validate_obstacle_position(obstacle_position)
51 | self._validate_initial_joint_positions(initial_joint_positions)
52 | self._validate_initial_positions_variation_range(initial_positions_variation_range)
53 | self._validate_max_force(max_force)
54 | self._validate_visualize(visualize)
55 |
56 | def _validate_endeffector_index(self, endeffector_index: int) -> None:
57 | """
58 | Validates the "endeffector_index" parameter.
59 | Args:
60 | endeffector_index: int
61 | Raises:
62 | InvalidEnvironmentParameter
63 | """
64 | if not isinstance(endeffector_index, int):
65 | raise InvalidEnvironmentParameter('End Effector index received is not an integer')
66 | self.endeffector_index = endeffector_index
67 |
68 | def _validate_fixed_joints(self, fixed_joints: List[int]) -> None:
69 | """
70 | Validates the "fixed_joints" parameter
71 | Args:
72 | fixed_joints: list of integers
73 | Raises:
74 | InvalidEnvironmentParameter
75 | """
76 | if not isinstance(fixed_joints, list):
77 | raise InvalidEnvironmentParameter('Fixed Joints received is not a list')
78 | for val in fixed_joints:
79 | if not isinstance(val, int):
80 | raise InvalidEnvironmentParameter('An item inside the Fixed Joints list is not an integer')
81 | self.fixed_joints = fixed_joints
82 |
83 | def _validate_involved_joints(self, involved_joints: List[int]) -> None:
84 | """
85 | Validates the "involved_joints" parameter
86 | Args:
87 | involved_joints: list of integers
88 | Raises:
89 | InvalidEnvironmentParameter
90 | """
91 | if not isinstance(involved_joints, list):
92 | raise InvalidEnvironmentParameter('Involved Joints received is not a list')
93 | for val in involved_joints:
94 | if not isinstance(val, int):
95 | raise InvalidEnvironmentParameter('An item inside the Involved Joints list is not an integer')
96 | self.involved_joints = involved_joints
97 |
98 | def _validate_target_position(self, target_position: List[float]) -> None:
99 | """
100 | Validates the "target_position" parameter
101 | Args:
102 | target_position: list of floats
103 | Raises:
104 | InvalidEnvironmentParameter
105 | """
106 | if not isinstance(target_position, list):
107 | raise InvalidEnvironmentParameter('Target Position received is not a list')
108 | for val in target_position:
109 | if not isinstance(val, (int, float)):
110 | raise InvalidEnvironmentParameter('An item inside the Target Position list is not a float')
111 | self.target_position = target_position
112 |
113 | def _validate_obstacle_position(self, obstacle_position: List[float]) -> None:
114 | """
115 | Validates the "obstacle_position" parameter
116 | Args:
117 | obstacle_position: list of floats
118 | Raises:
119 | InvalidEnvironmentParameter
120 | """
121 | if not isinstance(obstacle_position, list):
122 | raise InvalidEnvironmentParameter('Obstacle Position received is not a list')
123 | for val in obstacle_position:
124 | if not isinstance(val, (int, float)):
125 | raise InvalidEnvironmentParameter('An item inside the Obstacle Position list is not a float')
126 | self.obstacle_position = obstacle_position
127 |
128 | def _validate_initial_joint_positions(self, initial_joint_positions: List[float]) -> None:
129 | """
130 | Validates the "initial_joint_positions" parameter
131 | Args:
132 | initial_joint_positions: list of floats
133 | Raises:
134 | InvalidEnvironmentParameter
135 | """
136 | if initial_joint_positions is None:
137 | self.initial_joint_positions = None
138 | return
139 | if not isinstance(initial_joint_positions, list):
140 | raise InvalidEnvironmentParameter('Initial Joint Positions received is not a list')
141 | for val in initial_joint_positions:
142 | if not isinstance(val, (int, float)):
143 | raise InvalidEnvironmentParameter('An item inside the Initial Joint Positions list is not a float')
144 | self.initial_joint_positions = initial_joint_positions
145 |
146 | def _validate_initial_positions_variation_range(self, initial_positions_variation_range: List[float]) -> None:
147 | """
148 | Validates the "initial_positions_variation_range" parameter
149 | Args:
150 | initial_positions_variation_range: list of floats
151 | Raises:
152 | InvalidEnvironmentParameter
153 | """
154 | if initial_positions_variation_range is None:
155 | self.initial_positions_variation_range = None
156 | return
157 | if not isinstance(initial_positions_variation_range, list):
158 | raise InvalidEnvironmentParameter('Initial Positions Variation Range received is not a list')
159 | for val in initial_positions_variation_range:
160 | if not isinstance(val, (float, int)):
161 | raise InvalidEnvironmentParameter('An item inside the Initial Positions Variation Range '
162 | 'list is not a float')
163 | self.initial_positions_variation_range = initial_positions_variation_range
164 |
165 | def _validate_max_force(self, max_force: float) -> None:
166 | """
167 | Validates the "max_force" parameter
168 | Args:
169 | max_force: float
170 | Raises:
171 | InvalidEnvironmentParameter
172 | """
173 | if not isinstance(max_force, (int, float)):
174 | raise InvalidEnvironmentParameter('Maximum Force value received is not a float')
175 | self.max_force = max_force
176 |
177 | def _validate_visualize(self, visualize: bool) -> None:
178 | """
179 | Validates the "visualize" parameter
180 | Args:
181 | visualize: bool
182 | Raises:
183 | InvalidEnvironmentParameter
184 | """
185 | if not isinstance(visualize, bool):
186 | raise InvalidEnvironmentParameter('Visualize value received is not a boolean')
187 | self.visualize = visualize
188 |
189 |
190 | class Environment:
191 |
192 | def __init__(self,
193 | manipulator_file: str,
194 | environment_config: EnvironmentConfiguration):
195 | """
196 | Creates the Pybullet environment used along the training.
197 | Args:
198 | manipulator_file: Path to the URDF or SDF file from which to load the Robotic Manipulator.
199 | environment_config: Instance of the EnvironmentConfiguration class with all its attributes set.
200 | Raises:
201 | InvalidManipulatorFile: The URDF/SDF file doesn't exist, is invalid or has an invalid extension.
202 | """
203 | self.manipulator_file = manipulator_file
204 | self.visualize = environment_config.visualize
205 |
206 | # Initialize pybullet
207 | self.physics_client = p.connect(p.GUI if environment_config.visualize else p.DIRECT)
208 | p.setGravity(0, 0, -9.81)
209 | p.setRealTimeSimulation(0)
210 | p.setAdditionalSearchPath(pybullet_data.getDataPath())
211 |
212 | self.target_pos = environment_config.target_position
213 | self.obstacle_pos = environment_config.obstacle_position
214 | self.max_force = environment_config.max_force # Maximum force to be applied (DEFAULT=200)
215 | self.initial_joint_positions = environment_config.initial_joint_positions
216 | self.initial_positions_variation_range = environment_config.initial_positions_variation_range
217 |
218 | self.endeffector_index = environment_config.endeffector_index # Index of the Manipulator's End-Effector
219 | self.fixed_joints = environment_config.fixed_joints # List of indexes for the joints to be fixed
220 | self.involved_joints = environment_config.involved_joints # List of indexes of joints involved in the training
221 |
222 | # Load Manipulator from URDF/SDF file
223 | logger.debug(f'Loading URDF/SDF file {manipulator_file} for Robot Manipulator...')
224 | if not isinstance(manipulator_file, str):
225 | raise InvalidManipulatorFile('The filename provided is not a string')
226 |
227 | try:
228 | if manipulator_file.endswith('.urdf'):
229 | self.manipulator_uid = p.loadURDF(manipulator_file)
230 | elif manipulator_file.endswith('.sdf'):
231 | self.manipulator_uid = p.loadSDF(manipulator_file)[0]
232 | else:
233 | raise InvalidManipulatorFile('The file extension is neither .sdf nor .urdf')
234 | except p.error as err:
235 | logger.critical(err)
236 | raise InvalidManipulatorFile
237 |
238 | self.num_joints = p.getNumJoints(self.manipulator_uid)
239 |
240 | logger.debug(f'Robot Manipulator URDF/SDF file {manipulator_file} has been successfully loaded. '
241 | f'The Robot Manipulator has {self.num_joints} joints, and its joints, '
242 | f'together with the information of each, are:')
243 | data = list()
244 | for joint_ind in range(self.num_joints):
245 | joint_info = p.getJointInfo(self.manipulator_uid, joint_ind)
246 | data.append((joint_ind, joint_info[1].decode("utf-8"), joint_info[9], joint_info[8], joint_info[13]))
247 |
248 | # Print Joints info
249 | self.print_table(data)
250 |
251 | # Create obstacle with the shape of a sphere, and the target object with square shape
252 | self.obstacle = p.loadURDF('sphere_small.urdf', basePosition=self.obstacle_pos,
253 | useFixedBase=1, globalScaling=2.5)
254 | self.target = p.loadURDF('cube_small.urdf', basePosition=self.target_pos,
255 | useFixedBase=1, globalScaling=1)
256 | logger.debug(f'Both the obstacle and the target object have been generated in positions {self.obstacle_pos} '
257 | f'and {self.target_pos} respectively')
258 |
259 | # 9 elements correspond to the 3 vector indicating the position of the target, end effector and obstacle
260 | # The other elements are the two arrays of the involved joint's position and velocities
261 | self._observation_space = np.zeros((9 + 2 * len(self.involved_joints),))
262 | self._action_space = np.zeros((len(self.involved_joints),))
263 |
264 | def reset(self, verbose: bool = True) -> NDArray:
265 | """
266 | Resets the environment to a initial state.\n
267 | - If "initial_joint_positions" and "initial_positions_variation_range" are not set, all joints will be reset to
268 | the 0 position.\n
269 | - If only "initial_joint_positions" is set, the joints will be reset to those positions.\n
270 | - If only "initial_positions_variation_range" is set, the joints will be reset to 0 plus the variation noise.\n
271 | - If both "initial_joint_positions" and "initial_positions_variation_range" are set, the joints will be reset
272 | to the positions specified plus the variation noise.
273 | Args:
274 | verbose: Boolean indicating whether to print context information or not.
275 | Returns:
276 | New state reached after reset.
277 | """
278 | if verbose: logger.info('Resetting Environment...')
279 |
280 | # Reset the robot's base position and orientation
281 | p.resetBasePositionAndOrientation(self.manipulator_uid, [0.000000, 0.000000, 0.000000],
282 | [0.000000, 0.000000, 0.000000, 1.000000])
283 |
284 | if not self.initial_joint_positions and not self.initial_positions_variation_range:
285 | initial_state = [0 for _ in range(self.num_joints)]
286 | elif self.initial_joint_positions:
287 | if self.initial_positions_variation_range:
288 | initial_state = [random.uniform(pos - var, pos + var) for pos, var
289 | in zip(self.initial_joint_positions, self.initial_positions_variation_range)]
290 | else:
291 | initial_state = self.initial_joint_positions
292 | else:
293 | initial_state = [random.uniform(0 - var, 0 + var) for var in self.initial_positions_variation_range]
294 |
295 | for joint_index, pos in enumerate(initial_state):
296 | p.setJointMotorControl2(self.manipulator_uid, joint_index,
297 | controlMode=p.POSITION_CONTROL,
298 | targetPosition=pos)
299 |
300 | for _ in range(50):
301 | p.stepSimulation(self.physics_client)
302 |
303 | # Generate first state, and return it
304 | # The states are defined as {joint_pos, joint_vel, end-effector_pos, target_pos, obstacle_pos}, where
305 | # both joint_pos and joint_vel are arrays with the pos and vel of each joint
306 | new_state = self.get_state()
307 | if verbose: logger.info('Environment Reset')
308 |
309 | return new_state
310 |
311 | def is_terminal_state(self, target_threshold: float = 0.05, obstacle_threshold: float = 0.,
312 | consider_autocollision: bool = False) -> int:
313 | """
314 | Calculates if a terminal state is reached.
315 | Args:
316 | target_threshold: Threshold which delimits the terminal state. If the end-effector is closer
317 | to the target position than the threshold value, then a terminal state is reached.
318 | obstacle_threshold: Threshold which delimits the terminal state. If the end-effector is closer
319 | to the obstacle position than the threshold value, then a terminal state is reached.
320 | consider_autocollision: If set to True, the collision of any of the joints and parts of the manipulator
321 | with any other joint or part will be considered a terminal state.
322 | Returns:
323 | Integer (0 or 1) indicating whether the new state reached is a terminal state or not.
324 | """
325 | # If the manipulator has a collision with the obstacle, the episode terminates
326 | if self.get_manipulator_obstacle_collisions(threshold=obstacle_threshold):
327 | logger.info('Collision detected, terminating episode...')
328 | return 1
329 |
330 | # If the position of the end-effector is the same as the one of the target position, episode terminates
331 | if self.get_endeffector_target_collision(threshold=target_threshold)[0]:
332 | logger.info('The goal state has been reached, terminating episode...')
333 | return 1
334 |
335 | # If the manipulator collides with itself, a terminal state is reached
336 | if consider_autocollision:
337 | self_distances = self.get_manipulator_collisions_with_itself()
338 | for distances in self_distances.values():
339 | if (distances < 0).any():
340 | logger.info('Auto-Collision detected, terminating episode...')
341 | return 1
342 |
343 | return 0
344 |
345 | def get_reward(self, consider_autocollision: bool = False) -> float:
346 | """
347 | Computes the reward from the given state.
348 | Returns:
349 | Rewards:\n
350 | - If the end effector reaches the target position, a reward of +250 is returned.\n
351 | - If the end effector collides with the obstacle or with itself*, a reward of -1000 is returned.\n
352 | - Otherwise, the negative value of the distance from end effector to the target is returned.\n
353 | * The manipulator's collisions with itself are only considered if "consider_autocollision" parameter is set
354 | to True.
355 | """
356 | # Auto-Collision is only calculated if requested
357 | self_collision = False
358 | if consider_autocollision:
359 | self_distances = self.get_manipulator_collisions_with_itself()
360 | for distances in self_distances.values():
361 | if (distances < 0).any():
362 | self_collision = True
363 |
364 | endeffector_target_collision, endeffector_target_dist = self.get_endeffector_target_collision(threshold=0.05)
365 |
366 | if endeffector_target_collision:
367 | return 250
368 | elif self.get_manipulator_obstacle_collisions(threshold=0) or self_collision:
369 | return -1000
370 | else:
371 | return -1 * float(endeffector_target_dist)
372 |
373 | def get_manipulator_obstacle_collisions(self, threshold: float) -> bool:
374 | """
375 | Calculates if there is a collision between the manipulator and the obstacle.
376 | Args:
377 | threshold: If the distance between the end effector and the obstacle is below the "threshold", then
378 | it is considered a collision.
379 | Returns:
380 | Boolean indicating whether a collision occurred.
381 | """
382 | joint_distances = list()
383 | for joint_ind in range(self.num_joints):
384 | end_effector_collision_obj = CollisionObject(body=self.manipulator_uid, link=joint_ind)
385 | collision_detector = CollisionDetector(collision_object=end_effector_collision_obj,
386 | obstacle_ids=[self.obstacle])
387 |
388 | dist = collision_detector.compute_distances()
389 | joint_distances.append(dist[0])
390 |
391 | joint_distances = np.array(joint_distances)
392 | return (joint_distances < threshold).any()
393 |
394 | def get_manipulator_collisions_with_itself(self) -> dict:
395 | """
396 | Calculates the distances between each of the manipulator's joints and the other joints.
397 | Returns:
398 | Dictionary where each key is the index of a joint, and where each value is an array with the
399 | distances from that joint to any other joint in the manipulator.
400 | """
401 | joint_distances = dict()
402 | for joint_ind in range(self.num_joints):
403 | joint_collision_obj = CollisionObject(body=self.manipulator_uid, link=joint_ind)
404 | collision_detector = CollisionDetector(collision_object=joint_collision_obj,
405 | obstacle_ids=[])
406 | distances = collision_detector.compute_collisions_in_manipulator(
407 | affected_joints=[_ for _ in range(self.num_joints)], # all joints are taken into account
408 | max_distance=10
409 | )
410 | joint_distances[f'joint_{joint_ind}'] = distances
411 |
412 | return joint_distances
413 |
414 | def get_endeffector_target_collision(self, threshold: float) -> Tuple[bool, float]:
415 | """
416 | Calculates if there are any collisions between the end effector and the target.
417 | Args:
418 | threshold: If the distance between the end effector and the target is below {threshold}, then
419 | it is considered a collision.
420 | Returns:
421 | Tuple where the first element is a boolean indicating whether a collision occurred, adn where
422 | the second is the distance from end effector to target minus the threshold.
423 | """
424 | kuka_end_effector = CollisionObject(body=self.manipulator_uid, link=self.endeffector_index)
425 | collision_detector = CollisionDetector(collision_object=kuka_end_effector, obstacle_ids=[self.target])
426 |
427 | dist = collision_detector.compute_distances()
428 |
429 | return (dist < threshold).any(), dist - threshold
430 |
431 | def get_state(self) -> NDArray:
432 | """
433 | Retrieves information from the environment's current state.
434 | Returns:
435 | State as (joint_pos, joint_vel, end-effector_pos, target_pos, obstacle_pos):\n
436 | - The positions of the target, obstacle and end effector are given as 3D cartesian coordinates.\n
437 | - The joint positions and joint velocities are given as arrays of length equal to the number of
438 | joint involved in the training.
439 | """
440 | joint_pos, joint_vel = list(), list()
441 |
442 | for joint_index in range(len(self.involved_joints)):
443 | joint_pos.append(p.getJointState(self.manipulator_uid, joint_index)[0])
444 | joint_vel.append(p.getJointState(self.manipulator_uid, joint_index)[1])
445 |
446 | end_effector_pos = p.getLinkState(self.manipulator_uid, self.endeffector_index)[0]
447 | end_effector_pos = list(end_effector_pos)
448 |
449 | state = np.hstack([np.array(joint_pos), np.array(joint_vel), np.array(end_effector_pos),
450 | np.array(self.target_pos), np.array(self.obstacle_pos)])
451 | return state.astype(float)
452 |
453 | def step(self, action: NDArray) -> Tuple[NDArray, float, int]:
454 | """
455 | Applies the action on the Robot's joints, so that each joint reaches the desired velocity for
456 | each involved joint.
457 | Args:
458 | action: Array where each element corresponds to the velocity to be applied on the joint
459 | with that same index.
460 | Returns:
461 | (new_state, reward, done)
462 | """
463 | # Apply velocities on the involved joints according to action
464 | for joint_index, vel in zip(self.involved_joints, action):
465 | p.setJointMotorControl2(self.manipulator_uid,
466 | joint_index,
467 | p.VELOCITY_CONTROL,
468 | targetVelocity=vel,
469 | force=self.max_force)
470 |
471 | # Create constraint for fixed joints (maintain joint on fixed position)
472 | for joint_ind in self.fixed_joints:
473 | p.setJointMotorControl2(self.manipulator_uid,
474 | joint_ind,
475 | p.POSITION_CONTROL,
476 | targetPosition=0)
477 |
478 | # Perform actions on simulation
479 | p.stepSimulation(physicsClientId=self.physics_client)
480 |
481 | reward = self.get_reward()
482 | new_state = self.get_state()
483 | done = self.is_terminal_state()
484 |
485 | return new_state, reward, done
486 |
487 | @staticmethod
488 | def print_table(data: List[Tuple[int, str, float, float, tuple]]) -> None:
489 | """
490 | Prints a table such that the elements received in the "data" parameter are displayed under
491 | "Index", "Name", "Upper Limit", "Lower Limit" and "Axis" columns. It is used to print the Manipulator's
492 | joint's information in an ordered manner.
493 | Args:
494 | data: List where each element contains all the information about a given joint.
495 | Each element on the list will be a tuple containing (index, name, upper_limit, lower_limit, axis).
496 | """
497 | logger.debug('{:<6} {:<35} {:<15} {:<15} {:<15}'.format('Index', 'Name', 'Upper Limit', 'Lower Limit', 'Axis'))
498 | for index, name, up_limit, lo_limit, axis in data:
499 | logger.debug('{:<6} {:<35} {:<15} {:<15} {:<15}'.format(index, name, up_limit, lo_limit, str(axis)))
500 |
501 | @property
502 | def observation_space(self) -> np.ndarray:
503 | """
504 | Getter for the observation space of the environment.
505 | Returns:
506 | Numpy array of zeros with same shape as the environment's states.
507 | """
508 | return self._observation_space
509 |
510 | @property
511 | def action_space(self) -> np.ndarray:
512 | """
513 | Getter for the action space of the environment.
514 | Returns:
515 | Numpy array of zeros with same shape as the environment's actions.
516 | """
517 | return self._action_space
518 |
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/naf_components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/__init__.py
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/naf_components/demo_weights/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/demo_weights/__init__.py
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/naf_components/demo_weights/weights_kuka.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/demo_weights/weights_kuka.p
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/naf_components/demo_weights/weights_xarm6.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/naf_components/demo_weights/weights_xarm6.p
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/naf_components/naf_algorithm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import time
5 | from typing import Tuple, Dict
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | from numpy.typing import NDArray
11 | from torch.nn.utils import clip_grad_norm_
12 |
13 | from robotic_manipulator_rloa.utils.logger import get_global_logger
14 | from robotic_manipulator_rloa.environment.environment import Environment
15 | from robotic_manipulator_rloa.utils.exceptions import MissingWeightsFile
16 | from robotic_manipulator_rloa.naf_components.naf_neural_network import NAF
17 | from robotic_manipulator_rloa.utils.replay_buffer import ReplayBuffer
18 |
19 |
20 | logger = get_global_logger()
21 |
22 |
23 | class NAFAgent:
24 |
25 | MODEL_PATH = 'model.p' # Filename where the parameters of the trained torch neural network are stored
26 |
27 | def __init__(self,
28 | environment: Environment,
29 | state_size: int,
30 | action_size: int,
31 | layer_size: int,
32 | batch_size: int,
33 | buffer_size: int,
34 | learning_rate: float,
35 | tau: float,
36 | gamma: float,
37 | update_freq: int,
38 | num_updates: int,
39 | checkpoint_frequency: int,
40 | device: torch.device,
41 | seed: int) -> None:
42 | """
43 | Interacts with and learns from the environment via the NAF algorithm.
44 | Args:
45 | environment: Instance of Environment class.
46 | state_size: Dimension of the states.
47 | action_size: Dimension of the actions.
48 | layer_size: Size for the hidden layers of the neural network.
49 | batch_size: Number of experiences to train with per training batch.
50 | buffer_size: Maximum number of experiences to be stored in Replay Buffer.
51 | learning_rate: Learning rate for neural network's optimizer.
52 | tau: Hyperparameter for soft updating the target network.
53 | gamma: Discount factor.
54 | update_freq: Number of timesteps after which the main neural network is updated.
55 | num_updates: Number of updates performed when learning.
56 | checkpoint_frequency: Number of episodes after which a checkpoint is generated.
57 | device: Device used (CPU or CUDA).
58 | seed: Random seed.
59 | """
60 | # Create required parent directory
61 | os.makedirs('checkpoints/', exist_ok=True)
62 |
63 | self.environment = environment
64 | self.state_size = state_size
65 | self.action_size = action_size
66 | self.layer_size = layer_size
67 | self.buffer_size = buffer_size
68 | self.learning_rate = learning_rate
69 | random.seed(seed)
70 | self.device = device
71 | self.tau = tau
72 | self.gamma = gamma
73 | self.update_freq = update_freq
74 | self.num_updates = num_updates
75 | self.batch_size = batch_size
76 | self.checkpoint_frequency = checkpoint_frequency
77 |
78 | # Initalize Q-Networks
79 | self.qnetwork_main = NAF(state_size, action_size, layer_size, seed, device).to(device)
80 | self.qnetwork_target = NAF(state_size, action_size, layer_size, seed, device).to(device)
81 |
82 | # Define Adam as optimizer
83 | self.optimizer = optim.Adam(self.qnetwork_main.parameters(), lr=learning_rate)
84 |
85 | # Initialize Replay memory
86 | self.memory = ReplayBuffer(buffer_size, batch_size, self.device, seed)
87 |
88 | # Initialize update time step counter (for updating every {update_freq} steps)
89 | self.update_t_step = 0
90 |
91 | def initialize_pretrained_agent_from_episode(self, episode: int) -> None:
92 | """
93 | Loads the previously trained weights into the main and target neural networks.
94 | The pretrained weights are retrieved from the checkpoints generated on a training execution, so
95 | the episode provided must be present in the checkpoints/ folder.
96 | Args:
97 | episode: Episode from which to retrieve the pretrained weights.
98 | Raises:
99 | MissingWeightsFile: The weights.p file is not present in the checkpoints/{episode}/ folder provided.
100 | """
101 | # Check if file is present in checkpoints/{episode}/ directory
102 | if not os.path.isfile(f'checkpoints/{episode}/weights.p'):
103 | raise MissingWeightsFile
104 |
105 | logger.debug(f'Loading naf_components weights from trained naf_components on episode {episode}...')
106 | self.qnetwork_main.load_state_dict(torch.load(f'checkpoints/{episode}/weights.p'))
107 | self.qnetwork_target.load_state_dict(torch.load(f'checkpoints/{episode}/weights.p'))
108 | logger.info(f'Loaded weights from trained naf_components on episode {episode}')
109 |
110 | def initialize_pretrained_agent_from_weights_file(self, weights_path: str) -> None:
111 | """
112 | Loads the previously trained weights into the main and target neural networks.
113 | The pretrained weights are retrieved from a .p file containing the weights, located in
114 | the {weights_path} path.
115 | Args:
116 | weights_path: Path to the .p file containing the network's weights.
117 | Raises:
118 | MissingWeightsFile: The file path provided does not exist.
119 | """
120 | # Check if file is present
121 | if not os.path.isfile(weights_path):
122 | raise MissingWeightsFile
123 |
124 | logger.debug('Loading naf_components weights from trained naf_components...')
125 | self.qnetwork_main.load_state_dict(torch.load(weights_path))
126 | self.qnetwork_target.load_state_dict(torch.load(weights_path))
127 | logger.info('Loaded pre-trained weights for the NN')
128 |
129 | def step(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
130 | """
131 | Stores in the ReplayBuffer the new experience composed by the parameters received,
132 | and learns only if the Buffer contains enough experiences to fill a batch. The
133 | learning will occur if the update frequency {update_freq} is reached, in which case it
134 | will learn {num_updates} times.
135 | Args:
136 | state: Current state.
137 | action: Action performed from state {state}.
138 | reward: Reward obtained after performing action {action} from state {state}.
139 | next_state: New state reached after performing action {action} from state {state}.
140 | done: Integer (0 or 1) indicating whether a terminal state have been reached.
141 | """
142 |
143 | # Save experience in replay memory
144 | self.memory.add(state, action, reward, next_state, done)
145 |
146 | # Learning will be performed every {update_freq}} time-steps.
147 | self.update_t_step = (self.update_t_step + 1) % self.update_freq # Update time step counter
148 | if self.update_t_step == 0:
149 | # If enough samples are available in memory, get random subset and learn
150 | if len(self.memory) > self.batch_size:
151 | for _ in range(self.num_updates):
152 | # Pick random batch of experiences from memory
153 | experiences = self.memory.sample()
154 |
155 | # Learn from experiences and get loss
156 | self.learn(experiences)
157 |
158 | def act(self, state: NDArray) -> NDArray:
159 | """
160 | Extracts the action which maximizes the Q-Function, by getting the output of the mu layer
161 | of the main neural network.
162 | Args:
163 | state: Current state from which to pick the best action.
164 | Returns:
165 | Action which maximizes Q-Function.
166 | """
167 | state = torch.from_numpy(state).float().to(self.device)
168 |
169 | # Set evaluation mode on naf_components for obtaining a prediction
170 | self.qnetwork_main.eval()
171 | with torch.no_grad():
172 | # Get the action with maximum Q-Value from the local network
173 | action, _, _ = self.qnetwork_main(state.unsqueeze(0))
174 |
175 | # Set training mode on naf_components for future use
176 | self.qnetwork_main.train()
177 |
178 | return action.cpu().squeeze().numpy()
179 |
180 | def learn(self, experiences: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]) -> None:
181 | """
182 | Calculate the Q-Function estimate from the main neural network, the Target value from
183 | the target neural network, and calculate the loss with both values, all by feeding the received
184 | batch of experience tuples to both networks. After loss is calculated, backpropagation is performed on the
185 | main network from the given loss, so that the weights of the main network are updated.
186 | Args:
187 | experiences: Tuple of five elements, where each element is a torch.Tensor of length {batch_size}.
188 | """
189 | # Set gradients of all optimized torch Tensors to zero
190 | self.optimizer.zero_grad()
191 | states, actions, rewards, next_states, dones = experiences
192 |
193 | # Get the Value Function for the next state from target naf_components (no_grad() disables gradient calculation)
194 | with torch.no_grad():
195 | _, _, V_ = self.qnetwork_target(next_states)
196 |
197 | # Compute the target Value Functions for the given experiences.
198 | # The target value is calculated as target_val = r + gamma * V(s')
199 | target_values = rewards + (self.gamma * V_)
200 |
201 | # Compute the expected Value Function from main network
202 | _, q_estimate, _ = self.qnetwork_main(states, actions)
203 |
204 | # Compute loss between target value and expected Q value
205 | loss = F.mse_loss(q_estimate, target_values)
206 |
207 | # Perform backpropagation for minimizing loss
208 | loss.backward()
209 | clip_grad_norm_(self.qnetwork_main.parameters(), 1)
210 | self.optimizer.step()
211 |
212 | # Update the target network softly with the local one
213 | self.soft_update(self.qnetwork_main, self.qnetwork_target)
214 |
215 | # return loss.detach().cpu().numpy()
216 |
217 | def soft_update(self, main_nn: NAF, target_nn: NAF) -> None:
218 | """
219 | Soft update naf_components parameters following this formula:\n
220 | θ_target = τ*θ_local + (1 - τ)*θ_target
221 | Args:
222 | main_nn: Main torch neural network.
223 | target_nn: Target torch neural network.
224 | """
225 | for target_param, main_param in zip(target_nn.parameters(), main_nn.parameters()):
226 | target_param.data.copy_(self.tau * main_param.data + (1. - self.tau) * target_param.data)
227 |
228 | def run(self, frames: int = 1000, episodes: int = 1000, verbose: bool = True) -> Dict[int, Tuple[float, int]]:
229 | """
230 | Execute training flow of the NAF algorithm on the given environment.
231 | Args:
232 | frames: Number of maximum frames or timesteps per episode.
233 | episodes: Number of episodes required to terminate the training.
234 | verbose: Boolean indicating whether many or few logs are shown.
235 | Returns:
236 | Returns the score history generated along the training.
237 | """
238 | logger.info('Training started')
239 | # Initialize 'scores' dictionary to store rewards and timesteps executed for each episode
240 | scores = {episode: (0, 0) for episode in range(1, episodes + 1)}
241 |
242 | # Iterate through every episode
243 | for episode in range(episodes):
244 | logger.info(f'Running Episode {episode + 1}')
245 | start = time.time() # Timer to measure execution time per episode
246 | state = self.environment.reset(verbose)
247 | score, mean = 0, list()
248 |
249 | for frame in range(1, frames + 1):
250 | if verbose: logger.info(f'Running frame {frame} in episode {episode + 1}')
251 |
252 | # Pick action according to current state
253 | if verbose: logger.info(f'Current State: {state}')
254 | action = self.act(state)
255 | if verbose: logger.info(f'Action chosen for the given state is: {action}')
256 |
257 | # Perform action on environment and get new state and reward
258 | next_state, reward, done = self.environment.step(action)
259 |
260 | # Save the experience in the ReplayBuffer, and learn from previous experiences if applicable
261 | self.step(state, action, reward, next_state, done)
262 |
263 | state = next_state # Update state to next state
264 | score += reward
265 | mean.append(reward)
266 |
267 | if verbose: logger.info(f'Reward: {reward} - Cumulative reward: {score}\n')
268 |
269 | if done:
270 | break
271 |
272 | # Updates scores history
273 | scores[episode + 1] = (score, frame) # save most recent score and last frame
274 | logger.info(f'Reward: {score}')
275 | logger.info(f'Number of frames: {frame}')
276 | logger.info(f'Mean of rewards on this episode: {sum(mean) / frames}')
277 | logger.info(f'Time taken for this episode: {round(time.time() - start, 3)} secs\n')
278 |
279 | # Save the episode's performance if it is a checkpoint episode
280 | if (episode + 1) % self.checkpoint_frequency == 0:
281 | # Create parent directory for current episode
282 | os.makedirs(f'checkpoints/{episode + 1}/', exist_ok=True)
283 | # Save naf_components weights
284 | torch.save(self.qnetwork_main.state_dict(), f'checkpoints/{episode + 1}/weights.p')
285 | # Save naf_components's performance metrics
286 | with open(f'checkpoints/{episode + 1}/scores.txt', 'w') as f:
287 | f.write(json.dumps(scores))
288 |
289 | torch.save(self.qnetwork_main.state_dict(), self.MODEL_PATH)
290 | logger.info(f'Model has been successfully saved in {self.MODEL_PATH}')
291 |
292 | return scores
293 |
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/naf_components/naf_neural_network.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple, Any
2 |
3 | import torch
4 | from torch import nn
5 | from torch.distributions import MultivariateNormal
6 |
7 |
8 | class NAF(nn.Module):
9 |
10 | def __init__(self, state_size: int, action_size: int, layer_size: int, seed: int, device: torch.device) -> None:
11 | """
12 | Model to be used in the NAF algorithm. Network Architecture:\n
13 | - Common network\n
14 | - Linear + BatchNormalization (input_shape, layer_size)\n
15 | - Linear + BatchNormalization (layer_size, layer_size)\n
16 |
17 | - Output for mu network (used for calculating A)\n
18 | - Linear (layer_size, action_size)\n
19 |
20 | - Output for V network (used for calculating Q = A + V)\n
21 | - Linear (layer_size, 1)\n
22 |
23 | - Output for L network (used for calculating P = L . Lt)\n
24 | - Linear (layer_size, (action_size*action_size+1)/2)\n
25 | Args:
26 | state_size: Dimension of a state.
27 | action_size: Dimension of an action.
28 | layer_size: Size of the hidden layers of the neural network.
29 | seed: Random seed.
30 | device: CUDA device.
31 | """
32 | super(NAF, self).__init__()
33 | self.seed = torch.manual_seed(seed)
34 | self.state_size = state_size
35 | self.action_size = action_size
36 | self.device = device
37 |
38 | # DEFINE THE MODEL
39 |
40 | # Define the first NN hidden layer + BatchNormalization
41 | self.input_layer = nn.Linear(in_features=self.state_size, out_features=layer_size)
42 | self.bn1 = nn.BatchNorm1d(layer_size)
43 |
44 | # Define the second NN hidden layer + BatchNormalization
45 | self.hidden_layer = nn.Linear(in_features=layer_size, out_features=layer_size)
46 | self.bn2 = nn.BatchNorm1d(layer_size)
47 |
48 | # Define the output layer for the mu Network
49 | self.action_values = nn.Linear(in_features=layer_size, out_features=action_size)
50 | # Define the output layer for the V Network
51 | self.value = nn.Linear(in_features=layer_size, out_features=1)
52 | # Define the output layer for the L Network
53 | self.matrix_entries = nn.Linear(in_features=layer_size,
54 | out_features=int(self.action_size * (self.action_size + 1) / 2))
55 |
56 | def forward(self,
57 | input_: torch.Tensor,
58 | action: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[Any], Any]:
59 | """
60 | Forward propagation.
61 | It feeds the NN with the input, and gets the output for the mu, V and L networks.\n
62 | - Output from the L network is used to create the P matrix.\n
63 | - Output from the V network is used to calculate the Q value: Q = A + V\n
64 | - Output from the mu network is used to calculate A. The action output of mu nn is considered
65 | the action that maximizes Q-function.
66 | Args:
67 | input_: Input for the neural network's input layer.
68 | action: Current action, used for calculating the Q-Function estimate.
69 | Returns:
70 | Returns a tuple containing the action which maximizes the Q-Function, the
71 | Q-Function estimate and the Value Function.
72 | """
73 | # ============ FEED INPUT DATA TO THE NEURAL NETWORK =================================
74 |
75 | # Feed the input to the INPUT_LAYER and apply ReLu activation function (+ BatchNorm)
76 | x = torch.relu(self.bn1(self.input_layer(input_)))
77 | # Feed the output of INPUT_LAYER to the HIDDEN_LAYER layer and apply ReLu activation function (+ BatchNorm)
78 | x = torch.relu(self.bn2(self.hidden_layer(x)))
79 |
80 | # Feed the output of HIDDEN_LAYER to the mu layer and apply tanh activation function
81 | action_value = torch.tanh(self.action_values(x))
82 |
83 | # Feed the output of HIDDEN_LAYER to the L layer and apply tanh activation function
84 | matrix_entries = torch.tanh(self.matrix_entries(x))
85 |
86 | # Feed the output of HIDDEN_LAYER to the V layer
87 | V = self.value(x)
88 |
89 | # Modifies the output of the mu layer by unsqueezing it (all tensor as a 1D vector)
90 | action_value = action_value.unsqueeze(-1)
91 |
92 | # ============ CREATE L MATRIX from the outputs of the L layer =======================
93 |
94 | # Create lower-triangular matrix, size: (n_samples, action_size, action_size)
95 | L = torch.zeros((input_.shape[0], self.action_size, self.action_size)).to(self.device)
96 | # Get lower triagular indices (returns list of 2 elems, where the first row contains row coordinates
97 | # of all indices and the second row contains column coordinates)
98 | lower_tri_indices = torch.tril_indices(row=self.action_size, col=self.action_size, offset=0)
99 | # Fill matrix with the outputs of the L layer
100 | L[:, lower_tri_indices[0], lower_tri_indices[1]] = matrix_entries
101 | # Raise the diagonal elements of the matrix to the square
102 | L.diagonal(dim1=1, dim2=2).exp_()
103 | # Calculate state-dependent, positive-definite square matrix P
104 | P = L * L.transpose(2, 1)
105 |
106 | # ============================ CALCULATE Q-VALUE ===================================== #
107 |
108 | Q = None
109 | if action is not None:
110 | # Calculate Advantage Function estimate
111 | A = (-0.5 * torch.matmul(torch.matmul((action.unsqueeze(-1) - action_value).transpose(2, 1), P),
112 | (action.unsqueeze(-1) - action_value))).squeeze(-1)
113 |
114 | # Calculate Q-values
115 | Q = A + V
116 |
117 | # =========================== ADD NOISE TO ACTION ==================================== #
118 |
119 | dist = MultivariateNormal(action_value.squeeze(-1), torch.inverse(P))
120 | action = dist.sample()
121 | action = torch.clamp(action, min=-1, max=1)
122 |
123 | return action, Q, V
124 |
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/robotic_manipulator_rloa/utils/__init__.py
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/utils/collision_detector.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import List
3 |
4 | import numpy as np
5 | import pybullet as p
6 | from numpy.typing import NDArray
7 |
8 |
9 | @dataclass
10 | class CollisionObject:
11 | """
12 | Dataclass which contains the UID of the manipulator/body and the link number of the joint from which to
13 | calculate distances to other bodies.
14 | """
15 | body: str
16 | link: int
17 |
18 |
19 | class CollisionDetector:
20 |
21 | def __init__(self, collision_object: CollisionObject, obstacle_ids: List[str]):
22 | """
23 | Calculates distances between bodies' joints.
24 | Args:
25 | collision_object: CollisionObject instance, which indicates the body/joint from
26 | which to calculate distances/collisions.
27 | obstacle_ids: Obstacle body UID. Distances are calculated from the joint/body given in the
28 | "collision_object" parameter to the "obstacle_ids" bodies.
29 | """
30 | self.obstacles = obstacle_ids
31 | self.collision_object = collision_object
32 |
33 | def compute_distances(self, max_distance: float = 10.0) -> NDArray:
34 | """
35 | Compute the closest distances from the joint given by the CollisionObject instance in self.collision_object
36 | to the bodies defined in self.obstacles.
37 | Args:
38 | max_distance: Bodies farther apart than this distance are not queried by PyBullet, the return value
39 | for the distance between such bodies will be max_distance.
40 | Returns:
41 | A numpy array of distances, one per pair of collision objects.
42 | """
43 | distances = list()
44 | for obstacle in self.obstacles:
45 |
46 | # Compute the shortest distances between the collision-object and the given obstacle
47 | closest_points = p.getClosestPoints(
48 | self.collision_object.body,
49 | obstacle,
50 | distance=max_distance,
51 | linkIndexA=self.collision_object.link
52 | )
53 |
54 | # If bodies are above max_distance apart, nothing is returned, so
55 | # we just saturate at max_distance. Otherwise, take the minimum
56 | if len(closest_points) == 0:
57 | distances.append(max_distance)
58 | else:
59 | distances.append(np.min([point[8] for point in closest_points]))
60 |
61 | return np.array(distances)
62 |
63 | def compute_collisions_in_manipulator(self, affected_joints: List[int], max_distance: float = 10.) -> NDArray:
64 | """
65 | Compute collisions between manipulator's parts.
66 | Args:
67 | affected_joints: Joints to consider when calculating distances.
68 | max_distance: Maximum distance to be considered. Distances further than this will be ignored, and
69 | the "max_distance" value will be returned.
70 | Returns:
71 | Array where each element corresponds to the distances from a given joint to the other joints.
72 | """
73 | distances = list()
74 | for joint_ind in affected_joints:
75 |
76 | # Collisions with the previous and next joints are omitted, as they will be always in contact
77 | if (self.collision_object.link == joint_ind) or \
78 | (joint_ind == self.collision_object.link - 1) or \
79 | (joint_ind == self.collision_object.link + 1):
80 | continue # pragma: no cover
81 |
82 | # Compute the shortest distances between all object pairs
83 | closest_points = p.getClosestPoints(
84 | self.collision_object.body,
85 | self.collision_object.body,
86 | distance=max_distance,
87 | linkIndexA=self.collision_object.link,
88 | linkIndexB=joint_ind
89 | )
90 |
91 | # If bodies are above max_distance apart, nothing is returned, so
92 | # we just saturate at max_distance. Otherwise, take the minimum
93 | if len(closest_points) == 0:
94 | distances.append(max_distance)
95 | else:
96 | distances.append(np.min([point[8] for point in closest_points]))
97 |
98 | return np.array(distances)
99 |
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/utils/exceptions.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional
4 |
5 |
6 | class FrameworkException(Exception):
7 |
8 | def __init__(self, message: str) -> None:
9 | """
10 | Creates new FrameworkException. Base class that must be extended by any custom exception.
11 | Args:
12 | message: Info about the exception.
13 | """
14 | Exception.__init__(self, message)
15 | self.message = message
16 |
17 | def __str__(self) -> str:
18 | """ Returns the string representation of the object """
19 | return self.__class__.__name__ + ': ' + self.message
20 |
21 | def set_message(self, value: str) -> FrameworkException:
22 | """
23 | Set the message to be printed on terminal.
24 | Args:
25 | value: Message with info about exception.
26 | Returns:
27 | FrameworkException
28 | """
29 | self.message = value
30 | return self
31 |
32 |
33 | class InvalidManipulatorFile(FrameworkException):
34 | """
35 | Exception raised when the URDF/SDF file received cannot be loaded with
36 | Pybullet's loadURDF/loadSDF methods.
37 | """
38 | message = 'The URDF/SDF file received is not valid'
39 |
40 | def __init__(self, message: Optional[str] = None) -> None:
41 | if message:
42 | self.message = message
43 | FrameworkException.__init__(self, self.message)
44 |
45 |
46 | class InvalidHyperParameter(FrameworkException):
47 | """
48 | Exception raised when the user tries to set an invalid value on a hyper-parameter
49 | """
50 | message = 'The hyperparameter received is not valid'
51 |
52 | def __init__(self, message: Optional[str] = None) -> None:
53 | if message:
54 | self.message = message
55 | FrameworkException.__init__(self, self.message)
56 |
57 |
58 | class InvalidEnvironmentParameter(FrameworkException):
59 | """
60 | Exception raised when the Environment is initialized with invalid parameter/parameters.
61 | """
62 | message = 'The Environment parameter received is not valid'
63 |
64 | def __init__(self, message: Optional[str] = None) -> None:
65 | if message:
66 | self.message = message
67 | FrameworkException.__init__(self, self.message)
68 |
69 |
70 | class InvalidNAFAgentParameter(FrameworkException):
71 | """
72 | Exception raised when the NAFAgent is initialized with invalid parameter/parameters.
73 | """
74 | message = 'The NAF Agent parameter received is not valid'
75 |
76 | def __init__(self, message: Optional[str] = None) -> None:
77 | if message:
78 | self.message = message
79 | FrameworkException.__init__(self, self.message)
80 |
81 |
82 | class EnvironmentNotInitialized(FrameworkException):
83 | """
84 | Exception raised when the Environment has not yet been initialized and the user tries to
85 | call a method which requires the Environment to be initialized.
86 | """
87 | message = 'The Environment is not yet initialized. The environment can be initialized via the ' \
88 | 'initialize_environment() method'
89 |
90 | def __init__(self, message: Optional[str] = None) -> None:
91 | if message:
92 | self.message = message
93 | FrameworkException.__init__(self, self.message)
94 |
95 |
96 | class NAFAgentNotInitialized(FrameworkException):
97 | """
98 | Exception raised when the NAFAgent has not yet been initialized and the user tries to
99 | call a method which requires the NAFAgent to be initialized.
100 | """
101 | message = 'The NAF Agent is not yet initialized. The agent can be initialized via the ' \
102 | 'initialize_naf_agent() method'
103 |
104 | def __init__(self, message: Optional[str] = None) -> None:
105 | if message:
106 | self.message = message
107 | FrameworkException.__init__(self, self.message)
108 |
109 |
110 | class MissingWeightsFile(FrameworkException):
111 | """
112 | Exception raised when the user loads pretrained weights from an invalid location.
113 | """
114 | message = 'The weight file provided does not exist'
115 |
116 | def __init__(self, message: Optional[str] = None) -> None:
117 | if message:
118 | self.message = message
119 | FrameworkException.__init__(self, self.message)
120 |
121 |
122 | class ConfigurationIncomplete(FrameworkException):
123 | """
124 | Exception raised when either the Environment, the NAFAgent or both have not been initialized
125 | yet, and the user tries to execute a training by calling the run_training() method.
126 | """
127 | message = 'The configuration for the training is incomplete. Either the Environment, the ' \
128 | 'NAF Agent or both are not yet initialized. The environment can be initialized via the ' \
129 | 'initialize_environment() method, and the agent can be initialized via the ' \
130 | 'initialize_naf_agent() method'
131 |
132 | def __init__(self, message: Optional[str] = None) -> None:
133 | if message:
134 | self.message = message
135 | FrameworkException.__init__(self, self.message)
136 |
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from datetime import datetime
3 | from logging import LogRecord
4 | from logging.config import dictConfig
5 | from logging.handlers import RotatingFileHandler
6 |
7 |
8 | class Logger: # pylint: disable=too-many-instance-attributes
9 | """ Stores and processes the logs """
10 |
11 | @staticmethod
12 | def set_logger_setup() -> None:
13 | """
14 | Sets the logger setup with a predefined configuration.
15 | """
16 |
17 | log_config_dict = Logger.generate_logging_config_dict()
18 | dictConfig(log_config_dict)
19 | rotating_file_handler = RotatingFileHandler(filename='training_logs.log', mode='a', maxBytes=50000000,
20 | backupCount=10, encoding='utf-8')
21 | rotating_file_handler.setFormatter(logging.Formatter(
22 | '"%(levelname)s"|"{datetime}"|%(message)s'.format(datetime=datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.%fZ'))
23 | ))
24 | rotating_file_handler.setLevel(logging.INFO)
25 | logger = get_global_logger()
26 | logger.addHandler(rotating_file_handler)
27 | logger.setLevel(20)
28 |
29 | @staticmethod
30 | def generate_logging_config_dict() -> dict:
31 | """
32 | Generates the configuration dictionary that is used to configure the logger.
33 | Returns:
34 | Configuration dictionary.
35 | """
36 | return {
37 | 'version': 1,
38 | 'disable_existing_loggers': False,
39 | 'formatters': {
40 | 'custom_formatter': {
41 | '()': CustomFormatter,
42 | 'dateformat': '%Y-%m-%dT%H:%M:%S.%06d%z'
43 | },
44 | },
45 | 'handlers': {
46 | 'debug_console_handler': {
47 | 'level': 'NOTSET',
48 | 'formatter': 'custom_formatter',
49 | 'class': 'logging.StreamHandler',
50 | 'stream': 'ext://sys.stdout',
51 | }
52 | },
53 | 'loggers': {
54 | '': {
55 | 'handlers': ['debug_console_handler'],
56 | 'level': 'NOTSET',
57 | },
58 | }
59 | }
60 |
61 |
62 | def get_global_logger() -> logging.Logger:
63 | """
64 | Getter for the logger.
65 | Returns:
66 | Logger instance to be used on the framework.
67 | """
68 | return logging.getLogger(__name__)
69 |
70 |
71 | class CustomFormatter(logging.Formatter):
72 |
73 | def __init__(self, dateformat: str = None):
74 | """
75 | CustomFormatter for the logger.
76 | """
77 | super().__init__()
78 | self.dateformat = dateformat
79 |
80 | def format(self, record: LogRecord) -> str:
81 | """
82 | Formats the provided LogRecord instance.
83 | Returns:
84 | Formatted LogRecord as string.
85 | """
86 | # Set format and colors
87 | grey = "\033[38;20m"
88 | green = "\033[32;20m"
89 | yellow = "\033[33;20m"
90 | red = "\033[31;20m"
91 | bold_red = "\033[31;1m"
92 | reset = "\033[0m"
93 | format_ = '[%(levelname)-8s] - {datetime} - %(message)s'.format(
94 | datetime=datetime.now().astimezone().strftime('%Y-%m-%dT%H:%M:%S.%f%z')
95 | )
96 |
97 | self.FORMATS = {
98 | logging.DEBUG: green + format_ + reset,
99 | logging.INFO: grey + format_ + reset,
100 | logging.WARNING: yellow + format_ + reset,
101 | logging.ERROR: red + format_ + reset,
102 | logging.CRITICAL: bold_red + format_ + reset
103 | }
104 |
105 | log_format = self.FORMATS.get(record.levelno)
106 |
107 | formatter = logging.Formatter(log_format, datefmt=self.dateformat)
108 | return formatter.format(record)
109 |
--------------------------------------------------------------------------------
/robotic_manipulator_rloa/utils/replay_buffer.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import deque, namedtuple
3 | from typing import Tuple
4 |
5 | import numpy as np
6 | import torch
7 | from numpy.typing import NDArray
8 |
9 |
10 | # deque is a Doubly Ended Queue, provides O(1) complexity for pop and append actions
11 | # namedtuple is a tuple that can be accessed by both its index and attributes
12 |
13 |
14 | class ReplayBuffer:
15 |
16 | def __init__(self, buffer_size: int, batch_size: int, device: torch.device, seed: int):
17 | """
18 | Buffer to store experience tuples. Each experience has the following structure:
19 | (state, action, reward, next_state, done)
20 | Args:
21 | buffer_size: Maximum size for the buffer. Higher buffer size imply higher RAM consumption.
22 | batch_size: Number of experiences to be retrieved from the ReplayBuffer per batch.
23 | device: CUDA device.
24 | seed: Random seed.
25 | """
26 | self.device = device
27 | self.memory = deque(maxlen=buffer_size)
28 | self.batch_size = batch_size
29 | self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
30 | random.seed(seed)
31 |
32 | def add(self, state: NDArray, action: NDArray, reward: float, next_state: NDArray, done: int) -> None:
33 | """
34 | Add a new experience to the Replay Buffer.
35 | Args:
36 | state: NDArray of the current state.
37 | action: NDArray of the action taken from state {state}.
38 | reward: Reward obtained after performing action {action} from state {state}.
39 | next_state: NDArray of the state reached after performing action {action} from state {state}.
40 | done: Integer (0 or 1) indicating whether the next_state is a terminal state.
41 | """
42 | # Create namedtuple object from the experience
43 | exp = self.experience(state, action, reward, next_state, done)
44 | # Add the experience object to memory
45 | self.memory.append(exp)
46 |
47 | def sample(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
48 | """
49 | Randomly sample a batch of experiences from memory.
50 | Returns:
51 | Tuple of 5 elements, which are (states, actions, rewards, next_states, dones). Each element
52 | in the tuple is a torch Tensor composed of {batch_size} items.
53 | """
54 | # Randomly sample a batch of experiences
55 | experiences = random.sample(self.memory, k=self.batch_size)
56 |
57 | states = torch.from_numpy(
58 | np.stack([e.state if not isinstance(e.state, tuple) else e.state[0] for e in experiences])).float().to(
59 | self.device)
60 | actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(self.device)
61 | rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(self.device)
62 | next_states = torch.from_numpy(np.stack([e.next_state for e in experiences if e is not None])).float().to(
63 | self.device)
64 | dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(
65 | self.device)
66 |
67 | return states, actions, rewards, next_states, dones
68 |
69 | def __len__(self) -> int:
70 | """
71 | Return the current size of the Replay Buffer
72 | Returns:
73 | Size of Replay Buffer
74 | """
75 | return len(self.memory)
76 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/__init__.py
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/__init__.py
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/environment/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/environment/__init__.py
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/naf_components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/naf_components/__init__.py
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/naf_components/test_naf_algorithm.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from mock import MagicMock, patch, mock_open
4 | import pytest
5 |
6 | from robotic_manipulator_rloa.naf_components.naf_algorithm import NAFAgent
7 | from robotic_manipulator_rloa.utils.exceptions import MissingWeightsFile
8 |
9 |
10 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.time')
11 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.clip_grad_norm_')
12 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.F')
13 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch')
14 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.logger')
15 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment')
16 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.os')
17 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.random')
18 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF')
19 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.ReplayBuffer')
20 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam')
21 | def test_naf_agent(mock_optim: MagicMock,
22 | mock_replay_buffer: MagicMock,
23 | mock_naf: MagicMock,
24 | mock_random: MagicMock,
25 | mock_os: MagicMock,
26 | mock_environment: MagicMock,
27 | mock_logger: MagicMock,
28 | mock_torch: MagicMock,
29 | mock_F: MagicMock,
30 | mock_clip_grad_norm_: MagicMock,
31 | mock_time: MagicMock) -> None:
32 | """
33 | Test for the NAFAgent class constructor
34 | """
35 | # ================== Constructor ==========================================
36 | naf_agent = NAFAgent(environment=mock_environment,
37 | state_size=10,
38 | action_size=5,
39 | layer_size=128,
40 | batch_size=64,
41 | buffer_size=100000,
42 | learning_rate=0.001,
43 | tau=0.001,
44 | gamma=0.001,
45 | update_freq=1,
46 | num_updates=1,
47 | checkpoint_frequency=1,
48 | device='cpu',
49 | seed=0)
50 |
51 | mock_os.makedirs.return_value = None
52 | fake_naf_instance = mock_naf.return_value
53 | fake_torch_naf = fake_naf_instance.to.return_value
54 | fake_network_params = fake_torch_naf.parameters.return_value
55 | fake_optimizer = mock_optim.return_value
56 | fake_replay_buffer = mock_replay_buffer.return_value
57 |
58 | assert naf_agent.environment == mock_environment
59 | assert naf_agent.state_size == 10
60 | assert naf_agent.action_size == 5
61 | assert naf_agent.layer_size == 128
62 | assert naf_agent.buffer_size == 100000
63 | assert naf_agent.learning_rate == 0.001
64 | mock_random.seed.assert_any_call(0)
65 | assert naf_agent.device == 'cpu'
66 | assert naf_agent.tau == 0.001
67 | assert naf_agent.gamma == 0.001
68 | assert naf_agent.update_freq == 1
69 | assert naf_agent.num_updates == 1
70 | assert naf_agent.batch_size == 64
71 | assert naf_agent.checkpoint_frequency == 1
72 | assert naf_agent.qnetwork_main == fake_torch_naf
73 | assert naf_agent.qnetwork_target == fake_torch_naf
74 | mock_naf.assert_any_call(10, 5, 128, 0, 'cpu')
75 | assert naf_agent.optimizer == fake_optimizer
76 | mock_optim.assert_any_call(fake_network_params, lr=0.001)
77 | assert naf_agent.memory == fake_replay_buffer
78 | mock_replay_buffer.assert_any_call(100000, 64, 'cpu', 0)
79 | assert naf_agent.update_t_step == 0
80 |
81 | # ================== initialize_pretrained_agent_from_episode() ===========
82 | fake_torch_load = mock_torch.load.return_value
83 | mock_os.path.isfile.return_value = True
84 | naf_agent.initialize_pretrained_agent_from_episode(0)
85 | fake_torch_naf.load_state_dict.assert_any_call(fake_torch_load)
86 | mock_torch.load.assert_any_call('checkpoints/0/weights.p')
87 |
88 | # ================== initialize_pretrained_agent_from_episode() when file is not present
89 | mock_os.path.isfile.return_value = False
90 | with pytest.raises(MissingWeightsFile):
91 | naf_agent.initialize_pretrained_agent_from_episode(0)
92 |
93 | # ================== initialize_pretrained_agent_from_weights_file() ======
94 | fake_torch_load = mock_torch.load.return_value
95 | mock_os.path.isfile.return_value = True
96 | naf_agent.initialize_pretrained_agent_from_weights_file('weights.p')
97 | fake_torch_naf.load_state_dict.assert_any_call(fake_torch_load)
98 | mock_torch.load.assert_any_call('weights.p')
99 |
100 | # ================== initialize_pretrained_agent_from_weights_file() when file is not present
101 | mock_os.path.isfile.return_value = False
102 | with pytest.raises(MissingWeightsFile):
103 | naf_agent.initialize_pretrained_agent_from_weights_file('weights.p')
104 |
105 |
106 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.ReplayBuffer')
107 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment')
108 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.learn')
109 | def test_naf_agent__step(mock_learn: MagicMock,
110 | mock_environment: MagicMock,
111 | mock_replay_buffer: MagicMock) -> None:
112 | """Test for the step() method of the NAFAgent class"""
113 | # ================== step() ===============================================
114 | naf_agent = NAFAgent(environment=mock_environment,
115 | state_size=10,
116 | action_size=5,
117 | layer_size=128,
118 | batch_size=64,
119 | buffer_size=100000,
120 | learning_rate=0.001,
121 | tau=0.001,
122 | gamma=0.001,
123 | update_freq=1,
124 | num_updates=1,
125 | checkpoint_frequency=1,
126 | device='cpu',
127 | seed=0)
128 | fake_replay_buffer = mock_replay_buffer.return_value
129 | fake_replay_buffer.__len__.return_value = 100
130 | fake_experiences = MagicMock()
131 | fake_replay_buffer.sample.return_value = fake_experiences
132 |
133 | naf_agent.step('state', 'action', 'reward', 'next_state', 'done')
134 |
135 | fake_replay_buffer.add.assert_any_call('state', 'action', 'reward', 'next_state', 'done')
136 | assert naf_agent.update_t_step == 0
137 | fake_replay_buffer.sample.assert_any_call()
138 | mock_learn.assert_any_call(fake_experiences)
139 |
140 |
141 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment')
142 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch')
143 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF')
144 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam')
145 | def test_naf_agent__act(mock_optimizer: MagicMock,
146 | mock_naf: MagicMock,
147 | mock_torch: MagicMock,
148 | mock_environment: MagicMock) -> None:
149 | """Test for the act() method of the NAFAgent class"""
150 | # ================== act() ================================================
151 | naf_agent = NAFAgent(environment=mock_environment,
152 | state_size=10,
153 | action_size=5,
154 | layer_size=128,
155 | batch_size=64,
156 | buffer_size=100000,
157 | learning_rate=0.001,
158 | tau=0.001,
159 | gamma=0.001,
160 | update_freq=1,
161 | num_updates=1,
162 | checkpoint_frequency=1,
163 | device='cpu',
164 | seed=0)
165 | fake_state_from_numpy = mock_torch.from_numpy.return_value
166 | fake_state_float = fake_state_from_numpy.float.return_value
167 | fake_state_to = fake_state_float.to.return_value
168 | fake_state_unsqueezed = fake_state_to.unsqueeze.return_value
169 | fake_action = MagicMock()
170 | fake_naf_instance = mock_naf.return_value
171 | fake_torch_naf = fake_naf_instance.to.return_value
172 | fake_torch_naf.return_value = (fake_action, None, None)
173 | fake_action_cpu = fake_action.cpu.return_value
174 | fake_action_squeeze = fake_action_cpu.squeeze.return_value
175 | fake_action_numpy = fake_action_squeeze.numpy.return_value
176 | fake_no_grad = MagicMock(__enter__=MagicMock())
177 | mock_torch.no_grad.return_value = fake_no_grad
178 |
179 | assert naf_agent.act('state') == fake_action_numpy
180 |
181 | fake_torch_naf.eval.assert_any_call()
182 | mock_torch.no_grad.assert_any_call()
183 | fake_torch_naf.assert_any_call(fake_state_unsqueezed)
184 | fake_torch_naf.train.assert_any_call()
185 |
186 |
187 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment')
188 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch')
189 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF')
190 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam')
191 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.F')
192 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.clip_grad_norm_')
193 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.soft_update')
194 | def test_naf_agent__learn(mock_soft_update: MagicMock,
195 | mock_clip_grad_norm_: MagicMock,
196 | mock_F: MagicMock,
197 | mock_optimizer: MagicMock,
198 | mock_naf: MagicMock,
199 | mock_torch: MagicMock,
200 | mock_environment: MagicMock) -> None:
201 | """Test for the learn() method in the NAFAgent class"""
202 | # ================== learn() ==============================================
203 | naf_agent = NAFAgent(environment=mock_environment,
204 | state_size=10,
205 | action_size=5,
206 | layer_size=128,
207 | batch_size=64,
208 | buffer_size=100000,
209 | learning_rate=0.001,
210 | tau=0.001,
211 | gamma=0.001,
212 | update_freq=1,
213 | num_updates=1,
214 | checkpoint_frequency=1,
215 | device='cpu',
216 | seed=0)
217 | fake_states, fake_actions, fake_rewards, fake_next_states, fake_dones = \
218 | MagicMock(), MagicMock(), 1, MagicMock(), MagicMock()
219 | fake_experiences = (fake_states, fake_actions, fake_rewards, fake_next_states, fake_dones)
220 | fake_v, fake_q_estimate = 1, 1
221 | fake_naf_instance = mock_naf.return_value
222 | fake_torch_naf = fake_naf_instance.to.return_value
223 | fake_network_params = fake_torch_naf.parameters.return_value
224 | fake_torch_naf.return_value = (None, fake_q_estimate, fake_v)
225 | fake_optimizer = mock_optimizer.return_value
226 | fake_loss = mock_F.mse_loss.return_value
227 |
228 | naf_agent.learn(fake_experiences)
229 |
230 | fake_optimizer.zero_grad.assert_any_call()
231 | fake_torch_naf.assert_any_call(fake_next_states)
232 | fake_torch_naf.assert_any_call(fake_states, fake_actions)
233 | mock_F.mse_loss.assert_any_call(fake_q_estimate, 1.001)
234 | fake_loss.backward.assert_any_call()
235 | mock_clip_grad_norm_.assert_any_call(fake_network_params, 1)
236 | fake_optimizer.step.assert_any_call()
237 | mock_soft_update.assert_any_call(fake_torch_naf, fake_torch_naf)
238 |
239 |
240 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment')
241 | def test_naf_agent__soft_update(mock_environment: MagicMock) -> None:
242 | """Test for the soft_update() method for the NAFAgent class"""
243 | # ================== soft_update() ========================================
244 | naf_agent = NAFAgent(environment=mock_environment,
245 | state_size=10,
246 | action_size=5,
247 | layer_size=128,
248 | batch_size=64,
249 | buffer_size=100000,
250 | learning_rate=0.001,
251 | tau=0.001,
252 | gamma=0.001,
253 | update_freq=1,
254 | num_updates=1,
255 | checkpoint_frequency=1,
256 | device='cpu',
257 | seed=0)
258 | fake_main_nn, fake_target_nn = MagicMock(), MagicMock()
259 | fake_param_value = MagicMock(value=1)
260 | fake_main_params, fake_target_params = MagicMock(data=fake_param_value), MagicMock(data=fake_param_value)
261 | fake_main_nn.parameters.return_value = [fake_main_params]
262 | fake_target_nn.parameters.return_value = [fake_target_params]
263 |
264 | naf_agent.soft_update(fake_main_nn, fake_target_nn)
265 |
266 |
267 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.logger')
268 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.Environment')
269 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.optim.Adam')
270 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.time')
271 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.os')
272 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.torch')
273 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAF')
274 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.act')
275 | @patch('robotic_manipulator_rloa.naf_components.naf_algorithm.NAFAgent.step')
276 | @patch('builtins.open', new_callable=mock_open())
277 | def test_naf_agent__run(mock_open_file: MagicMock,
278 | mock_step: MagicMock,
279 | mock_act: MagicMock,
280 | mock_naf: MagicMock,
281 | mock_torch: MagicMock,
282 | mock_os: MagicMock,
283 | mock_time: MagicMock,
284 | mock_optimizer: MagicMock,
285 | mock_environment: MagicMock,
286 | mock_logger: MagicMock) -> None:
287 | """Test for the run() method for the NAFAgent class"""
288 | # ================== run() ================================================
289 | naf_agent = NAFAgent(environment=mock_environment,
290 | state_size=10,
291 | action_size=5,
292 | layer_size=128,
293 | batch_size=64,
294 | buffer_size=100000,
295 | learning_rate=0.001,
296 | tau=0.001,
297 | gamma=0.001,
298 | update_freq=1,
299 | num_updates=1,
300 | checkpoint_frequency=1,
301 | device='cpu',
302 | seed=0)
303 | mock_time.time.return_value = 50
304 | fake_next_state, fake_reward, fake_done = MagicMock(), 1, MagicMock()
305 | fake_state = mock_environment.reset.return_value
306 | fake_action = MagicMock()
307 | mock_act.return_value = fake_action
308 | mock_environment.step.return_value = (fake_next_state, fake_reward, fake_done)
309 | mock_step.return_value = None
310 | mock_os.makedirs.return_value = None
311 | fake_naf_instance = mock_naf.return_value
312 | fake_torch_naf = fake_naf_instance.to.return_value
313 | fake_torch_naf.return_value = (fake_action, None, None)
314 | fake_torch_naf.state_dict.return_value = 'params_dict'
315 |
316 | assert naf_agent.run(1, 1, True) == {1: (1, 1)}
317 |
318 | mock_time.time.assert_any_call()
319 | mock_environment.reset.assert_any_call(True)
320 | mock_act.assert_any_call(fake_state)
321 | mock_environment.step.assert_any_call(fake_action)
322 | mock_step.assert_any_call(fake_state, fake_action, fake_reward, fake_next_state, fake_done)
323 | mock_os.makedirs.assert_any_call('checkpoints/1/', exist_ok=True)
324 | fake_torch_naf.state_dict.assert_any_call()
325 | mock_torch.save.assert_any_call('params_dict', 'checkpoints/1/weights.p')
326 | mock_open_file.assert_called_once_with('checkpoints/1/scores.txt', 'w')
327 | mock_open_file.return_value.__enter__().write.assert_called_once_with(json.dumps({1: (1, 1)}))
328 | mock_torch.save.assert_any_call('params_dict', 'model.p')
329 |
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/naf_components/test_naf_neural_network.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from mock import MagicMock, patch
4 |
5 | from robotic_manipulator_rloa.naf_components.naf_neural_network import NAF
6 |
7 |
8 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.manual_seed')
9 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.nn.Module')
10 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.nn.Linear')
11 | @patch('robotic_manipulator_rloa.naf_components.naf_neural_network.torch.nn.BatchNorm1d')
12 | def test_naf(mock_batchnorm: MagicMock,
13 | mock_linear: MagicMock,
14 | mock_nn: MagicMock,
15 | mock_manual_seed: MagicMock) -> None:
16 | """Test for the NAF class constructor"""
17 | mock_manual_seed.return_value = 'manual_seed'
18 | mock_linear.side_effect = [
19 | 'linear_for_input_layer',
20 | 'linear_for_hidden_layer',
21 | 'linear_for_action_values',
22 | 'linear_for_value',
23 | 'linear_for_matrix_entries'
24 | ]
25 | mock_batchnorm.side_effect = [
26 | 'batchnorm_for_bn1',
27 | 'batchnorm_for_bn2'
28 | ]
29 |
30 | naf = NAF(state_size=10, action_size=5, layer_size=128, seed=0, device='cpu')
31 |
32 | assert naf.seed == 'manual_seed'
33 | assert naf.state_size == 10
34 | assert naf.action_size == 5
35 | assert naf.device == 'cpu'
36 | assert naf.input_layer == 'linear_for_input_layer'
37 | assert naf.bn1 == 'batchnorm_for_bn1'
38 | assert naf.hidden_layer == 'linear_for_hidden_layer'
39 | assert naf.bn2 == 'batchnorm_for_bn2'
40 | assert naf.action_values == 'linear_for_action_values'
41 | assert naf.value == 'linear_for_value'
42 | assert naf.matrix_entries == 'linear_for_matrix_entries'
43 | mock_manual_seed.assert_any_call(0)
44 | mock_linear.assert_any_call(in_features=10, out_features=128)
45 | mock_linear.assert_any_call(in_features=128, out_features=128)
46 | mock_linear.assert_any_call(in_features=128, out_features=5)
47 | mock_linear.assert_any_call(in_features=128, out_features=1)
48 | mock_linear.assert_any_call(in_features=128, out_features=15)
49 | mock_batchnorm.assert_any_call(128)
50 | mock_batchnorm.assert_any_call(128)
51 |
52 |
53 | def test_naf__forward() -> None:
54 | """Test for the forward() method of the NAF class"""
55 | device = torch.device('cpu')
56 | naf = NAF(state_size=10, action_size=5, layer_size=256, seed=0, device=device)
57 | states = torch.from_numpy(np.stack(
58 | [np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])]
59 | )).float().to(device)
60 | actions = torch.from_numpy(np.vstack(
61 | [np.array([0, 1, 2, 3, 4]), np.array([10, 11, 12, 13, 14])]
62 | )).long().to(device)
63 |
64 | action, q, v = naf(states, actions)
65 |
66 | assert q.tolist() == [[-35.50931930541992], [-638.494873046875]]
67 | assert v.tolist() == [[0.5665180683135986], [-0.08311141282320023]]
68 |
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JavierMtz5/robotic_manipulator_rloa/263aab50f70deeb48b8f0bc28041c86b76453816/tests/robotic_manipulator_rloa/utils/__init__.py
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/utils/test_collision_detector.py:
--------------------------------------------------------------------------------
1 | from mock import MagicMock, patch
2 | import pytest
3 | import numpy as np
4 |
5 | from robotic_manipulator_rloa.utils.collision_detector import CollisionDetector, CollisionObject
6 |
7 |
8 | def test_collisionobject() -> None:
9 | """Test for the CollisionObject dataclass"""
10 | collision_object = CollisionObject(body='manipulator_body', link=0)
11 | assert collision_object.body == 'manipulator_body'
12 | assert collision_object.link == 0
13 |
14 |
15 | @pytest.mark.parametrize('closest_points', [
16 | [(None, None, None, None, None, None, None, None, 1),
17 | (None, None, None, None, None, None, None, None, 3),
18 | (None, None, None, None, None, None, None, None, 5)],
19 | []
20 | ])
21 | @patch('robotic_manipulator_rloa.utils.collision_detector.p')
22 | def test_collision_detector(mock_pybullet: MagicMock,
23 | closest_points: list) -> None:
24 | """Test for the CollisionDetector class"""
25 | collision_object = CollisionObject(body='manipulator_body', link=0)
26 | collision_detector = CollisionDetector(collision_object=collision_object,
27 | obstacle_ids=['obstacle'])
28 |
29 | assert collision_detector.obstacles == ['obstacle']
30 | assert collision_detector.collision_object == collision_object
31 |
32 | # ================== TEST FOR compute_distances() method ==================
33 |
34 | mock_pybullet.getClosestPoints.return_value = closest_points
35 | output = np.array([10.0]) if len(closest_points) == 0 else np.array([1])
36 |
37 | assert collision_detector.compute_distances() == output
38 |
39 | mock_pybullet.getClosestPoints.assert_any_call(collision_object.body,
40 | 'obstacle',
41 | distance=10.0,
42 | linkIndexA=collision_object.link)
43 |
44 | # ================== TEST FOR compute_collisions_in_manipulator() method ==================
45 |
46 | mock_pybullet.getClosestPoints.return_value = closest_points
47 | output = np.array([10.0]) if len(closest_points) == 0 else np.array([1])
48 |
49 | assert collision_detector.compute_collisions_in_manipulator([0, 3]) == output
50 |
51 | mock_pybullet.getClosestPoints.assert_any_call(collision_object.body,
52 | collision_object.body,
53 | distance=10.0,
54 | linkIndexA=collision_object.link,
55 | linkIndexB=3)
56 |
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/utils/test_exceptions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from mock import MagicMock, patch
3 | from typing import Union
4 |
5 | from robotic_manipulator_rloa.utils.exceptions import (
6 | FrameworkException,
7 | InvalidManipulatorFile,
8 | InvalidHyperParameter,
9 | InvalidEnvironmentParameter,
10 | InvalidNAFAgentParameter,
11 | EnvironmentNotInitialized,
12 | NAFAgentNotInitialized,
13 | MissingWeightsFile,
14 | ConfigurationIncomplete
15 | )
16 |
17 |
18 | def test_framework_exception() -> None:
19 | """Test for the base exception class FrameworkException"""
20 | exception = FrameworkException(message='test_exception')
21 | assert exception.message == 'test_exception'
22 | assert str(exception) == 'FrameworkException: test_exception'
23 | exception.set_message('test msg')
24 | assert exception.message == 'test msg'
25 |
26 |
27 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
28 | def test_invalid_manipulator_file(msg: Union[str, None]) -> None:
29 | """Test for the InvalidManipulatorFile exception class"""
30 | if msg:
31 | exception = InvalidManipulatorFile(msg)
32 | assert exception.message == msg
33 | else:
34 | exception = InvalidManipulatorFile()
35 | assert exception.message == 'The URDF/SDF file received is not valid'
36 |
37 |
38 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
39 | def test_invalid_hyperparameter(msg: Union[str, None]) -> None:
40 | """Test for the InvalidHyperParameter exception class"""
41 | if msg:
42 | exception = InvalidHyperParameter(msg)
43 | assert exception.message == msg
44 | else:
45 | exception = InvalidHyperParameter()
46 | assert exception.message == 'The hyperparameter received is not valid'
47 |
48 |
49 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
50 | def test_invalid_environment_parameter(msg: Union[str, None]) -> None:
51 | """Test for the InvalidEnvironmentParameter exception class"""
52 | if msg:
53 | exception = InvalidEnvironmentParameter(msg)
54 | assert exception.message == msg
55 | else:
56 | exception = InvalidEnvironmentParameter()
57 | assert exception.message == 'The Environment parameter received is not valid'
58 |
59 |
60 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
61 | def test_invalid_nafagent_parameter(msg: Union[str, None]) -> None:
62 | """Test for the InvalidNAFAgentParameter exception class"""
63 | if msg:
64 | exception = InvalidNAFAgentParameter(msg)
65 | assert exception.message == msg
66 | else:
67 | exception = InvalidNAFAgentParameter()
68 | assert exception.message == 'The NAF Agent parameter received is not valid'
69 |
70 |
71 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
72 | def test_environment_not_initialized(msg: Union[str, None]) -> None:
73 | """Test for the EnvironmentNotInitialized exception class"""
74 | if msg:
75 | exception = EnvironmentNotInitialized(msg)
76 | assert exception.message == msg
77 | else:
78 | exception = EnvironmentNotInitialized()
79 | assert exception.message == 'The Environment is not yet initialized. The environment can be initialized ' \
80 | 'via the initialize_environment() method'
81 |
82 |
83 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
84 | def test_nafagent_not_initialized(msg: Union[str, None]) -> None:
85 | """Test for the NAFAgentNotInitialized exception class"""
86 | if msg:
87 | exception = NAFAgentNotInitialized(msg)
88 | assert exception.message == msg
89 | else:
90 | exception = NAFAgentNotInitialized()
91 | assert exception.message == 'The NAF Agent is not yet initialized. The agent can be initialized via the ' \
92 | 'initialize_naf_agent() method'
93 |
94 |
95 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
96 | def test_missing_weights_file(msg: Union[str, None]) -> None:
97 | """Test for the MissingWeightsFile exception class"""
98 | if msg:
99 | exception = MissingWeightsFile(msg)
100 | assert exception.message == msg
101 | else:
102 | exception = MissingWeightsFile()
103 | assert exception.message == 'The weight file provided does not exist'
104 |
105 |
106 | @pytest.mark.parametrize('msg', [None, 'error_msg'])
107 | def test_configuration_incomplete(msg: Union[str, None]) -> None:
108 | """Test for the ConfigurationIncomplete exception class"""
109 | if msg:
110 | exception = ConfigurationIncomplete(msg)
111 | assert exception.message == msg
112 | else:
113 | exception = ConfigurationIncomplete()
114 | assert exception.message == 'The configuration for the training is incomplete. Either the Environment, ' \
115 | 'the NAF Agent or both are not yet initialized. The environment can be initialized ' \
116 | 'via the initialize_environment() method, and the agent can be initialized via ' \
117 | 'the initialize_naf_agent() method'
118 |
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/utils/test_logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytest
4 | from mock import MagicMock, patch
5 |
6 | from robotic_manipulator_rloa.utils.logger import Logger, get_global_logger, CustomFormatter
7 |
8 |
9 | @patch('robotic_manipulator_rloa.utils.logger.Logger.generate_logging_config_dict')
10 | @patch('robotic_manipulator_rloa.utils.logger.dictConfig')
11 | @patch('robotic_manipulator_rloa.utils.logger.RotatingFileHandler')
12 | @patch('robotic_manipulator_rloa.utils.logger.logging.Formatter')
13 | @patch('robotic_manipulator_rloa.utils.logger.get_global_logger')
14 | @patch('robotic_manipulator_rloa.utils.logger.datetime')
15 | def test_logger(mock_datetime: MagicMock,
16 | mock_get_global_logger: MagicMock,
17 | mock_formatter: MagicMock,
18 | mock_rotating_file_handler: MagicMock,
19 | mock_dictConfig: MagicMock,
20 | mock_generate_config_dict: MagicMock) -> None:
21 | """Test for the Logger class"""
22 | fake_utcnow, fake_strftime = MagicMock(), MagicMock()
23 | mock_datetime.utcnow.return_value = fake_utcnow
24 | fake_utcnow.strftime.return_value = 'datetime'
25 | fake_config_dict = mock_generate_config_dict.return_value
26 | mock_dictConfig.return_value = None
27 | fake_rotating_file_handler, fake_formatter = MagicMock(), MagicMock()
28 | mock_rotating_file_handler.return_value = fake_rotating_file_handler
29 | mock_formatter.return_value = fake_formatter
30 | fake_rotating_file_handler.setFormatter.return_value = None
31 | fake_rotating_file_handler.setLevel.return_value = None
32 | fake_logger = MagicMock()
33 | mock_get_global_logger.return_value = fake_logger
34 | fake_logger.addHandler.return_value = None
35 | fake_logger.setLevel.return_value = None
36 |
37 | Logger.set_logger_setup()
38 |
39 | mock_generate_config_dict.assert_any_call()
40 | mock_dictConfig.assert_any_call(fake_config_dict)
41 | mock_rotating_file_handler.assert_any_call(filename='training_logs.log', mode='a', maxBytes=50000000,
42 | backupCount=10, encoding='utf-8')
43 | mock_formatter.assert_any_call('"%(levelname)s"|"datetime"|%(message)s')
44 | fake_rotating_file_handler.setFormatter.assert_any_call(fake_formatter)
45 | fake_rotating_file_handler.setLevel.assert_any_call(logging.INFO)
46 | mock_get_global_logger.assert_any_call()
47 | fake_logger.addHandler.assert_any_call(fake_rotating_file_handler)
48 | fake_logger.setLevel.assert_any_call(20)
49 |
50 |
51 | @patch('robotic_manipulator_rloa.utils.logger.CustomFormatter')
52 | def test_logger__generate_logging_config_dict(mock_custom_formatter: MagicMock) -> None:
53 | """Test for the generate_logging_config_dict() method of the Logger class"""
54 | output = {
55 | 'version': 1,
56 | 'disable_existing_loggers': False,
57 | 'formatters': {
58 | 'custom_formatter': {
59 | '()': mock_custom_formatter,
60 | 'dateformat': '%Y-%m-%dT%H:%M:%S.%06d%z'
61 | },
62 | },
63 | 'handlers': {
64 | 'debug_console_handler': {
65 | 'level': 'NOTSET',
66 | 'formatter': 'custom_formatter',
67 | 'class': 'logging.StreamHandler',
68 | 'stream': 'ext://sys.stdout',
69 | }
70 | },
71 | 'loggers': {
72 | '': {
73 | 'handlers': ['debug_console_handler'],
74 | 'level': 'NOTSET',
75 | },
76 | }
77 | }
78 | assert Logger.generate_logging_config_dict() == output
79 |
80 |
81 | @patch('robotic_manipulator_rloa.utils.logger.logging.getLogger')
82 | def test_get_global_logger(mock_get_logger: MagicMock) -> None:
83 | """Test for the get_global_logger() method"""
84 | mock_get_logger.return_value = 'logger'
85 | assert get_global_logger() == 'logger'
86 | mock_get_logger.assert_any_call('robotic_manipulator_rloa.utils.logger')
87 |
88 |
89 | @patch('robotic_manipulator_rloa.utils.logger.logging.Formatter')
90 | @patch('robotic_manipulator_rloa.utils.logger.datetime')
91 | def test_customformatter(mock_datetime: MagicMock,
92 | mock_formatter: MagicMock) -> None:
93 | """Test for the CustomFormatter class"""
94 | custom_formatter = CustomFormatter(dateformat='dateformat')
95 | assert custom_formatter.dateformat == 'dateformat'
96 |
97 | # ================== TEST FOR format() method =============================
98 |
99 | fake_formatter, fake_now, fake_astimezone = MagicMock(), MagicMock(), MagicMock()
100 | fake_input_record = MagicMock(levelno=logging.DEBUG)
101 | mock_formatter.return_value = fake_formatter
102 | fake_formatter.format.return_value = 'formatted_record'
103 | mock_datetime.now.return_value = fake_now
104 | fake_now.astimezone.return_value = fake_astimezone
105 | fake_astimezone.strftime.return_value = 'datetime'
106 |
107 | assert custom_formatter.format(fake_input_record) == 'formatted_record'
108 |
109 | mock_formatter.assert_any_call("\033[32;20m" + "[%(levelname)-8s] - datetime - %(message)s" + "\033[0m",
110 | datefmt='dateformat')
111 |
112 |
113 |
--------------------------------------------------------------------------------
/tests/robotic_manipulator_rloa/utils/test_replay_buffer.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from mock import MagicMock, patch
3 | from collections import deque, namedtuple
4 | import numpy as np
5 |
6 | from robotic_manipulator_rloa.utils.replay_buffer import ReplayBuffer
7 |
8 |
9 | @patch('robotic_manipulator_rloa.utils.replay_buffer.namedtuple')
10 | @patch('robotic_manipulator_rloa.utils.replay_buffer.deque')
11 | @patch('robotic_manipulator_rloa.utils.replay_buffer.random')
12 | def test_replaybuffer(mock_random: MagicMock,
13 | mock_deque: MagicMock,
14 | mock_namedtuple: MagicMock) -> None:
15 | """Test for the ReplayBuffer class"""
16 | replay_buffer = ReplayBuffer(buffer_size=10000,
17 | batch_size=128,
18 | device='cpu',
19 | seed=0)
20 | assert replay_buffer.device == 'cpu'
21 | assert replay_buffer.memory == mock_deque.return_value
22 | assert replay_buffer.batch_size == 128
23 | assert replay_buffer.experience == mock_namedtuple.return_value
24 |
25 | mock_deque.assert_any_call(maxlen=10000)
26 | mock_namedtuple.assert_any_call("Experience", field_names=["state", "action", "reward", "next_state", "done"])
27 | mock_random.seed.assert_any_call(0)
28 |
29 |
30 | def test_replaybuffer__add() -> None:
31 | """Test for the add() method of the ReplayBuffer class"""
32 | replay_buffer = ReplayBuffer(buffer_size=10000,
33 | batch_size=128,
34 | device='cpu',
35 | seed=0)
36 | named_tuple = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"])
37 | deque_ = deque(maxlen=10)
38 | experience = named_tuple('state', 'action', 'reward', 'next_state', 'done')
39 | deque_.append(experience)
40 |
41 | replay_buffer.add('state', 'action', 'reward', 'next_state', 'done')
42 |
43 | assert replay_buffer.memory == deque_
44 |
45 |
46 | @patch('robotic_manipulator_rloa.utils.replay_buffer.torch')
47 | @patch('robotic_manipulator_rloa.utils.replay_buffer.random.sample')
48 | @patch('robotic_manipulator_rloa.utils.replay_buffer.np.stack')
49 | @patch('robotic_manipulator_rloa.utils.replay_buffer.np.vstack')
50 | def test_replaybuffer__sample(mock_vstack: MagicMock,
51 | mock_stack: MagicMock,
52 | mock_random_sample: MagicMock,
53 | mock_torch: MagicMock) -> None:
54 | """Test for the sample() method for the ReplayBuffer class"""
55 | replay_buffer = ReplayBuffer(buffer_size=10000,
56 | batch_size=1,
57 | device='cpu',
58 | seed=0)
59 | fake_state = np.array([0, 1])
60 | fake_action = np.array([2, 3])
61 | fake_reward = 1.5
62 | fake_next_state = np.array([4, 5])
63 | fake_done = 0
64 | mock_random_sample.return_value = [MagicMock(state=fake_state,
65 | action=fake_action,
66 | reward=fake_reward,
67 | next_state=fake_next_state,
68 | done=fake_done)]
69 | fake_stack_state, fake_stack_next_state = MagicMock(), MagicMock()
70 | fake_vstack_action, fake_vstack_reward, fake_vstack_done = MagicMock(), MagicMock(), MagicMock()
71 |
72 | fake_vstack_done.astype.return_value = 'done_astype'
73 | mock_stack.side_effect = [fake_stack_state, fake_stack_next_state]
74 | mock_vstack.side_effect = [fake_vstack_action, fake_vstack_reward, fake_vstack_done]
75 |
76 | fake_from_numpy_state, fake_from_numpy_action, fake_from_numpy_reward, \
77 | fake_from_numpy_next_state, fake_from_numpy_done = MagicMock(), MagicMock(), MagicMock(), MagicMock(), MagicMock()
78 | mock_torch.from_numpy.side_effect = [
79 | fake_from_numpy_state, fake_from_numpy_action, fake_from_numpy_reward,
80 | fake_from_numpy_next_state, fake_from_numpy_done]
81 |
82 | output_state_float = fake_from_numpy_state.float.return_value
83 | output_state = output_state_float.to.return_value
84 |
85 | output_action_long = fake_from_numpy_action.long.return_value
86 | output_action = output_action_long.to.return_value
87 |
88 | output_reward_float = fake_from_numpy_reward.float.return_value
89 | output_reward = output_reward_float.to.return_value
90 |
91 | output_next_state_float = fake_from_numpy_next_state.float.return_value
92 | output_next_state = output_next_state_float.to.return_value
93 |
94 | output_done_float = fake_from_numpy_done.float.return_value
95 | output_done = output_done_float.to.return_value
96 |
97 | assert replay_buffer.sample() == (output_state, output_action, output_reward, output_next_state, output_done)
98 |
99 |
100 | def test_replaybuffer__len__() -> None:
101 | """Test for the __len__() method of the ReplayBuffer class"""
102 | replay_buffer = ReplayBuffer(buffer_size=10000,
103 | batch_size=1,
104 | device='cpu',
105 | seed=0)
106 | assert len(replay_buffer) == 0
107 | replay_buffer.add('state', 'action', 'reward', 'next_state', 'done')
108 | assert len(replay_buffer) == 1
109 |
--------------------------------------------------------------------------------