├── .gitignore ├── CustomTerminalCostMPC.m ├── LICENSE ├── LQRController.m ├── LinearSystem.m ├── MPC.m ├── NonlinearVehicle.m ├── README.md ├── Untitled.m ├── ValueFunctionApproximator.m ├── main_RLMPC.asv ├── main_RLMPC.m ├── plot_comparison_linear.m ├── plot_comparison_nonlinear.m ├── plot_comparison_nonlinear_extended.m └── simulate_system.m /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | -------------------------------------------------------------------------------- /CustomTerminalCostMPC.m: -------------------------------------------------------------------------------- 1 | classdef CustomTerminalCostMPC < handle 2 | % CustomTerminalCostMPC: MPC implementation with custom terminal cost 3 | % This class implements MPC with a custom terminal cost function 4 | 5 | properties 6 | system % Dynamic system model 7 | horizon % Prediction horizon (N) 8 | terminal_cost_func % Terminal cost function handle 9 | last_solution % Store the last solution for warm start 10 | end 11 | 12 | methods 13 | function obj = CustomTerminalCostMPC(system, horizon, terminal_cost_func) 14 | % Constructor 15 | obj.system = system; 16 | obj.horizon = horizon; 17 | obj.terminal_cost_func = terminal_cost_func; 18 | obj.last_solution = []; 19 | end 20 | 21 | function [u_seq, cost_seq, x_traj] = solve(obj, x0) 22 | % Solve the MPC optimization problem 23 | % Returns the optimal control sequence, cost sequence, and state trajectory 24 | 25 | % Get system dimensions 26 | n = obj.system.n; % State dimension 27 | m = obj.system.m; % Input dimension 28 | N = obj.horizon; % Prediction horizon 29 | 30 | % Initialize optimization variables 31 | num_vars = m * N; % Number of control inputs over horizon 32 | 33 | % Initial state 34 | x_init = x0; 35 | 36 | % Define the objective function 37 | function [J, grad] = objective_function(U) 38 | % Reshape U vector to control sequence 39 | U_seq = reshape(U, [m, N]); 40 | 41 | % Initialize cost and state trajectory 42 | J = 0; 43 | x = x_init; 44 | x_traj_local = zeros(n, N+1); 45 | x_traj_local(:, 1) = x; 46 | 47 | % Initialize gradient if needed 48 | if nargout > 1 49 | grad = zeros(size(U)); 50 | end 51 | 52 | % Simulate the system over the horizon and compute cost 53 | cost_seq_local = zeros(1, N); 54 | for i = 1:N 55 | u_i = U_seq(:, i); 56 | 57 | % Stage cost 58 | stage_cost = obj.system.stage_cost(x, u_i); 59 | cost_seq_local(i) = stage_cost; 60 | J = J + stage_cost; 61 | 62 | % System dynamics 63 | x = obj.system.dynamics(x, u_i); 64 | x_traj_local(:, i+1) = x; 65 | end 66 | 67 | % Add terminal cost 68 | terminal_cost = obj.terminal_cost_func(x); 69 | J = J + terminal_cost; 70 | 71 | % We rely on numeric approximation for the gradient 72 | end 73 | 74 | % Define the constraints function for nonlinear constraints 75 | function [c, ceq, gradc, gradceq] = constraint_function(U) 76 | % No inequality constraints for the basic implementation 77 | c = []; 78 | gradc = []; 79 | 80 | % No equality constraints 81 | ceq = []; 82 | gradceq = []; 83 | end 84 | 85 | % Set optimization options (improved for convergence) 86 | options = optimoptions('fmincon', ... 87 | 'Display', 'off', ... 88 | 'Algorithm', 'sqp', ... 89 | 'MaxIterations', 500, ... % Increased from 200 90 | 'MaxFunctionEvaluations', 10000, ... % Increased from 5000 91 | 'OptimalityTolerance', 1e-3, ... % Relaxed from 1e-4 92 | 'ConstraintTolerance', 1e-3, ... % Relaxed from 1e-4 93 | 'StepTolerance', 1e-5, ... % Relaxed from 1e-6 94 | 'SpecifyObjectiveGradient', false, ... 95 | 'CheckGradients', false, ... 96 | 'ScaleProblem', 'obj-and-constr', ... % Added scaling 97 | 'HessianApproximation', 'bfgs'); % Added hessian approximation 98 | 99 | % Initial guess for control sequence 100 | if isempty(obj.last_solution) || length(obj.last_solution) ~= num_vars 101 | % If no previous solution or horizon changed, use zeros 102 | U0 = zeros(num_vars, 1); 103 | 104 | % For nonlinear vehicle, use a better initial guess 105 | if isa(obj.system, 'NonlinearVehicle') 106 | % Simple position-based heuristic for vehicle 107 | target = [0; 0; 0]; 108 | curr_pos = x_init(1:2); 109 | curr_theta = x_init(3); 110 | 111 | % Vector to target 112 | vec_to_target = target(1:2) - curr_pos; 113 | dist_to_target = norm(vec_to_target); 114 | 115 | if dist_to_target > 0.1 116 | % Desired heading angle 117 | desired_theta = atan2(vec_to_target(2), vec_to_target(1)); 118 | 119 | % Angular difference (shortest path) 120 | ang_diff = mod(desired_theta - curr_theta + pi, 2*pi) - pi; 121 | 122 | % Initial control: move toward target with appropriate turning 123 | v_init = min(0.5, dist_to_target); 124 | omega_init = 0.5 * ang_diff; 125 | 126 | % Clip to bounds 127 | v_init = max(min(v_init, obj.system.v_bounds(2)), obj.system.v_bounds(1)); 128 | omega_init = max(min(omega_init, obj.system.omega_bounds(2)), obj.system.omega_bounds(1)); 129 | 130 | % Set initial guess 131 | for i = 1:N 132 | U0((i-1)*m+1:(i-1)*m+m) = [v_init; omega_init]; 133 | end 134 | end 135 | end 136 | else 137 | % Warm start from last solution (shifted) 138 | U0 = [obj.last_solution(m+1:end); zeros(m, 1)]; 139 | end 140 | 141 | % Set bounds for control inputs (if applicable) 142 | lb = []; 143 | ub = []; 144 | 145 | % For nonlinear vehicle with control constraints 146 | if isa(obj.system, 'NonlinearVehicle') 147 | lb = repmat([obj.system.v_bounds(1); obj.system.omega_bounds(1)], [N, 1]); 148 | ub = repmat([obj.system.v_bounds(2); obj.system.omega_bounds(2)], [N, 1]); 149 | end 150 | 151 | % Try multiple optimization attempts with different options if needed 152 | exitflag = -99; 153 | U_opt = U0; 154 | 155 | % First attempt: SQP algorithm 156 | try 157 | [U_opt, ~, exitflag] = fmincon(@objective_function, U0, [], [], [], [], ... 158 | lb, ub, @constraint_function, options); 159 | catch 160 | warning('First optimization attempt failed. Trying fallback method.'); 161 | end 162 | 163 | % If first attempt failed, try with interior-point algorithm 164 | if exitflag <= 0 165 | options.Algorithm = 'interior-point'; 166 | options.HessianApproximation = 'bfgs'; 167 | options.InitBarrierParam = 0.1; % Added for interior-point 168 | options.InitTrustRegionRadius = 1; % Added for better convergence 169 | try 170 | [U_opt, ~, exitflag] = fmincon(@objective_function, U0, [], [], [], [], ... 171 | lb, ub, @constraint_function, options); 172 | catch 173 | warning('Second optimization attempt failed. Using best available solution.'); 174 | end 175 | 176 | % If that still fails, try active-set as last resort with relaxed tolerances 177 | if exitflag <= 0 178 | options.Algorithm = 'active-set'; 179 | options.OptimalityTolerance = 1e-2; % Further relaxed 180 | options.ConstraintTolerance = 1e-2; % Further relaxed 181 | try 182 | [U_opt, ~, exitflag] = fmincon(@objective_function, U0, [], [], [], [], ... 183 | lb, ub, @constraint_function, options); 184 | catch 185 | warning('Third optimization attempt failed. Using best available solution.'); 186 | end 187 | end 188 | end 189 | 190 | % Check if optimization was successful 191 | if exitflag <= 0 192 | warning('MPC optimization did not converge. Exitflag: %d', exitflag); 193 | % We'll still use the best solution found so far 194 | 195 | % For nonlinear vehicle, use a simple fallback strategy if optimization completely fails 196 | if isa(obj.system, 'NonlinearVehicle') && norm(U_opt - U0) < 1e-6 197 | % If solution is essentially unchanged from initial guess 198 | % Generate a simple control sequence to gradually slow down and stabilize 199 | for i = 1:N 200 | % Start with current guess and gradually reduce inputs 201 | decay = max(0, 1 - i/(N/2)); % Decay factor 202 | if i <= 2 203 | % Keep first couple inputs 204 | continue; 205 | elseif m == 2 % Assuming this is velocity and angular velocity 206 | % Gradually reduce velocity and steering 207 | U_opt((i-1)*m+1) = U_opt((i-1)*m+1) * decay; % Reduce velocity 208 | U_opt((i-1)*m+2) = U_opt((i-1)*m+2) * decay^2; % Reduce steering faster 209 | else 210 | % Generic reduction for other systems 211 | U_opt((i-1)*m+1:i*m) = U_opt((i-1)*m+1:i*m) * decay; 212 | end 213 | end 214 | warning('Using fallback control strategy due to optimization failure.'); 215 | end 216 | end 217 | 218 | % Store solution for warm start next time 219 | obj.last_solution = U_opt; 220 | 221 | % Extract optimal control sequence 222 | U_opt_seq = reshape(U_opt, [m, N]); 223 | u_seq = U_opt_seq; 224 | 225 | % Compute trajectory and cost sequence with optimal inputs 226 | x_traj = zeros(n, N+1); 227 | x_traj(:, 1) = x0; 228 | cost_seq = zeros(1, N); 229 | 230 | for i = 1:N 231 | % Stage cost 232 | cost_seq(i) = obj.system.stage_cost(x_traj(:, i), u_seq(:, i)); 233 | 234 | % System dynamics 235 | x_traj(:, i+1) = obj.system.dynamics(x_traj(:, i), u_seq(:, i)); 236 | end 237 | end 238 | end 239 | end -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 lmcgg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LQRController.m: -------------------------------------------------------------------------------- 1 | classdef LQRController < handle 2 | % LQRController: Implementation of the Linear Quadratic Regulator controller 3 | 4 | properties 5 | K % Feedback gain matrix 6 | end 7 | 8 | methods 9 | function obj = LQRController(K) 10 | % Constructor 11 | obj.K = K; 12 | end 13 | 14 | function [u_seq, cost_seq, x_traj] = solve(obj, x0) 15 | % Apply LQR control law 16 | % For compatibility with the MPC interface 17 | 18 | % Just return the current control input for the current state 19 | u = -obj.K * x0; 20 | 21 | % Return single step of control input, cost, and state trajectory 22 | u_seq = u; 23 | cost_seq = x0' * eye(length(x0)) * x0 + u' * 0.5 * u; % Approximate cost 24 | x_traj = [x0, x0]; % Placeholder for state trajectory 25 | end 26 | end 27 | end -------------------------------------------------------------------------------- /LinearSystem.m: -------------------------------------------------------------------------------- 1 | classdef LinearSystem < handle 2 | % LinearSystem: Class implementing a discrete-time linear system 3 | 4 | properties 5 | A % System matrix 6 | B % Input matrix 7 | Q % State cost matrix 8 | R % Input cost matrix 9 | n % State dimension 10 | m % Input dimension 11 | x % Current state 12 | end 13 | 14 | methods 15 | function obj = LinearSystem(A, B, Q, R) 16 | % Constructor 17 | obj.A = A; 18 | obj.B = B; 19 | obj.Q = Q; 20 | obj.R = R; 21 | 22 | % Determine dimensions 23 | obj.n = size(A, 1); 24 | obj.m = size(B, 2); 25 | 26 | % Initialize state 27 | obj.x = zeros(obj.n, 1); 28 | end 29 | 30 | function x_next = dynamics(obj, x, u) 31 | % System dynamics: x_{k+1} = Ax_k + Bu_k 32 | x_next = obj.A * x + obj.B * u; 33 | end 34 | 35 | function cost = stage_cost(obj, x, u) 36 | % Stage cost: l(x, u) = x'Qx + u'Ru 37 | cost = x' * obj.Q * x + u' * obj.R * u; 38 | end 39 | 40 | function x_next = step(obj, u) 41 | % Take a step in the system with input u 42 | obj.x = obj.dynamics(obj.x, u); 43 | x_next = obj.x; 44 | end 45 | 46 | function x = get_state(obj) 47 | % Get current state 48 | x = obj.x; 49 | end 50 | 51 | function set_state(obj, x0) 52 | % Set state 53 | obj.x = x0; 54 | end 55 | 56 | function x0 = get_initial_state(obj) 57 | % Get initial state for simulations 58 | x0 = [1; 1]; % Example initial state 59 | end 60 | 61 | function samples = generate_samples(obj, num_samples) 62 | % Generate random state samples for policy evaluation 63 | % For the linear system, sample within a reasonable range 64 | samples = 2 * rand(obj.n, num_samples) - 1; % Range [-1, 1] 65 | end 66 | 67 | function is_stable = is_stable(obj) 68 | % Check if the system is stable 69 | eigenvalues = eig(obj.A); 70 | is_stable = all(abs(eigenvalues) < 1); 71 | end 72 | end 73 | end -------------------------------------------------------------------------------- /MPC.m: -------------------------------------------------------------------------------- 1 | classdef MPC < handle 2 | % MPC: Model Predictive Control implementation 3 | % This class implements MPC with the terminal cost derived from RLMPC 4 | 5 | properties 6 | system % Dynamic system model 7 | horizon % Prediction horizon (N) 8 | vfa % Value function approximator 9 | W % Value function weights 10 | has_terminal_cost % Flag to indicate if terminal cost is used 11 | last_solution % Store the last solution for warm start 12 | end 13 | 14 | methods 15 | function obj = MPC(system, horizon, vfa, W) 16 | % Constructor 17 | obj.system = system; 18 | obj.horizon = horizon; 19 | obj.vfa = vfa; 20 | obj.W = W; 21 | obj.last_solution = []; 22 | 23 | % Check if terminal cost is used 24 | if isempty(vfa) || isempty(W) 25 | obj.has_terminal_cost = false; 26 | else 27 | obj.has_terminal_cost = true; 28 | end 29 | end 30 | 31 | function [u_seq, cost_seq, x_traj] = solve(obj, x0) 32 | % Solve the MPC optimization problem 33 | % Returns the optimal control sequence, cost sequence, and state trajectory 34 | 35 | % Get system dimensions 36 | n = obj.system.n; % State dimension 37 | m = obj.system.m; % Input dimension 38 | N = obj.horizon; % Prediction horizon 39 | 40 | % Initialize optimization variables 41 | num_vars = m * N; % Number of control inputs over horizon 42 | 43 | % Initial state 44 | x_init = x0; 45 | 46 | % Define the objective function 47 | function [J, grad] = objective_function(U) 48 | % Reshape U vector to control sequence 49 | U_seq = reshape(U, [m, N]); 50 | 51 | % Initialize cost and state trajectory 52 | J = 0; 53 | x = x_init; 54 | x_traj_local = zeros(n, N+1); 55 | x_traj_local(:, 1) = x; 56 | 57 | % Initialize gradient if needed 58 | if nargout > 1 59 | grad = zeros(size(U)); 60 | end 61 | 62 | % Simulate the system over the horizon and compute cost 63 | cost_seq_local = zeros(1, N); 64 | for i = 1:N 65 | u_i = U_seq(:, i); 66 | 67 | % Stage cost 68 | stage_cost = obj.system.stage_cost(x, u_i); 69 | cost_seq_local(i) = stage_cost; 70 | J = J + stage_cost; 71 | 72 | % System dynamics 73 | x = obj.system.dynamics(x, u_i); 74 | x_traj_local(:, i+1) = x; 75 | end 76 | 77 | % Add terminal cost if available 78 | if obj.has_terminal_cost 79 | terminal_cost = obj.vfa.evaluate(x, obj.W); 80 | J = J + terminal_cost; 81 | end 82 | 83 | % We rely on numeric approximation for the gradient 84 | end 85 | 86 | % Define the constraints function for nonlinear constraints 87 | function [c, ceq, gradc, gradceq] = constraint_function(U) 88 | % No inequality constraints for the basic implementation 89 | c = []; 90 | gradc = []; 91 | 92 | % No equality constraints 93 | ceq = []; 94 | gradceq = []; 95 | end 96 | 97 | % Set optimization options (improved for convergence) 98 | options = optimoptions('fmincon', ... 99 | 'Display', 'off', ... 100 | 'Algorithm', 'sqp', ... 101 | 'MaxIterations', 500, ... % Increased from 200 102 | 'MaxFunctionEvaluations', 10000, ... % Increased from 5000 103 | 'OptimalityTolerance', 1e-3, ... % Relaxed from 1e-4 104 | 'ConstraintTolerance', 1e-3, ... % Relaxed from 1e-4 105 | 'StepTolerance', 1e-5, ... % Relaxed from 1e-6 106 | 'SpecifyObjectiveGradient', false, ... 107 | 'CheckGradients', false, ... 108 | 'ScaleProblem', 'obj-and-constr', ... % Added scaling 109 | 'HessianApproximation', 'bfgs'); % Added hessian approximation 110 | 111 | % Initial guess for control sequence 112 | if isempty(obj.last_solution) || length(obj.last_solution) ~= num_vars 113 | % If no previous solution or horizon changed, use zeros 114 | U0 = zeros(num_vars, 1); 115 | 116 | % For nonlinear vehicle, use a better initial guess 117 | if isa(obj.system, 'NonlinearVehicle') 118 | % Simple position-based heuristic for vehicle 119 | target = [0; 0; 0]; 120 | curr_pos = x_init(1:2); 121 | curr_theta = x_init(3); 122 | 123 | % Vector to target 124 | vec_to_target = target(1:2) - curr_pos; 125 | dist_to_target = norm(vec_to_target); 126 | 127 | if dist_to_target > 0.1 128 | % Desired heading angle 129 | desired_theta = atan2(vec_to_target(2), vec_to_target(1)); 130 | 131 | % Angular difference (shortest path) 132 | ang_diff = mod(desired_theta - curr_theta + pi, 2*pi) - pi; 133 | 134 | % Initial control: move toward target with appropriate turning 135 | v_init = min(0.5, dist_to_target); 136 | omega_init = 0.5 * ang_diff; 137 | 138 | % Clip to bounds 139 | v_init = max(min(v_init, obj.system.v_bounds(2)), obj.system.v_bounds(1)); 140 | omega_init = max(min(omega_init, obj.system.omega_bounds(2)), obj.system.omega_bounds(1)); 141 | 142 | % Set initial guess 143 | for i = 1:N 144 | U0((i-1)*m+1:(i-1)*m+m) = [v_init; omega_init]; 145 | end 146 | end 147 | end 148 | else 149 | % Warm start from last solution (shifted) 150 | U0 = [obj.last_solution(m+1:end); zeros(m, 1)]; 151 | end 152 | 153 | % Set bounds for control inputs (if applicable) 154 | lb = []; 155 | ub = []; 156 | 157 | % For nonlinear vehicle with control constraints 158 | if isa(obj.system, 'NonlinearVehicle') 159 | lb = repmat([obj.system.v_bounds(1); obj.system.omega_bounds(1)], [N, 1]); 160 | ub = repmat([obj.system.v_bounds(2); obj.system.omega_bounds(2)], [N, 1]); 161 | end 162 | 163 | % Try multiple optimization attempts with different options if needed 164 | exitflag = -99; 165 | U_opt = U0; 166 | 167 | % First attempt: SQP algorithm 168 | try 169 | [U_opt, ~, exitflag] = fmincon(@objective_function, U0, [], [], [], [], ... 170 | lb, ub, @constraint_function, options); 171 | catch 172 | warning('First optimization attempt failed. Trying fallback method.'); 173 | end 174 | 175 | % If first attempt failed, try with interior-point algorithm 176 | if exitflag <= 0 177 | options.Algorithm = 'interior-point'; 178 | options.HessianApproximation = 'bfgs'; 179 | options.InitBarrierParam = 0.1; % Added for interior-point 180 | options.InitTrustRegionRadius = 1; % Added for better convergence 181 | try 182 | [U_opt, ~, exitflag] = fmincon(@objective_function, U0, [], [], [], [], ... 183 | lb, ub, @constraint_function, options); 184 | catch 185 | warning('Second optimization attempt failed. Using best available solution.'); 186 | end 187 | 188 | % If that still fails, try active-set as last resort with relaxed tolerances 189 | if exitflag <= 0 190 | options.Algorithm = 'active-set'; 191 | options.OptimalityTolerance = 1e-2; % Further relaxed 192 | options.ConstraintTolerance = 1e-2; % Further relaxed 193 | try 194 | [U_opt, ~, exitflag] = fmincon(@objective_function, U0, [], [], [], [], ... 195 | lb, ub, @constraint_function, options); 196 | catch 197 | warning('Third optimization attempt failed. Using best available solution.'); 198 | end 199 | end 200 | end 201 | 202 | % Check if optimization was successful 203 | if exitflag <= 0 204 | warning('MPC optimization did not converge. Exitflag: %d', exitflag); 205 | % We'll still use the best solution found so far 206 | 207 | % For nonlinear vehicle, use a simple fallback strategy if optimization completely fails 208 | if isa(obj.system, 'NonlinearVehicle') && norm(U_opt - U0) < 1e-6 209 | % If solution is essentially unchanged from initial guess 210 | % Generate a simple control sequence to gradually slow down and stabilize 211 | for i = 1:N 212 | % Start with current guess and gradually reduce inputs 213 | decay = max(0, 1 - i/(N/2)); % Decay factor 214 | if i <= 2 215 | % Keep first couple inputs 216 | continue; 217 | elseif m == 2 % Assuming this is velocity and angular velocity 218 | % Gradually reduce velocity and steering 219 | U_opt((i-1)*m+1) = U_opt((i-1)*m+1) * decay; % Reduce velocity 220 | U_opt((i-1)*m+2) = U_opt((i-1)*m+2) * decay^2; % Reduce steering faster 221 | else 222 | % Generic reduction for other systems 223 | U_opt((i-1)*m+1:i*m) = U_opt((i-1)*m+1:i*m) * decay; 224 | end 225 | end 226 | warning('Using fallback control strategy due to optimization failure.'); 227 | end 228 | end 229 | 230 | % Store solution for warm start next time 231 | obj.last_solution = U_opt; 232 | 233 | % Extract optimal control sequence 234 | U_opt_seq = reshape(U_opt, [m, N]); 235 | u_seq = U_opt_seq; 236 | 237 | % Compute trajectory and cost sequence with optimal inputs 238 | x_traj = zeros(n, N+1); 239 | x_traj(:, 1) = x0; 240 | cost_seq = zeros(1, N); 241 | 242 | for i = 1:N 243 | % Stage cost 244 | cost_seq(i) = obj.system.stage_cost(x_traj(:, i), u_seq(:, i)); 245 | 246 | % System dynamics 247 | x_traj(:, i+1) = obj.system.dynamics(x_traj(:, i), u_seq(:, i)); 248 | end 249 | end 250 | end 251 | end -------------------------------------------------------------------------------- /NonlinearVehicle.m: -------------------------------------------------------------------------------- 1 | classdef NonlinearVehicle < handle 2 | % NonlinearVehicle: Class implementing a non-holonomic vehicle system 3 | % States: [x, y, θ]' where (x,y) is position and θ is orientation 4 | % Inputs: [v, ω]' where v is linear velocity and ω is angular velocity 5 | 6 | properties 7 | x_bounds % State bounds [min, max] 8 | v_bounds % Linear velocity bounds [min, max] 9 | omega_bounds % Angular velocity bounds [min, max] 10 | dt % Time step 11 | Q % State cost matrix 12 | R % Input cost matrix 13 | x % Current state [x, y, θ]' 14 | n % State dimension 15 | m % Input dimension 16 | y_bounds % Y position bounds [min, max] 17 | theta_bounds % Orientation bounds [min, max] 18 | end 19 | 20 | methods 21 | function obj = NonlinearVehicle(x_bounds, v_bounds, omega_bounds) 22 | % Constructor 23 | obj.x_bounds = x_bounds; 24 | obj.v_bounds = v_bounds; 25 | obj.omega_bounds = omega_bounds; 26 | obj.dt = 0.1; % 100ms time step 27 | 28 | % State and input dimensions 29 | obj.n = 3; % [x, y, θ] 30 | obj.m = 2; % [v, ω] 31 | 32 | % Default bounds for y and theta 33 | obj.y_bounds = x_bounds; % Same bounds as x by default 34 | obj.theta_bounds = [-pi, pi]; % Full orientation range 35 | 36 | % Cost matrices - adjusted to better balance state vs input costs 37 | obj.Q = diag([1, 1, 0.2]); % State cost 38 | obj.R = diag([0.05, 0.01]); % Input cost 39 | 40 | % Initialize state 41 | obj.x = zeros(obj.n, 1); 42 | end 43 | 44 | function x_next = dynamics(obj, x, u) 45 | % Non-holonomic vehicle dynamics 46 | % x_{k+1} = x_k + g(x_k)u_k 47 | 48 | % Extract states and inputs 49 | theta = x(3); 50 | v = u(1); % Linear velocity 51 | omega = u(2); % Angular velocity 52 | 53 | % Apply input constraints with smooth saturation 54 | v = obj.smooth_saturation(v, obj.v_bounds(1), obj.v_bounds(2)); 55 | omega = obj.smooth_saturation(omega, obj.omega_bounds(1), obj.omega_bounds(2)); 56 | 57 | % Improve numerical integration for better accuracy 58 | % Use Runge-Kutta method (RK4) for more accurate dynamics 59 | k1 = obj.vehicle_ode(x, [v; omega]); 60 | k2 = obj.vehicle_ode(x + obj.dt/2 * k1, [v; omega]); 61 | k3 = obj.vehicle_ode(x + obj.dt/2 * k2, [v; omega]); 62 | k4 = obj.vehicle_ode(x + obj.dt * k3, [v; omega]); 63 | 64 | % Update state using RK4 integration 65 | x_next = x + obj.dt/6 * (k1 + 2*k2 + 2*k3 + k4); 66 | 67 | % Apply state constraints with smooth saturation 68 | x_next(1) = obj.smooth_saturation(x_next(1), obj.x_bounds(1), obj.x_bounds(2)); 69 | x_next(2) = obj.smooth_saturation(x_next(2), obj.y_bounds(1), obj.y_bounds(2)); 70 | 71 | % Normalize angle to [-π, π] 72 | x_next(3) = mod(x_next(3) + pi, 2*pi) - pi; 73 | end 74 | 75 | function dxdt = vehicle_ode(obj, x, u) 76 | % ODE function for vehicle dynamics 77 | theta = x(3); 78 | v = u(1); 79 | omega = u(2); 80 | 81 | dxdt = [v * cos(theta); 82 | v * sin(theta); 83 | omega]; 84 | end 85 | 86 | function val = smooth_saturation(obj, val, min_val, max_val) 87 | % Smooth saturation function to avoid discontinuities in gradient 88 | % Uses tanh-based soft saturation 89 | buffer = 0.01 * (max_val - min_val); 90 | if val > max_val - buffer 91 | val = max_val - buffer * tanh((max_val - val) / buffer); 92 | elseif val < min_val + buffer 93 | val = min_val + buffer * tanh((val - min_val) / buffer); 94 | end 95 | end 96 | 97 | function cost = stage_cost(obj, x, u) 98 | % Calculate stage cost 99 | % Penalize deviation from target (origin) and control effort 100 | 101 | % Get target state (origin) 102 | x_target = [0; 0; 0]; 103 | 104 | % Calculate error 105 | error = x - x_target; 106 | 107 | % Normalize angle error to [-π, π] 108 | error(3) = mod(error(3) + pi, 2*pi) - pi; 109 | 110 | % Stage cost with additional smoothness term for control inputs 111 | if length(u) > 1 112 | u_prev = u(:, 1); 113 | u_curr = u(:, end); 114 | smoothness_cost = 0.01 * norm(u_curr - u_prev)^2; 115 | else 116 | smoothness_cost = 0; 117 | end 118 | 119 | cost = error' * obj.Q * error + u' * obj.R * u + smoothness_cost; 120 | end 121 | 122 | function x_next = step(obj, u) 123 | % Take a step in the system with input u 124 | obj.x = obj.dynamics(obj.x, u); 125 | x_next = obj.x; 126 | end 127 | 128 | function x = get_state(obj) 129 | % Get current state 130 | x = obj.x; 131 | end 132 | 133 | function set_state(obj, x0) 134 | % Set state 135 | obj.x = x0; 136 | end 137 | 138 | function x0 = get_initial_state(obj) 139 | % Get initial state for simulations 140 | x0 = [1.5; 1.5; pi/4]; % Example initial state 141 | end 142 | 143 | function samples = generate_samples(obj, num_samples) 144 | % Generate random state samples for policy evaluation 145 | 146 | % Position samples within bounds 147 | x_samples = obj.x_bounds(1) + (obj.x_bounds(2) - obj.x_bounds(1)) * rand(1, num_samples); 148 | y_samples = obj.y_bounds(1) + (obj.y_bounds(2) - obj.y_bounds(1)) * rand(1, num_samples); 149 | 150 | % Orientation samples between -pi and pi 151 | theta_samples = 2 * pi * rand(1, num_samples) - pi; 152 | 153 | % Combine into state samples 154 | samples = [x_samples; y_samples; theta_samples]; 155 | 156 | % Add some samples near the target for better convergence 157 | num_near_target = min(50, floor(num_samples/4)); 158 | if num_near_target > 0 159 | target_vicinity = 0.5; % Range around target 160 | samples(:, 1:num_near_target) = [ 161 | target_vicinity * (2*rand(1, num_near_target) - 1); 162 | target_vicinity * (2*rand(1, num_near_target) - 1); 163 | pi * (2*rand(1, num_near_target) - 1) 164 | ]; 165 | end 166 | end 167 | 168 | function [A, B] = linearize(obj, x, u) 169 | % Linearize the system at a given point (x, u) 170 | % Returns Jacobians A = df/dx, B = df/du 171 | 172 | % Extract states and inputs 173 | theta = x(3); 174 | v = u(1); 175 | 176 | % Jacobian with respect to state (discretized) 177 | A = eye(3) + obj.dt * [0, 0, -v*sin(theta); 178 | 0, 0, v*cos(theta); 179 | 0, 0, 0]; 180 | 181 | % Jacobian with respect to input (discretized) 182 | B = obj.dt * [cos(theta), 0; 183 | sin(theta), 0; 184 | 0, 1]; 185 | end 186 | end 187 | end -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RLMPC - Reinforcement Learning Model Predictive Control 2 | 3 | This MATLAB implementation demonstrates the RLMPC (Reinforcement Learning Model Predictive Control) framework, which combines Reinforcement Learning (RL) with Model Predictive Control (MPC) through Policy Iteration (PI). 4 | 5 | ## Overview 6 | 7 | The RLMPC framework integrates RL's ability to learn optimal policies from data with MPC's optimization-based control strategy. The core innovation is using RL to learn an optimal terminal cost function for MPC, which improves control performance without requiring explicit terminal constraints or conditions. 8 | 9 | 10 | ### Key Features 11 | 12 | - **Policy Iteration Framework**: Alternates between policy generation (MPC) and policy evaluation (RL) 13 | - **Online Learning**: Learns the terminal cost function while controlling the system 14 | - **Polynomial Basis Functions**: Approximates the value function using polynomial features 15 | - **System Support**: Implementations for both linear systems and nonlinear non-holonomic vehicle systems 16 | - **Comparative Analysis**: Performance benchmarks against standard LQR and MPC approaches 17 | 18 | ## Algorithm 19 | 20 | The RLMPC algorithm implements the following key steps: 21 | 22 | 1. **Initialization**: Start with a zero terminal cost function 23 | 2. **Policy Generation**: Solve MPC optimization with current value function as terminal cost 24 | 3. **Policy Evaluation**: Sample states and evaluate current policy to learn value function 25 | 4. **Policy Improvement**: Update value function weights and use as new terminal cost 26 | 5. **Iterate**: Repeat until convergence or for a fixed number of iterations 27 | 28 | ## System Requirements 29 | 30 | - MATLAB R2019b or later 31 | - Optimization Toolbox 32 | - Statistics and Machine Learning Toolbox (optional, for some advanced features) 33 | 34 | ## Repository Structure 35 | 36 | - `main_RLMPC.m`: Main script to run RLMPC simulations 37 | - `LinearSystem.m`: Class implementing linear system dynamics and costs 38 | - `NonlinearVehicle.m`: Class implementing nonlinear non-holonomic vehicle dynamics 39 | - `ValueFunctionApproximator.m`: Class for polynomial basis function approximation 40 | - `MPC.m`: Model Predictive Control implementation with learned terminal cost 41 | - `CustomTerminalCostMPC.m`: MPC implementation with manually specified terminal cost 42 | - `LQRController.m`: Linear Quadratic Regulator controller for comparison 43 | - `simulate_system.m`: Function to simulate systems with various controllers 44 | - Visualization utilities: 45 | - `plot_comparison_linear.m`: Visualization for linear system results 46 | - `plot_comparison_nonlinear.m`: Visualization for nonlinear system results 47 | - `plot_comparison_nonlinear_extended.m`: Enhanced visualization for nonlinear systems 48 | 49 | ## Usage 50 | 51 | To run the simulations: 52 | 53 | ```matlab 54 | % Run the main script 55 | main_RLMPC 56 | ``` 57 | 58 | You can select the system type by modifying the `system_type` variable in the main script: 59 | 60 | ```matlab 61 | system_type = 'linear'; % For linear system 62 | % OR 63 | system_type = 'nonlinear'; % For non-holonomic vehicle 64 | ``` 65 | 66 | ## Configuration Parameters 67 | 68 | ### Linear System 69 | - System matrices: A = [1, 0.5; 0.1, 0.5], B = [1; 0] 70 | - Cost matrices: Q = eye(2), R = 0.5 71 | - Prediction horizon: N = 3 72 | - Polynomial basis order: 2 73 | - Learning rate (alpha): 1e-2 74 | - Regularization (lambda): 0.01 75 | 76 | ### Nonlinear Vehicle System 77 | - State: [x, y, θ] (position and orientation) 78 | - Input: [v, ω] (linear and angular velocity) 79 | - Input constraints: |v| ≤ 1, |ω| ≤ 4 80 | - State constraints: x, y bounded within [0, 2] 81 | - Prediction horizon: N = 5 (RLMPC), N_mpc = 20 (traditional MPC) 82 | - Polynomial basis order: 4 (35 features) 83 | - Learning rate (alpha): 2e-3 84 | - Regularization (lambda): 0.01 85 | 86 | ## Implementation Details 87 | 88 | ### Learning Process 89 | The implementation uses a hybrid approach for learning: 90 | - **Online Learning**: The main loop implements online learning where the value function is updated at each time step as the system evolves 91 | - **Stochastic Gradient Descent**: SGD is used to update the value function weights 92 | - **Gradient Clipping**: Prevents exploding gradients for stable learning 93 | - **L2 Regularization**: Prevents overfitting and stabilizes learning 94 | 95 | ### MPC Implementation 96 | The MPC controller is implemented with several features: 97 | - **Multiple Solvers**: Attempts different algorithms (SQP, interior-point, active-set) if optimization fails 98 | - **Warm Starting**: Uses previous solution to initialize new optimization 99 | - **Robust Constraints**: Handles both state and input constraints 100 | - **Adaptive Terminal Cost**: Incorporates the learned value function as terminal cost 101 | 102 | ## Experimental Results 103 | 104 | The implementation produces several visualizations: 105 | - State trajectories 106 | - Control inputs 107 | - Accumulated costs 108 | - Policy iteration convergence 109 | - Weight convergence during learning 110 | ![image](https://github.com/user-attachments/assets/cd22c41d-de86-47bf-b812-0002d1ee8bbe) 111 | ![image](https://github.com/user-attachments/assets/f67f5a30-45d3-4b03-9cfd-d15304e60a09) 112 | ![image](https://github.com/user-attachments/assets/2bb0f49d-c23f-4b60-8af9-d2428c95f604) 113 | ![image](https://github.com/user-attachments/assets/cbba04b0-5883-491d-a244-7cd2a416ff16) 114 | Performance Metrics (Nonlinear Vehicle with MPC-TC): 115 | - RLMPC Total Cost: 38.3546 116 | - MPC-Long Total Cost: 37.7445 117 | - MPC w/o TC Total Cost: 41.2104 118 | - MPC-TC Total Cost: 38.3546 119 | - RLMPC vs MPC-Long improvement: -1.6164% 120 | - RLMPC vs MPC w/o TC improvement: 6.9299% 121 | - RLMPC vs MPC-TC improvement: 0% 122 | - MPC-TC vs MPC w/o TC improvement: 6.9299% 123 | 124 | ### Linear System Results 125 | For linear systems, RLMPC is compared with: 126 | - Standard LQR controller 127 | - MPC without terminal cost 128 | - MPC with manually tuned terminal cost 129 | 130 | ### Nonlinear System Results 131 | For nonlinear systems, RLMPC is compared with: 132 | - Long-horizon MPC (N_mpc = 20) 133 | - Short-horizon MPC without terminal cost (N = 5) 134 | - MPC with manually tuned terminal cost (N = 5) 135 | 136 | ## Citation 137 | 138 | this code is based on "Reinforcement Learning-Based Model Predictive Control for Discrete-Time Systems" -IEEE TRANSACTIONS ON NEURAL NETWORKS AND LEARNING SYSTEMS, VOL. 35, NO. 3, MARCH 2024. 139 | 140 | my code is a simple test for the paper ,so ,there might be so many faults. If anyone find the mistakes in this code or some improvement for it ,iy's my pleasure for you to tell me or change it. 141 | 142 | 143 | ## Star History 144 | 145 | [![Star History Chart](https://api.star-history.com/svg?repos=lmcggg/RL-based-MPC-for-dts&type=Date)](https://www.star-history.com/#lmcggg/RL-based-MPC-for-dts&Date) 146 | 147 | 148 | ``` 149 | 150 | 151 | 152 | ## License 153 | 154 | This project is licensed under the MIT License - see the LICENSE file for details. 155 | 156 | ## Acknowledgments 157 | 158 | - This implementation is based on the RLMPC framework described in relevant literature 159 | - Special thanks to contributors and researchers in the field of learning-based control 160 | -------------------------------------------------------------------------------- /Untitled.m: -------------------------------------------------------------------------------- 1 | plot_comparison_nonlinear(x_rlmpc, x_mpc_long, x_mpc_wtc, u_rlmpc, u_mpc_long, u_mpc_wtc, ... 2 | total_cost_rlmpc, total_cost_mpc_long, total_cost_mpc_wtc, ... 3 | W_history, cost_history, iter_history); 4 | % 假设 w 是一个 35x200 的 double 数组 5 | % w = rand(35, 200); % 示例数据,实际使用时请替换为你的数据 6 | 7 | % 创建时间轴 8 | time = 0:size(W_history, 2) - 1; 9 | 10 | % 绘制每个变量的历史变动 11 | figure; 12 | hold on; % 保持图像,以便绘制多个曲线 13 | for i = 1:size(W_history, 1) 14 | plot(time, W_history(i, :), 'DisplayName', ['Variable ', num2str(i)]); 15 | end 16 | hold off; 17 | 18 | % 添加图例和标签 19 | legend('show'); 20 | xlabel('Time'); 21 | ylabel('Value'); 22 | title('History of 35 Variables'); 23 | grid on; % 添加网格 -------------------------------------------------------------------------------- /ValueFunctionApproximator.m: -------------------------------------------------------------------------------- 1 | classdef ValueFunctionApproximator < handle 2 | % ValueFunctionApproximator: Class for approximating the value function 3 | % using polynomial basis functions 4 | 5 | properties 6 | basis_order 7 | state_dim 8 | num_features 9 | feature_indices 10 | end 11 | 12 | methods 13 | function obj = ValueFunctionApproximator(basis_order, state_dim) 14 | 15 | obj.basis_order = basis_order; 16 | obj.state_dim = state_dim; 17 | 18 | obj.feature_indices = obj.generate_feature_indices(); 19 | obj.num_features = size(obj.feature_indices, 1); 20 | end 21 | 22 | function indices = generate_feature_indices(obj) 23 | 24 | 25 | 26 | indices = zeros(1, obj.state_dim); 27 | 28 | for order = 1:obj.basis_order 29 | 30 | new_indices = obj.generate_combinations_of_order(order); 31 | indices = [indices; new_indices]; 32 | end 33 | end 34 | 35 | function combinations = generate_combinations_of_order(obj, order) 36 | 37 | 38 | 39 | combinations = []; 40 | 41 | 42 | function recurse_combinations(curr_comb, remain_order, start_idx) 43 | if remain_order == 0 44 | combinations = [combinations; curr_comb]; 45 | return; 46 | end 47 | 48 | for i = start_idx:obj.state_dim 49 | new_comb = curr_comb; 50 | new_comb(i) = new_comb(i) + 1; 51 | recurse_combinations(new_comb, remain_order - 1, i); 52 | end 53 | end 54 | 55 | recurse_combinations(zeros(1, obj.state_dim), order, 1); 56 | end 57 | 58 | function phi = get_features(obj, x) 59 | 60 | x = x(:); 61 | 62 | phi = ones(obj.num_features, 1); 63 | 64 | for i = 2:obj.num_features 65 | term = 1; 66 | for j = 1:obj.state_dim 67 | if obj.feature_indices(i, j) > 0 68 | term = term * x(j)^obj.feature_indices(i, j); 69 | end 70 | end 71 | phi(i) = term; 72 | end 73 | end 74 | 75 | function value = evaluate(obj, x, W) 76 | % Evaluate the value function at state x with weights W 77 | % V(x) = W'Φ(x) 78 | phi = obj.get_features(x); 79 | value = W' * phi; 80 | end 81 | 82 | function n = get_num_features(obj) 83 | 84 | n = obj.num_features; 85 | end 86 | end 87 | end -------------------------------------------------------------------------------- /main_RLMPC.asv: -------------------------------------------------------------------------------- 1 | %% RLMPC - Reinforcement Learning Model Predictive Control 2 | % This script implements the RLMPC framework by combining RL and MPC 3 | 4 | 5 | clear; clc; close all; 6 | 7 | %% Select system type 8 | system_type = 'nonlinear'; % Options: 'linear', 'nonlinear' 9 | 10 | %% Parameters 11 | % Common parameters 12 | num_samples = 500; % Number of samples for policy evaluation 13 | 14 | %% System setup based on type 15 | if strcmp(system_type, 'linear') 16 | % Linear system parameters 17 | A = [1, 0.5; 0.1, 0.5]; 18 | B = [1; 0]; 19 | Q = eye(2); 20 | R = 0.5; 21 | N = 3; % Prediction horizon 22 | 23 | % Define polynomial basis functions 24 | basis_order = 2; 25 | alpha = 1e-2; % 26 | lambda_reg = 0.01; % Added for consistency 27 | 28 | 29 | sys = LinearSystem(A, B, Q, R); 30 | 31 | [K_lqr, P_lqr] = dlqr(A, B, Q, R); 32 | 33 | 34 | P_terminal = 5 * eye(2); 35 | 36 | elseif strcmp(system_type, 'nonlinear') 37 | % Nonlinear system parameters (non-holonomic vehicle) 38 | N = 5; % Prediction horizon for RLMPC 39 | N_mpc = 20; 40 | 41 | % State and input constraints 42 | x_bounds = [0, 2]; 43 | y_bounds = x_bounds; % Assuming symmetric bounds for y 44 | v_bounds = [-1, 1]; 45 | omega_bounds = [-4, 4]; 46 | 47 | 48 | basis_order = 4; 49 | % Original: 4 (35 features), Paper's hint: 6 (84 features) 50 | 51 | 52 | alpha = 2e-3; % More reasonable alpha since we'll average the gradients 53 | 54 | 55 | lambda_reg = 0.01; 56 | 57 | % Initialize system 58 | sys = NonlinearVehicle(x_bounds, v_bounds, omega_bounds); 59 | sys.y_bounds = y_bounds; % Explicitly set if not handled by constructor 60 | 61 | % Define terminal cost matrix for MPC with terminal cost 62 | % Using a higher weight than the stage cost Q for better terminal behavior 63 | P_terminal = diag([5, 5, 1]); % Higher weights on position than orientation 64 | else 65 | error('Invalid system type. Choose "linear" or "nonlinear".'); 66 | end 67 | 68 | %% Initialize value function approximation 69 | vfa = ValueFunctionApproximator(basis_order, size(sys.get_state(), 1)); 70 | 71 | %% RLMPC Algorithm - Online Learning (Closer to Algorithm 1) 72 | disp('Starting RLMPC online learning and simulation...'); 73 | 74 | % Simulation parameters 75 | sim_steps_total = 200; % Total steps for the system to run (learning + execution) 76 | learning_phase_steps = 100; % Number of steps during which learning (W update) happens 77 | % This value needs tuning. It's an analogue to 78 | % how long Flag == 0 in Algorithm 1. 79 | % After this, W is fixed. 80 | 81 | % Initialize system 82 | sys.set_state(sys.get_initial_state()); % Set initial state of the system 83 | x_k = sys.get_state(); 84 | 85 | % History for plotting actual system trajectory during learning/execution 86 | x_history_actual = zeros(size(sys.get_state(), 1), sim_steps_total + 1); 87 | if strcmp(system_type, 'linear') 88 | input_dim = size(B, 2); % For linear system, input dimension is columns of B 89 | else 90 | input_dim = sys.m; % For nonlinear system, assume sys.m exists 91 | end 92 | u_history_actual = zeros(input_dim, sim_steps_total); 93 | cost_history_actual = zeros(1, sim_steps_total); 94 | x_history_actual(:, 1) = x_k; 95 | 96 | W = zeros(vfa.get_num_features(), 1); % Initial weights 97 | W_history_online = zeros(vfa.get_num_features(), sim_steps_total); % Store W at each system step k 98 | 99 | % MPC controller instance (reused) 100 | mpc_learner = MPC(sys, N, vfa, W); % N is the MPC horizon for RLMPC 101 | 102 | % Learning control parameters 103 | alpha_sgd = 1e-5; % SGD learning rate, LIKELY NEEDS TUNING - smaller than batch alpha 104 | lambda_reg_sgd = 0.001; % Regularization for SGD 105 | epsilon_sgd_W_change = 1e-6; % Convergence threshold for W change if we want early stop 106 | max_sgd_updates_per_k = num_samples; % How many SGD steps per system step k on samples from Sk 107 | % Algorithm 1 implies Sk is processed, so this should be num_samples 108 | 109 | learning_flag = true; % Corresponds to Flag == 0 in Algorithm 1 initially 110 | consecutive_no_W_change_count = 0; 111 | max_consecutive_no_W_change = 5; % Stop learning if W doesn't change much for this many k steps 112 | 113 | for k = 1:sim_steps_total 114 | fprintf('System Step k = %d/%d\n', k, sim_steps_total); 115 | W_history_online(:, k) = W; % Store current W 116 | 117 | % --- Policy Generation (MPC solve) --- 118 | mpc_learner.W = W; % Update MPC with current W (from W_k to generate policy pi_k) 119 | try 120 | [u_sequence_k, cost_sequence_k, x_predicted_traj_k] = mpc_learner.solve(x_k); 121 | u_0_k = u_sequence_k(:, 1); % First control input 122 | catch e 123 | warning('MPC solve failed at system step k=%d: %s. Using zero input.', k, e.message); 124 | u_0_k = zeros(sys.m, 1); % Fallback control 125 | end 126 | 127 | % Apply u_0_k to the actual system 128 | actual_cost_k = sys.stage_cost(x_k, u_0_k); 129 | x_k_plus_1 = sys.step(u_0_k); % System moves to next state 130 | 131 | % Store actual system evolution 132 | u_history_actual(:, k) = u_0_k; 133 | cost_history_actual(k) = actual_cost_k; 134 | x_history_actual(:, k+1) = x_k_plus_1; 135 | 136 | % --- Policy Evaluation and W Update (if in learning phase) --- 137 | if learning_flag && k <= learning_phase_steps 138 | fprintf(' Learning phase: Updating W...\n'); 139 | W_before_update_this_k = W; 140 | 141 | % 1. Sample Sk (as in your original code) 142 | % It's better if Sk is sampled fresh or partially updated each k 143 | % For simplicity here, let's resample it. 144 | samples_Sk = sys.generate_samples(num_samples); % num_samples defined earlier 145 | 146 | for s_idx = 1:num_samples % Iterate through each sample in Sk for SGD 147 | x_j_sample = samples_Sk(:, s_idx); 148 | 149 | % 2. For each x_j in Sk, generate training data J'(x_j, u_j*) 150 | % This J' is calculated using the SAME W that was used for u_0_k (i.e., W before this inner loop's updates) 151 | % Or, more aligned with PI, J' should be for policy pi_k (based on W at start of step k) 152 | % So, mpc_learner.W should be W_before_update_this_k 153 | mpc_temp_eval = MPC(sys, N, vfa, W_before_update_this_k); % Use W from start of step k for J' calc 154 | try 155 | [u_seq_sample, cost_seq_sample, x_traj_sample] = mpc_temp_eval.solve(x_j_sample); 156 | 157 | % Calculate J_target = sum(costs) + V_pi_k(x_N_sample) 158 | % V_pi_k(x_N_sample) uses W_before_update_this_k 159 | terminal_value_approx_sample = W_before_update_this_k' * vfa.get_features(x_traj_sample(:, end)); 160 | J_target_for_sample_j = sum(cost_seq_sample) + terminal_value_approx_sample; 161 | 162 | % 3. SGD Update for W for this sample x_j_sample 163 | phi_j = vfa.get_features(x_j_sample); % Features for the sampled state x_j 164 | 165 | % TD error: J_target is for policy pi_k, W is current estimate V_pi_k_approx 166 | % We want W to approximate J_target 167 | td_error = J_target_for_sample_j - (W' * phi_j); % W is W_{t} being updated to W_{t+1} 168 | 169 | % SGD update rule (Eq. 16) 170 | delta_W_sgd = alpha_sgd * td_error * phi_j; 171 | 172 | % Add regularization (optional, but often good) 173 | regularization_term_sgd = alpha_sgd * lambda_reg_sgd * W; 174 | delta_W_sgd = delta_W_sgd - regularization_term_sgd; 175 | 176 | W = W + delta_W_sgd; % Update W immediately 177 | 178 | % Gradient Norm Clipping (helps prevent single large updates) 179 | max_grad_norm = 1.0; % More conservative clipping threshold 180 | current_grad_norm = norm(delta_W_sgd); 181 | if current_grad_norm > max_grad_norm 182 | delta_W_sgd = delta_W_sgd * (max_grad_norm / current_grad_norm); 183 | fprintf(' Gradient norm clipped from %e to %e\n', current_grad_norm, max_grad_norm); 184 | end 185 | 186 | % Clipping W as a last resort for stability 187 | W = min(max(W, -100.0), 100.0); % Stricter clipping range 188 | 189 | catch e_sgd 190 | % warning('MPC solve for SGD sample %d in Sk (k=%d) failed: %s', s_idx, k, e_sgd.message); 191 | % Skip update for this sample if MPC fails 192 | end 193 | end % End of SGD loop over Sk 194 | 195 | % Check for W convergence or stagnation 196 | w_change_norm_this_k = norm(W - W_before_update_this_k); 197 | fprintf(' Norm of W update (delta_W_total_for_k): %e, Norm of W: %e\n', ... 198 | w_change_norm_this_k, norm(W)); 199 | if w_change_norm_this_k < epsilon_sgd_W_change 200 | consecutive_no_W_change_count = consecutive_no_W_change_count + 1; 201 | if consecutive_no_W_change_count >= max_consecutive_no_W_change 202 | fprintf(' W has not changed significantly for %d steps. Stopping learning.\n', max_consecutive_no_W_change); 203 | learning_flag = false; % Corresponds to Flag = 1 204 | end 205 | else 206 | consecutive_no_W_change_count = 0; % Reset counter 207 | end 208 | % Check for NaN/Inf in W 209 | if any(isnan(W)) || any(isinf(W)) 210 | error('Error: W contains NaN or Inf values. System step k=%d. Stopping.', k); 211 | end 212 | end % End of learning_flag check 213 | 214 | % Update current state for next system step 215 | x_k = x_k_plus_1; 216 | 217 | % If target reached or system becomes unstable, you might want to break 218 | if norm(x_k(1:2)) < 0.01 && strcmp(system_type, 'nonlinear') % Example for vehicle 219 | fprintf('Target reached at step k=%d.\n', k); 220 | break; 221 | end 222 | if norm(x_k) > 100 % Generic instability check 223 | fprintf('System seems unstable at step k=%d. State norm: %f\n', k, norm(x_k)); 224 | break; 225 | end 226 | 227 | end % End of system simulation loop (k) 228 | 229 | % Trim histories if simulation ended early 230 | actual_sim_steps = k; 231 | x_history_actual = x_history_actual(:, 1:actual_sim_steps+1); 232 | u_history_actual = u_history_actual(:, 1:actual_sim_steps); 233 | cost_history_actual = cost_history_actual(1:actual_sim_steps); 234 | W_history_online = W_history_online(:, 1:actual_sim_steps); 235 | 236 | %% Evaluation of the FINAL learned policy (RLMPC Epi. 2 from paper) 237 | disp('Evaluating final policy (RLMPC Epi. 2)...'); 238 | mpc_final_eval = MPC(sys, N, vfa, W); % MPC with the FINAL learned W 239 | [x_rlmpc_epi2, u_rlmpc_epi2, cost_rlmpc_epi2_traj] = simulate_system(sys, mpc_final_eval, sys.get_initial_state(), 50); 240 | total_cost_rlmpc_epi2 = sum(cost_rlmpc_epi2_traj); 241 | fprintf('Total cost RLMPC (Epi. 2 with final W): %f\n', total_cost_rlmpc_epi2); 242 | 243 | %% Evaluate final policy and compare with baselines 244 | disp('Comparing with baseline controllers...'); 245 | 246 | if strcmp(system_type, 'linear') 247 | lqr_controller = LQRController(K_lqr); 248 | [x_lqr, u_lqr, cost_lqr_traj] = simulate_system(sys, lqr_controller, sys.get_initial_state(), 50); 249 | total_cost_lqr = sum(cost_lqr_traj); 250 | fprintf('Total cost LQR: %f\n', total_cost_lqr); 251 | 252 | mpc_wtc = MPC(sys, N, [], []); % MPC without terminal cost 253 | [x_mpc_wtc, u_mpc_wtc, cost_mpc_wtc_traj] = simulate_system(sys, mpc_wtc, sys.get_initial_state(), 50); 254 | total_cost_mpc_wtc = sum(cost_mpc_wtc_traj); 255 | fprintf('Total cost MPC_WTC: %f\n', total_cost_mpc_wtc); 256 | 257 | % Create a custom MPC with manual terminal cost 258 | % Create a custom quadratic terminal cost VFA using the P_terminal matrix 259 | terminal_cost_func = @(x) x' * P_terminal * x; 260 | mpc_tc = CustomTerminalCostMPC(sys, N, terminal_cost_func); 261 | [x_mpc_tc, u_mpc_tc, cost_mpc_tc_traj] = simulate_system(sys, mpc_tc, sys.get_initial_state(), 50); 262 | total_cost_mpc_tc = sum(cost_mpc_tc_traj); 263 | fprintf('Total cost MPC with terminal cost: %f\n', total_cost_mpc_tc); 264 | 265 | % Create a dummy iter_history for the plotting function if it expects one 266 | iter_history_dummy = 1:actual_sim_steps; 267 | 268 | % Plot results - adaptation of original plot_comparison_linear 269 | % First, plot W evolution during online learning 270 | figure; 271 | subplot(2,1,1); 272 | plot(W_history_online'); title('W weights evolution during online learning (k)'); xlabel('System Step k'); ylabel('Weight Value'); 273 | legend_entries = arrayfun(@(i) sprintf('W%d', i), 1:size(W_history_online,1), 'UniformOutput', false); 274 | legend(legend_entries); 275 | subplot(2,1,2); 276 | % Plot norm of W change 277 | W_norm_change = vecnorm(diff(W_history_online,1,2)); 278 | plot(W_norm_change); title('Norm of W change per step k'); xlabel('System Step k'); ylabel('||W_{k+1} - W_k||'); 279 | 280 | % Then, plot trajectory comparisons using original plot_comparison_linear but with new data 281 | plot_comparison_linear(x_rlmpc_epi2, x_lqr, x_mpc_wtc, u_rlmpc_epi2, u_lqr, u_mpc_wtc, ... 282 | total_cost_rlmpc_epi2, total_cost_lqr, total_cost_mpc_wtc, ... 283 | W_history_online, cost_history_actual, iter_history_dummy); 284 | 285 | elseif strcmp(system_type, 'nonlinear') 286 | % Add a try-catch block around the MPC_long simulation to prevent complete failure 287 | try 288 | fprintf('Running traditional MPC with horizon %d...\n', N_mpc); 289 | mpc_long = MPC(sys, N_mpc, [], []); % Traditional MPC (long horizon, no VFA) 290 | [x_mpc_long, u_mpc_long, cost_mpc_long_traj] = simulate_system(sys, mpc_long, sys.get_initial_state(), 50); 291 | total_cost_mpc_long = sum(cost_mpc_long_traj); 292 | fprintf('Total cost MPC_long: %f\n', total_cost_mpc_long); 293 | catch 294 | % Try with a shorter horizon as fallback 295 | try 296 | N_mpc_fallback = 15; 297 | fprintf('Retrying with reduced horizon %d...\n', N_mpc_fallback); 298 | mpc_long = MPC(sys, N_mpc_fallback, [], []); 299 | [x_mpc_long, u_mpc_long, cost_mpc_long_traj] = simulate_system(sys, mpc_long, sys.get_initial_state(), 50); 300 | total_cost_mpc_long = sum(cost_mpc_long_traj); 301 | fprintf('Total cost MPC_long (reduced horizon): %f\n', total_cost_mpc_long); 302 | catch 303 | % In case of complete failure, just duplicate RLMPC results for the comparison 304 | x_mpc_long = x_rlmpc_epi2; 305 | u_mpc_long = u_rlmpc_epi2; 306 | cost_mpc_long_traj = cost_rlmpc_epi2_traj; 307 | total_cost_mpc_long = total_cost_rlmpc_epi2; 308 | fprintf('Using RLMPC results as fallback for comparison.\n'); 309 | end 310 | end 311 | 312 | try 313 | mpc_wtc = MPC(sys, N, [], []); % MPC without terminal cost (same short horizon as RLMPC) 314 | [x_mpc_wtc, u_mpc_wtc, cost_mpc_wtc_traj] = simulate_system(sys, mpc_wtc, sys.get_initial_state(), 50); 315 | total_cost_mpc_wtc = sum(cost_mpc_wtc_traj); 316 | fprintf('Total cost MPC_WTC (N=%d): %f\n', N, total_cost_mpc_wtc); 317 | catch 318 | % In case of failure, just duplicate RLMPC results for the comparison 319 | x_mpc_wtc = x_rlmpc_epi2; 320 | u_mpc_wtc = u_rlmpc_epi2; 321 | cost_mpc_wtc_traj = cost_rlmpc_epi2_traj; 322 | total_cost_mpc_wtc = total_cost_rlmpc_epi2; 323 | fprintf('Using RLMPC results as fallback for comparison.\n'); 324 | end 325 | 326 | % Add a new MPC controller with custom terminal cost 327 | try 328 | % Create a custom quadratic terminal cost function 329 | % x'*P_terminal*x where x is relative to target (origin) 330 | terminal_cost_func = @(x) x' * P_terminal * x; 331 | 332 | % Create MPC with custom terminal cost 333 | mpc_tc = CustomTerminalCostMPC(sys, N, terminal_cost_func); 334 | [x_mpc_tc, u_mpc_tc, cost_mpc_tc_traj] = simulate_system(sys, mpc_tc, sys.get_initial_state(), 50); 335 | total_cost_mpc_tc = sum(cost_mpc_tc_traj); 336 | fprintf('Total cost MPC with terminal cost: %f\n', total_cost_mpc_tc); 337 | 338 | % In case of failure, use RLMPC results 339 | x_mpc_tc = x_rlmpc_epi2; 340 | u_mpc_tc = u_rlmpc_epi2; 341 | cost_mpc_tc_traj = cost_rlmpc_epi2_traj; 342 | total_cost_mpc_tc = total_cost_rlmpc_epi2; 343 | end 344 | 345 | % Create a dummy iter_history for the plotting function if it expects one 346 | iter_history_dummy = 1:actual_sim_steps; 347 | 348 | % Plot W evolution during online learning 349 | figure; 350 | subplot(2,1,1); 351 | plot(W_history_online'); title('W weights evolution during online learning (k)'); xlabel('System Step k'); ylabel('Weight Value'); 352 | legend_entries = arrayfun(@(i) sprintf('W%d', i), 1:size(W_history_online,1), 'UniformOutput', false); 353 | legend(legend_entries); 354 | subplot(2,1,2); 355 | % Plot norm of W change 356 | W_norm_change = vecnorm(diff(W_history_online,1,2)); 357 | plot(W_norm_change); title('Norm of W change per step k'); xlabel('System Step k'); ylabel('||W_{k+1} - W_k||'); 358 | 359 | % Plot results with the Episode 2 (final policy) performance 360 | plot_comparison_nonlinear_extended(x_rlmpc_epi2, x_mpc_long, x_mpc_wtc, x_mpc_tc, ... 361 | u_rlmpc_epi2, u_mpc_long, u_mpc_wtc, u_mpc_tc, ... 362 | total_cost_rlmpc_epi2, total_cost_mpc_long, total_cost_mpc_wtc, total_cost_mpc_tc, ... 363 | W_history_online, cost_history_actual, iter_history_dummy); 364 | 365 | % Additional plot for learning phase trajectory (Episode 1) 366 | figure; 367 | subplot(2,1,1); 368 | plot(x_history_actual'); title('RLMPC Episode 1 - Learning Phase Trajectory'); xlabel('System Step k'); ylabel('State Value'); 369 | state_labels = arrayfun(@(i) sprintf('State%d', i), 1:size(x_history_actual,1), 'UniformOutput', false); 370 | legend(state_labels); 371 | subplot(2,1,2); 372 | plot(u_history_actual'); title('RLMPC Episode 1 - Learning Phase Control'); xlabel('System Step k'); ylabel('Control Value'); 373 | control_labels = arrayfun(@(i) sprintf('Control%d', i), 1:size(u_history_actual,1), 'UniformOutput', false); 374 | legend(control_labels); 375 | end 376 | 377 | disp('RLMPC simulation completed.'); -------------------------------------------------------------------------------- /main_RLMPC.m: -------------------------------------------------------------------------------- 1 | %% RLMPC - Reinforcement Learning Model Predictive Control 2 | % This script implements the RLMPC framework by combining RL and MPC 3 | 4 | 5 | clear; clc; close all; 6 | 7 | %% Select system type 8 | system_type = 'nonlinear'; % Options: 'linear', 'nonlinear' 9 | 10 | %% Parameters 11 | % Common parameters 12 | num_samples = 500; % Number of samples for policy evaluation 13 | 14 | %% System setup based on type 15 | if strcmp(system_type, 'linear') 16 | % Linear system parameters 17 | A = [1, 0.5; 0.1, 0.5]; 18 | B = [1; 0]; 19 | Q = eye(2); 20 | R = 0.5; 21 | N = 3; % Prediction horizon 22 | 23 | % Define polynomial basis functions 24 | basis_order = 2; 25 | alpha = 1e-2; % 26 | lambda_reg = 0.01; % Added for consistency 27 | 28 | 29 | sys = LinearSystem(A, B, Q, R); 30 | 31 | [K_lqr, P_lqr] = dlqr(A, B, Q, R); 32 | 33 | 34 | P_terminal = 5 * eye(2); 35 | 36 | elseif strcmp(system_type, 'nonlinear') 37 | % Nonlinear system parameters (non-holonomic vehicle) 38 | N = 5; % Prediction horizon for RLMPC 39 | N_mpc = 20; 40 | 41 | % State and input constraints 42 | x_bounds = [0, 2]; 43 | y_bounds = x_bounds; % Assuming symmetric bounds for y 44 | v_bounds = [-1, 1]; 45 | omega_bounds = [-4, 4]; 46 | 47 | 48 | basis_order = 4; 49 | % Original: 4 (35 features), Paper's hint: 6 (84 features) 50 | 51 | 52 | alpha = 2e-3; % More reasonable alpha since we'll average the gradients 53 | 54 | 55 | lambda_reg = 0.01; 56 | 57 | 58 | sys = NonlinearVehicle(x_bounds, v_bounds, omega_bounds); 59 | sys.y_bounds = y_bounds; 60 | 61 | 62 | P_terminal = diag([5, 5, 1]); 63 | else 64 | error('Invalid system type. Choose "linear" or "nonlinear".'); 65 | end 66 | 67 | %% Initialize value function approximation 68 | vfa = ValueFunctionApproximator(basis_order, size(sys.get_state(), 1)); 69 | 70 | %% RLMPC Algorithm - Online Learning (Closer to Algorithm 1) 71 | disp('Starting RLMPC online learning and simulation...'); 72 | 73 | % Simulation parameters 74 | sim_steps_total = 200; % Total steps for the system to run (learning + execution) 75 | learning_phase_steps = 100; % Number of steps during which learning (W update) happens 76 | 77 | 78 | 79 | % Initialize system 80 | sys.set_state(sys.get_initial_state()); % Set initial state of the system 81 | x_k = sys.get_state(); 82 | 83 | % History for plotting actual system trajectory during learning/execution 84 | x_history_actual = zeros(size(sys.get_state(), 1), sim_steps_total + 1); 85 | if strcmp(system_type, 'linear') 86 | input_dim = size(B, 2); 87 | else 88 | input_dim = sys.m; 89 | end 90 | u_history_actual = zeros(input_dim, sim_steps_total); 91 | cost_history_actual = zeros(1, sim_steps_total); 92 | x_history_actual(:, 1) = x_k; 93 | 94 | W = zeros(vfa.get_num_features(), 1); 95 | W_history_online = zeros(vfa.get_num_features(), sim_steps_total); 96 | 97 | % MPC controller instance (reused) 98 | mpc_learner = MPC(sys, N, vfa, W); % N is the MPC horizon for RLMPC 99 | 100 | % Learning control parameters 101 | alpha_sgd = 1e-5; % SGD learning rate, liner=1e-3,nonliner=1e-5 102 | lambda_reg_sgd = 0.001; 103 | epsilon_sgd_W_change = 1e-6; % Convergence threshold for W change if we want early stop 104 | max_sgd_updates_per_k = num_samples; 105 | 106 | learning_flag = true; 107 | consecutive_no_W_change_count = 0; 108 | max_consecutive_no_W_change = 5; 109 | 110 | for k = 1:sim_steps_total 111 | fprintf('System Step k = %d/%d\n', k, sim_steps_total); 112 | W_history_online(:, k) = W; 113 | 114 | % --- Policy Generation (MPC solve) --- 115 | mpc_learner.W = W; 116 | try 117 | [u_sequence_k, cost_sequence_k, x_predicted_traj_k] = mpc_learner.solve(x_k); 118 | u_0_k = u_sequence_k(:, 1); 119 | catch e 120 | warning('MPC solve failed at system step k=%d: %s. Using zero input.', k, e.message); 121 | u_0_k = zeros(sys.m, 1); 122 | end 123 | 124 | 125 | actual_cost_k = sys.stage_cost(x_k, u_0_k); 126 | x_k_plus_1 = sys.step(u_0_k); 127 | 128 | 129 | u_history_actual(:, k) = u_0_k; 130 | cost_history_actual(k) = actual_cost_k; 131 | x_history_actual(:, k+1) = x_k_plus_1; 132 | 133 | % --- Policy Evaluation and W Update (if in learning phase) --- 134 | if learning_flag && k <= learning_phase_steps 135 | fprintf(' Learning phase: Updating W...\n'); 136 | W_before_update_this_k = W; 137 | 138 | % 1. Sample Sk (as in your original code) 139 | % It's better if Sk is sampled fresh or partially updated each k 140 | % For simplicity here, let's resample it. 141 | samples_Sk = sys.generate_samples(num_samples); 142 | 143 | for s_idx = 1:num_samples 144 | x_j_sample = samples_Sk(:, s_idx); 145 | 146 | % 2. For each x_j in Sk, generate training data J'(x_j, u_j*) 147 | % This J' is calculated using the SAME W that was used for u_0_k (i.e., W before this inner loop's updates) 148 | % Or, more aligned with PI, J' should be for policy pi_k (based on W at start of step k) 149 | % So, mpc_learner.W should be W_before_update_this_k 150 | mpc_temp_eval = MPC(sys, N, vfa, W_before_update_this_k); 151 | try 152 | [u_seq_sample, cost_seq_sample, x_traj_sample] = mpc_temp_eval.solve(x_j_sample); 153 | 154 | 155 | terminal_value_approx_sample = W_before_update_this_k' * vfa.get_features(x_traj_sample(:, end)); 156 | J_target_for_sample_j = sum(cost_seq_sample) + terminal_value_approx_sample; 157 | 158 | % 3. SGD Update for W for this sample x_j_sample 159 | phi_j = vfa.get_features(x_j_sample); 160 | 161 | td_error = J_target_for_sample_j - (W' * phi_j); 162 | 163 | 164 | delta_W_sgd = alpha_sgd * td_error * phi_j; 165 | 166 | 167 | regularization_term_sgd = alpha_sgd * lambda_reg_sgd * W; 168 | delta_W_sgd = delta_W_sgd - regularization_term_sgd; 169 | 170 | W = W + delta_W_sgd; % Update W immediately 171 | 172 | 173 | max_grad_norm = 1.0; 174 | current_grad_norm = norm(delta_W_sgd); 175 | if current_grad_norm > max_grad_norm 176 | delta_W_sgd = delta_W_sgd * (max_grad_norm / current_grad_norm); 177 | fprintf(' Gradient norm clipped from %e to %e\n', current_grad_norm, max_grad_norm); 178 | end 179 | 180 | 181 | W = min(max(W, -100.0), 100.0); 182 | 183 | catch e_sgd 184 | 185 | end 186 | end 187 | 188 | 189 | w_change_norm_this_k = norm(W - W_before_update_this_k); 190 | fprintf(' Norm of W update (delta_W_total_for_k): %e, Norm of W: %e\n', ... 191 | w_change_norm_this_k, norm(W)); 192 | if w_change_norm_this_k < epsilon_sgd_W_change 193 | consecutive_no_W_change_count = consecutive_no_W_change_count + 1; 194 | if consecutive_no_W_change_count >= max_consecutive_no_W_change 195 | fprintf(' W has not changed significantly for %d steps. Stopping learning.\n', max_consecutive_no_W_change); 196 | learning_flag = false; % Corresponds to Flag = 1 197 | end 198 | else 199 | consecutive_no_W_change_count = 0; % Reset counter 200 | end 201 | % Check for NaN/Inf in W 202 | if any(isnan(W)) || any(isinf(W)) 203 | error('Error: W contains NaN or Inf values. System step k=%d. Stopping.', k); 204 | end 205 | end % End of learning_flag check 206 | 207 | % Update current state for next system step 208 | x_k = x_k_plus_1; 209 | 210 | % If target reached or system becomes unstable, you might want to break 211 | if norm(x_k(1:2)) < 0.01 && strcmp(system_type, 'nonlinear') % Example for vehicle 212 | fprintf('Target reached at step k=%d.\n', k); 213 | break; 214 | end 215 | if norm(x_k) > 100 % Generic instability check 216 | fprintf('System seems unstable at step k=%d. State norm: %f\n', k, norm(x_k)); 217 | break; 218 | end 219 | 220 | end % End of system simulation loop (k) 221 | 222 | actual_sim_steps = k; 223 | x_history_actual = x_history_actual(:, 1:actual_sim_steps+1); 224 | u_history_actual = u_history_actual(:, 1:actual_sim_steps); 225 | cost_history_actual = cost_history_actual(1:actual_sim_steps); 226 | W_history_online = W_history_online(:, 1:actual_sim_steps); 227 | 228 | %% Evaluation of the FINAL learned policy (RLMPC Epi. 2 from paper) 229 | disp('Evaluating final policy (RLMPC Epi. 2)...'); 230 | mpc_final_eval = MPC(sys, N, vfa, W); % MPC with the FINAL learned W 231 | [x_rlmpc_epi2, u_rlmpc_epi2, cost_rlmpc_epi2_traj] = simulate_system(sys, mpc_final_eval, sys.get_initial_state(), 50); 232 | total_cost_rlmpc_epi2 = sum(cost_rlmpc_epi2_traj); 233 | fprintf('Total cost RLMPC (Epi. 2 with final W): %f\n', total_cost_rlmpc_epi2); 234 | 235 | %% Evaluate final policy and compare with baselines 236 | disp('Comparing with baseline controllers...'); 237 | 238 | if strcmp(system_type, 'linear') 239 | lqr_controller = LQRController(K_lqr); 240 | [x_lqr, u_lqr, cost_lqr_traj] = simulate_system(sys, lqr_controller, sys.get_initial_state(), 50); 241 | total_cost_lqr = sum(cost_lqr_traj); 242 | fprintf('Total cost LQR: %f\n', total_cost_lqr); 243 | 244 | mpc_wtc = MPC(sys, N, [], []); 245 | [x_mpc_wtc, u_mpc_wtc, cost_mpc_wtc_traj] = simulate_system(sys, mpc_wtc, sys.get_initial_state(), 50); 246 | total_cost_mpc_wtc = sum(cost_mpc_wtc_traj); 247 | fprintf('Total cost MPC_WTC: %f\n', total_cost_mpc_wtc); 248 | 249 | 250 | terminal_cost_func = @(x) x' * P_terminal * x; 251 | mpc_tc = CustomTerminalCostMPC(sys, N, terminal_cost_func); 252 | [x_mpc_tc, u_mpc_tc, cost_mpc_tc_traj] = simulate_system(sys, mpc_tc, sys.get_initial_state(), 50); 253 | total_cost_mpc_tc = sum(cost_mpc_tc_traj); 254 | fprintf('Total cost MPC with terminal cost: %f\n', total_cost_mpc_tc); 255 | 256 | 257 | iter_history_dummy = 1:actual_sim_steps; 258 | 259 | figure; 260 | subplot(2,1,1); 261 | plot(W_history_online'); title('W weights evolution during online learning (k)'); xlabel('System Step k'); ylabel('Weight Value'); 262 | legend_entries = arrayfun(@(i) sprintf('W%d', i), 1:size(W_history_online,1), 'UniformOutput', false); 263 | legend(legend_entries); 264 | subplot(2,1,2); 265 | 266 | W_norm_change = vecnorm(diff(W_history_online,1,2)); 267 | plot(W_norm_change); title('Norm of W change per step k'); xlabel('System Step k'); ylabel('||W_{k+1} - W_k||'); 268 | 269 | 270 | plot_comparison_linear(x_rlmpc_epi2, x_lqr, x_mpc_wtc, u_rlmpc_epi2, u_lqr, u_mpc_wtc, ... 271 | total_cost_rlmpc_epi2, total_cost_lqr, total_cost_mpc_wtc, ... 272 | W_history_online, cost_history_actual, iter_history_dummy); 273 | 274 | elseif strcmp(system_type, 'nonlinear') 275 | 276 | try 277 | fprintf('Running traditional MPC with horizon %d...\n', N_mpc); 278 | mpc_long = MPC(sys, N_mpc, [], []); % Traditional MPC (long horizon, no VFA) 279 | [x_mpc_long, u_mpc_long, cost_mpc_long_traj] = simulate_system(sys, mpc_long, sys.get_initial_state(), 50); 280 | total_cost_mpc_long = sum(cost_mpc_long_traj); 281 | fprintf('Total cost MPC_long: %f\n', total_cost_mpc_long); 282 | catch 283 | 284 | try 285 | N_mpc_fallback = 15; 286 | fprintf('Retrying with reduced horizon %d...\n', N_mpc_fallback); 287 | mpc_long = MPC(sys, N_mpc_fallback, [], []); 288 | [x_mpc_long, u_mpc_long, cost_mpc_long_traj] = simulate_system(sys, mpc_long, sys.get_initial_state(), 50); 289 | total_cost_mpc_long = sum(cost_mpc_long_traj); 290 | fprintf('Total cost MPC_long (reduced horizon): %f\n', total_cost_mpc_long); 291 | catch 292 | 293 | x_mpc_long = x_rlmpc_epi2; 294 | u_mpc_long = u_rlmpc_epi2; 295 | cost_mpc_long_traj = cost_rlmpc_epi2_traj; 296 | total_cost_mpc_long = total_cost_rlmpc_epi2; 297 | fprintf('Using RLMPC results as fallback for comparison.\n'); 298 | end 299 | end 300 | 301 | try 302 | mpc_wtc = MPC(sys, N, [], []); 303 | [x_mpc_wtc, u_mpc_wtc, cost_mpc_wtc_traj] = simulate_system(sys, mpc_wtc, sys.get_initial_state(), 50); 304 | total_cost_mpc_wtc = sum(cost_mpc_wtc_traj); 305 | fprintf('Total cost MPC_WTC (N=%d): %f\n', N, total_cost_mpc_wtc); 306 | catch 307 | 308 | x_mpc_wtc = x_rlmpc_epi2; 309 | u_mpc_wtc = u_rlmpc_epi2; 310 | cost_mpc_wtc_traj = cost_rlmpc_epi2_traj; 311 | total_cost_mpc_wtc = total_cost_rlmpc_epi2; 312 | fprintf('Using RLMPC results as fallback for comparison.\n'); 313 | end 314 | 315 | 316 | try 317 | 318 | terminal_cost_func = @(x) x' * P_terminal * x; 319 | 320 | mpc_tc = CustomTerminalCostMPC(sys, N, terminal_cost_func); 321 | [x_mpc_tc, u_mpc_tc, cost_mpc_tc_traj] = simulate_system(sys, mpc_tc, sys.get_initial_state(), 50); 322 | total_cost_mpc_tc = sum(cost_mpc_tc_traj); 323 | fprintf('Total cost MPC with terminal cost: %f\n', total_cost_mpc_tc); 324 | 325 | x_mpc_tc = x_rlmpc_epi2; 326 | u_mpc_tc = u_rlmpc_epi2; 327 | cost_mpc_tc_traj = cost_rlmpc_epi2_traj; 328 | total_cost_mpc_tc = total_cost_rlmpc_epi2; 329 | end 330 | 331 | iter_history_dummy = 1:actual_sim_steps; 332 | 333 | figure; 334 | subplot(2,1,1); 335 | plot(W_history_online'); title('W weights evolution during online learning (k)'); xlabel('System Step k'); ylabel('Weight Value'); 336 | legend_entries = arrayfun(@(i) sprintf('W%d', i), 1:size(W_history_online,1), 'UniformOutput', false); 337 | legend(legend_entries); 338 | subplot(2,1,2); 339 | 340 | W_norm_change = vecnorm(diff(W_history_online,1,2)); 341 | plot(W_norm_change); title('Norm of W change per step k'); xlabel('System Step k'); ylabel('||W_{k+1} - W_k||'); 342 | 343 | plot_comparison_nonlinear_extended(x_rlmpc_epi2, x_mpc_long, x_mpc_wtc, x_mpc_tc, ... 344 | u_rlmpc_epi2, u_mpc_long, u_mpc_wtc, u_mpc_tc, ... 345 | total_cost_rlmpc_epi2, total_cost_mpc_long, total_cost_mpc_wtc, total_cost_mpc_tc, ... 346 | W_history_online, cost_history_actual, iter_history_dummy); 347 | 348 | 349 | figure; 350 | subplot(2,1,1); 351 | plot(x_history_actual'); title('RLMPC Episode 1 - Learning Phase Trajectory'); xlabel('System Step k'); ylabel('State Value'); 352 | state_labels = arrayfun(@(i) sprintf('State%d', i), 1:size(x_history_actual,1), 'UniformOutput', false); 353 | legend(state_labels); 354 | subplot(2,1,2); 355 | plot(u_history_actual'); title('RLMPC Episode 1 - Learning Phase Control'); xlabel('System Step k'); ylabel('Control Value'); 356 | control_labels = arrayfun(@(i) sprintf('Control%d', i), 1:size(u_history_actual,1), 'UniformOutput', false); 357 | legend(control_labels); 358 | end 359 | 360 | disp('RLMPC simulation completed.'); -------------------------------------------------------------------------------- /plot_comparison_linear.m: -------------------------------------------------------------------------------- 1 | function plot_comparison_linear(x_rlmpc, x_lqr, x_mpc_wtc, u_rlmpc, u_lqr, u_mpc_wtc, ... 2 | cost_rlmpc, cost_lqr, cost_mpc_wtc, W_history, cost_history, iter_history) 3 | % PLOT_COMPARISON_LINEAR Plot comparison results for the linear system 4 | % 5 | % This function plots state trajectories, control inputs, accumulated costs, 6 | % and value function weight convergence for the linear system case 7 | 8 | % Calculate time vector 9 | sim_steps = size(x_rlmpc, 2) - 1; 10 | time = 0:sim_steps; 11 | 12 | % Calculate accumulated costs 13 | acc_cost_rlmpc = cumsum(cost_rlmpc); 14 | acc_cost_lqr = cumsum(cost_lqr); 15 | acc_cost_mpc_wtc = cumsum(cost_mpc_wtc); 16 | 17 | % Create figure for state trajectories 18 | figure('Name', 'Linear System: State Trajectories', 'Position', [100, 100, 1000, 600]); 19 | 20 | % Plot state trajectories 21 | subplot(2, 2, 1); 22 | hold on; 23 | plot(time, x_rlmpc(1, :), 'b-', 'LineWidth', 2); 24 | plot(time, x_lqr(1, :), 'r--', 'LineWidth', 1.5); 25 | plot(time, x_mpc_wtc(1, :), 'g-.', 'LineWidth', 1.5); 26 | hold off; 27 | grid on; 28 | xlabel('Time step'); 29 | ylabel('State x_1'); 30 | title('State x_1 Trajectory'); 31 | legend('RLMPC', 'LQR', 'MPC w/o TC'); 32 | 33 | subplot(2, 2, 2); 34 | hold on; 35 | plot(time, x_rlmpc(2, :), 'b-', 'LineWidth', 2); 36 | plot(time, x_lqr(2, :), 'r--', 'LineWidth', 1.5); 37 | plot(time, x_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 38 | hold off; 39 | grid on; 40 | xlabel('Time step'); 41 | ylabel('State x_2'); 42 | title('State x_2 Trajectory'); 43 | legend('RLMPC', 'LQR', 'MPC w/o TC'); 44 | 45 | % Plot state phase portrait 46 | subplot(2, 2, 3); 47 | hold on; 48 | plot(x_rlmpc(1, :), x_rlmpc(2, :), 'b-', 'LineWidth', 2); 49 | plot(x_lqr(1, :), x_lqr(2, :), 'r--', 'LineWidth', 1.5); 50 | plot(x_mpc_wtc(1, :), x_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 51 | plot(x_rlmpc(1, 1), x_rlmpc(2, 1), 'ko', 'MarkerSize', 8, 'MarkerFaceColor', 'k'); 52 | plot(0, 0, 'kx', 'MarkerSize', 8, 'LineWidth', 2); 53 | hold off; 54 | grid on; 55 | xlabel('State x_1'); 56 | ylabel('State x_2'); 57 | title('State Phase Portrait'); 58 | legend('RLMPC', 'LQR', 'MPC w/o TC', 'Initial State', 'Target State'); 59 | 60 | % Plot control inputs 61 | subplot(2, 2, 4); 62 | hold on; 63 | stairs(0:sim_steps-1, u_rlmpc, 'b-', 'LineWidth', 2); 64 | stairs(0:sim_steps-1, u_lqr, 'r--', 'LineWidth', 1.5); 65 | stairs(0:sim_steps-1, u_mpc_wtc, 'g-.', 'LineWidth', 1.5); 66 | hold off; 67 | grid on; 68 | xlabel('Time step'); 69 | ylabel('Control input u'); 70 | title('Control Inputs'); 71 | legend('RLMPC', 'LQR', 'MPC w/o TC'); 72 | 73 | % Create figure for accumulated costs 74 | figure('Name', 'Linear System: Performance Comparison', 'Position', [100, 700, 1000, 600]); 75 | 76 | % Plot accumulated costs 77 | subplot(2, 2, 1); 78 | hold on; 79 | % Use time range that matches the actual trajectory length 80 | plot_time = 0:length(acc_cost_rlmpc)-1; 81 | plot(plot_time, acc_cost_rlmpc, 'b-', 'LineWidth', 2); 82 | plot(plot_time, acc_cost_lqr, 'r--', 'LineWidth', 1.5); 83 | plot(plot_time, acc_cost_mpc_wtc, 'g-.', 'LineWidth', 1.5); 84 | hold off; 85 | grid on; 86 | xlabel('Time step'); 87 | ylabel('Accumulated cost'); 88 | title('Accumulated Costs (ACC)'); 89 | legend('RLMPC', 'LQR', 'MPC w/o TC'); 90 | % Set y-axis to start from 0 for better comparison 91 | y_lim = get(gca, 'YLim'); 92 | set(gca, 'YLim', [0, y_lim(2)]); 93 | 94 | % Plot final accumulated costs as bar chart 95 | subplot(2, 2, 2); 96 | final_costs = [acc_cost_rlmpc(end), acc_cost_lqr(end), acc_cost_mpc_wtc(end)]; 97 | bar(final_costs); 98 | set(gca, 'XTickLabel', {'RLMPC', 'LQR', 'MPC w/o TC'}); 99 | grid on; 100 | ylabel('Total cost'); 101 | title('Total Accumulated Cost'); 102 | 103 | % Plot average cost per iteration during training 104 | if ~isempty(cost_history) 105 | subplot(2, 2, 3); 106 | % Filter out any NaN values that might be present in cost_history 107 | valid_indices = ~isnan(cost_history); 108 | valid_costs = cost_history(valid_indices); 109 | 110 | % Prepare iterations data 111 | if exist('iter_history', 'var') && ~isempty(iter_history) 112 | iterations = iter_history(valid_indices); 113 | else 114 | iterations = find(valid_indices); 115 | end 116 | 117 | % If there's at least one valid data point, plot it 118 | if ~isempty(valid_costs) 119 | plot(iterations, valid_costs, 'b-o', 'LineWidth', 2); 120 | grid on; 121 | xlabel('Policy iteration'); 122 | ylabel('Average cost'); 123 | title('Average Cost per Policy Iteration'); 124 | else 125 | text(0.5, 0.5, 'No valid cost data available', 'HorizontalAlignment', 'center'); 126 | axis off; 127 | end 128 | end 129 | 130 | % Plot weight convergence during training 131 | if ~isempty(W_history) 132 | subplot(2, 2, 4); 133 | % Use iter_history if provided, otherwise use 1:size(W_history, 2) 134 | if exist('iter_history', 'var') && ~isempty(iter_history) 135 | iterations = iter_history(1:size(W_history, 2)); 136 | else 137 | iterations = 1:size(W_history, 2); 138 | end 139 | 140 | % Only plot if there are at least 2 iterations of data 141 | if length(iterations) > 1 142 | plot(iterations, W_history', 'LineWidth', 1.5); 143 | grid on; 144 | xlabel('Policy iteration'); 145 | ylabel('Weight value'); 146 | title('Value Function Weight Convergence'); 147 | legend_str = cell(1, size(W_history, 1)); 148 | for i = 1:size(W_history, 1) 149 | legend_str{i} = ['W_' num2str(i)]; 150 | end 151 | legend(legend_str, 'Location', 'eastoutside'); 152 | else 153 | text(0.5, 0.5, 'Insufficient weight history data', 'HorizontalAlignment', 'center'); 154 | axis off; 155 | end 156 | end 157 | 158 | % Display performance metrics 159 | disp('Performance Metrics:'); 160 | disp(['RLMPC Total Cost: ', num2str(acc_cost_rlmpc(end))]); 161 | disp(['LQR Total Cost: ', num2str(acc_cost_lqr(end))]); 162 | disp(['MPC w/o TC Total Cost: ', num2str(acc_cost_mpc_wtc(end))]); 163 | disp(['RLMPC vs LQR improvement: ', num2str((acc_cost_lqr(end) - acc_cost_rlmpc(end))/acc_cost_lqr(end)*100), '%']); 164 | disp(['RLMPC vs MPC w/o TC improvement: ', num2str((acc_cost_mpc_wtc(end) - acc_cost_rlmpc(end))/acc_cost_mpc_wtc(end)*100), '%']); 165 | 166 | end -------------------------------------------------------------------------------- /plot_comparison_nonlinear.m: -------------------------------------------------------------------------------- 1 | function plot_comparison_nonlinear(x_rlmpc, x_mpc_long, x_mpc_wtc, u_rlmpc, u_mpc_long, u_mpc_wtc, ... 2 | cost_rlmpc, cost_mpc_long, cost_mpc_wtc, W_history, cost_history, iter_history) 3 | % PLOT_COMPARISON_NONLINEAR Plot comparison results for the nonlinear vehicle system 4 | % 5 | % This function plots state trajectories, control inputs, accumulated costs, 6 | % and value function weight convergence for the nonlinear vehicle case 7 | 8 | % Calculate time vector 9 | sim_steps = size(x_rlmpc, 2) - 1; 10 | time = 0:sim_steps; 11 | 12 | % Calculate accumulated costs 13 | acc_cost_rlmpc = cumsum(cost_rlmpc); 14 | acc_cost_mpc_long = cumsum(cost_mpc_long); 15 | acc_cost_mpc_wtc = cumsum(cost_mpc_wtc); 16 | 17 | % Create figure for state trajectories 18 | figure('Name', 'Nonlinear Vehicle: State Trajectories', 'Position', [100, 100, 1000, 600]); 19 | 20 | % Plot position trajectory in xy-plane 21 | subplot(2, 2, 1); 22 | hold on; 23 | plot(x_rlmpc(1, :), x_rlmpc(2, :), 'b-', 'LineWidth', 2); 24 | plot(x_mpc_long(1, :), x_mpc_long(2, :), 'r--', 'LineWidth', 1.5); 25 | plot(x_mpc_wtc(1, :), x_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 26 | plot(x_rlmpc(1, 1), x_rlmpc(2, 1), 'ko', 'MarkerSize', 8, 'MarkerFaceColor', 'k'); 27 | plot(0, 0, 'kx', 'MarkerSize', 8, 'LineWidth', 2); 28 | % Plot arrows to show orientation at intervals 29 | interval = max(1, floor(sim_steps/10)); 30 | for i = 1:interval:sim_steps 31 | % RLMPC 32 | arrow_length = 0.1; 33 | theta = x_rlmpc(3, i); 34 | quiver(x_rlmpc(1, i), x_rlmpc(2, i), arrow_length*cos(theta), arrow_length*sin(theta), 0, 'b', 'LineWidth', 1.5); 35 | end 36 | hold off; 37 | grid on; 38 | xlabel('x position'); 39 | ylabel('y position'); 40 | title('Position Trajectory (x-y plane)'); 41 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'Initial Position', 'Target Position'); 42 | 43 | % Plot x position over time 44 | subplot(2, 2, 2); 45 | hold on; 46 | plot(time, x_rlmpc(1, :), 'b-', 'LineWidth', 2); 47 | plot(time, x_mpc_long(1, :), 'r--', 'LineWidth', 1.5); 48 | plot(time, x_mpc_wtc(1, :), 'g-.', 'LineWidth', 1.5); 49 | hold off; 50 | grid on; 51 | xlabel('Time step'); 52 | ylabel('x position'); 53 | title('x Position over Time'); 54 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC'); 55 | 56 | % Plot y position over time 57 | subplot(2, 2, 3); 58 | hold on; 59 | plot(time, x_rlmpc(2, :), 'b-', 'LineWidth', 2); 60 | plot(time, x_mpc_long(2, :), 'r--', 'LineWidth', 1.5); 61 | plot(time, x_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 62 | hold off; 63 | grid on; 64 | xlabel('Time step'); 65 | ylabel('y position'); 66 | title('y Position over Time'); 67 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC'); 68 | 69 | % Plot orientation over time 70 | subplot(2, 2, 4); 71 | hold on; 72 | plot(time, x_rlmpc(3, :), 'b-', 'LineWidth', 2); 73 | plot(time, x_mpc_long(3, :), 'r--', 'LineWidth', 1.5); 74 | plot(time, x_mpc_wtc(3, :), 'g-.', 'LineWidth', 1.5); 75 | hold off; 76 | grid on; 77 | xlabel('Time step'); 78 | ylabel('Orientation \theta (rad)'); 79 | title('Orientation over Time'); 80 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC'); 81 | 82 | % Create figure for control inputs 83 | figure('Name', 'Nonlinear Vehicle: Control Inputs', 'Position', [100, 700, 1000, 400]); 84 | 85 | % Plot linear velocity input 86 | subplot(1, 2, 1); 87 | hold on; 88 | stairs(0:sim_steps-1, u_rlmpc(1, :), 'b-', 'LineWidth', 2); 89 | stairs(0:sim_steps-1, u_mpc_long(1, :), 'r--', 'LineWidth', 1.5); 90 | stairs(0:sim_steps-1, u_mpc_wtc(1, :), 'g-.', 'LineWidth', 1.5); 91 | yline(1, 'k--'); % Upper bound 92 | yline(-1, 'k--'); % Lower bound 93 | hold off; 94 | grid on; 95 | xlabel('Time step'); 96 | ylabel('Linear velocity v'); 97 | title('Linear Velocity Input'); 98 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'Bounds'); 99 | 100 | % Plot angular velocity input 101 | subplot(1, 2, 2); 102 | hold on; 103 | stairs(0:sim_steps-1, u_rlmpc(2, :), 'b-', 'LineWidth', 2); 104 | stairs(0:sim_steps-1, u_mpc_long(2, :), 'r--', 'LineWidth', 1.5); 105 | stairs(0:sim_steps-1, u_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 106 | yline(4, 'k--'); % Upper bound 107 | yline(-4, 'k--'); % Lower bound 108 | hold off; 109 | grid on; 110 | xlabel('Time step'); 111 | ylabel('Angular velocity \omega'); 112 | title('Angular Velocity Input'); 113 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'Bounds'); 114 | 115 | % Create figure for accumulated costs and training progress 116 | figure('Name', 'Nonlinear Vehicle: Performance Comparison', 'Position', [100, 1100, 1000, 600]); 117 | 118 | % Plot accumulated costs 119 | subplot(2, 2, 1); 120 | hold on; 121 | % Use time range that matches the actual trajectory length 122 | plot_time = 0:length(acc_cost_rlmpc)-1; 123 | plot(plot_time, acc_cost_rlmpc, 'b-', 'LineWidth', 2); 124 | plot(plot_time, acc_cost_mpc_long, 'r--', 'LineWidth', 1.5); 125 | plot(plot_time, acc_cost_mpc_wtc, 'g-.', 'LineWidth', 1.5); 126 | hold off; 127 | grid on; 128 | xlabel('Time step'); 129 | ylabel('Accumulated cost'); 130 | title('Accumulated Costs (ACC)'); 131 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC'); 132 | % Set y-axis to start from 0 for better comparison 133 | y_lim = get(gca, 'YLim'); 134 | set(gca, 'YLim', [0, y_lim(2)]); 135 | 136 | % Plot final accumulated costs as bar chart 137 | subplot(2, 2, 2); 138 | final_costs = [acc_cost_rlmpc(end), acc_cost_mpc_long(end), acc_cost_mpc_wtc(end)]; 139 | bar(final_costs); 140 | set(gca, 'XTickLabel', {'RLMPC', 'MPC-Long', 'MPC w/o TC'}); 141 | grid on; 142 | ylabel('Total cost'); 143 | title('Total Accumulated Cost'); 144 | 145 | % Plot average cost per iteration during training 146 | if ~isempty(cost_history) 147 | subplot(2, 2, 3); 148 | % Filter out any NaN values that might be present in cost_history 149 | valid_indices = ~isnan(cost_history); 150 | valid_costs = cost_history(valid_indices); 151 | 152 | % Prepare iterations data 153 | if exist('iter_history', 'var') && ~isempty(iter_history) 154 | iterations = iter_history(valid_indices); 155 | else 156 | iterations = find(valid_indices); 157 | end 158 | 159 | % If there's at least one valid data point, plot it 160 | if ~isempty(valid_costs) 161 | plot(iterations, valid_costs, 'b-o', 'LineWidth', 2); 162 | grid on; 163 | xlabel('Policy iteration'); 164 | ylabel('Average cost'); 165 | title('Average Cost per Policy Iteration'); 166 | else 167 | text(0.5, 0.5, 'No valid cost data available', 'HorizontalAlignment', 'center'); 168 | axis off; 169 | end 170 | end 171 | 172 | % Plot weight convergence during training 173 | if ~isempty(W_history) 174 | subplot(2, 2, 4); 175 | % If there are many weights, only plot a subset 176 | if size(W_history, 1) > 10 177 | % Select a representative subset of weights to plot 178 | indices = round(linspace(1, size(W_history, 1), 10)); 179 | W_plot = W_history(indices, :); 180 | 181 | % Use iter_history if provided, otherwise use 1:size(W_history, 2) 182 | if exist('iter_history', 'var') && ~isempty(iter_history) 183 | iterations = iter_history(1:size(W_history, 2)); 184 | else 185 | iterations = 1:size(W_history, 2); 186 | end 187 | 188 | % Only plot if there are at least 2 iterations of data 189 | if length(iterations) > 1 190 | plot(iterations, W_plot', 'LineWidth', 1.5); 191 | grid on; 192 | xlabel('Policy iteration'); 193 | ylabel('Weight value'); 194 | title('Value Function Weight Convergence (Sample)'); 195 | 196 | legend_str = cell(1, length(indices)); 197 | for i = 1:length(indices) 198 | legend_str{i} = ['W_' num2str(indices(i))]; 199 | end 200 | legend(legend_str, 'Location', 'eastoutside'); 201 | else 202 | text(0.5, 0.5, 'Insufficient weight history data', 'HorizontalAlignment', 'center'); 203 | axis off; 204 | end 205 | else 206 | % Use iter_history if provided, otherwise use 1:size(W_history, 2) 207 | if exist('iter_history', 'var') && ~isempty(iter_history) 208 | iterations = iter_history(1:size(W_history, 2)); 209 | else 210 | iterations = 1:size(W_history, 2); 211 | end 212 | 213 | % Only plot if there are at least 2 iterations of data 214 | if length(iterations) > 1 215 | plot(iterations, W_history', 'LineWidth', 1.5); 216 | grid on; 217 | xlabel('Policy iteration'); 218 | ylabel('Weight value'); 219 | title('Value Function Weight Convergence'); 220 | 221 | legend_str = cell(1, size(W_history, 1)); 222 | for i = 1:size(W_history, 1) 223 | legend_str{i} = ['W_' num2str(i)]; 224 | end 225 | legend(legend_str, 'Location', 'eastoutside'); 226 | else 227 | text(0.5, 0.5, 'Insufficient weight history data', 'HorizontalAlignment', 'center'); 228 | axis off; 229 | end 230 | end 231 | end 232 | 233 | % Display performance metrics 234 | disp('Performance Metrics (Nonlinear Vehicle):'); 235 | disp(['RLMPC Total Cost: ', num2str(acc_cost_rlmpc(end))]); 236 | disp(['MPC-Long Total Cost: ', num2str(acc_cost_mpc_long(end))]); 237 | disp(['MPC w/o TC Total Cost: ', num2str(acc_cost_mpc_wtc(end))]); 238 | disp(['RLMPC vs MPC-Long improvement: ', num2str((acc_cost_mpc_long(end) - acc_cost_rlmpc(end))/acc_cost_mpc_long(end)*100), '%']); 239 | disp(['RLMPC vs MPC w/o TC improvement: ', num2str((acc_cost_mpc_wtc(end) - acc_cost_rlmpc(end))/acc_cost_mpc_wtc(end)*100), '%']); 240 | 241 | end -------------------------------------------------------------------------------- /plot_comparison_nonlinear_extended.m: -------------------------------------------------------------------------------- 1 | function plot_comparison_nonlinear_extended(x_rlmpc, x_mpc_long, x_mpc_wtc, x_mpc_tc, ... 2 | u_rlmpc, u_mpc_long, u_mpc_wtc, u_mpc_tc, ... 3 | cost_rlmpc, cost_mpc_long, cost_mpc_wtc, cost_mpc_tc, ... 4 | W_history, cost_history, iter_history) 5 | % PLOT_COMPARISON_NONLINEAR_EXTENDED Plot comparison results for the nonlinear vehicle system 6 | % including the additional MPC with terminal cost controller 7 | % 8 | % This function is an extended version of plot_comparison_nonlinear that 9 | % includes a fourth controller (MPC with terminal cost) in the comparison 10 | 11 | % Calculate time vector 12 | sim_steps = size(x_rlmpc, 2) - 1; 13 | time = 0:sim_steps; 14 | 15 | % Calculate accumulated costs 16 | acc_cost_rlmpc = cumsum(cost_rlmpc); 17 | acc_cost_mpc_long = cumsum(cost_mpc_long); 18 | acc_cost_mpc_wtc = cumsum(cost_mpc_wtc); 19 | acc_cost_mpc_tc = cumsum(cost_mpc_tc); 20 | 21 | % Create figure for state trajectories 22 | figure('Name', 'Nonlinear Vehicle: State Trajectories', 'Position', [100, 100, 1000, 600]); 23 | 24 | % Plot position trajectory in xy-plane 25 | subplot(2, 2, 1); 26 | hold on; 27 | plot(x_rlmpc(1, :), x_rlmpc(2, :), 'b-', 'LineWidth', 2); 28 | plot(x_mpc_long(1, :), x_mpc_long(2, :), 'r--', 'LineWidth', 1.5); 29 | plot(x_mpc_wtc(1, :), x_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 30 | plot(x_mpc_tc(1, :), x_mpc_tc(2, :), 'm:', 'LineWidth', 2); % Add MPC-TC 31 | plot(x_rlmpc(1, 1), x_rlmpc(2, 1), 'ko', 'MarkerSize', 8, 'MarkerFaceColor', 'k'); 32 | plot(0, 0, 'kx', 'MarkerSize', 8, 'LineWidth', 2); 33 | % Plot arrows to show orientation at intervals 34 | interval = max(1, floor(sim_steps/10)); 35 | for i = 1:interval:sim_steps 36 | % RLMPC 37 | arrow_length = 0.1; 38 | theta = x_rlmpc(3, i); 39 | quiver(x_rlmpc(1, i), x_rlmpc(2, i), arrow_length*cos(theta), arrow_length*sin(theta), 0, 'b', 'LineWidth', 1.5); 40 | end 41 | hold off; 42 | grid on; 43 | xlabel('x position'); 44 | ylabel('y position'); 45 | title('Position Trajectory (x-y plane)'); 46 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC', 'Initial Position', 'Target Position'); 47 | 48 | % Plot x position over time 49 | subplot(2, 2, 2); 50 | hold on; 51 | plot(time, x_rlmpc(1, :), 'b-', 'LineWidth', 2); 52 | plot(time, x_mpc_long(1, :), 'r--', 'LineWidth', 1.5); 53 | plot(time, x_mpc_wtc(1, :), 'g-.', 'LineWidth', 1.5); 54 | plot(time, x_mpc_tc(1, :), 'm:', 'LineWidth', 2); % Add MPC-TC 55 | hold off; 56 | grid on; 57 | xlabel('Time step'); 58 | ylabel('x position'); 59 | title('x Position over Time'); 60 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC'); 61 | 62 | % Plot y position over time 63 | subplot(2, 2, 3); 64 | hold on; 65 | plot(time, x_rlmpc(2, :), 'b-', 'LineWidth', 2); 66 | plot(time, x_mpc_long(2, :), 'r--', 'LineWidth', 1.5); 67 | plot(time, x_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 68 | plot(time, x_mpc_tc(2, :), 'm:', 'LineWidth', 2); % Add MPC-TC 69 | hold off; 70 | grid on; 71 | xlabel('Time step'); 72 | ylabel('y position'); 73 | title('y Position over Time'); 74 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC'); 75 | 76 | % Plot orientation over time 77 | subplot(2, 2, 4); 78 | hold on; 79 | plot(time, x_rlmpc(3, :), 'b-', 'LineWidth', 2); 80 | plot(time, x_mpc_long(3, :), 'r--', 'LineWidth', 1.5); 81 | plot(time, x_mpc_wtc(3, :), 'g-.', 'LineWidth', 1.5); 82 | plot(time, x_mpc_tc(3, :), 'm:', 'LineWidth', 2); % Add MPC-TC 83 | hold off; 84 | grid on; 85 | xlabel('Time step'); 86 | ylabel('Orientation \theta (rad)'); 87 | title('Orientation over Time'); 88 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC'); 89 | 90 | % Create figure for control inputs 91 | figure('Name', 'Nonlinear Vehicle: Control Inputs', 'Position', [100, 700, 1000, 400]); 92 | 93 | % Plot linear velocity input 94 | subplot(1, 2, 1); 95 | hold on; 96 | stairs(0:sim_steps-1, u_rlmpc(1, :), 'b-', 'LineWidth', 2); 97 | stairs(0:sim_steps-1, u_mpc_long(1, :), 'r--', 'LineWidth', 1.5); 98 | stairs(0:sim_steps-1, u_mpc_wtc(1, :), 'g-.', 'LineWidth', 1.5); 99 | stairs(0:sim_steps-1, u_mpc_tc(1, :), 'm:', 'LineWidth', 2); % Add MPC-TC 100 | yline(1, 'k--'); % Upper bound 101 | yline(-1, 'k--'); % Lower bound 102 | hold off; 103 | grid on; 104 | xlabel('Time step'); 105 | ylabel('Linear velocity v'); 106 | title('Linear Velocity Input'); 107 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC', 'Bounds'); 108 | 109 | % Plot angular velocity input 110 | subplot(1, 2, 2); 111 | hold on; 112 | stairs(0:sim_steps-1, u_rlmpc(2, :), 'b-', 'LineWidth', 2); 113 | stairs(0:sim_steps-1, u_mpc_long(2, :), 'r--', 'LineWidth', 1.5); 114 | stairs(0:sim_steps-1, u_mpc_wtc(2, :), 'g-.', 'LineWidth', 1.5); 115 | stairs(0:sim_steps-1, u_mpc_tc(2, :), 'm:', 'LineWidth', 2); % Add MPC-TC 116 | yline(4, 'k--'); % Upper bound 117 | yline(-4, 'k--'); % Lower bound 118 | hold off; 119 | grid on; 120 | xlabel('Time step'); 121 | ylabel('Angular velocity \omega'); 122 | title('Angular Velocity Input'); 123 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC', 'Bounds'); 124 | 125 | % Create figure for accumulated costs and training progress 126 | figure('Name', 'Nonlinear Vehicle: Performance Comparison', 'Position', [100, 1100, 1000, 600]); 127 | 128 | % Plot accumulated costs 129 | subplot(2, 2, 1); 130 | hold on; 131 | % Use time range that matches the actual trajectory length 132 | plot_time = 0:length(acc_cost_rlmpc)-1; 133 | plot(plot_time, acc_cost_rlmpc, 'b-', 'LineWidth', 2); 134 | plot(plot_time, acc_cost_mpc_long, 'r--', 'LineWidth', 1.5); 135 | plot(plot_time, acc_cost_mpc_wtc, 'g-.', 'LineWidth', 1.5); 136 | plot(plot_time, acc_cost_mpc_tc, 'm:', 'LineWidth', 2); % Add MPC-TC 137 | hold off; 138 | grid on; 139 | xlabel('Time step'); 140 | ylabel('Accumulated cost'); 141 | title('Accumulated Costs (ACC)'); 142 | legend('RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC'); 143 | % Set y-axis to start from 0 for better comparison 144 | y_lim = get(gca, 'YLim'); 145 | set(gca, 'YLim', [0, y_lim(2)]); 146 | 147 | % Plot final accumulated costs as bar chart 148 | subplot(2, 2, 2); 149 | final_costs = [acc_cost_rlmpc(end), acc_cost_mpc_long(end), acc_cost_mpc_wtc(end), acc_cost_mpc_tc(end)]; 150 | bar(final_costs); 151 | set(gca, 'XTickLabel', {'RLMPC', 'MPC-Long', 'MPC w/o TC', 'MPC-TC'}); 152 | grid on; 153 | ylabel('Total cost'); 154 | title('Total Accumulated Cost'); 155 | 156 | % Plot average cost per iteration during training 157 | if ~isempty(cost_history) 158 | subplot(2, 2, 3); 159 | % Filter out any NaN values that might be present in cost_history 160 | valid_indices = ~isnan(cost_history); 161 | valid_costs = cost_history(valid_indices); 162 | 163 | % Prepare iterations data 164 | if exist('iter_history', 'var') && ~isempty(iter_history) 165 | iterations = iter_history(valid_indices); 166 | else 167 | iterations = find(valid_indices); 168 | end 169 | 170 | % If there's at least one valid data point, plot it 171 | if ~isempty(valid_costs) 172 | plot(iterations, valid_costs, 'b-o', 'LineWidth', 2); 173 | grid on; 174 | xlabel('Policy iteration'); 175 | ylabel('Average cost'); 176 | title('Average Cost per Policy Iteration'); 177 | else 178 | text(0.5, 0.5, 'No valid cost data available', 'HorizontalAlignment', 'center'); 179 | axis off; 180 | end 181 | end 182 | 183 | % Plot value function weight convergence history 184 | subplot(2, 2, 4); 185 | if isempty(W_history) 186 | text(0.5, 0.5, 'No weight history available', 'HorizontalAlignment', 'center'); 187 | axis off; 188 | else 189 | % Use iter_history if provided, otherwise use 1:size(W_history, 2) 190 | if exist('iter_history', 'var') && ~isempty(iter_history) 191 | iterations = iter_history(1:size(W_history, 2)); 192 | else 193 | iterations = 1:size(W_history, 2); 194 | end 195 | 196 | % Only plot if there are at least 2 iterations of data 197 | if length(iterations) > 1 198 | plot(iterations, W_history', 'LineWidth', 1.5); 199 | grid on; 200 | xlabel('Policy iteration'); 201 | ylabel('Weight value'); 202 | title('Value Function Weight Convergence'); 203 | 204 | legend_str = cell(1, size(W_history, 1)); 205 | for i = 1:size(W_history, 1) 206 | legend_str{i} = ['W_' num2str(i)]; 207 | end 208 | legend(legend_str, 'Location', 'eastoutside'); 209 | else 210 | text(0.5, 0.5, 'Insufficient weight history data', 'HorizontalAlignment', 'center'); 211 | axis off; 212 | end 213 | end 214 | 215 | % Display performance metrics 216 | disp('Performance Metrics (Nonlinear Vehicle with MPC-TC):'); 217 | disp(['RLMPC Total Cost: ', num2str(acc_cost_rlmpc(end))]); 218 | disp(['MPC-Long Total Cost: ', num2str(acc_cost_mpc_long(end))]); 219 | disp(['MPC w/o TC Total Cost: ', num2str(acc_cost_mpc_wtc(end))]); 220 | disp(['MPC-TC Total Cost: ', num2str(acc_cost_mpc_tc(end))]); 221 | disp(['RLMPC vs MPC-Long improvement: ', num2str((acc_cost_mpc_long(end) - acc_cost_rlmpc(end))/acc_cost_mpc_long(end)*100), '%']); 222 | disp(['RLMPC vs MPC w/o TC improvement: ', num2str((acc_cost_mpc_wtc(end) - acc_cost_rlmpc(end))/acc_cost_mpc_wtc(end)*100), '%']); 223 | disp(['RLMPC vs MPC-TC improvement: ', num2str((acc_cost_mpc_tc(end) - acc_cost_rlmpc(end))/acc_cost_mpc_tc(end)*100), '%']); 224 | disp(['MPC-TC vs MPC w/o TC improvement: ', num2str((acc_cost_mpc_wtc(end) - acc_cost_mpc_tc(end))/acc_cost_mpc_wtc(end)*100), '%']); 225 | 226 | end -------------------------------------------------------------------------------- /simulate_system.m: -------------------------------------------------------------------------------- 1 | function [x_history, u_history, cost_history] = simulate_system(system, controller, x0, sim_steps) 2 | % SIMULATE_SYSTEM Simulates a dynamical system with a given controller 3 | % 4 | % Inputs: 5 | % system - System object with dynamics and cost methods 6 | % controller - Controller object with solve method 7 | % x0 - Initial state 8 | % sim_steps - Number of simulation steps 9 | % 10 | % Outputs: 11 | % x_history - State trajectory history (n x sim_steps+1) 12 | % u_history - Control input history (m x sim_steps) 13 | % cost_history - Cost history (1 x sim_steps) 14 | 15 | % Declare persistent variables at the beginning of the function 16 | persistent consecutive_failures; 17 | 18 | % Get system dimensions 19 | n = size(x0, 1); 20 | m = system.m; 21 | 22 | % Initialize history arrays 23 | x_history = zeros(n, sim_steps+1); 24 | u_history = zeros(m, sim_steps); 25 | cost_history = zeros(1, sim_steps); 26 | 27 | % Set initial state 28 | x_history(:, 1) = x0; 29 | system.set_state(x0); 30 | 31 | % Keep track of the last valid control input for fallback 32 | last_valid_u = zeros(m, 1); 33 | 34 | % Initialize consecutive failures counter if not already initialized 35 | if isempty(consecutive_failures) 36 | consecutive_failures = 0; 37 | end 38 | 39 | % Simulation loop 40 | for k = 1:sim_steps 41 | % Get current state 42 | x_k = system.get_state(); 43 | 44 | % Try to compute control input using controller 45 | try 46 | [u_seq, ~, ~] = controller.solve(x_k); 47 | 48 | % Apply first control input 49 | if size(u_seq, 2) >= 1 50 | u_k = u_seq(:, 1); 51 | last_valid_u = u_k; % Update last valid control 52 | consecutive_failures = 0; % Reset failures counter 53 | else 54 | % If no control sequence returned, use last valid control 55 | u_k = last_valid_u; 56 | consecutive_failures = consecutive_failures + 1; 57 | warning('Empty control sequence at step %d, using last valid control', k); 58 | end 59 | catch e 60 | % If controller fails, use last valid control as fallback 61 | u_k = last_valid_u; 62 | consecutive_failures = consecutive_failures + 1; 63 | 64 | % More descriptive error message based on failure type 65 | if contains(e.message, 'converge') 66 | warning('Controller optimization failed to converge at step %d. Using last valid control.', k); 67 | elseif contains(e.message, 'infeasible') 68 | warning('Controller problem is infeasible at step %d. Using last valid control.', k); 69 | else 70 | warning('Controller failed at step %d: %s\nUsing last valid control.', k, e.message); 71 | end 72 | 73 | % If we have too many consecutive failures, create a safer fallback 74 | if consecutive_failures > 3 75 | % Create a safer fallback that gradually reduces control input 76 | decay_factor = 0.8; 77 | u_k = u_k * decay_factor; 78 | warning('Multiple consecutive failures detected. Reducing control input magnitude for safety.'); 79 | 80 | % If using nonlinear vehicle, try to create a stabilizing input 81 | if isa(system, 'NonlinearVehicle') && length(u_k) > 1 82 | % If velocity is non-zero, reduce it further 83 | if abs(u_k(1)) > 0.1 84 | u_k(1) = u_k(1) * 0.7; % Reduce velocity more aggressively 85 | end 86 | 87 | % Reduce steering angle to stabilize 88 | if length(u_k) > 1 89 | u_k(2) = u_k(2) * 0.5; % Reduce steering angle 90 | end 91 | end 92 | end 93 | end 94 | 95 | % Apply input limits if this is a nonlinear vehicle (safety check) 96 | if isa(system, 'NonlinearVehicle') 97 | u_k(1) = max(min(u_k(1), system.v_bounds(2)), system.v_bounds(1)); 98 | if length(u_k) > 1 99 | u_k(2) = max(min(u_k(2), system.omega_bounds(2)), system.omega_bounds(1)); 100 | end 101 | end 102 | 103 | % Compute cost 104 | cost_history(k) = system.stage_cost(x_k, u_k); 105 | 106 | % Store control input 107 | u_history(:, k) = u_k; 108 | 109 | % Simulate system for one step 110 | x_next = system.step(u_k); 111 | 112 | % Store next state 113 | x_history(:, k+1) = x_next; 114 | 115 | % Provide progress update for long simulations 116 | if mod(k, 10) == 0 117 | fprintf('Simulation: %d/%d steps completed\n', k, sim_steps); 118 | end 119 | end 120 | 121 | end --------------------------------------------------------------------------------