CNNへの交差検定(​Cross-Vali​dation)の導入​の仕方

19 visualizaciones (últimos 30 días)
ssk
ssk el 7 de Feb. de 2019
Editada: ssk el 11 de Feb. de 2019
プログラミング初心者です。
現在、チュートリアルのコードを微修正して動かしており、以下のコードに交差検定の追加を検討しております。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.5,'randomize');
help crossvarで検索すると、以下のようにでてきました。
TESTVAL = FUN(XTRAIN,XTEST)
こちらを、TESTVAL = FUN(imdsTrain, imdsValidation)とすると交差検定を導入できるという認識で
コンパイルしたのですが動きませんでした。
Undefined function or variable 'FUN'.
というエラーが出てしまいます。
交差検定の正しいやり方につきましてご教示いただけますと幸いです。
どうぞよろしくお願いいたします。

Respuesta aceptada

Tohru Kikawada
Tohru Kikawada el 9 de Feb. de 2019
crossvalのドキュメントに記載のある下記は指定する関数の戻り値と引数の一例です。
TESTVAL = FUN(XTRAIN,XTEST)
ドキュメントにあるいくつかの例題は試してみましたでしょうか。crossvalは様々な機械学習のアルゴリズムで使えるように汎用性のある関数ハンドルの受け渡しで実行されます。CNNで交差検定を実行する場合も下記のようにCNNのクラス分類結果を返すような関数を関数ハンドルとして渡してあげる必要があります。
%% データセットの読み込み
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
%% ダミーのトレーニングインデックスを生成
X = (1:imds.numpartitions)';
y = imds.Labels;
%% 交差検定にCNNの予測ラベル関数のポインタを渡す
mcr = crossval('mcr',X,y,'Predfun',@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds))
%% CNNを学習し、予測ラベルを出力する関数
function ypred = myCNNPredict(xtrain,ytrain,xtest,imds)
% 結果が一意になるように乱数シードをデフォルト値に設定
rng('default');
% ダミーの変数ベクトルを受けてimageDatastoreを学習用とテスト用に分割
imdsTrain = imageDatastore(imds.Files(xtrain));
imdsTrain.Labels = ytrain;
imdsValidation = imageDatastore(imds.Files(xtest));
% レイヤーの設定
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'Verbose',false);
net = trainNetwork(imdsTrain,layers,options);
ypred = classify(net,imdsValidation);
end
  4 comentarios
ssk
ssk el 11 de Feb. de 2019
ご回答ありがとうございます。おかげさまでチュートリアルのコードを無事、コンパイルすることができました。ありがとうございます。
DICOMファイルでも交差検定が使えるかどうか試したところ、以下のようなエラーが出てしまいます。
The function '@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds)' generated
the following error:
Input folders or files contain non-standard file extensions.
拡張子が違うのが原因かもしれません。
currentdirectory = pwd;
% set categories of subdirectory
categories = {'a', 'b', 'c','d'};
imds = imageDatastore(fullfile(currentdirectory, categories),'IncludeSubfolders',true,'FileExtensions','.dcm','LabelSource', 'foldernames','ReadFcn',@dicomread);
mcr = crossval('mcr',X,y,'Predfun',(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds))
作成したコードは上記の通りですが、DICOMファイルでの交差検定の仕方につきまして、ご教示頂けますと幸いです。
どうぞよろしくお願いいたします。
ssk
ssk el 11 de Feb. de 2019
Editada: ssk el 11 de Feb. de 2019
五月雨式のコメント失礼いたします。
頂いた回答につきまして以下の質問がございます。
%% ダミーのトレーニングインデックスを生成
X = (1:imds.numpartitions)';
(1)なぜ、ダミーのトレーニングインデックスを生成しているのか、
(2)なぜ、numpartitions(おそらくnumber of partition)を使っているのか、
(3)(1:imds.numpatition)の意味につきましてもご教示いただけますと幸いです。
@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds)
また、mcrの意味につきましては、 misclassification rateの略語という意味でお間違えないでしょうか。
どうぞよろしくお願いいたします。

Iniciar sesión para comentar.

Más respuestas (0)

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Etiquetas

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!