Esta página aún no se ha traducido para esta versión. Puede ver la versión más reciente de esta página en inglés.

Optimice un clasificador SVM con validación cruzada mediantebayesopt

En este ejemplo se muestra cómo optimizar una clasificación de SVM mediante la función.bayesopt La clasificación funciona en ubicaciones de puntos de un modelo de mezcla Gaussiana. En, Hastie, Tibshirani y Friedman (2009), página 17 describe el modelo.The Elements of Statistical Learning El modelo comienza con la generación de 10 puntos de base para una clase "verde", distribuida como normales 2-D independientes con media (1, 0) y varianza de unidad. También genera 10 puntos base para una clase "roja", distribuida como normales 2-D independientes con media (0,1) y varianza de unidad. Para cada clase (verde y rojo), genere 100 puntos aleatorios de la siguiente manera:

  1. Elija un punto base del color adecuado uniformemente al azar.m

  2. Genere un punto aleatorio independiente con una distribución normal en 2-D con la media y la varianza I/5, donde I es la matriz de identidad 2 por 2.m En este ejemplo, utilice una varianza I/50 para mostrar la ventaja de la optimización con más claridad.

Después de generar 100 puntos verdes y 100 rojos, clasificarlos usando.fitcsvm A continuación, utilízese para optimizar los parámetros del modelo SVM resultante con respecto a la validación cruzada.bayesopt

Genere los puntos y clasificador

Genere los 10 puntos base para cada clase.

rng default grnpop = mvnrnd([1,0],eye(2),10); redpop = mvnrnd([0,1],eye(2),10);

Ver los puntos base.

plot(grnpop(:,1),grnpop(:,2),'go') hold on plot(redpop(:,1),redpop(:,2),'ro') hold off

Dado que algunos puntos de base rojos están cerca de puntos de base verdes, puede ser difícil clasificar los puntos de datos en función de la ubicación por sí solo.

Genere los 100 puntos de datos de cada clase.

redpts = zeros(100,2);grnpts = redpts; for i = 1:100     grnpts(i,:) = mvnrnd(grnpop(randi(10),:),eye(2)*0.02);     redpts(i,:) = mvnrnd(redpop(randi(10),:),eye(2)*0.02); end

Ver los puntos de datos.

figure plot(grnpts(:,1),grnpts(:,2),'go') hold on plot(redpts(:,1),redpts(:,2),'ro') hold off

Preparar datos para la clasificación

Coloque los datos en una matriz y haga un vector que etiqueta la clase de cada punto.grp

cdata = [grnpts;redpts]; grp = ones(200,1); % Green label 1, red label -1 grp(101:200) = -1;

Prepare la validación cruzada

Configure una partición para la validación cruzada. Este paso corrige los conjuntos de trenes y pruebas que utiliza la optimización en cada paso.

c = cvpartition(200,'KFold',10);

Prepare variables para la optimización bayesiana

Configurar una función que toma una entrada y devuelve el valor de pérdida de validación cruzada de.z = [rbf_sigma,boxconstraint]z Tome los componentes de variables positivas, de transformación logaritmo entre y.z1e-51e5 Elija un amplio rango, porque no sabe qué valores son propensos a ser buenos.

sigma = optimizableVariable('sigma',[1e-5,1e5],'Transform','log'); box = optimizableVariable('box',[1e-5,1e5],'Transform','log');

Función objetivo

Este identificador de función calcula la pérdida de validación cruzada en los parámetros.[sigma,box] Para obtener más información, consulte.kfoldLoss

pasa la variable a la función objetivo como una tabla de una fila.bayesoptz

minfn = @(z)kfoldLoss(fitcsvm(cdata,grp,'CVPartition',c,...     'KernelFunction','rbf','BoxConstraint',z.box,...     'KernelScale',z.sigma));

Optimizar clasificador

Busca los mejores parámetros usando.[sigma,box]bayesopt Para reproducibilidad, seleccione la función de adquisición.'expected-improvement-plus' La función de adquisición predeterminada depende del tiempo de ejecución y, por lo tanto, puede dar resultados variables.

results = bayesopt(minfn,[sigma,box],'IsObjectiveDeterministic',true,...     'AcquisitionFunctionName','expected-improvement-plus')

|=====================================================================================================| | Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box | |      | result |             | runtime     | (observed)  | (estim.)    |              |              | |=====================================================================================================| |    1 | Best   |        0.61 |     0.37034 |        0.61 |        0.61 |   0.00013375 |        13929 | |    2 | Best   |       0.345 |     0.37734 |       0.345 |       0.345 |        24526 |        1.936 | |    3 | Accept |        0.61 |     0.38481 |       0.345 |       0.345 |    0.0026459 |   0.00084929 | |    4 | Accept |       0.345 |     0.43188 |       0.345 |       0.345 |       3506.3 |   6.7427e-05 | |    5 | Accept |       0.345 |     0.26215 |       0.345 |       0.345 |       9135.2 |       571.87 | |    6 | Accept |       0.345 |      0.2454 |       0.345 |       0.345 |        99701 |        10223 | |    7 | Best   |       0.295 |     0.30424 |       0.295 |       0.295 |       455.88 |       9957.4 | |    8 | Best   |        0.24 |      4.2242 |        0.24 |        0.24 |        31.56 |        99389 | |    9 | Accept |        0.24 |      5.2582 |        0.24 |        0.24 |       10.451 |        64429 | |   10 | Accept |        0.35 |     0.22117 |        0.24 |        0.24 |       17.331 |   1.0264e-05 | |   11 | Best   |        0.23 |      3.3915 |        0.23 |        0.23 |       16.005 |        90155 | |   12 | Best   |         0.1 |     0.44594 |         0.1 |         0.1 |      0.36562 |        80878 | |   13 | Accept |       0.115 |     0.37155 |         0.1 |         0.1 |       0.1793 |        68459 | |   14 | Accept |       0.105 |     0.27767 |         0.1 |         0.1 |       0.2267 |        95421 | |   15 | Best   |       0.095 |     0.24329 |       0.095 |       0.095 |      0.28999 |    0.0058227 | |   16 | Best   |       0.075 |     0.38443 |       0.075 |       0.075 |      0.30554 |       8.9017 | |   17 | Accept |       0.085 |      0.3546 |       0.075 |       0.075 |      0.41122 |       4.4476 | |   18 | Accept |       0.085 |     0.21907 |       0.075 |       0.075 |      0.25565 |       7.8038 | |   19 | Accept |       0.075 |     0.21965 |       0.075 |       0.075 |      0.32869 |       18.076 | |   20 | Accept |       0.085 |     0.27585 |       0.075 |       0.075 |      0.32442 |       5.2118 | |=====================================================================================================| | Iter | Eval   | Objective   | Objective   | BestSoFar   | BestSoFar   |        sigma |          box | |      | result |             | runtime     | (observed)  | (estim.)    |              |              | |=====================================================================================================| |   21 | Accept |         0.3 |     0.24937 |       0.075 |       0.075 |       1.3592 |    0.0098067 | |   22 | Accept |        0.12 |     0.25383 |       0.075 |       0.075 |      0.17515 |   0.00070913 | |   23 | Accept |       0.175 |     0.21193 |       0.075 |       0.075 |       0.1252 |     0.010749 | |   24 | Accept |       0.105 |     0.21899 |       0.075 |       0.075 |       1.1664 |        31.13 | |   25 | Accept |         0.1 |     0.24818 |       0.075 |       0.075 |      0.57465 |       2013.8 | |   26 | Accept |        0.12 |     0.17319 |       0.075 |       0.075 |      0.42922 |   1.1602e-05 | |   27 | Accept |        0.12 |     0.23357 |       0.075 |       0.075 |      0.42956 |   0.00027218 | |   28 | Accept |       0.095 |     0.14202 |       0.075 |       0.075 |       0.4806 |       13.452 | |   29 | Accept |       0.105 |     0.25958 |       0.075 |       0.075 |      0.19755 |       943.87 | |   30 | Accept |       0.205 |     0.26937 |       0.075 |       0.075 |       3.5051 |       93.492 |  __________________________________________________________ Optimization completed. MaxObjectiveEvaluations of 30 reached. Total function evaluations: 30 Total elapsed time: 87.0198 seconds. Total objective function evaluation time: 20.5233  Best observed feasible point:      sigma      box       _______    ______      0.30554    8.9017  Observed objective function value = 0.075 Estimated objective function value = 0.075 Function evaluation time = 0.38443  Best estimated feasible point (according to models):      sigma      box       _______    ______      0.32869    18.076  Estimated objective function value = 0.075 Estimated function evaluation time = 0.2349 
results =    BayesianOptimization with properties:                        ObjectiveFcn: [function_handle]               VariableDescriptions: [1x2 optimizableVariable]                            Options: [1x1 struct]                       MinObjective: 0.0750                    XAtMinObjective: [1x2 table]              MinEstimatedObjective: 0.0750           XAtMinEstimatedObjective: [1x2 table]            NumObjectiveEvaluations: 30                   TotalElapsedTime: 87.0198                          NextPoint: [1x2 table]                             XTrace: [30x2 table]                     ObjectiveTrace: [30x1 double]                   ConstraintsTrace: []                      UserDataTrace: {30x1 cell}       ObjectiveEvaluationTimeTrace: [30x1 double]                 IterationTimeTrace: [30x1 double]                         ErrorTrace: [30x1 double]                   FeasibilityTrace: [30x1 logical]        FeasibilityProbabilityTrace: [30x1 double]                IndexOfMinimumTrace: [30x1 double]              ObjectiveMinimumTrace: [30x1 double]     EstimatedObjectiveMinimumTrace: [30x1 double]  

Utilice los resultados para entrenar un nuevo clasificador SVM optimizado.

z(1) = results.XAtMinObjective.sigma; z(2) = results.XAtMinObjective.box; SVMModel = fitcsvm(cdata,grp,'KernelFunction','rbf',...     'KernelScale',z(1),'BoxConstraint',z(2));

Trace los límites de clasificación. Para visualizar el clasificador de vectores de soporte, predecir puntuaciones sobre una rejilla.

d = 0.02; [x1Grid,x2Grid] = meshgrid(min(cdata(:,1)):d:max(cdata(:,1)),...     min(cdata(:,2)):d:max(cdata(:,2))); xGrid = [x1Grid(:),x2Grid(:)]; [~,scores] = predict(SVMModel,xGrid);  h = nan(3,1); % Preallocation figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3) = plot(cdata(SVMModel.IsSupportVector,1),...     cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h,{'-1','+1','Support Vectors'},'Location','Southeast'); axis equal hold off

Evalúe la precisión en nuevos datos

Genere y clasifique algunos puntos de datos nuevos.

grnobj = gmdistribution(grnpop,.2*eye(2)); redobj = gmdistribution(redpop,.2*eye(2));  newData = random(grnobj,10); newData = [newData;random(redobj,10)]; grpData = ones(20,1); grpData(11:20) = -1; % red = -1  v = predict(SVMModel,newData);  g = nan(7,1); figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); h(5) = plot(cdata(SVMModel.IsSupportVector,1),...     cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k'); legend(h(1:5),{'-1 (training)','+1 (training)','-1 (classified)',...     '+1 (classified)','Support Vectors'},'Location','Southeast'); axis equal hold off

Vea qué nuevos puntos de datos se clasifican correctamente. Circule los puntos clasificados correctamente en rojo y los puntos clasificados incorrectamente en negro.

mydiff = (v == grpData); % Classified correctly figure; h(1:2) = gscatter(cdata(:,1),cdata(:,2),grp,'rg','+*'); hold on h(3:4) = gscatter(newData(:,1),newData(:,2),v,'mc','**'); h(5) = plot(cdata(SVMModel.IsSupportVector,1),...     cdata(SVMModel.IsSupportVector,2),'ko'); contour(x1Grid,x2Grid,reshape(scores(:,2),size(x1Grid)),[0 0],'k');  for ii = mydiff % Plot red squares around correct pts     h(6) = plot(newData(ii,1),newData(ii,2),'rs','MarkerSize',12); end  for ii = not(mydiff) % Plot black squares around incorrect pts     h(7) = plot(newData(ii,1),newData(ii,2),'ks','MarkerSize',12); end legend(h,{'-1 (training)','+1 (training)','-1 (classified)',...     '+1 (classified)','Support Vectors','Correctly Classified',...     'Misclassified'},'Location','Southeast'); hold off

Consulte también

|

Temas relacionados