MNIST データセットをパースする
MNIST データセットは、THE MNIST DATABASE of handwritten digits からダウンロード可能な手書き数字のデータです。
機械学習ライブラリ等に標準で用意されてたりしますが、 今回は Node.js と Java でパースしてみました。
ソースは http://github.com/fits/try_samples/tree/master/blog/20160307/
概要
MNIST は以下のようなデータです。
- 0 ~ 9 の手書き数字のグレースケール画像
- 1つの画像は 28x28 のサイズ
- 学習用に 6万件、テスト用に 1万件が用意されている
- 画像データとラベルデータのファイルが対になっている
下記のように 4種類のファイルが用意されています。
画像データファイル | ラベルデータファイル | |
---|---|---|
学習用 | train-images.idx3-ubyte | train-labels.idx1-ubyte |
テスト用 | t10k-images.idx3-ubyte | t10k-labels.idx1-ubyte |
ファイルフォーマットは以下の通りです。
画像データのファイルフォーマット
オフセット | タイプ | 値 | 内容 |
---|---|---|---|
0 | 32bit integer | 2051 | マジックナンバー |
4 | 32bit integer | 60000 or 10000 | 画像の数 |
8 | 32bit integer | 28 | 行の数 |
12 | 32bit integer | 28 | 列の数 |
16~ | unsigned byte | 0~255 ※ | 1バイトずつピクセル値が連続 |
先頭 16バイトがヘッダー部分、それ以降に 1画像 28 x 28 = 784 バイトのデータが 6万件もしくは 1万件続きます。
※ 0 が白、255 が黒
ラベルデータのファイルフォーマット
オフセット | タイプ | 値 | 内容 |
---|---|---|---|
0 | 32bit integer | 2049 | マジックナンバー |
4 | 32bit integer | 60000 or 10000 | ラベルの数 |
8~ | unsigned byte | 0~9 | 1バイトずつラベル値が連続 |
先頭 8バイトがヘッダー部分、それ以降にラベルデータが 6万件もしくは 1万件続きます。
(a) ConvNetJS 用に変換 (Node.js)
それでは、Node.js を使ってパースしてみます。
今回は ConvNetJS で使えるように、画像データを convnetjs.Vol
として作成するようにしました。
Buffer
を使ってバイナリデータを処理します。
MNIST のデータはビッグエンディアンのようなので 32bit integer を読み込む際は readInt32BE
を使います。
なお、マジックナンバーに関してはチェックしていません。
load_mnist.js
var Promise = require('bluebird'); var convnetjs = require('convnetjs'); var fs = require('fs'); var readFile = Promise.promisify(fs.readFile); // ファイル内容から Buffer 作成 var readToBuffer = file => readFile(file).then(r => new Buffer(r, 'binary')); // 画像データのロード var loadImages = file => readToBuffer(file) .then(buf => { var magicNum = buf.readInt32BE(0); var num = buf.readInt32BE(4); var rowNum = buf.readInt32BE(8); var colNum = buf.readInt32BE(12); // 画像データ部分を分離(ヘッダー部分を除外) var dataBuf = buf.slice(16); var res = Array(num); var offset = 0; for (var i = 0; i < num; i++) { var data = new convnetjs.Vol(colNum, rowNum, 1, 0); for (var y = 0; y < rowNum; y++) { for (var x = 0; x < colNum; x++) { var value = dataBuf.readUInt8(offset++); data.set(x, y, 0, value); } } res[i] = data; } return res; }); // ラベルデータのロード var loadLabels = file => readToBuffer(file) .then(buf => { var magicNum = buf.readInt32BE(0); var num = buf.readInt32BE(4); // ラベルデータ部分を分離(ヘッダー部分を除外) var dataBuf = buf.slice(8); var res = Array(num); for (var i = 0; i < num; i++) { res[i] = dataBuf.readUInt8(i); } return res; }); // 画像・ラベルデータをロード module.exports.loadMnist = (imgFile, labelFile) => Promise.all([ loadImages(imgFile), loadLabels(labelFile) ]).spread( (r1, r2) => r2.map((label, i) => { return { values: r1[i], label: label }; }) );
bluebird
と convnetjs
パッケージを使っています。
package.json
{ "name": "mnist_parse_sample", "version": "1.0.0", "description": "", "dependencies": { "bluebird": "^3.3.3", "convnetjs": "^0.3.0" } }
動作確認
以下のテストコードを使って簡単な動作確認を行います。
手書き数字の大まかな形状を確認できるように、ピクセル値が 0 より大きければ #
へ、それ以外は へ変換し出力してみました。
test_load_mnist.js
var mnist = require('./load_mnist'); var printData = d => { console.log(`***** number = ${d.label} *****`); var v = d.values; for (var y = 0; y < v.sy; y++) { var r = Array(v.sx); for (var x = 0; x < v.sx; x++) { // ピクセル値が 0より大きいと '#'、それ以外は ' ' r[x] = v.get(x, y, 0) > 0 ? '#' : ' '; } // 文字で表現した画像を出力 console.log(r.join('')); } // ピクセル値を出力 console.log(d.values.w.join(',')); }; mnist.loadMnist(process.argv[2], process.argv[3]) .then(ds => { console.log(`size: ${ds.length}`); printData(ds[0]); console.log('----------'); printData(ds[1]); });
MNIST 学習用データを使った実行結果は以下の通りです。
実行結果
> node test_load_mnist.js train-images.idx3-ubyte train-labels.idx1-ubyte size: 60000 ***** number = 5 ***** ############ ################ ################ ########### ####### ## ##### #### #### ###### ###### ###### ##### #### ####### ######## ######### ########## ########## ########## ######## 0,・・・,3,18,18,18,126,136,175,26,166,255,・・・,0 ---------- ***** number = 0 ***** ##### ###### ######### ########### ########### ############ ######### ### ###### ### ####### ### #### ### ### ### #### ### #### ##### ### ##### ### #### ### #### ############# ########### ######### ####### 0,・・・,51,159,253,159,50,・・・,0
(b) Deeplearning4J 用に変換 (Java)
次は Deeplearning4J で使えるように Java で org.nd4j.linalg.dataset.DataSet
へ変換します。
画像データとラベルデータをそれぞれ org.nd4j.linalg.api.ndarray.INDArray
として作成し、最後に DataSet へまとめます。
1件の画像データは 784 (= 28 x 28) 要素のフラットな INDArray として作成します ※。
※ Deeplearning4J では、ConvolutionLayerSetup を使って 画像の高さ・幅を指定できるため、フラットなデータにして問題ありません
ラベルの値は FeatureUtil.toOutcomeVector()
メソッドで作成します。
このメソッドによって、ラベルの種類と同じ数 (今回は 10) の要素を持ち、該当するインデックスの値だけが 1 で他は 0 になった配列 (INDArray) を得られます。
このように、変換後のデータセットの構造は (a) のケースとは異なります。
また、Java には unsigned byte
という型がなく、ピクセル値を byte として読み込むと -128 ~ 127 になってしまいます。そこで & 0xff
して正しい値 (0 ~ 255) となるように変換しています。
ちなみに、Deeplearning4J の org.deeplearning4j.datasets.DataSets.mnist()
メソッドを使えば MNIST データセットを取得できますが、その場合はピクセル値が正規化 ※ されて 0 か 1 の値となります。 (シャッフルも実施される)
※ ピクセル値 > 30 の場合は 1、それ以外は 0 になる
src/main/java/MnistLoader.java
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.util.FeatureUtil; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.Paths; import java.util.concurrent.CompletableFuture; public class MnistLoader { private final static int LABELS_NUM = 10; public static CompletableFuture<DataSet> loadMnist(String imageFileName, String labelFileName) { return CompletableFuture.supplyAsync(() -> loadImages(imageFileName)) .thenCombineAsync( CompletableFuture.supplyAsync(() -> loadLabels(labelFileName)), DataSet::new ); } private static INDArray loadImages(String fileName) { try (FileChannel fc = FileChannel.open(Paths.get(fileName))) { ByteBuffer headerBuf = ByteBuffer.allocateDirect(16); // ヘッダー部分の読み込み fc.read(headerBuf); headerBuf.rewind(); int magicNum = headerBuf.getInt(); int num = headerBuf.getInt(); int rowNum = headerBuf.getInt(); int colNum = headerBuf.getInt(); ByteBuffer buf = ByteBuffer.allocateDirect(num * rowNum * colNum); // 画像データ部分の読み込み fc.read(buf); buf.rewind(); int dataSize = rowNum * colNum; INDArray res = Nd4j.create(num, dataSize); for (int n = 0; n < num; n++) { INDArray d = Nd4j.create(1, dataSize); for (int i = 0; i < dataSize; i++) { // & 0xff する事で unsigned の値へ変換 d.putScalar(i, buf.get() & 0xff); } res.putRow(n, d); } return res; } catch(IOException ex) { throw new RuntimeException(ex); } } private static INDArray loadLabels(String fileName) { try (FileChannel fc = FileChannel.open(Paths.get(fileName))) { ByteBuffer headerBuf = ByteBuffer.allocateDirect(8); // ヘッダー部分の読み込み fc.read(headerBuf); headerBuf.rewind(); int magicNum = headerBuf.getInt(); int num = headerBuf.getInt(); ByteBuffer buf = ByteBuffer.allocateDirect(num); // ラベルデータ部分の読み込み fc.read(buf); buf.rewind(); INDArray res = Nd4j.create(num, LABELS_NUM); for (int i = 0; i < num; i++) { res.putRow(i, FeatureUtil.toOutcomeVector(buf.get(), LABELS_NUM)); } return res; } catch(IOException ex) { throw new RuntimeException(ex); } } }
動作確認
Gradle で以下のテストコードを実行し簡単な動作確認を行います。
src/main/java/TestMnistLoader.java
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import java.util.stream.IntStream; public class TestMnistLoader { public static void main(String... args) { MnistLoader.loadMnist(args[0], args[1]) .thenAccept(ds -> { System.out.println("size: " + ds.numExamples()); printData(ds.get(0)); System.out.println("----------"); printData(ds.get(1)); }) .join(); } private static void printData(DataSet d) { System.out.println("***** labels = " + d.getLabels()); INDArray v = d.getFeatures(); IntStream.range(0, 28).forEach( y -> { IntStream.range(0, 28).forEach ( x -> { System.out.print( v.getInt(x + y * 28) > 0 ? "#" : " " ); }); System.out.println(); }); System.out.println(d.getFeatures()); } }
build.gradle
apply plugin: 'application' tasks.withType(AbstractCompile)*.options*.encoding = 'UTF-8' mainClassName = 'TestMnistLoader' repositories { jcenter() } dependencies { compile 'org.nd4j:nd4j-x86:0.4-rc3.8' runtime 'org.slf4j:slf4j-nop:1.7.18' } run { if (project.hasProperty('args')) { args project.args.split(' ') } }
MNIST 学習用データを使った実行結果は以下の通りです。
実行結果
> gradle run -Pargs="train-images.idx3-ubyte train-labels.idx1-ubyte" ・・・ :run 3 06, 2016 10:22:09 午後 com.github.fommil.netlib.BLAS <clinit> 警告: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS 3 06, 2016 10:22:10 午後 com.github.fommil.jni.JniLoader liberalLoad 情報: successfully loaded ・・・\Temp\jniloader211398199777152626netlib-native_ref-win-x86_64.dll **************************************************************** WARNING: COULD NOT LOAD NATIVE SYSTEM BLAS ND4J performance WILL be reduced Please install native BLAS library such as OpenBLAS or IntelMKL See http://nd4j.org/getstarted.html#open for further details **************************************************************** size: 60000 ***** labels = [ 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00] ############ ################ ################ ########### ####### ## ##### #### #### ###### ###### ###### ##### #### ####### ######## ######### ########## ########## ########## ######## [ 0.00, ・・・, 3.00, 18.00, 18.00, 18.00, 126.00, 136.00, 175.00, 26.00, 166.00, 255.00, ・・・, 0.00] ---------- ***** labels = [ 1.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00] ##### ###### ######### ########### ########### ############ ######### ### ###### ### ####### ### #### ### ### ### #### ### #### ##### ### ##### ### #### ### #### ############# ########### ######### ####### [ 0.00, ・・・, 51.00, 159.00, 253.00, 159.00, 50.00, ・・・, 0.00]
今回は DataSet を作っているだけなので WARNING を気にする必要は特にありませんが、WARNING を消して正しい状態で実行するには OpenBlas 等の BLAS ライブラリをインストールしてから実行します。
Windows 環境であれば、http://nd4j.org/getstarted.html#open のリンクから ND4J_Win64_OpenBLAS-v0.2.14.zip をダウンロード・解凍し、環境変数 PATH へ設定してから実行するのが簡単だと思います。
また、JniLoader が TEMP ディレクトリへ netlib-native_system-win-x86_64.dll
をダウンロードするのを防止するには、この dll も環境変数 PATH へ設定した場所へ配置しておきます。 (dll は Maven の Central Repository からダウンロードできます ※)
※ ダウンロードしたファイル名では都合が悪いようなので、 バージョン番号を除いた netlib-native_system-win-x86_64.dll という ファイル名へ変更します
OpenBlas 設定後の実行結果
> set PATH=C:\ND4J_Win64_OpenBLAS-v0.2.14;%PATH% > gradle run -Pargs="train-images.idx3-ubyte train-labels.idx1-ubyte" ・・・ :run 3 06, 2016 11:31:22 午後 com.github.fommil.jni.JniLoader liberalLoad 情報: successfully loaded ・・・\netlib-native_system-win-x86_64.dll size: 60000 ***** labels = [ 0.00, 0.00, 0.00, 0.00, 0.00, 1.00, 0.00, 0.00, 0.00, 0.00] ・・・