Entrenar una red con características numéricas
En este ejemplo se muestra cómo crear y entrenar una red neuronal sencilla para la clasificación de datos de características mediante deep learning.
Si tiene un conjunto de datos de características numéricas (por ejemplo, una colección de datos numéricos sin dimensiones espaciales ni temporales), puede entrenar una red de deep learning utilizando una capa de entrada de características. Para ver un ejemplo de cómo entrenar una red para la clasificación de imágenes, consulte Crear una red neuronal de deep learning sencilla para clasificación.
En este ejemplo se muestra cómo entrenar una red para clasificar la condición de los dientes del engranaje de un sistema de transmisión con una mezcla de lecturas numéricas de sensores, estadísticas y etiquetas categóricas.
Cargar datos
Cargue el conjunto de datos de la caja de engranajes para el entrenamiento. Este conjunto de datos está formado por 208 lecturas sintéticas de un sistema de engranajes formado por 18 lecturas numéricas y 3 etiquetas categóricas:
SigMean
: media de la señal de vibraciónSigMedian
: mediana de la señal de vibraciónSigRMS
: RMS de la señal de vibraciónSigVar
: varianza de la señal de vibraciónSigPeak
: pico de la señal de vibraciónSigPeak2Peak
: pico a pico de la señal de vibraciónSigSkewness
: asimetría de la señal de vibraciónSigKurtosis
: curtosis de la señal de vibraciónSigCrestFactor
: factor de cresta de la señal de vibraciónSigMAD
: MAD de la señal de vibraciónSigRangeCumSum
: suma de intervalos de la señal de vibraciónSigCorrDimension
: dimensión de correlación de la señal de vibraciónSigApproxEntropy
: entropía aproximada de la señal de vibraciónSigLyapExponent
: exponente de Lyapunov de la señal de vibraciónPeakFreq
: frecuencia picoHighFreqPower
: potencia de frecuencia altaEnvPower
: potencia de entornoPeakSpecKurtosis
: frecuencia pico de curtosis espectralSensorCondition
: condición de sensor, especificada como "desvío de sensor" o "sin desvío de sensor"ShaftCondition
: condición de eje, especificada como "desgaste de eje" o "sin desgaste de eje"GearToothCondition
: condición de diente de engranaje, especificada como "diente con error" o "diente sin error"
Lea los datos de la caja de engranajes del archivo CSV "transmissionCasingData.csv"
.
filename = "transmissionCasingData.csv"; tbl = readtable(filename,TextType="String");
Convierta las etiquetas para la predicción en categóricas utilizando la función convertvars
.
labelName = "GearToothCondition"; tbl = convertvars(tbl,labelName,"categorical");
Visualice las primeras filas de la tabla.
head(tbl)
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis SensorCondition ShaftCondition GearToothCondition ________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ _______________ __________________ -0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 "Sensor Drift" "No Shaft Wear" No Tooth Fault -0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 "Sensor Drift" "No Shaft Wear" No Tooth Fault 1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault 1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 "Sensor Drift" "Shaft Wear" No Tooth Fault 1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault 1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault 1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault 1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 "Sensor Drift" "Shaft Wear" No Tooth Fault
Para entrenar una red utilizando características categóricas, primero debe convertir las características categóricas en numéricas. Primero, convierta los predictores categóricos en numéricos con la función convertvars
especificando un arreglo de cadena que contenga los nombres de todas las variables de entrada categórica. En este conjunto de datos, hay dos características categóricas con los nombres "SensorCondition"
y "ShaftCondition"
.
categoricalInputNames = ["SensorCondition" "ShaftCondition"]; tbl = convertvars(tbl,categoricalInputNames,"categorical");
Forme un lazo con las variables de entrada categórica. Para cada variable:
Convierta los valores categóricos en vectores codificados one-hot usando la función
onehotencode
.Añada los vectores one-hot a la tabla utilizando la función
addvars
. Especifique que los vectores se inserten después de la columna que contiene los datos categóricos correspondientes.Elimine la columna correspondiente que contiene los datos categóricos.
for i = 1:numel(categoricalInputNames) name = categoricalInputNames(i); oh = onehotencode(tbl(:,name)); tbl = addvars(tbl,oh,After=name); tbl(:,name) = []; end
Divida los vectores en columnas independientes utilizando la función splitvars
.
tbl = splitvars(tbl);
Visualice las primeras filas de la tabla. Observe que los predictores categóricos se han dividido en varias columnas con los valores categóricos como los nombres de las variables.
head(tbl)
SigMean SigMedian SigRMS SigVar SigPeak SigPeak2Peak SigSkewness SigKurtosis SigCrestFactor SigMAD SigRangeCumSum SigCorrDimension SigApproxEntropy SigLyapExponent PeakFreq HighFreqPower EnvPower PeakSpecKurtosis No Sensor Drift Sensor Drift No Shaft Wear Shaft Wear GearToothCondition ________ _________ ______ _______ _______ ____________ ___________ ___________ ______________ _______ ______________ ________________ ________________ _______________ ________ _____________ ________ ________________ _______________ ____________ _____________ __________ __________________ -0.94876 -0.9722 1.3726 0.98387 0.81571 3.6314 -0.041525 2.2666 2.0514 0.8081 28562 1.1429 0.031581 79.931 0 6.75e-06 3.23e-07 162.13 0 1 1 0 No Tooth Fault -0.97537 -0.98958 1.3937 0.99105 0.81571 3.6314 -0.023777 2.2598 2.0203 0.81017 29418 1.1362 0.037835 70.325 0 5.08e-08 9.16e-08 226.12 0 1 1 0 No Tooth Fault 1.0502 1.0267 1.4449 0.98491 2.8157 3.6314 -0.04162 2.2658 1.9487 0.80853 31710 1.1479 0.031565 125.19 0 6.74e-06 2.85e-07 162.13 0 1 0 1 No Tooth Fault 1.0227 1.0045 1.4288 0.99553 2.8157 3.6314 -0.016356 2.2483 1.9707 0.81324 30984 1.1472 0.032088 112.5 0 4.99e-06 2.4e-07 162.13 0 1 0 1 No Tooth Fault 1.0123 1.0024 1.4202 0.99233 2.8157 3.6314 -0.014701 2.2542 1.9826 0.81156 30661 1.1469 0.03287 108.86 0 3.62e-06 2.28e-07 230.39 0 1 0 1 No Tooth Fault 1.0275 1.0102 1.4338 1.0001 2.8157 3.6314 -0.02659 2.2439 1.9638 0.81589 31102 1.0985 0.033427 64.576 0 2.55e-06 1.65e-07 230.39 0 1 0 1 No Tooth Fault 1.0464 1.0275 1.4477 1.0011 2.8157 3.6314 -0.042849 2.2455 1.9449 0.81595 31665 1.1417 0.034159 98.838 0 1.73e-06 1.55e-07 230.39 0 1 0 1 No Tooth Fault 1.0459 1.0257 1.4402 0.98047 2.8157 3.6314 -0.035405 2.2757 1.955 0.80583 31554 1.1345 0.0353 44.223 0 1.11e-06 1.39e-07 230.39 0 1 0 1 No Tooth Fault
Visualice los nombres de las clases del conjunto de datos.
classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
{'No Tooth Fault'}
{'Tooth Fault' }
Dividir un conjunto de datos en conjuntos de entrenamiento y de validación
Divida el conjunto de datos en particiones de entrenamiento, de validación y de prueba. Reserve el 15% de los datos para la validación y otro 15% para las pruebas.
Visualice el número de observaciones del conjunto de datos.
numObservations = size(tbl,1)
numObservations = 208
Determine el número de observaciones para cada partición.
numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32
Cree un arreglo de índices aleatorios que se corresponda con las observaciones y divídalo utilizando los tamaños de partición.
idx = randperm(numObservations); idxTrain = idx(1:numObservationsTrain); idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation); idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);
Divida la tabla de datos en particiones de entrenamiento, de validación y de prueba utilizando los índices.
tblTrain = tbl(idxTrain,:); tblValidation = tbl(idxValidation,:); tblTest = tbl(idxTest,:);
Definir la arquitectura de red
Defina la red para la clasificación.
Defina una red con una capa de entrada de características y especifique el número de características. Configure también la capa de entrada para normalizar los datos utilizando la normalización de puntuación Z. A continuación, incluya una capa completamente conectada con un tamaño de salida de 50, seguida de una capa de normalización de lotes y una capa ReLU. Para la clasificación, especifique otra capa totalmente conectada con un tamaño de salida que se corresponda con el número de clases, seguida de una capa softmax.
numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
layers = [
featureInputLayer(numFeatures,Normalization="zscore")
fullyConnectedLayer(50)
batchNormalizationLayer
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer];
Especificar las opciones de entrenamiento
Especifique las opciones de entrenamiento.
Entrene la red con Adam.
Realice el entrenamiento empleando minilotes de un tamaño de 16.
Cambie el orden de los datos en cada época.
Monitorice la precisión de la red durante el entrenamiento especificando datos de validación.
Muestre el progreso del entrenamiento en una gráfica y omita la salida de la ventana de comandos detallada.
El software entrena la red según los datos de entrenamiento y calcula la precisión de los datos de validación en intervalos regulares durante el entrenamiento. Los datos de validación no se utilizan para actualizar los pesos de la red.
miniBatchSize = 16; options = trainingOptions("adam", ... MiniBatchSize=miniBatchSize, ... Shuffle="every-epoch", ... ValidationData=tblValidation, ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);
Entrenar la red
Entrene la red con la arquitectura definida por layers
, los datos de entrenamiento y las opciones de entrenamiento. De forma predeterminada, trainnet
usa una GPU en caso de que esté disponible. De lo contrario, usa una CPU. Entrenar en una GPU requiere Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). También puede especificar el entorno de ejecución con el argumento nombre-valor ExecutionEnvironment
de trainingOptions
.
La gráfica de progreso del entrenamiento muestra la pérdida y la precisión de minilotes y la pérdida y la precisión de validación. Para obtener más información sobre la gráfica de progreso del entrenamiento, consulte Monitorizar el progreso del entrenamiento de deep learning.
net = trainnet(tblTrain,layers,"crossentropy",options);
Probar la red
Prediga las etiquetas de los datos de prueba con la red entrenada y calcule la precisión. Especifique el mismo tamaño de minilote utilizado para el entrenamiento.
scores = minibatchpredict(net,tblTest(:,1:end-1),MiniBatchSize=miniBatchSize); YPred = scores2label(scores,classNames);
Calcule la precisión de clasificación. La precisión es la proporción de etiquetas que la red predice correctamente.
YTest = tblTest{:,labelName}; accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9375
Visualice los resultados en una matriz de confusión.
figure confusionchart(YTest,YPred)
Consulte también
trainnet
| trainingOptions
| dlnetwork
| fullyConnectedLayer
| Deep Network Designer | featureInputLayer
Ejemplos relacionados
- Crear una red neuronal de deep learning sencilla para clasificación
- Entrenar una red neuronal convolucional para regresión