Evaluar el rendimiento de redes neuronales de regresión
Cree un modelo de red neuronal predictiva de regresión con capas interconectadas usando fitrnet. Utilice datos de validación para detener el proceso de entrenamiento antes de tiempo a fin de evitar el sobreajuste del modelo. Después, utilice las funciones de objeto del modelo para evaluar su rendimiento en los datos de prueba.
Cargar datos de muestra
Cargue el conjunto de datos carbig, que contiene mediciones de coches fabricados en la década de los 70 y a principios de la década de los 80.
load carbigConvierta la variable Origin a una variable categórica. Después, cree una tabla que contenga las variables predictoras Acceleration, Displacement, etc., así como la variable de respuesta MPG. Cada fila contiene las mediciones de un solo coche. Elimine las filas de la tabla en las que la tabla tenga valores faltantes.
Origin = categorical(cellstr(Origin));
Tbl = table(Acceleration,Displacement,Horsepower, ...
Model_Year,Origin,Weight,MPG);
Tbl = rmmissing(Tbl);Dividir datos
Divida los datos en conjuntos de entrenamiento, validación y prueba. En primer lugar, reserve aproximadamente un tercio de las observaciones para el conjunto de prueba. Después, divida por la mitad los datos restantes para crear los conjuntos de entrenamiento y validación.
rng("default") % For reproducibility of the data partitions cvp1 = cvpartition(size(Tbl,1),"Holdout",1/3); testTbl = Tbl(test(cvp1),:); remainingTbl = Tbl(training(cvp1),:); cvp2 = cvpartition(size(remainingTbl,1),"Holdout",1/2); validationTbl = remainingTbl(test(cvp2),:); trainTbl = remainingTbl(training(cvp2),:);
Entrenar una red neuronal
Entrene un modelo de red neuronal de regresión utilizando el conjunto de entrenamiento. Especifique la columna MPG de tblTrain como la variable de respuesta y estandarice los predictores numéricos. Evalúe el modelo en cada iteración usando el conjunto de validación. Especifique si desea mostrar la información de entrenamiento en cada iteración utilizando el argumento nombre-valor Verbose. De forma predeterminada, el proceso de entrenamiento termina antes de tiempo si la pérdida de validación es mayor que o igual a la pérdida mínima de validación calculada hasta ese momento, seis veces seguidas. Para cambiar el número de veces que se permite que la pérdida de validación sea mayor que o igual al valor mínimo, especifique el argumento nombre-valor ValidationPatience.
Mdl = fitrnet(trainTbl,"MPG","Standardize",true, ... "ValidationData",validationTbl, ... "Verbose",1);
|==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 1| 102.962345| 46.853164| 6.700877| 0.032779| 115.730384| 0| | 2| 55.403995| 22.171181| 1.811805| 0.020571| 53.086379| 0| | 3| 37.588848| 11.135231| 0.782861| 0.005298| 38.580002| 0| | 4| 29.713458| 8.379231| 0.392009| 0.003921| 31.021379| 0| | 5| 17.523851| 9.958164| 2.137584| 0.003729| 17.594863| 0| | 6| 12.700624| 2.957771| 0.744551| 0.003962| 14.209019| 0| | 7| 11.841152| 1.907378| 0.201770| 0.003880| 13.159899| 0| | 8| 10.162988| 2.542555| 0.576907| 0.003956| 11.352490| 0| | 9| 8.889095| 2.779980| 0.615716| 0.002668| 10.446334| 0| | 10| 7.670335| 2.400272| 0.648711| 0.011382| 10.424337| 0| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 11| 7.416274| 0.505111| 0.214707| 0.005407| 10.522517| 1| | 12| 7.338923| 0.880655| 0.119085| 0.004414| 10.648031| 2| | 13| 7.149407| 1.784821| 0.277908| 0.002899| 10.800952| 3| | 14| 6.866385| 1.904480| 0.472190| 0.005637| 10.839202| 4| | 15| 6.815575| 3.339285| 0.943063| 0.002956| 10.031692| 0| | 16| 6.428137| 0.684771| 0.133729| 0.003287| 9.867819| 0| | 17| 6.363299| 0.456606| 0.125363| 0.006535| 9.720076| 0| | 18| 6.289887| 0.742923| 0.152290| 0.009971| 9.576588| 0| | 19| 6.215407| 0.964684| 0.183503| 0.002971| 9.422910| 0| | 20| 6.078333| 2.124971| 0.566948| 0.002843| 9.599573| 1| |==========================================================================================| | Iteration | Train Loss | Gradient | Step | Iteration | Validation | Validation | | | | | | Time (sec) | Loss | Checks | |==========================================================================================| | 21| 5.947923| 1.217291| 0.583867| 0.003745| 9.618400| 2| | 22| 5.855505| 0.671774| 0.285123| 0.002729| 9.734680| 3| | 23| 5.831802| 1.882061| 0.657368| 0.001770| 10.365968| 4| | 24| 5.713261| 1.004072| 0.134719| 0.001882| 10.314258| 5| | 25| 5.520766| 0.967032| 0.290156| 0.001891| 10.177322| 6| |==========================================================================================|
Utilice la información que se encuentra dentro de la propiedad TrainingHistory del objeto Mdl para comprobar la iteración que se corresponde con el valor mínimo del error cuadrático medio (MSE) de la validación. El modelo devuelto final Mdl es el modelo entrenado en esta iteración.
iteration = Mdl.TrainingHistory.Iteration; valLosses = Mdl.TrainingHistory.ValidationLoss; [~,minIdx] = min(valLosses); iteration(minIdx)
ans = 19
Evaluar el rendimiento de un conjunto de prueba
Evalúe el rendimiento del modelo entrenado Mdl en el conjunto de prueba testTbl utilizando las funciones de objeto loss y predict.
Calcule el error cuadrático medio (MSE) del conjunto de prueba. Los valores de MSE menores indican un mejor rendimiento.
mse = loss(Mdl,testTbl,"MPG")mse = 7.4101
Compare los valores de respuesta del conjunto de prueba pronosticados con los valores de respuesta verdaderos. Represente las millas por galón (MPG) pronosticadas a lo largo del eje vertical y las MPG verdaderas a lo largo del eje horizontal. Los puntos de la línea de referencia indican predicciones correctas. Un buen modelo produce predicciones que aparecen dispersas cerca de la línea.
predictedY = predict(Mdl,testTbl); plot(testTbl.MPG,predictedY,".") hold on plot(testTbl.MPG,testTbl.MPG) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("Predicted Miles Per Gallon (MPG)")

Utilice gráficas de caja para comparar la distribución de los valores de MPG pronosticados y verdaderos por país de origen. Cree las gráficas de caja utilizando la función boxchart. Cada gráfica de caja muestra la mediana, el cuartil inferior y el cuartil superior, cualquier valor atípico (calculado usando el rango intercuartil) y los valores mínimo y máximo que no son valores atípicos. En concreto, la línea que se encuentra dentro de cada caja es la mediana de la muestra y los marcadores circulares indican valores atípicos.
Para cada país de origen, compare la gráfica de caja roja (que muestra la distribución de los valores de MPG pronosticados) con la gráfica de caja azul (que muestra la distribución de los valores de MPG verdaderos). Distribuciones similares para los valores de MPG predichos y verdaderos indican buenas predicciones.
boxchart(testTbl.Origin,testTbl.MPG) hold on boxchart(testTbl.Origin,predictedY) hold off legend(["True MPG","Predicted MPG"]) xlabel("Country of Origin") ylabel("Miles Per Gallon (MPG)")

Para la mayoría de países, los valores de MPG predichos y verdaderos tienen distribuciones similares. Algunas discrepancias pueden deberse al número reducido de coches en los conjuntos de entrenamiento y prueba.
Compare el rango de valores de MPG para los coches de los conjuntos de entrenamiento y prueba.
trainSummary = grpstats(trainTbl(:,["MPG","Origin"]),"Origin", ... "range")
trainSummary=6×3 table
Origin GroupCount range_MPG
_______ __________ _________
France France 2 1.2
Germany Germany 12 23.4
Italy Italy 1 0
Japan Japan 26 26.6
Sweden Sweden 4 8
USA USA 86 27
testSummary = grpstats(testTbl(:,["MPG","Origin"]),"Origin", ... "range")
testSummary=6×3 table
Origin GroupCount range_MPG
_______ __________ _________
France France 4 19.8
Germany Germany 13 20.3
Italy Italy 4 11.3
Japan Japan 26 25.6
Sweden Sweden 1 0
USA USA 82 29
Para países como Francia, Italia y Suecia, que tienen pocos coches en los conjuntos de entrenamiento y prueba, el rango de los valores de MPG varían significativamente en ambos conjuntos.
Represente los valores residuales del conjunto de prueba. Normalmente, un buen modelo tiene valores residuales dispersos de forma casi simétrica en torno a 0. Los patrones claros en los valores residuales son una señal de que puede mejorar el modelo.
residuals = testTbl.MPG - predictedY; plot(testTbl.MPG,residuals,".") hold on yline(0) hold off xlabel("True Miles Per Gallon (MPG)") ylabel("MPG Residuals")

La gráfica sugiere que los valores residuales están bien distribuidos.
Puede obtener más información sobre las observaciones con los mayores valores residuales, en términos de valor absoluto.
[~,residualIdx] = sort(residuals,"descend", ... "ComparisonMethod","abs"); residuals(residualIdx)
ans = 130×1
-8.8469
8.4427
8.0493
7.8996
-6.2220
5.8589
5.7007
-5.6733
-5.4545
5.1899
-4.9175
-4.8600
4.5415
-4.3959
-4.3915
⋮
Muestre las tres observaciones con los mayores valores residuales, es decir, con magnitudes superiores a 8.
testTbl(residualIdx(1:3),:)
ans=3×7 table
Acceleration Displacement Horsepower Model_Year Origin Weight MPG
____________ ____________ __________ __________ ______ ______ ____
17.6 91 68 82 Japan 1970 31
11.4 168 132 80 Japan 2910 32.7
13.8 91 67 80 Japan 1850 44.6
Consulte también
fitrnet | loss | predict | RegressionNeuralNetwork | boxchart