ConvNetJS で MNIST を分類2 - 畳み込みニューラルネット
前回 の続きです。 今回は畳み込みニューラルネットを使って MNIST の手書き数字を分類してみます。
- Node.js 5.8.0
- ConvNetJS 0.3.0
ソースは http://github.com/fits/try_samples/tree/master/blog/20160328/
準備
誤差・正解率のグラフ化と畳み込みフィルタの画像化を行うため、前回の構成へ d3
等を追加しています。
package.json
{ "name": "convnetjs_mnist_conv_sample", "version": "1.0.0", "description": "", "main": "index.js", "dependencies": { "basic-csv": "0.0.2", "bluebird": "^3.3.4", "convnetjs": "^0.3.0", "d3": "^3.5.16", "jsdom": "^8.1.0", "shuffle-array": "^0.1.2" } }
インストール例
> npm install
(b) 畳み込みニューラルネット
畳み込みニューラルネットでは畳み込み層とプーリング層を組み合わせてレイヤーを構築します。 (実際は全結合層も使います)
名称 | 処理 | ConvNetJS の layer_type |
---|---|---|
畳み込み層 | 入力画像へフィルターを適用し特徴量を抽出 | conv |
プーリング層 | 入力画像へプーリング演算(フィルター内)を適用 | pool |
ConvNetJS の畳み込み層・プーリング層は以下のように設定します。
sx
とsy
でフィルターのサイズを指定 (sy
を省略するとsx
と同じ値を適用)pad
で入力画像の周囲にゼロパディング(0埋め)する数を指定stride
でフィルタの適用位置を縦横に移動する数を指定 (1 の場合は縦横に 1画素ずつずらしてフィルターを適用)
畳み込み層では filters
で適用するフィルターの数を指定します。
プーリング層では 「最大プーリング」 (フィルターの値の最大値を採用) を行うようになっており、今回試したバージョンでは (「平均プーリング」等へ) プーリング方法を変更する機能は無さそうでした。
今回は、以下のように畳み込み層・プーリング層が 2回続くような構成にしてみました。
create_layer_conv.js (畳み込みニューラルネットのモデル構築と保存処理)
'use strict'; // 畳み込み層の活性化関数 const act = process.argv[2]; // 出力ファイル名 const jsonDestFile = process.argv[3]; require('./save_model').saveModel( [ { type: 'input', out_sx: 28, out_sy: 28, out_depth: 1 }, // 1つ目の畳み込み層 { type: 'conv', sx: 5, filters: 8, stride: 1, pad: 2, activation: act }, // 1つ目のプーリング層 { type: 'pool', sx: 2, stride: 2 }, // 2つ目の畳み込み層 { type: 'conv', sx: 5, filters: 16, stride: 1, pad: 2, activation: act }, // 2つ目のプーリング層 { type: 'pool', sx: 3, stride: 3 }, { type: 'softmax', num_classes: 10 } ], jsonDestFile );
活性化関数へ relu を指定した場合の内部的なレイヤー構成は以下のようになりました。
学習モデルの内部的なレイヤー構成例
input -> conv -> relu -> pool -> conv -> relu -> pool -> fc -> softmax
各レイヤーの出力サイズは以下の通りです。
layer_type | out_sx | out_sy | out_depth |
---|---|---|---|
input | 28 | 28 | 1 |
conv | 28 | 28 | 8 |
relu | 28 | 28 | 8 |
pool | 14 | 14 | 8 |
conv | 14 | 14 | 16 |
relu | 14 | 14 | 16 |
pool | 4 | 4 | 16 |
fc | 1 | 1 | 10 |
softmax | 1 | 1 | 10 |
学習と評価
前回 作成した共通処理 (learn_mnist.js 等) を使って学習と評価を実施します。
学習回数を前回と同じ 15回にすると相当時間がかかってしまうので、今回は以下の 4種類で学習・評価を試してみました。
- 活性化関数 = relu, 学習回数 = 5
- 活性化関数 = relu, 学習回数 = 10
- 活性化関数 = sigmoid, 学習回数 = 5
- 活性化関数 = sigmoid, 学習回数 = 10
学習回数以外は前回と同じパラメータを使います。
- 学習回数 = 15
- バッチサイズ = 100
- 学習係数 = 0.001
- 学習係数の決定方法 = adadelta
処理時間は学習回数 5回で 1.5時間、10回で 3時間程度でした。
PC の性能にも依存すると思いますが、1つの CPU で処理するので比較的遅めだと思います。
1. 活性化関数 = relu, 学習回数 = 5 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)
> node create_layer_conv.js relu models/conv_relu.json > node learn_mnist.js 5 100 0.001 adadelta models/conv_relu.json results/b-1_conv_relu.json > logs/b-1_conv_relu.log > node validate_mnist.js results/b-1_conv_relu.json data size: 10000 accuracy: 0.9785
学習時の誤差と正解率
学習後のフィルター
2. 活性化関数 = relu, 学習回数 = 10 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)
> node learn_mnist.js 10 100 0.001 adadelta models/conv_relu.json results/b-2_conv_relu.json > logs/b-2_conv_relu.log > node validate_mnist.js results/b-2_conv_relu.json data size: 10000 accuracy: 0.9786
学習時の誤差と正解率
学習後のフィルター
3. 活性化関数 = sigmoid, 学習回数 = 5 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)
> node create_layer_conv.js sigmoid models/conv_sigmoid.json > node learn_mnist.js 5 100 0.001 adadelta models/conv_sigmoid.json results/b-3_conv_sigmoid.json > logs/b-3_conv_sigmoid.log > node validate_mnist.js results/b-3_conv_sigmoid.json data size: 10000 accuracy: 0.9812
学習時の誤差と正解率
学習後のフィルター
4. 活性化関数 = sigmoid, 学習回数 = 10 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)
> node learn_mnist.js 10 100 0.001 adadelta models/conv_sigmoid.json results/b-4_conv_sigmoid.json > logs/b-4_conv_sigmoid.log > node validate_mnist.js results/b-4_conv_sigmoid.json data size: 10000 accuracy: 0.9862
学習時の誤差と正解率
学習後のフィルター
結果のまとめ
番号 | 活性化関数 | 学習回数 | 正解率 |
---|---|---|---|
1 | relu | 5 | 0.9785 |
2 | relu | 10 | 0.9786 |
3 | sigmoid | 5 | 0.9812 |
4 | sigmoid | 10 | 0.9862 |
前回の結果よりも高い正解率となりました。
補足
(1) 誤差と正解率のグラフ化
誤差と正解率のログは以下のようなスクリプトでグラフ化しました。
誤差の値が Infinity
となる事があったので ※、その場合はとりあえず固定値(以下では 1000)で置換するようにしています。
※ ただし、Infinity となったのは前回の階層型ニューラルネットの結果で、 畳み込みニューラルネットの結果では発生していません
line_chart.js
'use strict'; const Promise = require('bluebird'); const d3 = require('d3'); const jsdom = require('jsdom').jsdom; const readCSV = Promise.promisify(require('basic-csv').readCSV); const w = 300; const h = 300; const margin = { top: 20, bottom: 50, left: 50, right: 20 }; const xLabels = ['バッチ回数', 'バッチ回数']; const yLabels = ['誤差', '正解率']; readCSV(process.argv[2]).then( ds => { const document = jsdom(); const chartLayout = (xnum, w, h, margin) => { const borderWidth = w + margin.left + margin.right; const borderHeight = h + margin.top + margin.bottom; const svg = d3.select(document.body).append('svg') .attr('xmlns', 'http://www.w3.org/2000/svg') .attr('width', xnum * borderWidth) .attr('height', borderHeight); return Array(xnum).fill(0).map( (n, i) => svg.append('g') .attr('transform', `translate(${i * borderWidth + margin.left}, ${margin.top})`) ); }; const xDomain = [0, ds.length]; const yDomain = [1, 0]; // スケールの定義 const x = d3.scale.linear().range([0, w]).domain(xDomain); const y = d3.scale.linear().range([0, h]).domain(yDomain); // 軸の定義 const xAxis = d3.svg.axis().scale(x).orient('bottom').ticks(5); const yAxis = d3.svg.axis().scale(y).orient('left'); // 折れ線の作成 const createLine = d3.svg.line() .x((d, i) => x(i + 1)) .y(d => { // Infinity の際に固定値を設定 if (d == 'Infinity') { d = 1000; } return y(d); }); // 折れ線の描画 const drawLine = (g, data, colIndex, color) => { g.append('path') .attr('d', createLine(data.map(d => d[colIndex]))) .attr('stroke', color) .attr('fill', 'none'); }; const gs = chartLayout(2, w, h, margin); // X・Y軸の描画 gs.forEach( (g, i) => { g.append('g') .attr('transform', `translate(0, ${h})`) .call(xAxis) .append('text') .attr('x', w / 2) .attr('y', 35) .style('font-family', 'Sans') .text(xLabels[i]); g.append('g') .call(yAxis) .append('text') .attr('x', -h / 2) .attr('y', -35) .attr('transform', 'rotate(-90)') .style('font-family', 'Sans') .text(yLabels[i]); }); drawLine(gs[0], ds, 2, 'blue'); drawLine(gs[1], ds, 3, 'blue'); return document.body.innerHTML; }).then( html => console.log(html) );
このスクリプトを使って svg ファイルへ出力し、ImageMagick で png ファイルへ変換しました。
実行例
> node line_chart.js logs/b-1_conv_relu.log > img/b-1.svg > convert img/b-1.svg img/b-1.png
(2) 畳み込み層のフィルターを svg 化
学習後の畳み込み層のフィルターを以下のようなスクリプトで画像化(svg)してみました。
実際のフィルターサイズは 5x5 と小さくて分かり難いので、d3.scale
を使って 50x50 へ拡大しています。
また、フィルターを可視化するため、d3.scale
でフィルターの値が 0 ~ 255 となるように変換しています。
conv_filter_svg.js
'use strict'; const Promise = require('bluebird'); const convnetjs = require('convnetjs'); const d3 = require('d3'); const jsdom = require('jsdom').jsdom; const readFile = Promise.promisify(require('fs').readFile); const size = 50; const margin = 5; const modelJsonFile = process.argv[2]; // フィルター内の最小値と最大値を抽出 const valueRange = fs => fs.reduce( (acc, f) => [ Math.min(acc[0], Math.min.apply(null, f.w)), Math.max(acc[1], Math.max.apply(null, f.w)) ], [0, 0] ); readFile(modelJsonFile).then( json => { const net = new convnetjs.Net(); net.fromJSON(JSON.parse(json)); return net.layers; }).then( layers => layers.reduce( (acc, v) => { // 畳み込み層のフィルターを抽出 if (v.layer_type == 'conv' && v.filters) { acc.push( v.filters ); } return acc; }, []) ).then( filtersList => { const document = jsdom(); const svg = d3.select(document.body) .append('svg') .attr('xmlns', 'http://www.w3.org/2000/svg'); filtersList.forEach( (fs, j) => { const yPos = (size + margin) * j; // フィルターの数値を 0 ~ 255 の値へ変換 const pixelScale = d3.scale.linear() .range([0, 255]).domain(valueRange(fs)); fs.forEach( (f, i) => { const xPos = (size + margin) * i; const g = svg.append('g') .attr('transform', `translate(${xPos}, ${yPos})`); const xScale = d3.scale.linear() .range([0, size]).domain([0, f.sx]); const yScale = d3.scale.linear() .range([0, size]).domain([0, f.sy]); for (let y = 0; y < f.sy; y++) { for (let x = 0; x < f.sx; x++) { const p = pixelScale( f.get(x, y, 0) ); g.append('rect') .attr('x', xScale(x)) .attr('y', yScale(y)) .attr('width', xScale(1)) .attr('height', yScale(1)) .attr('fill', d3.rgb(p, p, p)); } } }); }); return document.body.innerHTML; }).then( svg => console.log(svg) ).catch( e => console.error(e) );
こちらも、svg ファイルとして出力し、ImageMagick で png ファイルへ変換しました。
実行例
> node conv_filter_svg.js results/b-1_conv_relu.json > img/b-1_filters.svg > convert img/b-1_filters.svg img/b-1_filters.png