Realizar predicciones con un objeto dlnetwork
En este ejemplo se muestra cómo realizar predicciones con un objeto dlnetwork haciendo bucles con minilotes.
Para conjuntos grandes de datos o cuando se predice con hardware con memoria limitada, realice predicciones haciendo bucles con minilotes de los datos usando la función minibatchpredict.
Cargar un objeto dlnetwork
Cargue un objeto dlnetwork entrenado y los nombres de clase correspondientes en el área de trabajo. La red neuronal tiene una entrada y dos salidas. Toma imágenes de dígitos escritos a mano como entrada y predice la etiqueta de los dígitos y el ángulo de rotación.
load dlnetDigitsCargar datos para predicción
Cargue los datos de prueba de dígitos para la predicción.
load DigitsDataTestVisualice los nombres de las clases.
classNames
classNames = 10×1 cell
{'0'}
{'1'}
{'2'}
{'3'}
{'4'}
{'5'}
{'6'}
{'7'}
{'8'}
{'9'}
Visualice algunas de las imágenes y las etiquetas y ángulos de rotación correspondientes.
numObservations = size(XTest,4); numPlots = 9; idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = labelsTest(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + anglesTest(idx(i))) end

Hacer predicciones
Realice predicciones con la función minibatchpredict y convierta las puntuaciones de clasificación en etiquetas con la función scores2label. De forma predeterminada, la función minibatchpredict usa una GPU en caso de que esté disponible. Para utilizar una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función usa la CPU. Para seleccionar el entorno de ejecución manualmente, utilice el argumento ExecutionEnvironment de la función minibatchpredict.
[scoresTest,Y2Test] = minibatchpredict(net,XTest); Y1Test = scores2label(scoresTest,classNames);
Visualice algunas de las predicciones.
idx = randperm(numObservations,numPlots); figure for i = 1:numPlots nexttile(i) I = XTest(:,:,:,idx(i)); label = Y1Test(idx(i)); imshow(I) title("Label: " + string(label) + newline + "Angle: " + Y2Test(idx(i))) end

Consulte también
dlarray | dlnetwork | predict | minibatchqueue | onehotdecode
Temas
- Entrenar redes generativas antagónicas (GAN)
- Entrenar una red con un bucle de entrenamiento personalizado
- Define Model Loss Function for Custom Training Loop
- Update Batch Normalization Statistics in Custom Training Loop
- Define Custom Training Loops, Loss Functions, and Networks
- Make Predictions Using Model Function
- Specify Training Options in Custom Training Loop
- Lista de capas de deep learning
- Trucos y consejos de deep learning
- Información sobre la diferenciación automática