Keras.js によるランドマーク検出の Web アプリケーション化

前回の 「CNN でランドマーク検出」 の学習済みモデルを Keras.js を使って Web ブラウザ上で実行できるようにしてみます。

ソースは http://github.com/fits/try_samples/tree/master/blog/20190331/

準備

npm で Keras.js をインストールします。

Keras.js インストール
> npm install --save keras-js

Keras.js に含まれている encoder.py スクリプトを使って、Python の Keras で学習したモデル(model/cnn_landmark_400.h5)を Keras.js 用に変換します。

モデルファイル(HDF5 形式)を Keras.js 用に変換
> python node_modules/keras-js/python/encoder.py model/cnn_landmark_400.h5

生成された .bin ファイル(model/cnn_landmark_400.bin)のパス(URL)を KerasJS.Model へ指定して使う事になります。

ついでに、webpack もインストールしておきます。(webpack コマンドを使うには webpack-cli も必要)

webpack インストール
> npm install --save-dev webpack webpack-cli

Web アプリケーション作成

今回、作成する Web アプリケーションのファイル構成は以下の通りです。

  • index.html
  • js/bundle_app.js
  • js/bundle_worker.js
  • model/cnn_landmark_400.bin

処理は全て Web ブラウザ上で実行するようにし、Keras.js の処理(今回のランドマーク検出)はそれなりに重いので Web Worker として実行します。

index.html の内容は以下の通りです。

index.html
<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <style type="text/css">
        canvas {
            border: 1px solid;
        }

        canvas.dragging {
            border: 3px solid red;
        }

        table {
            text-align: center;
            border-collapse: collapse;
        }

        table th, table td {
            border: 1px solid;
            padding: 8px;
        }
    </style>
</head>
<body>
    <dialog id="load-dialog">loading model ...</dialog>
    <dialog id="detect-dialog">detecting landmarks ...</dialog>
    <dialog id="error-dialog">ERROR</dialog>

    <div>
        <canvas width="256" height="256"></canvas>
    </div>

    <br>

    <div>
        <div id="landmarks"></div>
    </div>

    <script src="./js/bundle_app.js"></script>
</body>
</html>

bundle_xxx.js を生成するため、以下のような webpack 設定ファイルを用意します。

fs: 'empty' の箇所は Keras.js を webpack で処理するために必要な設定で、これが無いと Module not found: Error: Can't resolve 'fs' のようなエラーが出る事になります。

webpack.config.js
module.exports = {
    entry: {
        bundle_app: __dirname + '/src/app.js',
        bundle_worker: __dirname + '/src/worker.js'
    },
    output: {
        path: __dirname + '/js',
        filename: '[name].js',
    },
    // Keras.js 用の設定
    node: {
        fs: 'empty'
    }
}

Web Worker では Actor モデルのようにメッセージパッシング(postMessage で送信、onmessage で受信)を使ってメインの UI 処理とのデータ連携を行います。

今回は、以下のようなメッセージを Web Worker(Keras.js を使った処理)とやり取りするようにします。

Web Worker(Keras.js の処理)とのメッセージ内容
処理 送信メッセージ 受信メッセージ(成功時) 受信メッセージ(エラー時)
初期化 {type: 'init', url: <モデルファイルのURL>} {type: 'init'} {type: 'init', error: <エラーメッセージ>}
ランドマーク検出 {type: 'predict', input: <入力データ>} {type: 'predict', output: <ランドマーク検出結果>} {type: 'predict', error: <エラーメッセージ>}

(a) ランドマーク検出処理(src/worker.js)

Web Worker として実装するため、postMessage で UI 処理へメッセージを送信し、onmessage でメッセージを受信するようにします。

Keras.js 1.0.3 における Dense 処理の問題

実は、Keras.js 1.0.3 では今回の CNN モデルを正しく処理できません。

というのも、Keras.js 1.0.3 における Dense の処理(GPU を使わない場合)は以下のようになっています。

node_modules/keras-js/lib/layers/core/Dense.js の問題個所
  _callCPU(x) {
    this.output = new _Tensor.default([], [this.units]);

    ・・・
  }

今回の CNN モデルでは Dense の結果が 3次元 (256, 256, 7) になる必要がありますが、上記 Dense 処理では (7) のように 1次元になってしまい正しい結果を得られません。 ※

 ※ ついでに、Keras.js の softmax 処理にも不都合な点があった

そこで、今回は(GPU を使わない事を前提に)Dense_callCPU を実行時に書き換える事で回避しました。

処理内容としては、元の処理を 2重ループ内で実施するようにしています。

Dense 問題の回避措置(src/worker.js)
import KerasJS from 'keras-js'
import { gemv } from 'ndarray-blas-level2'
import ops from 'ndarray-ops'

・・・

// Dense の _callCPU を実行時に変更
KerasJS.layers.Dense.prototype._callCPU = function(x) {
    const h = x.tensor.shape[0]
    const w = x.tensor.shape[1]

    this.output = new KerasJS.Tensor([], [h, w, this.units])

    for (let i = 0; i < h; i++) {
        for (let j = 0; j < w; j++) {

            const xt = x.tensor.pick(i, j)
            const ot = this.output.tensor.pick(i, j)

            if (this.use_bias) {
                ops.assign(ot, this.weights['bias'].tensor)
            }

            gemv(1, this.weights['kernel'].tensor.transpose(1, 0), xt, 1, ot)

            this.activationFunc({tensor: ot})
        }
    }
}

ランドマーク検出の実装

KerasJS.Modelpredict へ入力データを渡したり結果を取り出すにはレイヤー名を指定する必要があり、これらのレイヤー名は iuputLayerNamesoutputLayerNames でそれぞれ取得できます。

predict の結果は、各座標のランドマーク該当確率 (256, 256, 7) となるので、ここではランドマーク毎 ※ に最も確率の高かった座標のみを結果として返すようにしています。

 ※ ランドマーク 0 はランドマークに該当しなかった場合なので結果に含めていない
src/worker.js
import KerasJS from 'keras-js'
import { gemv } from 'ndarray-blas-level2'
import ops from 'ndarray-ops'

let model = null

// モデルデータの読み込み
const loadModel = file => {
    const model = new KerasJS.Model({ filepath: file })

    return model.ready().then(r => model)
}

// Keras.js の Dense 問題への対応
KerasJS.layers.Dense.prototype._callCPU = function(x) {
    ・・・
}

// predict の結果を処理(ランドマーク毎に最も確率の高い座標を抽出)
const detectLandmarks = ts => {
    const res = {}

    for (let h = 0; h < ts.tensor.shape[0]; h++) {
        for (let w = 0; w < ts.tensor.shape[1]; w++) {
            const t = ts.tensor.pick(h, w)

            const wrkProb = {landmark: 0, prob: 0, x: w, y: h}

            for (let c = 0; c < t.shape[0]; c++) {
                const prob = t.get(c)

                if (prob > wrkProb.prob) {
                    wrkProb.landmark = c
                    wrkProb.prob = prob
                }
            }
            // ランドマーク 0 (ランドマークでは無い)は除外
            if (wrkProb.landmark > 0) {
                const curProb = res[wrkProb.landmark]

                if (!curProb || curProb.prob < wrkProb.prob) {
                    res[wrkProb.landmark] = wrkProb
                }
            }
        }
    }

    return res
}

// UI 処理からのメッセージ受信
onmessage = ev => {
    switch (ev.data.type) {
        case 'init':
            loadModel(ev.data.url)
                .then(m => {
                    model = m
                    postMessage({type: ev.data.type})
                })
                .catch(err => {
                    console.log(err)
                    postMessage({type: ev.data.type, error: err.message})
                })

            break
        case 'predict':
            const outputLayerName = model.outputLayerNames[0]

            const shape = model.modelLayersMap.get(outputLayerName)
                                                .output.tensor.shape

            const data = {}
            // 入力データの設定
            data[model.inputLayerNames[0]] = ev.data.input

            Promise.resolve(model.predict(data))
                .then(r => new KerasJS.Tensor(r[outputLayerName], shape)) // predict 実行結果の取り出し
                .then(detectLandmarks)
                .then(r => 
                    // UI 処理へ結果送信
                    postMessage({type: ev.data.type, output: r})
                )
                .catch(err => {
                    console.log(err)
                    postMessage({type: ev.data.type, error: err.message})
                })

            break
    }
}

(b) UI 処理(src/app.js)

画像データの変換(入力データの作成)

KerasJS.Model で predict するために、今回のケースでは画像データを 256(高さ)× 256(幅)× 3(RGB) サイズの 1次元配列 Float32Array へ変換する必要があります。

今回は以下のように canvas を利用して変換を行いました。

ImageData.data は RGBA 並びの 1次元配列 Uint8ClampedArray となっているので、RGB 部分のみを取り出して(A の内容を除外する)Float32Array を生成しています。

ちなみに、今回の CNN モデル自体は画像サイズに依存しない(Fully Convolutional Networks 的な)構成になっています。

そのため、任意サイズの画像を処理する事もできるのですが、現時点の Keras.js ではそんな事を考慮してくれていないので、実現するにはそれなりの工夫が必要になります。(一応、実現は可能でした)

ここでは、単純に canvas へ描画した 256x256 範囲の内容だけ(つまりは固定サイズ)を使うようにしています。※

 ※ この方法では 256x256 以外のサイズで欠けや余白の入り込みが発生する
画像データ変換部分(src/app.js)
・・・
const imageTypes = ['image/jpeg']

const canvas = document.getElementsByTagName('canvas')[0]
const ctx = canvas.getContext('2d')

・・・

// RGBA 並びの Uint8ClampedArray を RGB 並びの Float32Array へ変換
const imgToArray = imgData => new Float32Array(
    imgData.data.reduce(
        (acc, v, i) => {
            // RGBA の A 部分を除外
            if (i % 4 != 3) {
                acc.push(v)
            }
            return acc
        },
        []
     )
)
// 画像の読み込み
const loadImage = url => new Promise(resolve => {
    const img = new Image()

    img.addEventListener('load', () => {
        ctx.clearRect(0, 0, canvas.width, canvas.height)

        // 画像サイズが canvas よりも小さい場合の考慮
        const w = Math.min(img.width, canvas.width)
        const h = Math.min(img.height, canvas.height)

        // canvas へ画像を描画
        ctx.drawImage(img, 0, 0, w, h, 0, 0, w, h)

        // ImageData の取得
        const d = ctx.getImageData(0, 0, canvas.width, canvas.height)

        resolve(imgToArray(d))
    })

    img.src = url
})

・・・

// モデルデータ読み込み完了時の処理
const ready = () => {
    canvas.addEventListener('dragover', ev => {
        ev.preventDefault()
        canvas.classList.add('dragging')
    }, false)

    canvas.addEventListener('dragleave', ev => {
        canvas.classList.remove('dragging')
    }, false)

    // ドロップ時の処理
    canvas.addEventListener('drop', ev => {
        ev.preventDefault()
        canvas.classList.remove('dragging')

        const file = ev.dataTransfer.files[0]

        if (imageTypes.includes(file.type)) {
            ・・・
            const reader = new FileReader()

            reader.onload = ev => {
                loadImage(reader.result)
                    .then(img => {
                        ・・・
                    })
            }

            reader.readAsDataURL(file)
        }
    }, false)
}

・・・

Web Worker との連携

Web Worker とメッセージをやり取りし、ランドマークの検出結果を描画する部分の実装です。

Web Worker との連携部分(src/app.js)
const colors = ['rgb(255, 255, 255)', 'rgb(255, 0, 0)', 'rgb(0, 255, 0)', 'rgb(0, 0, 255)', 'rgb(255, 255, 0)', 'rgb(0, 255, 255)', 'rgb(255, 0, 255)']

const radius = 5
const imageTypes = ['image/jpeg']

const modelFile = '../model/cnn_landmark_400.bin'

// Web Worker の作成
const worker = new Worker('./js/bundle_worker.js')

・・・

// 検出したランドマークを canvas へ描画
const drawLandmarks = lms => {
    Object.values(lms).forEach(v => {
        ctx.fillStyle = colors[v.landmark]
        ctx.beginPath()
        ctx.arc(v.x, v.y, radius, 0, Math.PI * 2, false)
        ctx.fill()
    })
}

・・・

// 検出したランドマークの内容を table(HTML)化して表示
const showLandmarksInfo = lms => {
    ・・・

    infoNode.innerHTML = `
      <table>
        <tr>
          <th>landmark</th>
          <th>coordinate</th>
          <th>prob</th>
        </tr>
        ${rowsHtml}
      </table>
    `
}

// モデルデータ読み込み完了後
const ready = () => {
    ・・・
    canvas.addEventListener('drop', ev => {
        ・・・
        if (imageTypes.includes(file.type)) {
            ・・・
            reader.onload = ev => {
                loadImage(reader.result)
                    .then(img => {
                        detectDialog.showModal()
                        // Web Worker へのランドマーク検出指示
                        worker.postMessage({type: 'predict', input: img})
                    })
            }

            reader.readAsDataURL(file)
        }
    }, false)
}

// Web Worker からのメッセージ受信
worker.onmessage = ev => {
    if (ev.data.error) {
        ・・・
    }
    else {
        switch (ev.data.type) {
            case 'init':
                ready()

                loadDialog.close()

                break
            case 'predict':
                const res = ev.data.output

                console.log(res)
                detectDialog.close()

                drawLandmarks(res)
                showLandmarksInfo(res)

                break
        }
    }
}

loadDialog.showModal()
// Web Worker へのモデルデータ読み込み指示
worker.postMessage({type: 'init', url: modelFile})

(c) ビルド

webpack コマンドを実行し、js/bundle_app.js と js/bundle_worker.js を生成します。

webpack によるビルド
> webpack

(d) 動作確認

HTTP サーバーを使って動作確認を行います。 今回は http-server を使って実行しました。

http-server 実行
> http-server

Starting up http-server, serving ./
Available on:
  ・・・
  http://127.0.0.1:8080
Hit CTRL-C to stop the server

http://localhost:8080/index.htmlChrome ※ でアクセスして画像ファイルをドラッグアンドドロップすると以下のような結果となりました。

f:id:fits:20190331192210j:plain

 ※ HTMLDialogElement.showModal() を使っている関係で
    現時点では Chrome でしか動作しませんが、
    dialog 以外の部分(Keras.js の処理等)は
    Firefox でも動作するようになっています