Entrenar redes generativas antagónicas condicionales (CGAN)
En este ejemplo se muestra cómo entrenar una red generativa antagónica condicional para generar imágenes.
Una red generativa antagónica (GAN) es un tipo de red de deep learning que puede generar datos con características similares a las de los datos de entrenamiento de entrada.
Una GAN consta de dos redes que se entrenan juntas:
Generador: dado un vector de valores aleatorios como entrada, esta red genera datos con la misma estructura que los datos de entrenamiento.
Discriminador: dados los lotes de datos que contienen observaciones de los datos de entrenamiento y de los datos generados por el generador, esta red intenta clasificar las observaciones como
"real"o"generated".

Una red generativa antagónica condicional (CGAN) es un tipo de GAN que también aprovecha las etiquetas durante el proceso de entrenamiento.
Generador: dados una etiqueta y un arreglo aleatorios como entrada, esta red genera datos con la misma estructura que las observaciones de datos de entrenamiento correspondientes a la misma etiqueta.
Discriminador: dados los lotes de datos etiquetados que contienen observaciones de los datos de entrenamiento y de los datos generados por el generador, esta red intenta clasificar las observaciones como
"real"o"generated".

Para entrenar una GAN condicional, entrene las dos redes simultáneamente para maximizar el rendimiento de ambas:
Entrene el generador para generar datos que "engañen" al discriminador.
Entrene el discriminador para distinguir entre datos reales y generados.
Para maximizar el rendimiento del generador, maximice la pérdida del discriminador cuando se proporcionen datos etiquetados generados. Es decir, el objetivo del generador es generar datos etiquetados que el discriminador clasifique como "real".
Para maximizar el rendimiento del discriminador, minimice la pérdida del discriminador cuando se proporcionen lotes de datos reales y etiquetados generados. Es decir, el objetivo del discriminador es no ser "engañado" por el generador.
Idealmente, estas estrategias dan como resultado un generador que genera datos convincentemente realistas que corresponden a las etiquetas de entrada y un discriminador que ha aprendido representaciones de características fuertes que son representativas de los datos de entrenamiento de cada etiqueta.
Cargar los datos de entrenamiento
Descargue y extraiga el conjunto de datos Flowers [1].
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"; downloadFolder = tempdir; filename = fullfile(downloadFolder,"flower_dataset.tgz"); imageFolder = fullfile(downloadFolder,"flower_photos"); if ~exist(imageFolder,"dir") disp("Downloading Flowers data set (218 MB)...") websave(filename,url); untar(filename,downloadFolder) end
Cree un almacén de datos de imágenes que contenga las fotos de las flores.
datasetFolder = fullfile(imageFolder);
imds = imageDatastore(datasetFolder,IncludeSubfolders=true,LabelSource="foldernames");Visualice el número de clases.
classes = categories(imds.Labels); numClasses = numel(classes)
numClasses = 5
Aumente los datos para incluir volteo horizontal aleatorio y cambie el tamaño de las imágenes para que tengan un tamaño de 64 por 64.
augmenter = imageDataAugmenter(RandXReflection=true); augimds = augmentedImageDatastore([64 64],imds,DataAugmentation=augmenter);
Definir la red del generador
Defina la siguiente red de dos entradas, que genera imágenes dados vectores aleatorios de tamaño 100 y las etiquetas correspondientes.

Esta red:
Convierte los vectores aleatorios de tamaño 100 a arreglos de 4 por 4 por 1024 utilizando una capa totalmente conectada seguida de una operación de remodelación.
Convierte las etiquetas categóricas a vectores de incrustación y cambia su forma a un arreglo de 4 por 4.
Concatena las imágenes resultantes de las dos entradas en la dimensión de canal. La salida es un arreglo de 4 por 4 por 1025.
Mejora los arreglos resultantes para que sean arreglos de 64 por 64 por 3 usando una serie de capas de convolución traspuesta con normalización de lotes y capas ReLU.
Defina esta arquitectura de la red y especifique las siguientes propiedades de red.
Para las entradas categóricas, utilice una dimensión de incrustación de 50.
Para las capas de convolución traspuesta, especifique filtros de 5 por 5 con un número descendente de filtros para cada capa, un tramo de 2 y recorte
"same"de la salida.Para la última capa de convolución traspuesta, especifique un filtro de 3 por 5 por 5 que corresponda a los tres canales RGB de las imágenes generadas.
Al final de la red, incluya una capa tanh.
Para proyectar y remodelar la entrada de ruido, use una capa totalmente conectada seguida de una operación de remodelación especificada como una capa de función con la función dada por la función feature2image, que se adjunta a este ejemplo como archivo de soporte. Para incrustar las etiquetas categóricas, use la capa personalizada embeddingLayer que se adjunta a este ejemplo como archivo de soporte. Para acceder a estos archivos de soporte, abra el ejemplo como un script en vivo.
numLatentInputs = 100;
embeddingDimension = 50;
numFilters = 64;
filterSize = 5;
projectionSize = [4 4 1024];
netG = dlnetwork;
layers = [
featureInputLayer(numLatentInputs)
fullyConnectedLayer(prod(projectionSize))
functionLayer(@(X) feature2image(X,projectionSize),Formattable=true)
concatenationLayer(3,2,Name="cat");
transposedConv2dLayer(filterSize,4*numFilters)
batchNormalizationLayer
reluLayer
transposedConv2dLayer(filterSize,2*numFilters,Stride=2,Cropping="same")
batchNormalizationLayer
reluLayer
transposedConv2dLayer(filterSize,numFilters,Stride=2,Cropping="same")
batchNormalizationLayer
reluLayer
transposedConv2dLayer(filterSize,3,Stride=2,Cropping="same")
tanhLayer];
netG = addLayers(netG,layers);
layers = [
featureInputLayer(1)
embeddingLayer(embeddingDimension,numClasses)
fullyConnectedLayer(prod(projectionSize(1:2)))
functionLayer(@(X) feature2image(X,[projectionSize(1:2) 1]),Formattable=true,Name="emb_reshape")];
netG = addLayers(netG,layers);
netG = connectLayers(netG,"emb_reshape","cat/in2");Para entrenar la red con un bucle de entrenamiento personalizado, inicialice el objeto dlnetwork.
netG = initialize(netG)
netG =
dlnetwork with properties:
Layers: [19×1 nnet.cnn.layer.Layer]
Connections: [18×2 table]
Learnables: [19×3 table]
State: [6×3 table]
InputNames: {'input' 'input_1'}
OutputNames: {'layer_2'}
Initialized: 1
View summary with summary.
Definir la red del discriminador
Defina la siguiente red de dos entradas, que clasifica imágenes de 64 por 64 reales y generadas dado un conjunto de imágenes y las etiquetas correspondientes.

Cree una red que tome como entrada imágenes de 64 por 64 por 1 y las etiquetas correspondientes y que genere como salida una puntuación de predicción escalar usando una serie de capas de convolución con normalización de lotes y capas ReLU con fugas. Añada ruido a las imágenes de entrada mediante abandono.
Para la capa de abandono, especifique una probabilidad de abandono de 0.75.
Para las capas de convolución, especifique filtros de 5 por 5 con un número ascendente de filtros para cada capa. Especifique también un tramo de 2 y el relleno de la salida en cada borde.
Para las capas ReLU con fugas, especifique una escala de 0.2.
Para la capa final, especifique una capa de convolución con un filtro de 4 por 4.
dropoutProb = 0.75;
numFilters = 64;
scale = 0.2;
inputSize = [64 64 3];
filterSize = 5;
netD = dlnetwork;
layers = [
imageInputLayer(inputSize,Normalization="none")
dropoutLayer(dropoutProb)
concatenationLayer(3,2,Name="cat")
convolution2dLayer(filterSize,numFilters,Stride=2,Padding="same")
leakyReluLayer(scale)
convolution2dLayer(filterSize,2*numFilters,Stride=2,Padding="same")
batchNormalizationLayer
leakyReluLayer(scale)
convolution2dLayer(filterSize,4*numFilters,Stride=2,Padding="same")
batchNormalizationLayer
leakyReluLayer(scale)
convolution2dLayer(filterSize,8*numFilters,Stride=2,Padding="same")
batchNormalizationLayer
leakyReluLayer(scale)
convolution2dLayer(4,1)];
netD = addLayers(netD,layers);
layers = [
featureInputLayer(1)
embeddingLayer(embeddingDimension,numClasses)
fullyConnectedLayer(prod(inputSize(1:2)))
functionLayer(@(X) feature2image(X,[inputSize(1:2) 1]),Formattable=true,Name="emb_reshape")];
netD = addLayers(netD,layers);
netD = connectLayers(netD,"emb_reshape","cat/in2");Para entrenar la red con un bucle de entrenamiento personalizado y habilitar la diferenciación automática, inicialice el objeto dlnetwork.
netD = initialize(netD)
netD =
dlnetwork with properties:
Layers: [19×1 nnet.cnn.layer.Layer]
Connections: [18×2 table]
Learnables: [19×3 table]
State: [6×3 table]
InputNames: {'imageinput' 'input'}
OutputNames: {'conv_5'}
Initialized: 1
View summary with summary.
Definir las funciones de pérdida del modelo
Cree la función modelLoss, enumerada en la sección Función de pérdida del modelo del ejemplo, que toma como entrada las redes del generador y el discriminador, un minilote de datos de entrada y un arreglo de valores aleatorios, y devuelve los gradientes de la pérdida con respecto a los parámetros que se pueden aprender en las redes y un arreglo de imágenes generadas.
Especificar las opciones de entrenamiento
Entrene con un tamaño de minilote de 128 durante 500 épocas.
numEpochs = 500; miniBatchSize = 128;
Especifique las opciones para la optimización de Adam. Para ambas redes, utilice:
Una tasa de aprendizaje de 0.0002
Un factor de decaimiento de gradiente de 0.5
Un factor de decaimiento de gradiente cuadrado de 0.999
learnRate = 0.0002; gradientDecayFactor = 0.5; squaredGradientDecayFactor = 0.999;
Actualice la gráfica del progreso del entrenamiento cada 100 iteraciones.
validationFrequency = 100;
Si el discriminador aprende a discriminar entre imágenes reales y generadas demasiado rápidamente, es posible que el generador no sea capaz de entrenarse. Para equilibrar mejor el aprendizaje del discriminador y el generador, voltee aleatoriamente las etiquetas de una parte de las imágenes reales. Especifique un factor de volteo de 0.5.
flipFactor = 0.5;
Entrenar un modelo
Entrene el modelo con un bucle de entrenamiento personalizado. Pase en bucle por los datos de entrenamiento y actualice los parámetros de la red en cada iteración. Para monitorizar el progreso del entrenamiento, muestre un lote de imágenes generadas usando un arreglo de retención de valores aleatorios para introducir en el generador y las puntuaciones de la red.
Utilice minibatchqueue para procesar y gestionar los minilotes de imágenes durante el entrenamiento. Para cada minilote:
Utilice la función de preprocesamiento de minilotes personalizada
preprocessMiniBatch(definida al final de este ejemplo) para volver a escalar las imágenes en el intervalo[-1,1].Descarte cualquier minilote parcial con menos de 128 observaciones.
Dé formato a los datos de imagen con las etiquetas de dimensión
"SSCB"(espacial, espacial, canal, lote).Dé formato a los datos de etiqueta con las etiquetas de dimensión
"BC"(lote, canal).Entrene en una GPU, si se dispone de ella. Cuando la opción
OutputEnvironmentdeminibatchqueuees"auto",minibatchqueueconvierte cada salida a ungpuArraysi hay una GPU disponible. Utilizar una GPU requiere Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox).
De forma predeterminada, el objeto minibatchqueue convierte los datos a objetos dlarray con el tipo subyacente single.
augimds.MiniBatchSize = miniBatchSize; executionEnvironment = "auto"; mbq = minibatchqueue(augimds, ... MiniBatchSize=miniBatchSize, ... PartialMiniBatch="discard", ... MiniBatchFcn=@preprocessData, ... MiniBatchFormat=["SSCB" "BC"], ... OutputEnvironment=executionEnvironment);
Inicialice los parámetros para el optimizador Adam.
velocityD = []; trailingAvgG = []; trailingAvgSqG = []; trailingAvgD = []; trailingAvgSqD = [];
Para monitorizar el progreso de entrenamiento, cree un lote de retención de 25 vectores aleatorios y un conjunto correspondiente de etiquetas de 1 a 5 (correspondientes a las clases) repetido cinco veces.
numValidationImagesPerClass = 5;
ZValidation = randn(numLatentInputs,numValidationImagesPerClass*numClasses,"single");
TValidation = single(repmat(1:numClasses,[1 numValidationImagesPerClass]));Convierta los datos a objetos dlarray y especifique las etiquetas de dimensión "CB" (canal, lote).
ZValidation = dlarray(ZValidation,"CB"); TValidation = dlarray(TValidation,"CB");
Para el entrenamiento en GPU, convierta los datos a objetos gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" ZValidation = gpuArray(ZValidation); TValidation = gpuArray(TValidation); end
Para realizar un seguimiento de las puntuaciones del generador y del discriminador, use un objeto trainingProgressMonitor. Calcule el número total de iteraciones para la monitorización.
numObservationsTrain = numel(imds.Files); numIterationsPerEpoch = floor(numObservationsTrain / miniBatchSize); numIterations = numEpochs * numIterationsPerEpoch;
Inicialice el objeto TrainingProgressMonitor. Dado que el cronómetro empieza cuando crea el objeto de monitorización, asegúrese de crear el objeto cerca del bucle de entrenamiento.
monitor = trainingProgressMonitor( ... Metrics=["GeneratorScore","DiscriminatorScore"], ... Info=["Epoch","Iteration"], ... XLabel="Iteration"); groupSubPlot(monitor,Score=["GeneratorScore","DiscriminatorScore"])
Entrene la GAN condicional. Para cada época, cambie el orden de los datos y pase en bucle por minilotes de datos.
Para cada minilote:
Se detiene si la propiedad
Stopdel objetoTrainingProgressMonitorestrue. La propiedadStopcambia atruecuando hace clic en el botón Stop.Evalúe los gradientes de la pérdida con respecto a los parámetros que se pueden aprender, el estado del generador y las puntuaciones de la red usando
dlfevaly la funciónmodelLoss.Actualice los parámetros de red con la función
adamupdate.Represente las puntuaciones de las dos redes.
Después de cada
validationFrequencyiteraciones, muestre un lote de imágenes generadas para una entrada de generador de retención fija.
La ejecución del entrenamiento puede tardar algún tiempo.
epoch = 0; iteration = 0; % Loop over epochs. while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1; % Reset and shuffle data. shuffle(mbq); % Loop over mini-batches. while hasdata(mbq) && ~monitor.Stop iteration = iteration + 1; % Read mini-batch of data. [X,T] = next(mbq); % Generate latent inputs for the generator network. Convert to % dlarray and specify the dimension labels "CB" (channel, batch). % If training on a GPU, then convert latent inputs to gpuArray. Z = randn(numLatentInputs,miniBatchSize,"single"); Z = dlarray(Z,"CB"); if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" Z = gpuArray(Z); end % Evaluate the gradients of the loss with respect to the learnable % parameters, the generator state, and the network scores using % dlfeval and the modelLoss function. [~,~,gradientsG,gradientsD,stateG,scoreG,scoreD] = ... dlfeval(@modelLoss,netG,netD,X,T,Z,flipFactor); netG.State = stateG; % Update the discriminator network parameters. [netD,trailingAvgD,trailingAvgSqD] = adamupdate(netD, gradientsD, ... trailingAvgD, trailingAvgSqD, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Update the generator network parameters. [netG,trailingAvgG,trailingAvgSqG] = ... adamupdate(netG, gradientsG, ... trailingAvgG, trailingAvgSqG, iteration, ... learnRate, gradientDecayFactor, squaredGradientDecayFactor); % Every validationFrequency iterations, display batch of generated images using the % held-out generator input. if mod(iteration,validationFrequency) == 0 || iteration == 1 % Generate images using the held-out generator input. XGeneratedValidation = predict(netG,ZValidation,TValidation); % Tile and rescale the images in the range [0 1]. I = imtile(extractdata(XGeneratedValidation), ... GridSize=[numValidationImagesPerClass numClasses]); I = rescale(I); % Display the images. image(I) xticklabels([]); yticklabels([]); title("Generated Images"); end % Update the training progress monitor. recordMetrics(monitor,iteration, ... GeneratorScore=scoreG, ... DiscriminatorScore=scoreD); updateInfo(monitor,Epoch=epoch,Iteration=iteration); monitor.Progress = 100*iteration/numIterations; end end


En este caso, el discriminador ha aprendido una representación fuerte que identifica imágenes reales entre las imágenes generadas. A su vez, el generador ha aprendido una representación de características de similar fuerza que permite generar imágenes parecidas a los datos de entrenamiento. Cada columna corresponde a una sola clase.
La gráfica de entrenamiento muestra las puntuaciones de las redes del generador y el discriminador. Para obtener más información sobre cómo interpretar las puntuaciones de las redes, consulte Monitor GAN Training Progress and Identify Common Failure Modes.
Generar imágenes nuevas
Para generar imágenes nuevas de una clase concreta, use la función predict en el generador con un objeto dlarray que contenga un lote de vectores aleatorios y un arreglo de etiquetas correspondientes a las clases deseadas. Convierta los datos a objetos dlarray y especifique las etiquetas de dimensión "CB" (canal, lote). Para la predicción en GPU, convierta los datos a objetos gpuArray. Para mostrar las imágenes juntas, use la función imtile y vuelva a escalar las imágenes con la función rescale.
Cree un arreglo de 36 vectores de valores aleatorios correspondientes a la primera clase.
numObservationsNew = 36;
idxClass = 1;
ZNew = randn(numLatentInputs,numObservationsNew,"single");
TNew = repmat(single(idxClass),[1 numObservationsNew]);Convierta los datos a objetos dlarray con las etiquetas de dimensión "SSCB" (espacial, espacial, canales, lote).
ZNew = dlarray(ZNew,"CB"); TNew = dlarray(TNew,"CB");
Para generar imágenes con la GPU, también hay que convertir los datos a objetos gpuArray.
if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" ZNew = gpuArray(ZNew); TNew = gpuArray(TNew); end
Genere imágenes usando la función predict con la red del generador.
XGeneratedNew = predict(netG,ZNew,TNew);
Muestre las imágenes generadas en una gráfica.
figure
I = imtile(extractdata(XGeneratedNew));
I = rescale(I);
imshow(I)
title("Class: " + classes(idxClass))
En este caso, la red del generador genera imágenes condicionadas en la clase especificada.
Función de pérdida del modelo
La función modelLoss toma como entrada los objetos dlnetwork del generador y el discriminador (netG y netD), un minilote de datos de entrada X, las etiquetas correspondientes T y un arreglo de valores aleatorios Z, y devuelve los gradientes de la pérdida con respecto a los parámetros que se pueden aprender en las redes, el estado del generador y las puntuaciones de red.
Si el discriminador aprende a discriminar entre imágenes reales y generadas demasiado rápidamente, es posible que el generador no sea capaz de entrenarse. Para equilibrar mejor el aprendizaje del discriminador y el generador, voltee aleatoriamente las etiquetas de una parte de las imágenes reales.
function [lossG,lossD,gradientsG,gradientsD,stateG,scoreG,scoreD] = ... modelLoss(netG,netD,X,T,Z,flipFactor) % Calculate the predictions for real data with the discriminator network. YReal = forward(netD,X,T); % Calculate the predictions for generated data with the discriminator network. [XGenerated,stateG] = forward(netG,Z,T); YGenerated = forward(netD,XGenerated,T); % Calculate probabilities. probGenerated = sigmoid(YGenerated); probReal = sigmoid(YReal); % Calculate the generator and discriminator scores. scoreG = mean(probGenerated); scoreD = (mean(probReal) + mean(1-probGenerated)) / 2; % Flip labels. numObservations = size(YReal,4); idx = randperm(numObservations,floor(flipFactor * numObservations)); probReal(:,:,:,idx) = 1 - probReal(:,:,:,idx); % Calculate the GAN loss. [lossG, lossD] = ganLoss(probReal,probGenerated); % For each network, calculate the gradients with respect to the loss. gradientsG = dlgradient(lossG,netG.Learnables,RetainData=true); gradientsD = dlgradient(lossD,netD.Learnables); end
Función de pérdida GAN
El objetivo del generador es generar datos que el discriminador clasifique como "real". Para maximizar la probabilidad de que las imágenes del generador sean clasificadas como reales por el discriminador, minimice la función de verosimilitud logarítmica negativa.
Dada la salida del discriminador:
es la probabilidad de que la imagen de entrada pertenezca a la clase
"real".es la probabilidad de que la imagen de entrada pertenezca a la clase
"generated".
Observe que la operación sigmoide tiene lugar en la función modelLoss. La función de pérdida para el generador viene dada por
donde contiene las probabilidades de salida del discriminador para las imágenes generadas.
El objetivo del discriminador es no ser "engañado" por el generador. Para maximizar la probabilidad de que el discriminador discrimine correctamente entre las imágenes reales y generadas, minimice la suma de las correspondientes funciones de verosimilitud logarítmica negativa. La función de pérdida para el discriminador viene dada por
donde contiene las probabilidades de salida del discriminador para las imágenes reales.
function [lossG, lossD] = ganLoss(scoresReal,scoresGenerated) % Calculate losses for the discriminator network. lossGenerated = -mean(log(1 - scoresGenerated)); lossReal = -mean(log(scoresReal)); % Combine the losses for the discriminator network. lossD = lossReal + lossGenerated; % Calculate the loss for the generator network. lossG = -mean(log(scoresGenerated)); end
Función de preprocesamiento de minilotes
La función preprocessMiniBatch preprocesa los datos dando los siguientes pasos:
Extraer los datos de imagen y etiqueta de los arreglos de celdas de entrada y concatenarlos en arreglos numéricos.
Volver a escalar las imágenes para que estén en el intervalo
[-1,1].
function [X,T] = preprocessData(XCell,TCell) % Extract image data from cell and concatenate X = cat(4,XCell{:}); % Extract label data from cell and concatenate T = cat(1,TCell{:}); % Rescale the images in the range [-1 1]. X = rescale(X,-1,1,InputMin=0,InputMax=255); end
Referencias
The TensorFlow Team. Flowers http://download.tensorflow.org/example_images/flower_photos.tgz
Consulte también
dlnetwork | forward | predict | dlarray | dlgradient | dlfeval | adamupdate
Temas
- Entrenar redes generativas antagónicas (GAN)
- Monitor GAN Training Progress and Identify Common Failure Modes
- Train Fast Style Transfer Network
- Generate Images Using Diffusion
- Define Custom Training Loops, Loss Functions, and Networks
- Entrenar una red con un bucle de entrenamiento personalizado
- Specify Training Options in Custom Training Loop
- Lista de capas de deep learning