Conditional GAN Training Error for TrainGAN function

4 visualizaciones (últimos 30 días)
Yang Liu
Yang Liu el 19 de Jun. de 2024
Comentada: Yang Liu el 28 de Jun. de 2024
I try to make the Conditional GAN training working with input as a 2D matrix: 14*8.
I try to mimic the "GenerateSyntheticPumpSignalsUsingCGANExample", by changing the vector input as a 2D matrix input.
The error message pops out as:
It seems that there is a size mismatch in the function modelGradients. But since this is an official example, thus I have no idea how to revise it. Can someone give a hint?
The input data is attached as: test.mat
The training script is attached as: untitled3.m. I have also pasted it below.
clear;
%% Load the data
% LSTM_Reform_Data_SeriesData1_20210315_data001_for_GAN;
% load('LoadedData_20210315_data001_for_GAN.mat')
load('test.mat');
% load('test2.mat');
%% Generator Network
numFilters = 4;
numLatentInputs = 120;
projectionSize = [2 1 63];
numClasses = 2;
embeddingDimension = 120;
layersGenerator = [
imageInputLayer([1 1 numLatentInputs],'Normalization','none','Name','Input')
projectAndReshapeLayer(projectionSize,numLatentInputs,'ProjReshape');
concatenationLayer(3,2,'Name','Concate1');
transposedConv2dLayer([3 2],8*numFilters,'Stride',1,'Name','TransConv1') % 4*2*32
batchNormalizationLayer('Name','BN1','Epsilon',1e-5)
reluLayer('Name','Relu1')
transposedConv2dLayer([2 2],4*numFilters,'Stride',2,'Name','TransConv2') % 8*4*16
batchNormalizationLayer('Name','BN2','Epsilon',1e-5)
reluLayer('Name','Relu2')
transposedConv2dLayer([2 2],2*numFilters,'Stride',2,'Cropping',[2 1],'Name','TransConv3') % 12*6*8
batchNormalizationLayer('Name','BN3','Epsilon',1e-5)
reluLayer('Name','Relu3')
transposedConv2dLayer([3 3],2*numFilters,'Stride',1,'Name','TransConv4') % 14*8*1
];
lgraphGenerator = layerGraph(layersGenerator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(projectionSize(1:2),embeddingDimension,numClasses,'EmbedReshape1')];
lgraphGenerator = addLayers(lgraphGenerator,layers);
lgraphGenerator = connectLayers(lgraphGenerator,'EmbedReshape1','Concate1/in2');
subplot(1,2,1);
plot(lgraphGenerator);
dlnetGenerator = dlnetwork(lgraphGenerator);
%% Discriminator Network
scale = 0.2;
Input_Num_Feature = [14 8 1]; % The input data is [14 8 1]
layersDiscriminator = [
imageInputLayer(Input_Num_Feature,'Normalization','none','Name','Input')
concatenationLayer(3,2,'Name','Concate2')
convolution2dLayer([2 2],4*numFilters,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv1')
leakyReluLayer(scale,'Name','LeakyRelu1')
convolution2dLayer([2 4],2*numFilters,'Stride',2,'DilationFactor',1,'Padding',[2 2],'Name','Conv2')
leakyReluLayer(scale,'Name','LeakyRelu2')
convolution2dLayer([2 2],numFilters,'Stride',2,'DilationFactor',1,'Padding',[0 0],'Name','Conv3')
leakyReluLayer(scale,'Name','LeakyRelu3')
convolution2dLayer([2 1],numFilters/2,'Stride',1,'DilationFactor',2,'Padding',[0 0],'Name','Conv4')
leakyReluLayer(scale,'Name','LeakyRelu4')
convolution2dLayer([2 2],numFilters/4,'Stride',1,'DilationFactor',1,'Padding',[0 0],'Name','Conv5')
];
lgraphDiscriminator = layerGraph(layersDiscriminator);
layers = [
imageInputLayer([1 1],'Name','Labels','Normalization','none')
embedAndReshapeLayer(Input_Num_Feature,embeddingDimension,numClasses,'EmbedReshape2')];
lgraphDiscriminator = addLayers(lgraphDiscriminator,layers);
lgraphDiscriminator = connectLayers(lgraphDiscriminator,'EmbedReshape2','Concate2/in2');
subplot(1,2,2);
plot(lgraphDiscriminator);
dlnetDiscriminator = dlnetwork(lgraphDiscriminator);
%% Train model
params.numLatentInputs = numLatentInputs;
params.numClasses = numClasses;
params.sizeData = [Input_Num_Feature length(Series_Fused_Label)];
params.numEpochs = 50;
params.miniBatchSize = 512;
% Specify the options for Adam optimizer
params.learnRate = 0.0002;
params.gradientDecayFactor = 0.5;
params.squaredGradientDecayFactor = 0.999;
executionEnvironment = "cpu";
params.executionEnvironment = executionEnvironment;
% for test, 14*8*30779
[dlnetGenerator,dlnetDiscriminator] =...
trainGAN(dlnetGenerator,dlnetDiscriminator,Series_Fused_Expand_Norm_Input,Series_Fused_Label,params);

Respuesta aceptada

Garmit Pant
Garmit Pant el 21 de Jun. de 2024
Hello Yang Liu
From what I understand, you are following the “Generate Synthetic Signals Using Conditional GAN” MATLAB example to train a conditional GAN work with a 2-Dimensional input.
The error you have encountered is occurring due to a mismatch in the output dimension of the generator network and the input dimension of the discriminator network.
The discriminator network has been adapted correctly for the use case and expects a 14x8x1 input. The last transposed convolutional layer of the generator network has ‘numFilters’ set as 8. This results in an output dimension of 14x8x8. Kindly make the following change to fix the network.
transposedConv2dLayer([3 3],1,'Stride',1,'Name','TransConv4') % 14*8*1
Additionally, you can either comment out or remove the following lines from the ‘trainGAN.m’ file since they are used to visualise the signal data specific to the example and for that reason it will throw an error for your specific data.
% if mod(ct,50) == 0 || ct == 1
% % Generate signals using held-out generator input
% dlXGeneratedValidation = predict(dlnetGenerator, dlZValidation, dlTValidation);
% dlXGeneratedValidation = squeeze(extractdata(gather(dlXGeneratedValidation)));
%
% % Display spectra of validation signals
% subplot(1,2,1);
% pspectrum(dlXGeneratedValidation);
% set(gca, 'XScale', 'log')
% legend('Healthy', 'Faulty')
% title("Spectra of Generated Signals")
% end
For further understanding, I suggest you refer to the following MathWorks Documentation and resources:
  1. Refer to the “Input Arguments” section to understand further about various parameters of ‘transposedConv2dlayer’: https://www.mathworks.com/help/releases/R2023b/deeplearning/ref/transposedconv2dlayer.html
I hope you find the above explanation and suggestions useful!
  1 comentario
Yang Liu
Yang Liu el 28 de Jun. de 2024
Dear Garmit,
Thank you so much for your kind help! This is my bad and overlook.I should check the last line of constructing the Generator more carefully.
Yes, the part in TrainGAN related with signal spectrum should be commented out, otherwise it may report other errors. I will revise that part as I want to observe the generated signals while I train the GAN network.
Thanks again for your kind help!
Yang Liu

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Measurements and Feature Extraction en Help Center y File Exchange.

Productos


Versión

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by