sqlparse で SQL から更新対象のカラムを抽出

sqlparse を使って SQL の UPDATE 文から更新対象のカラムを抽出してみます。

ソースコードこちら

はじめに

更新対象のカラムを抽出するにはパース結果のトークンの中から該当部分を探して値を取得します。

例えば、以下のような UPDATE 文をパースすると

1.sql
UPDATE ITEMS SET price = 1000 WHERE id = '1'

次のようなトークン構成となり、更新するカラムの部分は Comparison となります。

1.sql のパース結果
[<DML 'UPDATE' at 0x200357A3280>, 
 <Whitespace ' ' at 0x200357A32E0>, 
 <Identifier 'ITEMS' at 0x2003579E7A0>, 
 <Whitespace ' ' at 0x200357A33A0>, 
 <Keyword 'SET' at 0x200357A3400>, 
 <Whitespace ' ' at 0x200357A3460>, 
 <Comparison 'price ...' at 0x2003579E8F0>, 
 <Whitespace ' ' at 0x200357A36A0>, 
 <Where 'WHERE ...' at 0x2003579E730>]

また、次のように複数のカラムを更新する場合は IdentifierList (複数の Comparison を持っている)となります。

2.sql
update ITEMS as i set i.price = 200, i.updated_at = NOW(), i.rev = i.rev + 1 where id = '2'
2.sql のパース結果
[<DML 'update' at 0x1EA8A403160>, 
 <Whitespace ' ' at 0x1EA8A403280>, 
 <Identifier 'ITEMS ...' at 0x1EA8A3FEA40>, 
 <Whitespace ' ' at 0x1EA8A4034C0>, 
 <Keyword 'set' at 0x1EA8A403520>, <Whitespace ' ' at 0x1EA8A403580>, 
 <IdentifierList 'i.pric...' at 0x1EA8A3FEB90>, 
 <Whitespace ' ' at 0x1EA8A44C280>, 
 <Where 'where ...' at 0x1EA8A3FE810>]

更新カラムの抽出

それでは、更新対象のカラム(ついでにテーブル名も付与)を抽出してみます。

トークンの型を判定する事になるので、ここでは Python 3.10 でサポートされたパターンマッチを使っています。

テーブル名にエイリアスを使っていると get_name() ではエイリアスが返ってくるため、get_real_name() を使うようにしました。

sample1.py
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Comparison, Token
from sqlparse.tokens import DML

import sys

sql = sys.stdin.read()

def fields_to_update(st):
    is_update = False
    table = ''

    for t in st.tokens:
        match t:
            # UPDATE 文の場合
            case Token(ttype=tt, value=v) if tt == DML and v.upper() == 'UPDATE':
                is_update = True
            # 複数カラム更新時
            case IdentifierList() if is_update:
                for c in t.tokens:
                    match c:
                        case Comparison(left=Identifier() as l) if is_update:
                            yield f"{table}.{l.get_real_name()}"
            # テーブル名の取得
            case Identifier() if is_update:
                table = t.get_real_name()
            # 単体カラム更新時
            case Comparison(left=Identifier() as l) if is_update:
                yield f"{table}.{l.get_real_name()}"

for s in sqlparse.parse(sql):
    fields = list(fields_to_update(s))
    print(fields)

上記の冗長な部分を再帰処理に変えて改良すると以下のようになりました。

sample2.py
import sqlparse
from sqlparse.sql import Identifier, IdentifierList, Comparison, Token
from sqlparse.tokens import DML

import sys

sql = sys.stdin.read()

def fields_to_update(st):
    def process(ts, table = '', is_update = False):
        for t in ts.tokens:
            match t:
                case Token(ttype=tt, value=v) if tt == DML and v.upper() == 'UPDATE':
                    is_update = True
                case IdentifierList() if is_update:
                    yield from process(t, table, is_update)
                case Identifier() if is_update:
                    table = t.get_real_name()
                case Comparison(left=Identifier() as l) if is_update:
                    yield f"{table}.{l.get_real_name()}"
    
    yield from process(st)

for s in sqlparse.parse(sql):
    fields = list(fields_to_update(s))
    print(fields)

動作確認

以下の SQL を使って動作確認してみます。

3.sql
UPDATE ITEMS SET price = 1000 WHERE id = '1';
SELECT * FROM sample.ITEMS WHERE price > 1000;

update
  sample.ITEMS as i 
set
  i.price = 200, 
  i.updated_at = NOW(), 
  i.rev = i.rev + 1
where
  id = '2';

delete from ITEMS where price <= 0;

実行結果は以下の通りです。

sample1.py 実行結果
$ python sample1.py < 3.sql
['ITEMS.price']
[]
['ITEMS.price', 'ITEMS.updated_at', 'ITEMS.rev']
[]
sample2.py 実行結果
$ python sample2.py < 3.sql
['ITEMS.price']
[]
['ITEMS.price', 'ITEMS.updated_at', 'ITEMS.rev']
[]

辞書ベースの日本語 Tokenizer - Kuromoji, Sudachi, Fugashi, Kagome, Lindera

辞書をベースに処理する日本語 Tokenizer のいくつかをコードを書いて実行してみました。

  • (a) Lucene Kuromoji
  • (b) atilika Kuromoji
  • (c) Sudachi
  • (d) Kuromoji.js
  • (e) Fugashi
  • (f) Kagome
  • (g) Lindera

今回は以下の文を処理して分割された単語と品詞を出力します。

処理対象文
 WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。

システム辞書だけを使用し、分割モードを指定する場合は固有名詞などをそのままにする(細かく分割しない)モードを選ぶ事にします。

ソースコードhttps://github.com/fits/try_samples/tree/master/blog/20220106/

(a) Lucene Kuromoji

Lucene に組み込まれた Kuromoji で Elasticsearch や Solr で使われます。

kuromoji.gradle を見ると、システム辞書は以下のどちらかを使うようになっているようです。

a1

lucene-analyzers-kuromoji の JapaneseTokenizer を使います。 辞書は IPADIC のようです。

lucene/a1.groovy
@Grab('org.apache.lucene:lucene-analyzers-kuromoji:8.11.1')
import org.apache.lucene.analysis.ja.JapaneseTokenizer;
import org.apache.lucene.analysis.ja.JapaneseTokenizer.Mode
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
import org.apache.lucene.analysis.ja.tokenattributes.PartOfSpeechAttribute

def text = args[0]

new JapaneseTokenizer(null, false, Mode.NORMAL).withCloseable { tokenizer ->
    def term = tokenizer.addAttribute(CharTermAttribute)
    def pos = tokenizer.addAttribute(PartOfSpeechAttribute)

    tokenizer.reader = new StringReader(text)
    tokenizer.reset()

    while(tokenizer.incrementToken()) {
        println "term=${term}, partOfSpeech=${pos.partOfSpeech}"
    }
}
a1 結果
> groovy a1.groovy "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。"

term=WebAssembly, partOfSpeech=名詞-固有名詞-組織
term=が, partOfSpeech=助詞-格助詞-一般
term=サーバー, partOfSpeech=名詞-一般
term=レス, partOfSpeech=名詞-サ変接続
term=分野, partOfSpeech=名詞-一般
term=へ, partOfSpeech=助詞-格助詞-一般
term=大きな, partOfSpeech=連体詞
term=影響, partOfSpeech=名詞-サ変接続
term=を, partOfSpeech=助詞-格助詞-一般
term=与える, partOfSpeech=動詞-自立
term=だろ, partOfSpeech=助動詞
term=う, partOfSpeech=助動詞
term=と, partOfSpeech=助詞-格助詞-引用
term=答え, partOfSpeech=動詞-自立
term=た, partOfSpeech=助動詞
term=回答, partOfSpeech=名詞-サ変接続
term=者, partOfSpeech=名詞-接尾-一般
term=は, partOfSpeech=助詞-係助詞
term=全体, partOfSpeech=名詞-副詞可能
term=の, partOfSpeech=助詞-連体化
term=5, partOfSpeech=名詞-数
term=6, partOfSpeech=名詞-数
term=%, partOfSpeech=名詞-接尾-助数詞
term=だっ, partOfSpeech=助動詞
term=た, partOfSpeech=助動詞
term=。, partOfSpeech=記号-句点

a2

org.codelibs が上記の ipadic-neologd 版を提供していたので、ついでに試してみました。

処理内容はそのままで、モジュールとパッケージ名を変えるだけです。

lucene/a2.groovy
@GrabResolver('https://maven.codelibs.org/')
@Grab('org.codelibs:lucene-analyzers-kuromoji-ipadic-neologd:8.2.0-20200120')
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute
import org.codelibs.neologd.ipadic.lucene.analysis.ja.JapaneseTokenizer
import org.codelibs.neologd.ipadic.lucene.analysis.ja.JapaneseTokenizer.Mode
import org.codelibs.neologd.ipadic.lucene.analysis.ja.tokenattributes.PartOfSpeechAttribute

def text = args[0]

new JapaneseTokenizer(null, false, Mode.NORMAL).withCloseable { tokenizer ->
    ・・・
}
a2 結果
> groovy a2.groovy "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。"

term=WebAssembly, partOfSpeech=名詞-固有名詞-組織
term=が, partOfSpeech=助詞-格助詞-一般
term=サーバーレス, partOfSpeech=名詞-固有名詞-一般
term=分野, partOfSpeech=名詞-一般
term=へ, partOfSpeech=助詞-格助詞-一般
term=大きな, partOfSpeech=連体詞
term=影響, partOfSpeech=名詞-サ変接続
term=を, partOfSpeech=助詞-格助詞-一般
term=与える, partOfSpeech=動詞-自立
term=だろ, partOfSpeech=助動詞
term=う, partOfSpeech=助動詞
term=と, partOfSpeech=助詞-格助詞-引用
term=答え, partOfSpeech=動詞-自立
term=た, partOfSpeech=助動詞
term=回答者, partOfSpeech=名詞-固有名詞-一般
term=は, partOfSpeech=助詞-係助詞
term=全体, partOfSpeech=名詞-副詞可能
term=の, partOfSpeech=助詞-連体化
term=5, partOfSpeech=名詞-数
term=6, partOfSpeech=名詞-数
term=%, partOfSpeech=名詞-接尾-助数詞
term=だっ, partOfSpeech=助動詞
term=た, partOfSpeech=助動詞
term=。, partOfSpeech=記号-句点

a1 と違って "サーバーレス" や "回答者" となりました。

(b) atilika Kuromoji

https://github.com/atilika/kuromoji

Lucene Kuromoji のベースとなった Kuromoji。 更新は途絶えているようですが、色々な辞書に対応しています。

ここでは以下の 2種類の辞書を試してみました。

  • UniDic(2.1.2)
  • JUMAN(7.0-20130310)

b1

まずは、UniDic 版です。

kuromoji/b1.groovy
@Grab('com.atilika.kuromoji:kuromoji-unidic:0.9.0')
import com.atilika.kuromoji.unidic.Tokenizer

def text = args[0]
def tokenizer = new Tokenizer()

tokenizer.tokenize(args[0]).each {
    def pos = [
        it.partOfSpeechLevel1,
        it.partOfSpeechLevel2,
        it.partOfSpeechLevel3,
        it.partOfSpeechLevel4
    ]

    println "term=${it.surface}, partOfSpeech=${pos}"
}
b1 結果
> groovy b1.groovy "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だ った。"

term=WebAssembly, partOfSpeech=[名詞, 普通名詞, 一般, *]
term=が, partOfSpeech=[助詞, 格助詞, *, *]
term=サーバー, partOfSpeech=[名詞, 普通名詞, 一般, *]
term=レス, partOfSpeech=[名詞, 普通名詞, 一般, *]
term=分野, partOfSpeech=[名詞, 普通名詞, 一般, *]
term=へ, partOfSpeech=[助詞, 格助詞, *, *]
term=大きな, partOfSpeech=[連体詞, *, *, *]
term=影響, partOfSpeech=[名詞, 普通名詞, サ変可能, *]
term=を, partOfSpeech=[助詞, 格助詞, *, *]
term=与える, partOfSpeech=[動詞, 一般, *, *]
term=だろう, partOfSpeech=[助動詞, *, *, *]
term=と, partOfSpeech=[助詞, 格助詞, *, *]
term=答え, partOfSpeech=[動詞, 一般, *, *]
term=た, partOfSpeech=[助動詞, *, *, *]
term=回答, partOfSpeech=[名詞, 普通名詞, サ変可能, *]
term=者, partOfSpeech=[接尾辞, 名詞的, 一般, *]
term=は, partOfSpeech=[助詞, 係助詞, *, *]
term=全体, partOfSpeech=[名詞, 普通名詞, 一般, *]
term=の, partOfSpeech=[助詞, 格助詞, *, *]
term=5, partOfSpeech=[名詞, 数詞, *, *]
term=6, partOfSpeech=[名詞, 数詞, *, *]
term=%, partOfSpeech=[名詞, 普通名詞, 助数詞可能, *]
term=だっ, partOfSpeech=[助動詞, *, *, *]
term=た, partOfSpeech=[助動詞, *, *, *]
term=。, partOfSpeech=[補助記号, 句点, *, *]

"だろう" が分割されていないのが特徴。

b2

JUMAN 辞書版です。

kuromoji/b2.groovy
@Grab('com.atilika.kuromoji:kuromoji-jumandic:0.9.0')
import com.atilika.kuromoji.jumandic.Tokenizer

def text = args[0]
def tokenizer = new Tokenizer()

・・・
b2 結果
> groovy b2.groovy "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だ った。"

term=WebAssembly, partOfSpeech=[名詞, 組織名, *, *]
term=が, partOfSpeech=[助詞, 格助詞, *, *]
term=サーバーレス, partOfSpeech=[名詞, 人名, *, *]
term=分野, partOfSpeech=[名詞, 普通名詞, *, *]
term=へ, partOfSpeech=[助詞, 格助詞, *, *]
term=大きな, partOfSpeech=[連体詞, *, *, *]
term=影響, partOfSpeech=[名詞, サ変名詞, *, *]
term=を, partOfSpeech=[助詞, 格助詞, *, *]
term=与える, partOfSpeech=[動詞, *, 母音動詞, 基本形]
term=だろう, partOfSpeech=[助動詞, *, 助動詞だろう型, 基本形]
term=と, partOfSpeech=[助詞, 格助詞, *, *]
term=答えた, partOfSpeech=[動詞, *, 母音動詞, タ形]
term=回答, partOfSpeech=[名詞, サ変名詞, *, *]
term=者, partOfSpeech=[接尾辞, 名詞性名詞接尾辞, *, *]
term=は, partOfSpeech=[助詞, 副助詞, *, *]
term=全体, partOfSpeech=[名詞, 普通名詞, *, *]
term=の, partOfSpeech=[助詞, 接続助詞, *, *]
term=56, partOfSpeech=[名詞, 数詞, *, *]
term=%, partOfSpeech=[接尾辞, 名詞性名詞助数辞, *, *]
term=だった, partOfSpeech=[判定詞, *, 判定詞, ダ列タ形]
term=。, partOfSpeech=[特殊, 句点, *, *]

"サーバーレス"、"だろう"、"答えた"、"56"、"だった" が分割されていないのが特徴。 "サーバーレス" が人名となっているのは不思議。

(c) Sudachi

https://github.com/WorksApplications/Sudachi

Rust 版 もありますが、Java 版を使いました。

辞書は UniDic と NEologd をベースに調整したものらしく、3種類(Small, Core, Full)用意されています。

辞書が継続的にメンテナンスされており最新のものを使えるのが魅力だと思います。

ここではデフォルトの Core 辞書を使いました。(system_core.dic ファイルをカレントディレクトリに配置して実行)

  • Core 辞書(20211220 版)

また、Elasticsearch 用のプラグイン analysis-sudachi も用意されています。

sudachi/c1.groovy
@Grab('com.worksap.nlp:sudachi:0.5.3')
import com.worksap.nlp.sudachi.DictionaryFactory
import com.worksap.nlp.sudachi.Tokenizer

def text = args[0]

new DictionaryFactory().create().withCloseable { dic ->
    def tokenizer = dic.create()
    def ts = tokenizer.tokenize(Tokenizer.SplitMode.C, text)

    ts.each { t ->
        def pos = dic.getPartOfSpeechString(t.partOfSpeechId())

        println "term=${t.surface()}, partOfSpeech=${pos}"
    }
}
c1 結果
> groovy c1.groovy "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だっ た。"

term=WebAssembly, partOfSpeech=[名詞, 普通名詞, 一般, *, *, *]
term=が, partOfSpeech=[助詞, 格助詞, *, *, *, *]
term=サーバーレス, partOfSpeech=[名詞, 普通名詞, 一般, *, *, *]
term=分野, partOfSpeech=[名詞, 普通名詞, 一般, *, *, *]
term=へ, partOfSpeech=[助詞, 格助詞, *, *, *, *]
term=大きな, partOfSpeech=[連体詞, *, *, *, *, *]
term=影響, partOfSpeech=[名詞, 普通名詞, サ変可能, *, *, *]
term=を, partOfSpeech=[助詞, 格助詞, *, *, *, *]
term=与える, partOfSpeech=[動詞, 一般, *, *, 下一段-ア行, 終止形-一般]
term=だろう, partOfSpeech=[助動詞, *, *, *, 助動詞-ダ, 意志推量形]
term=と, partOfSpeech=[助詞, 格助詞, *, *, *, *]
term=答え, partOfSpeech=[動詞, 一般, *, *, 下一段-ア行, 連用形-一般]
term=た, partOfSpeech=[助動詞, *, *, *, 助動詞-タ, 連体形-一般]
term=回答者, partOfSpeech=[名詞, 普通名詞, 一般, *, *, *]
term=は, partOfSpeech=[助詞, 係助詞, *, *, *, *]
term=全体, partOfSpeech=[名詞, 普通名詞, 一般, *, *, *]
term=の, partOfSpeech=[助詞, 格助詞, *, *, *, *]
term=56, partOfSpeech=[名詞, 数詞, *, *, *, *]
term=%, partOfSpeech=[名詞, 普通名詞, 助数詞可能, *, *, *]
term=だっ, partOfSpeech=[助動詞, *, *, *, 助動詞-ダ, 連用形-促音便]
term=た, partOfSpeech=[助動詞, *, *, *, 助動詞-タ, 終止形-一般]
term=。, partOfSpeech=[補助記号, 句点, *, *, *, *]

"サーバーレス"、"だろう"、"回答者"、"56" となっているのが特徴。

(d) Kuromoji.js

https://github.com/takuyaa/kuromoji.js/

Kuromoji の JavaScript 実装。

kuromoji.js/d1.mjs
import kuromoji from 'kuromoji'

const dicPath = 'node_modules/kuromoji/dict'
const text = process.argv[2]

kuromoji.builder({ dicPath }).build((err, tokenizer) => {
    if (err) {
        console.error(err)
        return
    }

    const ts = tokenizer.tokenize(text)

    for (const t of ts) {
        const pos = [t.pos, t.pos_detail_1, t.pos_detail_2, t.pos_detail_3]

        console.log(`term=${t.surface_form}, partOfSpeech=${pos}`)
    }
})
d1 結果
> node d1.mjs "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。"

term=WebAssembly, partOfSpeech=名詞,固有名詞,組織,*
term=が, partOfSpeech=助詞,格助詞,一般,*
term=サーバー, partOfSpeech=名詞,一般,*,*
term=レス, partOfSpeech=名詞,サ変接続,*,*
term=分野, partOfSpeech=名詞,一般,*,*
term=へ, partOfSpeech=助詞,格助詞,一般,*
term=大きな, partOfSpeech=連体詞,*,*,*
term=影響, partOfSpeech=名詞,サ変接続,*,*
term=を, partOfSpeech=助詞,格助詞,一般,*
term=与える, partOfSpeech=動詞,自立,*,*
term=だろ, partOfSpeech=助動詞,*,*,*
term=う, partOfSpeech=助動詞,*,*,*
term=と, partOfSpeech=助詞,格助詞,引用,*
term=答え, partOfSpeech=動詞,自立,*,*
term=た, partOfSpeech=助動詞,*,*,*
term=回答, partOfSpeech=名詞,サ変接続,*,*
term=者, partOfSpeech=名詞,接尾,一般,*
term=は, partOfSpeech=助詞,係助詞,*,*
term=全体, partOfSpeech=名詞,副詞可能,*,*
term=の, partOfSpeech=助詞,連体化,*,*
term=5, partOfSpeech=名詞,数,*,*
term=6, partOfSpeech=名詞,数,*,*
term=%, partOfSpeech=名詞,接尾,助数詞,*
term=だっ, partOfSpeech=助動詞,*,*,*
term=た, partOfSpeech=助動詞,*,*,*
term=。, partOfSpeech=記号,句点,*,*

a1 と同じ結果になりました。

(e) Fugashi

https://github.com/polm/fugashi

MeCabPython 用ラッパー。

辞書として unidic-lite と unidic のパッケージが用意されていましたが、 ここでは JUMAN 辞書を使いました。

  • JUMAN
fugashi/e1.py
from fugashi import GenericTagger
import sys

text = sys.argv[1]

tagger = GenericTagger()

for t in tagger(text):
    pos = t.feature[0:4]
    print(f"term={t.surface}, partOfSpeech={pos}")
e1 結果
> python e1.py "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。"

term=WebAssembly, partOfSpeech=('名詞', '組織名', '*', '*')
term=が, partOfSpeech=('助詞', '格助詞', '*', '*')
term=サーバーレス, partOfSpeech=('名詞', '人名', '*', '*')
term=分野, partOfSpeech=('名詞', '普通名詞', '*', '*')
term=へ, partOfSpeech=('助詞', '格助詞', '*', '*')
term=大きな, partOfSpeech=('連体詞', '*', '*', '*')
term=影響, partOfSpeech=('名詞', 'サ変名詞', '*', '*')
term=を, partOfSpeech=('助詞', '格助詞', '*', '*')
term=与える, partOfSpeech=('動詞', '*', '母音動詞', '基本形')
term=だろう, partOfSpeech=('助動詞', '*', '助動詞だろう型', '基本形')
term=と, partOfSpeech=('助詞', '格助詞', '*', '*')
term=答えた, partOfSpeech=('動詞', '*', '母音動詞', 'タ形')
term=回答, partOfSpeech=('名詞', 'サ変名詞', '*', '*')
term=者, partOfSpeech=('接尾辞', '名詞性名詞接尾辞', '*', '*')
term=は, partOfSpeech=('助詞', '副助詞', '*', '*')
term=全体, partOfSpeech=('名詞', '普通名詞', '*', '*')
term=の, partOfSpeech=('助詞', '接続助詞', '*', '*')
term=56, partOfSpeech=('名詞', '数詞', '*', '*')
term=%, partOfSpeech=('接尾辞', '名詞性名詞助数辞', '*', '*')
term=だった, partOfSpeech=('判定詞', '*', '判定詞', 'ダ列タ形')
term=。, partOfSpeech=('特殊', '句点', '*', '*')

同じ辞書を使っている b2 と同じ結果になりました。("サーバーレス" が人名なのも同じ)。

(f) Kagome

https://github.com/ikawaha/kagome

ここでは以下の辞書を使用します。

  • IPADIC(mecab-ipadic-2.7.0-20070801)
  • UniDic(2.1.2)

なお、Tokenize を呼び出した場合は、分割モードとして Normal が適用されるようです。

f1

IPADIC 版

kagome/f1.go
package main

import (
    "fmt"
    "os"

    "github.com/ikawaha/kagome-dict/ipa"
    "github.com/ikawaha/kagome/v2/tokenizer"
)

func main() {
    text := os.Args[1]

    t, err := tokenizer.New(ipa.Dict(), tokenizer.OmitBosEos())

    if err != nil {
        panic(err)
    }
    // 分割モード Normal
    ts := t.Tokenize(text)

    for _, t := range ts {
        fmt.Printf("term=%s, partOfSpeech=%v\n", t.Surface, t.POS())
    }
}
f1 結果
> go run f1.go "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。"

term=WebAssembly, partOfSpeech=[名詞 固有名詞 組織 *]
term=が, partOfSpeech=[助詞 格助詞 一般 *]
term=サーバー, partOfSpeech=[名詞 一般 * *]
term=レス, partOfSpeech=[名詞 サ変接続 * *]
term=分野, partOfSpeech=[名詞 一般 * *]
term=へ, partOfSpeech=[助詞 格助詞 一般 *]
term=大きな, partOfSpeech=[連体詞 * * *]
term=影響, partOfSpeech=[名詞 サ変接続 * *]
term=を, partOfSpeech=[助詞 格助詞 一般 *]
term=与える, partOfSpeech=[動詞 自立 * *]
term=だろ, partOfSpeech=[助動詞 * * *]
term=う, partOfSpeech=[助動詞 * * *]
term=と, partOfSpeech=[助詞 格助詞 引用 *]
term=答え, partOfSpeech=[動詞 自立 * *]
term=た, partOfSpeech=[助動詞 * * *]
term=回答, partOfSpeech=[名詞 サ変接続 * *]
term=者, partOfSpeech=[名詞 接尾 一般 *]
term=は, partOfSpeech=[助詞 係助詞 * *]
term=全体, partOfSpeech=[名詞 副詞可能 * *]
term=の, partOfSpeech=[助詞 連体化 * *]
term=5, partOfSpeech=[名詞 数 * *]
term=6, partOfSpeech=[名詞 数 * *]
term=%, partOfSpeech=[名詞 接尾 助数詞 *]
term=だっ, partOfSpeech=[助動詞 * * *]
term=た, partOfSpeech=[助動詞 * * *]
term=。, partOfSpeech=[記号 句点 * *]

同じ辞書を使っている a1 と同じ結果になりました。

f2

UniDic 版

kagome/f2.go
package main

import (
    "fmt"
    "os"

    "github.com/ikawaha/kagome-dict/uni"
    "github.com/ikawaha/kagome/v2/tokenizer"
)

func main() {
    text := os.Args[1]

    t, err := tokenizer.New(uni.Dict(), tokenizer.OmitBosEos())

    ・・・
}
f2 結果
> go run f2.go "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。"

term=WebAssembly, partOfSpeech=[名詞 普通名詞 一般 *]
term=が, partOfSpeech=[助詞 格助詞 * *]
term=サーバー, partOfSpeech=[名詞 普通名詞 一般 *]
term=レス, partOfSpeech=[名詞 普通名詞 一般 *]
term=分野, partOfSpeech=[名詞 普通名詞 一般 *]
term=へ, partOfSpeech=[助詞 格助詞 * *]
term=大きな, partOfSpeech=[連体詞 * * *]
term=影響, partOfSpeech=[名詞 普通名詞 サ変可能 *]
term=を, partOfSpeech=[助詞 格助詞 * *]
term=与える, partOfSpeech=[動詞 一般 * *]
term=だろう, partOfSpeech=[助動詞 * * *]
term=と, partOfSpeech=[助詞 格助詞 * *]
term=答え, partOfSpeech=[動詞 一般 * *]
term=た, partOfSpeech=[助動詞 * * *]
term=回答, partOfSpeech=[名詞 普通名詞 サ変可能 *]
term=者, partOfSpeech=[接尾辞 名詞的 一般 *]
term=は, partOfSpeech=[助詞 係助詞 * *]
term=全体, partOfSpeech=[名詞 普通名詞 一般 *]
term=の, partOfSpeech=[助詞 格助詞 * *]
term=5, partOfSpeech=[名詞 数詞 * *]
term=6, partOfSpeech=[名詞 数詞 * *]
term=%, partOfSpeech=[名詞 普通名詞 助数詞可能 *]
term=だっ, partOfSpeech=[助動詞 * * *]
term=た, partOfSpeech=[助動詞 * * *]
term=。, partOfSpeech=[補助記号 句点 * *]

同じ辞書を使っている b1 と同じ結果になりました。

(g) Lindera

https://github.com/lindera-morphology/lindera

kuromoji-rs のフォークのようで、辞書は IPADIC です。

  • IPADIC(mecab-ipadic-2.7.0-20070801)
lindera/src/main.rs
use lindera::tokenizer::Tokenizer;
use lindera_core::LinderaResult;

use std::env;

fn main() -> LinderaResult<()> {
    let text = env::args().nth(1).unwrap_or("".to_string());

    let mut tokenizer = Tokenizer::new()?;
    let ts = tokenizer.tokenize(&text)?;

    for t in ts {
        let pos = t.detail.get(0..4).unwrap_or(&t.detail);

        println!("text={}, partOfSpeech={:?}", t.text, pos);
    }

    Ok(())
}
結果
> cargo run "WebAssemblyがサーバーレス分野へ大きな影響を与えるだろうと答えた回答者は全体の56%だった。"

・・・
text=WebAssembly, partOfSpeech=["UNK"]
text=が, partOfSpeech=["助詞", "格助詞", "一般", "*"]
text=サーバー, partOfSpeech=["名詞", "一般", "*", "*"]
text=レス, partOfSpeech=["名詞", "サ変接続", "*", "*"]
text=分野, partOfSpeech=["名詞", "一般", "*", "*"]
text=へ, partOfSpeech=["助詞", "格助詞", "一般", "*"]
text=大きな, partOfSpeech=["連体詞", "*", "*", "*"]
text=影響, partOfSpeech=["名詞", "サ変接続", "*", "*"]
text=を, partOfSpeech=["助詞", "格助詞", "一般", "*"]
text=与える, partOfSpeech=["動詞", "自立", "*", "*"]
text=だろ, partOfSpeech=["助動詞", "*", "*", "*"]
text=う, partOfSpeech=["助動詞", "*", "*", "*"]
text=と, partOfSpeech=["助詞", "格助詞", "引用", "*"]
text=答え, partOfSpeech=["動詞", "自立", "*", "*"]
text=た, partOfSpeech=["助動詞", "*", "*", "*"]
text=回答, partOfSpeech=["名詞", "サ変接続", "*", "*"]
text=者, partOfSpeech=["名詞", "接尾", "一般", "*"]
text=は, partOfSpeech=["助詞", "係助詞", "*", "*"]
text=全体, partOfSpeech=["名詞", "副詞可能", "*", "*"]
text=の, partOfSpeech=["助詞", "連体化", "*", "*"]
text=5, partOfSpeech=["名詞", "数", "*", "*"]
text=6, partOfSpeech=["名詞", "数", "*", "*"]
text=%, partOfSpeech=["名詞", "接尾", "助数詞", "*"]
text=だっ, partOfSpeech=["助動詞", "*", "*", "*"]
text=た, partOfSpeech=["助動詞", "*", "*", "*"]
text=。, partOfSpeech=["記号", "句点", "*", "*"]

a1 と概ね同じ結果となりましたが、"WebAssembly" が名詞になっていないのが特徴。

RLlib を使ってナップサック問題を強化学習2

局所最適に陥っていたと思われる 前回 に対して、以下の改善案 ※ を思いついたので試してみました。

  • より困難な目標を達成した場合に報酬(価値)へボーナスを加算
 ※ 局所最適から脱して、より良い結果を目指す効果を期待

今回のサンプルコードは http://github.com/fits/try_samples/tree/master/blog/20201019/

サンプル1 改良版(ボーナス加算)

単一操作(品物の 1つを -1 or +1 するか何もしない)を行動とした(前回の)サンプル1 にボーナスを加算する処理を加えてみました。

とりあえず、価値の合計が 375(0-1 ナップサック問題としての最適解)を超えた場合に報酬へ +200 するようにしてみます。

前回から、変数や関数名を一部変更していますが、基本的な処理内容に変更はありません。

また、PPOTrainer では episode_reward_mean / vf_clip_param の値が 200 を超えると警告ログを出すようなので(ppo.pywarn_about_bad_reward_scales)、config で vf_clip_param(デフォルト値は 10)の値を変更するようにしています。

sample1_bonus.ipynb
・・・
def next_state(items, state, action):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        state[idx] += (1 if act == 1 else -1)

    return state

def calc_value(items, state, max_weight, burst_value):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_value
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.episode_steps = config["episode_steps"]
        self.burst_reward = config["burst_reward"]
        self.bonus_rules = config["bonus_rules"]
        
        n = self.episode_steps
        
        self.action_space = Discrete(len(self.items) * 2 + 1)
        self.observation_space = Box(low = -n, high = n, shape = (len(self.items), ))
        
        self.reset()

    def reset(self):
        self.current_steps = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = next_state(self.items, self.state, action)
        
        r, _ = calc_value(self.items, self.state, self.max_weight, self.burst_reward)
        reward = r
        
        # 段階的なボーナス加算
        for (v, b) in self.bonus_rules:
            if r > v:
                reward += b
        
        self.current_steps += 1
        done = self.current_steps >= self.episode_steps
        
        return self.state, reward, done, {}

items = [
    [105, 10],
    [74, 7],
    [164, 15],
    [32, 3],
    [235, 22]
]

config = {
    "env": Knapsack, 
    "vf_clip_param": 60,
    "env_config": {
        "items": items, "episode_steps": 10, "max_weight": 35, "burst_reward": -100, 
        # ボーナスの設定
        "bonus_rules": [ (375, 200) ]
    }
}

・・・

trainer = PPOTrainer(config = config)

・・・
# 30回の学習
for _ in range(30):
    r = trainer.train()
    ・・・

・・・

rs = []

# 1000回試行
for _ in range(1000):
    
    s = [0 for _ in range(len(items))]
    r_tmp = config["env_config"]["burst_reward"]

    for _ in range(config["env_config"]["episode_steps"]):
        a = trainer.compute_action(s)
        s = next_state(items, s, a)

        r, w = calc_value(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
        
        r_tmp = max(r, r_tmp)

    rs.append(r_tmp)

collections.Counter(rs)

上記の結果(30回の学習後に 1000回試行してそれぞれの最高値をカウント)は以下のようになりました。

結果
Counter({376: 957, 375: 42, 334: 1})

最適解の 376 が出るようになっており、ボーナスの効果はそれなりにありそうです。

ただし、毎回このような結果になるわけではなく、前回と同じように 375(0-1 ナップサック問題としての最適解)止まりとなる場合もあります。

検証

次に、ナップサック問題の内容を変えて検証してみます。

ここでは、「2.5 ナップサック問題 - 数理システム」の例題を題材として、状態の範囲やボーナスの内容を変えると結果にどのような差が生じるのかを確認します。

ナップサック問題の内容

価値 サイズ
120 10
130 12
80 7
100 9
250 21
185 16

最大容量(サイズ) 65 における最適解は以下の通りです。

価値 770 の組み合わせ(最適解)
3, 0, 2, 0, 1, 0
価値 745 の組み合わせ(0-1 ナップサック問題の最適解)
0, 1, 1, 1, 1, 1

1. 単一操作(品物の 1つを -1 or +1 するか何もしない)

行動は前回の サンプル1 と同様の以下とします。

  • 品物のどれか 1つを -1 or +1 するか、何も変更しない

ここでは、以下のような状態範囲とボーナスを試しました。

状態範囲(品物毎の個数の範囲)
状態タイプ 最小値 最大値
a -10 10
b 0 5
c 0 3
ボーナス定義
ボーナス定義タイプ v > 750 v > 760 v > 765
0 0 0 0
1 100 100 100
2 100 200 400

ボーナスは段階的に加算し、ボーナス定義タイプ 2 で価値が仮に 770 だった場合は、700(100 + 200 + 400)を加算する事にします。

また、状態(品物毎の個数)はその範囲を超えないよう最小値もしくは最大値で止まるようにしました。

なお、ここからは Jupyter Notebook ではなく Python スクリプトとして実行します。

test1.py
import sys
import numpy as np

import gym
from gym.spaces import Discrete, Box

import ray
from ray.rllib.agents.ppo import PPOTrainer

import collections

N = int(sys.argv[1])
EPISODE_STEPS = int(sys.argv[2])
STATE_TYPE = sys.argv[3]
BONUS_TYPE = sys.argv[4]

items = [
    [120, 10],
    [130, 12],
    [80, 7],
    [100, 9],
    [250, 21],
    [185, 16]
]

state_types = {
    "a": (-10, 10),
    "b": (0, 5),
    "c": (0, 3)
}

bonus_types = {
    "0": [],
    "1": [(750, 100), (760, 100), (765, 100)],
    "2": [(750, 100), (760, 200), (765, 400)]
}

vf_clip_params = {
    "0": 800,
    "1": 1100,
    "2": 1500
}

def next_state(items, state, action, state_range):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        v = state[idx] + (1 if act == 1 else -1)
        # 状態が範囲内に収まるように調整
        state[idx] = min(state_range[1], max(state_range[0], v))

    return state

def calc_value(items, state, max_weight, burst_value):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_value
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.episode_steps = config["episode_steps"]
        self.burst_reward = config["burst_reward"]
        self.state_range = config["state_range"]
        self.bonus_rules = config["bonus_rules"]
        
        self.action_space = Discrete(len(self.items) * 2 + 1)
        
        self.observation_space = Box(
            low = self.state_range[0], 
            high = self.state_range[1], 
            shape = (len(self.items), )
        )

        self.reset()

    def reset(self):
        self.current_steps = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = next_state(self.items, self.state, action, self.state_range)
        
        r, _ = calc_value(self.items, self.state, self.max_weight, self.burst_reward)
        reward = r

        for (v, b) in self.bonus_rules:
            if r > v:
                reward += b
        
        self.current_steps += 1
        done = self.current_steps >= self.episode_steps
        
        return self.state, reward, done, {}

config = {
    "env": Knapsack, 
    "vf_clip_param": vf_clip_params[BONUS_TYPE],
    "env_config": {
        "items": items, "max_weight": 65, "burst_reward": -100, 
        "episode_steps": EPISODE_STEPS, 
        "state_range": state_types[STATE_TYPE], 
        "bonus_rules": bonus_types[BONUS_TYPE]
    }
}

ray.init()

trainer = PPOTrainer(config = config)

for _ in range(N):
    r = trainer.train()
    print(f'iter = {r["training_iteration"]}')

print(f'N = {N}, EPISODE_STEPS = {EPISODE_STEPS}, state_type = {STATE_TYPE}, bonus_type = {BONUS_TYPE}')

rs = []

for _ in range(1000):
    s = [0 for _ in range(len(items))]
    r_tmp = config["env_config"]["burst_reward"]

    for _ in range(config["env_config"]["episode_steps"]):
        a = trainer.compute_action(s)
        s = next_state(items, s, a, config["env_config"]["state_range"])

        r, w = calc_value(
            items, s, 
            config["env_config"]["max_weight"], config["env_config"]["burst_reward"]
        )
        
        r_tmp = max(r, r_tmp)

    rs.append(r_tmp)

print( collections.Counter(rs) )

ray.shutdown()
実行例
> python test1.py 50 10 a 0

学習回数 50、1エピソードのステップ数 10 で学習した後、1000回の試行で最も件数の多かった価値を列挙する処理を 3回実施した結果です。(() 内の数値は 1000回の内にその値が最高値だった件数)

結果(学習回数 = 50、エピソードのステップ数 = 10)
状態タイプ ボーナス定義タイプ 状態の最小値 状態の最大値 765 超過時の総ボーナス 1回目 2回目 3回目
a-0 a 0 -10 10 0 735 (994) 735 (935) 735 (916)
a-1 a 1 -10 10 +300 745 (965) 745 (977) 735 (976)
a-2 a 2 -10 10 +700 735 (945) 735 (971) 770 (1000)
b-0 b 0 0 5 0 750 (931) 750 (829) 750 (1000)
b-1 b 1 0 5 +300 765 (995) 765 (998) 750 (609)
b-2 b 2 0 5 +700 765 (1000) 765 (995) 765 (998)
c-0 c 0 0 3 0 750 (998) 750 (996) 750 (1000)
c-1 c 1 0 3 +300 765 (1000) 750 (993) 765 (1000)
c-2 c 2 0 3 +700 765 (999) 765 (1000) 770 (999)

やはり、ボーナスは有効そうですが、状態タイプ a のように状態の範囲が広く、ボーナスの発生頻度が低くなるようなケースでは有効に働かない可能性も高そうです。

ボーナス定義タイプ 2 で最適解の 770 が出るようになっているものの、頻出するようなものでも無く、たまたま学習が上手くいった場合にのみ発生しているような印象でした。

また、b-1 の 3回目で件数が他と比べて低くなっていますが、こちらは学習が(順調に進まずに)足りていない状態だと考えられます。

2. 一括操作(全品物をそれぞれ -1 or 0 or +1 する)

次に、行動を以下のように変えて同じように検証してみます。

  • 全ての品物を対象にそれぞれを -1 or 0 or +1 する

こちらは、ボーナス加算タイプを 1種類追加しました。

状態範囲(品物毎の個数の範囲)
状態タイプ 最小値 最大値
a -10 10
b 0 5
c 0 3
ボーナス加算
ボーナス加算タイプ v > 750 v > 760 v > 765
0 0 0 0
1 100 100 100
2 100 200 400
3 200 400 800

行動の変更に伴い action_spaceBox で定義しています。

test2.py
・・・

state_types = {
    "a": (-10, 10),
    "b": (0, 5),
    "c": (0, 3)
}

bonus_types = {
    "0": [],
    "1": [(750, 100), (760, 100), (765, 100)],
    "2": [(750, 100), (760, 200), (765, 400)],
    "3": [(750, 200), (760, 400), (765, 800)]
}

・・・

def next_state(items, state, action, state_range):
    for i in range(len(action)):
        v = state[i] + round(action[i])
        state[i] = min(state_range[1], max(state_range[0], v))

    return state

・・・

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.episode_steps = config["episode_steps"]
        self.burst_reward = config["burst_reward"]
        self.state_range = config["state_range"]
        self.bonus_rules = config["bonus_rules"]
        # 品物毎の -1 ~ 1
        self.action_space = Box(low = -1, high = 1, shape = (len(self.items), ))
        
        self.observation_space = Box(
            low = self.state_range[0], 
            high = self.state_range[1], 
            shape = (len(self.items), )
        )

        self.reset()

    def reset(self):
        self.current_steps = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        self.state = next_state(self.items, self.state, action, self.state_range)
        
        r, _ = calc_value(self.items, self.state, self.max_weight, self.burst_reward)
        reward = r

        for (v, b) in self.bonus_rules:
            if r > v:
                reward += b
        
        self.current_steps += 1
        done = self.current_steps >= self.episode_steps
        
        return self.state, reward, done, {}

・・・

こちらの方法では、学習回数 50 では明らかに足りなかったので 100 にして実施しました。

結果(学習回数 = 100、エピソードのステップ数 = 10)
状態タイプ ボーナス加算タイプ 状態の最小値 状態の最大値 765 超過時の総ボーナス 1回目 2回目 3回目
a-0 a 0 -10 10 0 735 (477) 735 (531) 735 (714)
a-1 a 1 -10 10 +300 735 (689) 735 (951) 745 (666)
a-2 a 2 -10 10 +700 735 (544) 735 (666) 735 (719)
a-3 a 3 -10 10 +1400 745 (633) 735 (735) 735 (875)
b-0 b 0 0 5 0 735 (364) 760 (716) 740 (590)
b-1 b 1 0 5 +300 735 (935) 760 (988) 655 (685)
b-2 b 2 0 5 +700 760 (1000) 735 (310) 770 (963)
b-3 b 3 0 5 +1400 675 (254) 770 (1000) 770 (909)
c-0 c 0 0 3 0 735 (762) 740 (975) 740 (669)
c-1 c 1 0 3 +300 740 (935) 740 (842) 735 (963)
c-2 c 2 0 3 +700 770 (999) 770 (1000) 715 (508)
c-3 c 3 0 3 +1400 770 (1000) 770 (1000) 770 (1000)

学習の足りていない所が散見されますが(学習も不安定)、特定のタイプで最適解の 770 が割と頻繁に出るようになりました。

ただ、c-3 の場合でも 770 が出やすくなっているものの、確実にそのように学習するわけではありませんでした。

結局のところ、状態・行動・報酬の設計次第という事かもしれません。

RLlib を使ってナップサック問題を強化学習

ナップサック問題強化学習を適用すると、どうなるのか気になったので試してみました。

強化学習には、Ray に含まれている RLlib を使い、Jupyter Notebook 上で実行します。

今回のサンプルコードは http://github.com/fits/try_samples/tree/master/blog/20200922/

はじめに

以下のようにして Ray と RLlib をインストールしておきます。(TensorFlow も事前にインストールしておく)

Ray インストール
> pip install ray[rllib]

ナップサック問題

今回は、以下のような価値と重さを持った品物に対して、重さの合計が 35 以下で価値の合計を最大化する品物の組み合わせを探索する事にします。

価値 重さ
105 10
74 7
164 15
32 3
235 22

品物をそれぞれ 1個までしか選べない場合(0-1 ナップサック問題)の最適な組み合わせは、以下のように 5番目以外を 1個ずつ選ぶ事です。(価値の合計は 375

価値 375 の組み合わせ(0-1 ナップサック問題の最適解)
1, 1, 1, 1, 0

また、同じ品物をいくらでも選べる場合の最適な組み合わせは以下のようになります。(価値の合計は 376

価値 376 の組み合わせ
0, 2, 1, 2, 0

強化学習でこのような組み合わせを導き出す事ができるのか確認します。

1. サンプル1 - sample1.ipynb

とりあえず、強化学習における状態・行動・報酬を以下のようにしてみました。 エピソードは、指定した回数(今回は 10回)の行動を行う事で終了とします。

状態 行動 (即時)報酬
品物毎の個数 品物の個数を操作(-1, +1) 価値の合計

行動は以下のような 0 ~ 10 の数値で表現する事にします。

  • 0 = 1番目の品物の個数を -1
  • 1 = 1番目の品物の個数を +1
  • ・・・
  • 8 = 5番目の品物の個数を -1
  • 9 = 5番目の品物の個数を +1
  • 10 = 個数を変更しない(現状維持)

これらを OpenAI Gym で定義したのが次のコードです。

環境は gym.Env を継承し、__init__ で行動空間(action_space)と状態空間(observation_space)を定義、reset で状態の初期化、step で状態の更新と報酬の算出を行うように実装します。

step の戻り値は、更新後の状態報酬エピソード終了か否か(デバッグ用途等の)情報 となっています。

Discrete(n) は 0 ~ n - 1 の整数値、Box は low ~ high の実数値の多次元配列となっており、、行動空間と状態空間の定義にそれぞれ使用しています。

環境定義
import gym
from gym.spaces import Discrete, Box

def next_state(items, state, action):
    idx = action // 2
    act = action % 2

    if idx < len(items):
        state[idx] += (1 if act == 1 else -1)

    return state

# 報酬の算出
def calc_reward(items, state, max_weight, burst_reward):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_reward
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        # 重さの上限値
        self.max_weight = config["max_weight"]
        # 行動の回数
        self.max_count = config["max_count"]
        # 重さが超過するか、個数が負の数となった場合の報酬
        self.burst_reward = config["burst_reward"]
        
        n = self.max_count
        
        # 行動空間の定義
        self.action_space = Discrete(len(self.items) * 2 + 1)
        # 状態空間の定義
        self.observation_space = Box(low = -n, high = n, shape = (len(self.items), ))
        
        self.reset()

    def reset(self):
        self.count = 0
        self.state = [0 for _ in self.items]
        
        return self.state

    def step(self, action):
        # 状態の更新
        self.state = next_state(self.items, self.state, action)
        # 報酬の算出
        reward, _ = calc_reward(self.items, self.state, self.max_weight, self.burst_reward)
        
        self.count += 1
        # エピソード完了の判定
        done = self.count >= self.max_count
        
        return self.state, reward, done, {}

次に上記環境のための設定を行います。 env_config の内容が __init__ の config 引数となります。

重さの上限値を超えた場合や個数が負の数となった場合の報酬(burst_reward)をとりあえず -100 としています。

なお、基本的に RLlib のデフォルト設定値を使う事にします。

設定
items = [
    [105, 10],
    [74, 7],
    [164, 15],
    [32, 3],
    [235, 22]
]

config = {
    "env": Knapsack, 
    "env_config": {"items": items, "max_count": 10, "max_weight": 35, "burst_reward": -100}
}

強化学習を実施する前に、Ray を初期化しておきます。

Ray 初期化
import ray

ray.init()

(a) PPO(Proximal Policy Optimization)

PPO アルゴリズムを試してみます。

トレーナーの定義 - PPO
from ray.rllib.agents.ppo import PPOTrainer

trainer = PPOTrainer(config = config)

まずは、学習(train の実行)を 10回繰り返してみます。

後で学習時の状況を確認するために episode_reward_max 等の値を保持するようにしています。

学習
r_max = []
r_min = []
r_mean = []
from ray.tune.logger import pretty_print

for _ in range(10):
    r = trainer.train()
    print(pretty_print(r))

    r_max.append(r["episode_reward_max"])   # 最大
    r_min.append(r["episode_reward_min"])   # 最小
    r_mean.append(r["episode_reward_mean"]) # 平均

以下のコードで結果を確認してみます。 compute_action を呼び出す事で、指定した状態に対する行動を取得できます。

評価1
s = [0 for _ in range(len(items))]

for _ in range(config["env_config"]["max_count"]):
    a = trainer.compute_action(s)
    
    s = next_state(items, s, a)
    
    r, w = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
    
    print(f"{a}, {s}, {r}, {w}")

下記のように、価値の合計が 375 となる組み合わせ(0-1 ナップサック問題とした場合の最適解)が現れており、ある程度は学習できているように見えます。

評価1の結果 - PPO 学習 10回
1, [1, 0, 0, 0, 0], 105, 10
3, [1, 1, 0, 0, 0], 179, 17
5, [1, 1, 1, 0, 0], 343, 32
7, [1, 1, 1, 1, 0], 375, 35
7, [1, 1, 1, 2, 0], -100, 38
6, [1, 1, 1, 1, 0], 375, 35
0, [0, 1, 1, 1, 0], 270, 25
1, [1, 1, 1, 1, 0], 375, 35
6, [1, 1, 1, 0, 0], 343, 32
7, [1, 1, 1, 1, 0], 375, 35

これだけだとよく分からないので、100回繰り返してそれぞれの報酬の最高値をカウントしてみます。

評価2
import collections

rs = []

for _ in range(100):
    
    s = [0 for _ in range(len(items))]
    r_tmp = config["env_config"]["burst_reward"]

    for _ in range(config["env_config"]["max_count"]):
        a = trainer.compute_action(s)
        s = next_state(items, s, a)

        r, w = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
        
        r_tmp = max(r, r_tmp)

    rs.append(r_tmp)

collections.Counter(rs)

結果は以下のようになりました。 最高値の 376 ではなく 375 の組み合わせへ向かうように学習が進んでいるように見えます。

評価2の結果 - PPO 学習 10回
Counter({302: 6,
         309: 2,
         284: 1,
         343: 12,
         375: 39,
         270: 4,
         341: 7,
         373: 1,
         340: 2,
         269: 1,
         372: 6,
         301: 4,
         376: 1,
         334: 2,
         333: 2,
         344: 1,
         299: 4,
         317: 1,
         275: 1,
         349: 1,
         316: 1,
         312: 1})

更に、学習を 10回繰り返した後の結果です。 375 へ向かって収束しているように見えます。

評価1の結果 - PPO 学習 20回
1, [1, 0, 0, 0, 0], 105, 10
5, [1, 0, 1, 0, 0], 269, 25
3, [1, 1, 1, 0, 0], 343, 32
7, [1, 1, 1, 1, 0], 375, 35
10, [1, 1, 1, 1, 0], 375, 35
10, [1, 1, 1, 1, 0], 375, 35
10, [1, 1, 1, 1, 0], 375, 35
2, [1, 0, 1, 1, 0], 301, 28
3, [1, 1, 1, 1, 0], 375, 35
0, [0, 1, 1, 1, 0], 270, 25
評価2の結果 - PPO 学習 20回
Counter({375: 69, 373: 18, 376: 3, 341: 5, 372: 1, 302: 1, 344: 3})

50回学習した後の結果は以下のようになりました。

評価2の結果 - PPO 学習 50回
Counter({375: 98, 373: 2})

ここまでの(学習時の)報酬状況をグラフ化してみます。

学習回数と報酬のグラフ化
%matplotlib inline

import matplotlib.pyplot as plt

plt.plot(r_max, label = "reward_max", color = "red")
plt.plot(r_min, label = "reward_min", color = "green")
plt.plot(r_mean, label = "reward_mean", color = "blue")

plt.legend(loc = "upper left")

plt.ylim([-1000, 3700])
plt.ylabel("reward")

plt.show()
学習時の報酬グラフ - PPO 学習 50回

f:id:fits:20200922223349p:plain

エピソード内の報酬を高めていくには重量超過やマイナス個数を避けるのが重要、それには各品物の個数を 0 か 1 にしておくのが無難なため、0-1 ナップサック問題としての最適解へ向かっていくのかもしれません。

(b) DQN(Deep Q-Network)

比較のために DQN でも実行してみます。 PPOTrainer の代わりに DQNTrainer を使うだけです。

トレーナーの定義 - DQN
from ray.rllib.agents.dqn import DQNTrainer

trainer = DQNTrainer(config = config)

10回の学習では以下のような結果となりました。

評価1の結果 - DQN 学習 10回
5, [0, 0, 1, 0, 0], 164, 15
7, [0, 0, 1, 1, 0], 196, 18
2, [0, -1, 1, 1, 0], -100, 11
5, [0, -1, 2, 1, 0], -100, 26
2, [0, -2, 2, 1, 0], -100, 19
0, [-1, -2, 2, 1, 0], -100, 9
1, [0, -2, 2, 1, 0], -100, 19
3, [0, -1, 2, 1, 0], -100, 26
5, [0, -1, 3, 1, 0], -100, 41
6, [0, -1, 3, 0, 0], -100, 38
評価2の結果 - DQN 学習 10回
Counter({-100: 33,
         360: 1,
         105: 8,
         270: 2,
         372: 5,
         315: 1,
         0: 5,
         374: 1,
         74: 2,
         274: 1,
         235: 7,
         340: 3,
         32: 4,
         164: 5,
         238: 1,
         309: 2,
         106: 1,
         228: 2,
         169: 1,
         267: 3,
         343: 1,
         196: 1,
         331: 1,
         363: 1,
         254: 1,
         299: 1,
         317: 1,
         212: 1,
         301: 1,
         328: 1,
         138: 1,
         64: 1})

DQN は PPO に比べて episodes_total(学習で実施したエピソード数の合計)が 1/4 程度と少なかったので、40回まで実施してみました。

評価1の結果 - DQN 学習 40回
6, [0, 0, 0, -1, 0], -100, -3
7, [0, 0, 0, 0, 0], 0, 0
1, [1, 0, 0, 0, 0], 105, 10
4, [1, 0, -1, 0, 0], -100, -5
4, [1, 0, -2, 0, 0], -100, -20
8, [1, 0, -2, 0, -1], -100, -42
7, [1, 0, -2, 1, -1], -100, -39
1, [2, 0, -2, 1, -1], -100, -29
2, [2, -1, -2, 1, -1], -100, -36
4, [2, -1, -3, 1, -1], -100, -51
評価2の結果 - DQN 学習 40回
Counter({0: 5,
         328: 4,
         164: 8,
         -100: 28,
         340: 5,
         228: 1,
         309: 4,
         74: 1,
         238: 4,
         235: 7,
         32: 3,
         358: 2,
         343: 2,
         196: 2,
         269: 3,
         372: 2,
         365: 1,
         179: 2,
         374: 1,
         302: 1,
         106: 3,
         312: 1,
         105: 5,
         270: 2,
         148: 2,
         360: 1})

80回学習した結果です。

評価1 - DQN 学習 80回
5, [0, 0, 1, 0, 0], 164, 15
1, [1, 0, 1, 0, 0], 269, 25
7, [1, 0, 1, 1, 0], 301, 28
4, [1, 0, 0, 1, 0], 137, 13
0, [0, 0, 0, 1, 0], 32, 3
2, [0, -1, 0, 1, 0], -100, -4
7, [0, -1, 0, 2, 0], -100, -1
1, [1, -1, 0, 2, 0], -100, 9
0, [0, -1, 0, 2, 0], -100, -1
4, [0, -1, -1, 2, 0], -100, -16
評価2 - DQN 学習 80回
Counter({-100: 5,
         269: 8,
         164: 20,
         301: 3,
         211: 2,
         235: 7,
         105: 3,
         372: 2,
         333: 1,
         238: 2,
         316: 2,
         196: 2,
         253: 1,
         242: 2,
         312: 3,
         343: 3,
         340: 9,
         179: 2,
         267: 3,
         270: 1,
         74: 3,
         328: 9,
         309: 2,
         284: 1,
         137: 1,
         285: 1,
         360: 1,
         374: 1})

200回学習した結果です。

評価2 - DQN 学習 200回
Counter({-100: 8,
         238: 4,
         340: 10,
         0: 4,
         267: 6,
         106: 1,
         74: 7,
         105: 1,
         328: 4,
         235: 33,
         309: 9,
         228: 1,
         270: 2,
         32: 1,
         301: 1,
         196: 3,
         341: 2,
         269: 1,
         374: 1,
         372: 1})

学習時の報酬グラフは以下の通り、PPO のようにスムーズに学習が進んでおらず、DQN は本件に向いていないのかもしれません。

学習時の報酬グラフ - DQN 学習 200回

f:id:fits:20200922223834p:plain

これは、報酬のクリッピング ※ に因るものかとも思いましたが、RLlib における報酬クリッピングの設定 clip_rewards はデフォルトで None であり、DQN のデフォルト設定 (dqn.py の DEFAULT_CONFIG) においても有効化しているような箇所は見当たりませんでした。

他の箇所で実施しているのかもしれませんが、今回は確認できませんでした。

※ 基本的には、
   元の報酬の符号に合わせて 1, -1, 0 のいずれかに報酬の値が変更されてしまい、
   報酬の大小は考慮されなくなる

   本件であれば、
   重量超過やマイナス個数のように報酬がマイナスにならない行動であれば
   どれでも良い事になってしまうと思われる

ちなみに、clip_rewardsTrue に設定した PPOTrainer を使ってみたところ、結果が著しく悪化しました。(マイナス報酬を避けるだけになった)

2. サンプル2 - sample2.ipynb

次に、行動によってのみ次の状態を決定するようにしてみます。 エピソードは 1回の行動で終了とします。

状態 行動 (即時)報酬
品物毎の個数 品物毎の個数 価値の合計

この場合、action_spaceBox で定義する事になります。 最軽量の品物が重要制限内で選択できる最大の個数を Box の最大値としています。

環境定義
def next_state(action):
    return [round(d) for d in action]

def calc_reward(items, state, max_weight, burst_reward):
    reward = 0
    weight = 0
    
    for i in range(len(state)):
        reward += items[i][0] * state[i]
        weight += items[i][1] * state[i]
    
    if weight > max_weight or min(state) < 0:
        reward = burst_reward
    
    return reward, weight

class Knapsack(gym.Env):
    def __init__(self, config):
        self.items = config["items"]
        self.max_weight = config["max_weight"]
        self.burst_reward = config["burst_reward"]
        # 個数の最大値
        n = self.max_weight // min(np.array(self.items)[:, 1])
        
        self.action_space = Box(0, n, shape = (len(self.items), ))
        self.observation_space = Box(0, n, shape = (len(self.items), ))

    def reset(self):
        return [0 for _ in self.items]

    def step(self, action):
        state = next_state(action)
        
        reward, _ = calc_reward(self.items, state, self.max_weight, self.burst_reward)
        
        return state, reward, True, {}
設定
items = [
    [105, 10],
    [74, 7],
    [164, 15],
    [32, 3],
    [235, 22]
]

config = {
    "env": Knapsack, 
    "env_config": {"items": items, "max_weight": 35, "burst_reward": -100}
}

(a) PPO(Proximal Policy Optimization)

PPO で実施してみます。

基本的な処理内容は 1. サンプル1 と同じですが、行動 1回でエピソードが終了するため、以下のコードで結果を確認する事にします。

評価
import collections

rs = []

for _ in range(100):
    s = [0 for _ in range(len(items))]
    a = trainer.compute_action(s)
    
    s = next_state(a)
    
    r, _ = calc_reward(items, s, config["env_config"]["max_weight"], config["env_config"]["burst_reward"])
    
    rs.append(r)

collections.Counter(rs)

70回学習後の結果です。

評価結果 - PPO 学習 70回
Counter({375.0: 100})

学習時の状況は以下のようになりました。

学習時の報酬グラフ - PPO 学習 70回

f:id:fits:20200922224006p:plain

(b) DDPG(Deep Deterministic Policy Gradient)

action_space が Box の場合に、DQNTrainer が UnsupportedSpaceException エラーを発生させるので、DQN は使えませんでした。

そこで、DDPG を使ってみました。

トレーナーの定義 - DDPG
from ray.rllib.agents.ddpg import DDPGTrainer

trainer = DDPGTrainer(config = config)

こちらは、PPO 等と比べて処理に時間がかかる上に、80回学習しても順調とはいえない結果となりました。

学習時の報酬グラフ - DDPG 学習 80回

f:id:fits:20200922224049p:plain

評価結果 - DDPG 学習 80回
Counter({347.0: 18, -100: 79, 242.0: 3})

以下のように評価の回数を 1000 にして再度確認してみます。

for _ in range(1000):
    s = [0 for _ in range(len(items))]
    a = trainer.compute_action(s)
    
    ・・・

collections.Counter(rs)
評価結果(1000回) - DDPG 学習 80回
Counter({347.0: 232,
         -100: 718,
         242.0: 24,
         315.0: 14,
         210.0: 1,
         316.0: 7,
         274.0: 3,
         374.0: 1})

347(3, 0, 0, 1, 0)が多くなっているのはよく分かりませんが、DDPG も本件に向いていないのかもしれません。

多腕バンディット問題とトンプソンサンプリング

多腕バンディット問題に対してベイズ的な手法をとるトンプソンサンプリングに興味を惹かれたので、「テストの実行 - C# を使用した Thompson Sampling」 を参考に Python で実装してみました。

ソースは http://github.com/fits/try_samples/tree/master/blog/20190603/

実装

ここで実装するのは、当たった場合の報酬を 1、外れた場合の報酬を 0 とするベルヌーイ試行によるバンディット問題です。

この場合、ベータ分布を 当たり数 + 1外れ数 + 1 のパラメータで ※ サンプリングした結果からアクションを選ぶだけのようです。

 ※ +1 するのは 0 にしないための措置だと思う

引数 arms でアクションの選択肢とその当たる(報酬が得られる)確率を渡すようにしています。

import numpy as np

def thompson_sampling(arms, n = 1000):
    states = [(0, 0) for _ in arms]

    # "当たり数 + 1" と "外れ数 + 1" でベータ分布からサンプリングし
    # アクションを決定する処理
    action = lambda: np.argmax([np.random.beta(s[0] + 1, s[1] + 1) for s in states])

    for _ in range(n):
        # アクションの決定
        a = action()
        # 当たり・外れの判定(報酬の算出)
        r = 1 if np.random.rand() < arms[a] else 0

        # 当たり・外れ数の更新
        states[a] = (states[a][0] + r, states[a][1] + 1 - r)

    return states

上記の結果を出力する処理は以下です。

def summary(states):
    for s in states:
        print(f'win={s[0]}, lose={s[1]}, p={s[0] / sum(s)}')

実行

アクションと報酬の得られる確率が以下のような構成に対して、試行回数 1,000 で実行してみます。

  • (a) 0.2
  • (b) 0.5
  • (c) 0.7
実行例1
summary( thompson_sampling([0.2, 0.5, 0.7]) )
win=2, lose=7, p=0.2222222222222222
win=7, lose=12, p=0.3684210526315789
win=673, lose=299, p=0.6923868312757202
実行例2
summary( thompson_sampling([0.2, 0.5, 0.7]) )
win=1, lose=7, p=0.125
win=11, lose=13, p=0.4583333333333333
win=690, lose=278, p=0.7128099173553719

実行の度に結果は異なりますが、アクションが選択された回数(win と lose の合計)に注目してみると、報酬の得られる確率が最も高い 3番目 (c) に集中している事が分かります。

つまり、最も報酬の得られる(確率の高い)アクションを選んでいる事になります。

実装2

報酬の得られる確率を thompson_sampling 関数の引数として与える上記の実装は微妙な気がするので、少し改良してみます。

確率を直接与える代わりに、アクションの選択肢と報酬を算出する関数を引数として与えるようにしてみました。

import numpy as np

def thompson_sampling(acts, reward_func, n = 1000):
    states = {a: (0, 0) for a in acts}
    
    def action():
        bs = {a: np.random.beta(s[0] + 1, s[1] + 1) for a, s in states.items()}
        return max(bs, key = bs.get)
    
    for _ in range(n):
        a = action()
        r = reward_func(a)
        
        states[a] = (states[a][0] + r, states[a][1] + 1 - r)
    
    return states

def probability_reward_func(probs):
    return lambda a: 1 if np.random.rand() < probs[a] else 0

def summary(states):
    for a, s in states.items():
        print(f'{a}: win={s[0]}, lose={s[1]}, p={s[0] / sum(s)}')
実行例1
probs1 = { 'a': 0.2, 'b': 0.5, 'c': 0.7 }

summary( thompson_sampling(probs1.keys(), probability_reward_func(probs1)) )
a: win=1, lose=7, p=0.125
b: win=17, lose=19, p=0.4722222222222222
c: win=674, lose=282, p=0.7050209205020921
実行例2
probs2 = { 'a': 0.2, 'b': 0.5, 'c': 0.7, 'd': 0.1, 'e': 0.8 }

summary( thompson_sampling(probs2.keys(), probability_reward_func(probs2)) )
a: win=4, lose=7, p=0.36363636363636365
b: win=10, lose=10, p=0.5
c: win=60, lose=25, p=0.7058823529411765
d: win=0, lose=4, p=0.0
e: win=701, lose=179, p=0.7965909090909091

CNN でランドマーク検出

前回の「CNNで輪郭の検出」 で試した手法を工夫し、ランドマーク(特徴点)検出へ適用してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190217/

輪郭の検出では画像をピクセル単位で二値分類(輪郭以外 = 0, 輪郭 = 1)しましたが、今回はこれを多クラス分類(ランドマーク以外 = 0, ランドマーク1 = 1, ランドマーク2 = 2, ・・・)へ変更します。

ちなみに、Deeplearning でランドマーク検出を行うような場合、ランドマークの座標を直接予測するような手法が考えられますが、今回試してみた限りでは納得のいく結果(座標の精度や汎用性など)を出せなくて、代わりに思いついたのが今回の手法となっています。

はじめに

データセット

今回は、DeepFashion: In-shop Clothes Retrieval のランドマーク用データセットから以下の条件を満たすものに限定して使います。

  • clothes_type の値が 1 (upper-body clothes)
  • variation_type の値が 1 (normal pose)
  • landmark_visibility_1 ~ 6 の値が 0(visible)

ランドマークには 6種類 (landmark_location_x_1 ~ 6、landmark_location_y_1 ~ 6) の座標を使います。

教師データ

入力データには画像を使うため、データ形状は (<バッチサイズ>, 256, 256, 3) ※ となります。

 ※ (<バッチサイズ>, <高さ>, <幅>, <チャンネル数>)

ラベルデータは landmark_location 1 ~ 6 の値を元に動的に生成します。

ピクセル単位でランドマーク以外(= 0)とランドマーク 1 ~ 6 の多クラス分類を行うため、データ形状は (<バッチサイズ>, 256, 256, 7) とします。

ランドマーク毎に 1ピクセルだけランドマークへ分類しても上手く学習できないので ※、一定の大きさ(範囲)をランドマークへ分類する必要があります。

 ※ 全てをランドマーク以外(= 0)とするようになってしまう

そこで、ランドマーク周辺の一定範囲をランドマークへ分類するとともに、以下の図(中心がランドマーク)のようにランドマークから離れると確率値が下がるように工夫します。

f:id:fits:20190217023108p:plain

学習

学習処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、list_landmarks_inshop.txt ファイルを読み込んで必要なデータを抜き出します。

今回は学習時間の短縮のため、先頭から 100件だけを使用しています。

データ読み込みとフィルタリング
import pandas as pd

df = pd.read_table('list_landmarks_inshop.txt', sep = '\s+', skiprows = 1)

s = 100

dfa = df[(df['clothes_type'] == 1) & (df['variation_type'] == 1) &
         (df['landmark_visibility_1'] == 0) & (df['landmark_visibility_2'] == 0) & 
         (df['landmark_visibility_3'] == 0) & (df['landmark_visibility_4'] == 0) &
         (df['landmark_visibility_5'] == 0) & (df['landmark_visibility_6'] == 0)][:s]

次に、入力データとして使う画像を読み込みます。

入力データ(画像)読み込み
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

imgs = np.array([ img_to_array(load_img(f)) for f in dfa['image_name']])

入力データの形状は以下のようになります。

imgs.shape
(100, 256, 256, 3)

(2) ラベルデータの生成

先述したように landmark_location の値から得られたランドマーク座標の周辺に確率値を設定していきます。

ここでは、確率の構成内容や確率値の設定対象とする周辺座標の取得処理を引数で指定できるようにしてみました。

また、他のランドマークの範囲と重なった場合、今回は単純に上書き(後勝ち)するようにしましたが、確率値の大きい方を選択するか確率値を分配するようにした方が望ましいと思われます。

ラベルデータ作成処理
cols = [f'landmark_location_{t}_{i + 1}' for i in range(6) for t in ['x', 'y'] ]
labels_t = dfa[cols].values.astype(int)

def gen_labels(prob, around_func):
    res = np.zeros(imgs.shape[:-1] + (int(len(cols) / 2) + 1,))
    res[:, :, :, 0] = 1.0
    
    for i in range(len(res)):
        r = res[i]
        
        # ランドマーク毎の設定
        for j in range(0, len(labels_t[i]), 2):
            # ランドマークの座標
            x = labels_t[i, j]
            y = labels_t[i, j + 1]
            
            # ランドマークの分類(1 ~ 6)
            c = int(j / 2) + 1
            
            for k in range(len(prob)):
                p = prob[k]
                
                # (相対的な)周辺座標の取得
                for a in around_func(k):
                    ax = x + a[0]
                    ay = y + a[1]
                    
                    if ax >= 0 and ax < imgs.shape[2] and ay >= 0 and ay < imgs.shape[1]:
                        # 他のランドマークと範囲が重なった場合への対応(設定値のクリア)
                        r[ay, ax, :] = 0.0
                        
                        # ランドマーク c へ該当する確率
                        r[ay, ax, c] = p
                        # ランドマーク以外へ該当する確率
                        r[ay, ax, 0] = 1.0 - p

    return res

今回は以下のような内容でラベルデータを作りました。

ラベルデータ作成
def around_square(n):
    return [(x, y) for x in range(-n, n + 1) for y in range(-n, n + 1) if abs(x) == n or abs(y) == n]

labels = gen_labels([1.0, 1.0, 1.0, 0.8, 0.8, 0.7, 0.7, 0.6, 0.6, 0.5], around_square)

ラベルデータの形状は以下の通りです。

labels.shape
(100, 256, 256, 7)

ラベルデータの内容確認

ランドマーク周辺の値を見てみると以下のようになっており、問題無さそうです。

labels[0, 59, 105:126]
array([[1. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0.5, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0. , 1. , 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.2, 0.8, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.3, 0.7, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.4, 0.6, 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0.5, 0. , 0. , 0. , 0. , 0. ],
       [1. , 0. , 0. , 0. , 0. , 0. , 0. ]])
labels[0, 50:71, 149]
array([[1. , 0. , 0. , 0. , 0. , 0. , 0. ],
       [0.5, 0. , 0.5, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0. , 0. , 1. , 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.2, 0. , 0.8, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.3, 0. , 0.7, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.4, 0. , 0.6, 0. , 0. , 0. , 0. ],
       [0.5, 0. , 0.5, 0. , 0. , 0. , 0. ],
       [1. , 0. , 0. , 0. , 0. , 0. , 0. ]])

これだけだと分かり難いので、単純な可視化を行ってみます。(ランドマークの該当確率をピクセル毎に合計しているだけ)

ラベルデータの可視化処理
matplotlib inline

import matplotlib.pyplot as plt

def imshow_label(index):
    plt.imshow(labels[index, :, :, 1:].sum(axis = -1), cmap = 'gray')
imshow_label(0)

f:id:fits:20190217023149p:plain

imshow_label(1)

f:id:fits:20190217023204p:plain

特に問題は無さそうです。

(3) CNN モデル

前回 と同様に Encoder-Decoder の構成を採用し、Encoder・Decoder をそれぞれ 1段階深くしました。(4段階に縮小して拡大)

多クラス分類を行うために、出力層の活性化関数を softmax にして、損失関数を categorical_crossentropy としています。

モデル内容
from keras.models import Model
from keras.layers import Input, Dense, Dropout, UpSampling2D
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(256, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(128, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(64, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(32, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(32, 3, padding = 'same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = UpSampling2D()(x)
x = Conv2DTranspose(16, 3, padding = 'same', activation = 'relu')(x)
x = Conv2DTranspose(16, 3, padding = 'same', activation = 'relu')(x)

x = Dropout(0.3)(x)

output = Dense(labels.shape[-1], activation = 'softmax')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'categorical_crossentropy', optimizer = 'adam', metrics = ['acc'])

model.summary()
model.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_7 (InputLayer)         (None, 256, 256, 3)       0         
_________________________________________________________________
batch_normalization_46 (Batc (None, 256, 256, 3)       12        
_________________________________________________________________
conv2d_56 (Conv2D)           (None, 256, 256, 16)      448       
_________________________________________________________________
conv2d_57 (Conv2D)           (None, 256, 256, 16)      2320      
_________________________________________________________________
max_pooling2d_20 (MaxPooling (None, 128, 128, 16)      0         
_________________________________________________________________
batch_normalization_47 (Batc (None, 128, 128, 16)      64        
_________________________________________________________________
dropout_46 (Dropout)         (None, 128, 128, 16)      0         
_________________________________________________________________
conv2d_58 (Conv2D)           (None, 128, 128, 32)      4640      
_________________________________________________________________
conv2d_59 (Conv2D)           (None, 128, 128, 32)      9248      
_________________________________________________________________
conv2d_60 (Conv2D)           (None, 128, 128, 32)      9248      
_________________________________________________________________
max_pooling2d_21 (MaxPooling (None, 64, 64, 32)        0         
_________________________________________________________________
batch_normalization_48 (Batc (None, 64, 64, 32)        128       
_________________________________________________________________
dropout_47 (Dropout)         (None, 64, 64, 32)        0         
_________________________________________________________________
conv2d_61 (Conv2D)           (None, 64, 64, 64)        18496     
_________________________________________________________________
conv2d_62 (Conv2D)           (None, 64, 64, 64)        36928     
_________________________________________________________________
conv2d_63 (Conv2D)           (None, 64, 64, 64)        36928     
_________________________________________________________________
max_pooling2d_22 (MaxPooling (None, 32, 32, 64)        0         
_________________________________________________________________
batch_normalization_49 (Batc (None, 32, 32, 64)        256       
_________________________________________________________________
dropout_48 (Dropout)         (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_64 (Conv2D)           (None, 32, 32, 128)       73856     
_________________________________________________________________
conv2d_65 (Conv2D)           (None, 32, 32, 128)       147584    
_________________________________________________________________
conv2d_66 (Conv2D)           (None, 32, 32, 128)       147584    
_________________________________________________________________
max_pooling2d_23 (MaxPooling (None, 16, 16, 128)       0         
_________________________________________________________________
batch_normalization_50 (Batc (None, 16, 16, 128)       512       
_________________________________________________________________
dropout_49 (Dropout)         (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_67 (Conv2D)           (None, 16, 16, 256)       295168    
_________________________________________________________________
batch_normalization_51 (Batc (None, 16, 16, 256)       1024      
_________________________________________________________________
dropout_50 (Dropout)         (None, 16, 16, 256)       0         
_________________________________________________________________
up_sampling2d_20 (UpSampling (None, 32, 32, 256)       0         
_________________________________________________________________
conv2d_transpose_44 (Conv2DT (None, 32, 32, 128)       295040    
_________________________________________________________________
conv2d_transpose_45 (Conv2DT (None, 32, 32, 128)       147584    
_________________________________________________________________
conv2d_transpose_46 (Conv2DT (None, 32, 32, 128)       147584    
_________________________________________________________________
batch_normalization_52 (Batc (None, 32, 32, 128)       512       
_________________________________________________________________
dropout_51 (Dropout)         (None, 32, 32, 128)       0         
_________________________________________________________________
up_sampling2d_21 (UpSampling (None, 64, 64, 128)       0         
_________________________________________________________________
conv2d_transpose_47 (Conv2DT (None, 64, 64, 64)        73792     
_________________________________________________________________
conv2d_transpose_48 (Conv2DT (None, 64, 64, 64)        36928     
_________________________________________________________________
conv2d_transpose_49 (Conv2DT (None, 64, 64, 64)        36928     
_________________________________________________________________
batch_normalization_53 (Batc (None, 64, 64, 64)        256       
_________________________________________________________________
dropout_52 (Dropout)         (None, 64, 64, 64)        0         
_________________________________________________________________
up_sampling2d_22 (UpSampling (None, 128, 128, 64)      0         
_________________________________________________________________
conv2d_transpose_50 (Conv2DT (None, 128, 128, 32)      18464     
_________________________________________________________________
conv2d_transpose_51 (Conv2DT (None, 128, 128, 32)      9248      
_________________________________________________________________
batch_normalization_54 (Batc (None, 128, 128, 32)      128       
_________________________________________________________________
dropout_53 (Dropout)         (None, 128, 128, 32)      0         
_________________________________________________________________
up_sampling2d_23 (UpSampling (None, 256, 256, 32)      0         
_________________________________________________________________
conv2d_transpose_52 (Conv2DT (None, 256, 256, 16)      4624      
_________________________________________________________________
conv2d_transpose_53 (Conv2DT (None, 256, 256, 16)      2320      
_________________________________________________________________
dropout_54 (Dropout)         (None, 256, 256, 16)      0         
_________________________________________________________________
dense_13 (Dense)             (None, 256, 256, 7)       119       
=================================================================
Total params: 1,557,971
Trainable params: 1,556,525
Non-trainable params: 1,446
_________________________________________________________________

(4) 学習

教師データ 100件では少なすぎると思いますが、今回はその中の 80件のみ学習に使用して 20件を検証に使ってみます。(validation_split で指定)

ここで、ランドマークとそれ以外でデータ数に大きな偏りがあるため(ランドマーク以外が大多数)、そのままでは上手く学習できない恐れがあります。

以下では class_weight を使ってランドマーク分類の重みを大きくしています。

実行例(351 ~ 400 エポック)
# 分類毎の重みを定義(ランドマークは 256*256 に設定)
wg = np.ones(labels.shape[-1]) * (imgs.shape[1] * imgs.shape[2])
# ランドマーク以外(= 0)の重み設定
wg[0] = 1

hist = model.fit(imgs, labels, initial_epoch = 350, epochs = 400, batch_size = 10, class_weight = wg, validation_split = 0.2)
結果例
Train on 80 samples, validate on 20 samples
Epoch 351/400
80/80 [===・・・ - loss: 0.0261 - acc: 0.9924 - val_loss: 0.1644 - val_acc: 0.9782
Epoch 352/400
80/80 [===・・・ - loss: 0.0263 - acc: 0.9924 - val_loss: 0.1638 - val_acc: 0.9784
・・・
Epoch 399/400
80/80 [===・・・ - loss: 0.0255 - acc: 0.9930 - val_loss: 0.1719 - val_acc: 0.9775
Epoch 400/400
80/80 [===・・・ - loss: 0.0253 - acc: 0.9931 - val_loss: 0.1720 - val_acc: 0.9777

(5) 確認

fit の戻り値から学習・検証の lossacc の値をそれぞれグラフ化してみます。

fit 結果表示
%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (16, 4)

plt.subplot(1, 4, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 4, 2)
plt.plot(hist.history['acc'])

plt.subplot(1, 4, 3)
plt.plot(hist.history['val_loss'])

plt.subplot(1, 4, 4)
plt.plot(hist.history['val_acc'])
結果例(351 ~ 400 エポック)

f:id:fits:20190217023309p:plain

val_lossval_acc の値が良くないのは、データ量が少なすぎる点にあると考えています。

(6) ランドマーク検出

下記 4種類の画像を出力して、ランドマーク検出(predict)結果とラベルデータ(正解)を比較してみます。

  • (a) ラベルデータの分類(ピクセル毎に確率値が最大の分類で色分け)
  • (b) 予測結果(predict)の分類(ピクセル毎に確率値が最大の分類で色分け)
  • (c) 元画像と (b) の重ね合わせ
  • (d) ランドマークの描画(各ランドマークの確率値が最大の座標へ円を描画)

今回はランドマークは分類毎に 1点のみなので、確率が最大値の座標がランドマークと判断できます。

ランドマーク検出と結果出力
import cv2

colors = [(255, 255, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255), (255, 165, 0), (210, 180, 140)]

def predict(index, n = 0, c_size = 5, s = 5.0):
    plt.rcParams['figure.figsize'] = (s * 4, s)
    
    img = imgs[index]

    # 予測結果(ランドマーク分類結果)
    p = model.predict(np.array([img]))[0]

    # (a) ラベルデータの分類(ピクセル毎に確率値が最大の分類で色分け)
    img1 = np.apply_along_axis(lambda x: colors[x.argmax()], -1, labels[index])
    # (b) 予測結果の分類(ピクセル毎に確率値が最大の分類で色分け)
    img2 = np.apply_along_axis(lambda x: colors[x.argmax()], -1, p)
    # (c) 元画像への重ね合わせ
    img3 = cv2.addWeighted(img.astype(int), 0.4, img2, 0.6, 0)
    
    plt.subplot(1, 4, 1)
    plt.imshow(img1)
    
    plt.subplot(1, 4, 2)
    plt.imshow(img2)
    
    plt.subplot(1, 4, 3)
    plt.imshow(img3)

    img4 = img.astype(int)

    pdf = pd.DataFrame(
        [[np.argmax(vx), x, y, np.max(vx)] for y, vy in enumerate(p) for x, vx in enumerate(vy)], 
        columns = ['landmark', 'x', 'y', 'prob']
    )
    
    for c, v in pdf[pdf['landmark'] > 0].sort_values('prob', ascending = False).groupby('landmark'):
        # (d) ランドマークを描画(確率値が最大の座標へ円を描画)
        img4 = cv2.circle(img4, tuple(v[['x', 'y']].values[0]), c_size, colors[c], -1)
        
        if n > 0:
            print(f"landmark {c} : x = {labels_t[index, (c - 1) * 2]}, {labels_t[index, (c - 1) * 2 + 1]}")
            print(v[:n])

    plt.subplot(1, 4, 4)
    plt.imshow(img4)

学習データの結果例

左から (a) ラベルデータの分類、(b) 予測結果の分類、(c) 元画像との重ね合わせ、(d) ランドマーク検出結果となっています。

f:id:fits:20190217035001p:plain f:id:fits:20190217035116p:plain

ラベルデータにかなり近い結果が出ているように見えます。

下記のように、ランドマーク毎の確率値 TOP 3 とラベルデータを数値で比較してみると、かなり近い値になっている事を確認できました。

predict(0, n = 3)
landmark 1 : x = 115, 59
       landmark    x   y      prob
15475         1  115  60  0.893763
15476         1  116  60  0.893605
15220         1  116  59  0.893044

landmark 2 : x = 149, 60
       landmark    x   y      prob
15510         2  150  60  0.878173
15766         2  150  61  0.872413
15509         2  149  60  0.872222

landmark 3 : x = 82, 153
       landmark   x    y      prob
39250         3  82  153  0.882741
39249         3  81  153  0.881362
39248         3  80  153  0.879979

landmark 4 : x = 185, 150
       landmark    x    y      prob
38841         4  185  151  0.836826
38585         4  185  150  0.836212
38840         4  184  151  0.836164

landmark 5 : x = 93, 198
       landmark   x    y      prob
50782         5  94  198  0.829380
50526         5  94  197  0.825815
51038         5  94  199  0.825342

landmark 6 : x = 171, 197
       landmark    x    y      prob
50602         6  170  197  0.881702
50603         6  171  197  0.880731
50858         6  170  198  0.877772
predict(40, n = 3)
landmark 1 : x = 120, 42
      landmark    x   y      prob
8820         1  116  34  0.568582
9075         1  115  35  0.566257
9074         1  114  35  0.561259

landmark 2 : x = 134, 40
       landmark    x   y      prob
10372         2  132  40  0.812515
10371         2  131  40  0.807980
10628         2  132  41  0.807899

landmark 3 : x = 109, 48
       landmark    x   y      prob
12652         3  108  49  0.839624
12653         3  109  49  0.838190
12396         3  108  48  0.837235

landmark 4 : x = 148, 43
       landmark    x   y      prob
11156         4  148  43  0.837879
10900         4  148  42  0.837810
11157         4  149  43  0.836910

landmark 5 : x = 107, 176
       landmark    x    y      prob
45164         5  108  176  0.845494
45420         5  108  177  0.841054
45163         5  107  176  0.839846

landmark 6 : x = 154, 182
       landmark    x    y      prob
46746         6  154  182  0.865920
46747         6  155  182  0.863970
46490         6  154  181  0.862724

なお、predict(40) におけるランドマーク 1(赤色)の結果が振るわないのは、ラベルデータの作り方の問題だと考えられます。(上書きでは無く確率値が大きい方を採用する等で改善するはず)

検証データの結果例

f:id:fits:20190217035136p:plain f:id:fits:20190217035148p:plain

当然ながら、学習に使っていないこちらのデータでは結果が悪化していますが、それなりに正しそうな位置を部分的に検出しているように見えます。

学習に使ったデータ量の少なさを考えると、かなり良好な結果が出ているようにも思います。

そもそも、predict(-3) のようなランドマークの左右が反転している背面からの画像なんてのは無理があるように思いますし、predict(-8) のランドマーク 5(水色)はラベルデータの方が間違っている(検出結果の方が正しい)ような気もします。

predict(-1, n = 3)
landmark 1 : x = 96, 60
       landmark   x   y      prob
15969         1  97  62  0.872259
16225         1  97  63  0.869837
15970         1  98  62  0.869681

landmark 2 : x = 126, 59
       landmark    x   y      prob
16254         2  126  63  0.866628
16255         2  127  63  0.865502
15998         2  126  62  0.864939

landmark 3 : x = 66, 125
       landmark   x    y      prob
30521         3  57  119  0.832024
30520         3  56  119  0.831721
30777         3  57  120  0.829537

landmark 4 : x = 157, 117
       landmark    x    y      prob
29099         4  171  113  0.814012
29098         4  170  113  0.813680
28843         4  171  112  0.812420
predict(-8, n = 3)
landmark 1 : x = 133, 40
       landmark    x   y      prob
10629         1  133  41  0.812287
10628         1  132  41  0.810564
10373         1  133  40  0.808298

landmark 2 : x = 157, 47
       landmark    x   y      prob
12704         2  160  49  0.767413
12448         2  160  48  0.764577
12703         2  159  49  0.762571

landmark 3 : x = 105, 77
       landmark    x   y     prob
19300         3  100  75  0.79014
19301         3  101  75  0.78945
19556         3  100  76  0.78496

landmark 4 : x = 181, 86
       landmark    x    y      prob
56242         4  178  219  0.768471
55986         4  178  218  0.768215
56243         4  179  219  0.766977

landmark 5 : x = 137, 211
       landmark    x    y      prob
54370         5   98  212  0.710897
54626         5   98  213  0.707652
54372         5  100  212  0.707127

CNN で輪郭の検出

画像内の物体の輪郭検出を CNN(畳み込みニューラルネット)で試してみました。

  • Keras + Tensorflow
  • Jupyter Notebook

ソースは http://github.com/fits/try_samples/tree/master/blog/20190114/

概要

今回は、画像をピクセル単位で輪郭か否かに分類する事(輪郭 = 1, 輪郭以外 = 0)で輪郭を検出できないか試しました。

そこで、教師データとして以下のような衣服単体の画像(jpg)と衣服の輪郭部分だけを白く塗りつぶした画像(png)を用意しました。

f:id:fits:20190114225837j:plain

教師データを大量に用意するのは困難だったため、240x288 の画像 160 ファイルで学習を行っています。

学習

学習の処理は Jupyter Notebook 上で実行しました。

(1) 入力データの準備

まずは、入力画像(jpg)を読み込みます。(教師データの画像は img ディレクトリへ配置しています)

import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array

files = glob.glob('img/*.jpg')

imgs = np.array([img_to_array(load_img(f)) for f in files])

imgs.shape

入力データの形状は以下の通りです。

imgs.shape 結果
(160, 288, 240, 3)

(2) ラベルデータの準備

輪郭画像(png)を読み込み、128 を境にして二値化(輪郭 = 1、輪郭以外 = 0)します。

import os

th = 128

labels = np.array([img_to_array(load_img(f"{os.path.splitext(f)[0]}.png", color_mode = 'grayscale')) for f in files])

labels[labels < th] = 0
labels[labels >= th] = 1

labels.shape

ラベルデータの形状は以下の通りです。

labels.shape 結果
(160, 288, 240, 1)

(3) CNN モデル

どのようなネットワーク構成が適しているのか分からなかったので、セマンティックセグメンテーション等で用いられている Encoder-Decoder の構成を参考にしてみました。

30x36 まで段階的に縮小して(Encoder)、元の大きさ 240x288 まで段階的に拡大する(Decoder)ようにしています。

最終層の活性化関数に sigmoid を使って 0 ~ 1 の値となるようにしています。

損失関数は binary_crossentropy を使うと進捗が遅そうに見えたので ※、代わりに mean_squared_error を使っています。

 ※ 今回の場合、輪郭(= 1)に該当するピクセルの方が少なくなるため
    binary_crossentropy を使用する場合は fit の class_weight 引数で
    調整する必要があったと思われます

また、参考のため mean_absolute_error(mae) の値も出力するように metrics で指定しています。

from keras.models import Model
from keras.layers import Input, Dropout
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPool2D
from keras.layers.normalization import BatchNormalization

input = Input(shape = imgs.shape[1:])

x = input

x = BatchNormalization()(x)

# Encoder

x = Conv2D(16, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = MaxPool2D()(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2D(128, 3, padding='same', activation = 'relu')(x)
x = Conv2D(128, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

# Decoder

x = Conv2DTranspose(64, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)
x = Conv2D(64, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(32, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)
x = Conv2D(32, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

x = Conv2DTranspose(16, 3, strides = 2, padding='same', activation = 'relu')(x)
x = Conv2D(16, 3, padding='same', activation = 'relu')(x)

x = BatchNormalization()(x)
x = Dropout(0.3)(x)

output = Conv2D(1, 1, activation = 'sigmoid')(x)

model = Model(inputs = input, outputs = output)

model.compile(loss = 'mse', optimizer = 'adam', metrics = ['mae'])

model.summary()
model.summary() 結果
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 288, 240, 3)       0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 288, 240, 3)       12        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 288, 240, 16)      448       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 144, 120, 16)      0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 144, 120, 16)      64        
_________________________________________________________________
dropout_1 (Dropout)          (None, 144, 120, 16)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 144, 120, 32)      4640      
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 144, 120, 32)      9248      
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 72, 60, 32)        0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 72, 60, 32)        128       
_________________________________________________________________
dropout_2 (Dropout)          (None, 72, 60, 32)        0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 72, 60, 64)        18496     
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 72, 60, 64)        36928     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 36, 30, 64)        0         
_________________________________________________________________
batch_normalization_4 (Batch (None, 36, 30, 64)        256       
_________________________________________________________________
dropout_3 (Dropout)          (None, 36, 30, 64)        0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 36, 30, 128)       73856     
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 36, 30, 128)       147584    
_________________________________________________________________
batch_normalization_5 (Batch (None, 36, 30, 128)       512       
_________________________________________________________________
dropout_4 (Dropout)          (None, 36, 30, 128)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 72, 60, 64)        73792     
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 72, 60, 64)        36928     
_________________________________________________________________
batch_normalization_6 (Batch (None, 72, 60, 64)        256       
_________________________________________________________________
dropout_5 (Dropout)          (None, 72, 60, 64)        0         
_________________________________________________________________
conv2d_transpose_2 (Conv2DTr (None, 144, 120, 32)      18464     
_________________________________________________________________
conv2d_13 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_14 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
conv2d_15 (Conv2D)           (None, 144, 120, 32)      9248      
_________________________________________________________________
batch_normalization_7 (Batch (None, 144, 120, 32)      128       
_________________________________________________________________
dropout_6 (Dropout)          (None, 144, 120, 32)      0         
_________________________________________________________________
conv2d_transpose_3 (Conv2DTr (None, 288, 240, 16)      4624      
_________________________________________________________________
conv2d_16 (Conv2D)           (None, 288, 240, 16)      2320      
_________________________________________________________________
batch_normalization_8 (Batch (None, 288, 240, 16)      64        
_________________________________________________________________
dropout_7 (Dropout)          (None, 288, 240, 16)      0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 288, 240, 1)       17        
=================================================================
Total params: 576,541
Trainable params: 575,831
Non-trainable params: 710
_________________________________________________________________

(4) 学習

教師データが少ないため、全て学習で使う事にします。

実行例(441 ~ 480 エポック)
hist = model.fit(imgs, labels, initial_epoch = 440, epochs = 480, batch_size = 10)

Keras では fit を繰り返し呼び出すと学習を(続きから)再開できるので、40 エポックを何回か繰り返しました。(バッチサイズは 20 で始めて途中で 10 へ変えたりしています)

その場合、正しいエポックを出力するには initial_epochepochs の値を調整する必要があります ※。

 ※ initial_epoch を指定しなくても
    fit を繰り返し実行するだけで学習は継続されますが、
    その場合は出力されるエポックの値がクリアされます(1 からのカウントとなる)
結果例
Epoch 441/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
Epoch 442/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0125
Epoch 443/480
160/160 [=====・・・ - loss: 0.0048 - mean_absolute_error: 0.0126
・・・
Epoch 478/480
160/160 [=====・・・ - loss: 0.0045 - mean_absolute_error: 0.0116
Epoch 479/480
160/160 [=====・・・ - loss: 0.0046 - mean_absolute_error: 0.0117
Epoch 480/480
160/160 [=====・・・ - loss: 0.0044 - mean_absolute_error: 0.0115

(5) 確認

fit の戻り値から mean_squared_error(loss)と mean_absolute_error の値の遷移をグラフ化してみます。

%matplotlib inline

import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (8, 4)

plt.subplot(1, 2, 1)
plt.plot(hist.history['loss'])

plt.subplot(1, 2, 2)
plt.plot(hist.history['mean_absolute_error'])
結果例(441 ~ 480 エポック)

f:id:fits:20190114231037p:plain

(6) 検証

(a) 教師データ

教師データの入力画像と輪郭画像(ラベルデータ)、model.predict の結果を並べて表示してみます。

def predict(index, s = 6.0):
    plt.rcParams['figure.figsize'] = (s, s)

    sh = imgs.shape[1:-1]

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([imgs[index]]))[0]
    pred *= 255

    plt.subplot(1, 3, 1)
    # 入力画像の表示
    plt.imshow(imgs[index].astype(int))

    plt.subplot(1, 3, 2)
    # 輪郭画像(ラベルデータ)の表示
    plt.imshow(labels[index].reshape(sh), cmap = 'gray')

    plt.subplot(1, 3, 3)
    # predict の結果表示
    plt.imshow(pred.reshape(sh).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, 輪郭画像, model.predict 結果

f:id:fits:20190114231107j:plain

概ね教師データに近い結果が出るようになっています。

(b) 教師データ以外

教師データとして使っていない画像に対して model.predict を実施し、輪郭の検出を行ってみます。

def predict_eval(file, s = 4.0):
    plt.rcParams['figure.figsize'] = (s, s)

    img = img_to_array(load_img(file))

    # 輪郭の検出(予測処理)
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    plt.subplot(1, 2, 1)
    # 入力画像の表示
    plt.imshow(img.astype(int))

    plt.subplot(1, 2, 2)
    # predict の結果表示
    plt.imshow(pred.reshape(pred.shape[:-1]).astype(int), cmap = 'gray')
結果例(480 エポック): 入力画像, model.predict 結果

f:id:fits:20190114231154j:plain

所々で途切れたりしていますが、ある程度の輪郭は検出できているように見えます。

(7) 保存

学習したモデルを保存します。

model.save('model/c1_480.h5')

輪郭検出

学習済みモデルを使って輪郭検出を行う処理をスクリプト化してみました。

predict_contours.py
import sys
import os
import glob
import numpy as np
from keras.preprocessing.image import load_img, img_to_array
from keras.models import load_model
import cv2

model_file = sys.argv[1]
img_files = sys.argv[2]
dest_dir = sys.argv[3]

model = load_model(model_file)

for f in glob.glob(img_files):
    img = img_to_array(load_img(f))

    # 輪郭の検出
    pred = model.predict(np.array([img]))[0]
    pred *= 255

    file, ext = os.path.splitext(os.path.basename(f))

    # 画像の保存
    cv2.imwrite(f"{dest_dir}/{file}_predict.png", pred)

    print(f"done: {f}")
実行例
python predict_contours.py model/c1_480.h5 img_eval2/*.jpg result

480 エポックの学習モデルを使って、教師データに無いタイプの背景を使った画像(影の影響もある)に試してみました。

輪郭検出結果例
入力画像 処理結果
f:id:fits:20181229154654j:plain f:id:fits:20190114231530p:plain
f:id:fits:20181229155049j:plain f:id:fits:20190114231547p:plain

こちらは難しかったようです。

なお、2つ目の画像は 120 エポックの学習モデルの方が良好な結果(輪郭がより多く検出されていた)でした。