├── my_ADRC
├── td.m
├── adrc.m
├── eso.m
├── test_adrc.m
├── test_pid.slx
├── slprj
│ ├── sl_proj.tmw
│ └── grt
│ │ └── untitled
│ │ └── tmwinternal
│ │ └── minfo.mat
├── untitled_grt_rtw
│ └── build_exception.mat
├── leso3.m
├── nlsef3.m
├── td3.m
└── eso3.m
├── my_nnpid
├── rnn.m
├── bp_nn.m
├── pid_nn.m
├── test_nn.m
├── my_nn_pid.m
└── lstm.m
├── images
├── TD_i_d.PNG
├── TD_i_t.PNG
├── TD_i_d_e.PNG
├── TD_i_t_e.PNG
├── pid_test.PNG
├── adrc_test.PNG
├── adrc_test_s_e.PNG
├── pid_test_s_e.PNG
└── transfer_func.PNG
├── README.md
├── doc_2.md
├── .gitignore
├── LICENSE
└── doc_1.md
/my_ADRC/td.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/td.m
--------------------------------------------------------------------------------
/my_ADRC/adrc.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/adrc.m
--------------------------------------------------------------------------------
/my_ADRC/eso.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/eso.m
--------------------------------------------------------------------------------
/my_nnpid/rnn.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/rnn.m
--------------------------------------------------------------------------------
/images/TD_i_d.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_d.PNG
--------------------------------------------------------------------------------
/images/TD_i_t.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_t.PNG
--------------------------------------------------------------------------------
/my_nnpid/bp_nn.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/bp_nn.m
--------------------------------------------------------------------------------
/my_nnpid/pid_nn.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/pid_nn.m
--------------------------------------------------------------------------------
/images/TD_i_d_e.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_d_e.PNG
--------------------------------------------------------------------------------
/images/TD_i_t_e.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_t_e.PNG
--------------------------------------------------------------------------------
/images/pid_test.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/pid_test.PNG
--------------------------------------------------------------------------------
/my_ADRC/test_adrc.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/test_adrc.m
--------------------------------------------------------------------------------
/my_nnpid/test_nn.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/test_nn.m
--------------------------------------------------------------------------------
/images/adrc_test.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/adrc_test.PNG
--------------------------------------------------------------------------------
/my_ADRC/test_pid.slx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/test_pid.slx
--------------------------------------------------------------------------------
/my_nnpid/my_nn_pid.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/my_nn_pid.m
--------------------------------------------------------------------------------
/images/adrc_test_s_e.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/adrc_test_s_e.PNG
--------------------------------------------------------------------------------
/images/pid_test_s_e.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/pid_test_s_e.PNG
--------------------------------------------------------------------------------
/images/transfer_func.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/transfer_func.PNG
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ADRC-matlab
2 | 用matlab写的ADRC程序。
3 | 还有一些相关类似读书报告的文档。
4 | 例如
5 |
6 | ### [小白理解ADRC控制器](./doc_1.md)
--------------------------------------------------------------------------------
/my_ADRC/slprj/sl_proj.tmw:
--------------------------------------------------------------------------------
1 | Simulink Coder project marker file. Please don't change it.
2 | slprjVersion: 8.7_029
--------------------------------------------------------------------------------
/my_ADRC/untitled_grt_rtw/build_exception.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/untitled_grt_rtw/build_exception.mat
--------------------------------------------------------------------------------
/my_ADRC/slprj/grt/untitled/tmwinternal/minfo.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/slprj/grt/untitled/tmwinternal/minfo.mat
--------------------------------------------------------------------------------
/doc_2.md:
--------------------------------------------------------------------------------
1 | 小白理解神经网络PID
2 | ==========
3 | 整篇文章仅包含古董公式,新鲜的公式不知道在哪里,有知道的请告知:stuck_out_tongue_winking_eye:。
4 |
5 | ## 引入
6 | 如何将PID控制器挺好用,
7 | 但是如何加上点好玩~~玄学~~的东西让它看起来更高大上,
8 | 以期帮助我们更好地灌水呢?
9 | 这时候就要上神经网络这个大杀器了。
10 |
11 | 不过实际上PID神经网络是快20年前大佬们玩剩下的东西,我写这个就是想看看实际上效果和局限在哪里。
12 |
13 | ##
14 |
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Prerequisites
2 | *.d
3 |
4 | # Compiled Object files
5 | *.slo
6 | *.lo
7 | *.o
8 | *.obj
9 |
10 | # Precompiled Headers
11 | *.gch
12 | *.pch
13 |
14 | # Compiled Dynamic libraries
15 | *.so
16 | *.dylib
17 | *.dll
18 |
19 | # Fortran module files
20 | *.mod
21 | *.smod
22 |
23 | # Compiled Static libraries
24 | *.lai
25 | *.la
26 | *.a
27 | *.lib
28 |
29 | # Executables
30 | *.exe
31 | *.out
32 | *.app
33 |
--------------------------------------------------------------------------------
/my_ADRC/leso3.m:
--------------------------------------------------------------------------------
1 | function [z1_new,z2_new,z3_new] = leso3(z1_last,z2_last,z3_last, ...
2 | w, ...
3 | input,b0,output,h)
4 | beta_01 = 3*w;
5 | beta_02 = 3*w^2;
6 | beta_03 = w^3;
7 | e = z1_last-output;
8 | z1_new = z1_last+h*(z2_last-beta_01*e);
9 | z2_new = z2_last + h*(z3_last-beta_02*e+b0*input);
10 | z3_new = z3_last + h*(-beta_03*e);
11 | end
--------------------------------------------------------------------------------
/my_ADRC/nlsef3.m:
--------------------------------------------------------------------------------
1 | function u0 = nlsef3(e1,e2,c,r,h1)
2 | u0 = -fhan(e1,c*e2,r,h1);
3 | end
4 |
5 | function fh = fhan(x1_last,x2_last,r,h0)
6 | d = r*h0;
7 | d0 = h0*d;
8 | x1_new = x1_last+h0*x2_last;
9 | a0 = sqrt(d^2+8*r*abs(x1_new));
10 | if abs(x1_new)>d0
11 | a=x2_last+(a0-d)/2*sign(x1_new);
12 | else
13 | a = x2_last+x1_new/h0;
14 | end
15 | fh = -r*sat(a,d);
16 | end
17 |
18 | function M=sat(x,delta)
19 | if abs(x)<=delta
20 | M=x/delta;
21 | else
22 | M=sign(x);
23 | end
24 | end
--------------------------------------------------------------------------------
/my_ADRC/td3.m:
--------------------------------------------------------------------------------
1 | function [x1_new,x2_new] = td3(x1_last,x2_last,input,r,h,h0)
2 | x1_new = x1_last+h*x2_last;
3 | x2_new = x2_last+h*fhan(x1_last-input,x2_last,r,h0);
4 | end
5 |
6 | function fh = fhan(x1_last,x2_last,r,h0)
7 | d = r*h0;
8 | d0 = h0*d;
9 | x1_new = x1_last+h0*x2_last;
10 | a0 = sqrt(d^2+8*r*abs(x1_new));
11 | if abs(x1_new)>d0
12 | a=x2_last+(a0-d)/2*sign(x1_new);
13 | else
14 | a = x2_last+x1_new/h0;
15 | end
16 | fh = -r*sat(a,d);
17 | end
18 |
19 | function M=sat(x,delta)
20 | if abs(x)<=delta
21 | M=x/delta;
22 | else
23 | M=sign(x);
24 | end
25 | end
--------------------------------------------------------------------------------
/my_ADRC/eso3.m:
--------------------------------------------------------------------------------
1 | function [z1_new,z2_new,z3_new] = eso3(z1_last,z2_last,z3_last, ...
2 | beta_01,beta_02,beta_03, ...
3 | input,b0,output,h,threshold)
4 | e = z1_last-output;
5 | fe = fal(e,0.5,threshold);
6 | fe1 = fal(e,0.25,threshold);
7 | z1_new = z1_last+h*(z2_last-beta_01*e);
8 | z2_new = z2_last + h*(z3_last-beta_02*fe+b0*input);
9 | z3_new = z3_last + h*(-beta_03*fe1);
10 | end
11 |
12 | function fe = fal(error,pow,threshold)
13 | if abs(error) > threshold
14 | fe = abs(error)^pow*sign(error);
15 | else
16 | fe = error/threshold^pow;
17 | end
18 | end
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 lvniqi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/doc_1.md:
--------------------------------------------------------------------------------
1 | 小白理解ADRC控制器
2 | ==========
3 | 整篇文章仅包含毫无意义的公式,有意义的公式都在书里,小白请放心食用:stuck_out_tongue_winking_eye:
4 |
5 | ## 引入
6 | ### 传统PID存在的问题
7 | 传统上,对于模型不确定的系统我们(或者仅仅是我?)都喜欢用PID控制。
8 | 即使在模型清楚的情况下,有时我们也很难获得模型的部分参数,所以继续沿用PID控制器。
9 |
10 | 一般我们看到的的pid控制是长这样的:
11 |
12 |
13 |

14 |
15 |
16 | 如果把这样的pid控制器写成传递函数,就是
17 |
18 | =P+I\\frac{1}{s}+Ds)
19 |
20 | Z变换的话是变成
21 |
22 | =P+I\cdot{T_s}\\frac{1}{z-1}+D\\frac{z-1}{z\cdot{T_s}})
23 |
24 | 嗯~看起来挺不错的,那么问题在哪儿呢?
25 |
26 | #### 微分不可用
27 | 在实际场景中,上面的pid控制上存在一个巨大的问题,即D分量一般并不可用。因为对象输出c(t)存在噪声,简单做微分会导致噪声引入系统导致系统不稳定。
28 |
29 | #### 超调&&过渡过程
30 |
31 |
32 | 任何一个控制器,其最终目的都是希望输出c(t)尽可能快地跟踪输入r(t)。在大多数情况下,我们希望这个跟踪没有超调和振荡。然而这个愿望可能实现吗?
33 |
34 | 我们观察一个二阶系统:
35 |
36 | =\\frac{a_1}{s^2+a_2s+a_1})
37 |
38 | 当且仅当时,系统是临界阻尼状态,此时系统既没有超调振荡,跟踪速度也是较快的。
39 | 而我们制作的控制器也是希望能修改原有的传递函数,使得其能实现类似于临界阻尼这样的状态。对于大多数二阶系统而言,PID控制就能做到。
40 |
41 | 但如果a1和a2在系统运行中有些许变化怎么办?
42 | 额,这时候简单的PID就要出事了。(至于为什么,额,我数学不好,谁有兴趣谁推导:joy:)
43 |
44 | 其实在我们输入r(t)的时候,输入的期望是不连续的(比如阶跃)。
45 | 在调节时间ts内,我们实际上在强控制器所难,因为它的输出本来就不可能到达我们期望的输入。
46 | 如果不作理论推导(又不做推导,我这样是要被打的:dizzy_face:),单凭直觉,我们很容易想象,一个含有惯性过程的系统,在e(t)较大的情况下很可能出现超调的情况。
47 | 如果我们换个思路,为控制器设计一个合理的过渡过程,是不是能减轻下这个控制器的负担,初始u(t)不会过大,超调也会减少一部分。
48 |
49 | #### 积分反馈的问题
50 | PID中的I项是用于抑制常值扰动,减少稳态误差的。然而添加I项可能导致系统动态特性变差,又有积分饱和的问题。这项这么扔着真的好么。
51 |
52 | #### PID控制器本身
53 | PID是将比例、积分、微分做线性加权和作为调节器输出扔给被控对象的(类似单个无偏置神经元?:point_left:)。这么做是不是合适?用非线性组合是不是更好?
54 |
55 | ## ADRC控制器
56 | ADRC控制器结构如下图所示:
57 |
58 |
59 |
60 | 是不是看着很乱?哈哈~我只是想安利一下这个
61 | [画图工具](https://yuml.me/diagram/scruffy/class/samples)。
62 |
63 | 正常的画风是这样的:
64 |
65 |
66 |

67 |
68 |
69 | 看着是不是和那些个pid好不一样。
70 | 照我的理解,原有的PID相当于这之中的**非线性组合**,
71 | 而**过渡过程构建**、**扩张状态观测器**这两个是新添加的模块。
72 | 抛去**非线性组合**,我们先来看看剩下的两个。
73 |
74 | #### 过渡过程构建(跟踪微分器TD)
75 | ADRC中的动态过程是通过跟踪微分器来实现的。
76 | 跟踪微分器本不是用于实现动态过程,是为了从被污染的信号中求取微分而设计的。
77 | 然后对微分进行积分,得到跟踪信号。
78 | 由于这个微分是进行滞后和滤波的,所以所得跟踪信号的带宽也有所限制。
79 | 这样得到的跟踪信号即可看做是原有信号的**动态过程**。
80 |
81 | 传统上,微分可以由传递函数
82 |
83 | =\\frac{1}{T}\(1-\\frac{1}{Ts+1}\))
84 |
85 | 得到,式中,T为采样间隔时间。即微分 = (现在-过去)/采样间隔。显然,在输入信号含噪的情况下,这样的微分会放大噪声直至微分信号被淹没。
86 |
87 | 跟踪微分器中,将微分的传递函数变成了
88 |
89 | =\\frac{1}{\tau_2-\tau_1}\(\\frac{1}{\tau_1s+1}-\\frac{1}{\tau_2s+1}\))
90 |
91 | 即两个惯性环节相减,目的是通过惯性环节降低噪声。
92 | 进一步地,整理上式并令τ2无限接近于τ1,可得
93 |
94 | =\\frac{s}{\tau^2s^2+2\tau{s}+1})
95 |
96 | 令r = τ可得
97 |
98 | =\\frac{r^2s}{s^2+2rs+r^2})
99 |
100 | 看下这是个啥。
101 | 实际上,这就是一个二阶临界阻尼系统的微分。
102 |
103 | 状态变量方程为
104 |
105 |
106 |
107 | 写成离散方程的样子,就变成一个实际可用的跟踪微分器,即
108 |
109 |
110 |
111 | x1为跟踪输出,x2为微分输出,h为采样周期
112 | 差分效果如下所示
113 |
114 | 
115 |
116 | 跟踪效果如下所示
117 |
118 | 
119 |
120 | 看起来不错,那么这么做的代价是什么。
121 | 眼尖的同学可能已经看出来了,跟踪有滞后。其实这么做,微分和跟踪都会有滞后。
122 | 我们将r的值改小一点看看。
123 |
124 |

125 |

126 |
127 |
128 | 发现了吧,微分滞后到几乎没有超前了,而跟踪滞后了90°,并且幅值也变小了。系统的带宽被限制了。
129 | 从另一个角度来看,这样的限制为系统增加了**过渡过程**。
130 | 因为系统实际上由于惯性的原因不可能完全跟踪输入,适当设计过渡过程可以降低超调。
131 | 此外我这儿说的是线性的跟踪微分器,
132 | 最优跟踪微分器有所改动,但思路是一脉相承的,想了解的可以看下书。
133 |
134 | #### 扩张状态观测器(ESO)
135 |
136 | ##### 状态观测器(SO)
137 |
138 | 这一部分我感觉书上讲的比较模糊(话说整本书都是以实验为主线),那我就讲得更模糊一点好了。
139 | 状态观测器的作用就是根据系统输入和输出估计系统的状态信息。
140 | 观察一个二阶系统
141 |
142 |
143 |
144 | 由于a1,a2未知,我们只能根据输出y和输入u估计整个状态变量z。
145 | 一个全维状态观测器是这样的,通过引入l1、l2进行修正。
146 |
147 |
148 |
149 | 令e = z1-x1那么
150 |
151 | 
152 |
153 | 
154 |
155 | e_1+a_2e_2)
156 |
157 | 因此只要l1、l2选取得当,e->0,观测器输出z逼近真实的状态变量x。
158 | 当然,你要是开心的话可以给e们过一个函数提高性能。
159 |
160 |
161 | ##### 扩张状态观测器(ESO)
162 |
163 | 那么**扩张状态观测器**究竟**扩张**在哪儿了呢?
164 | 对于一个二阶系统,假设其
165 |
166 | +u)
167 |
168 | 中的f(x1,x2)进行扩张,使得
169 |
170 | )
171 |
172 | =w(t))
173 |
174 | 对,这个扩张出来的x3就是所谓的扩张状态,构建包含x3,输出z={z1,z2,z3}的观测器就是扩张状态观测器。
175 |
176 | ### 细枝末节
177 |
178 | 其他细枝末节的东西不高兴写了,TD、控制组合、ESO都可以有线性的,也可以有非线性的(~~类比激活函数:joy:大误~~)。
179 |
180 |
181 | ## 测试一下
182 | 看如下一个二阶含噪系统
183 |
184 | 
185 |
186 | 使用PID以及简单自动tune以后,输出是这样的。
187 |
188 |

189 |
190 |
191 | 使用ADRC简单手动调参,输出是这样的。
192 |
193 |
194 |

195 |
196 |
197 | ADRC效果很出众吧!
198 | ~~但其实有隐忧,对比下稳态输出的噪声。~~
199 | ~~稳态噪声相差10倍~猜测原因是~~
200 | * ~~没有积分?~~
201 | * ~~各种非线性?~~
202 |
203 | **我matlab写错了,把噪声传回了传递函数里面重新算了,改了bug后噪声相差没有十倍,抱歉了。修改后的图如下**
204 |
205 | PID:
206 |
207 | 
208 |
209 | ADRC:
210 |
211 | 
212 |
213 |
214 | 以上。
215 |
216 | ## ~~参考文献~~
217 | [1]韩京清. 自抗扰控制技术: 估计补偿不确定因素的控制技术[M]. 国防工业出版社, 2008.
--------------------------------------------------------------------------------
/my_nnpid/lstm.m:
--------------------------------------------------------------------------------
1 | function lstm()
2 | % implementation of LSTM
3 | clc
4 | clear
5 | close all
6 |
7 | %% training dataset generation
8 | binary_dim = 8;
9 |
10 | largest_number = 2^binary_dim - 1;
11 | binary = cell(largest_number, 1);
12 |
13 | for i = 1:largest_number + 1
14 | binary{i} = dec2bin(i-1, binary_dim);
15 | int2binary{i} = binary{i};
16 | end
17 |
18 | %% input variables
19 | alpha = 0.05;
20 | input_dim = 2;
21 | hidden_dim = 32;
22 | output_dim = 1;
23 |
24 | %% initialize neural network weights
25 | % in_gate = sigmoid(X(t) * U_i + H(t-1) * W_i) ------- (1)
26 | U_i = 2 * rand(input_dim, hidden_dim) - 1;
27 | W_i = 2 * rand(hidden_dim, hidden_dim) - 1;
28 | U_i_update = zeros(size(U_i));
29 | W_i_update = zeros(size(W_i));
30 |
31 | % forget_gate = sigmoid(X(t) * U_f + H(t-1) * W_f) ------- (2)
32 | U_f = 2 * rand(input_dim, hidden_dim) - 1;
33 | W_f = 2 * rand(hidden_dim, hidden_dim) - 1;
34 | U_f_update = zeros(size(U_f));
35 | W_f_update = zeros(size(W_f));
36 |
37 | % out_gate = sigmoid(X(t) * U_o + H(t-1) * W_o) ------- (3)
38 | U_o = 2 * rand(input_dim, hidden_dim) - 1;
39 | W_o = 2 * rand(hidden_dim, hidden_dim) - 1;
40 | U_o_update = zeros(size(U_o));
41 | W_o_update = zeros(size(W_o));
42 |
43 | % g_gate = tanh(X(t) * U_g + H(t-1) * W_g) ------- (4)
44 | U_g = 2 * rand(input_dim, hidden_dim) - 1;
45 | W_g = 2 * rand(hidden_dim, hidden_dim) - 1;
46 | U_g_update = zeros(size(U_g));
47 | W_g_update = zeros(size(W_g));
48 |
49 | out_para = 2 * rand(hidden_dim, output_dim) - 1;
50 | out_para_update = zeros(size(out_para));
51 | % C(t) = C(t-1) .* forget_gate + g_gate .* in_gate ------- (5)
52 | % S(t) = tanh(C(t)) .* out_gate ------- (6)
53 | % Out = sigmoid(S(t) * out_para) ------- (7)
54 | % Note: Equations (1)-(6) are cores of LSTM in forward, and equation (7) is
55 | % used to transfer hiddent layer to predicted output, i.e., the output layer.
56 | % (Sometimes you can use softmax for equation (7))
57 |
58 | %% train
59 | iter = 999999; % training iterations
60 | for j = 1:iter
61 | % generate a simple addition problem (a + b = c)
62 | a_int = randi(round(largest_number/2)); % int version
63 | a = int2binary{a_int+1}; % binary encoding
64 |
65 | b_int = randi(floor(largest_number/2)); % int version
66 | b = int2binary{b_int+1}; % binary encoding
67 |
68 | % true answer
69 | c_int = a_int + b_int; % int version
70 | c = int2binary{c_int+1}; % binary encoding
71 |
72 | % where we'll store our best guess (binary encoded)
73 | d = zeros(size(c));
74 | if length(d)<8
75 | pause;
76 | end
77 |
78 | % total error
79 | overallError = 0;
80 |
81 | % difference in output layer, i.e., (target - out)
82 | output_deltas = [];
83 |
84 | % values of hidden layer, i.e., S(t)
85 | hidden_layer_values = [];
86 | cell_gate_values = [];
87 | % initialize S(0) as a zero-vector
88 | hidden_layer_values = [hidden_layer_values; zeros(1, hidden_dim)];
89 | cell_gate_values = [cell_gate_values; zeros(1, hidden_dim)];
90 |
91 | % initialize memory gate
92 | % hidden layer
93 | H = [];
94 | H = [H; zeros(1, hidden_dim)];
95 | % cell gate
96 | C = [];
97 | C = [C; zeros(1, hidden_dim)];
98 | % in gate
99 | I = [];
100 | % forget gate
101 | F = [];
102 | % out gate
103 | O = [];
104 | % g gate
105 | G = [];
106 |
107 | % start to process a sequence, i.e., a forward pass
108 | % Note: the output of a LSTM cell is the hidden_layer, and you need to
109 | % transfer it to predicted output
110 | for position = 0:binary_dim-1
111 | % X ------> input, size: 1 x input_dim
112 | X = [a(binary_dim - position)-'0' b(binary_dim - position)-'0'];
113 |
114 | % y ------> label, size: 1 x output_dim
115 | y = [c(binary_dim - position)-'0']';
116 |
117 | % use equations (1)-(7) in a forward pass. here we do not use bias
118 | in_gate = sigmoid(X * U_i + H(end, :) * W_i); % equation (1)
119 | forget_gate = sigmoid(X * U_f + H(end, :) * W_f); % equation (2)
120 | out_gate = sigmoid(X * U_o + H(end, :) * W_o); % equation (3)
121 | g_gate = tan_h(X * U_g + H(end, :) * W_g); % equation (4)
122 | C_t = C(end, :) .* forget_gate + g_gate .* in_gate; % equation (5)
123 | H_t = tan_h(C_t) .* out_gate; % equation (6)
124 |
125 | % store these memory gates
126 | I = [I; in_gate];
127 | F = [F; forget_gate];
128 | O = [O; out_gate];
129 | G = [G; g_gate];
130 | C = [C; C_t];
131 | H = [H; H_t];
132 |
133 | % compute predict output
134 | pred_out = sigmoid(H_t * out_para);
135 |
136 | % compute error in output layer
137 | output_error = y - pred_out;
138 |
139 | % compute difference in output layer using derivative
140 | % output_diff = output_error * sigmoid_output_to_derivative(pred_out);
141 | output_deltas = [output_deltas; output_error];
142 |
143 | % compute total error
144 | % note that if the size of pred_out or target is 1 x n or m x n,
145 | % you should use other approach to compute error. here the dimension
146 | % of pred_out is 1 x 1
147 | overallError = overallError + abs(output_error(1));
148 |
149 | % decode estimate so we can print it out
150 | d(binary_dim - position) = round(pred_out);
151 | end
152 |
153 | % from the last LSTM cell, you need a initial hidden layer difference
154 | future_H_diff = zeros(1, hidden_dim);
155 |
156 | % stare back-propagation, i.e., a backward pass
157 | % the goal is to compute differences and use them to update weights
158 | % start from the last LSTM cell
159 | for position = 0:binary_dim-1
160 | X = [a(position+1)-'0' b(position+1)-'0'];
161 |
162 | % hidden layer
163 | H_t = H(end-position, :); % H(t)
164 | % previous hidden layer
165 | H_t_1 = H(end-position-1, :); % H(t-1)
166 | C_t = C(end-position, :); % C(t)
167 | C_t_1 = C(end-position-1, :); % C(t-1)
168 | O_t = O(end-position, :);
169 | F_t = F(end-position, :);
170 | G_t = G(end-position, :);
171 | I_t = I(end-position, :);
172 |
173 | % output layer difference
174 | output_diff = output_deltas(end-position, :);
175 |
176 | % hidden layer difference
177 | % note that here we consider one hidden layer is input to both
178 | % output layer and next LSTM cell. Thus its difference also comes
179 | % from two sources. In some other method, only one source is taken
180 | % into consideration.
181 | % use the equation: delta(l) = (delta(l+1) * W(l+1)) .* f'(z) to
182 | % compute difference in previous layers. look for more about the
183 | % proof at http://neuralnetworksanddeeplearning.com/chap2.html
184 | % H_t_diff = (future_H_diff * (W_i' + W_o' + W_f' + W_g') + output_diff * out_para') ...
185 | % .* sigmoid_output_to_derivative(H_t);
186 |
187 | % H_t_diff = output_diff * (out_para') .* sigmoid_output_to_derivative(H_t);
188 | H_t_diff = output_diff * (out_para') .* sigmoid_output_to_derivative(H_t);
189 |
190 | % out_para_diff = output_diff * (H_t) * sigmoid_output_to_derivative(out_para);
191 | out_para_diff = (H_t') * output_diff;
192 |
193 | % out_gate diference
194 | O_t_diff = H_t_diff .* tan_h(C_t) .* sigmoid_output_to_derivative(O_t);
195 |
196 | % C_t difference
197 | C_t_diff = H_t_diff .* O_t .* tan_h_output_to_derivative(C_t);
198 |
199 | % % C(t-1) difference
200 | % C_t_1_diff = C_t_diff .* F_t;
201 |
202 | % forget_gate_diffeence
203 | F_t_diff = C_t_diff .* C_t_1 .* sigmoid_output_to_derivative(F_t);
204 |
205 | % in_gate difference
206 | I_t_diff = C_t_diff .* G_t .* sigmoid_output_to_derivative(I_t);
207 |
208 | % g_gate difference
209 | G_t_diff = C_t_diff .* I_t .* tan_h_output_to_derivative(G_t);
210 |
211 | % differences of U_i and W_i
212 | U_i_diff = X' * I_t_diff .* sigmoid_output_to_derivative(U_i);
213 | W_i_diff = (H_t_1)' * I_t_diff .* sigmoid_output_to_derivative(W_i);
214 |
215 | % differences of U_o and W_o
216 | U_o_diff = X' * O_t_diff .* sigmoid_output_to_derivative(U_o);
217 | W_o_diff = (H_t_1)' * O_t_diff .* sigmoid_output_to_derivative(W_o);
218 |
219 | % differences of U_o and W_o
220 | U_f_diff = X' * F_t_diff .* sigmoid_output_to_derivative(U_f);
221 | W_f_diff = (H_t_1)' * F_t_diff .* sigmoid_output_to_derivative(W_f);
222 |
223 | % differences of U_o and W_o
224 | U_g_diff = X' * G_t_diff .* tan_h_output_to_derivative(U_g);
225 | W_g_diff = (H_t_1)' * G_t_diff .* tan_h_output_to_derivative(W_g);
226 |
227 | % update
228 | U_i_update = U_i_update + U_i_diff;
229 | W_i_update = W_i_update + W_i_diff;
230 | U_o_update = U_o_update + U_o_diff;
231 | W_o_update = W_o_update + W_o_diff;
232 | U_f_update = U_f_update + U_f_diff;
233 | W_f_update = W_f_update + W_f_diff;
234 | U_g_update = U_g_update + U_g_diff;
235 | W_g_update = W_g_update + W_g_diff;
236 | out_para_update = out_para_update + out_para_diff;
237 | end
238 |
239 | U_i = U_i + U_i_update * alpha;
240 | W_i = W_i + W_i_update * alpha;
241 | U_o = U_o + U_o_update * alpha;
242 | W_o = W_o + W_o_update * alpha;
243 | U_f = U_f + U_f_update * alpha;
244 | W_f = W_f + W_f_update * alpha;
245 | U_g = U_g + U_g_update * alpha;
246 | W_g = W_g + W_g_update * alpha;
247 | out_para = out_para + out_para_update * alpha;
248 |
249 | U_i_update = U_i_update * 0;
250 | W_i_update = W_i_update * 0;
251 | U_o_update = U_o_update * 0;
252 | W_o_update = W_o_update * 0;
253 | U_f_update = U_f_update * 0;
254 | W_f_update = W_f_update * 0;
255 | U_g_update = U_g_update * 0;
256 | W_g_update = W_g_update * 0;
257 | out_para_update = out_para_update * 0;
258 |
259 | if(mod(j,1000) == 0)
260 | err = sprintf('Error:%s\n', num2str(overallError)); fprintf(err);
261 | d = bin2dec(num2str(d));
262 | pred = sprintf('Pred:%s\n',dec2bin(d,8)); fprintf(pred);
263 | Tru = sprintf('True:%s\n', num2str(c)); fprintf(Tru);
264 | % out = 0;
265 | % sep = sprintf('-------------\n'); fprintf(sep);
266 | end
267 | end
268 | end
269 | function output = tan_h(x)
270 | output =tanh(x);
271 | end
272 | function gradient = tan_h_output_to_derivative(output)
273 | gradient =1-output.*output;
274 | end
275 | function output = sigmoid(x)
276 | output =1./(1+exp(-x));
277 | end
278 | function gradient = sigmoid_output_to_derivative(output)
279 | gradient =output.*(1-output);
280 | end
--------------------------------------------------------------------------------