├── GussianKernal.m ├── README.md ├── SBLsinc.m └── data.m /GussianKernal.m: -------------------------------------------------------------------------------- 1 | function GK = GussianKernal( r1, r2 ) 2 | 3 | ler1=size(r1,1); 4 | ler2=size(r2,1); 5 | GK=ones(ler1,ler2); 6 | 7 | for j=2:ler2 8 | for i = 1:ler1 9 | GK(i,j)=exp(-(r1(i)-r2(j))^2); 10 | end 11 | end 12 | 13 | end 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse-Bayesian-Learning 2 | SBL matlab code 3 | 验证sinc函数,加入gussian白噪声 4 | -------------------------------------------------------------------------------- /SBLsinc.m: -------------------------------------------------------------------------------- 1 | clc; 2 | clear; 3 | Ntrain = 200; %训练集样本数 4 | Ntest = 500; %测试集样本数 5 | [x,t] = data(Ntrain); %生成训练数据集 6 | f = GussianKernal(x,x); %生成高斯核函数 7 | 8 | m=size(f,2); 9 | alpha=rand(1,m); 10 | beta=rand(); %初始化学习参数 11 | 12 | for ep = 1:3000 13 | A = diag(alpha); 14 | F = pinv(beta*f'*f+A); 15 | u = beta*F*f'*t; 16 | gamma = 1-alpha.*diag(F)'; %更新参数 17 | alpha_old=alpha; 18 | beta_old=beta; 19 | 20 | index=abs(alpha)<1e3; 21 | alpha(index)=gamma(index)./(u(index)'.^2); 22 | beta=(Ntrain-sum(gamma))/((t-f*u)'*(t-f*u));%将数值小于1000的列标去除,只更新此部分 23 | 24 | err=max(abs(alpha(index)-alpha_old(index))./abs(alpha(index)))+abs(beta-beta_old)/abs(beta); 25 | if err<0.1 26 | break; 27 | end %判断如果满足收敛条件,停止迭代 28 | end 29 | 30 | tpre=f(:,index)*u(index); %计算预测数据集 31 | 32 | figure(1); 33 | plot(x,t,'r+'); 34 | hold on; 35 | plot(x,tpre,'b*'); 36 | hold on; 37 | plot(x(index(2:end)),t(index(2:end)),'ko'); %分别绘制原始数据、预测数据和相关向量 38 | 39 | title('Testing Data1'); 40 | xlabel('x'); 41 | ylabel('t'); 42 | legend('训练数据','预测数据','相关向量');%标注 43 | 44 | [xte,tte]=data(Ntest);%生成500个测试数据集 45 | ftest = GussianKernal(xte,x);%生成关联矩阵 46 | ttepre=ftest(:,index)*u(index);%预测 47 | figure(2) 48 | plot(xte,tte,'.'); 49 | hold on; 50 | plot(xte,ttepre,'+'); 51 | hold on; 52 | plot(x(index(2:end)),t(index(2:end)),'ko');%分别绘制原始数据、预测数据和关联向量(由200个数据给出) 53 | 54 | title('Testing Data2'); 55 | xlabel('x'); 56 | ylabel('t'); 57 | legend('训练数据','预测数据','相关向量');%标注 58 | -------------------------------------------------------------------------------- /data.m: -------------------------------------------------------------------------------- 1 | function [input, output ] = data( N ) 2 | input=(rand(N,1)-0.5)*4; 3 | output=sin(input*5)+randn(N,1)/4; 4 | end 5 | --------------------------------------------------------------------------------