読者です 読者をやめる 読者になる 読者になる

Spring Web Reactive を試す

Java Reactive Spring

Spring 5 で導入される Spring Web Reactive を試してみました。

本来なら Spring Boot で実行する事になると思いますが、今回は Spring Boot を使わずに Undertow で直接実行してみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20161115/

サンプル作成

ビルド定義

現時点では Spring Web Reactive の正式版はリリースされていないようなのでスナップショット版を使います。

Undertow を実行するために undertow-coreJSON で結果を返す処理を試すために jackson-databind を依存関係へ設定しています。

build.gradle
apply plugin: 'application'

mainClassName = 'sample.App'

repositories {
    jcenter()

    maven {
        url 'http://repo.spring.io/snapshot/'
    }
}

dependencies {
    compile 'org.projectlombok:lombok:1.16.10'

    // Spring Web Reactive
    compile 'org.springframework:spring-web-reactive:5.0.0.BUILD-SNAPSHOT'
    // Undertow
    compile 'io.undertow:undertow-core:2.0.0.Alpha1'

    // JSON 用
    runtime 'com.fasterxml.jackson.core:jackson-databind:2.8.4'
}

設定クラス

@EnableWebReactive を付与すれば Spring Web Reactive を有効にできるようです。

src/main/java/sample/config/AppConfig.java
package sample.config;

import org.springframework.context.annotation.ComponentScan;
import org.springframework.web.reactive.config.EnableWebReactive;

@EnableWebReactive
@ComponentScan("sample.controller")
public class AppConfig {
}

コントローラークラス

Spring Web Reactive は Spring Web (MVC) のプログラミングスタイルを踏襲しているようです。

Spring Web (MVC) と同じアノテーションでコントローラークラスを定義し、メソッドの戻り値に Reactor の Mono / Flux を使えばよさそうです。

クラス 概要
Mono 単一の結果を返す場合に使用
Flux 複数の結果を返す場合に使用
src/main/java/sample/controller/SampleController.java
package sample.controller;

import lombok.Value;

import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

@RestController
public class SampleController {

    @RequestMapping("/")
    public Mono<String> sample() {
        return Mono.just("sample");
    }

    @RequestMapping(value = "/data/{name}", produces = "application/json")
    public Flux<Data> dataList(@PathVariable("name") String name) {
        return Flux.fromArray(new Data[] {
            new Data(name + "-1", 1),
            new Data(name + "-2", 2),
            new Data(name + "-3", 3)
        });
    }

    @Value
    class Data {
        private String name;
        private int value;
    }
}

実行クラス

まずは、Spring Web Reactive を有効化した ApplicationContext を使って HttpHandler を作成します。 (DispatcherHandlerWebHttpHandlerBuilder を使用)

あとは、対象の Web サーバー ※ に合わせて実行します。

 ※ 今のところ Servlet 3.1 対応コンテナ、Netty、Undertow を
    サポートしているようです

Undertow の場合、アダプタークラス UndertowHttpHandlerAdapter で HttpHandler をラッピングして実行するだけです。

src/main/java/sample/App.java
package sample;

import io.undertow.Undertow;

import lombok.val;

import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.http.server.reactive.UndertowHttpHandlerAdapter;
import org.springframework.web.reactive.DispatcherHandler;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;

import sample.config.AppConfig;

public class App {
    public static void main(String... args) {

        val ctx = new AnnotationConfigApplicationContext();
        // @EnableWebReactive を付与したクラスを登録
        ctx.register(AppConfig.class);
        ctx.refresh();

        val handler = new DispatcherHandler();
        handler.setApplicationContext(ctx);

        // HttpHandler の作成
        val httpHandler = WebHttpHandlerBuilder.webHandler(handler).build();

        val server = Undertow.builder()
                .addHttpListener(8080, "localhost")
                // Undertow 用アダプターでラッピングして設定
                .setHandler(new UndertowHttpHandlerAdapter(httpHandler))
                .build();

        server.start();
    }
}

実行

Gradle で実行
> gradle -q run

・・・
情報: Mapped "{[/data/{name}],produces=[application/json]}" onto public reactor.core.publisher.Flux<sample.controller.SampleController$Data> sample.controller.SampleController.dataList(java.lang.String)
11 14, 2016 9:32:23 午後 org.springframework.web.reactive.result.method.annotation.RequestMappingHandlerMapping register
情報: Mapped "{[/]}" onto public reactor.core.publisher.Mono<java.lang.String> sample.controller.SampleController.sample()
11 14, 2016 9:32:23 午後 org.xnio.Xnio <clinit>
INFO: XNIO version 3.3.6.Final
11 14, 2016 9:32:23 午後 org.xnio.nio.NioXnio <clinit>
INFO: XNIO NIO Implementation Version 3.3.6.Final
動作確認1
$ curl -s http://localhost:8080/
sample
動作確認2
$ curl -s http://localhost:8080/data/a
[{"name":"a-1","value":1},{"name":"a-2","value":2},{"name":"a-3","value":3}]

Java で行列の演算 - nd4j, commons-math, la4j, ujmp, jblas, colt

Java Matrix

Java で以下のような行列の演算を複数のライブラリで試しました。

  • (a) 和
  • (b) 積
  • (c) 転置

とりあえず今回は、更新日が比較的新しめのライブラリを試してみました。

また、あくまでも個人的な印象ですが、手軽に使いたいなら la4j か Commons Math、性能重視なら ND4J か jblas、可視化や DB との連携を考慮するなら UJMP を使えば良さそうです。

ソースは http://github.com/fits/try_samples/tree/master/blog/20161031/

ND4J

Deeplearning4J で使用しているライブラリ。

build.gradle
apply plugin: 'application'

mainClassName = 'SampleApp'

repositories {
    jcenter()
}

dependencies {
    compile 'org.nd4j:nd4j-native-platform:0.6.0'
    runtime 'org.slf4j:slf4j-nop:1.7.21'
}
src/main/java/SampleApp.java
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class SampleApp {
    public static void main(String... args) {
        INDArray x = Nd4j.create(new double[][] {
            {1, 2},
            {3, 4}
        });

        INDArray y = Nd4j.create(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.mmul(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

[[ 6.00,  8.00],
 [10.00, 12.00]]
-----
[[19.00, 22.00],
 [43.00, 50.00]]
-----
[[1.00, 3.00],
 [2.00, 4.00]]

Commons Math

Apache Commons のライブラリ。

build.gradle
・・・
dependencies {
    compile 'org.apache.commons:commons-math3:3.6.1'
}
src/main/java/SampleApp.java
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;

public class SampleApp {
    public static void main(String... args) {
        RealMatrix x = MatrixUtils.createRealMatrix(new double[][] {
            {1, 2},
            {3, 4}
        });

        RealMatrix y = MatrixUtils.createRealMatrix(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.multiply(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

Array2DRowRealMatrix{{6.0,8.0},{10.0,12.0}}
-----
Array2DRowRealMatrix{{19.0,22.0},{43.0,50.0}}
-----
Array2DRowRealMatrix{{1.0,3.0},{2.0,4.0}}

la4j

Java のみで実装された軽量なライブラリ。

build.gradle
・・・
dependencies {
    compile 'org.la4j:la4j:0.6.0'
}
src/main/java/SampleApp.java
import org.la4j.Matrix;

public class SampleApp {
    public static void main(String... args) {
        Matrix x = Matrix.from2DArray(new double[][] {
            {1, 2},
            {3, 4}
        });

        Matrix y = Matrix.from2DArray(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.multiply(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

 6.000  8.000
10.000 12.000

-----
19.000 22.000
43.000 50.000

-----
1.000 3.000
2.000 4.000

UJMP

データ可視化や JDBC との連携など、機能豊富そうなライブラリ。 Colt や jblas 等ともプラグインモジュールで連携できる模様。

build.gradle
・・・
dependencies {
    compile 'org.ujmp:ujmp-core:0.3.0'
}
src/main/java/SampleApp.java
import org.ujmp.core.DenseMatrix;
import org.ujmp.core.Matrix;

public class SampleApp {
    public static void main(String... args) {
        Matrix x = DenseMatrix.Factory.linkToArray(
            new double[] {1, 2},
            new double[] {3, 4}
        );

        Matrix y = DenseMatrix.Factory.linkToArray(
            new double[] {5, 6},
            new double[] {7, 8}
        );
        // (a)
        System.out.println( x.plus(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.mtimes(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

    6.0000     8.0000
   10.0000    12.0000

-----
   19.0000    22.0000
   43.0000    50.0000

-----
    1.0000     3.0000
    2.0000     4.0000

jblas

BLAS/LAPACK をベースとしたライブラリ。 ネイティブライブラリを使用する。

build.gradle
・・・
dependencies {
    compile 'org.jblas:jblas:1.2.4'
}
src/main/java/SampleApp.java
import org.jblas.DoubleMatrix;

public class SampleApp {
    public static void main(String... args) {
        DoubleMatrix x = new DoubleMatrix(new double[][] {
            {1, 2},
            {3, 4}
        });

        DoubleMatrix y = new DoubleMatrix(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.add(y) );

        System.out.println("-----");
        // (b)
        System.out.println( x.mmul(y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.transpose() );
    }
}
実行結果
> gradle -q run

[6.000000, 8.000000; 10.000000, 12.000000]
-----
[19.000000, 22.000000; 43.000000, 50.000000]
-----
[1.000000, 3.000000; 2.000000, 4.000000]
-- org.jblas INFO Starting temp DLL cleanup task.
-- org.jblas INFO Deleted 4 unused temp DLL libraries from ・・・

Colt Blazegraph 版

Colt は長らく更新されていないようなので、今回は Blazegraph による fork 版? を使いました。

build.gradle
・・・
dependencies {
    compile 'com.blazegraph:colt:2.1.4'
}
src/main/java/SampleApp.java
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;

public class SampleApp {
    public static void main(String... args) {
        DoubleMatrix2D x = DoubleFactory2D.dense.make(new double[][] {
            {1, 2},
            {3, 4}
        });

        DoubleMatrix2D y = DoubleFactory2D.dense.make(new double[][] {
            {5, 6},
            {7, 8}
        });
        // (a)
        System.out.println( x.copy().assign(y, Functions.plus) );

        System.out.println("-----");

        Algebra algebra = new Algebra();
        // (b)
        System.out.println( algebra.mult(x, y) );

        System.out.println("-----");
        // (c)
        System.out.println( x.viewDice() );
    }
}

assign を使うと自身の値を更新するため copy を使っています。

実行結果
> gradle -q run

2 x 2 matrix
 6  8
10 12
-----
2 x 2 matrix
19 22
43 50
-----
2 x 2 matrix
1 3
2 4

Lucene API で Solr と Elasticsearch のインデックスを確認

Java Groovy Lucene

Groovy で LuceneAPI を使用して Solr や Elasticsearch のインデックスの内容を確認してみました。(Lucene 6.2.1 の API を使用)

ソースは http://github.com/fits/try_samples/tree/master/blog/20161024/

(a) ドキュメントの内容を出力

まずは、ドキュメントに属するフィールドの内容を出力する処理です。

DirectoryReader から Document を取得し、フィールド IndexableField の内容を出力しています。

dump_docs.groovy
@Grab('org.apache.lucene:lucene-core:6.2.1')
import org.apache.lucene.index.DirectoryReader
import org.apache.lucene.store.FSDirectory
import java.nio.file.Paths

def dir = FSDirectory.open(Paths.get(args[0]))

DirectoryReader.open(dir).withCloseable { reader ->
    println "numDocs = ${reader.numDocs()}"

    (0..<reader.numDocs()).each {
        // ドキュメントの取得
        def doc = reader.document(it)

        println "---------- doc: ${it} ----------"

        // ドキュメント内のフィールドを出力
        doc.fields.each { f -> 
            def value = f.binaryValue()? f.binaryValue().utf8ToString(): f.stringValue()

            println "<field> name=${f.name}, value=${value}, class=${f.class}"
        }
    }
}

(b) Term の内容を出力

インデックス内のフィールド情報と Term を出力する処理です。 Term は基本的な検索の単位となっており、Term の内容を見れば単語の分割状況を確認できます。

これらの情報を取得するには LeafReader を使います。

Term の内容 BytesRefTermsEnum から取得できます。

LeafReader から terms メソッドで該当フィールドの Terms を取得し、iterator メソッドで TermsEnum を取得します。

dump_terms.groovy
@Grab('org.apache.lucene:lucene-core:6.2.1')
import org.apache.lucene.index.DirectoryReader
import org.apache.lucene.store.FSDirectory
import java.nio.file.Paths

def dir = FSDirectory.open(Paths.get(args[0]))

DirectoryReader.open(dir).withCloseable { reader ->

    reader.leaves().each { ctx ->
        // LeafReader の取得
        def leafReader = ctx.reader()

        println "---------- leaf: ${leafReader} ----------"

        // フィールド情報の出力
        leafReader.getFieldInfos().each { fi ->
            println "<fieldInfo> name: ${fi.name}, valueType: ${fi.docValuesType}, indexOptions: ${fi.indexOptions}"
        }

        leafReader.fields().each { name ->
            // 指定のフィールド名に対する TermsEnum を取得
            def termsEnum = leafReader.terms(name).iterator()

            println ''
            println "===== <term> name=${name} ====="

            try {
                while(termsEnum.next() != null) {
                    // Term の内容を出力
                    println "term=${termsEnum.term().utf8ToString()}, freq=${termsEnum.docFreq()}"
                }
            } catch(e) {
            }
        }
    }
}

動作確認

Lucene のバージョンが以下のようになっていたので、今回は Solr 6.2.1 と Elasticsearch 5.0.0 RC1 のインデックス内容を確認してみます。

プロダクト 使用している Lucene のバージョン
Solr 6.2.1 Lucene 6.2.1
Elasticsearch 2.4.1 Lucene 5.5.2
Elasticsearch 5.0.0 RC1 Lucene 6.2.0

(1) Solr 6.2.1

Solr 6.2.1 のインデックスから確認してみます。

準備

インデックスを作成してドキュメントを登録しておきます。

1. Solr 起動とインデックスの作成

> solr start

・・・

> solr create -c sample

{
  "responseHeader":{
    "status":0,
    "QTime":9197},
  "core":"sample"}

2. スキーマの登録

schema.json
{
    "add-field": {
        "name": "title",
        "type": "string"
    },
    "add-field": {
        "name": "num",
        "type": "int"
    },
    "add-field": {
        "name": "rdate",
        "type": "date"
    }
}
スキーマ登録
$ curl -s http://localhost:8983/solr/sample/schema --data-binary @schema.json

{
  "responseHeader":{
    "status":0,
    "QTime":554}}

3. ドキュメントの登録

data1.json
{
    "title": "item1",
    "num": 11,
    "rdate": "2016-10-20T13:45:00Z"
}
ドキュメント登録
$ curl -s http://localhost:8983/solr/sample/update/json/docs --data-binary @data1.json

{"responseHeader":{"status":0,"QTime":199}}

ちなみに、コミットしなくてもインデックスファイルには反映されるようです。

インデックスの内容確認

それでは、インデックスの内容を確認します。

該当するインデックスのディレクトリ(例. C:\solr-6.2.1\server\solr\sample\data\index)を引数に指定して実行します。

(a) ドキュメントの内容
> groovy dump_docs.groovy C:\solr-6.2.1\server\solr\sample\data\index

numDocs = 1
---------- doc: 0 ----------
<field> name=title, value=item1, class=class org.apache.lucene.document.StoredField
<field> name=num, value=11, class=class org.apache.lucene.document.StoredField
<field> name=rdate, value=1476971100000, class=class org.apache.lucene.document.StoredField
<field> name=id, value=2b1080dd-0cd3-43c6-a3ff-ab618ad00113, class=class org.apache.lucene.document.StoredField
(b) Term の内容
> groovy dump_terms.groovy C:\solr-6.2.1\server\solr\sample\data\index

---------- leaf: _0(6.2.1):C1 ----------
<fieldInfo> name: title, valueType: SORTED, indexOptions: DOCS
<fieldInfo> name: _text_, valueType: NONE, indexOptions: DOCS_AND_FREQS_AND_POSITIONS
<fieldInfo> name: num, valueType: NUMERIC, indexOptions: DOCS
<fieldInfo> name: rdate, valueType: NUMERIC, indexOptions: DOCS
<fieldInfo> name: id, valueType: SORTED, indexOptions: DOCS
<fieldInfo> name: _version_, valueType: NUMERIC, indexOptions: DOCS

===== <term> name=_text_ =====
term=00, freq=1
term=0cd3, freq=1
term=11, freq=1
term=13, freq=1
term=1548710288432300032, freq=1
term=20, freq=1
term=2016, freq=1
term=2b1080dd, freq=1
term=43c6, freq=1
term=45, freq=1
term=a3ff, freq=1
term=ab618ad00113, freq=1
term=item1, freq=1
term=oct, freq=1
term=thu, freq=1
term=utc, freq=1

===== <term> name=_version_ =====
term= ?yTP   , freq=1

===== <term> name=id =====
term=2b1080dd-0cd3-43c6-a3ff-ab618ad00113, freq=1

===== <term> name=num =====
term=   , freq=1

===== <term> name=rdate =====
term=    *~Yn`, freq=1

===== <term> name=title =====
term=item1, freq=1

version・num・rdate の値が文字化けしているように見えますが、これは org.apache.lucene.util.LegacyNumericUtils.intToPrefixCoded() メソッド等で処理されてバイナリデータとなっているためです。

実際の値を復元するには LegacyNumericUtils.prefixCodedToInt() 等を適用する必要があるようです。

なお、LegacyNumericUtils クラスは Lucene 6.2.1 API で deprecated となっていますが、Solr は未だ使っているようです。

(2) Elasticsearch 5.0.0 RC1

次は Elasticsearch です。

準備

インデックスを作成してドキュメントを登録しておきます。

1. Elasticsearch 起動

> elasticsearch

・・・

2. インデックスの作成とスキーマ登録

schema.json
{
    "mappings": {
        "data": {
            "properties": {
                "title": {
                    "type": "string",
                    "index": "not_analyzed"
                },
                "num": { "type": "integer" },
                "rdate": { "type": "date" }
            }
        }
    }
}
インデックス作成とスキーマ登録
$ curl -s -XPUT http://localhost:9200/sample --data-binary @schema.json

{"acknowledged":true,"shards_acknowledged":true}

3. ドキュメントの登録

data1.json
{
    "title": "item1",
    "num": 11,
    "rdate": "2016-10-20T13:45:00Z"
}
ドキュメント登録
$ curl -s http://localhost:9200/sample/data --data-binary @data1.json

{"_index":"sample","_type":"data","_id":"AVfwTjEQFnFWQdd5V9p5","_version":1,"result":"created","_shards":{"total":2,"successful":1,"failed":0},"created":true}

なお、すぐにインデックスファイルへ反映されない場合は flush を実施します。

flush 例
$ curl -s http://localhost:9200/sample/_flush

インデックスの内容確認

インデックスの内容を確認します。

Elasticsearch の場合はデフォルトで複数の shard に分かれているため、ドキュメントを登録した shard を確認しておきます。

shard の確認
$ curl -s http://localhost:9200/_cat/shards/sample?v

index  shard prirep state      docs store ip        node
sample 1     p      STARTED       0  130b 127.0.0.1 iUp_FE_
・・・
sample 4     p      STARTED       1 3.8kb 127.0.0.1 iUp_FE_
sample 4     r      UNASSIGNED
sample 0     p      STARTED       0  130b 127.0.0.1 iUp_FE_
sample 0     r      UNASSIGNED

shard 4 にドキュメントが登録されています。

Elasticsearch のインデックスディレクトリは data/nodes/<ノード番号>/indices/<インデックスのuuid>/<shard番号>/index となっているようで、今回は data\nodes\0\indices\QBXMjcCFSWy26Gow1Y9ItQ\4\index でした。(インデックスの uuid は QBXMjcCFSWy26Gow1Y9ItQ)

(a) ドキュメントの内容
> groovy dump_docs.groovy C:\elasticsearch-5.0.0-rc1\data\nodes\0\indices\QBXMjcCFSWy26Gow1Y9ItQ\4\index

numDocs = 1
---------- doc: 0 ----------
<field> name=_source, value={
        "title": "item1",
        "num": 11,
        "rdate": "2016-10-20T13:45:00Z"
}
, class=class org.apache.lucene.document.StoredField
<field> name=_uid, value=data#AVfwTjEQFnFWQdd5V9p5, class=class org.apache.lucene.document.StoredField
(b) Term の内容
> groovy dump_terms.groovy C:\elasticsearch-5.0.0-rc1\data\nodes\0\indices\QBXMjcCFSWy26Gow1Y9ItQ\4\index

---------- leaf: _0(6.2.0):c1 ----------
<fieldInfo> name: _source, valueType: NONE, indexOptions: NONE
<fieldInfo> name: _type, valueType: SORTED_SET, indexOptions: DOCS
<fieldInfo> name: _uid, valueType: NONE, indexOptions: DOCS
<fieldInfo> name: _version, valueType: NUMERIC, indexOptions: NONE
<fieldInfo> name: title, valueType: SORTED_SET, indexOptions: DOCS
<fieldInfo> name: num, valueType: SORTED_NUMERIC, indexOptions: NONE
<fieldInfo> name: rdate, valueType: SORTED_NUMERIC, indexOptions: NONE
<fieldInfo> name: _all, valueType: NONE, indexOptions: DOCS_AND_FREQS_AND_POSITIONS
<fieldInfo> name: _field_names, valueType: NONE, indexOptions: DOCS

===== <term> name=_all =====
term=00z, freq=1
term=10, freq=1
term=11, freq=1
term=2016, freq=1
term=20t13, freq=1
term=45, freq=1
term=item1, freq=1

===== <term> name=_field_names =====
term=_all, freq=1
term=_source, freq=1
term=_type, freq=1
term=_uid, freq=1
term=_version, freq=1
term=num, freq=1
term=rdate, freq=1
term=title, freq=1

===== <term> name=_type =====
term=data, freq=1

===== <term> name=_uid =====
term=data#AVfwTjEQFnFWQdd5V9p5, freq=1

===== <term> name=title =====
term=item1, freq=1

Solr とは、かなり違った結果になっています。

Deeplearning4J で MNIST を分類

Deeplearning Java Kotlin

Deeplearning4J で iris を分類」 に続いて、畳み込みニューラルネットを使った MNIST の分類を試します。

Deeplearning4J のバージョンが上がって、@Grab を使った Groovy 上での実行が上手くいかなかったので、今回は Kotlin で実装し Gradle で実行します。

  • Gradle 3.1
  • Kotlin 1.0.4

ソースは http://github.com/fits/try_samples/tree/master/blog/20161010/

はじめに

今回は、Gradle のマルチプロジェクトを使って以下の処理をサブプロジェクトとしました。

Gradle のビルド定義は以下の通り。

build.gradle
def slf4jVer = '1.7.21'

buildscript {
    repositories {
        jcenter()
    }

    dependencies {
        classpath 'org.jetbrains.kotlin:kotlin-gradle-plugin:1.0.4'
    }
}
// サブプロジェクトの共通設定
subprojects {
    apply plugin: 'kotlin'
    apply plugin: 'application'

    repositories {
        jcenter()
    }

    dependencies {
        compile 'org.jetbrains.kotlin:kotlin-stdlib:1.0.4'

        compile('org.deeplearning4j:deeplearning4j-core:0.6.0') {
            // エラーの回避策 (Could not find javacpp-presets-${os.name}-${os.arch}.jar)
            exclude group: 'org.bytedeco', module: 'javacpp-presets'
        }

        runtime 'org.nd4j:nd4j-native-platform:0.6.0'
    }

    run {
        // 実行時引数
        if (project.hasProperty('args')) {
            args project.args.split(' ')
        }
    }
}
// (1) 畳み込みニューラルネットのモデル作成(JSON で出力)
project(':conv_model') {
    mainClassName = 'ConvModelKt'

    dependencies {
        runtime "org.slf4j:slf4j-nop:${slf4jVer}"
    }
}
// (2) 学習処理
project(':learn_mnist') {
    mainClassName = 'LearnMnistKt'

    dependencies {
        runtime "org.slf4j:slf4j-simple:${slf4jVer}"
    }
}
// (3) 評価処理
project(':eval_mnist') {
    mainClassName = 'EvalMnistKt'

    dependencies {
        runtime "org.slf4j:slf4j-nop:${slf4jVer}"
    }
}
settings.gradle
include 'conv_model', 'learn_mnist', 'eval_mnist'

ファイル構成は以下の通りです。

ファイル構成
  • build.gradle
  • settings.gradle
  • conv_model/src/main/kotlin/convModel.kt
  • learn_mnist/src/main/kotlin/learnMnist.kt
  • eval_mnist/src/main/kotlin/evalMnist.kt

(1) 畳み込みニューラルネットのモデル作成(JSON で出力)

畳み込みニューラルネットの構成情報を JSON 化して標準出力します。

MnistDataSetIterator の MNIST データセットに対して畳み込みニューラルネットを行うには InputType.convolutionalFlat(28, 28, 1)setInputType します。

conv_model/src/main/kotlin/convModel.kt
import org.deeplearning4j.nn.conf.NeuralNetConfiguration
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer
import org.deeplearning4j.nn.conf.layers.OutputLayer
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer
import org.nd4j.linalg.lossfunctions.LossFunctions

fun main(args: Array<String>) {
    val builder = NeuralNetConfiguration.Builder()
        .iterations(3) // 3回繰り返し
        .list(
            // 畳み込み層
            ConvolutionLayer.Builder(5, 5)
                .nIn(1)
                .nOut(8)
                .padding(2, 2)
                .activation("relu")
                .build()
            ,
            // プーリング層(最大プーリング)
            SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX, intArrayOf(2, 2))
                .stride(2, 2)
                .build()
            ,
            // 畳み込み層
            ConvolutionLayer.Builder(5, 5)
                .nOut(16)
                .padding(1, 1)
                .activation("relu")
                .build()
            ,
            // プーリング層(最大プーリング)
            SubsamplingLayer.Builder(
                SubsamplingLayer.PoolingType.MAX, intArrayOf(3, 3))
                .stride(3, 3)
                .build()
            ,
            OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                .nOut(10)
                .activation("softmax")
                .build()
        )
        // MNIST データセットに対する畳み込みニューラルネット用の入力タイプ設定
        .setInputType(InputType.convolutionalFlat(28, 28, 1))

    // JSON 化して出力
    println(builder.build().toJson())
}

(2) 学習処理

(1) で出力した JSON ファイルを使って学習を実施し、学習後のパラメータ(重み)を JSON ファイルへ出力します。

MNIST データセットには MnistDataSetIterator を使いました ※。

 ※ train 引数(今回使用したコンストラクタの第2引数)が true の場合に学習用、
    false の場合に評価用のデータセットとなります
learn_mnist/src/main/kotlin/learnMnist.kt
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.optimize.listeners.ScoreIterationListener

import org.nd4j.linalg.api.ndarray.NdArrayJSONWriter

import java.io.File

fun main(args: Array<String>) {
    // ニューラルネットのモデルファイル(JSON)
    val confJson = File(args[0]).readText()
    // パラメータ(重み)の出力ファイル名
    val destFile = args[1]

    val conf = MultiLayerConfiguration.fromJson(confJson)
    val network = MultiLayerNetwork(conf)

    network.init()
    // スコア(誤差)の出力
    network.setListeners(ScoreIterationListener())

    // MNIST 学習用データ(バッチサイズ 100)
    val trainData = MnistDataSetIterator(100, true, 0)
    // 学習
    network.fit(trainData)
    // 学習後のパラメータ(重み)を JSON ファイルへ出力
    NdArrayJSONWriter.write(network.params(), destFile)
}

(3) 評価処理

(1) と (2) の結果を使って評価します。

eval_mnist/src/main/kotlin/evalMnist.kt
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.nn.conf.MultiLayerConfiguration
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork

import org.nd4j.linalg.api.ndarray.NdArrayJSONReader

import java.io.File

fun main(args: Array<String>) {
    // ニューラルネットのモデルファイル(JSON)
    val confJson = File(args[0]).readText()
    // パラメータ(重み)のロード
    val params = NdArrayJSONReader().read(File(args[1]))

    val conf = MultiLayerConfiguration.fromJson(confJson)
    val network = MultiLayerNetwork(conf, params)
    // MNIST 評価用データ
    val testData = MnistDataSetIterator(1, false, 0)
    // 評価
    val res = network.evaluate(testData)

    println(res.stats())
}

実行

まず、畳み込みニューラルネットの構成を JSON ファイルへ出力します。

(1) 畳み込みニューラルネットのモデル作成
> gradle -q :conv_model:run > conv_model.json

次に、学習を実施します。

org.reflections.Reflections - could not create Vfs.Dir from url という警告ログが出力されましたが、特に支障は無さそうです。

(2) 学習
> gradle -q :learn_mnist:run -Pargs="../conv_model.json ../conv_params.json"

・・・
[main] INFO org.deeplearning4j.optimize.listeners.ScoreIterationListener - Score at iteration 1790 is 0.07888245582580566
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase
[main] INFO org.deeplearning4j.nn.multilayer.MultiLayerNetwork - Finetune phase

最後に、評価を実施します。

(3) 評価
> gradle -q :eval_mnist:run -Pargs="../conv_model.json ../conv_params.json"

Examples labeled as 0 classified by model as 0: 944 times
Examples labeled as 0 classified by model as 2: 7 times
Examples labeled as 0 classified by model as 5: 4 times
・・・
Examples labeled as 9 classified by model as 7: 24 times
Examples labeled as 9 classified by model as 8: 22 times
Examples labeled as 9 classified by model as 9: 937 times


==========================Scores========================================
 Accuracy:  0.9432
 Precision: 0.9463
 Recall:    0.9427
 F1 Score:  0.9445
========================================================================

Keras で MNIST を分類

Python DeepLearning Docker

Keras で iris を分類」 に続き、今回は Keras で畳み込みニューラルネットを使った MNIST の分類を試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20160920/

準備

Docker で実行するための Docker イメージを作成します。

Docker イメージ作成

/vagrant/work/keras/Dockerfile
FROM python

RUN apt-get update && apt-get upgrade -y

RUN pip install --upgrade pip

RUN pip install keras
RUN pip install h5py

RUN apt-get clean

h5py は Keras の save_model 関数を使うために必要でした。

今回のバージョンでは keras をインストールしても h5py は自動的にインストールされなかったので、別途インストールするようにしています。

上記を docker build して Docker イメージを作成しておきます。

Docker ビルド
$ cd /vagrant/work/keras
$ docker build --no-cache -t sample/py-keras:0.2 .

(1) MNIST データセットの取得

keras.datasets の mnist を使うと MNIST データセットを取得できます。(S3 からダウンロードするようになっている)

load_data 関数で取得したデータセット[[<学習用画像データ>, <学習用ラベルデータ>], [<評価用画像データ>, <評価用ラベルデータ>]] のような内容となっていましたが(画素の値は 0 ~ 255)、そのままでは今回の用途に使えなかったので numpy を使って変換しています。

/vagrant/work/mnist_helper.py
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils

# 学習用のデータセット取得
def train_mnist():
    return convert_mnist(mnist.load_data()[0])

# 評価用のデータセット取得
def test_mnist():
    return convert_mnist(mnist.load_data()[1])

def convert_mnist(tpl):
    # 画像データの加工
    features = tpl[0].reshape(tpl[0].shape[0], 1, 28, 28).astype(np.float32)
    features /= 255

    # ラベルデータの加工 (10種類の分類)
    labels = np_utils.to_categorical(tpl[1], 10)

    return (features, labels)

(2) 畳み込みニューラルネットモデル

畳み込みニューラルネットのモデルを作成してバイナリファイルとして保存する処理です。

畳み込みのレイヤー構成は 「ConvNetJS で MNIST を分類2」 と同じ様にしてみました。 (活性化関数は relu を使用)

/vagrant/work/create_layer_conv.py
import sys

from keras.models import Sequential, save_model
from keras.layers.core import Dense, Activation, Flatten
from keras.layers import Convolution2D, MaxPooling2D

model_dest_file = sys.argv[1]

model = Sequential()

# 1つ目の畳み込み層(5x5 で 8個出力)
model.add(Convolution2D(8, 5, 5, input_shape = (1, 28, 28)))
model.add(Activation('relu'))

# 1つ目のプーリング層 (最大プーリング)
model.add(MaxPooling2D(pool_size = (2, 2), strides = (2, 2)))

# 2つ目の畳み込み層(5x5 で 16個出力)
model.add(Convolution2D(16, 5, 5))
model.add(Activation('relu'))

# 2つ目のプーリング層 (最大プーリング)
model.add(MaxPooling2D(pool_size = (3, 3), strides = (3, 3)))

model.add(Flatten())

model.add(Dense(10))
model.add(Activation('softmax'))

model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['accuracy'])

# モデルの保存
save_model(model, model_dest_file)

(3) 学習処理

学習処理は以下の通りです。

学習用の MNIST データセットを使って fit を実行します。

/vagrant/work/learn_mnist.py
import sys

from keras.models import save_model, load_model
from mnist_helper import train_mnist

epoch = int(sys.argv[1])
mini_batch = int(sys.argv[2])

model_file = sys.argv[3]
model_dest_file = sys.argv[4]

# モデルの読み込み
model = load_model(model_file)

# 学習用 MNIST データセット取得
(x_train, y_train) = train_mnist()

# 学習
model.fit(x_train, y_train, nb_epoch = epoch, batch_size = mini_batch)

# 学習後のモデルを保存
save_model(model, model_dest_file)

(4) 評価処理

評価処理は以下の通りです。

評価用の MNIST データセットを使って evaluate を実行します。 verbose を 0 にすれば途中経過を出力しなくなるようです。

/vagrant/work/eval_mnist.py
import sys

from keras.models import load_model
from mnist_helper import test_mnist

model_file = sys.argv[1]

model = load_model(model_file)

# 評価用 MNIST データセット取得
(x_test, y_test) = test_mnist()

# 評価
(loss, acc) = model.evaluate(x_test, y_test, verbose = 0)

print("loss = %f, accuracy = %f" % (loss, acc))

実行

まずは、作成した Docker イメージ(py-keras)を使って Docker コンテナを起動します。

Docker コンテナ起動
$ docker run --rm -it -v /vagrant/work:/work sample/py-keras:0.2 bash

# cd /work

起動した Docker コンテナで、畳み込みニューラルネットのモデルを作成してファイルへ保存します。

1. モデル作成
# python create_layer_conv.py 1.model

Using Theano backend.

保存したファイルを使って学習を行います。 今回はミニバッチサイズ 200 で 3 回繰り返してみます。

初回実行時は MNIST データセットのダウンロードが行われます。

2. 学習
# python learn_mnist.py 3 200 1.model 1a.model

Using Theano backend.
Downloading data from https://s3.amazonaws.com/img-datasets/mnist.pkl.gz
・・・
Epoch 1/3
60000/60000 [==============================] - 116s - loss: 0.7228 - acc: 0.7931
Epoch 2/3
60000/60000 [==============================] - 116s - loss: 0.1855 - acc: 0.9458
Epoch 3/3
60000/60000 [==============================] - 115s - loss: 0.1352 - acc: 0.9591

最後に、学習後のモデルを使って評価用のデータセットを評価します。

3. 評価
# python eval_mnist.py 1a.model

Using Theano backend.
loss = 0.102997, accuracy = 0.968600

Gradle で ScalaPB を使う

Scala protobuf

前回と同様の処理を ScalaPB で行ってみました。

ScalaPB であればビルドツールに sbt を使う方が簡単かもしれませんが、引き続き Gradle を使います。

今回作成したソースは http://github.com/fits/try_samples/tree/master/blog/20160905/

proto ファイル

前回と同じファイルですが、ファイル名に - を含むと都合が悪いようなので ※ ファイル名だけ変えています。

※ ScalaPB 0.5.40 では、デフォルトで proto ファイル名が
   そのままパッケージ名の一部となりました
   (パッケージ名は <java_package オプションの値>.<protoファイル名> )

ちなみに、java_outer_classname のオプション設定は無視されるようです。

proto/addressbook.proto (.proto ファイル)
syntax = "proto3";

package sample;

option java_package = "sample.model";
option java_outer_classname = "AddressBookProtos";

message Person {
  string name = 1;
  int32 id = 2;
  string email = 3;

  enum PhoneType {
    MOBILE = 0;
    HOME = 1;
    WORK = 2;
  }

  message PhoneNumber {
    string number = 1;
    PhoneType type = 2;
  }

  repeated PhoneNumber phone = 4;
}

message AddressBook {
  repeated Person person = 1;
}

Gradle ビルド定義

基本的な構成は前回と同じですが、ScalaPBC を実行してソースを生成する等、Scala 用に変えています。

build.gradle
apply plugin: 'scala'
apply plugin: 'application'

// protoc によるソースの自動生成先
def protoDestDir = 'src/main/protoc-generated'
// proto ファイル名
def protoFile = 'proto/addressbook.proto'

mainClassName = 'SampleApp'

repositories {
    jcenter()
}

configurations {
    scalapbc
}

dependencies {
    scalapbc 'com.trueaccord.scalapb:scalapbc_2.11:0.5.40'

    compile 'org.scala-lang:scala-library:2.11.8'
    compile 'com.trueaccord.scalapb:scalapb-runtime_2.11:0.5.40'
}

task scalapbc << {
    mkdir(protoDestDir)

    javaexec {
        main = 'com.trueaccord.scalapb.ScalaPBC'
        classpath = configurations.scalapbc
        args = [ protoFile, "--scala_out=${protoDestDir}" ]
    }
}

compileScala {
    dependsOn scalapbc
    source protoDestDir
}

clean {
    delete protoDestDir
}

サンプルアプリケーション

こちらも前回と同じ処理内容ですが、ScalaPB 用の実装となっています。

src/main/scala/SampleApp.scala
import java.io.ByteArrayOutputStream

import sample.model.addressbook.Person
import Person.PhoneNumber
import Person.PhoneType._

object SampleApp extends App {

    val phone = PhoneNumber("000-1234-5678", HOME)
    val person = Person(name = "sample1", phone = Seq(phone))

    println(person)

    val output = new ByteArrayOutputStream()

    try {
        person.writeTo(output)

        println("----------")

        val restoredPerson = Person.parseFrom(output.toByteArray)

        println(restoredPerson)

    } finally {
        output.close
    }
}

ビルドと実行

ビルドと実行の結果は以下の通りです。

前回と違って今回のビルド(scalapbc の実施)には python コマンドが必要でした。※

※ python コマンドを呼び出せるように環境変数 PATH 等を設定しておきます
   今回は Python 2.7 を使用しました
実行結果
> gradle run

・・・
:scalapbc
protoc-jar: protoc version: 300, detected platform: windows 10/amd64
protoc-jar: executing: [・・・\Temp\protoc8428481850206377506.exe, --plugin=protoc-gen-scala=・・・\Temp\protocbridge9000836851429371052.bat, proto/addressbook.proto, --scala_out=src/main/protoc-generated]
:compileScala
・・・
:run
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}

----------
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}


BUILD SUCCESSFUL

scalapbc タスクの実行によって以下のようなソースが生成されました。

  • src/main/protoc-generated/sample/model/addressbook/AddressBook.scala
  • src/main/protoc-generated/sample/model/addressbook/AddressbookProto.scala
  • src/main/protoc-generated/sample/model/addressbook/Person.scala

java_package オプションの設定値は反映されていますが、java_outer_classname オプション設定の方は無視されているようです。

Gradle で Protocol Buffers を使う - Java

Java protobuf

Gradle を使って Protocol Buffers の protoc で Java ソースコードを生成し、ビルドしてみます。

Gradle から protoc コマンドを呼び出す方法もありますが、今回は protoc-jar を使いました。

protoc-jar を使うと、プラットフォームの環境に応じた protoc コマンドを TEMP ディレクトリへ一時的に生成して実行してくれます。

今回作成したソースは http://github.com/fits/try_samples/tree/master/blog/20160829/

proto ファイル

今回は以下の proto ファイル(version 3)を使います。

Protocol Buffers では、proto ファイルを protoc コマンドで処理する事で任意のプログラム言語のソースコードを自動生成します。

java_packagejava_outer_classname オプションで Java ソースコードを生成した際のパッケージ名とクラス名をそれぞれ指定できます。

proto/address-book.proto (.proto ファイル)
syntax = "proto3";

package sample;

option java_package = "sample.model";
option java_outer_classname = "AddressBookProtos";

message Person {
  string name = 1;
  int32 id = 2;
  string email = 3;

  enum PhoneType {
    MOBILE = 0;
    HOME = 1;
    WORK = 2;
  }

  message PhoneNumber {
    string number = 1;
    PhoneType type = 2;
  }

  repeated PhoneNumber phone = 4;
}

message AddressBook {
  repeated Person person = 1;
}

Gradle ビルド定義

compileJava タスクの実行前に com.github.os72.protocjar.Protoc を実行して src/main/protoc-generated へソースを自動生成する protoc タスクを定義しました。

protoc-jar モジュールをクラスパスへ指定するため protoc 用の configurations を定義しています。

なお、今回は Java 用のソースコードを生成するため --java_out オプションを使っています。

build.gradle
apply plugin: 'application'

// protoc によるソースの自動生成先
def protoDestDir = 'src/main/protoc-generated'
// proto ファイル名
def protoFile = 'proto/address-book.proto'

mainClassName = 'SampleApp'

repositories {
    jcenter()
}

configurations {
    protoc
}

dependencies {
    protoc 'com.github.os72:protoc-jar:3.0.0'

    compileOnly 'org.projectlombok:lombok:1.16.10'

    compile 'com.google.protobuf:protobuf-java:3.0.0'
}
// protoc の実行タスク
task protoc << {
    mkdir(protoDestDir)

    // protoc の実行
    javaexec {
        main = 'com.github.os72.protocjar.Protoc'
        classpath = configurations.protoc
        args = [ protoFile, "--java_out=${protoDestDir}" ]
    }
}

compileJava {
    dependsOn protoc
    source protoDestDir
}

clean {
    delete protoDestDir
}

サンプルアプリケーション

protoc で自動生成したクラスを動作確認するための簡単なサンプルを用意しました。

src/main/java/SampleApp.java
import static sample.model.AddressBookProtos.Person.PhoneType.*;

import lombok.val;

import java.io.ByteArrayOutputStream;

import sample.model.AddressBookProtos.Person;
import sample.model.AddressBookProtos.Person.PhoneNumber;

class SampleApp {
    public static void main(String... args) throws Exception {

        val phone = PhoneNumber.newBuilder()
                        .setNumber("000-1234-5678")
                        .setType(HOME)
                        .build();

        val person = Person.newBuilder()
                        .setName("sample1")
                        .addPhone(phone)
                        .build();

        System.out.println(person);

        try (val output = new ByteArrayOutputStream()) {
            // シリアライズ処理(バイト配列化)
            person.writeTo(output);

            System.out.println("----------");

            // デシリアライズ処理(バイト配列から復元)
            val restoredPerson = Person.newBuilder()
                                    .mergeFrom(output.toByteArray())
                                    .build();

            System.out.println(restoredPerson);
        }
    }
}

ビルドと実行

ビルドと実行の結果は以下の通りです。

実行結果
> gradle run

:protoc
protoc-jar: protoc version: 300, detected platform: windows 10/amd64
protoc-jar: executing: [・・・\Local\Temp\protoc3178938487369694690.exe, proto/address-book.proto, --java_out=src/main/protoc-generated]
:compileJava
・・・
:run
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}

----------
name: "sample1"
phone {
  number: "000-1234-5678"
  type: HOME
}


BUILD SUCCESSFUL

なお、protoc で以下のようなコードが生成されました。

src/main/protoc-generated/sample/model/AddressBookProtos.java
package sample.model;

public final class AddressBookProtos {
  private AddressBookProtos() {}
  ・・・
  public  static final class Person extends
      com.google.protobuf.GeneratedMessageV3 implements
      // @@protoc_insertion_point(message_implements:sample.Person)
      PersonOrBuilder {
    ・・・
    public enum PhoneType
        implements com.google.protobuf.ProtocolMessageEnum {
      /**
       * <code>MOBILE = 0;</code>
       */
      MOBILE(0),
      /**
       * <code>HOME = 1;</code>
       */
      HOME(1),
      /**
       * <code>WORK = 2;</code>
       */
      WORK(2),
      UNRECOGNIZED(-1),
      ;
      ・・・
    }
    ・・・
    public  static final class PhoneNumber extends
        com.google.protobuf.GeneratedMessageV3 implements
        // @@protoc_insertion_point(message_implements:sample.Person.PhoneNumber)
        PhoneNumberOrBuilder {
      ・・・
    }
    ・・・
  }
  ・・・
}