├── ReLU.m ├── main.m ├── sigmoid.m ├── backward.m ├── result1.png ├── result2.png ├── README.md ├── forward.m └── upgrading.m /ReLU.m: -------------------------------------------------------------------------------- 1 | function Y = ReLU(X) 2 | Y = max(X,0); 3 | end -------------------------------------------------------------------------------- /main.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxiang26/Simple_GAN/HEAD/main.m -------------------------------------------------------------------------------- /sigmoid.m: -------------------------------------------------------------------------------- 1 | function Y = sigmoid(X) 2 | Y = 1 ./ (1 + exp(-X)); 3 | end -------------------------------------------------------------------------------- /backward.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxiang26/Simple_GAN/HEAD/backward.m -------------------------------------------------------------------------------- /result1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxiang26/Simple_GAN/HEAD/result1.png -------------------------------------------------------------------------------- /result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cxiang26/Simple_GAN/HEAD/result2.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple_GAN 2 | Based on Matlab, we implement a simple generating adversarial networks for gaussian distribution. 3 | You can find every details in this projection. 4 | -------------------------------------------------------------------------------- /forward.m: -------------------------------------------------------------------------------- 1 | function net = forward(net, batch_x) 2 | net.x = batch_x; 3 | %net.y = batch_y'; 4 | 5 | % L1 6 | net.h = net.w * net.x + repmat(net.wb,1,size(net.x,2)); 7 | net.h_o = sigmoid(net.h); 8 | % L2 9 | net.o = net.v * net.h_o + repmat(net.vb,1,size(net.x,2)); 10 | net.o_o = sigmoid(net.o); 11 | 12 | % loss 13 | %net.loss_temp = net.o_o - net.y; 14 | %net.loss = sum(sum(1/2*(net.loss_temp).^2))/size(net.x,2); 15 | end -------------------------------------------------------------------------------- /upgrading.m: -------------------------------------------------------------------------------- 1 | function net = upgrading(net) 2 | 3 | net.w = net.w - net.lr * net.d_w/net.batch_size - net.moment * net.mw; 4 | net.v = net.v - net.lr * net.d_v/net.batch_size - net.moment * net.mv; 5 | net.wb = net.wb - net.lr * sum(net.d_wb,2)/net.batch_size; 6 | net.vb = net.vb - net.lr * sum(net.d_vb,2)/net.batch_size; 7 | 8 | net.mw = net.lr * net.d_w/net.batch_size; 9 | net.mv = net.lr * net.d_v/net.batch_size; 10 | 11 | %% WGAN 对参数进行了clip 效果明显提升 12 | % a = (net.w > 0.1) * 0.1; 13 | % b = (a < -0.1) * -0.1; 14 | % c = (net.w => -0.1) .* (net.w <= 0.1) .* net.w; 15 | % net.w = a + b + c; 16 | 17 | % a = (net.v > 0.1) * 0.1; 18 | % b = (a < -0.1) * -0.1; 19 | % c = (net.v => -0.1) .* (net.v <= 0.1) .* net.v; 20 | % net.v = a + b + c; 21 | 22 | end 23 | --------------------------------------------------------------------------------