(5) EPOCH数を増やして正解率上昇

実験内容

シンプル構成初版では、1EPOCHで90%前後の正解率だった。
http://nnet.dogrow.net/?p=43

今回は EPOCH数を増やして正解率の向上を目論む。
対象は、シンプル構成初版の実験で 90.8%(1EPOCH)のスコアを出した [784]-[64]-[10]の層構成とする。
シンプル構成初版プログラムからの変更は、学習とテストを指定EPOCH数分繰り返すようにしただけ。

実験結果

初回の正解率が 90.8%、34回目には 96.5%まで上昇した。
http://nnet.dogrow.net/wp-content/uploads/2013/10/20131006_01.png

EPOCH No.1
[ 0]  949 /  980 ( 96.8%) 
[ 1] 1109 / 1135 ( 97.7%) 
[ 2]  907 / 1032 ( 87.9%) 
[ 3]  901 / 1010 ( 89.2%) 
[ 4]  898 /  982 ( 91.4%) 
[ 5]  745 /  892 ( 83.5%) 
[ 6]  898 /  958 ( 93.7%) 
[ 7]  929 / 1028 ( 90.4%) 
[ 8]  859 /  974 ( 88.2%) 
[ 9]  888 / 1009 ( 88.0%) 
Total  9083 / 10000 ( 90.8%)
EPOCH No.34
[ 0]  966 /  980 ( 98.6%) 
[ 1] 1120 / 1135 ( 98.7%) 
[ 2]  991 / 1032 ( 96.0%) 
[ 3]  974 / 1010 ( 96.4%) 
[ 4]  956 /  982 ( 97.4%) 
[ 5]  860 /  892 ( 96.4%) 
[ 6]  929 /  958 ( 97.0%) 
[ 7]  983 / 1028 ( 95.6%) 
[ 8]  923 /  974 ( 94.8%) 
[ 9]  953 / 1009 ( 94.4%) 
Total  9655 / 10000 ( 96.5%)

プログラムのソースコード

シンプル構成初版からの変更点は下記の2点のみ。
(1)コマンドライン引数でEPOCH数を指定可能とした。
(2)学習、テストを指定EPOCH数分繰り返すようにした。

NNET_control.m ※シンプル構成初版からの変更はこの1ファイルのみ。

function NNET_control( num_unit_of_each_layer, num_EPOCH )

  % 学習画像・ラベル、テスト画像・ラベルをファイルから読み込み
  [train_img, train_lbl] = load_MNIST( '../data/train-images-idx3-ubyte', '../data/train-labels-idx1-ubyte' );
  [test_img,  test_lbl ] = load_MNIST( '../data/t10k-images-idx3-ubyte',  '../data/t10k-labels-idx1-ubyte'  );

  % 各画像データを 0.0~1.0の範囲に正規化
  train_img = train_img / 255;
  test_img  = test_img  / 255;

  % 指定された層数、ユニット数でニューラルネットワークを作成
  nn = NNET_setup( num_unit_of_each_layer );

  % 正解率記録用領域を確保
  accuracy = zeros(num_EPOCH,1);

  % 指定EPOCH回数だけ繰り返す
  for epoch=1 : num_EPOCH

    % 学習実行
    nn = NNET_learn( nn, train_img, train_lbl );

    % テスト実行
    result = NNET_test( nn, test_img, test_lbl );

    % テスト結果を表示
    printf('\nEPOCH No.%d\n', epoch);
    for i=1: 10
      printf('[%2d] %4d / %4d (%5.1f%%) \n', i-1, result(i,2), result(i,1), result(i,2)/result(i,1)*100);
    end
    sum_result = sum(result,1);
    printf('Total %5d / %5d (%5.1f%%) \n', sum_result(2), sum_result(1), sum_result(2)/sum_result(1)*100);
    fflush(1);

    % 正解率をCSVファイルに出力
    if exist('accuracy.csv','file')~=0
      delete('accuracy.csv');
    end
    accuracy(epoch) = sum_result(2)/sum_result(1);  % 正解率
    csvwrite('accuracy.csv', accuracy);

  end
end

シンプル構成初版のプログラムソースコードはこちら。
http://nnet.dogrow.net/?p=43

00002_simpleNN_02

Leave a Comment


NOTE - You can use these HTML tags and attributes:
<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <s> <strike> <strong>

*