ConvNetJS で MNIST を分類2 - 畳み込みニューラルネット

前回 の続きです。 今回は畳み込みニューラルネットを使って MNIST の手書き数字を分類してみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160328/

準備

誤差・正解率のグラフ化と畳み込みフィルタの画像化を行うため、前回の構成へ d3 等を追加しています。

package.json
{
  "name": "convnetjs_mnist_conv_sample",
  "version": "1.0.0",
  "description": "",
  "main": "index.js",
  "dependencies": {
    "basic-csv": "0.0.2",
    "bluebird": "^3.3.4",
    "convnetjs": "^0.3.0",
    "d3": "^3.5.16",
    "jsdom": "^8.1.0",
    "shuffle-array": "^0.1.2"
  }
}
インストール例
> npm install

(b) 畳み込みニューラルネット

畳み込みニューラルネットでは畳み込み層とプーリング層を組み合わせてレイヤーを構築します。 (実際は全結合層も使います)

名称 処理 ConvNetJS の layer_type
畳み込み層 入力画像へフィルターを適用し特徴量を抽出 conv
プーリング層 入力画像へプーリング演算(フィルター内)を適用 pool

ConvNetJS の畳み込み層・プーリング層は以下のように設定します。

  • sxsy でフィルターのサイズを指定 (sy を省略すると sx と同じ値を適用)
  • pad で入力画像の周囲にゼロパディング(0埋め)する数を指定
  • stride でフィルタの適用位置を縦横に移動する数を指定 (1 の場合は縦横に 1画素ずつずらしてフィルターを適用)

畳み込み層では filters で適用するフィルターの数を指定します。

プーリング層では 「最大プーリング」 (フィルターの値の最大値を採用) を行うようになっており、今回試したバージョンでは (「平均プーリング」等へ) プーリング方法を変更する機能は無さそうでした。

今回は、以下のように畳み込み層・プーリング層が 2回続くような構成にしてみました。

create_layer_conv.js (畳み込みニューラルネットのモデル構築と保存処理)
'use strict';

// 畳み込み層の活性化関数
const act = process.argv[2];
// 出力ファイル名
const jsonDestFile = process.argv[3];

require('./save_model').saveModel(
    [
        { type: 'input', out_sx: 28, out_sy: 28, out_depth: 1 },
        // 1つ目の畳み込み層
        { type: 'conv', sx: 5, filters: 8, stride: 1, pad: 2, activation: act },
        // 1つ目のプーリング層
        { type: 'pool', sx: 2, stride: 2 },
        // 2つ目の畳み込み層
        { type: 'conv', sx: 5, filters: 16, stride: 1, pad: 2, activation: act },
        // 2つ目のプーリング層
        { type: 'pool', sx: 3, stride: 3 },
        { type: 'softmax', num_classes: 10 }
    ],
    jsonDestFile
);

活性化関数へ relu を指定した場合の内部的なレイヤー構成は以下のようになりました。

学習モデルの内部的なレイヤー構成例
input -> conv -> relu -> pool -> conv -> relu -> pool -> fc -> softmax

各レイヤーの出力サイズは以下の通りです。

layer_type out_sx out_sy out_depth
input 28 28 1
conv 28 28 8
relu 28 28 8
pool 14 14 8
conv 14 14 16
relu 14 14 16
pool 4 4 16
fc 1 1 10
softmax 1 1 10

学習と評価

前回 作成した共通処理 (learn_mnist.js 等) を使って学習と評価を実施します。

学習回数を前回と同じ 15回にすると相当時間がかかってしまうので、今回は以下の 4種類で学習・評価を試してみました。

  1. 活性化関数 = relu, 学習回数 = 5
  2. 活性化関数 = relu, 学習回数 = 10
  3. 活性化関数 = sigmoid, 学習回数 = 5
  4. 活性化関数 = sigmoid, 学習回数 = 10

学習回数以外は前回と同じパラメータを使います。

  • 学習回数 = 15
  • バッチサイズ = 100
  • 学習係数 = 0.001
  • 学習係数の決定方法 = adadelta

処理時間は学習回数 5回で 1.5時間、10回で 3時間程度でした。

PC の性能にも依存すると思いますが、1つの CPU で処理するので比較的遅めだと思います。

1. 活性化関数 = relu, 学習回数 = 5 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_conv.js relu models/conv_relu.json

> node learn_mnist.js 5 100 0.001 adadelta models/conv_relu.json results/b-1_conv_relu.json > logs/b-1_conv_relu.log

> node validate_mnist.js results/b-1_conv_relu.json

data size: 10000
accuracy: 0.9785
学習時の誤差と正解率

f:id:fits:20160328195647p:plain

学習後のフィルター

f:id:fits:20160328195700p:plain

2. 活性化関数 = relu, 学習回数 = 10 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node learn_mnist.js 10 100 0.001 adadelta models/conv_relu.json results/b-2_conv_relu.json > logs/b-2_conv_relu.log

> node validate_mnist.js results/b-2_conv_relu.json

data size: 10000
accuracy: 0.9786
学習時の誤差と正解率

f:id:fits:20160328195714p:plain

学習後のフィルター

f:id:fits:20160328195724p:plain

3. 活性化関数 = sigmoid, 学習回数 = 5 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_conv.js sigmoid models/conv_sigmoid.json

> node learn_mnist.js 5 100 0.001 adadelta models/conv_sigmoid.json results/b-3_conv_sigmoid.json > logs/b-3_conv_sigmoid.log

> node validate_mnist.js results/b-3_conv_sigmoid.json

data size: 10000
accuracy: 0.9812
学習時の誤差と正解率

f:id:fits:20160328195741p:plain

学習後のフィルター

f:id:fits:20160328195749p:plain

4. 活性化関数 = sigmoid, 学習回数 = 10 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node learn_mnist.js 10 100 0.001 adadelta models/conv_sigmoid.json results/b-4_conv_sigmoid.json > logs/b-4_conv_sigmoid.log

> node validate_mnist.js results/b-4_conv_sigmoid.json

data size: 10000
accuracy: 0.9862
学習時の誤差と正解率

f:id:fits:20160328195802p:plain

学習後のフィルター

f:id:fits:20160328195811p:plain

結果のまとめ

番号 活性化関数 学習回数 正解率
1 relu 5 0.9785
2 relu 10 0.9786
3 sigmoid 5 0.9812
4 sigmoid 10 0.9862

前回の結果よりも高い正解率となりました。

補足

(1) 誤差と正解率のグラフ化

誤差と正解率のログは以下のようなスクリプトでグラフ化しました。

誤差の値が Infinity となる事があったので ※、その場合はとりあえず固定値(以下では 1000)で置換するようにしています。

 ※ ただし、Infinity となったのは前回の階層型ニューラルネットの結果で、
    畳み込みニューラルネットの結果では発生していません
line_chart.js
'use strict';

const Promise = require('bluebird');
const d3 = require('d3');
const jsdom = require('jsdom').jsdom;

const readCSV = Promise.promisify(require('basic-csv').readCSV);

const w = 300;
const h = 300;
const margin = { top: 20, bottom: 50, left: 50, right: 20 };

const xLabels = ['バッチ回数', 'バッチ回数'];
const yLabels = ['誤差', '正解率'];

readCSV(process.argv[2]).then( ds => {
    const document = jsdom();

    const chartLayout = (xnum, w, h, margin) => {
        const borderWidth = w + margin.left + margin.right;
        const borderHeight = h + margin.top + margin.bottom;

        const svg = d3.select(document.body).append('svg')
            .attr('xmlns', 'http://www.w3.org/2000/svg')
            .attr('width', xnum * borderWidth)
            .attr('height', borderHeight);

        return Array(xnum).fill(0).map( (n, i) =>
            svg.append('g')
                .attr('transform', `translate(${i * borderWidth + margin.left}, ${margin.top})`)
        );
    };

    const xDomain = [0, ds.length];
    const yDomain = [1, 0];

    // スケールの定義
    const x = d3.scale.linear().range([0, w]).domain(xDomain);
    const y = d3.scale.linear().range([0, h]).domain(yDomain);

    // 軸の定義
    const xAxis = d3.svg.axis().scale(x).orient('bottom').ticks(5);
    const yAxis = d3.svg.axis().scale(y).orient('left');

    // 折れ線の作成
    const createLine = d3.svg.line()
        .x((d, i) => x(i + 1))
        .y(d => {
            // Infinity の際に固定値を設定
            if (d == 'Infinity') {
                d = 1000;
            }
            return y(d);
        });

    // 折れ線の描画
    const drawLine = (g, data, colIndex, color) => {
        g.append('path')
            .attr('d', createLine(data.map(d => d[colIndex])))
            .attr('stroke', color)
            .attr('fill', 'none');
    };

    const gs = chartLayout(2, w, h, margin);

    // X・Y軸の描画
    gs.forEach( (g, i) => {
        g.append('g')
            .attr('transform', `translate(0, ${h})`)
            .call(xAxis)
            .append('text')
                .attr('x', w / 2)
                .attr('y', 35)
                .style('font-family', 'Sans')
                .text(xLabels[i]);

        g.append('g')
            .call(yAxis)
            .append('text')
                .attr('x', -h / 2)
                .attr('y', -35)
                .attr('transform', 'rotate(-90)')
                .style('font-family', 'Sans')
                .text(yLabels[i]);
    });

    drawLine(gs[0], ds, 2, 'blue');
    drawLine(gs[1], ds, 3, 'blue');

    return document.body.innerHTML;

}).then( html => 
    console.log(html)
);

このスクリプトを使って svg ファイルへ出力し、ImageMagickpng ファイルへ変換しました。

実行例
> node line_chart.js logs/b-1_conv_relu.log > img/b-1.svg

> convert img/b-1.svg img/b-1.png

(2) 畳み込み層のフィルターを svg

学習後の畳み込み層のフィルターを以下のようなスクリプトで画像化(svg)してみました。

実際のフィルターサイズは 5x5 と小さくて分かり難いので、d3.scale を使って 50x50 へ拡大しています。

また、フィルターを可視化するため、d3.scale でフィルターの値が 0 ~ 255 となるように変換しています。

conv_filter_svg.js
'use strict';

const Promise = require('bluebird');
const convnetjs = require('convnetjs');
const d3 = require('d3');
const jsdom = require('jsdom').jsdom;

const readFile = Promise.promisify(require('fs').readFile);

const size = 50;
const margin = 5;

const modelJsonFile = process.argv[2];

// フィルター内の最小値と最大値を抽出
const valueRange = fs => fs.reduce( (acc, f) =>
    [
        Math.min(acc[0], Math.min.apply(null, f.w)),
        Math.max(acc[1], Math.max.apply(null, f.w))
    ], 
    [0, 0]
);

readFile(modelJsonFile).then( json => {
    const net = new convnetjs.Net();
    net.fromJSON(JSON.parse(json));

    return net.layers;
}).then( layers =>
    layers.reduce( (acc, v) => {
        // 畳み込み層のフィルターを抽出
        if (v.layer_type == 'conv' && v.filters) {
            acc.push( v.filters );
        }

        return acc;
    }, [])
).then( filtersList => {
    const document = jsdom();

    const svg = d3.select(document.body)
                    .append('svg')
                    .attr('xmlns', 'http://www.w3.org/2000/svg');

    filtersList.forEach( (fs, j) => {
        const yPos = (size + margin) * j;

        // フィルターの数値を 0 ~ 255 の値へ変換
        const pixelScale = d3.scale.linear()
                            .range([0, 255]).domain(valueRange(fs));

        fs.forEach( (f, i) => {
            const xPos = (size + margin) * i;

            const g = svg.append('g')
                        .attr('transform', `translate(${xPos}, ${yPos})`);

            const xScale = d3.scale.linear()
                            .range([0, size]).domain([0, f.sx]);

            const yScale = d3.scale.linear()
                            .range([0, size]).domain([0, f.sy]);

            for (let y = 0; y < f.sy; y++) {
                for (let x = 0; x < f.sx; x++) {
                    const p = pixelScale( f.get(x, y, 0) );

                    g.append('rect')
                        .attr('x', xScale(x))
                        .attr('y', yScale(y))
                        .attr('width', xScale(1))
                        .attr('height', yScale(1))
                        .attr('fill', d3.rgb(p, p, p));
                }
            }
        });
    });

    return document.body.innerHTML;

}).then( svg =>
    console.log(svg)
).catch( e => 
    console.error(e)
);

こちらも、svg ファイルとして出力し、ImageMagickpng ファイルへ変換しました。

実行例
> node conv_filter_svg.js results/b-1_conv_relu.json > img/b-1_filters.svg

> convert img/b-1_filters.svg img/b-1_filters.png