Java で行列の演算 - nd4j, commons-math, la4j, ujmp, jblas, colt

Java で以下のような行列の演算を複数のライブラリで試しました。

  • (a) 和
  • (b) 積
  • (c) 転置

とりあえず今回は、更新日が比較的新しめのライブラリを試してみました。

また、あくまでも個人的な印象ですが、手軽に使いたいなら la4j か Commons Math、性能重視なら ND4J か jblas、可視化や DB との連携を考慮するなら UJMP を使えば良さそうです。

ソースは http://github.com/fits/try_samples/tree/master/blog/20161031/

ND4J

Deeplearning4J で使用しているライブラリ。

build.gradle
apply plugin: 'application'

mainClassName = 'SampleApp'

repositories {
    jcenter()
}

dependencies {
    compile 'org.nd4j:nd4j-native-platform:0.6.0'
    runtime 'org.slf4j:slf4j-nop:1.7.21'
}
src/main/java/SampleApp.java
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class SampleApp {
    public static void main(String... args) {
        INDArray x = Nd4j.create(new double[][] {
            {1, 2},
            {3, 4}
        });

        INDArray y = Nd4j.create(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.mmul(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

[[ 6.00,  8.00],
 [10.00, 12.00]]
-----
[[19.00, 22.00],
 [43.00, 50.00]]
-----
[[1.00, 3.00],
 [2.00, 4.00]]

Commons Math

Apache Commons のライブラリ。

build.gradle
・・・
dependencies {
    compile 'org.apache.commons:commons-math3:3.6.1'
}
src/main/java/SampleApp.java
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;

public class SampleApp {
    public static void main(String... args) {
        RealMatrix x = MatrixUtils.createRealMatrix(new double[][] {
            {1, 2},
            {3, 4}
        });

        RealMatrix y = MatrixUtils.createRealMatrix(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.multiply(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

Array2DRowRealMatrix{{6.0,8.0},{10.0,12.0}}
-----
Array2DRowRealMatrix{{19.0,22.0},{43.0,50.0}}
-----
Array2DRowRealMatrix{{1.0,3.0},{2.0,4.0}}

la4j

Java のみで実装された軽量なライブラリ。

build.gradle
・・・
dependencies {
    compile 'org.la4j:la4j:0.6.0'
}
src/main/java/SampleApp.java
import org.la4j.Matrix;

public class SampleApp {
    public static void main(String... args) {
        Matrix x = Matrix.from2DArray(new double[][] {
            {1, 2},
            {3, 4}
        });

        Matrix y = Matrix.from2DArray(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.multiply(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

 6.000  8.000
10.000 12.000

-----
19.000 22.000
43.000 50.000

-----
1.000 3.000
2.000 4.000

UJMP

データ可視化や JDBC との連携など、機能豊富そうなライブラリ。 Colt や jblas 等ともプラグインモジュールで連携できる模様。

build.gradle
・・・
dependencies {
    compile 'org.ujmp:ujmp-core:0.3.0'
}
src/main/java/SampleApp.java
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

public class SampleApp {
    public static void main(String... args) {
        Matrix x = DenseMatrix.Factory.linkToArray(
            new double[] {1, 2},
            new double[] {3, 4}
        );

        Matrix y = DenseMatrix.Factory.linkToArray(
            new double[] {5, 6},
            new double[] {7, 8}
        );
        // (a)
        System.out.println( x.plus(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.mtimes(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

    6.0000     8.0000
   10.0000    12.0000

-----
   19.0000    22.0000
   43.0000    50.0000

-----
    1.0000     3.0000
    2.0000     4.0000

jblas

BLAS/LAPACK をベースとしたライブラリ。 ネイティブライブラリを使用する。

build.gradle
・・・
dependencies {
    compile 'org.jblas:jblas:1.2.4'
}
src/main/java/SampleApp.java
import org.jblas.DoubleMatrix;

public class SampleApp {
    public static void main(String... args) {
        DoubleMatrix x = new DoubleMatrix(new double[][] {
            {1, 2},
            {3, 4}
        });

        DoubleMatrix y = new DoubleMatrix(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.mmul(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

[6.000000, 8.000000; 10.000000, 12.000000]
-----
[19.000000, 22.000000; 43.000000, 50.000000]
-----
[1.000000, 3.000000; 2.000000, 4.000000]
-- org.jblas INFO Starting temp DLL cleanup task.
-- org.jblas INFO Deleted 4 unused temp DLL libraries from ・・・

Colt Blazegraph 版

Colt は長らく更新されていないようなので、今回は Blazegraph による fork 版? を使いました。

build.gradle
・・・
dependencies {
    compile 'com.blazegraph:colt:2.1.4'
}
src/main/java/SampleApp.java
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;

public class SampleApp {
    public static void main(String... args) {
        DoubleMatrix2D x = DoubleFactory2D.dense.make(new double[][] {
            {1, 2},
            {3, 4}
        });

        DoubleMatrix2D y = DoubleFactory2D.dense.make(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.copy().assign(y, Functions.plus) );

        System.out.println("-----");

        Algebra algebra = new Algebra();
        // (b)
        System.out.println( algebra.mult(x, y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.viewDice() );
    }
}

assign を使うと自身の値を更新するため copy を使っています。

実行結果
> gradle -q run

2 x 2 matrix
 6  8
10 12
-----
2 x 2 matrix
19 22
43 50
-----
2 x 2 matrix
1 3
2 4

Lucene API で Solr と Elasticsearch のインデックスを確認

Groovy で LuceneAPI を使用して Solr や Elasticsearch のインデックスの内容を確認してみました。(Lucene 6.2.1 の API を使用)

ソースは http://github.com/fits/try_samples/tree/master/blog/20161024/

(a) ドキュメントの内容を出力

まずは、ドキュメントに属するフィールドの内容を出力する処理です。

DirectoryReader から Document を取得し、フィールド IndexableField の内容を出力しています。

dump_docs.groovy
@Grab('org.apache.lucene:lucene-core:6.2.1')
import org.apache.lucene.index.DirectoryReader
import org.apache.lucene.store.FSDirectory
import java.nio.file.Paths

def dir = FSDirectory.open(Paths.get(args[0]))

DirectoryReader.open(dir).withCloseable { reader ->
    println "numDocs = ${reader.numDocs()}"

    (0..<reader.numDocs()).each {
        // ドキュメントの取得
        def doc = reader.document(it)

        println "---------- doc: ${it} ----------"

        // ドキュメント内のフィールドを出力
        doc.fields.each { f -> 
            def value = f.binaryValue()? f.binaryValue().utf8ToString(): f.stringValue()

            println "<field> name=${f.name}, value=${value}, class=${f.class}"
        }
    }
}

(b) Term の内容を出力

インデックス内のフィールド情報と Term を出力する処理です。 Term は基本的な検索の単位となっており、Term の内容を見れば単語の分割状況を確認できます。

これらの情報を取得するには LeafReader を使います。

Term の内容 BytesRefTermsEnum から取得できます。

LeafReader から terms メソッドで該当フィールドの Terms を取得し、iterator メソッドで TermsEnum を取得します。

dump_terms.groovy
@Grab('org.apache.lucene:lucene-core:6.2.1')
import org.apache.lucene.index.DirectoryReader
import org.apache.lucene.store.FSDirectory
import java.nio.file.Paths

def dir = FSDirectory.open(Paths.get(args[0]))

DirectoryReader.open(dir).withCloseable { reader ->

    reader.leaves().each { ctx ->
        // LeafReader の取得
        def leafReader = ctx.reader()

        println "---------- leaf: ${leafReader} ----------"

        // フィールド情報の出力
        leafReader.getFieldInfos().each { fi ->
            println "<fieldInfo> name: ${fi.name}, valueType: ${fi.docValuesType}, indexOptions: ${fi.indexOptions}"
        }

        leafReader.fields().each { name ->
            // 指定のフィールド名に対する TermsEnum を取得
            def termsEnum = leafReader.terms(name).iterator()

            println ''
            println "===== <term> name=${name} ====="

            try {
                while(termsEnum.next() != null) {
                    // Term の内容を出力
                    println "term=${termsEnum.term().utf8ToString()}, freq=${termsEnum.docFreq()}"
                }
            } catch(e) {
            }
        }
    }
}

動作確認

Lucene のバージョンが以下のようになっていたので、今回は Solr 6.2.1 と Elasticsearch 5.0.0 RC1 のインデックス内容を確認してみます。

プロダクト 使用している Lucene のバージョン
Solr 6.2.1 Lucene 6.2.1
Elasticsearch 2.4.1 Lucene 5.5.2
Elasticsearch 5.0.0 RC1 Lucene 6.2.0

(1) Solr 6.2.1

Solr 6.2.1 のインデックスから確認してみます。

準備

インデックスを作成してドキュメントを登録しておきます。

1. Solr 起動とインデックスの作成

> solr start

・・・

> solr create -c sample

{
  "responseHeader":{
    "status":0,
    "QTime":9197},
  "core":"sample"}

2. スキーマの登録

schema.json
{
    "add-field": {
        "name": "title",
        "type": "string"
    },
    "add-field": {
        "name": "num",
        "type": "int"
    },
    "add-field": {
        "name": "rdate",
        "type": "date"
    }
}
スキーマ登録
$ curl -s http://localhost:8983/solr/sample/schema --data-binary @schema.json

{
  "responseHeader":{
    "status":0,
    "QTime":554}}

3. ドキュメントの登録

data1.json
{
    "title": "item1",
    "num": 11,
    "rdate": "2016-10-20T13:45:00Z"
}
ドキュメント登録
$ curl -s http://localhost:8983/solr/sample/update/json/docs --data-binary @data1.json

{"responseHeader":{"status":0,"QTime":199}}

ちなみに、コミットしなくてもインデックスファイルには反映されるようです。

インデックスの内容確認

それでは、インデックスの内容を確認します。

該当するインデックスのディレクトリ(例. C:\solr-6.2.1\server\solr\sample\data\index)を引数に指定して実行します。

(a) ドキュメントの内容
> groovy dump_docs.groovy C:\solr-6.2.1\server\solr\sample\data\index

numDocs = 1
---------- doc: 0 ----------
<field> name=title, value=item1, class=class org.apache.lucene.document.StoredField
<field> name=num, value=11, class=class org.apache.lucene.document.StoredField
<field> name=rdate, value=1476971100000, class=class org.apache.lucene.document.StoredField
<field> name=id, value=2b1080dd-0cd3-43c6-a3ff-ab618ad00113, class=class org.apache.lucene.document.StoredField
(b) Term の内容
> groovy dump_terms.groovy C:\solr-6.2.1\server\solr\sample\data\index

---------- leaf: _0(6.2.1):C1 ----------
<fieldInfo> name: title, valueType: SORTED, indexOptions: DOCS
<fieldInfo> name: _text_, valueType: NONE, indexOptions: DOCS_AND_FREQS_AND_POSITIONS
<fieldInfo> name: num, valueType: NUMERIC, indexOptions: DOCS
<fieldInfo> name: rdate, valueType: NUMERIC, indexOptions: DOCS
<fieldInfo> name: id, valueType: SORTED, indexOptions: DOCS
<fieldInfo> name: _version_, valueType: NUMERIC, indexOptions: DOCS

===== <term> name=_text_ =====
term=00, freq=1
term=0cd3, freq=1
term=11, freq=1
term=13, freq=1
term=1548710288432300032, freq=1
term=20, freq=1
term=2016, freq=1
term=2b1080dd, freq=1
term=43c6, freq=1
term=45, freq=1
term=a3ff, freq=1
term=ab618ad00113, freq=1
term=item1, freq=1
term=oct, freq=1
term=thu, freq=1
term=utc, freq=1

===== <term> name=_version_ =====
term= ?yTP   , freq=1

===== <term> name=id =====
term=2b1080dd-0cd3-43c6-a3ff-ab618ad00113, freq=1

===== <term> name=num =====
term=   , freq=1

===== <term> name=rdate =====
term=    *~Yn`, freq=1

===== <term> name=title =====
term=item1, freq=1

version・num・rdate の値が文字化けしているように見えますが、これは org.apache.lucene.util.LegacyNumericUtils.intToPrefixCoded() メソッド等で処理されてバイナリデータとなっているためです。

実際の値を復元するには LegacyNumericUtils.prefixCodedToInt() 等を適用する必要があるようです。

なお、LegacyNumericUtils クラスは Lucene 6.2.1 API で deprecated となっていますが、Solr は未だ使っているようです。

(2) Elasticsearch 5.0.0 RC1

次は Elasticsearch です。

準備

インデックスを作成してドキュメントを登録しておきます。

1. Elasticsearch 起動

> elasticsearch

・・・

2. インデックスの作成とスキーマ登録

schema.json
{
    "mappings": {
        "data": {
            "properties": {
                "title": {
                    "type": "string",
                    "index": "not_analyzed"
                },
                "num": { "type": "integer" },
                "rdate": { "type": "date" }
            }
        }
    }
}
インデックス作成とスキーマ登録
$ curl -s -XPUT http://localhost:9200/sample --data-binary @schema.json

{"acknowledged":true,"shards_acknowledged":true}

3. ドキュメントの登録

data1.json
{
    "title": "item1",
    "num": 11,
    "rdate": "2016-10-20T13:45:00Z"
}
ドキュメント登録
$ curl -s http://localhost:9200/sample/data --data-binary @data1.json

{"_index":"sample","_type":"data","_id":"AVfwTjEQFnFWQdd5V9p5","_version":1,"result":"created","_shards":{"total":2,"successful":1,"failed":0},"created":true}

なお、すぐにインデックスファイルへ反映されない場合は flush を実施します。

flush 例
$ curl -s http://localhost:9200/sample/_flush

インデックスの内容確認

インデックスの内容を確認します。

Elasticsearch の場合はデフォルトで複数の shard に分かれているため、ドキュメントを登録した shard を確認しておきます。

shard の確認
$ curl -s http://localhost:9200/_cat/shards/sample?v

index  shard prirep state      docs store ip        node
sample 1     p      STARTED       0  130b 127.0.0.1 iUp_FE_
・・・
sample 4     p      STARTED       1 3.8kb 127.0.0.1 iUp_FE_
sample 4     r      UNASSIGNED
sample 0     p      STARTED       0  130b 127.0.0.1 iUp_FE_
sample 0     r      UNASSIGNED

shard 4 にドキュメントが登録されています。

Elasticsearch のインデックスディレクトリは data/nodes/<ノード番号>/indices/<インデックスのuuid>/<shard番号>/index となっているようで、今回は data\nodes\0\indices\QBXMjcCFSWy26Gow1Y9ItQ\4\index でした。(インデックスの uuid は QBXMjcCFSWy26Gow1Y9ItQ)

(a) ドキュメントの内容
> groovy dump_docs.groovy C:\elasticsearch-5.0.0-rc1\data\nodes\0\indices\QBXMjcCFSWy26Gow1Y9ItQ\4\index

numDocs = 1
---------- doc: 0 ----------
<field> name=_source, value={
        "title": "item1",
        "num": 11,
        "rdate": "2016-10-20T13:45:00Z"
}
, class=class org.apache.lucene.document.StoredField
<field> name=_uid, value=data#AVfwTjEQFnFWQdd5V9p5, class=class org.apache.lucene.document.StoredField
(b) Term の内容
> groovy dump_terms.groovy C:\elasticsearch-5.0.0-rc1\data\nodes\0\indices\QBXMjcCFSWy26Gow1Y9ItQ\4\index

---------- leaf: _0(6.2.0):c1 ----------
<fieldInfo> name: _source, valueType: NONE, indexOptions: NONE
<fieldInfo> name: _type, valueType: SORTED_SET, indexOptions: DOCS
<fieldInfo> name: _uid, valueType: NONE, indexOptions: DOCS
<fieldInfo> name: _version, valueType: NUMERIC, indexOptions: NONE
<fieldInfo> name: title, valueType: SORTED_SET, indexOptions: DOCS
<fieldInfo> name: num, valueType: SORTED_NUMERIC, indexOptions: NONE
<fieldInfo> name: rdate, valueType: SORTED_NUMERIC, indexOptions: NONE
<fieldInfo> name: _all, valueType: NONE, indexOptions: DOCS_AND_FREQS_AND_POSITIONS
<fieldInfo> name: _field_names, valueType: NONE, indexOptions: DOCS

===== <term> name=_all =====
term=00z, freq=1
term=10, freq=1
term=11, freq=1
term=2016, freq=1
term=20t13, freq=1
term=45, freq=1
term=item1, freq=1

===== <term> name=_field_names =====
term=_all, freq=1
term=_source, freq=1
term=_type, freq=1
term=_uid, freq=1
term=_version, freq=1
term=num, freq=1
term=rdate, freq=1
term=title, freq=1

===== <term> name=_type =====
term=data, freq=1

===== <term> name=_uid =====
term=data#AVfwTjEQFnFWQdd5V9p5, freq=1

===== <term> name=title =====
term=item1, freq=1

Solr とは、かなり違った結果になっています。

Deeplearning4J で MNIST を分類

Deeplearning4J で iris を分類」 に続いて、畳み込みニューラルネットを使った MNIST の分類を試します。

Deeplearning4J のバージョンが上がって、@Grab を使った Groovy 上での実行が上手くいかなかったので、今回は Kotlin で実装し Gradle で実行します。

  • Gradle 3.1
  • Kotlin 1.0.4

ソースは http://github.com/fits/try_samples/tree/master/blog/20161010/

はじめに

今回は、Gradle のマルチプロジェクトを使って以下の処理をサブプロジェクトとしました。

Gradle のビルド定義は以下の通り。

build.gradle
def slf4jVer = '1.7.21'

buildscript {
    repositories {
        jcenter()
    }

    dependencies {
        classpath 'org.jetbrains.kotlin:kotlin-gradle-plugin:1.0.4'
    }
}
// サブプロジェクトの共通設定
subprojects {
    apply plugin: 'kotlin'
    apply plugin: 'application'

    repositories {
        jcenter()
    }

    dependencies {
        compile 'org.jetbrains.kotlin:kotlin-stdlib:1.0.4'

        compile('org.deeplearning4j:deeplearning4j-core:0.6.0') {
            // エラーの回避策 (Could not find javacpp-presets-${os.name}-${os.arch}.jar)
            exclude group: 'org.bytedeco', module: 'javacpp-presets'
        }

        runtime 'org.nd4j:nd4j-native-platform:0.6.0'
    }

    run {
        // 実行時引数
        if (project.hasProperty('args')) {
            args project.args.split(' ')
        }
    }
}
// (1) 畳み込みニューラルネットのモデル作成(JSON で出力)
project(':conv_model') {
    mainClassName = 'ConvModelKt'

    dependencies {
        runtime "org.slf4j:slf4j-nop:${slf4jVer}"
    }
}
// (2) 学習処理
project(':learn_mnist') {
    mainClassName = 'LearnMnistKt'

    dependencies {
        runtime "org.slf4j:slf4j-simple:${slf4jVer}"
    }
}
// (3) 評価処理
project(':eval_mnist') {
    mainClassName = 'EvalMnistKt'

    dependencies {
        runtime "org.slf4j:slf4j-nop:${slf4jVer}"
    }
}
settings.gradle
include 'conv_model', 'learn_mnist', 'eval_mnist'

ファイル構成は以下の通りです。

ファイル構成
  • build.gradle
  • settings.gradle
  • conv_model/src/main/kotlin/convModel.kt
  • learn_mnist/src/main/kotlin/learnMnist.kt
  • eval_mnist/src/main/kotlin/evalMnist.kt

(1) 畳み込みニューラルネットのモデル作成(JSON で出力)

畳み込みニューラルネットの構成情報を JSON 化して標準出力します。

MnistDataSetIterator の MNIST データセットに対して畳み込みニューラルネットを行うには InputType.convolutionalFlat(28, 28, 1)setInputType します。

conv_model/src/main/kotlin/convModel.kt
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer
import org.nd4j.linalg.lossfunctions.LossFunctions

fun main(args: Array<String>) {
    val builder = NeuralNetConfiguration.Builder()
        .iterations(3) // 3回繰り返し
        .list(
            // 畳み込み層
            ConvolutionLayer.Builder(5, 5)
                .nIn(1)
                .nOut(8)
                .padding(2, 2)
                .activation("relu")
                .build()
            ,
            // プーリング層(最大プーリング)
            SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX, intArrayOf(2, 2))
                .stride(2, 2)
                .build()
            ,
            // 畳み込み層
            ConvolutionLayer.Builder(5, 5)
                .nOut(16)
                .padding(1, 1)
                .activation("relu")
                .build()
            ,
            // プーリング層(最大プーリング)
            SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX, intArrayOf(3, 3))
                .stride(3, 3)
                .build()
            ,
            OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                .nOut(10)
                .activation("softmax")
                .build()
        )
        // MNIST データセットに対する畳み込みニューラルネット用の入力タイプ設定
        .setInputType(InputType.convolutionalFlat(28, 28, 1))

    // JSON 化して出力
    println(builder.build().toJson())
}

(2) 学習処理

(1) で出力した JSON ファイルを使って学習を実施し、学習後のパラメータ(重み)を JSON ファイルへ出力します。

MNIST データセットには MnistDataSetIterator を使いました ※。

 ※ train 引数(今回使用したコンストラクタの第2引数)が true の場合に学習用、
    false の場合に評価用のデータセットとなります
learn_mnist/src/main/kotlin/learnMnist.kt
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.optimize.listeners.ScoreIterationListener

import org.nd4j.linalg.api.ndarray.NdArrayJSONWriter

import java.io.File

fun main(args: Array<String>) {
    // ニューラルネットのモデルファイル(JSON)
    val confJson = File(args[0]).readText()
    // パラメータ(重み)の出力ファイル名
    val destFile = args[1]

    val conf = MultiLayerConfiguration.fromJson(confJson)
    val network = MultiLayerNetwork(conf)

    network.init()
    // スコア(誤差)の出力
    network.setListeners(ScoreIterationListener())

    // MNIST 学習用データ(バッチサイズ 100)
    val trainData = MnistDataSetIterator(100, true, 0)
    // 学習
    network.fit(trainData)
    // 学習後のパラメータ(重み)を JSON ファイルへ出力
    NdArrayJSONWriter.write(network.params(), destFile)
}

(3) 評価処理

(1) と (2) の結果を使って評価します。

eval_mnist/src/main/kotlin/evalMnist.kt
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork

import org.nd4j.linalg.api.ndarray.NdArrayJSONReader

import java.io.File

fun main(args: Array<String>) {
    // ニューラルネットのモデルファイル(JSON)
    val confJson = File(args[0]).readText()
    // パラメータ(重み)のロード
    val params = NdArrayJSONReader().read(File(args[1]))

    val conf = MultiLayerConfiguration.fromJson(confJson)
    val network = MultiLayerNetwork(conf, params)
    // MNIST 評価用データ
    val testData = MnistDataSetIterator(1, false, 0)
    // 評価
    val res = network.evaluate(testData)

    println(res.stats())
}

実行

まず、畳み込みニューラルネットの構成を JSON ファイルへ出力します。

(1) 畳み込みニューラルネットのモデル作成
> gradle -q :conv_model:run > conv_model.json

次に、学習を実施します。

org.reflections.Reflections - could not create Vfs.Dir from url という警告ログが出力されましたが、特に支障は無さそうです。

(2) 学習
> gradle -q :learn_mnist:run -Pargs="../conv_model.json ../conv_params.json"

・・・
[main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1790 is 0.07888245582580566
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase

最後に、評価を実施します。

(3) 評価
> gradle -q :eval_mnist:run -Pargs="../conv_model.json ../conv_params.json"

Examples labeled as 0 classified by model as 0: 944 times
Examples labeled as 0 classified by model as 2: 7 times
Examples labeled as 0 classified by model as 5: 4 times
・・・
Examples labeled as 9 classified by model as 7: 24 times
Examples labeled as 9 classified by model as 8: 22 times
Examples labeled as 9 classified by model as 9: 937 times


==========================Scores========================================
 Accuracy:  0.9432
 Precision: 0.9463
 Recall:    0.9427
 F1 Score:  0.9445
========================================================================

Keras で MNIST を分類

Keras で iris を分類」 に続き、今回は Keras で畳み込みニューラルネットを使った MNIST の分類を試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160920/

準備

Docker で実行するための Docker イメージを作成します。

Docker イメージ作成

/vagrant/work/keras/Dockerfile
FROM python

RUN apt-get update && apt-get upgrade -y

RUN pip install --upgrade pip

RUN pip install keras
RUN pip install h5py

RUN apt-get clean

h5py は Keras の save_model 関数を使うために必要でした。

今回のバージョンでは keras をインストールしても h5py は自動的にインストールされなかったので、別途インストールするようにしています。

上記を docker build して Docker イメージを作成しておきます。

Docker ビルド
$ cd /vagrant/work/keras
$ docker build --no-cache -t sample/py-keras:0.2 .

(1) MNIST データセットの取得

keras.datasets の mnist を使うと MNIST データセットを取得できます。(S3 からダウンロードするようになっている)

load_data 関数で取得したデータセット[[<学習用画像データ>, <学習用ラベルデータ>], [<評価用画像データ>, <評価用ラベルデータ>]] のような内容となっていましたが(画素の値は 0 ~ 255)、そのままでは今回の用途に使えなかったので numpy を使って変換しています。

/vagrant/work/mnist_helper.py
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils

# 学習用のデータセット取得
def train_mnist():
    return convert_mnist(mnist.load_data()[0])

# 評価用のデータセット取得
def test_mnist():
    return convert_mnist(mnist.load_data()[1])

def convert_mnist(tpl):
    # 画像データの加工
    features = tpl[0].reshape(tpl[0].shape[0], 1, 28, 28).astype(np.float32)
    features /= 255

    # ラベルデータの加工 (10種類の分類)
    labels = np_utils.to_categorical(tpl[1], 10)

    return (features, labels)

(2) 畳み込みニューラルネットモデル

畳み込みニューラルネットのモデルを作成してバイナリファイルとして保存する処理です。

畳み込みのレイヤー構成は 「ConvNetJS で MNIST を分類2」 と同じ様にしてみました。 (活性化関数は relu を使用)

/vagrant/work/create_layer_conv.py
import sys

from keras.models import Sequential, save_model
from keras.layers.core import Dense, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D

model_dest_file = sys.argv[1]

model = Sequential()

# 1つ目の畳み込み層(5x5 で 8個出力)
model.add(Convolution2D(8, 5, 5, input_shape = (1, 28, 28)))
model.add(Activation('relu'))

# 1つ目のプーリング層 (最大プーリング)
model.add(MaxPooling2D(pool_size = (2, 2), strides = (2, 2)))

# 2つ目の畳み込み層(5x5 で 16個出力)
model.add(Convolution2D(16, 5, 5))
model.add(Activation('relu'))

# 2つ目のプーリング層 (最大プーリング)
model.add(MaxPooling2D(pool_size = (3, 3), strides = (3, 3)))

model.add(Flatten())

model.add(Dense(10))
model.add(Activation('softmax'))

model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])

# モデルの保存
save_model(model, model_dest_file)

(3) 学習処理

学習処理は以下の通りです。

学習用の MNIST データセットを使って fit を実行します。

/vagrant/work/learn_mnist.py
import sys

from keras.models import save_model, load_model
from mnist_helper import train_mnist

epoch = int(sys.argv[1])
mini_batch = int(sys.argv[2])

model_file = sys.argv[3]
model_dest_file = sys.argv[4]

# モデルの読み込み
model = load_model(model_file)

# 学習用 MNIST データセット取得
(x_train, y_train) = train_mnist()

# 学習
model.fit(x_train, y_train, nb_epoch = epoch, batch_size = mini_batch)

# 学習後のモデルを保存
save_model(model, model_dest_file)

(4) 評価処理

評価処理は以下の通りです。

評価用の MNIST データセットを使って evaluate を実行します。 verbose を 0 にすれば途中経過を出力しなくなるようです。

/vagrant/work/eval_mnist.py
import sys

from keras.models import load_model
from mnist_helper import test_mnist

model_file = sys.argv[1]

model = load_model(model_file)

# 評価用 MNIST データセット取得
(x_test, y_test) = test_mnist()

# 評価
(loss, acc) = model.evaluate(x_test, y_test, verbose = 0)

print("loss = %f, accuracy = %f" % (loss, acc))

実行

まずは、作成した Docker イメージ(py-keras)を使って Docker コンテナを起動します。

Docker コンテナ起動
$ docker run --rm -it -v /vagrant/work:/work sample/py-keras:0.2 bash

# cd /work

起動した Docker コンテナで、畳み込みニューラルネットのモデルを作成してファイルへ保存します。

1. モデル作成
# python create_layer_conv.py 1.model

Using Theano backend.

保存したファイルを使って学習を行います。 今回はミニバッチサイズ 200 で 3 回繰り返してみます。

初回実行時は MNIST データセットのダウンロードが行われます。

2. 学習
# python learn_mnist.py 3 200 1.model 1a.model

Using Theano backend.
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.pkl.gz
・・・
Epoch 1/3
60000/60000 [==============================] - 116s - loss: 0.7228 - acc: 0.7931
Epoch 2/3
60000/60000 [==============================] - 116s - loss: 0.1855 - acc: 0.9458
Epoch 3/3
60000/60000 [==============================] - 115s - loss: 0.1352 - acc: 0.9591

最後に、学習後のモデルを使って評価用のデータセットを評価します。

3. 評価
# python eval_mnist.py 1a.model

Using Theano backend.
loss = 0.102997, accuracy = 0.968600

Gradle で ScalaPB を使う

前回と同様の処理を ScalaPB で行ってみました。

ScalaPB であればビルドツールに sbt を使う方が簡単かもしれませんが、引き続き Gradle を使います。

今回作成したソースは http://github.com/fits/try_samples/tree/master/blog/20160905/

proto ファイル

前回と同じファイルですが、ファイル名に - を含むと都合が悪いようなので ※ ファイル名だけ変えています。

※ ScalaPB 0.5.40 では、デフォルトで proto ファイル名が
   そのままパッケージ名の一部となりました
   (パッケージ名は <java_package オプションの値>.<protoファイル名> )

ちなみに、java_outer_classname のオプション設定は無視されるようです。

proto/addressbook.proto (.proto ファイル)
syntax = "proto3";

package sample;

option java_package = "sample.model";
option java_outer_classname = "AddressBookProtos";

message Person {
  string name = 1;
  int32 id = 2;
  string email = 3;

  enum PhoneType {
    MOBILE = 0;
    HOME = 1;
    WORK = 2;
  }

  message PhoneNumber {
    string number = 1;
    PhoneType type = 2;
  }

  repeated PhoneNumber phone = 4;
}

message AddressBook {
  repeated Person person = 1;
}

Gradle ビルド定義

基本的な構成は前回と同じですが、ScalaPBC を実行してソースを生成する等、Scala 用に変えています。

build.gradle
apply plugin: 'scala'
apply plugin: 'application'

// protoc によるソースの自動生成先
def protoDestDir = 'src/main/protoc-generated'
// proto ファイル名
def protoFile = 'proto/addressbook.proto'

mainClassName = 'SampleApp'

repositories {
    jcenter()
}

configurations {
    scalapbc
}

dependencies {
    scalapbc 'com.trueaccord.scalapb:scalapbc_2.11:0.5.40'

    compile 'org.scala-lang:scala-library:2.11.8'
    compile 'com.trueaccord.scalapb:scalapb-runtime_2.11:0.5.40'
}

task scalapbc << {
    mkdir(protoDestDir)

    javaexec {
        main = 'com.trueaccord.scalapb.ScalaPBC'
        classpath = configurations.scalapbc
        args = [ protoFile, "--scala_out=${protoDestDir}" ]
    }
}

compileScala {
    dependsOn scalapbc
    source protoDestDir
}

clean {
    delete protoDestDir
}

サンプルアプリケーション

こちらも前回と同じ処理内容ですが、ScalaPB 用の実装となっています。

src/main/scala/SampleApp.scala
import java.io.ByteArrayOutputStream

import sample.model.addressbook.Person
import Person.PhoneNumber
import Person.PhoneType._

object SampleApp extends App {

    val phone = PhoneNumber("000-1234-5678", HOME)
    val person = Person(name = "sample1", phone = Seq(phone))

    println(person)

    val output = new ByteArrayOutputStream()

    try {
        person.writeTo(output)

        println("----------")

        val restoredPerson = Person.parseFrom(output.toByteArray)

        println(restoredPerson)

    } finally {
        output.close
    }
}

ビルドと実行

ビルドと実行の結果は以下の通りです。

前回と違って今回のビルド(scalapbc の実施)には python コマンドが必要でした。※

※ python コマンドを呼び出せるように環境変数 PATH 等を設定しておきます
   今回は Python 2.7 を使用しました
実行結果
> gradle run

・・・
:scalapbc
protoc-jar: protoc version: 300, detected platform: windows 10/amd64
protoc-jar: executing: [・・・\Temp\protoc8428481850206377506.exe, --plugin=protoc-gen-scala=・・・\Temp\protocbridge9000836851429371052.bat, proto/addressbook.proto, --scala_out=src/main/protoc-generated]
:compileScala
・・・
:run
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}

----------
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}


BUILD SUCCESSFUL

scalapbc タスクの実行によって以下のようなソースが生成されました。

  • src/main/protoc-generated/sample/model/addressbook/AddressBook.scala
  • src/main/protoc-generated/sample/model/addressbook/AddressbookProto.scala
  • src/main/protoc-generated/sample/model/addressbook/Person.scala

java_package オプションの設定値は反映されていますが、java_outer_classname オプション設定の方は無視されているようです。

Gradle で Protocol Buffers を使う - Java

Gradle を使って Protocol Buffers の protoc で Java ソースコードを生成し、ビルドしてみます。

Gradle から protoc コマンドを呼び出す方法もありますが、今回は protoc-jar を使いました。

protoc-jar を使うと、プラットフォームの環境に応じた protoc コマンドを TEMP ディレクトリへ一時的に生成して実行してくれます。

今回作成したソースは http://github.com/fits/try_samples/tree/master/blog/20160829/

proto ファイル

今回は以下の proto ファイル(version 3)を使います。

Protocol Buffers では、proto ファイルを protoc コマンドで処理する事で任意のプログラム言語のソースコードを自動生成します。

java_packagejava_outer_classname オプションで Java ソースコードを生成した際のパッケージ名とクラス名をそれぞれ指定できます。

proto/address-book.proto (.proto ファイル)
syntax = "proto3";

package sample;

option java_package = "sample.model";
option java_outer_classname = "AddressBookProtos";

message Person {
  string name = 1;
  int32 id = 2;
  string email = 3;

  enum PhoneType {
    MOBILE = 0;
    HOME = 1;
    WORK = 2;
  }

  message PhoneNumber {
    string number = 1;
    PhoneType type = 2;
  }

  repeated PhoneNumber phone = 4;
}

message AddressBook {
  repeated Person person = 1;
}

Gradle ビルド定義

compileJava タスクの実行前に com.github.os72.protocjar.Protoc を実行して src/main/protoc-generated へソースを自動生成する protoc タスクを定義しました。

protoc-jar モジュールをクラスパスへ指定するため protoc 用の configurations を定義しています。

なお、今回は Java 用のソースコードを生成するため --java_out オプションを使っています。

build.gradle
apply plugin: 'application'

// protoc によるソースの自動生成先
def protoDestDir = 'src/main/protoc-generated'
// proto ファイル名
def protoFile = 'proto/address-book.proto'

mainClassName = 'SampleApp'

repositories {
    jcenter()
}

configurations {
    protoc
}

dependencies {
    protoc 'com.github.os72:protoc-jar:3.0.0'

    compileOnly 'org.projectlombok:lombok:1.16.10'

    compile 'com.google.protobuf:protobuf-java:3.0.0'
}
// protoc の実行タスク
task protoc << {
    mkdir(protoDestDir)

    // protoc の実行
    javaexec {
        main = 'com.github.os72.protocjar.Protoc'
        classpath = configurations.protoc
        args = [ protoFile, "--java_out=${protoDestDir}" ]
    }
}

compileJava {
    dependsOn protoc
    source protoDestDir
}

clean {
    delete protoDestDir
}

サンプルアプリケーション

protoc で自動生成したクラスを動作確認するための簡単なサンプルを用意しました。

src/main/java/SampleApp.java
import static sample.model.AddressBookProtos.Person.PhoneType.*;

import lombok.val;

import java.io.ByteArrayOutputStream;

import sample.model.AddressBookProtos.Person;
import sample.model.AddressBookProtos.Person.PhoneNumber;

class SampleApp {
    public static void main(String... args) throws Exception {

        val phone = PhoneNumber.newBuilder()
                        .setNumber("000-1234-5678")
                        .setType(HOME)
                        .build();

        val person = Person.newBuilder()
                        .setName("sample1")
                        .addPhone(phone)
                        .build();

        System.out.println(person);

        try (val output = new ByteArrayOutputStream()) {
            // シリアライズ処理(バイト配列化)
            person.writeTo(output);

            System.out.println("----------");

            // デシリアライズ処理(バイト配列から復元)
            val restoredPerson = Person.newBuilder()
                                    .mergeFrom(output.toByteArray())
                                    .build();

            System.out.println(restoredPerson);
        }
    }
}

ビルドと実行

ビルドと実行の結果は以下の通りです。

実行結果
> gradle run

:protoc
protoc-jar: protoc version: 300, detected platform: windows 10/amd64
protoc-jar: executing: [・・・\Local\Temp\protoc3178938487369694690.exe, proto/address-book.proto, --java_out=src/main/protoc-generated]
:compileJava
・・・
:run
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}

----------
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}


BUILD SUCCESSFUL

なお、protoc で以下のようなコードが生成されました。

src/main/protoc-generated/sample/model/AddressBookProtos.java
package sample.model;

public final class AddressBookProtos {
  private AddressBookProtos() {}
  ・・・
  public  static final class Person extends
      com.google.protobuf.GeneratedMessageV3 implements
      // @@protoc_insertion_point(message_implements:sample.Person)
      PersonOrBuilder {
    ・・・
    public enum PhoneType
        implements com.google.protobuf.ProtocolMessageEnum {
      /**
       * <code>MOBILE = 0;</code>
       */
      MOBILE(0),
      /**
       * <code>HOME = 1;</code>
       */
      HOME(1),
      /**
       * <code>WORK = 2;</code>
       */
      WORK(2),
      UNRECOGNIZED(-1),
      ;
      ・・・
    }
    ・・・
    public  static final class PhoneNumber extends
        com.google.protobuf.GeneratedMessageV3 implements
        // @@protoc_insertion_point(message_implements:sample.Person.PhoneNumber)
        PhoneNumberOrBuilder {
      ・・・
    }
    ・・・
  }
  ・・・
}

Axon Framework でイベントソーシング

Axon Fraework のイベントソーシング機能を軽く試してみました。

今回のソースは http://github.com/fits/try_samples/tree/master/blog/20160815/

はじめに

今回は以下のような Gradle 用ビルド定義を使います。

lombok は必須ではありませんが、便利なので使っています。 compileOnlyコンパイル時にのみ使用するモジュールを指定できます。

build.gradle
apply plugin: 'application'

repositories {
    jcenter()
}

dependencies {
    compileOnly 'org.projectlombok:lombok:1.16.10'

    compile 'org.axonframework:axon-core:3.0-M3'

    runtime 'org.slf4j:slf4j-simple:1.7.21'
}

mainClassName = 'SampleApp'

Axon Framework によるイベントソーシング

DDD や CQRS におけるイベントソーシングでは、永続化したイベント群を順次適用して集約の状態を復元します。

Axon Framework では以下のように処理すれば、イベントソーシングを実現できるようです。

  • (a) CommandHandler でコマンドからイベントを生成し適用
  • (b) EventSourcingHandler でイベントの内容からモデルの状態を変更

なお、これらのハンドラはアノテーションで指定できるようになっています。

コマンドの作成

まずは、在庫作成のコマンドを実装します。

処理対象の識別子を設定するフィールドに @TargetAggregateIdentifier を付けますが、CreateInventoryItem には特に付けなくても問題無さそうでした。(次の CheckInItemsToInventory では必須)

なお、lombok.Value を使っているため、コンストラクタや getter メソッドが自動的に生成されます。

src/main/java/sample/commands/CreateInventoryItem.java
package sample.commands;

import org.axonframework.commandhandling.TargetAggregateIdentifier;
import lombok.Value;

// 在庫作成コマンド
@Value
public class CreateInventoryItem {

    @TargetAggregateIdentifier // このアノテーションは必須では無さそう
    private String id;

    private String name;
}

次に、在庫数を加えるためのコマンドです。

src/main/java/sample/commands/CheckInItemsToInventory.java
package sample.commands;

import org.axonframework.commandhandling.TargetAggregateIdentifier;
import lombok.Value;

// 在庫数追加コマンド
@Value
public class CheckInItemsToInventory {

    @TargetAggregateIdentifier // このアノテーションは必須
    private String id;

    private int count;
}

イベントの作成

次は、コマンドによって生じるイベントを作成します。

今回は、CreateInventoryItem コマンドから 2つのイベント (InventoryItemCreated と InventoryItemRenamed) が生じるような仕様で考えてみました。

src/main/java/sample/events/InventoryItemCreated.java
package sample.events;

import lombok.Value;

// 在庫作成イベント
@Value
public class InventoryItemCreated {
    private String id;
}
src/main/java/sample/events/InventoryItemRenamed.java
package sample.events;

import lombok.Value;

// 名前変更イベント
@Value
public class InventoryItemRenamed {
    private String newName;
}
src/main/java/sample/events/ItemsCheckedInToInventory.java
package sample.events;

import lombok.Value;

// 在庫数追加イベント
@Value
public class ItemsCheckedInToInventory {
    private int count;
}

エンティティの作成

在庫エンティティを作成します。

一意の識別子を設定するフィールドへ @AggregateIdentifier を付与します。

コマンドを処理するコンストラクタ ※ やメソッドへ @CommandHandler を付与し、その処理内でイベントを作成して AggregateLifecycle.apply メソッドへ渡せば、@EventSourcingHandler を付与した該当メソッドが呼び出されます。

イベントの内容に合わせてエンティティの内部状態を更新する事で、イベントソーシングを実現できます。

 ※ 新規作成のコマンドを処理する場合に、コンストラクタを使います

なお、引数なしのデフォルトコンストラクタは必須なようです。

src/main/java/sample/models/InventoryItem.java
package sample.models;

import org.axonframework.commandhandling.CommandHandler;
import org.axonframework.commandhandling.model.AggregateIdentifier;
import org.axonframework.commandhandling.model.AggregateLifecycle;
import org.axonframework.commandhandling.model.ApplyMore;

import org.axonframework.eventsourcing.EventSourcingHandler;

import lombok.Getter;
import lombok.ToString;

import sample.commands.CreateInventoryItem;
import sample.commands.CheckInItemsToInventory;
import sample.events.InventoryItemCreated;
import sample.events.InventoryItemRenamed;
import sample.events.ItemsCheckedInToInventory;

@ToString
public class InventoryItem {

    @AggregateIdentifier
    @Getter
    private String id;

    @Getter
    private String name;

    @Getter
    private int count;

    // デフォルトコンストラクタは必須
    public InventoryItem() {
    }

    // 在庫作成コマンド処理
    @CommandHandler
    public InventoryItem(CreateInventoryItem cmd) {
        System.out.println("C call new: " + cmd);
        // 在庫作成イベントの作成と適用
        AggregateLifecycle.apply(new InventoryItemCreated(cmd.getId()));
        // 名前変更イベントの作成と適用
        AggregateLifecycle.apply(new InventoryItemRenamed(cmd.getName()));
    }

    // 在庫数追加コマンド処理
    @CommandHandler
    private ApplyMore updateCount(CheckInItemsToInventory cmd) {
        System.out.println("C call updateCount: " + cmd);
        // 在庫数追加イベントの作成と適用
        return AggregateLifecycle.apply(new ItemsCheckedInToInventory(cmd.getCount()));
    }

    // 在庫作成イベントの適用処理
    @EventSourcingHandler
    private void applyCreated(InventoryItemCreated event) {
        System.out.println("E call applyCreated: " + event);

        this.id = event.getId();
    }

    // 名前変更イベントの適用処理
    @EventSourcingHandler
    private void applyRenamed(InventoryItemRenamed event) {
        System.out.println("E call applyRenamed: " + event);

        this.name = event.getNewName();
    }

    // 在庫数追加イベントの適用処理
    @EventSourcingHandler
    private void applyCheckedIn(ItemsCheckedInToInventory event) {
        System.out.println("E call applyCheckedIn: " + event);

        this.count += event.getCount();
    }
}

実行クラスの作成

動作確認のための実行クラスを作成します。

CommandGateway へコマンドを send すれば処理が流れるように CommandBus・Repository・EventStore 等を組み合わせます。

イベントソーシングには EventSourcingRepository を使用します。 アノテーションを使ったコマンドハンドラを適用するには AggregateAnnotationCommandHandler を使用します。

また、今回はインメモリでイベントを保持する InMemoryEventStorageEngine を使っています。

src/main/java/SampleApp.java
import org.axonframework.commandhandling.AggregateAnnotationCommandHandler;
import org.axonframework.commandhandling.SimpleCommandBus;
import org.axonframework.commandhandling.gateway.DefaultCommandGateway;

import org.axonframework.eventsourcing.EventSourcedAggregate;
import org.axonframework.eventsourcing.EventSourcingRepository;
import org.axonframework.eventsourcing.eventstore.EmbeddedEventStore;
import org.axonframework.eventsourcing.eventstore.inmemory.InMemoryEventStorageEngine;

import lombok.val;

import sample.commands.CreateInventoryItem;
import sample.commands.CheckInItemsToInventory;
import sample.models.InventoryItem;

public class SampleApp {
    public static void main(String... args) {

        val cmdBus = new SimpleCommandBus();
        val gateway = new DefaultCommandGateway(cmdBus);

        val es = new EmbeddedEventStore(new InMemoryEventStorageEngine());

        // イベントソーシング用の Repository
        val repository = new EventSourcingRepository<>(InventoryItem.class, es);

        // アノテーションによるコマンドハンドラを適用
        new AggregateAnnotationCommandHandler<>(InventoryItem.class, 
                                                repository).subscribe(cmdBus);

        String r1 = gateway.sendAndWait(new CreateInventoryItem("s1", "sample1"));
        System.out.println("id: " + r1);

        System.out.println("----------");

        EventSourcedAggregate<InventoryItem> r2 = 
            gateway.sendAndWait(new CheckInItemsToInventory("s1", 5));

        printAggregate(r2);

        System.out.println("----------");

        EventSourcedAggregate<InventoryItem> r3 = 
            gateway.sendAndWait(new CheckInItemsToInventory("s1", 3));

        printAggregate(r3);
    }

    private static void printAggregate(EventSourcedAggregate<InventoryItem> esag) {
        System.out.println(esag.getAggregateRoot());
    }
}

なお、CreateInventoryItem を send した後に、同じ ID (今回は "s1")を使って再度 CreateInventoryItem を send してみたところ、特に重複チェックなどが実施されるわけではなく、普通にイベント (InventoryItemCreated と InventoryItemRenamed) が追加されました。

実行

実行結果は以下の通りです。

コマンドを送信する度に、これまでのイベントを適用してエンティティの状態を復元した後、新しいコマンドを処理している動作を確認できました。

> gradle run

・・・
:run
C call new: CreateInventoryItem(id=s1, name=sample1)
E call applyCreated: InventoryItemCreated(id=s1)
E call applyRenamed: InventoryItemRenamed(newName=sample1)
id: s1
----------
E call applyCreated: InventoryItemCreated(id=s1)
E call applyRenamed: InventoryItemRenamed(newName=sample1)
C call updateCount: CheckInItemsToInventory(id=s1, count=5)
E call applyCheckedIn: ItemsCheckedInToInventory(count=5)
InventoryItem(id=s1, name=sample1, count=5)
----------
E call applyCreated: InventoryItemCreated(id=s1)
E call applyRenamed: InventoryItemRenamed(newName=sample1)
E call applyCheckedIn: ItemsCheckedInToInventory(count=5)
C call updateCount: CheckInItemsToInventory(id=s1, count=3)
E call applyCheckedIn: ItemsCheckedInToInventory(count=3)
InventoryItem(id=s1, name=sample1, count=8)

備考 - スナップショット

イベント数が多くなると、イベントを毎回初めから適用して状態を復元するのはパフォーマンス的に厳しくなるため、スナップショットを使います。

Axon Framework にもスナップショットの機能が用意されており、EventCountSnapshotterTrigger 等を EventSourcingRepository へ設定すれば使えるようです。

ただし、今回のサンプル (SampleApp.java) では EventCountSnapshotterTrigger を機能させられませんでした。

というのも、EventCountSnapshotterTrigger では decorateForAppend メソッドの実行時にスナップショット化を行うようですが ※、今回のサンプルでは decorateForAppend は一度も実行せず decorateForRead メソッドのみが実行されるようでした。

 ※ ソースをざっと見た限りでは、カウンターのカウントアップ等も
    decorateForAppend の中でのみ実施しているようだった