Ramda で階層的なグルーピング

JavaScript 用の関数型ライブラリ Ramda で階層的なグルーピングを行ってみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180220/

はじめに

概要

今回は、以下のデータに対して階層的なグルーピングと集計処理を適用します。

データ
const data = [
    {category: 'A', item: 'A01', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A02', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A01', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A01', date: '2018-02-02', value: 20},
    {category: 'A', item: 'A03', date: '2018-02-03', value: 2},
    {category: 'B', item: 'B01', date: '2018-02-02', value: 1},
    {category: 'A', item: 'A03', date: '2018-02-03', value: 5},
    {category: 'A', item: 'A01', date: '2018-02-02', value: 2},
    {category: 'B', item: 'B01', date: '2018-02-03', value: 3},
    {category: 'B', item: 'B01', date: '2018-02-04', value: 1},
    {category: 'C', item: 'C01', date: '2018-02-01', value: 1},
    {category: 'B', item: 'B01', date: '2018-02-04', value: 10}
]

具体的には、上記category item date の順に階層的にグルーピングした後、value の合計値を算出して以下のようにします。

処理結果
{
  A: {
     A01: { '2018-02-01': 2, '2018-02-02': 22 },
     A02: { '2018-02-01': 1 },
     A03: { '2018-02-03': 7 }
  },
  B: { B01: { '2018-02-02': 1, '2018-02-03': 3, '2018-02-04': 11 } },
  C: { C01: { '2018-02-01': 1 } }
}

Ramda インストール

Ramda は以下のようにインストールしておきます。

> npm install ramda

実装

(a) 階層的なグルーピングと集計

まずは、処理方法を確認するため、順番に処理を実施してみます。

1. category でグルーピング(1層目)

指定項目によるグルーピング処理は R.groupBy で行えます。 category でグルーピングする処理は以下のようになります。

category グルーピング処理
const R = require('ramda')

const data = [
    {category: 'A', item: 'A01', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A02', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A01', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A01', date: '2018-02-02', value: 20},
    {category: 'A', item: 'A03', date: '2018-02-03', value: 2},
    {category: 'B', item: 'B01', date: '2018-02-02', value: 1},
    {category: 'A', item: 'A03', date: '2018-02-03', value: 5},
    {category: 'A', item: 'A01', date: '2018-02-02', value: 2},
    {category: 'B', item: 'B01', date: '2018-02-03', value: 3},
    {category: 'B', item: 'B01', date: '2018-02-04', value: 1},
    {category: 'C', item: 'C01', date: '2018-02-01', value: 1},
    {category: 'B', item: 'B01', date: '2018-02-04', value: 10}
]

const res1 = R.groupBy(R.prop('category'), data)

console.log(res1)
category グルーピング結果
{ A:
   [ { category: 'A', item: 'A01', date: '2018-02-01', value: 1 },
     { category: 'A', item: 'A02', date: '2018-02-01', value: 1 },
     { category: 'A', item: 'A01', date: '2018-02-01', value: 1 },
     { category: 'A', item: 'A01', date: '2018-02-02', value: 20 },
     { category: 'A', item: 'A03', date: '2018-02-03', value: 2 },
     { category: 'A', item: 'A03', date: '2018-02-03', value: 5 },
     { category: 'A', item: 'A01', date: '2018-02-02', value: 2 } ],
  B:
   [ { category: 'B', item: 'B01', date: '2018-02-02', value: 1 },
     { category: 'B', item: 'B01', date: '2018-02-03', value: 3 },
     { category: 'B', item: 'B01', date: '2018-02-04', value: 1 },
     { category: 'B', item: 'B01', date: '2018-02-04', value: 10 } ],
  C:
   [ { category: 'C', item: 'C01', date: '2018-02-01', value: 1 } ] }

2. item でグルーピング(2層目)

category のグルーピング結果を item で更にグルーピングするには res1 の値部分 ([ { category: 'A', ・・・}, ・・・ ] 等) に R.groupBy を適用します。

これは R.mapObjIndexed で実施できます。

item グルーピング処理
const res2 = R.mapObjIndexed(R.groupBy(R.prop('item')), res1)

console.log(res2)
item グルーピング結果
{ A:
   { A01: [ [Object], [Object], [Object], [Object] ],
     A02: [ [Object] ],
     A03: [ [Object], [Object] ] },
  B: { B01: [ [Object], [Object], [Object], [Object] ] },
  C: { C01: [ [Object] ] } }

3. date でグルーピング(3層目)

更に date でグルーピングするには R.mapObjIndexed を重ねて R.groupBy を適用します。

date グルーピング処理
const res3 = R.mapObjIndexed(R.mapObjIndexed(R.groupBy(R.prop('date'))), res2)

console.log(res3)
date グルーピング結果
{ A:
   { A01: { '2018-02-01': [Array], '2018-02-02': [Array] },
     A02: { '2018-02-01': [Array] },
     A03: { '2018-02-03': [Array] } },
  B:
   { B01:
      { '2018-02-02': [Array],
        '2018-02-03': [Array],
        '2018-02-04': [Array] } },
  C: { C01: { '2018-02-01': [Array] } } }

4. value の合計

最後に、R.groupBy の代わりに value を合計する処理(以下の sumValue)へ R.mapObjIndexed を階層分だけ重ねて適用すれば完成です。

value 合計処理
const sumValue = R.reduce((a, b) => a + b.value, 0)

const res4 = R.mapObjIndexed(R.mapObjIndexed(R.mapObjIndexed(sumValue)), res3)

console.log(res4)
value 合計結果
{ A:
   { A01: { '2018-02-01': 2, '2018-02-02': 22 },
     A02: { '2018-02-01': 1 },
     A03: { '2018-02-03': 7 } },
  B: { B01: { '2018-02-02': 1, '2018-02-03': 3, '2018-02-04': 11 } },
  C: { C01: { '2018-02-01': 1 } } }

(b) N階層のグルーピングと集計

次は、汎用的に使えるような実装にしてみます。

任意の処理に対して指定回数だけ R.mapObjIndexed を重ねる処理があると便利なので applyObjIndexedN として実装しました。

(a) で実施したように、階層的なグルーピングは R.mapObjIndexed を階層分重ねた R.groupBy を繰り返し適用していくだけですので R.reduce で実装できます。(以下の groupByMulti

ちなみに、階層的にグルーピングする実装例は Ramda の Cookbook(groupByMultiple) にありましたが、変数へ再代入したりと手続き的な実装内容になっているのが気になりました。

sample.js
const R = require('ramda')

const data = [
    {category: 'A', item: 'A01', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A02', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A01', date: '2018-02-01', value: 1},
    {category: 'A', item: 'A01', date: '2018-02-02', value: 20},
    {category: 'A', item: 'A03', date: '2018-02-03', value: 2},
    {category: 'B', item: 'B01', date: '2018-02-02', value: 1},
    {category: 'A', item: 'A03', date: '2018-02-03', value: 5},
    {category: 'A', item: 'A01', date: '2018-02-02', value: 2},
    {category: 'B', item: 'B01', date: '2018-02-03', value: 3},
    {category: 'B', item: 'B01', date: '2018-02-04', value: 1},
    {category: 'C', item: 'C01', date: '2018-02-01', value: 1},
    {category: 'B', item: 'B01', date: '2018-02-04', value: 10}
]

/* 
  指定回数(n)だけ R.mapObjIndexed を重ねた任意の処理(fn)を
  data を引数にして実行する処理
*/
const applyObjIndexedN = R.curry((n, fn, data) =>
    R.reduce(
        (a, b) => R.mapObjIndexed(a), 
        fn, 
        R.range(0, n)
    )(data)
)

// 階層的なグルーピング処理
const groupByMulti = R.curry((fields, data) => 
    R.reduce(
        (a, b) => applyObjIndexedN(b, R.groupBy(R.prop(fields[b])), a),
        data, 
        R.range(0, fields.length)
    )
)

const cols = ['category', 'item', 'date']

const sumValue = R.reduce((a, b) => a + b.value, 0)

const sumMultiGroups = R.pipe(
    groupByMulti(cols), // グルーピング処理
    applyObjIndexedN(cols.length, sumValue) // 合計処理
)

console.log( sumMultiGroups(data) )
実行結果
> node sample.js

{ A:
   { A01: { '2018-02-01': 2, '2018-02-02': 22 },
     A02: { '2018-02-01': 1 },
     A03: { '2018-02-03': 7 } },
  B: { B01: { '2018-02-02': 1, '2018-02-03': 3, '2018-02-04': 11 } },
  C: { C01: { '2018-02-01': 1 } } }

D3.js で HAR ファイルから散布図を作成

前回、HAR (HTTP ARchive) ファイルから Python で作成した散布図を D3.js を使って SVG として作ってみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20170515/

はじめに

Node.js で D3.js を使用するために d3jsdom をインストールしておきます。

依存ライブラリのインストー
npm install --save d3 jsdom
package.json
{
  ・・・
  "dependencies": {
    "d3": "^4.8.0",
    "jsdom": "^10.1.0"
  }
}

散布図(SVG)の作成

標準入力から HAR ファイルの内容(JSON)を受け取り、time と bodySize の該当位置へ円を描画する事で散布図を表現してみました。

円の色をサブタイプ毎に分けることで色分けしています。

ついでに、ツールチップの表示やマウスイベント処理を追加してみました。

Web ブラウザで SVG ファイルを開いた際にマウスイベントを処理できるように attr を使ってマウスイベント処理(該当する円のサイズを変更)を設定しています。

index.js
const d3 = require('d3');

const jsdom = require('jsdom');
const { JSDOM } = jsdom;

const w = 800;
const h = 500;
const margin = { left: 80, top: 20, bottom: 50, right: 200 };
const legendMargin = { top: 30 };

const fontSize = '12pt';
const circleRadius = 5;

// X軸(time)の値の範囲
const xDomain = [0, 5000];
// Y軸(bodySize)の値の範囲
const yDomain = [0, 180000];

const colorMap = {
    'javascript': 'blue',
    'x-javascript': 'blue',
    'json': 'green',
    'gif': 'tomato',
    'jpeg': 'red',
    'png': 'pink',
    'html': 'lime',
    'css': 'turquoise'
};

const dom = new JSDOM();
const document = dom.window.document;

const toSubType = mtype => mtype.split(';')[0].split('/').pop();
const toColor = stype => colorMap[stype] ? colorMap[stype] : 'black';

process.stdin.resume();

let json = '';

process.stdin.on('data', chunk => json += chunk);

process.stdin.on('end', () => {
    const data = JSON.parse(json);

    const df = data.log.entries.map( d => {
        return {
            'url': d.request.url,
            'subType': toSubType(d.response.content.mimeType),
            'bodySize': d.response.bodySize,
            'time': d.time
        };
    });

    const svg = d3.select(document.body)
        .append('svg')
            .attr('xmlns', 'http://www.w3.org/2000/svg')
            .attr('width', w + margin.left + margin.right)
            .attr('height', h + margin.top + margin.bottom)
            .append('g')
                .attr('transform', `translate(${margin.left}, ${margin.top})`);

    const x = d3.scaleLinear().range([0, w]).domain(xDomain);
    const y = d3.scaleLinear().range([h, 0]).domain(yDomain);

    const xAxis = d3.axisBottom(x);
    const yAxis = d3.axisLeft(y);

    // X軸
    svg.append('g')
        .attr('transform', `translate(0, ${h})`)
        .call(xAxis);
    // Y軸
    svg.append('g')
        .call(yAxis);

    // X軸ラベル
    svg.append('text')
        .attr('x', w / 2)
        .attr('y', h + 40)
        .style('font-size', fontSize)
        .text('Time (ms)');
    // Y軸ラベル
    svg.append('text')
        .attr('x', -h / 2)
        .attr('y', -(margin.left) / 1.5)
        .style('font-size', fontSize)
        .attr('transform', 'rotate(-90)')
        .text('Body Size');

    // 円
    const point = svg.selectAll('circle')
        .data(df)
        .enter()
            .append('circle');
    // 円の設定
    point.attr('class', d => d.subType)
        .attr('cx', d => x(d.time))
        .attr('cy', d => y(d.bodySize))
        .attr('r', circleRadius)
        .attr('fill', d => toColor(d.subType))
        .append('title') // ツールチップの設定
            .text(d => d.url);

    // 凡例
    const legend = svg.selectAll('.legend')
        .data(d3.entries(colorMap))
        .enter()
            .append('g')
                .attr('class', 'legend')
                .attr('transform', (d, i) => {
                    const left = w + margin.left;
                    const top = margin.top + i * legendMargin.top;
                    return `translate(${left}, ${top})`;
                });

    legend.append('circle')
        .attr('r', circleRadius)
        .attr('fill', d => d.value);

    legend.append('text')
        .attr('x', circleRadius * 2)
        .attr('y', 4)
        .style('font-size', fontSize)
        // マウスイベント処理(該当する円のサイズを変更)
        .attr('onmouseover', d => 
            `document.querySelectorAll('circle.${d.key}').forEach(d => d.setAttribute('r', ${circleRadius} * 2))`)
        .attr('onmouseout', d => 
            `document.querySelectorAll('circle.${d.key}').forEach(d => d.setAttribute('r', ${circleRadius}))`)
        // 凡例のラベル
        .text(d => d.key);

    // SVG の出力
    console.log(document.body.innerHTML);
});
実行例
node index.js < a.har > a.svg

実行結果例

a.svg

f:id:fits:20170515220405p:plain

b.svg

f:id:fits:20170515220428p:plain

node-ffi で OpenCL を使う2 - 演算の実行

node-ffi で OpenCL を使う」 に続き、Node.js を使って OpenCL の演算を実施してみます。

サンプルソースhttp://github.com/fits/try_samples/tree/master/blog/20160725/

はじめに

演算の実行には ref-array モジュールを使った方が便利なため、node-ffi をインストールした環境へ追加でインストールしておきます。

ref-array インストール例
> npm install ref-array

OpenCL の演算実行サンプル

今回は配列の要素を 3乗する OpenCL のコード(以下)を Node.js から実行する事にします。

cube.cl
__kernel void cube(
    __global float* input,
    __global float* output,
    const unsigned int count)
{
    int i = get_global_id(0);

    if (i < count) {
        output[i] = input[i] * input[i] * input[i];
    }
}

サンプルコード概要

上記 cube.cl を実行する Node.js サンプルコードの全体像です。(OpenCLAPI は try-finally 内で呼び出しています)

OpenCL 演算の入力値として data 変数の値を使用します。OpenCL のコードはファイルから読み込んで code 変数へ設定しています。

OpenCL APIclCreateXXX で作成したリソースは clReleaseXXX で解放するようなので、解放処理を都度 releaseList へ追加しておき、finally で実行するようにしています。

なお、OpenCL API のエラーコード取得には以下の 2通りがあります。(使用する関数による)

  • 関数の戻り値でエラーコードを取得
  • 関数の引数(ポインタ)でエラーコードを取得
calc.js (全体)
'use strict';

const fs = require('fs');
const ref = require('ref');
const ArrayType = require('ref-array');
const ffi = require('ffi');

const CL_DEVICE_TYPE_DEFAULT = 1;

const CL_MEM_READ_WRITE = (1 << 0);
const CL_MEM_WRITE_ONLY = (1 << 1);
const CL_MEM_READ_ONLY = (1 << 2);
const CL_MEM_USE_HOST_PTR = (1 << 3);
const CL_MEM_ALLOC_HOST_PTR = (1 << 4);
const CL_MEM_COPY_HOST_PTR = (1 << 5);

const intPtr = ref.refType(ref.types.int32);
const uintPtr = ref.refType(ref.types.uint32);
const sizeTPtr = ref.refType('size_t');
const StringArray = ArrayType('string');

const clLib = (process.platform == 'win32') ? 'OpenCL' : 'libOpenCL';

// 使用する OpenCL の関数定義
const openCl = ffi.Library(clLib, {
    'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]],
    'clGetDeviceIDs': ['int', ['size_t', 'ulong', 'uint', sizeTPtr, uintPtr]],
    'clCreateContext': ['pointer', ['pointer', 'uint', sizeTPtr, 'pointer', 'pointer', intPtr]],
    'clReleaseContext': ['int', ['pointer']],
    'clCreateProgramWithSource': ['pointer', ['pointer', 'uint', StringArray, sizeTPtr, intPtr]],
    'clBuildProgram': ['int', ['pointer', 'uint', sizeTPtr, 'string', 'pointer', 'pointer']],
    'clReleaseProgram': ['int', ['pointer']],
    'clCreateKernel': ['pointer', ['pointer', 'string', intPtr]],
    'clReleaseKernel': ['int', ['pointer']],
    'clCreateBuffer': ['pointer', ['pointer', 'ulong', 'size_t', 'pointer', intPtr]],
    'clReleaseMemObject': ['int', ['pointer']],
    'clSetKernelArg': ['int', ['pointer', 'uint', 'size_t', 'pointer']],
    'clCreateCommandQueue': ['pointer', ['pointer', 'size_t', 'ulong', intPtr]],
    'clReleaseCommandQueue': ['int', ['pointer']],
    'clEnqueueReadBuffer': ['int', ['pointer', 'pointer', 'bool', 'size_t', 'size_t', 'pointer', 'uint', 'pointer', 'pointer']],
    'clEnqueueNDRangeKernel': ['int', ['pointer', 'pointer', 'uint', sizeTPtr, sizeTPtr, sizeTPtr, 'uint', 'pointer', 'pointer']]
});

// エラーチェック
const checkError = (err, title = '') => {
    if (err instanceof Buffer) {
        // ポインタの場合はエラーコードを取り出す
        err = intPtr.get(err);
    }

    if (err != 0) {
        throw new Error(`${title} Error: ${err}`);
    }
};

// 演算対象データ
const data = [1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9];

const functionName = process.argv[2]
// OpenCL コードの読み込み(ファイルから)
const code = fs.readFileSync(process.argv[3]);

const releaseList = [];

try {
    ・・・ OpenCL API の呼び出し処理 ・・・

} finally {
    // リソースの解放
    releaseList.reverse().forEach( f => f() );
}

clCreateProgramWithSourceOpenCL のコードを渡す際に、ref-array で作成した String の配列 StringArray を使っています。

OpenCL 処理部分

今回は以下のような OpenCL API を使っています。

番号 概要 OpenCL 関数名
(1) プラットフォーム取得 clGetPlatformIDs
(2) バイス取得 clGetDeviceIDs
(3) コンテキスト作成 clCreateContext
(4) コマンドキュー作成 clCreateCommandQueue
(5) プログラム作成 clCreateProgramWithSource
(6) プログラムのビルド clBuildProgram
(7) カーネル作成 clCreateKernel
(8) 引数用のバッファ作成 clCreateBuffer
(9) 引数の設定 clSetKernelArg
(10) 処理の実行 clEnqueueNDRangeKernel
(11) 結果の取得 clEnqueueReadBuffer

OpenCL のコードを実行するには (6) のように API を使ってビルドする必要があります。

Node.js と OpenCL 間で配列データ等をやりとりするには (8) で作ったバッファを使います。(入力値をバッファへ書き込んで、出力値をバッファから読み出す)

また、今回は clEnqueueNDRangeKernel を使って実行しましたが、clEnqueueTask を使って実行する方法もあります。

calc.js (OpenCL 処理部分)
・・・
try {
    const platformIdsPtr = ref.alloc(sizeTPtr);
    // (1) プラットフォーム取得
    let res = openCl.clGetPlatformIDs(1, platformIdsPtr, null);

    checkError(res, 'clGetPlatformIDs');

    const platformId = sizeTPtr.get(platformIdsPtr);

    const deviceIdsPtr = ref.alloc(sizeTPtr);
    // (2) デバイス取得 (デフォルトをとりあえず使用)
    res = openCl.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_DEFAULT, 1, deviceIdsPtr, null);

    checkError(res, 'clGetDeviceIDs');

    const deviceId = sizeTPtr.get(deviceIdsPtr);

    const errPtr = ref.alloc(intPtr);
    // (3) コンテキスト作成
    const ctx = openCl.clCreateContext(null, 1, deviceIdsPtr, null, null, errPtr);

    checkError(errPtr, 'clCreateContext');
    releaseList.push( () => openCl.clReleaseContext(ctx) );
    // (4) コマンドキュー作成
    const queue = openCl.clCreateCommandQueue(ctx, deviceId, 0, errPtr);

    checkError(errPtr, 'clCreateCommandQueue');
    releaseList.push( () => openCl.clReleaseCommandQueue(queue) );

    const codeArray = new StringArray([code.toString()]);
    // (5) プログラム作成
    const program = openCl.clCreateProgramWithSource(ctx, 1, codeArray, null, errPtr);

    checkError(errPtr, 'clCreateProgramWithSource');
    releaseList.push( () => openCl.clReleaseProgram(program) );
    // (6) プログラムのビルド
    res = openCl.clBuildProgram(program, 1, deviceIdsPtr, null, null, null)

    checkError(res, 'clBuildProgram');
    // (7) カーネル作成
    const kernel = openCl.clCreateKernel(program, functionName, errPtr);

    checkError(errPtr, 'clCreateKernel');
    releaseList.push( () => openCl.clReleaseKernel(kernel) );

    const FixedFloatArray = ArrayType('float', data.length);
    // 入力データ
    const inputData = new FixedFloatArray(data);

    const bufSize = inputData.buffer.length;
    // (8) 引数用のバッファ作成(入力用)し inputData の内容を書き込む
    const inClBuf = openCl.clCreateBuffer(ctx, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, bufSize, inputData.buffer, errPtr);

    checkError(errPtr, 'clCreateBuffer In');
    releaseList.push( () => openCl.clReleaseMemObject(inClBuf) );

    // (8) 引数用のバッファ作成(出力用)
    const outClBuf = openCl.clCreateBuffer(ctx, CL_MEM_WRITE_ONLY, bufSize, null, errPtr);

    checkError(errPtr, 'clCreateBuffer Out');
    releaseList.push( () => openCl.clReleaseMemObject(outClBuf) );

    const inClBufRef = inClBuf.ref();
    // (9) 引数の設定
    res = openCl.clSetKernelArg(kernel, 0, inClBufRef.length, inClBufRef);

    checkError(res, 'clSetKernelArg 0');

    const outClBufRef = outClBuf.ref();
    // (9) 引数の設定
    res = openCl.clSetKernelArg(kernel, 1, outClBufRef.length, outClBufRef);

    checkError(res, 'clSetKernelArg 1');

    const ct = ref.alloc(ref.types.uint32, data.length);

    // (9) 引数の設定
    res = openCl.clSetKernelArg(kernel, 2, ct.length, ct);

    checkError(res, 'clSetKernelArg 2');

    const globalPtr = ref.alloc(sizeTPtr);
    sizeTPtr.set(globalPtr, 0, data.length);
    // (10) 処理の実行
    res = openCl.clEnqueueNDRangeKernel(queue, kernel, 1, null, globalPtr, null, 0, null, null);

    checkError(res, 'clEnqueueNDRangeKernel');

    const resData = new FixedFloatArray();

    // (11) 結果の取得 (outClBuf の内容を resData へ)
    res = openCl.clEnqueueReadBuffer(queue, outClBuf, true, 0, resData.buffer.length, resData.buffer, 0, null, null);

    checkError(res, 'clEnqueueReadBuffer');

    // 結果出力
    for (let i = 0; i < resData.length; i++) {
        console.log(resData[i]);
    }

} finally {
    // リソースの解放
    releaseList.reverse().forEach( f => f() );
}

動作確認

今回は以下の Node.js を使って WindowsLinux の両方で動作確認します。

  • Node.js v6.3.1

(a) Windows で実行

node-ffi で OpenCL を使う」 で構築した環境へ ref-array をインストールして実行しました。

実行結果 (Windows
> node calc.js cube cube.cl

1.3310000896453857
10.648000717163086
35.93699645996094
85.18400573730469
166.375
287.4959716796875
456.532958984375
681.4720458984375
970.2988891601562

(b) Linux で実行

前回 の Docker イメージを使って実行します。

calc.js と cube.cl を /vagrant/work へ配置し、Docker コンテナからは /work でアクセスできるようにマッピングしました。

Docker コンテナ実行
$ docker run --rm -it -v /vagrant/work:/work sample/opencl:0.1 bash
Node.js と必要なモジュールのインストール
# curl -o- https://raw.githubusercontent.com/creationix/nvm/v0.31.2/install.sh | bash
・・・
# source ~/.bashrc

# nvm install v6.3.1
・・・
# npm install -g node-gyp
・・・
# cd /work
# npm install ffi ref-array
・・・

/work 内で実行します。

実行結果 (Linux
# node calc.js cube cube.cl

1.3310000896453857
10.648000717163086
35.93699645996094
85.18400573730469
166.375
287.4959716796875
456.532958984375
681.4720458984375
970.2988891601562

node-ffi で OpenCL を使う

Windows 環境で node-ffi (Node.js Foreign Function Interface) を使って OpenCLAPI を呼び出してみました。

サンプルソースhttp://github.com/fits/try_samples/tree/master/blog/20160627/

なお、OpenCL 上での演算は今回扱いませんが、単純な演算のサンプルは ここ に置いてます。

はじめに

node-ffi のインストール

まずは、node-gyp をインストールしておきます。 node-gyp を Windows 環境で使うには VC++Python 2.7 が必要です。

node-gyp インストール例
> npm install -g node-gyp

node-ffi をインストールします。(モジュール名は node-ffi ではなく ffi です)

node-ffi インストール例
> npm install ffi

node-ffi の使い方

node-ffi では Library 関数を使ってネイティブライブラリの関数をマッピングします。

ffi.Library(<ライブラリ名>, {
    <関数名>: [<戻り値の型>, [<第1引数の型>, <第2引数の型>, ・・・]],
    ・・・
})

引数の型などはライブラリのヘッダーファイルなどを参考にして設定します。

例えば、OpenCL.dll (Windows 環境の場合) の clGetPlatformIDs 関数を Node.js から openCl.clGetPlatformIDs(・・・) で呼び出すには以下のようにします。

Library の使用例
const openCl = ffi.Library('OpenCL', {
    'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]],
    ・・・
});

ref モジュールの refType でポインタ用の型を定義する事が可能です。

refType の使用例
const uintPtr = ref.refType(ref.types.uint32);
const sizeTPtr = ref.refType('size_t');

OpenCL の利用

それでは、下記 OpenCL ランタイムをインストールした Windows 環境で、OpenCLAPI を 3つほど呼び出してみます。

1. OpenCL のデバイスID取得

まずは、以下を実施してみます。

  • (1) clGetPlatformIDs を使ってプラットフォームIDを取得
  • (2) clGetDeviceIDs を使ってデバイスIDを取得

OpenCL (v1.2) のヘッダーファイルを見てみると、プラットフォームIDの型 cl_platform_id やデバイスIDの型 cl_device_id はこれ自体がポインタのようなので ※、これらに該当する型は size_t としました。

※ そのため、プラットフォームID や デバイスID という表現は
   適切ではないかもしれません

node-ffi ではポインタを扱うために Buffer を使います。

そのための補助関数が ref モジュールに用意されており、下記サンプルでは以下を使っています。

  • ref モジュールの alloc を使って指定した型に応じた Buffer を作成
  • 定義した型の get を使って Buffer から値を取得

get を使えば、型のサイズやエンディアンに応じた値を Buffer から取り出してくれます。 (例えば、int32 なら Buffer の readInt32LE や readInt32BE を使って値を取得する)

なお、エラーの有無は clGetPlatformIDs・clGetDeviceIDs の戻り値が 0 かどうかで判定します。(0: 成功、0以外: エラー)

get_device_id.js
'use strict';

const ffi = require('ffi');
const ref = require('ref');

// 定数の定義
const CL_DEVICE_TYPE_DEFAULT = 1;

// ポインタ用の型定義
const uintPtr = ref.refType(ref.types.uint32);
const sizeTPtr = ref.refType('size_t');

// OpenCL の関数定義
const openCl = ffi.Library('OpenCL', {
    'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]],
    'clGetDeviceIDs': ['int', ['size_t', 'ulong', 'uint', sizeTPtr, uintPtr]]
});

// エラーチェック処理
const checkError = (errCode, title = '') => {
    if (errCode != 0) {
        throw new Error(`${title} Error: ${errCode}`);
    }
};

const platformIdsPtr = ref.alloc(sizeTPtr);

// (1) プラットフォームIDを(1つ)取得
let res = openCl.clGetPlatformIDs(1, platformIdsPtr, null);

checkError(res, 'clGetPlatformIDs');

// プラットフォームID(get を使って platformIdsPtr の先頭の値を取得)
const platformId = sizeTPtr.get(platformIdsPtr);

console.log(`platformId: ${platformId}`);

const deviceIdsPtr = ref.alloc(sizeTPtr);

// (2) デバイスIDを(1つ)取得
res = openCl.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_DEFAULT, 1, deviceIdsPtr, null);

checkError(res, 'clGetDeviceIDs');

// デバイスID(get を使って deviceIdsPtr の先頭の値を取得)
const deviceId = sizeTPtr.get(deviceIdsPtr);

console.log(`deviceId: ${deviceId}`);
実行結果
> node get_device_id.js

platformId: 47812336
deviceId: 4404320

2. OpenCL のプラットフォーム情報取得

次は OpenCL のプラットフォーム情報を取得してみます。 プラットフォーム情報は clGetPlatformInfo を使って取得します。

  • (1) clGetPlatformInfo でデータサイズを取得
  • (2) バッファを確保
  • (3) clGetPlatformInfo でデータを取得
platform_info.js
'use strict';

const ffi = require('ffi');
const ref = require('ref');

// 定数の定義
const CL_PLATFORM_PROFILE = 0x0900;
const CL_PLATFORM_VERSION = 0x0901;
const CL_PLATFORM_NAME = 0x0902;
const CL_PLATFORM_VENDOR = 0x0903;
const CL_PLATFORM_EXTENSIONS = 0x0904;
const CL_PLATFORM_HOST_TIMER_RESOLUTION = 0x0905;

const uintPtr = ref.refType(ref.types.uint32);
const sizeTPtr = ref.refType('size_t');

const openCl = ffi.Library('OpenCL', {
    'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]],
    'clGetPlatformInfo': ['int', ['size_t', 'uint', 'size_t', 'pointer', sizeTPtr]]
});

const checkError = (errCode, title = '') => {
    if (errCode != 0) {
        throw new Error(`${title} Error: ${errCode}`);
    }
};

// プラットフォーム情報の出力
const printPlatformInfo = (pid, paramName) => {
    const sPtr = ref.alloc(sizeTPtr);

    // (1) データサイズを取得
    let res = openCl.clGetPlatformInfo(pid, paramName, 0, null, sPtr);

    checkError(res, 'clGetPlatformInfo size');

    // データサイズの値を取り出す
    const size = sizeTPtr.get(sPtr);

    // (2) バッファを確保
    const buf = Buffer.alloc(size);

    // (3) データを取得
    res = openCl.clGetPlatformInfo(pid, paramName, size, buf, null);

    checkError(res, 'clGetPlatformInfo data');

    // 出力
    console.log(buf.toString());
};

const platformIdsPtr = ref.alloc(sizeTPtr);

const res = openCl.clGetPlatformIDs(1, platformIdsPtr, null);

checkError(res, 'clGetPlatformIDs');

const platformId = sizeTPtr.get(platformIdsPtr);

[
    CL_PLATFORM_PROFILE,
    CL_PLATFORM_VERSION,
    CL_PLATFORM_NAME
].forEach( p => 
    printPlatformInfo(platformId, p)
);
実行結果
> node platform_info.js

FULL_PROFILE 
OpenCL 1.2  
Intel(R) OpenCL 

ConvNetJS で MNIST を分類2 - 畳み込みニューラルネット

前回 の続きです。 今回は畳み込みニューラルネットを使って MNIST の手書き数字を分類してみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160328/

準備

誤差・正解率のグラフ化と畳み込みフィルタの画像化を行うため、前回の構成へ d3 等を追加しています。

package.json
{
  "name": "convnetjs_mnist_conv_sample",
  "version": "1.0.0",
  "description": "",
  "main": "index.js",
  "dependencies": {
    "basic-csv": "0.0.2",
    "bluebird": "^3.3.4",
    "convnetjs": "^0.3.0",
    "d3": "^3.5.16",
    "jsdom": "^8.1.0",
    "shuffle-array": "^0.1.2"
  }
}
インストール例
> npm install

(b) 畳み込みニューラルネット

畳み込みニューラルネットでは畳み込み層とプーリング層を組み合わせてレイヤーを構築します。 (実際は全結合層も使います)

名称 処理 ConvNetJS の layer_type
畳み込み層 入力画像へフィルターを適用し特徴量を抽出 conv
プーリング層 入力画像へプーリング演算(フィルター内)を適用 pool

ConvNetJS の畳み込み層・プーリング層は以下のように設定します。

  • sxsy でフィルターのサイズを指定 (sy を省略すると sx と同じ値を適用)
  • pad で入力画像の周囲にゼロパディング(0埋め)する数を指定
  • stride でフィルタの適用位置を縦横に移動する数を指定 (1 の場合は縦横に 1画素ずつずらしてフィルターを適用)

畳み込み層では filters で適用するフィルターの数を指定します。

プーリング層では 「最大プーリング」 (フィルターの値の最大値を採用) を行うようになっており、今回試したバージョンでは (「平均プーリング」等へ) プーリング方法を変更する機能は無さそうでした。

今回は、以下のように畳み込み層・プーリング層が 2回続くような構成にしてみました。

create_layer_conv.js (畳み込みニューラルネットのモデル構築と保存処理)
'use strict';

// 畳み込み層の活性化関数
const act = process.argv[2];
// 出力ファイル名
const jsonDestFile = process.argv[3];

require('./save_model').saveModel(
    [
        { type: 'input', out_sx: 28, out_sy: 28, out_depth: 1 },
        // 1つ目の畳み込み層
        { type: 'conv', sx: 5, filters: 8, stride: 1, pad: 2, activation: act },
        // 1つ目のプーリング層
        { type: 'pool', sx: 2, stride: 2 },
        // 2つ目の畳み込み層
        { type: 'conv', sx: 5, filters: 16, stride: 1, pad: 2, activation: act },
        // 2つ目のプーリング層
        { type: 'pool', sx: 3, stride: 3 },
        { type: 'softmax', num_classes: 10 }
    ],
    jsonDestFile
);

活性化関数へ relu を指定した場合の内部的なレイヤー構成は以下のようになりました。

学習モデルの内部的なレイヤー構成例
input -> conv -> relu -> pool -> conv -> relu -> pool -> fc -> softmax

各レイヤーの出力サイズは以下の通りです。

layer_type out_sx out_sy out_depth
input 28 28 1
conv 28 28 8
relu 28 28 8
pool 14 14 8
conv 14 14 16
relu 14 14 16
pool 4 4 16
fc 1 1 10
softmax 1 1 10

学習と評価

前回 作成した共通処理 (learn_mnist.js 等) を使って学習と評価を実施します。

学習回数を前回と同じ 15回にすると相当時間がかかってしまうので、今回は以下の 4種類で学習・評価を試してみました。

  1. 活性化関数 = relu, 学習回数 = 5
  2. 活性化関数 = relu, 学習回数 = 10
  3. 活性化関数 = sigmoid, 学習回数 = 5
  4. 活性化関数 = sigmoid, 学習回数 = 10

学習回数以外は前回と同じパラメータを使います。

  • 学習回数 = 15
  • バッチサイズ = 100
  • 学習係数 = 0.001
  • 学習係数の決定方法 = adadelta

処理時間は学習回数 5回で 1.5時間、10回で 3時間程度でした。

PC の性能にも依存すると思いますが、1つの CPU で処理するので比較的遅めだと思います。

1. 活性化関数 = relu, 学習回数 = 5 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_conv.js relu models/conv_relu.json

> node learn_mnist.js 5 100 0.001 adadelta models/conv_relu.json results/b-1_conv_relu.json > logs/b-1_conv_relu.log

> node validate_mnist.js results/b-1_conv_relu.json

data size: 10000
accuracy: 0.9785
学習時の誤差と正解率

f:id:fits:20160328195647p:plain

学習後のフィルター

f:id:fits:20160328195700p:plain

2. 活性化関数 = relu, 学習回数 = 10 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node learn_mnist.js 10 100 0.001 adadelta models/conv_relu.json results/b-2_conv_relu.json > logs/b-2_conv_relu.log

> node validate_mnist.js results/b-2_conv_relu.json

data size: 10000
accuracy: 0.9786
学習時の誤差と正解率

f:id:fits:20160328195714p:plain

学習後のフィルター

f:id:fits:20160328195724p:plain

3. 活性化関数 = sigmoid, 学習回数 = 5 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_conv.js sigmoid models/conv_sigmoid.json

> node learn_mnist.js 5 100 0.001 adadelta models/conv_sigmoid.json results/b-3_conv_sigmoid.json > logs/b-3_conv_sigmoid.log

> node validate_mnist.js results/b-3_conv_sigmoid.json

data size: 10000
accuracy: 0.9812
学習時の誤差と正解率

f:id:fits:20160328195741p:plain

学習後のフィルター

f:id:fits:20160328195749p:plain

4. 活性化関数 = sigmoid, 学習回数 = 10 (バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node learn_mnist.js 10 100 0.001 adadelta models/conv_sigmoid.json results/b-4_conv_sigmoid.json > logs/b-4_conv_sigmoid.log

> node validate_mnist.js results/b-4_conv_sigmoid.json

data size: 10000
accuracy: 0.9862
学習時の誤差と正解率

f:id:fits:20160328195802p:plain

学習後のフィルター

f:id:fits:20160328195811p:plain

結果のまとめ

番号 活性化関数 学習回数 正解率
1 relu 5 0.9785
2 relu 10 0.9786
3 sigmoid 5 0.9812
4 sigmoid 10 0.9862

前回の結果よりも高い正解率となりました。

補足

(1) 誤差と正解率のグラフ化

誤差と正解率のログは以下のようなスクリプトでグラフ化しました。

誤差の値が Infinity となる事があったので ※、その場合はとりあえず固定値(以下では 1000)で置換するようにしています。

 ※ ただし、Infinity となったのは前回の階層型ニューラルネットの結果で、
    畳み込みニューラルネットの結果では発生していません
line_chart.js
'use strict';

const Promise = require('bluebird');
const d3 = require('d3');
const jsdom = require('jsdom').jsdom;

const readCSV = Promise.promisify(require('basic-csv').readCSV);

const w = 300;
const h = 300;
const margin = { top: 20, bottom: 50, left: 50, right: 20 };

const xLabels = ['バッチ回数', 'バッチ回数'];
const yLabels = ['誤差', '正解率'];

readCSV(process.argv[2]).then( ds => {
    const document = jsdom();

    const chartLayout = (xnum, w, h, margin) => {
        const borderWidth = w + margin.left + margin.right;
        const borderHeight = h + margin.top + margin.bottom;

        const svg = d3.select(document.body).append('svg')
            .attr('xmlns', 'http://www.w3.org/2000/svg')
            .attr('width', xnum * borderWidth)
            .attr('height', borderHeight);

        return Array(xnum).fill(0).map( (n, i) =>
            svg.append('g')
                .attr('transform', `translate(${i * borderWidth + margin.left}, ${margin.top})`)
        );
    };

    const xDomain = [0, ds.length];
    const yDomain = [1, 0];

    // スケールの定義
    const x = d3.scale.linear().range([0, w]).domain(xDomain);
    const y = d3.scale.linear().range([0, h]).domain(yDomain);

    // 軸の定義
    const xAxis = d3.svg.axis().scale(x).orient('bottom').ticks(5);
    const yAxis = d3.svg.axis().scale(y).orient('left');

    // 折れ線の作成
    const createLine = d3.svg.line()
        .x((d, i) => x(i + 1))
        .y(d => {
            // Infinity の際に固定値を設定
            if (d == 'Infinity') {
                d = 1000;
            }
            return y(d);
        });

    // 折れ線の描画
    const drawLine = (g, data, colIndex, color) => {
        g.append('path')
            .attr('d', createLine(data.map(d => d[colIndex])))
            .attr('stroke', color)
            .attr('fill', 'none');
    };

    const gs = chartLayout(2, w, h, margin);

    // X・Y軸の描画
    gs.forEach( (g, i) => {
        g.append('g')
            .attr('transform', `translate(0, ${h})`)
            .call(xAxis)
            .append('text')
                .attr('x', w / 2)
                .attr('y', 35)
                .style('font-family', 'Sans')
                .text(xLabels[i]);

        g.append('g')
            .call(yAxis)
            .append('text')
                .attr('x', -h / 2)
                .attr('y', -35)
                .attr('transform', 'rotate(-90)')
                .style('font-family', 'Sans')
                .text(yLabels[i]);
    });

    drawLine(gs[0], ds, 2, 'blue');
    drawLine(gs[1], ds, 3, 'blue');

    return document.body.innerHTML;

}).then( html => 
    console.log(html)
);

このスクリプトを使って svg ファイルへ出力し、ImageMagickpng ファイルへ変換しました。

実行例
> node line_chart.js logs/b-1_conv_relu.log > img/b-1.svg

> convert img/b-1.svg img/b-1.png

(2) 畳み込み層のフィルターを svg

学習後の畳み込み層のフィルターを以下のようなスクリプトで画像化(svg)してみました。

実際のフィルターサイズは 5x5 と小さくて分かり難いので、d3.scale を使って 50x50 へ拡大しています。

また、フィルターを可視化するため、d3.scale でフィルターの値が 0 ~ 255 となるように変換しています。

conv_filter_svg.js
'use strict';

const Promise = require('bluebird');
const convnetjs = require('convnetjs');
const d3 = require('d3');
const jsdom = require('jsdom').jsdom;

const readFile = Promise.promisify(require('fs').readFile);

const size = 50;
const margin = 5;

const modelJsonFile = process.argv[2];

// フィルター内の最小値と最大値を抽出
const valueRange = fs => fs.reduce( (acc, f) =>
    [
        Math.min(acc[0], Math.min.apply(null, f.w)),
        Math.max(acc[1], Math.max.apply(null, f.w))
    ], 
    [0, 0]
);

readFile(modelJsonFile).then( json => {
    const net = new convnetjs.Net();
    net.fromJSON(JSON.parse(json));

    return net.layers;
}).then( layers =>
    layers.reduce( (acc, v) => {
        // 畳み込み層のフィルターを抽出
        if (v.layer_type == 'conv' && v.filters) {
            acc.push( v.filters );
        }

        return acc;
    }, [])
).then( filtersList => {
    const document = jsdom();

    const svg = d3.select(document.body)
                    .append('svg')
                    .attr('xmlns', 'http://www.w3.org/2000/svg');

    filtersList.forEach( (fs, j) => {
        const yPos = (size + margin) * j;

        // フィルターの数値を 0 ~ 255 の値へ変換
        const pixelScale = d3.scale.linear()
                            .range([0, 255]).domain(valueRange(fs));

        fs.forEach( (f, i) => {
            const xPos = (size + margin) * i;

            const g = svg.append('g')
                        .attr('transform', `translate(${xPos}, ${yPos})`);

            const xScale = d3.scale.linear()
                            .range([0, size]).domain([0, f.sx]);

            const yScale = d3.scale.linear()
                            .range([0, size]).domain([0, f.sy]);

            for (let y = 0; y < f.sy; y++) {
                for (let x = 0; x < f.sx; x++) {
                    const p = pixelScale( f.get(x, y, 0) );

                    g.append('rect')
                        .attr('x', xScale(x))
                        .attr('y', yScale(y))
                        .attr('width', xScale(1))
                        .attr('height', yScale(1))
                        .attr('fill', d3.rgb(p, p, p));
                }
            }
        });
    });

    return document.body.innerHTML;

}).then( svg =>
    console.log(svg)
).catch( e => 
    console.error(e)
);

こちらも、svg ファイルとして出力し、ImageMagickpng ファイルへ変換しました。

実行例
> node conv_filter_svg.js results/b-1_conv_relu.json > img/b-1_filters.svg

> convert img/b-1_filters.svg img/b-1_filters.png

ConvNetJS で MNIST を分類1 - 階層型ニューラルネット

Node.js で ConvNetJS を使って MNIST の手書き数字を分類してみます。

今回は階層型ニューラルネット、次回は畳み込みニューラルネットを試す予定です。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160322/

準備

npm で convnetjs 等のモジュールをインストールします。(概ね「ConvNetJS で iris を分類 」と同じ構成です)

package.json
{
  "name": "convnetjs_mnist_sample",
  "version": "1.0.0",
  "description": "",
  "main": "index.js",
  "dependencies": {
    "bluebird": "^3.3.4",
    "convnetjs": "^0.3.0",
    "shuffle-array": "^0.1.2"
  }
}
インストール例
> npm install

共通処理

ConvNetJS には、ニューラルネットの学習モデルを JSON で入出力する機能(fromJSONtoJSON)がありますので、今回はこの機能を使って学習前と後のモデルを JSON ファイルで扱う事にします。

MNIST には学習用のデータセット 6万件と評価用のデータセット 1万件がそれぞれ用意されていますので、今回は 6万件の学習データセットを全て学習に使い、1万件の評価データセットで評価する事にしました。

今回と次回で共通に使う処理として以下のようなスクリプトを作成しました。

(1) MNIST データセットのロード

MNIST の学習・評価データセットをロードする処理です。(処理内容に関しては 前回 を参照)

load_mnist.js
'use strict';

const Promise = require('bluebird');
const convnetjs = require('convnetjs');
const fs = require('fs');

const readFile = Promise.promisify(fs.readFile);
const readToBuffer = file => readFile(file).then(r => new Buffer(r, 'binary'));

const loadImages = file =>
    readToBuffer(file)
        .then(buf => {
            const magicNum = buf.readInt32BE(0);

            const num = buf.readInt32BE(4);
            const rowNum = buf.readInt32BE(8);
            const colNum = buf.readInt32BE(12);

            const dataBuf = buf.slice(16);

            const res = Array(num);

            let offset = 0;

            for (let i = 0; i < num; i++) {
                const data = new convnetjs.Vol(colNum, rowNum, 1, 0);

                for (let y = 0; y < rowNum; y++) {
                    for (let x = 0; x < colNum; x++) {

                        const value = dataBuf.readUInt8(offset++);

                        data.set(x, y, 0, value);
                    }
                }

                res[i] = data;
            }

            return res;
        });

const loadLabels = file =>
    readToBuffer(file)
        .then(buf => {
            const magicNum = buf.readInt32BE(0);

            const num = buf.readInt32BE(4);

            const dataBuf = buf.slice(8);

            const res = Array(num);

            for (let i = 0; i < num; i++) {
                res[i] = dataBuf.readUInt8(i);
            }

            return res;
        });

module.exports.loadMnist = (imgFile, labelFile) => 
    Promise.all([
        loadImages(imgFile),
        loadLabels(labelFile)
    ]).spread( (r1, r2) => 
        r2.map((label, i) => new Object({ values: r1[i], label: label }))
    );

(2) 学習モデルの保存

ニューラルネットの学習モデルを JSON ファイルへ保存する処理です。

save_model.js
'use strict';

const Promise = require('bluebird');
const convnetjs = require('convnetjs');

const writeFile = Promise.promisify(require('fs').writeFile);

module.exports.saveModel = (layers, destFile) => {

    const net = new convnetjs.Net();
    // 内部的なレイヤーの構築
    net.makeLayers(layers);

    // JSON 化してファイルへ保存
    writeFile(destFile, JSON.stringify(net.toJSON()))
        .catch( e => console.error(e) );
};

(3) 学習

指定の学習モデル (JSON) を MNIST の学習データセットで学習する処理です。

処理の進行状況を確認できるように batchSize 毎に誤差の平均と正解率を出力するようにしました。 (ただし、以下の方法では batchSize 次第で Trainer によるパラメータの更新タイミングと合わなくなります)

また、MNIST データセットの配列を直接シャッフルする代わりに、0 ~ 59999 のインデックス値から成る配列を用意して、それをシャッフルするようにしています。

learn_mnist.js
'use strict';

const Promise = require('bluebird');
const fs = require('fs');
const shuffle = require('shuffle-array');
const convnetjs = require('convnetjs');
const readFile = Promise.promisify(fs.readFile);
const writeFile = Promise.promisify(fs.writeFile);

const mnist = require('./load_mnist');

const epoch = parseInt(process.argv[2]);
const batchSize = parseInt(process.argv[3]);
const learningRate = parseFloat(process.argv[4]);
const trainMethod = process.argv[5];

const modelJsonFile = process.argv[6];
const modelJsonDestFile = process.argv[7];

// 0 ~ n - 1 を要素とする配列作成
const range = n => {
    const res = Array(n);

    for (let i = 0; i < n; i++) {
        res[i] = i;
    }

    return res;
};
// 指定サイズ毎に誤差の平均と正解率を出力する処理
const createLogger = (logSize, logFunc) => {
    let list = [];
    let counter = 0;

    return (loss, accuracy) => {
        list.push({loss: loss, accuracy: accuracy});

        const size = list.length;

        if (size >= logSize) {
            const res = list.reduce(
                (acc, d) => {
                    acc.loss += d.loss;
                    acc.accuracy += d.accuracy;

                    return acc;
                },
                { loss: 0.0, accuracy: 0 }
            );
            // 出力処理の実行
            logFunc(
                res.loss / size,
                res.accuracy / size,
                counter++
            );

            list = [];
        }
    };
};


Promise.all([
    readFile(modelJsonFile),
    mnist.loadMnist('train-images.idx3-ubyte', 'train-labels.idx1-ubyte')
]).spread( (json, data) => {
    const net = new convnetjs.Net();
    // JSON から学習モデルを復元
    net.fromJSON(JSON.parse(json));

    const trainer = new convnetjs.Trainer(net, {
        method: trainMethod, 
        batch_size: batchSize, 
        learning_rate: learningRate
    });

    range(epoch).forEach(ep => {
        // ログ出力処理の作成
        const log = createLogger(batchSize, (loss, acc, counter) =>
            console.log( [ep, counter, loss, acc].join(',') )
        );

        // インデックス値の配列を作成しシャッフル
        shuffle(range(data.length)).forEach(i => {
            // 該当するデータを取得
            const d = data[i];
            // 学習
            const stats = trainer.train(d.values, d.label);

            log(
                stats.loss,
                (net.getPrediction() == d.label)? 1: 0
            );
        });
    });

    return net;

}).then( net => 
    // 学習モデルの保存
    writeFile(modelJsonDestFile, JSON.stringify(net.toJSON()))
).catch( e => 
    console.error(e)
);

(4) 評価(テスト)

指定の学習モデル (JSON) で MNIST の評価データセットを処理し、正解率を出力する処理です。

validate_mnist.js
'use strict';

const Promise = require('bluebird');
const convnetjs = require('convnetjs');
const readFile = Promise.promisify(require('fs').readFile);

const mnist = require('./load_mnist');

const modelJsonFile = process.argv[2];

Promise.all([
    readFile(modelJsonFile),
    mnist.loadMnist('t10k-images.idx3-ubyte', 't10k-labels.idx1-ubyte')
]).spread( (json, data) => {

    const net = new convnetjs.Net();
    // JSON から学習モデルを復元
    net.fromJSON(JSON.parse(json));

    const accuCount = data.reduce((acc, d) => {
        net.forward(d.values);
        // 正解数のカウント
        return acc + (d.label == net.getPrediction()? 1: 0);
    }, 0);

    console.log(`data size: ${data.length}`);
    // 正解率の出力
    console.log(`accuracy: ${accuCount / data.length}`);
});

(a) 階層型ニューラルネット

MNIST の画像サイズは 28x28 のため、入力層 (type = input) の out_sxout_sy へそれぞれ 28 を設定し、画素値は 1バイトのため out_depth へ 1 を設定します。

出力層 (type = output)は 0 ~ 9 の分類 (10種類) となるため、typesoftmax にして num_classes へ 10 を設定します。

今回は、隠れ層を 1層にして活性化関数とニューロン数をコマンドライン引数で指定できるようにしました。

create_layer_hnn.js
'use strict';

const act = process.argv[2];
const numNeurons = parseInt(process.argv[3]);
const jsonDestFile = process.argv[4];

require('./save_model').saveModel(
    [
        { type: 'input', out_sx: 28, out_sy: 28, out_depth: 1 },
        { type: 'fc', activation: act, num_neurons: numNeurons },
        { type: 'softmax', num_classes: 10 }
    ], 
    jsonDestFile
);

例えば、上記の活性化関数へ relu を指定した場合の makeLayers の結果は下記のようになります。

学習モデルの内部的なレイヤー構成例
input -> fc -> relu -> fc -> softmax

type へ softmax を指定した場合、fc 層が差し込まれるようになっています。

学習と評価

今回は以下の 4種類で学習・評価を試してみました。

  1. 活性化関数 = relu, ニューロン数 = 50
  2. 活性化関数 = relu, ニューロン数 = 300
  3. 活性化関数 = sigmoid, ニューロン数 = 50
  4. 活性化関数 = sigmoid, ニューロン数 = 300

学習回数などのパラメータはとりあえず下記で実行します。

  • 学習回数 = 15
  • バッチサイズ = 100
  • 学習係数 = 0.001
  • 学習係数の決定方法 = adadelta

バッチサイズを 100 とする事で、学習データ 100件毎にパラメータ(重み)の更新が実施されます。

なお、処理時間はニューロン数や学習回数・バッチサイズなどに影響されます。(今回、ニューロン数 50 では 10分程度、300 では 50分程度でした)

1. 活性化関数 = relu, ニューロン数 = 50 (学習回数 = 15, バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_hnn.js relu 50 models/relu_50.json

> node learn_mnist.js 15 100 0.001 adadelta models/relu_50.json results/a-1_relu_50.json > logs/a-1_relu_50.log

> node validate_mnist.js results/a-1_relu_50.json

data size: 10000
accuracy: 0.9455
学習時の誤差と正解率

f:id:fits:20160324201523p:plain

2. 活性化関数 = relu, ニューロン数 = 300 (学習回数 = 15, バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_hnn.js relu 300 models/relu_300.json

> node learn_mnist.js 15 100 0.001 adadelta models/relu_300.json results/a-2_relu_300.json > logs/a-2_relu_300.log

> node validate_mnist.js results/a-2_relu_300.json

data size: 10000
accuracy: 0.965
学習時の誤差と正解率

f:id:fits:20160324201542p:plain

3. 活性化関数 = sigmoid, ニューロン数 = 50 (学習回数 = 15, バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_hnn.js sigmoid 50 models/sigmoid_50.json

> node learn_mnist.js 15 100 0.001 adadelta models/sigmoid_50.json results/a-3_sigmoid_50.json > logs/a-3_sigmoid_50.log

> node validate_mnist.js results/a-3_sigmoid_50.json

data size: 10000
accuracy: 0.9368
学習時の誤差と正解率

f:id:fits:20160324201601p:plain

4. 活性化関数 = sigmoid, ニューロン数 = 300 (学習回数 = 15, バッチサイズ = 100, 学習係数 = 0.001, adadelta)

> node create_layer_hnn.js sigmoid 300 models/sigmoid_300.json

> node learn_mnist.js 15 100 0.001 adadelta models/sigmoid_300.json results/a-4_sigmoid_300.json > logs/a-4_sigmoid_300.log

> node validate_mnist.js results/a-4_sigmoid_300.json

data size: 10000
accuracy: 0.9631
学習時の誤差と正解率

f:id:fits:20160324201617p:plain

結果のまとめ

番号 活性化関数(隠れ層) ニューロン数(隠れ層) 正解率
1 relu 50 0.9455
2 relu 300 0.965
3 sigmoid 50 0.9368
4 sigmoid 300 0.9631

MNIST データセットをパースする

MNIST データセットは、THE MNIST DATABASE of handwritten digits からダウンロード可能な手書き数字のデータです。

機械学習ライブラリ等に標準で用意されてたりしますが、 今回は Node.js と Java でパースしてみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160307/

概要

MNIST は以下のようなデータです。

  • 0 ~ 9 の手書き数字のグレースケール画像
  • 1つの画像は 28x28 のサイズ
  • 学習用に 6万件、テスト用に 1万件が用意されている
  • 画像データとラベルデータのファイルが対になっている

下記のように 4種類のファイルが用意されています。

画像データファイル ラベルデータファイル
学習用 train-images.idx3-ubyte train-labels.idx1-ubyte
テスト用 t10k-images.idx3-ubyte t10k-labels.idx1-ubyte

ファイルフォーマットは以下の通りです。

画像データのファイルフォーマット

オフセット タイプ 内容
0 32bit integer 2051 マジックナンバー
4 32bit integer 60000 or 10000 画像の数
8 32bit integer 28 行の数
12 32bit integer 28 列の数
16~ unsigned byte 0~255 ※ 1バイトずつピクセル値が連続

先頭 16バイトがヘッダー部分、それ以降に 1画像 28 x 28 = 784 バイトのデータが 6万件もしくは 1万件続きます。

 ※ 0 が白、255 が黒

ラベルデータのファイルフォーマット

オフセット タイプ 内容
0 32bit integer 2049 マジックナンバー
4 32bit integer 60000 or 10000 ラベルの数
8~ unsigned byte 0~9 1バイトずつラベル値が連続

先頭 8バイトがヘッダー部分、それ以降にラベルデータが 6万件もしくは 1万件続きます。

(a) ConvNetJS 用に変換 (Node.js)

それでは、Node.js を使ってパースしてみます。

今回は ConvNetJS で使えるように、画像データを convnetjs.Vol として作成するようにしました。

Buffer を使ってバイナリデータを処理します。

MNIST のデータはビッグエンディアンのようなので 32bit integer を読み込む際は readInt32BE を使います。

なお、マジックナンバーに関してはチェックしていません。

load_mnist.js
var Promise = require('bluebird');
var convnetjs = require('convnetjs');

var fs = require('fs');

var readFile = Promise.promisify(fs.readFile);
// ファイル内容から Buffer 作成
var readToBuffer = file => readFile(file).then(r => new Buffer(r, 'binary'));

// 画像データのロード
var loadImages = file =>
    readToBuffer(file)
        .then(buf => {
            var magicNum = buf.readInt32BE(0);
            var num = buf.readInt32BE(4);
            var rowNum = buf.readInt32BE(8);
            var colNum = buf.readInt32BE(12);

            // 画像データ部分を分離(ヘッダー部分を除外)
            var dataBuf = buf.slice(16);

            var res = Array(num);

            var offset = 0;

            for (var i = 0; i < num; i++) {
                var data = new convnetjs.Vol(colNum, rowNum, 1, 0);

                for (var y = 0; y < rowNum; y++) {
                    for (var x = 0; x < colNum; x++) {
                        var value = dataBuf.readUInt8(offset++);
                        data.set(x, y, 0, value);
                    }
                }

                res[i] = data;
            }

            return res;
        });

// ラベルデータのロード
var loadLabels = file =>
    readToBuffer(file)
        .then(buf => {
            var magicNum = buf.readInt32BE(0);
            var num = buf.readInt32BE(4);

            // ラベルデータ部分を分離(ヘッダー部分を除外)
            var dataBuf = buf.slice(8);

            var res = Array(num);

            for (var i = 0; i < num; i++) {
                res[i] = dataBuf.readUInt8(i);
            }

            return res;
        });

// 画像・ラベルデータをロード
module.exports.loadMnist = (imgFile, labelFile) => 
    Promise.all([
        loadImages(imgFile),
        loadLabels(labelFile)
    ]).spread( (r1, r2) => 
        r2.map((label, i) => {
            return { values: r1[i], label: label };
        })
    );

bluebirdconvnetjs パッケージを使っています。

package.json
{
  "name": "mnist_parse_sample",
  "version": "1.0.0",
  "description": "",
  "dependencies": {
    "bluebird": "^3.3.3",
    "convnetjs": "^0.3.0"
  }
}

動作確認

以下のテストコードを使って簡単な動作確認を行います。

手書き数字の大まかな形状を確認できるように、ピクセル値が 0 より大きければ # へ、それ以外は へ変換し出力してみました。

test_load_mnist.js
var mnist = require('./load_mnist');

var printData = d => {
    console.log(`***** number = ${d.label} *****`);

    var v = d.values;

    for (var y = 0; y < v.sy; y++) {
        var r = Array(v.sx);

        for (var x = 0; x < v.sx; x++) {
            // ピクセル値が 0より大きいと '#'、それ以外は ' '
            r[x] = v.get(x, y, 0) > 0 ? '#' : ' ';
        }

        // 文字で表現した画像を出力
        console.log(r.join(''));
    }

    // ピクセル値を出力
    console.log(d.values.w.join(','));
};

mnist.loadMnist(process.argv[2], process.argv[3])
    .then(ds => {
        console.log(`size: ${ds.length}`);

        printData(ds[0]);

        console.log('----------');

        printData(ds[1]);
    });

MNIST 学習用データを使った実行結果は以下の通りです。

実行結果
> node test_load_mnist.js train-images.idx3-ubyte train-labels.idx1-ubyte

size: 60000
***** number = 5 *****





            ############
        ################
       ################
       ###########
        ####### ##
         #####
           ####
           ####
            ######
             ######
              ######
               #####
                 ####
              #######
            ########
          #########
        ##########
      ##########
    ##########
    ########


0,・・・,3,18,18,18,126,136,175,26,166,255,・・・,0
----------
***** number = 0 *****




               #####
              ######
             #########
           ###########
           ###########
          ############
         #########  ###
        ######      ###
       #######      ###
       ####         ###
       ###          ###
      ####          ###
      ####        #####
      ###        #####
      ###       ####
      ###      ####
      #############
      ###########
      #########
       #######




0,・・・,51,159,253,159,50,・・・,0

(b) Deeplearning4J 用に変換 (Java

次は Deeplearning4J で使えるように Javaorg.nd4j.linalg.dataset.DataSet へ変換します。

画像データとラベルデータをそれぞれ org.nd4j.linalg.api.ndarray.INDArray として作成し、最後に DataSet へまとめます。

1件の画像データは 784 (= 28 x 28) 要素のフラットな INDArray として作成します ※。

 ※ Deeplearning4J では、ConvolutionLayerSetup を使って
    画像の高さ・幅を指定できるため、フラットなデータにして問題ありません

ラベルの値は FeatureUtil.toOutcomeVector() メソッドで作成します。

このメソッドによって、ラベルの種類と同じ数 (今回は 10) の要素を持ち、該当するインデックスの値だけが 1 で他は 0 になった配列 (INDArray) を得られます。

このように、変換後のデータセットの構造は (a) のケースとは異なります。

また、Java には unsigned byte という型がなく、ピクセル値を byte として読み込むと -128 ~ 127 になってしまいます。そこで & 0xff して正しい値 (0 ~ 255) となるように変換しています。

ちなみに、Deeplearning4J の org.deeplearning4j.datasets.DataSets.mnist() メソッドを使えば MNIST データセットを取得できますが、その場合はピクセル値が正規化 ※ されて 0 か 1 の値となります。 (シャッフルも実施される)

 ※ ピクセル値 > 30 の場合は 1、それ以外は 0 になる
src/main/java/MnistLoader.java
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.concurrent.CompletableFuture;

public class MnistLoader {
    private final static int LABELS_NUM = 10;

    public static CompletableFuture<DataSet> loadMnist(String imageFileName, String labelFileName) {

        return CompletableFuture.supplyAsync(() -> loadImages(imageFileName))
            .thenCombineAsync(
                CompletableFuture.supplyAsync(() -> loadLabels(labelFileName)),
                DataSet::new
            );
    }

    private static INDArray loadImages(String fileName) {
        try (FileChannel fc = FileChannel.open(Paths.get(fileName))) {

            ByteBuffer headerBuf = ByteBuffer.allocateDirect(16);
            // ヘッダー部分の読み込み
            fc.read(headerBuf);

            headerBuf.rewind();

            int magicNum = headerBuf.getInt();
            int num = headerBuf.getInt();
            int rowNum = headerBuf.getInt();
            int colNum = headerBuf.getInt();

            ByteBuffer buf = ByteBuffer.allocateDirect(num * rowNum * colNum);
            // 画像データ部分の読み込み
            fc.read(buf);

            buf.rewind();

            int dataSize = rowNum * colNum;

            INDArray res = Nd4j.create(num, dataSize);

            for (int n = 0; n < num; n++) {
                INDArray d = Nd4j.create(1, dataSize);

                for (int i = 0; i < dataSize; i++) {
                    // & 0xff する事で unsigned の値へ変換
                    d.putScalar(i, buf.get() & 0xff);
                }

                res.putRow(n, d);
            }
            return res;

        } catch(IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    private static INDArray loadLabels(String fileName) {
        try (FileChannel fc = FileChannel.open(Paths.get(fileName))) {

            ByteBuffer headerBuf = ByteBuffer.allocateDirect(8);
            // ヘッダー部分の読み込み
            fc.read(headerBuf);

            headerBuf.rewind();

            int magicNum = headerBuf.getInt();
            int num = headerBuf.getInt();

            ByteBuffer buf = ByteBuffer.allocateDirect(num);
            // ラベルデータ部分の読み込み
            fc.read(buf);

            buf.rewind();

            INDArray res = Nd4j.create(num, LABELS_NUM);

            for (int i = 0; i < num; i++) {
                res.putRow(i, FeatureUtil.toOutcomeVector(buf.get(), LABELS_NUM));
            }

            return res;

        } catch(IOException ex) {
            throw new RuntimeException(ex);
        }
    }
}

動作確認

Gradle で以下のテストコードを実行し簡単な動作確認を行います。

src/main/java/TestMnistLoader.java
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

import java.util.stream.IntStream;

public class TestMnistLoader {
    public static void main(String... args) {

        MnistLoader.loadMnist(args[0], args[1])
            .thenAccept(ds -> {
                System.out.println("size: " + ds.numExamples());

                printData(ds.get(0));

                System.out.println("----------");

                printData(ds.get(1));
            })
            .join();
    }

    private static void printData(DataSet d) {
        System.out.println("***** labels = " + d.getLabels());

        INDArray v = d.getFeatures();

        IntStream.range(0, 28).forEach( y -> {
            IntStream.range(0, 28).forEach ( x -> {
                System.out.print( v.getInt(x + y * 28) > 0 ? "#" : " " );
            });

            System.out.println();
        });

        System.out.println(d.getFeatures());
    }
}
build.gradle
apply plugin: 'application'

tasks.withType(AbstractCompile)*.options*.encoding = 'UTF-8'

mainClassName = 'TestMnistLoader'

repositories {
    jcenter()
}

dependencies {
    compile 'org.nd4j:nd4j-x86:0.4-rc3.8'
    runtime 'org.slf4j:slf4j-nop:1.7.18'
}

run {
    if (project.hasProperty('args')) {
        args project.args.split(' ')
    }
}

MNIST 学習用データを使った実行結果は以下の通りです。

実行結果
> gradle run -Pargs="train-images.idx3-ubyte train-labels.idx1-ubyte"

・・・
:run
3 06, 2016 10:22:09 午後 com.github.fommil.netlib.BLAS <clinit>
警告: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
3 06, 2016 10:22:10 午後 com.github.fommil.jni.JniLoader liberalLoad
情報: successfully loaded ・・・\Temp\jniloader211398199777152626netlib-native_ref-win-x86_64.dll
****************************************************************
WARNING: COULD NOT LOAD NATIVE SYSTEM BLAS
ND4J performance WILL be reduced
Please install native BLAS library such as OpenBLAS or IntelMKL
See http://nd4j.org/getstarted.html#open for further details
****************************************************************
size: 60000
***** labels = [ 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00]





            ############
        ################
       ################
       ###########
        ####### ##
         #####
           ####
           ####
            ######
             ######
              ######
               #####
                 ####
              #######
            ########
          #########
        ##########
      ##########
    ##########
    ########



[ 0.00, ・・・, 3.00, 18.00, 18.00, 18.00, 126.00, 136.00, 175.00, 26.00, 166.00, 255.00, ・・・, 0.00]
----------
***** labels = [ 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]




               #####
              ######
             #########
           ###########
           ###########
          ############
         #########  ###
        ######      ###
       #######      ###
       ####         ###
       ###          ###
      ####          ###
      ####        #####
      ###        #####
      ###       ####
      ###      ####
      #############
      ###########
      #########
       #######




[ 0.00, ・・・, 51.00, 159.00, 253.00, 159.00, 50.00, ・・・, 0.00]

今回は DataSet を作っているだけなので WARNING を気にする必要は特にありませんが、WARNING を消して正しい状態で実行するには OpenBlas 等の BLAS ライブラリをインストールしてから実行します。

Windows 環境であれば、http://nd4j.org/getstarted.html#open のリンクから ND4J_Win64_OpenBLAS-v0.2.14.zip をダウンロード・解凍し、環境変数 PATH へ設定してから実行するのが簡単だと思います。

また、JniLoader が TEMP ディレクトリへ netlib-native_system-win-x86_64.dll をダウンロードするのを防止するには、この dll も環境変数 PATH へ設定した場所へ配置しておきます。 (dll は Maven の Central Repository からダウンロードできます ※)

 ※ ダウンロードしたファイル名では都合が悪いようなので、
    バージョン番号を除いた netlib-native_system-win-x86_64.dll という
    ファイル名へ変更します
OpenBlas 設定後の実行結果
> set PATH=C:\ND4J_Win64_OpenBLAS-v0.2.14;%PATH%
> gradle run -Pargs="train-images.idx3-ubyte train-labels.idx1-ubyte"

・・・
:run
3 06, 2016 11:31:22 午後 com.github.fommil.jni.JniLoader liberalLoad
情報: successfully loaded ・・・\netlib-native_system-win-x86_64.dll
size: 60000
***** labels = [ 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00]
・・・