├── my_ADRC ├── td.m ├── adrc.m ├── eso.m ├── test_adrc.m ├── test_pid.slx ├── slprj │ ├── sl_proj.tmw │ └── grt │ │ └── untitled │ │ └── tmwinternal │ │ └── minfo.mat ├── untitled_grt_rtw │ └── build_exception.mat ├── leso3.m ├── nlsef3.m ├── td3.m └── eso3.m ├── my_nnpid ├── rnn.m ├── bp_nn.m ├── pid_nn.m ├── test_nn.m ├── my_nn_pid.m └── lstm.m ├── images ├── TD_i_d.PNG ├── TD_i_t.PNG ├── TD_i_d_e.PNG ├── TD_i_t_e.PNG ├── pid_test.PNG ├── adrc_test.PNG ├── adrc_test_s_e.PNG ├── pid_test_s_e.PNG └── transfer_func.PNG ├── README.md ├── doc_2.md ├── .gitignore ├── LICENSE └── doc_1.md /my_ADRC/td.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/td.m -------------------------------------------------------------------------------- /my_ADRC/adrc.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/adrc.m -------------------------------------------------------------------------------- /my_ADRC/eso.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/eso.m -------------------------------------------------------------------------------- /my_nnpid/rnn.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/rnn.m -------------------------------------------------------------------------------- /images/TD_i_d.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_d.PNG -------------------------------------------------------------------------------- /images/TD_i_t.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_t.PNG -------------------------------------------------------------------------------- /my_nnpid/bp_nn.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/bp_nn.m -------------------------------------------------------------------------------- /my_nnpid/pid_nn.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/pid_nn.m -------------------------------------------------------------------------------- /images/TD_i_d_e.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_d_e.PNG -------------------------------------------------------------------------------- /images/TD_i_t_e.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/TD_i_t_e.PNG -------------------------------------------------------------------------------- /images/pid_test.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/pid_test.PNG -------------------------------------------------------------------------------- /my_ADRC/test_adrc.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/test_adrc.m -------------------------------------------------------------------------------- /my_nnpid/test_nn.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/test_nn.m -------------------------------------------------------------------------------- /images/adrc_test.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/adrc_test.PNG -------------------------------------------------------------------------------- /my_ADRC/test_pid.slx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/test_pid.slx -------------------------------------------------------------------------------- /my_nnpid/my_nn_pid.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_nnpid/my_nn_pid.m -------------------------------------------------------------------------------- /images/adrc_test_s_e.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/adrc_test_s_e.PNG -------------------------------------------------------------------------------- /images/pid_test_s_e.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/pid_test_s_e.PNG -------------------------------------------------------------------------------- /images/transfer_func.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/images/transfer_func.PNG -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADRC-matlab 2 | 用matlab写的ADRC程序。 3 | 还有一些相关类似读书报告的文档。 4 | 例如 5 | 6 | ### [小白理解ADRC控制器](./doc_1.md) -------------------------------------------------------------------------------- /my_ADRC/slprj/sl_proj.tmw: -------------------------------------------------------------------------------- 1 | Simulink Coder project marker file. Please don't change it. 2 | slprjVersion: 8.7_029 -------------------------------------------------------------------------------- /my_ADRC/untitled_grt_rtw/build_exception.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/untitled_grt_rtw/build_exception.mat -------------------------------------------------------------------------------- /my_ADRC/slprj/grt/untitled/tmwinternal/minfo.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lvniqi/ADRC-matlab/HEAD/my_ADRC/slprj/grt/untitled/tmwinternal/minfo.mat -------------------------------------------------------------------------------- /doc_2.md: -------------------------------------------------------------------------------- 1 | 小白理解神经网络PID 2 | ========== 3 | 整篇文章仅包含古董公式,新鲜的公式不知道在哪里,有知道的请告知:stuck_out_tongue_winking_eye:。 4 | 5 | ## 引入 6 | 如何将PID控制器挺好用, 7 | 但是如何加上点好玩~~玄学~~的东西让它看起来更高大上, 8 | 以期帮助我们更好地灌水呢? 9 | 这时候就要上神经网络这个大杀器了。 10 | 11 | 不过实际上PID神经网络是快20年前大佬们玩剩下的东西,我写这个就是想看看实际上效果和局限在哪里。 12 | 13 | ## 14 | 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Prerequisites 2 | *.d 3 | 4 | # Compiled Object files 5 | *.slo 6 | *.lo 7 | *.o 8 | *.obj 9 | 10 | # Precompiled Headers 11 | *.gch 12 | *.pch 13 | 14 | # Compiled Dynamic libraries 15 | *.so 16 | *.dylib 17 | *.dll 18 | 19 | # Fortran module files 20 | *.mod 21 | *.smod 22 | 23 | # Compiled Static libraries 24 | *.lai 25 | *.la 26 | *.a 27 | *.lib 28 | 29 | # Executables 30 | *.exe 31 | *.out 32 | *.app 33 | -------------------------------------------------------------------------------- /my_ADRC/leso3.m: -------------------------------------------------------------------------------- 1 | function [z1_new,z2_new,z3_new] = leso3(z1_last,z2_last,z3_last, ... 2 | w, ... 3 | input,b0,output,h) 4 | beta_01 = 3*w; 5 | beta_02 = 3*w^2; 6 | beta_03 = w^3; 7 | e = z1_last-output; 8 | z1_new = z1_last+h*(z2_last-beta_01*e); 9 | z2_new = z2_last + h*(z3_last-beta_02*e+b0*input); 10 | z3_new = z3_last + h*(-beta_03*e); 11 | end -------------------------------------------------------------------------------- /my_ADRC/nlsef3.m: -------------------------------------------------------------------------------- 1 | function u0 = nlsef3(e1,e2,c,r,h1) 2 | u0 = -fhan(e1,c*e2,r,h1); 3 | end 4 | 5 | function fh = fhan(x1_last,x2_last,r,h0) 6 | d = r*h0; 7 | d0 = h0*d; 8 | x1_new = x1_last+h0*x2_last; 9 | a0 = sqrt(d^2+8*r*abs(x1_new)); 10 | if abs(x1_new)>d0 11 | a=x2_last+(a0-d)/2*sign(x1_new); 12 | else 13 | a = x2_last+x1_new/h0; 14 | end 15 | fh = -r*sat(a,d); 16 | end 17 | 18 | function M=sat(x,delta) 19 | if abs(x)<=delta 20 | M=x/delta; 21 | else 22 | M=sign(x); 23 | end 24 | end -------------------------------------------------------------------------------- /my_ADRC/td3.m: -------------------------------------------------------------------------------- 1 | function [x1_new,x2_new] = td3(x1_last,x2_last,input,r,h,h0) 2 | x1_new = x1_last+h*x2_last; 3 | x2_new = x2_last+h*fhan(x1_last-input,x2_last,r,h0); 4 | end 5 | 6 | function fh = fhan(x1_last,x2_last,r,h0) 7 | d = r*h0; 8 | d0 = h0*d; 9 | x1_new = x1_last+h0*x2_last; 10 | a0 = sqrt(d^2+8*r*abs(x1_new)); 11 | if abs(x1_new)>d0 12 | a=x2_last+(a0-d)/2*sign(x1_new); 13 | else 14 | a = x2_last+x1_new/h0; 15 | end 16 | fh = -r*sat(a,d); 17 | end 18 | 19 | function M=sat(x,delta) 20 | if abs(x)<=delta 21 | M=x/delta; 22 | else 23 | M=sign(x); 24 | end 25 | end -------------------------------------------------------------------------------- /my_ADRC/eso3.m: -------------------------------------------------------------------------------- 1 | function [z1_new,z2_new,z3_new] = eso3(z1_last,z2_last,z3_last, ... 2 | beta_01,beta_02,beta_03, ... 3 | input,b0,output,h,threshold) 4 | e = z1_last-output; 5 | fe = fal(e,0.5,threshold); 6 | fe1 = fal(e,0.25,threshold); 7 | z1_new = z1_last+h*(z2_last-beta_01*e); 8 | z2_new = z2_last + h*(z3_last-beta_02*fe+b0*input); 9 | z3_new = z3_last + h*(-beta_03*fe1); 10 | end 11 | 12 | function fe = fal(error,pow,threshold) 13 | if abs(error) > threshold 14 | fe = abs(error)^pow*sign(error); 15 | else 16 | fe = error/threshold^pow; 17 | end 18 | end -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 lvniqi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /doc_1.md: -------------------------------------------------------------------------------- 1 | 小白理解ADRC控制器 2 | ========== 3 | 整篇文章仅包含毫无意义的公式,有意义的公式都在书里,小白请放心食用:stuck_out_tongue_winking_eye: 4 | 5 | ## 引入 6 | ### 传统PID存在的问题 7 | 传统上,对于模型不确定的系统我们(或者仅仅是我?)都喜欢用PID控制。 8 | 即使在模型清楚的情况下,有时我们也很难获得模型的部分参数,所以继续沿用PID控制器。 9 | 10 | 一般我们看到的的pid控制是长这样的: 11 | 12 |
13 | 14 |
15 | 16 | 如果把这样的pid控制器写成传递函数,就是 17 | 18 | ![W(s)=P+I\\frac{1}{s}+Ds](http://latex.codecogs.com/png.latex?W(s)=P+I\\frac{1}{s}+Ds) 19 | 20 | Z变换的话是变成 21 | 22 | ![W(z)=P+I\cdot{T_s}\\frac{1}{z-1}+D\\frac{z-1}{z\cdot{T_s}}](http://latex.codecogs.com/png.latex?W(z)=P+I\cdot{T_s}\\frac{1}{z-1}+D\\frac{z-1}{z\cdot{T_s}}) 23 | 24 | 嗯~看起来挺不错的,那么问题在哪儿呢? 25 | 26 | #### 微分不可用 27 | 在实际场景中,上面的pid控制上存在一个巨大的问题,即D分量一般并不可用。因为对象输出c(t)存在噪声,简单做微分会导致噪声引入系统导致系统不稳定。 28 | 29 | #### 超调&&过渡过程 30 | 31 | 32 | 任何一个控制器,其最终目的都是希望输出c(t)尽可能快地跟踪输入r(t)。在大多数情况下,我们希望这个跟踪没有超调和振荡。然而这个愿望可能实现吗? 33 | 34 | 我们观察一个二阶系统: 35 | 36 | ![](http://latex.codecogs.com/png.latex?G(s)=\\frac{a_1}{s^2+a_2s+a_1}) 37 | 38 | 当且仅当![](http://latex.codecogs.com/png.latex?{a_2}=2\\sqrt{a_1})时,系统是临界阻尼状态,此时系统既没有超调振荡,跟踪速度也是较快的。 39 | 而我们制作的控制器也是希望能修改原有的传递函数,使得其能实现类似于临界阻尼这样的状态。对于大多数二阶系统而言,PID控制就能做到。 40 | 41 | 但如果a1和a2在系统运行中有些许变化怎么办? 42 | 额,这时候简单的PID就要出事了。(至于为什么,额,我数学不好,谁有兴趣谁推导:joy:) 43 | 44 | 其实在我们输入r(t)的时候,输入的期望是不连续的(比如阶跃)。 45 | 在调节时间ts内,我们实际上在强控制器所难,因为它的输出本来就不可能到达我们期望的输入。 46 | 如果不作理论推导(又不做推导,我这样是要被打的:dizzy_face:),单凭直觉,我们很容易想象,一个含有惯性过程的系统,在e(t)较大的情况下很可能出现超调的情况。 47 | 如果我们换个思路,为控制器设计一个合理的过渡过程,是不是能减轻下这个控制器的负担,初始u(t)不会过大,超调也会减少一部分。 48 | 49 | #### 积分反馈的问题 50 | PID中的I项是用于抑制常值扰动,减少稳态误差的。然而添加I项可能导致系统动态特性变差,又有积分饱和的问题。这项这么扔着真的好么。 51 | 52 | #### PID控制器本身 53 | PID是将比例、积分、微分做线性加权和作为调节器输出扔给被控对象的(类似单个无偏置神经元?:point_left:)。这么做是不是合适?用非线性组合是不是更好? 54 | 55 | ## ADRC控制器 56 | ADRC控制器结构如下图所示: 57 | 58 | 59 | 60 | 是不是看着很乱?哈哈~我只是想安利一下这个 61 | [画图工具](https://yuml.me/diagram/scruffy/class/samples)。 62 | 63 | 正常的画风是这样的: 64 | 65 |
66 | 67 |
68 | 69 | 看着是不是和那些个pid好不一样。 70 | 照我的理解,原有的PID相当于这之中的**非线性组合**, 71 | 而**过渡过程构建**、**扩张状态观测器**这两个是新添加的模块。 72 | 抛去**非线性组合**,我们先来看看剩下的两个。 73 | 74 | #### 过渡过程构建(跟踪微分器TD) 75 | ADRC中的动态过程是通过跟踪微分器来实现的。 76 | 跟踪微分器本不是用于实现动态过程,是为了从被污染的信号中求取微分而设计的。 77 | 然后对微分进行积分,得到跟踪信号。 78 | 由于这个微分是进行滞后和滤波的,所以所得跟踪信号的带宽也有所限制。 79 | 这样得到的跟踪信号即可看做是原有信号的**动态过程**。 80 | 81 | 传统上,微分可以由传递函数 82 | 83 | ![](http://latex.codecogs.com/png.latex?W(s)=\\frac{1}{T}\(1-\\frac{1}{Ts+1}\)) 84 | 85 | 得到,式中,T为采样间隔时间。即微分 = (现在-过去)/采样间隔。显然,在输入信号含噪的情况下,这样的微分会放大噪声直至微分信号被淹没。 86 | 87 | 跟踪微分器中,将微分的传递函数变成了 88 | 89 | ![](http://latex.codecogs.com/png.latex?W(s)=\\frac{1}{\tau_2-\tau_1}\(\\frac{1}{\tau_1s+1}-\\frac{1}{\tau_2s+1}\)) 90 | 91 | 即两个惯性环节相减,目的是通过惯性环节降低噪声。 92 | 进一步地,整理上式并令τ2无限接近于τ1,可得 93 | 94 | ![](http://latex.codecogs.com/png.latex?W_d(s)=\\frac{s}{\tau^2s^2+2\tau{s}+1}) 95 | 96 | 令r = τ可得 97 | 98 | ![](http://latex.codecogs.com/png.latex?W_d(s)=\\frac{r^2s}{s^2+2rs+r^2}) 99 | 100 | 看下这是个啥。 101 | 实际上,这就是一个二阶临界阻尼系统的微分。 102 | 103 | 状态变量方程为 104 | 105 | 106 | 107 | 写成离散方程的样子,就变成一个实际可用的跟踪微分器,即 108 | 109 | 110 | 111 | x1为跟踪输出,x2为微分输出,h为采样周期 112 | 差分效果如下所示 113 | 114 | ![差分效果](./images/TD_i_d.PNG) 115 | 116 | 跟踪效果如下所示 117 | 118 | ![跟踪效果](./images/TD_i_t.PNG) 119 | 120 | 看起来不错,那么这么做的代价是什么。 121 | 眼尖的同学可能已经看出来了,跟踪有滞后。其实这么做,微分和跟踪都会有滞后。 122 | 我们将r的值改小一点看看。 123 |
124 | 125 | 126 |
127 | 128 | 发现了吧,微分滞后到几乎没有超前了,而跟踪滞后了90°,并且幅值也变小了。系统的带宽被限制了。 129 | 从另一个角度来看,这样的限制为系统增加了**过渡过程**。 130 | 因为系统实际上由于惯性的原因不可能完全跟踪输入,适当设计过渡过程可以降低超调。 131 | 此外我这儿说的是线性的跟踪微分器, 132 | 最优跟踪微分器有所改动,但思路是一脉相承的,想了解的可以看下书。 133 | 134 | #### 扩张状态观测器(ESO) 135 | 136 | ##### 状态观测器(SO) 137 | 138 | 这一部分我感觉书上讲的比较模糊(话说整本书都是以实验为主线),那我就讲得更模糊一点好了。 139 | 状态观测器的作用就是根据系统输入和输出估计系统的状态信息。 140 | 观察一个二阶系统 141 | 142 | 143 | 144 | 由于a1,a2未知,我们只能根据输出y和输入u估计整个状态变量z。 145 | 一个全维状态观测器是这样的,通过引入l1、l2进行修正。 146 | 147 | 148 | 149 | 令e = z1-x1那么 150 | 151 | ![](http://latex.codecogs.com/png.latex?e_2=z_2-x_2) 152 | 153 | ![](http://latex.codecogs.com/png.latex?\dot{e_1}=-l_1e_1-x_2) 154 | 155 | ![](http://latex.codecogs.com/png.latex?\dot{e_2}=(-l_2+a_1)e_1+a_2e_2) 156 | 157 | 因此只要l1、l2选取得当,e->0,观测器输出z逼近真实的状态变量x。 158 | 当然,你要是开心的话可以给e们过一个函数提高性能。 159 | 160 | 161 | ##### 扩张状态观测器(ESO) 162 | 163 | 那么**扩张状态观测器**究竟**扩张**在哪儿了呢? 164 | 对于一个二阶系统,假设其 165 | 166 | ![](http://latex.codecogs.com/png.latex?\dot{x_2}=f(x_1,x_2)+u) 167 | 168 | 中的f(x1,x2)进行扩张,使得 169 | 170 | ![](http://latex.codecogs.com/png.latex?x_3[t]=f(x_1[t],x_2[t])) 171 | 172 | ![](http://latex.codecogs.com/png.latex?\dot{x_3}(t)=w(t)) 173 | 174 | 对,这个扩张出来的x3就是所谓的扩张状态,构建包含x3,输出z={z1,z2,z3}的观测器就是扩张状态观测器。 175 | 176 | ### 细枝末节 177 | 178 | 其他细枝末节的东西不高兴写了,TD、控制组合、ESO都可以有线性的,也可以有非线性的(~~类比激活函数:joy:大误~~)。 179 | 180 | 181 | ## 测试一下 182 | 看如下一个二阶含噪系统 183 | 184 | ![](images/transfer_func.PNG) 185 | 186 | 使用PID以及简单自动tune以后,输出是这样的。 187 |
188 | 189 |
190 | 191 | 使用ADRC简单手动调参,输出是这样的。 192 | 193 |
194 | 195 |
196 | 197 | ADRC效果很出众吧! 198 | ~~但其实有隐忧,对比下稳态输出的噪声。~~ 199 | ~~稳态噪声相差10倍~猜测原因是~~ 200 | * ~~没有积分?~~ 201 | * ~~各种非线性?~~ 202 | 203 | **我matlab写错了,把噪声传回了传递函数里面重新算了,改了bug后噪声相差没有十倍,抱歉了。修改后的图如下** 204 | 205 | PID: 206 | 207 | ![](./images/pid_test_s_e.PNG) 208 | 209 | ADRC: 210 | 211 | ![](./images/adrc_test_s_e.PNG) 212 | 213 | 214 | 以上。 215 | 216 | ## ~~参考文献~~ 217 | [1]韩京清. 自抗扰控制技术: 估计补偿不确定因素的控制技术[M]. 国防工业出版社, 2008. -------------------------------------------------------------------------------- /my_nnpid/lstm.m: -------------------------------------------------------------------------------- 1 | function lstm() 2 | % implementation of LSTM 3 | clc 4 | clear 5 | close all 6 | 7 | %% training dataset generation 8 | binary_dim = 8; 9 | 10 | largest_number = 2^binary_dim - 1; 11 | binary = cell(largest_number, 1); 12 | 13 | for i = 1:largest_number + 1 14 | binary{i} = dec2bin(i-1, binary_dim); 15 | int2binary{i} = binary{i}; 16 | end 17 | 18 | %% input variables 19 | alpha = 0.05; 20 | input_dim = 2; 21 | hidden_dim = 32; 22 | output_dim = 1; 23 | 24 | %% initialize neural network weights 25 | % in_gate = sigmoid(X(t) * U_i + H(t-1) * W_i) ------- (1) 26 | U_i = 2 * rand(input_dim, hidden_dim) - 1; 27 | W_i = 2 * rand(hidden_dim, hidden_dim) - 1; 28 | U_i_update = zeros(size(U_i)); 29 | W_i_update = zeros(size(W_i)); 30 | 31 | % forget_gate = sigmoid(X(t) * U_f + H(t-1) * W_f) ------- (2) 32 | U_f = 2 * rand(input_dim, hidden_dim) - 1; 33 | W_f = 2 * rand(hidden_dim, hidden_dim) - 1; 34 | U_f_update = zeros(size(U_f)); 35 | W_f_update = zeros(size(W_f)); 36 | 37 | % out_gate = sigmoid(X(t) * U_o + H(t-1) * W_o) ------- (3) 38 | U_o = 2 * rand(input_dim, hidden_dim) - 1; 39 | W_o = 2 * rand(hidden_dim, hidden_dim) - 1; 40 | U_o_update = zeros(size(U_o)); 41 | W_o_update = zeros(size(W_o)); 42 | 43 | % g_gate = tanh(X(t) * U_g + H(t-1) * W_g) ------- (4) 44 | U_g = 2 * rand(input_dim, hidden_dim) - 1; 45 | W_g = 2 * rand(hidden_dim, hidden_dim) - 1; 46 | U_g_update = zeros(size(U_g)); 47 | W_g_update = zeros(size(W_g)); 48 | 49 | out_para = 2 * rand(hidden_dim, output_dim) - 1; 50 | out_para_update = zeros(size(out_para)); 51 | % C(t) = C(t-1) .* forget_gate + g_gate .* in_gate ------- (5) 52 | % S(t) = tanh(C(t)) .* out_gate ------- (6) 53 | % Out = sigmoid(S(t) * out_para) ------- (7) 54 | % Note: Equations (1)-(6) are cores of LSTM in forward, and equation (7) is 55 | % used to transfer hiddent layer to predicted output, i.e., the output layer. 56 | % (Sometimes you can use softmax for equation (7)) 57 | 58 | %% train 59 | iter = 999999; % training iterations 60 | for j = 1:iter 61 | % generate a simple addition problem (a + b = c) 62 | a_int = randi(round(largest_number/2)); % int version 63 | a = int2binary{a_int+1}; % binary encoding 64 | 65 | b_int = randi(floor(largest_number/2)); % int version 66 | b = int2binary{b_int+1}; % binary encoding 67 | 68 | % true answer 69 | c_int = a_int + b_int; % int version 70 | c = int2binary{c_int+1}; % binary encoding 71 | 72 | % where we'll store our best guess (binary encoded) 73 | d = zeros(size(c)); 74 | if length(d)<8 75 | pause; 76 | end 77 | 78 | % total error 79 | overallError = 0; 80 | 81 | % difference in output layer, i.e., (target - out) 82 | output_deltas = []; 83 | 84 | % values of hidden layer, i.e., S(t) 85 | hidden_layer_values = []; 86 | cell_gate_values = []; 87 | % initialize S(0) as a zero-vector 88 | hidden_layer_values = [hidden_layer_values; zeros(1, hidden_dim)]; 89 | cell_gate_values = [cell_gate_values; zeros(1, hidden_dim)]; 90 | 91 | % initialize memory gate 92 | % hidden layer 93 | H = []; 94 | H = [H; zeros(1, hidden_dim)]; 95 | % cell gate 96 | C = []; 97 | C = [C; zeros(1, hidden_dim)]; 98 | % in gate 99 | I = []; 100 | % forget gate 101 | F = []; 102 | % out gate 103 | O = []; 104 | % g gate 105 | G = []; 106 | 107 | % start to process a sequence, i.e., a forward pass 108 | % Note: the output of a LSTM cell is the hidden_layer, and you need to 109 | % transfer it to predicted output 110 | for position = 0:binary_dim-1 111 | % X ------> input, size: 1 x input_dim 112 | X = [a(binary_dim - position)-'0' b(binary_dim - position)-'0']; 113 | 114 | % y ------> label, size: 1 x output_dim 115 | y = [c(binary_dim - position)-'0']'; 116 | 117 | % use equations (1)-(7) in a forward pass. here we do not use bias 118 | in_gate = sigmoid(X * U_i + H(end, :) * W_i); % equation (1) 119 | forget_gate = sigmoid(X * U_f + H(end, :) * W_f); % equation (2) 120 | out_gate = sigmoid(X * U_o + H(end, :) * W_o); % equation (3) 121 | g_gate = tan_h(X * U_g + H(end, :) * W_g); % equation (4) 122 | C_t = C(end, :) .* forget_gate + g_gate .* in_gate; % equation (5) 123 | H_t = tan_h(C_t) .* out_gate; % equation (6) 124 | 125 | % store these memory gates 126 | I = [I; in_gate]; 127 | F = [F; forget_gate]; 128 | O = [O; out_gate]; 129 | G = [G; g_gate]; 130 | C = [C; C_t]; 131 | H = [H; H_t]; 132 | 133 | % compute predict output 134 | pred_out = sigmoid(H_t * out_para); 135 | 136 | % compute error in output layer 137 | output_error = y - pred_out; 138 | 139 | % compute difference in output layer using derivative 140 | % output_diff = output_error * sigmoid_output_to_derivative(pred_out); 141 | output_deltas = [output_deltas; output_error]; 142 | 143 | % compute total error 144 | % note that if the size of pred_out or target is 1 x n or m x n, 145 | % you should use other approach to compute error. here the dimension 146 | % of pred_out is 1 x 1 147 | overallError = overallError + abs(output_error(1)); 148 | 149 | % decode estimate so we can print it out 150 | d(binary_dim - position) = round(pred_out); 151 | end 152 | 153 | % from the last LSTM cell, you need a initial hidden layer difference 154 | future_H_diff = zeros(1, hidden_dim); 155 | 156 | % stare back-propagation, i.e., a backward pass 157 | % the goal is to compute differences and use them to update weights 158 | % start from the last LSTM cell 159 | for position = 0:binary_dim-1 160 | X = [a(position+1)-'0' b(position+1)-'0']; 161 | 162 | % hidden layer 163 | H_t = H(end-position, :); % H(t) 164 | % previous hidden layer 165 | H_t_1 = H(end-position-1, :); % H(t-1) 166 | C_t = C(end-position, :); % C(t) 167 | C_t_1 = C(end-position-1, :); % C(t-1) 168 | O_t = O(end-position, :); 169 | F_t = F(end-position, :); 170 | G_t = G(end-position, :); 171 | I_t = I(end-position, :); 172 | 173 | % output layer difference 174 | output_diff = output_deltas(end-position, :); 175 | 176 | % hidden layer difference 177 | % note that here we consider one hidden layer is input to both 178 | % output layer and next LSTM cell. Thus its difference also comes 179 | % from two sources. In some other method, only one source is taken 180 | % into consideration. 181 | % use the equation: delta(l) = (delta(l+1) * W(l+1)) .* f'(z) to 182 | % compute difference in previous layers. look for more about the 183 | % proof at http://neuralnetworksanddeeplearning.com/chap2.html 184 | % H_t_diff = (future_H_diff * (W_i' + W_o' + W_f' + W_g') + output_diff * out_para') ... 185 | % .* sigmoid_output_to_derivative(H_t); 186 | 187 | % H_t_diff = output_diff * (out_para') .* sigmoid_output_to_derivative(H_t); 188 | H_t_diff = output_diff * (out_para') .* sigmoid_output_to_derivative(H_t); 189 | 190 | % out_para_diff = output_diff * (H_t) * sigmoid_output_to_derivative(out_para); 191 | out_para_diff = (H_t') * output_diff; 192 | 193 | % out_gate diference 194 | O_t_diff = H_t_diff .* tan_h(C_t) .* sigmoid_output_to_derivative(O_t); 195 | 196 | % C_t difference 197 | C_t_diff = H_t_diff .* O_t .* tan_h_output_to_derivative(C_t); 198 | 199 | % % C(t-1) difference 200 | % C_t_1_diff = C_t_diff .* F_t; 201 | 202 | % forget_gate_diffeence 203 | F_t_diff = C_t_diff .* C_t_1 .* sigmoid_output_to_derivative(F_t); 204 | 205 | % in_gate difference 206 | I_t_diff = C_t_diff .* G_t .* sigmoid_output_to_derivative(I_t); 207 | 208 | % g_gate difference 209 | G_t_diff = C_t_diff .* I_t .* tan_h_output_to_derivative(G_t); 210 | 211 | % differences of U_i and W_i 212 | U_i_diff = X' * I_t_diff .* sigmoid_output_to_derivative(U_i); 213 | W_i_diff = (H_t_1)' * I_t_diff .* sigmoid_output_to_derivative(W_i); 214 | 215 | % differences of U_o and W_o 216 | U_o_diff = X' * O_t_diff .* sigmoid_output_to_derivative(U_o); 217 | W_o_diff = (H_t_1)' * O_t_diff .* sigmoid_output_to_derivative(W_o); 218 | 219 | % differences of U_o and W_o 220 | U_f_diff = X' * F_t_diff .* sigmoid_output_to_derivative(U_f); 221 | W_f_diff = (H_t_1)' * F_t_diff .* sigmoid_output_to_derivative(W_f); 222 | 223 | % differences of U_o and W_o 224 | U_g_diff = X' * G_t_diff .* tan_h_output_to_derivative(U_g); 225 | W_g_diff = (H_t_1)' * G_t_diff .* tan_h_output_to_derivative(W_g); 226 | 227 | % update 228 | U_i_update = U_i_update + U_i_diff; 229 | W_i_update = W_i_update + W_i_diff; 230 | U_o_update = U_o_update + U_o_diff; 231 | W_o_update = W_o_update + W_o_diff; 232 | U_f_update = U_f_update + U_f_diff; 233 | W_f_update = W_f_update + W_f_diff; 234 | U_g_update = U_g_update + U_g_diff; 235 | W_g_update = W_g_update + W_g_diff; 236 | out_para_update = out_para_update + out_para_diff; 237 | end 238 | 239 | U_i = U_i + U_i_update * alpha; 240 | W_i = W_i + W_i_update * alpha; 241 | U_o = U_o + U_o_update * alpha; 242 | W_o = W_o + W_o_update * alpha; 243 | U_f = U_f + U_f_update * alpha; 244 | W_f = W_f + W_f_update * alpha; 245 | U_g = U_g + U_g_update * alpha; 246 | W_g = W_g + W_g_update * alpha; 247 | out_para = out_para + out_para_update * alpha; 248 | 249 | U_i_update = U_i_update * 0; 250 | W_i_update = W_i_update * 0; 251 | U_o_update = U_o_update * 0; 252 | W_o_update = W_o_update * 0; 253 | U_f_update = U_f_update * 0; 254 | W_f_update = W_f_update * 0; 255 | U_g_update = U_g_update * 0; 256 | W_g_update = W_g_update * 0; 257 | out_para_update = out_para_update * 0; 258 | 259 | if(mod(j,1000) == 0) 260 | err = sprintf('Error:%s\n', num2str(overallError)); fprintf(err); 261 | d = bin2dec(num2str(d)); 262 | pred = sprintf('Pred:%s\n',dec2bin(d,8)); fprintf(pred); 263 | Tru = sprintf('True:%s\n', num2str(c)); fprintf(Tru); 264 | % out = 0; 265 | % sep = sprintf('-------------\n'); fprintf(sep); 266 | end 267 | end 268 | end 269 | function output = tan_h(x) 270 | output =tanh(x); 271 | end 272 | function gradient = tan_h_output_to_derivative(output) 273 | gradient =1-output.*output; 274 | end 275 | function output = sigmoid(x) 276 | output =1./(1+exp(-x)); 277 | end 278 | function gradient = sigmoid_output_to_derivative(output) 279 | gradient =output.*(1-output); 280 | end --------------------------------------------------------------------------------