├── .gitignore ├── A_Star ├── Astar.cpp ├── Astar.h └── main.cpp ├── B-spline ├── BSpline.cpp ├── BSpline.h └── main.cpp ├── Bezier ├── BezierCurve.cpp ├── BezierCurve.h └── main.cpp ├── CMakeLists.txt ├── Dijkstra ├── Dijkstra.cpp ├── Dijkstra.h └── main.cpp ├── README.assets ├── dijkstra_demo.gif ├── rrt_connect-16832995765544.gif ├── rrt_connect-16832995877726.gif ├── rrt_connect1.gif └── rrt_star_demo.gif ├── README.md ├── Rapidly-exploring_Random_Tree ├── RRT.cpp ├── RRT.h └── main.cpp ├── Rapidly-exploring_Random_Tree_Star ├── RRT_Star.cpp ├── RRT_Star.h └── main.cpp ├── Rapidly-exploring_Random_Tree_connect ├── RRT_connect.cpp ├── RRT_connect.h └── main.cpp ├── gif ├── astar.gif ├── b_spline_demo.gif ├── bezier_demo.gif ├── dijkstra_demo.gif ├── rrt_connect.gif ├── rrt_demo.gif └── rrt_star_demo.gif └── matplotlibcpp.h /.gitignore: -------------------------------------------------------------------------------- 1 | /.vscode 2 | /build -------------------------------------------------------------------------------- /A_Star/Astar.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file Astar.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-16 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "Astar.h" 13 | 14 | int gifindex = 0; 15 | Astar::Node::Node(double x, double y, float cost, double parentIndex) : x(x), y(y), cost(cost), parent_index(parentIndex) {} 16 | 17 | Astar::Astar(double resolution, double robotRadius) : resolution(resolution), robot_radius(robotRadius) {} 18 | 19 | /** 20 | * @brief 设置障碍物以及边缘 21 | * 22 | */ 23 | void Astar::setObstale(vector &ox, vector &oy) 24 | { 25 | for (double i = -10; i < 60; i++) 26 | { 27 | ox.push_back(i); 28 | oy.push_back(-10.0); 29 | } 30 | for (double i = -10; i < 60; i++) 31 | { 32 | ox.push_back(60.0); 33 | oy.push_back(i); 34 | } 35 | for (double i = -10; i < 61; i++) 36 | { 37 | ox.push_back(i); 38 | oy.push_back(60.0); 39 | } 40 | for (double i = -10; i < 61; i++) 41 | { 42 | ox.push_back(-10.0); 43 | oy.push_back(i); 44 | } 45 | for (double i = -10; i < 10; i++) 46 | { 47 | ox.push_back(i); 48 | oy.push_back(10); 49 | } 50 | for (double i = 0; i < 30; i++) 51 | { 52 | ox.push_back(40.0 - i); 53 | oy.push_back(30); 54 | } 55 | for (double i = 0; i < 15; i++) 56 | { 57 | ox.push_back(60.0 - i); 58 | oy.push_back(30); 59 | } 60 | } 61 | 62 | /** 63 | * @brief 得到障碍物信息图,有障碍物的地方标记为true,没有标记为false 64 | * @param ox 障碍物x坐标集合 65 | * @param oy 障碍物y坐标集合 66 | */ 67 | void Astar::calObstacleMap(const vector &ox, const vector &oy) 68 | { 69 | min_x = round(*min_element(ox.begin(), ox.end())); 70 | min_y = round(*min_element(oy.begin(), oy.end())); 71 | max_x = round(*max_element(ox.begin(), ox.end())); 72 | max_y = round(*max_element(oy.begin(), oy.end())); 73 | 74 | cout << "min_x:" << min_x << " min_y:" << min_y << " max_x:" << max_x << " max_y:" << max_y << endl; 75 | 76 | x_width = round((max_x - min_x) / resolution); 77 | y_width = round((max_y - min_y) / resolution); 78 | cout << "x_width:" << x_width << " y_width:" << y_width << endl; 79 | 80 | obstacle_map = vector>(x_width, vector(y_width, false)); 81 | 82 | for (double i = 0; i < x_width; i++) 83 | { 84 | double x = calPosition(i, min_x); 85 | for (double j = 0; j < y_width; j++) 86 | { 87 | double y = calPosition(j, min_y); 88 | for (double k = 0; k < ox.size(); k++) 89 | { 90 | double d = sqrt(pow(ox[k] - x, 2) + pow(oy[k] - y, 2)); 91 | if (d <= robot_radius) 92 | { 93 | obstacle_map[i][j] = true; 94 | break; 95 | } 96 | } 97 | } 98 | } 99 | } 100 | 101 | /** 102 | * @brief 计算栅格在地图中的位置 103 | * @param index 104 | * @param minp 105 | * @return 106 | */ 107 | double Astar::calPosition(double index, double minp) 108 | { 109 | double pos = index * resolution + minp; 110 | return pos; 111 | } 112 | 113 | /** 114 | * @brief 标记移动代价 115 | * @return 116 | */ 117 | void Astar::getMotionModel() 118 | { 119 | // x,y,cost 120 | motion = {{1, 0, 1}, 121 | {0, 1, 1}, 122 | {-1, 0, 1}, 123 | {0, -1, 1}, 124 | {-1, -1, sqrt(2)}, 125 | {-1, 1, sqrt(2)}, 126 | {1, -1, sqrt(2)}, 127 | {1, 1, sqrt(2)}}; 128 | } 129 | 130 | /** 131 | * @brief 计算起点终点的栅格索引 132 | * @param position 133 | * @param minp 134 | * @return 135 | */ 136 | double Astar::calXyIndex(double position, double minp) 137 | { 138 | return round((position - minp) / resolution); 139 | } 140 | 141 | /** 142 | * @brief 计算栅格索引 143 | * @param node 144 | * @return 145 | */ 146 | double Astar::calIndex(Astar::Node *node) 147 | { 148 | // cout<x<<","<y<y - min_y) * x_width + (node->x - min_x); 150 | } 151 | 152 | /** 153 | * @brief 判断节点是否有效,即是否超出边界和碰到障碍物 154 | * @param node 155 | * @return 156 | */ 157 | bool Astar::verifyNode(Astar::Node *node) 158 | { 159 | double px = calPosition(node->x, min_x); 160 | double py = calPosition(node->y, min_y); 161 | // 超出边界 162 | if (px < min_x || py < min_y || px >= max_x || py >= max_y) 163 | return false; 164 | // 遇到障碍物 165 | if (obstacle_map[node->x][node->y]) 166 | return false; 167 | return true; 168 | } 169 | 170 | /** 171 | * @brief 根据parent_index,倒推出路径 172 | * @param goal_node 173 | * @param closed_set 174 | * @return 175 | */ 176 | pair, vector> Astar::calFinalPath(Astar::Node *goal_node, map closed_set) 177 | { 178 | vector rx, ry; 179 | rx.push_back(calPosition(goal_node->x, min_x)); 180 | ry.push_back(calPosition(goal_node->y, min_y)); 181 | 182 | double parent_index = goal_node->parent_index; 183 | 184 | while (parent_index != -1) 185 | { 186 | Node *node = closed_set[parent_index]; 187 | rx.push_back(calPosition(node->x, min_x)); 188 | ry.push_back(calPosition(node->y, min_y)); 189 | parent_index = node->parent_index; 190 | } 191 | return {rx, ry}; 192 | } 193 | 194 | /** 195 | * @brief 规划 196 | * @param start 起点 197 | * @param goal 终点 198 | * @return 规划后的路径 199 | */ 200 | pair, vector> Astar::planning(const vector start, const vector goal) 201 | { 202 | double sx = start[0], sy = start[1]; 203 | double gx = goal[0], gy = goal[1]; 204 | Node *start_node = new Node(calXyIndex(sx, min_x), calXyIndex(sy, min_y), 0.0, -1); 205 | Node *goal_node = new Node(calXyIndex(gx, min_x), calXyIndex(gy, min_y), 0.0, -1); 206 | 207 | map open_set, closed_set; 208 | // 将起点加入到open set 209 | open_set[calIndex(start_node)] = start_node; 210 | 211 | Node *current = nullptr; 212 | while (true) 213 | { 214 | double cur_id = numeric_limits::max(); 215 | double cost = numeric_limits::max(); 216 | // 计算代价最小的节点,与dijkstra代码维一不同的地方,即启发函数的计算不同,其他都一样 217 | for (auto it = open_set.begin(); it != open_set.end(); it++) 218 | { 219 | double now_cost = it->second->cost + calHeuristic(goal_node, it->second); 220 | if (now_cost < cost) 221 | { 222 | cost = now_cost; 223 | cur_id = it->first; 224 | } 225 | } 226 | current = open_set[cur_id]; 227 | 228 | plotGraph(current); // 画图 229 | 230 | // 若找到了目标结点,则退出循环 231 | if (abs(current->x - goal_node->x) < EPS && abs(current->y - goal_node->y) < EPS) 232 | { 233 | cout << "Find goal" << endl; 234 | goal_node->parent_index = current->parent_index; 235 | goal_node->cost = current->cost; 236 | break; 237 | } 238 | 239 | // 从open set中去除 240 | auto iter = open_set.find(cur_id); 241 | open_set.erase(iter); 242 | // 将其加入到closed set 243 | closed_set[cur_id] = current; 244 | 245 | // 根据motion,扩展搜索网络 246 | for (vector move : motion) 247 | { 248 | // cout<x + move[0], current->y + move[1], current->cost + move[2], cur_id); 251 | node->cost += calHeuristic(goal_node, node); 252 | double n_id = calIndex(node); 253 | // 如果已经在closed_set中了 254 | if (closed_set.find(n_id) != closed_set.end()) 255 | continue; 256 | // 如果超出边界或者碰到障碍物了 257 | if (!verifyNode(node)) 258 | continue; 259 | // 如果open set中没有这个节点 260 | if (open_set.find(n_id) == open_set.end()) 261 | { 262 | open_set[n_id] = node; 263 | } 264 | // 如果open set中已经存在这个节点,则更新cost 265 | else 266 | { 267 | if (open_set[n_id]->cost >= node->cost) 268 | { 269 | open_set[n_id] = node; 270 | } 271 | } 272 | } 273 | } 274 | return calFinalPath(goal_node, closed_set); 275 | } 276 | /** 277 | * @brief 启发函数计算,与dijkstra不同 278 | * 279 | * @param n1 280 | * @param n2 281 | * @return double 282 | */ 283 | double Astar::calHeuristic(Node *n1, Node *n2) 284 | { 285 | double w = 1.0; // 启发函数权重 286 | double d = w * sqrt(pow(n1->x - n2->x, 2) + pow(n1->y - n1->x, 2)); 287 | return d; 288 | } 289 | 290 | /** 291 | * @brief 画图 292 | * @param current 293 | */ 294 | void Astar::plotGraph(Astar::Node *current) 295 | { 296 | plt::plot(vector{calPosition(current->x, min_x)}, vector{calPosition(current->y, min_y)}, "xc"); 297 | // // 将每一帧保存为单独的文件 298 | // stringstream filename; 299 | // filename << "./frame_" << gifindex++ << ".png"; 300 | // plt::save(filename.str()); 301 | plt::pause(0.0000001); 302 | } 303 | 304 | /** 305 | * @brief 设置坐标 306 | * 307 | * @param st 起点 308 | * @param go 目标 309 | * @param ox 障碍物x 310 | * @param oy 障碍物y 311 | */ 312 | void Astar::set(const vector &st, const vector &go, const vector &ox, const vector &oy) 313 | { 314 | Astar::st = st; 315 | Astar::go = go; 316 | Astar::ox = ox; 317 | Astar::oy = oy; 318 | } 319 | -------------------------------------------------------------------------------- /A_Star/Astar.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file Astar.h 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-10 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #ifndef Astar_H 13 | #define Astar_H 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "../matplotlibcpp.h" 24 | 25 | using namespace std; 26 | using namespace Eigen; 27 | namespace plt = matplotlibcpp; 28 | #define EPS 1e-4 29 | #define PI 3.14159265354 30 | 31 | //extern int gifindex; 32 | class Astar 33 | { 34 | public: 35 | struct Node 36 | { 37 | double x; 38 | double y; 39 | float cost; 40 | // Node* p_node; 41 | double parent_index; 42 | 43 | Node(double x, double y, float cost, double parentIndex); 44 | }; 45 | 46 | private: 47 | double resolution; // 栅格大小 48 | double robot_radius; 49 | double min_x, min_y, max_x, max_y; // 地图范围 50 | double x_width, y_width; // 长宽 51 | vector> obstacle_map; // 障碍物地图 52 | vector> motion; // 障碍物地图 53 | vector st, go; 54 | vector ox, oy; 55 | 56 | public: 57 | Astar(double resolution, double robotRadius); 58 | void setObstale(vector &ox, vector &oy); 59 | void calObstacleMap(const vector &ox, const vector &oy); 60 | 61 | double calPosition(double index, double minp); 62 | 63 | void getMotionModel(); 64 | 65 | double calXyIndex(double position, double minp); 66 | 67 | double calIndex(Node *node); 68 | 69 | bool verifyNode(Node *node); 70 | 71 | pair, vector> calFinalPath(Node *goal_node, map closed_set); 72 | 73 | pair, vector> planning(vector start, vector goal); 74 | double calHeuristic(Node *n1, Node *n2); 75 | 76 | void plotGraph(Node *current); 77 | 78 | void set(const vector &st, const vector &go, const vector &ox, const vector &oy); 79 | }; 80 | 81 | #endif // ASTAR_H 82 | -------------------------------------------------------------------------------- /A_Star/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file main.cpp 3 | * @author czj 4 | * @brief 主函数 5 | * @version 0.1 6 | * @date 2023-04-10 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "Astar.h" 13 | 14 | int main() 15 | { 16 | vector start{-5, -5}, goal{50, 50}; 17 | double grid_size = 2.0; 18 | double robot_radius = 1.0; 19 | 20 | vector ox; 21 | vector oy; 22 | 23 | Astar astar(grid_size, robot_radius); 24 | astar.setObstale(ox, oy); // 设置障碍物信息 25 | astar.set(start, goal, ox, oy); // 设置起点,目标点,障碍物的(x,y) 26 | astar.calObstacleMap(ox, oy); // 在地图上生成障碍物 27 | astar.getMotionModel(); // 创建移动代价 28 | 29 | // 绘制地图 30 | plt::plot(ox, oy, ".k"); 31 | plt::plot(vector{start[0]}, vector{start[1]}, "ob"); 32 | plt::plot(vector{goal[0]}, vector{goal[1]}, "or"); 33 | plt::grid(true); 34 | 35 | // 规划路径 36 | pair, vector> xy = astar.planning(start, goal); 37 | // 绘制路径 38 | plt::plot(xy.first, xy.second, "-r"); 39 | 40 | // stringstream filename; 41 | // filename << "./frame_" << gifindex << ".png"; 42 | // plt::save(filename.str()); 43 | 44 | // // 合成 GIF 图片 45 | // const char *gif_filename = "./astar_demo.gif"; 46 | // stringstream cmd; 47 | // cmd << "convert -delay 2 -loop 0 ./frame_*.png " << gif_filename; 48 | // system(cmd.str().c_str()); 49 | // cout << "Saving result to " << gif_filename << std::endl; 50 | // plt::show(); 51 | // //删除png图片 52 | // system("rm *.png"); 53 | // return 0; 54 | 55 | // 保存图片 56 | const char *filename = "./astar_demo.png"; 57 | cout << "Saving result to " << filename << std::endl; 58 | plt::save(filename); 59 | plt::show(); 60 | 61 | return 0; 62 | } 63 | -------------------------------------------------------------------------------- /B-spline/BSpline.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file BSpline.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-22 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "BSpline.h" 13 | 14 | /** 15 | * @brief 基函数定义 16 | * @param i 17 | * @param k B样条阶数k 18 | * @param u 自变量 19 | * @param knots 节点向量 array([u0,u1,u2,...,u_n+k],shape=[1,n+k+1]. 20 | */ 21 | double basic(int i, int k, double u, vector node_vector) 22 | { 23 | // 0次B样条(1阶B样条) 24 | double Bik_u; 25 | if (k == 1) 26 | { 27 | if (u >= node_vector[i] && u < node_vector[i + 1]) 28 | { 29 | Bik_u = 1; 30 | } 31 | else 32 | { 33 | Bik_u = 0; 34 | } 35 | } 36 | else 37 | { 38 | // 公式中的两个分母 39 | double denominator_1 = node_vector[i + k - 1] - node_vector[i]; 40 | double denominator_2 = node_vector[i + k] - node_vector[i + 1]; 41 | // # 如果遇到分母为 0的情况: 42 | // # 1. 如果此时分子也为0,约定这一项整体为0; 43 | // # 2. 如果此时分子不为0,则约定分母为1 。 44 | if (denominator_1 == 0) 45 | denominator_1 = 1; 46 | if (denominator_2 == 0) 47 | denominator_2 = 1; 48 | Bik_u = (u - node_vector[i]) / denominator_1 * basic(i, k - 1, u, node_vector) + (node_vector[i + k] - u) / denominator_2 * 49 | basic(i + 1, k - 1, u, node_vector); 50 | } 51 | return Bik_u; 52 | } 53 | 54 | /** 55 | * @brief 准均匀B样条的节点向量计算 56 | * 首末值定义为 0 和 1 57 | * @param n 控制点n个 58 | * @param k B样条阶数k, k阶B样条,k-1次曲线. 59 | * @return vector 60 | */ 61 | vector QuasiUniform(int n, int k) 62 | { 63 | int numKnots = n + k; // 节点个数 64 | vector knots(numKnots); 65 | 66 | for (int i = 0; i < numKnots; i++) 67 | { 68 | if (i < k) 69 | knots[i] = 0.0f; // 前k个节点向量为0 70 | else if (i >= n) 71 | knots[i] = 1.0f; // 后k个节点向量为1 72 | else 73 | { 74 | // 计算均匀参数值 75 | knots[i] += knots[i - 1] + (double)1 / (n - k + 1); 76 | } 77 | } 78 | 79 | return knots; 80 | } 81 | 82 | /** 83 | * @brief 均匀B样条的节点向量计算 84 | * 85 | * @param n 控制点n个 86 | * @param k B样条阶数k, k阶B样条,k-1次曲线 87 | * @return vector 88 | */ 89 | vector Uniform(int n, int k) 90 | { 91 | int numKnots = n + k; 92 | vector knots(numKnots); 93 | 94 | for (int i = 0; i < numKnots; i++) 95 | { 96 | knots[i] = (double)i / (double)(numKnots - 1); 97 | } 98 | 99 | return knots; 100 | } -------------------------------------------------------------------------------- /B-spline/BSpline.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file BSpline.h 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-23 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #ifndef BSPLINE_H 13 | #define BSPLINE_H 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | using namespace std; 20 | using namespace Eigen; 21 | 22 | double basic(int i, int k, double u, vector knots); 23 | 24 | vector QuasiUniform(int n, int k); 25 | 26 | vector Uniform(int n, int k); 27 | 28 | #endif // BSPLINE_H 29 | -------------------------------------------------------------------------------- /B-spline/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file main.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-23 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "BSpline.h" 13 | #include "../matplotlibcpp.h" 14 | namespace plt = matplotlibcpp; 15 | 16 | int main() 17 | { 18 | // 控制点 19 | vector points{Vector2d(0, 0), Vector2d(1, 1), Vector2d(2, 1), Vector2d(3, 1), Vector2d(4, 2)}; 20 | 21 | // 生成控制点 22 | vector x(points.size()), y(points.size()); 23 | for (int i = 0; i < points.size(); ++i) 24 | { 25 | x[i] = points[i][0]; 26 | y[i] = points[i][1]; 27 | } 28 | 29 | // 生成曲线点 30 | vector x_, y_; 31 | int n = points.size(); // 控制点个数 32 | int k = 3; // k阶、k-1次B样条 33 | Vector2d p_u(0.0, 0.0); 34 | vector b(n); 35 | 36 | int flag; 37 | cout << "请选择:1. 均匀B样条 2.准均匀B样条 0. 退出 " << endl; 38 | cin >> flag; 39 | vector knots; 40 | switch (flag) 41 | { 42 | case 1: // 均匀B样条 43 | knots = Uniform(n, k); 44 | for (double u = (double)(k - 1) / (n + k); u < (double)(n + 1) / (n + k); u += 0.005) 45 | { 46 | plt::clf(); 47 | for (int i = 0; i < n; i++) 48 | { 49 | b[i] = basic(i, k, u, knots); 50 | // cout << b[i] << endl; 51 | } 52 | for (int i = 0; i < points.size(); i++) 53 | { 54 | p_u += points[i] * b[i]; 55 | } 56 | x_.push_back(p_u[0]); 57 | y_.push_back(p_u[1]); 58 | // 画图 59 | // plt::xlim(0,1); 60 | plt::plot(x_, y_, "r"); 61 | plt::plot(x, y); 62 | plt::pause(0.01); 63 | p_u = Vector2d(0, 0); 64 | } 65 | break; 66 | case 2: // 准均匀B样条 67 | knots = QuasiUniform(n, k); 68 | for (double u = 0; u < 1; u += 0.005) 69 | { 70 | plt::clf(); 71 | for (int i = 0; i < n; i++) 72 | { 73 | b[i] = basic(i, k, u, knots); 74 | // cout << b[i] << endl; 75 | } 76 | for (int i = 0; i < points.size(); i++) 77 | { 78 | p_u += points[i] * b[i]; 79 | // cout<(n - i + 1) / static_cast(i); 25 | } 26 | return result; 27 | } 28 | 29 | /** 30 | * 计算n阶贝塞尔曲线上的点 31 | * @param points 控制点 32 | * @param t 参数t 33 | * @return n阶贝塞尔曲线上的点 34 | */ 35 | Vector2d bezier(const vector &points, double t) 36 | { 37 | 38 | int n = points.size() - 1; 39 | Vector2d result; 40 | for (int i = 0; i <= n; ++i) 41 | { 42 | result += combination(n, i) * pow(t, i) * pow(1 - t, n - i) * points[i]; 43 | } 44 | return result; 45 | } -------------------------------------------------------------------------------- /Bezier/BezierCurve.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file BezierCurve.h 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-21 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #ifndef BEZIERCURVE_H 13 | #define BEZIERCURVE_H 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | using namespace std; 21 | using namespace Eigen; 22 | 23 | double combination(int n, int k); 24 | 25 | Vector2d bezier(const vector& points, double t); 26 | 27 | 28 | #endif // BEZIERCURVE_H 29 | -------------------------------------------------------------------------------- /Bezier/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file main.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-21 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "BezierCurve.h" 13 | #include "../matplotlibcpp.h" 14 | namespace plt = matplotlibcpp; 15 | 16 | int main() 17 | { 18 | // 控制点 19 | vector points{Vector2d(0, 0), Vector2d(1, 1), Vector2d(2.5, 1.5), Vector2d(3, 1), Vector2d(4, 2)}; 20 | 21 | // 生成控制点 22 | cout << "Plotting..." << endl; 23 | vector x(points.size()), y(points.size()); 24 | for (int i = 0; i < points.size(); ++i) 25 | { 26 | x[i] = points[i][0]; 27 | y[i] = points[i][1]; 28 | } 29 | 30 | // 生成贝塞尔曲线 31 | const int n_points = 100; 32 | vector curve(n_points); 33 | vector x_curve, y_curve; 34 | for (int i = 0; i < n_points; ++i) 35 | { 36 | plt::clf(); 37 | double t = static_cast(i) / n_points; 38 | curve[i] = bezier(points, t); 39 | x_curve.emplace_back(curve[i][0]); 40 | y_curve.emplace_back(curve[i][1]); 41 | // 绘制控制点和曲线 42 | plt::plot(x_curve, y_curve, "r"); 43 | plt::plot(x, y); 44 | // // 将每一帧保存为单独的文件 45 | // stringstream filename; 46 | // filename << "./frame_" << i << ".png"; 47 | // plt::save(filename.str()); 48 | plt::pause(0.01); 49 | } 50 | // // 合成 GIF 图片 51 | // const char *gif_filename = "./bezier_demo.gif"; 52 | // stringstream cmd; 53 | // cmd << "convert -delay 10 -loop 0 ./frame_*.png " << gif_filename; 54 | // system(cmd.str().c_str()); 55 | // cout << "Saving result to " << gif_filename << std::endl; 56 | // plt::show(); 57 | // //删除png图片 58 | // system("rm *.png"); 59 | // return 0; 60 | 61 | 62 | // save figure 63 | const char* filename = "./bezier_demo.png"; 64 | cout << "Saving result to " << filename << std::endl; 65 | plt::save(filename); 66 | plt::show(); 67 | return 0; 68 | } -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.21) 2 | project(path_planning) 3 | 4 | set(CMAKE_CXX_STANDARD 14) 5 | 6 | 7 | # include(GNUInstallDirs) 8 | set(PACKAGE_NAME path_planning) 9 | 10 | # output 11 | set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) 12 | 13 | 14 | # Library target 15 | add_library(path_planning INTERFACE) 16 | 17 | # TODO: Use `Development.Embed` component when requiring cmake >= 3.18 18 | find_package(Python COMPONENTS Interpreter Development NumPy REQUIRED) 19 | target_link_libraries(path_planning INTERFACE 20 | Python::Python 21 | Python::Module 22 | Python::NumPy 23 | ) 24 | install( 25 | TARGETS path_planning 26 | EXPORT install_targets 27 | ) 28 | 29 | 30 | find_package(Eigen3 REQUIRED) 31 | include_directories(${EIGEN3_INCLUDE_DIR}) 32 | 33 | 34 | # bezier 35 | add_executable(bezier_demo Bezier/main.cpp Bezier/BezierCurve.cpp) 36 | target_link_libraries(bezier_demo PRIVATE path_planning) 37 | 38 | # B-spline 39 | add_executable(b_spline_demo B-spline/main.cpp B-spline/BSpline.cpp) 40 | target_link_libraries(b_spline_demo PRIVATE path_planning) 41 | 42 | 43 | # RRT 44 | add_executable(rrt_demo Rapidly-exploring_Random_Tree/main.cpp Rapidly-exploring_Random_Tree/RRT.cpp) 45 | target_link_libraries(rrt_demo PRIVATE path_planning) 46 | 47 | 48 | # RRT_connect 49 | add_executable(rrt_connect_demo Rapidly-exploring_Random_Tree_connect/main.cpp Rapidly-exploring_Random_Tree_connect/RRT_connect.cpp) 50 | target_link_libraries(rrt_connect_demo path_planning) 51 | 52 | # RRT_star 53 | add_executable(rrt_star_demo Rapidly-exploring_Random_Tree_Star/main.cpp Rapidly-exploring_Random_Tree_Star/RRT_Star.cpp Rapidly-exploring_Random_Tree/RRT.cpp) 54 | target_link_libraries(rrt_star_demo path_planning) 55 | 56 | # Dijkstra 57 | add_executable(dijkstra_demo Dijkstra/main.cpp Dijkstra/Dijkstra.cpp) 58 | target_link_libraries(dijkstra_demo path_planning) 59 | 60 | # astar 61 | add_executable(astar_demo A_Star/Astar.cpp A_Star/main.cpp) 62 | target_link_libraries(astar_demo path_planning) -------------------------------------------------------------------------------- /Dijkstra/Dijkstra.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file Dijkstra.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-16 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "Dijkstra.h" 13 | 14 | //int gifindex = 0; 15 | 16 | Dijkstra::Node::Node(double x, double y, float cost, double parentIndex) : x(x), y(y), cost(cost), parent_index(parentIndex) {} 17 | 18 | Dijkstra::Dijkstra(double resolution, double robotRadius) : resolution(resolution), robot_radius(robotRadius) {} 19 | 20 | /** 21 | * @brief 设置障碍物以及边缘 22 | * 23 | */ 24 | void Dijkstra::setObstale(vector &ox, vector &oy) 25 | { 26 | for (double i = -10; i < 60; i++) 27 | { 28 | ox.push_back(i); 29 | oy.push_back(-10.0); 30 | } 31 | for (double i = -10; i < 60; i++) 32 | { 33 | ox.push_back(60.0); 34 | oy.push_back(i); 35 | } 36 | for (double i = -10; i < 61; i++) 37 | { 38 | ox.push_back(i); 39 | oy.push_back(60.0); 40 | } 41 | for (double i = -10; i < 61; i++) 42 | { 43 | ox.push_back(-10.0); 44 | oy.push_back(i); 45 | } 46 | for (double i = -10; i < 10; i++) 47 | { 48 | ox.push_back(i); 49 | oy.push_back(10); 50 | } 51 | for (double i = 0; i < 30; i++) 52 | { 53 | ox.push_back(40.0 - i); 54 | oy.push_back(30); 55 | } 56 | for (double i = 0; i < 15; i++) 57 | { 58 | ox.push_back(60.0 - i); 59 | oy.push_back(30); 60 | } 61 | } 62 | 63 | /** 64 | * @brief 得到障碍物信息图,有障碍物的地方标记为true,没有标记为false 65 | * @param ox 障碍物x坐标集合 66 | * @param oy 障碍物y坐标集合 67 | */ 68 | void Dijkstra::calObstacleMap(const vector &ox, const vector &oy) 69 | { 70 | min_x = round(*min_element(ox.begin(), ox.end())); 71 | min_y = round(*min_element(oy.begin(), oy.end())); 72 | max_x = round(*max_element(ox.begin(), ox.end())); 73 | max_y = round(*max_element(oy.begin(), oy.end())); 74 | 75 | cout << "min_x:" << min_x << " min_y:" << min_y << " max_x:" << max_x << " max_y:" << max_y << endl; 76 | 77 | x_width = round((max_x - min_x) / resolution); 78 | y_width = round((max_y - min_y) / resolution); 79 | cout << "x_width:" << x_width << " y_width:" << y_width << endl; 80 | 81 | obstacle_map = vector>(x_width, vector(y_width, false)); 82 | 83 | for (double i = 0; i < x_width; i++) 84 | { 85 | double x = calPosition(i, min_x); 86 | for (double j = 0; j < y_width; j++) 87 | { 88 | double y = calPosition(j, min_y); 89 | for (double k = 0; k < ox.size(); k++) 90 | { 91 | double d = sqrt(pow(ox[k] - x, 2) + pow(oy[k] - y, 2)); 92 | if (d <= robot_radius) 93 | { 94 | obstacle_map[i][j] = true; 95 | break; 96 | } 97 | } 98 | } 99 | } 100 | } 101 | 102 | /** 103 | * @brief 计算栅格在地图中的位置 104 | * @param index 105 | * @param minp 106 | * @return 107 | */ 108 | double Dijkstra::calPosition(double index, double minp) 109 | { 110 | double pos = index * resolution + minp; 111 | return pos; 112 | } 113 | 114 | /** 115 | * @brief 标记移动代价 116 | * @return 117 | */ 118 | void Dijkstra::getMotionModel() 119 | { 120 | // x,y,cost 121 | motion = {{1, 0, 1}, 122 | {0, 1, 1}, 123 | {-1, 0, 1}, 124 | {0, -1, 1}, 125 | {-1, -1, sqrt(2)}, 126 | {-1, 1, sqrt(2)}, 127 | {1, -1, sqrt(2)}, 128 | {1, 1, sqrt(2)}}; 129 | } 130 | 131 | /** 132 | * @brief 计算起点终点的栅格索引 133 | * @param position 134 | * @param minp 135 | * @return 136 | */ 137 | double Dijkstra::calXyIndex(double position, double minp) 138 | { 139 | return round((position - minp) / resolution); 140 | } 141 | 142 | /** 143 | * @brief 计算栅格索引 144 | * @param node 145 | * @return 146 | */ 147 | double Dijkstra::calIndex(Dijkstra::Node *node) 148 | { 149 | // cout<x<<","<y<y - min_y) * x_width + (node->x - min_x); 151 | } 152 | 153 | /** 154 | * @brief 判断节点是否有效,即是否超出边界和碰到障碍物 155 | * @param node 156 | * @return 157 | */ 158 | bool Dijkstra::verifyNode(Dijkstra::Node *node) 159 | { 160 | double px = calPosition(node->x, min_x); 161 | double py = calPosition(node->y, min_y); 162 | // 超出边界 163 | if (px < min_x || py < min_y || px >= max_x || py >= max_y) 164 | return false; 165 | // 遇到障碍物 166 | if (obstacle_map[node->x][node->y]) 167 | return false; 168 | return true; 169 | } 170 | 171 | /** 172 | * @brief 根据parent_index,倒推出路径 173 | * @param goal_node 174 | * @param closed_set 175 | * @return 176 | */ 177 | pair, vector> Dijkstra::calFinalPath(Dijkstra::Node *goal_node, map closed_set) 178 | { 179 | vector rx, ry; 180 | rx.push_back(calPosition(goal_node->x, min_x)); 181 | ry.push_back(calPosition(goal_node->y, min_y)); 182 | 183 | double parent_index = goal_node->parent_index; 184 | 185 | while (parent_index != -1) 186 | { 187 | Node *node = closed_set[parent_index]; 188 | rx.push_back(calPosition(node->x, min_x)); 189 | ry.push_back(calPosition(node->y, min_y)); 190 | parent_index = node->parent_index; 191 | } 192 | return {rx, ry}; 193 | } 194 | 195 | /** 196 | * @brief 规划 197 | * @param start 起点 198 | * @param goal 终点 199 | * @return 规划后的路径 200 | */ 201 | pair, vector> Dijkstra::planning(const vector start, const vector goal) 202 | { 203 | double sx = start[0], sy = start[1]; 204 | double gx = goal[0], gy = goal[1]; 205 | Node *start_node = new Node(calXyIndex(sx, min_x), calXyIndex(sy, min_y), 0.0, -1); 206 | Node *goal_node = new Node(calXyIndex(gx, min_x), calXyIndex(gy, min_y), 0.0, -1); 207 | 208 | map open_set, closed_set; 209 | // 将起点加入到open set 210 | open_set[calIndex(start_node)] = start_node; 211 | 212 | Node *current = nullptr; 213 | while (true) 214 | { 215 | double cur_id = numeric_limits::max(); 216 | double cost = numeric_limits::max(); 217 | // 计算open_set中代价最小的节点(此处可以用优先队列优化) 218 | for (auto it = open_set.begin(); it != open_set.end(); it++) 219 | { 220 | if (it->second->cost < cost) 221 | { 222 | cost = it->second->cost; 223 | cur_id = it->first; // index 224 | } 225 | } 226 | current = open_set[cur_id]; 227 | 228 | plotGraph(current); // 画图 229 | 230 | // 若找到了目标结点,则退出循环 231 | if (abs(current->x - goal_node->x) < EPS && abs(current->y - goal_node->y) < EPS) 232 | { 233 | cout << "Find goal" << endl; 234 | goal_node->parent_index = current->parent_index; 235 | goal_node->cost = current->cost; 236 | break; 237 | } 238 | 239 | // 从open set中去除 240 | auto iter = open_set.find(cur_id); 241 | open_set.erase(iter); 242 | // 将其加入到closed set 243 | closed_set[cur_id] = current; 244 | 245 | // 根据motion,扩展搜索网络 246 | for (vector move : motion) 247 | { 248 | // cout<x + move[0], current->y + move[1], current->cost + move[2], cur_id); 251 | double n_id = calIndex(node); 252 | // 如果已经在closed_set中了 253 | if (closed_set.find(n_id) != closed_set.end()) 254 | continue; 255 | // 如果超出边界或者碰到障碍物了 256 | if (!verifyNode(node)) 257 | continue; 258 | // 如果open set中没有这个节点 259 | if (open_set.find(n_id) == open_set.end()) 260 | { 261 | open_set[n_id] = node; 262 | } 263 | // 如果open set中已经存在这个节点,则更新cost 264 | else 265 | { 266 | if (open_set[n_id]->cost >= node->cost) 267 | { 268 | open_set[n_id] = node; 269 | } 270 | } 271 | } 272 | } 273 | return calFinalPath(goal_node, closed_set); 274 | } 275 | 276 | /** 277 | * @brief 画图 278 | * @param current 279 | */ 280 | void Dijkstra::plotGraph(Dijkstra::Node *current) 281 | { 282 | plt::plot(vector{calPosition(current->x, min_x)}, vector{calPosition(current->y, min_y)}, "xc"); 283 | // // 将每一帧保存为单独的文件 284 | // stringstream filename; 285 | // filename << "./frame_" << gifindex++ << ".png"; 286 | // plt::save(filename.str()); 287 | plt::pause(0.0000001); 288 | } 289 | 290 | /** 291 | * @brief 设置坐标 292 | * 293 | * @param st 起点 294 | * @param go 目标 295 | * @param ox 障碍物x 296 | * @param oy 障碍物y 297 | */ 298 | void Dijkstra::set(const vector &st, const vector &go, const vector &ox, const vector &oy) 299 | { 300 | Dijkstra::st = st; 301 | Dijkstra::go = go; 302 | Dijkstra::ox = ox; 303 | Dijkstra::oy = oy; 304 | } 305 | -------------------------------------------------------------------------------- /Dijkstra/Dijkstra.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file Dijkstra.h 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-10 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #ifndef DIJKSTRA_H 13 | #define DIJKSTRA_H 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | #include "../matplotlibcpp.h" 24 | namespace plt = matplotlibcpp; 25 | using namespace std; 26 | using namespace Eigen; 27 | 28 | #define EPS 1e-4 29 | #define PI 3.14159265354 30 | 31 | //extern int gifindex; 32 | 33 | class Dijkstra 34 | { 35 | public: 36 | struct Node 37 | { 38 | double x; 39 | double y; 40 | float cost; 41 | // Node* p_node; 42 | double parent_index; 43 | 44 | Node(double x, double y, float cost, double parentIndex); 45 | }; 46 | 47 | private: 48 | double resolution; // 栅格大小 49 | double robot_radius; 50 | double min_x, min_y, max_x, max_y; // 地图范围 51 | double x_width, y_width; // 长宽 52 | vector> obstacle_map; // 障碍物地图 53 | vector> motion; // 障碍物地图 54 | vector st, go; 55 | vector ox, oy; 56 | 57 | public: 58 | Dijkstra(double resolution, double robotRadius); 59 | void setObstale(vector &ox, vector &oy); 60 | void calObstacleMap(const vector &ox, const vector &oy); 61 | 62 | double calPosition(double index, double minp); 63 | 64 | void getMotionModel(); 65 | 66 | double calXyIndex(double position, double minp); 67 | 68 | double calIndex(Node *node); 69 | 70 | bool verifyNode(Node *node); 71 | 72 | pair, vector> calFinalPath(Node *goal_node, map closed_set); 73 | 74 | pair, vector> planning(vector start, vector goal); 75 | 76 | void plotGraph(Node *current); 77 | 78 | void set(const vector &st, const vector &go, const vector &ox, const vector &oy); 79 | }; 80 | 81 | #endif // DIJKSTRA_H 82 | -------------------------------------------------------------------------------- /Dijkstra/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file main.cpp 3 | * @author czj 4 | * @brief 主函数 5 | * @version 0.1 6 | * @date 2023-04-10 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "Dijkstra.h" 13 | 14 | int main() 15 | { 16 | vector start{-5, -5}, goal{50, 50}; 17 | double grid_size = 2.0; 18 | double robot_radius = 1.0; 19 | 20 | vector ox; 21 | vector oy; 22 | 23 | Dijkstra dijkstra(grid_size, robot_radius); 24 | dijkstra.setObstale(ox, oy); // 设置障碍物信息 25 | dijkstra.set(start, goal, ox, oy); // 设置起点,目标点,障碍物的(x,y) 26 | dijkstra.calObstacleMap(ox, oy); // 在地图上生成障碍物 27 | dijkstra.getMotionModel(); // 创建移动代价 28 | 29 | // 绘制地图 30 | plt::plot(ox, oy, ".k"); 31 | plt::plot(vector{start[0]}, vector{start[1]}, "ob"); 32 | plt::plot(vector{goal[0]}, vector{goal[1]}, "or"); 33 | plt::grid(true); 34 | 35 | // 规划路径 36 | pair, vector> xy = dijkstra.planning(start, goal); 37 | // 绘制路径 38 | plt::plot(xy.first, xy.second, "-r"); 39 | 40 | // // 合成 GIF 图片 41 | // stringstream filename; 42 | // filename << "./frame_" << gifindex << ".png"; 43 | // plt::save(filename.str()); 44 | // const char *gif_filename = "./dijkstra_demo.gif"; 45 | // stringstream cmd; 46 | // cmd << "convert -delay 3 -loop 0 ./frame_*.png " << gif_filename; 47 | // system(cmd.str().c_str()); 48 | // cout << "Saving result to " << gif_filename << std::endl; 49 | // plt::show(); 50 | // //删除png图片 51 | // system("rm *.png"); 52 | // return 0; 53 | 54 | // 保存图片 55 | const char *filename = "./dijkstra_demo.png"; 56 | cout << "Saving result to " << filename << std::endl; 57 | plt::save(filename); 58 | plt::show(); 59 | 60 | return 0; 61 | } 62 | -------------------------------------------------------------------------------- /README.assets/dijkstra_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/README.assets/dijkstra_demo.gif -------------------------------------------------------------------------------- /README.assets/rrt_connect-16832995765544.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/README.assets/rrt_connect-16832995765544.gif -------------------------------------------------------------------------------- /README.assets/rrt_connect-16832995877726.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/README.assets/rrt_connect-16832995877726.gif -------------------------------------------------------------------------------- /README.assets/rrt_connect1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/README.assets/rrt_connect1.gif -------------------------------------------------------------------------------- /README.assets/rrt_star_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/README.assets/rrt_star_demo.gif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 简介 2 | 3 | 自动驾驶常用路径规划算法C++实现 4 | 5 | 6 | 7 | # 项目依赖 8 | 9 | 推荐在Ubuntu 18.04/20.04 环境下运行 10 | 11 | - **cmake** 12 | 13 | 在Ubuntu中安装cmake: 14 | 15 | ``` 16 | sudo apt install cmake 17 | ``` 18 | 19 | - **Eigen** 20 | 21 | 在Ubuntu中安装Eigen: 22 | 23 | ``` 24 | sudo apt-get install libeigen3-dev 25 | ``` 26 | 27 | - **python3** 28 | 29 | 30 | 31 | # 编译 32 | 33 | 在当前目录下输入: 34 | 35 | ```shell 36 | mkdir build 37 | cd build 38 | cmake ../ 39 | make 40 | ``` 41 | 42 | 43 | 44 | # Path_planning 45 | 46 | ## Dijkstra 47 | 48 | 49 | 50 | ![dijkstra_demo](README.assets/dijkstra_demo.gif) 51 | 52 | ## A star 53 | 54 | ![astar](https://gitee.com/czjaixuexi/typora_pictures/raw/master/img/astar-168329807331712.gif) 55 | 56 | ## RRT 57 | 58 | 59 | 60 | ![rrt_demo](https://gitee.com/czjaixuexi/typora_pictures/raw/master/img/rrt_demo-168329807983514.gif) 61 | 62 | 63 | 64 | ## RRT connect 65 | 66 | ![rrt_connect](README.assets/rrt_connect-16832995877726.gif) 67 | 68 | ## RRT star 69 | 70 | ![rrt_star_demo](README.assets/rrt_star_demo.gif) 71 | 72 | 73 | 74 | ## Bezier 75 | 76 | ![bezier_demo](https://gitee.com/czjaixuexi/typora_pictures/raw/master/img/bezier_demo.gif) 77 | 78 | 79 | 80 | ## B spline 81 | 82 | ![b_spline_demo](https://gitee.com/czjaixuexi/typora_pictures/raw/master/img/b_spline_demo.gif) 83 | 84 | 85 | 86 | 87 | 88 | # 参考 89 | 90 | [PythonRobotics](https://github.com/AtsushiSakai/PythonRobotics#pythonrobotics) 91 | 92 | [chhRobotics_CPP](https://github.com/CHH3213/chhRobotics_CPP) 93 | 94 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree/RRT.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file RRT.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-23 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "RRT.h" 13 | 14 | //int gifindex = 0; 15 | 16 | Node::Node(double x, double y) : x(x), y(y), parent(NULL), cost(0) {} 17 | 18 | /** 19 | * @brief Construct 20 | * 21 | * @param obstacleList 障碍物位置列表 [[x,y,size],...] 22 | * @param randArea 采样区域 x,y ∈ [min,max]; 23 | * @param playArea 约束随机树的范围 [xmin,xmax,ymin,ymax] 24 | * @param robotRadius 机器人半径 25 | * @param expandDis 扩展的步长 26 | * @param goalSampleRate 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 27 | * @param maxIter 最大采样点数 28 | */ 29 | RRT::RRT(const vector> &obstacleList, 30 | const vector &randArea, const vector &playArea, double robotRadius, double expandDis, 31 | double goalSampleRate, int maxIter) : obstacle_list(obstacleList), rand_area(randArea), 32 | play_area(playArea), robot_radius(robotRadius), expand_dis(expandDis), 33 | goal_sample_rate(goalSampleRate), max_iter(maxIter) {} 34 | 35 | /** 36 | * @brief 计算两点间的距离和方位角 37 | * 38 | * @param from_node 39 | * @param to_node 40 | * @return vector 41 | */ 42 | vector RRT::calDistanceAngle(Node *from_node, Node *to_node) 43 | { 44 | double dx = to_node->x - from_node->x; 45 | double dy = to_node->y - from_node->y; 46 | double d = sqrt(pow(dx, 2) + pow(dy, 2)); 47 | double theta = atan2(dy, dx); 48 | return {d, theta}; 49 | } 50 | 51 | /** 52 | * @brief 判断是否有障碍物 53 | * 54 | * @param node 55 | * @return true 56 | * @return false 57 | */ 58 | bool RRT::obstacleFree(Node *node) 59 | { 60 | for (vector obs : obstacle_list) 61 | { 62 | for (int i = 0; i < node->path_x.size(); i++) 63 | { 64 | double x = node->path_x[i]; 65 | double y = node->path_y[i]; 66 | if (pow(obs[0] - x, 2) + pow(obs[1] - y, 2) <= pow(obs[2] + robot_radius, 2)) 67 | return false; // collision 68 | } 69 | } 70 | return true; 71 | } 72 | 73 | /** 74 | * @brief 判断是否在可行区域里面 75 | * @param node 76 | * @return 77 | */ 78 | bool RRT::isInsidePlayArea(Node *node) 79 | { 80 | if (node->x < play_area[0] || node->x > play_area[1] || node->y < play_area[2] || node->y > play_area[3]) 81 | return false; 82 | return true; 83 | } 84 | 85 | /** 86 | * @brief 计算最近的节点 87 | * @param node_list 节点列表 88 | * @param rnd_node 随机采样的节点 89 | * @return 最近的节点索引 90 | */ 91 | int RRT::getNearestNodeIndex(vector node_list, Node *rnd_node) 92 | { 93 | int min_index = -1; 94 | double d = numeric_limits::max(); 95 | for (int i = 0; i < node_list.size(); i++) 96 | { 97 | Node *node = node_list[i]; 98 | double dist = pow(node->x - rnd_node->x, 2) + pow(node->y - rnd_node->y, 2); 99 | if (d > dist) 100 | { 101 | d = dist; 102 | min_index = i; 103 | } 104 | } 105 | return min_index; 106 | } 107 | 108 | /** 109 | * @brief 随机采样 以(100-goal_sample_rate%的概率随机生长,(goal_sample_rate)%的概率朝向目标点生长 110 | * @return 生成的节点 111 | */ 112 | Node *RRT::sampleFree() 113 | { 114 | Node *rnd = nullptr; 115 | if (rand() % (100) > goal_sample_rate) 116 | { 117 | double x_rand = rand() / double(RAND_MAX) * (rand_area[1] - rand_area[0]) + rand_area[0]; 118 | double y_rand = rand() / double(RAND_MAX) * (rand_area[1] - rand_area[0]) + rand_area[0]; 119 | rnd = new Node(x_rand, y_rand); 120 | } 121 | else 122 | { 123 | rnd = new Node(end->x, end->y); 124 | } 125 | return rnd; 126 | } 127 | 128 | void RRT::setBegin(Node *begin) 129 | { 130 | RRT::begin = begin; 131 | } 132 | 133 | void RRT::setEnd(Node *end) 134 | { 135 | RRT::end = end; 136 | } 137 | 138 | /** 139 | * @brief 计算(x,y)离目标点的距离 140 | * @param x 141 | * @param y 142 | * @return 143 | */ 144 | double RRT::calDistToGoal(double x, double y) 145 | { 146 | double dx = x - end->x; 147 | double dy = y - end->y; 148 | return sqrt(pow(dx, 2) + pow(dy, 2)); 149 | } 150 | 151 | /** 152 | * @brief 生成路径 153 | * @param goal_ind 154 | * @return 155 | */ 156 | pair, vector> RRT::generateFinalCourse(double goal_ind) 157 | { 158 | vector x_, y_; 159 | x_.push_back(end->x); 160 | y_.push_back(end->y); 161 | Node *node = node_list[goal_ind]; 162 | while (node->parent != nullptr) 163 | { 164 | x_.push_back(node->x); 165 | y_.push_back(node->y); 166 | node = node->parent; 167 | // cout<x<<","<y<x); 170 | y_.push_back(node->y); 171 | return {x_, y_}; 172 | } 173 | 174 | /** 175 | * @brief 连线方向扩展固定步长查找x_new 176 | * @param from_node x_near 177 | * @param to_node x_rand 178 | * @param extend_length 扩展步长u. Defaults to float("inf"). 179 | * @return 180 | */ 181 | Node *RRT::steer(Node *from_node, Node *to_node, double extend_length) 182 | { 183 | // 利用反正切计算角度, 然后利用角度和步长计算新坐标 184 | vector dist_angle = calDistanceAngle(from_node, to_node); 185 | 186 | double new_x, new_y; 187 | if (extend_length >= dist_angle[0]) 188 | { 189 | new_x = to_node->x; 190 | new_y = to_node->y; 191 | } 192 | else 193 | { 194 | new_x = from_node->x + extend_length * cos(dist_angle[1]); 195 | new_y = from_node->y + extend_length * sin(dist_angle[1]); 196 | } 197 | 198 | Node *new_node = new Node(new_x, new_y); 199 | new_node->path_x.push_back(from_node->x); 200 | new_node->path_y.push_back(from_node->y); 201 | new_node->path_x.push_back(new_node->x); 202 | new_node->path_y.push_back(new_node->y); 203 | 204 | new_node->parent = from_node; 205 | // cout<x<<","<y<, vector> RRT::planning() 215 | { 216 | node_list.push_back(begin); // 将起点作为根节点x_{init},加入到随机树的节点集合中。 217 | for (int i = 0; i < max_iter; i++) 218 | { 219 | // 从可行区域内随机选取一个节点x_{rand} 220 | Node *rnd_node = sampleFree(); 221 | 222 | // 已生成的树中利用欧氏距离判断距离x_{rand}最近的点x_{near}。 223 | int nearest_ind = getNearestNodeIndex(node_list, rnd_node); 224 | Node *nearest_node = node_list[nearest_ind]; 225 | 226 | // 从x_{near}与x_{rand}的连线方向上扩展固定步长u,得到新节点 x_{new} 227 | Node *new_node = steer(nearest_node, rnd_node, expand_dis); 228 | 229 | // 如果在可行区域内,且x_{near}与x_{new}之间无障碍物 230 | if (isInsidePlayArea(new_node) && obstacleFree(new_node)) 231 | { 232 | node_list.push_back(new_node); 233 | } 234 | // cout<x, node_list[node_list.size() - 1]->y) <= expand_dis) 237 | { 238 | Node *final_node = steer(node_list[node_list.size() - 1], end, expand_dis); 239 | if (obstacleFree(final_node)) 240 | { 241 | cout << "reaches the goal!" << endl; 242 | // return {node_list[node_list.size()-1]->path_x,node_list[node_list.size()-1]->path_y}; 243 | return generateFinalCourse(node_list.size() - 1); 244 | } 245 | } 246 | // cout<x<<","<y< x_t, y_t; 262 | for (double i = 0.; i <= 2 * PI; i += 0.01) 263 | { 264 | x_t.push_back(x + size * cos(i)); 265 | y_t.push_back(y + size * sin(i)); 266 | } 267 | plt::plot(x_t, y_t, color); 268 | } 269 | 270 | /** 271 | * 画出搜索过程的图 272 | * @param node 273 | */ 274 | void RRT::draw(Node *node) 275 | { 276 | plt::clf(); 277 | // 画随机点 278 | if (node) 279 | { 280 | plt::plot(vector{node->x}, vector{node->y}, "^k"); 281 | if (robot_radius > 0) 282 | { 283 | plotCircle(node->x, node->y, robot_radius, "-r"); 284 | } 285 | } 286 | 287 | // 画已生成的树 288 | for (Node *node1 : node_list) 289 | { 290 | if (node1->parent) 291 | { 292 | plt::plot(node1->path_x, node1->path_y, "-g"); 293 | } 294 | } 295 | // 画障碍物 296 | for (vector ob : obstacle_list) 297 | { 298 | plotCircle(ob[0], ob[1], ob[2]); 299 | } 300 | 301 | plt::plot(vector{play_area[0], play_area[1], play_area[1], play_area[0], play_area[0]}, vector{play_area[2], play_area[2], play_area[3], play_area[3], play_area[2]}, "k-"); 302 | 303 | // 画出起点和目标点 304 | plt::plot(vector{begin->x}, vector{begin->y}, "xr"); 305 | plt::plot(vector{end->x}, vector{end->y}, "xr"); 306 | plt::axis("equal"); 307 | plt::grid(true); 308 | plt::xlim(play_area[0] - 1, play_area[1] + 1); 309 | plt::ylim(play_area[2] - 1, play_area[3] + 1); 310 | 311 | // // 将每一帧保存为单独的文件 312 | // stringstream filename; 313 | // filename << "./frame_" << gifindex++ << ".png"; 314 | // plt::save(filename.str()); 315 | 316 | plt::pause(0.01); 317 | } 318 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree/RRT.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file RRT.h 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-23 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | #ifndef RRT_H 12 | #define RRT_H 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | #include "../matplotlibcpp.h" 23 | namespace plt = matplotlibcpp; 24 | 25 | using namespace std; 26 | using namespace Eigen; 27 | 28 | #define PI 3.14159265354 29 | 30 | //extern int gifindex; 31 | 32 | class Node 33 | { 34 | public: 35 | double x, y; // 节点坐标 36 | vector path_x = {}, path_y = {}; // 路径,作为画图的数据 37 | Node *parent; 38 | double cost; 39 | 40 | public: 41 | Node(double x, double y); 42 | }; 43 | 44 | class RRT 45 | { 46 | public: 47 | vector> obstacle_list; // 障碍物位置列表 [[x,y,size],...] 48 | vector rand_area, play_area; // 采样区域 x,y ∈ [min,max];约束随机树的范围 [xmin,xmax,ymin,ymax] 49 | double robot_radius; // 机器人半径 50 | double expand_dis; // 扩展的步长 51 | double goal_sample_rate; // 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 52 | vector node_list; 53 | Node *begin; // 起始点 54 | Node *end; // 目标点 55 | 56 | int max_iter; //最大采样点数 57 | 58 | public: 59 | 60 | RRT(const vector> &obstacleList, 61 | const vector &randArea, const vector &playArea, double robotRadius, double expandDis, 62 | double goalSampleRate, int maxIter); 63 | 64 | vector calDistanceAngle(Node *from_node, Node *to_node); // 计算两个节点间的距离和方位角 65 | 66 | bool obstacleFree(Node *node); // 判断是否有障碍物 67 | 68 | bool isInsidePlayArea(Node *node); // 判断是否在可行区域里面 69 | 70 | int getNearestNodeIndex(vector node_list, Node *rnd_node); // 计算最近的节点 71 | 72 | Node *sampleFree(); // 采样生成节点 73 | 74 | double calDistToGoal(double x, double y); // 计算(x,y)离目标点的距离 75 | 76 | pair, vector> generateFinalCourse(double goal_ind); // 生成路径,画图 77 | 78 | Node *steer(Node *from_node, Node *to_node, double extend_length = numeric_limits::max()); // 连线方向扩展固定步长查找x_new 79 | 80 | pair, vector> planning(); 81 | 82 | void setBegin(Node *begin); 83 | 84 | void setEnd(Node *End); 85 | 86 | void plotCircle(double x, double y, double size, string color = "b"); // 画圆 87 | void draw(Node *node = nullptr); 88 | }; 89 | 90 | #endif // RRT_H 91 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file main.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-23 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | #include "RRT.h" 12 | 13 | int main() 14 | { 15 | vector> obstacle_list{ 16 | {5, 5, 1}, 17 | {3, 6, 2}, 18 | {3, 8, 2}, 19 | {3, 10, 2}, 20 | {7, 5, 2}, 21 | {9, 5, 2}, 22 | {8, 10, 1}, 23 | {6, 12, 1}}; 24 | Node *begin = new Node(0.0, 0.0); 25 | Node *end = new Node(6.0, 10.0); 26 | vector rnd_area{-2, 15}; // 采样区域 x,y ∈ [min,max]; 27 | vector play_area{-2, 12, 0, 14}; // 约束随机树的范围 [xmin,xmax,ymin,ymax] 28 | double radius = 0.8; // 机器人半径 29 | double expand_dis = 2; // 扩展的步长 30 | double goal_sample_rate = 5; // 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 31 | int max_iter = 500; // 最大采样点数 32 | RRT rrt(obstacle_list, rnd_area, play_area, radius, expand_dis, goal_sample_rate, max_iter); 33 | rrt.setBegin(begin); 34 | rrt.setEnd(end); 35 | 36 | pair, vector> traj = rrt.planning(); 37 | 38 | plt::plot(traj.first, traj.second, "r"); 39 | 40 | // // 合成 GIF 图片 41 | // stringstream filename; 42 | // filename << "./frame_" << gifindex << ".png"; 43 | // plt::save(filename.str()); 44 | 45 | // const char *gif_filename = "./rrt_demo.gif"; 46 | // stringstream cmd; 47 | // cmd << "convert -delay 10 -loop 0 ./frame_*.png " << gif_filename; 48 | // system(cmd.str().c_str()); 49 | // cout << "Saving result to " << gif_filename << std::endl; 50 | // plt::show(); 51 | // //删除png图片 52 | // system("rm *.png"); 53 | // return 0; 54 | 55 | const char *filename = "./rrt_demo.png"; 56 | cout << "Saving result to " << filename << std::endl; 57 | plt::save(filename); 58 | plt::show(); 59 | 60 | return 0; 61 | } -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree_Star/RRT_Star.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file RRT_Star.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-24 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "RRT_Star.h" 13 | 14 | 15 | /** 16 | * @brief Construct a new rrt star::rrt star object 17 | * 18 | * @param obstacleList 障碍物位置列表 [[x,y,size],...] 19 | * @param randArea 采样区域 x,y ∈ [min,max]; 20 | * @param playArea 约束随机树的范围 [xmin,xmax,ymin,ymax] 21 | * @param robotRadius 机器人半径 22 | * @param expandDis 扩展的步长 23 | * @param goalSampleRate 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 24 | * @param maxIter 最大采样点数 25 | * @param connectCircleDist rewire的探索半径 26 | * @param searchUntilMaxIter 27 | */ 28 | RRT_Star::RRT_Star(const vector> &obstacleList, const vector &randArea, 29 | const vector &playArea, double robotRadius, double expandDis, double goalSampleRate, 30 | int maxIter, double connectCircleDist, bool searchUntilMaxIter) : RRT(obstacleList, randArea, 31 | playArea, robotRadius, 32 | expandDis, goalSampleRate, 33 | maxIter), 34 | connect_circle_dist( 35 | connectCircleDist), 36 | search_until_max_iter(searchUntilMaxIter) {} 37 | 38 | pair, vector> RRT_Star::planning() 39 | { 40 | node_list.push_back(begin); // 将起点作为根节点x_{init},加入到随机树的节点集合中。 41 | for (int i = 0; i < max_iter; i++) 42 | { 43 | // 从可行区域内随机选取一个节点x_{rand} 44 | Node *rnd_node = sampleFree(); 45 | cout << "随机树节点个数:" << node_list.size() << endl; 46 | // 已生成的树中利用欧氏距离判断距离x_{rand}最近的点x_{near}。 47 | int nearest_ind = getNearestNodeIndex(node_list, rnd_node); 48 | Node *nearest_node = node_list[nearest_ind]; 49 | // 从x_{near}与x_{rand}的连线方向上扩展固定步长u,得到新节点 x_{new} 50 | Node *new_node = steer(nearest_node, rnd_node, expand_dis); 51 | // 计算代价,欧氏距离 52 | new_node->cost = nearest_node->cost + sqrt(pow(new_node->x - nearest_node->x, 2) + pow(new_node->y - nearest_node->y, 2)); 53 | 54 | // 如果在可行区域内,且x_{near}与x_{new}之间无障碍物 55 | if (isInsidePlayArea(new_node) && obstacleFree(new_node)) 56 | { 57 | vector near_ind = findNearInds(new_node); // 找到x_new的邻近节点 58 | Node *node_with_updated_parent = chooseParent(new_node, near_ind); // 重新选择父节点 59 | // 如果父节点更新了(非空) 60 | if (node_with_updated_parent) 61 | { 62 | // 重布线 63 | rewire(node_with_updated_parent, near_ind); 64 | node_list.push_back(node_with_updated_parent); 65 | } 66 | else 67 | { 68 | node_list.push_back(new_node); 69 | } 70 | } 71 | 72 | draw(rnd_node); 73 | 74 | if ((!search_until_max_iter) && new_node) 75 | { // reaches goal 76 | int last_index = findBestGoalInd(); 77 | if (last_index != -1) 78 | { 79 | cout << "reaches the goal!" << endl; 80 | return generateFinalCourse(last_index); 81 | } 82 | } 83 | } 84 | cout << "达到最大回合数" << endl; 85 | int last_index = findBestGoalInd(); 86 | if (last_index != -1) 87 | return generateFinalCourse(last_index); 88 | return {}; 89 | } 90 | 91 | /** 92 | * @brief 计算周围一定半径内的所有邻近节点 93 | * @param new_node 94 | * @return 所有邻近节点索引 95 | */ 96 | vector RRT_Star::findNearInds(Node *new_node) 97 | { 98 | int nnode = node_list.size() + 1; 99 | vector inds; 100 | double r = connect_circle_dist * sqrt(log(nnode) / nnode); 101 | for (int i = 0; i < node_list.size(); i++) 102 | { 103 | Node *n_ = node_list[i]; 104 | if (pow(n_->x - new_node->x, 2) + pow(n_->y - new_node->y, 2) < r * r) 105 | { 106 | inds.push_back(i); 107 | } 108 | } 109 | return inds; 110 | } 111 | 112 | /** 113 | * @brief 更新以paren_node为父节点到起点的所有cost 114 | * 115 | * @param parent_node 116 | */ 117 | void RRT_Star::propagateCostToLeaves(Node *parent_node) 118 | { 119 | for (Node *node : node_list) 120 | { 121 | if (node->parent == parent_node) 122 | { 123 | node->cost = calcNewCost(parent_node, node); 124 | propagateCostToLeaves(node); 125 | } 126 | } 127 | } 128 | 129 | /** 130 | * @brief 计算代价 131 | * @param from_node 132 | * @param to_node 133 | * @return 134 | */ 135 | double RRT_Star::calcNewCost(Node *from_node, Node *to_node) 136 | { 137 | vector da = calDistanceAngle(from_node, to_node); 138 | return from_node->cost + da[0]; 139 | } 140 | 141 | /** 142 | * @brief rewire 143 | * @param new_node 144 | * @param near_inds 145 | */ 146 | void RRT_Star::rewire(Node *new_node, vector near_inds) 147 | { 148 | for (int i : near_inds) 149 | { 150 | Node *near_node = node_list[i]; 151 | Node *edge_node = steer(new_node, near_node); 152 | if (!edge_node) 153 | continue; 154 | edge_node->cost = calcNewCost(new_node, near_node); 155 | 156 | if (obstacleFree(edge_node) && near_node->cost > edge_node->cost) 157 | { 158 | near_node->x = edge_node->x; 159 | near_node->y = edge_node->y; 160 | near_node->cost = edge_node->cost; 161 | near_node->path_x = edge_node->path_x; 162 | near_node->path_y = edge_node->path_y; 163 | near_node->parent = edge_node->parent; 164 | propagateCostToLeaves(new_node); 165 | } 166 | } 167 | } 168 | 169 | /** 170 | * @brief 计算离目标点的最佳索引 171 | * @return 172 | */ 173 | int RRT_Star::findBestGoalInd() 174 | { 175 | vector goal_inds, safe_goal_inds; 176 | for (int i = 0; i < node_list.size(); i++) 177 | { 178 | Node *node = node_list[i]; 179 | double dist = calDistToGoal(node->x, node->y); 180 | if (dist <= expand_dis) 181 | { 182 | goal_inds.push_back(i); 183 | } 184 | } 185 | 186 | for (int goal_ind : goal_inds) 187 | { 188 | Node *t_node = steer(node_list[goal_ind], end); 189 | if (obstacleFree(t_node)) 190 | { 191 | safe_goal_inds.push_back(goal_ind); 192 | } 193 | } 194 | if (safe_goal_inds.empty()) 195 | return -1; 196 | double min_cost = numeric_limits::max(); 197 | int safe_ind = -1; 198 | for (int ind : safe_goal_inds) 199 | { 200 | if (min_cost > node_list[ind]->cost) 201 | { 202 | min_cost = node_list[ind]->cost; 203 | safe_ind = ind; 204 | } 205 | } 206 | return safe_ind; 207 | } 208 | 209 | /** 210 | * @brief 在新产生的节点 $x_{new}$ 附近以定义的半径范围$r$内寻找所有的近邻节点 $X_{near}$, 211 | 作为替换 $x_{new}$ 原始父节点 $x_{near}$ 的备选 212 | 我们需要依次计算起点到每个近邻节点 $X_{near}$ 的路径代价 加上近邻节点 $X_{near}$ 到 $x_{new}$ 的路径代价, 213 | 取路径代价最小的近邻节点$x_{min}$作为 $x_{new}$ 新的父节点 214 | * @param new_node 215 | * @param near_inds 216 | * @return 217 | */ 218 | Node *RRT_Star::chooseParent(Node *new_node, vector near_inds) 219 | { 220 | if (near_inds.empty()) 221 | return nullptr; 222 | vector costs; 223 | for (int i : near_inds) 224 | { 225 | Node *near_node = node_list[i]; 226 | Node *t_node = steer(near_node, new_node); 227 | if (t_node && obstacleFree(t_node)) 228 | { 229 | costs.push_back(calcNewCost(near_node, new_node)); 230 | } 231 | else 232 | { 233 | costs.push_back(numeric_limits::max()); // the cost of collision node 234 | } 235 | } 236 | double min_cost = *min_element(costs.begin(), costs.end()); 237 | 238 | if (min_cost == numeric_limits::max()) 239 | { 240 | cout << "There is no good path.(min_cost is inf)" << endl; 241 | return nullptr; 242 | } 243 | int min_ind = near_inds[min_element(costs.begin(), costs.end()) - costs.begin()]; 244 | 245 | Node *determine_node = steer(node_list[min_ind], new_node); 246 | determine_node->cost = min_cost; 247 | return determine_node; 248 | } 249 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree_Star/RRT_Star.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file RRT_Star.h 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-24 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #ifndef RRT_STAR_H 13 | #define RRT_STAR_H 14 | 15 | #include "../Rapidly-exploring_Random_Tree/RRT.h" 16 | 17 | 18 | class RRT_Star : public RRT 19 | { 20 | public: 21 | double connect_circle_dist; // rewire的探索半径 22 | bool search_until_max_iter; // 达到最大步长后是否停止搜索 23 | 24 | RRT_Star(const vector> &obstacleList, const vector &randArea, const vector &playArea, 25 | double robotRadius, double expandDis, double goalSampleRate, int maxIter, double connectCircleDist, 26 | bool searchUntilMaxIter); 27 | 28 | pair, vector> planning(); 29 | 30 | vector findNearInds(Node *new_node); // 找出邻近节点集合 31 | 32 | void propagateCostToLeaves(Node *parent_node); 33 | 34 | double calcNewCost(Node *from_node, Node *to_node); 35 | 36 | void rewire(Node *new_node, vector near_inds); 37 | 38 | int findBestGoalInd(); 39 | 40 | Node *chooseParent(Node *new_node, vector near_inds); 41 | }; 42 | 43 | #endif // RRT_STAR_H 44 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree_Star/main.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by chh3213 on 2022/11/27. 3 | // 4 | #include "RRT_Star.h" 5 | 6 | int main() 7 | { 8 | vector> obstacle_list{ 9 | {5, 5, 1}, 10 | {3, 6, 2}, 11 | {3, 8, 2}, 12 | {3, 10, 2}, 13 | {7, 5, 2}, 14 | {9, 5, 2}, 15 | {8, 10, 1}, 16 | {6, 12, 1}}; 17 | Node *begin = new Node(0.0, 0.0); 18 | Node *end = new Node(6.0, 10.0); 19 | vector rnd_area{-2, 15}; 20 | vector play_area{-2, 12, 0, 14}; 21 | double radius = 0.5; 22 | double expand_dis = 1; // 扩展的步长 23 | double goal_sample_rate = 20; // 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 24 | int max_iter = 500; 25 | double connect_circle_dist = 5.0; 26 | bool search_until_max_iter = false; 27 | 28 | RRT_Star rrt(obstacle_list, rnd_area, play_area, radius, expand_dis, goal_sample_rate, max_iter, connect_circle_dist, search_until_max_iter); 29 | rrt.setBegin(begin); 30 | rrt.setEnd(end); 31 | 32 | pair, vector> traj = rrt.planning(); 33 | 34 | plt::plot(traj.first, traj.second, "r"); 35 | 36 | // // 合成 GIF 图片 37 | // stringstream filename; 38 | // filename << "./frame_" << gifindex << ".png"; 39 | // plt::save(filename.str()); 40 | 41 | // const char *gif_filename = "./rrt_star_demo.gif"; 42 | // stringstream cmd; 43 | // cmd << "convert -delay 10 -loop 0 ./frame_*.png " << gif_filename; 44 | // system(cmd.str().c_str()); 45 | // cout << "Saving result to " << gif_filename << std::endl; 46 | // plt::show(); 47 | // //删除png图片 48 | // system("rm *.png"); 49 | // return 0; 50 | 51 | // save figure 52 | const char *filename = "./rrt_star_demo.png"; 53 | cout << "Saving result to " << filename << std::endl; 54 | plt::save(filename); 55 | plt::show(); 56 | 57 | return 0; 58 | } 59 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree_connect/RRT_connect.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file RRT_connect.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-24 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "RRT_connect.h" 13 | 14 | Node::Node(double x, double y) : x(x), y(y), parent(NULL) {} 15 | 16 | //int gifindex = 0; 17 | 18 | /** 19 | * @brief Construct 20 | * 21 | * @param obstacleList 障碍物位置列表 [[x,y,size],...] 22 | * @param randArea 采样区域 x,y ∈ [min,max]; 23 | * @param playArea 约束随机树的范围 [xmin,xmax,ymin,ymax] 24 | * @param robotRadius 机器人半径 25 | * @param expandDis 扩展的步长 26 | * @param goalSampleRate 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 27 | * @param maxIter 最大采样点数 28 | */ 29 | RRTConnect::RRTConnect(const vector> &obstacleList, 30 | const vector &randArea, const vector &playArea, double robotRadius, double expandDis, 31 | double goalSampleRate, int maxIter) : obstacle_list(obstacleList), rand_area(randArea), 32 | play_area(playArea), robot_radius(robotRadius), expand_dis(expandDis), 33 | goal_sample_rate(goalSampleRate), max_iter(maxIter) {} 34 | 35 | /** 36 | * @brief 计算两个节点间的距离和方位角 37 | * @param from_node 38 | * @param to_node 39 | * @return 40 | */ 41 | vector RRTConnect::calDistanceAngle(Node *from_node, Node *to_node) 42 | { 43 | double dx = to_node->x - from_node->x; 44 | double dy = to_node->y - from_node->y; 45 | double d = sqrt(pow(dx, 2) + pow(dy, 2)); 46 | double theta = atan2(dy, dx); 47 | return {d, theta}; 48 | } 49 | 50 | /** 51 | * @brief 判断是否有障碍物 52 | * @param node 节点坐标 53 | * @return 54 | */ 55 | bool RRTConnect::obstacleFree(Node *node) 56 | { 57 | for (vector obs : obstacle_list) 58 | { 59 | for (int i = 0; i < node->path_x.size(); i++) 60 | { 61 | double x = node->path_x[i]; 62 | double y = node->path_y[i]; 63 | if (pow(obs[0] - x, 2) + pow(obs[1] - y, 2) <= pow(obs[2] + robot_radius, 2)) 64 | return false; // collision 65 | } 66 | } 67 | return true; // safe 68 | } 69 | 70 | /** 71 | * @brief 判断是否在可行区域里面 72 | * @param node 73 | * @return 74 | */ 75 | bool RRTConnect::isInsidePlayArea(Node *node) 76 | { 77 | if (node->x < play_area[0] || node->x > play_area[1] || node->y < play_area[2] || node->y > play_area[3]) 78 | return false; 79 | return true; 80 | } 81 | 82 | /** 83 | * @brief 计算最近的节点 84 | * @param node_list 节点列表 85 | * @param rnd_node 随机采样的节点 86 | * @return 最近的节点索引 87 | */ 88 | int RRTConnect::getNearestNodeIndex(vector node_list, Node *rnd_node) 89 | { 90 | int min_index = -1; 91 | double d = numeric_limits::max(); 92 | for (int i = 0; i < node_list.size(); i++) 93 | { 94 | Node *node = node_list[i]; 95 | double dist = pow(node->x - rnd_node->x, 2) + pow(node->y - rnd_node->y, 2); 96 | if (d > dist) 97 | { 98 | d = dist; 99 | min_index = i; 100 | } 101 | } 102 | return min_index; 103 | } 104 | 105 | /** 106 | * @brief 以(100-goal_sample_rate)%的概率随机生长,(goal_sample_rate)%的概率朝向目标点生长 107 | * @return 生成的节点 108 | */ 109 | Node *RRTConnect::sampleFree() 110 | { 111 | Node *rnd = nullptr; 112 | if (rand() % (100) > goal_sample_rate) 113 | { 114 | 115 | double x_rand = rand() / double(RAND_MAX) * (rand_area[1] - rand_area[0]) + rand_area[0]; 116 | double y_rand = rand() / double(RAND_MAX) * (rand_area[1] - rand_area[0]) + rand_area[0]; 117 | // cout<x, end->y); 123 | } 124 | return rnd; 125 | } 126 | 127 | void RRTConnect::setBegin(Node *begin) 128 | { 129 | RRTConnect::begin = begin; 130 | } 131 | 132 | void RRTConnect::setEnd(Node *end) 133 | { 134 | RRTConnect::end = end; 135 | } 136 | 137 | /** 138 | * @brief 计算(x,y)离目标点的距离 139 | * @param x 140 | * @param y 141 | * @return 142 | */ 143 | double RRTConnect::calDistToGoal(double x, double y) 144 | { 145 | double dx = x - end->x; 146 | double dy = y - end->y; 147 | return sqrt(pow(dx, 2) + pow(dy, 2)); 148 | } 149 | 150 | /** 151 | * @brief 生成路径 152 | * @param goal_ind 153 | * @return 154 | */ 155 | pair, vector> RRTConnect::generateFinalCourse() 156 | { 157 | vector x_, y_, x1, y1, x2, y2; 158 | 159 | Node *node = node_list_1[node_list_1.size() - 1]; 160 | while (node->parent != nullptr) 161 | { 162 | x1.push_back(node->x); 163 | y1.push_back(node->y); 164 | node = node->parent; 165 | // cout<x<<","<y<x); 168 | y1.push_back(node->y); 169 | 170 | node = node_list_2[node_list_2.size() - 1]; 171 | while (node->parent != nullptr) 172 | { 173 | x2.push_back(node->x); 174 | y2.push_back(node->y); 175 | node = node->parent; 176 | // cout<x<<","<y<x); 179 | y2.push_back(node->y); 180 | 181 | for (int i = x1.size() - 1; i >= 0; i--) 182 | { 183 | x_.push_back(x1[i]); 184 | y_.push_back(y1[i]); 185 | } 186 | for (int i = 0; i < x2.size(); i++) 187 | { 188 | x_.push_back(x2[i]); 189 | y_.push_back(y2[i]); 190 | } 191 | 192 | return {x_, y_}; 193 | } 194 | 195 | /** 196 | * @brief 连线方向扩展固定步长查找x_new 197 | * @param from_node x_near 198 | * @param to_node x_rand 199 | * @param extend_length 扩展步长u. Defaults to float("inf"). 200 | * @return 201 | */ 202 | Node *RRTConnect::steer(Node *from_node, Node *to_node, double extend_length) 203 | { 204 | // 利用反正切计算角度, 然后利用角度和步长计算新坐标 205 | vector dist_angle = calDistanceAngle(from_node, to_node); 206 | double new_x, new_y; 207 | if (extend_length >= dist_angle[0]) 208 | { 209 | new_x = to_node->x; 210 | new_y = to_node->y; 211 | } 212 | else 213 | { 214 | new_x = from_node->x + extend_length * cos(dist_angle[1]); 215 | new_y = from_node->y + extend_length * sin(dist_angle[1]); 216 | } 217 | 218 | Node *new_node = new Node(new_x, new_y); 219 | new_node->path_x.push_back(from_node->x); 220 | new_node->path_y.push_back(from_node->y); 221 | new_node->path_x.push_back(new_node->x); 222 | new_node->path_y.push_back(new_node->y); 223 | 224 | new_node->parent = from_node; 225 | return new_node; 226 | } 227 | 228 | /** 229 | * @brief rrt path planning,两边同时进行搜索 230 | * @return 轨迹数据 231 | */ 232 | pair, vector> RRTConnect::planning() 233 | { 234 | node_list_1.push_back(begin); // 将起点作为根节点x_{init},加入到第一棵随机树的节点集合中。 235 | node_list_2.push_back(end); // 将终点作为根节点x_{init},加入到第二棵随机树的节点集合中。 236 | for (int i = 0; i < max_iter; i++) 237 | { 238 | // 从可行区域内随机选取一个节点x_{rand} 239 | Node *rnd_node = sampleFree(); 240 | 241 | // 已生成的树中利用欧氏距离判断距离x_{rand}最近的点x_{near}。 242 | int nearest_ind = getNearestNodeIndex(node_list_1, rnd_node); 243 | Node *nearest_node = node_list_1[nearest_ind]; 244 | // 从x_{near}与x_{rand}的连线方向上扩展固定步长u,得到新节点 x_{new} 245 | Node *new_node = steer(nearest_node, rnd_node, expand_dis); 246 | 247 | // 第一棵树,如果在可行区域内,且q_{near}与q_{new}之间无障碍物 248 | if (isInsidePlayArea(new_node) && obstacleFree(new_node)) 249 | { 250 | // 将x_new加入到集合中 251 | node_list_1.push_back(new_node); 252 | // 扩展完第一棵树的新节点x_{𝑛𝑒𝑤}后,以这个新的目标点x_{𝑛𝑒𝑤}作为第二棵树扩展的x_rand。 253 | nearest_ind = getNearestNodeIndex(node_list_2, new_node); 254 | nearest_node = node_list_2[nearest_ind]; 255 | // 从x_{near}与x_{rand}的连线方向上扩展固定步长u,得到新节点 x_{new2} 256 | Node *new_node_2 = steer(nearest_node, new_node, expand_dis); 257 | // 第二棵树,如果在可行区域内,且x_{near}与x_{new2}之间无障碍物 258 | if (isInsidePlayArea(new_node_2) && obstacleFree(new_node_2)) 259 | { 260 | // 将x_new2加入第二棵树 261 | node_list_2.push_back(new_node_2); 262 | // 接下来,第二棵树继续往第一棵树的x_{new}方向扩展,直到扩展失败(碰到障碍物)或者𝑞′𝑛𝑒𝑤=𝑞𝑛𝑒𝑤表示与第一棵树相连 263 | while (true) 264 | { 265 | Node *new_node_2_ = steer(new_node_2, new_node, expand_dis); 266 | if (obstacleFree(new_node_2_)) 267 | { 268 | node_list_2.push_back(new_node_2_); 269 | new_node_2 = new_node_2_; 270 | } 271 | else 272 | break; 273 | // 当$𝑞′_{𝑛𝑒𝑤}=𝑞_{𝑛𝑒𝑤}$时,表示与第一棵树相连,算法结束 274 | if (abs(new_node_2->x - new_node->x) < 0.00001 && abs(new_node_2->y - new_node->y) < 0.00001) 275 | { 276 | cout << "reaches the goal!" << endl; 277 | return generateFinalCourse(); 278 | } 279 | } 280 | } 281 | } 282 | // # 考虑两棵树的平衡性,即两棵树的节点数的多少,交换次序选择“小”的那棵树进行扩展。 283 | if (node_list_1.size() > node_list_2.size()) 284 | { 285 | swap(node_list_1, node_list_2); 286 | } 287 | draw(rnd_node, new_node); 288 | } 289 | return {}; 290 | } 291 | 292 | /** 293 | * 画圆 294 | * @param x 295 | * @param y 296 | * @param size 297 | * @param color 298 | */ 299 | void RRTConnect::plotCircle(double x, double y, double size, string color) 300 | { 301 | vector x_t, y_t; 302 | for (double i = 0.; i <= 2 * PI; i += 0.01) 303 | { 304 | x_t.push_back(x + size * cos(i)); 305 | y_t.push_back(y + size * sin(i)); 306 | } 307 | plt::plot(x_t, y_t, color); 308 | } 309 | 310 | /** 311 | * 画出搜索过程的图 312 | * @param node 313 | */ 314 | void RRTConnect::draw(Node *node1, Node *node2) 315 | { 316 | plt::clf(); 317 | // 画随机点 318 | if (node1) 319 | { 320 | plt::plot(vector{node1->x}, vector{node1->y}, "^k"); 321 | if (robot_radius > 0) 322 | { 323 | plotCircle(node1->x, node1->y, robot_radius, "-r"); 324 | } 325 | } 326 | if (node2) 327 | { 328 | plt::plot(vector{node2->x}, vector{node2->y}, "^k"); 329 | if (robot_radius > 0) 330 | { 331 | plotCircle(node2->x, node2->y, robot_radius, "-r"); 332 | } 333 | } 334 | 335 | // 画已生成的树 336 | for (Node *n1 : node_list_1) 337 | { 338 | if (n1->parent) 339 | { 340 | plt::plot(n1->path_x, n1->path_y, "-g"); 341 | } 342 | } 343 | // 画已生成的树 344 | for (Node *n2 : node_list_2) 345 | { 346 | if (n2->parent) 347 | { 348 | plt::plot(n2->path_x, n2->path_y, "-g"); 349 | } 350 | } 351 | // 画障碍物 352 | for (vector ob : obstacle_list) 353 | { 354 | plotCircle(ob[0], ob[1], ob[2]); 355 | } 356 | 357 | plt::plot(vector{play_area[0], play_area[1], play_area[1], play_area[0], play_area[0]}, vector{play_area[2], play_area[2], play_area[3], play_area[3], play_area[2]}, "k-"); 358 | 359 | // 画出起点和目标点 360 | plt::plot(vector{begin->x}, vector{begin->y}, "xr"); 361 | plt::plot(vector{end->x}, vector{end->y}, "xr"); 362 | plt::axis("equal"); 363 | plt::grid(true); 364 | plt::xlim(play_area[0] - 1, play_area[1] + 1); 365 | plt::ylim(play_area[2] - 1, play_area[3] + 1); 366 | // // 将每一帧保存为单独的文件 367 | // stringstream filename; 368 | // filename << "./frame_" << gifindex++ << ".png"; 369 | // plt::save(filename.str()); 370 | plt::pause(0.01); 371 | } 372 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree_connect/RRT_connect.h: -------------------------------------------------------------------------------- 1 | /** 2 | * @file RRT_connect.h 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-24 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #ifndef RRT_CONNECT_H 13 | #define RRT_CONNECT_H 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #include "../matplotlibcpp.h" 24 | namespace plt = matplotlibcpp; 25 | 26 | using namespace std; 27 | using namespace Eigen; 28 | 29 | #define PI 3.14159265354 30 | 31 | //extern int gifindex; 32 | 33 | class Node 34 | { 35 | public: 36 | double x, y; // 节点坐标 37 | vector path_x = {}, path_y = {}; // 路径,作为画图的数据 38 | Node *parent; 39 | double cost; 40 | 41 | public: 42 | Node(double x, double y); 43 | }; 44 | 45 | class RRTConnect 46 | { 47 | public: 48 | vector> obstacle_list; // 障碍物位置列表 [[x,y,size],...] 49 | vector rand_area, play_area; // 采样区域 x,y ∈ [min,max];约束随机树的范围 [xmin,xmax,ymin,ymax] 50 | 51 | double robot_radius; // 机器人半径 52 | double expand_dis; // 扩展的步长 53 | double goal_sample_rate; // 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 54 | vector node_list_1, node_list_2; // 与RRT不同的地方 55 | Node *begin; // 根节点 56 | Node *end; // 终节点 57 | 58 | int max_iter; 59 | 60 | public: 61 | RRTConnect(const vector> &obstacleList, 62 | const vector &randArea, const vector &playArea, double robotRadius, double expandDis, 63 | double goalSampleRate, int maxIter); 64 | 65 | vector calDistanceAngle(Node *from_node, Node *to_node); // 计算两个节点间的距离和方位角 66 | 67 | bool obstacleFree(Node *node); // 判断是否有障碍物 68 | 69 | bool isInsidePlayArea(Node *node); // 判断是否在可行区域里面 70 | 71 | int getNearestNodeIndex(vector node_list, Node *rnd_node); // 计算最近的节点 72 | 73 | Node *sampleFree(); // 采样生成节点 74 | 75 | double calDistToGoal(double x, double y); // 计算(x,y)离目标点的距离 76 | 77 | pair, vector> generateFinalCourse(); // 生成路径,与RRT不同的地方 78 | 79 | Node *steer(Node *from_node, Node *to_node, double extend_length = numeric_limits::max()); // 连线方向扩展固定步长查找x_new 80 | 81 | pair, vector> planning(); // 与RRT不同的地方 82 | 83 | void setBegin(Node *begin); 84 | 85 | void setEnd(Node *end); 86 | 87 | void plotCircle(double x, double y, double size, string color = "b"); // 画圆 88 | void draw(Node *node1 = nullptr, Node *node2 = nullptr); 89 | }; 90 | 91 | #endif // RRT_CONNECT_H 92 | -------------------------------------------------------------------------------- /Rapidly-exploring_Random_Tree_connect/main.cpp: -------------------------------------------------------------------------------- 1 | /** 2 | * @file main.cpp 3 | * @author czj 4 | * @brief 5 | * @version 0.1 6 | * @date 2023-04-24 7 | * 8 | * @copyright Copyright (c) 2023 9 | * 10 | */ 11 | 12 | #include "RRT_connect.h" 13 | 14 | int main() 15 | { 16 | vector> obstacle_list{ 17 | {5, 5, 1}, 18 | {3, 6, 2}, 19 | {3, 8, 2}, 20 | {3, 10, 2}, 21 | {7, 5, 2}, 22 | {9, 5, 2}, 23 | {8, 10, 1}, 24 | {6, 12, 1}}; 25 | Node *begin = new Node(0.0, 0.0); 26 | Node *end = new Node(6.0, 10.0); 27 | vector rnd_area{-2, 15}; // 采样区域 x,y ∈ [min,max]; 28 | vector play_area{-2, 12, 0, 14}; // 约束随机树的范围 [xmin,xmax,ymin,ymax] 29 | double radius = 0.8; // 机器人半径 30 | double expand_dis = 2; // 扩展的步长 31 | double goal_sample_rate = 5; // 采样目标点的概率,百分制.default: 5,即表示5%的概率直接采样目标点 32 | int max_iter = 500; // 最大采样点数 33 | RRTConnect rrt(obstacle_list, rnd_area, play_area, radius, expand_dis, goal_sample_rate, max_iter); 34 | rrt.setBegin(begin); 35 | rrt.setEnd(end); 36 | 37 | pair, vector> traj = rrt.planning(); 38 | 39 | plt::plot(traj.first, traj.second, "r"); 40 | 41 | // // 合成 GIF 图片 42 | // stringstream filename; 43 | // filename << "./frame_" << gifindex << ".png"; 44 | // plt::save(filename.str()); 45 | 46 | // const char *gif_filename = "./rrt_connect.gif"; 47 | // stringstream cmd; 48 | // cmd << "convert -delay 10 -loop 0 ./frame_*.png " << gif_filename; 49 | // system(cmd.str().c_str()); 50 | // cout << "Saving result to " << gif_filename << std::endl; 51 | // plt::show(); 52 | // //删除png图片 53 | // system("rm *.png"); 54 | // return 0; 55 | 56 | const char *filename = "./rrt_connect_demo.png"; 57 | cout << "Saving result to " << filename << std::endl; 58 | plt::save(filename); 59 | plt::show(); 60 | 61 | return 0; 62 | } -------------------------------------------------------------------------------- /gif/astar.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/gif/astar.gif -------------------------------------------------------------------------------- /gif/b_spline_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/gif/b_spline_demo.gif -------------------------------------------------------------------------------- /gif/bezier_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/gif/bezier_demo.gif -------------------------------------------------------------------------------- /gif/dijkstra_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/gif/dijkstra_demo.gif -------------------------------------------------------------------------------- /gif/rrt_connect.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/gif/rrt_connect.gif -------------------------------------------------------------------------------- /gif/rrt_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/gif/rrt_demo.gif -------------------------------------------------------------------------------- /gif/rrt_star_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/czjaixuexi/path_planning/83627373069fe4d84dae0fde2de125dcde909a16/gif/rrt_star_demo.gif -------------------------------------------------------------------------------- /matplotlibcpp.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | // Python headers must be included before any system headers, since 4 | // they define _POSIX_C_SOURCE 5 | #include 6 | 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include // requires c++11 support 15 | #include 16 | #include // std::stod 17 | 18 | #ifndef WITHOUT_NUMPY 19 | # define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION 20 | # include 21 | 22 | # ifdef WITH_OPENCV 23 | # include 24 | # endif // WITH_OPENCV 25 | 26 | /* 27 | * A bunch of constants were removed in OpenCV 4 in favour of enum classes, so 28 | * define the ones we need here. 29 | */ 30 | # if CV_MAJOR_VERSION > 3 31 | # define CV_BGR2RGB cv::COLOR_BGR2RGB 32 | # define CV_BGRA2RGBA cv::COLOR_BGRA2RGBA 33 | # endif 34 | #endif // WITHOUT_NUMPY 35 | 36 | #if PY_MAJOR_VERSION >= 3 37 | # define PyString_FromString PyUnicode_FromString 38 | # define PyInt_FromLong PyLong_FromLong 39 | # define PyString_FromString PyUnicode_FromString 40 | #endif 41 | 42 | 43 | namespace matplotlibcpp { 44 | namespace detail { 45 | 46 | static std::string s_backend; 47 | 48 | struct _interpreter { 49 | PyObject* s_python_function_arrow; 50 | PyObject *s_python_function_show; 51 | PyObject *s_python_function_close; 52 | PyObject *s_python_function_draw; 53 | PyObject *s_python_function_pause; 54 | PyObject *s_python_function_save; 55 | PyObject *s_python_function_figure; 56 | PyObject *s_python_function_fignum_exists; 57 | PyObject *s_python_function_plot; 58 | PyObject *s_python_function_quiver; 59 | PyObject* s_python_function_contour; 60 | PyObject *s_python_function_semilogx; 61 | PyObject *s_python_function_semilogy; 62 | PyObject *s_python_function_loglog; 63 | PyObject *s_python_function_fill; 64 | PyObject *s_python_function_fill_between; 65 | PyObject *s_python_function_hist; 66 | PyObject *s_python_function_imshow; 67 | PyObject *s_python_function_scatter; 68 | PyObject *s_python_function_boxplot; 69 | PyObject *s_python_function_subplot; 70 | PyObject *s_python_function_subplot2grid; 71 | PyObject *s_python_function_legend; 72 | PyObject *s_python_function_xlim; 73 | PyObject *s_python_function_ion; 74 | PyObject *s_python_function_ginput; 75 | PyObject *s_python_function_ylim; 76 | PyObject *s_python_function_title; 77 | PyObject *s_python_function_axis; 78 | PyObject *s_python_function_axhline; 79 | PyObject *s_python_function_axvline; 80 | PyObject *s_python_function_axvspan; 81 | PyObject *s_python_function_xlabel; 82 | PyObject *s_python_function_ylabel; 83 | PyObject *s_python_function_gca; 84 | PyObject *s_python_function_xticks; 85 | PyObject *s_python_function_yticks; 86 | PyObject* s_python_function_margins; 87 | PyObject *s_python_function_tick_params; 88 | PyObject *s_python_function_grid; 89 | PyObject* s_python_function_cla; 90 | PyObject *s_python_function_clf; 91 | PyObject *s_python_function_errorbar; 92 | PyObject *s_python_function_annotate; 93 | PyObject *s_python_function_tight_layout; 94 | PyObject *s_python_colormap; 95 | PyObject *s_python_empty_tuple; 96 | PyObject *s_python_function_stem; 97 | PyObject *s_python_function_xkcd; 98 | PyObject *s_python_function_text; 99 | PyObject *s_python_function_suptitle; 100 | PyObject *s_python_function_bar; 101 | PyObject *s_python_function_barh; 102 | PyObject *s_python_function_colorbar; 103 | PyObject *s_python_function_subplots_adjust; 104 | PyObject *s_python_function_rcparams; 105 | PyObject *s_python_function_spy; 106 | 107 | /* For now, _interpreter is implemented as a singleton since its currently not possible to have 108 | multiple independent embedded python interpreters without patching the python source code 109 | or starting a separate process for each. [1] 110 | Furthermore, many python objects expect that they are destructed in the same thread as they 111 | were constructed. [2] So for advanced usage, a `kill()` function is provided so that library 112 | users can manually ensure that the interpreter is constructed and destroyed within the 113 | same thread. 114 | 115 | 1: http://bytes.com/topic/python/answers/793370-multiple-independent-python-interpreters-c-c-program 116 | 2: https://github.com/lava/matplotlib-cpp/pull/202#issue-436220256 117 | */ 118 | 119 | static _interpreter& get() { 120 | return interkeeper(false); 121 | } 122 | 123 | static _interpreter& kill() { 124 | return interkeeper(true); 125 | } 126 | 127 | // Stores the actual singleton object referenced by `get()` and `kill()`. 128 | static _interpreter& interkeeper(bool should_kill) { 129 | static _interpreter ctx; 130 | if (should_kill) 131 | ctx.~_interpreter(); 132 | return ctx; 133 | } 134 | 135 | PyObject* safe_import(PyObject* module, std::string fname) { 136 | PyObject* fn = PyObject_GetAttrString(module, fname.c_str()); 137 | 138 | if (!fn) 139 | throw std::runtime_error(std::string("Couldn't find required function: ") + fname); 140 | 141 | if (!PyFunction_Check(fn)) 142 | throw std::runtime_error(fname + std::string(" is unexpectedly not a PyFunction.")); 143 | 144 | return fn; 145 | } 146 | 147 | private: 148 | 149 | #ifndef WITHOUT_NUMPY 150 | # if PY_MAJOR_VERSION >= 3 151 | 152 | void *import_numpy() { 153 | import_array(); // initialize C-API 154 | return NULL; 155 | } 156 | 157 | # else 158 | 159 | void import_numpy() { 160 | import_array(); // initialize C-API 161 | } 162 | 163 | # endif 164 | #endif 165 | 166 | _interpreter() { 167 | 168 | // optional but recommended 169 | #if PY_MAJOR_VERSION >= 3 170 | wchar_t name[] = L"plotting"; 171 | #else 172 | char name[] = "plotting"; 173 | #endif 174 | Py_SetProgramName(name); 175 | Py_Initialize(); 176 | 177 | wchar_t const *dummy_args[] = {L"Python", NULL}; // const is needed because literals must not be modified 178 | wchar_t const **argv = dummy_args; 179 | int argc = sizeof(dummy_args)/sizeof(dummy_args[0])-1; 180 | 181 | #if PY_MAJOR_VERSION >= 3 182 | PySys_SetArgv(argc, const_cast(argv)); 183 | #else 184 | PySys_SetArgv(argc, (char **)(argv)); 185 | #endif 186 | 187 | #ifndef WITHOUT_NUMPY 188 | import_numpy(); // initialize numpy C-API 189 | #endif 190 | 191 | PyObject* matplotlibname = PyString_FromString("matplotlib"); 192 | PyObject* pyplotname = PyString_FromString("matplotlib.pyplot"); 193 | PyObject* cmname = PyString_FromString("matplotlib.cm"); 194 | PyObject* pylabname = PyString_FromString("pylab"); 195 | if (!pyplotname || !pylabname || !matplotlibname || !cmname) { 196 | throw std::runtime_error("couldnt create string"); 197 | } 198 | 199 | PyObject* matplotlib = PyImport_Import(matplotlibname); 200 | 201 | Py_DECREF(matplotlibname); 202 | if (!matplotlib) { 203 | PyErr_Print(); 204 | throw std::runtime_error("Error loading module matplotlib!"); 205 | } 206 | 207 | // matplotlib.use() must be called *before* pylab, matplotlib.pyplot, 208 | // or matplotlib.backends is imported for the first time 209 | if (!s_backend.empty()) { 210 | PyObject_CallMethod(matplotlib, const_cast("use"), const_cast("s"), s_backend.c_str()); 211 | } 212 | 213 | 214 | 215 | PyObject* pymod = PyImport_Import(pyplotname); 216 | Py_DECREF(pyplotname); 217 | if (!pymod) { throw std::runtime_error("Error loading module matplotlib.pyplot!"); } 218 | 219 | s_python_colormap = PyImport_Import(cmname); 220 | Py_DECREF(cmname); 221 | if (!s_python_colormap) { throw std::runtime_error("Error loading module matplotlib.cm!"); } 222 | 223 | PyObject* pylabmod = PyImport_Import(pylabname); 224 | Py_DECREF(pylabname); 225 | if (!pylabmod) { throw std::runtime_error("Error loading module pylab!"); } 226 | 227 | s_python_function_arrow = safe_import(pymod, "arrow"); 228 | s_python_function_show = safe_import(pymod, "show"); 229 | s_python_function_close = safe_import(pymod, "close"); 230 | s_python_function_draw = safe_import(pymod, "draw"); 231 | s_python_function_pause = safe_import(pymod, "pause"); 232 | s_python_function_figure = safe_import(pymod, "figure"); 233 | s_python_function_fignum_exists = safe_import(pymod, "fignum_exists"); 234 | s_python_function_plot = safe_import(pymod, "plot"); 235 | s_python_function_quiver = safe_import(pymod, "quiver"); 236 | s_python_function_contour = safe_import(pymod, "contour"); 237 | s_python_function_semilogx = safe_import(pymod, "semilogx"); 238 | s_python_function_semilogy = safe_import(pymod, "semilogy"); 239 | s_python_function_loglog = safe_import(pymod, "loglog"); 240 | s_python_function_fill = safe_import(pymod, "fill"); 241 | s_python_function_fill_between = safe_import(pymod, "fill_between"); 242 | s_python_function_hist = safe_import(pymod,"hist"); 243 | s_python_function_scatter = safe_import(pymod,"scatter"); 244 | s_python_function_boxplot = safe_import(pymod,"boxplot"); 245 | s_python_function_subplot = safe_import(pymod, "subplot"); 246 | s_python_function_subplot2grid = safe_import(pymod, "subplot2grid"); 247 | s_python_function_legend = safe_import(pymod, "legend"); 248 | s_python_function_xlim = safe_import(pymod, "xlim"); 249 | s_python_function_ylim = safe_import(pymod, "ylim"); 250 | s_python_function_title = safe_import(pymod, "title"); 251 | s_python_function_axis = safe_import(pymod, "axis"); 252 | s_python_function_axhline = safe_import(pymod, "axhline"); 253 | s_python_function_axvline = safe_import(pymod, "axvline"); 254 | s_python_function_axvspan = safe_import(pymod, "axvspan"); 255 | s_python_function_xlabel = safe_import(pymod, "xlabel"); 256 | s_python_function_ylabel = safe_import(pymod, "ylabel"); 257 | s_python_function_gca = safe_import(pymod, "gca"); 258 | s_python_function_xticks = safe_import(pymod, "xticks"); 259 | s_python_function_yticks = safe_import(pymod, "yticks"); 260 | s_python_function_margins = safe_import(pymod, "margins"); 261 | s_python_function_tick_params = safe_import(pymod, "tick_params"); 262 | s_python_function_grid = safe_import(pymod, "grid"); 263 | s_python_function_ion = safe_import(pymod, "ion"); 264 | s_python_function_ginput = safe_import(pymod, "ginput"); 265 | s_python_function_save = safe_import(pylabmod, "savefig"); 266 | s_python_function_annotate = safe_import(pymod,"annotate"); 267 | s_python_function_cla = safe_import(pymod, "cla"); 268 | s_python_function_clf = safe_import(pymod, "clf"); 269 | s_python_function_errorbar = safe_import(pymod, "errorbar"); 270 | s_python_function_tight_layout = safe_import(pymod, "tight_layout"); 271 | s_python_function_stem = safe_import(pymod, "stem"); 272 | s_python_function_xkcd = safe_import(pymod, "xkcd"); 273 | s_python_function_text = safe_import(pymod, "text"); 274 | s_python_function_suptitle = safe_import(pymod, "suptitle"); 275 | s_python_function_bar = safe_import(pymod,"bar"); 276 | s_python_function_barh = safe_import(pymod, "barh"); 277 | s_python_function_colorbar = PyObject_GetAttrString(pymod, "colorbar"); 278 | s_python_function_subplots_adjust = safe_import(pymod,"subplots_adjust"); 279 | s_python_function_rcparams = PyObject_GetAttrString(pymod, "rcParams"); 280 | s_python_function_spy = PyObject_GetAttrString(pymod, "spy"); 281 | #ifndef WITHOUT_NUMPY 282 | s_python_function_imshow = safe_import(pymod, "imshow"); 283 | #endif 284 | s_python_empty_tuple = PyTuple_New(0); 285 | } 286 | 287 | ~_interpreter() { 288 | Py_Finalize(); 289 | } 290 | }; 291 | 292 | } // end namespace detail 293 | 294 | /// Select the backend 295 | /// 296 | /// **NOTE:** This must be called before the first plot command to have 297 | /// any effect. 298 | /// 299 | /// Mainly useful to select the non-interactive 'Agg' backend when running 300 | /// matplotlibcpp in headless mode, for example on a machine with no display. 301 | /// 302 | /// See also: https://matplotlib.org/2.0.2/api/matplotlib_configuration_api.html#matplotlib.use 303 | inline void backend(const std::string& name) 304 | { 305 | detail::s_backend = name; 306 | } 307 | 308 | inline bool annotate(std::string annotation, double x, double y) 309 | { 310 | detail::_interpreter::get(); 311 | 312 | PyObject * xy = PyTuple_New(2); 313 | PyObject * str = PyString_FromString(annotation.c_str()); 314 | 315 | PyTuple_SetItem(xy,0,PyFloat_FromDouble(x)); 316 | PyTuple_SetItem(xy,1,PyFloat_FromDouble(y)); 317 | 318 | PyObject* kwargs = PyDict_New(); 319 | PyDict_SetItemString(kwargs, "xy", xy); 320 | 321 | PyObject* args = PyTuple_New(1); 322 | PyTuple_SetItem(args, 0, str); 323 | 324 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_annotate, args, kwargs); 325 | 326 | Py_DECREF(args); 327 | Py_DECREF(kwargs); 328 | 329 | if(res) Py_DECREF(res); 330 | 331 | return res; 332 | } 333 | 334 | namespace detail { 335 | 336 | #ifndef WITHOUT_NUMPY 337 | // Type selector for numpy array conversion 338 | template struct select_npy_type { const static NPY_TYPES type = NPY_NOTYPE; }; //Default 339 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_DOUBLE; }; 340 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_FLOAT; }; 341 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_BOOL; }; 342 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_INT8; }; 343 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_SHORT; }; 344 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_INT; }; 345 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_INT64; }; 346 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_UINT8; }; 347 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_USHORT; }; 348 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_ULONG; }; 349 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_UINT64; }; 350 | 351 | // Sanity checks; comment them out or change the numpy type below if you're compiling on 352 | // a platform where they don't apply 353 | static_assert(sizeof(long long) == 8); 354 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_INT64; }; 355 | static_assert(sizeof(unsigned long long) == 8); 356 | template <> struct select_npy_type { const static NPY_TYPES type = NPY_UINT64; }; 357 | 358 | template 359 | PyObject* get_array(const std::vector& v) 360 | { 361 | npy_intp vsize = v.size(); 362 | NPY_TYPES type = select_npy_type::type; 363 | if (type == NPY_NOTYPE) { 364 | size_t memsize = v.size()*sizeof(double); 365 | double* dp = static_cast(::malloc(memsize)); 366 | for (size_t i=0; i(varray), NPY_ARRAY_OWNDATA); 370 | return varray; 371 | } 372 | 373 | PyObject* varray = PyArray_SimpleNewFromData(1, &vsize, type, (void*)(v.data())); 374 | return varray; 375 | } 376 | 377 | 378 | template 379 | PyObject* get_2darray(const std::vector<::std::vector>& v) 380 | { 381 | if (v.size() < 1) throw std::runtime_error("get_2d_array v too small"); 382 | 383 | npy_intp vsize[2] = {static_cast(v.size()), 384 | static_cast(v[0].size())}; 385 | 386 | PyArrayObject *varray = 387 | (PyArrayObject *)PyArray_SimpleNew(2, vsize, NPY_DOUBLE); 388 | 389 | double *vd_begin = static_cast(PyArray_DATA(varray)); 390 | 391 | for (const ::std::vector &v_row : v) { 392 | if (v_row.size() != static_cast(vsize[1])) 393 | throw std::runtime_error("Missmatched array size"); 394 | std::copy(v_row.begin(), v_row.end(), vd_begin); 395 | vd_begin += vsize[1]; 396 | } 397 | 398 | return reinterpret_cast(varray); 399 | } 400 | 401 | #else // fallback if we don't have numpy: copy every element of the given vector 402 | 403 | template 404 | PyObject* get_array(const std::vector& v) 405 | { 406 | PyObject* list = PyList_New(v.size()); 407 | for(size_t i = 0; i < v.size(); ++i) { 408 | PyList_SetItem(list, i, PyFloat_FromDouble(v.at(i))); 409 | } 410 | return list; 411 | } 412 | 413 | #endif // WITHOUT_NUMPY 414 | 415 | // sometimes, for labels and such, we need string arrays 416 | inline PyObject * get_array(const std::vector& strings) 417 | { 418 | PyObject* list = PyList_New(strings.size()); 419 | for (std::size_t i = 0; i < strings.size(); ++i) { 420 | PyList_SetItem(list, i, PyString_FromString(strings[i].c_str())); 421 | } 422 | return list; 423 | } 424 | 425 | // not all matplotlib need 2d arrays, some prefer lists of lists 426 | template 427 | PyObject* get_listlist(const std::vector>& ll) 428 | { 429 | PyObject* listlist = PyList_New(ll.size()); 430 | for (std::size_t i = 0; i < ll.size(); ++i) { 431 | PyList_SetItem(listlist, i, get_array(ll[i])); 432 | } 433 | return listlist; 434 | } 435 | 436 | } // namespace detail 437 | 438 | /// Plot a line through the given x and y data points.. 439 | /// 440 | /// See: https://matplotlib.org/3.2.1/api/_as_gen/matplotlib.pyplot.plot.html 441 | template 442 | bool plot(const std::vector &x, const std::vector &y, const std::map& keywords) 443 | { 444 | assert(x.size() == y.size()); 445 | 446 | detail::_interpreter::get(); 447 | 448 | // using numpy arrays 449 | PyObject* xarray = detail::get_array(x); 450 | PyObject* yarray = detail::get_array(y); 451 | 452 | // construct positional args 453 | PyObject* args = PyTuple_New(2); 454 | PyTuple_SetItem(args, 0, xarray); 455 | PyTuple_SetItem(args, 1, yarray); 456 | 457 | // construct keyword args 458 | PyObject* kwargs = PyDict_New(); 459 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 460 | { 461 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 462 | } 463 | 464 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_plot, args, kwargs); 465 | 466 | Py_DECREF(args); 467 | Py_DECREF(kwargs); 468 | if(res) Py_DECREF(res); 469 | 470 | return res; 471 | } 472 | 473 | // TODO - it should be possible to make this work by implementing 474 | // a non-numpy alternative for `detail::get_2darray()`. 475 | #ifndef WITHOUT_NUMPY 476 | template 477 | void plot_surface(const std::vector<::std::vector> &x, 478 | const std::vector<::std::vector> &y, 479 | const std::vector<::std::vector> &z, 480 | const std::map &keywords = 481 | std::map(), 482 | const long fig_number=0) 483 | { 484 | detail::_interpreter::get(); 485 | 486 | // We lazily load the modules here the first time this function is called 487 | // because I'm not sure that we can assume "matplotlib installed" implies 488 | // "mpl_toolkits installed" on all platforms, and we don't want to require 489 | // it for people who don't need 3d plots. 490 | static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; 491 | if (!mpl_toolkitsmod) { 492 | detail::_interpreter::get(); 493 | 494 | PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits"); 495 | PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d"); 496 | if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); } 497 | 498 | mpl_toolkitsmod = PyImport_Import(mpl_toolkits); 499 | Py_DECREF(mpl_toolkits); 500 | if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); } 501 | 502 | axis3dmod = PyImport_Import(axis3d); 503 | Py_DECREF(axis3d); 504 | if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); } 505 | } 506 | 507 | assert(x.size() == y.size()); 508 | assert(y.size() == z.size()); 509 | 510 | // using numpy arrays 511 | PyObject *xarray = detail::get_2darray(x); 512 | PyObject *yarray = detail::get_2darray(y); 513 | PyObject *zarray = detail::get_2darray(z); 514 | 515 | // construct positional args 516 | PyObject *args = PyTuple_New(3); 517 | PyTuple_SetItem(args, 0, xarray); 518 | PyTuple_SetItem(args, 1, yarray); 519 | PyTuple_SetItem(args, 2, zarray); 520 | 521 | // Build up the kw args. 522 | PyObject *kwargs = PyDict_New(); 523 | PyDict_SetItemString(kwargs, "rstride", PyInt_FromLong(1)); 524 | PyDict_SetItemString(kwargs, "cstride", PyInt_FromLong(1)); 525 | 526 | PyObject *python_colormap_coolwarm = PyObject_GetAttrString( 527 | detail::_interpreter::get().s_python_colormap, "coolwarm"); 528 | 529 | PyDict_SetItemString(kwargs, "cmap", python_colormap_coolwarm); 530 | 531 | for (std::map::const_iterator it = keywords.begin(); 532 | it != keywords.end(); ++it) { 533 | if (it->first == "linewidth" || it->first == "alpha") { 534 | PyDict_SetItemString(kwargs, it->first.c_str(), 535 | PyFloat_FromDouble(std::stod(it->second))); 536 | } else { 537 | PyDict_SetItemString(kwargs, it->first.c_str(), 538 | PyString_FromString(it->second.c_str())); 539 | } 540 | } 541 | 542 | PyObject *fig_args = PyTuple_New(1); 543 | PyObject* fig = nullptr; 544 | PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number)); 545 | PyObject *fig_exists = 546 | PyObject_CallObject( 547 | detail::_interpreter::get().s_python_function_fignum_exists, fig_args); 548 | if (!PyObject_IsTrue(fig_exists)) { 549 | fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, 550 | detail::_interpreter::get().s_python_empty_tuple); 551 | } else { 552 | fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, 553 | fig_args); 554 | } 555 | Py_DECREF(fig_exists); 556 | if (!fig) throw std::runtime_error("Call to figure() failed."); 557 | 558 | PyObject *gca_kwargs = PyDict_New(); 559 | PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); 560 | 561 | PyObject *gca = PyObject_GetAttrString(fig, "gca"); 562 | if (!gca) throw std::runtime_error("No gca"); 563 | Py_INCREF(gca); 564 | PyObject *axis = PyObject_Call( 565 | gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); 566 | 567 | if (!axis) throw std::runtime_error("No axis"); 568 | Py_INCREF(axis); 569 | 570 | Py_DECREF(gca); 571 | Py_DECREF(gca_kwargs); 572 | 573 | PyObject *plot_surface = PyObject_GetAttrString(axis, "plot_surface"); 574 | if (!plot_surface) throw std::runtime_error("No surface"); 575 | Py_INCREF(plot_surface); 576 | PyObject *res = PyObject_Call(plot_surface, args, kwargs); 577 | if (!res) throw std::runtime_error("failed surface"); 578 | Py_DECREF(plot_surface); 579 | 580 | Py_DECREF(axis); 581 | Py_DECREF(args); 582 | Py_DECREF(kwargs); 583 | if (res) Py_DECREF(res); 584 | } 585 | 586 | template 587 | void contour(const std::vector<::std::vector> &x, 588 | const std::vector<::std::vector> &y, 589 | const std::vector<::std::vector> &z, 590 | const std::map &keywords = {}) 591 | { 592 | detail::_interpreter::get(); 593 | 594 | // using numpy arrays 595 | PyObject *xarray = detail::get_2darray(x); 596 | PyObject *yarray = detail::get_2darray(y); 597 | PyObject *zarray = detail::get_2darray(z); 598 | 599 | // construct positional args 600 | PyObject *args = PyTuple_New(3); 601 | PyTuple_SetItem(args, 0, xarray); 602 | PyTuple_SetItem(args, 1, yarray); 603 | PyTuple_SetItem(args, 2, zarray); 604 | 605 | // Build up the kw args. 606 | PyObject *kwargs = PyDict_New(); 607 | 608 | PyObject *python_colormap_coolwarm = PyObject_GetAttrString( 609 | detail::_interpreter::get().s_python_colormap, "coolwarm"); 610 | 611 | PyDict_SetItemString(kwargs, "cmap", python_colormap_coolwarm); 612 | 613 | for (std::map::const_iterator it = keywords.begin(); 614 | it != keywords.end(); ++it) { 615 | PyDict_SetItemString(kwargs, it->first.c_str(), 616 | PyString_FromString(it->second.c_str())); 617 | } 618 | 619 | PyObject *res = PyObject_Call(detail::_interpreter::get().s_python_function_contour, args, kwargs); 620 | if (!res) 621 | throw std::runtime_error("failed contour"); 622 | 623 | Py_DECREF(args); 624 | Py_DECREF(kwargs); 625 | if (res) Py_DECREF(res); 626 | } 627 | 628 | template 629 | void spy(const std::vector<::std::vector> &x, 630 | const double markersize = -1, // -1 for default matplotlib size 631 | const std::map &keywords = {}) 632 | { 633 | detail::_interpreter::get(); 634 | 635 | PyObject *xarray = detail::get_2darray(x); 636 | 637 | PyObject *kwargs = PyDict_New(); 638 | if (markersize != -1) { 639 | PyDict_SetItemString(kwargs, "markersize", PyFloat_FromDouble(markersize)); 640 | } 641 | for (std::map::const_iterator it = keywords.begin(); 642 | it != keywords.end(); ++it) { 643 | PyDict_SetItemString(kwargs, it->first.c_str(), 644 | PyString_FromString(it->second.c_str())); 645 | } 646 | 647 | PyObject *plot_args = PyTuple_New(1); 648 | PyTuple_SetItem(plot_args, 0, xarray); 649 | 650 | PyObject *res = PyObject_Call( 651 | detail::_interpreter::get().s_python_function_spy, plot_args, kwargs); 652 | 653 | Py_DECREF(plot_args); 654 | Py_DECREF(kwargs); 655 | if (res) Py_DECREF(res); 656 | } 657 | #endif // WITHOUT_NUMPY 658 | 659 | template 660 | void plot3(const std::vector &x, 661 | const std::vector &y, 662 | const std::vector &z, 663 | const std::map &keywords = 664 | std::map(), 665 | const long fig_number=0) 666 | { 667 | detail::_interpreter::get(); 668 | 669 | // Same as with plot_surface: We lazily load the modules here the first time 670 | // this function is called because I'm not sure that we can assume "matplotlib 671 | // installed" implies "mpl_toolkits installed" on all platforms, and we don't 672 | // want to require it for people who don't need 3d plots. 673 | static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; 674 | if (!mpl_toolkitsmod) { 675 | detail::_interpreter::get(); 676 | 677 | PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits"); 678 | PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d"); 679 | if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); } 680 | 681 | mpl_toolkitsmod = PyImport_Import(mpl_toolkits); 682 | Py_DECREF(mpl_toolkits); 683 | if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); } 684 | 685 | axis3dmod = PyImport_Import(axis3d); 686 | Py_DECREF(axis3d); 687 | if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); } 688 | } 689 | 690 | assert(x.size() == y.size()); 691 | assert(y.size() == z.size()); 692 | 693 | PyObject *xarray = detail::get_array(x); 694 | PyObject *yarray = detail::get_array(y); 695 | PyObject *zarray = detail::get_array(z); 696 | 697 | // construct positional args 698 | PyObject *args = PyTuple_New(3); 699 | PyTuple_SetItem(args, 0, xarray); 700 | PyTuple_SetItem(args, 1, yarray); 701 | PyTuple_SetItem(args, 2, zarray); 702 | 703 | // Build up the kw args. 704 | PyObject *kwargs = PyDict_New(); 705 | 706 | for (std::map::const_iterator it = keywords.begin(); 707 | it != keywords.end(); ++it) { 708 | PyDict_SetItemString(kwargs, it->first.c_str(), 709 | PyString_FromString(it->second.c_str())); 710 | } 711 | 712 | PyObject *fig_args = PyTuple_New(1); 713 | PyObject* fig = nullptr; 714 | PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number)); 715 | PyObject *fig_exists = 716 | PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args); 717 | if (!PyObject_IsTrue(fig_exists)) { 718 | fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, 719 | detail::_interpreter::get().s_python_empty_tuple); 720 | } else { 721 | fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, 722 | fig_args); 723 | } 724 | if (!fig) throw std::runtime_error("Call to figure() failed."); 725 | 726 | PyObject *gca_kwargs = PyDict_New(); 727 | PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); 728 | 729 | PyObject *gca = PyObject_GetAttrString(fig, "gca"); 730 | if (!gca) throw std::runtime_error("No gca"); 731 | Py_INCREF(gca); 732 | PyObject *axis = PyObject_Call( 733 | gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); 734 | 735 | if (!axis) throw std::runtime_error("No axis"); 736 | Py_INCREF(axis); 737 | 738 | Py_DECREF(gca); 739 | Py_DECREF(gca_kwargs); 740 | 741 | PyObject *plot3 = PyObject_GetAttrString(axis, "plot"); 742 | if (!plot3) throw std::runtime_error("No 3D line plot"); 743 | Py_INCREF(plot3); 744 | PyObject *res = PyObject_Call(plot3, args, kwargs); 745 | if (!res) throw std::runtime_error("Failed 3D line plot"); 746 | Py_DECREF(plot3); 747 | 748 | Py_DECREF(axis); 749 | Py_DECREF(args); 750 | Py_DECREF(kwargs); 751 | if (res) Py_DECREF(res); 752 | } 753 | 754 | template 755 | bool stem(const std::vector &x, const std::vector &y, const std::map& keywords) 756 | { 757 | assert(x.size() == y.size()); 758 | 759 | detail::_interpreter::get(); 760 | 761 | // using numpy arrays 762 | PyObject* xarray = detail::get_array(x); 763 | PyObject* yarray = detail::get_array(y); 764 | 765 | // construct positional args 766 | PyObject* args = PyTuple_New(2); 767 | PyTuple_SetItem(args, 0, xarray); 768 | PyTuple_SetItem(args, 1, yarray); 769 | 770 | // construct keyword args 771 | PyObject* kwargs = PyDict_New(); 772 | for (std::map::const_iterator it = 773 | keywords.begin(); it != keywords.end(); ++it) { 774 | PyDict_SetItemString(kwargs, it->first.c_str(), 775 | PyString_FromString(it->second.c_str())); 776 | } 777 | 778 | PyObject* res = PyObject_Call( 779 | detail::_interpreter::get().s_python_function_stem, args, kwargs); 780 | 781 | Py_DECREF(args); 782 | Py_DECREF(kwargs); 783 | if (res) 784 | Py_DECREF(res); 785 | 786 | return res; 787 | } 788 | 789 | template< typename Numeric > 790 | bool fill(const std::vector& x, const std::vector& y, const std::map& keywords) 791 | { 792 | assert(x.size() == y.size()); 793 | 794 | detail::_interpreter::get(); 795 | 796 | // using numpy arrays 797 | PyObject* xarray = detail::get_array(x); 798 | PyObject* yarray = detail::get_array(y); 799 | 800 | // construct positional args 801 | PyObject* args = PyTuple_New(2); 802 | PyTuple_SetItem(args, 0, xarray); 803 | PyTuple_SetItem(args, 1, yarray); 804 | 805 | // construct keyword args 806 | PyObject* kwargs = PyDict_New(); 807 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 808 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 809 | } 810 | 811 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_fill, args, kwargs); 812 | 813 | Py_DECREF(args); 814 | Py_DECREF(kwargs); 815 | 816 | if (res) Py_DECREF(res); 817 | 818 | return res; 819 | } 820 | 821 | template< typename Numeric > 822 | bool fill_between(const std::vector& x, const std::vector& y1, const std::vector& y2, const std::map& keywords) 823 | { 824 | assert(x.size() == y1.size()); 825 | assert(x.size() == y2.size()); 826 | 827 | detail::_interpreter::get(); 828 | 829 | // using numpy arrays 830 | PyObject* xarray = detail::get_array(x); 831 | PyObject* y1array = detail::get_array(y1); 832 | PyObject* y2array = detail::get_array(y2); 833 | 834 | // construct positional args 835 | PyObject* args = PyTuple_New(3); 836 | PyTuple_SetItem(args, 0, xarray); 837 | PyTuple_SetItem(args, 1, y1array); 838 | PyTuple_SetItem(args, 2, y2array); 839 | 840 | // construct keyword args 841 | PyObject* kwargs = PyDict_New(); 842 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) { 843 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 844 | } 845 | 846 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_fill_between, args, kwargs); 847 | 848 | Py_DECREF(args); 849 | Py_DECREF(kwargs); 850 | if(res) Py_DECREF(res); 851 | 852 | return res; 853 | } 854 | 855 | template 856 | bool arrow(Numeric x, Numeric y, Numeric end_x, Numeric end_y, const std::string& fc = "r", 857 | const std::string ec = "k", Numeric head_length = 0.25, Numeric head_width = 0.1625) { 858 | PyObject* obj_x = PyFloat_FromDouble(x); 859 | PyObject* obj_y = PyFloat_FromDouble(y); 860 | PyObject* obj_end_x = PyFloat_FromDouble(end_x); 861 | PyObject* obj_end_y = PyFloat_FromDouble(end_y); 862 | 863 | PyObject* kwargs = PyDict_New(); 864 | PyDict_SetItemString(kwargs, "fc", PyString_FromString(fc.c_str())); 865 | PyDict_SetItemString(kwargs, "ec", PyString_FromString(ec.c_str())); 866 | PyDict_SetItemString(kwargs, "head_width", PyFloat_FromDouble(head_width)); 867 | PyDict_SetItemString(kwargs, "head_length", PyFloat_FromDouble(head_length)); 868 | 869 | PyObject* plot_args = PyTuple_New(4); 870 | PyTuple_SetItem(plot_args, 0, obj_x); 871 | PyTuple_SetItem(plot_args, 1, obj_y); 872 | PyTuple_SetItem(plot_args, 2, obj_end_x); 873 | PyTuple_SetItem(plot_args, 3, obj_end_y); 874 | 875 | PyObject* res = 876 | PyObject_Call(detail::_interpreter::get().s_python_function_arrow, plot_args, kwargs); 877 | 878 | Py_DECREF(plot_args); 879 | Py_DECREF(kwargs); 880 | if (res) 881 | Py_DECREF(res); 882 | 883 | return res; 884 | } 885 | 886 | template< typename Numeric> 887 | bool hist(const std::vector& y, long bins=10,std::string color="b", 888 | double alpha=1.0, bool cumulative=false) 889 | { 890 | detail::_interpreter::get(); 891 | 892 | PyObject* yarray = detail::get_array(y); 893 | 894 | PyObject* kwargs = PyDict_New(); 895 | PyDict_SetItemString(kwargs, "bins", PyLong_FromLong(bins)); 896 | PyDict_SetItemString(kwargs, "color", PyString_FromString(color.c_str())); 897 | PyDict_SetItemString(kwargs, "alpha", PyFloat_FromDouble(alpha)); 898 | PyDict_SetItemString(kwargs, "cumulative", cumulative ? Py_True : Py_False); 899 | 900 | PyObject* plot_args = PyTuple_New(1); 901 | 902 | PyTuple_SetItem(plot_args, 0, yarray); 903 | 904 | 905 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_hist, plot_args, kwargs); 906 | 907 | 908 | Py_DECREF(plot_args); 909 | Py_DECREF(kwargs); 910 | if(res) Py_DECREF(res); 911 | 912 | return res; 913 | } 914 | 915 | #ifndef WITHOUT_NUMPY 916 | namespace detail { 917 | 918 | inline void imshow(void *ptr, const NPY_TYPES type, const int rows, const int columns, const int colors, const std::map &keywords, PyObject** out) 919 | { 920 | assert(type == NPY_UINT8 || type == NPY_FLOAT); 921 | assert(colors == 1 || colors == 3 || colors == 4); 922 | 923 | detail::_interpreter::get(); 924 | 925 | // construct args 926 | npy_intp dims[3] = { rows, columns, colors }; 927 | PyObject *args = PyTuple_New(1); 928 | PyTuple_SetItem(args, 0, PyArray_SimpleNewFromData(colors == 1 ? 2 : 3, dims, type, ptr)); 929 | 930 | // construct keyword args 931 | PyObject* kwargs = PyDict_New(); 932 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 933 | { 934 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 935 | } 936 | 937 | PyObject *res = PyObject_Call(detail::_interpreter::get().s_python_function_imshow, args, kwargs); 938 | Py_DECREF(args); 939 | Py_DECREF(kwargs); 940 | if (!res) 941 | throw std::runtime_error("Call to imshow() failed"); 942 | if (out) 943 | *out = res; 944 | else 945 | Py_DECREF(res); 946 | } 947 | 948 | } // namespace detail 949 | 950 | inline void imshow(const unsigned char *ptr, const int rows, const int columns, const int colors, const std::map &keywords = {}, PyObject** out = nullptr) 951 | { 952 | detail::imshow((void *) ptr, NPY_UINT8, rows, columns, colors, keywords, out); 953 | } 954 | 955 | inline void imshow(const float *ptr, const int rows, const int columns, const int colors, const std::map &keywords = {}, PyObject** out = nullptr) 956 | { 957 | detail::imshow((void *) ptr, NPY_FLOAT, rows, columns, colors, keywords, out); 958 | } 959 | 960 | #ifdef WITH_OPENCV 961 | void imshow(const cv::Mat &image, const std::map &keywords = {}) 962 | { 963 | // Convert underlying type of matrix, if needed 964 | cv::Mat image2; 965 | NPY_TYPES npy_type = NPY_UINT8; 966 | switch (image.type() & CV_MAT_DEPTH_MASK) { 967 | case CV_8U: 968 | image2 = image; 969 | break; 970 | case CV_32F: 971 | image2 = image; 972 | npy_type = NPY_FLOAT; 973 | break; 974 | default: 975 | image.convertTo(image2, CV_MAKETYPE(CV_8U, image.channels())); 976 | } 977 | 978 | // If color image, convert from BGR to RGB 979 | switch (image2.channels()) { 980 | case 3: 981 | cv::cvtColor(image2, image2, CV_BGR2RGB); 982 | break; 983 | case 4: 984 | cv::cvtColor(image2, image2, CV_BGRA2RGBA); 985 | } 986 | 987 | detail::imshow(image2.data, npy_type, image2.rows, image2.cols, image2.channels(), keywords); 988 | } 989 | #endif // WITH_OPENCV 990 | #endif // WITHOUT_NUMPY 991 | 992 | template 993 | bool scatter(const std::vector& x, 994 | const std::vector& y, 995 | const double s=1.0, // The marker size in points**2 996 | const std::map & keywords = {}) 997 | { 998 | detail::_interpreter::get(); 999 | 1000 | assert(x.size() == y.size()); 1001 | 1002 | PyObject* xarray = detail::get_array(x); 1003 | PyObject* yarray = detail::get_array(y); 1004 | 1005 | PyObject* kwargs = PyDict_New(); 1006 | PyDict_SetItemString(kwargs, "s", PyLong_FromLong(s)); 1007 | for (const auto& it : keywords) 1008 | { 1009 | PyDict_SetItemString(kwargs, it.first.c_str(), PyString_FromString(it.second.c_str())); 1010 | } 1011 | 1012 | PyObject* plot_args = PyTuple_New(2); 1013 | PyTuple_SetItem(plot_args, 0, xarray); 1014 | PyTuple_SetItem(plot_args, 1, yarray); 1015 | 1016 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_scatter, plot_args, kwargs); 1017 | 1018 | Py_DECREF(plot_args); 1019 | Py_DECREF(kwargs); 1020 | if(res) Py_DECREF(res); 1021 | 1022 | return res; 1023 | } 1024 | 1025 | template 1026 | bool scatter_colored(const std::vector& x, 1027 | const std::vector& y, 1028 | const std::vector& colors, 1029 | const double s=1.0, // The marker size in points**2 1030 | const std::map & keywords = {}) 1031 | { 1032 | detail::_interpreter::get(); 1033 | 1034 | assert(x.size() == y.size()); 1035 | 1036 | PyObject* xarray = detail::get_array(x); 1037 | PyObject* yarray = detail::get_array(y); 1038 | PyObject* colors_array = detail::get_array(colors); 1039 | 1040 | PyObject* kwargs = PyDict_New(); 1041 | PyDict_SetItemString(kwargs, "s", PyLong_FromLong(s)); 1042 | PyDict_SetItemString(kwargs, "c", colors_array); 1043 | 1044 | for (const auto& it : keywords) 1045 | { 1046 | PyDict_SetItemString(kwargs, it.first.c_str(), PyString_FromString(it.second.c_str())); 1047 | } 1048 | 1049 | PyObject* plot_args = PyTuple_New(2); 1050 | PyTuple_SetItem(plot_args, 0, xarray); 1051 | PyTuple_SetItem(plot_args, 1, yarray); 1052 | 1053 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_scatter, plot_args, kwargs); 1054 | 1055 | Py_DECREF(plot_args); 1056 | Py_DECREF(kwargs); 1057 | if(res) Py_DECREF(res); 1058 | 1059 | return res; 1060 | } 1061 | 1062 | 1063 | template 1064 | bool scatter(const std::vector& x, 1065 | const std::vector& y, 1066 | const std::vector& z, 1067 | const double s=1.0, // The marker size in points**2 1068 | const std::map & keywords = {}, 1069 | const long fig_number=0) { 1070 | detail::_interpreter::get(); 1071 | 1072 | // Same as with plot_surface: We lazily load the modules here the first time 1073 | // this function is called because I'm not sure that we can assume "matplotlib 1074 | // installed" implies "mpl_toolkits installed" on all platforms, and we don't 1075 | // want to require it for people who don't need 3d plots. 1076 | static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; 1077 | if (!mpl_toolkitsmod) { 1078 | detail::_interpreter::get(); 1079 | 1080 | PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits"); 1081 | PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d"); 1082 | if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); } 1083 | 1084 | mpl_toolkitsmod = PyImport_Import(mpl_toolkits); 1085 | Py_DECREF(mpl_toolkits); 1086 | if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); } 1087 | 1088 | axis3dmod = PyImport_Import(axis3d); 1089 | Py_DECREF(axis3d); 1090 | if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); } 1091 | } 1092 | 1093 | assert(x.size() == y.size()); 1094 | assert(y.size() == z.size()); 1095 | 1096 | PyObject *xarray = detail::get_array(x); 1097 | PyObject *yarray = detail::get_array(y); 1098 | PyObject *zarray = detail::get_array(z); 1099 | 1100 | // construct positional args 1101 | PyObject *args = PyTuple_New(3); 1102 | PyTuple_SetItem(args, 0, xarray); 1103 | PyTuple_SetItem(args, 1, yarray); 1104 | PyTuple_SetItem(args, 2, zarray); 1105 | 1106 | // Build up the kw args. 1107 | PyObject *kwargs = PyDict_New(); 1108 | 1109 | for (std::map::const_iterator it = keywords.begin(); 1110 | it != keywords.end(); ++it) { 1111 | PyDict_SetItemString(kwargs, it->first.c_str(), 1112 | PyString_FromString(it->second.c_str())); 1113 | } 1114 | PyObject *fig_args = PyTuple_New(1); 1115 | PyObject* fig = nullptr; 1116 | PyTuple_SetItem(fig_args, 0, PyLong_FromLong(fig_number)); 1117 | PyObject *fig_exists = 1118 | PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, fig_args); 1119 | if (!PyObject_IsTrue(fig_exists)) { 1120 | fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, 1121 | detail::_interpreter::get().s_python_empty_tuple); 1122 | } else { 1123 | fig = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, 1124 | fig_args); 1125 | } 1126 | Py_DECREF(fig_exists); 1127 | if (!fig) throw std::runtime_error("Call to figure() failed."); 1128 | 1129 | PyObject *gca_kwargs = PyDict_New(); 1130 | PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); 1131 | 1132 | PyObject *gca = PyObject_GetAttrString(fig, "gca"); 1133 | if (!gca) throw std::runtime_error("No gca"); 1134 | Py_INCREF(gca); 1135 | PyObject *axis = PyObject_Call( 1136 | gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); 1137 | 1138 | if (!axis) throw std::runtime_error("No axis"); 1139 | Py_INCREF(axis); 1140 | 1141 | Py_DECREF(gca); 1142 | Py_DECREF(gca_kwargs); 1143 | 1144 | PyObject *plot3 = PyObject_GetAttrString(axis, "scatter"); 1145 | if (!plot3) throw std::runtime_error("No 3D line plot"); 1146 | Py_INCREF(plot3); 1147 | PyObject *res = PyObject_Call(plot3, args, kwargs); 1148 | if (!res) throw std::runtime_error("Failed 3D line plot"); 1149 | Py_DECREF(plot3); 1150 | 1151 | Py_DECREF(axis); 1152 | Py_DECREF(args); 1153 | Py_DECREF(kwargs); 1154 | Py_DECREF(fig); 1155 | if (res) Py_DECREF(res); 1156 | return res; 1157 | 1158 | } 1159 | 1160 | template 1161 | bool boxplot(const std::vector>& data, 1162 | const std::vector& labels = {}, 1163 | const std::map & keywords = {}) 1164 | { 1165 | detail::_interpreter::get(); 1166 | 1167 | PyObject* listlist = detail::get_listlist(data); 1168 | PyObject* args = PyTuple_New(1); 1169 | PyTuple_SetItem(args, 0, listlist); 1170 | 1171 | PyObject* kwargs = PyDict_New(); 1172 | 1173 | // kwargs needs the labels, if there are (the correct number of) labels 1174 | if (!labels.empty() && labels.size() == data.size()) { 1175 | PyDict_SetItemString(kwargs, "labels", detail::get_array(labels)); 1176 | } 1177 | 1178 | // take care of the remaining keywords 1179 | for (const auto& it : keywords) 1180 | { 1181 | PyDict_SetItemString(kwargs, it.first.c_str(), PyString_FromString(it.second.c_str())); 1182 | } 1183 | 1184 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_boxplot, args, kwargs); 1185 | 1186 | Py_DECREF(args); 1187 | Py_DECREF(kwargs); 1188 | 1189 | if(res) Py_DECREF(res); 1190 | 1191 | return res; 1192 | } 1193 | 1194 | template 1195 | bool boxplot(const std::vector& data, 1196 | const std::map & keywords = {}) 1197 | { 1198 | detail::_interpreter::get(); 1199 | 1200 | PyObject* vector = detail::get_array(data); 1201 | PyObject* args = PyTuple_New(1); 1202 | PyTuple_SetItem(args, 0, vector); 1203 | 1204 | PyObject* kwargs = PyDict_New(); 1205 | for (const auto& it : keywords) 1206 | { 1207 | PyDict_SetItemString(kwargs, it.first.c_str(), PyString_FromString(it.second.c_str())); 1208 | } 1209 | 1210 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_boxplot, args, kwargs); 1211 | 1212 | Py_DECREF(args); 1213 | Py_DECREF(kwargs); 1214 | 1215 | if(res) Py_DECREF(res); 1216 | 1217 | return res; 1218 | } 1219 | 1220 | template 1221 | bool bar(const std::vector & x, 1222 | const std::vector & y, 1223 | std::string ec = "black", 1224 | std::string ls = "-", 1225 | double lw = 1.0, 1226 | const std::map & keywords = {}) 1227 | { 1228 | detail::_interpreter::get(); 1229 | 1230 | PyObject * xarray = detail::get_array(x); 1231 | PyObject * yarray = detail::get_array(y); 1232 | 1233 | PyObject * kwargs = PyDict_New(); 1234 | 1235 | PyDict_SetItemString(kwargs, "ec", PyString_FromString(ec.c_str())); 1236 | PyDict_SetItemString(kwargs, "ls", PyString_FromString(ls.c_str())); 1237 | PyDict_SetItemString(kwargs, "lw", PyFloat_FromDouble(lw)); 1238 | 1239 | for (std::map::const_iterator it = 1240 | keywords.begin(); 1241 | it != keywords.end(); 1242 | ++it) { 1243 | PyDict_SetItemString( 1244 | kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 1245 | } 1246 | 1247 | PyObject * plot_args = PyTuple_New(2); 1248 | PyTuple_SetItem(plot_args, 0, xarray); 1249 | PyTuple_SetItem(plot_args, 1, yarray); 1250 | 1251 | PyObject * res = PyObject_Call( 1252 | detail::_interpreter::get().s_python_function_bar, plot_args, kwargs); 1253 | 1254 | Py_DECREF(plot_args); 1255 | Py_DECREF(kwargs); 1256 | if (res) Py_DECREF(res); 1257 | 1258 | return res; 1259 | } 1260 | 1261 | template 1262 | bool bar(const std::vector & y, 1263 | std::string ec = "black", 1264 | std::string ls = "-", 1265 | double lw = 1.0, 1266 | const std::map & keywords = {}) 1267 | { 1268 | using T = typename std::remove_reference::type::value_type; 1269 | 1270 | detail::_interpreter::get(); 1271 | 1272 | std::vector x; 1273 | for (std::size_t i = 0; i < y.size(); i++) { x.push_back(i); } 1274 | 1275 | return bar(x, y, ec, ls, lw, keywords); 1276 | } 1277 | 1278 | 1279 | template 1280 | bool barh(const std::vector &x, const std::vector &y, std::string ec = "black", std::string ls = "-", double lw = 1.0, const std::map &keywords = { }) { 1281 | PyObject *xarray = detail::get_array(x); 1282 | PyObject *yarray = detail::get_array(y); 1283 | 1284 | PyObject *kwargs = PyDict_New(); 1285 | 1286 | PyDict_SetItemString(kwargs, "ec", PyString_FromString(ec.c_str())); 1287 | PyDict_SetItemString(kwargs, "ls", PyString_FromString(ls.c_str())); 1288 | PyDict_SetItemString(kwargs, "lw", PyFloat_FromDouble(lw)); 1289 | 1290 | for (std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) { 1291 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 1292 | } 1293 | 1294 | PyObject *plot_args = PyTuple_New(2); 1295 | PyTuple_SetItem(plot_args, 0, xarray); 1296 | PyTuple_SetItem(plot_args, 1, yarray); 1297 | 1298 | PyObject *res = PyObject_Call(detail::_interpreter::get().s_python_function_barh, plot_args, kwargs); 1299 | 1300 | Py_DECREF(plot_args); 1301 | Py_DECREF(kwargs); 1302 | if (res) Py_DECREF(res); 1303 | 1304 | return res; 1305 | } 1306 | 1307 | 1308 | inline bool subplots_adjust(const std::map& keywords = {}) 1309 | { 1310 | detail::_interpreter::get(); 1311 | 1312 | PyObject* kwargs = PyDict_New(); 1313 | for (std::map::const_iterator it = 1314 | keywords.begin(); it != keywords.end(); ++it) { 1315 | PyDict_SetItemString(kwargs, it->first.c_str(), 1316 | PyFloat_FromDouble(it->second)); 1317 | } 1318 | 1319 | 1320 | PyObject* plot_args = PyTuple_New(0); 1321 | 1322 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_subplots_adjust, plot_args, kwargs); 1323 | 1324 | Py_DECREF(plot_args); 1325 | Py_DECREF(kwargs); 1326 | if(res) Py_DECREF(res); 1327 | 1328 | return res; 1329 | } 1330 | 1331 | template< typename Numeric> 1332 | bool named_hist(std::string label,const std::vector& y, long bins=10, std::string color="b", double alpha=1.0) 1333 | { 1334 | detail::_interpreter::get(); 1335 | 1336 | PyObject* yarray = detail::get_array(y); 1337 | 1338 | PyObject* kwargs = PyDict_New(); 1339 | PyDict_SetItemString(kwargs, "label", PyString_FromString(label.c_str())); 1340 | PyDict_SetItemString(kwargs, "bins", PyLong_FromLong(bins)); 1341 | PyDict_SetItemString(kwargs, "color", PyString_FromString(color.c_str())); 1342 | PyDict_SetItemString(kwargs, "alpha", PyFloat_FromDouble(alpha)); 1343 | 1344 | 1345 | PyObject* plot_args = PyTuple_New(1); 1346 | PyTuple_SetItem(plot_args, 0, yarray); 1347 | 1348 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_hist, plot_args, kwargs); 1349 | 1350 | Py_DECREF(plot_args); 1351 | Py_DECREF(kwargs); 1352 | if(res) Py_DECREF(res); 1353 | 1354 | return res; 1355 | } 1356 | 1357 | template 1358 | bool plot(const std::vector& x, const std::vector& y, const std::string& s = "") 1359 | { 1360 | assert(x.size() == y.size()); 1361 | 1362 | detail::_interpreter::get(); 1363 | 1364 | PyObject* xarray = detail::get_array(x); 1365 | PyObject* yarray = detail::get_array(y); 1366 | 1367 | PyObject* pystring = PyString_FromString(s.c_str()); 1368 | 1369 | PyObject* plot_args = PyTuple_New(3); 1370 | PyTuple_SetItem(plot_args, 0, xarray); 1371 | PyTuple_SetItem(plot_args, 1, yarray); 1372 | PyTuple_SetItem(plot_args, 2, pystring); 1373 | 1374 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_plot, plot_args); 1375 | 1376 | Py_DECREF(plot_args); 1377 | if(res) Py_DECREF(res); 1378 | 1379 | return res; 1380 | } 1381 | 1382 | template 1383 | bool contour(const std::vector& x, const std::vector& y, 1384 | const std::vector& z, 1385 | const std::map& keywords = {}) { 1386 | assert(x.size() == y.size() && x.size() == z.size()); 1387 | 1388 | PyObject* xarray = detail::get_array(x); 1389 | PyObject* yarray = detail::get_array(y); 1390 | PyObject* zarray = detail::get_array(z); 1391 | 1392 | PyObject* plot_args = PyTuple_New(3); 1393 | PyTuple_SetItem(plot_args, 0, xarray); 1394 | PyTuple_SetItem(plot_args, 1, yarray); 1395 | PyTuple_SetItem(plot_args, 2, zarray); 1396 | 1397 | // construct keyword args 1398 | PyObject* kwargs = PyDict_New(); 1399 | for (std::map::const_iterator it = keywords.begin(); 1400 | it != keywords.end(); ++it) { 1401 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 1402 | } 1403 | 1404 | PyObject* res = 1405 | PyObject_Call(detail::_interpreter::get().s_python_function_contour, plot_args, kwargs); 1406 | 1407 | Py_DECREF(kwargs); 1408 | Py_DECREF(plot_args); 1409 | if (res) 1410 | Py_DECREF(res); 1411 | 1412 | return res; 1413 | } 1414 | 1415 | template 1416 | bool quiver(const std::vector& x, const std::vector& y, const std::vector& u, const std::vector& w, const std::map& keywords = {}) 1417 | { 1418 | assert(x.size() == y.size() && x.size() == u.size() && u.size() == w.size()); 1419 | 1420 | detail::_interpreter::get(); 1421 | 1422 | PyObject* xarray = detail::get_array(x); 1423 | PyObject* yarray = detail::get_array(y); 1424 | PyObject* uarray = detail::get_array(u); 1425 | PyObject* warray = detail::get_array(w); 1426 | 1427 | PyObject* plot_args = PyTuple_New(4); 1428 | PyTuple_SetItem(plot_args, 0, xarray); 1429 | PyTuple_SetItem(plot_args, 1, yarray); 1430 | PyTuple_SetItem(plot_args, 2, uarray); 1431 | PyTuple_SetItem(plot_args, 3, warray); 1432 | 1433 | // construct keyword args 1434 | PyObject* kwargs = PyDict_New(); 1435 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 1436 | { 1437 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 1438 | } 1439 | 1440 | PyObject* res = PyObject_Call( 1441 | detail::_interpreter::get().s_python_function_quiver, plot_args, kwargs); 1442 | 1443 | Py_DECREF(kwargs); 1444 | Py_DECREF(plot_args); 1445 | if (res) 1446 | Py_DECREF(res); 1447 | 1448 | return res; 1449 | } 1450 | 1451 | template 1452 | bool quiver(const std::vector& x, const std::vector& y, const std::vector& z, const std::vector& u, const std::vector& w, const std::vector& v, const std::map& keywords = {}) 1453 | { 1454 | //set up 3d axes stuff 1455 | static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; 1456 | if (!mpl_toolkitsmod) { 1457 | detail::_interpreter::get(); 1458 | 1459 | PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits"); 1460 | PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d"); 1461 | if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); } 1462 | 1463 | mpl_toolkitsmod = PyImport_Import(mpl_toolkits); 1464 | Py_DECREF(mpl_toolkits); 1465 | if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); } 1466 | 1467 | axis3dmod = PyImport_Import(axis3d); 1468 | Py_DECREF(axis3d); 1469 | if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); } 1470 | } 1471 | 1472 | //assert sizes match up 1473 | assert(x.size() == y.size() && x.size() == u.size() && u.size() == w.size() && x.size() == z.size() && x.size() == v.size() && u.size() == v.size()); 1474 | 1475 | //set up parameters 1476 | detail::_interpreter::get(); 1477 | 1478 | PyObject* xarray = detail::get_array(x); 1479 | PyObject* yarray = detail::get_array(y); 1480 | PyObject* zarray = detail::get_array(z); 1481 | PyObject* uarray = detail::get_array(u); 1482 | PyObject* warray = detail::get_array(w); 1483 | PyObject* varray = detail::get_array(v); 1484 | 1485 | PyObject* plot_args = PyTuple_New(6); 1486 | PyTuple_SetItem(plot_args, 0, xarray); 1487 | PyTuple_SetItem(plot_args, 1, yarray); 1488 | PyTuple_SetItem(plot_args, 2, zarray); 1489 | PyTuple_SetItem(plot_args, 3, uarray); 1490 | PyTuple_SetItem(plot_args, 4, warray); 1491 | PyTuple_SetItem(plot_args, 5, varray); 1492 | 1493 | // construct keyword args 1494 | PyObject* kwargs = PyDict_New(); 1495 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 1496 | { 1497 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 1498 | } 1499 | 1500 | //get figure gca to enable 3d projection 1501 | PyObject *fig = 1502 | PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, 1503 | detail::_interpreter::get().s_python_empty_tuple); 1504 | if (!fig) throw std::runtime_error("Call to figure() failed."); 1505 | 1506 | PyObject *gca_kwargs = PyDict_New(); 1507 | PyDict_SetItemString(gca_kwargs, "projection", PyString_FromString("3d")); 1508 | 1509 | PyObject *gca = PyObject_GetAttrString(fig, "gca"); 1510 | if (!gca) throw std::runtime_error("No gca"); 1511 | Py_INCREF(gca); 1512 | PyObject *axis = PyObject_Call( 1513 | gca, detail::_interpreter::get().s_python_empty_tuple, gca_kwargs); 1514 | 1515 | if (!axis) throw std::runtime_error("No axis"); 1516 | Py_INCREF(axis); 1517 | Py_DECREF(gca); 1518 | Py_DECREF(gca_kwargs); 1519 | 1520 | //plot our boys bravely, plot them strongly, plot them with a wink and clap 1521 | PyObject *plot3 = PyObject_GetAttrString(axis, "quiver"); 1522 | if (!plot3) throw std::runtime_error("No 3D line plot"); 1523 | Py_INCREF(plot3); 1524 | PyObject* res = PyObject_Call( 1525 | plot3, plot_args, kwargs); 1526 | if (!res) throw std::runtime_error("Failed 3D plot"); 1527 | Py_DECREF(plot3); 1528 | Py_DECREF(axis); 1529 | Py_DECREF(kwargs); 1530 | Py_DECREF(plot_args); 1531 | if (res) 1532 | Py_DECREF(res); 1533 | 1534 | return res; 1535 | } 1536 | 1537 | template 1538 | bool stem(const std::vector& x, const std::vector& y, const std::string& s = "") 1539 | { 1540 | assert(x.size() == y.size()); 1541 | 1542 | detail::_interpreter::get(); 1543 | 1544 | PyObject* xarray = detail::get_array(x); 1545 | PyObject* yarray = detail::get_array(y); 1546 | 1547 | PyObject* pystring = PyString_FromString(s.c_str()); 1548 | 1549 | PyObject* plot_args = PyTuple_New(3); 1550 | PyTuple_SetItem(plot_args, 0, xarray); 1551 | PyTuple_SetItem(plot_args, 1, yarray); 1552 | PyTuple_SetItem(plot_args, 2, pystring); 1553 | 1554 | PyObject* res = PyObject_CallObject( 1555 | detail::_interpreter::get().s_python_function_stem, plot_args); 1556 | 1557 | Py_DECREF(plot_args); 1558 | if (res) 1559 | Py_DECREF(res); 1560 | 1561 | return res; 1562 | } 1563 | 1564 | template 1565 | bool semilogx(const std::vector& x, const std::vector& y, const std::string& s = "") 1566 | { 1567 | assert(x.size() == y.size()); 1568 | 1569 | detail::_interpreter::get(); 1570 | 1571 | PyObject* xarray = detail::get_array(x); 1572 | PyObject* yarray = detail::get_array(y); 1573 | 1574 | PyObject* pystring = PyString_FromString(s.c_str()); 1575 | 1576 | PyObject* plot_args = PyTuple_New(3); 1577 | PyTuple_SetItem(plot_args, 0, xarray); 1578 | PyTuple_SetItem(plot_args, 1, yarray); 1579 | PyTuple_SetItem(plot_args, 2, pystring); 1580 | 1581 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_semilogx, plot_args); 1582 | 1583 | Py_DECREF(plot_args); 1584 | if(res) Py_DECREF(res); 1585 | 1586 | return res; 1587 | } 1588 | 1589 | template 1590 | bool semilogy(const std::vector& x, const std::vector& y, const std::string& s = "") 1591 | { 1592 | assert(x.size() == y.size()); 1593 | 1594 | detail::_interpreter::get(); 1595 | 1596 | PyObject* xarray = detail::get_array(x); 1597 | PyObject* yarray = detail::get_array(y); 1598 | 1599 | PyObject* pystring = PyString_FromString(s.c_str()); 1600 | 1601 | PyObject* plot_args = PyTuple_New(3); 1602 | PyTuple_SetItem(plot_args, 0, xarray); 1603 | PyTuple_SetItem(plot_args, 1, yarray); 1604 | PyTuple_SetItem(plot_args, 2, pystring); 1605 | 1606 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_semilogy, plot_args); 1607 | 1608 | Py_DECREF(plot_args); 1609 | if(res) Py_DECREF(res); 1610 | 1611 | return res; 1612 | } 1613 | 1614 | template 1615 | bool loglog(const std::vector& x, const std::vector& y, const std::string& s = "") 1616 | { 1617 | assert(x.size() == y.size()); 1618 | 1619 | detail::_interpreter::get(); 1620 | 1621 | PyObject* xarray = detail::get_array(x); 1622 | PyObject* yarray = detail::get_array(y); 1623 | 1624 | PyObject* pystring = PyString_FromString(s.c_str()); 1625 | 1626 | PyObject* plot_args = PyTuple_New(3); 1627 | PyTuple_SetItem(plot_args, 0, xarray); 1628 | PyTuple_SetItem(plot_args, 1, yarray); 1629 | PyTuple_SetItem(plot_args, 2, pystring); 1630 | 1631 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_loglog, plot_args); 1632 | 1633 | Py_DECREF(plot_args); 1634 | if(res) Py_DECREF(res); 1635 | 1636 | return res; 1637 | } 1638 | 1639 | template 1640 | bool errorbar(const std::vector &x, const std::vector &y, const std::vector &yerr, const std::map &keywords = {}) 1641 | { 1642 | assert(x.size() == y.size()); 1643 | 1644 | detail::_interpreter::get(); 1645 | 1646 | PyObject* xarray = detail::get_array(x); 1647 | PyObject* yarray = detail::get_array(y); 1648 | PyObject* yerrarray = detail::get_array(yerr); 1649 | 1650 | // construct keyword args 1651 | PyObject* kwargs = PyDict_New(); 1652 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 1653 | { 1654 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 1655 | } 1656 | 1657 | PyDict_SetItemString(kwargs, "yerr", yerrarray); 1658 | 1659 | PyObject *plot_args = PyTuple_New(2); 1660 | PyTuple_SetItem(plot_args, 0, xarray); 1661 | PyTuple_SetItem(plot_args, 1, yarray); 1662 | 1663 | PyObject *res = PyObject_Call(detail::_interpreter::get().s_python_function_errorbar, plot_args, kwargs); 1664 | 1665 | Py_DECREF(kwargs); 1666 | Py_DECREF(plot_args); 1667 | 1668 | if (res) 1669 | Py_DECREF(res); 1670 | else 1671 | throw std::runtime_error("Call to errorbar() failed."); 1672 | 1673 | return res; 1674 | } 1675 | 1676 | template 1677 | bool named_plot(const std::string& name, const std::vector& y, const std::string& format = "") 1678 | { 1679 | detail::_interpreter::get(); 1680 | 1681 | PyObject* kwargs = PyDict_New(); 1682 | PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str())); 1683 | 1684 | PyObject* yarray = detail::get_array(y); 1685 | 1686 | PyObject* pystring = PyString_FromString(format.c_str()); 1687 | 1688 | PyObject* plot_args = PyTuple_New(2); 1689 | 1690 | PyTuple_SetItem(plot_args, 0, yarray); 1691 | PyTuple_SetItem(plot_args, 1, pystring); 1692 | 1693 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_plot, plot_args, kwargs); 1694 | 1695 | Py_DECREF(kwargs); 1696 | Py_DECREF(plot_args); 1697 | if (res) Py_DECREF(res); 1698 | 1699 | return res; 1700 | } 1701 | 1702 | template 1703 | bool named_plot(const std::string& name, const std::vector& x, const std::vector& y, const std::string& format = "") 1704 | { 1705 | detail::_interpreter::get(); 1706 | 1707 | PyObject* kwargs = PyDict_New(); 1708 | PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str())); 1709 | 1710 | PyObject* xarray = detail::get_array(x); 1711 | PyObject* yarray = detail::get_array(y); 1712 | 1713 | PyObject* pystring = PyString_FromString(format.c_str()); 1714 | 1715 | PyObject* plot_args = PyTuple_New(3); 1716 | PyTuple_SetItem(plot_args, 0, xarray); 1717 | PyTuple_SetItem(plot_args, 1, yarray); 1718 | PyTuple_SetItem(plot_args, 2, pystring); 1719 | 1720 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_plot, plot_args, kwargs); 1721 | 1722 | Py_DECREF(kwargs); 1723 | Py_DECREF(plot_args); 1724 | if (res) Py_DECREF(res); 1725 | 1726 | return res; 1727 | } 1728 | 1729 | template 1730 | bool named_semilogx(const std::string& name, const std::vector& x, const std::vector& y, const std::string& format = "") 1731 | { 1732 | detail::_interpreter::get(); 1733 | 1734 | PyObject* kwargs = PyDict_New(); 1735 | PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str())); 1736 | 1737 | PyObject* xarray = detail::get_array(x); 1738 | PyObject* yarray = detail::get_array(y); 1739 | 1740 | PyObject* pystring = PyString_FromString(format.c_str()); 1741 | 1742 | PyObject* plot_args = PyTuple_New(3); 1743 | PyTuple_SetItem(plot_args, 0, xarray); 1744 | PyTuple_SetItem(plot_args, 1, yarray); 1745 | PyTuple_SetItem(plot_args, 2, pystring); 1746 | 1747 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_semilogx, plot_args, kwargs); 1748 | 1749 | Py_DECREF(kwargs); 1750 | Py_DECREF(plot_args); 1751 | if (res) Py_DECREF(res); 1752 | 1753 | return res; 1754 | } 1755 | 1756 | template 1757 | bool named_semilogy(const std::string& name, const std::vector& x, const std::vector& y, const std::string& format = "") 1758 | { 1759 | detail::_interpreter::get(); 1760 | 1761 | PyObject* kwargs = PyDict_New(); 1762 | PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str())); 1763 | 1764 | PyObject* xarray = detail::get_array(x); 1765 | PyObject* yarray = detail::get_array(y); 1766 | 1767 | PyObject* pystring = PyString_FromString(format.c_str()); 1768 | 1769 | PyObject* plot_args = PyTuple_New(3); 1770 | PyTuple_SetItem(plot_args, 0, xarray); 1771 | PyTuple_SetItem(plot_args, 1, yarray); 1772 | PyTuple_SetItem(plot_args, 2, pystring); 1773 | 1774 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_semilogy, plot_args, kwargs); 1775 | 1776 | Py_DECREF(kwargs); 1777 | Py_DECREF(plot_args); 1778 | if (res) Py_DECREF(res); 1779 | 1780 | return res; 1781 | } 1782 | 1783 | template 1784 | bool named_loglog(const std::string& name, const std::vector& x, const std::vector& y, const std::string& format = "") 1785 | { 1786 | detail::_interpreter::get(); 1787 | 1788 | PyObject* kwargs = PyDict_New(); 1789 | PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str())); 1790 | 1791 | PyObject* xarray = detail::get_array(x); 1792 | PyObject* yarray = detail::get_array(y); 1793 | 1794 | PyObject* pystring = PyString_FromString(format.c_str()); 1795 | 1796 | PyObject* plot_args = PyTuple_New(3); 1797 | PyTuple_SetItem(plot_args, 0, xarray); 1798 | PyTuple_SetItem(plot_args, 1, yarray); 1799 | PyTuple_SetItem(plot_args, 2, pystring); 1800 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_loglog, plot_args, kwargs); 1801 | 1802 | Py_DECREF(kwargs); 1803 | Py_DECREF(plot_args); 1804 | if (res) Py_DECREF(res); 1805 | 1806 | return res; 1807 | } 1808 | 1809 | template 1810 | bool plot(const std::vector& y, const std::string& format = "") 1811 | { 1812 | std::vector x(y.size()); 1813 | for(size_t i=0; i 1818 | bool plot(const std::vector& y, const std::map& keywords) 1819 | { 1820 | std::vector x(y.size()); 1821 | for(size_t i=0; i 1826 | bool stem(const std::vector& y, const std::string& format = "") 1827 | { 1828 | std::vector x(y.size()); 1829 | for (size_t i = 0; i < x.size(); ++i) x.at(i) = i; 1830 | return stem(x, y, format); 1831 | } 1832 | 1833 | template 1834 | void text(Numeric x, Numeric y, const std::string& s = "") 1835 | { 1836 | detail::_interpreter::get(); 1837 | 1838 | PyObject* args = PyTuple_New(3); 1839 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(x)); 1840 | PyTuple_SetItem(args, 1, PyFloat_FromDouble(y)); 1841 | PyTuple_SetItem(args, 2, PyString_FromString(s.c_str())); 1842 | 1843 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_text, args); 1844 | if(!res) throw std::runtime_error("Call to text() failed."); 1845 | 1846 | Py_DECREF(args); 1847 | Py_DECREF(res); 1848 | } 1849 | 1850 | inline void colorbar(PyObject* mappable = NULL, const std::map& keywords = {}) 1851 | { 1852 | if (mappable == NULL) 1853 | throw std::runtime_error("Must call colorbar with PyObject* returned from an image, contour, surface, etc."); 1854 | 1855 | detail::_interpreter::get(); 1856 | 1857 | PyObject* args = PyTuple_New(1); 1858 | PyTuple_SetItem(args, 0, mappable); 1859 | 1860 | PyObject* kwargs = PyDict_New(); 1861 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 1862 | { 1863 | PyDict_SetItemString(kwargs, it->first.c_str(), PyFloat_FromDouble(it->second)); 1864 | } 1865 | 1866 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_colorbar, args, kwargs); 1867 | if(!res) throw std::runtime_error("Call to colorbar() failed."); 1868 | 1869 | Py_DECREF(args); 1870 | Py_DECREF(kwargs); 1871 | Py_DECREF(res); 1872 | } 1873 | 1874 | 1875 | inline long figure(long number = -1) 1876 | { 1877 | detail::_interpreter::get(); 1878 | 1879 | PyObject *res; 1880 | if (number == -1) 1881 | res = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, detail::_interpreter::get().s_python_empty_tuple); 1882 | else { 1883 | assert(number > 0); 1884 | 1885 | // Make sure interpreter is initialised 1886 | detail::_interpreter::get(); 1887 | 1888 | PyObject *args = PyTuple_New(1); 1889 | PyTuple_SetItem(args, 0, PyLong_FromLong(number)); 1890 | res = PyObject_CallObject(detail::_interpreter::get().s_python_function_figure, args); 1891 | Py_DECREF(args); 1892 | } 1893 | 1894 | if(!res) throw std::runtime_error("Call to figure() failed."); 1895 | 1896 | PyObject* num = PyObject_GetAttrString(res, "number"); 1897 | if (!num) throw std::runtime_error("Could not get number attribute of figure object"); 1898 | const long figureNumber = PyLong_AsLong(num); 1899 | 1900 | Py_DECREF(num); 1901 | Py_DECREF(res); 1902 | 1903 | return figureNumber; 1904 | } 1905 | 1906 | inline bool fignum_exists(long number) 1907 | { 1908 | detail::_interpreter::get(); 1909 | 1910 | PyObject *args = PyTuple_New(1); 1911 | PyTuple_SetItem(args, 0, PyLong_FromLong(number)); 1912 | PyObject *res = PyObject_CallObject(detail::_interpreter::get().s_python_function_fignum_exists, args); 1913 | if(!res) throw std::runtime_error("Call to fignum_exists() failed."); 1914 | 1915 | bool ret = PyObject_IsTrue(res); 1916 | Py_DECREF(res); 1917 | Py_DECREF(args); 1918 | 1919 | return ret; 1920 | } 1921 | 1922 | inline void figure_size(size_t w, size_t h) 1923 | { 1924 | detail::_interpreter::get(); 1925 | 1926 | const size_t dpi = 100; 1927 | PyObject* size = PyTuple_New(2); 1928 | PyTuple_SetItem(size, 0, PyFloat_FromDouble((double)w / dpi)); 1929 | PyTuple_SetItem(size, 1, PyFloat_FromDouble((double)h / dpi)); 1930 | 1931 | PyObject* kwargs = PyDict_New(); 1932 | PyDict_SetItemString(kwargs, "figsize", size); 1933 | PyDict_SetItemString(kwargs, "dpi", PyLong_FromSize_t(dpi)); 1934 | 1935 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_figure, 1936 | detail::_interpreter::get().s_python_empty_tuple, kwargs); 1937 | 1938 | Py_DECREF(kwargs); 1939 | 1940 | if(!res) throw std::runtime_error("Call to figure_size() failed."); 1941 | Py_DECREF(res); 1942 | } 1943 | 1944 | inline void legend() 1945 | { 1946 | detail::_interpreter::get(); 1947 | 1948 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_legend, detail::_interpreter::get().s_python_empty_tuple); 1949 | if(!res) throw std::runtime_error("Call to legend() failed."); 1950 | 1951 | Py_DECREF(res); 1952 | } 1953 | 1954 | inline void legend(const std::map& keywords) 1955 | { 1956 | detail::_interpreter::get(); 1957 | 1958 | // construct keyword args 1959 | PyObject* kwargs = PyDict_New(); 1960 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 1961 | { 1962 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 1963 | } 1964 | 1965 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_legend, detail::_interpreter::get().s_python_empty_tuple, kwargs); 1966 | if(!res) throw std::runtime_error("Call to legend() failed."); 1967 | 1968 | Py_DECREF(kwargs); 1969 | Py_DECREF(res); 1970 | } 1971 | 1972 | template 1973 | inline void set_aspect(Numeric ratio) 1974 | { 1975 | detail::_interpreter::get(); 1976 | 1977 | PyObject* args = PyTuple_New(1); 1978 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(ratio)); 1979 | PyObject* kwargs = PyDict_New(); 1980 | 1981 | PyObject *ax = 1982 | PyObject_CallObject(detail::_interpreter::get().s_python_function_gca, 1983 | detail::_interpreter::get().s_python_empty_tuple); 1984 | if (!ax) throw std::runtime_error("Call to gca() failed."); 1985 | Py_INCREF(ax); 1986 | 1987 | PyObject *set_aspect = PyObject_GetAttrString(ax, "set_aspect"); 1988 | if (!set_aspect) throw std::runtime_error("Attribute set_aspect not found."); 1989 | Py_INCREF(set_aspect); 1990 | 1991 | PyObject *res = PyObject_Call(set_aspect, args, kwargs); 1992 | if (!res) throw std::runtime_error("Call to set_aspect() failed."); 1993 | Py_DECREF(set_aspect); 1994 | 1995 | Py_DECREF(ax); 1996 | Py_DECREF(args); 1997 | Py_DECREF(kwargs); 1998 | } 1999 | 2000 | inline void set_aspect_equal() 2001 | { 2002 | // expect ratio == "equal". Leaving error handling to matplotlib. 2003 | detail::_interpreter::get(); 2004 | 2005 | PyObject* args = PyTuple_New(1); 2006 | PyTuple_SetItem(args, 0, PyString_FromString("equal")); 2007 | PyObject* kwargs = PyDict_New(); 2008 | 2009 | PyObject *ax = 2010 | PyObject_CallObject(detail::_interpreter::get().s_python_function_gca, 2011 | detail::_interpreter::get().s_python_empty_tuple); 2012 | if (!ax) throw std::runtime_error("Call to gca() failed."); 2013 | Py_INCREF(ax); 2014 | 2015 | PyObject *set_aspect = PyObject_GetAttrString(ax, "set_aspect"); 2016 | if (!set_aspect) throw std::runtime_error("Attribute set_aspect not found."); 2017 | Py_INCREF(set_aspect); 2018 | 2019 | PyObject *res = PyObject_Call(set_aspect, args, kwargs); 2020 | if (!res) throw std::runtime_error("Call to set_aspect() failed."); 2021 | Py_DECREF(set_aspect); 2022 | 2023 | Py_DECREF(ax); 2024 | Py_DECREF(args); 2025 | Py_DECREF(kwargs); 2026 | } 2027 | 2028 | template 2029 | void ylim(Numeric left, Numeric right) 2030 | { 2031 | detail::_interpreter::get(); 2032 | 2033 | PyObject* list = PyList_New(2); 2034 | PyList_SetItem(list, 0, PyFloat_FromDouble(left)); 2035 | PyList_SetItem(list, 1, PyFloat_FromDouble(right)); 2036 | 2037 | PyObject* args = PyTuple_New(1); 2038 | PyTuple_SetItem(args, 0, list); 2039 | 2040 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_ylim, args); 2041 | if(!res) throw std::runtime_error("Call to ylim() failed."); 2042 | 2043 | Py_DECREF(args); 2044 | Py_DECREF(res); 2045 | } 2046 | 2047 | template 2048 | void xlim(Numeric left, Numeric right) 2049 | { 2050 | detail::_interpreter::get(); 2051 | 2052 | PyObject* list = PyList_New(2); 2053 | PyList_SetItem(list, 0, PyFloat_FromDouble(left)); 2054 | PyList_SetItem(list, 1, PyFloat_FromDouble(right)); 2055 | 2056 | PyObject* args = PyTuple_New(1); 2057 | PyTuple_SetItem(args, 0, list); 2058 | 2059 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_xlim, args); 2060 | if(!res) throw std::runtime_error("Call to xlim() failed."); 2061 | 2062 | Py_DECREF(args); 2063 | Py_DECREF(res); 2064 | } 2065 | 2066 | 2067 | inline std::array xlim() 2068 | { 2069 | PyObject* args = PyTuple_New(0); 2070 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_xlim, args); 2071 | 2072 | if(!res) throw std::runtime_error("Call to xlim() failed."); 2073 | 2074 | Py_DECREF(res); 2075 | 2076 | PyObject* left = PyTuple_GetItem(res,0); 2077 | PyObject* right = PyTuple_GetItem(res,1); 2078 | return { PyFloat_AsDouble(left), PyFloat_AsDouble(right) }; 2079 | } 2080 | 2081 | 2082 | inline std::array ylim() 2083 | { 2084 | PyObject* args = PyTuple_New(0); 2085 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_ylim, args); 2086 | 2087 | if(!res) throw std::runtime_error("Call to ylim() failed."); 2088 | 2089 | Py_DECREF(res); 2090 | 2091 | PyObject* left = PyTuple_GetItem(res,0); 2092 | PyObject* right = PyTuple_GetItem(res,1); 2093 | return { PyFloat_AsDouble(left), PyFloat_AsDouble(right) }; 2094 | } 2095 | 2096 | template 2097 | inline void xticks(const std::vector &ticks, const std::vector &labels = {}, const std::map& keywords = {}) 2098 | { 2099 | assert(labels.size() == 0 || ticks.size() == labels.size()); 2100 | 2101 | detail::_interpreter::get(); 2102 | 2103 | // using numpy array 2104 | PyObject* ticksarray = detail::get_array(ticks); 2105 | 2106 | PyObject* args; 2107 | if(labels.size() == 0) { 2108 | // construct positional args 2109 | args = PyTuple_New(1); 2110 | PyTuple_SetItem(args, 0, ticksarray); 2111 | } else { 2112 | // make tuple of tick labels 2113 | PyObject* labelstuple = PyTuple_New(labels.size()); 2114 | for (size_t i = 0; i < labels.size(); i++) 2115 | PyTuple_SetItem(labelstuple, i, PyUnicode_FromString(labels[i].c_str())); 2116 | 2117 | // construct positional args 2118 | args = PyTuple_New(2); 2119 | PyTuple_SetItem(args, 0, ticksarray); 2120 | PyTuple_SetItem(args, 1, labelstuple); 2121 | } 2122 | 2123 | // construct keyword args 2124 | PyObject* kwargs = PyDict_New(); 2125 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 2126 | { 2127 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 2128 | } 2129 | 2130 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_xticks, args, kwargs); 2131 | 2132 | Py_DECREF(args); 2133 | Py_DECREF(kwargs); 2134 | if(!res) throw std::runtime_error("Call to xticks() failed"); 2135 | 2136 | Py_DECREF(res); 2137 | } 2138 | 2139 | template 2140 | inline void xticks(const std::vector &ticks, const std::map& keywords) 2141 | { 2142 | xticks(ticks, {}, keywords); 2143 | } 2144 | 2145 | template 2146 | inline void yticks(const std::vector &ticks, const std::vector &labels = {}, const std::map& keywords = {}) 2147 | { 2148 | assert(labels.size() == 0 || ticks.size() == labels.size()); 2149 | 2150 | detail::_interpreter::get(); 2151 | 2152 | // using numpy array 2153 | PyObject* ticksarray = detail::get_array(ticks); 2154 | 2155 | PyObject* args; 2156 | if(labels.size() == 0) { 2157 | // construct positional args 2158 | args = PyTuple_New(1); 2159 | PyTuple_SetItem(args, 0, ticksarray); 2160 | } else { 2161 | // make tuple of tick labels 2162 | PyObject* labelstuple = PyTuple_New(labels.size()); 2163 | for (size_t i = 0; i < labels.size(); i++) 2164 | PyTuple_SetItem(labelstuple, i, PyUnicode_FromString(labels[i].c_str())); 2165 | 2166 | // construct positional args 2167 | args = PyTuple_New(2); 2168 | PyTuple_SetItem(args, 0, ticksarray); 2169 | PyTuple_SetItem(args, 1, labelstuple); 2170 | } 2171 | 2172 | // construct keyword args 2173 | PyObject* kwargs = PyDict_New(); 2174 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 2175 | { 2176 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 2177 | } 2178 | 2179 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_yticks, args, kwargs); 2180 | 2181 | Py_DECREF(args); 2182 | Py_DECREF(kwargs); 2183 | if(!res) throw std::runtime_error("Call to yticks() failed"); 2184 | 2185 | Py_DECREF(res); 2186 | } 2187 | 2188 | template 2189 | inline void yticks(const std::vector &ticks, const std::map& keywords) 2190 | { 2191 | yticks(ticks, {}, keywords); 2192 | } 2193 | 2194 | template inline void margins(Numeric margin) 2195 | { 2196 | // construct positional args 2197 | PyObject* args = PyTuple_New(1); 2198 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(margin)); 2199 | 2200 | PyObject* res = 2201 | PyObject_CallObject(detail::_interpreter::get().s_python_function_margins, args); 2202 | if (!res) 2203 | throw std::runtime_error("Call to margins() failed."); 2204 | 2205 | Py_DECREF(args); 2206 | Py_DECREF(res); 2207 | } 2208 | 2209 | template inline void margins(Numeric margin_x, Numeric margin_y) 2210 | { 2211 | // construct positional args 2212 | PyObject* args = PyTuple_New(2); 2213 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(margin_x)); 2214 | PyTuple_SetItem(args, 1, PyFloat_FromDouble(margin_y)); 2215 | 2216 | PyObject* res = 2217 | PyObject_CallObject(detail::_interpreter::get().s_python_function_margins, args); 2218 | if (!res) 2219 | throw std::runtime_error("Call to margins() failed."); 2220 | 2221 | Py_DECREF(args); 2222 | Py_DECREF(res); 2223 | } 2224 | 2225 | 2226 | inline void tick_params(const std::map& keywords, const std::string axis = "both") 2227 | { 2228 | detail::_interpreter::get(); 2229 | 2230 | // construct positional args 2231 | PyObject* args; 2232 | args = PyTuple_New(1); 2233 | PyTuple_SetItem(args, 0, PyString_FromString(axis.c_str())); 2234 | 2235 | // construct keyword args 2236 | PyObject* kwargs = PyDict_New(); 2237 | for (std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 2238 | { 2239 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 2240 | } 2241 | 2242 | 2243 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_tick_params, args, kwargs); 2244 | 2245 | Py_DECREF(args); 2246 | Py_DECREF(kwargs); 2247 | if (!res) throw std::runtime_error("Call to tick_params() failed"); 2248 | 2249 | Py_DECREF(res); 2250 | } 2251 | 2252 | inline void subplot(long nrows, long ncols, long plot_number) 2253 | { 2254 | detail::_interpreter::get(); 2255 | 2256 | // construct positional args 2257 | PyObject* args = PyTuple_New(3); 2258 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(nrows)); 2259 | PyTuple_SetItem(args, 1, PyFloat_FromDouble(ncols)); 2260 | PyTuple_SetItem(args, 2, PyFloat_FromDouble(plot_number)); 2261 | 2262 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_subplot, args); 2263 | if(!res) throw std::runtime_error("Call to subplot() failed."); 2264 | 2265 | Py_DECREF(args); 2266 | Py_DECREF(res); 2267 | } 2268 | 2269 | inline void subplot2grid(long nrows, long ncols, long rowid=0, long colid=0, long rowspan=1, long colspan=1) 2270 | { 2271 | detail::_interpreter::get(); 2272 | 2273 | PyObject* shape = PyTuple_New(2); 2274 | PyTuple_SetItem(shape, 0, PyLong_FromLong(nrows)); 2275 | PyTuple_SetItem(shape, 1, PyLong_FromLong(ncols)); 2276 | 2277 | PyObject* loc = PyTuple_New(2); 2278 | PyTuple_SetItem(loc, 0, PyLong_FromLong(rowid)); 2279 | PyTuple_SetItem(loc, 1, PyLong_FromLong(colid)); 2280 | 2281 | PyObject* args = PyTuple_New(4); 2282 | PyTuple_SetItem(args, 0, shape); 2283 | PyTuple_SetItem(args, 1, loc); 2284 | PyTuple_SetItem(args, 2, PyLong_FromLong(rowspan)); 2285 | PyTuple_SetItem(args, 3, PyLong_FromLong(colspan)); 2286 | 2287 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_subplot2grid, args); 2288 | if(!res) throw std::runtime_error("Call to subplot2grid() failed."); 2289 | 2290 | Py_DECREF(shape); 2291 | Py_DECREF(loc); 2292 | Py_DECREF(args); 2293 | Py_DECREF(res); 2294 | } 2295 | 2296 | inline void title(const std::string &titlestr, const std::map &keywords = {}) 2297 | { 2298 | detail::_interpreter::get(); 2299 | 2300 | PyObject* pytitlestr = PyString_FromString(titlestr.c_str()); 2301 | PyObject* args = PyTuple_New(1); 2302 | PyTuple_SetItem(args, 0, pytitlestr); 2303 | 2304 | PyObject* kwargs = PyDict_New(); 2305 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 2306 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 2307 | } 2308 | 2309 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_title, args, kwargs); 2310 | if(!res) throw std::runtime_error("Call to title() failed."); 2311 | 2312 | Py_DECREF(args); 2313 | Py_DECREF(kwargs); 2314 | Py_DECREF(res); 2315 | } 2316 | 2317 | inline void suptitle(const std::string &suptitlestr, const std::map &keywords = {}) 2318 | { 2319 | detail::_interpreter::get(); 2320 | 2321 | PyObject* pysuptitlestr = PyString_FromString(suptitlestr.c_str()); 2322 | PyObject* args = PyTuple_New(1); 2323 | PyTuple_SetItem(args, 0, pysuptitlestr); 2324 | 2325 | PyObject* kwargs = PyDict_New(); 2326 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 2327 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 2328 | } 2329 | 2330 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_suptitle, args, kwargs); 2331 | if(!res) throw std::runtime_error("Call to suptitle() failed."); 2332 | 2333 | Py_DECREF(args); 2334 | Py_DECREF(kwargs); 2335 | Py_DECREF(res); 2336 | } 2337 | 2338 | inline void axis(const std::string &axisstr) 2339 | { 2340 | detail::_interpreter::get(); 2341 | 2342 | PyObject* str = PyString_FromString(axisstr.c_str()); 2343 | PyObject* args = PyTuple_New(1); 2344 | PyTuple_SetItem(args, 0, str); 2345 | 2346 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_axis, args); 2347 | if(!res) throw std::runtime_error("Call to title() failed."); 2348 | 2349 | Py_DECREF(args); 2350 | Py_DECREF(res); 2351 | } 2352 | 2353 | inline void axhline(double y, double xmin = 0., double xmax = 1., const std::map& keywords = std::map()) 2354 | { 2355 | detail::_interpreter::get(); 2356 | 2357 | // construct positional args 2358 | PyObject* args = PyTuple_New(3); 2359 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(y)); 2360 | PyTuple_SetItem(args, 1, PyFloat_FromDouble(xmin)); 2361 | PyTuple_SetItem(args, 2, PyFloat_FromDouble(xmax)); 2362 | 2363 | // construct keyword args 2364 | PyObject* kwargs = PyDict_New(); 2365 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 2366 | { 2367 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 2368 | } 2369 | 2370 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_axhline, args, kwargs); 2371 | 2372 | Py_DECREF(args); 2373 | Py_DECREF(kwargs); 2374 | 2375 | if(res) Py_DECREF(res); 2376 | } 2377 | 2378 | inline void axvline(double x, double ymin = 0., double ymax = 1., const std::map& keywords = std::map()) 2379 | { 2380 | detail::_interpreter::get(); 2381 | 2382 | // construct positional args 2383 | PyObject* args = PyTuple_New(3); 2384 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(x)); 2385 | PyTuple_SetItem(args, 1, PyFloat_FromDouble(ymin)); 2386 | PyTuple_SetItem(args, 2, PyFloat_FromDouble(ymax)); 2387 | 2388 | // construct keyword args 2389 | PyObject* kwargs = PyDict_New(); 2390 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 2391 | { 2392 | PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 2393 | } 2394 | 2395 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_axvline, args, kwargs); 2396 | 2397 | Py_DECREF(args); 2398 | Py_DECREF(kwargs); 2399 | 2400 | if(res) Py_DECREF(res); 2401 | } 2402 | 2403 | inline void axvspan(double xmin, double xmax, double ymin = 0., double ymax = 1., const std::map& keywords = std::map()) 2404 | { 2405 | // construct positional args 2406 | PyObject* args = PyTuple_New(4); 2407 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(xmin)); 2408 | PyTuple_SetItem(args, 1, PyFloat_FromDouble(xmax)); 2409 | PyTuple_SetItem(args, 2, PyFloat_FromDouble(ymin)); 2410 | PyTuple_SetItem(args, 3, PyFloat_FromDouble(ymax)); 2411 | 2412 | // construct keyword args 2413 | PyObject* kwargs = PyDict_New(); 2414 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 2415 | if (it->first == "linewidth" || it->first == "alpha") { 2416 | PyDict_SetItemString(kwargs, it->first.c_str(), 2417 | PyFloat_FromDouble(std::stod(it->second))); 2418 | } else { 2419 | PyDict_SetItemString(kwargs, it->first.c_str(), 2420 | PyString_FromString(it->second.c_str())); 2421 | } 2422 | } 2423 | 2424 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_axvspan, args, kwargs); 2425 | Py_DECREF(args); 2426 | Py_DECREF(kwargs); 2427 | 2428 | if(res) Py_DECREF(res); 2429 | } 2430 | 2431 | inline void xlabel(const std::string &str, const std::map &keywords = {}) 2432 | { 2433 | detail::_interpreter::get(); 2434 | 2435 | PyObject* pystr = PyString_FromString(str.c_str()); 2436 | PyObject* args = PyTuple_New(1); 2437 | PyTuple_SetItem(args, 0, pystr); 2438 | 2439 | PyObject* kwargs = PyDict_New(); 2440 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 2441 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 2442 | } 2443 | 2444 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_xlabel, args, kwargs); 2445 | if(!res) throw std::runtime_error("Call to xlabel() failed."); 2446 | 2447 | Py_DECREF(args); 2448 | Py_DECREF(kwargs); 2449 | Py_DECREF(res); 2450 | } 2451 | 2452 | inline void ylabel(const std::string &str, const std::map& keywords = {}) 2453 | { 2454 | detail::_interpreter::get(); 2455 | 2456 | PyObject* pystr = PyString_FromString(str.c_str()); 2457 | PyObject* args = PyTuple_New(1); 2458 | PyTuple_SetItem(args, 0, pystr); 2459 | 2460 | PyObject* kwargs = PyDict_New(); 2461 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 2462 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 2463 | } 2464 | 2465 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_ylabel, args, kwargs); 2466 | if(!res) throw std::runtime_error("Call to ylabel() failed."); 2467 | 2468 | Py_DECREF(args); 2469 | Py_DECREF(kwargs); 2470 | Py_DECREF(res); 2471 | } 2472 | 2473 | inline void set_zlabel(const std::string &str, const std::map& keywords = {}) 2474 | { 2475 | detail::_interpreter::get(); 2476 | 2477 | // Same as with plot_surface: We lazily load the modules here the first time 2478 | // this function is called because I'm not sure that we can assume "matplotlib 2479 | // installed" implies "mpl_toolkits installed" on all platforms, and we don't 2480 | // want to require it for people who don't need 3d plots. 2481 | static PyObject *mpl_toolkitsmod = nullptr, *axis3dmod = nullptr; 2482 | if (!mpl_toolkitsmod) { 2483 | PyObject* mpl_toolkits = PyString_FromString("mpl_toolkits"); 2484 | PyObject* axis3d = PyString_FromString("mpl_toolkits.mplot3d"); 2485 | if (!mpl_toolkits || !axis3d) { throw std::runtime_error("couldnt create string"); } 2486 | 2487 | mpl_toolkitsmod = PyImport_Import(mpl_toolkits); 2488 | Py_DECREF(mpl_toolkits); 2489 | if (!mpl_toolkitsmod) { throw std::runtime_error("Error loading module mpl_toolkits!"); } 2490 | 2491 | axis3dmod = PyImport_Import(axis3d); 2492 | Py_DECREF(axis3d); 2493 | if (!axis3dmod) { throw std::runtime_error("Error loading module mpl_toolkits.mplot3d!"); } 2494 | } 2495 | 2496 | PyObject* pystr = PyString_FromString(str.c_str()); 2497 | PyObject* args = PyTuple_New(1); 2498 | PyTuple_SetItem(args, 0, pystr); 2499 | 2500 | PyObject* kwargs = PyDict_New(); 2501 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 2502 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 2503 | } 2504 | 2505 | PyObject *ax = 2506 | PyObject_CallObject(detail::_interpreter::get().s_python_function_gca, 2507 | detail::_interpreter::get().s_python_empty_tuple); 2508 | if (!ax) throw std::runtime_error("Call to gca() failed."); 2509 | Py_INCREF(ax); 2510 | 2511 | PyObject *zlabel = PyObject_GetAttrString(ax, "set_zlabel"); 2512 | if (!zlabel) throw std::runtime_error("Attribute set_zlabel not found."); 2513 | Py_INCREF(zlabel); 2514 | 2515 | PyObject *res = PyObject_Call(zlabel, args, kwargs); 2516 | if (!res) throw std::runtime_error("Call to set_zlabel() failed."); 2517 | Py_DECREF(zlabel); 2518 | 2519 | Py_DECREF(ax); 2520 | Py_DECREF(args); 2521 | Py_DECREF(kwargs); 2522 | if (res) Py_DECREF(res); 2523 | } 2524 | 2525 | inline void grid(bool flag) 2526 | { 2527 | detail::_interpreter::get(); 2528 | 2529 | PyObject* pyflag = flag ? Py_True : Py_False; 2530 | Py_INCREF(pyflag); 2531 | 2532 | PyObject* args = PyTuple_New(1); 2533 | PyTuple_SetItem(args, 0, pyflag); 2534 | 2535 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_grid, args); 2536 | if(!res) throw std::runtime_error("Call to grid() failed."); 2537 | 2538 | Py_DECREF(args); 2539 | Py_DECREF(res); 2540 | } 2541 | 2542 | inline void show(const bool block = true) 2543 | { 2544 | detail::_interpreter::get(); 2545 | 2546 | PyObject* res; 2547 | if(block) 2548 | { 2549 | res = PyObject_CallObject( 2550 | detail::_interpreter::get().s_python_function_show, 2551 | detail::_interpreter::get().s_python_empty_tuple); 2552 | } 2553 | else 2554 | { 2555 | PyObject *kwargs = PyDict_New(); 2556 | PyDict_SetItemString(kwargs, "block", Py_False); 2557 | res = PyObject_Call( detail::_interpreter::get().s_python_function_show, detail::_interpreter::get().s_python_empty_tuple, kwargs); 2558 | Py_DECREF(kwargs); 2559 | } 2560 | 2561 | 2562 | if (!res) throw std::runtime_error("Call to show() failed."); 2563 | 2564 | Py_DECREF(res); 2565 | } 2566 | 2567 | inline void close() 2568 | { 2569 | detail::_interpreter::get(); 2570 | 2571 | PyObject* res = PyObject_CallObject( 2572 | detail::_interpreter::get().s_python_function_close, 2573 | detail::_interpreter::get().s_python_empty_tuple); 2574 | 2575 | if (!res) throw std::runtime_error("Call to close() failed."); 2576 | 2577 | Py_DECREF(res); 2578 | } 2579 | 2580 | inline void xkcd() { 2581 | detail::_interpreter::get(); 2582 | 2583 | PyObject* res; 2584 | PyObject *kwargs = PyDict_New(); 2585 | 2586 | res = PyObject_Call(detail::_interpreter::get().s_python_function_xkcd, 2587 | detail::_interpreter::get().s_python_empty_tuple, kwargs); 2588 | 2589 | Py_DECREF(kwargs); 2590 | 2591 | if (!res) 2592 | throw std::runtime_error("Call to show() failed."); 2593 | 2594 | Py_DECREF(res); 2595 | } 2596 | 2597 | inline void draw() 2598 | { 2599 | detail::_interpreter::get(); 2600 | 2601 | PyObject* res = PyObject_CallObject( 2602 | detail::_interpreter::get().s_python_function_draw, 2603 | detail::_interpreter::get().s_python_empty_tuple); 2604 | 2605 | if (!res) throw std::runtime_error("Call to draw() failed."); 2606 | 2607 | Py_DECREF(res); 2608 | } 2609 | 2610 | template 2611 | inline void pause(Numeric interval) 2612 | { 2613 | detail::_interpreter::get(); 2614 | 2615 | PyObject* args = PyTuple_New(1); 2616 | PyTuple_SetItem(args, 0, PyFloat_FromDouble(interval)); 2617 | 2618 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_pause, args); 2619 | if(!res) throw std::runtime_error("Call to pause() failed."); 2620 | 2621 | Py_DECREF(args); 2622 | Py_DECREF(res); 2623 | } 2624 | 2625 | inline void save(const std::string& filename, const int dpi=0) 2626 | { 2627 | detail::_interpreter::get(); 2628 | 2629 | PyObject* pyfilename = PyString_FromString(filename.c_str()); 2630 | 2631 | PyObject* args = PyTuple_New(1); 2632 | PyTuple_SetItem(args, 0, pyfilename); 2633 | 2634 | PyObject* kwargs = PyDict_New(); 2635 | 2636 | if(dpi > 0) 2637 | { 2638 | PyDict_SetItemString(kwargs, "dpi", PyLong_FromLong(dpi)); 2639 | } 2640 | 2641 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_save, args, kwargs); 2642 | if (!res) throw std::runtime_error("Call to save() failed."); 2643 | 2644 | Py_DECREF(args); 2645 | Py_DECREF(kwargs); 2646 | Py_DECREF(res); 2647 | } 2648 | 2649 | inline void rcparams(const std::map& keywords = {}) { 2650 | detail::_interpreter::get(); 2651 | PyObject* args = PyTuple_New(0); 2652 | PyObject* kwargs = PyDict_New(); 2653 | for (auto it = keywords.begin(); it != keywords.end(); ++it) { 2654 | if ("text.usetex" == it->first) 2655 | PyDict_SetItemString(kwargs, it->first.c_str(), PyLong_FromLong(std::stoi(it->second.c_str()))); 2656 | else PyDict_SetItemString(kwargs, it->first.c_str(), PyString_FromString(it->second.c_str())); 2657 | } 2658 | 2659 | PyObject * update = PyObject_GetAttrString(detail::_interpreter::get().s_python_function_rcparams, "update"); 2660 | PyObject * res = PyObject_Call(update, args, kwargs); 2661 | if(!res) throw std::runtime_error("Call to rcParams.update() failed."); 2662 | Py_DECREF(args); 2663 | Py_DECREF(kwargs); 2664 | Py_DECREF(update); 2665 | Py_DECREF(res); 2666 | } 2667 | 2668 | inline void clf() { 2669 | detail::_interpreter::get(); 2670 | 2671 | PyObject *res = PyObject_CallObject( 2672 | detail::_interpreter::get().s_python_function_clf, 2673 | detail::_interpreter::get().s_python_empty_tuple); 2674 | 2675 | if (!res) throw std::runtime_error("Call to clf() failed."); 2676 | 2677 | Py_DECREF(res); 2678 | } 2679 | 2680 | inline void cla() { 2681 | detail::_interpreter::get(); 2682 | 2683 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_cla, 2684 | detail::_interpreter::get().s_python_empty_tuple); 2685 | 2686 | if (!res) 2687 | throw std::runtime_error("Call to cla() failed."); 2688 | 2689 | Py_DECREF(res); 2690 | } 2691 | 2692 | inline void ion() { 2693 | detail::_interpreter::get(); 2694 | 2695 | PyObject *res = PyObject_CallObject( 2696 | detail::_interpreter::get().s_python_function_ion, 2697 | detail::_interpreter::get().s_python_empty_tuple); 2698 | 2699 | if (!res) throw std::runtime_error("Call to ion() failed."); 2700 | 2701 | Py_DECREF(res); 2702 | } 2703 | 2704 | inline std::vector> ginput(const int numClicks = 1, const std::map& keywords = {}) 2705 | { 2706 | detail::_interpreter::get(); 2707 | 2708 | PyObject *args = PyTuple_New(1); 2709 | PyTuple_SetItem(args, 0, PyLong_FromLong(numClicks)); 2710 | 2711 | // construct keyword args 2712 | PyObject* kwargs = PyDict_New(); 2713 | for(std::map::const_iterator it = keywords.begin(); it != keywords.end(); ++it) 2714 | { 2715 | PyDict_SetItemString(kwargs, it->first.c_str(), PyUnicode_FromString(it->second.c_str())); 2716 | } 2717 | 2718 | PyObject* res = PyObject_Call( 2719 | detail::_interpreter::get().s_python_function_ginput, args, kwargs); 2720 | 2721 | Py_DECREF(kwargs); 2722 | Py_DECREF(args); 2723 | if (!res) throw std::runtime_error("Call to ginput() failed."); 2724 | 2725 | const size_t len = PyList_Size(res); 2726 | std::vector> out; 2727 | out.reserve(len); 2728 | for (size_t i = 0; i < len; i++) { 2729 | PyObject *current = PyList_GetItem(res, i); 2730 | std::array position; 2731 | position[0] = PyFloat_AsDouble(PyTuple_GetItem(current, 0)); 2732 | position[1] = PyFloat_AsDouble(PyTuple_GetItem(current, 1)); 2733 | out.push_back(position); 2734 | } 2735 | Py_DECREF(res); 2736 | 2737 | return out; 2738 | } 2739 | 2740 | // Actually, is there any reason not to call this automatically for every plot? 2741 | inline void tight_layout() { 2742 | detail::_interpreter::get(); 2743 | 2744 | PyObject *res = PyObject_CallObject( 2745 | detail::_interpreter::get().s_python_function_tight_layout, 2746 | detail::_interpreter::get().s_python_empty_tuple); 2747 | 2748 | if (!res) throw std::runtime_error("Call to tight_layout() failed."); 2749 | 2750 | Py_DECREF(res); 2751 | } 2752 | 2753 | // Support for variadic plot() and initializer lists: 2754 | 2755 | namespace detail { 2756 | 2757 | template 2758 | using is_function = typename std::is_function>>::type; 2759 | 2760 | template 2761 | struct is_callable_impl; 2762 | 2763 | template 2764 | struct is_callable_impl 2765 | { 2766 | typedef is_function type; 2767 | }; // a non-object is callable iff it is a function 2768 | 2769 | template 2770 | struct is_callable_impl 2771 | { 2772 | struct Fallback { void operator()(); }; 2773 | struct Derived : T, Fallback { }; 2774 | 2775 | template struct Check; 2776 | 2777 | template 2778 | static std::true_type test( ... ); // use a variadic function to make sure (1) it accepts everything and (2) its always the worst match 2779 | 2780 | template 2781 | static std::false_type test( Check* ); 2782 | 2783 | public: 2784 | typedef decltype(test(nullptr)) type; 2785 | typedef decltype(&Fallback::operator()) dtype; 2786 | static constexpr bool value = type::value; 2787 | }; // an object is callable iff it defines operator() 2788 | 2789 | template 2790 | struct is_callable 2791 | { 2792 | // dispatch to is_callable_impl or is_callable_impl depending on whether T is of class type or not 2793 | typedef typename is_callable_impl::value, T>::type type; 2794 | }; 2795 | 2796 | template 2797 | struct plot_impl { }; 2798 | 2799 | template<> 2800 | struct plot_impl 2801 | { 2802 | template 2803 | bool operator()(const IterableX& x, const IterableY& y, const std::string& format) 2804 | { 2805 | detail::_interpreter::get(); 2806 | 2807 | // 2-phase lookup for distance, begin, end 2808 | using std::distance; 2809 | using std::begin; 2810 | using std::end; 2811 | 2812 | auto xs = distance(begin(x), end(x)); 2813 | auto ys = distance(begin(y), end(y)); 2814 | assert(xs == ys && "x and y data must have the same number of elements!"); 2815 | 2816 | PyObject* xlist = PyList_New(xs); 2817 | PyObject* ylist = PyList_New(ys); 2818 | PyObject* pystring = PyString_FromString(format.c_str()); 2819 | 2820 | auto itx = begin(x), ity = begin(y); 2821 | for(size_t i = 0; i < xs; ++i) { 2822 | PyList_SetItem(xlist, i, PyFloat_FromDouble(*itx++)); 2823 | PyList_SetItem(ylist, i, PyFloat_FromDouble(*ity++)); 2824 | } 2825 | 2826 | PyObject* plot_args = PyTuple_New(3); 2827 | PyTuple_SetItem(plot_args, 0, xlist); 2828 | PyTuple_SetItem(plot_args, 1, ylist); 2829 | PyTuple_SetItem(plot_args, 2, pystring); 2830 | 2831 | PyObject* res = PyObject_CallObject(detail::_interpreter::get().s_python_function_plot, plot_args); 2832 | 2833 | Py_DECREF(plot_args); 2834 | if(res) Py_DECREF(res); 2835 | 2836 | return res; 2837 | } 2838 | }; 2839 | 2840 | template<> 2841 | struct plot_impl 2842 | { 2843 | template 2844 | bool operator()(const Iterable& ticks, const Callable& f, const std::string& format) 2845 | { 2846 | if(begin(ticks) == end(ticks)) return true; 2847 | 2848 | // We could use additional meta-programming to deduce the correct element type of y, 2849 | // but all values have to be convertible to double anyways 2850 | std::vector y; 2851 | for(auto x : ticks) y.push_back(f(x)); 2852 | return plot_impl()(ticks,y,format); 2853 | } 2854 | }; 2855 | 2856 | } // end namespace detail 2857 | 2858 | // recursion stop for the above 2859 | template 2860 | bool plot() { return true; } 2861 | 2862 | template 2863 | bool plot(const A& a, const B& b, const std::string& format, Args... args) 2864 | { 2865 | return detail::plot_impl::type>()(a,b,format) && plot(args...); 2866 | } 2867 | 2868 | /* 2869 | * This group of plot() functions is needed to support initializer lists, i.e. calling 2870 | * plot( {1,2,3,4} ) 2871 | */ 2872 | inline bool plot(const std::vector& x, const std::vector& y, const std::string& format = "") { 2873 | return plot(x,y,format); 2874 | } 2875 | 2876 | inline bool plot(const std::vector& y, const std::string& format = "") { 2877 | return plot(y,format); 2878 | } 2879 | 2880 | inline bool plot(const std::vector& x, const std::vector& y, const std::map& keywords) { 2881 | return plot(x,y,keywords); 2882 | } 2883 | 2884 | /* 2885 | * This class allows dynamic plots, ie changing the plotted data without clearing and re-plotting 2886 | */ 2887 | class Plot 2888 | { 2889 | public: 2890 | // default initialization with plot label, some data and format 2891 | template 2892 | Plot(const std::string& name, const std::vector& x, const std::vector& y, const std::string& format = "") { 2893 | detail::_interpreter::get(); 2894 | 2895 | assert(x.size() == y.size()); 2896 | 2897 | PyObject* kwargs = PyDict_New(); 2898 | if(name != "") 2899 | PyDict_SetItemString(kwargs, "label", PyString_FromString(name.c_str())); 2900 | 2901 | PyObject* xarray = detail::get_array(x); 2902 | PyObject* yarray = detail::get_array(y); 2903 | 2904 | PyObject* pystring = PyString_FromString(format.c_str()); 2905 | 2906 | PyObject* plot_args = PyTuple_New(3); 2907 | PyTuple_SetItem(plot_args, 0, xarray); 2908 | PyTuple_SetItem(plot_args, 1, yarray); 2909 | PyTuple_SetItem(plot_args, 2, pystring); 2910 | 2911 | PyObject* res = PyObject_Call(detail::_interpreter::get().s_python_function_plot, plot_args, kwargs); 2912 | 2913 | Py_DECREF(kwargs); 2914 | Py_DECREF(plot_args); 2915 | 2916 | if(res) 2917 | { 2918 | line= PyList_GetItem(res, 0); 2919 | 2920 | if(line) 2921 | set_data_fct = PyObject_GetAttrString(line,"set_data"); 2922 | else 2923 | Py_DECREF(line); 2924 | Py_DECREF(res); 2925 | } 2926 | } 2927 | 2928 | // shorter initialization with name or format only 2929 | // basically calls line, = plot([], []) 2930 | Plot(const std::string& name = "", const std::string& format = "") 2931 | : Plot(name, std::vector(), std::vector(), format) {} 2932 | 2933 | template 2934 | bool update(const std::vector& x, const std::vector& y) { 2935 | assert(x.size() == y.size()); 2936 | if(set_data_fct) 2937 | { 2938 | PyObject* xarray = detail::get_array(x); 2939 | PyObject* yarray = detail::get_array(y); 2940 | 2941 | PyObject* plot_args = PyTuple_New(2); 2942 | PyTuple_SetItem(plot_args, 0, xarray); 2943 | PyTuple_SetItem(plot_args, 1, yarray); 2944 | 2945 | PyObject* res = PyObject_CallObject(set_data_fct, plot_args); 2946 | if (res) Py_DECREF(res); 2947 | return res; 2948 | } 2949 | return false; 2950 | } 2951 | 2952 | // clears the plot but keep it available 2953 | bool clear() { 2954 | return update(std::vector(), std::vector()); 2955 | } 2956 | 2957 | // definitely remove this line 2958 | void remove() { 2959 | if(line) 2960 | { 2961 | auto remove_fct = PyObject_GetAttrString(line,"remove"); 2962 | PyObject* args = PyTuple_New(0); 2963 | PyObject* res = PyObject_CallObject(remove_fct, args); 2964 | if (res) Py_DECREF(res); 2965 | } 2966 | decref(); 2967 | } 2968 | 2969 | ~Plot() { 2970 | decref(); 2971 | } 2972 | private: 2973 | 2974 | void decref() { 2975 | if(line) 2976 | Py_DECREF(line); 2977 | if(set_data_fct) 2978 | Py_DECREF(set_data_fct); 2979 | } 2980 | 2981 | 2982 | PyObject* line = nullptr; 2983 | PyObject* set_data_fct = nullptr; 2984 | }; 2985 | 2986 | } // end namespace matplotlibcpp 2987 | --------------------------------------------------------------------------------