MATLAB Answers

implementation of mini-batch stochastic gradient descent

53 views (last 30 days)
konoha on 28 Mar 2021
Edited: konoha on 2 Apr 2021
I implemented a mini-batch stochastic gradien descent but counldn't find the bug in my code.
I used this implement to do a classification problem but all my final predictions are 0.
W2 = -1+2*rand(5,2); W3 = -1+2*rand(5,5);
W4 = -1+2*rand(5,5); W5 = -1+2*rand(1,5);
b2 = -1+2*rand(5,1); b3 = -1+2*rand(5,1);
b4 = -1+2*rand(5,1); b5 = -1+2*rand(1,1);
eta = 5e-3; % learning rate
iter = 1000; % number of iterations
num_data = length(label);
loss_vec = zeros(1,iter);
tloss_vec = zeros(1,iter);
for it = 1:iter
% mini-batch method
batch_size = 50;
rand_idx = randperm(num_data);
rand_idx = reshape(rand_idx,[],num_data/batch_size);
for idx = rand_idx
% forward pass
a2 = activate([x1(:,idx);x2(:,idx)], W2, b2);
a3 = activate(a2,W3,b3);
a4 = activate(a3,W4,b4);
a5 = activate(a4,W5,b5);
% backward pass (gradient)
delta5 = a5.*(1-a5).*(a5-label(idx));
delta4 = a4.*(1-a4).*(W5'*delta5);
delta3 = a3.*(1-a3).*(W4'*delta4);
delta2 = a2.*(1-a2).*(W3'*delta3);
% update weights and bias
W2 = W2 - 1/length(idx)*eta*delta2*[x1(:,idx);x2(:,idx)]';
W3 = W3 - 1/length(idx)*eta*delta3*a2';
W4 = W4 - 1/length(idx)*eta*delta4*a3';
W5 = W5 - 1/length(idx)*eta*delta5*a4';
b2 = b2 - 1/length(idx)*eta*sum(delta2,2);
b3 = b3 - 1/length(idx)*eta*sum(delta3,2);
b4 = b4 - 1/length(idx)*eta*sum(delta4,2);
b5 = b5 - 1/length(idx)*eta*sum(delta5,2);
% compute train loss and test loss
loss_vec(it) = 1/(2*num_data)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[x1;x2],label);
tloss_vec(it) = 1/(2*200)*LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,[tx1;tx2],tlabel);
%% cost function
function loss = LossFunc(W2,W3,W4,W5,b2,b3,b4,b5,x,y)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
loss = norm(a5-y,2)^2;
%% prediction
function pred = predict(W2,W3,W4,W5,b2,b3,b4,b5,x)
a2 = activate(x, W2, b2);
a3 = activate(a2, W3, b3);
a4 = activate(a3, W4, b4);
a5 = activate(a4, W5, b5);
pred = round(a5);
%% activation function
function y = activate(x,W,b)
y = 1./(1+exp(-(W*x+b)));

Answers (1)

Mahesh Taparia
Mahesh Taparia on 2 Apr 2021
You mentioned that you are implementing a classification network. In your code, you are using square of L2 norm to calculate the loss and loss derivative is also not correct while doing back propagation. Moreover, since it is a classification network, use the classification loss like cross entropy loss, focalcrossentropy, etc instead of norm. May be this is the reason you are getting 0 everytime.
Also, you can use MATLAB inbuilt function to perform back propagation. For this, you can refer the link given below:
Hope it will help!
  1 Comment
konoha on 2 Apr 2021
the derivative of mes is -(y-f(x))f'(x). I don't follow your suggestions.
Thank you.

Sign in to comment.




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by