node-ffi で OpenCL を使う

Windows 環境で node-ffi (Node.js Foreign Function Interface) を使って OpenCLAPI を呼び出してみました。

サンプルソースhttp://github.com/fits/try_samples/tree/master/blog/20160627/

なお、OpenCL 上での演算は今回扱いませんが、単純な演算のサンプルは ここ に置いてます。

はじめに

node-ffi のインストール

まずは、node-gyp をインストールしておきます。 node-gyp を Windows 環境で使うには VC++Python 2.7 が必要です。

node-gyp インストール例
> npm install -g node-gyp

node-ffi をインストールします。(モジュール名は node-ffi ではなく ffi です)

node-ffi インストール例
> npm install ffi

node-ffi の使い方

node-ffi では Library 関数を使ってネイティブライブラリの関数をマッピングします。

ffi.Library(<ライブラリ名>, {
    <関数名>: [<戻り値の型>, [<第1引数の型>, <第2引数の型>, ・・・]],
    ・・・
})

引数の型などはライブラリのヘッダーファイルなどを参考にして設定します。

例えば、OpenCL.dll (Windows 環境の場合) の clGetPlatformIDs 関数を Node.js から openCl.clGetPlatformIDs(・・・) で呼び出すには以下のようにします。

Library の使用例
const openCl = ffi.Library('OpenCL', {
    'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]],
    ・・・
});

ref モジュールの refType でポインタ用の型を定義する事が可能です。

refType の使用例
const uintPtr = ref.refType(ref.types.uint32);
const sizeTPtr = ref.refType('size_t');

OpenCL の利用

それでは、下記 OpenCL ランタイムをインストールした Windows 環境で、OpenCLAPI を 3つほど呼び出してみます。

1. OpenCL のデバイスID取得

まずは、以下を実施してみます。

  • (1) clGetPlatformIDs を使ってプラットフォームIDを取得
  • (2) clGetDeviceIDs を使ってデバイスIDを取得

OpenCL (v1.2) のヘッダーファイルを見てみると、プラットフォームIDの型 cl_platform_id やデバイスIDの型 cl_device_id はこれ自体がポインタのようなので ※、これらに該当する型は size_t としました。

※ そのため、プラットフォームID や デバイスID という表現は
   適切ではないかもしれません

node-ffi ではポインタを扱うために Buffer を使います。

そのための補助関数が ref モジュールに用意されており、下記サンプルでは以下を使っています。

  • ref モジュールの alloc を使って指定した型に応じた Buffer を作成
  • 定義した型の get を使って Buffer から値を取得

get を使えば、型のサイズやエンディアンに応じた値を Buffer から取り出してくれます。 (例えば、int32 なら Buffer の readInt32LE や readInt32BE を使って値を取得する)

なお、エラーの有無は clGetPlatformIDs・clGetDeviceIDs の戻り値が 0 かどうかで判定します。(0: 成功、0以外: エラー)

get_device_id.js
'use strict';

const ffi = require('ffi');
const ref = require('ref');

// 定数の定義
const CL_DEVICE_TYPE_DEFAULT = 1;

// ポインタ用の型定義
const uintPtr = ref.refType(ref.types.uint32);
const sizeTPtr = ref.refType('size_t');

// OpenCL の関数定義
const openCl = ffi.Library('OpenCL', {
    'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]],
    'clGetDeviceIDs': ['int', ['size_t', 'ulong', 'uint', sizeTPtr, uintPtr]]
});

// エラーチェック処理
const checkError = (errCode, title = '') => {
    if (errCode != 0) {
        throw new Error(`${title} Error: ${errCode}`);
    }
};

const platformIdsPtr = ref.alloc(sizeTPtr);

// (1) プラットフォームIDを(1つ)取得
let res = openCl.clGetPlatformIDs(1, platformIdsPtr, null);

checkError(res, 'clGetPlatformIDs');

// プラットフォームID(get を使って platformIdsPtr の先頭の値を取得)
const platformId = sizeTPtr.get(platformIdsPtr);

console.log(`platformId: ${platformId}`);

const deviceIdsPtr = ref.alloc(sizeTPtr);

// (2) デバイスIDを(1つ)取得
res = openCl.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_DEFAULT, 1, deviceIdsPtr, null);

checkError(res, 'clGetDeviceIDs');

// デバイスID(get を使って deviceIdsPtr の先頭の値を取得)
const deviceId = sizeTPtr.get(deviceIdsPtr);

console.log(`deviceId: ${deviceId}`);
実行結果
> node get_device_id.js

platformId: 47812336
deviceId: 4404320

2. OpenCL のプラットフォーム情報取得

次は OpenCL のプラットフォーム情報を取得してみます。 プラットフォーム情報は clGetPlatformInfo を使って取得します。

  • (1) clGetPlatformInfo でデータサイズを取得
  • (2) バッファを確保
  • (3) clGetPlatformInfo でデータを取得
platform_info.js
'use strict';

const ffi = require('ffi');
const ref = require('ref');

// 定数の定義
const CL_PLATFORM_PROFILE = 0x0900;
const CL_PLATFORM_VERSION = 0x0901;
const CL_PLATFORM_NAME = 0x0902;
const CL_PLATFORM_VENDOR = 0x0903;
const CL_PLATFORM_EXTENSIONS = 0x0904;
const CL_PLATFORM_HOST_TIMER_RESOLUTION = 0x0905;

const uintPtr = ref.refType(ref.types.uint32);
const sizeTPtr = ref.refType('size_t');

const openCl = ffi.Library('OpenCL', {
    'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]],
    'clGetPlatformInfo': ['int', ['size_t', 'uint', 'size_t', 'pointer', sizeTPtr]]
});

const checkError = (errCode, title = '') => {
    if (errCode != 0) {
        throw new Error(`${title} Error: ${errCode}`);
    }
};

// プラットフォーム情報の出力
const printPlatformInfo = (pid, paramName) => {
    const sPtr = ref.alloc(sizeTPtr);

    // (1) データサイズを取得
    let res = openCl.clGetPlatformInfo(pid, paramName, 0, null, sPtr);

    checkError(res, 'clGetPlatformInfo size');

    // データサイズの値を取り出す
    const size = sizeTPtr.get(sPtr);

    // (2) バッファを確保
    const buf = Buffer.alloc(size);

    // (3) データを取得
    res = openCl.clGetPlatformInfo(pid, paramName, size, buf, null);

    checkError(res, 'clGetPlatformInfo data');

    // 出力
    console.log(buf.toString());
};

const platformIdsPtr = ref.alloc(sizeTPtr);

const res = openCl.clGetPlatformIDs(1, platformIdsPtr, null);

checkError(res, 'clGetPlatformIDs');

const platformId = sizeTPtr.get(platformIdsPtr);

[
    CL_PLATFORM_PROFILE,
    CL_PLATFORM_VERSION,
    CL_PLATFORM_NAME
].forEach( p => 
    printPlatformInfo(platformId, p)
);
実行結果
> node platform_info.js

FULL_PROFILE 
OpenCL 1.2  
Intel(R) OpenCL 

Keras で iris を分類

Theano・TensorFlow 用のディープラーニングライブラリ Keras を使って、階層型ニューラルネットによる iris の分類を試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160531/

準備

今回は Docker コンテナで Keras を実行するため、Docker の公式イメージ python をベースに Keras をインストールした Docker イメージを用意します。

Docker イメージ作成

/vagrant/work/keras/Dockerfile
FROM python

RUN apt-get update && apt-get upgrade -y

RUN pip install --upgrade pip

RUN pip install keras
RUN pip install scikit-learn

RUN apt-get clean

sklearn の iris データセットを使うために scikit-learn もインストールしていますが、Keras を使うだけなら不要です。

また、Theano はデフォルトでインストールされるようですが、TensorFlow を使う場合は別途インストールする必要がありそうです。(今回は Theano を使います)

Dockerfile に対して docker build を実行して Docker イメージを作成します。

Docker ビルド
$ cd /vagrant/work/keras
$ docker build --no-cache -t sample/py-keras:0.1 .

(1) 学習

まずは、iris のデータセットを全て使って学習してみます。

ConvnetJS で iris を分類」 で実施したように、出力層の活性化関数はソフトマックス、損失関数(誤差関数)に交差エントロピーを使います。

sklearn の iris データセットのように、ラベルデータ (target の値) が数値 (0 ~ 2) の場合 ※ に交差エントロピーを実施するには compile の引数で loss = 'sparse_categorical_crossentropy' と指定すれば良さそうです。

※ iris.target の内容

    [0 0 0 0 0 0 0 ・・・
     0 0 0 0 0 0 0 ・・・
     1 1 1 1 1 1 1 ・・・
     2 2 2 2 2 2 2 ・・・
     2 2]

とりあえず、入力層 - 隠れ層 - 出力層 (隠れ層のニューロン数 8)というレイヤー構成を使い、学習処理(fit)はミニバッチサイズを 1 として 50 回繰り返すように指定しました。

/vagrant/work/iris_sample1.py
from keras.models import Sequential
from keras.layers.core import Dense, Activation
from sklearn import datasets

# モデルの定義
model = Sequential()

# 隠れ層の定義
model.add(Dense(input_dim = 4, output_dim = 8))
# 隠れ層の活性化関数
model.add(Activation('relu'))

# 出力層の定義
model.add(Dense(output_dim = 3))
# 出力層の活性化関数
model.add(Activation('softmax'))

model.compile(loss = 'sparse_categorical_crossentropy', optimizer = 'sgd', metrics = ['accuracy'])

iris = datasets.load_iris()

# 学習
model.fit(iris.data, iris.target, nb_epoch = 50, batch_size = 1)

docker run で Keras 用の Docker コンテナを起動した後、コンテナ内で上記を実行します。

実行例
$ docker run --rm -it -v /vagrant/work:/work sample/py-keras:0.1 bash

# cd /work
# python iris_sample1.py

Using Theano backend.
Epoch 1/50
  1/150 [..............................] - ETA: 0s - loss: 3.4213 - acc: 0.0000e  2/150 [..............................] - ETA: 0s - loss: 2.2539 - acc: 0.0000e・・・
Epoch 49/50
150/150 [==============================] - 0s - loss: 0.1225 - acc: 0.9533
Epoch 50/50
150/150 [==============================] - 0s - loss: 0.1525 - acc: 0.9333

誤差(loss)と正解率(acc)が出力されました。

(2) 学習と評価

次は、iris データセットを学習用と評価用に分割して学習と評価をそれぞれ実行してみます。

データセットを直接シャッフルする代わりに、0 ~ 149 の数値をランダムに配置した配列を numpy の random.permutation で作成し、学習・評価用のデータ分割に利用しました。

/vagrant/work/iris_sample2.py
import sys

from keras.models import Sequential
from keras.layers.core import Dense, Activation
from sklearn import datasets
import numpy as np

# 学習・評価用のデータ分割率
trainEvalRate = 0.7

# 学習の繰り返し回数
epoch = int(sys.argv[1])

# 隠れ層のニューロン数
neuNum = int(sys.argv[2])
# 隠れ層の活性化関数
act = sys.argv[3]

optm = sys.argv[4]

model = Sequential()

model.add(Dense(input_dim = 4, output_dim = neuNum))
model.add(Activation(act))

model.add(Dense(output_dim = 3))
model.add(Activation('softmax'))

model.compile(
    loss = 'sparse_categorical_crossentropy', 
    optimizer = optm, 
    metrics = ['accuracy']
)

iris = datasets.load_iris()

data_size = len(iris.data)
train_size = int(data_size * trainEvalRate)

perm = np.random.permutation(data_size)

# 学習用データ
x_train = iris.data[ perm[0:train_size] ]
y_train = iris.target[ perm[0:train_size] ]

# 学習
model.fit(x_train, y_train, nb_epoch = epoch, batch_size = 1)

print('-----')

# 評価用データ
x_test = iris.data[ perm[train_size:] ]
y_test = iris.target[ perm[train_size:] ]

# 評価
res = model.evaluate(x_test, y_test, batch_size = 1)

print(res)
実行例
# python iris_sample2.py 50 6 sigmoid adam

Using Theano backend.
Epoch 1/50
105/105 [==============================] - 0s - loss: 1.0751 - acc: 0.3524
Epoch 2/50
105/105 [==============================] - 0s - loss: 1.0417 - acc: 0.3524
・・・
Epoch 49/50
105/105 [==============================] - 0s - loss: 0.3503 - acc: 0.9714
Epoch 50/50
105/105 [==============================] - 0s - loss: 0.3458 - acc: 0.9714
-----
45/45 [==============================] - 0s
[0.35189295262098313, 0.97777777777777775]

JMX で Java Flight Recorder (JFR) を実行する

Java Flight Recorder (JFR) は Java Mission Control (jmc) や jcmd コマンドから実行できますが、今回は以下の MBean を使って JMX から実行してみます。

  • com.sun.management:type=DiagnosticCommand

この MBean は以下のような操作を備えており(戻り値は全て String)、jcmd コマンドと同じ事ができるようです。

  • jfrCheck
  • jfrDump
  • jfrStop
  • jfrStart
  • vmCheckCommercialFeatures
  • vmCommandLine
  • vmFlags
  • vmSystemProperties
  • vmUnlockCommercialFeatures
  • vmUptime
  • vmVersion
  • vmNativeMemory
  • gcRotateLog
  • gcRun
  • gcRunFinalization
  • gcClassHistogram
  • gcClassStats
  • threadPrint

(a) JFR の実行

JMX を使う方法はいくつかありますが、今回は Attach API でローカルの VM へアタッチし、startLocalManagementAgent メソッドJMX エージェントを適用する方法を用いました。

DiagnosticCommand には java.lang.management.ThreadMXBean のようなラッパーが用意されていないようなので GroovyMBean を使う事にします。

jfrStart の引数は jcmd コマンドと同じものを String 配列にして渡すだけのようです。(jfrStart 以外も基本的に同じ)

また、JFR の実行には Commercial Features のアンロックが必要です。

jfr_run.groovy
import com.sun.tools.attach.VirtualMachine

import javax.management.remote.JMXConnectorFactory
import javax.management.remote.JMXServiceURL

def pid = args[0]
def duration = args[1]
def fileName = args[2]

// 指定の JVM プロセスへアタッチ
def vm = VirtualMachine.attach(pid)

try {
    // JMX エージェントを適用
    def jmxuri = vm.startLocalManagementAgent()

    JMXConnectorFactory.connect(new JMXServiceURL(jmxuri)).withCloseable {
        def server = it.getMBeanServerConnection()

        // MBean の取得
        def bean = new GroovyMBean(server, 'com.sun.management:type=DiagnosticCommand')

        // Commercial Features のアンロック (JFR の実行に必要)
        println bean.vmUnlockCommercialFeatures()

        // JFR の開始
        println bean.jfrStart([
            "duration=${duration}",
            "filename=${fileName}",
            'delay=10s'
        ] as String[])
    }
} finally {
    vm.detach()
}

実行例

apache-tomcat-9.0.0.M4 へ適用してみます。

Tomcat 実行
> startup

以下の環境で実行しました。

  • Groovy 2.4.6
  • Java SE 8u92 64bit版
JFR 実行
> jps

4576 Jps
2924 Bootstrap

> groovy jfr_run.groovy 2924 1m sample1.jfr

Commercial Features now unlocked.

Recording 1 scheduled to start in 10 s. The result will be written to:

C:\・・・\apache-tomcat-9.0.0.M4\apache-tomcat-9.0.0.M4\bin\sample1.jfr

jfrStart は JFR の完了を待たずに戻り値を返すため、JFR の実行状況は別途確認する事になります。

出力結果 Recording 1 scheduled1 が recoding の番号で、この番号を使って JFR の状態を確認できます。

ファイル名を相対パスで指定すると対象プロセスのカレントディレクトリへ出力されるようです。 (今回は Tomcat の bin ディレクトリへ出力されました)

(b) JFR の状態確認

JFR の実行状況を確認するには jfrCheck を使います。

下記では recording の番号を指定し、該当する JFR の実行状況を出力しています。

jfrCheck の引数が null の場合は全ての JFR 実行状態を取得するようです。

jfr_check.groovy
import com.sun.tools.attach.VirtualMachine

import javax.management.remote.JMXConnectorFactory
import javax.management.remote.JMXServiceURL

def pid = args[0]
String[] params = (args.length > 1)? ["recording=${args[1]}"]: null

def vm = VirtualMachine.attach(pid)

try {
    def jmxuri = vm.startLocalManagementAgent()

    JMXConnectorFactory.connect(new JMXServiceURL(jmxuri)).withCloseable {
        def server = it.getMBeanServerConnection()

        def bean = new GroovyMBean(server, 'com.sun.management:type=DiagnosticCommand')

        println bean.jfrCheck(params)
    }

} finally {
    vm.detach()
}

実行例

recording 番号(下記では 1)を指定して実行します。

実行例1 (JFR 実行中)
> groovy jfr_check.groovy 2924 1

Recording: recording=1 name="sample1.jfr" duration=1m filename="sample1.jfr" compress=false (running)
実行例2 (JFR 完了後)
> groovy jfr_check.groovy 2924 1

Recording: recording=1 name="sample1.jfr" duration=1m filename="sample1.jfr" compress=false (stopped)

今回作成したサンプルのソースは http://github.com/fits/try_samples/tree/master/blog/20160519/

JDI でオブジェクトの世代(Young・Old)を判別する2

前回 の処理を sun.jvm.hotspot.oops.ObjectHeap を使って高速化してみたいと思います。(世代の判別方法などは前回と同じ)

使用した環境は前回と同じです。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160506/

ObjectHeap で Oop を取得

ObjectReference の代わりに、sun.jvm.hotspot.oops.ObjectHeapiterate(HeapVisitor) メソッドを使えば Oop を取得できます。

今回のような方法では、以下の理由で iterate メソッドの引数へ SAJDIClassLoader がロードした sun.jvm.hotspot.oops.HeapVisitor インターフェースの実装オブジェクトを与える必要があります。

  • JDI の内部で管理している Serviceability Agent APIsun.jvm.hotspot.jdi.SAJDIClassLoader によってロードされている

下記サンプルでは SAJDIClassLoader がロードした HeapVisitor を入手し、asType を使って実装オブジェクトを作成しています。

また、HeapVisitor の doObj で false を返すと処理を継続し、true を返すと中止 ※ するようです。

 ※ 厳密には、
    対象としている Address 範囲の while ループを break するだけで、
    その外側の(liveRegions に対する)for ループは継続するようです
    (ObjectHeap の iterateLiveRegions メソッドのソース参照)

なお、ObjectHeap は sun.jvm.hotspot.jdi.VirtualMachineImpl から saObjectHeap() で取得するか、sun.jvm.hotspot.runtime.VM から取得します。

check_gen2.groovy
import com.sun.jdi.Bootstrap

def pid = args[0]
def prefix = (args.length > 1)? args[1]: ''

def manager = Bootstrap.virtualMachineManager()

def connector = manager.attachingConnectors().find {
    it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector'
}

def params = connector.defaultArguments()
params.get('pid').setValue(pid)

def vm = connector.attach(params)

// 世代の判定処理を返す
generation = { heap ->
    def hasYoungGen = heap.metaClass.getMetaMethod('youngGen') != null

    [
        young: hasYoungGen? heap.youngGen(): heap.getGen(0),
        old: hasYoungGen? heap.oldGen(): heap.getGen(1)
    ]
}

try {
    def uv = vm.saVM.universe

    def gen = generation(uv.heap())

    def youngGen = gen.young
    def oldGen = gen.old

    println "*** youngGen=${youngGen}, oldGen=${oldGen}"
    println ''

    def objHeap = vm.saObjectHeap()
    // 以下でも可
    //def objHeap = vm.saVM.objectHeap

    // SAJDIClassLoader がロードした HeapVisitor インターフェースを取得
    def heapVisitorCls = uv.class.classLoader.loadClass('sun.jvm.hotspot.oops.HeapVisitor')

    // SAJDIClassLoader がロードした HeapVisitor インターフェースを実装
    def heapVisitor = [
        prologue: { size -> },
        epilogue: {},
        doObj: { oop ->
            def clsName = oop.klass.name.asString()

            if (clsName.startsWith(prefix)) {
                def age = oop.mark.age()

                // 世代の判別
                def inYoung = youngGen.isIn(oop.handle)
                def inOld = oldGen.isIn(oop.handle)

                def identityHash = ''

                try {
                    identityHash = Long.toHexString(oop.identityHash())
                } catch (e) {
                }

                println "class=${clsName}, hash=${identityHash}, handle=${oop.handle}, age=${age}, inYoung=${inYoung}, inOld=${inOld}"
            }

            // 処理を継続する場合は false を返す
            false
        }
    ].asType(heapVisitorCls)

    objHeap.iterate(heapVisitor)

} finally {
    vm.dispose()
}

動作確認

前回と同じように、実行中の apache-tomcat-9.0.0.M4 へ適用してみました。

前回と異なり、クラス名が '/' で区切られている点に注意

実行例1 (Windows の場合)
> jps

3604 Bootstrap
4516 Jps
> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar check_gen2.groovy 3604 org/apache/catalina/core/StandardContext

*** youngGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSYoungGen@0x0000000002149ab0, oldGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSOldGen@0x0000000002149b40

class=org/apache/catalina/core/StandardContextValve, hash=0, handle=0x00000000c3a577d0, age=1, inYoung=false, inOld=true
class=org/apache/catalina/core/StandardContext$NoPluggabilityServletContext, hash=0, handle=0x00000000c3a633d8, age=0, inYoung=false, inOld=true
class=org/apache/catalina/core/StandardContext$ContextFilterMaps, hash=0, handle=0x00000000c3a63ef0, age=1, inYoung=false, inOld=true
class=org/apache/catalina/core/StandardContext$NoPluggabilityServletContext, hash=0, handle=0x00000000ebc46da0, age=0, inYoung=true, inOld=false
class=org/apache/catalina/core/StandardContext, hash=6f2d2815, handle=0x00000000eddfeaa0, age=1, inYoung=true, inOld=false
class=org/apache/catalina/core/StandardContext, hash=21f2e66b, handle=0x00000000eddff238, age=3, inYoung=true, inOld=false
・・・
実行例2 (Linux の場合)
$ jps

2778 Jps
2766 Bootstrap
$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar check_gen2.groovy 2766 org/apache/catalina/core/StandardContext

*** youngGen=sun.jvm.hotspot.memory.DefNewGeneration@0x00007f0760019cb0, oldGen=sun.jvm.hotspot.memory.TenuredGeneration@0x00007f076001bfc0

class=org/apache/catalina/core/StandardContext, hash=497fe2c4, handle=0x00000000f821bf90, age=0, inYoung=true, inOld=false
class=org/apache/catalina/core/StandardContext$ContextFilterMaps, hash=0, handle=0x00000000f821c5d8, age=0, inYoung=true, inOld=false
class=org/apache/catalina/core/StandardContextValve, hash=0, handle=0x00000000f821ca60, age=0, inYoung=true, inOld=false
・・・
class=org/apache/catalina/core/StandardContext, hash=5478de1a, handle=0x00000000fb12b310, age=1, inYoung=false, inOld=true
class=org/apache/catalina/core/StandardContext$NoPluggabilityServletContext, hash=0, handle=0x00000000fb12f6b0, age=0, inYoung=false, inOld=true
class=org/apache/catalina/core/StandardContext$ContextFilterMaps, hash=0, handle=0x00000000fb131a80, age=0, inYoung=false, inOld=true
class=org/apache/catalina/core/StandardContextValve, hash=0, handle=0x00000000fb1398b0, age=0, inYoung=false, inOld=true

JDI でオブジェクトの世代(Young・Old)を判別する

前回、オブジェクトの age を取得しましたが、同様の方法で今回はオブジェクトが Young 世代(New 領域)と Old 世代(Old 領域) のどちらに割り当てられているかを判別してみたいと思います。 (ただし、結果の正否は確認できていません)

使用した環境は前回と同じです。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160430/

Young・Old 世代の判別

さて、Young・Old の判別方法ですが。

Serviceability Agent API を見てみると sun.jvm.hotspot.gc_implementation.parallelScavenge パッケージに PSYoungGenPSOldGen というクラスがあり、isIn(Address) メソッドで判定できそうです。

更に PSYoungGen と PSOldGen は sun.jvm.hotspot.gc_implementation.parallelScavenge.ParallelScavengeHeap から取得できます。

Address (sun.jvm.hotspot.debugger パッケージ所属) は sun.jvm.hotspot.oops.OopgetHandle()getMark().getAddress() で取得できるので (下記サンプルでは getHandle を使用)、ParallelScavengeHeap を取得すれば何とかなりそうです。

実際に試してみたところ、ParallelScavengeHeap を取得できたのは Windows 環境で、Linux 環境では GenCollectedHeap を使った別の方法 (getGen メソッドを使う) が必要でした。 (GC の設定等によって更に変わるかもしれません)

世代の判定クラス
実行環境 ヒープクラス ※ Young 世代の判定クラス Old 世代の判定クラス
Windows ParallelScavengeHeap PSYoungGen PSOldGen
Linux GenCollectedHeap DefNewGeneration TenuredGeneration
 ※ Universe の heap() メソッド戻り値の実際の型
    CollectedHeap のサブクラス

上記を踏まえて、前回の処理をベースに以下を追加してみました。

  • (1) sun.jvm.hotspot.jdi.VirtualMachineImpl から sun.jvm.hotspot.runtime.VM オブジェクトを取り出す ※1
  • (2) VM オブジェクトから sun.jvm.hotspot.memory.Universe オブジェクトを取得
  • (3) Universe オブジェクトから CollectedHeap (のサブクラス) を取得 ※2
  • (4) (3) の結果から世代を判定するオブジェクトをそれぞれ取得

(4) で妥当な条件分岐の仕方が分からなかったので、とりあえず youngGen メソッドが無ければ GenCollectedHeap として処理するようにしました。

 ※1 private フィールドの saVM か、package メソッドの saVM() で取得

 ※2 今回のやり方では、Windows は ParallelScavengeHeap、
     Linux は GenCollectedHeap でした

JDI の SAPIDAttachingConnector で attach した結果が VirtualMachineImpl となります。

また、SAPIDAttachingConnector でデバッグ接続した場合 (読み取り専用のデバッグ接続)、デバッグ対象オブジェクトのメソッド (hashCode や toString 等) を呼び出せないようなので、オブジェクトを識別するための情報を得るため identityHash を使ってみました。 (ただし、戻り値が 0 になるものが多数ありました)

check_gen.groovy
import com.sun.jdi.Bootstrap

def pid = args[0]
def prefix = args[1]

def manager = Bootstrap.virtualMachineManager()

def connector = manager.attachingConnectors().find {
    it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector'
}

def params = connector.defaultArguments()
params.get('pid').setValue(pid)

def vm = connector.attach(params)

// (4) 世代を判定するためのオブジェクトを取得
generation = { heap ->
    def hasYoungGen = heap.metaClass.getMetaMethod('youngGen') != null

    [
        // Young 世代の判定オブジェクト(PSYoungGen or DefNewGeneration)
        young: hasYoungGen? heap.youngGen(): heap.getGen(0),
        // Old 世代の判定オブジェクト(PSOldGen or TenuredGeneration)
        old: hasYoungGen? heap.oldGen(): heap.getGen(1)
    ]
}

try {
    if (vm.canGetInstanceInfo()) {

        // (1) (2)
        def uv = vm.saVM.universe

        // (3)
        def gen = generation(uv.heap())

        def youngGen = gen.young
        def oldGen = gen.old

        println "*** youngGen=${youngGen}, oldGen=${oldGen}"
        println ''

        vm.allClasses().findAll { it.name().startsWith(prefix) }.each { cls ->
            println cls.name()

            cls.instances(0).each { inst ->
                def oop = inst.ref()
                def age = oop.mark.age()

                // 世代の判別
                def inYoung = youngGen.isIn(oop.handle)
                def inOld = oldGen.isIn(oop.handle)

                def identityHash = ''

                try {
                    identityHash = Long.toHexString(oop.identityHash())
                } catch (e) {
                }

                println "  hash=${identityHash}, handle=${oop.handle}, age=${age}, inYoung=${inYoung}, inOld=${inOld}"
            }
        }
    }
} finally {
    vm.dispose()
}

動作確認

前回と同じように、実行中の apache-tomcat-9.0.0.M4 へ適用してみました。

実行例1 (Windows の場合)
> jps

2836 Bootstrap
5944 Jps
> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar check_gen.groovy 2836 org.apache.catalina.core.StandardContext

*** youngGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSYoungGen@0x0000000002049ad0, oldGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSOldGen@0x0000000002049b60

org.apache.catalina.core.StandardContext
  hash=66dfd722, handle=0x00000000c394a990, age=0, inYoung=false, inOld=true
  hash=39504d4e, handle=0x00000000edea7cf8, age=3, inYoung=true, inOld=false
  hash=194311fa, handle=0x00000000edea8e90, age=1, inYoung=true, inOld=false
  hash=2b28e016, handle=0x00000000edf0c130, age=2, inYoung=true, inOld=false
  hash=578787b8, handle=0x00000000edf457c0, age=1, inYoung=true, inOld=false
org.apache.catalina.core.StandardContext$ContextFilterMaps
  hash=0, handle=0x00000000c394e7d0, age=0, inYoung=false, inOld=true
  hash=0, handle=0x00000000c396ec90, age=2, inYoung=false, inOld=true
  hash=0, handle=0x00000000c3988eb0, age=1, inYoung=false, inOld=true
  hash=0, handle=0x00000000edf04320, age=1, inYoung=true, inOld=false
  hash=0, handle=0x00000000edf70988, age=1, inYoung=true, inOld=false
・・・
> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar check_gen.groovy 2836 org.apache.catalina.LifecycleEvent

*** youngGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSYoungGen@0x0000000002049ad0, oldGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSOldGen@0x0000000002049b60

org.apache.catalina.LifecycleEvent
  hash=0, handle=0x00000000c37459c0, age=0, inYoung=false, inOld=true
  hash=0, handle=0x00000000c374ed40, age=1, inYoung=false, inOld=true
  hash=0, handle=0x00000000c39ff950, age=0, inYoung=false, inOld=true
  hash=0, handle=0x00000000ebb8ef90, age=0, inYoung=true, inOld=false
  hash=0, handle=0x00000000ebb90490, age=0, inYoung=true, inOld=false
  hash=0, handle=0x00000000ebb904c0, age=0, inYoung=true, inOld=false
  ・・・
実行例2 (Linux の場合)
$ jps

2801 Jps
2790 Bootstrap
$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar check_gen.groovy 2790 org.apache.catalina.core.StandardContext

*** youngGen=sun.jvm.hotspot.memory.DefNewGeneration@0x00007fca50019cb0, oldGen=sun.jvm.hotspot.memory.TenuredGeneration@0x00007fca5001bfc0

org.apache.catalina.core.StandardContext
  hash=27055bff, handle=0x00000000fb025d38, age=1, inYoung=false, inOld=true
  hash=5638a30f, handle=0x00000000fb1270a8, age=1, inYoung=false, inOld=true
  hash=15fad243, handle=0x00000000fb296730, age=1, inYoung=false, inOld=true
  hash=36c4d4a0, handle=0x00000000fb2f3cf0, age=1, inYoung=false, inOld=true
  hash=33309557, handle=0x00000000fb2f3ef8, age=1, inYoung=false, inOld=true
org.apache.catalina.core.StandardContextValve
  hash=0, handle=0x00000000fb045ad8, age=0, inYoung=false, inOld=true
  hash=0, handle=0x00000000fb135300, age=0, inYoung=false, inOld=true
  hash=0, handle=0x00000000fb2ad568, age=1, inYoung=false, inOld=true
  hash=0, handle=0x00000000fb3022b0, age=1, inYoung=false, inOld=true
  hash=0, handle=0x00000000fb3050e8, age=1, inYoung=false, inOld=true
・・・
$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar check_gen.groovy 2790 org.apache.catalina.LifecycleEvent

*** youngGen=sun.jvm.hotspot.memory.DefNewGeneration@0x00007fca50019cb0, oldGen=sun.jvm.hotspot.memory.TenuredGeneration@0x00007fca5001bfc0

org.apache.catalina.LifecycleEvent
  hash=0, handle=0x00000000f82079a8, age=0, inYoung=true, inOld=false
  hash=0, handle=0x00000000f8207a00, age=0, inYoung=true, inOld=false
  hash=0, handle=0x00000000f8210470, age=0, inYoung=true, inOld=false
  ・・・
  hash=0, handle=0x00000000f8506568, age=0, inYoung=true, inOld=false
  hash=0, handle=0x00000000fb003370, age=0, inYoung=false, inOld=true

この結果の正否はともかく、一応は判別できているように見えます。

ちなみに、前回と同様に処理が遅い(重い)点に関しては、Oop を Serviceability Agent APIsun.jvm.hotspot.oops.ObjectHeap で取得するように変更すれば改善できます。

注意点

今回のように JDI の内部で管理している Serviceability Agent API を取り出して使う場合の注意点は以下の通りです。

  • JDI 内部の Serviceability Agent API のクラス(インターフェースも含む)は sun.jvm.hotspot.jdi.SAJDIClassLoader クラスローダーによってロードされる

同じ名称のクラスでもロードするクラスローダーが異なれば別物となりますので、Java で今回のような処理を実装しようとすると、クラスのキャストができずリフレクション等を多用する事になると思います。

また、Groovy でも HeapVisitor 等を使う場合に多少の工夫が必要になります。

JDI でオブジェクトの年齢(age)を取得

HotSpot VM の世代別 GC において、オブジェクト(インスタンス)には年齢 (age) が設定されており、Minor GC が適用される度にカウントアップされ、長命オブジェクトかどうかの判定に使われるとされています。

そこで今回は、Groovy で JDI (Java Debug Interface)Serviceability Agent API を使って、実行中の Java アプリケーションへアタッチし、オブジェクトの年齢を取得してみたいと思います。

使用した環境は以下の通りです。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160425/

はじめに

今回は、JDI の SA PID コネクタsun.jvm.hotspot.jdi.SAPIDAttachingConnector) を使ってデバッグ接続する事にします。

SAPIDAttachingConnector の特徴は以下のようになります。

SAPIDAttachingConnector の特徴
利点 欠点
デバッグ対象アプリケーションの実行時に -agentlib:jdwp のようなオプション指定が不要 読み取り専用で VM へアタッチするため、実行可能な API が限定される。デバッグ接続中は対象のプロセスが中断する ※
 ※ 全スレッドの処理が中断した状態となり、
    ステップ実行のような状態を変化させる API は使えません。
    (VMCannotBeModifiedException が throw される API は使えない)

デバッグ接続中は全ての処理が完全に中断するので、運用中のサーバーアプリケーションなどへの適用には不向きだと思います。

(a) JDI でクラスの一覧を取得

まずは SAPIDAttachingConnector の動作確認も兼ねて、実行中のアプリケーションへデバッグ接続して、クラスの一覧を出力してみます。

SAPIDAttachingConnector を使ったデバッグ接続手順は以下のようになります。

  • (1) Bootstrap から com.sun.jdi.VirtualMachineManager を取得
  • (2) VirtualMachineManager から利用可能な JDI コネクタの一覧を attachingConnectors メソッドで取得し、SAPIDAttachingConnector を抽出
  • (3) デバッグ対象アプリケーションのプロセスIDをパラメータへ設定しデバッグ接続

SAPIDAttachingConnector を使うには、JDI を使用するアプリケーション(下記スクリプト)の実行時のクラスパスへ lib/sa-jdi.jar を含めておかなければならない点に注意が必要です ※。

そうしないと attachingConnectors の結果に SAPIDAttachingConnector が含まれず、下記スクリプトでは NullPointerException となります。 (find の結果が null になるので)

 ※ JDI を使うには lib/tools.jar が必要ですが、
    Groovy の場合は groovy-starter.conf のデフォルト設定で
    tools.jar をロードするようになっています
class_list.groovy
import com.sun.jdi.Bootstrap

def pid = args[0]

// (1)
def manager = Bootstrap.virtualMachineManager()

// (2) SAPIDAttachingConnector を抽出
def connector = manager.attachingConnectors().find {
    it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector'
}

// パラメータの設定
def params = connector.defaultArguments()
params.get('pid').setValue(pid)

// (3) デバッグ接続
def vm = connector.attach(params)

try {
    // クラス情報 com.sun.jdi.ReferenceType の取得
    vm.allClasses().each { cls ->
        println cls.name()
    }
} finally {
    vm.dispose()
}

動作確認

今回は apache-tomcat-9.0.0.M4 を起動しておき、上記スクリプトデバッグ接続してクラスの情報を取得してみます。

実行例1 (Windows の場合)

まずはデバッグ対象の Tomcat を実行しておきます。(設定等はデフォルトのまま)

Tomcat 起動
> startup

jps 等でプロセス ID を調べて上記スクリプトを実行します。 JDK の lib/sa-jdi.jar をクラスパスへ指定する必要があります。

class_list.groovy の実行
> jps

804 Bootstrap
5244 Jps

> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar class_list.groovy 804

sun.management.HotSpotDiagnostic
java.security.CodeSigner
java.security.CodeSigner[]
java.lang.Character
java.lang.Character[]
・・・
short[]
short[][]
long[]
float[]
double[]

実行例2 (Linux の場合)

Linux でも同じです。

Tomcat 起動
$ ./startup.sh
class_list.groovy の実行
$ jps

2388 Jps
2363 Bootstrap

$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar class_list.groovy 2363

sun.nio.ch.SelectorProviderImpl
java.util.jar.JarInputStream
org.apache.catalina.mapper.Mapper$ContextVersion
org.apache.catalina.mapper.Mapper$ContextVersion[]
java.nio.channels.SeekableByteChannel
・・・
short[]
short[][]
long[]
float[]
double[]

SAPIDAttachingConnector により読み取り専用でデバッグ接続するため、デバッグ接続中に Tomcat へアクセスしても応答が返ってこなくなる点にご注意下さい。

(b) JDI でオブジェクトの年齢を取得

それでは、本題の年齢を取得してみます。

クラス情報 ReferenceType を取得するまでは上記 (a) と同じで、その後は以下のように JDI の API から Serviceability Agent API を取り出せばオブジェクトの年齢を取得できます。

  • (1) ReferenceType の instances メソッドインスタンス情報 ObjectReference を取得 (引数を 0 にすると全インスタンスを取得)
  • (2) (1) の実体が sun.jvm.hotspot.jdi.ObjectReferenceImpl なので、protected メソッドref() を呼び出して sun.jvm.hotspot.oops.Oop (Serviceability Agent API) を取得
  • (3) OopgetMark() メソッドsun.jvm.hotspot.oops.Mark を取得し、age() メソッドで年齢を取得

instances メソッドは非常に重い処理だったので、今回は対象クラスをコマンドライン引数 (第2引数) で指定した名称で始まるクラス名に限定するようにしました。 (下記 findAll の箇所)

また、Oop は他の方法 (sun.jvm.hotspot.oops.ObjectHeap を使用) でも取得できるようです。

age_list.groovy
import com.sun.jdi.Bootstrap

def pid = args[0]
def prefix = args[1]

def manager = Bootstrap.virtualMachineManager()

def connector = manager.attachingConnectors().find {
    it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector'
}

def params = connector.defaultArguments()
params.get('pid').setValue(pid)

def vm = connector.attach(params)

try {
    if (vm.canGetInstanceInfo()) {
        vm.allClasses().findAll { it.name().startsWith(prefix) }.each { cls ->
            println cls.name()

            // (1) インスタンス情報 ObjectReference の取得
            cls.instances(0).each { inst ->
                // (2) Oop の取得
                def oop = inst.ref()
                // (3) Mark を取得して年齢を取得
                def age = oop.mark.age()

                println "  handle=${oop.handle}, age=${age}"
            }
        }
    }
} finally {
    vm.dispose()
}

ここで、HotSpot VM のソース share/vm/oops/oop.inline.hpp の oopDesc::age() を見ると、has_displaced_mark() が true の時は displaced_mark の age を取得しているのですが、同じように実装(下記)すると hasDisplacedMarkHelper が true の場合に sun.jvm.hotspot.debugger.UnalignedAddressException が発生したため (WindowsLinux の両方で発生)、とりあえず displacedMarkHelper は今回無視するようにしました。

hasDisplacedMarkHelper が true の場合に UnalignedAddressException が発生したコード例
def mark = oop.mark
def age = mark.hasDisplacedMarkHelper()? mark.displacedMarkHelper().age(): mark.age()

動作確認

先程と同じように実行中の apache-tomcat-9.0.0.M4 に対して適用してみます。

全クラスを対象にすると時間がかかり過ぎるので、クラス名が org.apache.catalina.core.ApplicationContext で始まるものに限定してみました。

実行例1 (Windows の場合)

> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar age_list.groovy 804 org.apache.catalina.core.ApplicationContext

org.apache.catalina.core.ApplicationContextFacade
  handle=0x00000000c39602f0, age=0
  handle=0x00000000c3962460, age=1
  handle=0x00000000c3971718, age=0
  handle=0x00000000ecf913b0, age=3
  handle=0x00000000ecf99950, age=1
org.apache.catalina.core.ApplicationContext
  handle=0x00000000c39607e0, age=0
  handle=0x00000000c3966960, age=1
  handle=0x00000000c3971c08, age=0
  handle=0x00000000ecf918a0, age=3
  handle=0x00000000ecf99680, age=1

実行例2 (Linux の場合)

$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar age_list.groovy 2363 org.apache.catalina.core.ApplicationContext

org.apache.catalina.core.ApplicationContextFacade
  handle=0x00000000fb041d08, age=1
  handle=0x00000000fb1288d8, age=0
  handle=0x00000000fb2b1df0, age=1
  handle=0x00000000fb307c98, age=1
  handle=0x00000000fb318aa8, age=1
org.apache.catalina.core.ApplicationContext
  handle=0x00000000fb0338c0, age=1
  handle=0x00000000fb12ed50, age=0
  handle=0x00000000fb2a56a0, age=1
  handle=0x00000000fb2f8da8, age=1
  handle=0x00000000fb312a80, age=1

Deeplearning4J で iris を分類

Deeplearning4J (DL4J) を使って 「ConvNetJS で iris を分類」 と同様に iris を分類してみました。

今回は Groovy を使って実行します。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160412/

準備

iris データセット

iris のデータセットorg.deeplearning4j.datasets.DataSets.iris() メソッドで取得できるため、ConvNetJS の時のようにダウンロードする必要はありません。

OpenBlas のインストール(Windows

Deeplearning4J が利用する ND4J を効果的に使うには OpenBLAS などのネイティブの BLAS ライブラリが必要です。

MNIST データセットのパース」 でも少し書きましたが、OpenBLAS を Windows 環境へインストールするには以下のようにします。

また、JniLoader が netlib-native_system-win-x86_64.dll を %TEMP% ディレクトリへダウンロードするのを防止するには、この dll も環境変数 PATH へ設定した場所へ配置しておきます。 (dll は Maven の Central Repository からダウンロードできます)

今回は、ND4J_Win64_OpenBLAS-v0.2.14.zip の解凍先へ netlib-native_system-win-x86_64.dll も配置しました。

  • C:\ND4J_Win64_OpenBLAS-v0.2.14
    • libblas3.dll
    • libgcc_s_seh-1.dll
    • libgfortran-3.dll
    • liblapack3.dll
    • libopenblas.dll
    • libquadmath-0.dll
    • netlib-native_system-win-x86_64.dll
環境変数 PATH 設定例
> set PATH=C:\ND4J_Win64_OpenBLAS-v0.2.14;%PATH%

共通処理の作成

Deeplearning4J には学習モデルを JSON 化するような処理は用意されていないようです。

ただ、MultiLayerNetwork 自体が Serializable なので、今回は Javaシリアライズ機能を使って保存等を行いました。

1. 学習モデルのセーブ・ロード処理

ObjectInputStream・ObjectOutputStream を使って MultiLayerNetwork を保存・復元する処理です。

ModelManager.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8')
@Grab('org.nd4j:nd4j-x86:0.4-rc3.8')
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork

abstract class ModelManager extends Script {

    // MultiLayerNetwork の復元
    def loadModel(String fileName) {
        new File(fileName).withObjectInputStream(this.class.classLoader) {
            it.readObject() as MultiLayerNetwork
        }
    }

    // MultiLayerNetwork の生成と保存
    def saveModel(String fileName, MultiLayerConfiguration conf) {
        // MultiLayerNetwork の生成と初期構築
        def model = new MultiLayerNetwork(conf)
        model.init()

        new File(fileName).withObjectOutputStream {
            it.writeObject model
        }
    }
}

2. 学習・評価処理

以前 ConvNetJS で実装したものと同じような出力結果となるように実装してみました。

今回使用した MultiLayerNetworkメソッドは以下の通りです。

メソッド 備考
fit 学習の実施(誤差の算出、重みの調整等)
score 誤差の取得(引数次第で誤差の算出も実施)
output ニューラルネットの処理結果を取得

output は正解率の算出に使うだけなので、第 2引数を false にして評価モード(TrainingMode.TEST)で実行します。

今回は、学習時と評価時の処理を共通化するため score メソッドで誤差を算出しましたが、setListeners メソッドScoreIterationListener を設定すれば誤差をログ出力できます。

Evaluationeval メソッドへ正解のラベルデータとニューラルネットの処理結果を渡すと正解率などを算出してくれ、accuracy メソッドで正解率を取得できます。 (stats メソッドを使えば結果を文字列で取得する事もできます)

org.nd4j.linalg.dataset.api.DataSetsplitTestAndTrain メソッドで学習用と評価用にデータを分割 ※、batchBy メソッドでミニバッチへ分割できます。

 ※ getTrain で学習用、getTest で評価用のデータセットを取得できます
iris_train.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8')
@Grab('org.nd4j:nd4j-x86:0.4-rc3.8')
import org.deeplearning4j.datasets.DataSets
import org.deeplearning4j.eval.Evaluation
import org.nd4j.linalg.dataset.SplitTestAndTrain

import groovy.transform.BaseScript

@BaseScript ModelManager baseScript

def epoch = args[0] as int
def trainRate = 0.7
def batchSize = 1

def modelFile = args[1]

// 誤差・正解率の算出
class SimpleEvaluator {
    private def model
    private def ev = new Evaluation(3) // iris の品種の数 3 を設定
    private def lossList = []

    SimpleEvaluator(model) {
        this.model = model
    }

    def eval(d) {
        // 誤差の算出とリストへの追加
        lossList << model.score(d)
        // 正解率などの算出
        ev.eval(d.labels, model.output(d.featureMatrix, false))
    }
    // 誤差(平均値)
    def loss() {
        lossList.sum() / lossList.size()
    }
    // 正解率
    def accuracy() {
        ev.accuracy()
    }
}

// 学習モデル (MultiLayerNetwork) のロード
def model = loadModel(modelFile)
// iris データセットの取得
def data = DataSets.iris()

(0..<epoch).each {
    def ev = [
        train: new SimpleEvaluator(model),
        test: new SimpleEvaluator(model)
    ]

    data.shuffle()

    // 学習用とテスト用にデータセットを分割
    def testAndTrain = data.splitTestAndTrain(trainRate)

    // 学習用データセットをミニバッチへ分割
    testAndTrain.train.batchBy(batchSize).each {
        ev.train.eval(it)
        // 学習
        model.fit(it)
    }

    // テスト用データセットを評価
    ev.test.eval(testAndTrain.test)

    // 結果の出力
    println([
        ev.train.loss(), ev.train.accuracy(),
        ev.test.loss(), ev.test.accuracy()
    ].join(','))
}

また、デフォルトではログを標準出力するので、今回は設定ファイルを用意して無効化しました。

logback.xml
<configuration>
  <root level="OFF"></root>
</configuration>

(a) 単純な構成 (入力層 - 出力層)

入力層と出力層だけの単純なニューラルネットを試します。

iris データセットは、4つの変数を使って 3つの品種に分類する事になりますので、OutputLayer.Builder へ以下のように設定します。

  • nIn へ変数の数 4 を設定
  • nOut へ分類の数 3 を設定

ソフトマックスと 3種類以上の交差エントロピー(Cross Entropy)を実施するには、以下のように設定すれば良いみたいです。

  • activation へ softmax を設定
  • OutputLayer.BuilderLossFunctions.LossFunction.MCXENT (Multiclass Cross Entropy) を設定

また、list には設定するレイヤーの数 (以下では 1) を設定します ※。

 ※ ただし、最新のソースで list(int) は Deprecated となっており、
    代わりにレイヤー数を指定しなくても済む list() を使うようです

    また、list(int) は list() を呼ぶだけの実装へ変わっていました
create_iris_hnn1.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8')
@Grab('org.nd4j:nd4j-x86:0.4-rc3.8')
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.Updater
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.nd4j.linalg.lossfunctions.LossFunctions

import groovy.transform.BaseScript

@BaseScript ModelManager baseScript

def learningRate = args[0] as double
def updateMethod = Updater.valueOf(args[1].toUpperCase())
def destFile = args[2]

def conf = new NeuralNetConfiguration.Builder()
    .iterations(1) // 最適化の繰り返し回数(デフォルト設定は 5)
    .updater(updateMethod)
    .learningRate(learningRate)
    .list(1)
    .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
        .nIn(4)
        .nOut(3)
        .activation('softmax')
        .build()
    )
    .build()

// 保存
saveModel(destFile, conf)

学習・評価

更新方法 (updateMethod) だけを変えたモデルを作って学習・評価を実施してみました。

(a-1) learningRate = 0.01, updateMethod = adam, epoch = 50

実行例
> groovy create_iris_hnn1.groovy 0.01 adam models/a-1_adam.ser
・・・
> groovy iris_train.groovy 50 models/a-1_adam.ser > results/a-1.csv
・・・

f:id:fits:20160412211058p:plain

(a-2) learningRate = 0.01, updateMethod = adadelta, epoch = 50

実行例
> groovy create_iris_hnn1.groovy 0.01 adadelta models/a-2_adadelta.ser
・・・
> groovy iris_train.groovy 50 models/a-2_adadelta.ser > results/a-2.csv
・・・

f:id:fits:20160412211111p:plain

(a-3) learningRate = 0.01, updateMethod = sgd, epoch = 50

実行例
> groovy create_iris_hnn1.groovy 0.01 sgd models/a-3_sgd.ser
・・・
> groovy iris_train.groovy 50 models/a-3_sgd.ser > results/a-3.csv
・・・

f:id:fits:20160412211124p:plain

(b) 隠れ層を追加 (入力層 - 隠れ層 - 出力層)

次に、隠れ層を追加してみます。

DenseLayer を追加して nInnOut を調整します。

create_iris_hnn2.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8')
@Grab('org.nd4j:nd4j-x86:0.4-rc3.8')
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.Updater
import org.deeplearning4j.nn.conf.layers.DenseLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.nd4j.linalg.lossfunctions.LossFunctions

import groovy.transform.BaseScript

@BaseScript ModelManager baseScript

def learningRate = args[0] as double
def updateMethod = Updater.valueOf(args[1].toUpperCase())

// 隠れ層のニューロン数
def fcNeuNum = args[2] as int
// 隠れ層の活性化関数
def fcAct = args[3]

def destFile = args[4]

def conf = new NeuralNetConfiguration.Builder()
    .iterations(1)
    .updater(updateMethod)
    .learningRate(learningRate)
    .list(2)
    // 隠れ層
    .layer(0, new DenseLayer.Builder()
        .nIn(4)
        .nOut(fcNeuNum)
        .activation(fcAct)
        .build()
    )
    .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
        .nIn(fcNeuNum)
        .nOut(3)
        .activation('softmax')
        .build()
    )
    .build()

// 保存
saveModel(destFile, conf)

学習・評価

隠れ層の活性化関数 (fcAct) だけを変えたモデルを作って学習・評価を実施してみました。

(b-1) fcNeuNum = 8, fcAct = relu, learningRate = 0.01, updateMethod = adam, epoch = 50

実行例
> groovy create_iris_hnn2.groovy 0.01 adam 8 relu models/b-1_adam_relu.ser
・・・
> groovy iris_train.groovy 50 models/b-1_adam_relu.ser > results/b-1.csv
・・・

f:id:fits:20160412211144p:plain

(b-2) fcNeuNum = 8, fcAct = sigmoid, learningRate = 0.01, updateMethod = adam, epoch = 50

実行例
> groovy create_iris_hnn2.groovy 0.01 adam 8 sigmoid models/b-2_adam_sigmoid.ser
・・・
> groovy iris_train.groovy 50 models/b-2_adam_sigmoid.ser > results/b-2.csv
・・・

f:id:fits:20160412211155p:plain