MATLAB Answers

Squeezenet model not training in MatlabR2017b.

6 views (last 30 days)
BHUSHAN MUTHIYAN
BHUSHAN MUTHIYAN on 12 Oct 2017
Edited: BHUSHAN MUTHIYAN on 19 Nov 2017
Hello,
I have generated a Squeezenet basic model(vanilla) using Matlab R2017b. I am having exactly same implementation as it is in the Squeezenet implementation using Caffe.
Below is my Matlab code: I am using image datastore object with 10 classes "indsRand10.mat" which is subset of ImageNet dataset.
%%Squeezenet network
%%--Bhushan Muthiyan
imdbPath = fullfile(pwd, 'indsRand10.mat') ;
if exist(imdbPath, 'file')
imdb = load(imdbPath) ;
trainingNumFiles = 768;
valNumFiles = 64;
rng(1) % For reproducibility
[imdb.trainDigitData, imdb.testDigitData] = splitEachLabel(imdb.imdsTrain, ...
trainingNumFiles,'randomize');
[imdb.testDigitData, a] = splitEachLabel(imdb.testDigitData, ...
valNumFiles,'randomize');
end
numImages = numel(imdb.trainDigitData.Files);
idx = randperm(numImages,20);
for i = 1:20
subplot(4,5,i)
I = readimage(imdb.trainDigitData, idx(i));
imshow(I)
end
numClasses = numel(categories(imdb.trainDigitData.Labels));
layers = [
imageInputLayer([224 224 3],'Name','input')
convolution2dLayer(7,96,'Padding','same','Stride',2,'Name','conv_1')
reluLayer('Name','relu_1')
maxPooling2dLayer(3,'Stride',2,'Name','pool_1')
%%fire 2
convolution2dLayer(1,16,'Padding','same','Stride',1,'Name','conv_2')
reluLayer('Name','relu_11')
convolution2dLayer(1,64,'Padding','same','Stride',1,'Name','conv_3')
reluLayer('Name','relu_2')
depthConcatenationLayer(2,'Name','concat_1')
%%fire 3
convolution2dLayer(1,16,'Padding','same','Stride',1,'Name','conv_6')
reluLayer('Name','relu_12')
convolution2dLayer(1,64,'Padding','same','Stride',1,'Name','conv_7')
reluLayer('Name','relu_3')
depthConcatenationLayer(2,'Name','concat_3')
maxPooling2dLayer(3,'Stride',2,'Name','pool_2')
%%fire 4
convolution2dLayer(1,32,'Padding','same','Stride',1,'Name','conv_9')
reluLayer('Name','relu_13')
convolution2dLayer(1,128,'Padding','same','Stride',1,'Name','conv_10')
reluLayer('Name','relu_4')
depthConcatenationLayer(2,'Name','concat_5')
%%fire 5
convolution2dLayer(1,32,'Padding','same','Stride',1,'Name','conv_12')
reluLayer('Name','relu_14')
convolution2dLayer(1,128,'Padding','same','Stride',1,'Name','conv_13')
reluLayer('Name','relu_5')
depthConcatenationLayer(2,'Name','concat_6')
maxPooling2dLayer(3,'Stride',2,'Name','pool_3')
%fire 6
convolution2dLayer(1,48,'Padding','same','Stride',1,'Name','conv_15')
reluLayer('Name','relu_15')
convolution2dLayer(1,192,'Padding','same','Stride',1,'Name','conv_16')
reluLayer('Name','relu_6')
depthConcatenationLayer(2,'Name','concat_8')
%%fire 7
convolution2dLayer(1,48,'Padding','same','Stride',1,'Name','conv_18')
reluLayer('Name','relu_16')
convolution2dLayer(1,192,'Padding','same','Stride',1,'Name','conv_19')
reluLayer('Name','relu_7')
depthConcatenationLayer(2,'Name','concat_9')
%fire 8
convolution2dLayer(1,64,'Padding','same','Stride',1,'Name','conv_21')
reluLayer('Name','relu_17')
convolution2dLayer(1,256,'Padding','same','Stride',1,'Name','conv_22')
reluLayer('Name','relu_8')
depthConcatenationLayer(2,'Name','concat_11')
% fire 9
convolution2dLayer(1,64,'Padding','same','Stride',1,'Name','conv_24')
reluLayer('Name','relu_18')
convolution2dLayer(1,256,'Padding','same','Stride',1,'Name','conv_25')
depthConcatenationLayer(2,'Name','concat_12')
%reluLayer('Name','relu_9')
dropoutLayer(0.5,'Name','Drop_1')
convolution2dLayer(1,numClasses,'Padding','same','Stride',1,'Name','conv_27')
reluLayer('Name','relu_9')
averagePooling2dLayer(13,'Stride',1,'Name','avg_pool_4')
%reluLayer('Name','relu_10')
softmaxLayer('Name','softmax')
classificationLayer('Name','classOutput')];
lgraph = layerGraph(layers);
figure
plot(lgraph)
conv_4 = convolution2dLayer(3,64,'Padding',1,'Stride',1,'Name','conv_4');
lgraph = addLayers(lgraph,conv_4);
conv_8 = convolution2dLayer(3,64,'Padding',1,'Stride',1,'Name','conv_8');
lgraph = addLayers(lgraph,conv_8);
conv_11 = convolution2dLayer(3,128,'Padding',1,'Stride',1,'Name','conv_11');
lgraph = addLayers(lgraph,conv_11);
conv_14 = convolution2dLayer(3,128,'Padding',1,'Stride',1,'Name','conv_14');
lgraph = addLayers(lgraph,conv_14);
conv_17 = convolution2dLayer(3,192,'Padding',1,'Stride',1,'Name','conv_17');
lgraph = addLayers(lgraph,conv_17);
conv_20 = convolution2dLayer(3,192,'Padding',1,'Stride',1,'Name','conv_20');
lgraph = addLayers(lgraph,conv_20);
conv_23 = convolution2dLayer(3,256,'Padding',1,'Stride',1,'Name','conv_23');
lgraph = addLayers(lgraph,conv_23);
conv_26 = convolution2dLayer(3,256,'Padding',1,'Stride',1,'Name','conv_26');
lgraph = addLayers(lgraph,conv_26);
relu_19 = reluLayer('Name','relu_19');
lgraph = addLayers(lgraph,relu_19);
relu_20 = reluLayer('Name','relu_20');
lgraph = addLayers(lgraph,relu_20);
relu_21 = reluLayer('Name','relu_21');
lgraph = addLayers(lgraph,relu_21);
relu_22 = reluLayer('Name','relu_22');
lgraph = addLayers(lgraph,relu_22);
relu_23 = reluLayer('Name','relu_23');
lgraph = addLayers(lgraph,relu_23);
relu_24 = reluLayer('Name','relu_24');
lgraph = addLayers(lgraph,relu_24);
relu_25 = reluLayer('Name','relu_25');
lgraph = addLayers(lgraph,relu_25);
relu_26 = reluLayer('Name','relu_26');
lgraph = addLayers(lgraph,relu_26);
lgraph = connectLayers(lgraph,'relu_11','conv_4');
lgraph = connectLayers(lgraph,'conv_4','relu_19');
lgraph = connectLayers(lgraph,'relu_19','concat_1/in2');
lgraph = connectLayers(lgraph,'relu_12','conv_8');
lgraph = connectLayers(lgraph,'conv_8','relu_20');
lgraph = connectLayers(lgraph,'relu_20','concat_3/in2');
lgraph = connectLayers(lgraph,'relu_13','conv_11');
lgraph = connectLayers(lgraph,'conv_11','relu_21');
lgraph = connectLayers(lgraph,'relu_21','concat_5/in2');
lgraph = connectLayers(lgraph,'relu_14','conv_14');
lgraph = connectLayers(lgraph,'conv_14','relu_22');
lgraph = connectLayers(lgraph,'relu_22','concat_6/in2');
lgraph = connectLayers(lgraph,'relu_15','conv_17');
lgraph = connectLayers(lgraph,'conv_17','relu_23');
lgraph = connectLayers(lgraph,'relu_23','concat_8/in2');
lgraph = connectLayers(lgraph,'relu_16','conv_20');
lgraph = connectLayers(lgraph,'conv_20','relu_24');
lgraph = connectLayers(lgraph,'relu_24','concat_9/in2');
lgraph = connectLayers(lgraph,'relu_17','conv_23');
lgraph = connectLayers(lgraph,'conv_23','relu_25');
lgraph = connectLayers(lgraph,'relu_25','concat_11/in2');
lgraph = connectLayers(lgraph,'relu_18','conv_26');
lgraph = connectLayers(lgraph,'conv_26','relu_26');
lgraph = connectLayers(lgraph,'relu_26','concat_12/in2');
figure
plot(lgraph);
optionsTransfer = trainingOptions('sgdm', ...
'MaxEpochs',25, ...
'MiniBatchSize',64,...
'InitialLearnRate',0.04,...
'LearnRateDropFactor',0.2,...
'LearnRateDropPeriod',5,...
'Plots','training-progress',...
'ExecutionEnvironment','auto');
netTransfer = trainNetwork(imdb.trainDigitData,lgraph,optionsTransfer);
YPred = classify(netTransfer,imdb.testDigitData);
YTest = imdb.testDigitData.Labels;
accuracy = sum(YPred==YTest)/numel(YTest);
fprintf('accuracy = %f\n',accuracy);
numImages = numel(imdb.testDigitData.Files);
idx = randperm(numImages,20);
for i = 1:20
subplot(4,5,i)
I = readimage(imdb.testDigitData, idx(i));
imshow(I)
end
Can someone let me know the reason behind this.
Enclosed here is the image of Squeezenet vanilla model structure.

  0 Comments

Sign in to comment.

Answers (1)

Mickaël Tits
Mickaël Tits on 14 Nov 2017
Hi,
If I understand, you are trying to train your Squeezenet model from scratch, with 768 images ? You need a pretrained model if you want a chance that it works.
You can get here a pretrained SqueezeNet, and use it for transfer learning as you want : https://github.com/titsitits/Squeezenet-Matlab-Keras
Mickaël Tits

  1 Comment

BHUSHAN MUTHIYAN
BHUSHAN MUTHIYAN on 19 Nov 2017
Hello Mickaël,
I had a look to the link provided by you.
But, the .json file description says that the Keras model generated above has first convolution layer 3x3x64 whereas the original Keras implementation in Caffe has dimension 7x7x96 filter.
Can you please provide me with the exact implementation of Squeezenet model (.h5 file) which matches with original squeezenet implemetation.
Thanks!!

Sign in to comment.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by