https://jp.mathworks.com/help/deeplearning/examples/create-simple-deep-learning-network-for-classification.html
今回は資料にもあるように、手書き文字の分類をしていきます。
(1) データの用意
http://yann.lecun.com/exdb/mnist/
よりデータをダウンロードして展開して使えるようにします。
https://www.kaggle.com/
も活用すると便利です。
MatlabのLiveScriptファイルを用意します。今回はsimple01.mlxとします。
データの取り込みをします。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
imds = imageDatastore('/Users/mizuno/Documents/MATLAB/deeplearning/SimpleDeepLearning/trainingSet/','IncludeSubfolders',true,'LabelSource','foldernames'); |
ラベル数をカウントします。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
labelCount = countEachLabel(imds) |
約42000のデータが取り込まれています。次にどのようなデータなのか確認します。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
perm = randperm(40000,20); | |
montage(imds, 'Indices', perm); |
データサイズを確認します。今回は28×28×1(グレースケール)になります。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
img = readimage(imds,1); | |
size(img) |
学習データと検証データに分割します。今回は8:2に分けます。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8); |
(2) ネットワークアーキテクチャの定義
ネットワーク全体の役割の概念を確認するには、下記を参考にします。
https://jp.mathworks.com/help/deeplearning/ug/introduction-to-convolutional-neural-networks.html
今回のネットワークを定義します。まず最小のネットワークでやってみます。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
layers = [imageInputLayer([28 28 1]) | |
convolution2dLayer(5,20) | |
reluLayer | |
maxPooling2dLayer(2,'Stride',2) | |
fullyConnectedLayer(10) | |
softmaxLayer | |
classificationLayer]; |
この中で
convolution2dLayer(5,20)
・畳み込み層の最初の引数は filterSize です。これは、イメージのスキャン時に学習関数によって使用されるフィルターの高さと幅を示します。この例では、5 という数字によってフィルター サイズが 5 x 5 であることを示しています。
・2 つ目の引数 numFilters はフィルターの数です。
・convolution2dLayer(3,8,'Padding','same')のようにすると、'Padding' を使用して、入力の特徴マップにパディングを追加
・既定のストライドが 1 の畳み込み層の場合、'same' パディングによって空間の出力サイズが入力サイズ
batchNormalizationLayer
バッチ正規化層は、ニューラル ネットワークを通じて伝播される活性化と勾配を正規化します。これにより、ネットワークの学習は簡単な最適化問題になります。
reluLayer
最も一般的な活性化関数は、正規化線形ユニット (ReLU) です。
maxPooling2dLayer(2,'Stride',2)
この例では、矩形領域のサイズは [2,2] です。名前と値のペアの引数 'Stride' は、入力に沿ってスキャンするときに学習関数が取るステップ サイズを指定します。[2,2]に最大値をとることで情報を集約しています。
fullyConnectedLayer(10)
畳み込み層とダウンサンプリング層の後には、1 つ以上の全結合層を配置します。最後の全結合層の OutputSize パラメーターは、ターゲット データのクラスの数と等しくなります。
softmaxLayer
ソフトマックス活性化関数は、全結合層の出力を正規化します。ソフトマックス層の出力は合計が 1 になる正の数値で構成されており、分類層で分類の確率として使用できます。
classificationLayer
最後の層は分類層です。この層は、ソフトマックス活性化関数によって各入力について返された確率を使用して、互いに排他的なクラスの 1 つに入力を割り当て、損失を計算します。
https://jp.mathworks.com/help/deeplearning/ug/layers-of-a-convolutional-neural-network.html
学習オプションを指定します。今回の学習オプションは例のものを使います。学習率を0.001としています。(ちなみに学習オプションを何も指定しない場合、今回収束しなくなります。)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
options = trainingOptions('sgdm', ... | |
'InitialLearnRate',0.001, ... | |
'MaxEpochs',4, ... | |
'Shuffle','every-epoch', ... | |
'ValidationData',imdsValidation, ... | |
'ValidationFrequency',30, ... | |
'Verbose',false, ... | |
'Plots','training-progress'); |
https://jp.mathworks.com/help/deeplearning/ug/setting-up-parameters-and-training-of-a-convnet.html
学習を実行します。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
net = trainNetwork(imdsTrain,layers,options); |
精度が良くないです。特にLossが大きいです。
(3)検証
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
YPred = classify(net,imdsValidation); | |
YValidation = imdsValidation.Labels; | |
accuracy = sum(YPred == YValidation)/numel(YValidation) |
で検証をします。
89%となりもっと精度を上げられます。
ヒートマップを書きます。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[numberconf, numbernames] = confusionmat(YValidation, YPred) | |
heatmap(numbernames, numbernames, numberconf); |
ここまでのコード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
imds = imageDatastore('/Users/mizuno/Documents/MATLAB/deeplearning/SimpleDeepLearning/trainingSet/','IncludeSubfolders',true,'LabelSource','foldernames'); | |
labelCount = countEachLabel(imds) | |
perm = randperm(40000,20); | |
montage(imds, 'Indices', perm); | |
img = readimage(imds,1); | |
size(img) | |
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8); | |
layers = [imageInputLayer([28 28 1]) | |
convolution2dLayer(5,20) | |
reluLayer | |
maxPooling2dLayer(2,'Stride',2) | |
fullyConnectedLayer(10) | |
softmaxLayer | |
classificationLayer]; | |
options = trainingOptions('sgdm', ... | |
'InitialLearnRate',0.001, ... | |
'MaxEpochs',4, ... | |
'Shuffle','every-epoch', ... | |
'ValidationData',imdsValidation, ... | |
'ValidationFrequency',30, ... | |
'Verbose',false, ... | |
'Plots','training-progress'); | |
net = trainNetwork(imdsTrain,layers,options); | |
YPred = classify(net,imdsValidation); | |
YValidation = imdsValidation.Labels; | |
accuracy = sum(YPred == YValidation)/numel(YValidation) | |
[numberconf, numbernames] = confusionmat(YValidation, YPred); | |
heatmap(numbernames, numbernames, numberconf); |
(4) 精度の向上
畳み込み層を増やして、精度を向上させます。
変更後のLayerの例です。レイヤが増える度に、フィルタ数を2倍にしています。プーリングサイズは基本[2,2]です。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
layers = [imageInputLayer([28 28 1]) | |
convolution2dLayer(3,8,'Padding','same') | |
reluLayer | |
maxPooling2dLayer(2,'Stride',2) | |
convolution2dLayer(3,16,'Padding','same') | |
batchNormalizationLayer | |
reluLayer | |
maxPooling2dLayer(2,'Stride',2) | |
convolution2dLayer(3,32,'Padding','same') | |
batchNormalizationLayer | |
reluLayer | |
fullyConnectedLayer(10) | |
softmaxLayer | |
classificationLayer]; |
学習してみます。特にLossが減少しています。
精度を確認します。98%となり精度が向上しました。
ここまでのコード
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
imds = imageDatastore('/Users/mizuno/Documents/MATLAB/deeplearning/SimpleDeepLearning/trainingSet/','IncludeSubfolders',true,'LabelSource','foldernames'); | |
labelCount = countEachLabel(imds) | |
perm = randperm(40000,20); | |
montage(imds, 'Indices', perm); | |
img = readimage(imds,1); | |
size(img) | |
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8); | |
layers = [imageInputLayer([28 28 1]) | |
convolution2dLayer(3,8,'Padding','same') | |
reluLayer | |
maxPooling2dLayer(2,'Stride',2) | |
convolution2dLayer(3,16,'Padding','same') | |
batchNormalizationLayer | |
reluLayer | |
maxPooling2dLayer(2,'Stride',2) | |
convolution2dLayer(3,32,'Padding','same') | |
batchNormalizationLayer | |
reluLayer | |
fullyConnectedLayer(10) | |
softmaxLayer | |
classificationLayer]; | |
options = trainingOptions('sgdm', ... | |
'InitialLearnRate',0.001, ... | |
'MaxEpochs',4, ... | |
'Shuffle','every-epoch', ... | |
'ValidationData',imdsValidation, ... | |
'ValidationFrequency',30, ... | |
'Verbose',false, ... | |
'Plots','training-progress'); | |
net = trainNetwork(imdsTrain,layers,options); | |
YPred = classify(net,imdsValidation); | |
YValidation = imdsValidation.Labels; | |
accuracy = sum(YPred == YValidation)/numel(YValidation) | |
[numberconf, numbernames] = confusionmat(YValidation, YPred); | |
heatmap(numbernames, numbernames, numberconf); |
0 件のコメント:
コメントを投稿