Keras.js によるランドマーク検出の Web アプリケーション化

前回の 「CNN でランドマーク検出」 の学習済みモデルを Keras.js を使って Web ブラウザ上で実行できるようにしてみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20190331/

準備

npm で Keras.js をインストールします。

Keras.js インストール
> npm install --save keras-js

Keras.js に含まれている encoder.py スクリプトを使って、Python の Keras で学習したモデル(model/cnn_landmark_400.h5)を Keras.js 用に変換します。

モデルファイル(HDF5 形式)を Keras.js 用に変換
> python node_modules/keras-js/python/encoder.py model/cnn_landmark_400.h5

生成された .bin ファイル(model/cnn_landmark_400.bin)のパス(URL)を KerasJS.Model へ指定して使う事になります。

ついでに、webpack もインストールしておきます。(webpack コマンドを使うには webpack-cli も必要)

webpack インストール
> npm install --save-dev webpack webpack-cli

Web アプリケーション作成

今回、作成する Web アプリケーションのファイル構成は以下の通りです。

  • index.html
  • js/bundle_app.js
  • js/bundle_worker.js
  • model/cnn_landmark_400.bin

処理は全て Web ブラウザ上で実行するようにし、Keras.js の処理(今回のランドマーク検出)はそれなりに重いので Web Worker として実行します。

index.html の内容は以下の通りです。

index.html
<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <style type="text/css">
        canvas {
            border: 1px solid;
        }

        canvas.dragging {
            border: 3px solid red;
        }

        table {
            text-align: center;
            border-collapse: collapse;
        }

        table th, table td {
            border: 1px solid;
            padding: 8px;
        }
    </style>
</head>
<body>
    <dialog id="load-dialog">loading model ...</dialog>
    <dialog id="detect-dialog">detecting landmarks ...</dialog>
    <dialog id="error-dialog">ERROR</dialog>

    <div>
        <canvas width="256" height="256"></canvas>
    </div>

    <br>

    <div>
        <div id="landmarks"></div>
    </div>

    <script src="./js/bundle_app.js"></script>
</body>
</html>

bundle_xxx.js を生成するため、以下のような webpack 設定ファイルを用意します。

fs: 'empty' の箇所は Keras.js を webpack で処理するために必要な設定で、これが無いと Module not found: Error: Can't resolve 'fs' のようなエラーが出る事になります。

webpack.config.js
module.exports = {
    entry: {
        bundle_app: __dirname + '/src/app.js',
        bundle_worker: __dirname + '/src/worker.js'
    },
    output: {
        path: __dirname + '/js',
        filename: '[name].js',
    },
    // Keras.js 用の設定
    node: {
        fs: 'empty'
    }
}

Web Worker では Actor モデルのようにメッセージパッシング(postMessage で送信、onmessage で受信)を使ってメインの UI 処理とのデータ連携を行います。

今回は、以下のようなメッセージを Web Worker(Keras.js を使った処理)とやり取りするようにします。

Web Worker(Keras.js の処理)とのメッセージ内容
処理 送信メッセージ 受信メッセージ(成功時) 受信メッセージ(エラー時)
初期化 {type: 'init', url: <モデルファイルのURL>} {type: 'init'} {type: 'init', error: <エラーメッセージ>}
ランドマーク検出 {type: 'predict', input: <入力データ>} {type: 'predict', output: <ランドマーク検出結果>} {type: 'predict', error: <エラーメッセージ>}

(a) ランドマーク検出処理(src/worker.js)

Web Worker として実装するため、postMessage で UI 処理へメッセージを送信し、onmessage でメッセージを受信するようにします。

Keras.js 1.0.3 における Dense 処理の問題

実は、Keras.js 1.0.3 では今回の CNN モデルを正しく処理できません。

というのも、Keras.js 1.0.3 における Dense の処理(GPU を使わない場合)は以下のようになっています。

node_modules/keras-js/lib/layers/core/Dense.js の問題個所
  _callCPU(x) {
    this.output = new _Tensor.default([], [this.units]);

    ・・・
  }

今回の CNN モデルでは Dense の結果が 3次元 (256, 256, 7) になる必要がありますが、上記 Dense 処理では (7) のように 1次元になってしまい正しい結果を得られません。 ※

 ※ ついでに、Keras.js の softmax 処理にも不都合な点があった

そこで、今回は(GPU を使わない事を前提に)Dense_callCPU を実行時に書き換える事で回避しました。

処理内容としては、元の処理を 2重ループ内で実施するようにしています。

Dense 問題の回避措置(src/worker.js)
import KerasJS from 'keras-js'
import { gemv } from 'ndarray-blas-level2'
import ops from 'ndarray-ops'

・・・

// Dense の _callCPU を実行時に変更
KerasJS.layers.Dense.prototype._callCPU = function(x) {
    const h = x.tensor.shape[0]
    const w = x.tensor.shape[1]

    this.output = new KerasJS.Tensor([], [h, w, this.units])

    for (let i = 0; i < h; i++) {
        for (let j = 0; j < w; j++) {

            const xt = x.tensor.pick(i, j)
            const ot = this.output.tensor.pick(i, j)

            if (this.use_bias) {
                ops.assign(ot, this.weights['bias'].tensor)
            }

            gemv(1, this.weights['kernel'].tensor.transpose(1, 0), xt, 1, ot)

            this.activationFunc({tensor: ot})
        }
    }
}

ランドマーク検出の実装

KerasJS.Modelpredict へ入力データを渡したり結果を取り出すにはレイヤー名を指定する必要があり、これらのレイヤー名は iuputLayerNamesoutputLayerNames でそれぞれ取得できます。

predict の結果は、各座標のランドマーク該当確率 (256, 256, 7) となるので、ここではランドマーク毎 ※ に最も確率の高かった座標のみを結果として返すようにしています。

 ※ ランドマーク 0 はランドマークに該当しなかった場合なので結果に含めていない
src/worker.js
import KerasJS from 'keras-js'
import { gemv } from 'ndarray-blas-level2'
import ops from 'ndarray-ops'

let model = null

// モデルデータの読み込み
const loadModel = file => {
    const model = new KerasJS.Model({ filepath: file })

    return model.ready().then(r => model)
}

// Keras.js の Dense 問題への対応
KerasJS.layers.Dense.prototype._callCPU = function(x) {
    ・・・
}

// predict の結果を処理(ランドマーク毎に最も確率の高い座標を抽出)
const detectLandmarks = ts => {
    const res = {}

    for (let h = 0; h < ts.tensor.shape[0]; h++) {
        for (let w = 0; w < ts.tensor.shape[1]; w++) {
            const t = ts.tensor.pick(h, w)

            const wrkProb = {landmark: 0, prob: 0, x: w, y: h}

            for (let c = 0; c < t.shape[0]; c++) {
                const prob = t.get(c)

                if (prob > wrkProb.prob) {
                    wrkProb.landmark = c
                    wrkProb.prob = prob
                }
            }
            // ランドマーク 0 (ランドマークでは無い)は除外
            if (wrkProb.landmark > 0) {
                const curProb = res[wrkProb.landmark]

                if (!curProb || curProb.prob < wrkProb.prob) {
                    res[wrkProb.landmark] = wrkProb
                }
            }
        }
    }

    return res
}

// UI 処理からのメッセージ受信
onmessage = ev => {
    switch (ev.data.type) {
        case 'init':
            loadModel(ev.data.url)
                .then(m => {
                    model = m
                    postMessage({type: ev.data.type})
                })
                .catch(err => {
                    console.log(err)
                    postMessage({type: ev.data.type, error: err.message})
                })

            break
        case 'predict':
            const outputLayerName = model.outputLayerNames[0]

            const shape = model.modelLayersMap.get(outputLayerName)
                                                .output.tensor.shape

            const data = {}
            // 入力データの設定
            data[model.inputLayerNames[0]] = ev.data.input

            Promise.resolve(model.predict(data))
                .then(r => new KerasJS.Tensor(r[outputLayerName], shape)) // predict 実行結果の取り出し
                .then(detectLandmarks)
                .then(r => 
                    // UI 処理へ結果送信
                    postMessage({type: ev.data.type, output: r})
                )
                .catch(err => {
                    console.log(err)
                    postMessage({type: ev.data.type, error: err.message})
                })

            break
    }
}

(b) UI 処理(src/app.js)

画像データの変換(入力データの作成)

KerasJS.Model で predict するために、今回のケースでは画像データを 256(高さ)× 256(幅)× 3(RGB) サイズの 1次元配列 Float32Array へ変換する必要があります。

今回は以下のように canvas を利用して変換を行いました。

ImageData.data は RGBA 並びの 1次元配列 Uint8ClampedArray となっているので、RGB 部分のみを取り出して(A の内容を除外する)Float32Array を生成しています。

ちなみに、今回の CNN モデル自体は画像サイズに依存しない(Fully Convolutional Networks 的な)構成になっています。

そのため、任意サイズの画像を処理する事もできるのですが、現時点の Keras.js ではそんな事を考慮してくれていないので、実現するにはそれなりの工夫が必要になります。(一応、実現は可能でした)

ここでは、単純に canvas へ描画した 256x256 範囲の内容だけ(つまりは固定サイズ)を使うようにしています。※

 ※ この方法では 256x256 以外のサイズで欠けや余白の入り込みが発生する
画像データ変換部分(src/app.js)
・・・
const imageTypes = ['image/jpeg']

const canvas = document.getElementsByTagName('canvas')[0]
const ctx = canvas.getContext('2d')

・・・

// RGBA 並びの Uint8ClampedArray を RGB 並びの Float32Array へ変換
const imgToArray = imgData => new Float32Array(
    imgData.data.reduce(
        (acc, v, i) => {
            // RGBA の A 部分を除外
            if (i % 4 != 3) {
                acc.push(v)
            }
            return acc
        },
        []
     )
)
// 画像の読み込み
const loadImage = url => new Promise(resolve => {
    const img = new Image()

    img.addEventListener('load', () => {
        ctx.clearRect(0, 0, canvas.width, canvas.height)

        // 画像サイズが canvas よりも小さい場合の考慮
        const w = Math.min(img.width, canvas.width)
        const h = Math.min(img.height, canvas.height)

        // canvas へ画像を描画
        ctx.drawImage(img, 0, 0, w, h, 0, 0, w, h)

        // ImageData の取得
        const d = ctx.getImageData(0, 0, canvas.width, canvas.height)

        resolve(imgToArray(d))
    })

    img.src = url
})

・・・

// モデルデータ読み込み完了時の処理
const ready = () => {
    canvas.addEventListener('dragover', ev => {
        ev.preventDefault()
        canvas.classList.add('dragging')
    }, false)

    canvas.addEventListener('dragleave', ev => {
        canvas.classList.remove('dragging')
    }, false)

    // ドロップ時の処理
    canvas.addEventListener('drop', ev => {
        ev.preventDefault()
        canvas.classList.remove('dragging')

        const file = ev.dataTransfer.files[0]

        if (imageTypes.includes(file.type)) {
            ・・・
            const reader = new FileReader()

            reader.onload = ev => {
                loadImage(reader.result)
                    .then(img => {
                        ・・・
                    })
            }

            reader.readAsDataURL(file)
        }
    }, false)
}

・・・

Web Worker との連携

Web Worker とメッセージをやり取りし、ランドマークの検出結果を描画する部分の実装です。

Web Worker との連携部分(src/app.js)
const colors = ['rgb(255, 255, 255)', 'rgb(255, 0, 0)', 'rgb(0, 255, 0)', 'rgb(0, 0, 255)', 'rgb(255, 255, 0)', 'rgb(0, 255, 255)', 'rgb(255, 0, 255)']

const radius = 5
const imageTypes = ['image/jpeg']

const modelFile = '../model/cnn_landmark_400.bin'

// Web Worker の作成
const worker = new Worker('./js/bundle_worker.js')

・・・

// 検出したランドマークを canvas へ描画
const drawLandmarks = lms => {
    Object.values(lms).forEach(v => {
        ctx.fillStyle = colors[v.landmark]
        ctx.beginPath()
        ctx.arc(v.x, v.y, radius, 0, Math.PI * 2, false)
        ctx.fill()
    })
}

・・・

// 検出したランドマークの内容を table(HTML)化して表示
const showLandmarksInfo = lms => {
    ・・・

    infoNode.innerHTML = `
      <table>
        <tr>
          <th>landmark</th>
          <th>coordinate</th>
          <th>prob</th>
        </tr>
        ${rowsHtml}
      </table>
    `
}

// モデルデータ読み込み完了後
const ready = () => {
    ・・・
    canvas.addEventListener('drop', ev => {
        ・・・
        if (imageTypes.includes(file.type)) {
            ・・・
            reader.onload = ev => {
                loadImage(reader.result)
                    .then(img => {
                        detectDialog.showModal()
                        // Web Worker へのランドマーク検出指示
                        worker.postMessage({type: 'predict', input: img})
                    })
            }

            reader.readAsDataURL(file)
        }
    }, false)
}

// Web Worker からのメッセージ受信
worker.onmessage = ev => {
    if (ev.data.error) {
        ・・・
    }
    else {
        switch (ev.data.type) {
            case 'init':
                ready()

                loadDialog.close()

                break
            case 'predict':
                const res = ev.data.output

                console.log(res)
                detectDialog.close()

                drawLandmarks(res)
                showLandmarksInfo(res)

                break
        }
    }
}

loadDialog.showModal()
// Web Worker へのモデルデータ読み込み指示
worker.postMessage({type: 'init', url: modelFile})

(c) ビルド

webpack コマンドを実行し、js/bundle_app.js と js/bundle_worker.js を生成します。

webpack によるビルド
> webpack

(d) 動作確認

HTTP サーバーを使って動作確認を行います。 今回は http-server を使って実行しました。

http-server 実行
> http-server

Starting up http-server, serving ./
Available on:
  ・・・
  http://127.0.0.1:8080
Hit CTRL-C to stop the server

http://localhost:8080/index.htmlChrome ※ でアクセスして画像ファイルをドラッグアンドドロップすると以下のような結果となりました。

f:id:fits:20190331192210j:plain

 ※ HTMLDialogElement.showModal() を使っている関係で
    現時点では Chrome でしか動作しませんが、
    dialog 以外の部分(Keras.js の処理等)は
    Firefox でも動作するようになっています

CNN でランドマーク検出

前回の「CNNで輪郭の検出」 で試した手法を工夫し、ランドマーク(特徴点)検出へ適用してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190217/

輪郭の検出では画像をピクセル単位で二値分類(輪郭以外 = 0, 輪郭 = 1)しましたが、今回はこれを多クラス分類(ランドマーク以外 = 0, ランドマーク1 = 1, ランドマーク2 = 2, ・・・)へ変更します。

ちなみに、Deeplearning でランドマーク検出を行うような場合、ランドマークの座標を直接予測するような手法が考えられますが、今回試してみた限りでは納得のいく結果(座標の精度や汎用性など)を出せなくて、代わりに思いついたのが今回の手法となっています。

はじめに

データセット

今回は、DeepFashion: In-shop Clothes Retrieval のランドマーク用データセットから以下の条件を満たすものに限定して使います。

  • clothes_type の値が 1 (upper-body clothes)
  • variation_type の値が 1 (normal pose)
  • landmark_visibility_1 ~ 6 の値が 0(visible)

ランドマークには 6種類 (landmark_location_x_1 ~ 6、landmark_location_y_1 ~ 6) の座標を使います。

教師データ

入力データには画像を使うため、データ形状は (<バッチサイズ>, 256, 256, 3) ※ となります。

 ※ (<バッチサイズ>, <高さ>, <幅>, <チャンネル数>)

ラベルデータは landmark_location 1 ~ 6 の値を元に動的に生成します。

ピクセル単位でランドマーク以外(= 0)とランドマーク 1 ~ 6 の多クラス分類を行うため、データ形状は (<バッチサイズ>, 256, 256, 7) とします。

ランドマーク毎に 1ピクセルだけランドマークへ分類しても上手く学習できないので ※、一定の大きさ(範囲)をランドマークへ分類する必要があります。

 ※ 全てをランドマーク以外(= 0)とするようになってしまう

そこで、ランドマーク周辺の一定範囲をランドマークへ分類するとともに、以下の図(中心がランドマーク)のようにランドマークから離れると確率値が下がるように工夫します。

f:id:fits:20190217023108p:plain

学習

学習処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、list_landmarks_inshop.txt ファイルを読み込んで必要なデータを抜き出します。

今回は学習時間の短縮のため、先頭から 100件だけを使用しています。

データ読み込みとフィルタリング
import pandas as pd

df = pd.read_table('list_landmarks_inshop.txt', sep = '\s+', skiprows = 1)

s = 100

dfa = df[(df['clothes_type'] == 1) & (df['variation_type'] == 1) &
         (df['landmark_visibility_1'] == 0) & (df['landmark_visibility_2'] == 0) & 
         (df['landmark_visibility_3'] == 0) & (df['landmark_visibility_4'] == 0) &
         (df['landmark_visibility_5'] == 0) & (df['landmark_visibility_6'] == 0)][:s]

次に、入力データとして使う画像を読み込みます。

入力データ(画像)読み込み
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

imgs = np.array([ img_to_array(load_img(f)) for f in dfa['image_name']])

入力データの形状は以下のようになります。

imgs.shape
(100, 256, 256, 3)

(2) ラベルデータの生成

先述したように landmark_location の値から得られたランドマーク座標の周辺に確率値を設定していきます。

ここでは、確率の構成内容や確率値の設定対象とする周辺座標の取得処理を引数で指定できるようにしてみました。

また、他のランドマークの範囲と重なった場合、今回は単純に上書き(後勝ち)するようにしましたが、確率値の大きい方を選択するか確率値を分配するようにした方が望ましいと思われます。

ラベルデータ作成処理
cols = [f'landmark_location_{t}_{i + 1}' for i in range(6) for t in ['x', 'y'] ]
labels_t = dfa[cols].values.astype(int)

def gen_labels(prob, around_func):
    res = np.zeros(imgs.shape[:-1] + (int(len(cols) / 2) + 1,))
    res[:, :, :, 0] = 1.0
    
    for i in range(len(res)):
        r = res[i]
        
        # ランドマーク毎の設定
        for j in range(0, len(labels_t[i]), 2):
            # ランドマークの座標
            x = labels_t[i, j]
            y = labels_t[i, j + 1]
            
            # ランドマークの分類(1 ~ 6)
            c = int(j / 2) + 1
            
            for k in range(len(prob)):
                p = prob[k]
                
                # (相対的な)周辺座標の取得
                for a in around_func(k):
                    ax = x + a[0]
                    ay = y + a[1]
                    
                    if ax >= 0 and ax < imgs.shape[2] and ay >= 0 and ay < imgs.shape[1]:
                        # 他のランドマークと範囲が重なった場合への対応(設定値のクリア)
                        r[ay, ax, :] = 0.0
                        
                        # ランドマーク c へ該当する確率
                        r[ay, ax, c] = p
                        # ランドマーク以外へ該当する確率
                        r[ay, ax, 0] = 1.0 - p

    return res

今回は以下のような内容でラベルデータを作りました。

ラベルデータ作成
def around_square(n):
    return [(x, y) for x in range(-n, n + 1) for y in range(-n, n + 1) if abs(x) == n or abs(y) == n]

labels = gen_labels([1.0, 1.0, 1.0, 0.8, 0.8, 0.7, 0.7, 0.6, 0.6, 0.5], around_square)

ラベルデータの形状は以下の通りです。

labels.shape
(100, 256, 256, 7)

ラベルデータの内容確認

ランドマーク周辺の値を見てみると以下のようになっており、問題無さそうです。

labels[0, 59, 105:126]
array([[1. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0.5, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0.5, 0. , 0. , 0. , 0. , 0. ],
       [1. , 0. , 0. , 0. , 0. , 0. , 0. ]])
labels[0, 50:71, 149]
array([[1. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0. , 0.5, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.5, 0. , 0.5, 0. , 0. , 0. , 0. ],
       [1. , 0. , 0. , 0. , 0. , 0. , 0. ]])

これだけだと分かり難いので、単純な可視化を行ってみます。(ランドマークの該当確率をピクセル毎に合計しているだけ)

ラベルデータの可視化処理
matplotlib inline

import matplotlib.pyplot as plt

def imshow_label(index):
    plt.imshow(labels[index, :, :, 1:].sum(axis = -1), cmap = 'gray')
imshow_label(0)

f:id:fits:20190217023149p:plain

imshow_label(1)

f:id:fits:20190217023204p:plain

特に問題は無さそうです。

(3) CNN モデル

前回 と同様に Encoder-Decoder の構成を採用し、Encoder・Decoder をそれぞれ 1段階深くしました。(4段階に縮小して拡大)

多クラス分類を行うために、出力層の活性化関数を softmax にして、損失関数を categorical_crossentropy としています。

モデル内容
from keras.models import Model
from keras.layers import Input, Dense, Dropout, UpSampling2D
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(256, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(32, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(32, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(16, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(16, 3, padding = 'same', activation = 'relu')(x)

x = Dropout(0.3)(x)

output = Dense(labels.shape[-1], activation = 'softmax')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['acc'])

model.summary()
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_7 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
batch_normalization_46 (Batc (None, 256, 256, 3)       12        
_________________________________________________________________
conv2d_56 (Conv2D)           (None, 256, 256, 16)      448       
_________________________________________________________________
conv2d_57 (Conv2D)           (None, 256, 256, 16)      2320      
_________________________________________________________________
max_pooling2d_20 (MaxPooling (None, 128, 128, 16)      0         
_________________________________________________________________
batch_normalization_47 (Batc (None, 128, 128, 16)      64        
_________________________________________________________________
dropout_46 (Dropout)         (None, 128, 128, 16)      0         
_________________________________________________________________
conv2d_58 (Conv2D)           (None, 128, 128, 32)      4640      
_________________________________________________________________
conv2d_59 (Conv2D)           (None, 128, 128, 32)      9248      
_________________________________________________________________
conv2d_60 (Conv2D)           (None, 128, 128, 32)      9248      
_________________________________________________________________
max_pooling2d_21 (MaxPooling (None, 64, 64, 32)        0         
_________________________________________________________________
batch_normalization_48 (Batc (None, 64, 64, 32)        128       
_________________________________________________________________
dropout_47 (Dropout)         (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_61 (Conv2D)           (None, 64, 64, 64)        18496     
_________________________________________________________________
conv2d_62 (Conv2D)           (None, 64, 64, 64)        36928     
_________________________________________________________________
conv2d_63 (Conv2D)           (None, 64, 64, 64)        36928     
_________________________________________________________________
max_pooling2d_22 (MaxPooling (None, 32, 32, 64)        0         
_________________________________________________________________
batch_normalization_49 (Batc (None, 32, 32, 64)        256       
_________________________________________________________________
dropout_48 (Dropout)         (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_64 (Conv2D)           (None, 32, 32, 128)       73856     
_________________________________________________________________
conv2d_65 (Conv2D)           (None, 32, 32, 128)       147584    
_________________________________________________________________
conv2d_66 (Conv2D)           (None, 32, 32, 128)       147584    
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 16, 16, 128)       0         
_________________________________________________________________
batch_normalization_50 (Batc (None, 16, 16, 128)       512       
_________________________________________________________________
dropout_49 (Dropout)         (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_67 (Conv2D)           (None, 16, 16, 256)       295168    
_________________________________________________________________
batch_normalization_51 (Batc (None, 16, 16, 256)       1024      
_________________________________________________________________
dropout_50 (Dropout)         (None, 16, 16, 256)       0         
_________________________________________________________________
up_sampling2d_20 (UpSampling (None, 32, 32, 256)       0         
_________________________________________________________________
conv2d_transpose_44 (Conv2DT (None, 32, 32, 128)       295040    
_________________________________________________________________
conv2d_transpose_45 (Conv2DT (None, 32, 32, 128)       147584    
_________________________________________________________________
conv2d_transpose_46 (Conv2DT (None, 32, 32, 128)       147584    
_________________________________________________________________
batch_normalization_52 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
dropout_51 (Dropout)         (None, 32, 32, 128)       0         
_________________________________________________________________
up_sampling2d_21 (UpSampling (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_transpose_47 (Conv2DT (None, 64, 64, 64)        73792     
_________________________________________________________________
conv2d_transpose_48 (Conv2DT (None, 64, 64, 64)        36928     
_________________________________________________________________
conv2d_transpose_49 (Conv2DT (None, 64, 64, 64)        36928     
_________________________________________________________________
batch_normalization_53 (Batc (None, 64, 64, 64)        256       
_________________________________________________________________
dropout_52 (Dropout)         (None, 64, 64, 64)        0         
_________________________________________________________________
up_sampling2d_22 (UpSampling (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_transpose_50 (Conv2DT (None, 128, 128, 32)      18464     
_________________________________________________________________
conv2d_transpose_51 (Conv2DT (None, 128, 128, 32)      9248      
_________________________________________________________________
batch_normalization_54 (Batc (None, 128, 128, 32)      128       
_________________________________________________________________
dropout_53 (Dropout)         (None, 128, 128, 32)      0         
_________________________________________________________________
up_sampling2d_23 (UpSampling (None, 256, 256, 32)      0         
_________________________________________________________________
conv2d_transpose_52 (Conv2DT (None, 256, 256, 16)      4624      
_________________________________________________________________
conv2d_transpose_53 (Conv2DT (None, 256, 256, 16)      2320      
_________________________________________________________________
dropout_54 (Dropout)         (None, 256, 256, 16)      0         
_________________________________________________________________
dense_13 (Dense)             (None, 256, 256, 7)       119       
=================================================================
Total params: 1,557,971
Trainable params: 1,556,525
Non-trainable params: 1,446
_________________________________________________________________

(4) 学習

教師データ 100件では少なすぎると思いますが、今回はその中の 80件のみ学習に使用して 20件を検証に使ってみます。(validation_split で指定)

ここで、ランドマークとそれ以外でデータ数に大きな偏りがあるため(ランドマーク以外が大多数)、そのままでは上手く学習できない恐れがあります。

以下では class_weight を使ってランドマーク分類の重みを大きくしています。

実行例(351 ~ 400 エポック)
# 分類毎の重みを定義(ランドマークは 256*256 に設定)
wg = np.ones(labels.shape[-1]) * (imgs.shape[1] * imgs.shape[2])
# ランドマーク以外(= 0)の重み設定
wg[0] = 1

hist = model.fit(imgs, labels, initial_epoch = 350, epochs = 400, batch_size = 10, class_weight = wg, validation_split = 0.2)
結果例
Train on 80 samples, validate on 20 samples
Epoch 351/400
80/80 [===・・・ - loss: 0.0261 - acc: 0.9924 - val_loss: 0.1644 - val_acc: 0.9782
Epoch 352/400
80/80 [===・・・ - loss: 0.0263 - acc: 0.9924 - val_loss: 0.1638 - val_acc: 0.9784
・・・
Epoch 399/400
80/80 [===・・・ - loss: 0.0255 - acc: 0.9930 - val_loss: 0.1719 - val_acc: 0.9775
Epoch 400/400
80/80 [===・・・ - loss: 0.0253 - acc: 0.9931 - val_loss: 0.1720 - val_acc: 0.9777

(5) 確認

fit の戻り値から学習・検証の lossacc の値をそれぞれグラフ化してみます。

fit 結果表示
%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (16, 4)

plt.subplot(1, 4, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 4, 2)
plt.plot(hist.history['acc'])

plt.subplot(1, 4, 3)
plt.plot(hist.history['val_loss'])

plt.subplot(1, 4, 4)
plt.plot(hist.history['val_acc'])
結果例(351 ~ 400 エポック)

f:id:fits:20190217023309p:plain

val_lossval_acc の値が良くないのは、データ量が少なすぎる点にあると考えています。

(6) ランドマーク検出

下記 4種類の画像を出力して、ランドマーク検出(predict)結果とラベルデータ(正解)を比較してみます。

  • (a) ラベルデータの分類(ピクセル毎に確率値が最大の分類で色分け)
  • (b) 予測結果(predict)の分類(ピクセル毎に確率値が最大の分類で色分け)
  • (c) 元画像と (b) の重ね合わせ
  • (d) ランドマークの描画(各ランドマークの確率値が最大の座標へ円を描画)

今回はランドマークは分類毎に 1点のみなので、確率が最大値の座標がランドマークと判断できます。

ランドマーク検出と結果出力
import cv2

colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255), (255, 165, 0), (210, 180, 140)]

def predict(index, n = 0, c_size = 5, s = 5.0):
    plt.rcParams['figure.figsize'] = (s * 4, s)
    
    img = imgs[index]

    # 予測結果(ランドマーク分類結果)
    p = model.predict(np.array([img]))[0]

    # (a) ラベルデータの分類(ピクセル毎に確率値が最大の分類で色分け)
    img1 = np.apply_along_axis(lambda x: colors[x.argmax()], -1, labels[index])
    # (b) 予測結果の分類(ピクセル毎に確率値が最大の分類で色分け)
    img2 = np.apply_along_axis(lambda x: colors[x.argmax()], -1, p)
    # (c) 元画像への重ね合わせ
    img3 = cv2.addWeighted(img.astype(int), 0.4, img2, 0.6, 0)
    
    plt.subplot(1, 4, 1)
    plt.imshow(img1)
    
    plt.subplot(1, 4, 2)
    plt.imshow(img2)
    
    plt.subplot(1, 4, 3)
    plt.imshow(img3)

    img4 = img.astype(int)

    pdf = pd.DataFrame(
        [[np.argmax(vx), x, y, np.max(vx)] for y, vy in enumerate(p) for x, vx in enumerate(vy)], 
        columns = ['landmark', 'x', 'y', 'prob']
    )
    
    for c, v in pdf[pdf['landmark'] > 0].sort_values('prob', ascending = False).groupby('landmark'):
        # (d) ランドマークを描画(確率値が最大の座標へ円を描画)
        img4 = cv2.circle(img4, tuple(v[['x', 'y']].values[0]), c_size, colors[c], -1)
        
        if n > 0:
            print(f"landmark {c} : x = {labels_t[index, (c - 1) * 2]}, {labels_t[index, (c - 1) * 2 + 1]}")
            print(v[:n])

    plt.subplot(1, 4, 4)
    plt.imshow(img4)

学習データの結果例

左から (a) ラベルデータの分類、(b) 予測結果の分類、(c) 元画像との重ね合わせ、(d) ランドマーク検出結果となっています。

f:id:fits:20190217035001p:plain f:id:fits:20190217035116p:plain

ラベルデータにかなり近い結果が出ているように見えます。

下記のように、ランドマーク毎の確率値 TOP 3 とラベルデータを数値で比較してみると、かなり近い値になっている事を確認できました。

predict(0, n = 3)
landmark 1 : x = 115, 59
       landmark    x   y      prob
15475         1  115  60  0.893763
15476         1  116  60  0.893605
15220         1  116  59  0.893044

landmark 2 : x = 149, 60
       landmark    x   y      prob
15510         2  150  60  0.878173
15766         2  150  61  0.872413
15509         2  149  60  0.872222

landmark 3 : x = 82, 153
       landmark   x    y      prob
39250         3  82  153  0.882741
39249         3  81  153  0.881362
39248         3  80  153  0.879979

landmark 4 : x = 185, 150
       landmark    x    y      prob
38841         4  185  151  0.836826
38585         4  185  150  0.836212
38840         4  184  151  0.836164

landmark 5 : x = 93, 198
       landmark   x    y      prob
50782         5  94  198  0.829380
50526         5  94  197  0.825815
51038         5  94  199  0.825342

landmark 6 : x = 171, 197
       landmark    x    y      prob
50602         6  170  197  0.881702
50603         6  171  197  0.880731
50858         6  170  198  0.877772
predict(40, n = 3)
landmark 1 : x = 120, 42
      landmark    x   y      prob
8820         1  116  34  0.568582
9075         1  115  35  0.566257
9074         1  114  35  0.561259

landmark 2 : x = 134, 40
       landmark    x   y      prob
10372         2  132  40  0.812515
10371         2  131  40  0.807980
10628         2  132  41  0.807899

landmark 3 : x = 109, 48
       landmark    x   y      prob
12652         3  108  49  0.839624
12653         3  109  49  0.838190
12396         3  108  48  0.837235

landmark 4 : x = 148, 43
       landmark    x   y      prob
11156         4  148  43  0.837879
10900         4  148  42  0.837810
11157         4  149  43  0.836910

landmark 5 : x = 107, 176
       landmark    x    y      prob
45164         5  108  176  0.845494
45420         5  108  177  0.841054
45163         5  107  176  0.839846

landmark 6 : x = 154, 182
       landmark    x    y      prob
46746         6  154  182  0.865920
46747         6  155  182  0.863970
46490         6  154  181  0.862724

なお、predict(40) におけるランドマーク 1(赤色)の結果が振るわないのは、ラベルデータの作り方の問題だと考えられます。(上書きでは無く確率値が大きい方を採用する等で改善するはず)

検証データの結果例

f:id:fits:20190217035136p:plain f:id:fits:20190217035148p:plain

当然ながら、学習に使っていないこちらのデータでは結果が悪化していますが、それなりに正しそうな位置を部分的に検出しているように見えます。

学習に使ったデータ量の少なさを考えると、かなり良好な結果が出ているようにも思います。

そもそも、predict(-3) のようなランドマークの左右が反転している背面からの画像なんてのは無理があるように思いますし、predict(-8) のランドマーク 5(水色)はラベルデータの方が間違っている(検出結果の方が正しい)ような気もします。

predict(-1, n = 3)
landmark 1 : x = 96, 60
       landmark   x   y      prob
15969         1  97  62  0.872259
16225         1  97  63  0.869837
15970         1  98  62  0.869681

landmark 2 : x = 126, 59
       landmark    x   y      prob
16254         2  126  63  0.866628
16255         2  127  63  0.865502
15998         2  126  62  0.864939

landmark 3 : x = 66, 125
       landmark   x    y      prob
30521         3  57  119  0.832024
30520         3  56  119  0.831721
30777         3  57  120  0.829537

landmark 4 : x = 157, 117
       landmark    x    y      prob
29099         4  171  113  0.814012
29098         4  170  113  0.813680
28843         4  171  112  0.812420
predict(-8, n = 3)
landmark 1 : x = 133, 40
       landmark    x   y      prob
10629         1  133  41  0.812287
10628         1  132  41  0.810564
10373         1  133  40  0.808298

landmark 2 : x = 157, 47
       landmark    x   y      prob
12704         2  160  49  0.767413
12448         2  160  48  0.764577
12703         2  159  49  0.762571

landmark 3 : x = 105, 77
       landmark    x   y     prob
19300         3  100  75  0.79014
19301         3  101  75  0.78945
19556         3  100  76  0.78496

landmark 4 : x = 181, 86
       landmark    x    y      prob
56242         4  178  219  0.768471
55986         4  178  218  0.768215
56243         4  179  219  0.766977

landmark 5 : x = 137, 211
       landmark    x    y      prob
54370         5   98  212  0.710897
54626         5   98  213  0.707652
54372         5  100  212  0.707127

CNN で輪郭の検出

画像内の物体の輪郭検出を CNN(畳み込みニューラルネット)で試してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190114/

概要

今回は、画像をピクセル単位で輪郭か否かに分類する事(輪郭 = 1, 輪郭以外 = 0)で輪郭を検出できないか試しました。

そこで、教師データとして以下のような衣服単体の画像(jpg)と衣服の輪郭部分だけを白く塗りつぶした画像(png)を用意しました。

f:id:fits:20190114225837j:plain

教師データを大量に用意するのは困難だったため、240x288 の画像 160 ファイルで学習を行っています。

学習

学習の処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、入力画像(jpg)を読み込みます。(教師データの画像は img ディレクトリへ配置しています)

import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

files = glob.glob('img/*.jpg')

imgs = np.array([img_to_array(load_img(f)) for f in files])

imgs.shape

入力データの形状は以下の通りです。

imgs.shape 結果
(160, 288, 240, 3)

(2) ラベルデータの準備

輪郭画像(png)を読み込み、128 を境にして二値化(輪郭 = 1、輪郭以外 = 0)します。

import os

th = 128

labels = np.array([img_to_array(load_img(f"{os.path.splitext(f)[0]}.png", color_mode = 'grayscale')) for f in files])

labels[labels < th] = 0
labels[labels >= th] = 1

labels.shape

ラベルデータの形状は以下の通りです。

labels.shape 結果
(160, 288, 240, 1)

(3) CNN モデル

どのようなネットワーク構成が適しているのか分からなかったので、セマンティックセグメンテーション等で用いられている Encoder-Decoder の構成を参考にしてみました。

30x36 まで段階的に縮小して(Encoder)、元の大きさ 240x288 まで段階的に拡大する(Decoder)ようにしています。

最終層の活性化関数に sigmoid を使って 0 ~ 1 の値となるようにしています。

損失関数は binary_crossentropy を使うと進捗が遅そうに見えたので ※、代わりに mean_squared_error を使っています。

 ※ 今回の場合、輪郭(= 1)に該当するピクセルの方が少なくなるため
    binary_crossentropy を使用する場合は fit の class_weight 引数で
    調整する必要があったと思われます

また、参考のため mean_absolute_error(mae) の値も出力するように metrics で指定しています。

from keras.models import Model
from keras.layers import Input, Dropout
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

# Encoder

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

# Decoder

x = Conv2DTranspose(64, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(32, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(16, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

output = Conv2D(1, 1, activation = 'sigmoid')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])

model.summary()
model.summary() 結果
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 288, 240, 3)       0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 288, 240, 3)       12        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 288, 240, 16)      448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 144, 120, 16)      0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 144, 120, 16)      64        
_________________________________________________________________
dropout_1 (Dropout)          (None, 144, 120, 16)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 144, 120, 32)      4640      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 72, 60, 32)        0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 72, 60, 32)        128       
_________________________________________________________________
dropout_2 (Dropout)          (None, 72, 60, 32)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 72, 60, 64)        18496     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 36, 30, 64)        0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 36, 30, 64)        256       
_________________________________________________________________
dropout_3 (Dropout)          (None, 36, 30, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 36, 30, 128)       73856     
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 36, 30, 128)       147584    
_________________________________________________________________
batch_normalization_5 (Batch (None, 36, 30, 128)       512       
_________________________________________________________________
dropout_4 (Dropout)          (None, 36, 30, 128)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 72, 60, 64)        73792     
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
batch_normalization_6 (Batch (None, 72, 60, 64)        256       
_________________________________________________________________
dropout_5 (Dropout)          (None, 72, 60, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 144, 120, 32)      18464     
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
batch_normalization_7 (Batch (None, 144, 120, 32)      128       
_________________________________________________________________
dropout_6 (Dropout)          (None, 144, 120, 32)      0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 288, 240, 16)      4624      
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 288, 240, 16)      2320      
_________________________________________________________________
batch_normalization_8 (Batch (None, 288, 240, 16)      64        
_________________________________________________________________
dropout_7 (Dropout)          (None, 288, 240, 16)      0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 288, 240, 1)       17        
=================================================================
Total params: 576,541
Trainable params: 575,831
Non-trainable params: 710
_________________________________________________________________

(4) 学習

教師データが少ないため、全て学習で使う事にします。

実行例(441 ~ 480 エポック)
hist = model.fit(imgs, labels, initial_epoch = 440, epochs = 480, batch_size = 10)

Keras では fit を繰り返し呼び出すと学習を(続きから)再開できるので、40 エポックを何回か繰り返しました。(バッチサイズは 20 で始めて途中で 10 へ変えたりしています)

その場合、正しいエポックを出力するには initial_epochepochs の値を調整する必要があります ※。

 ※ initial_epoch を指定しなくても
    fit を繰り返し実行するだけで学習は継続されますが、
    その場合は出力されるエポックの値がクリアされます(1 からのカウントとなる)
結果例
Epoch 441/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
Epoch 442/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0125
Epoch 443/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
・・・
Epoch 478/480
160/160 [=====・・・ - loss: 0.0045 - mean_absolute_error: 0.0116
Epoch 479/480
160/160 [=====・・・ - loss: 0.0046 - mean_absolute_error: 0.0117
Epoch 480/480
160/160 [=====・・・ - loss: 0.0044 - mean_absolute_error: 0.0115

(5) 確認

fit の戻り値から mean_squared_error(loss)と mean_absolute_error の値の遷移をグラフ化してみます。

%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (8, 4)

plt.subplot(1, 2, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 2, 2)
plt.plot(hist.history['mean_absolute_error'])
結果例(441 ~ 480 エポック)

f:id:fits:20190114231037p:plain

(6) 検証

(a) 教師データ

教師データの入力画像と輪郭画像(ラベルデータ)、model.predict の結果を並べて表示してみます。

def predict(index, s = 6.0):
    plt.rcParams['figure.figsize'] = (s, s)

    sh = imgs.shape[1:-1]

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([imgs[index]]))[0]
    pred *= 255

    plt.subplot(1, 3, 1)
    # 入力画像の表示
    plt.imshow(imgs[index].astype(int))

    plt.subplot(1, 3, 2)
    # 輪郭画像(ラベルデータ)の表示
    plt.imshow(labels[index].reshape(sh), cmap = 'gray')

    plt.subplot(1, 3, 3)
    # predict の結果表示
    plt.imshow(pred.reshape(sh).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, 輪郭画像, model.predict 結果

f:id:fits:20190114231107j:plain

概ね教師データに近い結果が出るようになっています。

(b) 教師データ以外

教師データとして使っていない画像に対して model.predict を実施し、輪郭の検出を行ってみます。

def predict_eval(file, s = 4.0):
    plt.rcParams['figure.figsize'] = (s, s)

    img = img_to_array(load_img(file))

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    plt.subplot(1, 2, 1)
    # 入力画像の表示
    plt.imshow(img.astype(int))

    plt.subplot(1, 2, 2)
    # predict の結果表示
    plt.imshow(pred.reshape(pred.shape[:-1]).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, model.predict 結果

f:id:fits:20190114231154j:plain

所々で途切れたりしていますが、ある程度の輪郭は検出できているように見えます。

(7) 保存

学習したモデルを保存します。

model.save('model/c1_480.h5')

輪郭検出

学習済みモデルを使って輪郭検出を行う処理をスクリプト化してみました。

predict_contours.py
import sys
import os
import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
from keras.models import load_model
import cv2

model_file = sys.argv[1]
img_files = sys.argv[2]
dest_dir = sys.argv[3]

model = load_model(model_file)

for f in glob.glob(img_files):
    img = img_to_array(load_img(f))

    # 輪郭の検出
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    file, ext = os.path.splitext(os.path.basename(f))

    # 画像の保存
    cv2.imwrite(f"{dest_dir}/{file}_predict.png", pred)

    print(f"done: {f}")
実行例
python predict_contours.py model/c1_480.h5 img_eval2/*.jpg result

480 エポックの学習モデルを使って、教師データに無いタイプの背景を使った画像(影の影響もある)に試してみました。

輪郭検出結果例
入力画像 処理結果
f:id:fits:20181229154654j:plain f:id:fits:20190114231530p:plain
f:id:fits:20181229155049j:plain f:id:fits:20190114231547p:plain

こちらは難しかったようです。

なお、2つ目の画像は 120 エポックの学習モデルの方が良好な結果(輪郭がより多く検出されていた)でした。

MongoDB で条件に合致する子要素を抽出

MongoDB で指定の条件に合致する子要素のみを抽出する方法を調査してみました。

  • MongoDB 4.0.4

はじめに、下記 3つのドキュメントが sample コレクションへ登録されているとします。

ドキュメント内容
{ "_id" : 1, "items" : [
    { "color" : "black", "size" : "S" }, 
    { "color" : "white", "size" : "S" }
] }

{ "_id" : 2, "items" : [
    { "color" : "red",   "size" : "L" }, 
    { "color" : "blue",  "size" : "S" }
] }

{ "_id" : 3, "items" : [
    { "color" : "white", "size" : "L" }, 
    { "color" : "red",   "size" : "L" }, 
    { "color" : "white", "size" : "S" }
] }

ここで、items.colorwhite のものだけを抽出し、以下の結果を得る事を目指してみます。(items の中身が white のものだけを含むようにする)

目標とする検索結果
{ "_id" : 1, "items" : [
    { "color" : "white", "size" : "S" }
] }

{ "_id" : 3, "items" : [
    { "color" : "white", "size" : "L" }, 
    { "color" : "white", "size" : "S" }
] }

(a) items.color で条件指定

まずは {"items.color": "white"} の条件で find した結果です。

white を持つドキュメントだけを抽出できましたが、ドキュメントの内容はそのままなので black 等の余計なものも含んでしまいます。

> db.sample.find({"items.color": "white"})

{ "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

(b) $elemMatch 使用

次に $elemMatch を使ってみます。

$elemMatch を find の query(第一引数)で使うか、projection(第二引数)で使うかで結果が変わります。

query で使う場合は先程の (a) と同じ結果になります。

query で使用
> db.sample.find({"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "black", "size" : "S" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "red", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

query の条件は指定せずに projection で $elemMatch を使った場合の結果は以下です。

全ドキュメントを対象に items の中身がフィルタリングされていますが、条件に合致する全ての子要素が抽出されるわけでは無く、(条件に合致する)先頭の要素しか含まれていません。

projection で使用1
> db.sample.find({}, {"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 2 }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }

query 条件を指定する事で不要なドキュメントを除く事はできますが、条件に合致する全ての子要素が抽出されない事に変わりはありません。

projection で使用2
> db.sample.find({"items.color": "white"}, {"items": {$elemMatch: {"color": "white"}}})

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" } ] }

このように、今回のようなドキュメントに対して $elemMatch を使うと、条件に合致する最初の子要素だけが抽出されるようです。(何らかの回避策があるのかもしれませんが)

(c) aggregate 使用

最後に aggregate を使ってみます。

$unwind を使うと配列の個々の要素を処理できるので、$match で white のみに限定した後、$group でグルーピングすれば良さそうです。

> db.sample.aggregate([
  {$unwind: "$items"}, 
  {$match: {"items.color": "white"}}, 
  {$group: {_id: "$_id", "items": {$push: "$items"}}}
])

{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }
{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }

これで目指した結果は一応得られました。

なお、対象を最初に絞り込むようにしてソートを付けると以下のようになります。

> db.sample.aggregate([
  {$match: {"items.color": "white"}},
  {$unwind: "$items"},
  {$match: {"items.color": "white"}},
  {$group: {_id: "$_id", "items": {$push: "$items"}}},
  {$sort:  {_id: 1}}
])

{ "_id" : 1, "items" : [ { "color" : "white", "size" : "S" } ] }
{ "_id" : 3, "items" : [ { "color" : "white", "size" : "L" }, { "color" : "white", "size" : "S" } ] }

Scala のケースクラスに制約を持たせる

Scala のケースクラスで値に制約を持たせたい場合にどうするか。

例えば、以下のケースクラスで amount の値を 0 以上となるように制限し、0 未満ならインスタンス化を失敗させる事を考えてみます。

case class Quantity(amount: Int)

使用した環境は以下

ソースは http://github.com/fits/try_samples/tree/master/blog/20181009/

ケースクラスの値を制限

まず、最も単純なのは以下のような実装だと思います。

case class Quantity(amount: Int) {
  if (amount < 0)
    throw new IllegalArgumentException(s"amount($amount) < 0")
}

これだと、例外が throw されてしまい関数プログラミングで扱い難いので Try[Quantity]Option[Quantity] 等を返すようにしたいところです。

そこで、以下のようにケースクラスを abstract 化して、コンパニオンオブジェクトへ生成関数を定義する方法を使ってみました。

sample.scala
import scala.util.{Try, Success, Failure}

sealed abstract case class Quantity private (amount: Int)

object Quantity {
  def apply(amount: Int): Try[Quantity] =
    if (amount >= 0)
      Success(new Quantity(amount){})
    else
      Failure(new IllegalArgumentException(s"amount($amount) < 0"))
}

println(Quantity(1))
println(Quantity(0))
println(Quantity(-1))

// この方法では Quantity へ copy がデフォルト定義されないため
// copy は使えません(error: value copy is not a member of this.Quantity)
//
// println(Quantity(1).map(_.copy(-1)))

実行結果は以下の通りです。

実行結果
> scala sample.scala

Success(Quantity(1))
Success(Quantity(0))
Failure(java.lang.IllegalArgumentException: amount(-1) < 0)

上記 sample.scala では、以下を直接呼び出せないようにしてケースクラスの勝手なインスタンス化を防止しています。

  • (a) コンストラクタ(new)
  • (b) コンパニオンオブジェクトへデフォルト定義される apply
  • (c) ケースクラスへデフォルト定義される copy

そのために、下記 2点を実施しています。

  • (1) コンストラクタの private 化 : (a) の防止
  • (2) ケースクラスの abstract 化 : (b) (c) の防止

(1) コンストラクタの private 化

以下のように private を付ける事でコンストラクタを private 化できます。

コンストラクタの private 化
case class Quantity private (amount: Int)

これで (a) new Quantity(・・・) の実行を防止できますが、以下のように (b) の apply や (c) の copy を実行できてしまいます。

検証例
scala> case class Quantity private (amount: Int)
defined class Quantity

scala> new Quantity(1)
<console>:14: error: constructor Quantity in class Quantity cannot be accessed in object $iw
       new Quantity(1)

scala> Quantity(1)
res1: Quantity = Quantity(1)

scala> Quantity.apply(2)
res2: Quantity = Quantity(2)

scala> Quantity(3).copy(30)
res3: Quantity = Quantity(30)

(2) ケースクラスの abstract 化

ケースクラスを abstract 化すると、通常ならデフォルト定義されるコンパニオンオブジェクトの apply やケースクラスの copy を防止できるようです。

そのため、(1) と組み合わせることで (a) ~ (c) を防止できます。

ケースクラスの abstract 化とコンストラクタの private 化
sealed abstract case class Quantity private (amount: Int)

以下のように Quantity.apply は定義されなくなります。

検証例
scala> sealed abstract case class Quantity private (amount: Int)
defined class Quantity

scala> new Quantity(1){}
<console>:14: error: constructor Quantity in class Quantity cannot be accessed in <$anon: Quantity>
       new Quantity(1){}
           ^

scala> Quantity.apply(1)
<console>:14: error: value apply is not a member of object Quantity
       Quantity.apply(1)

このままだと何もできなくなるため、実際はコンパニオンオブジェクトへ生成用の関数が必要になります。

sealed abstract case class Quantity private (amount: Int)

object Quantity {
  def create(amount: Int): Quantity = new Quantity(amount){}
}

備考

今回の方法は、以下の書籍に記載されているような ADTs(algebraic data types)と Smart constructors をより安全に定義するために活用できると考えています。

Functional and Reactive Domain Modeling

Functional and Reactive Domain Modeling

Kotlin の関数型プログラミング用ライブラリ Λrrow を試してみる

Kotlin で ScalaScalazCats のような関数型プログラミング用のライブラリを探していたところ、以下を見つけたので試してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180822/

はじめに

Λrrow は以下のような要素で構成されており、Haskell の型クラス・インスタンスのような仕組みを実現しているようです。

  • Datatypes (Option, Either, Try, Kleisli, StateT, WriterT 等)
  • Typeclasses (Functor, Applicative, Monad, Monoid, Eq, Show 等)
  • Instances (OptionMonadInstance 等)

そして、その実現にはアノテーションプロセッサが活用されているようです。(@higherkind@instance から補助的なコードを生成)

Option のソースコード

例えば、Optionソースコードは以下のようになっていますが、OptionOf の定義は github のソース上には見当たりません。

arrow/core/Option.kt
@higherkind
sealed class Option<out A> : OptionOf<A> {
    ・・・
}

OptionOf はアノテーションプロセッサ ※ で生成したコード(以下)で定義されています。(OptionOf<A>Kind<ForOption, A> の型エイリアス

higherkind.arrow.core.Option.kt
package arrow.core

class ForOption private constructor() { companion object }
typealias OptionOf<A> = arrow.Kind<ForOption, A>

@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <A> OptionOf<A>.fix(): Option<A> =
  this as Option<A>
 ※ @higherkind に対する処理は
    arrow.higherkinds.HigherKindsFileGenerator に実装されているようです

ここで、ForOption というクラスが定義されていますが、この For + 型名 はコンテナ型を表現するための型のようです。

ForOption がある事で Kind<Option<A>, A> ではなく Kind<ForOption, A> と定義できるようになっており、Functor<F>Monad<F> における F の具体型として使われています。(例. Monad<ForOption>

Either のソースコード

次に Eitherソースコードも見てみます。

Either のように型パラメータが複数の場合はアノテーションプロセッサで生成されるコードの内容に多少の違いがあるようです。

arrow/core/Either.kt
@higherkind
sealed class Either<out A, out B> : EitherOf<A, B> {
    ・・・
}

アノテーションプロセッサで生成されたコードは以下のように EitherOf の他に EitherPartialOf が型エイリアスとして定義されています。

higherkind.arrow.core.Either.kt
package arrow.core

class ForEither private constructor() { companion object }
typealias EitherOf<A, B> = arrow.Kind2<ForEither, A, B>
typealias EitherPartialOf<A> = arrow.Kind<ForEither, A>

@Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
inline fun <A, B> EitherOf<A, B>.fix(): Either<A, B> =
  this as Either<A, B>

また、上記で定義されている ForEither クラスとは別に arrow-instances-core モジュールには ForEither 関数 も用意されています。

arrow/instances/either.kt
・・・
class EitherContext<L> : EitherMonadErrorInstance<L>, EitherTraverseInstance<L>, EitherSemigroupKInstance<L> {
  override fun <A, B> Kind<EitherPartialOf<L>, A>.map(f: (A) -> B): Either<L, B> =
    fix().map(f)
}

class EitherContextPartiallyApplied<L> {
  infix fun <A> extensions(f: EitherContext<L>.() -> A): A =
    f(EitherContext())
}

fun <L> ForEither(): EitherContextPartiallyApplied<L> =
  EitherContextPartiallyApplied()

この ForEither 関数は extensions を呼び出す際に使用する事になります。

サンプルコード

Option と Either をそれぞれ使ったサンプルを作成してみます。

(a) Option の利用

Option の SomeNone の作成にはいくつかの方法が用意されています。

Monad の拡張関数として定義されている binding を使うと Haskell の do 記法や Scala の for 式のようにモナドを処理できるようです。(Kotlin の coroutines 機能で実現)

binding の呼び出し方はいくつか考えられますが、ForOption extensions { ・・・ } を使うのが基本のようです。

ForOption extensions へ渡す処理内では this が OptionContext となるため、OptionContext の処理を呼び出せるようになっています。

また、binding へ渡す処理内では MonadContinuation<ForOption, *> が this となります。

bind は MonadContinuation 内で suspend fun <B> Kind<F, B>.bind(): B と定義されており、Kind<F, B> から B の値を取り出す処理となっています。

src/main/kotlin/App.kt
import arrow.core.*
import arrow.instances.*
import arrow.typeclasses.binding

fun main(args: Array<String>) {
    // Some
    val d1: Option<Int> = Option.just(10)
    val d2: Option<Int> = 5.some()
    val d3: Option<Int> = Some(2)
    // None
    val d4: Option<Int> = Option.empty()
    val d5: Option<Int> = none()
    val d6: Option<Int> = None

    // Some(15)
    val r1 = d1.flatMap { a ->
        d2.map { b -> a + b }
    }

    println(r1)

    // Some(17)
    val r2 = Option.monad().binding { // this: MonadContinuation<ForOption, *>
        val a = d1.bind() // 10
        val b = d2.bind() //  5
        val c = d3.bind() //  2
        a + b + c
    }

    println(r2)

    // Some(17)
    val r3 = ForOption extensions { // this: OptionContext
        println(this) // arrow.instances.OptionContext@3ffc5af1

        binding { // this: MonadContinuation<ForOption, *>
            println(this) // arrow.typeclasses.MonadContinuation@26653222

            val a = d1.bind()
            val b = d2.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r3)

    // Some(17)
    val r4 = OptionContext.binding {
        val a = d1.bind()
        val b = d2.bind()
        val c = d3.bind()
        a + b + c
    }

    println(r4)

    // None
    val r5 = Option.monad().binding {
        val a = d1.bind()
        val b = d4.bind()
        a + b
    }

    println(r5)

    // None
    val r6 = ForOption extensions {
        binding {
            val a = d5.bind()
            val b = d2.bind()
            val c = d6.bind()
            a + b + c
        }
    }

    println(r6)
}

Option のような基本的な型を使うだけであれば、依存ライブラリとして arrow-instances-core を指定するだけで問題なさそうです。

そのため、Gradle ビルド定義ファイルは下記のような内容になります。

ここでは The feature "coroutines" is experimental (see: https://kotlinlang.org/docs/diagnostics/experimental-coroutines) 警告ログを出力しないように coroutines を有効化しています。

build.gradle (Gradle ビルド定義)
plugins {
    id 'org.jetbrains.kotlin.jvm' version '1.2.51'
    id 'application'
}

mainClassName = 'AppKt' // 実行するクラス名

repositories {
    jcenter()
}

dependencies {
    compile 'org.jetbrains.kotlin:kotlin-stdlib-jdk8'
    compile 'io.arrow-kt:arrow-instances-core:0.7.3'
}

// coroutines の有効化(ビルド時の警告を抑制)
kotlin {
    experimental {
        coroutines 'enable'
    }
}

実行結果は以下の通りです。

実行結果
> gradle run

・・・
Some(15)
Some(17)
arrow.instances.OptionContext@3ffc5af1
arrow.typeclasses.MonadContinuation@26653222
Some(17)
Some(17)
None
None

(b) Either の利用

Either の場合、extensions の呼び出し方が Option とは多少異なります。

Option の場合は ForOption の extensions を呼び出しましたが、Either の場合は ForEither に extensions は用意されておらず、代わりに EitherContextPartiallyApplied<L> クラス内に定義されています。

EitherContextPartiallyApplied オブジェクトは ForEither 関数 で取得できるので、これを使って ForEither<Left 側の型>() extensions { ・・・ } のようにします。

ForEither<String>() extensions へ渡す処理内の this は EitherContext<String> で、binding へ渡す処理内の this は MonadContinuation<EitherPartialOf<String>, *> となります。

src/main/kotlin/App.kt
import arrow.core.*
import arrow.instances.*
import arrow.typeclasses.binding

fun main(args: Array<String>) {
    // Right
    val d1: Either<String, Int> = Either.right(10)
    val d2: Either<String, Int> = 5.right()
    val d3: Either<String, Int> = Right(2)
    // Left
    val d4: Either<String, Int> = Either.left("error data")

    // Right(b=15)
    val r1 = d1.flatMap { a ->
        d2.map { b -> a + b }
    }

    println(r1)

    // Right(b=17)
    val r2 = Either.monad<String>().binding { // this: MonadContinuation<EitherPartialOf<String>, *>
        val a = d1.bind()
        val b = d2.bind()
        val c = d3.bind()
        a + b + c
    }

    println(r2)

    // Right(b=17)
    // ForEither<String>() 関数の呼び出し
    val r3 = ForEither<String>() extensions { // this: EitherContext<String>
        binding { // this: MonadContinuation<EitherPartialOf<String>, *>
            val a = d1.bind()
            val b = d2.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r3)

    // Left(a=error data)
    val r4 = ForEither<String>() extensions {
        binding {
            val a = d1.bind()
            val b = d4.bind()
            val c = d3.bind()
            a + b + c
        }
    }

    println(r4)
}

(a) と同じ内容の build.gradle を使って実行します。

実行結果
> gradle run

・・・
Right(b=15)
Right(b=17)
Right(b=17)
Left(a=error data)

TypeScript で funfix を使用 - tsc, FuseBox

funfixJavaScript, TypeScript, Flow の関数型プログラミング用ライブラリで、Fantasy LandStatic Land ※ に準拠し Scala の Option, Either, Try, Future 等と同等の型が用意されているようです。

 ※ JavaScript 用に Monoid や Monad 等の代数的構造に関する仕様を定義したもの

今回は Option を使った単純な処理を TypeScript で実装し Node.js で実行してみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20180730/

はじめに

Option を使った下記サンプルをコンパイルして実行します。

サンプルソース
import { Option, Some } from 'funfix'

const f = (ma, mb) => ma.flatMap(a => mb.map(b => `${a} + ${b} = ${a + b}`))

const d1 = Some(10)
const d2 = Some(2)

console.log( d1 )

console.log('-----')

console.log( f(d1, d2) )
console.log( f(d1, Option.none()) )

console.log('-----')

console.log( f(d1, d2).getOrElse('none') )
console.log( f(d1, Option.none()).getOrElse('none') )

ビルドと実行

上記ソースファイルを以下の 2通りでビルドして実行してみます。

  • (a) tsc 利用
  • (b) FuseBox 利用

(a) tsc を利用する場合

tsc コマンドを使って TypeScript のソースをコンパイルします。

まずは typescript と funfix モジュールをそれぞれインストールします。

typescript インストール
> npm install --save-dev typescript
funfix インストール
> npm install --save funfix

この状態で sample.ts ファイルをコンパイルしてみると、型関係のエラーが出るものの sample.js は正常に作られました。

コンパイル1
> tsc sample.ts

node_modules/funfix-core/dist/disjunctions.d.ts:775:14 - error TS2416: Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'.
  Type 'undefined' is not assignable to type 'never'.

775     readonly value: undefined;
                 ~~~~~


node_modules/funfix-effect/dist/eval.d.ts:256:42 - error TS2304: Cannot find name 'Iterable'.

256     static sequence<A>(list: Eval<A>[] | Iterable<Eval<A>>): Eval<A[]>;
                                             ~~~~~~~~
・・・

sample.js を実行してみると特に問題無く動作します。

実行1
> node sample.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

これで一応は動いたわけですが、コンパイル時にエラーが出るというのも望ましい状態ではないので、エラー部分を解決してみます。

他にも方法があるかもしれませんが、ここでは以下のように対応します。

  • (1) Property 'value' in type 'TNone' ・・・ 'Option<never>' のエラーに対して tsc 実行時に --strictNullChecks オプションを指定して対応
  • (2) Cannot find name 'Iterable' 等のエラーに対して @types/node をインストールして対応

strictNullChecks は tsc の実行時オプションで指定する以外にも設定ファイル tsconfig.json で設定する事もできるので、ここでは tsconfig.json ファイルを使います。

(1) tsconfig.json
{
  "compilerOptions": {
    "strictNullChecks": true
  }
}

次に @types/node をインストールします。

@types/node には Node.js で実行するための型定義(Node.js 依存の API 等)が TypeScript 用に定義されています。

(2) @types/node インストール
> npm install --save-dev @types/node

この状態で tsc を実行すると先程のようなエラーは出なくなりました。(tsconfig.json を適用するため tsc コマンドへ引数を指定していない点に注意)

コンパイル2
> tsc

実行結果にも差異はありません。

実行2
> node sample.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

最終的な package.json の内容は以下の通りです。

package.json
{
  "name": "sample",
  "version": "1.0.0",
  "devDependencies": {
    "@types/node": "^10.5.4",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

(b) FuseBox を利用する場合

次に、モジュールバンドラーの FuseBox を使用してみます。(以降は (a) とは異なるディレクトリで実施)

なお、ここでは npm の代わりに yarn を使っていますが、npm でも特に問題はありません。

yarn のインストール例(npm 使用)
> npm install -g yarn

(b-1) 型チェック無し

まずは typescript, fuse-box, funfix をそれぞれインストールしておきます。

typescript と fuse-box インストール
> yarn add --dev typescript fuse-box
funfix インストール
> yarn add funfix

FuseBox ではビルド定義を JavaScript のコードで記載します。 とりあえずは必要最小限の設定を行いました。

bundle で指定した名称が init の $name に適用されるため、*.tsコンパイル結果と依存モジュールの内容をバンドルして bundle.js へ出力する事になります。

なお、> でロード時に実行する(コードを記載した)ファイルを指定します。

fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box')

const fuse = FuseBox.init({
    output: '$name.js'
})

fuse.bundle('bundle').instructions('> *.ts')

fuse.run()

上記スクリプトを実行してビルド(TypeScript のコンパイルとバンドル)を行います。

ビルド
> node fuse.js

--- FuseBox 3.4.0 ---
  → Generating recommended tsconfig.json:  ・・・\sample_fusebox1\tsconfig.json
  → Typescript script target: ES7

--------------------------
Bundle "bundle"

    sample.js
└──  (1 files,  700 Bytes) default
└── funfix-core 34.4 kB (1 files)
└── funfix-effect 43.1 kB (1 files)
└── funfix-exec 79.5 kB (1 files)
└── funfix 1 kB (1 files)
size: 158.7 kB in 765ms

初回実行時にデフォルト設定の tsconfig.json が作られました。(tsconfig.json が存在しない場合)

tsc の時のような型関係のエラーは出ていませんが、これは FuseBox がデフォルトで TypeScript の型チェックをしていない事が原因のようです。

型チェックを実施するには fuse-box-typechecker プラグインを使う必要がありそうです。

実行
> node bundle.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

package.json の内容は以下の通りです。

package.json
{
  "name": "sample_fusebox1",
  "version": "1.0.0",
  "main": "bundle.js",
  "license": "MIT",
  "devDependencies": {
    "fuse-box": "^3.4.0",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}

(b-2) 型チェック有り

TypeScript の型チェックを行うようにしてみます。

まずは、(b-1) と同じ構成に fuse-box-typechecker プラグインを加えます。

fuse-box-typechecker を追加インストール
> yarn add --dev fuse-box-typechecker

次に、fuse.js へ fuse-box-typechecker プラグインの設定を追加します。

TypeChecker で型チェックにエラーがあった場合、例外が throw されるようにはなっていないため、ここではエラーがあった場合に Error を throw して fuse.run() を実行しないようにしてみました。

ただし、こうすると tsconfig.json を予め用意しておく必要があります。(TypeChecker に tsconfig.json が必要)

fuse.js (FuseBox ビルド定義)
const {FuseBox} = require('fuse-box')
const {TypeChecker} = require('fuse-box-typechecker')

// fuse-box-typechecker の設定
const testSync = TypeChecker({
    tsConfig: './tsconfig.json'
})

const fuse = FuseBox.init({
    output: '$name.js'
})

fuse.bundle('bundle').instructions('> *.ts')

testSync.runPromise()
    .then(n => {
        if (n != 0) {
            // 型チェックでエラーがあった場合
            throw new Error(n)
        }
        // 型チェックでエラーがなかった場合
        return fuse.run()
    })
    .catch(console.error)

これで、ビルド時に (a) と同様の型エラーが出るようになりました。

ビルド1
> node fuse.js

・・・
--- FuseBox 3.4.0 ---

Typechecker plugin(promisesync) .
Time:Sun Jul 29 2018 12:40:47 GMT+0900 (GMT+09:00)

File errors:
└── .\node_modules\funfix-core\dist\disjunctions.d.ts
   | ・・・\sample_fusebox2\node_modules\funfix-core\dist\disjunctions.d.ts (775,14) (Error:TS2416) Property 'value' in type 'TNone' is not assignable to the same property in base type 'Option<never>'.
  Type 'undefined' is not assignable to type 'never'.

Errors:1
└── Options: 0
└── Global: 0
└── Syntactic: 0
└── Semantic: 1
└── TsLint: 0

Typechecking time: 4116ms
Quitting typechecker

・・・

ここで、Iterable の型エラーが出ていないのは fuse-box-typechecker のインストール時に @types/node もインストールされているためです。

(a) と同様に strictNullChecks の設定を tsconfig.json追記して、このエラーを解決します。

tsconfig.json へ strictNullChecks の設定を追加
{
  "compilerOptions": {
    "module": "commonjs",
    "target": "ES7",
    ・・・
    "strictNullChecks": true
  }
}

これでビルドが成功するようになりました。

ビルド2
> node fuse.js

・・・
Typechecker name: undefined
Typechecker basepath: ・・・\sample_fusebox2
Typechecker tsconfig: ・・・\sample_fusebox2\tsconfig.json
--- FuseBox 3.4.0 ---

Typechecker plugin(promisesync) .
Time:Sun Jul 29 2018 12:44:57 GMT+0900 (GMT+09:00)
All good, no errors :-)
Typechecking time: 4103ms
Quitting typechecker

killing worker  → Typescript config file:  \tsconfig.json
  → Typescript script target: ES7

--------------------------
Bundle "bundle"

    sample.js
└──  (1 files,  700 Bytes) default
└── funfix-core 34.4 kB (1 files)
└── funfix-effect 43.1 kB (1 files)
└── funfix-exec 79.5 kB (1 files)
└── funfix 1 kB (1 files)
size: 158.7 kB in 664ms
実行結果
> node bundle.js

TSome { _isEmpty: false, value: 10 }
-----
TSome { _isEmpty: false, value: '10 + 2 = 12' }
TNone { _isEmpty: true, value: undefined }
-----
10 + 2 = 12
none

package.json の内容は以下の通りです。

package.json
{
  "name": "sample_fusebox2",
  "version": "1.0.0",
  "main": "index.js",
  "license": "MIT",
  "devDependencies": {
    "fuse-box": "^3.4.0",
    "fuse-box-typechecker": "^2.10.0",
    "typescript": "^2.9.2"
  },
  "dependencies": {
    "funfix": "^7.0.1"
  }
}