├── LICENSE.txt
├── README.md
├── examples
├── circle.png
├── infty.png
├── lorenz.png
├── mg17.png
└── temp
├── experiments
├── change_speed.m
├── double_arm.m
├── iter_noise.m
├── normal_success_rate.m
├── speed_test.m
├── speed_test_main.m
├── success_rate.m
├── success_rate_1.m
├── success_rate_bias.m
├── temp
├── uncertain_test_saferegion.m
└── uncertain_test_saferegion_main.m
├── func_desired_traj.m
├── func_double_arm.m
├── func_reservoir_train.m
├── func_reservoir_validate.m
├── func_train_val.m
├── main.m
├── rand_traj_control.m
├── read_data
└── temp
├── robot_data_generator.m
├── save_file
├── all_traj_06282022.mat
└── temp
├── tools
├── func_limits.m
├── func_plot_figures.m
├── func_plot_movie.m
└── func_rmse.m
└── val_and_update.m
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Zheng-Meng Zhai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tracking Control
2 | Codes of a submitted manuscript to Nature Communications.
3 |
4 | [Model-free tracking control of complex dynamical trajectories with machine learning](https://www.nature.com/articles/s41467-023-41379-3) has been published in Nature Communications!
5 |
6 | # Requirements
7 | Please download the dataset at https://doi.org/10.5281/zenodo.8044994
8 | The chaotic trajectories should be moved into the folder: read_data. The periodic trajectories are generated in the code
9 |
10 | Note that we use a built-in package 'matsplit' in MATLAB. Please click 'Home', choose 'Add-Ons', search this package and install it to run the code.
11 |
12 | # Example
13 | Run 'main.m' with traj_type = 'circle', you will get the ground truth and tracked trajectories in the picture bellow:
14 |
15 | 
16 |
17 | Change traj_type to others to track different trajectories, e.g., traj_type = 'lorenz'.
18 |
19 | # Citation
20 | This work is available at [https://www.nature.com/articles/s41467-023-41379-3](https://www.nature.com/articles/s41467-023-41379-3), and can be cited with the followling bibtex entry:
21 | ```
22 | @article{zhai2023model,
23 | title={Model-free tracking control of complex dynamical trajectories with machine learning},
24 | author={Zhai, Zheng-Meng and Moradi, Mohammadamin and Kong, Ling-Wei and Glaz, Bryan and Haile, Mulugeta and Lai, Ying-Cheng},
25 | journal={Nature Communications},
26 | volume={14},
27 | number={1},
28 | pages={5698},
29 | year={2023},
30 | publisher={Nature Publishing Group UK London}
31 | }
32 | ```
33 |
--------------------------------------------------------------------------------
/examples/circle.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/circle.png
--------------------------------------------------------------------------------
/examples/infty.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/infty.png
--------------------------------------------------------------------------------
/examples/lorenz.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/lorenz.png
--------------------------------------------------------------------------------
/examples/mg17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/examples/mg17.png
--------------------------------------------------------------------------------
/examples/temp:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/experiments/change_speed.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | %% change the frequency of infty symbol
6 |
7 | addpath('./tools/')
8 |
9 | time_today = datestr(now, 'mmddyyyy');
10 |
11 | % traj_frequency_set = round(exp(linspace(log(1), log(500), 10)));
12 | traj_frequency_set = round(linspace(1, 400, 20));
13 |
14 | iteration = 50;
15 | val_length_all = 300000;
16 |
17 | disturbance = 0.00;
18 | measurement_noise = 0.00;
19 |
20 | plot_movie = 0;
21 | traj_type = 'infty';
22 | bridge_type = 'cubic';
23 | failure.type = 'none';
24 | blur.blur = 0;
25 |
26 | rmse_start_time = round(val_length_all * 3/5);
27 | rmse_end_time = val_length_all - 100;
28 |
29 | rmse_frequency_set = zeros(length(traj_frequency_set), iteration);
30 |
31 | for tfs = 1:length(traj_frequency_set)
32 | traj_frequency = traj_frequency_set(tfs);
33 |
34 | rmse_repeat_set = zeros(1, iteration);
35 | for repeat_i = 1:iteration
36 | load('./save_file/all_traj_06282022.mat')
37 | rng('shuffle')
38 | save_rend=0;
39 | idx=1;
40 | time_infor.val_length=val_length_all;
41 |
42 | val_and_update;
43 |
44 | rmse_repeat_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
45 |
46 | bb=1;
47 | end
48 |
49 | rmse_repeat_set = sort(rmse_repeat_set);
50 |
51 | rmse_frequency_set(tfs, :) = rmse_repeat_set;
52 |
53 | xxx = ['traj frequency:', traj_frequency];
54 | disp(xxx)
55 |
56 | aa = 1;
57 | end
58 |
59 |
60 | save_frequency_iter.traj_type = traj_type;
61 |
62 | save_frequency_iter.val_length = val_length_all;
63 | save_frequency_iter.traj_frequency_set = traj_frequency_set;
64 | save_frequency_iter.rmse_frequency_set = rmse_frequency_set;
65 |
66 | save(['save_data/save_frequency_iter_', time_today, '_' num2str(randi(999)) '.mat'], "save_frequency_iter")
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/experiments/double_arm.m:
--------------------------------------------------------------------------------
1 | %% clear all
2 |
3 | disp('Preparing for clearing all and restart...')
4 | m=input('Do you want to continue, Y/N [Y]:', 's');
5 | if m ~= 'Y'
6 | return
7 | end
8 |
9 | clear all
10 | close all
11 | clc
12 |
13 | addpath('./tools/')
14 |
15 | %% pre training
16 | dt=0.01;
17 | % input_infor={'q'};
18 | input_infor={'xy', 'qdt'};
19 | % input_infor={'q', 'qdt'};
20 | % input_infor={'xy', 'qdt', 'q2dt'};
21 |
22 | % in-out dimension
23 | dim_in=length(input_infor) * 4;
24 | dim_out=2;
25 |
26 | % double robot arm properties
27 | m1=1;m2=1;
28 | l1=0.5;l2=0.5;
29 | lc1=0.25;lc2=0.25;
30 | I1=0.03;I2=0.03;
31 | properties=[m1, m2, l1, l2, lc1, lc2, I1, I2];
32 |
33 | reset_t=80;
34 | train_t = 200000;
35 | val_t=500;
36 | noise_level=2.0*10^(-2);
37 | % add pertubations and measurement noise to the model(and uncertainty)
38 | % proportional to the real value.
39 | disturbance = 0.00;
40 | measurement_noise = 0.00;
41 |
42 | n=200;
43 | hyperpara_set = [0.756250, 0.756250, 0.843750, -3.125, 106.71875, 2.0];
44 | eig_rho = hyperpara_set(1);
45 | W_in_a = hyperpara_set(2);
46 | alpha = hyperpara_set(3);
47 | beta = 10^hyperpara_set(4);
48 | k = round( hyperpara_set(5)/200*n);
49 | kb = hyperpara_set(6);
50 |
51 | W_in = W_in_a*(2*rand(n,dim_in)-1);
52 | res_net=sprandsym(n,k/n);
53 | eig_D=eigs(res_net,1);
54 | res_net=(eig_rho/(abs(eig_D))).*res_net;
55 | res_net=full(res_net);
56 |
57 | section_len=round(reset_t/dt); % 30
58 | washup_length=round(1002/dt);
59 | train_length=round(train_t/dt)+round(5/dt); % 50000
60 | val_length=round(val_t/dt); % 300
61 | time_length = train_length + 2 * val_length + 3 * washup_length + 100;
62 |
63 | res_infor=struct('W_in', W_in, 'res_net', res_net, 'alpha', alpha, 'kb', kb, 'beta', beta, 'n', n);
64 | time_infor=struct('section_len', section_len, 'washup_length', washup_length, ...
65 | 'train_length', train_length, 'val_length', val_length, 'time_length', time_length);
66 |
67 | % generate training and validation data
68 | [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties);
69 | xy=xy(washup_length:end, :);
70 | q=q(washup_length:end, :);
71 | qdt=qdt(washup_length:end, :);
72 | q2dt=q2dt(washup_length:end, :);
73 | tau=tau(washup_length:end, :);
74 | data_reservoir = struct('xy', xy, 'q', q, 'qdt', qdt, 'q2dt', q2dt, 'tau', tau);
75 |
76 | clearvars xy q qdt q2dt tau
77 |
78 | %% training
79 | tic;
80 | [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out);
81 | toc;
82 |
83 | clearvars data_reservoir
84 | %% load data
85 | load_data = 1;
86 | if load_data==1
87 | load('./save_file/all_traj_06282022.mat')
88 | end
89 |
90 | %% validating
91 |
92 | rng('shuffle')
93 | % write the code that can update and pause
94 | % traj_type = 'infty';
95 | % traj_type = 'lorenz';
96 | % traj_type = 'mg17';
97 | % traj_type = 'mg30';
98 | % traj_type = 'circle'
99 |
100 | failure.type = 'none';
101 | blur.blur = 0;
102 |
103 | plot_val_and_update=1;
104 |
105 | disturbance = 0.05;
106 | measurement_noise = 0.05;
107 | plot_movie = 0;
108 |
109 | traj_type = 'lorenz';
110 | bridge_type = 'cubic';
111 |
112 | time_infor.val_length=250000;
113 |
114 | save_rend=0;
115 | idx=1;
116 |
117 | val_and_update;
118 |
119 | % idx = 2
120 | traj_type = 'circle';
121 | bridge_type = 'cubic';
122 | traj_frequency = 150;
123 |
124 | time_infor.val_length=250000;
125 | save_rend=1;
126 | idx=2;
127 |
128 | val_and_update;
129 |
130 | % idx = 3
131 | traj_type = 'mg17';
132 | bridge_type = 'cubic';
133 |
134 | time_infor.val_length=250000;
135 | save_rend=1;
136 | idx=3;
137 |
138 | val_and_update;
139 |
140 | % idx = 4
141 | traj_type = 'infty';
142 | bridge_type = 'cubic';
143 | traj_frequency = 100;
144 |
145 | time_infor.val_length=250000;
146 | save_rend=1;
147 | idx=4;
148 |
149 | val_and_update;
150 |
151 | % normal method to calculate rmse
152 | rmse_length = 1000/dt;
153 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :));
154 | rmse = sqrt(mean(mean(error.^2, 2)));
155 |
156 |
157 | %% plot figure
158 |
159 | % start_time = 1;
160 | % end_time = 290000;
161 | %
162 | % % plot trajectory
163 | % figure();
164 | % hold on
165 | % plot(data_control(start_time:end_time, 1), data_control(start_time:end_time, 2),'r');
166 | % plot(data_pred(start_time:end_time, 1), data_pred(start_time:end_time, 2),'b--');
167 | % xlabel('x')
168 | % ylabel('y')
169 | % line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--')
170 | % line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--')
171 | % xlim([-1, 1])
172 | % ylim([-1, 1])
173 | % legend('desired trajectory', 'pred trajectory')
174 | % %
175 | % % figure();
176 | % % hold on
177 | % % plot(data_control(100000:120000, 1), data_control(100000:120000, 2),'r');
178 | % % plot(data_pred(100000:120000, 1), data_pred(100000:120000, 2),'b--');
179 | % % xlabel('x')
180 | % % ylabel('y')
181 | % % line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--')
182 | % % line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--')
183 | % % xlim([-1, 1])
184 | % % ylim([-1, 1])
185 | % % legend('desired trajectory', 'pred trajectory')
186 | %
187 | % % plot q
188 | % q_control_plot=mod(q_control, pi);
189 | % q_pred_plot=mod(q_pred, pi);
190 | %
191 | % figure();
192 | % hold on
193 | % plot(q_control_plot(start_time:end_time,1), 'r')
194 | % plot(q_pred_plot(start_time:end_time,1), 'b')
195 | % xlabel('time step')
196 | % ylabel('q(1)')
197 | % legend('desired', 'pred')
198 | %
199 | % figure()
200 | % hold on
201 | % plot(q_control_plot(start_time:end_time,2), 'r')
202 | % plot(q_pred_plot(start_time:end_time,2), 'b')
203 | % xlabel('time step')
204 | % ylabel('q(2)')
205 | % legend('desired', 'pred')
206 | %
207 | % % plot dq/dt
208 | % figure()
209 | % hold on
210 | % plot(qdt_control(start_time:end_time,1), 'r')
211 | % plot(qdt_pred(start_time:end_time,1), 'b')
212 | % xlabel('time step')
213 | % ylabel('dq/dt(1)')
214 | % legend('desired', 'pred')
215 | %
216 | % figure()
217 | % hold on
218 | % plot(qdt_control(start_time:end_time,2), 'r')
219 | % plot(qdt_pred(start_time:end_time,2), 'b')
220 | % xlabel('time step')
221 | % ylabel('dq/dt(2)')
222 | % legend('desired', 'pred')
223 | %
224 | % % plot d2q/dt2
225 | % remove_transient = 1000;
226 | %
227 | % figure()
228 | % hold on
229 | % plot(q2dt_control(start_time+remove_transient:end_time,1), 'r')
230 | % plot(q2dt_pred(start_time+remove_transient:end_time,1), 'b')
231 | % xlabel('time step')
232 | % ylabel('d2q/dt2(1)')
233 | % legend('desired', 'pred')
234 | %
235 | % figure()
236 | % hold on
237 | % plot(q2dt_control(start_time+remove_transient:end_time,2), 'r')
238 | % plot(q2dt_pred(start_time+remove_transient:end_time,2), 'b')
239 | % xlabel('time step')
240 | % ylabel('d2q/dt2(2)')
241 | % legend('desired', 'pred')
242 | %
243 | % figure()
244 | % hold on
245 | % plot(tau_control(start_time+remove_transient:end_time,1), 'r')
246 | % plot(tau_pred(start_time+remove_transient:end_time,1), 'b')
247 | % xlabel('time step')
248 | % ylabel('tau(1)')
249 | % legend('desired', 'pred')
250 | %
251 | % figure()
252 | % hold on
253 | % plot(tau_control(start_time+remove_transient:end_time,2), 'r')
254 | % plot(tau_pred(start_time+remove_transient:end_time,2), 'b')
255 | % xlabel('time step')
256 | % ylabel('tau(2)')
257 | % legend('desired', 'pred')
258 | %
259 | % % figure()
260 | % % hold on
261 | % % plot(tau_control(start_time:end_time,1), 'r')
262 | % % plot(tau_control(start_time:end_time,2), 'b')
263 | % % xlabel('time step')
264 | % % ylabel('tau')
265 | % % legend('tau1', 'tau2')
266 | %
267 | % % plot movie for validation
268 | % plot_movie_val = 0;
269 | % start_step=start_time;
270 | % movie_step=500;
271 | % time_all=end_time;
272 | % line_property='dotted';
273 | % q1=q_pred(:, 1);
274 | % q2=q_pred(:, 2);
275 | % if plot_movie_val == 1
276 | % func_plot_movie(start_step, movie_step, time_all, q1, q2, properties, line_property)
277 | % end
278 |
279 |
280 | %% save data
281 |
282 | % save('./save_file/8d_input_04222022_lorenz_recover_from_failure.mat')
283 | % save('./save_file/8d_input_04222022_mg17.mat')
284 | % save('./save_file/reset_to_same_point_05182022_infty.mat')
285 | % save('./save_file/reset_to_random_point_05242022_infty.mat')
286 |
287 | % save('./save_file/reto_random_work_for_all_traj_05242022.mat')
288 | %
289 | % save('./save_data/4_trajectory_08182022.mat', "save_all_traj")
290 |
291 | % save('./save_file/all_traj_06282022.mat', 'time_infor', 'input_infor', 'res_infor', 'properties', 'dim_in', 'dim_out', 'Wout', 'r_end', 'dt', 'reset_t', 'noise_level')
292 | % save('./save_file/all_traj_06282022_SI.mat', 'data_reservoir')
293 |
294 |
295 | %% success test
296 |
297 | % aa = 1;
298 | %
299 | % plot_movie = 0;
300 | % traj_type = 'infty';
301 | % bridge_type = 'cubic';
302 | % plot_val_and_update=0;
303 | %
304 | % time_infor.val_length=150000;
305 | % rmse_start_time = round(time_infor.val_length * 2/5);
306 | % rmse_end_time = time_infor.val_length - 100;
307 | %
308 | % iteration = 50;
309 | % rmse_parfor_set = zeros(1, iteration);
310 | % for repeat_i = 1:iteration
311 | % load('./save_file/all_traj_06282022.mat')
312 | % val_and_update;
313 | % data_pred=output_infor.data_pred;
314 | % data_control=control_infor.data_control;
315 | % rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
316 | % end
317 | %
318 | % rmse_parfor_set = sort(rmse_parfor_set);
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
--------------------------------------------------------------------------------
/experiments/iter_noise.m:
--------------------------------------------------------------------------------
1 | %
2 | % disturbance = 1.00;
3 | % measurement_noise = 0.00;
4 |
5 | % try for only the first one and then only the second one
6 | % then try them together
7 | % then try uncertainty
8 |
9 | time_today = datestr(now, 'mmddyyyy');
10 |
11 | %% only disturbance
12 | measurement_noise = 0.00;
13 |
14 | % disturbance_set = exp(linspace(log(0.01), log(10), 20));
15 | disturbance_set = linspace(1, 5, 20);
16 | error_set = zeros(1, length(disturbance_set));
17 |
18 | rmse_length = 1000/dt;
19 | weight=zeros(rmse_length,1);
20 | % alpha_w=0.2;
21 | for ii=1:rmse_length
22 | weight(ii)=ii;
23 | end
24 | weight = weight/norm(weight, 1);
25 |
26 | for di_id = 1:length(disturbance_set)
27 | disturbance = disturbance_set(di_id);
28 | plot_movie = 0;
29 | traj_type = 'infty';
30 | % traj_type = 'lorenz';
31 | % traj_type = 'mg17';
32 | % traj_type = 'mg30';
33 | bridge_type = 'cubic';
34 |
35 | time_infor.val_length=120000;
36 |
37 | [control_infor, output_infor, time_infor] = func_reservoir_validate(traj_type,...
38 | bridge_type, time_infor, input_infor, res_infor, properties, dim_in, dim_out, ...
39 | Wout, dt, plot_movie, disturbance, measurement_noise);
40 |
41 | data_pred=output_infor.data_pred;
42 | q_pred=output_infor.q_pred;
43 | qdt_pred=output_infor.qdt_pred;
44 | q2dt_pred=output_infor.q2dt_pred;
45 | tau_pred=output_infor.tau_pred;
46 |
47 | q_control=control_infor.q_control;
48 | qdt_control=control_infor.qdt_control;
49 | q2dt_control=control_infor.q2dt_control;
50 | tau_control=control_infor.tau_control;
51 | % q_control_all=control_infor.q_control_all;
52 | % qdt_control_all=control_infor.qdt_control_all;
53 | data_control=control_infor.data_control;
54 |
55 | val_length=time_infor.val_length;
56 |
57 | % calculate rmse
58 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :));
59 | error = sum(weight.*mean(error.^2, 2));
60 |
61 | error_set(di_id) = error;
62 | end
63 |
64 | nan_id = isnan(error_set);
65 | error_set_plot = error_set;
66 | error_set_plot(nan_id) = 5*max(error_set, [], 'omitnan');
67 |
68 | save_data.disturbance_set = disturbance_set;
69 | save_data.disturbance_error_set = error_set;
70 | save_data.disturbance_error_set_plot = error_set_plot;
71 |
72 | save(['./save_data/disturbance_result_normscale_', time_today, '.mat'], "save_data")
73 |
74 |
75 | %% for plot
76 |
77 | figure();
78 | % hold on
79 | % semilogx(disturbance_set, error_set_plot, 'o-')
80 | plot(disturbance_set, error_set_plot, 'o-')
81 | xlabel('gaussian disturbance, \sigma')
82 | ylabel('error')
83 |
84 |
85 | %% only measurement
86 | disturbance = 0.00;
87 | measurement_set = exp(linspace(log(0.01), log(10), 20));
88 | % measurement_set = linspace(1, 5, 20);
89 | error_set = zeros(1, length(measurement_set));
90 |
91 | rmse_length = 1000/dt;
92 | weight=zeros(rmse_length,1);
93 | for ii=1:rmse_length
94 | weight(ii)=ii;
95 | end
96 | weight = weight/norm(weight, 1);
97 |
98 | for di_id = 1:length(measurement_set)
99 | measurement_noise = measurement_set(di_id);
100 | plot_movie = 0;
101 | traj_type = 'infty';
102 | % traj_type = 'lorenz';
103 | % traj_type = 'mg17';
104 | % traj_type = 'mg30';
105 | bridge_type = 'cubic';
106 |
107 | time_infor.val_length=120000;
108 |
109 | [control_infor, output_infor, time_infor] = func_reservoir_validate(traj_type,...
110 | bridge_type, time_infor, input_infor, res_infor, properties, dim_in, dim_out, ...
111 | Wout, dt, plot_movie, disturbance, measurement_noise);
112 |
113 | data_pred=output_infor.data_pred;
114 | q_pred=output_infor.q_pred;
115 | qdt_pred=output_infor.qdt_pred;
116 | q2dt_pred=output_infor.q2dt_pred;
117 | tau_pred=output_infor.tau_pred;
118 |
119 | q_control=control_infor.q_control;
120 | qdt_control=control_infor.qdt_control;
121 | q2dt_control=control_infor.q2dt_control;
122 | tau_control=control_infor.tau_control;
123 | % q_control_all=control_infor.q_control_all;
124 | % qdt_control_all=control_infor.qdt_control_all;
125 | data_control=control_infor.data_control;
126 |
127 | val_length=time_infor.val_length;
128 |
129 | % calculate rmse
130 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :));
131 | error = sum(weight.*mean(error.^2, 2));
132 |
133 | error_set(di_id) = error;
134 | end
135 |
136 | nan_id = isnan(error_set);
137 | error_set_plot = error_set;
138 | error_set_plot(nan_id) = 5*max(error_set, [], 'omitnan');
139 |
140 | save_data.measurement_set = measurement_set;
141 | save_data.measurement_error_set = error_set;
142 | save_data.measurement_error_set_plot = error_set_plot;
143 |
144 | save(['./save_data/measurement_result_logscale_', time_today, '.mat'], "save_data")
145 |
146 | %% for plot
147 |
148 | load('./save_data/measurement_result_logscale_05202022.mat')
149 |
150 | error_set_plot = save_data.measurement_error_set;
151 |
152 | figure();
153 | % hold on
154 | semilogx(measurement_set, error_set_plot, 'o-')
155 | % plot(measurement_set, error_set_plot, 'o-')
156 | xlabel('gaussian measurement noise, \sigma')
157 | ylabel('error')
158 |
159 |
160 |
161 | %% disturbance and measurement noise
162 |
163 | disturbance_set = exp(linspace(log(0.01), log(10), 20));
164 | measurement_set = exp(linspace(log(0.01), log(10), 20));
165 |
166 | error_set = zeros(length(disturbance_set), length(measurement_set));
167 |
168 | rmse_length = 1000/dt;
169 | weight=zeros(rmse_length,1);
170 | for ii=1:rmse_length
171 | weight(ii)=ii;
172 | end
173 | weight = weight/norm(weight, 1);
174 |
175 | for di_id = 1:length(disturbance_set)
176 | for mn_id = 1:length(measurement_set)
177 | disturbance = disturbance_set(di_id);
178 | measurement_noise = measurement_set(mn_id);
179 | plot_movie = 0;
180 | traj_type = 'infty';
181 | % traj_type = 'lorenz';
182 | % traj_type = 'mg17';
183 | % traj_type = 'mg30';
184 | bridge_type = 'cubic';
185 |
186 | time_infor.val_length=120000;
187 |
188 | [control_infor, output_infor, time_infor] = func_reservoir_validate(traj_type,...
189 | bridge_type, time_infor, input_infor, res_infor, properties, dim_in, dim_out, ...
190 | Wout, dt, plot_movie, disturbance, measurement_noise);
191 |
192 | data_pred=output_infor.data_pred;
193 | q_pred=output_infor.q_pred;
194 | qdt_pred=output_infor.qdt_pred;
195 | q2dt_pred=output_infor.q2dt_pred;
196 | tau_pred=output_infor.tau_pred;
197 |
198 | q_control=control_infor.q_control;
199 | qdt_control=control_infor.qdt_control;
200 | q2dt_control=control_infor.q2dt_control;
201 | tau_control=control_infor.tau_control;
202 | % q_control_all=control_infor.q_control_all;
203 | % qdt_control_all=control_infor.qdt_control_all;
204 | data_control=control_infor.data_control;
205 |
206 | val_length=time_infor.val_length;
207 |
208 | % calculate rmse
209 | error = abs(data_control(1:rmse_length, :) - data_pred(1:rmse_length, :));
210 | error = sum(weight.*mean(error.^2, 2));
211 |
212 | error_set(di_id, mn_id) = error;
213 | end
214 | end
215 |
216 | nan_id = isnan(error_set);
217 | error_set_plot = error_set;
218 | error_set_plot(nan_id) = 3*max(max(error_set));
219 |
220 | save_data.disturbance_set_heat = disturbance_set;
221 | save_data.measurement_set_heat = measurement_set;
222 | save_data.heat_error_set = error_set;
223 | save_data.heat_error_set_plot = error_set_plot;
224 |
225 | save(['./save_data/heat_result_logscale_', time_today, '.mat'], "save_data")
226 |
227 | %% for plot
228 |
229 | figure();
230 | surf(disturbance_set, measurement_set, error_set_plot);
231 | xlabel('measurement noise, \sigma')
232 | ylabel('disturbance noise, \sigma')
233 | set(gca, 'YScale', 'log');
234 | set(gca, 'XScale', 'log');
235 | colorbar
236 | view(0, 90)
237 | title('error')
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
--------------------------------------------------------------------------------
/experiments/normal_success_rate.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | addpath('./tools/')
6 |
7 | % load training data
8 | load('./save_file/all_traj_06282022.mat')
9 | time_today = datestr(now, 'mmddyyyy');
10 |
11 | plot_movie = 0;
12 | if exist('traj_type','var') == 0
13 | traj_type = 'infty';
14 | end
15 |
16 | traj_set = ["lorenz", "circle", "mg17", "infty", "astroid", "fermat", ...
17 | "lissajous", "talbot", "heart", "chua", "rossler", ...
18 | "sprott_1", "sprott_4", "mg30", "epitrochoid"];
19 | % traj_set = ["lorenz"];
20 | iteration = 100;
21 |
22 | bridge_type = 'cubic';
23 | time_infor.val_length=250000;
24 | val_length_rm = time_infor.val_length;
25 |
26 | failure.type = 'none';
27 | blur.last_time = 1000;
28 | blur.recover_time = time_infor.val_length;
29 | blur.blur = 0;
30 |
31 | failure.type = 'all';
32 | failure.amplitude = 0.1;
33 | failure.amplitude_2 = 0.1;
34 |
35 | rmse_start_time = round(val_length_rm * 3/5);
36 | rmse_end_time = time_infor.val_length - 100;
37 |
38 | rmse_set = zeros(length(traj_set), iteration);
39 |
40 | idx=1;
41 | for traj_id = 1:length(traj_set)
42 | traj_type = traj_set(traj_id);
43 | if strcmp(traj_type, 'circle') == 1
44 | traj_frequency = 150;
45 | elseif strcmp(traj_type, 'infty') == 1
46 | traj_frequency = 75;
47 | end
48 |
49 | rmse_parfor_set = zeros(1, iteration);
50 | for repeat_i = 1:iteration
51 | load('./save_file/all_traj_06282022.mat')
52 | time_infor.val_length = val_length_rm;
53 | save_rend=0;
54 | val_and_update;
55 | rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
56 | aaa = 1;
57 | end
58 | rmse_parfor_set = sort(rmse_parfor_set);
59 | rmse_set(traj_id, :) = rmse_parfor_set;
60 | end
61 |
62 | save_success.rmse_set = rmse_set;
63 | save_success.val_length = time_infor.val_length;
64 | save_success.traj_set = traj_set;
65 |
66 | % save(['./save_data/save_saferegion_success_rate_noise_' time_today, '_' num2str(randi(999)) '.mat'], "save_success")
67 |
68 | %%
69 | % load('./save_data/save_normal_success_rate_08022022_92.mat')
70 |
71 | % rmse_set = save_success.rmse_set;
72 |
73 | rmse_threshold = 0.18;
74 | rmse_logic = zeros(size(rmse_set));
75 | for i = 1:size(rmse_set, 1)
76 | for j = 1:size(rmse_set, 2)
77 | if rmse_set(i, j) > rmse_threshold
78 | rmse_logic(i, j) = 0;
79 | else
80 | rmse_logic(i, j) = 1;
81 | end
82 | end
83 | end
84 |
85 | rmse_count = mean(rmse_logic, 2);
86 |
87 | figure()
88 | plot(rmse_count, 'o')
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
--------------------------------------------------------------------------------
/experiments/speed_test.m:
--------------------------------------------------------------------------------
1 | % clear all
2 | % close all
3 | % clc
4 |
5 | addpath('./tools/')
6 |
7 | %%
8 |
9 | load('./save_file/all_traj_06282022.mat')
10 |
11 | time_today = datestr(now, 'mmddyyyy');
12 |
13 | plot_val_and_update = 0;
14 |
15 | blur.blur = 0;
16 | plot_movie = 0;
17 | traj_type = 'infty';
18 |
19 | if exist('traj_type','var') == 0
20 | traj_type = 'infty';
21 | end
22 |
23 | traj_type = 'infty';
24 |
25 | failure.type = 'all';
26 | failure.amplitude = 0.1;
27 | failure.amplitude_2 = 0.1;
28 |
29 | bridge_type = 'cubic';
30 |
31 | time_infor.val_length=250000;
32 | val_length_rm = time_infor.val_length;
33 |
34 | rmse_start_time = round(val_length_rm * 3/5);
35 | rmse_end_time = time_infor.val_length - 100;
36 |
37 | if strcmp(traj_type, 'circle') == 1
38 | frequency_set = round(linspace(10, 500, 15));
39 | elseif strcmp(traj_type, 'infty') == 1
40 | frequency_set = round(linspace(10, 500, 15) / 2);
41 | end
42 |
43 | iteration = 50;
44 |
45 | failure.type = 'none';
46 |
47 | rmse_set = zeros(length(frequency_set), iteration);
48 |
49 | idx=1;
50 |
51 | for f_idx = 1:length(frequency_set)
52 | rmse_parfor_set = zeros(1, iteration);
53 | for repeat_i = 1:iteration
54 | save_rend=0;
55 | load('./save_file/all_traj_06282022.mat')
56 | time_infor.val_length = val_length_rm;
57 | traj_frequency = frequency_set(f_idx);
58 |
59 | val_and_update;
60 |
61 | rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
62 | aaa = 1;
63 | end
64 | rmse_parfor_set = sort(rmse_parfor_set);
65 | rmse_set(f_idx, :) = rmse_parfor_set;
66 | end
67 |
68 |
69 | save_speed_iter.traj_type = traj_type;
70 |
71 | save_speed_iter.(['frequency_set', num2str(idx)]) = frequency_set;
72 | save_speed_iter.(['rmse_set', num2str(idx)]) = rmse_set;
73 |
74 | save(['save_data/save_speed_iter_' traj_type, '_' time_today, '_' num2str(randi(999)) '.mat'], "save_speed_iter")
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/experiments/speed_test_main.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | time_today = datestr(now, 'mmddyyyy');
6 |
7 | traj_type = 'circle';
8 | speed_test
9 |
10 | traj_type = 'infty';
11 | speed_test
--------------------------------------------------------------------------------
/experiments/success_rate.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | addpath('./tools/')
6 |
7 | %% test for network size, training length,
8 | % reset time and noise level
9 | % please record the rmse and running time
10 |
11 | % network_set = round(linspace(20, 500, 10));
12 | % training_length_set = round(linspace(20000, 200000, 10));
13 |
14 | network_set = round(exp(linspace(log(20), log(250), 10)));
15 | training_length_set = round(exp(linspace(log(2000), log(150000), 10)));
16 |
17 | % network_set = fliplr(network_set);
18 | % training_length_set = fliplr(training_length_set);
19 |
20 | reset_t_set = round(linspace(10, 150, 10));
21 | noise_level_set = 10 .^ linspace(-3, 0, 10);
22 |
23 | time_today = datestr(now, 'mmddyyyy');
24 |
25 |
26 | %% first to test the n_set and train_len_set
27 |
28 | reset_t = 80;
29 | noise_level = 2.0 * 10 ^ (-2);
30 |
31 | bias = 2.0;
32 |
33 | % delete(gcp('nocreate'))
34 | % parpool('local',10)
35 |
36 | iteration = 50;
37 |
38 | rmse_set_lorenz = zeros(length(network_set), length(training_length_set), iteration);
39 | rmse_set_circle = zeros(length(network_set), length(training_length_set), iteration);
40 | rmse_set_mg17 = zeros(length(network_set), length(training_length_set), iteration);
41 | rmse_set_infty = zeros(length(network_set), length(training_length_set), iteration);
42 | time_set = zeros(length(network_set), length(training_length_set), iteration);
43 |
44 | for ns = 1:length(network_set)
45 | n = network_set(ns);
46 | for tls = 1:length(training_length_set)
47 | train_t = training_length_set(tls);
48 |
49 | rmse_parfor_set_lorenz = zeros(1, iteration);
50 | rmse_parfor_set_circle = zeros(1, iteration);
51 | rmse_parfor_set_mg17 = zeros(1, iteration);
52 | rmse_parfor_set_infty = zeros(1, iteration);
53 |
54 | time_parfor_set = zeros(1, iteration);
55 | for repeat_i = 1:iteration
56 | [rmse_l, rmse_c, rmse_m, rmse_i, t_repeat_i] = func_train_val(n, train_t, reset_t, noise_level, bias);
57 | rmse_parfor_set_lorenz(repeat_i) = rmse_l;
58 | rmse_parfor_set_circle(repeat_i) = rmse_c;
59 | rmse_parfor_set_mg17(repeat_i) = rmse_m;
60 | rmse_parfor_set_infty(repeat_i) = rmse_i;
61 |
62 | time_parfor_set(repeat_i) = t_repeat_i;
63 | end
64 |
65 | rmse_set_lorenz(ns, tls, :) = rmse_parfor_set_lorenz;
66 | rmse_set_circle(ns, tls, :) = rmse_parfor_set_circle;
67 | rmse_set_mg17(ns, tls, :) = rmse_parfor_set_mg17;
68 | rmse_set_infty(ns, tls, :) = rmse_parfor_set_infty;
69 |
70 | time_set(ns, tls, :) = time_parfor_set;
71 | end
72 | end
73 |
74 | save_success_rate.reset_t = reset_t;
75 | save_success_rate.noise_level = noise_level;
76 | save_success_rate.network_set = network_set;
77 | save_success_rate.training_length_set = training_length_set;
78 | save_success_rate.rmse_set_lorenz = rmse_set_lorenz;
79 | save_success_rate.rmse_set_circle = rmse_set_circle;
80 | save_success_rate.rmse_set_mg17 = rmse_set_mg17;
81 | save_success_rate.rmse_set_infty = rmse_set_infty;
82 | save_success_rate.time_set = time_set;
83 |
84 | save(['save_data/save_success_rate_nt_', time_today, '_' num2str(randi(999)) '.mat'], 'save_success_rate')
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/experiments/success_rate_1.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | addpath('./tools/')
6 |
7 | %% test for network size, training length,
8 | % reset time and noise level
9 | % please record the rmse and running time
10 |
11 | % network_set = round(exp(linspace(log(20), log(270), 10)));
12 | % training_length_set = round(exp(linspace(log(2000), log(150000), 10)));
13 |
14 | reset_t_set = round(linspace(10, 150, 10));
15 | noise_level_set = 10 .^ linspace(-3, 0, 10);
16 |
17 | time_today = datestr(now, 'mmddyyyy');
18 |
19 |
20 | %% first to test the n_set and train_len_set
21 |
22 | % reset_t = 80;
23 | % noise_level = 2.0 * 10 ^ (-2);
24 | n = 100;
25 | train_t = 10000;
26 |
27 | bias = 2.0;
28 |
29 | % delete(gcp('nocreate'))
30 | % parpool('local',6)
31 |
32 | iteration = 50;
33 |
34 | rmse_set_lorenz = zeros(length(reset_t_set), length(noise_level_set), iteration);
35 | rmse_set_circle = zeros(length(reset_t_set), length(noise_level_set), iteration);
36 | rmse_set_mg17 = zeros(length(reset_t_set), length(noise_level_set), iteration);
37 | rmse_set_infty = zeros(length(reset_t_set), length(noise_level_set), iteration);
38 | time_set = zeros(length(reset_t_set), length(noise_level_set), iteration);
39 |
40 | for rts = 1:length(reset_t_set)
41 | reset_t = reset_t_set(rts);
42 | for nls = 1:length(noise_level_set)
43 | noise_level = noise_level_set(nls);
44 |
45 | rmse_parfor_set_lorenz = zeros(1, iteration);
46 | rmse_parfor_set_circle = zeros(1, iteration);
47 | rmse_parfor_set_mg17 = zeros(1, iteration);
48 | rmse_parfor_set_infty = zeros(1, iteration);
49 | time_parfor_set = zeros(1, iteration);
50 |
51 | for repeat_i = 1:iteration
52 | [rmse_l, rmse_c, rmse_m, rmse_i, t_repeat_i] = func_train_val(n, train_t, reset_t, noise_level, bias);
53 | rmse_parfor_set_lorenz(repeat_i) = rmse_l;
54 | rmse_parfor_set_circle(repeat_i) = rmse_c;
55 | rmse_parfor_set_mg17(repeat_i) = rmse_m;
56 | rmse_parfor_set_infty(repeat_i) = rmse_i;
57 |
58 | time_parfor_set(repeat_i) = t_repeat_i;
59 | end
60 | rmse_set_lorenz(rts, nls, :) = rmse_parfor_set_lorenz;
61 | rmse_set_circle(rts, nls, :) = rmse_parfor_set_circle;
62 | rmse_set_mg17(rts, nls, :) = rmse_parfor_set_mg17;
63 | rmse_set_infty(rts, nls, :) = rmse_parfor_set_infty;
64 |
65 | time_set(rts, nls, :) = time_parfor_set;
66 | end
67 | end
68 |
69 | save_success_rate.reset_t_set = reset_t_set;
70 | save_success_rate.noise_level_set = noise_level_set;
71 | save_success_rate.n = n;
72 | save_success_rate.train_t = train_t;
73 | save_success_rate.rmse_set_lorenz = rmse_set_lorenz;
74 | save_success_rate.rmse_set_circle = rmse_set_circle;
75 | save_success_rate.rmse_set_mg17 = rmse_set_mg17;
76 | save_success_rate.rmse_set_infty = rmse_set_infty;
77 | save_success_rate.time_set = time_set;
78 |
79 | save(['save_data/save_success_rate_rn_', time_today, '_' num2str(randi(999)) '.mat'], 'save_success_rate')
80 |
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/experiments/success_rate_bias.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | addpath('./tools/')
6 |
7 | bias_set = linspace(0, 3, 7);
8 |
9 | time_today = datestr(now, 'mmddyyyy');
10 |
11 | reset_t = 80;
12 | noise_level = 2.0 * 10 ^ (-2);
13 | n=200;
14 | train_t = 150000;
15 |
16 | iteration = 10;
17 |
18 | rmse_set_lorenz = zeros(length(bias_set), iteration);
19 | rmse_set_circle = zeros(length(bias_set), iteration);
20 | rmse_set_mg17 = zeros(length(bias_set), iteration);
21 | rmse_set_infty = zeros(length(bias_set), iteration);
22 |
23 | for bs = 1:length(bias_set)
24 | bias = bias_set(bs);
25 | rmse_parfor_set_lorenz = zeros(1, iteration);
26 | rmse_parfor_set_circle = zeros(1, iteration);
27 | rmse_parfor_set_mg17 = zeros(1, iteration);
28 | rmse_parfor_set_infty = zeros(1, iteration);
29 |
30 | for repeat_i = 1:iteration
31 | [rmse_l, rmse_c, rmse_m, rmse_i, t_repeat_i] = func_train_val(n, train_t, reset_t, noise_level, bias);
32 | rmse_parfor_set_lorenz(repeat_i) = rmse_l;
33 | rmse_parfor_set_circle(repeat_i) = rmse_c;
34 | rmse_parfor_set_mg17(repeat_i) = rmse_m;
35 | rmse_parfor_set_infty(repeat_i) = rmse_i;
36 | end
37 |
38 | rmse_set_lorenz(bs, :) = rmse_parfor_set_lorenz;
39 | rmse_set_circle(bs, :) = rmse_parfor_set_circle;
40 | rmse_set_mg17(bs, :) = rmse_parfor_set_mg17;
41 | rmse_set_infty(bs, :) = rmse_parfor_set_infty;
42 | end
43 |
44 | save_success_rate.reset_t = reset_t;
45 | save_success_rate.noise_level = noise_level;
46 | save_success_rate.n = n;
47 | save_success_rate.train_t = train_t;
48 | save_success_rate.bias_set = bias_set;
49 | save_success_rate.rmse_set_lorenz = rmse_set_lorenz;
50 | save_success_rate.rmse_set_circle = rmse_set_circle;
51 | save_success_rate.rmse_set_mg17 = rmse_set_mg17;
52 | save_success_rate.rmse_set_infty = rmse_set_infty;
53 |
54 |
55 | save(['save_data/save_success_rate_bias_', time_today, '_' num2str(randi(999)) '.mat'], 'save_success_rate')
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/experiments/temp:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/experiments/uncertain_test_saferegion.m:
--------------------------------------------------------------------------------
1 | % clear all
2 | % close all
3 | % clc
4 |
5 | addpath('./tools/')
6 |
7 | %%
8 |
9 | load('./save_file/all_traj_06282022.mat')
10 | %
11 | % time_today = datestr(now, 'mmddyyyy');
12 |
13 | plot_val_and_update = 0;
14 |
15 | blur.blur = 0;
16 | plot_movie = 0;
17 | % traj_type = 'lorenz';
18 |
19 | if exist('traj_type','var') == 0
20 | traj_type = 'infty';
21 | end
22 |
23 | bridge_type = 'cubic';
24 |
25 | time_infor.val_length=250000;
26 | val_length_rm = time_infor.val_length;
27 |
28 | rmse_start_time = round(val_length_rm * 3/5);
29 | rmse_end_time = time_infor.val_length - 100;
30 |
31 | if strcmp(traj_type, 'circle') == 1
32 | traj_frequency = 150;
33 | elseif strcmp(traj_type, 'infty') == 1
34 | traj_frequency = 75;
35 | else
36 | traj_frequency = 75;
37 | end
38 |
39 | l1_set = linspace(0.5, 0.55, 10);
40 | l2_set = linspace(0.5, 0.55, 10);
41 |
42 | iteration = 50;
43 |
44 | failure.type = 'all';
45 | failure.amplitude = 0.1;
46 | failure.amplitude_2 = 0.1;
47 |
48 | % length
49 | uncertain_type = 'l';
50 |
51 | rmse_set = zeros(length(l1_set), length(l2_set), iteration);
52 |
53 | idx=2;
54 |
55 | for l1_idx = 1:length(l1_set)
56 | for l2_idx = 1:length(l2_set)
57 | rmse_parfor_set = zeros(1, iteration);
58 | for repeat_i = 1:iteration
59 | save_rend=0;
60 |
61 | load('./save_file/all_traj_06282022.mat')
62 | time_infor.val_length = val_length_rm;
63 | properties(3:4) = [l1_set(l1_idx), l2_set(l2_idx)];
64 |
65 | val_and_update;
66 |
67 | rmse_parfor_set(repeat_i) = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
68 | aaa = 1;
69 | end
70 | rmse_parfor_set = sort(rmse_parfor_set);
71 | rmse_set(l1_idx, l2_idx, :) = rmse_parfor_set;
72 | end
73 | end
74 |
75 | save_uncertain_iter.traj_type = traj_type;
76 |
77 | save_uncertain_iter.(['type', num2str(idx)]) = uncertain_type;
78 | save_uncertain_iter.(['l1_set', num2str(idx)]) = l1_set;
79 | save_uncertain_iter.(['l2_set', num2str(idx)]) = l2_set;
80 | save_uncertain_iter.(['rmse_set', num2str(idx)]) = rmse_set;
81 |
82 | save(['save_data/save_uncertain_iter_' traj_type, '_' time_today, '_' num2str(randi(999)) '.mat'], "save_uncertain_iter")
83 |
84 |
85 |
86 | %% rmse set
87 |
88 | rmse_set_l = save_uncertain_iter.rmse_set2;
89 |
90 | rmse_l_mean = mean(rmse_set_l, 3);
91 |
92 | figure();
93 | imagesc(l1_set, l2_set, rmse_l_mean);
94 | colorbar
95 |
96 | l1_set_plot = (l1_set - 0.5)/ 0.5;
97 | l2_set_plot = (l2_set - 0.5)/ 0.5;
98 | rmse_l_mean_flip = flip(rmse_l_mean, 1);
99 |
100 | figure();
101 | imagesc(l1_set_plot, l2_set_plot, rmse_l_mean_flip);
102 | colorbar
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/experiments/uncertain_test_saferegion_main.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | time_today = datestr(now, 'mmddyyyy');
6 |
7 | % traj_type = 'lorenz';
8 | % uncertain_test_saferegion
9 |
10 | % traj_type = 'circle';
11 | % uncertain_test_saferegion
12 |
13 | % traj_type = 'mg17';
14 | % uncertain_test_saferegion
15 |
16 | traj_type = 'infty';
17 | uncertain_test_saferegion
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
--------------------------------------------------------------------------------
/func_desired_traj.m:
--------------------------------------------------------------------------------
1 | function [control_infor, time_infor] = func_desired_traj(traj_type, bridge_type, time_infor, control_infor, properties, dt, plot_movie, traj_frequency)
2 |
3 | plot_figures_inside = 0;
4 |
5 | val_length=time_infor.val_length;
6 | train_length=time_infor.train_length;
7 | % Get the two-arms property
8 | [m1, m2, l1, l2, lc1, lc2, I1, I2] = matsplit(properties);
9 | % Get the initial information
10 | q_control=control_infor.q_control;
11 | qdt_control=control_infor.qdt_control;
12 | q2dt_control=control_infor.q2dt_control;
13 | tau_control=control_infor.tau_control;
14 |
15 | value_q=q_control(1,:);
16 | x_start=l1*cos(value_q(1))+l2*cos(value_q(1)+value_q(2));
17 | y_start=l1*sin(value_q(1))+l2*sin(value_q(1)+value_q(2));
18 |
19 | t=0:dt:val_length;
20 |
21 | %% Different traj_types, the readers can add their own trajectories
22 | if strcmp(traj_type, 'infty')==1
23 | x = 0.25*sin(2*pi*t*1/(2*traj_frequency));
24 | y = 0.15*sin(2*pi*t*1/traj_frequency);
25 |
26 | x = x(1:2*val_length+1);
27 | y = y(1:2*val_length+1);
28 | elseif strcmp(traj_type, 'circle')
29 | x = 0.5 * cos(2 * pi * t * 1 / traj_frequency);
30 | y = 0.5 * sin(2 * pi * t * 1 / traj_frequency);
31 |
32 | x = x(1:2*val_length+1);
33 | y = y(1:2*val_length+1);
34 | elseif strcmp(traj_type, 'astroid')
35 | x = 0.4 * (cos(2 * pi * t * 1 / 250).^3);
36 | y = 0.4 * (sin(2 * pi * t * 1 / 250).^3);
37 |
38 | x = x(1:2*val_length+1);
39 | y = y(1:2*val_length+1);
40 |
41 | elseif strcmp(traj_type, 'heart')
42 | f = 250;
43 | a = 0.4;
44 | theta = 2 * pi * t * 1 / f;
45 |
46 | r = 1 - sin(theta);
47 |
48 | x = r.*cos(theta);
49 | y = r.*sin(theta);
50 |
51 | x = normalize(x, 'range', [-a, a]);
52 | y = normalize(y, 'range', [-a, a]);
53 |
54 | x = x(1:2*val_length+1);
55 | y = y(1:2*val_length+1);
56 |
57 | elseif strcmp(traj_type, 'epitrochoid')
58 | a = 5;
59 | b = 3;
60 | c = 5;
61 |
62 | f = 200;
63 |
64 | x = (a + b) * cos(2 * pi * t * 1 / f) - c * cos((a/b + 1) * 2 * pi * t * 1 / f);
65 | y = (a + b) * sin(2 * pi * t * 1 / f) - c * sin((a/b + 1) * 2 * pi * t * 1 / f);
66 |
67 | x = normalize(x, 'range', [-0.4, 0.4]);
68 | y = normalize(y, 'range', [-0.4, 0.4]);
69 |
70 | x = x(1:2*val_length+1);
71 | y = y(1:2*val_length+1);
72 |
73 | elseif strcmp(traj_type, 'fermat')
74 | f = 100;
75 | a = 0.5;
76 | theta = 2 * pi * t * 1 / f;
77 |
78 | r = sqrt(a^2 * theta);
79 |
80 | x = r.*cos(theta);
81 | y = r.*sin(theta);
82 |
83 | x = normalize(x, 'range', [-4, 4]);
84 | y = normalize(y, 'range', [-4, 4]);
85 |
86 | x = x(1:2*val_length+1);
87 | y = y(1:2*val_length+1);
88 |
89 | elseif strcmp(traj_type, 'lissajous')
90 | f = 300;
91 | a = 1;
92 | b = 3;
93 | A = 1;
94 | B = 1;
95 |
96 | x = A * sin(a * t * 2 * pi / f + pi / 4);
97 | y = B * sin(b * t * 2 * pi / f) ;
98 |
99 | xx = normalize(x, 'range', [-0.3, 0.3]);
100 | yy = normalize(y, 'range', [-0.3, 0.3]);
101 |
102 | x = yy(1:2*val_length+1);
103 | y = xx(1:2*val_length+1);
104 |
105 | elseif strcmp(traj_type, 'talbot')
106 | f = 400;
107 | a = 1.1;
108 | b = 0.666;
109 | ff = 1;
110 |
111 | x = (a^2 + ff^2 .* sin(t * 2 * pi / f).^2) .* cos(t * 2 * pi / f) / a;
112 | y = (a^2 - 2*ff^2 + ff^2 .* sin(t * 2 * pi / f).^2) .* sin(t * 2 * pi / f) / b;
113 |
114 | x = normalize(x, 'range', [-0.4, 0.4]);
115 | y = normalize(y, 'range', [-0.3, 0.3]);
116 |
117 | x = x(1:2*val_length+1);
118 | y = y(1:2*val_length+1);
119 |
120 | elseif strcmp(traj_type, 'lorenz')
121 | load('./read_data/lorenz.mat')
122 | lorenz_xy=ts_train(1000:1000+round(val_length*2.1), 1:2);
123 |
124 | lorenz_x = normalize(lorenz_xy(:, 1), 'range', [-0.5, 0.5]);
125 | lorenz_y = normalize(lorenz_xy(:, 2), 'range', [-0.5, 0.5]);
126 |
127 | x = lorenz_x(1:val_length*2+1, :)';
128 | y = lorenz_y(1:val_length*2+1, :)';
129 |
130 | y = y - 0.3;
131 |
132 | elseif strcmp(traj_type, 'chua')
133 | load('./read_data/chua.mat')
134 | chua_xy = ts_train(10000:10000+round(val_length*2.1), 1:2);
135 |
136 | chua_x = normalize(chua_xy(:, 1), 'range', [-0.5, 0.5]);
137 | chua_y = normalize(chua_xy(:, 2), 'range', [-0.3, 0.3]);
138 |
139 | chua_x = interp1(1:length(chua_x), chua_x, 0:0.3:length(chua_x), 'pchip');
140 | chua_y = interp1(1:length(chua_y), chua_y, 0:0.3:length(chua_y), 'pchip');
141 |
142 | x = chua_x(1:val_length*2+1)';
143 | y = chua_y(1:val_length*2+1)';
144 |
145 | x = reshape(x, 1, []);
146 | y = reshape(y, 1, []);
147 |
148 | elseif strcmp(traj_type, 'rossler')
149 | load('./read_data/rossler.mat')
150 | rossler_xy = ts_train(10000:10000+round(val_length*2.1), 1:2);
151 |
152 | rossler_x = normalize(rossler_xy(:, 1), 'range', [-0.35, 0.35]);
153 | rossler_y = normalize(rossler_xy(:, 2), 'range', [-0.35, 0.35]);
154 |
155 | rossler_x = interp1(1:length(rossler_x), rossler_x, 0:0.55:length(rossler_x), 'pchip');
156 | rossler_y = interp1(1:length(rossler_y), rossler_y, 0:0.55:length(rossler_y), 'pchip');
157 |
158 | x = rossler_x(1:val_length*2+1)';
159 | y = rossler_y(1:val_length*2+1)';
160 |
161 | x = reshape(x, 1, []);
162 | y = reshape(y, 1, []);
163 |
164 | elseif strcmp(traj_type, 'sprott_1')
165 | load('./read_data/sprott_1.mat')
166 | sprott_xy = ts_train(10000:10000+round(val_length*2.1), 1:2);
167 |
168 | sprott_x = normalize(sprott_xy(:, 1), 'range', [-0.4, 0.4]);
169 | sprott_y = normalize(sprott_xy(:, 2), 'range', [-0.4, 0.4]);
170 |
171 | sprott_x = interp1(1:length(sprott_x), sprott_x, 0:0.6:length(sprott_x), 'pchip');
172 | sprott_y = interp1(1:length(sprott_y), sprott_y, 0:0.6:length(sprott_y), 'pchip');
173 |
174 | x = sprott_x(1:val_length*2+1)';
175 | y = sprott_y(1:val_length*2+1)';
176 |
177 | x = reshape(x, 1, []);
178 | y = reshape(y, 1, []);
179 |
180 | elseif strcmp(traj_type, 'sprott_4')
181 | load('./read_data/sprott_4.mat')
182 | sprott_xy = ts_train(10000:10000+round(val_length*2.1), 1:2);
183 |
184 | sprott_x = normalize(sprott_xy(:, 1), 'range', [-0.4, 0.4]);
185 | sprott_y = normalize(sprott_xy(:, 2), 'range', [-0.4, 0.4]);
186 |
187 | sprott_x = interp1(1:length(sprott_x), sprott_x, 0:0.6:length(sprott_x), 'pchip');
188 | sprott_y = interp1(1:length(sprott_y), sprott_y, 0:0.6:length(sprott_y), 'pchip');
189 |
190 | x = sprott_x(1:val_length*2+1)';
191 | y = sprott_y(1:val_length*2+1)';
192 |
193 | x = reshape(x, 1, []);
194 | y = reshape(y, 1, []);
195 |
196 |
197 | elseif strcmp(traj_type, 'mg17')
198 | load('./read_data/MG17.mat')
199 | mg=ts_train(10000:10000+round(val_length*2.1));
200 |
201 | mg_x = normalize(mg(10000-1700:end-1700), 'range', [-0.5, 0.5]);
202 | mg_y = normalize(mg(10000:end), 'range', [-0.5, 0.5]);
203 |
204 | mg_x = interp1(1:length(mg_x), mg_x, 0:0.25:length(mg_x), 'pchip');
205 | mg_y = interp1(1:length(mg_y), mg_y, 0:0.25:length(mg_y), 'pchip');
206 |
207 | x = mg_x(1:val_length*2+1)';
208 | y = mg_y(1:val_length*2+1)';
209 |
210 | x = reshape(x, 1, []);
211 | y = reshape(y, 1, []);
212 |
213 | x = x - 0.1;
214 | y = y - 0.1;
215 | elseif strcmp(traj_type, 'mg30')
216 | load('./read_data/MG30.mat')
217 | mg=ts_train(10000:10000+round(val_length*2.1));
218 |
219 | mg_x = normalize(mg(10000:end), 'range', [-0.4, 0.4]);
220 | mg_y = normalize(mg(10000-3000:end-3000), 'range', [-0.4, 0.4]);
221 |
222 | mg_x = interp1(1:length(mg_x), mg_x, 0:0.25:length(mg_x), 'pchip');
223 | mg_y = interp1(1:length(mg_y), mg_y, 0:0.25:length(mg_y), 'pchip');
224 |
225 | x = mg_x(1:val_length*2+1)';
226 | y = mg_y(1:val_length*2+1)';
227 |
228 | x = reshape(x, 1, []);
229 | y = reshape(y, 1, []);
230 |
231 | x = x + 0.25;
232 | y = y;
233 |
234 | elseif strcmp(traj_type, 'lorenz96')
235 | load('./read_data/lorenz96.mat')
236 |
237 | lorenz_highD = ts_train(10000:10000+round(val_length*2.1), :);
238 |
239 | lorenz_x = normalize(lorenz_highD(:, 1), 'range', [-0.6, 0.6]);
240 | lorenz_y = normalize(lorenz_highD(:, 2), 'range', [-0.7, 0.7]);
241 |
242 | x = lorenz_x(1:val_length*2+1)';
243 | y = lorenz_y(1:val_length*2+1)';
244 |
245 | x = reshape(x, 1, []);
246 | y = reshape(y, 1, []);
247 |
248 | x = x + 0.1;
249 | y = y;
250 | else
251 | disp('error: please input the traj type!')
252 | end
253 |
254 | %% Build bridge to the reference trajectory
255 | add_id=1;
256 | closet_value = Inf;
257 | for i =1:min(100000, val_length)
258 | dis_value = sqrt((x_start-x(i)).^2+(y_start-y(i)).^2);
259 | if dis_value < closet_value
260 | if round(x(i), 6)==0 && round(y(i), 6)==0
261 | continue
262 | end
263 | closet_value = dis_value;
264 | add_id=i;
265 | end
266 | end
267 | bridge_point = [x(add_id), y(add_id)];
268 | bridge_len = sqrt((bridge_point(1)-x_start).^2 + (bridge_point(2)-y_start).^2);
269 | bridge_time = round(bridge_len * 1 / dt);
270 | % We provide two methods to build the bridge: cubic and linear
271 | if strcmp(bridge_type, 'cubic') == 1
272 | t_bg = 0:dt:bridge_time;
273 | theta0=q_control(1,:);
274 | theta0_dot=qdt_control(1,:);
275 | theta0=mod(theta0, 2*pi);
276 | if theta0(2) > pi
277 | theta0(2) = theta0(2)-2*pi;
278 | end
279 | q2_bg(1)=acos((x(add_id).^2+y(add_id).^2-l1^2-l2^2)/(2*l1*l2));
280 | q2_bg(2)=acos((x(add_id+1).^2+y(add_id+1).^2-l1^2-l2^2)/(2*l1*l2));
281 | if theta0(2)<0
282 | q2_bg=-q2_bg;
283 | end
284 |
285 | q1_bg(1)=atan(y(add_id)./x(add_id))-atan(l2*sin(q2_bg(1))./(l1+l2*cos(q2_bg(1))));
286 | q1_bg(2)=atan(y(add_id+1)./x(add_id+1))-atan(l2*sin(q2_bg(2))./(l1+l2*cos(q2_bg(2))));
287 |
288 | for ij = 1:length(q1_bg)
289 | if x(add_id + ij-1) < 0 && y(add_id + ij-1) > 0
290 | q1_bg(ij)=q1_bg(ij)+pi;
291 | elseif x(add_id + ij-1) < 0 && y(add_id + ij-1) < 0
292 | q1_bg(ij)=q1_bg(ij)+pi;
293 | elseif x(add_id + ij-1) > 0 && y(add_id + ij-1) < 0
294 | q1_bg(ij)=q1_bg(ij)+2*pi;
295 | end
296 | end
297 |
298 | theta1=[q1_bg(1), q2_bg(1)];
299 |
300 | x_back_test = l1*cos(theta1(1)) + l2*cos(theta1(1)+theta1(2));
301 | y_back_test = l1*sin(theta1(1)) + l2*sin(theta1(1)+theta1(2));
302 |
303 | theta1_dot=[(q1_bg(2)-q1_bg(1))/dt, (q2_bg(2)-q2_bg(1))/dt];
304 |
305 | a0=theta0;
306 | a1=theta0_dot;
307 | a2=3.*(theta1-theta0)./(bridge_time^2)-2.*theta0_dot./bridge_time-theta1_dot/bridge_time;
308 | a3=-2.*(theta1-theta0)./(bridge_time^3)+(theta1_dot+theta0_dot)/(bridge_time^2);
309 |
310 | q1_bridge=a0(1) + a1(1).*t_bg + a2(1).*(t_bg.^2) + a3(1).*(t_bg.^3);
311 | q2_bridge=a0(2) + a1(2).*t_bg + a2(2).*(t_bg.^2) + a3(2).*(t_bg.^3);
312 |
313 | truePosition=zeros(2, length(q1_bridge));
314 | truePosition(1,:)=l1*cos(q1_bridge)+l2*cos(q1_bridge+q2_bridge);
315 | truePosition(2,:)=l1*sin(q1_bridge)+l2*sin(q1_bridge+q2_bridge);
316 |
317 | x=[truePosition(1, 2:end-1), reshape(x(add_id:end), 1, [])];
318 | y=[truePosition(2, 2:end-1), reshape(y(add_id:end), 1, [])];
319 | elseif strcmp(bridge_type, 'linear') == 1
320 | bridge = waypointTrajectory([x_start, y_start, 0; bridge_point(1), bridge_point(2), 0], 'TimeOfArrival', [0, bridge_time], 'SampleRate', round(bridge_time/dt));
321 |
322 | truePosition = zeros(bridge.SampleRate * bridge.TimeOfArrival(end)-1, 3);
323 | count=1;
324 | while ~isDone(bridge)
325 | truePosition(count, :)=bridge();
326 | count=count+1;
327 | end
328 |
329 | x=[truePosition(2:end-1, 1)', x(add_id:end)];
330 | y=[truePosition(2:end-1, 2)', y(add_id:end)];
331 | else
332 | disp('error: please input the bridge_type!')
333 | end
334 |
335 | val_length = val_length + length(truePosition(2:end-1, 1));
336 |
337 | %% Desired reference
338 | q2=acos((x.^2+y.^2-l1^2-l2^2)/(2*l1*l2));
339 |
340 | symb = 1;
341 | for ij = 2:length(x)-1
342 | q2(ij) = symb * q2(ij);
343 | if q2(ij) == pi || q2(ij) == -pi || q2(ij) == 0
344 | if sign(x(ij+1) - x(ij)) == sign(x(ij) - x(ij-1))
345 | symb = -symb;
346 | end
347 | end
348 | end
349 |
350 | q1=atan(y./x)-atan(l2*sin(q2)./(l1+l2*cos(q2)));
351 |
352 | x1 = l1*cos(q1(1));
353 | if round(x1, 2) ~= round(l1*cos(value_q(1)), 2)
354 | symbol_change_record = 1;
355 | q2 = -q2;
356 | q1=atan(y./x)-atan(l2*sin(q2)./(l1+l2*cos(q2)));
357 | for ij = 1:length(q1)
358 | if x(ij) < 0 && y(ij) > 0
359 | q1(ij)=q1(ij)+pi;
360 | elseif x(ij) < 0 && y(ij) < 0
361 | q1(ij)=q1(ij)+pi;
362 | elseif x(ij) > 0 && y(ij) < 0
363 | q1(ij)=q1(ij)+2*pi;
364 | end
365 | end
366 |
367 | for ij = 2:length(q1)
368 | if q1(ij)-q1(ij-1)pi-0.1
369 | q1(ij)=q1(ij)-pi;
370 | end
371 | if q1(ij)-q1(ij-1)<-pi+0.1 && q1(ij)-q1(ij-1)>-pi-0.1
372 | q1(ij)=q1(ij)+pi;
373 | end
374 | if q1(ij)-q1(ij-1)>pi
375 | q1(ij:end)=q1(ij:end)-2*pi;
376 | end
377 | if q1(ij)-q1(ij-1)<-pi
378 | q1(ij:end)=q1(ij:end)+2*pi;
379 | end
380 | if q2(ij)-q2(ij-1)>pi
381 | q2(ij:end)=q2(ij:end)-2*pi;
382 | end
383 | if q2(ij)-q2(ij-1)<-pi
384 | q2(ij:end)=q2(ij:end)+2*pi;
385 | end
386 | end
387 | end
388 |
389 | wave_qdt=[(q1(2:val_length+1)-q1(1:val_length))', (q2(2:val_length+1)-q2(1:val_length))'];
390 | wave_qdt=wave_qdt./dt;
391 |
392 | q_control(2:val_length+1, :)=[q1(1:val_length)', q2(1:val_length)'];
393 |
394 | qdt_control(2:val_length+1, :)=wave_qdt(1:val_length,:);
395 |
396 | q2dt_control(1:val_length, :)=...
397 | (qdt_control(2:val_length+1,:)-qdt_control(1:val_length,:))/dt;
398 |
399 | for ii = 1:val_length
400 | H11=m1*lc1^2+I1+m2*(l1^2+lc2^2+2*l1*lc2*cos(q_control(ii,2)))+I2;
401 | H12=m2*l1*lc2*cos(q_control(ii,2))+m2*lc2^2+I2;
402 | H21=H12;
403 | H22=m2*lc2^2+I2;
404 | h=m2*l1*lc2*sin(q_control(ii,2));
405 |
406 | part_1=-h*qdt_control(ii,2)*qdt_control(ii,1)-h*(qdt_control(ii,1)+qdt_control(ii,2))*qdt_control(ii,2);
407 | part_2=h*qdt_control(ii,1)*qdt_control(ii,1);
408 |
409 | tau_control(ii,1)=H11*q2dt_control(ii,1)+H12*q2dt_control(ii,2)+part_1;
410 | tau_control(ii,2)=H21*q2dt_control(ii,1)+H22*q2dt_control(ii,2)+part_2;
411 | end
412 |
413 | % plot_figures_inside = 1;
414 | % if we need to visualization
415 | if plot_figures_inside == 1
416 | figure();
417 | hold on
418 | plot(tau_control(2:val_length, 1))
419 | plot(tau_control(2:val_length, 2))
420 | ylabel('tau')
421 | end
422 |
423 | if plot_figures_inside == 1
424 | figure();
425 | hold on
426 | plot(qdt_control(1:val_length-1, 1), 'r')
427 | plot(qdt_control(1:val_length-1, 2), 'b')
428 | xlabel('time step')
429 | ylabel('dq/dt(control)')
430 | legend('dq/dt(1)', 'dq/dt(2)')
431 | end
432 |
433 | x_control=l1*cos(q_control(:,1))+l2*cos(q_control(:,1)+q_control(:,2));
434 | y_control=l1*sin(q_control(:,1))+l2*sin(q_control(:,1)+q_control(:,2));
435 |
436 | if plot_figures_inside == 1
437 | figure();
438 | hold on
439 | plot(x(1:val_length), y(1:val_length), 'r', 'LineWidth', 1)
440 | plot(x_control(1:val_length), y_control(1:val_length), 'b--', 'LineWidth', 2)
441 | xlim([-1, 1])
442 | ylim([-1, 1])
443 | legend('real', 'desired')
444 | end
445 |
446 | start_step = train_length -1000;
447 | movie_step = 100;
448 | time_all = train_length + 10000;
449 |
450 | line_prop='solid';
451 | % plot animation of the tracking control
452 | if plot_movie == 1
453 | func_plot_movie(start_step, movie_step, time_all, q_control_all(:, 1), q_control_all(:, 2), properties, line_prop)
454 | end
455 |
456 | control_infor.q_control=q_control;
457 | control_infor.qdt_control=qdt_control;
458 | control_infor.q2dt_control=q2dt_control;
459 | control_infor.tau_control=tau_control;
460 |
461 | time_infor.val_length=val_length;
462 |
463 | end
464 |
465 |
--------------------------------------------------------------------------------
/func_double_arm.m:
--------------------------------------------------------------------------------
1 | function [] = func_double_arm()
2 |
3 | rng('shuffle')
4 |
5 | time_today = datestr(now, 'mmddyyyy');
6 |
7 | dt=0.01;
8 | input_infor={'xy', 'qdt'};
9 |
10 | % in-out dimension
11 | dim_in=length(input_infor) * 4;
12 | dim_out=2;
13 |
14 | % double robot arm properties
15 | m1=1;m2=1;
16 | l1=0.5;l2=0.5;
17 | lc1=0.25;lc2=0.25;
18 | I1=0.03;I2=0.03;
19 | properties=[m1, m2, l1, l2, lc1, lc2, I1, I2];
20 |
21 | reset_t=80; % 70 to get good results for infty symbol, and noise level:0.8*0.5*10^(-2)
22 | train_t = 200000;
23 | val_t=500;
24 |
25 | noise_level=2.0*10^(-2);
26 | disturbance = 0.00;
27 | measurement_noise = 0.00;
28 |
29 | n=200;
30 | hyperpara_set = [0.756250, 0.756250, 0.843750, -3.125, 106.71875, 2.0];
31 | eig_rho = hyperpara_set(1);
32 | W_in_a = hyperpara_set(2);
33 | alpha = hyperpara_set(3);
34 | beta = 10^hyperpara_set(4);
35 | k = round( hyperpara_set(5)/200*n);
36 | kb = hyperpara_set(6);
37 |
38 | W_in = W_in_a*(2*rand(n,dim_in)-1);
39 | res_net=sprandsym(n,k/n);
40 | eig_D=eigs(res_net,1);
41 | res_net=(eig_rho/(abs(eig_D))).*res_net;
42 | res_net=full(res_net);
43 |
44 | section_len=round(reset_t/dt); % 30
45 | washup_length=round(1002/dt);
46 | train_length=round(train_t/dt)+round(5/dt); % 50000
47 | val_length=round(val_t/dt); % 300
48 | time_length = train_length + 2 * val_length + 3 * washup_length + 100;
49 |
50 | res_infor=struct('W_in', W_in, 'res_net', res_net, 'alpha', alpha, 'kb', kb, 'beta', beta, 'n', n);
51 | time_infor=struct('section_len', section_len, 'washup_length', washup_length, ...
52 | 'train_length', train_length, 'val_length', val_length, 'time_length', time_length);
53 |
54 | % generate training and validation data
55 | [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties);
56 | xy=xy(washup_length:end, :);
57 | q=q(washup_length:end, :);
58 | qdt=qdt(washup_length:end, :);
59 | q2dt=q2dt(washup_length:end, :);
60 | tau=tau(washup_length:end, :);
61 | data_reservoir = struct('xy', xy, 'q', q, 'qdt', qdt, 'q2dt', q2dt, 'tau', tau);
62 |
63 | clearvars xy q qdt q2dt tau
64 |
65 | tic;
66 | [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out);
67 | toc;
68 |
69 | clearvars data_reservoir
70 |
71 | rng('shuffle')
72 |
73 | failure.type = 'none';
74 | blur.blur = 0;
75 |
76 | plot_val_and_update=0;
77 |
78 | disturbance = 0.00;
79 | measurement_noise = 0.00;
80 |
81 | rmse_start_time = 200000;
82 | rmse_end_time = 300000-100;
83 |
84 | plot_movie = 0;
85 | traj_type = 'lorenz';
86 | bridge_type = 'cubic';
87 |
88 | time_infor.val_length=300000;
89 |
90 | save_rend=0;
91 | idx=1;
92 |
93 | val_and_update;
94 |
95 | rmse_1 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
96 |
97 | % idx = 2
98 | traj_type = 'circle';
99 | bridge_type = 'cubic';
100 |
101 | time_infor.val_length=300000;
102 | save_rend=1;
103 | idx=2;
104 |
105 | val_and_update;
106 |
107 | rmse_2 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
108 |
109 | % idx = 3
110 | traj_type = 'mg17';
111 | bridge_type = 'cubic';
112 |
113 | time_infor.val_length=300000;
114 | save_rend=1;
115 | idx=3;
116 |
117 | val_and_update;
118 |
119 | rmse_3 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
120 |
121 | % idx = 41
122 | traj_type = 'infty';
123 | bridge_type = 'cubic';
124 |
125 | time_infor.val_length=300000;
126 | save_rend=1;
127 | idx=4;
128 |
129 | val_and_update;
130 |
131 | rmse_4 = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
132 |
133 | a_rmse = (rmse_1 + rmse_2 + rmse_3 + rmse_4) / 4;
134 |
135 | save(['./choose_file/all_traj_', time_today, '_' num2str(randi(9999)) '_' num2str(randi(9999)) '.mat'], 'time_infor', 'input_infor', 'res_infor', 'properties', 'dim_in', 'dim_out', 'Wout', 'r_end', 'dt', 'reset_t', 'noise_level', 'a_rmse')
136 |
137 | end
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
--------------------------------------------------------------------------------
/func_reservoir_train.m:
--------------------------------------------------------------------------------
1 | function [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out)
2 | rng('shuffle');
3 | % train the reservoir computing by given the input and output.
4 | xy=data_reservoir.xy;
5 | q=data_reservoir.q;
6 | qdt=data_reservoir.qdt;
7 | q2dt=data_reservoir.q2dt;
8 | tau=data_reservoir.tau;
9 |
10 | washup_length=time_infor.washup_length;
11 | train_length=time_infor.train_length; % 50000
12 |
13 | W_in=res_infor.W_in;
14 | res_net=res_infor.res_net;
15 | alpha=res_infor.alpha;
16 | kb=res_infor.kb;
17 | beta=res_infor.beta;
18 | n=res_infor.n;
19 |
20 | r_train=zeros(n,train_length-washup_length);
21 | y_train=zeros(dim_out,train_length-washup_length);
22 |
23 | train_x=zeros(train_length,dim_in);
24 | train_y=zeros(train_length,dim_out);
25 | if length(input_infor)==2 && strcmp(input_infor(1), 'xy') == 1 && strcmp(input_infor(2), 'qdt') == 1
26 | train_x(:,:)=[xy(1:train_length,:),xy(2:train_length+1,:), qdt(1:train_length,:),qdt(2:train_length+1,:)];
27 | elseif length(input_infor)==1 && strcmp(input_infor(1), 'q') == 1
28 | train_x(:,:)=[q(1:train_length,:),q(2:train_length+1,:)];
29 | elseif length(input_infor)==2 && strcmp(input_infor(1), 'q') == 1 && strcmp(input_infor(2), 'qdt') == 1
30 | train_x(:,:)=[q(1:train_length,:),q(2:train_length+1,:), qdt(1:train_length,:),qdt(2:train_length+1,:)];
31 | elseif length(input_infor)==3 && strcmp(input_infor(1), 'xy') == 1 && strcmp(input_infor(2), 'qdt') == 1 && strcmp(input_infor(3), 'q2dt') == 1
32 | train_x(:,:)=[q(1:train_length,:),q(2:train_length+1,:), qdt(1:train_length,:),qdt(2:train_length+1,:), q2dt(1:train_length,:),q2dt(2:train_length+1,:)];
33 | end
34 |
35 | train_y(:,:)=tau(1:train_length,:);
36 | train_x=train_x';
37 | train_y=train_y';
38 |
39 | r_all=zeros(n,train_length+1);%2*rand(n,1)-1;%
40 | for ti=1:train_length
41 | r_all(:,ti+1)=(1-alpha)*r_all(:,ti) + alpha*tanh(res_net*r_all(:,ti)+W_in*train_x(:,ti)+kb*ones(n,1));
42 | end
43 |
44 | r_out=r_all(:,washup_length+2:end); % n * (train_length - 11)
45 | r_out(2:2:end,:)=r_out(2:2:end,:).^2;
46 | r_end(:)=r_all(:,end); % n * 1
47 |
48 | r_train(:,:) = r_out;
49 | y_train(:,:) = train_y(1:dim_out,washup_length+1:end);
50 | % linear regression
51 | Wout=y_train*r_train'*(r_train*r_train'+beta*eye(n))^(-1);
52 |
53 | end
54 |
55 |
--------------------------------------------------------------------------------
/func_reservoir_validate.m:
--------------------------------------------------------------------------------
1 | function [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type, bridge_type, time_infor, ...
2 | input_infor, res_infor, start_info, properties, dim_in, dim_out, Wout, r_end, dt, plot_movie, save_rend, ...
3 | failure, blur, traj_frequency)
4 | %% read parameters
5 | % save_rend = 0 means, the robot arm start from all states to be zero,
6 | % while save_rend = 1 means, the robot arm continue to move.
7 | if save_rend==1
8 | q=start_info.q;
9 | qdt=start_info.qdt;
10 | q2dt=start_info.q2dt;
11 | tau=start_info.tau;
12 | end
13 |
14 | val_length=time_infor.val_length;
15 |
16 | [m1, m2, l1, l2, lc1, lc2, I1, I2] = matsplit(properties);
17 |
18 | W_in=res_infor.W_in;
19 | res_net=res_infor.res_net;
20 | alpha=res_infor.alpha;
21 | kb=res_infor.kb;
22 | n=res_infor.n;
23 |
24 | failure_type = failure.type;
25 | if strcmp(failure_type, 'none') == 0
26 | failure_amplitude = failure.amplitude;
27 | if strcmp(failure_type, 'all') == 1
28 | failure_amplitude_2 = failure.amplitude_2;
29 | end
30 | end
31 |
32 | %% read data
33 | val_pred_y=zeros(val_length, dim_out);
34 | val_real_y=zeros(val_length, dim_out);
35 |
36 | if save_rend == 1
37 | r=reshape(r_end', n, 1);
38 | else
39 | r=zeros(size(res_net, 1), 1);
40 | end
41 | u=zeros(dim_in,1);
42 |
43 | q_control=zeros(val_length+100, 2);
44 | qdt_control=zeros(val_length+100, 2);
45 | q2dt_control=zeros(val_length+100, 2);
46 | tau_control=zeros(val_length+100, 2);
47 |
48 | if save_rend == 1
49 | q_control(1, :) = q;
50 | qdt_control(1,:) = qdt;
51 | q2dt_control(1,:) = q2dt;
52 | tau_control(1,:) = tau;
53 | else
54 | if strcmp(traj_type, 'infty') == 1
55 | q_control(1, 1) = (3-1)*rand(1) + 1;
56 | q_control(1, 2) = (0.0 - 2.4) * rand(1);
57 | % q_control(1, 1) = (4-3)*rand(1) + 3;
58 | % q_control(1, 2) = (1.5 - 1.0) * rand(1) + 1.0;
59 | else
60 | q_control(1, 1) = (6-4)*rand(1) + 4;
61 | q_control(1, 2) = (0.0 - 2.4) * rand(1) - 0.1;
62 | end
63 | end
64 |
65 | % judge if there is any nan number:
66 | % dbstop if naninf
67 | q_pred=q_control;
68 | qdt_pred=qdt_control;
69 | q2dt_pred=q2dt_control;
70 | tau_pred=tau_control;
71 |
72 | control_infor = struct('q_control', q_control, 'qdt_control', qdt_control, ...
73 | 'q2dt_control', q2dt_control, 'tau_control', tau_control);
74 |
75 | %% generate desired trajectory
76 |
77 | [control_infor, time_infor] = func_desired_traj(traj_type, bridge_type, time_infor, control_infor, properties, dt, plot_movie, traj_frequency);
78 |
79 | q_control=control_infor.q_control;
80 | qdt_control=control_infor.qdt_control;
81 |
82 | val_length=time_infor.val_length;
83 |
84 | x_control=l1*cos(q_control(:,1))+l2*cos(q_control(:,1)+q_control(:,2));
85 | y_control=l1*sin(q_control(:,1))+l2*sin(q_control(:,1)+q_control(:,2));
86 | data_control=[x_control, y_control];
87 |
88 | control_infor.data_control=data_control;
89 |
90 |
91 | %% validation
92 | if length(input_infor)==2 && strcmp(input_infor(1), 'xy') == 1 && strcmp(input_infor(2), 'qdt') == 1
93 | input_infor_label = 1;
94 | u(:)=[data_control(1,:)';data_control(2,:)';qdt_control(1,:)';qdt_control(2,:)'];
95 | end
96 |
97 | data_pred = data_control;
98 | % generate gaussian noise matrix for noise testing
99 | rng((now*1000-floor(now*1000))*100000)
100 | disturbance_failure = zeros(2, val_length);
101 | measurement_failure = zeros(round(dim_in/2), val_length);
102 | if strcmp(failure_type, 'disturbance')
103 | disturbance_failure = randn(2, val_length) * failure_amplitude;
104 | elseif strcmp(failure_type, 'measurement')
105 | measurement_failure = randn(round(dim_in/2), val_length) * failure_amplitude;
106 | elseif strcmp(failure_type, 'all')
107 | disturbance_failure = randn(2, val_length) * failure_amplitude;
108 | measurement_failure = randn(round(dim_in/2), val_length) * failure_amplitude_2;
109 | end
110 | % limit the predicted tau
111 | taudt_threshold = [-5e-2, 5e-2];
112 |
113 | % In each step, according to the predicted value, the system evolves
114 | % according to its inherent rule.
115 | for t_i = 1:val_length-3
116 | r = (1-alpha)*r + alpha*tanh(res_net*r + W_in*u + +kb*ones(n,1));
117 | r_out = r;
118 | r_out(2:2:end,1) = r_out(2:2:end,1).^2; %even number -> squared
119 | predict_value = Wout * r_out;
120 |
121 | disturbance_f_value = predict_value .* disturbance_failure(:, t_i);
122 | predict_value = predict_value + disturbance_f_value;
123 | if t_i == 1
124 | time_li = 1;
125 | else
126 | time_li = t_i - 1;
127 | end
128 | for li = 1:2
129 | if predict_value(li) - tau_pred(time_li, li) > taudt_threshold(2)*dt
130 | predict_value(li) = tau_pred(time_li, li) + taudt_threshold(2)*dt;
131 | end
132 | if predict_value(li) - tau_pred(time_li, li) < taudt_threshold(1)*dt
133 | predict_value(li) = tau_pred(time_li, li) + taudt_threshold(1)*dt;
134 | end
135 | end
136 | tau_pred(t_i, :) = predict_value;
137 |
138 | time_now=t_i;
139 |
140 | H11=m1*lc1^2+I1+m2*(l1^2+lc2^2+2*l1*lc2*cos(q_pred(time_now,2)))+I2;
141 | H12=m2*l1*lc2*cos(q_pred(time_now,2))+m2*lc2^2+I2;
142 |
143 | H21=H12;
144 | H22=m2*lc2^2+I2;
145 | h=m2*l1*lc2*sin(q_pred(time_now,2));
146 |
147 | part_1=-h*qdt_pred(time_now,2)*qdt_pred(time_now,1)-h*(qdt_pred(time_now,1)+qdt_pred(time_now,2))*qdt_pred(time_now,2);
148 | part_2=h*qdt_pred(time_now,1)*qdt_pred(time_now,1);
149 | denominator=H12*H21-H11*H22;
150 |
151 | q2dt_pred(time_now,1)=-(-part_1*H22+H12*part_2-H12*predict_value(2)+H22*predict_value(1))/denominator;
152 | q2dt_pred(time_now,2)=-(part_1*H21-H11*part_2+H11*predict_value(2)-H21*predict_value(1))/denominator;
153 |
154 | q_pred(time_now+1,:)=q_pred(time_now,:)+qdt_pred(time_now,:)*dt;
155 | qdt_pred(time_now+1,:)=qdt_pred(time_now,:)+q2dt_pred(time_now,:)*dt;
156 |
157 | x_pred=l1*cos(q_pred(time_now+1,1))+l2*cos(q_pred(time_now+1,1)+q_pred(time_now+1,2));
158 | y_pred=l1*sin(q_pred(time_now+1,1))+l2*sin(q_pred(time_now+1,1)+q_pred(time_now+1,2));
159 |
160 | x_measurement_f_value = x_pred .* measurement_failure(1, t_i);
161 | y_measurement_f_value = y_pred .* measurement_failure(2, t_i);
162 | qdt_measurement_f_value = qdt_pred(time_now+1, :) .* measurement_failure(3:4, t_i)';
163 |
164 | x_pred_measurement = x_pred + x_measurement_f_value;
165 | y_pred_measurement = y_pred + y_measurement_f_value;
166 | qdt_pred_measurement = qdt_pred(time_now+1, :) + qdt_measurement_f_value;
167 |
168 | data_pred(time_now+1,:)=[x_pred, y_pred];
169 |
170 | if input_infor_label == 1
171 | u(1:2) = [x_pred_measurement;y_pred_measurement];
172 | u(3:4) = data_control(time_now+2,:);
173 | u(5:6) = qdt_pred_measurement;
174 | u(7:8) = qdt_control(time_now+2,:);
175 | end
176 | end
177 |
178 | %% output
179 |
180 | output_infor.data_pred = data_pred;
181 | output_infor.q_pred = q_pred;
182 | output_infor.qdt_pred = qdt_pred;
183 | output_infor.q2dt_pred = q2dt_pred;
184 | output_infor.tau_pred = tau_pred;
185 |
186 | r_end = r;
187 |
188 | end
189 |
190 |
--------------------------------------------------------------------------------
/func_train_val.m:
--------------------------------------------------------------------------------
1 | function [rmse_l, rmse_c, rmse_m, rmse_i, t] = func_train_val(n, train_t, reset_t, noise_level, bias)
2 |
3 | dt=0.01;
4 | input_infor={'xy', 'qdt'};
5 |
6 | dim_in=length(input_infor) * 4;
7 | dim_out=2;
8 |
9 | m1=1;m2=1;
10 | l1=0.5;l2=0.5;
11 | lc1=0.25;lc2=0.25;
12 | I1=0.03;I2=0.03;
13 | properties=[m1, m2, l1, l2, lc1, lc2, I1, I2];
14 |
15 | val_t=500;
16 | % prepare for the reservoir computing, the hyperparameters are given by the
17 | % Bayesian optimization.
18 | hyperpara_set = [0.756250, 0.756250, 0.843750, -3.125, 106.71875, bias];
19 | eig_rho = hyperpara_set(1);
20 | W_in_a = hyperpara_set(2);
21 | alpha = hyperpara_set(3);
22 | beta = 10^hyperpara_set(4);
23 | k = round( hyperpara_set(5)/200*n);
24 | kb = hyperpara_set(6);
25 |
26 | W_in = W_in_a*(2*rand(n,dim_in)-1);
27 | res_net=sprandsym(n,k/n);
28 | eig_D=eigs(res_net,1);
29 | res_net=(eig_rho/(abs(eig_D))).*res_net;
30 | res_net=full(res_net);
31 |
32 | section_len=round(reset_t/dt);
33 | washup_length=round(1002/dt);
34 | train_length=round(train_t/dt)+round(5/dt); % 50000
35 | val_length=round(val_t/dt); % 300
36 | time_length = train_length + 2 * val_length + 3 * washup_length + 100;
37 |
38 | res_infor=struct('W_in', W_in, 'res_net', res_net, 'alpha', alpha, 'kb', kb, 'beta', beta, 'n', n);
39 | time_infor=struct('section_len', section_len, 'washup_length', washup_length, ...
40 | 'train_length', train_length, 'val_length', val_length, 'time_length', time_length);
41 |
42 | % generate training and validation data
43 | [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties);
44 | xy=xy(washup_length:end, :);
45 | q=q(washup_length:end, :);
46 | qdt=qdt(washup_length:end, :);
47 | q2dt=q2dt(washup_length:end, :);
48 | tau=tau(washup_length:end, :);
49 | data_reservoir = struct('xy', xy, 'q', q, 'qdt', qdt, 'q2dt', q2dt, 'tau', tau);
50 |
51 | clearvars q qdt q2dt tau
52 |
53 | tic;
54 | [Wout, r_end] = func_reservoir_train(data_reservoir, time_infor, input_infor, res_infor, dim_in, dim_out);
55 | t = toc;
56 |
57 | clearvars data_reservoir
58 |
59 | % after training, we test on four trajectories: lorenz, circle, mg17 and
60 | % eight symbol.
61 |
62 | % lorenz
63 | disturbance = 0.00;
64 | measurement_noise = 0.00;
65 | plot_movie = 0;
66 |
67 | time_infor.val_length=150000;
68 |
69 | traj_type = 'lorenz';
70 | bridge_type = 'cubic';
71 | save_rend=0;
72 | idx=1;
73 |
74 | if exist('traj_frequency','var') == 0
75 | traj_frequency = 75;
76 | end
77 |
78 | blur.blur = 0;
79 | failure.type = 'none';
80 |
81 | if save_rend == 0
82 | start_info.q=0;
83 | start_info.qdt=0;
84 | start_info.q2dt=0;
85 | start_info.tau=0;
86 | end
87 |
88 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,...
89 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ...
90 | Wout, r_end, dt, plot_movie, save_rend,failure,blur, traj_frequency);
91 |
92 | % update
93 | data_pred=output_infor.data_pred;
94 | data_control=control_infor.data_control;
95 |
96 | rmse_start_time = round(time_infor.val_length * 3/5);
97 | rmse_end_time = time_infor.val_length - 100;
98 |
99 | rmse_l = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
100 |
101 |
102 | % circle
103 | disturbance = 0.00;
104 | measurement_noise = 0.00;
105 | plot_movie = 0;
106 |
107 | time_infor.val_length=150000;
108 |
109 | traj_type = 'circle';
110 | bridge_type = 'cubic';
111 | save_rend=0;
112 | idx=1;
113 |
114 | blur.blur = 0;
115 | failure.type = 'none';
116 |
117 | if save_rend == 0
118 | start_info.q=0;
119 | start_info.qdt=0;
120 | start_info.q2dt=0;
121 | start_info.tau=0;
122 | end
123 |
124 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,...
125 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ...
126 | Wout, r_end, dt, plot_movie, save_rend,failure,blur, traj_frequency);
127 |
128 | % update
129 | data_pred=output_infor.data_pred;
130 | data_control=control_infor.data_control;
131 |
132 | rmse_start_time = round(time_infor.val_length * 3/5);
133 | rmse_end_time = time_infor.val_length - 100;
134 |
135 | rmse_c = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
136 |
137 |
138 | % mg17
139 | disturbance = 0.00;
140 | measurement_noise = 0.00;
141 | plot_movie = 0;
142 |
143 | time_infor.val_length=150000;
144 |
145 | traj_type = 'mg17';
146 | bridge_type = 'cubic';
147 | save_rend=0;
148 | idx=1;
149 |
150 | blur.blur = 0;
151 | failure.type = 'none';
152 |
153 | if save_rend == 0
154 | start_info.q=0;
155 | start_info.qdt=0;
156 | start_info.q2dt=0;
157 | start_info.tau=0;
158 | end
159 |
160 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,...
161 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ...
162 | Wout, r_end, dt, plot_movie, save_rend,failure,blur, traj_frequency);
163 |
164 | % update
165 | data_pred=output_infor.data_pred;
166 | data_control=control_infor.data_control;
167 |
168 | rmse_start_time = round(time_infor.val_length * 3/5);
169 | rmse_end_time = time_infor.val_length - 100;
170 |
171 | rmse_m = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
172 |
173 | % infty
174 | disturbance = 0.00;
175 | measurement_noise = 0.00;
176 | plot_movie = 0;
177 |
178 | time_infor.val_length=150000;
179 |
180 | traj_type = 'infty';
181 | bridge_type = 'cubic';
182 | save_rend=0;
183 | idx=1;
184 |
185 | blur.blur = 0;
186 | failure.type = 'none';
187 |
188 | if save_rend == 0
189 | start_info.q=0;
190 | start_info.qdt=0;
191 | start_info.q2dt=0;
192 | start_info.tau=0;
193 | end
194 |
195 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,...
196 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ...
197 | Wout, r_end, dt, plot_movie, save_rend,failure,blur,traj_frequency);
198 |
199 | % update
200 | data_pred=output_infor.data_pred;
201 | data_control=control_infor.data_control;
202 |
203 | rmse_start_time = round(time_infor.val_length * 3/5);
204 | rmse_end_time = time_infor.val_length - 100;
205 |
206 | rmse_i = func_rmse(data_pred, data_control, rmse_start_time, rmse_end_time);
207 |
208 |
209 | end
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
--------------------------------------------------------------------------------
/main.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | addpath('./tools/')
6 |
7 | load('./save_file/all_traj_06282022.mat')
8 |
9 | % choose the reference trajectory
10 | % traj_type = 'lorenz';
11 | traj_type = 'circle';
12 | % traj_type = 'mg17';
13 | % traj_type = 'infty';
14 | % traj_type = 'fermat';
15 | % traj_type = 'astroid';
16 | % traj_type = 'heart';
17 | % traj_type = 'epitrochoid';
18 | % traj_type = 'lissajous';
19 | % traj_type = 'talbot';
20 | % traj_type = 'chua';
21 | % traj_type = 'rossler';
22 | % traj_type = 'sprott_1';
23 | % traj_type = 'sprott_4';
24 | % traj_type = 'mg30';
25 | % traj_type = 'lorenz96';
26 |
27 | % traj_frequency = 75;
28 |
29 | time_infor.val_length=200000;
30 | bridge_type = 'cubic';
31 | failure.type = 'none';
32 | disturbance = 0.00;
33 | measurement_noise = 0.00;
34 | plot_movie = 0;
35 | blur.blur = 0;
36 |
37 | save_rend = 0;
38 | idx = 0;
39 |
40 | plot_val_and_update = 1;
41 |
42 | % let the well-trained machine to follow the given reference.
43 | val_and_update
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/rand_traj_control.m:
--------------------------------------------------------------------------------
1 | clear all
2 | close all
3 | clc
4 |
5 | addpath('./tools/')
6 |
7 | % load training data
8 | load('./save_file/all_traj_06282022.mat')
9 |
10 | %%
11 |
12 | traj_set = ["infty", "circle", "astroid", "fermat", ...
13 | "lissajous", "talbot", "heart", "lorenz", "chua", "rossler", ...
14 | "sprott_1", "sprott_4", "mg17", "mg30", "epitrochoid"];
15 |
16 | order = randperm(length(traj_set));
17 | traj_set= traj_set(order);
18 |
19 | plot_val_and_update = 1;
20 | disturbance = 0.1;
21 | measurement_noise = 0.1;
22 | plot_movie = 0;
23 | bridge_type = 'cubic';
24 | failure.type = 'none';
25 | blur.blur = 0;
26 | idx=1;
27 |
28 | val_length_all=150000;
29 |
30 | for ii = 1:length(traj_set)
31 | rng('shuffle')
32 | traj_type = traj_set(ii);
33 | time_infor.val_length=val_length_all;
34 |
35 | if ii == 1
36 | save_rend=0;
37 | else
38 | save_rend=1;
39 | end
40 |
41 | if exist('traj_frequency','var') == 1
42 | if strcmp(traj_type, 'lorenz') == 1
43 | traj_frequency = 100;
44 | elseif strcmp(traj_type, 'cirlce') == 1
45 | traj_frequency = 150;
46 | else
47 | traj_frequency = 75;
48 | end
49 | end
50 |
51 | val_and_update;
52 |
53 | idx = idx+1;
54 | if ii == 1
55 | aaa = 1;
56 | end
57 | end
58 |
59 | time_infor.val_length=val_length_all;
60 |
61 | time_today = datestr(now, 'mmddyyyy');
62 | % save(['./save_data/15traj_', time_today, '_', num2str(randi(999)), '.mat'], "val_length_all", "save_all_traj", "traj_set")
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
--------------------------------------------------------------------------------
/read_data/temp:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/robot_data_generator.m:
--------------------------------------------------------------------------------
1 | function [xy, q, qdt, q2dt, tau] = robot_data_generator(time_infor, noise_level, dt, properties)
2 |
3 | % generate time series for training and validation
4 |
5 | [m1, m2, l1, l2, lc1, lc2, I1, I2] = matsplit(properties);
6 |
7 | section_len = time_infor.section_len;
8 | time_length = time_infor.time_length;
9 |
10 | % generate noise as the control signal, note that we should smooth the data
11 | % to make the generated signal continuous.
12 | noise_interval=noise_level; % 3*10^(-2)
13 | pert_length=time_length*2;
14 | pert=-noise_interval+2*noise_interval*rand(pert_length,2);
15 |
16 | pert=pert(100:end, :);
17 |
18 | BB = smoothdata(pert, 'gaussian', 50);
19 | pert = BB;
20 |
21 | q=zeros(time_length,2);
22 | qdt=zeros(time_length,2);
23 | q2dt=zeros(time_length,2);
24 | tau=zeros(time_length,2);
25 | tau(:,:)=pert(1:time_length,:);
26 |
27 | rng('shuffle')
28 |
29 | % To avoid the values too large in random walk, every 'section_len' step we
30 | % will reset the states of the two link robot arm.
31 | for t_i = 1:time_length-1
32 | if mod(t_i, section_len)==0
33 | q(t_i, 1) = 2 * pi * rand(1);
34 | q(t_i, 2) = 2 * pi * rand(1) - pi;
35 | qdt(t_i,:)=[0,0];
36 | end
37 |
38 | H11=m1*lc1^2+I1+m2*(l1^2+lc2^2+2*l1*lc2*cos(q(t_i,2)))+I2;
39 | H12=m2*l1*lc2*cos(q(t_i,2))+m2*lc2^2+I2;
40 | H21=H12;
41 | H22=m2*lc2^2+I2;
42 | h=m2*l1*lc2*sin(q(t_i,2));
43 |
44 | part_1=-h*qdt(t_i,2)*qdt(t_i,1)-h*(qdt(t_i,1)+qdt(t_i,2))*qdt(t_i,2);
45 | part_2=h*qdt(t_i,1)*qdt(t_i,1);
46 | denominator=H12*H21-H11*H22;
47 |
48 | q2dt(t_i,1)=-(-part_1*H22+H12*part_2-H12*tau(t_i,2)+H22*tau(t_i,1))/denominator;
49 | q2dt(t_i,2)=-(part_1*H21-H11*part_2+H11*tau(t_i,2)-H21*tau(t_i,1))/denominator;
50 |
51 | if mod(t_i, section_len)==0
52 | q2dt(t_i,:)=[0,0];
53 | tau(t_i,:)=[0,0];
54 | end
55 |
56 | q(t_i+1,:)=q(t_i,:)+qdt(t_i,:)*dt;
57 | qdt(t_i+1,:)=qdt(t_i,:)+q2dt(t_i,:)*dt;
58 | end
59 |
60 | x=l1*cos(q(:,1))+l2*cos(q(:,1)+q(:,2));
61 | y=l1*sin(q(:,1))+l2*sin(q(:,1)+q(:,2));
62 | xy=[x, y];
63 |
64 | end
65 |
66 |
--------------------------------------------------------------------------------
/save_file/all_traj_06282022.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Zheng-Meng/Tracking-Control/a431a837e3cdf2890c5098d54bd68fa53e1e9878/save_file/all_traj_06282022.mat
--------------------------------------------------------------------------------
/save_file/temp:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/tools/func_limits.m:
--------------------------------------------------------------------------------
1 | function [value] = func_limits(value, limit_type)
2 | %% limitations
3 |
4 | qdt_lim = [-0.05, 0.05];
5 | q2dt_lim = [-0.3, 0.3];
6 |
7 | % limit type:
8 | % 1 : dq/dt
9 | % 2 : d2q/dt2
10 |
11 | if limit_type == 1
12 | if value(1) > qdt_lim(2)
13 | value(1) = qdt_lim(2);
14 | end
15 | if value(1) < qdt_lim(1)
16 | value(1) = qdt_lim(1);
17 | end
18 | if value(2) > qdt_lim(2)
19 | value(2) = qdt_lim(2);
20 | end
21 | if value(2) < qdt_lim(1)
22 | value(2) = qdt_lim(1);
23 | end
24 | elseif limit_type == 2
25 | if value(1) > q2dt_lim(2)
26 | value(1) = q2dt_lim(2);
27 | end
28 | if value(1) < q2dt_lim(1)
29 | value(1) = q2dt_lim(1);
30 | end
31 | if value(2) > q2dt_lim(2)
32 | value(2) = q2dt_lim(2);
33 | end
34 | if value(2) < q2dt_lim(1)
35 | value(2) = q2dt_lim(1);
36 | end
37 | end
38 |
39 | end
40 |
41 |
--------------------------------------------------------------------------------
/tools/func_plot_figures.m:
--------------------------------------------------------------------------------
1 | function [] = func_plot_figures(control_infor, output_infor, val_length, plot_movie)
2 |
3 | data_pred=output_infor.data_pred;
4 | q_pred=output_infor.q_pred;
5 | qdt_pred=output_infor.qdt_pred;
6 | q2dt_pred=output_infor.q2dt_pred;
7 | tau_pred=output_infor.tau_pred;
8 |
9 | q_control=control_infor.q_control;
10 | qdt_control=control_infor.qdt_control;
11 | q2dt_control=control_infor.q2dt_control;
12 | tau_control=control_infor.tau_control;
13 | data_control=control_infor.data_control;
14 |
15 | start_time = 1;
16 | end_time = val_length - 10000;
17 |
18 | % plot trajectory
19 | figure();
20 | hold on
21 | plot(data_control(start_time:end_time, 1), data_control(start_time:end_time, 2),'r');
22 | plot(data_pred(start_time:end_time, 1), data_pred(start_time:end_time, 2),'b--');
23 | xlabel('x')
24 | ylabel('y')
25 | line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--')
26 | line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--')
27 | xlim([-1, 1])
28 | ylim([-1, 1])
29 | legend('desired trajectory', 'pred trajectory')
30 |
31 | % plot q
32 | q_control_plot=mod(q_control, pi);
33 | q_pred_plot=mod(q_pred, pi);
34 |
35 | figure();
36 | hold on
37 | plot(q_control_plot(start_time:end_time,1), 'r')
38 | plot(q_pred_plot(start_time:end_time,1), 'b')
39 | xlabel('time step')
40 | ylabel('q(1)')
41 | legend('desired', 'pred')
42 |
43 | figure()
44 | hold on
45 | plot(q_control_plot(start_time:end_time,2), 'r')
46 | plot(q_pred_plot(start_time:end_time,2), 'b')
47 | xlabel('time step')
48 | ylabel('q(2)')
49 | legend('desired', 'pred')
50 |
51 | % plot dq/dt
52 | figure()
53 | hold on
54 | plot(qdt_control(start_time:end_time,1), 'r')
55 | plot(qdt_pred(start_time:end_time,1), 'b')
56 | xlabel('time step')
57 | ylabel('dq/dt(1)')
58 | legend('desired', 'pred')
59 |
60 | figure()
61 | hold on
62 | plot(qdt_control(start_time:end_time,2), 'r')
63 | plot(qdt_pred(start_time:end_time,2), 'b')
64 | xlabel('time step')
65 | ylabel('dq/dt(2)')
66 | legend('desired', 'pred')
67 |
68 | % plot d2q/dt2
69 | remove_transient = 1000;
70 |
71 | figure()
72 | hold on
73 | % plot(q2dt_control(start_time+remove_transient:end_time,1), 'r')
74 | plot(q2dt_pred(start_time+remove_transient:end_time,1), 'b')
75 | xlabel('time step')
76 | ylabel('d2q/dt2(1)')
77 | legend('pred')
78 |
79 | figure()
80 | hold on
81 | % plot(q2dt_control(start_time+remove_transient:end_time,2), 'r')
82 | plot(q2dt_pred(start_time+remove_transient:end_time,2), 'b')
83 | xlabel('time step')
84 | ylabel('d2q/dt2(2)')
85 | legend('pred')
86 |
87 | figure()
88 | hold on
89 | % plot(tau_control(start_time+remove_transient:end_time,1), 'r')
90 | plot(tau_pred(start_time+remove_transient:end_time,1), 'b')
91 | xlabel('time step')
92 | ylabel('tau(1)')
93 | legend('pred')
94 |
95 | figure()
96 | hold on
97 | % plot(tau_control(start_time+remove_transient:end_time,2), 'r')
98 | plot(tau_pred(start_time+remove_transient:end_time,2), 'b')
99 | xlabel('time step')
100 | ylabel('tau(2)')
101 | legend('pred')
102 |
103 | % figure()
104 | % hold on
105 | % plot(tau_control(start_time:end_time,1), 'r')
106 | % plot(tau_control(start_time:end_time,2), 'b')
107 | % xlabel('time step')
108 | % ylabel('tau')
109 | % legend('tau1', 'tau2')
110 |
111 | % plot movie for validation
112 | plot_movie_val = plot_movie;
113 | start_step=start_time;
114 | movie_step=500;
115 | time_all=end_time;
116 | line_property='dotted';
117 | q1=q_pred(:, 1);
118 | q2=q_pred(:, 2);
119 | if plot_movie_val == 1
120 | func_plot_movie(start_step, movie_step, time_all, q1, q2, properties, line_property)
121 | end
122 |
123 |
124 |
125 |
126 | end
--------------------------------------------------------------------------------
/tools/func_plot_movie.m:
--------------------------------------------------------------------------------
1 | function [] = func_plot_movie(start_step, movie_step, time_all, q1, q2, properties, line_prop)
2 |
3 | l1=properties(3);
4 | l2=properties(4);
5 |
6 | value_x_test=l1*cos(q1)+l2*cos(q1+q2);
7 | value_y_test=l1*sin(q1)+l2*sin(q1+q2);
8 |
9 | value_x1=l1*cos(q1);
10 | value_y1=l1*sin(q1);
11 |
12 | if strcmp(line_prop, 'solid') == 1
13 | line_property = 1;
14 | elseif strcmp(line_prop, 'dotted') == 1
15 | line_property = 2;
16 | else
17 | disp('error: please input the line property!')
18 | end
19 |
20 | filename='./results/double_arm.gif';
21 | figure()
22 | for i=start_step:movie_step:time_all
23 | clf;
24 | hold on
25 |
26 | trace=1000;
27 |
28 | if i > trace * movie_step
29 | if line_property == 1
30 | plot(value_x_test(i-trace*movie_step:i), value_y_test(i-trace*movie_step:i), 'b', 'LineWidth', 1);
31 | elseif line_property == 2
32 | plot(value_x_test(i-trace*movie_step:i), value_y_test(i-trace*movie_step:i), 'b--', 'LineWidth', 1);
33 | end
34 |
35 | else
36 | if line_property == 1
37 | plot(value_x_test(1:i), value_y_test(1:i), 'b', 'LineWidth', 1);
38 | elseif line_property == 2
39 | plot(value_x_test(1:i), value_y_test(1:i), 'b--', 'LineWidth', 1);
40 | end
41 | end
42 |
43 | line([0,value_x1(i)], [0, value_y1(i)], 'LineWidth', 3);
44 | line([value_x1(i),value_x_test(i)], [value_y1(i),value_y_test(i)], 'LineWidth', 3)
45 | line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--')
46 | line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--')
47 | xlabel('x')
48 | ylabel('y')
49 | xlim([-1,1])
50 | ylim([-1,1])
51 | title(['time = ' ,num2str(i*0.01),'s'])
52 | drawnow
53 | frame=getframe(gcf);
54 | im=frame2im(frame);
55 | [imind,cm]=rgb2ind(im,256);
56 | if i==start_step
57 | imwrite(imind,cm,filename,'gif', 'Loopcount',inf);
58 | else
59 | imwrite(imind,cm,filename,'gif','WriteMode','append');
60 | end
61 |
62 | end
63 |
64 | end
65 |
66 |
--------------------------------------------------------------------------------
/tools/func_rmse.m:
--------------------------------------------------------------------------------
1 | function [rmse] = func_rmse(a, b, time_start, time_end)
2 | [c,d] = size(a);
3 | len = max(c, d);
4 |
5 | if size(a, 2) ~= len
6 | a = a';
7 | end
8 |
9 | if size(b, 2) ~= len
10 | b = b';
11 | end
12 |
13 | rmse = sqrt(mean( sum( (a(:, time_start:time_end)-b(:, time_start:time_end)).^2, 1 )));
14 |
15 | end
--------------------------------------------------------------------------------
/val_and_update.m:
--------------------------------------------------------------------------------
1 | % validate
2 | if save_rend == 0
3 | start_info.q=0;
4 | start_info.qdt=0;
5 | start_info.q2dt=0;
6 | start_info.tau=0;
7 |
8 | r_end = zeros(size(res_infor.res_net, 1), 1);
9 | end
10 |
11 | if exist('traj_frequency','var') == 0
12 | if strcmp(traj_type, 'lorenz') == 1
13 | traj_frequency = 100;
14 | elseif strcmp(traj_type, 'cirlce') == 1
15 | traj_frequency = 150;
16 | else
17 | traj_frequency = 75;
18 | end
19 | end
20 |
21 | rng('shuffle')
22 |
23 | [control_infor, output_infor, time_infor, r_end] = func_reservoir_validate(traj_type,...
24 | bridge_type, time_infor, input_infor, res_infor, start_info, properties, dim_in, dim_out, ...
25 | Wout, r_end, dt, plot_movie, save_rend,failure,blur,traj_frequency);
26 |
27 | % update
28 | data_pred=output_infor.data_pred;
29 | q_pred=output_infor.q_pred;
30 | qdt_pred=output_infor.qdt_pred;
31 | q2dt_pred=output_infor.q2dt_pred;
32 | tau_pred=output_infor.tau_pred;
33 |
34 | q_control=control_infor.q_control;
35 | qdt_control=control_infor.qdt_control;
36 | q2dt_control=control_infor.q2dt_control;
37 | tau_control=control_infor.tau_control;
38 | data_control=control_infor.data_control;
39 |
40 | val_length=time_infor.val_length;
41 |
42 | start_info.q=q_pred(val_length-3,:);
43 | start_info.qdt=qdt_pred(val_length-3, :);
44 | start_info.q2dt=q2dt_pred(val_length-3,:);
45 | start_info.tau=tau_pred(val_length-3,:);
46 |
47 | save_all_traj.(['control_', num2str(idx)]) = control_infor;
48 | save_all_traj.(['output_', num2str(idx)]) = output_infor;
49 |
50 | % plot trajectory
51 |
52 | start_time=1;
53 | end_time=val_length-100;
54 |
55 | if exist('plot_val_and_update','var') == 0
56 | plot_val_and_update = 0;
57 | end
58 |
59 | if plot_val_and_update==1
60 | figure();
61 | hold on
62 | plot(data_control(start_time:end_time, 1), data_control(start_time:end_time, 2),'r');
63 | plot(data_pred(start_time:end_time, 1), data_pred(start_time:end_time, 2),'b--');
64 | xlabel('x')
65 | ylabel('y')
66 | line([0, 0], [-1, 1], 'Color', 'black', 'LineStyle', '--')
67 | line([-1, 1], [0, 0], 'Color', 'black', 'LineStyle', '--')
68 | xlim([-1, 1])
69 | ylim([-1, 1])
70 | legend('desired trajectory', 'pred trajectory')
71 | end
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
--------------------------------------------------------------------------------