├── LICENSE.txt ├── README.md ├── examples ├── circle.png ├── infty.png ├── lorenz.png ├── mg17.png └── temp ├── experiments ├── change_speed.m ├── double_arm.m ├── iter_noise.m ├── normal_success_rate.m ├── speed_test.m ├── speed_test_main.m ├── success_rate.m ├── success_rate_1.m ├── success_rate_bias.m ├── temp ├── uncertain_test_saferegion.m └── uncertain_test_saferegion_main.m ├── func_desired_traj.m ├── func_double_arm.m ├── func_reservoir_train.m ├── func_reservoir_validate.m ├── func_train_val.m ├── main.m ├── rand_traj_control.m ├── read_data └── temp ├── robot_data_generator.m ├── save_file ├── all_traj_06282022.mat └── temp ├── tools ├── func_limits.m ├── func_plot_figures.m ├── func_plot_movie.m └── func_rmse.m └── val_and_update.m /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zheng-Meng Zhai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tracking Control 2 | Codes of a submitted manuscript to Nature Communications. 3 | 4 | [Model-free tracking control of complex dynamical trajectories with machine learning](https://www.nature.com/articles/s41467-023-41379-3) has been published in Nature Communications! 5 | 6 | # Requirements 7 | Please download the dataset at https://doi.org/10.5281/zenodo.8044994
8 | The chaotic trajectories should be moved into the folder: read_data. The periodic trajectories are generated in the code 9 | 10 | Note that we use a built-in package 'matsplit' in MATLAB. Please click 'Home', choose 'Add-Ons', search this package and install it to run the code. 11 | 12 | # Example 13 | Run 'main.m' with traj_type = 'circle', you will get the ground truth and tracked trajectories in the picture bellow: 14 | 15 | ![Lorenz](./examples/circle.png) 16 | 17 | Change traj_type to others to track different trajectories, e.g., traj_type = 'lorenz'. 18 | 19 | # Citation 20 | This work is available at [https://www.nature.com/articles/s41467-023-41379-3](https://www.nature.com/articles/s41467-023-41379-3), and can be cited with the followling bibtex entry: 21 | ``` 22 | @article{zhai2023model, 23 | title={Model-free tracking control of complex dynamical trajectories with machine learning}, 24 | author={Zhai, Zheng-Meng and Moradi, Mohammadamin and Kong, Ling-Wei and Glaz, Bryan and Haile, Mulugeta and Lai, Ying-Cheng}, 25 | journal={Nature Communications}, 26 | volume={14}, 27 | number={1}, 28 | pages={5698}, 29 | year={2023}, 30 | publisher={Nature Publishing Group UK London} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /examples/circle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/circle.png -------------------------------------------------------------------------------- /examples/infty.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/infty.png -------------------------------------------------------------------------------- /examples/lorenz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/lorenz.png -------------------------------------------------------------------------------- /examples/mg17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/mg17.png -------------------------------------------------------------------------------- /examples/temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /experiments/change_speed.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | %% change the frequency of infty symbol 6 | 7 | addpath('./tools/') 8 | 9 | time_today = datestr(now, 'mmddyyyy'); 10 | 11 | % traj_frequency_set = round(exp(linspace(log(1), log(500), 10))); 12 | traj_frequency_set = round(linspace(1, 400, 20)); 13 | 14 | iteration = 50; 15 | val_length_all = 300000; 16 | 17 | disturbance = 0.00; 18 | measurement_noise = 0.00; 19 | 20 | plot_movie = 0; 21 | traj_type = 'infty'; 22 | bridge_type = 'cubic'; 23 | failure.type = 'none'; 24 | blur.blur = 0; 25 | 26 | rmse_start_time = round(val_length_all * 3/5); 27 | rmse_end_time = val_length_all - 100; 28 | 29 | rmse_frequency_set = zeros(length(traj_frequency_set), iteration); 30 | 31 | for tfs = 1:length(traj_frequency_set) 32 | traj_frequency = traj_frequency_set(tfs); 33 | 34 | rmse_repeat_set = zeros(1, iteration); 35 | for repeat_i = 1:iteration 36 | load('./save_file/all_traj_06282022.mat') 37 | rng('shuffle') 38 | save_rend=0; 39 | idx=1; 40 | time_infor.val_length=val_length_all; 41 | 42 | val_and_update; 43 | 44 | rmse_repeat_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 45 | 46 | bb=1; 47 | end 48 | 49 | rmse_repeat_set = sort(rmse_repeat_set); 50 | 51 | rmse_frequency_set(tfs, :) = rmse_repeat_set; 52 | 53 | xxx = ['traj frequency:', traj_frequency]; 54 | disp(xxx) 55 | 56 | aa = 1; 57 | end 58 | 59 | 60 | save_frequency_iter.traj_type = traj_type; 61 | 62 | save_frequency_iter.val_length = val_length_all; 63 | save_frequency_iter.traj_frequency_set = traj_frequency_set; 64 | save_frequency_iter.rmse_frequency_set = rmse_frequency_set; 65 | 66 | save(['save_data/save_frequency_iter_', time_today, '_' num2str(randi(999)) '.mat'], "save_frequency_iter") 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /experiments/double_arm.m: -------------------------------------------------------------------------------- 1 | %% clear all 2 | 3 | disp('Preparing for clearing all and restart...') 4 | m=input('Do you want to continue, Y/N [Y]:', 's'); 5 | if m ~= 'Y' 6 | return 7 | end 8 | 9 | clear all 10 | close all 11 | clc 12 | 13 | addpath('./tools/') 14 | 15 | %% pre training 16 | dt=0.01; 17 | % input_infor={'q'}; 18 | input_infor={'xy', 'qdt'}; 19 | % input_infor={'q', 'qdt'}; 20 | % input_infor={'xy', 'qdt', 'q2dt'}; 21 | 22 | % in-out dimension 23 | dim_in=length(input_infor) * 4; 24 | dim_out=2; 25 | 26 | % double robot arm properties 27 | m1=1;m2=1; 28 | l1=0.5;l2=0.5; 29 | lc1=0.25;lc2=0.25; 30 | I1=0.03;I2=0.03; 31 | properties=[m1, m2, l1, l2, lc1, lc2, I1, I2]; 32 | 33 | reset_t=80; 34 | train_t = 200000; 35 | val_t=500; 36 | noise_level=2.0*10^(-2); 37 | % add pertubations and measurement noise to the model(and uncertainty) 38 | % proportional to the real value. 39 | disturbance = 0.00; 40 | measurement_noise = 0.00; 41 | 42 | n=200; 43 | hyperpara_set = [0.756250, 0.756250, 0.843750, -3.125, 106.71875, 2.0]; 44 | eig_rho = hyperpara_set(1); 45 | W_in_a = hyperpara_set(2); 46 | alpha = hyperpara_set(3); 47 | beta = 10^hyperpara_set(4); 48 | k = round( hyperpara_set(5)/200*n); 49 | kb = hyperpara_set(6); 50 | 51 | W_in = W_in_a*(2*rand(n,dim_in)-1); 52 | res_net=sprandsym(n,k/n); 53 | eig_D=eigs(res_net,1); 54 | res_net=(eig_rho/(abs(eig_D))).*res_net; 55 | res_net=full(res_net); 56 | 57 | section_len=round(reset_t/dt); % 30 58 | washup_length=round(1002/dt); 59 | train_length=round(train_t/dt)+round(5/dt); % 50000 60 | val_length=round(val_t/dt); % 300 61 | time_length = train_length + 2 * val_length + 3 * washup_length + 100; 62 | 63 | res_infor=struct('W_in', W_in, 'res_net', res_net, 'alpha', alpha, 'kb', kb, 'beta', beta, 'n', n); 64 | time_infor=struct('section_len', section_len, 'washup_length', washup_length, ... 65 | 'train_length', train_length, 'val_length', val_length, 'time_length', time_length); 66 | 67 | % generate training and validation data 68 | [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties); 69 | xy=xy(washup_length:end, :); 70 | q=q(washup_length:end, :); 71 | qdt=qdt(washup_length:end, :); 72 | q2dt=q2dt(washup_length:end, :); 73 | tau=tau(washup_length:end, :); 74 | data_reservoir = struct('xy', xy, 'q', q, 'qdt', qdt, 'q2dt', q2dt, 'tau', tau); 75 | 76 | clearvars xy q qdt q2dt tau 77 | 78 | %% training 79 | tic; 80 | [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out); 81 | toc; 82 | 83 | clearvars data_reservoir 84 | %% load data 85 | load_data = 1; 86 | if load_data==1 87 | load('./save_file/all_traj_06282022.mat') 88 | end 89 | 90 | %% validating 91 | 92 | rng('shuffle') 93 | % write the code that can update and pause 94 | % traj_type = 'infty'; 95 | % traj_type = 'lorenz'; 96 | % traj_type = 'mg17'; 97 | % traj_type = 'mg30'; 98 | % traj_type = 'circle' 99 | 100 | failure.type = 'none'; 101 | blur.blur = 0; 102 | 103 | plot_val_and_update=1; 104 | 105 | disturbance = 0.05; 106 | measurement_noise = 0.05; 107 | plot_movie = 0; 108 | 109 | traj_type = 'lorenz'; 110 | bridge_type = 'cubic'; 111 | 112 | time_infor.val_length=250000; 113 | 114 | save_rend=0; 115 | idx=1; 116 | 117 | val_and_update; 118 | 119 | % idx = 2 120 | traj_type = 'circle'; 121 | bridge_type = 'cubic'; 122 | traj_frequency = 150; 123 | 124 | time_infor.val_length=250000; 125 | save_rend=1; 126 | idx=2; 127 | 128 | val_and_update; 129 | 130 | % idx = 3 131 | traj_type = 'mg17'; 132 | bridge_type = 'cubic'; 133 | 134 | time_infor.val_length=250000; 135 | save_rend=1; 136 | idx=3; 137 | 138 | val_and_update; 139 | 140 | % idx = 4 141 | traj_type = 'infty'; 142 | bridge_type = 'cubic'; 143 | traj_frequency = 100; 144 | 145 | time_infor.val_length=250000; 146 | save_rend=1; 147 | idx=4; 148 | 149 | val_and_update; 150 | 151 | % normal method to calculate rmse 152 | rmse_length = 1000/dt; 153 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :)); 154 | rmse = sqrt(mean(mean(error.^2, 2))); 155 | 156 | 157 | %% plot figure 158 | 159 | % start_time = 1; 160 | % end_time = 290000; 161 | % 162 | % % plot trajectory 163 | % figure(); 164 | % hold on 165 | % plot(data_control(start_time:end_time, 1), data_control(start_time:end_time, 2),'r'); 166 | % plot(data_pred(start_time:end_time, 1), data_pred(start_time:end_time, 2),'b--'); 167 | % xlabel('x') 168 | % ylabel('y') 169 | % line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--') 170 | % line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--') 171 | % xlim([-1, 1]) 172 | % ylim([-1, 1]) 173 | % legend('desired trajectory', 'pred trajectory') 174 | % % 175 | % % figure(); 176 | % % hold on 177 | % % plot(data_control(100000:120000, 1), data_control(100000:120000, 2),'r'); 178 | % % plot(data_pred(100000:120000, 1), data_pred(100000:120000, 2),'b--'); 179 | % % xlabel('x') 180 | % % ylabel('y') 181 | % % line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--') 182 | % % line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--') 183 | % % xlim([-1, 1]) 184 | % % ylim([-1, 1]) 185 | % % legend('desired trajectory', 'pred trajectory') 186 | % 187 | % % plot q 188 | % q_control_plot=mod(q_control, pi); 189 | % q_pred_plot=mod(q_pred, pi); 190 | % 191 | % figure(); 192 | % hold on 193 | % plot(q_control_plot(start_time:end_time,1), 'r') 194 | % plot(q_pred_plot(start_time:end_time,1), 'b') 195 | % xlabel('time step') 196 | % ylabel('q(1)') 197 | % legend('desired', 'pred') 198 | % 199 | % figure() 200 | % hold on 201 | % plot(q_control_plot(start_time:end_time,2), 'r') 202 | % plot(q_pred_plot(start_time:end_time,2), 'b') 203 | % xlabel('time step') 204 | % ylabel('q(2)') 205 | % legend('desired', 'pred') 206 | % 207 | % % plot dq/dt 208 | % figure() 209 | % hold on 210 | % plot(qdt_control(start_time:end_time,1), 'r') 211 | % plot(qdt_pred(start_time:end_time,1), 'b') 212 | % xlabel('time step') 213 | % ylabel('dq/dt(1)') 214 | % legend('desired', 'pred') 215 | % 216 | % figure() 217 | % hold on 218 | % plot(qdt_control(start_time:end_time,2), 'r') 219 | % plot(qdt_pred(start_time:end_time,2), 'b') 220 | % xlabel('time step') 221 | % ylabel('dq/dt(2)') 222 | % legend('desired', 'pred') 223 | % 224 | % % plot d2q/dt2 225 | % remove_transient = 1000; 226 | % 227 | % figure() 228 | % hold on 229 | % plot(q2dt_control(start_time+remove_transient:end_time,1), 'r') 230 | % plot(q2dt_pred(start_time+remove_transient:end_time,1), 'b') 231 | % xlabel('time step') 232 | % ylabel('d2q/dt2(1)') 233 | % legend('desired', 'pred') 234 | % 235 | % figure() 236 | % hold on 237 | % plot(q2dt_control(start_time+remove_transient:end_time,2), 'r') 238 | % plot(q2dt_pred(start_time+remove_transient:end_time,2), 'b') 239 | % xlabel('time step') 240 | % ylabel('d2q/dt2(2)') 241 | % legend('desired', 'pred') 242 | % 243 | % figure() 244 | % hold on 245 | % plot(tau_control(start_time+remove_transient:end_time,1), 'r') 246 | % plot(tau_pred(start_time+remove_transient:end_time,1), 'b') 247 | % xlabel('time step') 248 | % ylabel('tau(1)') 249 | % legend('desired', 'pred') 250 | % 251 | % figure() 252 | % hold on 253 | % plot(tau_control(start_time+remove_transient:end_time,2), 'r') 254 | % plot(tau_pred(start_time+remove_transient:end_time,2), 'b') 255 | % xlabel('time step') 256 | % ylabel('tau(2)') 257 | % legend('desired', 'pred') 258 | % 259 | % % figure() 260 | % % hold on 261 | % % plot(tau_control(start_time:end_time,1), 'r') 262 | % % plot(tau_control(start_time:end_time,2), 'b') 263 | % % xlabel('time step') 264 | % % ylabel('tau') 265 | % % legend('tau1', 'tau2') 266 | % 267 | % % plot movie for validation 268 | % plot_movie_val = 0; 269 | % start_step=start_time; 270 | % movie_step=500; 271 | % time_all=end_time; 272 | % line_property='dotted'; 273 | % q1=q_pred(:, 1); 274 | % q2=q_pred(:, 2); 275 | % if plot_movie_val == 1 276 | % func_plot_movie(start_step, movie_step, time_all, q1, q2, properties, line_property) 277 | % end 278 | 279 | 280 | %% save data 281 | 282 | % save('./save_file/8d_input_04222022_lorenz_recover_from_failure.mat') 283 | % save('./save_file/8d_input_04222022_mg17.mat') 284 | % save('./save_file/reset_to_same_point_05182022_infty.mat') 285 | % save('./save_file/reset_to_random_point_05242022_infty.mat') 286 | 287 | % save('./save_file/reto_random_work_for_all_traj_05242022.mat') 288 | % 289 | % save('./save_data/4_trajectory_08182022.mat', "save_all_traj") 290 | 291 | % save('./save_file/all_traj_06282022.mat', 'time_infor', 'input_infor', 'res_infor', 'properties', 'dim_in', 'dim_out', 'Wout', 'r_end', 'dt', 'reset_t', 'noise_level') 292 | % save('./save_file/all_traj_06282022_SI.mat', 'data_reservoir') 293 | 294 | 295 | %% success test 296 | 297 | % aa = 1; 298 | % 299 | % plot_movie = 0; 300 | % traj_type = 'infty'; 301 | % bridge_type = 'cubic'; 302 | % plot_val_and_update=0; 303 | % 304 | % time_infor.val_length=150000; 305 | % rmse_start_time = round(time_infor.val_length * 2/5); 306 | % rmse_end_time = time_infor.val_length - 100; 307 | % 308 | % iteration = 50; 309 | % rmse_parfor_set = zeros(1, iteration); 310 | % for repeat_i = 1:iteration 311 | % load('./save_file/all_traj_06282022.mat') 312 | % val_and_update; 313 | % data_pred=output_infor.data_pred; 314 | % data_control=control_infor.data_control; 315 | % rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 316 | % end 317 | % 318 | % rmse_parfor_set = sort(rmse_parfor_set); 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | -------------------------------------------------------------------------------- /experiments/iter_noise.m: -------------------------------------------------------------------------------- 1 | % 2 | % disturbance = 1.00; 3 | % measurement_noise = 0.00; 4 | 5 | % try for only the first one and then only the second one 6 | % then try them together 7 | % then try uncertainty 8 | 9 | time_today = datestr(now, 'mmddyyyy'); 10 | 11 | %% only disturbance 12 | measurement_noise = 0.00; 13 | 14 | % disturbance_set = exp(linspace(log(0.01), log(10), 20)); 15 | disturbance_set = linspace(1, 5, 20); 16 | error_set = zeros(1, length(disturbance_set)); 17 | 18 | rmse_length = 1000/dt; 19 | weight=zeros(rmse_length,1); 20 | % alpha_w=0.2; 21 | for ii=1:rmse_length 22 | weight(ii)=ii; 23 | end 24 | weight = weight/norm(weight, 1); 25 | 26 | for di_id = 1:length(disturbance_set) 27 | disturbance = disturbance_set(di_id); 28 | plot_movie = 0; 29 | traj_type = 'infty'; 30 | % traj_type = 'lorenz'; 31 | % traj_type = 'mg17'; 32 | % traj_type = 'mg30'; 33 | bridge_type = 'cubic'; 34 | 35 | time_infor.val_length=120000; 36 | 37 | [control_infor, output_infor, time_infor] = func_reservoir_validate(traj_type,... 38 | bridge_type, time_infor, input_infor, res_infor, properties, dim_in, dim_out, ... 39 | Wout, dt, plot_movie, disturbance, measurement_noise); 40 | 41 | data_pred=output_infor.data_pred; 42 | q_pred=output_infor.q_pred; 43 | qdt_pred=output_infor.qdt_pred; 44 | q2dt_pred=output_infor.q2dt_pred; 45 | tau_pred=output_infor.tau_pred; 46 | 47 | q_control=control_infor.q_control; 48 | qdt_control=control_infor.qdt_control; 49 | q2dt_control=control_infor.q2dt_control; 50 | tau_control=control_infor.tau_control; 51 | % q_control_all=control_infor.q_control_all; 52 | % qdt_control_all=control_infor.qdt_control_all; 53 | data_control=control_infor.data_control; 54 | 55 | val_length=time_infor.val_length; 56 | 57 | % calculate rmse 58 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :)); 59 | error = sum(weight.*mean(error.^2, 2)); 60 | 61 | error_set(di_id) = error; 62 | end 63 | 64 | nan_id = isnan(error_set); 65 | error_set_plot = error_set; 66 | error_set_plot(nan_id) = 5*max(error_set, [], 'omitnan'); 67 | 68 | save_data.disturbance_set = disturbance_set; 69 | save_data.disturbance_error_set = error_set; 70 | save_data.disturbance_error_set_plot = error_set_plot; 71 | 72 | save(['./save_data/disturbance_result_normscale_', time_today, '.mat'], "save_data") 73 | 74 | 75 | %% for plot 76 | 77 | figure(); 78 | % hold on 79 | % semilogx(disturbance_set, error_set_plot, 'o-') 80 | plot(disturbance_set, error_set_plot, 'o-') 81 | xlabel('gaussian disturbance, \sigma') 82 | ylabel('error') 83 | 84 | 85 | %% only measurement 86 | disturbance = 0.00; 87 | measurement_set = exp(linspace(log(0.01), log(10), 20)); 88 | % measurement_set = linspace(1, 5, 20); 89 | error_set = zeros(1, length(measurement_set)); 90 | 91 | rmse_length = 1000/dt; 92 | weight=zeros(rmse_length,1); 93 | for ii=1:rmse_length 94 | weight(ii)=ii; 95 | end 96 | weight = weight/norm(weight, 1); 97 | 98 | for di_id = 1:length(measurement_set) 99 | measurement_noise = measurement_set(di_id); 100 | plot_movie = 0; 101 | traj_type = 'infty'; 102 | % traj_type = 'lorenz'; 103 | % traj_type = 'mg17'; 104 | % traj_type = 'mg30'; 105 | bridge_type = 'cubic'; 106 | 107 | time_infor.val_length=120000; 108 | 109 | [control_infor, output_infor, time_infor] = func_reservoir_validate(traj_type,... 110 | bridge_type, time_infor, input_infor, res_infor, properties, dim_in, dim_out, ... 111 | Wout, dt, plot_movie, disturbance, measurement_noise); 112 | 113 | data_pred=output_infor.data_pred; 114 | q_pred=output_infor.q_pred; 115 | qdt_pred=output_infor.qdt_pred; 116 | q2dt_pred=output_infor.q2dt_pred; 117 | tau_pred=output_infor.tau_pred; 118 | 119 | q_control=control_infor.q_control; 120 | qdt_control=control_infor.qdt_control; 121 | q2dt_control=control_infor.q2dt_control; 122 | tau_control=control_infor.tau_control; 123 | % q_control_all=control_infor.q_control_all; 124 | % qdt_control_all=control_infor.qdt_control_all; 125 | data_control=control_infor.data_control; 126 | 127 | val_length=time_infor.val_length; 128 | 129 | % calculate rmse 130 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :)); 131 | error = sum(weight.*mean(error.^2, 2)); 132 | 133 | error_set(di_id) = error; 134 | end 135 | 136 | nan_id = isnan(error_set); 137 | error_set_plot = error_set; 138 | error_set_plot(nan_id) = 5*max(error_set, [], 'omitnan'); 139 | 140 | save_data.measurement_set = measurement_set; 141 | save_data.measurement_error_set = error_set; 142 | save_data.measurement_error_set_plot = error_set_plot; 143 | 144 | save(['./save_data/measurement_result_logscale_', time_today, '.mat'], "save_data") 145 | 146 | %% for plot 147 | 148 | load('./save_data/measurement_result_logscale_05202022.mat') 149 | 150 | error_set_plot = save_data.measurement_error_set; 151 | 152 | figure(); 153 | % hold on 154 | semilogx(measurement_set, error_set_plot, 'o-') 155 | % plot(measurement_set, error_set_plot, 'o-') 156 | xlabel('gaussian measurement noise, \sigma') 157 | ylabel('error') 158 | 159 | 160 | 161 | %% disturbance and measurement noise 162 | 163 | disturbance_set = exp(linspace(log(0.01), log(10), 20)); 164 | measurement_set = exp(linspace(log(0.01), log(10), 20)); 165 | 166 | error_set = zeros(length(disturbance_set), length(measurement_set)); 167 | 168 | rmse_length = 1000/dt; 169 | weight=zeros(rmse_length,1); 170 | for ii=1:rmse_length 171 | weight(ii)=ii; 172 | end 173 | weight = weight/norm(weight, 1); 174 | 175 | for di_id = 1:length(disturbance_set) 176 | for mn_id = 1:length(measurement_set) 177 | disturbance = disturbance_set(di_id); 178 | measurement_noise = measurement_set(mn_id); 179 | plot_movie = 0; 180 | traj_type = 'infty'; 181 | % traj_type = 'lorenz'; 182 | % traj_type = 'mg17'; 183 | % traj_type = 'mg30'; 184 | bridge_type = 'cubic'; 185 | 186 | time_infor.val_length=120000; 187 | 188 | [control_infor, output_infor, time_infor] = func_reservoir_validate(traj_type,... 189 | bridge_type, time_infor, input_infor, res_infor, properties, dim_in, dim_out, ... 190 | Wout, dt, plot_movie, disturbance, measurement_noise); 191 | 192 | data_pred=output_infor.data_pred; 193 | q_pred=output_infor.q_pred; 194 | qdt_pred=output_infor.qdt_pred; 195 | q2dt_pred=output_infor.q2dt_pred; 196 | tau_pred=output_infor.tau_pred; 197 | 198 | q_control=control_infor.q_control; 199 | qdt_control=control_infor.qdt_control; 200 | q2dt_control=control_infor.q2dt_control; 201 | tau_control=control_infor.tau_control; 202 | % q_control_all=control_infor.q_control_all; 203 | % qdt_control_all=control_infor.qdt_control_all; 204 | data_control=control_infor.data_control; 205 | 206 | val_length=time_infor.val_length; 207 | 208 | % calculate rmse 209 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :)); 210 | error = sum(weight.*mean(error.^2, 2)); 211 | 212 | error_set(di_id, mn_id) = error; 213 | end 214 | end 215 | 216 | nan_id = isnan(error_set); 217 | error_set_plot = error_set; 218 | error_set_plot(nan_id) = 3*max(max(error_set)); 219 | 220 | save_data.disturbance_set_heat = disturbance_set; 221 | save_data.measurement_set_heat = measurement_set; 222 | save_data.heat_error_set = error_set; 223 | save_data.heat_error_set_plot = error_set_plot; 224 | 225 | save(['./save_data/heat_result_logscale_', time_today, '.mat'], "save_data") 226 | 227 | %% for plot 228 | 229 | figure(); 230 | surf(disturbance_set, measurement_set, error_set_plot); 231 | xlabel('measurement noise, \sigma') 232 | ylabel('disturbance noise, \sigma') 233 | set(gca, 'YScale', 'log'); 234 | set(gca, 'XScale', 'log'); 235 | colorbar 236 | view(0, 90) 237 | title('error') 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | -------------------------------------------------------------------------------- /experiments/normal_success_rate.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | addpath('./tools/') 6 | 7 | % load training data 8 | load('./save_file/all_traj_06282022.mat') 9 | time_today = datestr(now, 'mmddyyyy'); 10 | 11 | plot_movie = 0; 12 | if exist('traj_type','var') == 0 13 | traj_type = 'infty'; 14 | end 15 | 16 | traj_set = ["lorenz", "circle", "mg17", "infty", "astroid", "fermat", ... 17 | "lissajous", "talbot", "heart", "chua", "rossler", ... 18 | "sprott_1", "sprott_4", "mg30", "epitrochoid"]; 19 | % traj_set = ["lorenz"]; 20 | iteration = 100; 21 | 22 | bridge_type = 'cubic'; 23 | time_infor.val_length=250000; 24 | val_length_rm = time_infor.val_length; 25 | 26 | failure.type = 'none'; 27 | blur.last_time = 1000; 28 | blur.recover_time = time_infor.val_length; 29 | blur.blur = 0; 30 | 31 | failure.type = 'all'; 32 | failure.amplitude = 0.1; 33 | failure.amplitude_2 = 0.1; 34 | 35 | rmse_start_time = round(val_length_rm * 3/5); 36 | rmse_end_time = time_infor.val_length - 100; 37 | 38 | rmse_set = zeros(length(traj_set), iteration); 39 | 40 | idx=1; 41 | for traj_id = 1:length(traj_set) 42 | traj_type = traj_set(traj_id); 43 | if strcmp(traj_type, 'circle') == 1 44 | traj_frequency = 150; 45 | elseif strcmp(traj_type, 'infty') == 1 46 | traj_frequency = 75; 47 | end 48 | 49 | rmse_parfor_set = zeros(1, iteration); 50 | for repeat_i = 1:iteration 51 | load('./save_file/all_traj_06282022.mat') 52 | time_infor.val_length = val_length_rm; 53 | save_rend=0; 54 | val_and_update; 55 | rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 56 | aaa = 1; 57 | end 58 | rmse_parfor_set = sort(rmse_parfor_set); 59 | rmse_set(traj_id, :) = rmse_parfor_set; 60 | end 61 | 62 | save_success.rmse_set = rmse_set; 63 | save_success.val_length = time_infor.val_length; 64 | save_success.traj_set = traj_set; 65 | 66 | % save(['./save_data/save_saferegion_success_rate_noise_' time_today, '_' num2str(randi(999)) '.mat'], "save_success") 67 | 68 | %% 69 | % load('./save_data/save_normal_success_rate_08022022_92.mat') 70 | 71 | % rmse_set = save_success.rmse_set; 72 | 73 | rmse_threshold = 0.18; 74 | rmse_logic = zeros(size(rmse_set)); 75 | for i = 1:size(rmse_set, 1) 76 | for j = 1:size(rmse_set, 2) 77 | if rmse_set(i, j) > rmse_threshold 78 | rmse_logic(i, j) = 0; 79 | else 80 | rmse_logic(i, j) = 1; 81 | end 82 | end 83 | end 84 | 85 | rmse_count = mean(rmse_logic, 2); 86 | 87 | figure() 88 | plot(rmse_count, 'o') 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /experiments/speed_test.m: -------------------------------------------------------------------------------- 1 | % clear all 2 | % close all 3 | % clc 4 | 5 | addpath('./tools/') 6 | 7 | %% 8 | 9 | load('./save_file/all_traj_06282022.mat') 10 | 11 | time_today = datestr(now, 'mmddyyyy'); 12 | 13 | plot_val_and_update = 0; 14 | 15 | blur.blur = 0; 16 | plot_movie = 0; 17 | traj_type = 'infty'; 18 | 19 | if exist('traj_type','var') == 0 20 | traj_type = 'infty'; 21 | end 22 | 23 | traj_type = 'infty'; 24 | 25 | failure.type = 'all'; 26 | failure.amplitude = 0.1; 27 | failure.amplitude_2 = 0.1; 28 | 29 | bridge_type = 'cubic'; 30 | 31 | time_infor.val_length=250000; 32 | val_length_rm = time_infor.val_length; 33 | 34 | rmse_start_time = round(val_length_rm * 3/5); 35 | rmse_end_time = time_infor.val_length - 100; 36 | 37 | if strcmp(traj_type, 'circle') == 1 38 | frequency_set = round(linspace(10, 500, 15)); 39 | elseif strcmp(traj_type, 'infty') == 1 40 | frequency_set = round(linspace(10, 500, 15) / 2); 41 | end 42 | 43 | iteration = 50; 44 | 45 | failure.type = 'none'; 46 | 47 | rmse_set = zeros(length(frequency_set), iteration); 48 | 49 | idx=1; 50 | 51 | for f_idx = 1:length(frequency_set) 52 | rmse_parfor_set = zeros(1, iteration); 53 | for repeat_i = 1:iteration 54 | save_rend=0; 55 | load('./save_file/all_traj_06282022.mat') 56 | time_infor.val_length = val_length_rm; 57 | traj_frequency = frequency_set(f_idx); 58 | 59 | val_and_update; 60 | 61 | rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 62 | aaa = 1; 63 | end 64 | rmse_parfor_set = sort(rmse_parfor_set); 65 | rmse_set(f_idx, :) = rmse_parfor_set; 66 | end 67 | 68 | 69 | save_speed_iter.traj_type = traj_type; 70 | 71 | save_speed_iter.(['frequency_set', num2str(idx)]) = frequency_set; 72 | save_speed_iter.(['rmse_set', num2str(idx)]) = rmse_set; 73 | 74 | save(['save_data/save_speed_iter_' traj_type, '_' time_today, '_' num2str(randi(999)) '.mat'], "save_speed_iter") 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /experiments/speed_test_main.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | time_today = datestr(now, 'mmddyyyy'); 6 | 7 | traj_type = 'circle'; 8 | speed_test 9 | 10 | traj_type = 'infty'; 11 | speed_test -------------------------------------------------------------------------------- /experiments/success_rate.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | addpath('./tools/') 6 | 7 | %% test for network size, training length, 8 | % reset time and noise level 9 | % please record the rmse and running time 10 | 11 | % network_set = round(linspace(20, 500, 10)); 12 | % training_length_set = round(linspace(20000, 200000, 10)); 13 | 14 | network_set = round(exp(linspace(log(20), log(250), 10))); 15 | training_length_set = round(exp(linspace(log(2000), log(150000), 10))); 16 | 17 | % network_set = fliplr(network_set); 18 | % training_length_set = fliplr(training_length_set); 19 | 20 | reset_t_set = round(linspace(10, 150, 10)); 21 | noise_level_set = 10 .^ linspace(-3, 0, 10); 22 | 23 | time_today = datestr(now, 'mmddyyyy'); 24 | 25 | 26 | %% first to test the n_set and train_len_set 27 | 28 | reset_t = 80; 29 | noise_level = 2.0 * 10 ^ (-2); 30 | 31 | bias = 2.0; 32 | 33 | % delete(gcp('nocreate')) 34 | % parpool('local',10) 35 | 36 | iteration = 50; 37 | 38 | rmse_set_lorenz = zeros(length(network_set), length(training_length_set), iteration); 39 | rmse_set_circle = zeros(length(network_set), length(training_length_set), iteration); 40 | rmse_set_mg17 = zeros(length(network_set), length(training_length_set), iteration); 41 | rmse_set_infty = zeros(length(network_set), length(training_length_set), iteration); 42 | time_set = zeros(length(network_set), length(training_length_set), iteration); 43 | 44 | for ns = 1:length(network_set) 45 | n = network_set(ns); 46 | for tls = 1:length(training_length_set) 47 | train_t = training_length_set(tls); 48 | 49 | rmse_parfor_set_lorenz = zeros(1, iteration); 50 | rmse_parfor_set_circle = zeros(1, iteration); 51 | rmse_parfor_set_mg17 = zeros(1, iteration); 52 | rmse_parfor_set_infty = zeros(1, iteration); 53 | 54 | time_parfor_set = zeros(1, iteration); 55 | for repeat_i = 1:iteration 56 | [rmse_l, rmse_c, rmse_m, rmse_i, t_repeat_i] = func_train_val(n, train_t, reset_t, noise_level, bias); 57 | rmse_parfor_set_lorenz(repeat_i) = rmse_l; 58 | rmse_parfor_set_circle(repeat_i) = rmse_c; 59 | rmse_parfor_set_mg17(repeat_i) = rmse_m; 60 | rmse_parfor_set_infty(repeat_i) = rmse_i; 61 | 62 | time_parfor_set(repeat_i) = t_repeat_i; 63 | end 64 | 65 | rmse_set_lorenz(ns, tls, :) = rmse_parfor_set_lorenz; 66 | rmse_set_circle(ns, tls, :) = rmse_parfor_set_circle; 67 | rmse_set_mg17(ns, tls, :) = rmse_parfor_set_mg17; 68 | rmse_set_infty(ns, tls, :) = rmse_parfor_set_infty; 69 | 70 | time_set(ns, tls, :) = time_parfor_set; 71 | end 72 | end 73 | 74 | save_success_rate.reset_t = reset_t; 75 | save_success_rate.noise_level = noise_level; 76 | save_success_rate.network_set = network_set; 77 | save_success_rate.training_length_set = training_length_set; 78 | save_success_rate.rmse_set_lorenz = rmse_set_lorenz; 79 | save_success_rate.rmse_set_circle = rmse_set_circle; 80 | save_success_rate.rmse_set_mg17 = rmse_set_mg17; 81 | save_success_rate.rmse_set_infty = rmse_set_infty; 82 | save_success_rate.time_set = time_set; 83 | 84 | save(['save_data/save_success_rate_nt_', time_today, '_' num2str(randi(999)) '.mat'], 'save_success_rate') 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | -------------------------------------------------------------------------------- /experiments/success_rate_1.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | addpath('./tools/') 6 | 7 | %% test for network size, training length, 8 | % reset time and noise level 9 | % please record the rmse and running time 10 | 11 | % network_set = round(exp(linspace(log(20), log(270), 10))); 12 | % training_length_set = round(exp(linspace(log(2000), log(150000), 10))); 13 | 14 | reset_t_set = round(linspace(10, 150, 10)); 15 | noise_level_set = 10 .^ linspace(-3, 0, 10); 16 | 17 | time_today = datestr(now, 'mmddyyyy'); 18 | 19 | 20 | %% first to test the n_set and train_len_set 21 | 22 | % reset_t = 80; 23 | % noise_level = 2.0 * 10 ^ (-2); 24 | n = 100; 25 | train_t = 10000; 26 | 27 | bias = 2.0; 28 | 29 | % delete(gcp('nocreate')) 30 | % parpool('local',6) 31 | 32 | iteration = 50; 33 | 34 | rmse_set_lorenz = zeros(length(reset_t_set), length(noise_level_set), iteration); 35 | rmse_set_circle = zeros(length(reset_t_set), length(noise_level_set), iteration); 36 | rmse_set_mg17 = zeros(length(reset_t_set), length(noise_level_set), iteration); 37 | rmse_set_infty = zeros(length(reset_t_set), length(noise_level_set), iteration); 38 | time_set = zeros(length(reset_t_set), length(noise_level_set), iteration); 39 | 40 | for rts = 1:length(reset_t_set) 41 | reset_t = reset_t_set(rts); 42 | for nls = 1:length(noise_level_set) 43 | noise_level = noise_level_set(nls); 44 | 45 | rmse_parfor_set_lorenz = zeros(1, iteration); 46 | rmse_parfor_set_circle = zeros(1, iteration); 47 | rmse_parfor_set_mg17 = zeros(1, iteration); 48 | rmse_parfor_set_infty = zeros(1, iteration); 49 | time_parfor_set = zeros(1, iteration); 50 | 51 | for repeat_i = 1:iteration 52 | [rmse_l, rmse_c, rmse_m, rmse_i, t_repeat_i] = func_train_val(n, train_t, reset_t, noise_level, bias); 53 | rmse_parfor_set_lorenz(repeat_i) = rmse_l; 54 | rmse_parfor_set_circle(repeat_i) = rmse_c; 55 | rmse_parfor_set_mg17(repeat_i) = rmse_m; 56 | rmse_parfor_set_infty(repeat_i) = rmse_i; 57 | 58 | time_parfor_set(repeat_i) = t_repeat_i; 59 | end 60 | rmse_set_lorenz(rts, nls, :) = rmse_parfor_set_lorenz; 61 | rmse_set_circle(rts, nls, :) = rmse_parfor_set_circle; 62 | rmse_set_mg17(rts, nls, :) = rmse_parfor_set_mg17; 63 | rmse_set_infty(rts, nls, :) = rmse_parfor_set_infty; 64 | 65 | time_set(rts, nls, :) = time_parfor_set; 66 | end 67 | end 68 | 69 | save_success_rate.reset_t_set = reset_t_set; 70 | save_success_rate.noise_level_set = noise_level_set; 71 | save_success_rate.n = n; 72 | save_success_rate.train_t = train_t; 73 | save_success_rate.rmse_set_lorenz = rmse_set_lorenz; 74 | save_success_rate.rmse_set_circle = rmse_set_circle; 75 | save_success_rate.rmse_set_mg17 = rmse_set_mg17; 76 | save_success_rate.rmse_set_infty = rmse_set_infty; 77 | save_success_rate.time_set = time_set; 78 | 79 | save(['save_data/save_success_rate_rn_', time_today, '_' num2str(randi(999)) '.mat'], 'save_success_rate') 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /experiments/success_rate_bias.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | addpath('./tools/') 6 | 7 | bias_set = linspace(0, 3, 7); 8 | 9 | time_today = datestr(now, 'mmddyyyy'); 10 | 11 | reset_t = 80; 12 | noise_level = 2.0 * 10 ^ (-2); 13 | n=200; 14 | train_t = 150000; 15 | 16 | iteration = 10; 17 | 18 | rmse_set_lorenz = zeros(length(bias_set), iteration); 19 | rmse_set_circle = zeros(length(bias_set), iteration); 20 | rmse_set_mg17 = zeros(length(bias_set), iteration); 21 | rmse_set_infty = zeros(length(bias_set), iteration); 22 | 23 | for bs = 1:length(bias_set) 24 | bias = bias_set(bs); 25 | rmse_parfor_set_lorenz = zeros(1, iteration); 26 | rmse_parfor_set_circle = zeros(1, iteration); 27 | rmse_parfor_set_mg17 = zeros(1, iteration); 28 | rmse_parfor_set_infty = zeros(1, iteration); 29 | 30 | for repeat_i = 1:iteration 31 | [rmse_l, rmse_c, rmse_m, rmse_i, t_repeat_i] = func_train_val(n, train_t, reset_t, noise_level, bias); 32 | rmse_parfor_set_lorenz(repeat_i) = rmse_l; 33 | rmse_parfor_set_circle(repeat_i) = rmse_c; 34 | rmse_parfor_set_mg17(repeat_i) = rmse_m; 35 | rmse_parfor_set_infty(repeat_i) = rmse_i; 36 | end 37 | 38 | rmse_set_lorenz(bs, :) = rmse_parfor_set_lorenz; 39 | rmse_set_circle(bs, :) = rmse_parfor_set_circle; 40 | rmse_set_mg17(bs, :) = rmse_parfor_set_mg17; 41 | rmse_set_infty(bs, :) = rmse_parfor_set_infty; 42 | end 43 | 44 | save_success_rate.reset_t = reset_t; 45 | save_success_rate.noise_level = noise_level; 46 | save_success_rate.n = n; 47 | save_success_rate.train_t = train_t; 48 | save_success_rate.bias_set = bias_set; 49 | save_success_rate.rmse_set_lorenz = rmse_set_lorenz; 50 | save_success_rate.rmse_set_circle = rmse_set_circle; 51 | save_success_rate.rmse_set_mg17 = rmse_set_mg17; 52 | save_success_rate.rmse_set_infty = rmse_set_infty; 53 | 54 | 55 | save(['save_data/save_success_rate_bias_', time_today, '_' num2str(randi(999)) '.mat'], 'save_success_rate') 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /experiments/temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /experiments/uncertain_test_saferegion.m: -------------------------------------------------------------------------------- 1 | % clear all 2 | % close all 3 | % clc 4 | 5 | addpath('./tools/') 6 | 7 | %% 8 | 9 | load('./save_file/all_traj_06282022.mat') 10 | % 11 | % time_today = datestr(now, 'mmddyyyy'); 12 | 13 | plot_val_and_update = 0; 14 | 15 | blur.blur = 0; 16 | plot_movie = 0; 17 | % traj_type = 'lorenz'; 18 | 19 | if exist('traj_type','var') == 0 20 | traj_type = 'infty'; 21 | end 22 | 23 | bridge_type = 'cubic'; 24 | 25 | time_infor.val_length=250000; 26 | val_length_rm = time_infor.val_length; 27 | 28 | rmse_start_time = round(val_length_rm * 3/5); 29 | rmse_end_time = time_infor.val_length - 100; 30 | 31 | if strcmp(traj_type, 'circle') == 1 32 | traj_frequency = 150; 33 | elseif strcmp(traj_type, 'infty') == 1 34 | traj_frequency = 75; 35 | else 36 | traj_frequency = 75; 37 | end 38 | 39 | l1_set = linspace(0.5, 0.55, 10); 40 | l2_set = linspace(0.5, 0.55, 10); 41 | 42 | iteration = 50; 43 | 44 | failure.type = 'all'; 45 | failure.amplitude = 0.1; 46 | failure.amplitude_2 = 0.1; 47 | 48 | % length 49 | uncertain_type = 'l'; 50 | 51 | rmse_set = zeros(length(l1_set), length(l2_set), iteration); 52 | 53 | idx=2; 54 | 55 | for l1_idx = 1:length(l1_set) 56 | for l2_idx = 1:length(l2_set) 57 | rmse_parfor_set = zeros(1, iteration); 58 | for repeat_i = 1:iteration 59 | save_rend=0; 60 | 61 | load('./save_file/all_traj_06282022.mat') 62 | time_infor.val_length = val_length_rm; 63 | properties(3:4) = [l1_set(l1_idx), l2_set(l2_idx)]; 64 | 65 | val_and_update; 66 | 67 | rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 68 | aaa = 1; 69 | end 70 | rmse_parfor_set = sort(rmse_parfor_set); 71 | rmse_set(l1_idx, l2_idx, :) = rmse_parfor_set; 72 | end 73 | end 74 | 75 | save_uncertain_iter.traj_type = traj_type; 76 | 77 | save_uncertain_iter.(['type', num2str(idx)]) = uncertain_type; 78 | save_uncertain_iter.(['l1_set', num2str(idx)]) = l1_set; 79 | save_uncertain_iter.(['l2_set', num2str(idx)]) = l2_set; 80 | save_uncertain_iter.(['rmse_set', num2str(idx)]) = rmse_set; 81 | 82 | save(['save_data/save_uncertain_iter_' traj_type, '_' time_today, '_' num2str(randi(999)) '.mat'], "save_uncertain_iter") 83 | 84 | 85 | 86 | %% rmse set 87 | 88 | rmse_set_l = save_uncertain_iter.rmse_set2; 89 | 90 | rmse_l_mean = mean(rmse_set_l, 3); 91 | 92 | figure(); 93 | imagesc(l1_set, l2_set, rmse_l_mean); 94 | colorbar 95 | 96 | l1_set_plot = (l1_set - 0.5)/ 0.5; 97 | l2_set_plot = (l2_set - 0.5)/ 0.5; 98 | rmse_l_mean_flip = flip(rmse_l_mean, 1); 99 | 100 | figure(); 101 | imagesc(l1_set_plot, l2_set_plot, rmse_l_mean_flip); 102 | colorbar 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /experiments/uncertain_test_saferegion_main.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | time_today = datestr(now, 'mmddyyyy'); 6 | 7 | % traj_type = 'lorenz'; 8 | % uncertain_test_saferegion 9 | 10 | % traj_type = 'circle'; 11 | % uncertain_test_saferegion 12 | 13 | % traj_type = 'mg17'; 14 | % uncertain_test_saferegion 15 | 16 | traj_type = 'infty'; 17 | uncertain_test_saferegion 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /func_desired_traj.m: -------------------------------------------------------------------------------- 1 | function [control_infor, time_infor] = func_desired_traj(traj_type, bridge_type, time_infor, control_infor, properties, dt, plot_movie, traj_frequency) 2 | 3 | plot_figures_inside = 0; 4 | 5 | val_length=time_infor.val_length; 6 | train_length=time_infor.train_length; 7 | % Get the two-arms property 8 | [m1, m2, l1, l2, lc1, lc2, I1, I2] = matsplit(properties); 9 | % Get the initial information 10 | q_control=control_infor.q_control; 11 | qdt_control=control_infor.qdt_control; 12 | q2dt_control=control_infor.q2dt_control; 13 | tau_control=control_infor.tau_control; 14 | 15 | value_q=q_control(1,:); 16 | x_start=l1*cos(value_q(1))+l2*cos(value_q(1)+value_q(2)); 17 | y_start=l1*sin(value_q(1))+l2*sin(value_q(1)+value_q(2)); 18 | 19 | t=0:dt:val_length; 20 | 21 | %% Different traj_types, the readers can add their own trajectories 22 | if strcmp(traj_type, 'infty')==1 23 | x = 0.25*sin(2*pi*t*1/(2*traj_frequency)); 24 | y = 0.15*sin(2*pi*t*1/traj_frequency); 25 | 26 | x = x(1:2*val_length+1); 27 | y = y(1:2*val_length+1); 28 | elseif strcmp(traj_type, 'circle') 29 | x = 0.5 * cos(2 * pi * t * 1 / traj_frequency); 30 | y = 0.5 * sin(2 * pi * t * 1 / traj_frequency); 31 | 32 | x = x(1:2*val_length+1); 33 | y = y(1:2*val_length+1); 34 | elseif strcmp(traj_type, 'astroid') 35 | x = 0.4 * (cos(2 * pi * t * 1 / 250).^3); 36 | y = 0.4 * (sin(2 * pi * t * 1 / 250).^3); 37 | 38 | x = x(1:2*val_length+1); 39 | y = y(1:2*val_length+1); 40 | 41 | elseif strcmp(traj_type, 'heart') 42 | f = 250; 43 | a = 0.4; 44 | theta = 2 * pi * t * 1 / f; 45 | 46 | r = 1 - sin(theta); 47 | 48 | x = r.*cos(theta); 49 | y = r.*sin(theta); 50 | 51 | x = normalize(x, 'range', [-a, a]); 52 | y = normalize(y, 'range', [-a, a]); 53 | 54 | x = x(1:2*val_length+1); 55 | y = y(1:2*val_length+1); 56 | 57 | elseif strcmp(traj_type, 'epitrochoid') 58 | a = 5; 59 | b = 3; 60 | c = 5; 61 | 62 | f = 200; 63 | 64 | x = (a + b) * cos(2 * pi * t * 1 / f) - c * cos((a/b + 1) * 2 * pi * t * 1 / f); 65 | y = (a + b) * sin(2 * pi * t * 1 / f) - c * sin((a/b + 1) * 2 * pi * t * 1 / f); 66 | 67 | x = normalize(x, 'range', [-0.4, 0.4]); 68 | y = normalize(y, 'range', [-0.4, 0.4]); 69 | 70 | x = x(1:2*val_length+1); 71 | y = y(1:2*val_length+1); 72 | 73 | elseif strcmp(traj_type, 'fermat') 74 | f = 100; 75 | a = 0.5; 76 | theta = 2 * pi * t * 1 / f; 77 | 78 | r = sqrt(a^2 * theta); 79 | 80 | x = r.*cos(theta); 81 | y = r.*sin(theta); 82 | 83 | x = normalize(x, 'range', [-4, 4]); 84 | y = normalize(y, 'range', [-4, 4]); 85 | 86 | x = x(1:2*val_length+1); 87 | y = y(1:2*val_length+1); 88 | 89 | elseif strcmp(traj_type, 'lissajous') 90 | f = 300; 91 | a = 1; 92 | b = 3; 93 | A = 1; 94 | B = 1; 95 | 96 | x = A * sin(a * t * 2 * pi / f + pi / 4); 97 | y = B * sin(b * t * 2 * pi / f) ; 98 | 99 | xx = normalize(x, 'range', [-0.3, 0.3]); 100 | yy = normalize(y, 'range', [-0.3, 0.3]); 101 | 102 | x = yy(1:2*val_length+1); 103 | y = xx(1:2*val_length+1); 104 | 105 | elseif strcmp(traj_type, 'talbot') 106 | f = 400; 107 | a = 1.1; 108 | b = 0.666; 109 | ff = 1; 110 | 111 | x = (a^2 + ff^2 .* sin(t * 2 * pi / f).^2) .* cos(t * 2 * pi / f) / a; 112 | y = (a^2 - 2*ff^2 + ff^2 .* sin(t * 2 * pi / f).^2) .* sin(t * 2 * pi / f) / b; 113 | 114 | x = normalize(x, 'range', [-0.4, 0.4]); 115 | y = normalize(y, 'range', [-0.3, 0.3]); 116 | 117 | x = x(1:2*val_length+1); 118 | y = y(1:2*val_length+1); 119 | 120 | elseif strcmp(traj_type, 'lorenz') 121 | load('./read_data/lorenz.mat') 122 | lorenz_xy=ts_train(1000:1000+round(val_length*2.1), 1:2); 123 | 124 | lorenz_x = normalize(lorenz_xy(:, 1), 'range', [-0.5, 0.5]); 125 | lorenz_y = normalize(lorenz_xy(:, 2), 'range', [-0.5, 0.5]); 126 | 127 | x = lorenz_x(1:val_length*2+1, :)'; 128 | y = lorenz_y(1:val_length*2+1, :)'; 129 | 130 | y = y - 0.3; 131 | 132 | elseif strcmp(traj_type, 'chua') 133 | load('./read_data/chua.mat') 134 | chua_xy = ts_train(10000:10000+round(val_length*2.1), 1:2); 135 | 136 | chua_x = normalize(chua_xy(:, 1), 'range', [-0.5, 0.5]); 137 | chua_y = normalize(chua_xy(:, 2), 'range', [-0.3, 0.3]); 138 | 139 | chua_x = interp1(1:length(chua_x), chua_x, 0:0.3:length(chua_x), 'pchip'); 140 | chua_y = interp1(1:length(chua_y), chua_y, 0:0.3:length(chua_y), 'pchip'); 141 | 142 | x = chua_x(1:val_length*2+1)'; 143 | y = chua_y(1:val_length*2+1)'; 144 | 145 | x = reshape(x, 1, []); 146 | y = reshape(y, 1, []); 147 | 148 | elseif strcmp(traj_type, 'rossler') 149 | load('./read_data/rossler.mat') 150 | rossler_xy = ts_train(10000:10000+round(val_length*2.1), 1:2); 151 | 152 | rossler_x = normalize(rossler_xy(:, 1), 'range', [-0.35, 0.35]); 153 | rossler_y = normalize(rossler_xy(:, 2), 'range', [-0.35, 0.35]); 154 | 155 | rossler_x = interp1(1:length(rossler_x), rossler_x, 0:0.55:length(rossler_x), 'pchip'); 156 | rossler_y = interp1(1:length(rossler_y), rossler_y, 0:0.55:length(rossler_y), 'pchip'); 157 | 158 | x = rossler_x(1:val_length*2+1)'; 159 | y = rossler_y(1:val_length*2+1)'; 160 | 161 | x = reshape(x, 1, []); 162 | y = reshape(y, 1, []); 163 | 164 | elseif strcmp(traj_type, 'sprott_1') 165 | load('./read_data/sprott_1.mat') 166 | sprott_xy = ts_train(10000:10000+round(val_length*2.1), 1:2); 167 | 168 | sprott_x = normalize(sprott_xy(:, 1), 'range', [-0.4, 0.4]); 169 | sprott_y = normalize(sprott_xy(:, 2), 'range', [-0.4, 0.4]); 170 | 171 | sprott_x = interp1(1:length(sprott_x), sprott_x, 0:0.6:length(sprott_x), 'pchip'); 172 | sprott_y = interp1(1:length(sprott_y), sprott_y, 0:0.6:length(sprott_y), 'pchip'); 173 | 174 | x = sprott_x(1:val_length*2+1)'; 175 | y = sprott_y(1:val_length*2+1)'; 176 | 177 | x = reshape(x, 1, []); 178 | y = reshape(y, 1, []); 179 | 180 | elseif strcmp(traj_type, 'sprott_4') 181 | load('./read_data/sprott_4.mat') 182 | sprott_xy = ts_train(10000:10000+round(val_length*2.1), 1:2); 183 | 184 | sprott_x = normalize(sprott_xy(:, 1), 'range', [-0.4, 0.4]); 185 | sprott_y = normalize(sprott_xy(:, 2), 'range', [-0.4, 0.4]); 186 | 187 | sprott_x = interp1(1:length(sprott_x), sprott_x, 0:0.6:length(sprott_x), 'pchip'); 188 | sprott_y = interp1(1:length(sprott_y), sprott_y, 0:0.6:length(sprott_y), 'pchip'); 189 | 190 | x = sprott_x(1:val_length*2+1)'; 191 | y = sprott_y(1:val_length*2+1)'; 192 | 193 | x = reshape(x, 1, []); 194 | y = reshape(y, 1, []); 195 | 196 | 197 | elseif strcmp(traj_type, 'mg17') 198 | load('./read_data/MG17.mat') 199 | mg=ts_train(10000:10000+round(val_length*2.1)); 200 | 201 | mg_x = normalize(mg(10000-1700:end-1700), 'range', [-0.5, 0.5]); 202 | mg_y = normalize(mg(10000:end), 'range', [-0.5, 0.5]); 203 | 204 | mg_x = interp1(1:length(mg_x), mg_x, 0:0.25:length(mg_x), 'pchip'); 205 | mg_y = interp1(1:length(mg_y), mg_y, 0:0.25:length(mg_y), 'pchip'); 206 | 207 | x = mg_x(1:val_length*2+1)'; 208 | y = mg_y(1:val_length*2+1)'; 209 | 210 | x = reshape(x, 1, []); 211 | y = reshape(y, 1, []); 212 | 213 | x = x - 0.1; 214 | y = y - 0.1; 215 | elseif strcmp(traj_type, 'mg30') 216 | load('./read_data/MG30.mat') 217 | mg=ts_train(10000:10000+round(val_length*2.1)); 218 | 219 | mg_x = normalize(mg(10000:end), 'range', [-0.4, 0.4]); 220 | mg_y = normalize(mg(10000-3000:end-3000), 'range', [-0.4, 0.4]); 221 | 222 | mg_x = interp1(1:length(mg_x), mg_x, 0:0.25:length(mg_x), 'pchip'); 223 | mg_y = interp1(1:length(mg_y), mg_y, 0:0.25:length(mg_y), 'pchip'); 224 | 225 | x = mg_x(1:val_length*2+1)'; 226 | y = mg_y(1:val_length*2+1)'; 227 | 228 | x = reshape(x, 1, []); 229 | y = reshape(y, 1, []); 230 | 231 | x = x + 0.25; 232 | y = y; 233 | 234 | elseif strcmp(traj_type, 'lorenz96') 235 | load('./read_data/lorenz96.mat') 236 | 237 | lorenz_highD = ts_train(10000:10000+round(val_length*2.1), :); 238 | 239 | lorenz_x = normalize(lorenz_highD(:, 1), 'range', [-0.6, 0.6]); 240 | lorenz_y = normalize(lorenz_highD(:, 2), 'range', [-0.7, 0.7]); 241 | 242 | x = lorenz_x(1:val_length*2+1)'; 243 | y = lorenz_y(1:val_length*2+1)'; 244 | 245 | x = reshape(x, 1, []); 246 | y = reshape(y, 1, []); 247 | 248 | x = x + 0.1; 249 | y = y; 250 | else 251 | disp('error: please input the traj type!') 252 | end 253 | 254 | %% Build bridge to the reference trajectory 255 | add_id=1; 256 | closet_value = Inf; 257 | for i =1:min(100000, val_length) 258 | dis_value = sqrt((x_start-x(i)).^2+(y_start-y(i)).^2); 259 | if dis_value < closet_value 260 | if round(x(i), 6)==0 && round(y(i), 6)==0 261 | continue 262 | end 263 | closet_value = dis_value; 264 | add_id=i; 265 | end 266 | end 267 | bridge_point = [x(add_id), y(add_id)]; 268 | bridge_len = sqrt((bridge_point(1)-x_start).^2 + (bridge_point(2)-y_start).^2); 269 | bridge_time = round(bridge_len * 1 / dt); 270 | % We provide two methods to build the bridge: cubic and linear 271 | if strcmp(bridge_type, 'cubic') == 1 272 | t_bg = 0:dt:bridge_time; 273 | theta0=q_control(1,:); 274 | theta0_dot=qdt_control(1,:); 275 | theta0=mod(theta0, 2*pi); 276 | if theta0(2) > pi 277 | theta0(2) = theta0(2)-2*pi; 278 | end 279 | q2_bg(1)=acos((x(add_id).^2+y(add_id).^2-l1^2-l2^2)/(2*l1*l2)); 280 | q2_bg(2)=acos((x(add_id+1).^2+y(add_id+1).^2-l1^2-l2^2)/(2*l1*l2)); 281 | if theta0(2)<0 282 | q2_bg=-q2_bg; 283 | end 284 | 285 | q1_bg(1)=atan(y(add_id)./x(add_id))-atan(l2*sin(q2_bg(1))./(l1+l2*cos(q2_bg(1)))); 286 | q1_bg(2)=atan(y(add_id+1)./x(add_id+1))-atan(l2*sin(q2_bg(2))./(l1+l2*cos(q2_bg(2)))); 287 | 288 | for ij = 1:length(q1_bg) 289 | if x(add_id + ij-1) < 0 && y(add_id + ij-1) > 0 290 | q1_bg(ij)=q1_bg(ij)+pi; 291 | elseif x(add_id + ij-1) < 0 && y(add_id + ij-1) < 0 292 | q1_bg(ij)=q1_bg(ij)+pi; 293 | elseif x(add_id + ij-1) > 0 && y(add_id + ij-1) < 0 294 | q1_bg(ij)=q1_bg(ij)+2*pi; 295 | end 296 | end 297 | 298 | theta1=[q1_bg(1), q2_bg(1)]; 299 | 300 | x_back_test = l1*cos(theta1(1)) + l2*cos(theta1(1)+theta1(2)); 301 | y_back_test = l1*sin(theta1(1)) + l2*sin(theta1(1)+theta1(2)); 302 | 303 | theta1_dot=[(q1_bg(2)-q1_bg(1))/dt, (q2_bg(2)-q2_bg(1))/dt]; 304 | 305 | a0=theta0; 306 | a1=theta0_dot; 307 | a2=3.*(theta1-theta0)./(bridge_time^2)-2.*theta0_dot./bridge_time-theta1_dot/bridge_time; 308 | a3=-2.*(theta1-theta0)./(bridge_time^3)+(theta1_dot+theta0_dot)/(bridge_time^2); 309 | 310 | q1_bridge=a0(1) + a1(1).*t_bg + a2(1).*(t_bg.^2) + a3(1).*(t_bg.^3); 311 | q2_bridge=a0(2) + a1(2).*t_bg + a2(2).*(t_bg.^2) + a3(2).*(t_bg.^3); 312 | 313 | truePosition=zeros(2, length(q1_bridge)); 314 | truePosition(1,:)=l1*cos(q1_bridge)+l2*cos(q1_bridge+q2_bridge); 315 | truePosition(2,:)=l1*sin(q1_bridge)+l2*sin(q1_bridge+q2_bridge); 316 | 317 | x=[truePosition(1, 2:end-1), reshape(x(add_id:end), 1, [])]; 318 | y=[truePosition(2, 2:end-1), reshape(y(add_id:end), 1, [])]; 319 | elseif strcmp(bridge_type, 'linear') == 1 320 | bridge = waypointTrajectory([x_start, y_start, 0; bridge_point(1), bridge_point(2), 0], 'TimeOfArrival', [0, bridge_time], 'SampleRate', round(bridge_time/dt)); 321 | 322 | truePosition = zeros(bridge.SampleRate * bridge.TimeOfArrival(end)-1, 3); 323 | count=1; 324 | while ~isDone(bridge) 325 | truePosition(count, :)=bridge(); 326 | count=count+1; 327 | end 328 | 329 | x=[truePosition(2:end-1, 1)', x(add_id:end)]; 330 | y=[truePosition(2:end-1, 2)', y(add_id:end)]; 331 | else 332 | disp('error: please input the bridge_type!') 333 | end 334 | 335 | val_length = val_length + length(truePosition(2:end-1, 1)); 336 | 337 | %% Desired reference 338 | q2=acos((x.^2+y.^2-l1^2-l2^2)/(2*l1*l2)); 339 | 340 | symb = 1; 341 | for ij = 2:length(x)-1 342 | q2(ij) = symb * q2(ij); 343 | if q2(ij) == pi || q2(ij) == -pi || q2(ij) == 0 344 | if sign(x(ij+1) - x(ij)) == sign(x(ij) - x(ij-1)) 345 | symb = -symb; 346 | end 347 | end 348 | end 349 | 350 | q1=atan(y./x)-atan(l2*sin(q2)./(l1+l2*cos(q2))); 351 | 352 | x1 = l1*cos(q1(1)); 353 | if round(x1, 2) ~= round(l1*cos(value_q(1)), 2) 354 | symbol_change_record = 1; 355 | q2 = -q2; 356 | q1=atan(y./x)-atan(l2*sin(q2)./(l1+l2*cos(q2))); 357 | for ij = 1:length(q1) 358 | if x(ij) < 0 && y(ij) > 0 359 | q1(ij)=q1(ij)+pi; 360 | elseif x(ij) < 0 && y(ij) < 0 361 | q1(ij)=q1(ij)+pi; 362 | elseif x(ij) > 0 && y(ij) < 0 363 | q1(ij)=q1(ij)+2*pi; 364 | end 365 | end 366 | 367 | for ij = 2:length(q1) 368 | if q1(ij)-q1(ij-1)pi-0.1 369 | q1(ij)=q1(ij)-pi; 370 | end 371 | if q1(ij)-q1(ij-1)<-pi+0.1 && q1(ij)-q1(ij-1)>-pi-0.1 372 | q1(ij)=q1(ij)+pi; 373 | end 374 | if q1(ij)-q1(ij-1)>pi 375 | q1(ij:end)=q1(ij:end)-2*pi; 376 | end 377 | if q1(ij)-q1(ij-1)<-pi 378 | q1(ij:end)=q1(ij:end)+2*pi; 379 | end 380 | if q2(ij)-q2(ij-1)>pi 381 | q2(ij:end)=q2(ij:end)-2*pi; 382 | end 383 | if q2(ij)-q2(ij-1)<-pi 384 | q2(ij:end)=q2(ij:end)+2*pi; 385 | end 386 | end 387 | end 388 | 389 | wave_qdt=[(q1(2:val_length+1)-q1(1:val_length))', (q2(2:val_length+1)-q2(1:val_length))']; 390 | wave_qdt=wave_qdt./dt; 391 | 392 | q_control(2:val_length+1, :)=[q1(1:val_length)', q2(1:val_length)']; 393 | 394 | qdt_control(2:val_length+1, :)=wave_qdt(1:val_length,:); 395 | 396 | q2dt_control(1:val_length, :)=... 397 | (qdt_control(2:val_length+1,:)-qdt_control(1:val_length,:))/dt; 398 | 399 | for ii = 1:val_length 400 | H11=m1*lc1^2+I1+m2*(l1^2+lc2^2+2*l1*lc2*cos(q_control(ii,2)))+I2; 401 | H12=m2*l1*lc2*cos(q_control(ii,2))+m2*lc2^2+I2; 402 | H21=H12; 403 | H22=m2*lc2^2+I2; 404 | h=m2*l1*lc2*sin(q_control(ii,2)); 405 | 406 | part_1=-h*qdt_control(ii,2)*qdt_control(ii,1)-h*(qdt_control(ii,1)+qdt_control(ii,2))*qdt_control(ii,2); 407 | part_2=h*qdt_control(ii,1)*qdt_control(ii,1); 408 | 409 | tau_control(ii,1)=H11*q2dt_control(ii,1)+H12*q2dt_control(ii,2)+part_1; 410 | tau_control(ii,2)=H21*q2dt_control(ii,1)+H22*q2dt_control(ii,2)+part_2; 411 | end 412 | 413 | % plot_figures_inside = 1; 414 | % if we need to visualization 415 | if plot_figures_inside == 1 416 | figure(); 417 | hold on 418 | plot(tau_control(2:val_length, 1)) 419 | plot(tau_control(2:val_length, 2)) 420 | ylabel('tau') 421 | end 422 | 423 | if plot_figures_inside == 1 424 | figure(); 425 | hold on 426 | plot(qdt_control(1:val_length-1, 1), 'r') 427 | plot(qdt_control(1:val_length-1, 2), 'b') 428 | xlabel('time step') 429 | ylabel('dq/dt(control)') 430 | legend('dq/dt(1)', 'dq/dt(2)') 431 | end 432 | 433 | x_control=l1*cos(q_control(:,1))+l2*cos(q_control(:,1)+q_control(:,2)); 434 | y_control=l1*sin(q_control(:,1))+l2*sin(q_control(:,1)+q_control(:,2)); 435 | 436 | if plot_figures_inside == 1 437 | figure(); 438 | hold on 439 | plot(x(1:val_length), y(1:val_length), 'r', 'LineWidth', 1) 440 | plot(x_control(1:val_length), y_control(1:val_length), 'b--', 'LineWidth', 2) 441 | xlim([-1, 1]) 442 | ylim([-1, 1]) 443 | legend('real', 'desired') 444 | end 445 | 446 | start_step = train_length -1000; 447 | movie_step = 100; 448 | time_all = train_length + 10000; 449 | 450 | line_prop='solid'; 451 | % plot animation of the tracking control 452 | if plot_movie == 1 453 | func_plot_movie(start_step, movie_step, time_all, q_control_all(:, 1), q_control_all(:, 2), properties, line_prop) 454 | end 455 | 456 | control_infor.q_control=q_control; 457 | control_infor.qdt_control=qdt_control; 458 | control_infor.q2dt_control=q2dt_control; 459 | control_infor.tau_control=tau_control; 460 | 461 | time_infor.val_length=val_length; 462 | 463 | end 464 | 465 | -------------------------------------------------------------------------------- /func_double_arm.m: -------------------------------------------------------------------------------- 1 | function [] = func_double_arm() 2 | 3 | rng('shuffle') 4 | 5 | time_today = datestr(now, 'mmddyyyy'); 6 | 7 | dt=0.01; 8 | input_infor={'xy', 'qdt'}; 9 | 10 | % in-out dimension 11 | dim_in=length(input_infor) * 4; 12 | dim_out=2; 13 | 14 | % double robot arm properties 15 | m1=1;m2=1; 16 | l1=0.5;l2=0.5; 17 | lc1=0.25;lc2=0.25; 18 | I1=0.03;I2=0.03; 19 | properties=[m1, m2, l1, l2, lc1, lc2, I1, I2]; 20 | 21 | reset_t=80; % 70 to get good results for infty symbol, and noise level:0.8*0.5*10^(-2) 22 | train_t = 200000; 23 | val_t=500; 24 | 25 | noise_level=2.0*10^(-2); 26 | disturbance = 0.00; 27 | measurement_noise = 0.00; 28 | 29 | n=200; 30 | hyperpara_set = [0.756250, 0.756250, 0.843750, -3.125, 106.71875, 2.0]; 31 | eig_rho = hyperpara_set(1); 32 | W_in_a = hyperpara_set(2); 33 | alpha = hyperpara_set(3); 34 | beta = 10^hyperpara_set(4); 35 | k = round( hyperpara_set(5)/200*n); 36 | kb = hyperpara_set(6); 37 | 38 | W_in = W_in_a*(2*rand(n,dim_in)-1); 39 | res_net=sprandsym(n,k/n); 40 | eig_D=eigs(res_net,1); 41 | res_net=(eig_rho/(abs(eig_D))).*res_net; 42 | res_net=full(res_net); 43 | 44 | section_len=round(reset_t/dt); % 30 45 | washup_length=round(1002/dt); 46 | train_length=round(train_t/dt)+round(5/dt); % 50000 47 | val_length=round(val_t/dt); % 300 48 | time_length = train_length + 2 * val_length + 3 * washup_length + 100; 49 | 50 | res_infor=struct('W_in', W_in, 'res_net', res_net, 'alpha', alpha, 'kb', kb, 'beta', beta, 'n', n); 51 | time_infor=struct('section_len', section_len, 'washup_length', washup_length, ... 52 | 'train_length', train_length, 'val_length', val_length, 'time_length', time_length); 53 | 54 | % generate training and validation data 55 | [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties); 56 | xy=xy(washup_length:end, :); 57 | q=q(washup_length:end, :); 58 | qdt=qdt(washup_length:end, :); 59 | q2dt=q2dt(washup_length:end, :); 60 | tau=tau(washup_length:end, :); 61 | data_reservoir = struct('xy', xy, 'q', q, 'qdt', qdt, 'q2dt', q2dt, 'tau', tau); 62 | 63 | clearvars xy q qdt q2dt tau 64 | 65 | tic; 66 | [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out); 67 | toc; 68 | 69 | clearvars data_reservoir 70 | 71 | rng('shuffle') 72 | 73 | failure.type = 'none'; 74 | blur.blur = 0; 75 | 76 | plot_val_and_update=0; 77 | 78 | disturbance = 0.00; 79 | measurement_noise = 0.00; 80 | 81 | rmse_start_time = 200000; 82 | rmse_end_time = 300000-100; 83 | 84 | plot_movie = 0; 85 | traj_type = 'lorenz'; 86 | bridge_type = 'cubic'; 87 | 88 | time_infor.val_length=300000; 89 | 90 | save_rend=0; 91 | idx=1; 92 | 93 | val_and_update; 94 | 95 | rmse_1 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 96 | 97 | % idx = 2 98 | traj_type = 'circle'; 99 | bridge_type = 'cubic'; 100 | 101 | time_infor.val_length=300000; 102 | save_rend=1; 103 | idx=2; 104 | 105 | val_and_update; 106 | 107 | rmse_2 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 108 | 109 | % idx = 3 110 | traj_type = 'mg17'; 111 | bridge_type = 'cubic'; 112 | 113 | time_infor.val_length=300000; 114 | save_rend=1; 115 | idx=3; 116 | 117 | val_and_update; 118 | 119 | rmse_3 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 120 | 121 | % idx = 41 122 | traj_type = 'infty'; 123 | bridge_type = 'cubic'; 124 | 125 | time_infor.val_length=300000; 126 | save_rend=1; 127 | idx=4; 128 | 129 | val_and_update; 130 | 131 | rmse_4 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 132 | 133 | a_rmse = (rmse_1 + rmse_2 + rmse_3 + rmse_4) / 4; 134 | 135 | save(['./choose_file/all_traj_', time_today, '_' num2str(randi(9999)) '_' num2str(randi(9999)) '.mat'], 'time_infor', 'input_infor', 'res_infor', 'properties', 'dim_in', 'dim_out', 'Wout', 'r_end', 'dt', 'reset_t', 'noise_level', 'a_rmse') 136 | 137 | end 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /func_reservoir_train.m: -------------------------------------------------------------------------------- 1 | function [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out) 2 | rng('shuffle'); 3 | % train the reservoir computing by given the input and output. 4 | xy=data_reservoir.xy; 5 | q=data_reservoir.q; 6 | qdt=data_reservoir.qdt; 7 | q2dt=data_reservoir.q2dt; 8 | tau=data_reservoir.tau; 9 | 10 | washup_length=time_infor.washup_length; 11 | train_length=time_infor.train_length; % 50000 12 | 13 | W_in=res_infor.W_in; 14 | res_net=res_infor.res_net; 15 | alpha=res_infor.alpha; 16 | kb=res_infor.kb; 17 | beta=res_infor.beta; 18 | n=res_infor.n; 19 | 20 | r_train=zeros(n,train_length-washup_length); 21 | y_train=zeros(dim_out,train_length-washup_length); 22 | 23 | train_x=zeros(train_length,dim_in); 24 | train_y=zeros(train_length,dim_out); 25 | if length(input_infor)==2 && strcmp(input_infor(1), 'xy') == 1 && strcmp(input_infor(2), 'qdt') == 1 26 | train_x(:,:)=[xy(1:train_length,:),xy(2:train_length+1,:), qdt(1:train_length,:),qdt(2:train_length+1,:)]; 27 | elseif length(input_infor)==1 && strcmp(input_infor(1), 'q') == 1 28 | train_x(:,:)=[q(1:train_length,:),q(2:train_length+1,:)]; 29 | elseif length(input_infor)==2 && strcmp(input_infor(1), 'q') == 1 && strcmp(input_infor(2), 'qdt') == 1 30 | train_x(:,:)=[q(1:train_length,:),q(2:train_length+1,:), qdt(1:train_length,:),qdt(2:train_length+1,:)]; 31 | elseif length(input_infor)==3 && strcmp(input_infor(1), 'xy') == 1 && strcmp(input_infor(2), 'qdt') == 1 && strcmp(input_infor(3), 'q2dt') == 1 32 | train_x(:,:)=[q(1:train_length,:),q(2:train_length+1,:), qdt(1:train_length,:),qdt(2:train_length+1,:), q2dt(1:train_length,:),q2dt(2:train_length+1,:)]; 33 | end 34 | 35 | train_y(:,:)=tau(1:train_length,:); 36 | train_x=train_x'; 37 | train_y=train_y'; 38 | 39 | r_all=zeros(n,train_length+1);%2*rand(n,1)-1;% 40 | for ti=1:train_length 41 | r_all(:,ti+1)=(1-alpha)*r_all(:,ti) + alpha*tanh(res_net*r_all(:,ti)+W_in*train_x(:,ti)+kb*ones(n,1)); 42 | end 43 | 44 | r_out=r_all(:,washup_length+2:end); % n * (train_length - 11) 45 | r_out(2:2:end,:)=r_out(2:2:end,:).^2; 46 | r_end(:)=r_all(:,end); % n * 1 47 | 48 | r_train(:,:) = r_out; 49 | y_train(:,:) = train_y(1:dim_out,washup_length+1:end); 50 | % linear regression 51 | Wout=y_train*r_train'*(r_train*r_train'+beta*eye(n))^(-1); 52 | 53 | end 54 | 55 | -------------------------------------------------------------------------------- /func_reservoir_validate.m: -------------------------------------------------------------------------------- 1 | function [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type, bridge_type, time_infor, ... 2 | input_infor, res_infor, start_info, properties, dim_in, dim_out, Wout, r_end, dt, plot_movie, save_rend, ... 3 | failure, blur, traj_frequency) 4 | %% read parameters 5 | % save_rend = 0 means, the robot arm start from all states to be zero, 6 | % while save_rend = 1 means, the robot arm continue to move. 7 | if save_rend==1 8 | q=start_info.q; 9 | qdt=start_info.qdt; 10 | q2dt=start_info.q2dt; 11 | tau=start_info.tau; 12 | end 13 | 14 | val_length=time_infor.val_length; 15 | 16 | [m1, m2, l1, l2, lc1, lc2, I1, I2] = matsplit(properties); 17 | 18 | W_in=res_infor.W_in; 19 | res_net=res_infor.res_net; 20 | alpha=res_infor.alpha; 21 | kb=res_infor.kb; 22 | n=res_infor.n; 23 | 24 | failure_type = failure.type; 25 | if strcmp(failure_type, 'none') == 0 26 | failure_amplitude = failure.amplitude; 27 | if strcmp(failure_type, 'all') == 1 28 | failure_amplitude_2 = failure.amplitude_2; 29 | end 30 | end 31 | 32 | %% read data 33 | val_pred_y=zeros(val_length, dim_out); 34 | val_real_y=zeros(val_length, dim_out); 35 | 36 | if save_rend == 1 37 | r=reshape(r_end', n, 1); 38 | else 39 | r=zeros(size(res_net, 1), 1); 40 | end 41 | u=zeros(dim_in,1); 42 | 43 | q_control=zeros(val_length+100, 2); 44 | qdt_control=zeros(val_length+100, 2); 45 | q2dt_control=zeros(val_length+100, 2); 46 | tau_control=zeros(val_length+100, 2); 47 | 48 | if save_rend == 1 49 | q_control(1, :) = q; 50 | qdt_control(1,:) = qdt; 51 | q2dt_control(1,:) = q2dt; 52 | tau_control(1,:) = tau; 53 | else 54 | if strcmp(traj_type, 'infty') == 1 55 | q_control(1, 1) = (3-1)*rand(1) + 1; 56 | q_control(1, 2) = (0.0 - 2.4) * rand(1); 57 | % q_control(1, 1) = (4-3)*rand(1) + 3; 58 | % q_control(1, 2) = (1.5 - 1.0) * rand(1) + 1.0; 59 | else 60 | q_control(1, 1) = (6-4)*rand(1) + 4; 61 | q_control(1, 2) = (0.0 - 2.4) * rand(1) - 0.1; 62 | end 63 | end 64 | 65 | % judge if there is any nan number: 66 | % dbstop if naninf 67 | q_pred=q_control; 68 | qdt_pred=qdt_control; 69 | q2dt_pred=q2dt_control; 70 | tau_pred=tau_control; 71 | 72 | control_infor = struct('q_control', q_control, 'qdt_control', qdt_control, ... 73 | 'q2dt_control', q2dt_control, 'tau_control', tau_control); 74 | 75 | %% generate desired trajectory 76 | 77 | [control_infor, time_infor] = func_desired_traj(traj_type, bridge_type, time_infor, control_infor, properties, dt, plot_movie, traj_frequency); 78 | 79 | q_control=control_infor.q_control; 80 | qdt_control=control_infor.qdt_control; 81 | 82 | val_length=time_infor.val_length; 83 | 84 | x_control=l1*cos(q_control(:,1))+l2*cos(q_control(:,1)+q_control(:,2)); 85 | y_control=l1*sin(q_control(:,1))+l2*sin(q_control(:,1)+q_control(:,2)); 86 | data_control=[x_control, y_control]; 87 | 88 | control_infor.data_control=data_control; 89 | 90 | 91 | %% validation 92 | if length(input_infor)==2 && strcmp(input_infor(1), 'xy') == 1 && strcmp(input_infor(2), 'qdt') == 1 93 | input_infor_label = 1; 94 | u(:)=[data_control(1,:)';data_control(2,:)';qdt_control(1,:)';qdt_control(2,:)']; 95 | end 96 | 97 | data_pred = data_control; 98 | % generate gaussian noise matrix for noise testing 99 | rng((now*1000-floor(now*1000))*100000) 100 | disturbance_failure = zeros(2, val_length); 101 | measurement_failure = zeros(round(dim_in/2), val_length); 102 | if strcmp(failure_type, 'disturbance') 103 | disturbance_failure = randn(2, val_length) * failure_amplitude; 104 | elseif strcmp(failure_type, 'measurement') 105 | measurement_failure = randn(round(dim_in/2), val_length) * failure_amplitude; 106 | elseif strcmp(failure_type, 'all') 107 | disturbance_failure = randn(2, val_length) * failure_amplitude; 108 | measurement_failure = randn(round(dim_in/2), val_length) * failure_amplitude_2; 109 | end 110 | % limit the predicted tau 111 | taudt_threshold = [-5e-2, 5e-2]; 112 | 113 | % In each step, according to the predicted value, the system evolves 114 | % according to its inherent rule. 115 | for t_i = 1:val_length-3 116 | r = (1-alpha)*r + alpha*tanh(res_net*r + W_in*u + +kb*ones(n,1)); 117 | r_out = r; 118 | r_out(2:2:end,1) = r_out(2:2:end,1).^2; %even number -> squared 119 | predict_value = Wout * r_out; 120 | 121 | disturbance_f_value = predict_value .* disturbance_failure(:, t_i); 122 | predict_value = predict_value + disturbance_f_value; 123 | if t_i == 1 124 | time_li = 1; 125 | else 126 | time_li = t_i - 1; 127 | end 128 | for li = 1:2 129 | if predict_value(li) - tau_pred(time_li, li) > taudt_threshold(2)*dt 130 | predict_value(li) = tau_pred(time_li, li) + taudt_threshold(2)*dt; 131 | end 132 | if predict_value(li) - tau_pred(time_li, li) < taudt_threshold(1)*dt 133 | predict_value(li) = tau_pred(time_li, li) + taudt_threshold(1)*dt; 134 | end 135 | end 136 | tau_pred(t_i, :) = predict_value; 137 | 138 | time_now=t_i; 139 | 140 | H11=m1*lc1^2+I1+m2*(l1^2+lc2^2+2*l1*lc2*cos(q_pred(time_now,2)))+I2; 141 | H12=m2*l1*lc2*cos(q_pred(time_now,2))+m2*lc2^2+I2; 142 | 143 | H21=H12; 144 | H22=m2*lc2^2+I2; 145 | h=m2*l1*lc2*sin(q_pred(time_now,2)); 146 | 147 | part_1=-h*qdt_pred(time_now,2)*qdt_pred(time_now,1)-h*(qdt_pred(time_now,1)+qdt_pred(time_now,2))*qdt_pred(time_now,2); 148 | part_2=h*qdt_pred(time_now,1)*qdt_pred(time_now,1); 149 | denominator=H12*H21-H11*H22; 150 | 151 | q2dt_pred(time_now,1)=-(-part_1*H22+H12*part_2-H12*predict_value(2)+H22*predict_value(1))/denominator; 152 | q2dt_pred(time_now,2)=-(part_1*H21-H11*part_2+H11*predict_value(2)-H21*predict_value(1))/denominator; 153 | 154 | q_pred(time_now+1,:)=q_pred(time_now,:)+qdt_pred(time_now,:)*dt; 155 | qdt_pred(time_now+1,:)=qdt_pred(time_now,:)+q2dt_pred(time_now,:)*dt; 156 | 157 | x_pred=l1*cos(q_pred(time_now+1,1))+l2*cos(q_pred(time_now+1,1)+q_pred(time_now+1,2)); 158 | y_pred=l1*sin(q_pred(time_now+1,1))+l2*sin(q_pred(time_now+1,1)+q_pred(time_now+1,2)); 159 | 160 | x_measurement_f_value = x_pred .* measurement_failure(1, t_i); 161 | y_measurement_f_value = y_pred .* measurement_failure(2, t_i); 162 | qdt_measurement_f_value = qdt_pred(time_now+1, :) .* measurement_failure(3:4, t_i)'; 163 | 164 | x_pred_measurement = x_pred + x_measurement_f_value; 165 | y_pred_measurement = y_pred + y_measurement_f_value; 166 | qdt_pred_measurement = qdt_pred(time_now+1, :) + qdt_measurement_f_value; 167 | 168 | data_pred(time_now+1,:)=[x_pred, y_pred]; 169 | 170 | if input_infor_label == 1 171 | u(1:2) = [x_pred_measurement;y_pred_measurement]; 172 | u(3:4) = data_control(time_now+2,:); 173 | u(5:6) = qdt_pred_measurement; 174 | u(7:8) = qdt_control(time_now+2,:); 175 | end 176 | end 177 | 178 | %% output 179 | 180 | output_infor.data_pred = data_pred; 181 | output_infor.q_pred = q_pred; 182 | output_infor.qdt_pred = qdt_pred; 183 | output_infor.q2dt_pred = q2dt_pred; 184 | output_infor.tau_pred = tau_pred; 185 | 186 | r_end = r; 187 | 188 | end 189 | 190 | -------------------------------------------------------------------------------- /func_train_val.m: -------------------------------------------------------------------------------- 1 | function [rmse_l, rmse_c, rmse_m, rmse_i, t] = func_train_val(n, train_t, reset_t, noise_level, bias) 2 | 3 | dt=0.01; 4 | input_infor={'xy', 'qdt'}; 5 | 6 | dim_in=length(input_infor) * 4; 7 | dim_out=2; 8 | 9 | m1=1;m2=1; 10 | l1=0.5;l2=0.5; 11 | lc1=0.25;lc2=0.25; 12 | I1=0.03;I2=0.03; 13 | properties=[m1, m2, l1, l2, lc1, lc2, I1, I2]; 14 | 15 | val_t=500; 16 | % prepare for the reservoir computing, the hyperparameters are given by the 17 | % Bayesian optimization. 18 | hyperpara_set = [0.756250, 0.756250, 0.843750, -3.125, 106.71875, bias]; 19 | eig_rho = hyperpara_set(1); 20 | W_in_a = hyperpara_set(2); 21 | alpha = hyperpara_set(3); 22 | beta = 10^hyperpara_set(4); 23 | k = round( hyperpara_set(5)/200*n); 24 | kb = hyperpara_set(6); 25 | 26 | W_in = W_in_a*(2*rand(n,dim_in)-1); 27 | res_net=sprandsym(n,k/n); 28 | eig_D=eigs(res_net,1); 29 | res_net=(eig_rho/(abs(eig_D))).*res_net; 30 | res_net=full(res_net); 31 | 32 | section_len=round(reset_t/dt); 33 | washup_length=round(1002/dt); 34 | train_length=round(train_t/dt)+round(5/dt); % 50000 35 | val_length=round(val_t/dt); % 300 36 | time_length = train_length + 2 * val_length + 3 * washup_length + 100; 37 | 38 | res_infor=struct('W_in', W_in, 'res_net', res_net, 'alpha', alpha, 'kb', kb, 'beta', beta, 'n', n); 39 | time_infor=struct('section_len', section_len, 'washup_length', washup_length, ... 40 | 'train_length', train_length, 'val_length', val_length, 'time_length', time_length); 41 | 42 | % generate training and validation data 43 | [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties); 44 | xy=xy(washup_length:end, :); 45 | q=q(washup_length:end, :); 46 | qdt=qdt(washup_length:end, :); 47 | q2dt=q2dt(washup_length:end, :); 48 | tau=tau(washup_length:end, :); 49 | data_reservoir = struct('xy', xy, 'q', q, 'qdt', qdt, 'q2dt', q2dt, 'tau', tau); 50 | 51 | clearvars q qdt q2dt tau 52 | 53 | tic; 54 | [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out); 55 | t = toc; 56 | 57 | clearvars data_reservoir 58 | 59 | % after training, we test on four trajectories: lorenz, circle, mg17 and 60 | % eight symbol. 61 | 62 | % lorenz 63 | disturbance = 0.00; 64 | measurement_noise = 0.00; 65 | plot_movie = 0; 66 | 67 | time_infor.val_length=150000; 68 | 69 | traj_type = 'lorenz'; 70 | bridge_type = 'cubic'; 71 | save_rend=0; 72 | idx=1; 73 | 74 | if exist('traj_frequency','var') == 0 75 | traj_frequency = 75; 76 | end 77 | 78 | blur.blur = 0; 79 | failure.type = 'none'; 80 | 81 | if save_rend == 0 82 | start_info.q=0; 83 | start_info.qdt=0; 84 | start_info.q2dt=0; 85 | start_info.tau=0; 86 | end 87 | 88 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,... 89 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ... 90 | Wout, r_end, dt, plot_movie, save_rend,failure,blur, traj_frequency); 91 | 92 | % update 93 | data_pred=output_infor.data_pred; 94 | data_control=control_infor.data_control; 95 | 96 | rmse_start_time = round(time_infor.val_length * 3/5); 97 | rmse_end_time = time_infor.val_length - 100; 98 | 99 | rmse_l = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 100 | 101 | 102 | % circle 103 | disturbance = 0.00; 104 | measurement_noise = 0.00; 105 | plot_movie = 0; 106 | 107 | time_infor.val_length=150000; 108 | 109 | traj_type = 'circle'; 110 | bridge_type = 'cubic'; 111 | save_rend=0; 112 | idx=1; 113 | 114 | blur.blur = 0; 115 | failure.type = 'none'; 116 | 117 | if save_rend == 0 118 | start_info.q=0; 119 | start_info.qdt=0; 120 | start_info.q2dt=0; 121 | start_info.tau=0; 122 | end 123 | 124 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,... 125 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ... 126 | Wout, r_end, dt, plot_movie, save_rend,failure,blur, traj_frequency); 127 | 128 | % update 129 | data_pred=output_infor.data_pred; 130 | data_control=control_infor.data_control; 131 | 132 | rmse_start_time = round(time_infor.val_length * 3/5); 133 | rmse_end_time = time_infor.val_length - 100; 134 | 135 | rmse_c = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 136 | 137 | 138 | % mg17 139 | disturbance = 0.00; 140 | measurement_noise = 0.00; 141 | plot_movie = 0; 142 | 143 | time_infor.val_length=150000; 144 | 145 | traj_type = 'mg17'; 146 | bridge_type = 'cubic'; 147 | save_rend=0; 148 | idx=1; 149 | 150 | blur.blur = 0; 151 | failure.type = 'none'; 152 | 153 | if save_rend == 0 154 | start_info.q=0; 155 | start_info.qdt=0; 156 | start_info.q2dt=0; 157 | start_info.tau=0; 158 | end 159 | 160 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,... 161 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ... 162 | Wout, r_end, dt, plot_movie, save_rend,failure,blur, traj_frequency); 163 | 164 | % update 165 | data_pred=output_infor.data_pred; 166 | data_control=control_infor.data_control; 167 | 168 | rmse_start_time = round(time_infor.val_length * 3/5); 169 | rmse_end_time = time_infor.val_length - 100; 170 | 171 | rmse_m = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 172 | 173 | % infty 174 | disturbance = 0.00; 175 | measurement_noise = 0.00; 176 | plot_movie = 0; 177 | 178 | time_infor.val_length=150000; 179 | 180 | traj_type = 'infty'; 181 | bridge_type = 'cubic'; 182 | save_rend=0; 183 | idx=1; 184 | 185 | blur.blur = 0; 186 | failure.type = 'none'; 187 | 188 | if save_rend == 0 189 | start_info.q=0; 190 | start_info.qdt=0; 191 | start_info.q2dt=0; 192 | start_info.tau=0; 193 | end 194 | 195 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,... 196 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ... 197 | Wout, r_end, dt, plot_movie, save_rend,failure,blur,traj_frequency); 198 | 199 | % update 200 | data_pred=output_infor.data_pred; 201 | data_control=control_infor.data_control; 202 | 203 | rmse_start_time = round(time_infor.val_length * 3/5); 204 | rmse_end_time = time_infor.val_length - 100; 205 | 206 | rmse_i = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time); 207 | 208 | 209 | end 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | -------------------------------------------------------------------------------- /main.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | addpath('./tools/') 6 | 7 | load('./save_file/all_traj_06282022.mat') 8 | 9 | % choose the reference trajectory 10 | % traj_type = 'lorenz'; 11 | traj_type = 'circle'; 12 | % traj_type = 'mg17'; 13 | % traj_type = 'infty'; 14 | % traj_type = 'fermat'; 15 | % traj_type = 'astroid'; 16 | % traj_type = 'heart'; 17 | % traj_type = 'epitrochoid'; 18 | % traj_type = 'lissajous'; 19 | % traj_type = 'talbot'; 20 | % traj_type = 'chua'; 21 | % traj_type = 'rossler'; 22 | % traj_type = 'sprott_1'; 23 | % traj_type = 'sprott_4'; 24 | % traj_type = 'mg30'; 25 | % traj_type = 'lorenz96'; 26 | 27 | % traj_frequency = 75; 28 | 29 | time_infor.val_length=200000; 30 | bridge_type = 'cubic'; 31 | failure.type = 'none'; 32 | disturbance = 0.00; 33 | measurement_noise = 0.00; 34 | plot_movie = 0; 35 | blur.blur = 0; 36 | 37 | save_rend = 0; 38 | idx = 0; 39 | 40 | plot_val_and_update = 1; 41 | 42 | % let the well-trained machine to follow the given reference. 43 | val_and_update 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /rand_traj_control.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | clc 4 | 5 | addpath('./tools/') 6 | 7 | % load training data 8 | load('./save_file/all_traj_06282022.mat') 9 | 10 | %% 11 | 12 | traj_set = ["infty", "circle", "astroid", "fermat", ... 13 | "lissajous", "talbot", "heart", "lorenz", "chua", "rossler", ... 14 | "sprott_1", "sprott_4", "mg17", "mg30", "epitrochoid"]; 15 | 16 | order = randperm(length(traj_set)); 17 | traj_set= traj_set(order); 18 | 19 | plot_val_and_update = 1; 20 | disturbance = 0.1; 21 | measurement_noise = 0.1; 22 | plot_movie = 0; 23 | bridge_type = 'cubic'; 24 | failure.type = 'none'; 25 | blur.blur = 0; 26 | idx=1; 27 | 28 | val_length_all=150000; 29 | 30 | for ii = 1:length(traj_set) 31 | rng('shuffle') 32 | traj_type = traj_set(ii); 33 | time_infor.val_length=val_length_all; 34 | 35 | if ii == 1 36 | save_rend=0; 37 | else 38 | save_rend=1; 39 | end 40 | 41 | if exist('traj_frequency','var') == 1 42 | if strcmp(traj_type, 'lorenz') == 1 43 | traj_frequency = 100; 44 | elseif strcmp(traj_type, 'cirlce') == 1 45 | traj_frequency = 150; 46 | else 47 | traj_frequency = 75; 48 | end 49 | end 50 | 51 | val_and_update; 52 | 53 | idx = idx+1; 54 | if ii == 1 55 | aaa = 1; 56 | end 57 | end 58 | 59 | time_infor.val_length=val_length_all; 60 | 61 | time_today = datestr(now, 'mmddyyyy'); 62 | % save(['./save_data/15traj_', time_today, '_', num2str(randi(999)), '.mat'], "val_length_all", "save_all_traj", "traj_set") 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | -------------------------------------------------------------------------------- /read_data/temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /robot_data_generator.m: -------------------------------------------------------------------------------- 1 | function [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties) 2 | 3 | % generate time series for training and validation 4 | 5 | [m1, m2, l1, l2, lc1, lc2, I1, I2] = matsplit(properties); 6 | 7 | section_len = time_infor.section_len; 8 | time_length = time_infor.time_length; 9 | 10 | % generate noise as the control signal, note that we should smooth the data 11 | % to make the generated signal continuous. 12 | noise_interval=noise_level; % 3*10^(-2) 13 | pert_length=time_length*2; 14 | pert=-noise_interval+2*noise_interval*rand(pert_length,2); 15 | 16 | pert=pert(100:end, :); 17 | 18 | BB = smoothdata(pert, 'gaussian', 50); 19 | pert = BB; 20 | 21 | q=zeros(time_length,2); 22 | qdt=zeros(time_length,2); 23 | q2dt=zeros(time_length,2); 24 | tau=zeros(time_length,2); 25 | tau(:,:)=pert(1:time_length,:); 26 | 27 | rng('shuffle') 28 | 29 | % To avoid the values too large in random walk, every 'section_len' step we 30 | % will reset the states of the two link robot arm. 31 | for t_i = 1:time_length-1 32 | if mod(t_i, section_len)==0 33 | q(t_i, 1) = 2 * pi * rand(1); 34 | q(t_i, 2) = 2 * pi * rand(1) - pi; 35 | qdt(t_i,:)=[0,0]; 36 | end 37 | 38 | H11=m1*lc1^2+I1+m2*(l1^2+lc2^2+2*l1*lc2*cos(q(t_i,2)))+I2; 39 | H12=m2*l1*lc2*cos(q(t_i,2))+m2*lc2^2+I2; 40 | H21=H12; 41 | H22=m2*lc2^2+I2; 42 | h=m2*l1*lc2*sin(q(t_i,2)); 43 | 44 | part_1=-h*qdt(t_i,2)*qdt(t_i,1)-h*(qdt(t_i,1)+qdt(t_i,2))*qdt(t_i,2); 45 | part_2=h*qdt(t_i,1)*qdt(t_i,1); 46 | denominator=H12*H21-H11*H22; 47 | 48 | q2dt(t_i,1)=-(-part_1*H22+H12*part_2-H12*tau(t_i,2)+H22*tau(t_i,1))/denominator; 49 | q2dt(t_i,2)=-(part_1*H21-H11*part_2+H11*tau(t_i,2)-H21*tau(t_i,1))/denominator; 50 | 51 | if mod(t_i, section_len)==0 52 | q2dt(t_i,:)=[0,0]; 53 | tau(t_i,:)=[0,0]; 54 | end 55 | 56 | q(t_i+1,:)=q(t_i,:)+qdt(t_i,:)*dt; 57 | qdt(t_i+1,:)=qdt(t_i,:)+q2dt(t_i,:)*dt; 58 | end 59 | 60 | x=l1*cos(q(:,1))+l2*cos(q(:,1)+q(:,2)); 61 | y=l1*sin(q(:,1))+l2*sin(q(:,1)+q(:,2)); 62 | xy=[x, y]; 63 | 64 | end 65 | 66 | -------------------------------------------------------------------------------- /save_file/all_traj_06282022.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/save_file/all_traj_06282022.mat -------------------------------------------------------------------------------- /save_file/temp: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /tools/func_limits.m: -------------------------------------------------------------------------------- 1 | function [value] = func_limits(value, limit_type) 2 | %% limitations 3 | 4 | qdt_lim = [-0.05, 0.05]; 5 | q2dt_lim = [-0.3, 0.3]; 6 | 7 | % limit type: 8 | % 1 : dq/dt 9 | % 2 : d2q/dt2 10 | 11 | if limit_type == 1 12 | if value(1) > qdt_lim(2) 13 | value(1) = qdt_lim(2); 14 | end 15 | if value(1) < qdt_lim(1) 16 | value(1) = qdt_lim(1); 17 | end 18 | if value(2) > qdt_lim(2) 19 | value(2) = qdt_lim(2); 20 | end 21 | if value(2) < qdt_lim(1) 22 | value(2) = qdt_lim(1); 23 | end 24 | elseif limit_type == 2 25 | if value(1) > q2dt_lim(2) 26 | value(1) = q2dt_lim(2); 27 | end 28 | if value(1) < q2dt_lim(1) 29 | value(1) = q2dt_lim(1); 30 | end 31 | if value(2) > q2dt_lim(2) 32 | value(2) = q2dt_lim(2); 33 | end 34 | if value(2) < q2dt_lim(1) 35 | value(2) = q2dt_lim(1); 36 | end 37 | end 38 | 39 | end 40 | 41 | -------------------------------------------------------------------------------- /tools/func_plot_figures.m: -------------------------------------------------------------------------------- 1 | function [] = func_plot_figures(control_infor, output_infor, val_length, plot_movie) 2 | 3 | data_pred=output_infor.data_pred; 4 | q_pred=output_infor.q_pred; 5 | qdt_pred=output_infor.qdt_pred; 6 | q2dt_pred=output_infor.q2dt_pred; 7 | tau_pred=output_infor.tau_pred; 8 | 9 | q_control=control_infor.q_control; 10 | qdt_control=control_infor.qdt_control; 11 | q2dt_control=control_infor.q2dt_control; 12 | tau_control=control_infor.tau_control; 13 | data_control=control_infor.data_control; 14 | 15 | start_time = 1; 16 | end_time = val_length - 10000; 17 | 18 | % plot trajectory 19 | figure(); 20 | hold on 21 | plot(data_control(start_time:end_time, 1), data_control(start_time:end_time, 2),'r'); 22 | plot(data_pred(start_time:end_time, 1), data_pred(start_time:end_time, 2),'b--'); 23 | xlabel('x') 24 | ylabel('y') 25 | line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--') 26 | line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--') 27 | xlim([-1, 1]) 28 | ylim([-1, 1]) 29 | legend('desired trajectory', 'pred trajectory') 30 | 31 | % plot q 32 | q_control_plot=mod(q_control, pi); 33 | q_pred_plot=mod(q_pred, pi); 34 | 35 | figure(); 36 | hold on 37 | plot(q_control_plot(start_time:end_time,1), 'r') 38 | plot(q_pred_plot(start_time:end_time,1), 'b') 39 | xlabel('time step') 40 | ylabel('q(1)') 41 | legend('desired', 'pred') 42 | 43 | figure() 44 | hold on 45 | plot(q_control_plot(start_time:end_time,2), 'r') 46 | plot(q_pred_plot(start_time:end_time,2), 'b') 47 | xlabel('time step') 48 | ylabel('q(2)') 49 | legend('desired', 'pred') 50 | 51 | % plot dq/dt 52 | figure() 53 | hold on 54 | plot(qdt_control(start_time:end_time,1), 'r') 55 | plot(qdt_pred(start_time:end_time,1), 'b') 56 | xlabel('time step') 57 | ylabel('dq/dt(1)') 58 | legend('desired', 'pred') 59 | 60 | figure() 61 | hold on 62 | plot(qdt_control(start_time:end_time,2), 'r') 63 | plot(qdt_pred(start_time:end_time,2), 'b') 64 | xlabel('time step') 65 | ylabel('dq/dt(2)') 66 | legend('desired', 'pred') 67 | 68 | % plot d2q/dt2 69 | remove_transient = 1000; 70 | 71 | figure() 72 | hold on 73 | % plot(q2dt_control(start_time+remove_transient:end_time,1), 'r') 74 | plot(q2dt_pred(start_time+remove_transient:end_time,1), 'b') 75 | xlabel('time step') 76 | ylabel('d2q/dt2(1)') 77 | legend('pred') 78 | 79 | figure() 80 | hold on 81 | % plot(q2dt_control(start_time+remove_transient:end_time,2), 'r') 82 | plot(q2dt_pred(start_time+remove_transient:end_time,2), 'b') 83 | xlabel('time step') 84 | ylabel('d2q/dt2(2)') 85 | legend('pred') 86 | 87 | figure() 88 | hold on 89 | % plot(tau_control(start_time+remove_transient:end_time,1), 'r') 90 | plot(tau_pred(start_time+remove_transient:end_time,1), 'b') 91 | xlabel('time step') 92 | ylabel('tau(1)') 93 | legend('pred') 94 | 95 | figure() 96 | hold on 97 | % plot(tau_control(start_time+remove_transient:end_time,2), 'r') 98 | plot(tau_pred(start_time+remove_transient:end_time,2), 'b') 99 | xlabel('time step') 100 | ylabel('tau(2)') 101 | legend('pred') 102 | 103 | % figure() 104 | % hold on 105 | % plot(tau_control(start_time:end_time,1), 'r') 106 | % plot(tau_control(start_time:end_time,2), 'b') 107 | % xlabel('time step') 108 | % ylabel('tau') 109 | % legend('tau1', 'tau2') 110 | 111 | % plot movie for validation 112 | plot_movie_val = plot_movie; 113 | start_step=start_time; 114 | movie_step=500; 115 | time_all=end_time; 116 | line_property='dotted'; 117 | q1=q_pred(:, 1); 118 | q2=q_pred(:, 2); 119 | if plot_movie_val == 1 120 | func_plot_movie(start_step, movie_step, time_all, q1, q2, properties, line_property) 121 | end 122 | 123 | 124 | 125 | 126 | end -------------------------------------------------------------------------------- /tools/func_plot_movie.m: -------------------------------------------------------------------------------- 1 | function [] = func_plot_movie(start_step, movie_step, time_all, q1, q2, properties, line_prop) 2 | 3 | l1=properties(3); 4 | l2=properties(4); 5 | 6 | value_x_test=l1*cos(q1)+l2*cos(q1+q2); 7 | value_y_test=l1*sin(q1)+l2*sin(q1+q2); 8 | 9 | value_x1=l1*cos(q1); 10 | value_y1=l1*sin(q1); 11 | 12 | if strcmp(line_prop, 'solid') == 1 13 | line_property = 1; 14 | elseif strcmp(line_prop, 'dotted') == 1 15 | line_property = 2; 16 | else 17 | disp('error: please input the line property!') 18 | end 19 | 20 | filename='./results/double_arm.gif'; 21 | figure() 22 | for i=start_step:movie_step:time_all 23 | clf; 24 | hold on 25 | 26 | trace=1000; 27 | 28 | if i > trace * movie_step 29 | if line_property == 1 30 | plot(value_x_test(i-trace*movie_step:i), value_y_test(i-trace*movie_step:i), 'b', 'LineWidth', 1); 31 | elseif line_property == 2 32 | plot(value_x_test(i-trace*movie_step:i), value_y_test(i-trace*movie_step:i), 'b--', 'LineWidth', 1); 33 | end 34 | 35 | else 36 | if line_property == 1 37 | plot(value_x_test(1:i), value_y_test(1:i), 'b', 'LineWidth', 1); 38 | elseif line_property == 2 39 | plot(value_x_test(1:i), value_y_test(1:i), 'b--', 'LineWidth', 1); 40 | end 41 | end 42 | 43 | line([0,value_x1(i)], [0, value_y1(i)], 'LineWidth', 3); 44 | line([value_x1(i),value_x_test(i)], [value_y1(i),value_y_test(i)], 'LineWidth', 3) 45 | line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--') 46 | line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--') 47 | xlabel('x') 48 | ylabel('y') 49 | xlim([-1,1]) 50 | ylim([-1,1]) 51 | title(['time = ' ,num2str(i*0.01),'s']) 52 | drawnow 53 | frame=getframe(gcf); 54 | im=frame2im(frame); 55 | [imind,cm]=rgb2ind(im,256); 56 | if i==start_step 57 | imwrite(imind,cm,filename,'gif', 'Loopcount',inf); 58 | else 59 | imwrite(imind,cm,filename,'gif','WriteMode','append'); 60 | end 61 | 62 | end 63 | 64 | end 65 | 66 | -------------------------------------------------------------------------------- /tools/func_rmse.m: -------------------------------------------------------------------------------- 1 | function [rmse] = func_rmse(a, b, time_start, time_end) 2 | [c,d] = size(a); 3 | len = max(c, d); 4 | 5 | if size(a, 2) ~= len 6 | a = a'; 7 | end 8 | 9 | if size(b, 2) ~= len 10 | b = b'; 11 | end 12 | 13 | rmse = sqrt(mean( sum( (a(:, time_start:time_end)-b(:, time_start:time_end)).^2, 1 ))); 14 | 15 | end -------------------------------------------------------------------------------- /val_and_update.m: -------------------------------------------------------------------------------- 1 | % validate 2 | if save_rend == 0 3 | start_info.q=0; 4 | start_info.qdt=0; 5 | start_info.q2dt=0; 6 | start_info.tau=0; 7 | 8 | r_end = zeros(size(res_infor.res_net, 1), 1); 9 | end 10 | 11 | if exist('traj_frequency','var') == 0 12 | if strcmp(traj_type, 'lorenz') == 1 13 | traj_frequency = 100; 14 | elseif strcmp(traj_type, 'cirlce') == 1 15 | traj_frequency = 150; 16 | else 17 | traj_frequency = 75; 18 | end 19 | end 20 | 21 | rng('shuffle') 22 | 23 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,... 24 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ... 25 | Wout, r_end, dt, plot_movie, save_rend,failure,blur,traj_frequency); 26 | 27 | % update 28 | data_pred=output_infor.data_pred; 29 | q_pred=output_infor.q_pred; 30 | qdt_pred=output_infor.qdt_pred; 31 | q2dt_pred=output_infor.q2dt_pred; 32 | tau_pred=output_infor.tau_pred; 33 | 34 | q_control=control_infor.q_control; 35 | qdt_control=control_infor.qdt_control; 36 | q2dt_control=control_infor.q2dt_control; 37 | tau_control=control_infor.tau_control; 38 | data_control=control_infor.data_control; 39 | 40 | val_length=time_infor.val_length; 41 | 42 | start_info.q=q_pred(val_length-3,:); 43 | start_info.qdt=qdt_pred(val_length-3, :); 44 | start_info.q2dt=q2dt_pred(val_length-3,:); 45 | start_info.tau=tau_pred(val_length-3,:); 46 | 47 | save_all_traj.(['control_', num2str(idx)]) = control_infor; 48 | save_all_traj.(['output_', num2str(idx)]) = output_infor; 49 | 50 | % plot trajectory 51 | 52 | start_time=1; 53 | end_time=val_length-100; 54 | 55 | if exist('plot_val_and_update','var') == 0 56 | plot_val_and_update = 0; 57 | end 58 | 59 | if plot_val_and_update==1 60 | figure(); 61 | hold on 62 | plot(data_control(start_time:end_time, 1), data_control(start_time:end_time, 2),'r'); 63 | plot(data_pred(start_time:end_time, 1), data_pred(start_time:end_time, 2),'b--'); 64 | xlabel('x') 65 | ylabel('y') 66 | line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--') 67 | line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--') 68 | xlim([-1, 1]) 69 | ylim([-1, 1]) 70 | legend('desired trajectory', 'pred trajectory') 71 | end 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | --------------------------------------------------------------------------------