Deeplearning4J で MNIST を分類

Deeplearning4J で iris を分類」 に続いて、畳み込みニューラルネットを使った MNIST の分類を試します。

Deeplearning4J のバージョンが上がって、@Grab を使った Groovy 上での実行が上手くいかなかったので、今回は Kotlin で実装し Gradle で実行します。

  • Gradle 3.1
  • Kotlin 1.0.4

ソースは http://github.com/fits/try_samples/tree/master/blog/20161010/

はじめに

今回は、Gradle のマルチプロジェクトを使って以下の処理をサブプロジェクトとしました。

Gradle のビルド定義は以下の通り。

build.gradle
def slf4jVer = '1.7.21'

buildscript {
    repositories {
        jcenter()
    }

    dependencies {
        classpath 'org.jetbrains.kotlin:kotlin-gradle-plugin:1.0.4'
    }
}
// サブプロジェクトの共通設定
subprojects {
    apply plugin: 'kotlin'
    apply plugin: 'application'

    repositories {
        jcenter()
    }

    dependencies {
        compile 'org.jetbrains.kotlin:kotlin-stdlib:1.0.4'

        compile('org.deeplearning4j:deeplearning4j-core:0.6.0') {
            // エラーの回避策 (Could not find javacpp-presets-${os.name}-${os.arch}.jar)
            exclude group: 'org.bytedeco', module: 'javacpp-presets'
        }

        runtime 'org.nd4j:nd4j-native-platform:0.6.0'
    }

    run {
        // 実行時引数
        if (project.hasProperty('args')) {
            args project.args.split(' ')
        }
    }
}
// (1) 畳み込みニューラルネットのモデル作成(JSON で出力)
project(':conv_model') {
    mainClassName = 'ConvModelKt'

    dependencies {
        runtime "org.slf4j:slf4j-nop:${slf4jVer}"
    }
}
// (2) 学習処理
project(':learn_mnist') {
    mainClassName = 'LearnMnistKt'

    dependencies {
        runtime "org.slf4j:slf4j-simple:${slf4jVer}"
    }
}
// (3) 評価処理
project(':eval_mnist') {
    mainClassName = 'EvalMnistKt'

    dependencies {
        runtime "org.slf4j:slf4j-nop:${slf4jVer}"
    }
}
settings.gradle
include 'conv_model', 'learn_mnist', 'eval_mnist'

ファイル構成は以下の通りです。

ファイル構成
  • build.gradle
  • settings.gradle
  • conv_model/src/main/kotlin/convModel.kt
  • learn_mnist/src/main/kotlin/learnMnist.kt
  • eval_mnist/src/main/kotlin/evalMnist.kt

(1) 畳み込みニューラルネットのモデル作成(JSON で出力)

畳み込みニューラルネットの構成情報を JSON 化して標準出力します。

MnistDataSetIterator の MNIST データセットに対して畳み込みニューラルネットを行うには InputType.convolutionalFlat(28, 28, 1)setInputType します。

conv_model/src/main/kotlin/convModel.kt
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer
import org.nd4j.linalg.lossfunctions.LossFunctions

fun main(args: Array<String>) {
    val builder = NeuralNetConfiguration.Builder()
        .iterations(3) // 3回繰り返し
        .list(
            // 畳み込み層
            ConvolutionLayer.Builder(5, 5)
                .nIn(1)
                .nOut(8)
                .padding(2, 2)
                .activation("relu")
                .build()
            ,
            // プーリング層(最大プーリング)
            SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX, intArrayOf(2, 2))
                .stride(2, 2)
                .build()
            ,
            // 畳み込み層
            ConvolutionLayer.Builder(5, 5)
                .nOut(16)
                .padding(1, 1)
                .activation("relu")
                .build()
            ,
            // プーリング層(最大プーリング)
            SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX, intArrayOf(3, 3))
                .stride(3, 3)
                .build()
            ,
            OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                .nOut(10)
                .activation("softmax")
                .build()
        )
        // MNIST データセットに対する畳み込みニューラルネット用の入力タイプ設定
        .setInputType(InputType.convolutionalFlat(28, 28, 1))

    // JSON 化して出力
    println(builder.build().toJson())
}

(2) 学習処理

(1) で出力した JSON ファイルを使って学習を実施し、学習後のパラメータ(重み)を JSON ファイルへ出力します。

MNIST データセットには MnistDataSetIterator を使いました ※。

 ※ train 引数(今回使用したコンストラクタの第2引数)が true の場合に学習用、
    false の場合に評価用のデータセットとなります
learn_mnist/src/main/kotlin/learnMnist.kt
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.optimize.listeners.ScoreIterationListener

import org.nd4j.linalg.api.ndarray.NdArrayJSONWriter

import java.io.File

fun main(args: Array<String>) {
    // ニューラルネットのモデルファイル(JSON)
    val confJson = File(args[0]).readText()
    // パラメータ(重み)の出力ファイル名
    val destFile = args[1]

    val conf = MultiLayerConfiguration.fromJson(confJson)
    val network = MultiLayerNetwork(conf)

    network.init()
    // スコア(誤差)の出力
    network.setListeners(ScoreIterationListener())

    // MNIST 学習用データ(バッチサイズ 100)
    val trainData = MnistDataSetIterator(100, true, 0)
    // 学習
    network.fit(trainData)
    // 学習後のパラメータ(重み)を JSON ファイルへ出力
    NdArrayJSONWriter.write(network.params(), destFile)
}

(3) 評価処理

(1) と (2) の結果を使って評価します。

eval_mnist/src/main/kotlin/evalMnist.kt
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork

import org.nd4j.linalg.api.ndarray.NdArrayJSONReader

import java.io.File

fun main(args: Array<String>) {
    // ニューラルネットのモデルファイル(JSON)
    val confJson = File(args[0]).readText()
    // パラメータ(重み)のロード
    val params = NdArrayJSONReader().read(File(args[1]))

    val conf = MultiLayerConfiguration.fromJson(confJson)
    val network = MultiLayerNetwork(conf, params)
    // MNIST 評価用データ
    val testData = MnistDataSetIterator(1, false, 0)
    // 評価
    val res = network.evaluate(testData)

    println(res.stats())
}

実行

まず、畳み込みニューラルネットの構成を JSON ファイルへ出力します。

(1) 畳み込みニューラルネットのモデル作成
> gradle -q :conv_model:run > conv_model.json

次に、学習を実施します。

org.reflections.Reflections - could not create Vfs.Dir from url という警告ログが出力されましたが、特に支障は無さそうです。

(2) 学習
> gradle -q :learn_mnist:run -Pargs="../conv_model.json ../conv_params.json"

・・・
[main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1790 is 0.07888245582580566
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase

最後に、評価を実施します。

(3) 評価
> gradle -q :eval_mnist:run -Pargs="../conv_model.json ../conv_params.json"

Examples labeled as 0 classified by model as 0: 944 times
Examples labeled as 0 classified by model as 2: 7 times
Examples labeled as 0 classified by model as 5: 4 times
・・・
Examples labeled as 9 classified by model as 7: 24 times
Examples labeled as 9 classified by model as 8: 22 times
Examples labeled as 9 classified by model as 9: 937 times


==========================Scores========================================
 Accuracy:  0.9432
 Precision: 0.9463
 Recall:    0.9427
 F1 Score:  0.9445
========================================================================