MNIST データセットをパースする

MNIST データセットは、THE MNIST DATABASE of handwritten digits からダウンロード可能な手書き数字のデータです。

機械学習ライブラリ等に標準で用意されてたりしますが、 今回は Node.js と Java でパースしてみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160307/

概要

MNIST は以下のようなデータです。

  • 0 ~ 9 の手書き数字のグレースケール画像
  • 1つの画像は 28x28 のサイズ
  • 学習用に 6万件、テスト用に 1万件が用意されている
  • 画像データとラベルデータのファイルが対になっている

下記のように 4種類のファイルが用意されています。

画像データファイル ラベルデータファイル
学習用 train-images.idx3-ubyte train-labels.idx1-ubyte
テスト用 t10k-images.idx3-ubyte t10k-labels.idx1-ubyte

ファイルフォーマットは以下の通りです。

画像データのファイルフォーマット

オフセット タイプ 内容
0 32bit integer 2051 マジックナンバー
4 32bit integer 60000 or 10000 画像の数
8 32bit integer 28 行の数
12 32bit integer 28 列の数
16~ unsigned byte 0~255 ※ 1バイトずつピクセル値が連続

先頭 16バイトがヘッダー部分、それ以降に 1画像 28 x 28 = 784 バイトのデータが 6万件もしくは 1万件続きます。

 ※ 0 が白、255 が黒

ラベルデータのファイルフォーマット

オフセット タイプ 内容
0 32bit integer 2049 マジックナンバー
4 32bit integer 60000 or 10000 ラベルの数
8~ unsigned byte 0~9 1バイトずつラベル値が連続

先頭 8バイトがヘッダー部分、それ以降にラベルデータが 6万件もしくは 1万件続きます。

(a) ConvNetJS 用に変換 (Node.js)

それでは、Node.js を使ってパースしてみます。

今回は ConvNetJS で使えるように、画像データを convnetjs.Vol として作成するようにしました。

Buffer を使ってバイナリデータを処理します。

MNIST のデータはビッグエンディアンのようなので 32bit integer を読み込む際は readInt32BE を使います。

なお、マジックナンバーに関してはチェックしていません。

load_mnist.js
var Promise = require('bluebird');
var convnetjs = require('convnetjs');

var fs = require('fs');

var readFile = Promise.promisify(fs.readFile);
// ファイル内容から Buffer 作成
var readToBuffer = file => readFile(file).then(r => new Buffer(r, 'binary'));

// 画像データのロード
var loadImages = file =>
    readToBuffer(file)
        .then(buf => {
            var magicNum = buf.readInt32BE(0);
            var num = buf.readInt32BE(4);
            var rowNum = buf.readInt32BE(8);
            var colNum = buf.readInt32BE(12);

            // 画像データ部分を分離(ヘッダー部分を除外)
            var dataBuf = buf.slice(16);

            var res = Array(num);

            var offset = 0;

            for (var i = 0; i < num; i++) {
                var data = new convnetjs.Vol(colNum, rowNum, 1, 0);

                for (var y = 0; y < rowNum; y++) {
                    for (var x = 0; x < colNum; x++) {
                        var value = dataBuf.readUInt8(offset++);
                        data.set(x, y, 0, value);
                    }
                }

                res[i] = data;
            }

            return res;
        });

// ラベルデータのロード
var loadLabels = file =>
    readToBuffer(file)
        .then(buf => {
            var magicNum = buf.readInt32BE(0);
            var num = buf.readInt32BE(4);

            // ラベルデータ部分を分離(ヘッダー部分を除外)
            var dataBuf = buf.slice(8);

            var res = Array(num);

            for (var i = 0; i < num; i++) {
                res[i] = dataBuf.readUInt8(i);
            }

            return res;
        });

// 画像・ラベルデータをロード
module.exports.loadMnist = (imgFile, labelFile) => 
    Promise.all([
        loadImages(imgFile),
        loadLabels(labelFile)
    ]).spread( (r1, r2) => 
        r2.map((label, i) => {
            return { values: r1[i], label: label };
        })
    );

bluebirdconvnetjs パッケージを使っています。

package.json
{
  "name": "mnist_parse_sample",
  "version": "1.0.0",
  "description": "",
  "dependencies": {
    "bluebird": "^3.3.3",
    "convnetjs": "^0.3.0"
  }
}

動作確認

以下のテストコードを使って簡単な動作確認を行います。

手書き数字の大まかな形状を確認できるように、ピクセル値が 0 より大きければ # へ、それ以外は へ変換し出力してみました。

test_load_mnist.js
var mnist = require('./load_mnist');

var printData = d => {
    console.log(`***** number = ${d.label} *****`);

    var v = d.values;

    for (var y = 0; y < v.sy; y++) {
        var r = Array(v.sx);

        for (var x = 0; x < v.sx; x++) {
            // ピクセル値が 0より大きいと '#'、それ以外は ' '
            r[x] = v.get(x, y, 0) > 0 ? '#' : ' ';
        }

        // 文字で表現した画像を出力
        console.log(r.join(''));
    }

    // ピクセル値を出力
    console.log(d.values.w.join(','));
};

mnist.loadMnist(process.argv[2], process.argv[3])
    .then(ds => {
        console.log(`size: ${ds.length}`);

        printData(ds[0]);

        console.log('----------');

        printData(ds[1]);
    });

MNIST 学習用データを使った実行結果は以下の通りです。

実行結果
> node test_load_mnist.js train-images.idx3-ubyte train-labels.idx1-ubyte

size: 60000
***** number = 5 *****





            ############
        ################
       ################
       ###########
        ####### ##
         #####
           ####
           ####
            ######
             ######
              ######
               #####
                 ####
              #######
            ########
          #########
        ##########
      ##########
    ##########
    ########


0,・・・,3,18,18,18,126,136,175,26,166,255,・・・,0
----------
***** number = 0 *****




               #####
              ######
             #########
           ###########
           ###########
          ############
         #########  ###
        ######      ###
       #######      ###
       ####         ###
       ###          ###
      ####          ###
      ####        #####
      ###        #####
      ###       ####
      ###      ####
      #############
      ###########
      #########
       #######




0,・・・,51,159,253,159,50,・・・,0

(b) Deeplearning4J 用に変換 (Java

次は Deeplearning4J で使えるように Javaorg.nd4j.linalg.dataset.DataSet へ変換します。

画像データとラベルデータをそれぞれ org.nd4j.linalg.api.ndarray.INDArray として作成し、最後に DataSet へまとめます。

1件の画像データは 784 (= 28 x 28) 要素のフラットな INDArray として作成します ※。

 ※ Deeplearning4J では、ConvolutionLayerSetup を使って
    画像の高さ・幅を指定できるため、フラットなデータにして問題ありません

ラベルの値は FeatureUtil.toOutcomeVector() メソッドで作成します。

このメソッドによって、ラベルの種類と同じ数 (今回は 10) の要素を持ち、該当するインデックスの値だけが 1 で他は 0 になった配列 (INDArray) を得られます。

このように、変換後のデータセットの構造は (a) のケースとは異なります。

また、Java には unsigned byte という型がなく、ピクセル値を byte として読み込むと -128 ~ 127 になってしまいます。そこで & 0xff して正しい値 (0 ~ 255) となるように変換しています。

ちなみに、Deeplearning4J の org.deeplearning4j.datasets.DataSets.mnist() メソッドを使えば MNIST データセットを取得できますが、その場合はピクセル値が正規化 ※ されて 0 か 1 の値となります。 (シャッフルも実施される)

 ※ ピクセル値 > 30 の場合は 1、それ以外は 0 になる
src/main/java/MnistLoader.java
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.FeatureUtil;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.Paths;
import java.util.concurrent.CompletableFuture;

public class MnistLoader {
    private final static int LABELS_NUM = 10;

    public static CompletableFuture<DataSet> loadMnist(String imageFileName, String labelFileName) {

        return CompletableFuture.supplyAsync(() -> loadImages(imageFileName))
            .thenCombineAsync(
                CompletableFuture.supplyAsync(() -> loadLabels(labelFileName)),
                DataSet::new
            );
    }

    private static INDArray loadImages(String fileName) {
        try (FileChannel fc = FileChannel.open(Paths.get(fileName))) {

            ByteBuffer headerBuf = ByteBuffer.allocateDirect(16);
            // ヘッダー部分の読み込み
            fc.read(headerBuf);

            headerBuf.rewind();

            int magicNum = headerBuf.getInt();
            int num = headerBuf.getInt();
            int rowNum = headerBuf.getInt();
            int colNum = headerBuf.getInt();

            ByteBuffer buf = ByteBuffer.allocateDirect(num * rowNum * colNum);
            // 画像データ部分の読み込み
            fc.read(buf);

            buf.rewind();

            int dataSize = rowNum * colNum;

            INDArray res = Nd4j.create(num, dataSize);

            for (int n = 0; n < num; n++) {
                INDArray d = Nd4j.create(1, dataSize);

                for (int i = 0; i < dataSize; i++) {
                    // & 0xff する事で unsigned の値へ変換
                    d.putScalar(i, buf.get() & 0xff);
                }

                res.putRow(n, d);
            }
            return res;

        } catch(IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    private static INDArray loadLabels(String fileName) {
        try (FileChannel fc = FileChannel.open(Paths.get(fileName))) {

            ByteBuffer headerBuf = ByteBuffer.allocateDirect(8);
            // ヘッダー部分の読み込み
            fc.read(headerBuf);

            headerBuf.rewind();

            int magicNum = headerBuf.getInt();
            int num = headerBuf.getInt();

            ByteBuffer buf = ByteBuffer.allocateDirect(num);
            // ラベルデータ部分の読み込み
            fc.read(buf);

            buf.rewind();

            INDArray res = Nd4j.create(num, LABELS_NUM);

            for (int i = 0; i < num; i++) {
                res.putRow(i, FeatureUtil.toOutcomeVector(buf.get(), LABELS_NUM));
            }

            return res;

        } catch(IOException ex) {
            throw new RuntimeException(ex);
        }
    }
}

動作確認

Gradle で以下のテストコードを実行し簡単な動作確認を行います。

src/main/java/TestMnistLoader.java
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

import java.util.stream.IntStream;

public class TestMnistLoader {
    public static void main(String... args) {

        MnistLoader.loadMnist(args[0], args[1])
            .thenAccept(ds -> {
                System.out.println("size: " + ds.numExamples());

                printData(ds.get(0));

                System.out.println("----------");

                printData(ds.get(1));
            })
            .join();
    }

    private static void printData(DataSet d) {
        System.out.println("***** labels = " + d.getLabels());

        INDArray v = d.getFeatures();

        IntStream.range(0, 28).forEach( y -> {
            IntStream.range(0, 28).forEach ( x -> {
                System.out.print( v.getInt(x + y * 28) > 0 ? "#" : " " );
            });

            System.out.println();
        });

        System.out.println(d.getFeatures());
    }
}
build.gradle
apply plugin: 'application'

tasks.withType(AbstractCompile)*.options*.encoding = 'UTF-8'

mainClassName = 'TestMnistLoader'

repositories {
    jcenter()
}

dependencies {
    compile 'org.nd4j:nd4j-x86:0.4-rc3.8'
    runtime 'org.slf4j:slf4j-nop:1.7.18'
}

run {
    if (project.hasProperty('args')) {
        args project.args.split(' ')
    }
}

MNIST 学習用データを使った実行結果は以下の通りです。

実行結果
> gradle run -Pargs="train-images.idx3-ubyte train-labels.idx1-ubyte"

・・・
:run
3 06, 2016 10:22:09 午後 com.github.fommil.netlib.BLAS <clinit>
警告: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
3 06, 2016 10:22:10 午後 com.github.fommil.jni.JniLoader liberalLoad
情報: successfully loaded ・・・\Temp\jniloader211398199777152626netlib-native_ref-win-x86_64.dll
****************************************************************
WARNING: COULD NOT LOAD NATIVE SYSTEM BLAS
ND4J performance WILL be reduced
Please install native BLAS library such as OpenBLAS or IntelMKL
See http://nd4j.org/getstarted.html#open for further details
****************************************************************
size: 60000
***** labels = [ 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00]





            ############
        ################
       ################
       ###########
        ####### ##
         #####
           ####
           ####
            ######
             ######
              ######
               #####
                 ####
              #######
            ########
          #########
        ##########
      ##########
    ##########
    ########



[ 0.00, ・・・, 3.00, 18.00, 18.00, 18.00, 126.00, 136.00, 175.00, 26.00, 166.00, 255.00, ・・・, 0.00]
----------
***** labels = [ 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00]




               #####
              ######
             #########
           ###########
           ###########
          ############
         #########  ###
        ######      ###
       #######      ###
       ####         ###
       ###          ###
      ####          ###
      ####        #####
      ###        #####
      ###       ####
      ###      ####
      #############
      ###########
      #########
       #######




[ 0.00, ・・・, 51.00, 159.00, 253.00, 159.00, 50.00, ・・・, 0.00]

今回は DataSet を作っているだけなので WARNING を気にする必要は特にありませんが、WARNING を消して正しい状態で実行するには OpenBlas 等の BLAS ライブラリをインストールしてから実行します。

Windows 環境であれば、http://nd4j.org/getstarted.html#open のリンクから ND4J_Win64_OpenBLAS-v0.2.14.zip をダウンロード・解凍し、環境変数 PATH へ設定してから実行するのが簡単だと思います。

また、JniLoader が TEMP ディレクトリへ netlib-native_system-win-x86_64.dll をダウンロードするのを防止するには、この dll も環境変数 PATH へ設定した場所へ配置しておきます。 (dll は Maven の Central Repository からダウンロードできます ※)

 ※ ダウンロードしたファイル名では都合が悪いようなので、
    バージョン番号を除いた netlib-native_system-win-x86_64.dll という
    ファイル名へ変更します
OpenBlas 設定後の実行結果
> set PATH=C:\ND4J_Win64_OpenBLAS-v0.2.14;%PATH%
> gradle run -Pargs="train-images.idx3-ubyte train-labels.idx1-ubyte"

・・・
:run
3 06, 2016 11:31:22 午後 com.github.fommil.jni.JniLoader liberalLoad
情報: successfully loaded ・・・\netlib-native_system-win-x86_64.dll
size: 60000
***** labels = [ 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00]
・・・