2019年7月4日木曜日

Matlab : DeepLearning入門(ネットワーク構築)

畳み込みニューラルネットワーク(ConvNet)を使います。一からネットワークを作っていきます。Matlabの資料を参考にしていきます。
https://jp.mathworks.com/help/deeplearning/examples/create-simple-deep-learning-network-for-classification.html
今回は資料にもあるように、手書き文字の分類をしていきます。

(1) データの用意
http://yann.lecun.com/exdb/mnist/
よりデータをダウンロードして展開して使えるようにします。
https://www.kaggle.com/
も活用すると便利です。
MatlabのLiveScriptファイルを用意します。今回はsimple01.mlxとします。
データの取り込みをします。
imds = imageDatastore('/Users/mizuno/Documents/MATLAB/deeplearning/SimpleDeepLearning/trainingSet/','IncludeSubfolders',true,'LabelSource','foldernames');
view raw gistfile1.txt hosted with ❤ by GitHub

 ラベル数をカウントします。
labelCount = countEachLabel(imds)
view raw gistfile1.txt hosted with ❤ by GitHub


約42000のデータが取り込まれています。次にどのようなデータなのか確認します。
perm = randperm(40000,20);
montage(imds, 'Indices', perm);
view raw gistfile1.txt hosted with ❤ by GitHub


データサイズを確認します。今回は28×28×1(グレースケール)になります。
img = readimage(imds,1);
size(img)
view raw gistfile1.txt hosted with ❤ by GitHub

学習データと検証データに分割します。今回は8:2に分けます。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8);
view raw gistfile1.txt hosted with ❤ by GitHub


(2) ネットワークアーキテクチャの定義
ネットワーク全体の役割の概念を確認するには、下記を参考にします。
https://jp.mathworks.com/help/deeplearning/ug/introduction-to-convolutional-neural-networks.html
今回のネットワークを定義します。まず最小のネットワークでやってみます。
layers = [imageInputLayer([28 28 1])
convolution2dLayer(5,20)
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
view raw gistfile1.txt hosted with ❤ by GitHub


この中で
convolution2dLayer(5,20)
・畳み込み層の最初の引数は filterSize です。これは、イメージのスキャン時に学習関数によって使用されるフィルターの高さと幅を示します。この例では、5 という数字によってフィルター サイズが 5 x 5 であることを示しています。
・2 つ目の引数 numFilters はフィルターの数です。
・convolution2dLayer(3,8,'Padding','same')のようにすると、'Padding' を使用して、入力の特徴マップにパディングを追加
・既定のストライドが 1 の畳み込み層の場合、'same' パディングによって空間の出力サイズが入力サイズ

batchNormalizationLayer
 バッチ正規化層は、ニューラル ネットワークを通じて伝播される活性化と勾配を正規化します。これにより、ネットワークの学習は簡単な最適化問題になります。

reluLayer
最も一般的な活性化関数は、正規化線形ユニット (ReLU) です。

maxPooling2dLayer(2,'Stride',2)
この例では、矩形領域のサイズは [2,2] です。名前と値のペアの引数 'Stride' は、入力に沿ってスキャンするときに学習関数が取るステップ サイズを指定します。[2,2]に最大値をとることで情報を集約しています。

fullyConnectedLayer(10)
畳み込み層とダウンサンプリング層の後には、1 つ以上の全結合層を配置します。最後の全結合層の OutputSize パラメーターは、ターゲット データのクラスの数と等しくなります。

softmaxLayer
ソフトマックス活性化関数は、全結合層の出力を正規化します。ソフトマックス層の出力は合計が 1 になる正の数値で構成されており、分類層で分類の確率として使用できます。

classificationLayer
 最後の層は分類層です。この層は、ソフトマックス活性化関数によって各入力について返された確率を使用して、互いに排他的なクラスの 1 つに入力を割り当て、損失を計算します。
https://jp.mathworks.com/help/deeplearning/ug/layers-of-a-convolutional-neural-network.html

学習オプションを指定します。今回の学習オプションは例のものを使います。学習率を0.001としています。(ちなみに学習オプションを何も指定しない場合、今回収束しなくなります。)
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.001, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
view raw gistfile1.txt hosted with ❤ by GitHub

https://jp.mathworks.com/help/deeplearning/ug/setting-up-parameters-and-training-of-a-convnet.html

学習を実行します。
net = trainNetwork(imdsTrain,layers,options);
view raw gistfile1.txt hosted with ❤ by GitHub

精度が良くないです。特にLossが大きいです。

(3)検証
YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)
view raw gistfile1.txt hosted with ❤ by GitHub

で検証をします。
89%となりもっと精度を上げられます。
ヒートマップを書きます。
[numberconf, numbernames] = confusionmat(YValidation, YPred)
heatmap(numbernames, numbernames, numberconf);
view raw gistfile1.txt hosted with ❤ by GitHub

ここまでのコード
imds = imageDatastore('/Users/mizuno/Documents/MATLAB/deeplearning/SimpleDeepLearning/trainingSet/','IncludeSubfolders',true,'LabelSource','foldernames');
labelCount = countEachLabel(imds)
perm = randperm(40000,20);
montage(imds, 'Indices', perm);
img = readimage(imds,1);
size(img)
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8);
layers = [imageInputLayer([28 28 1])
convolution2dLayer(5,20)
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.001, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(imdsTrain,layers,options);
YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)
[numberconf, numbernames] = confusionmat(YValidation, YPred);
heatmap(numbernames, numbernames, numberconf);
view raw gistfile1.txt hosted with ❤ by GitHub


 (4) 精度の向上
畳み込み層を増やして、精度を向上させます。
変更後のLayerの例です。レイヤが増える度に、フィルタ数を2倍にしています。プーリングサイズは基本[2,2]です。
layers = [imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
view raw gistfile1.txt hosted with ❤ by GitHub

学習してみます。特にLossが減少しています。
精度を確認します。98%となり精度が向上しました。
ここまでのコード
imds = imageDatastore('/Users/mizuno/Documents/MATLAB/deeplearning/SimpleDeepLearning/trainingSet/','IncludeSubfolders',true,'LabelSource','foldernames');
labelCount = countEachLabel(imds)
perm = randperm(40000,20);
montage(imds, 'Indices', perm);
img = readimage(imds,1);
size(img)
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8);
layers = [imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.001, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'ValidationData',imdsValidation, ...
'ValidationFrequency',30, ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(imdsTrain,layers,options);
YPred = classify(net,imdsValidation);
YValidation = imdsValidation.Labels;
accuracy = sum(YPred == YValidation)/numel(YValidation)
[numberconf, numbernames] = confusionmat(YValidation, YPred);
heatmap(numbernames, numbernames, numberconf);
view raw gistfile1.txt hosted with ❤ by GitHub

0 件のコメント:

コメントを投稿