├── A_star.py
├── BFS.py
├── DFS.py
├── Dijkstra.py
├── GBFS.py
├── HybridA_star.py
├── README.md
├── common.py
├── image.jpg
├── image1.jpg
└── 图片
├── astar.png
├── astar_1.png
├── bfs.png
├── dfs.png
├── dfs_1.png
├── dij.png
├── gbfs.png
├── hybrida.png
└── hybrida1.png
/A_star.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 30 16:45:58 2023
4 |
5 | @author: HJ
6 | """
7 |
8 | # A*算法
9 | # http://www.360doc.com/content/21/0811/13/40892371_990562118.shtml
10 | from functools import lru_cache
11 | from common import *
12 |
13 | Queue_Type = 0
14 | """
15 | # OpenList 采用的 PriorityQueue 的结构
16 | ## 0 -> SetQueue
17 | ## 1 -> ListQueue
18 | ## 2 -> PriorityQueuePro
19 | List/Set可以实现更新OpenList中节点的parent和cost, 找到的路径更优\n
20 | PriorityQueuePro速度最快, 但无法更新信息, 路径较差\n
21 | List速度最慢, Set速度接近PriorityQueuePro甚至更快\n
22 | """
23 |
24 | # 地图读取
25 | IMAGE_PATH = 'image.jpg' # 原图路径
26 | THRESH = 172 # 图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
27 | HIGHT = 350 # 地图高度
28 | WIDTH = 600 # 地图宽度
29 |
30 | MAP = GridMap(IMAGE_PATH, THRESH, HIGHT, WIDTH) # 栅格地图对象
31 |
32 | # 起点终点
33 | START = (290, 270) # 起点坐标 y轴向下为正
34 | END = (298, 150) # 终点坐标 y轴向下为正
35 |
36 |
37 |
38 |
39 |
40 |
41 | # ----------------- ↓ ↓ ↓ ↓ 避障地图大概长这样 ↓ ↓ ↓ ↓ -----------------
42 |
43 | #
44 | # ...
45 | # .=BBBB#-
46 | # .B%&&&&&
47 | # .=## #&&&&%&%
48 | # -B&&&&& &&&&&B=-.
49 | # =&@&&&&&& &&&&@B
50 | # -%@@@&&&&&&& &&&&@%.
51 | # =&@@@%%@&&& 起点 &&@@@%
52 | # =@@@$#.%@@@@@@ @@ &@@@-
53 | # .&@@@%&@@@@@@& @@@ &@@=
54 | # #&@@@&@@@@@ @@@@ B@@=
55 | # -%@@@@@@@@ d@@@@@B&@-
56 | # .B%&&&&@B @@@@@&@#
57 | # #B###BBBBBBB%%&&%#
58 | # .######BBBBBBBBBB.
59 | # =####BBBBBBBBBBBB#-
60 | # .=####BB%%B%%%%%%BB##=
61 | # .=##BBB%%#- -#%%%BBB##.
62 | # .=##BBB%#. .#%%BBBB#.
63 | # =##BB%%- 终点 =%%BBBB=
64 | # =#BB%%B- .B%%%B#-
65 | # =##BBB- -BB###.
66 | # -=##BB- -##=#-
67 | # ==##B=- -####=
68 | # =##B#- -####=
69 | # ###B= =###=
70 | # =##B#- ###=
71 | # =BB#= =BB=
72 | # -%&% =&
73 | # %&%% B%&&=
74 | #
75 |
76 | # ----------------------------------------------------------------------
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 | """ ---------------------------- A*算法 ---------------------------- """
85 | # F = G + H
86 |
87 |
88 | # 设置OpenList使用的优先队列
89 | if Queue_Type == 0:
90 | NodeQueue = SetQueue
91 | elif Queue_Type == 1:
92 | NodeQueue = ListQueue
93 | else:
94 | NodeQueue = PriorityQueuePro
95 |
96 |
97 | # A*算法
98 | class AStar:
99 | """A*算法"""
100 |
101 | def __init__(
102 | self,
103 | start_pos = START,
104 | end_pos = END,
105 | map_array = MAP.map_array,
106 | move_step = 3,
107 | move_direction = 8,
108 | ):
109 | """A*算法
110 |
111 | Parameters
112 | ----------
113 | start_pos : tuple/list
114 | 起点坐标
115 | end_pos : tuple/list
116 | 终点坐标
117 | map_array : ndarray
118 | 二值化地图, 0表示障碍物, 255表示空白, H*W维
119 | move_step : int
120 | 移动步数, 默认3
121 | move_direction : int (8 or 4)
122 | 移动方向, 默认8个方向
123 | """
124 |
125 | # 网格化地图
126 | self.map_array = map_array # H * W
127 |
128 | self.width = self.map_array.shape[1]
129 | self.high = self.map_array.shape[0]
130 |
131 | # 起点终点
132 | self.start = Node(*start_pos) # 初始位置
133 | self.end = Node(*end_pos) # 结束位置
134 |
135 | # Error Check
136 | if not self._in_map(self.start) or not self._in_map(self.end):
137 | raise ValueError(f"x坐标范围0~{self.width-1}, y坐标范围0~{self.height-1}")
138 | if self._is_collided(self.start):
139 | raise ValueError(f"起点x坐标或y坐标在障碍物上")
140 | if self._is_collided(self.end):
141 | raise ValueError(f"终点x坐标或y坐标在障碍物上")
142 |
143 | # 算法初始化
144 | self.reset(move_step, move_direction)
145 |
146 |
147 | def reset(self, move_step=3, move_direction=8):
148 | """重置算法"""
149 | self.__reset_flag = False
150 | self.move_step = move_step # 移动步长(搜索后期会减小)
151 | self.move_direction = move_direction # 移动方向 8 个
152 | self.close_set = set() # 存储已经走过的位置及其G值
153 | self.open_queue = NodeQueue() # 存储当前位置周围可行的位置及其F值
154 | self.path_list = [] # 存储路径(CloseList里的数据无序)
155 |
156 |
157 | def search(self):
158 | """搜索路径"""
159 | return self.__call__()
160 |
161 |
162 | def _in_map(self, node: Node):
163 | """点是否在网格地图中"""
164 | return (0 <= node.x < self.width) and (0 <= node.y < self.high) # 右边不能取等!!!
165 |
166 |
167 | def _is_collided(self, node: Node):
168 | """点是否和障碍物碰撞"""
169 | return self.map_array[node.y, node.x] == 0
170 |
171 |
172 | def _move(self):
173 | """移动点"""
174 | @lru_cache(maxsize=3) # 避免参数相同时重复计算
175 | def _move(move_step:int, move_direction:int):
176 | move = [
177 | (0, move_step), # 上
178 | (0, -move_step), # 下
179 | (-move_step, 0), # 左
180 | (move_step, 0), # 右
181 | (move_step, move_step), # 右上
182 | (move_step, -move_step), # 右下
183 | (-move_step, move_step), # 左上
184 | (-move_step, -move_step), # 左下
185 | ]
186 | return move[0:move_direction] # 坐标增量
187 | return _move(self.move_step, self.move_direction)
188 |
189 |
190 | def _update_open_list(self, curr: Node):
191 | """open_list添加可行点"""
192 | for add in self._move():
193 | # 更新节点
194 | next_ = curr + add # x、y、cost、parent都更新了
195 |
196 | # 新位置是否在地图外边
197 | if not self._in_map(next_):
198 | continue
199 | # 新位置是否碰到障碍物
200 | if self._is_collided(next_):
201 | continue
202 | # 新位置是否在 CloseList 中
203 | if next_ in self.close_set:
204 | continue
205 |
206 | # 把节点的 G 代价改成 F 代价
207 | H = next_ - self.end
208 | next_.cost += H
209 |
210 | # open-list添加/更新结点
211 | self.open_queue.put(next_)
212 |
213 | # 当剩余距离小时, 走慢一点
214 | if H < 20:
215 | self.move_step = 1
216 |
217 |
218 | def __call__(self):
219 | """A*路径搜索"""
220 | assert not self.__reset_flag, "call之前需要reset"
221 | print("搜索中\n")
222 |
223 | # 初始化 OpenList
224 | self.open_queue.put(self.start)
225 |
226 | # 正向搜索节点
227 | tic()
228 | while not self.open_queue.empty():
229 | # 弹出 OpenList 代价 F 最小的点
230 | curr = self.open_queue.get() # OpenList里是 F
231 | curr.cost -= (curr - self.end) # G = F - H
232 | # 更新 OpenList
233 | self._update_open_list(curr)
234 | # 更新 CloseList
235 | self.close_set.add(curr)
236 | # 结束迭代
237 | if curr == self.end:
238 | break
239 | print("路径搜索完成\n")
240 | toc()
241 |
242 | # 节点组合成路径
243 | while curr.parent is not None:
244 | self.path_list.append(curr)
245 | curr = curr.parent
246 | self.path_list.reverse()
247 |
248 | # 需要重置
249 | self.__reset_flag = True
250 |
251 | return self.path_list
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 | # debug
264 | if __name__ == '__main__':
265 | p = AStar()()
266 | MAP.show_path(p)
267 |
--------------------------------------------------------------------------------
/BFS.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 30 16:45:58 2023
4 |
5 | @author: HJ
6 | """
7 |
8 | # 广度优先搜索(Breadth First Search, BFS)算法
9 | from typing import Union
10 | from functools import lru_cache
11 | from collections import deque
12 | from dataclasses import dataclass
13 | from common import tic, toc, GridMap
14 |
15 |
16 | # 地图读取
17 | IMAGE_PATH = 'image.jpg' # 原图路径
18 | THRESH = 172 # 图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
19 | HIGHT = 350 # 地图高度
20 | WIDTH = 600 # 地图宽度
21 |
22 | MAP = GridMap(IMAGE_PATH, THRESH, HIGHT, WIDTH) # 栅格地图对象
23 |
24 | # 起点终点
25 | START = (290, 270) # 起点坐标 y轴向下为正
26 | END = (298, 150) # 终点坐标 y轴向下为正
27 |
28 |
29 |
30 |
31 |
32 |
33 | """ ---------------------------- Breadth First Search算法 ---------------------------- """
34 |
35 | Number = Union[int, float]
36 |
37 |
38 |
39 | @dataclass(eq=False)
40 | class Node:
41 | """节点"""
42 |
43 | x: int
44 | y: int
45 | parent: "Node" = None
46 |
47 | def __sub__(self, other) -> int:
48 | """计算节点与坐标的曼哈顿距离"""
49 | if isinstance(other, Node):
50 | return abs(self.x - other.x) + abs(self.y - other.y)
51 | elif isinstance(other, (tuple, list)):
52 | return abs(self.x - other[0]) + abs(self.y - other[1])
53 | raise ValueError("other必须为坐标或Node")
54 |
55 | def __add__(self, other: Union[tuple, list]) -> "Node":
56 | """生成新节点"""
57 | x = self.x + other[0]
58 | y = self.y + other[1]
59 | return Node(x, y, self)
60 |
61 | def __eq__(self, other):
62 | """坐标x,y比较 -> node in close_list"""
63 | if isinstance(other, Node):
64 | return self.x == other.x and self.y == other.y
65 | elif isinstance(other, (tuple, list)):
66 | return self.x == other[0] and self.y == other[1]
67 | return False
68 |
69 | def __hash__(self) -> int:
70 | """使可变对象可hash, 能放入set中"""
71 | return hash((self.x, self.y)) # tuple 可 hash
72 | # data in set 时间复杂度为 O(1), 但 data必须可hash
73 | # data in list 时间复杂度 O(n)
74 |
75 |
76 |
77 |
78 | # NOTE 广度优先搜索先入先出, 双向队列左侧弹出数据O(1), 列表左侧弹出数据O(n)
79 |
80 |
81 | # 广度优先搜索算法
82 | class BFS:
83 | """BFS算法"""
84 |
85 | def __init__(
86 | self,
87 | start_pos = START,
88 | end_pos = END,
89 | map_array = MAP.map_array,
90 | move_step = 3,
91 | move_direction = 8,
92 | ):
93 | """BFS算法
94 |
95 | Parameters
96 | ----------
97 | start_pos : tuple/list
98 | 起点坐标
99 | end_pos : tuple/list
100 | 终点坐标
101 | map_array : ndarray
102 | 二值化地图, 0表示障碍物, 255表示空白, H*W维
103 | move_step : int
104 | 移动步数, 默认3
105 | move_direction : int (8 or 4)
106 | 移动方向, 默认8个方向
107 | """
108 | # 网格化地图
109 | self.map_array = map_array # H * W
110 |
111 | self.width = self.map_array.shape[1]
112 | self.high = self.map_array.shape[0]
113 |
114 | # 起点终点
115 | self.start = Node(*start_pos) # 初始位置
116 | self.end = Node(*end_pos) # 结束位置
117 |
118 | # Error Check
119 | if not self._in_map(self.start) or not self._in_map(self.end):
120 | raise ValueError(f"x坐标范围0~{self.width-1}, y坐标范围0~{self.height-1}")
121 | if self._is_collided(self.start):
122 | raise ValueError(f"起点x坐标或y坐标在障碍物上")
123 | if self._is_collided(self.end):
124 | raise ValueError(f"终点x坐标或y坐标在障碍物上")
125 |
126 | # 算法初始化
127 | self.reset(move_step, move_direction)
128 |
129 |
130 | def reset(self, move_step=3, move_direction=8):
131 | """重置算法"""
132 | self.__reset_flag = False
133 | self.move_step = move_step # 移动步长(搜索后期会减小)
134 | self.move_direction = move_direction # 移动方向 8 个
135 | self.close_set = set() # 存储已经走过的位置及其G值
136 | self.open_deque = deque() # 存储当前位置周围可行的位置及其F值
137 | self.path_list = [] # 存储路径(CloseList里的数据无序)
138 |
139 |
140 | def search(self):
141 | """搜索路径"""
142 | return self.__call__()
143 |
144 |
145 | def _in_map(self, node: Node):
146 | """点是否在网格地图中"""
147 | return (0 <= node.x < self.width) and (0 <= node.y < self.high) # 右边不能取等!!!
148 |
149 |
150 | def _is_collided(self, node: Node):
151 | """点是否和障碍物碰撞"""
152 | return self.map_array[node.y, node.x] == 0
153 |
154 |
155 | def _move(self):
156 | """移动点"""
157 | @lru_cache(maxsize=3) # 避免参数相同时重复计算
158 | def _move(move_step:int, move_direction:int):
159 | move = [
160 | (0, move_step), # 上
161 | (0, -move_step), # 下
162 | (-move_step, 0), # 左
163 | (move_step, 0), # 右
164 | (move_step, move_step), # 右上
165 | (move_step, -move_step), # 右下
166 | (-move_step, move_step), # 左上
167 | (-move_step, -move_step), # 左下
168 | ]
169 | return move[0:move_direction] # 坐标增量
170 | return _move(self.move_step, self.move_direction)
171 |
172 |
173 | def _update_open_list(self, curr: Node):
174 | """open_list添加可行点"""
175 | for add in self._move():
176 | # 更新节点
177 | next_ = curr + add # x、y、cost、parent都更新了
178 |
179 | # 新位置是否在地图外边
180 | if not self._in_map(next_):
181 | continue
182 | # 新位置是否碰到障碍物
183 | if self._is_collided(next_):
184 | continue
185 | # 新位置是否在 CloseList 和 OpenDeque 中
186 | if next_ in self.close_set or next_ in self.open_deque:
187 | continue
188 |
189 | # open-list添加结点
190 | self.open_deque.append(next_)
191 |
192 | # 当剩余距离小时, 走慢一点
193 | if next_ - self.end < 20:
194 | self.move_step = 1
195 |
196 |
197 | def __call__(self):
198 | """BFS路径搜索"""
199 | assert not self.__reset_flag, "call之前需要reset"
200 | print("搜索中\n")
201 |
202 | # 初始化 OpenList
203 | self.open_deque.append(self.start)
204 |
205 | # 正向搜索节点
206 | tic()
207 | while self.open_deque:
208 | # 弹出 OpenList 最前的节点
209 | curr = self.open_deque.popleft()
210 | # 更新 OpenList
211 | self._update_open_list(curr)
212 | # 更新 CloseList
213 | self.close_set.add(curr)
214 | # 结束迭代
215 | if curr == self.end:
216 | break
217 | print("路径搜索完成\n")
218 | toc()
219 |
220 | # 节点组合成路径
221 | while curr.parent is not None:
222 | self.path_list.append(curr)
223 | curr = curr.parent
224 | self.path_list.reverse()
225 |
226 | # 需要重置
227 | self.__reset_flag = True
228 |
229 | return self.path_list
230 |
231 |
232 |
233 |
234 |
235 |
236 | # debug
237 | if __name__ == '__main__':
238 | p = BFS()()
239 | MAP.show_path(p)
240 |
--------------------------------------------------------------------------------
/DFS.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 30 16:45:58 2023
4 |
5 | @author: HJ
6 | """
7 |
8 | # 深度优先搜索(Depth First Search, DFS)算法
9 | from typing import Union
10 | from functools import lru_cache
11 | from dataclasses import dataclass
12 | from common import tic, toc, GridMap
13 |
14 |
15 | # 地图读取
16 | IMAGE_PATH = 'image.jpg' # 原图路径
17 | THRESH = 172 # 图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
18 | HIGHT = 350 # 地图高度
19 | WIDTH = 600 # 地图宽度
20 |
21 | MAP = GridMap(IMAGE_PATH, THRESH, HIGHT, WIDTH) # 栅格地图对象
22 |
23 | # 起点终点
24 | START = (290, 270) # 起点坐标 y轴向下为正
25 | END = (298, 150) # 终点坐标 y轴向下为正
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | """ ---------------------------- Depth First Search算法 ---------------------------- """
34 |
35 |
36 |
37 |
38 | @dataclass(eq=False)
39 | class Node:
40 | """节点"""
41 |
42 | x: int
43 | y: int
44 | parent: "Node" = None
45 |
46 | def __sub__(self, other) -> int:
47 | """计算节点与坐标的曼哈顿距离"""
48 | if isinstance(other, Node):
49 | return abs(self.x - other.x) + abs(self.y - other.y)
50 | elif isinstance(other, (tuple, list)):
51 | return abs(self.x - other[0]) + abs(self.y - other[1])
52 | raise ValueError("other必须为坐标或Node")
53 |
54 | def __add__(self, other: Union[tuple, list]) -> "Node":
55 | """生成新节点"""
56 | x = self.x + other[0]
57 | y = self.y + other[1]
58 |
59 | return Node(x, y, self)
60 |
61 | def __eq__(self, other):
62 | """坐标x,y比较 -> node in close_list"""
63 | if isinstance(other, Node):
64 | return self.x == other.x and self.y == other.y
65 | elif isinstance(other, (tuple, list)):
66 | return self.x == other[0] and self.y == other[1]
67 | return False
68 |
69 | def __hash__(self) -> int:
70 | """使可变对象可hash, 能放入set中"""
71 | return hash((self.x, self.y)) # tuple 可 hash
72 | # data in set 时间复杂度为 O(1), 但 data必须可hash
73 | # data in list 时间复杂度 O(n)
74 |
75 |
76 |
77 |
78 | # NOTE 深度优先先进后出, 用列表即可
79 |
80 |
81 | # 深度优先搜索算法
82 | class DFS:
83 | """DFS算法"""
84 |
85 | def __init__(
86 | self,
87 | start_pos = START,
88 | end_pos = END,
89 | map_array = MAP.map_array,
90 | move_step = 5,
91 | move_direction = 8,
92 | ):
93 | """DFS算法
94 |
95 | Parameters
96 | ----------
97 | start_pos : tuple/list
98 | 起点坐标
99 | end_pos : tuple/list
100 | 终点坐标
101 | map_array : ndarray
102 | 二值化地图, 0表示障碍物, 255表示空白, H*W维
103 | move_step : int
104 | 移动步数, 默认5
105 | move_direction : int (8 or 4)
106 | 移动方向, 默认8个方向
107 | """
108 | # 网格化地图
109 | self.map_array = map_array # H * W
110 |
111 | self.width = self.map_array.shape[1]
112 | self.high = self.map_array.shape[0]
113 |
114 | # 起点终点
115 | self.start = Node(*start_pos) # 初始位置
116 | self.end = Node(*end_pos) # 结束位置
117 |
118 | # Error Check
119 | if not self._in_map(self.start) or not self._in_map(self.end):
120 | raise ValueError(f"x坐标范围0~{self.width-1}, y坐标范围0~{self.height-1}")
121 | if self._is_collided(self.start):
122 | raise ValueError(f"起点x坐标或y坐标在障碍物上")
123 | if self._is_collided(self.end):
124 | raise ValueError(f"终点x坐标或y坐标在障碍物上")
125 |
126 | # 算法初始化
127 | self.reset(move_step, move_direction)
128 |
129 |
130 | def reset(self, move_step=3, move_direction=8):
131 | """重置算法"""
132 | self.__reset_flag = False
133 | self.move_step = move_step # 移动步长(搜索后期会减小)
134 | self.move_direction = move_direction # 移动方向 8 个
135 | self.close_set = set() # 存储已经走过的位置及其G值
136 | self.open_list = [] # 存储当前位置周围可行的位置及其F值
137 | self.path_list = [] # 存储路径(CloseList里的数据无序)
138 |
139 |
140 | def search(self):
141 | """搜索路径"""
142 | return self.__call__()
143 |
144 |
145 | def _in_map(self, node: Node):
146 | """点是否在网格地图中"""
147 | return (0 <= node.x < self.width) and (0 <= node.y < self.high) # 右边不能取等!!!
148 |
149 |
150 | def _is_collided(self, node: Node):
151 | """点是否和障碍物碰撞"""
152 | return self.map_array[node.y, node.x] == 0
153 |
154 |
155 | def _move(self):
156 | """移动点"""
157 | @lru_cache(maxsize=3) # 避免参数相同时重复计算
158 | def _move(move_step:int, move_direction:int):
159 | move = [
160 | (0, move_step), # 上
161 | (0, -move_step), # 下
162 | (-move_step, 0), # 左
163 | (move_step, 0), # 右
164 | (move_step, move_step), # 右上
165 | (move_step, -move_step), # 右下
166 | (-move_step, move_step), # 左上
167 | (-move_step, -move_step), # 左下
168 | ]
169 | return move[0:move_direction] # 坐标增量
170 | return _move(self.move_step, self.move_direction)[::-1] # 后入先出, 斜着搜索太慢, 把直的放后面
171 |
172 |
173 | def _update_open_list(self, curr: Node):
174 | """open_list添加可行点"""
175 | for add in self._move():
176 | # 更新节点
177 | next_ = curr + add # x、y、cost、parent都更新了
178 |
179 | # 新位置是否在地图外边
180 | if not self._in_map(next_):
181 | continue
182 | # 新位置是否碰到障碍物
183 | if self._is_collided(next_):
184 | continue
185 | # 新位置是否在 CloseList 和 OpenList 中
186 | if next_ in self.close_set or next_ in self.open_list:
187 | continue
188 |
189 | # open-list添加结点
190 | self.open_list.append(next_)
191 |
192 | # 当剩余距离小时, 走慢一点
193 | if (next_ - self.end) < 20:
194 | self.move_step = 1
195 |
196 |
197 | def __call__(self):
198 | """DFS路径搜索"""
199 | assert not self.__reset_flag, "call之前需要reset"
200 | print("搜索中\n")
201 |
202 | # 初始化 OpenList
203 | self.open_list.append(self.start)
204 |
205 | # 正向搜索节点
206 | tic()
207 | while self.open_list:
208 | # 弹出 OpenList 最后的节点
209 | curr = self.open_list.pop()
210 | # 更新 OpenList
211 | self._update_open_list(curr)
212 | # 更新 CloseList
213 | self.close_set.add(curr)
214 | # 结束迭代
215 | if curr == self.end:
216 | break
217 | print("路径搜索完成\n")
218 | toc()
219 |
220 | # 节点组合成路径
221 | while curr.parent is not None:
222 | self.path_list.append(curr)
223 | curr = curr.parent
224 | self.path_list.reverse()
225 |
226 | # 需要重置
227 | self.__reset_flag = True
228 |
229 | return self.path_list
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 | # debug
240 | if __name__ == '__main__':
241 | p = DFS()()
242 | MAP.show_path(p)
243 |
244 |
--------------------------------------------------------------------------------
/Dijkstra.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 30 16:45:58 2023
4 |
5 | @author: HJ
6 | """
7 |
8 | # 迪杰斯特拉(Dijkstra)算法
9 | # A*: F = G + H
10 | # Dijkstra: F = G
11 | # https://zhuanlan.zhihu.com/p/346666812
12 | from functools import lru_cache
13 | from common import *
14 |
15 | Queue_Type = 2
16 | """
17 | # OpenList 采用的 PriorityQueue 的结构
18 | ## 0 -> SetQueue
19 | ## 1 -> ListQueue
20 | ## 2 -> PriorityQueuePro
21 | List/Set可以实现更新OpenList中节点的parent和cost, 找到的路径更优\n
22 | PriorityQueuePro速度最快, 但无法更新信息, 路径较差\n
23 | List速度最慢, Set速度接近PriorityQueuePro甚至更快\n
24 | """
25 |
26 | # 地图读取
27 | IMAGE_PATH = 'image.jpg' # 原图路径
28 | THRESH = 172 # 图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
29 | HIGHT = 350 # 地图高度
30 | WIDTH = 600 # 地图宽度
31 |
32 | MAP = GridMap(IMAGE_PATH, THRESH, HIGHT, WIDTH) # 栅格地图对象
33 |
34 | # 起点终点
35 | START = (290, 270) # 起点坐标 y轴向下为正
36 | END = (298, 150) # 终点坐标 y轴向下为正
37 |
38 |
39 |
40 |
41 |
42 | """ ---------------------------- Dijkstra算法 ---------------------------- """
43 | # F = G + 0
44 |
45 |
46 | # 设置OpenList使用的优先队列
47 | if Queue_Type == 0:
48 | NodeQueue = SetQueue
49 | elif Queue_Type == 1:
50 | NodeQueue = ListQueue
51 | else:
52 | NodeQueue = PriorityQueuePro
53 |
54 |
55 | # 迪杰斯特拉算法
56 | class Dijkstra:
57 | """Dijkstra算法"""
58 |
59 | def __init__(
60 | self,
61 | start_pos = START,
62 | end_pos = END,
63 | map_array = MAP.map_array,
64 | move_step = 3,
65 | move_direction = 8,
66 | ):
67 | """Dijkstra算法
68 |
69 | Parameters
70 | ----------
71 | start_pos : tuple/list
72 | 起点坐标
73 | end_pos : tuple/list
74 | 终点坐标
75 | map_array : ndarray
76 | 二值化地图, 0表示障碍物, 255表示空白, H*W维
77 | move_step : int
78 | 移动步数, 默认3
79 | move_direction : int (8 or 4)
80 | 移动方向, 默认8个方向
81 | """
82 | # 网格化地图
83 | self.map_array = map_array # H * W
84 |
85 | self.width = self.map_array.shape[1]
86 | self.high = self.map_array.shape[0]
87 |
88 | # 起点终点
89 | self.start = Node(*start_pos) # 初始位置
90 | self.end = Node(*end_pos) # 结束位置
91 |
92 | # Error Check
93 | if not self._in_map(self.start) or not self._in_map(self.end):
94 | raise ValueError(f"x坐标范围0~{self.width-1}, y坐标范围0~{self.height-1}")
95 | if self._is_collided(self.start):
96 | raise ValueError(f"起点x坐标或y坐标在障碍物上")
97 | if self._is_collided(self.end):
98 | raise ValueError(f"终点x坐标或y坐标在障碍物上")
99 |
100 | # 算法初始化
101 | self.reset(move_step, move_direction)
102 |
103 |
104 | def reset(self, move_step=3, move_direction=8):
105 | """重置算法"""
106 | self.__reset_flag = False
107 | self.move_step = move_step # 移动步长(搜索后期会减小)
108 | self.move_direction = move_direction # 移动方向 8 个
109 | self.close_set = set() # 存储已经走过的位置及其G值
110 | self.open_queue = NodeQueue() # 存储当前位置周围可行的位置及其F值
111 | self.path_list = [] # 存储路径(CloseList里的数据无序)
112 |
113 |
114 | def search(self):
115 | """搜索路径"""
116 | return self.__call__()
117 |
118 |
119 | def _in_map(self, node: Node):
120 | """点是否在网格地图中"""
121 | return (0 <= node.x < self.width) and (0 <= node.y < self.high) # 右边不能取等!!!
122 |
123 |
124 | def _is_collided(self, node: Node):
125 | """点是否和障碍物碰撞"""
126 | return self.map_array[node.y, node.x] < 1
127 |
128 |
129 | def _move(self):
130 | """移动点"""
131 | @lru_cache(maxsize=3) # 避免参数相同时重复计算
132 | def _move(move_step:int, move_direction:int):
133 | move = (
134 | (0, move_step), # 上
135 | (0, -move_step), # 下
136 | (-move_step, 0), # 左
137 | (move_step, 0), # 右
138 | (move_step, move_step), # 右上
139 | (move_step, -move_step), # 右下
140 | (-move_step, move_step), # 左上
141 | (-move_step, -move_step), # 左下
142 | )
143 | return move[0:move_direction] # 坐标增量+代价
144 | return _move(self.move_step, self.move_direction)
145 |
146 |
147 | def _update_open_list(self, curr: Node):
148 | """open_list添加可行点"""
149 | for add in self._move():
150 | # 更新可行位置
151 | next_ = curr + add
152 |
153 | # 新位置是否在地图外边
154 | if not self._in_map(next_):
155 | continue
156 | # 新位置是否碰到障碍物
157 | if self._is_collided(next_):
158 | continue
159 | # 新位置是否在 CloseList 中
160 | if next_ in self.close_set:
161 | continue
162 |
163 | # open-list添加/更新结点
164 | self.open_queue.put(next_)
165 |
166 | # 当剩余距离小时, 走慢一点
167 | if next_ - self.end < 20:
168 | self.move_step = 1
169 |
170 |
171 | def __call__(self):
172 | """Dijkstra路径搜索"""
173 | assert not self.__reset_flag, "call之前需要reset"
174 | print("搜索中\n")
175 |
176 | # 初始化列表
177 | self.open_queue.put(self.start) # 初始化 OpenList
178 |
179 | # 正向搜索节点(CloseList里的数据无序)
180 | tic()
181 | while not self.open_queue.empty():
182 | # 弹出 OpenList 代价 G 最小的点
183 | curr = self.open_queue.get()
184 | # 更新 OpenList
185 | self._update_open_list(curr)
186 | # 更新 CloseList
187 | self.close_set.add(curr)
188 | # 结束迭代
189 | if curr == self.end:
190 | break
191 | print("路径搜索完成\n")
192 | toc()
193 |
194 | # 节点组合成路径
195 | while curr.parent is not None:
196 | self.path_list.append(curr)
197 | curr = curr.parent
198 | self.path_list.reverse()
199 |
200 | # 需要重置
201 | self.__reset_flag = True
202 |
203 | return self.path_list
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 | # debug
214 | if __name__ == '__main__':
215 | p = Dijkstra()()
216 | MAP.show_path(p)
--------------------------------------------------------------------------------
/GBFS.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 30 16:45:58 2023
4 |
5 | @author: HJ
6 | """
7 |
8 | # 贪婪最佳优先搜索(Greedy Best First Search, GBFS)算法
9 | # A*: F = G + H
10 | # GBFS: F = H
11 | # https://zhuanlan.zhihu.com/p/346666812
12 | from functools import lru_cache
13 | from common import *
14 |
15 | Queue_Type = 0
16 | """
17 | # OpenList 采用的 PriorityQueue 的结构
18 | ## 0 -> SetQueue
19 | ## 1 -> ListQueue
20 | ## 2 -> PriorityQueuePro
21 | List/Set可以实现更新OpenList中节点的parent和cost, 找到的路径更优\n
22 | PriorityQueuePro速度最快, 但无法更新信息, 路径较差\n
23 | List速度最慢, Set速度接近PriorityQueuePro甚至更快\n
24 | """
25 |
26 |
27 | # 地图读取
28 | IMAGE_PATH = 'image.jpg' # 原图路径
29 | THRESH = 172 # 图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
30 | HIGHT = 350 # 地图高度
31 | WIDTH = 600 # 地图宽度
32 |
33 | MAP = GridMap(IMAGE_PATH, THRESH, HIGHT, WIDTH) # 栅格地图对象
34 |
35 | # 起点终点
36 | START = (290, 270) # 起点坐标 y轴向下为正
37 | END = (298, 150) # 终点坐标 y轴向下为正
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 | """ ---------------------------- Greedy Best First Search算法 ---------------------------- """
46 | # F = H
47 |
48 |
49 | # 设置OpenList使用的优先队列
50 | if Queue_Type == 0:
51 | NodeQueue = SetQueue
52 | elif Queue_Type == 1:
53 | NodeQueue = ListQueue
54 | else:
55 | NodeQueue = PriorityQueuePro
56 |
57 |
58 | # 贪婪最佳优先搜索算法
59 | class GBFS:
60 | """GBFS算法"""
61 |
62 | def __init__(
63 | self,
64 | start_pos = START,
65 | end_pos = END,
66 | map_array = MAP.map_array,
67 | move_step = 3,
68 | move_direction = 8,
69 | ):
70 | """GBFS算法
71 |
72 | Parameters
73 | ----------
74 | start_pos : tuple/list
75 | 起点坐标
76 | end_pos : tuple/list
77 | 终点坐标
78 | map_array : ndarray
79 | 二值化地图, 0表示障碍物, 255表示空白, H*W维
80 | move_step : int
81 | 移动步数, 默认3
82 | move_direction : int (8 or 4)
83 | 移动方向, 默认8个方向
84 | """
85 | # 网格化地图
86 | self.map_array = map_array # H * W
87 |
88 | self.width = self.map_array.shape[1]
89 | self.high = self.map_array.shape[0]
90 |
91 | # 起点终点
92 | self.start = Node(*start_pos) # 初始位置
93 | self.end = Node(*end_pos) # 结束位置
94 |
95 | # Error Check
96 | if not self._in_map(self.start) or not self._in_map(self.end):
97 | raise ValueError(f"x坐标范围0~{self.width-1}, y坐标范围0~{self.height-1}")
98 | if self._is_collided(self.start):
99 | raise ValueError(f"起点x坐标或y坐标在障碍物上")
100 | if self._is_collided(self.end):
101 | raise ValueError(f"终点x坐标或y坐标在障碍物上")
102 |
103 | # 算法初始化
104 | self.reset(move_step, move_direction)
105 |
106 |
107 | def reset(self, move_step=3, move_direction=8):
108 | """重置算法"""
109 | self.__reset_flag = False
110 | self.move_step = move_step # 移动步长(搜索后期会减小)
111 | self.move_direction = move_direction # 移动方向 8 个
112 | self.close_set = set() # 存储已经走过的位置及其G值
113 | self.open_queue = NodeQueue() # 存储当前位置周围可行的位置及其F值
114 | self.path_list = [] # 存储路径(CloseList里的数据无序)
115 |
116 |
117 | def search(self):
118 | """搜索路径"""
119 | return self.__call__()
120 |
121 |
122 | def _in_map(self, node: Node):
123 | """点是否在网格地图中"""
124 | return (0 <= node.x < self.width) and (0 <= node.y < self.high) # 右边不能取等!!!
125 |
126 |
127 | def _is_collided(self, node: Node):
128 | """点是否和障碍物碰撞"""
129 | return self.map_array[node.y, node.x] == 0
130 |
131 |
132 | def _move(self):
133 | """移动点"""
134 | @lru_cache(maxsize=3) # 避免参数相同时重复计算
135 | def _move(move_step:int, move_direction:int):
136 | move = (
137 | [0, move_step], # 上
138 | [0, -move_step], # 下
139 | [-move_step, 0], # 左
140 | [move_step, 0], # 右
141 | [move_step, move_step], # 右上
142 | [move_step, -move_step], # 右下
143 | [-move_step, move_step], # 左上
144 | [-move_step, -move_step], # 左下
145 | )
146 | return move[0:move_direction] # 坐标增量+代价
147 | return _move(self.move_step, self.move_direction)
148 |
149 |
150 | def _update_open_list(self, curr: Node):
151 | """open_list添加可行点"""
152 | for add in self._move():
153 | # 更新可行位置
154 | next_ = curr + add
155 |
156 | # 新位置是否在地图外边
157 | if not self._in_map(next_):
158 | continue
159 | # 新位置是否碰到障碍物
160 | if self._is_collided(next_):
161 | continue
162 | # 新位置是否在 CloseList 中
163 | if next_ in self.close_set:
164 | continue
165 |
166 | # 计算所添加的结点的代价
167 | H = next_ - self.end # 剩余距离估计
168 | next_.cost = H # G = 0
169 |
170 | # open-list添加/更新结点
171 | self.open_queue.put(next_)
172 |
173 | # 当剩余距离小时, 走慢一点
174 | if H < 20:
175 | self.move_step = 1
176 |
177 |
178 | def __call__(self):
179 | """GBFS路径搜索"""
180 | assert not self.__reset_flag, "call之前需要reset"
181 | print("搜索中\n")
182 |
183 | # 初始化列表
184 | self.open_queue.put(self.start) # 初始化 OpenList
185 |
186 | # 正向搜索节点(CloseList里的数据无序)
187 | tic()
188 | while not self.open_queue.empty():
189 | # 弹出 OpenList 代价 H 最小的点
190 | curr = self.open_queue.get()
191 | # 更新 OpenList
192 | self._update_open_list(curr)
193 | # 更新 CloseList
194 | self.close_set.add(curr) # G始终为0
195 | # 结束迭代
196 | if curr == self.end:
197 | break
198 | print("路径搜索完成\n")
199 | toc()
200 |
201 | # 节点组合成路径
202 | while curr.parent is not None:
203 | self.path_list.append(curr)
204 | curr = curr.parent
205 | self.path_list.reverse()
206 |
207 | # 需要重置
208 | self.__reset_flag = True
209 |
210 | return self.path_list
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 | # debug
220 | if __name__ == '__main__':
221 | p = GBFS()()
222 | MAP.show_path(p)
223 |
--------------------------------------------------------------------------------
/HybridA_star.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 30 16:45:58 2023
4 |
5 | @author: HJ
6 | """
7 |
8 | # Hybrid A*算法
9 | import math
10 | import numpy as np
11 | from dataclasses import dataclass
12 | from itertools import product
13 | from copy import deepcopy
14 | from common import SetQueue, GridMap, tic, toc, limit_angle
15 |
16 |
17 |
18 |
19 | # 地图读取
20 | IMAGE_PATH = 'image1.jpg' # 原图路径
21 | THRESH = 172 # 图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
22 | MAP_HIGHT = 70 # 地图高度 (1)
23 | MAP_WIDTH = 120 # 地图宽度 (1)
24 |
25 | MAP = GridMap(IMAGE_PATH, THRESH, MAP_HIGHT, MAP_WIDTH) # 栅格地图对象
26 |
27 | # 栅格化位置和方向
28 | MAP_NORM = 1.0 # 地图一个像素表示多少米 (m/1) #! BUG MAP_NORM不为 1 时绘图鸡哥背景错位
29 | YAW_NORM = math.pi / 6 # 每多少rad算同一个角度 (rad/1)
30 |
31 | # 起点终点设置
32 | START = [5.0, 35.0, -math.pi/6] # 起点 (x, y, yaw), y轴向下为正, yaw顺时针为正
33 | END = [115.0, 60.0, math.pi/2] # 终点 (x, y, yaw), y轴向下为正, yaw顺时针为正
34 | ERR = 0.5 # 与终点距离小于 ERR 米时停止搜索
35 |
36 | # 车辆模型
37 | CAR_LENGTH = 4.5 # 车辆长度 (m)
38 | CAR_WIDTH = 2.0 # 车辆宽度 (m)
39 | CAR_MAX_STEER = math.radians(30) # 最大转角 (rad)
40 | CAR_MAX_SPEED = 8 # 最大速度 (m/s)
41 |
42 |
43 | # 定义运动模型
44 | def motion_model(s, u, dt):
45 | """
46 | >>> u = [v, δ]
47 | >>> dx/dt = v * cos(θ)
48 | >>> dy/dt = v * sin(θ)
49 | >>> dθ/dt = v/L * tan(δ)
50 | """
51 | s = deepcopy(s)
52 | s[0] += u[0] * math.cos(s[2]) * dt
53 | s[1] += u[0] * math.sin(s[2]) * dt
54 | s[2] += u[0]/CAR_LENGTH * math.tan(u[1]) * dt
55 | s[2] = limit_angle(s[2])
56 | return s
57 |
58 |
59 |
60 |
61 | # 坐标节点
62 | @dataclass(eq=False)
63 | class HybridNode:
64 | """节点"""
65 |
66 | x: float
67 | y: float
68 | yaw: float
69 |
70 | G: float = 0. # G代价
71 | cost: float = None # F代价
72 | parent: "HybridNode" = None # 父节点指针
73 |
74 | def __post_init__(self):
75 | # 坐标和方向栅格化
76 | self.x_idx = round(self.x / MAP_NORM) # int向下取整, round四舍五入
77 | self.y_idx = round(self.y / MAP_NORM)
78 | self.yaw_idx = round(self.yaw / YAW_NORM)
79 | if self.cost is None:
80 | self.cost = self.calculate_heuristic([self.x, self.y], END)
81 |
82 | def __call__(self, u, dt):
83 | # 生成新节点 -> new_node = node(u, dt)
84 | x, y, yaw = motion_model([self.x, self.y, self.yaw], u, dt)
85 | G = self.G + self.calculate_distance([self.x, self.y], [x, y]) + abs(yaw - self.yaw)
86 | return HybridNode(x, y, yaw, G, parent=self)
87 |
88 | def __eq__(self, other: "HybridNode"):
89 | # 节点eq比较 -> node in list
90 | return self.x_idx == other.x_idx and self.y_idx == other.y_idx and self.yaw_idx == other.yaw_idx
91 | #return self.__hash__() == hash(other)
92 |
93 | def __le__(self, other: "HybridNode"):
94 | # 代价<=比较 -> min(open_list)
95 | return self.cost <= other.cost
96 |
97 | def __lt__(self, other: "HybridNode"):
98 | # 代价<比较 -> min(open_list)
99 | return self.cost < other.cost
100 |
101 | def __hash__(self) -> int:
102 | # 节点hash比较 -> node in set
103 | return hash((self.x_idx, self.y_idx, self.yaw_idx))
104 |
105 | def heuristic(self, TARG = END):
106 | """启发搜索, 计算启发值H并更新F值"""
107 | H = self.calculate_heuristic([self.x, self.y], TARG)
108 | self.cost = self.G + H
109 | return H
110 |
111 | def is_end(self, err = ERR):
112 | """是否终点, 启发值H小于err"""
113 | if self.cost - self.G < err:
114 | return True
115 | return False
116 |
117 | def in_map(self, map_array = MAP.map_array):
118 | """是否在地图中"""
119 | return (0 <= self.x < map_array.shape[1]) and (0 <= self.y < map_array.shape[0]) # h*w维, 右边不能取等!!!
120 |
121 | def is_collided(self, map_array = MAP.map_array):
122 | """是否发生碰撞"""
123 | # 计算车辆的边界框的四个顶点坐标
124 | cos_ = math.cos(self.yaw)
125 | sin_ = math.sin(self.yaw)
126 | LC = CAR_LENGTH/2 * cos_
127 | LS = CAR_LENGTH/2 * sin_
128 | WC = CAR_WIDTH/2 * cos_
129 | WS = CAR_WIDTH/2 * sin_
130 | x1 = self.x + LC + WS
131 | y1 = self.y - LS + WC
132 | x2 = self.x + LC - WS
133 | y2 = self.y - LS - WC
134 | x3 = self.x - LC + WS
135 | y3 = self.y + LS + WC
136 | x4 = self.x - LC - WS
137 | y4 = self.y + LS - WC
138 | # 检查边界框所覆盖的栅格是否包含障碍物和出界
139 | for i in range(int(min([x1, x2, x3, x4])/MAP_NORM), int(max([x1, x2, x3, x4])/MAP_NORM)):
140 | for j in range(int(min([y1, y2, y3, y4])/MAP_NORM), int(max([y1, y2, y3, y4])/MAP_NORM)):
141 | if i < 0 or i >= map_array.shape[1]:
142 | return True
143 | if j < 0 or j >= map_array.shape[0]:
144 | return True
145 | if map_array[j, i] == 0: # h*w维, y是第一个索引, 0表示障碍物
146 | return True
147 | return False
148 |
149 | @staticmethod
150 | def calculate_distance(P1, P2):
151 | """欧氏距离"""
152 | return math.hypot(P1[0] - P2[0], P1[1] - P2[1])
153 |
154 | @classmethod
155 | def calculate_heuristic(cls, P, TARG):
156 | """启发函数"""
157 | return cls.calculate_distance(P, TARG) # 欧式距离
158 | #return abs(P[0]-TARG[0]) + abs(P[1]-TARG[1]) # 曼哈顿距离
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 | """ ---------------------------- Hybrid A*算法 ---------------------------- """
172 | # F = G + H
173 |
174 |
175 |
176 |
177 | # 混合A*算法
178 | class HybridAStar:
179 | """混合A*算法"""
180 |
181 | def __init__(self, num_speed=3, num_steer=3, move_step=2, dt=0.2):
182 | """混合A*算法
183 |
184 | Parameters
185 | ----------
186 | num_speed : int
187 | 控制量 v 离散个数, num>=1
188 | num_steer : int
189 | 控制量 δ 离散个数, num>=2
190 | move_step : int
191 | 向后搜索的次数
192 | dt : float
193 | 决策周期
194 | """
195 |
196 | # 起点
197 | self.start = HybridNode(*START) # 起点
198 | self.start.heuristic() # 更新 F 代价
199 |
200 | # Error Check
201 | end = HybridNode(*END)
202 | if not self.start.in_map() or not end.in_map():
203 | raise ValueError(f"x坐标y坐标超出地图边界")
204 | if self.start.is_collided():
205 | raise ValueError(f"起点x坐标或y坐标在障碍物上")
206 | if end.is_collided():
207 | raise ValueError(f"终点x坐标或y坐标在障碍物上")
208 |
209 | # 算法初始化
210 | self.reset(num_speed, num_steer, move_step, dt)
211 |
212 |
213 | def reset(self, num_speed=3, num_steer=3, move_step=2, dt=0.2):
214 | """重置算法"""
215 | self.__reset_flag = False
216 | assert num_steer > 1, "转向离散个数必须大于1"
217 | self.u_all = [
218 | np.linspace(CAR_MAX_SPEED, 0, num_speed) if num_speed > 1 else np.array([CAR_MAX_SPEED]),
219 | np.linspace(-CAR_MAX_STEER, CAR_MAX_STEER, num_steer),
220 | ]
221 | self.dt = dt
222 | self.move_step = move_step
223 | self.close_set = set() # 存储已经走过的位置及其G值
224 | self.open_queue = SetQueue() # 存储当前位置周围可行的位置及其F值
225 | self.path_list = [] # 存储路径(CloseList里的数据无序)
226 |
227 |
228 | def search(self):
229 | """搜索路径"""
230 | return self.__call__()
231 |
232 |
233 | def _update_open_list(self, curr: HybridNode):
234 | """open_list添加可行点"""
235 | for v, delta in product(*self.u_all):
236 | # 更新节点
237 | next_ = curr
238 | for _ in range(self.move_step):
239 | next_ = next_([v, delta], self.dt) # x、y、yaw、G_cost、parent都更新了, F_cost未更新
240 |
241 | # 新位置是否在地图外边
242 | if not next_.in_map():
243 | continue
244 | # 新位置是否碰到障碍物
245 | if next_.is_collided():
246 | continue
247 | # 新位置是否在 CloseList 中
248 | if next_ in self.close_set:
249 | continue
250 |
251 | # 更新F代价
252 | H = next_.heuristic()
253 |
254 | # open-list添加/更新结点
255 | self.open_queue.put(next_)
256 |
257 | # 当剩余距离小时, 走慢一点
258 | if H < 20:
259 | self.move_step = 1
260 |
261 |
262 | def __call__(self):
263 | """A*路径搜索"""
264 | assert not self.__reset_flag, "call之前需要reset"
265 | print("搜索中\n")
266 |
267 | # 初始化 OpenList
268 | self.open_queue.put(self.start)
269 |
270 | # 正向搜索节点
271 | tic()
272 | while not self.open_queue.empty():
273 | # 弹出 OpenList 代价 F 最小的点
274 | curr: HybridNode = self.open_queue.get()
275 | # 更新 OpenList
276 | self._update_open_list(curr)
277 | # 更新 CloseList
278 | self.close_set.add(curr)
279 | # 结束迭代
280 | if curr.is_end():
281 | break
282 | print("路径搜索完成\n")
283 | toc()
284 |
285 | # 节点组合成路径
286 | while curr.parent is not None:
287 | self.path_list.append(curr)
288 | curr = curr.parent
289 | self.path_list.reverse()
290 |
291 | # 需要重置
292 | self.__reset_flag = True
293 |
294 | return self.path_list
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 | # debug
307 | if __name__ == '__main__':
308 | p = HybridAStar()()
309 | MAP.show_path(p)
310 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 路径规划算法
2 |
3 | ## 算法:
4 |
5 | | 算法 | file | 类别 | 性质 | OpenList数据结构 |
6 | | ------------------------------------------------------------ | ----------- | ------------------- | --------- | ---------------- |
7 | | A星算法
(A*) | A_star.py | 启发搜索
F=G+H | 最优+最速 | PriorityQueue |
8 | | 混合A星算法
(Hybrid A*) | HybridA_star.py | 启发搜索
F=G+H | 考虑车辆运动学约束 | PriorityQueue |
9 | | 迪杰斯特拉搜索算法
(Dijkstra) | Dijkstra.py | 启发搜索
F=G | 最优 | PriorityQueue |
10 | | 贪婪最佳优先搜索算法
(Greedy Best First Search, GBFS) | GBFS.py | 启发搜索
F=H | 最速 | PriorityQueue |
11 | | 广度优先搜索算法
(Breadth First Search, BFS) | BFS.py | 遍历搜索 | 最优 | deque 先入先出 |
12 | | 深度优先搜索算法
(Depth First Search, DFS) | DFS.py | 遍历搜索 | 最速 | list 后入先出 |
13 | | 概率路图算法
(Probabilistic Road Map, PRM) | | 采样 | | |
14 | | 快速随机扩展树算法
(Rapidly-exploring Random Tree, RRT) | | 采样 | | |
15 |
16 | ###### 备注:
17 |
18 | * 原版 PriorityQueue:无法更新Node,路径较长,但搜索速度贼快
19 | * set版 PriorityQueue:动态更新Node,路径最短,速度接近或超过优先队列(要求Node可hash)
20 | * list版 PriorityQueue:动态更新Node,路径最短,速度最慢(in和pop时间复杂度比set的in和remove大)
21 |
22 | ## 用法:
23 |
24 | * 在草纸上随便画点障碍物,拍照上传替换鲲鲲图片 image.jpg,在 A_star.py 等脚本中设置起点终点等参数,运行即可.
25 | * 程序并没有设置复杂的继承/依赖关系,只需要如 common.py + A_star.py + image.jpg 三个文件在同一目录就能运行.
26 |
27 | ## 效果:
28 |
29 | **复杂障碍物地图下的路径规划结果(只能看一眼,不然会爆炸)**
30 |
31 | ### 混合A*算法:(考虑车辆运动学约束)
32 |
33 | 
34 |
35 | 
36 |
37 | ### A*算法:(介于最优和快速之间)
38 |
39 | * List耗时0.67s,Set耗时0.45s,PriorityQueue耗时0.48s,步长3
40 | * 由于List/Set存储结构能动态更新OpenList中Node的cost和parent信息,路径会更优
41 |
42 | ###### List/Set存储结构:
43 |
44 | 
45 |
46 | ###### PriorityQueue存储结构:
47 |
48 | 
49 |
50 | ### Dijkstra算法:(最优路径,耗时较大)
51 |
52 | * List耗时81s,Set耗时15.6s,PriorityQueue耗时15s,步长3
53 |
54 | 
55 |
56 | ### GBFS算法:(路径较差,速度贼快)
57 |
58 | * List耗时0.16s,Set耗时0.12s,PriorityQueue耗时0.13s,步长3
59 |
60 | 
61 |
62 | ### BFS算法:(最优路径,耗时较大)
63 |
64 | * Deque耗时8.92s,步长3
65 |
66 | 
67 |
68 | ### DFS算法:(最烂路径,速度较快)
69 |
70 | * List耗时1.96s,步长5
71 |
72 | 
73 |
74 | * List耗时15.07s,步长3
75 |
76 | 
77 |
78 | ### PRM算法:
79 |
80 | raise NotImplementedError
81 |
82 | ### RRT算法:
83 |
84 | raise NotImplementedError
85 |
86 | ## Requirement:
87 |
88 | python >= 3.9
89 |
90 | opencv-python >= 4.7.0.72
91 |
92 | matplotlib >= 3.5.1
93 |
94 | numpy >= 1.22.3
95 |
96 | ###### 广告:
97 |
98 | [DRL-for-Path-Planning: 深度强化学习路径规划, SAC路径规划](https://github.com/zhaohaojie1998/DRL-for-Path-Planning)
99 |
100 | [Grey-Wolf-Optimizer-for-Path-Planning: 灰狼优化算法路径规划、多智能体/多无人机航迹规划](https://github.com/zhaohaojie1998/Grey-Wolf-Optimizer-for-Path-Planning)
101 |
--------------------------------------------------------------------------------
/common.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Fri May 26 2023 16:03:59
4 | Modified on 2023-5-26 16:03:59
5 |
6 | @auther: HJ https://github.com/zhaohaojie1998
7 | """
8 | # 算法共同组成部分
9 | from typing import Union
10 | import cv2
11 | import time
12 | import math
13 | import numpy as np
14 | import matplotlib.pyplot as plt
15 | from queue import PriorityQueue
16 | from dataclasses import dataclass, field
17 | Number = Union[int, float]
18 |
19 |
20 | __all__ = ['tic', 'toc', 'limit_angle', 'GridMap', 'PriorityQueuePro', 'ListQueue', 'SetQueue', 'Node']
21 |
22 |
23 |
24 |
25 | # 坐标节点
26 | @dataclass(eq=False)
27 | class Node:
28 | """节点"""
29 |
30 | x: int
31 | y: int
32 | cost: Number = 0 # F代价
33 | parent: "Node" = None # 父节点指针
34 |
35 | def __sub__(self, other) -> int:
36 | """计算节点与坐标的曼哈顿距离"""
37 | if isinstance(other, Node):
38 | return abs(self.x - other.x) + abs(self.y - other.y)
39 | elif isinstance(other, (tuple, list)):
40 | return abs(self.x - other[0]) + abs(self.y - other[1])
41 | raise ValueError("other必须为坐标或Node")
42 |
43 | def __add__(self, other: Union[tuple, list]) -> "Node":
44 | """生成新节点"""
45 | x = self.x + other[0]
46 | y = self.y + other[1]
47 | cost = self.cost + math.sqrt(other[0]**2 + other[1]**2) # 欧式距离
48 | return Node(x, y, cost, self)
49 |
50 | def __eq__(self, other):
51 | """坐标x,y比较 -> node in list"""
52 | if isinstance(other, Node):
53 | return self.x == other.x and self.y == other.y
54 | elif isinstance(other, (tuple, list)):
55 | return self.x == other[0] and self.y == other[1]
56 | return False
57 |
58 | def __le__(self, other: "Node"):
59 | """代价<=比较 -> min(open_list)"""
60 | return self.cost <= other.cost
61 |
62 | def __lt__(self, other: "Node"):
63 | """代价<比较 -> min(open_list)"""
64 | return self.cost < other.cost
65 |
66 | def __hash__(self) -> int:
67 | """使可变对象可hash, 能放入set中 -> node in set"""
68 | return hash((self.x, self.y)) # tuple可hash
69 | # data in set 时间复杂度为 O(1), 但data必须可hash
70 | # data in list 时间复杂度 O(n)
71 |
72 |
73 |
74 |
75 |
76 | # Set版优先队列
77 | @dataclass
78 | class SetQueue:
79 | """节点优先存储队列 set 版"""
80 |
81 | queue: set[Node] = field(default_factory=set)
82 |
83 | # Queue容器增强
84 | def __bool__(self):
85 | """判断: while Queue:"""
86 | return bool(self.queue)
87 |
88 | def __contains__(self, item):
89 | """包含: pos in Queue"""
90 | return item in self.queue
91 | #NOTE: in是值比较, 只看hash是否在集合, 不看id是否在集合
92 |
93 | def __len__(self):
94 | """长度: len(Queue)"""
95 | return len(self.queue)
96 |
97 | # PriorityQueue操作
98 | def get(self):
99 | """Queue 弹出代价最小节点"""
100 | node = min(self.queue) # O(n)?
101 | self.queue.remove(node) # O(1)
102 | return node
103 |
104 | def put(self, node: Node):
105 | """Queue 加入/更新节点"""
106 | if node in self.queue: # O(1)
107 | qlist = list(self.queue) # 索引元素, set无法索引需转换
108 | idx = qlist.index(node) # O(n)
109 | if node.cost < qlist[idx].cost: # 新节点代价更小则加入新节点
110 | self.queue.remove(node) # O(1)
111 | self.queue.add(node) # O(1) 移除node和加入node的hash相同, 但cost和parent不同
112 | else:
113 | self.queue.add(node) # O(1)
114 |
115 | def empty(self):
116 | """Queue 是否为空"""
117 | return len(self.queue) == 0
118 |
119 |
120 |
121 |
122 |
123 | # List版优先队列
124 | @dataclass
125 | class ListQueue:
126 | """节点优先存储队列 list 版"""
127 |
128 | queue: list[Node] = field(default_factory=list)
129 |
130 | # Queue容器增强
131 | def __bool__(self):
132 | """判断: while Queue:"""
133 | return bool(self.queue)
134 |
135 | def __contains__(self, item):
136 | """包含: pos in Queue"""
137 | return item in self.queue
138 | #NOTE: in是值比较, 只看value是否在列表, 不看id是否在列表
139 |
140 | def __len__(self):
141 | """长度: len(Queue)"""
142 | return len(self.queue)
143 |
144 | def __getitem__(self, idx):
145 | """索引: Queue[i]"""
146 | return self.queue[idx]
147 |
148 | # List操作
149 | def append(self, node: Node):
150 | """List 添加节点"""
151 | self.queue.append(node) # O(1)
152 |
153 | def pop(self, idx = -1):
154 | """List 弹出节点"""
155 | return self.queue.pop(idx) # O(1) ~ O(n)
156 |
157 | # PriorityQueue操作
158 | def get(self):
159 | """Queue 弹出代价最小节点"""
160 | idx = self.queue.index(min(self.queue)) # O(n) + O(n)
161 | return self.queue.pop(idx) # O(1) ~ O(n)
162 |
163 | def put(self, node: Node):
164 | """Queue 加入/更新节点"""
165 | if node in self.queue: # O(n)
166 | idx = self.queue.index(node) # O(n)
167 | if node.cost < self.queue[idx].cost: # 新节点代价更小
168 | self.queue[idx].cost = node.cost # O(1) 更新代价
169 | self.queue[idx].parent = node.parent # O(1) 更新父节点
170 | else:
171 | self.queue.append(node) # O(1)
172 |
173 | # NOTE try语法虽然时间复杂度更小, 但频繁抛出异常速度反而更慢
174 | # try:
175 | # idx = self.queue.index(node) # O(n)
176 | # if node.cost < self.queue[idx].cost: # 新节点代价更小
177 | # self.queue[idx].cost = node.cost # O(1) 更新代价
178 | # self.queue[idx].parent = node.parent # O(1) 更新父节点
179 | # except ValueError:
180 | # self.queue.append(node) # O(1)
181 |
182 | def empty(self):
183 | """Queue 是否为空"""
184 | return len(self.queue) == 0
185 |
186 |
187 |
188 |
189 | # 原版优先队列增强(原版也是list实现, 但get更快, put更慢)
190 | class PriorityQueuePro(PriorityQueue):
191 | """节点优先存储队列 原版"""
192 |
193 | # PriorityQueue操作
194 | def put(self, item, block=True, timeout=None):
195 | """Queue 加入/更新节点"""
196 | if item in self.queue: # O(n)
197 | return # 修改数据会破坏二叉树结构, 就不存了
198 | else:
199 | super().put(item, block, timeout) # O(logn)
200 |
201 | # Queue容器增强
202 | def __bool__(self):
203 | """判断: while Queue:"""
204 | return bool(self.queue)
205 |
206 | def __contains__(self, item):
207 | """包含: pos in Queue"""
208 | return item in self.queue
209 | #NOTE: in是值比较, 只看value是否在列表, 不看id是否在列表
210 |
211 | def __len__(self):
212 | """长度: len(Queue)"""
213 | return len(self.queue)
214 |
215 | def __getitem__(self, idx):
216 | """索引: Queue[i]"""
217 | return self.queue[idx]
218 |
219 |
220 |
221 |
222 |
223 | # 图像处理生成网格地图
224 | class GridMap:
225 | """从图片中提取栅格地图"""
226 |
227 | def __init__(
228 | self,
229 | img_path: str,
230 | thresh: int,
231 | high: int,
232 | width: int,
233 | ):
234 | """提取栅格地图
235 |
236 | Parameters
237 | ----------
238 | img_path : str
239 | 原图片路径
240 | thresh : int
241 | 图片二值化阈值, 大于阈值的部分被置为255, 小于部分被置为0
242 | high : int
243 | 栅格地图高度
244 | width : int
245 | 栅格地图宽度
246 | """
247 | # 存储路径
248 | self.__map_path = 'map.png' # 栅格地图路径
249 | self.__path_path = 'path.png' # 路径规划结果路径
250 |
251 | # 图像处理 # NOTE cv2 按 HWC 存储图片
252 | image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) # 读取原图 H,W,C
253 | thresh, map_img = cv2.threshold(image, thresh, 255, cv2.THRESH_BINARY) # 地图二值化
254 | map_img = cv2.resize(map_img, (width, high)) # 设置地图尺寸
255 | cv2.imwrite(self.__map_path, map_img) # 存储二值地图
256 |
257 | # 栅格地图属性
258 | self.map_array = np.array(map_img)
259 | """ndarray地图, H*W, 0代表障碍物"""
260 | self.high = high
261 | """ndarray地图高度"""
262 | self.width = width
263 | """ndarray地图宽度"""
264 |
265 | def show_path(self, path_list, *, save = False):
266 | """路径规划结果绘制
267 |
268 | Parameters
269 | ----------
270 | path_list : list[Node]
271 | 路径节点组成的列表, 要求Node有x,y属性
272 | save : bool, optional
273 | 是否保存结果图片
274 | """
275 |
276 | if not path_list:
277 | print("\n传入空列表, 无法绘图\n")
278 | return
279 | if not hasattr(path_list[0], "x") or not hasattr(path_list[0], "y"):
280 | print("\n路径节点中没有坐标x或坐标y属性, 无法绘图\n")
281 | return
282 |
283 | x, y = [], []
284 | for p in path_list:
285 | x.append(p.x)
286 | y.append(p.y)
287 |
288 | fig, ax = plt.subplots()
289 | map_ = cv2.imread(self.__map_path)
290 | map_ = cv2.resize(map_, (self.width, self.high))
291 | #img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # R G B
292 | #img = img[:, :, ::-1] # R G B
293 | map_ = map_[::-1] # 画出来的鸡哥是反的, 需要转过来
294 | ax.imshow(map_, extent=[0, self.width, 0, self.high]) # extent[x_min, x_max, y_min, y_max]
295 | ax.plot(x, y, c = 'r', label='path', linewidth=2)
296 | ax.scatter(x[0], y[0], c='c', marker='o', label='start', s=40, linewidth=2)
297 | ax.scatter(x[-1], y[-1], c='c', marker='x', label='end', s=40, linewidth=2)
298 | ax.invert_yaxis() # 反转y轴
299 | ax.legend().set_draggable(True)
300 | plt.show()
301 | if save:
302 | plt.savefig(self.__path_path)
303 |
304 |
305 |
306 |
307 |
308 | # matlab计时器
309 | def tic():
310 | '''计时开始'''
311 | if 'global_tic_time' not in globals():
312 | global global_tic_time
313 | global_tic_time = []
314 | global_tic_time.append(time.time())
315 |
316 | def toc(name='', *, CN=True, digit=6):
317 | '''计时结束'''
318 | if 'global_tic_time' not in globals() or not global_tic_time: # 未设置全局变量或全局变量为[]
319 | print('未设置tic' if CN else 'tic not set')
320 | return
321 | name = name+' ' if (name and not CN) else name
322 | if CN:
323 | print('%s历时 %f 秒。\n' % (name, round(time.time() - global_tic_time.pop(), digit)))
324 | else:
325 | print('%sElapsed time is %f seconds.\n' % (name, round(time.time() - global_tic_time.pop(), digit)))
326 |
327 |
328 |
329 |
330 | # 角度归一化
331 | def limit_angle(x, mode=1):
332 | """
333 | mode1 : (-inf, inf) -> (-π, π]
334 | mode2 : (-inf, inf) -> [0, 2π)
335 | """
336 | x = x - x//(2*math.pi) * 2*math.pi # any -> [0, 2π)
337 | if mode == 1 and x > math.pi:
338 | return x - 2*math.pi # [0, 2π) -> (-π, π]
339 | return x
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
--------------------------------------------------------------------------------
/image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/image.jpg
--------------------------------------------------------------------------------
/image1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/image1.jpg
--------------------------------------------------------------------------------
/图片/astar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/astar.png
--------------------------------------------------------------------------------
/图片/astar_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/astar_1.png
--------------------------------------------------------------------------------
/图片/bfs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/bfs.png
--------------------------------------------------------------------------------
/图片/dfs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/dfs.png
--------------------------------------------------------------------------------
/图片/dfs_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/dfs_1.png
--------------------------------------------------------------------------------
/图片/dij.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/dij.png
--------------------------------------------------------------------------------
/图片/gbfs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/gbfs.png
--------------------------------------------------------------------------------
/图片/hybrida.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/hybrida.png
--------------------------------------------------------------------------------
/图片/hybrida1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhaohaojie1998/Path-Planning/a6922951b82b067871093452f577f1a8c063a6b7/图片/hybrida1.png
--------------------------------------------------------------------------------