Esta página aún no se ha traducido para esta versión. Puede ver la versión más reciente de esta página en inglés.

Clasificación con datos desequilibrados

Este ejemplo muestra cómo realizar la clasificación cuando una clase tiene muchas más observaciones que otra. Utilice el algoritmo primero, porque está diseñado para controlar este caso.RUSBoost Otra forma de controlar los datos desequilibrados es usar los argumentos de par nombre-valor o.'Prior''Cost' Para obtener más información, consulte.Manejar datos desequilibrados o costos desiguales de clasificación errónea en conjuntos de clasificación

Este ejemplo utiliza los datos de "tipo de portada" del archivo de aprendizaje automático de UCI, que se describe en.https://archive.ics.uci.edu/ml/datasets/Covertype Los datos clasifican los tipos de bosque (cubierta de tierra), basados en predictores como la elevación, el tipo de suelo y la distancia al agua. Los datos tienen más de 500.000 observaciones y más de 50 predictores, por lo que el entrenamiento y el uso de un clasificador consumen mucho tiempo.

Blackard y Dean describen una clasificación de red neuronal de estos datos.[1] Citan una precisión de clasificación del 70,6%. obtiene más de 81% de precisión de clasificación.RUSBoost

Obtener los datos

Importe los datos en el espacio de trabajo. Extraiga la última columna de datos en una variable denominada.Y

gunzip('https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz') load covtype.data Y = covtype(:,end); covtype(:,end) = [];

Examine los datos de respuesta

tabulate(Y)
  Value    Count   Percent       1    211840     36.46%       2    283301     48.76%       3    35754      6.15%       4     2747      0.47%       5     9493      1.63%       6    17367      2.99%       7    20510      3.53% 

Hay cientos de miles de puntos de datos. Los de la clase 4 son menos del 0,5% del total. Este desequilibrio indica que es un algoritmo apropiado.RUSBoost

Particionar los datos para la evaluación de calidad

Utilice la mitad de los datos para ajustar un clasificador y la mitad para examinar la calidad del clasificador resultante.

rng(10,'twister')         % For reproducibility part = cvpartition(Y,'Holdout',0.5); istrain = training(part); % Data for fitting istest = test(part);      % Data for quality assessment tabulate(Y(istrain))
  Value    Count   Percent       1    105919     36.46%       2    141651     48.76%       3    17877      6.15%       4     1374      0.47%       5     4747      1.63%       6     8684      2.99%       7    10254      3.53% 

Crea el conjunto

Utilice árboles profundos para una mayor precisión del conjunto. Para ello, establezca que los árboles tengan el número máximo de divisiones de decisión, donde está el número de observaciones en la muestra de formación.NN Fije para alcanzar la mayor precisión también.LearnRate0.1 Los datos son grandes y, con árboles profundos, crear el conjunto consume mucho tiempo.

N = sum(istrain);         % Number of observations in the training sample t = templateTree('MaxNumSplits',N); tic rusTree = fitcensemble(covtype(istrain,:),Y(istrain),'Method','RUSBoost', ...     'NumLearningCycles',1000,'Learners',t,'LearnRate',0.1,'nprint',100);
Training RUSBoost... Grown weak learners: 100 Grown weak learners: 200 Grown weak learners: 300 Grown weak learners: 400 Grown weak learners: 500 Grown weak learners: 600 Grown weak learners: 700 Grown weak learners: 800 Grown weak learners: 900 Grown weak learners: 1000 
toc
Elapsed time is 411.187279 seconds. 

Inspeccione el error de clasificación

Trace el error de clasificación en el número de miembros del conjunto.

figure; tic plot(loss(rusTree,covtype(istest,:),Y(istest),'mode','cumulative')); toc
Elapsed time is 192.732718 seconds. 
grid on; xlabel('Number of trees'); ylabel('Test classification error');

El conjunto logra un error de clasificación de menos del 20% usando 116 o más árboles. Para 500 o más árboles, el error de clasificación disminuye a una velocidad más lenta.

Examine la matriz de confusión para cada clase como un porcentaje de la clase true.

tic Yfit = predict(rusTree,covtype(istest,:)); toc
Elapsed time is 158.760052 seconds. 
confusionchart(Y(istest),Yfit,'Normalization','row-normalized','RowSummary','row-normalized');

Todas las clases excepto la clase 2 tienen más de 90% de precisión de clasificación. Pero la clase 2 hace cerca de la mitad de los datos, por lo que la precisión general no es tan alta.

Compacte el conjunto

El conjunto es grande. Elimine los datos mediante el método.compact

cmpctRus = compact(rusTree);  sz(1) = whos('rusTree'); sz(2) = whos('cmpctRus'); [sz(1).bytes sz(2).bytes]
ans = 1×2
109 ×

    1.6578    0.9421

El conjunto compactado es aproximadamente la mitad del tamaño del original.

Quita la mitad de los árboles.cmpctRus Es probable que esta acción tenga un efecto mínimo en el rendimiento predictivo, basándose en la observación de que 500 de 1000 árboles dan una precisión casi óptima.

cmpctRus = removeLearners(cmpctRus,[500:1000]);  sz(3) = whos('cmpctRus'); sz(3).bytes
ans = 452813153 

El conjunto compacto reducido lleva aproximadamente un cuarto de la memoria del conjunto completo. Su tasa de pérdida global es de menos del 19%:

L = loss(cmpctRus,covtype(istest,:),Y(istest))
L = 0.1833 

La precisión predictiva en los nuevos datos puede diferir, porque la precisión del conjunto puede ser sesgada. El sesgo surge porque los mismos datos utilizados para evaluar el conjunto se utilizaron para reducir el tamaño del conjunto. Para obtener una estimación imparcial del tamaño de conjunto necesario, debe usar la validación cruzada. Sin embargo, ese procedimiento consume mucho tiempo.

Referencias

[1] Blackard, J. A. and D. J. Dean. "Comparative accuracies of artificial neural networks and discriminant analysis in predicting forest cover types from cartographic variables". Computers and Electronics in Agriculture Vol. 24, Issue 3, 1999, pp. 131–151.

Consulte también

| | | | | | | | | | |

Temas relacionados