Esta página aún no se ha traducido para esta versión. Puede ver la versión más reciente de esta página en inglés.

Regularización de conjunto

La regularización es un proceso de elegir menos estudiantes débiles para un conjunto de una manera que no disminuya el rendimiento predictivo. Actualmente se pueden regularizar conjuntos de regresión. (También puede regularizar un clasificador de análisis discriminante en un contexto que no es de conjunto; ver.)Regularice un clasificador de análisis discriminante

el regularize método encuentra un conjunto óptimo de pesos del alumno Αt que minimizan

n=1Nwng((t=1Tαtht(xn)),yn)+λt=1T|αt|.

Aquí

  • λ ≥ 0 es un parámetro que usted proporciona, llamado el parámetro de lazo.

  • Ht es un aprendiz débil en el conjunto entrenado en observaciones con predictoresN XnRespuestas yn, y pesos Wn.

  • g(f,y) = (fy)2 es el error cuadrado.

El conjunto se regulariza en el mismo (Xn,yn,Wn) datos utilizados para la formación, por lo que

n=1Nwng((t=1Tαtht(xn)),yn)

es el error de reenvío del conjunto. El error se mide por el error cuadrado medio (MSE).

Si utiliza λ = 0, regularize encuentra las ponderaciones débiles del alumno minimizando el MSE de reenvío. Los conjuntos tienden a sobreentrenarse. En otras palabras, el error de reenvío suele ser menor que el error de generalización real. Al hacer el error de reenvío aún más pequeño, es probable que la precisión del conjunto sea peor en lugar de mejorarla. Por otro lado, los valores positivos de empujar la magnitud de laλ Αt coeficientes a 0. Esto a menudo mejora el error de generalización. Por supuesto, si usted elige demasiado grande, todos los coeficientes óptimos son 0, y el conjunto no tiene ninguna precisión.λ Por lo general, se puede encontrar un rango óptimo en el que la precisión del conjunto regularizado es mejor o comparable a la del conjunto completo sin regularización.λ

Una buena característica de la regularización de lazo es su capacidad para conducir los coeficientes optimizados precisamente a 0. Si el peso de un alumno Αt es 0, este alumno puede excluirse del conjunto regularizado. Al final, obtienes un conjunto con una precisión mejorada y menos estudiantes.

Regularice un conjunto de regresión

Este ejemplo utiliza datos para predecir el riesgo del seguro de un automóvil en función de sus muchos atributos.

Cargue los datos en el espacio de trabajo de MATLAB.imports-85

load imports-85;

Mire una descripción de los datos para encontrar las variables categóricas y los nombres predictores.

Description
Description = 9x79 char array
    '1985 Auto Imports Database from the UCI repository                             '
    'http://archive.ics.uci.edu/ml/machine-learning-databases/autos/imports-85.names'
    'Variables have been reordered to place variables with numeric values (referred '
    'to as "continuous" on the UCI site) to the left and categorical values to the  '
    'right. Specifically, variables 1:16 are: symboling, normalized-losses,         '
    'wheel-base, length, width, height, curb-weight, engine-size, bore, stroke,     '
    'compression-ratio, horsepower, peak-rpm, city-mpg, highway-mpg, and price.     '
    'Variables 17:26 are: make, fuel-type, aspiration, num-of-doors, body-style,    '
    'drive-wheels, engine-location, engine-type, num-of-cylinders, and fuel-system. '

El objetivo de este proceso es predecir el "symboling", la primera variable en los datos, de los otros predictores. el "symboling" es un número entero de (buen riesgo de seguro) a (riesgo de seguro deficiente).-33 Podría utilizar un conjunto de clasificación para predecir este riesgo en lugar de un conjunto de regresión. Si tiene la opción de elegir entre la regresión y la clasificación, primero debe intentar la regresión.

Prepare los datos para el ajuste del conjunto.

Y = X(:,1); X(:,1) = []; VarNames = {'normalized-losses' 'wheel-base' 'length' 'width' 'height' ...   'curb-weight' 'engine-size' 'bore' 'stroke' 'compression-ratio' ...   'horsepower' 'peak-rpm' 'city-mpg' 'highway-mpg' 'price' 'make' ...   'fuel-type' 'aspiration' 'num-of-doors' 'body-style' 'drive-wheels' ...   'engine-location' 'engine-type' 'num-of-cylinders' 'fuel-system'}; catidx = 16:25; % indices of categorical predictors

Cree un conjunto de regresión a partir de los datos utilizando 300 árboles.

ls = fitrensemble(X,Y,'Method','LSBoost','NumLearningCycles',300, ...     'LearnRate',0.1,'PredictorNames',VarNames, ...     'ResponseName','Symboling','CategoricalPredictors',catidx)
ls =    classreg.learning.regr.RegressionEnsemble            PredictorNames: {1x25 cell}              ResponseName: 'Symboling'     CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]         ResponseTransform: 'none'           NumObservations: 205                NumTrained: 300                    Method: 'LSBoost'              LearnerNames: {'Tree'}      ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'                   FitInfo: [300x1 double]        FitInfoDescription: {2x1 cell}            Regularization: []     Properties, Methods  

La línea final, está vacía ([]).Regularization Para regularizar el conjunto, usted tiene que utilizar el método.regularize

cv = crossval(ls,'KFold',5); figure; plot(kfoldLoss(cv,'Mode','Cumulative')); xlabel('Number of trees'); ylabel('Cross-validated MSE'); ylim([0.2,2])

Parece que usted puede obtener un rendimiento satisfactorio de un conjunto más pequeño, tal vez uno que contiene de 50 a 100 árboles.

Llame al método para tratar de encontrar los árboles que se pueden quitar del conjunto.regularize De forma predeterminada, examina 10 valores del parámetro Lasso () espaciados exponencialmente.regularizeLambda

ls = regularize(ls)
ls =    classreg.learning.regr.RegressionEnsemble            PredictorNames: {1x25 cell}              ResponseName: 'Symboling'     CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]         ResponseTransform: 'none'           NumObservations: 205                NumTrained: 300                    Method: 'LSBoost'              LearnerNames: {'Tree'}      ReasonForTermination: 'Terminated normally after completing the requested number of training cycles.'                   FitInfo: [300x1 double]        FitInfoDescription: {2x1 cell}            Regularization: [1x1 struct]     Properties, Methods  

La propiedad ya no está vacía.Regularization

Trace el error cuadrático de reenvío (MSE) y el número de alumnos con pesos distintos de cero en el parámetro Lasso. Trace el valor por separado en.Lambda = 0 Utilice una escala logarítmica porque los valores de están espaciados exponencialmente.Lambda

figure; semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...     'bx-','Markersize',10); line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...      ls.Regularization.ResubstitutionMSE(1)],...     'Marker','x','Markersize',10,'Color','b'); r0 = resubLoss(ls); line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...      [r0 r0],'Color','r','LineStyle','--'); xlabel('Lambda'); ylabel('Resubstitution MSE'); annotation('textbox',[0.5 0.22 0.5 0.05],'String','unregularized ensemble', ...     'Color','r','FontSize',14,'LineStyle','none');

 figure; loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1)); line([1e-3 1e-3],...     [sum(ls.Regularization.TrainedWeights(:,1)>0) ...     sum(ls.Regularization.TrainedWeights(:,1)>0)],...     'marker','x','markersize',10,'color','b'); line([ls.Regularization.Lambda(2) ls.Regularization.Lambda(end)],...     [ls.NTrained ls.NTrained],...     'color','r','LineStyle','--'); xlabel('Lambda'); ylabel('Number of learners'); annotation('textbox',[0.3 0.8 0.5 0.05],'String','unregularized ensemble',...     'color','r','FontSize',14,'LineStyle','none');

Es probable que los valores MSE de reenvío sean excesivamente optimistas. Para obtener estimaciones más fiables del error asociado con varios valores de, valide el conjunto usando.Lambdacvshrink Trace la pérdida de validación cruzada resultante (MSE) y el número de alumnos en contra.Lambda

rng(0,'Twister') % for reproducibility [mse,nlearn] = cvshrink(ls,'Lambda',ls.Regularization.Lambda,'KFold',5);
Warning: Some folds do not have any trained weak learners. 
 figure; semilogx(ls.Regularization.Lambda,ls.Regularization.ResubstitutionMSE, ...     'bx-','Markersize',10); hold on; semilogx(ls.Regularization.Lambda,mse,'ro-','Markersize',10); hold off; xlabel('Lambda'); ylabel('Mean squared error'); legend('resubstitution','cross-validation','Location','NW'); line([1e-3 1e-3],[ls.Regularization.ResubstitutionMSE(1) ...      ls.Regularization.ResubstitutionMSE(1)],...     'Marker','x','Markersize',10,'Color','b','HandleVisibility','off'); line([1e-3 1e-3],[mse(1) mse(1)],'Marker','o',...     'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

 figure; loglog(ls.Regularization.Lambda,sum(ls.Regularization.TrainedWeights>0,1)); hold;
Current plot held 
loglog(ls.Regularization.Lambda,nlearn,'r--'); hold off; xlabel('Lambda'); ylabel('Number of learners'); legend('resubstitution','cross-validation','Location','NE'); line([1e-3 1e-3],...     [sum(ls.Regularization.TrainedWeights(:,1)>0) ...     sum(ls.Regularization.TrainedWeights(:,1)>0)],...     'Marker','x','Markersize',10,'Color','b','HandleVisibility','off'); line([1e-3 1e-3],[nlearn(1) nlearn(1)],'marker','o',...     'Markersize',10,'Color','r','LineStyle','--','HandleVisibility','off');

El examen del error validado cruzado muestra que el MSE de la validación cruzada es casi llano para hasta un poco encima.Lambda1e-2

Examine para encontrar el valor más alto que da MSE en la región plana (hasta un poco encima).ls.Regularization.Lambda1e-2

jj = 1:length(ls.Regularization.Lambda); [jj;ls.Regularization.Lambda]
ans = 2×10

    1.0000    2.0000    3.0000    4.0000    5.0000    6.0000    7.0000    8.0000    9.0000   10.0000
         0    0.0019    0.0045    0.0107    0.0254    0.0602    0.1428    0.3387    0.8033    1.9048

Elemento de tiene valor, el más grande en el rango plano.5ls.Regularization.Lambda0.0254

Reduzca el tamaño del conjunto utilizando el método. Devuelve un conjunto compacto sin datos de entrenamiento.shrinkshrink El error de generalización para el nuevo conjunto compacto ya se estimó mediante la validación cruzada.mse(5)

cmp = shrink(ls,'weightcolumn',5)
cmp =    classreg.learning.regr.CompactRegressionEnsemble            PredictorNames: {1x25 cell}              ResponseName: 'Symboling'     CategoricalPredictors: [16 17 18 19 20 21 22 23 24 25]         ResponseTransform: 'none'                NumTrained: 8     Properties, Methods  

El número de árboles en el nuevo conjunto ha disminuido notablemente de la 300 en.ls

Compare los tamaños de los conjuntos.

sz(1) = whos('cmp'); sz(2) = whos('ls'); [sz(1).bytes sz(2).bytes]
ans = 1×2

       91536     3237183

El tamaño del conjunto reducido es una fracción del tamaño del original. Tenga en cuenta que los tamaños de conjunto pueden variar en función del sistema operativo.

Compare el MSE del conjunto reducido con el del conjunto original.

figure; plot(kfoldLoss(cv,'mode','cumulative')); hold on plot(cmp.NTrained,mse(5),'ro','MarkerSize',10); xlabel('Number of trees'); ylabel('Cross-validated MSE'); legend('unregularized ensemble','regularized ensemble',...     'Location','NE'); hold off

El conjunto reducido da baja pérdida mientras se utilizan muchos menos árboles.

Consulte también

| | | | | |

Temas relacionados