ANN for multi-class classification

2 views (last 30 days)
Aleef Fouzi
Aleef Fouzi on 13 Jan 2022
Answered: Prince Kumar on 17 Jan 2022
Im having problem to execute AUC, can someone help me, i read somewhere that perfcurve function could not be used for multi-class classification. I tried trapz but it is still error.
inputs = data_inputs';
targets = data_targets';
% Create a Pattern Recognition Network
hiddenLayerSize = 10;
net = patternnet(hiddenLayerSize, 'trainscg');
% Setup Division of Data for Training, Validation, Testing
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
% Train the Network
[net,tr] = train(net,inputs,targets);
% Test the Network
outputs = net(inputs);
e = gsubtract(targets,outputs);
performance = perform(net,targets,outputs)
% View the Network
view(net)
% Store the value of TP,TN,FP,FN
[c,cm]=confusion(targets,outputs);
% Print percentage for correct classification (Accuracy)
fprintf('Percentage Correct Classification: %f%%\n', (1-c)* 100);
% Print percentage for incorrect classification (Accuracy)
fprintf('Percentage Incorrect Classification: %f%%\n', 100*c);
% Print the error performance
fprintf('Error performance: %f\n',performance)
% Plots
figure, plotperform(tr)
figure, plotconfusion(targets,outputs)
figure, plotroc(targets,outputs)
%% Obtain the AUC
[tpr,fpr] = roc(targets,outputs);
colAUC = perfcurve(fpr,tpr);
these are my code, ignore the %% Obtain the AUC section because it is not working.
Not enough input arguments.
Error in perfcurve (line 450)
posClass,negClass,trueNames);
Error in ann (line 44)
colAUC = perfcurve(fpr,tpr);
and these are the error i got when run the command above.

Answers (1)

Prince Kumar
Prince Kumar on 17 Jan 2022
Hi,
According to your code, 'perfcurve' is taking true postive rate(tpr) and false postive rate(fpr) as input but 'perfcurve' takes true class labels, scores and the postive class label as input.
Following snippet shows the same :
[X,Y] = perfcurve(labels,scores,posclass)
You can visualize the performance curve using 'plot(X,Y)'.
For more information please refer to the documentation : perfcurve

Products


Release

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by