node-ffi で OpenCL を使う
Windows 環境で node-ffi (Node.js Foreign Function Interface) を使って OpenCL の API を呼び出してみました。
サンプルソースは http://github.com/fits/try_samples/tree/master/blog/20160627/
なお、OpenCL 上での演算は今回扱いませんが、単純な演算のサンプルは ここ に置いてます。
はじめに
node-ffi のインストール
まずは、node-gyp をインストールしておきます。 node-gyp を Windows 環境で使うには VC++ や Python 2.7 が必要です。
node-gyp インストール例
> npm install -g node-gyp
node-ffi をインストールします。(モジュール名は node-ffi ではなく ffi
です)
node-ffi インストール例
> npm install ffi
node-ffi の使い方
node-ffi では Library
関数を使ってネイティブライブラリの関数をマッピングします。
ffi.Library(<ライブラリ名>, { <関数名>: [<戻り値の型>, [<第1引数の型>, <第2引数の型>, ・・・]], ・・・ })
引数の型などはライブラリのヘッダーファイルなどを参考にして設定します。
例えば、OpenCL.dll (Windows 環境の場合) の clGetPlatformIDs 関数を Node.js から openCl.clGetPlatformIDs(・・・)
で呼び出すには以下のようにします。
Library の使用例
const openCl = ffi.Library('OpenCL', { 'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]], ・・・ });
ref モジュールの refType
でポインタ用の型を定義する事が可能です。
refType の使用例
const uintPtr = ref.refType(ref.types.uint32); const sizeTPtr = ref.refType('size_t');
OpenCL の利用
それでは、下記 OpenCL ランタイムをインストールした Windows 環境で、OpenCL の API を 3つほど呼び出してみます。
1. OpenCL のデバイスID取得
まずは、以下を実施してみます。
- (1)
clGetPlatformIDs
を使ってプラットフォームIDを取得 - (2)
clGetDeviceIDs
を使ってデバイスIDを取得
OpenCL (v1.2) のヘッダーファイルを見てみると、プラットフォームIDの型 cl_platform_id
やデバイスIDの型 cl_device_id
はこれ自体がポインタのようなので ※、これらに該当する型は size_t
としました。
※ そのため、プラットフォームID や デバイスID という表現は 適切ではないかもしれません
node-ffi ではポインタを扱うために Buffer
を使います。
そのための補助関数が ref モジュールに用意されており、下記サンプルでは以下を使っています。
- ref モジュールの
alloc
を使って指定した型に応じた Buffer を作成 - 定義した型の
get
を使って Buffer から値を取得
get
を使えば、型のサイズやエンディアンに応じた値を Buffer から取り出してくれます。 (例えば、int32 なら Buffer の readInt32LE や readInt32BE を使って値を取得する)
なお、エラーの有無は clGetPlatformIDs・clGetDeviceIDs の戻り値が 0 かどうかで判定します。(0: 成功、0以外: エラー)
get_device_id.js
'use strict'; const ffi = require('ffi'); const ref = require('ref'); // 定数の定義 const CL_DEVICE_TYPE_DEFAULT = 1; // ポインタ用の型定義 const uintPtr = ref.refType(ref.types.uint32); const sizeTPtr = ref.refType('size_t'); // OpenCL の関数定義 const openCl = ffi.Library('OpenCL', { 'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]], 'clGetDeviceIDs': ['int', ['size_t', 'ulong', 'uint', sizeTPtr, uintPtr]] }); // エラーチェック処理 const checkError = (errCode, title = '') => { if (errCode != 0) { throw new Error(`${title} Error: ${errCode}`); } }; const platformIdsPtr = ref.alloc(sizeTPtr); // (1) プラットフォームIDを(1つ)取得 let res = openCl.clGetPlatformIDs(1, platformIdsPtr, null); checkError(res, 'clGetPlatformIDs'); // プラットフォームID(get を使って platformIdsPtr の先頭の値を取得) const platformId = sizeTPtr.get(platformIdsPtr); console.log(`platformId: ${platformId}`); const deviceIdsPtr = ref.alloc(sizeTPtr); // (2) デバイスIDを(1つ)取得 res = openCl.clGetDeviceIDs(platformId, CL_DEVICE_TYPE_DEFAULT, 1, deviceIdsPtr, null); checkError(res, 'clGetDeviceIDs'); // デバイスID(get を使って deviceIdsPtr の先頭の値を取得) const deviceId = sizeTPtr.get(deviceIdsPtr); console.log(`deviceId: ${deviceId}`);
実行結果
> node get_device_id.js platformId: 47812336 deviceId: 4404320
2. OpenCL のプラットフォーム情報取得
次は OpenCL のプラットフォーム情報を取得してみます。
プラットフォーム情報は clGetPlatformInfo
を使って取得します。
- (1)
clGetPlatformInfo
でデータサイズを取得 - (2) バッファを確保
- (3)
clGetPlatformInfo
でデータを取得
platform_info.js
'use strict'; const ffi = require('ffi'); const ref = require('ref'); // 定数の定義 const CL_PLATFORM_PROFILE = 0x0900; const CL_PLATFORM_VERSION = 0x0901; const CL_PLATFORM_NAME = 0x0902; const CL_PLATFORM_VENDOR = 0x0903; const CL_PLATFORM_EXTENSIONS = 0x0904; const CL_PLATFORM_HOST_TIMER_RESOLUTION = 0x0905; const uintPtr = ref.refType(ref.types.uint32); const sizeTPtr = ref.refType('size_t'); const openCl = ffi.Library('OpenCL', { 'clGetPlatformIDs': ['int', ['uint', sizeTPtr, uintPtr]], 'clGetPlatformInfo': ['int', ['size_t', 'uint', 'size_t', 'pointer', sizeTPtr]] }); const checkError = (errCode, title = '') => { if (errCode != 0) { throw new Error(`${title} Error: ${errCode}`); } }; // プラットフォーム情報の出力 const printPlatformInfo = (pid, paramName) => { const sPtr = ref.alloc(sizeTPtr); // (1) データサイズを取得 let res = openCl.clGetPlatformInfo(pid, paramName, 0, null, sPtr); checkError(res, 'clGetPlatformInfo size'); // データサイズの値を取り出す const size = sizeTPtr.get(sPtr); // (2) バッファを確保 const buf = Buffer.alloc(size); // (3) データを取得 res = openCl.clGetPlatformInfo(pid, paramName, size, buf, null); checkError(res, 'clGetPlatformInfo data'); // 出力 console.log(buf.toString()); }; const platformIdsPtr = ref.alloc(sizeTPtr); const res = openCl.clGetPlatformIDs(1, platformIdsPtr, null); checkError(res, 'clGetPlatformIDs'); const platformId = sizeTPtr.get(platformIdsPtr); [ CL_PLATFORM_PROFILE, CL_PLATFORM_VERSION, CL_PLATFORM_NAME ].forEach( p => printPlatformInfo(platformId, p) );
実行結果
> node platform_info.js FULL_PROFILE OpenCL 1.2 Intel(R) OpenCL
Keras で iris を分類
Theano・TensorFlow 用のディープラーニングライブラリ Keras を使って、階層型ニューラルネットによる iris の分類を試してみました。
- Keras 1.0.3
- Python 3.5.1
ソースは http://github.com/fits/try_samples/tree/master/blog/20160531/
準備
今回は Docker コンテナで Keras を実行するため、Docker の公式イメージ python をベースに Keras をインストールした Docker イメージを用意します。
Docker イメージ作成
/vagrant/work/keras/Dockerfile
FROM python RUN apt-get update && apt-get upgrade -y RUN pip install --upgrade pip RUN pip install keras RUN pip install scikit-learn RUN apt-get clean
sklearn の iris データセットを使うために scikit-learn もインストールしていますが、Keras を使うだけなら不要です。
また、Theano はデフォルトでインストールされるようですが、TensorFlow を使う場合は別途インストールする必要がありそうです。(今回は Theano を使います)
Dockerfile に対して docker build を実行して Docker イメージを作成します。
Docker ビルド
$ cd /vagrant/work/keras $ docker build --no-cache -t sample/py-keras:0.1 .
(1) 学習
まずは、iris のデータセットを全て使って学習してみます。
「ConvnetJS で iris を分類」 で実施したように、出力層の活性化関数はソフトマックス、損失関数(誤差関数)に交差エントロピーを使います。
sklearn の iris データセットのように、ラベルデータ (target
の値) が数値 (0 ~ 2) の場合 ※ に交差エントロピーを実施するには compile
の引数で loss = 'sparse_categorical_crossentropy'
と指定すれば良さそうです。
※ iris.target の内容 [0 0 0 0 0 0 0 ・・・ 0 0 0 0 0 0 0 ・・・ 1 1 1 1 1 1 1 ・・・ 2 2 2 2 2 2 2 ・・・ 2 2]
とりあえず、入力層 - 隠れ層 - 出力層
(隠れ層のニューロン数 8)というレイヤー構成を使い、学習処理(fit
)はミニバッチサイズを 1 として 50 回繰り返すように指定しました。
/vagrant/work/iris_sample1.py
from keras.models import Sequential from keras.layers.core import Dense, Activation from sklearn import datasets # モデルの定義 model = Sequential() # 隠れ層の定義 model.add(Dense(input_dim = 4, output_dim = 8)) # 隠れ層の活性化関数 model.add(Activation('relu')) # 出力層の定義 model.add(Dense(output_dim = 3)) # 出力層の活性化関数 model.add(Activation('softmax')) model.compile(loss = 'sparse_categorical_crossentropy', optimizer = 'sgd', metrics = ['accuracy']) iris = datasets.load_iris() # 学習 model.fit(iris.data, iris.target, nb_epoch = 50, batch_size = 1)
docker run で Keras 用の Docker コンテナを起動した後、コンテナ内で上記を実行します。
実行例
$ docker run --rm -it -v /vagrant/work:/work sample/py-keras:0.1 bash # cd /work # python iris_sample1.py Using Theano backend. Epoch 1/50 1/150 [..............................] - ETA: 0s - loss: 3.4213 - acc: 0.0000e 2/150 [..............................] - ETA: 0s - loss: 2.2539 - acc: 0.0000e・・・ Epoch 49/50 150/150 [==============================] - 0s - loss: 0.1225 - acc: 0.9533 Epoch 50/50 150/150 [==============================] - 0s - loss: 0.1525 - acc: 0.9333
誤差(loss)と正解率(acc)が出力されました。
(2) 学習と評価
次は、iris データセットを学習用と評価用に分割して学習と評価をそれぞれ実行してみます。
データセットを直接シャッフルする代わりに、0 ~ 149 の数値をランダムに配置した配列を numpy の random.permutation
で作成し、学習・評価用のデータ分割に利用しました。
/vagrant/work/iris_sample2.py
import sys from keras.models import Sequential from keras.layers.core import Dense, Activation from sklearn import datasets import numpy as np # 学習・評価用のデータ分割率 trainEvalRate = 0.7 # 学習の繰り返し回数 epoch = int(sys.argv[1]) # 隠れ層のニューロン数 neuNum = int(sys.argv[2]) # 隠れ層の活性化関数 act = sys.argv[3] optm = sys.argv[4] model = Sequential() model.add(Dense(input_dim = 4, output_dim = neuNum)) model.add(Activation(act)) model.add(Dense(output_dim = 3)) model.add(Activation('softmax')) model.compile( loss = 'sparse_categorical_crossentropy', optimizer = optm, metrics = ['accuracy'] ) iris = datasets.load_iris() data_size = len(iris.data) train_size = int(data_size * trainEvalRate) perm = np.random.permutation(data_size) # 学習用データ x_train = iris.data[ perm[0:train_size] ] y_train = iris.target[ perm[0:train_size] ] # 学習 model.fit(x_train, y_train, nb_epoch = epoch, batch_size = 1) print('-----') # 評価用データ x_test = iris.data[ perm[train_size:] ] y_test = iris.target[ perm[train_size:] ] # 評価 res = model.evaluate(x_test, y_test, batch_size = 1) print(res)
実行例
# python iris_sample2.py 50 6 sigmoid adam Using Theano backend. Epoch 1/50 105/105 [==============================] - 0s - loss: 1.0751 - acc: 0.3524 Epoch 2/50 105/105 [==============================] - 0s - loss: 1.0417 - acc: 0.3524 ・・・ Epoch 49/50 105/105 [==============================] - 0s - loss: 0.3503 - acc: 0.9714 Epoch 50/50 105/105 [==============================] - 0s - loss: 0.3458 - acc: 0.9714 ----- 45/45 [==============================] - 0s [0.35189295262098313, 0.97777777777777775]
JMX で Java Flight Recorder (JFR) を実行する
Java Flight Recorder (JFR) は Java Mission Control (jmc) や jcmd コマンドから実行できますが、今回は以下の MBean を使って JMX から実行してみます。
- com.sun.management:type=DiagnosticCommand
この MBean は以下のような操作を備えており(戻り値は全て String)、jcmd コマンドと同じ事ができるようです。
- jfrCheck
- jfrDump
- jfrStop
- jfrStart
- vmCheckCommercialFeatures
- vmCommandLine
- vmFlags
- vmSystemProperties
- vmUnlockCommercialFeatures
- vmUptime
- vmVersion
- vmNativeMemory
- gcRotateLog
- gcRun
- gcRunFinalization
- gcClassHistogram
- gcClassStats
- threadPrint
(a) JFR の実行
JMX を使う方法はいくつかありますが、今回は Attach API でローカルの VM へアタッチし、startLocalManagementAgent
メソッドで JMX エージェントを適用する方法を用いました。
DiagnosticCommand には java.lang.management.ThreadMXBean のようなラッパーが用意されていないようなので GroovyMBean
を使う事にします。
jfrStart
の引数は jcmd コマンドと同じものを String 配列にして渡すだけのようです。(jfrStart 以外も基本的に同じ)
また、JFR の実行には Commercial Features のアンロックが必要です。
jfr_run.groovy
import com.sun.tools.attach.VirtualMachine import javax.management.remote.JMXConnectorFactory import javax.management.remote.JMXServiceURL def pid = args[0] def duration = args[1] def fileName = args[2] // 指定の JVM プロセスへアタッチ def vm = VirtualMachine.attach(pid) try { // JMX エージェントを適用 def jmxuri = vm.startLocalManagementAgent() JMXConnectorFactory.connect(new JMXServiceURL(jmxuri)).withCloseable { def server = it.getMBeanServerConnection() // MBean の取得 def bean = new GroovyMBean(server, 'com.sun.management:type=DiagnosticCommand') // Commercial Features のアンロック (JFR の実行に必要) println bean.vmUnlockCommercialFeatures() // JFR の開始 println bean.jfrStart([ "duration=${duration}", "filename=${fileName}", 'delay=10s' ] as String[]) } } finally { vm.detach() }
実行例
apache-tomcat-9.0.0.M4 へ適用してみます。
Tomcat 実行
> startup
以下の環境で実行しました。
- Groovy 2.4.6
- Java SE 8u92 64bit版
JFR 実行
> jps 4576 Jps 2924 Bootstrap > groovy jfr_run.groovy 2924 1m sample1.jfr Commercial Features now unlocked. Recording 1 scheduled to start in 10 s. The result will be written to: C:\・・・\apache-tomcat-9.0.0.M4\apache-tomcat-9.0.0.M4\bin\sample1.jfr
jfrStart は JFR の完了を待たずに戻り値を返すため、JFR の実行状況は別途確認する事になります。
出力結果 Recording 1 scheduled
の 1
が recoding の番号で、この番号を使って JFR の状態を確認できます。
ファイル名を相対パスで指定すると対象プロセスのカレントディレクトリへ出力されるようです。 (今回は Tomcat の bin ディレクトリへ出力されました)
(b) JFR の状態確認
JFR の実行状況を確認するには jfrCheck
を使います。
下記では recording の番号を指定し、該当する JFR の実行状況を出力しています。
jfrCheck の引数が null の場合は全ての JFR 実行状態を取得するようです。
jfr_check.groovy
import com.sun.tools.attach.VirtualMachine import javax.management.remote.JMXConnectorFactory import javax.management.remote.JMXServiceURL def pid = args[0] String[] params = (args.length > 1)? ["recording=${args[1]}"]: null def vm = VirtualMachine.attach(pid) try { def jmxuri = vm.startLocalManagementAgent() JMXConnectorFactory.connect(new JMXServiceURL(jmxuri)).withCloseable { def server = it.getMBeanServerConnection() def bean = new GroovyMBean(server, 'com.sun.management:type=DiagnosticCommand') println bean.jfrCheck(params) } } finally { vm.detach() }
実行例
recording 番号(下記では 1
)を指定して実行します。
実行例1 (JFR 実行中)
> groovy jfr_check.groovy 2924 1 Recording: recording=1 name="sample1.jfr" duration=1m filename="sample1.jfr" compress=false (running)
実行例2 (JFR 完了後)
> groovy jfr_check.groovy 2924 1 Recording: recording=1 name="sample1.jfr" duration=1m filename="sample1.jfr" compress=false (stopped)
今回作成したサンプルのソースは http://github.com/fits/try_samples/tree/master/blog/20160519/
JDI でオブジェクトの世代(Young・Old)を判別する2
前回 の処理を sun.jvm.hotspot.oops.ObjectHeap
を使って高速化してみたいと思います。(世代の判別方法などは前回と同じ)
使用した環境は前回と同じです。
- Groovy 2.4.6
- Java SE 8u92 64bit版 (JDK)
ソースは http://github.com/fits/try_samples/tree/master/blog/20160506/
ObjectHeap で Oop を取得
ObjectReference の代わりに、sun.jvm.hotspot.oops.ObjectHeap
の iterate(HeapVisitor)
メソッドを使えば Oop を取得できます。
今回のような方法では、以下の理由で iterate メソッドの引数へ SAJDIClassLoader がロードした sun.jvm.hotspot.oops.HeapVisitor
インターフェースの実装オブジェクトを与える必要があります。
- JDI の内部で管理している Serviceability Agent API は
sun.jvm.hotspot.jdi.SAJDIClassLoader
によってロードされている
下記サンプルでは SAJDIClassLoader がロードした HeapVisitor を入手し、asType
を使って実装オブジェクトを作成しています。
また、HeapVisitor の doObj
で false を返すと処理を継続し、true を返すと中止 ※ するようです。
※ 厳密には、 対象としている Address 範囲の while ループを break するだけで、 その外側の(liveRegions に対する)for ループは継続するようです (ObjectHeap の iterateLiveRegions メソッドのソース参照)
なお、ObjectHeap は sun.jvm.hotspot.jdi.VirtualMachineImpl
から saObjectHeap()
で取得するか、sun.jvm.hotspot.runtime.VM
から取得します。
check_gen2.groovy
import com.sun.jdi.Bootstrap def pid = args[0] def prefix = (args.length > 1)? args[1]: '' def manager = Bootstrap.virtualMachineManager() def connector = manager.attachingConnectors().find { it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector' } def params = connector.defaultArguments() params.get('pid').setValue(pid) def vm = connector.attach(params) // 世代の判定処理を返す generation = { heap -> def hasYoungGen = heap.metaClass.getMetaMethod('youngGen') != null [ young: hasYoungGen? heap.youngGen(): heap.getGen(0), old: hasYoungGen? heap.oldGen(): heap.getGen(1) ] } try { def uv = vm.saVM.universe def gen = generation(uv.heap()) def youngGen = gen.young def oldGen = gen.old println "*** youngGen=${youngGen}, oldGen=${oldGen}" println '' def objHeap = vm.saObjectHeap() // 以下でも可 //def objHeap = vm.saVM.objectHeap // SAJDIClassLoader がロードした HeapVisitor インターフェースを取得 def heapVisitorCls = uv.class.classLoader.loadClass('sun.jvm.hotspot.oops.HeapVisitor') // SAJDIClassLoader がロードした HeapVisitor インターフェースを実装 def heapVisitor = [ prologue: { size -> }, epilogue: {}, doObj: { oop -> def clsName = oop.klass.name.asString() if (clsName.startsWith(prefix)) { def age = oop.mark.age() // 世代の判別 def inYoung = youngGen.isIn(oop.handle) def inOld = oldGen.isIn(oop.handle) def identityHash = '' try { identityHash = Long.toHexString(oop.identityHash()) } catch (e) { } println "class=${clsName}, hash=${identityHash}, handle=${oop.handle}, age=${age}, inYoung=${inYoung}, inOld=${inOld}" } // 処理を継続する場合は false を返す false } ].asType(heapVisitorCls) objHeap.iterate(heapVisitor) } finally { vm.dispose() }
動作確認
前回と同じように、実行中の apache-tomcat-9.0.0.M4 へ適用してみました。
前回と異なり、クラス名が '/' で区切られている点に注意
実行例1 (Windows の場合)
> jps 3604 Bootstrap 4516 Jps
> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar check_gen2.groovy 3604 org/apache/catalina/core/StandardContext *** youngGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSYoungGen@0x0000000002149ab0, oldGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSOldGen@0x0000000002149b40 class=org/apache/catalina/core/StandardContextValve, hash=0, handle=0x00000000c3a577d0, age=1, inYoung=false, inOld=true class=org/apache/catalina/core/StandardContext$NoPluggabilityServletContext, hash=0, handle=0x00000000c3a633d8, age=0, inYoung=false, inOld=true class=org/apache/catalina/core/StandardContext$ContextFilterMaps, hash=0, handle=0x00000000c3a63ef0, age=1, inYoung=false, inOld=true class=org/apache/catalina/core/StandardContext$NoPluggabilityServletContext, hash=0, handle=0x00000000ebc46da0, age=0, inYoung=true, inOld=false class=org/apache/catalina/core/StandardContext, hash=6f2d2815, handle=0x00000000eddfeaa0, age=1, inYoung=true, inOld=false class=org/apache/catalina/core/StandardContext, hash=21f2e66b, handle=0x00000000eddff238, age=3, inYoung=true, inOld=false ・・・
実行例2 (Linux の場合)
$ jps 2778 Jps 2766 Bootstrap
$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar check_gen2.groovy 2766 org/apache/catalina/core/StandardContext *** youngGen=sun.jvm.hotspot.memory.DefNewGeneration@0x00007f0760019cb0, oldGen=sun.jvm.hotspot.memory.TenuredGeneration@0x00007f076001bfc0 class=org/apache/catalina/core/StandardContext, hash=497fe2c4, handle=0x00000000f821bf90, age=0, inYoung=true, inOld=false class=org/apache/catalina/core/StandardContext$ContextFilterMaps, hash=0, handle=0x00000000f821c5d8, age=0, inYoung=true, inOld=false class=org/apache/catalina/core/StandardContextValve, hash=0, handle=0x00000000f821ca60, age=0, inYoung=true, inOld=false ・・・ class=org/apache/catalina/core/StandardContext, hash=5478de1a, handle=0x00000000fb12b310, age=1, inYoung=false, inOld=true class=org/apache/catalina/core/StandardContext$NoPluggabilityServletContext, hash=0, handle=0x00000000fb12f6b0, age=0, inYoung=false, inOld=true class=org/apache/catalina/core/StandardContext$ContextFilterMaps, hash=0, handle=0x00000000fb131a80, age=0, inYoung=false, inOld=true class=org/apache/catalina/core/StandardContextValve, hash=0, handle=0x00000000fb1398b0, age=0, inYoung=false, inOld=true
JDI でオブジェクトの世代(Young・Old)を判別する
前回、オブジェクトの age を取得しましたが、同様の方法で今回はオブジェクトが Young 世代(New 領域)と Old 世代(Old 領域) のどちらに割り当てられているかを判別してみたいと思います。 (ただし、結果の正否は確認できていません)
使用した環境は前回と同じです。
- Groovy 2.4.6
- Java SE 8u92 64bit版 (JDK)
ソースは http://github.com/fits/try_samples/tree/master/blog/20160430/
Young・Old 世代の判別
さて、Young・Old の判別方法ですが。
Serviceability Agent API を見てみると sun.jvm.hotspot.gc_implementation.parallelScavenge
パッケージに PSYoungGen
と PSOldGen
というクラスがあり、isIn(Address)
メソッドで判定できそうです。
更に PSYoungGen と PSOldGen は sun.jvm.hotspot.gc_implementation.parallelScavenge.ParallelScavengeHeap
から取得できます。
Address (sun.jvm.hotspot.debugger
パッケージ所属) は sun.jvm.hotspot.oops.Oop
の getHandle()
か getMark().getAddress()
で取得できるので (下記サンプルでは getHandle を使用)、ParallelScavengeHeap を取得すれば何とかなりそうです。
実際に試してみたところ、ParallelScavengeHeap を取得できたのは Windows 環境で、Linux 環境では GenCollectedHeap を使った別の方法 (getGen
メソッドを使う) が必要でした。 (GC の設定等によって更に変わるかもしれません)
世代の判定クラス
実行環境 | ヒープクラス ※ | Young 世代の判定クラス | Old 世代の判定クラス |
---|---|---|---|
Windows | ParallelScavengeHeap | PSYoungGen | PSOldGen |
Linux | GenCollectedHeap | DefNewGeneration | TenuredGeneration |
※ Universe の heap() メソッド戻り値の実際の型 CollectedHeap のサブクラス
上記を踏まえて、前回の処理をベースに以下を追加してみました。
- (1)
sun.jvm.hotspot.jdi.VirtualMachineImpl
からsun.jvm.hotspot.runtime.VM
オブジェクトを取り出す ※1 - (2) VM オブジェクトから
sun.jvm.hotspot.memory.Universe
オブジェクトを取得 - (3) Universe オブジェクトから CollectedHeap (のサブクラス) を取得 ※2
- (4) (3) の結果から世代を判定するオブジェクトをそれぞれ取得
(4) で妥当な条件分岐の仕方が分からなかったので、とりあえず youngGen メソッドが無ければ GenCollectedHeap として処理するようにしました。
※1 private フィールドの saVM か、package メソッドの saVM() で取得 ※2 今回のやり方では、Windows は ParallelScavengeHeap、 Linux は GenCollectedHeap でした
JDI の SAPIDAttachingConnector で attach した結果が VirtualMachineImpl となります。
また、SAPIDAttachingConnector でデバッグ接続した場合 (読み取り専用のデバッグ接続)、デバッグ対象オブジェクトのメソッド (hashCode や toString 等) を呼び出せないようなので、オブジェクトを識別するための情報を得るため identityHash
を使ってみました。 (ただし、戻り値が 0 になるものが多数ありました)
check_gen.groovy
import com.sun.jdi.Bootstrap def pid = args[0] def prefix = args[1] def manager = Bootstrap.virtualMachineManager() def connector = manager.attachingConnectors().find { it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector' } def params = connector.defaultArguments() params.get('pid').setValue(pid) def vm = connector.attach(params) // (4) 世代を判定するためのオブジェクトを取得 generation = { heap -> def hasYoungGen = heap.metaClass.getMetaMethod('youngGen') != null [ // Young 世代の判定オブジェクト(PSYoungGen or DefNewGeneration) young: hasYoungGen? heap.youngGen(): heap.getGen(0), // Old 世代の判定オブジェクト(PSOldGen or TenuredGeneration) old: hasYoungGen? heap.oldGen(): heap.getGen(1) ] } try { if (vm.canGetInstanceInfo()) { // (1) (2) def uv = vm.saVM.universe // (3) def gen = generation(uv.heap()) def youngGen = gen.young def oldGen = gen.old println "*** youngGen=${youngGen}, oldGen=${oldGen}" println '' vm.allClasses().findAll { it.name().startsWith(prefix) }.each { cls -> println cls.name() cls.instances(0).each { inst -> def oop = inst.ref() def age = oop.mark.age() // 世代の判別 def inYoung = youngGen.isIn(oop.handle) def inOld = oldGen.isIn(oop.handle) def identityHash = '' try { identityHash = Long.toHexString(oop.identityHash()) } catch (e) { } println " hash=${identityHash}, handle=${oop.handle}, age=${age}, inYoung=${inYoung}, inOld=${inOld}" } } } } finally { vm.dispose() }
動作確認
前回と同じように、実行中の apache-tomcat-9.0.0.M4 へ適用してみました。
実行例1 (Windows の場合)
> jps 2836 Bootstrap 5944 Jps
> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar check_gen.groovy 2836 org.apache.catalina.core.StandardContext *** youngGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSYoungGen@0x0000000002049ad0, oldGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSOldGen@0x0000000002049b60 org.apache.catalina.core.StandardContext hash=66dfd722, handle=0x00000000c394a990, age=0, inYoung=false, inOld=true hash=39504d4e, handle=0x00000000edea7cf8, age=3, inYoung=true, inOld=false hash=194311fa, handle=0x00000000edea8e90, age=1, inYoung=true, inOld=false hash=2b28e016, handle=0x00000000edf0c130, age=2, inYoung=true, inOld=false hash=578787b8, handle=0x00000000edf457c0, age=1, inYoung=true, inOld=false org.apache.catalina.core.StandardContext$ContextFilterMaps hash=0, handle=0x00000000c394e7d0, age=0, inYoung=false, inOld=true hash=0, handle=0x00000000c396ec90, age=2, inYoung=false, inOld=true hash=0, handle=0x00000000c3988eb0, age=1, inYoung=false, inOld=true hash=0, handle=0x00000000edf04320, age=1, inYoung=true, inOld=false hash=0, handle=0x00000000edf70988, age=1, inYoung=true, inOld=false ・・・
> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar check_gen.groovy 2836 org.apache.catalina.LifecycleEvent *** youngGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSYoungGen@0x0000000002049ad0, oldGen=sun.jvm.hotspot.gc_implementation.parallelScavenge.PSOldGen@0x0000000002049b60 org.apache.catalina.LifecycleEvent hash=0, handle=0x00000000c37459c0, age=0, inYoung=false, inOld=true hash=0, handle=0x00000000c374ed40, age=1, inYoung=false, inOld=true hash=0, handle=0x00000000c39ff950, age=0, inYoung=false, inOld=true hash=0, handle=0x00000000ebb8ef90, age=0, inYoung=true, inOld=false hash=0, handle=0x00000000ebb90490, age=0, inYoung=true, inOld=false hash=0, handle=0x00000000ebb904c0, age=0, inYoung=true, inOld=false ・・・
実行例2 (Linux の場合)
$ jps 2801 Jps 2790 Bootstrap
$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar check_gen.groovy 2790 org.apache.catalina.core.StandardContext *** youngGen=sun.jvm.hotspot.memory.DefNewGeneration@0x00007fca50019cb0, oldGen=sun.jvm.hotspot.memory.TenuredGeneration@0x00007fca5001bfc0 org.apache.catalina.core.StandardContext hash=27055bff, handle=0x00000000fb025d38, age=1, inYoung=false, inOld=true hash=5638a30f, handle=0x00000000fb1270a8, age=1, inYoung=false, inOld=true hash=15fad243, handle=0x00000000fb296730, age=1, inYoung=false, inOld=true hash=36c4d4a0, handle=0x00000000fb2f3cf0, age=1, inYoung=false, inOld=true hash=33309557, handle=0x00000000fb2f3ef8, age=1, inYoung=false, inOld=true org.apache.catalina.core.StandardContextValve hash=0, handle=0x00000000fb045ad8, age=0, inYoung=false, inOld=true hash=0, handle=0x00000000fb135300, age=0, inYoung=false, inOld=true hash=0, handle=0x00000000fb2ad568, age=1, inYoung=false, inOld=true hash=0, handle=0x00000000fb3022b0, age=1, inYoung=false, inOld=true hash=0, handle=0x00000000fb3050e8, age=1, inYoung=false, inOld=true ・・・
$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar check_gen.groovy 2790 org.apache.catalina.LifecycleEvent *** youngGen=sun.jvm.hotspot.memory.DefNewGeneration@0x00007fca50019cb0, oldGen=sun.jvm.hotspot.memory.TenuredGeneration@0x00007fca5001bfc0 org.apache.catalina.LifecycleEvent hash=0, handle=0x00000000f82079a8, age=0, inYoung=true, inOld=false hash=0, handle=0x00000000f8207a00, age=0, inYoung=true, inOld=false hash=0, handle=0x00000000f8210470, age=0, inYoung=true, inOld=false ・・・ hash=0, handle=0x00000000f8506568, age=0, inYoung=true, inOld=false hash=0, handle=0x00000000fb003370, age=0, inYoung=false, inOld=true
この結果の正否はともかく、一応は判別できているように見えます。
ちなみに、前回と同様に処理が遅い(重い)点に関しては、Oop を Serviceability Agent API の sun.jvm.hotspot.oops.ObjectHeap
で取得するように変更すれば改善できます。
注意点
今回のように JDI の内部で管理している Serviceability Agent API を取り出して使う場合の注意点は以下の通りです。
- JDI 内部の Serviceability Agent API のクラス(インターフェースも含む)は
sun.jvm.hotspot.jdi.SAJDIClassLoader
クラスローダーによってロードされる
同じ名称のクラスでもロードするクラスローダーが異なれば別物となりますので、Java で今回のような処理を実装しようとすると、クラスのキャストができずリフレクション等を多用する事になると思います。
また、Groovy でも HeapVisitor 等を使う場合に多少の工夫が必要になります。
JDI でオブジェクトの年齢(age)を取得
HotSpot VM の世代別 GC において、オブジェクト(インスタンス)には年齢 (age) が設定されており、Minor GC が適用される度にカウントアップされ、長命オブジェクトかどうかの判定に使われるとされています。
そこで今回は、Groovy で JDI (Java Debug Interface) と Serviceability Agent API を使って、実行中の Java アプリケーションへアタッチし、オブジェクトの年齢を取得してみたいと思います。
使用した環境は以下の通りです。
- Groovy 2.4.6
- Java SE 8u92 64bit版 (JDK)
ソースは http://github.com/fits/try_samples/tree/master/blog/20160425/
はじめに
今回は、JDI の SA PID コネクタ (sun.jvm.hotspot.jdi.SAPIDAttachingConnector
) を使ってデバッグ接続する事にします。
SAPIDAttachingConnector の特徴は以下のようになります。
SAPIDAttachingConnector の特徴
利点 | 欠点 |
---|---|
デバッグ対象アプリケーションの実行時に -agentlib:jdwp のようなオプション指定が不要 |
読み取り専用で VM へアタッチするため、実行可能な API が限定される。デバッグ接続中は対象のプロセスが中断する ※ |
※ 全スレッドの処理が中断した状態となり、 ステップ実行のような状態を変化させる API は使えません。 (VMCannotBeModifiedException が throw される API は使えない)
デバッグ接続中は全ての処理が完全に中断するので、運用中のサーバーアプリケーションなどへの適用には不向きだと思います。
(a) JDI でクラスの一覧を取得
まずは SAPIDAttachingConnector の動作確認も兼ねて、実行中のアプリケーションへデバッグ接続して、クラスの一覧を出力してみます。
SAPIDAttachingConnector を使ったデバッグ接続手順は以下のようになります。
- (1) Bootstrap から
com.sun.jdi.VirtualMachineManager
を取得 - (2) VirtualMachineManager から利用可能な JDI コネクタの一覧を
attachingConnectors
メソッドで取得し、SAPIDAttachingConnector を抽出 - (3) デバッグ対象アプリケーションのプロセスIDをパラメータへ設定しデバッグ接続
SAPIDAttachingConnector を使うには、JDI を使用するアプリケーション(下記スクリプト)の実行時のクラスパスへ lib/sa-jdi.jar を含めておかなければならない点に注意が必要です ※。
そうしないと attachingConnectors の結果に SAPIDAttachingConnector が含まれず、下記スクリプトでは NullPointerException となります。 (find の結果が null になるので)
※ JDI を使うには lib/tools.jar が必要ですが、 Groovy の場合は groovy-starter.conf のデフォルト設定で tools.jar をロードするようになっています
class_list.groovy
import com.sun.jdi.Bootstrap def pid = args[0] // (1) def manager = Bootstrap.virtualMachineManager() // (2) SAPIDAttachingConnector を抽出 def connector = manager.attachingConnectors().find { it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector' } // パラメータの設定 def params = connector.defaultArguments() params.get('pid').setValue(pid) // (3) デバッグ接続 def vm = connector.attach(params) try { // クラス情報 com.sun.jdi.ReferenceType の取得 vm.allClasses().each { cls -> println cls.name() } } finally { vm.dispose() }
動作確認
今回は apache-tomcat-9.0.0.M4 を起動しておき、上記スクリプトでデバッグ接続してクラスの情報を取得してみます。
実行例1 (Windows の場合)
まずはデバッグ対象の Tomcat を実行しておきます。(設定等はデフォルトのまま)
Tomcat 起動
> startup
jps 等でプロセス ID を調べて上記スクリプトを実行します。 JDK の lib/sa-jdi.jar をクラスパスへ指定する必要があります。
class_list.groovy の実行
> jps 804 Bootstrap 5244 Jps > groovy -cp %JAVA_HOME%/lib/sa-jdi.jar class_list.groovy 804 sun.management.HotSpotDiagnostic java.security.CodeSigner java.security.CodeSigner[] java.lang.Character java.lang.Character[] ・・・ short[] short[][] long[] float[] double[]
実行例2 (Linux の場合)
Linux でも同じです。
Tomcat 起動
$ ./startup.sh
class_list.groovy の実行
$ jps 2388 Jps 2363 Bootstrap $ groovy -cp $JAVA_HOME/lib/sa-jdi.jar class_list.groovy 2363 sun.nio.ch.SelectorProviderImpl java.util.jar.JarInputStream org.apache.catalina.mapper.Mapper$ContextVersion org.apache.catalina.mapper.Mapper$ContextVersion[] java.nio.channels.SeekableByteChannel ・・・ short[] short[][] long[] float[] double[]
SAPIDAttachingConnector により読み取り専用でデバッグ接続するため、デバッグ接続中に Tomcat へアクセスしても応答が返ってこなくなる点にご注意下さい。
(b) JDI でオブジェクトの年齢を取得
それでは、本題の年齢を取得してみます。
クラス情報 ReferenceType を取得するまでは上記 (a) と同じで、その後は以下のように JDI の API から Serviceability Agent API を取り出せばオブジェクトの年齢を取得できます。
- (1) ReferenceType の
instances
メソッドでインスタンス情報 ObjectReference を取得 (引数を 0 にすると全インスタンスを取得) - (2) (1) の実体が
sun.jvm.hotspot.jdi.ObjectReferenceImpl
なので、protected メソッドのref()
を呼び出してsun.jvm.hotspot.oops.Oop
(Serviceability Agent API) を取得 - (3) Oop の
getMark()
メソッドでsun.jvm.hotspot.oops.Mark
を取得し、age()
メソッドで年齢を取得
instances メソッドは非常に重い処理だったので、今回は対象クラスをコマンドライン引数 (第2引数) で指定した名称で始まるクラス名に限定するようにしました。 (下記 findAll の箇所)
また、Oop は他の方法 (sun.jvm.hotspot.oops.ObjectHeap
を使用) でも取得できるようです。
age_list.groovy
import com.sun.jdi.Bootstrap def pid = args[0] def prefix = args[1] def manager = Bootstrap.virtualMachineManager() def connector = manager.attachingConnectors().find { it.name() == 'sun.jvm.hotspot.jdi.SAPIDAttachingConnector' } def params = connector.defaultArguments() params.get('pid').setValue(pid) def vm = connector.attach(params) try { if (vm.canGetInstanceInfo()) { vm.allClasses().findAll { it.name().startsWith(prefix) }.each { cls -> println cls.name() // (1) インスタンス情報 ObjectReference の取得 cls.instances(0).each { inst -> // (2) Oop の取得 def oop = inst.ref() // (3) Mark を取得して年齢を取得 def age = oop.mark.age() println " handle=${oop.handle}, age=${age}" } } } } finally { vm.dispose() }
ここで、HotSpot VM のソース share/vm/oops/oop.inline.hpp の oopDesc::age()
を見ると、has_displaced_mark()
が true の時は displaced_mark
の age を取得しているのですが、同じように実装(下記)すると hasDisplacedMarkHelper が true の場合に sun.jvm.hotspot.debugger.UnalignedAddressException
が発生したため (Windows・Linux の両方で発生)、とりあえず displacedMarkHelper は今回無視するようにしました。
hasDisplacedMarkHelper が true の場合に UnalignedAddressException が発生したコード例
def mark = oop.mark def age = mark.hasDisplacedMarkHelper()? mark.displacedMarkHelper().age(): mark.age()
動作確認
先程と同じように実行中の apache-tomcat-9.0.0.M4 に対して適用してみます。
全クラスを対象にすると時間がかかり過ぎるので、クラス名が org.apache.catalina.core.ApplicationContext で始まるものに限定してみました。
実行例1 (Windows の場合)
> groovy -cp %JAVA_HOME%/lib/sa-jdi.jar age_list.groovy 804 org.apache.catalina.core.ApplicationContext org.apache.catalina.core.ApplicationContextFacade handle=0x00000000c39602f0, age=0 handle=0x00000000c3962460, age=1 handle=0x00000000c3971718, age=0 handle=0x00000000ecf913b0, age=3 handle=0x00000000ecf99950, age=1 org.apache.catalina.core.ApplicationContext handle=0x00000000c39607e0, age=0 handle=0x00000000c3966960, age=1 handle=0x00000000c3971c08, age=0 handle=0x00000000ecf918a0, age=3 handle=0x00000000ecf99680, age=1
実行例2 (Linux の場合)
$ groovy -cp $JAVA_HOME/lib/sa-jdi.jar age_list.groovy 2363 org.apache.catalina.core.ApplicationContext org.apache.catalina.core.ApplicationContextFacade handle=0x00000000fb041d08, age=1 handle=0x00000000fb1288d8, age=0 handle=0x00000000fb2b1df0, age=1 handle=0x00000000fb307c98, age=1 handle=0x00000000fb318aa8, age=1 org.apache.catalina.core.ApplicationContext handle=0x00000000fb0338c0, age=1 handle=0x00000000fb12ed50, age=0 handle=0x00000000fb2a56a0, age=1 handle=0x00000000fb2f8da8, age=1 handle=0x00000000fb312a80, age=1
Deeplearning4J で iris を分類
Deeplearning4J (DL4J) を使って 「ConvNetJS で iris を分類」 と同様に iris を分類してみました。
今回は Groovy を使って実行します。
ソースは http://github.com/fits/try_samples/tree/master/blog/20160412/
準備
iris データセット
iris のデータセットは org.deeplearning4j.datasets.DataSets.iris()
メソッドで取得できるため、ConvNetJS の時のようにダウンロードする必要はありません。
OpenBlas のインストール(Windows)
Deeplearning4J が利用する ND4J を効果的に使うには OpenBLAS などのネイティブの BLAS ライブラリが必要です。
「MNIST データセットのパース」 でも少し書きましたが、OpenBLAS を Windows 環境へインストールするには以下のようにします。
- http://nd4j.org/getstarted.html#open のリンクから ND4J_Win64_OpenBLAS-v0.2.14.zip をダウンロード・解凍し、環境変数 PATH へ設定
また、JniLoader が netlib-native_system-win-x86_64.dll
を %TEMP% ディレクトリへダウンロードするのを防止するには、この dll も環境変数 PATH へ設定した場所へ配置しておきます。 (dll は Maven の Central Repository からダウンロードできます)
今回は、ND4J_Win64_OpenBLAS-v0.2.14.zip の解凍先へ netlib-native_system-win-x86_64.dll も配置しました。
- C:\ND4J_Win64_OpenBLAS-v0.2.14
環境変数 PATH 設定例
> set PATH=C:\ND4J_Win64_OpenBLAS-v0.2.14;%PATH%
共通処理の作成
Deeplearning4J には学習モデルを JSON 化するような処理は用意されていないようです。
ただ、MultiLayerNetwork
自体が Serializable
なので、今回は Java のシリアライズ機能を使って保存等を行いました。
1. 学習モデルのセーブ・ロード処理
ObjectInputStream・ObjectOutputStream を使って MultiLayerNetwork
を保存・復元する処理です。
ModelManager.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8') @Grab('org.nd4j:nd4j-x86:0.4-rc3.8') import org.deeplearning4j.nn.conf.MultiLayerConfiguration import org.deeplearning4j.nn.multilayer.MultiLayerNetwork abstract class ModelManager extends Script { // MultiLayerNetwork の復元 def loadModel(String fileName) { new File(fileName).withObjectInputStream(this.class.classLoader) { it.readObject() as MultiLayerNetwork } } // MultiLayerNetwork の生成と保存 def saveModel(String fileName, MultiLayerConfiguration conf) { // MultiLayerNetwork の生成と初期構築 def model = new MultiLayerNetwork(conf) model.init() new File(fileName).withObjectOutputStream { it.writeObject model } } }
2. 学習・評価処理
以前 ConvNetJS で実装したものと同じような出力結果となるように実装してみました。
今回使用した MultiLayerNetwork
のメソッドは以下の通りです。
メソッド | 備考 |
---|---|
fit | 学習の実施(誤差の算出、重みの調整等) |
score | 誤差の取得(引数次第で誤差の算出も実施) |
output | ニューラルネットの処理結果を取得 |
output
は正解率の算出に使うだけなので、第 2引数を false にして評価モード(TrainingMode.TEST)で実行します。
今回は、学習時と評価時の処理を共通化するため score
メソッドで誤差を算出しましたが、setListeners
メソッドで ScoreIterationListener
を設定すれば誤差をログ出力できます。
Evaluation
の eval
メソッドへ正解のラベルデータとニューラルネットの処理結果を渡すと正解率などを算出してくれ、accuracy
メソッドで正解率を取得できます。 (stats
メソッドを使えば結果を文字列で取得する事もできます)
org.nd4j.linalg.dataset.api.DataSet
は splitTestAndTrain
メソッドで学習用と評価用にデータを分割 ※、batchBy
メソッドでミニバッチへ分割できます。
※ getTrain で学習用、getTest で評価用のデータセットを取得できます
iris_train.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8') @Grab('org.nd4j:nd4j-x86:0.4-rc3.8') import org.deeplearning4j.datasets.DataSets import org.deeplearning4j.eval.Evaluation import org.nd4j.linalg.dataset.SplitTestAndTrain import groovy.transform.BaseScript @BaseScript ModelManager baseScript def epoch = args[0] as int def trainRate = 0.7 def batchSize = 1 def modelFile = args[1] // 誤差・正解率の算出 class SimpleEvaluator { private def model private def ev = new Evaluation(3) // iris の品種の数 3 を設定 private def lossList = [] SimpleEvaluator(model) { this.model = model } def eval(d) { // 誤差の算出とリストへの追加 lossList << model.score(d) // 正解率などの算出 ev.eval(d.labels, model.output(d.featureMatrix, false)) } // 誤差(平均値) def loss() { lossList.sum() / lossList.size() } // 正解率 def accuracy() { ev.accuracy() } } // 学習モデル (MultiLayerNetwork) のロード def model = loadModel(modelFile) // iris データセットの取得 def data = DataSets.iris() (0..<epoch).each { def ev = [ train: new SimpleEvaluator(model), test: new SimpleEvaluator(model) ] data.shuffle() // 学習用とテスト用にデータセットを分割 def testAndTrain = data.splitTestAndTrain(trainRate) // 学習用データセットをミニバッチへ分割 testAndTrain.train.batchBy(batchSize).each { ev.train.eval(it) // 学習 model.fit(it) } // テスト用データセットを評価 ev.test.eval(testAndTrain.test) // 結果の出力 println([ ev.train.loss(), ev.train.accuracy(), ev.test.loss(), ev.test.accuracy() ].join(',')) }
また、デフォルトではログを標準出力するので、今回は設定ファイルを用意して無効化しました。
logback.xml
<configuration> <root level="OFF"></root> </configuration>
(a) 単純な構成 (入力層 - 出力層)
入力層と出力層だけの単純なニューラルネットを試します。
iris データセットは、4つの変数を使って 3つの品種に分類する事になりますので、OutputLayer.Builder
へ以下のように設定します。
nIn
へ変数の数 4 を設定nOut
へ分類の数 3 を設定
ソフトマックスと 3種類以上の交差エントロピー(Cross Entropy)を実施するには、以下のように設定すれば良いみたいです。
activation
へ softmax を設定OutputLayer.Builder
でLossFunctions.LossFunction.MCXENT
(Multiclass Cross Entropy) を設定
また、list には設定するレイヤーの数 (以下では 1) を設定します ※。
※ ただし、最新のソースで list(int) は Deprecated となっており、 代わりにレイヤー数を指定しなくても済む list() を使うようです また、list(int) は list() を呼ぶだけの実装へ変わっていました
create_iris_hnn1.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8') @Grab('org.nd4j:nd4j-x86:0.4-rc3.8') import org.deeplearning4j.nn.conf.NeuralNetConfiguration import org.deeplearning4j.nn.conf.Updater import org.deeplearning4j.nn.conf.layers.OutputLayer import org.nd4j.linalg.lossfunctions.LossFunctions import groovy.transform.BaseScript @BaseScript ModelManager baseScript def learningRate = args[0] as double def updateMethod = Updater.valueOf(args[1].toUpperCase()) def destFile = args[2] def conf = new NeuralNetConfiguration.Builder() .iterations(1) // 最適化の繰り返し回数(デフォルト設定は 5) .updater(updateMethod) .learningRate(learningRate) .list(1) .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .nIn(4) .nOut(3) .activation('softmax') .build() ) .build() // 保存 saveModel(destFile, conf)
学習・評価
更新方法 (updateMethod) だけを変えたモデルを作って学習・評価を実施してみました。
(a-1) learningRate = 0.01, updateMethod = adam, epoch = 50
実行例
> groovy create_iris_hnn1.groovy 0.01 adam models/a-1_adam.ser ・・・ > groovy iris_train.groovy 50 models/a-1_adam.ser > results/a-1.csv ・・・
(a-2) learningRate = 0.01, updateMethod = adadelta, epoch = 50
実行例
> groovy create_iris_hnn1.groovy 0.01 adadelta models/a-2_adadelta.ser ・・・ > groovy iris_train.groovy 50 models/a-2_adadelta.ser > results/a-2.csv ・・・
(a-3) learningRate = 0.01, updateMethod = sgd, epoch = 50
実行例
> groovy create_iris_hnn1.groovy 0.01 sgd models/a-3_sgd.ser ・・・ > groovy iris_train.groovy 50 models/a-3_sgd.ser > results/a-3.csv ・・・
(b) 隠れ層を追加 (入力層 - 隠れ層 - 出力層)
次に、隠れ層を追加してみます。
DenseLayer
を追加して nIn
と nOut
を調整します。
create_iris_hnn2.groovy
@Grab('org.deeplearning4j:deeplearning4j-core:0.4-rc3.8') @Grab('org.nd4j:nd4j-x86:0.4-rc3.8') import org.deeplearning4j.nn.conf.NeuralNetConfiguration import org.deeplearning4j.nn.conf.Updater import org.deeplearning4j.nn.conf.layers.DenseLayer import org.deeplearning4j.nn.conf.layers.OutputLayer import org.nd4j.linalg.lossfunctions.LossFunctions import groovy.transform.BaseScript @BaseScript ModelManager baseScript def learningRate = args[0] as double def updateMethod = Updater.valueOf(args[1].toUpperCase()) // 隠れ層のニューロン数 def fcNeuNum = args[2] as int // 隠れ層の活性化関数 def fcAct = args[3] def destFile = args[4] def conf = new NeuralNetConfiguration.Builder() .iterations(1) .updater(updateMethod) .learningRate(learningRate) .list(2) // 隠れ層 .layer(0, new DenseLayer.Builder() .nIn(4) .nOut(fcNeuNum) .activation(fcAct) .build() ) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .nIn(fcNeuNum) .nOut(3) .activation('softmax') .build() ) .build() // 保存 saveModel(destFile, conf)
学習・評価
隠れ層の活性化関数 (fcAct) だけを変えたモデルを作って学習・評価を実施してみました。
(b-1) fcNeuNum = 8, fcAct = relu, learningRate = 0.01, updateMethod = adam, epoch = 50
実行例
> groovy create_iris_hnn2.groovy 0.01 adam 8 relu models/b-1_adam_relu.ser ・・・ > groovy iris_train.groovy 50 models/b-1_adam_relu.ser > results/b-1.csv ・・・
(b-2) fcNeuNum = 8, fcAct = sigmoid, learningRate = 0.01, updateMethod = adam, epoch = 50
実行例
> groovy create_iris_hnn2.groovy 0.01 adam 8 sigmoid models/b-2_adam_sigmoid.ser ・・・ > groovy iris_train.groovy 50 models/b-2_adam_sigmoid.ser > results/b-2.csv ・・・