Why happens all the gradients of the generator are zero (from the beginning to the end) when training a GAN

17 visualizaciones (últimos 30 días)
I want to train a generator which generates values of a sin function. However, when training a GAN, all the gradients of the generator are zero. I do not know what the problem is. Could anyone help me?
The code is listed as follows:
batch_size = 64;
n_ideas = 5;
art_components = 15;
step = 2/(art_components-1);
points = -1:step:1;
paint_points = repmat(points,batch_size,1);
Generator = [
featureInputLayer(n_ideas)
fullyConnectedLayer(128)
reluLayer
fullyConnectedLayer(art_components)
];
Discriminator = [
featureInputLayer(art_components)
fullyConnectedLayer(128)
reluLayer
fullyConnectedLayer(1)
sigmoidLayer
];
net_g = dlnetwork(Generator);
net_d = dlnetwork(Discriminator);
lr = 0.0001;
decay = 0.90;
sqdecay = 0.999;
avg_decay_g = [];
avd_sqdecay_g = [];
avg_decay_d = [];
avd_sqdecay_d = [];
for e=1:10000
artis_paintings = dlarray(single(artist_work(art_components,paint_points)),'BC');
% update learnable parameters of discriminator
g_ideas = dlarray(single(randn(batch_size,n_ideas)),'BC');
g_paintings = forward(net_g,g_ideas);
[loss_d,gradient_d,score_d] = ...
dlfeval(@d_loss,net_d,artis_paintings,g_paintings);
[net_d, avg_decay_d, avd_sqdecay_d] = ...
adamupdate(net_d,gradient_d,avg_decay_d,avd_sqdecay_d,e,lr,decay,sqdecay);
% update learnable parameters of generator
g_ideas = dlarray(single(randn(batch_size,n_ideas)),'BC');
g_paintings = forward(net_g,g_ideas);
prob_artist1 = forward(net_d,g_paintings);
[loss_g,gradient_g,score_g] = ...
dlfeval(@g_loss,net_g,prob_artist1);
[net_g, avg_decay_g, avd_sqdecay_g] = ...
adamupdate(net_g,gradient_g,avg_decay_g,avd_sqdecay_g,e,lr,decay,sqdecay);
end
function [loss_d,gradient_d,score_d] = ...
d_loss(net_d,artis_paintings,g_paintings)
% calculate loss
prob_artist0 = forward(net_d,artis_paintings);
prob_artist1 = forward(net_d,g_paintings);
score_d = mean(1-prob_artist1);
loss_d = -mean(log(prob_artist0)) - mean(log(1-prob_artist1));
% calculate gradients
gradient_d = dlgradient(loss_d, net_d.Learnables);
end
function [loss_g,gradient_g,score_g] = ...
g_loss(net_g,prob_artist1)
score_g = mean(prob_artist1);
% calculate gradients
loss_g = -mean(log(prob_artist1));
gradient_g = dlgradient(loss_g, net_g.Learnables);
end
function paintings=artist_work(art_components,paint_points)
r = 0.02 * randn(1,art_components);
paintings = sin(paint_points *pi) + r;
end

Respuesta aceptada

Richard
Richard el 5 de Nov. de 2022
All of the calculations that are "between" the variables you want gradients with respect to, and the loss value, need to be contained inside the function that you pass to dleval. If they are not, the dlgradient call will not know they have occurred and think there is no dependency between the outputs and inputs, hence gradients are all zero.
In this case, you must ensure that the "forward(net)" calls are inside the loss functions. You have done this correctly for the discriminator loss, but for the generator loss you need to pass in both the generator and disciminator networks and call forward on each one inside g_loss:
function [loss_g,gradient_g,score_g] = g_loss(net_g,net_d,g_ideas)
g_paintings = forward(net_g,g_ideas);
prob_artist1 = forward(net_d,g_paintings);
score_g = mean(prob_artist1);
% calculate gradients
loss_g = -mean(log(prob_artist1));
gradient_g = dlgradient(loss_g, net_g.Learnables);
end
  2 comentarios
You Jinkun
You Jinkun el 6 de Nov. de 2022
Thank u for the helpful suggestions. It works! I did not notice it when referring to the documentation.
Richard
Richard el 6 de Nov. de 2022
Thanks for the feedback @You Jinkun. I will submit a documention enhancement request regarding this aspect of the dlfeval/dlgradient interaction.
You can also use the "How useful was this information?" section at the bottom of any of our doc pages to directly submit feedback to our doc team if there is a specific page that you think could be improved (clicking on the rating opens a text field for submitting a specific comment).

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Get Started with Statistics and Machine Learning Toolbox en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by